| """The interpreter that executes SGL programs""" | |
| import asyncio | |
| import contextvars | |
| import copy | |
| import multiprocessing | |
| import queue | |
| import threading | |
| import uuid | |
| import warnings | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import contextmanager | |
| from typing import Any, Callable, Dict, List, Optional | |
| import tqdm | |
| from sglang.global_config import global_config | |
| from sglang.lang.ir import ( | |
| SglCommitLazy, | |
| SglConcateAndAppend, | |
| SglConstantText, | |
| SglExpr, | |
| SglExprList, | |
| SglGen, | |
| SglImage, | |
| SglRoleBegin, | |
| SglRoleEnd, | |
| SglSelect, | |
| SglSeparateReasoning, | |
| SglVariable, | |
| SglVarScopeBegin, | |
| SglVarScopeEnd, | |
| SglVideo, | |
| ) | |
| from sglang.utils import ( | |
| encode_image_base64, | |
| encode_video_base64, | |
| get_exception_traceback, | |
| ) | |
| def run_internal(state, program, func_args, func_kwargs, sync): | |
| try: | |
| state.ret_value = program.func(state, *func_args, **func_kwargs) | |
| except Exception as e: | |
| raise e | |
| finally: | |
| state.stream_executor.end() | |
| if sync: | |
| state.stream_executor.sync() | |
| if global_config.verbosity >= 2: | |
| print(state.text()) | |
| def run_program( | |
| program, | |
| backend, | |
| func_args, | |
| func_kwargs, | |
| default_sampling_para, | |
| stream, | |
| sync=False, | |
| use_thread=True, | |
| ): | |
| if hasattr(backend, "endpoint"): | |
| backend = backend.endpoint | |
| assert backend is not None, "Please specify a backend" | |
| func_kwargs.update(program.bind_arguments) | |
| stream_executor = StreamExecutor( | |
| backend, | |
| func_kwargs, | |
| default_sampling_para, | |
| chat_template=None, | |
| stream=stream, | |
| num_api_spec_tokens=program.num_api_spec_tokens, | |
| use_thread=use_thread, | |
| ) | |
| state = ProgramState(stream_executor) | |
| if stream: | |
| t = threading.Thread( | |
| target=run_internal, args=(state, program, func_args, func_kwargs, sync) | |
| ) | |
| t.start() | |
| return state | |
| else: | |
| run_internal(state, program, func_args, func_kwargs, sync) | |
| return state | |
| def run_program_batch( | |
| program, | |
| backend, | |
| batch_arguments, | |
| default_sampling_para, | |
| num_threads, | |
| progress_bar, | |
| generator_style=False, | |
| ): | |
| if hasattr(backend, "endpoint"): | |
| backend = backend.endpoint | |
| # Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program. | |
| if global_config.enable_precache_with_tracing and len(batch_arguments) > 1: | |
| cache_program(program, backend) | |
| # Run all programs | |
| if num_threads == "auto": | |
| num_threads = max(96, multiprocessing.cpu_count() * 16) | |
| num_threads = min(num_threads, len(batch_arguments)) | |
| if generator_style: | |
| return _run_program_batch_generator( | |
| program, | |
| backend, | |
| batch_arguments, | |
| default_sampling_para, | |
| num_threads, | |
| progress_bar, | |
| ) | |
| # Original code path when generator_style=False | |
| if num_threads == 1: | |
| rets = [] | |
| if progress_bar: | |
| for arguments in tqdm.tqdm(batch_arguments): | |
| rets.append( | |
| run_program( | |
| program, | |
| backend, | |
| (), | |
| arguments, | |
| default_sampling_para, | |
| False, | |
| True, | |
| ) | |
| ) | |
| else: | |
| for arguments in batch_arguments: | |
| rets.append( | |
| run_program( | |
| program, | |
| backend, | |
| (), | |
| arguments, | |
| default_sampling_para, | |
| False, | |
| True, | |
| ) | |
| ) | |
| else: | |
| if progress_bar: | |
| pbar = tqdm.tqdm(total=len(batch_arguments)) | |
| with ThreadPoolExecutor(num_threads) as executor: | |
| futures = [] | |
| for arguments in batch_arguments: | |
| futures.append( | |
| executor.submit( | |
| run_program, | |
| program, | |
| backend, | |
| (), | |
| arguments, | |
| default_sampling_para, | |
| False, | |
| True, | |
| ) | |
| ) | |
| if progress_bar: | |
| futures[-1].add_done_callback(lambda _: pbar.update()) | |
| rets = [f.result() for f in futures] | |
| rets[-1].sync() | |
| if progress_bar: | |
| pbar.close() | |
| return rets | |
| def _run_program_batch_generator( | |
| program, | |
| backend, | |
| batch_arguments, | |
| default_sampling_para, | |
| num_threads, | |
| progress_bar, | |
| ): | |
| """Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor.""" | |
| if num_threads == 1: | |
| iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments | |
| for arguments in iterator: | |
| yield run_program( | |
| program, | |
| backend, | |
| (), | |
| arguments, | |
| default_sampling_para, | |
| False, | |
| True, | |
| ) | |
| else: | |
| pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None | |
| # Process in chunks to avoid overwhelming ThreadPoolExecutor | |
| # Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks | |
| # so we will never reach "yield" until all tasks are done | |
| chunk_size = 200 | |
| with ThreadPoolExecutor(num_threads) as executor: | |
| for chunk_start in range(0, len(batch_arguments), chunk_size): | |
| chunk_end = min(chunk_start + chunk_size, len(batch_arguments)) | |
| chunk_futures = [] | |
| # Submit chunk of tasks | |
| for i in range(chunk_start, chunk_end): | |
| future = executor.submit( | |
| run_program, | |
| program, | |
| backend, | |
| (), | |
| batch_arguments[i], | |
| default_sampling_para, | |
| False, | |
| True, | |
| ) | |
| if pbar: | |
| future.add_done_callback(lambda _: pbar.update()) | |
| chunk_futures.append(future) | |
| # Yield results from this chunk as they complete | |
| for future in chunk_futures: | |
| yield future.result() | |
| if pbar: | |
| pbar.close() | |
| def cache_program(program, backend): | |
| from sglang.lang.tracer import extract_prefix_by_tracing | |
| prefix = extract_prefix_by_tracing(program, backend) | |
| if prefix and len(prefix) > 64: | |
| backend.cache_prefix(prefix) | |
| class StreamExecutor: | |
| """A stream executor that executes SGL expressions in a background thread.""" | |
| def __init__( | |
| self, | |
| backend, | |
| arguments, | |
| default_sampling_para, | |
| chat_template, | |
| stream, | |
| num_api_spec_tokens=None, | |
| use_thread=True, | |
| ): | |
| from sglang.lang.backend.base_backend import BaseBackend | |
| self.sid = uuid.uuid4().hex | |
| self.backend: BaseBackend = backend | |
| self.arguments: Dict[str, Any] = arguments | |
| self.default_sampling_para = default_sampling_para | |
| self.stream = stream | |
| self.variables = {} # Dict[name: str -> value: str] | |
| self.variable_event = {} # Dict[name: str -> event: threading.Event] | |
| self.meta_info = {} # Dict[name: str -> info: str] | |
| self.is_finished = False | |
| self.error_ = None | |
| # For completion | |
| self.text_ = "" # The full text | |
| # For chat | |
| self.messages_ = [] # The messages in the OpenAI API format | |
| self.chat_template = chat_template or self.backend.get_chat_template() | |
| self.cur_role = None | |
| self.cur_role_begin_pos = None | |
| # For vision | |
| self.images_ = [] | |
| self.cur_images = [] | |
| # For fork/join | |
| self.fork_start_text_pos = None | |
| # For speculative execution | |
| self.num_api_spec_tokens = num_api_spec_tokens | |
| self.speculated_text = "" | |
| # Worker thread | |
| self.use_thread = use_thread | |
| if self.use_thread: | |
| self.queue = queue.Queue() | |
| def _run_worker_in_context(): | |
| self._thread_worker_func() | |
| self.worker = threading.Thread( | |
| target=contextvars.copy_context().run, args=(_run_worker_in_context,) | |
| ) | |
| self.worker.start() | |
| # For streaming | |
| if stream: | |
| self.stream_text_event = threading.Event() | |
| self.stream_var_event = {} | |
| else: | |
| self.stream_text_event = None | |
| self.stream_var_event = None | |
| def submit(self, expr: SglExpr): | |
| self._init_var_event(expr) | |
| if self.use_thread: | |
| self.queue.put(expr) | |
| else: | |
| self._execute(expr) | |
| def sync(self): | |
| if self.use_thread: | |
| self.queue.join() | |
| def get_var(self, name): | |
| if name in self.variable_event: | |
| self.variable_event[name].wait() | |
| return self.variables[name] | |
| def set_var(self, name, value): | |
| self.variables[name] = value | |
| def get_meta_info(self, name, timeout=None): | |
| if name in self.variable_event: | |
| got = self.variable_event[name].wait(timeout) | |
| if not got: | |
| raise TimeoutError(f"Timeout while waiting for event '{name}'") | |
| ret = self.meta_info.get(name, None) | |
| return ret | |
| def fork( | |
| self, | |
| size: int = 1, | |
| position_ids_offset: Optional[List[int]] = None, | |
| ): | |
| if size > 1 and str(self.text_): | |
| self.submit(SglCommitLazy()) | |
| self.sync() | |
| size = int(size) | |
| exes = [ | |
| StreamExecutor( | |
| self.backend, | |
| self.arguments, | |
| self.default_sampling_para, | |
| self.chat_template, | |
| self.stream, | |
| ) | |
| for _ in range(size) | |
| ] | |
| for i in range(size): | |
| exes[i].variables = dict(self.variables) | |
| exes[i].text_ = str(self.text_) | |
| exes[i].messages_ = list(self.messages_) | |
| exes[i].cur_role = self.cur_role | |
| exes[i].cur_role_begin_pos = self.cur_role_begin_pos | |
| exes[i].fork_start_text_pos = len(self.text_) | |
| exes[i].images_ = list(self.images_) | |
| # TODO(ying): handle API speculative execution | |
| return exes | |
| def text(self): | |
| self.sync() | |
| return self.text_ | |
| def messages(self): | |
| self.sync() | |
| return self.messages_ | |
| def error(self): | |
| self.sync() | |
| return self.error_ | |
| def end(self): | |
| if self.use_thread: | |
| if self.worker.is_alive(): | |
| self.queue.put(None) | |
| self.backend.end_program(self) | |
| def _thread_worker_func(self): | |
| error = None | |
| while True: | |
| expr = self.queue.get() | |
| if expr is None: | |
| self.queue.task_done() | |
| break | |
| try: | |
| self._execute(expr) | |
| except Exception as e: | |
| warnings.warn(f"Error in stream_executor: {get_exception_traceback()}") | |
| error = e | |
| break | |
| self.queue.task_done() | |
| if self.stream_text_event: | |
| self.stream_text_event.set() | |
| # Clean the queue and events | |
| if error is not None: | |
| try: | |
| while True: | |
| self.queue.task_done() | |
| self.queue.get_nowait() | |
| except queue.Empty: | |
| pass | |
| for name in self.variable_event: | |
| self.variable_event[name].set() | |
| if self.stream_var_event: | |
| for name in self.stream_var_event: | |
| self.stream_var_event[name].set() | |
| self.error_ = error | |
| if self.stream_text_event: | |
| self.stream_text_event.set() | |
| self.is_finished = True | |
| def _execute(self, other): | |
| if isinstance(other, str): | |
| other = SglConstantText(other) | |
| assert isinstance(other, SglExpr), f"{other}" | |
| if isinstance(other, SglConstantText): | |
| self._execute_fill(other.value) | |
| elif isinstance(other, SglGen): | |
| self._execute_gen(other) | |
| elif isinstance(other, SglSelect): | |
| self._execute_select(other) | |
| elif isinstance(other, SglExprList): | |
| for x in other.expr_list: | |
| self._execute(x) | |
| elif isinstance(other, SglRoleBegin): | |
| self._execute_role_begin(other) | |
| elif isinstance(other, SglRoleEnd): | |
| self._execute_role_end(other) | |
| elif isinstance(other, SglImage): | |
| self._execute_image(other) | |
| elif isinstance(other, SglVideo): | |
| self._execute_video(other) | |
| elif isinstance(other, SglVariable): | |
| self._execute_variable(other) | |
| elif isinstance(other, SglVarScopeBegin): | |
| self._execute_var_scope_begin(other) | |
| elif isinstance(other, SglVarScopeEnd): | |
| self._execute_var_scope_end(other) | |
| elif isinstance(other, SglCommitLazy): | |
| self._execute_commit_lazy_operations(other) | |
| elif isinstance(other, SglConcateAndAppend): | |
| if ( | |
| global_config.enable_parallel_encoding | |
| and self.backend.support_concate_and_append | |
| ): | |
| self._execute_concatenate_and_append_kv_cache(other) | |
| else: | |
| self._execute_concatenate_and_append_text(other) | |
| elif isinstance(other, SglSeparateReasoning): | |
| self._execute_separate_reasoning(other) | |
| else: | |
| raise ValueError(f"Unknown type: {type(other)}") | |
| def _execute_fill(self, value: str, prefix=False): | |
| value = str(value) | |
| if ( | |
| self.cur_role == "assistant" | |
| and self.num_api_spec_tokens is not None | |
| and self.backend.is_chat_model | |
| and not prefix | |
| ): | |
| self.backend.spec_fill(value) | |
| return | |
| if self.speculated_text.startswith(value): | |
| self.speculated_text = self.speculated_text[len(value) :] | |
| else: | |
| self.speculated_text = "" | |
| self.text_ += value | |
| def _execute_image(self, expr: SglImage): | |
| path = expr.path | |
| base64_data = encode_image_base64(path) | |
| self.images_.append((path, base64_data)) | |
| self.cur_images.append((path, base64_data)) | |
| self.text_ += self.chat_template.image_token | |
| def _execute_video(self, expr: SglVideo): | |
| path = expr.path | |
| num_frames = expr.num_frames | |
| base64_data = encode_video_base64(path, num_frames) | |
| self.images_.append((path, base64_data)) | |
| self.cur_images.append((path, base64_data)) | |
| self.text_ += self.chat_template.image_token | |
| def _spec_gen(self, sampling_params): | |
| stop = sampling_params.stop | |
| max_new_tokens = sampling_params.max_new_tokens | |
| meta_info = {} | |
| def regen(): | |
| nonlocal meta_info | |
| sampling_params.max_new_tokens = max( | |
| sampling_params.max_new_tokens, self.num_api_spec_tokens | |
| ) | |
| sampling_params.stop = None | |
| self.speculated_text, meta_info = self.backend.generate( | |
| self, sampling_params=sampling_params | |
| ) | |
| def find_stop(): | |
| if isinstance(stop, str): | |
| return self.speculated_text.find(stop) | |
| elif isinstance(stop, (tuple, list)): | |
| pos = -1 | |
| for stop_str in stop: | |
| stop_pos = self.speculated_text.find(stop_str) | |
| if stop_pos != -1 and (pos == -1 or stop_pos < pos): | |
| pos = stop_pos | |
| return pos | |
| else: | |
| raise Exception("Wrong type of stop in sampling parameters.") | |
| if stop is None: | |
| if len(self.speculated_text) < max_new_tokens: | |
| regen() | |
| comp = self.speculated_text[:max_new_tokens] | |
| self.speculated_text = self.speculated_text[max_new_tokens:] | |
| elif isinstance(stop, (str, list, tuple)): | |
| if self.speculated_text == "": | |
| regen() | |
| stop_pos = find_stop() | |
| if stop_pos == -1: | |
| stop_pos = min( | |
| sampling_params.max_new_tokens, | |
| len(self.speculated_text), | |
| ) | |
| comp = self.speculated_text[:stop_pos] | |
| self.speculated_text = self.speculated_text[stop_pos:] | |
| else: | |
| raise ValueError("Wrong type of stop in sampling parameters.") | |
| return comp, meta_info | |
| def _execute_gen(self, expr: SglGen): | |
| sampling_params = self._resolve_sampling_params(expr.sampling_params) | |
| name = expr.name | |
| if not self.stream: | |
| if self.num_api_spec_tokens is None: | |
| comp, meta_info = self.backend.generate( | |
| self, | |
| sampling_params=sampling_params, | |
| ) | |
| else: | |
| if self.backend.is_chat_model: | |
| # Speculative execution on models with only chat interface. | |
| # Store the calls into a temporary list. | |
| # They will be lazily executed later. | |
| comp, meta_info = self.backend.generate( | |
| self, | |
| sampling_params=sampling_params, | |
| spec_var_name=name, | |
| ) | |
| return | |
| else: # Speculative execution on models with completion interface | |
| comp, meta_info = self._spec_gen(sampling_params) | |
| if isinstance(comp, list): | |
| self.text_ += comp[0] | |
| else: | |
| assert isinstance(comp, str) | |
| self.text_ += comp | |
| self.variables[name] = comp | |
| self.meta_info[name] = meta_info | |
| self.variable_event[name].set() | |
| else: | |
| assert ( | |
| self.num_api_spec_tokens is None | |
| ), "stream is not supported with api speculative execution" | |
| generator = self.backend.generate_stream( | |
| self, sampling_params=sampling_params | |
| ) | |
| self.variables[name] = "" | |
| self.stream_var_event[name].set() | |
| for comp, meta_info in generator: | |
| self.text_ += comp | |
| self.variables[name] += comp | |
| self.meta_info[name] = meta_info | |
| self.stream_var_event[name].set() | |
| self.stream_text_event.set() | |
| self.variable_event[name].set() | |
| self.stream_var_event[name].set() | |
| def _execute_select(self, expr: SglSelect): | |
| choices_decision = self.backend.select( | |
| self, expr.choices, expr.temperature, expr.choices_method | |
| ) | |
| if expr.name is not None: | |
| name = expr.name | |
| self.variables[name] = choices_decision.decision | |
| self.meta_info[name] = choices_decision.meta_info | |
| self.variable_event[name].set() | |
| if self.stream_var_event: | |
| self.stream_var_event[name].set() | |
| self.text_ += choices_decision.decision | |
| def _execute_variable(self, expr: SglVariable): | |
| src_executor = expr.source_stream_executor | |
| value = src_executor.get_var(expr.name) | |
| self._execute_fill(value) | |
| def _execute_role_begin(self, expr: SglRoleBegin): | |
| assert self.cur_role is None, "Nested roles are not allowed." | |
| if len(self.messages_) == 0 and expr.role != "system": | |
| # Insert the default system message | |
| default_system = self.chat_template.default_system_prompt | |
| if default_system: | |
| self._execute_role_begin(SglRoleBegin("system")) | |
| self._execute_fill(default_system) | |
| self._execute_role_end(SglRoleEnd("system")) | |
| self.cur_role = expr.role | |
| prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_) | |
| self._execute_fill(prefix, prefix=True) | |
| self.cur_role_begin_pos = len(self.text_) | |
| def _execute_role_end(self, expr: SglRoleEnd): | |
| if ( | |
| self.cur_role == "assistant" | |
| and self.num_api_spec_tokens is not None | |
| and self.backend.is_chat_model | |
| ): | |
| # Execute the stored lazy generation calls | |
| self.backend.role_end_generate(self) | |
| self.cur_role = None | |
| new_text = self.text_[self.cur_role_begin_pos :].lstrip() | |
| _, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_) | |
| self._execute_fill(suffix) | |
| if self.cur_images: | |
| # OpenAI vision API format | |
| last_msg = { | |
| "role": expr.role, | |
| "content": [{"type": "text", "text": new_text}], | |
| } | |
| for image_path, image_base64_data in self.cur_images: | |
| last_msg["content"].append( | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{image_base64_data}" | |
| }, | |
| } | |
| ) | |
| self.messages_.append(last_msg) | |
| self.cur_images = [] | |
| else: | |
| # OpenAI chat API format | |
| self.messages_.append({"role": expr.role, "content": new_text}) | |
| def _execute_var_scope_begin(self, expr: SglVarScopeBegin): | |
| self.variables[expr.name] = int(len(self.text_)) | |
| def _execute_var_scope_end(self, expr: SglVarScopeEnd): | |
| self.variables[expr.name] = self.text_[self.variables[expr.name] :] | |
| self.variable_event[expr.name].set() | |
| def _execute_commit_lazy_operations(self, expr: SglCommitLazy): | |
| self.backend.commit_lazy_operations(self) | |
| def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend): | |
| new_text = "" | |
| for s in expr.states: | |
| exe = s.stream_executor | |
| exe.sync() | |
| new_text += exe.text_[exe.fork_start_text_pos :] | |
| self._execute_fill(new_text) | |
| def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend): | |
| self_len = len(self.text_) | |
| for i, s in enumerate(expr.states): | |
| exe = s.stream_executor | |
| exe.submit(SglCommitLazy()) | |
| for i, s in enumerate(expr.states): | |
| exe = s.stream_executor | |
| exe.sync() | |
| assert exe.fork_start_text_pos == self_len | |
| self.text_ += exe.text_[exe.fork_start_text_pos :] | |
| src_rids = [state.stream_executor.sid for state in expr.states] | |
| self.backend.concatenate_and_append(src_rids, self.sid) | |
| def _execute_separate_reasoning(self, expr: SglSeparateReasoning): | |
| if self.stream: | |
| # separate reasoning for stream is not supported | |
| return | |
| if ( | |
| self.cur_role == "assistant" | |
| and self.num_api_spec_tokens is not None | |
| and self.backend.is_chat_model | |
| ): | |
| # Execute the stored lazy generation calls | |
| self.backend.role_end_generate(self) | |
| from sglang.srt.parser.reasoning_parser import ReasoningParser | |
| reasoning_parser = ReasoningParser(expr.model_type) | |
| other = expr.expr | |
| if not other: | |
| return | |
| elif isinstance(other, SglGen) or isinstance(other, SglSelect): | |
| cur_text = self.get_var(other.name) | |
| reasoning, normal_text = reasoning_parser.parse_non_stream(cur_text) | |
| reasoning_name = expr.process_name_for_reasoning(other.name) | |
| self.set_var(other.name, normal_text) | |
| self.set_var(reasoning_name, reasoning) | |
| # the variable is ready to be used | |
| self.variable_event[reasoning_name].set() | |
| self.text_ = self.text_[: self.cur_role_begin_pos] + normal_text | |
| elif isinstance(other, SglExprList): | |
| for x in other.expr_list: | |
| self._execute_separate_reasoning( | |
| SglSeparateReasoning(expr.model_type, x) | |
| ) | |
| def _init_var_event(self, expr): | |
| if isinstance( | |
| expr, (SglGen, SglSelect, SglVarScopeBegin, SglSeparateReasoning) | |
| ): | |
| self.variable_event[expr.name] = threading.Event() | |
| if self.stream: | |
| self.stream_var_event[expr.name] = threading.Event() | |
| elif isinstance(expr, SglExprList): | |
| for e in expr.expr_list: | |
| self._init_var_event(e) | |
| def _resolve_sampling_params(self, sampling_params): | |
| """ | |
| Construct sampling param based on default + override values | |
| The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args) | |
| , and `sampling_params` contains the override values from sgl.gen(). | |
| Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`. | |
| It also extends the stop tokens based on the chat template. | |
| """ | |
| # deepcopy is required because the dict has lists inside | |
| clone = copy.deepcopy(self.default_sampling_para) | |
| for item in [ | |
| "max_new_tokens", | |
| "min_new_tokens", | |
| "n", | |
| "stop", | |
| "stop_token_ids", | |
| "stop_regex", | |
| "temperature", | |
| "top_p", | |
| "top_k", | |
| "min_p", | |
| "frequency_penalty", | |
| "presence_penalty", | |
| "ignore_eos", | |
| "return_logprob", | |
| "logprob_start_len", | |
| "top_logprobs_num", | |
| "return_text_in_logprobs", | |
| "dtype", | |
| "regex", | |
| "json_schema", | |
| ]: | |
| value = getattr(sampling_params, item, None) | |
| if value is not None: | |
| setattr(clone, item, value) | |
| if self.chat_template.stop_str: | |
| if clone.stop == (): | |
| clone.stop = [] | |
| elif isinstance(clone.stop, str): | |
| clone.stop = [clone.stop] | |
| clone.stop += self.chat_template.stop_str | |
| return clone | |
| def __del__(self): | |
| self.end() | |
| class ProgramState: | |
| """The state of an SGL program.""" | |
| def __init__(self, stream_executor: StreamExecutor): | |
| self.stream_executor = stream_executor | |
| def _role_common(self, name: str, expr: Optional[SglExpr] = None): | |
| if expr is not None: | |
| role_expr = SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)]) | |
| self.stream_executor.submit(role_expr) | |
| return role_expr | |
| else: | |
| def role_scope(): | |
| self.stream_executor.submit(SglRoleBegin(name)) | |
| yield | |
| self.stream_executor.submit(SglRoleEnd(name)) | |
| return role_scope() | |
| def system(self, expr: Optional[SglExpr] = None): | |
| return self._role_common("system", expr) | |
| def user(self, expr: Optional[SglExpr] = None): | |
| return self._role_common("user", expr) | |
| def assistant(self, expr: Optional[SglExpr] = None): | |
| return self._role_common("assistant", expr) | |
| def var_scope(self, name: str): | |
| self.stream_executor.submit(SglVarScopeBegin(name)) | |
| yield | |
| self.stream_executor.submit(SglVarScopeEnd(name)) | |
| def fork( | |
| self, | |
| size: int = 1, | |
| position_ids_offset: Optional[List[int]] = None, | |
| ): | |
| stream_executors = self.stream_executor.fork(size, position_ids_offset) | |
| states = [ProgramState(x) for x in stream_executors] | |
| state_group = ProgramStateGroup(states, self) | |
| return state_group | |
| def copy(self, position_ids_offset: Optional[List[int]] = None): | |
| state_group = self.fork(1, position_ids_offset) | |
| try: | |
| yield state_group[0] | |
| finally: | |
| state_group.join() | |
| def text(self): | |
| return self.stream_executor.text() | |
| def messages(self): | |
| return self.stream_executor.messages() | |
| def sync(self): | |
| return self.stream_executor.sync() | |
| def error(self): | |
| return self.stream_executor.error() | |
| def text_iter(self, var_name: Optional[str] = None): | |
| if self.stream_executor.stream: | |
| prev = 0 | |
| if var_name is None: | |
| event = self.stream_executor.stream_text_event | |
| while True: | |
| event.wait() | |
| event.clear() | |
| out = str(self.stream_executor.text_[prev:]) | |
| prev += len(out) | |
| if out: | |
| yield out | |
| if self.stream_executor.is_finished: | |
| break | |
| else: | |
| event = None | |
| while not event: | |
| if var_name in self.stream_executor.stream_var_event: | |
| event = self.stream_executor.stream_var_event[var_name] | |
| if self.stream_executor.is_finished: | |
| yield "" | |
| return | |
| while True: | |
| event.wait() | |
| event.clear() | |
| out = str(self.stream_executor.variables[var_name][prev:]) | |
| prev += len(out) | |
| if out: | |
| yield out | |
| if self.stream_executor.variable_event[var_name].is_set(): | |
| break | |
| else: | |
| if var_name is None: | |
| yield self.text() | |
| else: | |
| yield self.get_var(var_name) | |
| async def text_async_iter( | |
| self, var_name: Optional[str] = None, return_meta_data: bool = False | |
| ): | |
| loop = asyncio.get_running_loop() | |
| if self.stream_executor.stream: | |
| prev = 0 | |
| if var_name is None: | |
| event = self.stream_executor.stream_text_event | |
| while True: | |
| await loop.run_in_executor(None, event.wait) | |
| event.clear() | |
| out = str(self.stream_executor.text_[prev:]) | |
| prev += len(out) | |
| if out: | |
| yield out | |
| if self.stream_executor.is_finished: | |
| break | |
| else: | |
| event = None | |
| while not event: | |
| if var_name in self.stream_executor.stream_var_event: | |
| event = self.stream_executor.stream_var_event[var_name] | |
| if self.stream_executor.is_finished: | |
| yield "" | |
| return | |
| while True: | |
| await loop.run_in_executor(None, event.wait) | |
| event.clear() | |
| out = str(self.stream_executor.variables[var_name][prev:]) | |
| prev += len(out) | |
| if out: | |
| if return_meta_data: | |
| yield out, self.stream_executor.meta_info[var_name] | |
| else: | |
| yield out | |
| if self.stream_executor.variable_event[var_name].is_set(): | |
| break | |
| else: | |
| if var_name is None: | |
| yield self.text() | |
| else: | |
| yield self.get_var(var_name) | |
| def get_var(self, name): | |
| return self.stream_executor.get_var(name) | |
| def set_var(self, name, value): | |
| return self.stream_executor.set_var(name, value) | |
| def get_meta_info(self, name): | |
| return self.stream_executor.get_meta_info(name) | |
| def __iadd__(self, other): | |
| if other is None: | |
| raise ValueError("Tried to append None to state.") | |
| self.stream_executor.submit(other) | |
| return self | |
| def __getitem__(self, name): | |
| return self.get_var(name) | |
| def __setitem__(self, name, value): | |
| self.set_var(name, value) | |
| def __contains__(self, name): | |
| return name in self.stream_executor.variables | |
| def __del__(self): | |
| self.stream_executor.end() | |
| def __repr__(self) -> str: | |
| return f"ProgramState({self.text()})" | |
| class ProgramStateGroup: | |
| def __init__( | |
| self, states: List[ProgramState], src_state: Optional[ProgramState] = None | |
| ): | |
| self.states = states | |
| self.src_state = src_state | |
| def join(self, mode: str = "gather_variable"): | |
| if mode == "gather_variable": | |
| # Copy variables back | |
| src_vars = self.src_state.stream_executor.variables | |
| src_var_set = set(src_vars.keys()) | |
| for child_state in self.states: | |
| child_state.stream_executor.sync() | |
| child_vars = child_state.stream_executor.variables | |
| new_vars = set(child_vars.keys()) - src_var_set | |
| for k in new_vars: | |
| if k in src_vars: | |
| src_vars[k].append(child_vars[k]) | |
| else: | |
| src_vars[k] = [child_vars[k]] | |
| elif mode == "concate_and_append": | |
| # Concatenate and append KV cache | |
| self.src_state += SglConcateAndAppend(self.states) | |
| # Need a sync here. Otherwise, `states` can be deleted. | |
| self.src_state.stream_executor.sync() | |
| else: | |
| raise ValueError(f"Invalid join mode: {mode}") | |
| for s in self.states: | |
| s.stream_executor.end() | |
| def __getitem__(self, i: int): | |
| return self.states[i] | |
| def __setitem__(self, i: int, value): | |
| assert self.states[i] == value | |
| def __iadd__(self, other): | |
| if isinstance(other, Callable): | |
| # lambda function | |
| for i in range(len(self.states)): | |
| self.states[i] += other(i) | |
| elif isinstance(other, SglExpr): | |
| for i in range(len(self.states)): | |
| self.states[i] += other | |
| elif isinstance(other, (list, tuple)): | |
| for i in range(len(self.states)): | |
| self.states[i] += other[i] | |
| else: | |
| raise ValueError(f"Invalid value: {other}") | |
| return self | |
Xet Storage Details
- Size:
- 34.8 kB
- Xet hash:
- 761196e424fe36fb609e23973612af3ba511a2d8a6f63b39c55bcb5e5272c62b
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.