Spaces:
Runtime error
Runtime error
| import matplotlib.pyplot as plt | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| class Pix2PixLitModule(pl.LightningModule): | |
| """ Lightning Module for pix2pix """ | |
| def _weights_init(m): | |
| if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): | |
| torch.nn.init.normal_(m.weight, 0.0, 0.02) | |
| if isinstance(m, nn.BatchNorm2d): | |
| torch.nn.init.normal_(m.weight, 0.0, 0.02) | |
| torch.nn.init.constant_(m.bias, 0) | |
| def __init__( | |
| self, | |
| generator, | |
| discriminator, | |
| use_gpu: bool, | |
| lambda_recon=100 | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.gen = generator | |
| self.disc = discriminator | |
| # intializing weights | |
| self.gen = self.gen.apply(self._weights_init) | |
| self.disc = self.disc.apply(self._weights_init) | |
| # | |
| self.adversarial_criterion = nn.BCEWithLogitsLoss() | |
| self.recon_criterion = nn.L1Loss() | |
| self.lambda_l1 = lambda_recon | |
| def _gen_step(self, sketch, coloured_sketches): | |
| # Pix2Pix has adversarial and a reconstruction loss | |
| # First calculate the adversarial loss | |
| gen_coloured_sketches = self.gen(sketch) | |
| # disc_logits = self.disc(gen_coloured_sketches, coloured_sketches) | |
| disc_logits = self.disc(sketch, gen_coloured_sketches) | |
| adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits)) | |
| # calculate reconstruction loss | |
| recon_loss = self.recon_criterion(gen_coloured_sketches, coloured_sketches) * self.lambda_l1 | |
| # | |
| self.log("Gen recon_loss", recon_loss) | |
| self.log("Gen adversarial_loss", adversarial_loss) | |
| # | |
| return adversarial_loss + recon_loss | |
| def _disc_step(self, sketch, coloured_sketches): | |
| gen_coloured_sketches = self.gen(sketch).detach() | |
| # | |
| # fake_logits = self.disc(gen_coloured_sketches, coloured_sketches) | |
| fake_logits = self.disc(sketch, gen_coloured_sketches) | |
| real_logits = self.disc(sketch, coloured_sketches) | |
| # | |
| fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits)) | |
| real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits)) | |
| # | |
| self.log("PatchGAN fake_loss", fake_loss) | |
| self.log("PatchGAN real_loss", real_loss) | |
| return (real_loss + fake_loss) / 2 | |
| def forward(self, x): | |
| return self.gen(x) | |
| def training_step(self, batch, batch_idx, optimizer_idx): | |
| real, condition = batch | |
| loss = None | |
| if optimizer_idx == 0: | |
| loss = self._disc_step(real, condition) | |
| self.log("TRAIN_PatchGAN Loss", loss) | |
| elif optimizer_idx == 1: | |
| loss = self._gen_step(real, condition) | |
| self.log("TRAIN_Generator Loss", loss) | |
| return loss | |
| def validation_epoch_end(self, outputs) -> None: | |
| """ Log the images""" | |
| sketch = outputs[0]['sketch'] | |
| colour = outputs[0]['colour'] | |
| gen_coloured = self.gen(sketch) | |
| grid_image = torchvision.utils.make_grid( | |
| [sketch[0], colour[0], gen_coloured[0]], | |
| normalize=True | |
| ) | |
| self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch) | |
| #plt.imshow(grid_image.permute(1, 2, 0)) | |
| def validation_step(self, batch, batch_idx): | |
| """ Validation step """ | |
| real, condition = batch | |
| return { | |
| 'sketch': real, | |
| 'colour': condition | |
| } | |
| def configure_optimizers(self, lr=2e-4): | |
| gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999)) | |
| disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999)) | |
| return disc_opt, gen_opt | |
| # class EpochInference(pl.Callback): | |
| # """ | |
| # Callback on each end of training epoch | |
| # The callback will do inference on test dataloader based on corresponding checkpoints | |
| # The results will be saved as an image with 4-rows: | |
| # 1 - Input image e.g. grayscale edged input | |
| # 2 - Ground-truth | |
| # 3 - Single inference | |
| # 4 - Mean of hundred accumulated inference | |
| # Note that the inference have a noise factor that will generate different output on each execution | |
| # """ | |
| # | |
| # def __init__(self, dataloader, use_gpu: bool, *args, **kwargs): | |
| # super().__init__(*args, **kwargs) | |
| # self.dataloader = dataloader | |
| # self.use_gpu = use_gpu | |
| # | |
| # def on_train_epoch_end(self, trainer, pl_module): | |
| # super().on_train_epoch_end(trainer, pl_module) | |
| # data = next(iter(self.dataloader)) | |
| # image, target = data | |
| # if self.use_gpu: | |
| # image = image.cuda() | |
| # target = target.cuda() | |
| # with torch.no_grad(): | |
| # # Take average of multiple inference as there is a random noise | |
| # # Single | |
| # reconstruction_init = pl_module(image) | |
| # reconstruction_init = torch.clip(reconstruction_init, 0, 1) | |
| # # # Mean | |
| # # reconstruction_mean = torch.stack([pl_module(image) for _ in range(10)]) | |
| # # reconstruction_mean = torch.clip(reconstruction_mean, 0, 1) | |
| # # reconstruction_mean = torch.mean(reconstruction_mean, dim=0) | |
| # # Grayscale 1-D to 3-D | |
| # # image = torch.stack([image for _ in range(3)], dim=1) | |
| # # image = torch.squeeze(image) | |
| # grid_image = torchvision.utils.make_grid([image[0], target[0], reconstruction_init[0]]) | |
| # torchvision.utils.save_image(grid_image, fp=f'{trainer.default_root_dir}/epoch-{trainer.current_epoch:04}.png') | |