| |
| """Train MapNuRec (ver.neurec) — per-pixel feed-forward 3DGS warm-started from |
| Depth-Anything-V2, with MapGS map-grounding. Input = N context views -> per-pixel |
| Gaussians (metric world frame) -> gsplat render to held-out views. Losses: photometric |
| L1+SSIM, ② map-depth (metric anchor on ground = the MapGS contribution that gives the |
| feed-forward depth its metric scale), L_vert mono-depth prior. Eval on held-out SCENES |
| (av2/val). Everything is metric/centered — no scene-scale normalization needed.""" |
| import argparse, time, os, random |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from gsplat import rasterization |
|
|
| from mapgs.config import load_config |
| from mapgs.data import UnifiedClipDataset |
| from mapgs.hdmap.rasterize_map import rasterize_map_depth |
| from mapgs.losses import mapdepth_loss |
| from mapgs.eval.metrics import psnr, ssim |
| from mapnurec import MapNuRec |
|
|
| DEV = "cuda" |
|
|
|
|
| def prep(s, n_in, device): |
| """Context views = input; sup views = held-out render targets. Dataset already |
| centers poses to a local metric frame (no rotation), so we use them directly.""" |
| return dict( |
| in_img=s.ctx_images[:n_in].to(device), in_K=s.ctx_K[:n_in].to(device), |
| in_c2w=s.ctx_c2w[:n_in].to(device), |
| sup_img=s.sup_images.to(device), sup_K=s.sup_K.to(device), sup_c2w=s.sup_c2w.to(device), |
| ground=s.ground.to(device)) |
|
|
|
|
| def render(g, c2w, K, H, W): |
| out, _, _ = rasterization(means=g["means"], quats=g["quats"], scales=g["scales"], |
| opacities=g["opacities"], colors=g["colors"], |
| viewmats=torch.inverse(c2w), Ks=K, width=W, height=H, |
| near_plane=0.01, far_plane=500.0, render_mode="RGB+ED") |
| rgb = out[..., :3].clamp(0, 1).permute(0, 3, 1, 2) |
| depth = out[..., 3] |
| return rgb, depth |
|
|
|
|
| def ssi_disp(pred_depth, mono_disp, mask, min_px=256): |
| valid = mask & (pred_depth > 1e-3) |
| if int(valid.sum()) < min_px: |
| return pred_depth.sum() * 0.0 |
| pd, td = 1.0 / pred_depth[valid], mono_disp[valid] |
| nrm = lambda x: (x - x.median()) / (x - x.median()).abs().mean().clamp_min(1e-6) |
| return F.l1_loss(nrm(pd), nrm(td)) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, ds, n, n_in, device): |
| model.eval(); ps, ss = [], [] |
| for i in range(min(n, len(ds))): |
| d = prep(ds[i], n_in, device) |
| g = model(d["in_img"], d["in_K"], d["in_c2w"]) |
| rgb, _ = render(g, d["sup_c2w"], d["sup_K"], *d["sup_img"].shape[-2:]) |
| p, s = float(psnr(rgb, d["sup_img"])), float(ssim(rgb, d["sup_img"])) |
| if p == p and abs(p) != float("inf"): |
| ps.append(p); ss.append(s) |
| model.train() |
| mp = sum(ps) / max(1, len(ps)); sd = (sum((x - mp) ** 2 for x in ps) / max(1, len(ps))) ** 0.5 |
| return mp, sum(ss) / max(1, len(ss)), sd, len(ps) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--roots", default="/mnt/william/data/unified/av2,/mnt/william/data/unified/waymo") |
| ap.add_argument("--val-roots", default="/mnt/william/data/unified/av2") |
| ap.add_argument("--iters", type=int, default=2000) |
| ap.add_argument("--n-in", type=int, default=10) |
| ap.add_argument("--height", type=int, default=448) |
| ap.add_argument("--width", type=int, default=784) |
| ap.add_argument("--lr-head", type=float, default=3e-4) |
| ap.add_argument("--lr-da", type=float, default=2e-5) |
| ap.add_argument("--wd", type=float, default=0.0) |
| ap.add_argument("--lam-md", type=float, default=0.5) |
| ap.add_argument("--lam-vert", type=float, default=0.05) |
| ap.add_argument("--vert-ramp", type=int, default=400) |
| ap.add_argument("--eval-clips", type=int, default=48) |
| ap.add_argument("--eval-every", type=int, default=500) |
| ap.add_argument("--seed", type=int, default=0) |
| ap.add_argument("--out", default="/mnt/william/runs/mapnurec.safetensors") |
| args = ap.parse_args() |
| H, W = args.height, args.width |
| random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed) |
|
|
| model = MapNuRec().to(DEV) |
| cfg = load_config(overrides=["data.name=unified", f"data.root={args.roots}", |
| f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048"]) |
| ds = UnifiedClipDataset(cfg, roots=args.roots.split(","), split="train", n_sup_views=6) |
| vds = UnifiedClipDataset(cfg, roots=args.val_roots.split(","), split="val", n_sup_views=6) |
| tids = {os.path.basename(c.rstrip("/")) for c in ds.clips} |
| leak = sum(os.path.basename(c.rstrip("/")) in tids for c in vds.clips) |
| print(f"MapNuRec | train {len(ds)} | held-out val scenes {len(vds)} | overlap {leak} | {H}x{W} n_in {args.n_in}", flush=True) |
|
|
| from mapgs.losses import Tempering |
| temper = Tempering(cfg.loss, cfg.model.tokens, args.iters) |
| da_ids = {id(p) for p in model.da.parameters()} |
| opt = torch.optim.AdamW([ |
| {"params": [p for p in model.parameters() if id(p) in da_ids and p.requires_grad], "lr": args.lr_da}, |
| {"params": [p for p in model.parameters() if id(p) not in da_ids and p.requires_grad], "lr": args.lr_head}, |
| ], betas=(0.9, 0.95), weight_decay=args.wd) |
|
|
| b_ps, b_ss, b_sd, b_n = evaluate(model, vds, args.eval_clips, args.n_in, DEV) |
| print(f"BEFORE (warm-start DA-V2, no train): held-out-SCENE PSNR {b_ps:.2f}±{b_sd:.2f} SSIM {b_ss:.3f} (n={b_n})", flush=True) |
|
|
| from safetensors.torch import save_file |
| best_path = args.out.replace(".safetensors", "_best.safetensors"); best = b_ps |
| t = time.time() |
| for step in range(args.iters): |
| eps = temper.eps(step) |
| d = prep(ds[step % len(ds)], args.n_in, DEV) |
| g = model(d["in_img"], d["in_K"], d["in_c2w"]) |
| rgb, depth = render(g, d["sup_c2w"], d["sup_K"], H, W) |
| l_rgb = F.l1_loss(rgb, d["sup_img"]) + 0.1 * (1 - ssim(rgb, d["sup_img"])) |
| with torch.no_grad(): |
| md, mask = rasterize_map_depth(d["ground"], d["sup_K"], d["sup_c2w"], H, W) |
| l_md = mapdepth_loss(depth, md, mask, eps, cfg.loss.huber_delta) if mask.any() else depth.sum() * 0 |
| if step >= args.vert_ramp and args.lam_vert > 0: |
| mono = model.disp(d["sup_img"]).detach() |
| l_vert = ssi_disp(depth, mono, (~mask) & (depth > 1e-3)) |
| else: |
| l_vert = depth.sum() * 0 |
| loss = l_rgb + args.lam_md * l_md + args.lam_vert * l_vert |
| opt.zero_grad(set_to_none=True) |
| if torch.isfinite(loss): |
| loss.backward() |
| gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| if torch.isfinite(gn): |
| opt.step() |
| if step % 50 == 0 or step < 4: |
| a, b = float(F.softplus(model.aff[0])), float(F.softplus(model.aff[1])) |
| print(f"it {step:5d} | loss {float(loss):.4f} rgb {float(l_rgb):.4f} md {float(l_md):.4f} " |
| f"vert {float(l_vert):.4f} | aff({a:.3f},{b:.3f}) G {g['means'].shape[0]//1000}k | {time.time()-t:.0f}s", flush=True) |
| if step > 0 and step % args.eval_every == 0: |
| e_ps, e_ss, e_sd, e_n = evaluate(model, vds, args.eval_clips, args.n_in, DEV) |
| tag = "" |
| if e_ps > best: |
| best = e_ps; save_file(model.state_dict(), best_path); tag = " *best" |
| save_file(model.state_dict(), args.out) |
| print(f" [eval @ {step}] held-out-SCENE PSNR {e_ps:.2f}±{e_sd:.2f} SSIM {e_ss:.3f} (n={e_n}){tag} | {time.time()-t:.0f}s", flush=True) |
|
|
| a_ps, a_ss, a_sd, a_n = evaluate(model, vds, args.eval_clips, args.n_in, DEV) |
| if a_ps > best: |
| best = a_ps; save_file(model.state_dict(), best_path) |
| save_file(model.state_dict(), args.out) |
| print(f"\nAFTER ({args.iters} it): held-out-SCENE PSNR {a_ps:.2f}±{a_sd:.2f} SSIM {a_ss:.3f} (n={a_n})", flush=True) |
| print(f"=> BEFORE {b_ps:.2f} -> AFTER {a_ps:.2f} | BEST {best:.2f} -> {best_path}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|