| | import dataclasses |
| | import os |
| | import enum |
| | from typing import List, Optional |
| | import random |
| |
|
| | import torch |
| | import kernelkit as kk |
| | |
| | from kernels import get_kernel, get_local_kernel |
| | flash_mla = get_kernel("drbh/tmp-kernel-123") |
| |
|
| | import quant |
| |
|
| | class TestTarget(enum.Enum): |
| | FWD = 0 |
| | DECODE = 1 |
| |
|
| | @dataclasses.dataclass |
| | class ExtraTestParamForDecode: |
| | b: int |
| | is_varlen: bool |
| | have_zero_seqlen_k: bool |
| | extra_s_k: Optional[int] = None |
| | extra_topk: Optional[int] = None |
| | block_size: int = 64 |
| | extra_block_size: Optional[int] = None |
| | have_extra_topk_length: bool = False |
| | |
| | @dataclasses.dataclass |
| | class TestParam: |
| | s_q: int |
| | s_kv: int |
| | topk: int |
| | h_q: int = 128 |
| | h_kv: int = 1 |
| | d_qk: int = 512 |
| | d_v: int = 512 |
| | seed: int = -1 |
| | check_correctness: bool = True |
| | is_all_indices_invalid: bool = False |
| | num_runs: int = 10 |
| | have_attn_sink: bool = False |
| | have_topk_length: bool = False |
| | decode: Optional[ExtraTestParamForDecode] = None |
| |
|
| | @dataclasses.dataclass |
| | class RawTestParamForDecode: |
| | """ |
| | "Flattened" test parameters for decoding test |
| | |
| | In our test script, to maintain compatibility with TestParam, we embed decode-only parameters into TestParam.decode, which is not very convinient when construct testcases. So here we have a "flattened" version of test parameters for decoding test. |
| | """ |
| | b: int |
| | h_q: int |
| | s_q: int |
| | h_kv: int |
| | s_kv: int |
| | is_varlen: bool |
| | topk: int |
| | is_all_indices_invalid: bool = False |
| | have_zero_seqlen_k: bool = False |
| | have_topk_length: bool = False |
| | enable_attn_sink: bool = True |
| | extra_s_k: Optional[int] = None |
| | extra_topk: Optional[int] = None |
| | block_size: int = 64 |
| | extra_block_size: Optional[int] = None |
| | have_extra_topk_length: bool = False |
| | d_qk: int = 576 |
| | d_v: int = 512 |
| | check_correctness: bool = True |
| | num_runs: int = 10 |
| | seed: int = -1 |
| |
|
| | def to_test_param(self) -> TestParam: |
| | return TestParam( |
| | self.s_q, self.s_kv, self.topk, self.h_q, self.h_kv, self.d_qk, self.d_v, |
| | self.seed, self.check_correctness, |
| | self.is_all_indices_invalid, |
| | self.num_runs, |
| | self.enable_attn_sink, |
| | self.have_topk_length, |
| | decode = ExtraTestParamForDecode( |
| | self.b, self.is_varlen, self.have_zero_seqlen_k, |
| | self.extra_s_k, self.extra_topk, |
| | self.block_size, self.extra_block_size, self.have_extra_topk_length |
| | ) |
| | ) |
| | |
| | @dataclasses.dataclass |
| | class Testcase: |
| | p: TestParam |
| | dOut: torch.Tensor |
| | q: torch.Tensor |
| | kv: torch.Tensor |
| | indices: torch.Tensor |
| | sm_scale: float |
| | attn_sink: Optional[torch.Tensor] |
| | topk_length: Optional[torch.Tensor] |
| |
|
| | def _randperm_batch(batch_size: int, perm_range: torch.Tensor, perm_size: int, paddings: List[int]) -> torch.Tensor: |
| | """ |
| | Generate random permutations in batch |
| | The return tensor, denoted as `res`, has a shape of [batch_size, perm_size]. `0 <= res[i, :] < perm_range[i]` holds. |
| | Values within each row are unique. |
| | If, for some `i`, `perm_range[i] < perm_size` holds, then `res[i, :]` contains values in `[0, perm_range[i])` as many as possible, and the rest are filled with `padding`. |
| | """ |
| | assert not torch.are_deterministic_algorithms_enabled() |
| | torch.use_deterministic_algorithms(True) |
| | perm_range_max = max(int(torch.max(perm_range).item()), perm_size) |
| | rand = torch.rand(batch_size, perm_range_max, dtype=torch.float32) |
| | rand[torch.arange(0, perm_range_max).broadcast_to(batch_size, perm_range_max) >= perm_range.view(batch_size, 1)] = float("-inf") |
| | res = rand.topk(perm_size, dim=-1, sorted=True).indices.to(torch.int32) |
| | if len(paddings) == 1: |
| | res[res >= perm_range.view(batch_size, 1)] = paddings[0] |
| | else: |
| | fillers = torch.tensor(paddings, dtype=torch.int32).index_select(0, torch.randint(0, len(paddings), (res.numel(), ), dtype=torch.int32)) |
| | res.masked_scatter_(res >= perm_range.view(batch_size, 1), fillers) |
| | torch.use_deterministic_algorithms(False) |
| | return res |
| |
|
| | def generate_testcase(t: TestParam) -> Testcase: |
| | kk.set_random_seed(t.seed) |
| | q = torch.randn((t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10 |
| | kv = torch.randn((t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10 |
| | do = torch.randn((t.s_q, t.h_q, t.d_v), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10 |
| |
|
| | q.clamp_(-10, 10) |
| | kv.clamp_(-10, 10) |
| | do.clamp_(-10, 10) |
| | |
| | invalid_indices_candidate = [-2147483648, -123456, -1, t.s_kv, 114514, 1919810, 2147480000, 2147483647] |
| | indices = _randperm_batch(t.s_q, torch.full((t.s_q, ), t.s_kv, dtype=torch.int32), t.topk, invalid_indices_candidate).view(t.s_q, t.h_kv, t.topk) |
| |
|
| | if t.is_all_indices_invalid: |
| | all_indices_invalid_mask = torch.randn(t.s_q, device='cpu') < -2 |
| | indices[all_indices_invalid_mask[:, None, None].broadcast_to(indices.shape)] = random.choice(invalid_indices_candidate) |
| | indices = indices.to(q.device) |
| |
|
| | attn_sink = None |
| | if t.have_attn_sink: |
| | attn_sink = torch.randn((t.h_q, ), dtype=torch.float32) |
| | mask = torch.randn((t.h_q, ), dtype=torch.float32) |
| | attn_sink[mask < -0.5] = float("-inf") |
| | attn_sink[mask > +0.5] = float("+inf") |
| |
|
| | topk_length = None |
| | if t.have_topk_length: |
| | topk_length = torch.randint(0, max(t.topk + 1, 64), (t.s_q, ), dtype=torch.int32, device=q.device).clamp_max(t.topk) |
| |
|
| | q = kk.non_contiguousify(q) |
| | kv = kk.non_contiguousify(kv) |
| | do = kk.non_contiguousify(do) |
| | indices = kk.non_contiguousify(indices) |
| |
|
| | return Testcase( |
| | p=t, |
| | dOut=do, |
| | q=q, |
| | kv=kv, |
| | indices=indices, |
| | sm_scale=0.5, |
| | attn_sink=attn_sink, |
| | topk_length=topk_length |
| | ) |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class KVScope: |
| | t: TestParam |
| | cache_seqlens: torch.Tensor |
| | block_table: torch.Tensor |
| | blocked_k: torch.Tensor |
| | abs_indices: torch.Tensor |
| | indices_in_kvcache: torch.Tensor |
| | topk_length: Optional[torch.Tensor] |
| | blocked_k_quantized: Optional[torch.Tensor] = None |
| |
|
| | def quant_and_dequant_(self): |
| | """ |
| | For FP8 cases, we need to quantize the KV cache for Flash MLA. |
| | Besides, the quantization error may be too large to be distinguished from wrong kernels, so we de-quantize kvcache here to mitigate quantization error |
| | """ |
| | fp8_kvcache_layout = None |
| | if self.t.d_qk == 576: |
| | fp8_kvcache_layout = quant.FP8KVCacheLayout.V32_FP8Sparse |
| | elif self.t.d_qk == 512: |
| | assert self.abs_indices is not None |
| | fp8_kvcache_layout = quant.FP8KVCacheLayout.MODEL1_FP8Sparse |
| | else: |
| | assert False |
| | self.blocked_k_quantized = quant.quantize_k_cache(self.blocked_k, fp8_kvcache_layout) |
| | blocked_k_dequantized = quant.dequantize_k_cache(self.blocked_k_quantized, fp8_kvcache_layout) |
| | self.blocked_k = blocked_k_dequantized |
| |
|
| | def get_kvcache_for_flash_mla(self) -> torch.Tensor: |
| | """ |
| | Return the quantized blocked_k for Flash MLA |
| | """ |
| | assert self.blocked_k_quantized is not None, "Please call `quant_and_dequant_` first before calling `get_kvcache_for_flash_mla`" |
| | return self.blocked_k_quantized |
| | |
| | def apply_perm(self, perm: torch.Tensor) -> "KVScope": |
| | """ |
| | Apply a batch permutation to this KVScope. Used for batch-invariance test |
| | """ |
| | new_kvscope = KVScope( |
| | self.t, |
| | self.cache_seqlens[perm], |
| | self.block_table[perm], |
| | self.blocked_k, |
| | self.abs_indices[perm], |
| | self.indices_in_kvcache[perm], |
| | self.topk_length[perm] if self.topk_length is not None else None, |
| | self.blocked_k_quantized |
| | ) |
| | return new_kvscope |
| | |
| | @dataclasses.dataclass |
| | class TestcaseForDecode: |
| | p: TestParam |
| | q: torch.Tensor |
| | attn_sink: Optional[torch.Tensor] |
| | sm_scale: float |
| | kv_scope: KVScope |
| | extra_kv_scope: Optional[KVScope] |
| |
|
| | def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode: |
| | kk.set_random_seed(t.seed) |
| | assert t.h_q % t.h_kv == 0 |
| | assert t.decode is not None |
| |
|
| | q = torch.randn((t.decode.b, t.s_q, t.h_q, t.d_qk)) |
| | q.clamp_(min=-1.0, max=1.0) |
| |
|
| | attn_sink = None |
| | if t.have_attn_sink: |
| | attn_sink = torch.randn((t.h_q, ), dtype=torch.float32) |
| | inf_mask = torch.randn((t.h_q, ), dtype=torch.float32) |
| | attn_sink[inf_mask > 0.5] = float("inf") |
| | attn_sink[inf_mask < -0.5] = float("-inf") |
| |
|
| | def generate_one_k_scope(s_k: int, block_size: int, topk: int, is_varlen: bool, have_zero_seqlen: bool, is_all_indices_invalid: bool, have_topk_length: bool) -> KVScope: |
| | b = t.decode.b |
| | cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device='cpu') |
| | if is_varlen: |
| | for i in range(b): |
| | cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), t.s_q) |
| |
|
| | if have_zero_seqlen: |
| | zeros_mask = torch.randn(b, dtype=torch.float32, device='cpu') > 0 |
| | cache_seqlens_cpu[zeros_mask] = 0 |
| |
|
| | max_seqlen_alignment = 4 * block_size |
| | max_seqlen_pad = max(kk.cdiv(int(cache_seqlens_cpu.max().item()), max_seqlen_alignment), 1) * max_seqlen_alignment |
| | cache_seqlens = cache_seqlens_cpu.cuda() |
| |
|
| | assert max_seqlen_pad % block_size == 0 |
| | block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) |
| | block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1) |
| |
|
| | blocked_k = kk.gen_non_contiguous_randn_tensor((block_table.numel(), block_size, t.h_kv, t.d_qk)) / 10 |
| | blocked_k.clamp_(min=-1.0, max=1.0) |
| | |
| | abs_indices = torch.empty((b, t.s_q, topk), dtype=torch.int32) |
| | if is_all_indices_invalid: |
| | abs_indices.fill_(-1) |
| | else: |
| | abs_indices[:] = _randperm_batch(b*t.s_q, cache_seqlens.repeat_interleave(t.s_q), topk, [-1]).view(b, t.s_q, topk) |
| | indices_in_kvcache = quant.abs_indices2indices_in_kvcache(abs_indices, block_table, block_size) |
| |
|
| | topk_length = torch.randint(0, topk+1, (b, ), dtype=torch.int32, device=q.device) if have_topk_length else None |
| |
|
| | |
| | if have_topk_length: |
| | indices_in_kvcache_masked = indices_in_kvcache.clone() |
| | indices_in_kvcache_masked[torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, t.s_q, topk) >= (topk_length.view(b, 1, 1) if have_topk_length else topk)] = -1 |
| | else: |
| | indices_in_kvcache_masked = indices_in_kvcache |
| | |
| | blocked_k = blocked_k.view(-1, t.h_kv, t.d_qk) |
| | nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu') |
| | nonused_indices_mask[indices_in_kvcache_masked] = False |
| | blocked_k[nonused_indices_mask, :, :] = float("nan") |
| | blocked_k = blocked_k.view(-1, block_size, t.h_kv, t.d_qk) |
| | |
| | block_table = kk.non_contiguousify(block_table) |
| | abs_indices = kk.non_contiguousify(abs_indices) |
| | indices_in_kvcache = kk.non_contiguousify(indices_in_kvcache) |
| | return KVScope(t, cache_seqlens, block_table, blocked_k, abs_indices, indices_in_kvcache, topk_length) |
| |
|
| | kv_scope0 = generate_one_k_scope(t.s_kv, t.decode.block_size, t.topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.have_topk_length) |
| | kv_scope0.quant_and_dequant_() |
| | if t.decode.extra_topk is not None: |
| | if t.decode.extra_s_k is None: |
| | t.decode.extra_s_k = t.decode.extra_topk*2 |
| | if t.decode.extra_block_size is None: |
| | t.decode.extra_block_size = t.decode.block_size |
| | kv_scope1 = generate_one_k_scope(t.decode.extra_s_k, t.decode.extra_block_size, t.decode.extra_topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.decode.have_extra_topk_length) |
| | kv_scope1.quant_and_dequant_() |
| | else: |
| | assert t.decode.extra_block_size is None and t.decode.extra_s_k is None and not t.decode.have_extra_topk_length |
| | kv_scope1 = None |
| | |
| | sm_scale = t.d_qk ** -0.55 |
| |
|
| | q = kk.non_contiguousify(q) |
| | return TestcaseForDecode(t, q, attn_sink, sm_scale, kv_scope0, kv_scope1) |
| |
|
| |
|
| | def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool): |
| | assert not return_p_sum |
| | return flash_mla.flash_mla_sparse_fwd( |
| | t.q, t.kv, t.indices, |
| | sm_scale=t.sm_scale, |
| | attn_sink=t.attn_sink, |
| | topk_length=t.topk_length |
| | ) |
| |
|
| | def run_flash_mla_decode(p: TestParam, t: TestcaseForDecode, tile_scheduler_metadata, num_splits): |
| | assert p.decode is not None |
| | return flash_mla.flash_mla_with_kvcache( |
| | t.q, |
| | t.kv_scope.get_kvcache_for_flash_mla(), |
| | None, None, p.d_v, |
| | tile_scheduler_metadata, num_splits, |
| |
|
| | t.sm_scale, False, True, |
| | t.kv_scope.indices_in_kvcache, |
| | t.attn_sink, |
| | t.extra_kv_scope.get_kvcache_for_flash_mla() if t.extra_kv_scope is not None else None, |
| | t.extra_kv_scope.indices_in_kvcache if t.extra_kv_scope is not None else None, |
| | t.kv_scope.topk_length, |
| | t.extra_kv_scope.topk_length if t.extra_kv_scope is not None and t.extra_kv_scope.topk_length is not None else None |
| | ) |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class FlopsAndMemVolStatistics: |
| | """ |
| | FLOPs and memory volume statistics for prefilling |
| | """ |
| | fwd_flop: float |
| | fwd_mem_vol: float |
| |
|
| | def count_flop_and_mem_vol(p: TestParam, t: Testcase) -> FlopsAndMemVolStatistics: |
| | total_topk = (p.s_q*p.topk) if t.topk_length is None else t.topk_length.sum().item() |
| | indices_valid_mask = (t.indices >= 0) & (t.indices < p.s_kv) |
| | if t.topk_length is not None: |
| | indices_valid_mask &= (torch.arange(p.topk)[None, None, :].broadcast_to(p.s_q, p.h_kv, p.topk)) < t.topk_length[:, None, None] |
| | num_valid_indices = indices_valid_mask.sum().item() |
| |
|
| | fwd_flop = 2 * total_topk * p.h_q * (p.d_qk + p.d_v) |
| | fwd_mem_vol = num_valid_indices*p.d_qk*2 + p.s_q*p.h_q*(p.d_qk+p.d_v)*2 |
| | return FlopsAndMemVolStatistics( |
| | fwd_flop, |
| | fwd_mem_vol, |
| | ) |
| |
|
| | @dataclasses.dataclass |
| | class FlopsAndMemVolStatisticsForDecode: |
| | """ |
| | FLOPs and memory volume statistics for decoding |
| | """ |
| | flop: float |
| | mem_vol: float |
| |
|
| | def count_flop_and_mem_vol_for_decode(p: TestParam, t: TestcaseForDecode) -> FlopsAndMemVolStatisticsForDecode: |
| | assert p.decode |
| | b = p.decode.b |
| |
|
| | def get_num_attended_tokens(kv_scope: KVScope) -> int: |
| | topk = kv_scope.indices_in_kvcache.shape[-1] |
| | if kv_scope.topk_length is None: |
| | return b * p.s_q * topk |
| | else: |
| | return int(kv_scope.topk_length.sum().item()) * p.s_q |
| | |
| | def get_num_retrieved_tokens(kv_scope: KVScope) -> int: |
| | if kv_scope.topk_length is None: |
| | indices = kv_scope.indices_in_kvcache |
| | else: |
| | indices = kv_scope.indices_in_kvcache.clone() |
| | batch, s_q, topk = indices.shape |
| | mask = torch.arange(0, topk, device=indices.device).view(1, 1, topk).broadcast_to(batch, s_q, topk) >= kv_scope.topk_length.view(batch, 1, 1) |
| | indices[mask] = -1 |
| | num_unique_tokens = indices.unique().numel() |
| | return num_unique_tokens |
| |
|
| | num_attended_tokens = get_num_attended_tokens(t.kv_scope) + (get_num_attended_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0) |
| | num_retrieved_tokens = get_num_retrieved_tokens(t.kv_scope) + (get_num_retrieved_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0) |
| |
|
| | compute_flop = 2 * p.h_q * num_attended_tokens * (p.d_qk + p.d_v) |
| | kv_token_size = 656 if p.d_qk == 576 else 576 |
| | mem_vol = sum([ |
| | 2 * b * p.s_q * p.h_q * p.d_qk, |
| | num_retrieved_tokens * kv_token_size, |
| | 2 * b * p.s_q * p.h_q * p.d_v, |
| | ]) |
| | return FlopsAndMemVolStatisticsForDecode( |
| | compute_flop, |
| | mem_vol |
| | ) |
| |
|
| | def is_no_cooldown() -> bool: |
| | return os.environ.get('NO_COOLDOWN', '').lower() in ['1', 'yes', 'y'] |
| |
|