Punit71 commited on
Commit
1e2b658
·
verified ·
1 Parent(s): 1346b12

Upload folder using huggingface_hub

Browse files
evaluation/__pycache__/generate_video.cpython-310.pyc ADDED
Binary file (1.11 kB). View file
 
evaluation/evaluate_mcts.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from env.firefighter_env import FireFighterEnv
4
+ from env.renderer import FireFighterRenderer
5
+ import copy
6
+ from tqdm import trange
7
+ from PIL import Image
8
+
9
+ SAVE_DIR = "assets/mcts_frames"
10
+ GIF_PATH = "assets/firefighter_mcts_success.gif"
11
+ os.makedirs(SAVE_DIR, exist_ok=True)
12
+
13
+ class MCTSAgent:
14
+ def __init__(self, env, n_simulations=50):
15
+ self.env = env
16
+ self.n_simulations = n_simulations
17
+
18
+ def rollout(self, env_copy):
19
+ total_reward = 0
20
+ for _ in range(10):
21
+ action = env_copy.action_space.sample()
22
+ _, reward, terminated, truncated, _ = env_copy.step(action)
23
+ total_reward += reward
24
+ if terminated or truncated:
25
+ break
26
+ return total_reward
27
+
28
+ def select_action(self, state):
29
+ action_rewards = np.zeros(self.env.action_space.n)
30
+
31
+ for action in range(self.env.action_space.n):
32
+ reward_sum = 0
33
+ for _ in range(self.n_simulations):
34
+ env_copy = copy.deepcopy(self.env)
35
+ env_copy.reset_to(state)
36
+ _, reward, _, _, _ = env_copy.step(action)
37
+ reward += self.rollout(env_copy)
38
+ reward_sum += reward
39
+ action_rewards[action] = reward_sum / self.n_simulations
40
+
41
+ return np.argmax(action_rewards)
42
+
43
+ # Run one visualized episode
44
+ def run_mcts_episode():
45
+ env = FireFighterEnv()
46
+ obs, _ = env.reset()
47
+ agent = MCTSAgent(env)
48
+ renderer = FireFighterRenderer(save_dir=SAVE_DIR)
49
+
50
+ total_reward = 0
51
+
52
+ for _ in range(60):
53
+ state = env.save_state()
54
+ action = agent.select_action(state)
55
+ obs, reward, terminated, truncated, _ = env.step(action)
56
+ total_reward += reward
57
+
58
+ renderer.render(env.agent_pos, env.has_bucket, env.fire_out,
59
+ env.bucket_pos, env.fire_pos, env.goal_pos, env.walls, reward)
60
+ if terminated or truncated:
61
+ break
62
+
63
+ renderer.close()
64
+ print(f"Total Reward (MCTS): {total_reward}")
65
+
66
+ # Generate GIF
67
+ frame_files = sorted([f for f in os.listdir(SAVE_DIR) if f.endswith(".png")])
68
+ frames = [Image.open(os.path.join(SAVE_DIR, f)) for f in frame_files]
69
+ frames[0].save(GIF_PATH, format='GIF', append_images=frames[1:], save_all=True, duration=300, loop=0)
70
+ print(f"✅ MCTS animation saved to {GIF_PATH}")
71
+
72
+ if __name__ == "__main__":
73
+ run_mcts_episode()
74
+
evaluation/generate_video.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio
2
+ import os
3
+
4
+ def generate_gif(frame_dir="assets/assets_video", output_path="assets/firefighter_episode.gif", fps=3, cleanup=True):
5
+ files = sorted(f for f in os.listdir(frame_dir) if f.endswith(".png"))
6
+ if not files:
7
+ print("⚠️ No frames found to generate GIF.")
8
+ return
9
+
10
+ images = [imageio.v2.imread(os.path.join(frame_dir, f)) for f in files]
11
+ imageio.mimsave(output_path, images, fps=fps)
12
+ print(f"✅ GIF saved to {output_path}")
13
+
14
+ if cleanup:
15
+ for f in files:
16
+ os.remove(os.path.join(frame_dir, f))
17
+ print("🧹 Frame PNGs deleted after GIF creation.")
18
+
evaluation/plot_rewards.py ADDED
File without changes
evaluation/visualize_trajectories.py ADDED
File without changes