File size: 16,774 Bytes
ccef021 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 | import dataclasses
import os
import enum
from typing import List, Optional
import random
import torch
import kernelkit as kk
# import flash_mla
from kernels import get_kernel, get_local_kernel
flash_mla = get_kernel("drbh/tmp-kernel-123")
import quant
class TestTarget(enum.Enum):
FWD = 0
DECODE = 1
@dataclasses.dataclass
class ExtraTestParamForDecode:
b: int
is_varlen: bool
have_zero_seqlen_k: bool
extra_s_k: Optional[int] = None
extra_topk: Optional[int] = None
block_size: int = 64
extra_block_size: Optional[int] = None
have_extra_topk_length: bool = False
@dataclasses.dataclass
class TestParam:
s_q: int
s_kv: int
topk: int
h_q: int = 128
h_kv: int = 1
d_qk: int = 512
d_v: int = 512
seed: int = -1 # -1: to be filled automatically
check_correctness: bool = True
is_all_indices_invalid: bool = False # All indices are invalid, i.e., all indices are set to a large number (e.g., 2147483647)
num_runs: int = 10
have_attn_sink: bool = False
have_topk_length: bool = False
decode: Optional[ExtraTestParamForDecode] = None
@dataclasses.dataclass
class RawTestParamForDecode:
"""
"Flattened" test parameters for decoding test
In our test script, to maintain compatibility with TestParam, we embed decode-only parameters into TestParam.decode, which is not very convinient when construct testcases. So here we have a "flattened" version of test parameters for decoding test.
"""
b: int
h_q: int
s_q: int
h_kv: int
s_kv: int
is_varlen: bool
topk: int
is_all_indices_invalid: bool = False
have_zero_seqlen_k: bool = False
have_topk_length: bool = False
enable_attn_sink: bool = True
extra_s_k: Optional[int] = None
extra_topk: Optional[int] = None
block_size: int = 64
extra_block_size: Optional[int] = None
have_extra_topk_length: bool = False
d_qk: int = 576 # Q/K head dim (= dv + RoPE dim)
d_v: int = 512 # V head dim
check_correctness: bool = True
num_runs: int = 10
seed: int = -1
def to_test_param(self) -> TestParam:
return TestParam(
self.s_q, self.s_kv, self.topk, self.h_q, self.h_kv, self.d_qk, self.d_v,
self.seed, self.check_correctness,
self.is_all_indices_invalid,
self.num_runs,
self.enable_attn_sink,
self.have_topk_length,
decode = ExtraTestParamForDecode(
self.b, self.is_varlen, self.have_zero_seqlen_k,
self.extra_s_k, self.extra_topk,
self.block_size, self.extra_block_size, self.have_extra_topk_length
)
)
@dataclasses.dataclass
class Testcase:
p: TestParam
dOut: torch.Tensor # [s_q, h_q, d_v]
q: torch.Tensor # [s_q, h_q, d_qk]
kv: torch.Tensor # [s_kv, h_kv, d_qk]
indices: torch.Tensor # [s_q, h_kv, topk]
sm_scale: float
attn_sink: Optional[torch.Tensor] # [h_q]
topk_length: Optional[torch.Tensor] # [s_q]
def _randperm_batch(batch_size: int, perm_range: torch.Tensor, perm_size: int, paddings: List[int]) -> torch.Tensor:
"""
Generate random permutations in batch
The return tensor, denoted as `res`, has a shape of [batch_size, perm_size]. `0 <= res[i, :] < perm_range[i]` holds.
Values within each row are unique.
If, for some `i`, `perm_range[i] < perm_size` holds, then `res[i, :]` contains values in `[0, perm_range[i])` as many as possible, and the rest are filled with `padding`.
"""
assert not torch.are_deterministic_algorithms_enabled()
torch.use_deterministic_algorithms(True)
perm_range_max = max(int(torch.max(perm_range).item()), perm_size)
rand = torch.rand(batch_size, perm_range_max, dtype=torch.float32)
rand[torch.arange(0, perm_range_max).broadcast_to(batch_size, perm_range_max) >= perm_range.view(batch_size, 1)] = float("-inf") # Fill invalid positions, so that the following `topk` operators will select positions within `perm_range` first
res = rand.topk(perm_size, dim=-1, sorted=True).indices.to(torch.int32)
if len(paddings) == 1:
res[res >= perm_range.view(batch_size, 1)] = paddings[0]
else:
fillers = torch.tensor(paddings, dtype=torch.int32).index_select(0, torch.randint(0, len(paddings), (res.numel(), ), dtype=torch.int32))
res.masked_scatter_(res >= perm_range.view(batch_size, 1), fillers)
torch.use_deterministic_algorithms(False)
return res
def generate_testcase(t: TestParam) -> Testcase:
kk.set_random_seed(t.seed)
q = torch.randn((t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10
kv = torch.randn((t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10
do = torch.randn((t.s_q, t.h_q, t.d_v), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
do.clamp_(-10, 10)
invalid_indices_candidate = [-2147483648, -123456, -1, t.s_kv, 114514, 1919810, 2147480000, 2147483647]
indices = _randperm_batch(t.s_q, torch.full((t.s_q, ), t.s_kv, dtype=torch.int32), t.topk, invalid_indices_candidate).view(t.s_q, t.h_kv, t.topk)
if t.is_all_indices_invalid:
all_indices_invalid_mask = torch.randn(t.s_q, device='cpu') < -2
indices[all_indices_invalid_mask[:, None, None].broadcast_to(indices.shape)] = random.choice(invalid_indices_candidate)
indices = indices.to(q.device)
attn_sink = None
if t.have_attn_sink:
attn_sink = torch.randn((t.h_q, ), dtype=torch.float32)
mask = torch.randn((t.h_q, ), dtype=torch.float32)
attn_sink[mask < -0.5] = float("-inf")
attn_sink[mask > +0.5] = float("+inf")
topk_length = None
if t.have_topk_length:
topk_length = torch.randint(0, max(t.topk + 1, 64), (t.s_q, ), dtype=torch.int32, device=q.device).clamp_max(t.topk)
q = kk.non_contiguousify(q)
kv = kk.non_contiguousify(kv)
do = kk.non_contiguousify(do)
indices = kk.non_contiguousify(indices)
return Testcase(
p=t,
dOut=do,
q=q,
kv=kv,
indices=indices,
sm_scale=0.5, # Otherwise dK is too small compared to dV
attn_sink=attn_sink,
topk_length=topk_length
)
@dataclasses.dataclass
class KVScope:
t: TestParam
cache_seqlens: torch.Tensor
block_table: torch.Tensor
blocked_k: torch.Tensor
abs_indices: torch.Tensor
indices_in_kvcache: torch.Tensor
topk_length: Optional[torch.Tensor]
blocked_k_quantized: Optional[torch.Tensor] = None
def quant_and_dequant_(self):
"""
For FP8 cases, we need to quantize the KV cache for Flash MLA.
Besides, the quantization error may be too large to be distinguished from wrong kernels, so we de-quantize kvcache here to mitigate quantization error
"""
fp8_kvcache_layout = None
if self.t.d_qk == 576:
fp8_kvcache_layout = quant.FP8KVCacheLayout.V32_FP8Sparse
elif self.t.d_qk == 512:
assert self.abs_indices is not None
fp8_kvcache_layout = quant.FP8KVCacheLayout.MODEL1_FP8Sparse
else:
assert False
self.blocked_k_quantized = quant.quantize_k_cache(self.blocked_k, fp8_kvcache_layout)
blocked_k_dequantized = quant.dequantize_k_cache(self.blocked_k_quantized, fp8_kvcache_layout)
self.blocked_k = blocked_k_dequantized
def get_kvcache_for_flash_mla(self) -> torch.Tensor:
"""
Return the quantized blocked_k for Flash MLA
"""
assert self.blocked_k_quantized is not None, "Please call `quant_and_dequant_` first before calling `get_kvcache_for_flash_mla`"
return self.blocked_k_quantized
def apply_perm(self, perm: torch.Tensor) -> "KVScope":
"""
Apply a batch permutation to this KVScope. Used for batch-invariance test
"""
new_kvscope = KVScope(
self.t,
self.cache_seqlens[perm],
self.block_table[perm],
self.blocked_k,
self.abs_indices[perm],
self.indices_in_kvcache[perm],
self.topk_length[perm] if self.topk_length is not None else None,
self.blocked_k_quantized
)
return new_kvscope
@dataclasses.dataclass
class TestcaseForDecode:
p: TestParam
q: torch.Tensor # [b, s_q, h_q, d_qk]
attn_sink: Optional[torch.Tensor] # [h_q]
sm_scale: float
kv_scope: KVScope
extra_kv_scope: Optional[KVScope]
def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode:
kk.set_random_seed(t.seed)
assert t.h_q % t.h_kv == 0
assert t.decode is not None
q = torch.randn((t.decode.b, t.s_q, t.h_q, t.d_qk))
q.clamp_(min=-1.0, max=1.0)
attn_sink = None
if t.have_attn_sink:
attn_sink = torch.randn((t.h_q, ), dtype=torch.float32)
inf_mask = torch.randn((t.h_q, ), dtype=torch.float32)
attn_sink[inf_mask > 0.5] = float("inf")
attn_sink[inf_mask < -0.5] = float("-inf")
def generate_one_k_scope(s_k: int, block_size: int, topk: int, is_varlen: bool, have_zero_seqlen: bool, is_all_indices_invalid: bool, have_topk_length: bool) -> KVScope:
b = t.decode.b # type: ignore
cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device='cpu')
if is_varlen:
for i in range(b):
cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), t.s_q)
if have_zero_seqlen:
zeros_mask = torch.randn(b, dtype=torch.float32, device='cpu') > 0
cache_seqlens_cpu[zeros_mask] = 0
max_seqlen_alignment = 4 * block_size
max_seqlen_pad = max(kk.cdiv(int(cache_seqlens_cpu.max().item()), max_seqlen_alignment), 1) * max_seqlen_alignment
cache_seqlens = cache_seqlens_cpu.cuda()
assert max_seqlen_pad % block_size == 0
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1)
blocked_k = kk.gen_non_contiguous_randn_tensor((block_table.numel(), block_size, t.h_kv, t.d_qk)) / 10
blocked_k.clamp_(min=-1.0, max=1.0)
abs_indices = torch.empty((b, t.s_q, topk), dtype=torch.int32)
if is_all_indices_invalid:
abs_indices.fill_(-1)
else:
abs_indices[:] = _randperm_batch(b*t.s_q, cache_seqlens.repeat_interleave(t.s_q), topk, [-1]).view(b, t.s_q, topk)
indices_in_kvcache = quant.abs_indices2indices_in_kvcache(abs_indices, block_table, block_size)
topk_length = torch.randint(0, topk+1, (b, ), dtype=torch.int32, device=q.device) if have_topk_length else None
# Mask nonused KV as NaN
if have_topk_length:
indices_in_kvcache_masked = indices_in_kvcache.clone()
indices_in_kvcache_masked[torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, t.s_q, topk) >= (topk_length.view(b, 1, 1) if have_topk_length else topk)] = -1
else:
indices_in_kvcache_masked = indices_in_kvcache
blocked_k = blocked_k.view(-1, t.h_kv, t.d_qk)
nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask[indices_in_kvcache_masked] = False
blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, block_size, t.h_kv, t.d_qk)
block_table = kk.non_contiguousify(block_table)
abs_indices = kk.non_contiguousify(abs_indices)
indices_in_kvcache = kk.non_contiguousify(indices_in_kvcache)
return KVScope(t, cache_seqlens, block_table, blocked_k, abs_indices, indices_in_kvcache, topk_length)
kv_scope0 = generate_one_k_scope(t.s_kv, t.decode.block_size, t.topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.have_topk_length)
kv_scope0.quant_and_dequant_()
if t.decode.extra_topk is not None:
if t.decode.extra_s_k is None:
t.decode.extra_s_k = t.decode.extra_topk*2
if t.decode.extra_block_size is None:
t.decode.extra_block_size = t.decode.block_size
kv_scope1 = generate_one_k_scope(t.decode.extra_s_k, t.decode.extra_block_size, t.decode.extra_topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.decode.have_extra_topk_length)
kv_scope1.quant_and_dequant_()
else:
assert t.decode.extra_block_size is None and t.decode.extra_s_k is None and not t.decode.have_extra_topk_length
kv_scope1 = None
sm_scale = t.d_qk ** -0.55
q = kk.non_contiguousify(q)
return TestcaseForDecode(t, q, attn_sink, sm_scale, kv_scope0, kv_scope1)
def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool):
assert not return_p_sum
return flash_mla.flash_mla_sparse_fwd(
t.q, t.kv, t.indices,
sm_scale=t.sm_scale,
attn_sink=t.attn_sink,
topk_length=t.topk_length
)
def run_flash_mla_decode(p: TestParam, t: TestcaseForDecode, tile_scheduler_metadata, num_splits):
assert p.decode is not None
return flash_mla.flash_mla_with_kvcache(
t.q,
t.kv_scope.get_kvcache_for_flash_mla(),
None, None, p.d_v,
tile_scheduler_metadata, num_splits,
t.sm_scale, False, True,
t.kv_scope.indices_in_kvcache,
t.attn_sink,
t.extra_kv_scope.get_kvcache_for_flash_mla() if t.extra_kv_scope is not None else None,
t.extra_kv_scope.indices_in_kvcache if t.extra_kv_scope is not None else None,
t.kv_scope.topk_length,
t.extra_kv_scope.topk_length if t.extra_kv_scope is not None and t.extra_kv_scope.topk_length is not None else None
)
@dataclasses.dataclass
class FlopsAndMemVolStatistics:
"""
FLOPs and memory volume statistics for prefilling
"""
fwd_flop: float
fwd_mem_vol: float
def count_flop_and_mem_vol(p: TestParam, t: Testcase) -> FlopsAndMemVolStatistics:
total_topk = (p.s_q*p.topk) if t.topk_length is None else t.topk_length.sum().item()
indices_valid_mask = (t.indices >= 0) & (t.indices < p.s_kv)
if t.topk_length is not None:
indices_valid_mask &= (torch.arange(p.topk)[None, None, :].broadcast_to(p.s_q, p.h_kv, p.topk)) < t.topk_length[:, None, None]
num_valid_indices = indices_valid_mask.sum().item()
fwd_flop = 2 * total_topk * p.h_q * (p.d_qk + p.d_v)
fwd_mem_vol = num_valid_indices*p.d_qk*2 + p.s_q*p.h_q*(p.d_qk+p.d_v)*2
return FlopsAndMemVolStatistics(
fwd_flop,
fwd_mem_vol,
)
@dataclasses.dataclass
class FlopsAndMemVolStatisticsForDecode:
"""
FLOPs and memory volume statistics for decoding
"""
flop: float
mem_vol: float
def count_flop_and_mem_vol_for_decode(p: TestParam, t: TestcaseForDecode) -> FlopsAndMemVolStatisticsForDecode:
assert p.decode
b = p.decode.b
def get_num_attended_tokens(kv_scope: KVScope) -> int:
topk = kv_scope.indices_in_kvcache.shape[-1]
if kv_scope.topk_length is None:
return b * p.s_q * topk
else:
return int(kv_scope.topk_length.sum().item()) * p.s_q
def get_num_retrieved_tokens(kv_scope: KVScope) -> int:
if kv_scope.topk_length is None:
indices = kv_scope.indices_in_kvcache
else:
indices = kv_scope.indices_in_kvcache.clone()
batch, s_q, topk = indices.shape
mask = torch.arange(0, topk, device=indices.device).view(1, 1, topk).broadcast_to(batch, s_q, topk) >= kv_scope.topk_length.view(batch, 1, 1)
indices[mask] = -1
num_unique_tokens = indices.unique().numel() # type: ignore
return num_unique_tokens
num_attended_tokens = get_num_attended_tokens(t.kv_scope) + (get_num_attended_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0)
num_retrieved_tokens = get_num_retrieved_tokens(t.kv_scope) + (get_num_retrieved_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0)
compute_flop = 2 * p.h_q * num_attended_tokens * (p.d_qk + p.d_v)
kv_token_size = 656 if p.d_qk == 576 else 576 # Assume FP8 KV Cache
mem_vol = sum([
2 * b * p.s_q * p.h_q * p.d_qk, # Q
num_retrieved_tokens * kv_token_size, # K
2 * b * p.s_q * p.h_q * p.d_v, # O
])
return FlopsAndMemVolStatisticsForDecode(
compute_flop,
mem_vol
)
def is_no_cooldown() -> bool:
return os.environ.get('NO_COOLDOWN', '').lower() in ['1', 'yes', 'y']
|