mapvggt / scripts /demo_extrapolation.py
ChenmingWu's picture
Upload folder using huggingface_hub
8cf92b3 verified
Raw
History Blame Contribute Delete
2.71 kB
#!/usr/bin/env python3
"""Render a lateral-shift sweep for one scene and save a comparison grid:
rows = lateral shift, cols = [MapGS render | GT deviated view]. Showcases the
out-of-trajectory simulation use case (§4.8).
python scripts/demo_extrapolation.py --config configs/synthetic_smoke.yaml \\
--ckpt runs/mapgs_synth/ckpt_final.pt --scene 0 --out demo.png
"""
import argparse
import numpy as np
import torch
from mapgs.config import load_config
from mapgs.data import SyntheticDataset
from mapgs.data.synthetic import render_deviated_gt
from mapgs.eval.evaluator import Evaluator
from mapgs.losses import perturb_pose
from mapgs.model import MapGS
from mapgs.model.dynamic import place_dynamic_gaussians
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--config", type=str, default="configs/synthetic_smoke.yaml")
ap.add_argument("--override", nargs="*", default=[])
ap.add_argument("--ckpt", type=str, default=None)
ap.add_argument("--scene", type=int, default=0)
ap.add_argument("--frame", type=int, default=None)
ap.add_argument("--out", type=str, default="demo_extrapolation.png")
args = ap.parse_args()
import imageio.v2 as imageio
cfg = load_config(args.config, args.override)
model = MapGS(cfg).to(cfg.device)
if args.ckpt:
model.load_state_dict(torch.load(args.ckpt, map_location=cfg.device, weights_only=False)["model"])
model.eval()
ds = SyntheticDataset(cfg, "train", device=cfg.device)
ev = Evaluator(model, cfg, device=cfg.device)
s = ds[args.scene]
scene = ds.get_scene(args.scene)
frame = args.frame if args.frame is not None else cfg.data.num_frames // 2
g, dyn = ev._decode(s)
g_f = g if dyn is None else place_dynamic_gaussians(
g, dyn["box_centers"], dyn["box_rots"], dyn["canon_idx"], frame)
Kc = scene.K[1].to(cfg.device)
base = scene.cam2world[frame, 1].to(cfg.device)
rows = []
for sh in cfg.eval.lateral_shifts:
dev = perturb_pose(base, lateral=float(sh))
out = ev.ras.render(g_f, Kc[None], dev[None], cfg.data.height, cfg.data.width)
pred = (model.feature_to_rgb(out.color) if model.uses_features else out.color[:, :3].clamp(0, 1))[0]
gt, _ = render_deviated_gt(scene, frame, dev.cpu(), cfg.device, cfg.data.height, cfg.data.width, ev.ras)
row = torch.cat([pred.cpu(), gt], dim=-1) # [3, H, 2W]
rows.append(row)
grid = torch.cat(rows, dim=-2).permute(1, 2, 0).clamp(0, 1).numpy()
imageio.imwrite(args.out, (grid * 255).astype(np.uint8))
print(f"wrote {args.out} (rows = shifts {list(cfg.eval.lateral_shifts)} m; left=MapGS, right=GT)")
if __name__ == "__main__":
main()