# Why MagiCompiler? ## 1. Compiler Overview ### 1.1 Background We have long encountered several significant challenges in model optimization: 1. **Blurred Acceleration Boundaries:** There is ambiguity regarding the extent of optimization required to achieve "extreme" performance. 2. **Complex Performance Tuning:** Optimization strategies are often tightly coupled with model architectures, necessitating extensive and repetitive manual intervention. 3. **Deficiency in Optimization Tools:** The infrastructure lacks sufficient mechanisms for computational graph-level optimizations, such as operator substitution and communication overlap. MagiCompiler addresses these issues through the following approaches: * **Addressing Challenge 1:** It adopts **whole-graph compilation**, thoroughly transcending the boundaries of `TransformerLayer` to maximize the scope of kernel fusion. * **Addressing Challenge 2:** It integrates infrastructure optimizations directly into MagiCompiler, implementing features such as `AutoCudaGraph` and `AutoCheckpointing(WIP)`. * **Addressing Challenge 3:** It leverages the dynamic-to-static capabilities provided by **Dynamo**, capturing `fx.graph` IR in eager mode to perform pass optimizations at the IR level. #### Illustrative Example ```python from magi_compiler import magi_compile @magi_compile() class TinyModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(1024, 1024, device="cuda") @no_grad() def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: return self.linear(x + y - z + 1) def magi_compiler_demo(): model = TinyModel() x = torch.randn(1024, 1024, device="cuda") y = torch.randn(1024, 1024, device="cuda") z = torch.randn(1024, 1024, device="cuda") model(x, y, z) ``` **Optimized Code (Triton Kernel):** ```python triton_poi_fused_add_sub_0 = async_compile.triton('triton_poi_fused_add_sub_0', ''' import triton import triton.language as tl from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties triton_helpers.set_driver_to_gpu() @triton_heuristics.pointwise( size_hints={'x': 1048576}, filename=__file__, triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'B8F4209CBFC2377D6AF9CF3C65D610CA2B56C138A443862350DE1E56F5BF54C3', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, min_elem_per_thread=0 ) @triton.jit def triton_poi_fused_add_sub_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp3 = tl.load(in_ptr2 + (x0), xmask) tmp2 = tmp0 + tmp1 tmp4 = tmp2 - tmp3 tmp5 = 1.0 tmp6 = tmp4 + tmp5 tl.store(out_ptr0 + (x0), tmp6, xmask) ''', device_str='cuda') ``` ### 1.2 Frontend (Dynamo) ![Dynamo](./assets/why_magicompiler_1_dynamo.jpeg) * **PyFrameObject (Dynamic Call Stack):** * Represents the context environment during function execution. Python creates a new `PyFrameObject` for each function call. * **PyCodeObject (Static Bytecode):** * The compiled product of Python code, which is static and read-only. A single `PyCodeObject` exists regardless of how many times the function is invoked. ```python def f(x, mod): for guard, transformed_code in f.compiled_entries: if guard(x, mod): return transformed_code(x, mod) try: guard, transformed_code = compile_and_optimize(x, mod) f.compiled_entries.append([guard, transformed_code]) return transformed_code(x, mod) except FailToCompileError: y = mod(x) z = torch.log(y) return z ``` #### Symbolic Shape MagiCompiler specifically targets the Transformer architecture and supports custom `dynamic_arg_dims` (typically for `seq_len`). **Example:** ```python @magi_compile(dynamic_arg_dims={"x": 0, "y": 0, "z": 0}) class TinyModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(1024, 1024, device="cuda") @no_grad() def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: return self.linear(x + y - z + 1) ``` **Guard Mechanism and Elimination in Symbolic Shape Deduction:** ```log I1204 16:31:35.745000 1859360 torch/_dynamo/symbolic_convert.py:3842] [0/0] Step 1: torchdynamo start tracing inner /usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py:66 I1204 16:31:35.746000 1859360 torch/fx/experimental/symbolic_shapes.py:3775] [0/0] create_env I1204 16:31:35.781000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s33 = 1024 for L['args'][0].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s33" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I1204 16:31:35.785000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s6 = 1024 for L['args'][1].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s6" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I1204 16:31:35.794000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s6) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s6)" I1204 16:31:35.795000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s33 (solve) VR[2, int_oo] I1204 16:31:35.800000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s21 = 1024 for L['args'][2].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I1204 16:31:35.806000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s21) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s21)" I1204 16:31:35.807000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s33 = s21 (solve) VR[2, int_oo] I1204 16:31:35.828000 1859360 torch/_dynamo/symbolic_convert.py:4059] [0/0] Step 1: torchdynamo done tracing inner (RETURN_VALUE) I1204 16:31:35.837000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s21 (find) VR[2, int_oo] ``` ### 1.3 Backend (Inductor, MagiBackend, etc.) ![Backend Architecture](./assets/why_magicompiler_2_arch.png) MagiCompiler hijack the `torch.compile` logic through the following components: * **`custom_partitioner_fn`:** Segments the forward and backward computational graphs and determines which intermediate results are transmitted to the backward pass. * **`post_grad_custom_pre_pass`:** Performs pass optimizations at the whole-graph level (computational graph matching and rewriting). * **`PartitionFunc`:** Implements custom subgraph partitioning logic, utilizing attention mechanisms as splitting points. ![Partition](./assets/why_magicompiler_3_partition.png) * **`post_grad_custom_post_pass`:** Executes pass optimizations at the subgraph level (computation/communication overlap). --- ## 2. Best Practices ### 2.1 Model Adaptation MagiCompiler has certain limitations, such as mandatory whole-graph capture and the inability to support implicit subgraph interruptions. Consequently, manual adaptation is required in specific scenarios: **1. Computational Graph Dependencies or CPU/GPU Synchronization** ```python @magi_compile class MeanModule(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x: torch.Tensor, y: torch.Tensor): x = x.cos().sin() if x.mean() > 0.5: x = x - 1 return x * y ``` > **Note:** In typical Transformer models, certain pre/post-processing operations are unavoidable. Therefore, the recommended practice for `magi_compiler` is to perform **whole-graph capture at the `TransformerBlock` level**, as `TransformerBlock` computations constitute over 95% of the total workload. **2. Custom Operators (e.g., FlashAttention, FlexFlashAttention, MoE kernels)** * **Operator Registration:** A logic for operator registration is provided. Commonly used operators like FlashAttention (FA) and FlexFlashAttention (FFA) are already registered. ```python # Operator Registration @torch.library.custom_op("athena::flash_attn_func", mutates_args=()) def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: ... # Operator Deduce Function @flash_attn_func.register_fake def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: return torch.empty_like(query) # Call flash_attn_func self_attn_out = torch.ops.athena.flash_attn_func(q, k, v) out, _ = torch.ops.athena.flex_flash_attn_func(q, k, v, q_ranges=ffa_handler.q_ranges, k_ranges=ffa_handler.k_ranges) ``` * **Unit Testing:** Independent unit tests for operators should be provided in the production environment. ```python @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("seq_len", [1024, 2048, 4096]) @pytest.mark.parametrize("query_head", [48]) @pytest.mark.parametrize("kv_head", [4, 8]) @pytest.mark.parametrize("head_dim", [128, 256]) def test_fake_fa3(batch_size, seq_len, query_head, kv_head, head_dim): q = torch.randn((batch_size, seq_len, query_head, head_dim), device="cuda", dtype=torch.bfloat16) k = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16) v = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16) torch.library.opcheck(torch.ops.athena.flash_attn_func, (q, k, v)) ``` ### 2.2 Debugging Methods Key questions for debugging: * Is the bug originating from the compiler? * Which specific component of the compiler is causing the bug? ![Debugging](./assets/why_magicompiler_4_debug.png) ```python class CompileConfig(BaseModel): # Basic configs backend: str = Field("inductor", description="Compilation backend.") compile_mode: CompileMode = Field(CompileMode.MAGI_COMPILE, description="Compilation mode.") ... # Cudagraph configs cudagraph_mode: CudaGraphMode = Field(CudaGraphMode.NONE, description="Cudagraph mode.") ... # Pass configs pass_config: PassConfig = Field(PassConfig(), description="Pass configuration.") ... ``` ### 2.3 Profiling Results For further details, please refer to the [**Wan2.2 Benchmark**](Wan2.2Benchmark.md). --- ## References 1. [PyTorch 2.0 Overview](https://docs.pytorch.org/assets/pytorch2-2.pdf) 2. [TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation](https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361) 3. [Depyf Walkthrough](https://depyf.readthedocs.io/en/latest/walk_through.html) 4. [Getting Started with PyTorch Compiler](https://docs.pytorch.org/docs/main/torch.compiler_get_started.html)