Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch as T | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.distributions import Categorical | |
| class Agent: | |
| def __init__(self, obs_space, action_space, hidden, gamma, lr, alpha, seed, batch_size, tau=0.005): | |
| if seed is not None: | |
| np.random.seed(seed) | |
| T.manual_seed(seed) | |
| # Use GPU if available | |
| self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu') | |
| self.action_dim = int(getattr(action_space, "n", action_space.n)) # Use .n for Discrete | |
| self.obs_shape = obs_space.shape | |
| self.gamma, self.tau, self.batch_size = gamma, tau, batch_size | |
| # Make alpha learnable (adjust entropy based on reward magnitude) | |
| self.target_entropy = -float(self.action_dim) | |
| self.log_alpha = T.tensor(np.log(alpha), requires_grad=True, device=self.device) | |
| self.alpha = np.exp(self.log_alpha.item()) | |
| self.alpha_opt = optim.Adam([self.log_alpha], lr=lr) | |
| self.policy = CategoricalActor(self.obs_shape, self.action_dim, hidden).to(self.device) | |
| self.q1 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device) | |
| self.q2 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device) | |
| self.q1_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device) | |
| self.q2_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device) | |
| self.q1_target.load_state_dict(self.q1.state_dict()) | |
| self.q2_target.load_state_dict(self.q2.state_dict()) | |
| self.policy_opt = optim.Adam(self.policy.parameters(), lr=lr) | |
| self.q1_opt = optim.Adam(self.q1.parameters(), lr=lr) | |
| self.q2_opt = optim.Adam(self.q2.parameters(), lr=lr) | |
| self.memory = Memory() | |
| def choose_action(self, observation, eval_mode=False): | |
| state = T.as_tensor(observation, dtype=T.float32, device=self.device) | |
| with T.no_grad(): | |
| logits = self.policy(state.unsqueeze(0)) | |
| dist = Categorical(logits=logits) | |
| if eval_mode: | |
| action = logits.argmax(dim=-1) | |
| else: | |
| action = dist.sample() | |
| return int(action.item()) | |
| def remember(self, state, action, reward, done, next_state): | |
| self.memory.store(state, action, reward, done, next_state) | |
| def vanilla_sac_update(self): | |
| if len(self.memory.states) < self.batch_size: | |
| return 0.0 | |
| # Mini-batch sampling | |
| idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False) | |
| states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device) | |
| actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1) | |
| rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1) | |
| dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1) | |
| next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device) | |
| # Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration | |
| with T.no_grad(): | |
| next_logits = self.policy(next_states) | |
| next_dist = Categorical(logits=next_logits) | |
| next_probs = next_dist.probs | |
| next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True) | |
| q1_next = self.q1_target(next_states) | |
| q2_next = self.q2_target(next_states) | |
| # Soft Policy Evaluation | |
| min_q_next = T.min(q1_next, q2_next) | |
| next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True) | |
| target = rewards + self.gamma * (1 - dones) * next_value | |
| q1 = self.q1(states).gather(1, actions) | |
| q2 = self.q2(states).gather(1, actions) | |
| # Losses of both Q-functions | |
| q1_loss = nn.MSELoss()(q1, target) | |
| q2_loss = nn.MSELoss()(q2, target) | |
| self.q1_opt.zero_grad() | |
| q1_loss.backward() | |
| self.q1_opt.step() | |
| self.q2_opt.zero_grad() | |
| q2_loss.backward() | |
| self.q2_opt.step() | |
| # Policy/Actor Objective | |
| logits = self.policy(states) | |
| dist = Categorical(logits=logits) | |
| probs = dist.probs | |
| log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True) | |
| q1_policy = self.q1(states) | |
| q2_policy = self.q2(states) | |
| min_q_policy = T.min(q1_policy, q2_policy) | |
| # Slightly different policy loss for discrete actions | |
| policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean() | |
| # Temperature to update Alpha | |
| alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean() | |
| self.alpha_opt.zero_grad() | |
| alpha_loss.backward() | |
| self.alpha_opt.step() | |
| self.alpha = self.log_alpha.exp().item() | |
| self.policy_opt.zero_grad() | |
| policy_loss.backward() | |
| self.policy_opt.step() | |
| # Target network update | |
| for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()): | |
| target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) | |
| for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()): | |
| target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) | |
| return policy_loss.item() | |
| def update_reward_gradient_clipping(self): | |
| if len(self.memory.states) < self.batch_size: | |
| return 0.0 | |
| # Mini-batch sampling | |
| idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False) | |
| states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device) | |
| actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1) | |
| rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1) | |
| dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1) | |
| next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device) | |
| """ | |
| # Min-max normalization and tanh scaling to [-1, 1] | |
| rewards_np = np.array([self.memory.rewards[i] for i in idxs]) | |
| r_min = rewards_np.min() | |
| r_max = rewards_np.max() | |
| # Avoid division by zero | |
| r_scaled = 2 * (rewards_np - r_min) / (r_max - r_min + 1e-8) - 1 | |
| normalized_rewards = np.tanh(r_scaled) | |
| rewards = T.as_tensor(normalized_rewards, dtype=T.float32, device=self.device).unsqueeze(-1) | |
| """ | |
| # Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration | |
| with T.no_grad(): | |
| next_logits = self.policy(next_states) | |
| next_dist = Categorical(logits=next_logits) | |
| next_probs = next_dist.probs | |
| next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True) | |
| q1_next = self.q1_target(next_states) | |
| q2_next = self.q2_target(next_states) | |
| # Soft Policy Evaluation | |
| min_q_next = T.min(q1_next, q2_next) | |
| next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True) | |
| target = rewards + self.gamma * (1 - dones) * next_value | |
| q1 = self.q1(states).gather(1, actions) | |
| q2 = self.q2(states).gather(1, actions) | |
| # Losses of both Q-functions | |
| q1_loss = nn.MSELoss()(q1, target) | |
| q2_loss = nn.MSELoss()(q2, target) | |
| self.q1_opt.zero_grad() | |
| q1_loss.backward() | |
| self.q1_opt.step() | |
| self.q2_opt.zero_grad() | |
| q2_loss.backward() | |
| self.q2_opt.step() | |
| # Policy/Actor Objective | |
| logits = self.policy(states) | |
| dist = Categorical(logits=logits) | |
| probs = dist.probs | |
| log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True) | |
| q1_policy = self.q1(states) | |
| q2_policy = self.q2(states) | |
| min_q_policy = T.min(q1_policy, q2_policy) | |
| # Slightly different policy loss for discrete actions | |
| policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean() | |
| # Temperature to update Alpha | |
| alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean() | |
| self.alpha_opt.zero_grad() | |
| alpha_loss.backward() | |
| T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping | |
| self.alpha_opt.step() | |
| self.alpha = self.log_alpha.exp().item() | |
| self.policy_opt.zero_grad() | |
| policy_loss.backward() | |
| T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping | |
| self.policy_opt.step() | |
| # Target network update | |
| for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()): | |
| target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) | |
| for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()): | |
| target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) | |
| return policy_loss.item() | |
| # Actor/Policy network | |
| # Typical SAC Actor network is used to output a Gaussian distribution of a state | |
| # Here, we adapt it for discrete actions using a Categorical distribution, as the ATARI environment is discrete | |
| # The policy outputs logits for each discrete action. | |
| # From: https://ch.mathworks.com/help/reinforcement-learning/ug/soft-actor-critic-agents.html | |
| # The actor takes the current observation and generates a categorical distribution, in which each possible action is associated with a probability. | |
| class CategoricalActor(nn.Module): | |
| def __init__(self, obs_shape: tuple, action_dim: int, hidden: int): | |
| super().__init__() | |
| c, h, w = obs_shape | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(c, 16, kernel_size=8, stride=4), | |
| nn.ReLU(), | |
| nn.Conv2d(16, 32, kernel_size=4, stride=2), | |
| nn.ReLU(), | |
| nn.Flatten() | |
| ) | |
| with T.no_grad(): | |
| cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1] | |
| self.fc = nn.Sequential( | |
| nn.Linear(cnn_output_dim, hidden), | |
| nn.ReLU(), | |
| nn.Linear(hidden, action_dim) | |
| ) | |
| def forward(self, state: T.Tensor): | |
| if state.dim() == 3: | |
| state = state.unsqueeze(0) | |
| cnn_out = self.cnn(state) | |
| logits = self.fc(cnn_out) | |
| return logits | |
| # Q-network for discrete actions | |
| class QNetwork(nn.Module): | |
| def __init__(self, obs_shape: tuple, action_dim: int, hidden: int): | |
| super().__init__() | |
| c, h, w = obs_shape | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(c, 16, kernel_size=8, stride=4), | |
| nn.ReLU(), | |
| nn.Conv2d(16, 32, kernel_size=4, stride=2), | |
| nn.ReLU(), | |
| nn.Flatten() | |
| ) | |
| with T.no_grad(): | |
| cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1] | |
| self.net = nn.Sequential( | |
| nn.Linear(cnn_output_dim, hidden), | |
| nn.ReLU(), | |
| nn.Linear(hidden, action_dim) | |
| ) | |
| def forward(self, state: T.Tensor): | |
| if state.dim() == 3: | |
| state = state.unsqueeze(0) | |
| cnn_out = self.cnn(state) | |
| return self.net(cnn_out) | |
| class Memory: | |
| def __init__(self): | |
| self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], [] | |
| def store(self, s, a, r, d, ns): | |
| self.states.append(np.asarray(s, dtype=np.float32)) | |
| self.actions.append(np.asarray(a, dtype=np.float32)) | |
| self.rewards.append(float(r)) | |
| self.dones.append(float(d)) | |
| self.next_states.append(np.asarray(ns, dtype=np.float32)) | |
| def clear(self): | |
| self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], [] |