mapvggt / scripts /probe_head.py
ChenmingWu's picture
Upload folder using huggingface_hub
8cf92b3 verified
Raw
History Blame Contribute Delete
3.87 kB
#!/usr/bin/env python3
"""Probe: how much does the TRAINED per-pixel head contribute over a DEFAULT (untrained,
zero-residual) head? Renders the same val clips with (a) the trained head and (b) a freshly
re-initialized head (the init: opacity logit 2.0, log-scale 0 -> footprint, identity quat).
If (a)~(b), the learnable head adds ~nothing and PSNR is set by VGGT-depth + source color +
default footprint -> the bottleneck is the representation/backbone, not the head's optimization.
"""
import argparse, os, copy, sys
import numpy as np
import torch
from mapgs.config import load_config
from mapgs.data import UnifiedClipDataset
from mapgs.eval.metrics import psnr, ssim
from mapvggt import MapVGGT
DEV = "cuda"
sys.path.insert(0, "/mnt/william")
from scripts.train_mapvggt_full import prep, render_scene
def build_val(args):
cfg = load_config(overrides=["data.name=unified", f"data.root={args.roots}",
f"data.height={args.height}", f"data.width={args.width}",
"model.tokens.n_map=2048"])
full = UnifiedClipDataset(cfg, roots=args.roots.split(","), split="train", n_sup_views=6)
def segid(p):
return "_".join(os.path.basename(p.rstrip("/")).split("_")[:2])
segs = sorted(set(segid(c) for c in full.clips))
val_segs = set(segs[:args.val_segs])
seen, vclips = set(), []
for c in full.clips:
sgi = segid(c)
if sgi in val_segs and sgi not in seen:
seen.add(sgi); vclips.append(c)
vds = copy.copy(full); vds.clips = vclips
return vds
@torch.no_grad()
def eval_model(model, vds, n_in):
model.eval(); model.cur_s = model.s_max
model.vggt.eval()
ps, ss = [], []
for i in range(len(vds.clips)):
d = prep(vds[i], n_in, DEV)
gsm = model(d["in_img"], d["in_K"], d["in_c2w"], d["ap"], d["at"], d["an"])
rgb, _ = render_scene(model, gsm, d, *d["sup_img"].shape[-2:], gain=1.0)
p = float(psnr(rgb, d["sup_img"]))
if p == p and abs(p) != float("inf"):
ps.append(p); ss.append(float(ssim(rgb, d["sup_img"])))
return np.mean(ps), np.mean(ss), np.std(ps)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--roots", default="/mnt/william/data/unified/waymo")
ap.add_argument("--ckpt", default="/mnt/william/runs/abl_full_best.safetensors")
ap.add_argument("--n-in", type=int, default=8)
ap.add_argument("--height", type=int, default=256)
ap.add_argument("--width", type=int, default=448)
ap.add_argument("--val-segs", type=int, default=40)
ap.add_argument("--clips", type=int, default=12)
args = ap.parse_args()
from safetensors.torch import load_file
sd = load_file(args.ckpt)
vds = build_val(args)
vds.clips = vds.clips[:args.clips]
print(f"probing {len(vds.clips)} val clips", flush=True)
# (a) trained head
model = MapVGGT(with_map=True, with_dyn=True, finetune_backbone=False).to(DEV)
model.load_state_dict(sd, strict=False)
tp, ts, tsd = eval_model(model, vds, args.n_in)
print(f"TRAINED head: PSNR {tp:.2f}±{tsd:.2f} SSIM {ts:.3f}", flush=True)
# (b) re-init the static head to its constructor init (zero residual, default footprint),
# keep map/dyn heads trained (we want to isolate the per-pixel static head)
import torch.nn as nn
h = model.head
nn.init.zeros_(h[0].weight); nn.init.zeros_(h[0].bias)
nn.init.zeros_(h[2].weight); nn.init.zeros_(h[2].bias)
nn.init.zeros_(h[-1].weight); nn.init.zeros_(h[-1].bias)
h[-1].bias.data[0] = 2.0
h[-1].bias.data[4] = 1.0
up, us, usd = eval_model(model, vds, args.n_in)
print(f"DEFAULT head: PSNR {up:.2f}±{usd:.2f} SSIM {us:.3f}", flush=True)
print(f"\nHead's learned contribution: PSNR {tp-up:+.2f} dB SSIM {ts-us:+.3f}", flush=True)
if __name__ == "__main__":
main()