daVinci-MagiHuman / pkgs /MagiCompiler /docs /WhyMagiCompiler.md
jiadisu
Switch back to Docker SDK with local pkgs
e6066e8

A newer version of the Gradio SDK is available: 6.14.0

Upgrade

Why MagiCompiler?

1. Compiler Overview

1.1 Background

We have long encountered several significant challenges in model optimization:

  1. Blurred Acceleration Boundaries: There is ambiguity regarding the extent of optimization required to achieve "extreme" performance.
  2. Complex Performance Tuning: Optimization strategies are often tightly coupled with model architectures, necessitating extensive and repetitive manual intervention.
  3. Deficiency in Optimization Tools: The infrastructure lacks sufficient mechanisms for computational graph-level optimizations, such as operator substitution and communication overlap.

MagiCompiler addresses these issues through the following approaches:

  • Addressing Challenge 1: It adopts whole-graph compilation, thoroughly transcending the boundaries of TransformerLayer to maximize the scope of kernel fusion.
  • Addressing Challenge 2: It integrates infrastructure optimizations directly into MagiCompiler, implementing features such as AutoCudaGraph and AutoCheckpointing(WIP).
  • Addressing Challenge 3: It leverages the dynamic-to-static capabilities provided by Dynamo, capturing fx.graph IR in eager mode to perform pass optimizations at the IR level.

Illustrative Example

from magi_compiler import magi_compile

@magi_compile()
class TinyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1024, 1024, device="cuda")

    @no_grad()
    def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        return self.linear(x + y - z + 1)


def magi_compiler_demo():
    model = TinyModel()
    x = torch.randn(1024, 1024, device="cuda")
    y = torch.randn(1024, 1024, device="cuda")
    z = torch.randn(1024, 1024, device="cuda")
    model(x, y, z)

Optimized Code (Triton Kernel):

triton_poi_fused_add_sub_0 = async_compile.triton('triton_poi_fused_add_sub_0', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 1048576},
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'B8F4209CBFC2377D6AF9CF3C65D610CA2B56C138A443862350DE1E56F5BF54C3', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_sub_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp3 = tl.load(in_ptr2 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 - tmp3
    tmp5 = 1.0
    tmp6 = tmp4 + tmp5
    tl.store(out_ptr0 + (x0), tmp6, xmask)
''', device_str='cuda')

1.2 Frontend (Dynamo)

Dynamo

  • PyFrameObject (Dynamic Call Stack):
    • Represents the context environment during function execution. Python creates a new PyFrameObject for each function call.
  • PyCodeObject (Static Bytecode):
    • The compiled product of Python code, which is static and read-only. A single PyCodeObject exists regardless of how many times the function is invoked.
def f(x, mod):
    for guard, transformed_code in f.compiled_entries:
        if guard(x, mod):
            return transformed_code(x, mod)
    try:
        guard, transformed_code = compile_and_optimize(x, mod)
        f.compiled_entries.append([guard, transformed_code])
        return transformed_code(x, mod)
    except FailToCompileError:
        y = mod(x)
        z = torch.log(y)
        return z

Symbolic Shape

MagiCompiler specifically targets the Transformer architecture and supports custom dynamic_arg_dims (typically for seq_len).

Example:

@magi_compile(dynamic_arg_dims={"x": 0, "y": 0, "z": 0})
class TinyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1024, 1024, device="cuda")

    @no_grad()
    def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        return self.linear(x + y - z + 1)

Guard Mechanism and Elimination in Symbolic Shape Deduction:

I1204 16:31:35.745000 1859360 torch/_dynamo/symbolic_convert.py:3842] [0/0] Step 1: torchdynamo start tracing inner /usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py:66
I1204 16:31:35.746000 1859360 torch/fx/experimental/symbolic_shapes.py:3775] [0/0] create_env
I1204 16:31:35.781000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s33 = 1024 for L['args'][0].size()[0] [2, int_oo] return self.linear(x + y - z + 1)  # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s33" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.785000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s6 = 1024 for L['args'][1].size()[0] [2, int_oo] return self.linear(x + y - z + 1)  # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s6" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.794000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s6) [guard added] return self.linear(x + y - z + 1)  # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s6)"
I1204 16:31:35.795000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s33 (solve) VR[2, int_oo]
I1204 16:31:35.800000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s21 = 1024 for L['args'][2].size()[0] [2, int_oo] return self.linear(x + y - z + 1)  # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1204 16:31:35.806000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s21) [guard added] return self.linear(x + y - z + 1)  # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s21)"
I1204 16:31:35.807000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s33 = s21 (solve) VR[2, int_oo]
I1204 16:31:35.828000 1859360 torch/_dynamo/symbolic_convert.py:4059] [0/0] Step 1: torchdynamo done tracing inner (RETURN_VALUE)
I1204 16:31:35.837000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s21 (find) VR[2, int_oo]

1.3 Backend (Inductor, MagiBackend, etc.)

Backend Architecture

MagiCompiler hijack the torch.compile logic through the following components:

  • custom_partitioner_fn: Segments the forward and backward computational graphs and determines which intermediate results are transmitted to the backward pass.
  • post_grad_custom_pre_pass: Performs pass optimizations at the whole-graph level (computational graph matching and rewriting).
  • PartitionFunc: Implements custom subgraph partitioning logic, utilizing attention mechanisms as splitting points.

Partition

  • post_grad_custom_post_pass: Executes pass optimizations at the subgraph level (computation/communication overlap).

2. Best Practices

2.1 Model Adaptation

MagiCompiler has certain limitations, such as mandatory whole-graph capture and the inability to support implicit subgraph interruptions. Consequently, manual adaptation is required in specific scenarios:

1. Computational Graph Dependencies or CPU/GPU Synchronization

@magi_compile
class MeanModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        x = x.cos().sin()
        if x.mean() > 0.5:
            x = x - 1
        return x * y

Note: In typical Transformer models, certain pre/post-processing operations are unavoidable. Therefore, the recommended practice for magi_compiler is to perform whole-graph capture at the TransformerBlock level, as TransformerBlock computations constitute over 95% of the total workload.

2. Custom Operators (e.g., FlashAttention, FlexFlashAttention, MoE kernels)

  • Operator Registration: A logic for operator registration is provided. Commonly used operators like FlashAttention (FA) and FlexFlashAttention (FFA) are already registered.
# Operator Registration
@torch.library.custom_op("athena::flash_attn_func", mutates_args=())
def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    ...

# Operator Deduce Function
@flash_attn_func.register_fake
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    return torch.empty_like(query)

# Call flash_attn_func
self_attn_out = torch.ops.athena.flash_attn_func(q, k, v)
out, _ = torch.ops.athena.flex_flash_attn_func(q, k, v, q_ranges=ffa_handler.q_ranges, k_ranges=ffa_handler.k_ranges)
  • Unit Testing: Independent unit tests for operators should be provided in the production environment.
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seq_len", [1024, 2048, 4096])
@pytest.mark.parametrize("query_head", [48])
@pytest.mark.parametrize("kv_head", [4, 8])
@pytest.mark.parametrize("head_dim", [128, 256])
def test_fake_fa3(batch_size, seq_len, query_head, kv_head, head_dim):
    q = torch.randn((batch_size, seq_len, query_head, head_dim), device="cuda", dtype=torch.bfloat16)
    k = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
    v = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
    torch.library.opcheck(torch.ops.athena.flash_attn_func, (q, k, v))

2.2 Debugging Methods

Key questions for debugging:

  • Is the bug originating from the compiler?
  • Which specific component of the compiler is causing the bug?

Debugging

class CompileConfig(BaseModel):
    # Basic configs
    backend: str = Field("inductor", description="Compilation backend.")
    compile_mode: CompileMode = Field(CompileMode.MAGI_COMPILE, description="Compilation mode.")
    ...

    # Cudagraph configs
    cudagraph_mode: CudaGraphMode = Field(CudaGraphMode.NONE, description="Cudagraph mode.")
    ...

    # Pass configs
    pass_config: PassConfig = Field(PassConfig(), description="Pass configuration.")
    ...

2.3 Profiling Results

For further details, please refer to the Wan2.2 Benchmark.


References

  1. PyTorch 2.0 Overview
  2. TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation
  3. Depyf Walkthrough
  4. Getting Started with PyTorch Compiler