| import pytest |
| import torch |
| import torch.nn.functional as F |
| from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size |
| from torch import Tensor |
|
|
| |
| if torch.cuda.get_device_capability() != (10, 0): |
| pytest.skip( |
| reason="Cutlass MLA Requires compute capability of 10.", |
| allow_module_level=True, |
| ) |
|
|
|
|
| def ref_mla( |
| out: Tensor, |
| query: Tensor, |
| kv_cache: Tensor, |
| scale: float, |
| block_tables: Tensor, |
| seq_lens: Tensor, |
| ): |
| bs, num_heads, v_head_dim = out.shape |
| head_dim = query.shape[2] |
|
|
| for i in range(bs): |
| |
| kv = kv_cache[block_tables[i]] |
| kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] |
| v = kv[:, :, :v_head_dim] |
|
|
| q = query[i].view(num_heads, 1, head_dim) |
| o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True) |
| out[i] = o.view(num_heads, v_head_dim) |
|
|
| return out |
|
|
|
|
| @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) |
| @pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096]) |
| @pytest.mark.parametrize("bs", [1, 2, 4]) |
| @pytest.mark.parametrize("varlen", [False, True]) |
| @pytest.mark.parametrize("block_size", [1, 16, 64, 128]) |
| @pytest.mark.parametrize("num_heads", [16, 32, 64, 128]) |
| @pytest.mark.parametrize("num_kv_splits", [-1, 1]) |
| def test_cutlass_mla_decode( |
| dtype: torch.dtype, |
| mean_seq_len: int, |
| bs: int, |
| varlen: bool, |
| block_size: int, |
| num_heads: int, |
| num_kv_splits: int, |
| ): |
| torch.set_default_dtype(dtype) |
| torch.set_default_device("cuda") |
| torch.manual_seed(42) |
|
|
| d = 576 |
| h_q = num_heads |
| dv = 512 |
|
|
| q_nope_dim = 128 |
| q_pe_dim = 64 |
| scale = (q_nope_dim + q_pe_dim) ** (-0.5) |
| if varlen: |
| seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) |
| seq_lens = seq_lens.clip(2).to(torch.int32) |
| else: |
| seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32) |
| max_seq_len = seq_lens.max().item() |
| block_num = (max_seq_len + block_size - 1) // block_size |
|
|
| |
| |
| pack_factor = 128 // block_size |
| block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor |
|
|
| |
| q = torch.randn(bs, h_q, d) * 100.0 |
| block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32) |
|
|
| kv_cache = torch.randn(block_table.numel(), block_size, d) |
|
|
| workspace_size = cutlass_mla_get_workspace_size( |
| block_num * block_size, bs, num_kv_splits=num_kv_splits |
| ) |
| workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8) |
|
|
| q_nope = torch.empty((h_q, bs, dv)).transpose(0, 1) |
| q_nope.copy_(q[:, :, :dv]) |
| q_pe = q[:, :, dv:].clone() |
|
|
| out_ref = q.new_zeros(bs, h_q, dv) |
| ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) |
| out = cutlass_mla_decode( |
| q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, scale, num_kv_splits |
| ) |
|
|
| torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2) |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__]) |
|
|