code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified | """Same-mask JiT(native) vs JiT-FD comparison: does FD-perceptual sharpen the synth? | |
| Samples JiT-FD on the SAME no-aug f50 masks the native JiT align-set used, builds | |
| [mask | real | JiT native | JiT-FD] grids for ISIC + Kvasir.""" | |
| import os, subprocess | |
| import numpy as np | |
| import matplotlib; matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| ROOT = "/home/wzhang/LSC/Code/NPJ"; DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified" | |
| PY = "/opt/anaconda3/envs/seggen/bin/python" | |
| DSETS = {"isic": ("medsegdb_isic2018", "holdout", 2582), "kvasir": ("kvasir_seg", "official", 800)} | |
| def sample(ckpt, ds, proto, frac, out): | |
| if os.path.isdir(out + "/images") and len(os.listdir(out + "/images")) >= 40: | |
| print(f"[skip-sample] {out} exists", flush=True); return | |
| env = dict(os.environ, CUDA_DEVICE_ORDER="PCI_BUS_ID", CUDA_VISIBLE_DEVICES="0", | |
| TORCHDYNAMO_DISABLE="1", PYTHONPATH=".", OMP_NUM_THREADS="4") | |
| subprocess.run([PY, "-m", "framework.synth.pixdiff.sample", "--ckpt", ckpt, "--data_root", DR, | |
| "--dataset", ds, "--protocol", proto, "--train_fraction", str(frac), | |
| "--fraction_seed", "0", "--n_per_mask", "1", "--num_steps", "50", "--out_dir", out], | |
| env=env, cwd=ROOT, check=True) | |
| print(f"[sampled] {out}", flush=True) | |
| def fmap(d): | |
| p = os.path.join(d, "images"); m = {} | |
| if os.path.isdir(p): | |
| for f in sorted(os.listdir(p)): | |
| if f.endswith(".png"): m.setdefault(f[:-4].split("__")[0], os.path.join(p, f)) | |
| return m | |
| def rgb(p): return np.asarray(Image.open(p).convert("RGB").resize((256, 256))) | |
| def gray(p): return np.asarray(Image.open(p).convert("L").resize((256, 256))) | |
| for dk, (ds, proto, tot) in DSETS.items(): | |
| f50 = 50 / tot | |
| fd_out = f"{DR}/{ds}/{proto}/synth_alignfd_jitfd_{dk}" | |
| sample(f"pretrained/pixdiff/p1_jitfd_{dk}.pt", ds, proto, f50, fd_out) | |
| base = f"{DR}/{ds}/{proto}"; ri, rm = f"{base}/train/images", f"{base}/train/masks" | |
| nat = fmap(f"{base}/synth_align_jit_{dk}"); fd = fmap(fd_out) | |
| common = set(os.path.splitext(f)[0] for f in os.listdir(ri) if f.endswith(".png")) & set(nat) & set(fd) | |
| common = sorted(common); ncol = min(6, len(common)) | |
| idx = [round(i * (len(common) - 1) / (ncol - 1)) for i in range(ncol)] if ncol > 1 else [0] | |
| cases = [common[i] for i in idx] | |
| rows = [("Conditioning mask", "mask"), ("Real", "real"), ("JiT (native, P1)", nat), ("JiT-FD (FD-感知)", fd)] | |
| fig, ax = plt.subplots(len(rows), ncol, figsize=(ncol * 2.1, len(rows) * 2.15)) | |
| for r, (lab, src) in enumerate(rows): | |
| for c, bs in enumerate(cases): | |
| a = ax[r][c] | |
| try: | |
| mk = gray(f"{rm}/{bs}.png") | |
| if src == "mask": a.imshow(mk, cmap="gray") | |
| elif src == "real": a.imshow(rgb(f"{ri}/{bs}.png")) | |
| else: a.imshow(rgb(src[bs])) | |
| if src not in ("mask",): a.contour((mk > 127).astype(float), levels=[0.5], colors=["#19f04b"], linewidths=0.9) | |
| except Exception: | |
| a.imshow(np.ones((256, 256, 3))); a.text(0.5, 0.5, "n/a", ha="center", va="center", transform=a.transAxes) | |
| a.set_xticks([]); a.set_yticks([]) | |
| for s in a.spines.values(): s.set_visible(False) | |
| if c == 0: a.set_ylabel(lab, fontsize=11, rotation=90, va="center", labelpad=8, | |
| color=("#111" if r < 2 else "#1a3b8b"), fontweight=("bold" if r == 3 else "normal")) | |
| fig.suptitle(f"{dk.upper()} — 原生 JiT vs JiT-FD(同掩码):FD-感知精修是否更锐?", fontsize=12) | |
| plt.tight_layout(rect=[0.03, 0, 1, 0.95]) | |
| out = f"/tmp/jit_vs_fd_{dk}.png"; plt.savefig(out, dpi=150, bbox_inches="tight", facecolor="white") | |
| print(f"[grid] {out}", flush=True) | |
| print("JIT_VS_FD_DONE", flush=True) | |