Sketch2ColourDemo / app /model /lit_model.py
Nikhil Mudhalwadkar
added other files
c6d5483
raw
history blame
5.79 kB
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision
class Pix2PixLitModule(pl.LightningModule):
""" Lightning Module for pix2pix """
@staticmethod
def _weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
def __init__(
self,
generator,
discriminator,
use_gpu: bool,
lambda_recon=100
):
super().__init__()
self.save_hyperparameters()
self.gen = generator
self.disc = discriminator
# intializing weights
self.gen = self.gen.apply(self._weights_init)
self.disc = self.disc.apply(self._weights_init)
#
self.adversarial_criterion = nn.BCEWithLogitsLoss()
self.recon_criterion = nn.L1Loss()
self.lambda_l1 = lambda_recon
def _gen_step(self, sketch, coloured_sketches):
# Pix2Pix has adversarial and a reconstruction loss
# First calculate the adversarial loss
gen_coloured_sketches = self.gen(sketch)
# disc_logits = self.disc(gen_coloured_sketches, coloured_sketches)
disc_logits = self.disc(sketch, gen_coloured_sketches)
adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))
# calculate reconstruction loss
recon_loss = self.recon_criterion(gen_coloured_sketches, coloured_sketches) * self.lambda_l1
#
self.log("Gen recon_loss", recon_loss)
self.log("Gen adversarial_loss", adversarial_loss)
#
return adversarial_loss + recon_loss
def _disc_step(self, sketch, coloured_sketches):
gen_coloured_sketches = self.gen(sketch).detach()
#
# fake_logits = self.disc(gen_coloured_sketches, coloured_sketches)
fake_logits = self.disc(sketch, gen_coloured_sketches)
real_logits = self.disc(sketch, coloured_sketches)
#
fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
#
self.log("PatchGAN fake_loss", fake_loss)
self.log("PatchGAN real_loss", real_loss)
return (real_loss + fake_loss) / 2
def forward(self, x):
return self.gen(x)
def training_step(self, batch, batch_idx, optimizer_idx):
real, condition = batch
loss = None
if optimizer_idx == 0:
loss = self._disc_step(real, condition)
self.log("TRAIN_PatchGAN Loss", loss)
elif optimizer_idx == 1:
loss = self._gen_step(real, condition)
self.log("TRAIN_Generator Loss", loss)
return loss
def validation_epoch_end(self, outputs) -> None:
""" Log the images"""
sketch = outputs[0]['sketch']
colour = outputs[0]['colour']
gen_coloured = self.gen(sketch)
grid_image = torchvision.utils.make_grid(
[sketch[0], colour[0], gen_coloured[0]],
normalize=True
)
self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)
#plt.imshow(grid_image.permute(1, 2, 0))
def validation_step(self, batch, batch_idx):
""" Validation step """
real, condition = batch
return {
'sketch': real,
'colour': condition
}
def configure_optimizers(self, lr=2e-4):
gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999))
disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999))
return disc_opt, gen_opt
# class EpochInference(pl.Callback):
# """
# Callback on each end of training epoch
# The callback will do inference on test dataloader based on corresponding checkpoints
# The results will be saved as an image with 4-rows:
# 1 - Input image e.g. grayscale edged input
# 2 - Ground-truth
# 3 - Single inference
# 4 - Mean of hundred accumulated inference
# Note that the inference have a noise factor that will generate different output on each execution
# """
#
# def __init__(self, dataloader, use_gpu: bool, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.dataloader = dataloader
# self.use_gpu = use_gpu
#
# def on_train_epoch_end(self, trainer, pl_module):
# super().on_train_epoch_end(trainer, pl_module)
# data = next(iter(self.dataloader))
# image, target = data
# if self.use_gpu:
# image = image.cuda()
# target = target.cuda()
# with torch.no_grad():
# # Take average of multiple inference as there is a random noise
# # Single
# reconstruction_init = pl_module(image)
# reconstruction_init = torch.clip(reconstruction_init, 0, 1)
# # # Mean
# # reconstruction_mean = torch.stack([pl_module(image) for _ in range(10)])
# # reconstruction_mean = torch.clip(reconstruction_mean, 0, 1)
# # reconstruction_mean = torch.mean(reconstruction_mean, dim=0)
# # Grayscale 1-D to 3-D
# # image = torch.stack([image for _ in range(3)], dim=1)
# # image = torch.squeeze(image)
# grid_image = torchvision.utils.make_grid([image[0], target[0], reconstruction_init[0]])
# torchvision.utils.save_image(grid_image, fp=f'{trainer.default_root_dir}/epoch-{trainer.current_epoch:04}.png')