"""Same-mask JiT(native) vs JiT-FD comparison: does FD-perceptual sharpen the synth? Samples JiT-FD on the SAME no-aug f50 masks the native JiT align-set used, builds [mask | real | JiT native | JiT-FD] grids for ISIC + Kvasir.""" import os, subprocess import numpy as np import matplotlib; matplotlib.use("Agg") import matplotlib.pyplot as plt from PIL import Image ROOT = "/home/wzhang/LSC/Code/NPJ"; DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified" PY = "/opt/anaconda3/envs/seggen/bin/python" DSETS = {"isic": ("medsegdb_isic2018", "holdout", 2582), "kvasir": ("kvasir_seg", "official", 800)} def sample(ckpt, ds, proto, frac, out): if os.path.isdir(out + "/images") and len(os.listdir(out + "/images")) >= 40: print(f"[skip-sample] {out} exists", flush=True); return env = dict(os.environ, CUDA_DEVICE_ORDER="PCI_BUS_ID", CUDA_VISIBLE_DEVICES="0", TORCHDYNAMO_DISABLE="1", PYTHONPATH=".", OMP_NUM_THREADS="4") subprocess.run([PY, "-m", "framework.synth.pixdiff.sample", "--ckpt", ckpt, "--data_root", DR, "--dataset", ds, "--protocol", proto, "--train_fraction", str(frac), "--fraction_seed", "0", "--n_per_mask", "1", "--num_steps", "50", "--out_dir", out], env=env, cwd=ROOT, check=True) print(f"[sampled] {out}", flush=True) def fmap(d): p = os.path.join(d, "images"); m = {} if os.path.isdir(p): for f in sorted(os.listdir(p)): if f.endswith(".png"): m.setdefault(f[:-4].split("__")[0], os.path.join(p, f)) return m def rgb(p): return np.asarray(Image.open(p).convert("RGB").resize((256, 256))) def gray(p): return np.asarray(Image.open(p).convert("L").resize((256, 256))) for dk, (ds, proto, tot) in DSETS.items(): f50 = 50 / tot fd_out = f"{DR}/{ds}/{proto}/synth_alignfd_jitfd_{dk}" sample(f"pretrained/pixdiff/p1_jitfd_{dk}.pt", ds, proto, f50, fd_out) base = f"{DR}/{ds}/{proto}"; ri, rm = f"{base}/train/images", f"{base}/train/masks" nat = fmap(f"{base}/synth_align_jit_{dk}"); fd = fmap(fd_out) common = set(os.path.splitext(f)[0] for f in os.listdir(ri) if f.endswith(".png")) & set(nat) & set(fd) common = sorted(common); ncol = min(6, len(common)) idx = [round(i * (len(common) - 1) / (ncol - 1)) for i in range(ncol)] if ncol > 1 else [0] cases = [common[i] for i in idx] rows = [("Conditioning mask", "mask"), ("Real", "real"), ("JiT (native, P1)", nat), ("JiT-FD (FD-感知)", fd)] fig, ax = plt.subplots(len(rows), ncol, figsize=(ncol * 2.1, len(rows) * 2.15)) for r, (lab, src) in enumerate(rows): for c, bs in enumerate(cases): a = ax[r][c] try: mk = gray(f"{rm}/{bs}.png") if src == "mask": a.imshow(mk, cmap="gray") elif src == "real": a.imshow(rgb(f"{ri}/{bs}.png")) else: a.imshow(rgb(src[bs])) if src not in ("mask",): a.contour((mk > 127).astype(float), levels=[0.5], colors=["#19f04b"], linewidths=0.9) except Exception: a.imshow(np.ones((256, 256, 3))); a.text(0.5, 0.5, "n/a", ha="center", va="center", transform=a.transAxes) a.set_xticks([]); a.set_yticks([]) for s in a.spines.values(): s.set_visible(False) if c == 0: a.set_ylabel(lab, fontsize=11, rotation=90, va="center", labelpad=8, color=("#111" if r < 2 else "#1a3b8b"), fontweight=("bold" if r == 3 else "normal")) fig.suptitle(f"{dk.upper()} — 原生 JiT vs JiT-FD(同掩码):FD-感知精修是否更锐?", fontsize=12) plt.tight_layout(rect=[0.03, 0, 1, 0.95]) out = f"/tmp/jit_vs_fd_{dk}.png"; plt.savefig(out, dpi=150, bbox_inches="tight", facecolor="white") print(f"[grid] {out}", flush=True) print("JIT_VS_FD_DONE", flush=True)