| import contextvars | |
| import inspect | |
| import logging | |
| import os | |
| import sys | |
| import types | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Optional, Union | |
| import torch | |
| from sglang.srt.compilation.compilation_config import CompilationConfig | |
| logger = logging.getLogger(__name__) | |
| _COMPILE_ENABLED = contextvars.ContextVar("_COMPILE_ENABLED", default=False) | |
| def set_compiled(enabled: bool = True): | |
| token = _COMPILE_ENABLED.set(enabled) | |
| try: | |
| yield | |
| finally: | |
| _COMPILE_ENABLED.reset(token) | |
| class IntermediateTensors: | |
| """For all pipeline stages except the last, we need to return the hidden | |
| states and residuals to be sent to the next stage. This data structure | |
| contains the hidden states and residuals for a request. | |
| Each stage also needs to handle its own finished_sending and | |
| finished_recving in case of kv transfer. | |
| """ | |
| tensors: dict[str, torch.Tensor] | |
| # [req_ids] | |
| finished_sending: Optional[set[str]] = None | |
| finished_recving: Optional[set[str]] = None | |
| def __init__(self, tensors): | |
| # manually define this function, so that | |
| # Dynamo knows `IntermediateTensors()` comes from this file. | |
| # Otherwise, dataclass will generate this function by evaluating | |
| # a string, and we will lose the information about the source file. | |
| self.tensors = tensors | |
| def __getitem__(self, key: Union[str, slice]): | |
| if isinstance(key, str): | |
| return self.tensors[key] | |
| elif isinstance(key, slice): | |
| return self.__class__({k: v[key] for k, v in self.tensors.items()}) | |
| def __setitem__(self, key: str, value: torch.Tensor): | |
| self.tensors[key] = value | |
| def items(self): | |
| return self.tensors.items() | |
| def __len__(self): | |
| return len(self.tensors) | |
| def __eq__(self, other: object): | |
| return isinstance(other, self.__class__) and self | |
| def __repr__(self) -> str: | |
| return f"IntermediateTensors(tensors={self.tensors})" | |
| def _normalize_dims(dims, ndim: int): | |
| dims = [dims] if isinstance(dims, int) else list(dims) | |
| return [d if d >= 0 else ndim + d for d in dims] | |
| class _MaybeIntermediateTensors: | |
| """Duck-typed check to support your IntermediateTensors without importing.""" | |
| def __init__(self, obj): | |
| self.is_intermediate = hasattr(obj, "tensors") and isinstance( | |
| getattr(obj, "tensors"), dict | |
| ) | |
| self.obj = obj | |
| def _mark_dynamic_on_value(val, dims): | |
| if isinstance(val, torch.Tensor): | |
| torch._dynamo.mark_dynamic(val, _normalize_dims(dims, val.ndim)) | |
| else: | |
| mit = _MaybeIntermediateTensors(val) | |
| if mit.is_intermediate: | |
| for t in mit.obj.tensors.values(): | |
| torch._dynamo.mark_dynamic(t, _normalize_dims(dims, t.ndim)) | |
| # else: ignore (None or non-tensor) | |
| def _infer_dynamic_arg_dims_from_annotations(forward_fn): | |
| sig = inspect.signature(forward_fn) | |
| dyn = {} | |
| for name, p in sig.parameters.items(): | |
| ann = p.annotation | |
| # Accept torch.Tensor / Optional[torch.Tensor] / your IntermediateTensors types by name | |
| if ( | |
| ann is torch.Tensor | |
| or getattr(getattr(ann, "__args__", [None])[0], "__name__", "") == "Tensor" | |
| ): | |
| dyn[name] = 0 | |
| elif getattr(ann, "__name__", "") in ("IntermediateTensors",) or any( | |
| getattr(a, "__name__", "") == "IntermediateTensors" | |
| for a in getattr(ann, "__args__", []) | |
| ): | |
| dyn[name] = 0 | |
| if not dyn: | |
| raise ValueError("No dynamic dims inferred; pass dynamic_arg_dims explicitly.") | |
| return dyn | |
| def install_torch_compiled( | |
| module: torch.nn.Module, | |
| *, | |
| dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None, | |
| backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None, | |
| compile_config: CompilationConfig = None, | |
| fullgraph: bool = True, | |
| graph_pool: Any = None, | |
| ): | |
| unbound_fwd = module.__class__.forward | |
| if not callable(unbound_fwd): | |
| raise TypeError("module.__class__.forward must be callable") | |
| original_code = unbound_fwd.__code__ | |
| dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd) | |
| if backend_factory is None: | |
| from sglang.srt.compilation.backend import SGLangBackend | |
| backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)( | |
| gm, ex | |
| ) | |
| compiled_codes: list[type(original_code)] = [] | |
| state = {"compiled": False, "compiled_callable": None} | |
| def bytecode_hook(old_code, new_code): | |
| if old_code is not original_code: | |
| return | |
| frame = sys._getframe() | |
| while frame and frame.f_back: | |
| frame = frame.f_back | |
| if ( | |
| frame.f_code.co_name == "_compile" | |
| and os.path.basename(frame.f_code.co_filename) == "convert_frame.py" | |
| ): | |
| break | |
| try: | |
| dynamo_frame = frame.f_locals["frame"] | |
| except Exception: | |
| return | |
| if dynamo_frame.f_code is not old_code: | |
| return | |
| if dynamo_frame.f_locals.get("self") is not module: | |
| return | |
| compiled_codes.append(new_code) | |
| torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook) | |
| def _ensure_compiled(self, *args, **kwargs): | |
| """Compile on first use (with flag ON).""" | |
| if state["compiled"]: | |
| return | |
| # Mark dynamic dims only when we are about to compile | |
| sig = inspect.signature(unbound_fwd) | |
| ba = sig.bind(self, *args, **kwargs) | |
| ba.apply_defaults() | |
| for name, dims in (dyn_map or {}).items(): | |
| if name in ba.arguments: | |
| val = ba.arguments[name] | |
| if val is not None: | |
| _mark_dynamic_on_value(val, dims) | |
| # Avoid cross-instance cache reuse | |
| torch._dynamo.eval_frame.remove_from_cache(unbound_fwd.__code__) | |
| bound = types.MethodType(unbound_fwd, self) | |
| compiled_callable = torch.compile( | |
| bound, fullgraph=fullgraph, backend=backend_factory | |
| ) | |
| # Trigger Dynamo so bytecode hook can capture | |
| compiled_callable(*args, **kwargs) | |
| state["compiled"] = True | |
| state["compiled_callable"] = compiled_callable | |
| def trampoline(self, *args, **kwargs): | |
| use_compiled = _COMPILE_ENABLED.get() | |
| if use_compiled: | |
| if not state["compiled"]: | |
| _ensure_compiled(self, *args, **kwargs) | |
| compiled_callable = state["compiled_callable"] | |
| return compiled_callable(*args, **kwargs) | |
| else: | |
| # Explicitly run the original uncompiled forward | |
| return unbound_fwd(self, *args, **kwargs) | |
| module.forward = types.MethodType(trampoline, module) | |
| return module | |
Xet Storage Details
- Size:
- 6.97 kB
- Xet hash:
- 5d3dd246f129bbee6ce79cf4cd1c09936beace09fddb7983771ac3e9caab8193
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.