pyre_env / examples /train_rl_agent.py
Akshaykumarbm's picture
Upload folder using huggingface_hub
443c22e verified
"""Train a deep RL baseline directly against the local Pyre environment.
This script makes the environment contract explicit:
- Observation: encoded from `PyreObservation.map_state` into a fixed-length vector
- Action: fixed discrete action table with a runtime validity mask from `available_actions_hint`
- Reward: the environment's composite reward returned by `PyreEnvironment.step()`
It uses a self-contained NumPy actor-critic implementation so it can run in
this repository without external ML dependencies.
Examples:
python examples/train_rl_agent.py --episodes 150 --difficulty easy
python examples/train_rl_agent.py --episodes 300 --difficulty-schedule easy,medium
python examples/train_rl_agent.py --episodes 200 --difficulty medium --observation-mode full
python examples/train_rl_agent.py --describe-only
"""
from __future__ import annotations
import argparse
import csv
import json
import math
import re
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Sequence
import numpy as np
from pyre_env.models import PyreAction, PyreObservation
from pyre_env.server.pyre_env_environment import PyreEnvironment
MAX_GRID_W = 24
MAX_GRID_H = 24
MAX_DOORS = 16
DIRECTIONS = ("north", "south", "west", "east")
WINDS = ("CALM", "NORTH", "SOUTH", "WEST", "EAST")
DIFFICULTIES = ("easy", "medium", "hard")
MOVE_KEYS = [f"move(direction='{d}')" for d in DIRECTIONS]
LOOK_KEYS = [f"look(direction='{d}')" for d in DIRECTIONS]
WAIT_KEY = "wait()"
OPEN_KEYS = [f"door(target_id='door_{i}', door_state='open')" for i in range(1, MAX_DOORS + 1)]
CLOSE_KEYS = [f"door(target_id='door_{i}', door_state='close')" for i in range(1, MAX_DOORS + 1)]
ACTION_KEYS = MOVE_KEYS + LOOK_KEYS + [WAIT_KEY] + OPEN_KEYS + CLOSE_KEYS
ACTION_DIM = len(ACTION_KEYS)
ACTION_TO_INDEX = {key: idx for idx, key in enumerate(ACTION_KEYS)}
_MOVE_RE = re.compile(r"move\(direction='(north|south|west|east)'\)")
_LOOK_RE = re.compile(r"look\(direction='(north|south|west|east)'\)")
_DOOR_RE = re.compile(r"door\(target_id='(door_(\d+))', door_state='(open|close)'\)")
def _one_hot(index: int, size: int) -> np.ndarray:
arr = np.zeros(size, dtype=np.float32)
if 0 <= index < size:
arr[index] = 1.0
return arr
def action_index_to_env_action(index: int) -> PyreAction:
if 0 <= index < 4:
return PyreAction(action="move", direction=DIRECTIONS[index])
if 4 <= index < 8:
return PyreAction(action="look", direction=DIRECTIONS[index - 4])
if index == 8:
return PyreAction(action="wait")
if 9 <= index < 9 + MAX_DOORS:
door_id = f"door_{index - 8}"
return PyreAction(action="door", target_id=door_id, door_state="open")
door_slot = index - (9 + MAX_DOORS)
door_id = f"door_{door_slot + 1}"
return PyreAction(action="door", target_id=door_id, door_state="close")
def build_action_mask(observation: PyreObservation) -> np.ndarray:
mask = np.zeros(ACTION_DIM, dtype=np.float32)
for hint in observation.available_actions_hint:
idx = ACTION_TO_INDEX.get(hint)
if idx is not None:
mask[idx] = 1.0
continue
match = _MOVE_RE.fullmatch(hint)
if match:
mask[ACTION_TO_INDEX[f"move(direction='{match.group(1)}')"]] = 1.0
continue
match = _LOOK_RE.fullmatch(hint)
if match:
mask[ACTION_TO_INDEX[f"look(direction='{match.group(1)}')"]] = 1.0
continue
match = _DOOR_RE.fullmatch(hint)
if match:
door_id = match.group(1)
door_num = int(match.group(2))
state = match.group(3)
if 1 <= door_num <= MAX_DOORS:
mask[ACTION_TO_INDEX[f"door(target_id='{door_id}', door_state='{state}')"]] = 1.0
if mask.sum() == 0:
mask[ACTION_TO_INDEX[WAIT_KEY]] = 1.0
return mask
class ObservationEncoder:
"""Encode Pyre observations into a fixed-size float vector."""
def __init__(self, mode: str = "visible"):
if mode not in {"visible", "full"}:
raise ValueError(f"Unsupported observation mode: {mode}")
self.mode = mode
self.base_dim = MAX_GRID_W * MAX_GRID_H * 10 + 22
def encode(self, observation: PyreObservation) -> np.ndarray:
map_state = observation.map_state
if map_state is None:
raise ValueError("PyreObservation.map_state is required for RL training.")
cell_one_hot = np.zeros((MAX_GRID_H, MAX_GRID_W, 6), dtype=np.float32)
fire_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32)
smoke_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32)
visible_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32)
agent_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32)
visible = {(x, y) for x, y in map_state.visible_cells}
for y in range(map_state.grid_h):
for x in range(map_state.grid_w):
if self.mode == "visible" and (x, y) not in visible and (x, y) != (map_state.agent_x, map_state.agent_y):
continue
i = y * map_state.grid_w + x
cell_type = int(map_state.cell_grid[i])
if 0 <= cell_type <= 5:
cell_one_hot[y, x, cell_type] = 1.0
fire_channel[y, x] = float(map_state.fire_grid[i])
smoke_channel[y, x] = float(map_state.smoke_grid[i])
visible_channel[y, x] = 1.0 if (x, y) in visible else 0.0
if 0 <= map_state.agent_x < MAX_GRID_W and 0 <= map_state.agent_y < MAX_GRID_H:
agent_channel[map_state.agent_y, map_state.agent_x] = 1.0
grid_features = np.concatenate(
[
cell_one_hot.reshape(-1),
fire_channel.reshape(-1),
smoke_channel.reshape(-1),
visible_channel.reshape(-1),
agent_channel.reshape(-1),
]
)
metadata = observation.metadata or {}
wind_dir = str(metadata.get("wind_dir", map_state.wind_dir or "CALM")).upper()
difficulty = str(metadata.get("difficulty", "medium")).lower()
wind_index = WINDS.index(wind_dir) if wind_dir in WINDS else 0
difficulty_index = DIFFICULTIES.index(difficulty) if difficulty in DIFFICULTIES else 1
global_features = np.concatenate(
[
np.array(
[
float(observation.agent_health) / 100.0,
float(map_state.agent_health) / 100.0,
float(map_state.step_count) / max(1, map_state.max_steps),
float(map_state.fire_spread_rate),
float(map_state.humidity),
float(map_state.agent_x) / max(1, map_state.grid_w - 1),
float(map_state.agent_y) / max(1, map_state.grid_h - 1),
float(metadata.get("nearest_exit_distance", MAX_GRID_W + MAX_GRID_H) or 0.0) / float(MAX_GRID_W + MAX_GRID_H),
float(metadata.get("reachable_exit_count", 0.0)) / 4.0,
float(metadata.get("visible_cell_count", 0.0)) / float(MAX_GRID_W * MAX_GRID_H),
float(metadata.get("fire_sources", 0.0)) / 5.0,
{"none": 0.0, "light": 0.33, "moderate": 0.66, "heavy": 1.0}.get(observation.smoke_level, 0.0),
1.0 if map_state.agent_alive else 0.0,
1.0 if map_state.agent_evacuated else 0.0,
],
dtype=np.float32,
),
_one_hot(wind_index, len(WINDS)),
_one_hot(difficulty_index, len(DIFFICULTIES)),
]
)
return np.concatenate([grid_features, global_features]).astype(np.float32)
def describe(self, history_length: int) -> str:
grid_text = (
f"Observation mode `{self.mode}` encodes a {MAX_GRID_W}x{MAX_GRID_H} padded map with "
"10 channels per cell: 6-way cell type one-hot, fire intensity, smoke intensity, visible mask, and agent mask."
)
if self.mode == "visible":
visibility_text = "Only currently visible cells are populated; unseen cells stay zeroed."
else:
visibility_text = "The full ground-truth map is exposed for curriculum/debug use."
return (
f"{grid_text} {visibility_text} "
f"Global features add health, step progress, fire parameters, position, exit-distance metadata, smoke severity, wind, and difficulty. "
f"{history_length} encoded frames are stacked, so the network input dimension is {self.base_dim * history_length}."
)
def softmax_with_mask(logits: np.ndarray, mask: np.ndarray) -> np.ndarray:
masked_logits = np.where(mask > 0.0, logits, -1e9)
max_logits = np.max(masked_logits, axis=1, keepdims=True)
exps = np.exp(masked_logits - max_logits) * mask
denom = np.sum(exps, axis=1, keepdims=True)
denom = np.where(denom <= 0.0, 1.0, denom)
return exps / denom
class AdamOptimizer:
def __init__(self, params: Dict[str, np.ndarray], lr: float = 3e-4, beta1: float = 0.9, beta2: float = 0.999):
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.eps = 1e-8
self.t = 0
self.m = {k: np.zeros_like(v) for k, v in params.items()}
self.v = {k: np.zeros_like(v) for k, v in params.items()}
def step(self, params: Dict[str, np.ndarray], grads: Dict[str, np.ndarray], clip_norm: float = 1.0) -> None:
total_norm_sq = 0.0
for grad in grads.values():
total_norm_sq += float(np.sum(grad * grad))
total_norm = math.sqrt(total_norm_sq)
scale = 1.0
if total_norm > clip_norm:
scale = clip_norm / (total_norm + 1e-8)
self.t += 1
for name, param in params.items():
grad = grads[name] * scale
self.m[name] = self.beta1 * self.m[name] + (1.0 - self.beta1) * grad
self.v[name] = self.beta2 * self.v[name] + (1.0 - self.beta2) * (grad * grad)
m_hat = self.m[name] / (1.0 - self.beta1 ** self.t)
v_hat = self.v[name] / (1.0 - self.beta2 ** self.t)
params[name] -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps)
class PolicyValueNetwork:
def __init__(self, input_dim: int, action_dim: int, rng: np.random.Generator, hidden_sizes: Sequence[int] = (256, 128)):
h1, h2 = hidden_sizes
self.params: Dict[str, np.ndarray] = {
"w1": self._init_weight(rng, input_dim, h1),
"b1": np.zeros(h1, dtype=np.float32),
"w2": self._init_weight(rng, h1, h2),
"b2": np.zeros(h2, dtype=np.float32),
"wp": self._init_weight(rng, h2, action_dim),
"bp": np.zeros(action_dim, dtype=np.float32),
"wv": self._init_weight(rng, h2, 1),
"bv": np.zeros(1, dtype=np.float32),
}
self.optimizer = AdamOptimizer(self.params)
@staticmethod
def _init_weight(rng: np.random.Generator, in_dim: int, out_dim: int) -> np.ndarray:
scale = math.sqrt(2.0 / max(1, in_dim + out_dim))
return (rng.standard_normal((in_dim, out_dim)) * scale).astype(np.float32)
def forward(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray]]:
z1 = x @ self.params["w1"] + self.params["b1"]
h1 = np.tanh(z1)
z2 = h1 @ self.params["w2"] + self.params["b2"]
h2 = np.tanh(z2)
logits = h2 @ self.params["wp"] + self.params["bp"]
values = (h2 @ self.params["wv"] + self.params["bv"]).reshape(-1)
cache = {"x": x, "h1": h1, "h2": h2}
return logits, values, cache
def predict(self, x: np.ndarray, mask: np.ndarray) -> tuple[np.ndarray, float]:
logits, values, _ = self.forward(x[None, :])
probs = softmax_with_mask(logits, mask[None, :])[0]
return probs, float(values[0])
def update(
self,
states: np.ndarray,
masks: np.ndarray,
actions: np.ndarray,
returns: np.ndarray,
advantages: np.ndarray,
value_coef: float = 0.5,
) -> Dict[str, float]:
logits, values, cache = self.forward(states)
probs = softmax_with_mask(logits, masks)
batch_size = max(1, states.shape[0])
grad_logits = probs.copy()
grad_logits[np.arange(batch_size), actions] -= 1.0
grad_logits *= advantages[:, None] / batch_size
grad_logits *= masks
grad_values = ((values - returns)[:, None] * value_coef) / batch_size
grads: Dict[str, np.ndarray] = {}
grads["wp"] = cache["h2"].T @ grad_logits
grads["bp"] = np.sum(grad_logits, axis=0)
grads["wv"] = cache["h2"].T @ grad_values
grads["bv"] = np.sum(grad_values, axis=0)
dh2 = grad_logits @ self.params["wp"].T + grad_values @ self.params["wv"].T
dz2 = dh2 * (1.0 - cache["h2"] ** 2)
grads["w2"] = cache["h1"].T @ dz2
grads["b2"] = np.sum(dz2, axis=0)
dh1 = dz2 @ self.params["w2"].T
dz1 = dh1 * (1.0 - cache["h1"] ** 2)
grads["w1"] = cache["x"].T @ dz1
grads["b1"] = np.sum(dz1, axis=0)
self.optimizer.step(self.params, grads, clip_norm=1.0)
chosen_probs = np.clip(probs[np.arange(batch_size), actions], 1e-8, 1.0)
policy_loss = float(-np.mean(advantages * np.log(chosen_probs)))
value_loss = float(0.5 * np.mean((values - returns) ** 2))
entropy = float(-np.mean(np.sum(np.where(probs > 0.0, probs * np.log(np.clip(probs, 1e-8, 1.0)), 0.0), axis=1)))
return {
"policy_loss": policy_loss,
"value_loss": value_loss,
"entropy": entropy,
"mean_value": float(np.mean(values)),
}
def save(self, path: Path, metadata: Dict[str, object]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
arrays = {name: value for name, value in self.params.items()}
arrays["metadata_json"] = np.array(json.dumps(metadata))
np.savez(path, **arrays)
@dataclass
class Trajectory:
states: List[np.ndarray]
masks: List[np.ndarray]
actions: List[int]
rewards: List[float]
values: List[float]
evacuated: bool
final_health: float
steps: int
total_reward: float
def compute_gae(
rewards: Sequence[float],
values: Sequence[float],
gamma: float,
gae_lambda: float,
) -> tuple[np.ndarray, np.ndarray]:
rewards_arr = np.asarray(rewards, dtype=np.float32)
values_arr = np.asarray(values, dtype=np.float32)
advantages = np.zeros(len(rewards_arr), dtype=np.float32)
gae = 0.0
next_value = 0.0
for i in range(len(rewards_arr) - 1, -1, -1):
delta = rewards_arr[i] + gamma * next_value - values_arr[i]
gae = delta + gamma * gae_lambda * gae
advantages[i] = gae
next_value = values_arr[i]
returns = advantages + values_arr
return returns.astype(np.float32), advantages.astype(np.float32)
def select_action(
network: PolicyValueNetwork,
state_vec: np.ndarray,
mask: np.ndarray,
rng: np.random.Generator,
greedy: bool = False,
) -> tuple[int, float]:
probs, value = network.predict(state_vec, mask)
valid_indices = np.flatnonzero(mask > 0.0)
if len(valid_indices) == 0:
return ACTION_TO_INDEX[WAIT_KEY], value
if greedy:
best_local = int(np.argmax(probs[valid_indices]))
return int(valid_indices[best_local]), value
return int(rng.choice(np.arange(len(probs)), p=probs)), value
def build_stacked_state(frames: deque[np.ndarray]) -> np.ndarray:
return np.concatenate(list(frames), dtype=np.float32)
def run_episode(
env: PyreEnvironment,
network: PolicyValueNetwork,
encoder: ObservationEncoder,
rng: np.random.Generator,
difficulty: str,
history_length: int,
greedy: bool = False,
) -> Trajectory:
observation = env.reset(difficulty=difficulty)
zero_frame = np.zeros(encoder.base_dim, dtype=np.float32)
frames: deque[np.ndarray] = deque([zero_frame.copy() for _ in range(history_length)], maxlen=history_length)
frames.append(encoder.encode(observation))
states: List[np.ndarray] = []
masks: List[np.ndarray] = []
actions: List[int] = []
rewards: List[float] = []
values: List[float] = []
total_reward = 0.0
final_health = observation.agent_health
evacuated = False
steps = 0
while True:
state_vec = build_stacked_state(frames)
mask = build_action_mask(observation)
action_idx, value = select_action(network, state_vec, mask, rng, greedy=greedy)
action = action_index_to_env_action(action_idx)
next_obs = env.step(action)
reward = float(next_obs.reward or 0.0)
states.append(state_vec)
masks.append(mask)
actions.append(action_idx)
rewards.append(reward)
values.append(value)
total_reward += reward
steps += 1
final_health = next_obs.agent_health
evacuated = next_obs.agent_evacuated
frames.append(encoder.encode(next_obs))
observation = next_obs
if next_obs.done:
break
return Trajectory(
states=states,
masks=masks,
actions=actions,
rewards=rewards,
values=values,
evacuated=evacuated,
final_health=final_health,
steps=steps,
total_reward=total_reward,
)
def evaluate_policy(
env: PyreEnvironment,
network: PolicyValueNetwork,
encoder: ObservationEncoder,
rng: np.random.Generator,
difficulty: str,
history_length: int,
episodes: int,
) -> Dict[str, float]:
rewards = []
evacuations = 0
lengths = []
for _ in range(episodes):
traj = run_episode(env, network, encoder, rng, difficulty, history_length, greedy=True)
rewards.append(traj.total_reward)
lengths.append(traj.steps)
evacuations += int(traj.evacuated)
return {
"eval_reward_mean": float(np.mean(rewards)) if rewards else 0.0,
"eval_reward_max": float(np.max(rewards)) if rewards else 0.0,
"eval_success_rate": float(evacuations / max(1, episodes)),
"eval_steps_mean": float(np.mean(lengths)) if lengths else 0.0,
}
def expand_difficulty_schedule(schedule_text: str, episodes: int) -> List[str]:
stages = [part.strip().lower() for part in schedule_text.split(",") if part.strip()]
if not stages:
stages = ["medium"]
for stage in stages:
if stage not in DIFFICULTIES:
raise ValueError(f"Invalid difficulty in schedule: {stage}")
segment = max(1, episodes // len(stages))
expanded: List[str] = []
for stage in stages:
expanded.extend([stage] * segment)
while len(expanded) < episodes:
expanded.append(stages[-1])
return expanded[:episodes]
def describe_environment_contract(encoder: ObservationEncoder, history_length: int) -> str:
action_text = (
f"Action space has {ACTION_DIM} fixed discrete actions: 4 moves, 4 looks, wait, "
f"{MAX_DOORS} door-open slots, and {MAX_DOORS} door-close slots. "
"A per-step mask from `available_actions_hint` prevents invalid actions."
)
reward_text = (
"Reward comes directly from the environment's composite rubric: time penalty, exit progress, "
"progress regression penalty, safe-progress bonus, danger penalty, health-drain penalty, "
"strategic door bonus, exploration bonus, plus terminal evacuation/death/timeout/near-miss/time bonuses."
)
return "\n".join(
[
"Pyre RL contract",
encoder.describe(history_length),
action_text,
reward_text,
]
)
def _moving_average(values: Sequence[float], window: int) -> List[float]:
if not values:
return []
out: List[float] = []
run = 0.0
q: deque[float] = deque()
for value in values:
q.append(float(value))
run += float(value)
if len(q) > window:
run -= q.popleft()
out.append(run / len(q))
return out
def save_metrics_csv(path: Path, rows: List[Dict[str, float | int | str]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
if not rows:
return
with path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
writer.writeheader()
writer.writerows(rows)
def save_training_graph(path: Path, episode_rows: List[Dict[str, float | int | str]], eval_rows: List[Dict[str, float | int | str]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
if not episode_rows:
return
width = 1200
height = 720
margin_left = 80
margin_right = 40
margin_top = 50
margin_bottom = 60
plot_w = width - margin_left - margin_right
plot_h = height - margin_top - margin_bottom
episodes = [int(r["episode"]) for r in episode_rows]
rewards = [float(r["reward"]) for r in episode_rows]
reward_ma = _moving_average(rewards, 20)
success_ma = _moving_average([float(r["evacuated"]) for r in episode_rows], 20)
all_reward_values = rewards + reward_ma + [float(r["reward_mean"]) for r in eval_rows] + [float(r["reward_max"]) for r in eval_rows]
y_min = min(all_reward_values) if all_reward_values else -1.0
y_max = max(all_reward_values) if all_reward_values else 1.0
if abs(y_max - y_min) < 1e-6:
y_min -= 1.0
y_max += 1.0
y_pad = 0.1 * (y_max - y_min)
y_min -= y_pad
y_max += y_pad
max_episode = max(episodes) if episodes else 1
def x_pos(ep: float) -> float:
return margin_left + (float(ep) - 1.0) / max(1.0, max_episode - 1.0) * plot_w
def y_pos_reward(value: float) -> float:
return margin_top + (y_max - float(value)) / max(1e-6, (y_max - y_min)) * plot_h
def y_pos_success(value: float) -> float:
return margin_top + (1.0 - float(value)) * plot_h
def polyline(points: List[tuple[float, float]]) -> str:
return " ".join(f"{x:.1f},{y:.1f}" for x, y in points)
reward_points = [(x_pos(ep), y_pos_reward(val)) for ep, val in zip(episodes, rewards)]
reward_ma_points = [(x_pos(ep), y_pos_reward(val)) for ep, val in zip(episodes, reward_ma)]
success_points = [(x_pos(ep), y_pos_success(val)) for ep, val in zip(episodes, success_ma)]
eval_points = [(x_pos(float(r["episode"])), y_pos_success(float(r["success_rate"]))) for r in eval_rows]
episode_ticks = [1, max_episode // 4, max_episode // 2, (3 * max_episode) // 4, max_episode]
episode_ticks = sorted(set(t for t in episode_ticks if t >= 1))
reward_ticks = [y_min + (y_max - y_min) * i / 4.0 for i in range(5)]
success_ticks = [0.0, 0.25, 0.5, 0.75, 1.0]
svg = []
svg.append(f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">')
svg.append('<rect width="100%" height="100%" fill="#f7f5ef"/>')
svg.append('<text x="80" y="28" font-family="Georgia, serif" font-size="24" fill="#1d2a38">Pyre RL Training</text>')
svg.append('<text x="80" y="48" font-family="Georgia, serif" font-size="13" fill="#5b6770">Reward curves on left axis, success rate on right axis</text>')
svg.append(f'<rect x="{margin_left}" y="{margin_top}" width="{plot_w}" height="{plot_h}" fill="#fffdf8" stroke="#d1c9b8"/>')
for tick in episode_ticks:
x = x_pos(float(tick))
svg.append(f'<line x1="{x:.1f}" y1="{margin_top}" x2="{x:.1f}" y2="{margin_top + plot_h}" stroke="#ece7db" />')
svg.append(f'<text x="{x:.1f}" y="{height - 24}" text-anchor="middle" font-family="Georgia, serif" font-size="12" fill="#5b6770">{tick}</text>')
for tick in reward_ticks:
y = y_pos_reward(tick)
svg.append(f'<line x1="{margin_left}" y1="{y:.1f}" x2="{margin_left + plot_w}" y2="{y:.1f}" stroke="#ece7db" />')
svg.append(f'<text x="{margin_left - 10}" y="{y + 4:.1f}" text-anchor="end" font-family="Georgia, serif" font-size="12" fill="#8a4b08">{tick:.1f}</text>')
for tick in success_ticks:
y = y_pos_success(tick)
svg.append(f'<text x="{margin_left + plot_w + 10}" y="{y + 4:.1f}" font-family="Georgia, serif" font-size="12" fill="#0d5b6b">{tick:.2f}</text>')
svg.append(f'<polyline fill="none" stroke="#c5bfb1" stroke-width="1.5" points="{polyline(reward_points)}"/>')
svg.append(f'<polyline fill="none" stroke="#c1661c" stroke-width="3" points="{polyline(reward_ma_points)}"/>')
svg.append(f'<polyline fill="none" stroke="#127a8a" stroke-width="3" points="{polyline(success_points)}"/>')
for x, y in eval_points:
svg.append(f'<circle cx="{x:.1f}" cy="{y:.1f}" r="4.5" fill="#0d5b6b" stroke="#ffffff" stroke-width="1.5"/>')
legend_y = height - 18
svg.append(f'<line x1="80" y1="{legend_y}" x2="110" y2="{legend_y}" stroke="#c1661c" stroke-width="3"/>')
svg.append(f'<text x="118" y="{legend_y + 4}" font-family="Georgia, serif" font-size="12" fill="#1d2a38">Reward moving average</text>')
svg.append(f'<line x1="300" y1="{legend_y}" x2="330" y2="{legend_y}" stroke="#127a8a" stroke-width="3"/>')
svg.append(f'<text x="338" y="{legend_y + 4}" font-family="Georgia, serif" font-size="12" fill="#1d2a38">Success moving average</text>')
svg.append(f'<line x1="510" y1="{legend_y}" x2="540" y2="{legend_y}" stroke="#c5bfb1" stroke-width="1.5"/>')
svg.append(f'<text x="548" y="{legend_y + 4}" font-family="Georgia, serif" font-size="12" fill="#1d2a38">Episode reward</text>')
svg.append(f'<circle cx="700" cy="{legend_y}" r="4.5" fill="#0d5b6b" stroke="#ffffff" stroke-width="1.5"/>')
svg.append(f'<text x="712" y="{legend_y + 4}" font-family="Georgia, serif" font-size="12" fill="#1d2a38">Eval success checkpoints</text>')
svg.append("</svg>")
path.write_text("\n".join(svg), encoding="utf-8")
def train(args: argparse.Namespace) -> None:
rng = np.random.default_rng(args.seed)
encoder = ObservationEncoder(mode=args.observation_mode)
difficulty_schedule = expand_difficulty_schedule(args.difficulty_schedule, args.episodes)
input_dim = encoder.base_dim * args.history_length
network = PolicyValueNetwork(input_dim=input_dim, action_dim=ACTION_DIM, rng=rng)
env = PyreEnvironment(max_steps=args.max_steps)
print(describe_environment_contract(encoder, args.history_length))
print("")
batch_states: List[np.ndarray] = []
batch_masks: List[np.ndarray] = []
batch_actions: List[int] = []
batch_returns: List[np.ndarray] = []
batch_advantages: List[np.ndarray] = []
reward_window: deque[float] = deque(maxlen=20)
success_window: deque[float] = deque(maxlen=20)
episode_metrics: List[Dict[str, float | int | str]] = []
eval_metrics_rows: List[Dict[str, float | int | str]] = []
for episode_idx in range(args.episodes):
difficulty = difficulty_schedule[episode_idx] if args.difficulty_schedule else args.difficulty
traj = run_episode(
env=env,
network=network,
encoder=encoder,
rng=rng,
difficulty=difficulty,
history_length=args.history_length,
greedy=False,
)
returns, advantages = compute_gae(traj.rewards, traj.values, args.gamma, args.gae_lambda)
batch_states.extend(traj.states)
batch_masks.extend(traj.masks)
batch_actions.extend(traj.actions)
batch_returns.append(returns)
batch_advantages.append(advantages)
reward_window.append(traj.total_reward)
success_window.append(float(traj.evacuated))
episode_metrics.append(
{
"episode": episode_idx + 1,
"difficulty": difficulty,
"reward": round(traj.total_reward, 4),
"evacuated": int(traj.evacuated),
"steps": traj.steps,
"final_health": round(traj.final_health, 2),
"reward_mean_20": round(float(np.mean(reward_window)), 4),
"success_rate_20": round(float(np.mean(success_window)), 4),
}
)
print(
f"episode={episode_idx + 1:04d} difficulty={difficulty:<6} "
f"steps={traj.steps:03d} reward={traj.total_reward:+8.3f} "
f"evacuated={int(traj.evacuated)} health={traj.final_health:6.1f}"
)
should_update = (episode_idx + 1) % args.update_every == 0 or (episode_idx + 1) == args.episodes
if should_update and batch_states:
states_arr = np.asarray(batch_states, dtype=np.float32)
masks_arr = np.asarray(batch_masks, dtype=np.float32)
actions_arr = np.asarray(batch_actions, dtype=np.int64)
returns_arr = np.concatenate(batch_returns).astype(np.float32)
advantages_arr = np.concatenate(batch_advantages).astype(np.float32)
advantages_arr = (advantages_arr - advantages_arr.mean()) / (advantages_arr.std() + 1e-8)
network.optimizer.lr = args.learning_rate
metrics = {}
for _ in range(args.update_epochs):
order = rng.permutation(len(states_arr))
for start in range(0, len(states_arr), args.minibatch_size):
idx = order[start:start + args.minibatch_size]
metrics = network.update(
states=states_arr[idx],
masks=masks_arr[idx],
actions=actions_arr[idx],
returns=returns_arr[idx],
advantages=advantages_arr[idx],
value_coef=args.value_coef,
)
print(
f"update episodes={episode_idx + 1:04d} samples={len(states_arr):05d} "
f"reward_mean20={np.mean(reward_window):+8.3f} success20={np.mean(success_window):.2f} "
f"policy_loss={metrics['policy_loss']:+.4f} value_loss={metrics['value_loss']:.4f} "
f"entropy={metrics['entropy']:.4f}"
)
batch_states.clear()
batch_masks.clear()
batch_actions.clear()
batch_returns.clear()
batch_advantages.clear()
should_eval = args.eval_every > 0 and ((episode_idx + 1) % args.eval_every == 0 or (episode_idx + 1) == args.episodes)
if should_eval:
eval_metrics = evaluate_policy(
env=env,
network=network,
encoder=encoder,
rng=rng,
difficulty=args.eval_difficulty,
history_length=args.history_length,
episodes=args.eval_episodes,
)
print(
f"eval episodes={episode_idx + 1:04d} difficulty={args.eval_difficulty:<6} "
f"reward_mean={eval_metrics['eval_reward_mean']:+8.3f} "
f"reward_max={eval_metrics['eval_reward_max']:+8.3f} "
f"success={eval_metrics['eval_success_rate']:.2f} "
f"steps={eval_metrics['eval_steps_mean']:.1f}"
)
eval_metrics_rows.append(
{
"episode": episode_idx + 1,
"difficulty": args.eval_difficulty,
"reward_mean": round(eval_metrics["eval_reward_mean"], 4),
"reward_max": round(eval_metrics["eval_reward_max"], 4),
"success_rate": round(eval_metrics["eval_success_rate"], 4),
"steps_mean": round(eval_metrics["eval_steps_mean"], 4),
}
)
if args.output:
output_path = Path(args.output)
network.save(
output_path,
metadata={
"observation_mode": args.observation_mode,
"history_length": args.history_length,
"episodes": args.episodes,
"difficulty": args.difficulty,
"difficulty_schedule": args.difficulty_schedule,
"gamma": args.gamma,
"gae_lambda": args.gae_lambda,
"learning_rate": args.learning_rate,
"update_epochs": args.update_epochs,
"minibatch_size": args.minibatch_size,
"action_dim": ACTION_DIM,
"input_dim": input_dim,
},
)
print(f"saved model={output_path}")
if args.save_metrics:
metrics_path = output_path.with_suffix(".csv")
save_metrics_csv(metrics_path, episode_metrics)
print(f"saved metrics={metrics_path}")
if args.save_graph:
graph_path = output_path.with_suffix(".svg")
save_training_graph(graph_path, episode_metrics, eval_metrics_rows)
print(f"saved graph={graph_path}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train a NumPy actor-critic baseline for Pyre.")
parser.add_argument("--episodes", type=int, default=120, help="Training episodes.")
parser.add_argument("--difficulty", type=str, default="easy", choices=DIFFICULTIES)
parser.add_argument(
"--difficulty-schedule",
type=str,
default="easy,medium",
help="Comma-separated curriculum, expanded evenly across episodes.",
)
parser.add_argument("--eval-difficulty", type=str, default="medium", choices=DIFFICULTIES)
parser.add_argument("--eval-episodes", type=int, default=5)
parser.add_argument("--eval-every", type=int, default=20)
parser.add_argument("--update-every", type=int, default=5, help="Episodes per policy update.")
parser.add_argument("--update-epochs", type=int, default=3, help="Gradient passes over each on-policy batch.")
parser.add_argument("--minibatch-size", type=int, default=256, help="Samples per gradient step.")
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--learning-rate", type=float, default=3e-4)
parser.add_argument("--value-coef", type=float, default=0.5)
parser.add_argument("--history-length", type=int, default=4)
parser.add_argument("--max-steps", type=int, default=150)
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--observation-mode", type=str, default="visible", choices=("visible", "full"))
parser.add_argument("--output", type=str, default="artifacts/pyre_actor_critic.npz")
parser.add_argument("--save-metrics", action="store_true", help="Save per-episode metrics as CSV beside the model.")
parser.add_argument("--save-graph", action="store_true", help="Save an SVG training graph beside the model.")
parser.add_argument("--describe-only", action="store_true", help="Print observation/action/reward definitions and exit.")
return parser.parse_args()
def main() -> None:
args = parse_args()
encoder = ObservationEncoder(mode=args.observation_mode)
if args.describe_only:
print(describe_environment_contract(encoder, args.history_length))
return
train(args)
if __name__ == "__main__":
main()