| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from utils import weight_init, AvgL1Norm | |
| class EnsembleQNet(nn.Module): | |
| def __init__(self, num_critics, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.elu): | |
| super(EnsembleQNet, self).__init__() | |
| self.device = device | |
| self.activation_fc = activation_fc | |
| self.num_critics = num_critics | |
| self.q_nets = nn.ModuleList() | |
| for _ in range(self.num_critics): | |
| q_net = self._build_q_net(state_dim, action_dim, zs_dim, hidden_dims) | |
| self.q_nets.append(q_net) | |
| self.apply(weight_init) | |
| def _build_q_net(self, state_dim, action_dim, zs_dim, hidden_dims): | |
| q_net = nn.ModuleDict({ | |
| 's_input_layer': nn.Linear(state_dim + action_dim, hidden_dims[0]), | |
| 'emb_input_layer': nn.Linear(2 * zs_dim + hidden_dims[0], hidden_dims[0]), | |
| 'emb_hidden_layers': nn.ModuleList([ | |
| nn.Linear(hidden_dims[i], hidden_dims[i + 1]) for i in range(len(hidden_dims) - 1) | |
| ]), | |
| 'output_layer': nn.Linear(hidden_dims[-1], 1) | |
| }) | |
| return q_net | |
| def _format(self, state, action): | |
| x, u = state, action | |
| if not isinstance(x, torch.Tensor): | |
| x = torch.tensor(x, device=self.device, dtype=torch.float32) | |
| x = x.unsqueeze(0) | |
| if not isinstance(u, torch.Tensor): | |
| u = torch.tensor(u, device=self.device, dtype=torch.float32) | |
| u = u.unsqueeze(0) | |
| return x, u | |
| def forward(self, state, action, zsa, zs): | |
| s, a = self._format(state, action) | |
| sa = torch.cat([s, a], dim=1) | |
| embeddings = torch.cat([zsa, zs], dim=1) | |
| q_values = [] | |
| for q_net in self.q_nets: | |
| q = AvgL1Norm(q_net['s_input_layer'](sa)) | |
| q = torch.cat([q, embeddings], dim=1) | |
| q = self.activation_fc(q_net['emb_input_layer'](q)) | |
| for hidden_layer in q_net['emb_hidden_layers']: | |
| q = self.activation_fc(hidden_layer(q)) | |
| q = q_net['output_layer'](q) | |
| q_values.append(q) | |
| return torch.cat(q_values, dim=1) | |
| class Policy(nn.Module): | |
| def __init__(self, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.relu): | |
| super(Policy, self).__init__() | |
| self.device = device | |
| self.apply(weight_init) | |
| self.activation_fc = activation_fc | |
| self.s_input_layer = nn.Linear(state_dim, hidden_dims[0]) | |
| self.zss_input_layer = nn.Linear(zs_dim + hidden_dims[0], hidden_dims[0]) | |
| self.zss_hidden_layers = nn.ModuleList() | |
| for i in range(len(hidden_dims)-1): | |
| hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1]) | |
| self.zss_hidden_layers.append(hidden_layer) | |
| self.zss_output_layer = nn.Linear(hidden_dims[-1], action_dim) | |
| def _format(self, state): | |
| x = state | |
| if not isinstance(x, torch.Tensor): | |
| x = torch.tensor(x, device=self.device, dtype=torch.float32) | |
| x = x.unsqueeze(0) | |
| return x | |
| def forward(self, state, zs): | |
| state = self._format(state) | |
| state = AvgL1Norm(self.s_input_layer(state)) | |
| zss = torch.cat([state, zs], 1) | |
| zss = self.activation_fc(self.zss_input_layer(zss)) | |
| for i, hidden_layer in enumerate(self.zss_hidden_layers): | |
| zss = self.activation_fc(hidden_layer(zss)) | |
| zss = self.zss_output_layer(zss) | |
| action = torch.tanh(zss) | |
| return action | |
| class Encoder(nn.Module): | |
| def __init__(self, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.elu): | |
| super(Encoder, self).__init__() | |
| self.device = device | |
| self.activation_fc = activation_fc | |
| self.s_encoder_input_layer = nn.Linear(state_dim, hidden_dims[0]) | |
| self.s_encoder_hidden_layers = nn.ModuleList() | |
| for i in range(len(hidden_dims)-1): | |
| hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1]) | |
| self.s_encoder_hidden_layers.append(hidden_layer) | |
| self.s_encoder_output_layer = nn.Linear(hidden_dims[-1], zs_dim) | |
| self.zsa_encoder_input_layer = nn.Linear(zs_dim + action_dim, hidden_dims[0]) | |
| self.zsa_encoder_hidden_layers = nn.ModuleList() | |
| for i in range(len(hidden_dims)-1): | |
| hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1]) | |
| self.zsa_encoder_hidden_layers.append(hidden_layer) | |
| self.zsa_encoder_output_layer = nn.Linear(hidden_dims[-1], zs_dim) | |
| def _format(self, state): | |
| x = state | |
| if not isinstance(x, torch.Tensor): | |
| x = torch.tensor(x, device=self.device, dtype=torch.float32) | |
| x = x.unsqueeze(0) | |
| return x | |
| def zs(self, state): | |
| state = self._format(state) | |
| zs = self.activation_fc(self.s_encoder_input_layer(state)) | |
| for i, hidden_layer in enumerate(self.s_encoder_hidden_layers): | |
| zs = self.activation_fc(hidden_layer(zs)) | |
| zs = AvgL1Norm(self.s_encoder_output_layer(zs)) | |
| return zs | |
| def zsa(self, zs, action): | |
| action = self._format(action) | |
| zsa = torch.cat([zs, action], 1) | |
| zsa = self.activation_fc(self.zsa_encoder_input_layer(zsa)) | |
| for i, hidden_layer in enumerate(self.zsa_encoder_hidden_layers): | |
| zsa = self.activation_fc(hidden_layer(zsa)) | |
| zsa = self.zsa_encoder_output_layer(zsa) | |
| return zsa | |