code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified | """FID (1-2k samples) per backbone x dataset + clear same-mask aligned viz. | |
| A) fid-sample: train_fraction=1.0, mask_aug, n_per_mask -> ~1.6-2.6k synth; FID vs real train. | |
| B) align-sample: f50 masks, NO aug, 1/mask -> all backbones share identical real masks -> aligned grid. | |
| Then pytorch_fid per pair + build [mask|real|4 backbones] grids. GPU0-5 pool.""" | |
| import os, time, json, re, subprocess | |
| import numpy as np | |
| import matplotlib; matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| ROOT = "/home/wzhang/LSC/Code/NPJ"; DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified" | |
| PY = "/opt/anaconda3/envs/seggen/bin/python"; GPUS = [0, 1, 2, 3, 4, 5] | |
| os.chdir(ROOT); LOGD = os.path.join(ROOT, "logs", "fidviz"); os.makedirs(LOGD, exist_ok=True) | |
| def log(m): | |
| line = f"[{time.strftime('%F %T')}] {m}"; open(os.path.join(LOGD, "status.md"), "a").write(line + "\n"); print(line, flush=True) | |
| # (ds, proto, total, npm_for_fid) | |
| DSETS = {"isic": ("medsegdb_isic2018", "holdout", 2582, 1), | |
| "kvasir": ("kvasir_seg", "official", 800, 2), | |
| "busi": ("busi", "fold01", 545, 3)} | |
| BKS = ["jit", "pixelgen", "deco", "pixeldit"]; LAB = {"jit": "JiT", "pixelgen": "PixelGen", "deco": "DeCo", "pixeldit": "PixelDiT"} | |
| jobs = {} | |
| def add(jid, cmd, deps=(), done_path=None, done_min=1): | |
| jobs[jid] = {"cmd": cmd, "deps": list(deps), "done_path": done_path, "done_min": done_min, "state": "pending", "tries": 0, "gpu": None} | |
| for bk in BKS: | |
| for dk, (ds, proto, tot, npm) in DSETS.items(): | |
| ck = f"pretrained/pixdiff/p1_{bk}_{dk}.pt" | |
| fsd = f"{DR}/{ds}/{proto}/synth_fid_{bk}_{dk}" | |
| add(f"fidsamp_{bk}_{dk}", | |
| f"{PY} -m framework.synth.pixdiff.sample --ckpt {ck} --data_root {DR} --dataset {ds} --protocol {proto} " | |
| f"--train_fraction 1.0 --fraction_seed 0 --n_per_mask {npm} --mask_aug --num_steps 50 --out_dir {fsd}", | |
| done_path=os.path.join(fsd, "images"), done_min=int(0.8 * tot * npm)) | |
| real = f"{DR}/{ds}/{proto}/train/images" | |
| flog = os.path.join(LOGD, f"fid_{bk}_{dk}.log"); fok = os.path.join(LOGD, f"fid_{bk}_{dk}.ok") | |
| add(f"fid_{bk}_{dk}", | |
| f"{PY} -m pytorch_fid {real} {fsd}/images --device cuda --batch-size 64 > {flog} 2>&1 && grep -q FID {flog} && touch {fok}", | |
| deps=[f"fidsamp_{bk}_{dk}"], done_path=fok) | |
| f50 = 50 / tot; asd = f"{DR}/{ds}/{proto}/synth_align_{bk}_{dk}" | |
| add(f"alignsamp_{bk}_{dk}", | |
| f"{PY} -m framework.synth.pixdiff.sample --ckpt {ck} --data_root {DR} --dataset {ds} --protocol {proto} " | |
| f"--train_fraction {f50} --fraction_seed 0 --n_per_mask 1 --num_steps 50 --out_dir {asd}", | |
| done_path=os.path.join(asd, "images"), done_min=40) | |
| def is_done(j): | |
| p = j["done_path"] | |
| if not p or not os.path.exists(p): return False | |
| if os.path.isdir(p): | |
| try: return len(os.listdir(p)) >= j["done_min"] | |
| except OSError: return False | |
| return True | |
| for jid, j in jobs.items(): | |
| if is_done(j): j["state"] = "done" | |
| def deps_done(j): return all(jobs[d]["state"] == "done" for d in j["deps"]) | |
| running = {}; free = set(GPUS); last = 0 | |
| log(f"START {len(jobs)} jobs on {GPUS}") | |
| while True: | |
| if all(j["state"] in ("done", "failed") for j in jobs.values()): break | |
| for jid, j in jobs.items(): | |
| if not free: break | |
| if j["state"] == "pending" and deps_done(j): | |
| if is_done(j): j["state"] = "done"; continue | |
| g = free.pop() | |
| env = dict(os.environ, CUDA_DEVICE_ORDER="PCI_BUS_ID", CUDA_VISIBLE_DEVICES=str(g), TORCHDYNAMO_DISABLE="1", PYTHONPATH=".", OMP_NUM_THREADS="4") | |
| lf = open(os.path.join(LOGD, jid + ".log"), "a") | |
| p = subprocess.Popen(j["cmd"], shell=True, env=env, stdout=lf, stderr=subprocess.STDOUT, cwd=ROOT) | |
| running[g] = (jid, p, lf); j["state"] = "running"; j["gpu"] = g; j["tries"] += 1 | |
| log(f"LAUNCH {jid} GPU{g} try{j['tries']}") | |
| for g, (jid, p, lf) in list(running.items()): | |
| rc = p.poll() | |
| if rc is None: continue | |
| lf.close(); del running[g]; free.add(g); j = jobs[jid] | |
| if is_done(j): j["state"] = "done"; log(f"DONE {jid}") | |
| elif j["tries"] < 2: j["state"] = "pending"; log(f"RETRY {jid} rc={rc}") | |
| else: j["state"] = "failed"; log(f"FAILED {jid} rc={rc}") | |
| if time.time() - last > 180: | |
| cnt = {s: sum(1 for j in jobs.values() if j["state"] == s) for s in ("done", "running", "pending", "failed")}; log(f"SUMMARY {cnt}"); last = time.time() | |
| time.sleep(8) | |
| # ---- parse FID ---- | |
| fid = {} | |
| for bk in BKS: | |
| for dk in DSETS: | |
| lg = os.path.join(LOGD, f"fid_{bk}_{dk}.log") | |
| if os.path.exists(lg): | |
| m = re.findall(r"FID:\s*([0-9.]+)", open(lg).read()) | |
| if m: fid[f"{dk}_{bk}"] = float(m[-1]) | |
| json.dump(fid, open(os.path.join(LOGD, "fid_results.json"), "w"), indent=2) | |
| log(f"FID: {fid}") | |
| # ---- aligned grids ([mask | real | 4 backbones], same real mask per column) ---- | |
| def rgb(p): return np.asarray(Image.open(p).convert("RGB").resize((256, 256))) | |
| def gray(p): return np.asarray(Image.open(p).convert("L").resize((256, 256))) | |
| def fmap(d): | |
| p = os.path.join(d, "images"); m = {} | |
| if os.path.isdir(p): | |
| for f in sorted(os.listdir(p)): | |
| if f.endswith(".png"): m.setdefault(f[:-4].split("__")[0], os.path.join(p, f)) | |
| return m | |
| for dk, (ds, proto, tot, npm) in DSETS.items(): | |
| base = f"{DR}/{ds}/{proto}"; ri, rm = f"{base}/train/images", f"{base}/train/masks" | |
| maps = {bk: fmap(f"{base}/synth_align_{bk}_{dk}") for bk in BKS} | |
| common = set(os.path.splitext(f)[0] for f in os.listdir(ri) if f.endswith(".png")) | |
| for bk in BKS: common &= set(maps[bk].keys()) | |
| common = sorted(common); ncol = min(6, len(common)) | |
| if ncol == 0: continue | |
| idx = [round(i * (len(common) - 1) / (ncol - 1)) for i in range(ncol)] if ncol > 1 else [0] | |
| cases = [common[i] for i in idx] | |
| rows = [("Conditioning mask", "mask"), ("Real image", "real")] + [(LAB[bk], bk) for bk in BKS] | |
| fig, ax = plt.subplots(len(rows), ncol, figsize=(ncol * 1.9, len(rows) * 1.95)) | |
| for r, (labr, kind) in enumerate(rows): | |
| for c, bs in enumerate(cases): | |
| a = ax[r][c] | |
| try: | |
| mk = gray(f"{rm}/{bs}.png") | |
| if kind == "mask": | |
| a.imshow(mk, cmap="gray") | |
| elif kind == "real": | |
| a.imshow(rgb(f"{ri}/{bs}.png")); a.contour((mk > 127).astype(float), levels=[0.5], colors=["#19f04b"], linewidths=1.0) | |
| else: | |
| a.imshow(rgb(maps[kind][bs])); a.contour((mk > 127).astype(float), levels=[0.5], colors=["#19f04b"], linewidths=1.0) | |
| except Exception: | |
| a.imshow(np.ones((256, 256, 3))); a.text(0.5, 0.5, "n/a", ha="center", va="center", transform=a.transAxes, fontsize=8) | |
| a.set_xticks([]); a.set_yticks([]) | |
| for s in a.spines.values(): s.set_visible(False) | |
| if c == 0: a.set_ylabel(labr, fontsize=10, rotation=90, va="center", labelpad=8, color=("#111" if r < 2 else "#1a3b8b")) | |
| fig.suptitle(f"{dk.upper()} — same-mask aligned: every backbone generates the SAME real mask (row 1)\n" | |
| f"Row2=real image; rows 3-6=each backbone's mask-conditioned synthesis (green=that mask). Proves mask guidance.", fontsize=10) | |
| plt.tight_layout(rect=[0.02, 0, 1, 0.94]); plt.savefig(f"/tmp/p1_aligned_{dk}.png", dpi=145, bbox_inches="tight", facecolor="white") | |
| log(f"aligned grid saved /tmp/p1_aligned_{dk}.png") | |
| log("ALL DONE"); print("FIDVIZ_DONE", flush=True) | |