Spaces:
Sleeping
Sleeping
| """ | |
| Module containing the main VAE class. | |
| """ | |
| import torch | |
| from torch import nn, optim | |
| from torch.nn import functional as F | |
| from disvae.utils.initialization import weights_init | |
| from .encoders import get_encoder | |
| from .decoders import get_decoder | |
| MODELS = ["Burgess"] | |
| def init_specific_model(model_type, img_size, latent_dim): | |
| """Return an instance of a VAE with encoder and decoder from `model_type`.""" | |
| model_type = model_type.lower().capitalize() | |
| if model_type not in MODELS: | |
| err = "Unkown model_type={}. Possible values: {}" | |
| raise ValueError(err.format(model_type, MODELS)) | |
| encoder = get_encoder(model_type) | |
| decoder = get_decoder(model_type) | |
| model = VAE(img_size, encoder, decoder, latent_dim) | |
| model.model_type = model_type # store to help reloading | |
| return model | |
| class VAE(nn.Module): | |
| def __init__(self, img_size, encoder, decoder, latent_dim): | |
| """ | |
| Class which defines model and forward pass. | |
| Parameters | |
| ---------- | |
| img_size : tuple of ints | |
| Size of images. E.g. (1, 32, 32) or (3, 64, 64). | |
| """ | |
| super(VAE, self).__init__() | |
| if list(img_size[1:]) not in [[32, 32], [64, 64]]: | |
| raise RuntimeError("{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(img_size)) | |
| self.latent_dim = latent_dim | |
| self.img_size = img_size | |
| self.num_pixels = self.img_size[1] * self.img_size[2] | |
| self.encoder = encoder(img_size, self.latent_dim) | |
| self.decoder = decoder(img_size, self.latent_dim) | |
| self.reset_parameters() | |
| def reparameterize(self, mean, logvar): | |
| """ | |
| Samples from a normal distribution using the reparameterization trick. | |
| Parameters | |
| ---------- | |
| mean : torch.Tensor | |
| Mean of the normal distribution. Shape (batch_size, latent_dim) | |
| logvar : torch.Tensor | |
| Diagonal log variance of the normal distribution. Shape (batch_size, | |
| latent_dim) | |
| """ | |
| if self.training: | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mean + std * eps | |
| else: | |
| # Reconstruction mode | |
| return mean | |
| def forward(self, x): | |
| """ | |
| Forward pass of model. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| Batch of data. Shape (batch_size, n_chan, height, width) | |
| """ | |
| latent_dist = self.encoder(x) | |
| latent_sample = self.reparameterize(*latent_dist) | |
| reconstruct = self.decoder(latent_sample) | |
| return reconstruct, latent_dist, latent_sample | |
| def reset_parameters(self): | |
| self.apply(weights_init) | |
| def sample_latent(self, x): | |
| """ | |
| Returns a sample from the latent distribution. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| Batch of data. Shape (batch_size, n_chan, height, width) | |
| """ | |
| latent_dist = self.encoder(x) | |
| latent_sample = self.reparameterize(*latent_dist) | |
| return latent_sample | |