Upload folder using huggingface_hub
Browse files- env/__pycache__/firefighter_env.cpython-310.pyc +0 -0
- env/__pycache__/renderer.cpython-310.pyc +0 -0
- env/firefighter_env.py +102 -0
- env/renderer.py +106 -0
- env/utils.py +10 -0
env/__pycache__/firefighter_env.cpython-310.pyc
ADDED
|
Binary file (3.04 kB). View file
|
|
|
env/__pycache__/renderer.cpython-310.pyc
ADDED
|
Binary file (3.12 kB). View file
|
|
|
env/firefighter_env.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# firefighter_env.py
|
| 2 |
+
import copy
|
| 3 |
+
import gymnasium as gym
|
| 4 |
+
from gymnasium import spaces
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
class FireFighterEnv(gym.Env):
|
| 8 |
+
metadata = {"render_modes": ["human"], "render_fps": 4}
|
| 9 |
+
|
| 10 |
+
def __init__(self, render_mode=None):
|
| 11 |
+
super(FireFighterEnv, self).__init__()
|
| 12 |
+
|
| 13 |
+
self.grid_size = 4
|
| 14 |
+
self.max_steps = 60
|
| 15 |
+
|
| 16 |
+
self.agent_start = (0, 0)
|
| 17 |
+
self.bucket_pos = (1, 1)
|
| 18 |
+
self.fire_pos = (1, 3)
|
| 19 |
+
self.goal_pos = (3, 3)
|
| 20 |
+
self.walls = {(1, 2), (2, 1)}
|
| 21 |
+
|
| 22 |
+
# Actions: 0=Up, 1=Down, 2=Left, 3=Right
|
| 23 |
+
self.action_space = spaces.Discrete(4)
|
| 24 |
+
|
| 25 |
+
# Observation: (x, y, has_bucket, fire_out)
|
| 26 |
+
self.observation_space = spaces.MultiDiscrete([4, 4, 2, 2])
|
| 27 |
+
|
| 28 |
+
self.render_mode = render_mode
|
| 29 |
+
self.reset()
|
| 30 |
+
|
| 31 |
+
def reset(self, seed=None, options=None):
|
| 32 |
+
super().reset(seed=seed)
|
| 33 |
+
|
| 34 |
+
self.agent_pos = list(self.agent_start)
|
| 35 |
+
self.has_bucket = False
|
| 36 |
+
self.fire_out = False
|
| 37 |
+
self.steps = 0
|
| 38 |
+
|
| 39 |
+
obs = self._get_obs()
|
| 40 |
+
return obs, {}
|
| 41 |
+
|
| 42 |
+
def step(self, action):
|
| 43 |
+
self.steps += 1
|
| 44 |
+
x, y = self.agent_pos
|
| 45 |
+
|
| 46 |
+
move = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}
|
| 47 |
+
dx, dy = move[action]
|
| 48 |
+
|
| 49 |
+
# Stochastic transitions
|
| 50 |
+
if np.random.rand() > 0.8:
|
| 51 |
+
dx, dy = move[np.random.choice([a for a in move if a != action])]
|
| 52 |
+
|
| 53 |
+
new_x = np.clip(x + dx, 0, self.grid_size - 1)
|
| 54 |
+
new_y = np.clip(y + dy, 0, self.grid_size - 1)
|
| 55 |
+
|
| 56 |
+
if (new_x, new_y) in self.walls:
|
| 57 |
+
reward = -5
|
| 58 |
+
else:
|
| 59 |
+
self.agent_pos = [new_x, new_y]
|
| 60 |
+
reward = 0
|
| 61 |
+
|
| 62 |
+
# Bucket collection
|
| 63 |
+
if tuple(self.agent_pos) == self.bucket_pos and not self.has_bucket:
|
| 64 |
+
self.has_bucket = True
|
| 65 |
+
reward += 10
|
| 66 |
+
|
| 67 |
+
# Extinguish fire
|
| 68 |
+
if tuple(self.agent_pos) == self.fire_pos and self.has_bucket and not self.fire_out:
|
| 69 |
+
self.fire_out = True
|
| 70 |
+
reward += 10
|
| 71 |
+
|
| 72 |
+
# Reaching goal
|
| 73 |
+
if tuple(self.agent_pos) == self.goal_pos:
|
| 74 |
+
if self.fire_out:
|
| 75 |
+
reward += 10
|
| 76 |
+
terminated = True
|
| 77 |
+
else:
|
| 78 |
+
reward -= 10
|
| 79 |
+
terminated = True
|
| 80 |
+
else:
|
| 81 |
+
terminated = False
|
| 82 |
+
|
| 83 |
+
truncated = self.steps >= self.max_steps
|
| 84 |
+
|
| 85 |
+
obs = self._get_obs()
|
| 86 |
+
info = {}
|
| 87 |
+
return obs, reward, terminated, truncated, info
|
| 88 |
+
|
| 89 |
+
def _get_obs(self):
|
| 90 |
+
return np.array([*self.agent_pos, int(self.has_bucket), int(self.fire_out)], dtype=np.int32)
|
| 91 |
+
def save_state(self):
|
| 92 |
+
return copy.deepcopy(self)
|
| 93 |
+
|
| 94 |
+
def reset_to(self, saved_env):
|
| 95 |
+
self.__dict__.update(saved_env.__dict__)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def render(self):
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
def close(self):
|
| 102 |
+
pass
|
env/renderer.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
matplotlib.use('Agg')
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import matplotlib.image as mpimg
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
class FireFighterRenderer:
|
| 9 |
+
def __init__(self, save_dir=None):
|
| 10 |
+
self.grid_size = 4
|
| 11 |
+
self.save_dir = save_dir
|
| 12 |
+
self.frame_idx = 0
|
| 13 |
+
self.sprite_dir = "assets/sprites"
|
| 14 |
+
self.rewards = []
|
| 15 |
+
|
| 16 |
+
self.fig, (self.ax_grid, self.ax_plot) = plt.subplots(1, 2, figsize=(10, 5),
|
| 17 |
+
gridspec_kw={'width_ratios': [1, 1]})
|
| 18 |
+
|
| 19 |
+
# Load sprites
|
| 20 |
+
def load(name):
|
| 21 |
+
path = os.path.join(self.sprite_dir, f"{name}.png")
|
| 22 |
+
img = mpimg.imread(path)
|
| 23 |
+
assert img is not None, f"Failed to load: {path}"
|
| 24 |
+
return img
|
| 25 |
+
|
| 26 |
+
self.sprites = {
|
| 27 |
+
"robot_white": load("robot_white"),
|
| 28 |
+
"robot_blue": load("robot_blue"),
|
| 29 |
+
"robot_green": load("robot_green"),
|
| 30 |
+
"bucket": load("bucket"),
|
| 31 |
+
"fire": load("fire"),
|
| 32 |
+
"goal": load("goal"),
|
| 33 |
+
"wall": load("wall"),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
if save_dir:
|
| 37 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
def render(self, agent_pos, has_bucket, fire_out, bucket_pos, fire_pos, goal_pos, walls, reward):
|
| 40 |
+
self.ax_grid.clear()
|
| 41 |
+
self.ax_plot.clear()
|
| 42 |
+
self.rewards.append(reward)
|
| 43 |
+
|
| 44 |
+
self.ax_grid.set_xlim(0, self.grid_size)
|
| 45 |
+
self.ax_grid.set_ylim(0, self.grid_size)
|
| 46 |
+
self.ax_grid.set_xticks([])
|
| 47 |
+
self.ax_grid.set_yticks([])
|
| 48 |
+
self.ax_grid.set_aspect('equal')
|
| 49 |
+
|
| 50 |
+
# White background grid
|
| 51 |
+
for x in range(self.grid_size):
|
| 52 |
+
for y in range(self.grid_size):
|
| 53 |
+
self.ax_grid.add_patch(plt.Rectangle((y, self.grid_size - 1 - x), 1, 1,
|
| 54 |
+
edgecolor='black', facecolor='white', linewidth=1))
|
| 55 |
+
|
| 56 |
+
# Function to place sprites
|
| 57 |
+
def draw(sprite_name, x, y):
|
| 58 |
+
sprite = self.sprites[sprite_name]
|
| 59 |
+
self.ax_grid.imshow(sprite,
|
| 60 |
+
extent=(y, y + 1, self.grid_size - 1 - x, self.grid_size - x),
|
| 61 |
+
zorder=10)
|
| 62 |
+
|
| 63 |
+
# Draw elements
|
| 64 |
+
for wx, wy in walls:
|
| 65 |
+
draw("wall", wx, wy)
|
| 66 |
+
|
| 67 |
+
if not fire_out:
|
| 68 |
+
fx, fy = fire_pos
|
| 69 |
+
draw("fire", fx, fy)
|
| 70 |
+
|
| 71 |
+
bx, by = bucket_pos
|
| 72 |
+
draw("bucket", bx, by)
|
| 73 |
+
|
| 74 |
+
gx, gy = goal_pos
|
| 75 |
+
draw("goal", gx, gy)
|
| 76 |
+
|
| 77 |
+
ax, ay = agent_pos
|
| 78 |
+
if has_bucket and not fire_out:
|
| 79 |
+
robot_color = "robot_blue"
|
| 80 |
+
elif fire_out:
|
| 81 |
+
robot_color = "robot_green"
|
| 82 |
+
else:
|
| 83 |
+
robot_color = "robot_white"
|
| 84 |
+
draw(robot_color, ax, ay)
|
| 85 |
+
|
| 86 |
+
self.ax_grid.set_title(f"Step {self.frame_idx}")
|
| 87 |
+
|
| 88 |
+
# Reward plot
|
| 89 |
+
self.ax_plot.plot(np.cumsum(self.rewards), color='green', marker='o')
|
| 90 |
+
self.ax_plot.set_title("Cumulative Reward")
|
| 91 |
+
self.ax_plot.set_xlabel("Step")
|
| 92 |
+
self.ax_plot.set_ylabel("Total Reward")
|
| 93 |
+
self.ax_plot.grid(True)
|
| 94 |
+
|
| 95 |
+
if self.save_dir:
|
| 96 |
+
frame_path = os.path.join(self.save_dir, f"frame_{self.frame_idx:03d}.png")
|
| 97 |
+
self.fig.tight_layout()
|
| 98 |
+
self.fig.savefig(frame_path)
|
| 99 |
+
self.frame_idx += 1
|
| 100 |
+
else:
|
| 101 |
+
plt.pause(0.3)
|
| 102 |
+
plt.draw()
|
| 103 |
+
|
| 104 |
+
def close(self):
|
| 105 |
+
plt.close(self.fig)
|
| 106 |
+
|
env/utils.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def format_observation(obs):
|
| 2 |
+
return {
|
| 3 |
+
'x': obs[0],
|
| 4 |
+
'y': obs[1],
|
| 5 |
+
'has_bucket': bool(obs[2]),
|
| 6 |
+
'fire_out': bool(obs[3]),
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
def is_terminal_state(obs):
|
| 10 |
+
return (obs[0], obs[1]) == (3, 3) and obs[3] == 1
|