Spaces:
Sleeping
Sleeping
| """Action-sensitivity diagnostic for the world model. | |
| From the SAME seed, roll out four dreams with the action FORCED to a fixed | |
| value (idle / left / right / jump) and compare. If the four rows look the same | |
| and divergence is tiny, the model is ignoring your buttons (controls not | |
| learned). If they diverge, controls work and any failure is drift / seeding. | |
| Run from the repo root: | |
| python3 src/action_test.py | |
| python3 src/action_test.py --steps 60 --epoch 30 | |
| """ | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| import argparse | |
| import glob | |
| import random | |
| import imageio | |
| import numpy as np | |
| import torch | |
| from hydra.utils import instantiate | |
| from omegaconf import OmegaConf | |
| from agent import Agent | |
| from data.episode import Episode | |
| from models.diffusion import DiffusionSampler | |
| from utils import get_path_agent_ckpt | |
| NUM_ACTIONS = 6 | |
| UPSCALE = 4 | |
| # (label, action_id) — held constant for the whole rollout | |
| PROBES = [("idle", 0), ("left", 1), ("right", 2), ("jump", 3)] | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--run-dir", type=str, default=None) | |
| parser.add_argument("--epoch", type=int, default=-1) | |
| parser.add_argument("--steps", type=int, default=48) | |
| parser.add_argument("--fps", type=int, default=10) | |
| args = parser.parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| run_dir = Path(args.run_dir) if args.run_dir else sorted(Path("outputs").glob("*/*"))[-1] | |
| ckpt = get_path_agent_ckpt(run_dir / "checkpoints", args.epoch) | |
| print(f"run dir: {run_dir}") | |
| print(f"checkpoint: {ckpt}") | |
| if not OmegaConf.has_resolver("eval"): | |
| OmegaConf.register_new_resolver("eval", eval) | |
| cfg = OmegaConf.load(run_dir / "config" / "trainer.yaml") | |
| agent = Agent(instantiate(cfg.agent, num_actions=NUM_ACTIONS)).to(device).eval() | |
| agent.load(ckpt) | |
| sampler = DiffusionSampler(agent.denoiser, instantiate(cfg.world_model_env.diffusion_sampler)) | |
| ctx = agent.denoiser.cfg.inner_model.num_steps_conditioning | |
| # one high-motion seed | |
| paths = [p for p in glob.glob("dataset/tobu/test/**/*.pt", recursive=True) if Path(p).stem.isdigit()] | |
| assert paths, "no test episodes under dataset/tobu/test" | |
| best = None | |
| for _ in range(200): | |
| ep = Episode.load(random.choice(paths)) | |
| if len(ep) < ctx + args.steps + 2: | |
| continue | |
| s = random.randrange(len(ep) - ctx - args.steps - 1) | |
| clip = ep.obs[s + ctx : s + ctx + args.steps] | |
| m = (clip[1:] - clip[:-1]).abs().mean().item() | |
| if best is None or m > best[0]: | |
| best = (m, ep, s) | |
| motion, ep, start = best | |
| print(f"seed motion {motion:.4f}, start {start}") | |
| n = len(PROBES) | |
| obs0 = ep.obs[start : start + ctx].unsqueeze(0).repeat(n, 1, 1, 1, 1).to(device) | |
| act0 = ep.act[start : start + ctx].unsqueeze(0).repeat(n, 1).to(device) | |
| forced = torch.tensor([a for _, a in PROBES], device=device) | |
| obs_buf, act_buf = obs0.clone(), act0.clone() | |
| frames = [] | |
| with torch.no_grad(): | |
| for t in range(args.steps): | |
| act_buf[:, -1] = forced | |
| next_obs, _ = sampler.sample(obs_buf, act_buf) | |
| next_obs = next_obs.clamp(-1, 1) | |
| obs_buf = obs_buf.roll(-1, dims=1) | |
| act_buf = act_buf.roll(-1, dims=1) | |
| obs_buf[:, -1] = next_obs | |
| frames.append(next_obs.cpu()) | |
| seq = torch.stack(frames) # [T,n,3,H,W] | |
| # numeric divergence vs idle (row 0) | |
| print("\nmean |frame - idle_frame| over the rollout (0 = identical to idle):") | |
| for i, (label, _) in enumerate(PROBES): | |
| div = (seq[:, i] - seq[:, 0]).abs().mean().item() | |
| bar = "#" * int(div * 400) | |
| print(f" {label:>5}: {div:.4f} {bar}") | |
| print("\nIf left/right/jump are all ~0.00, the model is IGNORING actions.\n") | |
| def to_uint8(x): # [T,3,H,W] -> [T,H,W,3] | |
| x = ((x + 1) / 2).permute(0, 2, 3, 1).numpy() | |
| x = (x * 255).clip(0, 255).astype(np.uint8) | |
| return x.repeat(UPSCALE, axis=1).repeat(UPSCALE, axis=2) | |
| rows = [to_uint8(seq[:, i]) for i in range(n)] # each [T,H,W,3] | |
| T, H, W, _ = rows[0].shape | |
| hsep = np.full((T, 6, W, 3), 200, np.uint8) | |
| grid = [] | |
| for i, r in enumerate(rows): | |
| grid.append(r) | |
| if i < n - 1: | |
| grid.append(hsep) | |
| out_frames = np.concatenate(grid, axis=1) # stack rows vertically | |
| out = f"action_test_epoch_{ckpt.stem.split('_')[-1]}.gif" | |
| try: | |
| imageio.mimsave(out, list(out_frames), fps=args.fps, loop=0) | |
| except TypeError: | |
| imageio.mimsave(out, list(out_frames), duration=1000 / args.fps, loop=0) | |
| print(f"wrote {out} (rows top->bottom: {', '.join(l for l, _ in PROBES)})") | |
| if __name__ == "__main__": | |
| main() | |