GenSeg-Baselines / code /scripts /p1 /make_jit_vs_fd.py
MaybeRichard's picture
code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified
Raw
History Blame Contribute Delete
3.86 kB
"""Same-mask JiT(native) vs JiT-FD comparison: does FD-perceptual sharpen the synth?
Samples JiT-FD on the SAME no-aug f50 masks the native JiT align-set used, builds
[mask | real | JiT native | JiT-FD] grids for ISIC + Kvasir."""
import os, subprocess
import numpy as np
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image
ROOT = "/home/wzhang/LSC/Code/NPJ"; DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified"
PY = "/opt/anaconda3/envs/seggen/bin/python"
DSETS = {"isic": ("medsegdb_isic2018", "holdout", 2582), "kvasir": ("kvasir_seg", "official", 800)}
def sample(ckpt, ds, proto, frac, out):
if os.path.isdir(out + "/images") and len(os.listdir(out + "/images")) >= 40:
print(f"[skip-sample] {out} exists", flush=True); return
env = dict(os.environ, CUDA_DEVICE_ORDER="PCI_BUS_ID", CUDA_VISIBLE_DEVICES="0",
TORCHDYNAMO_DISABLE="1", PYTHONPATH=".", OMP_NUM_THREADS="4")
subprocess.run([PY, "-m", "framework.synth.pixdiff.sample", "--ckpt", ckpt, "--data_root", DR,
"--dataset", ds, "--protocol", proto, "--train_fraction", str(frac),
"--fraction_seed", "0", "--n_per_mask", "1", "--num_steps", "50", "--out_dir", out],
env=env, cwd=ROOT, check=True)
print(f"[sampled] {out}", flush=True)
def fmap(d):
p = os.path.join(d, "images"); m = {}
if os.path.isdir(p):
for f in sorted(os.listdir(p)):
if f.endswith(".png"): m.setdefault(f[:-4].split("__")[0], os.path.join(p, f))
return m
def rgb(p): return np.asarray(Image.open(p).convert("RGB").resize((256, 256)))
def gray(p): return np.asarray(Image.open(p).convert("L").resize((256, 256)))
for dk, (ds, proto, tot) in DSETS.items():
f50 = 50 / tot
fd_out = f"{DR}/{ds}/{proto}/synth_alignfd_jitfd_{dk}"
sample(f"pretrained/pixdiff/p1_jitfd_{dk}.pt", ds, proto, f50, fd_out)
base = f"{DR}/{ds}/{proto}"; ri, rm = f"{base}/train/images", f"{base}/train/masks"
nat = fmap(f"{base}/synth_align_jit_{dk}"); fd = fmap(fd_out)
common = set(os.path.splitext(f)[0] for f in os.listdir(ri) if f.endswith(".png")) & set(nat) & set(fd)
common = sorted(common); ncol = min(6, len(common))
idx = [round(i * (len(common) - 1) / (ncol - 1)) for i in range(ncol)] if ncol > 1 else [0]
cases = [common[i] for i in idx]
rows = [("Conditioning mask", "mask"), ("Real", "real"), ("JiT (native, P1)", nat), ("JiT-FD (FD-感知)", fd)]
fig, ax = plt.subplots(len(rows), ncol, figsize=(ncol * 2.1, len(rows) * 2.15))
for r, (lab, src) in enumerate(rows):
for c, bs in enumerate(cases):
a = ax[r][c]
try:
mk = gray(f"{rm}/{bs}.png")
if src == "mask": a.imshow(mk, cmap="gray")
elif src == "real": a.imshow(rgb(f"{ri}/{bs}.png"))
else: a.imshow(rgb(src[bs]))
if src not in ("mask",): a.contour((mk > 127).astype(float), levels=[0.5], colors=["#19f04b"], linewidths=0.9)
except Exception:
a.imshow(np.ones((256, 256, 3))); a.text(0.5, 0.5, "n/a", ha="center", va="center", transform=a.transAxes)
a.set_xticks([]); a.set_yticks([])
for s in a.spines.values(): s.set_visible(False)
if c == 0: a.set_ylabel(lab, fontsize=11, rotation=90, va="center", labelpad=8,
color=("#111" if r < 2 else "#1a3b8b"), fontweight=("bold" if r == 3 else "normal"))
fig.suptitle(f"{dk.upper()} — 原生 JiT vs JiT-FD(同掩码):FD-感知精修是否更锐?", fontsize=12)
plt.tight_layout(rect=[0.03, 0, 1, 0.95])
out = f"/tmp/jit_vs_fd_{dk}.png"; plt.savefig(out, dpi=150, bbox_inches="tight", facecolor="white")
print(f"[grid] {out}", flush=True)
print("JIT_VS_FD_DONE", flush=True)