mapvggt / scripts /depth_limiter_probe.py
ChenmingWu's picture
Upload folder using huggingface_hub
8cf92b3 verified
Raw
History Blame Contribute Delete
13.6 kB
#!/usr/bin/env python3
"""Measure how much DEPTH/GEOMETRY error limits MapVGGT novel-view PSNR.
Loads the finetuned + base checkpoints, runs on the 40 scene-disjoint Waymo val
clips, and computes concrete numbers for the 5 questions in the brief."""
import os, sys, copy, math, faulthandler
faulthandler.enable()
import numpy as np
import torch
import torch.nn.functional as F
from gsplat import rasterization
sys.path.insert(0, "/mnt/william/_vggt_omega_repo")
from mapgs.config import load_config
from mapgs.data import UnifiedClipDataset
from mapgs.hdmap.rasterize_map import rasterize_map_depth
from mapgs.eval.metrics import psnr, ssim
from mapvggt import MapVGGT
from mapvggt.heads import cat_gaussians
from mapnurec.model import lift_to_world
DEV = "cuda"
H, W, N_IN = 256, 448, 8
ROOT = "/mnt/william/data/unified/waymo"
def build_val(cfg):
full = UnifiedClipDataset(cfg, roots=[ROOT], split="train", n_sup_views=6)
def segid(p):
return "_".join(os.path.basename(p.rstrip("/")).split("_")[:2])
segs = sorted(set(segid(c) for c in full.clips))
val_segs = set(segs[:40])
seen, vclips = set(), []
for c in full.clips:
sgi = segid(c)
if sgi in val_segs and sgi not in seen:
seen.add(sgi); vclips.append(c)
vds = copy.copy(full); vds.clips = vclips
return vds
def prep(s, device):
return dict(
in_img=s.ctx_images[:N_IN].to(device), in_K=s.ctx_K[:N_IN].to(device),
in_c2w=s.ctx_c2w[:N_IN].to(device),
sup_img=s.sup_images.to(device), sup_K=s.sup_K.to(device), sup_c2w=s.sup_c2w.to(device),
ground=s.ground.to(device),
ap=s.anchor_pos[None].to(device), at=s.anchor_type[None].to(device),
an=s.anchor_normal[None].to(device))
def render_static(gsm, K, c2w):
"""Render gsm (no dynamics) into the given views. Returns rgb [S,3,H,W], depth [S,H,W]."""
rgbs, deps = [], []
for i in range(c2w.shape[0]):
out, _, _ = rasterization(means=gsm["means"], quats=gsm["quats"], scales=gsm["scales"],
opacities=gsm["opacities"], colors=gsm["colors"],
viewmats=torch.inverse(c2w[i:i+1]), Ks=K[i:i+1],
width=W, height=H, near_plane=0.01, far_plane=500.0,
render_mode="RGB+ED")
rgbs.append(out[0, ..., :3].clamp(0, 1).permute(2, 0, 1)); deps.append(out[0, ..., 3])
return torch.stack(rgbs), torch.stack(deps)
def cam_centers(c2w):
return c2w[:, :3, 3]
def target_context_dist(d):
"""Mean over sup views of min distance to any context camera center."""
cc = cam_centers(d["in_c2w"]) # [Vc,3]
sc = cam_centers(d["sup_c2w"]) # [Vs,3]
dist = torch.cdist(sc, cc).min(dim=1).values
return float(dist.mean())
@torch.no_grad()
def per_clip_depth_error(model, d):
"""Per-pixel VGGT metric depth vs map ground depth, on ground pixels of the INPUT views.
Returns dict of abs-error stats (m), near<20 / mid / far>40, and frac ground coverage."""
z, _ = model.vggt_depth(d["in_img"]) # [Vc,H,W] metric, clamped [1.5,120]
md, mask = rasterize_map_depth(d["ground"], d["in_K"], d["in_c2w"], H, W) # [Vc,H,W]
if not mask.any():
return None
err = (z - md).abs()[mask]
gt = md[mask]
out = {}
out["n_ground_px"] = int(mask.sum())
out["frac_ground"] = float(mask.float().mean())
out["med_all"] = float(err.median()); out["mean_all"] = float(err.mean())
out["med_rel"] = float((err / gt.clamp_min(1e-3)).median())
for name, lo, hi in [("near", 0, 20), ("mid", 20, 40), ("far", 40, 200)]:
m2 = (gt >= lo) & (gt < hi)
if m2.any():
out[f"med_{name}"] = float(err[m2].median())
out[f"mean_{name}"] = float(err[m2].mean())
out[f"n_{name}"] = int(m2.sum())
out[f"gtmean_{name}"] = float(gt[m2].mean())
return out
def main():
cfg = load_config(overrides=["data.name=unified", f"data.root={ROOT}",
f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048"])
vds = build_val(cfg)
print(f"# val clips: {len(vds.clips)}", flush=True)
# ---- load finetuned full model ----
print("loading abl_full_best (finetuned backbone+heads)...", flush=True)
mf = MapVGGT(with_map=True, with_dyn=True, finetune_backbone=True).to(DEV)
from safetensors.torch import load_file
sdf = load_file("/mnt/william/runs/abl_full_best.safetensors")
miss = mf.load_state_dict(sdf, strict=False)
print(f" full: missing={len(miss.missing_keys)} unexpected={len(miss.unexpected_keys)}", flush=True)
mf.eval()
# precompute per-clip prep + target-context distance, sort EASY..HARD
samples = []
for i in range(len(vds.clips)):
d = prep(vds[i], DEV)
samples.append((i, d, target_context_dist(d)))
samples.sort(key=lambda x: x[2])
dists = [s[2] for s in samples]
print(f"# target-context dist: min {min(dists):.3f} med {np.median(dists):.3f} max {max(dists):.3f}", flush=True)
# ============ Q1: self-recon vs novel-view PSNR, EASY clips ============
print("\n==== Q1: self-recon (render INPUT views) vs novel-view PSNR ====", flush=True)
print("idx dist novelPSNR selfPSNR novelSSIM selfSSIM", flush=True)
easy_rows = []
n_easy = 8
for (i, d, dist) in samples[:n_easy]:
gsm = mf(d["in_img"], d["in_K"], d["in_c2w"], d["ap"], d["at"], d["an"])
# novel: render the (static) gsm into sup views
rgb_n, _ = render_static(gsm, d["sup_K"], d["sup_c2w"])
p_n = float(psnr(rgb_n, d["sup_img"])); s_n = float(ssim(rgb_n, d["sup_img"]))
# self-recon: render gsm into the INPUT views (n_in of them)
rgb_s, _ = render_static(gsm, d["in_K"], d["in_c2w"])
p_s = float(psnr(rgb_s, d["in_img"])); s_s = float(ssim(rgb_s, d["in_img"]))
easy_rows.append((i, dist, p_n, p_s, s_n, s_s))
print(f"{i:3d} {dist:6.3f} {p_n:7.2f} {p_s:7.2f} {s_n:.3f} {s_s:.3f}", flush=True)
en = np.array([r[2] for r in easy_rows]); es = np.array([r[3] for r in easy_rows])
print(f"EASY mean: novel {en.mean():.2f} self-recon {es.mean():.2f} gap {es.mean()-en.mean():.2f} dB", flush=True)
# ============ Q2: VGGT depth accuracy vs map ground depth ============
print("\n==== Q2: VGGT depth abs-error vs HD-map ground depth (m), all val clips ====", flush=True)
agg = {k: [] for k in ["med_all", "mean_all", "med_rel",
"med_near", "mean_near", "med_mid", "mean_mid", "med_far", "mean_far"]}
npx = {"near": 0, "mid": 0, "far": 0}
gtm = {"near": [], "mid": [], "far": []}
for (i, d, dist) in samples:
st = per_clip_depth_error(mf, d)
if st is None:
continue
for k in agg:
if k in st:
agg[k].append(st[k])
for b in ["near", "mid", "far"]:
npx[b] += st.get(f"n_{b}", 0)
if f"gtmean_{b}" in st:
gtm[b].append(st[f"gtmean_{b}"])
def report(label, key):
v = np.array(agg[key])
if len(v): print(f" {label:18s} median-of-clip {np.median(v):6.2f} mean {v.mean():6.2f} m", flush=True)
report("abs err (all gnd)", "med_all")
report("mean abs err all", "mean_all")
print(f" relative err (all) median-of-clip {np.median(agg['med_rel'])*100:.1f}%", flush=True)
report("near <20m (med)", "med_near")
report("near <20m (mean)", "mean_near")
report("mid 20-40m (med)", "med_mid")
report("far >40m (med)", "med_far")
report("far >40m (mean)", "mean_far")
for b in ["near", "mid", "far"]:
if gtm[b]:
print(f" [{b}] ground px total {npx[b]} mean-GT-depth {np.mean(gtm[b]):.1f}m", flush=True)
# ============ Q3: ghosting -- pixel displacement from depth error at eval baseline ============
print("\n==== Q3: gaussian reprojection displacement from depth error (px) ====", flush=True)
# For each val clip: lift INPUT pixels with VGGT depth z and with z+err(map), reproject
# both into the nearest sup view, measure pixel shift on ground pixels.
disp_meds, disp_p90s = [], []
for (i, d, dist) in samples:
z, _ = mf.vggt_depth(d["in_img"])
md, mask = rasterize_map_depth(d["ground"], d["in_K"], d["in_c2w"], H, W)
if not mask.any():
continue
# world points from predicted depth and from map (true ground) depth
xyz_pred = lift_to_world(z, d["in_K"], d["in_c2w"]) # [Vc,H,W,3]
xyz_true = lift_to_world(md, d["in_K"], d["in_c2w"])
# reproject both into the nearest sup view (by camera center)
cc_in = cam_centers(d["in_c2w"]); cc_sup = cam_centers(d["sup_c2w"])
# choose the single sup view closest to the input rig centroid
tgt = int(torch.cdist(cc_in.mean(0, keepdim=True), cc_sup).argmin())
c2w_t = d["sup_c2w"][tgt]; K_t = d["sup_K"][tgt]
w2c = torch.inverse(c2w_t)
def proj(xyz):
p = xyz.reshape(-1, 3)
pc = (w2c[:3, :3] @ p.T).T + w2c[:3, 3]
uv = (K_t @ pc.T).T
z_ = uv[:, 2].clamp_min(1e-4)
return torch.stack([uv[:, 0] / z_, uv[:, 1] / z_], -1), pc[:, 2]
uv_p, zc_p = proj(xyz_pred); uv_t, zc_t = proj(xyz_true)
m = mask.reshape(-1) & (zc_t > 0.1) & (zc_p > 0.1)
if not m.any():
continue
shift = (uv_p[m] - uv_t[m]).norm(dim=-1)
disp_meds.append(float(shift.median())); disp_p90s.append(float(shift.quantile(0.9)))
print(f" reproj displacement on ground px (into nearest sup view):", flush=True)
print(f" median-of-clip median shift {np.median(disp_meds):.2f} px mean {np.mean(disp_meds):.2f} px", flush=True)
print(f" median-of-clip p90 shift {np.median(disp_p90s):.2f} px", flush=True)
# correlate per-clip depth error and displacement with novel PSNR + dist (HARD vs disocclusion)
print("\n==== Q3b: per-clip novelPSNR vs dist and vs depth-displacement ====", flush=True)
rows = []
for (i, d, dist) in samples:
gsm = mf(d["in_img"], d["in_K"], d["in_c2w"], d["ap"], d["at"], d["an"])
rgb_n, _ = render_static(gsm, d["sup_K"], d["sup_c2w"])
p_n = float(psnr(rgb_n, d["sup_img"]))
st = per_clip_depth_error(mf, d)
de = st["med_all"] if st else float("nan")
rows.append((i, dist, p_n, de))
arr = np.array([(r[1], r[2], r[3]) for r in rows], dtype=float)
valid = ~np.isnan(arr[:, 2])
def corr(a, b):
a, b = a[valid], b[valid]
m = ~np.isnan(a) & ~np.isnan(b)
return float(np.corrcoef(a[m], b[m])[0, 1])
print(f" PSNR vs target-context dist r = {corr(arr[:,0], arr[:,1]):.3f}", flush=True)
print(f" PSNR vs median depth-error r = {corr(arr[:,2], arr[:,1]):.3f}", flush=True)
print(f" dist vs median depth-error r = {corr(arr[:,0], arr[:,2]):.3f}", flush=True)
# split easy/hard by dist median
med_d = np.median(arr[valid, 0])
easymask = valid & (arr[:, 0] <= med_d); hardmask = valid & (arr[:, 0] > med_d)
print(f" EASY (dist<={med_d:.2f}): PSNR {arr[easymask,1].mean():.2f} depth-err {arr[easymask,2].mean():.2f}m", flush=True)
print(f" HARD (dist >{med_d:.2f}): PSNR {arr[hardmask,1].mean():.2f} depth-err {arr[hardmask,2].mean():.2f}m", flush=True)
# ============ Q4: finetuned vs frozen-backbone depth ============
print("\n==== Q4: finetuned-backbone vs FROZEN VGGT depth error ====", flush=True)
# precompute per-clip map-depth GT on input views (cheap, no VGGT) so we can free
# each backbone before loading the next (avoids 3 resident 1B models -> OOM).
gt_cache = []
for (i, d, dist) in samples:
md, mask = rasterize_map_depth(d["ground"], d["in_K"], d["in_c2w"], H, W)
gt_cache.append((d, md, mask))
# full model already loaded -> measure it first, then free it.
full = []
for (d, md, mask) in gt_cache:
if not mask.any():
continue
zu, _ = mf.vggt_depth(d["in_img"]); full.append(float((zu[mask]-md[mask]).abs().median()))
nvggt = sum(1 for k in sdf if k.startswith("vggt."))
del mf; torch.cuda.empty_cache()
def measure_backbone(ckpt, finetune):
m = MapVGGT(with_map=False, with_dyn=False, finetune_backbone=finetune).to(DEV)
nk = 0
if ckpt is not None:
sd = load_file(ckpt)
nk = sum(1 for k in sd if k.startswith("vggt."))
r = m.load_state_dict(sd, strict=False)
print(f" {os.path.basename(ckpt)}: missing={len(r.missing_keys)} "
f"unexpected={len(r.unexpected_keys)} vggt_keys={nk}", flush=True)
m.eval()
errs = []
for (d, md, mask) in gt_cache:
if not mask.any():
continue
z, _ = m.vggt_depth(d["in_img"]); errs.append(float((z[mask]-md[mask]).abs().median()))
del m; torch.cuda.empty_cache()
return errs, nk
froz, _ = measure_backbone(None, False)
base, base_nk = measure_backbone("/mnt/william/runs/abl_base_best.safetensors", True)
print(f" median-of-clip ground depth abs-err (m):", flush=True)
print(f" FROZEN VGGT-Omega {np.median(froz):.3f}", flush=True)
print(f" abl_base (ft backbone) {np.median(base):.3f} (vggt keys in ckpt: {base_nk})", flush=True)
print(f" abl_full (ft bb+heads) {np.median(full):.3f} (vggt keys in ckpt: {nvggt})", flush=True)
if __name__ == "__main__":
main()