Spaces:
Sleeping
Sleeping
Added vanilla_ppo_update (base case w/o fancy normalizations)
Browse files- CNN_PPO/ppo_helpers_cnn.py +96 -11
CNN_PPO/ppo_helpers_cnn.py
CHANGED
|
@@ -106,6 +106,89 @@ class Agent:
|
|
| 106 |
next_value = self.critic.evaluated_state(ns).item()
|
| 107 |
self.memory.store(state, action, reward, done, log_prob, value, next_value)
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def update_rbs(self):
|
| 110 |
if len(self.memory.states) == 0:
|
| 111 |
return 0.0
|
|
@@ -135,9 +218,10 @@ class Agent:
|
|
| 135 |
# --- Return-based normalization (RBS) ---
|
| 136 |
sigma_t = returns.std(unbiased=False) + 1e-8
|
| 137 |
returns = returns / sigma_t
|
|
|
|
| 138 |
adv = adv / sigma_t
|
|
|
|
| 139 |
adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
|
| 140 |
-
self.sigma_history.append(sigma_t.item())
|
| 141 |
|
| 142 |
# --- PPO Multiple Epochs + Minibatch ---
|
| 143 |
total_loss_epoch = 0.0
|
|
@@ -178,6 +262,10 @@ class Agent:
|
|
| 178 |
self.entropy_coef * entropy
|
| 179 |
)
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
self.opt.zero_grad(set_to_none=True)
|
| 182 |
total_loss.backward()
|
| 183 |
self.opt.step()
|
|
@@ -216,13 +304,8 @@ class Agent:
|
|
| 216 |
adv[t] = gae
|
| 217 |
|
| 218 |
returns = adv + values
|
| 219 |
-
|
| 220 |
-
# --- Return-based normalization (RBS) ---
|
| 221 |
-
sigma_t = returns.std(unbiased=False) + 1e-8
|
| 222 |
-
returns = returns / sigma_t
|
| 223 |
-
adv = adv / sigma_t
|
| 224 |
adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
|
| 225 |
-
self.sigma_history.append(sigma_t.item())
|
| 226 |
|
| 227 |
# --- PPO Multiple Epochs + Minibatch ---
|
| 228 |
total_loss_epoch = 0.0
|
|
@@ -262,6 +345,10 @@ class Agent:
|
|
| 262 |
self.value_coef * value_loss -
|
| 263 |
self.entropy_coef * entropy
|
| 264 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
self.opt.zero_grad(set_to_none=True)
|
| 267 |
total_loss.backward()
|
|
@@ -353,11 +440,9 @@ class Critic(nn.Module):
|
|
| 353 |
c, h, w = obs_shape
|
| 354 |
# Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
|
| 355 |
self.cnn = nn.Sequential(
|
| 356 |
-
nn.Conv2d(c,
|
| 357 |
-
nn.ReLU(),
|
| 358 |
-
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
| 359 |
nn.ReLU(),
|
| 360 |
-
nn.Conv2d(
|
| 361 |
nn.ReLU(),
|
| 362 |
nn.Flatten()
|
| 363 |
)
|
|
|
|
| 106 |
next_value = self.critic.evaluated_state(ns).item()
|
| 107 |
self.memory.store(state, action, reward, done, log_prob, value, next_value)
|
| 108 |
|
| 109 |
+
def vanilla_ppo_update(self):
|
| 110 |
+
if len(self.memory.states) == 0:
|
| 111 |
+
return 0.0
|
| 112 |
+
|
| 113 |
+
# Convert memory to tensors
|
| 114 |
+
states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
|
| 115 |
+
actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
|
| 116 |
+
rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
|
| 117 |
+
dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
|
| 118 |
+
old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
|
| 119 |
+
values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
|
| 120 |
+
|
| 121 |
+
with T.no_grad():
|
| 122 |
+
# Compute next values (bootstrap for final step)
|
| 123 |
+
next_values = T.cat([values[1:], values[-1:].clone()])
|
| 124 |
+
deltas = rewards + self.gamma * next_values * (1 - dones) - values
|
| 125 |
+
|
| 126 |
+
# --- GAE-Lambda ---
|
| 127 |
+
adv = T.zeros_like(rewards)
|
| 128 |
+
gae = 0.0
|
| 129 |
+
for t in reversed(range(len(rewards))):
|
| 130 |
+
gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
|
| 131 |
+
adv[t] = gae
|
| 132 |
+
|
| 133 |
+
returns = adv + values
|
| 134 |
+
# Advantage normalization
|
| 135 |
+
adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
|
| 136 |
+
|
| 137 |
+
# --- PPO Multiple Epochs + Minibatch ---
|
| 138 |
+
total_loss_epoch = 0.0
|
| 139 |
+
num_samples = len(states)
|
| 140 |
+
batch_size = min(64, num_samples)
|
| 141 |
+
ppo_epochs = 4
|
| 142 |
+
|
| 143 |
+
for _ in range(ppo_epochs):
|
| 144 |
+
# Shuffle indices
|
| 145 |
+
idxs = T.randperm(num_samples)
|
| 146 |
+
for start in range(0, num_samples, batch_size):
|
| 147 |
+
batch_idx = idxs[start:start + batch_size]
|
| 148 |
+
|
| 149 |
+
b_states = states[batch_idx]
|
| 150 |
+
b_actions = actions[batch_idx]
|
| 151 |
+
b_old_logp = old_logp[batch_idx]
|
| 152 |
+
b_returns = returns[batch_idx]
|
| 153 |
+
b_adv = adv[batch_idx]
|
| 154 |
+
|
| 155 |
+
dist = self.policy.next_action(b_states)
|
| 156 |
+
new_logp = dist.log_prob(b_actions)
|
| 157 |
+
entropy = dist.entropy().mean()
|
| 158 |
+
ratio = (new_logp - b_old_logp).exp()
|
| 159 |
+
|
| 160 |
+
# --- Clipped surrogate objective ---
|
| 161 |
+
surr1 = ratio * b_adv
|
| 162 |
+
surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
|
| 163 |
+
policy_loss = -T.min(surr1, surr2).mean()
|
| 164 |
+
|
| 165 |
+
# --- Critic loss ---
|
| 166 |
+
value_pred = self.critic.evaluated_state(b_states)
|
| 167 |
+
value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
|
| 168 |
+
|
| 169 |
+
# --- Total loss ---
|
| 170 |
+
total_loss = (
|
| 171 |
+
policy_loss +
|
| 172 |
+
self.value_coef * value_loss -
|
| 173 |
+
self.entropy_coef * entropy
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Debug: track individual loss components
|
| 177 |
+
self.policy_loss_history.append(policy_loss.item())
|
| 178 |
+
self.value_loss_history.append(value_loss.item())
|
| 179 |
+
|
| 180 |
+
self.opt.zero_grad(set_to_none=True)
|
| 181 |
+
total_loss.backward()
|
| 182 |
+
self.opt.step()
|
| 183 |
+
|
| 184 |
+
total_loss_epoch += total_loss.item()
|
| 185 |
+
|
| 186 |
+
# Clear memory after full PPO update
|
| 187 |
+
self.memory.clear()
|
| 188 |
+
|
| 189 |
+
return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
|
| 190 |
+
|
| 191 |
+
|
| 192 |
def update_rbs(self):
|
| 193 |
if len(self.memory.states) == 0:
|
| 194 |
return 0.0
|
|
|
|
| 218 |
# --- Return-based normalization (RBS) ---
|
| 219 |
sigma_t = returns.std(unbiased=False) + 1e-8
|
| 220 |
returns = returns / sigma_t
|
| 221 |
+
self.sigma_history.append(sigma_t.item())
|
| 222 |
adv = adv / sigma_t
|
| 223 |
+
# Advantage normalization
|
| 224 |
adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
|
|
|
|
| 225 |
|
| 226 |
# --- PPO Multiple Epochs + Minibatch ---
|
| 227 |
total_loss_epoch = 0.0
|
|
|
|
| 262 |
self.entropy_coef * entropy
|
| 263 |
)
|
| 264 |
|
| 265 |
+
# Debug: track individual loss components
|
| 266 |
+
self.policy_loss_history.append(policy_loss.item())
|
| 267 |
+
self.value_loss_history.append(value_loss.item())
|
| 268 |
+
|
| 269 |
self.opt.zero_grad(set_to_none=True)
|
| 270 |
total_loss.backward()
|
| 271 |
self.opt.step()
|
|
|
|
| 304 |
adv[t] = gae
|
| 305 |
|
| 306 |
returns = adv + values
|
| 307 |
+
# Advantage normalization
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
|
|
|
|
| 309 |
|
| 310 |
# --- PPO Multiple Epochs + Minibatch ---
|
| 311 |
total_loss_epoch = 0.0
|
|
|
|
| 345 |
self.value_coef * value_loss -
|
| 346 |
self.entropy_coef * entropy
|
| 347 |
)
|
| 348 |
+
|
| 349 |
+
# Debug: track individual loss components
|
| 350 |
+
self.policy_loss_history.append(policy_loss.item())
|
| 351 |
+
self.value_loss_history.append(value_loss.item())
|
| 352 |
|
| 353 |
self.opt.zero_grad(set_to_none=True)
|
| 354 |
total_loss.backward()
|
|
|
|
| 440 |
c, h, w = obs_shape
|
| 441 |
# Suggested architecture for Atari: https://arxiv.org/pdf/1312.5602
|
| 442 |
self.cnn = nn.Sequential(
|
| 443 |
+
nn.Conv2d(c, 16, kernel_size=8, stride=4),
|
|
|
|
|
|
|
| 444 |
nn.ReLU(),
|
| 445 |
+
nn.Conv2d(16, 32, kernel_size=4, stride=2),
|
| 446 |
nn.ReLU(),
|
| 447 |
nn.Flatten()
|
| 448 |
)
|