RL_Project20 / ppo_helpers_v2.py
manansodha's picture
Initial Commit
662707e verified
import numpy as np
import torch as T
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
class Agent:
def __init__(
self,
obs_space,
action_space,
hidden,
gamma,
clip_coef,
lr,
value_coef,
entropy_coef,
seed,
batch_size,
ppo_epochs,
lam
):
# Initialize seed for reproducibility
if seed is not None:
np.random.seed(seed)
T.manual_seed(seed)
# Use GPU if available
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
self.obs_dim = int(np.prod(getattr(obs_space, "shape", (obs_space,))))
self.action_dim = int(getattr(action_space, "n", action_space))
# Initialize the policy and the critic networks
self.policy = Policy(self.obs_dim, self.action_dim, hidden).to(self.device)
self.critic = Critic(self.obs_dim, hidden).to(self.device)
# Set optimizer for policy and critic networks
self.opt = optim.Adam(
list(self.policy.parameters()) + list(self.critic.parameters()),
lr=lr
)
self.gamma = gamma
self.clip = clip_coef
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.sigma_history = []
self.loss_history = []
self.policy_loss_history = []
self.value_loss_history = []
self.entropy_history = []
self.lam = lam
self.ppo_epochs = ppo_epochs
self.batch_size = batch_size
self.memory = Memory()
def choose_action(self, observation):
# Returns: action, log probabilitiy, value of the state
state = T.as_tensor(observation, dtype=T.float32, device=self.device).view(-1)
with T.no_grad():
# Forward function (defined in Policy class)
dist = self.policy.next_action(state)
action = dist.sample()
logp = dist.log_prob(action)
value = self.critic.evaluated_state(state)
return int(action.item()), float(logp.item()), float(value.item())
def remember(self, state, action, reward, done, log_prob, value, next_state):
with T.no_grad():
# Pass on next state and have it evaluated by the critic network
ns = T.as_tensor(next_state, dtype=T.float32, device=self.device).view(-1)
next_value = self.critic.evaluated_state(ns).item()
self.memory.store(state, action, reward, done, log_prob, value, next_value)
"""
def run_episode(self, env, max_steps: int, render: bool = False):
# Runs one episode, updates the policy once at the end
self.memory.clear()
out = env.reset()
state = out[0] if isinstance(out, tuple) else out
ep_return, ep_len = 0, 0
steps_limit = max_steps if max_steps is not None else float("inf")
while ep_len < steps_limit:
if render and hasattr(env, "render"):
env.render()
action, logp, value = self.choose_action(state)
step_out = env.step(action)
if len(step_out) == 5:
next_state, reward, terminated, truncated, _ = step_out
done = terminated or truncated
else:
next_state, reward, done, _ = step_out
self.remember(state, action, reward, done, logp, value, next_state)
ep_return += float(reward)
ep_len += 1
state = next_state
if done:
break
self._update()
return ep_return, ep_len
def run_episodes(self, env, n_episodes: int, max_steps: int, render: bool = False):
returns = []
for _ in range(n_episodes):
ep_ret, _ = self.run_episode(env, max_steps=max_steps, render=render)
returns.append(ep_ret)
return returns
"""
def update_rbs(self):
if len(self.memory.states) == 0:
return 0.0
# Convert memory to tensors
states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device)
actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device)
rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device)
dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device)
old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device)
values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device)
with T.no_grad():
# Compute next values (bootstrap for final step)
next_values = T.cat([values[1:], values[-1:].clone()])
deltas = rewards + self.gamma * next_values * (1 - dones) - values
# --- GAE-Lambda ---
adv = T.zeros_like(rewards)
gae = 0.0
for t in reversed(range(len(rewards))):
gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae
adv[t] = gae
returns = adv + values
# --- Return-based normalization (RBS) ---
sigma_t = returns.std(unbiased=False) + 1e-8
returns = returns / sigma_t
adv = adv / sigma_t
adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
self.sigma_history.append(sigma_t.item())
# --- PPO Multiple Epochs + Minibatch ---
total_loss_epoch = 0.0
num_samples = len(states)
batch_size = min(64, num_samples)
ppo_epochs = 4
for _ in range(ppo_epochs):
# Shuffle indices
idxs = T.randperm(num_samples)
for start in range(0, num_samples, batch_size):
batch_idx = idxs[start:start + batch_size]
b_states = states[batch_idx]
b_actions = actions[batch_idx]
b_old_logp = old_logp[batch_idx]
b_returns = returns[batch_idx]
b_adv = adv[batch_idx]
dist = self.policy.next_action(b_states)
new_logp = dist.log_prob(b_actions)
entropy = dist.entropy().mean()
ratio = (new_logp - b_old_logp).exp()
# --- Clipped surrogate objective ---
surr1 = ratio * b_adv
surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv
policy_loss = -T.min(surr1, surr2).mean()
# --- Critic loss ---
value_pred = self.critic.evaluated_state(b_states)
value_loss = 0.5 * (b_returns - value_pred).pow(2).mean()
# --- Total loss ---
total_loss = (
policy_loss +
self.value_coef * value_loss -
self.entropy_coef * entropy
)
self.opt.zero_grad(set_to_none=True)
total_loss.backward()
T.nn.utils.clip_grad_norm_(list(self.policy.parameters()) + list(self.critic.parameters()), 0.5)
self.opt.step()
total_loss_epoch += total_loss.item()
# Clear memory after full PPO update
self.memory.clear()
return total_loss_epoch / (ppo_epochs * (num_samples / batch_size))
class Policy(nn.Module):
def __init__(self, obs_dim: int, action_dim: int, hidden: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, action_dim)
)
def next_action(self, state: T.Tensor) -> Categorical:
# Returns the probability distribution over actions
if state.dim() == 1:
state = state.unsqueeze(0)
state = state.view(state.size(0), -1)
return Categorical(logits=self.net(state))
class Critic(nn.Module):
def __init__(self, obs_dim: int, hidden: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, 1)
)
def evaluated_state(self, x: T.Tensor) -> T.Tensor:
if x.dim() == 1:
x = x.unsqueeze(0)
x = x.view(x.size(0), -1)
return self.net(x).squeeze(-1)
class Memory():
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = []
self.values = []
self.next_values = []
def store(self, state, action, reward, done, log_prob, value, next_value):
self.states.append(np.asarray(state, dtype=np.float32))
self.actions.append(int(action))
self.rewards.append(float(reward))
self.dones.append(float(done))
self.log_probs.append(float(log_prob))
self.values.append(float(value))
self.next_values.append(float(next_value))
"""
# For mini-batch updates? To be implemented
def start_batch(self, batch_size: int):
n_states = len(self.states)
starts = np.arange(0, n_states, batch_size)
index = np.arange(n_states, dtype=np.int64)
np.random.shuffle(index)
return [index[s:s + batch_size] for s in starts]
"""
def clear(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = []
self.values = []
self.next_values = []