SnakeAI_TF_PPO_V0 / PPO_Trainer.py
privateboss's picture
Upload 8 files
2df2f26 verified
import gymnasium as gym
from Snake_EnvAndAgent import SnakeGameEnv
from PPO_Model import PPOAgent
import numpy as np
import time
import os
import json # For saving/loading training state
from plot_utility_Trainer import plot_rewards, smooth_curve, init_live_plot, update_live_plot, save_live_plot_final
HYPERPARAMETERS = {
'grid_size': 30, # This is used for environment initialization
'actor_lr': 0.0003,
'critic_lr': 0.0003,
'gamma': 0.99,
'gae_lambda': 0.95,
'clip_epsilon': 0.2,
'num_epochs_per_update': 10,
'batch_size': 64,
'num_steps_per_rollout': 2048, # Number of steps to collect before a learning update
'total_timesteps': 10_000_000, # Total environmental steps to train for
'hidden_layer_sizes': [512, 512, 512],
'save_interval_timesteps': 400000, # Save models every N total timesteps
'log_interval_episodes': 10, # Log training progress every N episodes
'render_training': False, # Set to True to see rendering during training (will slow down)
'render_fps_limit': 10, # Limits render FPS, if 0, renders as fast as possible (can be too fast)
'plot_smoothing_factor': 0.9, # For smoothing the reward plot
'live_plot_interval_episodes': 100, # Update live plot every N episodes
'resume_training': True # Set to True to attempt to resume from latest checkpoint
}
# Directory for saving models and plots
MODEL_SAVE_DIR = 'snake_ppo_models'
PLOT_SAVE_DIR = 'snake_ppo_plots'
TRAINING_STATE_FILE = os.path.join(MODEL_SAVE_DIR, 'training_state.json')
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
print(f"Model save directory created/checked: {os.path.abspath(MODEL_SAVE_DIR)}")
print(f"Plot save directory created/checked: {os.path.abspath(PLOT_SAVE_DIR)}")
def save_training_state(total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history):
state = {
'total_timesteps_trained': total_timesteps_trained,
'episode_count': episode_count,
'all_episode_rewards': all_episode_rewards,
'plot_rewards_history': plot_rewards_history
}
with open(TRAINING_STATE_FILE, 'w') as f:
json.dump(state, f)
print(f"Training state saved to {TRAINING_STATE_FILE}")
def load_training_state():
if os.path.exists(TRAINING_STATE_FILE):
with open(TRAINING_STATE_FILE, 'r') as f:
state = json.load(f)
print(f"Training state loaded from {TRAINING_STATE_FILE}")
return state['total_timesteps_trained'], \
state['episode_count'], \
state['all_episode_rewards'], \
state['plot_rewards_history']
return 0, 0, [], []
def train_agent():
print(f"Current working directory: {os.getcwd()}")
print("Initializing environment and agent...")
render_mode = 'human' if HYPERPARAMETERS['render_training'] else None
env = SnakeGameEnv(render_mode=render_mode)
if HYPERPARAMETERS['render_training'] and HYPERPARAMETERS['render_fps_limit'] > 0:
env.metadata["render_fps"] = HYPERPARAMETERS['render_fps_limit']
obs_shape = env.observation_space.shape
action_size = env.action_space.n
agent = PPOAgent(
observation_space_shape=obs_shape,
action_space_size=action_size,
actor_lr=HYPERPARAMETERS['actor_lr'],
critic_lr=HYPERPARAMETERS['critic_lr'],
gamma=HYPERPARAMETERS['gamma'],
gae_lambda=HYPERPARAMETERS['gae_lambda'],
clip_epsilon=HYPERPARAMETERS['clip_epsilon'],
num_epochs_per_update=HYPERPARAMETERS['num_epochs_per_update'],
batch_size=HYPERPARAMETERS['batch_size'],
hidden_layer_sizes=HYPERPARAMETERS['hidden_layer_sizes']
)
total_timesteps_trained = 0
episode_count = 0
all_episode_rewards = []
plot_rewards_history = []
last_saved_timesteps = 0
# --- Resume Training Logic ---
if HYPERPARAMETERS['resume_training']:
print("Attempting to resume training...")
latest_checkpoint = None
for f in os.listdir(MODEL_SAVE_DIR):
if f.endswith('_actor.keras'):
try:
timestep_str = f.split('_')[-2]
timestep = int(timestep_str)
if latest_checkpoint is None or timestep > latest_checkpoint[0]:
latest_checkpoint = (timestep, f.replace('_actor.keras', ''))
except ValueError:
continue
if latest_checkpoint:
print(f"Found latest checkpoint: {latest_checkpoint[1]}")
if agent.load_models(latest_checkpoint[1]):
total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history = load_training_state()
last_saved_timesteps = total_timesteps_trained
print(f"Resumed from Timestep: {total_timesteps_trained}, Episode: {episode_count}")
else:
print("Failed to load models. Starting new training run.")
HYPERPARAMETERS['resume_training'] = False
else:
print("No previous checkpoints found. Starting new training run.")
HYPERPARAMETERS['resume_training'] = False
print("Starting training loop...")
start_time = time.time()
fig, ax, line = init_live_plot(PLOT_SAVE_DIR, filename="live_ppo_training_progress.png")
if HYPERPARAMETERS['resume_training'] and len(plot_rewards_history) > 1:
episodes_for_plot = [i * HYPERPARAMETERS['log_interval_episodes'] for i in range(len(plot_rewards_history))]
smoothed_rewards = smooth_curve(plot_rewards_history, factor=HYPERPARAMETERS['plot_smoothing_factor'])
update_live_plot(fig, ax, line, episodes_for_plot, smoothed_rewards,
current_timestep=total_timesteps_trained,
total_timesteps=HYPERPARAMETERS['total_timesteps'])
while total_timesteps_trained < HYPERPARAMETERS['total_timesteps']:
current_rollout_steps = 0
while current_rollout_steps < HYPERPARAMETERS['num_steps_per_rollout'] and \
total_timesteps_trained + current_rollout_steps < HYPERPARAMETERS['total_timesteps']:
state, info = env.reset()
current_action_mask = info['action_mask']
done = False
current_episode_reward = 0
while not done and current_rollout_steps < HYPERPARAMETERS['num_steps_per_rollout'] and \
total_timesteps_trained + current_rollout_steps < HYPERPARAMETERS['total_timesteps']:
action, log_prob, value = agent.choose_action(state, current_action_mask)
next_state, reward, terminated, truncated, info = env.step(action)
current_episode_reward += reward
next_action_mask = info['action_mask']
# --- NEW: PASS ACTION MASK TO AGENT'S REMEMBER METHOD ---
agent.remember(state, action, reward, next_state, terminated, log_prob, value, current_action_mask)
state = next_state
current_action_mask = next_action_mask
current_rollout_steps += 1
done = terminated or truncated
if done:
episode_count += 1
all_episode_rewards.append(current_episode_reward)
if episode_count % HYPERPARAMETERS['log_interval_episodes'] == 0:
avg_reward_last_n_episodes = np.mean(all_episode_rewards[-HYPERPARAMETERS['log_interval_episodes']:]).round(2)
plot_rewards_history.append(avg_reward_last_n_episodes)
elapsed_time = time.time() - start_time
print(f"Timestep: {total_timesteps_trained + current_rollout_steps}/{HYPERPARAMETERS['total_timesteps']} | "
f"Episode: {episode_count} | "
f"Avg Reward (last {HYPERPARAMETERS['log_interval_episodes']}): {avg_reward_last_n_episodes} | "
f"Total Score (this ep): {info['score']} | "
f"Time: {elapsed_time:.2f}s")
if episode_count % HYPERPARAMETERS['live_plot_interval_episodes'] == 0:
if len(plot_rewards_history) > 1:
episodes_for_plot = [i * HYPERPARAMETERS['log_interval_episodes'] for i in range(len(plot_rewards_history))]
smoothed_rewards = smooth_curve(plot_rewards_history, factor=HYPERPARAMETERS['plot_smoothing_factor'])
update_live_plot(fig, ax, line, episodes_for_plot, smoothed_rewards,
current_timestep=total_timesteps_trained + current_rollout_steps,
total_timesteps=HYPERPARAMETERS['total_timesteps'])
if HYPERPARAMETERS['render_training'] and done:
time.sleep(0.5)
break
total_timesteps_trained += current_rollout_steps
if len(agent.states) > 0:
print(f" --- Agent learning at Total Timestep {total_timesteps_trained} (collected {len(agent.states)} steps in rollout) ---")
agent.learn()
else:
print(f" --- No data collected in current rollout, skipping learning ---")
if total_timesteps_trained >= HYPERPARAMETERS['save_interval_timesteps'] and \
(total_timesteps_trained // HYPERPARAMETERS['save_interval_timesteps']) > \
(last_saved_timesteps // HYPERPARAMETERS['save_interval_timesteps']):
save_path_timesteps = (total_timesteps_trained // HYPERPARAMETERS['save_interval_timesteps']) * HYPERPARAMETERS['save_interval_timesteps']
print(f"--- Triggering periodic save at calculated timestep: {save_path_timesteps} ---")
agent.save_models(os.path.join(MODEL_SAVE_DIR, f"ppo_snake_{save_path_timesteps}"))
save_training_state(total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history)
last_saved_timesteps = save_path_timesteps
print("\nTraining finished!")
print(f"--- Triggering final save at total_timesteps: {total_timesteps_trained} ---")
agent.save_models(os.path.join(MODEL_SAVE_DIR, "ppo_snake_final"))
save_training_state(total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history)
env.close()
print("Generating final performance plot...")
episodes_for_plot = [i * HYPERPARAMETERS['log_interval_episodes'] for i in range(len(plot_rewards_history))]
smoothed_rewards = smooth_curve(plot_rewards_history, factor=HYPERPARAMETERS['plot_smoothing_factor'])
update_live_plot(fig, ax, line, episodes_for_plot, smoothed_rewards,
current_timestep=total_timesteps_trained,
total_timesteps=HYPERPARAMETERS['total_timesteps'])
save_live_plot_final(fig, ax)
plot_rewards(smoothed_rewards, HYPERPARAMETERS['log_interval_episodes'], PLOT_SAVE_DIR, "ppo_training_progress_final.png", show_plot=False)
if __name__ == "__main__":
train_agent()