GenSeg-Baselines / code /scripts /p1 /train_fd_patched.py
MaybeRichard's picture
code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified
Raw
History Blame Contribute Delete
9.36 kB
"""FD-Loss post-training for a PixDiff generator (env: seggen, GPU 5).
Starts from a base PixDiff checkpoint and continues training with:
total = flow_matching_loss + fd_weight * normalized_FD_loss
where the FD term matches the generated x0 feature distribution to the real-image
reference distribution (Inception space), gated to low-noise timesteps (where x0 is
a meaningful image). This targets the blur/distribution gap the MSE objective leaves.
Run from project root (…/NPJ):
CUDA_VISIBLE_DEVICES=5 python -m framework.synth.pixdiff.train_fd \
--base_ckpt pretrained/pixdiff/kvasir_seg_official_f1.0.pt \
--data_root /home/wzhang/LSC/Dataset/Segmentation/processed_unified \
--dataset kvasir_seg --protocol official \
--epochs 200 --lr 2e-5 --fd_weight 0.5 \
--out_ckpt pretrained/pixdiff/kvasir_seg_official_f1.0_fd.pt
"""
from __future__ import annotations
import argparse
import os
import sys
import time
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
import numpy as np
import torch
from torch.utils.data import DataLoader
from framework.synth.pixdiff.data import MaskCondGenDataset
from framework.synth.pixdiff.conditioning import build_conditioner
from framework.synth.pixdiff.mask_jit import MaskDenoiser
from framework.synth.pixdiff.fd_loss import (
InceptionFeatures, FeatureQueue, compute_frechet_distance_loss,
precompute_sigma_ref_sqrt, compute_ref_stats,
)
def get_args():
p = argparse.ArgumentParser("PixDiff FD-Loss post-training")
p.add_argument("--base_ckpt", required=True)
p.add_argument("--data_root", required=True)
p.add_argument("--dataset", required=True)
p.add_argument("--protocol", required=True)
p.add_argument("--train_fraction", type=float, default=1.0)
p.add_argument("--fraction_seed", type=int, default=0)
p.add_argument("--epochs", type=int, default=200)
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--lr", type=float, default=2e-5)
p.add_argument("--num_workers", type=int, default=6)
p.add_argument("--amp", default="bf16", choices=["bf16", "fp16", "fp32"])
# FD-Loss knobs
p.add_argument("--fd_weight", type=float, default=0.5)
p.add_argument("--fd_gate_t", type=float, default=0.5, help="apply FD only when t>=this (low noise)")
p.add_argument("--queue_size", type=int, default=512)
p.add_argument("--fd_norm_eps", type=float, default=1e-2)
p.add_argument("--lpips_weight", type=float, default=0.0)
p.add_argument("--dino_weight", type=float, default=0.0)
p.add_argument("--percep_gate_t", type=float, default=0.5, help="apply perceptual only when t>=this")
p.add_argument("--ref_stats", default="", help="npz of (mu,sigma); auto path + compute if empty")
p.add_argument("--ema_decay", type=float, default=0.9999)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--out_ckpt", required=True)
p.add_argument("--log_interval", type=int, default=20)
return p.parse_args()
def main():
a = get_args()
torch.manual_seed(a.seed)
device = "cuda"
amp_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}.get(a.amp)
# ---- data ----
ds = MaskCondGenDataset(a.data_root, a.dataset, a.protocol, img_size=256,
train_fraction=a.train_fraction, fraction_seed=a.fraction_seed)
img_ch, n_cls = ds.in_channels, ds.num_classes
loader = DataLoader(ds, batch_size=a.batch_size, shuffle=True, drop_last=True,
num_workers=a.num_workers, pin_memory=True, persistent_workers=a.num_workers > 0)
print(f"[fd] {a.dataset}/{a.protocol} n={len(ds)} in_ch={img_ch} num_classes={n_cls}", flush=True)
if img_ch != 3:
print("[fd][warn] Inception expects 3ch; non-RGB dataset — FD features may be weak.", flush=True)
# ---- model from base ckpt ----
ckpt = torch.load(a.base_ckpt, map_location="cpu", weights_only=False)
cond = build_conditioner(ckpt.get("conditioner", "onehot"), n_cls)
model = MaskDenoiser(ckpt["model_name"], ckpt["img_size"], ckpt["img_channels"], cond,
noise_scale=ckpt.get("noise_scale", 1.0), ema_decay=a.ema_decay, backbone=ckpt.get("backbone", "jit")).to(device)
model.load_state_dict(ckpt["state_dict"])
model._ema = [e.to(device) for e in ckpt["ema"]] if ckpt.get("ema") is not None else None
if model._ema is None:
model.ema_init()
print(f"[fd] loaded base {a.base_ckpt}", flush=True)
# ---- FD machinery ----
inception = InceptionFeatures().to(device).eval()
queue = FeatureQueue(size=a.queue_size, feat_dim=inception.feat_dim).to(device)
percep = None
if a.lpips_weight > 0 or a.dino_weight > 0:
from framework.synth.pixdiff.perceptual import PerceptualLoss
percep = PerceptualLoss(use_lpips=a.lpips_weight > 0, use_dino=a.dino_weight > 0, device=device)
print(f"[fd] perceptual ON lpips_w={a.lpips_weight} dino_w={a.dino_weight} gate_t={a.percep_gate_t}", flush=True)
ref_path = a.ref_stats or a.out_ckpt.replace(".pt", "_refstats.npz")
if os.path.isfile(ref_path):
rs = np.load(ref_path); mu_ref_np, sigma_ref_np = rs["mu"], rs["sigma"]
print(f"[fd] loaded ref stats {ref_path}", flush=True)
else:
print("[fd] computing reference stats from real train images...", flush=True)
ref_loader = DataLoader(MaskCondGenDataset(a.data_root, a.dataset, a.protocol, img_size=256,
train_fraction=a.train_fraction, fraction_seed=a.fraction_seed,
hflip=False, vflip=False),
batch_size=a.batch_size, shuffle=False, num_workers=a.num_workers)
mu_ref_np, sigma_ref_np, nref = compute_ref_stats(ref_loader, inception, device)
os.makedirs(os.path.dirname(os.path.abspath(ref_path)) or ".", exist_ok=True)
np.savez(ref_path, mu=mu_ref_np, sigma=sigma_ref_np)
print(f"[fd] ref stats from {nref} imgs -> {ref_path}", flush=True)
mu_ref = torch.tensor(mu_ref_np, device=device, dtype=torch.float64)
sigma_ref = torch.tensor(sigma_ref_np, device=device, dtype=torch.float64)
sigma_ref_sqrt = precompute_sigma_ref_sqrt(sigma_ref)
opt = torch.optim.AdamW(model._trainable(), lr=a.lr, weight_decay=0.0)
os.makedirs(os.path.dirname(os.path.abspath(a.out_ckpt)) or ".", exist_ok=True)
def save():
torch.save({"model_name": ckpt["model_name"], "img_size": ckpt["img_size"],
"img_channels": img_ch, "num_classes": n_cls,
"conditioner": ckpt.get("conditioner", "onehot"),
"noise_scale": ckpt.get("noise_scale", 1.0),
"state_dict": model.state_dict(), "ema": model._ema, "args": vars(a)}, a.out_ckpt)
print(f"[fd] saved {a.out_ckpt}", flush=True)
step = 0
for epoch in range(a.epochs):
model.train(); t0 = time.time(); run_fm = run_fd = run_fdraw = 0.0
for batch in loader:
img = batch["image"].to(device, non_blocking=True)
mask = batch["mask"].to(device, non_blocking=True)
opt.zero_grad(set_to_none=True)
with torch.autocast("cuda", dtype=amp_dtype) if amp_dtype else _null():
fm_loss, x_pred, t = model(img, mask, return_pred=True)
# FD term on predicted clean image, gated to low noise
gate = t >= a.fd_gate_t
fd_loss = torch.zeros((), device=device); fd_raw = 0.0
if int(gate.sum()) >= 2:
xg = x_pred[gate].float()
feats = inception((xg.clamp(-1, 1) + 1) / 2) # (Ng,2048), grad flows
if queue.is_ready():
mu, sigma = queue.build_feats_stats(feats)
fd = compute_frechet_distance_loss(mu_ref, sigma_ref, mu, sigma, sigma_ref_sqrt)
fd_raw = float(fd); fd_loss = fd / (fd.detach() + a.fd_norm_eps)
queue.enqueue(feats)
total = fm_loss + a.fd_weight * fd_loss
pl_lpips = pl_dino = 0.0
if percep is not None:
pgate = t >= a.percep_gate_t
if int(pgate.sum()) >= 1:
pld = percep(x_pred[pgate].float(), img[pgate].float())
if "lpips" in pld:
total = total + a.lpips_weight * pld["lpips"]; pl_lpips = float(pld["lpips"])
if "dino" in pld:
total = total + a.dino_weight * pld["dino"]; pl_dino = float(pld["dino"])
total.backward(); opt.step(); model.ema_update()
run_fm += float(fm_loss); run_fd += float(fd_loss); run_fdraw += fd_raw; step += 1
if step % a.log_interval == 0:
print(f"[fd] ep{epoch} step{step} fm={float(fm_loss):.4f} fd_raw={fd_raw:.1f} "
f"lpips={pl_lpips:.3f} dino={pl_dino:.3f} qready={queue.is_ready()}", flush=True)
print(f"[fd] epoch {epoch} fm={run_fm/max(1,len(loader)):.4f} "
f"fd_raw={run_fdraw/max(1,len(loader)):.1f} ({time.time()-t0:.1f}s)", flush=True)
save()
class _null:
def __enter__(self): return self
def __exit__(self, *a): return False
if __name__ == "__main__":
main()