import torch import torch.nn as nn from torch.distributions import Normal import numpy as np class SharedActorCritic(nn.Module): def __init__(self, state_dim, action_dim): super(SharedActorCritic, self).__init__() self.feature_extractor = nn.Sequential( nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU() ) self.actor_head = nn.Linear(128, action_dim * 2) self.critic_head = nn.Linear(128, 1) def forward(self, state): features = self.feature_extractor(state) action_params = self.actor_head(features) mean, log_std = torch.chunk(action_params, 2, dim=-1) value = self.critic_head(features) return mean, log_std, value class PGAgent: def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.95, gae_lambda=0.95, critic_loss_coef=0.5): self.gamma = gamma self.gae_lambda = gae_lambda self.critic_loss_coef = critic_loss_coef self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = SharedActorCritic(state_dim, action_dim).to(self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.log_probs = [] self.rewards = [] self.values = [] self.dones = [] self.log_std_min = -20 self.log_std_max = 2 def select_action(self, state): state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) mean, log_std, value = self.model(state_tensor) log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) std = torch.exp(log_std) dist = Normal(mean, std) action = dist.sample() log_prob = dist.log_prob(action).sum(dim=-1) self.log_probs.append(log_prob) self.values.append(value) return np.clip(action.squeeze(0).cpu().detach().numpy(), 0.0, 1.0) def update(self): if not self.rewards: return next_value = 0 values = torch.cat(self.values).squeeze().detach().cpu().numpy() advantages, returns = self._calculate_gae_advantages(self.rewards, values, self.dones, next_value) log_probs = torch.cat(self.log_probs) advantages = torch.tensor(advantages, dtype=torch.float32, device=self.device) returns = torch.tensor(returns, dtype=torch.float32, device=self.device) advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) actor_loss = -(log_probs * advantages).mean() critic_values = torch.cat(self.values).squeeze() critic_loss = nn.MSELoss()(critic_values, returns) total_loss = actor_loss + self.critic_loss_coef * critic_loss self.optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) self.optimizer.step() self.rewards = [] self.log_probs = [] self.values = [] self.dones = [] def _calculate_gae_advantages(self, rewards, values, dones, next_value): advantages = np.zeros_like(rewards, dtype=np.float32) last_advantage = 0 for t in reversed(range(len(rewards))): mask = 1.0 - dones[t] v_next = values[t + 1] if t < len(rewards) - 1 else next_value delta = rewards[t] + self.gamma * v_next * mask - values[t] last_advantage = delta + self.gamma * self.gae_lambda * last_advantage * mask advantages[t] = last_advantage returns = advantages + values return advantages, returns def save(self, path): torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), }, path) def load(self, path): checkpoint = torch.load(path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])