Spaces:
Sleeping
Sleeping
| """Train a deep RL baseline directly against the local Pyre environment. | |
| This script makes the environment contract explicit: | |
| - Observation: encoded from `PyreObservation.map_state` into a fixed-length vector | |
| - Action: fixed discrete action table with a runtime validity mask from `available_actions_hint` | |
| - Reward: the environment's composite reward returned by `PyreEnvironment.step()` | |
| It uses a self-contained NumPy actor-critic implementation so it can run in | |
| this repository without external ML dependencies. | |
| Examples: | |
| python examples/train_rl_agent.py --episodes 150 --difficulty easy | |
| python examples/train_rl_agent.py --episodes 300 --difficulty-schedule easy,medium | |
| python examples/train_rl_agent.py --episodes 200 --difficulty medium --observation-mode full | |
| python examples/train_rl_agent.py --describe-only | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import math | |
| import re | |
| from collections import deque | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, Iterable, List, Sequence | |
| import numpy as np | |
| from pyre_env.models import PyreAction, PyreObservation | |
| from pyre_env.server.pyre_env_environment import PyreEnvironment | |
| MAX_GRID_W = 24 | |
| MAX_GRID_H = 24 | |
| MAX_DOORS = 16 | |
| DIRECTIONS = ("north", "south", "west", "east") | |
| WINDS = ("CALM", "NORTH", "SOUTH", "WEST", "EAST") | |
| DIFFICULTIES = ("easy", "medium", "hard") | |
| MOVE_KEYS = [f"move(direction='{d}')" for d in DIRECTIONS] | |
| LOOK_KEYS = [f"look(direction='{d}')" for d in DIRECTIONS] | |
| WAIT_KEY = "wait()" | |
| OPEN_KEYS = [f"door(target_id='door_{i}', door_state='open')" for i in range(1, MAX_DOORS + 1)] | |
| CLOSE_KEYS = [f"door(target_id='door_{i}', door_state='close')" for i in range(1, MAX_DOORS + 1)] | |
| ACTION_KEYS = MOVE_KEYS + LOOK_KEYS + [WAIT_KEY] + OPEN_KEYS + CLOSE_KEYS | |
| ACTION_DIM = len(ACTION_KEYS) | |
| ACTION_TO_INDEX = {key: idx for idx, key in enumerate(ACTION_KEYS)} | |
| _MOVE_RE = re.compile(r"move\(direction='(north|south|west|east)'\)") | |
| _LOOK_RE = re.compile(r"look\(direction='(north|south|west|east)'\)") | |
| _DOOR_RE = re.compile(r"door\(target_id='(door_(\d+))', door_state='(open|close)'\)") | |
| def _one_hot(index: int, size: int) -> np.ndarray: | |
| arr = np.zeros(size, dtype=np.float32) | |
| if 0 <= index < size: | |
| arr[index] = 1.0 | |
| return arr | |
| def action_index_to_env_action(index: int) -> PyreAction: | |
| if 0 <= index < 4: | |
| return PyreAction(action="move", direction=DIRECTIONS[index]) | |
| if 4 <= index < 8: | |
| return PyreAction(action="look", direction=DIRECTIONS[index - 4]) | |
| if index == 8: | |
| return PyreAction(action="wait") | |
| if 9 <= index < 9 + MAX_DOORS: | |
| door_id = f"door_{index - 8}" | |
| return PyreAction(action="door", target_id=door_id, door_state="open") | |
| door_slot = index - (9 + MAX_DOORS) | |
| door_id = f"door_{door_slot + 1}" | |
| return PyreAction(action="door", target_id=door_id, door_state="close") | |
| def build_action_mask(observation: PyreObservation) -> np.ndarray: | |
| mask = np.zeros(ACTION_DIM, dtype=np.float32) | |
| for hint in observation.available_actions_hint: | |
| idx = ACTION_TO_INDEX.get(hint) | |
| if idx is not None: | |
| mask[idx] = 1.0 | |
| continue | |
| match = _MOVE_RE.fullmatch(hint) | |
| if match: | |
| mask[ACTION_TO_INDEX[f"move(direction='{match.group(1)}')"]] = 1.0 | |
| continue | |
| match = _LOOK_RE.fullmatch(hint) | |
| if match: | |
| mask[ACTION_TO_INDEX[f"look(direction='{match.group(1)}')"]] = 1.0 | |
| continue | |
| match = _DOOR_RE.fullmatch(hint) | |
| if match: | |
| door_id = match.group(1) | |
| door_num = int(match.group(2)) | |
| state = match.group(3) | |
| if 1 <= door_num <= MAX_DOORS: | |
| mask[ACTION_TO_INDEX[f"door(target_id='{door_id}', door_state='{state}')"]] = 1.0 | |
| if mask.sum() == 0: | |
| mask[ACTION_TO_INDEX[WAIT_KEY]] = 1.0 | |
| return mask | |
| class ObservationEncoder: | |
| """Encode Pyre observations into a fixed-size float vector.""" | |
| def __init__(self, mode: str = "visible"): | |
| if mode not in {"visible", "full"}: | |
| raise ValueError(f"Unsupported observation mode: {mode}") | |
| self.mode = mode | |
| self.base_dim = MAX_GRID_W * MAX_GRID_H * 10 + 22 | |
| def encode(self, observation: PyreObservation) -> np.ndarray: | |
| map_state = observation.map_state | |
| if map_state is None: | |
| raise ValueError("PyreObservation.map_state is required for RL training.") | |
| cell_one_hot = np.zeros((MAX_GRID_H, MAX_GRID_W, 6), dtype=np.float32) | |
| fire_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) | |
| smoke_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) | |
| visible_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) | |
| agent_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) | |
| visible = {(x, y) for x, y in map_state.visible_cells} | |
| for y in range(map_state.grid_h): | |
| for x in range(map_state.grid_w): | |
| if self.mode == "visible" and (x, y) not in visible and (x, y) != (map_state.agent_x, map_state.agent_y): | |
| continue | |
| i = y * map_state.grid_w + x | |
| cell_type = int(map_state.cell_grid[i]) | |
| if 0 <= cell_type <= 5: | |
| cell_one_hot[y, x, cell_type] = 1.0 | |
| fire_channel[y, x] = float(map_state.fire_grid[i]) | |
| smoke_channel[y, x] = float(map_state.smoke_grid[i]) | |
| visible_channel[y, x] = 1.0 if (x, y) in visible else 0.0 | |
| if 0 <= map_state.agent_x < MAX_GRID_W and 0 <= map_state.agent_y < MAX_GRID_H: | |
| agent_channel[map_state.agent_y, map_state.agent_x] = 1.0 | |
| grid_features = np.concatenate( | |
| [ | |
| cell_one_hot.reshape(-1), | |
| fire_channel.reshape(-1), | |
| smoke_channel.reshape(-1), | |
| visible_channel.reshape(-1), | |
| agent_channel.reshape(-1), | |
| ] | |
| ) | |
| metadata = observation.metadata or {} | |
| wind_dir = str(metadata.get("wind_dir", map_state.wind_dir or "CALM")).upper() | |
| difficulty = str(metadata.get("difficulty", "medium")).lower() | |
| wind_index = WINDS.index(wind_dir) if wind_dir in WINDS else 0 | |
| difficulty_index = DIFFICULTIES.index(difficulty) if difficulty in DIFFICULTIES else 1 | |
| global_features = np.concatenate( | |
| [ | |
| np.array( | |
| [ | |
| float(observation.agent_health) / 100.0, | |
| float(map_state.agent_health) / 100.0, | |
| float(map_state.step_count) / max(1, map_state.max_steps), | |
| float(map_state.fire_spread_rate), | |
| float(map_state.humidity), | |
| float(map_state.agent_x) / max(1, map_state.grid_w - 1), | |
| float(map_state.agent_y) / max(1, map_state.grid_h - 1), | |
| float(metadata.get("nearest_exit_distance", MAX_GRID_W + MAX_GRID_H) or 0.0) / float(MAX_GRID_W + MAX_GRID_H), | |
| float(metadata.get("reachable_exit_count", 0.0)) / 4.0, | |
| float(metadata.get("visible_cell_count", 0.0)) / float(MAX_GRID_W * MAX_GRID_H), | |
| float(metadata.get("fire_sources", 0.0)) / 5.0, | |
| {"none": 0.0, "light": 0.33, "moderate": 0.66, "heavy": 1.0}.get(observation.smoke_level, 0.0), | |
| 1.0 if map_state.agent_alive else 0.0, | |
| 1.0 if map_state.agent_evacuated else 0.0, | |
| ], | |
| dtype=np.float32, | |
| ), | |
| _one_hot(wind_index, len(WINDS)), | |
| _one_hot(difficulty_index, len(DIFFICULTIES)), | |
| ] | |
| ) | |
| return np.concatenate([grid_features, global_features]).astype(np.float32) | |
| def describe(self, history_length: int) -> str: | |
| grid_text = ( | |
| f"Observation mode `{self.mode}` encodes a {MAX_GRID_W}x{MAX_GRID_H} padded map with " | |
| "10 channels per cell: 6-way cell type one-hot, fire intensity, smoke intensity, visible mask, and agent mask." | |
| ) | |
| if self.mode == "visible": | |
| visibility_text = "Only currently visible cells are populated; unseen cells stay zeroed." | |
| else: | |
| visibility_text = "The full ground-truth map is exposed for curriculum/debug use." | |
| return ( | |
| f"{grid_text} {visibility_text} " | |
| f"Global features add health, step progress, fire parameters, position, exit-distance metadata, smoke severity, wind, and difficulty. " | |
| f"{history_length} encoded frames are stacked, so the network input dimension is {self.base_dim * history_length}." | |
| ) | |
| def softmax_with_mask(logits: np.ndarray, mask: np.ndarray) -> np.ndarray: | |
| masked_logits = np.where(mask > 0.0, logits, -1e9) | |
| max_logits = np.max(masked_logits, axis=1, keepdims=True) | |
| exps = np.exp(masked_logits - max_logits) * mask | |
| denom = np.sum(exps, axis=1, keepdims=True) | |
| denom = np.where(denom <= 0.0, 1.0, denom) | |
| return exps / denom | |
| class AdamOptimizer: | |
| def __init__(self, params: Dict[str, np.ndarray], lr: float = 3e-4, beta1: float = 0.9, beta2: float = 0.999): | |
| self.lr = lr | |
| self.beta1 = beta1 | |
| self.beta2 = beta2 | |
| self.eps = 1e-8 | |
| self.t = 0 | |
| self.m = {k: np.zeros_like(v) for k, v in params.items()} | |
| self.v = {k: np.zeros_like(v) for k, v in params.items()} | |
| def step(self, params: Dict[str, np.ndarray], grads: Dict[str, np.ndarray], clip_norm: float = 1.0) -> None: | |
| total_norm_sq = 0.0 | |
| for grad in grads.values(): | |
| total_norm_sq += float(np.sum(grad * grad)) | |
| total_norm = math.sqrt(total_norm_sq) | |
| scale = 1.0 | |
| if total_norm > clip_norm: | |
| scale = clip_norm / (total_norm + 1e-8) | |
| self.t += 1 | |
| for name, param in params.items(): | |
| grad = grads[name] * scale | |
| self.m[name] = self.beta1 * self.m[name] + (1.0 - self.beta1) * grad | |
| self.v[name] = self.beta2 * self.v[name] + (1.0 - self.beta2) * (grad * grad) | |
| m_hat = self.m[name] / (1.0 - self.beta1 ** self.t) | |
| v_hat = self.v[name] / (1.0 - self.beta2 ** self.t) | |
| params[name] -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps) | |
| class PolicyValueNetwork: | |
| def __init__(self, input_dim: int, action_dim: int, rng: np.random.Generator, hidden_sizes: Sequence[int] = (256, 128)): | |
| h1, h2 = hidden_sizes | |
| self.params: Dict[str, np.ndarray] = { | |
| "w1": self._init_weight(rng, input_dim, h1), | |
| "b1": np.zeros(h1, dtype=np.float32), | |
| "w2": self._init_weight(rng, h1, h2), | |
| "b2": np.zeros(h2, dtype=np.float32), | |
| "wp": self._init_weight(rng, h2, action_dim), | |
| "bp": np.zeros(action_dim, dtype=np.float32), | |
| "wv": self._init_weight(rng, h2, 1), | |
| "bv": np.zeros(1, dtype=np.float32), | |
| } | |
| self.optimizer = AdamOptimizer(self.params) | |
| def _init_weight(rng: np.random.Generator, in_dim: int, out_dim: int) -> np.ndarray: | |
| scale = math.sqrt(2.0 / max(1, in_dim + out_dim)) | |
| return (rng.standard_normal((in_dim, out_dim)) * scale).astype(np.float32) | |
| def forward(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray]]: | |
| z1 = x @ self.params["w1"] + self.params["b1"] | |
| h1 = np.tanh(z1) | |
| z2 = h1 @ self.params["w2"] + self.params["b2"] | |
| h2 = np.tanh(z2) | |
| logits = h2 @ self.params["wp"] + self.params["bp"] | |
| values = (h2 @ self.params["wv"] + self.params["bv"]).reshape(-1) | |
| cache = {"x": x, "h1": h1, "h2": h2} | |
| return logits, values, cache | |
| def predict(self, x: np.ndarray, mask: np.ndarray) -> tuple[np.ndarray, float]: | |
| logits, values, _ = self.forward(x[None, :]) | |
| probs = softmax_with_mask(logits, mask[None, :])[0] | |
| return probs, float(values[0]) | |
| def update( | |
| self, | |
| states: np.ndarray, | |
| masks: np.ndarray, | |
| actions: np.ndarray, | |
| returns: np.ndarray, | |
| advantages: np.ndarray, | |
| value_coef: float = 0.5, | |
| ) -> Dict[str, float]: | |
| logits, values, cache = self.forward(states) | |
| probs = softmax_with_mask(logits, masks) | |
| batch_size = max(1, states.shape[0]) | |
| grad_logits = probs.copy() | |
| grad_logits[np.arange(batch_size), actions] -= 1.0 | |
| grad_logits *= advantages[:, None] / batch_size | |
| grad_logits *= masks | |
| grad_values = ((values - returns)[:, None] * value_coef) / batch_size | |
| grads: Dict[str, np.ndarray] = {} | |
| grads["wp"] = cache["h2"].T @ grad_logits | |
| grads["bp"] = np.sum(grad_logits, axis=0) | |
| grads["wv"] = cache["h2"].T @ grad_values | |
| grads["bv"] = np.sum(grad_values, axis=0) | |
| dh2 = grad_logits @ self.params["wp"].T + grad_values @ self.params["wv"].T | |
| dz2 = dh2 * (1.0 - cache["h2"] ** 2) | |
| grads["w2"] = cache["h1"].T @ dz2 | |
| grads["b2"] = np.sum(dz2, axis=0) | |
| dh1 = dz2 @ self.params["w2"].T | |
| dz1 = dh1 * (1.0 - cache["h1"] ** 2) | |
| grads["w1"] = cache["x"].T @ dz1 | |
| grads["b1"] = np.sum(dz1, axis=0) | |
| self.optimizer.step(self.params, grads, clip_norm=1.0) | |
| chosen_probs = np.clip(probs[np.arange(batch_size), actions], 1e-8, 1.0) | |
| policy_loss = float(-np.mean(advantages * np.log(chosen_probs))) | |
| value_loss = float(0.5 * np.mean((values - returns) ** 2)) | |
| entropy = float(-np.mean(np.sum(np.where(probs > 0.0, probs * np.log(np.clip(probs, 1e-8, 1.0)), 0.0), axis=1))) | |
| return { | |
| "policy_loss": policy_loss, | |
| "value_loss": value_loss, | |
| "entropy": entropy, | |
| "mean_value": float(np.mean(values)), | |
| } | |
| def save(self, path: Path, metadata: Dict[str, object]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| arrays = {name: value for name, value in self.params.items()} | |
| arrays["metadata_json"] = np.array(json.dumps(metadata)) | |
| np.savez(path, **arrays) | |
| class Trajectory: | |
| states: List[np.ndarray] | |
| masks: List[np.ndarray] | |
| actions: List[int] | |
| rewards: List[float] | |
| values: List[float] | |
| evacuated: bool | |
| final_health: float | |
| steps: int | |
| total_reward: float | |
| def compute_gae( | |
| rewards: Sequence[float], | |
| values: Sequence[float], | |
| gamma: float, | |
| gae_lambda: float, | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| rewards_arr = np.asarray(rewards, dtype=np.float32) | |
| values_arr = np.asarray(values, dtype=np.float32) | |
| advantages = np.zeros(len(rewards_arr), dtype=np.float32) | |
| gae = 0.0 | |
| next_value = 0.0 | |
| for i in range(len(rewards_arr) - 1, -1, -1): | |
| delta = rewards_arr[i] + gamma * next_value - values_arr[i] | |
| gae = delta + gamma * gae_lambda * gae | |
| advantages[i] = gae | |
| next_value = values_arr[i] | |
| returns = advantages + values_arr | |
| return returns.astype(np.float32), advantages.astype(np.float32) | |
| def select_action( | |
| network: PolicyValueNetwork, | |
| state_vec: np.ndarray, | |
| mask: np.ndarray, | |
| rng: np.random.Generator, | |
| greedy: bool = False, | |
| ) -> tuple[int, float]: | |
| probs, value = network.predict(state_vec, mask) | |
| valid_indices = np.flatnonzero(mask > 0.0) | |
| if len(valid_indices) == 0: | |
| return ACTION_TO_INDEX[WAIT_KEY], value | |
| if greedy: | |
| best_local = int(np.argmax(probs[valid_indices])) | |
| return int(valid_indices[best_local]), value | |
| return int(rng.choice(np.arange(len(probs)), p=probs)), value | |
| def build_stacked_state(frames: deque[np.ndarray]) -> np.ndarray: | |
| return np.concatenate(list(frames), dtype=np.float32) | |
| def run_episode( | |
| env: PyreEnvironment, | |
| network: PolicyValueNetwork, | |
| encoder: ObservationEncoder, | |
| rng: np.random.Generator, | |
| difficulty: str, | |
| history_length: int, | |
| greedy: bool = False, | |
| ) -> Trajectory: | |
| observation = env.reset(difficulty=difficulty) | |
| zero_frame = np.zeros(encoder.base_dim, dtype=np.float32) | |
| frames: deque[np.ndarray] = deque([zero_frame.copy() for _ in range(history_length)], maxlen=history_length) | |
| frames.append(encoder.encode(observation)) | |
| states: List[np.ndarray] = [] | |
| masks: List[np.ndarray] = [] | |
| actions: List[int] = [] | |
| rewards: List[float] = [] | |
| values: List[float] = [] | |
| total_reward = 0.0 | |
| final_health = observation.agent_health | |
| evacuated = False | |
| steps = 0 | |
| while True: | |
| state_vec = build_stacked_state(frames) | |
| mask = build_action_mask(observation) | |
| action_idx, value = select_action(network, state_vec, mask, rng, greedy=greedy) | |
| action = action_index_to_env_action(action_idx) | |
| next_obs = env.step(action) | |
| reward = float(next_obs.reward or 0.0) | |
| states.append(state_vec) | |
| masks.append(mask) | |
| actions.append(action_idx) | |
| rewards.append(reward) | |
| values.append(value) | |
| total_reward += reward | |
| steps += 1 | |
| final_health = next_obs.agent_health | |
| evacuated = next_obs.agent_evacuated | |
| frames.append(encoder.encode(next_obs)) | |
| observation = next_obs | |
| if next_obs.done: | |
| break | |
| return Trajectory( | |
| states=states, | |
| masks=masks, | |
| actions=actions, | |
| rewards=rewards, | |
| values=values, | |
| evacuated=evacuated, | |
| final_health=final_health, | |
| steps=steps, | |
| total_reward=total_reward, | |
| ) | |
| def evaluate_policy( | |
| env: PyreEnvironment, | |
| network: PolicyValueNetwork, | |
| encoder: ObservationEncoder, | |
| rng: np.random.Generator, | |
| difficulty: str, | |
| history_length: int, | |
| episodes: int, | |
| ) -> Dict[str, float]: | |
| rewards = [] | |
| evacuations = 0 | |
| lengths = [] | |
| for _ in range(episodes): | |
| traj = run_episode(env, network, encoder, rng, difficulty, history_length, greedy=True) | |
| rewards.append(traj.total_reward) | |
| lengths.append(traj.steps) | |
| evacuations += int(traj.evacuated) | |
| return { | |
| "eval_reward_mean": float(np.mean(rewards)) if rewards else 0.0, | |
| "eval_reward_max": float(np.max(rewards)) if rewards else 0.0, | |
| "eval_success_rate": float(evacuations / max(1, episodes)), | |
| "eval_steps_mean": float(np.mean(lengths)) if lengths else 0.0, | |
| } | |
| def expand_difficulty_schedule(schedule_text: str, episodes: int) -> List[str]: | |
| stages = [part.strip().lower() for part in schedule_text.split(",") if part.strip()] | |
| if not stages: | |
| stages = ["medium"] | |
| for stage in stages: | |
| if stage not in DIFFICULTIES: | |
| raise ValueError(f"Invalid difficulty in schedule: {stage}") | |
| segment = max(1, episodes // len(stages)) | |
| expanded: List[str] = [] | |
| for stage in stages: | |
| expanded.extend([stage] * segment) | |
| while len(expanded) < episodes: | |
| expanded.append(stages[-1]) | |
| return expanded[:episodes] | |
| def describe_environment_contract(encoder: ObservationEncoder, history_length: int) -> str: | |
| action_text = ( | |
| f"Action space has {ACTION_DIM} fixed discrete actions: 4 moves, 4 looks, wait, " | |
| f"{MAX_DOORS} door-open slots, and {MAX_DOORS} door-close slots. " | |
| "A per-step mask from `available_actions_hint` prevents invalid actions." | |
| ) | |
| reward_text = ( | |
| "Reward comes directly from the environment's composite rubric: time penalty, exit progress, " | |
| "progress regression penalty, safe-progress bonus, danger penalty, health-drain penalty, " | |
| "strategic door bonus, exploration bonus, plus terminal evacuation/death/timeout/near-miss/time bonuses." | |
| ) | |
| return "\n".join( | |
| [ | |
| "Pyre RL contract", | |
| encoder.describe(history_length), | |
| action_text, | |
| reward_text, | |
| ] | |
| ) | |
| def _moving_average(values: Sequence[float], window: int) -> List[float]: | |
| if not values: | |
| return [] | |
| out: List[float] = [] | |
| run = 0.0 | |
| q: deque[float] = deque() | |
| for value in values: | |
| q.append(float(value)) | |
| run += float(value) | |
| if len(q) > window: | |
| run -= q.popleft() | |
| out.append(run / len(q)) | |
| return out | |
| def save_metrics_csv(path: Path, rows: List[Dict[str, float | int | str]]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| if not rows: | |
| return | |
| with path.open("w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) | |
| writer.writeheader() | |
| writer.writerows(rows) | |
| def save_training_graph(path: Path, episode_rows: List[Dict[str, float | int | str]], eval_rows: List[Dict[str, float | int | str]]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| if not episode_rows: | |
| return | |
| width = 1200 | |
| height = 720 | |
| margin_left = 80 | |
| margin_right = 40 | |
| margin_top = 50 | |
| margin_bottom = 60 | |
| plot_w = width - margin_left - margin_right | |
| plot_h = height - margin_top - margin_bottom | |
| episodes = [int(r["episode"]) for r in episode_rows] | |
| rewards = [float(r["reward"]) for r in episode_rows] | |
| reward_ma = _moving_average(rewards, 20) | |
| success_ma = _moving_average([float(r["evacuated"]) for r in episode_rows], 20) | |
| all_reward_values = rewards + reward_ma + [float(r["reward_mean"]) for r in eval_rows] + [float(r["reward_max"]) for r in eval_rows] | |
| y_min = min(all_reward_values) if all_reward_values else -1.0 | |
| y_max = max(all_reward_values) if all_reward_values else 1.0 | |
| if abs(y_max - y_min) < 1e-6: | |
| y_min -= 1.0 | |
| y_max += 1.0 | |
| y_pad = 0.1 * (y_max - y_min) | |
| y_min -= y_pad | |
| y_max += y_pad | |
| max_episode = max(episodes) if episodes else 1 | |
| def x_pos(ep: float) -> float: | |
| return margin_left + (float(ep) - 1.0) / max(1.0, max_episode - 1.0) * plot_w | |
| def y_pos_reward(value: float) -> float: | |
| return margin_top + (y_max - float(value)) / max(1e-6, (y_max - y_min)) * plot_h | |
| def y_pos_success(value: float) -> float: | |
| return margin_top + (1.0 - float(value)) * plot_h | |
| def polyline(points: List[tuple[float, float]]) -> str: | |
| return " ".join(f"{x:.1f},{y:.1f}" for x, y in points) | |
| reward_points = [(x_pos(ep), y_pos_reward(val)) for ep, val in zip(episodes, rewards)] | |
| reward_ma_points = [(x_pos(ep), y_pos_reward(val)) for ep, val in zip(episodes, reward_ma)] | |
| success_points = [(x_pos(ep), y_pos_success(val)) for ep, val in zip(episodes, success_ma)] | |
| eval_points = [(x_pos(float(r["episode"])), y_pos_success(float(r["success_rate"]))) for r in eval_rows] | |
| episode_ticks = [1, max_episode // 4, max_episode // 2, (3 * max_episode) // 4, max_episode] | |
| episode_ticks = sorted(set(t for t in episode_ticks if t >= 1)) | |
| reward_ticks = [y_min + (y_max - y_min) * i / 4.0 for i in range(5)] | |
| success_ticks = [0.0, 0.25, 0.5, 0.75, 1.0] | |
| svg = [] | |
| svg.append(f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">') | |
| svg.append('<rect width="100%" height="100%" fill="#f7f5ef"/>') | |
| svg.append('<text x="80" y="28" font-family="Georgia, serif" font-size="24" fill="#1d2a38">Pyre RL Training</text>') | |
| svg.append('<text x="80" y="48" font-family="Georgia, serif" font-size="13" fill="#5b6770">Reward curves on left axis, success rate on right axis</text>') | |
| svg.append(f'<rect x="{margin_left}" y="{margin_top}" width="{plot_w}" height="{plot_h}" fill="#fffdf8" stroke="#d1c9b8"/>') | |
| for tick in episode_ticks: | |
| x = x_pos(float(tick)) | |
| svg.append(f'<line x1="{x:.1f}" y1="{margin_top}" x2="{x:.1f}" y2="{margin_top + plot_h}" stroke="#ece7db" />') | |
| svg.append(f'<text x="{x:.1f}" y="{height - 24}" text-anchor="middle" font-family="Georgia, serif" font-size="12" fill="#5b6770">{tick}</text>') | |
| for tick in reward_ticks: | |
| y = y_pos_reward(tick) | |
| svg.append(f'<line x1="{margin_left}" y1="{y:.1f}" x2="{margin_left + plot_w}" y2="{y:.1f}" stroke="#ece7db" />') | |
| svg.append(f'<text x="{margin_left - 10}" y="{y + 4:.1f}" text-anchor="end" font-family="Georgia, serif" font-size="12" fill="#8a4b08">{tick:.1f}</text>') | |
| for tick in success_ticks: | |
| y = y_pos_success(tick) | |
| svg.append(f'<text x="{margin_left + plot_w + 10}" y="{y + 4:.1f}" font-family="Georgia, serif" font-size="12" fill="#0d5b6b">{tick:.2f}</text>') | |
| svg.append(f'<polyline fill="none" stroke="#c5bfb1" stroke-width="1.5" points="{polyline(reward_points)}"/>') | |
| svg.append(f'<polyline fill="none" stroke="#c1661c" stroke-width="3" points="{polyline(reward_ma_points)}"/>') | |
| svg.append(f'<polyline fill="none" stroke="#127a8a" stroke-width="3" points="{polyline(success_points)}"/>') | |
| for x, y in eval_points: | |
| svg.append(f'<circle cx="{x:.1f}" cy="{y:.1f}" r="4.5" fill="#0d5b6b" stroke="#ffffff" stroke-width="1.5"/>') | |
| legend_y = height - 18 | |
| svg.append(f'<line x1="80" y1="{legend_y}" x2="110" y2="{legend_y}" stroke="#c1661c" stroke-width="3"/>') | |
| svg.append(f'<text x="118" y="{legend_y + 4}" font-family="Georgia, serif" font-size="12" fill="#1d2a38">Reward moving average</text>') | |
| svg.append(f'<line x1="300" y1="{legend_y}" x2="330" y2="{legend_y}" stroke="#127a8a" stroke-width="3"/>') | |
| svg.append(f'<text x="338" y="{legend_y + 4}" font-family="Georgia, serif" font-size="12" fill="#1d2a38">Success moving average</text>') | |
| svg.append(f'<line x1="510" y1="{legend_y}" x2="540" y2="{legend_y}" stroke="#c5bfb1" stroke-width="1.5"/>') | |
| svg.append(f'<text x="548" y="{legend_y + 4}" font-family="Georgia, serif" font-size="12" fill="#1d2a38">Episode reward</text>') | |
| svg.append(f'<circle cx="700" cy="{legend_y}" r="4.5" fill="#0d5b6b" stroke="#ffffff" stroke-width="1.5"/>') | |
| svg.append(f'<text x="712" y="{legend_y + 4}" font-family="Georgia, serif" font-size="12" fill="#1d2a38">Eval success checkpoints</text>') | |
| svg.append("</svg>") | |
| path.write_text("\n".join(svg), encoding="utf-8") | |
| def train(args: argparse.Namespace) -> None: | |
| rng = np.random.default_rng(args.seed) | |
| encoder = ObservationEncoder(mode=args.observation_mode) | |
| difficulty_schedule = expand_difficulty_schedule(args.difficulty_schedule, args.episodes) | |
| input_dim = encoder.base_dim * args.history_length | |
| network = PolicyValueNetwork(input_dim=input_dim, action_dim=ACTION_DIM, rng=rng) | |
| env = PyreEnvironment(max_steps=args.max_steps) | |
| print(describe_environment_contract(encoder, args.history_length)) | |
| print("") | |
| batch_states: List[np.ndarray] = [] | |
| batch_masks: List[np.ndarray] = [] | |
| batch_actions: List[int] = [] | |
| batch_returns: List[np.ndarray] = [] | |
| batch_advantages: List[np.ndarray] = [] | |
| reward_window: deque[float] = deque(maxlen=20) | |
| success_window: deque[float] = deque(maxlen=20) | |
| episode_metrics: List[Dict[str, float | int | str]] = [] | |
| eval_metrics_rows: List[Dict[str, float | int | str]] = [] | |
| for episode_idx in range(args.episodes): | |
| difficulty = difficulty_schedule[episode_idx] if args.difficulty_schedule else args.difficulty | |
| traj = run_episode( | |
| env=env, | |
| network=network, | |
| encoder=encoder, | |
| rng=rng, | |
| difficulty=difficulty, | |
| history_length=args.history_length, | |
| greedy=False, | |
| ) | |
| returns, advantages = compute_gae(traj.rewards, traj.values, args.gamma, args.gae_lambda) | |
| batch_states.extend(traj.states) | |
| batch_masks.extend(traj.masks) | |
| batch_actions.extend(traj.actions) | |
| batch_returns.append(returns) | |
| batch_advantages.append(advantages) | |
| reward_window.append(traj.total_reward) | |
| success_window.append(float(traj.evacuated)) | |
| episode_metrics.append( | |
| { | |
| "episode": episode_idx + 1, | |
| "difficulty": difficulty, | |
| "reward": round(traj.total_reward, 4), | |
| "evacuated": int(traj.evacuated), | |
| "steps": traj.steps, | |
| "final_health": round(traj.final_health, 2), | |
| "reward_mean_20": round(float(np.mean(reward_window)), 4), | |
| "success_rate_20": round(float(np.mean(success_window)), 4), | |
| } | |
| ) | |
| print( | |
| f"episode={episode_idx + 1:04d} difficulty={difficulty:<6} " | |
| f"steps={traj.steps:03d} reward={traj.total_reward:+8.3f} " | |
| f"evacuated={int(traj.evacuated)} health={traj.final_health:6.1f}" | |
| ) | |
| should_update = (episode_idx + 1) % args.update_every == 0 or (episode_idx + 1) == args.episodes | |
| if should_update and batch_states: | |
| states_arr = np.asarray(batch_states, dtype=np.float32) | |
| masks_arr = np.asarray(batch_masks, dtype=np.float32) | |
| actions_arr = np.asarray(batch_actions, dtype=np.int64) | |
| returns_arr = np.concatenate(batch_returns).astype(np.float32) | |
| advantages_arr = np.concatenate(batch_advantages).astype(np.float32) | |
| advantages_arr = (advantages_arr - advantages_arr.mean()) / (advantages_arr.std() + 1e-8) | |
| network.optimizer.lr = args.learning_rate | |
| metrics = {} | |
| for _ in range(args.update_epochs): | |
| order = rng.permutation(len(states_arr)) | |
| for start in range(0, len(states_arr), args.minibatch_size): | |
| idx = order[start:start + args.minibatch_size] | |
| metrics = network.update( | |
| states=states_arr[idx], | |
| masks=masks_arr[idx], | |
| actions=actions_arr[idx], | |
| returns=returns_arr[idx], | |
| advantages=advantages_arr[idx], | |
| value_coef=args.value_coef, | |
| ) | |
| print( | |
| f"update episodes={episode_idx + 1:04d} samples={len(states_arr):05d} " | |
| f"reward_mean20={np.mean(reward_window):+8.3f} success20={np.mean(success_window):.2f} " | |
| f"policy_loss={metrics['policy_loss']:+.4f} value_loss={metrics['value_loss']:.4f} " | |
| f"entropy={metrics['entropy']:.4f}" | |
| ) | |
| batch_states.clear() | |
| batch_masks.clear() | |
| batch_actions.clear() | |
| batch_returns.clear() | |
| batch_advantages.clear() | |
| should_eval = args.eval_every > 0 and ((episode_idx + 1) % args.eval_every == 0 or (episode_idx + 1) == args.episodes) | |
| if should_eval: | |
| eval_metrics = evaluate_policy( | |
| env=env, | |
| network=network, | |
| encoder=encoder, | |
| rng=rng, | |
| difficulty=args.eval_difficulty, | |
| history_length=args.history_length, | |
| episodes=args.eval_episodes, | |
| ) | |
| print( | |
| f"eval episodes={episode_idx + 1:04d} difficulty={args.eval_difficulty:<6} " | |
| f"reward_mean={eval_metrics['eval_reward_mean']:+8.3f} " | |
| f"reward_max={eval_metrics['eval_reward_max']:+8.3f} " | |
| f"success={eval_metrics['eval_success_rate']:.2f} " | |
| f"steps={eval_metrics['eval_steps_mean']:.1f}" | |
| ) | |
| eval_metrics_rows.append( | |
| { | |
| "episode": episode_idx + 1, | |
| "difficulty": args.eval_difficulty, | |
| "reward_mean": round(eval_metrics["eval_reward_mean"], 4), | |
| "reward_max": round(eval_metrics["eval_reward_max"], 4), | |
| "success_rate": round(eval_metrics["eval_success_rate"], 4), | |
| "steps_mean": round(eval_metrics["eval_steps_mean"], 4), | |
| } | |
| ) | |
| if args.output: | |
| output_path = Path(args.output) | |
| network.save( | |
| output_path, | |
| metadata={ | |
| "observation_mode": args.observation_mode, | |
| "history_length": args.history_length, | |
| "episodes": args.episodes, | |
| "difficulty": args.difficulty, | |
| "difficulty_schedule": args.difficulty_schedule, | |
| "gamma": args.gamma, | |
| "gae_lambda": args.gae_lambda, | |
| "learning_rate": args.learning_rate, | |
| "update_epochs": args.update_epochs, | |
| "minibatch_size": args.minibatch_size, | |
| "action_dim": ACTION_DIM, | |
| "input_dim": input_dim, | |
| }, | |
| ) | |
| print(f"saved model={output_path}") | |
| if args.save_metrics: | |
| metrics_path = output_path.with_suffix(".csv") | |
| save_metrics_csv(metrics_path, episode_metrics) | |
| print(f"saved metrics={metrics_path}") | |
| if args.save_graph: | |
| graph_path = output_path.with_suffix(".svg") | |
| save_training_graph(graph_path, episode_metrics, eval_metrics_rows) | |
| print(f"saved graph={graph_path}") | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Train a NumPy actor-critic baseline for Pyre.") | |
| parser.add_argument("--episodes", type=int, default=120, help="Training episodes.") | |
| parser.add_argument("--difficulty", type=str, default="easy", choices=DIFFICULTIES) | |
| parser.add_argument( | |
| "--difficulty-schedule", | |
| type=str, | |
| default="easy,medium", | |
| help="Comma-separated curriculum, expanded evenly across episodes.", | |
| ) | |
| parser.add_argument("--eval-difficulty", type=str, default="medium", choices=DIFFICULTIES) | |
| parser.add_argument("--eval-episodes", type=int, default=5) | |
| parser.add_argument("--eval-every", type=int, default=20) | |
| parser.add_argument("--update-every", type=int, default=5, help="Episodes per policy update.") | |
| parser.add_argument("--update-epochs", type=int, default=3, help="Gradient passes over each on-policy batch.") | |
| parser.add_argument("--minibatch-size", type=int, default=256, help="Samples per gradient step.") | |
| parser.add_argument("--gamma", type=float, default=0.99) | |
| parser.add_argument("--gae-lambda", type=float, default=0.95) | |
| parser.add_argument("--learning-rate", type=float, default=3e-4) | |
| parser.add_argument("--value-coef", type=float, default=0.5) | |
| parser.add_argument("--history-length", type=int, default=4) | |
| parser.add_argument("--max-steps", type=int, default=150) | |
| parser.add_argument("--seed", type=int, default=7) | |
| parser.add_argument("--observation-mode", type=str, default="visible", choices=("visible", "full")) | |
| parser.add_argument("--output", type=str, default="artifacts/pyre_actor_critic.npz") | |
| parser.add_argument("--save-metrics", action="store_true", help="Save per-episode metrics as CSV beside the model.") | |
| parser.add_argument("--save-graph", action="store_true", help="Save an SVG training graph beside the model.") | |
| parser.add_argument("--describe-only", action="store_true", help="Print observation/action/reward definitions and exit.") | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| encoder = ObservationEncoder(mode=args.observation_mode) | |
| if args.describe_only: | |
| print(describe_environment_contract(encoder, args.history_length)) | |
| return | |
| train(args) | |
| if __name__ == "__main__": | |
| main() | |