Buckets:
| # Copyright (c) 2026 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Tests for @magi_compile decorator functionality. | |
| """ | |
| import shutil | |
| import tempfile | |
| import time | |
| from typing import Tuple | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| import torch | |
| from torch import nn | |
| from torch.testing import assert_close | |
| from magi_compiler.api import magi_compile | |
| from magi_compiler.config import CompileConfig, CompileMode | |
| def compile_config_fixture(): | |
| """Fixture to set up a clean compile configuration for each test.""" | |
| cache_dir = tempfile.mkdtemp() | |
| compile_config = CompileConfig(compile_mode=CompileMode.TORCH_COMPILE, cache_root_dir=cache_dir) | |
| with patch("magi_compiler.config.get_compile_config") as mock_get_config, patch("torch.distributed.get_rank") as mock_rank: | |
| mock_get_config.return_value = compile_config | |
| mock_rank.return_value = 0 | |
| yield compile_config | |
| shutil.rmtree(cache_dir, ignore_errors=True) | |
| class TestDynamicArgDims: | |
| """Tests for dynamic_arg_dims inference and handling.""" | |
| def test_automatic_inference(self): | |
| """Test that dynamic_arg_dims is automatically inferred.""" | |
| # Base class | |
| class BaseModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear = nn.Linear(10, 5) | |
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| return self.linear(x + y) | |
| # 1. Class level | |
| class ClsBaseModel(BaseModel): | |
| pass | |
| cls_model = ClsBaseModel() | |
| # 2. Function level | |
| class FuncBaseModel(BaseModel): | |
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| return super().forward(x, y) | |
| func_model = FuncBaseModel() | |
| # 3. Instance level | |
| inst_model = BaseModel() | |
| inst_model = magi_compile(inst_model) | |
| # 4. Method level | |
| mtd_model = BaseModel() | |
| mtd_model.forward = magi_compile(mtd_model.forward) | |
| models = [cls_model, func_model, inst_model, mtd_model] | |
| for model in models: | |
| # Test with different batch sizes to ensure dynamic shape works | |
| x1 = torch.randn(4, 10) | |
| y1 = torch.randn(4, 10) | |
| output1 = model(x1, y1) | |
| assert output1.shape == (4, 5) | |
| x2 = torch.randn(8, 10) | |
| y2 = torch.randn(8, 10) | |
| output2 = model(x2, y2) | |
| assert output2.shape == (8, 5) | |
| def test_negative_index(self): | |
| """Test that negative indices in dynamic_arg_dims are handled correctly.""" | |
| class BaseDynamicLastDimModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.out_features = 5 | |
| self._weight = None | |
| self._bias = None | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| in_features = x.size(-1) | |
| if self._weight is None or self._weight.size(1) != in_features: | |
| self._weight = torch.randn(self.out_features, in_features, device=x.device, dtype=x.dtype) | |
| self._bias = torch.randn(self.out_features, device=x.device, dtype=x.dtype) | |
| return torch.matmul(x, self._weight.t()) + self._bias | |
| # 1. Class level | |
| class ClsBaseDynamicLastDimModel(BaseDynamicLastDimModel): | |
| pass | |
| cls_model = ClsBaseDynamicLastDimModel() | |
| # 2. Function level | |
| class FuncBaseDynamicLastDimModel(BaseDynamicLastDimModel): | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return super().forward(x) | |
| func_model = FuncBaseDynamicLastDimModel() | |
| # 3. Instance level | |
| inst_model = BaseDynamicLastDimModel() | |
| inst_model = magi_compile(inst_model, dynamic_arg_dims={"x": -1}) | |
| # 4. Method level | |
| mtd_model = BaseDynamicLastDimModel() | |
| mtd_model.forward = magi_compile(mtd_model.forward, dynamic_arg_dims={"x": -1}) | |
| models = [cls_model, func_model, inst_model, mtd_model] | |
| for model in models: | |
| x1 = torch.randn(4, 10) | |
| x2 = torch.randn(4, 15) | |
| output1 = model(x1) | |
| output2 = model(x2) | |
| assert output1.shape == (4, 5) | |
| assert output2.shape == (4, 5) | |
| # Accessing weight depends on which level we patched | |
| m = model.obj if hasattr(model, "obj") else model | |
| if isinstance(m, nn.Module): | |
| assert m._weight.size(1) == 15 | |
| def test_nested_dataclass(self): | |
| """Test that dynamic_arg_dims can handle nested dataclasses correctly.""" | |
| from dataclasses import dataclass | |
| class InnerConfig: | |
| tensor_val: torch.Tensor | |
| other_val: int = 1 | |
| class OuterConfig: | |
| inner: InnerConfig | |
| flag: bool = True | |
| class DataclassModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear = nn.Linear(10, 5) | |
| def forward(self, config: OuterConfig) -> torch.Tensor: | |
| x = config.inner.tensor_val | |
| return self.linear(x) * config.inner.other_val | |
| class CompiledDataclassModel(DataclassModel): | |
| pass | |
| model = CompiledDataclassModel() | |
| # Test with batch size 4 | |
| x1 = torch.randn(4, 10) | |
| config1 = OuterConfig(inner=InnerConfig(tensor_val=x1)) | |
| out1 = model(config1) | |
| assert out1.shape == (4, 5) | |
| # Test with batch size 8 | |
| x2 = torch.randn(8, 10) | |
| config2 = OuterConfig(inner=InnerConfig(tensor_val=x2)) | |
| out2 = model(config2) | |
| assert out2.shape == (8, 5) | |
| class TestCompilationCorrectness: | |
| """Tests verifying that compiled models produce correct outputs.""" | |
| def test_compiled_matches_native_forward_complex(self): | |
| """Test the correctness of ComplexModel across all 12 configuration combinations. | |
| 12 combinations = 4 objects (Class, Function, Instance, Method) * 3 ways (Decorator, Imperative, Factory) | |
| """ | |
| size = 32 | |
| x = torch.randn(2, size) | |
| # Define base structure for testing different styles | |
| class ComplexBlock(nn.Module): | |
| def __init__(self, size): | |
| super().__init__() | |
| self.ln = nn.LayerNorm(size) | |
| self.linear = nn.Linear(size, size) | |
| self.act = nn.GELU() | |
| def forward(self, x): | |
| return self.act(self.linear(self.ln(x))) | |
| class ComplexModel(nn.Module): | |
| def __init__(self, size): | |
| super().__init__() | |
| self.blocks = nn.Sequential(*[ComplexBlock(size) for _ in range(3)]) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for block in self.blocks: | |
| x = x + block(x) | |
| return x | |
| # Prepare native model | |
| native_model = ComplexModel(size) | |
| native_output = native_model(x) | |
| # Prepare factory compiler | |
| compiler = magi_compile(dynamic_arg_dims={"x": 0}) | |
| # --- Test Matrix --- | |
| # 1. Class | |
| class ClsDeco(ComplexModel): | |
| pass | |
| class ClsImp(ComplexModel): | |
| pass | |
| ClsImp = magi_compile(ClsImp, dynamic_arg_dims={"x": 0}) | |
| class ClsFact(ComplexModel): | |
| pass | |
| # 2. Instance | |
| inst_deco = ComplexModel(size) # N/A: Instance cannot be decorated directly | |
| inst_imp = magi_compile(ComplexModel(size), dynamic_arg_dims={"x": 0}) | |
| inst_fact = compiler(ComplexModel(size)) | |
| # 3. Function (Using model forward) | |
| def func_deco(x: torch.Tensor): | |
| return native_model(x) | |
| def func_imp_inner(x: torch.Tensor): | |
| return native_model(x) | |
| func_imp = magi_compile(func_imp_inner, dynamic_arg_dims={"x": 0}) | |
| func_fact_inner = lambda x: native_model(x) | |
| func_fact = compiler(func_fact_inner) | |
| # 4. Method | |
| class MtdDeco(ComplexModel): | |
| def forward(self, x): | |
| return super().forward(x) | |
| mtd_deco = MtdDeco(size) | |
| mtd_imp = ComplexModel(size) | |
| mtd_imp.forward = magi_compile(mtd_imp.forward, dynamic_arg_dims={"x": 0}) | |
| mtd_fact = ComplexModel(size) | |
| mtd_fact.forward = compiler(mtd_fact.forward) | |
| # 3. Unified verification function | |
| def _check(m_or_f, name): | |
| out = m_or_f(x) | |
| assert_close(out, native_output, rtol=1e-3, atol=1e-3, msg=f"Failed at {name}") | |
| # Helper function: synchronize state | |
| def _sync(model): | |
| if hasattr(model, "load_state_dict"): | |
| model.load_state_dict(native_model.state_dict()) | |
| return model | |
| _check(_sync(ClsDeco(size)), "ClsDeco") | |
| _check(_sync(ClsImp(size)), "ClsImp") | |
| _check(_sync(ClsFact(size)), "ClsFact") | |
| # For instances, they are new instances upon construction and require weight synchronization | |
| _check(_sync(inst_imp), "inst_imp") | |
| _check(_sync(inst_fact), "inst_fact") | |
| _check(func_deco, "func_deco") | |
| _check(func_imp, "func_imp") | |
| _check(func_fact, "func_fact") | |
| # For methods, mtd_... also requires synchronization | |
| _check(_sync(mtd_deco), "mtd_deco") | |
| _check(_sync(mtd_imp), "mtd_imp") | |
| _check(_sync(mtd_fact), "mtd_fact") | |
| def test_nested_function_calls(self): | |
| """Test compilation of model with nested function calls.""" | |
| class NestedModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv = nn.Conv2d(3, 8, kernel_size=1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self._preprocess(x) | |
| return self.conv(x) | |
| def _preprocess(self, x: torch.Tensor) -> torch.Tensor: | |
| return x * 2 + 1 | |
| x = torch.randn(2, 3, 8, 8) | |
| native_model = NestedModel() | |
| native_output = native_model.forward(x) | |
| # Class level | |
| class ClsLevelCompiledNestedModel(NestedModel): | |
| pass | |
| cls_model = ClsLevelCompiledNestedModel() | |
| cls_model.load_state_dict(native_model.state_dict()) | |
| cls_level_output = cls_model.forward(x) | |
| # Instance level | |
| inst_model = NestedModel() | |
| inst_model.load_state_dict(native_model.state_dict()) | |
| inst_model = magi_compile(inst_model, dynamic_arg_dims={"x": 0}) | |
| inst_level_output = inst_model.forward(x) | |
| assert_close(cls_level_output, native_output, rtol=1e-3, atol=1e-3) | |
| assert_close(inst_level_output, native_output, rtol=1e-3, atol=1e-3) | |
| def test_multiple_outputs(self): | |
| """Test compilation of model with multiple outputs.""" | |
| class MultiOutputModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear1 = nn.Linear(10, 5) | |
| self.linear2 = nn.Linear(10, 5) | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| return self.linear1(x), self.linear2(x) | |
| x = torch.randn(4, 10) | |
| native_model = MultiOutputModel() | |
| native_output1, native_output2 = native_model(x) | |
| # Class level | |
| class ClsLevelCompiledMultiOutputModel(MultiOutputModel): | |
| pass | |
| cls_model = ClsLevelCompiledMultiOutputModel() | |
| cls_model.load_state_dict(native_model.state_dict()) | |
| cls_level_output1, cls_level_output2 = cls_model(x) | |
| assert_close(cls_level_output1, native_output1, rtol=1e-3, atol=1e-3) | |
| assert_close(cls_level_output2, native_output2, rtol=1e-3, atol=1e-3) | |
| class TestRecompilation: | |
| """Tests for source code change detection and recompilation.""" | |
| def test_source_change_triggers_recompile(self, tmpdir): | |
| """Test that modifying source code triggers recompilation.""" | |
| import importlib | |
| import sys | |
| model_config = MagicMock() | |
| # Create a temporary source file for the model | |
| src_file = tmpdir.join("test_model.py") | |
| src_content = """ | |
| import torch | |
| from torch import nn | |
| class TempModel(nn.Module): | |
| def __init__(self, *, model_config): | |
| super().__init__() | |
| self.linear = nn.Linear(10, 5) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.linear(x) | |
| """ | |
| src_file.write(src_content) | |
| sys.path.insert(0, str(tmpdir)) | |
| import test_model | |
| from test_model import TempModel | |
| CompiledTempModel = magi_compile(dynamic_arg_dims={"x": 0})(TempModel) | |
| model1 = CompiledTempModel(model_config=model_config) | |
| x = torch.randn(4, 10) | |
| output1 = model1(x) | |
| # Modify the source file to change the forward logic | |
| modified_src = """ | |
| import torch | |
| from torch import nn | |
| class TempModel(nn.Module): | |
| def __init__(self, *, model_config): | |
| super().__init__() | |
| self.linear = nn.Linear(10, 5) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.linear(x) * 2 | |
| """ | |
| src_file.write(modified_src) | |
| importlib.reload(test_model) | |
| from test_model import TempModel | |
| CompiledTempModel2 = magi_compile(dynamic_arg_dims={"x": 0})(TempModel) | |
| model2 = CompiledTempModel2(model_config=model_config) | |
| output2 = model2(x) | |
| # Outputs should be different due to the *2 multiplication | |
| assert not torch.allclose(output1, output2, rtol=1e-5, atol=1e-5) | |
| # Verify compiled output matches native forward | |
| native_output = model2.forward(x) | |
| assert_close(output2, native_output, rtol=1e-5, atol=1e-5) | |
| class TestPerformanceImprovementConsistency: | |
| def test_simple_model_performance_improvement_consistency(self): | |
| """Test that a simple model achieves similar performance across all compilation levels.""" | |
| class SimpleModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.dim = 8 | |
| self.layers = nn.ModuleList([nn.Linear(self.dim, self.dim) for _ in range(2)]) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for layer in self.layers: | |
| res = x | |
| x = layer(x) | |
| x = torch.nn.functional.gelu(x) | |
| x = x + res | |
| return x | |
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| SEQ_LEN = 8 | |
| test_input = torch.randn(SEQ_LEN, 8).to(DEVICE) | |
| native_model = SimpleModel().to(DEVICE) | |
| class ClsSimpleModel(SimpleModel): | |
| pass | |
| cls_model = ClsSimpleModel().to(DEVICE) | |
| cls_model.load_state_dict(native_model.state_dict()) | |
| inst_model = SimpleModel().to(DEVICE) | |
| inst_model.load_state_dict(native_model.state_dict()) | |
| inst_model = magi_compile(inst_model, dynamic_arg_dims={"x": 0}) | |
| # Basic correctness across levels | |
| with torch.no_grad(): | |
| native_out = native_model(test_input) | |
| cls_out = cls_model(test_input) | |
| inst_out = inst_model(test_input) | |
| assert_close(cls_out, native_out, rtol=1e-3, atol=1e-3) | |
| assert_close(inst_out, native_out, rtol=1e-3, atol=1e-3) | |
| def test_simple_model_timing_class_function_instance_method(self): | |
| """Lightweight timing sanity: Class / Function / Instance / Method entrypoints.""" | |
| class SimpleModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.dim = 32 | |
| self.layers = nn.ModuleList([nn.Linear(self.dim, self.dim) for _ in range(4)]) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for layer in self.layers: | |
| res = x | |
| x = layer(x) | |
| x = torch.nn.functional.gelu(x) | |
| x = x + res | |
| return x | |
| device = torch.device("cuda:0") | |
| seq_len = 16 | |
| test_input = torch.randn(seq_len, 32, device=device) | |
| native = SimpleModel().to(device).eval() | |
| class ClsSimpleModel(SimpleModel): | |
| pass | |
| cls_model = ClsSimpleModel().to(device).eval() | |
| cls_model.load_state_dict(native.state_dict()) | |
| inst_model = SimpleModel().to(device).eval() | |
| inst_model.load_state_dict(native.state_dict()) | |
| inst_model = magi_compile(inst_model, dynamic_arg_dims={"x": 0}) | |
| func_native = SimpleModel().to(device).eval() | |
| func_native.load_state_dict(native.state_dict()) | |
| def func_entry(x: torch.Tensor) -> torch.Tensor: | |
| return func_native(x) | |
| mtd_model = SimpleModel().to(device).eval() | |
| mtd_model.load_state_dict(native.state_dict()) | |
| mtd_model.forward = magi_compile(mtd_model.forward, dynamic_arg_dims={"x": 0}) | |
| def _bench(callable_obj, label: str, warmup: int = 5, iters: int = 200) -> float: | |
| with torch.no_grad(): | |
| for _ in range(warmup): | |
| callable_obj(test_input) | |
| torch.cuda.synchronize() | |
| start = time.perf_counter() | |
| with torch.no_grad(): | |
| for _ in range(iters): | |
| callable_obj(test_input) | |
| torch.cuda.synchronize() | |
| elapsed = time.perf_counter() - start | |
| print(f"{label}: {elapsed:.4f}s") | |
| return elapsed | |
| t_class = _bench(cls_model, "class") | |
| t_func = _bench(func_entry, "function") | |
| t_inst = _bench(inst_model, "instance") | |
| t_mtd = _bench(mtd_model, "method") | |
| compiled_times = [t_class, t_func, t_inst, t_mtd] | |
| max_compiled = max(compiled_times) | |
| min_compiled = min(compiled_times) | |
| assert max_compiled / min_compiled < 1.2, ( | |
| "Magi entry timings diverged too much: " | |
| f"class={t_class:.4f}s, function={t_func:.4f}s, instance={t_inst:.4f}s, method={t_mtd:.4f}s" | |
| ) | |
Xet Storage Details
- Size:
- 19.5 kB
- Xet hash:
- 997249a78724221f3aa780cd31bb6217f3bbb20a42bde0aaa38747c8a73d1210
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.