BPQL / network.py
jangwon-kim-cocel's picture
Upload 14 files
1eefeba verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from utils import weight_init
class Twin_Q_net(nn.Module):
def __init__(self, state_dim, action_dim, device, hidden_dims=(256, 256), activation_fc=F.relu):
super(Twin_Q_net, self).__init__()
self.device = device
self.activation_fc = activation_fc
self.input_layer_A = nn.Linear(state_dim + action_dim, hidden_dims[0])
self.hidden_layers_A = nn.ModuleList()
for i in range(len(hidden_dims)-1):
hidden_layer_A = nn.Linear(hidden_dims[i], hidden_dims[i+1])
self.hidden_layers_A.append(hidden_layer_A)
self.output_layer_A = nn.Linear(hidden_dims[-1], 1)
self.input_layer_B = nn.Linear(state_dim + action_dim, hidden_dims[0])
self.hidden_layers_B = nn.ModuleList()
for i in range(len(hidden_dims)-1):
hidden_layer_B = nn.Linear(hidden_dims[i], hidden_dims[i+1])
self.hidden_layers_B.append(hidden_layer_B)
self.output_layer_B = nn.Linear(hidden_dims[-1], 1)
self.apply(weight_init)
def _format(self, state, action):
x, u = state, action
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)
x = x.unsqueeze(0)
if not isinstance(u, torch.Tensor):
u = torch.tensor(u, device=self.device, dtype=torch.float32)
u = u.unsqueeze(0)
return x, u
def forward(self, state, action):
x, u = self._format(state, action)
x = torch.cat([x, u], dim=1)
x_A = self.activation_fc(self.input_layer_A(x))
for i, hidden_layer_A in enumerate(self.hidden_layers_A):
x_A = self.activation_fc(hidden_layer_A(x_A))
x_A = self.output_layer_A(x_A)
x_B = self.activation_fc(self.input_layer_B(x))
for i, hidden_layer_B in enumerate(self.hidden_layers_B):
x_B = self.activation_fc(hidden_layer_B(x_B))
x_B = self.output_layer_B(x_B)
return x_A, x_B
class GaussianPolicy(nn.Module):
def __init__(self, args, delayed_steps, state_dim, action_dim, action_bound,
hidden_dims=(256, 256), activation_fc=F.relu, device='cuda'):
super(GaussianPolicy, self).__init__()
self.device = device
self.log_std_min = args.log_std_bound[0]
self.log_std_max = args.log_std_bound[1]
self.activation_fc = activation_fc
self.input_layer = nn.Linear(state_dim + delayed_steps * action_dim, hidden_dims[0])
self.hidden_layers = nn.ModuleList()
for i in range(len(hidden_dims)-1):
hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
self.hidden_layers.append(hidden_layer)
self.mean_layer = nn.Linear(hidden_dims[-1], action_dim)
self.log_std_layer = nn.Linear(hidden_dims[-1], action_dim)
self.action_rescale = torch.as_tensor((action_bound[1] - action_bound[0]) / 2., dtype=torch.float32)
self.action_rescale_bias = torch.as_tensor((action_bound[1] + action_bound[0]) / 2., dtype=torch.float32)
self.apply(weight_init)
def _format(self, state):
x = state
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)
x = x.unsqueeze(0)
return x
def forward(self, state):
x = self._format(state)
x = self.activation_fc(self.input_layer(x))
for i, hidden_layer in enumerate(self.hidden_layers):
x = self.activation_fc(hidden_layer(x))
mean = self.mean_layer(x)
log_std = self.log_std_layer(x)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
return mean, log_std
def sample(self, state):
mean, log_std = self.forward(state)
distribution = Normal(mean, log_std.exp())
unbounded_action = distribution.rsample()
bounded_action = torch.tanh(unbounded_action)
action = bounded_action * self.action_rescale + self.action_rescale_bias
log_prob = distribution.log_prob(unbounded_action) - torch.log(self.action_rescale *
(1 - bounded_action.pow(2).clamp(0, 1)) + 1e-6)
log_prob = log_prob.sum(dim=1, keepdim=True)
mean = torch.tanh(mean) * self.action_rescale + self.action_rescale_bias
return action, log_prob, mean