#!/usr/bin/env python3 """Render a lateral-shift sweep for one scene and save a comparison grid: rows = lateral shift, cols = [MapGS render | GT deviated view]. Showcases the out-of-trajectory simulation use case (ยง4.8). python scripts/demo_extrapolation.py --config configs/synthetic_smoke.yaml \\ --ckpt runs/mapgs_synth/ckpt_final.pt --scene 0 --out demo.png """ import argparse import numpy as np import torch from mapgs.config import load_config from mapgs.data import SyntheticDataset from mapgs.data.synthetic import render_deviated_gt from mapgs.eval.evaluator import Evaluator from mapgs.losses import perturb_pose from mapgs.model import MapGS from mapgs.model.dynamic import place_dynamic_gaussians def main(): ap = argparse.ArgumentParser() ap.add_argument("--config", type=str, default="configs/synthetic_smoke.yaml") ap.add_argument("--override", nargs="*", default=[]) ap.add_argument("--ckpt", type=str, default=None) ap.add_argument("--scene", type=int, default=0) ap.add_argument("--frame", type=int, default=None) ap.add_argument("--out", type=str, default="demo_extrapolation.png") args = ap.parse_args() import imageio.v2 as imageio cfg = load_config(args.config, args.override) model = MapGS(cfg).to(cfg.device) if args.ckpt: model.load_state_dict(torch.load(args.ckpt, map_location=cfg.device, weights_only=False)["model"]) model.eval() ds = SyntheticDataset(cfg, "train", device=cfg.device) ev = Evaluator(model, cfg, device=cfg.device) s = ds[args.scene] scene = ds.get_scene(args.scene) frame = args.frame if args.frame is not None else cfg.data.num_frames // 2 g, dyn = ev._decode(s) g_f = g if dyn is None else place_dynamic_gaussians( g, dyn["box_centers"], dyn["box_rots"], dyn["canon_idx"], frame) Kc = scene.K[1].to(cfg.device) base = scene.cam2world[frame, 1].to(cfg.device) rows = [] for sh in cfg.eval.lateral_shifts: dev = perturb_pose(base, lateral=float(sh)) out = ev.ras.render(g_f, Kc[None], dev[None], cfg.data.height, cfg.data.width) pred = (model.feature_to_rgb(out.color) if model.uses_features else out.color[:, :3].clamp(0, 1))[0] gt, _ = render_deviated_gt(scene, frame, dev.cpu(), cfg.device, cfg.data.height, cfg.data.width, ev.ras) row = torch.cat([pred.cpu(), gt], dim=-1) # [3, H, 2W] rows.append(row) grid = torch.cat(rows, dim=-2).permute(1, 2, 0).clamp(0, 1).numpy() imageio.imwrite(args.out, (grid * 255).astype(np.uint8)) print(f"wrote {args.out} (rows = shifts {list(cfg.eval.lateral_shifts)} m; left=MapGS, right=GT)") if __name__ == "__main__": main()