#!/usr/bin/env python3 """MapGS ver.neurec — zero-shot held-out-view eval of NVIDIA InstantNuRec on our unified clips. Feeds 18 posed input frames (6 timesteps x 3 cams) -> per-pixel Gaussians -> gsplat render to held-out views -> PSNR/SSIM. Validates the input/ frame conventions before any finetuning. Inputs follow the JIT call spec: rgb (1,V,H,W,3) [0,1] channel-last; c2w (1,V,4,4) translation*scene_rescale; fov (1,V,2) rad; rays (1,V,H,W,6) world; distance_to_depth_scale (1,V,H,W,1); camera_idxs (1,V) int64. Outputs: gs_xyz, gs_rot(wxyz), gs_scale(lin), gs_dens (sigmoid), gs_rgb, semantic, normals, affine — gaussians live in the *scene_rescale frame, so render cameras are scaled identically.""" import os, glob, math, argparse import numpy as np import torch import imageio.v2 as imageio from huggingface_hub import hf_hub_download from gsplat import rasterization from mapgs.geometry.cameras import resize_with_intrinsics from mapgs.eval.metrics import psnr, ssim DEV = "cuda" # semantic class ids (KelvinSemanticClass): 0 OTHERS,1 EGO,2 SKY,3 ROAD,4 MOVABLE def load_jit(): p = hf_hub_download("nvidia/instant-nurec", "instant_nurec.pt", token=os.environ["HF_TOKEN"]) torch.jit.set_fusion_strategy([("STATIC", 0), ("DYNAMIC", 0)]) jit = torch.jit.load(p, map_location=DEV).eval() sc = jit.static_core V, H, W = int(sc.expected_v.item()), int(sc.expected_h.item()), int(sc.expected_w.item()) rescale = float(sc.scene_rescale_buffer.item()) return jit, V, H, W, rescale def _read_img(clip_dir, f, v, K, H, W): base = os.path.join(clip_dir, "images", f"{f:03d}_{v}") path = base + ".jpg" if os.path.exists(base + ".jpg") else base + ".png" arr = torch.from_numpy(np.asarray(imageio.imread(path))).float().permute(2, 0, 1) / 255.0 return resize_with_intrinsics(arr, K, H, W) def load_clip(clip_dir, V_in, H, W, n_sup=6): m = torch.load(os.path.join(clip_dir, "meta.pt"), weights_only=False) F, Vc = m["num_frames"], m["num_cameras"] n_fr = max(1, V_in // Vc) # 18//3 = 6 timesteps in_frames = sorted(set(int(round(x)) for x in np.linspace(0, F - 1, n_fr))) in_pairs = [(f, v) for f in in_frames for v in range(Vc)][:V_in] out_frames = [f for f in range(F) if f not in in_frames] or in_frames sup_pairs = [(out_frames[k % len(out_frames)], k % Vc) for k in range(n_sup)] def gather(pairs): imgs, Ks, c2ws, cams = [], [], [], [] for f, v in pairs: img, K = _read_img(clip_dir, f, v, m["K"][v], H, W) imgs.append(img); Ks.append(K); c2ws.append(m["cam2world"][f, v].float()); cams.append(v) return (torch.stack(imgs).to(DEV), torch.stack(Ks).to(DEV), torch.stack(c2ws).to(DEV), torch.tensor(cams, device=DEV)) in_img, in_K, in_c2w, in_cam = gather(in_pairs) sup_img, sup_K, sup_c2w, _ = gather(sup_pairs) origin = in_c2w[:, :3, 3].mean(0) # local frame in_c2w = in_c2w.clone(); in_c2w[:, :3, 3] -= origin sup_c2w = sup_c2w.clone(); sup_c2w[:, :3, 3] -= origin return dict(in_img=in_img, in_K=in_K, in_c2w=in_c2w, in_cam=in_cam, sup_img=sup_img, sup_K=sup_K, sup_c2w=sup_c2w) def nurec_forward(jit, d, rescale, H, W): img, K, c2w, cam = d["in_img"], d["in_K"], d["in_c2w"], d["in_cam"] V = img.shape[0] rgb = img.permute(0, 2, 3, 1).contiguous()[None].float() # 1,V,H,W,3 c2w_s = c2w.clone(); c2w_s[:, :3, 3] *= rescale; c2w_s = c2w_s[None].float() fx, fy = K[:, 0, 0], K[:, 1, 1] fov = torch.stack([2 * torch.atan2(torch.full_like(fx, W / 2), fx), 2 * torch.atan2(torch.full_like(fy, H / 2), fy)], -1)[None] ys, xs = torch.meshgrid(torch.arange(H, device=DEV), torch.arange(W, device=DEV), indexing="ij") px, py = (xs + 0.5).float(), (ys + 0.5).float() dc = torch.stack([(px - K[:, 0, 2, None, None]) / fx[:, None, None], (py - K[:, 1, 2, None, None]) / fy[:, None, None], torch.ones(V, H, W, device=DEV)], -1) # V,H,W,3 (OpenCV +z) dirs_world = torch.einsum("vij,vhwj->vhwi", c2w[:, :3, :3], dc) origins = c2w[:, :3, 3][:, None, None, :].expand(V, H, W, 3) # un-rescaled (per spec) rays = torch.cat([origins, dirs_world], -1)[None].float() d2d = (1.0 / dc.norm(dim=-1, keepdim=True))[None].float() cidx = cam.to(torch.int64)[None] out = jit(rgb, c2w_s, fov, rays, d2d, cidx) gs_xyz, gs_rot, gs_scale, gs_dens, gs_rgb, sem = out[0], out[1], out[2], out[3], out[4], out[5] return (gs_xyz.reshape(-1, 3), gs_rot.reshape(-1, 4), gs_scale.reshape(-1, 3), gs_dens.reshape(-1), gs_rgb.reshape(-1, 3), sem.reshape(-1)) def render(g, sup_c2w, sup_K, rescale, H, W, drop_sem=(1, 4)): xyz, quat, scale, opac, color, sem = g keep = torch.ones_like(sem, dtype=torch.bool) for c in drop_sem: # drop EGO + MOVABLE for static NVS keep &= sem != c xyz, quat, scale, opac, color = xyz[keep], quat[keep], scale[keep], opac[keep], color[keep] c2w_s = sup_c2w.clone(); c2w_s[:, :3, 3] *= rescale # gaussians are in *rescale frame viewmats = torch.inverse(c2w_s) out, _, _ = rasterization(means=xyz, quats=quat, scales=scale, opacities=opac, colors=color, viewmats=viewmats, Ks=sup_K, width=W, height=H, near_plane=0.01, far_plane=1e3, render_mode="RGB") return out.clamp(0, 1).permute(0, 3, 1, 2) # S,3,H,W def main(): ap = argparse.ArgumentParser() ap.add_argument("--val-root", default="/mnt/william/data/unified/av2/val") ap.add_argument("--n-clips", type=int, default=8) ap.add_argument("--save-viz", default="") args = ap.parse_args() jit, V, H, W, rescale = load_jit() print(f"InstantNuRec JIT: V={V} H={H} W={W} scene_rescale={rescale}", flush=True) clips = sorted(glob.glob(os.path.join(args.val_root, "*")))[: args.n_clips] ps, ss = [], [] for i, c in enumerate(clips): if not os.path.exists(os.path.join(c, "meta.pt")): continue d = load_clip(c, V, H, W) with torch.no_grad(): g = nurec_forward(jit, d, rescale, H, W) rgb = render(g, d["sup_c2w"], d["sup_K"], rescale, H, W) p, s = float(psnr(rgb, d["sup_img"])), float(ssim(rgb, d["sup_img"])) ps.append(p); ss.append(s) print(f" clip {i} ({os.path.basename(c)}): held-out PSNR {p:.2f} SSIM {s:.3f} | " f"gaussians {g[0].shape[0]//1000}k", flush=True) if args.save_viz and i == 0: import torchvision grid = torch.cat([d["sup_img"][:4], rgb[:4]], 0) torchvision.utils.save_image(grid, args.save_viz, nrow=4) print(f"\n=== InstantNuRec ZERO-SHOT on held-out av2/val scenes (n={len(ps)}) ===", flush=True) print(f"PSNR {sum(ps)/len(ps):.2f} SSIM {sum(ss)/len(ss):.3f}", flush=True) if __name__ == "__main__": main()