neural-boy / src /action_test.py
sachinkumarsingh's picture
Upload folder using huggingface_hub
d548197 verified
Raw
History Blame Contribute Delete
4.79 kB
"""Action-sensitivity diagnostic for the world model.
From the SAME seed, roll out four dreams with the action FORCED to a fixed
value (idle / left / right / jump) and compare. If the four rows look the same
and divergence is tiny, the model is ignoring your buttons (controls not
learned). If they diverge, controls work and any failure is drift / seeding.
Run from the repo root:
python3 src/action_test.py
python3 src/action_test.py --steps 60 --epoch 30
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import argparse
import glob
import random
import imageio
import numpy as np
import torch
from hydra.utils import instantiate
from omegaconf import OmegaConf
from agent import Agent
from data.episode import Episode
from models.diffusion import DiffusionSampler
from utils import get_path_agent_ckpt
NUM_ACTIONS = 6
UPSCALE = 4
# (label, action_id) — held constant for the whole rollout
PROBES = [("idle", 0), ("left", 1), ("right", 2), ("jump", 3)]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--run-dir", type=str, default=None)
parser.add_argument("--epoch", type=int, default=-1)
parser.add_argument("--steps", type=int, default=48)
parser.add_argument("--fps", type=int, default=10)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
run_dir = Path(args.run_dir) if args.run_dir else sorted(Path("outputs").glob("*/*"))[-1]
ckpt = get_path_agent_ckpt(run_dir / "checkpoints", args.epoch)
print(f"run dir: {run_dir}")
print(f"checkpoint: {ckpt}")
if not OmegaConf.has_resolver("eval"):
OmegaConf.register_new_resolver("eval", eval)
cfg = OmegaConf.load(run_dir / "config" / "trainer.yaml")
agent = Agent(instantiate(cfg.agent, num_actions=NUM_ACTIONS)).to(device).eval()
agent.load(ckpt)
sampler = DiffusionSampler(agent.denoiser, instantiate(cfg.world_model_env.diffusion_sampler))
ctx = agent.denoiser.cfg.inner_model.num_steps_conditioning
# one high-motion seed
paths = [p for p in glob.glob("dataset/tobu/test/**/*.pt", recursive=True) if Path(p).stem.isdigit()]
assert paths, "no test episodes under dataset/tobu/test"
best = None
for _ in range(200):
ep = Episode.load(random.choice(paths))
if len(ep) < ctx + args.steps + 2:
continue
s = random.randrange(len(ep) - ctx - args.steps - 1)
clip = ep.obs[s + ctx : s + ctx + args.steps]
m = (clip[1:] - clip[:-1]).abs().mean().item()
if best is None or m > best[0]:
best = (m, ep, s)
motion, ep, start = best
print(f"seed motion {motion:.4f}, start {start}")
n = len(PROBES)
obs0 = ep.obs[start : start + ctx].unsqueeze(0).repeat(n, 1, 1, 1, 1).to(device)
act0 = ep.act[start : start + ctx].unsqueeze(0).repeat(n, 1).to(device)
forced = torch.tensor([a for _, a in PROBES], device=device)
obs_buf, act_buf = obs0.clone(), act0.clone()
frames = []
with torch.no_grad():
for t in range(args.steps):
act_buf[:, -1] = forced
next_obs, _ = sampler.sample(obs_buf, act_buf)
next_obs = next_obs.clamp(-1, 1)
obs_buf = obs_buf.roll(-1, dims=1)
act_buf = act_buf.roll(-1, dims=1)
obs_buf[:, -1] = next_obs
frames.append(next_obs.cpu())
seq = torch.stack(frames) # [T,n,3,H,W]
# numeric divergence vs idle (row 0)
print("\nmean |frame - idle_frame| over the rollout (0 = identical to idle):")
for i, (label, _) in enumerate(PROBES):
div = (seq[:, i] - seq[:, 0]).abs().mean().item()
bar = "#" * int(div * 400)
print(f" {label:>5}: {div:.4f} {bar}")
print("\nIf left/right/jump are all ~0.00, the model is IGNORING actions.\n")
def to_uint8(x): # [T,3,H,W] -> [T,H,W,3]
x = ((x + 1) / 2).permute(0, 2, 3, 1).numpy()
x = (x * 255).clip(0, 255).astype(np.uint8)
return x.repeat(UPSCALE, axis=1).repeat(UPSCALE, axis=2)
rows = [to_uint8(seq[:, i]) for i in range(n)] # each [T,H,W,3]
T, H, W, _ = rows[0].shape
hsep = np.full((T, 6, W, 3), 200, np.uint8)
grid = []
for i, r in enumerate(rows):
grid.append(r)
if i < n - 1:
grid.append(hsep)
out_frames = np.concatenate(grid, axis=1) # stack rows vertically
out = f"action_test_epoch_{ckpt.stem.split('_')[-1]}.gif"
try:
imageio.mimsave(out, list(out_frames), fps=args.fps, loop=0)
except TypeError:
imageio.mimsave(out, list(out_frames), duration=1000 / args.fps, loop=0)
print(f"wrote {out} (rows top->bottom: {', '.join(l for l, _ in PROBES)})")
if __name__ == "__main__":
main()