mapvggt / scripts /neurec_eval.py
ChenmingWu's picture
Upload folder using huggingface_hub
8cf92b3 verified
Raw
History Blame Contribute Delete
7.08 kB
#!/usr/bin/env python3
"""MapGS ver.neurec — zero-shot held-out-view eval of NVIDIA InstantNuRec on our
unified clips. Feeds 18 posed input frames (6 timesteps x 3 cams) -> per-pixel
Gaussians -> gsplat render to held-out views -> PSNR/SSIM. Validates the input/
frame conventions before any finetuning. Inputs follow the JIT call spec:
rgb (1,V,H,W,3) [0,1] channel-last; c2w (1,V,4,4) translation*scene_rescale;
fov (1,V,2) rad; rays (1,V,H,W,6) world; distance_to_depth_scale (1,V,H,W,1);
camera_idxs (1,V) int64. Outputs: gs_xyz, gs_rot(wxyz), gs_scale(lin), gs_dens
(sigmoid), gs_rgb, semantic, normals, affine — gaussians live in the *scene_rescale
frame, so render cameras are scaled identically."""
import os, glob, math, argparse
import numpy as np
import torch
import imageio.v2 as imageio
from huggingface_hub import hf_hub_download
from gsplat import rasterization
from mapgs.geometry.cameras import resize_with_intrinsics
from mapgs.eval.metrics import psnr, ssim
DEV = "cuda"
# semantic class ids (KelvinSemanticClass): 0 OTHERS,1 EGO,2 SKY,3 ROAD,4 MOVABLE
def load_jit():
p = hf_hub_download("nvidia/instant-nurec", "instant_nurec.pt", token=os.environ["HF_TOKEN"])
torch.jit.set_fusion_strategy([("STATIC", 0), ("DYNAMIC", 0)])
jit = torch.jit.load(p, map_location=DEV).eval()
sc = jit.static_core
V, H, W = int(sc.expected_v.item()), int(sc.expected_h.item()), int(sc.expected_w.item())
rescale = float(sc.scene_rescale_buffer.item())
return jit, V, H, W, rescale
def _read_img(clip_dir, f, v, K, H, W):
base = os.path.join(clip_dir, "images", f"{f:03d}_{v}")
path = base + ".jpg" if os.path.exists(base + ".jpg") else base + ".png"
arr = torch.from_numpy(np.asarray(imageio.imread(path))).float().permute(2, 0, 1) / 255.0
return resize_with_intrinsics(arr, K, H, W)
def load_clip(clip_dir, V_in, H, W, n_sup=6):
m = torch.load(os.path.join(clip_dir, "meta.pt"), weights_only=False)
F, Vc = m["num_frames"], m["num_cameras"]
n_fr = max(1, V_in // Vc) # 18//3 = 6 timesteps
in_frames = sorted(set(int(round(x)) for x in np.linspace(0, F - 1, n_fr)))
in_pairs = [(f, v) for f in in_frames for v in range(Vc)][:V_in]
out_frames = [f for f in range(F) if f not in in_frames] or in_frames
sup_pairs = [(out_frames[k % len(out_frames)], k % Vc) for k in range(n_sup)]
def gather(pairs):
imgs, Ks, c2ws, cams = [], [], [], []
for f, v in pairs:
img, K = _read_img(clip_dir, f, v, m["K"][v], H, W)
imgs.append(img); Ks.append(K); c2ws.append(m["cam2world"][f, v].float()); cams.append(v)
return (torch.stack(imgs).to(DEV), torch.stack(Ks).to(DEV),
torch.stack(c2ws).to(DEV), torch.tensor(cams, device=DEV))
in_img, in_K, in_c2w, in_cam = gather(in_pairs)
sup_img, sup_K, sup_c2w, _ = gather(sup_pairs)
origin = in_c2w[:, :3, 3].mean(0) # local frame
in_c2w = in_c2w.clone(); in_c2w[:, :3, 3] -= origin
sup_c2w = sup_c2w.clone(); sup_c2w[:, :3, 3] -= origin
return dict(in_img=in_img, in_K=in_K, in_c2w=in_c2w, in_cam=in_cam,
sup_img=sup_img, sup_K=sup_K, sup_c2w=sup_c2w)
def nurec_forward(jit, d, rescale, H, W):
img, K, c2w, cam = d["in_img"], d["in_K"], d["in_c2w"], d["in_cam"]
V = img.shape[0]
rgb = img.permute(0, 2, 3, 1).contiguous()[None].float() # 1,V,H,W,3
c2w_s = c2w.clone(); c2w_s[:, :3, 3] *= rescale; c2w_s = c2w_s[None].float()
fx, fy = K[:, 0, 0], K[:, 1, 1]
fov = torch.stack([2 * torch.atan2(torch.full_like(fx, W / 2), fx),
2 * torch.atan2(torch.full_like(fy, H / 2), fy)], -1)[None]
ys, xs = torch.meshgrid(torch.arange(H, device=DEV), torch.arange(W, device=DEV), indexing="ij")
px, py = (xs + 0.5).float(), (ys + 0.5).float()
dc = torch.stack([(px - K[:, 0, 2, None, None]) / fx[:, None, None],
(py - K[:, 1, 2, None, None]) / fy[:, None, None],
torch.ones(V, H, W, device=DEV)], -1) # V,H,W,3 (OpenCV +z)
dirs_world = torch.einsum("vij,vhwj->vhwi", c2w[:, :3, :3], dc)
origins = c2w[:, :3, 3][:, None, None, :].expand(V, H, W, 3) # un-rescaled (per spec)
rays = torch.cat([origins, dirs_world], -1)[None].float()
d2d = (1.0 / dc.norm(dim=-1, keepdim=True))[None].float()
cidx = cam.to(torch.int64)[None]
out = jit(rgb, c2w_s, fov, rays, d2d, cidx)
gs_xyz, gs_rot, gs_scale, gs_dens, gs_rgb, sem = out[0], out[1], out[2], out[3], out[4], out[5]
return (gs_xyz.reshape(-1, 3), gs_rot.reshape(-1, 4), gs_scale.reshape(-1, 3),
gs_dens.reshape(-1), gs_rgb.reshape(-1, 3), sem.reshape(-1))
def render(g, sup_c2w, sup_K, rescale, H, W, drop_sem=(1, 4)):
xyz, quat, scale, opac, color, sem = g
keep = torch.ones_like(sem, dtype=torch.bool)
for c in drop_sem: # drop EGO + MOVABLE for static NVS
keep &= sem != c
xyz, quat, scale, opac, color = xyz[keep], quat[keep], scale[keep], opac[keep], color[keep]
c2w_s = sup_c2w.clone(); c2w_s[:, :3, 3] *= rescale # gaussians are in *rescale frame
viewmats = torch.inverse(c2w_s)
out, _, _ = rasterization(means=xyz, quats=quat, scales=scale, opacities=opac, colors=color,
viewmats=viewmats, Ks=sup_K, width=W, height=H,
near_plane=0.01, far_plane=1e3, render_mode="RGB")
return out.clamp(0, 1).permute(0, 3, 1, 2) # S,3,H,W
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--val-root", default="/mnt/william/data/unified/av2/val")
ap.add_argument("--n-clips", type=int, default=8)
ap.add_argument("--save-viz", default="")
args = ap.parse_args()
jit, V, H, W, rescale = load_jit()
print(f"InstantNuRec JIT: V={V} H={H} W={W} scene_rescale={rescale}", flush=True)
clips = sorted(glob.glob(os.path.join(args.val_root, "*")))[: args.n_clips]
ps, ss = [], []
for i, c in enumerate(clips):
if not os.path.exists(os.path.join(c, "meta.pt")):
continue
d = load_clip(c, V, H, W)
with torch.no_grad():
g = nurec_forward(jit, d, rescale, H, W)
rgb = render(g, d["sup_c2w"], d["sup_K"], rescale, H, W)
p, s = float(psnr(rgb, d["sup_img"])), float(ssim(rgb, d["sup_img"]))
ps.append(p); ss.append(s)
print(f" clip {i} ({os.path.basename(c)}): held-out PSNR {p:.2f} SSIM {s:.3f} | "
f"gaussians {g[0].shape[0]//1000}k", flush=True)
if args.save_viz and i == 0:
import torchvision
grid = torch.cat([d["sup_img"][:4], rgb[:4]], 0)
torchvision.utils.save_image(grid, args.save_viz, nrow=4)
print(f"\n=== InstantNuRec ZERO-SHOT on held-out av2/val scenes (n={len(ps)}) ===", flush=True)
print(f"PSNR {sum(ps)/len(ps):.2f} SSIM {sum(ss)/len(ss):.3f}", flush=True)
if __name__ == "__main__":
main()