Buckets:

Mercity/FluxDistill / scripts /25_fused_4b_speed.py
Pranav2748's picture
download
raw
4.15 kB
"""FULLY-FUSED klein-4B end-to-end speed: convert the diffusers Flux2Transformer to the
Nunchaku fused model (fused_qkv_norm_rottary + attention_fp16 + SVDQ W4A4 Linears) and time
the full pipeline vs bf16. Weights are dummy (real non-block weights kept; block Linears are
dummy-quantized) — GEMM/attention timing is value-independent, so the SPEED is real; this
isolates what FUSION buys over the Linears-only swap (scripts/24).
Usage: PYTHONPATH=. python3 scripts/25_fused_4b_speed.py [H] [W] [steps] [rank]
"""
import sys, time, json, statistics as st
import torch
from flux2distill.model_utils import load_pipeline
from nunchaku.models.transformers.transformer_flux2 import NunchakuFlux2Transformer2DModel
from nunchaku.models.linear import SVDQW4A4Linear
H = int(sys.argv[1]) if len(sys.argv) > 1 else 512
W = int(sys.argv[2]) if len(sys.argv) > 2 else 512
STEPS = int(sys.argv[3]) if len(sys.argv) > 3 else 4
RANK = int(sys.argv[4]) if len(sys.argv) > 4 else 128
PROMPT = "a photorealistic storefront at golden hour, neon sign, wet pavement reflections, 50mm"
print(f"=== FUSED klein-4B end-to-end | {W}x{H} | {STEPS} steps | rank {RANK} | {torch.cuda.get_device_name(0)} ===")
pipe = load_pipeline(device="cuda")
tf = pipe.transformer
tf.eval().requires_grad_(False)
_step = []
def timer(fwd):
def t(*a, **k):
torch.cuda.synchronize(); s = time.perf_counter()
o = fwd(*a, **k)
torch.cuda.synchronize(); _step.append(time.perf_counter() - s)
return o
return t
def run(seed=0):
_step.clear(); torch.cuda.reset_peak_memory_stats()
g = torch.Generator("cpu").manual_seed(seed)
torch.cuda.synchronize(); t0 = time.perf_counter()
# NO autocast: the FP4 kernel needs bf16/fp16 acts, but autocast runs norms in fp32 (->
# fp32 acts into the fused qkv -> kernel assert). The Nunchaku pipeline runs pure bf16.
pipe(prompt=PROMPT, num_inference_steps=STEPS, guidance_scale=1.0, height=H, width=W, generator=g)
torch.cuda.synchronize()
return time.perf_counter() - t0, torch.cuda.max_memory_allocated()/1e9, (st.median(_step[1:]) if len(_step) > 1 else _step[0])
# ---- bf16 baseline ----
tf.forward = timer(tf.forward)
run(99)
bf = [run(i) for i in range(3)]
bf_t, bf_v, bf_s = st.median([x[0] for x in bf]), bf[0][1], st.median([x[2] for x in bf])
print(f"bf16 : {bf_t:.3f}s step={bf_s*1000:.0f}ms VRAM={bf_v:.1f}GB")
del tf.forward # restore class forward before swapping class
# ---- convert to the FUSED Nunchaku model ----
tf.__class__ = NunchakuFlux2Transformer2DModel
tf._patch_model(precision="nvfp4", rank=RANK, torch_dtype=torch.bfloat16)
nq = 0
for m in tf.modules():
if isinstance(m, SVDQW4A4Linear):
m.to_empty(device="cuda") # materialize the meta-created buffers
with torch.no_grad():
m.qweight.random_(-128, 127)
m.wscales.copy_(torch.ones_like(m.wscales))
m.smooth_factor.fill_(1.0); m.smooth_factor_orig.fill_(1.0)
m.proj_down.normal_(0, 0.02); m.proj_up.normal_(0, 0.02)
if m.wcscales is not None: m.wcscales.fill_(1.0)
if m.bias is not None: m.bias.zero_()
nq += 1
torch.cuda.empty_cache()
print(f"converted to fused NunchakuFlux2: {nq} SVDQ Linears, fused attention")
# ---- fused run ----
tf.forward = timer(tf.forward) # now wraps NunchakuFlux2.forward
run(99)
fz = [run(i) for i in range(3)]
fz_t, fz_v, fz_s = st.median([x[0] for x in fz]), fz[0][1], st.median([x[2] for x in fz])
sp = bf_t / fz_t
print(f"nvfp4 fused : {fz_t:.3f}s step={fz_s*1000:.0f}ms VRAM={fz_v:.1f}GB -> {sp:.2f}x end-to-end vs bf16")
print(f" (Linears-only swap was 1.24x @512 / 1.18x @1024 — fusion delta = the win)")
json.dump({"res": f"{W}x{H}", "steps": STEPS, "rank": RANK,
"bf16_s": round(bf_t, 3), "bf16_step_ms": round(bf_s*1000, 1), "bf16_vram_gb": round(bf_v, 1),
"fused_s": round(fz_t, 3), "fused_step_ms": round(fz_s*1000, 1), "fused_vram_gb": round(fz_v, 1),
"speedup": round(sp, 3)},
open("outputs/nvfp4/fused_e2e_speed.json", "w"), indent=2)
print("saved -> outputs/nvfp4/fused_e2e_speed.json")

Xet Storage Details

Size:
4.15 kB
·
Xet hash:
d5d19176aea885e43735d3a2ea223cdc4571918242de03655a100a207d6d6905

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.