Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """测试嵌套 compile 场景:torch.compile 与 magi_compile 的各种组合""" | |
| import os | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from magi_compiler import magi_compile | |
| from magi_compiler.config import CompileMode, get_compile_config | |
| DEVICE = "cuda" | |
| HIDDEN_SIZE = 64 | |
| TOLERANCE = 1e-3 | |
| # ============ 辅助函数 ============ | |
| def is_torch_compiled(module: nn.Module) -> bool: | |
| """ | |
| 检查模块是否被 torch.compile 编译 | |
| 两种方式: | |
| 1. torch.compile(instance) -> OptimizedModule | |
| 2. @torch.compile def forward -> forward 有 _torchdynamo_orig_callable | |
| 注意:@torch.compiler.disable 也设置 _torchdynamo_orig_callable, | |
| 但会额外设置 _torchdynamo_disable=True,需排除 | |
| """ | |
| if type(module).__name__ == "OptimizedModule": | |
| return True | |
| forward_method = type(module).forward | |
| if hasattr(forward_method, "_torchdynamo_orig_callable"): | |
| if not getattr(forward_method, "_torchdynamo_disable", False): | |
| return True | |
| return False | |
| def is_torch_disabled(module: nn.Module) -> bool: | |
| """检查 forward 是否被 @torch.compiler.disable 装饰""" | |
| return getattr(type(module).forward, "_torchdynamo_disable", False) | |
| def assert_torch_compiled(module: nn.Module, msg: str = ""): | |
| assert is_torch_compiled(module), ( | |
| f"Expected torch.compile'd. type={type(module).__name__}, " | |
| f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}" | |
| ) | |
| def assert_not_torch_compiled_or_disabled(module: nn.Module, msg: str = ""): | |
| assert not is_torch_compiled(module), ( | |
| f"Expected NOT torch.compile'd. type={type(module).__name__}, " | |
| f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}" | |
| ) | |
| def assert_magi_compiled(module: nn.Module, msg: str = ""): | |
| assert hasattr(module, "compiled_code"), f"Missing compiled_code. {msg}" | |
| assert module.compiled_code is not None, f"compiled_code is None. {msg}" | |
| def assert_not_magi_compiled(module: nn.Module, msg: str = ""): | |
| if hasattr(module, "compiled_code"): | |
| assert module.compiled_code is None, f"compiled_code should be None. {msg}" | |
| def assert_torch_disabled(module: nn.Module, msg: str = ""): | |
| assert is_torch_disabled(module), ( | |
| f"Expected @torch.compiler.disable. " | |
| f"_torchdynamo_disable={getattr(type(module).forward, '_torchdynamo_disable', False)}. {msg}" | |
| ) | |
| # ============ Fixtures ============ | |
| def set_magi_compile_mode(): | |
| """测试期间 compile_mode=MAGI_COMPILE""" | |
| config = get_compile_config() | |
| old_value = config.compile_mode | |
| config.compile_mode = CompileMode.MAGI_COMPILE | |
| config.cache_root_dir = os.environ.get("MAGI_COMPILE_CACHE_ROOT_DIR", config.cache_root_dir) | |
| print(f"set magi compile mode: {config.compile_mode}, cache root dir: {config.cache_root_dir}") | |
| yield | |
| config.compile_mode = old_value | |
| # ============ torch.compile 嵌套行为 ============ | |
| def test_torch_compile_nested(): | |
| """torch.compile 嵌套:内层已编译的 OptimizedModule 作为 opaque 节点""" | |
| class InnerBlock(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| class OuterModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.inner = InnerBlock(hidden_size) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.inner(x) | |
| return self.output(x) | |
| model = OuterModel(HIDDEN_SIZE).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| model.inner = torch.compile(model.inner, fullgraph=False, dynamic=True) | |
| assert_torch_compiled(model.inner) | |
| with torch.no_grad(): | |
| inner_compiled_out = model(x) | |
| assert torch.allclose(baseline, inner_compiled_out, atol=TOLERANCE, rtol=TOLERANCE) | |
| compiled_model = torch.compile(model, fullgraph=False, dynamic=True) | |
| assert_torch_compiled(compiled_model) | |
| assert_torch_compiled(compiled_model.inner) | |
| with torch.no_grad(): | |
| nested_out = compiled_model(x) | |
| assert torch.allclose(baseline, nested_out, atol=TOLERANCE, rtol=TOLERANCE) | |
| def test_torch_compile_with_disable_inner(): | |
| """torch.compile + @torch.compiler.disable:disable 的函数产生 graph break""" | |
| class InnerBlock(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| class OuterModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.inner = InnerBlock(hidden_size) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.inner(x) | |
| return self.output(x) | |
| model = OuterModel(HIDDEN_SIZE).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| compiled_model = torch.compile(model, fullgraph=False, dynamic=True) | |
| assert_torch_compiled(compiled_model) | |
| assert_not_torch_compiled_or_disabled(compiled_model.inner) | |
| with torch.no_grad(): | |
| compiled_out = compiled_model(x) | |
| assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE) | |
| # ============ torch.compile + magi_compile 嵌套 ============ | |
| def test_nested_torch_compile_magi_compile(): | |
| """外层 torch.compile + 内层 magi_compile""" | |
| class InnerMagiBlock(nn.Module): | |
| def __init__(self, hidden_size: int): | |
| super().__init__() | |
| self.linear1 = nn.Linear(hidden_size, hidden_size * 4) | |
| self.linear2 = nn.Linear(hidden_size * 4, hidden_size) | |
| self.norm = nn.LayerNorm(hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| x = self.norm(x) | |
| x = self.linear1(x) | |
| x = torch.nn.functional.gelu(x) | |
| x = self.linear2(x) | |
| return x + residual | |
| class OuterModel(nn.Module): | |
| def __init__(self, hidden_size: int, num_layers: int = 2): | |
| super().__init__() | |
| self.embed = nn.Linear(hidden_size, hidden_size) | |
| self.blocks = nn.ModuleList([InnerMagiBlock(hidden_size) for _ in range(num_layers)]) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.embed(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| return self.output(x) | |
| num_layers = 2 | |
| model = OuterModel(HIDDEN_SIZE, num_layers=num_layers).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| for i, block in enumerate(model.blocks): | |
| assert hasattr(block, "enable_compile") | |
| assert block.enable_compile is True | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| for i, block in enumerate(model.blocks): | |
| assert_magi_compiled(block) | |
| assert_not_torch_compiled_or_disabled(block) | |
| compiled_model = torch.compile(model, fullgraph=False, dynamic=True) | |
| assert_torch_compiled(compiled_model) | |
| assert compiled_model._orig_mod is model | |
| with torch.no_grad(): | |
| compiled_out = compiled_model(x) | |
| for i, block in enumerate(model.blocks): | |
| assert block.enable_compile is True | |
| assert_magi_compiled(block) | |
| assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE) | |
| def test_nested_torch_compile_multiple_magi_compile(): | |
| """外层 torch.compile 包含多个 magi_compile 模块""" | |
| class MagiBlock1(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| class MagiBlock2(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| class OuterModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.block1 = MagiBlock1(hidden_size) | |
| self.block2 = MagiBlock2(hidden_size) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.block1(x) | |
| x = self.block2(x) | |
| return self.output(x) | |
| model = OuterModel(HIDDEN_SIZE).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| assert_magi_compiled(model.block1) | |
| assert_magi_compiled(model.block2) | |
| assert_not_torch_compiled_or_disabled(model.block1) | |
| assert_not_torch_compiled_or_disabled(model.block2) | |
| compiled_model = torch.compile(model, fullgraph=False, dynamic=True) | |
| assert_torch_compiled(compiled_model) | |
| assert_not_torch_compiled_or_disabled(model.block1) | |
| assert_not_torch_compiled_or_disabled(model.block2) | |
| with torch.no_grad(): | |
| compiled_out = compiled_model(x) | |
| assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE) | |
| # ============ torch.compile 使用装饰器 + magi_compile 嵌套 ============ | |
| def test_decorator_torch_compile_on_forward(): | |
| """@torch.compile 装饰 forward:模块类型不变,但 is_torch_compiled 返回 True""" | |
| class MyModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| model = MyModel(HIDDEN_SIZE).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| # eager baseline | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| # 创建带 @torch.compile forward 的版本 | |
| class CompiledModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| compiled_model = CompiledModel(HIDDEN_SIZE).to(DEVICE) | |
| compiled_model.load_state_dict(model.state_dict()) | |
| assert type(compiled_model).__name__ == "CompiledModel" | |
| assert_torch_compiled(compiled_model) | |
| with torch.no_grad(): | |
| out = compiled_model(x) | |
| assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE) | |
| def test_decorator_nested_torch_compile_forward_magi_inner(): | |
| """外层 forward @torch.compile + 内层 @magi_compile""" | |
| class InnerBlock(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.relu(self.linear(x)) | |
| class OuterModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.inner = InnerBlock(hidden_size) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.inner(x) | |
| return self.output(x) | |
| model = OuterModel(HIDDEN_SIZE).to(DEVICE) | |
| x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE) | |
| # eager baseline | |
| with torch.no_grad(): | |
| baseline = model(x) | |
| # 创建 magi inner + torch.compile forward outer 版本 | |
| MagiInnerBlock = magi_compile()(InnerBlock) | |
| class CompiledOuterModel(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.inner = MagiInnerBlock(hidden_size) | |
| self.output = nn.Linear(hidden_size, hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.inner(x) | |
| return self.output(x) | |
| compiled_model = CompiledOuterModel(HIDDEN_SIZE).to(DEVICE) | |
| compiled_model.load_state_dict(model.state_dict()) | |
| assert_torch_compiled(compiled_model) | |
| with torch.no_grad(): | |
| out = compiled_model(x) | |
| assert_magi_compiled(compiled_model.inner) | |
| assert_not_torch_compiled_or_disabled(compiled_model.inner) | |
| assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE) | |
| # ============ torch._dynamo.config 正确性验证 ============ | |
| def test_dynamo_config_nested_patch_restore(): | |
| """验证 config.patch() 嵌套时能正确恢复到上一层的值""" | |
| import torch._dynamo.config as config | |
| # 记录初始值 | |
| initial_value = config.assume_static_by_default | |
| # 模拟外层 compile 设置 dynamic=True (assume_static_by_default=False) | |
| with config.patch(assume_static_by_default=False): | |
| assert config.assume_static_by_default is False, "外层 patch 应将值设为 False" | |
| # 模拟内层 magi_compile 恢复默认 (assume_static_by_default=True) | |
| with config.patch(assume_static_by_default=True): | |
| assert config.assume_static_by_default is True, "内层 patch 应将值设为 True" | |
| # 内层退出后,应该恢复到外层的值 | |
| assert config.assume_static_by_default is False, "内层退出后应恢复到外层值 False" | |
| # 外层退出后,应该恢复到初始值 | |
| assert config.assume_static_by_default == initial_value, f"外层退出后应恢复到初始值 {initial_value}" | |
| def test_dynamo_config_multiple_options_patch(): | |
| """验证同时 patch 多个配置项时的正确性""" | |
| import torch._dynamo.config as config | |
| # 记录初始值 | |
| initial_assume_static = config.assume_static_by_default | |
| initial_suppress_errors = config.suppress_errors | |
| initial_verbose = config.verbose | |
| # 同时 patch 多个配置项 | |
| with config.patch( | |
| assume_static_by_default=not initial_assume_static, | |
| suppress_errors=not initial_suppress_errors, | |
| verbose=not initial_verbose, | |
| ): | |
| # 验证所有配置项都已修改 | |
| assert config.assume_static_by_default == (not initial_assume_static), "assume_static_by_default 应被修改" | |
| assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors 应被修改" | |
| assert config.verbose == (not initial_verbose), "verbose 应被修改" | |
| # 嵌套 patch 部分配置项 | |
| with config.patch(assume_static_by_default=initial_assume_static): | |
| assert config.assume_static_by_default == initial_assume_static, "内层应恢复 assume_static_by_default" | |
| # 其他配置项应保持外层 patch 的值 | |
| assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors 应保持外层值" | |
| assert config.verbose == (not initial_verbose), "verbose 应保持外层值" | |
| # 内层退出后,assume_static_by_default 应恢复到外层 patch 的值 | |
| assert config.assume_static_by_default == (not initial_assume_static), "内层退出后应恢复到外层 patch 值" | |
| # 外层退出后,所有配置项都应恢复到初始值 | |
| assert config.assume_static_by_default == initial_assume_static, "外层退出后 assume_static_by_default 应恢复" | |
| assert config.suppress_errors == initial_suppress_errors, "外层退出后 suppress_errors 应恢复" | |
| assert config.verbose == initial_verbose, "外层退出后 verbose 应恢复" | |
| def test_dynamo_config_restore_on_exception(): | |
| """验证在 with 块内抛出异常时配置能正确恢复""" | |
| import torch._dynamo.config as config | |
| # 记录初始值 | |
| initial_value = config.assume_static_by_default | |
| # 测试单层 patch 在异常时的恢复 | |
| try: | |
| with config.patch(assume_static_by_default=not initial_value): | |
| assert config.assume_static_by_default == (not initial_value), "patch 应生效" | |
| raise RuntimeError("测试异常") | |
| except RuntimeError: | |
| pass | |
| # 异常后配置应恢复 | |
| assert config.assume_static_by_default == initial_value, "单层异常后应恢复到初始值" | |
| # 测试嵌套 patch 在内层异常时的恢复 | |
| try: | |
| with config.patch(assume_static_by_default=False): | |
| assert config.assume_static_by_default is False, "外层 patch 应生效" | |
| try: | |
| with config.patch(assume_static_by_default=True): | |
| assert config.assume_static_by_default is True, "内层 patch 应生效" | |
| raise ValueError("内层测试异常") | |
| except ValueError: | |
| pass | |
| # 内层异常捕获后,应恢复到外层值 | |
| assert config.assume_static_by_default is False, "内层异常后应恢复到外层值" | |
| except Exception: | |
| pytest.fail("外层不应捕获到异常") | |
| # 最终应恢复到初始值 | |
| assert config.assume_static_by_default == initial_value, "嵌套异常后应恢复到初始值" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v", "-s"]) | |