| | import matplotlib |
| | matplotlib.use('Agg') |
| | import matplotlib.pyplot as plt |
| | import matplotlib.image as mpimg |
| | import numpy as np |
| | import os |
| |
|
| | class FireFighterRenderer: |
| | def __init__(self, save_dir=None): |
| | self.grid_size = 4 |
| | self.save_dir = save_dir |
| | self.frame_idx = 0 |
| | self.sprite_dir = "assets/sprites" |
| | self.rewards = [] |
| |
|
| | self.fig, (self.ax_grid, self.ax_plot) = plt.subplots(1, 2, figsize=(10, 5), |
| | gridspec_kw={'width_ratios': [1, 1]}) |
| |
|
| | |
| | def load(name): |
| | path = os.path.join(self.sprite_dir, f"{name}.png") |
| | img = mpimg.imread(path) |
| | assert img is not None, f"Failed to load: {path}" |
| | return img |
| |
|
| | self.sprites = { |
| | "robot_white": load("robot_white"), |
| | "robot_blue": load("robot_blue"), |
| | "robot_green": load("robot_green"), |
| | "bucket": load("bucket"), |
| | "fire": load("fire"), |
| | "goal": load("goal"), |
| | "wall": load("wall"), |
| | } |
| |
|
| | if save_dir: |
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | def render(self, agent_pos, has_bucket, fire_out, bucket_pos, fire_pos, goal_pos, walls, reward): |
| | self.ax_grid.clear() |
| | self.ax_plot.clear() |
| | self.rewards.append(reward) |
| |
|
| | self.ax_grid.set_xlim(0, self.grid_size) |
| | self.ax_grid.set_ylim(0, self.grid_size) |
| | self.ax_grid.set_xticks([]) |
| | self.ax_grid.set_yticks([]) |
| | self.ax_grid.set_aspect('equal') |
| |
|
| | |
| | for x in range(self.grid_size): |
| | for y in range(self.grid_size): |
| | self.ax_grid.add_patch(plt.Rectangle((y, self.grid_size - 1 - x), 1, 1, |
| | edgecolor='black', facecolor='white', linewidth=1)) |
| |
|
| | |
| | def draw(sprite_name, x, y): |
| | sprite = self.sprites[sprite_name] |
| | self.ax_grid.imshow(sprite, |
| | extent=(y, y + 1, self.grid_size - 1 - x, self.grid_size - x), |
| | zorder=10) |
| |
|
| | |
| | for wx, wy in walls: |
| | draw("wall", wx, wy) |
| |
|
| | if not fire_out: |
| | fx, fy = fire_pos |
| | draw("fire", fx, fy) |
| |
|
| | bx, by = bucket_pos |
| | draw("bucket", bx, by) |
| |
|
| | gx, gy = goal_pos |
| | draw("goal", gx, gy) |
| |
|
| | ax, ay = agent_pos |
| | if has_bucket and not fire_out: |
| | robot_color = "robot_blue" |
| | elif fire_out: |
| | robot_color = "robot_green" |
| | else: |
| | robot_color = "robot_white" |
| | draw(robot_color, ax, ay) |
| |
|
| | self.ax_grid.set_title(f"Step {self.frame_idx}") |
| |
|
| | |
| | self.ax_plot.plot(np.cumsum(self.rewards), color='green', marker='o') |
| | self.ax_plot.set_title("Cumulative Reward") |
| | self.ax_plot.set_xlabel("Step") |
| | self.ax_plot.set_ylabel("Total Reward") |
| | self.ax_plot.grid(True) |
| |
|
| | if self.save_dir: |
| | frame_path = os.path.join(self.save_dir, f"frame_{self.frame_idx:03d}.png") |
| | self.fig.tight_layout() |
| | self.fig.savefig(frame_path) |
| | self.frame_idx += 1 |
| | else: |
| | plt.pause(0.3) |
| | plt.draw() |
| |
|
| | def close(self): |
| | plt.close(self.fig) |
| |
|
| |
|