File size: 4,803 Bytes
1eefeba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn.functional as F
from replay_memory import ReplayMemory
from network import Twin_Q_net, GaussianPolicy
from temporary_buffer import TemporaryBuffer
from utils import hard_update, soft_update


class BPQLAgent:  # SAC for the base learning algorithm
    def __init__(self, args, state_dim, action_dim, action_bound, action_space, device):
        self.args = args

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_bound = action_bound

        self.device = device
        self.replay_memory = ReplayMemory(args.obs_delayed_steps + args.act_delayed_steps, state_dim, action_dim, device, args.buffer_size)
        self.temporary_buffer = TemporaryBuffer(args.obs_delayed_steps + args.act_delayed_steps)
        self.eval_temporary_buffer = TemporaryBuffer(args.obs_delayed_steps + args.act_delayed_steps)
        self.batch_size = args.batch_size

        self.gamma = args.gamma
        self.tau = args.tau

        self.actor = GaussianPolicy(args, args.obs_delayed_steps + args.act_delayed_steps, state_dim, action_dim, action_bound, args.hidden_dims, F.relu, device).to(device)
        self.critic = Twin_Q_net(state_dim, action_dim, device, args.hidden_dims).to(device)  # Network for the beta Q-values.
        self.target_critic = Twin_Q_net(state_dim, action_dim, device, args.hidden_dims).to(device)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=args.critic_lr)

        # Automated Entropy Adjustment for Maximum Entropy RL
        if args.automating_temperature is True:
            self.target_entropy = -torch.prod(torch.Tensor(action_space.shape)).to(device)
            self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
            self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=args.temperature_lr)
        else:
            self.log_alpha = torch.log(torch.tensor(args.temperature, device=device, dtype=torch.float32))

        hard_update(self.critic, self.target_critic)

    def get_action(self, state, evaluation=True):
        with torch.no_grad():
            if evaluation:
                _, _, action = self.actor.sample(state)
            else:
                action, _, _ = self.actor.sample(state)
        return action.cpu().numpy()[0]

    def train_actor(self, augmented_states, states, train_alpha=True):
        self.actor_optimizer.zero_grad()
        actions, log_pis, _ = self.actor.sample(augmented_states)
        q_values_A, q_values_B = self.critic(states, actions)
        q_values = torch.min(q_values_A, q_values_B)

        actor_loss = (self.log_alpha.exp().detach() * log_pis - q_values).mean()
        actor_loss.backward()
        self.actor_optimizer.step()

        if train_alpha:
            self.alpha_optimizer.zero_grad()
            alpha_loss = -(self.log_alpha.exp() * (log_pis + self.target_entropy).detach()).mean()
            alpha_loss.backward()
            self.alpha_optimizer.step()
        else:
            alpha_loss = torch.tensor(0.)

        return actor_loss.item(), alpha_loss.item()

    def train_critic(self, actions, rewards, next_augmented_states, dones,  states, next_states):
        self.critic_optimizer.zero_grad()
        with torch.no_grad():
            next_actions, next_log_pis, _ = self.actor.sample(next_augmented_states)
            next_q_values_A, next_q_values_B = self.target_critic(next_states, next_actions)
            next_q_values = torch.min(next_q_values_A, next_q_values_B) - self.log_alpha.exp() * next_log_pis
            target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        q_values_A, q_values_B = self.critic(states, actions)
        critic_loss = ((q_values_A - target_q_values)**2).mean() + ((q_values_B - target_q_values)**2).mean()

        critic_loss.backward()
        self.critic_optimizer.step()

        return critic_loss.item() # 2 * Squared-Loss = (2*|TD-error|^2)

    def train(self):
        augmented_states, actions, rewards, next_augmented_states, dones, states, next_states = self.replay_memory.sample(self.batch_size)

        critic_loss = self.train_critic(actions, rewards, next_augmented_states, dones, states, next_states)
        if self.args.automating_temperature is True:
            actor_loss, log_alpha_loss = self.train_actor(augmented_states, states, train_alpha=True)
        else:
            actor_loss, log_alpha_loss = self.train_actor(augmented_states, states, train_alpha=False)

        soft_update(self.critic, self.target_critic, self.tau)

        return critic_loss, actor_loss, log_alpha_loss