File size: 4,319 Bytes
96170c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import copy
import torch
import torch.nn.functional as F
from SGVLB import SGVLB
from network import Net, Critic


class BPDAgent(object):
    def __init__(
            self,
            env,
            args,
            env_info,
            thresholds,
            datasize,
            device,
            discount,
            tau,
            noise_clip,
            policy_freq,
            h,
            num_teacher_param,
    ):
        self.args = args
        self.env = env
        self.env_info = env_info

        self.actor = Net(env_info['state_dim'], env_info['action_dim'], env_info['action_bound'],
                         args.student_hidden_dims, thresholds['ALPHA_THRESHOLD'], thresholds['THETA_THRESHOLD'],
                         device=device).to(device)
        self.actor_target = copy.deepcopy(self.actor)
        self.sgvlb = SGVLB(self.actor, datasize, loss_type='l2', device=device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        self.critic = Critic(env_info['state_dim'], env_info['action_dim']).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        self.discount = discount
        self.tau = tau
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.datasize = datasize
        self.h = h

        self.total_it = 0
        self.kl_weight = 0

    def set_kl_weight(self, kl_weight):
        self.kl_weight = kl_weight
        return

    def test(self):
        self.actor.eval()
        with torch.no_grad():
            return_list = []
            for epi_cnt in range(1, self.args.num_test_epi):
                episode_return = 0
                done = False
                state, _ = self.env.reset()
                while not done:
                    action = self.actor(state)
                    action = action.cpu().numpy()[0]
                    next_state, reward, terminated, truncated, _ = self.env.step(action)
                    done = terminated or truncated
                    episode_return += reward
                    state = next_state
                return_list.append(episode_return)

        avg_return = sum(return_list) / len(return_list)
        max_return = max(return_list)
        min_return = min(return_list)

        return avg_return, max_return, min_return

    def train(self, transition):
        self.actor.train()

        self.total_it += 1

        states, actions, rewards, next_states, dones = transition

        with torch.no_grad():
            next_actions = (
                    self.actor_target(next_states)
            ).clamp(self.env_info['action_bound'][0], self.env_info['action_bound'][1])

            target_Q1, target_Q2 = self.critic_target(next_states, next_actions)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = rewards + (1 - dones) * self.discount * target_Q

        current_Q1, current_Q2 = self.critic(states, actions)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        if self.total_it % self.policy_freq == 0:
            pi = self.actor(states)
            Q = self.critic.Q1(states, pi)
            lmbda = (self.h * self.datasize) / Q.abs().mean().detach()

            actor_loss = -lmbda * Q.mean() + self.sgvlb(pi, actions, self.kl_weight)  # lambda = h*|D|/avg(|Q|)

            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def __del__(self):
        del self.actor
        del self.actor_target
        del self.critic
        del self.critic_target
        return