flash-mla / tests /test_flash_mla_sparse_decoding.py
medmekk's picture
Upload folder using huggingface_hub
ccef021 verified
# /// script
# dependencies = [
# "numpy",
# "torch",
# "kernels",
# "triton",
# "rich",
# ]
# ///
import time
import dataclasses
from typing import Tuple, List, Dict, Optional
import copy
import rich.console
import rich.table
import torch
import kernelkit as kk
# import flash_mla
from kernels import get_kernel, get_local_kernel
flash_mla = get_kernel("drbh/tmp-kernel-123")
import lib
from lib import TestParam
from lib import RawTestParamForDecode as RawTestParam
import ref
"""
Generate testcase for unit test
"""
def gen_testcase() -> List[RawTestParam]:
correctness_cases = []
corner_cases = []
for d_qk in [576, 512]:
for have_extra_k in ([False, True] if d_qk == 512 else [False]):
for have_extra_topk_len in ([False, True] if have_extra_k else [False]):
for have_topk_len in ([False, True] if d_qk == 512 else [False]):
for h_q in [64, 128]:
cur_correctness_cases = [
RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk,
have_topk_length=have_topk_len,
enable_attn_sink=True,
extra_s_k=extra_s_k,
extra_topk=extra_topk,
block_size=block_size,
extra_block_size=extra_block_size,
have_extra_topk_length=have_extra_topk_len,
d_qk=d_qk,
check_correctness=True,
num_runs=0)
for (s_k, topk, block_size) in [
(512, 64, 2),
(512, 64, 64),
(512, 64, 69),
(1024, 576, 2),
(1024, 576, 61),
(2046, 2048, 2),
(2046, 2048, 64),
(2046, 2048, 576)
]
for (extra_s_k, extra_topk, extra_block_size) in ([
(512, 64, 2),
(512, 64, 64),
(512, 64, 69),
(1024, 576, 2),
(1024, 576, 61),
(2046, 2048, 2),
(2046, 2048, 64),
(2046, 2048, 576)
] if have_extra_k else [(None, None, None)])
for b in [4, 74, 321]
for s_q in [1, 3]
for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True])
]
correctness_cases.extend(cur_correctness_cases)
cur_corner_cases = [
RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk,
is_all_indices_invalid=is_all_indices_invalid,
have_zero_seqlen_k=have_zero_seqlen_k,
have_topk_length=have_topk_len,
enable_attn_sink=enable_attn_sink,
extra_s_k=extra_s_k,
extra_topk=extra_topk,
block_size=block_size,
extra_block_size=extra_block_size,
have_extra_topk_length=have_extra_topk_len,
d_qk=d_qk,
check_correctness=True,
num_runs=0,
)
for (s_k, topk, block_size) in [
(512, 64, 61),
(650, 576, 53),
]
for (extra_s_k, extra_topk, extra_block_size) in ([
(512, 64, 61),
(650, 576, 53),
] if have_extra_k else [(None, None, None)])
for b in [4, 74, 321]
for s_q in [3]
for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True])
for is_all_indices_invalid in [True, False]
for have_zero_seqlen_k in [True, False]
for enable_attn_sink in [True, False]
if (is_all_indices_invalid or have_zero_seqlen_k or enable_attn_sink)
]
corner_cases.extend(cur_corner_cases)
base_and_bszs = [
# V3.2
(RawTestParam(0, 128, 2, 1, 32768, True, topk=2048, d_qk=576), [2, 64, 74, 128]),
# MODEL1 CONFIG1
(RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=512, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]),
# MODEL1 CONFIG2
(RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]),
# MODEL1 CONFIG3
(RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),
# MODEL1 CONFIG4
(RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),
]
performance_cases = [
# Production cases
dataclasses.replace(base, b=b)
for base, bszs in base_and_bszs
for b in bszs
] + [
# Peak perf cases
RawTestParam(74*2, h_q, 2, 1, 32768, True, topk=16384, d_qk=d_qk)
for h_q in [64, 128]
for d_qk in [512, 576]
]
return correctness_cases + corner_cases + performance_cases
@dataclasses.dataclass
class Result:
is_correct: bool
compute_memory_ratio: float
time_usage_per_us: float
splitkv_time_usage_us: float
combine_time_usage_us: float
achieved_tflops: float
achieved_gBps: float
_counter = kk.Counter()
@torch.inference_mode()
def test_flash_mla(p: TestParam) -> Result:
if p.seed == -1:
global _counter
p.seed = _counter.next()
assert p.decode
print("================")
print(f"Running on {p}")
torch.cuda.empty_cache()
t = lib.generate_testcase_for_decode(p)
tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
def run_decode():
return lib.run_flash_mla_decode(p, t, tile_scheduler_metadata, None)
# We first run the kernel once to generate output data for the correctness test
# We must do this first, otherwise when allocating tensors for storing answers,
# it may re-use memory that contains the correct answer, leading to false positives
if p.check_correctness:
torch.cuda.synchronize()
out_ans, lse_ans = run_decode()
torch.cuda.synchronize()
# torch.set_printoptions(profile='full')
# print(tile_scheduler_metadata.tile_scheduler_metadata[:, :7])
# We run the performance test before generating the answer for the correctness test to avoid interference
performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
if p.num_runs == 0:
performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
else:
result = kk.bench_kineto(run_decode, p.num_runs)
splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel"
combine_kernel_name = "flash_fwd_mla_combine_kernel"
# Get individual kernel time usages
kernel_time_usages_us: Dict[str, Optional[float]] = {}
def pick_kernel_time_usage(kernel_name: str):
t = [kernel_name in s for s in result.get_kernel_names()]
if any(t):
assert sum(t) == 1
kernel_time_usages_us[kernel_name] = result.get_kernel_time(kernel_name) * 1e6
else:
kernel_time_usages_us[kernel_name] = None
pick_kernel_time_usage(splitkv_kernel_name)
pick_kernel_time_usage(combine_kernel_name)
# Get E2E time usages
def have_kernel(name: str):
return kernel_time_usages_us[name] is not None
if kk.is_using_profiling_tools():
e2e_time_usage_us = 1e6
else:
assert have_kernel(splitkv_kernel_name)
if have_kernel(combine_kernel_name):
e2e_time_usage_us = result.get_e2e_time(splitkv_kernel_name, combine_kernel_name) * 1e6
else:
e2e_time_usage_us = kernel_time_usages_us[splitkv_kernel_name]
assert e2e_time_usage_us is not None
flops_and_mem_vol = lib.count_flop_and_mem_vol_for_decode(p, t)
e2e_time_usage_s = e2e_time_usage_us / 1e6
theoritical_compute_memory_ratio = flops_and_mem_vol.flop / flops_and_mem_vol.mem_vol
achieved_tflops = flops_and_mem_vol.flop / e2e_time_usage_s / 1e12
achieved_gBps = flops_and_mem_vol.mem_vol / e2e_time_usage_s / 1e9
def print_kernel_time_usage(name: str, short_name: str):
if kernel_time_usages_us[name] is not None:
print(f'{short_name} time: {kernel_time_usages_us[name]:.1f} us')
print(f'Compute/Memory: {theoritical_compute_memory_ratio:.2f}')
print(f'Time (per): {e2e_time_usage_us:.1f} us')
print_kernel_time_usage(splitkv_kernel_name, "Splitkv")
print_kernel_time_usage(combine_kernel_name, "Combine")
print(f'TFlops: {achieved_tflops:.1f}')
print(f'GB/s: {achieved_gBps:.0f}')
performance_result = Result(True, theoritical_compute_memory_ratio, e2e_time_usage_us, kernel_time_usages_us[splitkv_kernel_name] or 0.0, kernel_time_usages_us[combine_kernel_name] or 0.0, achieved_tflops, achieved_gBps)
is_correct = True
if p.check_correctness:
torch.cuda.synchronize()
with torch.profiler.record_function("reference_flash_mla"):
out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t)
is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6)
is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
is_correct &= is_out_correct and is_lse_correct
performance_result.is_correct = is_correct
return performance_result
def main():
dtype = torch.bfloat16
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.set_float32_matmul_precision('high')
torch.set_num_threads(32)
raw_testcases = gen_testcase()
testcases = [t.to_test_param() for t in raw_testcases]
print(f"{kk.colors['CYAN_BG']}{len(testcases)} testcases to run{kk.colors['CLEAR']}")
is_no_cooldown = lib.is_no_cooldown()
num_testcases_len = len(str(len(testcases)))
failed_cases = []
results: List[Tuple[TestParam, Result]] = []
for testcase_idx, testcase in enumerate(testcases):
if testcase != testcases[0] and testcase.num_runs > 0 and not is_no_cooldown:
time.sleep(0.3) # Cooldown
print(f"[{testcase_idx+1:{num_testcases_len}d}/{len(testcases)}, {testcase_idx/len(testcases)*100:3.0f}%] ", end='')
result = test_flash_mla(testcase)
results.append((testcase, result))
if not result.is_correct:
failed_cases.append(testcase)
import sys
sys.exit(1)
console = rich.console.Console(width=120)
table = rich.table.Table(show_header=True, header_style="bold cyan")
table.add_column("topk")
table.add_column("Bsz")
table.add_column("h_q&k")
table.add_column("sq")
table.add_column("sk")
table.add_column("d_qk")
table.add_column("Feats")
table.add_column("C/M")
table.add_column("TFlops")
table.add_column("GBps")
table.add_column("us")
table.add_column(" ")
for testcase, result in results:
assert testcase.decode
topk_str = f"{testcase.topk}" if testcase.decode.extra_topk is None else f"{testcase.topk}+{testcase.decode.extra_topk}"
table.add_row(
topk_str,
str(testcase.decode.b),
f"{testcase.h_q:3d} {testcase.h_kv}",
str(testcase.s_q),
str(testcase.s_kv),
str(testcase.d_qk),
" V"[testcase.decode.is_varlen] + " L"[testcase.have_topk_length] + " E"[testcase.decode.have_extra_topk_length],
f"{result.compute_memory_ratio:3.0f}",
f"{result.achieved_tflops:3.0f}",
f"{result.achieved_gBps:4.0f}",
f"{result.time_usage_per_us:4.1f}",
"" if result.is_correct else "X"
)
console.print(table)
def geomean(l) -> float:
import numpy
return numpy.exp(numpy.mean(numpy.log(l)))
num_correct_testcases = [result.is_correct for t, result in results if t.check_correctness].count(True)
num_correctness_cases = sum([1 for t in testcases if t.check_correctness])
if num_correct_testcases == num_correctness_cases:
print(f"{kk.colors['GREEN_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}")
else:
print(f"{kk.colors['RED_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}")
for t in failed_cases:
print(f"\t{t},")
valid_achieved_tflops = [result.achieved_tflops for _, result in results if result.achieved_tflops > 0.1]
if len(valid_achieved_tflops) > 0:
achieved_tflops_geomean = geomean(valid_achieved_tflops) # > 0.1 to prune out correctness cases
print(f"TFlops geomean: {achieved_tflops_geomean:.1f}")
if __name__ == "__main__":
main()