jiadisu
Switch back to Docker SDK with local pkgs
e6066e8
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import inspect
import os
from contextlib import contextmanager
from typing import Callable, TypeVar, get_args, get_origin, overload
from unittest.mock import patch
import magi_compiler.utils.envs as envs
import torch
from magi_compiler.cuda.cudart import pin_memory_in_place
from magi_compiler.magi_compiler_base import MagiCompilerBase
from magi_compiler.utils import compilation_counter, magi_logger
from magi_compiler.utils.compile_time_monitor import CompileMonitor
from torch import distributed as dist
from torch import nn
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
from .config import CompileConfig, CompileMode, get_compile_config
# =============================================================================
# Workaround: TorchInductor autotune get_raw_stream
# =============================================================================
# TorchInductor autotune code blocks may reference get_raw_stream() without
# defining it, causing "name 'get_raw_stream' is not defined" at runtime.
# Register it as a builtin so the exec'd autotune snippets can always find it.
def _patch_get_raw_stream():
try:
import builtins
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
except Exception:
return
if not hasattr(builtins, "get_raw_stream"):
builtins.get_raw_stream = _get_raw_stream
_patch_get_raw_stream()
# =============================================================================
# Dynamo Config Isolation
# =============================================================================
# Capture the default dynamo config at module load time (before any torch.compile).
# This ensures we have a "clean" baseline config that hasn't been modified by
# external torch.compile calls (e.g., with dynamic=True).
_DEFAULT_DYNAMO_CONFIG: dict = torch._dynamo.config.get_config_copy()
@contextmanager
def _isolated_dynamo_config():
"""
Context manager that provides an isolated dynamo config environment.
"""
with torch._dynamo.config.patch(**_DEFAULT_DYNAMO_CONFIG):
yield
_T = TypeVar("_T", bound=type[nn.Module])
_W = TypeVar("_W", bound="MagiCompilerBase")
@overload
def magi_compile(*, enable_if: Callable[None, bool] | None = None) -> Callable[[_T], _T]:
...
@overload
def magi_compile(*, dynamic_arg_dims: dict[str, int | list[int]] | None) -> Callable[[_T], _T]:
...
@overload
def magi_compile(*, config_patch: Callable[[CompileConfig], CompileConfig] | None = None) -> Callable[[_T], _T]:
...
@overload
def magi_compile(cls: _T) -> _T:
...
def magi_compile(
cls: _T | None = None,
*,
model_tag: str | None = None,
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[None, bool] | None = None,
config_patch: Callable[[CompileConfig], CompileConfig] | None = None,
) -> Callable[[_T], _T] | _T:
"""
A decorator to add support for compiling the forward method of a class.
Usage:
1. use directly as a decorator without arguments:
```python
@magi_compile
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
```
2. use as a decorator with arguments:
```python
@magi_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
```
Arguments:
- model_tag: optional tag in cache path (e.g. "wan_ti2v"). If not set, class name is used.
Path segment: model_{idx}_{model_tag}_rank_{rank}.
- dynamic_arg_dims: a dictionary that maps argument names to the dynamic
dimensions of the argument. The dynamic dimensions can be either a single
integer or a list of integers.
- enable_if: a function that returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain conditions are met.
Notes:
- dynamic_arg_dims will be inferred from the type annotation of the forward method if not provided,
if the argument is annotated as `torch.Tensor` or `Optional[torch.Tensor]`,
the first dimension will be marked as dynamic.
- if an argument is `None`, it should always be passed as `None` during
the lifetime of the model, otherwise, it cannot be captured as a single
computation graph.
"""
def cls_decorator_helper(cls: _T) -> _T:
nonlocal dynamic_arg_dims
dynamic_arg_dims = dynamic_arg_dims or _infer_dynamic_arg_dims(cls)
# Accuracy check
assert hasattr(cls, "forward"), "decorated class should have a forward method."
assert len(dynamic_arg_dims) > 0, (
"No dynamic dimensions found in the forward method of " f"{cls}. Please provide dynamic_arg_dims explicitly."
)
for k in dynamic_arg_dims:
assert k in inspect.signature(cls.forward).parameters, f"Argument {k} not found in the forward method of {cls}"
return _magi_compile(cls, dynamic_arg_dims, enable_if, config_patch, model_tag=model_tag)
if cls is not None:
# use `magi_compile` as a decorator without arguments, cls is the class to be decorated
assert isinstance(cls, type)
return cls_decorator_helper(cls)
return cls_decorator_helper
def offload(obj):
if isinstance(obj, torch.Tensor):
return obj.cpu()
elif isinstance(obj, dict):
return {k: offload(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(offload(item) for item in obj)
return obj
def _magi_compile(
cls: _T,
dynamic_arg_dims: dict[str, int | list[int]],
enable_if: Callable[None, bool] | None = None,
config_patch: Callable[[CompileConfig], CompileConfig] | None = None,
model_tag: str | None = None,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
"""
if MagiCompilerBase in cls.__bases__:
return cls
# take care of method resolution order, make sure super().__init__ is called on the base class
# other than MagiCompilerBase
cls.__bases__ = cls.__bases__ + (MagiCompilerBase,)
if get_compile_config().offload_config.model_cpu_offload:
magi_logger.info(f"Enabling CPU offload for {cls}")
_orig_apply = cls._apply
def _cpu_apply(self, fn):
if getattr(self, "_magi_offloaded_once", False):
return _orig_apply(self, fn)
# First, move all parameters/buffers to CPU
def _force_cpu(t):
return fn(t).cpu()
_orig_apply(self, _force_cpu)
# create shared memory tensors for all parameters/buffers on CPU
if dist.is_initialized():
local_rank = int(os.environ.get("LOCAL_RANK", 0))
full_state_dict = self.state_dict()
grouped_params = {} # {dtype: [(name, tensor), ...]}
for name, tensor in full_state_dict.items():
if tensor.device.type == 'cpu':
dt = tensor.dtype
if dt not in grouped_params:
grouped_params[dt] = []
grouped_params[dt].append((name, tensor))
shared_state_dict = {}
self._magi_giant_buffers = []
dist.barrier()
for dtype, param_list in grouped_params.items():
dtype_str = str(dtype).split('.')[-1]
shared_bin_path = (
f"{envs.MAGI_SHARED_BIN_PATH}/magi_model_shared_{dtype_str}_{self.__class__.__name__}.bin"
)
total_numel = sum(t.numel() for _, t in param_list)
if local_rank == 0:
flat_buffer = torch.zeros(total_numel, dtype=dtype)
offset = 0
for _, tensor in param_list:
numel = tensor.numel()
flat_buffer[offset : offset + numel].copy_(tensor.view(-1))
offset += numel
if dtype == torch.bfloat16:
flat_buffer.view(torch.int16).numpy().tofile(shared_bin_path)
elif dtype.itemsize == 1 and dtype.is_floating_point:
# fp8
flat_buffer.view(torch.uint8).numpy().tofile(shared_bin_path)
else:
flat_buffer.numpy().tofile(shared_bin_path)
del flat_buffer
gc.collect()
dist.barrier()
giant_shared_tensor = torch.from_file(
shared_bin_path, shared=True, size=total_numel, dtype=dtype, device="cpu"
)
self._magi_giant_buffers.append(giant_shared_tensor)
pin_memory_in_place(giant_shared_tensor)
offset = 0
for name, original_tensor in param_list:
numel = original_tensor.numel()
shared_param = giant_shared_tensor[offset : offset + numel].view(original_tensor.shape)
if original_tensor.requires_grad:
shared_param.requires_grad_(True)
shared_state_dict[name] = shared_param
offset += numel
dist.barrier()
if local_rank == 0:
if os.path.exists(shared_bin_path):
os.remove(shared_bin_path)
self.load_state_dict(shared_state_dict, assign=True)
else:
def _pinner(t):
return t.pin_memory()
_orig_apply(self, _pinner)
self._magi_offloaded_once = True
return self
cls._apply = _cpu_apply
old_init = cls.__init__
def __init__(self: _W, *args, **kwargs):
old_init(self, *args, **kwargs)
compile_config = get_compile_config()
if config_patch is not None:
compile_config = config_patch(compile_config)
# deepcopy the compile config to avoid modifying the original compile config
self.compile_config = copy.deepcopy(compile_config)
enable_compile = enable_if is None or enable_if()
self.enable_compile = self.compile_config.compile_mode != CompileMode.NONE and enable_compile
if not self.enable_compile:
return
compilation_counter.num_models_seen += 1
self.compile_config.model_idx = compilation_counter.num_models_seen
self.compile_config.model_tag = model_tag if model_tag is not None else self.__class__.__name__
MagiCompilerBase.__init__(self, compile_config=self.compile_config)
cls.__init__ = __init__
old_call = cls.__call__
def __call__(self: _W, *args, **kwargs):
### Step1: Run compiled module directly if disable compile or captured before ###
if self.compile_config.offload_config.model_cpu_offload and self.compiled_code is None:
args = offload(args)
kwargs = offload(kwargs)
if not self.enable_compile or torch.compiler.is_compiling():
# Skip compiling the model if inside the compilation process.
return old_call(self, *args, **kwargs)
if self.compiled_code is not None:
# Run the compiled function if compiled code is available.
with self.dispatch_to_compiled_fwd(mode="jit"):
return old_call(self, *args, **kwargs)
if envs.MAGI_AOT_COMPILE:
# Try load AOT artifacts from cache and run directly.
self.aot_compiled_fn = self.try_load_aot_compile_artifacts()
if self.aot_compiled_fn is not None:
with self.dispatch_to_compiled_fwd(mode="aot"):
return old_call(self, *args, **kwargs)
### Step2: Mark dynamic shapes for the first compilation ###
bound_args = inspect.signature(self.__class__.forward).bind(self, *args, **kwargs)
bound_args.apply_defaults()
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is None:
continue
dims = [dims] if isinstance(dims, int) else dims
assert isinstance(arg, torch.Tensor), f"Unsupported dynamic dim {dims} for argument {k} with type {type(arg)}."
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.mark_dynamic(arg, dims)
### Step3: Start compiling the model ###
magi_logger.info(f"Start compiling function {self.original_code_object}")
CompileMonitor().start(
self.compile_config.compile_mode == CompileMode.MAGI_COMPILE, self.compile_config.debug_dump_path()
)
# Dynamo reuse the compilation across instances, but we need to make sure the compiled code is not reused.
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object)
with (
_hijack_inline_call_to_collect_traced_files(self),
patch.object(torch.compiler.config, "dynamic_sources", self.compile_config.dynamic_sources),
patch.object(torch._dynamo.config, "enable_cpp_symbolic_shape_guards", False),
# 允许 mark_dynamic 在 module 属性链上的 tensor 生效
# (默认 True 会强制 module property tensor 为 static shape,忽略 mark_dynamic)
patch.object(torch._dynamo.config, "force_nn_module_property_static_shapes", False),
patch.dict(
os.environ, {"TORCHINDUCTOR_CACHE_DIR": (self.compile_config.cache_dump_path() / "inductor_cache").as_posix()}
),
):
if envs.MAGI_AOT_COMPILE:
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
self.aot_compiled_fn.save_compiled_function(self.aot_compilation_path)
with self.dispatch_to_compiled_fwd(mode="aot"):
output = old_call(self, *args, **kwargs)
else:
with patch.object(self, "forward", self.jit_compile):
output = old_call(self, *args, **kwargs)
return output
# 使用 @torch.compiler.disable 和 _isolated_dynamo_config 包裹整个 __call__
# 确保 magi compile 在外部嵌套 torch.compile 时也能独立工作不受影响
isolated_call = _isolated_dynamo_config()(__call__)
cls.__call__ = torch.compiler.disable(isolated_call)
return cls
# Collect all relevant files traced by Dynamo, re-compile the model when any of these files change.
# 1. the file containing the top-level forward function
# 2. hijack function to know all the functions called during Dynamo tracing, every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call_
def _hijack_inline_call_to_collect_traced_files(owner: _W):
owner.compile_config.traced_files.add(owner.original_code_object.co_filename)
inline_call = InliningInstructionTranslator.inline_call_
def patched_inline_call(self_):
code = self_.f_code
owner.compile_config.traced_files.add(code.co_filename)
return inline_call(self_)
return patch.object(InliningInstructionTranslator, "inline_call_", patched_inline_call)
def _infer_dynamic_arg_dims(cls: _T) -> dict[str, int | list[int]]:
sig = inspect.signature(cls.forward)
inferred_dynamic_arg_dims = {}
for k, v in sig.parameters.items():
if v.annotation in [torch.Tensor, torch.Tensor | None]:
inferred_dynamic_arg_dims[k] = 0
magi_logger.info(f"Inferred dynamic dimensions for forward method of {cls}: {list(inferred_dynamic_arg_dims.keys())}")
return inferred_dynamic_arg_dims
def _get_num_outputs_from_return_annotation(fn: Callable) -> int:
"""
Get the number of outputs from the function's return type annotation.
Returns:
- 1 if the return type is a single Tensor
- N if the return type is tuple[Tensor, Tensor, ...] with N elements
- 1 if no annotation or unrecognized annotation (default to single output)
"""
sig = inspect.signature(fn)
return_annotation = sig.return_annotation
if return_annotation is inspect.Parameter.empty:
return 1
# Check if it's a tuple type (e.g., tuple[Tensor, Tensor])
origin = get_origin(return_annotation)
if origin is tuple:
args = get_args(return_annotation)
# Filter out ellipsis (for variable-length tuples like tuple[Tensor, ...])
if args and args[-1] is not ...:
return len(args)
return 1
return 1
def _generate_op_name(fn: Callable) -> str:
"""
Generate a unique operator name from function's name and source file.
The generated name follows the format: namespace::op_name
- namespace: derived from the source file path (module-like structure)
- op_name: the function name
Example:
Function `_my_custom_op` in file `/path/to/my_module.py`
-> "my_module::_my_custom_op"
"""
import re
from pathlib import Path
func_name = fn.__name__
# Get the source file path
try:
source_file = inspect.getfile(fn)
# Extract the file stem (without extension) as namespace
namespace = Path(source_file).stem
# Clean up namespace: replace invalid characters with underscores
namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
except (TypeError, OSError):
# If we can't get the source file, use a default namespace
namespace = "magi_custom"
return f"{namespace}::{func_name}"
def _create_identity_meta_fn(fn: Callable) -> Callable:
"""
Create a default identity meta function for the given function.
This identity meta function assumes that:
- The number of outputs is determined by the function's return type annotation
- Each output's metadata (shape, dtype, device) matches the corresponding input tensor
For example, if a function has signature:
def my_op(a: Tensor, b: Tensor, scale: float) -> tuple[Tensor, Tensor]:
The identity meta function will return:
(torch.empty_like(a), torch.empty_like(b))
"""
num_outputs = _get_num_outputs_from_return_annotation(fn)
sig = inspect.signature(fn)
# Get parameter names, excluding 'self' if present
param_names = [name for name in sig.parameters.keys() if name != "self"]
def identity_meta_fn(*args, **kwargs):
# Bind arguments to get a mapping of param_name -> value
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
# Collect the first `num_outputs` tensor arguments
tensor_args = []
for name in param_names:
arg = bound.arguments.get(name)
if isinstance(arg, torch.Tensor):
tensor_args.append(arg)
if len(tensor_args) >= num_outputs:
break
if len(tensor_args) < num_outputs:
raise ValueError(
f"identity_meta_fn requires at least {num_outputs} tensor inputs to match "
f"{num_outputs} outputs, but only found {len(tensor_args)} tensor inputs. "
f"Please provide a custom infer_output_meta_fn."
)
# Return outputs with same metadata as the first N inputs
if num_outputs == 1:
return torch.empty_like(tensor_args[0])
return tuple(torch.empty_like(t) for t in tensor_args[:num_outputs])
return identity_meta_fn
def _create_meta_fn_from_param_names(fn: Callable, param_names: list[str]) -> Callable:
"""
Create a meta function that returns torch.empty_like() for each specified parameter.
This is useful when output tensors have the same shape/dtype/device as specific input
parameters, but not necessarily in positional order.
Example:
param_names = ["weight", "bias"]
def my_op(grad: Tensor, weight: Tensor, bias: Tensor) -> tuple[Tensor, Tensor]:
...
Generated meta function returns:
(torch.empty_like(weight), torch.empty_like(bias))
"""
sig = inspect.signature(fn)
def meta_fn(*args, **kwargs):
# Bind arguments to get a mapping of param_name -> value
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
# Collect tensors for each specified parameter name
tensor_outputs = []
for name in param_names:
if name not in bound.arguments:
raise ValueError(
f"Parameter '{name}' not found in function signature. "
f"Available parameters: {list(bound.arguments.keys())}"
)
arg = bound.arguments[name]
if not isinstance(arg, torch.Tensor):
raise ValueError(
f"Parameter '{name}' is not a Tensor (got {type(arg).__name__}). "
f"infer_output_meta_fn list should only contain tensor parameter names."
)
tensor_outputs.append(torch.empty_like(arg))
# Return single tensor or tuple based on number of outputs
if len(tensor_outputs) == 1:
return tensor_outputs[0]
return tuple(tensor_outputs)
return meta_fn
def magi_register_custom_op(
name: str | None = None,
mutates_args: tuple[str, ...] = (),
infer_output_meta_fn: Callable | list[str] | None = None,
setup_context_fn: Callable | None = None,
backward_fn: Callable | None = None,
):
"""
A unified decorator to register a custom operator with PyTorch's library.
This decorator combines the functionality of:
- @torch.library.custom_op
- @torch.library.register_fake
- fn.register_autograd
Arguments:
name: The fully qualified name of the operator (e.g., "namespace::op_name").
If None, auto-generated from the function name and source file.
mutates_args: Tuple of argument names that are mutated by the operator.
infer_output_meta_fn: Specifies output tensor metadata (shape, dtype, device) for tracing.
- None (default): Assumes each output has the same metadata as the corresponding
input tensor (1st output matches 1st tensor input, 2nd matches 2nd, etc.).
- list[str]: Parameter names whose metadata to use for outputs.
E.g., ["weight", "bias"] means output[0] has same shape as `weight`,
output[1] has same shape as `bias`.
- Callable: Custom function with same signature as the op, returns torch.empty_like()
tensors matching the expected output shapes.
setup_context_fn: Function to save tensors/values for backward.
Signature: setup_context_fn(ctx, inputs, output)
backward_fn: Function to compute gradients.
Signature: backward_fn(ctx, *grad_outputs) -> tuple of gradients
Returns:
The registered custom operator function.
Examples:
1. Basic usage (forward only, auto-generated name and meta function):
>>> @magi_register_custom_op()
... def my_relu(x: torch.Tensor) -> torch.Tensor:
... return torch.maximum(x, torch.zeros_like(x))
2. Multiple outputs with explicit output metadata via parameter names:
>>> @magi_register_custom_op(
... infer_output_meta_fn=["weight", "bias"], # output shapes match weight and bias
... )
... def compute_gradients(
... grad_output: torch.Tensor,
... weight: torch.Tensor,
... bias: torch.Tensor,
... ) -> tuple[torch.Tensor, torch.Tensor]:
... grad_weight = grad_output.sum(dim=0).view_as(weight)
... grad_bias = grad_output.sum(dim=0).view_as(bias)
... return grad_weight, grad_bias
3. Full custom op with autograd support:
>>> def _square_meta(x: torch.Tensor) -> torch.Tensor:
... return torch.empty_like(x)
...
>>> def _square_setup_context(ctx, inputs, output):
... (x,) = inputs
... ctx.save_for_backward(x)
...
>>> def _square_backward(ctx, grad_output):
... (x,) = ctx.saved_tensors
... return grad_output * 2 * x
...
>>> @magi_register_custom_op(
... name="my_ops::square",
... infer_output_meta_fn=_square_meta,
... setup_context_fn=_square_setup_context,
... backward_fn=_square_backward,
... )
... def square(x: torch.Tensor) -> torch.Tensor:
... return x * x
"""
def decorator(fn: Callable) -> Callable:
# Auto-generate name if not provided
op_name = name if name is not None else _generate_op_name(fn)
# Step 1: Register the custom op with torch.library.custom_op
registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn)
# Step 2: Register the output meta inference function
# Determine meta_fn based on the type of infer_output_meta_fn
if infer_output_meta_fn is None:
meta_fn = _create_identity_meta_fn(fn)
elif isinstance(infer_output_meta_fn, list):
meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn)
else:
meta_fn = infer_output_meta_fn
torch.library.register_fake(op_name)(meta_fn)
# Step 3: Register autograd if backward_fn is provided
if backward_fn is not None:
registered_op.register_autograd(backward_fn, setup_context=setup_context_fn)
return registered_op
return decorator