Anoozh-Akileswaran commited on
Commit
c3ec5ed
·
1 Parent(s): d937e11

First results from observation/return/reward norm.

Browse files
Files changed (32) hide show
  1. CNN_PPO/ppo_helpers_cnn.py +2 -1
  2. Observation_Advantage_Norm/PPO_Obser_Adva_Norm.py +0 -355
  3. Observation_Advantage_Norm_diff_combo/ppo__rew_norm_obs_diff_combo.py +1254 -0
  4. Observation_Advantage_Norm_diff_combo/ppo_rew_norm_obs_env_diff_combo.py +201 -0
  5. Observation_Advantage_Norm_diff_env/ppo__rew_norm_obs_diff_env.py +891 -0
  6. Observation_Advantage_Norm_diff_env/ppo_rew_norm_obs_env_diff_env.py +191 -0
  7. Observation_Advantage_Norm_diff_hypo/Performance config for Learning Rate of update_advantage_norm.png +0 -0
  8. Observation_Advantage_Norm_diff_hypo/Performance config for Learning Rate of update_observation_norm.png +0 -0
  9. Observation_Advantage_Norm_diff_hypo/Performance config for Learning Rate of update_return_norm.png +0 -0
  10. Observation_Advantage_Norm_diff_hypo/Performance config for Learning Rate of vanilla_ppo_update.png +0 -0
  11. Observation_Advantage_Norm_diff_hypo/Performance config for entropy coefficient of update_advantage_norm.png +0 -0
  12. Observation_Advantage_Norm_diff_hypo/Performance config for entropy coefficient of update_observation_norm.png +0 -0
  13. Observation_Advantage_Norm_diff_hypo/Performance config for entropy coefficient of update_return_norm.png +0 -0
  14. Observation_Advantage_Norm_diff_hypo/Performance config for entropy coefficient of vanilla_ppo_update.png +0 -0
  15. Observation_Advantage_Norm_diff_hypo/Performance config for gamma value of update_advantage_norm.png +0 -0
  16. Observation_Advantage_Norm_diff_hypo/Performance config for gamma value of update_observation_norm.png +0 -0
  17. Observation_Advantage_Norm_diff_hypo/Performance config for gamma value of update_return_norm.png +0 -0
  18. Observation_Advantage_Norm_diff_hypo/Performance config for gamma value of vanilla_ppo_update.png +0 -0
  19. Observation_Advantage_Norm_diff_hypo/ppo__rew_norm_obs_diff_hyp.py +890 -0
  20. Observation_Advantage_Norm/PPO_environment.py → Observation_Advantage_Norm_diff_hypo/ppo_rew_norm_obs_env_diff_hypo.py +109 -44
  21. Observation_Advantage_Norm_in_batch/ppo__rew_norm_obs_in_batch.py +829 -0
  22. Observation_Advantage_Norm_in_batch/ppo_rew_norm_obs_env_in_batch.py +163 -0
  23. Observation_Advantage_Norm_in_batch/update_advantage_norm_in_batch.png +0 -0
  24. Observation_Advantage_Norm_in_batch/update_observation_norm_in_batch.png +0 -0
  25. Observation_Advantage_Norm_in_batch/update_return_norm_in_batch.png +0 -0
  26. Observation_Advantage_Norm_in_batch/vanilla_ppo_update_in_batch.png +0 -0
  27. Observation_Advantage_Norm_running_averages/ppo__rew_norm_obs_running_average.py +893 -0
  28. Observation_Advantage_Norm_running_averages/ppo_rew_norm_obs_env_running_average.py +163 -0
  29. Observation_Advantage_Norm_running_averages/update_advantage_norm_running_average_.png +0 -0
  30. Observation_Advantage_Norm_running_averages/update_observation_norm_running_average_.png +0 -0
  31. Observation_Advantage_Norm_running_averages/update_return_norm_running_average_.png +0 -0
  32. Observation_Advantage_Norm_running_averages/vanilla_ppo_update_running_average_.png +0 -0
CNN_PPO/ppo_helpers_cnn.py CHANGED
@@ -144,7 +144,7 @@ class Agent:
144
  # Shuffle indices
145
  idxs = T.randperm(num_samples)
146
  for start in range(0, num_samples, batch_size):
147
- batch_idx = idxs[start:start + batch_size]
148
 
149
  b_states = states[batch_idx]
150
  b_actions = actions[batch_idx]
@@ -187,6 +187,7 @@ class Agent:
187
  self.memory.clear()
188
 
189
  return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
 
190
 
191
 
192
  def update_rbs(self):
 
144
  # Shuffle indices
145
  idxs = T.randperm(num_samples)
146
  for start in range(0, num_samples, batch_size):
147
+ batch_idx = idxs[start:start + batch_size] #arrays with indices
148
 
149
  b_states = states[batch_idx]
150
  b_actions = actions[batch_idx]
 
187
  self.memory.clear()
188
 
189
  return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
190
+ #total loss per mini batch * ppo_epochs
191
 
192
 
193
  def update_rbs(self):
Observation_Advantage_Norm/PPO_Obser_Adva_Norm.py DELETED
@@ -1,355 +0,0 @@
1
- import numpy as np
2
- import torch as T
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- from torch.distributions import Categorical
6
-
7
-
8
- class Agent():
9
- # Minimal PPO-Clip agent (single full-batch update per episode, MC returns)
10
- def __init__(
11
- self,
12
- obs_space,
13
- action_space,
14
- hidden,
15
- gamma,
16
- clip_coef,
17
- lr,
18
- value_coef,
19
- entropy_coef,
20
- seed
21
- ):
22
- # Initialize seed for reproducibility
23
- if seed is not None:
24
- np.random.seed(seed)
25
- T.manual_seed(seed)
26
-
27
- # Use GPU if available
28
- self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
29
- self.obs_dim = int(np.prod(getattr(obs_space, "shape", (obs_space,))))
30
- self.action_dim = int(getattr(action_space, "n", action_space))
31
-
32
- # Initialize the policy and the critic networks
33
- self.policy = Policy(self.obs_dim, self.action_dim, hidden).to(self.device)
34
- self.critic = Critic(self.obs_dim, hidden).to(self.device)
35
-
36
- # Set optimizer for policy and critic networks
37
- self.opt = optim.Adam(
38
- list(self.policy.parameters()) + list(self.critic.parameters()),
39
- lr=lr
40
- )
41
- # Initialize the hyperparameter
42
- self.gamma = gamma
43
- self.clip = clip_coef
44
- self.value_coef = value_coef
45
- self.entropy_coef = entropy_coef
46
- # Initilize the memory to store the state, action, reward, ...
47
- self.memory = Memory()
48
- self.observationScaling = ObservationScaling()
49
- self.advantageNorm = AdvantageNorm()
50
- self.total_loss = 0
51
-
52
- def choose_action(self, observation):
53
- # Returns: action, log probabilitiy, value of the state
54
- state = T.as_tensor(observation, dtype=T.float32, device=self.device).view(-1)
55
- with T.no_grad():
56
- # Forward function (defined in Policy class)
57
- dist = self.policy.next_action(state)
58
- # Sample from the action distribution
59
- action = dist.sample()
60
- logp = dist.log_prob(action) # log πθ(a|s)
61
- # Value the current state
62
- value = self.critic.evaluated_state(state)
63
- return int(action.item()), float(logp.item()), float(value.item())
64
-
65
- def remember(self, state, action, reward, done, log_prob, value, next_state):
66
- # Store the info
67
- with T.no_grad():
68
- # Pass on next state and have it evaluated by the critic network
69
- ns = T.as_tensor(next_state, dtype=T.float32, device=self.device).view(-1)
70
- next_value = self.critic.evaluated_state(ns).item()
71
- self.memory.store(state, action, reward, done, log_prob, value, next_value)
72
-
73
- """
74
- def run_episode(self, env, max_steps: int, render: bool = False):
75
- # Runs one episode, updates the policy once at the end
76
- self.memory.clear()
77
- out = env.reset()
78
-
79
- state = out[0] if isinstance(out, tuple) else out
80
-
81
- ep_return, ep_len = 0, 0
82
-
83
- steps_limit = max_steps if max_steps is not None else float("inf")
84
-
85
- while ep_len < steps_limit:
86
- if render and hasattr(env, "render"):
87
- env.render()
88
-
89
- action, logp, value = self.choose_action(state)
90
- step_out = env.step(action)
91
- if len(step_out) == 5:
92
- next_state, reward, terminated, truncated, _ = step_out
93
- done = terminated or truncated
94
- else:
95
- next_state, reward, done, _ = step_out
96
-
97
- self.remember(state, action, reward, done, logp, value, next_state)
98
-
99
- ep_return += float(reward)
100
- ep_len += 1
101
- state = next_state
102
- if done:
103
- break
104
-
105
- self._update()
106
- return ep_return, ep_len
107
-
108
- def run_episodes(self, env, n_episodes: int, max_steps: int, render: bool = False):
109
- returns = []
110
- for _ in range(n_episodes):
111
- ep_ret, _ = self.run_episode(env, max_steps=max_steps, render=render)
112
- returns.append(ep_ret)
113
- return returns
114
-
115
- """
116
-
117
-
118
-
119
-
120
- def _update(self, mode, observationNorm, advantageNorm):
121
- if len(self.memory.states) == 0:
122
- return
123
-
124
- states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
125
- actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
126
- rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
127
- dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
128
- old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
129
- values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
130
- ###Normalization happening
131
- if observationNorm == True:
132
- self.observationScaling.update(states)
133
- states = self.observationScaling.normalize(states)
134
- ###
135
- # Monte Carlo returns (episode-aware)
136
- # Returns discounted sum of future rewards
137
- with T.no_grad():
138
- returns = T.zeros_like(rewards)
139
- G = 0.0
140
- for t in reversed(range(rewards.size(0))):
141
- G = rewards[t] + self.gamma * G * (1.0 - dones[t])
142
- returns[t] = G
143
- # Compute Advantage + advantage normalization in-batch
144
- adv = returns - values
145
- if advantageNorm == True:
146
- self.advantageNorm.update(adv)
147
- self.advantageNorm.normalize(adv)
148
-
149
- #adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
150
- # Recompute distribution under the current policy
151
- dist = self.policy.next_action(states)
152
- new_logp = dist.log_prob(actions)
153
-
154
- """PPO Components: Policy update, weighted probability distribution, clipped returns """
155
-
156
- # Updating the policy: update probability distribution (i.e., compute clipped probs)
157
- ratio = (new_logp - old_logp).exp() # r_t = πθ / πθ_old KL divergence
158
-
159
- # Weighted probaility distribution (according to the formula/update rule)
160
- surr1 = ratio * adv
161
- surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * adv
162
- value_pred = self.critic.evaluated_state(states)
163
- beta = 1.0
164
- target_kl = 0.01
165
-
166
-
167
- #PPO standards
168
- if mode == "clip":
169
- surr1 = ratio * adv
170
- surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * adv
171
- policy_loss = -T.min(surr1, surr2).mean()
172
- print(f"Current policy loss: {policy_loss} with mode; {mode}")
173
-
174
-
175
- elif mode == "kl_penalty":
176
- #punish to policy if it changes too much
177
-
178
- policy_loss = -(ratio * adv).mean()
179
- approx_kl = (old_logp - new_logp).mean()
180
- policy_loss = policy_loss + beta * approx_kl
181
- # adapt beta toward target_kl as shown above
182
- if approx_kl > 1.5 * target_kl:
183
- beta *= 2.0 # too big a step → increase penalty
184
- elif approx_kl < 0.5 * target_kl:
185
- beta *= 0.5 # too small a step → allow bigger updates
186
- print(f"Current policy loss: {policy_loss} with mode; {mode}")
187
-
188
- elif mode == "unclipped_earlystop":
189
- policy_loss = -(ratio * adv).mean()
190
- approx_kl = (old_logp - new_logp).mean()
191
- if approx_kl.item() > 1.5 * target_kl:
192
- # skip optimizer step this update or end further epochs
193
- print(f"Current policy loss: {policy_loss} with mode; {mode}")
194
-
195
- self.memory.clear()
196
- return
197
-
198
-
199
- # Loss: MSE of (return - critic value)
200
- value_loss = 0.5 * (returns - value_pred).pow(2).mean()
201
- # Entropy (account for randomness in action selection)
202
- entropy = dist.entropy().mean()
203
- # Total loss: policy loss + constant * value loss - constant * entropy
204
- self.total_loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
205
-
206
- self.opt.zero_grad(set_to_none=True)
207
- self.total_loss.backward()
208
- self.opt.step()
209
-
210
- self.memory.clear()
211
-
212
-
213
-
214
- class Policy(nn.Module):
215
- def __init__(self, obs_dim: int, action_dim: int, hidden: int):
216
- super().__init__()
217
- self.net = nn.Sequential(
218
- nn.Linear(obs_dim, hidden),
219
- nn.ReLU(),
220
- nn.Linear(hidden, hidden),
221
- nn.ReLU(),
222
- nn.Linear(hidden, action_dim)
223
- )
224
-
225
- def next_action(self, state: T.Tensor) -> Categorical:
226
- # Returns the probability distribution over actions
227
- if state.dim() == 1:
228
- state = state.unsqueeze(0)
229
- state = state.view(state.size(0), -1)
230
- return Categorical(logits=self.net(state))
231
-
232
-
233
- class Critic(nn.Module):
234
- def __init__(self, obs_dim: int, hidden: int):
235
- super().__init__()
236
- self.net = nn.Sequential(
237
- nn.Linear(obs_dim, hidden),
238
- nn.ReLU(),
239
- nn.Linear(hidden, hidden),
240
- nn.ReLU(),
241
- nn.Linear(hidden, 1)
242
- )
243
-
244
- def evaluated_state(self, x: T.Tensor) -> T.Tensor:
245
- if x.dim() == 1:
246
- x = x.unsqueeze(0)
247
- x = x.view(x.size(0), -1)
248
- return self.net(x).squeeze(-1)
249
-
250
-
251
- class Memory():
252
- def __init__(self):
253
- self.states = []
254
- self.actions = []
255
- self.rewards = []
256
- self.dones = []
257
- self.log_probs = []
258
- self.values = []
259
- self.next_values = []
260
-
261
- def store(self, state, action, reward, done, log_prob, value, next_value):
262
- self.states.append(np.asarray(state, dtype=np.float32))
263
- self.actions.append(int(action))
264
- self.rewards.append(float(reward))
265
- self.dones.append(float(done))
266
- self.log_probs.append(float(log_prob))
267
- self.values.append(float(value))
268
- self.next_values.append(float(next_value))
269
-
270
- """
271
- # For mini-batch updates? To be implemented
272
- def start_batch(self, batch_size: int):
273
- n_states = len(self.states)
274
- starts = np.arange(0, n_states, batch_size)
275
- index = np.arange(n_states, dtype=np.int64)
276
- np.random.shuffle(index)
277
- return [index[s:s + batch_size] for s in starts]
278
- """
279
-
280
- def clear(self):
281
- self.states = []
282
- self.actions = []
283
- self.rewards = []
284
- self.dones = []
285
- self.log_probs = []
286
- self.values = []
287
- self.next_values = []
288
-
289
-
290
- class AdvantageNorm:
291
- '''
292
- This class implements the Advantage Normalization. The purpose is to normalize either across batches or
293
- only within the same batch.
294
-
295
- '''
296
- def __init__(self):
297
- self.main_mean = 0
298
- self.main_var = 0
299
- self.count = 1e-4
300
-
301
- def update(self, x: T.Tensor):
302
- print("I am updating the main mean and main variance")
303
- batch_mean = T.mean(x, dim=0)
304
- batch_var = T.var(x, dim=0)
305
- batch_count = x.shape[0]
306
- self._update_from_moments(batch_mean, batch_var, batch_count)
307
-
308
- def _update_from_moments(self, batch_mean, batch_var, batch_count):
309
- delta = batch_mean - self.main_mean
310
- tot_count = self.count + batch_count
311
- new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
312
- m_a = self.main_var * self.count
313
- m_b = batch_var * batch_count
314
- M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
315
- new_var = M2 / tot_count # update the running variance
316
-
317
- self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
318
-
319
- def normalize(self, x):
320
- print("I apply normalization on the advantages")
321
-
322
- return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
323
- # divide through zero.
324
-
325
-
326
- class ObservationScaling:
327
- def __init__(self):
328
- self.main_mean = 0
329
- self.main_var = 0
330
- self.count = 1e-4
331
-
332
- def update(self, x: T.Tensor):
333
- print("I am updating the main mean and main variance")
334
- batch_mean = T.mean(x, dim=0)
335
- batch_var = T.var(x, dim=0)
336
- batch_count = x.shape[0]
337
- self._update_from_moments(batch_mean, batch_var, batch_count)
338
-
339
- def _update_from_moments(self, batch_mean, batch_var, batch_count):
340
- delta = batch_mean - self.main_mean
341
- tot_count = self.count + batch_count
342
- new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
343
- m_a = self.main_var * self.count
344
- m_b = batch_var * batch_count
345
- M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
346
- new_var = M2 / tot_count # update the running variance
347
-
348
- self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
349
-
350
- def normalize(self, x):
351
- print("I apply normalization on the observations")
352
-
353
- return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
354
- # divide through zero.
355
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Observation_Advantage_Norm_diff_combo/ppo__rew_norm_obs_diff_combo.py ADDED
@@ -0,0 +1,1254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as T
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.distributions import Categorical
6
+
7
+
8
+ class Agent:
9
+ def __init__(
10
+ self,
11
+ obs_space,
12
+ action_space,
13
+ hidden,
14
+ gamma,
15
+ clip_coef,
16
+ lr,
17
+ value_coef,
18
+ entropy_coef,
19
+ seed,
20
+ batch_size,
21
+ ppo_epochs,
22
+ lam,
23
+ update_type
24
+
25
+ ):
26
+ # Initialize seed for reproducibility
27
+ if seed is not None:
28
+ np.random.seed(seed)
29
+ T.manual_seed(seed)
30
+ """
31
+ # For flat observations (MLP model)
32
+ # Use GPU if available
33
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
34
+ self.obs_dim = int(np.prod(getattr(obs_space, "shape", (obs_space,))))
35
+ self.action_dim = int(getattr(action_space, "n", action_space))
36
+
37
+ # Initialize the policy and the critic networks
38
+ self.policy = Policy(self.obs_dim, self.action_dim, hidden).to(self.device)
39
+ self.critic = Critic(self.obs_dim, hidden).to(self.device)
40
+ """
41
+ # Use GPU if available
42
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
43
+ self.action_dim = int(getattr(action_space, "n", action_space))
44
+ self.update_type = update_type
45
+
46
+ # Initialize the policy and the critic networks
47
+ # Pass the shape tuple directly, not the flattened dimension.
48
+ self.policy = Policy(obs_space.shape, self.action_dim, hidden).to(self.device)
49
+ self.critic = Critic(obs_space.shape, hidden).to(self.device)
50
+ self.observeNorm = ObservationNorm()
51
+ self.advantageNorm = AdvantageNorm()
52
+ self.returnNorm = ReturnNorm()
53
+
54
+ # Set optimizer for policy and critic networks
55
+ self.opt = optim.Adam(
56
+ list(self.policy.parameters()) + list(self.critic.parameters()),
57
+ lr=lr
58
+ )
59
+
60
+ self.gamma = gamma
61
+ self.clip = clip_coef
62
+ self.value_coef = value_coef
63
+ self.entropy_coef = entropy_coef
64
+ self.sigma_history = []
65
+ self.loss_history = []
66
+ self.policy_loss_history = []
67
+ self.value_loss_history = []
68
+ self.entropy_history = []
69
+ self.lam = lam
70
+ self.ppo_epochs = ppo_epochs
71
+ self.batch_size = batch_size
72
+
73
+ self.memory = Memory()
74
+ """
75
+ # Choose action and remember for flat observations (MLP model)
76
+ def choose_action(self, observation):
77
+ # Returns: action, log probabilitiy, value of the state
78
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device).view(-1)
79
+ with T.no_grad():
80
+ # Forward function (defined in Policy class)
81
+ dist = self.policy.next_action(state)
82
+ action = dist.sample()
83
+ logp = dist.log_prob(action)
84
+ value = self.critic.evaluated_state(state)
85
+ return int(action.item()), float(logp.item()), float(value.item())
86
+
87
+ def remember(self, state, action, reward, done, log_prob, value, next_state):
88
+ with T.no_grad():
89
+ # Pass on next state and have it evaluated by the critic network
90
+ ns = T.as_tensor(next_state, dtype=T.float32, device=self.device).view(-1)
91
+ next_value = self.critic.evaluated_state(ns).item()
92
+ self.memory.store(state, action, reward, done, log_prob, value, next_value)
93
+ """
94
+ # For CNN model
95
+ def choose_action(self, observation):
96
+ # Returns: action, log probabilitiy, value of the state
97
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device) # Remove .view(-1)
98
+ with T.no_grad():
99
+ # Forward function (defined in Policy class)
100
+ dist = self.policy.next_action(state)
101
+ action = dist.sample()
102
+ logp = dist.log_prob(action)
103
+ value = self.critic.evaluated_state(state)
104
+ return int(action.item()), float(logp.item()), float(value.item())
105
+
106
+ def remember(self, state, action, reward, done, log_prob, value, next_state):
107
+ with T.no_grad():
108
+ # Pass on next state and have it evaluated by the critic network
109
+ ns = T.as_tensor(next_state, dtype=T.float32, device=self.device) # Remove .view(-1)
110
+ next_value = self.critic.evaluated_state(ns).item()
111
+ self.memory.store(state, action, reward, done, log_prob, value, next_value)
112
+
113
+
114
+ def _update(self):
115
+ if self.update_type == "update_all_norm":
116
+ return self.update_all_norm()
117
+ elif self.update_type == "update_observation_advantage_norm":
118
+ return self.update_observation_advantage_norm()
119
+ elif self.update_type == "update_observation_return_norm":
120
+ return self.update_observation_return_norm()
121
+ elif self.update_type == "update_advantage_return_norm":
122
+ return self.update_advantage_return_norm()
123
+ else:
124
+ return self.vanilla_ppo_update()
125
+
126
+ def vanilla_ppo_update(self):
127
+ if len(self.memory.states) == 0:
128
+ return 0.0
129
+
130
+ # Convert memory to tensors
131
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
132
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
133
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
134
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
135
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
136
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
137
+
138
+ with T.no_grad():
139
+ # Compute next values (bootstrap for final step)
140
+ next_values = T.cat([values[1:], values[-1:].clone()])
141
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
142
+
143
+ # --- GAE-Lambda ---
144
+ adv = T.zeros_like(rewards)
145
+ gae = 0.0
146
+ for t in reversed(range(len(rewards))):
147
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
148
+ adv[t] = gae
149
+
150
+ returns = adv + values
151
+ # Advantage normalization
152
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
153
+
154
+ # --- PPO Multiple Epochs + Minibatch ---
155
+ total_loss_epoch = 0.0
156
+ num_samples = len(states)
157
+ batch_size = min(64, num_samples)
158
+ ppo_epochs = 4
159
+
160
+ for _ in range(ppo_epochs):
161
+ # Shuffle indices
162
+ idxs = T.randperm(num_samples)
163
+ for start in range(0, num_samples, batch_size):
164
+ batch_idx = idxs[start:start + batch_size]
165
+
166
+ b_states = states[batch_idx]
167
+ b_actions = actions[batch_idx]
168
+ b_old_logp = old_logp[batch_idx]
169
+ b_returns = returns[batch_idx]
170
+ b_adv = adv[batch_idx]
171
+
172
+ dist = self.policy.next_action(b_states)
173
+ new_logp = dist.log_prob(b_actions)
174
+ entropy = dist.entropy().mean()
175
+ ratio = (new_logp - b_old_logp).exp()
176
+
177
+ # --- Clipped surrogate objective ---
178
+ surr1 = ratio * b_adv
179
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
180
+ policy_loss = -T.min(surr1, surr2).mean()
181
+
182
+ # --- Critic loss ---
183
+ value_pred = self.critic.evaluated_state(b_states)
184
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
185
+
186
+ # --- Total loss ---
187
+ total_loss = (
188
+ policy_loss +
189
+ self.value_coef * value_loss -
190
+ self.entropy_coef * entropy
191
+ )
192
+
193
+ # Debug: track individual loss components
194
+ self.policy_loss_history.append(policy_loss.item())
195
+ self.value_loss_history.append(value_loss.item())
196
+
197
+ self.opt.zero_grad(set_to_none=True)
198
+ total_loss.backward()
199
+ self.opt.step()
200
+
201
+ total_loss_epoch += total_loss.item()
202
+
203
+ # Clear memory after full PPO update
204
+ self.memory.clear()
205
+
206
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
207
+
208
+
209
+ def update_rbs(self):
210
+ if len(self.memory.states) == 0:
211
+ return 0.0
212
+
213
+ # Convert memory to tensors
214
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
215
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
216
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
217
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
218
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
219
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
220
+
221
+ with T.no_grad():
222
+ # Compute next values (bootstrap for final step)
223
+ next_values = T.cat([values[1:], values[-1:].clone()])
224
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
225
+
226
+ # --- GAE-Lambda ---
227
+ adv = T.zeros_like(rewards)
228
+ gae = 0.0
229
+ for t in reversed(range(len(rewards))):
230
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
231
+ adv[t] = gae
232
+
233
+ returns = adv + values
234
+
235
+ # --- Return-based normalization (RBS) ---
236
+ sigma_t = returns.std(unbiased=False) + 1e-8
237
+ returns = returns / sigma_t
238
+ self.sigma_history.append(sigma_t.item())
239
+ adv = adv / sigma_t
240
+ # Advantage normalization
241
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
242
+
243
+ # --- PPO Multiple Epochs + Minibatch ---
244
+ total_loss_epoch = 0.0
245
+ num_samples = len(states)
246
+ batch_size = min(64, num_samples)
247
+ ppo_epochs = 4
248
+
249
+ for _ in range(ppo_epochs):
250
+ # Shuffle indices
251
+ idxs = T.randperm(num_samples)
252
+ for start in range(0, num_samples, batch_size):
253
+ batch_idx = idxs[start:start + batch_size]
254
+
255
+ b_states = states[batch_idx]
256
+ b_actions = actions[batch_idx]
257
+ b_old_logp = old_logp[batch_idx]
258
+ b_returns = returns[batch_idx]
259
+ b_adv = adv[batch_idx]
260
+
261
+ dist = self.policy.next_action(b_states)
262
+ new_logp = dist.log_prob(b_actions)
263
+ entropy = dist.entropy().mean()
264
+ ratio = (new_logp - b_old_logp).exp()
265
+
266
+ # --- Clipped surrogate objective ---
267
+ surr1 = ratio * b_adv
268
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
269
+ policy_loss = -T.min(surr1, surr2).mean()
270
+
271
+ # --- Critic loss ---
272
+ value_pred = self.critic.evaluated_state(b_states)
273
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
274
+
275
+ # --- Total loss ---
276
+ total_loss = (
277
+ policy_loss +
278
+ self.value_coef * value_loss -
279
+ self.entropy_coef * entropy
280
+ )
281
+
282
+ # Debug: track individual loss components
283
+ self.policy_loss_history.append(policy_loss.item())
284
+ self.value_loss_history.append(value_loss.item())
285
+
286
+ self.opt.zero_grad(set_to_none=True)
287
+ total_loss.backward()
288
+ self.opt.step()
289
+ total_loss_epoch += total_loss.item()
290
+
291
+ # Clear memory after full PPO update
292
+ self.memory.clear()
293
+
294
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
295
+
296
+ '''
297
+ Different combination of normalization techniques combined to test if the performance gets better.
298
+ '''
299
+
300
+
301
+ def update_all_norm(self):
302
+ if len(self.memory.states) == 0:
303
+ return 0.0
304
+
305
+ # Convert memory to tensors
306
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
307
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
308
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
309
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
310
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
311
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
312
+
313
+ with T.no_grad():
314
+ # Compute next values (bootstrap for final step)
315
+ next_values = T.cat([values[1:], values[-1:].clone()])
316
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
317
+
318
+ # --- GAE-Lambda ---
319
+ adv = T.zeros_like(rewards)
320
+ gae = 0.0
321
+ for t in reversed(range(len(rewards))):
322
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
323
+ adv[t] = gae
324
+
325
+ # Advantage normalization
326
+ self.advantageNorm.update(adv)
327
+ adv = self.advantageNorm.normalize(adv)
328
+
329
+
330
+ returns = adv + values
331
+
332
+
333
+ # --- returns normalization ---
334
+ self.returnNorm.update(returns)
335
+ returns = self.returnNorm.normalize(returns)
336
+
337
+ # --- observation normalization ---
338
+ self.observeNorm.update(states)
339
+ states = self.observeNorm.normalize(states)
340
+
341
+
342
+ # --- PPO Multiple Epochs + Minibatch ---
343
+ total_loss_epoch = 0.0
344
+ num_samples = len(states)
345
+ batch_size = min(64, num_samples)
346
+ ppo_epochs = 4
347
+
348
+ for _ in range(ppo_epochs):
349
+ # Shuffle indices
350
+ idxs = T.randperm(num_samples)
351
+ for start in range(0, num_samples, batch_size):
352
+ batch_idx = idxs[start:start + batch_size]
353
+
354
+ b_states = states[batch_idx]
355
+ b_actions = actions[batch_idx]
356
+ b_old_logp = old_logp[batch_idx]
357
+ b_returns = returns[batch_idx]
358
+ b_adv = adv[batch_idx]
359
+
360
+ dist = self.policy.next_action(b_states)
361
+ new_logp = dist.log_prob(b_actions)
362
+ entropy = dist.entropy().mean()
363
+ ratio = (new_logp - b_old_logp).exp()
364
+
365
+ # --- Clipped surrogate objective ---
366
+ surr1 = ratio * b_adv
367
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
368
+ policy_loss = -T.min(surr1, surr2).mean()
369
+
370
+ # --- Critic loss ---
371
+ value_pred = self.critic.evaluated_state(b_states)
372
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
373
+
374
+ # --- Total loss ---
375
+ total_loss = (
376
+ policy_loss +
377
+ self.value_coef * value_loss -
378
+ self.entropy_coef * entropy
379
+ )
380
+
381
+ # Debug: track individual loss components
382
+ self.policy_loss_history.append(policy_loss.item())
383
+ self.value_loss_history.append(value_loss.item())
384
+
385
+ self.opt.zero_grad(set_to_none=True)
386
+ total_loss.backward()
387
+ self.opt.step()
388
+ total_loss_epoch += total_loss.item()
389
+
390
+ # Clear memory after full PPO update
391
+ self.memory.clear()
392
+
393
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
394
+
395
+ def update_observation_advantage_norm(self):
396
+ if len(self.memory.states) == 0:
397
+ return 0.0
398
+
399
+ # Convert memory to tensors
400
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
401
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
402
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
403
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
404
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
405
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
406
+
407
+ with T.no_grad():
408
+ # Compute next values (bootstrap for final step)
409
+ next_values = T.cat([values[1:], values[-1:].clone()])
410
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
411
+
412
+ # --- GAE-Lambda ---
413
+ adv = T.zeros_like(rewards)
414
+ gae = 0.0
415
+ for t in reversed(range(len(rewards))):
416
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
417
+ adv[t] = gae
418
+
419
+ # Advantage normalization
420
+ self.advantageNorm.update(adv)
421
+ adv = self.advantageNorm.normalize(adv)
422
+
423
+ returns = adv + values
424
+
425
+ # --- observation normalization ---
426
+ self.observeNorm.update(states)
427
+ states = self.observeNorm.normalize(states)
428
+
429
+ # --- PPO Multiple Epochs + Minibatch ---
430
+ total_loss_epoch = 0.0
431
+ num_samples = len(states)
432
+ batch_size = min(64, num_samples)
433
+ ppo_epochs = 4
434
+
435
+ for _ in range(ppo_epochs):
436
+ # Shuffle indices
437
+ idxs = T.randperm(num_samples)
438
+ for start in range(0, num_samples, batch_size):
439
+ batch_idx = idxs[start:start + batch_size]
440
+
441
+ b_states = states[batch_idx]
442
+ b_actions = actions[batch_idx]
443
+ b_old_logp = old_logp[batch_idx]
444
+ b_returns = returns[batch_idx]
445
+ b_adv = adv[batch_idx]
446
+
447
+ dist = self.policy.next_action(b_states)
448
+ new_logp = dist.log_prob(b_actions)
449
+ entropy = dist.entropy().mean()
450
+ ratio = (new_logp - b_old_logp).exp()
451
+
452
+ # --- Clipped surrogate objective ---
453
+ surr1 = ratio * b_adv
454
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
455
+ policy_loss = -T.min(surr1, surr2).mean()
456
+
457
+ # --- Critic loss ---
458
+ value_pred = self.critic.evaluated_state(b_states)
459
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
460
+
461
+ # --- Total loss ---
462
+ total_loss = (
463
+ policy_loss +
464
+ self.value_coef * value_loss -
465
+ self.entropy_coef * entropy
466
+ )
467
+
468
+ # Debug: track individual loss components
469
+ self.policy_loss_history.append(policy_loss.item())
470
+ self.value_loss_history.append(value_loss.item())
471
+
472
+ self.opt.zero_grad(set_to_none=True)
473
+ total_loss.backward()
474
+ self.opt.step()
475
+ total_loss_epoch += total_loss.item()
476
+
477
+ # Clear memory after full PPO update
478
+ self.memory.clear()
479
+
480
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
481
+
482
+ def update_observation_return_norm(self):
483
+ if len(self.memory.states) == 0:
484
+ return 0.0
485
+
486
+ # Convert memory to tensors
487
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
488
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
489
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
490
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
491
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
492
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
493
+
494
+ with T.no_grad():
495
+ # Compute next values (bootstrap for final step)
496
+ next_values = T.cat([values[1:], values[-1:].clone()])
497
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
498
+
499
+ # --- GAE-Lambda ---
500
+ adv = T.zeros_like(rewards)
501
+ gae = 0.0
502
+ for t in reversed(range(len(rewards))):
503
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
504
+ adv[t] = gae
505
+
506
+
507
+
508
+ returns = adv + values
509
+
510
+ # --- returns normalization ---
511
+ self.returnNorm.update(returns)
512
+ returns = self.returnNorm.normalize(returns)
513
+
514
+ # --- observation normalization ---
515
+ self.observeNorm.update(states)
516
+ states = self.observeNorm.normalize(states)
517
+
518
+ # --- PPO Multiple Epochs + Minibatch ---
519
+ total_loss_epoch = 0.0
520
+ num_samples = len(states)
521
+ batch_size = min(64, num_samples)
522
+ ppo_epochs = 4
523
+
524
+ for _ in range(ppo_epochs):
525
+ # Shuffle indices
526
+ idxs = T.randperm(num_samples)
527
+ for start in range(0, num_samples, batch_size):
528
+ batch_idx = idxs[start:start + batch_size]
529
+
530
+ b_states = states[batch_idx]
531
+ b_actions = actions[batch_idx]
532
+ b_old_logp = old_logp[batch_idx]
533
+ b_returns = returns[batch_idx]
534
+ b_adv = adv[batch_idx]
535
+
536
+ dist = self.policy.next_action(b_states)
537
+ new_logp = dist.log_prob(b_actions)
538
+ entropy = dist.entropy().mean()
539
+ ratio = (new_logp - b_old_logp).exp()
540
+
541
+ # --- Clipped surrogate objective ---
542
+ surr1 = ratio * b_adv
543
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
544
+ policy_loss = -T.min(surr1, surr2).mean()
545
+
546
+ # --- Critic loss ---
547
+ value_pred = self.critic.evaluated_state(b_states)
548
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
549
+
550
+ # --- Total loss ---
551
+ total_loss = (
552
+ policy_loss +
553
+ self.value_coef * value_loss -
554
+ self.entropy_coef * entropy
555
+ )
556
+
557
+ # Debug: track individual loss components
558
+ self.policy_loss_history.append(policy_loss.item())
559
+ self.value_loss_history.append(value_loss.item())
560
+
561
+ self.opt.zero_grad(set_to_none=True)
562
+ total_loss.backward()
563
+ self.opt.step()
564
+ total_loss_epoch += total_loss.item()
565
+
566
+ # Clear memory after full PPO update
567
+ self.memory.clear()
568
+
569
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
570
+
571
+ def update_advantage_return_norm(self):
572
+ if len(self.memory.states) == 0:
573
+ return 0.0
574
+
575
+ # Convert memory to tensors
576
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
577
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
578
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
579
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
580
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
581
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
582
+
583
+ with T.no_grad():
584
+ # Compute next values (bootstrap for final step)
585
+ next_values = T.cat([values[1:], values[-1:].clone()])
586
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
587
+
588
+ # --- GAE-Lambda ---
589
+ adv = T.zeros_like(rewards)
590
+ gae = 0.0
591
+ for t in reversed(range(len(rewards))):
592
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
593
+ adv[t] = gae
594
+
595
+ # Advantage normalization
596
+ self.advantageNorm.update(adv)
597
+ adv = self.advantageNorm.normalize(adv)
598
+
599
+ returns = adv + values
600
+
601
+ # --- returns normalization ---
602
+ self.returnNorm.update(returns)
603
+ returns = self.returnNorm.normalize(returns)
604
+
605
+
606
+
607
+ # --- PPO Multiple Epochs + Minibatch ---
608
+ total_loss_epoch = 0.0
609
+ num_samples = len(states)
610
+ batch_size = min(64, num_samples)
611
+ ppo_epochs = 4
612
+
613
+ for _ in range(ppo_epochs):
614
+ # Shuffle indices
615
+ idxs = T.randperm(num_samples)
616
+ for start in range(0, num_samples, batch_size):
617
+ batch_idx = idxs[start:start + batch_size]
618
+
619
+ b_states = states[batch_idx]
620
+ b_actions = actions[batch_idx]
621
+ b_old_logp = old_logp[batch_idx]
622
+ b_returns = returns[batch_idx]
623
+ b_adv = adv[batch_idx]
624
+
625
+ dist = self.policy.next_action(b_states)
626
+ new_logp = dist.log_prob(b_actions)
627
+ entropy = dist.entropy().mean()
628
+ ratio = (new_logp - b_old_logp).exp()
629
+
630
+ # --- Clipped surrogate objective ---
631
+ surr1 = ratio * b_adv
632
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
633
+ policy_loss = -T.min(surr1, surr2).mean()
634
+
635
+ # --- Critic loss ---
636
+ value_pred = self.critic.evaluated_state(b_states)
637
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
638
+
639
+ # --- Total loss ---
640
+ total_loss = (
641
+ policy_loss +
642
+ self.value_coef * value_loss -
643
+ self.entropy_coef * entropy
644
+ )
645
+
646
+ # Debug: track individual loss components
647
+ self.policy_loss_history.append(policy_loss.item())
648
+ self.value_loss_history.append(value_loss.item())
649
+
650
+ self.opt.zero_grad(set_to_none=True)
651
+ total_loss.backward()
652
+ self.opt.step()
653
+ total_loss_epoch += total_loss.item()
654
+
655
+ # Clear memory after full PPO update
656
+ self.memory.clear()
657
+
658
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
659
+ #------------------------------------------#
660
+
661
+
662
+ def update_observation_norm(self):
663
+ if len(self.memory.states) == 0:
664
+ return 0.0
665
+
666
+ # Convert memory to tensors
667
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
668
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
669
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
670
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
671
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
672
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
673
+
674
+ with T.no_grad():
675
+ # Compute next values (bootstrap for final step)
676
+ next_values = T.cat([values[1:], values[-1:].clone()])
677
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
678
+
679
+ # --- GAE-Lambda ---
680
+ adv = T.zeros_like(rewards)
681
+ gae = 0.0
682
+ for t in reversed(range(len(rewards))):
683
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
684
+ adv[t] = gae
685
+
686
+ returns = adv + values
687
+
688
+ # --- observation normalization ---
689
+ self.observeNorm.update(states)
690
+ states = self.observeNorm.normalize(states)
691
+ # Advantage normalization
692
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
693
+
694
+ # --- PPO Multiple Epochs + Minibatch ---
695
+ total_loss_epoch = 0.0
696
+ num_samples = len(states)
697
+ batch_size = min(64, num_samples)
698
+ ppo_epochs = 4
699
+
700
+ for _ in range(ppo_epochs):
701
+ # Shuffle indices
702
+ idxs = T.randperm(num_samples)
703
+ for start in range(0, num_samples, batch_size):
704
+ batch_idx = idxs[start:start + batch_size]
705
+
706
+ b_states = states[batch_idx]
707
+ b_actions = actions[batch_idx]
708
+ b_old_logp = old_logp[batch_idx]
709
+ b_returns = returns[batch_idx]
710
+ b_adv = adv[batch_idx]
711
+
712
+ dist = self.policy.next_action(b_states)
713
+ new_logp = dist.log_prob(b_actions)
714
+ entropy = dist.entropy().mean()
715
+ ratio = (new_logp - b_old_logp).exp()
716
+
717
+ # --- Clipped surrogate objective ---
718
+ surr1 = ratio * b_adv
719
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
720
+ policy_loss = -T.min(surr1, surr2).mean()
721
+
722
+ # --- Critic loss ---
723
+ value_pred = self.critic.evaluated_state(b_states)
724
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
725
+
726
+ # --- Total loss ---
727
+ total_loss = (
728
+ policy_loss +
729
+ self.value_coef * value_loss -
730
+ self.entropy_coef * entropy
731
+ )
732
+
733
+ # Debug: track individual loss components
734
+ self.policy_loss_history.append(policy_loss.item())
735
+ self.value_loss_history.append(value_loss.item())
736
+
737
+ self.opt.zero_grad(set_to_none=True)
738
+ total_loss.backward()
739
+ self.opt.step()
740
+ total_loss_epoch += total_loss.item()
741
+
742
+ # Clear memory after full PPO update
743
+ self.memory.clear()
744
+
745
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
746
+
747
+
748
+
749
+
750
+ def update_advantage_norm(self):
751
+ if len(self.memory.states) == 0:
752
+ return 0.0
753
+
754
+ # Convert memory to tensors
755
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
756
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
757
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
758
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
759
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
760
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
761
+
762
+ with T.no_grad():
763
+ # Compute next values (bootstrap for final step)
764
+ next_values = T.cat([values[1:], values[-1:].clone()])
765
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
766
+
767
+ # --- GAE-Lambda ---
768
+ adv = T.zeros_like(rewards)
769
+ gae = 0.0
770
+ for t in reversed(range(len(rewards))):
771
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
772
+ adv[t] = gae
773
+
774
+ # --- Advantage normalization ---
775
+ self.advantageNorm.update(adv)
776
+ adv = self.observeNorm.normalize(adv)
777
+
778
+ returns = adv + values
779
+
780
+
781
+
782
+ # --- PPO Multiple Epochs + Minibatch ---
783
+ total_loss_epoch = 0.0
784
+ num_samples = len(states)
785
+ batch_size = min(64, num_samples)
786
+ ppo_epochs = 4
787
+
788
+ for _ in range(ppo_epochs):
789
+ # Shuffle indices
790
+ idxs = T.randperm(num_samples)
791
+ for start in range(0, num_samples, batch_size):
792
+ batch_idx = idxs[start:start + batch_size]
793
+
794
+ b_states = states[batch_idx]
795
+ b_actions = actions[batch_idx]
796
+ b_old_logp = old_logp[batch_idx]
797
+ b_returns = returns[batch_idx]
798
+ b_adv = adv[batch_idx]
799
+
800
+ dist = self.policy.next_action(b_states)
801
+ new_logp = dist.log_prob(b_actions)
802
+ entropy = dist.entropy().mean()
803
+ ratio = (new_logp - b_old_logp).exp()
804
+
805
+ # --- Clipped surrogate objective ---
806
+ surr1 = ratio * b_adv
807
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
808
+ policy_loss = -T.min(surr1, surr2).mean()
809
+
810
+ # --- Critic loss ---
811
+ value_pred = self.critic.evaluated_state(b_states)
812
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
813
+
814
+ # --- Total loss ---
815
+ total_loss = (
816
+ policy_loss +
817
+ self.value_coef * value_loss -
818
+ self.entropy_coef * entropy
819
+ )
820
+
821
+ # Debug: track individual loss components
822
+ self.policy_loss_history.append(policy_loss.item())
823
+ self.value_loss_history.append(value_loss.item())
824
+
825
+ self.opt.zero_grad(set_to_none=True)
826
+ total_loss.backward()
827
+ self.opt.step()
828
+ total_loss_epoch += total_loss.item()
829
+
830
+ # Clear memory after full PPO update
831
+ self.memory.clear()
832
+
833
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
834
+
835
+ def update_return_norm(self):
836
+ if len(self.memory.states) == 0:
837
+ return 0.0
838
+
839
+ # Convert memory to tensors
840
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
841
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
842
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
843
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
844
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
845
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
846
+
847
+ with T.no_grad():
848
+ # Compute next values (bootstrap for final step)
849
+ next_values = T.cat([values[1:], values[-1:].clone()])
850
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
851
+
852
+ # --- GAE-Lambda ---
853
+ adv = T.zeros_like(rewards)
854
+ gae = 0.0
855
+ for t in reversed(range(len(rewards))):
856
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
857
+ adv[t] = gae
858
+
859
+
860
+
861
+ returns = adv + values
862
+
863
+ # --- returns normalization ---
864
+ self.returnNorm.update(returns)
865
+ returns = self.returnNorm.normalize(returns)
866
+
867
+
868
+ # Advantage normalization
869
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
870
+
871
+ # --- PPO Multiple Epochs + Minibatch ---
872
+ total_loss_epoch = 0.0
873
+ num_samples = len(states)
874
+ batch_size = min(64, num_samples)
875
+ ppo_epochs = 4
876
+
877
+ for _ in range(ppo_epochs):
878
+ # Shuffle indices
879
+ idxs = T.randperm(num_samples)
880
+ for start in range(0, num_samples, batch_size):
881
+ batch_idx = idxs[start:start + batch_size]
882
+
883
+ b_states = states[batch_idx]
884
+ b_actions = actions[batch_idx]
885
+ b_old_logp = old_logp[batch_idx]
886
+ b_returns = returns[batch_idx]
887
+ b_adv = adv[batch_idx]
888
+
889
+ dist = self.policy.next_action(b_states)
890
+ new_logp = dist.log_prob(b_actions)
891
+ entropy = dist.entropy().mean()
892
+ ratio = (new_logp - b_old_logp).exp()
893
+
894
+ # --- Clipped surrogate objective ---
895
+ surr1 = ratio * b_adv
896
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
897
+ policy_loss = -T.min(surr1, surr2).mean()
898
+
899
+ # --- Critic loss ---
900
+ value_pred = self.critic.evaluated_state(b_states)
901
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
902
+
903
+ # --- Total loss ---
904
+ total_loss = (
905
+ policy_loss +
906
+ self.value_coef * value_loss -
907
+ self.entropy_coef * entropy
908
+ )
909
+
910
+ # Debug: track individual loss components
911
+ self.policy_loss_history.append(policy_loss.item())
912
+ self.value_loss_history.append(value_loss.item())
913
+
914
+ self.opt.zero_grad(set_to_none=True)
915
+ total_loss.backward()
916
+ self.opt.step()
917
+ total_loss_epoch += total_loss.item()
918
+
919
+ # Clear memory after full PPO update
920
+ self.memory.clear()
921
+
922
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
923
+
924
+ def update_reward_gradient_clipping(self):
925
+ if len(self.memory.states) == 0:
926
+ return 0.0
927
+
928
+ # Convert memory to tensors
929
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
930
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
931
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
932
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
933
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
934
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
935
+
936
+ # Reward clipping
937
+ rewards = T.clamp(rewards, -1, 1)
938
+
939
+ with T.no_grad():
940
+ # Compute next values (bootstrap for final step)
941
+ next_values = T.cat([values[1:], values[-1:].clone()])
942
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
943
+
944
+ # --- GAE-Lambda ---
945
+ adv = T.zeros_like(rewards)
946
+ gae = 0.0
947
+ for t in reversed(range(len(rewards))):
948
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
949
+ adv[t] = gae
950
+
951
+ returns = adv + values
952
+ # Advantage normalization
953
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
954
+
955
+ # --- PPO Multiple Epochs + Minibatch ---
956
+ total_loss_epoch = 0.0
957
+ num_samples = len(states)
958
+ batch_size = min(64, num_samples)
959
+ ppo_epochs = 4
960
+
961
+ for _ in range(ppo_epochs):
962
+ # Shuffle indices
963
+ idxs = T.randperm(num_samples)
964
+ for start in range(0, num_samples, batch_size):
965
+ batch_idx = idxs[start:start + batch_size]
966
+
967
+ b_states = states[batch_idx]
968
+ b_actions = actions[batch_idx]
969
+ b_old_logp = old_logp[batch_idx]
970
+ b_returns = returns[batch_idx]
971
+ b_adv = adv[batch_idx]
972
+
973
+ dist = self.policy.next_action(b_states)
974
+ new_logp = dist.log_prob(b_actions)
975
+ entropy = dist.entropy().mean()
976
+ ratio = (new_logp - b_old_logp).exp()
977
+
978
+ # --- Clipped surrogate objective ---
979
+ surr1 = ratio * b_adv
980
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
981
+ policy_loss = -T.min(surr1, surr2).mean()
982
+
983
+ # --- Critic loss ---
984
+ value_pred = self.critic.evaluated_state(b_states)
985
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
986
+
987
+ # --- Total loss ---
988
+ total_loss = (
989
+ policy_loss +
990
+ self.value_coef * value_loss -
991
+ self.entropy_coef * entropy
992
+ )
993
+
994
+ # Debug: track individual loss components
995
+ self.policy_loss_history.append(policy_loss.item())
996
+ self.value_loss_history.append(value_loss.item())
997
+
998
+ self.opt.zero_grad(set_to_none=True)
999
+ total_loss.backward()
1000
+ T.nn.utils.clip_grad_norm_(list(self.policy.parameters()) + list(self.critic.parameters()), 0.5)
1001
+ self.opt.step()
1002
+
1003
+ total_loss_epoch += total_loss.item()
1004
+
1005
+ # Clear memory after full PPO update
1006
+ self.memory.clear()
1007
+
1008
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
1009
+
1010
+ """
1011
+ # Policy network (simple MLP, flattened observations)
1012
+ class Policy(nn.Module):
1013
+ def __init__(self, obs_dim: int, action_dim: int, hidden: int):
1014
+ super().__init__()
1015
+ self.net = nn.Sequential(
1016
+ nn.Linear(obs_dim, hidden),
1017
+ nn.ReLU(),
1018
+ nn.Linear(hidden, hidden),
1019
+ nn.ReLU(),
1020
+ nn.Linear(hidden, action_dim)
1021
+ )
1022
+
1023
+ def next_action(self, state: T.Tensor) -> Categorical:
1024
+ # Returns the probability distribution over actions
1025
+ if state.dim() == 1:
1026
+ state = state.unsqueeze(0)
1027
+ state = state.view(state.size(0), -1)
1028
+ return Categorical(logits=self.net(state))
1029
+ """
1030
+
1031
+ # Policy network (CNN)
1032
+ class Policy(nn.Module):
1033
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
1034
+ super().__init__()
1035
+ c, h, w = obs_shape
1036
+ # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
1037
+ self.cnn = nn.Sequential(
1038
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
1039
+ nn.ReLU(),
1040
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
1041
+ nn.ReLU(),
1042
+ nn.Flatten()
1043
+ )
1044
+
1045
+ with T.no_grad():
1046
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
1047
+
1048
+ self.net = nn.Sequential(
1049
+ nn.Linear(cnn_output_dim, hidden),
1050
+ nn.ReLU(),
1051
+ nn.Linear(hidden, action_dim)
1052
+ )
1053
+
1054
+ def next_action(self, state: T.Tensor) -> Categorical:
1055
+ # Returns the probability distribution over actions
1056
+ if state.dim() == 3:
1057
+ state = state.unsqueeze(0)
1058
+ cnn_out = self.cnn(state)
1059
+ return Categorical(logits=self.net(cnn_out))
1060
+
1061
+ """
1062
+ # Critic network (simple MLP, flattened observations)
1063
+ class Critic(nn.Module):
1064
+ def __init__(self, obs_dim: int, hidden: int):
1065
+ super().__init__()
1066
+ self.net = nn.Sequential(
1067
+ nn.Linear(obs_dim, hidden),
1068
+ nn.ReLU(),
1069
+ nn.Linear(hidden, hidden),
1070
+ nn.ReLU(),
1071
+ nn.Linear(hidden, 1)
1072
+ )
1073
+
1074
+ def evaluated_state(self, x: T.Tensor) -> T.Tensor:
1075
+ if x.dim() == 1:
1076
+ x = x.unsqueeze(0)
1077
+ x = x.view(x.size(0), -1)
1078
+ return self.net(x).squeeze(-1)
1079
+ """
1080
+
1081
+ # Critic network (CNN)
1082
+ class Critic(nn.Module):
1083
+ def __init__(self, obs_shape: tuple, hidden: int):
1084
+ super().__init__()
1085
+ c, h, w = obs_shape
1086
+ # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
1087
+ self.cnn = nn.Sequential(
1088
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
1089
+ nn.ReLU(),
1090
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
1091
+ nn.ReLU(),
1092
+ nn.Flatten()
1093
+ )
1094
+
1095
+ with T.no_grad():
1096
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
1097
+
1098
+ self.net = nn.Sequential(
1099
+ nn.Linear(cnn_output_dim, hidden),
1100
+ nn.ReLU(),
1101
+ nn.Linear(hidden, 1)
1102
+ )
1103
+
1104
+ def evaluated_state(self, x: T.Tensor) -> T.Tensor:
1105
+ if x.dim() == 3:
1106
+ x = x.unsqueeze(0)
1107
+ cnn_out = self.cnn(x)
1108
+ return self.net(cnn_out).squeeze(-1)
1109
+
1110
+ class Memory():
1111
+ def __init__(self):
1112
+ self.states = []
1113
+ self.actions = []
1114
+ self.rewards = []
1115
+ self.dones = []
1116
+ self.log_probs = []
1117
+ self.values = []
1118
+ self.next_values = []
1119
+
1120
+ def store(self, state, action, reward, done, log_prob, value, next_value):
1121
+ self.states.append(np.asarray(state, dtype=np.float32))
1122
+ self.actions.append(int(action))
1123
+ self.rewards.append(float(reward))
1124
+ self.dones.append(float(done))
1125
+ self.log_probs.append(float(log_prob))
1126
+ self.values.append(float(value))
1127
+ self.next_values.append(float(next_value))
1128
+
1129
+ """
1130
+ # For mini-batch updates? To be implemented
1131
+ def start_batch(self, batch_size: int):
1132
+ n_states = len(self.states)
1133
+ starts = np.arange(0, n_states, batch_size)
1134
+ index = np.arange(n_states, dtype=np.int64)
1135
+ np.random.shuffle(index)
1136
+ return [index[s:s + batch_size] for s in starts]
1137
+ """
1138
+
1139
+ def clear(self):
1140
+ self.states = []
1141
+ self.actions = []
1142
+ self.rewards = []
1143
+ self.dones = []
1144
+ self.log_probs = []
1145
+ self.values = []
1146
+ self.next_values = []
1147
+
1148
+
1149
+
1150
+ class ObservationNorm:
1151
+ def __init__(self):
1152
+ self.main_mean = 0
1153
+ self.main_var = 0
1154
+ self.count = 1e-4
1155
+
1156
+ def update(self, x: T.Tensor):
1157
+ batch_mean = T.mean(x, dim=0)
1158
+ batch_var = T.var(x, dim=0)
1159
+ batch_count = x.shape[0]
1160
+ self._update_from_moments(batch_mean, batch_var, batch_count)
1161
+
1162
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
1163
+ delta = batch_mean - self.main_mean
1164
+ tot_count = self.count + batch_count
1165
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
1166
+ m_a = self.main_var * self.count
1167
+ m_b = batch_var * batch_count
1168
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
1169
+ new_var = M2 / tot_count # update the running variance
1170
+
1171
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
1172
+
1173
+ def normalize(self, x):
1174
+
1175
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
1176
+ # divide through zero.
1177
+
1178
+
1179
+
1180
+
1181
+
1182
+ class AdvantageNorm:
1183
+ '''
1184
+ This class implements the Advantage Normalization. The purpose is to normalize either across batches or
1185
+ only within the same batch.
1186
+
1187
+ '''
1188
+ def __init__(self):
1189
+ self.main_mean = 0
1190
+ self.main_var = 0
1191
+ self.count = 1e-4
1192
+
1193
+ def update(self, x: T.Tensor):
1194
+ batch_mean = T.mean(x, dim=0)
1195
+ batch_var = T.var(x, dim=0)
1196
+ batch_count = x.shape[0]
1197
+ self._update_from_moments(batch_mean, batch_var, batch_count)
1198
+
1199
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
1200
+ delta = batch_mean - self.main_mean
1201
+ tot_count = self.count + batch_count
1202
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
1203
+ m_a = self.main_var * self.count
1204
+ m_b = batch_var * batch_count
1205
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
1206
+ new_var = M2 / tot_count # update the running variance
1207
+
1208
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
1209
+
1210
+ def normalize(self, x):
1211
+
1212
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
1213
+ # divide through zero.
1214
+
1215
+
1216
+
1217
+
1218
+ class ReturnNorm:
1219
+ '''
1220
+ This class implements the Advantage Normalization. The purpose is to normalize either across batches or
1221
+ only within the same batch.
1222
+
1223
+ '''
1224
+ def __init__(self):
1225
+ self.main_mean = 0
1226
+ self.main_var = 0
1227
+ self.count = 1e-4
1228
+
1229
+ def update(self, x: T.Tensor):
1230
+ batch_mean = T.mean(x, dim=0)
1231
+ batch_var = T.var(x, dim=0)
1232
+ batch_count = x.shape[0]
1233
+ self._update_from_moments(batch_mean, batch_var, batch_count)
1234
+
1235
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
1236
+ delta = batch_mean - self.main_mean
1237
+ tot_count = self.count + batch_count
1238
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
1239
+ m_a = self.main_var * self.count
1240
+ m_b = batch_var * batch_count
1241
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
1242
+ new_var = M2 / tot_count # update the running variance
1243
+
1244
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
1245
+
1246
+ def normalize(self, x):
1247
+
1248
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
1249
+ # divide through zero.
1250
+
1251
+
1252
+
1253
+
1254
+
Observation_Advantage_Norm_diff_combo/ppo_rew_norm_obs_env_diff_combo.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gymnasium as gym
3
+ import sys
4
+ import matplotlib.pyplot as plt
5
+ import ale_py
6
+ from ppo__rew_norm_obs_diff_combo import *
7
+ from gymnasium.spaces import Box
8
+ import cv2
9
+
10
+
11
+ class PlotCreater:
12
+ def __init__(self):
13
+ self.fig = plt.figure(figsize=(12, 8))
14
+ self.ax2 = plt.subplot(221)
15
+ self.ax3 = plt.subplot(222)
16
+ self.ax4 = plt.subplot(223)
17
+ self.ax5 = plt.subplot(224)
18
+
19
+ """
20
+ # Plot for Return-Based Scaling only
21
+ ax1 = plt.subplot(220)
22
+ ax1.plot(agent.sigma_history, label="Return σ")
23
+ ax1.set_xlabel("PPO Update")
24
+ ax1.set_ylabel("σ (Return Std)")
25
+ """
26
+
27
+
28
+
29
+ def lossHistorySetting(self, loss_history, update_type):
30
+ self.ax2.plot(loss_history, label=update_type)
31
+
32
+
33
+ def rewardSetting(self, reward_history, update_type):
34
+ self.ax3.plot(reward_history, label=update_type)
35
+
36
+ def policyHistorySetting(self, policy_history, update_type):
37
+ self.ax4.plot(policy_history, label=update_type)
38
+
39
+ def valueLossSetting(self, value_loss_history, update_type):
40
+ self.ax5.plot(value_loss_history, label=update_type)
41
+
42
+
43
+
44
+
45
+ def setTitle(self, title):
46
+ self.fig.suptitle(title)
47
+
48
+
49
+ def plotShow(self):
50
+
51
+ self.ax2.set_ylabel("Average PPO Loss")
52
+ self.ax2.set_xlabel("PPO Update")
53
+ self.ax2.legend()
54
+
55
+ self.ax3.set_ylabel("Reward")
56
+ self.ax3.set_xlabel("PPO Update")
57
+ self.ax3.legend()
58
+
59
+ # Details about value loss and policy loss
60
+
61
+ self.ax4.set_ylabel("Policy Loss")
62
+ self.ax4.set_xlabel("Training Step")
63
+ self.ax4.legend()
64
+
65
+ self.ax5.set_ylabel("Value Loss")
66
+ self.ax5.set_xlabel("Training Step")
67
+ self.ax5.legend()
68
+
69
+ self.fig.suptitle("PPO Training Stability of type " +
70
+ "-running_average")
71
+ self.fig.tight_layout()
72
+ self.fig.savefig( "Different_combination_"+ " running_average_.png")
73
+ plt.show()
74
+ print("Show the graph and store them")
75
+
76
+
77
+ def preprocess(obs):
78
+ # Convert to grayscale
79
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
80
+ # Resize
81
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
82
+ # Add channel dimension and normalize
83
+ return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0
84
+
85
+
86
+ def rl_model(update_type, plotCreater):
87
+ # env = gym.make("ALE/SpaceInvaders-v5", render_mode='human')
88
+ # env = gym.make("ALE/Pacman-v5", render_mode="human")
89
+ env = gym.make("ALE/Pacman-v5")
90
+
91
+ episode = 0
92
+ total_return = 0
93
+ ep_return = 0
94
+ steps = 1000
95
+ batches = 100
96
+
97
+ print("Observation space:", env.observation_space)
98
+ print("Action space:", env.action_space)
99
+ """
100
+ agent = Agent(obs_space=env.observation_space, action_space=env.action_space,
101
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
102
+ entropy_coef=0.01, value_coef=0.5, seed=70,
103
+ batch_size = 64, ppo_epochs = 4, lam = 0.95)
104
+
105
+ """
106
+ # Initialize CNN with a dummy observation (to get correct input shape)
107
+ obs, _ = env.reset()
108
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
109
+
110
+ agent = Agent(obs_space=dummy_obs_space, action_space=env.action_space,
111
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
112
+ entropy_coef=0.01, value_coef=0.5, seed=70,
113
+ batch_size=64, ppo_epochs=4, lam=0.95, update_type=update_type)
114
+ """
115
+ # Stats for Return-Based Scaling only
116
+ # === Return-Based Scaling stats ===
117
+ r_mean, r_var = 0.0, 1e-8
118
+ g2_mean = 1.0
119
+
120
+ agent.r_var = r_var
121
+ agent.g2_mean = g2_mean
122
+ """
123
+
124
+ try:
125
+ obs, info = env.reset(seed=42)
126
+ state = preprocess(obs)
127
+
128
+ loss_history = []
129
+ reward_history = []
130
+
131
+ for update in range(1, batches + 1):
132
+ for t in range(steps):
133
+ action, logp, value = agent.choose_action(state)
134
+ next_obs, reward, terminated, truncated, info = env.step(action)
135
+ done = terminated or truncated
136
+ next_state = preprocess(next_obs)
137
+
138
+ agent.remember(state, action, reward, done, logp, value, next_state)
139
+
140
+ ep_return += reward
141
+ state = next_state
142
+
143
+ if done:
144
+ episode += 1
145
+ total_return += ep_return
146
+ print(f"Episode {episode} return: {ep_return:.2f}")
147
+ ep_return = 0
148
+ obs, info = env.reset()
149
+ state = preprocess(obs)
150
+
151
+ # Using reward gradient clipping
152
+ avg_loss = agent._update()
153
+
154
+ # Vanilla PPO (no normalization)
155
+ # avg_loss = agent.vanilla_ppo_update()
156
+ loss_history.append(avg_loss)
157
+
158
+ avg_ret = (total_return / episode) if episode else 0
159
+ reward_history.append(avg_ret)
160
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")
161
+
162
+ plotCreater.lossHistorySetting(loss_history, update_type)
163
+ plotCreater.rewardSetting(reward_history, update_type)
164
+ plotCreater.policyHistorySetting(agent.policy_loss_history, update_type)
165
+ plotCreater.valueLossSetting(agent.value_loss_history, update_type)
166
+
167
+
168
+
169
+
170
+ except Exception as e:
171
+ print(f"Error: {e}", file=sys.stderr)
172
+ return 1
173
+ finally:
174
+ avg = total_return / episode if episode else 0
175
+ print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
176
+ env.close()
177
+
178
+ return 0
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
+ def main() -> int:
188
+ combo_type_list = ["update_all_norm", "update_observation_advantage_norm"
189
+ , "update_observation_return_norm", "update_advantage_return_norm"]
190
+ type_list = ["update_observation_norm", "update_advantage_norm", "update_return_norm", "vanilla_ppo_update"]
191
+
192
+ plotCreater = PlotCreater()
193
+ for update_type in combo_type_list:
194
+ rl_model(update_type, plotCreater)
195
+
196
+ plotCreater.plotShow()
197
+ return 0
198
+
199
+
200
+ if __name__ == "__main__":
201
+ raise SystemExit(main())
Observation_Advantage_Norm_diff_env/ppo__rew_norm_obs_diff_env.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as T
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.distributions import Categorical
6
+
7
+
8
+ class Agent:
9
+ def __init__(
10
+ self,
11
+ obs_space,
12
+ action_space,
13
+ hidden,
14
+ gamma,
15
+ clip_coef,
16
+ lr,
17
+ value_coef,
18
+ entropy_coef,
19
+ seed,
20
+ batch_size,
21
+ ppo_epochs,
22
+ lam,
23
+ update_type
24
+
25
+ ):
26
+ # Initialize seed for reproducibility
27
+ if seed is not None:
28
+ np.random.seed(seed)
29
+ T.manual_seed(seed)
30
+ """
31
+ # For flat observations (MLP model)
32
+ # Use GPU if available
33
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
34
+ self.obs_dim = int(np.prod(getattr(obs_space, "shape", (obs_space,))))
35
+ self.action_dim = int(getattr(action_space, "n", action_space))
36
+
37
+ # Initialize the policy and the critic networks
38
+ self.policy = Policy(self.obs_dim, self.action_dim, hidden).to(self.device)
39
+ self.critic = Critic(self.obs_dim, hidden).to(self.device)
40
+ """
41
+ # Use GPU if available
42
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
43
+ self.action_dim = int(getattr(action_space, "n", action_space))
44
+ self.update_type = update_type
45
+
46
+ # Initialize the policy and the critic networks
47
+ # Pass the shape tuple directly, not the flattened dimension.
48
+ self.policy = Policy(obs_space.shape, self.action_dim, hidden).to(self.device)
49
+ self.critic = Critic(obs_space.shape, hidden).to(self.device)
50
+ self.observeNorm = ObservationNorm()
51
+ self.advantageNorm = AdvantageNorm()
52
+ self.returnNorm = ReturnNorm()
53
+
54
+ # Set optimizer for policy and critic networks
55
+ self.opt = optim.Adam(
56
+ list(self.policy.parameters()) + list(self.critic.parameters()),
57
+ lr=lr
58
+ )
59
+
60
+ self.gamma = gamma
61
+ self.clip = clip_coef
62
+ self.value_coef = value_coef
63
+ self.entropy_coef = entropy_coef
64
+ self.sigma_history = []
65
+ self.loss_history = []
66
+ self.policy_loss_history = []
67
+ self.value_loss_history = []
68
+ self.entropy_history = []
69
+ self.lam = lam
70
+ self.ppo_epochs = ppo_epochs
71
+ self.batch_size = batch_size
72
+
73
+ self.memory = Memory()
74
+ """
75
+ # Choose action and remember for flat observations (MLP model)
76
+ def choose_action(self, observation):
77
+ # Returns: action, log probabilitiy, value of the state
78
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device).view(-1)
79
+ with T.no_grad():
80
+ # Forward function (defined in Policy class)
81
+ dist = self.policy.next_action(state)
82
+ action = dist.sample()
83
+ logp = dist.log_prob(action)
84
+ value = self.critic.evaluated_state(state)
85
+ return int(action.item()), float(logp.item()), float(value.item())
86
+
87
+ def remember(self, state, action, reward, done, log_prob, value, next_state):
88
+ with T.no_grad():
89
+ # Pass on next state and have it evaluated by the critic network
90
+ ns = T.as_tensor(next_state, dtype=T.float32, device=self.device).view(-1)
91
+ next_value = self.critic.evaluated_state(ns).item()
92
+ self.memory.store(state, action, reward, done, log_prob, value, next_value)
93
+ """
94
+ # For CNN model
95
+ def choose_action(self, observation):
96
+ # Returns: action, log probabilitiy, value of the state
97
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device) # Remove .view(-1)
98
+ with T.no_grad():
99
+ # Forward function (defined in Policy class)
100
+ dist = self.policy.next_action(state)
101
+ action = dist.sample()
102
+ logp = dist.log_prob(action)
103
+ value = self.critic.evaluated_state(state)
104
+ return int(action.item()), float(logp.item()), float(value.item())
105
+
106
+ def remember(self, state, action, reward, done, log_prob, value, next_state):
107
+ with T.no_grad():
108
+ # Pass on next state and have it evaluated by the critic network
109
+ ns = T.as_tensor(next_state, dtype=T.float32, device=self.device) # Remove .view(-1)
110
+ next_value = self.critic.evaluated_state(ns).item()
111
+ self.memory.store(state, action, reward, done, log_prob, value, next_value)
112
+
113
+
114
+ def _update(self):
115
+ if self.update_type == "update_observation_norm":
116
+ return self.update_observation_norm()
117
+ elif self.update_type == "update_advantage_norm":
118
+ return self.update_advantage_norm()
119
+ elif self.update_type == "update_return_norm":
120
+ return self.update_return_norm()
121
+ else:
122
+ return self.vanilla_ppo_update()
123
+
124
+ def vanilla_ppo_update(self):
125
+ if len(self.memory.states) == 0:
126
+ return 0.0
127
+
128
+ # Convert memory to tensors
129
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
130
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
131
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
132
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
133
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
134
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
135
+
136
+ with T.no_grad():
137
+ # Compute next values (bootstrap for final step)
138
+ next_values = T.cat([values[1:], values[-1:].clone()])
139
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
140
+
141
+ # --- GAE-Lambda ---
142
+ adv = T.zeros_like(rewards)
143
+ gae = 0.0
144
+ for t in reversed(range(len(rewards))):
145
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
146
+ adv[t] = gae
147
+
148
+ returns = adv + values
149
+ # Advantage normalization
150
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
151
+
152
+ # --- PPO Multiple Epochs + Minibatch ---
153
+ total_loss_epoch = 0.0
154
+ num_samples = len(states)
155
+ batch_size = min(64, num_samples)
156
+ ppo_epochs = 4
157
+
158
+ for _ in range(ppo_epochs):
159
+ # Shuffle indices
160
+ idxs = T.randperm(num_samples)
161
+ for start in range(0, num_samples, batch_size):
162
+ batch_idx = idxs[start:start + batch_size]
163
+
164
+ b_states = states[batch_idx]
165
+ b_actions = actions[batch_idx]
166
+ b_old_logp = old_logp[batch_idx]
167
+ b_returns = returns[batch_idx]
168
+ b_adv = adv[batch_idx]
169
+
170
+ dist = self.policy.next_action(b_states)
171
+ new_logp = dist.log_prob(b_actions)
172
+ entropy = dist.entropy().mean()
173
+ ratio = (new_logp - b_old_logp).exp()
174
+
175
+ # --- Clipped surrogate objective ---
176
+ surr1 = ratio * b_adv
177
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
178
+ policy_loss = -T.min(surr1, surr2).mean()
179
+
180
+ # --- Critic loss ---
181
+ value_pred = self.critic.evaluated_state(b_states)
182
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
183
+
184
+ # --- Total loss ---
185
+ total_loss = (
186
+ policy_loss +
187
+ self.value_coef * value_loss -
188
+ self.entropy_coef * entropy
189
+ )
190
+
191
+ # Debug: track individual loss components
192
+ self.policy_loss_history.append(policy_loss.item())
193
+ self.value_loss_history.append(value_loss.item())
194
+
195
+ self.opt.zero_grad(set_to_none=True)
196
+ total_loss.backward()
197
+ self.opt.step()
198
+
199
+ total_loss_epoch += total_loss.item()
200
+
201
+ # Clear memory after full PPO update
202
+ self.memory.clear()
203
+
204
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
205
+
206
+
207
+ def update_rbs(self):
208
+ if len(self.memory.states) == 0:
209
+ return 0.0
210
+
211
+ # Convert memory to tensors
212
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
213
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
214
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
215
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
216
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
217
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
218
+
219
+ with T.no_grad():
220
+ # Compute next values (bootstrap for final step)
221
+ next_values = T.cat([values[1:], values[-1:].clone()])
222
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
223
+
224
+ # --- GAE-Lambda ---
225
+ adv = T.zeros_like(rewards)
226
+ gae = 0.0
227
+ for t in reversed(range(len(rewards))):
228
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
229
+ adv[t] = gae
230
+
231
+ returns = adv + values
232
+
233
+ # --- Return-based normalization (RBS) ---
234
+ sigma_t = returns.std(unbiased=False) + 1e-8
235
+ returns = returns / sigma_t
236
+ self.sigma_history.append(sigma_t.item())
237
+ adv = adv / sigma_t
238
+ # Advantage normalization
239
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
240
+
241
+ # --- PPO Multiple Epochs + Minibatch ---
242
+ total_loss_epoch = 0.0
243
+ num_samples = len(states)
244
+ batch_size = min(64, num_samples)
245
+ ppo_epochs = 4
246
+
247
+ for _ in range(ppo_epochs):
248
+ # Shuffle indices
249
+ idxs = T.randperm(num_samples)
250
+ for start in range(0, num_samples, batch_size):
251
+ batch_idx = idxs[start:start + batch_size]
252
+
253
+ b_states = states[batch_idx]
254
+ b_actions = actions[batch_idx]
255
+ b_old_logp = old_logp[batch_idx]
256
+ b_returns = returns[batch_idx]
257
+ b_adv = adv[batch_idx]
258
+
259
+ dist = self.policy.next_action(b_states)
260
+ new_logp = dist.log_prob(b_actions)
261
+ entropy = dist.entropy().mean()
262
+ ratio = (new_logp - b_old_logp).exp()
263
+
264
+ # --- Clipped surrogate objective ---
265
+ surr1 = ratio * b_adv
266
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
267
+ policy_loss = -T.min(surr1, surr2).mean()
268
+
269
+ # --- Critic loss ---
270
+ value_pred = self.critic.evaluated_state(b_states)
271
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
272
+
273
+ # --- Total loss ---
274
+ total_loss = (
275
+ policy_loss +
276
+ self.value_coef * value_loss -
277
+ self.entropy_coef * entropy
278
+ )
279
+
280
+ # Debug: track individual loss components
281
+ self.policy_loss_history.append(policy_loss.item())
282
+ self.value_loss_history.append(value_loss.item())
283
+
284
+ self.opt.zero_grad(set_to_none=True)
285
+ total_loss.backward()
286
+ self.opt.step()
287
+ total_loss_epoch += total_loss.item()
288
+
289
+ # Clear memory after full PPO update
290
+ self.memory.clear()
291
+
292
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
293
+
294
+
295
+
296
+
297
+
298
+
299
+ def update_observation_norm(self):
300
+ if len(self.memory.states) == 0:
301
+ return 0.0
302
+
303
+ # Convert memory to tensors
304
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
305
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
306
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
307
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
308
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
309
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
310
+
311
+ with T.no_grad():
312
+ # Compute next values (bootstrap for final step)
313
+ next_values = T.cat([values[1:], values[-1:].clone()])
314
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
315
+
316
+ # --- GAE-Lambda ---
317
+ adv = T.zeros_like(rewards)
318
+ gae = 0.0
319
+ for t in reversed(range(len(rewards))):
320
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
321
+ adv[t] = gae
322
+
323
+ returns = adv + values
324
+
325
+ # --- observation normalization ---
326
+ self.observeNorm.update(states)
327
+ states = self.observeNorm.normalize(states)
328
+ # Advantage normalization
329
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
330
+
331
+ # --- PPO Multiple Epochs + Minibatch ---
332
+ total_loss_epoch = 0.0
333
+ num_samples = len(states)
334
+ batch_size = min(64, num_samples)
335
+ ppo_epochs = 4
336
+
337
+ for _ in range(ppo_epochs):
338
+ # Shuffle indices
339
+ idxs = T.randperm(num_samples)
340
+ for start in range(0, num_samples, batch_size):
341
+ batch_idx = idxs[start:start + batch_size]
342
+
343
+ b_states = states[batch_idx]
344
+ b_actions = actions[batch_idx]
345
+ b_old_logp = old_logp[batch_idx]
346
+ b_returns = returns[batch_idx]
347
+ b_adv = adv[batch_idx]
348
+
349
+ dist = self.policy.next_action(b_states)
350
+ new_logp = dist.log_prob(b_actions)
351
+ entropy = dist.entropy().mean()
352
+ ratio = (new_logp - b_old_logp).exp()
353
+
354
+ # --- Clipped surrogate objective ---
355
+ surr1 = ratio * b_adv
356
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
357
+ policy_loss = -T.min(surr1, surr2).mean()
358
+
359
+ # --- Critic loss ---
360
+ value_pred = self.critic.evaluated_state(b_states)
361
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
362
+
363
+ # --- Total loss ---
364
+ total_loss = (
365
+ policy_loss +
366
+ self.value_coef * value_loss -
367
+ self.entropy_coef * entropy
368
+ )
369
+
370
+ # Debug: track individual loss components
371
+ self.policy_loss_history.append(policy_loss.item())
372
+ self.value_loss_history.append(value_loss.item())
373
+
374
+ self.opt.zero_grad(set_to_none=True)
375
+ total_loss.backward()
376
+ self.opt.step()
377
+ total_loss_epoch += total_loss.item()
378
+
379
+ # Clear memory after full PPO update
380
+ self.memory.clear()
381
+
382
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
383
+
384
+
385
+
386
+
387
+ def update_advantage_norm(self):
388
+ if len(self.memory.states) == 0:
389
+ return 0.0
390
+
391
+ # Convert memory to tensors
392
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
393
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
394
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
395
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
396
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
397
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
398
+
399
+ with T.no_grad():
400
+ # Compute next values (bootstrap for final step)
401
+ next_values = T.cat([values[1:], values[-1:].clone()])
402
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
403
+
404
+ # --- GAE-Lambda ---
405
+ adv = T.zeros_like(rewards)
406
+ gae = 0.0
407
+ for t in reversed(range(len(rewards))):
408
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
409
+ adv[t] = gae
410
+
411
+ # --- Advantage normalization ---
412
+ self.advantageNorm.update(adv)
413
+ adv = self.observeNorm.normalize(adv)
414
+
415
+ returns = adv + values
416
+
417
+
418
+
419
+ # --- PPO Multiple Epochs + Minibatch ---
420
+ total_loss_epoch = 0.0
421
+ num_samples = len(states)
422
+ batch_size = min(64, num_samples)
423
+ ppo_epochs = 4
424
+
425
+ for _ in range(ppo_epochs):
426
+ # Shuffle indices
427
+ idxs = T.randperm(num_samples)
428
+ for start in range(0, num_samples, batch_size):
429
+ batch_idx = idxs[start:start + batch_size]
430
+
431
+ b_states = states[batch_idx]
432
+ b_actions = actions[batch_idx]
433
+ b_old_logp = old_logp[batch_idx]
434
+ b_returns = returns[batch_idx]
435
+ b_adv = adv[batch_idx]
436
+
437
+ dist = self.policy.next_action(b_states)
438
+ new_logp = dist.log_prob(b_actions)
439
+ entropy = dist.entropy().mean()
440
+ ratio = (new_logp - b_old_logp).exp()
441
+
442
+ # --- Clipped surrogate objective ---
443
+ surr1 = ratio * b_adv
444
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
445
+ policy_loss = -T.min(surr1, surr2).mean()
446
+
447
+ # --- Critic loss ---
448
+ value_pred = self.critic.evaluated_state(b_states)
449
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
450
+
451
+ # --- Total loss ---
452
+ total_loss = (
453
+ policy_loss +
454
+ self.value_coef * value_loss -
455
+ self.entropy_coef * entropy
456
+ )
457
+
458
+ # Debug: track individual loss components
459
+ self.policy_loss_history.append(policy_loss.item())
460
+ self.value_loss_history.append(value_loss.item())
461
+
462
+ self.opt.zero_grad(set_to_none=True)
463
+ total_loss.backward()
464
+ self.opt.step()
465
+ total_loss_epoch += total_loss.item()
466
+
467
+ # Clear memory after full PPO update
468
+ self.memory.clear()
469
+
470
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
471
+
472
+ def update_return_norm(self):
473
+ if len(self.memory.states) == 0:
474
+ return 0.0
475
+
476
+ # Convert memory to tensors
477
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
478
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
479
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
480
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
481
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
482
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
483
+
484
+ with T.no_grad():
485
+ # Compute next values (bootstrap for final step)
486
+ next_values = T.cat([values[1:], values[-1:].clone()])
487
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
488
+
489
+ # --- GAE-Lambda ---
490
+ adv = T.zeros_like(rewards)
491
+ gae = 0.0
492
+ for t in reversed(range(len(rewards))):
493
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
494
+ adv[t] = gae
495
+
496
+
497
+
498
+ returns = adv + values
499
+
500
+ # --- returns normalization ---
501
+ self.returnNorm.update(returns)
502
+ returns = self.returnNorm.normalize(returns)
503
+
504
+
505
+ # Advantage normalization
506
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
507
+
508
+ # --- PPO Multiple Epochs + Minibatch ---
509
+ total_loss_epoch = 0.0
510
+ num_samples = len(states)
511
+ batch_size = min(64, num_samples)
512
+ ppo_epochs = 4
513
+
514
+ for _ in range(ppo_epochs):
515
+ # Shuffle indices
516
+ idxs = T.randperm(num_samples)
517
+ for start in range(0, num_samples, batch_size):
518
+ batch_idx = idxs[start:start + batch_size]
519
+
520
+ b_states = states[batch_idx]
521
+ b_actions = actions[batch_idx]
522
+ b_old_logp = old_logp[batch_idx]
523
+ b_returns = returns[batch_idx]
524
+ b_adv = adv[batch_idx]
525
+
526
+ dist = self.policy.next_action(b_states)
527
+ new_logp = dist.log_prob(b_actions)
528
+ entropy = dist.entropy().mean()
529
+ ratio = (new_logp - b_old_logp).exp()
530
+
531
+ # --- Clipped surrogate objective ---
532
+ surr1 = ratio * b_adv
533
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
534
+ policy_loss = -T.min(surr1, surr2).mean()
535
+
536
+ # --- Critic loss ---
537
+ value_pred = self.critic.evaluated_state(b_states)
538
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
539
+
540
+ # --- Total loss ---
541
+ total_loss = (
542
+ policy_loss +
543
+ self.value_coef * value_loss -
544
+ self.entropy_coef * entropy
545
+ )
546
+
547
+ # Debug: track individual loss components
548
+ self.policy_loss_history.append(policy_loss.item())
549
+ self.value_loss_history.append(value_loss.item())
550
+
551
+ self.opt.zero_grad(set_to_none=True)
552
+ total_loss.backward()
553
+ self.opt.step()
554
+ total_loss_epoch += total_loss.item()
555
+
556
+ # Clear memory after full PPO update
557
+ self.memory.clear()
558
+
559
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
560
+
561
+ def update_reward_gradient_clipping(self):
562
+ if len(self.memory.states) == 0:
563
+ return 0.0
564
+
565
+ # Convert memory to tensors
566
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
567
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
568
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
569
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
570
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
571
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
572
+
573
+ # Reward clipping
574
+ rewards = T.clamp(rewards, -1, 1)
575
+
576
+ with T.no_grad():
577
+ # Compute next values (bootstrap for final step)
578
+ next_values = T.cat([values[1:], values[-1:].clone()])
579
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
580
+
581
+ # --- GAE-Lambda ---
582
+ adv = T.zeros_like(rewards)
583
+ gae = 0.0
584
+ for t in reversed(range(len(rewards))):
585
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
586
+ adv[t] = gae
587
+
588
+ returns = adv + values
589
+ # Advantage normalization
590
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
591
+
592
+ # --- PPO Multiple Epochs + Minibatch ---
593
+ total_loss_epoch = 0.0
594
+ num_samples = len(states)
595
+ batch_size = min(64, num_samples)
596
+ ppo_epochs = 4
597
+
598
+ for _ in range(ppo_epochs):
599
+ # Shuffle indices
600
+ idxs = T.randperm(num_samples)
601
+ for start in range(0, num_samples, batch_size):
602
+ batch_idx = idxs[start:start + batch_size]
603
+
604
+ b_states = states[batch_idx]
605
+ b_actions = actions[batch_idx]
606
+ b_old_logp = old_logp[batch_idx]
607
+ b_returns = returns[batch_idx]
608
+ b_adv = adv[batch_idx]
609
+
610
+ dist = self.policy.next_action(b_states)
611
+ new_logp = dist.log_prob(b_actions)
612
+ entropy = dist.entropy().mean()
613
+ ratio = (new_logp - b_old_logp).exp()
614
+
615
+ # --- Clipped surrogate objective ---
616
+ surr1 = ratio * b_adv
617
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
618
+ policy_loss = -T.min(surr1, surr2).mean()
619
+
620
+ # --- Critic loss ---
621
+ value_pred = self.critic.evaluated_state(b_states)
622
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
623
+
624
+ # --- Total loss ---
625
+ total_loss = (
626
+ policy_loss +
627
+ self.value_coef * value_loss -
628
+ self.entropy_coef * entropy
629
+ )
630
+
631
+ # Debug: track individual loss components
632
+ self.policy_loss_history.append(policy_loss.item())
633
+ self.value_loss_history.append(value_loss.item())
634
+
635
+ self.opt.zero_grad(set_to_none=True)
636
+ total_loss.backward()
637
+ T.nn.utils.clip_grad_norm_(list(self.policy.parameters()) + list(self.critic.parameters()), 0.5)
638
+ self.opt.step()
639
+
640
+ total_loss_epoch += total_loss.item()
641
+
642
+ # Clear memory after full PPO update
643
+ self.memory.clear()
644
+
645
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
646
+
647
+ """
648
+ # Policy network (simple MLP, flattened observations)
649
+ class Policy(nn.Module):
650
+ def __init__(self, obs_dim: int, action_dim: int, hidden: int):
651
+ super().__init__()
652
+ self.net = nn.Sequential(
653
+ nn.Linear(obs_dim, hidden),
654
+ nn.ReLU(),
655
+ nn.Linear(hidden, hidden),
656
+ nn.ReLU(),
657
+ nn.Linear(hidden, action_dim)
658
+ )
659
+
660
+ def next_action(self, state: T.Tensor) -> Categorical:
661
+ # Returns the probability distribution over actions
662
+ if state.dim() == 1:
663
+ state = state.unsqueeze(0)
664
+ state = state.view(state.size(0), -1)
665
+ return Categorical(logits=self.net(state))
666
+ """
667
+
668
+ # Policy network (CNN)
669
+ class Policy(nn.Module):
670
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
671
+ super().__init__()
672
+ c, h, w = obs_shape
673
+ # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
674
+ self.cnn = nn.Sequential(
675
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
676
+ nn.ReLU(),
677
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
678
+ nn.ReLU(),
679
+ nn.Flatten()
680
+ )
681
+
682
+ with T.no_grad():
683
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
684
+
685
+ self.net = nn.Sequential(
686
+ nn.Linear(cnn_output_dim, hidden),
687
+ nn.ReLU(),
688
+ nn.Linear(hidden, action_dim)
689
+ )
690
+
691
+ def next_action(self, state: T.Tensor) -> Categorical:
692
+ # Returns the probability distribution over actions
693
+ if state.dim() == 3:
694
+ state = state.unsqueeze(0)
695
+ cnn_out = self.cnn(state)
696
+ return Categorical(logits=self.net(cnn_out))
697
+
698
+ """
699
+ # Critic network (simple MLP, flattened observations)
700
+ class Critic(nn.Module):
701
+ def __init__(self, obs_dim: int, hidden: int):
702
+ super().__init__()
703
+ self.net = nn.Sequential(
704
+ nn.Linear(obs_dim, hidden),
705
+ nn.ReLU(),
706
+ nn.Linear(hidden, hidden),
707
+ nn.ReLU(),
708
+ nn.Linear(hidden, 1)
709
+ )
710
+
711
+ def evaluated_state(self, x: T.Tensor) -> T.Tensor:
712
+ if x.dim() == 1:
713
+ x = x.unsqueeze(0)
714
+ x = x.view(x.size(0), -1)
715
+ return self.net(x).squeeze(-1)
716
+ """
717
+
718
+ # Critic network (CNN)
719
+ class Critic(nn.Module):
720
+ def __init__(self, obs_shape: tuple, hidden: int):
721
+ super().__init__()
722
+ c, h, w = obs_shape
723
+ # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
724
+ self.cnn = nn.Sequential(
725
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
726
+ nn.ReLU(),
727
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
728
+ nn.ReLU(),
729
+ nn.Flatten()
730
+ )
731
+
732
+ with T.no_grad():
733
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
734
+
735
+ self.net = nn.Sequential(
736
+ nn.Linear(cnn_output_dim, hidden),
737
+ nn.ReLU(),
738
+ nn.Linear(hidden, 1)
739
+ )
740
+
741
+ def evaluated_state(self, x: T.Tensor) -> T.Tensor:
742
+ if x.dim() == 3:
743
+ x = x.unsqueeze(0)
744
+ cnn_out = self.cnn(x)
745
+ return self.net(cnn_out).squeeze(-1)
746
+
747
+ class Memory():
748
+ def __init__(self):
749
+ self.states = []
750
+ self.actions = []
751
+ self.rewards = []
752
+ self.dones = []
753
+ self.log_probs = []
754
+ self.values = []
755
+ self.next_values = []
756
+
757
+ def store(self, state, action, reward, done, log_prob, value, next_value):
758
+ self.states.append(np.asarray(state, dtype=np.float32))
759
+ self.actions.append(int(action))
760
+ self.rewards.append(float(reward))
761
+ self.dones.append(float(done))
762
+ self.log_probs.append(float(log_prob))
763
+ self.values.append(float(value))
764
+ self.next_values.append(float(next_value))
765
+
766
+ """
767
+ # For mini-batch updates? To be implemented
768
+ def start_batch(self, batch_size: int):
769
+ n_states = len(self.states)
770
+ starts = np.arange(0, n_states, batch_size)
771
+ index = np.arange(n_states, dtype=np.int64)
772
+ np.random.shuffle(index)
773
+ return [index[s:s + batch_size] for s in starts]
774
+ """
775
+
776
+ def clear(self):
777
+ self.states = []
778
+ self.actions = []
779
+ self.rewards = []
780
+ self.dones = []
781
+ self.log_probs = []
782
+ self.values = []
783
+ self.next_values = []
784
+
785
+
786
+
787
+ class ObservationNorm:
788
+ def __init__(self):
789
+ self.main_mean = 0
790
+ self.main_var = 0
791
+ self.count = 1e-4
792
+
793
+ def update(self, x: T.Tensor):
794
+ batch_mean = T.mean(x, dim=0)
795
+ batch_var = T.var(x, dim=0)
796
+ batch_count = x.shape[0]
797
+ self._update_from_moments(batch_mean, batch_var, batch_count)
798
+
799
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
800
+ delta = batch_mean - self.main_mean
801
+ tot_count = self.count + batch_count
802
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
803
+ m_a = self.main_var * self.count
804
+ m_b = batch_var * batch_count
805
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
806
+ new_var = M2 / tot_count # update the running variance
807
+
808
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
809
+
810
+ def normalize(self, x):
811
+
812
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
813
+ # divide through zero.
814
+
815
+
816
+
817
+
818
+
819
+ class AdvantageNorm:
820
+ '''
821
+ This class implements the Advantage Normalization. The purpose is to normalize either across batches or
822
+ only within the same batch.
823
+
824
+ '''
825
+ def __init__(self):
826
+ self.main_mean = 0
827
+ self.main_var = 0
828
+ self.count = 1e-4
829
+
830
+ def update(self, x: T.Tensor):
831
+ batch_mean = T.mean(x, dim=0)
832
+ batch_var = T.var(x, dim=0)
833
+ batch_count = x.shape[0]
834
+ self._update_from_moments(batch_mean, batch_var, batch_count)
835
+
836
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
837
+ delta = batch_mean - self.main_mean
838
+ tot_count = self.count + batch_count
839
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
840
+ m_a = self.main_var * self.count
841
+ m_b = batch_var * batch_count
842
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
843
+ new_var = M2 / tot_count # update the running variance
844
+
845
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
846
+
847
+ def normalize(self, x):
848
+
849
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
850
+ # divide through zero.
851
+
852
+
853
+
854
+
855
+ class ReturnNorm:
856
+ '''
857
+ This class implements the Advantage Normalization. The purpose is to normalize either across batches or
858
+ only within the same batch.
859
+
860
+ '''
861
+ def __init__(self):
862
+ self.main_mean = 0
863
+ self.main_var = 0
864
+ self.count = 1e-4
865
+
866
+ def update(self, x: T.Tensor):
867
+ batch_mean = T.mean(x, dim=0)
868
+ batch_var = T.var(x, dim=0)
869
+ batch_count = x.shape[0]
870
+ self._update_from_moments(batch_mean, batch_var, batch_count)
871
+
872
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
873
+ delta = batch_mean - self.main_mean
874
+ tot_count = self.count + batch_count
875
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
876
+ m_a = self.main_var * self.count
877
+ m_b = batch_var * batch_count
878
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
879
+ new_var = M2 / tot_count # update the running variance
880
+
881
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
882
+
883
+ def normalize(self, x):
884
+
885
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
886
+ # divide through zero.
887
+
888
+
889
+
890
+
891
+
Observation_Advantage_Norm_diff_env/ppo_rew_norm_obs_env_diff_env.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gymnasium as gym
3
+ import sys
4
+ import matplotlib.pyplot as plt
5
+ import ale_py
6
+ from ppo__rew_norm_obs_diff_env import *
7
+ from gymnasium.spaces import Box
8
+ import cv2
9
+
10
+ def preprocess(obs):
11
+ # Convert to grayscale
12
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
13
+ # Resize
14
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
15
+ # Add channel dimension and normalize
16
+ return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0
17
+
18
+ class PlotMultiple:
19
+ def __init__(self):
20
+ self.fig = plt.figure(figsize=(12, 8))
21
+
22
+ """
23
+ # Plot for Return-Based Scaling only
24
+ ax1 = plt.subplot(220)
25
+ ax1.plot(agent.sigma_history, label="Return σ")
26
+ ax1.set_xlabel("PPO Update")
27
+ ax1.set_ylabel("σ (Return Std)")
28
+ """
29
+
30
+ self.ax2 = plt.subplot(221)
31
+ self.ax2.set_ylabel("Average PPO Loss")
32
+ self.ax2.set_xlabel("PPO Update")
33
+
34
+ self.ax3 = plt.subplot(222)
35
+ self.ax3.set_ylabel("Reward")
36
+ self.ax3.set_xlabel("PPO Update")
37
+
38
+ # Details about value loss and policy loss
39
+ self.ax4 = plt.subplot(223)
40
+ self.ax4.set_ylabel("Policy Loss")
41
+ self.ax4.set_xlabel("Training Step")
42
+ self.ax4.legend()
43
+
44
+
45
+ self.ax5 = plt.subplot(224)
46
+ self.ax5.set_ylabel("Value Loss")
47
+ self.ax5.set_xlabel("Training Step")
48
+ self.ax5.legend()
49
+
50
+
51
+
52
+
53
+ def setPlot(self, loss_history, reward_history, policy_loss_history
54
+ , value_loss_history, env ):
55
+ self.ax2.plot(loss_history, label=env, title = "Loss")
56
+
57
+ self.ax3.plot(reward_history, label=env, title="Reward")
58
+
59
+
60
+ self.ax4.plot(policy_loss_history, label=env,title = "policy_loss", alpha=0.7)
61
+
62
+
63
+ self.ax5.plot(value_loss_history, label=env, title = "value_loss", alpha=0.7)
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+ def store(self, environ):
72
+ self.fig.suptitle("Performance with different Environments")
73
+ self.fig.tight_layout()
74
+ self.fig.savefig("Performance of "+environ + " with different_environment_.png")
75
+
76
+
77
+ def rl_model(type, plot, environ):
78
+ # env = gym.make("ALE/SpaceInvaders-v5", render_mode='human')
79
+ # env = gym.make("ALE/Pacman-v5", render_mode="human")
80
+ env = gym.make(environ)
81
+
82
+ episode = 0
83
+ total_return = 0
84
+ ep_return = 0
85
+ steps = 1000
86
+ batches = 100
87
+
88
+ print("Observation space:", env.observation_space)
89
+ print("Action space:", env.action_space)
90
+ """
91
+ agent = Agent(obs_space=env.observation_space, action_space=env.action_space,
92
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
93
+ entropy_coef=0.01, value_coef=0.5, seed=70,
94
+ batch_size = 64, ppo_epochs = 4, lam = 0.95)
95
+
96
+ """
97
+ # Initialize CNN with a dummy observation (to get correct input shape)
98
+ obs, _ = env.reset()
99
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
100
+ update_type = type
101
+ agent = Agent(obs_space=dummy_obs_space, action_space=env.action_space,
102
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
103
+ entropy_coef=0.01, value_coef=0.5, seed=70,
104
+ batch_size=64, ppo_epochs=4, lam=0.95, update_type=update_type)
105
+ """
106
+ # Stats for Return-Based Scaling only
107
+ # === Return-Based Scaling stats ===
108
+ r_mean, r_var = 0.0, 1e-8
109
+ g2_mean = 1.0
110
+
111
+ agent.r_var = r_var
112
+ agent.g2_mean = g2_mean
113
+ """
114
+
115
+ try:
116
+ obs, info = env.reset(seed=42)
117
+ state = preprocess(obs)
118
+
119
+ loss_history = []
120
+ reward_history = []
121
+
122
+ for update in range(1, batches + 1):
123
+ for t in range(steps):
124
+ action, logp, value = agent.choose_action(state)
125
+ next_obs, reward, terminated, truncated, info = env.step(action)
126
+ done = terminated or truncated
127
+ next_state = preprocess(next_obs)
128
+
129
+ agent.remember(state, action, reward, done, logp, value, next_state)
130
+
131
+ ep_return += reward
132
+ state = next_state
133
+
134
+ if done:
135
+ episode += 1
136
+ total_return += ep_return
137
+ print(f"Episode {episode} return: {ep_return:.2f}")
138
+ ep_return = 0
139
+ obs, info = env.reset()
140
+ state = preprocess(obs)
141
+
142
+ # Using reward gradient clipping
143
+ avg_loss = agent._update()
144
+
145
+ # Vanilla PPO (no normalization)
146
+ # avg_loss = agent.vanilla_ppo_update()
147
+ loss_history.append(avg_loss)
148
+
149
+ avg_ret = (total_return / episode) if episode else 0
150
+ reward_history.append(avg_ret)
151
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")
152
+
153
+
154
+ plot.setPlot(loss_history, reward_history, agent.policy_loss_history, agent.value_loss_history, environ)
155
+
156
+
157
+
158
+ except Exception as e:
159
+ print(f"Error: {e}", file=sys.stderr)
160
+ return 1
161
+ finally:
162
+ avg = total_return / episode if episode else 0
163
+ print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
164
+ env.close()
165
+
166
+ return 0
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+
175
+ def main() -> int:
176
+
177
+ list_env = ["ALE/Pacman-v5", "ALE/Gravitar-v5", "ALE/Boxing-v5"]
178
+ type_list = ["update_observation_norm", "update_advantage_norm",
179
+ "update_return_norm", "vanilla_ppo_update"]
180
+ for env in list_env:
181
+ plot = PlotMultiple()
182
+ for type in type_list:
183
+ rl_model(type, plot, env)
184
+
185
+ plot.store(env)
186
+
187
+ return 0
188
+
189
+
190
+ if __name__ == "__main__":
191
+ raise SystemExit(main())
Observation_Advantage_Norm_diff_hypo/Performance config for Learning Rate of update_advantage_norm.png ADDED
Observation_Advantage_Norm_diff_hypo/Performance config for Learning Rate of update_observation_norm.png ADDED
Observation_Advantage_Norm_diff_hypo/Performance config for Learning Rate of update_return_norm.png ADDED
Observation_Advantage_Norm_diff_hypo/Performance config for Learning Rate of vanilla_ppo_update.png ADDED
Observation_Advantage_Norm_diff_hypo/Performance config for entropy coefficient of update_advantage_norm.png ADDED
Observation_Advantage_Norm_diff_hypo/Performance config for entropy coefficient of update_observation_norm.png ADDED
Observation_Advantage_Norm_diff_hypo/Performance config for entropy coefficient of update_return_norm.png ADDED
Observation_Advantage_Norm_diff_hypo/Performance config for entropy coefficient of vanilla_ppo_update.png ADDED
Observation_Advantage_Norm_diff_hypo/Performance config for gamma value of update_advantage_norm.png ADDED
Observation_Advantage_Norm_diff_hypo/Performance config for gamma value of update_observation_norm.png ADDED
Observation_Advantage_Norm_diff_hypo/Performance config for gamma value of update_return_norm.png ADDED
Observation_Advantage_Norm_diff_hypo/Performance config for gamma value of vanilla_ppo_update.png ADDED
Observation_Advantage_Norm_diff_hypo/ppo__rew_norm_obs_diff_hyp.py ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as T
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.distributions import Categorical
6
+
7
+
8
+ class Agent:
9
+ def __init__(
10
+ self,
11
+ obs_space,
12
+ action_space,
13
+ hidden,
14
+ gamma,
15
+ clip_coef,
16
+ lr,
17
+ value_coef,
18
+ entropy_coef,
19
+ seed,
20
+ batch_size,
21
+ ppo_epochs,
22
+ lam,
23
+ update_type
24
+
25
+ ):
26
+ # Initialize seed for reproducibility
27
+ if seed is not None:
28
+ np.random.seed(seed)
29
+ T.manual_seed(seed)
30
+ """
31
+ # For flat observations (MLP model)
32
+ # Use GPU if available
33
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
34
+ self.obs_dim = int(np.prod(getattr(obs_space, "shape", (obs_space,))))
35
+ self.action_dim = int(getattr(action_space, "n", action_space))
36
+
37
+ # Initialize the policy and the critic networks
38
+ self.policy = Policy(self.obs_dim, self.action_dim, hidden).to(self.device)
39
+ self.critic = Critic(self.obs_dim, hidden).to(self.device)
40
+ """
41
+ # Use GPU if available
42
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
43
+ self.action_dim = int(getattr(action_space, "n", action_space))
44
+ self.update_type = update_type
45
+
46
+ # Initialize the policy and the critic networks
47
+ # Pass the shape tuple directly, not the flattened dimension.
48
+ self.policy = Policy(obs_space.shape, self.action_dim, hidden).to(self.device)
49
+ self.critic = Critic(obs_space.shape, hidden).to(self.device)
50
+ self.observeNorm = ObservationNorm()
51
+ self.advantageNorm = AdvantageNorm()
52
+ self.returnNorm = ReturnNorm()
53
+
54
+ # Set optimizer for policy and critic networks
55
+ self.opt = optim.Adam(
56
+ list(self.policy.parameters()) + list(self.critic.parameters()),
57
+ lr=lr
58
+ )
59
+
60
+ self.gamma = gamma
61
+ self.clip = clip_coef
62
+ self.value_coef = value_coef
63
+ self.entropy_coef = entropy_coef
64
+ self.sigma_history = []
65
+ self.loss_history = []
66
+ self.policy_loss_history = []
67
+ self.value_loss_history = []
68
+ self.entropy_history = []
69
+ self.lam = lam
70
+ self.ppo_epochs = ppo_epochs
71
+ self.batch_size = batch_size
72
+
73
+ self.memory = Memory()
74
+ """
75
+ # Choose action and remember for flat observations (MLP model)
76
+ def choose_action(self, observation):
77
+ # Returns: action, log probabilitiy, value of the state
78
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device).view(-1)
79
+ with T.no_grad():
80
+ # Forward function (defined in Policy class)
81
+ dist = self.policy.next_action(state)
82
+ action = dist.sample()
83
+ logp = dist.log_prob(action)
84
+ value = self.critic.evaluated_state(state)
85
+ return int(action.item()), float(logp.item()), float(value.item())
86
+
87
+ def remember(self, state, action, reward, done, log_prob, value, next_state):
88
+ with T.no_grad():
89
+ # Pass on next state and have it evaluated by the critic network
90
+ ns = T.as_tensor(next_state, dtype=T.float32, device=self.device).view(-1)
91
+ next_value = self.critic.evaluated_state(ns).item()
92
+ self.memory.store(state, action, reward, done, log_prob, value, next_value)
93
+ """
94
+ # For CNN model
95
+ def choose_action(self, observation):
96
+ # Returns: action, log probabilitiy, value of the state
97
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device) # Remove .view(-1)
98
+ with T.no_grad():
99
+ # Forward function (defined in Policy class)
100
+ dist = self.policy.next_action(state)
101
+ action = dist.sample()
102
+ logp = dist.log_prob(action)
103
+ value = self.critic.evaluated_state(state)
104
+ return int(action.item()), float(logp.item()), float(value.item())
105
+
106
+ def remember(self, state, action, reward, done, log_prob, value, next_state):
107
+ with T.no_grad():
108
+ # Pass on next state and have it evaluated by the critic network
109
+ ns = T.as_tensor(next_state, dtype=T.float32, device=self.device) # Remove .view(-1)
110
+ next_value = self.critic.evaluated_state(ns).item()
111
+ self.memory.store(state, action, reward, done, log_prob, value, next_value)
112
+
113
+
114
+ def _update(self):
115
+ if self.update_type == "update_observation_norm":
116
+ return self.update_observation_norm()
117
+ elif self.update_type == "update_advantage_norm":
118
+ return self.update_advantage_norm()
119
+ elif self.update_type == "update_return_norm":
120
+ return self.update_return_norm()
121
+ else:
122
+ return self.vanilla_ppo_update()
123
+
124
+ def vanilla_ppo_update(self):
125
+ if len(self.memory.states) == 0:
126
+ return 0.0
127
+
128
+ # Convert memory to tensors
129
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
130
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
131
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
132
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
133
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
134
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
135
+
136
+ with T.no_grad():
137
+ # Compute next values (bootstrap for final step)
138
+ next_values = T.cat([values[1:], values[-1:].clone()])
139
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
140
+
141
+ # --- GAE-Lambda ---
142
+ adv = T.zeros_like(rewards)
143
+ gae = 0.0
144
+ for t in reversed(range(len(rewards))):
145
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
146
+ adv[t] = gae
147
+
148
+ returns = adv + values
149
+ # Advantage normalization
150
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
151
+
152
+ # --- PPO Multiple Epochs + Minibatch ---
153
+ total_loss_epoch = 0.0
154
+ num_samples = len(states)
155
+ batch_size = min(64, num_samples)
156
+ ppo_epochs = 4
157
+
158
+ for _ in range(ppo_epochs):
159
+ # Shuffle indices
160
+ idxs = T.randperm(num_samples)
161
+ for start in range(0, num_samples, batch_size):
162
+ batch_idx = idxs[start:start + batch_size]
163
+
164
+ b_states = states[batch_idx]
165
+ b_actions = actions[batch_idx]
166
+ b_old_logp = old_logp[batch_idx]
167
+ b_returns = returns[batch_idx]
168
+ b_adv = adv[batch_idx]
169
+
170
+ dist = self.policy.next_action(b_states)
171
+ new_logp = dist.log_prob(b_actions)
172
+ entropy = dist.entropy().mean()
173
+ ratio = (new_logp - b_old_logp).exp()
174
+
175
+ # --- Clipped surrogate objective ---
176
+ surr1 = ratio * b_adv
177
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
178
+ policy_loss = -T.min(surr1, surr2).mean()
179
+
180
+ # --- Critic loss ---
181
+ value_pred = self.critic.evaluated_state(b_states)
182
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
183
+
184
+ # --- Total loss ---
185
+ total_loss = (
186
+ policy_loss +
187
+ self.value_coef * value_loss -
188
+ self.entropy_coef * entropy
189
+ )
190
+
191
+ # Debug: track individual loss components
192
+ self.policy_loss_history.append(policy_loss.item())
193
+ self.value_loss_history.append(value_loss.item())
194
+
195
+ self.opt.zero_grad(set_to_none=True)
196
+ total_loss.backward()
197
+ self.opt.step()
198
+
199
+ total_loss_epoch += total_loss.item()
200
+
201
+ # Clear memory after full PPO update
202
+ self.memory.clear()
203
+
204
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
205
+
206
+
207
+ def update_rbs(self):
208
+ if len(self.memory.states) == 0:
209
+ return 0.0
210
+
211
+ # Convert memory to tensors
212
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
213
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
214
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
215
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
216
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
217
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
218
+
219
+ with T.no_grad():
220
+ # Compute next values (bootstrap for final step)
221
+ next_values = T.cat([values[1:], values[-1:].clone()])
222
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
223
+
224
+ # --- GAE-Lambda ---
225
+ adv = T.zeros_like(rewards)
226
+ gae = 0.0
227
+ for t in reversed(range(len(rewards))):
228
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
229
+ adv[t] = gae
230
+
231
+ returns = adv + values
232
+
233
+ # --- Return-based normalization (RBS) ---
234
+ sigma_t = returns.std(unbiased=False) + 1e-8
235
+ returns = returns / sigma_t
236
+ self.sigma_history.append(sigma_t.item())
237
+ adv = adv / sigma_t
238
+ # Advantage normalization
239
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
240
+
241
+ # --- PPO Multiple Epochs + Minibatch ---
242
+ total_loss_epoch = 0.0
243
+ num_samples = len(states)
244
+ batch_size = min(64, num_samples)
245
+ ppo_epochs = 4
246
+
247
+ for _ in range(ppo_epochs):
248
+ # Shuffle indices
249
+ idxs = T.randperm(num_samples)
250
+ for start in range(0, num_samples, batch_size):
251
+ batch_idx = idxs[start:start + batch_size]
252
+
253
+ b_states = states[batch_idx]
254
+ b_actions = actions[batch_idx]
255
+ b_old_logp = old_logp[batch_idx]
256
+ b_returns = returns[batch_idx]
257
+ b_adv = adv[batch_idx]
258
+
259
+ dist = self.policy.next_action(b_states)
260
+ new_logp = dist.log_prob(b_actions)
261
+ entropy = dist.entropy().mean()
262
+ ratio = (new_logp - b_old_logp).exp()
263
+
264
+ # --- Clipped surrogate objective ---
265
+ surr1 = ratio * b_adv
266
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
267
+ policy_loss = -T.min(surr1, surr2).mean()
268
+
269
+ # --- Critic loss ---
270
+ value_pred = self.critic.evaluated_state(b_states)
271
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
272
+
273
+ # --- Total loss ---
274
+ total_loss = (
275
+ policy_loss +
276
+ self.value_coef * value_loss -
277
+ self.entropy_coef * entropy
278
+ )
279
+
280
+ # Debug: track individual loss components
281
+ self.policy_loss_history.append(policy_loss.item())
282
+ self.value_loss_history.append(value_loss.item())
283
+
284
+ self.opt.zero_grad(set_to_none=True)
285
+ total_loss.backward()
286
+ self.opt.step()
287
+ total_loss_epoch += total_loss.item()
288
+
289
+ # Clear memory after full PPO update
290
+ self.memory.clear()
291
+
292
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
293
+
294
+
295
+
296
+
297
+
298
+
299
+ def update_observation_norm(self):
300
+ if len(self.memory.states) == 0:
301
+ return 0.0
302
+
303
+ # Convert memory to tensors
304
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
305
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
306
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
307
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
308
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
309
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
310
+
311
+ with T.no_grad():
312
+ # Compute next values (bootstrap for final step)
313
+ next_values = T.cat([values[1:], values[-1:].clone()])
314
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
315
+
316
+ # --- GAE-Lambda ---
317
+ adv = T.zeros_like(rewards)
318
+ gae = 0.0
319
+ for t in reversed(range(len(rewards))):
320
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
321
+ adv[t] = gae
322
+
323
+ returns = adv + values
324
+
325
+ # --- observation normalization ---
326
+ self.observeNorm.update(states)
327
+ states = self.observeNorm.normalize(states)
328
+ # Advantage normalization
329
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
330
+
331
+ # --- PPO Multiple Epochs + Minibatch ---
332
+ total_loss_epoch = 0.0
333
+ num_samples = len(states)
334
+ batch_size = min(64, num_samples)
335
+ ppo_epochs = 4
336
+
337
+ for _ in range(ppo_epochs):
338
+ # Shuffle indices
339
+ idxs = T.randperm(num_samples)
340
+ for start in range(0, num_samples, batch_size):
341
+ batch_idx = idxs[start:start + batch_size]
342
+
343
+ b_states = states[batch_idx]
344
+ b_actions = actions[batch_idx]
345
+ b_old_logp = old_logp[batch_idx]
346
+ b_returns = returns[batch_idx]
347
+ b_adv = adv[batch_idx]
348
+
349
+ dist = self.policy.next_action(b_states)
350
+ new_logp = dist.log_prob(b_actions)
351
+ entropy = dist.entropy().mean()
352
+ ratio = (new_logp - b_old_logp).exp()
353
+
354
+ # --- Clipped surrogate objective ---
355
+ surr1 = ratio * b_adv
356
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
357
+ policy_loss = -T.min(surr1, surr2).mean()
358
+
359
+ # --- Critic loss ---
360
+ value_pred = self.critic.evaluated_state(b_states)
361
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
362
+
363
+ # --- Total loss ---
364
+ total_loss = (
365
+ policy_loss +
366
+ self.value_coef * value_loss -
367
+ self.entropy_coef * entropy
368
+ )
369
+
370
+ # Debug: track individual loss components
371
+ self.policy_loss_history.append(policy_loss.item())
372
+ self.value_loss_history.append(value_loss.item())
373
+
374
+ self.opt.zero_grad(set_to_none=True)
375
+ total_loss.backward()
376
+ self.opt.step()
377
+ total_loss_epoch += total_loss.item()
378
+
379
+ # Clear memory after full PPO update
380
+ self.memory.clear()
381
+
382
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
383
+
384
+
385
+
386
+
387
+ def update_advantage_norm(self):
388
+ if len(self.memory.states) == 0:
389
+ return 0.0
390
+
391
+ # Convert memory to tensors
392
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
393
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
394
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
395
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
396
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
397
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
398
+
399
+ with T.no_grad():
400
+ # Compute next values (bootstrap for final step)
401
+ next_values = T.cat([values[1:], values[-1:].clone()])
402
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
403
+
404
+ # --- GAE-Lambda ---
405
+ adv = T.zeros_like(rewards)
406
+ gae = 0.0
407
+ for t in reversed(range(len(rewards))):
408
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
409
+ adv[t] = gae
410
+
411
+ # --- Advantage normalization ---
412
+ returns = adv + values
413
+ self.advantageNorm.update(adv)
414
+ adv = self.observeNorm.normalize(adv)
415
+
416
+
417
+
418
+ # --- PPO Multiple Epochs + Minibatch ---
419
+ total_loss_epoch = 0.0
420
+ num_samples = len(states)
421
+ batch_size = min(64, num_samples)
422
+ ppo_epochs = 4
423
+
424
+ for _ in range(ppo_epochs):
425
+ # Shuffle indices
426
+ idxs = T.randperm(num_samples)
427
+ for start in range(0, num_samples, batch_size):
428
+ batch_idx = idxs[start:start + batch_size]
429
+
430
+ b_states = states[batch_idx]
431
+ b_actions = actions[batch_idx]
432
+ b_old_logp = old_logp[batch_idx]
433
+ b_returns = returns[batch_idx]
434
+ b_adv = adv[batch_idx]
435
+
436
+ dist = self.policy.next_action(b_states)
437
+ new_logp = dist.log_prob(b_actions)
438
+ entropy = dist.entropy().mean()
439
+ ratio = (new_logp - b_old_logp).exp()
440
+
441
+ # --- Clipped surrogate objective ---
442
+ surr1 = ratio * b_adv
443
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
444
+ policy_loss = -T.min(surr1, surr2).mean()
445
+
446
+ # --- Critic loss ---
447
+ value_pred = self.critic.evaluated_state(b_states)
448
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
449
+
450
+ # --- Total loss ---
451
+ total_loss = (
452
+ policy_loss +
453
+ self.value_coef * value_loss -
454
+ self.entropy_coef * entropy
455
+ )
456
+
457
+ # Debug: track individual loss components
458
+ self.policy_loss_history.append(policy_loss.item())
459
+ self.value_loss_history.append(value_loss.item())
460
+
461
+ self.opt.zero_grad(set_to_none=True)
462
+ total_loss.backward()
463
+ self.opt.step()
464
+ total_loss_epoch += total_loss.item()
465
+
466
+ # Clear memory after full PPO update
467
+ self.memory.clear()
468
+
469
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
470
+
471
+ def update_return_norm(self):
472
+ if len(self.memory.states) == 0:
473
+ return 0.0
474
+
475
+ # Convert memory to tensors
476
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
477
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
478
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
479
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
480
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
481
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
482
+
483
+ with T.no_grad():
484
+ # Compute next values (bootstrap for final step)
485
+ next_values = T.cat([values[1:], values[-1:].clone()])
486
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
487
+
488
+ # --- GAE-Lambda ---
489
+ adv = T.zeros_like(rewards)
490
+ gae = 0.0
491
+ for t in reversed(range(len(rewards))):
492
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
493
+ adv[t] = gae
494
+
495
+
496
+
497
+ returns = adv + values
498
+
499
+ # --- returns normalization ---
500
+ self.returnNorm.update(returns)
501
+ returns = self.returnNorm.normalize(returns)
502
+
503
+
504
+ # Advantage normalization
505
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
506
+
507
+ # --- PPO Multiple Epochs + Minibatch ---
508
+ total_loss_epoch = 0.0
509
+ num_samples = len(states)
510
+ batch_size = min(64, num_samples)
511
+ ppo_epochs = 4
512
+
513
+ for _ in range(ppo_epochs):
514
+ # Shuffle indices
515
+ idxs = T.randperm(num_samples)
516
+ for start in range(0, num_samples, batch_size):
517
+ batch_idx = idxs[start:start + batch_size]
518
+
519
+ b_states = states[batch_idx]
520
+ b_actions = actions[batch_idx]
521
+ b_old_logp = old_logp[batch_idx]
522
+ b_returns = returns[batch_idx]
523
+ b_adv = adv[batch_idx]
524
+
525
+ dist = self.policy.next_action(b_states)
526
+ new_logp = dist.log_prob(b_actions)
527
+ entropy = dist.entropy().mean()
528
+ ratio = (new_logp - b_old_logp).exp()
529
+
530
+ # --- Clipped surrogate objective ---
531
+ surr1 = ratio * b_adv
532
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
533
+ policy_loss = -T.min(surr1, surr2).mean()
534
+
535
+ # --- Critic loss ---
536
+ value_pred = self.critic.evaluated_state(b_states)
537
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
538
+
539
+ # --- Total loss ---
540
+ total_loss = (
541
+ policy_loss +
542
+ self.value_coef * value_loss -
543
+ self.entropy_coef * entropy
544
+ )
545
+
546
+ # Debug: track individual loss components
547
+ self.policy_loss_history.append(policy_loss.item())
548
+ self.value_loss_history.append(value_loss.item())
549
+
550
+ self.opt.zero_grad(set_to_none=True)
551
+ total_loss.backward()
552
+ self.opt.step()
553
+ total_loss_epoch += total_loss.item()
554
+
555
+ # Clear memory after full PPO update
556
+ self.memory.clear()
557
+
558
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
559
+
560
+ def update_reward_gradient_clipping(self):
561
+ if len(self.memory.states) == 0:
562
+ return 0.0
563
+
564
+ # Convert memory to tensors
565
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
566
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
567
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
568
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
569
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
570
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
571
+
572
+ # Reward clipping
573
+ rewards = T.clamp(rewards, -1, 1)
574
+
575
+ with T.no_grad():
576
+ # Compute next values (bootstrap for final step)
577
+ next_values = T.cat([values[1:], values[-1:].clone()])
578
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
579
+
580
+ # --- GAE-Lambda ---
581
+ adv = T.zeros_like(rewards)
582
+ gae = 0.0
583
+ for t in reversed(range(len(rewards))):
584
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
585
+ adv[t] = gae
586
+
587
+ returns = adv + values
588
+ # Advantage normalization
589
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
590
+
591
+ # --- PPO Multiple Epochs + Minibatch ---
592
+ total_loss_epoch = 0.0
593
+ num_samples = len(states)
594
+ batch_size = min(64, num_samples)
595
+ ppo_epochs = 4
596
+
597
+ for _ in range(ppo_epochs):
598
+ # Shuffle indices
599
+ idxs = T.randperm(num_samples)
600
+ for start in range(0, num_samples, batch_size):
601
+ batch_idx = idxs[start:start + batch_size]
602
+
603
+ b_states = states[batch_idx]
604
+ b_actions = actions[batch_idx]
605
+ b_old_logp = old_logp[batch_idx]
606
+ b_returns = returns[batch_idx]
607
+ b_adv = adv[batch_idx]
608
+
609
+ dist = self.policy.next_action(b_states)
610
+ new_logp = dist.log_prob(b_actions)
611
+ entropy = dist.entropy().mean()
612
+ ratio = (new_logp - b_old_logp).exp()
613
+
614
+ # --- Clipped surrogate objective ---
615
+ surr1 = ratio * b_adv
616
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
617
+ policy_loss = -T.min(surr1, surr2).mean()
618
+
619
+ # --- Critic loss ---
620
+ value_pred = self.critic.evaluated_state(b_states)
621
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
622
+
623
+ # --- Total loss ---
624
+ total_loss = (
625
+ policy_loss +
626
+ self.value_coef * value_loss -
627
+ self.entropy_coef * entropy
628
+ )
629
+
630
+ # Debug: track individual loss components
631
+ self.policy_loss_history.append(policy_loss.item())
632
+ self.value_loss_history.append(value_loss.item())
633
+
634
+ self.opt.zero_grad(set_to_none=True)
635
+ total_loss.backward()
636
+ T.nn.utils.clip_grad_norm_(list(self.policy.parameters()) + list(self.critic.parameters()), 0.5)
637
+ self.opt.step()
638
+
639
+ total_loss_epoch += total_loss.item()
640
+
641
+ # Clear memory after full PPO update
642
+ self.memory.clear()
643
+
644
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
645
+
646
+ """
647
+ # Policy network (simple MLP, flattened observations)
648
+ class Policy(nn.Module):
649
+ def __init__(self, obs_dim: int, action_dim: int, hidden: int):
650
+ super().__init__()
651
+ self.net = nn.Sequential(
652
+ nn.Linear(obs_dim, hidden),
653
+ nn.ReLU(),
654
+ nn.Linear(hidden, hidden),
655
+ nn.ReLU(),
656
+ nn.Linear(hidden, action_dim)
657
+ )
658
+
659
+ def next_action(self, state: T.Tensor) -> Categorical:
660
+ # Returns the probability distribution over actions
661
+ if state.dim() == 1:
662
+ state = state.unsqueeze(0)
663
+ state = state.view(state.size(0), -1)
664
+ return Categorical(logits=self.net(state))
665
+ """
666
+
667
+ # Policy network (CNN)
668
+ class Policy(nn.Module):
669
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
670
+ super().__init__()
671
+ c, h, w = obs_shape
672
+ # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
673
+ self.cnn = nn.Sequential(
674
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
675
+ nn.ReLU(),
676
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
677
+ nn.ReLU(),
678
+ nn.Flatten()
679
+ )
680
+
681
+ with T.no_grad():
682
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
683
+
684
+ self.net = nn.Sequential(
685
+ nn.Linear(cnn_output_dim, hidden),
686
+ nn.ReLU(),
687
+ nn.Linear(hidden, action_dim)
688
+ )
689
+
690
+ def next_action(self, state: T.Tensor) -> Categorical:
691
+ # Returns the probability distribution over actions
692
+ if state.dim() == 3:
693
+ state = state.unsqueeze(0)
694
+ cnn_out = self.cnn(state)
695
+ return Categorical(logits=self.net(cnn_out))
696
+
697
+ """
698
+ # Critic network (simple MLP, flattened observations)
699
+ class Critic(nn.Module):
700
+ def __init__(self, obs_dim: int, hidden: int):
701
+ super().__init__()
702
+ self.net = nn.Sequential(
703
+ nn.Linear(obs_dim, hidden),
704
+ nn.ReLU(),
705
+ nn.Linear(hidden, hidden),
706
+ nn.ReLU(),
707
+ nn.Linear(hidden, 1)
708
+ )
709
+
710
+ def evaluated_state(self, x: T.Tensor) -> T.Tensor:
711
+ if x.dim() == 1:
712
+ x = x.unsqueeze(0)
713
+ x = x.view(x.size(0), -1)
714
+ return self.net(x).squeeze(-1)
715
+ """
716
+
717
+ # Critic network (CNN)
718
+ class Critic(nn.Module):
719
+ def __init__(self, obs_shape: tuple, hidden: int):
720
+ super().__init__()
721
+ c, h, w = obs_shape
722
+ # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
723
+ self.cnn = nn.Sequential(
724
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
725
+ nn.ReLU(),
726
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
727
+ nn.ReLU(),
728
+ nn.Flatten()
729
+ )
730
+
731
+ with T.no_grad():
732
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
733
+
734
+ self.net = nn.Sequential(
735
+ nn.Linear(cnn_output_dim, hidden),
736
+ nn.ReLU(),
737
+ nn.Linear(hidden, 1)
738
+ )
739
+
740
+ def evaluated_state(self, x: T.Tensor) -> T.Tensor:
741
+ if x.dim() == 3:
742
+ x = x.unsqueeze(0)
743
+ cnn_out = self.cnn(x)
744
+ return self.net(cnn_out).squeeze(-1)
745
+
746
+ class Memory():
747
+ def __init__(self):
748
+ self.states = []
749
+ self.actions = []
750
+ self.rewards = []
751
+ self.dones = []
752
+ self.log_probs = []
753
+ self.values = []
754
+ self.next_values = []
755
+
756
+ def store(self, state, action, reward, done, log_prob, value, next_value):
757
+ self.states.append(np.asarray(state, dtype=np.float32))
758
+ self.actions.append(int(action))
759
+ self.rewards.append(float(reward))
760
+ self.dones.append(float(done))
761
+ self.log_probs.append(float(log_prob))
762
+ self.values.append(float(value))
763
+ self.next_values.append(float(next_value))
764
+
765
+ """
766
+ # For mini-batch updates? To be implemented
767
+ def start_batch(self, batch_size: int):
768
+ n_states = len(self.states)
769
+ starts = np.arange(0, n_states, batch_size)
770
+ index = np.arange(n_states, dtype=np.int64)
771
+ np.random.shuffle(index)
772
+ return [index[s:s + batch_size] for s in starts]
773
+ """
774
+
775
+ def clear(self):
776
+ self.states = []
777
+ self.actions = []
778
+ self.rewards = []
779
+ self.dones = []
780
+ self.log_probs = []
781
+ self.values = []
782
+ self.next_values = []
783
+
784
+
785
+
786
+ class ObservationNorm:
787
+ def __init__(self):
788
+ self.main_mean = 0
789
+ self.main_var = 0
790
+ self.count = 1e-4
791
+
792
+ def update(self, x: T.Tensor):
793
+ batch_mean = T.mean(x, dim=0)
794
+ batch_var = T.var(x, dim=0)
795
+ batch_count = x.shape[0]
796
+ self._update_from_moments(batch_mean, batch_var, batch_count)
797
+
798
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
799
+ delta = batch_mean - self.main_mean
800
+ tot_count = self.count + batch_count
801
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
802
+ m_a = self.main_var * self.count
803
+ m_b = batch_var * batch_count
804
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
805
+ new_var = M2 / tot_count # update the running variance
806
+
807
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
808
+
809
+ def normalize(self, x):
810
+
811
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
812
+ # divide through zero.
813
+
814
+
815
+
816
+
817
+
818
+ class AdvantageNorm:
819
+ '''
820
+ This class implements the Advantage Normalization. The purpose is to normalize either across batches or
821
+ only within the same batch.
822
+
823
+ '''
824
+ def __init__(self):
825
+ self.main_mean = 0
826
+ self.main_var = 0
827
+ self.count = 1e-4
828
+
829
+ def update(self, x: T.Tensor):
830
+ batch_mean = T.mean(x, dim=0)
831
+ batch_var = T.var(x, dim=0)
832
+ batch_count = x.shape[0]
833
+ self._update_from_moments(batch_mean, batch_var, batch_count)
834
+
835
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
836
+ delta = batch_mean - self.main_mean
837
+ tot_count = self.count + batch_count
838
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
839
+ m_a = self.main_var * self.count
840
+ m_b = batch_var * batch_count
841
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
842
+ new_var = M2 / tot_count # update the running variance
843
+
844
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
845
+
846
+ def normalize(self, x):
847
+
848
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
849
+ # divide through zero.
850
+
851
+
852
+
853
+
854
+ class ReturnNorm:
855
+ '''
856
+ This class implements the Advantage Normalization. The purpose is to normalize either across batches or
857
+ only within the same batch.
858
+
859
+ '''
860
+ def __init__(self):
861
+ self.main_mean = 0
862
+ self.main_var = 0
863
+ self.count = 1e-4
864
+
865
+ def update(self, x: T.Tensor):
866
+ batch_mean = T.mean(x, dim=0)
867
+ batch_var = T.var(x, dim=0)
868
+ batch_count = x.shape[0]
869
+ self._update_from_moments(batch_mean, batch_var, batch_count)
870
+
871
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
872
+ delta = batch_mean - self.main_mean
873
+ tot_count = self.count + batch_count
874
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
875
+ m_a = self.main_var * self.count
876
+ m_b = batch_var * batch_count
877
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
878
+ new_var = M2 / tot_count # update the running variance
879
+
880
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
881
+
882
+ def normalize(self, x):
883
+
884
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
885
+ # divide through zero.
886
+
887
+
888
+
889
+
890
+
Observation_Advantage_Norm/PPO_environment.py → Observation_Advantage_Norm_diff_hypo/ppo_rew_norm_obs_env_diff_hypo.py RENAMED
@@ -1,43 +1,75 @@
1
- import ale_py
2
  import gymnasium as gym
3
  import sys
4
- import numpy as np
5
- from PPO_Obser_Adva_Norm import *
6
  import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
7
 
8
 
9
  def preprocess(obs):
10
- # Flatten and normalize uint8 frames to float32 in [0,1]
11
- return (obs.astype(np.float32).ravel() / 255.0)
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- def main() -> int:
14
- # Initialize environment
15
- env = gym.make("ALE/Pacman-v5", render_mode="human") # consider removing render_mode for training speed
16
- # Initialize variables
17
  episode = 0
18
  total_return = 0
19
  ep_return = 0
20
- steps = 1000 # Batch of 100, 1000 environment steps per update
21
- batches = 15
22
- mode = "clip"
23
- average_return = []
24
- total_loss = []
25
- updates = []
26
- activate_observation_norm = True
27
- activate_advantage_norm = False
28
-
29
- # Inspect spaces
30
  print("Observation space:", env.observation_space)
31
  print("Action space:", env.action_space)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Create PPO Agent (adapted to ppo_helpers_v2.Agent signature)
34
- agent = Agent(obs_space=env.observation_space, action_space=env.action_space, hidden=64,
35
- lr=3e-4, gamma=0.99, clip_coef=0.2, entropy_coef=0, value_coef=0.5, seed=70)
36
 
37
  try:
38
  obs, info = env.reset(seed=42)
39
  state = preprocess(obs)
40
 
 
 
 
 
 
 
41
  for update in range(1, batches + 1):
42
  for t in range(steps):
43
  action, logp, value = agent.choose_action(state)
@@ -58,12 +90,21 @@ def main() -> int:
58
  obs, info = env.reset()
59
  state = preprocess(obs)
60
 
61
- agent._update(mode, activate_observation_norm, activate_advantage_norm)
62
- avg_ret = (total_return / episode) if episode else 0
63
- average_return.append(avg_ret)
64
- total_loss.append(float(agent.total_loss.detach().cpu().item()))
65
- updates.append(update)
66
- print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}")
 
 
 
 
 
 
 
 
 
67
 
68
  except Exception as e:
69
  print(f"Error: {e}", file=sys.stderr)
@@ -72,22 +113,46 @@ def main() -> int:
72
  avg = total_return / episode if episode else 0
73
  print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
74
  env.close()
75
- plt.plot(updates, average_return, label="average return")
76
- plt.plot(updates, total_loss, label="total loss")
77
- if activate_advantage_norm:
78
- plt.title("Average return vs. total loss with advantage norm")
79
- elif activate_observation_norm:
80
- plt.title("Average return vs. total loss with observation norm")
81
- elif activate_advantage_norm and activate_observation_norm:
82
- plt.title("Average return vs. total loss with observation norm and advantage norm")
83
- else:
84
- plt.title("Average return vs. total loss with no normalization")
85
-
86
- plt.xlabel("updates")
87
- plt.ylabel("average return/total loss")
88
- plt.legend()
89
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  return 0
91
 
 
92
  if __name__ == "__main__":
93
- raise SystemExit(main())
 
1
+
2
  import gymnasium as gym
3
  import sys
 
 
4
  import matplotlib.pyplot as plt
5
+ import ale_py
6
+ from ppo__rew_norm_obs_diff_hyp import *
7
+ from gymnasium.spaces import Box
8
+ import cv2
9
+ import matplotlib.pyplot as plt
10
+
11
+
12
+
13
 
14
 
15
  def preprocess(obs):
16
+ # Convert to grayscale
17
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
18
+ # Resize
19
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
20
+ # Add channel dimension and normalize
21
+ return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0
22
+
23
+
24
+ def rl_model(type, gamma = 0.99, clip_coef = 0.2,
25
+ lr = 1e-3, ent_coef = 0.01):
26
+ # env = gym.make("ALE/SpaceInvaders-v5", render_mode='human')
27
+ # env = gym.make("ALE/Pacman-v5", render_mode="human")
28
+ env = gym.make("ALE/Pacman-v5")
29
 
 
 
 
 
30
  episode = 0
31
  total_return = 0
32
  ep_return = 0
33
+ steps = 1000
34
+ batches = 100
35
+
 
 
 
 
 
 
 
36
  print("Observation space:", env.observation_space)
37
  print("Action space:", env.action_space)
38
+ """
39
+ agent = Agent(obs_space=env.observation_space, action_space=env.action_space,
40
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
41
+ entropy_coef=0.01, value_coef=0.5, seed=70,
42
+ batch_size = 64, ppo_epochs = 4, lam = 0.95)
43
+
44
+ """
45
+ # Initialize CNN with a dummy observation (to get correct input shape)
46
+ obs, _ = env.reset()
47
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
48
+ update_type = type
49
+ agent = Agent(obs_space=dummy_obs_space, action_space=env.action_space,
50
+ hidden=64, lr= lr, gamma= gamma, clip_coef= clip_coef,
51
+ entropy_coef= ent_coef, value_coef=0.5, seed=70,
52
+ batch_size=64, ppo_epochs=4, lam=0.95, update_type=update_type)
53
+ """
54
+ # Stats for Return-Based Scaling only
55
+ # === Return-Based Scaling stats ===
56
+ r_mean, r_var = 0.0, 1e-8
57
+ g2_mean = 1.0
58
 
59
+ agent.r_var = r_var
60
+ agent.g2_mean = g2_mean
61
+ """
62
 
63
  try:
64
  obs, info = env.reset(seed=42)
65
  state = preprocess(obs)
66
 
67
+ loss_history = []
68
+ reward_history = []
69
+
70
+ labels = []
71
+ final_scores = []
72
+
73
  for update in range(1, batches + 1):
74
  for t in range(steps):
75
  action, logp, value = agent.choose_action(state)
 
90
  obs, info = env.reset()
91
  state = preprocess(obs)
92
 
93
+ # Using reward gradient clipping
94
+ avg_loss = agent._update()
95
+
96
+ # Vanilla PPO (no normalization)
97
+ # avg_loss = agent.vanilla_ppo_update()
98
+ loss_history.append(avg_loss)
99
+
100
+ avg_ret = (total_return / episode) if episode else 0
101
+ reward_history.append(avg_ret)
102
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")
103
+
104
+ return reward_history, loss_history
105
+
106
+
107
+
108
 
109
  except Exception as e:
110
  print(f"Error: {e}", file=sys.stderr)
 
113
  avg = total_return / episode if episode else 0
114
  print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
115
  env.close()
116
+
117
+ return 0
118
+
119
+ def createHisto(x, final_scores, labels, title):
120
+ plt.figure(figsize=(10, 6)) # ← NEW FIGURE
121
+ plt.bar(x, final_scores)
122
+ plt.xticks(x, labels, rotation=45, ha="right")
123
+ plt.ylabel("Mean Reward")
124
+ plt.title(title)
125
+ plt.tight_layout()
126
+ plt.savefig(title + ".png")
127
+ plt.close()
128
+
129
+
130
+
131
+
132
+
133
+
134
+ def main() -> int:
135
+ type_list = ["update_observation_norm","update_advantage_norm",
136
+ "update_return_norm", "vanilla_ppo_update"]
137
+ learning_rates = [1e-2, 1e-3, 1e-4]
138
+ clip_coefs = [0.01, 0.1, 0.3 ]
139
+ gamma_list = [0.99, 0.97, 0.95]
140
+ entropy_coefs_list = [0.1, 0.01, 0.001]
141
+ final_scores = []
142
+ labels = ["entropy coef. = " + str(entrop_ceof) for entrop_ceof in entropy_coefs_list]
143
+ for update_type in type_list:
144
+ final_scores = []
145
+ for entrop_ceof in entropy_coefs_list:
146
+ reward_history, loss_history = rl_model(update_type, ent_coef = entrop_ceof )
147
+ final_scores.append(np.mean(reward_history))
148
+
149
+ createHisto(np.arange(len(labels)), final_scores, labels, "Performance config for entropy coefficient of " + update_type)
150
+
151
+
152
+
153
+
154
  return 0
155
 
156
+
157
  if __name__ == "__main__":
158
+ raise SystemExit(main())
Observation_Advantage_Norm_in_batch/ppo__rew_norm_obs_in_batch.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as T
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.distributions import Categorical
6
+
7
+
8
+ class Agent:
9
+ def __init__(
10
+ self,
11
+ obs_space,
12
+ action_space,
13
+ hidden,
14
+ gamma,
15
+ clip_coef,
16
+ lr,
17
+ value_coef,
18
+ entropy_coef,
19
+ seed,
20
+ batch_size,
21
+ ppo_epochs,
22
+ lam,
23
+ update_type
24
+
25
+ ):
26
+ # Initialize seed for reproducibility
27
+ if seed is not None:
28
+ np.random.seed(seed)
29
+ T.manual_seed(seed)
30
+ """
31
+ # For flat observations (MLP model)
32
+ # Use GPU if available
33
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
34
+ self.obs_dim = int(np.prod(getattr(obs_space, "shape", (obs_space,))))
35
+ self.action_dim = int(getattr(action_space, "n", action_space))
36
+
37
+ # Initialize the policy and the critic networks
38
+ self.policy = Policy(self.obs_dim, self.action_dim, hidden).to(self.device)
39
+ self.critic = Critic(self.obs_dim, hidden).to(self.device)
40
+ """
41
+ # Use GPU if available
42
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
43
+ self.action_dim = int(getattr(action_space, "n", action_space))
44
+ self.update_type = update_type
45
+
46
+ # Initialize the policy and the critic networks
47
+ # Pass the shape tuple directly, not the flattened dimension.
48
+ self.policy = Policy(obs_space.shape, self.action_dim, hidden).to(self.device)
49
+ self.critic = Critic(obs_space.shape, hidden).to(self.device)
50
+ self.observeNorm = ObservationNorm()
51
+ self.advantageNorm = AdvantageNorm()
52
+ self.returnNorm = ReturnNorm()
53
+
54
+ # Set optimizer for policy and critic networks
55
+ self.opt = optim.Adam(
56
+ list(self.policy.parameters()) + list(self.critic.parameters()),
57
+ lr=lr
58
+ )
59
+
60
+ self.gamma = gamma
61
+ self.clip = clip_coef
62
+ self.value_coef = value_coef
63
+ self.entropy_coef = entropy_coef
64
+ self.sigma_history = []
65
+ self.loss_history = []
66
+ self.policy_loss_history = []
67
+ self.value_loss_history = []
68
+ self.entropy_history = []
69
+ self.lam = lam
70
+ self.ppo_epochs = ppo_epochs
71
+ self.batch_size = batch_size
72
+
73
+ self.memory = Memory()
74
+ """
75
+ # Choose action and remember for flat observations (MLP model)
76
+ def choose_action(self, observation):
77
+ # Returns: action, log probabilitiy, value of the state
78
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device).view(-1)
79
+ with T.no_grad():
80
+ # Forward function (defined in Policy class)
81
+ dist = self.policy.next_action(state)
82
+ action = dist.sample()
83
+ logp = dist.log_prob(action)
84
+ value = self.critic.evaluated_state(state)
85
+ return int(action.item()), float(logp.item()), float(value.item())
86
+
87
+ def remember(self, state, action, reward, done, log_prob, value, next_state):
88
+ with T.no_grad():
89
+ # Pass on next state and have it evaluated by the critic network
90
+ ns = T.as_tensor(next_state, dtype=T.float32, device=self.device).view(-1)
91
+ next_value = self.critic.evaluated_state(ns).item()
92
+ self.memory.store(state, action, reward, done, log_prob, value, next_value)
93
+ """
94
+ # For CNN model
95
+ def choose_action(self, observation):
96
+ # Returns: action, log probabilitiy, value of the state
97
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device) # Remove .view(-1)
98
+ with T.no_grad():
99
+ # Forward function (defined in Policy class)
100
+ dist = self.policy.next_action(state)
101
+ action = dist.sample()
102
+ logp = dist.log_prob(action)
103
+ value = self.critic.evaluated_state(state)
104
+ return int(action.item()), float(logp.item()), float(value.item())
105
+
106
+ def remember(self, state, action, reward, done, log_prob, value, next_state):
107
+ with T.no_grad():
108
+ # Pass on next state and have it evaluated by the critic network
109
+ ns = T.as_tensor(next_state, dtype=T.float32, device=self.device) # Remove .view(-1)
110
+ next_value = self.critic.evaluated_state(ns).item()
111
+ self.memory.store(state, action, reward, done, log_prob, value, next_value)
112
+
113
+
114
+ def _update(self):
115
+ if self.update_type == "update_observation_norm":
116
+ return self.update_observation_norm()
117
+ elif self.update_type == "update_advantage_norm":
118
+ return self.update_advantage_norm()
119
+ elif self.update_type == "update_return_norm":
120
+ return self.update_return_norm()
121
+ else:
122
+ return self.vanilla_ppo_update()
123
+
124
+ def vanilla_ppo_update(self):
125
+ if len(self.memory.states) == 0:
126
+ return 0.0
127
+
128
+ # Convert memory to tensors
129
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
130
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
131
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
132
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
133
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
134
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
135
+
136
+ with T.no_grad():
137
+ # Compute next values (bootstrap for final step)
138
+ next_values = T.cat([values[1:], values[-1:].clone()])
139
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
140
+
141
+ # --- GAE-Lambda ---
142
+ adv = T.zeros_like(rewards)
143
+ gae = 0.0
144
+ for t in reversed(range(len(rewards))):
145
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
146
+ adv[t] = gae
147
+
148
+ returns = adv + values
149
+ # Advantage normalization
150
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
151
+
152
+ # --- PPO Multiple Epochs + Minibatch ---
153
+ total_loss_epoch = 0.0
154
+ num_samples = len(states)
155
+ batch_size = min(64, num_samples)
156
+ ppo_epochs = 4
157
+
158
+ for _ in range(ppo_epochs):
159
+ # Shuffle indices
160
+ idxs = T.randperm(num_samples)
161
+ for start in range(0, num_samples, batch_size):
162
+ batch_idx = idxs[start:start + batch_size]
163
+
164
+ b_states = states[batch_idx]
165
+ b_actions = actions[batch_idx]
166
+ b_old_logp = old_logp[batch_idx]
167
+ b_returns = returns[batch_idx]
168
+ b_adv = adv[batch_idx]
169
+
170
+ dist = self.policy.next_action(b_states)
171
+ new_logp = dist.log_prob(b_actions)
172
+ entropy = dist.entropy().mean()
173
+ ratio = (new_logp - b_old_logp).exp()
174
+
175
+ # --- Clipped surrogate objective ---
176
+ surr1 = ratio * b_adv
177
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
178
+ policy_loss = -T.min(surr1, surr2).mean()
179
+
180
+ # --- Critic loss ---
181
+ value_pred = self.critic.evaluated_state(b_states)
182
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
183
+
184
+ # --- Total loss ---
185
+ total_loss = (
186
+ policy_loss +
187
+ self.value_coef * value_loss -
188
+ self.entropy_coef * entropy
189
+ )
190
+
191
+ # Debug: track individual loss components
192
+ self.policy_loss_history.append(policy_loss.item())
193
+ self.value_loss_history.append(value_loss.item())
194
+
195
+ self.opt.zero_grad(set_to_none=True)
196
+ total_loss.backward()
197
+ self.opt.step()
198
+
199
+ total_loss_epoch += total_loss.item()
200
+
201
+ # Clear memory after full PPO update
202
+ self.memory.clear()
203
+
204
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
205
+
206
+
207
+ def update_rbs(self):
208
+ if len(self.memory.states) == 0:
209
+ return 0.0
210
+
211
+ # Convert memory to tensors
212
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
213
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
214
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
215
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
216
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
217
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
218
+
219
+ with T.no_grad():
220
+ # Compute next values (bootstrap for final step)
221
+ next_values = T.cat([values[1:], values[-1:].clone()])
222
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
223
+
224
+ # --- GAE-Lambda ---
225
+ adv = T.zeros_like(rewards)
226
+ gae = 0.0
227
+ for t in reversed(range(len(rewards))):
228
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
229
+ adv[t] = gae
230
+
231
+ returns = adv + values
232
+
233
+ # --- Return-based normalization (RBS) ---
234
+ sigma_t = returns.std(unbiased=False) + 1e-8
235
+ returns = returns / sigma_t
236
+ self.sigma_history.append(sigma_t.item())
237
+ adv = adv / sigma_t
238
+ # Advantage normalization
239
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
240
+
241
+ # --- PPO Multiple Epochs + Minibatch ---
242
+ total_loss_epoch = 0.0
243
+ num_samples = len(states)
244
+ batch_size = min(64, num_samples)
245
+ ppo_epochs = 4
246
+
247
+ for _ in range(ppo_epochs):
248
+ # Shuffle indices
249
+ idxs = T.randperm(num_samples)
250
+ for start in range(0, num_samples, batch_size):
251
+ batch_idx = idxs[start:start + batch_size]
252
+
253
+ b_states = states[batch_idx]
254
+ b_actions = actions[batch_idx]
255
+ b_old_logp = old_logp[batch_idx]
256
+ b_returns = returns[batch_idx]
257
+ b_adv = adv[batch_idx]
258
+
259
+ dist = self.policy.next_action(b_states)
260
+ new_logp = dist.log_prob(b_actions)
261
+ entropy = dist.entropy().mean()
262
+ ratio = (new_logp - b_old_logp).exp()
263
+
264
+ # --- Clipped surrogate objective ---
265
+ surr1 = ratio * b_adv
266
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
267
+ policy_loss = -T.min(surr1, surr2).mean()
268
+
269
+ # --- Critic loss ---
270
+ value_pred = self.critic.evaluated_state(b_states)
271
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
272
+
273
+ # --- Total loss ---
274
+ total_loss = (
275
+ policy_loss +
276
+ self.value_coef * value_loss -
277
+ self.entropy_coef * entropy
278
+ )
279
+
280
+ # Debug: track individual loss components
281
+ self.policy_loss_history.append(policy_loss.item())
282
+ self.value_loss_history.append(value_loss.item())
283
+
284
+ self.opt.zero_grad(set_to_none=True)
285
+ total_loss.backward()
286
+ self.opt.step()
287
+ total_loss_epoch += total_loss.item()
288
+
289
+ # Clear memory after full PPO update
290
+ self.memory.clear()
291
+
292
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
293
+
294
+
295
+
296
+
297
+
298
+
299
+ def update_observation_norm(self):
300
+ if len(self.memory.states) == 0:
301
+ return 0.0
302
+
303
+ # Convert memory to tensors
304
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
305
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
306
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
307
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
308
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
309
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
310
+
311
+ with T.no_grad():
312
+ # Compute next values (bootstrap for final step)
313
+ next_values = T.cat([values[1:], values[-1:].clone()])
314
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
315
+
316
+ # --- GAE-Lambda ---
317
+ adv = T.zeros_like(rewards)
318
+ gae = 0.0
319
+ for t in reversed(range(len(rewards))):
320
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
321
+ adv[t] = gae
322
+
323
+ returns = adv + values
324
+
325
+ # --- observation normalization ---
326
+ states = self.observeNorm.normalize(states)
327
+ # Advantage normalization
328
+ # adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
329
+
330
+ # --- PPO Multiple Epochs + Minibatch ---
331
+ total_loss_epoch = 0.0
332
+ num_samples = len(states)
333
+ batch_size = min(64, num_samples)
334
+ ppo_epochs = 4
335
+
336
+ for _ in range(ppo_epochs):
337
+ # Shuffle indices
338
+ idxs = T.randperm(num_samples)
339
+ for start in range(0, num_samples, batch_size):
340
+ batch_idx = idxs[start:start + batch_size]
341
+
342
+ b_states = states[batch_idx]
343
+ b_actions = actions[batch_idx]
344
+ b_old_logp = old_logp[batch_idx]
345
+ b_returns = returns[batch_idx]
346
+ b_adv = adv[batch_idx]
347
+
348
+ dist = self.policy.next_action(b_states)
349
+ new_logp = dist.log_prob(b_actions)
350
+ entropy = dist.entropy().mean()
351
+ ratio = (new_logp - b_old_logp).exp()
352
+
353
+ # --- Clipped surrogate objective ---
354
+ surr1 = ratio * b_adv
355
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
356
+ policy_loss = -T.min(surr1, surr2).mean()
357
+
358
+ # --- Critic loss ---
359
+ value_pred = self.critic.evaluated_state(b_states)
360
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
361
+
362
+ # --- Total loss ---
363
+ total_loss = (
364
+ policy_loss +
365
+ self.value_coef * value_loss -
366
+ self.entropy_coef * entropy
367
+ )
368
+
369
+ # Debug: track individual loss components
370
+ self.policy_loss_history.append(policy_loss.item())
371
+ self.value_loss_history.append(value_loss.item())
372
+
373
+ self.opt.zero_grad(set_to_none=True)
374
+ total_loss.backward()
375
+ self.opt.step()
376
+ total_loss_epoch += total_loss.item()
377
+
378
+ # Clear memory after full PPO update
379
+ self.memory.clear()
380
+
381
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
382
+
383
+
384
+
385
+
386
+ def update_advantage_norm(self):
387
+ if len(self.memory.states) == 0:
388
+ return 0.0
389
+
390
+ # Convert memory to tensors
391
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
392
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
393
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
394
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
395
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
396
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
397
+
398
+ with T.no_grad():
399
+ # Compute next values (bootstrap for final step)
400
+ next_values = T.cat([values[1:], values[-1:].clone()])
401
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
402
+
403
+ # --- GAE-Lambda ---
404
+ adv = T.zeros_like(rewards)
405
+ gae = 0.0
406
+ for t in reversed(range(len(rewards))):
407
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
408
+ adv[t] = gae
409
+
410
+ # --- Advantage normalization ---
411
+
412
+ returns = adv + values
413
+
414
+ adv = self.advantageNorm.normalize(adv)
415
+
416
+
417
+
418
+
419
+ # --- PPO Multiple Epochs + Minibatch ---
420
+ total_loss_epoch = 0.0
421
+ num_samples = len(states)
422
+ batch_size = min(64, num_samples)
423
+ ppo_epochs = 4
424
+
425
+ for _ in range(ppo_epochs):
426
+ # Shuffle indices
427
+ idxs = T.randperm(num_samples)
428
+ for start in range(0, num_samples, batch_size):
429
+ batch_idx = idxs[start:start + batch_size]
430
+
431
+ b_states = states[batch_idx]
432
+ b_actions = actions[batch_idx]
433
+ b_old_logp = old_logp[batch_idx]
434
+ b_returns = returns[batch_idx]
435
+ b_adv = adv[batch_idx]
436
+
437
+ dist = self.policy.next_action(b_states)
438
+ new_logp = dist.log_prob(b_actions)
439
+ entropy = dist.entropy().mean()
440
+ ratio = (new_logp - b_old_logp).exp()
441
+
442
+ # --- Clipped surrogate objective ---
443
+ surr1 = ratio * b_adv
444
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
445
+ policy_loss = -T.min(surr1, surr2).mean()
446
+
447
+ # --- Critic loss ---
448
+ value_pred = self.critic.evaluated_state(b_states)
449
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
450
+
451
+ # --- Total loss ---
452
+ total_loss = (
453
+ policy_loss +
454
+ self.value_coef * value_loss -
455
+ self.entropy_coef * entropy
456
+ )
457
+
458
+ # Debug: track individual loss components
459
+ self.policy_loss_history.append(policy_loss.item())
460
+ self.value_loss_history.append(value_loss.item())
461
+
462
+ self.opt.zero_grad(set_to_none=True)
463
+ total_loss.backward()
464
+ self.opt.step()
465
+ total_loss_epoch += total_loss.item()
466
+
467
+ # Clear memory after full PPO update
468
+ self.memory.clear()
469
+
470
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
471
+
472
+ def update_return_norm(self):
473
+ if len(self.memory.states) == 0:
474
+ return 0.0
475
+
476
+ # Convert memory to tensors
477
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
478
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
479
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
480
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
481
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
482
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
483
+
484
+ with T.no_grad():
485
+ # Compute next values (bootstrap for final step)
486
+ next_values = T.cat([values[1:], values[-1:].clone()])
487
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
488
+
489
+ # --- GAE-Lambda ---
490
+ adv = T.zeros_like(rewards)
491
+ gae = 0.0
492
+ for t in reversed(range(len(rewards))):
493
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
494
+ adv[t] = gae
495
+
496
+
497
+
498
+ returns = adv + values
499
+
500
+ # --- returns normalization ---
501
+ returns = self.returnNorm.normalize(returns)
502
+
503
+
504
+ # Advantage normalization
505
+ #adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
506
+
507
+ # --- PPO Multiple Epochs + Minibatch ---
508
+ total_loss_epoch = 0.0
509
+ num_samples = len(states)
510
+ batch_size = min(64, num_samples)
511
+ ppo_epochs = 4
512
+
513
+ for _ in range(ppo_epochs):
514
+ # Shuffle indices
515
+ idxs = T.randperm(num_samples)
516
+ for start in range(0, num_samples, batch_size):
517
+ batch_idx = idxs[start:start + batch_size]
518
+
519
+ b_states = states[batch_idx]
520
+ b_actions = actions[batch_idx]
521
+ b_old_logp = old_logp[batch_idx]
522
+ b_returns = returns[batch_idx]
523
+ b_adv = adv[batch_idx]
524
+
525
+ dist = self.policy.next_action(b_states)
526
+ new_logp = dist.log_prob(b_actions)
527
+ entropy = dist.entropy().mean()
528
+ ratio = (new_logp - b_old_logp).exp()
529
+
530
+ # --- Clipped surrogate objective ---
531
+ surr1 = ratio * b_adv
532
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
533
+ policy_loss = -T.min(surr1, surr2).mean()
534
+
535
+ # --- Critic loss ---
536
+ value_pred = self.critic.evaluated_state(b_states)
537
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
538
+
539
+ # --- Total loss ---
540
+ total_loss = (
541
+ policy_loss +
542
+ self.value_coef * value_loss -
543
+ self.entropy_coef * entropy
544
+ )
545
+
546
+ # Debug: track individual loss components
547
+ self.policy_loss_history.append(policy_loss.item())
548
+ self.value_loss_history.append(value_loss.item())
549
+
550
+ self.opt.zero_grad(set_to_none=True)
551
+ total_loss.backward()
552
+ self.opt.step()
553
+ total_loss_epoch += total_loss.item()
554
+
555
+ # Clear memory after full PPO update
556
+ self.memory.clear()
557
+
558
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
559
+
560
+ def update_reward_gradient_clipping(self):
561
+ if len(self.memory.states) == 0:
562
+ return 0.0
563
+
564
+ # Convert memory to tensors
565
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
566
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
567
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
568
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
569
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
570
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
571
+
572
+ # Reward clipping
573
+ rewards = T.clamp(rewards, -1, 1)
574
+
575
+ with T.no_grad():
576
+ # Compute next values (bootstrap for final step)
577
+ next_values = T.cat([values[1:], values[-1:].clone()])
578
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
579
+
580
+ # --- GAE-Lambda ---
581
+ adv = T.zeros_like(rewards)
582
+ gae = 0.0
583
+ for t in reversed(range(len(rewards))):
584
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
585
+ adv[t] = gae
586
+
587
+ returns = adv + values
588
+ # Advantage normalization
589
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
590
+
591
+ # --- PPO Multiple Epochs + Minibatch ---
592
+ total_loss_epoch = 0.0
593
+ num_samples = len(states)
594
+ batch_size = min(64, num_samples)
595
+ ppo_epochs = 4
596
+
597
+ for _ in range(ppo_epochs):
598
+ # Shuffle indices
599
+ idxs = T.randperm(num_samples)
600
+ for start in range(0, num_samples, batch_size):
601
+ batch_idx = idxs[start:start + batch_size]
602
+
603
+ b_states = states[batch_idx]
604
+ b_actions = actions[batch_idx]
605
+ b_old_logp = old_logp[batch_idx]
606
+ b_returns = returns[batch_idx]
607
+ b_adv = adv[batch_idx]
608
+
609
+ dist = self.policy.next_action(b_states)
610
+ new_logp = dist.log_prob(b_actions)
611
+ entropy = dist.entropy().mean()
612
+ ratio = (new_logp - b_old_logp).exp()
613
+
614
+ # --- Clipped surrogate objective ---
615
+ surr1 = ratio * b_adv
616
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
617
+ policy_loss = -T.min(surr1, surr2).mean()
618
+
619
+ # --- Critic loss ---
620
+ value_pred = self.critic.evaluated_state(b_states)
621
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
622
+
623
+ # --- Total loss ---
624
+ total_loss = (
625
+ policy_loss +
626
+ self.value_coef * value_loss -
627
+ self.entropy_coef * entropy
628
+ )
629
+
630
+ # Debug: track individual loss components
631
+ self.policy_loss_history.append(policy_loss.item())
632
+ self.value_loss_history.append(value_loss.item())
633
+
634
+ self.opt.zero_grad(set_to_none=True)
635
+ total_loss.backward()
636
+ T.nn.utils.clip_grad_norm_(list(self.policy.parameters()) + list(self.critic.parameters()), 0.5)
637
+ self.opt.step()
638
+
639
+ total_loss_epoch += total_loss.item()
640
+
641
+ # Clear memory after full PPO update
642
+ self.memory.clear()
643
+
644
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
645
+
646
+ """
647
+ # Policy network (simple MLP, flattened observations)
648
+ class Policy(nn.Module):
649
+ def __init__(self, obs_dim: int, action_dim: int, hidden: int):
650
+ super().__init__()
651
+ self.net = nn.Sequential(
652
+ nn.Linear(obs_dim, hidden),
653
+ nn.ReLU(),
654
+ nn.Linear(hidden, hidden),
655
+ nn.ReLU(),
656
+ nn.Linear(hidden, action_dim)
657
+ )
658
+
659
+ def next_action(self, state: T.Tensor) -> Categorical:
660
+ # Returns the probability distribution over actions
661
+ if state.dim() == 1:
662
+ state = state.unsqueeze(0)
663
+ state = state.view(state.size(0), -1)
664
+ return Categorical(logits=self.net(state))
665
+ """
666
+
667
+ # Policy network (CNN)
668
+ class Policy(nn.Module):
669
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
670
+ super().__init__()
671
+ c, h, w = obs_shape
672
+ # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
673
+ self.cnn = nn.Sequential(
674
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
675
+ nn.ReLU(),
676
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
677
+ nn.ReLU(),
678
+ nn.Flatten()
679
+ )
680
+
681
+ with T.no_grad():
682
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
683
+
684
+ self.net = nn.Sequential(
685
+ nn.Linear(cnn_output_dim, hidden),
686
+ nn.ReLU(),
687
+ nn.Linear(hidden, action_dim)
688
+ )
689
+
690
+ def next_action(self, state: T.Tensor) -> Categorical:
691
+ # Returns the probability distribution over actions
692
+ if state.dim() == 3:
693
+ state = state.unsqueeze(0)
694
+ cnn_out = self.cnn(state)
695
+ return Categorical(logits=self.net(cnn_out))
696
+
697
+ """
698
+ # Critic network (simple MLP, flattened observations)
699
+ class Critic(nn.Module):
700
+ def __init__(self, obs_dim: int, hidden: int):
701
+ super().__init__()
702
+ self.net = nn.Sequential(
703
+ nn.Linear(obs_dim, hidden),
704
+ nn.ReLU(),
705
+ nn.Linear(hidden, hidden),
706
+ nn.ReLU(),
707
+ nn.Linear(hidden, 1)
708
+ )
709
+
710
+ def evaluated_state(self, x: T.Tensor) -> T.Tensor:
711
+ if x.dim() == 1:
712
+ x = x.unsqueeze(0)
713
+ x = x.view(x.size(0), -1)
714
+ return self.net(x).squeeze(-1)
715
+ """
716
+
717
+ # Critic network (CNN)
718
+ class Critic(nn.Module):
719
+ def __init__(self, obs_shape: tuple, hidden: int):
720
+ super().__init__()
721
+ c, h, w = obs_shape
722
+ # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
723
+ self.cnn = nn.Sequential(
724
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
725
+ nn.ReLU(),
726
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
727
+ nn.ReLU(),
728
+ nn.Flatten()
729
+ )
730
+
731
+ with T.no_grad():
732
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
733
+
734
+ self.net = nn.Sequential(
735
+ nn.Linear(cnn_output_dim, hidden),
736
+ nn.ReLU(),
737
+ nn.Linear(hidden, 1)
738
+ )
739
+
740
+ def evaluated_state(self, x: T.Tensor) -> T.Tensor:
741
+ if x.dim() == 3:
742
+ x = x.unsqueeze(0)
743
+ cnn_out = self.cnn(x)
744
+ return self.net(cnn_out).squeeze(-1)
745
+
746
+ class Memory():
747
+ def __init__(self):
748
+ self.states = []
749
+ self.actions = []
750
+ self.rewards = []
751
+ self.dones = []
752
+ self.log_probs = []
753
+ self.values = []
754
+ self.next_values = []
755
+
756
+ def store(self, state, action, reward, done, log_prob, value, next_value):
757
+ self.states.append(np.asarray(state, dtype=np.float32))
758
+ self.actions.append(int(action))
759
+ self.rewards.append(float(reward))
760
+ self.dones.append(float(done))
761
+ self.log_probs.append(float(log_prob))
762
+ self.values.append(float(value))
763
+ self.next_values.append(float(next_value))
764
+
765
+ """
766
+ # For mini-batch updates? To be implemented
767
+ def start_batch(self, batch_size: int):
768
+ n_states = len(self.states)
769
+ starts = np.arange(0, n_states, batch_size)
770
+ index = np.arange(n_states, dtype=np.int64)
771
+ np.random.shuffle(index)
772
+ return [index[s:s + batch_size] for s in starts]
773
+ """
774
+
775
+ def clear(self):
776
+ self.states = []
777
+ self.actions = []
778
+ self.rewards = []
779
+ self.dones = []
780
+ self.log_probs = []
781
+ self.values = []
782
+ self.next_values = []
783
+
784
+
785
+
786
+ class ObservationNorm:
787
+
788
+
789
+ def normalize(self, x):
790
+ return (x - x.mean()) / (x.std(unbiased=False) + 1e-8) # We add epsilon to make sure that we don't
791
+ # divide through zero.
792
+
793
+
794
+
795
+
796
+
797
+ class AdvantageNorm:
798
+ '''
799
+ This class implements the Advantage Normalization. The purpose is to normalize either across batches or
800
+ only within the same batch.
801
+
802
+ '''
803
+
804
+
805
+ def normalize(self, x):
806
+
807
+ return (x - x.mean()) / (x.std(unbiased=False) + 1e-8) # We add epsilon to make sure that we don't
808
+ # divide through zero.
809
+
810
+
811
+
812
+
813
+ class ReturnNorm:
814
+ '''
815
+ This class implements the Advantage Normalization. The purpose is to normalize either across batches or
816
+ only within the same batch.
817
+
818
+ '''
819
+
820
+
821
+ def normalize(self, x):
822
+ return (x - x.mean()) / (x.std(unbiased=False) + 1e-8)
823
+ # We add epsilon to make sure that we don't
824
+ # divide through zero.
825
+
826
+
827
+
828
+
829
+
Observation_Advantage_Norm_in_batch/ppo_rew_norm_obs_env_in_batch.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gymnasium as gym
3
+ import sys
4
+ import matplotlib.pyplot as plt
5
+ import ale_py
6
+ from ppo__rew_norm_obs_in_batch import *
7
+ from gymnasium.spaces import Box
8
+ import cv2
9
+
10
+ def preprocess(obs):
11
+ # Convert to grayscale
12
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
13
+ # Resize
14
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
15
+ # Add channel dimension and normalize
16
+ return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0
17
+
18
+
19
+ def rl_model(type):
20
+ # env = gym.make("ALE/SpaceInvaders-v5", render_mode='human')
21
+ # env = gym.make("ALE/Pacman-v5", render_mode="human")
22
+ env = gym.make("ALE/Pacman-v5")
23
+
24
+ episode = 0
25
+ total_return = 0
26
+ ep_return = 0
27
+ steps = 1000
28
+ batches = 100
29
+
30
+ print("Observation space:", env.observation_space)
31
+ print("Action space:", env.action_space)
32
+ """
33
+ agent = Agent(obs_space=env.observation_space, action_space=env.action_space,
34
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
35
+ entropy_coef=0.01, value_coef=0.5, seed=70,
36
+ batch_size = 64, ppo_epochs = 4, lam = 0.95)
37
+
38
+ """
39
+ # Initialize CNN with a dummy observation (to get correct input shape)
40
+ obs, _ = env.reset()
41
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
42
+ update_type = type
43
+ agent = Agent(obs_space=dummy_obs_space, action_space=env.action_space,
44
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
45
+ entropy_coef=0.01, value_coef=0.5, seed=70,
46
+ batch_size=64, ppo_epochs=4, lam=0.95, update_type=update_type)
47
+ """
48
+ # Stats for Return-Based Scaling only
49
+ # === Return-Based Scaling stats ===
50
+ r_mean, r_var = 0.0, 1e-8
51
+ g2_mean = 1.0
52
+
53
+ agent.r_var = r_var
54
+ agent.g2_mean = g2_mean
55
+ """
56
+
57
+ try:
58
+ obs, info = env.reset(seed=42)
59
+ state = preprocess(obs)
60
+
61
+ loss_history = []
62
+ reward_history = []
63
+
64
+ for update in range(1, batches + 1):
65
+ for t in range(steps):
66
+ action, logp, value = agent.choose_action(state)
67
+ next_obs, reward, terminated, truncated, info = env.step(action)
68
+ done = terminated or truncated
69
+ next_state = preprocess(next_obs)
70
+
71
+ agent.remember(state, action, reward, done, logp, value, next_state)
72
+
73
+ ep_return += reward
74
+ state = next_state
75
+
76
+ if done:
77
+ episode += 1
78
+ total_return += ep_return
79
+ print(f"Episode {episode} return: {ep_return:.2f}")
80
+ ep_return = 0
81
+ obs, info = env.reset()
82
+ state = preprocess(obs)
83
+
84
+ # Using reward gradient clipping
85
+ avg_loss = agent._update()
86
+
87
+ # Vanilla PPO (no normalization)
88
+ # avg_loss = agent.vanilla_ppo_update()
89
+ loss_history.append(avg_loss)
90
+
91
+ avg_ret = (total_return / episode) if episode else 0
92
+ reward_history.append(avg_ret)
93
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")
94
+
95
+ fig = plt.figure(figsize=(12, 8))
96
+
97
+ """
98
+ # Plot for Return-Based Scaling only
99
+ ax1 = plt.subplot(220)
100
+ ax1.plot(agent.sigma_history, label="Return σ")
101
+ ax1.set_xlabel("PPO Update")
102
+ ax1.set_ylabel("σ (Return Std)")
103
+ """
104
+
105
+ ax2 = plt.subplot(221)
106
+ ax2.plot(loss_history, label="Avg Loss")
107
+ ax2.set_ylabel("Average PPO Loss")
108
+ ax2.set_xlabel("PPO Update")
109
+
110
+ ax3 = plt.subplot(222)
111
+ ax3.plot(reward_history, label="Reward")
112
+ ax3.set_ylabel("Reward")
113
+ ax3.set_xlabel("PPO Update")
114
+
115
+ # Details about value loss and policy loss
116
+ ax4 = plt.subplot(223)
117
+ ax4.plot(agent.policy_loss_history, label="Policy Loss", alpha=0.7)
118
+ ax4.set_ylabel("Policy Loss")
119
+ ax4.set_xlabel("Training Step")
120
+ ax4.legend()
121
+
122
+ ax5 = plt.subplot(224)
123
+ ax5.plot(agent.value_loss_history, label="Value Loss", alpha=0.7)
124
+ ax5.set_ylabel("Value Loss")
125
+ ax5.set_xlabel("Training Step")
126
+ ax5.legend()
127
+
128
+ fig.suptitle("PPO Training Stability of type " + update_type +
129
+ "-in_batch")
130
+ fig.tight_layout()
131
+ plt.savefig(type +"_in_batch.png")
132
+
133
+
134
+
135
+
136
+ except Exception as e:
137
+ print(f"Error: {e}", file=sys.stderr)
138
+ return 1
139
+ finally:
140
+ avg = total_return / episode if episode else 0
141
+ print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
142
+ env.close()
143
+
144
+ return 0
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+ def main() -> int:
154
+ type_list = ["update_observation_norm", "update_advantage_norm", "update_return_norm", "vanilla_ppo_update"]
155
+
156
+ for type in type_list:
157
+ rl_model(type)
158
+
159
+ return 0
160
+
161
+
162
+ if __name__ == "__main__":
163
+ raise SystemExit(main())
Observation_Advantage_Norm_in_batch/update_advantage_norm_in_batch.png ADDED
Observation_Advantage_Norm_in_batch/update_observation_norm_in_batch.png ADDED
Observation_Advantage_Norm_in_batch/update_return_norm_in_batch.png ADDED
Observation_Advantage_Norm_in_batch/vanilla_ppo_update_in_batch.png ADDED
Observation_Advantage_Norm_running_averages/ppo__rew_norm_obs_running_average.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as T
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.distributions import Categorical
6
+
7
+
8
+ class Agent:
9
+ def __init__(
10
+ self,
11
+ obs_space,
12
+ action_space,
13
+ hidden,
14
+ gamma,
15
+ clip_coef,
16
+ lr,
17
+ value_coef,
18
+ entropy_coef,
19
+ seed,
20
+ batch_size,
21
+ ppo_epochs,
22
+ lam,
23
+ update_type
24
+
25
+ ):
26
+ # Initialize seed for reproducibility
27
+ if seed is not None:
28
+ np.random.seed(seed)
29
+ T.manual_seed(seed)
30
+ """
31
+ # For flat observations (MLP model)
32
+ # Use GPU if available
33
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
34
+ self.obs_dim = int(np.prod(getattr(obs_space, "shape", (obs_space,))))
35
+ self.action_dim = int(getattr(action_space, "n", action_space))
36
+
37
+ # Initialize the policy and the critic networks
38
+ self.policy = Policy(self.obs_dim, self.action_dim, hidden).to(self.device)
39
+ self.critic = Critic(self.obs_dim, hidden).to(self.device)
40
+ """
41
+ # Use GPU if available
42
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
43
+ self.action_dim = int(getattr(action_space, "n", action_space))
44
+ self.update_type = update_type
45
+
46
+ # Initialize the policy and the critic networks
47
+ # Pass the shape tuple directly, not the flattened dimension.
48
+ self.policy = Policy(obs_space.shape, self.action_dim, hidden).to(self.device)
49
+ self.critic = Critic(obs_space.shape, hidden).to(self.device)
50
+ self.observeNorm = ObservationNorm()
51
+ self.advantageNorm = AdvantageNorm()
52
+ self.returnNorm = ReturnNorm()
53
+
54
+ # Set optimizer for policy and critic networks
55
+ self.opt = optim.Adam(
56
+ list(self.policy.parameters()) + list(self.critic.parameters()),
57
+ lr=lr
58
+ )
59
+
60
+ self.gamma = gamma
61
+ self.clip = clip_coef
62
+ self.value_coef = value_coef
63
+ self.entropy_coef = entropy_coef
64
+ self.sigma_history = []
65
+ self.loss_history = []
66
+ self.policy_loss_history = []
67
+ self.value_loss_history = []
68
+ self.entropy_history = []
69
+ self.lam = lam
70
+ self.ppo_epochs = ppo_epochs
71
+ self.batch_size = batch_size
72
+
73
+ self.memory = Memory()
74
+ """
75
+ # Choose action and remember for flat observations (MLP model)
76
+ def choose_action(self, observation):
77
+ # Returns: action, log probabilitiy, value of the state
78
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device).view(-1)
79
+ with T.no_grad():
80
+ # Forward function (defined in Policy class)
81
+ dist = self.policy.next_action(state)
82
+ action = dist.sample()
83
+ logp = dist.log_prob(action)
84
+ value = self.critic.evaluated_state(state)
85
+ return int(action.item()), float(logp.item()), float(value.item())
86
+
87
+ def remember(self, state, action, reward, done, log_prob, value, next_state):
88
+ with T.no_grad():
89
+ # Pass on next state and have it evaluated by the critic network
90
+ ns = T.as_tensor(next_state, dtype=T.float32, device=self.device).view(-1)
91
+ next_value = self.critic.evaluated_state(ns).item()
92
+ self.memory.store(state, action, reward, done, log_prob, value, next_value)
93
+ """
94
+ # For CNN model
95
+ def choose_action(self, observation):
96
+ # Returns: action, log probabilitiy, value of the state
97
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device) # Remove .view(-1)
98
+ with T.no_grad():
99
+ # Forward function (defined in Policy class)
100
+ dist = self.policy.next_action(state)
101
+ action = dist.sample()
102
+ logp = dist.log_prob(action)
103
+ value = self.critic.evaluated_state(state)
104
+ return int(action.item()), float(logp.item()), float(value.item())
105
+
106
+ def remember(self, state, action, reward, done, log_prob, value, next_state):
107
+ with T.no_grad():
108
+ # Pass on next state and have it evaluated by the critic network
109
+ ns = T.as_tensor(next_state, dtype=T.float32, device=self.device) # Remove .view(-1)
110
+ next_value = self.critic.evaluated_state(ns).item()
111
+ self.memory.store(state, action, reward, done, log_prob, value, next_value)
112
+
113
+
114
+ def _update(self):
115
+ if self.update_type == "update_observation_norm":
116
+ return self.update_observation_norm()
117
+ elif self.update_type == "update_advantage_norm":
118
+ return self.update_advantage_norm()
119
+ elif self.update_type == "update_return_norm":
120
+ return self.update_return_norm()
121
+ else:
122
+ return self.vanilla_ppo_update()
123
+
124
+ def vanilla_ppo_update(self):
125
+ if len(self.memory.states) == 0:
126
+ return 0.0
127
+
128
+ # Convert memory to tensors
129
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
130
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
131
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
132
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
133
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
134
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
135
+
136
+ with T.no_grad():
137
+ # Compute next values (bootstrap for final step)
138
+ next_values = T.cat([values[1:], values[-1:].clone()])
139
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
140
+
141
+ # --- GAE-Lambda ---
142
+ adv = T.zeros_like(rewards)
143
+ gae = 0.0
144
+ for t in reversed(range(len(rewards))):
145
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
146
+ adv[t] = gae
147
+
148
+ returns = adv + values
149
+ # Advantage normalization
150
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
151
+
152
+ # --- PPO Multiple Epochs + Minibatch ---
153
+ total_loss_epoch = 0.0
154
+ num_samples = len(states)
155
+ batch_size = min(64, num_samples)
156
+ ppo_epochs = 4
157
+
158
+ for _ in range(ppo_epochs):
159
+ # Shuffle indices
160
+ idxs = T.randperm(num_samples)
161
+ for start in range(0, num_samples, batch_size):
162
+ batch_idx = idxs[start:start + batch_size]
163
+
164
+ b_states = states[batch_idx]
165
+ b_actions = actions[batch_idx]
166
+ b_old_logp = old_logp[batch_idx]
167
+ b_returns = returns[batch_idx]
168
+ b_adv = adv[batch_idx]
169
+
170
+ dist = self.policy.next_action(b_states)
171
+ new_logp = dist.log_prob(b_actions)
172
+ entropy = dist.entropy().mean()
173
+ ratio = (new_logp - b_old_logp).exp()
174
+
175
+ # --- Clipped surrogate objective ---
176
+ surr1 = ratio * b_adv
177
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
178
+ policy_loss = -T.min(surr1, surr2).mean()
179
+
180
+ # --- Critic loss ---
181
+ value_pred = self.critic.evaluated_state(b_states)
182
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
183
+
184
+ # --- Total loss ---
185
+ total_loss = (
186
+ policy_loss +
187
+ self.value_coef * value_loss -
188
+ self.entropy_coef * entropy
189
+ )
190
+
191
+ # Debug: track individual loss components
192
+ self.policy_loss_history.append(policy_loss.item())
193
+ self.value_loss_history.append(value_loss.item())
194
+
195
+ self.opt.zero_grad(set_to_none=True)
196
+ total_loss.backward()
197
+ self.opt.step()
198
+
199
+ total_loss_epoch += total_loss.item()
200
+
201
+ # Clear memory after full PPO update
202
+ self.memory.clear()
203
+
204
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
205
+
206
+
207
+ def update_rbs(self):
208
+ if len(self.memory.states) == 0:
209
+ return 0.0
210
+
211
+ # Convert memory to tensors
212
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
213
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
214
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
215
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
216
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
217
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
218
+
219
+ with T.no_grad():
220
+ # Compute next values (bootstrap for final step)
221
+ next_values = T.cat([values[1:], values[-1:].clone()])
222
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
223
+
224
+ # --- GAE-Lambda ---
225
+ adv = T.zeros_like(rewards)
226
+ gae = 0.0
227
+ for t in reversed(range(len(rewards))):
228
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
229
+ adv[t] = gae
230
+
231
+ returns = adv + values
232
+
233
+ # --- Return-based normalization (RBS) ---
234
+ sigma_t = returns.std(unbiased=False) + 1e-8
235
+ returns = returns / sigma_t
236
+ self.sigma_history.append(sigma_t.item())
237
+ adv = adv / sigma_t
238
+ # Advantage normalization
239
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
240
+
241
+ # --- PPO Multiple Epochs + Minibatch ---
242
+ total_loss_epoch = 0.0
243
+ num_samples = len(states)
244
+ batch_size = min(64, num_samples)
245
+ ppo_epochs = 4
246
+
247
+ for _ in range(ppo_epochs):
248
+ # Shuffle indices
249
+ idxs = T.randperm(num_samples)
250
+ for start in range(0, num_samples, batch_size):
251
+ batch_idx = idxs[start:start + batch_size]
252
+
253
+ b_states = states[batch_idx]
254
+ b_actions = actions[batch_idx]
255
+ b_old_logp = old_logp[batch_idx]
256
+ b_returns = returns[batch_idx]
257
+ b_adv = adv[batch_idx]
258
+
259
+ dist = self.policy.next_action(b_states)
260
+ new_logp = dist.log_prob(b_actions)
261
+ entropy = dist.entropy().mean()
262
+ ratio = (new_logp - b_old_logp).exp()
263
+
264
+ # --- Clipped surrogate objective ---
265
+ surr1 = ratio * b_adv
266
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
267
+ policy_loss = -T.min(surr1, surr2).mean()
268
+
269
+ # --- Critic loss ---
270
+ value_pred = self.critic.evaluated_state(b_states)
271
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
272
+
273
+ # --- Total loss ---
274
+ total_loss = (
275
+ policy_loss +
276
+ self.value_coef * value_loss -
277
+ self.entropy_coef * entropy
278
+ )
279
+
280
+ # Debug: track individual loss components
281
+ self.policy_loss_history.append(policy_loss.item())
282
+ self.value_loss_history.append(value_loss.item())
283
+
284
+ self.opt.zero_grad(set_to_none=True)
285
+ total_loss.backward()
286
+ self.opt.step()
287
+ total_loss_epoch += total_loss.item()
288
+
289
+ # Clear memory after full PPO update
290
+ self.memory.clear()
291
+
292
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
293
+
294
+
295
+
296
+
297
+
298
+
299
+ def update_observation_norm(self):
300
+ if len(self.memory.states) == 0:
301
+ return 0.0
302
+
303
+ # Convert memory to tensors
304
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
305
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
306
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
307
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
308
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
309
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
310
+
311
+ with T.no_grad():
312
+ # Compute next values (bootstrap for final step)
313
+ next_values = T.cat([values[1:], values[-1:].clone()])
314
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
315
+
316
+ # --- GAE-Lambda ---
317
+ adv = T.zeros_like(rewards)
318
+ gae = 0.0
319
+ for t in reversed(range(len(rewards))):
320
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
321
+ adv[t] = gae
322
+
323
+ returns = adv + values
324
+
325
+ # --- observation normalization ---
326
+ self.observeNorm.update(states)
327
+ states = self.observeNorm.normalize(states)
328
+ # Advantage normalization
329
+ #adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
330
+
331
+ # --- PPO Multiple Epochs + Minibatch ---
332
+ total_loss_epoch = 0.0
333
+ num_samples = len(states)
334
+ batch_size = min(64, num_samples)
335
+ ppo_epochs = 4
336
+
337
+ for _ in range(ppo_epochs):
338
+ # Shuffle indices
339
+ idxs = T.randperm(num_samples)
340
+ for start in range(0, num_samples, batch_size):
341
+ batch_idx = idxs[start:start + batch_size]
342
+
343
+ b_states = states[batch_idx]
344
+ b_actions = actions[batch_idx]
345
+ b_old_logp = old_logp[batch_idx]
346
+ b_returns = returns[batch_idx]
347
+ b_adv = adv[batch_idx]
348
+
349
+ dist = self.policy.next_action(b_states)
350
+ new_logp = dist.log_prob(b_actions)
351
+ entropy = dist.entropy().mean()
352
+ ratio = (new_logp - b_old_logp).exp()
353
+
354
+ # --- Clipped surrogate objective ---
355
+ surr1 = ratio * b_adv
356
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
357
+ policy_loss = -T.min(surr1, surr2).mean()
358
+
359
+ # --- Critic loss ---
360
+ value_pred = self.critic.evaluated_state(b_states)
361
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
362
+
363
+ # --- Total loss ---
364
+ total_loss = (
365
+ policy_loss +
366
+ self.value_coef * value_loss -
367
+ self.entropy_coef * entropy
368
+ )
369
+
370
+ # Debug: track individual loss components
371
+ self.policy_loss_history.append(policy_loss.item())
372
+ self.value_loss_history.append(value_loss.item())
373
+
374
+ self.opt.zero_grad(set_to_none=True)
375
+ total_loss.backward()
376
+ self.opt.step()
377
+ total_loss_epoch += total_loss.item()
378
+
379
+ # Clear memory after full PPO update
380
+ self.memory.clear()
381
+
382
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
383
+
384
+
385
+
386
+
387
+ def update_advantage_norm(self):
388
+ if len(self.memory.states) == 0:
389
+ return 0.0
390
+
391
+ # Convert memory to tensors
392
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
393
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
394
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
395
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
396
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
397
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
398
+
399
+ with T.no_grad():
400
+ # Compute next values (bootstrap for final step)
401
+ next_values = T.cat([values[1:], values[-1:].clone()])
402
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
403
+
404
+ # --- GAE-Lambda ---
405
+ adv = T.zeros_like(rewards)
406
+ gae = 0.0
407
+ for t in reversed(range(len(rewards))):
408
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
409
+ adv[t] = gae
410
+
411
+
412
+
413
+ returns = adv + values
414
+
415
+ # --- Advantage normalization ---
416
+ self.advantageNorm.update(adv)
417
+ adv = self.advantageNorm.normalize(adv)
418
+
419
+
420
+
421
+ # --- PPO Multiple Epochs + Minibatch ---
422
+ total_loss_epoch = 0.0
423
+ num_samples = len(states)
424
+ batch_size = min(64, num_samples)
425
+ ppo_epochs = 4
426
+
427
+ for _ in range(ppo_epochs):
428
+ # Shuffle indices
429
+ idxs = T.randperm(num_samples)
430
+ for start in range(0, num_samples, batch_size):
431
+ batch_idx = idxs[start:start + batch_size]
432
+
433
+ b_states = states[batch_idx]
434
+ b_actions = actions[batch_idx]
435
+ b_old_logp = old_logp[batch_idx]
436
+ b_returns = returns[batch_idx]
437
+ b_adv = adv[batch_idx]
438
+
439
+ dist = self.policy.next_action(b_states)
440
+ new_logp = dist.log_prob(b_actions)
441
+ entropy = dist.entropy().mean()
442
+ ratio = (new_logp - b_old_logp).exp()
443
+
444
+ # --- Clipped surrogate objective ---
445
+ surr1 = ratio * b_adv
446
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
447
+ policy_loss = -T.min(surr1, surr2).mean()
448
+
449
+ # --- Critic loss ---
450
+ value_pred = self.critic.evaluated_state(b_states)
451
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
452
+
453
+ # --- Total loss ---
454
+ total_loss = (
455
+ policy_loss +
456
+ self.value_coef * value_loss -
457
+ self.entropy_coef * entropy
458
+ )
459
+
460
+ # Debug: track individual loss components
461
+ self.policy_loss_history.append(policy_loss.item())
462
+ self.value_loss_history.append(value_loss.item())
463
+
464
+ self.opt.zero_grad(set_to_none=True)
465
+ total_loss.backward()
466
+ self.opt.step()
467
+ total_loss_epoch += total_loss.item()
468
+
469
+ # Clear memory after full PPO update
470
+ self.memory.clear()
471
+
472
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
473
+
474
+ def update_return_norm(self):
475
+ if len(self.memory.states) == 0:
476
+ return 0.0
477
+
478
+ # Convert memory to tensors
479
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
480
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
481
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
482
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
483
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
484
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
485
+
486
+ with T.no_grad():
487
+ # Compute next values (bootstrap for final step)
488
+ next_values = T.cat([values[1:], values[-1:].clone()])
489
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
490
+
491
+ # --- GAE-Lambda ---
492
+ adv = T.zeros_like(rewards)
493
+ gae = 0.0
494
+ for t in reversed(range(len(rewards))):
495
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
496
+ adv[t] = gae
497
+
498
+
499
+
500
+ returns = adv + values
501
+
502
+ # --- returns normalization ---
503
+ self.returnNorm.update(returns)
504
+ returns = self.returnNorm.normalize(returns)
505
+
506
+
507
+ # Advantage normalization
508
+ #adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
509
+
510
+ # --- PPO Multiple Epochs + Minibatch ---
511
+ total_loss_epoch = 0.0
512
+ num_samples = len(states)
513
+ batch_size = min(64, num_samples)
514
+ ppo_epochs = 4
515
+
516
+ for _ in range(ppo_epochs):
517
+ # Shuffle indices
518
+ idxs = T.randperm(num_samples)
519
+ for start in range(0, num_samples, batch_size):
520
+ batch_idx = idxs[start:start + batch_size]
521
+
522
+ b_states = states[batch_idx]
523
+ b_actions = actions[batch_idx]
524
+ b_old_logp = old_logp[batch_idx]
525
+ b_returns = returns[batch_idx]
526
+ b_adv = adv[batch_idx]
527
+
528
+ dist = self.policy.next_action(b_states)
529
+ new_logp = dist.log_prob(b_actions)
530
+ entropy = dist.entropy().mean()
531
+ ratio = (new_logp - b_old_logp).exp()
532
+
533
+ # --- Clipped surrogate objective ---
534
+ surr1 = ratio * b_adv
535
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
536
+ policy_loss = -T.min(surr1, surr2).mean()
537
+
538
+ # --- Critic loss ---
539
+ value_pred = self.critic.evaluated_state(b_states)
540
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
541
+
542
+ # --- Total loss ---
543
+ total_loss = (
544
+ policy_loss +
545
+ self.value_coef * value_loss -
546
+ self.entropy_coef * entropy
547
+ )
548
+
549
+ # Debug: track individual loss components
550
+ self.policy_loss_history.append(policy_loss.item())
551
+ self.value_loss_history.append(value_loss.item())
552
+
553
+ self.opt.zero_grad(set_to_none=True)
554
+ total_loss.backward()
555
+ self.opt.step()
556
+ total_loss_epoch += total_loss.item()
557
+
558
+ # Clear memory after full PPO update
559
+ self.memory.clear()
560
+
561
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
562
+
563
+ def update_reward_gradient_clipping(self):
564
+ if len(self.memory.states) == 0:
565
+ return 0.0
566
+
567
+ # Convert memory to tensors
568
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
569
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
570
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
571
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
572
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
573
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
574
+
575
+ # Reward clipping
576
+ rewards = T.clamp(rewards, -1, 1)
577
+
578
+ with T.no_grad():
579
+ # Compute next values (bootstrap for final step)
580
+ next_values = T.cat([values[1:], values[-1:].clone()])
581
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
582
+
583
+ # --- GAE-Lambda ---
584
+ adv = T.zeros_like(rewards)
585
+ gae = 0.0
586
+ for t in reversed(range(len(rewards))):
587
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
588
+ adv[t] = gae
589
+
590
+ returns = adv + values
591
+ # Advantage normalization
592
+ #adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
593
+
594
+ # --- PPO Multiple Epochs + Minibatch ---
595
+ total_loss_epoch = 0.0
596
+ num_samples = len(states)
597
+ batch_size = min(64, num_samples)
598
+ ppo_epochs = 4
599
+
600
+ for _ in range(ppo_epochs):
601
+ # Shuffle indices
602
+ idxs = T.randperm(num_samples)
603
+ for start in range(0, num_samples, batch_size):
604
+ batch_idx = idxs[start:start + batch_size]
605
+
606
+ b_states = states[batch_idx]
607
+ b_actions = actions[batch_idx]
608
+ b_old_logp = old_logp[batch_idx]
609
+ b_returns = returns[batch_idx]
610
+ b_adv = adv[batch_idx]
611
+
612
+ dist = self.policy.next_action(b_states)
613
+ new_logp = dist.log_prob(b_actions)
614
+ entropy = dist.entropy().mean()
615
+ ratio = (new_logp - b_old_logp).exp()
616
+
617
+ # --- Clipped surrogate objective ---
618
+ surr1 = ratio * b_adv
619
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
620
+ policy_loss = -T.min(surr1, surr2).mean()
621
+
622
+ # --- Critic loss ---
623
+ value_pred = self.critic.evaluated_state(b_states)
624
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
625
+
626
+ # --- Total loss ---
627
+ total_loss = (
628
+ policy_loss +
629
+ self.value_coef * value_loss -
630
+ self.entropy_coef * entropy
631
+ )
632
+
633
+ # Debug: track individual loss components
634
+ self.policy_loss_history.append(policy_loss.item())
635
+ self.value_loss_history.append(value_loss.item())
636
+
637
+ self.opt.zero_grad(set_to_none=True)
638
+ total_loss.backward()
639
+ T.nn.utils.clip_grad_norm_(list(self.policy.parameters()) + list(self.critic.parameters()), 0.5)
640
+ self.opt.step()
641
+
642
+ total_loss_epoch += total_loss.item()
643
+
644
+ # Clear memory after full PPO update
645
+ self.memory.clear()
646
+
647
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
648
+
649
+ """
650
+ # Policy network (simple MLP, flattened observations)
651
+ class Policy(nn.Module):
652
+ def __init__(self, obs_dim: int, action_dim: int, hidden: int):
653
+ super().__init__()
654
+ self.net = nn.Sequential(
655
+ nn.Linear(obs_dim, hidden),
656
+ nn.ReLU(),
657
+ nn.Linear(hidden, hidden),
658
+ nn.ReLU(),
659
+ nn.Linear(hidden, action_dim)
660
+ )
661
+
662
+ def next_action(self, state: T.Tensor) -> Categorical:
663
+ # Returns the probability distribution over actions
664
+ if state.dim() == 1:
665
+ state = state.unsqueeze(0)
666
+ state = state.view(state.size(0), -1)
667
+ return Categorical(logits=self.net(state))
668
+ """
669
+
670
+ # Policy network (CNN)
671
+ class Policy(nn.Module):
672
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
673
+ super().__init__()
674
+ c, h, w = obs_shape
675
+ # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
676
+ self.cnn = nn.Sequential(
677
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
678
+ nn.ReLU(),
679
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
680
+ nn.ReLU(),
681
+ nn.Flatten()
682
+ )
683
+
684
+ with T.no_grad():
685
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
686
+
687
+ self.net = nn.Sequential(
688
+ nn.Linear(cnn_output_dim, hidden),
689
+ nn.ReLU(),
690
+ nn.Linear(hidden, action_dim)
691
+ )
692
+
693
+ def next_action(self, state: T.Tensor) -> Categorical:
694
+ # Returns the probability distribution over actions
695
+ if state.dim() == 3:
696
+ state = state.unsqueeze(0)
697
+ cnn_out = self.cnn(state)
698
+ return Categorical(logits=self.net(cnn_out))
699
+
700
+ """
701
+ # Critic network (simple MLP, flattened observations)
702
+ class Critic(nn.Module):
703
+ def __init__(self, obs_dim: int, hidden: int):
704
+ super().__init__()
705
+ self.net = nn.Sequential(
706
+ nn.Linear(obs_dim, hidden),
707
+ nn.ReLU(),
708
+ nn.Linear(hidden, hidden),
709
+ nn.ReLU(),
710
+ nn.Linear(hidden, 1)
711
+ )
712
+
713
+ def evaluated_state(self, x: T.Tensor) -> T.Tensor:
714
+ if x.dim() == 1:
715
+ x = x.unsqueeze(0)
716
+ x = x.view(x.size(0), -1)
717
+ return self.net(x).squeeze(-1)
718
+ """
719
+
720
+ # Critic network (CNN)
721
+ class Critic(nn.Module):
722
+ def __init__(self, obs_shape: tuple, hidden: int):
723
+ super().__init__()
724
+ c, h, w = obs_shape
725
+ # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
726
+ self.cnn = nn.Sequential(
727
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
728
+ nn.ReLU(),
729
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
730
+ nn.ReLU(),
731
+ nn.Flatten()
732
+ )
733
+
734
+ with T.no_grad():
735
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
736
+
737
+ self.net = nn.Sequential(
738
+ nn.Linear(cnn_output_dim, hidden),
739
+ nn.ReLU(),
740
+ nn.Linear(hidden, 1)
741
+ )
742
+
743
+ def evaluated_state(self, x: T.Tensor) -> T.Tensor:
744
+ if x.dim() == 3:
745
+ x = x.unsqueeze(0)
746
+ cnn_out = self.cnn(x)
747
+ return self.net(cnn_out).squeeze(-1)
748
+
749
+ class Memory():
750
+ def __init__(self):
751
+ self.states = []
752
+ self.actions = []
753
+ self.rewards = []
754
+ self.dones = []
755
+ self.log_probs = []
756
+ self.values = []
757
+ self.next_values = []
758
+
759
+ def store(self, state, action, reward, done, log_prob, value, next_value):
760
+ self.states.append(np.asarray(state, dtype=np.float32))
761
+ self.actions.append(int(action))
762
+ self.rewards.append(float(reward))
763
+ self.dones.append(float(done))
764
+ self.log_probs.append(float(log_prob))
765
+ self.values.append(float(value))
766
+ self.next_values.append(float(next_value))
767
+
768
+ """
769
+ # For mini-batch updates? To be implemented
770
+ def start_batch(self, batch_size: int):
771
+ n_states = len(self.states)
772
+ starts = np.arange(0, n_states, batch_size)
773
+ index = np.arange(n_states, dtype=np.int64)
774
+ np.random.shuffle(index)
775
+ return [index[s:s + batch_size] for s in starts]
776
+ """
777
+
778
+ def clear(self):
779
+ self.states = []
780
+ self.actions = []
781
+ self.rewards = []
782
+ self.dones = []
783
+ self.log_probs = []
784
+ self.values = []
785
+ self.next_values = []
786
+
787
+
788
+
789
+ class ObservationNorm:
790
+ def __init__(self):
791
+ self.main_mean = 0
792
+ self.main_var = 0
793
+ self.count = 1e-4
794
+
795
+ def update(self, x: T.Tensor):
796
+ batch_mean = T.mean(x, dim=0)
797
+ batch_var = T.var(x, dim=0)
798
+ batch_count = x.shape[0]
799
+ self._update_from_moments(batch_mean, batch_var, batch_count)
800
+
801
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
802
+ delta = batch_mean - self.main_mean
803
+ tot_count = self.count + batch_count
804
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
805
+ m_a = self.main_var * self.count
806
+ m_b = batch_var * batch_count
807
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
808
+ new_var = M2 / tot_count # update the running variance
809
+
810
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
811
+
812
+ def normalize(self, x):
813
+
814
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
815
+ # divide through zero.
816
+
817
+
818
+
819
+
820
+
821
+ class AdvantageNorm:
822
+ '''
823
+ This class implements the Advantage Normalization. The purpose is to normalize either across batches or
824
+ only within the same batch.
825
+
826
+ '''
827
+ def __init__(self):
828
+ self.main_mean = 0
829
+ self.main_var = 0
830
+ self.count = 1e-4
831
+
832
+ def update(self, x: T.Tensor):
833
+ batch_mean = T.mean(x, dim=0)
834
+ batch_var = T.var(x, dim=0)
835
+ batch_count = x.shape[0]
836
+ self._update_from_moments(batch_mean, batch_var, batch_count)
837
+
838
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
839
+ delta = batch_mean - self.main_mean
840
+ tot_count = self.count + batch_count
841
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
842
+ m_a = self.main_var * self.count
843
+ m_b = batch_var * batch_count
844
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
845
+ new_var = M2 / tot_count # update the running variance
846
+
847
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
848
+
849
+ def normalize(self, x):
850
+
851
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
852
+ # divide through zero.
853
+
854
+
855
+
856
+
857
+ class ReturnNorm:
858
+ '''
859
+ This class implements the Advantage Normalization. The purpose is to normalize either across batches or
860
+ only within the same batch.
861
+
862
+ '''
863
+ def __init__(self):
864
+ self.main_mean = 0
865
+ self.main_var = 0
866
+ self.count = 1e-4
867
+
868
+ def update(self, x: T.Tensor):
869
+ batch_mean = T.mean(x, dim=0)
870
+ batch_var = T.var(x, dim=0)
871
+ batch_count = x.shape[0]
872
+ self._update_from_moments(batch_mean, batch_var, batch_count)
873
+
874
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
875
+ delta = batch_mean - self.main_mean
876
+ tot_count = self.count + batch_count
877
+ new_mean = self.main_mean + delta * batch_count / tot_count #Update the running mean
878
+ m_a = self.main_var * self.count
879
+ m_b = batch_var * batch_count
880
+ M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
881
+ new_var = M2 / tot_count # update the running variance
882
+
883
+ self.main_mean, self.main_var, self.count = new_mean, new_var, tot_count
884
+
885
+ def normalize(self, x):
886
+
887
+ return (x - self.main_mean) / (np.sqrt(self.main_var) + 1e-8) # We add epsilon to make sure that we don't
888
+ # divide through zero.
889
+
890
+
891
+
892
+
893
+
Observation_Advantage_Norm_running_averages/ppo_rew_norm_obs_env_running_average.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gymnasium as gym
3
+ import sys
4
+ import matplotlib.pyplot as plt
5
+ import ale_py
6
+ from ppo__rew_norm_obs_running_average import *
7
+ from gymnasium.spaces import Box
8
+ import cv2
9
+
10
+ def preprocess(obs):
11
+ # Convert to grayscale
12
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
13
+ # Resize
14
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
15
+ # Add channel dimension and normalize
16
+ return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0
17
+
18
+
19
+ def rl_model(type):
20
+ # env = gym.make("ALE/SpaceInvaders-v5", render_mode='human')
21
+ # env = gym.make("ALE/Pacman-v5", render_mode="human")
22
+ env = gym.make("ALE/Pacman-v5")
23
+
24
+ episode = 0
25
+ total_return = 0
26
+ ep_return = 0
27
+ steps = 1000
28
+ batches = 100
29
+
30
+ print("Observation space:", env.observation_space)
31
+ print("Action space:", env.action_space)
32
+ """
33
+ agent = Agent(obs_space=env.observation_space, action_space=env.action_space,
34
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
35
+ entropy_coef=0.01, value_coef=0.5, seed=70,
36
+ batch_size = 64, ppo_epochs = 4, lam = 0.95)
37
+
38
+ """
39
+ # Initialize CNN with a dummy observation (to get correct input shape)
40
+ obs, _ = env.reset()
41
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
42
+ update_type = type
43
+ agent = Agent(obs_space=dummy_obs_space, action_space=env.action_space,
44
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
45
+ entropy_coef=0.01, value_coef=0.5, seed=70,
46
+ batch_size=64, ppo_epochs=4, lam=0.95, update_type=update_type)
47
+ """
48
+ # Stats for Return-Based Scaling only
49
+ # === Return-Based Scaling stats ===
50
+ r_mean, r_var = 0.0, 1e-8
51
+ g2_mean = 1.0
52
+
53
+ agent.r_var = r_var
54
+ agent.g2_mean = g2_mean
55
+ """
56
+
57
+ try:
58
+ obs, info = env.reset(seed=42)
59
+ state = preprocess(obs)
60
+
61
+ loss_history = []
62
+ reward_history = []
63
+
64
+ for update in range(1, batches + 1):
65
+ for t in range(steps):
66
+ action, logp, value = agent.choose_action(state)
67
+ next_obs, reward, terminated, truncated, info = env.step(action)
68
+ done = terminated or truncated
69
+ next_state = preprocess(next_obs)
70
+
71
+ agent.remember(state, action, reward, done, logp, value, next_state)
72
+
73
+ ep_return += reward
74
+ state = next_state
75
+
76
+ if done:
77
+ episode += 1
78
+ total_return += ep_return
79
+ print(f"Episode {episode} return: {ep_return:.2f}")
80
+ ep_return = 0
81
+ obs, info = env.reset()
82
+ state = preprocess(obs)
83
+
84
+ # Using reward gradient clipping
85
+ avg_loss = agent._update()
86
+
87
+ # Vanilla PPO (no normalization)
88
+ # avg_loss = agent.vanilla_ppo_update()
89
+ loss_history.append(avg_loss)
90
+
91
+ avg_ret = (total_return / episode) if episode else 0
92
+ reward_history.append(avg_ret)
93
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")
94
+
95
+ fig = plt.figure(figsize=(12, 8))
96
+
97
+ """
98
+ # Plot for Return-Based Scaling only
99
+ ax1 = plt.subplot(220)
100
+ ax1.plot(agent.sigma_history, label="Return σ")
101
+ ax1.set_xlabel("PPO Update")
102
+ ax1.set_ylabel("σ (Return Std)")
103
+ """
104
+
105
+ ax2 = plt.subplot(221)
106
+ ax2.plot(loss_history, label="Avg Loss")
107
+ ax2.set_ylabel("Average PPO Loss")
108
+ ax2.set_xlabel("PPO Update")
109
+
110
+ ax3 = plt.subplot(222)
111
+ ax3.plot(reward_history, label="Reward")
112
+ ax3.set_ylabel("Reward")
113
+ ax3.set_xlabel("PPO Update")
114
+
115
+ # Details about value loss and policy loss
116
+ ax4 = plt.subplot(223)
117
+ ax4.plot(agent.policy_loss_history, label="Policy Loss", alpha=0.7)
118
+ ax4.set_ylabel("Policy Loss")
119
+ ax4.set_xlabel("Training Step")
120
+ ax4.legend()
121
+
122
+ ax5 = plt.subplot(224)
123
+ ax5.plot(agent.value_loss_history, label="Value Loss", alpha=0.7)
124
+ ax5.set_ylabel("Value Loss")
125
+ ax5.set_xlabel("Training Step")
126
+ ax5.legend()
127
+
128
+ fig.suptitle("PPO Training Stability of type " + update_type +
129
+ "-running_average")
130
+ fig.tight_layout()
131
+ plt.savefig(type +"_running_average_.png")
132
+
133
+
134
+
135
+
136
+ except Exception as e:
137
+ print(f"Error: {e}", file=sys.stderr)
138
+ return 1
139
+ finally:
140
+ avg = total_return / episode if episode else 0
141
+ print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
142
+ env.close()
143
+
144
+ return 0
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+ def main() -> int:
154
+ type_list = ["update_observation_norm", "update_advantage_norm", "update_return_norm", "vanilla_ppo_update"]
155
+
156
+ for type in type_list:
157
+ rl_model(type)
158
+
159
+ return 0
160
+
161
+
162
+ if __name__ == "__main__":
163
+ raise SystemExit(main())
Observation_Advantage_Norm_running_averages/update_advantage_norm_running_average_.png ADDED
Observation_Advantage_Norm_running_averages/update_observation_norm_running_average_.png ADDED
Observation_Advantage_Norm_running_averages/update_return_norm_running_average_.png ADDED
Observation_Advantage_Norm_running_averages/vanilla_ppo_update_running_average_.png ADDED