Spaces:
Running
Running
| # with <3 to... | |
| # https://github.com/AntixK/PyTorch-VAE/blob/master/models/lvae.py | |
| # https://mbernste.github.io/posts/vae/ | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.datasets as datasets | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| class VAE(nn.Module): | |
| def __init__(self, z_dim=50): | |
| super(VAE, self).__init__() | |
| # Define autoencoding layers | |
| self.enc_conv1 = nn.Conv2d( | |
| in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1 | |
| ) # 32x32x32 | |
| self.enc_conv2 = nn.Conv2d(32, 64, 4, 2, 1) # 64x16x16 | |
| self.enc_conv3 = nn.Conv2d(64, 128, 4, 2, 1) # 128x8x8 | |
| self.enc_conv4 = nn.Conv2d(128, 256, 4, 2, 1) # 256x4x4 | |
| # Define autoencoding layers | |
| self.enc_fc_mu = nn.Linear(256 * 4 * 4, z_dim) | |
| self.enc_fc_logvar = nn.Linear(256 * 4 * 4, z_dim) | |
| # Decoder: Fully connected layer to expand latent vector | |
| self.dec_fc = nn.Linear(z_dim, 256 * 4 * 4) | |
| self.flatten = nn.Flatten() | |
| # Decoder: Transposed convolutional layers | |
| self.dec_conv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1) # 128x8x8 | |
| self.dec_conv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1) # 64x16x16 | |
| self.dec_conv3 = nn.ConvTranspose2d(64, 32, 4, 2, 1) # 32x32x32 | |
| self.dec_conv4 = nn.ConvTranspose2d(32, 3, 4, 2, 1) # 3x64x64 | |
| def encoder(self, x): | |
| x = F.relu(self.enc_conv1(x)) | |
| x = F.relu(self.enc_conv2(x)) | |
| x = F.relu(self.enc_conv3(x)) | |
| x = F.relu(self.enc_conv4(x)) | |
| x = self.flatten(x) | |
| mu = self.enc_fc_mu(x) | |
| logvar = self.enc_fc_logvar(x) | |
| return mu, logvar | |
| def reparameterize(self, mu, logvar): | |
| std = torch.exp(logvar / 2) | |
| eps = torch.randn_like(std) | |
| z = mu + std * eps | |
| return z | |
| def decoder(self, z): | |
| x = F.relu(self.dec_fc(z)) | |
| x = x.view(-1, 256, 4, 4) # Reshape to (batch_size, 256, 4, 4) | |
| x = F.relu(self.dec_conv1(x)) | |
| x = F.relu(self.dec_conv2(x)) | |
| x = F.relu(self.dec_conv3(x)) | |
| x = torch.tanh(self.dec_conv4(x)) # Sigmoid for output between 0 and 1 | |
| return x | |
| def forward(self, x): | |
| mu, logvar = self.encoder(x) | |
| z = self.reparameterize(mu, logvar) | |
| output = self.decoder(z) | |
| return output, z, mu, logvar | |
| def get(weights=None): | |
| return VAE() | |
| # Define the loss function | |
| def vae_loss_function(output, x, mu, logvar): | |
| # reconstruction loss | |
| recon_loss = F.mse_loss(output, x, reduction="sum") / x.size(0) | |
| kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
| return (recon_loss + 0.002 * kl_loss) * .001 | |