| |
| """Q4 (lean): finetuned-backbone vs frozen VGGT-Omega ground-depth error. |
| Loads ONE backbone at a time, streams clips (no GPU caching), frees, repeats.""" |
| import os, sys, copy, faulthandler |
| faulthandler.enable() |
| import numpy as np |
| import torch |
|
|
| sys.path.insert(0, "/mnt/william/_vggt_omega_repo") |
| from mapgs.config import load_config |
| from mapgs.data import UnifiedClipDataset |
| from mapgs.hdmap.rasterize_map import rasterize_map_depth |
| from mapvggt import MapVGGT |
| from safetensors.torch import load_file |
|
|
| DEV = "cuda" |
| H, W, N_IN = 256, 448, 8 |
| ROOT = "/mnt/william/data/unified/waymo" |
|
|
|
|
| def build_val(cfg): |
| full = UnifiedClipDataset(cfg, roots=[ROOT], split="train", n_sup_views=6) |
| segid = lambda p: "_".join(os.path.basename(p.rstrip("/")).split("_")[:2]) |
| segs = sorted(set(segid(c) for c in full.clips)); val_segs = set(segs[:40]) |
| seen, vclips = set(), [] |
| for c in full.clips: |
| s = segid(c) |
| if s in val_segs and s not in seen: |
| seen.add(s); vclips.append(c) |
| vds = copy.copy(full); vds.clips = vclips |
| return vds |
|
|
|
|
| def measure(model, val): |
| """Per-clip median ground-depth abs-err (m), also near/far split aggregated.""" |
| meds, near, far = [], [], [] |
| for i in range(len(val.clips)): |
| s = val[i] |
| in_img = s.ctx_images[:N_IN].to(DEV); in_K = s.ctx_K[:N_IN].to(DEV) |
| in_c2w = s.ctx_c2w[:N_IN].to(DEV); ground = s.ground.to(DEV) |
| md, mask = rasterize_map_depth(ground, in_K, in_c2w, H, W) |
| if not mask.any(): |
| continue |
| with torch.no_grad(): |
| z, _ = model.vggt_depth(in_img) |
| gt = md[mask]; err = (z[mask] - gt).abs() |
| meds.append(float(err.median())) |
| nm = gt < 20; fm = gt >= 40 |
| if nm.any(): near.append(float(err[nm].median())) |
| if fm.any(): far.append(float(err[fm].median())) |
| del in_img, in_K, in_c2w, ground, md, mask, z |
| torch.cuda.empty_cache() |
| return np.median(meds), np.median(near), np.median(far) |
|
|
|
|
| def run(ckpt, finetune, label): |
| m = MapVGGT(with_map=False, with_dyn=False, finetune_backbone=finetune).to(DEV) |
| nk = 0 |
| if ckpt: |
| sd = load_file(ckpt); nk = sum(1 for k in sd if k.startswith("vggt.")) |
| r = m.load_state_dict(sd, strict=False) |
| print(f" {label}: missing={len(r.missing_keys)} unexpected={len(r.unexpected_keys)} vggt_keys={nk}", flush=True) |
| m.eval() |
| val = build_val(load_config(overrides=["data.name=unified", f"data.root={ROOT}", |
| f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048"])) |
| a, n, f = measure(m, val) |
| print(f" {label}: median-of-clip depth err all={a:.3f} near<20m={n:.3f} far>40m={f:.3f} (vggt_keys={nk})", flush=True) |
| del m; torch.cuda.empty_cache() |
|
|
|
|
| def main(): |
| print("==== Q4: ground-depth abs-err (m), median over clips ====", flush=True) |
| run(None, False, "FROZEN_VGGT") |
| run("/mnt/william/runs/abl_base_best.safetensors", True, "abl_base_ft") |
| run("/mnt/william/runs/abl_full_best.safetensors", True, "abl_full_ft") |
| print("Q4_DONE", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|