Buckets:

Mercity/FluxDistill / docs /SPEEDUP_IDEAS.md
Pranav2748's picture
|
download
raw
6.71 kB

Speedup ideation — NVFP4-fused klein-4B on Blackwell (sm_120)

Where we are: the fused NVFP4 model runs at ~1.9× end-to-end / 2.3× per-step vs bf16, −26% VRAM. The per-step transformer (2.3×) is already near the per-layer GEMM ceiling (2.5×). The gap to "more" is NOT the GEMM — it's everything else. This doc is a prioritized menu for the next session. We are happy to depart from Nunchaku's literature where it pays.

0. Where the time actually goes (measure first, always)

From the data: at 512px/4-step the bf16 step is 92 ms, fused ~42 ms. End-to-end fused 0.264 s of which 4×42=168 ms is denoise → **96 ms (36%) is NON-transformer: VAE decode + text-encode + latent prep + launch overhead.** That fixed overhead is the single biggest cap on end-to-end speedup for the small 4B. Action 0: instrument a per-stage timeline (text-encode | per-step{qkv|attn|ff} | vae-decode | python glue) with CUDA events. Every idea below should be justified against this breakdown.

Tier 1 — high impact, low risk (do these first)

  1. CUDA Graphs over the denoise loop. The 4B's GEMMs are small → the path is partly launch-bound (dozens of tiny kernel launches × 4 steps). Capture the per-step transformer (or the whole 4-step loop) into a CUDA graph → eliminate launch latency + CPU dispatch. Biggest, cheapest win at small batch/res. Expect 10–30% end-to-end. Risk: dynamic shapes (fixed at capture — fine, res is fixed).
  2. Quantize the text encoder (Qwen3). It's bf16 now and a real chunk of the fixed overhead. Nunchaku ships svdq-int4-Qwen3-text (→ NVFP4 on Blackwell). Or cache prompt embeddings across a batch/session (encode-once). Cuts the ~text-encode slice to near-zero on repeat. Expect 5–15% end-to-end.
  3. Quantize / fuse the VAE decoder. Conv-heavy, bf16 now. FP8 convs on Blackwell, or tile+fuse the decoder. The VAE is a fixed per-image cost that dominates more as the transformer gets faster.
  4. Fix the fused path's batch=1 limitation. The packed-rotary asserts rot.shape[0]*shape[1]==M; make the rotary broadcast over batch (replicate per-batch, or index by token-mod-N). Unlocks batched throughput — critical for serving (right now bf16 b4 beats fused b1 on img/s at some sizes). Expect large throughput gains at batch>1.

Tier 2 — kernel-level (medium effort, we build Nunchaku from source now)

  1. Output-quantization fusion (act-quant of layer N+1 in layer N's epilogue). The kernel already has a qout output ("packed quantized input for the next layer"). If the FLUX2 path doesn't use it, each Linear pays a separate quantize_w4a4_act_fuse_lora pass. Fusing the next layer's act-quant into the current GEMM's epilogue removes a full activation read/write per layer. Expect 5–15% per-step.
  2. GEMM tile / warp-config autotune for sm_120. The kernel hardcodes warp_n=128, comp tiles, etc. GB20x has a specific SM count / register file / L2; sweep warp_n ∈ {64,128,256}, K-unroll, and the CTA tiling per layer shape (our 5 shapes: 3072², 3072→27648, 12288→3072, 3072→18432, 9216→3072). A small-shape vs large-shape split likely wins. Expect 5–20% on the GEMM.
  3. TMA (Tensor Memory Accelerator) loads. If the kernel still uses cp.async/manual loads, switch the weight/activation/scale loads to TMA bulk async copies with multicast — overlaps load with MMA, frees registers. Blackwell-native. Expect 5–15% on memory-bound (small-M) layers.
  4. Rank Pareto + mixed rank. r128 costs ~11% over r64 per-layer for 0.0303 vs 0.0364 quality. Map the quality/speed frontier; use per-layer rank (high rank only on the few hard layers — we have per- layer wrecon). Cuts the low-rank tax where it doesn't buy quality. Expect 3–8% per-step at iso-quality.

Tier 3 — Blackwell-architecture-specific (higher effort, biggest ceiling)

  1. tcgen05 5th-gen MMA + Tensor Memory (TMEM). Confirm the FP4 GEMM uses Blackwell's tcgen05.mma with TMEM accumulation and hardware block-scaling (mma.kind::mxf4nvf4 / NVFP4 microscale in the MMA itself), not a Hopper-style wgmma path. If it's on the old path, porting to tcgen05 is the single biggest kernel lever (the 5th-gen cores roughly double FP4 throughput vs emulated scaling).
  2. FP8 attention (the real attention lever). Attention is attention_fp16 today and dominates at high res (O(N²), bf16/fp16). Move QKᵀ and the AV matmul to FP8 (E4M3) à la FlashAttention-3 / SageAttention on Blackwell → ~2× the attention. At 1024px this is likely the #1 end-to-end win. Keep softmax in fp32. Risk: quality on the score matrix — validate on the probes.
  3. FP4 for the AV / V-projection. Softmax-weighted V is more robust than QKᵀ; FP4 the V path while keeping scores FP8. Squeezes attention further. Quality-gated.
  4. FP6 (E3M2) as a quality/speed knob. Blackwell does FP6 natively. For the hardest layers (or the activations), FP6 weights/acts give bf16-ish quality at ~1.5× — a middle rung between NVFP4 and bf16.

Tier 4 — numerical / format (cheap to try, quality-gated)

  1. Real per-channel wcscales. We set wcscales=1. A true per-output-channel scale tightens the NVFP4 weight quant → same quality at lower rank → less low-rank tax (compounds with #8).
  2. Activation group size. NVFP4 acts are group-16. The micro-scale loads have overhead; test g32 (MXFP4-style) for the activation path where quality allows — fewer scale loads.
  3. Smoothing/rotation for the residual. We use no-smooth (right for W4A8). For W4A4-NVFP4 a light Hadamard rotation of the residual can shrink outliers → smaller residual → more energy into the cheap low-rank branch, or lower wrecon at fixed rank.
  4. Fewer steps. Orthogonal but huge: the model is 4-step distilled; test 3-step / 2-step quality. Each step removed is −25%/−50% of denoise. Pure pipeline lever, no kernel work.

Measurement plan (record for the report)

For each idea: (a) the per-stage CUDA-event breakdown before/after, (b) per-step & end-to-end at 512/1024 × batch{1,2,4}, (c) VRAM, (d) image-quality check on the 4 probes (text + hand are the sensitive ones), (e) the velocity-loss if it changes the math. Always vs the bf16 baseline AND the current NVFP4-fused baseline. Keep the matrix in outputs/nvfp4/benchmark.json.

Rough priority (impact × 1/effort)

CUDA graphs (1) · TE/VAE quant (2,3) · batch fix (4) · output-fusion (5) — first. Then FP8 attention (10) and tcgen05 verification (9) for the high-res ceiling. Then tile autotune (6), TMA (7), rank Pareto (8), wcscales (13).

Xet Storage Details

Size:
6.71 kB
·
Xet hash:
ea12ef1e005431c1b92f24e01c5ac985820dc41fd0459ec41d926f594afbe9fb

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.