SnakeAI_TF_PPO_V0 / Trained_PPO_Agent.py
privateboss's picture
Upload 8 files
2df2f26 verified
import gymnasium as gym
from Snake_EnvAndAgent import SnakeGameEnv
from PPO_Model import PPOAgent
import os
import time
import numpy as np
from plot_utility_Trained_Agent import init_live_plot, update_live_plot, save_live_plot_final, smooth_curve
PLAY_CONFIG = {
'grid_size': 30,
'model_path_prefix': 'snake_ppo_models/ppo_snake_final',
'num_episodes_to_play': 100,
'render_fps': 10,
'live_plot_interval_episodes': 1,
'plot_smoothing_factor': 0.8
}
PLOT_SAVE_DIR = 'snake_ppo_plots'
os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
def play_agent():
print("Initializing environment for playback...")
env = SnakeGameEnv(render_mode='human')
if PLAY_CONFIG['render_fps'] > 0:
env.metadata["render_fps"] = PLAY_CONFIG['render_fps']
obs_shape = env.observation_space.shape
action_size = env.action_space.n
agent = PPOAgent(
observation_space_shape=obs_shape,
action_space_size=action_size,
actor_lr=3e-4,
critic_lr=3e-4,
hidden_layer_sizes=[512, 512, 512]
)
print(f"loading models from: {PLAY_CONFIG['model_path_prefix']}")
load_success = agent.load_models(PLAY_CONFIG['model_path_prefix'])
if not load_success:
print("\nFATAL ERROR: Failed to load trained models from disk. The agent CANNOT perform as trained. Exiting playback!.")
env.close()
return
print("--- Trained models loaded successfully. Loading playback. ---")
print("Starting agent playback...")
episode_rewards_playback = []
fig, ax, line = init_live_plot(PLOT_SAVE_DIR, filename="live_playback_rewards_plot.png")
ax.set_title('Live Playback Progress (Episode Rewards)')
ax.set_xlabel('Episode')
ax.set_ylabel('Total Reward')
for episode in range(PLAY_CONFIG['num_episodes_to_play']):
state, info = env.reset()
current_action_mask = info['action_mask']
done = False
episode_reward = 0
steps = 0
while not done:
action, _, _ = agent.choose_action(state, current_action_mask)
next_state, reward, terminated, truncated, info = env.step(action)
episode_reward += reward
state = next_state
steps += 1
done = terminated or truncated
current_action_mask = info['action_mask']
if PLAY_CONFIG['render_fps'] > 0:
time.sleep(1 / env.metadata["render_fps"])
episode_rewards_playback.append(episode_reward)
print(f"Episode {episode + 1}: Total Reward = {episode_reward:.2f}, Score = {info['score']}, Steps = {steps}")
if (episode + 1) % PLAY_CONFIG['live_plot_interval_episodes'] == 0:
current_episodes = list(range(1, len(episode_rewards_playback) + 1))
smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor'])
update_live_plot(fig, ax, line, current_episodes, smoothed_rewards,
current_timestep=episode + 1,
total_timesteps=PLAY_CONFIG['num_episodes_to_play'])
time.sleep(0.5)
env.close()
print("\nPlayback finished.")
current_episodes = list(range(1, len(episode_rewards_playback) + 1))
smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor'])
update_live_plot(fig, ax, line, current_episodes, smoothed_rewards,
current_timestep=PLAY_CONFIG['num_episodes_to_play'],
total_timesteps=PLAY_CONFIG['num_episodes_to_play'])
save_live_plot_final(fig, ax)
avg_playback_reward = np.mean(episode_rewards_playback)
print(f"Average Reward over {PLAY_CONFIG['num_episodes_to_play']} playback episodes: {avg_playback_reward:.2f}")
if __name__ == "__main__":
play_agent()