daVinci-MagiHuman / pkgs /MagiCompiler /docs /WhyMagiCompiler.md
jiadisu
Switch back to Docker SDK with local pkgs
e6066e8
# Why MagiCompiler?
## 1. Compiler Overview
### 1.1 Background
We have long encountered several significant challenges in model optimization:
1. **Blurred Acceleration Boundaries:** There is ambiguity regarding the extent of optimization required to achieve "extreme" performance.
2. **Complex Performance Tuning:** Optimization strategies are often tightly coupled with model architectures, necessitating extensive and repetitive manual intervention.
3. **Deficiency in Optimization Tools:** The infrastructure lacks sufficient mechanisms for computational graph-level optimizations, such as operator substitution and communication overlap.
MagiCompiler addresses these issues through the following approaches:
* **Addressing Challenge 1:** It adopts **whole-graph compilation**, thoroughly transcending the boundaries of `TransformerLayer` to maximize the scope of kernel fusion.
* **Addressing Challenge 2:** It integrates infrastructure optimizations directly into MagiCompiler, implementing features such as `AutoCudaGraph` and `AutoCheckpointing(WIP)`.
* **Addressing Challenge 3:** It leverages the dynamic-to-static capabilities provided by **Dynamo**, capturing `fx.graph` IR in eager mode to perform pass optimizations at the IR level.
#### Illustrative Example
```python
from magi_compiler import magi_compile
@magi_compile()
class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1024, 1024, device="cuda")
@no_grad()
def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
return self.linear(x + y - z + 1)
def magi_compiler_demo():
model = TinyModel()
x = torch.randn(1024, 1024, device="cuda")
y = torch.randn(1024, 1024, device="cuda")
z = torch.randn(1024, 1024, device="cuda")
model(x, y, z)
```
**Optimized Code (Triton Kernel):**
```python
triton_poi_fused_add_sub_0 = async_compile.triton('triton_poi_fused_add_sub_0', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 1048576},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'B8F4209CBFC2377D6AF9CF3C65D610CA2B56C138A443862350DE1E56F5BF54C3', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_sub_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl.load(in_ptr1 + (x0), xmask)
tmp3 = tl.load(in_ptr2 + (x0), xmask)
tmp2 = tmp0 + tmp1
tmp4 = tmp2 - tmp3
tmp5 = 1.0
tmp6 = tmp4 + tmp5
tl.store(out_ptr0 + (x0), tmp6, xmask)
''', device_str='cuda')
```
### 1.2 Frontend (Dynamo)
![Dynamo](./assets/why_magicompiler_1_dynamo.jpeg)
* **PyFrameObject (Dynamic Call Stack):**
* Represents the context environment during function execution. Python creates a new `PyFrameObject` for each function call.
* **PyCodeObject (Static Bytecode):**
* The compiled product of Python code, which is static and read-only. A single `PyCodeObject` exists regardless of how many times the function is invoked.
```python
def f(x, mod):
for guard, transformed_code in f.compiled_entries:
if guard(x, mod):
return transformed_code(x, mod)
try:
guard, transformed_code = compile_and_optimize(x, mod)
f.compiled_entries.append([guard, transformed_code])
return transformed_code(x, mod)
except FailToCompileError:
y = mod(x)
z = torch.log(y)
return z
```
#### Symbolic Shape
MagiCompiler specifically targets the Transformer architecture and supports custom `dynamic_arg_dims` (typically for `seq_len`).
**Example:**
```python
@magi_compile(dynamic_arg_dims={"x": 0, "y": 0, "z": 0})
class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1024, 1024, device="cuda")
@no_grad()
def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
return self.linear(x + y - z + 1)
```
**Guard Mechanism and Elimination in Symbolic Shape Deduction:**
```log
I1204 16:31:35.745000 1859360 torch/_dynamo/symbolic_convert.py:3842] [0/0] Step 1: torchdynamo start tracing inner /usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py:66
I1204 16:31:35.746000 1859360 torch/fx/experimental/symbolic_shapes.py:3775] [0/0] create_env
I1204 16:31:35.781000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s33 = 1024 for L['args'][0].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s33" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.785000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s6 = 1024 for L['args'][1].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s6" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.794000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s6) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s6)"
I1204 16:31:35.795000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s33 (solve) VR[2, int_oo]
I1204 16:31:35.800000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s21 = 1024 for L['args'][2].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.806000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s21) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s21)"
I1204 16:31:35.807000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s33 = s21 (solve) VR[2, int_oo]
I1204 16:31:35.828000 1859360 torch/_dynamo/symbolic_convert.py:4059] [0/0] Step 1: torchdynamo done tracing inner (RETURN_VALUE)
I1204 16:31:35.837000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s21 (find) VR[2, int_oo]
```
### 1.3 Backend (Inductor, MagiBackend, etc.)
![Backend Architecture](./assets/why_magicompiler_2_arch.png)
MagiCompiler hijack the `torch.compile` logic through the following components:
* **`custom_partitioner_fn`:** Segments the forward and backward computational graphs and determines which intermediate results are transmitted to the backward pass.
* **`post_grad_custom_pre_pass`:** Performs pass optimizations at the whole-graph level (computational graph matching and rewriting).
* **`PartitionFunc`:** Implements custom subgraph partitioning logic, utilizing attention mechanisms as splitting points.
![Partition](./assets/why_magicompiler_3_partition.png)
* **`post_grad_custom_post_pass`:** Executes pass optimizations at the subgraph level (computation/communication overlap).
---
## 2. Best Practices
### 2.1 Model Adaptation
MagiCompiler has certain limitations, such as mandatory whole-graph capture and the inability to support implicit subgraph interruptions. Consequently, manual adaptation is required in specific scenarios:
**1. Computational Graph Dependencies or CPU/GPU Synchronization**
```python
@magi_compile
class MeanModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor):
x = x.cos().sin()
if x.mean() > 0.5:
x = x - 1
return x * y
```
> **Note:** In typical Transformer models, certain pre/post-processing operations are unavoidable. Therefore, the recommended practice for `magi_compiler` is to perform **whole-graph capture at the `TransformerBlock` level**, as `TransformerBlock` computations constitute over 95% of the total workload.
**2. Custom Operators (e.g., FlashAttention, FlexFlashAttention, MoE kernels)**
* **Operator Registration:** A logic for operator registration is provided. Commonly used operators like FlashAttention (FA) and FlexFlashAttention (FFA) are already registered.
```python
# Operator Registration
@torch.library.custom_op("athena::flash_attn_func", mutates_args=())
def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
...
# Operator Deduce Function
@flash_attn_func.register_fake
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
return torch.empty_like(query)
# Call flash_attn_func
self_attn_out = torch.ops.athena.flash_attn_func(q, k, v)
out, _ = torch.ops.athena.flex_flash_attn_func(q, k, v, q_ranges=ffa_handler.q_ranges, k_ranges=ffa_handler.k_ranges)
```
* **Unit Testing:** Independent unit tests for operators should be provided in the production environment.
```python
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seq_len", [1024, 2048, 4096])
@pytest.mark.parametrize("query_head", [48])
@pytest.mark.parametrize("kv_head", [4, 8])
@pytest.mark.parametrize("head_dim", [128, 256])
def test_fake_fa3(batch_size, seq_len, query_head, kv_head, head_dim):
q = torch.randn((batch_size, seq_len, query_head, head_dim), device="cuda", dtype=torch.bfloat16)
k = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
v = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
torch.library.opcheck(torch.ops.athena.flash_attn_func, (q, k, v))
```
### 2.2 Debugging Methods
Key questions for debugging:
* Is the bug originating from the compiler?
* Which specific component of the compiler is causing the bug?
![Debugging](./assets/why_magicompiler_4_debug.png)
```python
class CompileConfig(BaseModel):
# Basic configs
backend: str = Field("inductor", description="Compilation backend.")
compile_mode: CompileMode = Field(CompileMode.MAGI_COMPILE, description="Compilation mode.")
...
# Cudagraph configs
cudagraph_mode: CudaGraphMode = Field(CudaGraphMode.NONE, description="Cudagraph mode.")
...
# Pass configs
pass_config: PassConfig = Field(PassConfig(), description="Pass configuration.")
...
```
### 2.3 Profiling Results
For further details, please refer to the [**Wan2.2 Benchmark**](Wan2.2Benchmark.md).
---
## References
1. [PyTorch 2.0 Overview](https://docs.pytorch.org/assets/pytorch2-2.pdf)
2. [TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation](https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361)
3. [Depyf Walkthrough](https://depyf.readthedocs.io/en/latest/walk_through.html)
4. [Getting Started with PyTorch Compiler](https://docs.pytorch.org/docs/main/torch.compiler_get_started.html)