Spaces:
Runtime error
Runtime error
File size: 5,790 Bytes
c6d5483 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision
class Pix2PixLitModule(pl.LightningModule):
""" Lightning Module for pix2pix """
@staticmethod
def _weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
def __init__(
self,
generator,
discriminator,
use_gpu: bool,
lambda_recon=100
):
super().__init__()
self.save_hyperparameters()
self.gen = generator
self.disc = discriminator
# intializing weights
self.gen = self.gen.apply(self._weights_init)
self.disc = self.disc.apply(self._weights_init)
#
self.adversarial_criterion = nn.BCEWithLogitsLoss()
self.recon_criterion = nn.L1Loss()
self.lambda_l1 = lambda_recon
def _gen_step(self, sketch, coloured_sketches):
# Pix2Pix has adversarial and a reconstruction loss
# First calculate the adversarial loss
gen_coloured_sketches = self.gen(sketch)
# disc_logits = self.disc(gen_coloured_sketches, coloured_sketches)
disc_logits = self.disc(sketch, gen_coloured_sketches)
adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))
# calculate reconstruction loss
recon_loss = self.recon_criterion(gen_coloured_sketches, coloured_sketches) * self.lambda_l1
#
self.log("Gen recon_loss", recon_loss)
self.log("Gen adversarial_loss", adversarial_loss)
#
return adversarial_loss + recon_loss
def _disc_step(self, sketch, coloured_sketches):
gen_coloured_sketches = self.gen(sketch).detach()
#
# fake_logits = self.disc(gen_coloured_sketches, coloured_sketches)
fake_logits = self.disc(sketch, gen_coloured_sketches)
real_logits = self.disc(sketch, coloured_sketches)
#
fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
#
self.log("PatchGAN fake_loss", fake_loss)
self.log("PatchGAN real_loss", real_loss)
return (real_loss + fake_loss) / 2
def forward(self, x):
return self.gen(x)
def training_step(self, batch, batch_idx, optimizer_idx):
real, condition = batch
loss = None
if optimizer_idx == 0:
loss = self._disc_step(real, condition)
self.log("TRAIN_PatchGAN Loss", loss)
elif optimizer_idx == 1:
loss = self._gen_step(real, condition)
self.log("TRAIN_Generator Loss", loss)
return loss
def validation_epoch_end(self, outputs) -> None:
""" Log the images"""
sketch = outputs[0]['sketch']
colour = outputs[0]['colour']
gen_coloured = self.gen(sketch)
grid_image = torchvision.utils.make_grid(
[sketch[0], colour[0], gen_coloured[0]],
normalize=True
)
self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)
#plt.imshow(grid_image.permute(1, 2, 0))
def validation_step(self, batch, batch_idx):
""" Validation step """
real, condition = batch
return {
'sketch': real,
'colour': condition
}
def configure_optimizers(self, lr=2e-4):
gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999))
disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999))
return disc_opt, gen_opt
# class EpochInference(pl.Callback):
# """
# Callback on each end of training epoch
# The callback will do inference on test dataloader based on corresponding checkpoints
# The results will be saved as an image with 4-rows:
# 1 - Input image e.g. grayscale edged input
# 2 - Ground-truth
# 3 - Single inference
# 4 - Mean of hundred accumulated inference
# Note that the inference have a noise factor that will generate different output on each execution
# """
#
# def __init__(self, dataloader, use_gpu: bool, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.dataloader = dataloader
# self.use_gpu = use_gpu
#
# def on_train_epoch_end(self, trainer, pl_module):
# super().on_train_epoch_end(trainer, pl_module)
# data = next(iter(self.dataloader))
# image, target = data
# if self.use_gpu:
# image = image.cuda()
# target = target.cuda()
# with torch.no_grad():
# # Take average of multiple inference as there is a random noise
# # Single
# reconstruction_init = pl_module(image)
# reconstruction_init = torch.clip(reconstruction_init, 0, 1)
# # # Mean
# # reconstruction_mean = torch.stack([pl_module(image) for _ in range(10)])
# # reconstruction_mean = torch.clip(reconstruction_mean, 0, 1)
# # reconstruction_mean = torch.mean(reconstruction_mean, dim=0)
# # Grayscale 1-D to 3-D
# # image = torch.stack([image for _ in range(3)], dim=1)
# # image = torch.squeeze(image)
# grid_image = torchvision.utils.make_grid([image[0], target[0], reconstruction_init[0]])
# torchvision.utils.save_image(grid_image, fp=f'{trainer.default_root_dir}/epoch-{trainer.current_epoch:04}.png')
|