Punit71 commited on
Commit
1346b12
·
verified ·
1 Parent(s): 491a664

Upload folder using huggingface_hub

Browse files
env/__pycache__/firefighter_env.cpython-310.pyc ADDED
Binary file (3.04 kB). View file
 
env/__pycache__/renderer.cpython-310.pyc ADDED
Binary file (3.12 kB). View file
 
env/firefighter_env.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # firefighter_env.py
2
+ import copy
3
+ import gymnasium as gym
4
+ from gymnasium import spaces
5
+ import numpy as np
6
+
7
+ class FireFighterEnv(gym.Env):
8
+ metadata = {"render_modes": ["human"], "render_fps": 4}
9
+
10
+ def __init__(self, render_mode=None):
11
+ super(FireFighterEnv, self).__init__()
12
+
13
+ self.grid_size = 4
14
+ self.max_steps = 60
15
+
16
+ self.agent_start = (0, 0)
17
+ self.bucket_pos = (1, 1)
18
+ self.fire_pos = (1, 3)
19
+ self.goal_pos = (3, 3)
20
+ self.walls = {(1, 2), (2, 1)}
21
+
22
+ # Actions: 0=Up, 1=Down, 2=Left, 3=Right
23
+ self.action_space = spaces.Discrete(4)
24
+
25
+ # Observation: (x, y, has_bucket, fire_out)
26
+ self.observation_space = spaces.MultiDiscrete([4, 4, 2, 2])
27
+
28
+ self.render_mode = render_mode
29
+ self.reset()
30
+
31
+ def reset(self, seed=None, options=None):
32
+ super().reset(seed=seed)
33
+
34
+ self.agent_pos = list(self.agent_start)
35
+ self.has_bucket = False
36
+ self.fire_out = False
37
+ self.steps = 0
38
+
39
+ obs = self._get_obs()
40
+ return obs, {}
41
+
42
+ def step(self, action):
43
+ self.steps += 1
44
+ x, y = self.agent_pos
45
+
46
+ move = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}
47
+ dx, dy = move[action]
48
+
49
+ # Stochastic transitions
50
+ if np.random.rand() > 0.8:
51
+ dx, dy = move[np.random.choice([a for a in move if a != action])]
52
+
53
+ new_x = np.clip(x + dx, 0, self.grid_size - 1)
54
+ new_y = np.clip(y + dy, 0, self.grid_size - 1)
55
+
56
+ if (new_x, new_y) in self.walls:
57
+ reward = -5
58
+ else:
59
+ self.agent_pos = [new_x, new_y]
60
+ reward = 0
61
+
62
+ # Bucket collection
63
+ if tuple(self.agent_pos) == self.bucket_pos and not self.has_bucket:
64
+ self.has_bucket = True
65
+ reward += 10
66
+
67
+ # Extinguish fire
68
+ if tuple(self.agent_pos) == self.fire_pos and self.has_bucket and not self.fire_out:
69
+ self.fire_out = True
70
+ reward += 10
71
+
72
+ # Reaching goal
73
+ if tuple(self.agent_pos) == self.goal_pos:
74
+ if self.fire_out:
75
+ reward += 10
76
+ terminated = True
77
+ else:
78
+ reward -= 10
79
+ terminated = True
80
+ else:
81
+ terminated = False
82
+
83
+ truncated = self.steps >= self.max_steps
84
+
85
+ obs = self._get_obs()
86
+ info = {}
87
+ return obs, reward, terminated, truncated, info
88
+
89
+ def _get_obs(self):
90
+ return np.array([*self.agent_pos, int(self.has_bucket), int(self.fire_out)], dtype=np.int32)
91
+ def save_state(self):
92
+ return copy.deepcopy(self)
93
+
94
+ def reset_to(self, saved_env):
95
+ self.__dict__.update(saved_env.__dict__)
96
+
97
+
98
+ def render(self):
99
+ pass
100
+
101
+ def close(self):
102
+ pass
env/renderer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use('Agg')
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.image as mpimg
5
+ import numpy as np
6
+ import os
7
+
8
+ class FireFighterRenderer:
9
+ def __init__(self, save_dir=None):
10
+ self.grid_size = 4
11
+ self.save_dir = save_dir
12
+ self.frame_idx = 0
13
+ self.sprite_dir = "assets/sprites"
14
+ self.rewards = []
15
+
16
+ self.fig, (self.ax_grid, self.ax_plot) = plt.subplots(1, 2, figsize=(10, 5),
17
+ gridspec_kw={'width_ratios': [1, 1]})
18
+
19
+ # Load sprites
20
+ def load(name):
21
+ path = os.path.join(self.sprite_dir, f"{name}.png")
22
+ img = mpimg.imread(path)
23
+ assert img is not None, f"Failed to load: {path}"
24
+ return img
25
+
26
+ self.sprites = {
27
+ "robot_white": load("robot_white"),
28
+ "robot_blue": load("robot_blue"),
29
+ "robot_green": load("robot_green"),
30
+ "bucket": load("bucket"),
31
+ "fire": load("fire"),
32
+ "goal": load("goal"),
33
+ "wall": load("wall"),
34
+ }
35
+
36
+ if save_dir:
37
+ os.makedirs(save_dir, exist_ok=True)
38
+
39
+ def render(self, agent_pos, has_bucket, fire_out, bucket_pos, fire_pos, goal_pos, walls, reward):
40
+ self.ax_grid.clear()
41
+ self.ax_plot.clear()
42
+ self.rewards.append(reward)
43
+
44
+ self.ax_grid.set_xlim(0, self.grid_size)
45
+ self.ax_grid.set_ylim(0, self.grid_size)
46
+ self.ax_grid.set_xticks([])
47
+ self.ax_grid.set_yticks([])
48
+ self.ax_grid.set_aspect('equal')
49
+
50
+ # White background grid
51
+ for x in range(self.grid_size):
52
+ for y in range(self.grid_size):
53
+ self.ax_grid.add_patch(plt.Rectangle((y, self.grid_size - 1 - x), 1, 1,
54
+ edgecolor='black', facecolor='white', linewidth=1))
55
+
56
+ # Function to place sprites
57
+ def draw(sprite_name, x, y):
58
+ sprite = self.sprites[sprite_name]
59
+ self.ax_grid.imshow(sprite,
60
+ extent=(y, y + 1, self.grid_size - 1 - x, self.grid_size - x),
61
+ zorder=10)
62
+
63
+ # Draw elements
64
+ for wx, wy in walls:
65
+ draw("wall", wx, wy)
66
+
67
+ if not fire_out:
68
+ fx, fy = fire_pos
69
+ draw("fire", fx, fy)
70
+
71
+ bx, by = bucket_pos
72
+ draw("bucket", bx, by)
73
+
74
+ gx, gy = goal_pos
75
+ draw("goal", gx, gy)
76
+
77
+ ax, ay = agent_pos
78
+ if has_bucket and not fire_out:
79
+ robot_color = "robot_blue"
80
+ elif fire_out:
81
+ robot_color = "robot_green"
82
+ else:
83
+ robot_color = "robot_white"
84
+ draw(robot_color, ax, ay)
85
+
86
+ self.ax_grid.set_title(f"Step {self.frame_idx}")
87
+
88
+ # Reward plot
89
+ self.ax_plot.plot(np.cumsum(self.rewards), color='green', marker='o')
90
+ self.ax_plot.set_title("Cumulative Reward")
91
+ self.ax_plot.set_xlabel("Step")
92
+ self.ax_plot.set_ylabel("Total Reward")
93
+ self.ax_plot.grid(True)
94
+
95
+ if self.save_dir:
96
+ frame_path = os.path.join(self.save_dir, f"frame_{self.frame_idx:03d}.png")
97
+ self.fig.tight_layout()
98
+ self.fig.savefig(frame_path)
99
+ self.frame_idx += 1
100
+ else:
101
+ plt.pause(0.3)
102
+ plt.draw()
103
+
104
+ def close(self):
105
+ plt.close(self.fig)
106
+
env/utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ def format_observation(obs):
2
+ return {
3
+ 'x': obs[0],
4
+ 'y': obs[1],
5
+ 'has_bucket': bool(obs[2]),
6
+ 'fire_out': bool(obs[3]),
7
+ }
8
+
9
+ def is_terminal_state(obs):
10
+ return (obs[0], obs[1]) == (3, 3) and obs[3] == 1