Spaces:
Sleeping
Sleeping
| import gymnasium as gym | |
| import numpy as np | |
| from gymnasium import spaces | |
| from pyre_env.models import PyreAction, PyreObservation | |
| from pyre_env.server.pyre_env_environment import PyreEnvironment | |
| import torch as th | |
| import sys | |
| import os | |
| sys.path.append(os.getcwd()) | |
| class PyreGymEnv(gym.Env): | |
| """Gymnasium wrapper for PyreEnvironment.""" | |
| def __init__(self, difficulty="easy", max_steps=150, observation_mode="visible"): | |
| super().__init__() | |
| self.env = PyreEnvironment(max_steps=max_steps) | |
| self.difficulty = difficulty | |
| self.observation_mode = observation_mode | |
| # Action space: | |
| # 0-3: Move (N, S, W, E) | |
| # 4-7: Look (N, S, W, E) | |
| # 8: Wait | |
| # 9-24: Open Door 1-16 | |
| # 25-40: Close Door 1-16 | |
| self.action_space = spaces.Discrete(41) | |
| # Observation space: Multi-input | |
| # 1. Grid: 24x24x7 (Floor, Wall, Door_Open, Door_Closed, Exit, Obstacle, Fire, Smoke) | |
| # 2. Global: [health, oxygen, step_progress, fire_spread, humidity, agent_x, agent_y, nearest_exit_dist, is_coughing] | |
| # 3. Heat Sensor: 3x3 | |
| self.observation_space = spaces.Dict({ | |
| "grid": spaces.Box(low=0, high=1, shape=(7, 24, 24), dtype=np.float32), | |
| "global": spaces.Box(low=0, high=1, shape=(9,), dtype=np.float32), | |
| "heat": spaces.Box(low=0, high=1, shape=(1, 3, 3), dtype=np.float32) | |
| }) | |
| def _get_obs(self, pyre_obs: PyreObservation): | |
| map_state = pyre_obs.map_state | |
| w, h = map_state.grid_w, map_state.grid_h | |
| # Build 7-channel grid | |
| # Channels: 0:Wall, 1:Door_Open, 2:Door_Closed, 3:Exit, 4:Obstacle, 5:Fire, 6:Smoke | |
| # (Floor is implicit as all zeros in other channels) | |
| grid = np.zeros((7, 24, 24), dtype=np.float32) | |
| visible = {(x, y) for x, y in map_state.visible_cells} | |
| for y in range(h): | |
| for x in range(w): | |
| if self.observation_mode == "visible" and (x, y) not in visible and (x, y) != (map_state.agent_x, map_state.agent_y): | |
| continue | |
| i = y * w + x | |
| ct = map_state.cell_grid[i] | |
| if ct == 1: grid[0, y, x] = 1.0 # Wall | |
| elif ct == 2: grid[1, y, x] = 1.0 # Door Open | |
| elif ct == 3: grid[2, y, x] = 1.0 # Door Closed | |
| elif ct == 4: grid[3, y, x] = 1.0 # Exit | |
| elif ct == 5: grid[4, y, x] = 1.0 # Obstacle | |
| grid[5, y, x] = float(map_state.fire_grid[i]) | |
| grid[6, y, x] = float(map_state.smoke_grid[i]) | |
| # Global features | |
| metadata = pyre_obs.metadata or {} | |
| nearest_exit = float(metadata.get("nearest_exit_distance", 48) or 48.0) / 48.0 | |
| global_feats = np.array([ | |
| float(pyre_obs.agent_health) / 100.0, | |
| float(pyre_obs.oxygen_level) / 100.0, | |
| float(map_state.step_count) / float(map_state.max_steps), | |
| float(map_state.fire_spread_rate), | |
| float(map_state.humidity), | |
| float(map_state.agent_x) / 24.0, | |
| float(map_state.agent_y) / 24.0, | |
| nearest_exit, | |
| 1.0 if pyre_obs.is_coughing else 0.0 | |
| ], dtype=np.float32) | |
| # Heat sensor | |
| heat = np.array(pyre_obs.heat_sensor, dtype=np.float32).reshape(1, 3, 3) | |
| return { | |
| "grid": grid, | |
| "global": global_feats, | |
| "heat": heat | |
| } | |
| def reset(self, seed=None, options=None): | |
| super().reset(seed=seed) | |
| difficulty = options.get("difficulty", self.difficulty) if options else self.difficulty | |
| pyre_obs = self.env.reset(seed=seed, difficulty=difficulty) | |
| return self._get_obs(pyre_obs), {} | |
| def step(self, action_idx): | |
| # Map Discrete action to PyreAction | |
| if action_idx < 4: | |
| dirs = ["north", "south", "west", "east"] | |
| action = PyreAction(action="move", direction=dirs[action_idx]) | |
| elif action_idx < 8: | |
| dirs = ["north", "south", "west", "east"] | |
| action = PyreAction(action="look", direction=dirs[action_idx - 4]) | |
| elif action_idx == 8: | |
| action = PyreAction(action="wait") | |
| elif action_idx < 9 + 16: | |
| action = PyreAction(action="door", target_id=f"door_{action_idx - 8}", door_state="open") | |
| else: | |
| action = PyreAction(action="door", target_id=f"door_{action_idx - 24}", door_state="close") | |
| pyre_obs = self.env.step(action) | |
| obs = self._get_obs(pyre_obs) | |
| reward = pyre_obs.reward | |
| terminated = pyre_obs.done | |
| truncated = False # Step limit handled by env.done | |
| return obs, reward, terminated, truncated, {"pyre_obs": pyre_obs} | |
| if __name__ == "__main__": | |
| from stable_baselines3 import PPO | |
| from stable_baselines3.common.callbacks import CheckpointCallback | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--episodes", type=int, default=1500, help="Total episodes to train across all levels") | |
| parser.add_argument("--difficulty", type=str, default="curriculum", help="easy, medium, hard, random, or curriculum") | |
| parser.add_argument("--output", type=str, default="artifacts/ppo_pyre_multilevel") | |
| args = parser.parse_args() | |
| from gymnasium.wrappers import RecordEpisodeStatistics | |
| # Custom wrapper to handle difficulty changes | |
| class MultiLevelWrapper(gym.Wrapper): | |
| def __init__(self, env, mode="curriculum"): | |
| super().__init__(env) | |
| self.mode = mode | |
| self.current_difficulty = "easy" | |
| self.step_count = 0 | |
| self.total_steps = 0 | |
| def reset(self, **kwargs): | |
| if self.mode == "random": | |
| self.current_difficulty = np.random.choice(["easy", "medium", "hard"]) | |
| elif self.mode == "curriculum": | |
| if self.total_steps < 0.33 * total_training_steps: | |
| self.current_difficulty = "easy" | |
| elif self.total_steps < 0.66 * total_training_steps: | |
| self.current_difficulty = "medium" | |
| else: | |
| self.current_difficulty = "hard" | |
| else: | |
| self.current_difficulty = self.mode | |
| # Extract options from kwargs if present, or create new | |
| options = kwargs.get("options") | |
| if options is None: | |
| options = {} | |
| options["difficulty"] = self.current_difficulty | |
| kwargs["options"] = options | |
| return self.env.reset(**kwargs) | |
| def step(self, action): | |
| obs, reward, term, trunc, info = self.env.step(action) | |
| self.total_steps += 1 | |
| info["difficulty"] = self.current_difficulty | |
| return obs, reward, term, trunc, info | |
| total_training_steps = args.episodes * 60 | |
| env = PyreGymEnv(difficulty="easy") # Base difficulty | |
| env = MultiLevelWrapper(env, mode=args.difficulty) | |
| env = RecordEpisodeStatistics(env) | |
| # Custom CNN policy for the grid | |
| # Increased network capacity for multiple levels | |
| policy_kwargs = dict( | |
| activation_fn=th.nn.ReLU, | |
| net_arch=dict(pi=[256, 128], qf=[256, 128]) | |
| ) | |
| model = PPO( | |
| "MultiInputPolicy", | |
| env, | |
| verbose=1, | |
| tensorboard_log="./ppo_pyre_tensorboard/", | |
| learning_rate=2e-4, # Slightly lower LR for stability across levels | |
| n_steps=2048, | |
| batch_size=128, | |
| n_epochs=10, | |
| gamma=0.99, | |
| gae_lambda=0.95, | |
| clip_range=0.2, | |
| ent_coef=0.02, # Higher entropy to encourage exploration in procedural maps | |
| ) | |
| print(f"Starting multi-level training (mode: {args.difficulty})...") | |
| # Add a simple callback to log episode rewards to a CSV | |
| from stable_baselines3.common.callbacks import BaseCallback | |
| import csv | |
| from pathlib import Path | |
| class CSVLogCallback(BaseCallback): | |
| def __init__(self, filename): | |
| super().__init__() | |
| self.filename = filename | |
| self.results = [] | |
| def _on_step(self): | |
| # Check every step for finished episodes | |
| for info in self.locals.get("infos", []): | |
| if "episode" in info: | |
| self.results.append({ | |
| "step": self.num_timesteps, | |
| "reward": info["episode"]["r"], | |
| "length": info["episode"]["l"] | |
| }) | |
| return True | |
| def _on_rollout_end(self): | |
| # Save every rollout | |
| if self.results: | |
| with open(self.filename, "w", newline="") as f: | |
| writer = csv.DictWriter(f, fieldnames=["step", "reward", "length"]) | |
| writer.writeheader() | |
| writer.writerows(self.results) | |
| return True | |
| csv_path = args.output + ".csv" | |
| callback = CSVLogCallback(csv_path) | |
| model.learn(total_timesteps=args.episodes * 50, callback=callback) | |
| model.save(args.output) | |
| print(f"Model saved to {args.output}") | |
| print(f"Metrics saved to {csv_path}") | |
| # Generate a quick SVG graph if we have results | |
| if callback.results: | |
| try: | |
| from examples.train_rl_agent import save_training_graph | |
| # Mocking the row format expected by the baseline plotter | |
| rows = [{"episode": i, "reward": r["reward"], "evacuated": 0} for i, r in enumerate(callback.results)] | |
| save_training_graph(Path(args.output + ".svg"), rows, []) | |
| print(f"Graph saved to {args.output}.svg") | |
| except Exception as e: | |
| print(f"Could not generate SVG automatically: {e}") | |
| print("CSV is available at " + csv_path) | |