File size: 5,790 Bytes
c6d5483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision


class Pix2PixLitModule(pl.LightningModule):
    """ Lightning Module for pix2pix """

    @staticmethod
    def _weights_init(m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            torch.nn.init.normal_(m.weight, 0.0, 0.02)
        if isinstance(m, nn.BatchNorm2d):
            torch.nn.init.normal_(m.weight, 0.0, 0.02)
            torch.nn.init.constant_(m.bias, 0)

    def __init__(
            self,
            generator,
            discriminator,
            use_gpu: bool,
            lambda_recon=100
    ):
        super().__init__()
        self.save_hyperparameters()

        self.gen = generator
        self.disc = discriminator

        # intializing weights
        self.gen = self.gen.apply(self._weights_init)
        self.disc = self.disc.apply(self._weights_init)
        #
        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.recon_criterion = nn.L1Loss()
        self.lambda_l1 = lambda_recon

    def _gen_step(self, sketch, coloured_sketches):
        # Pix2Pix has adversarial and a reconstruction loss
        # First calculate the adversarial loss
        gen_coloured_sketches = self.gen(sketch)
        # disc_logits = self.disc(gen_coloured_sketches, coloured_sketches)
        disc_logits = self.disc(sketch, gen_coloured_sketches)
        adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))
        # calculate reconstruction loss
        recon_loss = self.recon_criterion(gen_coloured_sketches, coloured_sketches) * self.lambda_l1
        #
        self.log("Gen recon_loss", recon_loss)
        self.log("Gen adversarial_loss", adversarial_loss)
        #
        return adversarial_loss + recon_loss

    def _disc_step(self, sketch, coloured_sketches):
        gen_coloured_sketches = self.gen(sketch).detach()
        #
        # fake_logits = self.disc(gen_coloured_sketches, coloured_sketches)
        fake_logits = self.disc(sketch, gen_coloured_sketches)
        real_logits = self.disc(sketch, coloured_sketches)
        #
        fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
        real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
        #
        self.log("PatchGAN fake_loss", fake_loss)
        self.log("PatchGAN real_loss", real_loss)
        return (real_loss + fake_loss) / 2

    def forward(self, x):
        return self.gen(x)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, condition = batch
        loss = None
        if optimizer_idx == 0:
            loss = self._disc_step(real, condition)
            self.log("TRAIN_PatchGAN Loss", loss)
        elif optimizer_idx == 1:
            loss = self._gen_step(real, condition)
            self.log("TRAIN_Generator Loss", loss)
        return loss

    def validation_epoch_end(self, outputs) -> None:
        """ Log the images"""
        sketch = outputs[0]['sketch']
        colour = outputs[0]['colour']
        gen_coloured = self.gen(sketch)
        grid_image = torchvision.utils.make_grid(
            [sketch[0], colour[0], gen_coloured[0]],
            normalize=True
        )
        self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)
        #plt.imshow(grid_image.permute(1, 2, 0))

    def validation_step(self, batch, batch_idx):
        """ Validation step """
        real, condition = batch
        return {
            'sketch': real,
            'colour': condition
        }

    def configure_optimizers(self, lr=2e-4):
        gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999))
        disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999))
        return disc_opt, gen_opt

# class EpochInference(pl.Callback):
#     """
#     Callback on each end of training epoch
#     The callback will do inference on test dataloader based on corresponding checkpoints
#     The results will be saved as an image with 4-rows:
#         1 - Input image e.g. grayscale edged input
#         2 - Ground-truth
#         3 - Single inference
#         4 - Mean of hundred accumulated inference
#     Note that the inference have a noise factor that will generate different output on each execution
#     """
#
#     def __init__(self, dataloader, use_gpu: bool, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.dataloader = dataloader
#         self.use_gpu = use_gpu
#
#     def on_train_epoch_end(self, trainer, pl_module):
#         super().on_train_epoch_end(trainer, pl_module)
#         data = next(iter(self.dataloader))
#         image, target = data
#         if self.use_gpu:
#             image = image.cuda()
#             target = target.cuda()
#         with torch.no_grad():
#             # Take average of multiple inference as there is a random noise
#             # Single
#             reconstruction_init = pl_module(image)
#             reconstruction_init = torch.clip(reconstruction_init, 0, 1)
#             # # Mean
#             # reconstruction_mean = torch.stack([pl_module(image) for _ in range(10)])
#             # reconstruction_mean = torch.clip(reconstruction_mean, 0, 1)
#             # reconstruction_mean = torch.mean(reconstruction_mean, dim=0)
#         # Grayscale 1-D to 3-D
#         # image = torch.stack([image for _ in range(3)], dim=1)
#         # image = torch.squeeze(image)
#         grid_image = torchvision.utils.make_grid([image[0], target[0], reconstruction_init[0]])
#         torchvision.utils.save_image(grid_image, fp=f'{trainer.default_root_dir}/epoch-{trainer.current_epoch:04}.png')