Buckets:
| """bf16 baseline profiler — same timing harness as 20_nunchaku_profile.py, for a stock | |
| (unquantized) Flux2KleinPipeline. Gives the bf16 reference that the Nunchaku FP4/INT4 | |
| kernel numbers are measured against. | |
| Times transformer.forward (per denoise step) + vae.decode, each cuda.synchronize()'d. | |
| Warm-up excluded. Same prompt / resolution / steps as the Nunchaku profile. | |
| Usage: | |
| PYTHONPATH=. python3 scripts/21_bf16_profile.py [model_dir] [steps] [H] [W] [n_timed] | |
| model_dir default models/klein-4b (our distilled 4B teacher) | |
| """ | |
| import os, sys, time, json, statistics as st | |
| import torch | |
| from diffusers import Flux2KleinPipeline | |
| MODEL = sys.argv[1] if len(sys.argv) > 1 else "models/klein-4b" | |
| STEPS = int(sys.argv[2]) if len(sys.argv) > 2 else 4 | |
| H = int(sys.argv[3]) if len(sys.argv) > 3 else 1024 | |
| W = int(sys.argv[4]) if len(sys.argv) > 4 else 1024 | |
| NRUN = int(sys.argv[5]) if len(sys.argv) > 5 else 3 | |
| TAG = os.path.basename(MODEL.rstrip("/")) | |
| PROMPT = ("a photorealistic storefront at golden hour with a glowing neon sign that reads " | |
| "\"NUNCHAKU\", wet pavement reflections, shallow depth of field, 50mm") | |
| print(f"=== bf16 baseline | {MODEL} | {W}x{H} | {STEPS} steps | {NRUN} timed runs ===") | |
| print(f"device={torch.cuda.get_device_name(0)} | torch={torch.__version__}") | |
| torch.cuda.reset_peak_memory_stats() | |
| t0 = time.perf_counter() | |
| pipe = Flux2KleinPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16) | |
| pipe.to("cuda") | |
| torch.cuda.synchronize() | |
| print(f"[load] {time.perf_counter()-t0:.1f}s | post-load VRAM {torch.cuda.max_memory_allocated()/1e9:.2f} GB") | |
| step_times, dec_times = [], [] | |
| _fwd = pipe.transformer.forward | |
| def timed_fwd(*a, **k): | |
| torch.cuda.synchronize(); s = time.perf_counter() | |
| out = _fwd(*a, **k) | |
| torch.cuda.synchronize(); step_times.append(time.perf_counter() - s) | |
| return out | |
| pipe.transformer.forward = timed_fwd | |
| _dec = pipe.vae.decode | |
| def timed_dec(*a, **k): | |
| torch.cuda.synchronize(); s = time.perf_counter() | |
| out = _dec(*a, **k) | |
| torch.cuda.synchronize(); dec_times.append(time.perf_counter() - s) | |
| return out | |
| pipe.vae.decode = timed_dec | |
| def run(seed=0): | |
| step_times.clear(); dec_times.clear() | |
| torch.cuda.reset_peak_memory_stats() | |
| g = torch.Generator("cpu").manual_seed(seed) | |
| torch.cuda.synchronize(); t = time.perf_counter() | |
| img = pipe(prompt=PROMPT, guidance_scale=1.0, num_inference_steps=STEPS, | |
| height=H, width=W, generator=g).images[0] | |
| torch.cuda.synchronize(); total = time.perf_counter() - t | |
| return dict(img=img, total=total, peak=torch.cuda.max_memory_allocated()/1e9, | |
| steps=list(step_times), denoise=sum(step_times), decode=sum(dec_times)) | |
| print("[warmup] running (excluded) ...") | |
| w = run(seed=999) | |
| print(f"[warmup] total {w['total']:.2f}s | steps {[f'{x*1000:.0f}' for x in w['steps']]}ms") | |
| runs = [run(seed=i) for i in range(NRUN)] | |
| med = st.median | |
| ns = len(runs[0]['steps']) | |
| per_step = [med([r['steps'][i] for r in runs]) for i in range(ns)] | |
| tot, den = med([r['total'] for r in runs]), med([r['denoise'] for r in runs]) | |
| dec = med([r['decode'] for r in runs]) | |
| ovh = med([r['total'] - r['denoise'] - r['decode'] for r in runs]) | |
| peak = med([r['peak'] for r in runs]) | |
| stepN = med([med([r['steps'][i] for r in runs]) for i in range(1, ns)]) if ns > 1 else per_step[0] | |
| print("\n================ RESULTS (median of %d runs) ================" % NRUN) | |
| print(f"model / precision : {TAG} bf16 (unquantized)") | |
| print(f"resolution / steps : {W}x{H} / {STEPS}") | |
| print(f"per-step transformer : " + ", ".join(f"{t*1000:.0f}ms" for t in per_step)) | |
| print(f" step1+ mean : {stepN*1000:.0f}ms") | |
| print(f"denoise (sum steps) : {den*1000:.0f}ms -> {STEPS/den:.2f} steps/s (transformer-only)") | |
| print(f"vae decode : {dec*1000:.0f}ms") | |
| print(f"text-encode+overhead : {ovh*1000:.0f}ms") | |
| print(f"END-TO-END : {tot:.2f}s -> {1/tot:.2f} img/s") | |
| print(f"peak VRAM : {peak:.2f} GB") | |
| os.makedirs("outputs/nunchaku", exist_ok=True) | |
| out = f"outputs/nunchaku/{TAG}_bf16_{W}x{H}_{STEPS}s.png" | |
| runs[-1]['img'].save(out) | |
| res = dict(model=TAG, engine="bf16", precision="bf16", H=H, W=W, steps=STEPS, | |
| per_step_ms=[round(x*1000,1) for x in per_step], denoise_ms=round(den*1000,1), | |
| steps_per_s=round(STEPS/den,3), decode_ms=round(dec*1000,1), overhead_ms=round(ovh*1000,1), | |
| total_s=round(tot,3), img_per_s=round(1/tot,3), peak_vram_gb=round(peak,2), | |
| device=torch.cuda.get_device_name(0)) | |
| with open(f"outputs/nunchaku/{TAG}_bf16_{W}x{H}_{STEPS}s.json", "w") as f: | |
| json.dump(res, f, indent=2) | |
| print(f"saved sample -> {out}") | |
Xet Storage Details
- Size:
- 4.65 kB
- Xet hash:
- fd000a9f04dd602954169040f2a9f2b45736caa8d671354efe19973a850d3f11
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.