Spaces:
Running
Running
File size: 2,787 Bytes
0917e8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# with <3 to...
# https://github.com/AntixK/PyTorch-VAE/blob/master/models/lvae.py
# https://mbernste.github.io/posts/vae/
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self, z_dim=50):
super(VAE, self).__init__()
# Define autoencoding layers
self.enc_conv1 = nn.Conv2d(
in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1
) # 32x32x32
self.enc_conv2 = nn.Conv2d(32, 64, 4, 2, 1) # 64x16x16
self.enc_conv3 = nn.Conv2d(64, 128, 4, 2, 1) # 128x8x8
self.enc_conv4 = nn.Conv2d(128, 256, 4, 2, 1) # 256x4x4
# Define autoencoding layers
self.enc_fc_mu = nn.Linear(256 * 4 * 4, z_dim)
self.enc_fc_logvar = nn.Linear(256 * 4 * 4, z_dim)
# Decoder: Fully connected layer to expand latent vector
self.dec_fc = nn.Linear(z_dim, 256 * 4 * 4)
self.flatten = nn.Flatten()
# Decoder: Transposed convolutional layers
self.dec_conv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1) # 128x8x8
self.dec_conv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1) # 64x16x16
self.dec_conv3 = nn.ConvTranspose2d(64, 32, 4, 2, 1) # 32x32x32
self.dec_conv4 = nn.ConvTranspose2d(32, 3, 4, 2, 1) # 3x64x64
def encoder(self, x):
x = F.relu(self.enc_conv1(x))
x = F.relu(self.enc_conv2(x))
x = F.relu(self.enc_conv3(x))
x = F.relu(self.enc_conv4(x))
x = self.flatten(x)
mu = self.enc_fc_mu(x)
logvar = self.enc_fc_logvar(x)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(logvar / 2)
eps = torch.randn_like(std)
z = mu + std * eps
return z
def decoder(self, z):
x = F.relu(self.dec_fc(z))
x = x.view(-1, 256, 4, 4) # Reshape to (batch_size, 256, 4, 4)
x = F.relu(self.dec_conv1(x))
x = F.relu(self.dec_conv2(x))
x = F.relu(self.dec_conv3(x))
x = torch.tanh(self.dec_conv4(x)) # Sigmoid for output between 0 and 1
return x
def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
output = self.decoder(z)
return output, z, mu, logvar
@staticmethod
def get(weights=None):
return VAE()
# Define the loss function
def vae_loss_function(output, x, mu, logvar):
# reconstruction loss
recon_loss = F.mse_loss(output, x, reduction="sum") / x.size(0)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return (recon_loss + 0.002 * kl_loss) * .001
|