neural-boy / src /render_rollout.py
sachinkumarsingh's picture
Upload folder using huggingface_hub
d548197 verified
Raw
History Blame Contribute Delete
5.19 kB
"""Render world-model rollouts to a gif, headless.
Run from the repo root:
python3 src/render_rollout.py # latest run, latest checkpoint
python3 src/render_rollout.py --steps 80 --num-seeds 4
python3 src/render_rollout.py --run-dir outputs/2026-06-13/xx-xx-xx --epoch 10
Output: rollout_epoch_XXXXX.gif — one row per seed: model dream (left) vs real
frames (right), both driven by the same recorded action sequence. Seeds are
chosen from test episodes by motion score, so blank-sky segments are skipped.
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import argparse
import glob
import random
import imageio
import numpy as np
import torch
from hydra.utils import instantiate
from omegaconf import OmegaConf
from agent import Agent
from data.episode import Episode
from models.diffusion import DiffusionSampler
from utils import get_path_agent_ckpt
NUM_ACTIONS = 6
UPSCALE = 4
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--run-dir", type=str, default=None, help="hydra run dir (default: latest under outputs/)")
parser.add_argument("--epoch", type=int, default=-1, help="agent checkpoint epoch (default: latest)")
parser.add_argument("--steps", type=int, default=80, help="number of frames to dream")
parser.add_argument("--num-seeds", type=int, default=4, help="rollouts rendered in parallel (one row each)")
parser.add_argument("--fps", type=int, default=10)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
run_dir = Path(args.run_dir) if args.run_dir else sorted(Path("outputs").glob("*/*"))[-1]
ckpt = get_path_agent_ckpt(run_dir / "checkpoints", args.epoch)
print(f"run dir: {run_dir}")
print(f"checkpoint: {ckpt}")
if not OmegaConf.has_resolver("eval"):
OmegaConf.register_new_resolver("eval", eval)
cfg = OmegaConf.load(run_dir / "config" / "trainer.yaml")
agent = Agent(instantiate(cfg.agent, num_actions=NUM_ACTIONS)).to(device).eval()
agent.load(ckpt)
sampler = DiffusionSampler(agent.denoiser, instantiate(cfg.world_model_env.diffusion_sampler))
ctx = agent.denoiser.cfg.inner_model.num_steps_conditioning # 4
# Candidate seeds from test episodes, scored by how much the real clip moves
paths = [p for p in glob.glob("dataset/tobu/test/**/*.pt", recursive=True) if Path(p).stem.isdigit()]
assert paths, "no test episodes found under dataset/tobu/test"
episodes = [Episode.load(p) for p in paths]
candidates = []
for ei, ep in enumerate(episodes):
max_start = len(ep) - ctx - args.steps - 1
if max_start < 1:
continue
for start in random.sample(range(max_start), min(20, max_start)):
clip = ep.obs[start + ctx : start + ctx + args.steps]
motion = (clip[1:] - clip[:-1]).abs().mean().item()
candidates.append((motion, ei, start))
assert candidates, "episodes too short for the requested --steps"
candidates.sort(reverse=True)
seeds = candidates[: args.num_seeds]
for motion, ei, start in seeds:
print(f"seed: episode {paths[ei]} start {start} motion {motion:.4f}")
n = len(seeds)
obs_buf = torch.stack([episodes[ei].obs[s : s + ctx] for _, ei, s in seeds]).to(device) # [N,4,3,64,64]
act_buf = torch.stack([episodes[ei].act[s : s + ctx] for _, ei, s in seeds]).to(device) # [N,4]
dream, real = [], []
for t in range(args.steps):
for i, (_, ei, s) in enumerate(seeds):
act_buf[i, -1] = episodes[ei].act[s + ctx - 1 + t]
with torch.no_grad():
next_obs, _ = sampler.sample(obs_buf, act_buf)
next_obs = next_obs.clamp(-1, 1)
obs_buf = obs_buf.roll(-1, dims=1)
act_buf = act_buf.roll(-1, dims=1)
obs_buf[:, -1] = next_obs
dream.append(next_obs.cpu())
real.append(torch.stack([episodes[ei].obs[s + ctx + t] for _, ei, s in seeds]))
def to_uint8(x): # [T,N,3,H,W] in [-1,1] -> [T,N,H,W,3] uint8, upscaled
x = ((torch.stack(x) + 1) / 2).permute(0, 1, 3, 4, 2).numpy()
x = (x * 255).clip(0, 255).astype(np.uint8)
return x.repeat(UPSCALE, axis=2).repeat(UPSCALE, axis=3)
d, r = to_uint8(dream), to_uint8(real) # [T,N,H,W,3]
T, N, H, W, _ = d.shape
vsep = np.full((T, N, H, 8, 3), 255, np.uint8)
rows = np.concatenate([d, vsep, r], axis=3) # dream | real, per seed
hsep = np.full((T, 6, rows.shape[3], 3), 0, np.uint8)
frames = []
for t in range(T):
parts = []
for i in range(N):
parts.append(rows[t, i])
if i < N - 1:
parts.append(hsep[t])
frames.append(np.concatenate(parts, axis=0))
out = f"rollout_epoch_{ckpt.stem.split('_')[-1]}.gif"
try:
imageio.mimsave(out, frames, fps=args.fps, loop=0)
except TypeError: # newer imageio dropped `fps` for gif
imageio.mimsave(out, frames, duration=1000 / args.fps, loop=0)
print(f"wrote {out} ({args.steps} frames, {n} seeds; per row: dream left / real right)")
if __name__ == "__main__":
main()