Buckets:

Mercity/FluxDistill / scripts /36_gen_batched.py
Pranav2748's picture
download
raw
2.27 kB
"""Batched diffusers generation (+timing) for teacher / BFL — batch>1 for speed.
(Our fused NVFP4 kernel can't batch; it uses scripts/35 at batch=1.)
Usage: python3 -u scripts/36_gen_batched.py MODE OUT_DIR TAG [N] [BATCH] [RES]
MODE: teacher | bfl
"""
import sys, json, os, time, statistics as st, torch
from flux2distill.model_utils import load_pipeline
MODE, OUT, TAG = sys.argv[1], sys.argv[2], sys.argv[3]
N = int(sys.argv[4]) if len(sys.argv) > 4 else 256
B = int(sys.argv[5]) if len(sys.argv) > 5 else 8
RES = int(sys.argv[6]) if len(sys.argv) > 6 else 512
os.makedirs(OUT, exist_ok=True)
prompts = json.load(open('outputs/eval/prompts.json'))[:N]
print(f"=== batched gen MODE={MODE} OUT={OUT} N={len(prompts)} B={B} RES={RES} ===", flush=True)
pipe = load_pipeline(device='cuda')
if MODE == 'bfl':
from diffusers import Flux2Transformer2DModel
BFL = "models/klein-4b-nvfp4/flux-2-klein-4b-nvfp4.safetensors"
tf = Flux2Transformer2DModel.from_single_file(BFL, torch_dtype=torch.bfloat16)
pipe.transformer = tf.to('cuda')
print("loaded BFL nvfp4 transformer", flush=True)
pipe.transformer.eval().requires_grad_(False)
torch.cuda.reset_peak_memory_stats()
t0 = time.time(); per_img = []
for i in range(0, len(prompts), B):
batch = prompts[i:i + B]
gens = [torch.Generator('cuda').manual_seed(d['idx']) for d in batch]
ps = [d['prompt'] for d in batch]
torch.cuda.synchronize(); ts = time.perf_counter()
imgs = pipe(prompt=ps, num_inference_steps=4, guidance_scale=1.0, height=RES, width=RES, generator=gens).images
torch.cuda.synchronize(); per_img.append((time.perf_counter() - ts) / len(batch))
for d, im in zip(batch, imgs):
im.save(os.path.join(OUT, f"{d['idx']:05d}.png"))
if (i // B) % 4 == 0:
print(f" {i+len(batch)}/{len(prompts)} ({time.time()-t0:.0f}s)", flush=True)
sp = per_img[1:] if len(per_img) > 1 else per_img
timing = {"mode": MODE, "tag": TAG, "res": RES, "batch": B, "n": len(prompts),
"s_per_img_median_batched": round(st.median(sp), 4) if sp else None,
"peak_vram_gb": round(torch.cuda.max_memory_allocated() / 1e9, 2)}
json.dump(timing, open(f"outputs/eval/timing_{TAG}.json", "w"), indent=2)
print(f"DONE {MODE} -> {OUT} | {timing}", flush=True)

Xet Storage Details

Size:
2.27 kB
·
Xet hash:
a34f04402c620c90e43e9b076db3a90077cd531ee22215c8e9e2f094c8419419

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.