privateboss commited on
Commit
285e9ec
·
verified ·
1 Parent(s): 6797b6d

Create Trained_PPO_Agent.py

Browse files
Files changed (1) hide show
  1. Trained_PPO_Agent.py +103 -0
Trained_PPO_Agent.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from Snake_EnvAndAgent import SnakeGameEnv
3
+ from PPO_Model import PPOAgent
4
+ import os
5
+ import time
6
+ import numpy as np
7
+ from plot_utility_Trained_Agent import init_live_plot, update_live_plot, save_live_plot_final, smooth_curve
8
+
9
+ PLAY_CONFIG = {
10
+     'grid_size': 30,
11
+     'model_path_prefix': 'snake_ppo_models/ppo_snake_final',
12
+     'num_episodes_to_play': 100,
13
+     'render_fps': 10,
14
+     'live_plot_interval_episodes': 1,
15
+     'plot_smoothing_factor': 0.8
16
+ }
17
+
18
+ PLOT_SAVE_DIR = 'snake_ppo_plots'
19
+ os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
20
+
21
+
22
+ def play_agent():
23
+     print("Initializing environment for playback...")
24
+     env = SnakeGameEnv(render_mode='human')
25
+    
26
+     if PLAY_CONFIG['render_fps'] > 0:
27
+         env.metadata["render_fps"] = PLAY_CONFIG['render_fps']
28
+
29
+     obs_shape = env.observation_space.shape
30
+     action_size = env.action_space.n
31
+
32
+     agent = PPOAgent(
33
+         observation_space_shape=obs_shape,
34
+         action_space_size=action_size,
35
+         actor_lr=3e-4,
36
+         critic_lr=3e-4,
37
+         hidden_layer_sizes=[512, 512, 512]
38
+     )
39
+
40
+     print(f"loading models from: {PLAY_CONFIG['model_path_prefix']}")
41
+    
42
+     load_success = agent.load_models(PLAY_CONFIG['model_path_prefix'])
43
+    
44
+     if not load_success:
45
+         print("\nFATAL ERROR: Failed to load trained models from disk. The agent CANNOT perform as trained. Exiting playback!.")
46
+         env.close()
47
+         return
48
+
49
+     print("--- Trained models loaded successfully. Loading playback. ---")
50
+
51
+     print("Starting agent playback...")
52
+    
53
+     episode_rewards_playback = []
54
+    
55
+     fig, ax, line = init_live_plot(PLOT_SAVE_DIR, filename="live_playback_rewards_plot.png")
56
+     ax.set_title('Live Playback Progress (Episode Rewards)')
57
+     ax.set_xlabel('Episode')
58
+     ax.set_ylabel('Total Reward')
59
+
60
+     for episode in range(PLAY_CONFIG['num_episodes_to_play']):
61
+         state, info = env.reset()
62
+         done = False
63
+         episode_reward = 0
64
+         steps = 0
65
+
66
+         while not done:
67
+             action, _, _ = agent.choose_action(state)
68
+
69
+             next_state, reward, terminated, truncated, info = env.step(action)
70
+             episode_reward += reward
71
+             state = next_state
72
+             steps += 1
73
+             done = terminated or truncated
74
+
75
+         episode_rewards_playback.append(episode_reward)
76
+        
77
+         print(f"Episode {episode + 1}: Total Reward = {episode_reward:.2f}, Score = {info['score']}, Steps = {steps}")
78
+        
79
+         if (episode + 1) % PLAY_CONFIG['live_plot_interval_episodes'] == 0:
80
+             current_episodes = list(range(1, len(episode_rewards_playback) + 1))
81
+             smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor'])
82
+             update_live_plot(fig, ax, line, current_episodes, smoothed_rewards,
83
+                              current_timestep=episode + 1,
84
+                              total_timesteps=PLAY_CONFIG['num_episodes_to_play'])
85
+
86
+         time.sleep(0.5)
87
+
88
+     env.close()
89
+     print("\nPlayback finished.")
90
+    
91
+     current_episodes = list(range(1, len(episode_rewards_playback) + 1))
92
+     smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor'])
93
+     update_live_plot(fig, ax, line, current_episodes, smoothed_rewards,
94
+                      current_timestep=PLAY_CONFIG['num_episodes_to_play'],
95
+                      total_timesteps=PLAY_CONFIG['num_episodes_to_play'])
96
+     save_live_plot_final(fig, ax)
97
+
98
+     avg_playback_reward = np.mean(episode_rewards_playback)
99
+     print(f"Average Reward over {PLAY_CONFIG['num_episodes_to_play']} playback episodes: {avg_playback_reward:.2f}")
100
+
101
+
102
+ if __name__ == "__main__":
103
+     play_agent()