leideng's picture
download
raw
6.97 kB
import contextvars
import inspect
import logging
import os
import sys
import types
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import torch
from sglang.srt.compilation.compilation_config import CompilationConfig
logger = logging.getLogger(__name__)
_COMPILE_ENABLED = contextvars.ContextVar("_COMPILE_ENABLED", default=False)
@contextmanager
def set_compiled(enabled: bool = True):
token = _COMPILE_ENABLED.set(enabled)
try:
yield
finally:
_COMPILE_ENABLED.reset(token)
@dataclass
class IntermediateTensors:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
Each stage also needs to handle its own finished_sending and
finished_recving in case of kv transfer.
"""
tensors: dict[str, torch.Tensor]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
def __init__(self, tensors):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors
def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
return self.tensors[key]
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value: torch.Tensor):
self.tensors[key] = value
def items(self):
return self.tensors.items()
def __len__(self):
return len(self.tensors)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self
def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"
def _normalize_dims(dims, ndim: int):
dims = [dims] if isinstance(dims, int) else list(dims)
return [d if d >= 0 else ndim + d for d in dims]
class _MaybeIntermediateTensors:
"""Duck-typed check to support your IntermediateTensors without importing."""
def __init__(self, obj):
self.is_intermediate = hasattr(obj, "tensors") and isinstance(
getattr(obj, "tensors"), dict
)
self.obj = obj
def _mark_dynamic_on_value(val, dims):
if isinstance(val, torch.Tensor):
torch._dynamo.mark_dynamic(val, _normalize_dims(dims, val.ndim))
else:
mit = _MaybeIntermediateTensors(val)
if mit.is_intermediate:
for t in mit.obj.tensors.values():
torch._dynamo.mark_dynamic(t, _normalize_dims(dims, t.ndim))
# else: ignore (None or non-tensor)
def _infer_dynamic_arg_dims_from_annotations(forward_fn):
sig = inspect.signature(forward_fn)
dyn = {}
for name, p in sig.parameters.items():
ann = p.annotation
# Accept torch.Tensor / Optional[torch.Tensor] / your IntermediateTensors types by name
if (
ann is torch.Tensor
or getattr(getattr(ann, "__args__", [None])[0], "__name__", "") == "Tensor"
):
dyn[name] = 0
elif getattr(ann, "__name__", "") in ("IntermediateTensors",) or any(
getattr(a, "__name__", "") == "IntermediateTensors"
for a in getattr(ann, "__args__", [])
):
dyn[name] = 0
if not dyn:
raise ValueError("No dynamic dims inferred; pass dynamic_arg_dims explicitly.")
return dyn
def install_torch_compiled(
module: torch.nn.Module,
*,
dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None,
backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None,
compile_config: CompilationConfig = None,
fullgraph: bool = True,
graph_pool: Any = None,
):
unbound_fwd = module.__class__.forward
if not callable(unbound_fwd):
raise TypeError("module.__class__.forward must be callable")
original_code = unbound_fwd.__code__
dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd)
if backend_factory is None:
from sglang.srt.compilation.backend import SGLangBackend
backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)(
gm, ex
)
compiled_codes: list[type(original_code)] = []
state = {"compiled": False, "compiled_callable": None}
def bytecode_hook(old_code, new_code):
if old_code is not original_code:
return
frame = sys._getframe()
while frame and frame.f_back:
frame = frame.f_back
if (
frame.f_code.co_name == "_compile"
and os.path.basename(frame.f_code.co_filename) == "convert_frame.py"
):
break
try:
dynamo_frame = frame.f_locals["frame"]
except Exception:
return
if dynamo_frame.f_code is not old_code:
return
if dynamo_frame.f_locals.get("self") is not module:
return
compiled_codes.append(new_code)
torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)
def _ensure_compiled(self, *args, **kwargs):
"""Compile on first use (with flag ON)."""
if state["compiled"]:
return
# Mark dynamic dims only when we are about to compile
sig = inspect.signature(unbound_fwd)
ba = sig.bind(self, *args, **kwargs)
ba.apply_defaults()
for name, dims in (dyn_map or {}).items():
if name in ba.arguments:
val = ba.arguments[name]
if val is not None:
_mark_dynamic_on_value(val, dims)
# Avoid cross-instance cache reuse
torch._dynamo.eval_frame.remove_from_cache(unbound_fwd.__code__)
bound = types.MethodType(unbound_fwd, self)
compiled_callable = torch.compile(
bound, fullgraph=fullgraph, backend=backend_factory
)
# Trigger Dynamo so bytecode hook can capture
compiled_callable(*args, **kwargs)
state["compiled"] = True
state["compiled_callable"] = compiled_callable
def trampoline(self, *args, **kwargs):
use_compiled = _COMPILE_ENABLED.get()
if use_compiled:
if not state["compiled"]:
_ensure_compiled(self, *args, **kwargs)
compiled_callable = state["compiled_callable"]
return compiled_callable(*args, **kwargs)
else:
# Explicitly run the original uncompiled forward
return unbound_fwd(self, *args, **kwargs)
module.forward = types.MethodType(trampoline, module)
return module

Xet Storage Details

Size:
6.97 kB
·
Xet hash:
5d3dd246f129bbee6ce79cf4cd1c09936beace09fddb7983771ac3e9caab8193

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.