Nikhil Mudhalwadkar
Recreate the demo
b308e39
raw
history blame
5.54 kB
from typing import Union, List
import gradio as gr
import matplotlib
import torch
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
matplotlib.use('Agg')
import numpy as np
from PIL import Image
import albumentations as A
import albumentations.pytorch as al_pytorch
import torchvision
from pl_bolts.models.gans import Pix2Pix
""" Class """
class OverpoweredPix2Pix(Pix2Pix):
def validation_step(self, batch, batch_idx):
""" Validation step """
real, condition = batch
with torch.no_grad():
loss = self._disc_step(real, condition)
self.log("val_PatchGAN_loss", loss)
loss = self._gen_step(real, condition)
self.log("val_generator_loss", loss)
return {
'sketch': real,
'colour': condition
}
def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
sketch = outputs[0]['sketch']
colour = outputs[0]['colour']
with torch.no_grad():
gen_coloured = self.gen(sketch)
grid_image = torchvision.utils.make_grid(
[
sketch[0], colour[0], gen_coloured[0],
],
normalize=True
)
self.logger.experiment.add_image(
f'Image Grid {str(self.current_epoch)}',
grid_image,
self.current_epoch
)
""" Load the model """
# train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt"
train_64_val_16_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=44600.ckpt"
train_16_val_1_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
# Load the models
train_64_val_16_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint(
train_64_val_16_plbolts_model_chkpt
)
train_64_val_16_plbolts_model.eval()
#
train_16_val_1_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint(
train_16_val_1_plbolts_model_chkpt
)
train_16_val_1_plbolts_model.eval()
def predict(img: Image, type_of_model: str):
""" Create predictions """
# transform img
image = np.asarray(img)
# use on inference
inference_transform = A.Compose([
A.Resize(width=256, height=256),
A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
al_pytorch.ToTensorV2(),
])
inference_img = inference_transform(
image=image
)['image'].unsqueeze(0)
# Choose model
if type_of_model == "train batch size 16, val batch size 1":
model = train_16_val_1_plbolts_model
elif type_of_model == "train batch size 64, val batch size 16":
model = train_64_val_16_plbolts_model
else:
raise Exception("NOT YET SUPPORTED")
with torch.no_grad():
result = model.gen(inference_img)
torchvision.utils.save_image(result, "inference_image.png", normalize=True)
return "inference_image.png" # 'coloured_image.png',
def predict1(img: Image):
return predict(img=img, type_of_model="train batch size 16, val batch size 1")
def predict2(img: Image):
return predict(img=img, type_of_model="train batch size 64, val batch size 16")
model_input = gr.inputs.Radio(
[
"train batch size 16, val batch size 1",
"train batch size 64, val batch size 16",
"train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16",
],
label="Type of Pix2Pix model to use : "
)
image_input = gr.inputs.Image(type="pil")
img_examples = [
"examples/thesis_test.png",
"examples/thesis_test2.png",
"examples/thesis1.png",
"examples/thesis4.png",
"examples/thesis5.png",
"examples/thesis6.png",
]
with gr.Blocks() as demo:
gr.Markdown(" # Colour your sketches!")
gr.Markdown(" ## Description :")
gr.Markdown(" There are three Pix2Pix models in this example:")
gr.Markdown(" 1. Training batch size is 16 , validation is 1")
gr.Markdown(" 2. Training batch size is 64 , validation is 16")
gr.Markdown(" 3. PatchGAN is changed, 1 value only instead of 16*16 ;"
"training batch size is 64 , validation is 16")
with gr.Tabs():
with gr.TabItem("tr_16_val_1"):
with gr.Row():
image_input1 = gr.inputs.Image(type="pil")
image_output1 = gr.outputs.Image(type="pil", )
colour_1 = gr.Button("Colour it!")
gr.Examples(
examples=img_examples,
inputs=image_input1,
outputs=image_output1,
fn=predict1,
)
with gr.TabItem("tr_64_val_14"):
with gr.Row():
image_input2 = gr.inputs.Image(type="pil")
image_output2 = gr.outputs.Image(type="pil", )
colour_2 = gr.Button("Colour it!")
with gr.Row():
gr.Examples(
examples=img_examples,
inputs=image_input2,
outputs=image_output2,
fn=predict2,
)
colour_1.click(
fn=predict1,
inputs=image_input1,
outputs=image_output1,
)
colour_2.click(
fn=predict2,
inputs=image_input2,
outputs=image_output2,
)
demo.title = "Colour your sketches!"
demo.launch()