Raid41 commited on
Commit
0654ee4
·
1 Parent(s): a3edef0

Upload 33 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/color2.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ *.ipynb
2
+ *.pth
3
+ *.zip
4
+
5
+ __pycache__/
6
+ temp_colorization/
7
+
8
+ static/temp_images/
__pycache__/colorizator.cpython-39.pyc ADDED
Binary file (2.7 kB). View file
 
colorizator.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ print(torch.__version__)
3
+ from torchvision.transforms import ToTensor
4
+ import numpy as np
5
+
6
+ from networks.models import Colorizer
7
+ from denoising.denoiser import FFDNetDenoiser
8
+ from utils.utils import resize_pad
9
+
10
+ class MangaColorizator:
11
+ def __init__(self, device, generator_path = 'networks/generator.zip', extractor_path = 'networks/extractor.pth'):
12
+ self.colorizer = Colorizer().to(device)
13
+ self.colorizer.generator.load_state_dict(torch.load(generator_path, map_location = device))
14
+ self.colorizer = self.colorizer.eval()
15
+
16
+ self.denoiser = FFDNetDenoiser(device)
17
+
18
+ self.current_image = None
19
+ self.current_hint = None
20
+ self.current_pad = None
21
+
22
+ self.device = device
23
+
24
+ def set_image(self, image, size = 576, apply_denoise = True, denoise_sigma = 25, transform = ToTensor()):
25
+ if (size % 32 != 0):
26
+ raise RuntimeError("size is not divisible by 32")
27
+
28
+ if apply_denoise:
29
+ image = self.denoiser.get_denoised_image(image, sigma = denoise_sigma)
30
+
31
+ image, self.current_pad = resize_pad(image, size)
32
+ self.current_image = transform(image).unsqueeze(0).to(self.device)
33
+ self.current_hint = torch.zeros(1, 4, self.current_image.shape[2], self.current_image.shape[3]).float().to(self.device)
34
+
35
+ def update_hint(self, hint, mask):
36
+ '''
37
+ Args:
38
+ hint: numpy.ndarray with shape (self.current_image.shape[2], self.current_image.shape[3], 3)
39
+ mask: numpy.ndarray with shape (self.current_image.shape[2], self.current_image.shape[3])
40
+ '''
41
+
42
+ if issubclass(hint.dtype.type, np.integer):
43
+ hint = hint.astype('float32') / 255
44
+
45
+ hint = (hint - 0.5) / 0.5
46
+ hint = torch.FloatTensor(hint).permute(2, 0, 1)
47
+ mask = torch.FloatTensor(np.expand_dims(mask, 0))
48
+
49
+ self.current_hint = torch.cat([hint * mask, mask], 0).unsqueeze(0).to(self.device)
50
+
51
+ def colorize(self):
52
+ with torch.no_grad():
53
+ fake_color, _ = self.colorizer(torch.cat([self.current_image, self.current_hint], 1))
54
+ fake_color = fake_color.detach()
55
+
56
+ result = fake_color[0].detach().cpu().permute(1, 2, 0) * 0.5 + 0.5
57
+
58
+ if self.current_pad[0] != 0:
59
+ result = result[:-self.current_pad[0]]
60
+ if self.current_pad[1] != 0:
61
+ result = result[:, :-self.current_pad[1]]
62
+
63
+ return result.numpy()
denoising/__pycache__/denoiser.cpython-39.pyc ADDED
Binary file (3.48 kB). View file
 
denoising/__pycache__/functions.cpython-39.pyc ADDED
Binary file (3.6 kB). View file
 
denoising/__pycache__/models.cpython-39.pyc ADDED
Binary file (3.49 kB). View file
 
denoising/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.39 kB). View file
 
denoising/denoiser.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Denoise an image with the FFDNet denoising method
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import os
14
+ import argparse
15
+ import time
16
+
17
+
18
+ import numpy as np
19
+ import cv2
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.autograd import Variable
23
+ from .models import FFDNet
24
+ from .utils import normalize, variable_to_cv2_image, remove_dataparallel_wrapper, is_rgb
25
+
26
+ class FFDNetDenoiser:
27
+ def __init__(self, _device, _sigma = 25, _weights_dir = 'denoising/models/', _in_ch = 3):
28
+ self.sigma = _sigma / 255
29
+ self.weights_dir = _weights_dir
30
+ self.channels = _in_ch
31
+ self.device = _device
32
+
33
+ self.model = FFDNet(num_input_channels = _in_ch)
34
+ self.load_weights()
35
+ self.model.eval()
36
+
37
+
38
+ def load_weights(self):
39
+ weights_name = 'net_rgb.pth' if self.channels == 3 else 'net_gray.pth'
40
+ weights_path = os.path.join(self.weights_dir, weights_name)
41
+ if self.device == 'cuda':
42
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
43
+ device_ids = [0]
44
+ self.model = nn.DataParallel(self.model, device_ids=device_ids).cuda()
45
+ else:
46
+ state_dict = torch.load(weights_path, map_location='cpu')
47
+ # CPU mode: remove the DataParallel wrapper
48
+ state_dict = remove_dataparallel_wrapper(state_dict)
49
+ self.model.load_state_dict(state_dict)
50
+
51
+ def get_denoised_image(self, imorig, sigma = None):
52
+
53
+ if sigma is not None:
54
+ cur_sigma = sigma / 255
55
+ else:
56
+ cur_sigma = self.sigma
57
+
58
+ if len(imorig.shape) < 3 or imorig.shape[2] == 1:
59
+ imorig = np.repeat(np.expand_dims(imorig, 2), 3, 2)
60
+
61
+ imorig = imorig[..., :3]
62
+
63
+ if (max(imorig.shape[0], imorig.shape[1]) > 1200):
64
+ ratio = max(imorig.shape[0], imorig.shape[1]) / 1200
65
+ imorig = cv2.resize(imorig, (int(imorig.shape[1] / ratio), int(imorig.shape[0] / ratio)), interpolation = cv2.INTER_AREA)
66
+
67
+ imorig = imorig.transpose(2, 0, 1)
68
+
69
+ if (imorig.max() > 1.2):
70
+ imorig = normalize(imorig)
71
+ imorig = np.expand_dims(imorig, 0)
72
+
73
+ # Handle odd sizes
74
+ expanded_h = False
75
+ expanded_w = False
76
+ sh_im = imorig.shape
77
+ if sh_im[2]%2 == 1:
78
+ expanded_h = True
79
+ imorig = np.concatenate((imorig, imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
80
+
81
+ if sh_im[3]%2 == 1:
82
+ expanded_w = True
83
+ imorig = np.concatenate((imorig, imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
84
+
85
+
86
+ imorig = torch.Tensor(imorig)
87
+
88
+
89
+ # Sets data type according to CPU or GPU modes
90
+ if self.device == 'cuda':
91
+ dtype = torch.cuda.FloatTensor
92
+ else:
93
+ dtype = torch.FloatTensor
94
+
95
+ imnoisy = imorig.clone()
96
+
97
+
98
+ with torch.no_grad():
99
+ imorig, imnoisy = imorig.type(dtype), imnoisy.type(dtype)
100
+ nsigma = torch.FloatTensor([cur_sigma]).type(dtype)
101
+
102
+
103
+ # Estimate noise and subtract it to the input image
104
+ im_noise_estim = self.model(imnoisy, nsigma)
105
+ outim = torch.clamp(imnoisy-im_noise_estim, 0., 1.)
106
+
107
+ if expanded_h:
108
+ imorig = imorig[:, :, :-1, :]
109
+ outim = outim[:, :, :-1, :]
110
+ imnoisy = imnoisy[:, :, :-1, :]
111
+
112
+ if expanded_w:
113
+ imorig = imorig[:, :, :, :-1]
114
+ outim = outim[:, :, :, :-1]
115
+ imnoisy = imnoisy[:, :, :, :-1]
116
+
117
+ return variable_to_cv2_image(outim)
denoising/functions.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions implementing custom NN layers
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import torch
14
+ from torch.autograd import Function, Variable
15
+
16
+ def concatenate_input_noise_map(input, noise_sigma):
17
+ r"""Implements the first layer of FFDNet. This function returns a
18
+ torch.autograd.Variable composed of the concatenation of the downsampled
19
+ input image and the noise map. Each image of the batch of size CxHxW gets
20
+ converted to an array of size 4*CxH/2xW/2. Each of the pixels of the
21
+ non-overlapped 2x2 patches of the input image are placed in the new array
22
+ along the first dimension.
23
+
24
+ Args:
25
+ input: batch containing CxHxW images
26
+ noise_sigma: the value of the pixels of the CxH/2xW/2 noise map
27
+ """
28
+ # noise_sigma is a list of length batch_size
29
+ N, C, H, W = input.size()
30
+ dtype = input.type()
31
+ sca = 2
32
+ sca2 = sca*sca
33
+ Cout = sca2*C
34
+ Hout = H//sca
35
+ Wout = W//sca
36
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
37
+
38
+ # Fill the downsampled image with zeros
39
+ if 'cuda' in dtype:
40
+ downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0)
41
+ else:
42
+ downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0)
43
+
44
+ # Build the CxH/2xW/2 noise map
45
+ noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout)
46
+
47
+ # Populate output
48
+ for idx in range(sca2):
49
+ downsampledfeatures[:, idx:Cout:sca2, :, :] = \
50
+ input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
51
+
52
+ # concatenate de-interleaved mosaic with noise map
53
+ return torch.cat((noise_map, downsampledfeatures), 1)
54
+
55
+ class UpSampleFeaturesFunction(Function):
56
+ r"""Extends PyTorch's modules by implementing a torch.autograd.Function.
57
+ This class implements the forward and backward methods of the last layer
58
+ of FFDNet. It basically performs the inverse of
59
+ concatenate_input_noise_map(): it converts each of the images of a
60
+ batch of size CxH/2xW/2 to images of size C/4xHxW
61
+ """
62
+ @staticmethod
63
+ def forward(ctx, input):
64
+ N, Cin, Hin, Win = input.size()
65
+ dtype = input.type()
66
+ sca = 2
67
+ sca2 = sca*sca
68
+ Cout = Cin//sca2
69
+ Hout = Hin*sca
70
+ Wout = Win*sca
71
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
72
+
73
+ assert (Cin%sca2 == 0), 'Invalid input dimensions: number of channels should be divisible by 4'
74
+
75
+ result = torch.zeros((N, Cout, Hout, Wout)).type(dtype)
76
+ for idx in range(sca2):
77
+ result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = input[:, idx:Cin:sca2, :, :]
78
+
79
+ return result
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output):
83
+ N, Cg_out, Hg_out, Wg_out = grad_output.size()
84
+ dtype = grad_output.data.type()
85
+ sca = 2
86
+ sca2 = sca*sca
87
+ Cg_in = sca2*Cg_out
88
+ Hg_in = Hg_out//sca
89
+ Wg_in = Wg_out//sca
90
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
91
+
92
+ # Build output
93
+ grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype)
94
+ # Populate output
95
+ for idx in range(sca2):
96
+ grad_input[:, idx:Cg_in:sca2, :, :] = grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
97
+
98
+ return Variable(grad_input)
99
+
100
+ # Alias functions
101
+ upsamplefeatures = UpSampleFeaturesFunction.apply
denoising/models.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Definition of the FFDNet model and its custom layers
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import torch.nn as nn
14
+ from torch.autograd import Variable
15
+ import denoising.functions as functions
16
+
17
+ class UpSampleFeatures(nn.Module):
18
+ r"""Implements the last layer of FFDNet
19
+ """
20
+ def __init__(self):
21
+ super(UpSampleFeatures, self).__init__()
22
+ def forward(self, x):
23
+ return functions.upsamplefeatures(x)
24
+
25
+ class IntermediateDnCNN(nn.Module):
26
+ r"""Implements the middel part of the FFDNet architecture, which
27
+ is basically a DnCNN net
28
+ """
29
+ def __init__(self, input_features, middle_features, num_conv_layers):
30
+ super(IntermediateDnCNN, self).__init__()
31
+ self.kernel_size = 3
32
+ self.padding = 1
33
+ self.input_features = input_features
34
+ self.num_conv_layers = num_conv_layers
35
+ self.middle_features = middle_features
36
+ if self.input_features == 5:
37
+ self.output_features = 4 #Grayscale image
38
+ elif self.input_features == 15:
39
+ self.output_features = 12 #RGB image
40
+ else:
41
+ raise Exception('Invalid number of input features')
42
+
43
+ layers = []
44
+ layers.append(nn.Conv2d(in_channels=self.input_features,\
45
+ out_channels=self.middle_features,\
46
+ kernel_size=self.kernel_size,\
47
+ padding=self.padding,\
48
+ bias=False))
49
+ layers.append(nn.ReLU(inplace=True))
50
+ for _ in range(self.num_conv_layers-2):
51
+ layers.append(nn.Conv2d(in_channels=self.middle_features,\
52
+ out_channels=self.middle_features,\
53
+ kernel_size=self.kernel_size,\
54
+ padding=self.padding,\
55
+ bias=False))
56
+ layers.append(nn.BatchNorm2d(self.middle_features))
57
+ layers.append(nn.ReLU(inplace=True))
58
+ layers.append(nn.Conv2d(in_channels=self.middle_features,\
59
+ out_channels=self.output_features,\
60
+ kernel_size=self.kernel_size,\
61
+ padding=self.padding,\
62
+ bias=False))
63
+ self.itermediate_dncnn = nn.Sequential(*layers)
64
+ def forward(self, x):
65
+ out = self.itermediate_dncnn(x)
66
+ return out
67
+
68
+ class FFDNet(nn.Module):
69
+ r"""Implements the FFDNet architecture
70
+ """
71
+ def __init__(self, num_input_channels):
72
+ super(FFDNet, self).__init__()
73
+ self.num_input_channels = num_input_channels
74
+ if self.num_input_channels == 1:
75
+ # Grayscale image
76
+ self.num_feature_maps = 64
77
+ self.num_conv_layers = 15
78
+ self.downsampled_channels = 5
79
+ self.output_features = 4
80
+ elif self.num_input_channels == 3:
81
+ # RGB image
82
+ self.num_feature_maps = 96
83
+ self.num_conv_layers = 12
84
+ self.downsampled_channels = 15
85
+ self.output_features = 12
86
+ else:
87
+ raise Exception('Invalid number of input features')
88
+
89
+ self.intermediate_dncnn = IntermediateDnCNN(\
90
+ input_features=self.downsampled_channels,\
91
+ middle_features=self.num_feature_maps,\
92
+ num_conv_layers=self.num_conv_layers)
93
+ self.upsamplefeatures = UpSampleFeatures()
94
+
95
+ def forward(self, x, noise_sigma):
96
+ concat_noise_x = functions.concatenate_input_noise_map(x.data, noise_sigma.data)
97
+ concat_noise_x = Variable(concat_noise_x)
98
+ h_dncnn = self.intermediate_dncnn(concat_noise_x)
99
+ pred_noise = self.upsamplefeatures(h_dncnn)
100
+ return pred_noise
denoising/models/net_rgb.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fe98bfd2ac870b15f360661b1c4789eecefc6dc2e4462842a0dd15e149a0433
3
+ size 3435567
denoising/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Different utilities such as orthogonalization of weights, initialization of
3
+ loggers, etc
4
+
5
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
6
+
7
+ This program is free software: you can use, modify and/or
8
+ redistribute it under the terms of the GNU General Public
9
+ License as published by the Free Software Foundation, either
10
+ version 3 of the License, or (at your option) any later
11
+ version. You should have received a copy of this license along
12
+ this program. If not, see <http://www.gnu.org/licenses/>.
13
+ """
14
+ import numpy as np
15
+ import cv2
16
+
17
+
18
+ def variable_to_cv2_image(varim):
19
+ r"""Converts a torch.autograd.Variable to an OpenCV image
20
+
21
+ Args:
22
+ varim: a torch.autograd.Variable
23
+ """
24
+ nchannels = varim.size()[1]
25
+ if nchannels == 1:
26
+ res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8)
27
+ elif nchannels == 3:
28
+ res = varim.data.cpu().numpy()[0]
29
+ res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
30
+ res = (res*255.).clip(0, 255).astype(np.uint8)
31
+ else:
32
+ raise Exception('Number of color channels not supported')
33
+ return res
34
+
35
+
36
+ def normalize(data):
37
+ return np.float32(data/255.)
38
+
39
+ def remove_dataparallel_wrapper(state_dict):
40
+ r"""Converts a DataParallel model to a normal one by removing the "module."
41
+ wrapper in the module dictionary
42
+
43
+ Args:
44
+ state_dict: a torch.nn.DataParallel state dictionary
45
+ """
46
+ from collections import OrderedDict
47
+
48
+ new_state_dict = OrderedDict()
49
+ for k, vl in state_dict.items():
50
+ name = k[7:] # remove 'module.' of DataParallel
51
+ new_state_dict[name] = vl
52
+
53
+ return new_state_dict
54
+
55
+ def is_rgb(im_path):
56
+ r""" Returns True if the image in im_path is an RGB image
57
+ """
58
+ from skimage.io import imread
59
+ rgb = False
60
+ im = imread(im_path)
61
+ if (len(im.shape) == 3):
62
+ if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])):
63
+ rgb = True
64
+ print("rgb: {}".format(rgb))
65
+ print("im shape: {}".format(im.shape))
66
+ return rgb
figures/bw1.jpg ADDED
figures/bw2.jpg ADDED
figures/bw3.jpg ADDED
figures/bw4.jpg ADDED
figures/bw5.jpg ADDED
figures/bw6.jpg ADDED
figures/color1.png ADDED
figures/color2.png ADDED

Git LFS Details

  • SHA256: 6900649446215ef2f21e0f36cdb0f9001907e7934e552c920b5bf0fa38e8dd10
  • Pointer size: 132 Bytes
  • Size of remote file: 1.4 MB
figures/color3.png ADDED
figures/color4.png ADDED
figures/color5.png ADDED
figures/color6.png ADDED
inference.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import sys
4
+
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+ from colorizator import MangaColorizator
9
+
10
+ def process_image(image, colorizator, args):
11
+ colorizator.set_image(image, args.size, args.denoiser, args.denoiser_sigma)
12
+
13
+ return colorizator.colorize()
14
+
15
+ def colorize_single_image(image_path, save_path, colorizator, args):
16
+
17
+ image = plt.imread(image_path)
18
+
19
+ colorization = process_image(image, colorizator, args)
20
+
21
+ plt.imsave(save_path, colorization)
22
+
23
+ return True
24
+
25
+
26
+ def colorize_images(target_path, colorizator, args):
27
+ images = os.listdir(args.path)
28
+
29
+ for image_name in images:
30
+ file_path = os.path.join(args.path, image_name)
31
+
32
+ if os.path.isdir(file_path):
33
+ continue
34
+
35
+ name, ext = os.path.splitext(image_name)
36
+ if (ext != '.png'):
37
+ image_name = name + '.png'
38
+
39
+ print(file_path)
40
+
41
+ save_path = os.path.join(target_path, image_name)
42
+ colorize_single_image(file_path, save_path, colorizator, args)
43
+
44
+ def parse_args():
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("-p", "--path", required=True)
47
+ parser.add_argument("-gen", "--generator", default = 'networks/generator.zip')
48
+ parser.add_argument("-ext", "--extractor", default = 'networks/extractor.pth')
49
+ parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
50
+ parser.add_argument('-nd', '--no_denoise', dest = 'denoiser', action = 'store_false')
51
+ parser.add_argument("-ds", "--denoiser_sigma", type = int, default = 25)
52
+ parser.add_argument("-s", "--size", type = int, default = 576)
53
+ parser.set_defaults(gpu = False)
54
+ parser.set_defaults(denoiser = True)
55
+ args = parser.parse_args()
56
+
57
+ return args
58
+
59
+
60
+ if __name__ == "__main__":
61
+
62
+ args = parse_args()
63
+
64
+ if args.gpu:
65
+ device = 'cuda'
66
+ else:
67
+ device = 'cpu'
68
+
69
+ colorizer = MangaColorizator(device, args.generator, args.extractor)
70
+
71
+ if os.path.isdir(args.path):
72
+ colorization_path = os.path.join(args.path, 'colorization')
73
+ if not os.path.exists(colorization_path):
74
+ os.makedirs(colorization_path)
75
+
76
+ colorize_images(colorization_path, colorizer, args)
77
+
78
+ elif os.path.isfile(args.path):
79
+
80
+ split = os.path.splitext(args.path)
81
+
82
+ if split[1].lower() in ('.jpg', '.png', ',jpeg'):
83
+ new_image_path = split[0] + '_colorized' + '.png'
84
+
85
+ colorize_single_image(args.path, new_image_path, colorizer, args)
86
+ else:
87
+ print('Wrong format')
88
+ else:
89
+ print('Wrong path')
90
+
networks/__pycache__/extractor.cpython-39.pyc ADDED
Binary file (3.82 kB). View file
 
networks/__pycache__/models.cpython-39.pyc ADDED
Binary file (10.7 kB). View file
 
networks/extractor.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ '''https://github.com/blandocs/Tag2Pix/blob/master/model/pretrained.py'''
6
+
7
+ # Pretrained version
8
+ class Selayer(nn.Module):
9
+ def __init__(self, inplanes):
10
+ super(Selayer, self).__init__()
11
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
12
+ self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
13
+ self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
14
+ self.relu = nn.ReLU(inplace=True)
15
+ self.sigmoid = nn.Sigmoid()
16
+
17
+ def forward(self, x):
18
+ out = self.global_avgpool(x)
19
+ out = self.conv1(out)
20
+ out = self.relu(out)
21
+ out = self.conv2(out)
22
+ out = self.sigmoid(out)
23
+
24
+ return x * out
25
+
26
+
27
+ class BottleneckX_Origin(nn.Module):
28
+ expansion = 4
29
+
30
+ def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
31
+ super(BottleneckX_Origin, self).__init__()
32
+ self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
33
+ self.bn1 = nn.BatchNorm2d(planes * 2)
34
+
35
+ self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride,
36
+ padding=1, groups=cardinality, bias=False)
37
+ self.bn2 = nn.BatchNorm2d(planes * 2)
38
+
39
+ self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
40
+ self.bn3 = nn.BatchNorm2d(planes * 4)
41
+
42
+ self.selayer = Selayer(planes * 4)
43
+
44
+ self.relu = nn.ReLU(inplace=True)
45
+ self.downsample = downsample
46
+ self.stride = stride
47
+
48
+ def forward(self, x):
49
+ residual = x
50
+
51
+ out = self.conv1(x)
52
+ out = self.bn1(out)
53
+ out = self.relu(out)
54
+
55
+ out = self.conv2(out)
56
+ out = self.bn2(out)
57
+ out = self.relu(out)
58
+
59
+ out = self.conv3(out)
60
+ out = self.bn3(out)
61
+
62
+ out = self.selayer(out)
63
+
64
+ if self.downsample is not None:
65
+ residual = self.downsample(x)
66
+
67
+ out += residual
68
+ out = self.relu(out)
69
+
70
+ return out
71
+
72
+ class SEResNeXt_Origin(nn.Module):
73
+ def __init__(self, block, layers, input_channels=3, cardinality=32, num_classes=1000):
74
+ super(SEResNeXt_Origin, self).__init__()
75
+ self.cardinality = cardinality
76
+ self.inplanes = 64
77
+ self.input_channels = input_channels
78
+
79
+ self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
80
+ bias=False)
81
+ self.bn1 = nn.BatchNorm2d(64)
82
+ self.relu = nn.ReLU(inplace=True)
83
+
84
+ self.layer1 = self._make_layer(block, 64, layers[0])
85
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
86
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
87
+
88
+ for m in self.modules():
89
+ if isinstance(m, nn.Conv2d):
90
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
91
+ m.weight.data.normal_(0, math.sqrt(2. / n))
92
+ if m.bias is not None:
93
+ m.bias.data.zero_()
94
+ elif isinstance(m, nn.BatchNorm2d):
95
+ m.weight.data.fill_(1)
96
+ m.bias.data.zero_()
97
+
98
+ def _make_layer(self, block, planes, blocks, stride=1):
99
+ downsample = None
100
+ if stride != 1 or self.inplanes != planes * block.expansion:
101
+ downsample = nn.Sequential(
102
+ nn.Conv2d(self.inplanes, planes * block.expansion,
103
+ kernel_size=1, stride=stride, bias=False),
104
+ nn.BatchNorm2d(planes * block.expansion),
105
+ )
106
+
107
+ layers = []
108
+ layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
109
+ self.inplanes = planes * block.expansion
110
+ for i in range(1, blocks):
111
+ layers.append(block(self.inplanes, planes, self.cardinality))
112
+
113
+ return nn.Sequential(*layers)
114
+
115
+ def forward(self, x):
116
+
117
+ x = self.conv1(x)
118
+ x = self.bn1(x)
119
+ x1 = self.relu(x)
120
+
121
+ x2 = self.layer1(x1)
122
+
123
+ x3 = self.layer2(x2)
124
+
125
+ x4 = self.layer3(x3)
126
+
127
+ return x1, x2, x3, x4
networks/generator.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae9bc204753267a38eeb43d262fee9c96fb1b5035fd89bbdf567fc69e5d3ebd1
3
+ size 202088636
networks/models.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as M
5
+ import math
6
+ from torch import Tensor
7
+ from torch.nn import Parameter
8
+
9
+ from .extractor import SEResNeXt_Origin, BottleneckX_Origin
10
+
11
+ '''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
12
+
13
+ def l2normalize(v, eps=1e-12):
14
+ return v / (v.norm() + eps)
15
+
16
+
17
+ class SpectralNorm(nn.Module):
18
+ def __init__(self, module, name='weight', power_iterations=1):
19
+ super(SpectralNorm, self).__init__()
20
+ self.module = module
21
+ self.name = name
22
+ self.power_iterations = power_iterations
23
+ if not self._made_params():
24
+ self._make_params()
25
+
26
+ def _update_u_v(self):
27
+ u = getattr(self.module, self.name + "_u")
28
+ v = getattr(self.module, self.name + "_v")
29
+ w = getattr(self.module, self.name + "_bar")
30
+
31
+ height = w.data.shape[0]
32
+ for _ in range(self.power_iterations):
33
+ v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
34
+ u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
35
+
36
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
37
+ sigma = u.dot(w.view(height, -1).mv(v))
38
+ setattr(self.module, self.name, w / sigma.expand_as(w))
39
+
40
+ def _made_params(self):
41
+ try:
42
+ u = getattr(self.module, self.name + "_u")
43
+ v = getattr(self.module, self.name + "_v")
44
+ w = getattr(self.module, self.name + "_bar")
45
+ return True
46
+ except AttributeError:
47
+ return False
48
+
49
+
50
+ def _make_params(self):
51
+ w = getattr(self.module, self.name)
52
+ height = w.data.shape[0]
53
+ width = w.view(height, -1).data.shape[1]
54
+
55
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
56
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
57
+ u.data = l2normalize(u.data)
58
+ v.data = l2normalize(v.data)
59
+ w_bar = Parameter(w.data)
60
+
61
+ del self.module._parameters[self.name]
62
+
63
+ self.module.register_parameter(self.name + "_u", u)
64
+ self.module.register_parameter(self.name + "_v", v)
65
+ self.module.register_parameter(self.name + "_bar", w_bar)
66
+
67
+
68
+ def forward(self, *args):
69
+ self._update_u_v()
70
+ return self.module.forward(*args)
71
+
72
+ class Selayer(nn.Module):
73
+ def __init__(self, inplanes):
74
+ super(Selayer, self).__init__()
75
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
76
+ self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
77
+ self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
78
+ self.relu = nn.ReLU(inplace=True)
79
+ self.sigmoid = nn.Sigmoid()
80
+
81
+ def forward(self, x):
82
+ out = self.global_avgpool(x)
83
+ out = self.conv1(out)
84
+ out = self.relu(out)
85
+ out = self.conv2(out)
86
+ out = self.sigmoid(out)
87
+
88
+ return x * out
89
+
90
+ class SelayerSpectr(nn.Module):
91
+ def __init__(self, inplanes):
92
+ super(SelayerSpectr, self).__init__()
93
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
94
+ self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
95
+ self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
96
+ self.relu = nn.ReLU(inplace=True)
97
+ self.sigmoid = nn.Sigmoid()
98
+
99
+ def forward(self, x):
100
+ out = self.global_avgpool(x)
101
+ out = self.conv1(out)
102
+ out = self.relu(out)
103
+ out = self.conv2(out)
104
+ out = self.sigmoid(out)
105
+
106
+ return x * out
107
+
108
+ class ResNeXtBottleneck(nn.Module):
109
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
110
+ super(ResNeXtBottleneck, self).__init__()
111
+ D = out_channels // 2
112
+ self.out_channels = out_channels
113
+ self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
114
+ self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
115
+ groups=cardinality,
116
+ bias=False)
117
+ self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
118
+ self.shortcut = nn.Sequential()
119
+ if stride != 1:
120
+ self.shortcut.add_module('shortcut',
121
+ nn.AvgPool2d(2, stride=2))
122
+
123
+ self.selayer = Selayer(out_channels)
124
+
125
+ def forward(self, x):
126
+ bottleneck = self.conv_reduce.forward(x)
127
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
128
+ bottleneck = self.conv_conv.forward(bottleneck)
129
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
130
+ bottleneck = self.conv_expand.forward(bottleneck)
131
+ bottleneck = self.selayer(bottleneck)
132
+
133
+ x = self.shortcut.forward(x)
134
+ return x + bottleneck
135
+
136
+ class SpectrResNeXtBottleneck(nn.Module):
137
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
138
+ super(SpectrResNeXtBottleneck, self).__init__()
139
+ D = out_channels // 2
140
+ self.out_channels = out_channels
141
+ self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
142
+ self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
143
+ groups=cardinality,
144
+ bias=False))
145
+ self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
146
+ self.shortcut = nn.Sequential()
147
+ if stride != 1:
148
+ self.shortcut.add_module('shortcut',
149
+ nn.AvgPool2d(2, stride=2))
150
+
151
+ self.selayer = SelayerSpectr(out_channels)
152
+
153
+ def forward(self, x):
154
+ bottleneck = self.conv_reduce.forward(x)
155
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
156
+ bottleneck = self.conv_conv.forward(bottleneck)
157
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
158
+ bottleneck = self.conv_expand.forward(bottleneck)
159
+ bottleneck = self.selayer(bottleneck)
160
+
161
+ x = self.shortcut.forward(x)
162
+ return x + bottleneck
163
+
164
+ class FeatureConv(nn.Module):
165
+ def __init__(self, input_dim=512, output_dim=512):
166
+ super(FeatureConv, self).__init__()
167
+
168
+ no_bn = True
169
+
170
+ seq = []
171
+ seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
172
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
173
+ seq.append(nn.ReLU(inplace=True))
174
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
175
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
176
+ seq.append(nn.ReLU(inplace=True))
177
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
178
+ seq.append(nn.ReLU(inplace=True))
179
+
180
+ self.network = nn.Sequential(*seq)
181
+
182
+ def forward(self, x):
183
+ return self.network(x)
184
+
185
+ class Generator(nn.Module):
186
+ def __init__(self, ngf=64):
187
+ super(Generator, self).__init__()
188
+
189
+ self.encoder = SEResNeXt_Origin(BottleneckX_Origin, [3, 4, 6, 3], num_classes= 370, input_channels=1)
190
+
191
+ self.to0 = self._make_encoder_block_first(5, 32)
192
+ self.to1 = self._make_encoder_block(32, 64)
193
+ self.to2 = self._make_encoder_block(64, 92)
194
+ self.to3 = self._make_encoder_block(92, 128)
195
+ self.to4 = self._make_encoder_block(128, 256)
196
+
197
+ self.deconv_for_decoder = nn.Sequential(
198
+ nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
199
+ nn.LeakyReLU(0.2),
200
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
201
+ nn.LeakyReLU(0.2),
202
+ nn.ConvTranspose2d(64, 32, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
203
+ nn.LeakyReLU(0.2),
204
+ nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
205
+ nn.Tanh(),
206
+ )
207
+
208
+ tunnel4 = nn.Sequential(*[ResNeXtBottleneck(512, 512, cardinality=32, dilate=1) for _ in range(20)])
209
+
210
+
211
+ self.tunnel4 = nn.Sequential(nn.Conv2d(1024 + 128, 512, kernel_size=3, stride=1, padding=1),
212
+ nn.LeakyReLU(0.2, True),
213
+ tunnel4,
214
+ nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1),
215
+ nn.PixelShuffle(2),
216
+ nn.LeakyReLU(0.2, True)
217
+ ) # 64
218
+
219
+ depth = 2
220
+ tunnel = [ResNeXtBottleneck(256, 256, cardinality=32, dilate=1) for _ in range(depth)]
221
+ tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=2) for _ in range(depth)]
222
+ tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=4) for _ in range(depth)]
223
+ tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=2),
224
+ ResNeXtBottleneck(256, 256, cardinality=32, dilate=1)]
225
+ tunnel3 = nn.Sequential(*tunnel)
226
+
227
+ self.tunnel3 = nn.Sequential(nn.Conv2d(512 + 256, 256, kernel_size=3, stride=1, padding=1),
228
+ nn.LeakyReLU(0.2, True),
229
+ tunnel3,
230
+ nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
231
+ nn.PixelShuffle(2),
232
+ nn.LeakyReLU(0.2, True)
233
+ ) # 128
234
+
235
+ tunnel = [ResNeXtBottleneck(128, 128, cardinality=32, dilate=1) for _ in range(depth)]
236
+ tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=2) for _ in range(depth)]
237
+ tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=4) for _ in range(depth)]
238
+ tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=2),
239
+ ResNeXtBottleneck(128, 128, cardinality=32, dilate=1)]
240
+ tunnel2 = nn.Sequential(*tunnel)
241
+
242
+ self.tunnel2 = nn.Sequential(nn.Conv2d(128 + 256 + 64, 128, kernel_size=3, stride=1, padding=1),
243
+ nn.LeakyReLU(0.2, True),
244
+ tunnel2,
245
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
246
+ nn.PixelShuffle(2),
247
+ nn.LeakyReLU(0.2, True)
248
+ )
249
+
250
+ tunnel = [ResNeXtBottleneck(64, 64, cardinality=16, dilate=1)]
251
+ tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=2)]
252
+ tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=4)]
253
+ tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=2),
254
+ ResNeXtBottleneck(64, 64, cardinality=16, dilate=1)]
255
+ tunnel1 = nn.Sequential(*tunnel)
256
+
257
+ self.tunnel1 = nn.Sequential(nn.Conv2d(64 + 32, 64, kernel_size=3, stride=1, padding=1),
258
+ nn.LeakyReLU(0.2, True),
259
+ tunnel1,
260
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
261
+ nn.PixelShuffle(2),
262
+ nn.LeakyReLU(0.2, True)
263
+ )
264
+
265
+ self.exit = nn.Sequential(nn.Conv2d(64 + 32, 32, kernel_size=3, stride=1, padding=1),
266
+ nn.LeakyReLU(0.2, True),
267
+ nn.Conv2d(32, 3, kernel_size= 1, stride = 1, padding = 0))
268
+
269
+
270
+ def _make_encoder_block(self, inplanes, planes):
271
+ return nn.Sequential(
272
+ nn.Conv2d(inplanes, planes, 3, 2, 1),
273
+ nn.LeakyReLU(0.2),
274
+ nn.Conv2d(planes, planes, 3, 1, 1),
275
+ nn.LeakyReLU(0.2),
276
+ )
277
+
278
+ def _make_encoder_block_first(self, inplanes, planes):
279
+ return nn.Sequential(
280
+ nn.Conv2d(inplanes, planes, 3, 1, 1),
281
+ nn.LeakyReLU(0.2),
282
+ nn.Conv2d(planes, planes, 3, 1, 1),
283
+ nn.LeakyReLU(0.2),
284
+ )
285
+
286
+ def forward(self, sketch):
287
+
288
+ x0 = self.to0(sketch)
289
+ aux_out = self.to1(x0)
290
+ aux_out = self.to2(aux_out)
291
+ aux_out = self.to3(aux_out)
292
+
293
+ x1, x2, x3, x4 = self.encoder(sketch[:, 0:1])
294
+
295
+ out = self.tunnel4(torch.cat([x4, aux_out], 1))
296
+
297
+
298
+
299
+ x = self.tunnel3(torch.cat([out, x3], 1))
300
+
301
+ x = self.tunnel2(torch.cat([x, x2, x1], 1))
302
+
303
+
304
+ x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
305
+
306
+ decoder_output = self.deconv_for_decoder(out)
307
+
308
+ return x, decoder_output
309
+
310
+
311
+ class Colorizer(nn.Module):
312
+ def __init__(self):
313
+ super(Colorizer, self).__init__()
314
+
315
+ self.generator = Generator()
316
+
317
+ def forward(self, x, extractor_grad = False):
318
+ fake, guide = self.generator(x)
319
+ return fake, guide
readme.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Automatic colorization
2
+
3
+ 1. Download [generator](https://drive.google.com/file/d/1qmxUEKADkEM4iYLp1fpPLLKnfZ6tcF-t/view?usp=sharing) and [denoiser](https://drive.google.com/file/d/161oyQcYpdkVdw8gKz_MA8RD-Wtg9XDp3/view?usp=sharing) weights. Put generator and extractor weights in `networks` and denoiser weights in `denoising/models`.
4
+ 2. To colorize image or folder of images, use the following command:
5
+ ```
6
+ $ python inference.py -p "path to file or folder"
7
+ ```
8
+
9
+ | col 1 | col 2 |
10
+ |------------|-------------|
11
+ | <img src="figures/bw1.jpg" width="512"> | <img src="figures/color1.png" width="512"> |
12
+ | <img src="figures/bw2.jpg" width="512"> | <img src="figures/color2.png" width="512"> |
13
+ | <img src="figures/bw3.jpg" width="512"> | <img src="figures/color3.png" width="512"> |
14
+ | <img src="figures/bw4.jpg" width="512"> | <img src="figures/color4.png" width="512"> |
15
+ | <img src="figures/bw5.jpg" width="512"> | <img src="figures/color5.png" width="512"> |
16
+ | <img src="figures/bw6.jpg" width="512"> | <img src="figures/color6.png" width="512"> |
utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.06 kB). View file
 
utils/utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ def resize_pad(img, size = 256):
5
+
6
+ if len(img.shape) == 2:
7
+ img = np.expand_dims(img, 2)
8
+
9
+ if img.shape[2] == 1:
10
+ img = np.repeat(img, 3, 2)
11
+
12
+ if img.shape[2] == 4:
13
+ img = img[:, :, :3]
14
+
15
+ pad = None
16
+
17
+ if (img.shape[0] < img.shape[1]):
18
+ height = img.shape[0]
19
+ ratio = height / (size * 1.5)
20
+ width = int(np.ceil(img.shape[1] / ratio))
21
+ img = cv2.resize(img, (width, int(size * 1.5)), interpolation = cv2.INTER_AREA)
22
+
23
+
24
+ new_width = width + (32 - width % 32)
25
+
26
+ pad = (0, new_width - width)
27
+
28
+ img = np.pad(img, ((0, 0), (0, pad[1]), (0, 0)), 'maximum')
29
+ else:
30
+ width = img.shape[1]
31
+ ratio = width / size
32
+ height = int(np.ceil(img.shape[0] / ratio))
33
+ img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
34
+
35
+ new_height = height + (32 - height % 32)
36
+
37
+ pad = (new_height - height, 0)
38
+
39
+ img = np.pad(img, ((0, pad[0]), (0, 0), (0, 0)), 'maximum')
40
+
41
+ if (img.dtype == 'float32'):
42
+ np.clip(img, 0, 1, out = img)
43
+
44
+ return img[:, :, :1], pad