"""Train a deep RL baseline directly against the local Pyre environment. This script makes the environment contract explicit: - Observation: encoded from `PyreObservation.map_state` into a fixed-length vector - Action: fixed discrete action table with a runtime validity mask from `available_actions_hint` - Reward: the environment's composite reward returned by `PyreEnvironment.step()` It uses a self-contained NumPy actor-critic implementation so it can run in this repository without external ML dependencies. Examples: python examples/train_rl_agent.py --episodes 150 --difficulty easy python examples/train_rl_agent.py --episodes 300 --difficulty-schedule easy,medium python examples/train_rl_agent.py --episodes 200 --difficulty easy,medium,hard --observation-mode full python examples/train_rl_agent.py --describe-only """ from __future__ import annotations import argparse import csv import json import math import re from collections import deque from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterable, List, Sequence import numpy as np from pyre_env.models import PyreAction, PyreObservation from pyre_env.server.pyre_env_environment import PyreEnvironment MAX_GRID_W = 24 MAX_GRID_H = 24 MAX_DOORS = 16 DIRECTIONS = ("north", "south", "west", "east") WINDS = ("CALM", "NORTH", "SOUTH", "WEST", "EAST") DIFFICULTIES = ("easy", "medium", "hard") MOVE_KEYS = [f"move(direction='{d}')" for d in DIRECTIONS] LOOK_KEYS = [f"look(direction='{d}')" for d in DIRECTIONS] WAIT_KEY = "wait()" OPEN_KEYS = [f"door(target_id='door_{i}', door_state='open')" for i in range(1, MAX_DOORS + 1)] CLOSE_KEYS = [f"door(target_id='door_{i}', door_state='close')" for i in range(1, MAX_DOORS + 1)] ACTION_KEYS = MOVE_KEYS + LOOK_KEYS + [WAIT_KEY] + OPEN_KEYS + CLOSE_KEYS ACTION_DIM = len(ACTION_KEYS) ACTION_TO_INDEX = {key: idx for idx, key in enumerate(ACTION_KEYS)} _MOVE_RE = re.compile(r"move\(direction='(north|south|west|east)'\)") _LOOK_RE = re.compile(r"look\(direction='(north|south|west|east)'\)") _DOOR_RE = re.compile(r"door\(target_id='(door_(\d+))', door_state='(open|close)'\)") def _one_hot(index: int, size: int) -> np.ndarray: arr = np.zeros(size, dtype=np.float32) if 0 <= index < size: arr[index] = 1.0 return arr def action_index_to_env_action(index: int) -> PyreAction: if 0 <= index < 4: return PyreAction(action="move", direction=DIRECTIONS[index]) if 4 <= index < 8: return PyreAction(action="look", direction=DIRECTIONS[index - 4]) if index == 8: return PyreAction(action="wait") if 9 <= index < 9 + MAX_DOORS: door_id = f"door_{index - 8}" return PyreAction(action="door", target_id=door_id, door_state="open") door_slot = index - (9 + MAX_DOORS) door_id = f"door_{door_slot + 1}" return PyreAction(action="door", target_id=door_id, door_state="close") def build_action_mask(observation: PyreObservation) -> np.ndarray: mask = np.zeros(ACTION_DIM, dtype=np.float32) for hint in observation.available_actions_hint: idx = ACTION_TO_INDEX.get(hint) if idx is not None: mask[idx] = 1.0 continue match = _MOVE_RE.fullmatch(hint) if match: mask[ACTION_TO_INDEX[f"move(direction='{match.group(1)}')"]] = 1.0 continue match = _LOOK_RE.fullmatch(hint) if match: mask[ACTION_TO_INDEX[f"look(direction='{match.group(1)}')"]] = 1.0 continue match = _DOOR_RE.fullmatch(hint) if match: door_id = match.group(1) door_num = int(match.group(2)) state = match.group(3) if 1 <= door_num <= MAX_DOORS: mask[ACTION_TO_INDEX[f"door(target_id='{door_id}', door_state='{state}')"]] = 1.0 if mask.sum() == 0: mask[ACTION_TO_INDEX[WAIT_KEY]] = 1.0 return mask class ObservationEncoder: """Encode Pyre observations into a fixed-size float vector.""" def __init__(self, mode: str = "visible"): if mode not in {"visible", "full"}: raise ValueError(f"Unsupported observation mode: {mode}") self.mode = mode self.base_dim = MAX_GRID_W * MAX_GRID_H * 10 + 22 def encode(self, observation: PyreObservation) -> np.ndarray: map_state = observation.map_state if map_state is None: raise ValueError("PyreObservation.map_state is required for RL training.") cell_one_hot = np.zeros((MAX_GRID_H, MAX_GRID_W, 6), dtype=np.float32) fire_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) smoke_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) visible_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) agent_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) visible = {(x, y) for x, y in map_state.visible_cells} for y in range(map_state.grid_h): for x in range(map_state.grid_w): if self.mode == "visible" and (x, y) not in visible and (x, y) != (map_state.agent_x, map_state.agent_y): continue i = y * map_state.grid_w + x cell_type = int(map_state.cell_grid[i]) if 0 <= cell_type <= 5: cell_one_hot[y, x, cell_type] = 1.0 fire_channel[y, x] = float(map_state.fire_grid[i]) smoke_channel[y, x] = float(map_state.smoke_grid[i]) visible_channel[y, x] = 1.0 if (x, y) in visible else 0.0 if 0 <= map_state.agent_x < MAX_GRID_W and 0 <= map_state.agent_y < MAX_GRID_H: agent_channel[map_state.agent_y, map_state.agent_x] = 1.0 grid_features = np.concatenate( [ cell_one_hot.reshape(-1), fire_channel.reshape(-1), smoke_channel.reshape(-1), visible_channel.reshape(-1), agent_channel.reshape(-1), ] ) metadata = observation.metadata or {} wind_dir = str(metadata.get("wind_dir", map_state.wind_dir or "CALM")).upper() difficulty = str(metadata.get("difficulty", "medium")).lower() wind_index = WINDS.index(wind_dir) if wind_dir in WINDS else 0 difficulty_index = DIFFICULTIES.index(difficulty) if difficulty in DIFFICULTIES else 1 global_features = np.concatenate( [ np.array( [ float(observation.agent_health) / 100.0, float(map_state.agent_health) / 100.0, float(map_state.step_count) / max(1, map_state.max_steps), float(map_state.fire_spread_rate), float(map_state.humidity), float(map_state.agent_x) / max(1, map_state.grid_w - 1), float(map_state.agent_y) / max(1, map_state.grid_h - 1), float(metadata.get("nearest_exit_distance", MAX_GRID_W + MAX_GRID_H) or 0.0) / float(MAX_GRID_W + MAX_GRID_H), float(metadata.get("reachable_exit_count", 0.0)) / 4.0, float(metadata.get("visible_cell_count", 0.0)) / float(MAX_GRID_W * MAX_GRID_H), float(metadata.get("fire_sources", 0.0)) / 5.0, {"none": 0.0, "light": 0.33, "moderate": 0.66, "heavy": 1.0}.get(observation.smoke_level, 0.0), 1.0 if map_state.agent_alive else 0.0, 1.0 if map_state.agent_evacuated else 0.0, ], dtype=np.float32, ), _one_hot(wind_index, len(WINDS)), _one_hot(difficulty_index, len(DIFFICULTIES)), ] ) return np.concatenate([grid_features, global_features]).astype(np.float32) def describe(self, history_length: int) -> str: grid_text = ( f"Observation mode `{self.mode}` encodes a {MAX_GRID_W}x{MAX_GRID_H} padded map with " "10 channels per cell: 6-way cell type one-hot, fire intensity, smoke intensity, visible mask, and agent mask." ) if self.mode == "visible": visibility_text = "Only currently visible cells are populated; unseen cells stay zeroed." else: visibility_text = "The full ground-truth map is exposed for curriculum/debug use." return ( f"{grid_text} {visibility_text} " f"Global features add health, step progress, fire parameters, position, exit-distance metadata, smoke severity, wind, and difficulty. " f"{history_length} encoded frames are stacked, so the network input dimension is {self.base_dim * history_length}." ) def softmax_with_mask(logits: np.ndarray, mask: np.ndarray) -> np.ndarray: masked_logits = np.where(mask > 0.0, logits, -1e9) max_logits = np.max(masked_logits, axis=1, keepdims=True) exps = np.exp(masked_logits - max_logits) * mask denom = np.sum(exps, axis=1, keepdims=True) denom = np.where(denom <= 0.0, 1.0, denom) return exps / denom class AdamOptimizer: def __init__(self, params: Dict[str, np.ndarray], lr: float = 3e-4, beta1: float = 0.9, beta2: float = 0.999): self.lr = lr self.beta1 = beta1 self.beta2 = beta2 self.eps = 1e-8 self.t = 0 self.m = {k: np.zeros_like(v) for k, v in params.items()} self.v = {k: np.zeros_like(v) for k, v in params.items()} def step(self, params: Dict[str, np.ndarray], grads: Dict[str, np.ndarray], clip_norm: float = 1.0) -> None: total_norm_sq = 0.0 for grad in grads.values(): total_norm_sq += float(np.sum(grad * grad)) total_norm = math.sqrt(total_norm_sq) scale = 1.0 if total_norm > clip_norm: scale = clip_norm / (total_norm + 1e-8) self.t += 1 for name, param in params.items(): grad = grads[name] * scale self.m[name] = self.beta1 * self.m[name] + (1.0 - self.beta1) * grad self.v[name] = self.beta2 * self.v[name] + (1.0 - self.beta2) * (grad * grad) m_hat = self.m[name] / (1.0 - self.beta1 ** self.t) v_hat = self.v[name] / (1.0 - self.beta2 ** self.t) params[name] -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps) class PolicyValueNetwork: def __init__(self, input_dim: int, action_dim: int, rng: np.random.Generator, hidden_sizes: Sequence[int] = (256, 128)): h1, h2 = hidden_sizes self.params: Dict[str, np.ndarray] = { "w1": self._init_weight(rng, input_dim, h1), "b1": np.zeros(h1, dtype=np.float32), "w2": self._init_weight(rng, h1, h2), "b2": np.zeros(h2, dtype=np.float32), "wp": self._init_weight(rng, h2, action_dim), "bp": np.zeros(action_dim, dtype=np.float32), "wv": self._init_weight(rng, h2, 1), "bv": np.zeros(1, dtype=np.float32), } self.optimizer = AdamOptimizer(self.params) @staticmethod def _init_weight(rng: np.random.Generator, in_dim: int, out_dim: int) -> np.ndarray: scale = math.sqrt(2.0 / max(1, in_dim + out_dim)) return (rng.standard_normal((in_dim, out_dim)) * scale).astype(np.float32) def forward(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray]]: z1 = x @ self.params["w1"] + self.params["b1"] h1 = np.tanh(z1) z2 = h1 @ self.params["w2"] + self.params["b2"] h2 = np.tanh(z2) logits = h2 @ self.params["wp"] + self.params["bp"] values = (h2 @ self.params["wv"] + self.params["bv"]).reshape(-1) cache = {"x": x, "h1": h1, "h2": h2} return logits, values, cache def predict(self, x: np.ndarray, mask: np.ndarray) -> tuple[np.ndarray, float]: logits, values, _ = self.forward(x[None, :]) probs = softmax_with_mask(logits, mask[None, :])[0] return probs, float(values[0]) def update( self, states: np.ndarray, masks: np.ndarray, actions: np.ndarray, returns: np.ndarray, advantages: np.ndarray, value_coef: float = 0.5, ) -> Dict[str, float]: logits, values, cache = self.forward(states) probs = softmax_with_mask(logits, masks) batch_size = max(1, states.shape[0]) grad_logits = probs.copy() grad_logits[np.arange(batch_size), actions] -= 1.0 grad_logits *= advantages[:, None] / batch_size grad_logits *= masks grad_values = ((values - returns)[:, None] * value_coef) / batch_size grads: Dict[str, np.ndarray] = {} grads["wp"] = cache["h2"].T @ grad_logits grads["bp"] = np.sum(grad_logits, axis=0) grads["wv"] = cache["h2"].T @ grad_values grads["bv"] = np.sum(grad_values, axis=0) dh2 = grad_logits @ self.params["wp"].T + grad_values @ self.params["wv"].T dz2 = dh2 * (1.0 - cache["h2"] ** 2) grads["w2"] = cache["h1"].T @ dz2 grads["b2"] = np.sum(dz2, axis=0) dh1 = dz2 @ self.params["w2"].T dz1 = dh1 * (1.0 - cache["h1"] ** 2) grads["w1"] = cache["x"].T @ dz1 grads["b1"] = np.sum(dz1, axis=0) self.optimizer.step(self.params, grads, clip_norm=1.0) chosen_probs = np.clip(probs[np.arange(batch_size), actions], 1e-8, 1.0) policy_loss = float(-np.mean(advantages * np.log(chosen_probs))) value_loss = float(0.5 * np.mean((values - returns) ** 2)) entropy = float(-np.mean(np.sum(np.where(probs > 0.0, probs * np.log(np.clip(probs, 1e-8, 1.0)), 0.0), axis=1))) return { "policy_loss": policy_loss, "value_loss": value_loss, "entropy": entropy, "mean_value": float(np.mean(values)), } def save(self, path: Path, metadata: Dict[str, object]) -> None: path.parent.mkdir(parents=True, exist_ok=True) arrays = {name: value for name, value in self.params.items()} arrays["metadata_json"] = np.array(json.dumps(metadata)) np.savez(path, **arrays) @dataclass class Trajectory: states: List[np.ndarray] masks: List[np.ndarray] actions: List[int] rewards: List[float] values: List[float] evacuated: bool final_health: float steps: int total_reward: float def compute_gae( rewards: Sequence[float], values: Sequence[float], gamma: float, gae_lambda: float, ) -> tuple[np.ndarray, np.ndarray]: rewards_arr = np.asarray(rewards, dtype=np.float32) values_arr = np.asarray(values, dtype=np.float32) advantages = np.zeros(len(rewards_arr), dtype=np.float32) gae = 0.0 next_value = 0.0 for i in range(len(rewards_arr) - 1, -1, -1): delta = rewards_arr[i] + gamma * next_value - values_arr[i] gae = delta + gamma * gae_lambda * gae advantages[i] = gae next_value = values_arr[i] returns = advantages + values_arr return returns.astype(np.float32), advantages.astype(np.float32) def select_action( network: PolicyValueNetwork, state_vec: np.ndarray, mask: np.ndarray, rng: np.random.Generator, greedy: bool = False, ) -> tuple[int, float]: probs, value = network.predict(state_vec, mask) valid_indices = np.flatnonzero(mask > 0.0) if len(valid_indices) == 0: return ACTION_TO_INDEX[WAIT_KEY], value if greedy: best_local = int(np.argmax(probs[valid_indices])) return int(valid_indices[best_local]), value return int(rng.choice(np.arange(len(probs)), p=probs)), value def build_stacked_state(frames: deque[np.ndarray]) -> np.ndarray: return np.concatenate(list(frames), dtype=np.float32) def run_episode( env: PyreEnvironment, network: PolicyValueNetwork, encoder: ObservationEncoder, rng: np.random.Generator, difficulty: str, history_length: int, greedy: bool = False, ) -> Trajectory: observation = env.reset(difficulty=difficulty) zero_frame = np.zeros(encoder.base_dim, dtype=np.float32) frames: deque[np.ndarray] = deque([zero_frame.copy() for _ in range(history_length)], maxlen=history_length) frames.append(encoder.encode(observation)) states: List[np.ndarray] = [] masks: List[np.ndarray] = [] actions: List[int] = [] rewards: List[float] = [] values: List[float] = [] total_reward = 0.0 final_health = observation.agent_health evacuated = False steps = 0 while True: state_vec = build_stacked_state(frames) mask = build_action_mask(observation) action_idx, value = select_action(network, state_vec, mask, rng, greedy=greedy) action = action_index_to_env_action(action_idx) next_obs = env.step(action) reward = float(next_obs.reward or 0.0) states.append(state_vec) masks.append(mask) actions.append(action_idx) rewards.append(reward) values.append(value) total_reward += reward steps += 1 final_health = next_obs.agent_health evacuated = next_obs.agent_evacuated frames.append(encoder.encode(next_obs)) observation = next_obs if next_obs.done: break return Trajectory( states=states, masks=masks, actions=actions, rewards=rewards, values=values, evacuated=evacuated, final_health=final_health, steps=steps, total_reward=total_reward, ) def evaluate_policy( env: PyreEnvironment, network: PolicyValueNetwork, encoder: ObservationEncoder, rng: np.random.Generator, difficulty: str, history_length: int, episodes: int, ) -> Dict[str, float]: rewards = [] evacuations = 0 lengths = [] for _ in range(episodes): traj = run_episode(env, network, encoder, rng, difficulty, history_length, greedy=True) rewards.append(traj.total_reward) lengths.append(traj.steps) evacuations += int(traj.evacuated) return { "eval_reward_mean": float(np.mean(rewards)) if rewards else 0.0, "eval_reward_max": float(np.max(rewards)) if rewards else 0.0, "eval_success_rate": float(evacuations / max(1, episodes)), "eval_steps_mean": float(np.mean(lengths)) if lengths else 0.0, } def expand_difficulty_schedule(schedule_text: str, episodes: int) -> List[str]: stages = [part.strip().lower() for part in schedule_text.split(",") if part.strip()] if not stages: stages = ["medium"] for stage in stages: if stage not in DIFFICULTIES: raise ValueError(f"Invalid difficulty in schedule: {stage}") segment = max(1, episodes // len(stages)) expanded: List[str] = [] for stage in stages: expanded.extend([stage] * segment) while len(expanded) < episodes: expanded.append(stages[-1]) return expanded[:episodes] def describe_environment_contract(encoder: ObservationEncoder, history_length: int) -> str: action_text = ( f"Action space has {ACTION_DIM} fixed discrete actions: 4 moves, 4 looks, wait, " f"{MAX_DOORS} door-open slots, and {MAX_DOORS} door-close slots. " "A per-step mask from `available_actions_hint` prevents invalid actions." ) reward_text = ( "Reward comes directly from the environment's composite rubric: time penalty, exit progress, " "progress regression penalty, safe-progress bonus, danger penalty, health-drain penalty, " "strategic door bonus, exploration bonus, plus terminal evacuation/death/timeout/near-miss/time bonuses." ) return "\n".join( [ "Pyre RL contract", encoder.describe(history_length), action_text, reward_text, ] ) def _moving_average(values: Sequence[float], window: int) -> List[float]: if not values: return [] out: List[float] = [] run = 0.0 q: deque[float] = deque() for value in values: q.append(float(value)) run += float(value) if len(q) > window: run -= q.popleft() out.append(run / len(q)) return out def save_metrics_csv(path: Path, rows: List[Dict[str, float | int | str]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) if not rows: return with path.open("w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) def save_training_graph(path: Path, episode_rows: List[Dict[str, float | int | str]], eval_rows: List[Dict[str, float | int | str]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) if not episode_rows: return width = 1260 height = 780 margin_left = 100 # extra room for rotated Y-axis label + tick values margin_right = 110 # extra room for right axis label + tick values margin_top = 70 # room for title margin_bottom = 90 # room for X-axis label + tick values + legend plot_w = width - margin_left - margin_right plot_h = height - margin_top - margin_bottom # X: plot_left=100, plot_right=1150 Y: plot_top=70, plot_bottom=690 episodes = [int(r["episode"]) for r in episode_rows] rewards = [float(r["reward"]) for r in episode_rows] reward_ma = _moving_average(rewards, 20) success_ma = _moving_average([float(r["evacuated"]) for r in episode_rows], 20) all_reward_values = rewards + reward_ma + [float(r["reward_mean"]) for r in eval_rows] + [float(r["reward_max"]) for r in eval_rows] y_min = min(all_reward_values) if all_reward_values else -1.0 y_max = max(all_reward_values) if all_reward_values else 1.0 if abs(y_max - y_min) < 1e-6: y_min -= 1.0 y_max += 1.0 y_pad = 0.1 * (y_max - y_min) y_min -= y_pad y_max += y_pad max_episode = max(episodes) if episodes else 1 plot_left = margin_left plot_right = margin_left + plot_w plot_top = margin_top plot_bottom = margin_top + plot_h def x_pos(ep: float) -> float: return plot_left + (float(ep) - 1.0) / max(1.0, max_episode - 1.0) * plot_w def y_pos_reward(value: float) -> float: return plot_top + (y_max - float(value)) / max(1e-6, (y_max - y_min)) * plot_h def y_pos_success(value: float) -> float: return plot_top + (1.0 - float(value)) * plot_h def polyline(points: List[tuple[float, float]]) -> str: return " ".join(f"{x:.1f},{y:.1f}" for x, y in points) reward_points = [(x_pos(ep), y_pos_reward(val)) for ep, val in zip(episodes, rewards)] reward_ma_points = [(x_pos(ep), y_pos_reward(val)) for ep, val in zip(episodes, reward_ma)] success_points = [(x_pos(ep), y_pos_success(val)) for ep, val in zip(episodes, success_ma)] eval_points = [(x_pos(float(r["episode"])), y_pos_success(float(r["success_rate"]))) for r in eval_rows] n_x_ticks = 8 episode_ticks = sorted(set( max(1, round(1 + i * (max_episode - 1) / n_x_ticks)) for i in range(n_x_ticks + 1) )) n_y_ticks = 6 reward_ticks = [y_min + (y_max - y_min) * i / n_y_ticks for i in range(n_y_ticks + 1)] success_ticks = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] svg = [] svg.append(f'') # Background svg.append('') # Title + subtitle svg.append(f'Pyre RL Training') svg.append(f'Left axis: Reward | Right axis: Success Rate (0–1)') # Plot area background + border svg.append(f'') # ── Vertical grid lines + X-axis ticks ────────────────────────────────── for tick in episode_ticks: x = x_pos(float(tick)) # dashed grid line svg.append(f'') # solid tick mark on bottom axis svg.append(f'') # tick label svg.append(f'{tick}') # X-axis title x_title_x = plot_left + plot_w / 2 x_title_y = plot_bottom + 50 svg.append(f'Episode') # ── Horizontal grid lines + Left Y-axis ticks (Reward) ────────────────── for tick in reward_ticks: y = y_pos_reward(tick) # dashed grid line svg.append(f'') # solid tick mark on left axis svg.append(f'') # tick label svg.append(f'{tick:.1f}') # Left Y-axis title (rotated) — centered on plot height ly_cx = plot_left - 70 ly_cy = plot_top + plot_h / 2 svg.append(f'Reward') # ── Right Y-axis ticks (Success Rate) ─────────────────────────────────── for tick in success_ticks: y = y_pos_success(tick) # solid tick mark on right axis svg.append(f'') # tick label svg.append(f'{tick:.2f}') # Right Y-axis title (rotated) ry_cx = plot_right + 85 ry_cy = plot_top + plot_h / 2 svg.append(f'Success Rate') # ── Axis border lines (solid, on top of grid) ──────────────────────────── # Bottom axis svg.append(f'') # Left axis svg.append(f'') # Right axis svg.append(f'') # ── Data series ───────────────────────────────────────────────────────── # Raw episode reward (faint) svg.append(f'') # Reward moving average svg.append(f'') # Success moving average svg.append(f'') # Eval checkpoints for x, y in eval_points: svg.append(f'') # ── Legend ─────────────────────────────────────────────────────────────── legend_y = plot_bottom + 72 items = [ ("#c1661c", 3, False, "Reward (moving avg)"), ("#127a8a", 3, False, "Success rate (moving avg)"), ("#c5bfb1", 1.5, False, "Episode reward"), ("#0d5b6b", 0, True, "Eval success checkpoint"), ] lx = plot_left for color, sw, is_dot, label in items: if is_dot: svg.append(f'') else: svg.append(f'') svg.append(f'{label}') lx += 230 svg.append("") path.write_text("\n".join(svg), encoding="utf-8") def train(args: argparse.Namespace) -> None: rng = np.random.default_rng(args.seed) encoder = ObservationEncoder(mode=args.observation_mode) difficulty_schedule = expand_difficulty_schedule(args.difficulty_schedule, args.episodes) input_dim = encoder.base_dim * args.history_length network = PolicyValueNetwork(input_dim=input_dim, action_dim=ACTION_DIM, rng=rng) env = PyreEnvironment(max_steps=args.max_steps) print(describe_environment_contract(encoder, args.history_length)) print("") batch_states: List[np.ndarray] = [] batch_masks: List[np.ndarray] = [] batch_actions: List[int] = [] batch_returns: List[np.ndarray] = [] batch_advantages: List[np.ndarray] = [] reward_window: deque[float] = deque(maxlen=20) success_window: deque[float] = deque(maxlen=20) episode_metrics: List[Dict[str, float | int | str]] = [] eval_metrics_rows: List[Dict[str, float | int | str]] = [] for episode_idx in range(args.episodes): difficulty = difficulty_schedule[episode_idx] if args.difficulty_schedule else args.difficulty traj = run_episode( env=env, network=network, encoder=encoder, rng=rng, difficulty=difficulty, history_length=args.history_length, greedy=False, ) returns, advantages = compute_gae(traj.rewards, traj.values, args.gamma, args.gae_lambda) batch_states.extend(traj.states) batch_masks.extend(traj.masks) batch_actions.extend(traj.actions) batch_returns.append(returns) batch_advantages.append(advantages) reward_window.append(traj.total_reward) success_window.append(float(traj.evacuated)) episode_metrics.append( { "episode": episode_idx + 1, "difficulty": difficulty, "reward": round(traj.total_reward, 4), "evacuated": int(traj.evacuated), "steps": traj.steps, "final_health": round(traj.final_health, 2), "reward_mean_20": round(float(np.mean(reward_window)), 4), "success_rate_20": round(float(np.mean(success_window)), 4), } ) print( f"episode={episode_idx + 1:04d} difficulty={difficulty:<6} " f"steps={traj.steps:03d} reward={traj.total_reward:+8.3f} " f"evacuated={int(traj.evacuated)} health={traj.final_health:6.1f}" ) should_update = (episode_idx + 1) % args.update_every == 0 or (episode_idx + 1) == args.episodes if should_update and batch_states: states_arr = np.asarray(batch_states, dtype=np.float32) masks_arr = np.asarray(batch_masks, dtype=np.float32) actions_arr = np.asarray(batch_actions, dtype=np.int64) returns_arr = np.concatenate(batch_returns).astype(np.float32) advantages_arr = np.concatenate(batch_advantages).astype(np.float32) advantages_arr = (advantages_arr - advantages_arr.mean()) / (advantages_arr.std() + 1e-8) network.optimizer.lr = args.learning_rate metrics = {} for _ in range(args.update_epochs): order = rng.permutation(len(states_arr)) for start in range(0, len(states_arr), args.minibatch_size): idx = order[start:start + args.minibatch_size] metrics = network.update( states=states_arr[idx], masks=masks_arr[idx], actions=actions_arr[idx], returns=returns_arr[idx], advantages=advantages_arr[idx], value_coef=args.value_coef, ) print( f"update episodes={episode_idx + 1:04d} samples={len(states_arr):05d} " f"reward_mean20={np.mean(reward_window):+8.3f} success20={np.mean(success_window):.2f} " f"policy_loss={metrics['policy_loss']:+.4f} value_loss={metrics['value_loss']:.4f} " f"entropy={metrics['entropy']:.4f}" ) batch_states.clear() batch_masks.clear() batch_actions.clear() batch_returns.clear() batch_advantages.clear() should_eval = args.eval_every > 0 and ((episode_idx + 1) % args.eval_every == 0 or (episode_idx + 1) == args.episodes) if should_eval: eval_metrics = evaluate_policy( env=env, network=network, encoder=encoder, rng=rng, difficulty=args.eval_difficulty, history_length=args.history_length, episodes=args.eval_episodes, ) print( f"eval episodes={episode_idx + 1:04d} difficulty={args.eval_difficulty:<6} " f"reward_mean={eval_metrics['eval_reward_mean']:+8.3f} " f"reward_max={eval_metrics['eval_reward_max']:+8.3f} " f"success={eval_metrics['eval_success_rate']:.2f} " f"steps={eval_metrics['eval_steps_mean']:.1f}" ) eval_metrics_rows.append( { "episode": episode_idx + 1, "difficulty": args.eval_difficulty, "reward_mean": round(eval_metrics["eval_reward_mean"], 4), "reward_max": round(eval_metrics["eval_reward_max"], 4), "success_rate": round(eval_metrics["eval_success_rate"], 4), "steps_mean": round(eval_metrics["eval_steps_mean"], 4), } ) if args.output: output_path = Path(args.output) network.save( output_path, metadata={ "observation_mode": args.observation_mode, "history_length": args.history_length, "episodes": args.episodes, "difficulty": args.difficulty, "difficulty_schedule": args.difficulty_schedule, "gamma": args.gamma, "gae_lambda": args.gae_lambda, "learning_rate": args.learning_rate, "update_epochs": args.update_epochs, "minibatch_size": args.minibatch_size, "action_dim": ACTION_DIM, "input_dim": input_dim, }, ) print(f"saved model={output_path}") if args.save_metrics: metrics_path = output_path.with_suffix(".csv") save_metrics_csv(metrics_path, episode_metrics) print(f"saved metrics={metrics_path}") if args.save_graph: graph_path = output_path.with_suffix(".svg") save_training_graph(graph_path, episode_metrics, eval_metrics_rows) print(f"saved graph={graph_path}") # Also save PNG try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.ticker as mticker import matplotlib.patches as mpatches episodes_list = [int(r["episode"]) for r in episode_metrics] rewards_list = [float(r["reward"]) for r in episode_metrics] evacuated_list = [float(r["evacuated"]) for r in episode_metrics] diff_list = [str(r["difficulty"]) for r in episode_metrics] def _ma(vals, w=20): out, run, q = [], 0.0, [] for v in vals: q.append(v); run += v if len(q) > w: run -= q.pop(0) out.append(run / len(q)) return out reward_ma = _ma(rewards_list) success_ma = _ma(evacuated_list) eval_eps = [int(r["episode"]) for r in eval_metrics_rows] eval_succ = [float(r["success_rate"]) for r in eval_metrics_rows] diff_colors = {"easy": "#d4edda", "medium": "#fff3cd", "hard": "#f8d7da"} regions = [] if diff_list: cur, start = diff_list[0], episodes_list[0] for ep, d in zip(episodes_list[1:], diff_list[1:]): if d != cur: regions.append((start, ep, cur)); cur, start = d, ep regions.append((start, episodes_list[-1], cur)) fig, ax1 = plt.subplots(figsize=(14, 6)) ax2 = ax1.twinx() for x0, x1, diff in regions: ax1.axvspan(x0, x1, color=diff_colors.get(diff, "#eeeeee"), alpha=0.35, zorder=0) ax1.axhline(0, color="#aaaaaa", linewidth=0.8, linestyle="--", zorder=1) ax1.plot(episodes_list, rewards_list, color="#d1c7bc", linewidth=0.8, alpha=0.6, label="Episode reward", zorder=2) ax1.plot(episodes_list, reward_ma, color="#c1661c", linewidth=2.5, label="Reward (MA-20)", zorder=3) ax2.plot(episodes_list, success_ma, color="#1a7a8a", linewidth=2.5, label="Success rate (MA-20)", zorder=3) if eval_eps: ax2.scatter(eval_eps, eval_succ, color="#0d5b6b", s=60, zorder=5, marker="D", edgecolors="white", linewidths=1.2, label="Eval success") ax1.set_xlabel("Episode", fontsize=13, fontweight="bold", labelpad=8) ax1.set_ylabel("Reward", fontsize=13, fontweight="bold", color="#c1661c", labelpad=8) ax2.set_ylabel("Success Rate", fontsize=13, fontweight="bold", color="#1a7a8a", labelpad=8) ax1.tick_params(axis="y", labelcolor="#c1661c") ax2.tick_params(axis="y", labelcolor="#1a7a8a") ax2.set_ylim(-0.05, 1.05) ax2.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0, decimals=0)) ax1.grid(True, linestyle="--", linewidth=0.6, color="#dddddd", alpha=0.8) ax1.set_xlim(episodes_list[0], episodes_list[-1]) diff_patches = [mpatches.Patch(color=diff_colors[d], alpha=0.6, label=d.capitalize()) for d in ["easy", "medium", "hard"] if d in diff_list] h1, l1 = ax1.get_legend_handles_labels() h2, l2 = ax2.get_legend_handles_labels() ax1.legend(h1 + h2 + diff_patches, l1 + l2 + [p.get_label() for p in diff_patches], loc="upper left", fontsize=9, framealpha=0.85) final_sr = success_ma[-1] if success_ma else 0.0 fig.suptitle(f"Pyre NumPy A2C Training — {episodes_list[-1]} episodes | final success: {final_sr:.0%}", fontsize=14, fontweight="bold", y=1.01) fig.tight_layout() png_path = output_path.with_suffix(".png") fig.savefig(png_path, dpi=150, bbox_inches="tight") plt.close(fig) print(f"saved graph_png={png_path}") except ImportError: pass def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train a NumPy actor-critic baseline for Pyre.") parser.add_argument("--episodes", type=int, default=120, help="Training episodes.") parser.add_argument("--difficulty", type=str, default="easy", choices=DIFFICULTIES) parser.add_argument( "--difficulty-schedule", type=str, default="easy,medium", help="Comma-separated curriculum, expanded evenly across episodes.", ) parser.add_argument("--eval-difficulty", type=str, default="medium", choices=DIFFICULTIES) parser.add_argument("--eval-episodes", type=int, default=5) parser.add_argument("--eval-every", type=int, default=20) parser.add_argument("--update-every", type=int, default=5, help="Episodes per policy update.") parser.add_argument("--update-epochs", type=int, default=3, help="Gradient passes over each on-policy batch.") parser.add_argument("--minibatch-size", type=int, default=256, help="Samples per gradient step.") parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--gae-lambda", type=float, default=0.95) parser.add_argument("--learning-rate", type=float, default=3e-4) parser.add_argument("--value-coef", type=float, default=0.5) parser.add_argument("--history-length", type=int, default=4) parser.add_argument("--max-steps", type=int, default=150) parser.add_argument("--seed", type=int, default=7) parser.add_argument("--observation-mode", type=str, default="visible", choices=("visible", "full")) parser.add_argument("--output", type=str, default="artifacts/pyre_actor_critic.npz") parser.add_argument("--save-metrics", action="store_true", help="Save per-episode metrics as CSV beside the model.") parser.add_argument("--save-graph", action="store_true", help="Save an SVG training graph beside the model.") parser.add_argument("--describe-only", action="store_true", help="Print observation/action/reward definitions and exit.") return parser.parse_args() def main() -> None: args = parse_args() encoder = ObservationEncoder(mode=args.observation_mode) if args.describe_only: print(describe_environment_contract(encoder, args.history_length)) return train(args) if __name__ == "__main__": main()