| import torch | |
| import torch.nn as nn | |
| class Autoencoder(nn.Module): | |
| def __init__(self, encoder_hidden_dims, decoder_hidden_dims): | |
| super(Autoencoder, self).__init__() | |
| encoder_layers = [] | |
| for i in range(len(encoder_hidden_dims)): | |
| if i == 0: | |
| encoder_layers.append(nn.Linear(512, encoder_hidden_dims[i])) | |
| else: | |
| encoder_layers.append(torch.nn.BatchNorm1d(encoder_hidden_dims[i-1])) | |
| encoder_layers.append(nn.ReLU()) | |
| encoder_layers.append(nn.Linear(encoder_hidden_dims[i-1], encoder_hidden_dims[i])) | |
| self.encoder = nn.ModuleList(encoder_layers) | |
| decoder_layers = [] | |
| for i in range(len(decoder_hidden_dims)): | |
| if i == 0: | |
| decoder_layers.append(nn.Linear(encoder_hidden_dims[-1], decoder_hidden_dims[i])) | |
| else: | |
| decoder_layers.append(nn.ReLU()) | |
| decoder_layers.append(nn.Linear(decoder_hidden_dims[i-1], decoder_hidden_dims[i])) | |
| self.decoder = nn.ModuleList(decoder_layers) | |
| print(self.encoder, self.decoder) | |
| def forward(self, x): | |
| for m in self.encoder: | |
| x = m(x) | |
| x = x / x.norm(dim=-1, keepdim=True) | |
| for m in self.decoder: | |
| x = m(x) | |
| x = x / x.norm(dim=-1, keepdim=True) | |
| return x | |
| def encode(self, x): | |
| for m in self.encoder: | |
| x = m(x) | |
| x = x / x.norm(dim=-1, keepdim=True) | |
| return x | |
| def decode(self, x): | |
| for m in self.decoder: | |
| x = m(x) | |
| x = x / x.norm(dim=-1, keepdim=True) | |
| return x | |