Spaces:
Runtime error
Runtime error
| from typing import Union, List | |
| import gradio as gr | |
| import matplotlib | |
| import torch | |
| from pytorch_lightning.utilities.types import EPOCH_OUTPUT | |
| matplotlib.use('Agg') | |
| import numpy as np | |
| from PIL import Image | |
| import albumentations as A | |
| import albumentations.pytorch as al_pytorch | |
| import torchvision | |
| from pl_bolts.models.gans import Pix2Pix | |
| """ Class """ | |
| class OverpoweredPix2Pix(Pix2Pix): | |
| def validation_step(self, batch, batch_idx): | |
| """ Validation step """ | |
| real, condition = batch | |
| with torch.no_grad(): | |
| loss = self._disc_step(real, condition) | |
| self.log("val_PatchGAN_loss", loss) | |
| loss = self._gen_step(real, condition) | |
| self.log("val_generator_loss", loss) | |
| return { | |
| 'sketch': real, | |
| 'colour': condition | |
| } | |
| def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: | |
| sketch = outputs[0]['sketch'] | |
| colour = outputs[0]['colour'] | |
| with torch.no_grad(): | |
| gen_coloured = self.gen(sketch) | |
| grid_image = torchvision.utils.make_grid( | |
| [ | |
| sketch[0], colour[0], gen_coloured[0], | |
| ], | |
| normalize=True | |
| ) | |
| self.logger.experiment.add_image( | |
| f'Image Grid {str(self.current_epoch)}', | |
| grid_image, | |
| self.current_epoch | |
| ) | |
| """ Load the model """ | |
| # train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt" | |
| train_64_val_16_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=44600.ckpt" | |
| train_16_val_1_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=89000.ckpt" | |
| # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt" | |
| # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth" | |
| # Load the models | |
| train_64_val_16_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint( | |
| train_64_val_16_plbolts_model_chkpt | |
| ) | |
| train_64_val_16_plbolts_model.eval() | |
| # | |
| train_16_val_1_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint( | |
| train_16_val_1_plbolts_model_chkpt | |
| ) | |
| train_16_val_1_plbolts_model.eval() | |
| def predict(img: Image, type_of_model: str): | |
| """ Create predictions """ | |
| # transform img | |
| image = np.asarray(img) | |
| # use on inference | |
| inference_transform = A.Compose([ | |
| A.Resize(width=256, height=256), | |
| A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), | |
| al_pytorch.ToTensorV2(), | |
| ]) | |
| inference_img = inference_transform( | |
| image=image | |
| )['image'].unsqueeze(0) | |
| # Choose model | |
| if type_of_model == "train batch size 16, val batch size 1": | |
| model = train_16_val_1_plbolts_model | |
| elif type_of_model == "train batch size 64, val batch size 16": | |
| model = train_64_val_16_plbolts_model | |
| else: | |
| raise Exception("NOT YET SUPPORTED") | |
| with torch.no_grad(): | |
| result = model.gen(inference_img) | |
| torchvision.utils.save_image(result, "inference_image.png", normalize=True) | |
| return "inference_image.png" # 'coloured_image.png', | |
| def predict1(img: Image): | |
| return predict(img=img, type_of_model="train batch size 16, val batch size 1") | |
| def predict2(img: Image): | |
| return predict(img=img, type_of_model="train batch size 64, val batch size 16") | |
| model_input = gr.inputs.Radio( | |
| [ | |
| "train batch size 16, val batch size 1", | |
| "train batch size 64, val batch size 16", | |
| "train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16", | |
| ], | |
| label="Type of Pix2Pix model to use : " | |
| ) | |
| image_input = gr.inputs.Image(type="pil") | |
| img_examples = [ | |
| "examples/thesis_test.png", | |
| "examples/thesis_test2.png", | |
| "examples/thesis1.png", | |
| "examples/thesis4.png", | |
| "examples/thesis5.png", | |
| "examples/thesis6.png", | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Markdown(" # Colour your sketches!") | |
| gr.Markdown(" ## Description :") | |
| gr.Markdown(" There are three Pix2Pix models in this example:") | |
| gr.Markdown(" 1. Training batch size is 16 , validation is 1") | |
| gr.Markdown(" 2. Training batch size is 64 , validation is 16") | |
| gr.Markdown(" 3. PatchGAN is changed, 1 value only instead of 16*16 ;" | |
| "training batch size is 64 , validation is 16") | |
| with gr.Tabs(): | |
| with gr.TabItem("tr_16_val_1"): | |
| with gr.Row(): | |
| image_input1 = gr.inputs.Image(type="pil") | |
| image_output1 = gr.outputs.Image(type="pil", ) | |
| colour_1 = gr.Button("Colour it!") | |
| gr.Examples( | |
| examples=img_examples, | |
| inputs=image_input1, | |
| outputs=image_output1, | |
| fn=predict1, | |
| ) | |
| with gr.TabItem("tr_64_val_14"): | |
| with gr.Row(): | |
| image_input2 = gr.inputs.Image(type="pil") | |
| image_output2 = gr.outputs.Image(type="pil", ) | |
| colour_2 = gr.Button("Colour it!") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=img_examples, | |
| inputs=image_input2, | |
| outputs=image_output2, | |
| fn=predict2, | |
| ) | |
| colour_1.click( | |
| fn=predict1, | |
| inputs=image_input1, | |
| outputs=image_output1, | |
| ) | |
| colour_2.click( | |
| fn=predict2, | |
| inputs=image_input2, | |
| outputs=image_output2, | |
| ) | |
| demo.title = "Colour your sketches!" | |
| demo.launch() | |