Buckets:
| """FULLY-FUSED klein-4B end-to-end speed: convert the diffusers Flux2Transformer to the | |
| Nunchaku fused model (fused_qkv_norm_rottary + attention_fp16 + SVDQ W4A4 Linears) and time | |
| the full pipeline vs bf16. Weights are dummy (real non-block weights kept; block Linears are | |
| dummy-quantized) — GEMM/attention timing is value-independent, so the SPEED is real; this | |
| isolates what FUSION buys over the Linears-only swap (scripts/24). | |
| Usage: PYTHONPATH=. python3 scripts/25_fused_4b_speed.py [H] [W] [steps] [rank] | |
| """ | |
| import sys, time, json, statistics as st | |
| import torch | |
| from flux2distill.model_utils import load_pipeline | |
| from nunchaku.models.transformers.transformer_flux2 import NunchakuFlux2Transformer2DModel | |
| from nunchaku.models.linear import SVDQW4A4Linear | |
| H = int(sys.argv[1]) if len(sys.argv) > 1 else 512 | |
| W = int(sys.argv[2]) if len(sys.argv) > 2 else 512 | |
| STEPS = int(sys.argv[3]) if len(sys.argv) > 3 else 4 | |
| RANK = int(sys.argv[4]) if len(sys.argv) > 4 else 128 | |
| PROMPT = "a photorealistic storefront at golden hour, neon sign, wet pavement reflections, 50mm" | |
| print(f"=== FUSED klein-4B end-to-end | {W}x{H} | {STEPS} steps | rank {RANK} | {torch.cuda.get_device_name(0)} ===") | |
| pipe = load_pipeline(device="cuda") | |
| tf = pipe.transformer | |
| tf.eval().requires_grad_(False) | |
| _step = [] | |
| def timer(fwd): | |
| def t(*a, **k): | |
| torch.cuda.synchronize(); s = time.perf_counter() | |
| o = fwd(*a, **k) | |
| torch.cuda.synchronize(); _step.append(time.perf_counter() - s) | |
| return o | |
| return t | |
| def run(seed=0): | |
| _step.clear(); torch.cuda.reset_peak_memory_stats() | |
| g = torch.Generator("cpu").manual_seed(seed) | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| # NO autocast: the FP4 kernel needs bf16/fp16 acts, but autocast runs norms in fp32 (-> | |
| # fp32 acts into the fused qkv -> kernel assert). The Nunchaku pipeline runs pure bf16. | |
| pipe(prompt=PROMPT, num_inference_steps=STEPS, guidance_scale=1.0, height=H, width=W, generator=g) | |
| torch.cuda.synchronize() | |
| return time.perf_counter() - t0, torch.cuda.max_memory_allocated()/1e9, (st.median(_step[1:]) if len(_step) > 1 else _step[0]) | |
| # ---- bf16 baseline ---- | |
| tf.forward = timer(tf.forward) | |
| run(99) | |
| bf = [run(i) for i in range(3)] | |
| bf_t, bf_v, bf_s = st.median([x[0] for x in bf]), bf[0][1], st.median([x[2] for x in bf]) | |
| print(f"bf16 : {bf_t:.3f}s step={bf_s*1000:.0f}ms VRAM={bf_v:.1f}GB") | |
| del tf.forward # restore class forward before swapping class | |
| # ---- convert to the FUSED Nunchaku model ---- | |
| tf.__class__ = NunchakuFlux2Transformer2DModel | |
| tf._patch_model(precision="nvfp4", rank=RANK, torch_dtype=torch.bfloat16) | |
| nq = 0 | |
| for m in tf.modules(): | |
| if isinstance(m, SVDQW4A4Linear): | |
| m.to_empty(device="cuda") # materialize the meta-created buffers | |
| with torch.no_grad(): | |
| m.qweight.random_(-128, 127) | |
| m.wscales.copy_(torch.ones_like(m.wscales)) | |
| m.smooth_factor.fill_(1.0); m.smooth_factor_orig.fill_(1.0) | |
| m.proj_down.normal_(0, 0.02); m.proj_up.normal_(0, 0.02) | |
| if m.wcscales is not None: m.wcscales.fill_(1.0) | |
| if m.bias is not None: m.bias.zero_() | |
| nq += 1 | |
| torch.cuda.empty_cache() | |
| print(f"converted to fused NunchakuFlux2: {nq} SVDQ Linears, fused attention") | |
| # ---- fused run ---- | |
| tf.forward = timer(tf.forward) # now wraps NunchakuFlux2.forward | |
| run(99) | |
| fz = [run(i) for i in range(3)] | |
| fz_t, fz_v, fz_s = st.median([x[0] for x in fz]), fz[0][1], st.median([x[2] for x in fz]) | |
| sp = bf_t / fz_t | |
| print(f"nvfp4 fused : {fz_t:.3f}s step={fz_s*1000:.0f}ms VRAM={fz_v:.1f}GB -> {sp:.2f}x end-to-end vs bf16") | |
| print(f" (Linears-only swap was 1.24x @512 / 1.18x @1024 — fusion delta = the win)") | |
| json.dump({"res": f"{W}x{H}", "steps": STEPS, "rank": RANK, | |
| "bf16_s": round(bf_t, 3), "bf16_step_ms": round(bf_s*1000, 1), "bf16_vram_gb": round(bf_v, 1), | |
| "fused_s": round(fz_t, 3), "fused_step_ms": round(fz_s*1000, 1), "fused_vram_gb": round(fz_v, 1), | |
| "speedup": round(sp, 3)}, | |
| open("outputs/nvfp4/fused_e2e_speed.json", "w"), indent=2) | |
| print("saved -> outputs/nvfp4/fused_e2e_speed.json") | |
Xet Storage Details
- Size:
- 4.15 kB
- Xet hash:
- d5d19176aea885e43735d3a2ea223cdc4571918242de03655a100a207d6d6905
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.