Create Trained_PPO_Agent.py
Browse files- Trained_PPO_Agent.py +103 -0
Trained_PPO_Agent.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from Snake_EnvAndAgent import SnakeGameEnv
|
| 3 |
+
from PPO_Model import PPOAgent
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import numpy as np
|
| 7 |
+
from plot_utility_Trained_Agent import init_live_plot, update_live_plot, save_live_plot_final, smooth_curve
|
| 8 |
+
|
| 9 |
+
PLAY_CONFIG = {
|
| 10 |
+
'grid_size': 30,
|
| 11 |
+
'model_path_prefix': 'snake_ppo_models/ppo_snake_final',
|
| 12 |
+
'num_episodes_to_play': 100,
|
| 13 |
+
'render_fps': 10,
|
| 14 |
+
'live_plot_interval_episodes': 1,
|
| 15 |
+
'plot_smoothing_factor': 0.8
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
PLOT_SAVE_DIR = 'snake_ppo_plots'
|
| 19 |
+
os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def play_agent():
|
| 23 |
+
print("Initializing environment for playback...")
|
| 24 |
+
env = SnakeGameEnv(render_mode='human')
|
| 25 |
+
|
| 26 |
+
if PLAY_CONFIG['render_fps'] > 0:
|
| 27 |
+
env.metadata["render_fps"] = PLAY_CONFIG['render_fps']
|
| 28 |
+
|
| 29 |
+
obs_shape = env.observation_space.shape
|
| 30 |
+
action_size = env.action_space.n
|
| 31 |
+
|
| 32 |
+
agent = PPOAgent(
|
| 33 |
+
observation_space_shape=obs_shape,
|
| 34 |
+
action_space_size=action_size,
|
| 35 |
+
actor_lr=3e-4,
|
| 36 |
+
critic_lr=3e-4,
|
| 37 |
+
hidden_layer_sizes=[512, 512, 512]
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
print(f"loading models from: {PLAY_CONFIG['model_path_prefix']}")
|
| 41 |
+
|
| 42 |
+
load_success = agent.load_models(PLAY_CONFIG['model_path_prefix'])
|
| 43 |
+
|
| 44 |
+
if not load_success:
|
| 45 |
+
print("\nFATAL ERROR: Failed to load trained models from disk. The agent CANNOT perform as trained. Exiting playback!.")
|
| 46 |
+
env.close()
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
print("--- Trained models loaded successfully. Loading playback. ---")
|
| 50 |
+
|
| 51 |
+
print("Starting agent playback...")
|
| 52 |
+
|
| 53 |
+
episode_rewards_playback = []
|
| 54 |
+
|
| 55 |
+
fig, ax, line = init_live_plot(PLOT_SAVE_DIR, filename="live_playback_rewards_plot.png")
|
| 56 |
+
ax.set_title('Live Playback Progress (Episode Rewards)')
|
| 57 |
+
ax.set_xlabel('Episode')
|
| 58 |
+
ax.set_ylabel('Total Reward')
|
| 59 |
+
|
| 60 |
+
for episode in range(PLAY_CONFIG['num_episodes_to_play']):
|
| 61 |
+
state, info = env.reset()
|
| 62 |
+
done = False
|
| 63 |
+
episode_reward = 0
|
| 64 |
+
steps = 0
|
| 65 |
+
|
| 66 |
+
while not done:
|
| 67 |
+
action, _, _ = agent.choose_action(state)
|
| 68 |
+
|
| 69 |
+
next_state, reward, terminated, truncated, info = env.step(action)
|
| 70 |
+
episode_reward += reward
|
| 71 |
+
state = next_state
|
| 72 |
+
steps += 1
|
| 73 |
+
done = terminated or truncated
|
| 74 |
+
|
| 75 |
+
episode_rewards_playback.append(episode_reward)
|
| 76 |
+
|
| 77 |
+
print(f"Episode {episode + 1}: Total Reward = {episode_reward:.2f}, Score = {info['score']}, Steps = {steps}")
|
| 78 |
+
|
| 79 |
+
if (episode + 1) % PLAY_CONFIG['live_plot_interval_episodes'] == 0:
|
| 80 |
+
current_episodes = list(range(1, len(episode_rewards_playback) + 1))
|
| 81 |
+
smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor'])
|
| 82 |
+
update_live_plot(fig, ax, line, current_episodes, smoothed_rewards,
|
| 83 |
+
current_timestep=episode + 1,
|
| 84 |
+
total_timesteps=PLAY_CONFIG['num_episodes_to_play'])
|
| 85 |
+
|
| 86 |
+
time.sleep(0.5)
|
| 87 |
+
|
| 88 |
+
env.close()
|
| 89 |
+
print("\nPlayback finished.")
|
| 90 |
+
|
| 91 |
+
current_episodes = list(range(1, len(episode_rewards_playback) + 1))
|
| 92 |
+
smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor'])
|
| 93 |
+
update_live_plot(fig, ax, line, current_episodes, smoothed_rewards,
|
| 94 |
+
current_timestep=PLAY_CONFIG['num_episodes_to_play'],
|
| 95 |
+
total_timesteps=PLAY_CONFIG['num_episodes_to_play'])
|
| 96 |
+
save_live_plot_final(fig, ax)
|
| 97 |
+
|
| 98 |
+
avg_playback_reward = np.mean(episode_rewards_playback)
|
| 99 |
+
print(f"Average Reward over {PLAY_CONFIG['num_episodes_to_play']} playback episodes: {avg_playback_reward:.2f}")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
play_agent()
|