Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from torch.distributions.normal import Normal | |
| from .constants import * | |
| class Encoder(nn.Module): | |
| ''' | |
| Encoder Class | |
| Values: | |
| im_chan: the number of channels of the output image, a scalar | |
| hidden_dim: the inner dimension, a scalar | |
| ''' | |
| def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM): | |
| super(Encoder, self).__init__() | |
| self.z_dim = output_chan | |
| self.disc = nn.Sequential( | |
| self.make_disc_block(im_chan, hidden_dim), | |
| self.make_disc_block(hidden_dim, hidden_dim * 2), | |
| self.make_disc_block(hidden_dim * 2, hidden_dim * 4), | |
| self.make_disc_block(hidden_dim * 4, hidden_dim * 8), | |
| self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True), | |
| ) | |
| def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False): | |
| ''' | |
| Function to return a sequence of operations corresponding to a encoder block of the VAE, | |
| corresponding to a convolution, a batchnorm (except for in the last layer), and an activation | |
| Parameters: | |
| input_channels: how many channels the input feature representation has | |
| output_channels: how many channels the output feature representation should have | |
| kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size) | |
| stride: the stride of the convolution | |
| final_layer: whether we're on the final layer (affects activation and batchnorm) | |
| ''' | |
| if not final_layer: | |
| return nn.Sequential( | |
| nn.Conv2d(input_channels, output_channels, kernel_size, stride), | |
| nn.BatchNorm2d(output_channels), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| else: | |
| return nn.Sequential( | |
| nn.Conv2d(input_channels, output_channels, kernel_size, stride), | |
| ) | |
| def forward(self, image): | |
| ''' | |
| Function for completing a forward pass of the Encoder: Given an image tensor, | |
| returns a 1-dimension tensor representing fake/real. | |
| Parameters: | |
| image: a flattened image tensor with dimension (im_dim) | |
| ''' | |
| disc_pred = self.disc(image) | |
| encoding = disc_pred.view(len(disc_pred), -1) | |
| # The stddev output is treated as the log of the variance of the normal | |
| # distribution by convention and for numerical stability | |
| return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp() | |
| class Decoder(nn.Module): | |
| ''' | |
| Decoder Class | |
| Values: | |
| z_dim: the dimension of the noise vector, a scalar | |
| im_chan: the number of channels of the output image, a scalar | |
| hidden_dim: the inner dimension, a scalar | |
| ''' | |
| def __init__(self, z_dim=Z_DIM, im_chan=3, hidden_dim=DEC_HIDDEN_DIM): | |
| super(Decoder, self).__init__() | |
| self.z_dim = z_dim | |
| self.gen = nn.Sequential( | |
| self.make_gen_block(z_dim, hidden_dim * 16), | |
| self.make_gen_block(hidden_dim * 16, hidden_dim * 8, kernel_size=4, stride=1), | |
| self.make_gen_block(hidden_dim * 8, hidden_dim * 4), | |
| self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4), | |
| self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=4), | |
| self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True), | |
| ) | |
| def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False): | |
| ''' | |
| Function to return a sequence of operations corresponding to a Decoder block of the VAE, | |
| corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation | |
| Parameters: | |
| input_channels: how many channels the input feature representation has | |
| output_channels: how many channels the output feature representation should have | |
| kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size) | |
| stride: the stride of the convolution | |
| final_layer: whether we're on the final layer (affects activation and batchnorm) | |
| ''' | |
| if not final_layer: | |
| return nn.Sequential( | |
| nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), | |
| nn.BatchNorm2d(output_channels), | |
| nn.ReLU(inplace=True), | |
| ) | |
| else: | |
| return nn.Sequential( | |
| nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, noise): | |
| ''' | |
| Function for completing a forward pass of the Decoder: Given a noise vector, | |
| returns a generated image. | |
| Parameters: | |
| noise: a noise tensor with dimensions (batch_size, z_dim) | |
| ''' | |
| x = noise.view(len(noise), self.z_dim, 1, 1) | |
| return self.gen(x) | |
| class VAE(nn.Module): | |
| ''' | |
| VAE Class | |
| Values: | |
| z_dim: the dimension of the noise vector, a scalar | |
| im_chan: the number of channels of the output image, a scalar | |
| MNIST is black-and-white, so that's our default | |
| hidden_dim: the inner dimension, a scalar | |
| ''' | |
| def __init__(self, z_dim=Z_DIM, im_chan=3): | |
| super(VAE, self).__init__() | |
| self.z_dim = z_dim | |
| self.encode = Encoder(im_chan, z_dim) | |
| self.decode = Decoder(z_dim, im_chan) | |
| def forward(self, images): | |
| ''' | |
| Function for completing a forward pass of the Decoder: Given a noise vector, | |
| returns a generated image. | |
| Parameters: | |
| images: an image tensor with dimensions (batch_size, im_chan, im_height, im_width) | |
| Returns: | |
| decoding: the autoencoded image | |
| q_dist: the z-distribution of the encoding | |
| ''' | |
| q_mean, q_stddev = self.encode(images) | |
| q_dist = Normal(q_mean, q_stddev) | |
| z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation | |
| decoding = self.decode(z_sample) | |
| return decoding, q_dist |