File size: 13,154 Bytes
e6066e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# Why MagiCompiler?

## 1. Compiler Overview

### 1.1 Background

We have long encountered several significant challenges in model optimization:

1.  **Blurred Acceleration Boundaries:** There is ambiguity regarding the extent of optimization required to achieve "extreme" performance.
2.  **Complex Performance Tuning:** Optimization strategies are often tightly coupled with model architectures, necessitating extensive and repetitive manual intervention.
3.  **Deficiency in Optimization Tools:** The infrastructure lacks sufficient mechanisms for computational graph-level optimizations, such as operator substitution and communication overlap.

MagiCompiler addresses these issues through the following approaches:

*   **Addressing Challenge 1:** It adopts **whole-graph compilation**, thoroughly transcending the boundaries of `TransformerLayer` to maximize the scope of kernel fusion.
*   **Addressing Challenge 2:** It integrates infrastructure optimizations directly into MagiCompiler, implementing features such as `AutoCudaGraph` and `AutoCheckpointing(WIP)`.
*   **Addressing Challenge 3:** It leverages the dynamic-to-static capabilities provided by **Dynamo**, capturing `fx.graph` IR in eager mode to perform pass optimizations at the IR level.

#### Illustrative Example

```python
from magi_compiler import magi_compile

@magi_compile()
class TinyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1024, 1024, device="cuda")

    @no_grad()
    def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        return self.linear(x + y - z + 1)


def magi_compiler_demo():
    model = TinyModel()
    x = torch.randn(1024, 1024, device="cuda")
    y = torch.randn(1024, 1024, device="cuda")
    z = torch.randn(1024, 1024, device="cuda")
    model(x, y, z)
```

**Optimized Code (Triton Kernel):**

```python
triton_poi_fused_add_sub_0 = async_compile.triton('triton_poi_fused_add_sub_0', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 1048576},
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'B8F4209CBFC2377D6AF9CF3C65D610CA2B56C138A443862350DE1E56F5BF54C3', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_sub_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp3 = tl.load(in_ptr2 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 - tmp3
    tmp5 = 1.0
    tmp6 = tmp4 + tmp5
    tl.store(out_ptr0 + (x0), tmp6, xmask)
''', device_str='cuda')
```

### 1.2 Frontend (Dynamo)

![Dynamo](./assets/why_magicompiler_1_dynamo.jpeg)

*   **PyFrameObject (Dynamic Call Stack):**
    *   Represents the context environment during function execution. Python creates a new `PyFrameObject` for each function call.
*   **PyCodeObject (Static Bytecode):**
    *   The compiled product of Python code, which is static and read-only. A single `PyCodeObject` exists regardless of how many times the function is invoked.

```python
def f(x, mod):
    for guard, transformed_code in f.compiled_entries:
        if guard(x, mod):
            return transformed_code(x, mod)
    try:
        guard, transformed_code = compile_and_optimize(x, mod)
        f.compiled_entries.append([guard, transformed_code])
        return transformed_code(x, mod)
    except FailToCompileError:
        y = mod(x)
        z = torch.log(y)
        return z
```

#### Symbolic Shape

MagiCompiler specifically targets the Transformer architecture and supports custom `dynamic_arg_dims` (typically for `seq_len`).

**Example:**

```python
@magi_compile(dynamic_arg_dims={"x": 0, "y": 0, "z": 0})
class TinyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1024, 1024, device="cuda")

    @no_grad()
    def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        return self.linear(x + y - z + 1)
```

**Guard Mechanism and Elimination in Symbolic Shape Deduction:**

```log
I1204 16:31:35.745000 1859360 torch/_dynamo/symbolic_convert.py:3842] [0/0] Step 1: torchdynamo start tracing inner /usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py:66
I1204 16:31:35.746000 1859360 torch/fx/experimental/symbolic_shapes.py:3775] [0/0] create_env
I1204 16:31:35.781000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s33 = 1024 for L['args'][0].size()[0] [2, int_oo] return self.linear(x + y - z + 1)  # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s33" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.785000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s6 = 1024 for L['args'][1].size()[0] [2, int_oo] return self.linear(x + y - z + 1)  # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s6" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.794000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s6) [guard added] return self.linear(x + y - z + 1)  # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s6)"
I1204 16:31:35.795000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s33 (solve) VR[2, int_oo]
I1204 16:31:35.800000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s21 = 1024 for L['args'][2].size()[0] [2, int_oo] return self.linear(x + y - z + 1)  # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.806000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s21) [guard added] return self.linear(x + y - z + 1)  # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s21)"
I1204 16:31:35.807000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s33 = s21 (solve) VR[2, int_oo]
I1204 16:31:35.828000 1859360 torch/_dynamo/symbolic_convert.py:4059] [0/0] Step 1: torchdynamo done tracing inner (RETURN_VALUE)
I1204 16:31:35.837000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s21 (find) VR[2, int_oo]
```

### 1.3 Backend (Inductor, MagiBackend, etc.)

![Backend Architecture](./assets/why_magicompiler_2_arch.png)

MagiCompiler hijack the `torch.compile` logic through the following components:

*   **`custom_partitioner_fn`:** Segments the forward and backward computational graphs and determines which intermediate results are transmitted to the backward pass.
*   **`post_grad_custom_pre_pass`:** Performs pass optimizations at the whole-graph level (computational graph matching and rewriting).
*   **`PartitionFunc`:** Implements custom subgraph partitioning logic, utilizing attention mechanisms as splitting points.

![Partition](./assets/why_magicompiler_3_partition.png)

*   **`post_grad_custom_post_pass`:** Executes pass optimizations at the subgraph level (computation/communication overlap).

---

## 2. Best Practices

### 2.1 Model Adaptation

MagiCompiler has certain limitations, such as mandatory whole-graph capture and the inability to support implicit subgraph interruptions. Consequently, manual adaptation is required in specific scenarios:

**1. Computational Graph Dependencies or CPU/GPU Synchronization**

```python
@magi_compile
class MeanModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        x = x.cos().sin()
        if x.mean() > 0.5:
            x = x - 1
        return x * y
```

> **Note:** In typical Transformer models, certain pre/post-processing operations are unavoidable. Therefore, the recommended practice for `magi_compiler` is to perform **whole-graph capture at the `TransformerBlock` level**, as `TransformerBlock` computations constitute over 95% of the total workload.

**2. Custom Operators (e.g., FlashAttention, FlexFlashAttention, MoE kernels)**

*   **Operator Registration:** A logic for operator registration is provided. Commonly used operators like FlashAttention (FA) and FlexFlashAttention (FFA) are already registered.

```python
# Operator Registration
@torch.library.custom_op("athena::flash_attn_func", mutates_args=())
def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    ...

# Operator Deduce Function
@flash_attn_func.register_fake
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    return torch.empty_like(query)

# Call flash_attn_func
self_attn_out = torch.ops.athena.flash_attn_func(q, k, v)
out, _ = torch.ops.athena.flex_flash_attn_func(q, k, v, q_ranges=ffa_handler.q_ranges, k_ranges=ffa_handler.k_ranges)
```

*   **Unit Testing:** Independent unit tests for operators should be provided in the production environment.

```python
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seq_len", [1024, 2048, 4096])
@pytest.mark.parametrize("query_head", [48])
@pytest.mark.parametrize("kv_head", [4, 8])
@pytest.mark.parametrize("head_dim", [128, 256])
def test_fake_fa3(batch_size, seq_len, query_head, kv_head, head_dim):
    q = torch.randn((batch_size, seq_len, query_head, head_dim), device="cuda", dtype=torch.bfloat16)
    k = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
    v = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
    torch.library.opcheck(torch.ops.athena.flash_attn_func, (q, k, v))
```

### 2.2 Debugging Methods

Key questions for debugging:
*   Is the bug originating from the compiler?
*   Which specific component of the compiler is causing the bug?

![Debugging](./assets/why_magicompiler_4_debug.png)

```python
class CompileConfig(BaseModel):
    # Basic configs
    backend: str = Field("inductor", description="Compilation backend.")
    compile_mode: CompileMode = Field(CompileMode.MAGI_COMPILE, description="Compilation mode.")
    ...

    # Cudagraph configs
    cudagraph_mode: CudaGraphMode = Field(CudaGraphMode.NONE, description="Cudagraph mode.")
    ...

    # Pass configs
    pass_config: PassConfig = Field(PassConfig(), description="Pass configuration.")
    ...
```

### 2.3 Profiling Results

For further details, please refer to the [**Wan2.2 Benchmark**](Wan2.2Benchmark.md).

---

## References

1.  [PyTorch 2.0 Overview](https://docs.pytorch.org/assets/pytorch2-2.pdf)
2.  [TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation](https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361)
3.  [Depyf Walkthrough](https://depyf.readthedocs.io/en/latest/walk_through.html)
4.  [Getting Started with PyTorch Compiler](https://docs.pytorch.org/docs/main/torch.compiler_get_started.html)