| import torch |
| import triton |
| import triton.language as tl |
| from sgl_kernel import concat_mla_k as concat_mla_k_cuda |
|
|
| DEVICE = triton.runtime.driver.active.get_active_torch_device() |
|
|
| num_local_heads = 128 |
| qk_nope_head_dim = 128 |
| qk_rope_head_dim = 64 |
|
|
|
|
| def create_data(num_tokens): |
| k_nope_container = torch.randn( |
| (num_tokens, num_local_heads, qk_nope_head_dim + 128), |
| dtype=torch.bfloat16, |
| device="cuda", |
| ) |
| k_nope = k_nope_container[:, :, :qk_nope_head_dim] |
|
|
| k_rope_container = torch.randn( |
| (num_tokens, 1, 128 + qk_rope_head_dim), dtype=torch.bfloat16, device="cuda" |
| ) |
| k_rope = k_rope_container[:, :, -qk_rope_head_dim:] |
|
|
| k = torch.empty( |
| (num_tokens, num_local_heads, qk_nope_head_dim + qk_rope_head_dim), |
| dtype=torch.bfloat16, |
| device="cuda", |
| ) |
| return dict(k=k, k_nope=k_nope, k_rope=k_rope) |
|
|
|
|
| def fn_torch(k, k_nope, k_rope): |
| k[..., :qk_nope_head_dim] = k_nope |
| k[..., qk_nope_head_dim:] = k_rope |
|
|
|
|
| def fn_hack_non_strided(k, k_nope, k_rope): |
| k_flatten_view = k.flatten() |
| k_flatten_view[: k_nope.numel()] = k_nope.flatten() |
|
|
| k2 = k_flatten_view[k_nope.numel() :].view(k_rope.numel(), -1) |
| k2 = k_rope.flatten()[:, None] |
|
|
|
|
| @torch.compile(dynamic=True) |
| def fn_torch_compiled(k, k_nope, k_rope): |
| return fn_torch(k, k_nope, k_rope) |
|
|
|
|
| def fn_cuda(k, k_nope, k_rope): |
| concat_mla_k_cuda(k, k_nope, k_rope) |
|
|
|
|
| @triton.jit |
| def fn_triton_kernel( |
| k_ptr, |
| k_nope_ptr, |
| k_rope_ptr, |
| num_tokens, |
| QK_NOPE_HEAD_DIM: tl.constexpr, |
| QK_ROPE_HEAD_DIM: tl.constexpr, |
| NUM_LOCAL_HEADS: tl.constexpr, |
| K_NOPE_STRIDE_0: tl.constexpr, |
| K_NOPE_STRIDE_1: tl.constexpr, |
| K_STRIDE_0: tl.constexpr, |
| K_STRIDE_1: tl.constexpr, |
| K_ROPE_STRIDE_0: tl.constexpr, |
| BLOCK_ROWS: tl.constexpr, |
| ): |
| pid = tl.program_id(axis=0) |
|
|
| token_id = pid * BLOCK_ROWS + tl.arange(0, BLOCK_ROWS) |
| token_mask = token_id < num_tokens |
|
|
| head_id = tl.arange(0, NUM_LOCAL_HEADS) |
|
|
| |
| nope_sub_id = tl.arange(0, QK_NOPE_HEAD_DIM) |
| offs_nope = ( |
| token_id[:, None, None] * K_NOPE_STRIDE_0 |
| + head_id[None, :, None] * K_NOPE_STRIDE_1 |
| + nope_sub_id[None, None, :] |
| ) |
| offs_k = ( |
| token_id[:, None, None] * K_STRIDE_0 |
| + head_id[None, :, None] * K_STRIDE_1 |
| + nope_sub_id[None, None, :] |
| ) |
| vals_nope = tl.load(k_nope_ptr + offs_nope, mask=token_mask[:, None, None]) |
| tl.store(k_ptr + offs_k, vals_nope, mask=token_mask[:, None, None]) |
|
|
| |
| rope_sub_id = tl.arange(0, QK_ROPE_HEAD_DIM) |
| offs_rope = token_id[:, None, None] * K_ROPE_STRIDE_0 + rope_sub_id[None, None, :] |
| offs_k = ( |
| token_id[:, None, None] * K_STRIDE_0 |
| + head_id[None, :, None] * K_STRIDE_1 |
| + rope_sub_id[None, None, :] |
| + QK_NOPE_HEAD_DIM |
| ) |
| vals_rope = tl.load(k_rope_ptr + offs_rope, mask=token_mask[:, None, None]) |
| tl.store(k_ptr + offs_k, vals_rope, mask=token_mask[:, None, None]) |
|
|
|
|
| def fn_triton(k, k_nope, k_rope): |
| assert k.device == DEVICE and k_nope.device == DEVICE and k_rope.device == DEVICE |
| num_tokens, _, _ = k.shape |
| grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_ROWS"]),) |
| fn_triton_kernel[grid]( |
| k, |
| k_nope, |
| k_rope, |
| num_tokens, |
| QK_NOPE_HEAD_DIM=qk_nope_head_dim, |
| QK_ROPE_HEAD_DIM=qk_rope_head_dim, |
| NUM_LOCAL_HEADS=num_local_heads, |
| K_NOPE_STRIDE_0=k_nope.stride(0), |
| K_NOPE_STRIDE_1=k_nope.stride(1), |
| K_STRIDE_0=k.stride(0), |
| K_STRIDE_1=k.stride(1), |
| K_ROPE_STRIDE_0=k_rope.stride(0), |
| BLOCK_ROWS=16, |
| ) |
|
|
|
|
| def execute_and_get_output(f, data): |
| data["k"].zero_() |
| f(**data) |
| assert data["k"].sum().item() != 0 |
| return data["k"].clone() |
|
|
|
|
| torch.manual_seed(0) |
| data = create_data(num_tokens=32768) |
| output_ref = execute_and_get_output(fn_torch, data) |
| output_exp = execute_and_get_output(fn_cuda, data) |
| |
| |
| if not torch.all(output_ref == output_exp): |
| abs_delta = torch.abs(output_ref - output_exp) |
| raise AssertionError( |
| f"{output_ref=} {output_exp=} " |
| f"{abs_delta=} " |
| f"{torch.argwhere(abs_delta != 0.0)=} " |
| ) |
|
|
|
|
| @triton.testing.perf_report( |
| triton.testing.Benchmark( |
| x_names=["num_tokens"], |
| x_vals=[ |
| 2048, |
| 4096, |
| 8192, |
| 16384, |
| 32768, |
| ], |
| x_log=False, |
| line_arg="provider", |
| line_vals=[ |
| "torch", |
| "torch_compiled", |
| "triton", |
| "hack_non_strided", |
| "cuda", |
| ], |
| line_names=[ |
| "torch", |
| "torch_compiled", |
| "triton", |
| "hack_non_strided", |
| "cuda", |
| ], |
| plot_name="vector-add-performance", |
| args={}, |
| ) |
| ) |
| def benchmark(num_tokens, provider): |
| data = create_data(num_tokens=num_tokens) |
| quantiles = [0.5, 0.2, 0.8] |
| fn = { |
| "torch": fn_torch, |
| "torch_compiled": fn_torch_compiled, |
| "triton": fn_triton, |
| "hack_non_strided": fn_hack_non_strided, |
| "cuda": fn_cuda, |
| }[provider] |
| ms, min_ms, max_ms = triton.testing.do_bench( |
| lambda: fn(**data), quantiles=quantiles |
| ) |
| return ms, min_ms, max_ms |
|
|
|
|
| torch.cuda.cudart().cudaProfilerStart() |
| benchmark.run(print_data=True, show_plots=True) |
| torch.cuda.cudart().cudaProfilerStop() |
|
|