# Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """测试嵌套 compile 场景:torch.compile 与 magi_compile 的各种组合""" import os import pytest import torch import torch.nn as nn from magi_compiler import magi_compile from magi_compiler.config import CompileMode, get_compile_config DEVICE = "cuda" HIDDEN_SIZE = 64 TOLERANCE = 1e-3 # ============ 辅助函数 ============ def is_torch_compiled(module: nn.Module) -> bool: """ 检查模块是否被 torch.compile 编译 两种方式: 1. torch.compile(instance) -> OptimizedModule 2. @torch.compile def forward -> forward 有 _torchdynamo_orig_callable 注意:@torch.compiler.disable 也设置 _torchdynamo_orig_callable, 但会额外设置 _torchdynamo_disable=True,需排除 """ if type(module).__name__ == "OptimizedModule": return True forward_method = type(module).forward if hasattr(forward_method, "_torchdynamo_orig_callable"): if not getattr(forward_method, "_torchdynamo_disable", False): return True return False def is_torch_disabled(module: nn.Module) -> bool: """检查 forward 是否被 @torch.compiler.disable 装饰""" return getattr(type(module).forward, "_torchdynamo_disable", False) def assert_torch_compiled(module: nn.Module, msg: str = ""): assert is_torch_compiled(module), ( f"Expected torch.compile'd. type={type(module).__name__}, " f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}" ) def assert_not_torch_compiled_or_disabled(module: nn.Module, msg: str = ""): assert not is_torch_compiled(module), ( f"Expected NOT torch.compile'd. type={type(module).__name__}, " f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}" ) def assert_magi_compiled(module: nn.Module, msg: str = ""): assert hasattr(module, "compiled_code"), f"Missing compiled_code. {msg}" assert module.compiled_code is not None, f"compiled_code is None. {msg}" def assert_not_magi_compiled(module: nn.Module, msg: str = ""): if hasattr(module, "compiled_code"): assert module.compiled_code is None, f"compiled_code should be None. {msg}" def assert_torch_disabled(module: nn.Module, msg: str = ""): assert is_torch_disabled(module), ( f"Expected @torch.compiler.disable. " f"_torchdynamo_disable={getattr(type(module).forward, '_torchdynamo_disable', False)}. {msg}" ) # ============ Fixtures ============ @pytest.fixture(autouse=True) def set_magi_compile_mode(): """测试期间 compile_mode=MAGI_COMPILE""" config = get_compile_config() old_value = config.compile_mode config.compile_mode = CompileMode.MAGI_COMPILE config.cache_root_dir = os.environ.get("MAGI_COMPILE_CACHE_ROOT_DIR", config.cache_root_dir) print(f"set magi compile mode: {config.compile_mode}, cache root dir: {config.cache_root_dir}") yield config.compile_mode = old_value # ============ torch.compile 嵌套行为 ============ def test_torch_compile_nested(): """torch.compile 嵌套:内层已编译的 OptimizedModule 作为 opaque 节点""" class InnerBlock(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(hidden_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.relu(self.linear(x)) class OuterModel(nn.Module): def __init__(self, hidden_size): super().__init__() self.inner = InnerBlock(hidden_size) self.output = nn.Linear(hidden_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.inner(x) return self.output(x) model = OuterModel(HIDDEN_SIZE).to(DEVICE) x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) with torch.no_grad(): baseline = model(x) model.inner = torch.compile(model.inner, fullgraph=False, dynamic=True) assert_torch_compiled(model.inner) with torch.no_grad(): inner_compiled_out = model(x) assert torch.allclose(baseline, inner_compiled_out, atol=TOLERANCE, rtol=TOLERANCE) compiled_model = torch.compile(model, fullgraph=False, dynamic=True) assert_torch_compiled(compiled_model) assert_torch_compiled(compiled_model.inner) with torch.no_grad(): nested_out = compiled_model(x) assert torch.allclose(baseline, nested_out, atol=TOLERANCE, rtol=TOLERANCE) def test_torch_compile_with_disable_inner(): """torch.compile + @torch.compiler.disable:disable 的函数产生 graph break""" class InnerBlock(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(hidden_size, hidden_size) @torch.compiler.disable def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.relu(self.linear(x)) class OuterModel(nn.Module): def __init__(self, hidden_size): super().__init__() self.inner = InnerBlock(hidden_size) self.output = nn.Linear(hidden_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.inner(x) return self.output(x) model = OuterModel(HIDDEN_SIZE).to(DEVICE) x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) with torch.no_grad(): baseline = model(x) compiled_model = torch.compile(model, fullgraph=False, dynamic=True) assert_torch_compiled(compiled_model) assert_not_torch_compiled_or_disabled(compiled_model.inner) with torch.no_grad(): compiled_out = compiled_model(x) assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE) # ============ torch.compile + magi_compile 嵌套 ============ def test_nested_torch_compile_magi_compile(): """外层 torch.compile + 内层 magi_compile""" @magi_compile() class InnerMagiBlock(nn.Module): def __init__(self, hidden_size: int): super().__init__() self.linear1 = nn.Linear(hidden_size, hidden_size * 4) self.linear2 = nn.Linear(hidden_size * 4, hidden_size) self.norm = nn.LayerNorm(hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.norm(x) x = self.linear1(x) x = torch.nn.functional.gelu(x) x = self.linear2(x) return x + residual class OuterModel(nn.Module): def __init__(self, hidden_size: int, num_layers: int = 2): super().__init__() self.embed = nn.Linear(hidden_size, hidden_size) self.blocks = nn.ModuleList([InnerMagiBlock(hidden_size) for _ in range(num_layers)]) self.output = nn.Linear(hidden_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.embed(x) for block in self.blocks: x = block(x) return self.output(x) num_layers = 2 model = OuterModel(HIDDEN_SIZE, num_layers=num_layers).to(DEVICE) x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) for i, block in enumerate(model.blocks): assert hasattr(block, "enable_compile") assert block.enable_compile is True with torch.no_grad(): baseline = model(x) for i, block in enumerate(model.blocks): assert_magi_compiled(block) assert_not_torch_compiled_or_disabled(block) compiled_model = torch.compile(model, fullgraph=False, dynamic=True) assert_torch_compiled(compiled_model) assert compiled_model._orig_mod is model with torch.no_grad(): compiled_out = compiled_model(x) for i, block in enumerate(model.blocks): assert block.enable_compile is True assert_magi_compiled(block) assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE) def test_nested_torch_compile_multiple_magi_compile(): """外层 torch.compile 包含多个 magi_compile 模块""" @magi_compile() class MagiBlock1(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(hidden_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.relu(self.linear(x)) @magi_compile() class MagiBlock2(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(hidden_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.relu(self.linear(x)) class OuterModel(nn.Module): def __init__(self, hidden_size): super().__init__() self.block1 = MagiBlock1(hidden_size) self.block2 = MagiBlock2(hidden_size) self.output = nn.Linear(hidden_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.block1(x) x = self.block2(x) return self.output(x) model = OuterModel(HIDDEN_SIZE).to(DEVICE) x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) with torch.no_grad(): baseline = model(x) assert_magi_compiled(model.block1) assert_magi_compiled(model.block2) assert_not_torch_compiled_or_disabled(model.block1) assert_not_torch_compiled_or_disabled(model.block2) compiled_model = torch.compile(model, fullgraph=False, dynamic=True) assert_torch_compiled(compiled_model) assert_not_torch_compiled_or_disabled(model.block1) assert_not_torch_compiled_or_disabled(model.block2) with torch.no_grad(): compiled_out = compiled_model(x) assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE) # ============ torch.compile 使用装饰器 + magi_compile 嵌套 ============ def test_decorator_torch_compile_on_forward(): """@torch.compile 装饰 forward:模块类型不变,但 is_torch_compiled 返回 True""" class MyModel(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(hidden_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.relu(self.linear(x)) model = MyModel(HIDDEN_SIZE).to(DEVICE) x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) # eager baseline with torch.no_grad(): baseline = model(x) # 创建带 @torch.compile forward 的版本 class CompiledModel(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(hidden_size, hidden_size) @torch.compile(fullgraph=False, dynamic=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.relu(self.linear(x)) compiled_model = CompiledModel(HIDDEN_SIZE).to(DEVICE) compiled_model.load_state_dict(model.state_dict()) assert type(compiled_model).__name__ == "CompiledModel" assert_torch_compiled(compiled_model) with torch.no_grad(): out = compiled_model(x) assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE) def test_decorator_nested_torch_compile_forward_magi_inner(): """外层 forward @torch.compile + 内层 @magi_compile""" class InnerBlock(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(hidden_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.relu(self.linear(x)) class OuterModel(nn.Module): def __init__(self, hidden_size): super().__init__() self.inner = InnerBlock(hidden_size) self.output = nn.Linear(hidden_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.inner(x) return self.output(x) model = OuterModel(HIDDEN_SIZE).to(DEVICE) x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) # eager baseline with torch.no_grad(): baseline = model(x) # 创建 magi inner + torch.compile forward outer 版本 MagiInnerBlock = magi_compile()(InnerBlock) class CompiledOuterModel(nn.Module): def __init__(self, hidden_size): super().__init__() self.inner = MagiInnerBlock(hidden_size) self.output = nn.Linear(hidden_size, hidden_size) @torch.compile(fullgraph=False, dynamic=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.inner(x) return self.output(x) compiled_model = CompiledOuterModel(HIDDEN_SIZE).to(DEVICE) compiled_model.load_state_dict(model.state_dict()) assert_torch_compiled(compiled_model) with torch.no_grad(): out = compiled_model(x) assert_magi_compiled(compiled_model.inner) assert_not_torch_compiled_or_disabled(compiled_model.inner) assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE) # ============ torch._dynamo.config 正确性验证 ============ def test_dynamo_config_nested_patch_restore(): """验证 config.patch() 嵌套时能正确恢复到上一层的值""" import torch._dynamo.config as config # 记录初始值 initial_value = config.assume_static_by_default # 模拟外层 compile 设置 dynamic=True (assume_static_by_default=False) with config.patch(assume_static_by_default=False): assert config.assume_static_by_default is False, "外层 patch 应将值设为 False" # 模拟内层 magi_compile 恢复默认 (assume_static_by_default=True) with config.patch(assume_static_by_default=True): assert config.assume_static_by_default is True, "内层 patch 应将值设为 True" # 内层退出后,应该恢复到外层的值 assert config.assume_static_by_default is False, "内层退出后应恢复到外层值 False" # 外层退出后,应该恢复到初始值 assert config.assume_static_by_default == initial_value, f"外层退出后应恢复到初始值 {initial_value}" def test_dynamo_config_multiple_options_patch(): """验证同时 patch 多个配置项时的正确性""" import torch._dynamo.config as config # 记录初始值 initial_assume_static = config.assume_static_by_default initial_suppress_errors = config.suppress_errors initial_verbose = config.verbose # 同时 patch 多个配置项 with config.patch( assume_static_by_default=not initial_assume_static, suppress_errors=not initial_suppress_errors, verbose=not initial_verbose, ): # 验证所有配置项都已修改 assert config.assume_static_by_default == (not initial_assume_static), "assume_static_by_default 应被修改" assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors 应被修改" assert config.verbose == (not initial_verbose), "verbose 应被修改" # 嵌套 patch 部分配置项 with config.patch(assume_static_by_default=initial_assume_static): assert config.assume_static_by_default == initial_assume_static, "内层应恢复 assume_static_by_default" # 其他配置项应保持外层 patch 的值 assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors 应保持外层值" assert config.verbose == (not initial_verbose), "verbose 应保持外层值" # 内层退出后,assume_static_by_default 应恢复到外层 patch 的值 assert config.assume_static_by_default == (not initial_assume_static), "内层退出后应恢复到外层 patch 值" # 外层退出后,所有配置项都应恢复到初始值 assert config.assume_static_by_default == initial_assume_static, "外层退出后 assume_static_by_default 应恢复" assert config.suppress_errors == initial_suppress_errors, "外层退出后 suppress_errors 应恢复" assert config.verbose == initial_verbose, "外层退出后 verbose 应恢复" def test_dynamo_config_restore_on_exception(): """验证在 with 块内抛出异常时配置能正确恢复""" import torch._dynamo.config as config # 记录初始值 initial_value = config.assume_static_by_default # 测试单层 patch 在异常时的恢复 try: with config.patch(assume_static_by_default=not initial_value): assert config.assume_static_by_default == (not initial_value), "patch 应生效" raise RuntimeError("测试异常") except RuntimeError: pass # 异常后配置应恢复 assert config.assume_static_by_default == initial_value, "单层异常后应恢复到初始值" # 测试嵌套 patch 在内层异常时的恢复 try: with config.patch(assume_static_by_default=False): assert config.assume_static_by_default is False, "外层 patch 应生效" try: with config.patch(assume_static_by_default=True): assert config.assume_static_by_default is True, "内层 patch 应生效" raise ValueError("内层测试异常") except ValueError: pass # 内层异常捕获后,应恢复到外层值 assert config.assume_static_by_default is False, "内层异常后应恢复到外层值" except Exception: pytest.fail("外层不应捕获到异常") # 最终应恢复到初始值 assert config.assume_static_by_default == initial_value, "嵌套异常后应恢复到初始值" if __name__ == "__main__": pytest.main([__file__, "-v", "-s"])