import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions import Normal from utils import weight_init class Twin_Q_net(nn.Module): def __init__(self, state_dim, action_dim, device, hidden_dims=(256, 256), activation_fc=F.relu): super(Twin_Q_net, self).__init__() self.device = device self.activation_fc = activation_fc self.input_layer_A = nn.Linear(state_dim + action_dim, hidden_dims[0]) self.hidden_layers_A = nn.ModuleList() for i in range(len(hidden_dims)-1): hidden_layer_A = nn.Linear(hidden_dims[i], hidden_dims[i+1]) self.hidden_layers_A.append(hidden_layer_A) self.output_layer_A = nn.Linear(hidden_dims[-1], 1) self.input_layer_B = nn.Linear(state_dim + action_dim, hidden_dims[0]) self.hidden_layers_B = nn.ModuleList() for i in range(len(hidden_dims)-1): hidden_layer_B = nn.Linear(hidden_dims[i], hidden_dims[i+1]) self.hidden_layers_B.append(hidden_layer_B) self.output_layer_B = nn.Linear(hidden_dims[-1], 1) self.apply(weight_init) def _format(self, state, action): x, u = state, action if not isinstance(x, torch.Tensor): x = torch.tensor(x, device=self.device, dtype=torch.float32) x = x.unsqueeze(0) if not isinstance(u, torch.Tensor): u = torch.tensor(u, device=self.device, dtype=torch.float32) u = u.unsqueeze(0) return x, u def forward(self, state, action): x, u = self._format(state, action) x = torch.cat([x, u], dim=1) x_A = self.activation_fc(self.input_layer_A(x)) for i, hidden_layer_A in enumerate(self.hidden_layers_A): x_A = self.activation_fc(hidden_layer_A(x_A)) x_A = self.output_layer_A(x_A) x_B = self.activation_fc(self.input_layer_B(x)) for i, hidden_layer_B in enumerate(self.hidden_layers_B): x_B = self.activation_fc(hidden_layer_B(x_B)) x_B = self.output_layer_B(x_B) return x_A, x_B class GaussianPolicy(nn.Module): def __init__(self, args, delayed_steps, state_dim, action_dim, action_bound, hidden_dims=(256, 256), activation_fc=F.relu, device='cuda'): super(GaussianPolicy, self).__init__() self.device = device self.log_std_min = args.log_std_bound[0] self.log_std_max = args.log_std_bound[1] self.activation_fc = activation_fc self.input_layer = nn.Linear(state_dim + delayed_steps * action_dim, hidden_dims[0]) self.hidden_layers = nn.ModuleList() for i in range(len(hidden_dims)-1): hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1]) self.hidden_layers.append(hidden_layer) self.mean_layer = nn.Linear(hidden_dims[-1], action_dim) self.log_std_layer = nn.Linear(hidden_dims[-1], action_dim) self.action_rescale = torch.as_tensor((action_bound[1] - action_bound[0]) / 2., dtype=torch.float32) self.action_rescale_bias = torch.as_tensor((action_bound[1] + action_bound[0]) / 2., dtype=torch.float32) self.apply(weight_init) def _format(self, state): x = state if not isinstance(x, torch.Tensor): x = torch.tensor(x, device=self.device, dtype=torch.float32) x = x.unsqueeze(0) return x def forward(self, state): x = self._format(state) x = self.activation_fc(self.input_layer(x)) for i, hidden_layer in enumerate(self.hidden_layers): x = self.activation_fc(hidden_layer(x)) mean = self.mean_layer(x) log_std = self.log_std_layer(x) log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) return mean, log_std def sample(self, state): mean, log_std = self.forward(state) distribution = Normal(mean, log_std.exp()) unbounded_action = distribution.rsample() bounded_action = torch.tanh(unbounded_action) action = bounded_action * self.action_rescale + self.action_rescale_bias log_prob = distribution.log_prob(unbounded_action) - torch.log(self.action_rescale * (1 - bounded_action.pow(2).clamp(0, 1)) + 1e-6) log_prob = log_prob.sum(dim=1, keepdim=True) mean = torch.tanh(mean) * self.action_rescale + self.action_rescale_bias return action, log_prob, mean