File size: 18,388 Bytes
e6066e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""测试嵌套 compile 场景:torch.compile 与 magi_compile 的各种组合"""

import os

import pytest
import torch
import torch.nn as nn
from magi_compiler import magi_compile
from magi_compiler.config import CompileMode, get_compile_config

DEVICE = "cuda"
HIDDEN_SIZE = 64
TOLERANCE = 1e-3


# ============ 辅助函数 ============


def is_torch_compiled(module: nn.Module) -> bool:
    """
    检查模块是否被 torch.compile 编译

    两种方式:
    1. torch.compile(instance) -> OptimizedModule
    2. @torch.compile def forward -> forward 有 _torchdynamo_orig_callable

    注意:@torch.compiler.disable 也设置 _torchdynamo_orig_callable,
    但会额外设置 _torchdynamo_disable=True,需排除
    """
    if type(module).__name__ == "OptimizedModule":
        return True
    forward_method = type(module).forward
    if hasattr(forward_method, "_torchdynamo_orig_callable"):
        if not getattr(forward_method, "_torchdynamo_disable", False):
            return True
    return False


def is_torch_disabled(module: nn.Module) -> bool:
    """检查 forward 是否被 @torch.compiler.disable 装饰"""
    return getattr(type(module).forward, "_torchdynamo_disable", False)


def assert_torch_compiled(module: nn.Module, msg: str = ""):
    assert is_torch_compiled(module), (
        f"Expected torch.compile'd. type={type(module).__name__}, "
        f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}"
    )


def assert_not_torch_compiled_or_disabled(module: nn.Module, msg: str = ""):
    assert not is_torch_compiled(module), (
        f"Expected NOT torch.compile'd. type={type(module).__name__}, "
        f"has _torchdynamo_orig_callable={hasattr(type(module).forward, '_torchdynamo_orig_callable')}. {msg}"
    )


def assert_magi_compiled(module: nn.Module, msg: str = ""):
    assert hasattr(module, "compiled_code"), f"Missing compiled_code. {msg}"
    assert module.compiled_code is not None, f"compiled_code is None. {msg}"


def assert_not_magi_compiled(module: nn.Module, msg: str = ""):
    if hasattr(module, "compiled_code"):
        assert module.compiled_code is None, f"compiled_code should be None. {msg}"


def assert_torch_disabled(module: nn.Module, msg: str = ""):
    assert is_torch_disabled(module), (
        f"Expected @torch.compiler.disable. "
        f"_torchdynamo_disable={getattr(type(module).forward, '_torchdynamo_disable', False)}. {msg}"
    )


# ============ Fixtures ============


@pytest.fixture(autouse=True)
def set_magi_compile_mode():
    """测试期间 compile_mode=MAGI_COMPILE"""
    config = get_compile_config()
    old_value = config.compile_mode
    config.compile_mode = CompileMode.MAGI_COMPILE
    config.cache_root_dir = os.environ.get("MAGI_COMPILE_CACHE_ROOT_DIR", config.cache_root_dir)
    print(f"set magi compile mode: {config.compile_mode}, cache root dir: {config.cache_root_dir}")
    yield
    config.compile_mode = old_value


# ============ torch.compile 嵌套行为 ============


def test_torch_compile_nested():
    """torch.compile 嵌套:内层已编译的 OptimizedModule 作为 opaque 节点"""

    class InnerBlock(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.linear = nn.Linear(hidden_size, hidden_size)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.relu(self.linear(x))

    class OuterModel(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.inner = InnerBlock(hidden_size)
            self.output = nn.Linear(hidden_size, hidden_size)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.inner(x)
            return self.output(x)

    model = OuterModel(HIDDEN_SIZE).to(DEVICE)
    x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)

    with torch.no_grad():
        baseline = model(x)

    model.inner = torch.compile(model.inner, fullgraph=False, dynamic=True)
    assert_torch_compiled(model.inner)
    with torch.no_grad():
        inner_compiled_out = model(x)
    assert torch.allclose(baseline, inner_compiled_out, atol=TOLERANCE, rtol=TOLERANCE)

    compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
    assert_torch_compiled(compiled_model)
    assert_torch_compiled(compiled_model.inner)
    with torch.no_grad():
        nested_out = compiled_model(x)

    assert torch.allclose(baseline, nested_out, atol=TOLERANCE, rtol=TOLERANCE)


def test_torch_compile_with_disable_inner():
    """torch.compile + @torch.compiler.disable:disable 的函数产生 graph break"""

    class InnerBlock(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.linear = nn.Linear(hidden_size, hidden_size)

        @torch.compiler.disable
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.relu(self.linear(x))

    class OuterModel(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.inner = InnerBlock(hidden_size)
            self.output = nn.Linear(hidden_size, hidden_size)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.inner(x)
            return self.output(x)

    model = OuterModel(HIDDEN_SIZE).to(DEVICE)
    x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)

    with torch.no_grad():
        baseline = model(x)

    compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
    assert_torch_compiled(compiled_model)
    assert_not_torch_compiled_or_disabled(compiled_model.inner)
    with torch.no_grad():
        compiled_out = compiled_model(x)

    assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE)


# ============ torch.compile + magi_compile 嵌套 ============


def test_nested_torch_compile_magi_compile():
    """外层 torch.compile + 内层 magi_compile"""

    @magi_compile()
    class InnerMagiBlock(nn.Module):
        def __init__(self, hidden_size: int):
            super().__init__()
            self.linear1 = nn.Linear(hidden_size, hidden_size * 4)
            self.linear2 = nn.Linear(hidden_size * 4, hidden_size)
            self.norm = nn.LayerNorm(hidden_size)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            residual = x
            x = self.norm(x)
            x = self.linear1(x)
            x = torch.nn.functional.gelu(x)
            x = self.linear2(x)
            return x + residual

    class OuterModel(nn.Module):
        def __init__(self, hidden_size: int, num_layers: int = 2):
            super().__init__()
            self.embed = nn.Linear(hidden_size, hidden_size)
            self.blocks = nn.ModuleList([InnerMagiBlock(hidden_size) for _ in range(num_layers)])
            self.output = nn.Linear(hidden_size, hidden_size)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.embed(x)
            for block in self.blocks:
                x = block(x)
            return self.output(x)

    num_layers = 2
    model = OuterModel(HIDDEN_SIZE, num_layers=num_layers).to(DEVICE)
    x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)

    for i, block in enumerate(model.blocks):
        assert hasattr(block, "enable_compile")
        assert block.enable_compile is True

    with torch.no_grad():
        baseline = model(x)

    for i, block in enumerate(model.blocks):
        assert_magi_compiled(block)
        assert_not_torch_compiled_or_disabled(block)

    compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
    assert_torch_compiled(compiled_model)
    assert compiled_model._orig_mod is model

    with torch.no_grad():
        compiled_out = compiled_model(x)

    for i, block in enumerate(model.blocks):
        assert block.enable_compile is True
        assert_magi_compiled(block)

    assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE)


def test_nested_torch_compile_multiple_magi_compile():
    """外层 torch.compile 包含多个 magi_compile 模块"""

    @magi_compile()
    class MagiBlock1(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.linear = nn.Linear(hidden_size, hidden_size)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.relu(self.linear(x))

    @magi_compile()
    class MagiBlock2(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.linear = nn.Linear(hidden_size, hidden_size)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.relu(self.linear(x))

    class OuterModel(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.block1 = MagiBlock1(hidden_size)
            self.block2 = MagiBlock2(hidden_size)
            self.output = nn.Linear(hidden_size, hidden_size)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.block1(x)
            x = self.block2(x)
            return self.output(x)

    model = OuterModel(HIDDEN_SIZE).to(DEVICE)
    x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)

    with torch.no_grad():
        baseline = model(x)

    assert_magi_compiled(model.block1)
    assert_magi_compiled(model.block2)
    assert_not_torch_compiled_or_disabled(model.block1)
    assert_not_torch_compiled_or_disabled(model.block2)

    compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
    assert_torch_compiled(compiled_model)
    assert_not_torch_compiled_or_disabled(model.block1)
    assert_not_torch_compiled_or_disabled(model.block2)
    with torch.no_grad():
        compiled_out = compiled_model(x)

    assert torch.allclose(baseline, compiled_out, atol=TOLERANCE, rtol=TOLERANCE)


# ============ torch.compile 使用装饰器 + magi_compile 嵌套 ============


def test_decorator_torch_compile_on_forward():
    """@torch.compile 装饰 forward:模块类型不变,但 is_torch_compiled 返回 True"""

    class MyModel(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.linear = nn.Linear(hidden_size, hidden_size)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.relu(self.linear(x))

    model = MyModel(HIDDEN_SIZE).to(DEVICE)
    x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)

    # eager baseline
    with torch.no_grad():
        baseline = model(x)

    # 创建带 @torch.compile forward 的版本
    class CompiledModel(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.linear = nn.Linear(hidden_size, hidden_size)

        @torch.compile(fullgraph=False, dynamic=True)
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.relu(self.linear(x))

    compiled_model = CompiledModel(HIDDEN_SIZE).to(DEVICE)
    compiled_model.load_state_dict(model.state_dict())

    assert type(compiled_model).__name__ == "CompiledModel"
    assert_torch_compiled(compiled_model)

    with torch.no_grad():
        out = compiled_model(x)

    assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE)


def test_decorator_nested_torch_compile_forward_magi_inner():
    """外层 forward @torch.compile + 内层 @magi_compile"""

    class InnerBlock(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.linear = nn.Linear(hidden_size, hidden_size)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.relu(self.linear(x))

    class OuterModel(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.inner = InnerBlock(hidden_size)
            self.output = nn.Linear(hidden_size, hidden_size)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.inner(x)
            return self.output(x)

    model = OuterModel(HIDDEN_SIZE).to(DEVICE)
    x = torch.randn(4, 16, HIDDEN_SIZE, device=DEVICE)

    # eager baseline
    with torch.no_grad():
        baseline = model(x)

    # 创建 magi inner + torch.compile forward outer 版本
    MagiInnerBlock = magi_compile()(InnerBlock)

    class CompiledOuterModel(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.inner = MagiInnerBlock(hidden_size)
            self.output = nn.Linear(hidden_size, hidden_size)

        @torch.compile(fullgraph=False, dynamic=True)
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.inner(x)
            return self.output(x)

    compiled_model = CompiledOuterModel(HIDDEN_SIZE).to(DEVICE)
    compiled_model.load_state_dict(model.state_dict())

    assert_torch_compiled(compiled_model)

    with torch.no_grad():
        out = compiled_model(x)

    assert_magi_compiled(compiled_model.inner)
    assert_not_torch_compiled_or_disabled(compiled_model.inner)
    assert torch.allclose(baseline, out, atol=TOLERANCE, rtol=TOLERANCE)


# ============ torch._dynamo.config 正确性验证 ============


def test_dynamo_config_nested_patch_restore():
    """验证 config.patch() 嵌套时能正确恢复到上一层的值"""
    import torch._dynamo.config as config

    # 记录初始值
    initial_value = config.assume_static_by_default

    # 模拟外层 compile 设置 dynamic=True (assume_static_by_default=False)
    with config.patch(assume_static_by_default=False):
        assert config.assume_static_by_default is False, "外层 patch 应将值设为 False"

        # 模拟内层 magi_compile 恢复默认 (assume_static_by_default=True)
        with config.patch(assume_static_by_default=True):
            assert config.assume_static_by_default is True, "内层 patch 应将值设为 True"

        # 内层退出后,应该恢复到外层的值
        assert config.assume_static_by_default is False, "内层退出后应恢复到外层值 False"

    # 外层退出后,应该恢复到初始值
    assert config.assume_static_by_default == initial_value, f"外层退出后应恢复到初始值 {initial_value}"


def test_dynamo_config_multiple_options_patch():
    """验证同时 patch 多个配置项时的正确性"""
    import torch._dynamo.config as config

    # 记录初始值
    initial_assume_static = config.assume_static_by_default
    initial_suppress_errors = config.suppress_errors
    initial_verbose = config.verbose

    # 同时 patch 多个配置项
    with config.patch(
        assume_static_by_default=not initial_assume_static,
        suppress_errors=not initial_suppress_errors,
        verbose=not initial_verbose,
    ):
        # 验证所有配置项都已修改
        assert config.assume_static_by_default == (not initial_assume_static), "assume_static_by_default 应被修改"
        assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors 应被修改"
        assert config.verbose == (not initial_verbose), "verbose 应被修改"

        # 嵌套 patch 部分配置项
        with config.patch(assume_static_by_default=initial_assume_static):
            assert config.assume_static_by_default == initial_assume_static, "内层应恢复 assume_static_by_default"
            # 其他配置项应保持外层 patch 的值
            assert config.suppress_errors == (not initial_suppress_errors), "suppress_errors 应保持外层值"
            assert config.verbose == (not initial_verbose), "verbose 应保持外层值"

        # 内层退出后,assume_static_by_default 应恢复到外层 patch 的值
        assert config.assume_static_by_default == (not initial_assume_static), "内层退出后应恢复到外层 patch 值"

    # 外层退出后,所有配置项都应恢复到初始值
    assert config.assume_static_by_default == initial_assume_static, "外层退出后 assume_static_by_default 应恢复"
    assert config.suppress_errors == initial_suppress_errors, "外层退出后 suppress_errors 应恢复"
    assert config.verbose == initial_verbose, "外层退出后 verbose 应恢复"


def test_dynamo_config_restore_on_exception():
    """验证在 with 块内抛出异常时配置能正确恢复"""
    import torch._dynamo.config as config

    # 记录初始值
    initial_value = config.assume_static_by_default

    # 测试单层 patch 在异常时的恢复
    try:
        with config.patch(assume_static_by_default=not initial_value):
            assert config.assume_static_by_default == (not initial_value), "patch 应生效"
            raise RuntimeError("测试异常")
    except RuntimeError:
        pass

    # 异常后配置应恢复
    assert config.assume_static_by_default == initial_value, "单层异常后应恢复到初始值"

    # 测试嵌套 patch 在内层异常时的恢复
    try:
        with config.patch(assume_static_by_default=False):
            assert config.assume_static_by_default is False, "外层 patch 应生效"
            try:
                with config.patch(assume_static_by_default=True):
                    assert config.assume_static_by_default is True, "内层 patch 应生效"
                    raise ValueError("内层测试异常")
            except ValueError:
                pass
            # 内层异常捕获后,应恢复到外层值
            assert config.assume_static_by_default is False, "内层异常后应恢复到外层值"
    except Exception:
        pytest.fail("外层不应捕获到异常")

    # 最终应恢复到初始值
    assert config.assume_static_by_default == initial_value, "嵌套异常后应恢复到初始值"


if __name__ == "__main__":
    pytest.main([__file__, "-v", "-s"])