| |
| """Measure how much DEPTH/GEOMETRY error limits MapVGGT novel-view PSNR. |
| Loads the finetuned + base checkpoints, runs on the 40 scene-disjoint Waymo val |
| clips, and computes concrete numbers for the 5 questions in the brief.""" |
| import os, sys, copy, math, faulthandler |
| faulthandler.enable() |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from gsplat import rasterization |
|
|
| sys.path.insert(0, "/mnt/william/_vggt_omega_repo") |
| from mapgs.config import load_config |
| from mapgs.data import UnifiedClipDataset |
| from mapgs.hdmap.rasterize_map import rasterize_map_depth |
| from mapgs.eval.metrics import psnr, ssim |
| from mapvggt import MapVGGT |
| from mapvggt.heads import cat_gaussians |
| from mapnurec.model import lift_to_world |
|
|
| DEV = "cuda" |
| H, W, N_IN = 256, 448, 8 |
| ROOT = "/mnt/william/data/unified/waymo" |
|
|
|
|
| def build_val(cfg): |
| full = UnifiedClipDataset(cfg, roots=[ROOT], split="train", n_sup_views=6) |
| def segid(p): |
| return "_".join(os.path.basename(p.rstrip("/")).split("_")[:2]) |
| segs = sorted(set(segid(c) for c in full.clips)) |
| val_segs = set(segs[:40]) |
| seen, vclips = set(), [] |
| for c in full.clips: |
| sgi = segid(c) |
| if sgi in val_segs and sgi not in seen: |
| seen.add(sgi); vclips.append(c) |
| vds = copy.copy(full); vds.clips = vclips |
| return vds |
|
|
|
|
| def prep(s, device): |
| return dict( |
| in_img=s.ctx_images[:N_IN].to(device), in_K=s.ctx_K[:N_IN].to(device), |
| in_c2w=s.ctx_c2w[:N_IN].to(device), |
| sup_img=s.sup_images.to(device), sup_K=s.sup_K.to(device), sup_c2w=s.sup_c2w.to(device), |
| ground=s.ground.to(device), |
| ap=s.anchor_pos[None].to(device), at=s.anchor_type[None].to(device), |
| an=s.anchor_normal[None].to(device)) |
|
|
|
|
| def render_static(gsm, K, c2w): |
| """Render gsm (no dynamics) into the given views. Returns rgb [S,3,H,W], depth [S,H,W].""" |
| rgbs, deps = [], [] |
| for i in range(c2w.shape[0]): |
| out, _, _ = rasterization(means=gsm["means"], quats=gsm["quats"], scales=gsm["scales"], |
| opacities=gsm["opacities"], colors=gsm["colors"], |
| viewmats=torch.inverse(c2w[i:i+1]), Ks=K[i:i+1], |
| width=W, height=H, near_plane=0.01, far_plane=500.0, |
| render_mode="RGB+ED") |
| rgbs.append(out[0, ..., :3].clamp(0, 1).permute(2, 0, 1)); deps.append(out[0, ..., 3]) |
| return torch.stack(rgbs), torch.stack(deps) |
|
|
|
|
| def cam_centers(c2w): |
| return c2w[:, :3, 3] |
|
|
|
|
| def target_context_dist(d): |
| """Mean over sup views of min distance to any context camera center.""" |
| cc = cam_centers(d["in_c2w"]) |
| sc = cam_centers(d["sup_c2w"]) |
| dist = torch.cdist(sc, cc).min(dim=1).values |
| return float(dist.mean()) |
|
|
|
|
| @torch.no_grad() |
| def per_clip_depth_error(model, d): |
| """Per-pixel VGGT metric depth vs map ground depth, on ground pixels of the INPUT views. |
| Returns dict of abs-error stats (m), near<20 / mid / far>40, and frac ground coverage.""" |
| z, _ = model.vggt_depth(d["in_img"]) |
| md, mask = rasterize_map_depth(d["ground"], d["in_K"], d["in_c2w"], H, W) |
| if not mask.any(): |
| return None |
| err = (z - md).abs()[mask] |
| gt = md[mask] |
| out = {} |
| out["n_ground_px"] = int(mask.sum()) |
| out["frac_ground"] = float(mask.float().mean()) |
| out["med_all"] = float(err.median()); out["mean_all"] = float(err.mean()) |
| out["med_rel"] = float((err / gt.clamp_min(1e-3)).median()) |
| for name, lo, hi in [("near", 0, 20), ("mid", 20, 40), ("far", 40, 200)]: |
| m2 = (gt >= lo) & (gt < hi) |
| if m2.any(): |
| out[f"med_{name}"] = float(err[m2].median()) |
| out[f"mean_{name}"] = float(err[m2].mean()) |
| out[f"n_{name}"] = int(m2.sum()) |
| out[f"gtmean_{name}"] = float(gt[m2].mean()) |
| return out |
|
|
|
|
| def main(): |
| cfg = load_config(overrides=["data.name=unified", f"data.root={ROOT}", |
| f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048"]) |
| vds = build_val(cfg) |
| print(f"# val clips: {len(vds.clips)}", flush=True) |
|
|
| |
| print("loading abl_full_best (finetuned backbone+heads)...", flush=True) |
| mf = MapVGGT(with_map=True, with_dyn=True, finetune_backbone=True).to(DEV) |
| from safetensors.torch import load_file |
| sdf = load_file("/mnt/william/runs/abl_full_best.safetensors") |
| miss = mf.load_state_dict(sdf, strict=False) |
| print(f" full: missing={len(miss.missing_keys)} unexpected={len(miss.unexpected_keys)}", flush=True) |
| mf.eval() |
|
|
| |
| samples = [] |
| for i in range(len(vds.clips)): |
| d = prep(vds[i], DEV) |
| samples.append((i, d, target_context_dist(d))) |
| samples.sort(key=lambda x: x[2]) |
| dists = [s[2] for s in samples] |
| print(f"# target-context dist: min {min(dists):.3f} med {np.median(dists):.3f} max {max(dists):.3f}", flush=True) |
|
|
| |
| print("\n==== Q1: self-recon (render INPUT views) vs novel-view PSNR ====", flush=True) |
| print("idx dist novelPSNR selfPSNR novelSSIM selfSSIM", flush=True) |
| easy_rows = [] |
| n_easy = 8 |
| for (i, d, dist) in samples[:n_easy]: |
| gsm = mf(d["in_img"], d["in_K"], d["in_c2w"], d["ap"], d["at"], d["an"]) |
| |
| rgb_n, _ = render_static(gsm, d["sup_K"], d["sup_c2w"]) |
| p_n = float(psnr(rgb_n, d["sup_img"])); s_n = float(ssim(rgb_n, d["sup_img"])) |
| |
| rgb_s, _ = render_static(gsm, d["in_K"], d["in_c2w"]) |
| p_s = float(psnr(rgb_s, d["in_img"])); s_s = float(ssim(rgb_s, d["in_img"])) |
| easy_rows.append((i, dist, p_n, p_s, s_n, s_s)) |
| print(f"{i:3d} {dist:6.3f} {p_n:7.2f} {p_s:7.2f} {s_n:.3f} {s_s:.3f}", flush=True) |
| en = np.array([r[2] for r in easy_rows]); es = np.array([r[3] for r in easy_rows]) |
| print(f"EASY mean: novel {en.mean():.2f} self-recon {es.mean():.2f} gap {es.mean()-en.mean():.2f} dB", flush=True) |
|
|
| |
| print("\n==== Q2: VGGT depth abs-error vs HD-map ground depth (m), all val clips ====", flush=True) |
| agg = {k: [] for k in ["med_all", "mean_all", "med_rel", |
| "med_near", "mean_near", "med_mid", "mean_mid", "med_far", "mean_far"]} |
| npx = {"near": 0, "mid": 0, "far": 0} |
| gtm = {"near": [], "mid": [], "far": []} |
| for (i, d, dist) in samples: |
| st = per_clip_depth_error(mf, d) |
| if st is None: |
| continue |
| for k in agg: |
| if k in st: |
| agg[k].append(st[k]) |
| for b in ["near", "mid", "far"]: |
| npx[b] += st.get(f"n_{b}", 0) |
| if f"gtmean_{b}" in st: |
| gtm[b].append(st[f"gtmean_{b}"]) |
| def report(label, key): |
| v = np.array(agg[key]) |
| if len(v): print(f" {label:18s} median-of-clip {np.median(v):6.2f} mean {v.mean():6.2f} m", flush=True) |
| report("abs err (all gnd)", "med_all") |
| report("mean abs err all", "mean_all") |
| print(f" relative err (all) median-of-clip {np.median(agg['med_rel'])*100:.1f}%", flush=True) |
| report("near <20m (med)", "med_near") |
| report("near <20m (mean)", "mean_near") |
| report("mid 20-40m (med)", "med_mid") |
| report("far >40m (med)", "med_far") |
| report("far >40m (mean)", "mean_far") |
| for b in ["near", "mid", "far"]: |
| if gtm[b]: |
| print(f" [{b}] ground px total {npx[b]} mean-GT-depth {np.mean(gtm[b]):.1f}m", flush=True) |
|
|
| |
| print("\n==== Q3: gaussian reprojection displacement from depth error (px) ====", flush=True) |
| |
| |
| disp_meds, disp_p90s = [], [] |
| for (i, d, dist) in samples: |
| z, _ = mf.vggt_depth(d["in_img"]) |
| md, mask = rasterize_map_depth(d["ground"], d["in_K"], d["in_c2w"], H, W) |
| if not mask.any(): |
| continue |
| |
| xyz_pred = lift_to_world(z, d["in_K"], d["in_c2w"]) |
| xyz_true = lift_to_world(md, d["in_K"], d["in_c2w"]) |
| |
| cc_in = cam_centers(d["in_c2w"]); cc_sup = cam_centers(d["sup_c2w"]) |
| |
| tgt = int(torch.cdist(cc_in.mean(0, keepdim=True), cc_sup).argmin()) |
| c2w_t = d["sup_c2w"][tgt]; K_t = d["sup_K"][tgt] |
| w2c = torch.inverse(c2w_t) |
| def proj(xyz): |
| p = xyz.reshape(-1, 3) |
| pc = (w2c[:3, :3] @ p.T).T + w2c[:3, 3] |
| uv = (K_t @ pc.T).T |
| z_ = uv[:, 2].clamp_min(1e-4) |
| return torch.stack([uv[:, 0] / z_, uv[:, 1] / z_], -1), pc[:, 2] |
| uv_p, zc_p = proj(xyz_pred); uv_t, zc_t = proj(xyz_true) |
| m = mask.reshape(-1) & (zc_t > 0.1) & (zc_p > 0.1) |
| if not m.any(): |
| continue |
| shift = (uv_p[m] - uv_t[m]).norm(dim=-1) |
| disp_meds.append(float(shift.median())); disp_p90s.append(float(shift.quantile(0.9))) |
| print(f" reproj displacement on ground px (into nearest sup view):", flush=True) |
| print(f" median-of-clip median shift {np.median(disp_meds):.2f} px mean {np.mean(disp_meds):.2f} px", flush=True) |
| print(f" median-of-clip p90 shift {np.median(disp_p90s):.2f} px", flush=True) |
|
|
| |
| print("\n==== Q3b: per-clip novelPSNR vs dist and vs depth-displacement ====", flush=True) |
| rows = [] |
| for (i, d, dist) in samples: |
| gsm = mf(d["in_img"], d["in_K"], d["in_c2w"], d["ap"], d["at"], d["an"]) |
| rgb_n, _ = render_static(gsm, d["sup_K"], d["sup_c2w"]) |
| p_n = float(psnr(rgb_n, d["sup_img"])) |
| st = per_clip_depth_error(mf, d) |
| de = st["med_all"] if st else float("nan") |
| rows.append((i, dist, p_n, de)) |
| arr = np.array([(r[1], r[2], r[3]) for r in rows], dtype=float) |
| valid = ~np.isnan(arr[:, 2]) |
| def corr(a, b): |
| a, b = a[valid], b[valid] |
| m = ~np.isnan(a) & ~np.isnan(b) |
| return float(np.corrcoef(a[m], b[m])[0, 1]) |
| print(f" PSNR vs target-context dist r = {corr(arr[:,0], arr[:,1]):.3f}", flush=True) |
| print(f" PSNR vs median depth-error r = {corr(arr[:,2], arr[:,1]):.3f}", flush=True) |
| print(f" dist vs median depth-error r = {corr(arr[:,0], arr[:,2]):.3f}", flush=True) |
| |
| med_d = np.median(arr[valid, 0]) |
| easymask = valid & (arr[:, 0] <= med_d); hardmask = valid & (arr[:, 0] > med_d) |
| print(f" EASY (dist<={med_d:.2f}): PSNR {arr[easymask,1].mean():.2f} depth-err {arr[easymask,2].mean():.2f}m", flush=True) |
| print(f" HARD (dist >{med_d:.2f}): PSNR {arr[hardmask,1].mean():.2f} depth-err {arr[hardmask,2].mean():.2f}m", flush=True) |
|
|
| |
| print("\n==== Q4: finetuned-backbone vs FROZEN VGGT depth error ====", flush=True) |
| |
| |
| gt_cache = [] |
| for (i, d, dist) in samples: |
| md, mask = rasterize_map_depth(d["ground"], d["in_K"], d["in_c2w"], H, W) |
| gt_cache.append((d, md, mask)) |
|
|
| |
| full = [] |
| for (d, md, mask) in gt_cache: |
| if not mask.any(): |
| continue |
| zu, _ = mf.vggt_depth(d["in_img"]); full.append(float((zu[mask]-md[mask]).abs().median())) |
| nvggt = sum(1 for k in sdf if k.startswith("vggt.")) |
| del mf; torch.cuda.empty_cache() |
|
|
| def measure_backbone(ckpt, finetune): |
| m = MapVGGT(with_map=False, with_dyn=False, finetune_backbone=finetune).to(DEV) |
| nk = 0 |
| if ckpt is not None: |
| sd = load_file(ckpt) |
| nk = sum(1 for k in sd if k.startswith("vggt.")) |
| r = m.load_state_dict(sd, strict=False) |
| print(f" {os.path.basename(ckpt)}: missing={len(r.missing_keys)} " |
| f"unexpected={len(r.unexpected_keys)} vggt_keys={nk}", flush=True) |
| m.eval() |
| errs = [] |
| for (d, md, mask) in gt_cache: |
| if not mask.any(): |
| continue |
| z, _ = m.vggt_depth(d["in_img"]); errs.append(float((z[mask]-md[mask]).abs().median())) |
| del m; torch.cuda.empty_cache() |
| return errs, nk |
|
|
| froz, _ = measure_backbone(None, False) |
| base, base_nk = measure_backbone("/mnt/william/runs/abl_base_best.safetensors", True) |
|
|
| print(f" median-of-clip ground depth abs-err (m):", flush=True) |
| print(f" FROZEN VGGT-Omega {np.median(froz):.3f}", flush=True) |
| print(f" abl_base (ft backbone) {np.median(base):.3f} (vggt keys in ckpt: {base_nk})", flush=True) |
| print(f" abl_full (ft bb+heads) {np.median(full):.3f} (vggt keys in ckpt: {nvggt})", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|