rahul7star/MagiGPU-storage / MagiCompiler /tests /api_tests /test_register_custom_op.py
download
raw
23 kB
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for @magi_register_custom_op decorator functionality.
This module tests:
- Basic custom op registration (forward only)
- Custom op with infer_output_meta_fn for torch.compile tracing
- Custom op with autograd support (setup_context + backward)
- Full custom op with all components
- Multiple outputs support
- Integration with magi_compile decorator
"""
import tempfile
from unittest.mock import patch
import pytest
import torch
from torch import nn
from torch.testing import assert_close
from magi_compiler.api import magi_compile, magi_register_custom_op
from magi_compiler.config import CompileConfig, CompileMode
class TestBasicRegistration:
"""Tests for basic custom op registration without autograd."""
def test_forward_only(self):
"""Test registering a custom op with only forward implementation."""
@magi_register_custom_op(name="test::forward_only_op", mutates_args=())
def _forward_only_op(x: torch.Tensor) -> torch.Tensor:
return x * 2 + 1
x = torch.randn(4, 8)
output = _forward_only_op(x)
expected = x * 2 + 1
assert_close(output, expected)
def test_multiple_inputs(self):
"""Test custom op with multiple input tensors."""
@magi_register_custom_op(name="test::multi_input_op", mutates_args=())
def _multi_input_op(a: torch.Tensor, b: torch.Tensor, scale: float) -> torch.Tensor:
return (a + b) * scale
a = torch.randn(4, 8)
b = torch.randn(4, 8)
scale = 2.5
output = _multi_input_op(a, b, scale)
expected = (a + b) * scale
assert_close(output, expected)
class TestInferOutputMeta:
"""Tests for custom op with infer_output_meta_fn."""
def test_with_infer_output_meta(self):
"""Test that infer_output_meta_fn is correctly registered for tracing."""
def _scaled_add_infer_output_meta(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tensor:
return torch.empty_like(x)
@magi_register_custom_op(
name="test::scaled_add_op", mutates_args=(), infer_output_meta_fn=_scaled_add_infer_output_meta
)
def _scaled_add_op(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tensor:
return (x + y) * scale
x = torch.randn(4, 8)
y = torch.randn(4, 8)
scale = 3.0
output = _scaled_add_op(x, y, scale)
expected = (x + y) * scale
assert_close(output, expected)
def test_multiple_outputs_infer_meta(self):
"""Test infer_output_meta_fn with multiple outputs."""
def _split_op_infer_output_meta(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
half_size = x.shape[-1] // 2
return (x.new_empty((*x.shape[:-1], half_size)), x.new_empty((*x.shape[:-1], half_size)))
@magi_register_custom_op(name="test::split_op", mutates_args=(), infer_output_meta_fn=_split_op_infer_output_meta)
def _split_op(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
half_size = x.shape[-1] // 2
# NOTE: Output cannot share the same memory with input
return torch.clone(x[..., :half_size]), torch.clone(x[..., half_size:])
x = torch.randn(4, 8)
out1, out2 = _split_op(x)
assert out1.shape == (4, 4)
assert out2.shape == (4, 4)
assert_close(out1, x[..., :4])
assert_close(out2, x[..., 4:])
class TestAutograd:
"""Tests for custom op with autograd support."""
def test_with_autograd(self):
"""Test custom op with setup_context and backward functions."""
def _square_infer_output_meta(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
def _square_setup_context(ctx, inputs, output):
(x,) = inputs
ctx.save_for_backward(x)
def _square_backward(ctx, grad_output):
(x,) = ctx.saved_tensors
return grad_output * 2 * x
@magi_register_custom_op(
name="test::square_op",
mutates_args=(),
infer_output_meta_fn=_square_infer_output_meta,
setup_context_fn=_square_setup_context,
backward_fn=_square_backward,
)
def _square_op(x: torch.Tensor) -> torch.Tensor:
return x * x
x = torch.randn(4, 8, requires_grad=True)
output = _square_op(x)
loss = output.sum()
loss.backward()
# Gradient of x^2 is 2x
expected_grad = 2 * x
assert_close(x.grad, expected_grad)
def test_autograd_multiple_inputs(self):
"""Test autograd with multiple input tensors."""
def _weighted_sum_infer_output_meta(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor:
return torch.empty_like(a)
def _weighted_sum_setup_context(ctx, inputs, output):
a, b, weight = inputs
ctx.save_for_backward(a, b)
ctx.weight = weight
def _weighted_sum_backward(ctx, grad_output):
a, b = ctx.saved_tensors
weight = ctx.weight
grad_a = grad_output * weight
grad_b = grad_output * (1 - weight)
return grad_a, grad_b, None # None for non-tensor input
@magi_register_custom_op(
name="test::weighted_sum_op",
mutates_args=(),
infer_output_meta_fn=_weighted_sum_infer_output_meta,
setup_context_fn=_weighted_sum_setup_context,
backward_fn=_weighted_sum_backward,
)
def _weighted_sum_op(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor:
return a * weight + b * (1 - weight)
a = torch.randn(4, 8, requires_grad=True)
b = torch.randn(4, 8, requires_grad=True)
weight = 0.7
output = _weighted_sum_op(a, b, weight)
loss = output.sum()
loss.backward()
expected_grad_a = torch.ones_like(a) * weight
expected_grad_b = torch.ones_like(b) * (1 - weight)
assert_close(a.grad, expected_grad_a)
assert_close(b.grad, expected_grad_b)
def test_autograd_multiple_outputs(self):
"""Test autograd with multiple output tensors."""
def _split_scale_infer_output_meta(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]:
half = x.shape[-1] // 2
return (x.new_empty((*x.shape[:-1], half)), x.new_empty((*x.shape[:-1], half)))
def _split_scale_setup_context(ctx, inputs, output):
x, scale = inputs
ctx.save_for_backward(x)
ctx.scale = scale
ctx.half = x.shape[-1] // 2
def _split_scale_backward(ctx, grad_out1, grad_out2):
(x,) = ctx.saved_tensors
scale = ctx.scale
# Reconstruct gradient for x
grad_x = torch.cat([grad_out1 * scale, grad_out2 * scale], dim=-1)
return grad_x, None
@magi_register_custom_op(
name="test::split_scale_op",
mutates_args=(),
infer_output_meta_fn=_split_scale_infer_output_meta,
setup_context_fn=_split_scale_setup_context,
backward_fn=_split_scale_backward,
)
def _split_scale_op(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]:
half = x.shape[-1] // 2
return x[..., :half] * scale, x[..., half:] * scale
x = torch.randn(4, 8, requires_grad=True)
scale = 2.0
out1, out2 = _split_scale_op(x, scale)
loss = out1.sum() + out2.sum()
loss.backward()
expected_grad = torch.ones_like(x) * scale
assert_close(x.grad, expected_grad)
class TestAutoGeneratedName:
"""Tests for auto-generated operator name when name is not provided."""
def test_auto_name_single_output(self):
"""Test auto-generated name with single tensor output."""
@magi_register_custom_op()
def _auto_name_single_op(x: torch.Tensor) -> torch.Tensor:
return x * 2
def fn(x):
return _auto_name_single_op(x)
compiled_fn = torch.compile(fn, backend="eager")
x = torch.randn(4, 8)
output = compiled_fn(x)
expected = x * 2
assert_close(output, expected)
def test_auto_name_multiple_outputs(self):
"""Test auto-generated name with multiple tensor outputs."""
@magi_register_custom_op()
def _auto_name_multi_out_op(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return torch.clone(a + 1), torch.clone(b + 2)
def fn(a, b):
return _auto_name_multi_out_op(a, b)
compiled_fn = torch.compile(fn, backend="eager")
a = torch.randn(3, 5)
b = torch.randn(3, 5)
out1, out2 = compiled_fn(a, b)
assert_close(out1, a + 1)
assert_close(out2, b + 2)
def test_auto_name_with_autograd(self):
"""Test auto-generated name with autograd support."""
def _auto_grad_setup_context(ctx, inputs, output):
(x,) = inputs
ctx.save_for_backward(x)
def _auto_grad_backward(ctx, grad_output):
(x,) = ctx.saved_tensors
return grad_output * 2 * x
@magi_register_custom_op(setup_context_fn=_auto_grad_setup_context, backward_fn=_auto_grad_backward)
def _auto_name_square_op(x: torch.Tensor) -> torch.Tensor:
return x * x
x = torch.randn(4, 8, requires_grad=True)
output = _auto_name_square_op(x)
loss = output.sum()
loss.backward()
expected_grad = 2 * x
assert_close(x.grad, expected_grad)
class TestDefaultIdentityMetaFn:
"""Tests for the default identity meta function when infer_output_meta_fn is not provided."""
def test_single_output_default_meta(self):
"""Test default meta function with single tensor output."""
@magi_register_custom_op(name="test::default_meta_single")
def _default_meta_single_op(x: torch.Tensor) -> torch.Tensor:
return x * 2
def fn(x):
return _default_meta_single_op(x)
compiled_fn = torch.compile(fn, backend="eager")
x = torch.randn(4, 8)
output = compiled_fn(x)
expected = x * 2
assert_close(output, expected)
def test_single_output_multiple_inputs_default_meta(self):
"""Test default meta function with multiple inputs but single tensor output."""
@magi_register_custom_op(name="test::default_meta_multi_in")
def _default_meta_multi_in_op(a: torch.Tensor, b: torch.Tensor, scale: float) -> torch.Tensor:
return (a + b) * scale
def fn(a, b, scale):
return _default_meta_multi_in_op(a, b, scale)
compiled_fn = torch.compile(fn, backend="eager")
a = torch.randn(4, 8)
b = torch.randn(4, 8)
scale = 2.5
output = compiled_fn(a, b, scale)
expected = (a + b) * scale
assert_close(output, expected)
def test_multiple_outputs_default_meta(self):
"""Test default meta function with multiple tensor outputs."""
@magi_register_custom_op(name="test::default_meta_multi_out")
def _default_meta_multi_out_op(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Clone to avoid aliasing issues
return torch.clone(x * 2), torch.clone(y * 3)
def fn(x, y):
return _default_meta_multi_out_op(x, y)
compiled_fn = torch.compile(fn, backend="eager")
x = torch.randn(4, 8)
y = torch.randn(4, 8)
out1, out2 = compiled_fn(x, y)
assert_close(out1, x * 2)
assert_close(out2, y * 3)
def test_three_outputs_default_meta(self):
"""Test default meta function with three tensor outputs."""
@magi_register_custom_op(name="test::default_meta_three_out", mutates_args=())
def _default_meta_three_out_op(
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return torch.clone(a + 1), torch.clone(b + 2), torch.clone(c + 3)
def fn(a, b, c):
return _default_meta_three_out_op(a, b, c)
compiled_fn = torch.compile(fn, backend="eager")
a = torch.randn(2, 4)
b = torch.randn(2, 4)
c = torch.randn(2, 4)
out1, out2, out3 = compiled_fn(a, b, c)
assert_close(out1, a + 1)
assert_close(out2, b + 2)
assert_close(out3, c + 3)
def test_default_meta_with_non_tensor_args(self):
"""Test default meta function correctly skips non-tensor arguments."""
@magi_register_custom_op(name="test::default_meta_mixed_args")
def _default_meta_mixed_args_op(
scale: float, x: torch.Tensor, offset: int, y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.clone(x * scale + offset), torch.clone(y * scale + offset)
def fn(scale, x, offset, y):
return _default_meta_mixed_args_op(scale, x, offset, y)
compiled_fn = torch.compile(fn, backend="eager")
scale = 2.0
x = torch.randn(3, 5)
offset = 10
y = torch.randn(3, 5)
out1, out2 = compiled_fn(scale, x, offset, y)
assert_close(out1, x * scale + offset)
assert_close(out2, y * scale + offset)
class TestTorchCompileIntegration:
"""Tests for integration with torch.compile."""
def test_custom_op_in_compiled_function(self):
"""Test that custom op works inside a torch.compile'd function."""
def _double_infer_output_meta(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
@magi_register_custom_op(name="test::double_op", mutates_args=(), infer_output_meta_fn=_double_infer_output_meta)
def _double_op(x: torch.Tensor) -> torch.Tensor:
return x * 2
def fn(x):
y = _double_op(x)
return y + 1
compiled_fn = torch.compile(fn, backend="eager")
x = torch.randn(4, 8)
output = compiled_fn(x)
expected = x * 2 + 1
assert_close(output, expected)
def test_custom_op_with_autograd_in_compiled_function(self):
"""Test custom op with autograd inside torch.compile'd function."""
def _cube_infer_output_meta(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
def _cube_setup_context(ctx, inputs, output):
(x,) = inputs
ctx.save_for_backward(x)
def _cube_backward(ctx, grad_output):
(x,) = ctx.saved_tensors
return grad_output * 3 * x * x
@magi_register_custom_op(
name="test::cube_op",
mutates_args=(),
infer_output_meta_fn=_cube_infer_output_meta,
setup_context_fn=_cube_setup_context,
backward_fn=_cube_backward,
)
def _cube_op(x: torch.Tensor) -> torch.Tensor:
return x * x * x
def fn(x):
return _cube_op(x)
compiled_fn = torch.compile(fn, backend="eager")
x = torch.randn(4, 8, requires_grad=True)
output = compiled_fn(x)
loss = output.sum()
loss.backward()
# Gradient of x^3 is 3x^2
expected_grad = 3 * x * x
assert_close(x.grad, expected_grad)
@pytest.fixture()
def magi_compile_config():
"""Fixture to set up a clean compile configuration for magi_compile tests."""
compile_config = CompileConfig(compile_mode=CompileMode.TORCH_COMPILE, cache_root_dir=tempfile.mkdtemp())
with patch("magi_compiler._api.get_compile_config") as mock_get_config, patch("torch.distributed.get_rank") as mock_rank:
mock_get_config.return_value = compile_config
mock_rank.return_value = 0
yield compile_config
import shutil
shutil.rmtree(compile_config.cache_root_dir, ignore_errors=True)
class TestMagiCompileIntegration:
"""Tests for integration with magi_compile decorator."""
def test_custom_op_in_magi_compiled_module(self, magi_compile_config):
"""Test that custom op works inside a magi_compile'd nn.Module."""
def _triple_infer_output_meta(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
@magi_register_custom_op(name="test::triple_op", mutates_args=(), infer_output_meta_fn=_triple_infer_output_meta)
def _triple_op(x: torch.Tensor) -> torch.Tensor:
return x * 3
@magi_compile(dynamic_arg_dims={"x": 0})
class TripleModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return _triple_op(x) + 1
model = TripleModule()
x = torch.randn(4, 8)
output = model(x)
expected = x * 3 + 1
assert_close(output, expected)
# Test with different batch size to exercise dynamic shapes
x2 = torch.randn(8, 8)
output2 = model(x2)
expected2 = x2 * 3 + 1
assert_close(output2, expected2)
def test_custom_op_with_autograd_in_magi_compiled_module(self, magi_compile_config):
"""Test custom op with autograd inside a magi_compile'd nn.Module."""
def _square_v2_infer_output_meta(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
def _square_v2_setup_context(ctx, inputs, output):
(x,) = inputs
ctx.save_for_backward(x)
def _square_v2_backward(ctx, grad_output):
(x,) = ctx.saved_tensors
return grad_output * 2 * x
@magi_register_custom_op(
name="test::square_v2_op",
mutates_args=(),
infer_output_meta_fn=_square_v2_infer_output_meta,
setup_context_fn=_square_v2_setup_context,
backward_fn=_square_v2_backward,
)
def _square_v2_op(x: torch.Tensor) -> torch.Tensor:
return x * x
@magi_compile(dynamic_arg_dims={"x": 0})
class SquareModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return _square_v2_op(x)
model = SquareModule()
x = torch.randn(4, 8, requires_grad=True)
output = model(x)
loss = output.sum()
loss.backward()
# Gradient of x^2 is 2x
expected_grad = 2 * x
assert_close(x.grad, expected_grad)
def test_custom_op_with_linear_in_magi_compiled_module(self, magi_compile_config):
"""Test custom op combined with nn.Linear inside a magi_compile'd module."""
def _relu_custom_infer_output_meta(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
@magi_register_custom_op(
name="test::relu_custom_op", mutates_args=(), infer_output_meta_fn=_relu_custom_infer_output_meta
)
def _relu_custom_op(x: torch.Tensor) -> torch.Tensor:
return torch.relu(x)
@magi_compile(dynamic_arg_dims={"x": 0})
class LinearReluModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(8, 8)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return _relu_custom_op(self.linear(x))
model = LinearReluModule()
x = torch.randn(4, 8)
output = model(x)
expected = torch.relu(model.linear(x))
assert_close(output, expected)
def test_multiple_custom_ops_in_magi_compiled_module(self, magi_compile_config):
"""Test multiple custom ops used together inside a magi_compile'd module."""
def _add_one_infer_output_meta(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
def _mul_two_infer_output_meta(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
@magi_register_custom_op(name="test::add_one_op", mutates_args=(), infer_output_meta_fn=_add_one_infer_output_meta)
def _add_one_op(x: torch.Tensor) -> torch.Tensor:
return x + 1
@magi_register_custom_op(name="test::mul_two_op", mutates_args=(), infer_output_meta_fn=_mul_two_infer_output_meta)
def _mul_two_op(x: torch.Tensor) -> torch.Tensor:
return x * 2
@magi_compile(dynamic_arg_dims={"x": 0})
class ChainedOpsModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# (x + 1) * 2
return _mul_two_op(_add_one_op(x))
model = ChainedOpsModule()
x = torch.randn(4, 8)
output = model(x)
expected = (x + 1) * 2
assert_close(output, expected)
def test_custom_op_multiple_outputs_in_magi_compiled_module(self, magi_compile_config):
"""Test custom op with multiple outputs inside a magi_compile'd module."""
def _split_v2_infer_output_meta(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
half = x.shape[-1] // 2
return (x.new_empty((*x.shape[:-1], half)), x.new_empty((*x.shape[:-1], half)))
@magi_register_custom_op(name="test::split_v2_op", mutates_args=(), infer_output_meta_fn=_split_v2_infer_output_meta)
def _split_v2_op(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
half = x.shape[-1] // 2
return torch.clone(x[..., :half]), torch.clone(x[..., half:])
@magi_compile(dynamic_arg_dims={"x": 0})
class SplitModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
a, b = _split_v2_op(x)
return a + 1, b * 2
model = SplitModule()
x = torch.randn(4, 8)
out1, out2 = model(x)
assert out1.shape == (4, 4)
assert out2.shape == (4, 4)
assert_close(out1, x[..., :4] + 1)
assert_close(out2, x[..., 4:] * 2)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

Xet Storage Details

Size:
23 kB
·
Xet hash:
9f5309c4070955e41168b68283be0e4c0f089e85bf9ae0e97a1a997e6f351e73

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.