| import gymnasium as gym | |
| from Snake_EnvAndAgent import SnakeGameEnv | |
| from PPO_Model import PPOAgent | |
| import os | |
| import time | |
| import numpy as np | |
| from plot_utility_Trained_Agent import init_live_plot, update_live_plot, save_live_plot_final, smooth_curve | |
| PLAY_CONFIG = { | |
| 'grid_size': 30, | |
| 'model_path_prefix': 'snake_ppo_models/ppo_snake_final', | |
| 'num_episodes_to_play': 100, | |
| 'render_fps': 10, | |
| 'live_plot_interval_episodes': 1, | |
| 'plot_smoothing_factor': 0.8 | |
| } | |
| PLOT_SAVE_DIR = 'snake_ppo_plots' | |
| os.makedirs(PLOT_SAVE_DIR, exist_ok=True) | |
| def play_agent(): | |
| print("Initializing environment for playback...") | |
| env = SnakeGameEnv(render_mode='human') | |
| if PLAY_CONFIG['render_fps'] > 0: | |
| env.metadata["render_fps"] = PLAY_CONFIG['render_fps'] | |
| obs_shape = env.observation_space.shape | |
| action_size = env.action_space.n | |
| agent = PPOAgent( | |
| observation_space_shape=obs_shape, | |
| action_space_size=action_size, | |
| actor_lr=3e-4, | |
| critic_lr=3e-4, | |
| hidden_layer_sizes=[512, 512, 512] | |
| ) | |
| print(f"loading models from: {PLAY_CONFIG['model_path_prefix']}") | |
| load_success = agent.load_models(PLAY_CONFIG['model_path_prefix']) | |
| if not load_success: | |
| print("\nFATAL ERROR: Failed to load trained models from disk. The agent CANNOT perform as trained. Exiting playback!.") | |
| env.close() | |
| return | |
| print("--- Trained models loaded successfully. Loading playback. ---") | |
| print("Starting agent playback...") | |
| episode_rewards_playback = [] | |
| fig, ax, line = init_live_plot(PLOT_SAVE_DIR, filename="live_playback_rewards_plot.png") | |
| ax.set_title('Live Playback Progress (Episode Rewards)') | |
| ax.set_xlabel('Episode') | |
| ax.set_ylabel('Total Reward') | |
| for episode in range(PLAY_CONFIG['num_episodes_to_play']): | |
| state, info = env.reset() | |
| current_action_mask = info['action_mask'] | |
| done = False | |
| episode_reward = 0 | |
| steps = 0 | |
| while not done: | |
| action, _, _ = agent.choose_action(state, current_action_mask) | |
| next_state, reward, terminated, truncated, info = env.step(action) | |
| episode_reward += reward | |
| state = next_state | |
| steps += 1 | |
| done = terminated or truncated | |
| current_action_mask = info['action_mask'] | |
| if PLAY_CONFIG['render_fps'] > 0: | |
| time.sleep(1 / env.metadata["render_fps"]) | |
| episode_rewards_playback.append(episode_reward) | |
| print(f"Episode {episode + 1}: Total Reward = {episode_reward:.2f}, Score = {info['score']}, Steps = {steps}") | |
| if (episode + 1) % PLAY_CONFIG['live_plot_interval_episodes'] == 0: | |
| current_episodes = list(range(1, len(episode_rewards_playback) + 1)) | |
| smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor']) | |
| update_live_plot(fig, ax, line, current_episodes, smoothed_rewards, | |
| current_timestep=episode + 1, | |
| total_timesteps=PLAY_CONFIG['num_episodes_to_play']) | |
| time.sleep(0.5) | |
| env.close() | |
| print("\nPlayback finished.") | |
| current_episodes = list(range(1, len(episode_rewards_playback) + 1)) | |
| smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor']) | |
| update_live_plot(fig, ax, line, current_episodes, smoothed_rewards, | |
| current_timestep=PLAY_CONFIG['num_episodes_to_play'], | |
| total_timesteps=PLAY_CONFIG['num_episodes_to_play']) | |
| save_live_plot_final(fig, ax) | |
| avg_playback_reward = np.mean(episode_rewards_playback) | |
| print(f"Average Reward over {PLAY_CONFIG['num_episodes_to_play']} playback episodes: {avg_playback_reward:.2f}") | |
| if __name__ == "__main__": | |
| play_agent() |