Buckets:
| """Real low-bit kernel speed for klein-4B's Linear shapes on this Blackwell card. | |
| Our 4B SVDQuant models are fake-quant (bf16) — they measure quality only. To get a REAL | |
| speed number we run the actual compiled kernels at klein-4B's exact Linear shapes and rebuild | |
| the per-step transformer latency: | |
| - bf16 : torch nn.Linear (baseline) | |
| - NVFP4 W4A4 : nunchaku SVDQW4A4Linear(precision='nvfp4') — the real fused FP4 GEMM | |
| + low-rank branch (this is what Exp3's W4A4 model deploys as) | |
| - FP8 (W·A both e4m3) : torch._scaled_mm — proxy for the W4+FP8-act variant's residual compute | |
| (Exp4); + the bf16 low-rank branch | |
| Low-rank branch rank is swept (r64, r128) to show the real 'rank tax'. Timing is per-shape, | |
| weighted by how many such Linears klein-4B has, summed to one denoise-step latency. | |
| Usage: PYTHONPATH=. python3 scripts/23_nvfp4_kernel_bench.py [T] (T = token count, default 1536) | |
| """ | |
| import os, sys, time, json | |
| import torch | |
| from nunchaku.models.linear import SVDQW4A4Linear | |
| dev = "cuda"; dt = torch.bfloat16 | |
| T = int(sys.argv[1]) if len(sys.argv) > 1 else 1536 # 512px = 512 txt + 1024 img tokens | |
| # (in_features, out_features): count — klein-4B's 100 quantized Linears | |
| SHAPES = [((3072, 27648), 20), ((12288, 3072), 20), ((3072, 18432), 10), | |
| ((3072, 3072), 40), ((9216, 3072), 10)] | |
| print(f"=== real low-bit kernel bench | {torch.cuda.get_device_name(0)} | T={T} tokens ===") | |
| def bench(fn, iters=100, warm=20): | |
| for _ in range(warm): fn() | |
| torch.cuda.synchronize(); t = time.perf_counter() | |
| for _ in range(iters): fn() | |
| torch.cuda.synchronize(); return (time.perf_counter() - t) / iters | |
| def bf16_fn(i, o): | |
| lin = torch.nn.Linear(i, o, bias=True).to(dev, dt) | |
| x = torch.randn(1, T, i, device=dev, dtype=dt) | |
| return lambda: lin(x) | |
| def nvfp4_fn(i, o, r): | |
| m = SVDQW4A4Linear(i, o, rank=r, bias=True, precision="nvfp4", torch_dtype=dt, device=dev) | |
| with torch.no_grad(): | |
| m.qweight.random_(-128, 127) | |
| m.wscales.copy_(torch.ones_like(m.wscales)) | |
| m.smooth_factor.fill_(1.0); m.smooth_factor_orig.fill_(1.0) | |
| m.proj_down.normal_(0, 0.02); m.proj_up.normal_(0, 0.02); m.wcscales.fill_(1.0) | |
| x = torch.randn(1, T, i, device=dev, dtype=dt) | |
| return lambda: m(x) | |
| def fp8_fn(i, o, r): | |
| W = torch.randn(o, i, device=dev).to(torch.float8_e4m3fn) # (N,K) row-major; W.t() = (K,N) col-major | |
| sA = torch.ones((), device=dev, dtype=torch.float32) | |
| sB = torch.ones((), device=dev, dtype=torch.float32) | |
| D = torch.randn(i, r, device=dev, dtype=dt); U = torch.randn(o, r, device=dev, dtype=dt) | |
| x = torch.randn(T, i, device=dev, dtype=dt) | |
| def f(): | |
| xf = x.to(torch.float8_e4m3fn) | |
| y = torch._scaled_mm(xf, W.t(), scale_a=sA, scale_b=sB, out_dtype=dt) | |
| y = y + (x @ D) @ U.t() # bf16 low-rank branch | |
| return y | |
| return f | |
| methods = {"bf16": lambda i, o: bf16_fn(i, o), | |
| "nvfp4_w4a4_r64": lambda i, o: nvfp4_fn(i, o, 64), | |
| "nvfp4_w4a4_r128": lambda i, o: nvfp4_fn(i, o, 128), | |
| "fp8_r64": lambda i, o: fp8_fn(i, o, 64), | |
| "fp8_r128": lambda i, o: fp8_fn(i, o, 128)} | |
| # smoke: confirm each method runs on the smallest shape before the full sweep | |
| for name, mk in list(methods.items()): | |
| try: | |
| f = mk(3072, 3072); f(); torch.cuda.synchronize() | |
| except Exception as e: | |
| print(f" [skip {name}] {type(e).__name__}: {str(e)[:120]}") | |
| methods.pop(name) | |
| totals = {m: 0.0 for m in methods} | |
| per_shape = {} | |
| for (i, o), cnt in SHAPES: | |
| row = {} | |
| for name, mk in methods.items(): | |
| us = bench(mk(i, o)) * 1e6 | |
| row[name] = us; totals[name] += us * cnt | |
| per_shape[f"{i}x{o}x{cnt}"] = {k: round(v, 1) for k, v in row.items()} | |
| print(f" {i:>5}->{o:<6} x{cnt:<2}: " + " ".join(f"{n}={row[n]:6.1f}us" for n in methods)) | |
| print("\n================ per denoise-step transformer (sum of 100 Linears) ================") | |
| base = totals.get("bf16", 1.0) | |
| for name in methods: | |
| ms = totals[name] / 1e3 | |
| print(f" {name:18s}: {ms:7.2f} ms/step ({base/totals[name]:.2f}x vs bf16)") | |
| os.makedirs("outputs/nvfp4", exist_ok=True) | |
| json.dump({"T": T, "per_step_ms": {m: round(totals[m]/1e3, 3) for m in methods}, | |
| "speedup_vs_bf16": {m: round(base/totals[m], 3) for m in methods}, | |
| "per_shape_us": per_shape}, | |
| open("outputs/nvfp4/kernel_speed.json", "w"), indent=2) | |
| print("saved -> outputs/nvfp4/kernel_speed.json") | |
Xet Storage Details
- Size:
- 4.53 kB
- Xet hash:
- eca5560d2fc774c8f5ed75924f2fb31801fdf73c9e8caabea213b911e026cee9
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.