#!/usr/bin/env python3 """Visualize the trained MapTokenGS: held-out novel-view reconstruction (GT vs predicted RGB vs predicted depth) and an out-of-trajectory simulation strip (render the scene from laterally-shifted cameras the ego never drove through). Saves PNGs under runs/.""" import argparse, numpy as np, torch import matplotlib; matplotlib.use("Agg") import matplotlib.pyplot as plt from tokengs import options from safetensors.torch import load_file from mapgs.config import load_config from mapgs.data import UnifiedClipDataset from mapgs.eval.metrics import psnr, ssim import scripts.finetune_maptokengs as FT dev = "cuda" def to_img(t): # [3,H,W] -> HWC uint8 return (t.clamp(0, 1).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) def depth_img(d): # [H,W] -> colored HWC d = d.cpu().numpy() lo, hi = np.percentile(d[d > 0], 2) if (d > 0).any() else 0, np.percentile(d[d > 0], 98) if (d > 0).any() else 1 dn = np.clip((d - lo) / max(hi - lo, 1e-6), 0, 1) return (plt.cm.turbo(dn)[..., :3] * 255).astype(np.uint8) def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", default="/mnt/william/runs/maptokengs_highres_best.safetensors") ap.add_argument("--scan", type=int, default=20) ap.add_argument("--s", type=float, default=1.5) ap.add_argument("--H", type=int, default=448) ap.add_argument("--W", type=int, default=768) args = ap.parse_args() H, W = args.H, args.W opt = options.config_defaults["train_dl3dv_base"]; opt = opt() if callable(opt) else opt opt.img_size = (H, W) m = FT.MapTokenGS(opt, s0=0.5, s_max=2.0, max_instances=8, dyn_per_instance=32).to(dev).eval() m.load_state_dict(load_file(args.ckpt), strict=False) m.cur_s = args.s cfg = load_config(overrides=["data.name=unified", "data.root=/mnt/william/data/unified/av2,/mnt/william/data/unified/waymo", f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048", "data.max_instances=8"]) ds = UnifiedClipDataset(cfg, roots=["/mnt/william/data/unified/av2", "/mnt/william/data/unified/waymo"], split="train", n_sup_views=4) # pick the best-PSNR clip among the first --scan (a clean, representative example) best = (-1, None, None) for i in range(args.scan): d = FT.prep(ds[i], 8, 1.0, dev) with torch.no_grad(): g = m.forward_reconstruction_mapped(d["mi"], d["anchors"], d["atype"], d["anormal"]) grp = m.gaussian_group()[0]; inst = m.gaussian_instance()[0] rgb, dep, alp = FT.render_dyn(m, g, grp, inst, d["mi"].decoder.cam_view, d["mi"].decoder.intrinsics, d["sup_frame"], d["dyn"], gain=1.0) p = float(psnr(rgb.clamp(0, 1), d["gt"])) if p > best[0]: best = (p, i, (d, g, grp, inst, rgb, dep)) p, i, (d, g, grp, inst, rgb, dep) = best src = "waymo" if "waymo" in ds.clips[i] else "av2" print(f"chosen clip idx {i} ({src}) | held-out PSNR {p:.2f} SSIM {float(ssim(rgb.clamp(0,1), d['gt'])):.3f}", flush=True) # ---- panel 1: held-out NVS grid (rows = views, cols = GT | Pred | Depth) ---- gt = d["gt"]; V = gt.shape[0] fig, ax = plt.subplots(V, 3, figsize=(3 * 3.4, V * 1.9)) for v in range(V): pv = float(psnr(rgb[v:v+1].clamp(0, 1), gt[v:v+1])) ax[v, 0].imshow(to_img(gt[v])); ax[v, 0].set_ylabel(f"view {v}", fontsize=9) ax[v, 1].imshow(to_img(rgb[v])); ax[v, 1].set_title(f"PSNR {pv:.1f}", fontsize=8) ax[v, 2].imshow(depth_img(dep[v])) for c in range(3): ax[v, c].set_xticks([]); ax[v, c].set_yticks([]) for c, t in enumerate(["GT (held-out)", "MapTokenGS pred", "pred depth"]): ax[0, c].set_title(t + ("" if c else ""), fontsize=10) if c == 0 else None ax[0, c].text(0.5, 1.25, t, transform=ax[0, c].transAxes, ha="center", fontsize=11, weight="bold") fig.suptitle(f"Held-out novel-view reconstruction — {src} clip {i} — mean PSNR {p:.2f}", y=0.995, fontsize=12) fig.tight_layout(rect=[0, 0, 1, 0.97]) out1 = "/mnt/william/runs/viz_nvs.png"; fig.savefig(out1, dpi=110, bbox_inches="tight"); plt.close(fig) print("saved", out1, flush=True) # ---- panel 2: out-of-trajectory simulation (lateral shifts from input view 0) ---- from mapgs.losses.extrap import perturb_pose refK = d["ref_K"]; intr = torch.stack([refK[0, 0], refK[1, 1], refK[0, 2], refK[1, 2]])[None, None].to(dev) f0 = int(d["sup_frame"][0]) if d["dyn"] is not None else 0 shifts = [-0.6, -0.3, 0.0, 0.3, 0.6] fig2, ax2 = plt.subplots(1, len(shifts), figsize=(len(shifts) * 3.0, 2.6)) for k, s in enumerate(shifts): P = perturb_pose(d["ref_c2w_norm"], lateral=s, yaw_deg=0.0) cv = torch.inverse(P).transpose(0, 1)[None, None].to(dev) gg = g[0] if d["dyn"] is not None: from maptokengs import place_dynamics_14d gg = place_dynamics_14d(g[0], grp, inst, f0, d["dyn"]["c"], d["dyn"]["R"], d["dyn"]["canon"], d["dyn"]["valid"], d["dyn"]["radius"], gain=1.0) with torch.no_grad(): o = m.render_gaussians(gg[None], cv, intr) ax2[k].imshow(to_img(o["images_pred"][0, 0].clamp(0, 1))); ax2[k].set_xticks([]); ax2[k].set_yticks([]) ax2[k].set_title(("ego path" if s == 0 else f"{'+' if s>0 else ''}{s} lat"), fontsize=10, weight=("bold" if s == 0 else "normal")) fig2.suptitle(f"Out-of-trajectory simulation — render from cameras off the driven path ({src} clip {i})", fontsize=11) fig2.tight_layout(rect=[0, 0, 1, 0.93]) out2 = "/mnt/william/runs/viz_oot.png"; fig2.savefig(out2, dpi=110, bbox_inches="tight"); plt.close(fig2) print("saved", out2, flush=True) if __name__ == "__main__": main()