# Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import gc import inspect import os from contextlib import contextmanager from typing import Callable, TypeVar, get_args, get_origin, overload from unittest.mock import patch import magi_compiler.utils.envs as envs import torch from magi_compiler.cuda.cudart import pin_memory_in_place from magi_compiler.magi_compiler_base import MagiCompilerBase from magi_compiler.utils import compilation_counter, magi_logger from magi_compiler.utils.compile_time_monitor import CompileMonitor from torch import distributed as dist from torch import nn from torch._dynamo.symbolic_convert import InliningInstructionTranslator from .config import CompileConfig, CompileMode, get_compile_config # ============================================================================= # Workaround: TorchInductor autotune get_raw_stream # ============================================================================= # TorchInductor autotune code blocks may reference get_raw_stream() without # defining it, causing "name 'get_raw_stream' is not defined" at runtime. # Register it as a builtin so the exec'd autotune snippets can always find it. def _patch_get_raw_stream(): try: import builtins from torch._C import _cuda_getCurrentRawStream as _get_raw_stream except Exception: return if not hasattr(builtins, "get_raw_stream"): builtins.get_raw_stream = _get_raw_stream _patch_get_raw_stream() # ============================================================================= # Dynamo Config Isolation # ============================================================================= # Capture the default dynamo config at module load time (before any torch.compile). # This ensures we have a "clean" baseline config that hasn't been modified by # external torch.compile calls (e.g., with dynamic=True). _DEFAULT_DYNAMO_CONFIG: dict = torch._dynamo.config.get_config_copy() @contextmanager def _isolated_dynamo_config(): """ Context manager that provides an isolated dynamo config environment. """ with torch._dynamo.config.patch(**_DEFAULT_DYNAMO_CONFIG): yield _T = TypeVar("_T", bound=type[nn.Module]) _W = TypeVar("_W", bound="MagiCompilerBase") @overload def magi_compile(*, enable_if: Callable[None, bool] | None = None) -> Callable[[_T], _T]: ... @overload def magi_compile(*, dynamic_arg_dims: dict[str, int | list[int]] | None) -> Callable[[_T], _T]: ... @overload def magi_compile(*, config_patch: Callable[[CompileConfig], CompileConfig] | None = None) -> Callable[[_T], _T]: ... @overload def magi_compile(cls: _T) -> _T: ... def magi_compile( cls: _T | None = None, *, model_tag: str | None = None, dynamic_arg_dims: dict[str, int | list[int]] | None = None, enable_if: Callable[None, bool] | None = None, config_patch: Callable[[CompileConfig], CompileConfig] | None = None, ) -> Callable[[_T], _T] | _T: """ A decorator to add support for compiling the forward method of a class. Usage: 1. use directly as a decorator without arguments: ```python @magi_compile class MyModel(nn.Module): def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... ``` 2. use as a decorator with arguments: ```python @magi_compile(dynamic_arg_dims={"x": 0, "y": 0}) class MyModel(nn.Module): def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... ``` Arguments: - model_tag: optional tag in cache path (e.g. "wan_ti2v"). If not set, class name is used. Path segment: model_{idx}_{model_tag}_rank_{rank}. - dynamic_arg_dims: a dictionary that maps argument names to the dynamic dimensions of the argument. The dynamic dimensions can be either a single integer or a list of integers. - enable_if: a function that returns a boolean value indicating whether to compile the model or not. This is useful if you want to compile the model only when certain conditions are met. Notes: - dynamic_arg_dims will be inferred from the type annotation of the forward method if not provided, if the argument is annotated as `torch.Tensor` or `Optional[torch.Tensor]`, the first dimension will be marked as dynamic. - if an argument is `None`, it should always be passed as `None` during the lifetime of the model, otherwise, it cannot be captured as a single computation graph. """ def cls_decorator_helper(cls: _T) -> _T: nonlocal dynamic_arg_dims dynamic_arg_dims = dynamic_arg_dims or _infer_dynamic_arg_dims(cls) # Accuracy check assert hasattr(cls, "forward"), "decorated class should have a forward method." assert len(dynamic_arg_dims) > 0, ( "No dynamic dimensions found in the forward method of " f"{cls}. Please provide dynamic_arg_dims explicitly." ) for k in dynamic_arg_dims: assert k in inspect.signature(cls.forward).parameters, f"Argument {k} not found in the forward method of {cls}" return _magi_compile(cls, dynamic_arg_dims, enable_if, config_patch, model_tag=model_tag) if cls is not None: # use `magi_compile` as a decorator without arguments, cls is the class to be decorated assert isinstance(cls, type) return cls_decorator_helper(cls) return cls_decorator_helper def offload(obj): if isinstance(obj, torch.Tensor): return obj.cpu() elif isinstance(obj, dict): return {k: offload(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return type(obj)(offload(item) for item in obj) return obj def _magi_compile( cls: _T, dynamic_arg_dims: dict[str, int | list[int]], enable_if: Callable[None, bool] | None = None, config_patch: Callable[[CompileConfig], CompileConfig] | None = None, model_tag: str | None = None, ) -> _T: """ A decorator to add support for compiling the forward method of a class. """ if MagiCompilerBase in cls.__bases__: return cls # take care of method resolution order, make sure super().__init__ is called on the base class # other than MagiCompilerBase cls.__bases__ = cls.__bases__ + (MagiCompilerBase,) if get_compile_config().offload_config.model_cpu_offload: magi_logger.info(f"Enabling CPU offload for {cls}") _orig_apply = cls._apply def _cpu_apply(self, fn): if getattr(self, "_magi_offloaded_once", False): return _orig_apply(self, fn) # First, move all parameters/buffers to CPU def _force_cpu(t): return fn(t).cpu() _orig_apply(self, _force_cpu) # create shared memory tensors for all parameters/buffers on CPU if dist.is_initialized(): local_rank = int(os.environ.get("LOCAL_RANK", 0)) full_state_dict = self.state_dict() grouped_params = {} # {dtype: [(name, tensor), ...]} for name, tensor in full_state_dict.items(): if tensor.device.type == 'cpu': dt = tensor.dtype if dt not in grouped_params: grouped_params[dt] = [] grouped_params[dt].append((name, tensor)) shared_state_dict = {} self._magi_giant_buffers = [] dist.barrier() for dtype, param_list in grouped_params.items(): dtype_str = str(dtype).split('.')[-1] shared_bin_path = ( f"{envs.MAGI_SHARED_BIN_PATH}/magi_model_shared_{dtype_str}_{self.__class__.__name__}.bin" ) total_numel = sum(t.numel() for _, t in param_list) if local_rank == 0: flat_buffer = torch.zeros(total_numel, dtype=dtype) offset = 0 for _, tensor in param_list: numel = tensor.numel() flat_buffer[offset : offset + numel].copy_(tensor.view(-1)) offset += numel if dtype == torch.bfloat16: flat_buffer.view(torch.int16).numpy().tofile(shared_bin_path) elif dtype.itemsize == 1 and dtype.is_floating_point: # fp8 flat_buffer.view(torch.uint8).numpy().tofile(shared_bin_path) else: flat_buffer.numpy().tofile(shared_bin_path) del flat_buffer gc.collect() dist.barrier() giant_shared_tensor = torch.from_file( shared_bin_path, shared=True, size=total_numel, dtype=dtype, device="cpu" ) self._magi_giant_buffers.append(giant_shared_tensor) pin_memory_in_place(giant_shared_tensor) offset = 0 for name, original_tensor in param_list: numel = original_tensor.numel() shared_param = giant_shared_tensor[offset : offset + numel].view(original_tensor.shape) if original_tensor.requires_grad: shared_param.requires_grad_(True) shared_state_dict[name] = shared_param offset += numel dist.barrier() if local_rank == 0: if os.path.exists(shared_bin_path): os.remove(shared_bin_path) self.load_state_dict(shared_state_dict, assign=True) else: def _pinner(t): return t.pin_memory() _orig_apply(self, _pinner) self._magi_offloaded_once = True return self cls._apply = _cpu_apply old_init = cls.__init__ def __init__(self: _W, *args, **kwargs): old_init(self, *args, **kwargs) compile_config = get_compile_config() if config_patch is not None: compile_config = config_patch(compile_config) # deepcopy the compile config to avoid modifying the original compile config self.compile_config = copy.deepcopy(compile_config) enable_compile = enable_if is None or enable_if() self.enable_compile = self.compile_config.compile_mode != CompileMode.NONE and enable_compile if not self.enable_compile: return compilation_counter.num_models_seen += 1 self.compile_config.model_idx = compilation_counter.num_models_seen self.compile_config.model_tag = model_tag if model_tag is not None else self.__class__.__name__ MagiCompilerBase.__init__(self, compile_config=self.compile_config) cls.__init__ = __init__ old_call = cls.__call__ def __call__(self: _W, *args, **kwargs): ### Step1: Run compiled module directly if disable compile or captured before ### if self.compile_config.offload_config.model_cpu_offload and self.compiled_code is None: args = offload(args) kwargs = offload(kwargs) if not self.enable_compile or torch.compiler.is_compiling(): # Skip compiling the model if inside the compilation process. return old_call(self, *args, **kwargs) if self.compiled_code is not None: # Run the compiled function if compiled code is available. with self.dispatch_to_compiled_fwd(mode="jit"): return old_call(self, *args, **kwargs) if envs.MAGI_AOT_COMPILE: # Try load AOT artifacts from cache and run directly. self.aot_compiled_fn = self.try_load_aot_compile_artifacts() if self.aot_compiled_fn is not None: with self.dispatch_to_compiled_fwd(mode="aot"): return old_call(self, *args, **kwargs) ### Step2: Mark dynamic shapes for the first compilation ### bound_args = inspect.signature(self.__class__.forward).bind(self, *args, **kwargs) bound_args.apply_defaults() for k, dims in dynamic_arg_dims.items(): arg = bound_args.arguments.get(k) if arg is None: continue dims = [dims] if isinstance(dims, int) else dims assert isinstance(arg, torch.Tensor), f"Unsupported dynamic dim {dims} for argument {k} with type {type(arg)}." dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] torch._dynamo.mark_dynamic(arg, dims) ### Step3: Start compiling the model ### magi_logger.info(f"Start compiling function {self.original_code_object}") CompileMonitor().start( self.compile_config.compile_mode == CompileMode.MAGI_COMPILE, self.compile_config.debug_dump_path() ) # Dynamo reuse the compilation across instances, but we need to make sure the compiled code is not reused. torch._dynamo.eval_frame.remove_from_cache(self.original_code_object) with ( _hijack_inline_call_to_collect_traced_files(self), patch.object(torch.compiler.config, "dynamic_sources", self.compile_config.dynamic_sources), patch.object(torch._dynamo.config, "enable_cpp_symbolic_shape_guards", False), # 允许 mark_dynamic 在 module 属性链上的 tensor 生效 # (默认 True 会强制 module property tensor 为 static shape,忽略 mark_dynamic) patch.object(torch._dynamo.config, "force_nn_module_property_static_shapes", False), patch.dict( os.environ, {"TORCHINDUCTOR_CACHE_DIR": (self.compile_config.cache_dump_path() / "inductor_cache").as_posix()} ), ): if envs.MAGI_AOT_COMPILE: self.aot_compiled_fn = self.aot_compile(*args, **kwargs) self.aot_compiled_fn.save_compiled_function(self.aot_compilation_path) with self.dispatch_to_compiled_fwd(mode="aot"): output = old_call(self, *args, **kwargs) else: with patch.object(self, "forward", self.jit_compile): output = old_call(self, *args, **kwargs) return output # 使用 @torch.compiler.disable 和 _isolated_dynamo_config 包裹整个 __call__ # 确保 magi compile 在外部嵌套 torch.compile 时也能独立工作不受影响 isolated_call = _isolated_dynamo_config()(__call__) cls.__call__ = torch.compiler.disable(isolated_call) return cls # Collect all relevant files traced by Dynamo, re-compile the model when any of these files change. # 1. the file containing the top-level forward function # 2. hijack function to know all the functions called during Dynamo tracing, every time Dynamo sees a function call, it will inline # the function by calling InliningInstructionTranslator.inline_call_ def _hijack_inline_call_to_collect_traced_files(owner: _W): owner.compile_config.traced_files.add(owner.original_code_object.co_filename) inline_call = InliningInstructionTranslator.inline_call_ def patched_inline_call(self_): code = self_.f_code owner.compile_config.traced_files.add(code.co_filename) return inline_call(self_) return patch.object(InliningInstructionTranslator, "inline_call_", patched_inline_call) def _infer_dynamic_arg_dims(cls: _T) -> dict[str, int | list[int]]: sig = inspect.signature(cls.forward) inferred_dynamic_arg_dims = {} for k, v in sig.parameters.items(): if v.annotation in [torch.Tensor, torch.Tensor | None]: inferred_dynamic_arg_dims[k] = 0 magi_logger.info(f"Inferred dynamic dimensions for forward method of {cls}: {list(inferred_dynamic_arg_dims.keys())}") return inferred_dynamic_arg_dims def _get_num_outputs_from_return_annotation(fn: Callable) -> int: """ Get the number of outputs from the function's return type annotation. Returns: - 1 if the return type is a single Tensor - N if the return type is tuple[Tensor, Tensor, ...] with N elements - 1 if no annotation or unrecognized annotation (default to single output) """ sig = inspect.signature(fn) return_annotation = sig.return_annotation if return_annotation is inspect.Parameter.empty: return 1 # Check if it's a tuple type (e.g., tuple[Tensor, Tensor]) origin = get_origin(return_annotation) if origin is tuple: args = get_args(return_annotation) # Filter out ellipsis (for variable-length tuples like tuple[Tensor, ...]) if args and args[-1] is not ...: return len(args) return 1 return 1 def _generate_op_name(fn: Callable) -> str: """ Generate a unique operator name from function's name and source file. The generated name follows the format: namespace::op_name - namespace: derived from the source file path (module-like structure) - op_name: the function name Example: Function `_my_custom_op` in file `/path/to/my_module.py` -> "my_module::_my_custom_op" """ import re from pathlib import Path func_name = fn.__name__ # Get the source file path try: source_file = inspect.getfile(fn) # Extract the file stem (without extension) as namespace namespace = Path(source_file).stem # Clean up namespace: replace invalid characters with underscores namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace) except (TypeError, OSError): # If we can't get the source file, use a default namespace namespace = "magi_custom" return f"{namespace}::{func_name}" def _create_identity_meta_fn(fn: Callable) -> Callable: """ Create a default identity meta function for the given function. This identity meta function assumes that: - The number of outputs is determined by the function's return type annotation - Each output's metadata (shape, dtype, device) matches the corresponding input tensor For example, if a function has signature: def my_op(a: Tensor, b: Tensor, scale: float) -> tuple[Tensor, Tensor]: The identity meta function will return: (torch.empty_like(a), torch.empty_like(b)) """ num_outputs = _get_num_outputs_from_return_annotation(fn) sig = inspect.signature(fn) # Get parameter names, excluding 'self' if present param_names = [name for name in sig.parameters.keys() if name != "self"] def identity_meta_fn(*args, **kwargs): # Bind arguments to get a mapping of param_name -> value bound = sig.bind(*args, **kwargs) bound.apply_defaults() # Collect the first `num_outputs` tensor arguments tensor_args = [] for name in param_names: arg = bound.arguments.get(name) if isinstance(arg, torch.Tensor): tensor_args.append(arg) if len(tensor_args) >= num_outputs: break if len(tensor_args) < num_outputs: raise ValueError( f"identity_meta_fn requires at least {num_outputs} tensor inputs to match " f"{num_outputs} outputs, but only found {len(tensor_args)} tensor inputs. " f"Please provide a custom infer_output_meta_fn." ) # Return outputs with same metadata as the first N inputs if num_outputs == 1: return torch.empty_like(tensor_args[0]) return tuple(torch.empty_like(t) for t in tensor_args[:num_outputs]) return identity_meta_fn def _create_meta_fn_from_param_names(fn: Callable, param_names: list[str]) -> Callable: """ Create a meta function that returns torch.empty_like() for each specified parameter. This is useful when output tensors have the same shape/dtype/device as specific input parameters, but not necessarily in positional order. Example: param_names = ["weight", "bias"] def my_op(grad: Tensor, weight: Tensor, bias: Tensor) -> tuple[Tensor, Tensor]: ... Generated meta function returns: (torch.empty_like(weight), torch.empty_like(bias)) """ sig = inspect.signature(fn) def meta_fn(*args, **kwargs): # Bind arguments to get a mapping of param_name -> value bound = sig.bind(*args, **kwargs) bound.apply_defaults() # Collect tensors for each specified parameter name tensor_outputs = [] for name in param_names: if name not in bound.arguments: raise ValueError( f"Parameter '{name}' not found in function signature. " f"Available parameters: {list(bound.arguments.keys())}" ) arg = bound.arguments[name] if not isinstance(arg, torch.Tensor): raise ValueError( f"Parameter '{name}' is not a Tensor (got {type(arg).__name__}). " f"infer_output_meta_fn list should only contain tensor parameter names." ) tensor_outputs.append(torch.empty_like(arg)) # Return single tensor or tuple based on number of outputs if len(tensor_outputs) == 1: return tensor_outputs[0] return tuple(tensor_outputs) return meta_fn def magi_register_custom_op( name: str | None = None, mutates_args: tuple[str, ...] = (), infer_output_meta_fn: Callable | list[str] | None = None, setup_context_fn: Callable | None = None, backward_fn: Callable | None = None, ): """ A unified decorator to register a custom operator with PyTorch's library. This decorator combines the functionality of: - @torch.library.custom_op - @torch.library.register_fake - fn.register_autograd Arguments: name: The fully qualified name of the operator (e.g., "namespace::op_name"). If None, auto-generated from the function name and source file. mutates_args: Tuple of argument names that are mutated by the operator. infer_output_meta_fn: Specifies output tensor metadata (shape, dtype, device) for tracing. - None (default): Assumes each output has the same metadata as the corresponding input tensor (1st output matches 1st tensor input, 2nd matches 2nd, etc.). - list[str]: Parameter names whose metadata to use for outputs. E.g., ["weight", "bias"] means output[0] has same shape as `weight`, output[1] has same shape as `bias`. - Callable: Custom function with same signature as the op, returns torch.empty_like() tensors matching the expected output shapes. setup_context_fn: Function to save tensors/values for backward. Signature: setup_context_fn(ctx, inputs, output) backward_fn: Function to compute gradients. Signature: backward_fn(ctx, *grad_outputs) -> tuple of gradients Returns: The registered custom operator function. Examples: 1. Basic usage (forward only, auto-generated name and meta function): >>> @magi_register_custom_op() ... def my_relu(x: torch.Tensor) -> torch.Tensor: ... return torch.maximum(x, torch.zeros_like(x)) 2. Multiple outputs with explicit output metadata via parameter names: >>> @magi_register_custom_op( ... infer_output_meta_fn=["weight", "bias"], # output shapes match weight and bias ... ) ... def compute_gradients( ... grad_output: torch.Tensor, ... weight: torch.Tensor, ... bias: torch.Tensor, ... ) -> tuple[torch.Tensor, torch.Tensor]: ... grad_weight = grad_output.sum(dim=0).view_as(weight) ... grad_bias = grad_output.sum(dim=0).view_as(bias) ... return grad_weight, grad_bias 3. Full custom op with autograd support: >>> def _square_meta(x: torch.Tensor) -> torch.Tensor: ... return torch.empty_like(x) ... >>> def _square_setup_context(ctx, inputs, output): ... (x,) = inputs ... ctx.save_for_backward(x) ... >>> def _square_backward(ctx, grad_output): ... (x,) = ctx.saved_tensors ... return grad_output * 2 * x ... >>> @magi_register_custom_op( ... name="my_ops::square", ... infer_output_meta_fn=_square_meta, ... setup_context_fn=_square_setup_context, ... backward_fn=_square_backward, ... ) ... def square(x: torch.Tensor) -> torch.Tensor: ... return x * x """ def decorator(fn: Callable) -> Callable: # Auto-generate name if not provided op_name = name if name is not None else _generate_op_name(fn) # Step 1: Register the custom op with torch.library.custom_op registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn) # Step 2: Register the output meta inference function # Determine meta_fn based on the type of infer_output_meta_fn if infer_output_meta_fn is None: meta_fn = _create_identity_meta_fn(fn) elif isinstance(infer_output_meta_fn, list): meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn) else: meta_fn = infer_output_meta_fn torch.library.register_fake(op_name)(meta_fn) # Step 3: Register autograd if backward_fn is provided if backward_fn is not None: registered_op.register_autograd(backward_fn, setup_context=setup_context_fn) return registered_op return decorator