| |
| """MapGS ver.neurec — zero-shot held-out-view eval of NVIDIA InstantNuRec on our |
| unified clips. Feeds 18 posed input frames (6 timesteps x 3 cams) -> per-pixel |
| Gaussians -> gsplat render to held-out views -> PSNR/SSIM. Validates the input/ |
| frame conventions before any finetuning. Inputs follow the JIT call spec: |
| rgb (1,V,H,W,3) [0,1] channel-last; c2w (1,V,4,4) translation*scene_rescale; |
| fov (1,V,2) rad; rays (1,V,H,W,6) world; distance_to_depth_scale (1,V,H,W,1); |
| camera_idxs (1,V) int64. Outputs: gs_xyz, gs_rot(wxyz), gs_scale(lin), gs_dens |
| (sigmoid), gs_rgb, semantic, normals, affine — gaussians live in the *scene_rescale |
| frame, so render cameras are scaled identically.""" |
| import os, glob, math, argparse |
| import numpy as np |
| import torch |
| import imageio.v2 as imageio |
| from huggingface_hub import hf_hub_download |
| from gsplat import rasterization |
| from mapgs.geometry.cameras import resize_with_intrinsics |
| from mapgs.eval.metrics import psnr, ssim |
|
|
| DEV = "cuda" |
| |
|
|
|
|
| def load_jit(): |
| p = hf_hub_download("nvidia/instant-nurec", "instant_nurec.pt", token=os.environ["HF_TOKEN"]) |
| torch.jit.set_fusion_strategy([("STATIC", 0), ("DYNAMIC", 0)]) |
| jit = torch.jit.load(p, map_location=DEV).eval() |
| sc = jit.static_core |
| V, H, W = int(sc.expected_v.item()), int(sc.expected_h.item()), int(sc.expected_w.item()) |
| rescale = float(sc.scene_rescale_buffer.item()) |
| return jit, V, H, W, rescale |
|
|
|
|
| def _read_img(clip_dir, f, v, K, H, W): |
| base = os.path.join(clip_dir, "images", f"{f:03d}_{v}") |
| path = base + ".jpg" if os.path.exists(base + ".jpg") else base + ".png" |
| arr = torch.from_numpy(np.asarray(imageio.imread(path))).float().permute(2, 0, 1) / 255.0 |
| return resize_with_intrinsics(arr, K, H, W) |
|
|
|
|
| def load_clip(clip_dir, V_in, H, W, n_sup=6): |
| m = torch.load(os.path.join(clip_dir, "meta.pt"), weights_only=False) |
| F, Vc = m["num_frames"], m["num_cameras"] |
| n_fr = max(1, V_in // Vc) |
| in_frames = sorted(set(int(round(x)) for x in np.linspace(0, F - 1, n_fr))) |
| in_pairs = [(f, v) for f in in_frames for v in range(Vc)][:V_in] |
| out_frames = [f for f in range(F) if f not in in_frames] or in_frames |
| sup_pairs = [(out_frames[k % len(out_frames)], k % Vc) for k in range(n_sup)] |
|
|
| def gather(pairs): |
| imgs, Ks, c2ws, cams = [], [], [], [] |
| for f, v in pairs: |
| img, K = _read_img(clip_dir, f, v, m["K"][v], H, W) |
| imgs.append(img); Ks.append(K); c2ws.append(m["cam2world"][f, v].float()); cams.append(v) |
| return (torch.stack(imgs).to(DEV), torch.stack(Ks).to(DEV), |
| torch.stack(c2ws).to(DEV), torch.tensor(cams, device=DEV)) |
| in_img, in_K, in_c2w, in_cam = gather(in_pairs) |
| sup_img, sup_K, sup_c2w, _ = gather(sup_pairs) |
| origin = in_c2w[:, :3, 3].mean(0) |
| in_c2w = in_c2w.clone(); in_c2w[:, :3, 3] -= origin |
| sup_c2w = sup_c2w.clone(); sup_c2w[:, :3, 3] -= origin |
| return dict(in_img=in_img, in_K=in_K, in_c2w=in_c2w, in_cam=in_cam, |
| sup_img=sup_img, sup_K=sup_K, sup_c2w=sup_c2w) |
|
|
|
|
| def nurec_forward(jit, d, rescale, H, W): |
| img, K, c2w, cam = d["in_img"], d["in_K"], d["in_c2w"], d["in_cam"] |
| V = img.shape[0] |
| rgb = img.permute(0, 2, 3, 1).contiguous()[None].float() |
| c2w_s = c2w.clone(); c2w_s[:, :3, 3] *= rescale; c2w_s = c2w_s[None].float() |
| fx, fy = K[:, 0, 0], K[:, 1, 1] |
| fov = torch.stack([2 * torch.atan2(torch.full_like(fx, W / 2), fx), |
| 2 * torch.atan2(torch.full_like(fy, H / 2), fy)], -1)[None] |
| ys, xs = torch.meshgrid(torch.arange(H, device=DEV), torch.arange(W, device=DEV), indexing="ij") |
| px, py = (xs + 0.5).float(), (ys + 0.5).float() |
| dc = torch.stack([(px - K[:, 0, 2, None, None]) / fx[:, None, None], |
| (py - K[:, 1, 2, None, None]) / fy[:, None, None], |
| torch.ones(V, H, W, device=DEV)], -1) |
| dirs_world = torch.einsum("vij,vhwj->vhwi", c2w[:, :3, :3], dc) |
| origins = c2w[:, :3, 3][:, None, None, :].expand(V, H, W, 3) |
| rays = torch.cat([origins, dirs_world], -1)[None].float() |
| d2d = (1.0 / dc.norm(dim=-1, keepdim=True))[None].float() |
| cidx = cam.to(torch.int64)[None] |
| out = jit(rgb, c2w_s, fov, rays, d2d, cidx) |
| gs_xyz, gs_rot, gs_scale, gs_dens, gs_rgb, sem = out[0], out[1], out[2], out[3], out[4], out[5] |
| return (gs_xyz.reshape(-1, 3), gs_rot.reshape(-1, 4), gs_scale.reshape(-1, 3), |
| gs_dens.reshape(-1), gs_rgb.reshape(-1, 3), sem.reshape(-1)) |
|
|
|
|
| def render(g, sup_c2w, sup_K, rescale, H, W, drop_sem=(1, 4)): |
| xyz, quat, scale, opac, color, sem = g |
| keep = torch.ones_like(sem, dtype=torch.bool) |
| for c in drop_sem: |
| keep &= sem != c |
| xyz, quat, scale, opac, color = xyz[keep], quat[keep], scale[keep], opac[keep], color[keep] |
| c2w_s = sup_c2w.clone(); c2w_s[:, :3, 3] *= rescale |
| viewmats = torch.inverse(c2w_s) |
| out, _, _ = rasterization(means=xyz, quats=quat, scales=scale, opacities=opac, colors=color, |
| viewmats=viewmats, Ks=sup_K, width=W, height=H, |
| near_plane=0.01, far_plane=1e3, render_mode="RGB") |
| return out.clamp(0, 1).permute(0, 3, 1, 2) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--val-root", default="/mnt/william/data/unified/av2/val") |
| ap.add_argument("--n-clips", type=int, default=8) |
| ap.add_argument("--save-viz", default="") |
| args = ap.parse_args() |
| jit, V, H, W, rescale = load_jit() |
| print(f"InstantNuRec JIT: V={V} H={H} W={W} scene_rescale={rescale}", flush=True) |
| clips = sorted(glob.glob(os.path.join(args.val_root, "*")))[: args.n_clips] |
| ps, ss = [], [] |
| for i, c in enumerate(clips): |
| if not os.path.exists(os.path.join(c, "meta.pt")): |
| continue |
| d = load_clip(c, V, H, W) |
| with torch.no_grad(): |
| g = nurec_forward(jit, d, rescale, H, W) |
| rgb = render(g, d["sup_c2w"], d["sup_K"], rescale, H, W) |
| p, s = float(psnr(rgb, d["sup_img"])), float(ssim(rgb, d["sup_img"])) |
| ps.append(p); ss.append(s) |
| print(f" clip {i} ({os.path.basename(c)}): held-out PSNR {p:.2f} SSIM {s:.3f} | " |
| f"gaussians {g[0].shape[0]//1000}k", flush=True) |
| if args.save_viz and i == 0: |
| import torchvision |
| grid = torch.cat([d["sup_img"][:4], rgb[:4]], 0) |
| torchvision.utils.save_image(grid, args.save_viz, nrow=4) |
| print(f"\n=== InstantNuRec ZERO-SHOT on held-out av2/val scenes (n={len(ps)}) ===", flush=True) |
| print(f"PSNR {sum(ps)/len(ps):.2f} SSIM {sum(ss)/len(ss):.3f}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|