Upload 8 files
Browse files- Dumb_Agent.py +40 -0
- Environment_Constants.py +23 -0
- PPO_Model.py +272 -0
- PPO_Trainer.py +231 -0
- Snake_EnvAndAgent.py +312 -0
- Trained_PPO_Agent.py +110 -0
- plot_utility_Trained_Agent.py +84 -0
- plot_utility_Trainer.py +84 -0
Dumb_Agent.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from Snake_EnvAndAgent import SnakeGameEnv
|
| 3 |
+
import pygame
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
env = SnakeGameEnv(render_mode='human')
|
| 8 |
+
|
| 9 |
+
episodes = 5
|
| 10 |
+
for episode in range(episodes):
|
| 11 |
+
obs, info = env.reset()
|
| 12 |
+
done = False
|
| 13 |
+
total_reward = 0
|
| 14 |
+
steps = 0
|
| 15 |
+
|
| 16 |
+
print(f"--- Starting Episode {episode + 1} ---")
|
| 17 |
+
|
| 18 |
+
while not done:
|
| 19 |
+
# For manual testing
|
| 20 |
+
# keys = pygame.key.get_pressed()
|
| 21 |
+
# if keys[pygame.K_UP]: action = 0 (map to straight)
|
| 22 |
+
|
| 23 |
+
action = env.action_space.sample()
|
| 24 |
+
|
| 25 |
+
next_obs, reward, terminated, truncated, info = env.step(action)
|
| 26 |
+
total_reward += reward
|
| 27 |
+
steps += 1
|
| 28 |
+
done = terminated or truncated
|
| 29 |
+
|
| 30 |
+
# Render the environment
|
| 31 |
+
#env.render()
|
| 32 |
+
#time.sleep(100) # Small delay to see the game progression
|
| 33 |
+
|
| 34 |
+
obs = next_obs
|
| 35 |
+
|
| 36 |
+
print(f"Episode {episode + 1} finished in {steps} steps with total reward: {total_reward:.2f}")
|
| 37 |
+
print(f"Final Score: {info['score']}")
|
| 38 |
+
|
| 39 |
+
env.close()
|
| 40 |
+
print("Environment test finished.")
|
Environment_Constants.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GRID_SIZE = 30
|
| 2 |
+
CELL_SIZE = 30
|
| 3 |
+
|
| 4 |
+
SCREEN_WIDTH = GRID_SIZE * CELL_SIZE
|
| 5 |
+
SCREEN_HEIGHT = GRID_SIZE * CELL_SIZE
|
| 6 |
+
|
| 7 |
+
WHITE = (255, 255, 255)
|
| 8 |
+
BLACK = (0, 0, 0)
|
| 9 |
+
GREEN = (0, 255, 0)
|
| 10 |
+
RED = (255, 0, 0)
|
| 11 |
+
BLUE = (0, 0, 255)
|
| 12 |
+
|
| 13 |
+
UP = (0, -1)
|
| 14 |
+
DOWN = (0, 1)
|
| 15 |
+
LEFT = (-1, 0)
|
| 16 |
+
RIGHT = (1, 0)
|
| 17 |
+
|
| 18 |
+
FPS = 10
|
| 19 |
+
|
| 20 |
+
REWARD_FOOD = 60
|
| 21 |
+
REWARD_COLLISION = -60
|
| 22 |
+
REWARD_STEP = -0.1
|
| 23 |
+
OBSERVATION_SPACE_SIZE = 11
|
PPO_Model.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import keras
|
| 3 |
+
from keras import layers, Model
|
| 4 |
+
import numpy as np
|
| 5 |
+
import tensorflow_probability as tfp
|
| 6 |
+
import os
|
| 7 |
+
import traceback
|
| 8 |
+
|
| 9 |
+
tfd = tfp.distributions
|
| 10 |
+
|
| 11 |
+
@tf.keras.utils.register_keras_serializable()
|
| 12 |
+
class Actor(Model):
|
| 13 |
+
def __init__(self, obs_shape, action_size, hidden_layer_sizes=[512, 512, 512], **kwargs):
|
| 14 |
+
super().__init__(**kwargs)
|
| 15 |
+
if len(obs_shape) > 1:
|
| 16 |
+
self.flatten = layers.Flatten(input_shape=obs_shape)
|
| 17 |
+
self.flatten(tf.zeros((1,) + obs_shape))
|
| 18 |
+
else:
|
| 19 |
+
self.flatten = None
|
| 20 |
+
|
| 21 |
+
self.dense_layers = []
|
| 22 |
+
for size in hidden_layer_sizes:
|
| 23 |
+
self.dense_layers.append(layers.Dense(size, activation='relu'))
|
| 24 |
+
self.logits = layers.Dense(action_size)
|
| 25 |
+
|
| 26 |
+
self._obs_shape = obs_shape
|
| 27 |
+
self._action_size = action_size
|
| 28 |
+
self._hidden_layer_sizes = hidden_layer_sizes
|
| 29 |
+
|
| 30 |
+
def call(self, inputs):
|
| 31 |
+
x = self.flatten(inputs) if self.flatten else inputs
|
| 32 |
+
for layer in self.dense_layers:
|
| 33 |
+
x = layer(x)
|
| 34 |
+
return self.logits(x)
|
| 35 |
+
|
| 36 |
+
def get_config(self):
|
| 37 |
+
config = super().get_config()
|
| 38 |
+
config.update({
|
| 39 |
+
'obs_shape': self._obs_shape,
|
| 40 |
+
'action_size': self._action_size,
|
| 41 |
+
'hidden_layer_sizes': self._hidden_layer_sizes
|
| 42 |
+
})
|
| 43 |
+
return config
|
| 44 |
+
|
| 45 |
+
@tf.keras.utils.register_keras_serializable()
|
| 46 |
+
class Critic(Model):
|
| 47 |
+
def __init__(self, obs_shape, hidden_layer_sizes=[512, 512, 512], **kwargs):
|
| 48 |
+
super().__init__(**kwargs)
|
| 49 |
+
if len(obs_shape) > 1:
|
| 50 |
+
self.flatten = layers.Flatten(input_shape=obs_shape)
|
| 51 |
+
self.flatten(tf.zeros((1,) + obs_shape))
|
| 52 |
+
else:
|
| 53 |
+
self.flatten = None
|
| 54 |
+
|
| 55 |
+
self.dense_layers = []
|
| 56 |
+
for size in hidden_layer_sizes:
|
| 57 |
+
self.dense_layers.append(layers.Dense(size, activation='relu'))
|
| 58 |
+
self.value = layers.Dense(1)
|
| 59 |
+
|
| 60 |
+
self._obs_shape = obs_shape
|
| 61 |
+
self._hidden_layer_sizes = hidden_layer_sizes
|
| 62 |
+
|
| 63 |
+
def call(self, inputs):
|
| 64 |
+
x = self.flatten(inputs) if self.flatten else inputs
|
| 65 |
+
for layer in self.dense_layers:
|
| 66 |
+
x = layer(x)
|
| 67 |
+
return self.value(x)
|
| 68 |
+
|
| 69 |
+
def get_config(self):
|
| 70 |
+
config = super().get_config()
|
| 71 |
+
config.update({
|
| 72 |
+
'obs_shape': self._obs_shape,
|
| 73 |
+
'hidden_layer_sizes': self._hidden_layer_sizes
|
| 74 |
+
})
|
| 75 |
+
return config
|
| 76 |
+
|
| 77 |
+
class PPOAgent:
|
| 78 |
+
def __init__(self, observation_space_shape, action_space_size,
|
| 79 |
+
actor_lr=3e-4, critic_lr=3e-4, gamma=0.99,
|
| 80 |
+
gae_lambda=0.95, clip_epsilon=0.2,
|
| 81 |
+
num_epochs_per_update=10, batch_size=64,
|
| 82 |
+
hidden_layer_sizes=[512, 512, 512]):
|
| 83 |
+
|
| 84 |
+
self.gamma = gamma
|
| 85 |
+
self.gae_lambda = gae_lambda
|
| 86 |
+
self.clip_epsilon = clip_epsilon
|
| 87 |
+
self.num_epochs_per_update = num_epochs_per_update
|
| 88 |
+
self.batch_size = batch_size
|
| 89 |
+
|
| 90 |
+
self.observation_space_shape = observation_space_shape
|
| 91 |
+
self.action_space_size = action_space_size
|
| 92 |
+
|
| 93 |
+
self.actor = Actor(observation_space_shape, action_space_size, hidden_layer_sizes=hidden_layer_sizes)
|
| 94 |
+
self.critic = Critic(observation_space_shape, hidden_layer_sizes=hidden_layer_sizes)
|
| 95 |
+
|
| 96 |
+
self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)
|
| 97 |
+
self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=critic_lr)
|
| 98 |
+
|
| 99 |
+
self.states = []
|
| 100 |
+
self.actions = []
|
| 101 |
+
self.rewards = []
|
| 102 |
+
self.next_states = []
|
| 103 |
+
self.dones = []
|
| 104 |
+
self.log_probs = []
|
| 105 |
+
self.values = []
|
| 106 |
+
self.action_masks = []
|
| 107 |
+
|
| 108 |
+
dummy_obs = tf.zeros((1,) + observation_space_shape, dtype=tf.float32)
|
| 109 |
+
self.actor(dummy_obs)
|
| 110 |
+
self.critic(dummy_obs)
|
| 111 |
+
|
| 112 |
+
def remember(self, state, action, reward, next_state, done, log_prob, value, action_mask):
|
| 113 |
+
self.states.append(state)
|
| 114 |
+
self.actions.append(action)
|
| 115 |
+
self.rewards.append(reward)
|
| 116 |
+
self.next_states.append(next_state)
|
| 117 |
+
self.dones.append(done)
|
| 118 |
+
self.log_probs.append(log_prob)
|
| 119 |
+
self.values.append(value)
|
| 120 |
+
self.action_masks.append(action_mask)
|
| 121 |
+
|
| 122 |
+
@tf.function
|
| 123 |
+
def _choose_action_tf(self, observation, action_mask):
|
| 124 |
+
observation = tf.expand_dims(tf.convert_to_tensor(observation, dtype=tf.float32), 0)
|
| 125 |
+
|
| 126 |
+
pi_logits = self.actor(observation)
|
| 127 |
+
|
| 128 |
+
masked_logits = tf.where(action_mask, pi_logits, -1e9)
|
| 129 |
+
|
| 130 |
+
value = self.critic(observation)
|
| 131 |
+
|
| 132 |
+
distribution = tfd.Categorical(logits=masked_logits)
|
| 133 |
+
|
| 134 |
+
action = distribution.sample()
|
| 135 |
+
log_prob = distribution.log_prob(action)
|
| 136 |
+
|
| 137 |
+
return action, log_prob, value
|
| 138 |
+
|
| 139 |
+
def choose_action(self, observation, action_mask):
|
| 140 |
+
action_tensor, log_prob_tensor, value_tensor = self._choose_action_tf(observation, tf.constant(action_mask, dtype=tf.bool))
|
| 141 |
+
return action_tensor.numpy(), log_prob_tensor.numpy(), value_tensor.numpy()[0,0]
|
| 142 |
+
|
| 143 |
+
def calculate_advantages_and_returns(self):
|
| 144 |
+
rewards = np.array(self.rewards, dtype=np.float32)
|
| 145 |
+
values = np.array(self.values, dtype=np.float32)
|
| 146 |
+
dones = np.array(self.dones, dtype=np.float32)
|
| 147 |
+
|
| 148 |
+
last_next_state_value = self.critic(tf.expand_dims(tf.convert_to_tensor(self.next_states[-1], dtype=tf.float32), 0)).numpy()[0,0] if not dones[-1] else 0
|
| 149 |
+
next_values = np.append(values[1:], last_next_state_value)
|
| 150 |
+
|
| 151 |
+
advantages = []
|
| 152 |
+
returns = []
|
| 153 |
+
|
| 154 |
+
last_advantage = 0
|
| 155 |
+
for t in reversed(range(len(rewards))):
|
| 156 |
+
delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
|
| 157 |
+
advantage = delta + self.gae_lambda * self.gamma * (1 - dones[t]) * last_advantage
|
| 158 |
+
advantages.insert(0, advantage)
|
| 159 |
+
returns.insert(0, advantage + values[t])
|
| 160 |
+
last_advantage = advantage
|
| 161 |
+
|
| 162 |
+
return np.array(advantages, dtype=np.float32), np.array(returns, dtype=np.float32)
|
| 163 |
+
|
| 164 |
+
def learn(self):
|
| 165 |
+
if not self.states:
|
| 166 |
+
return
|
| 167 |
+
|
| 168 |
+
states = tf.convert_to_tensor(np.array(self.states), dtype=tf.float32)
|
| 169 |
+
actions = tf.convert_to_tensor(np.array(self.actions), dtype=tf.int32)
|
| 170 |
+
old_log_probs = tf.convert_to_tensor(np.array(self.log_probs), dtype=tf.float32)
|
| 171 |
+
|
| 172 |
+
action_masks = tf.convert_to_tensor(np.array(self.action_masks), dtype=tf.bool)
|
| 173 |
+
|
| 174 |
+
advantages, returns = self.calculate_advantages_and_returns()
|
| 175 |
+
advantages = (advantages - tf.reduce_mean(advantages)) / (tf.math.reduce_std(advantages) + 1e-8)
|
| 176 |
+
|
| 177 |
+
dataset = tf.data.Dataset.from_tensor_slices((states, actions, old_log_probs, advantages, returns, action_masks))
|
| 178 |
+
dataset = dataset.shuffle(buffer_size=len(self.states)).batch(self.batch_size)
|
| 179 |
+
|
| 180 |
+
for _ in range(self.num_epochs_per_update):
|
| 181 |
+
for batch_states, batch_actions, batch_old_log_probs, batch_advantages, batch_returns, batch_action_masks in dataset:
|
| 182 |
+
|
| 183 |
+
with tf.GradientTape() as tape:
|
| 184 |
+
current_logits = self.actor(batch_states)
|
| 185 |
+
|
| 186 |
+
masked_logits = tf.where(batch_action_masks, current_logits, -1e9)
|
| 187 |
+
new_distribution = tfd.Categorical(logits=masked_logits)
|
| 188 |
+
|
| 189 |
+
new_log_probs = new_distribution.log_prob(batch_actions)
|
| 190 |
+
ratio = tf.exp(new_log_probs - batch_old_log_probs)
|
| 191 |
+
|
| 192 |
+
surrogate1 = ratio * batch_advantages
|
| 193 |
+
surrogate2 = tf.clip_by_value(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
|
| 194 |
+
|
| 195 |
+
actor_loss = -tf.reduce_mean(tf.minimum(surrogate1, surrogate2))
|
| 196 |
+
|
| 197 |
+
actor_grads = tape.gradient(actor_loss, self.actor.trainable_variables)
|
| 198 |
+
self.actor_optimizer.apply_gradients(zip(actor_grads, self.actor.trainable_variables))
|
| 199 |
+
|
| 200 |
+
with tf.GradientTape() as tape:
|
| 201 |
+
new_values = self.critic(batch_states)
|
| 202 |
+
critic_loss = tf.reduce_mean(tf.square(new_values - batch_returns))
|
| 203 |
+
|
| 204 |
+
critic_grads = tape.gradient(critic_loss, self.critic.trainable_variables)
|
| 205 |
+
self.critic_optimizer.apply_gradients(zip(critic_grads, self.critic.trainable_variables))
|
| 206 |
+
|
| 207 |
+
self.states = []
|
| 208 |
+
self.actions = []
|
| 209 |
+
self.rewards = []
|
| 210 |
+
self.next_states = []
|
| 211 |
+
self.dones = []
|
| 212 |
+
self.log_probs = []
|
| 213 |
+
self.values = []
|
| 214 |
+
self.action_masks = []
|
| 215 |
+
|
| 216 |
+
def save_models(self, path):
|
| 217 |
+
actor_save_path = f"{path}_actor.keras"
|
| 218 |
+
critic_save_path = f"{path}_critic.keras"
|
| 219 |
+
print(f"\n--- Attempting to save models ---")
|
| 220 |
+
print(f"Target Actor path: {os.path.abspath(actor_save_path)}")
|
| 221 |
+
print(f"Target Critic path: {os.path.abspath(critic_save_path)}")
|
| 222 |
+
try:
|
| 223 |
+
self.actor.save(actor_save_path)
|
| 224 |
+
print(f"Actor model saved successfully to {os.path.abspath(actor_save_path)}")
|
| 225 |
+
except Exception as e:
|
| 226 |
+
print(f"ERROR: Failed to save Actor model to {os.path.abspath(actor_save_path)}")
|
| 227 |
+
print(f"Reason: {e}")
|
| 228 |
+
traceback.print_exc()
|
| 229 |
+
try:
|
| 230 |
+
self.critic.save(critic_save_path)
|
| 231 |
+
print(f"Critic model saved successfully to {os.path.abspath(critic_save_path)}")
|
| 232 |
+
except Exception as e:
|
| 233 |
+
print(f"ERROR: Failed to save Critic model to {os.path.abspath(critic_save_path)}")
|
| 234 |
+
print(f"Reason: {e}")
|
| 235 |
+
traceback.print_exc()
|
| 236 |
+
print(f"--- Models save process completed ---\n")
|
| 237 |
+
|
| 238 |
+
def load_models(self, path):
|
| 239 |
+
actor_load_path = f"{path}_actor.keras"
|
| 240 |
+
critic_load_path = f"{path}_critic.keras"
|
| 241 |
+
actor_loaded_ok = False
|
| 242 |
+
critic_loaded_ok = False
|
| 243 |
+
|
| 244 |
+
custom_objects = {
|
| 245 |
+
'Actor': Actor,
|
| 246 |
+
'Critic': Critic
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
self.actor = tf.keras.models.load_model(actor_load_path, custom_objects=custom_objects)
|
| 251 |
+
actor_loaded_ok = True
|
| 252 |
+
print(f"Actor model loaded from: {os.path.abspath(actor_load_path)}")
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(f"ERROR: Failed to load Actor model from {os.path.abspath(actor_load_path)}")
|
| 255 |
+
print(f"Reason: {e}")
|
| 256 |
+
traceback.print_exc()
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
self.critic = tf.keras.models.load_model(critic_load_path, custom_objects=custom_objects)
|
| 260 |
+
critic_loaded_ok = True
|
| 261 |
+
print(f"Critic model loaded from: {os.path.abspath(critic_load_path)}")
|
| 262 |
+
except Exception as e:
|
| 263 |
+
print(f"ERROR: Failed to load Critic model from {os.path.abspath(critic_load_path)}")
|
| 264 |
+
print(f"Reason: {e}")
|
| 265 |
+
traceback.print_exc()
|
| 266 |
+
|
| 267 |
+
if actor_loaded_ok and critic_loaded_ok:
|
| 268 |
+
print(f"All PPO models loaded successfully from '{path}'.")
|
| 269 |
+
return True
|
| 270 |
+
else:
|
| 271 |
+
print(f"Warning: One or both models failed to load. The agent will use untrained models.")
|
| 272 |
+
return False
|
PPO_Trainer.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from Snake_EnvAndAgent import SnakeGameEnv
|
| 3 |
+
from PPO_Model import PPOAgent
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
import os
|
| 7 |
+
import json # For saving/loading training state
|
| 8 |
+
from plot_utility_Trainer import plot_rewards, smooth_curve, init_live_plot, update_live_plot, save_live_plot_final
|
| 9 |
+
|
| 10 |
+
HYPERPARAMETERS = {
|
| 11 |
+
'grid_size': 30, # This is used for environment initialization
|
| 12 |
+
'actor_lr': 0.0003,
|
| 13 |
+
'critic_lr': 0.0003,
|
| 14 |
+
'gamma': 0.99,
|
| 15 |
+
'gae_lambda': 0.95,
|
| 16 |
+
'clip_epsilon': 0.2,
|
| 17 |
+
'num_epochs_per_update': 10,
|
| 18 |
+
'batch_size': 64,
|
| 19 |
+
'num_steps_per_rollout': 2048, # Number of steps to collect before a learning update
|
| 20 |
+
'total_timesteps': 10_000_000, # Total environmental steps to train for
|
| 21 |
+
'hidden_layer_sizes': [512, 512, 512],
|
| 22 |
+
'save_interval_timesteps': 400000, # Save models every N total timesteps
|
| 23 |
+
'log_interval_episodes': 10, # Log training progress every N episodes
|
| 24 |
+
'render_training': False, # Set to True to see rendering during training (will slow down)
|
| 25 |
+
'render_fps_limit': 10, # Limits render FPS, if 0, renders as fast as possible (can be too fast)
|
| 26 |
+
'plot_smoothing_factor': 0.9, # For smoothing the reward plot
|
| 27 |
+
'live_plot_interval_episodes': 100, # Update live plot every N episodes
|
| 28 |
+
'resume_training': True # Set to True to attempt to resume from latest checkpoint
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
# Directory for saving models and plots
|
| 32 |
+
MODEL_SAVE_DIR = 'snake_ppo_models'
|
| 33 |
+
PLOT_SAVE_DIR = 'snake_ppo_plots'
|
| 34 |
+
TRAINING_STATE_FILE = os.path.join(MODEL_SAVE_DIR, 'training_state.json')
|
| 35 |
+
|
| 36 |
+
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
| 37 |
+
os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
|
| 38 |
+
print(f"Model save directory created/checked: {os.path.abspath(MODEL_SAVE_DIR)}")
|
| 39 |
+
print(f"Plot save directory created/checked: {os.path.abspath(PLOT_SAVE_DIR)}")
|
| 40 |
+
|
| 41 |
+
def save_training_state(total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history):
|
| 42 |
+
state = {
|
| 43 |
+
'total_timesteps_trained': total_timesteps_trained,
|
| 44 |
+
'episode_count': episode_count,
|
| 45 |
+
'all_episode_rewards': all_episode_rewards,
|
| 46 |
+
'plot_rewards_history': plot_rewards_history
|
| 47 |
+
}
|
| 48 |
+
with open(TRAINING_STATE_FILE, 'w') as f:
|
| 49 |
+
json.dump(state, f)
|
| 50 |
+
print(f"Training state saved to {TRAINING_STATE_FILE}")
|
| 51 |
+
|
| 52 |
+
def load_training_state():
|
| 53 |
+
if os.path.exists(TRAINING_STATE_FILE):
|
| 54 |
+
with open(TRAINING_STATE_FILE, 'r') as f:
|
| 55 |
+
state = json.load(f)
|
| 56 |
+
print(f"Training state loaded from {TRAINING_STATE_FILE}")
|
| 57 |
+
return state['total_timesteps_trained'], \
|
| 58 |
+
state['episode_count'], \
|
| 59 |
+
state['all_episode_rewards'], \
|
| 60 |
+
state['plot_rewards_history']
|
| 61 |
+
return 0, 0, [], []
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def train_agent():
|
| 65 |
+
print(f"Current working directory: {os.getcwd()}")
|
| 66 |
+
print("Initializing environment and agent...")
|
| 67 |
+
|
| 68 |
+
render_mode = 'human' if HYPERPARAMETERS['render_training'] else None
|
| 69 |
+
|
| 70 |
+
env = SnakeGameEnv(render_mode=render_mode)
|
| 71 |
+
|
| 72 |
+
if HYPERPARAMETERS['render_training'] and HYPERPARAMETERS['render_fps_limit'] > 0:
|
| 73 |
+
env.metadata["render_fps"] = HYPERPARAMETERS['render_fps_limit']
|
| 74 |
+
|
| 75 |
+
obs_shape = env.observation_space.shape
|
| 76 |
+
action_size = env.action_space.n
|
| 77 |
+
|
| 78 |
+
agent = PPOAgent(
|
| 79 |
+
observation_space_shape=obs_shape,
|
| 80 |
+
action_space_size=action_size,
|
| 81 |
+
actor_lr=HYPERPARAMETERS['actor_lr'],
|
| 82 |
+
critic_lr=HYPERPARAMETERS['critic_lr'],
|
| 83 |
+
gamma=HYPERPARAMETERS['gamma'],
|
| 84 |
+
gae_lambda=HYPERPARAMETERS['gae_lambda'],
|
| 85 |
+
clip_epsilon=HYPERPARAMETERS['clip_epsilon'],
|
| 86 |
+
num_epochs_per_update=HYPERPARAMETERS['num_epochs_per_update'],
|
| 87 |
+
batch_size=HYPERPARAMETERS['batch_size'],
|
| 88 |
+
hidden_layer_sizes=HYPERPARAMETERS['hidden_layer_sizes']
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
total_timesteps_trained = 0
|
| 92 |
+
episode_count = 0
|
| 93 |
+
all_episode_rewards = []
|
| 94 |
+
plot_rewards_history = []
|
| 95 |
+
last_saved_timesteps = 0
|
| 96 |
+
|
| 97 |
+
# --- Resume Training Logic ---
|
| 98 |
+
if HYPERPARAMETERS['resume_training']:
|
| 99 |
+
print("Attempting to resume training...")
|
| 100 |
+
latest_checkpoint = None
|
| 101 |
+
for f in os.listdir(MODEL_SAVE_DIR):
|
| 102 |
+
if f.endswith('_actor.keras'):
|
| 103 |
+
try:
|
| 104 |
+
timestep_str = f.split('_')[-2]
|
| 105 |
+
timestep = int(timestep_str)
|
| 106 |
+
if latest_checkpoint is None or timestep > latest_checkpoint[0]:
|
| 107 |
+
latest_checkpoint = (timestep, f.replace('_actor.keras', ''))
|
| 108 |
+
except ValueError:
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
if latest_checkpoint:
|
| 112 |
+
print(f"Found latest checkpoint: {latest_checkpoint[1]}")
|
| 113 |
+
if agent.load_models(latest_checkpoint[1]):
|
| 114 |
+
total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history = load_training_state()
|
| 115 |
+
last_saved_timesteps = total_timesteps_trained
|
| 116 |
+
print(f"Resumed from Timestep: {total_timesteps_trained}, Episode: {episode_count}")
|
| 117 |
+
else:
|
| 118 |
+
print("Failed to load models. Starting new training run.")
|
| 119 |
+
HYPERPARAMETERS['resume_training'] = False
|
| 120 |
+
else:
|
| 121 |
+
print("No previous checkpoints found. Starting new training run.")
|
| 122 |
+
HYPERPARAMETERS['resume_training'] = False
|
| 123 |
+
|
| 124 |
+
print("Starting training loop...")
|
| 125 |
+
start_time = time.time()
|
| 126 |
+
|
| 127 |
+
fig, ax, line = init_live_plot(PLOT_SAVE_DIR, filename="live_ppo_training_progress.png")
|
| 128 |
+
if HYPERPARAMETERS['resume_training'] and len(plot_rewards_history) > 1:
|
| 129 |
+
episodes_for_plot = [i * HYPERPARAMETERS['log_interval_episodes'] for i in range(len(plot_rewards_history))]
|
| 130 |
+
smoothed_rewards = smooth_curve(plot_rewards_history, factor=HYPERPARAMETERS['plot_smoothing_factor'])
|
| 131 |
+
update_live_plot(fig, ax, line, episodes_for_plot, smoothed_rewards,
|
| 132 |
+
current_timestep=total_timesteps_trained,
|
| 133 |
+
total_timesteps=HYPERPARAMETERS['total_timesteps'])
|
| 134 |
+
|
| 135 |
+
while total_timesteps_trained < HYPERPARAMETERS['total_timesteps']:
|
| 136 |
+
current_rollout_steps = 0
|
| 137 |
+
|
| 138 |
+
while current_rollout_steps < HYPERPARAMETERS['num_steps_per_rollout'] and \
|
| 139 |
+
total_timesteps_trained + current_rollout_steps < HYPERPARAMETERS['total_timesteps']:
|
| 140 |
+
|
| 141 |
+
state, info = env.reset()
|
| 142 |
+
current_action_mask = info['action_mask']
|
| 143 |
+
|
| 144 |
+
done = False
|
| 145 |
+
current_episode_reward = 0
|
| 146 |
+
|
| 147 |
+
while not done and current_rollout_steps < HYPERPARAMETERS['num_steps_per_rollout'] and \
|
| 148 |
+
total_timesteps_trained + current_rollout_steps < HYPERPARAMETERS['total_timesteps']:
|
| 149 |
+
|
| 150 |
+
action, log_prob, value = agent.choose_action(state, current_action_mask)
|
| 151 |
+
|
| 152 |
+
next_state, reward, terminated, truncated, info = env.step(action)
|
| 153 |
+
current_episode_reward += reward
|
| 154 |
+
|
| 155 |
+
next_action_mask = info['action_mask']
|
| 156 |
+
|
| 157 |
+
# --- NEW: PASS ACTION MASK TO AGENT'S REMEMBER METHOD ---
|
| 158 |
+
agent.remember(state, action, reward, next_state, terminated, log_prob, value, current_action_mask)
|
| 159 |
+
|
| 160 |
+
state = next_state
|
| 161 |
+
current_action_mask = next_action_mask
|
| 162 |
+
|
| 163 |
+
current_rollout_steps += 1
|
| 164 |
+
|
| 165 |
+
done = terminated or truncated
|
| 166 |
+
|
| 167 |
+
if done:
|
| 168 |
+
episode_count += 1
|
| 169 |
+
all_episode_rewards.append(current_episode_reward)
|
| 170 |
+
|
| 171 |
+
if episode_count % HYPERPARAMETERS['log_interval_episodes'] == 0:
|
| 172 |
+
avg_reward_last_n_episodes = np.mean(all_episode_rewards[-HYPERPARAMETERS['log_interval_episodes']:]).round(2)
|
| 173 |
+
plot_rewards_history.append(avg_reward_last_n_episodes)
|
| 174 |
+
|
| 175 |
+
elapsed_time = time.time() - start_time
|
| 176 |
+
print(f"Timestep: {total_timesteps_trained + current_rollout_steps}/{HYPERPARAMETERS['total_timesteps']} | "
|
| 177 |
+
f"Episode: {episode_count} | "
|
| 178 |
+
f"Avg Reward (last {HYPERPARAMETERS['log_interval_episodes']}): {avg_reward_last_n_episodes} | "
|
| 179 |
+
f"Total Score (this ep): {info['score']} | "
|
| 180 |
+
f"Time: {elapsed_time:.2f}s")
|
| 181 |
+
|
| 182 |
+
if episode_count % HYPERPARAMETERS['live_plot_interval_episodes'] == 0:
|
| 183 |
+
if len(plot_rewards_history) > 1:
|
| 184 |
+
episodes_for_plot = [i * HYPERPARAMETERS['log_interval_episodes'] for i in range(len(plot_rewards_history))]
|
| 185 |
+
smoothed_rewards = smooth_curve(plot_rewards_history, factor=HYPERPARAMETERS['plot_smoothing_factor'])
|
| 186 |
+
update_live_plot(fig, ax, line, episodes_for_plot, smoothed_rewards,
|
| 187 |
+
current_timestep=total_timesteps_trained + current_rollout_steps,
|
| 188 |
+
total_timesteps=HYPERPARAMETERS['total_timesteps'])
|
| 189 |
+
|
| 190 |
+
if HYPERPARAMETERS['render_training'] and done:
|
| 191 |
+
time.sleep(0.5)
|
| 192 |
+
|
| 193 |
+
break
|
| 194 |
+
|
| 195 |
+
total_timesteps_trained += current_rollout_steps
|
| 196 |
+
|
| 197 |
+
if len(agent.states) > 0:
|
| 198 |
+
print(f" --- Agent learning at Total Timestep {total_timesteps_trained} (collected {len(agent.states)} steps in rollout) ---")
|
| 199 |
+
agent.learn()
|
| 200 |
+
else:
|
| 201 |
+
print(f" --- No data collected in current rollout, skipping learning ---")
|
| 202 |
+
|
| 203 |
+
if total_timesteps_trained >= HYPERPARAMETERS['save_interval_timesteps'] and \
|
| 204 |
+
(total_timesteps_trained // HYPERPARAMETERS['save_interval_timesteps']) > \
|
| 205 |
+
(last_saved_timesteps // HYPERPARAMETERS['save_interval_timesteps']):
|
| 206 |
+
|
| 207 |
+
save_path_timesteps = (total_timesteps_trained // HYPERPARAMETERS['save_interval_timesteps']) * HYPERPARAMETERS['save_interval_timesteps']
|
| 208 |
+
print(f"--- Triggering periodic save at calculated timestep: {save_path_timesteps} ---")
|
| 209 |
+
agent.save_models(os.path.join(MODEL_SAVE_DIR, f"ppo_snake_{save_path_timesteps}"))
|
| 210 |
+
save_training_state(total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history)
|
| 211 |
+
last_saved_timesteps = save_path_timesteps
|
| 212 |
+
|
| 213 |
+
print("\nTraining finished!")
|
| 214 |
+
print(f"--- Triggering final save at total_timesteps: {total_timesteps_trained} ---")
|
| 215 |
+
agent.save_models(os.path.join(MODEL_SAVE_DIR, "ppo_snake_final"))
|
| 216 |
+
save_training_state(total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history)
|
| 217 |
+
|
| 218 |
+
env.close()
|
| 219 |
+
|
| 220 |
+
print("Generating final performance plot...")
|
| 221 |
+
episodes_for_plot = [i * HYPERPARAMETERS['log_interval_episodes'] for i in range(len(plot_rewards_history))]
|
| 222 |
+
smoothed_rewards = smooth_curve(plot_rewards_history, factor=HYPERPARAMETERS['plot_smoothing_factor'])
|
| 223 |
+
update_live_plot(fig, ax, line, episodes_for_plot, smoothed_rewards,
|
| 224 |
+
current_timestep=total_timesteps_trained,
|
| 225 |
+
total_timesteps=HYPERPARAMETERS['total_timesteps'])
|
| 226 |
+
save_live_plot_final(fig, ax)
|
| 227 |
+
|
| 228 |
+
plot_rewards(smoothed_rewards, HYPERPARAMETERS['log_interval_episodes'], PLOT_SAVE_DIR, "ppo_training_progress_final.png", show_plot=False)
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
train_agent()
|
Snake_EnvAndAgent.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium import spaces
|
| 3 |
+
import random
|
| 4 |
+
import pygame
|
| 5 |
+
import numpy as np
|
| 6 |
+
import collections
|
| 7 |
+
from collections import deque
|
| 8 |
+
from Environment_Constants import (
|
| 9 |
+
GRID_SIZE, CELL_SIZE, SCREEN_WIDTH, SCREEN_HEIGHT,
|
| 10 |
+
WHITE, BLACK, GREEN, RED, BLUE,
|
| 11 |
+
UP, DOWN, LEFT, RIGHT,
|
| 12 |
+
REWARD_FOOD, REWARD_COLLISION, REWARD_STEP,
|
| 13 |
+
FPS, OBSERVATION_SPACE_SIZE
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
class SnakeGameEnv(gym.Env):
|
| 17 |
+
metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': FPS}
|
| 18 |
+
|
| 19 |
+
def __init__(self, render_mode=None):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.grid_size = GRID_SIZE
|
| 22 |
+
self.cell_size = CELL_SIZE
|
| 23 |
+
self.screen_width = SCREEN_WIDTH
|
| 24 |
+
self.screen_height = SCREEN_HEIGHT
|
| 25 |
+
|
| 26 |
+
self.action_space = spaces.Discrete(3)
|
| 27 |
+
|
| 28 |
+
self.observation_space = spaces.Box(low=0, high=1,
|
| 29 |
+
shape=(OBSERVATION_SPACE_SIZE,),
|
| 30 |
+
dtype=np.float32)
|
| 31 |
+
|
| 32 |
+
self.render_mode = render_mode
|
| 33 |
+
self.window = None
|
| 34 |
+
self.clock = None
|
| 35 |
+
|
| 36 |
+
self._init_game_state()
|
| 37 |
+
|
| 38 |
+
def _init_game_state(self):
|
| 39 |
+
self.snake = deque()
|
| 40 |
+
self.head = (self.grid_size // 2, self.grid_size // 2)
|
| 41 |
+
self.snake.append(self.head)
|
| 42 |
+
self.snake.append((self.head[0], self.head[1] + 1))
|
| 43 |
+
self.snake.append((self.head[0], self.head[1] + 2))
|
| 44 |
+
|
| 45 |
+
self.direction = UP
|
| 46 |
+
self.score = 0
|
| 47 |
+
self.food = self._place_food()
|
| 48 |
+
self.game_over = False
|
| 49 |
+
self.steps_since_food = 0
|
| 50 |
+
self.length = len(self.snake)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _place_food(self):
|
| 54 |
+
while True:
|
| 55 |
+
x = random.randrange(self.grid_size)
|
| 56 |
+
y = random.randrange(self.grid_size)
|
| 57 |
+
food_pos = (x, y)
|
| 58 |
+
if food_pos not in self.snake:
|
| 59 |
+
return food_pos
|
| 60 |
+
|
| 61 |
+
def _is_position_safe_for_observation(self, pos):
|
| 62 |
+
px, py = pos
|
| 63 |
+
if not (0 <= px < self.grid_size and 0 <= py < self.grid_size):
|
| 64 |
+
return False
|
| 65 |
+
if pos in list(self.snake)[1:]:
|
| 66 |
+
return False
|
| 67 |
+
return True
|
| 68 |
+
|
| 69 |
+
def _get_observation(self):
|
| 70 |
+
obs = np.zeros(OBSERVATION_SPACE_SIZE, dtype=np.float32)
|
| 71 |
+
|
| 72 |
+
hx, hy = self.head
|
| 73 |
+
|
| 74 |
+
if self.direction == UP:
|
| 75 |
+
dir_straight = UP
|
| 76 |
+
dir_right = RIGHT
|
| 77 |
+
dir_left = LEFT
|
| 78 |
+
elif self.direction == DOWN:
|
| 79 |
+
dir_straight = DOWN
|
| 80 |
+
dir_right = LEFT
|
| 81 |
+
dir_left = RIGHT
|
| 82 |
+
elif self.direction == LEFT:
|
| 83 |
+
dir_straight = LEFT
|
| 84 |
+
dir_right = UP
|
| 85 |
+
dir_left = DOWN
|
| 86 |
+
elif self.direction == RIGHT:
|
| 87 |
+
dir_straight = RIGHT
|
| 88 |
+
dir_right = DOWN
|
| 89 |
+
dir_left = UP
|
| 90 |
+
|
| 91 |
+
check_pos_straight = (hx + dir_straight[0], hy + dir_straight[1])
|
| 92 |
+
check_pos_right = (hx + dir_right[0], hy + dir_right[1])
|
| 93 |
+
check_pos_left = (hx + dir_left[0], hy + dir_left[1])
|
| 94 |
+
|
| 95 |
+
obs[0] = 1 if not self._is_position_safe_for_observation(check_pos_straight) else 0
|
| 96 |
+
obs[1] = 1 if not self._is_position_safe_for_observation(check_pos_right) else 0
|
| 97 |
+
obs[2] = 1 if not self._is_position_safe_for_observation(check_pos_left) else 0
|
| 98 |
+
|
| 99 |
+
fx, fy = self.food
|
| 100 |
+
if fy < hy: obs[3] = 1
|
| 101 |
+
if fy > hy: obs[4] = 1
|
| 102 |
+
if fx < hx: obs[5] = 1
|
| 103 |
+
if fx > hx: obs[6] = 1
|
| 104 |
+
|
| 105 |
+
if self.direction == UP: obs[7] = 1
|
| 106 |
+
elif self.direction == DOWN: obs[8] = 1
|
| 107 |
+
elif self.direction == LEFT: obs[9] = 1
|
| 108 |
+
elif self.direction == RIGHT: obs[10] = 1
|
| 109 |
+
|
| 110 |
+
return obs
|
| 111 |
+
|
| 112 |
+
def _get_action_mask(self):
|
| 113 |
+
|
| 114 |
+
mask = np.array([True, True, True], dtype=bool)
|
| 115 |
+
hx, hy = self.head
|
| 116 |
+
|
| 117 |
+
potential_directions = [
|
| 118 |
+
self.direction,
|
| 119 |
+
None,
|
| 120 |
+
None
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
if self.direction == UP:
|
| 124 |
+
potential_directions[1] = RIGHT
|
| 125 |
+
potential_directions[2] = LEFT
|
| 126 |
+
elif self.direction == DOWN:
|
| 127 |
+
potential_directions[1] = LEFT
|
| 128 |
+
potential_directions[2] = RIGHT
|
| 129 |
+
elif self.direction == LEFT:
|
| 130 |
+
potential_directions[1] = UP
|
| 131 |
+
potential_directions[2] = DOWN
|
| 132 |
+
elif self.direction == RIGHT:
|
| 133 |
+
potential_directions[1] = DOWN
|
| 134 |
+
potential_directions[2] = UP
|
| 135 |
+
|
| 136 |
+
def _is_potential_move_illegal(pos_to_check, current_snake, food_pos):
|
| 137 |
+
if not (0 <= pos_to_check[0] < self.grid_size and 0 <= pos_to_check[1] < self.grid_size):
|
| 138 |
+
return True
|
| 139 |
+
|
| 140 |
+
if pos_to_check in list(current_snake)[:-1]:
|
| 141 |
+
return True
|
| 142 |
+
|
| 143 |
+
if pos_to_check == current_snake[-1]:
|
| 144 |
+
if pos_to_check != food_pos:
|
| 145 |
+
return True
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
return False
|
| 149 |
+
|
| 150 |
+
for action_idx, new_dir in enumerate(potential_directions):
|
| 151 |
+
dx, dy = new_dir
|
| 152 |
+
potential_head = (hx + dx, hy + dy)
|
| 153 |
+
if _is_potential_move_illegal(potential_head, self.snake, self.food):
|
| 154 |
+
mask[action_idx] = False
|
| 155 |
+
|
| 156 |
+
if not np.any(mask):
|
| 157 |
+
print(f"Warning: All actions masked out at head {self.head}, direction {self.direction}, food {self.food}. Attempting to find a fallback action.")
|
| 158 |
+
found_fallback = False
|
| 159 |
+
for i in range(3): # Check Straight, Right, Left
|
| 160 |
+
dx, dy = potential_directions[i]
|
| 161 |
+
potential_head = (hx + dx, hy + dy)
|
| 162 |
+
if not _is_potential_move_illegal(potential_head, self.snake, self.food):
|
| 163 |
+
mask[i] = True
|
| 164 |
+
found_fallback = True
|
| 165 |
+
|
| 166 |
+
if not found_fallback:
|
| 167 |
+
mask[np.random.choice(3)] = True
|
| 168 |
+
print("Critical Warning: No legal actions found even after fallback logic. Enabling a random action to prevent deadlock.")
|
| 169 |
+
|
| 170 |
+
return mask
|
| 171 |
+
|
| 172 |
+
def reset(self, seed=None, options=None):
|
| 173 |
+
super().reset(seed=seed)
|
| 174 |
+
self._init_game_state()
|
| 175 |
+
observation = self._get_observation()
|
| 176 |
+
info = self._get_info()
|
| 177 |
+
|
| 178 |
+
if not np.any(info['action_mask']):
|
| 179 |
+
print("Warning: No valid actions found in initial reset state.")
|
| 180 |
+
|
| 181 |
+
if self.render_mode == 'human':
|
| 182 |
+
self._render_frame()
|
| 183 |
+
return observation, info
|
| 184 |
+
|
| 185 |
+
def _get_info(self):
|
| 186 |
+
"""Returns environment information, including the action mask."""
|
| 187 |
+
return {
|
| 188 |
+
"score": self.score,
|
| 189 |
+
"snake_length": len(self.snake),
|
| 190 |
+
"action_mask": self._get_action_mask()
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
def step(self, action):
|
| 194 |
+
|
| 195 |
+
new_direction = self.direction
|
| 196 |
+
|
| 197 |
+
if action == 1:
|
| 198 |
+
if self.direction == UP: new_direction = RIGHT
|
| 199 |
+
elif self.direction == DOWN: new_direction = LEFT
|
| 200 |
+
elif self.direction == LEFT: new_direction = UP
|
| 201 |
+
elif self.direction == RIGHT: new_direction = DOWN
|
| 202 |
+
elif action == 2:
|
| 203 |
+
if self.direction == UP: new_direction = LEFT
|
| 204 |
+
elif self.direction == DOWN: new_direction = RIGHT
|
| 205 |
+
elif self.direction == LEFT: new_direction = DOWN
|
| 206 |
+
elif self.direction == RIGHT: new_direction = UP
|
| 207 |
+
elif action != 0:
|
| 208 |
+
raise ValueError(f"Received invalid action={action} which is not part of the action space.")
|
| 209 |
+
|
| 210 |
+
self.direction = new_direction
|
| 211 |
+
|
| 212 |
+
hx, hy = self.head
|
| 213 |
+
dx, dy = self.direction
|
| 214 |
+
new_head = (hx + dx, hy + dy)
|
| 215 |
+
|
| 216 |
+
reward = REWARD_STEP
|
| 217 |
+
terminated = False
|
| 218 |
+
truncated = False
|
| 219 |
+
if not (0 <= new_head[0] < self.grid_size and 0 <= new_head[1] < self.grid_size):
|
| 220 |
+
terminated = True
|
| 221 |
+
reward = REWARD_COLLISION
|
| 222 |
+
|
| 223 |
+
elif new_head in list(self.snake)[:-1]:
|
| 224 |
+
terminated = True
|
| 225 |
+
reward = REWARD_COLLISION
|
| 226 |
+
elif new_head == self.snake[-1] and new_head != self.food:
|
| 227 |
+
terminated = True
|
| 228 |
+
reward = REWARD_COLLISION
|
| 229 |
+
|
| 230 |
+
if terminated:
|
| 231 |
+
self.game_over = True
|
| 232 |
+
else:
|
| 233 |
+
self.snake.appendleft(new_head)
|
| 234 |
+
self.head = new_head
|
| 235 |
+
|
| 236 |
+
if new_head == self.food:
|
| 237 |
+
self.score += 1
|
| 238 |
+
self.length += 1
|
| 239 |
+
reward = REWARD_FOOD
|
| 240 |
+
self.food = self._place_food()
|
| 241 |
+
self.steps_since_food = 0
|
| 242 |
+
else:
|
| 243 |
+
self.snake.pop()
|
| 244 |
+
self.steps_since_food += 1
|
| 245 |
+
|
| 246 |
+
if self.steps_since_food >= self.grid_size * self.grid_size * 1.5:
|
| 247 |
+
terminated = True
|
| 248 |
+
truncated = True
|
| 249 |
+
reward = REWARD_COLLISION
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
observation = self._get_observation()
|
| 253 |
+
info = self._get_info()
|
| 254 |
+
|
| 255 |
+
if self.render_mode == 'human':
|
| 256 |
+
self._render_frame()
|
| 257 |
+
|
| 258 |
+
return observation, reward, terminated, truncated, info
|
| 259 |
+
|
| 260 |
+
def _render_frame(self):
|
| 261 |
+
if self.window is None and self.render_mode == 'human':
|
| 262 |
+
pygame.init()
|
| 263 |
+
pygame.display.init()
|
| 264 |
+
self.window = pygame.display.set_mode((self.screen_width, self.screen_height))
|
| 265 |
+
pygame.display.set_caption("Snake AI Training")
|
| 266 |
+
if self.clock is None and self.render_mode == 'human':
|
| 267 |
+
self.clock = pygame.time.Clock()
|
| 268 |
+
|
| 269 |
+
if self.render_mode == 'human':
|
| 270 |
+
self.window.fill(BLACK)
|
| 271 |
+
|
| 272 |
+
pygame.draw.rect(self.window, RED, (self.food[0] * self.cell_size,
|
| 273 |
+
self.food[1] * self.cell_size,
|
| 274 |
+
self.cell_size, self.cell_size))
|
| 275 |
+
|
| 276 |
+
for i, segment in enumerate(self.snake):
|
| 277 |
+
color = BLUE if i == 0 else GREEN
|
| 278 |
+
pygame.draw.rect(self.window, color, (segment[0] * self.cell_size,
|
| 279 |
+
segment[1] * self.cell_size,
|
| 280 |
+
self.cell_size, self.cell_size))
|
| 281 |
+
|
| 282 |
+
for x in range(0, self.screen_width, self.cell_size):
|
| 283 |
+
pygame.draw.line(self.window, WHITE, (x, 0), (x, self.screen_height))
|
| 284 |
+
for y in range(0, self.screen_height, self.cell_size):
|
| 285 |
+
pygame.draw.line(self.window, WHITE, (0, y), (self.screen_width, y))
|
| 286 |
+
|
| 287 |
+
font = pygame.font.Font(None, 25)
|
| 288 |
+
text = font.render(f"Score: {self.score}", True, WHITE)
|
| 289 |
+
self.window.blit(text, (5, 5))
|
| 290 |
+
|
| 291 |
+
pygame.event.pump()
|
| 292 |
+
pygame.display.flip()
|
| 293 |
+
self.clock.tick(self.metadata["render_fps"])
|
| 294 |
+
elif self.render_mode == "rgb_array":
|
| 295 |
+
surf = pygame.Surface((self.screen_width, self.screen_height))
|
| 296 |
+
surf.fill(BLACK)
|
| 297 |
+
pygame.draw.rect(surf, RED, (self.food[0] * self.cell_size,
|
| 298 |
+
self.food[1] * self.cell_size,
|
| 299 |
+
self.cell_size, self.cell_size))
|
| 300 |
+
for i, segment in enumerate(self.snake):
|
| 301 |
+
color = BLUE if i == 0 else GREEN
|
| 302 |
+
pygame.draw.rect(surf, color, (segment[0] * self.cell_size,
|
| 303 |
+
segment[1] * self.cell_size,
|
| 304 |
+
self.cell_size, self.cell_size))
|
| 305 |
+
return np.transpose(np.array(pygame.surfarray.pixels3d(surf)), axes=(1, 0, 2))
|
| 306 |
+
|
| 307 |
+
def close(self):
|
| 308 |
+
if self.window is not None:
|
| 309 |
+
pygame.display.quit()
|
| 310 |
+
pygame.quit()
|
| 311 |
+
self.window = None
|
| 312 |
+
self.clock = None
|
Trained_PPO_Agent.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from Snake_EnvAndAgent import SnakeGameEnv
|
| 3 |
+
from PPO_Model import PPOAgent
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import numpy as np
|
| 7 |
+
from plot_utility_Trained_Agent import init_live_plot, update_live_plot, save_live_plot_final, smooth_curve
|
| 8 |
+
|
| 9 |
+
PLAY_CONFIG = {
|
| 10 |
+
'grid_size': 30,
|
| 11 |
+
'model_path_prefix': 'snake_ppo_models/ppo_snake_final',
|
| 12 |
+
'num_episodes_to_play': 100,
|
| 13 |
+
'render_fps': 10,
|
| 14 |
+
'live_plot_interval_episodes': 1,
|
| 15 |
+
'plot_smoothing_factor': 0.8
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
PLOT_SAVE_DIR = 'snake_ppo_plots'
|
| 19 |
+
os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def play_agent():
|
| 23 |
+
print("Initializing environment for playback...")
|
| 24 |
+
env = SnakeGameEnv(render_mode='human')
|
| 25 |
+
|
| 26 |
+
if PLAY_CONFIG['render_fps'] > 0:
|
| 27 |
+
env.metadata["render_fps"] = PLAY_CONFIG['render_fps']
|
| 28 |
+
|
| 29 |
+
obs_shape = env.observation_space.shape
|
| 30 |
+
action_size = env.action_space.n
|
| 31 |
+
|
| 32 |
+
agent = PPOAgent(
|
| 33 |
+
observation_space_shape=obs_shape,
|
| 34 |
+
action_space_size=action_size,
|
| 35 |
+
actor_lr=3e-4,
|
| 36 |
+
critic_lr=3e-4,
|
| 37 |
+
hidden_layer_sizes=[512, 512, 512]
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
print(f"loading models from: {PLAY_CONFIG['model_path_prefix']}")
|
| 41 |
+
|
| 42 |
+
load_success = agent.load_models(PLAY_CONFIG['model_path_prefix'])
|
| 43 |
+
|
| 44 |
+
if not load_success:
|
| 45 |
+
print("\nFATAL ERROR: Failed to load trained models from disk. The agent CANNOT perform as trained. Exiting playback!.")
|
| 46 |
+
env.close()
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
print("--- Trained models loaded successfully. Loading playback. ---")
|
| 50 |
+
|
| 51 |
+
print("Starting agent playback...")
|
| 52 |
+
|
| 53 |
+
episode_rewards_playback = []
|
| 54 |
+
|
| 55 |
+
fig, ax, line = init_live_plot(PLOT_SAVE_DIR, filename="live_playback_rewards_plot.png")
|
| 56 |
+
ax.set_title('Live Playback Progress (Episode Rewards)')
|
| 57 |
+
ax.set_xlabel('Episode')
|
| 58 |
+
ax.set_ylabel('Total Reward')
|
| 59 |
+
|
| 60 |
+
for episode in range(PLAY_CONFIG['num_episodes_to_play']):
|
| 61 |
+
state, info = env.reset()
|
| 62 |
+
current_action_mask = info['action_mask']
|
| 63 |
+
|
| 64 |
+
done = False
|
| 65 |
+
episode_reward = 0
|
| 66 |
+
steps = 0
|
| 67 |
+
|
| 68 |
+
while not done:
|
| 69 |
+
action, _, _ = agent.choose_action(state, current_action_mask)
|
| 70 |
+
|
| 71 |
+
next_state, reward, terminated, truncated, info = env.step(action)
|
| 72 |
+
episode_reward += reward
|
| 73 |
+
state = next_state
|
| 74 |
+
steps += 1
|
| 75 |
+
done = terminated or truncated
|
| 76 |
+
|
| 77 |
+
current_action_mask = info['action_mask']
|
| 78 |
+
|
| 79 |
+
if PLAY_CONFIG['render_fps'] > 0:
|
| 80 |
+
time.sleep(1 / env.metadata["render_fps"])
|
| 81 |
+
|
| 82 |
+
episode_rewards_playback.append(episode_reward)
|
| 83 |
+
|
| 84 |
+
print(f"Episode {episode + 1}: Total Reward = {episode_reward:.2f}, Score = {info['score']}, Steps = {steps}")
|
| 85 |
+
|
| 86 |
+
if (episode + 1) % PLAY_CONFIG['live_plot_interval_episodes'] == 0:
|
| 87 |
+
current_episodes = list(range(1, len(episode_rewards_playback) + 1))
|
| 88 |
+
smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor'])
|
| 89 |
+
update_live_plot(fig, ax, line, current_episodes, smoothed_rewards,
|
| 90 |
+
current_timestep=episode + 1,
|
| 91 |
+
total_timesteps=PLAY_CONFIG['num_episodes_to_play'])
|
| 92 |
+
|
| 93 |
+
time.sleep(0.5)
|
| 94 |
+
|
| 95 |
+
env.close()
|
| 96 |
+
print("\nPlayback finished.")
|
| 97 |
+
|
| 98 |
+
current_episodes = list(range(1, len(episode_rewards_playback) + 1))
|
| 99 |
+
smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor'])
|
| 100 |
+
update_live_plot(fig, ax, line, current_episodes, smoothed_rewards,
|
| 101 |
+
current_timestep=PLAY_CONFIG['num_episodes_to_play'],
|
| 102 |
+
total_timesteps=PLAY_CONFIG['num_episodes_to_play'])
|
| 103 |
+
save_live_plot_final(fig, ax)
|
| 104 |
+
|
| 105 |
+
avg_playback_reward = np.mean(episode_rewards_playback)
|
| 106 |
+
print(f"Average Reward over {PLAY_CONFIG['num_episodes_to_play']} playback episodes: {avg_playback_reward:.2f}")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
play_agent()
|
plot_utility_Trained_Agent.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
def smooth_curve(points, factor=0.9):
|
| 7 |
+
smoothed_points = []
|
| 8 |
+
if points:
|
| 9 |
+
smoothed_points.append(points[0])
|
| 10 |
+
for i in range(1, len(points)):
|
| 11 |
+
smoothed_points.append(smoothed_points[-1] * factor + points[i] * (1 - factor))
|
| 12 |
+
return smoothed_points
|
| 13 |
+
|
| 14 |
+
def plot_rewards(rewards_history, log_interval, save_dir, filename="Trained_Agent_rewards_plot.png", show_plot=True):
|
| 15 |
+
|
| 16 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
plt.figure(figsize=(12, 6))
|
| 19 |
+
episodes = [i * log_interval for i in range(1, len(rewards_history) + 1)]
|
| 20 |
+
plt.plot(episodes, rewards_history, label='Average Reward')
|
| 21 |
+
plt.xlabel('Episodes')
|
| 22 |
+
plt.ylabel('Average Reward')
|
| 23 |
+
plt.title('Trained Agent Live Reward (Average Reward per Episode)')
|
| 24 |
+
plt.grid(True)
|
| 25 |
+
plt.legend()
|
| 26 |
+
plt.tight_layout()
|
| 27 |
+
|
| 28 |
+
save_path = os.path.join(save_dir, filename)
|
| 29 |
+
plt.savefig(save_path)
|
| 30 |
+
print(f"Plot saved to: {os.path.abspath(save_path)}")
|
| 31 |
+
|
| 32 |
+
if show_plot:
|
| 33 |
+
plt.show()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def init_live_plot(save_dir, filename="live_rewards_plot.png"):
|
| 37 |
+
|
| 38 |
+
plt.ion()
|
| 39 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 40 |
+
line, = ax.plot([], [], label='Smoothed Average Reward')
|
| 41 |
+
ax.set_xlabel('Episodes')
|
| 42 |
+
ax.set_ylabel('Average Reward')
|
| 43 |
+
ax.set_title('Live Reward for Trained Agent')
|
| 44 |
+
ax.grid(True)
|
| 45 |
+
ax.legend()
|
| 46 |
+
plt.tight_layout()
|
| 47 |
+
|
| 48 |
+
ax._save_path_final = os.path.join(save_dir, filename)
|
| 49 |
+
|
| 50 |
+
return fig, ax, line
|
| 51 |
+
|
| 52 |
+
def update_live_plot(fig, ax, line, episodes, smoothed_rewards, current_timestep=None, total_timesteps=None):
|
| 53 |
+
|
| 54 |
+
if not episodes or not smoothed_rewards:
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
line.set_data(episodes, smoothed_rewards)
|
| 58 |
+
|
| 59 |
+
ax.set_xlim(0, max(episodes) * 1.05 if episodes else 1)
|
| 60 |
+
|
| 61 |
+
min_y = min(smoothed_rewards) * 0.9 if smoothed_rewards else -1
|
| 62 |
+
max_y = max(smoothed_rewards) * 1.1 if smoothed_rewards else 1
|
| 63 |
+
|
| 64 |
+
if abs(max_y - min_y) < 0.1:
|
| 65 |
+
min_y -= 0.05
|
| 66 |
+
max_y += 0.05
|
| 67 |
+
ax.set_ylim(min_y, max_y)
|
| 68 |
+
|
| 69 |
+
if current_timestep is not None and total_timesteps is not None:
|
| 70 |
+
ax.set_title(f'Live Agent Progress (Timestep: {current_timestep:,}/{total_timesteps:,})')
|
| 71 |
+
|
| 72 |
+
fig.canvas.draw()
|
| 73 |
+
fig.canvas.flush_events()
|
| 74 |
+
time.sleep(0.01)
|
| 75 |
+
|
| 76 |
+
def save_live_plot_final(fig, ax):
|
| 77 |
+
|
| 78 |
+
plt.ioff()
|
| 79 |
+
save_path = getattr(ax, '_save_path_final', None)
|
| 80 |
+
if save_path:
|
| 81 |
+
plt.savefig(save_path)
|
| 82 |
+
print(f"Final live plot saved to: {os.path.abspath(save_path)}")
|
| 83 |
+
plt.close(fig)
|
| 84 |
+
plt.show()
|
plot_utility_Trainer.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
def smooth_curve(points, factor=0.9):
|
| 7 |
+
|
| 8 |
+
smoothed_points = []
|
| 9 |
+
if points:
|
| 10 |
+
smoothed_points.append(points[0])
|
| 11 |
+
for i in range(1, len(points)):
|
| 12 |
+
smoothed_points.append(smoothed_points[-1] * factor + points[i] * (1 - factor))
|
| 13 |
+
return smoothed_points
|
| 14 |
+
|
| 15 |
+
def plot_rewards(rewards_history, log_interval, save_dir, filename="rewards_plot.png", show_plot=True):
|
| 16 |
+
|
| 17 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
plt.figure(figsize=(12, 6))
|
| 20 |
+
episodes = [i * log_interval for i in range(1, len(rewards_history) + 1)]
|
| 21 |
+
plt.plot(episodes, rewards_history, label='Average Reward')
|
| 22 |
+
plt.xlabel('Episodes')
|
| 23 |
+
plt.ylabel('Average Reward')
|
| 24 |
+
plt.title('PPO Training Progress (Average Reward per Episode)')
|
| 25 |
+
plt.grid(True)
|
| 26 |
+
plt.legend()
|
| 27 |
+
plt.tight_layout()
|
| 28 |
+
|
| 29 |
+
save_path = os.path.join(save_dir, filename)
|
| 30 |
+
plt.savefig(save_path)
|
| 31 |
+
print(f"Plot saved to: {os.path.abspath(save_path)}")
|
| 32 |
+
|
| 33 |
+
if show_plot:
|
| 34 |
+
plt.show()
|
| 35 |
+
|
| 36 |
+
def init_live_plot(save_dir, filename="live_rewards_plot.png"):
|
| 37 |
+
|
| 38 |
+
plt.ion()
|
| 39 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 40 |
+
line, = ax.plot([], [], label='Smoothed Average Reward')
|
| 41 |
+
ax.set_xlabel('Episodes')
|
| 42 |
+
ax.set_ylabel('Average Reward')
|
| 43 |
+
ax.set_title('Live PPO Training Progress')
|
| 44 |
+
ax.grid(True)
|
| 45 |
+
ax.legend()
|
| 46 |
+
plt.tight_layout()
|
| 47 |
+
|
| 48 |
+
ax._save_path_final = os.path.join(save_dir, filename)
|
| 49 |
+
|
| 50 |
+
return fig, ax, line
|
| 51 |
+
|
| 52 |
+
def update_live_plot(fig, ax, line, episodes, smoothed_rewards, current_timestep=None, total_timesteps=None):
|
| 53 |
+
|
| 54 |
+
if not episodes or not smoothed_rewards:
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
line.set_data(episodes, smoothed_rewards)
|
| 58 |
+
|
| 59 |
+
ax.set_xlim(0, max(episodes) * 1.05 if episodes else 1)
|
| 60 |
+
|
| 61 |
+
min_y = min(smoothed_rewards) * 0.9 if smoothed_rewards else -1
|
| 62 |
+
max_y = max(smoothed_rewards) * 1.1 if smoothed_rewards else 1
|
| 63 |
+
|
| 64 |
+
if abs(max_y - min_y) < 0.1:
|
| 65 |
+
min_y -= 0.05
|
| 66 |
+
max_y += 0.05
|
| 67 |
+
ax.set_ylim(min_y, max_y)
|
| 68 |
+
|
| 69 |
+
if current_timestep is not None and total_timesteps is not None:
|
| 70 |
+
ax.set_title(f'Live PPO Training Progress (Timestep: {current_timestep:,}/{total_timesteps:,})')
|
| 71 |
+
|
| 72 |
+
fig.canvas.draw()
|
| 73 |
+
fig.canvas.flush_events()
|
| 74 |
+
time.sleep(0.01)
|
| 75 |
+
|
| 76 |
+
def save_live_plot_final(fig, ax):
|
| 77 |
+
|
| 78 |
+
plt.ioff()
|
| 79 |
+
save_path = getattr(ax, '_save_path_final', None)
|
| 80 |
+
if save_path:
|
| 81 |
+
plt.savefig(save_path)
|
| 82 |
+
print(f"Final live plot saved to: {os.path.abspath(save_path)}")
|
| 83 |
+
plt.close(fig)
|
| 84 |
+
plt.show()
|