| """The intermediate representation.""" | |
| import dataclasses | |
| import inspect | |
| import warnings | |
| from typing import List, Optional, Union | |
| from sglang.global_config import global_config | |
| from sglang.lang.choices import ChoicesSamplingMethod | |
| REGEX_INT = r"[-+]?[0-9]+[ \n]*" | |
| REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*" | |
| REGEX_BOOL = r"(True|False)" | |
| REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg | |
| class SglSamplingParams: | |
| max_new_tokens: int = 128 | |
| min_new_tokens: int = 0 | |
| n: int = 1 | |
| stop: Union[str, List[str]] = () | |
| stop_token_ids: Optional[List[int]] = () | |
| stop_regex: Optional[Union[str, List[str]]] = () | |
| temperature: float = 1.0 | |
| top_p: float = 1.0 | |
| top_k: int = -1 # -1 means disable | |
| min_p: float = 0.0 | |
| frequency_penalty: float = 0.0 | |
| presence_penalty: float = 0.0 | |
| ignore_eos: bool = False | |
| return_logprob: Optional[bool] = None | |
| logprob_start_len: Optional[int] = (None,) | |
| top_logprobs_num: Optional[int] = (None,) | |
| return_text_in_logprobs: Optional[bool] = (None,) | |
| json_schema: Optional[str] = None | |
| # for constrained generation, not included in to_xxx_kwargs | |
| dtype: Optional[str] = None | |
| regex: Optional[str] = None | |
| def clone(self): | |
| return SglSamplingParams( | |
| self.max_new_tokens, | |
| self.min_new_tokens, | |
| self.n, | |
| self.stop, | |
| self.stop_token_ids, | |
| self.stop_regex, | |
| self.temperature, | |
| self.top_p, | |
| self.top_k, | |
| self.min_p, | |
| self.frequency_penalty, | |
| self.presence_penalty, | |
| self.ignore_eos, | |
| self.return_logprob, | |
| self.logprob_start_len, | |
| self.top_logprobs_num, | |
| self.return_text_in_logprobs, | |
| self.json_schema, | |
| ) | |
| def to_openai_kwargs(self): | |
| # OpenAI does not support top_k, so we drop it here | |
| if self.regex is not None: | |
| warnings.warn("Regular expression is not supported in the OpenAI backend.") | |
| return { | |
| "max_tokens": self.max_new_tokens, | |
| "max_completion_tokens": self.max_new_tokens, | |
| "n": self.n, | |
| "stop": self.stop or None, | |
| "temperature": self.temperature, | |
| "top_p": self.top_p, | |
| "frequency_penalty": self.frequency_penalty, | |
| "presence_penalty": self.presence_penalty, | |
| } | |
| def to_vertexai_kwargs(self): | |
| if self.regex is not None: | |
| warnings.warn( | |
| "Regular expression is not supported in the VertexAI backend." | |
| ) | |
| return { | |
| "candidate_count": 1, | |
| "max_output_tokens": self.max_new_tokens, | |
| "stop_sequences": self.stop, | |
| "temperature": self.temperature, | |
| "top_p": self.top_p, | |
| "top_k": self.top_k if self.top_k > 0 else None, | |
| } | |
| def to_anthropic_kwargs(self): | |
| # Anthropic does not support frequency_penalty or presence_penalty, so we drop it here | |
| if self.regex is not None: | |
| warnings.warn( | |
| "Regular expression is not supported in the Anthropic backend." | |
| ) | |
| return { | |
| "max_tokens": self.max_new_tokens, | |
| "stop_sequences": ( | |
| self.stop if isinstance(self.stop, (list, tuple)) else [self.stop] | |
| ), | |
| "temperature": self.temperature, | |
| "top_p": self.top_p, | |
| "top_k": self.top_k, | |
| } | |
| def to_litellm_kwargs(self): | |
| if self.regex is not None: | |
| warnings.warn("Regular expression is not supported in the LiteLLM backend.") | |
| return { | |
| "max_tokens": self.max_new_tokens, | |
| "stop": self.stop or None, | |
| "temperature": self.temperature, | |
| "top_p": self.top_p, | |
| "frequency_penalty": self.frequency_penalty, | |
| "presence_penalty": self.presence_penalty, | |
| } | |
| def to_srt_kwargs(self): | |
| return { | |
| "max_new_tokens": self.max_new_tokens, | |
| "min_new_tokens": self.min_new_tokens, | |
| "n": self.n, | |
| "stop": self.stop, | |
| "stop_token_ids": self.stop_token_ids, | |
| "stop_regex": self.stop_regex, | |
| "temperature": self.temperature, | |
| "top_p": self.top_p, | |
| "top_k": self.top_k, | |
| "min_p": self.min_p, | |
| "frequency_penalty": self.frequency_penalty, | |
| "presence_penalty": self.presence_penalty, | |
| "ignore_eos": self.ignore_eos, | |
| "regex": self.regex, | |
| "json_schema": self.json_schema, | |
| } | |
| class SglFunction: | |
| def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None): | |
| self.func = func | |
| self.num_api_spec_tokens = num_api_spec_tokens | |
| self.bind_arguments = bind_arguments or {} | |
| self.pin_prefix_rid = None | |
| # Parse arguments | |
| argspec = inspect.getfullargspec(func) | |
| assert argspec.args[0] == "s", 'The first argument must be "s"' | |
| self.arg_names = argspec.args[1:] | |
| self.arg_defaults = argspec.defaults if argspec.defaults is not None else [] | |
| def bind(self, **kwargs): | |
| assert all(key in self.arg_names for key in kwargs) | |
| new_bind_dict = {**self.bind_arguments, **kwargs} | |
| return SglFunction(self.func, bind_arguments=new_bind_dict) | |
| def run( | |
| self, | |
| *args, | |
| max_new_tokens: int = 128, | |
| n: int = 1, | |
| stop: Optional[Union[str, List[str]]] = None, | |
| stop_token_ids: Optional[List[int]] = None, | |
| stop_regex: Optional[Union[str, List[str]]] = None, | |
| temperature: float = 1.0, | |
| top_p: float = 1.0, | |
| top_k: int = -1, | |
| min_p: float = 0.0, | |
| frequency_penalty: float = 0.0, | |
| presence_penalty: float = 0.0, | |
| ignore_eos: bool = False, | |
| return_logprob: Optional[bool] = None, | |
| logprob_start_len: Optional[int] = None, | |
| top_logprobs_num: Optional[int] = None, | |
| return_text_in_logprobs: Optional[bool] = None, | |
| stream: bool = False, | |
| backend=None, | |
| use_thread: bool = True, | |
| **kwargs, | |
| ): | |
| from sglang.lang.interpreter import run_program | |
| # avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/ | |
| if stop is None: | |
| stop = [] | |
| if stop_token_ids is None: | |
| stop_token_ids = [] | |
| if stop_regex is None: | |
| stop_regex = [] | |
| default_sampling_para = SglSamplingParams( | |
| max_new_tokens=max_new_tokens, | |
| n=n, | |
| stop=stop, | |
| stop_token_ids=stop_token_ids, | |
| stop_regex=stop_regex, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| min_p=min_p, | |
| frequency_penalty=frequency_penalty, | |
| presence_penalty=presence_penalty, | |
| ignore_eos=ignore_eos, | |
| return_logprob=return_logprob, | |
| logprob_start_len=logprob_start_len, | |
| top_logprobs_num=top_logprobs_num, | |
| return_text_in_logprobs=return_text_in_logprobs, | |
| ) | |
| backend = backend or global_config.default_backend | |
| return run_program( | |
| self, | |
| backend, | |
| args, | |
| kwargs, | |
| default_sampling_para, | |
| stream, | |
| use_thread=use_thread, | |
| ) | |
| def run_batch( | |
| self, | |
| batch_kwargs, | |
| *, | |
| max_new_tokens: int = 128, | |
| n: int = 1, | |
| stop: Optional[Union[str, List[str]]] = None, | |
| stop_token_ids: Optional[List[int]] = None, | |
| stop_regex: Optional[Union[str, List[str]]] = None, | |
| temperature: float = 1.0, | |
| top_p: float = 1.0, | |
| top_k: int = -1, | |
| min_p: float = 0.0, | |
| frequency_penalty: float = 0.0, | |
| presence_penalty: float = 0.0, | |
| ignore_eos: bool = False, | |
| return_logprob: Optional[bool] = None, | |
| logprob_start_len: Optional[int] = None, | |
| top_logprobs_num: Optional[int] = None, | |
| return_text_in_logprobs: Optional[bool] = None, | |
| backend=None, | |
| num_threads: Union[str, int] = "auto", | |
| progress_bar: bool = False, | |
| generator_style: bool = False, | |
| ): | |
| from sglang.lang.interpreter import run_program_batch | |
| if stop is None: | |
| stop = [] | |
| if stop_token_ids is None: | |
| stop_token_ids = [] | |
| if stop_regex is None: | |
| stop_regex = [] | |
| assert isinstance(batch_kwargs, (list, tuple)) | |
| if len(batch_kwargs) == 0: | |
| return [] | |
| if not isinstance(batch_kwargs[0], dict): | |
| num_programs = len(batch_kwargs) | |
| # change the list of argument values to dict of arg_name -> arg_value | |
| batch_kwargs = [ | |
| {self.arg_names[i]: v for i, v in enumerate(arg_values)} | |
| for arg_values in batch_kwargs | |
| if isinstance(arg_values, (list, tuple)) | |
| and len(self.arg_names) - len(self.arg_defaults) | |
| <= len(arg_values) | |
| <= len(self.arg_names) | |
| ] | |
| # Ensure to raise an exception if the number of arguments mismatch | |
| if len(batch_kwargs) != num_programs: | |
| raise Exception("Given arguments mismatch the SGL function signature") | |
| default_sampling_para = SglSamplingParams( | |
| max_new_tokens=max_new_tokens, | |
| n=n, | |
| stop=stop, | |
| stop_token_ids=stop_token_ids, | |
| stop_regex=stop_regex, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| min_p=min_p, | |
| frequency_penalty=frequency_penalty, | |
| presence_penalty=presence_penalty, | |
| ignore_eos=ignore_eos, | |
| return_logprob=return_logprob, | |
| logprob_start_len=logprob_start_len, | |
| top_logprobs_num=top_logprobs_num, | |
| return_text_in_logprobs=return_text_in_logprobs, | |
| ) | |
| backend = backend or global_config.default_backend | |
| return run_program_batch( | |
| self, | |
| backend, | |
| batch_kwargs, | |
| default_sampling_para, | |
| num_threads, | |
| progress_bar, | |
| generator_style=generator_style, | |
| ) | |
| def trace(self, *, backend=None, **kwargs): | |
| from sglang.lang.tracer import trace_program | |
| backend = backend or global_config.default_backend | |
| return trace_program(self, kwargs, backend) | |
| def cache(self, backend=None): | |
| from sglang.lang.interpreter import cache_program | |
| backend = backend or global_config.default_backend | |
| return cache_program(self, backend) | |
| def compile(self, *, backend=None): | |
| from sglang.lang.compiler import compile_func | |
| return compile_func(self, backend) | |
| def __call__(self, *args, **kwargs): | |
| from sglang.lang.tracer import TracingScope | |
| tracing_scope = TracingScope.get_current_scope() | |
| if tracing_scope is None: | |
| return self.run(*args, **kwargs) | |
| else: | |
| kwargs["backend"] = tracing_scope.tracer_state.backend | |
| return self.trace(*args, **kwargs) | |
| class SglExpr: | |
| node_ct = 0 | |
| def __init__(self): | |
| self.node_id = SglExpr.node_ct | |
| self.prev_node = None | |
| self.pid = None | |
| SglExpr.node_ct += 1 | |
| def __add__(self, other): | |
| if isinstance(other, str): | |
| other = SglConstantText(other) | |
| assert isinstance(other, SglExpr) | |
| return self.concatenate_ir(self, other) | |
| def __radd__(self, other): | |
| if isinstance(other, str): | |
| other = SglConstantText(other) | |
| assert isinstance(other, SglExpr), f"{other}" | |
| return self.concatenate_ir(other, self) | |
| def concatenate_ir(self, a, b): | |
| if isinstance(a, SglExprList): | |
| if isinstance(b, SglExprList): | |
| return SglExprList(a.expr_list + b.expr_list) | |
| else: | |
| return SglExprList(a.expr_list + [b]) | |
| elif isinstance(b, SglExprList): | |
| return SglExprList([a] + b.expr_list) | |
| return SglExprList([a, b]) | |
| def print_graph_dfs(self): | |
| ret = [""] | |
| visited = set() | |
| def dfs_print(x): | |
| if x is None or x in visited: | |
| return | |
| visited.add(x) | |
| # Print dependency | |
| if x.prev_node is not None: | |
| dfs_print(x.prev_node) | |
| if isinstance(x, SglExprList): | |
| for y in x.expr_list: | |
| dfs_print(y) | |
| # elif isinstance(x, SglRole): | |
| # dfs_print(x.expr) | |
| elif isinstance(x, SglVariable): | |
| dfs_print(x.source) | |
| # Print the node itself | |
| if isinstance(x, (SglFork, SglGetForkItem)): | |
| ret[0] += f"%{x.node_id} = {x}\n" | |
| else: | |
| if x.prev_node is not None: | |
| ret[0] += ( | |
| f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n" | |
| ) | |
| else: | |
| ret[0] += f"%{x.node_id} = " + str(x) + "\n" | |
| dfs_print(self) | |
| return ret[0] | |
| class SglExprList(SglExpr): | |
| def __init__(self, expr_list: List[SglExpr]): | |
| super().__init__() | |
| self.expr_list = expr_list | |
| def __repr__(self): | |
| return f"ExprList({self.expr_list})" | |
| class SglArgument(SglExpr): | |
| def __init__(self, name: str, value: str): | |
| super().__init__() | |
| self.name = name | |
| self.value = value | |
| def __repr__(self): | |
| return f"Argument(name={self.name}, value={repr(self.value)})" | |
| def __len__(self): | |
| return len(self.value) | |
| def __getitem__(self, i): | |
| return self.value[i] | |
| def __int__(self): | |
| return self.value | |
| def __bool__(self): | |
| return self.value | |
| def __format__(self, *args): | |
| raise TypeError( | |
| "Cannot put argument inside a f-string. " | |
| "This is not compatible with the tracer. " | |
| ) | |
| class SglImage(SglExpr): | |
| def __init__(self, path: str): | |
| self.path = path | |
| def __repr__(self) -> str: | |
| return f"SglImage({self.path})" | |
| class SglVideo(SglExpr): | |
| def __init__(self, path: str, num_frames: int): | |
| self.path = path | |
| self.num_frames = num_frames | |
| def __repr__(self) -> str: | |
| return f"SglVideo({self.path}, {self.num_frames})" | |
| class SglGen(SglExpr): | |
| def __init__( | |
| self, | |
| name: Optional[str] = None, | |
| max_new_tokens: Optional[int] = None, | |
| min_new_tokens: Optional[int] = None, | |
| n: Optional[int] = None, | |
| stop: Optional[Union[str, List[str]]] = None, | |
| stop_token_ids: Optional[List[int]] = None, | |
| stop_regex: Optional[Union[str, List[str]]] = None, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| min_p: Optional[float] = None, | |
| frequency_penalty: Optional[float] = None, | |
| presence_penalty: Optional[float] = None, | |
| ignore_eos: Optional[bool] = None, | |
| return_logprob: Optional[bool] = None, | |
| logprob_start_len: Optional[int] = None, | |
| top_logprobs_num: Optional[int] = None, | |
| return_text_in_logprobs: Optional[bool] = None, | |
| dtype: Optional[type] = None, | |
| regex: Optional[str] = None, | |
| json_schema: Optional[str] = None, | |
| ): | |
| """Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md""" | |
| super().__init__() | |
| self.name = name | |
| self.sampling_params = SglSamplingParams( | |
| max_new_tokens=max_new_tokens, | |
| min_new_tokens=min_new_tokens, | |
| n=n, | |
| stop=stop, | |
| stop_regex=stop_regex, | |
| stop_token_ids=stop_token_ids, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| min_p=min_p, | |
| frequency_penalty=frequency_penalty, | |
| presence_penalty=presence_penalty, | |
| ignore_eos=ignore_eos, | |
| return_logprob=return_logprob, | |
| logprob_start_len=logprob_start_len, | |
| top_logprobs_num=top_logprobs_num, | |
| return_text_in_logprobs=return_text_in_logprobs, | |
| dtype=dtype, | |
| regex=regex, | |
| json_schema=json_schema, | |
| ) | |
| def __repr__(self): | |
| return f"Gen('{self.name}')" | |
| class SglConstantText(SglExpr): | |
| def __init__(self, value: str): | |
| super().__init__() | |
| self.value = value | |
| def __repr__(self): | |
| return f"Constant({repr(self.value)})" | |
| class SglRoleBegin(SglExpr): | |
| def __init__(self, role: str): | |
| super().__init__() | |
| self.role = role | |
| def __repr__(self): | |
| return f"RoleBegin({self.role})" | |
| class SglRoleEnd(SglExpr): | |
| def __init__(self, role: str): | |
| super().__init__() | |
| self.role = role | |
| def __repr__(self): | |
| return f"RoleEnd({self.role})" | |
| class SglSelect(SglExpr): | |
| def __init__( | |
| self, | |
| name: str, | |
| choices: List[str], | |
| temperature: float, | |
| choices_method: ChoicesSamplingMethod, | |
| ): | |
| super().__init__() | |
| self.name = name | |
| self.choices = choices | |
| self.temperature = temperature | |
| self.choices_method = choices_method | |
| def __repr__(self): | |
| return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})" | |
| class SglFork(SglExpr): | |
| def __init__(self, number: int, position_ids_offset=None): | |
| super().__init__() | |
| self.number = number | |
| self.position_ids_offset = position_ids_offset | |
| def __repr__(self): | |
| return ( | |
| f"Fork(%{self.prev_node.node_id}, number={self.number}, " | |
| f"position_ids_offset={self.position_ids_offset})" | |
| ) | |
| class SglGetForkItem(SglExpr): | |
| def __init__(self, index: int): | |
| super().__init__() | |
| self.index = index | |
| def __repr__(self): | |
| return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})" | |
| class SglVariable(SglExpr): | |
| def __init__(self, name: str, source): | |
| super().__init__() | |
| self.name = name | |
| self.source = source | |
| def __repr__(self): | |
| return f"Variable('{self.name}', source=%{self.source.node_id})" | |
| class SglVarScopeBegin(SglExpr): | |
| def __init__(self, name: str): | |
| super().__init__() | |
| self.name = name | |
| def __repr__(self): | |
| return f"VarScopeBegin('{self.name}')" | |
| class SglVarScopeEnd(SglExpr): | |
| def __init__(self, name: str): | |
| super().__init__() | |
| self.name = name | |
| def __repr__(self): | |
| return f"VarScopeEnd('{self.name}')" | |
| class SglConcateAndAppend(SglExpr): | |
| def __init__(self, states): | |
| super().__init__() | |
| self.states = states | |
| def __repr__(self): | |
| return f"ConcatenateAndAppend('{self.states}')" | |
| class SglCommitLazy(SglExpr): | |
| def __init__(self): | |
| super().__init__() | |
| def __repr__(self): | |
| return "CommitLazy()" | |
| class SglSeparateReasoning(SglExpr): | |
| def __init__(self, model_type: str, expr: SglExpr): | |
| super().__init__() | |
| self.model_type = model_type | |
| self.expr = expr | |
| self.name = None | |
| self._process_expr(expr) | |
| def process_name_for_reasoning(self, name): | |
| if not name: | |
| raise ValueError("name must be provided") | |
| return f"{name}_reasoning_content" | |
| def _process_expr(self, expr): | |
| if isinstance(expr, SglGen): | |
| self.name = self.process_name_for_reasoning(expr.name) | |
| elif isinstance(expr, SglSelect): | |
| self.name = self.process_name_for_reasoning(expr.name) | |
| elif isinstance(expr, SglExprList): | |
| for x in expr.expr_list: | |
| self._process_expr(x) | |
| def __repr__(self): | |
| return f"SeparateReasoning(model_type={self.model_type}, name={self.name})" | |
Xet Storage Details
- Size:
- 20.2 kB
- Xet hash:
- 89f6d9da2a6a538aedf544051b7e7653eb1b9195095afd569ba6ce75e9ba9e0e
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.