privateboss commited on
Commit
2df2f26
·
verified ·
1 Parent(s): 8ef74a0

Upload 8 files

Browse files
Dumb_Agent.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from Snake_EnvAndAgent import SnakeGameEnv
3
+ import pygame
4
+ import time
5
+
6
+ if __name__ == "__main__":
7
+ env = SnakeGameEnv(render_mode='human')
8
+
9
+ episodes = 5
10
+ for episode in range(episodes):
11
+ obs, info = env.reset()
12
+ done = False
13
+ total_reward = 0
14
+ steps = 0
15
+
16
+ print(f"--- Starting Episode {episode + 1} ---")
17
+
18
+ while not done:
19
+ # For manual testing
20
+ # keys = pygame.key.get_pressed()
21
+ # if keys[pygame.K_UP]: action = 0 (map to straight)
22
+
23
+ action = env.action_space.sample()
24
+
25
+ next_obs, reward, terminated, truncated, info = env.step(action)
26
+ total_reward += reward
27
+ steps += 1
28
+ done = terminated or truncated
29
+
30
+ # Render the environment
31
+ #env.render()
32
+ #time.sleep(100) # Small delay to see the game progression
33
+
34
+ obs = next_obs
35
+
36
+ print(f"Episode {episode + 1} finished in {steps} steps with total reward: {total_reward:.2f}")
37
+ print(f"Final Score: {info['score']}")
38
+
39
+ env.close()
40
+ print("Environment test finished.")
Environment_Constants.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GRID_SIZE = 30
2
+ CELL_SIZE = 30
3
+
4
+ SCREEN_WIDTH = GRID_SIZE * CELL_SIZE
5
+ SCREEN_HEIGHT = GRID_SIZE * CELL_SIZE
6
+
7
+ WHITE = (255, 255, 255)
8
+ BLACK = (0, 0, 0)
9
+ GREEN = (0, 255, 0)
10
+ RED = (255, 0, 0)
11
+ BLUE = (0, 0, 255)
12
+
13
+ UP = (0, -1)
14
+ DOWN = (0, 1)
15
+ LEFT = (-1, 0)
16
+ RIGHT = (1, 0)
17
+
18
+ FPS = 10
19
+
20
+ REWARD_FOOD = 60
21
+ REWARD_COLLISION = -60
22
+ REWARD_STEP = -0.1
23
+ OBSERVATION_SPACE_SIZE = 11
PPO_Model.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import keras
3
+ from keras import layers, Model
4
+ import numpy as np
5
+ import tensorflow_probability as tfp
6
+ import os
7
+ import traceback
8
+
9
+ tfd = tfp.distributions
10
+
11
+ @tf.keras.utils.register_keras_serializable()
12
+ class Actor(Model):
13
+ def __init__(self, obs_shape, action_size, hidden_layer_sizes=[512, 512, 512], **kwargs):
14
+ super().__init__(**kwargs)
15
+ if len(obs_shape) > 1:
16
+ self.flatten = layers.Flatten(input_shape=obs_shape)
17
+ self.flatten(tf.zeros((1,) + obs_shape))
18
+ else:
19
+ self.flatten = None
20
+
21
+ self.dense_layers = []
22
+ for size in hidden_layer_sizes:
23
+ self.dense_layers.append(layers.Dense(size, activation='relu'))
24
+ self.logits = layers.Dense(action_size)
25
+
26
+ self._obs_shape = obs_shape
27
+ self._action_size = action_size
28
+ self._hidden_layer_sizes = hidden_layer_sizes
29
+
30
+ def call(self, inputs):
31
+ x = self.flatten(inputs) if self.flatten else inputs
32
+ for layer in self.dense_layers:
33
+ x = layer(x)
34
+ return self.logits(x)
35
+
36
+ def get_config(self):
37
+ config = super().get_config()
38
+ config.update({
39
+ 'obs_shape': self._obs_shape,
40
+ 'action_size': self._action_size,
41
+ 'hidden_layer_sizes': self._hidden_layer_sizes
42
+ })
43
+ return config
44
+
45
+ @tf.keras.utils.register_keras_serializable()
46
+ class Critic(Model):
47
+ def __init__(self, obs_shape, hidden_layer_sizes=[512, 512, 512], **kwargs):
48
+ super().__init__(**kwargs)
49
+ if len(obs_shape) > 1:
50
+ self.flatten = layers.Flatten(input_shape=obs_shape)
51
+ self.flatten(tf.zeros((1,) + obs_shape))
52
+ else:
53
+ self.flatten = None
54
+
55
+ self.dense_layers = []
56
+ for size in hidden_layer_sizes:
57
+ self.dense_layers.append(layers.Dense(size, activation='relu'))
58
+ self.value = layers.Dense(1)
59
+
60
+ self._obs_shape = obs_shape
61
+ self._hidden_layer_sizes = hidden_layer_sizes
62
+
63
+ def call(self, inputs):
64
+ x = self.flatten(inputs) if self.flatten else inputs
65
+ for layer in self.dense_layers:
66
+ x = layer(x)
67
+ return self.value(x)
68
+
69
+ def get_config(self):
70
+ config = super().get_config()
71
+ config.update({
72
+ 'obs_shape': self._obs_shape,
73
+ 'hidden_layer_sizes': self._hidden_layer_sizes
74
+ })
75
+ return config
76
+
77
+ class PPOAgent:
78
+ def __init__(self, observation_space_shape, action_space_size,
79
+ actor_lr=3e-4, critic_lr=3e-4, gamma=0.99,
80
+ gae_lambda=0.95, clip_epsilon=0.2,
81
+ num_epochs_per_update=10, batch_size=64,
82
+ hidden_layer_sizes=[512, 512, 512]):
83
+
84
+ self.gamma = gamma
85
+ self.gae_lambda = gae_lambda
86
+ self.clip_epsilon = clip_epsilon
87
+ self.num_epochs_per_update = num_epochs_per_update
88
+ self.batch_size = batch_size
89
+
90
+ self.observation_space_shape = observation_space_shape
91
+ self.action_space_size = action_space_size
92
+
93
+ self.actor = Actor(observation_space_shape, action_space_size, hidden_layer_sizes=hidden_layer_sizes)
94
+ self.critic = Critic(observation_space_shape, hidden_layer_sizes=hidden_layer_sizes)
95
+
96
+ self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)
97
+ self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=critic_lr)
98
+
99
+ self.states = []
100
+ self.actions = []
101
+ self.rewards = []
102
+ self.next_states = []
103
+ self.dones = []
104
+ self.log_probs = []
105
+ self.values = []
106
+ self.action_masks = []
107
+
108
+ dummy_obs = tf.zeros((1,) + observation_space_shape, dtype=tf.float32)
109
+ self.actor(dummy_obs)
110
+ self.critic(dummy_obs)
111
+
112
+ def remember(self, state, action, reward, next_state, done, log_prob, value, action_mask):
113
+ self.states.append(state)
114
+ self.actions.append(action)
115
+ self.rewards.append(reward)
116
+ self.next_states.append(next_state)
117
+ self.dones.append(done)
118
+ self.log_probs.append(log_prob)
119
+ self.values.append(value)
120
+ self.action_masks.append(action_mask)
121
+
122
+ @tf.function
123
+ def _choose_action_tf(self, observation, action_mask):
124
+ observation = tf.expand_dims(tf.convert_to_tensor(observation, dtype=tf.float32), 0)
125
+
126
+ pi_logits = self.actor(observation)
127
+
128
+ masked_logits = tf.where(action_mask, pi_logits, -1e9)
129
+
130
+ value = self.critic(observation)
131
+
132
+ distribution = tfd.Categorical(logits=masked_logits)
133
+
134
+ action = distribution.sample()
135
+ log_prob = distribution.log_prob(action)
136
+
137
+ return action, log_prob, value
138
+
139
+ def choose_action(self, observation, action_mask):
140
+ action_tensor, log_prob_tensor, value_tensor = self._choose_action_tf(observation, tf.constant(action_mask, dtype=tf.bool))
141
+ return action_tensor.numpy(), log_prob_tensor.numpy(), value_tensor.numpy()[0,0]
142
+
143
+ def calculate_advantages_and_returns(self):
144
+ rewards = np.array(self.rewards, dtype=np.float32)
145
+ values = np.array(self.values, dtype=np.float32)
146
+ dones = np.array(self.dones, dtype=np.float32)
147
+
148
+ last_next_state_value = self.critic(tf.expand_dims(tf.convert_to_tensor(self.next_states[-1], dtype=tf.float32), 0)).numpy()[0,0] if not dones[-1] else 0
149
+ next_values = np.append(values[1:], last_next_state_value)
150
+
151
+ advantages = []
152
+ returns = []
153
+
154
+ last_advantage = 0
155
+ for t in reversed(range(len(rewards))):
156
+ delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
157
+ advantage = delta + self.gae_lambda * self.gamma * (1 - dones[t]) * last_advantage
158
+ advantages.insert(0, advantage)
159
+ returns.insert(0, advantage + values[t])
160
+ last_advantage = advantage
161
+
162
+ return np.array(advantages, dtype=np.float32), np.array(returns, dtype=np.float32)
163
+
164
+ def learn(self):
165
+ if not self.states:
166
+ return
167
+
168
+ states = tf.convert_to_tensor(np.array(self.states), dtype=tf.float32)
169
+ actions = tf.convert_to_tensor(np.array(self.actions), dtype=tf.int32)
170
+ old_log_probs = tf.convert_to_tensor(np.array(self.log_probs), dtype=tf.float32)
171
+
172
+ action_masks = tf.convert_to_tensor(np.array(self.action_masks), dtype=tf.bool)
173
+
174
+ advantages, returns = self.calculate_advantages_and_returns()
175
+ advantages = (advantages - tf.reduce_mean(advantages)) / (tf.math.reduce_std(advantages) + 1e-8)
176
+
177
+ dataset = tf.data.Dataset.from_tensor_slices((states, actions, old_log_probs, advantages, returns, action_masks))
178
+ dataset = dataset.shuffle(buffer_size=len(self.states)).batch(self.batch_size)
179
+
180
+ for _ in range(self.num_epochs_per_update):
181
+ for batch_states, batch_actions, batch_old_log_probs, batch_advantages, batch_returns, batch_action_masks in dataset:
182
+
183
+ with tf.GradientTape() as tape:
184
+ current_logits = self.actor(batch_states)
185
+
186
+ masked_logits = tf.where(batch_action_masks, current_logits, -1e9)
187
+ new_distribution = tfd.Categorical(logits=masked_logits)
188
+
189
+ new_log_probs = new_distribution.log_prob(batch_actions)
190
+ ratio = tf.exp(new_log_probs - batch_old_log_probs)
191
+
192
+ surrogate1 = ratio * batch_advantages
193
+ surrogate2 = tf.clip_by_value(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
194
+
195
+ actor_loss = -tf.reduce_mean(tf.minimum(surrogate1, surrogate2))
196
+
197
+ actor_grads = tape.gradient(actor_loss, self.actor.trainable_variables)
198
+ self.actor_optimizer.apply_gradients(zip(actor_grads, self.actor.trainable_variables))
199
+
200
+ with tf.GradientTape() as tape:
201
+ new_values = self.critic(batch_states)
202
+ critic_loss = tf.reduce_mean(tf.square(new_values - batch_returns))
203
+
204
+ critic_grads = tape.gradient(critic_loss, self.critic.trainable_variables)
205
+ self.critic_optimizer.apply_gradients(zip(critic_grads, self.critic.trainable_variables))
206
+
207
+ self.states = []
208
+ self.actions = []
209
+ self.rewards = []
210
+ self.next_states = []
211
+ self.dones = []
212
+ self.log_probs = []
213
+ self.values = []
214
+ self.action_masks = []
215
+
216
+ def save_models(self, path):
217
+ actor_save_path = f"{path}_actor.keras"
218
+ critic_save_path = f"{path}_critic.keras"
219
+ print(f"\n--- Attempting to save models ---")
220
+ print(f"Target Actor path: {os.path.abspath(actor_save_path)}")
221
+ print(f"Target Critic path: {os.path.abspath(critic_save_path)}")
222
+ try:
223
+ self.actor.save(actor_save_path)
224
+ print(f"Actor model saved successfully to {os.path.abspath(actor_save_path)}")
225
+ except Exception as e:
226
+ print(f"ERROR: Failed to save Actor model to {os.path.abspath(actor_save_path)}")
227
+ print(f"Reason: {e}")
228
+ traceback.print_exc()
229
+ try:
230
+ self.critic.save(critic_save_path)
231
+ print(f"Critic model saved successfully to {os.path.abspath(critic_save_path)}")
232
+ except Exception as e:
233
+ print(f"ERROR: Failed to save Critic model to {os.path.abspath(critic_save_path)}")
234
+ print(f"Reason: {e}")
235
+ traceback.print_exc()
236
+ print(f"--- Models save process completed ---\n")
237
+
238
+ def load_models(self, path):
239
+ actor_load_path = f"{path}_actor.keras"
240
+ critic_load_path = f"{path}_critic.keras"
241
+ actor_loaded_ok = False
242
+ critic_loaded_ok = False
243
+
244
+ custom_objects = {
245
+ 'Actor': Actor,
246
+ 'Critic': Critic
247
+ }
248
+
249
+ try:
250
+ self.actor = tf.keras.models.load_model(actor_load_path, custom_objects=custom_objects)
251
+ actor_loaded_ok = True
252
+ print(f"Actor model loaded from: {os.path.abspath(actor_load_path)}")
253
+ except Exception as e:
254
+ print(f"ERROR: Failed to load Actor model from {os.path.abspath(actor_load_path)}")
255
+ print(f"Reason: {e}")
256
+ traceback.print_exc()
257
+
258
+ try:
259
+ self.critic = tf.keras.models.load_model(critic_load_path, custom_objects=custom_objects)
260
+ critic_loaded_ok = True
261
+ print(f"Critic model loaded from: {os.path.abspath(critic_load_path)}")
262
+ except Exception as e:
263
+ print(f"ERROR: Failed to load Critic model from {os.path.abspath(critic_load_path)}")
264
+ print(f"Reason: {e}")
265
+ traceback.print_exc()
266
+
267
+ if actor_loaded_ok and critic_loaded_ok:
268
+ print(f"All PPO models loaded successfully from '{path}'.")
269
+ return True
270
+ else:
271
+ print(f"Warning: One or both models failed to load. The agent will use untrained models.")
272
+ return False
PPO_Trainer.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from Snake_EnvAndAgent import SnakeGameEnv
3
+ from PPO_Model import PPOAgent
4
+ import numpy as np
5
+ import time
6
+ import os
7
+ import json # For saving/loading training state
8
+ from plot_utility_Trainer import plot_rewards, smooth_curve, init_live_plot, update_live_plot, save_live_plot_final
9
+
10
+ HYPERPARAMETERS = {
11
+ 'grid_size': 30, # This is used for environment initialization
12
+ 'actor_lr': 0.0003,
13
+ 'critic_lr': 0.0003,
14
+ 'gamma': 0.99,
15
+ 'gae_lambda': 0.95,
16
+ 'clip_epsilon': 0.2,
17
+ 'num_epochs_per_update': 10,
18
+ 'batch_size': 64,
19
+ 'num_steps_per_rollout': 2048, # Number of steps to collect before a learning update
20
+ 'total_timesteps': 10_000_000, # Total environmental steps to train for
21
+ 'hidden_layer_sizes': [512, 512, 512],
22
+ 'save_interval_timesteps': 400000, # Save models every N total timesteps
23
+ 'log_interval_episodes': 10, # Log training progress every N episodes
24
+ 'render_training': False, # Set to True to see rendering during training (will slow down)
25
+ 'render_fps_limit': 10, # Limits render FPS, if 0, renders as fast as possible (can be too fast)
26
+ 'plot_smoothing_factor': 0.9, # For smoothing the reward plot
27
+ 'live_plot_interval_episodes': 100, # Update live plot every N episodes
28
+ 'resume_training': True # Set to True to attempt to resume from latest checkpoint
29
+ }
30
+
31
+ # Directory for saving models and plots
32
+ MODEL_SAVE_DIR = 'snake_ppo_models'
33
+ PLOT_SAVE_DIR = 'snake_ppo_plots'
34
+ TRAINING_STATE_FILE = os.path.join(MODEL_SAVE_DIR, 'training_state.json')
35
+
36
+ os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
37
+ os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
38
+ print(f"Model save directory created/checked: {os.path.abspath(MODEL_SAVE_DIR)}")
39
+ print(f"Plot save directory created/checked: {os.path.abspath(PLOT_SAVE_DIR)}")
40
+
41
+ def save_training_state(total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history):
42
+ state = {
43
+ 'total_timesteps_trained': total_timesteps_trained,
44
+ 'episode_count': episode_count,
45
+ 'all_episode_rewards': all_episode_rewards,
46
+ 'plot_rewards_history': plot_rewards_history
47
+ }
48
+ with open(TRAINING_STATE_FILE, 'w') as f:
49
+ json.dump(state, f)
50
+ print(f"Training state saved to {TRAINING_STATE_FILE}")
51
+
52
+ def load_training_state():
53
+ if os.path.exists(TRAINING_STATE_FILE):
54
+ with open(TRAINING_STATE_FILE, 'r') as f:
55
+ state = json.load(f)
56
+ print(f"Training state loaded from {TRAINING_STATE_FILE}")
57
+ return state['total_timesteps_trained'], \
58
+ state['episode_count'], \
59
+ state['all_episode_rewards'], \
60
+ state['plot_rewards_history']
61
+ return 0, 0, [], []
62
+
63
+
64
+ def train_agent():
65
+ print(f"Current working directory: {os.getcwd()}")
66
+ print("Initializing environment and agent...")
67
+
68
+ render_mode = 'human' if HYPERPARAMETERS['render_training'] else None
69
+
70
+ env = SnakeGameEnv(render_mode=render_mode)
71
+
72
+ if HYPERPARAMETERS['render_training'] and HYPERPARAMETERS['render_fps_limit'] > 0:
73
+ env.metadata["render_fps"] = HYPERPARAMETERS['render_fps_limit']
74
+
75
+ obs_shape = env.observation_space.shape
76
+ action_size = env.action_space.n
77
+
78
+ agent = PPOAgent(
79
+ observation_space_shape=obs_shape,
80
+ action_space_size=action_size,
81
+ actor_lr=HYPERPARAMETERS['actor_lr'],
82
+ critic_lr=HYPERPARAMETERS['critic_lr'],
83
+ gamma=HYPERPARAMETERS['gamma'],
84
+ gae_lambda=HYPERPARAMETERS['gae_lambda'],
85
+ clip_epsilon=HYPERPARAMETERS['clip_epsilon'],
86
+ num_epochs_per_update=HYPERPARAMETERS['num_epochs_per_update'],
87
+ batch_size=HYPERPARAMETERS['batch_size'],
88
+ hidden_layer_sizes=HYPERPARAMETERS['hidden_layer_sizes']
89
+ )
90
+
91
+ total_timesteps_trained = 0
92
+ episode_count = 0
93
+ all_episode_rewards = []
94
+ plot_rewards_history = []
95
+ last_saved_timesteps = 0
96
+
97
+ # --- Resume Training Logic ---
98
+ if HYPERPARAMETERS['resume_training']:
99
+ print("Attempting to resume training...")
100
+ latest_checkpoint = None
101
+ for f in os.listdir(MODEL_SAVE_DIR):
102
+ if f.endswith('_actor.keras'):
103
+ try:
104
+ timestep_str = f.split('_')[-2]
105
+ timestep = int(timestep_str)
106
+ if latest_checkpoint is None or timestep > latest_checkpoint[0]:
107
+ latest_checkpoint = (timestep, f.replace('_actor.keras', ''))
108
+ except ValueError:
109
+ continue
110
+
111
+ if latest_checkpoint:
112
+ print(f"Found latest checkpoint: {latest_checkpoint[1]}")
113
+ if agent.load_models(latest_checkpoint[1]):
114
+ total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history = load_training_state()
115
+ last_saved_timesteps = total_timesteps_trained
116
+ print(f"Resumed from Timestep: {total_timesteps_trained}, Episode: {episode_count}")
117
+ else:
118
+ print("Failed to load models. Starting new training run.")
119
+ HYPERPARAMETERS['resume_training'] = False
120
+ else:
121
+ print("No previous checkpoints found. Starting new training run.")
122
+ HYPERPARAMETERS['resume_training'] = False
123
+
124
+ print("Starting training loop...")
125
+ start_time = time.time()
126
+
127
+ fig, ax, line = init_live_plot(PLOT_SAVE_DIR, filename="live_ppo_training_progress.png")
128
+ if HYPERPARAMETERS['resume_training'] and len(plot_rewards_history) > 1:
129
+ episodes_for_plot = [i * HYPERPARAMETERS['log_interval_episodes'] for i in range(len(plot_rewards_history))]
130
+ smoothed_rewards = smooth_curve(plot_rewards_history, factor=HYPERPARAMETERS['plot_smoothing_factor'])
131
+ update_live_plot(fig, ax, line, episodes_for_plot, smoothed_rewards,
132
+ current_timestep=total_timesteps_trained,
133
+ total_timesteps=HYPERPARAMETERS['total_timesteps'])
134
+
135
+ while total_timesteps_trained < HYPERPARAMETERS['total_timesteps']:
136
+ current_rollout_steps = 0
137
+
138
+ while current_rollout_steps < HYPERPARAMETERS['num_steps_per_rollout'] and \
139
+ total_timesteps_trained + current_rollout_steps < HYPERPARAMETERS['total_timesteps']:
140
+
141
+ state, info = env.reset()
142
+ current_action_mask = info['action_mask']
143
+
144
+ done = False
145
+ current_episode_reward = 0
146
+
147
+ while not done and current_rollout_steps < HYPERPARAMETERS['num_steps_per_rollout'] and \
148
+ total_timesteps_trained + current_rollout_steps < HYPERPARAMETERS['total_timesteps']:
149
+
150
+ action, log_prob, value = agent.choose_action(state, current_action_mask)
151
+
152
+ next_state, reward, terminated, truncated, info = env.step(action)
153
+ current_episode_reward += reward
154
+
155
+ next_action_mask = info['action_mask']
156
+
157
+ # --- NEW: PASS ACTION MASK TO AGENT'S REMEMBER METHOD ---
158
+ agent.remember(state, action, reward, next_state, terminated, log_prob, value, current_action_mask)
159
+
160
+ state = next_state
161
+ current_action_mask = next_action_mask
162
+
163
+ current_rollout_steps += 1
164
+
165
+ done = terminated or truncated
166
+
167
+ if done:
168
+ episode_count += 1
169
+ all_episode_rewards.append(current_episode_reward)
170
+
171
+ if episode_count % HYPERPARAMETERS['log_interval_episodes'] == 0:
172
+ avg_reward_last_n_episodes = np.mean(all_episode_rewards[-HYPERPARAMETERS['log_interval_episodes']:]).round(2)
173
+ plot_rewards_history.append(avg_reward_last_n_episodes)
174
+
175
+ elapsed_time = time.time() - start_time
176
+ print(f"Timestep: {total_timesteps_trained + current_rollout_steps}/{HYPERPARAMETERS['total_timesteps']} | "
177
+ f"Episode: {episode_count} | "
178
+ f"Avg Reward (last {HYPERPARAMETERS['log_interval_episodes']}): {avg_reward_last_n_episodes} | "
179
+ f"Total Score (this ep): {info['score']} | "
180
+ f"Time: {elapsed_time:.2f}s")
181
+
182
+ if episode_count % HYPERPARAMETERS['live_plot_interval_episodes'] == 0:
183
+ if len(plot_rewards_history) > 1:
184
+ episodes_for_plot = [i * HYPERPARAMETERS['log_interval_episodes'] for i in range(len(plot_rewards_history))]
185
+ smoothed_rewards = smooth_curve(plot_rewards_history, factor=HYPERPARAMETERS['plot_smoothing_factor'])
186
+ update_live_plot(fig, ax, line, episodes_for_plot, smoothed_rewards,
187
+ current_timestep=total_timesteps_trained + current_rollout_steps,
188
+ total_timesteps=HYPERPARAMETERS['total_timesteps'])
189
+
190
+ if HYPERPARAMETERS['render_training'] and done:
191
+ time.sleep(0.5)
192
+
193
+ break
194
+
195
+ total_timesteps_trained += current_rollout_steps
196
+
197
+ if len(agent.states) > 0:
198
+ print(f" --- Agent learning at Total Timestep {total_timesteps_trained} (collected {len(agent.states)} steps in rollout) ---")
199
+ agent.learn()
200
+ else:
201
+ print(f" --- No data collected in current rollout, skipping learning ---")
202
+
203
+ if total_timesteps_trained >= HYPERPARAMETERS['save_interval_timesteps'] and \
204
+ (total_timesteps_trained // HYPERPARAMETERS['save_interval_timesteps']) > \
205
+ (last_saved_timesteps // HYPERPARAMETERS['save_interval_timesteps']):
206
+
207
+ save_path_timesteps = (total_timesteps_trained // HYPERPARAMETERS['save_interval_timesteps']) * HYPERPARAMETERS['save_interval_timesteps']
208
+ print(f"--- Triggering periodic save at calculated timestep: {save_path_timesteps} ---")
209
+ agent.save_models(os.path.join(MODEL_SAVE_DIR, f"ppo_snake_{save_path_timesteps}"))
210
+ save_training_state(total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history)
211
+ last_saved_timesteps = save_path_timesteps
212
+
213
+ print("\nTraining finished!")
214
+ print(f"--- Triggering final save at total_timesteps: {total_timesteps_trained} ---")
215
+ agent.save_models(os.path.join(MODEL_SAVE_DIR, "ppo_snake_final"))
216
+ save_training_state(total_timesteps_trained, episode_count, all_episode_rewards, plot_rewards_history)
217
+
218
+ env.close()
219
+
220
+ print("Generating final performance plot...")
221
+ episodes_for_plot = [i * HYPERPARAMETERS['log_interval_episodes'] for i in range(len(plot_rewards_history))]
222
+ smoothed_rewards = smooth_curve(plot_rewards_history, factor=HYPERPARAMETERS['plot_smoothing_factor'])
223
+ update_live_plot(fig, ax, line, episodes_for_plot, smoothed_rewards,
224
+ current_timestep=total_timesteps_trained,
225
+ total_timesteps=HYPERPARAMETERS['total_timesteps'])
226
+ save_live_plot_final(fig, ax)
227
+
228
+ plot_rewards(smoothed_rewards, HYPERPARAMETERS['log_interval_episodes'], PLOT_SAVE_DIR, "ppo_training_progress_final.png", show_plot=False)
229
+
230
+ if __name__ == "__main__":
231
+ train_agent()
Snake_EnvAndAgent.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium import spaces
3
+ import random
4
+ import pygame
5
+ import numpy as np
6
+ import collections
7
+ from collections import deque
8
+ from Environment_Constants import (
9
+ GRID_SIZE, CELL_SIZE, SCREEN_WIDTH, SCREEN_HEIGHT,
10
+ WHITE, BLACK, GREEN, RED, BLUE,
11
+ UP, DOWN, LEFT, RIGHT,
12
+ REWARD_FOOD, REWARD_COLLISION, REWARD_STEP,
13
+ FPS, OBSERVATION_SPACE_SIZE
14
+ )
15
+
16
+ class SnakeGameEnv(gym.Env):
17
+ metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': FPS}
18
+
19
+ def __init__(self, render_mode=None):
20
+ super().__init__()
21
+ self.grid_size = GRID_SIZE
22
+ self.cell_size = CELL_SIZE
23
+ self.screen_width = SCREEN_WIDTH
24
+ self.screen_height = SCREEN_HEIGHT
25
+
26
+ self.action_space = spaces.Discrete(3)
27
+
28
+ self.observation_space = spaces.Box(low=0, high=1,
29
+ shape=(OBSERVATION_SPACE_SIZE,),
30
+ dtype=np.float32)
31
+
32
+ self.render_mode = render_mode
33
+ self.window = None
34
+ self.clock = None
35
+
36
+ self._init_game_state()
37
+
38
+ def _init_game_state(self):
39
+ self.snake = deque()
40
+ self.head = (self.grid_size // 2, self.grid_size // 2)
41
+ self.snake.append(self.head)
42
+ self.snake.append((self.head[0], self.head[1] + 1))
43
+ self.snake.append((self.head[0], self.head[1] + 2))
44
+
45
+ self.direction = UP
46
+ self.score = 0
47
+ self.food = self._place_food()
48
+ self.game_over = False
49
+ self.steps_since_food = 0
50
+ self.length = len(self.snake)
51
+
52
+
53
+ def _place_food(self):
54
+ while True:
55
+ x = random.randrange(self.grid_size)
56
+ y = random.randrange(self.grid_size)
57
+ food_pos = (x, y)
58
+ if food_pos not in self.snake:
59
+ return food_pos
60
+
61
+ def _is_position_safe_for_observation(self, pos):
62
+ px, py = pos
63
+ if not (0 <= px < self.grid_size and 0 <= py < self.grid_size):
64
+ return False
65
+ if pos in list(self.snake)[1:]:
66
+ return False
67
+ return True
68
+
69
+ def _get_observation(self):
70
+ obs = np.zeros(OBSERVATION_SPACE_SIZE, dtype=np.float32)
71
+
72
+ hx, hy = self.head
73
+
74
+ if self.direction == UP:
75
+ dir_straight = UP
76
+ dir_right = RIGHT
77
+ dir_left = LEFT
78
+ elif self.direction == DOWN:
79
+ dir_straight = DOWN
80
+ dir_right = LEFT
81
+ dir_left = RIGHT
82
+ elif self.direction == LEFT:
83
+ dir_straight = LEFT
84
+ dir_right = UP
85
+ dir_left = DOWN
86
+ elif self.direction == RIGHT:
87
+ dir_straight = RIGHT
88
+ dir_right = DOWN
89
+ dir_left = UP
90
+
91
+ check_pos_straight = (hx + dir_straight[0], hy + dir_straight[1])
92
+ check_pos_right = (hx + dir_right[0], hy + dir_right[1])
93
+ check_pos_left = (hx + dir_left[0], hy + dir_left[1])
94
+
95
+ obs[0] = 1 if not self._is_position_safe_for_observation(check_pos_straight) else 0
96
+ obs[1] = 1 if not self._is_position_safe_for_observation(check_pos_right) else 0
97
+ obs[2] = 1 if not self._is_position_safe_for_observation(check_pos_left) else 0
98
+
99
+ fx, fy = self.food
100
+ if fy < hy: obs[3] = 1
101
+ if fy > hy: obs[4] = 1
102
+ if fx < hx: obs[5] = 1
103
+ if fx > hx: obs[6] = 1
104
+
105
+ if self.direction == UP: obs[7] = 1
106
+ elif self.direction == DOWN: obs[8] = 1
107
+ elif self.direction == LEFT: obs[9] = 1
108
+ elif self.direction == RIGHT: obs[10] = 1
109
+
110
+ return obs
111
+
112
+ def _get_action_mask(self):
113
+
114
+ mask = np.array([True, True, True], dtype=bool)
115
+ hx, hy = self.head
116
+
117
+ potential_directions = [
118
+ self.direction,
119
+ None,
120
+ None
121
+ ]
122
+
123
+ if self.direction == UP:
124
+ potential_directions[1] = RIGHT
125
+ potential_directions[2] = LEFT
126
+ elif self.direction == DOWN:
127
+ potential_directions[1] = LEFT
128
+ potential_directions[2] = RIGHT
129
+ elif self.direction == LEFT:
130
+ potential_directions[1] = UP
131
+ potential_directions[2] = DOWN
132
+ elif self.direction == RIGHT:
133
+ potential_directions[1] = DOWN
134
+ potential_directions[2] = UP
135
+
136
+ def _is_potential_move_illegal(pos_to_check, current_snake, food_pos):
137
+ if not (0 <= pos_to_check[0] < self.grid_size and 0 <= pos_to_check[1] < self.grid_size):
138
+ return True
139
+
140
+ if pos_to_check in list(current_snake)[:-1]:
141
+ return True
142
+
143
+ if pos_to_check == current_snake[-1]:
144
+ if pos_to_check != food_pos:
145
+ return True
146
+
147
+
148
+ return False
149
+
150
+ for action_idx, new_dir in enumerate(potential_directions):
151
+ dx, dy = new_dir
152
+ potential_head = (hx + dx, hy + dy)
153
+ if _is_potential_move_illegal(potential_head, self.snake, self.food):
154
+ mask[action_idx] = False
155
+
156
+ if not np.any(mask):
157
+ print(f"Warning: All actions masked out at head {self.head}, direction {self.direction}, food {self.food}. Attempting to find a fallback action.")
158
+ found_fallback = False
159
+ for i in range(3): # Check Straight, Right, Left
160
+ dx, dy = potential_directions[i]
161
+ potential_head = (hx + dx, hy + dy)
162
+ if not _is_potential_move_illegal(potential_head, self.snake, self.food):
163
+ mask[i] = True
164
+ found_fallback = True
165
+
166
+ if not found_fallback:
167
+ mask[np.random.choice(3)] = True
168
+ print("Critical Warning: No legal actions found even after fallback logic. Enabling a random action to prevent deadlock.")
169
+
170
+ return mask
171
+
172
+ def reset(self, seed=None, options=None):
173
+ super().reset(seed=seed)
174
+ self._init_game_state()
175
+ observation = self._get_observation()
176
+ info = self._get_info()
177
+
178
+ if not np.any(info['action_mask']):
179
+ print("Warning: No valid actions found in initial reset state.")
180
+
181
+ if self.render_mode == 'human':
182
+ self._render_frame()
183
+ return observation, info
184
+
185
+ def _get_info(self):
186
+ """Returns environment information, including the action mask."""
187
+ return {
188
+ "score": self.score,
189
+ "snake_length": len(self.snake),
190
+ "action_mask": self._get_action_mask()
191
+ }
192
+
193
+ def step(self, action):
194
+
195
+ new_direction = self.direction
196
+
197
+ if action == 1:
198
+ if self.direction == UP: new_direction = RIGHT
199
+ elif self.direction == DOWN: new_direction = LEFT
200
+ elif self.direction == LEFT: new_direction = UP
201
+ elif self.direction == RIGHT: new_direction = DOWN
202
+ elif action == 2:
203
+ if self.direction == UP: new_direction = LEFT
204
+ elif self.direction == DOWN: new_direction = RIGHT
205
+ elif self.direction == LEFT: new_direction = DOWN
206
+ elif self.direction == RIGHT: new_direction = UP
207
+ elif action != 0:
208
+ raise ValueError(f"Received invalid action={action} which is not part of the action space.")
209
+
210
+ self.direction = new_direction
211
+
212
+ hx, hy = self.head
213
+ dx, dy = self.direction
214
+ new_head = (hx + dx, hy + dy)
215
+
216
+ reward = REWARD_STEP
217
+ terminated = False
218
+ truncated = False
219
+ if not (0 <= new_head[0] < self.grid_size and 0 <= new_head[1] < self.grid_size):
220
+ terminated = True
221
+ reward = REWARD_COLLISION
222
+
223
+ elif new_head in list(self.snake)[:-1]:
224
+ terminated = True
225
+ reward = REWARD_COLLISION
226
+ elif new_head == self.snake[-1] and new_head != self.food:
227
+ terminated = True
228
+ reward = REWARD_COLLISION
229
+
230
+ if terminated:
231
+ self.game_over = True
232
+ else:
233
+ self.snake.appendleft(new_head)
234
+ self.head = new_head
235
+
236
+ if new_head == self.food:
237
+ self.score += 1
238
+ self.length += 1
239
+ reward = REWARD_FOOD
240
+ self.food = self._place_food()
241
+ self.steps_since_food = 0
242
+ else:
243
+ self.snake.pop()
244
+ self.steps_since_food += 1
245
+
246
+ if self.steps_since_food >= self.grid_size * self.grid_size * 1.5:
247
+ terminated = True
248
+ truncated = True
249
+ reward = REWARD_COLLISION
250
+
251
+
252
+ observation = self._get_observation()
253
+ info = self._get_info()
254
+
255
+ if self.render_mode == 'human':
256
+ self._render_frame()
257
+
258
+ return observation, reward, terminated, truncated, info
259
+
260
+ def _render_frame(self):
261
+ if self.window is None and self.render_mode == 'human':
262
+ pygame.init()
263
+ pygame.display.init()
264
+ self.window = pygame.display.set_mode((self.screen_width, self.screen_height))
265
+ pygame.display.set_caption("Snake AI Training")
266
+ if self.clock is None and self.render_mode == 'human':
267
+ self.clock = pygame.time.Clock()
268
+
269
+ if self.render_mode == 'human':
270
+ self.window.fill(BLACK)
271
+
272
+ pygame.draw.rect(self.window, RED, (self.food[0] * self.cell_size,
273
+ self.food[1] * self.cell_size,
274
+ self.cell_size, self.cell_size))
275
+
276
+ for i, segment in enumerate(self.snake):
277
+ color = BLUE if i == 0 else GREEN
278
+ pygame.draw.rect(self.window, color, (segment[0] * self.cell_size,
279
+ segment[1] * self.cell_size,
280
+ self.cell_size, self.cell_size))
281
+
282
+ for x in range(0, self.screen_width, self.cell_size):
283
+ pygame.draw.line(self.window, WHITE, (x, 0), (x, self.screen_height))
284
+ for y in range(0, self.screen_height, self.cell_size):
285
+ pygame.draw.line(self.window, WHITE, (0, y), (self.screen_width, y))
286
+
287
+ font = pygame.font.Font(None, 25)
288
+ text = font.render(f"Score: {self.score}", True, WHITE)
289
+ self.window.blit(text, (5, 5))
290
+
291
+ pygame.event.pump()
292
+ pygame.display.flip()
293
+ self.clock.tick(self.metadata["render_fps"])
294
+ elif self.render_mode == "rgb_array":
295
+ surf = pygame.Surface((self.screen_width, self.screen_height))
296
+ surf.fill(BLACK)
297
+ pygame.draw.rect(surf, RED, (self.food[0] * self.cell_size,
298
+ self.food[1] * self.cell_size,
299
+ self.cell_size, self.cell_size))
300
+ for i, segment in enumerate(self.snake):
301
+ color = BLUE if i == 0 else GREEN
302
+ pygame.draw.rect(surf, color, (segment[0] * self.cell_size,
303
+ segment[1] * self.cell_size,
304
+ self.cell_size, self.cell_size))
305
+ return np.transpose(np.array(pygame.surfarray.pixels3d(surf)), axes=(1, 0, 2))
306
+
307
+ def close(self):
308
+ if self.window is not None:
309
+ pygame.display.quit()
310
+ pygame.quit()
311
+ self.window = None
312
+ self.clock = None
Trained_PPO_Agent.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from Snake_EnvAndAgent import SnakeGameEnv
3
+ from PPO_Model import PPOAgent
4
+ import os
5
+ import time
6
+ import numpy as np
7
+ from plot_utility_Trained_Agent import init_live_plot, update_live_plot, save_live_plot_final, smooth_curve
8
+
9
+ PLAY_CONFIG = {
10
+ 'grid_size': 30,
11
+ 'model_path_prefix': 'snake_ppo_models/ppo_snake_final',
12
+ 'num_episodes_to_play': 100,
13
+ 'render_fps': 10,
14
+ 'live_plot_interval_episodes': 1,
15
+ 'plot_smoothing_factor': 0.8
16
+ }
17
+
18
+ PLOT_SAVE_DIR = 'snake_ppo_plots'
19
+ os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
20
+
21
+
22
+ def play_agent():
23
+ print("Initializing environment for playback...")
24
+ env = SnakeGameEnv(render_mode='human')
25
+
26
+ if PLAY_CONFIG['render_fps'] > 0:
27
+ env.metadata["render_fps"] = PLAY_CONFIG['render_fps']
28
+
29
+ obs_shape = env.observation_space.shape
30
+ action_size = env.action_space.n
31
+
32
+ agent = PPOAgent(
33
+ observation_space_shape=obs_shape,
34
+ action_space_size=action_size,
35
+ actor_lr=3e-4,
36
+ critic_lr=3e-4,
37
+ hidden_layer_sizes=[512, 512, 512]
38
+ )
39
+
40
+ print(f"loading models from: {PLAY_CONFIG['model_path_prefix']}")
41
+
42
+ load_success = agent.load_models(PLAY_CONFIG['model_path_prefix'])
43
+
44
+ if not load_success:
45
+ print("\nFATAL ERROR: Failed to load trained models from disk. The agent CANNOT perform as trained. Exiting playback!.")
46
+ env.close()
47
+ return
48
+
49
+ print("--- Trained models loaded successfully. Loading playback. ---")
50
+
51
+ print("Starting agent playback...")
52
+
53
+ episode_rewards_playback = []
54
+
55
+ fig, ax, line = init_live_plot(PLOT_SAVE_DIR, filename="live_playback_rewards_plot.png")
56
+ ax.set_title('Live Playback Progress (Episode Rewards)')
57
+ ax.set_xlabel('Episode')
58
+ ax.set_ylabel('Total Reward')
59
+
60
+ for episode in range(PLAY_CONFIG['num_episodes_to_play']):
61
+ state, info = env.reset()
62
+ current_action_mask = info['action_mask']
63
+
64
+ done = False
65
+ episode_reward = 0
66
+ steps = 0
67
+
68
+ while not done:
69
+ action, _, _ = agent.choose_action(state, current_action_mask)
70
+
71
+ next_state, reward, terminated, truncated, info = env.step(action)
72
+ episode_reward += reward
73
+ state = next_state
74
+ steps += 1
75
+ done = terminated or truncated
76
+
77
+ current_action_mask = info['action_mask']
78
+
79
+ if PLAY_CONFIG['render_fps'] > 0:
80
+ time.sleep(1 / env.metadata["render_fps"])
81
+
82
+ episode_rewards_playback.append(episode_reward)
83
+
84
+ print(f"Episode {episode + 1}: Total Reward = {episode_reward:.2f}, Score = {info['score']}, Steps = {steps}")
85
+
86
+ if (episode + 1) % PLAY_CONFIG['live_plot_interval_episodes'] == 0:
87
+ current_episodes = list(range(1, len(episode_rewards_playback) + 1))
88
+ smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor'])
89
+ update_live_plot(fig, ax, line, current_episodes, smoothed_rewards,
90
+ current_timestep=episode + 1,
91
+ total_timesteps=PLAY_CONFIG['num_episodes_to_play'])
92
+
93
+ time.sleep(0.5)
94
+
95
+ env.close()
96
+ print("\nPlayback finished.")
97
+
98
+ current_episodes = list(range(1, len(episode_rewards_playback) + 1))
99
+ smoothed_rewards = smooth_curve(episode_rewards_playback, factor=PLAY_CONFIG['plot_smoothing_factor'])
100
+ update_live_plot(fig, ax, line, current_episodes, smoothed_rewards,
101
+ current_timestep=PLAY_CONFIG['num_episodes_to_play'],
102
+ total_timesteps=PLAY_CONFIG['num_episodes_to_play'])
103
+ save_live_plot_final(fig, ax)
104
+
105
+ avg_playback_reward = np.mean(episode_rewards_playback)
106
+ print(f"Average Reward over {PLAY_CONFIG['num_episodes_to_play']} playback episodes: {avg_playback_reward:.2f}")
107
+
108
+
109
+ if __name__ == "__main__":
110
+ play_agent()
plot_utility_Trained_Agent.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import os
4
+ import time
5
+
6
+ def smooth_curve(points, factor=0.9):
7
+ smoothed_points = []
8
+ if points:
9
+ smoothed_points.append(points[0])
10
+ for i in range(1, len(points)):
11
+ smoothed_points.append(smoothed_points[-1] * factor + points[i] * (1 - factor))
12
+ return smoothed_points
13
+
14
+ def plot_rewards(rewards_history, log_interval, save_dir, filename="Trained_Agent_rewards_plot.png", show_plot=True):
15
+
16
+ os.makedirs(save_dir, exist_ok=True)
17
+
18
+ plt.figure(figsize=(12, 6))
19
+ episodes = [i * log_interval for i in range(1, len(rewards_history) + 1)]
20
+ plt.plot(episodes, rewards_history, label='Average Reward')
21
+ plt.xlabel('Episodes')
22
+ plt.ylabel('Average Reward')
23
+ plt.title('Trained Agent Live Reward (Average Reward per Episode)')
24
+ plt.grid(True)
25
+ plt.legend()
26
+ plt.tight_layout()
27
+
28
+ save_path = os.path.join(save_dir, filename)
29
+ plt.savefig(save_path)
30
+ print(f"Plot saved to: {os.path.abspath(save_path)}")
31
+
32
+ if show_plot:
33
+ plt.show()
34
+
35
+
36
+ def init_live_plot(save_dir, filename="live_rewards_plot.png"):
37
+
38
+ plt.ion()
39
+ fig, ax = plt.subplots(figsize=(12, 6))
40
+ line, = ax.plot([], [], label='Smoothed Average Reward')
41
+ ax.set_xlabel('Episodes')
42
+ ax.set_ylabel('Average Reward')
43
+ ax.set_title('Live Reward for Trained Agent')
44
+ ax.grid(True)
45
+ ax.legend()
46
+ plt.tight_layout()
47
+
48
+ ax._save_path_final = os.path.join(save_dir, filename)
49
+
50
+ return fig, ax, line
51
+
52
+ def update_live_plot(fig, ax, line, episodes, smoothed_rewards, current_timestep=None, total_timesteps=None):
53
+
54
+ if not episodes or not smoothed_rewards:
55
+ return
56
+
57
+ line.set_data(episodes, smoothed_rewards)
58
+
59
+ ax.set_xlim(0, max(episodes) * 1.05 if episodes else 1)
60
+
61
+ min_y = min(smoothed_rewards) * 0.9 if smoothed_rewards else -1
62
+ max_y = max(smoothed_rewards) * 1.1 if smoothed_rewards else 1
63
+
64
+ if abs(max_y - min_y) < 0.1:
65
+ min_y -= 0.05
66
+ max_y += 0.05
67
+ ax.set_ylim(min_y, max_y)
68
+
69
+ if current_timestep is not None and total_timesteps is not None:
70
+ ax.set_title(f'Live Agent Progress (Timestep: {current_timestep:,}/{total_timesteps:,})')
71
+
72
+ fig.canvas.draw()
73
+ fig.canvas.flush_events()
74
+ time.sleep(0.01)
75
+
76
+ def save_live_plot_final(fig, ax):
77
+
78
+ plt.ioff()
79
+ save_path = getattr(ax, '_save_path_final', None)
80
+ if save_path:
81
+ plt.savefig(save_path)
82
+ print(f"Final live plot saved to: {os.path.abspath(save_path)}")
83
+ plt.close(fig)
84
+ plt.show()
plot_utility_Trainer.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import os
4
+ import time
5
+
6
+ def smooth_curve(points, factor=0.9):
7
+
8
+ smoothed_points = []
9
+ if points:
10
+ smoothed_points.append(points[0])
11
+ for i in range(1, len(points)):
12
+ smoothed_points.append(smoothed_points[-1] * factor + points[i] * (1 - factor))
13
+ return smoothed_points
14
+
15
+ def plot_rewards(rewards_history, log_interval, save_dir, filename="rewards_plot.png", show_plot=True):
16
+
17
+ os.makedirs(save_dir, exist_ok=True)
18
+
19
+ plt.figure(figsize=(12, 6))
20
+ episodes = [i * log_interval for i in range(1, len(rewards_history) + 1)]
21
+ plt.plot(episodes, rewards_history, label='Average Reward')
22
+ plt.xlabel('Episodes')
23
+ plt.ylabel('Average Reward')
24
+ plt.title('PPO Training Progress (Average Reward per Episode)')
25
+ plt.grid(True)
26
+ plt.legend()
27
+ plt.tight_layout()
28
+
29
+ save_path = os.path.join(save_dir, filename)
30
+ plt.savefig(save_path)
31
+ print(f"Plot saved to: {os.path.abspath(save_path)}")
32
+
33
+ if show_plot:
34
+ plt.show()
35
+
36
+ def init_live_plot(save_dir, filename="live_rewards_plot.png"):
37
+
38
+ plt.ion()
39
+ fig, ax = plt.subplots(figsize=(12, 6))
40
+ line, = ax.plot([], [], label='Smoothed Average Reward')
41
+ ax.set_xlabel('Episodes')
42
+ ax.set_ylabel('Average Reward')
43
+ ax.set_title('Live PPO Training Progress')
44
+ ax.grid(True)
45
+ ax.legend()
46
+ plt.tight_layout()
47
+
48
+ ax._save_path_final = os.path.join(save_dir, filename)
49
+
50
+ return fig, ax, line
51
+
52
+ def update_live_plot(fig, ax, line, episodes, smoothed_rewards, current_timestep=None, total_timesteps=None):
53
+
54
+ if not episodes or not smoothed_rewards:
55
+ return
56
+
57
+ line.set_data(episodes, smoothed_rewards)
58
+
59
+ ax.set_xlim(0, max(episodes) * 1.05 if episodes else 1)
60
+
61
+ min_y = min(smoothed_rewards) * 0.9 if smoothed_rewards else -1
62
+ max_y = max(smoothed_rewards) * 1.1 if smoothed_rewards else 1
63
+
64
+ if abs(max_y - min_y) < 0.1:
65
+ min_y -= 0.05
66
+ max_y += 0.05
67
+ ax.set_ylim(min_y, max_y)
68
+
69
+ if current_timestep is not None and total_timesteps is not None:
70
+ ax.set_title(f'Live PPO Training Progress (Timestep: {current_timestep:,}/{total_timesteps:,})')
71
+
72
+ fig.canvas.draw()
73
+ fig.canvas.flush_events()
74
+ time.sleep(0.01)
75
+
76
+ def save_live_plot_final(fig, ax):
77
+
78
+ plt.ioff()
79
+ save_path = getattr(ax, '_save_path_final', None)
80
+ if save_path:
81
+ plt.savefig(save_path)
82
+ print(f"Final live plot saved to: {os.path.abspath(save_path)}")
83
+ plt.close(fig)
84
+ plt.show()