download
raw
10 kB
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from typing import Callable, TypeVar
from torch import nn
from ._api import (
_check_dynamic_arg_dims,
_infer_dynamic_arg_dims,
_magi_compile_class,
_magi_compile_function,
_magi_compile_instance,
)
from ._magi_register_custom_op import _magi_register_custom_op_impl
from .config import CompileConfig
_T = TypeVar("_T", bound=type[nn.Module])
_F = TypeVar("_F", bound=Callable)
_M = TypeVar("_M", bound=nn.Module)
def magi_compile(
obj: _T | _M | _F | None = None,
*,
model_tag: str | None = None,
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[], bool] | None = None,
config_patch: Callable[[CompileConfig], CompileConfig] | None = None,
) -> _T | _M | _F | Callable[[_T | _M | _F], _T | _M | _F]:
"""
Compile target objects (nn.Module classes, modules, functions, or methods).
Supported target types
----------------------
1) Class (must be an `nn.Module` subclass):
- Affects all instances of the annotated class.
- Compilation dispatch enters via `__call__`, while compiled execution replaces `forward` logic.
- Example:
@magi_compile
class MyModel(nn.Module):
def forward(self, x): return x
2) Function (Standalone):
- Wraps a callable with MagiCompiler's dispatch logic.
- Useful for non-member functions or general callables.
- Example:
@magi_compile
def my_func(x): return x
3) Instance (nn.Module):
- Compiles a single instance specifically.
- Avoids affecting other instances by creating an instance-specific subclass.
- Example:
model = MyModel()
model = magi_compile(model)
4) Method (Bound/Unbound):
- Wraps a specific function attribute (e.g., `model.forward`).
- Enables focused compilation of specific object behaviors.
- Example:
model = MyModel()
model.forward = magi_compile(model.forward)
Usage Styles
------------
The compiler supports both declarative (decorator) and imperative (function call) styles.
A) Decorator Style:
- Example:
@magi_compile(dynamic_arg_dims={"x": 0})
class MyModel(nn.Module): ...
class MyModel(nn.Module):
@magi_compile
def forward(self, x): ...
B) Imperative Style:
- Apply directly to an existing object:
model = magi_compile(model, dynamic_arg_dims={"x": 0})
C) Factory Style:
- Configure a compiler first, then apply to multiple objects:
compiler = magi_compile(dynamic_arg_dims={"x": 0})
model = compiler(model)
cls = compiler(MyModel)
Arguments
---------
- dynamic_arg_dims: Dictionary mapping argument names to dynamic dimensions (int or list[int]).
- model_tag: Optional tag for caching path (defaults to class/function name).
- enable_if: Callable returning bool; compilation happens only if this returns True.
Notes
-----
- If `dynamic_arg_dims` is omitted, it is inferred from type annotations:
`torch.Tensor` arguments default to dynamic dimension 0.
- Consistency: For graph stability, maintain consistent input types (e.g., avoid switching between Tensor and None).
"""
if obj is None:
return functools.partial(
magi_compile,
model_tag=model_tag,
dynamic_arg_dims=dynamic_arg_dims,
enable_if=enable_if,
config_patch=config_patch,
)
# 1. Determine target function for dynamic dim inference
if inspect.isclass(obj):
assert issubclass(obj, nn.Module), f"Expected nn.Module subclass, got {obj}"
target_func = obj.forward
context_name = f"forward method of {obj.__name__}"
elif isinstance(obj, nn.Module):
target_func = obj.forward
context_name = f"forward method of instance {obj.__class__.__name__}"
elif callable(obj):
target_func = obj
context_name = f"function/method {obj.__name__}"
else:
raise TypeError(f"Unsupported type for magi_compile: {type(obj)}")
# 2. Infer dynamic dims
inferred_dims = dynamic_arg_dims or _infer_dynamic_arg_dims(target_func, context_name)
assert (
len(inferred_dims) > 0
), f"No dynamic dimensions found in {context_name}. Please provide dynamic_arg_dims explicitly."
_check_dynamic_arg_dims(inferred_dims, target_func)
# 3. Logic based on type
if inspect.isclass(obj):
return _magi_compile_class(obj, inferred_dims, enable_if, config_patch, model_tag)
elif isinstance(obj, nn.Module):
return _magi_compile_instance(obj, inferred_dims, enable_if, config_patch, model_tag)
else:
return _magi_compile_function(obj, inferred_dims, enable_if, config_patch, model_tag)
def magi_register_custom_op(
name: str | None = None,
mutates_args: tuple[str, ...] = (),
infer_output_meta_fn: Callable | list[str] | None = None,
setup_context_fn: Callable | None = None,
backward_fn: Callable | None = None,
is_compute_sensitive: bool = False,
is_subgraph_boundary: bool = False,
):
"""
A unified decorator to register a custom operator with PyTorch's library.
This decorator combines the functionality of:
- @torch.library.custom_op
- @torch.library.register_fake
- fn.register_autograd
Arguments:
name: The fully qualified name of the operator (e.g., "namespace::op_name").
If None, auto-generated from the function name and source file.
mutates_args: Tuple of argument names that are mutated by the operator.
infer_output_meta_fn: Specifies output tensor metadata (shape, dtype, device) for tracing.
- None (default): Assumes each output has the same metadata as the corresponding
input tensor (1st output matches 1st tensor input, 2nd matches 2nd, etc.).
- list[str]: Parameter names whose metadata to use for outputs.
E.g., ["weight", "bias"] means output[0] has same shape as `weight`,
output[1] has same shape as `bias`.
- Callable: Custom function with same signature as the op, returns torch.empty_like()
tensors matching the expected output shapes.
setup_context_fn: Function to save tensors/values for backward.
Signature: setup_context_fn(ctx, inputs, output)
backward_fn: Function to compute gradients.
Signature: backward_fn(ctx, *grad_outputs) -> tuple of gradients
is_compute_sensitive: If True, marks this operator as compute-intensive (e.g., MatMul,
Attention). During activation recomputation (rematerialization), outputs of
compute-sensitive ops are prioritized for saving rather than recomputing,
since recomputing them would be expensive.
is_subgraph_boundary: If True, the FX graph will be split at this operator during
compilation. Each sub-graph between boundary operators is compiled independently
by Inductor, enabling piecewise compilation and more flexible scheduling
(e.g., for CPU offloading or overlapping computation with data transfer).
Returns:
The registered custom operator function.
Examples:
1. Basic usage (forward only, auto-generated name and meta function):
>>> @magi_register_custom_op()
... def my_relu(x: torch.Tensor) -> torch.Tensor:
... return torch.maximum(x, torch.zeros_like(x))
2. Multiple outputs with explicit output metadata via parameter names:
>>> @magi_register_custom_op(
... infer_output_meta_fn=["weight", "bias"], # output shapes match weight and bias
... )
... def compute_gradients(
... grad_output: torch.Tensor,
... weight: torch.Tensor,
... bias: torch.Tensor,
... ) -> tuple[torch.Tensor, torch.Tensor]:
... grad_weight = grad_output.sum(dim=0).view_as(weight)
... grad_bias = grad_output.sum(dim=0).view_as(bias)
... return grad_weight, grad_bias
3. Full custom op with autograd support:
>>> def _square_meta(x: torch.Tensor) -> torch.Tensor:
... return torch.empty_like(x)
...
>>> def _square_setup_context(ctx, inputs, output):
... (x,) = inputs
... ctx.save_for_backward(x)
...
>>> def _square_backward(ctx, grad_output):
... (x,) = ctx.saved_tensors
... return grad_output * 2 * x
...
>>> @magi_register_custom_op(
... name="my_ops::square",
... infer_output_meta_fn=_square_meta,
... setup_context_fn=_square_setup_context,
... backward_fn=_square_backward,
... )
... def square(x: torch.Tensor) -> torch.Tensor:
... return x * x
"""
return _magi_register_custom_op_impl(
name=name,
mutates_args=mutates_args,
infer_output_meta_fn=infer_output_meta_fn,
setup_context_fn=setup_context_fn,
backward_fn=backward_fn,
is_compute_sensitive=is_compute_sensitive,
is_subgraph_boundary=is_subgraph_boundary,
)

Xet Storage Details

Size:
10 kB
·
Xet hash:
ae6381bf4e68c7449f60aa1f8f0acc10f73a34493bc5531690787cfaeecc6d2e

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.