Spaces:
Runtime error
Runtime error
| import torch; torch.manual_seed(0) | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils | |
| import torch.distributions | |
| import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def get_activation(activation): | |
| if activation == 'tanh': | |
| activ = F.tanh | |
| elif activation == 'relu': | |
| activ = F.relu | |
| elif activation == 'mish': | |
| activ = F.mish | |
| elif activation == 'sigmoid': | |
| activ = F.sigmoid | |
| elif activation == 'leakyrelu': | |
| activ = F.leaky_relu | |
| elif activation == 'exp': | |
| activ = torch.exp | |
| else: | |
| raise ValueError | |
| return activ | |
| class IngredientEncoder(nn.Module): | |
| def __init__(self, input_dim, deepset_latent_dim, hidden_dims, activation, dropout): | |
| super(IngredientEncoder, self).__init__() | |
| self.linears = nn.ModuleList() | |
| self.dropouts = nn.ModuleList() | |
| dims = [input_dim] + hidden_dims + [deepset_latent_dim] | |
| for d_in, d_out in zip(dims[:-1], dims[1:]): | |
| self.linears.append(nn.Linear(d_in, d_out)) | |
| self.dropouts.append(nn.Dropout(dropout)) | |
| self.activation = get_activation(activation) | |
| self.n_layers = len(self.linears) | |
| self.layer_range = range(self.n_layers) | |
| def forward(self, x): | |
| for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): | |
| x = layer(x) | |
| if i_layer != self.n_layers - 1: | |
| x = self.activation(dropout(x)) | |
| return x # do not use dropout on last layer? | |
| class DeepsetCocktailEncoder(nn.Module): | |
| def __init__(self, input_dim, deepset_latent_dim, hidden_dims_ing, activation, | |
| hidden_dims_cocktail, latent_dim, aggregation, dropout): | |
| super(DeepsetCocktailEncoder, self).__init__() | |
| self.input_dim = input_dim # dimension of ingredient representation + quantity | |
| self.ingredient_encoder = IngredientEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, dropout) # encode each ingredient separately | |
| self.deepset_latent_dim = deepset_latent_dim # dimension of the deepset aggregation | |
| self.aggregation = aggregation | |
| self.latent_dim = latent_dim | |
| # post aggregation network | |
| self.linears = nn.ModuleList() | |
| self.dropouts = nn.ModuleList() | |
| dims = [deepset_latent_dim] + hidden_dims_cocktail | |
| for d_in, d_out in zip(dims[:-1], dims[1:]): | |
| self.linears.append(nn.Linear(d_in, d_out)) | |
| self.dropouts.append(nn.Dropout(dropout)) | |
| self.FC_mean = nn.Linear(hidden_dims_cocktail[-1], latent_dim) | |
| self.FC_logvar = nn.Linear(hidden_dims_cocktail[-1], latent_dim) | |
| self.softplus = nn.Softplus() | |
| self.activation = get_activation(activation) | |
| self.n_layers = len(self.linears) | |
| self.layer_range = range(self.n_layers) | |
| def forward(self, nb_ingredients, x): | |
| # reshape x in (batch size * nb ingredients, dim_ing_rep) | |
| batch_size = x.shape[0] | |
| all_ingredients = [] | |
| for i in range(batch_size): | |
| for j in range(nb_ingredients[i]): | |
| all_ingredients.append(x[i, self.input_dim * j: self.input_dim * (j + 1)].reshape(1, -1)) | |
| x = torch.cat(all_ingredients, dim=0) | |
| # encode ingredients in parallel | |
| ingredients_encodings = self.ingredient_encoder(x) | |
| assert ingredients_encodings.shape == (torch.sum(nb_ingredients), self.deepset_latent_dim) | |
| # aggregate | |
| x = [] | |
| index_first = 0 | |
| for i in range(batch_size): | |
| index_last = index_first + nb_ingredients[i] | |
| # aggregate | |
| if self.aggregation == 'sum': | |
| x.append(torch.sum(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1)) | |
| elif self.aggregation == 'mean': | |
| x.append(torch.mean(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1)) | |
| else: | |
| raise ValueError | |
| index_first = index_last | |
| x = torch.cat(x, dim=0) | |
| assert x.shape[0] == batch_size | |
| for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): | |
| x = self.activation(dropout(layer(x))) | |
| mean = self.FC_mean(x) | |
| logvar = self.FC_logvar(x) | |
| return mean, logvar | |
| class Decoder(nn.Module): | |
| def __init__(self, latent_dim, hidden_dims, num_ingredients, activation, dropout, filter_output=None): | |
| super(Decoder, self).__init__() | |
| self.linears = nn.ModuleList() | |
| self.dropouts = nn.ModuleList() | |
| dims = [latent_dim] + hidden_dims + [num_ingredients] | |
| for d_in, d_out in zip(dims[:-1], dims[1:]): | |
| self.linears.append(nn.Linear(d_in, d_out)) | |
| self.dropouts.append(nn.Dropout(dropout)) | |
| self.activation = get_activation(activation) | |
| self.n_layers = len(self.linears) | |
| self.layer_range = range(self.n_layers) | |
| self.filter = filter_output | |
| def forward(self, x, to_filter=False): | |
| for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): | |
| x = layer(x) | |
| if i_layer != self.n_layers - 1: | |
| x = self.activation(dropout(x)) | |
| if to_filter: | |
| x = self.filter(x) | |
| return x | |
| class PredictorHead(nn.Module): | |
| def __init__(self, latent_dim, dim_output, final_activ): | |
| super(PredictorHead, self).__init__() | |
| self.linear = nn.Linear(latent_dim, dim_output) | |
| if final_activ != None: | |
| self.final_activ = get_activation(final_activ) | |
| self.use_final_activ = True | |
| else: | |
| self.use_final_activ = False | |
| def forward(self, x): | |
| x = self.linear(x) | |
| if self.use_final_activ: x = self.final_activ(x) | |
| return x | |
| class VAEModel(nn.Module): | |
| def __init__(self, encoder, decoder, auxiliaries_dict): | |
| super(VAEModel, self).__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.latent_dim = self.encoder.latent_dim | |
| self.auxiliaries_str = [] | |
| self.auxiliaries = nn.ModuleList() | |
| for aux_str in sorted(auxiliaries_dict.keys()): | |
| if aux_str == 'taste_reps': | |
| self.taste_reps_decoder = PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ']) | |
| else: | |
| self.auxiliaries_str.append(aux_str) | |
| self.auxiliaries.append(PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ'])) | |
| def reparameterization(self, mean, logvar): | |
| std = torch.exp(0.5 * logvar) | |
| epsilon = torch.randn_like(std).to(device) # sampling epsilon | |
| z = mean + std * epsilon # reparameterization trick | |
| return z | |
| def sample(self, n=1): | |
| z = torch.randn(size=(n, self.latent_dim)) | |
| return self.decoder(z) | |
| def get_all_auxiliaries(self, x): | |
| return [aux(x) for aux in self.auxiliaries] | |
| def get_auxiliary(self, z, aux_str): | |
| if aux_str == 'taste_reps': | |
| return self.taste_reps_decoder(z) | |
| else: | |
| index = self.auxiliaries_str.index(aux_str) | |
| return self.auxiliaries[index](z) | |
| def forward_direct(self, x, aux_str=None, to_filter=False): | |
| mean, logvar = self.encoder(x) | |
| z = self.reparameterization(mean, logvar) # takes exponential function (log var -> std) | |
| x_hat = self.decoder(mean, to_filter=to_filter) | |
| if aux_str is not None: | |
| return x_hat, z, mean, logvar, self.get_auxiliary(z, aux_str), [aux_str] | |
| else: | |
| return x_hat, z, mean, logvar, self.get_all_auxiliaries(z), self.auxiliaries_str | |
| def forward(self, nb_ingredients, x, aux_str=None, to_filter=False): | |
| assert False | |
| mean, std = self.encoder(nb_ingredients, x) | |
| z = self.reparameterization(mean, std) # takes exponential function (log var -> std) | |
| x_hat = self.decoder(mean, to_filter=to_filter) | |
| if aux_str is not None: | |
| return x_hat, z, mean, std, self.get_auxiliary(z, aux_str), [aux_str] | |
| else: | |
| return x_hat, z, mean, std, self.get_all_auxiliaries(z), self.auxiliaries_str | |
| class SimpleEncoder(nn.Module): | |
| def __init__(self, input_dim, hidden_dims, latent_dim, activation, dropout): | |
| super(SimpleEncoder, self).__init__() | |
| self.latent_dim = latent_dim | |
| # post aggregation network | |
| self.linears = nn.ModuleList() | |
| self.dropouts = nn.ModuleList() | |
| dims = [input_dim] + hidden_dims | |
| for d_in, d_out in zip(dims[:-1], dims[1:]): | |
| self.linears.append(nn.Linear(d_in, d_out)) | |
| self.dropouts.append(nn.Dropout(dropout)) | |
| self.FC_mean = nn.Linear(hidden_dims[-1], latent_dim) | |
| self.FC_logvar = nn.Linear(hidden_dims[-1], latent_dim) | |
| # self.softplus = nn.Softplus() | |
| self.activation = get_activation(activation) | |
| self.n_layers = len(self.linears) | |
| self.layer_range = range(self.n_layers) | |
| def forward(self, x): | |
| for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): | |
| x = self.activation(dropout(layer(x))) | |
| mean = self.FC_mean(x) | |
| logvar = self.FC_logvar(x) | |
| return mean, logvar | |
| def get_vae_model(input_dim, deepset_latent_dim, hidden_dims_ing, activation, | |
| hidden_dims_cocktail, hidden_dims_decoder, num_ingredients, latent_dim, aggregation, dropout, auxiliaries_dict, | |
| filter_decoder_output): | |
| # encoder = DeepsetCocktailEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, | |
| # hidden_dims_cocktail, latent_dim, aggregation, dropout) | |
| encoder = SimpleEncoder(num_ingredients, hidden_dims_cocktail, latent_dim, activation, dropout) | |
| decoder = Decoder(latent_dim, hidden_dims_decoder, num_ingredients, activation, dropout, filter_output=filter_decoder_output) | |
| vae = VAEModel(encoder, decoder, auxiliaries_dict) | |
| return vae |