Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch as T | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.distributions import Categorical | |
| class Agent: | |
| def __init__( | |
| self, | |
| obs_space, | |
| action_space, | |
| hidden, | |
| gamma, | |
| clip_coef, | |
| lr, | |
| value_coef, | |
| entropy_coef, | |
| seed, | |
| batch_size, | |
| ppo_epochs, | |
| lam | |
| ): | |
| # Initialize seed for reproducibility | |
| if seed is not None: | |
| np.random.seed(seed) | |
| T.manual_seed(seed) | |
| # Use GPU if available | |
| self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu') | |
| self.obs_dim = int(np.prod(getattr(obs_space, "shape", (obs_space,)))) | |
| self.action_dim = int(getattr(action_space, "n", action_space)) | |
| # Initialize the policy and the critic networks | |
| self.policy = Policy(self.obs_dim, self.action_dim, hidden).to(self.device) | |
| self.critic = Critic(self.obs_dim, hidden).to(self.device) | |
| # Set optimizer for policy and critic networks | |
| self.opt = optim.Adam( | |
| list(self.policy.parameters()) + list(self.critic.parameters()), | |
| lr=lr | |
| ) | |
| self.gamma = gamma | |
| self.clip = clip_coef | |
| self.value_coef = value_coef | |
| self.entropy_coef = entropy_coef | |
| self.sigma_history = [] | |
| self.loss_history = [] | |
| self.policy_loss_history = [] | |
| self.value_loss_history = [] | |
| self.entropy_history = [] | |
| self.lam = lam | |
| self.ppo_epochs = ppo_epochs | |
| self.batch_size = batch_size | |
| self.memory = Memory() | |
| def choose_action(self, observation): | |
| # Returns: action, log probabilitiy, value of the state | |
| state = T.as_tensor(observation, dtype=T.float32, device=self.device).view(-1) | |
| with T.no_grad(): | |
| # Forward function (defined in Policy class) | |
| dist = self.policy.next_action(state) | |
| action = dist.sample() | |
| logp = dist.log_prob(action) | |
| value = self.critic.evaluated_state(state) | |
| return int(action.item()), float(logp.item()), float(value.item()) | |
| def remember(self, state, action, reward, done, log_prob, value, next_state): | |
| with T.no_grad(): | |
| # Pass on next state and have it evaluated by the critic network | |
| ns = T.as_tensor(next_state, dtype=T.float32, device=self.device).view(-1) | |
| next_value = self.critic.evaluated_state(ns).item() | |
| self.memory.store(state, action, reward, done, log_prob, value, next_value) | |
| """ | |
| def run_episode(self, env, max_steps: int, render: bool = False): | |
| # Runs one episode, updates the policy once at the end | |
| self.memory.clear() | |
| out = env.reset() | |
| state = out[0] if isinstance(out, tuple) else out | |
| ep_return, ep_len = 0, 0 | |
| steps_limit = max_steps if max_steps is not None else float("inf") | |
| while ep_len < steps_limit: | |
| if render and hasattr(env, "render"): | |
| env.render() | |
| action, logp, value = self.choose_action(state) | |
| step_out = env.step(action) | |
| if len(step_out) == 5: | |
| next_state, reward, terminated, truncated, _ = step_out | |
| done = terminated or truncated | |
| else: | |
| next_state, reward, done, _ = step_out | |
| self.remember(state, action, reward, done, logp, value, next_state) | |
| ep_return += float(reward) | |
| ep_len += 1 | |
| state = next_state | |
| if done: | |
| break | |
| self._update() | |
| return ep_return, ep_len | |
| def run_episodes(self, env, n_episodes: int, max_steps: int, render: bool = False): | |
| returns = [] | |
| for _ in range(n_episodes): | |
| ep_ret, _ = self.run_episode(env, max_steps=max_steps, render=render) | |
| returns.append(ep_ret) | |
| return returns | |
| """ | |
| def update_rbs(self): | |
| if len(self.memory.states) == 0: | |
| return 0.0 | |
| # Convert memory to tensors | |
| states = T.as_tensor(np.array(self.memory.states), dtype=T.float32, device=self.device) | |
| actions = T.as_tensor(self.memory.actions, dtype=T.long, device=self.device) | |
| rewards = T.as_tensor(self.memory.rewards, dtype=T.float32, device=self.device) | |
| dones = T.as_tensor(self.memory.dones, dtype=T.float32, device=self.device) | |
| old_logp = T.as_tensor(self.memory.log_probs, dtype=T.float32, device=self.device) | |
| values = T.as_tensor(self.memory.values, dtype=T.float32, device=self.device) | |
| with T.no_grad(): | |
| # Compute next values (bootstrap for final step) | |
| next_values = T.cat([values[1:], values[-1:].clone()]) | |
| deltas = rewards + self.gamma * next_values * (1 - dones) - values | |
| # --- GAE-Lambda --- | |
| adv = T.zeros_like(rewards) | |
| gae = 0.0 | |
| for t in reversed(range(len(rewards))): | |
| gae = deltas[t] + self.gamma * self.lam * (1 - dones[t]) * gae | |
| adv[t] = gae | |
| returns = adv + values | |
| # --- Return-based normalization (RBS) --- | |
| sigma_t = returns.std(unbiased=False) + 1e-8 | |
| returns = returns / sigma_t | |
| adv = adv / sigma_t | |
| adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8) | |
| self.sigma_history.append(sigma_t.item()) | |
| # --- PPO Multiple Epochs + Minibatch --- | |
| total_loss_epoch = 0.0 | |
| num_samples = len(states) | |
| batch_size = min(64, num_samples) | |
| ppo_epochs = 4 | |
| for _ in range(ppo_epochs): | |
| # Shuffle indices | |
| idxs = T.randperm(num_samples) | |
| for start in range(0, num_samples, batch_size): | |
| batch_idx = idxs[start:start + batch_size] | |
| b_states = states[batch_idx] | |
| b_actions = actions[batch_idx] | |
| b_old_logp = old_logp[batch_idx] | |
| b_returns = returns[batch_idx] | |
| b_adv = adv[batch_idx] | |
| dist = self.policy.next_action(b_states) | |
| new_logp = dist.log_prob(b_actions) | |
| entropy = dist.entropy().mean() | |
| ratio = (new_logp - b_old_logp).exp() | |
| # --- Clipped surrogate objective --- | |
| surr1 = ratio * b_adv | |
| surr2 = T.clamp(ratio, 1 - self.clip, 1 + self.clip) * b_adv | |
| policy_loss = -T.min(surr1, surr2).mean() | |
| # --- Critic loss --- | |
| value_pred = self.critic.evaluated_state(b_states) | |
| value_loss = 0.5 * (b_returns - value_pred).pow(2).mean() | |
| # --- Total loss --- | |
| total_loss = ( | |
| policy_loss + | |
| self.value_coef * value_loss - | |
| self.entropy_coef * entropy | |
| ) | |
| self.opt.zero_grad(set_to_none=True) | |
| total_loss.backward() | |
| T.nn.utils.clip_grad_norm_(list(self.policy.parameters()) + list(self.critic.parameters()), 0.5) | |
| self.opt.step() | |
| total_loss_epoch += total_loss.item() | |
| # Clear memory after full PPO update | |
| self.memory.clear() | |
| return total_loss_epoch / (ppo_epochs * (num_samples / batch_size)) | |
| class Policy(nn.Module): | |
| def __init__(self, obs_dim: int, action_dim: int, hidden: int): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(obs_dim, hidden), | |
| nn.ReLU(), | |
| nn.Linear(hidden, hidden), | |
| nn.ReLU(), | |
| nn.Linear(hidden, action_dim) | |
| ) | |
| def next_action(self, state: T.Tensor) -> Categorical: | |
| # Returns the probability distribution over actions | |
| if state.dim() == 1: | |
| state = state.unsqueeze(0) | |
| state = state.view(state.size(0), -1) | |
| return Categorical(logits=self.net(state)) | |
| class Critic(nn.Module): | |
| def __init__(self, obs_dim: int, hidden: int): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(obs_dim, hidden), | |
| nn.ReLU(), | |
| nn.Linear(hidden, hidden), | |
| nn.ReLU(), | |
| nn.Linear(hidden, 1) | |
| ) | |
| def evaluated_state(self, x: T.Tensor) -> T.Tensor: | |
| if x.dim() == 1: | |
| x = x.unsqueeze(0) | |
| x = x.view(x.size(0), -1) | |
| return self.net(x).squeeze(-1) | |
| class Memory(): | |
| def __init__(self): | |
| self.states = [] | |
| self.actions = [] | |
| self.rewards = [] | |
| self.dones = [] | |
| self.log_probs = [] | |
| self.values = [] | |
| self.next_values = [] | |
| def store(self, state, action, reward, done, log_prob, value, next_value): | |
| self.states.append(np.asarray(state, dtype=np.float32)) | |
| self.actions.append(int(action)) | |
| self.rewards.append(float(reward)) | |
| self.dones.append(float(done)) | |
| self.log_probs.append(float(log_prob)) | |
| self.values.append(float(value)) | |
| self.next_values.append(float(next_value)) | |
| """ | |
| # For mini-batch updates? To be implemented | |
| def start_batch(self, batch_size: int): | |
| n_states = len(self.states) | |
| starts = np.arange(0, n_states, batch_size) | |
| index = np.arange(n_states, dtype=np.int64) | |
| np.random.shuffle(index) | |
| return [index[s:s + batch_size] for s in starts] | |
| """ | |
| def clear(self): | |
| self.states = [] | |
| self.actions = [] | |
| self.rewards = [] | |
| self.dones = [] | |
| self.log_probs = [] | |
| self.values = [] | |
| self.next_values = [] | |