File size: 8,140 Bytes
8cf92b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | #!/usr/bin/env python3
"""Train MapNuRec (ver.neurec) — per-pixel feed-forward 3DGS warm-started from
Depth-Anything-V2, with MapGS map-grounding. Input = N context views -> per-pixel
Gaussians (metric world frame) -> gsplat render to held-out views. Losses: photometric
L1+SSIM, ② map-depth (metric anchor on ground = the MapGS contribution that gives the
feed-forward depth its metric scale), L_vert mono-depth prior. Eval on held-out SCENES
(av2/val). Everything is metric/centered — no scene-scale normalization needed."""
import argparse, time, os, random
import numpy as np
import torch
import torch.nn.functional as F
from gsplat import rasterization
from mapgs.config import load_config
from mapgs.data import UnifiedClipDataset
from mapgs.hdmap.rasterize_map import rasterize_map_depth
from mapgs.losses import mapdepth_loss
from mapgs.eval.metrics import psnr, ssim
from mapnurec import MapNuRec
DEV = "cuda"
def prep(s, n_in, device):
"""Context views = input; sup views = held-out render targets. Dataset already
centers poses to a local metric frame (no rotation), so we use them directly."""
return dict(
in_img=s.ctx_images[:n_in].to(device), in_K=s.ctx_K[:n_in].to(device),
in_c2w=s.ctx_c2w[:n_in].to(device),
sup_img=s.sup_images.to(device), sup_K=s.sup_K.to(device), sup_c2w=s.sup_c2w.to(device),
ground=s.ground.to(device))
def render(g, c2w, K, H, W):
out, _, _ = rasterization(means=g["means"], quats=g["quats"], scales=g["scales"],
opacities=g["opacities"], colors=g["colors"],
viewmats=torch.inverse(c2w), Ks=K, width=W, height=H,
near_plane=0.01, far_plane=500.0, render_mode="RGB+ED")
rgb = out[..., :3].clamp(0, 1).permute(0, 3, 1, 2) # [S,3,H,W]
depth = out[..., 3] # [S,H,W]
return rgb, depth
def ssi_disp(pred_depth, mono_disp, mask, min_px=256):
valid = mask & (pred_depth > 1e-3)
if int(valid.sum()) < min_px:
return pred_depth.sum() * 0.0
pd, td = 1.0 / pred_depth[valid], mono_disp[valid]
nrm = lambda x: (x - x.median()) / (x - x.median()).abs().mean().clamp_min(1e-6)
return F.l1_loss(nrm(pd), nrm(td))
@torch.no_grad()
def evaluate(model, ds, n, n_in, device):
model.eval(); ps, ss = [], []
for i in range(min(n, len(ds))):
d = prep(ds[i], n_in, device)
g = model(d["in_img"], d["in_K"], d["in_c2w"])
rgb, _ = render(g, d["sup_c2w"], d["sup_K"], *d["sup_img"].shape[-2:])
p, s = float(psnr(rgb, d["sup_img"])), float(ssim(rgb, d["sup_img"]))
if p == p and abs(p) != float("inf"):
ps.append(p); ss.append(s)
model.train()
mp = sum(ps) / max(1, len(ps)); sd = (sum((x - mp) ** 2 for x in ps) / max(1, len(ps))) ** 0.5
return mp, sum(ss) / max(1, len(ss)), sd, len(ps)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--roots", default="/mnt/william/data/unified/av2,/mnt/william/data/unified/waymo")
ap.add_argument("--val-roots", default="/mnt/william/data/unified/av2")
ap.add_argument("--iters", type=int, default=2000)
ap.add_argument("--n-in", type=int, default=10)
ap.add_argument("--height", type=int, default=448)
ap.add_argument("--width", type=int, default=784)
ap.add_argument("--lr-head", type=float, default=3e-4)
ap.add_argument("--lr-da", type=float, default=2e-5) # gentle on the warm-started backbone
ap.add_argument("--wd", type=float, default=0.0)
ap.add_argument("--lam-md", type=float, default=0.5)
ap.add_argument("--lam-vert", type=float, default=0.05)
ap.add_argument("--vert-ramp", type=int, default=400)
ap.add_argument("--eval-clips", type=int, default=48)
ap.add_argument("--eval-every", type=int, default=500)
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--out", default="/mnt/william/runs/mapnurec.safetensors")
args = ap.parse_args()
H, W = args.height, args.width
random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed)
model = MapNuRec().to(DEV)
cfg = load_config(overrides=["data.name=unified", f"data.root={args.roots}",
f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048"])
ds = UnifiedClipDataset(cfg, roots=args.roots.split(","), split="train", n_sup_views=6)
vds = UnifiedClipDataset(cfg, roots=args.val_roots.split(","), split="val", n_sup_views=6)
tids = {os.path.basename(c.rstrip("/")) for c in ds.clips}
leak = sum(os.path.basename(c.rstrip("/")) in tids for c in vds.clips)
print(f"MapNuRec | train {len(ds)} | held-out val scenes {len(vds)} | overlap {leak} | {H}x{W} n_in {args.n_in}", flush=True)
from mapgs.losses import Tempering
temper = Tempering(cfg.loss, cfg.model.tokens, args.iters)
da_ids = {id(p) for p in model.da.parameters()}
opt = torch.optim.AdamW([
{"params": [p for p in model.parameters() if id(p) in da_ids and p.requires_grad], "lr": args.lr_da},
{"params": [p for p in model.parameters() if id(p) not in da_ids and p.requires_grad], "lr": args.lr_head},
], betas=(0.9, 0.95), weight_decay=args.wd)
b_ps, b_ss, b_sd, b_n = evaluate(model, vds, args.eval_clips, args.n_in, DEV)
print(f"BEFORE (warm-start DA-V2, no train): held-out-SCENE PSNR {b_ps:.2f}±{b_sd:.2f} SSIM {b_ss:.3f} (n={b_n})", flush=True)
from safetensors.torch import save_file
best_path = args.out.replace(".safetensors", "_best.safetensors"); best = b_ps
t = time.time()
for step in range(args.iters):
eps = temper.eps(step)
d = prep(ds[step % len(ds)], args.n_in, DEV)
g = model(d["in_img"], d["in_K"], d["in_c2w"])
rgb, depth = render(g, d["sup_c2w"], d["sup_K"], H, W)
l_rgb = F.l1_loss(rgb, d["sup_img"]) + 0.1 * (1 - ssim(rgb, d["sup_img"]))
with torch.no_grad():
md, mask = rasterize_map_depth(d["ground"], d["sup_K"], d["sup_c2w"], H, W)
l_md = mapdepth_loss(depth, md, mask, eps, cfg.loss.huber_delta) if mask.any() else depth.sum() * 0
if step >= args.vert_ramp and args.lam_vert > 0:
mono = model.disp(d["sup_img"]).detach() # DA-V2 disparity on the target view
l_vert = ssi_disp(depth, mono, (~mask) & (depth > 1e-3))
else:
l_vert = depth.sum() * 0
loss = l_rgb + args.lam_md * l_md + args.lam_vert * l_vert
opt.zero_grad(set_to_none=True)
if torch.isfinite(loss):
loss.backward()
gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if torch.isfinite(gn):
opt.step()
if step % 50 == 0 or step < 4:
a, b = float(F.softplus(model.aff[0])), float(F.softplus(model.aff[1]))
print(f"it {step:5d} | loss {float(loss):.4f} rgb {float(l_rgb):.4f} md {float(l_md):.4f} "
f"vert {float(l_vert):.4f} | aff({a:.3f},{b:.3f}) G {g['means'].shape[0]//1000}k | {time.time()-t:.0f}s", flush=True)
if step > 0 and step % args.eval_every == 0:
e_ps, e_ss, e_sd, e_n = evaluate(model, vds, args.eval_clips, args.n_in, DEV)
tag = ""
if e_ps > best:
best = e_ps; save_file(model.state_dict(), best_path); tag = " *best"
save_file(model.state_dict(), args.out) # latest (10h crash recovery)
print(f" [eval @ {step}] held-out-SCENE PSNR {e_ps:.2f}±{e_sd:.2f} SSIM {e_ss:.3f} (n={e_n}){tag} | {time.time()-t:.0f}s", flush=True)
a_ps, a_ss, a_sd, a_n = evaluate(model, vds, args.eval_clips, args.n_in, DEV)
if a_ps > best:
best = a_ps; save_file(model.state_dict(), best_path)
save_file(model.state_dict(), args.out)
print(f"\nAFTER ({args.iters} it): held-out-SCENE PSNR {a_ps:.2f}±{a_sd:.2f} SSIM {a_ss:.3f} (n={a_n})", flush=True)
print(f"=> BEFORE {b_ps:.2f} -> AFTER {a_ps:.2f} | BEST {best:.2f} -> {best_path}", flush=True)
if __name__ == "__main__":
main()
|