GenSeg-Baselines / code /scripts /p1 /gen_prdc.py
MaybeRichard's picture
code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified
Raw
History Blame Contribute Delete
2.59 kB
"""Generation-side Precision/Recall (Kynkaanniemi) + Density/Coverage (Naeem) on
InceptionV3 (pytorch-fid) features. Precision=fidelity (fake in real manifold),
Recall=diversity/coverage (real covered by fake). No sklearn: kNN via torch.cdist."""
import os, sys, json, random
import numpy as np, torch
from PIL import Image
sys.path.insert(0, "/home/wzhang/LSC/Code/NPJ")
from framework.synth.pixdiff.fd_loss import InceptionFeatures
DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified"
DSETS = {"isic": ("medsegdb_isic2018", "holdout"), "kvasir": ("kvasir_seg", "official"), "busi": ("busi", "fold01")}
BKS = ["jit", "pixelgen", "deco", "pixeldit"]
dev = "cuda"; CAP = 2000; K = 5
inc = InceptionFeatures().to(dev).eval()
def feats(d, cap=CAP):
fs = sorted(f for f in os.listdir(d) if f.lower().endswith((".png", ".jpg", ".jpeg")))
if len(fs) > cap:
random.seed(0); fs = random.sample(fs, cap)
out = []
for i in range(0, len(fs), 64):
b = []
for f in fs[i:i + 64]:
im = Image.open(os.path.join(d, f)).convert("RGB").resize((256, 256))
b.append(torch.from_numpy(np.asarray(im)).permute(2, 0, 1).float() / 255.)
with torch.no_grad():
out.append(inc(torch.stack(b).to(dev)).cpu())
return torch.cat(out)
def knn_radius(X, k):
d = torch.cdist(X, X); d.fill_diagonal_(float("inf")); return d.kthvalue(k, dim=1).values
def prdc(R, F, k=K):
R, F = R.to(dev), F.to(dev)
rr = knn_radius(R, k); ff = knn_radius(F, k); drf = torch.cdist(R, F)
prec = (drf <= rr[:, None]).any(0).float().mean().item()
rec = (drf <= ff[None, :]).any(1).float().mean().item()
dens = ((drf <= rr[:, None]).sum(0).float().mean() / k).item()
cov = (drf <= rr[:, None]).any(1).float().mean().item()
return prec, rec, dens, cov
realf = {}
for dk, (ds, proto) in DSETS.items():
realf[dk] = feats(f"{DR}/{ds}/{proto}/train/images")
print(f"[real] {dk}: {realf[dk].shape}", flush=True)
res = {}
for bk in BKS:
for dk, (ds, proto) in DSETS.items():
sd = f"{DR}/{ds}/{proto}/synth_fid_{bk}_{dk}/images"
if not os.path.isdir(sd):
print(f"[skip] {dk} {bk}"); continue
F = feats(sd)
p, r, de, c = prdc(realf[dk], F)
res[f"{dk}_{bk}"] = {"precision": round(p, 3), "recall": round(r, 3), "density": round(de, 3), "coverage": round(c, 3)}
print(f"[PRDC] {dk} {bk}: {res[f'{dk}_{bk}']}", flush=True)
json.dump(res, open("/home/wzhang/LSC/Code/NPJ/logs/fidviz/gen_prdc.json", "w"), indent=2)
print("PRDC_DONE", flush=True)