"""Action-sensitivity diagnostic for the world model. From the SAME seed, roll out four dreams with the action FORCED to a fixed value (idle / left / right / jump) and compare. If the four rows look the same and divergence is tiny, the model is ignoring your buttons (controls not learned). If they diverge, controls work and any failure is drift / seeding. Run from the repo root: python3 src/action_test.py python3 src/action_test.py --steps 60 --epoch 30 """ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent)) import argparse import glob import random import imageio import numpy as np import torch from hydra.utils import instantiate from omegaconf import OmegaConf from agent import Agent from data.episode import Episode from models.diffusion import DiffusionSampler from utils import get_path_agent_ckpt NUM_ACTIONS = 6 UPSCALE = 4 # (label, action_id) — held constant for the whole rollout PROBES = [("idle", 0), ("left", 1), ("right", 2), ("jump", 3)] def main(): parser = argparse.ArgumentParser() parser.add_argument("--run-dir", type=str, default=None) parser.add_argument("--epoch", type=int, default=-1) parser.add_argument("--steps", type=int, default=48) parser.add_argument("--fps", type=int, default=10) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") run_dir = Path(args.run_dir) if args.run_dir else sorted(Path("outputs").glob("*/*"))[-1] ckpt = get_path_agent_ckpt(run_dir / "checkpoints", args.epoch) print(f"run dir: {run_dir}") print(f"checkpoint: {ckpt}") if not OmegaConf.has_resolver("eval"): OmegaConf.register_new_resolver("eval", eval) cfg = OmegaConf.load(run_dir / "config" / "trainer.yaml") agent = Agent(instantiate(cfg.agent, num_actions=NUM_ACTIONS)).to(device).eval() agent.load(ckpt) sampler = DiffusionSampler(agent.denoiser, instantiate(cfg.world_model_env.diffusion_sampler)) ctx = agent.denoiser.cfg.inner_model.num_steps_conditioning # one high-motion seed paths = [p for p in glob.glob("dataset/tobu/test/**/*.pt", recursive=True) if Path(p).stem.isdigit()] assert paths, "no test episodes under dataset/tobu/test" best = None for _ in range(200): ep = Episode.load(random.choice(paths)) if len(ep) < ctx + args.steps + 2: continue s = random.randrange(len(ep) - ctx - args.steps - 1) clip = ep.obs[s + ctx : s + ctx + args.steps] m = (clip[1:] - clip[:-1]).abs().mean().item() if best is None or m > best[0]: best = (m, ep, s) motion, ep, start = best print(f"seed motion {motion:.4f}, start {start}") n = len(PROBES) obs0 = ep.obs[start : start + ctx].unsqueeze(0).repeat(n, 1, 1, 1, 1).to(device) act0 = ep.act[start : start + ctx].unsqueeze(0).repeat(n, 1).to(device) forced = torch.tensor([a for _, a in PROBES], device=device) obs_buf, act_buf = obs0.clone(), act0.clone() frames = [] with torch.no_grad(): for t in range(args.steps): act_buf[:, -1] = forced next_obs, _ = sampler.sample(obs_buf, act_buf) next_obs = next_obs.clamp(-1, 1) obs_buf = obs_buf.roll(-1, dims=1) act_buf = act_buf.roll(-1, dims=1) obs_buf[:, -1] = next_obs frames.append(next_obs.cpu()) seq = torch.stack(frames) # [T,n,3,H,W] # numeric divergence vs idle (row 0) print("\nmean |frame - idle_frame| over the rollout (0 = identical to idle):") for i, (label, _) in enumerate(PROBES): div = (seq[:, i] - seq[:, 0]).abs().mean().item() bar = "#" * int(div * 400) print(f" {label:>5}: {div:.4f} {bar}") print("\nIf left/right/jump are all ~0.00, the model is IGNORING actions.\n") def to_uint8(x): # [T,3,H,W] -> [T,H,W,3] x = ((x + 1) / 2).permute(0, 2, 3, 1).numpy() x = (x * 255).clip(0, 255).astype(np.uint8) return x.repeat(UPSCALE, axis=1).repeat(UPSCALE, axis=2) rows = [to_uint8(seq[:, i]) for i in range(n)] # each [T,H,W,3] T, H, W, _ = rows[0].shape hsep = np.full((T, 6, W, 3), 200, np.uint8) grid = [] for i, r in enumerate(rows): grid.append(r) if i < n - 1: grid.append(hsep) out_frames = np.concatenate(grid, axis=1) # stack rows vertically out = f"action_test_epoch_{ckpt.stem.split('_')[-1]}.gif" try: imageio.mimsave(out, list(out_frames), fps=args.fps, loop=0) except TypeError: imageio.mimsave(out, list(out_frames), duration=1000 / args.fps, loop=0) print(f"wrote {out} (rows top->bottom: {', '.join(l for l, _ in PROBES)})") if __name__ == "__main__": main()