Buckets:
| """True same-model FP4-vs-bf16 kernel speedup: time the bf16 9B Flux2 transformer.forward on | |
| the EXACT inputs captured from the FP4 run (tmp/fwd9b.pt). Only the kernel/precision differs, | |
| so (bf16 per-step) / (FP4 per-step) is a clean transformer-only speedup on this Blackwell card. | |
| Loads the bf16 transformer ALONE (no text encoder / vae) so 18GB fits in 32GB with no offload. | |
| Usage: PYTHONPATH=. python3 scripts/22_bf16_9b_bench.py [inputs.pt] [n_iter] [transformer_dir] | |
| """ | |
| import sys, time, json, statistics as st | |
| import torch | |
| from diffusers import Flux2Transformer2DModel | |
| DUMP = sys.argv[1] if len(sys.argv) > 1 else "tmp/fwd9b.pt" | |
| NITER = int(sys.argv[2]) if len(sys.argv) > 2 else 20 | |
| TDIR = sys.argv[3] if len(sys.argv) > 3 else "models/klein-9b-nunchaku/transformer" | |
| def to_cuda(x): | |
| if torch.is_tensor(x): return x.to("cuda") | |
| if isinstance(x, (list, tuple)): return type(x)(to_cuda(v) for v in x) | |
| if isinstance(x, dict): return {k: to_cuda(v) for k, v in x.items()} | |
| return x | |
| d = torch.load(DUMP, map_location="cpu", weights_only=False) | |
| args, kwargs = to_cuda(d["args"]), to_cuda(d["kwargs"]) | |
| print(f"=== bf16 9B transformer-only bench | inputs={DUMP} | {NITER} iters ===") | |
| print(f"device={torch.cuda.get_device_name(0)} torch={torch.__version__}") | |
| shapes = {k: tuple(v.shape) for k, v in kwargs.items() if torch.is_tensor(v)} | |
| shapes.update({f"arg{i}": tuple(v.shape) for i, v in enumerate(args) if torch.is_tensor(v)}) | |
| print("input tensor shapes:", shapes) | |
| torch.cuda.reset_peak_memory_stats() | |
| t0 = time.perf_counter() | |
| m = Flux2Transformer2DModel.from_pretrained(TDIR, torch_dtype=torch.bfloat16).to("cuda").eval() | |
| torch.cuda.synchronize() | |
| print(f"[load] bf16 transformer {time.perf_counter()-t0:.1f}s | VRAM {torch.cuda.max_memory_allocated()/1e9:.2f} GB") | |
| times = [] | |
| with torch.no_grad(): | |
| for _ in range(3): # warmup | |
| m(*args, **kwargs) | |
| torch.cuda.synchronize() | |
| for _ in range(NITER): | |
| torch.cuda.synchronize(); s = time.perf_counter() | |
| m(*args, **kwargs) | |
| torch.cuda.synchronize(); times.append(time.perf_counter() - s) | |
| med = st.median(times) | |
| print("\n================ bf16 9B transformer.forward ================") | |
| print(f"per-step (median): {med*1000:.0f}ms (min {min(times)*1000:.0f}, max {max(times)*1000:.0f}, n={NITER})") | |
| print(f"throughput : {1/med:.2f} steps/s") | |
| print(f"peak VRAM : {torch.cuda.max_memory_allocated()/1e9:.2f} GB") | |
| with open("outputs/nunchaku/bf16_9b_transformer_bench.json", "w") as f: | |
| json.dump(dict(per_step_ms=round(med*1000,1), min_ms=round(min(times)*1000,1), | |
| steps_per_s=round(1/med,3), n=NITER, shapes={k:list(v) for k,v in shapes.items()}), f, indent=2) | |
| print("saved -> outputs/nunchaku/bf16_9b_transformer_bench.json") | |
Xet Storage Details
- Size:
- 2.77 kB
- Xet hash:
- b36701b56ce93d77d99cffe273e7fa10e9a54afa6c9cdb9e30b4664ffe3fa266
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.