code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified | """FD-Loss post-training for a PixDiff generator (env: seggen, GPU 5). | |
| Starts from a base PixDiff checkpoint and continues training with: | |
| total = flow_matching_loss + fd_weight * normalized_FD_loss | |
| where the FD term matches the generated x0 feature distribution to the real-image | |
| reference distribution (Inception space), gated to low-noise timesteps (where x0 is | |
| a meaningful image). This targets the blur/distribution gap the MSE objective leaves. | |
| Run from project root (…/NPJ): | |
| CUDA_VISIBLE_DEVICES=5 python -m framework.synth.pixdiff.train_fd \ | |
| --base_ckpt pretrained/pixdiff/kvasir_seg_official_f1.0.pt \ | |
| --data_root /home/wzhang/LSC/Dataset/Segmentation/processed_unified \ | |
| --dataset kvasir_seg --protocol official \ | |
| --epochs 200 --lr 2e-5 --fd_weight 0.5 \ | |
| --out_ckpt pretrained/pixdiff/kvasir_seg_official_f1.0_fd.pt | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import sys | |
| import time | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))) | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from framework.synth.pixdiff.data import MaskCondGenDataset | |
| from framework.synth.pixdiff.conditioning import build_conditioner | |
| from framework.synth.pixdiff.mask_jit import MaskDenoiser | |
| from framework.synth.pixdiff.fd_loss import ( | |
| InceptionFeatures, FeatureQueue, compute_frechet_distance_loss, | |
| precompute_sigma_ref_sqrt, compute_ref_stats, | |
| ) | |
| def get_args(): | |
| p = argparse.ArgumentParser("PixDiff FD-Loss post-training") | |
| p.add_argument("--base_ckpt", required=True) | |
| p.add_argument("--data_root", required=True) | |
| p.add_argument("--dataset", required=True) | |
| p.add_argument("--protocol", required=True) | |
| p.add_argument("--train_fraction", type=float, default=1.0) | |
| p.add_argument("--fraction_seed", type=int, default=0) | |
| p.add_argument("--epochs", type=int, default=200) | |
| p.add_argument("--batch_size", type=int, default=32) | |
| p.add_argument("--lr", type=float, default=2e-5) | |
| p.add_argument("--num_workers", type=int, default=6) | |
| p.add_argument("--amp", default="bf16", choices=["bf16", "fp16", "fp32"]) | |
| # FD-Loss knobs | |
| p.add_argument("--fd_weight", type=float, default=0.5) | |
| p.add_argument("--fd_gate_t", type=float, default=0.5, help="apply FD only when t>=this (low noise)") | |
| p.add_argument("--queue_size", type=int, default=512) | |
| p.add_argument("--fd_norm_eps", type=float, default=1e-2) | |
| p.add_argument("--lpips_weight", type=float, default=0.0) | |
| p.add_argument("--dino_weight", type=float, default=0.0) | |
| p.add_argument("--percep_gate_t", type=float, default=0.5, help="apply perceptual only when t>=this") | |
| p.add_argument("--ref_stats", default="", help="npz of (mu,sigma); auto path + compute if empty") | |
| p.add_argument("--ema_decay", type=float, default=0.9999) | |
| p.add_argument("--seed", type=int, default=0) | |
| p.add_argument("--out_ckpt", required=True) | |
| p.add_argument("--log_interval", type=int, default=20) | |
| return p.parse_args() | |
| def main(): | |
| a = get_args() | |
| torch.manual_seed(a.seed) | |
| device = "cuda" | |
| amp_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}.get(a.amp) | |
| # ---- data ---- | |
| ds = MaskCondGenDataset(a.data_root, a.dataset, a.protocol, img_size=256, | |
| train_fraction=a.train_fraction, fraction_seed=a.fraction_seed) | |
| img_ch, n_cls = ds.in_channels, ds.num_classes | |
| loader = DataLoader(ds, batch_size=a.batch_size, shuffle=True, drop_last=True, | |
| num_workers=a.num_workers, pin_memory=True, persistent_workers=a.num_workers > 0) | |
| print(f"[fd] {a.dataset}/{a.protocol} n={len(ds)} in_ch={img_ch} num_classes={n_cls}", flush=True) | |
| if img_ch != 3: | |
| print("[fd][warn] Inception expects 3ch; non-RGB dataset — FD features may be weak.", flush=True) | |
| # ---- model from base ckpt ---- | |
| ckpt = torch.load(a.base_ckpt, map_location="cpu", weights_only=False) | |
| cond = build_conditioner(ckpt.get("conditioner", "onehot"), n_cls) | |
| model = MaskDenoiser(ckpt["model_name"], ckpt["img_size"], ckpt["img_channels"], cond, | |
| noise_scale=ckpt.get("noise_scale", 1.0), ema_decay=a.ema_decay, backbone=ckpt.get("backbone", "jit")).to(device) | |
| model.load_state_dict(ckpt["state_dict"]) | |
| model._ema = [e.to(device) for e in ckpt["ema"]] if ckpt.get("ema") is not None else None | |
| if model._ema is None: | |
| model.ema_init() | |
| print(f"[fd] loaded base {a.base_ckpt}", flush=True) | |
| # ---- FD machinery ---- | |
| inception = InceptionFeatures().to(device).eval() | |
| queue = FeatureQueue(size=a.queue_size, feat_dim=inception.feat_dim).to(device) | |
| percep = None | |
| if a.lpips_weight > 0 or a.dino_weight > 0: | |
| from framework.synth.pixdiff.perceptual import PerceptualLoss | |
| percep = PerceptualLoss(use_lpips=a.lpips_weight > 0, use_dino=a.dino_weight > 0, device=device) | |
| print(f"[fd] perceptual ON lpips_w={a.lpips_weight} dino_w={a.dino_weight} gate_t={a.percep_gate_t}", flush=True) | |
| ref_path = a.ref_stats or a.out_ckpt.replace(".pt", "_refstats.npz") | |
| if os.path.isfile(ref_path): | |
| rs = np.load(ref_path); mu_ref_np, sigma_ref_np = rs["mu"], rs["sigma"] | |
| print(f"[fd] loaded ref stats {ref_path}", flush=True) | |
| else: | |
| print("[fd] computing reference stats from real train images...", flush=True) | |
| ref_loader = DataLoader(MaskCondGenDataset(a.data_root, a.dataset, a.protocol, img_size=256, | |
| train_fraction=a.train_fraction, fraction_seed=a.fraction_seed, | |
| hflip=False, vflip=False), | |
| batch_size=a.batch_size, shuffle=False, num_workers=a.num_workers) | |
| mu_ref_np, sigma_ref_np, nref = compute_ref_stats(ref_loader, inception, device) | |
| os.makedirs(os.path.dirname(os.path.abspath(ref_path)) or ".", exist_ok=True) | |
| np.savez(ref_path, mu=mu_ref_np, sigma=sigma_ref_np) | |
| print(f"[fd] ref stats from {nref} imgs -> {ref_path}", flush=True) | |
| mu_ref = torch.tensor(mu_ref_np, device=device, dtype=torch.float64) | |
| sigma_ref = torch.tensor(sigma_ref_np, device=device, dtype=torch.float64) | |
| sigma_ref_sqrt = precompute_sigma_ref_sqrt(sigma_ref) | |
| opt = torch.optim.AdamW(model._trainable(), lr=a.lr, weight_decay=0.0) | |
| os.makedirs(os.path.dirname(os.path.abspath(a.out_ckpt)) or ".", exist_ok=True) | |
| def save(): | |
| torch.save({"model_name": ckpt["model_name"], "img_size": ckpt["img_size"], | |
| "img_channels": img_ch, "num_classes": n_cls, | |
| "conditioner": ckpt.get("conditioner", "onehot"), | |
| "noise_scale": ckpt.get("noise_scale", 1.0), | |
| "state_dict": model.state_dict(), "ema": model._ema, "args": vars(a)}, a.out_ckpt) | |
| print(f"[fd] saved {a.out_ckpt}", flush=True) | |
| step = 0 | |
| for epoch in range(a.epochs): | |
| model.train(); t0 = time.time(); run_fm = run_fd = run_fdraw = 0.0 | |
| for batch in loader: | |
| img = batch["image"].to(device, non_blocking=True) | |
| mask = batch["mask"].to(device, non_blocking=True) | |
| opt.zero_grad(set_to_none=True) | |
| with torch.autocast("cuda", dtype=amp_dtype) if amp_dtype else _null(): | |
| fm_loss, x_pred, t = model(img, mask, return_pred=True) | |
| # FD term on predicted clean image, gated to low noise | |
| gate = t >= a.fd_gate_t | |
| fd_loss = torch.zeros((), device=device); fd_raw = 0.0 | |
| if int(gate.sum()) >= 2: | |
| xg = x_pred[gate].float() | |
| feats = inception((xg.clamp(-1, 1) + 1) / 2) # (Ng,2048), grad flows | |
| if queue.is_ready(): | |
| mu, sigma = queue.build_feats_stats(feats) | |
| fd = compute_frechet_distance_loss(mu_ref, sigma_ref, mu, sigma, sigma_ref_sqrt) | |
| fd_raw = float(fd); fd_loss = fd / (fd.detach() + a.fd_norm_eps) | |
| queue.enqueue(feats) | |
| total = fm_loss + a.fd_weight * fd_loss | |
| pl_lpips = pl_dino = 0.0 | |
| if percep is not None: | |
| pgate = t >= a.percep_gate_t | |
| if int(pgate.sum()) >= 1: | |
| pld = percep(x_pred[pgate].float(), img[pgate].float()) | |
| if "lpips" in pld: | |
| total = total + a.lpips_weight * pld["lpips"]; pl_lpips = float(pld["lpips"]) | |
| if "dino" in pld: | |
| total = total + a.dino_weight * pld["dino"]; pl_dino = float(pld["dino"]) | |
| total.backward(); opt.step(); model.ema_update() | |
| run_fm += float(fm_loss); run_fd += float(fd_loss); run_fdraw += fd_raw; step += 1 | |
| if step % a.log_interval == 0: | |
| print(f"[fd] ep{epoch} step{step} fm={float(fm_loss):.4f} fd_raw={fd_raw:.1f} " | |
| f"lpips={pl_lpips:.3f} dino={pl_dino:.3f} qready={queue.is_ready()}", flush=True) | |
| print(f"[fd] epoch {epoch} fm={run_fm/max(1,len(loader)):.4f} " | |
| f"fd_raw={run_fdraw/max(1,len(loader)):.1f} ({time.time()-t0:.1f}s)", flush=True) | |
| save() | |
| class _null: | |
| def __enter__(self): return self | |
| def __exit__(self, *a): return False | |
| if __name__ == "__main__": | |
| main() | |