Buckets:
| """Batched diffusers generation (+timing) for teacher / BFL — batch>1 for speed. | |
| (Our fused NVFP4 kernel can't batch; it uses scripts/35 at batch=1.) | |
| Usage: python3 -u scripts/36_gen_batched.py MODE OUT_DIR TAG [N] [BATCH] [RES] | |
| MODE: teacher | bfl | |
| """ | |
| import sys, json, os, time, statistics as st, torch | |
| from flux2distill.model_utils import load_pipeline | |
| MODE, OUT, TAG = sys.argv[1], sys.argv[2], sys.argv[3] | |
| N = int(sys.argv[4]) if len(sys.argv) > 4 else 256 | |
| B = int(sys.argv[5]) if len(sys.argv) > 5 else 8 | |
| RES = int(sys.argv[6]) if len(sys.argv) > 6 else 512 | |
| os.makedirs(OUT, exist_ok=True) | |
| prompts = json.load(open('outputs/eval/prompts.json'))[:N] | |
| print(f"=== batched gen MODE={MODE} OUT={OUT} N={len(prompts)} B={B} RES={RES} ===", flush=True) | |
| pipe = load_pipeline(device='cuda') | |
| if MODE == 'bfl': | |
| from diffusers import Flux2Transformer2DModel | |
| BFL = "models/klein-4b-nvfp4/flux-2-klein-4b-nvfp4.safetensors" | |
| tf = Flux2Transformer2DModel.from_single_file(BFL, torch_dtype=torch.bfloat16) | |
| pipe.transformer = tf.to('cuda') | |
| print("loaded BFL nvfp4 transformer", flush=True) | |
| pipe.transformer.eval().requires_grad_(False) | |
| torch.cuda.reset_peak_memory_stats() | |
| t0 = time.time(); per_img = [] | |
| for i in range(0, len(prompts), B): | |
| batch = prompts[i:i + B] | |
| gens = [torch.Generator('cuda').manual_seed(d['idx']) for d in batch] | |
| ps = [d['prompt'] for d in batch] | |
| torch.cuda.synchronize(); ts = time.perf_counter() | |
| imgs = pipe(prompt=ps, num_inference_steps=4, guidance_scale=1.0, height=RES, width=RES, generator=gens).images | |
| torch.cuda.synchronize(); per_img.append((time.perf_counter() - ts) / len(batch)) | |
| for d, im in zip(batch, imgs): | |
| im.save(os.path.join(OUT, f"{d['idx']:05d}.png")) | |
| if (i // B) % 4 == 0: | |
| print(f" {i+len(batch)}/{len(prompts)} ({time.time()-t0:.0f}s)", flush=True) | |
| sp = per_img[1:] if len(per_img) > 1 else per_img | |
| timing = {"mode": MODE, "tag": TAG, "res": RES, "batch": B, "n": len(prompts), | |
| "s_per_img_median_batched": round(st.median(sp), 4) if sp else None, | |
| "peak_vram_gb": round(torch.cuda.max_memory_allocated() / 1e9, 2)} | |
| json.dump(timing, open(f"outputs/eval/timing_{TAG}.json", "w"), indent=2) | |
| print(f"DONE {MODE} -> {OUT} | {timing}", flush=True) | |
Xet Storage Details
- Size:
- 2.27 kB
- Xet hash:
- a34f04402c620c90e43e9b076db3a90077cd531ee22215c8e9e2f094c8419419
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.