| |
| """Benchmark flashrt-qkv-cache-rope against a PyTorch eager postprocess chain.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import ctypes |
| import ctypes.util |
| import importlib |
| import json |
| import math |
| import os |
| import sys |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
|
|
| import torch |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[2] |
| PACKAGE = ROOT / "flashrt-qkv-cache-rope" |
| REGISTRATION_INCLUDE = ( |
| ROOT.parent |
| / "kernels" |
| / "kernel-builder" |
| / "src" |
| / "pyproject" |
| / "templates" |
| / "torch" |
| ) |
|
|
| SHAPES = { |
| "small": (1, 64, 8, 128), |
| "wan_1k": (1, 1024, 24, 128), |
| "wan_2520": (1, 2520, 24, 128), |
| "wan_4096": (1, 4096, 24, 128), |
| "vl_512": (1, 512, 16, 128), |
| } |
| SHAPE_GROUPS = { |
| "smoke": ["small"], |
| "headline": ["wan_1k", "wan_2520", "vl_512"], |
| "all": list(SHAPES.keys()), |
| } |
|
|
|
|
| @dataclass |
| class Result: |
| shape: str |
| batch: int |
| seq_len: int |
| heads: int |
| head_dim: int |
| flashrt_us: float |
| torch_eager_us: float |
| speedup_vs_eager: float |
| q_p99_abs: float |
| k_p99_abs: float |
| q_cosine: float |
| k_cosine: float |
| status: str |
|
|
|
|
| class SourceOps: |
| def __init__(self, namespace: str) -> None: |
| self._ops = getattr(torch.ops, namespace) |
|
|
| def decode_q_norm_rope_stage_bf16(self, q_pre, q_w, cos, sin, eps=1e-6, q_out=None): |
| if q_out is None: |
| q_out = torch.empty_like(q_pre) |
| self._ops.decode_q_norm_rope_stage_bf16(q_pre, q_w, cos, sin, float(eps), q_out) |
| return q_out |
|
|
| def decode_k_norm_rope_kvwrite_bf16(self, k_pre, v_pre, k_w, cos, sin, eps=1e-6, k_out=None, v_out=None): |
| if k_out is None: |
| k_out = torch.empty_like(k_pre) |
| if v_out is None: |
| v_out = torch.empty_like(v_pre) |
| self._ops.decode_k_norm_rope_kvwrite_bf16(k_pre, v_pre, k_w, cos, sin, float(eps), k_out, v_out) |
| return k_out, v_out |
|
|
| def decode_k_norm_rope_kvwrite_devpos_bf16(self, k_pre, v_pre, k_w, cos, sin, cur_pos, k_cache, v_cache, eps=1e-6): |
| self._ops.decode_k_norm_rope_kvwrite_devpos_bf16(k_pre, v_pre, k_w, cos, sin, cur_pos, float(eps), k_cache, v_cache) |
| return k_cache, v_cache |
|
|
| def qkv_split_norm_rope_bf16( |
| self, packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, rope_seq_len=None, eps=1e-6, q_out=None, k_out=None |
| ): |
| if rope_seq_len is None: |
| rope_seq_len = packed.shape[1] |
| if q_out is None: |
| q_out = torch.empty((packed.shape[0], packed.shape[1], heads, head_dim), device=packed.device, dtype=torch.bfloat16) |
| if k_out is None: |
| k_out = torch.empty_like(q_out) |
| self._ops.qkv_split_norm_rope_bf16( |
| packed, q_w, k_w, freqs_re, freqs_im, int(heads), int(head_dim), |
| int(rope_seq_len), float(eps), q_out, k_out |
| ) |
| return q_out, k_out |
|
|
| def qkv_split_joint3_cat_bf16( |
| self, |
| packed_v, |
| qkv_v_bias, |
| norm_v_q_weight, |
| norm_v_k_weight, |
| freqs_re, |
| freqs_im, |
| packed_a, |
| norm_a_q_weight, |
| norm_a_k_weight, |
| packed_u, |
| norm_u_q_weight, |
| norm_u_k_weight, |
| heads, |
| head_dim, |
| q_cat_out, |
| k_cat_out, |
| v_cat_out, |
| rope_seq_len=None, |
| eps_v=1e-6, |
| eps_a=1e-6, |
| eps_u=1e-6, |
| ): |
| if rope_seq_len is None: |
| rope_seq_len = packed_v.shape[1] |
| self._ops.qkv_split_joint3_cat_bf16( |
| packed_v, |
| qkv_v_bias, |
| norm_v_q_weight, |
| norm_v_k_weight, |
| freqs_re, |
| freqs_im, |
| packed_a, |
| norm_a_q_weight, |
| norm_a_k_weight, |
| packed_u, |
| norm_u_q_weight, |
| norm_u_k_weight, |
| int(heads), |
| int(head_dim), |
| int(rope_seq_len), |
| float(eps_v), |
| float(eps_a), |
| float(eps_u), |
| q_cat_out, |
| k_cat_out, |
| v_cat_out, |
| ) |
| return q_cat_out, k_cat_out, v_cat_out |
|
|
| def qkv_split_rope_kvcache_bf16( |
| self, |
| packed_qkv, |
| rope, |
| q_heads, |
| kv_heads, |
| head_dim, |
| cache_offset, |
| q_out=None, |
| k_cache=None, |
| v_cache=None, |
| max_seq_len=None, |
| ): |
| batch, seq_len, _ = packed_qkv.shape |
| if q_out is None: |
| q_out = torch.empty( |
| (batch, seq_len, q_heads, head_dim), |
| device=packed_qkv.device, |
| dtype=torch.bfloat16, |
| ) |
| if k_cache is None or v_cache is None: |
| if max_seq_len is None: |
| max_seq_len = cache_offset + seq_len |
| cache_shape = (batch, max_seq_len, kv_heads, head_dim) |
| if k_cache is None: |
| k_cache = torch.empty(cache_shape, device=packed_qkv.device, dtype=torch.bfloat16) |
| if v_cache is None: |
| v_cache = torch.empty(cache_shape, device=packed_qkv.device, dtype=torch.bfloat16) |
| self._ops.qkv_split_rope_kvcache_bf16( |
| packed_qkv, |
| rope, |
| int(q_heads), |
| int(kv_heads), |
| int(head_dim), |
| int(cache_offset), |
| q_out, |
| k_cache, |
| v_cache, |
| ) |
| return q_out, k_cache, v_cache |
|
|
|
|
| def _preload_cublaslt() -> None: |
| for parent in Path(torch.__file__).resolve().parents: |
| candidate = parent / "nvidia" / "cublas" / "lib" / "libcublasLt.so.12" |
| if candidate.exists(): |
| ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL) |
| return |
| library = ctypes.util.find_library("cublasLt") |
| if library: |
| ctypes.CDLL(library, mode=ctypes.RTLD_GLOBAL) |
|
|
|
|
| def _current_arch_list() -> str: |
| major, minor = torch.cuda.get_device_capability(0) |
| return f"{major}.{minor}" |
|
|
|
|
| def load_source_ops() -> SourceOps: |
| from torch.utils.cpp_extension import load |
|
|
| if not REGISTRATION_INCLUDE.is_dir(): |
| raise RuntimeError(f"missing kernel-builder registration include: {REGISTRATION_INCLUDE}") |
| _preload_cublaslt() |
| os.environ.setdefault("TORCH_CUDA_ARCH_LIST", _current_arch_list()) |
| namespace = "flashrt_qkv_cache_rope_benchmark" |
| load( |
| name=namespace, |
| sources=[ |
| str(PACKAGE / "torch-ext" / "torch_binding.cpp"), |
| str(PACKAGE / "csrc" / "qkv_cache_rope.cu"), |
| ], |
| extra_include_paths=[str(PACKAGE / "csrc"), str(REGISTRATION_INCLUDE)], |
| extra_cflags=["-O3", "-DCUDA_KERNEL"], |
| extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr", "-DCUDA_KERNEL"], |
| verbose=False, |
| ) |
| return SourceOps(namespace) |
|
|
|
|
| def load_installed_ops(artifact: str | None): |
| if artifact: |
| sys.path.insert(0, artifact) |
| try: |
| return importlib.import_module("flashrt_qkv_cache_rope") |
| finally: |
| if artifact: |
| sys.path.remove(artifact) |
|
|
|
|
| def make_freqs(seq_len: int, head_dim: int): |
| theta = torch.randn((seq_len, head_dim // 2), device="cuda", dtype=torch.float32) |
| return torch.cos(theta).contiguous(), torch.sin(theta).contiguous() |
|
|
|
|
| def make_interleaved_rope(seq_len: int, head_dim: int): |
| theta = torch.randn((seq_len, head_dim // 2), device="cuda", dtype=torch.float32) |
| cos = torch.cos(theta).to(torch.bfloat16) |
| sin = torch.sin(theta).to(torch.bfloat16) |
| return torch.stack([cos, sin], dim=-1).reshape(seq_len, head_dim).contiguous() |
|
|
|
|
| def make_case(batch: int, seq_len: int, heads: int, head_dim: int): |
| dim = heads * head_dim |
| packed = torch.randn((batch, seq_len, 3 * dim), device="cuda", dtype=torch.bfloat16) |
| q_w = (1.0 + 0.1 * torch.randn((dim,), device="cuda", dtype=torch.bfloat16)).contiguous() |
| k_w = (1.0 + 0.1 * torch.randn((dim,), device="cuda", dtype=torch.bfloat16)).contiguous() |
| freqs_re, freqs_im = make_freqs(seq_len, head_dim) |
| q_out = torch.empty((batch, seq_len, heads, head_dim), device="cuda", dtype=torch.bfloat16) |
| k_out = torch.empty_like(q_out) |
| return packed, q_w, k_w, freqs_re, freqs_im, q_out, k_out |
|
|
|
|
| def make_decode_case(heads: int): |
| q = torch.randn((heads, 128), device="cuda", dtype=torch.bfloat16) |
| k = torch.randn((heads, 128), device="cuda", dtype=torch.bfloat16) |
| v = torch.randn((heads, 128), device="cuda", dtype=torch.bfloat16) |
| q_w = (1.0 + 0.1 * torch.randn((128,), device="cuda", dtype=torch.bfloat16)).contiguous() |
| k_w = (1.0 + 0.1 * torch.randn((128,), device="cuda", dtype=torch.bfloat16)).contiguous() |
| theta = torch.randn((64,), device="cuda", dtype=torch.float32) |
| cos = torch.cos(theta).to(torch.bfloat16).contiguous() |
| sin = torch.sin(theta).to(torch.bfloat16).contiguous() |
| return q, k, v, q_w, k_w, cos, sin |
|
|
|
|
| def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float): |
| rms = torch.rsqrt(torch.mean(x.float() * x.float(), dim=-1, keepdim=True) + eps) |
| return x.float() * rms * weight.float() |
|
|
|
|
| def apply_pair_rope(x: torch.Tensor, freqs_re: torch.Tensor, freqs_im: torch.Tensor): |
| batch, seq_len, heads, head_dim = x.shape |
| pair = x.float().reshape(batch, seq_len, heads, head_dim // 2, 2) |
| re = pair[..., 0] |
| im = pair[..., 1] |
| fr = freqs_re.view(1, seq_len, 1, head_dim // 2) |
| fi = freqs_im.view(1, seq_len, 1, head_dim // 2) |
| out = torch.empty_like(pair.float()) |
| out[..., 0] = re * fr - im * fi |
| out[..., 1] = re * fi + im * fr |
| return out.reshape(batch, seq_len, heads, head_dim).to(torch.bfloat16) |
|
|
|
|
| def apply_interleaved_pair_rope(x: torch.Tensor, rope: torch.Tensor): |
| batch, seq_len, heads, head_dim = x.shape |
| pair = x.float().reshape(batch, seq_len, heads, head_dim // 2, 2) |
| re = pair[..., 0] |
| im = pair[..., 1] |
| rope_pair = rope[:seq_len].float().reshape(seq_len, head_dim // 2, 2) |
| cos = rope_pair[..., 0].view(1, seq_len, 1, head_dim // 2) |
| sin = rope_pair[..., 1].view(1, seq_len, 1, head_dim // 2) |
| out = torch.empty_like(pair.float()) |
| out[..., 0] = re * cos - im * sin |
| out[..., 1] = re * sin + im * cos |
| return out.reshape(batch, seq_len, heads, head_dim).to(torch.bfloat16) |
|
|
|
|
| def apply_rotate_half_rope_128(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): |
| xf = x.float() |
| out = torch.empty_like(xf) |
| c = cos.float().view(1, 64) |
| s = sin.float().view(1, 64) |
| out[:, :64] = xf[:, :64] * c - xf[:, 64:] * s |
| out[:, 64:] = xf[:, 64:] * c + xf[:, :64] * s |
| return out.to(torch.bfloat16) |
|
|
|
|
| def torch_ref(packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, eps): |
| batch, seq_len, _ = packed.shape |
| dim = heads * head_dim |
| q = packed[:, :, :dim] |
| k = packed[:, :, dim : 2 * dim] |
| qn = rms_norm(q, q_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim) |
| kn = rms_norm(k, k_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim) |
| return apply_pair_rope(qn, freqs_re, freqs_im), apply_pair_rope(kn, freqs_re, freqs_im) |
|
|
|
|
| def torch_ref_bias(packed, qkv_bias, q_w, k_w, freqs_re, freqs_im, heads, head_dim, eps): |
| batch, seq_len, _ = packed.shape |
| dim = heads * head_dim |
| biased = packed.float() + qkv_bias.float().view(1, 1, 3 * dim) |
| q = biased[:, :, :dim] |
| k = biased[:, :, dim : 2 * dim] |
| v = biased[:, :, 2 * dim :].to(torch.bfloat16).view(batch, seq_len, heads, head_dim) |
| qn = rms_norm(q, q_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim) |
| kn = rms_norm(k, k_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim) |
| return apply_pair_rope(qn, freqs_re, freqs_im), apply_pair_rope(kn, freqs_re, freqs_im), v |
|
|
|
|
| def torch_ref_no_rope(packed, q_w, k_w, heads, head_dim, eps): |
| batch, seq_len, _ = packed.shape |
| dim = heads * head_dim |
| q = packed[:, :, :dim] |
| k = packed[:, :, dim : 2 * dim] |
| v = packed[:, :, 2 * dim :].view(batch, seq_len, heads, head_dim) |
| qn = rms_norm(q, q_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim) |
| kn = rms_norm(k, k_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim) |
| return qn, kn, v |
|
|
|
|
| def torch_ref_decode(x, weight, cos, sin, eps): |
| return apply_rotate_half_rope_128(rms_norm(x, weight, eps).to(torch.bfloat16), cos, sin) |
|
|
|
|
| def torch_ref_kvcache(packed_qkv, rope, q_heads, kv_heads, head_dim): |
| batch, seq_len, _ = packed_qkv.shape |
| q_dim = q_heads * head_dim |
| kv_dim = kv_heads * head_dim |
| q = packed_qkv[:, :, :q_dim].view(batch, seq_len, q_heads, head_dim) |
| k = packed_qkv[:, :, q_dim : q_dim + kv_dim].view(batch, seq_len, kv_heads, head_dim) |
| v = packed_qkv[:, :, q_dim + kv_dim :].view(batch, seq_len, kv_heads, head_dim) |
| return apply_interleaved_pair_rope(q, rope), apply_interleaved_pair_rope(k, rope), v |
|
|
|
|
| def make_joint3_case(video_len: int, action_len: int, und_len: int, heads: int, head_dim: int): |
| packed_v, v_q_w, v_k_w, freqs_re, freqs_im, _, _ = make_case(1, video_len, heads, head_dim) |
| packed_a, a_q_w, a_k_w, _, _, _, _ = make_case(1, action_len, heads, head_dim) |
| packed_u, u_q_w, u_k_w, _, _, _, _ = make_case(1, und_len, heads, head_dim) |
| dim = heads * head_dim |
| qkv_v_bias = (0.02 * torch.randn((3 * dim,), device="cuda", dtype=torch.bfloat16)).contiguous() |
| total = video_len + action_len + und_len |
| q_cat = torch.empty((1, total, heads, head_dim), device="cuda", dtype=torch.bfloat16) |
| k_cat = torch.empty_like(q_cat) |
| v_cat = torch.empty_like(q_cat) |
| return ( |
| packed_v, |
| qkv_v_bias, |
| v_q_w, |
| v_k_w, |
| freqs_re, |
| freqs_im, |
| packed_a, |
| a_q_w, |
| a_k_w, |
| packed_u, |
| u_q_w, |
| u_k_w, |
| q_cat, |
| k_cat, |
| v_cat, |
| ) |
|
|
|
|
| def torch_ref_joint3( |
| packed_v, |
| qkv_v_bias, |
| v_q_w, |
| v_k_w, |
| freqs_re, |
| freqs_im, |
| packed_a, |
| a_q_w, |
| a_k_w, |
| packed_u, |
| u_q_w, |
| u_k_w, |
| heads, |
| head_dim, |
| eps, |
| ): |
| qv, kv, vv = torch_ref_bias(packed_v, qkv_v_bias, v_q_w, v_k_w, freqs_re, freqs_im, heads, head_dim, eps) |
| qa, ka, va = torch_ref_no_rope(packed_a, a_q_w, a_k_w, heads, head_dim, eps) |
| qu, ku, vu = torch_ref_no_rope(packed_u, u_q_w, u_k_w, heads, head_dim, eps) |
| return torch.cat([qv, qa, qu], dim=1), torch.cat([kv, ka, ku], dim=1), torch.cat([vv, va, vu], dim=1) |
|
|
|
|
| def time_us(fn, warmup: int, iters: int) -> float: |
| for _ in range(warmup): |
| fn() |
| torch.cuda.synchronize() |
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
| start.record() |
| for _ in range(iters): |
| fn() |
| end.record() |
| torch.cuda.synchronize() |
| return start.elapsed_time(end) * 1000.0 / iters |
|
|
|
|
| def percentile(x: torch.Tensor, q: float) -> torch.Tensor: |
| flat = x.flatten() |
| k = max(1, min(flat.numel(), math.ceil(q * flat.numel()))) |
| return flat.kthvalue(k).values |
|
|
|
|
| def metrics(got, expected): |
| diff = (got.float() - expected.float()).abs().flatten() |
| return float(percentile(diff, 0.99).item()), float( |
| torch.nn.functional.cosine_similarity(got.float().flatten(), expected.float().flatten(), dim=0).item() |
| ) |
|
|
|
|
| def run_one(ops, name: str, shape: tuple[int, int, int, int], args) -> Result: |
| batch, seq_len, heads, head_dim = shape |
| packed, q_w, k_w, freqs_re, freqs_im, q_out, k_out = make_case(*shape) |
| eps = args.eps |
| got_q, got_k = ops.qkv_split_norm_rope_bf16( |
| packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, seq_len, eps, q_out, k_out |
| ) |
| exp_q, exp_k = torch_ref(packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, eps) |
| q_p99, q_cos = metrics(got_q, exp_q) |
| k_p99, k_cos = metrics(got_k, exp_k) |
| flashrt_us = time_us( |
| lambda: ops.qkv_split_norm_rope_bf16( |
| packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, seq_len, eps, q_out, k_out |
| ), |
| args.warmup, |
| args.iters, |
| ) |
| eager_us = time_us( |
| lambda: torch_ref(packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, eps), |
| args.warmup, |
| args.iters, |
| ) |
| status = "PASS" if q_p99 <= args.p99_abs_limit and k_p99 <= args.p99_abs_limit else "FAIL" |
| return Result( |
| shape=name, |
| batch=batch, |
| seq_len=seq_len, |
| heads=heads, |
| head_dim=head_dim, |
| flashrt_us=flashrt_us, |
| torch_eager_us=eager_us, |
| speedup_vs_eager=eager_us / flashrt_us, |
| q_p99_abs=q_p99, |
| k_p99_abs=k_p99, |
| q_cosine=q_cos, |
| k_cosine=k_cos, |
| status=status, |
| ) |
|
|
|
|
| def run_joint3(ops, name: str, video_len: int, action_len: int, und_len: int, heads: int, head_dim: int, args) -> Result: |
| case = make_joint3_case(video_len, action_len, und_len, heads, head_dim) |
| ( |
| packed_v, |
| qkv_v_bias, |
| v_q_w, |
| v_k_w, |
| freqs_re, |
| freqs_im, |
| packed_a, |
| a_q_w, |
| a_k_w, |
| packed_u, |
| u_q_w, |
| u_k_w, |
| q_cat, |
| k_cat, |
| v_cat, |
| ) = case |
| eps = args.eps |
| got_q, got_k, _ = ops.qkv_split_joint3_cat_bf16( |
| packed_v, |
| qkv_v_bias, |
| v_q_w, |
| v_k_w, |
| freqs_re, |
| freqs_im, |
| packed_a, |
| a_q_w, |
| a_k_w, |
| packed_u, |
| u_q_w, |
| u_k_w, |
| heads, |
| head_dim, |
| q_cat, |
| k_cat, |
| v_cat, |
| video_len, |
| eps, |
| eps, |
| eps, |
| ) |
| exp_q, exp_k, _ = torch_ref_joint3( |
| packed_v, |
| qkv_v_bias, |
| v_q_w, |
| v_k_w, |
| freqs_re, |
| freqs_im, |
| packed_a, |
| a_q_w, |
| a_k_w, |
| packed_u, |
| u_q_w, |
| u_k_w, |
| heads, |
| head_dim, |
| eps, |
| ) |
| q_p99, q_cos = metrics(got_q, exp_q) |
| k_p99, k_cos = metrics(got_k, exp_k) |
| flashrt_us = time_us( |
| lambda: ops.qkv_split_joint3_cat_bf16( |
| packed_v, |
| qkv_v_bias, |
| v_q_w, |
| v_k_w, |
| freqs_re, |
| freqs_im, |
| packed_a, |
| a_q_w, |
| a_k_w, |
| packed_u, |
| u_q_w, |
| u_k_w, |
| heads, |
| head_dim, |
| q_cat, |
| k_cat, |
| v_cat, |
| video_len, |
| eps, |
| eps, |
| eps, |
| ), |
| args.warmup, |
| args.iters, |
| ) |
| eager_us = time_us( |
| lambda: torch_ref_joint3( |
| packed_v, |
| qkv_v_bias, |
| v_q_w, |
| v_k_w, |
| freqs_re, |
| freqs_im, |
| packed_a, |
| a_q_w, |
| a_k_w, |
| packed_u, |
| u_q_w, |
| u_k_w, |
| heads, |
| head_dim, |
| eps, |
| ), |
| args.warmup, |
| args.iters, |
| ) |
| status = "PASS" if q_p99 <= args.p99_abs_limit and k_p99 <= args.p99_abs_limit else "FAIL" |
| return Result( |
| shape=name, |
| batch=1, |
| seq_len=video_len + action_len + und_len, |
| heads=heads, |
| head_dim=head_dim, |
| flashrt_us=flashrt_us, |
| torch_eager_us=eager_us, |
| speedup_vs_eager=eager_us / flashrt_us, |
| q_p99_abs=q_p99, |
| k_p99_abs=k_p99, |
| q_cosine=q_cos, |
| k_cosine=k_cos, |
| status=status, |
| ) |
|
|
|
|
| def run_decode_q(ops, name: str, heads: int, args) -> Result: |
| q, _, _, q_w, _, cos, sin = make_decode_case(heads) |
| q_out = torch.empty_like(q) |
| eps = args.eps |
| got = ops.decode_q_norm_rope_stage_bf16(q, q_w, cos, sin, eps, q_out) |
| exp = torch_ref_decode(q, q_w, cos, sin, eps) |
| q_p99, q_cos = metrics(got, exp) |
| flashrt_us = time_us( |
| lambda: ops.decode_q_norm_rope_stage_bf16(q, q_w, cos, sin, eps, q_out), |
| args.warmup, |
| args.iters, |
| ) |
| eager_us = time_us(lambda: torch_ref_decode(q, q_w, cos, sin, eps), args.warmup, args.iters) |
| status = "PASS" if q_p99 <= args.p99_abs_limit else "FAIL" |
| return Result( |
| shape=name, |
| batch=1, |
| seq_len=1, |
| heads=heads, |
| head_dim=128, |
| flashrt_us=flashrt_us, |
| torch_eager_us=eager_us, |
| speedup_vs_eager=eager_us / flashrt_us, |
| q_p99_abs=q_p99, |
| k_p99_abs=0.0, |
| q_cosine=q_cos, |
| k_cosine=1.0, |
| status=status, |
| ) |
|
|
|
|
| def run_decode_kv(ops, name: str, heads: int, devpos: bool, args) -> Result: |
| _, k, v, _, k_w, cos, sin = make_decode_case(heads) |
| k_slot = torch.empty_like(k) |
| v_slot = torch.empty_like(v) |
| eps = args.eps |
| exp_k = torch_ref_decode(k, k_w, cos, sin, eps) |
| if devpos: |
| pos = 3 |
| k_cache = torch.empty((8, heads, 128), device="cuda", dtype=torch.bfloat16) |
| v_cache = torch.empty_like(k_cache) |
| cur_pos = torch.tensor([pos], device="cuda", dtype=torch.int32) |
|
|
| def flashrt_fn(): |
| return ops.decode_k_norm_rope_kvwrite_devpos_bf16(k, v, k_w, cos, sin, cur_pos, k_cache, v_cache, eps) |
|
|
| def eager_fn(): |
| k_cache[pos].copy_(torch_ref_decode(k, k_w, cos, sin, eps)) |
| v_cache[pos].copy_(v) |
| return k_cache, v_cache |
|
|
| flashrt_fn() |
| got_k = k_cache[pos] |
| got_v = v_cache[pos] |
| else: |
| def flashrt_fn(): |
| return ops.decode_k_norm_rope_kvwrite_bf16(k, v, k_w, cos, sin, eps, k_slot, v_slot) |
|
|
| def eager_fn(): |
| k_slot.copy_(torch_ref_decode(k, k_w, cos, sin, eps)) |
| v_slot.copy_(v) |
| return k_slot, v_slot |
|
|
| got_k, got_v = flashrt_fn() |
| k_p99, k_cos = metrics(got_k, exp_k) |
| v_p99, v_cos = metrics(got_v, v) |
| flashrt_us = time_us(flashrt_fn, args.warmup, args.iters) |
| eager_us = time_us(eager_fn, args.warmup, args.iters) |
| status = "PASS" if k_p99 <= args.p99_abs_limit and v_p99 == 0.0 else "FAIL" |
| return Result( |
| shape=name, |
| batch=1, |
| seq_len=1, |
| heads=heads, |
| head_dim=128, |
| flashrt_us=flashrt_us, |
| torch_eager_us=eager_us, |
| speedup_vs_eager=eager_us / flashrt_us, |
| q_p99_abs=v_p99, |
| k_p99_abs=k_p99, |
| q_cosine=v_cos, |
| k_cosine=k_cos, |
| status=status, |
| ) |
|
|
|
|
| def run_kvcache_gqa( |
| ops, |
| name: str, |
| batch: int, |
| seq_len: int, |
| q_heads: int, |
| kv_heads: int, |
| head_dim: int, |
| args, |
| ) -> Result: |
| qkv_dim = (q_heads + 2 * kv_heads) * head_dim |
| packed = torch.randn((batch, seq_len, qkv_dim), device="cuda", dtype=torch.bfloat16) |
| rope = make_interleaved_rope(seq_len, head_dim) |
| cache_offset = 2 |
| max_seq_len = cache_offset + seq_len + 2 |
| q_out = torch.empty((batch, seq_len, q_heads, head_dim), device="cuda", dtype=torch.bfloat16) |
| k_cache = torch.empty((batch, max_seq_len, kv_heads, head_dim), device="cuda", dtype=torch.bfloat16) |
| v_cache = torch.empty_like(k_cache) |
| got_q, got_k_cache, got_v_cache = ops.qkv_split_rope_kvcache_bf16( |
| packed, |
| rope, |
| q_heads, |
| kv_heads, |
| head_dim, |
| cache_offset, |
| q_out, |
| k_cache, |
| v_cache, |
| ) |
| exp_q, exp_k, exp_v = torch_ref_kvcache(packed, rope, q_heads, kv_heads, head_dim) |
| sl = slice(cache_offset, cache_offset + seq_len) |
| q_p99, q_cos = metrics(got_q, exp_q) |
| k_p99, k_cos = metrics(got_k_cache[:, sl], exp_k) |
| v_p99, v_cos = metrics(got_v_cache[:, sl], exp_v) |
|
|
| def flashrt_fn(): |
| return ops.qkv_split_rope_kvcache_bf16( |
| packed, |
| rope, |
| q_heads, |
| kv_heads, |
| head_dim, |
| cache_offset, |
| q_out, |
| k_cache, |
| v_cache, |
| ) |
|
|
| def eager_fn(): |
| exp_q_local, exp_k_local, exp_v_local = torch_ref_kvcache(packed, rope, q_heads, kv_heads, head_dim) |
| q_out.copy_(exp_q_local) |
| k_cache[:, sl].copy_(exp_k_local) |
| v_cache[:, sl].copy_(exp_v_local) |
| return q_out, k_cache, v_cache |
|
|
| flashrt_us = time_us(flashrt_fn, args.warmup, args.iters) |
| eager_us = time_us(eager_fn, args.warmup, args.iters) |
| status = ( |
| "PASS" |
| if q_p99 <= args.p99_abs_limit and k_p99 <= args.p99_abs_limit and v_p99 == 0.0 |
| else "FAIL" |
| ) |
| return Result( |
| shape=name, |
| batch=batch, |
| seq_len=seq_len, |
| heads=q_heads, |
| head_dim=head_dim, |
| flashrt_us=flashrt_us, |
| torch_eager_us=eager_us, |
| speedup_vs_eager=eager_us / flashrt_us, |
| q_p99_abs=max(q_p99, v_p99), |
| k_p99_abs=k_p99, |
| q_cosine=min(q_cos, v_cos), |
| k_cosine=k_cos, |
| status=status, |
| ) |
|
|
|
|
| def write_markdown(path: Path, results: list[Result]) -> None: |
| lines = [ |
| "| Shape | B,L,H,D | FlashRT us | Eager us | vs eager | Q p99 | K p99 | Q cosine | K cosine | Status |", |
| "|---|---:|---:|---:|---:|---:|---:|---:|---:|---|", |
| ] |
| for r in results: |
| lines.append( |
| f"| {r.shape} | {r.batch},{r.seq_len},{r.heads},{r.head_dim} | " |
| f"{r.flashrt_us:.3f} | {r.torch_eager_us:.3f} | {r.speedup_vs_eager:.2f}x | " |
| f"{r.q_p99_abs:.6f} | {r.k_p99_abs:.6f} | {r.q_cosine:.8f} | " |
| f"{r.k_cosine:.8f} | {r.status} |" |
| ) |
| path.write_text("\n".join(lines) + "\n") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--backend", choices=["source", "installed"], default="source") |
| parser.add_argument("--artifact", default=None) |
| parser.add_argument("--shapes", choices=sorted(SHAPE_GROUPS), default="smoke") |
| parser.add_argument("--warmup", type=int, default=5) |
| parser.add_argument("--iters", type=int, default=20) |
| parser.add_argument("--eps", type=float, default=1e-6) |
| parser.add_argument("--p99-abs-limit", type=float, default=0.015625) |
| parser.add_argument("--output", default=None) |
| parser.add_argument("--markdown", default=None) |
| args = parser.parse_args() |
|
|
| if not torch.cuda.is_available(): |
| raise SystemExit("CUDA is required") |
| torch.manual_seed(37) |
| ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact) |
| results = [run_one(ops, name, SHAPES[name], args) for name in SHAPE_GROUPS[args.shapes]] |
| if args.shapes in ("smoke", "all"): |
| results.append(run_joint3(ops, "joint3_small", 64, 8, 4, 8, 128, args)) |
| results.append(run_kvcache_gqa(ops, "pi05_decoder_gqa_kvcache", 1, 10, 8, 1, 256, args)) |
| if args.shapes in ("headline", "all"): |
| results.append(run_joint3(ops, "joint3_vla", 2520, 16, 16, 24, 128, args)) |
| results.append(run_decode_q(ops, "decode_q_stage_h24", 24, args)) |
| results.append(run_decode_kv(ops, "decode_kvwrite_h8", 8, False, args)) |
| results.append(run_decode_kv(ops, "decode_kvwrite_devpos_h8", 8, True, args)) |
| if args.shapes == "headline": |
| results.append(run_kvcache_gqa(ops, "pi05_decoder_gqa_kvcache", 1, 10, 8, 1, 256, args)) |
|
|
| for r in results: |
| print( |
| f"{r.status} {r.shape}: flashrt={r.flashrt_us:.3f}us " |
| f"eager={r.torch_eager_us:.3f}us speedup={r.speedup_vs_eager:.2f}x " |
| f"q_p99={r.q_p99_abs:.6f} k_p99={r.k_p99_abs:.6f}" |
| ) |
|
|
| if args.output: |
| Path(args.output).parent.mkdir(parents=True, exist_ok=True) |
| Path(args.output).write_text(json.dumps([asdict(r) for r in results], indent=2) + "\n") |
| if args.markdown: |
| Path(args.markdown).parent.mkdir(parents=True, exist_ok=True) |
| write_markdown(Path(args.markdown), results) |
|
|
| if any(r.status != "PASS" for r in results): |
| raise SystemExit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|