Buckets:

Mercity/FluxDistill / scripts /21_bf16_profile.py
Pranav2748's picture
download
raw
4.65 kB
"""bf16 baseline profiler — same timing harness as 20_nunchaku_profile.py, for a stock
(unquantized) Flux2KleinPipeline. Gives the bf16 reference that the Nunchaku FP4/INT4
kernel numbers are measured against.
Times transformer.forward (per denoise step) + vae.decode, each cuda.synchronize()'d.
Warm-up excluded. Same prompt / resolution / steps as the Nunchaku profile.
Usage:
PYTHONPATH=. python3 scripts/21_bf16_profile.py [model_dir] [steps] [H] [W] [n_timed]
model_dir default models/klein-4b (our distilled 4B teacher)
"""
import os, sys, time, json, statistics as st
import torch
from diffusers import Flux2KleinPipeline
MODEL = sys.argv[1] if len(sys.argv) > 1 else "models/klein-4b"
STEPS = int(sys.argv[2]) if len(sys.argv) > 2 else 4
H = int(sys.argv[3]) if len(sys.argv) > 3 else 1024
W = int(sys.argv[4]) if len(sys.argv) > 4 else 1024
NRUN = int(sys.argv[5]) if len(sys.argv) > 5 else 3
TAG = os.path.basename(MODEL.rstrip("/"))
PROMPT = ("a photorealistic storefront at golden hour with a glowing neon sign that reads "
"\"NUNCHAKU\", wet pavement reflections, shallow depth of field, 50mm")
print(f"=== bf16 baseline | {MODEL} | {W}x{H} | {STEPS} steps | {NRUN} timed runs ===")
print(f"device={torch.cuda.get_device_name(0)} | torch={torch.__version__}")
torch.cuda.reset_peak_memory_stats()
t0 = time.perf_counter()
pipe = Flux2KleinPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16)
pipe.to("cuda")
torch.cuda.synchronize()
print(f"[load] {time.perf_counter()-t0:.1f}s | post-load VRAM {torch.cuda.max_memory_allocated()/1e9:.2f} GB")
step_times, dec_times = [], []
_fwd = pipe.transformer.forward
def timed_fwd(*a, **k):
torch.cuda.synchronize(); s = time.perf_counter()
out = _fwd(*a, **k)
torch.cuda.synchronize(); step_times.append(time.perf_counter() - s)
return out
pipe.transformer.forward = timed_fwd
_dec = pipe.vae.decode
def timed_dec(*a, **k):
torch.cuda.synchronize(); s = time.perf_counter()
out = _dec(*a, **k)
torch.cuda.synchronize(); dec_times.append(time.perf_counter() - s)
return out
pipe.vae.decode = timed_dec
def run(seed=0):
step_times.clear(); dec_times.clear()
torch.cuda.reset_peak_memory_stats()
g = torch.Generator("cpu").manual_seed(seed)
torch.cuda.synchronize(); t = time.perf_counter()
img = pipe(prompt=PROMPT, guidance_scale=1.0, num_inference_steps=STEPS,
height=H, width=W, generator=g).images[0]
torch.cuda.synchronize(); total = time.perf_counter() - t
return dict(img=img, total=total, peak=torch.cuda.max_memory_allocated()/1e9,
steps=list(step_times), denoise=sum(step_times), decode=sum(dec_times))
print("[warmup] running (excluded) ...")
w = run(seed=999)
print(f"[warmup] total {w['total']:.2f}s | steps {[f'{x*1000:.0f}' for x in w['steps']]}ms")
runs = [run(seed=i) for i in range(NRUN)]
med = st.median
ns = len(runs[0]['steps'])
per_step = [med([r['steps'][i] for r in runs]) for i in range(ns)]
tot, den = med([r['total'] for r in runs]), med([r['denoise'] for r in runs])
dec = med([r['decode'] for r in runs])
ovh = med([r['total'] - r['denoise'] - r['decode'] for r in runs])
peak = med([r['peak'] for r in runs])
stepN = med([med([r['steps'][i] for r in runs]) for i in range(1, ns)]) if ns > 1 else per_step[0]
print("\n================ RESULTS (median of %d runs) ================" % NRUN)
print(f"model / precision : {TAG} bf16 (unquantized)")
print(f"resolution / steps : {W}x{H} / {STEPS}")
print(f"per-step transformer : " + ", ".join(f"{t*1000:.0f}ms" for t in per_step))
print(f" step1+ mean : {stepN*1000:.0f}ms")
print(f"denoise (sum steps) : {den*1000:.0f}ms -> {STEPS/den:.2f} steps/s (transformer-only)")
print(f"vae decode : {dec*1000:.0f}ms")
print(f"text-encode+overhead : {ovh*1000:.0f}ms")
print(f"END-TO-END : {tot:.2f}s -> {1/tot:.2f} img/s")
print(f"peak VRAM : {peak:.2f} GB")
os.makedirs("outputs/nunchaku", exist_ok=True)
out = f"outputs/nunchaku/{TAG}_bf16_{W}x{H}_{STEPS}s.png"
runs[-1]['img'].save(out)
res = dict(model=TAG, engine="bf16", precision="bf16", H=H, W=W, steps=STEPS,
per_step_ms=[round(x*1000,1) for x in per_step], denoise_ms=round(den*1000,1),
steps_per_s=round(STEPS/den,3), decode_ms=round(dec*1000,1), overhead_ms=round(ovh*1000,1),
total_s=round(tot,3), img_per_s=round(1/tot,3), peak_vram_gb=round(peak,2),
device=torch.cuda.get_device_name(0))
with open(f"outputs/nunchaku/{TAG}_bf16_{W}x{H}_{STEPS}s.json", "w") as f:
json.dump(res, f, indent=2)
print(f"saved sample -> {out}")

Xet Storage Details

Size:
4.65 kB
·
Xet hash:
fd000a9f04dd602954169040f2a9f2b45736caa8d671354efe19973a850d3f11

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.