File size: 4,052 Bytes
55da406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
from torch.distributions import Normal
import numpy as np

class SharedActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(SharedActorCritic, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.actor_head = nn.Linear(128, action_dim * 2)
        self.critic_head = nn.Linear(128, 1)

    def forward(self, state):
        features = self.feature_extractor(state)
        action_params = self.actor_head(features)
        mean, log_std = torch.chunk(action_params, 2, dim=-1)
        value = self.critic_head(features)
        return mean, log_std, value

class PGAgent:
    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.95, gae_lambda=0.95, critic_loss_coef=0.5):
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.critic_loss_coef = critic_loss_coef
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SharedActorCritic(state_dim, action_dim).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.log_probs = []
        self.rewards = []
        self.values = []
        self.dones = []
        self.log_std_min = -20
        self.log_std_max = 2

    def select_action(self, state):
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        mean, log_std, value = self.model(state_tensor)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = torch.exp(log_std)
        dist = Normal(mean, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        self.log_probs.append(log_prob)
        self.values.append(value)
        return np.clip(action.squeeze(0).cpu().detach().numpy(), 0.0, 1.0)

    def update(self):
        if not self.rewards:
            return
        next_value = 0
        values = torch.cat(self.values).squeeze().detach().cpu().numpy()
        advantages, returns = self._calculate_gae_advantages(self.rewards, values, self.dones, next_value)
        log_probs = torch.cat(self.log_probs)
        advantages = torch.tensor(advantages, dtype=torch.float32, device=self.device)
        returns = torch.tensor(returns, dtype=torch.float32, device=self.device)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        actor_loss = -(log_probs * advantages).mean()
        critic_values = torch.cat(self.values).squeeze()
        critic_loss = nn.MSELoss()(critic_values, returns)
        total_loss = actor_loss + self.critic_loss_coef * critic_loss
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
        self.optimizer.step()
        self.rewards = []
        self.log_probs = []
        self.values = []
        self.dones = []

    def _calculate_gae_advantages(self, rewards, values, dones, next_value):
        advantages = np.zeros_like(rewards, dtype=np.float32)
        last_advantage = 0
        for t in reversed(range(len(rewards))):
            mask = 1.0 - dones[t]
            v_next = values[t + 1] if t < len(rewards) - 1 else next_value
            delta = rewards[t] + self.gamma * v_next * mask - values[t]
            last_advantage = delta + self.gamma * self.gae_lambda * last_advantage * mask
            advantages[t] = last_advantage
        returns = advantages + values
        return advantages, returns

    def save(self, path):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)

    def load(self, path):
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])