"""Train a deep RL baseline directly against the local Pyre environment. This script makes the environment contract explicit: - Observation: encoded from `PyreObservation.map_state` into a fixed-length vector - Action: fixed discrete action table with a runtime validity mask from `available_actions_hint` - Reward: the environment's composite reward returned by `PyreEnvironment.step()` It uses a self-contained NumPy actor-critic implementation so it can run in this repository without external ML dependencies. Examples: python examples/train_rl_agent.py --episodes 150 --difficulty easy python examples/train_rl_agent.py --episodes 300 --difficulty-schedule easy,medium python examples/train_rl_agent.py --episodes 200 --difficulty medium --observation-mode full python examples/train_rl_agent.py --describe-only """ from __future__ import annotations import argparse import csv import json import math import re from collections import deque from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterable, List, Sequence import numpy as np from pyre_env.models import PyreAction, PyreObservation from pyre_env.server.pyre_env_environment import PyreEnvironment MAX_GRID_W = 24 MAX_GRID_H = 24 MAX_DOORS = 16 DIRECTIONS = ("north", "south", "west", "east") WINDS = ("CALM", "NORTH", "SOUTH", "WEST", "EAST") DIFFICULTIES = ("easy", "medium", "hard") MOVE_KEYS = [f"move(direction='{d}')" for d in DIRECTIONS] LOOK_KEYS = [f"look(direction='{d}')" for d in DIRECTIONS] WAIT_KEY = "wait()" OPEN_KEYS = [f"door(target_id='door_{i}', door_state='open')" for i in range(1, MAX_DOORS + 1)] CLOSE_KEYS = [f"door(target_id='door_{i}', door_state='close')" for i in range(1, MAX_DOORS + 1)] ACTION_KEYS = MOVE_KEYS + LOOK_KEYS + [WAIT_KEY] + OPEN_KEYS + CLOSE_KEYS ACTION_DIM = len(ACTION_KEYS) ACTION_TO_INDEX = {key: idx for idx, key in enumerate(ACTION_KEYS)} _MOVE_RE = re.compile(r"move\(direction='(north|south|west|east)'\)") _LOOK_RE = re.compile(r"look\(direction='(north|south|west|east)'\)") _DOOR_RE = re.compile(r"door\(target_id='(door_(\d+))', door_state='(open|close)'\)") def _one_hot(index: int, size: int) -> np.ndarray: arr = np.zeros(size, dtype=np.float32) if 0 <= index < size: arr[index] = 1.0 return arr def action_index_to_env_action(index: int) -> PyreAction: if 0 <= index < 4: return PyreAction(action="move", direction=DIRECTIONS[index]) if 4 <= index < 8: return PyreAction(action="look", direction=DIRECTIONS[index - 4]) if index == 8: return PyreAction(action="wait") if 9 <= index < 9 + MAX_DOORS: door_id = f"door_{index - 8}" return PyreAction(action="door", target_id=door_id, door_state="open") door_slot = index - (9 + MAX_DOORS) door_id = f"door_{door_slot + 1}" return PyreAction(action="door", target_id=door_id, door_state="close") def build_action_mask(observation: PyreObservation) -> np.ndarray: mask = np.zeros(ACTION_DIM, dtype=np.float32) for hint in observation.available_actions_hint: idx = ACTION_TO_INDEX.get(hint) if idx is not None: mask[idx] = 1.0 continue match = _MOVE_RE.fullmatch(hint) if match: mask[ACTION_TO_INDEX[f"move(direction='{match.group(1)}')"]] = 1.0 continue match = _LOOK_RE.fullmatch(hint) if match: mask[ACTION_TO_INDEX[f"look(direction='{match.group(1)}')"]] = 1.0 continue match = _DOOR_RE.fullmatch(hint) if match: door_id = match.group(1) door_num = int(match.group(2)) state = match.group(3) if 1 <= door_num <= MAX_DOORS: mask[ACTION_TO_INDEX[f"door(target_id='{door_id}', door_state='{state}')"]] = 1.0 if mask.sum() == 0: mask[ACTION_TO_INDEX[WAIT_KEY]] = 1.0 return mask class ObservationEncoder: """Encode Pyre observations into a fixed-size float vector.""" def __init__(self, mode: str = "visible"): if mode not in {"visible", "full"}: raise ValueError(f"Unsupported observation mode: {mode}") self.mode = mode self.base_dim = MAX_GRID_W * MAX_GRID_H * 10 + 22 def encode(self, observation: PyreObservation) -> np.ndarray: map_state = observation.map_state if map_state is None: raise ValueError("PyreObservation.map_state is required for RL training.") cell_one_hot = np.zeros((MAX_GRID_H, MAX_GRID_W, 6), dtype=np.float32) fire_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) smoke_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) visible_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) agent_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) visible = {(x, y) for x, y in map_state.visible_cells} for y in range(map_state.grid_h): for x in range(map_state.grid_w): if self.mode == "visible" and (x, y) not in visible and (x, y) != (map_state.agent_x, map_state.agent_y): continue i = y * map_state.grid_w + x cell_type = int(map_state.cell_grid[i]) if 0 <= cell_type <= 5: cell_one_hot[y, x, cell_type] = 1.0 fire_channel[y, x] = float(map_state.fire_grid[i]) smoke_channel[y, x] = float(map_state.smoke_grid[i]) visible_channel[y, x] = 1.0 if (x, y) in visible else 0.0 if 0 <= map_state.agent_x < MAX_GRID_W and 0 <= map_state.agent_y < MAX_GRID_H: agent_channel[map_state.agent_y, map_state.agent_x] = 1.0 grid_features = np.concatenate( [ cell_one_hot.reshape(-1), fire_channel.reshape(-1), smoke_channel.reshape(-1), visible_channel.reshape(-1), agent_channel.reshape(-1), ] ) metadata = observation.metadata or {} wind_dir = str(metadata.get("wind_dir", map_state.wind_dir or "CALM")).upper() difficulty = str(metadata.get("difficulty", "medium")).lower() wind_index = WINDS.index(wind_dir) if wind_dir in WINDS else 0 difficulty_index = DIFFICULTIES.index(difficulty) if difficulty in DIFFICULTIES else 1 global_features = np.concatenate( [ np.array( [ float(observation.agent_health) / 100.0, float(map_state.agent_health) / 100.0, float(map_state.step_count) / max(1, map_state.max_steps), float(map_state.fire_spread_rate), float(map_state.humidity), float(map_state.agent_x) / max(1, map_state.grid_w - 1), float(map_state.agent_y) / max(1, map_state.grid_h - 1), float(metadata.get("nearest_exit_distance", MAX_GRID_W + MAX_GRID_H) or 0.0) / float(MAX_GRID_W + MAX_GRID_H), float(metadata.get("reachable_exit_count", 0.0)) / 4.0, float(metadata.get("visible_cell_count", 0.0)) / float(MAX_GRID_W * MAX_GRID_H), float(metadata.get("fire_sources", 0.0)) / 5.0, {"none": 0.0, "light": 0.33, "moderate": 0.66, "heavy": 1.0}.get(observation.smoke_level, 0.0), 1.0 if map_state.agent_alive else 0.0, 1.0 if map_state.agent_evacuated else 0.0, ], dtype=np.float32, ), _one_hot(wind_index, len(WINDS)), _one_hot(difficulty_index, len(DIFFICULTIES)), ] ) return np.concatenate([grid_features, global_features]).astype(np.float32) def describe(self, history_length: int) -> str: grid_text = ( f"Observation mode `{self.mode}` encodes a {MAX_GRID_W}x{MAX_GRID_H} padded map with " "10 channels per cell: 6-way cell type one-hot, fire intensity, smoke intensity, visible mask, and agent mask." ) if self.mode == "visible": visibility_text = "Only currently visible cells are populated; unseen cells stay zeroed." else: visibility_text = "The full ground-truth map is exposed for curriculum/debug use." return ( f"{grid_text} {visibility_text} " f"Global features add health, step progress, fire parameters, position, exit-distance metadata, smoke severity, wind, and difficulty. " f"{history_length} encoded frames are stacked, so the network input dimension is {self.base_dim * history_length}." ) def softmax_with_mask(logits: np.ndarray, mask: np.ndarray) -> np.ndarray: masked_logits = np.where(mask > 0.0, logits, -1e9) max_logits = np.max(masked_logits, axis=1, keepdims=True) exps = np.exp(masked_logits - max_logits) * mask denom = np.sum(exps, axis=1, keepdims=True) denom = np.where(denom <= 0.0, 1.0, denom) return exps / denom class AdamOptimizer: def __init__(self, params: Dict[str, np.ndarray], lr: float = 3e-4, beta1: float = 0.9, beta2: float = 0.999): self.lr = lr self.beta1 = beta1 self.beta2 = beta2 self.eps = 1e-8 self.t = 0 self.m = {k: np.zeros_like(v) for k, v in params.items()} self.v = {k: np.zeros_like(v) for k, v in params.items()} def step(self, params: Dict[str, np.ndarray], grads: Dict[str, np.ndarray], clip_norm: float = 1.0) -> None: total_norm_sq = 0.0 for grad in grads.values(): total_norm_sq += float(np.sum(grad * grad)) total_norm = math.sqrt(total_norm_sq) scale = 1.0 if total_norm > clip_norm: scale = clip_norm / (total_norm + 1e-8) self.t += 1 for name, param in params.items(): grad = grads[name] * scale self.m[name] = self.beta1 * self.m[name] + (1.0 - self.beta1) * grad self.v[name] = self.beta2 * self.v[name] + (1.0 - self.beta2) * (grad * grad) m_hat = self.m[name] / (1.0 - self.beta1 ** self.t) v_hat = self.v[name] / (1.0 - self.beta2 ** self.t) params[name] -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps) class PolicyValueNetwork: def __init__(self, input_dim: int, action_dim: int, rng: np.random.Generator, hidden_sizes: Sequence[int] = (256, 128)): h1, h2 = hidden_sizes self.params: Dict[str, np.ndarray] = { "w1": self._init_weight(rng, input_dim, h1), "b1": np.zeros(h1, dtype=np.float32), "w2": self._init_weight(rng, h1, h2), "b2": np.zeros(h2, dtype=np.float32), "wp": self._init_weight(rng, h2, action_dim), "bp": np.zeros(action_dim, dtype=np.float32), "wv": self._init_weight(rng, h2, 1), "bv": np.zeros(1, dtype=np.float32), } self.optimizer = AdamOptimizer(self.params) @staticmethod def _init_weight(rng: np.random.Generator, in_dim: int, out_dim: int) -> np.ndarray: scale = math.sqrt(2.0 / max(1, in_dim + out_dim)) return (rng.standard_normal((in_dim, out_dim)) * scale).astype(np.float32) def forward(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray]]: z1 = x @ self.params["w1"] + self.params["b1"] h1 = np.tanh(z1) z2 = h1 @ self.params["w2"] + self.params["b2"] h2 = np.tanh(z2) logits = h2 @ self.params["wp"] + self.params["bp"] values = (h2 @ self.params["wv"] + self.params["bv"]).reshape(-1) cache = {"x": x, "h1": h1, "h2": h2} return logits, values, cache def predict(self, x: np.ndarray, mask: np.ndarray) -> tuple[np.ndarray, float]: logits, values, _ = self.forward(x[None, :]) probs = softmax_with_mask(logits, mask[None, :])[0] return probs, float(values[0]) def update( self, states: np.ndarray, masks: np.ndarray, actions: np.ndarray, returns: np.ndarray, advantages: np.ndarray, value_coef: float = 0.5, ) -> Dict[str, float]: logits, values, cache = self.forward(states) probs = softmax_with_mask(logits, masks) batch_size = max(1, states.shape[0]) grad_logits = probs.copy() grad_logits[np.arange(batch_size), actions] -= 1.0 grad_logits *= advantages[:, None] / batch_size grad_logits *= masks grad_values = ((values - returns)[:, None] * value_coef) / batch_size grads: Dict[str, np.ndarray] = {} grads["wp"] = cache["h2"].T @ grad_logits grads["bp"] = np.sum(grad_logits, axis=0) grads["wv"] = cache["h2"].T @ grad_values grads["bv"] = np.sum(grad_values, axis=0) dh2 = grad_logits @ self.params["wp"].T + grad_values @ self.params["wv"].T dz2 = dh2 * (1.0 - cache["h2"] ** 2) grads["w2"] = cache["h1"].T @ dz2 grads["b2"] = np.sum(dz2, axis=0) dh1 = dz2 @ self.params["w2"].T dz1 = dh1 * (1.0 - cache["h1"] ** 2) grads["w1"] = cache["x"].T @ dz1 grads["b1"] = np.sum(dz1, axis=0) self.optimizer.step(self.params, grads, clip_norm=1.0) chosen_probs = np.clip(probs[np.arange(batch_size), actions], 1e-8, 1.0) policy_loss = float(-np.mean(advantages * np.log(chosen_probs))) value_loss = float(0.5 * np.mean((values - returns) ** 2)) entropy = float(-np.mean(np.sum(np.where(probs > 0.0, probs * np.log(np.clip(probs, 1e-8, 1.0)), 0.0), axis=1))) return { "policy_loss": policy_loss, "value_loss": value_loss, "entropy": entropy, "mean_value": float(np.mean(values)), } def save(self, path: Path, metadata: Dict[str, object]) -> None: path.parent.mkdir(parents=True, exist_ok=True) arrays = {name: value for name, value in self.params.items()} arrays["metadata_json"] = np.array(json.dumps(metadata)) np.savez(path, **arrays) @dataclass class Trajectory: states: List[np.ndarray] masks: List[np.ndarray] actions: List[int] rewards: List[float] values: List[float] evacuated: bool final_health: float steps: int total_reward: float def compute_gae( rewards: Sequence[float], values: Sequence[float], gamma: float, gae_lambda: float, ) -> tuple[np.ndarray, np.ndarray]: rewards_arr = np.asarray(rewards, dtype=np.float32) values_arr = np.asarray(values, dtype=np.float32) advantages = np.zeros(len(rewards_arr), dtype=np.float32) gae = 0.0 next_value = 0.0 for i in range(len(rewards_arr) - 1, -1, -1): delta = rewards_arr[i] + gamma * next_value - values_arr[i] gae = delta + gamma * gae_lambda * gae advantages[i] = gae next_value = values_arr[i] returns = advantages + values_arr return returns.astype(np.float32), advantages.astype(np.float32) def select_action( network: PolicyValueNetwork, state_vec: np.ndarray, mask: np.ndarray, rng: np.random.Generator, greedy: bool = False, ) -> tuple[int, float]: probs, value = network.predict(state_vec, mask) valid_indices = np.flatnonzero(mask > 0.0) if len(valid_indices) == 0: return ACTION_TO_INDEX[WAIT_KEY], value if greedy: best_local = int(np.argmax(probs[valid_indices])) return int(valid_indices[best_local]), value return int(rng.choice(np.arange(len(probs)), p=probs)), value def build_stacked_state(frames: deque[np.ndarray]) -> np.ndarray: return np.concatenate(list(frames), dtype=np.float32) def run_episode( env: PyreEnvironment, network: PolicyValueNetwork, encoder: ObservationEncoder, rng: np.random.Generator, difficulty: str, history_length: int, greedy: bool = False, ) -> Trajectory: observation = env.reset(difficulty=difficulty) zero_frame = np.zeros(encoder.base_dim, dtype=np.float32) frames: deque[np.ndarray] = deque([zero_frame.copy() for _ in range(history_length)], maxlen=history_length) frames.append(encoder.encode(observation)) states: List[np.ndarray] = [] masks: List[np.ndarray] = [] actions: List[int] = [] rewards: List[float] = [] values: List[float] = [] total_reward = 0.0 final_health = observation.agent_health evacuated = False steps = 0 while True: state_vec = build_stacked_state(frames) mask = build_action_mask(observation) action_idx, value = select_action(network, state_vec, mask, rng, greedy=greedy) action = action_index_to_env_action(action_idx) next_obs = env.step(action) reward = float(next_obs.reward or 0.0) states.append(state_vec) masks.append(mask) actions.append(action_idx) rewards.append(reward) values.append(value) total_reward += reward steps += 1 final_health = next_obs.agent_health evacuated = next_obs.agent_evacuated frames.append(encoder.encode(next_obs)) observation = next_obs if next_obs.done: break return Trajectory( states=states, masks=masks, actions=actions, rewards=rewards, values=values, evacuated=evacuated, final_health=final_health, steps=steps, total_reward=total_reward, ) def evaluate_policy( env: PyreEnvironment, network: PolicyValueNetwork, encoder: ObservationEncoder, rng: np.random.Generator, difficulty: str, history_length: int, episodes: int, ) -> Dict[str, float]: rewards = [] evacuations = 0 lengths = [] for _ in range(episodes): traj = run_episode(env, network, encoder, rng, difficulty, history_length, greedy=True) rewards.append(traj.total_reward) lengths.append(traj.steps) evacuations += int(traj.evacuated) return { "eval_reward_mean": float(np.mean(rewards)) if rewards else 0.0, "eval_reward_max": float(np.max(rewards)) if rewards else 0.0, "eval_success_rate": float(evacuations / max(1, episodes)), "eval_steps_mean": float(np.mean(lengths)) if lengths else 0.0, } def expand_difficulty_schedule(schedule_text: str, episodes: int) -> List[str]: stages = [part.strip().lower() for part in schedule_text.split(",") if part.strip()] if not stages: stages = ["medium"] for stage in stages: if stage not in DIFFICULTIES: raise ValueError(f"Invalid difficulty in schedule: {stage}") segment = max(1, episodes // len(stages)) expanded: List[str] = [] for stage in stages: expanded.extend([stage] * segment) while len(expanded) < episodes: expanded.append(stages[-1]) return expanded[:episodes] def describe_environment_contract(encoder: ObservationEncoder, history_length: int) -> str: action_text = ( f"Action space has {ACTION_DIM} fixed discrete actions: 4 moves, 4 looks, wait, " f"{MAX_DOORS} door-open slots, and {MAX_DOORS} door-close slots. " "A per-step mask from `available_actions_hint` prevents invalid actions." ) reward_text = ( "Reward comes directly from the environment's composite rubric: time penalty, exit progress, " "progress regression penalty, safe-progress bonus, danger penalty, health-drain penalty, " "strategic door bonus, exploration bonus, plus terminal evacuation/death/timeout/near-miss/time bonuses." ) return "\n".join( [ "Pyre RL contract", encoder.describe(history_length), action_text, reward_text, ] ) def _moving_average(values: Sequence[float], window: int) -> List[float]: if not values: return [] out: List[float] = [] run = 0.0 q: deque[float] = deque() for value in values: q.append(float(value)) run += float(value) if len(q) > window: run -= q.popleft() out.append(run / len(q)) return out def save_metrics_csv(path: Path, rows: List[Dict[str, float | int | str]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) if not rows: return with path.open("w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) def save_training_graph(path: Path, episode_rows: List[Dict[str, float | int | str]], eval_rows: List[Dict[str, float | int | str]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) if not episode_rows: return width = 1200 height = 720 margin_left = 80 margin_right = 40 margin_top = 50 margin_bottom = 60 plot_w = width - margin_left - margin_right plot_h = height - margin_top - margin_bottom episodes = [int(r["episode"]) for r in episode_rows] rewards = [float(r["reward"]) for r in episode_rows] reward_ma = _moving_average(rewards, 20) success_ma = _moving_average([float(r["evacuated"]) for r in episode_rows], 20) all_reward_values = rewards + reward_ma + [float(r["reward_mean"]) for r in eval_rows] + [float(r["reward_max"]) for r in eval_rows] y_min = min(all_reward_values) if all_reward_values else -1.0 y_max = max(all_reward_values) if all_reward_values else 1.0 if abs(y_max - y_min) < 1e-6: y_min -= 1.0 y_max += 1.0 y_pad = 0.1 * (y_max - y_min) y_min -= y_pad y_max += y_pad max_episode = max(episodes) if episodes else 1 def x_pos(ep: float) -> float: return margin_left + (float(ep) - 1.0) / max(1.0, max_episode - 1.0) * plot_w def y_pos_reward(value: float) -> float: return margin_top + (y_max - float(value)) / max(1e-6, (y_max - y_min)) * plot_h def y_pos_success(value: float) -> float: return margin_top + (1.0 - float(value)) * plot_h def polyline(points: List[tuple[float, float]]) -> str: return " ".join(f"{x:.1f},{y:.1f}" for x, y in points) reward_points = [(x_pos(ep), y_pos_reward(val)) for ep, val in zip(episodes, rewards)] reward_ma_points = [(x_pos(ep), y_pos_reward(val)) for ep, val in zip(episodes, reward_ma)] success_points = [(x_pos(ep), y_pos_success(val)) for ep, val in zip(episodes, success_ma)] eval_points = [(x_pos(float(r["episode"])), y_pos_success(float(r["success_rate"]))) for r in eval_rows] episode_ticks = [1, max_episode // 4, max_episode // 2, (3 * max_episode) // 4, max_episode] episode_ticks = sorted(set(t for t in episode_ticks if t >= 1)) reward_ticks = [y_min + (y_max - y_min) * i / 4.0 for i in range(5)] success_ticks = [0.0, 0.25, 0.5, 0.75, 1.0] svg = [] svg.append(f'') svg.append('') svg.append('Pyre RL Training') svg.append('Reward curves on left axis, success rate on right axis') svg.append(f'') for tick in episode_ticks: x = x_pos(float(tick)) svg.append(f'') svg.append(f'{tick}') for tick in reward_ticks: y = y_pos_reward(tick) svg.append(f'') svg.append(f'{tick:.1f}') for tick in success_ticks: y = y_pos_success(tick) svg.append(f'{tick:.2f}') svg.append(f'') svg.append(f'') svg.append(f'') for x, y in eval_points: svg.append(f'') legend_y = height - 18 svg.append(f'') svg.append(f'Reward moving average') svg.append(f'') svg.append(f'Success moving average') svg.append(f'') svg.append(f'Episode reward') svg.append(f'') svg.append(f'Eval success checkpoints') svg.append("") path.write_text("\n".join(svg), encoding="utf-8") def train(args: argparse.Namespace) -> None: rng = np.random.default_rng(args.seed) encoder = ObservationEncoder(mode=args.observation_mode) difficulty_schedule = expand_difficulty_schedule(args.difficulty_schedule, args.episodes) input_dim = encoder.base_dim * args.history_length network = PolicyValueNetwork(input_dim=input_dim, action_dim=ACTION_DIM, rng=rng) env = PyreEnvironment(max_steps=args.max_steps) print(describe_environment_contract(encoder, args.history_length)) print("") batch_states: List[np.ndarray] = [] batch_masks: List[np.ndarray] = [] batch_actions: List[int] = [] batch_returns: List[np.ndarray] = [] batch_advantages: List[np.ndarray] = [] reward_window: deque[float] = deque(maxlen=20) success_window: deque[float] = deque(maxlen=20) episode_metrics: List[Dict[str, float | int | str]] = [] eval_metrics_rows: List[Dict[str, float | int | str]] = [] for episode_idx in range(args.episodes): difficulty = difficulty_schedule[episode_idx] if args.difficulty_schedule else args.difficulty traj = run_episode( env=env, network=network, encoder=encoder, rng=rng, difficulty=difficulty, history_length=args.history_length, greedy=False, ) returns, advantages = compute_gae(traj.rewards, traj.values, args.gamma, args.gae_lambda) batch_states.extend(traj.states) batch_masks.extend(traj.masks) batch_actions.extend(traj.actions) batch_returns.append(returns) batch_advantages.append(advantages) reward_window.append(traj.total_reward) success_window.append(float(traj.evacuated)) episode_metrics.append( { "episode": episode_idx + 1, "difficulty": difficulty, "reward": round(traj.total_reward, 4), "evacuated": int(traj.evacuated), "steps": traj.steps, "final_health": round(traj.final_health, 2), "reward_mean_20": round(float(np.mean(reward_window)), 4), "success_rate_20": round(float(np.mean(success_window)), 4), } ) print( f"episode={episode_idx + 1:04d} difficulty={difficulty:<6} " f"steps={traj.steps:03d} reward={traj.total_reward:+8.3f} " f"evacuated={int(traj.evacuated)} health={traj.final_health:6.1f}" ) should_update = (episode_idx + 1) % args.update_every == 0 or (episode_idx + 1) == args.episodes if should_update and batch_states: states_arr = np.asarray(batch_states, dtype=np.float32) masks_arr = np.asarray(batch_masks, dtype=np.float32) actions_arr = np.asarray(batch_actions, dtype=np.int64) returns_arr = np.concatenate(batch_returns).astype(np.float32) advantages_arr = np.concatenate(batch_advantages).astype(np.float32) advantages_arr = (advantages_arr - advantages_arr.mean()) / (advantages_arr.std() + 1e-8) network.optimizer.lr = args.learning_rate metrics = {} for _ in range(args.update_epochs): order = rng.permutation(len(states_arr)) for start in range(0, len(states_arr), args.minibatch_size): idx = order[start:start + args.minibatch_size] metrics = network.update( states=states_arr[idx], masks=masks_arr[idx], actions=actions_arr[idx], returns=returns_arr[idx], advantages=advantages_arr[idx], value_coef=args.value_coef, ) print( f"update episodes={episode_idx + 1:04d} samples={len(states_arr):05d} " f"reward_mean20={np.mean(reward_window):+8.3f} success20={np.mean(success_window):.2f} " f"policy_loss={metrics['policy_loss']:+.4f} value_loss={metrics['value_loss']:.4f} " f"entropy={metrics['entropy']:.4f}" ) batch_states.clear() batch_masks.clear() batch_actions.clear() batch_returns.clear() batch_advantages.clear() should_eval = args.eval_every > 0 and ((episode_idx + 1) % args.eval_every == 0 or (episode_idx + 1) == args.episodes) if should_eval: eval_metrics = evaluate_policy( env=env, network=network, encoder=encoder, rng=rng, difficulty=args.eval_difficulty, history_length=args.history_length, episodes=args.eval_episodes, ) print( f"eval episodes={episode_idx + 1:04d} difficulty={args.eval_difficulty:<6} " f"reward_mean={eval_metrics['eval_reward_mean']:+8.3f} " f"reward_max={eval_metrics['eval_reward_max']:+8.3f} " f"success={eval_metrics['eval_success_rate']:.2f} " f"steps={eval_metrics['eval_steps_mean']:.1f}" ) eval_metrics_rows.append( { "episode": episode_idx + 1, "difficulty": args.eval_difficulty, "reward_mean": round(eval_metrics["eval_reward_mean"], 4), "reward_max": round(eval_metrics["eval_reward_max"], 4), "success_rate": round(eval_metrics["eval_success_rate"], 4), "steps_mean": round(eval_metrics["eval_steps_mean"], 4), } ) if args.output: output_path = Path(args.output) network.save( output_path, metadata={ "observation_mode": args.observation_mode, "history_length": args.history_length, "episodes": args.episodes, "difficulty": args.difficulty, "difficulty_schedule": args.difficulty_schedule, "gamma": args.gamma, "gae_lambda": args.gae_lambda, "learning_rate": args.learning_rate, "update_epochs": args.update_epochs, "minibatch_size": args.minibatch_size, "action_dim": ACTION_DIM, "input_dim": input_dim, }, ) print(f"saved model={output_path}") if args.save_metrics: metrics_path = output_path.with_suffix(".csv") save_metrics_csv(metrics_path, episode_metrics) print(f"saved metrics={metrics_path}") if args.save_graph: graph_path = output_path.with_suffix(".svg") save_training_graph(graph_path, episode_metrics, eval_metrics_rows) print(f"saved graph={graph_path}") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train a NumPy actor-critic baseline for Pyre.") parser.add_argument("--episodes", type=int, default=120, help="Training episodes.") parser.add_argument("--difficulty", type=str, default="easy", choices=DIFFICULTIES) parser.add_argument( "--difficulty-schedule", type=str, default="easy,medium", help="Comma-separated curriculum, expanded evenly across episodes.", ) parser.add_argument("--eval-difficulty", type=str, default="medium", choices=DIFFICULTIES) parser.add_argument("--eval-episodes", type=int, default=5) parser.add_argument("--eval-every", type=int, default=20) parser.add_argument("--update-every", type=int, default=5, help="Episodes per policy update.") parser.add_argument("--update-epochs", type=int, default=3, help="Gradient passes over each on-policy batch.") parser.add_argument("--minibatch-size", type=int, default=256, help="Samples per gradient step.") parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--gae-lambda", type=float, default=0.95) parser.add_argument("--learning-rate", type=float, default=3e-4) parser.add_argument("--value-coef", type=float, default=0.5) parser.add_argument("--history-length", type=int, default=4) parser.add_argument("--max-steps", type=int, default=150) parser.add_argument("--seed", type=int, default=7) parser.add_argument("--observation-mode", type=str, default="visible", choices=("visible", "full")) parser.add_argument("--output", type=str, default="artifacts/pyre_actor_critic.npz") parser.add_argument("--save-metrics", action="store_true", help="Save per-episode metrics as CSV beside the model.") parser.add_argument("--save-graph", action="store_true", help="Save an SVG training graph beside the model.") parser.add_argument("--describe-only", action="store_true", help="Print observation/action/reward definitions and exit.") return parser.parse_args() def main() -> None: args = parse_args() encoder = ObservationEncoder(mode=args.observation_mode) if args.describe_only: print(describe_environment_contract(encoder, args.history_length)) return train(args) if __name__ == "__main__": main()