import torch import torch.nn as nn class Autoencoder(nn.Module): def __init__(self, encoder_hidden_dims, decoder_hidden_dims): super(Autoencoder, self).__init__() encoder_layers = [] for i in range(len(encoder_hidden_dims)): if i == 0: encoder_layers.append(nn.Linear(512, encoder_hidden_dims[i])) else: encoder_layers.append(torch.nn.BatchNorm1d(encoder_hidden_dims[i-1])) encoder_layers.append(nn.ReLU()) encoder_layers.append(nn.Linear(encoder_hidden_dims[i-1], encoder_hidden_dims[i])) self.encoder = nn.ModuleList(encoder_layers) decoder_layers = [] for i in range(len(decoder_hidden_dims)): if i == 0: decoder_layers.append(nn.Linear(encoder_hidden_dims[-1], decoder_hidden_dims[i])) else: decoder_layers.append(nn.ReLU()) decoder_layers.append(nn.Linear(decoder_hidden_dims[i-1], decoder_hidden_dims[i])) self.decoder = nn.ModuleList(decoder_layers) print(self.encoder, self.decoder) def forward(self, x): for m in self.encoder: x = m(x) x = x / x.norm(dim=-1, keepdim=True) for m in self.decoder: x = m(x) x = x / x.norm(dim=-1, keepdim=True) return x def encode(self, x): for m in self.encoder: x = m(x) x = x / x.norm(dim=-1, keepdim=True) return x def decode(self, x): for m in self.decoder: x = m(x) x = x / x.norm(dim=-1, keepdim=True) return x