| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.distributions import Normal | |
| from utils import weight_init | |
| class Twin_Q_net(nn.Module): | |
| def __init__(self, state_dim, action_dim, device, hidden_dims=(256, 256), activation_fc=F.relu): | |
| super(Twin_Q_net, self).__init__() | |
| self.device = device | |
| self.activation_fc = activation_fc | |
| self.input_layer_A = nn.Linear(state_dim + action_dim, hidden_dims[0]) | |
| self.hidden_layers_A = nn.ModuleList() | |
| for i in range(len(hidden_dims)-1): | |
| hidden_layer_A = nn.Linear(hidden_dims[i], hidden_dims[i+1]) | |
| self.hidden_layers_A.append(hidden_layer_A) | |
| self.output_layer_A = nn.Linear(hidden_dims[-1], 1) | |
| self.input_layer_B = nn.Linear(state_dim + action_dim, hidden_dims[0]) | |
| self.hidden_layers_B = nn.ModuleList() | |
| for i in range(len(hidden_dims)-1): | |
| hidden_layer_B = nn.Linear(hidden_dims[i], hidden_dims[i+1]) | |
| self.hidden_layers_B.append(hidden_layer_B) | |
| self.output_layer_B = nn.Linear(hidden_dims[-1], 1) | |
| self.apply(weight_init) | |
| def _format(self, state, action): | |
| x, u = state, action | |
| if not isinstance(x, torch.Tensor): | |
| x = torch.tensor(x, device=self.device, dtype=torch.float32) | |
| x = x.unsqueeze(0) | |
| if not isinstance(u, torch.Tensor): | |
| u = torch.tensor(u, device=self.device, dtype=torch.float32) | |
| u = u.unsqueeze(0) | |
| return x, u | |
| def forward(self, state, action): | |
| x, u = self._format(state, action) | |
| x = torch.cat([x, u], dim=1) | |
| x_A = self.activation_fc(self.input_layer_A(x)) | |
| for i, hidden_layer_A in enumerate(self.hidden_layers_A): | |
| x_A = self.activation_fc(hidden_layer_A(x_A)) | |
| x_A = self.output_layer_A(x_A) | |
| x_B = self.activation_fc(self.input_layer_B(x)) | |
| for i, hidden_layer_B in enumerate(self.hidden_layers_B): | |
| x_B = self.activation_fc(hidden_layer_B(x_B)) | |
| x_B = self.output_layer_B(x_B) | |
| return x_A, x_B | |
| class GaussianPolicy(nn.Module): | |
| def __init__(self, args, delayed_steps, state_dim, action_dim, action_bound, | |
| hidden_dims=(256, 256), activation_fc=F.relu, device='cuda'): | |
| super(GaussianPolicy, self).__init__() | |
| self.device = device | |
| self.log_std_min = args.log_std_bound[0] | |
| self.log_std_max = args.log_std_bound[1] | |
| self.activation_fc = activation_fc | |
| self.input_layer = nn.Linear(state_dim + delayed_steps * action_dim, hidden_dims[0]) | |
| self.hidden_layers = nn.ModuleList() | |
| for i in range(len(hidden_dims)-1): | |
| hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1]) | |
| self.hidden_layers.append(hidden_layer) | |
| self.mean_layer = nn.Linear(hidden_dims[-1], action_dim) | |
| self.log_std_layer = nn.Linear(hidden_dims[-1], action_dim) | |
| self.action_rescale = torch.as_tensor((action_bound[1] - action_bound[0]) / 2., dtype=torch.float32) | |
| self.action_rescale_bias = torch.as_tensor((action_bound[1] + action_bound[0]) / 2., dtype=torch.float32) | |
| self.apply(weight_init) | |
| def _format(self, state): | |
| x = state | |
| if not isinstance(x, torch.Tensor): | |
| x = torch.tensor(x, device=self.device, dtype=torch.float32) | |
| x = x.unsqueeze(0) | |
| return x | |
| def forward(self, state): | |
| x = self._format(state) | |
| x = self.activation_fc(self.input_layer(x)) | |
| for i, hidden_layer in enumerate(self.hidden_layers): | |
| x = self.activation_fc(hidden_layer(x)) | |
| mean = self.mean_layer(x) | |
| log_std = self.log_std_layer(x) | |
| log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) | |
| return mean, log_std | |
| def sample(self, state): | |
| mean, log_std = self.forward(state) | |
| distribution = Normal(mean, log_std.exp()) | |
| unbounded_action = distribution.rsample() | |
| bounded_action = torch.tanh(unbounded_action) | |
| action = bounded_action * self.action_rescale + self.action_rescale_bias | |
| log_prob = distribution.log_prob(unbounded_action) - torch.log(self.action_rescale * | |
| (1 - bounded_action.pow(2).clamp(0, 1)) + 1e-6) | |
| log_prob = log_prob.sum(dim=1, keepdim=True) | |
| mean = torch.tanh(mean) * self.action_rescale + self.action_rescale_bias | |
| return action, log_prob, mean | |