|
|
import cv2 |
|
|
import requests |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchvision |
|
|
import torchvision.transforms as transforms |
|
|
import gradio as gr |
|
|
|
|
|
from utils import * |
|
|
|
|
|
|
|
|
|
|
|
normalize = transforms.Normalize( |
|
|
mean = [0.485, 0.456, 0.406], |
|
|
std = [0.229, 0.224, 0.225] |
|
|
) |
|
|
|
|
|
unnormalize = transforms.Compose([ |
|
|
transforms.Normalize( |
|
|
mean = [0., 0., 0.], |
|
|
std = [1/0.229, 1/0.224, 1/0.225] |
|
|
), |
|
|
transforms.Normalize( |
|
|
mean = [-0.485, -0.456, -0.406], |
|
|
std = [1., 1., 1.] |
|
|
) |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(url=None): |
|
|
|
|
|
if url: |
|
|
response = requests.get(url) |
|
|
original_img = Image.open(BytesIO(response.content)) |
|
|
else: |
|
|
original_img = pil_img |
|
|
|
|
|
img = np.array(original_img) |
|
|
img = cv2.resize(img, (512, 512)) |
|
|
img = img / 255 |
|
|
assert np.min(img) >= 0 |
|
|
assert np.max(img) <= 1 |
|
|
|
|
|
if len(img.shape) < 3: |
|
|
x = torch.Tensor(img) |
|
|
x = torch.stack([x, x, x], dim=0) |
|
|
x = normalize(x) |
|
|
x = x.unsqueeze(0) |
|
|
|
|
|
else: |
|
|
x = torch.Tensor(img).permute(2, 0, 1) |
|
|
x = normalize(x) |
|
|
x_gs = cv2.cvtColor(x.permute(1, 2, 0).detach().cpu().numpy(), cv2.COLOR_BGR2GRAY) |
|
|
x_gs = np.dstack([x_gs, x_gs, x_gs]) |
|
|
x = torch.Tensor(x_gs).permute(2, 0, 1).unsqueeze(0) |
|
|
|
|
|
pred = model(x) |
|
|
res = unnormalize(pred.squeeze(0)) |
|
|
res = res.clamp(0, 1) |
|
|
res = res.permute(1, 2, 0).detach().cpu().numpy() |
|
|
|
|
|
colored_img = cv2.resize(res, original_img.size) |
|
|
colored_img = Image.fromarray((colored_img * 255).astype(np.uint8)) |
|
|
|
|
|
if len(img.shape) >= 3: |
|
|
colored_img = postprocess_img(original_img, colored_img) |
|
|
|
|
|
return colored_img, original_img, original_img.convert('L') |
|
|
|
|
|
|
|
|
model = torch.jit.load('torchscript/generator_torchscript.pt') |
|
|
model = model.eval() |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=inference, |
|
|
|
|
|
inputs='text', |
|
|
outputs=['image', 'image', 'image'], |
|
|
title='paste a url (right click -> copy image address)' |
|
|
) |
|
|
iface.launch() |