mapvggt / scripts /train_mapvggt.py
ChenmingWu's picture
Upload folder using huggingface_hub
8cf92b3 verified
Raw
History Blame Contribute Delete
7.25 kB
#!/usr/bin/env python3
"""Train MapVGGT -- per-pixel feed-forward 3DGS warm-started from VGGT-Omega (1B-512),
replacing TokenGS. Input = N context views -> VGGT metric depth -> per-pixel world
Gaussians -> gsplat render to held-out views. Losses: photometric L1+SSIM, MapGS
map-depth (metric ground anchor) + sub-surface free-space + L_vert mono-disp prior.
Train on Waymo; hold out ONE clip for val. Backbone frozen by default (train the head);
--finetune-backbone to also tune VGGT (gently)."""
import argparse, time, os, random, copy
import numpy as np
import torch
import torch.nn.functional as F
from mapgs.config import load_config
from mapgs.data import UnifiedClipDataset
from mapgs.hdmap.rasterize_map import rasterize_map_depth
from mapgs.losses import mapdepth_loss
from mapgs.eval.metrics import psnr, ssim
from mapvggt import MapVGGT
from scripts.train_mapnurec import prep, render, ssi_disp
DEV = "cuda"
@torch.no_grad()
def evaluate(model, ds, n, n_in, device):
model.eval(); ps, ss = [], []
for i in range(min(n, len(ds.clips))):
d = prep(ds[i], n_in, device)
g = model(d["in_img"], d["in_K"], d["in_c2w"])
rgb, _ = render(g, d["sup_c2w"], d["sup_K"], *d["sup_img"].shape[-2:])
p, s = float(psnr(rgb, d["sup_img"])), float(ssim(rgb, d["sup_img"]))
if p == p and abs(p) != float("inf"):
ps.append(p); ss.append(s)
model.train()
if not model.finetune_backbone:
model.vggt.eval()
mp = sum(ps) / max(1, len(ps)); sd = (sum((x - mp) ** 2 for x in ps) / max(1, len(ps))) ** 0.5
return mp, sum(ss) / max(1, len(ss)), sd, len(ps)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--roots", default="/mnt/william/data/unified/waymo")
ap.add_argument("--iters", type=int, default=4000)
ap.add_argument("--n-in", type=int, default=8)
ap.add_argument("--height", type=int, default=256) # multiple of 16 (VGGT patch)
ap.add_argument("--width", type=int, default=448)
ap.add_argument("--lr-head", type=float, default=3e-4)
ap.add_argument("--lr-vggt", type=float, default=1e-5)
ap.add_argument("--finetune-backbone", action="store_true")
ap.add_argument("--lam-md", type=float, default=0.5) # ② map-depth metric anchor
ap.add_argument("--lam-fs", type=float, default=0.1) # ② sub-surface free-space
ap.add_argument("--lam-vert", type=float, default=0.05) # L_vert mono-disp prior
ap.add_argument("--vert-ramp", type=int, default=400)
ap.add_argument("--eval-every", type=int, default=250)
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--out", default="/mnt/william/runs/mapvggt.safetensors")
args = ap.parse_args()
H, W = args.height, args.width
random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed)
model = MapVGGT(finetune_backbone=args.finetune_backbone).to(DEV)
cfg = load_config(overrides=["data.name=unified", f"data.root={args.roots}",
f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048"])
full = UnifiedClipDataset(cfg, roots=args.roots.split(","), split="train", n_sup_views=6)
# hold out ONE clip for val (the last, deterministic), exclude it from training.
val_clip = full.clips[-1]
ds = copy.copy(full); ds.clips = full.clips[:-1]
vds = copy.copy(full); vds.clips = [val_clip]
print(f"MapVGGT(VGGT-Omega-1B) | train {len(ds.clips)} | val 1 clip [{os.path.basename(val_clip)}] | "
f"{H}x{W} n_in {args.n_in} | finetune_backbone={args.finetune_backbone}", flush=True)
from mapgs.losses import Tempering
temper = Tempering(cfg.loss, cfg.model.tokens, args.iters)
vggt_ids = {id(p) for p in model.vggt.parameters()}
groups = [{"params": [p for p in model.parameters() if id(p) not in vggt_ids and p.requires_grad],
"lr": args.lr_head}]
if args.finetune_backbone:
groups.append({"params": [p for p in model.vggt.parameters() if p.requires_grad], "lr": args.lr_vggt})
opt = torch.optim.AdamW(groups, betas=(0.9, 0.95), weight_decay=0.0)
b_ps, b_ss, b_sd, b_n = evaluate(model, vds, 1, args.n_in, DEV)
print(f"BEFORE (warm-start VGGT-Omega, head untrained): val PSNR {b_ps:.2f} SSIM {b_ss:.3f}", flush=True)
from safetensors.torch import save_file
best_path = args.out.replace(".safetensors", "_best.safetensors"); best = b_ps
# only persist trainable tensors (head, + backbone if finetuned) to keep ckpts small
def trainable_sd():
if args.finetune_backbone:
return model.state_dict()
return {k: v for k, v in model.state_dict().items() if not k.startswith("vggt.")}
t = time.time()
for step in range(args.iters):
eps = temper.eps(step)
d = prep(ds[step % len(ds.clips)], args.n_in, DEV)
g = model(d["in_img"], d["in_K"], d["in_c2w"])
rgb, depth = render(g, d["sup_c2w"], d["sup_K"], H, W)
l_rgb = F.l1_loss(rgb, d["sup_img"]) + 0.1 * (1 - ssim(rgb, d["sup_img"]))
with torch.no_grad():
md, mask = rasterize_map_depth(d["ground"], d["sup_K"], d["sup_c2w"], H, W)
l_md = mapdepth_loss(depth, md, mask, eps, cfg.loss.huber_delta) if mask.any() else depth.sum() * 0
# sub-surface free-space: penalize rendered depth FARTHER than the ground (gaussians below road)
l_fs = F.relu(depth - md)[mask].mean() if mask.any() else depth.sum() * 0
if step >= args.vert_ramp and args.lam_vert > 0:
mono = model.vggt_depth(d["sup_img"])[0].detach()
mono_disp = 1.0 / mono.clamp(min=1e-3)
l_vert = ssi_disp(depth, mono_disp, (~mask) & (depth > 1e-3))
else:
l_vert = depth.sum() * 0
loss = l_rgb + args.lam_md * l_md + args.lam_fs * l_fs + args.lam_vert * l_vert
opt.zero_grad(set_to_none=True)
if torch.isfinite(loss):
loss.backward()
gn = torch.nn.utils.clip_grad_norm_([p for grp in groups for p in grp["params"]], 1.0)
if torch.isfinite(gn):
opt.step()
if step % 50 == 0 or step < 4:
print(f"it {step:5d} | loss {float(loss):.4f} rgb {float(l_rgb):.4f} md {float(l_md):.4f} "
f"fs {float(l_fs):.4f} vert {float(l_vert):.4f} G {g['means'].shape[0]//1000}k | "
f"{time.time()-t:.0f}s", flush=True)
if step > 0 and step % args.eval_every == 0:
e_ps, e_ss, e_sd, e_n = evaluate(model, vds, 1, args.n_in, DEV)
tag = ""
if e_ps > best:
best = e_ps; save_file(trainable_sd(), best_path); tag = " *best"
save_file(trainable_sd(), args.out)
print(f" [eval @ {step}] val PSNR {e_ps:.2f} SSIM {e_ss:.3f}{tag} | {time.time()-t:.0f}s", flush=True)
a_ps, a_ss, a_sd, a_n = evaluate(model, vds, 1, args.n_in, DEV)
if a_ps > best:
best = a_ps; save_file(trainable_sd(), best_path)
save_file(trainable_sd(), args.out)
print(f"\nAFTER ({args.iters} it): val PSNR {a_ps:.2f} SSIM {a_ss:.3f}", flush=True)
print(f"=> BEFORE {b_ps:.2f} -> AFTER {a_ps:.2f} | BEST {best:.2f} -> {best_path}", flush=True)
if __name__ == "__main__":
main()