"""FD-Loss post-training for a PixDiff generator (env: seggen, GPU 5). Starts from a base PixDiff checkpoint and continues training with: total = flow_matching_loss + fd_weight * normalized_FD_loss where the FD term matches the generated x0 feature distribution to the real-image reference distribution (Inception space), gated to low-noise timesteps (where x0 is a meaningful image). This targets the blur/distribution gap the MSE objective leaves. Run from project root (…/NPJ): CUDA_VISIBLE_DEVICES=5 python -m framework.synth.pixdiff.train_fd \ --base_ckpt pretrained/pixdiff/kvasir_seg_official_f1.0.pt \ --data_root /home/wzhang/LSC/Dataset/Segmentation/processed_unified \ --dataset kvasir_seg --protocol official \ --epochs 200 --lr 2e-5 --fd_weight 0.5 \ --out_ckpt pretrained/pixdiff/kvasir_seg_official_f1.0_fd.pt """ from __future__ import annotations import argparse import os import sys import time sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))) import numpy as np import torch from torch.utils.data import DataLoader from framework.synth.pixdiff.data import MaskCondGenDataset from framework.synth.pixdiff.conditioning import build_conditioner from framework.synth.pixdiff.mask_jit import MaskDenoiser from framework.synth.pixdiff.fd_loss import ( InceptionFeatures, FeatureQueue, compute_frechet_distance_loss, precompute_sigma_ref_sqrt, compute_ref_stats, ) def get_args(): p = argparse.ArgumentParser("PixDiff FD-Loss post-training") p.add_argument("--base_ckpt", required=True) p.add_argument("--data_root", required=True) p.add_argument("--dataset", required=True) p.add_argument("--protocol", required=True) p.add_argument("--train_fraction", type=float, default=1.0) p.add_argument("--fraction_seed", type=int, default=0) p.add_argument("--epochs", type=int, default=200) p.add_argument("--batch_size", type=int, default=32) p.add_argument("--lr", type=float, default=2e-5) p.add_argument("--num_workers", type=int, default=6) p.add_argument("--amp", default="bf16", choices=["bf16", "fp16", "fp32"]) # FD-Loss knobs p.add_argument("--fd_weight", type=float, default=0.5) p.add_argument("--fd_gate_t", type=float, default=0.5, help="apply FD only when t>=this (low noise)") p.add_argument("--queue_size", type=int, default=512) p.add_argument("--fd_norm_eps", type=float, default=1e-2) p.add_argument("--lpips_weight", type=float, default=0.0) p.add_argument("--dino_weight", type=float, default=0.0) p.add_argument("--percep_gate_t", type=float, default=0.5, help="apply perceptual only when t>=this") p.add_argument("--ref_stats", default="", help="npz of (mu,sigma); auto path + compute if empty") p.add_argument("--ema_decay", type=float, default=0.9999) p.add_argument("--seed", type=int, default=0) p.add_argument("--out_ckpt", required=True) p.add_argument("--log_interval", type=int, default=20) return p.parse_args() def main(): a = get_args() torch.manual_seed(a.seed) device = "cuda" amp_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}.get(a.amp) # ---- data ---- ds = MaskCondGenDataset(a.data_root, a.dataset, a.protocol, img_size=256, train_fraction=a.train_fraction, fraction_seed=a.fraction_seed) img_ch, n_cls = ds.in_channels, ds.num_classes loader = DataLoader(ds, batch_size=a.batch_size, shuffle=True, drop_last=True, num_workers=a.num_workers, pin_memory=True, persistent_workers=a.num_workers > 0) print(f"[fd] {a.dataset}/{a.protocol} n={len(ds)} in_ch={img_ch} num_classes={n_cls}", flush=True) if img_ch != 3: print("[fd][warn] Inception expects 3ch; non-RGB dataset — FD features may be weak.", flush=True) # ---- model from base ckpt ---- ckpt = torch.load(a.base_ckpt, map_location="cpu", weights_only=False) cond = build_conditioner(ckpt.get("conditioner", "onehot"), n_cls) model = MaskDenoiser(ckpt["model_name"], ckpt["img_size"], ckpt["img_channels"], cond, noise_scale=ckpt.get("noise_scale", 1.0), ema_decay=a.ema_decay, backbone=ckpt.get("backbone", "jit")).to(device) model.load_state_dict(ckpt["state_dict"]) model._ema = [e.to(device) for e in ckpt["ema"]] if ckpt.get("ema") is not None else None if model._ema is None: model.ema_init() print(f"[fd] loaded base {a.base_ckpt}", flush=True) # ---- FD machinery ---- inception = InceptionFeatures().to(device).eval() queue = FeatureQueue(size=a.queue_size, feat_dim=inception.feat_dim).to(device) percep = None if a.lpips_weight > 0 or a.dino_weight > 0: from framework.synth.pixdiff.perceptual import PerceptualLoss percep = PerceptualLoss(use_lpips=a.lpips_weight > 0, use_dino=a.dino_weight > 0, device=device) print(f"[fd] perceptual ON lpips_w={a.lpips_weight} dino_w={a.dino_weight} gate_t={a.percep_gate_t}", flush=True) ref_path = a.ref_stats or a.out_ckpt.replace(".pt", "_refstats.npz") if os.path.isfile(ref_path): rs = np.load(ref_path); mu_ref_np, sigma_ref_np = rs["mu"], rs["sigma"] print(f"[fd] loaded ref stats {ref_path}", flush=True) else: print("[fd] computing reference stats from real train images...", flush=True) ref_loader = DataLoader(MaskCondGenDataset(a.data_root, a.dataset, a.protocol, img_size=256, train_fraction=a.train_fraction, fraction_seed=a.fraction_seed, hflip=False, vflip=False), batch_size=a.batch_size, shuffle=False, num_workers=a.num_workers) mu_ref_np, sigma_ref_np, nref = compute_ref_stats(ref_loader, inception, device) os.makedirs(os.path.dirname(os.path.abspath(ref_path)) or ".", exist_ok=True) np.savez(ref_path, mu=mu_ref_np, sigma=sigma_ref_np) print(f"[fd] ref stats from {nref} imgs -> {ref_path}", flush=True) mu_ref = torch.tensor(mu_ref_np, device=device, dtype=torch.float64) sigma_ref = torch.tensor(sigma_ref_np, device=device, dtype=torch.float64) sigma_ref_sqrt = precompute_sigma_ref_sqrt(sigma_ref) opt = torch.optim.AdamW(model._trainable(), lr=a.lr, weight_decay=0.0) os.makedirs(os.path.dirname(os.path.abspath(a.out_ckpt)) or ".", exist_ok=True) def save(): torch.save({"model_name": ckpt["model_name"], "img_size": ckpt["img_size"], "img_channels": img_ch, "num_classes": n_cls, "conditioner": ckpt.get("conditioner", "onehot"), "noise_scale": ckpt.get("noise_scale", 1.0), "state_dict": model.state_dict(), "ema": model._ema, "args": vars(a)}, a.out_ckpt) print(f"[fd] saved {a.out_ckpt}", flush=True) step = 0 for epoch in range(a.epochs): model.train(); t0 = time.time(); run_fm = run_fd = run_fdraw = 0.0 for batch in loader: img = batch["image"].to(device, non_blocking=True) mask = batch["mask"].to(device, non_blocking=True) opt.zero_grad(set_to_none=True) with torch.autocast("cuda", dtype=amp_dtype) if amp_dtype else _null(): fm_loss, x_pred, t = model(img, mask, return_pred=True) # FD term on predicted clean image, gated to low noise gate = t >= a.fd_gate_t fd_loss = torch.zeros((), device=device); fd_raw = 0.0 if int(gate.sum()) >= 2: xg = x_pred[gate].float() feats = inception((xg.clamp(-1, 1) + 1) / 2) # (Ng,2048), grad flows if queue.is_ready(): mu, sigma = queue.build_feats_stats(feats) fd = compute_frechet_distance_loss(mu_ref, sigma_ref, mu, sigma, sigma_ref_sqrt) fd_raw = float(fd); fd_loss = fd / (fd.detach() + a.fd_norm_eps) queue.enqueue(feats) total = fm_loss + a.fd_weight * fd_loss pl_lpips = pl_dino = 0.0 if percep is not None: pgate = t >= a.percep_gate_t if int(pgate.sum()) >= 1: pld = percep(x_pred[pgate].float(), img[pgate].float()) if "lpips" in pld: total = total + a.lpips_weight * pld["lpips"]; pl_lpips = float(pld["lpips"]) if "dino" in pld: total = total + a.dino_weight * pld["dino"]; pl_dino = float(pld["dino"]) total.backward(); opt.step(); model.ema_update() run_fm += float(fm_loss); run_fd += float(fd_loss); run_fdraw += fd_raw; step += 1 if step % a.log_interval == 0: print(f"[fd] ep{epoch} step{step} fm={float(fm_loss):.4f} fd_raw={fd_raw:.1f} " f"lpips={pl_lpips:.3f} dino={pl_dino:.3f} qready={queue.is_ready()}", flush=True) print(f"[fd] epoch {epoch} fm={run_fm/max(1,len(loader)):.4f} " f"fd_raw={run_fdraw/max(1,len(loader)):.1f} ({time.time()-t0:.1f}s)", flush=True) save() class _null: def __enter__(self): return self def __exit__(self, *a): return False if __name__ == "__main__": main()