Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| import gc | |
| import inspect | |
| import os | |
| from contextlib import contextmanager | |
| from typing import Callable, TypeVar, get_args, get_origin, overload | |
| from unittest.mock import patch | |
| import magi_compiler.utils.envs as envs | |
| import torch | |
| from magi_compiler.cuda.cudart import pin_memory_in_place | |
| from magi_compiler.magi_compiler_base import MagiCompilerBase | |
| from magi_compiler.utils import compilation_counter, magi_logger | |
| from magi_compiler.utils.compile_time_monitor import CompileMonitor | |
| from torch import distributed as dist | |
| from torch import nn | |
| from torch._dynamo.symbolic_convert import InliningInstructionTranslator | |
| from .config import CompileConfig, CompileMode, get_compile_config | |
| # ============================================================================= | |
| # Workaround: TorchInductor autotune get_raw_stream | |
| # ============================================================================= | |
| # TorchInductor autotune code blocks may reference get_raw_stream() without | |
| # defining it, causing "name 'get_raw_stream' is not defined" at runtime. | |
| # Register it as a builtin so the exec'd autotune snippets can always find it. | |
| def _patch_get_raw_stream(): | |
| try: | |
| import builtins | |
| from torch._C import _cuda_getCurrentRawStream as _get_raw_stream | |
| except Exception: | |
| return | |
| if not hasattr(builtins, "get_raw_stream"): | |
| builtins.get_raw_stream = _get_raw_stream | |
| _patch_get_raw_stream() | |
| # ============================================================================= | |
| # Dynamo Config Isolation | |
| # ============================================================================= | |
| # Capture the default dynamo config at module load time (before any torch.compile). | |
| # This ensures we have a "clean" baseline config that hasn't been modified by | |
| # external torch.compile calls (e.g., with dynamic=True). | |
| _DEFAULT_DYNAMO_CONFIG: dict = torch._dynamo.config.get_config_copy() | |
| def _isolated_dynamo_config(): | |
| """ | |
| Context manager that provides an isolated dynamo config environment. | |
| """ | |
| with torch._dynamo.config.patch(**_DEFAULT_DYNAMO_CONFIG): | |
| yield | |
| _T = TypeVar("_T", bound=type[nn.Module]) | |
| _W = TypeVar("_W", bound="MagiCompilerBase") | |
| def magi_compile(*, enable_if: Callable[None, bool] | None = None) -> Callable[[_T], _T]: | |
| ... | |
| def magi_compile(*, dynamic_arg_dims: dict[str, int | list[int]] | None) -> Callable[[_T], _T]: | |
| ... | |
| def magi_compile(*, config_patch: Callable[[CompileConfig], CompileConfig] | None = None) -> Callable[[_T], _T]: | |
| ... | |
| def magi_compile(cls: _T) -> _T: | |
| ... | |
| def magi_compile( | |
| cls: _T | None = None, | |
| *, | |
| model_tag: str | None = None, | |
| dynamic_arg_dims: dict[str, int | list[int]] | None = None, | |
| enable_if: Callable[None, bool] | None = None, | |
| config_patch: Callable[[CompileConfig], CompileConfig] | None = None, | |
| ) -> Callable[[_T], _T] | _T: | |
| """ | |
| A decorator to add support for compiling the forward method of a class. | |
| Usage: | |
| 1. use directly as a decorator without arguments: | |
| ```python | |
| @magi_compile | |
| class MyModel(nn.Module): | |
| def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... | |
| ``` | |
| 2. use as a decorator with arguments: | |
| ```python | |
| @magi_compile(dynamic_arg_dims={"x": 0, "y": 0}) | |
| class MyModel(nn.Module): | |
| def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... | |
| ``` | |
| Arguments: | |
| - model_tag: optional tag in cache path (e.g. "wan_ti2v"). If not set, class name is used. | |
| Path segment: model_{idx}_{model_tag}_rank_{rank}. | |
| - dynamic_arg_dims: a dictionary that maps argument names to the dynamic | |
| dimensions of the argument. The dynamic dimensions can be either a single | |
| integer or a list of integers. | |
| - enable_if: a function that returns a boolean value indicating whether to compile the model or not. | |
| This is useful if you want to compile the model only when certain conditions are met. | |
| Notes: | |
| - dynamic_arg_dims will be inferred from the type annotation of the forward method if not provided, | |
| if the argument is annotated as `torch.Tensor` or `Optional[torch.Tensor]`, | |
| the first dimension will be marked as dynamic. | |
| - if an argument is `None`, it should always be passed as `None` during | |
| the lifetime of the model, otherwise, it cannot be captured as a single | |
| computation graph. | |
| """ | |
| def cls_decorator_helper(cls: _T) -> _T: | |
| nonlocal dynamic_arg_dims | |
| dynamic_arg_dims = dynamic_arg_dims or _infer_dynamic_arg_dims(cls) | |
| # Accuracy check | |
| assert hasattr(cls, "forward"), "decorated class should have a forward method." | |
| assert len(dynamic_arg_dims) > 0, ( | |
| "No dynamic dimensions found in the forward method of " f"{cls}. Please provide dynamic_arg_dims explicitly." | |
| ) | |
| for k in dynamic_arg_dims: | |
| assert k in inspect.signature(cls.forward).parameters, f"Argument {k} not found in the forward method of {cls}" | |
| return _magi_compile(cls, dynamic_arg_dims, enable_if, config_patch, model_tag=model_tag) | |
| if cls is not None: | |
| # use `magi_compile` as a decorator without arguments, cls is the class to be decorated | |
| assert isinstance(cls, type) | |
| return cls_decorator_helper(cls) | |
| return cls_decorator_helper | |
| def offload(obj): | |
| if isinstance(obj, torch.Tensor): | |
| return obj.cpu() | |
| elif isinstance(obj, dict): | |
| return {k: offload(v) for k, v in obj.items()} | |
| elif isinstance(obj, (list, tuple)): | |
| return type(obj)(offload(item) for item in obj) | |
| return obj | |
| def _magi_compile( | |
| cls: _T, | |
| dynamic_arg_dims: dict[str, int | list[int]], | |
| enable_if: Callable[None, bool] | None = None, | |
| config_patch: Callable[[CompileConfig], CompileConfig] | None = None, | |
| model_tag: str | None = None, | |
| ) -> _T: | |
| """ | |
| A decorator to add support for compiling the forward method of a class. | |
| """ | |
| if MagiCompilerBase in cls.__bases__: | |
| return cls | |
| # take care of method resolution order, make sure super().__init__ is called on the base class | |
| # other than MagiCompilerBase | |
| cls.__bases__ = cls.__bases__ + (MagiCompilerBase,) | |
| if get_compile_config().offload_config.model_cpu_offload: | |
| magi_logger.info(f"Enabling CPU offload for {cls}") | |
| _orig_apply = cls._apply | |
| def _cpu_apply(self, fn): | |
| if getattr(self, "_magi_offloaded_once", False): | |
| return _orig_apply(self, fn) | |
| # First, move all parameters/buffers to CPU | |
| def _force_cpu(t): | |
| return fn(t).cpu() | |
| _orig_apply(self, _force_cpu) | |
| # create shared memory tensors for all parameters/buffers on CPU | |
| if dist.is_initialized(): | |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) | |
| full_state_dict = self.state_dict() | |
| grouped_params = {} # {dtype: [(name, tensor), ...]} | |
| for name, tensor in full_state_dict.items(): | |
| if tensor.device.type == 'cpu': | |
| dt = tensor.dtype | |
| if dt not in grouped_params: | |
| grouped_params[dt] = [] | |
| grouped_params[dt].append((name, tensor)) | |
| shared_state_dict = {} | |
| self._magi_giant_buffers = [] | |
| dist.barrier() | |
| for dtype, param_list in grouped_params.items(): | |
| dtype_str = str(dtype).split('.')[-1] | |
| shared_bin_path = ( | |
| f"{envs.MAGI_SHARED_BIN_PATH}/magi_model_shared_{dtype_str}_{self.__class__.__name__}.bin" | |
| ) | |
| total_numel = sum(t.numel() for _, t in param_list) | |
| if local_rank == 0: | |
| flat_buffer = torch.zeros(total_numel, dtype=dtype) | |
| offset = 0 | |
| for _, tensor in param_list: | |
| numel = tensor.numel() | |
| flat_buffer[offset : offset + numel].copy_(tensor.view(-1)) | |
| offset += numel | |
| if dtype == torch.bfloat16: | |
| flat_buffer.view(torch.int16).numpy().tofile(shared_bin_path) | |
| elif dtype.itemsize == 1 and dtype.is_floating_point: | |
| # fp8 | |
| flat_buffer.view(torch.uint8).numpy().tofile(shared_bin_path) | |
| else: | |
| flat_buffer.numpy().tofile(shared_bin_path) | |
| del flat_buffer | |
| gc.collect() | |
| dist.barrier() | |
| giant_shared_tensor = torch.from_file( | |
| shared_bin_path, shared=True, size=total_numel, dtype=dtype, device="cpu" | |
| ) | |
| self._magi_giant_buffers.append(giant_shared_tensor) | |
| pin_memory_in_place(giant_shared_tensor) | |
| offset = 0 | |
| for name, original_tensor in param_list: | |
| numel = original_tensor.numel() | |
| shared_param = giant_shared_tensor[offset : offset + numel].view(original_tensor.shape) | |
| if original_tensor.requires_grad: | |
| shared_param.requires_grad_(True) | |
| shared_state_dict[name] = shared_param | |
| offset += numel | |
| dist.barrier() | |
| if local_rank == 0: | |
| if os.path.exists(shared_bin_path): | |
| os.remove(shared_bin_path) | |
| self.load_state_dict(shared_state_dict, assign=True) | |
| else: | |
| def _pinner(t): | |
| return t.pin_memory() | |
| _orig_apply(self, _pinner) | |
| self._magi_offloaded_once = True | |
| return self | |
| cls._apply = _cpu_apply | |
| old_init = cls.__init__ | |
| def __init__(self: _W, *args, **kwargs): | |
| old_init(self, *args, **kwargs) | |
| compile_config = get_compile_config() | |
| if config_patch is not None: | |
| compile_config = config_patch(compile_config) | |
| # deepcopy the compile config to avoid modifying the original compile config | |
| self.compile_config = copy.deepcopy(compile_config) | |
| enable_compile = enable_if is None or enable_if() | |
| self.enable_compile = self.compile_config.compile_mode != CompileMode.NONE and enable_compile | |
| if not self.enable_compile: | |
| return | |
| compilation_counter.num_models_seen += 1 | |
| self.compile_config.model_idx = compilation_counter.num_models_seen | |
| self.compile_config.model_tag = model_tag if model_tag is not None else self.__class__.__name__ | |
| MagiCompilerBase.__init__(self, compile_config=self.compile_config) | |
| cls.__init__ = __init__ | |
| old_call = cls.__call__ | |
| def __call__(self: _W, *args, **kwargs): | |
| ### Step1: Run compiled module directly if disable compile or captured before ### | |
| if self.compile_config.offload_config.model_cpu_offload and self.compiled_code is None: | |
| args = offload(args) | |
| kwargs = offload(kwargs) | |
| if not self.enable_compile or torch.compiler.is_compiling(): | |
| # Skip compiling the model if inside the compilation process. | |
| return old_call(self, *args, **kwargs) | |
| if self.compiled_code is not None: | |
| # Run the compiled function if compiled code is available. | |
| with self.dispatch_to_compiled_fwd(mode="jit"): | |
| return old_call(self, *args, **kwargs) | |
| if envs.MAGI_AOT_COMPILE: | |
| # Try load AOT artifacts from cache and run directly. | |
| self.aot_compiled_fn = self.try_load_aot_compile_artifacts() | |
| if self.aot_compiled_fn is not None: | |
| with self.dispatch_to_compiled_fwd(mode="aot"): | |
| return old_call(self, *args, **kwargs) | |
| ### Step2: Mark dynamic shapes for the first compilation ### | |
| bound_args = inspect.signature(self.__class__.forward).bind(self, *args, **kwargs) | |
| bound_args.apply_defaults() | |
| for k, dims in dynamic_arg_dims.items(): | |
| arg = bound_args.arguments.get(k) | |
| if arg is None: | |
| continue | |
| dims = [dims] if isinstance(dims, int) else dims | |
| assert isinstance(arg, torch.Tensor), f"Unsupported dynamic dim {dims} for argument {k} with type {type(arg)}." | |
| dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] | |
| torch._dynamo.mark_dynamic(arg, dims) | |
| ### Step3: Start compiling the model ### | |
| magi_logger.info(f"Start compiling function {self.original_code_object}") | |
| CompileMonitor().start( | |
| self.compile_config.compile_mode == CompileMode.MAGI_COMPILE, self.compile_config.debug_dump_path() | |
| ) | |
| # Dynamo reuse the compilation across instances, but we need to make sure the compiled code is not reused. | |
| torch._dynamo.eval_frame.remove_from_cache(self.original_code_object) | |
| with ( | |
| _hijack_inline_call_to_collect_traced_files(self), | |
| patch.object(torch.compiler.config, "dynamic_sources", self.compile_config.dynamic_sources), | |
| patch.object(torch._dynamo.config, "enable_cpp_symbolic_shape_guards", False), | |
| # 允许 mark_dynamic 在 module 属性链上的 tensor 生效 | |
| # (默认 True 会强制 module property tensor 为 static shape,忽略 mark_dynamic) | |
| patch.object(torch._dynamo.config, "force_nn_module_property_static_shapes", False), | |
| patch.dict( | |
| os.environ, {"TORCHINDUCTOR_CACHE_DIR": (self.compile_config.cache_dump_path() / "inductor_cache").as_posix()} | |
| ), | |
| ): | |
| if envs.MAGI_AOT_COMPILE: | |
| self.aot_compiled_fn = self.aot_compile(*args, **kwargs) | |
| self.aot_compiled_fn.save_compiled_function(self.aot_compilation_path) | |
| with self.dispatch_to_compiled_fwd(mode="aot"): | |
| output = old_call(self, *args, **kwargs) | |
| else: | |
| with patch.object(self, "forward", self.jit_compile): | |
| output = old_call(self, *args, **kwargs) | |
| return output | |
| # 使用 @torch.compiler.disable 和 _isolated_dynamo_config 包裹整个 __call__ | |
| # 确保 magi compile 在外部嵌套 torch.compile 时也能独立工作不受影响 | |
| isolated_call = _isolated_dynamo_config()(__call__) | |
| cls.__call__ = torch.compiler.disable(isolated_call) | |
| return cls | |
| # Collect all relevant files traced by Dynamo, re-compile the model when any of these files change. | |
| # 1. the file containing the top-level forward function | |
| # 2. hijack function to know all the functions called during Dynamo tracing, every time Dynamo sees a function call, it will inline | |
| # the function by calling InliningInstructionTranslator.inline_call_ | |
| def _hijack_inline_call_to_collect_traced_files(owner: _W): | |
| owner.compile_config.traced_files.add(owner.original_code_object.co_filename) | |
| inline_call = InliningInstructionTranslator.inline_call_ | |
| def patched_inline_call(self_): | |
| code = self_.f_code | |
| owner.compile_config.traced_files.add(code.co_filename) | |
| return inline_call(self_) | |
| return patch.object(InliningInstructionTranslator, "inline_call_", patched_inline_call) | |
| def _infer_dynamic_arg_dims(cls: _T) -> dict[str, int | list[int]]: | |
| sig = inspect.signature(cls.forward) | |
| inferred_dynamic_arg_dims = {} | |
| for k, v in sig.parameters.items(): | |
| if v.annotation in [torch.Tensor, torch.Tensor | None]: | |
| inferred_dynamic_arg_dims[k] = 0 | |
| magi_logger.info(f"Inferred dynamic dimensions for forward method of {cls}: {list(inferred_dynamic_arg_dims.keys())}") | |
| return inferred_dynamic_arg_dims | |
| def _get_num_outputs_from_return_annotation(fn: Callable) -> int: | |
| """ | |
| Get the number of outputs from the function's return type annotation. | |
| Returns: | |
| - 1 if the return type is a single Tensor | |
| - N if the return type is tuple[Tensor, Tensor, ...] with N elements | |
| - 1 if no annotation or unrecognized annotation (default to single output) | |
| """ | |
| sig = inspect.signature(fn) | |
| return_annotation = sig.return_annotation | |
| if return_annotation is inspect.Parameter.empty: | |
| return 1 | |
| # Check if it's a tuple type (e.g., tuple[Tensor, Tensor]) | |
| origin = get_origin(return_annotation) | |
| if origin is tuple: | |
| args = get_args(return_annotation) | |
| # Filter out ellipsis (for variable-length tuples like tuple[Tensor, ...]) | |
| if args and args[-1] is not ...: | |
| return len(args) | |
| return 1 | |
| return 1 | |
| def _generate_op_name(fn: Callable) -> str: | |
| """ | |
| Generate a unique operator name from function's name and source file. | |
| The generated name follows the format: namespace::op_name | |
| - namespace: derived from the source file path (module-like structure) | |
| - op_name: the function name | |
| Example: | |
| Function `_my_custom_op` in file `/path/to/my_module.py` | |
| -> "my_module::_my_custom_op" | |
| """ | |
| import re | |
| from pathlib import Path | |
| func_name = fn.__name__ | |
| # Get the source file path | |
| try: | |
| source_file = inspect.getfile(fn) | |
| # Extract the file stem (without extension) as namespace | |
| namespace = Path(source_file).stem | |
| # Clean up namespace: replace invalid characters with underscores | |
| namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace) | |
| except (TypeError, OSError): | |
| # If we can't get the source file, use a default namespace | |
| namespace = "magi_custom" | |
| return f"{namespace}::{func_name}" | |
| def _create_identity_meta_fn(fn: Callable) -> Callable: | |
| """ | |
| Create a default identity meta function for the given function. | |
| This identity meta function assumes that: | |
| - The number of outputs is determined by the function's return type annotation | |
| - Each output's metadata (shape, dtype, device) matches the corresponding input tensor | |
| For example, if a function has signature: | |
| def my_op(a: Tensor, b: Tensor, scale: float) -> tuple[Tensor, Tensor]: | |
| The identity meta function will return: | |
| (torch.empty_like(a), torch.empty_like(b)) | |
| """ | |
| num_outputs = _get_num_outputs_from_return_annotation(fn) | |
| sig = inspect.signature(fn) | |
| # Get parameter names, excluding 'self' if present | |
| param_names = [name for name in sig.parameters.keys() if name != "self"] | |
| def identity_meta_fn(*args, **kwargs): | |
| # Bind arguments to get a mapping of param_name -> value | |
| bound = sig.bind(*args, **kwargs) | |
| bound.apply_defaults() | |
| # Collect the first `num_outputs` tensor arguments | |
| tensor_args = [] | |
| for name in param_names: | |
| arg = bound.arguments.get(name) | |
| if isinstance(arg, torch.Tensor): | |
| tensor_args.append(arg) | |
| if len(tensor_args) >= num_outputs: | |
| break | |
| if len(tensor_args) < num_outputs: | |
| raise ValueError( | |
| f"identity_meta_fn requires at least {num_outputs} tensor inputs to match " | |
| f"{num_outputs} outputs, but only found {len(tensor_args)} tensor inputs. " | |
| f"Please provide a custom infer_output_meta_fn." | |
| ) | |
| # Return outputs with same metadata as the first N inputs | |
| if num_outputs == 1: | |
| return torch.empty_like(tensor_args[0]) | |
| return tuple(torch.empty_like(t) for t in tensor_args[:num_outputs]) | |
| return identity_meta_fn | |
| def _create_meta_fn_from_param_names(fn: Callable, param_names: list[str]) -> Callable: | |
| """ | |
| Create a meta function that returns torch.empty_like() for each specified parameter. | |
| This is useful when output tensors have the same shape/dtype/device as specific input | |
| parameters, but not necessarily in positional order. | |
| Example: | |
| param_names = ["weight", "bias"] | |
| def my_op(grad: Tensor, weight: Tensor, bias: Tensor) -> tuple[Tensor, Tensor]: | |
| ... | |
| Generated meta function returns: | |
| (torch.empty_like(weight), torch.empty_like(bias)) | |
| """ | |
| sig = inspect.signature(fn) | |
| def meta_fn(*args, **kwargs): | |
| # Bind arguments to get a mapping of param_name -> value | |
| bound = sig.bind(*args, **kwargs) | |
| bound.apply_defaults() | |
| # Collect tensors for each specified parameter name | |
| tensor_outputs = [] | |
| for name in param_names: | |
| if name not in bound.arguments: | |
| raise ValueError( | |
| f"Parameter '{name}' not found in function signature. " | |
| f"Available parameters: {list(bound.arguments.keys())}" | |
| ) | |
| arg = bound.arguments[name] | |
| if not isinstance(arg, torch.Tensor): | |
| raise ValueError( | |
| f"Parameter '{name}' is not a Tensor (got {type(arg).__name__}). " | |
| f"infer_output_meta_fn list should only contain tensor parameter names." | |
| ) | |
| tensor_outputs.append(torch.empty_like(arg)) | |
| # Return single tensor or tuple based on number of outputs | |
| if len(tensor_outputs) == 1: | |
| return tensor_outputs[0] | |
| return tuple(tensor_outputs) | |
| return meta_fn | |
| def magi_register_custom_op( | |
| name: str | None = None, | |
| mutates_args: tuple[str, ...] = (), | |
| infer_output_meta_fn: Callable | list[str] | None = None, | |
| setup_context_fn: Callable | None = None, | |
| backward_fn: Callable | None = None, | |
| ): | |
| """ | |
| A unified decorator to register a custom operator with PyTorch's library. | |
| This decorator combines the functionality of: | |
| - @torch.library.custom_op | |
| - @torch.library.register_fake | |
| - fn.register_autograd | |
| Arguments: | |
| name: The fully qualified name of the operator (e.g., "namespace::op_name"). | |
| If None, auto-generated from the function name and source file. | |
| mutates_args: Tuple of argument names that are mutated by the operator. | |
| infer_output_meta_fn: Specifies output tensor metadata (shape, dtype, device) for tracing. | |
| - None (default): Assumes each output has the same metadata as the corresponding | |
| input tensor (1st output matches 1st tensor input, 2nd matches 2nd, etc.). | |
| - list[str]: Parameter names whose metadata to use for outputs. | |
| E.g., ["weight", "bias"] means output[0] has same shape as `weight`, | |
| output[1] has same shape as `bias`. | |
| - Callable: Custom function with same signature as the op, returns torch.empty_like() | |
| tensors matching the expected output shapes. | |
| setup_context_fn: Function to save tensors/values for backward. | |
| Signature: setup_context_fn(ctx, inputs, output) | |
| backward_fn: Function to compute gradients. | |
| Signature: backward_fn(ctx, *grad_outputs) -> tuple of gradients | |
| Returns: | |
| The registered custom operator function. | |
| Examples: | |
| 1. Basic usage (forward only, auto-generated name and meta function): | |
| >>> @magi_register_custom_op() | |
| ... def my_relu(x: torch.Tensor) -> torch.Tensor: | |
| ... return torch.maximum(x, torch.zeros_like(x)) | |
| 2. Multiple outputs with explicit output metadata via parameter names: | |
| >>> @magi_register_custom_op( | |
| ... infer_output_meta_fn=["weight", "bias"], # output shapes match weight and bias | |
| ... ) | |
| ... def compute_gradients( | |
| ... grad_output: torch.Tensor, | |
| ... weight: torch.Tensor, | |
| ... bias: torch.Tensor, | |
| ... ) -> tuple[torch.Tensor, torch.Tensor]: | |
| ... grad_weight = grad_output.sum(dim=0).view_as(weight) | |
| ... grad_bias = grad_output.sum(dim=0).view_as(bias) | |
| ... return grad_weight, grad_bias | |
| 3. Full custom op with autograd support: | |
| >>> def _square_meta(x: torch.Tensor) -> torch.Tensor: | |
| ... return torch.empty_like(x) | |
| ... | |
| >>> def _square_setup_context(ctx, inputs, output): | |
| ... (x,) = inputs | |
| ... ctx.save_for_backward(x) | |
| ... | |
| >>> def _square_backward(ctx, grad_output): | |
| ... (x,) = ctx.saved_tensors | |
| ... return grad_output * 2 * x | |
| ... | |
| >>> @magi_register_custom_op( | |
| ... name="my_ops::square", | |
| ... infer_output_meta_fn=_square_meta, | |
| ... setup_context_fn=_square_setup_context, | |
| ... backward_fn=_square_backward, | |
| ... ) | |
| ... def square(x: torch.Tensor) -> torch.Tensor: | |
| ... return x * x | |
| """ | |
| def decorator(fn: Callable) -> Callable: | |
| # Auto-generate name if not provided | |
| op_name = name if name is not None else _generate_op_name(fn) | |
| # Step 1: Register the custom op with torch.library.custom_op | |
| registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn) | |
| # Step 2: Register the output meta inference function | |
| # Determine meta_fn based on the type of infer_output_meta_fn | |
| if infer_output_meta_fn is None: | |
| meta_fn = _create_identity_meta_fn(fn) | |
| elif isinstance(infer_output_meta_fn, list): | |
| meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn) | |
| else: | |
| meta_fn = infer_output_meta_fn | |
| torch.library.register_fake(op_name)(meta_fn) | |
| # Step 3: Register autograd if backward_fn is provided | |
| if backward_fn is not None: | |
| registered_op.register_autograd(backward_fn, setup_context=setup_context_fn) | |
| return registered_op | |
| return decorator | |