|
|
import copy |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from SGVLB import SGVLB |
|
|
from network import Net, Critic |
|
|
|
|
|
|
|
|
class BPDAgent(object): |
|
|
def __init__( |
|
|
self, |
|
|
env, |
|
|
args, |
|
|
env_info, |
|
|
thresholds, |
|
|
datasize, |
|
|
device, |
|
|
discount, |
|
|
tau, |
|
|
noise_clip, |
|
|
policy_freq, |
|
|
h, |
|
|
num_teacher_param, |
|
|
): |
|
|
self.args = args |
|
|
self.env = env |
|
|
self.env_info = env_info |
|
|
|
|
|
self.actor = Net(env_info['state_dim'], env_info['action_dim'], env_info['action_bound'], |
|
|
args.student_hidden_dims, thresholds['ALPHA_THRESHOLD'], thresholds['THETA_THRESHOLD'], |
|
|
device=device).to(device) |
|
|
self.actor_target = copy.deepcopy(self.actor) |
|
|
self.sgvlb = SGVLB(self.actor, datasize, loss_type='l2', device=device) |
|
|
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) |
|
|
|
|
|
self.critic = Critic(env_info['state_dim'], env_info['action_dim']).to(device) |
|
|
self.critic_target = copy.deepcopy(self.critic) |
|
|
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) |
|
|
|
|
|
self.discount = discount |
|
|
self.tau = tau |
|
|
self.noise_clip = noise_clip |
|
|
self.policy_freq = policy_freq |
|
|
self.datasize = datasize |
|
|
self.h = h |
|
|
|
|
|
self.total_it = 0 |
|
|
self.kl_weight = 0 |
|
|
|
|
|
def set_kl_weight(self, kl_weight): |
|
|
self.kl_weight = kl_weight |
|
|
return |
|
|
|
|
|
def test(self): |
|
|
self.actor.eval() |
|
|
with torch.no_grad(): |
|
|
return_list = [] |
|
|
for epi_cnt in range(1, self.args.num_test_epi): |
|
|
episode_return = 0 |
|
|
done = False |
|
|
state, _ = self.env.reset() |
|
|
while not done: |
|
|
action = self.actor(state) |
|
|
action = action.cpu().numpy()[0] |
|
|
next_state, reward, terminated, truncated, _ = self.env.step(action) |
|
|
done = terminated or truncated |
|
|
episode_return += reward |
|
|
state = next_state |
|
|
return_list.append(episode_return) |
|
|
|
|
|
avg_return = sum(return_list) / len(return_list) |
|
|
max_return = max(return_list) |
|
|
min_return = min(return_list) |
|
|
|
|
|
return avg_return, max_return, min_return |
|
|
|
|
|
def train(self, transition): |
|
|
self.actor.train() |
|
|
|
|
|
self.total_it += 1 |
|
|
|
|
|
states, actions, rewards, next_states, dones = transition |
|
|
|
|
|
with torch.no_grad(): |
|
|
next_actions = ( |
|
|
self.actor_target(next_states) |
|
|
).clamp(self.env_info['action_bound'][0], self.env_info['action_bound'][1]) |
|
|
|
|
|
target_Q1, target_Q2 = self.critic_target(next_states, next_actions) |
|
|
target_Q = torch.min(target_Q1, target_Q2) |
|
|
target_Q = rewards + (1 - dones) * self.discount * target_Q |
|
|
|
|
|
current_Q1, current_Q2 = self.critic(states, actions) |
|
|
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) |
|
|
|
|
|
self.critic_optimizer.zero_grad() |
|
|
critic_loss.backward() |
|
|
self.critic_optimizer.step() |
|
|
|
|
|
if self.total_it % self.policy_freq == 0: |
|
|
pi = self.actor(states) |
|
|
Q = self.critic.Q1(states, pi) |
|
|
lmbda = (self.h * self.datasize) / Q.abs().mean().detach() |
|
|
|
|
|
actor_loss = -lmbda * Q.mean() + self.sgvlb(pi, actions, self.kl_weight) |
|
|
|
|
|
|
|
|
self.actor_optimizer.zero_grad() |
|
|
actor_loss.backward() |
|
|
self.actor_optimizer.step() |
|
|
|
|
|
|
|
|
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): |
|
|
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) |
|
|
|
|
|
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): |
|
|
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) |
|
|
|
|
|
def __del__(self): |
|
|
del self.actor |
|
|
del self.actor_target |
|
|
del self.critic |
|
|
del self.critic_target |
|
|
return |
|
|
|
|
|
|