"""Corrected FID: resize REAL train images to 256 (synth already 256) so pytorch_fid's collate doesn't choke on variable native sizes (Kvasir/BUSI). FID(real256, synth) per pair.""" import os, re, json, subprocess from PIL import Image ROOT = "/home/wzhang/LSC/Code/NPJ"; DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified" PY = "/opt/anaconda3/envs/seggen/bin/python" DSETS = {"isic": ("medsegdb_isic2018", "holdout"), "kvasir": ("kvasir_seg", "official"), "busi": ("busi", "fold01")} BKS = ["jit", "pixelgen", "deco", "pixeldit"] real256 = {} for dk, (ds, proto) in DSETS.items(): src = f"{DR}/{ds}/{proto}/train/images"; dst = f"/tmp/real256_{dk}"; os.makedirs(dst, exist_ok=True) for f in os.listdir(src): if f.lower().endswith((".png", ".jpg", ".jpeg")): o = f"{dst}/{os.path.splitext(f)[0]}.png" if not os.path.exists(o): Image.open(f"{src}/{f}").convert("RGB").resize((256, 256)).save(o) real256[dk] = dst print(f"[resize] real {dk}: {len(os.listdir(dst))} imgs", flush=True) fid = {} for bk in BKS: for dk, (ds, proto) in DSETS.items(): synth = f"{DR}/{ds}/{proto}/synth_fid_{bk}_{dk}/images" if not os.path.isdir(synth) or len(os.listdir(synth)) < 100: print(f"[skip] {dk} {bk}: synth missing/small", flush=True); continue env = dict(os.environ, CUDA_DEVICE_ORDER="PCI_BUS_ID", CUDA_VISIBLE_DEVICES="0") r = subprocess.run([PY, "-m", "pytorch_fid", real256[dk], synth, "--device", "cuda", "--batch-size", "50"], capture_output=True, text=True, env=env) m = re.findall(r"FID:\s*([0-9.]+)", r.stdout + r.stderr) if m: fid[f"{dk}_{bk}"] = round(float(m[-1]), 2); print(f"[FID] {dk} {bk} = {m[-1]}", flush=True) else: print(f"[FAIL] {dk} {bk}: {(r.stderr or r.stdout)[-300:]}", flush=True) json.dump(fid, open(f"{ROOT}/logs/fidviz/fid_results.json", "w"), indent=2) print("FID_FIXED_DONE", json.dumps(fid), flush=True)