SpyroSigma commited on
Commit
6e40561
·
verified ·
1 Parent(s): b42bedb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +408 -393
app.py CHANGED
@@ -1,393 +1,408 @@
1
- import matplotlib.pyplot as plt
2
- import streamlit as st
3
-
4
- # -------------------- base color ------------------
5
-
6
- import torch
7
- from torch import nn
8
-
9
- class BaseColor(nn.Module):
10
- def __init__(self):
11
- super(BaseColor, self).__init__()
12
-
13
- self.l_cent = 50.
14
- self.l_norm = 100.
15
- self.ab_norm = 110.
16
-
17
- def normalize_l(self, in_l):
18
- return (in_l-self.l_cent)/self.l_norm
19
-
20
- def unnormalize_l(self, in_l):
21
- return in_l*self.l_norm + self.l_cent
22
-
23
- def normalize_ab(self, in_ab):
24
- return in_ab/self.ab_norm
25
-
26
- def unnormalize_ab(self, in_ab):
27
- return in_ab*self.ab_norm
28
-
29
- # ------------------ eccv16 ---------------------
30
-
31
- import numpy as np
32
-
33
-
34
- class ECCVGenerator(BaseColor):
35
- def __init__(self, norm_layer=nn.BatchNorm2d):
36
- super(ECCVGenerator, self).__init__()
37
-
38
- model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
39
- model1+=[nn.ReLU(True),]
40
- model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
41
- model1+=[nn.ReLU(True),]
42
- model1+=[norm_layer(64),]
43
-
44
- model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
45
- model2+=[nn.ReLU(True),]
46
- model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
47
- model2+=[nn.ReLU(True),]
48
- model2+=[norm_layer(128),]
49
-
50
- model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
51
- model3+=[nn.ReLU(True),]
52
- model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
53
- model3+=[nn.ReLU(True),]
54
- model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
55
- model3+=[nn.ReLU(True),]
56
- model3+=[norm_layer(256),]
57
-
58
- model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
59
- model4+=[nn.ReLU(True),]
60
- model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
61
- model4+=[nn.ReLU(True),]
62
- model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
63
- model4+=[nn.ReLU(True),]
64
- model4+=[norm_layer(512),]
65
-
66
- model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
67
- model5+=[nn.ReLU(True),]
68
- model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
69
- model5+=[nn.ReLU(True),]
70
- model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
71
- model5+=[nn.ReLU(True),]
72
- model5+=[norm_layer(512),]
73
-
74
- model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
75
- model6+=[nn.ReLU(True),]
76
- model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
77
- model6+=[nn.ReLU(True),]
78
- model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
79
- model6+=[nn.ReLU(True),]
80
- model6+=[norm_layer(512),]
81
-
82
- model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
83
- model7+=[nn.ReLU(True),]
84
- model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
85
- model7+=[nn.ReLU(True),]
86
- model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
87
- model7+=[nn.ReLU(True),]
88
- model7+=[norm_layer(512),]
89
-
90
- model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
91
- model8+=[nn.ReLU(True),]
92
- model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
93
- model8+=[nn.ReLU(True),]
94
- model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
95
- model8+=[nn.ReLU(True),]
96
-
97
- model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
98
-
99
- self.model1 = nn.Sequential(*model1)
100
- self.model2 = nn.Sequential(*model2)
101
- self.model3 = nn.Sequential(*model3)
102
- self.model4 = nn.Sequential(*model4)
103
- self.model5 = nn.Sequential(*model5)
104
- self.model6 = nn.Sequential(*model6)
105
- self.model7 = nn.Sequential(*model7)
106
- self.model8 = nn.Sequential(*model8)
107
-
108
- self.softmax = nn.Softmax(dim=1)
109
- self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
110
- self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
111
-
112
- def forward(self, input_l):
113
- conv1_2 = self.model1(self.normalize_l(input_l))
114
- conv2_2 = self.model2(conv1_2)
115
- conv3_3 = self.model3(conv2_2)
116
- conv4_3 = self.model4(conv3_3)
117
- conv5_3 = self.model5(conv4_3)
118
- conv6_3 = self.model6(conv5_3)
119
- conv7_3 = self.model7(conv6_3)
120
- conv8_3 = self.model8(conv7_3)
121
- out_reg = self.model_out(self.softmax(conv8_3))
122
-
123
- return self.unnormalize_ab(self.upsample4(out_reg))
124
-
125
- def eccv16(pretrained=True):
126
- model = ECCVGenerator()
127
- if(pretrained):
128
- import torch.utils.model_zoo as model_zoo
129
- model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True))
130
- return model
131
-
132
- # ------------------ siggraph17 ---------------------
133
-
134
-
135
- class SIGGRAPHGenerator(BaseColor):
136
- def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
137
- super(SIGGRAPHGenerator, self).__init__()
138
-
139
- # Conv1
140
- model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
141
- model1+=[nn.ReLU(True),]
142
- model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
143
- model1+=[nn.ReLU(True),]
144
- model1+=[norm_layer(64),]
145
- # add a subsampling operation
146
-
147
- # Conv2
148
- model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
149
- model2+=[nn.ReLU(True),]
150
- model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
151
- model2+=[nn.ReLU(True),]
152
- model2+=[norm_layer(128),]
153
- # add a subsampling layer operation
154
-
155
- # Conv3
156
- model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
157
- model3+=[nn.ReLU(True),]
158
- model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
159
- model3+=[nn.ReLU(True),]
160
- model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
161
- model3+=[nn.ReLU(True),]
162
- model3+=[norm_layer(256),]
163
- # add a subsampling layer operation
164
-
165
- # Conv4
166
- model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
167
- model4+=[nn.ReLU(True),]
168
- model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
169
- model4+=[nn.ReLU(True),]
170
- model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
171
- model4+=[nn.ReLU(True),]
172
- model4+=[norm_layer(512),]
173
-
174
- # Conv5
175
- model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
176
- model5+=[nn.ReLU(True),]
177
- model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
178
- model5+=[nn.ReLU(True),]
179
- model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
180
- model5+=[nn.ReLU(True),]
181
- model5+=[norm_layer(512),]
182
-
183
- # Conv6
184
- model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
185
- model6+=[nn.ReLU(True),]
186
- model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
187
- model6+=[nn.ReLU(True),]
188
- model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
189
- model6+=[nn.ReLU(True),]
190
- model6+=[norm_layer(512),]
191
-
192
- # Conv7
193
- model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
194
- model7+=[nn.ReLU(True),]
195
- model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
196
- model7+=[nn.ReLU(True),]
197
- model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
198
- model7+=[nn.ReLU(True),]
199
- model7+=[norm_layer(512),]
200
-
201
- # Conv7
202
- model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
203
- model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
204
-
205
- model8=[nn.ReLU(True),]
206
- model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
207
- model8+=[nn.ReLU(True),]
208
- model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
209
- model8+=[nn.ReLU(True),]
210
- model8+=[norm_layer(256),]
211
-
212
- # Conv9
213
- model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
214
- model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
215
- # add the two feature maps above
216
-
217
- model9=[nn.ReLU(True),]
218
- model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
219
- model9+=[nn.ReLU(True),]
220
- model9+=[norm_layer(128),]
221
-
222
- # Conv10
223
- model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
224
- model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
225
- # add the two feature maps above
226
-
227
- model10=[nn.ReLU(True),]
228
- model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
229
- model10+=[nn.LeakyReLU(negative_slope=.2),]
230
-
231
- # classification output
232
- model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
233
-
234
- # regression output
235
- model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
236
- model_out+=[nn.Tanh()]
237
-
238
- self.model1 = nn.Sequential(*model1)
239
- self.model2 = nn.Sequential(*model2)
240
- self.model3 = nn.Sequential(*model3)
241
- self.model4 = nn.Sequential(*model4)
242
- self.model5 = nn.Sequential(*model5)
243
- self.model6 = nn.Sequential(*model6)
244
- self.model7 = nn.Sequential(*model7)
245
- self.model8up = nn.Sequential(*model8up)
246
- self.model8 = nn.Sequential(*model8)
247
- self.model9up = nn.Sequential(*model9up)
248
- self.model9 = nn.Sequential(*model9)
249
- self.model10up = nn.Sequential(*model10up)
250
- self.model10 = nn.Sequential(*model10)
251
- self.model3short8 = nn.Sequential(*model3short8)
252
- self.model2short9 = nn.Sequential(*model2short9)
253
- self.model1short10 = nn.Sequential(*model1short10)
254
-
255
- self.model_class = nn.Sequential(*model_class)
256
- self.model_out = nn.Sequential(*model_out)
257
-
258
- self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
259
- self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])
260
-
261
- def forward(self, input_A, input_B=None, mask_B=None):
262
- if(input_B is None):
263
- input_B = torch.cat((input_A*0, input_A*0), dim=1)
264
- if(mask_B is None):
265
- mask_B = input_A*0
266
-
267
- conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
268
- conv2_2 = self.model2(conv1_2[:,:,::2,::2])
269
- conv3_3 = self.model3(conv2_2[:,:,::2,::2])
270
- conv4_3 = self.model4(conv3_3[:,:,::2,::2])
271
- conv5_3 = self.model5(conv4_3)
272
- conv6_3 = self.model6(conv5_3)
273
- conv7_3 = self.model7(conv6_3)
274
-
275
- conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
276
- conv8_3 = self.model8(conv8_up)
277
- conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
278
- conv9_3 = self.model9(conv9_up)
279
- conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
280
- conv10_2 = self.model10(conv10_up)
281
- out_reg = self.model_out(conv10_2)
282
-
283
- conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
284
- conv9_3 = self.model9(conv9_up)
285
- conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
286
- conv10_2 = self.model10(conv10_up)
287
- out_reg = self.model_out(conv10_2)
288
-
289
- return self.unnormalize_ab(out_reg)
290
-
291
- def siggraph17(pretrained=True):
292
- model = SIGGRAPHGenerator()
293
- if(pretrained):
294
- import torch.utils.model_zoo as model_zoo
295
- model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
296
- return model
297
-
298
- # ------------------ utils ---------------------
299
-
300
-
301
- from PIL import Image
302
- import numpy as np
303
- from skimage import color
304
- import torch.nn.functional as F
305
-
306
- def load_img(img_path):
307
- out_np = np.asarray(Image.open(img_path))
308
- if(out_np.ndim==2):
309
- out_np = np.tile(out_np[:,:,None],3)
310
- return out_np
311
-
312
- def resize_img(img, HW=(256,256), resample=3):
313
- return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))
314
-
315
- def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):
316
- # return original size L and resized L as torch Tensors
317
- img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
318
-
319
- img_lab_orig = color.rgb2lab(img_rgb_orig)
320
- img_lab_rs = color.rgb2lab(img_rgb_rs)
321
-
322
- img_l_orig = img_lab_orig[:,:,0]
323
- img_l_rs = img_lab_rs[:,:,0]
324
-
325
- tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]
326
- tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]
327
-
328
- return (tens_orig_l, tens_rs_l)
329
-
330
- def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):
331
- # tens_orig_l 1 x 1 x H_orig x W_orig
332
- # out_ab 1 x 2 x H x W
333
-
334
- HW_orig = tens_orig_l.shape[2:]
335
- HW = out_ab.shape[2:]
336
-
337
- # call resize function if needed
338
- if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
339
- out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
340
- else:
341
- out_ab_orig = out_ab
342
-
343
- out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
344
- return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
345
-
346
-
347
- # parser = argparse.ArgumentParser()
348
- # parser.add_argument('-i','--img_path', type=str, default='imgs/test.jpg')
349
- # # parser.add_argument('--use_gpu', action='store_true', help='whether to use GPU')
350
- # parser.add_argument('-o','--save_prefix', type=str, default='saved', help='will save into this file with {eccv16.png, siggraph17.png} suffixes')
351
- # opt = parser.parse_args()
352
-
353
- colorizer_eccv16 = eccv16(pretrained=True).eval()
354
- colorizer_siggraph17 = siggraph17(pretrained=True).eval()
355
-
356
- # if(opt.use_gpu):
357
- # colorizer_eccv16.cuda()
358
- # colorizer_siggraph17.cuda()
359
-
360
- input_image = st.file_uploader("Upload Image : ", type=["jpg", "jpeg", "png"])
361
-
362
- if input_image is not None:
363
- img = load_img(input_image)
364
- (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256))
365
-
366
- img_bw = postprocess_tens(tens_l_orig, torch.cat((0*tens_l_orig,0*tens_l_orig),dim=1))
367
- out_img_eccv16 = postprocess_tens(tens_l_orig, colorizer_eccv16(tens_l_rs).cpu())
368
- out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu())
369
-
370
- plt.imsave(f'eccv16.png{input_image.name}', out_img_eccv16)
371
- plt.imsave(f'siggraph17.png{input_image.name}', out_img_siggraph17)
372
-
373
- plt.figure(figsize=(12,8))
374
- plt.subplot(2,2,1)
375
- plt.imshow(img)
376
- plt.title('Original')
377
- plt.axis('off')
378
-
379
- plt.subplot(2,2,2)
380
- plt.imshow(img_bw)
381
- plt.title('Input')
382
- plt.axis('off')
383
-
384
- plt.subplot(2,2,3)
385
- plt.imshow(out_img_eccv16)
386
- plt.title('Output (ECCV 16)')
387
- plt.axis('off')
388
-
389
- plt.subplot(2,2,4)
390
- plt.imshow(out_img_siggraph17)
391
- plt.title('Output (SIGGRAPH 17)')
392
- plt.axis('off')
393
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import streamlit as st
3
+
4
+ # -------------------- base color ------------------
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ class BaseColor(nn.Module):
10
+ def __init__(self):
11
+ super(BaseColor, self).__init__()
12
+
13
+ self.l_cent = 50.
14
+ self.l_norm = 100.
15
+ self.ab_norm = 110.
16
+
17
+ def normalize_l(self, in_l):
18
+ return (in_l-self.l_cent)/self.l_norm
19
+
20
+ def unnormalize_l(self, in_l):
21
+ return in_l*self.l_norm + self.l_cent
22
+
23
+ def normalize_ab(self, in_ab):
24
+ return in_ab/self.ab_norm
25
+
26
+ def unnormalize_ab(self, in_ab):
27
+ return in_ab*self.ab_norm
28
+
29
+ # ------------------ eccv16 ---------------------
30
+
31
+ import numpy as np
32
+
33
+
34
+ class ECCVGenerator(BaseColor):
35
+ def __init__(self, norm_layer=nn.BatchNorm2d):
36
+ super(ECCVGenerator, self).__init__()
37
+
38
+ model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
39
+ model1+=[nn.ReLU(True),]
40
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
41
+ model1+=[nn.ReLU(True),]
42
+ model1+=[norm_layer(64),]
43
+
44
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
45
+ model2+=[nn.ReLU(True),]
46
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
47
+ model2+=[nn.ReLU(True),]
48
+ model2+=[norm_layer(128),]
49
+
50
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
51
+ model3+=[nn.ReLU(True),]
52
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
53
+ model3+=[nn.ReLU(True),]
54
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
55
+ model3+=[nn.ReLU(True),]
56
+ model3+=[norm_layer(256),]
57
+
58
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
59
+ model4+=[nn.ReLU(True),]
60
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
61
+ model4+=[nn.ReLU(True),]
62
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
63
+ model4+=[nn.ReLU(True),]
64
+ model4+=[norm_layer(512),]
65
+
66
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
67
+ model5+=[nn.ReLU(True),]
68
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
69
+ model5+=[nn.ReLU(True),]
70
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
71
+ model5+=[nn.ReLU(True),]
72
+ model5+=[norm_layer(512),]
73
+
74
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
75
+ model6+=[nn.ReLU(True),]
76
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
77
+ model6+=[nn.ReLU(True),]
78
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
79
+ model6+=[nn.ReLU(True),]
80
+ model6+=[norm_layer(512),]
81
+
82
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
83
+ model7+=[nn.ReLU(True),]
84
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
85
+ model7+=[nn.ReLU(True),]
86
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
87
+ model7+=[nn.ReLU(True),]
88
+ model7+=[norm_layer(512),]
89
+
90
+ model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
91
+ model8+=[nn.ReLU(True),]
92
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
93
+ model8+=[nn.ReLU(True),]
94
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
95
+ model8+=[nn.ReLU(True),]
96
+
97
+ model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
98
+
99
+ self.model1 = nn.Sequential(*model1)
100
+ self.model2 = nn.Sequential(*model2)
101
+ self.model3 = nn.Sequential(*model3)
102
+ self.model4 = nn.Sequential(*model4)
103
+ self.model5 = nn.Sequential(*model5)
104
+ self.model6 = nn.Sequential(*model6)
105
+ self.model7 = nn.Sequential(*model7)
106
+ self.model8 = nn.Sequential(*model8)
107
+
108
+ self.softmax = nn.Softmax(dim=1)
109
+ self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
110
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
111
+
112
+ def forward(self, input_l):
113
+ conv1_2 = self.model1(self.normalize_l(input_l))
114
+ conv2_2 = self.model2(conv1_2)
115
+ conv3_3 = self.model3(conv2_2)
116
+ conv4_3 = self.model4(conv3_3)
117
+ conv5_3 = self.model5(conv4_3)
118
+ conv6_3 = self.model6(conv5_3)
119
+ conv7_3 = self.model7(conv6_3)
120
+ conv8_3 = self.model8(conv7_3)
121
+ out_reg = self.model_out(self.softmax(conv8_3))
122
+
123
+ return self.unnormalize_ab(self.upsample4(out_reg))
124
+
125
+ def eccv16(pretrained=True):
126
+ model = ECCVGenerator()
127
+ if(pretrained):
128
+ import torch.utils.model_zoo as model_zoo
129
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True))
130
+ return model
131
+
132
+ # ------------------ siggraph17 ---------------------
133
+
134
+
135
+ class SIGGRAPHGenerator(BaseColor):
136
+ def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
137
+ super(SIGGRAPHGenerator, self).__init__()
138
+
139
+ # Conv1
140
+ model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
141
+ model1+=[nn.ReLU(True),]
142
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
143
+ model1+=[nn.ReLU(True),]
144
+ model1+=[norm_layer(64),]
145
+ # add a subsampling operation
146
+
147
+ # Conv2
148
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
149
+ model2+=[nn.ReLU(True),]
150
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
151
+ model2+=[nn.ReLU(True),]
152
+ model2+=[norm_layer(128),]
153
+ # add a subsampling layer operation
154
+
155
+ # Conv3
156
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
157
+ model3+=[nn.ReLU(True),]
158
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
159
+ model3+=[nn.ReLU(True),]
160
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
161
+ model3+=[nn.ReLU(True),]
162
+ model3+=[norm_layer(256),]
163
+ # add a subsampling layer operation
164
+
165
+ # Conv4
166
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
167
+ model4+=[nn.ReLU(True),]
168
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
169
+ model4+=[nn.ReLU(True),]
170
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
171
+ model4+=[nn.ReLU(True),]
172
+ model4+=[norm_layer(512),]
173
+
174
+ # Conv5
175
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
176
+ model5+=[nn.ReLU(True),]
177
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
178
+ model5+=[nn.ReLU(True),]
179
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
180
+ model5+=[nn.ReLU(True),]
181
+ model5+=[norm_layer(512),]
182
+
183
+ # Conv6
184
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
185
+ model6+=[nn.ReLU(True),]
186
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
187
+ model6+=[nn.ReLU(True),]
188
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
189
+ model6+=[nn.ReLU(True),]
190
+ model6+=[norm_layer(512),]
191
+
192
+ # Conv7
193
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
194
+ model7+=[nn.ReLU(True),]
195
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
196
+ model7+=[nn.ReLU(True),]
197
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
198
+ model7+=[nn.ReLU(True),]
199
+ model7+=[norm_layer(512),]
200
+
201
+ # Conv7
202
+ model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
203
+ model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
204
+
205
+ model8=[nn.ReLU(True),]
206
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
207
+ model8+=[nn.ReLU(True),]
208
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
209
+ model8+=[nn.ReLU(True),]
210
+ model8+=[norm_layer(256),]
211
+
212
+ # Conv9
213
+ model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
214
+ model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
215
+ # add the two feature maps above
216
+
217
+ model9=[nn.ReLU(True),]
218
+ model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
219
+ model9+=[nn.ReLU(True),]
220
+ model9+=[norm_layer(128),]
221
+
222
+ # Conv10
223
+ model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
224
+ model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
225
+ # add the two feature maps above
226
+
227
+ model10=[nn.ReLU(True),]
228
+ model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
229
+ model10+=[nn.LeakyReLU(negative_slope=.2),]
230
+
231
+ # classification output
232
+ model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
233
+
234
+ # regression output
235
+ model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
236
+ model_out+=[nn.Tanh()]
237
+
238
+ self.model1 = nn.Sequential(*model1)
239
+ self.model2 = nn.Sequential(*model2)
240
+ self.model3 = nn.Sequential(*model3)
241
+ self.model4 = nn.Sequential(*model4)
242
+ self.model5 = nn.Sequential(*model5)
243
+ self.model6 = nn.Sequential(*model6)
244
+ self.model7 = nn.Sequential(*model7)
245
+ self.model8up = nn.Sequential(*model8up)
246
+ self.model8 = nn.Sequential(*model8)
247
+ self.model9up = nn.Sequential(*model9up)
248
+ self.model9 = nn.Sequential(*model9)
249
+ self.model10up = nn.Sequential(*model10up)
250
+ self.model10 = nn.Sequential(*model10)
251
+ self.model3short8 = nn.Sequential(*model3short8)
252
+ self.model2short9 = nn.Sequential(*model2short9)
253
+ self.model1short10 = nn.Sequential(*model1short10)
254
+
255
+ self.model_class = nn.Sequential(*model_class)
256
+ self.model_out = nn.Sequential(*model_out)
257
+
258
+ self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
259
+ self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])
260
+
261
+ def forward(self, input_A, input_B=None, mask_B=None):
262
+ if(input_B is None):
263
+ input_B = torch.cat((input_A*0, input_A*0), dim=1)
264
+ if(mask_B is None):
265
+ mask_B = input_A*0
266
+
267
+ conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
268
+ conv2_2 = self.model2(conv1_2[:,:,::2,::2])
269
+ conv3_3 = self.model3(conv2_2[:,:,::2,::2])
270
+ conv4_3 = self.model4(conv3_3[:,:,::2,::2])
271
+ conv5_3 = self.model5(conv4_3)
272
+ conv6_3 = self.model6(conv5_3)
273
+ conv7_3 = self.model7(conv6_3)
274
+
275
+ conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
276
+ conv8_3 = self.model8(conv8_up)
277
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
278
+ conv9_3 = self.model9(conv9_up)
279
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
280
+ conv10_2 = self.model10(conv10_up)
281
+ out_reg = self.model_out(conv10_2)
282
+
283
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
284
+ conv9_3 = self.model9(conv9_up)
285
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
286
+ conv10_2 = self.model10(conv10_up)
287
+ out_reg = self.model_out(conv10_2)
288
+
289
+ return self.unnormalize_ab(out_reg)
290
+
291
+ def siggraph17(pretrained=True):
292
+ model = SIGGRAPHGenerator()
293
+ if(pretrained):
294
+ import torch.utils.model_zoo as model_zoo
295
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
296
+ return model
297
+
298
+ # ------------------ utils ---------------------
299
+
300
+
301
+ from PIL import Image
302
+ import numpy as np
303
+ from skimage import color
304
+ import torch.nn.functional as F
305
+
306
+ def load_img(img_path):
307
+ out_np = np.asarray(Image.open(img_path))
308
+ if(out_np.ndim==2):
309
+ out_np = np.tile(out_np[:,:,None],3)
310
+ return out_np
311
+
312
+ def resize_img(img, HW=(256,256), resample=3):
313
+ return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))
314
+
315
+ def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):
316
+ # return original size L and resized L as torch Tensors
317
+ img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
318
+
319
+ img_lab_orig = color.rgb2lab(img_rgb_orig)
320
+ img_lab_rs = color.rgb2lab(img_rgb_rs)
321
+
322
+ img_l_orig = img_lab_orig[:,:,0]
323
+ img_l_rs = img_lab_rs[:,:,0]
324
+
325
+ tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]
326
+ tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]
327
+
328
+ return (tens_orig_l, tens_rs_l)
329
+
330
+ def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):
331
+ # tens_orig_l 1 x 1 x H_orig x W_orig
332
+ # out_ab 1 x 2 x H x W
333
+
334
+ HW_orig = tens_orig_l.shape[2:]
335
+ HW = out_ab.shape[2:]
336
+
337
+ # call resize function if needed
338
+ if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
339
+ out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
340
+ else:
341
+ out_ab_orig = out_ab
342
+
343
+ out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
344
+ return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
345
+
346
+
347
+ # parser = argparse.ArgumentParser()
348
+ # parser.add_argument('-i','--img_path', type=str, default='imgs/test.jpg')
349
+ # # parser.add_argument('--use_gpu', action='store_true', help='whether to use GPU')
350
+ # parser.add_argument('-o','--save_prefix', type=str, default='saved', help='will save into this file with {eccv16.png, siggraph17.png} suffixes')
351
+ # opt = parser.parse_args()
352
+
353
+ colorizer_eccv16 = eccv16(pretrained=True).eval()
354
+ colorizer_siggraph17 = siggraph17(pretrained=True).eval()
355
+
356
+ # if(opt.use_gpu):
357
+ # colorizer_eccv16.cuda()
358
+ # colorizer_siggraph17.cuda()
359
+
360
+ input_image = st.file_uploader("Upload Image : ", type=["jpg", "jpeg", "png"])
361
+
362
+ if input_image is not None:
363
+ img = load_img(input_image)
364
+ (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256))
365
+
366
+ img_bw = postprocess_tens(tens_l_orig, torch.cat((0*tens_l_orig,0*tens_l_orig),dim=1))
367
+ out_img_eccv16 = postprocess_tens(tens_l_orig, colorizer_eccv16(tens_l_rs).cpu())
368
+ out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu())
369
+
370
+ plt.imsave(f'eccv16.png{input_image.name}', out_img_eccv16)
371
+ plt.imsave(f'siggraph17.png{input_image.name}', out_img_siggraph17)
372
+
373
+ eccv16_path = f'eccv16_{input_image.name}'
374
+ siggraph17_path = f'siggraph17_{input_image.name}'
375
+
376
+ plt.imsave(eccv16_path, out_img_eccv16)
377
+ plt.imsave(siggraph17_path, out_img_siggraph17)
378
+
379
+ # Display images using Streamlit
380
+ st.image([img, img_bw, out_img_eccv16, out_img_siggraph17], caption=['Original', 'Input', 'Output (ECCV 16)', 'Output (SIGGRAPH 17)'],
381
+ width=256)
382
+
383
+ # Optionally, you can also display the saved images
384
+ st.markdown("### Saved Images:")
385
+ st.image([eccv16_path, siggraph17_path], width=256)
386
+
387
+ # plt.figure(figsize=(12,8))
388
+ # plt.subplot(2,2,1)
389
+ # plt.imshow(img)
390
+ # plt.title('Original')
391
+ # plt.axis('off')
392
+
393
+
394
+ # plt.subplot(2,2,2)
395
+ # plt.imshow(img_bw)
396
+ # plt.title('Input')
397
+ # plt.axis('off')
398
+
399
+ # plt.subplot(2,2,3)
400
+ # plt.imshow(out_img_eccv16)
401
+ # plt.title('Output (ECCV 16)')
402
+ # plt.axis('off')
403
+
404
+ # plt.subplot(2,2,4)
405
+ # plt.imshow(out_img_siggraph17)
406
+ # plt.title('Output (SIGGRAPH 17)')
407
+ # plt.axis('off')
408
+ # plt.show()