rahul7star/MagiGPU-storage / MagiCompiler /magi_compiler /_magi_register_custom_op.py
download
raw
7.5 kB
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, get_args, get_origin
import torch
from .config import get_compile_config
def _get_num_outputs_from_return_annotation(fn: Callable) -> int:
"""
Get the number of outputs from the function's return type annotation.
Returns:
- 1 if the return type is a single Tensor
- N if the return type is tuple[Tensor, Tensor, ...] with N elements
- 1 if no annotation or unrecognized annotation (default to single output)
"""
sig = inspect.signature(fn)
return_annotation = sig.return_annotation
if return_annotation is inspect.Parameter.empty:
return 1
# Check if it's a tuple type (e.g., tuple[Tensor, Tensor])
origin = get_origin(return_annotation)
if origin is tuple:
args = get_args(return_annotation)
# Filter out ellipsis (for variable-length tuples like tuple[Tensor, ...])
if args and args[-1] is not ...:
return len(args)
return 1
return 1
def _generate_op_name(fn: Callable) -> str:
"""
Generate a unique operator name from function's name and source file.
Format: {filename_stem}::{function_name}
Example: my_module.py with function `my_op` -> "my_module::my_op"
Falls back to "magi_custom::{function_name}" if source file cannot be determined.
"""
import re
from pathlib import Path
func_name = fn.__name__
# Get the source file path
try:
source_file = inspect.getfile(fn)
# Extract the file stem (without extension) as namespace
namespace = Path(source_file).stem
# Clean up namespace: replace invalid characters with underscores
namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
except (TypeError, OSError):
# If we can't get the source file, use a default namespace
namespace = "magi_custom"
return f"{namespace}::{func_name}"
def _create_identity_meta_fn(fn: Callable) -> Callable:
"""
Create a default identity meta function for the given function.
The generated meta function:
- Determines number of outputs from return type annotation
- Uses first N tensor inputs to infer output metadata
- Returns torch.empty_like() tensors with matching shape/dtype/device
Raises ValueError if not enough tensor inputs are provided.
"""
num_outputs = _get_num_outputs_from_return_annotation(fn)
sig = inspect.signature(fn)
# Get parameter names, excluding 'self' if present
param_names = [name for name in sig.parameters.keys() if name != "self"]
def identity_meta_fn(*args, **kwargs):
# Bind arguments to get a mapping of param_name -> value
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
# Collect the first `num_outputs` tensor arguments
tensor_args = []
for name in param_names:
arg = bound.arguments.get(name)
if isinstance(arg, torch.Tensor):
tensor_args.append(arg)
if len(tensor_args) >= num_outputs:
break
if len(tensor_args) < num_outputs:
raise ValueError(
f"identity_meta_fn requires at least {num_outputs} tensor inputs to match "
f"{num_outputs} outputs, but only found {len(tensor_args)} tensor inputs. "
f"Please provide a custom infer_output_meta_fn."
)
# Return outputs with same metadata as the first N inputs
if num_outputs == 1:
return torch.empty_like(tensor_args[0])
return tuple(torch.empty_like(t) for t in tensor_args[:num_outputs])
return identity_meta_fn
def _create_meta_fn_from_param_names(fn: Callable, param_names: list[str]) -> Callable:
"""
Create a meta function that returns torch.empty_like() for each specified parameter.
Args:
fn: Target function to inspect
param_names: List of parameter names to use as output templates
Returns:
Meta function that maps specified input params to output tensors
Raises:
ValueError: If parameter name doesn't exist or isn't a Tensor
"""
sig = inspect.signature(fn)
def meta_fn(*args, **kwargs):
# Bind arguments to get a mapping of param_name -> value
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
# Collect tensors for each specified parameter name
tensor_outputs = []
for name in param_names:
if name not in bound.arguments:
raise ValueError(
f"Parameter '{name}' not found in function signature. "
f"Available parameters: {list(bound.arguments.keys())}"
)
arg = bound.arguments[name]
if not isinstance(arg, torch.Tensor):
raise ValueError(
f"Parameter '{name}' is not a Tensor (got {type(arg).__name__}). "
f"infer_output_meta_fn list should only contain tensor parameter names."
)
tensor_outputs.append(torch.empty_like(arg))
# Return single tensor or tuple based on number of outputs
if len(tensor_outputs) == 1:
return tensor_outputs[0]
return tuple(tensor_outputs)
return meta_fn
def _magi_register_custom_op_impl(
name: str | None = None,
mutates_args: tuple[str, ...] = (),
infer_output_meta_fn: Callable | list[str] | None = None,
setup_context_fn: Callable | None = None,
backward_fn: Callable | None = None,
is_compute_sensitive: bool = False,
is_subgraph_boundary: bool = False,
):
def decorator(fn: Callable) -> Callable:
# Auto-generate name if not provided
op_name = name if name is not None else _generate_op_name(fn)
if is_compute_sensitive:
get_compile_config().recompute_config.custom_compute_sensitive_ops.append(op_name)
if is_subgraph_boundary:
get_compile_config().splitting_ops.append(op_name)
# Step 1: Register the custom op with torch.library.custom_op
registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn)
# Step 2: Register the output meta inference function
# Determine meta_fn based on the type of infer_output_meta_fn
if infer_output_meta_fn is None:
meta_fn = _create_identity_meta_fn(fn)
elif isinstance(infer_output_meta_fn, list):
meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn)
else:
meta_fn = infer_output_meta_fn
torch.library.register_fake(op_name)(meta_fn)
# Step 3: Register autograd if backward_fn is provided
if backward_fn is not None:
registered_op.register_autograd(backward_fn, setup_context=setup_context_fn)
return registered_op
return decorator

Xet Storage Details

Size:
7.5 kB
·
Xet hash:
88b0cf978b449e65b27b5ec1b0700672387377e23ba78ed27bacf0249087a104

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.