liangsu9988's picture
Uploaded using `kernel-builder`.
1fea443 verified
Raw
History Blame
27.5 kB
#!/usr/bin/env python3
"""Benchmark flashrt-qkv-cache-rope against a PyTorch eager postprocess chain."""
from __future__ import annotations
import argparse
import ctypes
import ctypes.util
import importlib
import json
import math
import os
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
PACKAGE = ROOT / "flashrt-qkv-cache-rope"
REGISTRATION_INCLUDE = (
ROOT.parent
/ "kernels"
/ "kernel-builder"
/ "src"
/ "pyproject"
/ "templates"
/ "torch"
)
SHAPES = {
"small": (1, 64, 8, 128),
"wan_1k": (1, 1024, 24, 128),
"wan_2520": (1, 2520, 24, 128),
"wan_4096": (1, 4096, 24, 128),
"vl_512": (1, 512, 16, 128),
}
SHAPE_GROUPS = {
"smoke": ["small"],
"headline": ["wan_1k", "wan_2520", "vl_512"],
"all": list(SHAPES.keys()),
}
@dataclass
class Result:
shape: str
batch: int
seq_len: int
heads: int
head_dim: int
flashrt_us: float
torch_eager_us: float
speedup_vs_eager: float
q_p99_abs: float
k_p99_abs: float
q_cosine: float
k_cosine: float
status: str
class SourceOps:
def __init__(self, namespace: str) -> None:
self._ops = getattr(torch.ops, namespace)
def decode_q_norm_rope_stage_bf16(self, q_pre, q_w, cos, sin, eps=1e-6, q_out=None):
if q_out is None:
q_out = torch.empty_like(q_pre)
self._ops.decode_q_norm_rope_stage_bf16(q_pre, q_w, cos, sin, float(eps), q_out)
return q_out
def decode_k_norm_rope_kvwrite_bf16(self, k_pre, v_pre, k_w, cos, sin, eps=1e-6, k_out=None, v_out=None):
if k_out is None:
k_out = torch.empty_like(k_pre)
if v_out is None:
v_out = torch.empty_like(v_pre)
self._ops.decode_k_norm_rope_kvwrite_bf16(k_pre, v_pre, k_w, cos, sin, float(eps), k_out, v_out)
return k_out, v_out
def decode_k_norm_rope_kvwrite_devpos_bf16(self, k_pre, v_pre, k_w, cos, sin, cur_pos, k_cache, v_cache, eps=1e-6):
self._ops.decode_k_norm_rope_kvwrite_devpos_bf16(k_pre, v_pre, k_w, cos, sin, cur_pos, float(eps), k_cache, v_cache)
return k_cache, v_cache
def qkv_split_norm_rope_bf16(
self, packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, rope_seq_len=None, eps=1e-6, q_out=None, k_out=None
):
if rope_seq_len is None:
rope_seq_len = packed.shape[1]
if q_out is None:
q_out = torch.empty((packed.shape[0], packed.shape[1], heads, head_dim), device=packed.device, dtype=torch.bfloat16)
if k_out is None:
k_out = torch.empty_like(q_out)
self._ops.qkv_split_norm_rope_bf16(
packed, q_w, k_w, freqs_re, freqs_im, int(heads), int(head_dim),
int(rope_seq_len), float(eps), q_out, k_out
)
return q_out, k_out
def qkv_split_joint3_cat_bf16(
self,
packed_v,
qkv_v_bias,
norm_v_q_weight,
norm_v_k_weight,
freqs_re,
freqs_im,
packed_a,
norm_a_q_weight,
norm_a_k_weight,
packed_u,
norm_u_q_weight,
norm_u_k_weight,
heads,
head_dim,
q_cat_out,
k_cat_out,
v_cat_out,
rope_seq_len=None,
eps_v=1e-6,
eps_a=1e-6,
eps_u=1e-6,
):
if rope_seq_len is None:
rope_seq_len = packed_v.shape[1]
self._ops.qkv_split_joint3_cat_bf16(
packed_v,
qkv_v_bias,
norm_v_q_weight,
norm_v_k_weight,
freqs_re,
freqs_im,
packed_a,
norm_a_q_weight,
norm_a_k_weight,
packed_u,
norm_u_q_weight,
norm_u_k_weight,
int(heads),
int(head_dim),
int(rope_seq_len),
float(eps_v),
float(eps_a),
float(eps_u),
q_cat_out,
k_cat_out,
v_cat_out,
)
return q_cat_out, k_cat_out, v_cat_out
def qkv_split_rope_kvcache_bf16(
self,
packed_qkv,
rope,
q_heads,
kv_heads,
head_dim,
cache_offset,
q_out=None,
k_cache=None,
v_cache=None,
max_seq_len=None,
):
batch, seq_len, _ = packed_qkv.shape
if q_out is None:
q_out = torch.empty(
(batch, seq_len, q_heads, head_dim),
device=packed_qkv.device,
dtype=torch.bfloat16,
)
if k_cache is None or v_cache is None:
if max_seq_len is None:
max_seq_len = cache_offset + seq_len
cache_shape = (batch, max_seq_len, kv_heads, head_dim)
if k_cache is None:
k_cache = torch.empty(cache_shape, device=packed_qkv.device, dtype=torch.bfloat16)
if v_cache is None:
v_cache = torch.empty(cache_shape, device=packed_qkv.device, dtype=torch.bfloat16)
self._ops.qkv_split_rope_kvcache_bf16(
packed_qkv,
rope,
int(q_heads),
int(kv_heads),
int(head_dim),
int(cache_offset),
q_out,
k_cache,
v_cache,
)
return q_out, k_cache, v_cache
def _preload_cublaslt() -> None:
for parent in Path(torch.__file__).resolve().parents:
candidate = parent / "nvidia" / "cublas" / "lib" / "libcublasLt.so.12"
if candidate.exists():
ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL)
return
library = ctypes.util.find_library("cublasLt")
if library:
ctypes.CDLL(library, mode=ctypes.RTLD_GLOBAL)
def _current_arch_list() -> str:
major, minor = torch.cuda.get_device_capability(0)
return f"{major}.{minor}"
def load_source_ops() -> SourceOps:
from torch.utils.cpp_extension import load
if not REGISTRATION_INCLUDE.is_dir():
raise RuntimeError(f"missing kernel-builder registration include: {REGISTRATION_INCLUDE}")
_preload_cublaslt()
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", _current_arch_list())
namespace = "flashrt_qkv_cache_rope_benchmark"
load(
name=namespace,
sources=[
str(PACKAGE / "torch-ext" / "torch_binding.cpp"),
str(PACKAGE / "csrc" / "qkv_cache_rope.cu"),
],
extra_include_paths=[str(PACKAGE / "csrc"), str(REGISTRATION_INCLUDE)],
extra_cflags=["-O3", "-DCUDA_KERNEL"],
extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr", "-DCUDA_KERNEL"],
verbose=False,
)
return SourceOps(namespace)
def load_installed_ops(artifact: str | None):
if artifact:
sys.path.insert(0, artifact)
try:
return importlib.import_module("flashrt_qkv_cache_rope")
finally:
if artifact:
sys.path.remove(artifact)
def make_freqs(seq_len: int, head_dim: int):
theta = torch.randn((seq_len, head_dim // 2), device="cuda", dtype=torch.float32)
return torch.cos(theta).contiguous(), torch.sin(theta).contiguous()
def make_interleaved_rope(seq_len: int, head_dim: int):
theta = torch.randn((seq_len, head_dim // 2), device="cuda", dtype=torch.float32)
cos = torch.cos(theta).to(torch.bfloat16)
sin = torch.sin(theta).to(torch.bfloat16)
return torch.stack([cos, sin], dim=-1).reshape(seq_len, head_dim).contiguous()
def make_case(batch: int, seq_len: int, heads: int, head_dim: int):
dim = heads * head_dim
packed = torch.randn((batch, seq_len, 3 * dim), device="cuda", dtype=torch.bfloat16)
q_w = (1.0 + 0.1 * torch.randn((dim,), device="cuda", dtype=torch.bfloat16)).contiguous()
k_w = (1.0 + 0.1 * torch.randn((dim,), device="cuda", dtype=torch.bfloat16)).contiguous()
freqs_re, freqs_im = make_freqs(seq_len, head_dim)
q_out = torch.empty((batch, seq_len, heads, head_dim), device="cuda", dtype=torch.bfloat16)
k_out = torch.empty_like(q_out)
return packed, q_w, k_w, freqs_re, freqs_im, q_out, k_out
def make_decode_case(heads: int):
q = torch.randn((heads, 128), device="cuda", dtype=torch.bfloat16)
k = torch.randn((heads, 128), device="cuda", dtype=torch.bfloat16)
v = torch.randn((heads, 128), device="cuda", dtype=torch.bfloat16)
q_w = (1.0 + 0.1 * torch.randn((128,), device="cuda", dtype=torch.bfloat16)).contiguous()
k_w = (1.0 + 0.1 * torch.randn((128,), device="cuda", dtype=torch.bfloat16)).contiguous()
theta = torch.randn((64,), device="cuda", dtype=torch.float32)
cos = torch.cos(theta).to(torch.bfloat16).contiguous()
sin = torch.sin(theta).to(torch.bfloat16).contiguous()
return q, k, v, q_w, k_w, cos, sin
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
rms = torch.rsqrt(torch.mean(x.float() * x.float(), dim=-1, keepdim=True) + eps)
return x.float() * rms * weight.float()
def apply_pair_rope(x: torch.Tensor, freqs_re: torch.Tensor, freqs_im: torch.Tensor):
batch, seq_len, heads, head_dim = x.shape
pair = x.float().reshape(batch, seq_len, heads, head_dim // 2, 2)
re = pair[..., 0]
im = pair[..., 1]
fr = freqs_re.view(1, seq_len, 1, head_dim // 2)
fi = freqs_im.view(1, seq_len, 1, head_dim // 2)
out = torch.empty_like(pair.float())
out[..., 0] = re * fr - im * fi
out[..., 1] = re * fi + im * fr
return out.reshape(batch, seq_len, heads, head_dim).to(torch.bfloat16)
def apply_interleaved_pair_rope(x: torch.Tensor, rope: torch.Tensor):
batch, seq_len, heads, head_dim = x.shape
pair = x.float().reshape(batch, seq_len, heads, head_dim // 2, 2)
re = pair[..., 0]
im = pair[..., 1]
rope_pair = rope[:seq_len].float().reshape(seq_len, head_dim // 2, 2)
cos = rope_pair[..., 0].view(1, seq_len, 1, head_dim // 2)
sin = rope_pair[..., 1].view(1, seq_len, 1, head_dim // 2)
out = torch.empty_like(pair.float())
out[..., 0] = re * cos - im * sin
out[..., 1] = re * sin + im * cos
return out.reshape(batch, seq_len, heads, head_dim).to(torch.bfloat16)
def apply_rotate_half_rope_128(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
xf = x.float()
out = torch.empty_like(xf)
c = cos.float().view(1, 64)
s = sin.float().view(1, 64)
out[:, :64] = xf[:, :64] * c - xf[:, 64:] * s
out[:, 64:] = xf[:, 64:] * c + xf[:, :64] * s
return out.to(torch.bfloat16)
def torch_ref(packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, eps):
batch, seq_len, _ = packed.shape
dim = heads * head_dim
q = packed[:, :, :dim]
k = packed[:, :, dim : 2 * dim]
qn = rms_norm(q, q_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim)
kn = rms_norm(k, k_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim)
return apply_pair_rope(qn, freqs_re, freqs_im), apply_pair_rope(kn, freqs_re, freqs_im)
def torch_ref_bias(packed, qkv_bias, q_w, k_w, freqs_re, freqs_im, heads, head_dim, eps):
batch, seq_len, _ = packed.shape
dim = heads * head_dim
biased = packed.float() + qkv_bias.float().view(1, 1, 3 * dim)
q = biased[:, :, :dim]
k = biased[:, :, dim : 2 * dim]
v = biased[:, :, 2 * dim :].to(torch.bfloat16).view(batch, seq_len, heads, head_dim)
qn = rms_norm(q, q_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim)
kn = rms_norm(k, k_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim)
return apply_pair_rope(qn, freqs_re, freqs_im), apply_pair_rope(kn, freqs_re, freqs_im), v
def torch_ref_no_rope(packed, q_w, k_w, heads, head_dim, eps):
batch, seq_len, _ = packed.shape
dim = heads * head_dim
q = packed[:, :, :dim]
k = packed[:, :, dim : 2 * dim]
v = packed[:, :, 2 * dim :].view(batch, seq_len, heads, head_dim)
qn = rms_norm(q, q_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim)
kn = rms_norm(k, k_w, eps).to(torch.bfloat16).view(batch, seq_len, heads, head_dim)
return qn, kn, v
def torch_ref_decode(x, weight, cos, sin, eps):
return apply_rotate_half_rope_128(rms_norm(x, weight, eps).to(torch.bfloat16), cos, sin)
def torch_ref_kvcache(packed_qkv, rope, q_heads, kv_heads, head_dim):
batch, seq_len, _ = packed_qkv.shape
q_dim = q_heads * head_dim
kv_dim = kv_heads * head_dim
q = packed_qkv[:, :, :q_dim].view(batch, seq_len, q_heads, head_dim)
k = packed_qkv[:, :, q_dim : q_dim + kv_dim].view(batch, seq_len, kv_heads, head_dim)
v = packed_qkv[:, :, q_dim + kv_dim :].view(batch, seq_len, kv_heads, head_dim)
return apply_interleaved_pair_rope(q, rope), apply_interleaved_pair_rope(k, rope), v
def make_joint3_case(video_len: int, action_len: int, und_len: int, heads: int, head_dim: int):
packed_v, v_q_w, v_k_w, freqs_re, freqs_im, _, _ = make_case(1, video_len, heads, head_dim)
packed_a, a_q_w, a_k_w, _, _, _, _ = make_case(1, action_len, heads, head_dim)
packed_u, u_q_w, u_k_w, _, _, _, _ = make_case(1, und_len, heads, head_dim)
dim = heads * head_dim
qkv_v_bias = (0.02 * torch.randn((3 * dim,), device="cuda", dtype=torch.bfloat16)).contiguous()
total = video_len + action_len + und_len
q_cat = torch.empty((1, total, heads, head_dim), device="cuda", dtype=torch.bfloat16)
k_cat = torch.empty_like(q_cat)
v_cat = torch.empty_like(q_cat)
return (
packed_v,
qkv_v_bias,
v_q_w,
v_k_w,
freqs_re,
freqs_im,
packed_a,
a_q_w,
a_k_w,
packed_u,
u_q_w,
u_k_w,
q_cat,
k_cat,
v_cat,
)
def torch_ref_joint3(
packed_v,
qkv_v_bias,
v_q_w,
v_k_w,
freqs_re,
freqs_im,
packed_a,
a_q_w,
a_k_w,
packed_u,
u_q_w,
u_k_w,
heads,
head_dim,
eps,
):
qv, kv, vv = torch_ref_bias(packed_v, qkv_v_bias, v_q_w, v_k_w, freqs_re, freqs_im, heads, head_dim, eps)
qa, ka, va = torch_ref_no_rope(packed_a, a_q_w, a_k_w, heads, head_dim, eps)
qu, ku, vu = torch_ref_no_rope(packed_u, u_q_w, u_k_w, heads, head_dim, eps)
return torch.cat([qv, qa, qu], dim=1), torch.cat([kv, ka, ku], dim=1), torch.cat([vv, va, vu], dim=1)
def time_us(fn, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) * 1000.0 / iters
def percentile(x: torch.Tensor, q: float) -> torch.Tensor:
flat = x.flatten()
k = max(1, min(flat.numel(), math.ceil(q * flat.numel())))
return flat.kthvalue(k).values
def metrics(got, expected):
diff = (got.float() - expected.float()).abs().flatten()
return float(percentile(diff, 0.99).item()), float(
torch.nn.functional.cosine_similarity(got.float().flatten(), expected.float().flatten(), dim=0).item()
)
def run_one(ops, name: str, shape: tuple[int, int, int, int], args) -> Result:
batch, seq_len, heads, head_dim = shape
packed, q_w, k_w, freqs_re, freqs_im, q_out, k_out = make_case(*shape)
eps = args.eps
got_q, got_k = ops.qkv_split_norm_rope_bf16(
packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, seq_len, eps, q_out, k_out
)
exp_q, exp_k = torch_ref(packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, eps)
q_p99, q_cos = metrics(got_q, exp_q)
k_p99, k_cos = metrics(got_k, exp_k)
flashrt_us = time_us(
lambda: ops.qkv_split_norm_rope_bf16(
packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, seq_len, eps, q_out, k_out
),
args.warmup,
args.iters,
)
eager_us = time_us(
lambda: torch_ref(packed, q_w, k_w, freqs_re, freqs_im, heads, head_dim, eps),
args.warmup,
args.iters,
)
status = "PASS" if q_p99 <= args.p99_abs_limit and k_p99 <= args.p99_abs_limit else "FAIL"
return Result(
shape=name,
batch=batch,
seq_len=seq_len,
heads=heads,
head_dim=head_dim,
flashrt_us=flashrt_us,
torch_eager_us=eager_us,
speedup_vs_eager=eager_us / flashrt_us,
q_p99_abs=q_p99,
k_p99_abs=k_p99,
q_cosine=q_cos,
k_cosine=k_cos,
status=status,
)
def run_joint3(ops, name: str, video_len: int, action_len: int, und_len: int, heads: int, head_dim: int, args) -> Result:
case = make_joint3_case(video_len, action_len, und_len, heads, head_dim)
(
packed_v,
qkv_v_bias,
v_q_w,
v_k_w,
freqs_re,
freqs_im,
packed_a,
a_q_w,
a_k_w,
packed_u,
u_q_w,
u_k_w,
q_cat,
k_cat,
v_cat,
) = case
eps = args.eps
got_q, got_k, _ = ops.qkv_split_joint3_cat_bf16(
packed_v,
qkv_v_bias,
v_q_w,
v_k_w,
freqs_re,
freqs_im,
packed_a,
a_q_w,
a_k_w,
packed_u,
u_q_w,
u_k_w,
heads,
head_dim,
q_cat,
k_cat,
v_cat,
video_len,
eps,
eps,
eps,
)
exp_q, exp_k, _ = torch_ref_joint3(
packed_v,
qkv_v_bias,
v_q_w,
v_k_w,
freqs_re,
freqs_im,
packed_a,
a_q_w,
a_k_w,
packed_u,
u_q_w,
u_k_w,
heads,
head_dim,
eps,
)
q_p99, q_cos = metrics(got_q, exp_q)
k_p99, k_cos = metrics(got_k, exp_k)
flashrt_us = time_us(
lambda: ops.qkv_split_joint3_cat_bf16(
packed_v,
qkv_v_bias,
v_q_w,
v_k_w,
freqs_re,
freqs_im,
packed_a,
a_q_w,
a_k_w,
packed_u,
u_q_w,
u_k_w,
heads,
head_dim,
q_cat,
k_cat,
v_cat,
video_len,
eps,
eps,
eps,
),
args.warmup,
args.iters,
)
eager_us = time_us(
lambda: torch_ref_joint3(
packed_v,
qkv_v_bias,
v_q_w,
v_k_w,
freqs_re,
freqs_im,
packed_a,
a_q_w,
a_k_w,
packed_u,
u_q_w,
u_k_w,
heads,
head_dim,
eps,
),
args.warmup,
args.iters,
)
status = "PASS" if q_p99 <= args.p99_abs_limit and k_p99 <= args.p99_abs_limit else "FAIL"
return Result(
shape=name,
batch=1,
seq_len=video_len + action_len + und_len,
heads=heads,
head_dim=head_dim,
flashrt_us=flashrt_us,
torch_eager_us=eager_us,
speedup_vs_eager=eager_us / flashrt_us,
q_p99_abs=q_p99,
k_p99_abs=k_p99,
q_cosine=q_cos,
k_cosine=k_cos,
status=status,
)
def run_decode_q(ops, name: str, heads: int, args) -> Result:
q, _, _, q_w, _, cos, sin = make_decode_case(heads)
q_out = torch.empty_like(q)
eps = args.eps
got = ops.decode_q_norm_rope_stage_bf16(q, q_w, cos, sin, eps, q_out)
exp = torch_ref_decode(q, q_w, cos, sin, eps)
q_p99, q_cos = metrics(got, exp)
flashrt_us = time_us(
lambda: ops.decode_q_norm_rope_stage_bf16(q, q_w, cos, sin, eps, q_out),
args.warmup,
args.iters,
)
eager_us = time_us(lambda: torch_ref_decode(q, q_w, cos, sin, eps), args.warmup, args.iters)
status = "PASS" if q_p99 <= args.p99_abs_limit else "FAIL"
return Result(
shape=name,
batch=1,
seq_len=1,
heads=heads,
head_dim=128,
flashrt_us=flashrt_us,
torch_eager_us=eager_us,
speedup_vs_eager=eager_us / flashrt_us,
q_p99_abs=q_p99,
k_p99_abs=0.0,
q_cosine=q_cos,
k_cosine=1.0,
status=status,
)
def run_decode_kv(ops, name: str, heads: int, devpos: bool, args) -> Result:
_, k, v, _, k_w, cos, sin = make_decode_case(heads)
k_slot = torch.empty_like(k)
v_slot = torch.empty_like(v)
eps = args.eps
exp_k = torch_ref_decode(k, k_w, cos, sin, eps)
if devpos:
pos = 3
k_cache = torch.empty((8, heads, 128), device="cuda", dtype=torch.bfloat16)
v_cache = torch.empty_like(k_cache)
cur_pos = torch.tensor([pos], device="cuda", dtype=torch.int32)
def flashrt_fn():
return ops.decode_k_norm_rope_kvwrite_devpos_bf16(k, v, k_w, cos, sin, cur_pos, k_cache, v_cache, eps)
def eager_fn():
k_cache[pos].copy_(torch_ref_decode(k, k_w, cos, sin, eps))
v_cache[pos].copy_(v)
return k_cache, v_cache
flashrt_fn()
got_k = k_cache[pos]
got_v = v_cache[pos]
else:
def flashrt_fn():
return ops.decode_k_norm_rope_kvwrite_bf16(k, v, k_w, cos, sin, eps, k_slot, v_slot)
def eager_fn():
k_slot.copy_(torch_ref_decode(k, k_w, cos, sin, eps))
v_slot.copy_(v)
return k_slot, v_slot
got_k, got_v = flashrt_fn()
k_p99, k_cos = metrics(got_k, exp_k)
v_p99, v_cos = metrics(got_v, v)
flashrt_us = time_us(flashrt_fn, args.warmup, args.iters)
eager_us = time_us(eager_fn, args.warmup, args.iters)
status = "PASS" if k_p99 <= args.p99_abs_limit and v_p99 == 0.0 else "FAIL"
return Result(
shape=name,
batch=1,
seq_len=1,
heads=heads,
head_dim=128,
flashrt_us=flashrt_us,
torch_eager_us=eager_us,
speedup_vs_eager=eager_us / flashrt_us,
q_p99_abs=v_p99,
k_p99_abs=k_p99,
q_cosine=v_cos,
k_cosine=k_cos,
status=status,
)
def run_kvcache_gqa(
ops,
name: str,
batch: int,
seq_len: int,
q_heads: int,
kv_heads: int,
head_dim: int,
args,
) -> Result:
qkv_dim = (q_heads + 2 * kv_heads) * head_dim
packed = torch.randn((batch, seq_len, qkv_dim), device="cuda", dtype=torch.bfloat16)
rope = make_interleaved_rope(seq_len, head_dim)
cache_offset = 2
max_seq_len = cache_offset + seq_len + 2
q_out = torch.empty((batch, seq_len, q_heads, head_dim), device="cuda", dtype=torch.bfloat16)
k_cache = torch.empty((batch, max_seq_len, kv_heads, head_dim), device="cuda", dtype=torch.bfloat16)
v_cache = torch.empty_like(k_cache)
got_q, got_k_cache, got_v_cache = ops.qkv_split_rope_kvcache_bf16(
packed,
rope,
q_heads,
kv_heads,
head_dim,
cache_offset,
q_out,
k_cache,
v_cache,
)
exp_q, exp_k, exp_v = torch_ref_kvcache(packed, rope, q_heads, kv_heads, head_dim)
sl = slice(cache_offset, cache_offset + seq_len)
q_p99, q_cos = metrics(got_q, exp_q)
k_p99, k_cos = metrics(got_k_cache[:, sl], exp_k)
v_p99, v_cos = metrics(got_v_cache[:, sl], exp_v)
def flashrt_fn():
return ops.qkv_split_rope_kvcache_bf16(
packed,
rope,
q_heads,
kv_heads,
head_dim,
cache_offset,
q_out,
k_cache,
v_cache,
)
def eager_fn():
exp_q_local, exp_k_local, exp_v_local = torch_ref_kvcache(packed, rope, q_heads, kv_heads, head_dim)
q_out.copy_(exp_q_local)
k_cache[:, sl].copy_(exp_k_local)
v_cache[:, sl].copy_(exp_v_local)
return q_out, k_cache, v_cache
flashrt_us = time_us(flashrt_fn, args.warmup, args.iters)
eager_us = time_us(eager_fn, args.warmup, args.iters)
status = (
"PASS"
if q_p99 <= args.p99_abs_limit and k_p99 <= args.p99_abs_limit and v_p99 == 0.0
else "FAIL"
)
return Result(
shape=name,
batch=batch,
seq_len=seq_len,
heads=q_heads,
head_dim=head_dim,
flashrt_us=flashrt_us,
torch_eager_us=eager_us,
speedup_vs_eager=eager_us / flashrt_us,
q_p99_abs=max(q_p99, v_p99),
k_p99_abs=k_p99,
q_cosine=min(q_cos, v_cos),
k_cosine=k_cos,
status=status,
)
def write_markdown(path: Path, results: list[Result]) -> None:
lines = [
"| Shape | B,L,H,D | FlashRT us | Eager us | vs eager | Q p99 | K p99 | Q cosine | K cosine | Status |",
"|---|---:|---:|---:|---:|---:|---:|---:|---:|---|",
]
for r in results:
lines.append(
f"| {r.shape} | {r.batch},{r.seq_len},{r.heads},{r.head_dim} | "
f"{r.flashrt_us:.3f} | {r.torch_eager_us:.3f} | {r.speedup_vs_eager:.2f}x | "
f"{r.q_p99_abs:.6f} | {r.k_p99_abs:.6f} | {r.q_cosine:.8f} | "
f"{r.k_cosine:.8f} | {r.status} |"
)
path.write_text("\n".join(lines) + "\n")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=["source", "installed"], default="source")
parser.add_argument("--artifact", default=None)
parser.add_argument("--shapes", choices=sorted(SHAPE_GROUPS), default="smoke")
parser.add_argument("--warmup", type=int, default=5)
parser.add_argument("--iters", type=int, default=20)
parser.add_argument("--eps", type=float, default=1e-6)
parser.add_argument("--p99-abs-limit", type=float, default=0.015625)
parser.add_argument("--output", default=None)
parser.add_argument("--markdown", default=None)
args = parser.parse_args()
if not torch.cuda.is_available():
raise SystemExit("CUDA is required")
torch.manual_seed(37)
ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact)
results = [run_one(ops, name, SHAPES[name], args) for name in SHAPE_GROUPS[args.shapes]]
if args.shapes in ("smoke", "all"):
results.append(run_joint3(ops, "joint3_small", 64, 8, 4, 8, 128, args))
results.append(run_kvcache_gqa(ops, "pi05_decoder_gqa_kvcache", 1, 10, 8, 1, 256, args))
if args.shapes in ("headline", "all"):
results.append(run_joint3(ops, "joint3_vla", 2520, 16, 16, 24, 128, args))
results.append(run_decode_q(ops, "decode_q_stage_h24", 24, args))
results.append(run_decode_kv(ops, "decode_kvwrite_h8", 8, False, args))
results.append(run_decode_kv(ops, "decode_kvwrite_devpos_h8", 8, True, args))
if args.shapes == "headline":
results.append(run_kvcache_gqa(ops, "pi05_decoder_gqa_kvcache", 1, 10, 8, 1, 256, args))
for r in results:
print(
f"{r.status} {r.shape}: flashrt={r.flashrt_us:.3f}us "
f"eager={r.torch_eager_us:.3f}us speedup={r.speedup_vs_eager:.2f}x "
f"q_p99={r.q_p99_abs:.6f} k_p99={r.k_p99_abs:.6f}"
)
if args.output:
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
Path(args.output).write_text(json.dumps([asdict(r) for r in results], indent=2) + "\n")
if args.markdown:
Path(args.markdown).parent.mkdir(parents=True, exist_ok=True)
write_markdown(Path(args.markdown), results)
if any(r.status != "PASS" for r in results):
raise SystemExit(1)
if __name__ == "__main__":
main()