Buckets:

Mercity/FluxDistill / docs /SPEEDUP_IDEAS.md
Pranav2748's picture
|
download
raw
6.71 kB
# Speedup ideation — NVFP4-fused klein-4B on Blackwell (sm_120)
Where we are: the fused NVFP4 model runs at **~1.9× end-to-end / 2.3× per-step** vs bf16, −26% VRAM.
The per-step transformer (2.3×) is already near the per-layer GEMM ceiling (2.5×). The gap to "more"
is NOT the GEMM — it's everything else. This doc is a prioritized menu for the next session. We are
happy to depart from Nunchaku's literature where it pays.
## 0. Where the time actually goes (measure first, always)
From the data: at 512px/4-step the bf16 step is ~92 ms, fused ~42 ms. End-to-end fused 0.264 s of which
4×42=168 ms is denoise → **~96 ms (36%) is NON-transformer: VAE decode + text-encode + latent prep +
launch overhead.** That fixed overhead is the single biggest cap on end-to-end speedup for the small 4B.
**Action 0:** instrument a per-stage timeline (text-encode | per-step{qkv|attn|ff} | vae-decode | python
glue) with CUDA events. Every idea below should be justified against this breakdown.
## Tier 1 — high impact, low risk (do these first)
1. **CUDA Graphs over the denoise loop.** The 4B's GEMMs are small → the path is partly *launch-bound*
(dozens of tiny kernel launches × 4 steps). Capture the per-step transformer (or the whole 4-step
loop) into a CUDA graph → eliminate launch latency + CPU dispatch. Biggest, cheapest win at small
batch/res. Expect 10–30% end-to-end. Risk: dynamic shapes (fixed at capture — fine, res is fixed).
2. **Quantize the text encoder (Qwen3).** It's bf16 now and a real chunk of the fixed overhead. Nunchaku
ships `svdq-int4-Qwen3-text` (→ NVFP4 on Blackwell). Or cache prompt embeddings across a batch/session
(encode-once). Cuts the ~text-encode slice to near-zero on repeat. Expect 5–15% end-to-end.
3. **Quantize / fuse the VAE decoder.** Conv-heavy, bf16 now. FP8 convs on Blackwell, or tile+fuse the
decoder. The VAE is a fixed per-image cost that dominates more as the transformer gets faster.
4. **Fix the fused path's batch=1 limitation.** The packed-rotary asserts `rot.shape[0]*shape[1]==M`;
make the rotary broadcast over batch (replicate per-batch, or index by token-mod-N). Unlocks batched
throughput — critical for serving (right now bf16 b4 beats fused b1 on img/s at some sizes). Expect
large throughput gains at batch>1.
## Tier 2 — kernel-level (medium effort, we build Nunchaku from source now)
5. **Output-quantization fusion (act-quant of layer N+1 in layer N's epilogue).** The kernel already has
a `qout` output ("packed quantized input for the next layer"). If the FLUX2 path doesn't use it, each
Linear pays a separate `quantize_w4a4_act_fuse_lora` pass. Fusing the next layer's act-quant into the
current GEMM's epilogue removes a full activation read/write per layer. Expect 5–15% per-step.
6. **GEMM tile / warp-config autotune for sm_120.** The kernel hardcodes `warp_n=128`, comp tiles, etc.
GB20x has a specific SM count / register file / L2; sweep `warp_n ∈ {64,128,256}`, K-unroll, and the
CTA tiling per layer shape (our 5 shapes: 3072², 3072→27648, 12288→3072, 3072→18432, 9216→3072).
A small-shape vs large-shape split likely wins. Expect 5–20% on the GEMM.
7. **TMA (Tensor Memory Accelerator) loads.** If the kernel still uses cp.async/manual loads, switch the
weight/activation/scale loads to TMA bulk async copies with multicast — overlaps load with MMA, frees
registers. Blackwell-native. Expect 5–15% on memory-bound (small-M) layers.
8. **Rank Pareto + mixed rank.** r128 costs ~11% over r64 per-layer for 0.0303 vs 0.0364 quality. Map the
quality/speed frontier; use **per-layer rank** (high rank only on the few hard layers — we have per-
layer wrecon). Cuts the low-rank tax where it doesn't buy quality. Expect 3–8% per-step at iso-quality.
## Tier 3 — Blackwell-architecture-specific (higher effort, biggest ceiling)
9. **`tcgen05` 5th-gen MMA + Tensor Memory (TMEM).** Confirm the FP4 GEMM uses Blackwell's `tcgen05.mma`
with TMEM accumulation and **hardware block-scaling** (`mma.kind::mxf4nvf4` / NVFP4 microscale in the
MMA itself), not a Hopper-style `wgmma` path. If it's on the old path, porting to tcgen05 is the
single biggest kernel lever (the 5th-gen cores roughly double FP4 throughput vs emulated scaling).
10. **FP8 attention (the real attention lever).** Attention is `attention_fp16` today and dominates at
high res (O(N²), bf16/fp16). Move QKᵀ and the AV matmul to **FP8** (E4M3) à la FlashAttention-3 /
SageAttention on Blackwell → ~2× the attention. At 1024px this is likely the #1 end-to-end win. Keep
softmax in fp32. Risk: quality on the score matrix — validate on the probes.
11. **FP4 for the AV / V-projection.** Softmax-weighted V is more robust than QKᵀ; FP4 the V path while
keeping scores FP8. Squeezes attention further. Quality-gated.
12. **FP6 (E3M2) as a quality/speed knob.** Blackwell does FP6 natively. For the hardest layers (or the
activations), FP6 weights/acts give bf16-ish quality at ~1.5× — a middle rung between NVFP4 and bf16.
## Tier 4 — numerical / format (cheap to try, quality-gated)
13. **Real per-channel `wcscales`.** We set wcscales=1. A true per-output-channel scale tightens the
NVFP4 weight quant → same quality at **lower rank** → less low-rank tax (compounds with #8).
14. **Activation group size.** NVFP4 acts are group-16. The micro-scale loads have overhead; test g32
(MXFP4-style) for the activation path where quality allows — fewer scale loads.
15. **Smoothing/rotation for the residual.** We use no-smooth (right for W4A8). For W4A4-NVFP4 a light
Hadamard rotation of the residual can shrink outliers → smaller residual → more energy into the
cheap low-rank branch, or lower wrecon at fixed rank.
16. **Fewer steps.** Orthogonal but huge: the model is 4-step distilled; test 3-step / 2-step quality.
Each step removed is −25%/−50% of denoise. Pure pipeline lever, no kernel work.
## Measurement plan (record for the report)
For each idea: (a) the per-stage CUDA-event breakdown before/after, (b) per-step & end-to-end at
512/1024 × batch{1,2,4}, (c) VRAM, (d) image-quality check on the 4 probes (text + hand are the
sensitive ones), (e) the velocity-loss if it changes the math. Always vs the bf16 baseline AND the
current NVFP4-fused baseline. Keep the matrix in `outputs/nvfp4/benchmark.json`.
## Rough priority (impact × 1/effort)
CUDA graphs (1) · TE/VAE quant (2,3) · batch fix (4) · output-fusion (5) — first.
Then FP8 attention (10) and tcgen05 verification (9) for the high-res ceiling.
Then tile autotune (6), TMA (7), rank Pareto (8), wcscales (13).

Xet Storage Details

Size:
6.71 kB
·
Xet hash:
ea12ef1e005431c1b92f24e01c5ac985820dc41fd0459ec41d926f594afbe9fb

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.