akshatjain1004 commited on
Commit
1753ad9
·
1 Parent(s): 6907eb4

uploads app.py

Browse files
Files changed (1) hide show
  1. app.py +308 -0
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import Dict
3
+ sys.path.insert(0, 'gradio-modified')
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ from PIL import Image
9
+
10
+ import torch
11
+
12
+ if torch.cuda.is_available():
13
+ t = torch.cuda.get_device_properties(0).total_memory
14
+ r = torch.cuda.memory_reserved(0)
15
+ a = torch.cuda.memory_allocated(0)
16
+ f = t-a # free inside reserved
17
+ if f < 2**32:
18
+ device = 'cpu'
19
+ else:
20
+ device = 'cuda'
21
+ else:
22
+ device = 'cpu'
23
+ torch._C._jit_set_bailout_depth(0)
24
+
25
+ print('Use device:', device)
26
+
27
+
28
+ net = torch.jit.load(f'weights/pkp-v1.{device}.jit.pt')
29
+
30
+ class BaseColor(nn.Module):
31
+ def __init__(self):
32
+ super(BaseColor, self).__init__()
33
+
34
+ self.l_cent = 50.
35
+ self.l_norm = 100.
36
+ self.ab_norm = 110.
37
+
38
+ def normalize_l(self, in_l):
39
+ return (in_l-self.l_cent)/self.l_norm
40
+
41
+ def unnormalize_l(self, in_l):
42
+ return in_l*self.l_norm + self.l_cent
43
+
44
+ def normalize_ab(self, in_ab):
45
+ return in_ab/self.ab_norm
46
+
47
+ def unnormalize_ab(self, in_ab):
48
+ return in_ab*self.ab_norm
49
+
50
+
51
+
52
+ class ECCVGenerator(BaseColor):
53
+ def __init__(self, norm_layer=nn.BatchNorm2d):
54
+ super(ECCVGenerator, self).__init__()
55
+
56
+ model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
57
+ model1+=[nn.ReLU(True),]
58
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
59
+ model1+=[nn.ReLU(True),]
60
+ model1+=[norm_layer(64),]
61
+
62
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
63
+ model2+=[nn.ReLU(True),]
64
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
65
+ model2+=[nn.ReLU(True),]
66
+ model2+=[norm_layer(128),]
67
+
68
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
69
+ model3+=[nn.ReLU(True),]
70
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
71
+ model3+=[nn.ReLU(True),]
72
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
73
+ model3+=[nn.ReLU(True),]
74
+ model3+=[norm_layer(256),]
75
+
76
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
77
+ model4+=[nn.ReLU(True),]
78
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
79
+ model4+=[nn.ReLU(True),]
80
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
81
+ model4+=[nn.ReLU(True),]
82
+ model4+=[norm_layer(512),]
83
+
84
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
85
+ model5+=[nn.ReLU(True),]
86
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
87
+ model5+=[nn.ReLU(True),]
88
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
89
+ model5+=[nn.ReLU(True),]
90
+ model5+=[norm_layer(512),]
91
+
92
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
93
+ model6+=[nn.ReLU(True),]
94
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
95
+ model6+=[nn.ReLU(True),]
96
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
97
+ model6+=[nn.ReLU(True),]
98
+ model6+=[norm_layer(512),]
99
+
100
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
101
+ model7+=[nn.ReLU(True),]
102
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
103
+ model7+=[nn.ReLU(True),]
104
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
105
+ model7+=[nn.ReLU(True),]
106
+ model7+=[norm_layer(512),]
107
+
108
+ model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
109
+ model8+=[nn.ReLU(True),]
110
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
111
+ model8+=[nn.ReLU(True),]
112
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
113
+ model8+=[nn.ReLU(True),]
114
+
115
+ model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
116
+
117
+ self.model1 = nn.Sequential(*model1)
118
+ self.model2 = nn.Sequential(*model2)
119
+ self.model3 = nn.Sequential(*model3)
120
+ self.model4 = nn.Sequential(*model4)
121
+ self.model5 = nn.Sequential(*model5)
122
+ self.model6 = nn.Sequential(*model6)
123
+ self.model7 = nn.Sequential(*model7)
124
+ self.model8 = nn.Sequential(*model8)
125
+
126
+ self.softmax = nn.Softmax(dim=1)
127
+ self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
128
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
129
+
130
+ def forward(self, input_l):
131
+ conv1_2 = self.model1(self.normalize_l(input_l))
132
+ conv2_2 = self.model2(conv1_2)
133
+ conv3_3 = self.model3(conv2_2)
134
+ conv4_3 = self.model4(conv3_3)
135
+ conv5_3 = self.model5(conv4_3)
136
+ conv6_3 = self.model6(conv5_3)
137
+ conv7_3 = self.model7(conv6_3)
138
+ conv8_3 = self.model8(conv7_3)
139
+ out_reg = self.model_out(self.softmax(conv8_3))
140
+
141
+ x= self.unnormalize_ab(self.upsample4(out_reg))
142
+ zeros = torch.zeros_like(x[:, :1, :, :])
143
+ x = torch.cat([x, zeros], dim=1) # concatenate the tensor of zeros with the input tensor along the channel dimension
144
+ return x
145
+
146
+
147
+ # model_net = torch.load(f'weights/colorizer.pt')
148
+ model_net = ECCVGenerator()
149
+ model_net.load_state_dict(torch.load(f'weights/colorizer (1).pt', map_location=torch.device('cpu')))
150
+
151
+
152
+ def resize_original(img: Image.Image):
153
+ if img is None:
154
+ return img
155
+ if isinstance(img, dict):
156
+ img = img["image"]
157
+
158
+ guide_img = img.convert('L')
159
+ w, h = guide_img.size
160
+ scale = 256 / min(guide_img.size)
161
+ guide_img = guide_img.resize([int(round(s*scale)) for s in guide_img.size], Image.Resampling.LANCZOS)
162
+
163
+ guide = np.asarray(guide_img)
164
+ h, w = guide.shape[-2:]
165
+ rows = int(np.ceil(h/64))*64
166
+ cols = int(np.ceil(w/64))*64
167
+ ph_1 = (rows-h) // 2
168
+ ph_2 = rows-h - (rows-h) // 2
169
+ pw_1 = (cols-w) // 2
170
+ pw_2 = cols-w - (cols-w) // 2
171
+ guide = np.pad(guide, ((ph_1, ph_2), (pw_1, pw_2)), mode='constant', constant_values=255)
172
+ guide_img = Image.fromarray(guide)
173
+
174
+ return gr.Image.update(value=guide_img.convert('RGBA')), guide_img.convert('RGBA')
175
+
176
+
177
+ def resize_original2(img: Image.Image):
178
+ if img is None:
179
+ return img
180
+ if isinstance(img, dict):
181
+ img = img["image"]
182
+
183
+ img = img.resize(256,256)
184
+
185
+ return img
186
+
187
+
188
+ def colorize(img: Dict[str, Image.Image], guide_img: Image.Image, seed: int, hint_mode: str):
189
+ if not isinstance(img, dict):
190
+ return gr.update(visible=True)
191
+
192
+ if hint_mode == "Roughly Hint":
193
+ hint_mode_int = 0
194
+ elif hint_mode == "Precisely Hint":
195
+ hint_mode_int = 0
196
+
197
+ guide_img = guide_img.convert('L')
198
+ hint_img = img["mask"].convert('RGBA') # I modified gradio to enable it upload colorful mask
199
+
200
+ guide = torch.from_numpy(np.asarray(guide_img))[None,None].float().to(device) / 255.0 * 2 - 1
201
+ hint = torch.from_numpy(np.asarray(hint_img)).permute(2,0,1)[None].float().to(device) / 255.0 * 2 - 1
202
+ hint_alpha = (hint[:,-1:] > 0.99).float()
203
+ hint = hint[:,:3] * hint_alpha - 2 * (1 - hint_alpha)
204
+
205
+ np.random.seed(int(seed))
206
+ b, c, h, w = hint.shape
207
+ h //= 8
208
+ w //= 8
209
+ noises = [torch.from_numpy(np.random.randn(b, c, h, w)).float().to(device) for _ in range(16+1)]
210
+
211
+ with torch.inference_mode():
212
+ sample = net(noises, guide, hint, hint_mode_int)
213
+ out = sample[0].cpu().numpy().transpose([1,2,0])
214
+ out = np.uint8(((out + 1) / 2 * 255).clip(0,255))
215
+
216
+ return Image.fromarray(out).convert('RGB')
217
+
218
+
219
+ def colorize2(img: Image.Image, model_option: str):
220
+ if not isinstance(img, dict):
221
+ return gr.update(visible=True)
222
+
223
+ if model_option == "Model 1":
224
+ model_int = 0
225
+ elif model_option == "Model 2":
226
+ model_int = 0
227
+ input = torch.from_numpy(np.asarray(img))[None,None].float().to(device) / 255.0 * 2 - 1
228
+ with torch.inference_mode():
229
+ out2 = model_net(input).squeeze()
230
+ print(out2.shape)
231
+ out2 = sample[0].cpu().numpy().transpose([1,2,0])
232
+ out2 = np.uint8(((out + 1) / 2 * 255).clip(0,255))
233
+
234
+ return Image.fromarray(out2).convert('RGB')
235
+
236
+
237
+ with gr.Blocks() as demo:
238
+ gr.Markdown('''<center><h1>Image Colorization With Hint</h1></center>
239
+ <h2>Colorize your images/sketches with hint points.</h2>
240
+ <br />
241
+ ''')
242
+ with gr.Row():
243
+ with gr.Column():
244
+ inp = gr.Image(
245
+ source="upload",
246
+ tool="sketch", # tool="color-sketch", # color-sketch upload image mixed with the original
247
+ type="pil",
248
+ label="Sketch",
249
+ interactive=True,
250
+ elem_id="sketch-canvas"
251
+ )
252
+ inp_store = gr.Image(
253
+ type="pil",
254
+ interactive=False
255
+ )
256
+ inp_store.visible = False
257
+ with gr.Column():
258
+ seed = gr.Slider(1, 2**32, step=1, label="Seed", interactive=True, randomize=True)
259
+ hint_mode = gr.Radio(["Roughly Hint", "Precisely Hint"], value="Roughly Hint", label="Hint Mode")
260
+ btn = gr.Button("Run")
261
+ with gr.Column():
262
+ output = gr.Image(type="pil", label="Output", interactive=False)
263
+ with gr.Row():
264
+ with gr.Column():
265
+ inp2 = gr.Image(
266
+ source="upload",
267
+ type="pil",
268
+ label="Sketch",
269
+ interactive=True
270
+ )
271
+ inp_store2 = gr.Image(
272
+ type="pil",
273
+ interactive=False
274
+ )
275
+ inp_store2.visible = False
276
+ with gr.Column():
277
+ # seed = gr.Slider(1, 2**32, step=1, label="Seed", interactive=True, randomize=True)
278
+ model_option = gr.Radio(["Model 1", "Model 2"], value="Model 1", label="Model 2")
279
+ btn2 = gr.Button("Run Colorization")
280
+ with gr.Column():
281
+ output2 = gr.Image(type="pil", label="Output2", interactive=False)
282
+ gr.Markdown('''
283
+ Upon uploading an image, kindly give color hints at specific points, and then run the model. Average inference time is about 52 seconds.<br />
284
+ ''')
285
+ gr.Markdown('''Authors: <a href=\"https://www.linkedin.com/in/chakshu-dhannawat/">Chakshu Dhannawat</a>, <a href=\"https://www.linkedin.com/in/navlika-singh-963120204/">Navlika Singh</a>,<a href=\"https://www.linkedin.com/in/akshat-jain-103550201/"> Akshat Jain</a>''')
286
+ inp.upload(
287
+ resize_original,
288
+ inp,
289
+ [inp, inp_store],
290
+ )
291
+ inp2.upload(
292
+ resize_original2,
293
+ inp,
294
+ inp
295
+ )
296
+ btn.click(
297
+ colorize,
298
+ [inp, inp_store, seed, hint_mode],
299
+ output
300
+ )
301
+ btn2.click(
302
+ colorize2,
303
+ [inp2, model_option],
304
+ output2
305
+ )
306
+
307
+ if __name__ == "__main__":
308
+ demo.launch()