GenSeg-Baselines / code /scripts /p1 /fid_fixed.py
MaybeRichard's picture
code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified
Raw
History Blame Contribute Delete
2.03 kB
"""Corrected FID: resize REAL train images to 256 (synth already 256) so pytorch_fid's
collate doesn't choke on variable native sizes (Kvasir/BUSI). FID(real256, synth) per pair."""
import os, re, json, subprocess
from PIL import Image
ROOT = "/home/wzhang/LSC/Code/NPJ"; DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified"
PY = "/opt/anaconda3/envs/seggen/bin/python"
DSETS = {"isic": ("medsegdb_isic2018", "holdout"), "kvasir": ("kvasir_seg", "official"), "busi": ("busi", "fold01")}
BKS = ["jit", "pixelgen", "deco", "pixeldit"]
real256 = {}
for dk, (ds, proto) in DSETS.items():
src = f"{DR}/{ds}/{proto}/train/images"; dst = f"/tmp/real256_{dk}"; os.makedirs(dst, exist_ok=True)
for f in os.listdir(src):
if f.lower().endswith((".png", ".jpg", ".jpeg")):
o = f"{dst}/{os.path.splitext(f)[0]}.png"
if not os.path.exists(o):
Image.open(f"{src}/{f}").convert("RGB").resize((256, 256)).save(o)
real256[dk] = dst
print(f"[resize] real {dk}: {len(os.listdir(dst))} imgs", flush=True)
fid = {}
for bk in BKS:
for dk, (ds, proto) in DSETS.items():
synth = f"{DR}/{ds}/{proto}/synth_fid_{bk}_{dk}/images"
if not os.path.isdir(synth) or len(os.listdir(synth)) < 100:
print(f"[skip] {dk} {bk}: synth missing/small", flush=True); continue
env = dict(os.environ, CUDA_DEVICE_ORDER="PCI_BUS_ID", CUDA_VISIBLE_DEVICES="0")
r = subprocess.run([PY, "-m", "pytorch_fid", real256[dk], synth, "--device", "cuda", "--batch-size", "50"],
capture_output=True, text=True, env=env)
m = re.findall(r"FID:\s*([0-9.]+)", r.stdout + r.stderr)
if m:
fid[f"{dk}_{bk}"] = round(float(m[-1]), 2); print(f"[FID] {dk} {bk} = {m[-1]}", flush=True)
else:
print(f"[FAIL] {dk} {bk}: {(r.stderr or r.stdout)[-300:]}", flush=True)
json.dump(fid, open(f"{ROOT}/logs/fidviz/fid_results.json", "w"), indent=2)
print("FID_FIXED_DONE", json.dumps(fid), flush=True)