Buckets:
| """(A) Save the deployable NVFP4 checkpoint + (B) benchmark battery for the report. | |
| Matrix: bf16 {batch 1,2,4}x{512,1024} ; NVFP4-fused {batch 1}x{512,1024} (fused rotary is batch=1). | |
| Records per-step / end-to-end / img-per-s / VRAM / speedup to outputs/nvfp4/benchmark.json. | |
| Also saves the converted model to outputs/nvfp4/deploy/klein4b_nvfp4_fused.safetensors. | |
| """ | |
| import json, time, statistics as st, os | |
| import torch | |
| from safetensors.torch import save_file | |
| from flux2distill.model_utils import load_pipeline | |
| from flux2distill.nunchaku_export import quantize_pack_nvfp4 | |
| from nunchaku.models.transformers.transformer_flux2 import NunchakuFlux2Transformer2DModel | |
| from nunchaku.models.linear import SVDQW4A4Linear | |
| RANK = 128 | |
| P = "a photorealistic storefront at golden hour, neon sign, wet pavement reflections, 50mm" | |
| pipe = load_pipeline(device="cuda") | |
| tf = pipe.transformer; tf.eval().requires_grad_(False) | |
| _step = [] | |
| def timed(fwd): | |
| def t(*a, **k): | |
| torch.cuda.synchronize(); s = time.perf_counter() | |
| o = fwd(*a, **k); torch.cuda.synchronize(); _step.append(time.perf_counter() - s); return o | |
| return t | |
| def run(B, H, W, reps=3): | |
| res = [] | |
| for r in range(reps + 1): | |
| _step.clear(); torch.cuda.reset_peak_memory_stats() | |
| g = torch.Generator("cpu").manual_seed(r) | |
| torch.cuda.synchronize(); t0 = time.perf_counter() | |
| try: | |
| pipe(prompt=[P] * B, num_inference_steps=4, guidance_scale=1.0, height=H, width=W, generator=g) | |
| except torch.cuda.OutOfMemoryError: | |
| return None | |
| torch.cuda.synchronize() | |
| if r > 0: # skip warmup | |
| res.append((time.perf_counter() - t0, torch.cuda.max_memory_allocated()/1e9, | |
| st.median(_step[1:]) if len(_step) > 1 else _step[0])) | |
| tot = st.median([x[0] for x in res]); vram = max(x[1] for x in res); stp = st.median([x[2] for x in res]) | |
| return {"batch": B, "res": f"{W}x{H}", "total_s": round(tot, 3), "img_per_s": round(B/tot, 3), | |
| "step_ms": round(stp*1000, 1), "vram_gb": round(vram, 1)} | |
| results = {"bf16": [], "nvfp4_fused": []} | |
| print("=== bf16 baseline matrix ===") | |
| tf.forward = timed(tf.forward) | |
| for H in (512, 1024): | |
| for B in (1, 2, 4): | |
| r = run(B, H, H) | |
| if r: results["bf16"].append(r); print(" ", r) | |
| else: print(f" bf16 b{B} {H} OOM") | |
| del tf.forward | |
| # ---- convert to NVFP4-fused + save checkpoint ---- | |
| print("=== converting to NVFP4-fused (480s) ===") | |
| src = {} | |
| for i, b in enumerate(tf.transformer_blocks): | |
| a = b.attn | |
| src[("d", i, "attn.to_qkv")] = torch.cat([a.to_q.weight, a.to_k.weight, a.to_v.weight], 0).clone() | |
| src[("d", i, "attn.to_out.0")] = a.to_out[0].weight.clone() | |
| src[("d", i, "attn.to_added_qkv")] = torch.cat([a.add_q_proj.weight, a.add_k_proj.weight, a.add_v_proj.weight], 0).clone() | |
| src[("d", i, "attn.to_add_out")] = a.to_add_out.weight.clone() | |
| src[("d", i, "ff.linear_in")] = b.ff.linear_in.weight.clone() | |
| src[("d", i, "ff.linear_out")] = b.ff.linear_out.weight.clone() | |
| src[("d", i, "ff_context.linear_in")] = b.ff_context.linear_in.weight.clone() | |
| src[("d", i, "ff_context.linear_out")] = b.ff_context.linear_out.weight.clone() | |
| for i, b in enumerate(tf.single_transformer_blocks): | |
| src[("s", i, "Win")] = b.attn.to_qkv_mlp_proj.weight.clone() | |
| src[("s", i, "Wout")] = b.attn.to_out.weight.clone() | |
| tf.__class__ = NunchakuFlux2Transformer2DModel | |
| tf._patch_model(precision="nvfp4", rank=RANK, torch_dtype=torch.bfloat16) | |
| wtscales = {} | |
| t0 = time.time() | |
| for name, m in tf.named_modules(): | |
| if not isinstance(m, SVDQW4A4Linear): continue | |
| m.to_empty(device="cuda") | |
| if m.bias is not None: m.bias.zero_() | |
| parts = name.split("."); bt = "d" if parts[0] == "transformer_blocks" else "s"; idx = int(parts[1]); local = ".".join(parts[2:]) | |
| if bt == "d": W = src[("d", idx, local)] | |
| elif local == "attn.qkv_proj": W = src[("s", idx, "Win")][:m.out_features] | |
| elif local == "attn.mlp_fc1": W = src[("s", idx, "Win")][9216:9216 + m.out_features] | |
| elif local == "attn.out_proj": W = src[("s", idx, "Wout")][:, :m.in_features] | |
| elif local == "attn.mlp_fc2": W = src[("s", idx, "Wout")][:, 3072:3072 + m.in_features] | |
| bufs = quantize_pack_nvfp4(W.float().cuda(), rank=RANK, refine=3) | |
| for n in ("qweight", "wscales", "wcscales", "proj_down", "proj_up", "smooth_factor"): | |
| getattr(m, n).data.copy_(bufs[n].reshape(getattr(m, n).shape).to(getattr(m, n).dtype)) | |
| m.smooth_factor_orig.data.copy_(bufs["smooth_factor"].reshape(m.smooth_factor_orig.shape).to(m.smooth_factor_orig.dtype)) | |
| m.wtscale = bufs["wtscale"]; wtscales[name] = bufs["wtscale"] | |
| del src; torch.cuda.empty_cache() | |
| print(f"converted in {time.time()-t0:.0f}s") | |
| os.makedirs("outputs/nvfp4/deploy", exist_ok=True) | |
| sd = {k: v.contiguous().cpu() for k, v in tf.state_dict().items()} | |
| meta = {"wtscales": json.dumps(wtscales), "rank": str(RANK), "precision": "nvfp4", | |
| "model": "klein-4B", "format": "NunchakuFlux2Transformer2DModel state_dict"} | |
| save_file(sd, "outputs/nvfp4/deploy/klein4b_nvfp4_fused.safetensors", metadata=meta) | |
| print(f"saved checkpoint: outputs/nvfp4/deploy/klein4b_nvfp4_fused.safetensors ({sum(v.numel()*v.element_size() for v in sd.values())/1e9:.1f}GB)") | |
| print("=== NVFP4-fused matrix (batch=1) ===") | |
| tf.forward = timed(tf.forward) | |
| for H in (512, 1024): | |
| r = run(1, H, H) | |
| if r: results["nvfp4_fused"].append(r); print(" ", r) | |
| # speedup vs bf16 same (batch1) config | |
| bf = {x["res"]: x for x in results["bf16"] if x["batch"] == 1} | |
| for x in results["nvfp4_fused"]: | |
| if x["res"] in bf: | |
| x["speedup_vs_bf16"] = round(bf[x["res"]]["total_s"] / x["total_s"], 3) | |
| x["vram_reduction"] = round(1 - x["vram_gb"]/bf[x["res"]]["vram_gb"], 3) | |
| json.dump(results, open("outputs/nvfp4/benchmark.json", "w"), indent=2) | |
| print("\nsaved -> outputs/nvfp4/benchmark.json") | |
| print(json.dumps(results, indent=2)) | |
Xet Storage Details
- Size:
- 5.91 kB
- Xet hash:
- bb0ebd2ed503818f0888c89bfb8cf44baa07b673431479f29ec5570de0829eb3
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.