Raid41 commited on
Commit
5f240b7
·
1 Parent(s): ba264a5

Upload 34 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,19 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ dataset/manga/train/bw/002-0000-0000.png filter=lfs diff=lfs merge=lfs -text
37
+ dataset/manga/train/bw/x2-0000-0000.png filter=lfs diff=lfs merge=lfs -text
38
+ dataset/manga/train/bw/x3-0000-0000.png filter=lfs diff=lfs merge=lfs -text
39
+ dataset/manga/train/bw/x5-0000-0000.png filter=lfs diff=lfs merge=lfs -text
40
+ dataset/manga/train/color/002-0000-0000.png filter=lfs diff=lfs merge=lfs -text
41
+ dataset/manga/train/color/x1-0000-0000.png filter=lfs diff=lfs merge=lfs -text
42
+ dataset/manga/train/color/x2-0000-0000.png filter=lfs diff=lfs merge=lfs -text
43
+ dataset/manga/train/color/x3-0000-0000.png filter=lfs diff=lfs merge=lfs -text
44
+ dataset/manga/train/color/x4-0000-0000.png filter=lfs diff=lfs merge=lfs -text
45
+ dataset/manga/train/color/x5-0000-0000.png filter=lfs diff=lfs merge=lfs -text
46
+ dataset/manga/train/real_manga/002-0000-0000.png filter=lfs diff=lfs merge=lfs -text
47
+ dataset/manga/train/real_manga/x1-0000-0000.png filter=lfs diff=lfs merge=lfs -text
48
+ dataset/manga/train/real_manga/x2-0000-0000.png filter=lfs diff=lfs merge=lfs -text
49
+ dataset/manga/train/real_manga/x3-0000-0000.png filter=lfs diff=lfs merge=lfs -text
50
+ dataset/manga/train/real_manga/x4-0000-0000.png filter=lfs diff=lfs merge=lfs -text
51
+ dataset/manga/train/real_manga/x5-0000-0000.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.ipynb
2
+ *.pth
3
+
4
+ __pycache__/
5
+ temp_colorization/
configs/train_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "generator_lr" : 1e-4,
3
+ "discriminator_lr" : 4e-4,
4
+ "epochs" : 15,
5
+ "lr_decrease_epoch" : 10,
6
+ "finetuning_generator_lr" : 1e-6,
7
+ "finetuning_iterations" : 3500,
8
+ "batch_size" : 4,
9
+ "number_of_mults" : 3
10
+ }
configs/xdog_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sigma" : 0.5,
3
+ "k" : 8,
4
+ "phi" : 89.25,
5
+ "gamma" : 0.95,
6
+ "eps" : -0.1,
7
+ "mult" : 7
8
+ }
dataset/datasets.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import torchvision.transforms as transforms
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+
7
+ from utils.utils import generate_mask
8
+
9
+
10
+ class TrainDataset(torch.utils.data.Dataset):
11
+ def __init__(self, data_path, transform = None, mults_amount = 1):
12
+ self.data = os.listdir(os.path.join(data_path, 'color'))
13
+ self.data_path = data_path
14
+ self.transform = transform
15
+ self.mults_amount = mults_amount
16
+
17
+ self.ToTensor = transforms.ToTensor()
18
+ def __len__(self):
19
+ return len(self.data)
20
+
21
+ def __getitem__(self, idx):
22
+ image_name = self.data[idx]
23
+
24
+ color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
25
+
26
+
27
+ if self.mults_amount > 1:
28
+ mult_number = np.random.choice(range(self.mults_amount))
29
+
30
+ bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
32
+ else:
33
+ bw_name = self.data[idx]
34
+ dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
35
+
36
+
37
+ bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
38
+ dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
39
+
40
+ bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
41
+
42
+ if self.transform:
43
+ result = self.transform(image = color_img, mask = bw_img)
44
+ color_img = result['image']
45
+ bw_img = result['mask']
46
+
47
+ dfm_img = bw_img[:, :, 1]
48
+ bw_img = bw_img[:, :, 0]
49
+
50
+ color_img = self.ToTensor(color_img)
51
+ bw_img = self.ToTensor(bw_img)
52
+
53
+ dfm_img = self.ToTensor(dfm_img)
54
+
55
+ color_img = (color_img - 0.5) / 0.5
56
+
57
+ mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
58
+ hint = torch.cat((color_img * mask, mask), 0)
59
+
60
+ return bw_img, color_img, hint, dfm_img
61
+
62
+ class FineTuningDataset(torch.utils.data.Dataset):
63
+ def __init__(self, data_path, transform = None, mult_amount = 1):
64
+ self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
65
+ self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
66
+ self.data_path = data_path
67
+ self.transform = transform
68
+ self.mults_amount = mult_amount
69
+
70
+ np.random.shuffle(self.color_data)
71
+
72
+ self.ToTensor = transforms.ToTensor()
73
+ def __len__(self):
74
+ return len(self.data)
75
+
76
+ def __getitem__(self, idx):
77
+ color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
78
+
79
+ image_name = self.data[idx]
80
+ if self.mults_amount > 1:
81
+ mult_number = np.random.choice(range(self.mults_amount))
82
+
83
+ bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
84
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
85
+ else:
86
+ bw_name = self.data[idx]
87
+ dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
88
+
89
+
90
+ bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
91
+ dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', dfm_name)), 2)
92
+
93
+ if self.transform:
94
+ result = self.transform(image = color_img)
95
+ color_img = result['image']
96
+
97
+ result = self.transform(image = bw_img, mask = dfm_img)
98
+ bw_img = result['image']
99
+ dfm_img = result['mask']
100
+
101
+ color_img = self.ToTensor(color_img)
102
+ bw_img = self.ToTensor(bw_img)
103
+ dfm_img = self.ToTensor(dfm_img)
104
+
105
+ color_img = (color_img - 0.5) / 0.5
106
+
107
+ return bw_img, dfm_img, color_img
dataset/manga/train/bw/002-0000-0000.png ADDED

Git LFS Details

  • SHA256: ab9ffbbdc32766ce64f758d0402194d5f78b444d6614e4a5088256ace4dad327
  • Pointer size: 132 Bytes
  • Size of remote file: 3.87 MB
dataset/manga/train/bw/003-0000-0000.png ADDED
dataset/manga/train/bw/x2-0000-0000.png ADDED

Git LFS Details

  • SHA256: b605bc2accdddb7659463cbdf891a4ac5ac12c12dddd406de001d1737cfc2818
  • Pointer size: 132 Bytes
  • Size of remote file: 6.7 MB
dataset/manga/train/bw/x3-0000-0000.png ADDED

Git LFS Details

  • SHA256: ebfd588a0e54d0380b3dfa522a3b8b8d16a40180ea66ba8d3da6b0abdac2f195
  • Pointer size: 132 Bytes
  • Size of remote file: 3.98 MB
dataset/manga/train/bw/x5-0000-0000.png ADDED

Git LFS Details

  • SHA256: b1c940b720dc52aa4f8f5372fb8071b639f5ae8b75470983b5301aabd3d9bd2b
  • Pointer size: 132 Bytes
  • Size of remote file: 4.25 MB
dataset/manga/train/color/002-0000-0000.png ADDED

Git LFS Details

  • SHA256: 1a289ac6f4c1b477e987eef553232e70cefd536b62cf7c0662e27d84c55b3f02
  • Pointer size: 132 Bytes
  • Size of remote file: 6.26 MB
dataset/manga/train/color/003-0000-0000.png ADDED
dataset/manga/train/color/004-0000-0000.png ADDED
dataset/manga/train/color/x1-0000-0000.png ADDED

Git LFS Details

  • SHA256: 4fee65ac5d942ccbcf46250042077897a1a1d0288dbcbd4d4e01a95667b4321f
  • Pointer size: 132 Bytes
  • Size of remote file: 5.74 MB
dataset/manga/train/color/x2-0000-0000.png ADDED

Git LFS Details

  • SHA256: 3934c58bf1876be001bd806f637d7722dbd6ba34ba9cc2aed98288019771a248
  • Pointer size: 133 Bytes
  • Size of remote file: 11.3 MB
dataset/manga/train/color/x3-0000-0000.png ADDED

Git LFS Details

  • SHA256: e8690c5617dbb1215cfdc0f70a2e31156794b0abf21dde7f9c50bbe2675de71e
  • Pointer size: 132 Bytes
  • Size of remote file: 5.25 MB
dataset/manga/train/color/x4-0000-0000.png ADDED

Git LFS Details

  • SHA256: c5d0f633082871aa456f2d7864d0d9f4fa1ca9ae2b1cc229df4927b38a4df40c
  • Pointer size: 132 Bytes
  • Size of remote file: 4.98 MB
dataset/manga/train/color/x5-0000-0000.png ADDED

Git LFS Details

  • SHA256: acba9510555cebb18fbdb220ef3f22724617893f35d8a3fe1c509ed31f257dbb
  • Pointer size: 132 Bytes
  • Size of remote file: 5.63 MB
dataset/manga/train/real_manga/002-0000-0000.png ADDED

Git LFS Details

  • SHA256: 1a289ac6f4c1b477e987eef553232e70cefd536b62cf7c0662e27d84c55b3f02
  • Pointer size: 132 Bytes
  • Size of remote file: 6.26 MB
dataset/manga/train/real_manga/003-0000-0000.png ADDED
dataset/manga/train/real_manga/004-0000-0000.png ADDED
dataset/manga/train/real_manga/x1-0000-0000.png ADDED

Git LFS Details

  • SHA256: 4fee65ac5d942ccbcf46250042077897a1a1d0288dbcbd4d4e01a95667b4321f
  • Pointer size: 132 Bytes
  • Size of remote file: 5.74 MB
dataset/manga/train/real_manga/x2-0000-0000.png ADDED

Git LFS Details

  • SHA256: 3934c58bf1876be001bd806f637d7722dbd6ba34ba9cc2aed98288019771a248
  • Pointer size: 133 Bytes
  • Size of remote file: 11.3 MB
dataset/manga/train/real_manga/x3-0000-0000.png ADDED

Git LFS Details

  • SHA256: e8690c5617dbb1215cfdc0f70a2e31156794b0abf21dde7f9c50bbe2675de71e
  • Pointer size: 132 Bytes
  • Size of remote file: 5.25 MB
dataset/manga/train/real_manga/x4-0000-0000.png ADDED

Git LFS Details

  • SHA256: c5d0f633082871aa456f2d7864d0d9f4fa1ca9ae2b1cc229df4927b38a4df40c
  • Pointer size: 132 Bytes
  • Size of remote file: 4.98 MB
dataset/manga/train/real_manga/x5-0000-0000.png ADDED

Git LFS Details

  • SHA256: acba9510555cebb18fbdb220ef3f22724617893f35d8a3fe1c509ed31f257dbb
  • Pointer size: 132 Bytes
  • Size of remote file: 5.63 MB
inference.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from utils.dataset_utils import get_sketch
5
+ from utils.utils import resize_pad, generate_mask, extract_cbr, create_cbz, sorted_alphanumeric, subfolder_image_search, remove_folder
6
+ from torchvision.transforms import ToTensor
7
+ import os
8
+ import matplotlib.pyplot as plt
9
+ import argparse
10
+ from model.models import Colorizer, Generator
11
+ from model.extractor import get_seresnext_extractor
12
+ from utils.xdog import XDoGSketcher
13
+ from utils.utils import open_json
14
+ import sys
15
+
16
+ def colorize_without_hint(inp, colorizer, device = 'cpu', auto_hint = False, auto_hint_sigma = 0.003):
17
+ i_hint = torch.zeros(1, 4, inp.shape[2], inp.shape[3]).float().to(device)
18
+
19
+ with torch.no_grad():
20
+ fake_color, _ = colorizer(torch.cat([inp, i_hint], 1))
21
+
22
+ if auto_hint:
23
+ mask = generate_mask(fake_color.shape[2], fake_color.shape[3], full = False, prob = 1, sigma = auto_hint_sigma).unsqueeze(0)
24
+ mask = mask.to(device)
25
+ i_hint = torch.cat([fake_color * mask, mask], 1)
26
+
27
+ with torch.no_grad():
28
+ fake_color, _ = colorizer(torch.cat([inp, i_hint], 1))
29
+
30
+ return fake_color
31
+
32
+
33
+ def process_image(image, sketcher, colorizer, auto_hint, auto_hint_sigma = 0.003, dfm = True, device = 'cpu', to_tensor = ToTensor()):
34
+ image, pad = resize_pad(image)
35
+ bw, dfm = get_sketch(image, sketcher, dfm)
36
+
37
+ bw = to_tensor(bw).unsqueeze(0).to(device)
38
+ dfm = to_tensor(dfm).unsqueeze(0).to(device)
39
+
40
+ output = colorize_without_hint(torch.cat([bw, dfm], 1), colorizer, device = device, auto_hint = auto_hint)
41
+ result = output[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
42
+
43
+ if pad[0] != 0:
44
+ result = result[:-pad[0]]
45
+ if pad[1] != 0:
46
+ result = result[:, :-pad[1]]
47
+
48
+ return result
49
+
50
+ def colorize_single_image(file_path, save_path, sketcher, colorizer, auto_hint, auto_hint_sigma = 0.003, dfm = True, device = 'cpu'):
51
+ try:
52
+ image = plt.imread(file_path)
53
+
54
+ colorization = process_image(image, sketcher, colorizer, auto_hint, auto_hint_sigma, dfm, device)
55
+
56
+ plt.imsave(save_path, colorization)
57
+ except KeyboardInterrupt:
58
+ sys.exit(0)
59
+ except:
60
+ print('Failed to colorize {}'.format(file_path))
61
+
62
+ def colorize_images(source_path, target_path, sketcher, colorizer, auto_hint, auto_hint_sigma = 0.003, dfm = True, device = 'cpu'):
63
+ images = os.listdir(source_path)
64
+
65
+ for image_name in images:
66
+ file_path = os.path.join(source_path, image_name)
67
+ save_path = os.path.join(target_path, image_name)
68
+ colorize_single_image(file_path, save_path, sketcher, colorizer, auto_hint, auto_hint_sigma, dfm, device)
69
+
70
+ def colorize_cbr(file_path, sketcher, colorizer, auto_hint, auto_hint_sigma = 0.003, dfm = True, device = 'cpu'):
71
+ file_name = os.path.splitext(os.path.basename(file_path))[0]
72
+ temp_path = 'temp_colorization'
73
+
74
+ if not os.path.exists(temp_path):
75
+ os.makedirs(temp_path)
76
+ extract_cbr(file_path, temp_path)
77
+
78
+ images = subfolder_image_search(temp_path)
79
+ for image_path in images:
80
+ try:
81
+ image = plt.imread(image_path)
82
+
83
+ colorization = process_image(image, sketcher, colorizer, auto_hint, auto_hint_sigma, dfm, device)
84
+
85
+ plt.imsave(image_path, colorization)
86
+ except KeyboardInterrupt:
87
+ sys.exit(0)
88
+ except:
89
+ print('Failed to colorize {}'.format(image_path))
90
+
91
+ result_name = os.path.join(os.path.dirname(file_path), file_name + '_colorized.cbz')
92
+
93
+ create_cbz(result_name, images)
94
+
95
+ remove_folder(temp_path)
96
+
97
+ def parse_args():
98
+ parser = argparse.ArgumentParser()
99
+ parser.add_argument("-p", "--path", required=True)
100
+ parser.add_argument("-gen", "--generator", default = 'model/biggan.pth')
101
+ parser.add_argument("-ext", "--extractor", default = 'model/extractor.pth')
102
+ parser.add_argument("-s", "--sigma", type = float, default = 0.003)
103
+ parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
104
+ parser.add_argument('-ah', '--auto', dest = 'autohint', action = 'store_true')
105
+ parser.set_defaults(gpu = False)
106
+ parser.set_defaults(autohint = False)
107
+ args = parser.parse_args()
108
+
109
+ return args
110
+
111
+
112
+ if __name__ == "__main__":
113
+
114
+ args = parse_args()
115
+
116
+ if args.gpu:
117
+ device = 'cuda'
118
+ else:
119
+ device = 'cpu'
120
+
121
+ generator = Generator()
122
+ generator.load_state_dict(torch.load(args.generator))
123
+
124
+ extractor = get_seresnext_extractor()
125
+ extractor.load_state_dict(torch.load(args.extractor))
126
+
127
+ colorizer = Colorizer(generator, extractor)
128
+ colorizer = colorizer.eval().to(device)
129
+
130
+ sketcher = XDoGSketcher()
131
+ xdog_config = open_json('configs/xdog_config.json')
132
+ for key in xdog_config.keys():
133
+ if key in sketcher.params:
134
+ sketcher.params[key] = xdog_config[key]
135
+
136
+ if os.path.isdir(args.path):
137
+ colorization_path = os.path.join(args.path, 'colorization')
138
+ if not os.path.exists(colorization_path):
139
+ os.makedirs(colorization_path)
140
+
141
+ colorize_images(args.path, colorization_path, sketcher, colorizer, args.autohint, args.sigma, device = device)
142
+ elif os.path.isfile(args.path):
143
+ split = os.path.splitext(args.path)
144
+ if split[1].lower() in ('.cbr', '.cbz', '.rar', '.zip'):
145
+ colorize_cbr(args.path, sketcher, colorizer, args.autohint, args.sigma, device = device)
146
+ elif split[1].lower() in ('.jpg', '.png'):
147
+ new_image_path = split[0] + '_colorized' + split[1]
148
+
149
+ colorize_single_image(args.path, new_image_path, sketcher, colorizer, args.autohint, args.sigma, device = device)
150
+ else:
151
+ print('Wrong format')
152
+ else:
153
+ print('Wrong path')
154
+
model/extractor.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee3c59f02ac8c59298fd9b819fa33d2efa168847e15e4be39b35c286f7c18607
3
+ size 6340842
model/extractor.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ '''https://github.com/blandocs/Tag2Pix/blob/master/model/pretrained.py'''
6
+
7
+ # Pretrained version
8
+ class Selayer(nn.Module):
9
+ def __init__(self, inplanes):
10
+ super(Selayer, self).__init__()
11
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
12
+ self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
13
+ self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
14
+ self.relu = nn.ReLU(inplace=True)
15
+ self.sigmoid = nn.Sigmoid()
16
+
17
+ def forward(self, x):
18
+ out = self.global_avgpool(x)
19
+ out = self.conv1(out)
20
+ out = self.relu(out)
21
+ out = self.conv2(out)
22
+ out = self.sigmoid(out)
23
+
24
+ return x * out
25
+
26
+
27
+ class BottleneckX_Origin(nn.Module):
28
+ expansion = 4
29
+
30
+ def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
31
+ super(BottleneckX_Origin, self).__init__()
32
+ self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
33
+ self.bn1 = nn.BatchNorm2d(planes * 2)
34
+
35
+ self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride,
36
+ padding=1, groups=cardinality, bias=False)
37
+ self.bn2 = nn.BatchNorm2d(planes * 2)
38
+
39
+ self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
40
+ self.bn3 = nn.BatchNorm2d(planes * 4)
41
+
42
+ self.selayer = Selayer(planes * 4)
43
+
44
+ self.relu = nn.ReLU(inplace=True)
45
+ self.downsample = downsample
46
+ self.stride = stride
47
+
48
+ def forward(self, x):
49
+ residual = x
50
+
51
+ out = self.conv1(x)
52
+ out = self.bn1(out)
53
+ out = self.relu(out)
54
+
55
+ out = self.conv2(out)
56
+ out = self.bn2(out)
57
+ out = self.relu(out)
58
+
59
+ out = self.conv3(out)
60
+ out = self.bn3(out)
61
+
62
+ out = self.selayer(out)
63
+
64
+ if self.downsample is not None:
65
+ residual = self.downsample(x)
66
+
67
+ out += residual
68
+ out = self.relu(out)
69
+
70
+ return out
71
+
72
+ class SEResNeXt_extractor(nn.Module):
73
+ def __init__(self, block, layers, input_channels=3, cardinality=32):
74
+ super(SEResNeXt_extractor, self).__init__()
75
+ self.cardinality = cardinality
76
+ self.inplanes = 64
77
+ self.input_channels = input_channels
78
+
79
+ self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
80
+ bias=False)
81
+ self.bn1 = nn.BatchNorm2d(64)
82
+ self.relu = nn.ReLU(inplace=True)
83
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
84
+
85
+ self.layer1 = self._make_layer(block, 64, layers[0])
86
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
87
+
88
+ for m in self.modules():
89
+ if isinstance(m, nn.Conv2d):
90
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
91
+ m.weight.data.normal_(0, math.sqrt(2. / n))
92
+ if m.bias is not None:
93
+ m.bias.data.zero_()
94
+ elif isinstance(m, nn.BatchNorm2d):
95
+ m.weight.data.fill_(1)
96
+ m.bias.data.zero_()
97
+
98
+ def _make_layer(self, block, planes, blocks, stride=1):
99
+ downsample = None
100
+ if stride != 1 or self.inplanes != planes * block.expansion:
101
+ downsample = nn.Sequential(
102
+ nn.Conv2d(self.inplanes, planes * block.expansion,
103
+ kernel_size=1, stride=stride, bias=False),
104
+ nn.BatchNorm2d(planes * block.expansion),
105
+ )
106
+
107
+ layers = []
108
+ layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
109
+ self.inplanes = planes * block.expansion
110
+ for i in range(1, blocks):
111
+ layers.append(block(self.inplanes, planes, self.cardinality))
112
+
113
+ return nn.Sequential(*layers)
114
+
115
+ def forward(self, x):
116
+ x = self.conv1(x)
117
+ x = self.bn1(x)
118
+ x = self.relu(x)
119
+ x = self.maxpool(x)
120
+
121
+ x = self.layer1(x)
122
+ x = self.layer2(x)
123
+
124
+ return x
125
+
126
+ def get_seresnext_extractor():
127
+ return SEResNeXt_extractor(BottleneckX_Origin, [3, 4, 6, 3], 1)
model/models.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as M
5
+ import math
6
+ from torch import Tensor
7
+ from torch.nn import Parameter
8
+
9
+ '''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
10
+
11
+ def l2normalize(v, eps=1e-12):
12
+ return v / (v.norm() + eps)
13
+
14
+
15
+ class SpectralNorm(nn.Module):
16
+ def __init__(self, module, name='weight', power_iterations=1):
17
+ super(SpectralNorm, self).__init__()
18
+ self.module = module
19
+ self.name = name
20
+ self.power_iterations = power_iterations
21
+ if not self._made_params():
22
+ self._make_params()
23
+
24
+ def _update_u_v(self):
25
+ u = getattr(self.module, self.name + "_u")
26
+ v = getattr(self.module, self.name + "_v")
27
+ w = getattr(self.module, self.name + "_bar")
28
+
29
+ height = w.data.shape[0]
30
+ for _ in range(self.power_iterations):
31
+ v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
32
+ u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
33
+
34
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
35
+ sigma = u.dot(w.view(height, -1).mv(v))
36
+ setattr(self.module, self.name, w / sigma.expand_as(w))
37
+
38
+ def _made_params(self):
39
+ try:
40
+ u = getattr(self.module, self.name + "_u")
41
+ v = getattr(self.module, self.name + "_v")
42
+ w = getattr(self.module, self.name + "_bar")
43
+ return True
44
+ except AttributeError:
45
+ return False
46
+
47
+
48
+ def _make_params(self):
49
+ w = getattr(self.module, self.name)
50
+ height = w.data.shape[0]
51
+ width = w.view(height, -1).data.shape[1]
52
+
53
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
54
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
55
+ u.data = l2normalize(u.data)
56
+ v.data = l2normalize(v.data)
57
+ w_bar = Parameter(w.data)
58
+
59
+ del self.module._parameters[self.name]
60
+
61
+ self.module.register_parameter(self.name + "_u", u)
62
+ self.module.register_parameter(self.name + "_v", v)
63
+ self.module.register_parameter(self.name + "_bar", w_bar)
64
+
65
+
66
+ def forward(self, *args):
67
+ self._update_u_v()
68
+ return self.module.forward(*args)
69
+
70
+ class Selayer(nn.Module):
71
+ def __init__(self, inplanes):
72
+ super(Selayer, self).__init__()
73
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
74
+ self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
75
+ self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
76
+ self.relu = nn.ReLU(inplace=True)
77
+ self.sigmoid = nn.Sigmoid()
78
+
79
+ def forward(self, x):
80
+ out = self.global_avgpool(x)
81
+ out = self.conv1(out)
82
+ out = self.relu(out)
83
+ out = self.conv2(out)
84
+ out = self.sigmoid(out)
85
+
86
+ return x * out
87
+
88
+ class SelayerSpectr(nn.Module):
89
+ def __init__(self, inplanes):
90
+ super(SelayerSpectr, self).__init__()
91
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
92
+ self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
93
+ self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
94
+ self.relu = nn.ReLU(inplace=True)
95
+ self.sigmoid = nn.Sigmoid()
96
+
97
+ def forward(self, x):
98
+ out = self.global_avgpool(x)
99
+ out = self.conv1(out)
100
+ out = self.relu(out)
101
+ out = self.conv2(out)
102
+ out = self.sigmoid(out)
103
+
104
+ return x * out
105
+
106
+ class ResNeXtBottleneck(nn.Module):
107
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
108
+ super(ResNeXtBottleneck, self).__init__()
109
+ D = out_channels // 2
110
+ self.out_channels = out_channels
111
+ self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
112
+ self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
113
+ groups=cardinality,
114
+ bias=False)
115
+ self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
116
+ self.shortcut = nn.Sequential()
117
+ if stride != 1:
118
+ self.shortcut.add_module('shortcut',
119
+ nn.AvgPool2d(2, stride=2))
120
+
121
+ self.selayer = Selayer(out_channels)
122
+
123
+ def forward(self, x):
124
+ bottleneck = self.conv_reduce.forward(x)
125
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
126
+ bottleneck = self.conv_conv.forward(bottleneck)
127
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
128
+ bottleneck = self.conv_expand.forward(bottleneck)
129
+ bottleneck = self.selayer(bottleneck)
130
+
131
+ x = self.shortcut.forward(x)
132
+ return x + bottleneck
133
+
134
+ class SpectrResNeXtBottleneck(nn.Module):
135
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
136
+ super(SpectrResNeXtBottleneck, self).__init__()
137
+ D = out_channels // 2
138
+ self.out_channels = out_channels
139
+ self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
140
+ self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
141
+ groups=cardinality,
142
+ bias=False))
143
+ self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
144
+ self.shortcut = nn.Sequential()
145
+ if stride != 1:
146
+ self.shortcut.add_module('shortcut',
147
+ nn.AvgPool2d(2, stride=2))
148
+
149
+ self.selayer = SelayerSpectr(out_channels)
150
+
151
+ def forward(self, x):
152
+ bottleneck = self.conv_reduce.forward(x)
153
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
154
+ bottleneck = self.conv_conv.forward(bottleneck)
155
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
156
+ bottleneck = self.conv_expand.forward(bottleneck)
157
+ bottleneck = self.selayer(bottleneck)
158
+
159
+ x = self.shortcut.forward(x)
160
+ return x + bottleneck
161
+
162
+ class FeatureConv(nn.Module):
163
+ def __init__(self, input_dim=512, output_dim=512):
164
+ super(FeatureConv, self).__init__()
165
+
166
+ no_bn = True
167
+
168
+ seq = []
169
+ seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
170
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
171
+ seq.append(nn.ReLU(inplace=True))
172
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
173
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
174
+ seq.append(nn.ReLU(inplace=True))
175
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
176
+ seq.append(nn.ReLU(inplace=True))
177
+
178
+ self.network = nn.Sequential(*seq)
179
+
180
+ def forward(self, x):
181
+ return self.network(x)
182
+
183
+ class Generator(nn.Module):
184
+ def __init__(self, ngf=64):
185
+ super(Generator, self).__init__()
186
+
187
+ self.feature_conv = FeatureConv()
188
+
189
+ self.to0 = self._make_encoder_block_first(6, 32)
190
+ self.to1 = self._make_encoder_block(32, 64)
191
+ self.to2 = self._make_encoder_block(64, 128)
192
+ self.to3 = self._make_encoder_block(128, 256)
193
+ self.to4 = self._make_encoder_block(256, 512)
194
+
195
+ self.deconv_for_decoder = nn.Sequential(
196
+ nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
197
+ nn.LeakyReLU(0.2),
198
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
199
+ nn.LeakyReLU(0.2),
200
+ nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # output is 256 * 256
201
+ nn.LeakyReLU(0.2),
202
+ nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
203
+ nn.Tanh(),
204
+ )
205
+
206
+ tunnel4 = nn.Sequential(*[ResNeXtBottleneck(ngf * 8, ngf * 8, cardinality=32, dilate=1) for _ in range(20)])
207
+
208
+ self.tunnel4 = nn.Sequential(nn.Conv2d(ngf * 8 + 512, ngf * 8, kernel_size=3, stride=1, padding=1),
209
+ nn.LeakyReLU(0.2, True),
210
+ tunnel4,
211
+ nn.Conv2d(ngf * 8, ngf * 4 * 4, kernel_size=3, stride=1, padding=1),
212
+ nn.PixelShuffle(2),
213
+ nn.LeakyReLU(0.2, True)
214
+ ) # 64
215
+
216
+ depth = 2
217
+ tunnel = [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1) for _ in range(depth)]
218
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2) for _ in range(depth)]
219
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=4) for _ in range(depth)]
220
+ tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2),
221
+ ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1)]
222
+ tunnel3 = nn.Sequential(*tunnel)
223
+
224
+ self.tunnel3 = nn.Sequential(nn.Conv2d(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=1),
225
+ nn.LeakyReLU(0.2, True),
226
+ tunnel3,
227
+ nn.Conv2d(ngf * 4, ngf * 2 * 4, kernel_size=3, stride=1, padding=1),
228
+ nn.PixelShuffle(2),
229
+ nn.LeakyReLU(0.2, True)
230
+ ) # 128
231
+
232
+ tunnel = [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1) for _ in range(depth)]
233
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2) for _ in range(depth)]
234
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=4) for _ in range(depth)]
235
+ tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2),
236
+ ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1)]
237
+ tunnel2 = nn.Sequential(*tunnel)
238
+
239
+ self.tunnel2 = nn.Sequential(nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=1),
240
+ nn.LeakyReLU(0.2, True),
241
+ tunnel2,
242
+ nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1),
243
+ nn.PixelShuffle(2),
244
+ nn.LeakyReLU(0.2, True)
245
+ )
246
+
247
+ tunnel = [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
248
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2)]
249
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=4)]
250
+ tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2),
251
+ ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
252
+ tunnel1 = nn.Sequential(*tunnel)
253
+
254
+ self.tunnel1 = nn.Sequential(nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=1),
255
+ nn.LeakyReLU(0.2, True),
256
+ tunnel1,
257
+ nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=1, padding=1),
258
+ nn.PixelShuffle(2),
259
+ nn.LeakyReLU(0.2, True)
260
+ )
261
+
262
+ self.exit = nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)
263
+
264
+
265
+ def _make_encoder_block(self, inplanes, planes):
266
+ return nn.Sequential(
267
+ nn.Conv2d(inplanes, planes, 3, 2, 1),
268
+ nn.LeakyReLU(0.2),
269
+ nn.Conv2d(planes, planes, 3, 1, 1),
270
+ nn.LeakyReLU(0.2),
271
+ )
272
+
273
+ def _make_encoder_block_first(self, inplanes, planes):
274
+ return nn.Sequential(
275
+ nn.Conv2d(inplanes, planes, 3, 1, 1),
276
+ nn.LeakyReLU(0.2),
277
+ nn.Conv2d(planes, planes, 3, 1, 1),
278
+ nn.LeakyReLU(0.2),
279
+ )
280
+
281
+ def forward(self, sketch, sketch_feat):
282
+
283
+ x0 = self.to0(sketch)
284
+ x1 = self.to1(x0)
285
+ x2 = self.to2(x1)
286
+ x3 = self.to3(x2) # !
287
+ x4 = self.to4(x3)
288
+
289
+ sketch_feat = self.feature_conv(sketch_feat)
290
+
291
+ out = self.tunnel4(torch.cat([x4, sketch_feat], 1))
292
+
293
+
294
+
295
+
296
+ x = self.tunnel3(torch.cat([out, x3], 1))
297
+ x = self.tunnel2(torch.cat([x, x2], 1))
298
+ x = self.tunnel1(torch.cat([x, x1], 1))
299
+ x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
300
+
301
+ decoder_output = self.deconv_for_decoder(out)
302
+
303
+ return x, decoder_output
304
+ '''
305
+ class Colorizer(nn.Module):
306
+ def __init__(self, extractor_path = 'model/model.pth'):
307
+ super(Colorizer, self).__init__()
308
+
309
+ self.generator = Generator()
310
+ self.extractor = se_resnext_half(dump_path=extractor_path, num_classes=370, input_channels=1)
311
+
312
+ def extractor_eval(self):
313
+ for param in self.extractor.parameters():
314
+ param.requires_grad = False
315
+
316
+ def extractor_train(self):
317
+ for param in extractor.parameters():
318
+ param.requires_grad = True
319
+
320
+ def forward(self, x, extractor_grad = False):
321
+
322
+ if extractor_grad:
323
+ features = self.extractor(x[:, 0:1])
324
+ else:
325
+ with torch.no_grad():
326
+ features = self.extractor(x[:, 0:1]).detach()
327
+
328
+ fake, guide = self.generator(x, features)
329
+
330
+ return fake, guide
331
+ '''
332
+
333
+ class Colorizer(nn.Module):
334
+ def __init__(self, generator_model, extractor_model):
335
+ super(Colorizer, self).__init__()
336
+
337
+ self.generator = generator_model
338
+ self.extractor = extractor_model
339
+
340
+ def load_generator_weights(self, gen_weights):
341
+ self.generator.load_state_dict(gen_weights)
342
+
343
+ def load_extractor_weights(self, ext_weights):
344
+ self.extractor.load_state_dict(ext_weights)
345
+
346
+ def extractor_eval(self):
347
+ for param in self.extractor.parameters():
348
+ param.requires_grad = False
349
+ self.extractor.eval()
350
+
351
+ def extractor_train(self):
352
+ for param in extractor.parameters():
353
+ param.requires_grad = True
354
+ self.extractor.train()
355
+
356
+ def forward(self, x, extractor_grad = False):
357
+
358
+ if extractor_grad:
359
+ features = self.extractor(x[:, 0:1])
360
+ else:
361
+ with torch.no_grad():
362
+ features = self.extractor(x[:, 0:1]).detach()
363
+
364
+ fake, guide = self.generator(x, features)
365
+
366
+ return fake, guide
367
+
368
+ class Discriminator(nn.Module):
369
+ def __init__(self, ndf=64):
370
+ super(Discriminator, self).__init__()
371
+
372
+ self.feed = nn.Sequential(SpectralNorm(nn.Conv2d(3, 64, 3, 1, 1)),
373
+ nn.LeakyReLU(0.2, True),
374
+ SpectralNorm(nn.Conv2d(64, 64, 3, 2, 0)),
375
+ nn.LeakyReLU(0.2, True),
376
+
377
+
378
+
379
+
380
+ SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1),
381
+ SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1, stride=2), # 128
382
+ SpectralNorm(nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=False)),
383
+ nn.LeakyReLU(0.2, True),
384
+
385
+ SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1),
386
+ SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2), # 64
387
+ SpectralNorm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0, bias=False)),
388
+ nn.LeakyReLU(0.2, True),
389
+
390
+ SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1),
391
+ SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2), # 32,
392
+ SpectralNorm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=1, stride=1, padding=1, bias=False)),
393
+ nn.LeakyReLU(0.2, True),
394
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
395
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2), # 16
396
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
397
+ SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
398
+ nn.AdaptiveAvgPool2d((1, 1))
399
+ )
400
+
401
+ self.out = nn.Linear(512, 1)
402
+
403
+ def forward(self, color):
404
+ x = self.feed(color)
405
+
406
+ out = self.out(x.view(color.size(0), -1))
407
+ return out
408
+
409
+ class Content(nn.Module):
410
+ def __init__(self, path):
411
+ super(Content, self).__init__()
412
+ vgg16 = M.vgg16()
413
+ vgg16.load_state_dict(torch.load(path))
414
+ vgg16.features = nn.Sequential(
415
+ *list(vgg16.features.children())[:9]
416
+ )
417
+ self.model = vgg16.features
418
+ self.register_buffer('mean', torch.FloatTensor([0.485 - 0.5, 0.456 - 0.5, 0.406 - 0.5]).view(1, 3, 1, 1))
419
+ self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
420
+
421
+ def forward(self, images):
422
+ return self.model((images.mul(0.5) - self.mean) / self.std)
model/vgg16-397923af.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0
3
+ size 553433881
train.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import numpy as np
5
+ import albumentations as albu
6
+ import argparse
7
+ import datetime
8
+
9
+ from utils.utils import open_json, weights_init, weights_init_spectr, generate_mask
10
+ from model.models import Colorizer, Generator, Content, Discriminator
11
+ from model.extractor import get_seresnext_extractor
12
+ from dataset.datasets import TrainDataset, FineTuningDataset
13
+
14
+
15
+
16
+ def parse_args():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("-p", "--path", required=True, help = "dataset path")
19
+ parser.add_argument('-ft', '--fine_tuning', dest = 'fine_tuning', action = 'store_true')
20
+ parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
21
+ parser.set_defaults(fine_tuning = False)
22
+ parser.set_defaults(gpu = False)
23
+ args = parser.parse_args()
24
+
25
+ return args
26
+
27
+ def get_transforms():
28
+ return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
29
+
30
+ def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number):
31
+ train_dataset = TrainDataset(data_path, transforms, mult_number)
32
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
33
+
34
+ if fine_tuning:
35
+ finetuning_dataset = FineTuningDataset(data_path, transforms)
36
+ finetuning_dataloader = torch.utils.data.DataLoader(finetuning_dataset, batch_size = batch_size, shuffle = True)
37
+
38
+ return train_dataloader, finetuning_dataloader
39
+
40
+ def get_models(device):
41
+ generator = Generator()
42
+ extractor = get_seresnext_extractor()
43
+ colorizer = Colorizer(generator, extractor)
44
+
45
+ colorizer.extractor_eval()
46
+ colorizer = colorizer.to(device)
47
+
48
+ discriminator = Discriminator().to(device)
49
+
50
+ content = Content('model/vgg16-397923af.pth').eval().to(device)
51
+ for param in content.parameters():
52
+ param.requires_grad = False
53
+
54
+ return colorizer, discriminator, content
55
+
56
+ def set_weights(colorizer, discriminator):
57
+ colorizer.generator.apply(weights_init)
58
+ colorizer.load_extractor_weights(torch.load('model/extractor.pth'))
59
+
60
+ discriminator.apply(weights_init_spectr)
61
+
62
+ def generator_loss(disc_output, true_labels, main_output, guide_output, real_image, content_gen, content_true, dist_loss = nn.L1Loss(), content_dist_loss = nn.MSELoss(), class_loss = nn.BCEWithLogitsLoss()):
63
+ sim_loss_full = dist_loss(main_output, real_image)
64
+ sim_loss_guide = dist_loss(guide_output, real_image)
65
+
66
+ adv_loss = class_loss(disc_output, true_labels)
67
+
68
+ content_loss = content_dist_loss(content_gen, content_true)
69
+
70
+ sum_loss = 10 * (sim_loss_full + 0.9 * sim_loss_guide) + adv_loss + content_loss
71
+
72
+ return sum_loss
73
+
74
+ def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
75
+ optimizerG = optim.Adam(colorizer.generator.parameters(), lr = generator_lr, betas=(0.5, 0.9))
76
+ optimizerD = optim.Adam(discriminator.parameters(), lr = discriminator_lr, betas=(0.5, 0.9))
77
+
78
+ return optimizerG, optimizerD
79
+
80
+ def generator_step(inputs, colorizer, discriminator, content, loss_function, optimizer, device, white_penalty = True):
81
+ for p in discriminator.parameters():
82
+ p.requires_grad = False
83
+ for p in colorizer.generator.parameters():
84
+ p.requires_grad = True
85
+
86
+ colorizer.generator.zero_grad()
87
+
88
+ bw, color, hint, dfm = inputs
89
+ bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
90
+
91
+ fake, guide = colorizer(torch.cat([bw, dfm, hint], 1))
92
+
93
+ logits_fake = discriminator(fake)
94
+ y_real = torch.ones((bw.size(0), 1), device = device)
95
+
96
+ content_fake = content(fake)
97
+ with torch.no_grad():
98
+ content_true = content(color)
99
+
100
+ generator_loss = loss_function(logits_fake, y_real, fake, guide, color, content_fake, content_true)
101
+
102
+ if white_penalty:
103
+ mask = (~((color > 0.85).float().sum(dim = 1) == 3).unsqueeze(1).repeat((1, 3, 1, 1 ))).float()
104
+ white_zones = mask * (fake + 1) / 2
105
+ white_penalty = (torch.pow(white_zones.sum(dim = 1), 2).sum(dim = (1, 2)) / (mask.sum(dim = (1, 2, 3)) + 1)).mean()
106
+
107
+ generator_loss += white_penalty
108
+
109
+ generator_loss.backward()
110
+
111
+ optimizer.step()
112
+
113
+ return generator_loss.item()
114
+
115
+ def discriminator_step(inputs, colorizer, discriminator, optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
116
+
117
+ for p in discriminator.parameters():
118
+ p.requires_grad = True
119
+ for p in colorizer.generator.parameters():
120
+ p.requires_grad = False
121
+
122
+ discriminator.zero_grad()
123
+
124
+ bw, color, hint, dfm = inputs
125
+ bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
126
+
127
+ y_real = torch.full((bw.size(0), 1), 0.9, device = device)
128
+
129
+ y_fake = torch.zeros((bw.size(0), 1), device = device)
130
+
131
+ with torch.no_grad():
132
+ fake_color, _ = colorizer(torch.cat([bw, dfm, hint], 1))
133
+ fake_color.detach()
134
+
135
+ logits_fake = discriminator(fake_color)
136
+ logits_real = discriminator(color)
137
+
138
+ fake_loss = loss_function(logits_fake, y_fake)
139
+ real_loss = loss_function(logits_real, y_real)
140
+
141
+ discriminator_loss = real_loss + fake_loss
142
+
143
+ discriminator_loss.backward()
144
+ optimizer.step()
145
+
146
+ return discriminator_loss.item()
147
+
148
+ def decrease_lr(optimizer, rate):
149
+ for group in optimizer.param_groups:
150
+ group['lr'] /= rate
151
+
152
+ def set_lr(optimizer, value):
153
+ for group in optimizer.param_groups:
154
+ group['lr'] = value
155
+
156
+ def train(colorizer, discriminator, content, dataloader, epochs, colorizer_optimizer, discriminator_optimizer, lr_decay_epoch = -1, device = 'cpu'):
157
+ colorizer.generator.train()
158
+ discriminator.train()
159
+
160
+ disc_step = True
161
+
162
+ for epoch in range(epochs):
163
+ if (epoch == lr_decay_epoch):
164
+ decrease_lr(colorizer_optimizer, 10)
165
+ decrease_lr(discriminator_optimizer, 10)
166
+
167
+ sum_disc_loss = 0
168
+ sum_gen_loss = 0
169
+
170
+ for n, inputs in enumerate(dataloader):
171
+ if n % 5 == 0:
172
+ print(datetime.datetime.now().time())
173
+ print('Step : %d Discr loss: %.4f Gen loss : %.4f \n'%(n, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
174
+
175
+
176
+ if disc_step:
177
+ step_loss = discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
178
+ sum_disc_loss += step_loss
179
+ else:
180
+ step_loss = generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
181
+ sum_gen_loss += step_loss
182
+
183
+ disc_step = disc_step ^ True
184
+
185
+
186
+ print(datetime.datetime.now().time())
187
+ print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
188
+
189
+
190
+ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
191
+
192
+ for p in discriminator.parameters():
193
+ p.requires_grad = True
194
+ for p in colorizer.generator.parameters():
195
+ p.requires_grad = False
196
+
197
+ for cur_disc_step in range(5):
198
+ discriminator.zero_grad()
199
+
200
+ bw, dfm, color_for_real = data_iter.next()
201
+ bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
202
+
203
+ y_real = torch.full((bw.size(0), 1), 0.9, device = device)
204
+ y_fake = torch.zeros((bw.size(0), 1), device = device)
205
+
206
+ empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3] ).float().to(device)
207
+
208
+ with torch.no_grad():
209
+ fake_color_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint ], 1))
210
+ fake_color_manga.detach()
211
+
212
+ logits_fake = discriminator(fake_color_manga)
213
+ logits_real = discriminator(color_for_real)
214
+
215
+ fake_loss = loss_function(logits_fake, y_fake)
216
+ real_loss = loss_function(logits_real, y_real)
217
+ discriminator_loss = real_loss + fake_loss
218
+
219
+ discriminator_loss.backward()
220
+ disc_optimizer.step()
221
+
222
+
223
+ for p in discriminator.parameters():
224
+ p.requires_grad = False
225
+ for p in colorizer.generator.parameters():
226
+ p.requires_grad = True
227
+
228
+ colorizer.generator.zero_grad()
229
+
230
+ bw, dfm, _ = data_iter.next()
231
+ bw, dfm = bw.to(device), dfm.to(device)
232
+
233
+ y_real = torch.ones((bw.size(0), 1), device = device)
234
+
235
+ empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3]).float().to(device)
236
+
237
+ fake_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint], 1))
238
+
239
+ logits_fake = discriminator(fake_manga)
240
+ adv_loss = loss_function(logits_fake, y_real)
241
+
242
+ generator_loss = adv_loss
243
+
244
+ generator_loss.backward()
245
+ gen_optimizer.step()
246
+
247
+
248
+
249
+ def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, data_iter, device = 'cpu'):
250
+ colorizer.generator.train()
251
+ discriminator.train()
252
+
253
+ disc_step = True
254
+
255
+ for n, inputs in enumerate(dataloader):
256
+
257
+ if n == iterations:
258
+ return
259
+
260
+ if disc_step:
261
+ discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
262
+ else:
263
+ generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
264
+
265
+ disc_step = disc_step ^ True
266
+
267
+ if n % 10 == 5:
268
+ fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
269
+
270
+ if __name__ == '__main__':
271
+ args = parse_args()
272
+ config = open_json('configs/train_config.json')
273
+
274
+ if args.gpu:
275
+ device = 'cuda'
276
+ else:
277
+ device = 'cpu'
278
+
279
+ augmentations = get_transforms()
280
+
281
+ train_dataloader, ft_dataloader = get_dataloaders(args.path, augmentations, config['batch_size'], args.fine_tuning, config['number_of_mults'])
282
+
283
+ colorizer, discriminator, content = get_models(device)
284
+ set_weights(colorizer, discriminator)
285
+
286
+ gen_optimizer, disc_optimizer = get_optimizers(colorizer, discriminator, config['generator_lr'], config['discriminator_lr'])
287
+
288
+ train(colorizer, discriminator, content, train_dataloader, config['epochs'], gen_optimizer, disc_optimizer, config['lr_decrease_epoch'], device)
289
+
290
+ if args.fine_tuning:
291
+ set_lr(gen_optimizer, config["finetuning_generator_lr"])
292
+ fine_tuning(colorizer, discriminator, content, train_dataloader, config['finetuning_iterations'], gen_optimizer, disc_optimizer, iter(ft_dataloader), device)
293
+
294
+ torch.save(colorizer.generator.state_dict(), str(datetime.datetime.now().time()))
utils/dataset_utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import cv2
4
+ import snowy
5
+ import os
6
+
7
+
8
+ def get_resized_image(img, size):
9
+ if len(img.shape) == 2:
10
+ img = np.repeat(np.expand_dims(img, 2), 3, 2)
11
+
12
+ if (img.shape[0] < img.shape[1]):
13
+ height = img.shape[0]
14
+ ratio = height / size
15
+ width = int(np.ceil(img.shape[1] / ratio))
16
+ img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
17
+ else:
18
+ width = img.shape[1]
19
+ ratio = width / size
20
+ height = int(np.ceil(img.shape[0] / ratio))
21
+ img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
22
+
23
+ if (img.dtype == 'float32'):
24
+ np.clip(img, 0, 1, out = img)
25
+
26
+ return img
27
+
28
+
29
+ def get_sketch_image(img, sketcher, mult_val):
30
+
31
+ if mult_val:
32
+ sketch_image = sketcher.get_sketch_with_resize(img, mult = mult_val)
33
+ else:
34
+ sketch_image = sketcher.get_sketch_with_resize(img)
35
+
36
+ return sketch_image
37
+
38
+
39
+ def get_dfm_image(sketch):
40
+ dfm_image = snowy.unitize(snowy.generate_sdf(np.expand_dims(1 - sketch, 2) != 0)).squeeze()
41
+ return dfm_image
42
+
43
+ def get_sketch(image, sketcher, dfm, mult = None):
44
+ sketch_image = get_sketch_image(image, sketcher, mult)
45
+
46
+ dfm_image = None
47
+
48
+ if dfm:
49
+ dfm_image = get_dfm_image(sketch_image)
50
+
51
+ sketch_image = (sketch_image * 255).astype('uint8')
52
+
53
+ if dfm:
54
+ dfm_image = (dfm_image * 255).astype('uint8')
55
+
56
+ return sketch_image, dfm_image
57
+
58
+ def get_sketches(image, sketcher, mult_list, dfm):
59
+ for mult in mult_list:
60
+ yield get_sketch(image, sketcher, dfm, mult)
61
+
62
+
63
+ def create_resized_dataset(source_path, target_path, side_size):
64
+ images = os.listdir(source_path)
65
+
66
+ for image_name in images:
67
+
68
+ new_image_name = image_name[:image_name.rfind('.')] + '.png'
69
+ new_path = os.path.join(target_path, new_image_name)
70
+
71
+ if not os.path.exists(new_path):
72
+ try:
73
+ image = cv2.imread(os.path.join(source_path, image_name))
74
+
75
+ if image is None:
76
+ raise Exception()
77
+
78
+ image = get_resized_image(image, side_size)
79
+
80
+ cv2.imwrite(new_path, image)
81
+ except:
82
+ print('Failed to process {}'.format(image_name))
83
+
84
+
85
+ def create_sketches_dataset(source_path, target_path, sketcher, mult_list, dfm = False):
86
+
87
+ images = os.listdir(source_path)
88
+ for image_name in images:
89
+ try:
90
+ image = cv2.imread(os.path.join(source_path, image_name))
91
+
92
+ if image is None:
93
+ raise Exception()
94
+
95
+ for number, (sketch_image, dfm_image) in enumerate(get_sketches(image, sketcher, mult_list, dfm)):
96
+ new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
97
+ cv2.imwrite(os.path.join(target_path, new_sketch_name), sketch_image)
98
+
99
+ if dfm:
100
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
101
+ cv2.imwrite(os.path.join(target_path, dfm_name), dfm_image)
102
+
103
+ except:
104
+ print('Failed to process {}'.format(image_name))
105
+
106
+
107
+ def create_dataset(source_path, target_path, sketcher, mult_list, side_size, dfm = False):
108
+ images = os.listdir(source_path)
109
+
110
+ color_path = os.path.join(target_path, 'color')
111
+ sketch_path = os.path.join(target_path, 'bw')
112
+
113
+ if not os.path.exists(color_path):
114
+ os.makedirs(color_path)
115
+
116
+ if not os.path.exists(sketch_path):
117
+ os.makedirs(sketch_path)
118
+
119
+ for image_name in images:
120
+ new_image_name = image_name[:image_name.rfind('.')] + '.png'
121
+
122
+ try:
123
+ image = cv2.imread(os.path.join(source_path, image_name))
124
+
125
+ if image is None:
126
+ raise Exception()
127
+
128
+ resized_image = get_resized_image(image, side_size)
129
+ cv2.imwrite(os.path.join(color_path, new_image_name), resized_image)
130
+
131
+ for number, (sketch_image, dfm_image) in enumerate(get_sketches(resized_image, sketcher, mult_list, dfm)):
132
+ new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
133
+ cv2.imwrite(os.path.join(sketch_path, new_sketch_name), sketch_image)
134
+
135
+ if dfm:
136
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
137
+ cv2.imwrite(os.path.join(sketch_path, dfm_name), dfm_image)
138
+
139
+ except:
140
+ print('Failed to process {}'.format(image_name))
141
+
utils/utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import scipy.stats as stats
5
+ import cv2
6
+ import json
7
+ import patoolib
8
+ import re
9
+ from pathlib import Path
10
+ from shutil import rmtree
11
+
12
+ def weights_init(m):
13
+ classname = m.__class__.__name__
14
+ if classname.find('Conv2d') != -1:
15
+ nn.init.xavier_uniform_(m.weight.data)
16
+
17
+ def weights_init_spectr(m):
18
+ classname = m.__class__.__name__
19
+ if classname.find('Conv2d') != -1:
20
+ nn.init.xavier_uniform_(m.weight_bar.data)
21
+
22
+ def generate_mask(height, width, mu = 1, sigma = 0.0005, prob = 0.5, full = True, full_prob = 0.01):
23
+ X = stats.truncnorm((0 - mu) / sigma, (1 - mu) / sigma, loc=mu, scale=sigma)
24
+
25
+ if full:
26
+ if (np.random.binomial(1, p = full_prob) == 1):
27
+ return torch.ones(1, height, width).float()
28
+
29
+ if np.random.binomial(1, p = prob) == 1:
30
+ mask = torch.rand(1, height, width).ge(X.rvs(1)[0]).float()
31
+ else:
32
+ mask = torch.zeros(1, height, width).float()
33
+
34
+ return mask
35
+
36
+ def resize_pad(img, size = 512):
37
+
38
+ if len(img.shape) == 2:
39
+ img = np.expand_dims(img, 2)
40
+
41
+ if img.shape[2] == 1:
42
+ img = np.repeat(img, 3, 2)
43
+
44
+ if img.shape[2] == 4:
45
+ img = img[:, :, :3]
46
+
47
+ pad = None
48
+
49
+ if (img.shape[0] < img.shape[1]):
50
+ height = img.shape[0]
51
+ ratio = height / size
52
+ width = int(np.ceil(img.shape[1] / ratio))
53
+ img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
54
+
55
+ new_width = width
56
+ while (new_width % 32 != 0):
57
+ new_width += 1
58
+
59
+ pad = (0, new_width - width)
60
+
61
+ img = np.pad(img, ((0, 0), (0, pad[1]), (0, 0)), 'maximum')
62
+ else:
63
+ width = img.shape[1]
64
+ ratio = width / size
65
+ height = int(np.ceil(img.shape[0] / ratio))
66
+ img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
67
+
68
+ new_height = height
69
+ while (new_height % 32 != 0):
70
+ new_height += 1
71
+
72
+ pad = (new_height - height, 0)
73
+
74
+ img = np.pad(img, ((0, pad[0]), (0, 0), (0, 0)), 'maximum')
75
+
76
+ if (img.dtype == 'float32'):
77
+ np.clip(img, 0, 1, out = img)
78
+
79
+ return img, pad
80
+
81
+ def open_json(file):
82
+ with open(file) as json_file:
83
+ data = json.load(json_file)
84
+
85
+ return data
86
+
87
+ def extract_cbr(file, out_dir):
88
+ patoolib.extract_archive(file, outdir = out_dir, verbosity = 1)
89
+
90
+ def create_cbz(file_path, files):
91
+ patoolib.create_archive(file_path, files, verbosity = 1)
92
+
93
+ def subfolder_image_search(start_folder):
94
+ return [x.as_posix() for x in Path(".").rglob("*.[pPjJ][nNpP][gG]")]
95
+
96
+ def remove_folder(folder_path):
97
+ rmtree(folder_path)
98
+
99
+ def sorted_alphanumeric(data):
100
+ convert = lambda text: int(text) if text.isdigit() else text.lower()
101
+ alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
102
+ return sorted(data, key=alphanum_key)
utils/xdog.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cv2 import resize, INTER_LANCZOS4, INTER_AREA
2
+ from skimage.color import rgb2gray
3
+ import numpy as np
4
+ from scipy.ndimage.filters import gaussian_filter
5
+ from skimage.filters import threshold_otsu
6
+ import matplotlib.pyplot as plt
7
+
8
+ class XDoGSketcher:
9
+
10
+ def __init__(self, gamma = 0.95, phi = 89.25, eps = -0.1, k = 8, sigma = 0.5, mult = 1):
11
+ self.params = {}
12
+ self.params['gamma'] = gamma
13
+ self.params['phi'] = phi
14
+ self.params['eps'] = eps
15
+ self.params['k'] = k
16
+ self.params['sigma'] = sigma
17
+
18
+ self.params['mult'] = mult
19
+
20
+ def _xdog(self, im, **transform_params):
21
+ # Source : https://github.com/CemalUnal/XDoG-Filter
22
+ # Reference : XDoG: An eXtended difference-of-Gaussians compendium including advanced image stylization
23
+ # Link : http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.365.151&rep=rep1&type=pdf
24
+
25
+ if im.shape[2] == 3:
26
+ im = rgb2gray(im)
27
+
28
+ imf1 = gaussian_filter(im, transform_params['sigma'])
29
+ imf2 = gaussian_filter(im, transform_params['sigma'] * transform_params['k'])
30
+ imdiff = imf1 - transform_params['gamma'] * imf2
31
+ imdiff = (imdiff < transform_params['eps']) * 1.0 \
32
+ + (imdiff >= transform_params['eps']) * (1.0 + np.tanh(transform_params['phi'] * imdiff))
33
+ imdiff -= imdiff.min()
34
+ imdiff /= imdiff.max()
35
+
36
+
37
+ th = threshold_otsu(imdiff)
38
+ imdiff = imdiff >= th
39
+
40
+ imdiff = imdiff.astype('float32')
41
+
42
+ return imdiff
43
+
44
+
45
+ def get_sketch(self, image, **kwargs):
46
+ current_params = self.params.copy()
47
+
48
+ for key in kwargs.keys():
49
+ if key in current_params.keys():
50
+ current_params[key] = kwargs[key]
51
+
52
+ result_image = self._xdog(image, **current_params)
53
+
54
+ return result_image
55
+
56
+ def get_sketch_with_resize(self, image, **kwargs):
57
+ if 'mult' in kwargs.keys():
58
+ mult = kwargs['mult']
59
+ else:
60
+ mult = self.params['mult']
61
+
62
+ temp_image = resize(image, (image.shape[1] * mult, image.shape[0] * mult), interpolation = INTER_LANCZOS4)
63
+ temp_image = self.get_sketch(temp_image, **kwargs)
64
+ image = resize(temp_image, (image.shape[1], image.shape[0]), interpolation = INTER_AREA)
65
+
66
+ return image
67
+
68
+