from typing import Union, List import gradio as gr import matplotlib import torch from pytorch_lightning.utilities.types import EPOCH_OUTPUT matplotlib.use('Agg') import numpy as np from PIL import Image import albumentations as A import albumentations.pytorch as al_pytorch import torchvision from pl_bolts.models.gans import Pix2Pix """ Class """ class OverpoweredPix2Pix(Pix2Pix): def validation_step(self, batch, batch_idx): """ Validation step """ real, condition = batch with torch.no_grad(): loss = self._disc_step(real, condition) self.log("val_PatchGAN_loss", loss) loss = self._gen_step(real, condition) self.log("val_generator_loss", loss) return { 'sketch': real, 'colour': condition } def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: sketch = outputs[0]['sketch'] colour = outputs[0]['colour'] with torch.no_grad(): gen_coloured = self.gen(sketch) grid_image = torchvision.utils.make_grid( [ sketch[0], colour[0], gen_coloured[0], ], normalize=True ) self.logger.experiment.add_image( f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch ) """ Load the model """ # train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt" train_64_val_16_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=44600.ckpt" train_16_val_1_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=89000.ckpt" # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt" # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth" # Load the models train_64_val_16_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint( train_64_val_16_plbolts_model_chkpt ) train_64_val_16_plbolts_model.eval() # train_16_val_1_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint( train_16_val_1_plbolts_model_chkpt ) train_16_val_1_plbolts_model.eval() def predict(img: Image, type_of_model: str): """ Create predictions """ # transform img image = np.asarray(img) # use on inference inference_transform = A.Compose([ A.Resize(width=256, height=256), A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), al_pytorch.ToTensorV2(), ]) inference_img = inference_transform( image=image )['image'].unsqueeze(0) # Choose model if type_of_model == "train batch size 16, val batch size 1": model = train_16_val_1_plbolts_model elif type_of_model == "train batch size 64, val batch size 16": model = train_64_val_16_plbolts_model else: raise Exception("NOT YET SUPPORTED") with torch.no_grad(): result = model.gen(inference_img) torchvision.utils.save_image(result, "inference_image.png", normalize=True) return "inference_image.png" # 'coloured_image.png', def predict1(img: Image): return predict(img=img, type_of_model="train batch size 16, val batch size 1") def predict2(img: Image): return predict(img=img, type_of_model="train batch size 64, val batch size 16") model_input = gr.inputs.Radio( [ "train batch size 16, val batch size 1", "train batch size 64, val batch size 16", "train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16", ], label="Type of Pix2Pix model to use : " ) image_input = gr.inputs.Image(type="pil") img_examples = [ "examples/thesis_test.png", "examples/thesis_test2.png", "examples/thesis1.png", "examples/thesis4.png", "examples/thesis5.png", "examples/thesis6.png", ] with gr.Blocks() as demo: gr.Markdown(" # Colour your sketches!") gr.Markdown(" ## Description :") gr.Markdown(" There are three Pix2Pix models in this example:") gr.Markdown(" 1. Training batch size is 16 , validation is 1") gr.Markdown(" 2. Training batch size is 64 , validation is 16") gr.Markdown(" 3. PatchGAN is changed, 1 value only instead of 16*16 ;" "training batch size is 64 , validation is 16") with gr.Tabs(): with gr.TabItem("tr_16_val_1"): with gr.Row(): image_input1 = gr.inputs.Image(type="pil") image_output1 = gr.outputs.Image(type="pil", ) colour_1 = gr.Button("Colour it!") gr.Examples( examples=img_examples, inputs=image_input1, outputs=image_output1, fn=predict1, ) with gr.TabItem("tr_64_val_14"): with gr.Row(): image_input2 = gr.inputs.Image(type="pil") image_output2 = gr.outputs.Image(type="pil", ) colour_2 = gr.Button("Colour it!") with gr.Row(): gr.Examples( examples=img_examples, inputs=image_input2, outputs=image_output2, fn=predict2, ) colour_1.click( fn=predict1, inputs=image_input1, outputs=image_output1, ) colour_2.click( fn=predict2, inputs=image_input2, outputs=image_output2, ) demo.title = "Colour your sketches!" demo.launch()