Buckets:

Mercity/FluxDistill / scripts /28_benchmark_all.py
Pranav2748's picture
download
raw
5.91 kB
"""(A) Save the deployable NVFP4 checkpoint + (B) benchmark battery for the report.
Matrix: bf16 {batch 1,2,4}x{512,1024} ; NVFP4-fused {batch 1}x{512,1024} (fused rotary is batch=1).
Records per-step / end-to-end / img-per-s / VRAM / speedup to outputs/nvfp4/benchmark.json.
Also saves the converted model to outputs/nvfp4/deploy/klein4b_nvfp4_fused.safetensors.
"""
import json, time, statistics as st, os
import torch
from safetensors.torch import save_file
from flux2distill.model_utils import load_pipeline
from flux2distill.nunchaku_export import quantize_pack_nvfp4
from nunchaku.models.transformers.transformer_flux2 import NunchakuFlux2Transformer2DModel
from nunchaku.models.linear import SVDQW4A4Linear
RANK = 128
P = "a photorealistic storefront at golden hour, neon sign, wet pavement reflections, 50mm"
pipe = load_pipeline(device="cuda")
tf = pipe.transformer; tf.eval().requires_grad_(False)
_step = []
def timed(fwd):
def t(*a, **k):
torch.cuda.synchronize(); s = time.perf_counter()
o = fwd(*a, **k); torch.cuda.synchronize(); _step.append(time.perf_counter() - s); return o
return t
def run(B, H, W, reps=3):
res = []
for r in range(reps + 1):
_step.clear(); torch.cuda.reset_peak_memory_stats()
g = torch.Generator("cpu").manual_seed(r)
torch.cuda.synchronize(); t0 = time.perf_counter()
try:
pipe(prompt=[P] * B, num_inference_steps=4, guidance_scale=1.0, height=H, width=W, generator=g)
except torch.cuda.OutOfMemoryError:
return None
torch.cuda.synchronize()
if r > 0: # skip warmup
res.append((time.perf_counter() - t0, torch.cuda.max_memory_allocated()/1e9,
st.median(_step[1:]) if len(_step) > 1 else _step[0]))
tot = st.median([x[0] for x in res]); vram = max(x[1] for x in res); stp = st.median([x[2] for x in res])
return {"batch": B, "res": f"{W}x{H}", "total_s": round(tot, 3), "img_per_s": round(B/tot, 3),
"step_ms": round(stp*1000, 1), "vram_gb": round(vram, 1)}
results = {"bf16": [], "nvfp4_fused": []}
print("=== bf16 baseline matrix ===")
tf.forward = timed(tf.forward)
for H in (512, 1024):
for B in (1, 2, 4):
r = run(B, H, H)
if r: results["bf16"].append(r); print(" ", r)
else: print(f" bf16 b{B} {H} OOM")
del tf.forward
# ---- convert to NVFP4-fused + save checkpoint ----
print("=== converting to NVFP4-fused (480s) ===")
src = {}
for i, b in enumerate(tf.transformer_blocks):
a = b.attn
src[("d", i, "attn.to_qkv")] = torch.cat([a.to_q.weight, a.to_k.weight, a.to_v.weight], 0).clone()
src[("d", i, "attn.to_out.0")] = a.to_out[0].weight.clone()
src[("d", i, "attn.to_added_qkv")] = torch.cat([a.add_q_proj.weight, a.add_k_proj.weight, a.add_v_proj.weight], 0).clone()
src[("d", i, "attn.to_add_out")] = a.to_add_out.weight.clone()
src[("d", i, "ff.linear_in")] = b.ff.linear_in.weight.clone()
src[("d", i, "ff.linear_out")] = b.ff.linear_out.weight.clone()
src[("d", i, "ff_context.linear_in")] = b.ff_context.linear_in.weight.clone()
src[("d", i, "ff_context.linear_out")] = b.ff_context.linear_out.weight.clone()
for i, b in enumerate(tf.single_transformer_blocks):
src[("s", i, "Win")] = b.attn.to_qkv_mlp_proj.weight.clone()
src[("s", i, "Wout")] = b.attn.to_out.weight.clone()
tf.__class__ = NunchakuFlux2Transformer2DModel
tf._patch_model(precision="nvfp4", rank=RANK, torch_dtype=torch.bfloat16)
wtscales = {}
t0 = time.time()
for name, m in tf.named_modules():
if not isinstance(m, SVDQW4A4Linear): continue
m.to_empty(device="cuda")
if m.bias is not None: m.bias.zero_()
parts = name.split("."); bt = "d" if parts[0] == "transformer_blocks" else "s"; idx = int(parts[1]); local = ".".join(parts[2:])
if bt == "d": W = src[("d", idx, local)]
elif local == "attn.qkv_proj": W = src[("s", idx, "Win")][:m.out_features]
elif local == "attn.mlp_fc1": W = src[("s", idx, "Win")][9216:9216 + m.out_features]
elif local == "attn.out_proj": W = src[("s", idx, "Wout")][:, :m.in_features]
elif local == "attn.mlp_fc2": W = src[("s", idx, "Wout")][:, 3072:3072 + m.in_features]
bufs = quantize_pack_nvfp4(W.float().cuda(), rank=RANK, refine=3)
for n in ("qweight", "wscales", "wcscales", "proj_down", "proj_up", "smooth_factor"):
getattr(m, n).data.copy_(bufs[n].reshape(getattr(m, n).shape).to(getattr(m, n).dtype))
m.smooth_factor_orig.data.copy_(bufs["smooth_factor"].reshape(m.smooth_factor_orig.shape).to(m.smooth_factor_orig.dtype))
m.wtscale = bufs["wtscale"]; wtscales[name] = bufs["wtscale"]
del src; torch.cuda.empty_cache()
print(f"converted in {time.time()-t0:.0f}s")
os.makedirs("outputs/nvfp4/deploy", exist_ok=True)
sd = {k: v.contiguous().cpu() for k, v in tf.state_dict().items()}
meta = {"wtscales": json.dumps(wtscales), "rank": str(RANK), "precision": "nvfp4",
"model": "klein-4B", "format": "NunchakuFlux2Transformer2DModel state_dict"}
save_file(sd, "outputs/nvfp4/deploy/klein4b_nvfp4_fused.safetensors", metadata=meta)
print(f"saved checkpoint: outputs/nvfp4/deploy/klein4b_nvfp4_fused.safetensors ({sum(v.numel()*v.element_size() for v in sd.values())/1e9:.1f}GB)")
print("=== NVFP4-fused matrix (batch=1) ===")
tf.forward = timed(tf.forward)
for H in (512, 1024):
r = run(1, H, H)
if r: results["nvfp4_fused"].append(r); print(" ", r)
# speedup vs bf16 same (batch1) config
bf = {x["res"]: x for x in results["bf16"] if x["batch"] == 1}
for x in results["nvfp4_fused"]:
if x["res"] in bf:
x["speedup_vs_bf16"] = round(bf[x["res"]]["total_s"] / x["total_s"], 3)
x["vram_reduction"] = round(1 - x["vram_gb"]/bf[x["res"]]["vram_gb"], 3)
json.dump(results, open("outputs/nvfp4/benchmark.json", "w"), indent=2)
print("\nsaved -> outputs/nvfp4/benchmark.json")
print(json.dumps(results, indent=2))

Xet Storage Details

Size:
5.91 kB
·
Xet hash:
bb0ebd2ed503818f0888c89bfb8cf44baa07b673431479f29ec5570de0829eb3

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.