rl-project-7Oct commited on
Commit
9763567
·
verified ·
1 Parent(s): bcb0c1c

Added vanilla_ppo_update (base case w/o fancy normalizations)

Browse files
Files changed (1) hide show
  1. CNN_PPO/ppo_helpers_cnn.py +96 -11
CNN_PPO/ppo_helpers_cnn.py CHANGED
@@ -106,6 +106,89 @@ class Agent:
106
  next_value = self.critic.evaluated_state(ns).item()
107
  self.memory.store(state, action, reward, done, log_prob, value, next_value)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def update_rbs(self):
110
  if len(self.memory.states) == 0:
111
  return 0.0
@@ -135,9 +218,10 @@ class Agent:
135
  # --- Return-based normalization (RBS) ---
136
  sigma_t = returns.std(unbiased=False) + 1e-8
137
  returns = returns / sigma_t
 
138
  adv = adv / sigma_t
 
139
  adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
140
- self.sigma_history.append(sigma_t.item())
141
 
142
  # --- PPO Multiple Epochs + Minibatch ---
143
  total_loss_epoch = 0.0
@@ -178,6 +262,10 @@ class Agent:
178
  self.entropy_coef * entropy
179
  )
180
 
 
 
 
 
181
  self.opt.zero_grad(set_to_none=True)
182
  total_loss.backward()
183
  self.opt.step()
@@ -216,13 +304,8 @@ class Agent:
216
  adv[t] = gae
217
 
218
  returns = adv + values
219
-
220
- # --- Return-based normalization (RBS) ---
221
- sigma_t = returns.std(unbiased=False) + 1e-8
222
- returns = returns / sigma_t
223
- adv = adv / sigma_t
224
  adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
225
- self.sigma_history.append(sigma_t.item())
226
 
227
  # --- PPO Multiple Epochs + Minibatch ---
228
  total_loss_epoch = 0.0
@@ -262,6 +345,10 @@ class Agent:
262
  self.value_coef * value_loss -
263
  self.entropy_coef * entropy
264
  )
 
 
 
 
265
 
266
  self.opt.zero_grad(set_to_none=True)
267
  total_loss.backward()
@@ -353,11 +440,9 @@ class Critic(nn.Module):
353
  c, h, w = obs_shape
354
  # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
355
  self.cnn = nn.Sequential(
356
- nn.Conv2d(c, 32, kernel_size=8, stride=4),
357
- nn.ReLU(),
358
- nn.Conv2d(32, 64, kernel_size=4, stride=2),
359
  nn.ReLU(),
360
- nn.Conv2d(64, 64, kernel_size=3, stride=1),
361
  nn.ReLU(),
362
  nn.Flatten()
363
  )
 
106
  next_value = self.critic.evaluated_state(ns).item()
107
  self.memory.store(state, action, reward, done, log_prob, value, next_value)
108
 
109
+ def vanilla_ppo_update(self):
110
+ if len(self.memory.states) == 0:
111
+ return 0.0
112
+
113
+ # Convert memory to tensors
114
+ states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
115
+ actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
116
+ rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
117
+ dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
118
+ old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
119
+ values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
120
+
121
+ with T.no_grad():
122
+ # Compute next values (bootstrap for final step)
123
+ next_values = T.cat([values[1:], values[-1:].clone()])
124
+ deltas = rewards + self.gamma * next_values * (1 - dones) - values
125
+
126
+ # --- GAE-Lambda ---
127
+ adv = T.zeros_like(rewards)
128
+ gae = 0.0
129
+ for t in reversed(range(len(rewards))):
130
+ gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
131
+ adv[t] = gae
132
+
133
+ returns = adv + values
134
+ # Advantage normalization
135
+ adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
136
+
137
+ # --- PPO Multiple Epochs + Minibatch ---
138
+ total_loss_epoch = 0.0
139
+ num_samples = len(states)
140
+ batch_size = min(64, num_samples)
141
+ ppo_epochs = 4
142
+
143
+ for _ in range(ppo_epochs):
144
+ # Shuffle indices
145
+ idxs = T.randperm(num_samples)
146
+ for start in range(0, num_samples, batch_size):
147
+ batch_idx = idxs[start:start + batch_size]
148
+
149
+ b_states = states[batch_idx]
150
+ b_actions = actions[batch_idx]
151
+ b_old_logp = old_logp[batch_idx]
152
+ b_returns = returns[batch_idx]
153
+ b_adv = adv[batch_idx]
154
+
155
+ dist = self.policy.next_action(b_states)
156
+ new_logp = dist.log_prob(b_actions)
157
+ entropy = dist.entropy().mean()
158
+ ratio = (new_logp - b_old_logp).exp()
159
+
160
+ # --- Clipped surrogate objective ---
161
+ surr1 = ratio * b_adv
162
+ surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
163
+ policy_loss = -T.min(surr1, surr2).mean()
164
+
165
+ # --- Critic loss ---
166
+ value_pred = self.critic.evaluated_state(b_states)
167
+ value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
168
+
169
+ # --- Total loss ---
170
+ total_loss = (
171
+ policy_loss +
172
+ self.value_coef * value_loss -
173
+ self.entropy_coef * entropy
174
+ )
175
+
176
+ # Debug: track individual loss components
177
+ self.policy_loss_history.append(policy_loss.item())
178
+ self.value_loss_history.append(value_loss.item())
179
+
180
+ self.opt.zero_grad(set_to_none=True)
181
+ total_loss.backward()
182
+ self.opt.step()
183
+
184
+ total_loss_epoch += total_loss.item()
185
+
186
+ # Clear memory after full PPO update
187
+ self.memory.clear()
188
+
189
+ return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
190
+
191
+
192
  def update_rbs(self):
193
  if len(self.memory.states) == 0:
194
  return 0.0
 
218
  # --- Return-based normalization (RBS) ---
219
  sigma_t = returns.std(unbiased=False) + 1e-8
220
  returns = returns / sigma_t
221
+ self.sigma_history.append(sigma_t.item())
222
  adv = adv / sigma_t
223
+ # Advantage normalization
224
  adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
 
225
 
226
  # --- PPO Multiple Epochs + Minibatch ---
227
  total_loss_epoch = 0.0
 
262
  self.entropy_coef * entropy
263
  )
264
 
265
+ # Debug: track individual loss components
266
+ self.policy_loss_history.append(policy_loss.item())
267
+ self.value_loss_history.append(value_loss.item())
268
+
269
  self.opt.zero_grad(set_to_none=True)
270
  total_loss.backward()
271
  self.opt.step()
 
304
  adv[t] = gae
305
 
306
  returns = adv + values
307
+ # Advantage normalization
 
 
 
 
308
  adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
 
309
 
310
  # --- PPO Multiple Epochs + Minibatch ---
311
  total_loss_epoch = 0.0
 
345
  self.value_coef * value_loss -
346
  self.entropy_coef * entropy
347
  )
348
+
349
+ # Debug: track individual loss components
350
+ self.policy_loss_history.append(policy_loss.item())
351
+ self.value_loss_history.append(value_loss.item())
352
 
353
  self.opt.zero_grad(set_to_none=True)
354
  total_loss.backward()
 
440
  c, h, w = obs_shape
441
  # Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
442
  self.cnn = nn.Sequential(
443
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
 
 
444
  nn.ReLU(),
445
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
446
  nn.ReLU(),
447
  nn.Flatten()
448
  )