| import torch | |
| import torch.nn as nn | |
| from torch.distributions import Normal | |
| import numpy as np | |
| class SharedActorCritic(nn.Module): | |
| def __init__(self, state_dim, action_dim): | |
| super(SharedActorCritic, self).__init__() | |
| self.feature_extractor = nn.Sequential( | |
| nn.Linear(state_dim, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 128), | |
| nn.ReLU() | |
| ) | |
| self.actor_head = nn.Linear(128, action_dim * 2) | |
| self.critic_head = nn.Linear(128, 1) | |
| def forward(self, state): | |
| features = self.feature_extractor(state) | |
| action_params = self.actor_head(features) | |
| mean, log_std = torch.chunk(action_params, 2, dim=-1) | |
| value = self.critic_head(features) | |
| return mean, log_std, value | |
| class PGAgent: | |
| def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.95, gae_lambda=0.95, critic_loss_coef=0.5): | |
| self.gamma = gamma | |
| self.gae_lambda = gae_lambda | |
| self.critic_loss_coef = critic_loss_coef | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = SharedActorCritic(state_dim, action_dim).to(self.device) | |
| self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) | |
| self.log_probs = [] | |
| self.rewards = [] | |
| self.values = [] | |
| self.dones = [] | |
| self.log_std_min = -20 | |
| self.log_std_max = 2 | |
| def select_action(self, state): | |
| state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) | |
| mean, log_std, value = self.model(state_tensor) | |
| log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) | |
| std = torch.exp(log_std) | |
| dist = Normal(mean, std) | |
| action = dist.sample() | |
| log_prob = dist.log_prob(action).sum(dim=-1) | |
| self.log_probs.append(log_prob) | |
| self.values.append(value) | |
| return np.clip(action.squeeze(0).cpu().detach().numpy(), 0.0, 1.0) | |
| def update(self): | |
| if not self.rewards: | |
| return | |
| next_value = 0 | |
| values = torch.cat(self.values).squeeze().detach().cpu().numpy() | |
| advantages, returns = self._calculate_gae_advantages(self.rewards, values, self.dones, next_value) | |
| log_probs = torch.cat(self.log_probs) | |
| advantages = torch.tensor(advantages, dtype=torch.float32, device=self.device) | |
| returns = torch.tensor(returns, dtype=torch.float32, device=self.device) | |
| advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) | |
| actor_loss = -(log_probs * advantages).mean() | |
| critic_values = torch.cat(self.values).squeeze() | |
| critic_loss = nn.MSELoss()(critic_values, returns) | |
| total_loss = actor_loss + self.critic_loss_coef * critic_loss | |
| self.optimizer.zero_grad() | |
| total_loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) | |
| self.optimizer.step() | |
| self.rewards = [] | |
| self.log_probs = [] | |
| self.values = [] | |
| self.dones = [] | |
| def _calculate_gae_advantages(self, rewards, values, dones, next_value): | |
| advantages = np.zeros_like(rewards, dtype=np.float32) | |
| last_advantage = 0 | |
| for t in reversed(range(len(rewards))): | |
| mask = 1.0 - dones[t] | |
| v_next = values[t + 1] if t < len(rewards) - 1 else next_value | |
| delta = rewards[t] + self.gamma * v_next * mask - values[t] | |
| last_advantage = delta + self.gamma * self.gae_lambda * last_advantage * mask | |
| advantages[t] = last_advantage | |
| returns = advantages + values | |
| return advantages, returns | |
| def save(self, path): | |
| torch.save({ | |
| 'model_state_dict': self.model.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| }, path) | |
| def load(self, path): | |
| checkpoint = torch.load(path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |