SolarSys2025's picture
Upload 30 files
55da406 verified
import torch
import torch.nn as nn
from torch.distributions import Normal
import numpy as np
class SharedActorCritic(nn.Module):
def __init__(self, state_dim, action_dim):
super(SharedActorCritic, self).__init__()
self.feature_extractor = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU()
)
self.actor_head = nn.Linear(128, action_dim * 2)
self.critic_head = nn.Linear(128, 1)
def forward(self, state):
features = self.feature_extractor(state)
action_params = self.actor_head(features)
mean, log_std = torch.chunk(action_params, 2, dim=-1)
value = self.critic_head(features)
return mean, log_std, value
class PGAgent:
def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.95, gae_lambda=0.95, critic_loss_coef=0.5):
self.gamma = gamma
self.gae_lambda = gae_lambda
self.critic_loss_coef = critic_loss_coef
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SharedActorCritic(state_dim, action_dim).to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
self.log_probs = []
self.rewards = []
self.values = []
self.dones = []
self.log_std_min = -20
self.log_std_max = 2
def select_action(self, state):
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
mean, log_std, value = self.model(state_tensor)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
std = torch.exp(log_std)
dist = Normal(mean, std)
action = dist.sample()
log_prob = dist.log_prob(action).sum(dim=-1)
self.log_probs.append(log_prob)
self.values.append(value)
return np.clip(action.squeeze(0).cpu().detach().numpy(), 0.0, 1.0)
def update(self):
if not self.rewards:
return
next_value = 0
values = torch.cat(self.values).squeeze().detach().cpu().numpy()
advantages, returns = self._calculate_gae_advantages(self.rewards, values, self.dones, next_value)
log_probs = torch.cat(self.log_probs)
advantages = torch.tensor(advantages, dtype=torch.float32, device=self.device)
returns = torch.tensor(returns, dtype=torch.float32, device=self.device)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
actor_loss = -(log_probs * advantages).mean()
critic_values = torch.cat(self.values).squeeze()
critic_loss = nn.MSELoss()(critic_values, returns)
total_loss = actor_loss + self.critic_loss_coef * critic_loss
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
self.optimizer.step()
self.rewards = []
self.log_probs = []
self.values = []
self.dones = []
def _calculate_gae_advantages(self, rewards, values, dones, next_value):
advantages = np.zeros_like(rewards, dtype=np.float32)
last_advantage = 0
for t in reversed(range(len(rewards))):
mask = 1.0 - dones[t]
v_next = values[t + 1] if t < len(rewards) - 1 else next_value
delta = rewards[t] + self.gamma * v_next * mask - values[t]
last_advantage = delta + self.gamma * self.gae_lambda * last_advantage * mask
advantages[t] = last_advantage
returns = advantages + values
return advantages, returns
def save(self, path):
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, path)
def load(self, path):
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])