"""Render world-model rollouts to a gif, headless. Run from the repo root: python3 src/render_rollout.py # latest run, latest checkpoint python3 src/render_rollout.py --steps 80 --num-seeds 4 python3 src/render_rollout.py --run-dir outputs/2026-06-13/xx-xx-xx --epoch 10 Output: rollout_epoch_XXXXX.gif — one row per seed: model dream (left) vs real frames (right), both driven by the same recorded action sequence. Seeds are chosen from test episodes by motion score, so blank-sky segments are skipped. """ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent)) import argparse import glob import random import imageio import numpy as np import torch from hydra.utils import instantiate from omegaconf import OmegaConf from agent import Agent from data.episode import Episode from models.diffusion import DiffusionSampler from utils import get_path_agent_ckpt NUM_ACTIONS = 6 UPSCALE = 4 def main(): parser = argparse.ArgumentParser() parser.add_argument("--run-dir", type=str, default=None, help="hydra run dir (default: latest under outputs/)") parser.add_argument("--epoch", type=int, default=-1, help="agent checkpoint epoch (default: latest)") parser.add_argument("--steps", type=int, default=80, help="number of frames to dream") parser.add_argument("--num-seeds", type=int, default=4, help="rollouts rendered in parallel (one row each)") parser.add_argument("--fps", type=int, default=10) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") run_dir = Path(args.run_dir) if args.run_dir else sorted(Path("outputs").glob("*/*"))[-1] ckpt = get_path_agent_ckpt(run_dir / "checkpoints", args.epoch) print(f"run dir: {run_dir}") print(f"checkpoint: {ckpt}") if not OmegaConf.has_resolver("eval"): OmegaConf.register_new_resolver("eval", eval) cfg = OmegaConf.load(run_dir / "config" / "trainer.yaml") agent = Agent(instantiate(cfg.agent, num_actions=NUM_ACTIONS)).to(device).eval() agent.load(ckpt) sampler = DiffusionSampler(agent.denoiser, instantiate(cfg.world_model_env.diffusion_sampler)) ctx = agent.denoiser.cfg.inner_model.num_steps_conditioning # 4 # Candidate seeds from test episodes, scored by how much the real clip moves paths = [p for p in glob.glob("dataset/tobu/test/**/*.pt", recursive=True) if Path(p).stem.isdigit()] assert paths, "no test episodes found under dataset/tobu/test" episodes = [Episode.load(p) for p in paths] candidates = [] for ei, ep in enumerate(episodes): max_start = len(ep) - ctx - args.steps - 1 if max_start < 1: continue for start in random.sample(range(max_start), min(20, max_start)): clip = ep.obs[start + ctx : start + ctx + args.steps] motion = (clip[1:] - clip[:-1]).abs().mean().item() candidates.append((motion, ei, start)) assert candidates, "episodes too short for the requested --steps" candidates.sort(reverse=True) seeds = candidates[: args.num_seeds] for motion, ei, start in seeds: print(f"seed: episode {paths[ei]} start {start} motion {motion:.4f}") n = len(seeds) obs_buf = torch.stack([episodes[ei].obs[s : s + ctx] for _, ei, s in seeds]).to(device) # [N,4,3,64,64] act_buf = torch.stack([episodes[ei].act[s : s + ctx] for _, ei, s in seeds]).to(device) # [N,4] dream, real = [], [] for t in range(args.steps): for i, (_, ei, s) in enumerate(seeds): act_buf[i, -1] = episodes[ei].act[s + ctx - 1 + t] with torch.no_grad(): next_obs, _ = sampler.sample(obs_buf, act_buf) next_obs = next_obs.clamp(-1, 1) obs_buf = obs_buf.roll(-1, dims=1) act_buf = act_buf.roll(-1, dims=1) obs_buf[:, -1] = next_obs dream.append(next_obs.cpu()) real.append(torch.stack([episodes[ei].obs[s + ctx + t] for _, ei, s in seeds])) def to_uint8(x): # [T,N,3,H,W] in [-1,1] -> [T,N,H,W,3] uint8, upscaled x = ((torch.stack(x) + 1) / 2).permute(0, 1, 3, 4, 2).numpy() x = (x * 255).clip(0, 255).astype(np.uint8) return x.repeat(UPSCALE, axis=2).repeat(UPSCALE, axis=3) d, r = to_uint8(dream), to_uint8(real) # [T,N,H,W,3] T, N, H, W, _ = d.shape vsep = np.full((T, N, H, 8, 3), 255, np.uint8) rows = np.concatenate([d, vsep, r], axis=3) # dream | real, per seed hsep = np.full((T, 6, rows.shape[3], 3), 0, np.uint8) frames = [] for t in range(T): parts = [] for i in range(N): parts.append(rows[t, i]) if i < N - 1: parts.append(hsep[t]) frames.append(np.concatenate(parts, axis=0)) out = f"rollout_epoch_{ckpt.stem.split('_')[-1]}.gif" try: imageio.mimsave(out, frames, fps=args.fps, loop=0) except TypeError: # newer imageio dropped `fps` for gif imageio.mimsave(out, frames, duration=1000 / args.fps, loop=0) print(f"wrote {out} ({args.steps} frames, {n} seeds; per row: dream left / real right)") if __name__ == "__main__": main()