sleepyml commited on
Commit
43a6711
·
1 Parent(s): 09c6f18
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import requests
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ import torchvision.transforms as transforms
11
+ import gradio as gr
12
+
13
+ # torchvision transforms
14
+ normalize = transforms.Normalize(
15
+ mean = [0.485, 0.456, 0.406],
16
+ std = [0.229, 0.224, 0.225]
17
+ )
18
+
19
+ unnormalize = transforms.Compose([
20
+ transforms.Normalize(
21
+ mean = [0., 0., 0.],
22
+ std = [1/0.229, 1/0.224, 1/0.225]
23
+ ),
24
+ transforms.Normalize(
25
+ mean = [-0.485, -0.456, -0.406],
26
+ std = [1., 1., 1.]
27
+ )
28
+ ])
29
+
30
+
31
+ # inference script for huggingface space (assumes model is already loaded)
32
+ def inference(url, postprocess=True):
33
+ response = requests.get(url)
34
+ original_img = Image.open(BytesIO(response.content))
35
+ img = np.array(original_img)
36
+
37
+ img = cv2.resize(img, (512, 512))
38
+ img = img / 255
39
+ assert np.min(img) >= 0
40
+ assert np.max(img) <= 1
41
+
42
+ if len(img.shape) < 3: # grayscale coloring
43
+ x = torch.Tensor(img)
44
+ x = torch.stack([x, x, x], dim=0)
45
+ x = normalize(x)
46
+ x = x.unsqueeze(0)
47
+
48
+ else: # RGB reconstruction
49
+ x = torch.Tensor(img).permute(2, 0, 1)
50
+ x = normalize(x)
51
+ x_gs = cv2.cvtColor(x.permute(1, 2, 0).detach().cpu().numpy(), cv2.COLOR_BGR2GRAY)
52
+ x_gs = np.dstack([x_gs, x_gs, x_gs])
53
+ x = torch.Tensor(x_gs).permute(2, 0, 1).unsqueeze(0)
54
+
55
+ pred = model(x)
56
+ res = unnormalize(pred.squeeze(0))
57
+ res = res.clamp(0, 1)
58
+ res = res.permute(1, 2, 0).detach().cpu().numpy()
59
+
60
+ colored_img = cv2.resize(res, original_img.size)
61
+ colored_img = Image.fromarray((colored_img * 255).astype(np.uint8))
62
+
63
+ if postprocess and len(img.shape) >= 3:
64
+ colored_img = postprocess_img(original_img, colored_img)
65
+
66
+ return colored_img, original_img, original_img.convert('L')
67
+
68
+ # load torchscript model
69
+ model = torch.jit.load('torchscript/generator_torchscript.pt')
70
+ model = model.eval()
71
+
72
+ # gradio interface
73
+ iface = gr.Interface(fn=inference, inputs=["text", "bool"], outputs=['image', 'image', 'image'])
74
+ iface.launch()
utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ def postprocess_img(original_img: Image.Image, colored_img: Image.Image) -> Image.Image:
7
+ original_np, colored_np = np.array(original_img), np.array(colored_img)
8
+ original_yuv = cv2.cvtColor(original_np, cv2.COLOR_BGR2YUV)
9
+ predicted_yuv = cv2.cvtColor(colored_np, cv2.COLOR_BGR2YUV)
10
+
11
+ processed_img = original_yuv.copy()
12
+ processed_img[:, :, 1:] = predicted_yuv[:, :, 1:]
13
+ processed_img = cv2.cvtColor(processed_img, cv2.COLOR_YUV2BGR)
14
+ processed_img = Image.fromarray(processed_img)
15
+ return processed_img
weights/.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # torchscript models
weights/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # torchscript models