| |
| """Self-contained visual-quality evaluation (FID / CLIP-text / CLIP-img / LPIPS). |
| |
| This is a standalone re-implementation of the visual-quality half of the |
| PixelGen ``eval_depth_metrics.py`` engine, with the DepthAnything dependency |
| removed so it runs anywhere with only ``torch / torchvision / transformers / |
| pytorch_fid / lpips`` installed. |
| |
| It compares a folder of generated RGB images against the paired ground-truth |
| RGB images + captions of the evaluation set, and reports: |
| |
| * FID (Inception-V3 2048-d features) -- lower is better |
| * CLIP-text cos(CLIP_img(gen), CLIP_text(caption)) -- higher is better |
| * CLIP-img cos(CLIP_img(gen), CLIP_img(real)) -- higher is better |
| * LPIPS (alex) perceptual distance gen vs real -- lower is better |
| |
| Matching convention |
| ------------------- |
| For each generated file the *stem* ``sa_XXXXXX`` is used to look up the GT |
| image ``<image_root>/<stem>.{jpg,jpeg,png}`` and caption ``<image_root>/<stem>.txt``. |
| Generated files may be named either ``<stem>.png`` or, for single-control |
| outputs, ``<stem>_<suffix>.png`` (e.g. ``sa_000201_seg.png``) -- pass the |
| suffix via ``--control_suffix seg``. |
| |
| Examples |
| -------- |
| # Our three-control run, segmentation-only outputs (sa_xxxxxx_seg.png): |
| python eval/eval_visual_quality.py \ |
| --gen_dir /.../val/all_modes_eval2000/iter_10000 \ |
| --name ours_seg --control_suffix seg \ |
| --image_root t2i/data/blip/extracted_new/sa_000201 \ |
| --clip_model t2i/pretrained/clip-vit-large-patch14 \ |
| --metrics fid clip_img lpips \ |
| --output_json outputs/vq_ours_seg.json |
| |
| # Edge-only checkpoint outputs (sa_xxxxxx.png produced by infer): |
| python eval/eval_visual_quality.py \ |
| --gen_dir /.../val/edge_eval2000/iter_12000 --name edge_iter12000 \ |
| --metrics fid clip_img lpips --output_json outputs/vq_edge.json |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import re |
| import time |
| from typing import Dict, List, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision import transforms |
|
|
| IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG", ".webp") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Visual quality: FID / CLIP / LPIPS.") |
| p.add_argument("--gen_dir", action="append", default=[], dest="gen_dirs", |
| help="Generated image folder. Repeatable to compare many runs.") |
| p.add_argument("--name", action="append", default=[], dest="names", |
| help="Display name per --gen_dir (must match count if given).") |
| p.add_argument("--control_suffix", default="", |
| help="If set, only match files named sa_xxxxxx_<suffix>.png.") |
| p.add_argument("--image_root", |
| default="t2i/data/blip/extracted_new/sa_000201", |
| help="GT eval folder containing <stem>.txt + <stem>.<img ext>.") |
| p.add_argument("--clip_model", |
| default="t2i/pretrained/clip-vit-large-patch14") |
| p.add_argument("--metrics", nargs="+", |
| choices=["fid", "clip_text", "clip_img", "lpips"], |
| default=["fid", "clip_text", "clip_img", "lpips"]) |
| p.add_argument("--resolution", type=int, default=512) |
| p.add_argument("--batch_size", type=int, default=16) |
| p.add_argument("--max_samples", type=int, default=-1) |
| p.add_argument("--device", default="cuda:0") |
| p.add_argument("--num_workers", type=int, default=4) |
| p.add_argument("--output_json", default="outputs/visual_quality.json") |
| return p.parse_args() |
|
|
|
|
| def find_with_exts(root: str, stem: str) -> Optional[str]: |
| for ext in IMAGE_EXTS: |
| path = os.path.join(root, stem + ext) |
| if os.path.exists(path): |
| return path |
| return None |
|
|
|
|
| def gen_pattern(control_suffix: str) -> re.Pattern: |
| if control_suffix: |
| return re.compile(rf"^(?P<stem>sa_\d+)_{re.escape(control_suffix)}\.png$") |
| return re.compile(r"^(?P<stem>sa_\d+)\.png$") |
|
|
|
|
| def index_gen_dir(gen_dir: str, control_suffix: str) -> Dict[str, str]: |
| pat = gen_pattern(control_suffix) |
| out: Dict[str, str] = {} |
| for name in sorted(os.listdir(gen_dir)): |
| m = pat.match(name) |
| if m: |
| out[m.group("stem")] = os.path.join(gen_dir, name) |
| return out |
|
|
|
|
| class _PathDataset(Dataset): |
| def __init__(self, paths: List[str], resolution: int): |
| self.paths = paths |
| self.tx = transforms.Compose([ |
| transforms.Resize(resolution), |
| transforms.CenterCrop(resolution), |
| transforms.ToTensor(), |
| ]) |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| def __getitem__(self, idx): |
| with Image.open(self.paths[idx]) as im: |
| return self.tx(im.convert("RGB")) |
|
|
|
|
| def _build_inception(device): |
| from pytorch_fid.inception import InceptionV3 |
| block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] |
| return InceptionV3([block_idx]).to(device).eval() |
|
|
|
|
| @torch.no_grad() |
| def _inception_acts(model, paths, resolution, batch_size, device, n_workers=4) -> np.ndarray: |
| dl = DataLoader(_PathDataset(paths, resolution), batch_size=batch_size, |
| num_workers=n_workers, pin_memory=True, shuffle=False) |
| feats = [] |
| for x in dl: |
| out = model(x.to(device, non_blocking=True))[0] |
| feats.append(out.squeeze(-1).squeeze(-1).cpu().numpy()) |
| return np.concatenate(feats, axis=0) |
|
|
|
|
| def _fid(acts_a, acts_b) -> float: |
| from pytorch_fid.fid_score import calculate_frechet_distance |
| mu_a, sig_a = acts_a.mean(0), np.cov(acts_a, rowvar=False) |
| mu_b, sig_b = acts_b.mean(0), np.cov(acts_b, rowvar=False) |
| return float(calculate_frechet_distance(mu_a, sig_a, mu_b, sig_b)) |
|
|
|
|
| def _build_clip(model_name, device): |
| from transformers import CLIPModel, CLIPProcessor |
| return (CLIPModel.from_pretrained(model_name).to(device).eval(), |
| CLIPProcessor.from_pretrained(model_name)) |
|
|
|
|
| def _as_tensor(out) -> torch.Tensor: |
| if isinstance(out, torch.Tensor): |
| return out |
| for attr in ("image_embeds", "text_embeds", "pooler_output", "last_hidden_state"): |
| v = getattr(out, attr, None) |
| if isinstance(v, torch.Tensor): |
| return v |
| raise TypeError(f"cannot coerce CLIP output {type(out).__name__}") |
|
|
|
|
| @torch.no_grad() |
| def _clip_img(model, proc, paths, batch_size, device) -> torch.Tensor: |
| embeds = [] |
| for i in range(0, len(paths), batch_size): |
| imgs = [Image.open(p).convert("RGB") for p in paths[i:i + batch_size]] |
| inp = proc(images=imgs, return_tensors="pt").to(device) |
| e = F.normalize(_as_tensor(model.get_image_features(**inp)), dim=-1) |
| embeds.append(e.cpu()) |
| return torch.cat(embeds, 0) |
|
|
|
|
| @torch.no_grad() |
| def _clip_text(model, proc, caps, batch_size, device) -> torch.Tensor: |
| embeds = [] |
| for i in range(0, len(caps), batch_size): |
| inp = proc(text=caps[i:i + batch_size], return_tensors="pt", |
| padding=True, truncation=True, max_length=77).to(device) |
| e = F.normalize(_as_tensor(model.get_text_features(**inp)), dim=-1) |
| embeds.append(e.cpu()) |
| return torch.cat(embeds, 0) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if args.names and len(args.names) != len(args.gen_dirs): |
| raise ValueError("--name count must match --gen_dir count") |
| if not args.gen_dirs: |
| raise ValueError("at least one --gen_dir is required") |
| device = torch.device(args.device) |
| os.makedirs(os.path.dirname(args.output_json) or ".", exist_ok=True) |
|
|
| metrics = set(args.metrics) |
| need_clip = "clip_text" in metrics or "clip_img" in metrics |
| need_incep = "fid" in metrics |
| need_lpips = "lpips" in metrics |
|
|
| clip_model = clip_proc = inception = lpips_model = None |
| if need_clip: |
| print(f"[vq] loading CLIP {args.clip_model}") |
| clip_model, clip_proc = _build_clip(args.clip_model, device) |
| if need_incep: |
| print("[vq] loading InceptionV3") |
| inception = _build_inception(device) |
| if need_lpips: |
| print("[vq] loading LPIPS(alex)") |
| import lpips |
| lpips_model = lpips.LPIPS(net="alex").to(device).eval() |
|
|
| |
| results = {} |
| for i, gen_dir in enumerate(args.gen_dirs): |
| name = args.names[i] if args.names else os.path.basename(gen_dir.rstrip("/")) |
| gen_index = index_gen_dir(gen_dir, args.control_suffix) |
| |
| stems = [] |
| for stem in sorted(gen_index): |
| ip = find_with_exts(args.image_root, stem) |
| cp = os.path.join(args.image_root, stem + ".txt") |
| if ip is None: |
| continue |
| if "clip_text" in metrics and not os.path.exists(cp): |
| continue |
| stems.append(stem) |
| if args.max_samples > 0: |
| stems = stems[: args.max_samples] |
| if not stems: |
| print(f"[vq][skip] {name}: 0 matched stems in {gen_dir}") |
| continue |
|
|
| gen_paths = [gen_index[s] for s in stems] |
| real_paths = [find_with_exts(args.image_root, s) for s in stems] |
| caps = [] |
| if "clip_text" in metrics: |
| for s in stems: |
| with open(os.path.join(args.image_root, s + ".txt"), encoding="utf-8", errors="ignore") as f: |
| caps.append(f.read().strip()) |
|
|
| rec = {"n_samples": len(stems), "gen_dir": gen_dir} |
| print(f"\n===== {name} ({len(stems)} samples) =====") |
|
|
| if need_clip: |
| gen_e = _clip_img(clip_model, clip_proc, gen_paths, args.batch_size, device) |
| if "clip_img" in metrics: |
| real_e = _clip_img(clip_model, clip_proc, real_paths, args.batch_size, device) |
| rec["clip_img"] = float((gen_e * real_e).sum(-1).mean()) |
| if "clip_text" in metrics: |
| text_e = _clip_text(clip_model, clip_proc, caps, args.batch_size, device) |
| rec["clip_text"] = float((gen_e * text_e).sum(-1).mean()) |
|
|
| if need_incep: |
| real_acts = _inception_acts(inception, real_paths, args.resolution, |
| args.batch_size, device, args.num_workers) |
| gen_acts = _inception_acts(inception, gen_paths, args.resolution, |
| args.batch_size, device, args.num_workers) |
| rec["fid"] = _fid(real_acts, gen_acts) |
|
|
| if need_lpips: |
| dl_g = DataLoader(_PathDataset(gen_paths, args.resolution), batch_size=args.batch_size, |
| num_workers=args.num_workers, pin_memory=True, shuffle=False) |
| dl_r = DataLoader(_PathDataset(real_paths, args.resolution), batch_size=args.batch_size, |
| num_workers=args.num_workers, pin_memory=True, shuffle=False) |
| vals = [] |
| with torch.no_grad(): |
| for xg, xr in zip(dl_g, dl_r): |
| xg = (xg * 2 - 1).to(device) |
| xr = (xr * 2 - 1).to(device) |
| vals.append(lpips_model(xg, xr).view(-1).cpu()) |
| rec["lpips"] = float(torch.cat(vals).mean()) |
|
|
| results[name] = rec |
| print(" " + " ".join(f"{k}={v:.4f}" for k, v in rec.items() if isinstance(v, float))) |
|
|
| blob = {"metadata": {"image_root": args.image_root, "resolution": args.resolution, |
| "clip_model": args.clip_model if need_clip else None, |
| "control_suffix": args.control_suffix, |
| "metrics": list(metrics)}, |
| "results": results} |
| with open(args.output_json, "w") as f: |
| json.dump(blob, f, indent=2) |
| print(f"\n[vq] wrote {args.output_json}") |
|
|
| cols = [m for m in ["fid", "clip_text", "clip_img", "lpips"] if m in metrics] |
| w = 10 |
| header = f"{'run':>22s} | " + " | ".join(f"{c:>{w}s}" for c in cols) |
| print("\n" + header + "\n" + "-" * len(header)) |
| for tag, rec in results.items(): |
| print(f"{tag:>22s} | " + " | ".join(f"{rec.get(c, float('nan')):>{w}.4f}" for c in cols)) |
| print("\nDirections: FID lower better | CLIP-text/img higher better | LPIPS lower better") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|