Buckets:
| """Generate eval images for ONE model into a per-mode dir (one model per process -> no OOM). | |
| Modes: | |
| teacher bf16 teacher (reference for LPIPS/PSNR + FID-vs-teacher) | |
| fq:RANK fake-quant NVFP4 W4A4 at given rank: 0 = plain NVFP4 (no low-rank branch), | |
| >0 = SVDQuant. Recipe = plain-SVD, no-smooth, refine=3, w/a group-16 NVFP4. | |
| (fake-quant = numerically the real-kernel quality; slow but correct.) | |
| bfl BFL official NVFP4 (wired in scripts/33 once loading is solved) | |
| Reads outputs/eval/prompts.json; seed = idx (paired across modes). Saves {OUT}/{idx:05d}.png. | |
| Usage: python3 -u scripts/32_gen_eval.py MODE OUT_DIR [START] [COUNT] [RES] | |
| """ | |
| import sys, json, os, time, torch | |
| from flux2distill.model_utils import load_pipeline | |
| MODE = sys.argv[1]; OUT = sys.argv[2] | |
| START = int(sys.argv[3]) if len(sys.argv) > 3 else 0 | |
| COUNT = int(sys.argv[4]) if len(sys.argv) > 4 else 10**9 | |
| RES = int(sys.argv[5]) if len(sys.argv) > 5 else 512 | |
| os.makedirs(OUT, exist_ok=True) | |
| prompts = json.load(open(os.environ.get('PROMPTS_JSON', 'outputs/eval/prompts.json')))[START:START + COUNT] | |
| print(f"=== gen MODE={MODE} OUT={OUT} N={len(prompts)} RES={RES} ===", flush=True) | |
| pipe = load_pipeline(device='cuda'); tf = pipe.transformer; tf.eval().requires_grad_(False) | |
| def gen(prompt, seed): | |
| g = torch.Generator('cuda').manual_seed(seed) | |
| with torch.autocast('cuda', dtype=torch.bfloat16): | |
| return pipe(prompt=prompt, num_inference_steps=4, guidance_scale=1.0, | |
| height=RES, width=RES, generator=g).images[0] | |
| if MODE.startswith('fq:'): | |
| rank = int(MODE.split(':')[1]) | |
| from flux2distill.svdquant import collect_act_stats, apply_svdquant_from_stats, target_linear_names | |
| names = target_linear_names(tf) | |
| absmax, _, handles = collect_act_stats(tf, names, with_gram=False) | |
| print("calibrating act stats (4 fwd passes)...", flush=True) | |
| for d in prompts[:4]: | |
| gen(d['prompt'], 0) | |
| for h in handles: | |
| h.remove() | |
| apply_svdquant_from_stats(tf, absmax, rank=rank, whiten=False, smooth=False, refine_iters=3, | |
| w_bits=4, a_bits=4, w_group=16, a_group=16, | |
| w_fmt='nvfp4', a_fmt='nvfp4', grams=None) | |
| tag = 'plain-NVFP4 (rank0)' if rank == 0 else f'SVDQuant-NVFP4 rank{rank}' | |
| print(f"applied fake-quant NVFP4 W4A4 {tag}, no-smooth plain-SVD refine=3", flush=True) | |
| elif MODE == 'teacher': | |
| pass | |
| else: | |
| raise SystemExit(f"unknown MODE {MODE}") | |
| t0 = time.time(); done = 0 | |
| for d in prompts: | |
| f = os.path.join(OUT, f"{d['idx']:05d}.png") | |
| if os.path.exists(f): | |
| continue | |
| gen(d['prompt'], d['idx']).save(f) | |
| done += 1 | |
| if done % 25 == 0: | |
| print(f" {done}/{len(prompts)} ({time.time()-t0:.0f}s)", flush=True) | |
| print(f"DONE {MODE} -> {OUT} ({done} new imgs, {time.time()-t0:.0f}s)", flush=True) | |
Xet Storage Details
- Size:
- 2.89 kB
- Xet hash:
- e6d468b839800a62964b12af4ac2b59ab93d58f7ebb3dd5a65e2b34463258aa9
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.