| |
| """Render a lateral-shift sweep for one scene and save a comparison grid: |
| rows = lateral shift, cols = [MapGS render | GT deviated view]. Showcases the |
| out-of-trajectory simulation use case (§4.8). |
| |
| python scripts/demo_extrapolation.py --config configs/synthetic_smoke.yaml \\ |
| --ckpt runs/mapgs_synth/ckpt_final.pt --scene 0 --out demo.png |
| """ |
|
|
| import argparse |
|
|
| import numpy as np |
| import torch |
|
|
| from mapgs.config import load_config |
| from mapgs.data import SyntheticDataset |
| from mapgs.data.synthetic import render_deviated_gt |
| from mapgs.eval.evaluator import Evaluator |
| from mapgs.losses import perturb_pose |
| from mapgs.model import MapGS |
| from mapgs.model.dynamic import place_dynamic_gaussians |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--config", type=str, default="configs/synthetic_smoke.yaml") |
| ap.add_argument("--override", nargs="*", default=[]) |
| ap.add_argument("--ckpt", type=str, default=None) |
| ap.add_argument("--scene", type=int, default=0) |
| ap.add_argument("--frame", type=int, default=None) |
| ap.add_argument("--out", type=str, default="demo_extrapolation.png") |
| args = ap.parse_args() |
|
|
| import imageio.v2 as imageio |
|
|
| cfg = load_config(args.config, args.override) |
| model = MapGS(cfg).to(cfg.device) |
| if args.ckpt: |
| model.load_state_dict(torch.load(args.ckpt, map_location=cfg.device, weights_only=False)["model"]) |
| model.eval() |
|
|
| ds = SyntheticDataset(cfg, "train", device=cfg.device) |
| ev = Evaluator(model, cfg, device=cfg.device) |
| s = ds[args.scene] |
| scene = ds.get_scene(args.scene) |
| frame = args.frame if args.frame is not None else cfg.data.num_frames // 2 |
| g, dyn = ev._decode(s) |
| g_f = g if dyn is None else place_dynamic_gaussians( |
| g, dyn["box_centers"], dyn["box_rots"], dyn["canon_idx"], frame) |
| Kc = scene.K[1].to(cfg.device) |
| base = scene.cam2world[frame, 1].to(cfg.device) |
|
|
| rows = [] |
| for sh in cfg.eval.lateral_shifts: |
| dev = perturb_pose(base, lateral=float(sh)) |
| out = ev.ras.render(g_f, Kc[None], dev[None], cfg.data.height, cfg.data.width) |
| pred = (model.feature_to_rgb(out.color) if model.uses_features else out.color[:, :3].clamp(0, 1))[0] |
| gt, _ = render_deviated_gt(scene, frame, dev.cpu(), cfg.device, cfg.data.height, cfg.data.width, ev.ras) |
| row = torch.cat([pred.cpu(), gt], dim=-1) |
| rows.append(row) |
| grid = torch.cat(rows, dim=-2).permute(1, 2, 0).clamp(0, 1).numpy() |
| imageio.imwrite(args.out, (grid * 255).astype(np.uint8)) |
| print(f"wrote {args.out} (rows = shifts {list(cfg.eval.lateral_shifts)} m; left=MapGS, right=GT)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|