Buckets:
| # Copyright (c) 2025 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Test nested compile scenarios: various combinations of torch.compile and magi_compile""" | |
| import os | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from magi_compiler import magi_compile | |
| from magi_compiler.config import CompileMode, get_compile_config | |
| DEVICE = "cuda" | |
| HIDDEN_SIZE = 8 | |
| TOLERANCE = 1e-3 | |
| # ============ Helper Functions ============ | |
| def is_torch_compiled(module: nn.Module) -> bool: | |
| """ | |
| Check if a module is compiled by torch.compile | |
| Two ways: | |
| 1. torch.compile(instance) -> OptimizedModule | |
| 2. @torch.compile def forward -> forward has _torchdynamo_orig_callable | |
| Note: @torch.compiler.disable also sets _torchdynamo_orig_callable, | |
| but additionally sets _torchdynamo_disable=True, need to exclude | |
| """ | |
| if type(module).__name__ == "OptimizedModule": | |
| return True | |
| forward_method = type(module).forward | |
| if hasattr(forward_method, "_torchdynamo_orig_callable"): | |
| if not getattr(forward_method, "_torchdynamo_disable", False): | |
| return True | |
| return False | |
| def is_torch_disabled(module: nn.Module) -> bool: | |
| """Check if forward is decorated with @torch.compiler.disable""" | |
| return getattr(type(module).forward, "_torchdynamo_disable", False) | |
| def assert_torch_compiled(module: nn.Module, msg: str = ""): | |
| assert is_torch_compiled(module), ( | |
| f"Expected torch.compile'd. type={type(module).__name__}, " | |
| f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}" | |
| ) | |
| def assert_not_torch_compiled_or_disabled(module: nn.Module, msg: str = ""): | |
| assert not is_torch_compiled(module), ( | |
| f"Expected NOT torch.compile'd. type={type(module).__name__}, " | |
| f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}" | |
| ) | |
| def assert_magi_compiled(module: nn.Module, msg: str = ""): | |
| assert hasattr(module, "_magi"), f"Missing _magi (MagiCompileState). {msg}" | |
| state = module._magi | |
| assert state is not None, f"_magi is None (compilation disabled). {msg}" | |
| assert state.compiled_code is not None, f"compiled_code is None. {msg}" | |
| def assert_not_magi_compiled(module: nn.Module, msg: str = ""): | |
| if hasattr(module, "_magi") and module._magi is not None: | |
| assert module._magi.compiled_code is None, f"compiled_code should be None. {msg}" | |
| def assert_torch_disabled(module: nn.Module, msg: str = ""): | |
| assert is_torch_disabled(module), ( | |
| f"Expected @torch.compiler.disable. " | |
| f"_torchdynamo_disable={getattr(type(module).forward, '_torchdynamo_disable', False)}. {msg}" | |
| ) | |
| # ============ Fixtures ============ | |
| def set_magi_compile_mode(): | |
| """Set compile_mode=MAGI_COMPILE during tests""" | |
| config = get_compile_config() | |
| old_value = config.compile_mode | |
| config.compile_mode = CompileMode.MAGI_COMPILE | |
| config.cache_root_dir = os.environ.get("MAGI_COMPILE_CACHE_ROOT_DIR", config.cache_root_dir) | |
| print(f"set magi compile_mode: {config.compile_mode}, cache root dir: {config.cache_root_dir}") | |
| yield | |
| config.compile_mode = old_value | |
| # ============ torch.compile Nested Behavior ============ | |
| def test_torch_compile_nested(): | |
| """torch.compile nested: inner compiled OptimizedModule as opaque node""" | |
| class InnerBlock(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| class OuterModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.inner = InnerBlock(hidden_size) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.inner(x) | |
| return self.output(x) | |
| model = OuterModel(HIDDEN_SIZE).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| model.inner = torch.compile(model.inner, fullgraph=False, dynamic=True) | |
| assert_torch_compiled(model.inner) | |
| with torch.no_grad(): | |
| inner_compiled_out = model(x) | |
| assert torch.allclose(baseline, inner_compiled_out, atol=TOLERANCE, rtol=TOLERANCE) | |
| compiled_model = torch.compile(model, fullgraph=False, dynamic=True) | |
| assert_torch_compiled(compiled_model) | |
| assert_torch_compiled(compiled_model.inner) | |
| with torch.no_grad(): | |
| nested_out = compiled_model(x) | |
| assert torch.allclose(baseline, nested_out, atol=TOLERANCE, rtol=TOLERANCE) | |
| def test_torch_compile_with_disable_inner(): | |
| """torch.compile + @torch.compiler.disable: disabled functions cause graph break""" | |
| class InnerBlock(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| class OuterModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.inner = InnerBlock(hidden_size) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.inner(x) | |
| return self.output(x) | |
| model = OuterModel(HIDDEN_SIZE).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| compiled_model = torch.compile(model, fullgraph=False, dynamic=True) | |
| assert_torch_compiled(compiled_model) | |
| assert_not_torch_compiled_or_disabled(compiled_model.inner) | |
| with torch.no_grad(): | |
| compiled_out = compiled_model(x) | |
| assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE) | |
| # ============ torch.compile + magi_compile Nested ============ | |
| def test_nested_torch_compile_magi_compile(): | |
| """Outer torch.compile + inner magi_compile""" | |
| class InnerMagiBlock(nn.Module): | |
| def __init__(self, hidden_size: int): | |
| super().__init__() | |
| self.linear1 = nn.Linear(hidden_size, hidden_size * 4) | |
| self.linear2 = nn.Linear(hidden_size * 4, hidden_size) | |
| self.norm = nn.LayerNorm(hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| x = self.norm(x) | |
| x = self.linear1(x) | |
| x = torch.nn.functional.gelu(x) | |
| x = self.linear2(x) | |
| return x + residual | |
| class OuterModel(nn.Module): | |
| def __init__(self, hidden_size: int, num_layers: int = 2): | |
| super().__init__() | |
| self.embed = nn.Linear(hidden_size, hidden_size) | |
| self.blocks = nn.ModuleList([InnerMagiBlock(hidden_size) for _ in range(num_layers)]) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.embed(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| return self.output(x) | |
| num_layers = 2 | |
| model = OuterModel(HIDDEN_SIZE, num_layers=num_layers).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| for i, block in enumerate(model.blocks): | |
| assert hasattr(block, "_magi_compiled"), f"block {i} should be @magi_compile'd" | |
| assert block._magi_compiled is True | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| for i, block in enumerate(model.blocks): | |
| assert_magi_compiled(block) | |
| assert_not_torch_compiled_or_disabled(block) | |
| compiled_model = torch.compile(model, fullgraph=False, dynamic=True) | |
| assert_torch_compiled(compiled_model) | |
| assert compiled_model._orig_mod is model | |
| with torch.no_grad(): | |
| compiled_out = compiled_model(x) | |
| for i, block in enumerate(model.blocks): | |
| assert block._magi_compiled is True | |
| assert_magi_compiled(block) | |
| assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE) | |
| def test_nested_torch_compile_multiple_magi_compile(): | |
| """Outer torch.compile with multiple magi_compile modules""" | |
| class MagiBlock1(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| class MagiBlock2(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| class OuterModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.block1 = MagiBlock1(hidden_size) | |
| self.block2 = MagiBlock2(hidden_size) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.block1(x) | |
| x = self.block2(x) | |
| return self.output(x) | |
| model = OuterModel(HIDDEN_SIZE).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| assert_magi_compiled(model.block1) | |
| assert_magi_compiled(model.block2) | |
| assert_not_torch_compiled_or_disabled(model.block1) | |
| assert_not_torch_compiled_or_disabled(model.block2) | |
| compiled_model = torch.compile(model, fullgraph=False, dynamic=True) | |
| assert_torch_compiled(compiled_model) | |
| assert_not_torch_compiled_or_disabled(model.block1) | |
| assert_not_torch_compiled_or_disabled(model.block2) | |
| with torch.no_grad(): | |
| compiled_out = compiled_model(x) | |
| assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE) | |
| # ============ torch.compile Decorator + magi_compile Nested ============ | |
| def test_decorator_torch_compile_on_forward(): | |
| """@torch.compile decorate forward: module type unchanged, but is_torch_compiled returns True""" | |
| class MyModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| model = MyModel(HIDDEN_SIZE).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| # eager baseline | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| # Create version with @torch.compile forward | |
| class CompiledModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| compiled_model = CompiledModel(HIDDEN_SIZE).to(DEVICE) | |
| compiled_model.load_state_dict(model.state_dict()) | |
| assert type(compiled_model).__name__ == "CompiledModel" | |
| assert_torch_compiled(compiled_model) | |
| with torch.no_grad(): | |
| out = compiled_model(x) | |
| assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE) | |
| def test_decorator_nested_torch_compile_forward_magi_inner(): | |
| """Outer forward @torch.compile + inner @magi_compile""" | |
| class InnerBlock(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| class OuterModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.inner = InnerBlock(hidden_size) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.inner(x) | |
| return self.output(x) | |
| model = OuterModel(HIDDEN_SIZE).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| # eager baseline | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| # Create magi inner + torch.compile forward outer version | |
| MagiInnerBlock = magi_compile()(InnerBlock) | |
| class CompiledOuterModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.inner = MagiInnerBlock(hidden_size) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.inner(x) | |
| return self.output(x) | |
| compiled_model = CompiledOuterModel(HIDDEN_SIZE).to(DEVICE) | |
| compiled_model.load_state_dict(model.state_dict()) | |
| assert_torch_compiled(compiled_model) | |
| with torch.no_grad(): | |
| out = compiled_model(x) | |
| assert_magi_compiled(compiled_model.inner) | |
| assert_not_torch_compiled_or_disabled(compiled_model.inner) | |
| assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE) | |
| # ============ torch._dynamo.config Correctness Verification ============ | |
| def test_dynamo_config_nested_patch_restore(): | |
| """Verify config.patch() correctly restores to previous value when nested""" | |
| import torch._dynamo.config as config | |
| # Record initial value | |
| initial_value = config.assume_static_by_default | |
| # Simulate outer compile setting dynamic=True (assume_static_by_default=False) | |
| with config.patch(assume_static_by_default=False): | |
| assert config.assume_static_by_default is False, "outer patch should set value to False" | |
| # Simulate inner magi_compile restoring default (assume_static_by_default=True) | |
| with config.patch(assume_static_by_default=True): | |
| assert config.assume_static_by_default is True, "inner patch should set value to True" | |
| # After inner exits, should restore to outer value | |
| assert config.assume_static_by_default is False, "after inner exit should restore to outer value False" | |
| # After outer exits, should restore to initial value | |
| assert ( | |
| config.assume_static_by_default == initial_value | |
| ), f"after outer exit should restore to initial value {initial_value}" | |
| def test_dynamo_config_multiple_options_patch(): | |
| """Verify correctness when patching multiple config options simultaneously""" | |
| import torch._dynamo.config as config | |
| # Record initial values | |
| initial_assume_static = config.assume_static_by_default | |
| initial_suppress_errors = config.suppress_errors | |
| initial_verbose = config.verbose | |
| # Patch multiple config options simultaneously | |
| with config.patch( | |
| assume_static_by_default=not initial_assume_static, | |
| suppress_errors=not initial_suppress_errors, | |
| verbose=not initial_verbose, | |
| ): | |
| # Verify all config options are modified | |
| assert config.assume_static_by_default == (not initial_assume_static), "assume_static_by_default should be modified" | |
| assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors should be modified" | |
| assert config.verbose == (not initial_verbose), "verbose should be modified" | |
| # Nested patch for partial config options | |
| with config.patch(assume_static_by_default=initial_assume_static): | |
| assert config.assume_static_by_default == initial_assume_static, "inner should restore assume_static_by_default" | |
| # Other config options should keep outer patch values | |
| assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors should keep outer value" | |
| assert config.verbose == (not initial_verbose), "verbose should keep outer value" | |
| # After inner exits, assume_static_by_default should restore to outer patch value | |
| assert config.assume_static_by_default == ( | |
| not initial_assume_static | |
| ), "after inner exit should restore to outer patch value" | |
| # After outer exits, all config options should restore to initial values | |
| assert config.assume_static_by_default == initial_assume_static, "after outer exit assume_static_by_default should restore" | |
| assert config.suppress_errors == initial_suppress_errors, "after outer exit suppress_errors should restore" | |
| assert config.verbose == initial_verbose, "after outer exit verbose should restore" | |
| def test_dynamo_config_restore_on_exception(): | |
| """Verify config correctly restores when exception is raised inside with block""" | |
| import torch._dynamo.config as config | |
| # Record initial value | |
| initial_value = config.assume_static_by_default | |
| # Test single-layer patch restore on exception | |
| try: | |
| with config.patch(assume_static_by_default=not initial_value): | |
| assert config.assume_static_by_default == (not initial_value), "patch should take effect" | |
| raise RuntimeError("test exception") | |
| except RuntimeError: | |
| pass | |
| # After exception config should restore | |
| assert config.assume_static_by_default == initial_value, "after single exception should restore to initial value" | |
| # Test nested patch restore on inner exception | |
| try: | |
| with config.patch(assume_static_by_default=False): | |
| assert config.assume_static_by_default is False, "outer patch should take effect" | |
| try: | |
| with config.patch(assume_static_by_default=True): | |
| assert config.assume_static_by_default is True, "inner patch should take effect" | |
| raise ValueError("inner test exception") | |
| except ValueError: | |
| pass | |
| # After inner exception caught, should restore to outer value | |
| assert config.assume_static_by_default is False, "after inner exception should restore to outer value" | |
| except Exception: | |
| pytest.fail("outer should not catch exception") | |
| # Final should restore to initial value | |
| assert config.assume_static_by_default == initial_value, "after nested exception should restore to initial value" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v", "-s"]) | |
Xet Storage Details
- Size:
- 19 kB
- Xet hash:
- 5afd727abb325d7a2f4ba52384483276124a4548dde16950ced4047e45dc70ec
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.