pyre_env / training /ppo /debug_episode.py
Krooz's picture
Upload folder using huggingface_hub
f5a6010 verified
"""
Debug a single Pyre episode using a trained PPO checkpoint.
Runs one episode on the specified difficulty, prints a rich step-by-step
trace, and produces a failure analysis at the end.
Usage
-----
# Debug a hard episode with the latest checkpoint
uv run python training/ppo/debug_episode.py \
--checkpoint artifacts/pyre_ppo_hard_ckpt.pt \
--difficulty hard
# Override difficulty or seed for reproducibility
uv run python training/ppo/debug_episode.py \
--checkpoint artifacts/pyre_ppo_hard_ckpt.pt \
--difficulty hard \
--seed 7
# Stochastic policy (not deterministic) to see exploration behaviour
uv run python training/ppo/debug_episode.py \
--checkpoint artifacts/pyre_ppo_hard_ckpt.pt \
--difficulty hard \
--stochastic
"""
from __future__ import annotations
import argparse
import sys
from collections import deque
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import torch
_HERE = Path(__file__).resolve().parent
_ROOT = _HERE.parent.parent
if str(_ROOT) not in sys.path:
sys.path.insert(0, str(_ROOT))
from training.ppo.train_torch_ppo import (
ACTION_KEYS,
ACTION_DIM,
ActorCritic,
ObservationEncoder,
RolloutBuffer,
action_index_to_env_action,
build_action_mask,
)
from training.ppo.train_torch_ppo_http import HttpPyreEnv
# ── ANSI colours (degrade gracefully on Windows) ──────────────────────────────
try:
import os
_COLOUR = os.get_terminal_size().columns > 0
except Exception:
_COLOUR = False
def _c(text: str, code: str) -> str:
return f"\033[{code}m{text}\033[0m" if _COLOUR else text
RED = lambda t: _c(t, "31")
YELLOW = lambda t: _c(t, "33")
GREEN = lambda t: _c(t, "32")
CYAN = lambda t: _c(t, "36")
BOLD = lambda t: _c(t, "1")
DIM = lambda t: _c(t, "2")
# ── Mini ASCII map renderer ───────────────────────────────────────────────────
_CELL_CHAR = {0: ".", 1: "#", 2: "/", 3: "+", 4: "E", 5: "X"}
_CELL_NAME = {0: "floor", 1: "wall", 2: "door(open)", 3: "door(closed)", 4: "EXIT", 5: "obstacle"}
def render_ascii_map(ms, visible_set: set, step: int) -> str:
"""Render a small ASCII map centered on the agent (15×15 window)."""
ax, ay = ms.agent_x, ms.agent_y
gw, gh = ms.grid_w, ms.grid_h
W = 15
x0 = max(0, ax - W // 2)
y0 = max(0, ay - W // 2)
x1 = min(gw, x0 + W)
y1 = min(gh, y0 + W)
lines = [f" Map window [{x0}:{x1}, {y0}:{y1}] agent=({ax},{ay})"]
for y in range(y0, y1):
row = []
for x in range(x0, x1):
idx = y * gw + x
if x == ax and y == ay:
row.append(BOLD(CYAN("@")))
continue
ct = int(ms.cell_grid[idx])
fire = float(ms.fire_grid[idx])
smoke = float(ms.smoke_grid[idx])
ch = _CELL_CHAR.get(ct, "?")
if (x, y) not in visible_set:
row.append(DIM("·"))
elif ct == 4:
row.append(GREEN(ch))
elif fire >= 0.5:
row.append(RED("F"))
elif fire > 0.1:
row.append(YELLOW("f"))
elif smoke >= 0.5:
row.append(YELLOW("s"))
else:
row.append(ch)
lines.append(" " + " ".join(row))
return "\n".join(lines)
# ── Per-step diagnostics ──────────────────────────────────────────────────────
def _nearest_exit(ms) -> Tuple[int, Optional[List]]:
ax, ay = ms.agent_x, ms.agent_y
gw = ms.grid_w
exits = ms.exit_positions
if not exits:
return 9999, None
dists = [(abs(ax - e[0]) + abs(ay - e[1]), e) for e in exits]
dists.sort()
return dists[0]
def _fire_summary(ms) -> str:
"""Fraction of grid cells that are actively burning."""
total = ms.grid_w * ms.grid_h
burning = sum(1 for v in ms.fire_grid if float(v) >= 0.3)
blocked = [e for e in ms.exit_positions if float(ms.fire_grid[e[1] * ms.grid_w + e[0]]) >= 0.5]
return (f"burning={burning}/{total} ({100*burning/total:.1f}%) "
f"exits_blocked={len(blocked)}/{len(ms.exit_positions)}")
def _health_bar(hp: float) -> str:
filled = int(hp / 5)
bar = "█" * filled + "░" * (20 - filled)
colour = GREEN if hp > 60 else (YELLOW if hp > 30 else RED)
return colour(f"[{bar}] {hp:.1f} HP")
# ── Main debug runner ─────────────────────────────────────────────────────────
def debug_episode(
checkpoint_path: str,
difficulty: str,
server: str,
deterministic: bool,
history_length: int,
hidden_sizes: Tuple[int, ...],
seed: Optional[int],
show_map: bool,
) -> None:
device = torch.device("cpu")
encoder = ObservationEncoder(mode="visible")
input_dim = encoder.base_dim * history_length
network = ActorCritic(input_dim, ACTION_DIM, hidden_sizes).to(device)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
network.load_state_dict(ckpt.get("network_state", ckpt))
network.eval()
trained_eps = ckpt.get("episode", "?")
print(BOLD(f"\n{'='*70}"))
print(BOLD(f" Pyre Episode Debugger"))
print(f" checkpoint : {checkpoint_path} (trained {trained_eps} eps)")
print(f" difficulty : {difficulty} | deterministic={deterministic}")
print(f" seed : {seed if seed is not None else 'random'}")
print(BOLD(f"{'='*70}\n"))
env = HttpPyreEnv(base_url=server)
if not env.health_check():
print(RED("[error] Server not reachable. Start it with: uv run server"))
sys.exit(1)
observation = env.reset(difficulty=difficulty, seed=seed)
ms0 = observation.map_state
print(BOLD(f"── Episode Start ──────────────────────────────────────────────────────"))
print(f" Map size : {ms0.grid_w}×{ms0.grid_h} template={ms0.template_name}")
print(f" Max steps : {ms0.max_steps}")
print(f" Fire spread: {ms0.fire_spread_rate:.3f} humidity={ms0.humidity:.3f} wind={ms0.wind_dir}")
d, nearest = _nearest_exit(ms0)
print(f" Agent spawn: ({ms0.agent_x}, {ms0.agent_y}) nearest_exit={nearest} dist={d}")
print(f" Exits : {ms0.exit_positions}")
print()
zero_frame = np.zeros(encoder.base_dim, dtype=np.float32)
frames: deque = deque([zero_frame.copy() for _ in range(history_length)], maxlen=history_length)
frames.append(encoder.encode(observation))
# ── Tracking ──
step_log = [] # list of dicts, one per step
loop_positions: deque = deque(maxlen=12)
visited: set = set()
total_reward = 0.0
prev_exit_dist = d
# ── Episode loop ──
with torch.no_grad():
while True:
state_vec = np.concatenate(list(frames), dtype=np.float32)
action_mask = build_action_mask(observation, exclude_look=True)
obs_t = torch.tensor(state_vec, dtype=torch.float32, device=device).unsqueeze(0)
mask_t = torch.tensor(action_mask, dtype=torch.float32, device=device).unsqueeze(0)
action_t, log_prob_t, value_t = network.act(obs_t, mask_t, deterministic=deterministic)
action_idx = int(action_t.item())
log_prob = float(log_prob_t.item())
value_est = float(value_t.item())
env_action = action_index_to_env_action(action_idx)
next_obs = env.step(env_action)
ms = next_obs.map_state
reward = float(next_obs.reward or 0.0)
done = bool(next_obs.done)
# Compute diagnostics
exit_dist, nearest_exit = _nearest_exit(ms) if ms else (9999, None)
exit_delta = prev_exit_dist - exit_dist # positive = moved closer
prev_exit_dist = exit_dist
cur_pos = (ms.agent_x, ms.agent_y) if ms else None
is_loop = cur_pos in loop_positions if cur_pos else False
is_new = cur_pos not in visited if cur_pos else False
if cur_pos:
loop_positions.append(cur_pos)
visited.add(cur_pos)
visible_set = {tuple(v) for v in (ms.visible_cells if ms else [])}
fire_nearby = False
if ms and cur_pos:
gw = ms.grid_w
for dx, dy in ((0,1),(0,-1),(1,0),(-1,0)):
nx, ny = cur_pos[0]+dx, cur_pos[1]+dy
if 0 <= nx < ms.grid_w and 0 <= ny < ms.grid_h:
if float(ms.fire_grid[ny*gw+nx]) > 0.15:
fire_nearby = True
break
total_reward += reward
step = len(step_log) + 1
# Classify this step
if next_obs.agent_evacuated:
flag = GREEN("✓ EVACUATED")
elif not next_obs.agent_health > 0 or (done and not next_obs.agent_evacuated):
death_cause = "DIED" if not (ms and ms.agent_alive) else "TIMEOUT"
flag = RED(f"✗ {death_cause}")
elif is_loop:
flag = YELLOW("⚠ LOOP")
elif fire_nearby:
flag = YELLOW("⚠ fire-adj")
else:
flag = ""
step_log.append({
"step": step,
"action": ACTION_KEYS[action_idx],
"hp": next_obs.agent_health,
"reward": reward,
"exit_dist": exit_dist,
"exit_delta": exit_delta,
"value_est": value_est,
"log_prob": log_prob,
"is_loop": is_loop,
"is_new": is_new,
"fire_nearby": fire_nearby,
"smoke": next_obs.smoke_level,
"flag": flag,
"pos": cur_pos,
"fire_summary": _fire_summary(ms) if ms else "",
})
# Print step line
delta_str = f"{exit_delta:+.0f}" if exit_delta != 0 else " 0"
loop_warn = YELLOW(" [LOOP]") if is_loop else ""
fire_warn = YELLOW(" [fire-adj]") if fire_nearby else ""
print(
f" step {step:03d} | {ACTION_KEYS[action_idx]:<38} | "
f"hp={next_obs.agent_health:5.1f} | "
f"reward={reward:+6.3f} | "
f"dist={exit_dist:3d}({delta_str}) | "
f"val={value_est:+5.2f} | "
f"smoke={next_obs.smoke_level:<8} "
f"{flag}{loop_warn}{fire_warn}",
flush=True,
)
# Every 10 steps print fire state + optional map
if step % 10 == 0:
print(DIM(f" fire: {step_log[-1]['fire_summary']}"))
if show_map and ms:
print(render_ascii_map(ms, visible_set, step))
print()
frames.append(encoder.encode(next_obs))
observation = next_obs
if done:
break
# ── Post-episode analysis ──────────────────────────────────────────────────
print(BOLD(f"\n{'='*70}"))
print(BOLD(" Post-Episode Analysis"))
print(BOLD(f"{'='*70}"))
outcome = GREEN("EVACUATED ✓") if step_log[-1]["flag"].strip().endswith("EVACUATED ✓") or observation.agent_evacuated \
else RED("FAILED ✗")
print(f"\n Outcome : {outcome}")
print(f" Total steps : {len(step_log)}")
print(f" Total reward : {total_reward:+.3f}")
print(f" Final HP : {_health_bar(observation.agent_health)}")
# Cause of failure
if not observation.agent_evacuated:
ms_final = observation.map_state
if ms_final and not ms_final.agent_alive:
print(f"\n {RED('Cause of failure:')} Agent DIED (health depleted by smoke/fire)")
elif len(step_log) >= (ms_final.max_steps if ms_final else 100):
print(f"\n {YELLOW('Cause of failure:')} TIMEOUT — ran out of steps")
else:
print(f"\n {YELLOW('Cause of failure:')} Unknown / server reported done")
# Loop analysis
loop_steps = [s for s in step_log if s["is_loop"]]
print(f"\n Loop steps : {len(loop_steps)} / {len(step_log)} "
f"({100*len(loop_steps)/max(1,len(step_log)):.1f}%)")
if loop_steps:
positions_looped = [s["pos"] for s in loop_steps]
from collections import Counter
top = Counter(positions_looped).most_common(3)
print(f" Most-looped positions: {top}")
# Health drain
hp_start = 100.0
hp_end = observation.agent_health
drain_per_step = (hp_start - hp_end) / max(1, len(step_log))
print(f"\n HP drain : {hp_start:.0f}{hp_end:.1f} ({drain_per_step:.2f} HP/step avg)")
# Reward breakdown
positive_steps = [s for s in step_log if s["reward"] > 0]
negative_steps = [s for s in step_log if s["reward"] < 0]
worst_5 = sorted(step_log, key=lambda s: s["reward"])[:5]
best_5 = sorted(step_log, key=lambda s: s["reward"], reverse=True)[:5]
print(f"\n Positive reward steps : {len(positive_steps)}")
print(f" Negative reward steps : {len(negative_steps)}")
print(f"\n {RED('5 worst steps')}:")
for s in worst_5:
print(f" step {s['step']:03d} reward={s['reward']:+.3f} "
f"action={s['action']:<35} dist={s['exit_dist']} hp={s['hp']:.1f}")
print(f"\n {GREEN('5 best steps')}:")
for s in best_5:
print(f" step {s['step']:03d} reward={s['reward']:+.3f} "
f"action={s['action']:<35} dist={s['exit_dist']} hp={s['hp']:.1f}")
# Exit distance trend
dists = [s["exit_dist"] for s in step_log]
print(f"\n Exit distance: start={dists[0]} min={min(dists)} "
f"final={dists[-1]} never_below_5={'YES' if min(dists) < 5 else 'NO'}")
# Fire at end
ms_final = observation.map_state
if ms_final:
print(f"\n Final fire state: {_fire_summary(ms_final)}")
blocked_at_end = [e for e in ms_final.exit_positions
if float(ms_final.fire_grid[e[1]*ms_final.grid_w+e[0]]) >= 0.5]
if blocked_at_end:
print(f" {RED('All exits were fire-blocked at episode end!')} {blocked_at_end}")
else:
open_exits = [e for e in ms_final.exit_positions
if float(ms_final.fire_grid[e[1]*ms_final.grid_w+e[0]]) < 0.5]
print(f" Open exits at end: {open_exits}")
# Diagnosis
print(BOLD(f"\n{'─'*70}"))
print(BOLD(" Diagnosis"))
print(BOLD(f"{'─'*70}"))
diags = []
if len(loop_steps) / max(1, len(step_log)) > 0.15:
diags.append(RED("⚠ Heavy looping") +
f" — agent revisited positions on {len(loop_steps)} steps. "
"Policy is stuck in a local movement pattern.")
if drain_per_step > 2.5:
diags.append(RED("⚠ Rapid health drain") +
f" — {drain_per_step:.1f} HP/step average. Agent is spending "
"too many steps inside smoke/fire zones.")
if min(dists) >= 10:
diags.append(RED("⚠ Never got close to exit") +
f" — minimum exit distance was {min(dists)}. "
"Agent may be exploring in the wrong direction.")
if ms_final:
blocked_exits = [e for e in ms_final.exit_positions
if float(ms_final.fire_grid[e[1]*ms_final.grid_w+e[0]]) >= 0.5]
if len(blocked_exits) == len(ms_final.exit_positions):
diags.append(RED("⚠ All exits blocked by fire at end") +
" — either exits ignited early or agent took too long. "
"Check if fire spread too fast (p_spread > 0.4).")
if len(step_log) >= (ms_final.max_steps if ms_final else 100) - 2:
diags.append(YELLOW("⚠ Timeout") +
" — agent ran out of steps. "
"Either map is too large to navigate in 100 steps, "
"or looping wasted too many steps.")
progress_steps = sum(1 for s in step_log if s["exit_delta"] > 0)
regress_steps = sum(1 for s in step_log if s["exit_delta"] < 0)
if regress_steps > progress_steps:
diags.append(YELLOW("⚠ More regression than progress") +
f" — agent moved away from exit more often ({regress_steps}x) "
f"than toward it ({progress_steps}x).")
if not diags:
diags.append(GREEN("No obvious single failure mode. "
"Run multiple seeds to find patterns."))
for d in diags:
print(f" • {d}\n")
print(BOLD(f"{'='*70}\n"))
# ── CLI ───────────────────────────────────────────────────────────────────────
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Debug a single Pyre PPO episode with rich diagnostics",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument("--checkpoint", type=str, default="artifacts/pyre_ppo_hard_ckpt.pt",
help="Path to the .pt checkpoint to load")
p.add_argument("--difficulty", type=str, default="hard",
choices=("easy", "medium", "hard"))
p.add_argument("--server", type=str, default="http://localhost:8000")
p.add_argument("--stochastic", action="store_true", default=False,
help="Use stochastic policy (default: deterministic greedy)")
p.add_argument("--history-length", type=int, default=4)
p.add_argument("--hidden-sizes", type=str, default="512,256,128")
p.add_argument("--seed", type=int, default=None,
help="Optional fixed seed for reproducibility")
p.add_argument("--map", action="store_true", default=False,
help="Print ASCII map every 10 steps")
return p.parse_args()
def main() -> None:
args = parse_args()
hidden_sizes = tuple(int(x) for x in args.hidden_sizes.split(","))
debug_episode(
checkpoint_path=args.checkpoint,
difficulty=args.difficulty,
server=args.server,
deterministic=not args.stochastic,
history_length=args.history_length,
hidden_sizes=hidden_sizes,
seed=args.seed,
show_map=args.map,
)
if __name__ == "__main__":
main()