Spaces:
Sleeping
Sleeping
File size: 4,335 Bytes
3743728 0f3cd84 5b29d56 764782e 3984452 41bee7b 0b8d9f0 0f3cd84 0b8d9f0 aef0baa 379e978 bd4b429 3a27073 764782e aef0baa 3a27073 0b8d9f0 aef0baa 0b8d9f0 3464fc9 3a27073 3743728 764782e 03c67b1 0b8d9f0 764782e 3743728 3464fc9 0b8d9f0 3464fc9 3743728 0f3cd84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import torchvision.transforms as transforms
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [ nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
# Initial convolution block
model0 = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
norm_layer(64),
nn.ReLU(inplace=True) ]
self.model0 = nn.Sequential(*model0)
# Downsampling
model1 = []
in_features = 64
out_features = in_features*2
for _ in range(2):
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2
self.model1 = nn.Sequential(*model1)
model2 = []
# Residual blocks
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
# Upsampling
model3 = []
out_features = in_features//2
for _ in range(2):
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2
self.model3 = nn.Sequential(*model3)
# Output layer
model4 = [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7)]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
anime = Generator(3, 1, 3)
anime.load_state_dict(torch.load('anime.pth', map_location=torch.device('cpu')))
anime.eval()
contour = Generator(3, 1, 3)
contour.load_state_dict(torch.load('contour.pth', map_location=torch.device('cpu')))
contour.eval()
opensketch = Generator(3, 1, 3)
opensketch.load_state_dict(torch.load('opensketch.pth', map_location=torch.device('cpu')))
opensketch.eval()
def predict(input_img, ver):
input_img = Image.open(input_img)
# transform = transforms.Compose([transforms.Resize(256, Image.BICUBIC), transforms.ToTensor()])
transform = transforms.Compose([transforms.ToTensor()])
input_img = transform(input_img)
input_img = torch.unsqueeze(input_img, 0)
drawing = 0
with torch.no_grad():
if ver == 'anime':
drawing = anime(input_img)[0].detach()
elif ver == 'contour':
drawing = contour(input_img)[0].detach()
else:
drawing = opensketch(input_img)[0].detach()
drawing = transforms.ToPILImage()(drawing)
return drawing
title="informative-drawings"
description="Gradio Demo for line drawing generation. "
# article = "<p style='text-align: center'><a href='TODO' target='_blank'>Project Page</a> | <a href='codelink' target='_blank'>Github</a></p>"
examples=[['cat.png', 'anime'], ['bridge.png', 'contour'], ['lizard.png', 'opensketch'],]
iface = gr.Interface(predict, [gr.Image(type='filepath'),
gr.Radio(['anime','opensketch','contour'], type="value", value='contour', label='version')],
gr.Image(type="pil"), title=title,description=description,examples=examples)
iface.launch() |