code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified | """Generation-side Precision/Recall (Kynkaanniemi) + Density/Coverage (Naeem) on | |
| InceptionV3 (pytorch-fid) features. Precision=fidelity (fake in real manifold), | |
| Recall=diversity/coverage (real covered by fake). No sklearn: kNN via torch.cdist.""" | |
| import os, sys, json, random | |
| import numpy as np, torch | |
| from PIL import Image | |
| sys.path.insert(0, "/home/wzhang/LSC/Code/NPJ") | |
| from framework.synth.pixdiff.fd_loss import InceptionFeatures | |
| DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified" | |
| DSETS = {"isic": ("medsegdb_isic2018", "holdout"), "kvasir": ("kvasir_seg", "official"), "busi": ("busi", "fold01")} | |
| BKS = ["jit", "pixelgen", "deco", "pixeldit"] | |
| dev = "cuda"; CAP = 2000; K = 5 | |
| inc = InceptionFeatures().to(dev).eval() | |
| def feats(d, cap=CAP): | |
| fs = sorted(f for f in os.listdir(d) if f.lower().endswith((".png", ".jpg", ".jpeg"))) | |
| if len(fs) > cap: | |
| random.seed(0); fs = random.sample(fs, cap) | |
| out = [] | |
| for i in range(0, len(fs), 64): | |
| b = [] | |
| for f in fs[i:i + 64]: | |
| im = Image.open(os.path.join(d, f)).convert("RGB").resize((256, 256)) | |
| b.append(torch.from_numpy(np.asarray(im)).permute(2, 0, 1).float() / 255.) | |
| with torch.no_grad(): | |
| out.append(inc(torch.stack(b).to(dev)).cpu()) | |
| return torch.cat(out) | |
| def knn_radius(X, k): | |
| d = torch.cdist(X, X); d.fill_diagonal_(float("inf")); return d.kthvalue(k, dim=1).values | |
| def prdc(R, F, k=K): | |
| R, F = R.to(dev), F.to(dev) | |
| rr = knn_radius(R, k); ff = knn_radius(F, k); drf = torch.cdist(R, F) | |
| prec = (drf <= rr[:, None]).any(0).float().mean().item() | |
| rec = (drf <= ff[None, :]).any(1).float().mean().item() | |
| dens = ((drf <= rr[:, None]).sum(0).float().mean() / k).item() | |
| cov = (drf <= rr[:, None]).any(1).float().mean().item() | |
| return prec, rec, dens, cov | |
| realf = {} | |
| for dk, (ds, proto) in DSETS.items(): | |
| realf[dk] = feats(f"{DR}/{ds}/{proto}/train/images") | |
| print(f"[real] {dk}: {realf[dk].shape}", flush=True) | |
| res = {} | |
| for bk in BKS: | |
| for dk, (ds, proto) in DSETS.items(): | |
| sd = f"{DR}/{ds}/{proto}/synth_fid_{bk}_{dk}/images" | |
| if not os.path.isdir(sd): | |
| print(f"[skip] {dk} {bk}"); continue | |
| F = feats(sd) | |
| p, r, de, c = prdc(realf[dk], F) | |
| res[f"{dk}_{bk}"] = {"precision": round(p, 3), "recall": round(r, 3), "density": round(de, 3), "coverage": round(c, 3)} | |
| print(f"[PRDC] {dk} {bk}: {res[f'{dk}_{bk}']}", flush=True) | |
| json.dump(res, open("/home/wzhang/LSC/Code/NPJ/logs/fidviz/gen_prdc.json", "w"), indent=2) | |
| print("PRDC_DONE", flush=True) | |