Spaces:
Sleeping
Sleeping
| """Render world-model rollouts to a gif, headless. | |
| Run from the repo root: | |
| python3 src/render_rollout.py # latest run, latest checkpoint | |
| python3 src/render_rollout.py --steps 80 --num-seeds 4 | |
| python3 src/render_rollout.py --run-dir outputs/2026-06-13/xx-xx-xx --epoch 10 | |
| Output: rollout_epoch_XXXXX.gif — one row per seed: model dream (left) vs real | |
| frames (right), both driven by the same recorded action sequence. Seeds are | |
| chosen from test episodes by motion score, so blank-sky segments are skipped. | |
| """ | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| import argparse | |
| import glob | |
| import random | |
| import imageio | |
| import numpy as np | |
| import torch | |
| from hydra.utils import instantiate | |
| from omegaconf import OmegaConf | |
| from agent import Agent | |
| from data.episode import Episode | |
| from models.diffusion import DiffusionSampler | |
| from utils import get_path_agent_ckpt | |
| NUM_ACTIONS = 6 | |
| UPSCALE = 4 | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--run-dir", type=str, default=None, help="hydra run dir (default: latest under outputs/)") | |
| parser.add_argument("--epoch", type=int, default=-1, help="agent checkpoint epoch (default: latest)") | |
| parser.add_argument("--steps", type=int, default=80, help="number of frames to dream") | |
| parser.add_argument("--num-seeds", type=int, default=4, help="rollouts rendered in parallel (one row each)") | |
| parser.add_argument("--fps", type=int, default=10) | |
| args = parser.parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| run_dir = Path(args.run_dir) if args.run_dir else sorted(Path("outputs").glob("*/*"))[-1] | |
| ckpt = get_path_agent_ckpt(run_dir / "checkpoints", args.epoch) | |
| print(f"run dir: {run_dir}") | |
| print(f"checkpoint: {ckpt}") | |
| if not OmegaConf.has_resolver("eval"): | |
| OmegaConf.register_new_resolver("eval", eval) | |
| cfg = OmegaConf.load(run_dir / "config" / "trainer.yaml") | |
| agent = Agent(instantiate(cfg.agent, num_actions=NUM_ACTIONS)).to(device).eval() | |
| agent.load(ckpt) | |
| sampler = DiffusionSampler(agent.denoiser, instantiate(cfg.world_model_env.diffusion_sampler)) | |
| ctx = agent.denoiser.cfg.inner_model.num_steps_conditioning # 4 | |
| # Candidate seeds from test episodes, scored by how much the real clip moves | |
| paths = [p for p in glob.glob("dataset/tobu/test/**/*.pt", recursive=True) if Path(p).stem.isdigit()] | |
| assert paths, "no test episodes found under dataset/tobu/test" | |
| episodes = [Episode.load(p) for p in paths] | |
| candidates = [] | |
| for ei, ep in enumerate(episodes): | |
| max_start = len(ep) - ctx - args.steps - 1 | |
| if max_start < 1: | |
| continue | |
| for start in random.sample(range(max_start), min(20, max_start)): | |
| clip = ep.obs[start + ctx : start + ctx + args.steps] | |
| motion = (clip[1:] - clip[:-1]).abs().mean().item() | |
| candidates.append((motion, ei, start)) | |
| assert candidates, "episodes too short for the requested --steps" | |
| candidates.sort(reverse=True) | |
| seeds = candidates[: args.num_seeds] | |
| for motion, ei, start in seeds: | |
| print(f"seed: episode {paths[ei]} start {start} motion {motion:.4f}") | |
| n = len(seeds) | |
| obs_buf = torch.stack([episodes[ei].obs[s : s + ctx] for _, ei, s in seeds]).to(device) # [N,4,3,64,64] | |
| act_buf = torch.stack([episodes[ei].act[s : s + ctx] for _, ei, s in seeds]).to(device) # [N,4] | |
| dream, real = [], [] | |
| for t in range(args.steps): | |
| for i, (_, ei, s) in enumerate(seeds): | |
| act_buf[i, -1] = episodes[ei].act[s + ctx - 1 + t] | |
| with torch.no_grad(): | |
| next_obs, _ = sampler.sample(obs_buf, act_buf) | |
| next_obs = next_obs.clamp(-1, 1) | |
| obs_buf = obs_buf.roll(-1, dims=1) | |
| act_buf = act_buf.roll(-1, dims=1) | |
| obs_buf[:, -1] = next_obs | |
| dream.append(next_obs.cpu()) | |
| real.append(torch.stack([episodes[ei].obs[s + ctx + t] for _, ei, s in seeds])) | |
| def to_uint8(x): # [T,N,3,H,W] in [-1,1] -> [T,N,H,W,3] uint8, upscaled | |
| x = ((torch.stack(x) + 1) / 2).permute(0, 1, 3, 4, 2).numpy() | |
| x = (x * 255).clip(0, 255).astype(np.uint8) | |
| return x.repeat(UPSCALE, axis=2).repeat(UPSCALE, axis=3) | |
| d, r = to_uint8(dream), to_uint8(real) # [T,N,H,W,3] | |
| T, N, H, W, _ = d.shape | |
| vsep = np.full((T, N, H, 8, 3), 255, np.uint8) | |
| rows = np.concatenate([d, vsep, r], axis=3) # dream | real, per seed | |
| hsep = np.full((T, 6, rows.shape[3], 3), 0, np.uint8) | |
| frames = [] | |
| for t in range(T): | |
| parts = [] | |
| for i in range(N): | |
| parts.append(rows[t, i]) | |
| if i < N - 1: | |
| parts.append(hsep[t]) | |
| frames.append(np.concatenate(parts, axis=0)) | |
| out = f"rollout_epoch_{ckpt.stem.split('_')[-1]}.gif" | |
| try: | |
| imageio.mimsave(out, frames, fps=args.fps, loop=0) | |
| except TypeError: # newer imageio dropped `fps` for gif | |
| imageio.mimsave(out, frames, duration=1000 / args.fps, loop=0) | |
| print(f"wrote {out} ({args.steps} frames, {n} seeds; per row: dream left / real right)") | |
| if __name__ == "__main__": | |
| main() | |