Spaces:
Sleeping
Sleeping
| """vae.py module.""" | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import List | |
| class VAE(nn.Module): | |
| def __init__(self, input_dim: int, latent_dim: int, hidden_dims: List[int]): | |
| super(VAE, self).__init__() | |
| # Encoder | |
| modules = [] | |
| in_features = input_dim | |
| for h_dim in hidden_dims: | |
| modules.append(nn.Linear(in_features, h_dim)) | |
| modules.append(nn.ReLU()) | |
| in_features = h_dim | |
| self.encoder = nn.Sequential(*modules) | |
| # Latent space | |
| self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim) | |
| self.fc_var = nn.Linear(hidden_dims[-1], latent_dim) | |
| # Decoder | |
| modules = [] | |
| hidden_dims.reverse() | |
| in_features = latent_dim | |
| for h_dim in hidden_dims: | |
| modules.append(nn.Linear(in_features, h_dim)) | |
| modules.append(nn.ReLU()) | |
| in_features = h_dim | |
| modules.append(nn.Linear(hidden_dims[-1], input_dim)) | |
| self.decoder = nn.Sequential(*modules) | |
| def encode(self, x): | |
| h = self.encoder(x) | |
| return self.fc_mu(h), self.fc_var(h) | |
| def decode(self, z): | |
| return self.decoder(z) | |
| def reparameterize(self, mu, log_var): | |
| std = torch.exp(0.5 * log_var) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def forward(self, x): | |
| mu, log_var = self.encode(x) | |
| z = self.reparameterize(mu, log_var) | |
| return self.decode(z), mu, log_var | |