Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchvision.transforms.functional as TF | |
| import torchvision.transforms as transforms | |
| from src.model import Model | |
| import os | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def denorm_img(img: torch.Tensor): | |
| std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1) | |
| mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1) | |
| return torch.clip(img * std + mean, min=0, max=1) | |
| def main(inp1, inp2, alph, out_size=256): | |
| # print("inp1 ", inp1) | |
| # print("inp2 ", inp2) | |
| model = Model() | |
| model.load_state_dict(torch.load("./models/model_puddle.pt", map_location=torch.device(device))) | |
| model.eval() | |
| model.alpha = alph | |
| style = TF.to_tensor(inp1["composite"]) | |
| content = TF.to_tensor(inp2["composite"]) | |
| norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| transform = transforms.Compose([ | |
| transforms.Resize(out_size, antialias=True) | |
| ]) | |
| style, content = norm(style), norm(content) | |
| style, content = transform(style), transform(content) | |
| style, content = style.unsqueeze(0).to(device), content.unsqueeze(0).to(device) | |
| out = model(content, style) | |
| return denorm_img(out[0].detach()).permute(1, 2, 0).numpy() | |
| def update_crop_size(crop_size): | |
| return gr.update(crop_size=(crop_size, crop_size)) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Style Transfer with AdaIN") | |
| with gr.Row(variant="compact", equal_height=False): | |
| inp1 = gr.ImageEditor( | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| crop_size=(256, 256), | |
| eraser=False, | |
| brush=False, | |
| layers=False, | |
| label="Style", | |
| image_mode="RGB", | |
| transforms="crop", | |
| canvas_size=(512, 512) | |
| ) | |
| inp2 = gr.ImageEditor( | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| crop_size=(256, 256), | |
| eraser=False, | |
| brush=False, | |
| layers=False, | |
| label="Content", | |
| image_mode="RGB", | |
| transforms="crop", | |
| canvas_size=(512, 512) | |
| ) | |
| out = gr.Image(type="pil", label="Output") | |
| with gr.Row(): | |
| out_size = gr.Dropdown( | |
| choices=[256, 512], | |
| value=256, | |
| multiselect=False, | |
| interactive=True, | |
| allow_custom_value=True, | |
| label="Output size", | |
| info="Size of the output image" | |
| ) | |
| out_size.change(fn=update_crop_size, inputs=out_size, outputs=inp1) | |
| out_size.change(fn=update_crop_size, inputs=out_size, outputs=inp2) | |
| alph = gr.Slider(0, 1, value=1, label="Alpha", info="How much to change the original image", interactive=True, scale=3) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Style Examples") | |
| gr.Examples( | |
| examples=[ | |
| os.path.join(os.path.dirname(__file__), "data/styles/25.jpg"), | |
| os.path.join(os.path.dirname(__file__), "data/styles/2272.jpg"), | |
| os.path.join(os.path.dirname(__file__), "data/styles/2314.jpg"), | |
| ], | |
| inputs=inp1, | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("## Content Examples") | |
| gr.Examples( | |
| examples=[ | |
| # os.path.join(os.path.dirname(__file__), "data/content/bear.jpg"), | |
| os.path.join(os.path.dirname(__file__), "data/content/cat.jpg"), | |
| os.path.join(os.path.dirname(__file__), "data/content/cow.jpg"), | |
| os.path.join(os.path.dirname(__file__), "data/content/ducks.jpg"), | |
| ], | |
| inputs=inp2, | |
| ) | |
| btn = gr.Button("Run") | |
| btn.click(fn=main, inputs=[inp1, inp2, alph, out_size], outputs=out) | |
| demo.launch() | |