Buckets:

Mercity/FluxDistill / scripts /34_metrics.py
Pranav2748's picture
download
raw
4.42 kB
"""Compute quant-eval metrics for each model's generated images.
Per model: PickScore (↑ human-preference proxy), CLIP score (↑ alignment), LPIPS↓ + PSNR↑ vs
teacher, FID↓ vs MJHQ-real, FID↓ vs teacher outputs. Writes outputs/eval/metrics.json + table.
PickScore replaces ImageReward (whose bundled BLIP is incompatible with transformers 5.x); it's a
native CLIP-H model and a standard human-preference metric. Relative scores across models (same
prompts) are what matter for SVDQuant-vs-plain.
Usage: python3 scripts/34_metrics.py TEACHER_DIR REF_DIR MODEL_DIR [MODEL_DIR ...]
"""
import sys, json, os, glob
import numpy as np, torch
from PIL import Image
prompts = {d['idx']: d['prompt'] for d in json.load(open('outputs/eval/prompts.json'))}
TEACHER, REF, MODELS = sys.argv[1], sys.argv[2], sys.argv[3:]
dev = 'cuda'
import lpips as lpips_lib
lpips_fn = lpips_lib.LPIPS(net='alex').to(dev).eval()
import open_clip
clip_model, _, clip_pre = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
clip_model = clip_model.to(dev).eval()
clip_tok = open_clip.get_tokenizer('ViT-B-32')
from transformers import AutoModel, AutoProcessor
try:
ps_proc = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
ps_model = AutoModel.from_pretrained("yuvalkirstain/PickScore_v1").to(dev).eval()
HAS_PS = True
except Exception as e:
print("PickScore unavailable:", type(e).__name__, str(e)[:90]); HAS_PS = False
from cleanfid import fid as cleanfid
def imgs(d):
return sorted(glob.glob(os.path.join(d, '*.png')) + glob.glob(os.path.join(d, '*.jpg')))
def idx_of(p):
return int(os.path.splitext(os.path.basename(p))[0])
def load(p):
return Image.open(p).convert('RGB')
def to_t(im):
return torch.from_numpy(np.asarray(im).astype('float32') / 127.5 - 1).permute(2, 0, 1)[None].to(dev)
@torch.no_grad()
def pickscore(prompt, im):
# CLIPModel joint forward -> logits_per_image is the image-text similarity (PickScore).
inp = ps_proc(text=[prompt], images=im, return_tensors="pt", padding=True, truncation=True, max_length=77).to(dev)
return float(ps_model(**inp).logits_per_image[0, 0].item())
@torch.no_grad()
def per_image(model_dir, vs_teacher):
r = {'PickScore': [], 'CLIP': [], 'LPIPS': [], 'PSNR': []}
for p in imgs(model_dir):
i = idx_of(p); prompt = prompts[i]; im = load(p)
if HAS_PS:
try:
r['PickScore'].append(pickscore(prompt, im))
except Exception:
pass
it = clip_pre(im)[None].to(dev); tt = clip_tok([prompt]).to(dev)
imf = clip_model.encode_image(it); txf = clip_model.encode_text(tt)
imf = imf / imf.norm(dim=-1, keepdim=True); txf = txf / txf.norm(dim=-1, keepdim=True)
r['CLIP'].append(float((imf * txf).sum(-1).item()) * 100)
if vs_teacher:
tp = os.path.join(TEACHER, f"{i:05d}.png")
if os.path.exists(tp):
tim = load(tp)
r['LPIPS'].append(float(lpips_fn(to_t(im), to_t(tim)).item()))
a = np.asarray(im).astype('float32'); b = np.asarray(tim).astype('float32')
mse = ((a - b) ** 2).mean()
r['PSNR'].append(99.0 if mse < 1e-9 else float(10 * np.log10(255.0 ** 2 / mse)))
return {k: (round(float(np.mean(v)), 4) if v else None) for k, v in r.items()}
out = {}
for m in [TEACHER] + MODELS:
name = os.path.basename(m.rstrip('/'))
print(f"--- {name} ({len(imgs(m))} imgs) ---", flush=True)
pm = per_image(m, vs_teacher=(os.path.abspath(m) != os.path.abspath(TEACHER)))
fid_real = round(cleanfid.compute_fid(m, REF, verbose=False), 3) if os.path.isdir(REF) and imgs(REF) else None
fid_teacher = None if os.path.abspath(m) == os.path.abspath(TEACHER) else round(cleanfid.compute_fid(m, TEACHER, verbose=False), 3)
out[name] = {**pm, 'FID_vs_real': fid_real, 'FID_vs_teacher': fid_teacher}
print(f" {out[name]}", flush=True)
json.dump(out, open('outputs/eval/metrics.json', 'w'), indent=2)
print("\n=== SUMMARY (outputs/eval/metrics.json) ===")
print(f"{'model':30s} {'PickScore↑':10s} {'CLIP↑':7s} {'LPIPS↓':7s} {'PSNR↑':7s} {'FIDreal↓':9s} {'FIDteach↓':9s}")
for name, v in out.items():
print(f"{name:30s} {str(v['PickScore']):10s} {str(v['CLIP']):7s} {str(v['LPIPS']):7s} {str(v['PSNR']):7s} {str(v['FID_vs_real']):9s} {str(v['FID_vs_teacher']):9s}")

Xet Storage Details

Size:
4.42 kB
·
Xet hash:
556203498802d8befe74d201bb159782fe1372c043f3285521b5af707330ec4a

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.