| from __future__ import annotations | |
| import os | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union | |
| import torch | |
| from sglang.srt.layers.dp_attention import set_dp_buffer_len | |
| if TYPE_CHECKING: | |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | |
| _ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0"))) | |
| if _ENABLE_PROFILE: | |
| import nvtx | |
| def execute_operations(inputs, operations): | |
| stages = _convert_operations_to_stages(operations) | |
| executor = _StageExecutor("primary", stages, inputs=inputs) | |
| for _ in range(executor.num_stages): | |
| executor.next() | |
| assert executor.done | |
| return executor.output | |
| def execute_overlapped_operations( | |
| inputs_arr: Sequence, | |
| operations_arr: Sequence, | |
| delta_stages: Sequence[int], | |
| ) -> Sequence: | |
| # Make it explicit for clarity; if we need multi-batch overlap, this can be generalized | |
| inputs_a, inputs_b = inputs_arr | |
| operations_a, operations_b = operations_arr | |
| delta_stage_a, delta_stage_b = delta_stages | |
| assert delta_stage_a == 0 | |
| delta_stage = delta_stage_b | |
| stages_a = _convert_operations_to_stages(operations_a) | |
| stages_b = _convert_operations_to_stages(operations_b) | |
| executor_a = _StageExecutor("a", stages_a, inputs=inputs_a) | |
| executor_b = _StageExecutor("b", stages_b, inputs=inputs_b) | |
| for _ in range(delta_stage): | |
| executor_a.next() | |
| for _ in range(executor_a.num_stages - delta_stage): | |
| executor_a.next() | |
| executor_b.next() | |
| for _ in range(delta_stage): | |
| executor_b.next() | |
| assert executor_a.done and executor_b.done | |
| return [executor_a.output, executor_b.output] | |
| class YieldOperation: | |
| pass | |
| class ExecutionOperation: | |
| debug_name: str | |
| fn: Callable | |
| Operation = Union[YieldOperation, ExecutionOperation, Callable] | |
| Stage = List[ExecutionOperation] | |
| class _StageExecutor: | |
| def __init__(self, debug_name: str, stages: List[Stage], inputs: dict): | |
| self._debug_name = debug_name | |
| self._stages = stages | |
| self._index = 0 | |
| self._stage_state = _StateDict() | |
| self._stage_output = inputs | |
| # handling DP attention | |
| forward_batch: ForwardBatch = inputs["forward_batch"] | |
| self._global_dp_buffer_len = forward_batch.global_dp_buffer_len | |
| self._local_dp_buffer_len = forward_batch.input_ids.shape[0] | |
| self._global_num_tokens = forward_batch.global_num_tokens_cpu | |
| def next(self): | |
| assert not self.done | |
| stage = self._stages[self._index] | |
| if self._global_dp_buffer_len is not None: | |
| set_dp_buffer_len( | |
| self._global_dp_buffer_len, | |
| self._local_dp_buffer_len, | |
| self._global_num_tokens, | |
| ) | |
| with _annotate_region(debug_name=f"{self._debug_name}{self._index}"): | |
| for op in stage: | |
| with _annotate_region(debug_name=op.debug_name): | |
| self._stage_output = op.fn( | |
| state=self._stage_state, | |
| **( | |
| self._stage_output if self._stage_output is not None else {} | |
| ), | |
| ) | |
| self._index += 1 | |
| def output(self): | |
| assert self.done | |
| return self._stage_output | |
| def done(self): | |
| return self._index >= self.num_stages | |
| def num_stages(self): | |
| return len(self._stages) | |
| def _annotate_region(debug_name): | |
| if _ENABLE_PROFILE: | |
| with torch.autograd.profiler.record_function(debug_name): | |
| with nvtx.annotate(debug_name): | |
| yield | |
| else: | |
| yield | |
| class _StateDict: | |
| def __init__(self): | |
| self._data = {} | |
| def __setattr__(self, key, value): | |
| if key == "_data": | |
| super().__setattr__(key, value) | |
| return | |
| assert ( | |
| key not in self._data | |
| ), f"`{key}` already exist, are you sure you want to override it?" | |
| self._data[key] = value | |
| def __getattr__(self, item): | |
| return self._data[item] | |
| def __delattr__(self, item): | |
| del self._data[item] | |
| def pop(self, item): | |
| return self._data.pop(item) | |
| def update(self, values: Dict[str, Any]): | |
| for k, v in values.items(): | |
| setattr(self, k, v) | |
| def get(self, item): | |
| return self._data.get(item) | |
| def clear(self, expect_keys: Sequence[str]): | |
| if set(self._data.keys()) != set(expect_keys): | |
| raise Exception( | |
| f"Unexpected keys when clearning. This may indicate you do not release memory early enough but leave it to here. {list(self._data.keys())=} {expect_keys=}" | |
| ) | |
| self._data.clear() | |
| def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]: | |
| operations = _decorate_operations(operations) | |
| operation_chunks = list( | |
| _chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation)) | |
| ) | |
| assert all(len(chunk) > 0 for chunk in operation_chunks) | |
| return operation_chunks | |
| def _chunk_by_separator( | |
| items: List[Any], is_separator: Callable[[Any], bool] | |
| ) -> Generator[List[Any], None, None]: | |
| pending_items = [] | |
| for item in items: | |
| if is_separator(item): | |
| yield pending_items | |
| pending_items = [] | |
| else: | |
| pending_items.append(item) | |
| if len(pending_items) > 0: | |
| yield pending_items | |
| def _decorate_operations(operations: List[Operation], debug_name_prefix: str = ""): | |
| return [_decorate_operation(op, debug_name_prefix) for op in operations] | |
| def _decorate_operation(operation: Operation, debug_name_prefix: str): | |
| if isinstance(operation, YieldOperation): | |
| return operation | |
| return ExecutionOperation( | |
| debug_name=debug_name_prefix | |
| + getattr(operation, "__name__", "unknown").replace("op_", ""), | |
| fn=operation, | |
| ) | |
Xet Storage Details
- Size:
- 6.09 kB
- Xet hash:
- 45f54edd08d009301aa1dfd02baa865a5dfd9c3acb4c14bc7e58a0ec17e25c34
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.