Spaces:
Runtime error
Runtime error
| ### 1. Imports and class names setup ### | |
| import gradio as gr | |
| import os | |
| import torch | |
| import utils | |
| from typing import Tuple, Dict | |
| from model import TransformerNet | |
| from torchvision import transforms | |
| from PIL import Image | |
| from pathlib import Path | |
| # Get model | |
| model_dir = 'models/' | |
| models = list(Path(model_dir).glob("*/*.pth.tar")) | |
| models = sorted(models) | |
| # Get style image | |
| style_dir = 'style_images/' | |
| style_list = list(Path(style_dir).glob("*")) | |
| style_list = sorted(style_list) | |
| # Get examples | |
| example_list = [["examples/" + example] for example in os.listdir("examples")] | |
| def transfer(image, model): | |
| device = 'cpu' | |
| width = image.size[0] | |
| height = image.size[1] | |
| if width > 750 or height > 500: | |
| iamge = image.thumbnail((712, 474)) | |
| # load model | |
| style_model = TransformerNet() | |
| state_dict = torch.load(models[int(model)], map_location=torch.device('cpu')) | |
| style_model.load_state_dict(state_dict["state_dict"]) | |
| content_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda x: x.mul(255)) | |
| ]) | |
| content_image = content_transform(image) | |
| content_image = content_image.unsqueeze(0).to(device) | |
| style_model.eval() | |
| with torch.no_grad(): | |
| style_model.to(device) | |
| output = style_model(content_image).cpu() | |
| img = utils.deprocess(output[0]) | |
| img = Image.fromarray(img) | |
| return img, style_list[int(model)] | |
| # Create title, description and article strings | |
| title = "Image Style Transfer" | |
| description = "Choose a image that you want to transfer and the corresponding style. The app will be transfer your image. You will have received new image." | |
| article = "Model have created base on paper [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/pdf/1603.08155v1.pdf)." | |
| image_output_1 = gr.Image(label='Tranfer') # output result | |
| image_output_2 = gr.Image(label='Style Image') # Show style image | |
| # Create the Gradio demo | |
| demo = gr.Interface(fn=transfer, # mapping function from input to output | |
| inputs=[gr.Image(type="pil", label='Input'), | |
| gr.Dropdown(choices=[i.parent.name for i in models], value='rain_princess', type='index', label="Style", info="Chooses kind of style image")], # what are the inputs? | |
| outputs=[image_output_1, image_output_2], # our fn has two outputs, therefore we have two outputs | |
| examples=example_list, | |
| cache_examples=False, | |
| title=title, | |
| description=description, | |
| article=article) | |
| # Launch the demo! | |
| demo.launch() | |