Buckets:
| """Compute quant-eval metrics for each model's generated images. | |
| Per model: PickScore (↑ human-preference proxy), CLIP score (↑ alignment), LPIPS↓ + PSNR↑ vs | |
| teacher, FID↓ vs MJHQ-real, FID↓ vs teacher outputs. Writes outputs/eval/metrics.json + table. | |
| PickScore replaces ImageReward (whose bundled BLIP is incompatible with transformers 5.x); it's a | |
| native CLIP-H model and a standard human-preference metric. Relative scores across models (same | |
| prompts) are what matter for SVDQuant-vs-plain. | |
| Usage: python3 scripts/34_metrics.py TEACHER_DIR REF_DIR MODEL_DIR [MODEL_DIR ...] | |
| """ | |
| import sys, json, os, glob | |
| import numpy as np, torch | |
| from PIL import Image | |
| prompts = {d['idx']: d['prompt'] for d in json.load(open('outputs/eval/prompts.json'))} | |
| TEACHER, REF, MODELS = sys.argv[1], sys.argv[2], sys.argv[3:] | |
| dev = 'cuda' | |
| import lpips as lpips_lib | |
| lpips_fn = lpips_lib.LPIPS(net='alex').to(dev).eval() | |
| import open_clip | |
| clip_model, _, clip_pre = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai') | |
| clip_model = clip_model.to(dev).eval() | |
| clip_tok = open_clip.get_tokenizer('ViT-B-32') | |
| from transformers import AutoModel, AutoProcessor | |
| try: | |
| ps_proc = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") | |
| ps_model = AutoModel.from_pretrained("yuvalkirstain/PickScore_v1").to(dev).eval() | |
| HAS_PS = True | |
| except Exception as e: | |
| print("PickScore unavailable:", type(e).__name__, str(e)[:90]); HAS_PS = False | |
| from cleanfid import fid as cleanfid | |
| def imgs(d): | |
| return sorted(glob.glob(os.path.join(d, '*.png')) + glob.glob(os.path.join(d, '*.jpg'))) | |
| def idx_of(p): | |
| return int(os.path.splitext(os.path.basename(p))[0]) | |
| def load(p): | |
| return Image.open(p).convert('RGB') | |
| def to_t(im): | |
| return torch.from_numpy(np.asarray(im).astype('float32') / 127.5 - 1).permute(2, 0, 1)[None].to(dev) | |
| def pickscore(prompt, im): | |
| # CLIPModel joint forward -> logits_per_image is the image-text similarity (PickScore). | |
| inp = ps_proc(text=[prompt], images=im, return_tensors="pt", padding=True, truncation=True, max_length=77).to(dev) | |
| return float(ps_model(**inp).logits_per_image[0, 0].item()) | |
| def per_image(model_dir, vs_teacher): | |
| r = {'PickScore': [], 'CLIP': [], 'LPIPS': [], 'PSNR': []} | |
| for p in imgs(model_dir): | |
| i = idx_of(p); prompt = prompts[i]; im = load(p) | |
| if HAS_PS: | |
| try: | |
| r['PickScore'].append(pickscore(prompt, im)) | |
| except Exception: | |
| pass | |
| it = clip_pre(im)[None].to(dev); tt = clip_tok([prompt]).to(dev) | |
| imf = clip_model.encode_image(it); txf = clip_model.encode_text(tt) | |
| imf = imf / imf.norm(dim=-1, keepdim=True); txf = txf / txf.norm(dim=-1, keepdim=True) | |
| r['CLIP'].append(float((imf * txf).sum(-1).item()) * 100) | |
| if vs_teacher: | |
| tp = os.path.join(TEACHER, f"{i:05d}.png") | |
| if os.path.exists(tp): | |
| tim = load(tp) | |
| r['LPIPS'].append(float(lpips_fn(to_t(im), to_t(tim)).item())) | |
| a = np.asarray(im).astype('float32'); b = np.asarray(tim).astype('float32') | |
| mse = ((a - b) ** 2).mean() | |
| r['PSNR'].append(99.0 if mse < 1e-9 else float(10 * np.log10(255.0 ** 2 / mse))) | |
| return {k: (round(float(np.mean(v)), 4) if v else None) for k, v in r.items()} | |
| out = {} | |
| for m in [TEACHER] + MODELS: | |
| name = os.path.basename(m.rstrip('/')) | |
| print(f"--- {name} ({len(imgs(m))} imgs) ---", flush=True) | |
| pm = per_image(m, vs_teacher=(os.path.abspath(m) != os.path.abspath(TEACHER))) | |
| fid_real = round(cleanfid.compute_fid(m, REF, verbose=False), 3) if os.path.isdir(REF) and imgs(REF) else None | |
| fid_teacher = None if os.path.abspath(m) == os.path.abspath(TEACHER) else round(cleanfid.compute_fid(m, TEACHER, verbose=False), 3) | |
| out[name] = {**pm, 'FID_vs_real': fid_real, 'FID_vs_teacher': fid_teacher} | |
| print(f" {out[name]}", flush=True) | |
| json.dump(out, open('outputs/eval/metrics.json', 'w'), indent=2) | |
| print("\n=== SUMMARY (outputs/eval/metrics.json) ===") | |
| print(f"{'model':30s} {'PickScore↑':10s} {'CLIP↑':7s} {'LPIPS↓':7s} {'PSNR↑':7s} {'FIDreal↓':9s} {'FIDteach↓':9s}") | |
| for name, v in out.items(): | |
| print(f"{name:30s} {str(v['PickScore']):10s} {str(v['CLIP']):7s} {str(v['LPIPS']):7s} {str(v['PSNR']):7s} {str(v['FID_vs_real']):9s} {str(v['FID_vs_teacher']):9s}") | |
Xet Storage Details
- Size:
- 4.42 kB
- Xet hash:
- 556203498802d8befe74d201bb159782fe1372c043f3285521b5af707330ec4a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.