File size: 5,536 Bytes
025bf23
 
c6d5483
 
b308e39
025bf23
 
c6d5483
 
 
 
 
 
025bf23
c6d5483
025bf23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b308e39
 
 
 
025bf23
 
 
c6d5483
b308e39
 
 
025bf23
 
 
b308e39
 
 
c6d5483
b308e39
025bf23
b308e39
 
 
025bf23
b308e39
025bf23
c6d5483
b308e39
 
025bf23
 
c6d5483
 
 
 
 
 
 
025bf23
c6d5483
b308e39
 
 
 
 
 
 
 
 
025bf23
 
 
 
c6d5483
67c17bf
b308e39
 
 
 
 
 
 
 
 
 
 
 
 
025bf23
b308e39
67c17bf
b308e39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from typing import Union, List

import gradio as gr
import matplotlib
import torch
from pytorch_lightning.utilities.types import EPOCH_OUTPUT

matplotlib.use('Agg')
import numpy as np
from PIL import Image
import albumentations as A
import albumentations.pytorch as al_pytorch
import torchvision
from pl_bolts.models.gans import Pix2Pix

""" Class """


class OverpoweredPix2Pix(Pix2Pix):

    def validation_step(self, batch, batch_idx):
        """ Validation step """
        real, condition = batch
        with torch.no_grad():
            loss = self._disc_step(real, condition)
            self.log("val_PatchGAN_loss", loss)

            loss = self._gen_step(real, condition)
            self.log("val_generator_loss", loss)

        return {
            'sketch': real,
            'colour': condition
        }

    def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
        sketch = outputs[0]['sketch']
        colour = outputs[0]['colour']
        with torch.no_grad():
            gen_coloured = self.gen(sketch)
        grid_image = torchvision.utils.make_grid(
            [
                sketch[0], colour[0], gen_coloured[0],
            ],
            normalize=True
        )
        self.logger.experiment.add_image(
            f'Image Grid {str(self.current_epoch)}',
            grid_image,
            self.current_epoch
        )


""" Load the model """
# train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt"
train_64_val_16_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=44600.ckpt"
train_16_val_1_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"

# Load the models
train_64_val_16_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint(
    train_64_val_16_plbolts_model_chkpt
)
train_64_val_16_plbolts_model.eval()

#
train_16_val_1_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint(
    train_16_val_1_plbolts_model_chkpt
)
train_16_val_1_plbolts_model.eval()


def predict(img: Image, type_of_model: str):
    """ Create predictions """
    # transform img
    image = np.asarray(img)
    # use on inference
    inference_transform = A.Compose([
        A.Resize(width=256, height=256),
        A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
        al_pytorch.ToTensorV2(),
    ])
    inference_img = inference_transform(
        image=image
    )['image'].unsqueeze(0)

    # Choose model
    if type_of_model == "train batch size 16, val batch size 1":
        model = train_16_val_1_plbolts_model
    elif type_of_model == "train batch size 64, val batch size 16":
        model = train_64_val_16_plbolts_model
    else:
        raise Exception("NOT YET SUPPORTED")

    with torch.no_grad():
        result = model.gen(inference_img)
        torchvision.utils.save_image(result, "inference_image.png", normalize=True)
    return "inference_image.png"  # 'coloured_image.png',


def predict1(img: Image):
    return predict(img=img, type_of_model="train batch size 16, val batch size 1")


def predict2(img: Image):
    return predict(img=img, type_of_model="train batch size 64, val batch size 16")


model_input = gr.inputs.Radio(
    [
        "train batch size 16, val batch size 1",
        "train batch size 64, val batch size 16",
        "train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16",
    ],
    label="Type of Pix2Pix model to use : "
)
image_input = gr.inputs.Image(type="pil")
img_examples = [
    "examples/thesis_test.png",
    "examples/thesis_test2.png",
    "examples/thesis1.png",
    "examples/thesis4.png",
    "examples/thesis5.png",
    "examples/thesis6.png",
]


with gr.Blocks() as demo:
    gr.Markdown(" # Colour your sketches!")
    gr.Markdown(" ## Description :")
    gr.Markdown(" There are three Pix2Pix models in this example:")
    gr.Markdown(" 1. Training batch size is 16 , validation is 1")
    gr.Markdown(" 2. Training batch size is 64 , validation is 16")
    gr.Markdown(" 3. PatchGAN is changed, 1 value only instead of 16*16 ;"
                "training batch size is 64 , validation is 16")
    with gr.Tabs():
        with gr.TabItem("tr_16_val_1"):
            with gr.Row():
                image_input1 = gr.inputs.Image(type="pil")
                image_output1 = gr.outputs.Image(type="pil", )
            colour_1 = gr.Button("Colour it!")
            gr.Examples(
                examples=img_examples,
                inputs=image_input1,
                outputs=image_output1,
                fn=predict1,
            )
        with gr.TabItem("tr_64_val_14"):
            with gr.Row():
                image_input2 = gr.inputs.Image(type="pil")
                image_output2 = gr.outputs.Image(type="pil", )
            colour_2 = gr.Button("Colour it!")
            with gr.Row():
                gr.Examples(
                    examples=img_examples,
                    inputs=image_input2,
                    outputs=image_output2,
                    fn=predict2,
                )

    colour_1.click(
        fn=predict1,
        inputs=image_input1,
        outputs=image_output1,
    )
    colour_2.click(
        fn=predict2,
        inputs=image_input2,
        outputs=image_output2,
    )

demo.title = "Colour your sketches!"
demo.launch()