| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import time |
| | import dataclasses |
| | from typing import Tuple, List, Dict, Optional |
| | import copy |
| |
|
| | import rich.console |
| | import rich.table |
| |
|
| | import torch |
| | import kernelkit as kk |
| |
|
| | |
| | from kernels import get_kernel, get_local_kernel |
| | flash_mla = get_kernel("drbh/tmp-kernel-123") |
| |
|
| | import lib |
| | from lib import TestParam |
| | from lib import RawTestParamForDecode as RawTestParam |
| | import ref |
| |
|
| | """ |
| | Generate testcase for unit test |
| | """ |
| |
|
| | def gen_testcase() -> List[RawTestParam]: |
| | correctness_cases = [] |
| | corner_cases = [] |
| | for d_qk in [576, 512]: |
| | for have_extra_k in ([False, True] if d_qk == 512 else [False]): |
| | for have_extra_topk_len in ([False, True] if have_extra_k else [False]): |
| | for have_topk_len in ([False, True] if d_qk == 512 else [False]): |
| | for h_q in [64, 128]: |
| | cur_correctness_cases = [ |
| | RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk, |
| | have_topk_length=have_topk_len, |
| | enable_attn_sink=True, |
| | extra_s_k=extra_s_k, |
| | extra_topk=extra_topk, |
| | block_size=block_size, |
| | extra_block_size=extra_block_size, |
| | have_extra_topk_length=have_extra_topk_len, |
| | d_qk=d_qk, |
| | check_correctness=True, |
| | num_runs=0) |
| | for (s_k, topk, block_size) in [ |
| | (512, 64, 2), |
| | (512, 64, 64), |
| | (512, 64, 69), |
| | (1024, 576, 2), |
| | (1024, 576, 61), |
| | (2046, 2048, 2), |
| | (2046, 2048, 64), |
| | (2046, 2048, 576) |
| | ] |
| | for (extra_s_k, extra_topk, extra_block_size) in ([ |
| | (512, 64, 2), |
| | (512, 64, 64), |
| | (512, 64, 69), |
| | (1024, 576, 2), |
| | (1024, 576, 61), |
| | (2046, 2048, 2), |
| | (2046, 2048, 64), |
| | (2046, 2048, 576) |
| | ] if have_extra_k else [(None, None, None)]) |
| | for b in [4, 74, 321] |
| | for s_q in [1, 3] |
| | for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True]) |
| | ] |
| | correctness_cases.extend(cur_correctness_cases) |
| |
|
| | cur_corner_cases = [ |
| | RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk, |
| | is_all_indices_invalid=is_all_indices_invalid, |
| | have_zero_seqlen_k=have_zero_seqlen_k, |
| | have_topk_length=have_topk_len, |
| | enable_attn_sink=enable_attn_sink, |
| | extra_s_k=extra_s_k, |
| | extra_topk=extra_topk, |
| | block_size=block_size, |
| | extra_block_size=extra_block_size, |
| | have_extra_topk_length=have_extra_topk_len, |
| | d_qk=d_qk, |
| | check_correctness=True, |
| | num_runs=0, |
| | ) |
| | for (s_k, topk, block_size) in [ |
| | (512, 64, 61), |
| | (650, 576, 53), |
| | ] |
| | for (extra_s_k, extra_topk, extra_block_size) in ([ |
| | (512, 64, 61), |
| | (650, 576, 53), |
| | ] if have_extra_k else [(None, None, None)]) |
| | for b in [4, 74, 321] |
| | for s_q in [3] |
| | for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True]) |
| | for is_all_indices_invalid in [True, False] |
| | for have_zero_seqlen_k in [True, False] |
| | for enable_attn_sink in [True, False] |
| | if (is_all_indices_invalid or have_zero_seqlen_k or enable_attn_sink) |
| | ] |
| | corner_cases.extend(cur_corner_cases) |
| |
|
| | base_and_bszs = [ |
| | |
| | (RawTestParam(0, 128, 2, 1, 32768, True, topk=2048, d_qk=576), [2, 64, 74, 128]), |
| | |
| | (RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=512, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]), |
| | |
| | (RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]), |
| | |
| | (RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]), |
| | |
| | (RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]), |
| | ] |
| | performance_cases = [ |
| | |
| | dataclasses.replace(base, b=b) |
| | for base, bszs in base_and_bszs |
| | for b in bszs |
| | ] + [ |
| | |
| | RawTestParam(74*2, h_q, 2, 1, 32768, True, topk=16384, d_qk=d_qk) |
| | for h_q in [64, 128] |
| | for d_qk in [512, 576] |
| | ] |
| |
|
| | return correctness_cases + corner_cases + performance_cases |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class Result: |
| | is_correct: bool |
| | compute_memory_ratio: float |
| | time_usage_per_us: float |
| | splitkv_time_usage_us: float |
| | combine_time_usage_us: float |
| | achieved_tflops: float |
| | achieved_gBps: float |
| |
|
| | _counter = kk.Counter() |
| |
|
| | @torch.inference_mode() |
| | def test_flash_mla(p: TestParam) -> Result: |
| | if p.seed == -1: |
| | global _counter |
| | p.seed = _counter.next() |
| | assert p.decode |
| |
|
| | print("================") |
| | print(f"Running on {p}") |
| | torch.cuda.empty_cache() |
| |
|
| | t = lib.generate_testcase_for_decode(p) |
| |
|
| | tile_scheduler_metadata, _ = flash_mla.get_mla_metadata() |
| | def run_decode(): |
| | return lib.run_flash_mla_decode(p, t, tile_scheduler_metadata, None) |
| | |
| | |
| | |
| | |
| | if p.check_correctness: |
| | torch.cuda.synchronize() |
| | out_ans, lse_ans = run_decode() |
| | torch.cuda.synchronize() |
| | |
| | |
| | |
| | |
| | performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) |
| | if p.num_runs == 0: |
| | performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) |
| | else: |
| | result = kk.bench_kineto(run_decode, p.num_runs) |
| |
|
| | splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel" |
| | combine_kernel_name = "flash_fwd_mla_combine_kernel" |
| | |
| | |
| | kernel_time_usages_us: Dict[str, Optional[float]] = {} |
| | def pick_kernel_time_usage(kernel_name: str): |
| | t = [kernel_name in s for s in result.get_kernel_names()] |
| | if any(t): |
| | assert sum(t) == 1 |
| | kernel_time_usages_us[kernel_name] = result.get_kernel_time(kernel_name) * 1e6 |
| | else: |
| | kernel_time_usages_us[kernel_name] = None |
| | pick_kernel_time_usage(splitkv_kernel_name) |
| | pick_kernel_time_usage(combine_kernel_name) |
| |
|
| | |
| | def have_kernel(name: str): |
| | return kernel_time_usages_us[name] is not None |
| | |
| | if kk.is_using_profiling_tools(): |
| | e2e_time_usage_us = 1e6 |
| | else: |
| | assert have_kernel(splitkv_kernel_name) |
| | if have_kernel(combine_kernel_name): |
| | e2e_time_usage_us = result.get_e2e_time(splitkv_kernel_name, combine_kernel_name) * 1e6 |
| | else: |
| | e2e_time_usage_us = kernel_time_usages_us[splitkv_kernel_name] |
| | assert e2e_time_usage_us is not None |
| |
|
| | flops_and_mem_vol = lib.count_flop_and_mem_vol_for_decode(p, t) |
| |
|
| | e2e_time_usage_s = e2e_time_usage_us / 1e6 |
| | theoritical_compute_memory_ratio = flops_and_mem_vol.flop / flops_and_mem_vol.mem_vol |
| | achieved_tflops = flops_and_mem_vol.flop / e2e_time_usage_s / 1e12 |
| | achieved_gBps = flops_and_mem_vol.mem_vol / e2e_time_usage_s / 1e9 |
| | def print_kernel_time_usage(name: str, short_name: str): |
| | if kernel_time_usages_us[name] is not None: |
| | print(f'{short_name} time: {kernel_time_usages_us[name]:.1f} us') |
| | print(f'Compute/Memory: {theoritical_compute_memory_ratio:.2f}') |
| | print(f'Time (per): {e2e_time_usage_us:.1f} us') |
| | print_kernel_time_usage(splitkv_kernel_name, "Splitkv") |
| | print_kernel_time_usage(combine_kernel_name, "Combine") |
| | print(f'TFlops: {achieved_tflops:.1f}') |
| | print(f'GB/s: {achieved_gBps:.0f}') |
| |
|
| | performance_result = Result(True, theoritical_compute_memory_ratio, e2e_time_usage_us, kernel_time_usages_us[splitkv_kernel_name] or 0.0, kernel_time_usages_us[combine_kernel_name] or 0.0, achieved_tflops, achieved_gBps) |
| | |
| | is_correct = True |
| | if p.check_correctness: |
| | torch.cuda.synchronize() |
| | with torch.profiler.record_function("reference_flash_mla"): |
| | out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t) |
| |
|
| | is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6) |
| | is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) |
| | is_correct &= is_out_correct and is_lse_correct |
| |
|
| | performance_result.is_correct = is_correct |
| | return performance_result |
| |
|
| |
|
| | def main(): |
| | dtype = torch.bfloat16 |
| | device = torch.device("cuda:0") |
| | torch.set_default_dtype(dtype) |
| | torch.set_default_device(device) |
| | torch.cuda.set_device(device) |
| | torch.set_float32_matmul_precision('high') |
| | torch.set_num_threads(32) |
| |
|
| | raw_testcases = gen_testcase() |
| | testcases = [t.to_test_param() for t in raw_testcases] |
| |
|
| | print(f"{kk.colors['CYAN_BG']}{len(testcases)} testcases to run{kk.colors['CLEAR']}") |
| |
|
| | is_no_cooldown = lib.is_no_cooldown() |
| | num_testcases_len = len(str(len(testcases))) |
| | failed_cases = [] |
| | results: List[Tuple[TestParam, Result]] = [] |
| | for testcase_idx, testcase in enumerate(testcases): |
| | if testcase != testcases[0] and testcase.num_runs > 0 and not is_no_cooldown: |
| | time.sleep(0.3) |
| | print(f"[{testcase_idx+1:{num_testcases_len}d}/{len(testcases)}, {testcase_idx/len(testcases)*100:3.0f}%] ", end='') |
| | result = test_flash_mla(testcase) |
| | results.append((testcase, result)) |
| | if not result.is_correct: |
| | failed_cases.append(testcase) |
| | import sys |
| | sys.exit(1) |
| |
|
| | console = rich.console.Console(width=120) |
| | table = rich.table.Table(show_header=True, header_style="bold cyan") |
| | table.add_column("topk") |
| | table.add_column("Bsz") |
| | table.add_column("h_q&k") |
| | table.add_column("sq") |
| | table.add_column("sk") |
| | table.add_column("d_qk") |
| | table.add_column("Feats") |
| | table.add_column("C/M") |
| | table.add_column("TFlops") |
| | table.add_column("GBps") |
| | table.add_column("us") |
| | table.add_column(" ") |
| |
|
| | for testcase, result in results: |
| | assert testcase.decode |
| | topk_str = f"{testcase.topk}" if testcase.decode.extra_topk is None else f"{testcase.topk}+{testcase.decode.extra_topk}" |
| | table.add_row( |
| | topk_str, |
| | str(testcase.decode.b), |
| | f"{testcase.h_q:3d} {testcase.h_kv}", |
| | str(testcase.s_q), |
| | str(testcase.s_kv), |
| | str(testcase.d_qk), |
| | " V"[testcase.decode.is_varlen] + " L"[testcase.have_topk_length] + " E"[testcase.decode.have_extra_topk_length], |
| | f"{result.compute_memory_ratio:3.0f}", |
| | f"{result.achieved_tflops:3.0f}", |
| | f"{result.achieved_gBps:4.0f}", |
| | f"{result.time_usage_per_us:4.1f}", |
| | "" if result.is_correct else "X" |
| | ) |
| | console.print(table) |
| |
|
| | def geomean(l) -> float: |
| | import numpy |
| | return numpy.exp(numpy.mean(numpy.log(l))) |
| | |
| | num_correct_testcases = [result.is_correct for t, result in results if t.check_correctness].count(True) |
| | num_correctness_cases = sum([1 for t in testcases if t.check_correctness]) |
| | if num_correct_testcases == num_correctness_cases: |
| | print(f"{kk.colors['GREEN_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}") |
| | else: |
| | print(f"{kk.colors['RED_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}") |
| | for t in failed_cases: |
| | print(f"\t{t},") |
| |
|
| | valid_achieved_tflops = [result.achieved_tflops for _, result in results if result.achieved_tflops > 0.1] |
| | if len(valid_achieved_tflops) > 0: |
| | achieved_tflops_geomean = geomean(valid_achieved_tflops) |
| | print(f"TFlops geomean: {achieved_tflops_geomean:.1f}") |
| | |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|