| | from .vae import VAE, Flatten, Stack |
| | import torch.nn as nn |
| | import pytorch_lightning as pl |
| | import torch |
| | import os |
| | import random |
| | from typing import Optional |
| | import torchvision.transforms as transforms |
| | from torchvision.datasets import MNIST, FashionMNIST, CelebA |
| | import torchvision.transforms as transforms |
| | from torch.utils.data import DataLoader |
| | from torchvision.utils import save_image |
| | from torch.optim import Adam |
| | from torch.optim.lr_scheduler import ReduceLROnPlateau |
| |
|
| | class PrintShape(nn.Module): |
| | def __init__(self): |
| | super(PrintShape, self).__init__() |
| |
|
| | def forward(self, x): |
| | |
| | |
| | return x |
| |
|
| | class UnFlatten(nn.Module): |
| | def forward(self, input, size=4096): |
| | |
| | return input.view(input.size(0), size, 1, 1) |
| |
|
| |
|
| | class Flatten(nn.Module): |
| | def forward(self, input): |
| | |
| | return input.view(input.size(0), -1) |
| |
|
| | class Conv_VAE(pl.LightningModule): |
| | def __init__(self, channels: int, height: int, width: int, lr: int, |
| | latent_size: int, hidden_size: int, alpha: int, batch_size: int, |
| | dataset: Optional[str] = None, |
| | save_images: Optional[bool] = None, |
| | save_path: Optional[str] = None, **kwargs): |
| | super().__init__() |
| | self.latent_size = latent_size |
| | self.hidden_size = hidden_size |
| | if save_images: |
| | self.save_path = f'{save_path}/{kwargs["model_type"]}_images/' |
| | self.save_hyperparameters() |
| | self.save_images = save_images |
| | self.lr = lr |
| | self.batch_size = batch_size |
| | self.alpha = alpha |
| | self.dataset = dataset |
| | assert not height % 4 and not width % 4, "Choose height and width to "\ |
| | "be divisible by 4" |
| | self.channels = channels |
| | self.height = height |
| | self.width = width |
| | self.latent_size = latent_size |
| | self.save_hyperparameters() |
| |
|
| | self.data_transform = transforms.Compose([ |
| | transforms.Resize(64), |
| | transforms.CenterCrop((64, 64)), |
| | transforms.ToTensor() |
| | ]) |
| |
|
| |
|
| | self.encoder = nn.Sequential( |
| | PrintShape(), |
| | nn.Conv2d(self.channels, 32, kernel_size=3, stride=2, padding=1), |
| | nn.BatchNorm2d(32), |
| | nn.LeakyReLU(), |
| | PrintShape(), |
| | nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), |
| | nn.BatchNorm2d(64), |
| | nn.LeakyReLU(), |
| | PrintShape(), |
| | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), |
| | nn.BatchNorm2d(128), |
| | nn.LeakyReLU(), |
| | PrintShape(), |
| | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), |
| | nn.BatchNorm2d(256), |
| | nn.LeakyReLU(), |
| | PrintShape(), |
| | Flatten(), |
| | PrintShape(), |
| | ) |
| |
|
| | self.fc1 = nn.Linear(self.hidden_size, self.latent_size) |
| | self.fc2 = nn.Linear(self.latent_size, self.hidden_size) |
| |
|
| | self.decoder = nn.Sequential( |
| | PrintShape(), |
| | |
| | |
| | |
| | UnFlatten(), |
| | PrintShape(), |
| | nn.ConvTranspose2d(self.hidden_size, 256, kernel_size=6, stride=2, padding=1), |
| | PrintShape(), |
| | nn.LeakyReLU(), |
| | nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), |
| | nn.BatchNorm2d(128), |
| | PrintShape(), |
| | nn.LeakyReLU(), |
| | nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), |
| | nn.BatchNorm2d(64), |
| | PrintShape(), |
| | nn.LeakyReLU(), |
| | nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), |
| | nn.BatchNorm2d(32), |
| | PrintShape(), |
| | nn.LeakyReLU(), |
| | nn.ConvTranspose2d(32, self.channels, kernel_size=4, stride=2, padding=1), |
| | nn.BatchNorm2d(self.channels), |
| | PrintShape(), |
| | nn.Sigmoid(), |
| | ) |
| |
|
| | def encode(self, x): |
| | hidden = self.encoder(x) |
| | mu, log_var = self.fc1(hidden), self.fc1(hidden) |
| | |
| | return mu, log_var |
| |
|
| | def decode(self, z): |
| | |
| | |
| | z = self.fc2(z) |
| | |
| | x = self.decoder(z) |
| | return x |
| | |
| | def reparametrize(self, mu, log_var): |
| | |
| | |
| | sigma = torch.exp(0.5*log_var) |
| | z = torch.randn_like(sigma) |
| | return mu + sigma*z |
| |
|
| | def training_step(self, batch, batch_idx): |
| | x, _ = batch |
| | mu, log_var, x_out = self.forward(x) |
| | kl_loss = (-0.5*(1+log_var - mu**2 - |
| | torch.exp(log_var)).sum(dim=1)).mean(dim=0) |
| | recon_loss_criterion = nn.MSELoss() |
| | recon_loss = recon_loss_criterion(x, x_out) |
| | |
| | loss = recon_loss*self.alpha + kl_loss |
| |
|
| | self.log('train_loss', loss, on_step=False, |
| | on_epoch=True, prog_bar=True) |
| | return loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | x, _ = batch |
| | mu, log_var, x_out = self.forward(x) |
| |
|
| | kl_loss = (-0.5*(1+log_var - mu**2 - |
| | torch.exp(log_var)).sum(dim=1)).mean(dim=0) |
| | recon_loss_criterion = nn.MSELoss() |
| | recon_loss = recon_loss_criterion(x, x_out) |
| | |
| | loss = recon_loss*self.alpha + kl_loss |
| | self.log('val_kl_loss', kl_loss, on_step=False, on_epoch=True) |
| | self.log('val_recon_loss', recon_loss, on_step=False, on_epoch=True) |
| | self.log('val_loss', loss, on_step=False, on_epoch=True) |
| | |
| | return x_out, loss |
| |
|
| | def validation_epoch_end(self, outputs): |
| | if not self.save_images: |
| | return |
| | if not os.path.exists(self.save_path): |
| | os.makedirs(self.save_path) |
| | choice = random.choice(outputs) |
| | output_sample = choice[0] |
| | output_sample = output_sample.reshape(-1, 1, self.width, self.height) |
| | |
| | save_image( |
| | output_sample, |
| | f"{self.save_path}/epoch_{self.current_epoch+1}.png", |
| | |
| | ) |
| |
|
| | def configure_optimizers(self): |
| | optimizer = Adam(self.parameters(), lr=(self.lr or self.learning_rate)) |
| | lr_scheduler = ReduceLROnPlateau(optimizer,) |
| | return { |
| | "optimizer": optimizer, "lr_scheduler": lr_scheduler, |
| | "monitor": "val_loss" |
| | } |
| | |
| | def forward(self, x): |
| | mu, log_var = self.encode(x) |
| | hidden = self.reparametrize(mu, log_var) |
| | output = self.decode(hidden) |
| | return mu, log_var, output |
| |
|
| | |
| | def train_dataloader(self): |
| | if self.dataset == "mnist": |
| | train_set = MNIST('data/', download=True, |
| | train=True, transform=self.data_transform) |
| | elif self.dataset == "fashion-mnist": |
| | train_set = FashionMNIST( |
| | 'data/', download=True, train=True, |
| | transform=self.data_transform) |
| | elif self.dataset == "celeba": |
| | train_set = CelebA('data/', download=False, split="train", transform=self.data_transform) |
| | return DataLoader(train_set, batch_size=self.batch_size, shuffle=True) |
| |
|
| | def val_dataloader(self): |
| | if self.dataset == "mnist": |
| | val_set = MNIST('data/', download=True, train=False, |
| | transform=self.data_transform) |
| | elif self.dataset == "fashion-mnist": |
| | val_set = FashionMNIST( |
| | 'data/', download=True, train=False, |
| | transform=self.data_transform) |
| | elif self.dataset == "celeba": |
| | val_set = CelebA('data/', download=False, split="valid", transform=self.data_transform) |
| | return DataLoader(val_set, batch_size=self.batch_size) |
| | |
| | def test_dataloader(self): |
| | if self.dataset == "mnist": |
| | val_set = MNIST('data/', download=True, train=False, |
| | transform=self.data_transform) |
| | elif self.dataset == "fashion-mnist": |
| | val_set = FashionMNIST( |
| | 'data/', download=True, train=False, |
| | transform=self.data_transform) |
| | elif self.dataset == "celeba": |
| | val_set = CelebA('data/', download=False, split="test", transform=self.data_transform) |
| | return DataLoader(val_set, batch_size=self.batch_size) |
| |
|
| |
|
| |
|
| |
|