pyre_env / examples /train_sb3_agent.py
Akshaykumarbm's picture
Upload folder using huggingface_hub
443c22e verified
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from pyre_env.models import PyreAction, PyreObservation
from pyre_env.server.pyre_env_environment import PyreEnvironment
import torch as th
import sys
import os
sys.path.append(os.getcwd())
class PyreGymEnv(gym.Env):
"""Gymnasium wrapper for PyreEnvironment."""
def __init__(self, difficulty="easy", max_steps=150, observation_mode="visible"):
super().__init__()
self.env = PyreEnvironment(max_steps=max_steps)
self.difficulty = difficulty
self.observation_mode = observation_mode
# Action space:
# 0-3: Move (N, S, W, E)
# 4-7: Look (N, S, W, E)
# 8: Wait
# 9-24: Open Door 1-16
# 25-40: Close Door 1-16
self.action_space = spaces.Discrete(41)
# Observation space: Multi-input
# 1. Grid: 24x24x7 (Floor, Wall, Door_Open, Door_Closed, Exit, Obstacle, Fire, Smoke)
# 2. Global: [health, oxygen, step_progress, fire_spread, humidity, agent_x, agent_y, nearest_exit_dist, is_coughing]
# 3. Heat Sensor: 3x3
self.observation_space = spaces.Dict({
"grid": spaces.Box(low=0, high=1, shape=(7, 24, 24), dtype=np.float32),
"global": spaces.Box(low=0, high=1, shape=(9,), dtype=np.float32),
"heat": spaces.Box(low=0, high=1, shape=(1, 3, 3), dtype=np.float32)
})
def _get_obs(self, pyre_obs: PyreObservation):
map_state = pyre_obs.map_state
w, h = map_state.grid_w, map_state.grid_h
# Build 7-channel grid
# Channels: 0:Wall, 1:Door_Open, 2:Door_Closed, 3:Exit, 4:Obstacle, 5:Fire, 6:Smoke
# (Floor is implicit as all zeros in other channels)
grid = np.zeros((7, 24, 24), dtype=np.float32)
visible = {(x, y) for x, y in map_state.visible_cells}
for y in range(h):
for x in range(w):
if self.observation_mode == "visible" and (x, y) not in visible and (x, y) != (map_state.agent_x, map_state.agent_y):
continue
i = y * w + x
ct = map_state.cell_grid[i]
if ct == 1: grid[0, y, x] = 1.0 # Wall
elif ct == 2: grid[1, y, x] = 1.0 # Door Open
elif ct == 3: grid[2, y, x] = 1.0 # Door Closed
elif ct == 4: grid[3, y, x] = 1.0 # Exit
elif ct == 5: grid[4, y, x] = 1.0 # Obstacle
grid[5, y, x] = float(map_state.fire_grid[i])
grid[6, y, x] = float(map_state.smoke_grid[i])
# Global features
metadata = pyre_obs.metadata or {}
nearest_exit = float(metadata.get("nearest_exit_distance", 48) or 48.0) / 48.0
global_feats = np.array([
float(pyre_obs.agent_health) / 100.0,
float(pyre_obs.oxygen_level) / 100.0,
float(map_state.step_count) / float(map_state.max_steps),
float(map_state.fire_spread_rate),
float(map_state.humidity),
float(map_state.agent_x) / 24.0,
float(map_state.agent_y) / 24.0,
nearest_exit,
1.0 if pyre_obs.is_coughing else 0.0
], dtype=np.float32)
# Heat sensor
heat = np.array(pyre_obs.heat_sensor, dtype=np.float32).reshape(1, 3, 3)
return {
"grid": grid,
"global": global_feats,
"heat": heat
}
def reset(self, seed=None, options=None):
super().reset(seed=seed)
difficulty = options.get("difficulty", self.difficulty) if options else self.difficulty
pyre_obs = self.env.reset(seed=seed, difficulty=difficulty)
return self._get_obs(pyre_obs), {}
def step(self, action_idx):
# Map Discrete action to PyreAction
if action_idx < 4:
dirs = ["north", "south", "west", "east"]
action = PyreAction(action="move", direction=dirs[action_idx])
elif action_idx < 8:
dirs = ["north", "south", "west", "east"]
action = PyreAction(action="look", direction=dirs[action_idx - 4])
elif action_idx == 8:
action = PyreAction(action="wait")
elif action_idx < 9 + 16:
action = PyreAction(action="door", target_id=f"door_{action_idx - 8}", door_state="open")
else:
action = PyreAction(action="door", target_id=f"door_{action_idx - 24}", door_state="close")
pyre_obs = self.env.step(action)
obs = self._get_obs(pyre_obs)
reward = pyre_obs.reward
terminated = pyre_obs.done
truncated = False # Step limit handled by env.done
return obs, reward, terminated, truncated, {"pyre_obs": pyre_obs}
if __name__ == "__main__":
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--episodes", type=int, default=1500, help="Total episodes to train across all levels")
parser.add_argument("--difficulty", type=str, default="curriculum", help="easy, medium, hard, random, or curriculum")
parser.add_argument("--output", type=str, default="artifacts/ppo_pyre_multilevel")
args = parser.parse_args()
from gymnasium.wrappers import RecordEpisodeStatistics
# Custom wrapper to handle difficulty changes
class MultiLevelWrapper(gym.Wrapper):
def __init__(self, env, mode="curriculum"):
super().__init__(env)
self.mode = mode
self.current_difficulty = "easy"
self.step_count = 0
self.total_steps = 0
def reset(self, **kwargs):
if self.mode == "random":
self.current_difficulty = np.random.choice(["easy", "medium", "hard"])
elif self.mode == "curriculum":
if self.total_steps < 0.33 * total_training_steps:
self.current_difficulty = "easy"
elif self.total_steps < 0.66 * total_training_steps:
self.current_difficulty = "medium"
else:
self.current_difficulty = "hard"
else:
self.current_difficulty = self.mode
# Extract options from kwargs if present, or create new
options = kwargs.get("options")
if options is None:
options = {}
options["difficulty"] = self.current_difficulty
kwargs["options"] = options
return self.env.reset(**kwargs)
def step(self, action):
obs, reward, term, trunc, info = self.env.step(action)
self.total_steps += 1
info["difficulty"] = self.current_difficulty
return obs, reward, term, trunc, info
total_training_steps = args.episodes * 60
env = PyreGymEnv(difficulty="easy") # Base difficulty
env = MultiLevelWrapper(env, mode=args.difficulty)
env = RecordEpisodeStatistics(env)
# Custom CNN policy for the grid
# Increased network capacity for multiple levels
policy_kwargs = dict(
activation_fn=th.nn.ReLU,
net_arch=dict(pi=[256, 128], qf=[256, 128])
)
model = PPO(
"MultiInputPolicy",
env,
verbose=1,
tensorboard_log="./ppo_pyre_tensorboard/",
learning_rate=2e-4, # Slightly lower LR for stability across levels
n_steps=2048,
batch_size=128,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.02, # Higher entropy to encourage exploration in procedural maps
)
print(f"Starting multi-level training (mode: {args.difficulty})...")
# Add a simple callback to log episode rewards to a CSV
from stable_baselines3.common.callbacks import BaseCallback
import csv
from pathlib import Path
class CSVLogCallback(BaseCallback):
def __init__(self, filename):
super().__init__()
self.filename = filename
self.results = []
def _on_step(self):
# Check every step for finished episodes
for info in self.locals.get("infos", []):
if "episode" in info:
self.results.append({
"step": self.num_timesteps,
"reward": info["episode"]["r"],
"length": info["episode"]["l"]
})
return True
def _on_rollout_end(self):
# Save every rollout
if self.results:
with open(self.filename, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=["step", "reward", "length"])
writer.writeheader()
writer.writerows(self.results)
return True
csv_path = args.output + ".csv"
callback = CSVLogCallback(csv_path)
model.learn(total_timesteps=args.episodes * 50, callback=callback)
model.save(args.output)
print(f"Model saved to {args.output}")
print(f"Metrics saved to {csv_path}")
# Generate a quick SVG graph if we have results
if callback.results:
try:
from examples.train_rl_agent import save_training_graph
# Mocking the row format expected by the baseline plotter
rows = [{"episode": i, "reward": r["reward"], "evacuated": 0} for i, r in enumerate(callback.results)]
save_training_graph(Path(args.output + ".svg"), rows, [])
print(f"Graph saved to {args.output}.svg")
except Exception as e:
print(f"Could not generate SVG automatically: {e}")
print("CSV is available at " + csv_path)