mapvggt / scripts /visualize_maptokengs.py
ChenmingWu's picture
Upload folder using huggingface_hub
8cf92b3 verified
Raw
History Blame Contribute Delete
5.94 kB
#!/usr/bin/env python3
"""Visualize the trained MapTokenGS: held-out novel-view reconstruction (GT vs
predicted RGB vs predicted depth) and an out-of-trajectory simulation strip
(render the scene from laterally-shifted cameras the ego never drove through).
Saves PNGs under runs/."""
import argparse, numpy as np, torch
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
from tokengs import options
from safetensors.torch import load_file
from mapgs.config import load_config
from mapgs.data import UnifiedClipDataset
from mapgs.eval.metrics import psnr, ssim
import scripts.finetune_maptokengs as FT
dev = "cuda"
def to_img(t): # [3,H,W] -> HWC uint8
return (t.clamp(0, 1).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
def depth_img(d): # [H,W] -> colored HWC
d = d.cpu().numpy()
lo, hi = np.percentile(d[d > 0], 2) if (d > 0).any() else 0, np.percentile(d[d > 0], 98) if (d > 0).any() else 1
dn = np.clip((d - lo) / max(hi - lo, 1e-6), 0, 1)
return (plt.cm.turbo(dn)[..., :3] * 255).astype(np.uint8)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", default="/mnt/william/runs/maptokengs_highres_best.safetensors")
ap.add_argument("--scan", type=int, default=20)
ap.add_argument("--s", type=float, default=1.5)
ap.add_argument("--H", type=int, default=448)
ap.add_argument("--W", type=int, default=768)
args = ap.parse_args()
H, W = args.H, args.W
opt = options.config_defaults["train_dl3dv_base"]; opt = opt() if callable(opt) else opt
opt.img_size = (H, W)
m = FT.MapTokenGS(opt, s0=0.5, s_max=2.0, max_instances=8, dyn_per_instance=32).to(dev).eval()
m.load_state_dict(load_file(args.ckpt), strict=False)
m.cur_s = args.s
cfg = load_config(overrides=["data.name=unified",
"data.root=/mnt/william/data/unified/av2,/mnt/william/data/unified/waymo",
f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048",
"data.max_instances=8"])
ds = UnifiedClipDataset(cfg, roots=["/mnt/william/data/unified/av2", "/mnt/william/data/unified/waymo"],
split="train", n_sup_views=4)
# pick the best-PSNR clip among the first --scan (a clean, representative example)
best = (-1, None, None)
for i in range(args.scan):
d = FT.prep(ds[i], 8, 1.0, dev)
with torch.no_grad():
g = m.forward_reconstruction_mapped(d["mi"], d["anchors"], d["atype"], d["anormal"])
grp = m.gaussian_group()[0]; inst = m.gaussian_instance()[0]
rgb, dep, alp = FT.render_dyn(m, g, grp, inst, d["mi"].decoder.cam_view,
d["mi"].decoder.intrinsics, d["sup_frame"], d["dyn"], gain=1.0)
p = float(psnr(rgb.clamp(0, 1), d["gt"]))
if p > best[0]:
best = (p, i, (d, g, grp, inst, rgb, dep))
p, i, (d, g, grp, inst, rgb, dep) = best
src = "waymo" if "waymo" in ds.clips[i] else "av2"
print(f"chosen clip idx {i} ({src}) | held-out PSNR {p:.2f} SSIM {float(ssim(rgb.clamp(0,1), d['gt'])):.3f}", flush=True)
# ---- panel 1: held-out NVS grid (rows = views, cols = GT | Pred | Depth) ----
gt = d["gt"]; V = gt.shape[0]
fig, ax = plt.subplots(V, 3, figsize=(3 * 3.4, V * 1.9))
for v in range(V):
pv = float(psnr(rgb[v:v+1].clamp(0, 1), gt[v:v+1]))
ax[v, 0].imshow(to_img(gt[v])); ax[v, 0].set_ylabel(f"view {v}", fontsize=9)
ax[v, 1].imshow(to_img(rgb[v])); ax[v, 1].set_title(f"PSNR {pv:.1f}", fontsize=8)
ax[v, 2].imshow(depth_img(dep[v]))
for c in range(3):
ax[v, c].set_xticks([]); ax[v, c].set_yticks([])
for c, t in enumerate(["GT (held-out)", "MapTokenGS pred", "pred depth"]):
ax[0, c].set_title(t + ("" if c else ""), fontsize=10) if c == 0 else None
ax[0, c].text(0.5, 1.25, t, transform=ax[0, c].transAxes, ha="center", fontsize=11, weight="bold")
fig.suptitle(f"Held-out novel-view reconstruction — {src} clip {i} — mean PSNR {p:.2f}", y=0.995, fontsize=12)
fig.tight_layout(rect=[0, 0, 1, 0.97])
out1 = "/mnt/william/runs/viz_nvs.png"; fig.savefig(out1, dpi=110, bbox_inches="tight"); plt.close(fig)
print("saved", out1, flush=True)
# ---- panel 2: out-of-trajectory simulation (lateral shifts from input view 0) ----
from mapgs.losses.extrap import perturb_pose
refK = d["ref_K"]; intr = torch.stack([refK[0, 0], refK[1, 1], refK[0, 2], refK[1, 2]])[None, None].to(dev)
f0 = int(d["sup_frame"][0]) if d["dyn"] is not None else 0
shifts = [-0.6, -0.3, 0.0, 0.3, 0.6]
fig2, ax2 = plt.subplots(1, len(shifts), figsize=(len(shifts) * 3.0, 2.6))
for k, s in enumerate(shifts):
P = perturb_pose(d["ref_c2w_norm"], lateral=s, yaw_deg=0.0)
cv = torch.inverse(P).transpose(0, 1)[None, None].to(dev)
gg = g[0]
if d["dyn"] is not None:
from maptokengs import place_dynamics_14d
gg = place_dynamics_14d(g[0], grp, inst, f0, d["dyn"]["c"], d["dyn"]["R"],
d["dyn"]["canon"], d["dyn"]["valid"], d["dyn"]["radius"], gain=1.0)
with torch.no_grad():
o = m.render_gaussians(gg[None], cv, intr)
ax2[k].imshow(to_img(o["images_pred"][0, 0].clamp(0, 1))); ax2[k].set_xticks([]); ax2[k].set_yticks([])
ax2[k].set_title(("ego path" if s == 0 else f"{'+' if s>0 else ''}{s} lat"), fontsize=10,
weight=("bold" if s == 0 else "normal"))
fig2.suptitle(f"Out-of-trajectory simulation — render from cameras off the driven path ({src} clip {i})", fontsize=11)
fig2.tight_layout(rect=[0, 0, 1, 0.93])
out2 = "/mnt/william/runs/viz_oot.png"; fig2.savefig(out2, dpi=110, bbox_inches="tight"); plt.close(fig2)
print("saved", out2, flush=True)
if __name__ == "__main__":
main()