Spaces:
Runtime error
Runtime error
File size: 13,154 Bytes
e6066e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | # Why MagiCompiler?
## 1. Compiler Overview
### 1.1 Background
We have long encountered several significant challenges in model optimization:
1. **Blurred Acceleration Boundaries:** There is ambiguity regarding the extent of optimization required to achieve "extreme" performance.
2. **Complex Performance Tuning:** Optimization strategies are often tightly coupled with model architectures, necessitating extensive and repetitive manual intervention.
3. **Deficiency in Optimization Tools:** The infrastructure lacks sufficient mechanisms for computational graph-level optimizations, such as operator substitution and communication overlap.
MagiCompiler addresses these issues through the following approaches:
* **Addressing Challenge 1:** It adopts **whole-graph compilation**, thoroughly transcending the boundaries of `TransformerLayer` to maximize the scope of kernel fusion.
* **Addressing Challenge 2:** It integrates infrastructure optimizations directly into MagiCompiler, implementing features such as `AutoCudaGraph` and `AutoCheckpointing(WIP)`.
* **Addressing Challenge 3:** It leverages the dynamic-to-static capabilities provided by **Dynamo**, capturing `fx.graph` IR in eager mode to perform pass optimizations at the IR level.
#### Illustrative Example
```python
from magi_compiler import magi_compile
@magi_compile()
class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1024, 1024, device="cuda")
@no_grad()
def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
return self.linear(x + y - z + 1)
def magi_compiler_demo():
model = TinyModel()
x = torch.randn(1024, 1024, device="cuda")
y = torch.randn(1024, 1024, device="cuda")
z = torch.randn(1024, 1024, device="cuda")
model(x, y, z)
```
**Optimized Code (Triton Kernel):**
```python
triton_poi_fused_add_sub_0 = async_compile.triton('triton_poi_fused_add_sub_0', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1048576},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'B8F4209CBFC2377D6AF9CF3C65D610CA2B56C138A443862350DE1E56F5BF54C3', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_sub_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl.load(in_ptr1 + (x0), xmask)
tmp3 = tl.load(in_ptr2 + (x0), xmask)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 - tmp3
tmp5 = 1.0
tmp6 = tmp4 + tmp5
tl.store(out_ptr0 + (x0), tmp6, xmask)
''', device_str='cuda')
```
### 1.2 Frontend (Dynamo)

* **PyFrameObject (Dynamic Call Stack):**
* Represents the context environment during function execution. Python creates a new `PyFrameObject` for each function call.
* **PyCodeObject (Static Bytecode):**
* The compiled product of Python code, which is static and read-only. A single `PyCodeObject` exists regardless of how many times the function is invoked.
```python
def f(x, mod):
for guard, transformed_code in f.compiled_entries:
if guard(x, mod):
return transformed_code(x, mod)
try:
guard, transformed_code = compile_and_optimize(x, mod)
f.compiled_entries.append([guard, transformed_code])
return transformed_code(x, mod)
except FailToCompileError:
y = mod(x)
z = torch.log(y)
return z
```
#### Symbolic Shape
MagiCompiler specifically targets the Transformer architecture and supports custom `dynamic_arg_dims` (typically for `seq_len`).
**Example:**
```python
@magi_compile(dynamic_arg_dims={"x": 0, "y": 0, "z": 0})
class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1024, 1024, device="cuda")
@no_grad()
def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
return self.linear(x + y - z + 1)
```
**Guard Mechanism and Elimination in Symbolic Shape Deduction:**
```log
I1204 16:31:35.745000 1859360 torch/_dynamo/symbolic_convert.py:3842] [0/0] Step 1: torchdynamo start tracing inner /usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py:66
I1204 16:31:35.746000 1859360 torch/fx/experimental/symbolic_shapes.py:3775] [0/0] create_env
I1204 16:31:35.781000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s33 = 1024 for L['args'][0].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s33" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.785000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s6 = 1024 for L['args'][1].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s6" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.794000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s6) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s6)"
I1204 16:31:35.795000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s33 (solve) VR[2, int_oo]
I1204 16:31:35.800000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s21 = 1024 for L['args'][2].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.806000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s21) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s21)"
I1204 16:31:35.807000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s33 = s21 (solve) VR[2, int_oo]
I1204 16:31:35.828000 1859360 torch/_dynamo/symbolic_convert.py:4059] [0/0] Step 1: torchdynamo done tracing inner (RETURN_VALUE)
I1204 16:31:35.837000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s21 (find) VR[2, int_oo]
```
### 1.3 Backend (Inductor, MagiBackend, etc.)

MagiCompiler hijack the `torch.compile` logic through the following components:
* **`custom_partitioner_fn`:** Segments the forward and backward computational graphs and determines which intermediate results are transmitted to the backward pass.
* **`post_grad_custom_pre_pass`:** Performs pass optimizations at the whole-graph level (computational graph matching and rewriting).
* **`PartitionFunc`:** Implements custom subgraph partitioning logic, utilizing attention mechanisms as splitting points.

* **`post_grad_custom_post_pass`:** Executes pass optimizations at the subgraph level (computation/communication overlap).
---
## 2. Best Practices
### 2.1 Model Adaptation
MagiCompiler has certain limitations, such as mandatory whole-graph capture and the inability to support implicit subgraph interruptions. Consequently, manual adaptation is required in specific scenarios:
**1. Computational Graph Dependencies or CPU/GPU Synchronization**
```python
@magi_compile
class MeanModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor):
x = x.cos().sin()
if x.mean() > 0.5:
x = x - 1
return x * y
```
> **Note:** In typical Transformer models, certain pre/post-processing operations are unavoidable. Therefore, the recommended practice for `magi_compiler` is to perform **whole-graph capture at the `TransformerBlock` level**, as `TransformerBlock` computations constitute over 95% of the total workload.
**2. Custom Operators (e.g., FlashAttention, FlexFlashAttention, MoE kernels)**
* **Operator Registration:** A logic for operator registration is provided. Commonly used operators like FlashAttention (FA) and FlexFlashAttention (FFA) are already registered.
```python
# Operator Registration
@torch.library.custom_op("athena::flash_attn_func", mutates_args=())
def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
...
# Operator Deduce Function
@flash_attn_func.register_fake
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
return torch.empty_like(query)
# Call flash_attn_func
self_attn_out = torch.ops.athena.flash_attn_func(q, k, v)
out, _ = torch.ops.athena.flex_flash_attn_func(q, k, v, q_ranges=ffa_handler.q_ranges, k_ranges=ffa_handler.k_ranges)
```
* **Unit Testing:** Independent unit tests for operators should be provided in the production environment.
```python
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seq_len", [1024, 2048, 4096])
@pytest.mark.parametrize("query_head", [48])
@pytest.mark.parametrize("kv_head", [4, 8])
@pytest.mark.parametrize("head_dim", [128, 256])
def test_fake_fa3(batch_size, seq_len, query_head, kv_head, head_dim):
q = torch.randn((batch_size, seq_len, query_head, head_dim), device="cuda", dtype=torch.bfloat16)
k = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
v = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
torch.library.opcheck(torch.ops.athena.flash_attn_func, (q, k, v))
```
### 2.2 Debugging Methods
Key questions for debugging:
* Is the bug originating from the compiler?
* Which specific component of the compiler is causing the bug?

```python
class CompileConfig(BaseModel):
# Basic configs
backend: str = Field("inductor", description="Compilation backend.")
compile_mode: CompileMode = Field(CompileMode.MAGI_COMPILE, description="Compilation mode.")
...
# Cudagraph configs
cudagraph_mode: CudaGraphMode = Field(CudaGraphMode.NONE, description="Cudagraph mode.")
...
# Pass configs
pass_config: PassConfig = Field(PassConfig(), description="Pass configuration.")
...
```
### 2.3 Profiling Results
For further details, please refer to the [**Wan2.2 Benchmark**](Wan2.2Benchmark.md).
---
## References
1. [PyTorch 2.0 Overview](https://docs.pytorch.org/assets/pytorch2-2.pdf)
2. [TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation](https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361)
3. [Depyf Walkthrough](https://depyf.readthedocs.io/en/latest/walk_through.html)
4. [Getting Started with PyTorch Compiler](https://docs.pytorch.org/docs/main/torch.compiler_get_started.html)
|