Spaces:
Runtime error
Runtime error
| # Why MagiCompiler? | |
| ## 1. Compiler Overview | |
| ### 1.1 Background | |
| We have long encountered several significant challenges in model optimization: | |
| 1. **Blurred Acceleration Boundaries:** There is ambiguity regarding the extent of optimization required to achieve "extreme" performance. | |
| 2. **Complex Performance Tuning:** Optimization strategies are often tightly coupled with model architectures, necessitating extensive and repetitive manual intervention. | |
| 3. **Deficiency in Optimization Tools:** The infrastructure lacks sufficient mechanisms for computational graph-level optimizations, such as operator substitution and communication overlap. | |
| MagiCompiler addresses these issues through the following approaches: | |
| * **Addressing Challenge 1:** It adopts **whole-graph compilation**, thoroughly transcending the boundaries of `TransformerLayer` to maximize the scope of kernel fusion. | |
| * **Addressing Challenge 2:** It integrates infrastructure optimizations directly into MagiCompiler, implementing features such as `AutoCudaGraph` and `AutoCheckpointing(WIP)`. | |
| * **Addressing Challenge 3:** It leverages the dynamic-to-static capabilities provided by **Dynamo**, capturing `fx.graph` IR in eager mode to perform pass optimizations at the IR level. | |
| #### Illustrative Example | |
| ```python | |
| from magi_compiler import magi_compile | |
| @magi_compile() | |
| class TinyModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear = nn.Linear(1024, 1024, device="cuda") | |
| @no_grad() | |
| def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: | |
| return self.linear(x + y - z + 1) | |
| def magi_compiler_demo(): | |
| model = TinyModel() | |
| x = torch.randn(1024, 1024, device="cuda") | |
| y = torch.randn(1024, 1024, device="cuda") | |
| z = torch.randn(1024, 1024, device="cuda") | |
| model(x, y, z) | |
| ``` | |
| **Optimized Code (Triton Kernel):** | |
| ```python | |
| triton_poi_fused_add_sub_0 = async_compile.triton('triton_poi_fused_add_sub_0', ''' | |
| import triton | |
| import triton.language as tl | |
| from torch._inductor.runtime import triton_helpers, triton_heuristics | |
| from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
| from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
| triton_helpers.set_driver_to_gpu() | |
| @triton_heuristics.pointwise( | |
| size_hints={'x': 1048576}, | |
| filename=__file__, | |
| triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
| inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'B8F4209CBFC2377D6AF9CF3C65D610CA2B56C138A443862350DE1E56F5BF54C3', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
| min_elem_per_thread=0 | |
| ) | |
| @triton.jit | |
| def triton_poi_fused_add_sub_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = xindex < xnumel | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), xmask) | |
| tmp1 = tl.load(in_ptr1 + (x0), xmask) | |
| tmp3 = tl.load(in_ptr2 + (x0), xmask) | |
| tmp2 = tmp0 + tmp1 | |
| tmp4 = tmp2 - tmp3 | |
| tmp5 = 1.0 | |
| tmp6 = tmp4 + tmp5 | |
| tl.store(out_ptr0 + (x0), tmp6, xmask) | |
| ''', device_str='cuda') | |
| ``` | |
| ### 1.2 Frontend (Dynamo) | |
|  | |
| * **PyFrameObject (Dynamic Call Stack):** | |
| * Represents the context environment during function execution. Python creates a new `PyFrameObject` for each function call. | |
| * **PyCodeObject (Static Bytecode):** | |
| * The compiled product of Python code, which is static and read-only. A single `PyCodeObject` exists regardless of how many times the function is invoked. | |
| ```python | |
| def f(x, mod): | |
| for guard, transformed_code in f.compiled_entries: | |
| if guard(x, mod): | |
| return transformed_code(x, mod) | |
| try: | |
| guard, transformed_code = compile_and_optimize(x, mod) | |
| f.compiled_entries.append([guard, transformed_code]) | |
| return transformed_code(x, mod) | |
| except FailToCompileError: | |
| y = mod(x) | |
| z = torch.log(y) | |
| return z | |
| ``` | |
| #### Symbolic Shape | |
| MagiCompiler specifically targets the Transformer architecture and supports custom `dynamic_arg_dims` (typically for `seq_len`). | |
| **Example:** | |
| ```python | |
| @magi_compile(dynamic_arg_dims={"x": 0, "y": 0, "z": 0}) | |
| class TinyModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear = nn.Linear(1024, 1024, device="cuda") | |
| @no_grad() | |
| def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: | |
| return self.linear(x + y - z + 1) | |
| ``` | |
| **Guard Mechanism and Elimination in Symbolic Shape Deduction:** | |
| ```log | |
| I1204 16:31:35.745000 1859360 torch/_dynamo/symbolic_convert.py:3842] [0/0] Step 1: torchdynamo start tracing inner /usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py:66 | |
| I1204 16:31:35.746000 1859360 torch/fx/experimental/symbolic_shapes.py:3775] [0/0] create_env | |
| I1204 16:31:35.781000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s33 = 1024 for L['args'][0].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s33" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" | |
| I1204 16:31:35.785000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s6 = 1024 for L['args'][1].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s6" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" | |
| I1204 16:31:35.794000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s6) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s6)" | |
| I1204 16:31:35.795000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s33 (solve) VR[2, int_oo] | |
| I1204 16:31:35.800000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s21 = 1024 for L['args'][2].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" | |
| I1204 16:31:35.806000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s21) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s21)" | |
| I1204 16:31:35.807000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s33 = s21 (solve) VR[2, int_oo] | |
| I1204 16:31:35.828000 1859360 torch/_dynamo/symbolic_convert.py:4059] [0/0] Step 1: torchdynamo done tracing inner (RETURN_VALUE) | |
| I1204 16:31:35.837000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s21 (find) VR[2, int_oo] | |
| ``` | |
| ### 1.3 Backend (Inductor, MagiBackend, etc.) | |
|  | |
| MagiCompiler hijack the `torch.compile` logic through the following components: | |
| * **`custom_partitioner_fn`:** Segments the forward and backward computational graphs and determines which intermediate results are transmitted to the backward pass. | |
| * **`post_grad_custom_pre_pass`:** Performs pass optimizations at the whole-graph level (computational graph matching and rewriting). | |
| * **`PartitionFunc`:** Implements custom subgraph partitioning logic, utilizing attention mechanisms as splitting points. | |
|  | |
| * **`post_grad_custom_post_pass`:** Executes pass optimizations at the subgraph level (computation/communication overlap). | |
| --- | |
| ## 2. Best Practices | |
| ### 2.1 Model Adaptation | |
| MagiCompiler has certain limitations, such as mandatory whole-graph capture and the inability to support implicit subgraph interruptions. Consequently, manual adaptation is required in specific scenarios: | |
| **1. Computational Graph Dependencies or CPU/GPU Synchronization** | |
| ```python | |
| @magi_compile | |
| class MeanModule(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x: torch.Tensor, y: torch.Tensor): | |
| x = x.cos().sin() | |
| if x.mean() > 0.5: | |
| x = x - 1 | |
| return x * y | |
| ``` | |
| > **Note:** In typical Transformer models, certain pre/post-processing operations are unavoidable. Therefore, the recommended practice for `magi_compiler` is to perform **whole-graph capture at the `TransformerBlock` level**, as `TransformerBlock` computations constitute over 95% of the total workload. | |
| **2. Custom Operators (e.g., FlashAttention, FlexFlashAttention, MoE kernels)** | |
| * **Operator Registration:** A logic for operator registration is provided. Commonly used operators like FlashAttention (FA) and FlexFlashAttention (FFA) are already registered. | |
| ```python | |
| # Operator Registration | |
| @torch.library.custom_op("athena::flash_attn_func", mutates_args=()) | |
| def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
| ... | |
| # Operator Deduce Function | |
| @flash_attn_func.register_fake | |
| def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
| return torch.empty_like(query) | |
| # Call flash_attn_func | |
| self_attn_out = torch.ops.athena.flash_attn_func(q, k, v) | |
| out, _ = torch.ops.athena.flex_flash_attn_func(q, k, v, q_ranges=ffa_handler.q_ranges, k_ranges=ffa_handler.k_ranges) | |
| ``` | |
| * **Unit Testing:** Independent unit tests for operators should be provided in the production environment. | |
| ```python | |
| @pytest.mark.parametrize("batch_size", [1]) | |
| @pytest.mark.parametrize("seq_len", [1024, 2048, 4096]) | |
| @pytest.mark.parametrize("query_head", [48]) | |
| @pytest.mark.parametrize("kv_head", [4, 8]) | |
| @pytest.mark.parametrize("head_dim", [128, 256]) | |
| def test_fake_fa3(batch_size, seq_len, query_head, kv_head, head_dim): | |
| q = torch.randn((batch_size, seq_len, query_head, head_dim), device="cuda", dtype=torch.bfloat16) | |
| k = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16) | |
| v = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16) | |
| torch.library.opcheck(torch.ops.athena.flash_attn_func, (q, k, v)) | |
| ``` | |
| ### 2.2 Debugging Methods | |
| Key questions for debugging: | |
| * Is the bug originating from the compiler? | |
| * Which specific component of the compiler is causing the bug? | |
|  | |
| ```python | |
| class CompileConfig(BaseModel): | |
| # Basic configs | |
| backend: str = Field("inductor", description="Compilation backend.") | |
| compile_mode: CompileMode = Field(CompileMode.MAGI_COMPILE, description="Compilation mode.") | |
| ... | |
| # Cudagraph configs | |
| cudagraph_mode: CudaGraphMode = Field(CudaGraphMode.NONE, description="Cudagraph mode.") | |
| ... | |
| # Pass configs | |
| pass_config: PassConfig = Field(PassConfig(), description="Pass configuration.") | |
| ... | |
| ``` | |
| ### 2.3 Profiling Results | |
| For further details, please refer to the [**Wan2.2 Benchmark**](Wan2.2Benchmark.md). | |
| --- | |
| ## References | |
| 1. [PyTorch 2.0 Overview](https://docs.pytorch.org/assets/pytorch2-2.pdf) | |
| 2. [TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation](https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361) | |
| 3. [Depyf Walkthrough](https://depyf.readthedocs.io/en/latest/walk_through.html) | |
| 4. [Getting Started with PyTorch Compiler](https://docs.pytorch.org/docs/main/torch.compiler_get_started.html) | |