Buckets:

Mercity/FluxDistill / scripts /23_nvfp4_kernel_bench.py
Pranav2748's picture
download
raw
4.53 kB
"""Real low-bit kernel speed for klein-4B's Linear shapes on this Blackwell card.
Our 4B SVDQuant models are fake-quant (bf16) — they measure quality only. To get a REAL
speed number we run the actual compiled kernels at klein-4B's exact Linear shapes and rebuild
the per-step transformer latency:
- bf16 : torch nn.Linear (baseline)
- NVFP4 W4A4 : nunchaku SVDQW4A4Linear(precision='nvfp4') — the real fused FP4 GEMM
+ low-rank branch (this is what Exp3's W4A4 model deploys as)
- FP8 (W·A both e4m3) : torch._scaled_mm — proxy for the W4+FP8-act variant's residual compute
(Exp4); + the bf16 low-rank branch
Low-rank branch rank is swept (r64, r128) to show the real 'rank tax'. Timing is per-shape,
weighted by how many such Linears klein-4B has, summed to one denoise-step latency.
Usage: PYTHONPATH=. python3 scripts/23_nvfp4_kernel_bench.py [T] (T = token count, default 1536)
"""
import os, sys, time, json
import torch
from nunchaku.models.linear import SVDQW4A4Linear
dev = "cuda"; dt = torch.bfloat16
T = int(sys.argv[1]) if len(sys.argv) > 1 else 1536 # 512px = 512 txt + 1024 img tokens
# (in_features, out_features): count — klein-4B's 100 quantized Linears
SHAPES = [((3072, 27648), 20), ((12288, 3072), 20), ((3072, 18432), 10),
((3072, 3072), 40), ((9216, 3072), 10)]
print(f"=== real low-bit kernel bench | {torch.cuda.get_device_name(0)} | T={T} tokens ===")
def bench(fn, iters=100, warm=20):
for _ in range(warm): fn()
torch.cuda.synchronize(); t = time.perf_counter()
for _ in range(iters): fn()
torch.cuda.synchronize(); return (time.perf_counter() - t) / iters
def bf16_fn(i, o):
lin = torch.nn.Linear(i, o, bias=True).to(dev, dt)
x = torch.randn(1, T, i, device=dev, dtype=dt)
return lambda: lin(x)
def nvfp4_fn(i, o, r):
m = SVDQW4A4Linear(i, o, rank=r, bias=True, precision="nvfp4", torch_dtype=dt, device=dev)
with torch.no_grad():
m.qweight.random_(-128, 127)
m.wscales.copy_(torch.ones_like(m.wscales))
m.smooth_factor.fill_(1.0); m.smooth_factor_orig.fill_(1.0)
m.proj_down.normal_(0, 0.02); m.proj_up.normal_(0, 0.02); m.wcscales.fill_(1.0)
x = torch.randn(1, T, i, device=dev, dtype=dt)
return lambda: m(x)
def fp8_fn(i, o, r):
W = torch.randn(o, i, device=dev).to(torch.float8_e4m3fn) # (N,K) row-major; W.t() = (K,N) col-major
sA = torch.ones((), device=dev, dtype=torch.float32)
sB = torch.ones((), device=dev, dtype=torch.float32)
D = torch.randn(i, r, device=dev, dtype=dt); U = torch.randn(o, r, device=dev, dtype=dt)
x = torch.randn(T, i, device=dev, dtype=dt)
def f():
xf = x.to(torch.float8_e4m3fn)
y = torch._scaled_mm(xf, W.t(), scale_a=sA, scale_b=sB, out_dtype=dt)
y = y + (x @ D) @ U.t() # bf16 low-rank branch
return y
return f
methods = {"bf16": lambda i, o: bf16_fn(i, o),
"nvfp4_w4a4_r64": lambda i, o: nvfp4_fn(i, o, 64),
"nvfp4_w4a4_r128": lambda i, o: nvfp4_fn(i, o, 128),
"fp8_r64": lambda i, o: fp8_fn(i, o, 64),
"fp8_r128": lambda i, o: fp8_fn(i, o, 128)}
# smoke: confirm each method runs on the smallest shape before the full sweep
for name, mk in list(methods.items()):
try:
f = mk(3072, 3072); f(); torch.cuda.synchronize()
except Exception as e:
print(f" [skip {name}] {type(e).__name__}: {str(e)[:120]}")
methods.pop(name)
totals = {m: 0.0 for m in methods}
per_shape = {}
for (i, o), cnt in SHAPES:
row = {}
for name, mk in methods.items():
us = bench(mk(i, o)) * 1e6
row[name] = us; totals[name] += us * cnt
per_shape[f"{i}x{o}x{cnt}"] = {k: round(v, 1) for k, v in row.items()}
print(f" {i:>5}->{o:<6} x{cnt:<2}: " + " ".join(f"{n}={row[n]:6.1f}us" for n in methods))
print("\n================ per denoise-step transformer (sum of 100 Linears) ================")
base = totals.get("bf16", 1.0)
for name in methods:
ms = totals[name] / 1e3
print(f" {name:18s}: {ms:7.2f} ms/step ({base/totals[name]:.2f}x vs bf16)")
os.makedirs("outputs/nvfp4", exist_ok=True)
json.dump({"T": T, "per_step_ms": {m: round(totals[m]/1e3, 3) for m in methods},
"speedup_vs_bf16": {m: round(base/totals[m], 3) for m in methods},
"per_shape_us": per_shape},
open("outputs/nvfp4/kernel_speed.json", "w"), indent=2)
print("saved -> outputs/nvfp4/kernel_speed.json")

Xet Storage Details

Size:
4.53 kB
·
Xet hash:
eca5560d2fc774c8f5ed75924f2fb31801fdf73c9e8caabea213b911e026cee9

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.