Upload 34 files
Browse files- .gitattributes +16 -0
- .gitignore +5 -0
- configs/train_config.json +10 -0
- configs/xdog_config.json +8 -0
- dataset/datasets.py +107 -0
- dataset/manga/train/bw/002-0000-0000.png +3 -0
- dataset/manga/train/bw/003-0000-0000.png +0 -0
- dataset/manga/train/bw/x2-0000-0000.png +3 -0
- dataset/manga/train/bw/x3-0000-0000.png +3 -0
- dataset/manga/train/bw/x5-0000-0000.png +3 -0
- dataset/manga/train/color/002-0000-0000.png +3 -0
- dataset/manga/train/color/003-0000-0000.png +0 -0
- dataset/manga/train/color/004-0000-0000.png +0 -0
- dataset/manga/train/color/x1-0000-0000.png +3 -0
- dataset/manga/train/color/x2-0000-0000.png +3 -0
- dataset/manga/train/color/x3-0000-0000.png +3 -0
- dataset/manga/train/color/x4-0000-0000.png +3 -0
- dataset/manga/train/color/x5-0000-0000.png +3 -0
- dataset/manga/train/real_manga/002-0000-0000.png +3 -0
- dataset/manga/train/real_manga/003-0000-0000.png +0 -0
- dataset/manga/train/real_manga/004-0000-0000.png +0 -0
- dataset/manga/train/real_manga/x1-0000-0000.png +3 -0
- dataset/manga/train/real_manga/x2-0000-0000.png +3 -0
- dataset/manga/train/real_manga/x3-0000-0000.png +3 -0
- dataset/manga/train/real_manga/x4-0000-0000.png +3 -0
- dataset/manga/train/real_manga/x5-0000-0000.png +3 -0
- inference.py +154 -0
- model/extractor.pth +3 -0
- model/extractor.py +127 -0
- model/models.py +422 -0
- model/vgg16-397923af.pth +3 -0
- train.py +294 -0
- utils/dataset_utils.py +141 -0
- utils/utils.py +102 -0
- utils/xdog.py +68 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,19 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
dataset/manga/train/bw/002-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
dataset/manga/train/bw/x2-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
dataset/manga/train/bw/x3-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
dataset/manga/train/bw/x5-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
dataset/manga/train/color/002-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
dataset/manga/train/color/x1-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
dataset/manga/train/color/x2-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
dataset/manga/train/color/x3-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
dataset/manga/train/color/x4-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
dataset/manga/train/color/x5-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
dataset/manga/train/real_manga/002-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
dataset/manga/train/real_manga/x1-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
dataset/manga/train/real_manga/x2-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
dataset/manga/train/real_manga/x3-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
dataset/manga/train/real_manga/x4-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
dataset/manga/train/real_manga/x5-0000-0000.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.ipynb
|
| 2 |
+
*.pth
|
| 3 |
+
|
| 4 |
+
__pycache__/
|
| 5 |
+
temp_colorization/
|
configs/train_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"generator_lr" : 1e-4,
|
| 3 |
+
"discriminator_lr" : 4e-4,
|
| 4 |
+
"epochs" : 15,
|
| 5 |
+
"lr_decrease_epoch" : 10,
|
| 6 |
+
"finetuning_generator_lr" : 1e-6,
|
| 7 |
+
"finetuning_iterations" : 3500,
|
| 8 |
+
"batch_size" : 4,
|
| 9 |
+
"number_of_mults" : 3
|
| 10 |
+
}
|
configs/xdog_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"sigma" : 0.5,
|
| 3 |
+
"k" : 8,
|
| 4 |
+
"phi" : 89.25,
|
| 5 |
+
"gamma" : 0.95,
|
| 6 |
+
"eps" : -0.1,
|
| 7 |
+
"mult" : 7
|
| 8 |
+
}
|
dataset/datasets.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from utils.utils import generate_mask
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TrainDataset(torch.utils.data.Dataset):
|
| 11 |
+
def __init__(self, data_path, transform = None, mults_amount = 1):
|
| 12 |
+
self.data = os.listdir(os.path.join(data_path, 'color'))
|
| 13 |
+
self.data_path = data_path
|
| 14 |
+
self.transform = transform
|
| 15 |
+
self.mults_amount = mults_amount
|
| 16 |
+
|
| 17 |
+
self.ToTensor = transforms.ToTensor()
|
| 18 |
+
def __len__(self):
|
| 19 |
+
return len(self.data)
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, idx):
|
| 22 |
+
image_name = self.data[idx]
|
| 23 |
+
|
| 24 |
+
color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if self.mults_amount > 1:
|
| 28 |
+
mult_number = np.random.choice(range(self.mults_amount))
|
| 29 |
+
|
| 30 |
+
bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
|
| 31 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
|
| 32 |
+
else:
|
| 33 |
+
bw_name = self.data[idx]
|
| 34 |
+
dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
|
| 38 |
+
dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
|
| 39 |
+
|
| 40 |
+
bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
|
| 41 |
+
|
| 42 |
+
if self.transform:
|
| 43 |
+
result = self.transform(image = color_img, mask = bw_img)
|
| 44 |
+
color_img = result['image']
|
| 45 |
+
bw_img = result['mask']
|
| 46 |
+
|
| 47 |
+
dfm_img = bw_img[:, :, 1]
|
| 48 |
+
bw_img = bw_img[:, :, 0]
|
| 49 |
+
|
| 50 |
+
color_img = self.ToTensor(color_img)
|
| 51 |
+
bw_img = self.ToTensor(bw_img)
|
| 52 |
+
|
| 53 |
+
dfm_img = self.ToTensor(dfm_img)
|
| 54 |
+
|
| 55 |
+
color_img = (color_img - 0.5) / 0.5
|
| 56 |
+
|
| 57 |
+
mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
|
| 58 |
+
hint = torch.cat((color_img * mask, mask), 0)
|
| 59 |
+
|
| 60 |
+
return bw_img, color_img, hint, dfm_img
|
| 61 |
+
|
| 62 |
+
class FineTuningDataset(torch.utils.data.Dataset):
|
| 63 |
+
def __init__(self, data_path, transform = None, mult_amount = 1):
|
| 64 |
+
self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
|
| 65 |
+
self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
|
| 66 |
+
self.data_path = data_path
|
| 67 |
+
self.transform = transform
|
| 68 |
+
self.mults_amount = mult_amount
|
| 69 |
+
|
| 70 |
+
np.random.shuffle(self.color_data)
|
| 71 |
+
|
| 72 |
+
self.ToTensor = transforms.ToTensor()
|
| 73 |
+
def __len__(self):
|
| 74 |
+
return len(self.data)
|
| 75 |
+
|
| 76 |
+
def __getitem__(self, idx):
|
| 77 |
+
color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
|
| 78 |
+
|
| 79 |
+
image_name = self.data[idx]
|
| 80 |
+
if self.mults_amount > 1:
|
| 81 |
+
mult_number = np.random.choice(range(self.mults_amount))
|
| 82 |
+
|
| 83 |
+
bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
|
| 84 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
|
| 85 |
+
else:
|
| 86 |
+
bw_name = self.data[idx]
|
| 87 |
+
dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
|
| 91 |
+
dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', dfm_name)), 2)
|
| 92 |
+
|
| 93 |
+
if self.transform:
|
| 94 |
+
result = self.transform(image = color_img)
|
| 95 |
+
color_img = result['image']
|
| 96 |
+
|
| 97 |
+
result = self.transform(image = bw_img, mask = dfm_img)
|
| 98 |
+
bw_img = result['image']
|
| 99 |
+
dfm_img = result['mask']
|
| 100 |
+
|
| 101 |
+
color_img = self.ToTensor(color_img)
|
| 102 |
+
bw_img = self.ToTensor(bw_img)
|
| 103 |
+
dfm_img = self.ToTensor(dfm_img)
|
| 104 |
+
|
| 105 |
+
color_img = (color_img - 0.5) / 0.5
|
| 106 |
+
|
| 107 |
+
return bw_img, dfm_img, color_img
|
dataset/manga/train/bw/002-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/bw/003-0000-0000.png
ADDED
|
dataset/manga/train/bw/x2-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/bw/x3-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/bw/x5-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/color/002-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/color/003-0000-0000.png
ADDED
|
dataset/manga/train/color/004-0000-0000.png
ADDED
|
dataset/manga/train/color/x1-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/color/x2-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/color/x3-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/color/x4-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/color/x5-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/real_manga/002-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/real_manga/003-0000-0000.png
ADDED
|
dataset/manga/train/real_manga/004-0000-0000.png
ADDED
|
dataset/manga/train/real_manga/x1-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/real_manga/x2-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/real_manga/x3-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/real_manga/x4-0000-0000.png
ADDED
|
Git LFS Details
|
dataset/manga/train/real_manga/x5-0000-0000.png
ADDED
|
Git LFS Details
|
inference.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from utils.dataset_utils import get_sketch
|
| 5 |
+
from utils.utils import resize_pad, generate_mask, extract_cbr, create_cbz, sorted_alphanumeric, subfolder_image_search, remove_folder
|
| 6 |
+
from torchvision.transforms import ToTensor
|
| 7 |
+
import os
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import argparse
|
| 10 |
+
from model.models import Colorizer, Generator
|
| 11 |
+
from model.extractor import get_seresnext_extractor
|
| 12 |
+
from utils.xdog import XDoGSketcher
|
| 13 |
+
from utils.utils import open_json
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
def colorize_without_hint(inp, colorizer, device = 'cpu', auto_hint = False, auto_hint_sigma = 0.003):
|
| 17 |
+
i_hint = torch.zeros(1, 4, inp.shape[2], inp.shape[3]).float().to(device)
|
| 18 |
+
|
| 19 |
+
with torch.no_grad():
|
| 20 |
+
fake_color, _ = colorizer(torch.cat([inp, i_hint], 1))
|
| 21 |
+
|
| 22 |
+
if auto_hint:
|
| 23 |
+
mask = generate_mask(fake_color.shape[2], fake_color.shape[3], full = False, prob = 1, sigma = auto_hint_sigma).unsqueeze(0)
|
| 24 |
+
mask = mask.to(device)
|
| 25 |
+
i_hint = torch.cat([fake_color * mask, mask], 1)
|
| 26 |
+
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
fake_color, _ = colorizer(torch.cat([inp, i_hint], 1))
|
| 29 |
+
|
| 30 |
+
return fake_color
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def process_image(image, sketcher, colorizer, auto_hint, auto_hint_sigma = 0.003, dfm = True, device = 'cpu', to_tensor = ToTensor()):
|
| 34 |
+
image, pad = resize_pad(image)
|
| 35 |
+
bw, dfm = get_sketch(image, sketcher, dfm)
|
| 36 |
+
|
| 37 |
+
bw = to_tensor(bw).unsqueeze(0).to(device)
|
| 38 |
+
dfm = to_tensor(dfm).unsqueeze(0).to(device)
|
| 39 |
+
|
| 40 |
+
output = colorize_without_hint(torch.cat([bw, dfm], 1), colorizer, device = device, auto_hint = auto_hint)
|
| 41 |
+
result = output[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
|
| 42 |
+
|
| 43 |
+
if pad[0] != 0:
|
| 44 |
+
result = result[:-pad[0]]
|
| 45 |
+
if pad[1] != 0:
|
| 46 |
+
result = result[:, :-pad[1]]
|
| 47 |
+
|
| 48 |
+
return result
|
| 49 |
+
|
| 50 |
+
def colorize_single_image(file_path, save_path, sketcher, colorizer, auto_hint, auto_hint_sigma = 0.003, dfm = True, device = 'cpu'):
|
| 51 |
+
try:
|
| 52 |
+
image = plt.imread(file_path)
|
| 53 |
+
|
| 54 |
+
colorization = process_image(image, sketcher, colorizer, auto_hint, auto_hint_sigma, dfm, device)
|
| 55 |
+
|
| 56 |
+
plt.imsave(save_path, colorization)
|
| 57 |
+
except KeyboardInterrupt:
|
| 58 |
+
sys.exit(0)
|
| 59 |
+
except:
|
| 60 |
+
print('Failed to colorize {}'.format(file_path))
|
| 61 |
+
|
| 62 |
+
def colorize_images(source_path, target_path, sketcher, colorizer, auto_hint, auto_hint_sigma = 0.003, dfm = True, device = 'cpu'):
|
| 63 |
+
images = os.listdir(source_path)
|
| 64 |
+
|
| 65 |
+
for image_name in images:
|
| 66 |
+
file_path = os.path.join(source_path, image_name)
|
| 67 |
+
save_path = os.path.join(target_path, image_name)
|
| 68 |
+
colorize_single_image(file_path, save_path, sketcher, colorizer, auto_hint, auto_hint_sigma, dfm, device)
|
| 69 |
+
|
| 70 |
+
def colorize_cbr(file_path, sketcher, colorizer, auto_hint, auto_hint_sigma = 0.003, dfm = True, device = 'cpu'):
|
| 71 |
+
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 72 |
+
temp_path = 'temp_colorization'
|
| 73 |
+
|
| 74 |
+
if not os.path.exists(temp_path):
|
| 75 |
+
os.makedirs(temp_path)
|
| 76 |
+
extract_cbr(file_path, temp_path)
|
| 77 |
+
|
| 78 |
+
images = subfolder_image_search(temp_path)
|
| 79 |
+
for image_path in images:
|
| 80 |
+
try:
|
| 81 |
+
image = plt.imread(image_path)
|
| 82 |
+
|
| 83 |
+
colorization = process_image(image, sketcher, colorizer, auto_hint, auto_hint_sigma, dfm, device)
|
| 84 |
+
|
| 85 |
+
plt.imsave(image_path, colorization)
|
| 86 |
+
except KeyboardInterrupt:
|
| 87 |
+
sys.exit(0)
|
| 88 |
+
except:
|
| 89 |
+
print('Failed to colorize {}'.format(image_path))
|
| 90 |
+
|
| 91 |
+
result_name = os.path.join(os.path.dirname(file_path), file_name + '_colorized.cbz')
|
| 92 |
+
|
| 93 |
+
create_cbz(result_name, images)
|
| 94 |
+
|
| 95 |
+
remove_folder(temp_path)
|
| 96 |
+
|
| 97 |
+
def parse_args():
|
| 98 |
+
parser = argparse.ArgumentParser()
|
| 99 |
+
parser.add_argument("-p", "--path", required=True)
|
| 100 |
+
parser.add_argument("-gen", "--generator", default = 'model/biggan.pth')
|
| 101 |
+
parser.add_argument("-ext", "--extractor", default = 'model/extractor.pth')
|
| 102 |
+
parser.add_argument("-s", "--sigma", type = float, default = 0.003)
|
| 103 |
+
parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
|
| 104 |
+
parser.add_argument('-ah', '--auto', dest = 'autohint', action = 'store_true')
|
| 105 |
+
parser.set_defaults(gpu = False)
|
| 106 |
+
parser.set_defaults(autohint = False)
|
| 107 |
+
args = parser.parse_args()
|
| 108 |
+
|
| 109 |
+
return args
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
|
| 114 |
+
args = parse_args()
|
| 115 |
+
|
| 116 |
+
if args.gpu:
|
| 117 |
+
device = 'cuda'
|
| 118 |
+
else:
|
| 119 |
+
device = 'cpu'
|
| 120 |
+
|
| 121 |
+
generator = Generator()
|
| 122 |
+
generator.load_state_dict(torch.load(args.generator))
|
| 123 |
+
|
| 124 |
+
extractor = get_seresnext_extractor()
|
| 125 |
+
extractor.load_state_dict(torch.load(args.extractor))
|
| 126 |
+
|
| 127 |
+
colorizer = Colorizer(generator, extractor)
|
| 128 |
+
colorizer = colorizer.eval().to(device)
|
| 129 |
+
|
| 130 |
+
sketcher = XDoGSketcher()
|
| 131 |
+
xdog_config = open_json('configs/xdog_config.json')
|
| 132 |
+
for key in xdog_config.keys():
|
| 133 |
+
if key in sketcher.params:
|
| 134 |
+
sketcher.params[key] = xdog_config[key]
|
| 135 |
+
|
| 136 |
+
if os.path.isdir(args.path):
|
| 137 |
+
colorization_path = os.path.join(args.path, 'colorization')
|
| 138 |
+
if not os.path.exists(colorization_path):
|
| 139 |
+
os.makedirs(colorization_path)
|
| 140 |
+
|
| 141 |
+
colorize_images(args.path, colorization_path, sketcher, colorizer, args.autohint, args.sigma, device = device)
|
| 142 |
+
elif os.path.isfile(args.path):
|
| 143 |
+
split = os.path.splitext(args.path)
|
| 144 |
+
if split[1].lower() in ('.cbr', '.cbz', '.rar', '.zip'):
|
| 145 |
+
colorize_cbr(args.path, sketcher, colorizer, args.autohint, args.sigma, device = device)
|
| 146 |
+
elif split[1].lower() in ('.jpg', '.png'):
|
| 147 |
+
new_image_path = split[0] + '_colorized' + split[1]
|
| 148 |
+
|
| 149 |
+
colorize_single_image(args.path, new_image_path, sketcher, colorizer, args.autohint, args.sigma, device = device)
|
| 150 |
+
else:
|
| 151 |
+
print('Wrong format')
|
| 152 |
+
else:
|
| 153 |
+
print('Wrong path')
|
| 154 |
+
|
model/extractor.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee3c59f02ac8c59298fd9b819fa33d2efa168847e15e4be39b35c286f7c18607
|
| 3 |
+
size 6340842
|
model/extractor.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
'''https://github.com/blandocs/Tag2Pix/blob/master/model/pretrained.py'''
|
| 6 |
+
|
| 7 |
+
# Pretrained version
|
| 8 |
+
class Selayer(nn.Module):
|
| 9 |
+
def __init__(self, inplanes):
|
| 10 |
+
super(Selayer, self).__init__()
|
| 11 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 12 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
|
| 13 |
+
self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
|
| 14 |
+
self.relu = nn.ReLU(inplace=True)
|
| 15 |
+
self.sigmoid = nn.Sigmoid()
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
out = self.global_avgpool(x)
|
| 19 |
+
out = self.conv1(out)
|
| 20 |
+
out = self.relu(out)
|
| 21 |
+
out = self.conv2(out)
|
| 22 |
+
out = self.sigmoid(out)
|
| 23 |
+
|
| 24 |
+
return x * out
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BottleneckX_Origin(nn.Module):
|
| 28 |
+
expansion = 4
|
| 29 |
+
|
| 30 |
+
def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
|
| 31 |
+
super(BottleneckX_Origin, self).__init__()
|
| 32 |
+
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
|
| 33 |
+
self.bn1 = nn.BatchNorm2d(planes * 2)
|
| 34 |
+
|
| 35 |
+
self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride,
|
| 36 |
+
padding=1, groups=cardinality, bias=False)
|
| 37 |
+
self.bn2 = nn.BatchNorm2d(planes * 2)
|
| 38 |
+
|
| 39 |
+
self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
|
| 40 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 41 |
+
|
| 42 |
+
self.selayer = Selayer(planes * 4)
|
| 43 |
+
|
| 44 |
+
self.relu = nn.ReLU(inplace=True)
|
| 45 |
+
self.downsample = downsample
|
| 46 |
+
self.stride = stride
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
residual = x
|
| 50 |
+
|
| 51 |
+
out = self.conv1(x)
|
| 52 |
+
out = self.bn1(out)
|
| 53 |
+
out = self.relu(out)
|
| 54 |
+
|
| 55 |
+
out = self.conv2(out)
|
| 56 |
+
out = self.bn2(out)
|
| 57 |
+
out = self.relu(out)
|
| 58 |
+
|
| 59 |
+
out = self.conv3(out)
|
| 60 |
+
out = self.bn3(out)
|
| 61 |
+
|
| 62 |
+
out = self.selayer(out)
|
| 63 |
+
|
| 64 |
+
if self.downsample is not None:
|
| 65 |
+
residual = self.downsample(x)
|
| 66 |
+
|
| 67 |
+
out += residual
|
| 68 |
+
out = self.relu(out)
|
| 69 |
+
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
class SEResNeXt_extractor(nn.Module):
|
| 73 |
+
def __init__(self, block, layers, input_channels=3, cardinality=32):
|
| 74 |
+
super(SEResNeXt_extractor, self).__init__()
|
| 75 |
+
self.cardinality = cardinality
|
| 76 |
+
self.inplanes = 64
|
| 77 |
+
self.input_channels = input_channels
|
| 78 |
+
|
| 79 |
+
self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
|
| 80 |
+
bias=False)
|
| 81 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 82 |
+
self.relu = nn.ReLU(inplace=True)
|
| 83 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 84 |
+
|
| 85 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 86 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 87 |
+
|
| 88 |
+
for m in self.modules():
|
| 89 |
+
if isinstance(m, nn.Conv2d):
|
| 90 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 91 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 92 |
+
if m.bias is not None:
|
| 93 |
+
m.bias.data.zero_()
|
| 94 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 95 |
+
m.weight.data.fill_(1)
|
| 96 |
+
m.bias.data.zero_()
|
| 97 |
+
|
| 98 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 99 |
+
downsample = None
|
| 100 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 101 |
+
downsample = nn.Sequential(
|
| 102 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 103 |
+
kernel_size=1, stride=stride, bias=False),
|
| 104 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
layers = []
|
| 108 |
+
layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
|
| 109 |
+
self.inplanes = planes * block.expansion
|
| 110 |
+
for i in range(1, blocks):
|
| 111 |
+
layers.append(block(self.inplanes, planes, self.cardinality))
|
| 112 |
+
|
| 113 |
+
return nn.Sequential(*layers)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
x = self.conv1(x)
|
| 117 |
+
x = self.bn1(x)
|
| 118 |
+
x = self.relu(x)
|
| 119 |
+
x = self.maxpool(x)
|
| 120 |
+
|
| 121 |
+
x = self.layer1(x)
|
| 122 |
+
x = self.layer2(x)
|
| 123 |
+
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
def get_seresnext_extractor():
|
| 127 |
+
return SEResNeXt_extractor(BottleneckX_Origin, [3, 4, 6, 3], 1)
|
model/models.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.models as M
|
| 5 |
+
import math
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from torch.nn import Parameter
|
| 8 |
+
|
| 9 |
+
'''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
|
| 10 |
+
|
| 11 |
+
def l2normalize(v, eps=1e-12):
|
| 12 |
+
return v / (v.norm() + eps)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SpectralNorm(nn.Module):
|
| 16 |
+
def __init__(self, module, name='weight', power_iterations=1):
|
| 17 |
+
super(SpectralNorm, self).__init__()
|
| 18 |
+
self.module = module
|
| 19 |
+
self.name = name
|
| 20 |
+
self.power_iterations = power_iterations
|
| 21 |
+
if not self._made_params():
|
| 22 |
+
self._make_params()
|
| 23 |
+
|
| 24 |
+
def _update_u_v(self):
|
| 25 |
+
u = getattr(self.module, self.name + "_u")
|
| 26 |
+
v = getattr(self.module, self.name + "_v")
|
| 27 |
+
w = getattr(self.module, self.name + "_bar")
|
| 28 |
+
|
| 29 |
+
height = w.data.shape[0]
|
| 30 |
+
for _ in range(self.power_iterations):
|
| 31 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
|
| 32 |
+
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
|
| 33 |
+
|
| 34 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
| 35 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
| 36 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
| 37 |
+
|
| 38 |
+
def _made_params(self):
|
| 39 |
+
try:
|
| 40 |
+
u = getattr(self.module, self.name + "_u")
|
| 41 |
+
v = getattr(self.module, self.name + "_v")
|
| 42 |
+
w = getattr(self.module, self.name + "_bar")
|
| 43 |
+
return True
|
| 44 |
+
except AttributeError:
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _make_params(self):
|
| 49 |
+
w = getattr(self.module, self.name)
|
| 50 |
+
height = w.data.shape[0]
|
| 51 |
+
width = w.view(height, -1).data.shape[1]
|
| 52 |
+
|
| 53 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
| 54 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
| 55 |
+
u.data = l2normalize(u.data)
|
| 56 |
+
v.data = l2normalize(v.data)
|
| 57 |
+
w_bar = Parameter(w.data)
|
| 58 |
+
|
| 59 |
+
del self.module._parameters[self.name]
|
| 60 |
+
|
| 61 |
+
self.module.register_parameter(self.name + "_u", u)
|
| 62 |
+
self.module.register_parameter(self.name + "_v", v)
|
| 63 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def forward(self, *args):
|
| 67 |
+
self._update_u_v()
|
| 68 |
+
return self.module.forward(*args)
|
| 69 |
+
|
| 70 |
+
class Selayer(nn.Module):
|
| 71 |
+
def __init__(self, inplanes):
|
| 72 |
+
super(Selayer, self).__init__()
|
| 73 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 74 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
|
| 75 |
+
self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
|
| 76 |
+
self.relu = nn.ReLU(inplace=True)
|
| 77 |
+
self.sigmoid = nn.Sigmoid()
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
out = self.global_avgpool(x)
|
| 81 |
+
out = self.conv1(out)
|
| 82 |
+
out = self.relu(out)
|
| 83 |
+
out = self.conv2(out)
|
| 84 |
+
out = self.sigmoid(out)
|
| 85 |
+
|
| 86 |
+
return x * out
|
| 87 |
+
|
| 88 |
+
class SelayerSpectr(nn.Module):
|
| 89 |
+
def __init__(self, inplanes):
|
| 90 |
+
super(SelayerSpectr, self).__init__()
|
| 91 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 92 |
+
self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
|
| 93 |
+
self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
|
| 94 |
+
self.relu = nn.ReLU(inplace=True)
|
| 95 |
+
self.sigmoid = nn.Sigmoid()
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
out = self.global_avgpool(x)
|
| 99 |
+
out = self.conv1(out)
|
| 100 |
+
out = self.relu(out)
|
| 101 |
+
out = self.conv2(out)
|
| 102 |
+
out = self.sigmoid(out)
|
| 103 |
+
|
| 104 |
+
return x * out
|
| 105 |
+
|
| 106 |
+
class ResNeXtBottleneck(nn.Module):
|
| 107 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
| 108 |
+
super(ResNeXtBottleneck, self).__init__()
|
| 109 |
+
D = out_channels // 2
|
| 110 |
+
self.out_channels = out_channels
|
| 111 |
+
self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
|
| 112 |
+
self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
| 113 |
+
groups=cardinality,
|
| 114 |
+
bias=False)
|
| 115 |
+
self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
| 116 |
+
self.shortcut = nn.Sequential()
|
| 117 |
+
if stride != 1:
|
| 118 |
+
self.shortcut.add_module('shortcut',
|
| 119 |
+
nn.AvgPool2d(2, stride=2))
|
| 120 |
+
|
| 121 |
+
self.selayer = Selayer(out_channels)
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
bottleneck = self.conv_reduce.forward(x)
|
| 125 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 126 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
| 127 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 128 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
| 129 |
+
bottleneck = self.selayer(bottleneck)
|
| 130 |
+
|
| 131 |
+
x = self.shortcut.forward(x)
|
| 132 |
+
return x + bottleneck
|
| 133 |
+
|
| 134 |
+
class SpectrResNeXtBottleneck(nn.Module):
|
| 135 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
| 136 |
+
super(SpectrResNeXtBottleneck, self).__init__()
|
| 137 |
+
D = out_channels // 2
|
| 138 |
+
self.out_channels = out_channels
|
| 139 |
+
self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
|
| 140 |
+
self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
| 141 |
+
groups=cardinality,
|
| 142 |
+
bias=False))
|
| 143 |
+
self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
|
| 144 |
+
self.shortcut = nn.Sequential()
|
| 145 |
+
if stride != 1:
|
| 146 |
+
self.shortcut.add_module('shortcut',
|
| 147 |
+
nn.AvgPool2d(2, stride=2))
|
| 148 |
+
|
| 149 |
+
self.selayer = SelayerSpectr(out_channels)
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
bottleneck = self.conv_reduce.forward(x)
|
| 153 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 154 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
| 155 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 156 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
| 157 |
+
bottleneck = self.selayer(bottleneck)
|
| 158 |
+
|
| 159 |
+
x = self.shortcut.forward(x)
|
| 160 |
+
return x + bottleneck
|
| 161 |
+
|
| 162 |
+
class FeatureConv(nn.Module):
|
| 163 |
+
def __init__(self, input_dim=512, output_dim=512):
|
| 164 |
+
super(FeatureConv, self).__init__()
|
| 165 |
+
|
| 166 |
+
no_bn = True
|
| 167 |
+
|
| 168 |
+
seq = []
|
| 169 |
+
seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
| 170 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
| 171 |
+
seq.append(nn.ReLU(inplace=True))
|
| 172 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
|
| 173 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
| 174 |
+
seq.append(nn.ReLU(inplace=True))
|
| 175 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
| 176 |
+
seq.append(nn.ReLU(inplace=True))
|
| 177 |
+
|
| 178 |
+
self.network = nn.Sequential(*seq)
|
| 179 |
+
|
| 180 |
+
def forward(self, x):
|
| 181 |
+
return self.network(x)
|
| 182 |
+
|
| 183 |
+
class Generator(nn.Module):
|
| 184 |
+
def __init__(self, ngf=64):
|
| 185 |
+
super(Generator, self).__init__()
|
| 186 |
+
|
| 187 |
+
self.feature_conv = FeatureConv()
|
| 188 |
+
|
| 189 |
+
self.to0 = self._make_encoder_block_first(6, 32)
|
| 190 |
+
self.to1 = self._make_encoder_block(32, 64)
|
| 191 |
+
self.to2 = self._make_encoder_block(64, 128)
|
| 192 |
+
self.to3 = self._make_encoder_block(128, 256)
|
| 193 |
+
self.to4 = self._make_encoder_block(256, 512)
|
| 194 |
+
|
| 195 |
+
self.deconv_for_decoder = nn.Sequential(
|
| 196 |
+
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
|
| 197 |
+
nn.LeakyReLU(0.2),
|
| 198 |
+
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
|
| 199 |
+
nn.LeakyReLU(0.2),
|
| 200 |
+
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # output is 256 * 256
|
| 201 |
+
nn.LeakyReLU(0.2),
|
| 202 |
+
nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
|
| 203 |
+
nn.Tanh(),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
tunnel4 = nn.Sequential(*[ResNeXtBottleneck(ngf * 8, ngf * 8, cardinality=32, dilate=1) for _ in range(20)])
|
| 207 |
+
|
| 208 |
+
self.tunnel4 = nn.Sequential(nn.Conv2d(ngf * 8 + 512, ngf * 8, kernel_size=3, stride=1, padding=1),
|
| 209 |
+
nn.LeakyReLU(0.2, True),
|
| 210 |
+
tunnel4,
|
| 211 |
+
nn.Conv2d(ngf * 8, ngf * 4 * 4, kernel_size=3, stride=1, padding=1),
|
| 212 |
+
nn.PixelShuffle(2),
|
| 213 |
+
nn.LeakyReLU(0.2, True)
|
| 214 |
+
) # 64
|
| 215 |
+
|
| 216 |
+
depth = 2
|
| 217 |
+
tunnel = [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1) for _ in range(depth)]
|
| 218 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2) for _ in range(depth)]
|
| 219 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=4) for _ in range(depth)]
|
| 220 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2),
|
| 221 |
+
ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1)]
|
| 222 |
+
tunnel3 = nn.Sequential(*tunnel)
|
| 223 |
+
|
| 224 |
+
self.tunnel3 = nn.Sequential(nn.Conv2d(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=1),
|
| 225 |
+
nn.LeakyReLU(0.2, True),
|
| 226 |
+
tunnel3,
|
| 227 |
+
nn.Conv2d(ngf * 4, ngf * 2 * 4, kernel_size=3, stride=1, padding=1),
|
| 228 |
+
nn.PixelShuffle(2),
|
| 229 |
+
nn.LeakyReLU(0.2, True)
|
| 230 |
+
) # 128
|
| 231 |
+
|
| 232 |
+
tunnel = [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1) for _ in range(depth)]
|
| 233 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2) for _ in range(depth)]
|
| 234 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=4) for _ in range(depth)]
|
| 235 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2),
|
| 236 |
+
ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1)]
|
| 237 |
+
tunnel2 = nn.Sequential(*tunnel)
|
| 238 |
+
|
| 239 |
+
self.tunnel2 = nn.Sequential(nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=1),
|
| 240 |
+
nn.LeakyReLU(0.2, True),
|
| 241 |
+
tunnel2,
|
| 242 |
+
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1),
|
| 243 |
+
nn.PixelShuffle(2),
|
| 244 |
+
nn.LeakyReLU(0.2, True)
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
tunnel = [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
|
| 248 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2)]
|
| 249 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=4)]
|
| 250 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2),
|
| 251 |
+
ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
|
| 252 |
+
tunnel1 = nn.Sequential(*tunnel)
|
| 253 |
+
|
| 254 |
+
self.tunnel1 = nn.Sequential(nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=1),
|
| 255 |
+
nn.LeakyReLU(0.2, True),
|
| 256 |
+
tunnel1,
|
| 257 |
+
nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=1, padding=1),
|
| 258 |
+
nn.PixelShuffle(2),
|
| 259 |
+
nn.LeakyReLU(0.2, True)
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
self.exit = nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _make_encoder_block(self, inplanes, planes):
|
| 266 |
+
return nn.Sequential(
|
| 267 |
+
nn.Conv2d(inplanes, planes, 3, 2, 1),
|
| 268 |
+
nn.LeakyReLU(0.2),
|
| 269 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
| 270 |
+
nn.LeakyReLU(0.2),
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def _make_encoder_block_first(self, inplanes, planes):
|
| 274 |
+
return nn.Sequential(
|
| 275 |
+
nn.Conv2d(inplanes, planes, 3, 1, 1),
|
| 276 |
+
nn.LeakyReLU(0.2),
|
| 277 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
| 278 |
+
nn.LeakyReLU(0.2),
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def forward(self, sketch, sketch_feat):
|
| 282 |
+
|
| 283 |
+
x0 = self.to0(sketch)
|
| 284 |
+
x1 = self.to1(x0)
|
| 285 |
+
x2 = self.to2(x1)
|
| 286 |
+
x3 = self.to3(x2) # !
|
| 287 |
+
x4 = self.to4(x3)
|
| 288 |
+
|
| 289 |
+
sketch_feat = self.feature_conv(sketch_feat)
|
| 290 |
+
|
| 291 |
+
out = self.tunnel4(torch.cat([x4, sketch_feat], 1))
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
x = self.tunnel3(torch.cat([out, x3], 1))
|
| 297 |
+
x = self.tunnel2(torch.cat([x, x2], 1))
|
| 298 |
+
x = self.tunnel1(torch.cat([x, x1], 1))
|
| 299 |
+
x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
|
| 300 |
+
|
| 301 |
+
decoder_output = self.deconv_for_decoder(out)
|
| 302 |
+
|
| 303 |
+
return x, decoder_output
|
| 304 |
+
'''
|
| 305 |
+
class Colorizer(nn.Module):
|
| 306 |
+
def __init__(self, extractor_path = 'model/model.pth'):
|
| 307 |
+
super(Colorizer, self).__init__()
|
| 308 |
+
|
| 309 |
+
self.generator = Generator()
|
| 310 |
+
self.extractor = se_resnext_half(dump_path=extractor_path, num_classes=370, input_channels=1)
|
| 311 |
+
|
| 312 |
+
def extractor_eval(self):
|
| 313 |
+
for param in self.extractor.parameters():
|
| 314 |
+
param.requires_grad = False
|
| 315 |
+
|
| 316 |
+
def extractor_train(self):
|
| 317 |
+
for param in extractor.parameters():
|
| 318 |
+
param.requires_grad = True
|
| 319 |
+
|
| 320 |
+
def forward(self, x, extractor_grad = False):
|
| 321 |
+
|
| 322 |
+
if extractor_grad:
|
| 323 |
+
features = self.extractor(x[:, 0:1])
|
| 324 |
+
else:
|
| 325 |
+
with torch.no_grad():
|
| 326 |
+
features = self.extractor(x[:, 0:1]).detach()
|
| 327 |
+
|
| 328 |
+
fake, guide = self.generator(x, features)
|
| 329 |
+
|
| 330 |
+
return fake, guide
|
| 331 |
+
'''
|
| 332 |
+
|
| 333 |
+
class Colorizer(nn.Module):
|
| 334 |
+
def __init__(self, generator_model, extractor_model):
|
| 335 |
+
super(Colorizer, self).__init__()
|
| 336 |
+
|
| 337 |
+
self.generator = generator_model
|
| 338 |
+
self.extractor = extractor_model
|
| 339 |
+
|
| 340 |
+
def load_generator_weights(self, gen_weights):
|
| 341 |
+
self.generator.load_state_dict(gen_weights)
|
| 342 |
+
|
| 343 |
+
def load_extractor_weights(self, ext_weights):
|
| 344 |
+
self.extractor.load_state_dict(ext_weights)
|
| 345 |
+
|
| 346 |
+
def extractor_eval(self):
|
| 347 |
+
for param in self.extractor.parameters():
|
| 348 |
+
param.requires_grad = False
|
| 349 |
+
self.extractor.eval()
|
| 350 |
+
|
| 351 |
+
def extractor_train(self):
|
| 352 |
+
for param in extractor.parameters():
|
| 353 |
+
param.requires_grad = True
|
| 354 |
+
self.extractor.train()
|
| 355 |
+
|
| 356 |
+
def forward(self, x, extractor_grad = False):
|
| 357 |
+
|
| 358 |
+
if extractor_grad:
|
| 359 |
+
features = self.extractor(x[:, 0:1])
|
| 360 |
+
else:
|
| 361 |
+
with torch.no_grad():
|
| 362 |
+
features = self.extractor(x[:, 0:1]).detach()
|
| 363 |
+
|
| 364 |
+
fake, guide = self.generator(x, features)
|
| 365 |
+
|
| 366 |
+
return fake, guide
|
| 367 |
+
|
| 368 |
+
class Discriminator(nn.Module):
|
| 369 |
+
def __init__(self, ndf=64):
|
| 370 |
+
super(Discriminator, self).__init__()
|
| 371 |
+
|
| 372 |
+
self.feed = nn.Sequential(SpectralNorm(nn.Conv2d(3, 64, 3, 1, 1)),
|
| 373 |
+
nn.LeakyReLU(0.2, True),
|
| 374 |
+
SpectralNorm(nn.Conv2d(64, 64, 3, 2, 0)),
|
| 375 |
+
nn.LeakyReLU(0.2, True),
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1),
|
| 381 |
+
SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1, stride=2), # 128
|
| 382 |
+
SpectralNorm(nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=False)),
|
| 383 |
+
nn.LeakyReLU(0.2, True),
|
| 384 |
+
|
| 385 |
+
SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1),
|
| 386 |
+
SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2), # 64
|
| 387 |
+
SpectralNorm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0, bias=False)),
|
| 388 |
+
nn.LeakyReLU(0.2, True),
|
| 389 |
+
|
| 390 |
+
SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1),
|
| 391 |
+
SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2), # 32,
|
| 392 |
+
SpectralNorm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=1, stride=1, padding=1, bias=False)),
|
| 393 |
+
nn.LeakyReLU(0.2, True),
|
| 394 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
| 395 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2), # 16
|
| 396 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
| 397 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
| 398 |
+
nn.AdaptiveAvgPool2d((1, 1))
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
self.out = nn.Linear(512, 1)
|
| 402 |
+
|
| 403 |
+
def forward(self, color):
|
| 404 |
+
x = self.feed(color)
|
| 405 |
+
|
| 406 |
+
out = self.out(x.view(color.size(0), -1))
|
| 407 |
+
return out
|
| 408 |
+
|
| 409 |
+
class Content(nn.Module):
|
| 410 |
+
def __init__(self, path):
|
| 411 |
+
super(Content, self).__init__()
|
| 412 |
+
vgg16 = M.vgg16()
|
| 413 |
+
vgg16.load_state_dict(torch.load(path))
|
| 414 |
+
vgg16.features = nn.Sequential(
|
| 415 |
+
*list(vgg16.features.children())[:9]
|
| 416 |
+
)
|
| 417 |
+
self.model = vgg16.features
|
| 418 |
+
self.register_buffer('mean', torch.FloatTensor([0.485 - 0.5, 0.456 - 0.5, 0.406 - 0.5]).view(1, 3, 1, 1))
|
| 419 |
+
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 420 |
+
|
| 421 |
+
def forward(self, images):
|
| 422 |
+
return self.model((images.mul(0.5) - self.mean) / self.std)
|
model/vgg16-397923af.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0
|
| 3 |
+
size 553433881
|
train.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import numpy as np
|
| 5 |
+
import albumentations as albu
|
| 6 |
+
import argparse
|
| 7 |
+
import datetime
|
| 8 |
+
|
| 9 |
+
from utils.utils import open_json, weights_init, weights_init_spectr, generate_mask
|
| 10 |
+
from model.models import Colorizer, Generator, Content, Discriminator
|
| 11 |
+
from model.extractor import get_seresnext_extractor
|
| 12 |
+
from dataset.datasets import TrainDataset, FineTuningDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_args():
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
parser.add_argument("-p", "--path", required=True, help = "dataset path")
|
| 19 |
+
parser.add_argument('-ft', '--fine_tuning', dest = 'fine_tuning', action = 'store_true')
|
| 20 |
+
parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
|
| 21 |
+
parser.set_defaults(fine_tuning = False)
|
| 22 |
+
parser.set_defaults(gpu = False)
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
return args
|
| 26 |
+
|
| 27 |
+
def get_transforms():
|
| 28 |
+
return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
|
| 29 |
+
|
| 30 |
+
def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number):
|
| 31 |
+
train_dataset = TrainDataset(data_path, transforms, mult_number)
|
| 32 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
|
| 33 |
+
|
| 34 |
+
if fine_tuning:
|
| 35 |
+
finetuning_dataset = FineTuningDataset(data_path, transforms)
|
| 36 |
+
finetuning_dataloader = torch.utils.data.DataLoader(finetuning_dataset, batch_size = batch_size, shuffle = True)
|
| 37 |
+
|
| 38 |
+
return train_dataloader, finetuning_dataloader
|
| 39 |
+
|
| 40 |
+
def get_models(device):
|
| 41 |
+
generator = Generator()
|
| 42 |
+
extractor = get_seresnext_extractor()
|
| 43 |
+
colorizer = Colorizer(generator, extractor)
|
| 44 |
+
|
| 45 |
+
colorizer.extractor_eval()
|
| 46 |
+
colorizer = colorizer.to(device)
|
| 47 |
+
|
| 48 |
+
discriminator = Discriminator().to(device)
|
| 49 |
+
|
| 50 |
+
content = Content('model/vgg16-397923af.pth').eval().to(device)
|
| 51 |
+
for param in content.parameters():
|
| 52 |
+
param.requires_grad = False
|
| 53 |
+
|
| 54 |
+
return colorizer, discriminator, content
|
| 55 |
+
|
| 56 |
+
def set_weights(colorizer, discriminator):
|
| 57 |
+
colorizer.generator.apply(weights_init)
|
| 58 |
+
colorizer.load_extractor_weights(torch.load('model/extractor.pth'))
|
| 59 |
+
|
| 60 |
+
discriminator.apply(weights_init_spectr)
|
| 61 |
+
|
| 62 |
+
def generator_loss(disc_output, true_labels, main_output, guide_output, real_image, content_gen, content_true, dist_loss = nn.L1Loss(), content_dist_loss = nn.MSELoss(), class_loss = nn.BCEWithLogitsLoss()):
|
| 63 |
+
sim_loss_full = dist_loss(main_output, real_image)
|
| 64 |
+
sim_loss_guide = dist_loss(guide_output, real_image)
|
| 65 |
+
|
| 66 |
+
adv_loss = class_loss(disc_output, true_labels)
|
| 67 |
+
|
| 68 |
+
content_loss = content_dist_loss(content_gen, content_true)
|
| 69 |
+
|
| 70 |
+
sum_loss = 10 * (sim_loss_full + 0.9 * sim_loss_guide) + adv_loss + content_loss
|
| 71 |
+
|
| 72 |
+
return sum_loss
|
| 73 |
+
|
| 74 |
+
def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
|
| 75 |
+
optimizerG = optim.Adam(colorizer.generator.parameters(), lr = generator_lr, betas=(0.5, 0.9))
|
| 76 |
+
optimizerD = optim.Adam(discriminator.parameters(), lr = discriminator_lr, betas=(0.5, 0.9))
|
| 77 |
+
|
| 78 |
+
return optimizerG, optimizerD
|
| 79 |
+
|
| 80 |
+
def generator_step(inputs, colorizer, discriminator, content, loss_function, optimizer, device, white_penalty = True):
|
| 81 |
+
for p in discriminator.parameters():
|
| 82 |
+
p.requires_grad = False
|
| 83 |
+
for p in colorizer.generator.parameters():
|
| 84 |
+
p.requires_grad = True
|
| 85 |
+
|
| 86 |
+
colorizer.generator.zero_grad()
|
| 87 |
+
|
| 88 |
+
bw, color, hint, dfm = inputs
|
| 89 |
+
bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
|
| 90 |
+
|
| 91 |
+
fake, guide = colorizer(torch.cat([bw, dfm, hint], 1))
|
| 92 |
+
|
| 93 |
+
logits_fake = discriminator(fake)
|
| 94 |
+
y_real = torch.ones((bw.size(0), 1), device = device)
|
| 95 |
+
|
| 96 |
+
content_fake = content(fake)
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
content_true = content(color)
|
| 99 |
+
|
| 100 |
+
generator_loss = loss_function(logits_fake, y_real, fake, guide, color, content_fake, content_true)
|
| 101 |
+
|
| 102 |
+
if white_penalty:
|
| 103 |
+
mask = (~((color > 0.85).float().sum(dim = 1) == 3).unsqueeze(1).repeat((1, 3, 1, 1 ))).float()
|
| 104 |
+
white_zones = mask * (fake + 1) / 2
|
| 105 |
+
white_penalty = (torch.pow(white_zones.sum(dim = 1), 2).sum(dim = (1, 2)) / (mask.sum(dim = (1, 2, 3)) + 1)).mean()
|
| 106 |
+
|
| 107 |
+
generator_loss += white_penalty
|
| 108 |
+
|
| 109 |
+
generator_loss.backward()
|
| 110 |
+
|
| 111 |
+
optimizer.step()
|
| 112 |
+
|
| 113 |
+
return generator_loss.item()
|
| 114 |
+
|
| 115 |
+
def discriminator_step(inputs, colorizer, discriminator, optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
|
| 116 |
+
|
| 117 |
+
for p in discriminator.parameters():
|
| 118 |
+
p.requires_grad = True
|
| 119 |
+
for p in colorizer.generator.parameters():
|
| 120 |
+
p.requires_grad = False
|
| 121 |
+
|
| 122 |
+
discriminator.zero_grad()
|
| 123 |
+
|
| 124 |
+
bw, color, hint, dfm = inputs
|
| 125 |
+
bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
|
| 126 |
+
|
| 127 |
+
y_real = torch.full((bw.size(0), 1), 0.9, device = device)
|
| 128 |
+
|
| 129 |
+
y_fake = torch.zeros((bw.size(0), 1), device = device)
|
| 130 |
+
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
fake_color, _ = colorizer(torch.cat([bw, dfm, hint], 1))
|
| 133 |
+
fake_color.detach()
|
| 134 |
+
|
| 135 |
+
logits_fake = discriminator(fake_color)
|
| 136 |
+
logits_real = discriminator(color)
|
| 137 |
+
|
| 138 |
+
fake_loss = loss_function(logits_fake, y_fake)
|
| 139 |
+
real_loss = loss_function(logits_real, y_real)
|
| 140 |
+
|
| 141 |
+
discriminator_loss = real_loss + fake_loss
|
| 142 |
+
|
| 143 |
+
discriminator_loss.backward()
|
| 144 |
+
optimizer.step()
|
| 145 |
+
|
| 146 |
+
return discriminator_loss.item()
|
| 147 |
+
|
| 148 |
+
def decrease_lr(optimizer, rate):
|
| 149 |
+
for group in optimizer.param_groups:
|
| 150 |
+
group['lr'] /= rate
|
| 151 |
+
|
| 152 |
+
def set_lr(optimizer, value):
|
| 153 |
+
for group in optimizer.param_groups:
|
| 154 |
+
group['lr'] = value
|
| 155 |
+
|
| 156 |
+
def train(colorizer, discriminator, content, dataloader, epochs, colorizer_optimizer, discriminator_optimizer, lr_decay_epoch = -1, device = 'cpu'):
|
| 157 |
+
colorizer.generator.train()
|
| 158 |
+
discriminator.train()
|
| 159 |
+
|
| 160 |
+
disc_step = True
|
| 161 |
+
|
| 162 |
+
for epoch in range(epochs):
|
| 163 |
+
if (epoch == lr_decay_epoch):
|
| 164 |
+
decrease_lr(colorizer_optimizer, 10)
|
| 165 |
+
decrease_lr(discriminator_optimizer, 10)
|
| 166 |
+
|
| 167 |
+
sum_disc_loss = 0
|
| 168 |
+
sum_gen_loss = 0
|
| 169 |
+
|
| 170 |
+
for n, inputs in enumerate(dataloader):
|
| 171 |
+
if n % 5 == 0:
|
| 172 |
+
print(datetime.datetime.now().time())
|
| 173 |
+
print('Step : %d Discr loss: %.4f Gen loss : %.4f \n'%(n, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
if disc_step:
|
| 177 |
+
step_loss = discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
|
| 178 |
+
sum_disc_loss += step_loss
|
| 179 |
+
else:
|
| 180 |
+
step_loss = generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
|
| 181 |
+
sum_gen_loss += step_loss
|
| 182 |
+
|
| 183 |
+
disc_step = disc_step ^ True
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
print(datetime.datetime.now().time())
|
| 187 |
+
print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
|
| 191 |
+
|
| 192 |
+
for p in discriminator.parameters():
|
| 193 |
+
p.requires_grad = True
|
| 194 |
+
for p in colorizer.generator.parameters():
|
| 195 |
+
p.requires_grad = False
|
| 196 |
+
|
| 197 |
+
for cur_disc_step in range(5):
|
| 198 |
+
discriminator.zero_grad()
|
| 199 |
+
|
| 200 |
+
bw, dfm, color_for_real = data_iter.next()
|
| 201 |
+
bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
|
| 202 |
+
|
| 203 |
+
y_real = torch.full((bw.size(0), 1), 0.9, device = device)
|
| 204 |
+
y_fake = torch.zeros((bw.size(0), 1), device = device)
|
| 205 |
+
|
| 206 |
+
empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3] ).float().to(device)
|
| 207 |
+
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
fake_color_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint ], 1))
|
| 210 |
+
fake_color_manga.detach()
|
| 211 |
+
|
| 212 |
+
logits_fake = discriminator(fake_color_manga)
|
| 213 |
+
logits_real = discriminator(color_for_real)
|
| 214 |
+
|
| 215 |
+
fake_loss = loss_function(logits_fake, y_fake)
|
| 216 |
+
real_loss = loss_function(logits_real, y_real)
|
| 217 |
+
discriminator_loss = real_loss + fake_loss
|
| 218 |
+
|
| 219 |
+
discriminator_loss.backward()
|
| 220 |
+
disc_optimizer.step()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
for p in discriminator.parameters():
|
| 224 |
+
p.requires_grad = False
|
| 225 |
+
for p in colorizer.generator.parameters():
|
| 226 |
+
p.requires_grad = True
|
| 227 |
+
|
| 228 |
+
colorizer.generator.zero_grad()
|
| 229 |
+
|
| 230 |
+
bw, dfm, _ = data_iter.next()
|
| 231 |
+
bw, dfm = bw.to(device), dfm.to(device)
|
| 232 |
+
|
| 233 |
+
y_real = torch.ones((bw.size(0), 1), device = device)
|
| 234 |
+
|
| 235 |
+
empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3]).float().to(device)
|
| 236 |
+
|
| 237 |
+
fake_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint], 1))
|
| 238 |
+
|
| 239 |
+
logits_fake = discriminator(fake_manga)
|
| 240 |
+
adv_loss = loss_function(logits_fake, y_real)
|
| 241 |
+
|
| 242 |
+
generator_loss = adv_loss
|
| 243 |
+
|
| 244 |
+
generator_loss.backward()
|
| 245 |
+
gen_optimizer.step()
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, data_iter, device = 'cpu'):
|
| 250 |
+
colorizer.generator.train()
|
| 251 |
+
discriminator.train()
|
| 252 |
+
|
| 253 |
+
disc_step = True
|
| 254 |
+
|
| 255 |
+
for n, inputs in enumerate(dataloader):
|
| 256 |
+
|
| 257 |
+
if n == iterations:
|
| 258 |
+
return
|
| 259 |
+
|
| 260 |
+
if disc_step:
|
| 261 |
+
discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
|
| 262 |
+
else:
|
| 263 |
+
generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
|
| 264 |
+
|
| 265 |
+
disc_step = disc_step ^ True
|
| 266 |
+
|
| 267 |
+
if n % 10 == 5:
|
| 268 |
+
fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
|
| 269 |
+
|
| 270 |
+
if __name__ == '__main__':
|
| 271 |
+
args = parse_args()
|
| 272 |
+
config = open_json('configs/train_config.json')
|
| 273 |
+
|
| 274 |
+
if args.gpu:
|
| 275 |
+
device = 'cuda'
|
| 276 |
+
else:
|
| 277 |
+
device = 'cpu'
|
| 278 |
+
|
| 279 |
+
augmentations = get_transforms()
|
| 280 |
+
|
| 281 |
+
train_dataloader, ft_dataloader = get_dataloaders(args.path, augmentations, config['batch_size'], args.fine_tuning, config['number_of_mults'])
|
| 282 |
+
|
| 283 |
+
colorizer, discriminator, content = get_models(device)
|
| 284 |
+
set_weights(colorizer, discriminator)
|
| 285 |
+
|
| 286 |
+
gen_optimizer, disc_optimizer = get_optimizers(colorizer, discriminator, config['generator_lr'], config['discriminator_lr'])
|
| 287 |
+
|
| 288 |
+
train(colorizer, discriminator, content, train_dataloader, config['epochs'], gen_optimizer, disc_optimizer, config['lr_decrease_epoch'], device)
|
| 289 |
+
|
| 290 |
+
if args.fine_tuning:
|
| 291 |
+
set_lr(gen_optimizer, config["finetuning_generator_lr"])
|
| 292 |
+
fine_tuning(colorizer, discriminator, content, train_dataloader, config['finetuning_iterations'], gen_optimizer, disc_optimizer, iter(ft_dataloader), device)
|
| 293 |
+
|
| 294 |
+
torch.save(colorizer.generator.state_dict(), str(datetime.datetime.now().time()))
|
utils/dataset_utils.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import cv2
|
| 4 |
+
import snowy
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_resized_image(img, size):
|
| 9 |
+
if len(img.shape) == 2:
|
| 10 |
+
img = np.repeat(np.expand_dims(img, 2), 3, 2)
|
| 11 |
+
|
| 12 |
+
if (img.shape[0] < img.shape[1]):
|
| 13 |
+
height = img.shape[0]
|
| 14 |
+
ratio = height / size
|
| 15 |
+
width = int(np.ceil(img.shape[1] / ratio))
|
| 16 |
+
img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
|
| 17 |
+
else:
|
| 18 |
+
width = img.shape[1]
|
| 19 |
+
ratio = width / size
|
| 20 |
+
height = int(np.ceil(img.shape[0] / ratio))
|
| 21 |
+
img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
|
| 22 |
+
|
| 23 |
+
if (img.dtype == 'float32'):
|
| 24 |
+
np.clip(img, 0, 1, out = img)
|
| 25 |
+
|
| 26 |
+
return img
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_sketch_image(img, sketcher, mult_val):
|
| 30 |
+
|
| 31 |
+
if mult_val:
|
| 32 |
+
sketch_image = sketcher.get_sketch_with_resize(img, mult = mult_val)
|
| 33 |
+
else:
|
| 34 |
+
sketch_image = sketcher.get_sketch_with_resize(img)
|
| 35 |
+
|
| 36 |
+
return sketch_image
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_dfm_image(sketch):
|
| 40 |
+
dfm_image = snowy.unitize(snowy.generate_sdf(np.expand_dims(1 - sketch, 2) != 0)).squeeze()
|
| 41 |
+
return dfm_image
|
| 42 |
+
|
| 43 |
+
def get_sketch(image, sketcher, dfm, mult = None):
|
| 44 |
+
sketch_image = get_sketch_image(image, sketcher, mult)
|
| 45 |
+
|
| 46 |
+
dfm_image = None
|
| 47 |
+
|
| 48 |
+
if dfm:
|
| 49 |
+
dfm_image = get_dfm_image(sketch_image)
|
| 50 |
+
|
| 51 |
+
sketch_image = (sketch_image * 255).astype('uint8')
|
| 52 |
+
|
| 53 |
+
if dfm:
|
| 54 |
+
dfm_image = (dfm_image * 255).astype('uint8')
|
| 55 |
+
|
| 56 |
+
return sketch_image, dfm_image
|
| 57 |
+
|
| 58 |
+
def get_sketches(image, sketcher, mult_list, dfm):
|
| 59 |
+
for mult in mult_list:
|
| 60 |
+
yield get_sketch(image, sketcher, dfm, mult)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def create_resized_dataset(source_path, target_path, side_size):
|
| 64 |
+
images = os.listdir(source_path)
|
| 65 |
+
|
| 66 |
+
for image_name in images:
|
| 67 |
+
|
| 68 |
+
new_image_name = image_name[:image_name.rfind('.')] + '.png'
|
| 69 |
+
new_path = os.path.join(target_path, new_image_name)
|
| 70 |
+
|
| 71 |
+
if not os.path.exists(new_path):
|
| 72 |
+
try:
|
| 73 |
+
image = cv2.imread(os.path.join(source_path, image_name))
|
| 74 |
+
|
| 75 |
+
if image is None:
|
| 76 |
+
raise Exception()
|
| 77 |
+
|
| 78 |
+
image = get_resized_image(image, side_size)
|
| 79 |
+
|
| 80 |
+
cv2.imwrite(new_path, image)
|
| 81 |
+
except:
|
| 82 |
+
print('Failed to process {}'.format(image_name))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def create_sketches_dataset(source_path, target_path, sketcher, mult_list, dfm = False):
|
| 86 |
+
|
| 87 |
+
images = os.listdir(source_path)
|
| 88 |
+
for image_name in images:
|
| 89 |
+
try:
|
| 90 |
+
image = cv2.imread(os.path.join(source_path, image_name))
|
| 91 |
+
|
| 92 |
+
if image is None:
|
| 93 |
+
raise Exception()
|
| 94 |
+
|
| 95 |
+
for number, (sketch_image, dfm_image) in enumerate(get_sketches(image, sketcher, mult_list, dfm)):
|
| 96 |
+
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
|
| 97 |
+
cv2.imwrite(os.path.join(target_path, new_sketch_name), sketch_image)
|
| 98 |
+
|
| 99 |
+
if dfm:
|
| 100 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
|
| 101 |
+
cv2.imwrite(os.path.join(target_path, dfm_name), dfm_image)
|
| 102 |
+
|
| 103 |
+
except:
|
| 104 |
+
print('Failed to process {}'.format(image_name))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def create_dataset(source_path, target_path, sketcher, mult_list, side_size, dfm = False):
|
| 108 |
+
images = os.listdir(source_path)
|
| 109 |
+
|
| 110 |
+
color_path = os.path.join(target_path, 'color')
|
| 111 |
+
sketch_path = os.path.join(target_path, 'bw')
|
| 112 |
+
|
| 113 |
+
if not os.path.exists(color_path):
|
| 114 |
+
os.makedirs(color_path)
|
| 115 |
+
|
| 116 |
+
if not os.path.exists(sketch_path):
|
| 117 |
+
os.makedirs(sketch_path)
|
| 118 |
+
|
| 119 |
+
for image_name in images:
|
| 120 |
+
new_image_name = image_name[:image_name.rfind('.')] + '.png'
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
image = cv2.imread(os.path.join(source_path, image_name))
|
| 124 |
+
|
| 125 |
+
if image is None:
|
| 126 |
+
raise Exception()
|
| 127 |
+
|
| 128 |
+
resized_image = get_resized_image(image, side_size)
|
| 129 |
+
cv2.imwrite(os.path.join(color_path, new_image_name), resized_image)
|
| 130 |
+
|
| 131 |
+
for number, (sketch_image, dfm_image) in enumerate(get_sketches(resized_image, sketcher, mult_list, dfm)):
|
| 132 |
+
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
|
| 133 |
+
cv2.imwrite(os.path.join(sketch_path, new_sketch_name), sketch_image)
|
| 134 |
+
|
| 135 |
+
if dfm:
|
| 136 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
|
| 137 |
+
cv2.imwrite(os.path.join(sketch_path, dfm_name), dfm_image)
|
| 138 |
+
|
| 139 |
+
except:
|
| 140 |
+
print('Failed to process {}'.format(image_name))
|
| 141 |
+
|
utils/utils.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import scipy.stats as stats
|
| 5 |
+
import cv2
|
| 6 |
+
import json
|
| 7 |
+
import patoolib
|
| 8 |
+
import re
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from shutil import rmtree
|
| 11 |
+
|
| 12 |
+
def weights_init(m):
|
| 13 |
+
classname = m.__class__.__name__
|
| 14 |
+
if classname.find('Conv2d') != -1:
|
| 15 |
+
nn.init.xavier_uniform_(m.weight.data)
|
| 16 |
+
|
| 17 |
+
def weights_init_spectr(m):
|
| 18 |
+
classname = m.__class__.__name__
|
| 19 |
+
if classname.find('Conv2d') != -1:
|
| 20 |
+
nn.init.xavier_uniform_(m.weight_bar.data)
|
| 21 |
+
|
| 22 |
+
def generate_mask(height, width, mu = 1, sigma = 0.0005, prob = 0.5, full = True, full_prob = 0.01):
|
| 23 |
+
X = stats.truncnorm((0 - mu) / sigma, (1 - mu) / sigma, loc=mu, scale=sigma)
|
| 24 |
+
|
| 25 |
+
if full:
|
| 26 |
+
if (np.random.binomial(1, p = full_prob) == 1):
|
| 27 |
+
return torch.ones(1, height, width).float()
|
| 28 |
+
|
| 29 |
+
if np.random.binomial(1, p = prob) == 1:
|
| 30 |
+
mask = torch.rand(1, height, width).ge(X.rvs(1)[0]).float()
|
| 31 |
+
else:
|
| 32 |
+
mask = torch.zeros(1, height, width).float()
|
| 33 |
+
|
| 34 |
+
return mask
|
| 35 |
+
|
| 36 |
+
def resize_pad(img, size = 512):
|
| 37 |
+
|
| 38 |
+
if len(img.shape) == 2:
|
| 39 |
+
img = np.expand_dims(img, 2)
|
| 40 |
+
|
| 41 |
+
if img.shape[2] == 1:
|
| 42 |
+
img = np.repeat(img, 3, 2)
|
| 43 |
+
|
| 44 |
+
if img.shape[2] == 4:
|
| 45 |
+
img = img[:, :, :3]
|
| 46 |
+
|
| 47 |
+
pad = None
|
| 48 |
+
|
| 49 |
+
if (img.shape[0] < img.shape[1]):
|
| 50 |
+
height = img.shape[0]
|
| 51 |
+
ratio = height / size
|
| 52 |
+
width = int(np.ceil(img.shape[1] / ratio))
|
| 53 |
+
img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
|
| 54 |
+
|
| 55 |
+
new_width = width
|
| 56 |
+
while (new_width % 32 != 0):
|
| 57 |
+
new_width += 1
|
| 58 |
+
|
| 59 |
+
pad = (0, new_width - width)
|
| 60 |
+
|
| 61 |
+
img = np.pad(img, ((0, 0), (0, pad[1]), (0, 0)), 'maximum')
|
| 62 |
+
else:
|
| 63 |
+
width = img.shape[1]
|
| 64 |
+
ratio = width / size
|
| 65 |
+
height = int(np.ceil(img.shape[0] / ratio))
|
| 66 |
+
img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
|
| 67 |
+
|
| 68 |
+
new_height = height
|
| 69 |
+
while (new_height % 32 != 0):
|
| 70 |
+
new_height += 1
|
| 71 |
+
|
| 72 |
+
pad = (new_height - height, 0)
|
| 73 |
+
|
| 74 |
+
img = np.pad(img, ((0, pad[0]), (0, 0), (0, 0)), 'maximum')
|
| 75 |
+
|
| 76 |
+
if (img.dtype == 'float32'):
|
| 77 |
+
np.clip(img, 0, 1, out = img)
|
| 78 |
+
|
| 79 |
+
return img, pad
|
| 80 |
+
|
| 81 |
+
def open_json(file):
|
| 82 |
+
with open(file) as json_file:
|
| 83 |
+
data = json.load(json_file)
|
| 84 |
+
|
| 85 |
+
return data
|
| 86 |
+
|
| 87 |
+
def extract_cbr(file, out_dir):
|
| 88 |
+
patoolib.extract_archive(file, outdir = out_dir, verbosity = 1)
|
| 89 |
+
|
| 90 |
+
def create_cbz(file_path, files):
|
| 91 |
+
patoolib.create_archive(file_path, files, verbosity = 1)
|
| 92 |
+
|
| 93 |
+
def subfolder_image_search(start_folder):
|
| 94 |
+
return [x.as_posix() for x in Path(".").rglob("*.[pPjJ][nNpP][gG]")]
|
| 95 |
+
|
| 96 |
+
def remove_folder(folder_path):
|
| 97 |
+
rmtree(folder_path)
|
| 98 |
+
|
| 99 |
+
def sorted_alphanumeric(data):
|
| 100 |
+
convert = lambda text: int(text) if text.isdigit() else text.lower()
|
| 101 |
+
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
|
| 102 |
+
return sorted(data, key=alphanum_key)
|
utils/xdog.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from cv2 import resize, INTER_LANCZOS4, INTER_AREA
|
| 2 |
+
from skimage.color import rgb2gray
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy.ndimage.filters import gaussian_filter
|
| 5 |
+
from skimage.filters import threshold_otsu
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
|
| 8 |
+
class XDoGSketcher:
|
| 9 |
+
|
| 10 |
+
def __init__(self, gamma = 0.95, phi = 89.25, eps = -0.1, k = 8, sigma = 0.5, mult = 1):
|
| 11 |
+
self.params = {}
|
| 12 |
+
self.params['gamma'] = gamma
|
| 13 |
+
self.params['phi'] = phi
|
| 14 |
+
self.params['eps'] = eps
|
| 15 |
+
self.params['k'] = k
|
| 16 |
+
self.params['sigma'] = sigma
|
| 17 |
+
|
| 18 |
+
self.params['mult'] = mult
|
| 19 |
+
|
| 20 |
+
def _xdog(self, im, **transform_params):
|
| 21 |
+
# Source : https://github.com/CemalUnal/XDoG-Filter
|
| 22 |
+
# Reference : XDoG: An eXtended difference-of-Gaussians compendium including advanced image stylization
|
| 23 |
+
# Link : http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.365.151&rep=rep1&type=pdf
|
| 24 |
+
|
| 25 |
+
if im.shape[2] == 3:
|
| 26 |
+
im = rgb2gray(im)
|
| 27 |
+
|
| 28 |
+
imf1 = gaussian_filter(im, transform_params['sigma'])
|
| 29 |
+
imf2 = gaussian_filter(im, transform_params['sigma'] * transform_params['k'])
|
| 30 |
+
imdiff = imf1 - transform_params['gamma'] * imf2
|
| 31 |
+
imdiff = (imdiff < transform_params['eps']) * 1.0 \
|
| 32 |
+
+ (imdiff >= transform_params['eps']) * (1.0 + np.tanh(transform_params['phi'] * imdiff))
|
| 33 |
+
imdiff -= imdiff.min()
|
| 34 |
+
imdiff /= imdiff.max()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
th = threshold_otsu(imdiff)
|
| 38 |
+
imdiff = imdiff >= th
|
| 39 |
+
|
| 40 |
+
imdiff = imdiff.astype('float32')
|
| 41 |
+
|
| 42 |
+
return imdiff
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_sketch(self, image, **kwargs):
|
| 46 |
+
current_params = self.params.copy()
|
| 47 |
+
|
| 48 |
+
for key in kwargs.keys():
|
| 49 |
+
if key in current_params.keys():
|
| 50 |
+
current_params[key] = kwargs[key]
|
| 51 |
+
|
| 52 |
+
result_image = self._xdog(image, **current_params)
|
| 53 |
+
|
| 54 |
+
return result_image
|
| 55 |
+
|
| 56 |
+
def get_sketch_with_resize(self, image, **kwargs):
|
| 57 |
+
if 'mult' in kwargs.keys():
|
| 58 |
+
mult = kwargs['mult']
|
| 59 |
+
else:
|
| 60 |
+
mult = self.params['mult']
|
| 61 |
+
|
| 62 |
+
temp_image = resize(image, (image.shape[1] * mult, image.shape[0] * mult), interpolation = INTER_LANCZOS4)
|
| 63 |
+
temp_image = self.get_sketch(temp_image, **kwargs)
|
| 64 |
+
image = resize(temp_image, (image.shape[1], image.shape[0]), interpolation = INTER_AREA)
|
| 65 |
+
|
| 66 |
+
return image
|
| 67 |
+
|
| 68 |
+
|