File size: 8,140 Bytes
8cf92b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python3
"""Train MapNuRec (ver.neurec) — per-pixel feed-forward 3DGS warm-started from
Depth-Anything-V2, with MapGS map-grounding. Input = N context views -> per-pixel
Gaussians (metric world frame) -> gsplat render to held-out views. Losses: photometric
L1+SSIM, ② map-depth (metric anchor on ground = the MapGS contribution that gives the
feed-forward depth its metric scale), L_vert mono-depth prior. Eval on held-out SCENES
(av2/val). Everything is metric/centered — no scene-scale normalization needed."""
import argparse, time, os, random
import numpy as np
import torch
import torch.nn.functional as F
from gsplat import rasterization

from mapgs.config import load_config
from mapgs.data import UnifiedClipDataset
from mapgs.hdmap.rasterize_map import rasterize_map_depth
from mapgs.losses import mapdepth_loss
from mapgs.eval.metrics import psnr, ssim
from mapnurec import MapNuRec

DEV = "cuda"


def prep(s, n_in, device):
    """Context views = input; sup views = held-out render targets. Dataset already
    centers poses to a local metric frame (no rotation), so we use them directly."""
    return dict(
        in_img=s.ctx_images[:n_in].to(device), in_K=s.ctx_K[:n_in].to(device),
        in_c2w=s.ctx_c2w[:n_in].to(device),
        sup_img=s.sup_images.to(device), sup_K=s.sup_K.to(device), sup_c2w=s.sup_c2w.to(device),
        ground=s.ground.to(device))


def render(g, c2w, K, H, W):
    out, _, _ = rasterization(means=g["means"], quats=g["quats"], scales=g["scales"],
                              opacities=g["opacities"], colors=g["colors"],
                              viewmats=torch.inverse(c2w), Ks=K, width=W, height=H,
                              near_plane=0.01, far_plane=500.0, render_mode="RGB+ED")
    rgb = out[..., :3].clamp(0, 1).permute(0, 3, 1, 2)               # [S,3,H,W]
    depth = out[..., 3]                                              # [S,H,W]
    return rgb, depth


def ssi_disp(pred_depth, mono_disp, mask, min_px=256):
    valid = mask & (pred_depth > 1e-3)
    if int(valid.sum()) < min_px:
        return pred_depth.sum() * 0.0
    pd, td = 1.0 / pred_depth[valid], mono_disp[valid]
    nrm = lambda x: (x - x.median()) / (x - x.median()).abs().mean().clamp_min(1e-6)
    return F.l1_loss(nrm(pd), nrm(td))


@torch.no_grad()
def evaluate(model, ds, n, n_in, device):
    model.eval(); ps, ss = [], []
    for i in range(min(n, len(ds))):
        d = prep(ds[i], n_in, device)
        g = model(d["in_img"], d["in_K"], d["in_c2w"])
        rgb, _ = render(g, d["sup_c2w"], d["sup_K"], *d["sup_img"].shape[-2:])
        p, s = float(psnr(rgb, d["sup_img"])), float(ssim(rgb, d["sup_img"]))
        if p == p and abs(p) != float("inf"):
            ps.append(p); ss.append(s)
    model.train()
    mp = sum(ps) / max(1, len(ps)); sd = (sum((x - mp) ** 2 for x in ps) / max(1, len(ps))) ** 0.5
    return mp, sum(ss) / max(1, len(ss)), sd, len(ps)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--roots", default="/mnt/william/data/unified/av2,/mnt/william/data/unified/waymo")
    ap.add_argument("--val-roots", default="/mnt/william/data/unified/av2")
    ap.add_argument("--iters", type=int, default=2000)
    ap.add_argument("--n-in", type=int, default=10)
    ap.add_argument("--height", type=int, default=448)
    ap.add_argument("--width", type=int, default=784)
    ap.add_argument("--lr-head", type=float, default=3e-4)
    ap.add_argument("--lr-da", type=float, default=2e-5)            # gentle on the warm-started backbone
    ap.add_argument("--wd", type=float, default=0.0)
    ap.add_argument("--lam-md", type=float, default=0.5)
    ap.add_argument("--lam-vert", type=float, default=0.05)
    ap.add_argument("--vert-ramp", type=int, default=400)
    ap.add_argument("--eval-clips", type=int, default=48)
    ap.add_argument("--eval-every", type=int, default=500)
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--out", default="/mnt/william/runs/mapnurec.safetensors")
    args = ap.parse_args()
    H, W = args.height, args.width
    random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed)

    model = MapNuRec().to(DEV)
    cfg = load_config(overrides=["data.name=unified", f"data.root={args.roots}",
                                 f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048"])
    ds = UnifiedClipDataset(cfg, roots=args.roots.split(","), split="train", n_sup_views=6)
    vds = UnifiedClipDataset(cfg, roots=args.val_roots.split(","), split="val", n_sup_views=6)
    tids = {os.path.basename(c.rstrip("/")) for c in ds.clips}
    leak = sum(os.path.basename(c.rstrip("/")) in tids for c in vds.clips)
    print(f"MapNuRec | train {len(ds)} | held-out val scenes {len(vds)} | overlap {leak} | {H}x{W} n_in {args.n_in}", flush=True)

    from mapgs.losses import Tempering
    temper = Tempering(cfg.loss, cfg.model.tokens, args.iters)
    da_ids = {id(p) for p in model.da.parameters()}
    opt = torch.optim.AdamW([
        {"params": [p for p in model.parameters() if id(p) in da_ids and p.requires_grad], "lr": args.lr_da},
        {"params": [p for p in model.parameters() if id(p) not in da_ids and p.requires_grad], "lr": args.lr_head},
    ], betas=(0.9, 0.95), weight_decay=args.wd)

    b_ps, b_ss, b_sd, b_n = evaluate(model, vds, args.eval_clips, args.n_in, DEV)
    print(f"BEFORE (warm-start DA-V2, no train): held-out-SCENE PSNR {b_ps:.2f}±{b_sd:.2f} SSIM {b_ss:.3f} (n={b_n})", flush=True)

    from safetensors.torch import save_file
    best_path = args.out.replace(".safetensors", "_best.safetensors"); best = b_ps
    t = time.time()
    for step in range(args.iters):
        eps = temper.eps(step)
        d = prep(ds[step % len(ds)], args.n_in, DEV)
        g = model(d["in_img"], d["in_K"], d["in_c2w"])
        rgb, depth = render(g, d["sup_c2w"], d["sup_K"], H, W)
        l_rgb = F.l1_loss(rgb, d["sup_img"]) + 0.1 * (1 - ssim(rgb, d["sup_img"]))
        with torch.no_grad():
            md, mask = rasterize_map_depth(d["ground"], d["sup_K"], d["sup_c2w"], H, W)
        l_md = mapdepth_loss(depth, md, mask, eps, cfg.loss.huber_delta) if mask.any() else depth.sum() * 0
        if step >= args.vert_ramp and args.lam_vert > 0:
            mono = model.disp(d["sup_img"]).detach()                 # DA-V2 disparity on the target view
            l_vert = ssi_disp(depth, mono, (~mask) & (depth > 1e-3))
        else:
            l_vert = depth.sum() * 0
        loss = l_rgb + args.lam_md * l_md + args.lam_vert * l_vert
        opt.zero_grad(set_to_none=True)
        if torch.isfinite(loss):
            loss.backward()
            gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            if torch.isfinite(gn):
                opt.step()
        if step % 50 == 0 or step < 4:
            a, b = float(F.softplus(model.aff[0])), float(F.softplus(model.aff[1]))
            print(f"it {step:5d} | loss {float(loss):.4f} rgb {float(l_rgb):.4f} md {float(l_md):.4f} "
                  f"vert {float(l_vert):.4f} | aff({a:.3f},{b:.3f}) G {g['means'].shape[0]//1000}k | {time.time()-t:.0f}s", flush=True)
        if step > 0 and step % args.eval_every == 0:
            e_ps, e_ss, e_sd, e_n = evaluate(model, vds, args.eval_clips, args.n_in, DEV)
            tag = ""
            if e_ps > best:
                best = e_ps; save_file(model.state_dict(), best_path); tag = " *best"
            save_file(model.state_dict(), args.out)                  # latest (10h crash recovery)
            print(f"  [eval @ {step}] held-out-SCENE PSNR {e_ps:.2f}±{e_sd:.2f} SSIM {e_ss:.3f} (n={e_n}){tag} | {time.time()-t:.0f}s", flush=True)

    a_ps, a_ss, a_sd, a_n = evaluate(model, vds, args.eval_clips, args.n_in, DEV)
    if a_ps > best:
        best = a_ps; save_file(model.state_dict(), best_path)
    save_file(model.state_dict(), args.out)
    print(f"\nAFTER ({args.iters} it): held-out-SCENE PSNR {a_ps:.2f}±{a_sd:.2f} SSIM {a_ss:.3f} (n={a_n})", flush=True)
    print(f"=> BEFORE {b_ps:.2f} -> AFTER {a_ps:.2f} | BEST {best:.2f} -> {best_path}", flush=True)


if __name__ == "__main__":
    main()