GenSeg-Baselines / code /scripts /p1 /smoke_pixelgen.py
MaybeRichard's picture
code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified
Raw
History Blame Contribute Delete
3.43 kB
"""Smoke test: port PixelGen's JiT denoiser into the PixDiff mask-concat scaffolding.
Validates: import on server, in=img+cond/out=img build, flow-matching train step, sampling.
Tiny: 20 train steps + 10-step sample on ~50 ISIC images. No checkpoint written."""
import os, sys
sys.path.insert(0, "/home/wzhang/LSC/Code/NPJ") # framework.*
PG = "/home/wzhang/LSC/Code/NPJ/sota/PixelGen"
sys.path.insert(0, os.path.join(PG, "src", "models", "transformer")) # JiT.py
import torch, torch.nn as nn
from torch.utils.data import DataLoader
from JiT import JiT_models, FinalLayer # PixelGen denoiser
from framework.synth.pixdiff.conditioning import build_conditioner
from framework.synth.pixdiff.data import MaskCondGenDataset
dev = "cuda"
DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified"
ds = MaskCondGenDataset(DR, "medsegdb_isic2018", "holdout", img_size=256,
train_fraction=0.02, fraction_seed=0)
print(f"[ds] n={len(ds)} in_ch={ds.in_channels} num_classes={ds.num_classes}", flush=True)
cond = build_conditioner("onehot", ds.num_classes).to(dev)
img_ch, K = ds.in_channels, cond.cond_channels
net = JiT_models["JiT-B/16"](input_size=256, in_channels=img_ch + K, num_classes=1).to(dev)
if net.out_channels != img_ch:
net.out_channels = img_ch
net.final_layer = FinalLayer(net.hidden_size, net.patch_size, img_ch).to(dev)
nn.init.constant_(net.final_layer.linear.weight, 0); nn.init.constant_(net.final_layer.linear.bias, 0)
nn.init.constant_(net.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(net.final_layer.adaLN_modulation[-1].bias, 0)
print(f"[net] PixelGen JiT-B/16 in={img_ch+K} out={net.out_channels} params={sum(p.numel() for p in net.parameters())/1e6:.1f}M", flush=True)
opt = torch.optim.AdamW(net.parameters(), lr=1e-4)
dl = DataLoader(ds, batch_size=4, shuffle=True, drop_last=True, num_workers=2)
it = iter(dl)
def get_batch():
global it
try: b = next(it)
except StopIteration:
it = iter(dl); b = next(it)
if isinstance(b, dict): return b["image"], b["mask"]
return b[0], b[1]
net.train()
for step in range(20):
img, msk = get_batch(); img, msk = img.to(dev), msk.to(dev)
t = torch.sigmoid(torch.randn(img.size(0), device=dev) * 0.8 - 0.8).view(-1, 1, 1, 1)
e = torch.randn_like(img)
z = t * img + (1 - t) * e
v = (img - z) / (1 - t).clamp_min(5e-2)
c = cond(msk)
y = torch.zeros(img.size(0), dtype=torch.long, device=dev)
x_pred = net(torch.cat([z, c], dim=1), t.flatten(), y)
v_pred = (x_pred - z) / (1 - t).clamp_min(5e-2)
loss = ((v - v_pred) ** 2).mean()
loss.backward(); opt.step(); opt.zero_grad()
if step % 5 == 0 or step == 19:
print(f"[train] step {step:2d} loss {loss.item():.4f}", flush=True)
net.eval()
with torch.no_grad():
msk0 = msk[:2]; c0 = cond(msk0)
z = torch.randn(2, img_ch, 256, 256, device=dev)
ts = torch.linspace(0, 1, 11).tolist()
for i in range(10):
tc, dt = ts[i], ts[i + 1] - ts[i]
tt = torch.full((2,), tc, device=dev)
xp = net(torch.cat([z, c0], dim=1), tt, torch.zeros(2, dtype=torch.long, device=dev))
z = z + (xp - z) / max(1 - tc, 5e-2) * dt
print(f"[sample] ok shape={tuple(z.shape)} range=({z.min():.2f},{z.max():.2f})", flush=True)
print("SMOKE_PIXELGEN_PASS", flush=True)