Spaces:
Runtime error
Runtime error
File size: 18,388 Bytes
e6066e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 | # Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""测试嵌套 compile 场景:torch.compile 与 magi_compile 的各种组合"""
import os
import pytest
import torch
import torch.nn as nn
from magi_compiler import magi_compile
from magi_compiler.config import CompileMode, get_compile_config
DEVICE = "cuda"
HIDDEN_SIZE = 64
TOLERANCE = 1e-3
# ============ 辅助函数 ============
def is_torch_compiled(module: nn.Module) -> bool:
"""
检查模块是否被 torch.compile 编译
两种方式:
1. torch.compile(instance) -> OptimizedModule
2. @torch.compile def forward -> forward 有 _torchdynamo_orig_callable
注意:@torch.compiler.disable 也设置 _torchdynamo_orig_callable,
但会额外设置 _torchdynamo_disable=True,需排除
"""
if type(module).__name__ == "OptimizedModule":
return True
forward_method = type(module).forward
if hasattr(forward_method, "_torchdynamo_orig_callable"):
if not getattr(forward_method, "_torchdynamo_disable", False):
return True
return False
def is_torch_disabled(module: nn.Module) -> bool:
"""检查 forward 是否被 @torch.compiler.disable 装饰"""
return getattr(type(module).forward, "_torchdynamo_disable", False)
def assert_torch_compiled(module: nn.Module, msg: str = ""):
assert is_torch_compiled(module), (
f"Expected torch.compile'd. type={type(module).__name__}, "
f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}"
)
def assert_not_torch_compiled_or_disabled(module: nn.Module, msg: str = ""):
assert not is_torch_compiled(module), (
f"Expected NOT torch.compile'd. type={type(module).__name__}, "
f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}"
)
def assert_magi_compiled(module: nn.Module, msg: str = ""):
assert hasattr(module, "compiled_code"), f"Missing compiled_code. {msg}"
assert module.compiled_code is not None, f"compiled_code is None. {msg}"
def assert_not_magi_compiled(module: nn.Module, msg: str = ""):
if hasattr(module, "compiled_code"):
assert module.compiled_code is None, f"compiled_code should be None. {msg}"
def assert_torch_disabled(module: nn.Module, msg: str = ""):
assert is_torch_disabled(module), (
f"Expected @torch.compiler.disable. "
f"_torchdynamo_disable={getattr(type(module).forward, '_torchdynamo_disable', False)}. {msg}"
)
# ============ Fixtures ============
@pytest.fixture(autouse=True)
def set_magi_compile_mode():
"""测试期间 compile_mode=MAGI_COMPILE"""
config = get_compile_config()
old_value = config.compile_mode
config.compile_mode = CompileMode.MAGI_COMPILE
config.cache_root_dir = os.environ.get("MAGI_COMPILE_CACHE_ROOT_DIR", config.cache_root_dir)
print(f"set magi compile mode: {config.compile_mode}, cache root dir: {config.cache_root_dir}")
yield
config.compile_mode = old_value
# ============ torch.compile 嵌套行为 ============
def test_torch_compile_nested():
"""torch.compile 嵌套:内层已编译的 OptimizedModule 作为 opaque 节点"""
class InnerBlock(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
class OuterModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.inner = InnerBlock(hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.inner(x)
return self.output(x)
model = OuterModel(HIDDEN_SIZE).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
with torch.no_grad():
baseline = model(x)
model.inner = torch.compile(model.inner, fullgraph=False, dynamic=True)
assert_torch_compiled(model.inner)
with torch.no_grad():
inner_compiled_out = model(x)
assert torch.allclose(baseline, inner_compiled_out, atol=TOLERANCE, rtol=TOLERANCE)
compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
assert_torch_compiled(compiled_model)
assert_torch_compiled(compiled_model.inner)
with torch.no_grad():
nested_out = compiled_model(x)
assert torch.allclose(baseline, nested_out, atol=TOLERANCE, rtol=TOLERANCE)
def test_torch_compile_with_disable_inner():
"""torch.compile + @torch.compiler.disable:disable 的函数产生 graph break"""
class InnerBlock(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
@torch.compiler.disable
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
class OuterModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.inner = InnerBlock(hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.inner(x)
return self.output(x)
model = OuterModel(HIDDEN_SIZE).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
with torch.no_grad():
baseline = model(x)
compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
assert_torch_compiled(compiled_model)
assert_not_torch_compiled_or_disabled(compiled_model.inner)
with torch.no_grad():
compiled_out = compiled_model(x)
assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE)
# ============ torch.compile + magi_compile 嵌套 ============
def test_nested_torch_compile_magi_compile():
"""外层 torch.compile + 内层 magi_compile"""
@magi_compile()
class InnerMagiBlock(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.linear1 = nn.Linear(hidden_size, hidden_size * 4)
self.linear2 = nn.Linear(hidden_size * 4, hidden_size)
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.norm(x)
x = self.linear1(x)
x = torch.nn.functional.gelu(x)
x = self.linear2(x)
return x + residual
class OuterModel(nn.Module):
def __init__(self, hidden_size: int, num_layers: int = 2):
super().__init__()
self.embed = nn.Linear(hidden_size, hidden_size)
self.blocks = nn.ModuleList([InnerMagiBlock(hidden_size) for _ in range(num_layers)])
self.output = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embed(x)
for block in self.blocks:
x = block(x)
return self.output(x)
num_layers = 2
model = OuterModel(HIDDEN_SIZE, num_layers=num_layers).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
for i, block in enumerate(model.blocks):
assert hasattr(block, "enable_compile")
assert block.enable_compile is True
with torch.no_grad():
baseline = model(x)
for i, block in enumerate(model.blocks):
assert_magi_compiled(block)
assert_not_torch_compiled_or_disabled(block)
compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
assert_torch_compiled(compiled_model)
assert compiled_model._orig_mod is model
with torch.no_grad():
compiled_out = compiled_model(x)
for i, block in enumerate(model.blocks):
assert block.enable_compile is True
assert_magi_compiled(block)
assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE)
def test_nested_torch_compile_multiple_magi_compile():
"""外层 torch.compile 包含多个 magi_compile 模块"""
@magi_compile()
class MagiBlock1(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
@magi_compile()
class MagiBlock2(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
class OuterModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.block1 = MagiBlock1(hidden_size)
self.block2 = MagiBlock2(hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.block1(x)
x = self.block2(x)
return self.output(x)
model = OuterModel(HIDDEN_SIZE).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
with torch.no_grad():
baseline = model(x)
assert_magi_compiled(model.block1)
assert_magi_compiled(model.block2)
assert_not_torch_compiled_or_disabled(model.block1)
assert_not_torch_compiled_or_disabled(model.block2)
compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
assert_torch_compiled(compiled_model)
assert_not_torch_compiled_or_disabled(model.block1)
assert_not_torch_compiled_or_disabled(model.block2)
with torch.no_grad():
compiled_out = compiled_model(x)
assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE)
# ============ torch.compile 使用装饰器 + magi_compile 嵌套 ============
def test_decorator_torch_compile_on_forward():
"""@torch.compile 装饰 forward:模块类型不变,但 is_torch_compiled 返回 True"""
class MyModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
model = MyModel(HIDDEN_SIZE).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
# eager baseline
with torch.no_grad():
baseline = model(x)
# 创建带 @torch.compile forward 的版本
class CompiledModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
@torch.compile(fullgraph=False, dynamic=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
compiled_model = CompiledModel(HIDDEN_SIZE).to(DEVICE)
compiled_model.load_state_dict(model.state_dict())
assert type(compiled_model).__name__ == "CompiledModel"
assert_torch_compiled(compiled_model)
with torch.no_grad():
out = compiled_model(x)
assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE)
def test_decorator_nested_torch_compile_forward_magi_inner():
"""外层 forward @torch.compile + 内层 @magi_compile"""
class InnerBlock(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.linear(x))
class OuterModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.inner = InnerBlock(hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.inner(x)
return self.output(x)
model = OuterModel(HIDDEN_SIZE).to(DEVICE)
x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)
# eager baseline
with torch.no_grad():
baseline = model(x)
# 创建 magi inner + torch.compile forward outer 版本
MagiInnerBlock = magi_compile()(InnerBlock)
class CompiledOuterModel(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.inner = MagiInnerBlock(hidden_size)
self.output = nn.Linear(hidden_size, hidden_size)
@torch.compile(fullgraph=False, dynamic=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.inner(x)
return self.output(x)
compiled_model = CompiledOuterModel(HIDDEN_SIZE).to(DEVICE)
compiled_model.load_state_dict(model.state_dict())
assert_torch_compiled(compiled_model)
with torch.no_grad():
out = compiled_model(x)
assert_magi_compiled(compiled_model.inner)
assert_not_torch_compiled_or_disabled(compiled_model.inner)
assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE)
# ============ torch._dynamo.config 正确性验证 ============
def test_dynamo_config_nested_patch_restore():
"""验证 config.patch() 嵌套时能正确恢复到上一层的值"""
import torch._dynamo.config as config
# 记录初始值
initial_value = config.assume_static_by_default
# 模拟外层 compile 设置 dynamic=True (assume_static_by_default=False)
with config.patch(assume_static_by_default=False):
assert config.assume_static_by_default is False, "外层 patch 应将值设为 False"
# 模拟内层 magi_compile 恢复默认 (assume_static_by_default=True)
with config.patch(assume_static_by_default=True):
assert config.assume_static_by_default is True, "内层 patch 应将值设为 True"
# 内层退出后,应该恢复到外层的值
assert config.assume_static_by_default is False, "内层退出后应恢复到外层值 False"
# 外层退出后,应该恢复到初始值
assert config.assume_static_by_default == initial_value, f"外层退出后应恢复到初始值 {initial_value}"
def test_dynamo_config_multiple_options_patch():
"""验证同时 patch 多个配置项时的正确性"""
import torch._dynamo.config as config
# 记录初始值
initial_assume_static = config.assume_static_by_default
initial_suppress_errors = config.suppress_errors
initial_verbose = config.verbose
# 同时 patch 多个配置项
with config.patch(
assume_static_by_default=not initial_assume_static,
suppress_errors=not initial_suppress_errors,
verbose=not initial_verbose,
):
# 验证所有配置项都已修改
assert config.assume_static_by_default == (not initial_assume_static), "assume_static_by_default 应被修改"
assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors 应被修改"
assert config.verbose == (not initial_verbose), "verbose 应被修改"
# 嵌套 patch 部分配置项
with config.patch(assume_static_by_default=initial_assume_static):
assert config.assume_static_by_default == initial_assume_static, "内层应恢复 assume_static_by_default"
# 其他配置项应保持外层 patch 的值
assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors 应保持外层值"
assert config.verbose == (not initial_verbose), "verbose 应保持外层值"
# 内层退出后,assume_static_by_default 应恢复到外层 patch 的值
assert config.assume_static_by_default == (not initial_assume_static), "内层退出后应恢复到外层 patch 值"
# 外层退出后,所有配置项都应恢复到初始值
assert config.assume_static_by_default == initial_assume_static, "外层退出后 assume_static_by_default 应恢复"
assert config.suppress_errors == initial_suppress_errors, "外层退出后 suppress_errors 应恢复"
assert config.verbose == initial_verbose, "外层退出后 verbose 应恢复"
def test_dynamo_config_restore_on_exception():
"""验证在 with 块内抛出异常时配置能正确恢复"""
import torch._dynamo.config as config
# 记录初始值
initial_value = config.assume_static_by_default
# 测试单层 patch 在异常时的恢复
try:
with config.patch(assume_static_by_default=not initial_value):
assert config.assume_static_by_default == (not initial_value), "patch 应生效"
raise RuntimeError("测试异常")
except RuntimeError:
pass
# 异常后配置应恢复
assert config.assume_static_by_default == initial_value, "单层异常后应恢复到初始值"
# 测试嵌套 patch 在内层异常时的恢复
try:
with config.patch(assume_static_by_default=False):
assert config.assume_static_by_default is False, "外层 patch 应生效"
try:
with config.patch(assume_static_by_default=True):
assert config.assume_static_by_default is True, "内层 patch 应生效"
raise ValueError("内层测试异常")
except ValueError:
pass
# 内层异常捕获后,应恢复到外层值
assert config.assume_static_by_default is False, "内层异常后应恢复到外层值"
except Exception:
pytest.fail("外层不应捕获到异常")
# 最终应恢复到初始值
assert config.assume_static_by_default == initial_value, "嵌套异常后应恢复到初始值"
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
|