mapvggt / scripts /q4_backbone_depth.py
ChenmingWu's picture
Upload folder using huggingface_hub
8cf92b3 verified
Raw
History Blame Contribute Delete
3.13 kB
#!/usr/bin/env python3
"""Q4 (lean): finetuned-backbone vs frozen VGGT-Omega ground-depth error.
Loads ONE backbone at a time, streams clips (no GPU caching), frees, repeats."""
import os, sys, copy, faulthandler
faulthandler.enable()
import numpy as np
import torch
sys.path.insert(0, "/mnt/william/_vggt_omega_repo")
from mapgs.config import load_config
from mapgs.data import UnifiedClipDataset
from mapgs.hdmap.rasterize_map import rasterize_map_depth
from mapvggt import MapVGGT
from safetensors.torch import load_file
DEV = "cuda"
H, W, N_IN = 256, 448, 8
ROOT = "/mnt/william/data/unified/waymo"
def build_val(cfg):
full = UnifiedClipDataset(cfg, roots=[ROOT], split="train", n_sup_views=6)
segid = lambda p: "_".join(os.path.basename(p.rstrip("/")).split("_")[:2])
segs = sorted(set(segid(c) for c in full.clips)); val_segs = set(segs[:40])
seen, vclips = set(), []
for c in full.clips:
s = segid(c)
if s in val_segs and s not in seen:
seen.add(s); vclips.append(c)
vds = copy.copy(full); vds.clips = vclips
return vds
def measure(model, val):
"""Per-clip median ground-depth abs-err (m), also near/far split aggregated."""
meds, near, far = [], [], []
for i in range(len(val.clips)):
s = val[i]
in_img = s.ctx_images[:N_IN].to(DEV); in_K = s.ctx_K[:N_IN].to(DEV)
in_c2w = s.ctx_c2w[:N_IN].to(DEV); ground = s.ground.to(DEV)
md, mask = rasterize_map_depth(ground, in_K, in_c2w, H, W)
if not mask.any():
continue
with torch.no_grad():
z, _ = model.vggt_depth(in_img)
gt = md[mask]; err = (z[mask] - gt).abs()
meds.append(float(err.median()))
nm = gt < 20; fm = gt >= 40
if nm.any(): near.append(float(err[nm].median()))
if fm.any(): far.append(float(err[fm].median()))
del in_img, in_K, in_c2w, ground, md, mask, z
torch.cuda.empty_cache()
return np.median(meds), np.median(near), np.median(far)
def run(ckpt, finetune, label):
m = MapVGGT(with_map=False, with_dyn=False, finetune_backbone=finetune).to(DEV)
nk = 0
if ckpt:
sd = load_file(ckpt); nk = sum(1 for k in sd if k.startswith("vggt."))
r = m.load_state_dict(sd, strict=False)
print(f" {label}: missing={len(r.missing_keys)} unexpected={len(r.unexpected_keys)} vggt_keys={nk}", flush=True)
m.eval()
val = build_val(load_config(overrides=["data.name=unified", f"data.root={ROOT}",
f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048"]))
a, n, f = measure(m, val)
print(f" {label}: median-of-clip depth err all={a:.3f} near<20m={n:.3f} far>40m={f:.3f} (vggt_keys={nk})", flush=True)
del m; torch.cuda.empty_cache()
def main():
print("==== Q4: ground-depth abs-err (m), median over clips ====", flush=True)
run(None, False, "FROZEN_VGGT")
run("/mnt/william/runs/abl_base_best.safetensors", True, "abl_base_ft")
run("/mnt/william/runs/abl_full_best.safetensors", True, "abl_full_ft")
print("Q4_DONE", flush=True)
if __name__ == "__main__":
main()