Buckets:

Mercity/FluxDistill / scripts /32_gen_eval.py
Pranav2748's picture
download
raw
2.89 kB
"""Generate eval images for ONE model into a per-mode dir (one model per process -> no OOM).
Modes:
teacher bf16 teacher (reference for LPIPS/PSNR + FID-vs-teacher)
fq:RANK fake-quant NVFP4 W4A4 at given rank: 0 = plain NVFP4 (no low-rank branch),
>0 = SVDQuant. Recipe = plain-SVD, no-smooth, refine=3, w/a group-16 NVFP4.
(fake-quant = numerically the real-kernel quality; slow but correct.)
bfl BFL official NVFP4 (wired in scripts/33 once loading is solved)
Reads outputs/eval/prompts.json; seed = idx (paired across modes). Saves {OUT}/{idx:05d}.png.
Usage: python3 -u scripts/32_gen_eval.py MODE OUT_DIR [START] [COUNT] [RES]
"""
import sys, json, os, time, torch
from flux2distill.model_utils import load_pipeline
MODE = sys.argv[1]; OUT = sys.argv[2]
START = int(sys.argv[3]) if len(sys.argv) > 3 else 0
COUNT = int(sys.argv[4]) if len(sys.argv) > 4 else 10**9
RES = int(sys.argv[5]) if len(sys.argv) > 5 else 512
os.makedirs(OUT, exist_ok=True)
prompts = json.load(open(os.environ.get('PROMPTS_JSON', 'outputs/eval/prompts.json')))[START:START + COUNT]
print(f"=== gen MODE={MODE} OUT={OUT} N={len(prompts)} RES={RES} ===", flush=True)
pipe = load_pipeline(device='cuda'); tf = pipe.transformer; tf.eval().requires_grad_(False)
def gen(prompt, seed):
g = torch.Generator('cuda').manual_seed(seed)
with torch.autocast('cuda', dtype=torch.bfloat16):
return pipe(prompt=prompt, num_inference_steps=4, guidance_scale=1.0,
height=RES, width=RES, generator=g).images[0]
if MODE.startswith('fq:'):
rank = int(MODE.split(':')[1])
from flux2distill.svdquant import collect_act_stats, apply_svdquant_from_stats, target_linear_names
names = target_linear_names(tf)
absmax, _, handles = collect_act_stats(tf, names, with_gram=False)
print("calibrating act stats (4 fwd passes)...", flush=True)
for d in prompts[:4]:
gen(d['prompt'], 0)
for h in handles:
h.remove()
apply_svdquant_from_stats(tf, absmax, rank=rank, whiten=False, smooth=False, refine_iters=3,
w_bits=4, a_bits=4, w_group=16, a_group=16,
w_fmt='nvfp4', a_fmt='nvfp4', grams=None)
tag = 'plain-NVFP4 (rank0)' if rank == 0 else f'SVDQuant-NVFP4 rank{rank}'
print(f"applied fake-quant NVFP4 W4A4 {tag}, no-smooth plain-SVD refine=3", flush=True)
elif MODE == 'teacher':
pass
else:
raise SystemExit(f"unknown MODE {MODE}")
t0 = time.time(); done = 0
for d in prompts:
f = os.path.join(OUT, f"{d['idx']:05d}.png")
if os.path.exists(f):
continue
gen(d['prompt'], d['idx']).save(f)
done += 1
if done % 25 == 0:
print(f" {done}/{len(prompts)} ({time.time()-t0:.0f}s)", flush=True)
print(f"DONE {MODE} -> {OUT} ({done} new imgs, {time.time()-t0:.0f}s)", flush=True)

Xet Storage Details

Size:
2.89 kB
·
Xet hash:
e6d468b839800a62964b12af4ac2b59ab93d58f7ebb3dd5a65e2b34463258aa9

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.