code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified | """Smoke: port DeCo (PixNerDiT) or PixelDiT (PixDiT) into the PixDiff mask-concat scaffolding. | |
| Backbone-agnostic decouple: build with in=img+cond, out defaults to in, take x_pred[:, :img_ch]. | |
| Usage: python smoke_backbone.py {deco|pixeldit}""" | |
| import os, sys | |
| sys.path.insert(0, "/home/wzhang/LSC/Code/NPJ") | |
| import torch, torch.nn as nn | |
| from torch.utils.data import DataLoader | |
| from framework.synth.pixdiff.conditioning import build_conditioner | |
| from framework.synth.pixdiff.data import MaskCondGenDataset | |
| BK = sys.argv[1] | |
| DECO = "/home/wzhang/LSC/Code/NPJ/sota/DeCo" | |
| PIXELDIT = "/home/wzhang/LSC/Code/NPJ/sota/PixelDiT" | |
| dev = "cuda" | |
| DR = "/home/wzhang/LSC/Dataset/Segmentation/processed_unified" | |
| ds = MaskCondGenDataset(DR, "medsegdb_isic2018", "holdout", img_size=256, | |
| train_fraction=0.02, fraction_seed=0) | |
| cond = build_conditioner("onehot", ds.num_classes).to(dev) | |
| img_ch, K = ds.in_channels, cond.cond_channels | |
| in_tot = img_ch + K | |
| print(f"[{BK}] ds n={len(ds)} img_ch={img_ch} K={K} in_tot={in_tot}", flush=True) | |
| if BK == "deco": | |
| sys.path.insert(0, os.path.join(DECO, "src", "models", "transformer")) | |
| from dit_c2i_DeCo import PixNerDiT | |
| net = PixNerDiT(in_channels=in_tot, patch_size=16, num_groups=12, hidden_size=768, | |
| hidden_size_x=32, num_blocks=13, num_cond_blocks=12, num_classes=1).to(dev) | |
| elif BK == "pixeldit": | |
| sys.path.insert(0, PIXELDIT) | |
| from pixdit_core.pixeldit_c2i import PixDiT | |
| net = PixDiT(in_channels=in_tot, num_groups=12, hidden_size=768, pixel_hidden_size=16, | |
| patch_depth=12, pixel_depth=4, patch_size=16, num_classes=1).to(dev) | |
| else: | |
| raise SystemExit("backbone must be deco|pixeldit") | |
| print(f"[{BK}] params={sum(p.numel() for p in net.parameters())/1e6:.1f}M", flush=True) | |
| opt = torch.optim.AdamW(net.parameters(), lr=1e-4) | |
| dl = DataLoader(ds, batch_size=4, shuffle=True, drop_last=True, num_workers=2) | |
| it = iter(dl) | |
| def get_batch(): | |
| global it | |
| try: b = next(it) | |
| except StopIteration: it = iter(dl); b = next(it) | |
| return (b["image"], b["mask"]) if isinstance(b, dict) else (b[0], b[1]) | |
| net.train() | |
| for step in range(20): | |
| img, msk = get_batch(); img, msk = img.to(dev), msk.to(dev) | |
| t = torch.sigmoid(torch.randn(img.size(0), device=dev) * 0.8 - 0.8).view(-1, 1, 1, 1) | |
| e = torch.randn_like(img) | |
| z = t * img + (1 - t) * e | |
| v = (img - z) / (1 - t).clamp_min(5e-2) | |
| c = cond(msk) | |
| y = torch.zeros(img.size(0), dtype=torch.long, device=dev) | |
| out = net(torch.cat([z, c], dim=1), t.flatten(), y) | |
| assert out.dim() == 4 and out.shape[1] >= img_ch, f"bad out shape {tuple(out.shape)}" | |
| x_pred = out[:, :img_ch] | |
| v_pred = (x_pred - z) / (1 - t).clamp_min(5e-2) | |
| loss = ((v - v_pred) ** 2).mean() | |
| loss.backward(); opt.step(); opt.zero_grad() | |
| if step % 5 == 0 or step == 19: | |
| print(f"[{BK}] step {step:2d} loss {loss.item():.4f}", flush=True) | |
| net.eval() | |
| with torch.no_grad(): | |
| msk0 = msk[:2]; c0 = cond(msk0) | |
| z = torch.randn(2, img_ch, 256, 256, device=dev) | |
| ts = torch.linspace(0, 1, 11).tolist() | |
| for i in range(10): | |
| tc, dt = ts[i], ts[i + 1] - ts[i] | |
| out = net(torch.cat([z, c0], dim=1), torch.full((2,), tc, device=dev), | |
| torch.zeros(2, dtype=torch.long, device=dev))[:, :img_ch] | |
| z = z + (out - z) / max(1 - tc, 5e-2) * dt | |
| print(f"[{BK}] sample ok shape={tuple(z.shape)} range=({z.min():.2f},{z.max():.2f})", flush=True) | |
| print(f"SMOKE_{BK.upper()}_PASS", flush=True) | |