#!/usr/bin/env python3 """Canonical loader + evaluator for a released MapVGGT (+RefineUNet) checkpoint. Reconstructs the model, loads the model.*/unet.* state dict, and reports held-out-SCENE PSNR/SSIM on a segment-disjoint Waymo val split. This is the reference inference path for the released weights (verifies they round-trip). Env: VGGT_OMEGA_REPO (vggt-omega clone), MAPVGGT_VGGT_CKPT (base VGGT-Omega weights). """ import argparse, os, copy, statistics as st import torch from safetensors.torch import load_file from mapgs.config import load_config from mapgs.data import UnifiedClipDataset from mapgs.eval.metrics import psnr, ssim from mapvggt import MapVGGT from mapvggt.refine import RefineUNet from scripts.train_mapvggt_refine import render_rda from scripts.train_mapvggt_full import prep DEV = "cuda" def load_mapvggt_refine(ckpt_path): """Reconstruct MapVGGT + RefineUNet from a refine checkpoint (model.*/unet.* keys).""" model = MapVGGT(with_map=False, with_dyn=False, finetune_backbone=True).to(DEV).eval() unet = RefineUNet().to(DEV).eval() sd = load_file(ckpt_path) mmiss, _ = model.load_state_dict({k[6:]: v for k, v in sd.items() if k.startswith("model.")}, strict=False) umiss, _ = unet.load_state_dict({k[5:]: v for k, v in sd.items() if k.startswith("unet.")}, strict=False) return model, unet @torch.no_grad() def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", default="/mnt/william/runs/mapvggt_refine_best.safetensors") ap.add_argument("--roots", default="/mnt/william/data/unified/waymo") ap.add_argument("--val-segs", type=int, default=40) ap.add_argument("--n-in", type=int, default=8) ap.add_argument("--height", type=int, default=256); ap.add_argument("--width", type=int, default=448) args = ap.parse_args() H, W = args.height, args.width model, unet = load_mapvggt_refine(args.ckpt) cfg = load_config(overrides=["data.name=unified", f"data.root={args.roots}", f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048"]) full = UnifiedClipDataset(cfg, roots=args.roots.split(","), split="train", n_sup_views=6) segid = lambda p: "_".join(os.path.basename(p.rstrip("/")).split("_")[:2]) segs = sorted(set(segid(c) for c in full.clips)); val_segs = set(segs[:args.val_segs]) seen = set(); vclips = [c for c in full.clips if segid(c) in val_segs and not (segid(c) in seen or seen.add(segid(c)))] vds = copy.copy(full); vds.clips = vclips ps, ss = [], [] for i in range(len(vds.clips)): d = prep(vds[i], args.n_in, DEV) g = model(d["in_img"], d["in_K"], d["in_c2w"]) rgb, dep, al = render_rda(g, d["sup_c2w"], d["sup_K"], H, W) ref = unet(rgb, dep, al) p = float(psnr(ref, d["sup_img"])) if p == p and abs(p) != float("inf"): ps.append(p); ss.append(float(ssim(ref, d["sup_img"]))) mp = st.mean(ps); sd = st.pstdev(ps) print(f"loaded {os.path.basename(args.ckpt)} | held-out-SCENE val (n={len(ps)}): " f"PSNR {mp:.2f}±{sd:.2f} SSIM {st.mean(ss):.3f}") if __name__ == "__main__": main()