File size: 2,512 Bytes
43a6711
 
 
 
 
 
 
 
 
 
 
 
8aaebb0
 
 
43a6711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8aaebb0
e8fadb7
 
7ef6899
 
 
 
 
 
43a6711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ef6899
43a6711
 
 
 
 
 
 
 
7ef6899
 
 
e8fadb7
 
7ef6899
0f91e66
7ef6899
43a6711
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import cv2
import requests
from io import BytesIO
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import gradio as gr

from utils import *


# torchvision transforms
normalize = transforms.Normalize(
    mean = [0.485, 0.456, 0.406],
    std = [0.229, 0.224, 0.225]
)

unnormalize = transforms.Compose([
    transforms.Normalize(
        mean = [0., 0., 0.],
        std = [1/0.229, 1/0.224, 1/0.225]
    ),
    transforms.Normalize(
        mean = [-0.485, -0.456, -0.406],
        std = [1., 1., 1.]
    )
])


# inference script for huggingface space (assumes model is already loaded)
# TODO: interface supports drag & drop or url
def inference(url=None):
    # assert url is not None and pil_img is not None
    if url:
        response = requests.get(url)
        original_img = Image.open(BytesIO(response.content))
    else:
        original_img = pil_img
        
    img = np.array(original_img)
    img = cv2.resize(img, (512, 512))
    img = img / 255
    assert np.min(img) >= 0
    assert np.max(img) <= 1
    
    if len(img.shape) < 3: # grayscale coloring
        x = torch.Tensor(img)
        x = torch.stack([x, x, x], dim=0)
        x = normalize(x)
        x = x.unsqueeze(0)

    else: # RGB reconstruction
        x = torch.Tensor(img).permute(2, 0, 1)
        x = normalize(x)
        x_gs = cv2.cvtColor(x.permute(1, 2, 0).detach().cpu().numpy(), cv2.COLOR_BGR2GRAY)
        x_gs = np.dstack([x_gs, x_gs, x_gs])
        x = torch.Tensor(x_gs).permute(2, 0, 1).unsqueeze(0)
        
    pred = model(x)
    res = unnormalize(pred.squeeze(0))
    res = res.clamp(0, 1)
    res = res.permute(1, 2, 0).detach().cpu().numpy()
    
    colored_img = cv2.resize(res, original_img.size)
    colored_img = Image.fromarray((colored_img * 255).astype(np.uint8))
    
    if len(img.shape) >= 3:
        colored_img = postprocess_img(original_img, colored_img)
        
    return colored_img, original_img, original_img.convert('L')

# load torchscript model
model = torch.jit.load('torchscript/generator_torchscript.pt')
model = model.eval()

# gradio interface : i think it can take either a url or a drag and drop by setting one or the other to None
iface = gr.Interface(
    fn=inference, 
    # inputs=["text", gr.Image(type='pil')], 
    inputs='text',
    outputs=['image', 'image', 'image'],
    title='paste a url (right click -> copy image address)'
)
iface.launch()