Upload folder using huggingface_hub
Browse files
evaluation/__pycache__/generate_video.cpython-310.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
evaluation/evaluate_mcts.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from env.firefighter_env import FireFighterEnv
|
| 4 |
+
from env.renderer import FireFighterRenderer
|
| 5 |
+
import copy
|
| 6 |
+
from tqdm import trange
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
SAVE_DIR = "assets/mcts_frames"
|
| 10 |
+
GIF_PATH = "assets/firefighter_mcts_success.gif"
|
| 11 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
class MCTSAgent:
|
| 14 |
+
def __init__(self, env, n_simulations=50):
|
| 15 |
+
self.env = env
|
| 16 |
+
self.n_simulations = n_simulations
|
| 17 |
+
|
| 18 |
+
def rollout(self, env_copy):
|
| 19 |
+
total_reward = 0
|
| 20 |
+
for _ in range(10):
|
| 21 |
+
action = env_copy.action_space.sample()
|
| 22 |
+
_, reward, terminated, truncated, _ = env_copy.step(action)
|
| 23 |
+
total_reward += reward
|
| 24 |
+
if terminated or truncated:
|
| 25 |
+
break
|
| 26 |
+
return total_reward
|
| 27 |
+
|
| 28 |
+
def select_action(self, state):
|
| 29 |
+
action_rewards = np.zeros(self.env.action_space.n)
|
| 30 |
+
|
| 31 |
+
for action in range(self.env.action_space.n):
|
| 32 |
+
reward_sum = 0
|
| 33 |
+
for _ in range(self.n_simulations):
|
| 34 |
+
env_copy = copy.deepcopy(self.env)
|
| 35 |
+
env_copy.reset_to(state)
|
| 36 |
+
_, reward, _, _, _ = env_copy.step(action)
|
| 37 |
+
reward += self.rollout(env_copy)
|
| 38 |
+
reward_sum += reward
|
| 39 |
+
action_rewards[action] = reward_sum / self.n_simulations
|
| 40 |
+
|
| 41 |
+
return np.argmax(action_rewards)
|
| 42 |
+
|
| 43 |
+
# Run one visualized episode
|
| 44 |
+
def run_mcts_episode():
|
| 45 |
+
env = FireFighterEnv()
|
| 46 |
+
obs, _ = env.reset()
|
| 47 |
+
agent = MCTSAgent(env)
|
| 48 |
+
renderer = FireFighterRenderer(save_dir=SAVE_DIR)
|
| 49 |
+
|
| 50 |
+
total_reward = 0
|
| 51 |
+
|
| 52 |
+
for _ in range(60):
|
| 53 |
+
state = env.save_state()
|
| 54 |
+
action = agent.select_action(state)
|
| 55 |
+
obs, reward, terminated, truncated, _ = env.step(action)
|
| 56 |
+
total_reward += reward
|
| 57 |
+
|
| 58 |
+
renderer.render(env.agent_pos, env.has_bucket, env.fire_out,
|
| 59 |
+
env.bucket_pos, env.fire_pos, env.goal_pos, env.walls, reward)
|
| 60 |
+
if terminated or truncated:
|
| 61 |
+
break
|
| 62 |
+
|
| 63 |
+
renderer.close()
|
| 64 |
+
print(f"Total Reward (MCTS): {total_reward}")
|
| 65 |
+
|
| 66 |
+
# Generate GIF
|
| 67 |
+
frame_files = sorted([f for f in os.listdir(SAVE_DIR) if f.endswith(".png")])
|
| 68 |
+
frames = [Image.open(os.path.join(SAVE_DIR, f)) for f in frame_files]
|
| 69 |
+
frames[0].save(GIF_PATH, format='GIF', append_images=frames[1:], save_all=True, duration=300, loop=0)
|
| 70 |
+
print(f"✅ MCTS animation saved to {GIF_PATH}")
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
run_mcts_episode()
|
| 74 |
+
|
evaluation/generate_video.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import imageio
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def generate_gif(frame_dir="assets/assets_video", output_path="assets/firefighter_episode.gif", fps=3, cleanup=True):
|
| 5 |
+
files = sorted(f for f in os.listdir(frame_dir) if f.endswith(".png"))
|
| 6 |
+
if not files:
|
| 7 |
+
print("⚠️ No frames found to generate GIF.")
|
| 8 |
+
return
|
| 9 |
+
|
| 10 |
+
images = [imageio.v2.imread(os.path.join(frame_dir, f)) for f in files]
|
| 11 |
+
imageio.mimsave(output_path, images, fps=fps)
|
| 12 |
+
print(f"✅ GIF saved to {output_path}")
|
| 13 |
+
|
| 14 |
+
if cleanup:
|
| 15 |
+
for f in files:
|
| 16 |
+
os.remove(os.path.join(frame_dir, f))
|
| 17 |
+
print("🧹 Frame PNGs deleted after GIF creation.")
|
| 18 |
+
|
evaluation/plot_rewards.py
ADDED
|
File without changes
|
evaluation/visualize_trajectories.py
ADDED
|
File without changes
|