|
|
import logging |
|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class PolicyNet(torch.nn.Module): |
|
|
def __init__(self, state_dim, hidden_dim, action_dim, action_bound): |
|
|
super(PolicyNet, self).__init__() |
|
|
self.fc1 = torch.nn.Linear(state_dim, hidden_dim) |
|
|
self.fc2 = torch.nn.Linear(hidden_dim, action_dim) |
|
|
self.action_bound = action_bound |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.relu(self.fc1(x)) |
|
|
return torch.tanh(self.fc2(x)) * self.action_bound |
|
|
|
|
|
|
|
|
class RMMPolicyNet(torch.nn.Module): |
|
|
def __init__(self, state_dim, hidden_dim, action_dim): |
|
|
super(RMMPolicyNet, self).__init__() |
|
|
self.fc1 = nn.Sequential( |
|
|
nn.Linear(state_dim, hidden_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Linear(hidden_dim, action_dim), |
|
|
) |
|
|
self.fc2 = nn.Sequential( |
|
|
nn.Linear(state_dim+action_dim, hidden_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Linear(hidden_dim, action_dim), |
|
|
) |
|
|
def forward(self, x): |
|
|
a1 = torch.sigmoid(self.fc1(x)) |
|
|
x = torch.cat([x,a1],dim=1) |
|
|
a2 = torch.tanh(self.fc2(x)) |
|
|
return torch.cat([a1,a2],dim=1) |
|
|
|
|
|
class QValueNet(torch.nn.Module): |
|
|
def __init__(self, state_dim, hidden_dim, action_dim): |
|
|
super(QValueNet, self).__init__() |
|
|
self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim) |
|
|
self.fc2 = torch.nn.Linear(hidden_dim, 1) |
|
|
|
|
|
def forward(self, x, a): |
|
|
cat = torch.cat([x, a], dim=1) |
|
|
x = F.relu(self.fc1(cat)) |
|
|
return self.fc2(x) |
|
|
|
|
|
|
|
|
class TwoLayerFC(torch.nn.Module): |
|
|
def __init__( |
|
|
self, num_in, num_out, hidden_dim, activation=F.relu, out_fn=lambda x: x |
|
|
): |
|
|
super().__init__() |
|
|
self.fc1 = nn.Linear(num_in, hidden_dim) |
|
|
self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.fc3 = nn.Linear(hidden_dim, num_out) |
|
|
|
|
|
self.activation = activation |
|
|
self.out_fn = out_fn |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.activation(self.fc1(x)) |
|
|
x = self.activation(self.fc2(x)) |
|
|
x = self.out_fn(self.fc3(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class DDPG: |
|
|
"""DDPG algo""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_in_actor, |
|
|
num_out_actor, |
|
|
num_in_critic, |
|
|
hidden_dim, |
|
|
discrete, |
|
|
action_bound, |
|
|
sigma, |
|
|
actor_lr, |
|
|
critic_lr, |
|
|
tau, |
|
|
gamma, |
|
|
device, |
|
|
use_rmm=True, |
|
|
): |
|
|
|
|
|
out_fn = (lambda x: x) if discrete else (lambda x: torch.tanh(x) * action_bound) |
|
|
|
|
|
if use_rmm: |
|
|
self.actor = RMMPolicyNet( |
|
|
num_in_actor, |
|
|
hidden_dim, |
|
|
num_out_actor, |
|
|
).to(device) |
|
|
self.target_actor = RMMPolicyNet( |
|
|
num_in_actor, |
|
|
hidden_dim, |
|
|
num_out_actor, |
|
|
).to(device) |
|
|
else: |
|
|
self.actor = TwoLayerFC( |
|
|
num_in_actor, |
|
|
num_out_actor, |
|
|
hidden_dim, |
|
|
activation=F.relu, |
|
|
out_fn=out_fn, |
|
|
).to(device) |
|
|
self.target_actor = TwoLayerFC( |
|
|
num_in_actor, |
|
|
num_out_actor, |
|
|
hidden_dim, |
|
|
activation=F.relu, |
|
|
out_fn=out_fn, |
|
|
).to(device) |
|
|
|
|
|
self.critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device) |
|
|
self.target_critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device) |
|
|
self.target_critic.load_state_dict(self.critic.state_dict()) |
|
|
self.target_actor.load_state_dict(self.actor.state_dict()) |
|
|
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) |
|
|
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) |
|
|
self.gamma = gamma |
|
|
self.sigma = sigma |
|
|
self.action_bound = action_bound |
|
|
self.tau = tau |
|
|
self.action_dim = num_out_actor |
|
|
self.device = device |
|
|
|
|
|
def take_action(self, state): |
|
|
state = torch.tensor(np.expand_dims(state,0), dtype=torch.float).to(self.device) |
|
|
action = self.actor(state)[0].detach().cpu().numpy() |
|
|
|
|
|
action = action + self.sigma * np.random.randn(self.action_dim) |
|
|
action[0]=np.clip(action[0],0,1) |
|
|
action[1]=np.clip(action[1],-1,1) |
|
|
return action |
|
|
def save_state_dict(self,name): |
|
|
dicts = { |
|
|
"critic":self.critic.state_dict(), |
|
|
"target_critic":self.target_critic.state_dict(), |
|
|
"actor":self.actor.state_dict(), |
|
|
"target_actor":self.target_actor.state_dict() |
|
|
} |
|
|
torch.save(dicts,name) |
|
|
def load_state_dict(self,name): |
|
|
dicts = torch.load(name) |
|
|
self.critic.load_state_dict(dicts["critic"]) |
|
|
self.target_critic.load_state_dict(dicts["target_critic"]) |
|
|
self.actor.load_state_dict(dicts["actor"]) |
|
|
self.target_actor.load_state_dict(dicts["target_actor"]) |
|
|
def soft_update(self, net, target_net): |
|
|
for param_target, param in zip(target_net.parameters(), net.parameters()): |
|
|
param_target.data.copy_( |
|
|
param_target.data * (1.0 - self.tau) + param.data * self.tau |
|
|
) |
|
|
|
|
|
def update(self, transition_dict): |
|
|
states = torch.tensor(transition_dict["states"], dtype=torch.float).to( |
|
|
self.device |
|
|
) |
|
|
actions = ( |
|
|
torch.tensor(transition_dict["actions"], dtype=torch.float) |
|
|
.to(self.device) |
|
|
) |
|
|
rewards = ( |
|
|
torch.tensor(transition_dict["rewards"], dtype=torch.float) |
|
|
.view(-1, 1) |
|
|
.to(self.device) |
|
|
) |
|
|
next_states = torch.tensor( |
|
|
transition_dict["next_states"], dtype=torch.float |
|
|
).to(self.device) |
|
|
dones = ( |
|
|
torch.tensor(transition_dict["dones"], dtype=torch.float) |
|
|
.view(-1, 1) |
|
|
.to(self.device) |
|
|
) |
|
|
|
|
|
next_q_values = self.target_critic( |
|
|
torch.cat([next_states, self.target_actor(next_states)], dim=1) |
|
|
) |
|
|
q_targets = rewards + self.gamma * next_q_values * (1 - dones) |
|
|
critic_loss = torch.mean( |
|
|
F.mse_loss( |
|
|
self.critic(torch.cat([states, actions], dim=1)), |
|
|
q_targets, |
|
|
) |
|
|
) |
|
|
self.critic_optimizer.zero_grad() |
|
|
critic_loss.backward() |
|
|
self.critic_optimizer.step() |
|
|
|
|
|
actor_loss = -torch.mean( |
|
|
self.critic( |
|
|
torch.cat([states, self.actor(states)], dim=1) |
|
|
) |
|
|
) |
|
|
self.actor_optimizer.zero_grad() |
|
|
actor_loss.backward() |
|
|
self.actor_optimizer.step() |
|
|
logging.info(f"update DDPG: actor loss {actor_loss.item():.3f}, critic loss {critic_loss.item():.3f}, ") |
|
|
self.soft_update(self.actor, self.target_actor) |
|
|
self.soft_update(self.critic, self.target_critic) |
|
|
|