#!/usr/bin/env python3 """Measure how much DEPTH/GEOMETRY error limits MapVGGT novel-view PSNR. Loads the finetuned + base checkpoints, runs on the 40 scene-disjoint Waymo val clips, and computes concrete numbers for the 5 questions in the brief.""" import os, sys, copy, math, faulthandler faulthandler.enable() import numpy as np import torch import torch.nn.functional as F from gsplat import rasterization sys.path.insert(0, "/mnt/william/_vggt_omega_repo") from mapgs.config import load_config from mapgs.data import UnifiedClipDataset from mapgs.hdmap.rasterize_map import rasterize_map_depth from mapgs.eval.metrics import psnr, ssim from mapvggt import MapVGGT from mapvggt.heads import cat_gaussians from mapnurec.model import lift_to_world DEV = "cuda" H, W, N_IN = 256, 448, 8 ROOT = "/mnt/william/data/unified/waymo" def build_val(cfg): full = UnifiedClipDataset(cfg, roots=[ROOT], split="train", n_sup_views=6) def segid(p): return "_".join(os.path.basename(p.rstrip("/")).split("_")[:2]) segs = sorted(set(segid(c) for c in full.clips)) val_segs = set(segs[:40]) seen, vclips = set(), [] for c in full.clips: sgi = segid(c) if sgi in val_segs and sgi not in seen: seen.add(sgi); vclips.append(c) vds = copy.copy(full); vds.clips = vclips return vds def prep(s, device): return dict( in_img=s.ctx_images[:N_IN].to(device), in_K=s.ctx_K[:N_IN].to(device), in_c2w=s.ctx_c2w[:N_IN].to(device), sup_img=s.sup_images.to(device), sup_K=s.sup_K.to(device), sup_c2w=s.sup_c2w.to(device), ground=s.ground.to(device), ap=s.anchor_pos[None].to(device), at=s.anchor_type[None].to(device), an=s.anchor_normal[None].to(device)) def render_static(gsm, K, c2w): """Render gsm (no dynamics) into the given views. Returns rgb [S,3,H,W], depth [S,H,W].""" rgbs, deps = [], [] for i in range(c2w.shape[0]): out, _, _ = rasterization(means=gsm["means"], quats=gsm["quats"], scales=gsm["scales"], opacities=gsm["opacities"], colors=gsm["colors"], viewmats=torch.inverse(c2w[i:i+1]), Ks=K[i:i+1], width=W, height=H, near_plane=0.01, far_plane=500.0, render_mode="RGB+ED") rgbs.append(out[0, ..., :3].clamp(0, 1).permute(2, 0, 1)); deps.append(out[0, ..., 3]) return torch.stack(rgbs), torch.stack(deps) def cam_centers(c2w): return c2w[:, :3, 3] def target_context_dist(d): """Mean over sup views of min distance to any context camera center.""" cc = cam_centers(d["in_c2w"]) # [Vc,3] sc = cam_centers(d["sup_c2w"]) # [Vs,3] dist = torch.cdist(sc, cc).min(dim=1).values return float(dist.mean()) @torch.no_grad() def per_clip_depth_error(model, d): """Per-pixel VGGT metric depth vs map ground depth, on ground pixels of the INPUT views. Returns dict of abs-error stats (m), near<20 / mid / far>40, and frac ground coverage.""" z, _ = model.vggt_depth(d["in_img"]) # [Vc,H,W] metric, clamped [1.5,120] md, mask = rasterize_map_depth(d["ground"], d["in_K"], d["in_c2w"], H, W) # [Vc,H,W] if not mask.any(): return None err = (z - md).abs()[mask] gt = md[mask] out = {} out["n_ground_px"] = int(mask.sum()) out["frac_ground"] = float(mask.float().mean()) out["med_all"] = float(err.median()); out["mean_all"] = float(err.mean()) out["med_rel"] = float((err / gt.clamp_min(1e-3)).median()) for name, lo, hi in [("near", 0, 20), ("mid", 20, 40), ("far", 40, 200)]: m2 = (gt >= lo) & (gt < hi) if m2.any(): out[f"med_{name}"] = float(err[m2].median()) out[f"mean_{name}"] = float(err[m2].mean()) out[f"n_{name}"] = int(m2.sum()) out[f"gtmean_{name}"] = float(gt[m2].mean()) return out def main(): cfg = load_config(overrides=["data.name=unified", f"data.root={ROOT}", f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048"]) vds = build_val(cfg) print(f"# val clips: {len(vds.clips)}", flush=True) # ---- load finetuned full model ---- print("loading abl_full_best (finetuned backbone+heads)...", flush=True) mf = MapVGGT(with_map=True, with_dyn=True, finetune_backbone=True).to(DEV) from safetensors.torch import load_file sdf = load_file("/mnt/william/runs/abl_full_best.safetensors") miss = mf.load_state_dict(sdf, strict=False) print(f" full: missing={len(miss.missing_keys)} unexpected={len(miss.unexpected_keys)}", flush=True) mf.eval() # precompute per-clip prep + target-context distance, sort EASY..HARD samples = [] for i in range(len(vds.clips)): d = prep(vds[i], DEV) samples.append((i, d, target_context_dist(d))) samples.sort(key=lambda x: x[2]) dists = [s[2] for s in samples] print(f"# target-context dist: min {min(dists):.3f} med {np.median(dists):.3f} max {max(dists):.3f}", flush=True) # ============ Q1: self-recon vs novel-view PSNR, EASY clips ============ print("\n==== Q1: self-recon (render INPUT views) vs novel-view PSNR ====", flush=True) print("idx dist novelPSNR selfPSNR novelSSIM selfSSIM", flush=True) easy_rows = [] n_easy = 8 for (i, d, dist) in samples[:n_easy]: gsm = mf(d["in_img"], d["in_K"], d["in_c2w"], d["ap"], d["at"], d["an"]) # novel: render the (static) gsm into sup views rgb_n, _ = render_static(gsm, d["sup_K"], d["sup_c2w"]) p_n = float(psnr(rgb_n, d["sup_img"])); s_n = float(ssim(rgb_n, d["sup_img"])) # self-recon: render gsm into the INPUT views (n_in of them) rgb_s, _ = render_static(gsm, d["in_K"], d["in_c2w"]) p_s = float(psnr(rgb_s, d["in_img"])); s_s = float(ssim(rgb_s, d["in_img"])) easy_rows.append((i, dist, p_n, p_s, s_n, s_s)) print(f"{i:3d} {dist:6.3f} {p_n:7.2f} {p_s:7.2f} {s_n:.3f} {s_s:.3f}", flush=True) en = np.array([r[2] for r in easy_rows]); es = np.array([r[3] for r in easy_rows]) print(f"EASY mean: novel {en.mean():.2f} self-recon {es.mean():.2f} gap {es.mean()-en.mean():.2f} dB", flush=True) # ============ Q2: VGGT depth accuracy vs map ground depth ============ print("\n==== Q2: VGGT depth abs-error vs HD-map ground depth (m), all val clips ====", flush=True) agg = {k: [] for k in ["med_all", "mean_all", "med_rel", "med_near", "mean_near", "med_mid", "mean_mid", "med_far", "mean_far"]} npx = {"near": 0, "mid": 0, "far": 0} gtm = {"near": [], "mid": [], "far": []} for (i, d, dist) in samples: st = per_clip_depth_error(mf, d) if st is None: continue for k in agg: if k in st: agg[k].append(st[k]) for b in ["near", "mid", "far"]: npx[b] += st.get(f"n_{b}", 0) if f"gtmean_{b}" in st: gtm[b].append(st[f"gtmean_{b}"]) def report(label, key): v = np.array(agg[key]) if len(v): print(f" {label:18s} median-of-clip {np.median(v):6.2f} mean {v.mean():6.2f} m", flush=True) report("abs err (all gnd)", "med_all") report("mean abs err all", "mean_all") print(f" relative err (all) median-of-clip {np.median(agg['med_rel'])*100:.1f}%", flush=True) report("near <20m (med)", "med_near") report("near <20m (mean)", "mean_near") report("mid 20-40m (med)", "med_mid") report("far >40m (med)", "med_far") report("far >40m (mean)", "mean_far") for b in ["near", "mid", "far"]: if gtm[b]: print(f" [{b}] ground px total {npx[b]} mean-GT-depth {np.mean(gtm[b]):.1f}m", flush=True) # ============ Q3: ghosting -- pixel displacement from depth error at eval baseline ============ print("\n==== Q3: gaussian reprojection displacement from depth error (px) ====", flush=True) # For each val clip: lift INPUT pixels with VGGT depth z and with z+err(map), reproject # both into the nearest sup view, measure pixel shift on ground pixels. disp_meds, disp_p90s = [], [] for (i, d, dist) in samples: z, _ = mf.vggt_depth(d["in_img"]) md, mask = rasterize_map_depth(d["ground"], d["in_K"], d["in_c2w"], H, W) if not mask.any(): continue # world points from predicted depth and from map (true ground) depth xyz_pred = lift_to_world(z, d["in_K"], d["in_c2w"]) # [Vc,H,W,3] xyz_true = lift_to_world(md, d["in_K"], d["in_c2w"]) # reproject both into the nearest sup view (by camera center) cc_in = cam_centers(d["in_c2w"]); cc_sup = cam_centers(d["sup_c2w"]) # choose the single sup view closest to the input rig centroid tgt = int(torch.cdist(cc_in.mean(0, keepdim=True), cc_sup).argmin()) c2w_t = d["sup_c2w"][tgt]; K_t = d["sup_K"][tgt] w2c = torch.inverse(c2w_t) def proj(xyz): p = xyz.reshape(-1, 3) pc = (w2c[:3, :3] @ p.T).T + w2c[:3, 3] uv = (K_t @ pc.T).T z_ = uv[:, 2].clamp_min(1e-4) return torch.stack([uv[:, 0] / z_, uv[:, 1] / z_], -1), pc[:, 2] uv_p, zc_p = proj(xyz_pred); uv_t, zc_t = proj(xyz_true) m = mask.reshape(-1) & (zc_t > 0.1) & (zc_p > 0.1) if not m.any(): continue shift = (uv_p[m] - uv_t[m]).norm(dim=-1) disp_meds.append(float(shift.median())); disp_p90s.append(float(shift.quantile(0.9))) print(f" reproj displacement on ground px (into nearest sup view):", flush=True) print(f" median-of-clip median shift {np.median(disp_meds):.2f} px mean {np.mean(disp_meds):.2f} px", flush=True) print(f" median-of-clip p90 shift {np.median(disp_p90s):.2f} px", flush=True) # correlate per-clip depth error and displacement with novel PSNR + dist (HARD vs disocclusion) print("\n==== Q3b: per-clip novelPSNR vs dist and vs depth-displacement ====", flush=True) rows = [] for (i, d, dist) in samples: gsm = mf(d["in_img"], d["in_K"], d["in_c2w"], d["ap"], d["at"], d["an"]) rgb_n, _ = render_static(gsm, d["sup_K"], d["sup_c2w"]) p_n = float(psnr(rgb_n, d["sup_img"])) st = per_clip_depth_error(mf, d) de = st["med_all"] if st else float("nan") rows.append((i, dist, p_n, de)) arr = np.array([(r[1], r[2], r[3]) for r in rows], dtype=float) valid = ~np.isnan(arr[:, 2]) def corr(a, b): a, b = a[valid], b[valid] m = ~np.isnan(a) & ~np.isnan(b) return float(np.corrcoef(a[m], b[m])[0, 1]) print(f" PSNR vs target-context dist r = {corr(arr[:,0], arr[:,1]):.3f}", flush=True) print(f" PSNR vs median depth-error r = {corr(arr[:,2], arr[:,1]):.3f}", flush=True) print(f" dist vs median depth-error r = {corr(arr[:,0], arr[:,2]):.3f}", flush=True) # split easy/hard by dist median med_d = np.median(arr[valid, 0]) easymask = valid & (arr[:, 0] <= med_d); hardmask = valid & (arr[:, 0] > med_d) print(f" EASY (dist<={med_d:.2f}): PSNR {arr[easymask,1].mean():.2f} depth-err {arr[easymask,2].mean():.2f}m", flush=True) print(f" HARD (dist >{med_d:.2f}): PSNR {arr[hardmask,1].mean():.2f} depth-err {arr[hardmask,2].mean():.2f}m", flush=True) # ============ Q4: finetuned vs frozen-backbone depth ============ print("\n==== Q4: finetuned-backbone vs FROZEN VGGT depth error ====", flush=True) # precompute per-clip map-depth GT on input views (cheap, no VGGT) so we can free # each backbone before loading the next (avoids 3 resident 1B models -> OOM). gt_cache = [] for (i, d, dist) in samples: md, mask = rasterize_map_depth(d["ground"], d["in_K"], d["in_c2w"], H, W) gt_cache.append((d, md, mask)) # full model already loaded -> measure it first, then free it. full = [] for (d, md, mask) in gt_cache: if not mask.any(): continue zu, _ = mf.vggt_depth(d["in_img"]); full.append(float((zu[mask]-md[mask]).abs().median())) nvggt = sum(1 for k in sdf if k.startswith("vggt.")) del mf; torch.cuda.empty_cache() def measure_backbone(ckpt, finetune): m = MapVGGT(with_map=False, with_dyn=False, finetune_backbone=finetune).to(DEV) nk = 0 if ckpt is not None: sd = load_file(ckpt) nk = sum(1 for k in sd if k.startswith("vggt.")) r = m.load_state_dict(sd, strict=False) print(f" {os.path.basename(ckpt)}: missing={len(r.missing_keys)} " f"unexpected={len(r.unexpected_keys)} vggt_keys={nk}", flush=True) m.eval() errs = [] for (d, md, mask) in gt_cache: if not mask.any(): continue z, _ = m.vggt_depth(d["in_img"]); errs.append(float((z[mask]-md[mask]).abs().median())) del m; torch.cuda.empty_cache() return errs, nk froz, _ = measure_backbone(None, False) base, base_nk = measure_backbone("/mnt/william/runs/abl_base_best.safetensors", True) print(f" median-of-clip ground depth abs-err (m):", flush=True) print(f" FROZEN VGGT-Omega {np.median(froz):.3f}", flush=True) print(f" abl_base (ft backbone) {np.median(base):.3f} (vggt keys in ckpt: {base_nk})", flush=True) print(f" abl_full (ft bb+heads) {np.median(full):.3f} (vggt keys in ckpt: {nvggt})", flush=True) if __name__ == "__main__": main()