Buckets:
| # Speedup ideation — NVFP4-fused klein-4B on Blackwell (sm_120) | |
| Where we are: the fused NVFP4 model runs at **~1.9× end-to-end / 2.3× per-step** vs bf16, −26% VRAM. | |
| The per-step transformer (2.3×) is already near the per-layer GEMM ceiling (2.5×). The gap to "more" | |
| is NOT the GEMM — it's everything else. This doc is a prioritized menu for the next session. We are | |
| happy to depart from Nunchaku's literature where it pays. | |
| ## 0. Where the time actually goes (measure first, always) | |
| From the data: at 512px/4-step the bf16 step is ~92 ms, fused ~42 ms. End-to-end fused 0.264 s of which | |
| 4×42=168 ms is denoise → **~96 ms (36%) is NON-transformer: VAE decode + text-encode + latent prep + | |
| launch overhead.** That fixed overhead is the single biggest cap on end-to-end speedup for the small 4B. | |
| **Action 0:** instrument a per-stage timeline (text-encode | per-step{qkv|attn|ff} | vae-decode | python | |
| glue) with CUDA events. Every idea below should be justified against this breakdown. | |
| ## Tier 1 — high impact, low risk (do these first) | |
| 1. **CUDA Graphs over the denoise loop.** The 4B's GEMMs are small → the path is partly *launch-bound* | |
| (dozens of tiny kernel launches × 4 steps). Capture the per-step transformer (or the whole 4-step | |
| loop) into a CUDA graph → eliminate launch latency + CPU dispatch. Biggest, cheapest win at small | |
| batch/res. Expect 10–30% end-to-end. Risk: dynamic shapes (fixed at capture — fine, res is fixed). | |
| 2. **Quantize the text encoder (Qwen3).** It's bf16 now and a real chunk of the fixed overhead. Nunchaku | |
| ships `svdq-int4-Qwen3-text` (→ NVFP4 on Blackwell). Or cache prompt embeddings across a batch/session | |
| (encode-once). Cuts the ~text-encode slice to near-zero on repeat. Expect 5–15% end-to-end. | |
| 3. **Quantize / fuse the VAE decoder.** Conv-heavy, bf16 now. FP8 convs on Blackwell, or tile+fuse the | |
| decoder. The VAE is a fixed per-image cost that dominates more as the transformer gets faster. | |
| 4. **Fix the fused path's batch=1 limitation.** The packed-rotary asserts `rot.shape[0]*shape[1]==M`; | |
| make the rotary broadcast over batch (replicate per-batch, or index by token-mod-N). Unlocks batched | |
| throughput — critical for serving (right now bf16 b4 beats fused b1 on img/s at some sizes). Expect | |
| large throughput gains at batch>1. | |
| ## Tier 2 — kernel-level (medium effort, we build Nunchaku from source now) | |
| 5. **Output-quantization fusion (act-quant of layer N+1 in layer N's epilogue).** The kernel already has | |
| a `qout` output ("packed quantized input for the next layer"). If the FLUX2 path doesn't use it, each | |
| Linear pays a separate `quantize_w4a4_act_fuse_lora` pass. Fusing the next layer's act-quant into the | |
| current GEMM's epilogue removes a full activation read/write per layer. Expect 5–15% per-step. | |
| 6. **GEMM tile / warp-config autotune for sm_120.** The kernel hardcodes `warp_n=128`, comp tiles, etc. | |
| GB20x has a specific SM count / register file / L2; sweep `warp_n ∈ {64,128,256}`, K-unroll, and the | |
| CTA tiling per layer shape (our 5 shapes: 3072², 3072→27648, 12288→3072, 3072→18432, 9216→3072). | |
| A small-shape vs large-shape split likely wins. Expect 5–20% on the GEMM. | |
| 7. **TMA (Tensor Memory Accelerator) loads.** If the kernel still uses cp.async/manual loads, switch the | |
| weight/activation/scale loads to TMA bulk async copies with multicast — overlaps load with MMA, frees | |
| registers. Blackwell-native. Expect 5–15% on memory-bound (small-M) layers. | |
| 8. **Rank Pareto + mixed rank.** r128 costs ~11% over r64 per-layer for 0.0303 vs 0.0364 quality. Map the | |
| quality/speed frontier; use **per-layer rank** (high rank only on the few hard layers — we have per- | |
| layer wrecon). Cuts the low-rank tax where it doesn't buy quality. Expect 3–8% per-step at iso-quality. | |
| ## Tier 3 — Blackwell-architecture-specific (higher effort, biggest ceiling) | |
| 9. **`tcgen05` 5th-gen MMA + Tensor Memory (TMEM).** Confirm the FP4 GEMM uses Blackwell's `tcgen05.mma` | |
| with TMEM accumulation and **hardware block-scaling** (`mma.kind::mxf4nvf4` / NVFP4 microscale in the | |
| MMA itself), not a Hopper-style `wgmma` path. If it's on the old path, porting to tcgen05 is the | |
| single biggest kernel lever (the 5th-gen cores roughly double FP4 throughput vs emulated scaling). | |
| 10. **FP8 attention (the real attention lever).** Attention is `attention_fp16` today and dominates at | |
| high res (O(N²), bf16/fp16). Move QKᵀ and the AV matmul to **FP8** (E4M3) à la FlashAttention-3 / | |
| SageAttention on Blackwell → ~2× the attention. At 1024px this is likely the #1 end-to-end win. Keep | |
| softmax in fp32. Risk: quality on the score matrix — validate on the probes. | |
| 11. **FP4 for the AV / V-projection.** Softmax-weighted V is more robust than QKᵀ; FP4 the V path while | |
| keeping scores FP8. Squeezes attention further. Quality-gated. | |
| 12. **FP6 (E3M2) as a quality/speed knob.** Blackwell does FP6 natively. For the hardest layers (or the | |
| activations), FP6 weights/acts give bf16-ish quality at ~1.5× — a middle rung between NVFP4 and bf16. | |
| ## Tier 4 — numerical / format (cheap to try, quality-gated) | |
| 13. **Real per-channel `wcscales`.** We set wcscales=1. A true per-output-channel scale tightens the | |
| NVFP4 weight quant → same quality at **lower rank** → less low-rank tax (compounds with #8). | |
| 14. **Activation group size.** NVFP4 acts are group-16. The micro-scale loads have overhead; test g32 | |
| (MXFP4-style) for the activation path where quality allows — fewer scale loads. | |
| 15. **Smoothing/rotation for the residual.** We use no-smooth (right for W4A8). For W4A4-NVFP4 a light | |
| Hadamard rotation of the residual can shrink outliers → smaller residual → more energy into the | |
| cheap low-rank branch, or lower wrecon at fixed rank. | |
| 16. **Fewer steps.** Orthogonal but huge: the model is 4-step distilled; test 3-step / 2-step quality. | |
| Each step removed is −25%/−50% of denoise. Pure pipeline lever, no kernel work. | |
| ## Measurement plan (record for the report) | |
| For each idea: (a) the per-stage CUDA-event breakdown before/after, (b) per-step & end-to-end at | |
| 512/1024 × batch{1,2,4}, (c) VRAM, (d) image-quality check on the 4 probes (text + hand are the | |
| sensitive ones), (e) the velocity-loss if it changes the math. Always vs the bf16 baseline AND the | |
| current NVFP4-fused baseline. Keep the matrix in `outputs/nvfp4/benchmark.json`. | |
| ## Rough priority (impact × 1/effort) | |
| CUDA graphs (1) · TE/VAE quant (2,3) · batch fix (4) · output-fusion (5) — first. | |
| Then FP8 attention (10) and tcgen05 verification (9) for the high-res ceiling. | |
| Then tile autotune (6), TMA (7), rank Pareto (8), wcscales (13). | |
Xet Storage Details
- Size:
- 6.71 kB
- Xet hash:
- ea12ef1e005431c1b92f24e01c5ac985820dc41fd0459ec41d926f594afbe9fb
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.