| import argparse |
| import copy |
| import itertools |
| import os |
|
|
| import torch |
| import triton |
| from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size |
|
|
| from sglang.srt.utils import get_device_capability |
|
|
| |
| IS_CI = ( |
| os.getenv("CI", "false").lower() == "true" |
| or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" |
| ) |
|
|
| |
| if IS_CI: |
| bs_range = [1] |
| qlen_range = [64] |
| else: |
| bs_range = [1, 8, 32, 64, 128, 256] |
| qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192] |
|
|
| configs = list(itertools.product(bs_range, qlen_range)) |
|
|
|
|
| @triton.testing.perf_report( |
| triton.testing.Benchmark( |
| x_names=["batch_size", "seq_len"], |
| x_vals=configs, |
| x_log=False, |
| line_arg="provider", |
| line_vals=[ |
| "128 heads", |
| "64 heads", |
| "32 heads", |
| "16 heads", |
| ], |
| line_names=[ |
| "128 heads", |
| "64 heads", |
| "32 heads", |
| "16 heads", |
| ], |
| styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], |
| ylabel="GB/s", |
| plot_name="cutlass mla", |
| args={}, |
| ) |
| ) |
| def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits): |
| d = 576 |
| dn = 64 |
| dv = 512 |
|
|
| h_q_map = { |
| "128": 128, |
| "64": 64, |
| "32": 32, |
| "16": 16, |
| } |
| parsed_h_q = next( |
| (value for key, value in h_q_map.items() if key in provider), None |
| ) |
|
|
| if parsed_h_q is None: |
| raise ValueError(f"Unknown head configuration in provider: {provider}") |
| h_q = parsed_h_q |
|
|
| seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") |
| max_seq_len = seq_lens.max().item() |
| block_num = (max_seq_len + block_size - 1) // block_size |
|
|
| |
| |
| pack_factor = 128 // block_size |
| block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor |
|
|
| qn = ( |
| torch.randn(h_q, batch_size, d - dn, dtype=torch.bfloat16, device="cuda") |
| * 100.0 |
| ) |
| qr = torch.randn(batch_size, h_q, dn, dtype=torch.bfloat16, device="cuda") * 100.0 |
| block_table = torch.randint( |
| 0, |
| batch_size * block_num, |
| (batch_size, block_num), |
| dtype=torch.int32, |
| device="cuda", |
| ) |
|
|
| kv_cache = torch.randn( |
| block_table.numel(), block_size, d, dtype=torch.bfloat16, device="cuda" |
| ) |
|
|
| workspace_size = cutlass_mla_get_workspace_size( |
| block_num * block_size, batch_size, num_kv_splits=num_kv_splits |
| ) |
| workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8) |
|
|
| quantiles = [0.5, 0.2, 0.8] |
| ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( |
| lambda: cutlass_mla_decode( |
| qn.transpose(0, 1), |
| qr, |
| kv_cache, |
| seq_lens, |
| block_table, |
| workspace, |
| 1.44, |
| num_kv_splits, |
| ), |
| quantiles=quantiles, |
| ) |
|
|
| q_size = qn.numel() * qn.element_size() + qr.numel() * qr.element_size() |
|
|
| gbps = ( |
| lambda ms: ( |
| q_size + q_size * dv / d + kv_cache.numel() * kv_cache.element_size() |
| ) |
| * 1e-9 |
| / (ms * 1e-3) |
| ) |
| return gbps(ms), gbps(max_ms), gbps(min_ms) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--block-sizes", |
| nargs="+", |
| type=int, |
| default=[1, 32, 64, 128], |
| help="List of batch sizes", |
| ) |
| parser.add_argument( |
| "--num-kv-splits", |
| nargs="+", |
| type=int, |
| default=[-1], |
| help="List of batch sizes", |
| ) |
| args = parser.parse_args() |
|
|
| |
| if IS_CI: |
| major, minor = get_device_capability() |
| if major is None or major < 10: |
| print("Skipping Cutlass MLA benchmark in CI environment") |
| if major is not None: |
| print( |
| f"Cutlass MLA requires compute capability 10.0+, but found {major}.{minor}" |
| ) |
| else: |
| print("Could not determine device capability") |
| else: |
| for block_size in args.block_sizes: |
| for kv_split in args.num_kv_splits: |
| print(f"block_size={block_size}, num_kv_splits={kv_split}: ") |
| benchmark.run( |
| print_data=True, |
| block_size=block_size, |
| num_kv_splits=kv_split, |
| ) |
| print("Benchmark finished!") |
| else: |
| for block_size in args.block_sizes: |
| for kv_split in args.num_kv_splits: |
| print(f"block_size={block_size}, num_kv_splits={kv_split}: ") |
| benchmark.run( |
| print_data=True, |
| block_size=block_size, |
| num_kv_splits=kv_split, |
| ) |
| print("Benchmark finished!") |
|
|