| |
| """Visualize the trained MapTokenGS: held-out novel-view reconstruction (GT vs |
| predicted RGB vs predicted depth) and an out-of-trajectory simulation strip |
| (render the scene from laterally-shifted cameras the ego never drove through). |
| Saves PNGs under runs/.""" |
| import argparse, numpy as np, torch |
| import matplotlib; matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| from tokengs import options |
| from safetensors.torch import load_file |
| from mapgs.config import load_config |
| from mapgs.data import UnifiedClipDataset |
| from mapgs.eval.metrics import psnr, ssim |
| import scripts.finetune_maptokengs as FT |
|
|
| dev = "cuda" |
|
|
|
|
| def to_img(t): |
| return (t.clamp(0, 1).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) |
|
|
|
|
| def depth_img(d): |
| d = d.cpu().numpy() |
| lo, hi = np.percentile(d[d > 0], 2) if (d > 0).any() else 0, np.percentile(d[d > 0], 98) if (d > 0).any() else 1 |
| dn = np.clip((d - lo) / max(hi - lo, 1e-6), 0, 1) |
| return (plt.cm.turbo(dn)[..., :3] * 255).astype(np.uint8) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ckpt", default="/mnt/william/runs/maptokengs_highres_best.safetensors") |
| ap.add_argument("--scan", type=int, default=20) |
| ap.add_argument("--s", type=float, default=1.5) |
| ap.add_argument("--H", type=int, default=448) |
| ap.add_argument("--W", type=int, default=768) |
| args = ap.parse_args() |
| H, W = args.H, args.W |
|
|
| opt = options.config_defaults["train_dl3dv_base"]; opt = opt() if callable(opt) else opt |
| opt.img_size = (H, W) |
| m = FT.MapTokenGS(opt, s0=0.5, s_max=2.0, max_instances=8, dyn_per_instance=32).to(dev).eval() |
| m.load_state_dict(load_file(args.ckpt), strict=False) |
| m.cur_s = args.s |
| cfg = load_config(overrides=["data.name=unified", |
| "data.root=/mnt/william/data/unified/av2,/mnt/william/data/unified/waymo", |
| f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048", |
| "data.max_instances=8"]) |
| ds = UnifiedClipDataset(cfg, roots=["/mnt/william/data/unified/av2", "/mnt/william/data/unified/waymo"], |
| split="train", n_sup_views=4) |
|
|
| |
| best = (-1, None, None) |
| for i in range(args.scan): |
| d = FT.prep(ds[i], 8, 1.0, dev) |
| with torch.no_grad(): |
| g = m.forward_reconstruction_mapped(d["mi"], d["anchors"], d["atype"], d["anormal"]) |
| grp = m.gaussian_group()[0]; inst = m.gaussian_instance()[0] |
| rgb, dep, alp = FT.render_dyn(m, g, grp, inst, d["mi"].decoder.cam_view, |
| d["mi"].decoder.intrinsics, d["sup_frame"], d["dyn"], gain=1.0) |
| p = float(psnr(rgb.clamp(0, 1), d["gt"])) |
| if p > best[0]: |
| best = (p, i, (d, g, grp, inst, rgb, dep)) |
| p, i, (d, g, grp, inst, rgb, dep) = best |
| src = "waymo" if "waymo" in ds.clips[i] else "av2" |
| print(f"chosen clip idx {i} ({src}) | held-out PSNR {p:.2f} SSIM {float(ssim(rgb.clamp(0,1), d['gt'])):.3f}", flush=True) |
|
|
| |
| gt = d["gt"]; V = gt.shape[0] |
| fig, ax = plt.subplots(V, 3, figsize=(3 * 3.4, V * 1.9)) |
| for v in range(V): |
| pv = float(psnr(rgb[v:v+1].clamp(0, 1), gt[v:v+1])) |
| ax[v, 0].imshow(to_img(gt[v])); ax[v, 0].set_ylabel(f"view {v}", fontsize=9) |
| ax[v, 1].imshow(to_img(rgb[v])); ax[v, 1].set_title(f"PSNR {pv:.1f}", fontsize=8) |
| ax[v, 2].imshow(depth_img(dep[v])) |
| for c in range(3): |
| ax[v, c].set_xticks([]); ax[v, c].set_yticks([]) |
| for c, t in enumerate(["GT (held-out)", "MapTokenGS pred", "pred depth"]): |
| ax[0, c].set_title(t + ("" if c else ""), fontsize=10) if c == 0 else None |
| ax[0, c].text(0.5, 1.25, t, transform=ax[0, c].transAxes, ha="center", fontsize=11, weight="bold") |
| fig.suptitle(f"Held-out novel-view reconstruction — {src} clip {i} — mean PSNR {p:.2f}", y=0.995, fontsize=12) |
| fig.tight_layout(rect=[0, 0, 1, 0.97]) |
| out1 = "/mnt/william/runs/viz_nvs.png"; fig.savefig(out1, dpi=110, bbox_inches="tight"); plt.close(fig) |
| print("saved", out1, flush=True) |
|
|
| |
| from mapgs.losses.extrap import perturb_pose |
| refK = d["ref_K"]; intr = torch.stack([refK[0, 0], refK[1, 1], refK[0, 2], refK[1, 2]])[None, None].to(dev) |
| f0 = int(d["sup_frame"][0]) if d["dyn"] is not None else 0 |
| shifts = [-0.6, -0.3, 0.0, 0.3, 0.6] |
| fig2, ax2 = plt.subplots(1, len(shifts), figsize=(len(shifts) * 3.0, 2.6)) |
| for k, s in enumerate(shifts): |
| P = perturb_pose(d["ref_c2w_norm"], lateral=s, yaw_deg=0.0) |
| cv = torch.inverse(P).transpose(0, 1)[None, None].to(dev) |
| gg = g[0] |
| if d["dyn"] is not None: |
| from maptokengs import place_dynamics_14d |
| gg = place_dynamics_14d(g[0], grp, inst, f0, d["dyn"]["c"], d["dyn"]["R"], |
| d["dyn"]["canon"], d["dyn"]["valid"], d["dyn"]["radius"], gain=1.0) |
| with torch.no_grad(): |
| o = m.render_gaussians(gg[None], cv, intr) |
| ax2[k].imshow(to_img(o["images_pred"][0, 0].clamp(0, 1))); ax2[k].set_xticks([]); ax2[k].set_yticks([]) |
| ax2[k].set_title(("ego path" if s == 0 else f"{'+' if s>0 else ''}{s} lat"), fontsize=10, |
| weight=("bold" if s == 0 else "normal")) |
| fig2.suptitle(f"Out-of-trajectory simulation — render from cameras off the driven path ({src} clip {i})", fontsize=11) |
| fig2.tight_layout(rect=[0, 0, 1, 0.93]) |
| out2 = "/mnt/william/runs/viz_oot.png"; fig2.savefig(out2, dpi=110, bbox_inches="tight"); plt.close(fig2) |
| print("saved", out2, flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|