| |
| """Canonical loader + evaluator for a released MapVGGT (+RefineUNet) checkpoint. |
| Reconstructs the model, loads the model.*/unet.* state dict, and reports held-out-SCENE |
| PSNR/SSIM on a segment-disjoint Waymo val split. This is the reference inference path for |
| the released weights (verifies they round-trip). |
| |
| Env: VGGT_OMEGA_REPO (vggt-omega clone), MAPVGGT_VGGT_CKPT (base VGGT-Omega weights). |
| """ |
| import argparse, os, copy, statistics as st |
| import torch |
| from safetensors.torch import load_file |
|
|
| from mapgs.config import load_config |
| from mapgs.data import UnifiedClipDataset |
| from mapgs.eval.metrics import psnr, ssim |
| from mapvggt import MapVGGT |
| from mapvggt.refine import RefineUNet |
| from scripts.train_mapvggt_refine import render_rda |
| from scripts.train_mapvggt_full import prep |
|
|
| DEV = "cuda" |
|
|
|
|
| def load_mapvggt_refine(ckpt_path): |
| """Reconstruct MapVGGT + RefineUNet from a refine checkpoint (model.*/unet.* keys).""" |
| model = MapVGGT(with_map=False, with_dyn=False, finetune_backbone=True).to(DEV).eval() |
| unet = RefineUNet().to(DEV).eval() |
| sd = load_file(ckpt_path) |
| mmiss, _ = model.load_state_dict({k[6:]: v for k, v in sd.items() if k.startswith("model.")}, strict=False) |
| umiss, _ = unet.load_state_dict({k[5:]: v for k, v in sd.items() if k.startswith("unet.")}, strict=False) |
| return model, unet |
|
|
|
|
| @torch.no_grad() |
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ckpt", default="/mnt/william/runs/mapvggt_refine_best.safetensors") |
| ap.add_argument("--roots", default="/mnt/william/data/unified/waymo") |
| ap.add_argument("--val-segs", type=int, default=40) |
| ap.add_argument("--n-in", type=int, default=8) |
| ap.add_argument("--height", type=int, default=256); ap.add_argument("--width", type=int, default=448) |
| args = ap.parse_args() |
| H, W = args.height, args.width |
| model, unet = load_mapvggt_refine(args.ckpt) |
| cfg = load_config(overrides=["data.name=unified", f"data.root={args.roots}", |
| f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048"]) |
| full = UnifiedClipDataset(cfg, roots=args.roots.split(","), split="train", n_sup_views=6) |
| segid = lambda p: "_".join(os.path.basename(p.rstrip("/")).split("_")[:2]) |
| segs = sorted(set(segid(c) for c in full.clips)); val_segs = set(segs[:args.val_segs]) |
| seen = set(); vclips = [c for c in full.clips if segid(c) in val_segs and not (segid(c) in seen or seen.add(segid(c)))] |
| vds = copy.copy(full); vds.clips = vclips |
| ps, ss = [], [] |
| for i in range(len(vds.clips)): |
| d = prep(vds[i], args.n_in, DEV) |
| g = model(d["in_img"], d["in_K"], d["in_c2w"]) |
| rgb, dep, al = render_rda(g, d["sup_c2w"], d["sup_K"], H, W) |
| ref = unet(rgb, dep, al) |
| p = float(psnr(ref, d["sup_img"])) |
| if p == p and abs(p) != float("inf"): |
| ps.append(p); ss.append(float(ssim(ref, d["sup_img"]))) |
| mp = st.mean(ps); sd = st.pstdev(ps) |
| print(f"loaded {os.path.basename(args.ckpt)} | held-out-SCENE val (n={len(ps)}): " |
| f"PSNR {mp:.2f}±{sd:.2f} SSIM {st.mean(ss):.3f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|