| | import time |
| | import sys |
| |
|
| | import torch |
| | import kernelkit as kk |
| |
|
| | from lib import TestParam |
| | import lib |
| | import ref |
| |
|
| | _counter = kk.Counter() |
| |
|
| | @torch.inference_mode() |
| | def run_test(p: TestParam) -> bool: |
| | if p.seed == -1: |
| | global _counter |
| | p.seed = _counter.next() |
| |
|
| | print("================") |
| | print(f"Running on {p}") |
| | torch.cuda.empty_cache() |
| |
|
| | t = lib.generate_testcase(p) |
| | torch.cuda.synchronize() |
| | |
| | def run_prefill(): |
| | return lib.run_flash_mla_sparse_fwd(p, t, False) |
| | |
| | prefill_ans_out, prefill_ans_max_logits, prefill_ans_lse = run_prefill() |
| | torch.cuda.synchronize() |
| |
|
| | if p.num_runs > 0: |
| | flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t) |
| | prefill_ans_time = kk.bench_kineto(run_prefill, num_tests=p.num_runs).get_kernel_time("sparse_attn_fwd") |
| | prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12 |
| | prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12 |
| | print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:6.1f} TFlops, {prefill_mem_bw:4.2f} TBps") |
| |
|
| | if p.check_correctness: |
| | torch.cuda.synchronize() |
| | ref_out, ref_out_fp32, ref_max_logits, ref_lse = ref.ref_sparse_attn_fwd(p, t) |
| | ref_lse[ref_lse == float("-inf")] = float("+inf") |
| | torch.cuda.synchronize() |
| |
|
| | is_correct = True |
| | is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6) |
| | is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) |
| | is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) |
| |
|
| | return is_correct |
| | else: |
| | return True |
| |
|
| |
|
| | if __name__ == '__main__': |
| | device = torch.device("cuda:0") |
| | torch.set_default_dtype(torch.bfloat16) |
| | torch.set_default_device(device) |
| | torch.cuda.set_device(device) |
| | torch.set_float32_matmul_precision('high') |
| |
|
| | correctness_cases = [ |
| | |
| | TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, d_qk=d_qk) |
| | for d_qk in [512, 576] |
| | for h_q in [ |
| | 128, 64 |
| | ] |
| | for s_kv, topk in [ |
| | |
| | (128, 128), |
| | (256, 256), |
| | (512, 512), |
| |
|
| | |
| | (592, 128), |
| | (1840, 256), |
| | (1592, 384), |
| | (1521, 512), |
| |
|
| | |
| | (95, 128), |
| | (153, 256), |
| | (114, 384), |
| | ] |
| | for s_q in [ |
| | 1, 62, 213 |
| | ] |
| | ] |
| |
|
| | correctness_cases_with_features = [ |
| | TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length, d_qk=d_qk) |
| | for d_qk in [512, 576] |
| | for h_q in [ |
| | 128, 64 |
| | ] |
| | for s_kv, topk in [ |
| | (592, 128), |
| | (1840, 256), |
| | (1592, 384), |
| | (1521, 512), |
| |
|
| | (95, 128), |
| | (153, 256), |
| | (114, 384), |
| | ] |
| | for s_q in [62, 213] |
| | for have_sink_lse in [False, True] |
| | for have_attn_sink in [False, True] |
| | for have_topk_length in [False, True] |
| | ] |
| |
|
| | corner_cases = [ |
| | TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=True, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk) |
| | for d_qk in [512, 576] |
| | for h_q in [ |
| | 128, 64 |
| | ] |
| | for s_q, s_kv, topk in [ |
| | (1, 128, 128), |
| | (1, 256, 256), |
| | (1234, 4321, 4096), |
| | (4096, 2048, 2048) |
| | ] |
| | ] + [ |
| | |
| | TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk) |
| | for d_qk in [512, 576] |
| | for h_q in [ |
| | 128, 64 |
| | ] |
| | for s_kv, topk in [ |
| | (32, 2048), |
| | (64, 8192) |
| | ] |
| | for s_q in [1, 1024] |
| | ] + [ |
| | |
| | TestParam(70000, 256, 256, h_q=h_q, check_correctness=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk) |
| | for d_qk in [512, 576] |
| | for h_q in [ |
| | 128, 64 |
| | ] |
| | ] |
| |
|
| | performance_case_templates = [ |
| | |
| | (576, 128, 2048, [8192, 32768, 65536, 98304, 131072]), |
| | |
| | (512, 64, 512, [8192, 32768, 49152, 65536]), |
| | |
| | (512, 128, 1024, [8192, 32768, 49152, 65536]), |
| | ] |
| |
|
| | performance_cases = [ |
| | TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=True) |
| | for (d_qk, h_q, topk, s_kv_list) in performance_case_templates |
| | for s_q in [4096] |
| | for s_kv in s_kv_list |
| | ] |
| |
|
| | testcases = correctness_cases + correctness_cases_with_features + corner_cases + performance_cases |
| |
|
| | is_no_cooldown = lib.is_no_cooldown() |
| | failed_cases = [] |
| | for test in testcases: |
| | if test != testcases[0] and test.num_runs > 0 and not is_no_cooldown: |
| | time.sleep(0.3) |
| | is_correct = run_test(test) |
| | if not is_correct: |
| | failed_cases.append(test) |
| | |
| | if len(failed_cases) > 0: |
| | print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m") |
| | for case in failed_cases: |
| | print(f" {case}") |
| | sys.exit(1) |
| | else: |
| | print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") |
| |
|
| |
|