RL_Project20 / Observation_norm_SAC /sac_helpers_cnn.py
Anoozh-Akileswaran
Observation, Advantage and Return normalization for SAC and PPO
fc2ab64
import numpy as np
import torch as T
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
class Agent:
def __init__(self, obs_space, action_space, hidden, gamma, lr, alpha, seed, batch_size, tau=0.005):
if seed is not None:
np.random.seed(seed)
T.manual_seed(seed)
# Use GPU if available
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
self.action_dim = int(getattr(action_space, "n", action_space.n)) # Use .n for Discrete
self.obs_shape = obs_space.shape
self.gamma, self.tau, self.batch_size = gamma, tau, batch_size
# Make alpha learnable (adjust entropy based on reward magnitude)
self.target_entropy = -float(self.action_dim)
self.log_alpha = T.tensor(np.log(alpha), requires_grad=True, device=self.device)
self.alpha = np.exp(self.log_alpha.item())
self.alpha_opt = optim.Adam([self.log_alpha], lr=lr)
self.policy = CategoricalActor(self.obs_shape, self.action_dim, hidden).to(self.device)
self.q1 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
self.q2 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
self.q1_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
self.q2_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
self.q1_target.load_state_dict(self.q1.state_dict())
self.q2_target.load_state_dict(self.q2.state_dict())
self.policy_opt = optim.Adam(self.policy.parameters(), lr=lr)
self.q1_opt = optim.Adam(self.q1.parameters(), lr=lr)
self.q2_opt = optim.Adam(self.q2.parameters(), lr=lr)
self.memory = Memory()
def choose_action(self, observation, eval_mode=False):
state = T.as_tensor(observation, dtype=T.float32, device=self.device)
with T.no_grad():
logits = self.policy(state.unsqueeze(0))
dist = Categorical(logits=logits)
if eval_mode:
action = logits.argmax(dim=-1)
else:
action = dist.sample()
return int(action.item())
def remember(self, state, action, reward, done, next_state):
self.memory.store(state, action, reward, done, next_state)
def vanilla_sac_update(self):
if len(self.memory.states) < self.batch_size:
return 0.0
# Mini-batch sampling
idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False)
states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device)
actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1)
rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device)
# Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration
with T.no_grad():
next_logits = self.policy(next_states)
next_dist = Categorical(logits=next_logits)
next_probs = next_dist.probs
next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True)
q1_next = self.q1_target(next_states)
q2_next = self.q2_target(next_states)
# Soft Policy Evaluation
min_q_next = T.min(q1_next, q2_next)
next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True)
target = rewards + self.gamma * (1 - dones) * next_value
q1 = self.q1(states).gather(1, actions)
q2 = self.q2(states).gather(1, actions)
# Losses of both Q-functions
q1_loss = nn.MSELoss()(q1, target)
q2_loss = nn.MSELoss()(q2, target)
self.q1_opt.zero_grad()
q1_loss.backward()
self.q1_opt.step()
self.q2_opt.zero_grad()
q2_loss.backward()
self.q2_opt.step()
# Policy/Actor Objective
logits = self.policy(states)
dist = Categorical(logits=logits)
probs = dist.probs
log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True)
q1_policy = self.q1(states)
q2_policy = self.q2(states)
min_q_policy = T.min(q1_policy, q2_policy)
# Slightly different policy loss for discrete actions
policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean()
# Temperature to update Alpha
alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
self.alpha_opt.zero_grad()
alpha_loss.backward()
self.alpha_opt.step()
self.alpha = self.log_alpha.exp().item()
self.policy_opt.zero_grad()
policy_loss.backward()
self.policy_opt.step()
# Target network update
for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
return policy_loss.item()
def update_reward_gradient_clipping(self):
if len(self.memory.states) < self.batch_size:
return 0.0
# Mini-batch sampling
idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False)
states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device)
actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1)
rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device)
"""
# Min-max normalization and tanh scaling to [-1, 1]
rewards_np = np.array([self.memory.rewards[i] for i in idxs])
r_min = rewards_np.min()
r_max = rewards_np.max()
# Avoid division by zero
r_scaled = 2 * (rewards_np - r_min) / (r_max - r_min + 1e-8) - 1
normalized_rewards = np.tanh(r_scaled)
rewards = T.as_tensor(normalized_rewards, dtype=T.float32, device=self.device).unsqueeze(-1)
"""
# Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration
with T.no_grad():
next_logits = self.policy(next_states)
next_dist = Categorical(logits=next_logits)
next_probs = next_dist.probs
next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True)
q1_next = self.q1_target(next_states)
q2_next = self.q2_target(next_states)
# Soft Policy Evaluation
min_q_next = T.min(q1_next, q2_next)
next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True)
target = rewards + self.gamma * (1 - dones) * next_value
q1 = self.q1(states).gather(1, actions)
q2 = self.q2(states).gather(1, actions)
# Losses of both Q-functions
q1_loss = nn.MSELoss()(q1, target)
q2_loss = nn.MSELoss()(q2, target)
self.q1_opt.zero_grad()
q1_loss.backward()
self.q1_opt.step()
self.q2_opt.zero_grad()
q2_loss.backward()
self.q2_opt.step()
# Policy/Actor Objective
logits = self.policy(states)
dist = Categorical(logits=logits)
probs = dist.probs
log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True)
q1_policy = self.q1(states)
q2_policy = self.q2(states)
min_q_policy = T.min(q1_policy, q2_policy)
# Slightly different policy loss for discrete actions
policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean()
# Temperature to update Alpha
alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
self.alpha_opt.zero_grad()
alpha_loss.backward()
T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping
self.alpha_opt.step()
self.alpha = self.log_alpha.exp().item()
self.policy_opt.zero_grad()
policy_loss.backward()
T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping
self.policy_opt.step()
# Target network update
for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
return policy_loss.item()
# Actor/Policy network
# Typical SAC Actor network is used to output a Gaussian distribution of a state
# Here, we adapt it for discrete actions using a Categorical distribution, as the ATARI environment is discrete
# The policy outputs logits for each discrete action.
# From: https://ch.mathworks.com/help/reinforcement-learning/ug/soft-actor-critic-agents.html
# The actor takes the current observation and generates a categorical distribution, in which each possible action is associated with a probability.
class CategoricalActor(nn.Module):
def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
super().__init__()
c, h, w = obs_shape
self.cnn = nn.Sequential(
nn.Conv2d(c, 16, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=4, stride=2),
nn.ReLU(),
nn.Flatten()
)
with T.no_grad():
cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
self.fc = nn.Sequential(
nn.Linear(cnn_output_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, action_dim)
)
def forward(self, state: T.Tensor):
if state.dim() == 3:
state = state.unsqueeze(0)
cnn_out = self.cnn(state)
logits = self.fc(cnn_out)
return logits
# Q-network for discrete actions
class QNetwork(nn.Module):
def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
super().__init__()
c, h, w = obs_shape
self.cnn = nn.Sequential(
nn.Conv2d(c, 16, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=4, stride=2),
nn.ReLU(),
nn.Flatten()
)
with T.no_grad():
cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
self.net = nn.Sequential(
nn.Linear(cnn_output_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, action_dim)
)
def forward(self, state: T.Tensor):
if state.dim() == 3:
state = state.unsqueeze(0)
cnn_out = self.cnn(state)
return self.net(cnn_out)
class Memory:
def __init__(self):
self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], []
def store(self, s, a, r, d, ns):
self.states.append(np.asarray(s, dtype=np.float32))
self.actions.append(np.asarray(a, dtype=np.float32))
self.rewards.append(float(r))
self.dones.append(float(d))
self.next_states.append(np.asarray(ns, dtype=np.float32))
def clear(self):
self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], []