| import multiprocessing | |
| from concurrent.futures import ThreadPoolExecutor | |
| from queue import Queue | |
| from typing import List, Union | |
| from sglang.global_config import global_config | |
| from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program | |
| from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable | |
| def compile_func(function, backend): | |
| tracer = function.trace(backend=backend) | |
| compiler = CompiledFunction(tracer, function) | |
| return compiler | |
| class CompiledFunction: | |
| def __init__(self, tracer, function): | |
| self.function = function | |
| self.last_node = CompGraphNode(tracer.last_node) | |
| self.expr_to_node = {} | |
| self.build_graph(tracer) | |
| self.topological_sort() | |
| def build_graph(self, tracer): | |
| self.nodes = [self.last_node] | |
| self.expr_to_node[tracer.last_node] = self.nodes[-1] | |
| rename_pid = {} | |
| visited = set([tracer.last_node]) | |
| head = 0 | |
| while head < len(self.nodes): | |
| cur_node = self.nodes[head] | |
| # add prev node | |
| prev_node = cur_node.expr.prev_node | |
| if prev_node is not None: | |
| if prev_node not in visited: | |
| visited.add(prev_node) | |
| self.nodes.append(CompGraphNode(prev_node)) | |
| self.expr_to_node[prev_node] = self.nodes[-1] | |
| cur_node.prev_node = self.expr_to_node[prev_node] | |
| self.expr_to_node[prev_node].add_next_node(cur_node) | |
| # add source node | |
| if isinstance(cur_node.expr, SglVariable): | |
| if cur_node.expr.name in tracer.variables: | |
| source = tracer.variables[cur_node.expr.name].source | |
| else: | |
| source = cur_node.expr.source | |
| if source not in visited: | |
| visited.add(source) | |
| self.nodes.append(CompGraphNode(source)) | |
| self.expr_to_node[source] = self.nodes[-1] | |
| cur_node.source_node = self.expr_to_node[source] | |
| self.expr_to_node[source].add_next_node(cur_node) | |
| head += 1 | |
| # rename pid | |
| if cur_node.expr.pid not in rename_pid: | |
| rename_pid[cur_node.expr.pid] = len(rename_pid) | |
| cur_node.expr.pid = rename_pid[cur_node.expr.pid] | |
| def topological_sort(self): | |
| prevd = {} | |
| cand = Queue() | |
| for x in self.nodes: | |
| prevd[x] = (x.prev_node is not None) + (x.source_node is not None) | |
| if prevd[x] == 0: | |
| cand.put(x) | |
| new_list = [] | |
| while cand.qsize() > 0: | |
| head = cand.get() | |
| new_list.append(head) | |
| for x in head.next_nodes: | |
| prevd[x] -= 1 | |
| if prevd[x] == 0: | |
| cand.put(x) | |
| self.nodes = new_list | |
| def print_graph( | |
| self, | |
| ): | |
| for node in self.nodes: | |
| print(node) | |
| def run_internal( | |
| self, | |
| backend, | |
| kwargs, | |
| default_sampling_para, | |
| ): | |
| stream_executor_ids = set([x.expr.pid for x in self.nodes]) | |
| stream_executors = {} | |
| for x in stream_executor_ids: | |
| arguments = kwargs if x == self.last_node.expr.pid else {} | |
| stream_executors[x] = StreamExecutor( | |
| backend, arguments, default_sampling_para, None, False | |
| ) | |
| for node in self.nodes: | |
| se_id = node.expr.pid | |
| expr = node.expr | |
| if isinstance(expr, SglVariable): | |
| # Make a copy for SglVariable | |
| expr = SglVariable(expr.name, expr.source) | |
| expr.source_stream_executor = stream_executors[ | |
| node.source_node.expr.pid | |
| ] | |
| elif isinstance(expr, SglArgument): | |
| # Substitute SglArgument | |
| expr = kwargs[expr.name] | |
| stream_executors[se_id].submit(expr) | |
| for stream_executor in stream_executors.values(): | |
| stream_executor.end() | |
| return ProgramState(stream_executors[self.last_node.expr.pid]) | |
| def run( | |
| self, | |
| *, | |
| max_new_tokens: int = 128, | |
| stop: Union[str, List[str]] = (), | |
| temperature: float = 1.0, | |
| top_p: float = 1.0, | |
| top_k: int = -1, | |
| min_p: float = 0.0, | |
| frequency_penalty: float = 0.0, | |
| presence_penalty: float = 0.0, | |
| backend=None, | |
| **kwargs, | |
| ): | |
| backend = backend or global_config.default_backend | |
| kwargs.update(self.function.bind_arguments) | |
| default_sampling_para = SglSamplingParams( | |
| max_new_tokens=max_new_tokens, | |
| stop=stop, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| min_p=min_p, | |
| frequency_penalty=frequency_penalty, | |
| presence_penalty=presence_penalty, | |
| ) | |
| return self.run_internal(backend, kwargs, default_sampling_para) | |
| def run_batch( | |
| self, | |
| batch_kwargs, | |
| *, | |
| max_new_tokens: int = 128, | |
| stop: Union[str, List[str]] = (), | |
| temperature: float = 1.0, | |
| top_p: float = 1.0, | |
| top_k: int = -1, | |
| min_p: float = 0.0, | |
| frequency_penalty: float = 0.0, | |
| presence_penalty: float = 0.0, | |
| backend=None, | |
| num_threads: Union[str, int] = "auto", | |
| ): | |
| assert isinstance(batch_kwargs, (list, tuple)) | |
| if len(batch_kwargs) == 0: | |
| return [] | |
| assert isinstance(batch_kwargs[0], dict) | |
| backend = backend or global_config.default_backend | |
| default_sampling_para = SglSamplingParams( | |
| max_new_tokens=max_new_tokens, | |
| stop=stop, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| min_p=min_p, | |
| frequency_penalty=frequency_penalty, | |
| presence_penalty=presence_penalty, | |
| ) | |
| # Extract prefix by tracing and cache it | |
| if len(batch_kwargs) > 1: | |
| cache_program(self.function, backend) | |
| # Run all programs | |
| if num_threads == "auto": | |
| num_threads = multiprocessing.cpu_count() | |
| num_threads = min(num_threads, len(batch_kwargs)) | |
| if num_threads == 1: | |
| rets = [] | |
| for arguments in batch_kwargs: | |
| rets.append( | |
| self.run_internal(backend, arguments, default_sampling_para) | |
| ) | |
| else: | |
| with ThreadPoolExecutor(num_threads) as executor: | |
| futures = [] | |
| for arguments in batch_kwargs: | |
| futures.append( | |
| executor.submit( | |
| self.run_internal, backend, arguments, default_sampling_para | |
| ) | |
| ) | |
| rets = [f.result() for f in futures] | |
| rets[-1].sync() | |
| return rets | |
| class CompGraphNode: | |
| def __init__( | |
| self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None | |
| ): | |
| self.expr = expr | |
| self.next_nodes = next_nodes or [] | |
| self.prev_node = prev_node | |
| self.source_node = source_node | |
| def add_next_node(self, other): | |
| self.next_nodes.append(other) | |
| def __repr__(self): | |
| re = f"stream {self.expr.pid:2d}: " | |
| re += f"%{self.expr.node_id} = " | |
| if self.prev_node is not None: | |
| re += f"%{self.prev_node.expr.node_id} + " | |
| re += repr(self.expr) | |
| return re | |
Xet Storage Details
- Size:
- 7.6 kB
- Xet hash:
- 209f675345c01f29e34daaf9cd35cc536c79c77ff9d108382602dc3215effd5e
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.