flash-mla / tests /test_flash_mla_sparse_prefill.py
medmekk's picture
Upload folder using huggingface_hub
ccef021 verified
import time
import sys
import torch
import kernelkit as kk
from lib import TestParam
import lib
import ref
_counter = kk.Counter()
@torch.inference_mode()
def run_test(p: TestParam) -> bool:
if p.seed == -1:
global _counter
p.seed = _counter.next()
print("================")
print(f"Running on {p}")
torch.cuda.empty_cache()
t = lib.generate_testcase(p)
torch.cuda.synchronize()
def run_prefill():
return lib.run_flash_mla_sparse_fwd(p, t, False)
prefill_ans_out, prefill_ans_max_logits, prefill_ans_lse = run_prefill()
torch.cuda.synchronize()
if p.num_runs > 0:
flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t)
prefill_ans_time = kk.bench_kineto(run_prefill, num_tests=p.num_runs).get_kernel_time("sparse_attn_fwd")
prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12
prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12
print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:6.1f} TFlops, {prefill_mem_bw:4.2f} TBps")
if p.check_correctness:
torch.cuda.synchronize()
ref_out, ref_out_fp32, ref_max_logits, ref_lse = ref.ref_sparse_attn_fwd(p, t)
ref_lse[ref_lse == float("-inf")] = float("+inf")
torch.cuda.synchronize()
is_correct = True
is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6)
is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
return is_correct
else:
return True
if __name__ == '__main__':
device = torch.device("cuda:0")
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.set_float32_matmul_precision('high')
correctness_cases = [
# Regular shapes
TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_kv, topk in [
# Regular shapes
(128, 128),
(256, 256),
(512, 512),
# Irregular shapes
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
# Irregular shapes with OOB TopK
(95, 128),
(153, 256),
(114, 384),
]
for s_q in [
1, 62, 213
]
]
correctness_cases_with_features = [
TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_kv, topk in [
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
(95, 128),
(153, 256),
(114, 384),
]
for s_q in [62, 213]
for have_sink_lse in [False, True]
for have_attn_sink in [False, True]
for have_topk_length in [False, True]
]
corner_cases = [
TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=True, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_q, s_kv, topk in [
(1, 128, 128),
(1, 256, 256),
(1234, 4321, 4096),
(4096, 2048, 2048)
]
] + [
# In these cases, some blocks may not have any valid topk indices
TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
for s_kv, topk in [
(32, 2048),
(64, 8192)
]
for s_q in [1, 1024]
] + [
# In this testcase, s_q is really large, so we cannot put it on the second dimension of grid shape
TestParam(70000, 256, 256, h_q=h_q, check_correctness=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk)
for d_qk in [512, 576]
for h_q in [
128, 64
]
]
performance_case_templates = [
# V3.2
(576, 128, 2048, [8192, 32768, 65536, 98304, 131072]),
# MODEL1 CONFIG1
(512, 64, 512, [8192, 32768, 49152, 65536]),
# MODEL1 CONFIG2
(512, 128, 1024, [8192, 32768, 49152, 65536]),
]
performance_cases = [
TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=True)
for (d_qk, h_q, topk, s_kv_list) in performance_case_templates
for s_q in [4096]
for s_kv in s_kv_list
]
testcases = correctness_cases + correctness_cases_with_features + corner_cases + performance_cases
is_no_cooldown = lib.is_no_cooldown()
failed_cases = []
for test in testcases:
if test != testcases[0] and test.num_runs > 0 and not is_no_cooldown:
time.sleep(0.3)
is_correct = run_test(test)
if not is_correct:
failed_cases.append(test)
if len(failed_cases) > 0:
print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m")
for case in failed_cases:
print(f" {case}")
sys.exit(1)
else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")