daVinci-MagiHuman / pkgs /MagiCompiler /tests /test_nested_compile.py
jiadisu
Switch back to Docker SDK with local pkgs
e6066e8
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""测试嵌套 compile 场景:torch.compile 与 magi_compile 的各种组合"""
import os
import pytest
import torch
import torch.nn as nn
from magi_compiler import magi_compile
from magi_compiler.config import CompileMode, get_compile_config
DEVICE = "cuda"
HIDDEN_SIZE = 64
TOLERANCE = 1e-3
# ============ 辅助函数 ============
def is_torch_compiled(module: nn.Module) -> bool:
"""
检查模块是否被 torch.compile 编译
两种方式:
1. torch.compile(instance) -> OptimizedModule
2. @torch.compile def forward -> forward 有 _torchdynamo_orig_callable
注意:@torch.compiler.disable 也设置 _torchdynamo_orig_callable,
但会额外设置 _torchdynamo_disable=True,需排除
"""
if type(module).__name__ == "OptimizedModule":
return True
forward_method = type(module).forward
if hasattr(forward_method, "_torchdynamo_orig_callable"):
if not getattr(forward_method, "_torchdynamo_disable", False):
return True
return False
def is_torch_disabled(module: nn.Module) -> bool:
"""检查 forward 是否被 @torch.compiler.disable 装饰"""
return getattr(type(module).forward, "_torchdynamo_disable", False)
def assert_torch_compiled(module: nn.Module, msg: str = ""):
assert is_torch_compiled(module), (
f"Expected torch.compile'd. type={type(module).__name__}, "
f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}"
)
def assert_not_torch_compiled_or_disabled(module: nn.Module, msg: str = ""):
assert not is_torch_compiled(module), (
f"Expected NOT torch.compile'd. type={type(module).__name__}, "
f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}"
)
def assert_magi_compiled(module: nn.Module, msg: str = ""):
assert hasattr(module, "compiled_code"), f"Missing compiled_code. {msg}"
assert module.compiled_code is not None, f"compiled_code is None. {msg}"
def assert_not_magi_compiled(module: nn.Module, msg: str = ""):
if hasattr(module, "compiled_code"):
assert module.compiled_code is None, f"compiled_code should be None. {msg}"
def assert_torch_disabled(module: nn.Module, msg: str = ""):
assert is_torch_disabled(module), (
f"Expected @torch.compiler.disable. "
f"_torchdynamo_disable={getattr(type(module).forward, '_torchdynamo_disable', False)}. {msg}"
)
# ============ Fixtures ============
@pytest.fixture(autouse=True)
def set_magi_compile_mode():
"""测试期间 compile_mode=MAGI_COMPILE"""
config = get_compile_config()
old_value = config.compile_mode
config.compile_mode = CompileMode.MAGI_COMPILE
config.cache_root_dir = os.environ.get("MAGI_COMPILE_CACHE_ROOT_DIR", config.cache_root_dir)
print(f"set magi compile mode: {config.compile_mode}, cache root dir: {config.cache_root_dir}")
yield
config.compile_mode = old_value
# ============ torch.compile 嵌套行为 ============
def test_torch_compile_nested():
"""torch.compile 嵌套:内层已编译的 OptimizedModule 作为 opaque 节点"""
class InnerBlock(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
class OuterModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.inner = InnerBlock(hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.inner(x)
return self.output(x)
model = OuterModel(HIDDEN_SIZE).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
with torch.no_grad():
baseline = model(x)
model.inner = torch.compile(model.inner, fullgraph=False, dynamic=True)
assert_torch_compiled(model.inner)
with torch.no_grad():
inner_compiled_out = model(x)
assert torch.allclose(baseline, inner_compiled_out, atol=TOLERANCE, rtol=TOLERANCE)
compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
assert_torch_compiled(compiled_model)
assert_torch_compiled(compiled_model.inner)
with torch.no_grad():
nested_out = compiled_model(x)
assert torch.allclose(baseline, nested_out, atol=TOLERANCE, rtol=TOLERANCE)
def test_torch_compile_with_disable_inner():
"""torch.compile + @torch.compiler.disable:disable 的函数产生 graph break"""
class InnerBlock(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
@torch.compiler.disable
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
class OuterModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.inner = InnerBlock(hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.inner(x)
return self.output(x)
model = OuterModel(HIDDEN_SIZE).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
with torch.no_grad():
baseline = model(x)
compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
assert_torch_compiled(compiled_model)
assert_not_torch_compiled_or_disabled(compiled_model.inner)
with torch.no_grad():
compiled_out = compiled_model(x)
assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE)
# ============ torch.compile + magi_compile 嵌套 ============
def test_nested_torch_compile_magi_compile():
"""外层 torch.compile + 内层 magi_compile"""
@magi_compile()
class InnerMagiBlock(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.linear1 = nn.Linear(hidden_size, hidden_size * 4)
self.linear2 = nn.Linear(hidden_size * 4, hidden_size)
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.norm(x)
x = self.linear1(x)
x = torch.nn.functional.gelu(x)
x = self.linear2(x)
return x + residual
class OuterModel(nn.Module):
def __init__(self, hidden_size: int, num_layers: int = 2):
super().__init__()
self.embed = nn.Linear(hidden_size, hidden_size)
self.blocks = nn.ModuleList([InnerMagiBlock(hidden_size) for _ in range(num_layers)])
self.output = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embed(x)
for block in self.blocks:
x = block(x)
return self.output(x)
num_layers = 2
model = OuterModel(HIDDEN_SIZE, num_layers=num_layers).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
for i, block in enumerate(model.blocks):
assert hasattr(block, "enable_compile")
assert block.enable_compile is True
with torch.no_grad():
baseline = model(x)
for i, block in enumerate(model.blocks):
assert_magi_compiled(block)
assert_not_torch_compiled_or_disabled(block)
compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
assert_torch_compiled(compiled_model)
assert compiled_model._orig_mod is model
with torch.no_grad():
compiled_out = compiled_model(x)
for i, block in enumerate(model.blocks):
assert block.enable_compile is True
assert_magi_compiled(block)
assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE)
def test_nested_torch_compile_multiple_magi_compile():
"""外层 torch.compile 包含多个 magi_compile 模块"""
@magi_compile()
class MagiBlock1(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
@magi_compile()
class MagiBlock2(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
class OuterModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.block1 = MagiBlock1(hidden_size)
self.block2 = MagiBlock2(hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.block1(x)
x = self.block2(x)
return self.output(x)
model = OuterModel(HIDDEN_SIZE).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
with torch.no_grad():
baseline = model(x)
assert_magi_compiled(model.block1)
assert_magi_compiled(model.block2)
assert_not_torch_compiled_or_disabled(model.block1)
assert_not_torch_compiled_or_disabled(model.block2)
compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
assert_torch_compiled(compiled_model)
assert_not_torch_compiled_or_disabled(model.block1)
assert_not_torch_compiled_or_disabled(model.block2)
with torch.no_grad():
compiled_out = compiled_model(x)
assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE)
# ============ torch.compile 使用装饰器 + magi_compile 嵌套 ============
def test_decorator_torch_compile_on_forward():
"""@torch.compile 装饰 forward:模块类型不变,但 is_torch_compiled 返回 True"""
class MyModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
model = MyModel(HIDDEN_SIZE).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
# eager baseline
with torch.no_grad():
baseline = model(x)
# 创建带 @torch.compile forward 的版本
class CompiledModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
@torch.compile(fullgraph=False, dynamic=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
compiled_model = CompiledModel(HIDDEN_SIZE).to(DEVICE)
compiled_model.load_state_dict(model.state_dict())
assert type(compiled_model).__name__ == "CompiledModel"
assert_torch_compiled(compiled_model)
with torch.no_grad():
out = compiled_model(x)
assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE)
def test_decorator_nested_torch_compile_forward_magi_inner():
"""外层 forward @torch.compile + 内层 @magi_compile"""
class InnerBlock(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
class OuterModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.inner = InnerBlock(hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.inner(x)
return self.output(x)
model = OuterModel(HIDDEN_SIZE).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
# eager baseline
with torch.no_grad():
baseline = model(x)
# 创建 magi inner + torch.compile forward outer 版本
MagiInnerBlock = magi_compile()(InnerBlock)
class CompiledOuterModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.inner = MagiInnerBlock(hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
@torch.compile(fullgraph=False, dynamic=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.inner(x)
return self.output(x)
compiled_model = CompiledOuterModel(HIDDEN_SIZE).to(DEVICE)
compiled_model.load_state_dict(model.state_dict())
assert_torch_compiled(compiled_model)
with torch.no_grad():
out = compiled_model(x)
assert_magi_compiled(compiled_model.inner)
assert_not_torch_compiled_or_disabled(compiled_model.inner)
assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE)
# ============ torch._dynamo.config 正确性验证 ============
def test_dynamo_config_nested_patch_restore():
"""验证 config.patch() 嵌套时能正确恢复到上一层的值"""
import torch._dynamo.config as config
# 记录初始值
initial_value = config.assume_static_by_default
# 模拟外层 compile 设置 dynamic=True (assume_static_by_default=False)
with config.patch(assume_static_by_default=False):
assert config.assume_static_by_default is False, "外层 patch 应将值设为 False"
# 模拟内层 magi_compile 恢复默认 (assume_static_by_default=True)
with config.patch(assume_static_by_default=True):
assert config.assume_static_by_default is True, "内层 patch 应将值设为 True"
# 内层退出后,应该恢复到外层的值
assert config.assume_static_by_default is False, "内层退出后应恢复到外层值 False"
# 外层退出后,应该恢复到初始值
assert config.assume_static_by_default == initial_value, f"外层退出后应恢复到初始值 {initial_value}"
def test_dynamo_config_multiple_options_patch():
"""验证同时 patch 多个配置项时的正确性"""
import torch._dynamo.config as config
# 记录初始值
initial_assume_static = config.assume_static_by_default
initial_suppress_errors = config.suppress_errors
initial_verbose = config.verbose
# 同时 patch 多个配置项
with config.patch(
assume_static_by_default=not initial_assume_static,
suppress_errors=not initial_suppress_errors,
verbose=not initial_verbose,
):
# 验证所有配置项都已修改
assert config.assume_static_by_default == (not initial_assume_static), "assume_static_by_default 应被修改"
assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors 应被修改"
assert config.verbose == (not initial_verbose), "verbose 应被修改"
# 嵌套 patch 部分配置项
with config.patch(assume_static_by_default=initial_assume_static):
assert config.assume_static_by_default == initial_assume_static, "内层应恢复 assume_static_by_default"
# 其他配置项应保持外层 patch 的值
assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors 应保持外层值"
assert config.verbose == (not initial_verbose), "verbose 应保持外层值"
# 内层退出后,assume_static_by_default 应恢复到外层 patch 的值
assert config.assume_static_by_default == (not initial_assume_static), "内层退出后应恢复到外层 patch 值"
# 外层退出后,所有配置项都应恢复到初始值
assert config.assume_static_by_default == initial_assume_static, "外层退出后 assume_static_by_default 应恢复"
assert config.suppress_errors == initial_suppress_errors, "外层退出后 suppress_errors 应恢复"
assert config.verbose == initial_verbose, "外层退出后 verbose 应恢复"
def test_dynamo_config_restore_on_exception():
"""验证在 with 块内抛出异常时配置能正确恢复"""
import torch._dynamo.config as config
# 记录初始值
initial_value = config.assume_static_by_default
# 测试单层 patch 在异常时的恢复
try:
with config.patch(assume_static_by_default=not initial_value):
assert config.assume_static_by_default == (not initial_value), "patch 应生效"
raise RuntimeError("测试异常")
except RuntimeError:
pass
# 异常后配置应恢复
assert config.assume_static_by_default == initial_value, "单层异常后应恢复到初始值"
# 测试嵌套 patch 在内层异常时的恢复
try:
with config.patch(assume_static_by_default=False):
assert config.assume_static_by_default is False, "外层 patch 应生效"
try:
with config.patch(assume_static_by_default=True):
assert config.assume_static_by_default is True, "内层 patch 应生效"
raise ValueError("内层测试异常")
except ValueError:
pass
# 内层异常捕获后,应恢复到外层值
assert config.assume_static_by_default is False, "内层异常后应恢复到外层值"
except Exception:
pytest.fail("外层不应捕获到异常")
# 最终应恢复到初始值
assert config.assume_static_by_default == initial_value, "嵌套异常后应恢复到初始值"
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])