File size: 10,202 Bytes
443c22e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from pyre_env.models import PyreAction, PyreObservation
from pyre_env.server.pyre_env_environment import PyreEnvironment
import torch as th
import sys
import os
sys.path.append(os.getcwd())

class PyreGymEnv(gym.Env):
    """Gymnasium wrapper for PyreEnvironment."""
    
    def __init__(self, difficulty="easy", max_steps=150, observation_mode="visible"):
        super().__init__()
        self.env = PyreEnvironment(max_steps=max_steps)
        self.difficulty = difficulty
        self.observation_mode = observation_mode
        
        # Action space: 
        # 0-3: Move (N, S, W, E)
        # 4-7: Look (N, S, W, E)
        # 8: Wait
        # 9-24: Open Door 1-16
        # 25-40: Close Door 1-16
        self.action_space = spaces.Discrete(41)
        
        # Observation space: Multi-input
        # 1. Grid: 24x24x7 (Floor, Wall, Door_Open, Door_Closed, Exit, Obstacle, Fire, Smoke)
        # 2. Global: [health, oxygen, step_progress, fire_spread, humidity, agent_x, agent_y, nearest_exit_dist, is_coughing]
        # 3. Heat Sensor: 3x3
        self.observation_space = spaces.Dict({
            "grid": spaces.Box(low=0, high=1, shape=(7, 24, 24), dtype=np.float32),
            "global": spaces.Box(low=0, high=1, shape=(9,), dtype=np.float32),
            "heat": spaces.Box(low=0, high=1, shape=(1, 3, 3), dtype=np.float32)
        })

    def _get_obs(self, pyre_obs: PyreObservation):
        map_state = pyre_obs.map_state
        w, h = map_state.grid_w, map_state.grid_h
        
        # Build 7-channel grid
        # Channels: 0:Wall, 1:Door_Open, 2:Door_Closed, 3:Exit, 4:Obstacle, 5:Fire, 6:Smoke
        # (Floor is implicit as all zeros in other channels)
        grid = np.zeros((7, 24, 24), dtype=np.float32)
        
        visible = {(x, y) for x, y in map_state.visible_cells}
        for y in range(h):
            for x in range(w):
                if self.observation_mode == "visible" and (x, y) not in visible and (x, y) != (map_state.agent_x, map_state.agent_y):
                    continue
                
                i = y * w + x
                ct = map_state.cell_grid[i]
                if ct == 1: grid[0, y, x] = 1.0 # Wall
                elif ct == 2: grid[1, y, x] = 1.0 # Door Open
                elif ct == 3: grid[2, y, x] = 1.0 # Door Closed
                elif ct == 4: grid[3, y, x] = 1.0 # Exit
                elif ct == 5: grid[4, y, x] = 1.0 # Obstacle
                
                grid[5, y, x] = float(map_state.fire_grid[i])
                grid[6, y, x] = float(map_state.smoke_grid[i])
        
        # Global features
        metadata = pyre_obs.metadata or {}
        nearest_exit = float(metadata.get("nearest_exit_distance", 48) or 48.0) / 48.0
        
        global_feats = np.array([
            float(pyre_obs.agent_health) / 100.0,
            float(pyre_obs.oxygen_level) / 100.0,
            float(map_state.step_count) / float(map_state.max_steps),
            float(map_state.fire_spread_rate),
            float(map_state.humidity),
            float(map_state.agent_x) / 24.0,
            float(map_state.agent_y) / 24.0,
            nearest_exit,
            1.0 if pyre_obs.is_coughing else 0.0
        ], dtype=np.float32)
        
        # Heat sensor
        heat = np.array(pyre_obs.heat_sensor, dtype=np.float32).reshape(1, 3, 3)
        
        return {
            "grid": grid,
            "global": global_feats,
            "heat": heat
        }

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        difficulty = options.get("difficulty", self.difficulty) if options else self.difficulty
        pyre_obs = self.env.reset(seed=seed, difficulty=difficulty)
        return self._get_obs(pyre_obs), {}

    def step(self, action_idx):
        # Map Discrete action to PyreAction
        if action_idx < 4:
            dirs = ["north", "south", "west", "east"]
            action = PyreAction(action="move", direction=dirs[action_idx])
        elif action_idx < 8:
            dirs = ["north", "south", "west", "east"]
            action = PyreAction(action="look", direction=dirs[action_idx - 4])
        elif action_idx == 8:
            action = PyreAction(action="wait")
        elif action_idx < 9 + 16:
            action = PyreAction(action="door", target_id=f"door_{action_idx - 8}", door_state="open")
        else:
            action = PyreAction(action="door", target_id=f"door_{action_idx - 24}", door_state="close")
            
        pyre_obs = self.env.step(action)
        
        obs = self._get_obs(pyre_obs)
        reward = pyre_obs.reward
        terminated = pyre_obs.done
        truncated = False # Step limit handled by env.done
        
        return obs, reward, terminated, truncated, {"pyre_obs": pyre_obs}

if __name__ == "__main__":
    from stable_baselines3 import PPO
    from stable_baselines3.common.callbacks import CheckpointCallback
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--episodes", type=int, default=1500, help="Total episodes to train across all levels")
    parser.add_argument("--difficulty", type=str, default="curriculum", help="easy, medium, hard, random, or curriculum")
    parser.add_argument("--output", type=str, default="artifacts/ppo_pyre_multilevel")
    args = parser.parse_args()
    
    from gymnasium.wrappers import RecordEpisodeStatistics
    
    # Custom wrapper to handle difficulty changes
    class MultiLevelWrapper(gym.Wrapper):
        def __init__(self, env, mode="curriculum"):
            super().__init__(env)
            self.mode = mode
            self.current_difficulty = "easy"
            self.step_count = 0
            self.total_steps = 0
            
        def reset(self, **kwargs):
            if self.mode == "random":
                self.current_difficulty = np.random.choice(["easy", "medium", "hard"])
            elif self.mode == "curriculum":
                if self.total_steps < 0.33 * total_training_steps:
                    self.current_difficulty = "easy"
                elif self.total_steps < 0.66 * total_training_steps:
                    self.current_difficulty = "medium"
                else:
                    self.current_difficulty = "hard"
            else:
                self.current_difficulty = self.mode
            
            # Extract options from kwargs if present, or create new
            options = kwargs.get("options")
            if options is None:
                options = {}
            options["difficulty"] = self.current_difficulty
            kwargs["options"] = options
                
            return self.env.reset(**kwargs)

        def step(self, action):
            obs, reward, term, trunc, info = self.env.step(action)
            self.total_steps += 1
            info["difficulty"] = self.current_difficulty
            return obs, reward, term, trunc, info

    total_training_steps = args.episodes * 60
    
    env = PyreGymEnv(difficulty="easy") # Base difficulty
    env = MultiLevelWrapper(env, mode=args.difficulty)
    env = RecordEpisodeStatistics(env)
    
    # Custom CNN policy for the grid
    # Increased network capacity for multiple levels
    policy_kwargs = dict(
        activation_fn=th.nn.ReLU,
        net_arch=dict(pi=[256, 128], qf=[256, 128])
    )
    
    model = PPO(
        "MultiInputPolicy", 
        env, 
        verbose=1, 
        tensorboard_log="./ppo_pyre_tensorboard/",
        learning_rate=2e-4, # Slightly lower LR for stability across levels
        n_steps=2048,
        batch_size=128,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        ent_coef=0.02, # Higher entropy to encourage exploration in procedural maps
    )
    
    print(f"Starting multi-level training (mode: {args.difficulty})...")
    
    # Add a simple callback to log episode rewards to a CSV
    from stable_baselines3.common.callbacks import BaseCallback
    import csv
    from pathlib import Path
    
    class CSVLogCallback(BaseCallback):
        def __init__(self, filename):
            super().__init__()
            self.filename = filename
            self.results = []
        def _on_step(self):
            # Check every step for finished episodes
            for info in self.locals.get("infos", []):
                if "episode" in info:
                    self.results.append({
                        "step": self.num_timesteps,
                        "reward": info["episode"]["r"],
                        "length": info["episode"]["l"]
                    })
            return True
        def _on_rollout_end(self):
            # Save every rollout
            if self.results:
                with open(self.filename, "w", newline="") as f:
                    writer = csv.DictWriter(f, fieldnames=["step", "reward", "length"])
                    writer.writeheader()
                    writer.writerows(self.results)
            return True

    csv_path = args.output + ".csv"
    callback = CSVLogCallback(csv_path)

    model.learn(total_timesteps=args.episodes * 50, callback=callback)
    
    model.save(args.output)
    print(f"Model saved to {args.output}")
    print(f"Metrics saved to {csv_path}")

    # Generate a quick SVG graph if we have results
    if callback.results:
        try:
            from examples.train_rl_agent import save_training_graph
            # Mocking the row format expected by the baseline plotter
            rows = [{"episode": i, "reward": r["reward"], "evacuated": 0} for i, r in enumerate(callback.results)]
            save_training_graph(Path(args.output + ".svg"), rows, [])
            print(f"Graph saved to {args.output}.svg")
        except Exception as e:
            print(f"Could not generate SVG automatically: {e}")
            print("CSV is available at " + csv_path)