Upload 35 files
Browse files- .dockerignore +6 -0
- .gitignore +10 -0
- Dockerfile +11 -0
- configs/train_config.json +10 -0
- configs/xdog_config.json +8 -0
- dataset/__pycache__/datasets.cpython-39.pyc +0 -0
- dataset/datasets.py +107 -0
- denoising/denoiser.py +113 -0
- denoising/functions.py +101 -0
- denoising/models.py +100 -0
- denoising/models/.gitkeep +0 -0
- denoising/utils.py +66 -0
- drawing.py +165 -0
- environment.yml +43 -0
- inference.py +215 -0
- model/__pycache__/extractor.cpython-39.pyc +0 -0
- model/__pycache__/models.cpython-39.pyc +0 -0
- model/extractor.pth +3 -0
- model/extractor.py +127 -0
- model/models.py +422 -0
- model/vgg16-397923af.pth +3 -0
- readme.md +20 -0
- requirements.txt +41 -0
- run_drawing.sh +1 -0
- static/js/draw.js +120 -0
- static/temp_images/.gitkeep +0 -0
- templates/drawing.html +206 -0
- templates/submit.html +11 -0
- templates/upload.html +20 -0
- train.py +294 -0
- utils/__pycache__/utils.cpython-39.pyc +0 -0
- utils/dataset_utils.py +141 -0
- utils/utils.py +102 -0
- utils/xdog.py +68 -0
- web.py +108 -0
.dockerignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.ipynb
|
| 2 |
+
|
| 3 |
+
model/*.pth
|
| 4 |
+
|
| 5 |
+
temp_colorization/
|
| 6 |
+
__pycache__/
|
.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.ipynb
|
| 2 |
+
*.pth
|
| 3 |
+
*.zip
|
| 4 |
+
|
| 5 |
+
__pycache__/
|
| 6 |
+
temp_colorization/
|
| 7 |
+
|
| 8 |
+
static/temp_images/*
|
| 9 |
+
|
| 10 |
+
!.gitkeep
|
Dockerfile
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime
|
| 2 |
+
|
| 3 |
+
RUN apt-get update && apt-get install -y libglib2.0-0 libsm6 libxext6 libxrender-dev
|
| 4 |
+
|
| 5 |
+
COPY . .
|
| 6 |
+
|
| 7 |
+
RUN pip install --no-cache-dir -r ./requirements.txt
|
| 8 |
+
|
| 9 |
+
EXPOSE 5000
|
| 10 |
+
|
| 11 |
+
CMD gunicorn --timeout 200 -w 3 -b 0.0.0.0:5000 drawing:app
|
configs/train_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"generator_lr" : 1e-4,
|
| 3 |
+
"discriminator_lr" : 4e-4,
|
| 4 |
+
"epochs" : 15,
|
| 5 |
+
"lr_decrease_epoch" : 10,
|
| 6 |
+
"finetuning_generator_lr" : 1e-6,
|
| 7 |
+
"finetuning_iterations" : 3500,
|
| 8 |
+
"batch_size" : 4,
|
| 9 |
+
"number_of_mults" : 3
|
| 10 |
+
}
|
configs/xdog_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"sigma" : 0.5,
|
| 3 |
+
"k" : 8,
|
| 4 |
+
"phi" : 89.25,
|
| 5 |
+
"gamma" : 0.95,
|
| 6 |
+
"eps" : -0.1,
|
| 7 |
+
"mult" : 7
|
| 8 |
+
}
|
dataset/__pycache__/datasets.cpython-39.pyc
ADDED
|
Binary file (3.52 kB). View file
|
|
|
dataset/datasets.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from utils.utils import generate_mask
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TrainDataset(torch.utils.data.Dataset):
|
| 11 |
+
def __init__(self, data_path, transform = None, mults_amount = 1):
|
| 12 |
+
self.data = os.listdir(os.path.join(data_path, 'color'))
|
| 13 |
+
self.data_path = data_path
|
| 14 |
+
self.transform = transform
|
| 15 |
+
self.mults_amount = mults_amount
|
| 16 |
+
|
| 17 |
+
self.ToTensor = transforms.ToTensor()
|
| 18 |
+
def __len__(self):
|
| 19 |
+
return len(self.data)
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, idx):
|
| 22 |
+
image_name = self.data[idx]
|
| 23 |
+
|
| 24 |
+
color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if self.mults_amount > 1:
|
| 28 |
+
mult_number = np.random.choice(range(self.mults_amount))
|
| 29 |
+
|
| 30 |
+
bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
|
| 31 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
|
| 32 |
+
else:
|
| 33 |
+
bw_name = self.data[idx]
|
| 34 |
+
dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
|
| 38 |
+
dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
|
| 39 |
+
|
| 40 |
+
bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
|
| 41 |
+
|
| 42 |
+
if self.transform:
|
| 43 |
+
result = self.transform(image = color_img, mask = bw_img)
|
| 44 |
+
color_img = result['image']
|
| 45 |
+
bw_img = result['mask']
|
| 46 |
+
|
| 47 |
+
dfm_img = bw_img[:, :, 1]
|
| 48 |
+
bw_img = bw_img[:, :, 0]
|
| 49 |
+
|
| 50 |
+
color_img = self.ToTensor(color_img)
|
| 51 |
+
bw_img = self.ToTensor(bw_img)
|
| 52 |
+
|
| 53 |
+
dfm_img = self.ToTensor(dfm_img)
|
| 54 |
+
|
| 55 |
+
color_img = (color_img - 0.5) / 0.5
|
| 56 |
+
|
| 57 |
+
mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
|
| 58 |
+
hint = torch.cat((color_img * mask, mask), 0)
|
| 59 |
+
|
| 60 |
+
return bw_img, color_img, hint, dfm_img
|
| 61 |
+
|
| 62 |
+
class FineTuningDataset(torch.utils.data.Dataset):
|
| 63 |
+
def __init__(self, data_path, transform = None, mult_amount = 1):
|
| 64 |
+
self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
|
| 65 |
+
self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
|
| 66 |
+
self.data_path = data_path
|
| 67 |
+
self.transform = transform
|
| 68 |
+
self.mults_amount = mult_amount
|
| 69 |
+
|
| 70 |
+
np.random.shuffle(self.color_data)
|
| 71 |
+
|
| 72 |
+
self.ToTensor = transforms.ToTensor()
|
| 73 |
+
def __len__(self):
|
| 74 |
+
return len(self.data)
|
| 75 |
+
|
| 76 |
+
def __getitem__(self, idx):
|
| 77 |
+
color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
|
| 78 |
+
|
| 79 |
+
image_name = self.data[idx]
|
| 80 |
+
if self.mults_amount > 1:
|
| 81 |
+
mult_number = np.random.choice(range(self.mults_amount))
|
| 82 |
+
|
| 83 |
+
bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
|
| 84 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
|
| 85 |
+
else:
|
| 86 |
+
bw_name = self.data[idx]
|
| 87 |
+
dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
|
| 91 |
+
dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', dfm_name)), 2)
|
| 92 |
+
|
| 93 |
+
if self.transform:
|
| 94 |
+
result = self.transform(image = color_img)
|
| 95 |
+
color_img = result['image']
|
| 96 |
+
|
| 97 |
+
result = self.transform(image = bw_img, mask = dfm_img)
|
| 98 |
+
bw_img = result['image']
|
| 99 |
+
dfm_img = result['mask']
|
| 100 |
+
|
| 101 |
+
color_img = self.ToTensor(color_img)
|
| 102 |
+
bw_img = self.ToTensor(bw_img)
|
| 103 |
+
dfm_img = self.ToTensor(dfm_img)
|
| 104 |
+
|
| 105 |
+
color_img = (color_img - 0.5) / 0.5
|
| 106 |
+
|
| 107 |
+
return bw_img, dfm_img, color_img
|
denoising/denoiser.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Denoise an image with the FFDNet denoising method
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 5 |
+
|
| 6 |
+
This program is free software: you can use, modify and/or
|
| 7 |
+
redistribute it under the terms of the GNU General Public
|
| 8 |
+
License as published by the Free Software Foundation, either
|
| 9 |
+
version 3 of the License, or (at your option) any later
|
| 10 |
+
version. You should have received a copy of this license along
|
| 11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 12 |
+
"""
|
| 13 |
+
import os
|
| 14 |
+
import argparse
|
| 15 |
+
import time
|
| 16 |
+
import numpy as np
|
| 17 |
+
import cv2
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from torch.autograd import Variable
|
| 21 |
+
from denoising.models import FFDNet
|
| 22 |
+
from denoising.utils import normalize, variable_to_cv2_image, remove_dataparallel_wrapper, is_rgb
|
| 23 |
+
|
| 24 |
+
class FFDNetDenoiser:
|
| 25 |
+
def __init__(self, _device, _sigma = 25, _weights_dir = 'denoising/models/', _in_ch = 3):
|
| 26 |
+
self.sigma = _sigma / 255
|
| 27 |
+
self.weights_dir = _weights_dir
|
| 28 |
+
self.channels = _in_ch
|
| 29 |
+
self.device = _device
|
| 30 |
+
|
| 31 |
+
self.model = FFDNet(num_input_channels = _in_ch)
|
| 32 |
+
self.load_weights()
|
| 33 |
+
self.model.eval()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_weights(self):
|
| 37 |
+
weights_name = 'net_rgb.pth' if self.channels == 3 else 'net_gray.pth'
|
| 38 |
+
weights_path = os.path.join(self.weights_dir, weights_name)
|
| 39 |
+
if self.device == 'cuda':
|
| 40 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
| 41 |
+
device_ids = [0]
|
| 42 |
+
self.model = nn.DataParallel(self.model, device_ids=device_ids).cuda()
|
| 43 |
+
else:
|
| 44 |
+
state_dict = torch.load(weights_path, map_location='cpu')
|
| 45 |
+
# CPU mode: remove the DataParallel wrapper
|
| 46 |
+
state_dict = remove_dataparallel_wrapper(state_dict)
|
| 47 |
+
self.model.load_state_dict(state_dict)
|
| 48 |
+
|
| 49 |
+
def get_denoised_image(self, imorig, sigma = None):
|
| 50 |
+
|
| 51 |
+
if sigma is not None:
|
| 52 |
+
cur_sigma = sigma / 255
|
| 53 |
+
else:
|
| 54 |
+
cur_sigma = self.sigma
|
| 55 |
+
|
| 56 |
+
if len(imorig.shape) < 3 or imorig.shape[2] == 1:
|
| 57 |
+
imorig = np.repeat(np.expand_dims(imorig, 2), 3, 2)
|
| 58 |
+
|
| 59 |
+
if (max(imorig.shape[0], imorig.shape[1]) > 1200):
|
| 60 |
+
ratio = max(imorig.shape[0], imorig.shape[1]) / 1200
|
| 61 |
+
imorig = cv2.resize(imorig, (int(imorig.shape[1] / ratio), int(imorig.shape[0] / ratio)), interpolation = cv2.INTER_AREA)
|
| 62 |
+
|
| 63 |
+
imorig = imorig.transpose(2, 0, 1)
|
| 64 |
+
|
| 65 |
+
if (imorig.max() > 1.2):
|
| 66 |
+
imorig = normalize(imorig)
|
| 67 |
+
imorig = np.expand_dims(imorig, 0)
|
| 68 |
+
|
| 69 |
+
# Handle odd sizes
|
| 70 |
+
expanded_h = False
|
| 71 |
+
expanded_w = False
|
| 72 |
+
sh_im = imorig.shape
|
| 73 |
+
if sh_im[2]%2 == 1:
|
| 74 |
+
expanded_h = True
|
| 75 |
+
imorig = np.concatenate((imorig, imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
|
| 76 |
+
|
| 77 |
+
if sh_im[3]%2 == 1:
|
| 78 |
+
expanded_w = True
|
| 79 |
+
imorig = np.concatenate((imorig, imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
imorig = torch.Tensor(imorig)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Sets data type according to CPU or GPU modes
|
| 86 |
+
if self.device == 'cuda':
|
| 87 |
+
dtype = torch.cuda.FloatTensor
|
| 88 |
+
else:
|
| 89 |
+
dtype = torch.FloatTensor
|
| 90 |
+
|
| 91 |
+
imnoisy = imorig.clone()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
imorig, imnoisy = imorig.type(dtype), imnoisy.type(dtype)
|
| 96 |
+
nsigma = torch.FloatTensor([cur_sigma]).type(dtype)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Estimate noise and subtract it to the input image
|
| 100 |
+
im_noise_estim = self.model(imnoisy, nsigma)
|
| 101 |
+
outim = torch.clamp(imnoisy-im_noise_estim, 0., 1.)
|
| 102 |
+
|
| 103 |
+
if expanded_h:
|
| 104 |
+
imorig = imorig[:, :, :-1, :]
|
| 105 |
+
outim = outim[:, :, :-1, :]
|
| 106 |
+
imnoisy = imnoisy[:, :, :-1, :]
|
| 107 |
+
|
| 108 |
+
if expanded_w:
|
| 109 |
+
imorig = imorig[:, :, :, :-1]
|
| 110 |
+
outim = outim[:, :, :, :-1]
|
| 111 |
+
imnoisy = imnoisy[:, :, :, :-1]
|
| 112 |
+
|
| 113 |
+
return variable_to_cv2_image(outim)
|
denoising/functions.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Functions implementing custom NN layers
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 5 |
+
|
| 6 |
+
This program is free software: you can use, modify and/or
|
| 7 |
+
redistribute it under the terms of the GNU General Public
|
| 8 |
+
License as published by the Free Software Foundation, either
|
| 9 |
+
version 3 of the License, or (at your option) any later
|
| 10 |
+
version. You should have received a copy of this license along
|
| 11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 12 |
+
"""
|
| 13 |
+
import torch
|
| 14 |
+
from torch.autograd import Function, Variable
|
| 15 |
+
|
| 16 |
+
def concatenate_input_noise_map(input, noise_sigma):
|
| 17 |
+
r"""Implements the first layer of FFDNet. This function returns a
|
| 18 |
+
torch.autograd.Variable composed of the concatenation of the downsampled
|
| 19 |
+
input image and the noise map. Each image of the batch of size CxHxW gets
|
| 20 |
+
converted to an array of size 4*CxH/2xW/2. Each of the pixels of the
|
| 21 |
+
non-overlapped 2x2 patches of the input image are placed in the new array
|
| 22 |
+
along the first dimension.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
input: batch containing CxHxW images
|
| 26 |
+
noise_sigma: the value of the pixels of the CxH/2xW/2 noise map
|
| 27 |
+
"""
|
| 28 |
+
# noise_sigma is a list of length batch_size
|
| 29 |
+
N, C, H, W = input.size()
|
| 30 |
+
dtype = input.type()
|
| 31 |
+
sca = 2
|
| 32 |
+
sca2 = sca*sca
|
| 33 |
+
Cout = sca2*C
|
| 34 |
+
Hout = H//sca
|
| 35 |
+
Wout = W//sca
|
| 36 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
| 37 |
+
|
| 38 |
+
# Fill the downsampled image with zeros
|
| 39 |
+
if 'cuda' in dtype:
|
| 40 |
+
downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0)
|
| 41 |
+
else:
|
| 42 |
+
downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0)
|
| 43 |
+
|
| 44 |
+
# Build the CxH/2xW/2 noise map
|
| 45 |
+
noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout)
|
| 46 |
+
|
| 47 |
+
# Populate output
|
| 48 |
+
for idx in range(sca2):
|
| 49 |
+
downsampledfeatures[:, idx:Cout:sca2, :, :] = \
|
| 50 |
+
input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
|
| 51 |
+
|
| 52 |
+
# concatenate de-interleaved mosaic with noise map
|
| 53 |
+
return torch.cat((noise_map, downsampledfeatures), 1)
|
| 54 |
+
|
| 55 |
+
class UpSampleFeaturesFunction(Function):
|
| 56 |
+
r"""Extends PyTorch's modules by implementing a torch.autograd.Function.
|
| 57 |
+
This class implements the forward and backward methods of the last layer
|
| 58 |
+
of FFDNet. It basically performs the inverse of
|
| 59 |
+
concatenate_input_noise_map(): it converts each of the images of a
|
| 60 |
+
batch of size CxH/2xW/2 to images of size C/4xHxW
|
| 61 |
+
"""
|
| 62 |
+
@staticmethod
|
| 63 |
+
def forward(ctx, input):
|
| 64 |
+
N, Cin, Hin, Win = input.size()
|
| 65 |
+
dtype = input.type()
|
| 66 |
+
sca = 2
|
| 67 |
+
sca2 = sca*sca
|
| 68 |
+
Cout = Cin//sca2
|
| 69 |
+
Hout = Hin*sca
|
| 70 |
+
Wout = Win*sca
|
| 71 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
| 72 |
+
|
| 73 |
+
assert (Cin%sca2 == 0), 'Invalid input dimensions: number of channels should be divisible by 4'
|
| 74 |
+
|
| 75 |
+
result = torch.zeros((N, Cout, Hout, Wout)).type(dtype)
|
| 76 |
+
for idx in range(sca2):
|
| 77 |
+
result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = input[:, idx:Cin:sca2, :, :]
|
| 78 |
+
|
| 79 |
+
return result
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def backward(ctx, grad_output):
|
| 83 |
+
N, Cg_out, Hg_out, Wg_out = grad_output.size()
|
| 84 |
+
dtype = grad_output.data.type()
|
| 85 |
+
sca = 2
|
| 86 |
+
sca2 = sca*sca
|
| 87 |
+
Cg_in = sca2*Cg_out
|
| 88 |
+
Hg_in = Hg_out//sca
|
| 89 |
+
Wg_in = Wg_out//sca
|
| 90 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
| 91 |
+
|
| 92 |
+
# Build output
|
| 93 |
+
grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype)
|
| 94 |
+
# Populate output
|
| 95 |
+
for idx in range(sca2):
|
| 96 |
+
grad_input[:, idx:Cg_in:sca2, :, :] = grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
|
| 97 |
+
|
| 98 |
+
return Variable(grad_input)
|
| 99 |
+
|
| 100 |
+
# Alias functions
|
| 101 |
+
upsamplefeatures = UpSampleFeaturesFunction.apply
|
denoising/models.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Definition of the FFDNet model and its custom layers
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 5 |
+
|
| 6 |
+
This program is free software: you can use, modify and/or
|
| 7 |
+
redistribute it under the terms of the GNU General Public
|
| 8 |
+
License as published by the Free Software Foundation, either
|
| 9 |
+
version 3 of the License, or (at your option) any later
|
| 10 |
+
version. You should have received a copy of this license along
|
| 11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 12 |
+
"""
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
import denoising.functions as functions
|
| 16 |
+
|
| 17 |
+
class UpSampleFeatures(nn.Module):
|
| 18 |
+
r"""Implements the last layer of FFDNet
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super(UpSampleFeatures, self).__init__()
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return functions.upsamplefeatures(x)
|
| 24 |
+
|
| 25 |
+
class IntermediateDnCNN(nn.Module):
|
| 26 |
+
r"""Implements the middel part of the FFDNet architecture, which
|
| 27 |
+
is basically a DnCNN net
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self, input_features, middle_features, num_conv_layers):
|
| 30 |
+
super(IntermediateDnCNN, self).__init__()
|
| 31 |
+
self.kernel_size = 3
|
| 32 |
+
self.padding = 1
|
| 33 |
+
self.input_features = input_features
|
| 34 |
+
self.num_conv_layers = num_conv_layers
|
| 35 |
+
self.middle_features = middle_features
|
| 36 |
+
if self.input_features == 5:
|
| 37 |
+
self.output_features = 4 #Grayscale image
|
| 38 |
+
elif self.input_features == 15:
|
| 39 |
+
self.output_features = 12 #RGB image
|
| 40 |
+
else:
|
| 41 |
+
raise Exception('Invalid number of input features')
|
| 42 |
+
|
| 43 |
+
layers = []
|
| 44 |
+
layers.append(nn.Conv2d(in_channels=self.input_features,\
|
| 45 |
+
out_channels=self.middle_features,\
|
| 46 |
+
kernel_size=self.kernel_size,\
|
| 47 |
+
padding=self.padding,\
|
| 48 |
+
bias=False))
|
| 49 |
+
layers.append(nn.ReLU(inplace=True))
|
| 50 |
+
for _ in range(self.num_conv_layers-2):
|
| 51 |
+
layers.append(nn.Conv2d(in_channels=self.middle_features,\
|
| 52 |
+
out_channels=self.middle_features,\
|
| 53 |
+
kernel_size=self.kernel_size,\
|
| 54 |
+
padding=self.padding,\
|
| 55 |
+
bias=False))
|
| 56 |
+
layers.append(nn.BatchNorm2d(self.middle_features))
|
| 57 |
+
layers.append(nn.ReLU(inplace=True))
|
| 58 |
+
layers.append(nn.Conv2d(in_channels=self.middle_features,\
|
| 59 |
+
out_channels=self.output_features,\
|
| 60 |
+
kernel_size=self.kernel_size,\
|
| 61 |
+
padding=self.padding,\
|
| 62 |
+
bias=False))
|
| 63 |
+
self.itermediate_dncnn = nn.Sequential(*layers)
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
out = self.itermediate_dncnn(x)
|
| 66 |
+
return out
|
| 67 |
+
|
| 68 |
+
class FFDNet(nn.Module):
|
| 69 |
+
r"""Implements the FFDNet architecture
|
| 70 |
+
"""
|
| 71 |
+
def __init__(self, num_input_channels):
|
| 72 |
+
super(FFDNet, self).__init__()
|
| 73 |
+
self.num_input_channels = num_input_channels
|
| 74 |
+
if self.num_input_channels == 1:
|
| 75 |
+
# Grayscale image
|
| 76 |
+
self.num_feature_maps = 64
|
| 77 |
+
self.num_conv_layers = 15
|
| 78 |
+
self.downsampled_channels = 5
|
| 79 |
+
self.output_features = 4
|
| 80 |
+
elif self.num_input_channels == 3:
|
| 81 |
+
# RGB image
|
| 82 |
+
self.num_feature_maps = 96
|
| 83 |
+
self.num_conv_layers = 12
|
| 84 |
+
self.downsampled_channels = 15
|
| 85 |
+
self.output_features = 12
|
| 86 |
+
else:
|
| 87 |
+
raise Exception('Invalid number of input features')
|
| 88 |
+
|
| 89 |
+
self.intermediate_dncnn = IntermediateDnCNN(\
|
| 90 |
+
input_features=self.downsampled_channels,\
|
| 91 |
+
middle_features=self.num_feature_maps,\
|
| 92 |
+
num_conv_layers=self.num_conv_layers)
|
| 93 |
+
self.upsamplefeatures = UpSampleFeatures()
|
| 94 |
+
|
| 95 |
+
def forward(self, x, noise_sigma):
|
| 96 |
+
concat_noise_x = functions.concatenate_input_noise_map(x.data, noise_sigma.data)
|
| 97 |
+
concat_noise_x = Variable(concat_noise_x)
|
| 98 |
+
h_dncnn = self.intermediate_dncnn(concat_noise_x)
|
| 99 |
+
pred_noise = self.upsamplefeatures(h_dncnn)
|
| 100 |
+
return pred_noise
|
denoising/models/.gitkeep
ADDED
|
File without changes
|
denoising/utils.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Different utilities such as orthogonalization of weights, initialization of
|
| 3 |
+
loggers, etc
|
| 4 |
+
|
| 5 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 6 |
+
|
| 7 |
+
This program is free software: you can use, modify and/or
|
| 8 |
+
redistribute it under the terms of the GNU General Public
|
| 9 |
+
License as published by the Free Software Foundation, either
|
| 10 |
+
version 3 of the License, or (at your option) any later
|
| 11 |
+
version. You should have received a copy of this license along
|
| 12 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 13 |
+
"""
|
| 14 |
+
import numpy as np
|
| 15 |
+
import cv2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def variable_to_cv2_image(varim):
|
| 19 |
+
r"""Converts a torch.autograd.Variable to an OpenCV image
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
varim: a torch.autograd.Variable
|
| 23 |
+
"""
|
| 24 |
+
nchannels = varim.size()[1]
|
| 25 |
+
if nchannels == 1:
|
| 26 |
+
res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8)
|
| 27 |
+
elif nchannels == 3:
|
| 28 |
+
res = varim.data.cpu().numpy()[0]
|
| 29 |
+
res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
|
| 30 |
+
res = (res*255.).clip(0, 255).astype(np.uint8)
|
| 31 |
+
else:
|
| 32 |
+
raise Exception('Number of color channels not supported')
|
| 33 |
+
return res
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def normalize(data):
|
| 37 |
+
return np.float32(data/255.)
|
| 38 |
+
|
| 39 |
+
def remove_dataparallel_wrapper(state_dict):
|
| 40 |
+
r"""Converts a DataParallel model to a normal one by removing the "module."
|
| 41 |
+
wrapper in the module dictionary
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
state_dict: a torch.nn.DataParallel state dictionary
|
| 45 |
+
"""
|
| 46 |
+
from collections import OrderedDict
|
| 47 |
+
|
| 48 |
+
new_state_dict = OrderedDict()
|
| 49 |
+
for k, vl in state_dict.items():
|
| 50 |
+
name = k[7:] # remove 'module.' of DataParallel
|
| 51 |
+
new_state_dict[name] = vl
|
| 52 |
+
|
| 53 |
+
return new_state_dict
|
| 54 |
+
|
| 55 |
+
def is_rgb(im_path):
|
| 56 |
+
r""" Returns True if the image in im_path is an RGB image
|
| 57 |
+
"""
|
| 58 |
+
from skimage.io import imread
|
| 59 |
+
rgb = False
|
| 60 |
+
im = imread(im_path)
|
| 61 |
+
if (len(im.shape) == 3):
|
| 62 |
+
if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])):
|
| 63 |
+
rgb = True
|
| 64 |
+
print("rgb: {}".format(rgb))
|
| 65 |
+
print("im shape: {}".format(im.shape))
|
| 66 |
+
return rgb
|
drawing.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
import base64
|
| 4 |
+
import random
|
| 5 |
+
import string
|
| 6 |
+
import shutil
|
| 7 |
+
import torch
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
from flask import Flask, request, jsonify, abort, redirect, url_for, render_template, send_file, Response
|
| 11 |
+
from flask_wtf import FlaskForm
|
| 12 |
+
from wtforms import StringField, FileField, BooleanField, DecimalField
|
| 13 |
+
from wtforms.validators import DataRequired
|
| 14 |
+
from flask import after_this_request
|
| 15 |
+
|
| 16 |
+
from model.models import Colorizer, Generator
|
| 17 |
+
from model.extractor import get_seresnext_extractor
|
| 18 |
+
from utils.xdog import XDoGSketcher
|
| 19 |
+
from utils.utils import open_json
|
| 20 |
+
from denoising.denoiser import FFDNetDenoiser
|
| 21 |
+
from inference import process_image_with_hint
|
| 22 |
+
from utils.utils import resize_pad
|
| 23 |
+
from utils.dataset_utils import get_sketch
|
| 24 |
+
|
| 25 |
+
def generate_id(size=25, chars=string.ascii_letters + string.digits):
|
| 26 |
+
return ''.join(random.SystemRandom().choice(chars) for _ in range(size))
|
| 27 |
+
|
| 28 |
+
def generate_unique_id(current_ids = set()):
|
| 29 |
+
id_t = generate_id()
|
| 30 |
+
while id_t in current_ids:
|
| 31 |
+
id_t = generate_id()
|
| 32 |
+
|
| 33 |
+
current_ids.add(id_t)
|
| 34 |
+
|
| 35 |
+
return id_t
|
| 36 |
+
|
| 37 |
+
app = Flask(__name__)
|
| 38 |
+
app.config.update(dict(
|
| 39 |
+
SECRET_KEY="lol kek",
|
| 40 |
+
WTF_CSRF_SECRET_KEY="cheburek"
|
| 41 |
+
))
|
| 42 |
+
|
| 43 |
+
if torch.cuda.is_available():
|
| 44 |
+
device = 'cuda'
|
| 45 |
+
else:
|
| 46 |
+
device = 'cpu'
|
| 47 |
+
|
| 48 |
+
colorizer = torch.jit.load('./model/colorizer.zip', map_location=torch.device(device))
|
| 49 |
+
|
| 50 |
+
sketcher = XDoGSketcher()
|
| 51 |
+
xdog_config = open_json('configs/xdog_config.json')
|
| 52 |
+
for key in xdog_config.keys():
|
| 53 |
+
if key in sketcher.params:
|
| 54 |
+
sketcher.params[key] = xdog_config[key]
|
| 55 |
+
|
| 56 |
+
denoiser = FFDNetDenoiser(device)
|
| 57 |
+
|
| 58 |
+
color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'device':device, 'dfm' : True, 'auto_hint' : False, 'ignore_gray' : False, 'denoiser' : denoiser, 'denoiser_sigma' : 25}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class SubmitForm(FlaskForm):
|
| 62 |
+
file = FileField(validators=[DataRequired(), ])
|
| 63 |
+
|
| 64 |
+
def preprocess_image(file_id, ext):
|
| 65 |
+
directory_path = os.path.join('static', 'temp_images', file_id)
|
| 66 |
+
original_path = os.path.join(directory_path, 'original') + ext
|
| 67 |
+
original_image = plt.imread(original_path)
|
| 68 |
+
|
| 69 |
+
resized_image, _ = resize_pad(original_image)
|
| 70 |
+
resized_image = denoiser.get_denoised_image(resized_image, 25)
|
| 71 |
+
bw, dfm = get_sketch(resized_image, sketcher, True)
|
| 72 |
+
|
| 73 |
+
resized_name = 'resized_' + str(resized_image.shape[0]) + '_' + str(resized_image.shape[1]) + '.png'
|
| 74 |
+
plt.imsave(os.path.join(directory_path, resized_name), resized_image)
|
| 75 |
+
plt.imsave(os.path.join(directory_path, 'bw.png'), bw, cmap = 'gray')
|
| 76 |
+
plt.imsave(os.path.join(directory_path, 'dfm.png'), dfm, cmap = 'gray')
|
| 77 |
+
os.remove(original_path)
|
| 78 |
+
|
| 79 |
+
empty_hint = np.zeros((resized_image.shape[0], resized_image.shape[1], 4), dtype = np.float32)
|
| 80 |
+
plt.imsave(os.path.join(directory_path, 'hint.png'), empty_hint)
|
| 81 |
+
|
| 82 |
+
@app.route('/', methods=['GET', 'POST'])
|
| 83 |
+
def upload():
|
| 84 |
+
form = SubmitForm()
|
| 85 |
+
if form.validate_on_submit():
|
| 86 |
+
input_data = form.file.data
|
| 87 |
+
|
| 88 |
+
_, ext = os.path.splitext(input_data.filename)
|
| 89 |
+
|
| 90 |
+
if ext not in ('.jpg', '.png', '.jpeg'):
|
| 91 |
+
return abort(400)
|
| 92 |
+
|
| 93 |
+
file_id = generate_unique_id()
|
| 94 |
+
directory = os.path.join('static', 'temp_images', file_id)
|
| 95 |
+
original_filename = os.path.join(directory, 'original') + ext
|
| 96 |
+
|
| 97 |
+
try :
|
| 98 |
+
os.mkdir(directory)
|
| 99 |
+
input_data.save(original_filename)
|
| 100 |
+
|
| 101 |
+
preprocess_image(file_id, ext)
|
| 102 |
+
|
| 103 |
+
return redirect(f'/draw/{file_id}')
|
| 104 |
+
|
| 105 |
+
except :
|
| 106 |
+
print('Failed to colorize')
|
| 107 |
+
if os.path.exists(directory):
|
| 108 |
+
shutil.rmtree(directory)
|
| 109 |
+
return abort(400)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
return render_template("upload.html", form = form)
|
| 113 |
+
|
| 114 |
+
@app.route('/img/<file_id>')
|
| 115 |
+
def show_image(file_id):
|
| 116 |
+
if not os.path.exists(os.path.join('static', 'temp_images', str(file_id))):
|
| 117 |
+
abort(404)
|
| 118 |
+
return f'<img src="/static/temp_images/{file_id}/colorized.png?{random. randint(1,1000000)}">'
|
| 119 |
+
|
| 120 |
+
def colorize_image(file_id):
|
| 121 |
+
directory_path = os.path.join('static', 'temp_images', file_id)
|
| 122 |
+
|
| 123 |
+
bw = plt.imread(os.path.join(directory_path, 'bw.png'))[..., :1]
|
| 124 |
+
dfm = plt.imread(os.path.join(directory_path, 'dfm.png'))[..., :1]
|
| 125 |
+
hint = plt.imread(os.path.join(directory_path, 'hint.png'))
|
| 126 |
+
|
| 127 |
+
return process_image_with_hint(bw, dfm, hint, color_args)
|
| 128 |
+
|
| 129 |
+
@app.route('/colorize', methods=['POST'])
|
| 130 |
+
def colorize():
|
| 131 |
+
|
| 132 |
+
file_id = request.form['save_file_id']
|
| 133 |
+
file_id = file_id[file_id.rfind('/') + 1:]
|
| 134 |
+
|
| 135 |
+
img_data = request.form['save_image']
|
| 136 |
+
img_data = img_data[img_data.find(',') + 1:]
|
| 137 |
+
|
| 138 |
+
directory_path = os.path.join('static', 'temp_images', file_id)
|
| 139 |
+
|
| 140 |
+
with open(os.path.join(directory_path, 'hint.png'), "wb") as im:
|
| 141 |
+
im.write(base64.decodestring(str.encode(img_data)))
|
| 142 |
+
|
| 143 |
+
result = colorize_image(file_id)
|
| 144 |
+
|
| 145 |
+
plt.imsave(os.path.join(directory_path, 'colorized.png'), result)
|
| 146 |
+
|
| 147 |
+
src_path = f'../static/temp_images/{file_id}/colorized.png?{random. randint(1,1000000)}'
|
| 148 |
+
|
| 149 |
+
return src_path
|
| 150 |
+
|
| 151 |
+
@app.route('/draw/<file_id>', methods=['GET', 'POST'])
|
| 152 |
+
def paintapp(file_id):
|
| 153 |
+
if request.method == 'GET':
|
| 154 |
+
|
| 155 |
+
directory_path = os.path.join('static', 'temp_images', str(file_id))
|
| 156 |
+
if not os.path.exists(directory_path):
|
| 157 |
+
abort(404)
|
| 158 |
+
|
| 159 |
+
resized_name = [x for x in os.listdir(directory_path) if x.startswith('resized_')][0]
|
| 160 |
+
|
| 161 |
+
split = os.path.splitext(resized_name)[0].split('_')
|
| 162 |
+
width = int(split[2])
|
| 163 |
+
height = int(split[1])
|
| 164 |
+
|
| 165 |
+
return render_template("drawing.html", height = height, width = width, img_path = os.path.join('temp_images', str(file_id), resized_name))
|
environment.yml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: manga
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- openai
|
| 6 |
+
- tiktoken
|
| 7 |
+
- setuptools
|
| 8 |
+
- numpy
|
| 9 |
+
- scipy
|
| 10 |
+
- matplotlib
|
| 11 |
+
- opencv-python
|
| 12 |
+
- scikit-learn
|
| 13 |
+
- tb-nightly
|
| 14 |
+
- flask
|
| 15 |
+
- gunicorn
|
| 16 |
+
- flask-wtf
|
| 17 |
+
- snowy
|
| 18 |
+
- scikit-image
|
| 19 |
+
- patool
|
| 20 |
+
- albumentations
|
| 21 |
+
- PyYAML
|
| 22 |
+
- qudida
|
| 23 |
+
- joblib
|
| 24 |
+
- threadpoolctl
|
| 25 |
+
- typing-extensions
|
| 26 |
+
- imageio
|
| 27 |
+
- pillow
|
| 28 |
+
- PyWavelets
|
| 29 |
+
- tifffile
|
| 30 |
+
- imutils
|
| 31 |
+
- cycler
|
| 32 |
+
- kiwisolver
|
| 33 |
+
- pyparsing
|
| 34 |
+
- python-dateutil
|
| 35 |
+
- pipdeptree
|
| 36 |
+
- numba
|
| 37 |
+
- llvmlite
|
| 38 |
+
- torch
|
| 39 |
+
- future
|
| 40 |
+
- tqdm
|
| 41 |
+
- colorama
|
| 42 |
+
- wheel
|
| 43 |
+
- torchvision
|
inference.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from utils.dataset_utils import get_sketch
|
| 5 |
+
from utils.utils import resize_pad, generate_mask, extract_cbr, create_cbz, sorted_alphanumeric, subfolder_image_search, remove_folder
|
| 6 |
+
from torchvision.transforms import ToTensor
|
| 7 |
+
import os
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import argparse
|
| 10 |
+
from model.models import Colorizer, Generator
|
| 11 |
+
from model.extractor import get_seresnext_extractor
|
| 12 |
+
from utils.xdog import XDoGSketcher
|
| 13 |
+
from utils.utils import open_json
|
| 14 |
+
import sys
|
| 15 |
+
from denoising.denoiser import FFDNetDenoiser
|
| 16 |
+
|
| 17 |
+
def colorize_without_hint(inp, color_args):
|
| 18 |
+
i_hint = torch.zeros(1, 4, inp.shape[2], inp.shape[3]).float().to(color_args['device'])
|
| 19 |
+
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
fake_color, _ = color_args['colorizer'](torch.cat([inp, i_hint], 1))
|
| 22 |
+
|
| 23 |
+
if color_args['auto_hint']:
|
| 24 |
+
mask = generate_mask(fake_color.shape[2], fake_color.shape[3], full = False, prob = 1, sigma = color_args['auto_hint_sigma']).unsqueeze(0)
|
| 25 |
+
mask = mask.to(color_args['device'])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if color_args['ignore_gray']:
|
| 29 |
+
diff1 = torch.abs(fake_color[:, 0] - fake_color[:, 1])
|
| 30 |
+
diff2 = torch.abs(fake_color[:, 0] - fake_color[:, 2])
|
| 31 |
+
diff3 = torch.abs(fake_color[:, 1] - fake_color[:, 2])
|
| 32 |
+
mask = ((mask + ((diff1 + diff2 + diff3) > 60 / 255).float().unsqueeze(1)) == 2).float()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
i_hint = torch.cat([fake_color * mask, mask], 1)
|
| 36 |
+
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
fake_color, _ = color_args['colorizer'](torch.cat([inp, i_hint], 1))
|
| 39 |
+
|
| 40 |
+
return fake_color
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def process_image(image, color_args, to_tensor = ToTensor()):
|
| 44 |
+
image, pad = resize_pad(image)
|
| 45 |
+
|
| 46 |
+
if color_args['denoiser'] is not None:
|
| 47 |
+
image = color_args['denoiser'].get_denoised_image(image, color_args['denoiser_sigma'])
|
| 48 |
+
|
| 49 |
+
bw, dfm = get_sketch(image, color_args['sketcher'], color_args['dfm'])
|
| 50 |
+
|
| 51 |
+
bw = to_tensor(bw).unsqueeze(0).to(color_args['device'])
|
| 52 |
+
dfm = to_tensor(dfm).unsqueeze(0).to(color_args['device'])
|
| 53 |
+
|
| 54 |
+
output = colorize_without_hint(torch.cat([bw, dfm], 1), color_args)
|
| 55 |
+
result = output[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
|
| 56 |
+
|
| 57 |
+
if pad[0] != 0:
|
| 58 |
+
result = result[:-pad[0]]
|
| 59 |
+
if pad[1] != 0:
|
| 60 |
+
result = result[:, :-pad[1]]
|
| 61 |
+
|
| 62 |
+
return result
|
| 63 |
+
|
| 64 |
+
def colorize_with_hint(inp, color_args):
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
fake_color, _ = color_args['colorizer'](inp)
|
| 67 |
+
|
| 68 |
+
return fake_color
|
| 69 |
+
|
| 70 |
+
def process_image_with_hint(bw, dfm, hint, color_args, to_tensor = ToTensor()):
|
| 71 |
+
bw = to_tensor(bw).unsqueeze(0).to(color_args['device'])
|
| 72 |
+
dfm = to_tensor(dfm).unsqueeze(0).to(color_args['device'])
|
| 73 |
+
|
| 74 |
+
i_hint = (torch.FloatTensor(hint[..., :3]).permute(2, 0, 1) - 0.5) / 0.5
|
| 75 |
+
mask = torch.FloatTensor(hint[..., 3:]).permute(2, 0, 1)
|
| 76 |
+
i_hint = torch.cat([i_hint * mask, mask], 0).unsqueeze(0).to(color_args['device'])
|
| 77 |
+
|
| 78 |
+
output = colorize_with_hint(torch.cat([bw, dfm, i_hint], 1), color_args)
|
| 79 |
+
result = output[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
|
| 80 |
+
|
| 81 |
+
return result
|
| 82 |
+
|
| 83 |
+
def colorize_single_image(file_path, save_path, color_args):
|
| 84 |
+
try:
|
| 85 |
+
image = plt.imread(file_path)
|
| 86 |
+
|
| 87 |
+
colorization = process_image(image, color_args)
|
| 88 |
+
|
| 89 |
+
plt.imsave(save_path, colorization)
|
| 90 |
+
|
| 91 |
+
return True
|
| 92 |
+
except KeyboardInterrupt:
|
| 93 |
+
sys.exit(0)
|
| 94 |
+
except:
|
| 95 |
+
print('Failed to colorize {}'.format(file_path))
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
def colorize_images(source_path, target_path, color_args):
|
| 99 |
+
images = os.listdir(source_path)
|
| 100 |
+
|
| 101 |
+
for image_name in images:
|
| 102 |
+
file_path = os.path.join(source_path, image_name)
|
| 103 |
+
|
| 104 |
+
name, ext = os.path.splitext(image_name)
|
| 105 |
+
if (ext != '.png'):
|
| 106 |
+
image_name = name + '.png'
|
| 107 |
+
|
| 108 |
+
save_path = os.path.join(target_path, image_name)
|
| 109 |
+
colorize_single_image(file_path, save_path, color_args)
|
| 110 |
+
|
| 111 |
+
def colorize_cbr(file_path, color_args):
|
| 112 |
+
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 113 |
+
temp_path = 'temp_colorization'
|
| 114 |
+
|
| 115 |
+
if not os.path.exists(temp_path):
|
| 116 |
+
os.makedirs(temp_path)
|
| 117 |
+
extract_cbr(file_path, temp_path)
|
| 118 |
+
|
| 119 |
+
images = subfolder_image_search(temp_path)
|
| 120 |
+
|
| 121 |
+
result_images = []
|
| 122 |
+
for image_path in images:
|
| 123 |
+
save_path = image_path
|
| 124 |
+
|
| 125 |
+
path, ext = os.path.splitext(save_path)
|
| 126 |
+
if (ext != '.png'):
|
| 127 |
+
save_path = path + '.png'
|
| 128 |
+
|
| 129 |
+
res_flag = colorize_single_image(image_path, save_path, color_args)
|
| 130 |
+
|
| 131 |
+
result_images.append(save_path if res_flag else image_path)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
result_name = os.path.join(os.path.dirname(file_path), file_name + '_colorized.cbz')
|
| 135 |
+
|
| 136 |
+
create_cbz(result_name, result_images)
|
| 137 |
+
|
| 138 |
+
remove_folder(temp_path)
|
| 139 |
+
|
| 140 |
+
return result_name
|
| 141 |
+
|
| 142 |
+
def parse_args():
|
| 143 |
+
parser = argparse.ArgumentParser()
|
| 144 |
+
parser.add_argument("-p", "--path", required=True)
|
| 145 |
+
parser.add_argument("-gen", "--generator", default = 'model/generator.pth')
|
| 146 |
+
parser.add_argument("-ext", "--extractor", default = 'model/extractor.pth')
|
| 147 |
+
parser.add_argument("-s", "--sigma", type = float, default = 0.003)
|
| 148 |
+
parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
|
| 149 |
+
parser.add_argument('-ah', '--auto', dest = 'autohint', action = 'store_true')
|
| 150 |
+
parser.add_argument('-ig', '--ignore_grey', dest = 'ignore', action = 'store_true')
|
| 151 |
+
parser.add_argument('-nd', '--no_denoise', dest = 'denoiser', action = 'store_false')
|
| 152 |
+
parser.add_argument("-ds", "--denoiser_sigma", type = int, default = 25)
|
| 153 |
+
parser.set_defaults(gpu = False)
|
| 154 |
+
parser.set_defaults(autohint = False)
|
| 155 |
+
parser.set_defaults(ignore = False)
|
| 156 |
+
parser.set_defaults(denoiser = True)
|
| 157 |
+
args = parser.parse_args()
|
| 158 |
+
|
| 159 |
+
return args
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
|
| 164 |
+
args = parse_args()
|
| 165 |
+
|
| 166 |
+
if args.gpu:
|
| 167 |
+
device = 'cuda'
|
| 168 |
+
else:
|
| 169 |
+
device = 'cpu'
|
| 170 |
+
|
| 171 |
+
generator = Generator()
|
| 172 |
+
generator.load_state_dict(torch.load(args.generator))
|
| 173 |
+
|
| 174 |
+
extractor = get_seresnext_extractor()
|
| 175 |
+
extractor.load_state_dict(torch.load(args.extractor))
|
| 176 |
+
|
| 177 |
+
colorizer = Colorizer(generator, extractor)
|
| 178 |
+
colorizer = colorizer.eval().to(device)
|
| 179 |
+
|
| 180 |
+
sketcher = XDoGSketcher()
|
| 181 |
+
xdog_config = open_json('configs/xdog_config.json')
|
| 182 |
+
for key in xdog_config.keys():
|
| 183 |
+
if key in sketcher.params:
|
| 184 |
+
sketcher.params[key] = xdog_config[key]
|
| 185 |
+
|
| 186 |
+
denoiser = None
|
| 187 |
+
if args.denoiser:
|
| 188 |
+
denoiser = FFDNetDenoiser(device, args.denoiser_sigma)
|
| 189 |
+
|
| 190 |
+
color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'auto_hint':args.autohint, 'auto_hint_sigma':args.sigma,\
|
| 191 |
+
'ignore_gray':args.ignore, 'device':device, 'dfm' : True, 'denoiser':denoiser, 'denoiser_sigma' : args.denoiser_sigma}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if os.path.isdir(args.path):
|
| 195 |
+
colorization_path = os.path.join(args.path, 'colorization')
|
| 196 |
+
if not os.path.exists(colorization_path):
|
| 197 |
+
os.makedirs(colorization_path)
|
| 198 |
+
|
| 199 |
+
colorize_images(args.path, colorization_path, color_args)
|
| 200 |
+
|
| 201 |
+
elif os.path.isfile(args.path):
|
| 202 |
+
|
| 203 |
+
split = os.path.splitext(args.path)
|
| 204 |
+
|
| 205 |
+
if split[1].lower() in ('.cbr', '.cbz', '.rar', '.zip'):
|
| 206 |
+
colorize_cbr(args.path, color_args)
|
| 207 |
+
elif split[1].lower() in ('.jpg', '.png', ',jpeg'):
|
| 208 |
+
new_image_path = split[0] + '_colorized' + '.png'
|
| 209 |
+
|
| 210 |
+
colorize_single_image(args.path, new_image_path, color_args)
|
| 211 |
+
else:
|
| 212 |
+
print('Wrong format')
|
| 213 |
+
else:
|
| 214 |
+
print('Wrong path')
|
| 215 |
+
|
model/__pycache__/extractor.cpython-39.pyc
ADDED
|
Binary file (3.91 kB). View file
|
|
|
model/__pycache__/models.cpython-39.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
model/extractor.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee3c59f02ac8c59298fd9b819fa33d2efa168847e15e4be39b35c286f7c18607
|
| 3 |
+
size 6340842
|
model/extractor.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
'''https://github.com/blandocs/Tag2Pix/blob/master/model/pretrained.py'''
|
| 6 |
+
|
| 7 |
+
# Pretrained version
|
| 8 |
+
class Selayer(nn.Module):
|
| 9 |
+
def __init__(self, inplanes):
|
| 10 |
+
super(Selayer, self).__init__()
|
| 11 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 12 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
|
| 13 |
+
self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
|
| 14 |
+
self.relu = nn.ReLU(inplace=True)
|
| 15 |
+
self.sigmoid = nn.Sigmoid()
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
out = self.global_avgpool(x)
|
| 19 |
+
out = self.conv1(out)
|
| 20 |
+
out = self.relu(out)
|
| 21 |
+
out = self.conv2(out)
|
| 22 |
+
out = self.sigmoid(out)
|
| 23 |
+
|
| 24 |
+
return x * out
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BottleneckX_Origin(nn.Module):
|
| 28 |
+
expansion = 4
|
| 29 |
+
|
| 30 |
+
def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
|
| 31 |
+
super(BottleneckX_Origin, self).__init__()
|
| 32 |
+
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
|
| 33 |
+
self.bn1 = nn.BatchNorm2d(planes * 2)
|
| 34 |
+
|
| 35 |
+
self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride,
|
| 36 |
+
padding=1, groups=cardinality, bias=False)
|
| 37 |
+
self.bn2 = nn.BatchNorm2d(planes * 2)
|
| 38 |
+
|
| 39 |
+
self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
|
| 40 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 41 |
+
|
| 42 |
+
self.selayer = Selayer(planes * 4)
|
| 43 |
+
|
| 44 |
+
self.relu = nn.ReLU(inplace=True)
|
| 45 |
+
self.downsample = downsample
|
| 46 |
+
self.stride = stride
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
residual = x
|
| 50 |
+
|
| 51 |
+
out = self.conv1(x)
|
| 52 |
+
out = self.bn1(out)
|
| 53 |
+
out = self.relu(out)
|
| 54 |
+
|
| 55 |
+
out = self.conv2(out)
|
| 56 |
+
out = self.bn2(out)
|
| 57 |
+
out = self.relu(out)
|
| 58 |
+
|
| 59 |
+
out = self.conv3(out)
|
| 60 |
+
out = self.bn3(out)
|
| 61 |
+
|
| 62 |
+
out = self.selayer(out)
|
| 63 |
+
|
| 64 |
+
if self.downsample is not None:
|
| 65 |
+
residual = self.downsample(x)
|
| 66 |
+
|
| 67 |
+
out += residual
|
| 68 |
+
out = self.relu(out)
|
| 69 |
+
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
class SEResNeXt_extractor(nn.Module):
|
| 73 |
+
def __init__(self, block, layers, input_channels=3, cardinality=32):
|
| 74 |
+
super(SEResNeXt_extractor, self).__init__()
|
| 75 |
+
self.cardinality = cardinality
|
| 76 |
+
self.inplanes = 64
|
| 77 |
+
self.input_channels = input_channels
|
| 78 |
+
|
| 79 |
+
self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
|
| 80 |
+
bias=False)
|
| 81 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 82 |
+
self.relu = nn.ReLU(inplace=True)
|
| 83 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 84 |
+
|
| 85 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 86 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 87 |
+
|
| 88 |
+
for m in self.modules():
|
| 89 |
+
if isinstance(m, nn.Conv2d):
|
| 90 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 91 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 92 |
+
if m.bias is not None:
|
| 93 |
+
m.bias.data.zero_()
|
| 94 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 95 |
+
m.weight.data.fill_(1)
|
| 96 |
+
m.bias.data.zero_()
|
| 97 |
+
|
| 98 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 99 |
+
downsample = None
|
| 100 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 101 |
+
downsample = nn.Sequential(
|
| 102 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 103 |
+
kernel_size=1, stride=stride, bias=False),
|
| 104 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
layers = []
|
| 108 |
+
layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
|
| 109 |
+
self.inplanes = planes * block.expansion
|
| 110 |
+
for i in range(1, blocks):
|
| 111 |
+
layers.append(block(self.inplanes, planes, self.cardinality))
|
| 112 |
+
|
| 113 |
+
return nn.Sequential(*layers)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
x = self.conv1(x)
|
| 117 |
+
x = self.bn1(x)
|
| 118 |
+
x = self.relu(x)
|
| 119 |
+
x = self.maxpool(x)
|
| 120 |
+
|
| 121 |
+
x = self.layer1(x)
|
| 122 |
+
x = self.layer2(x)
|
| 123 |
+
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
def get_seresnext_extractor():
|
| 127 |
+
return SEResNeXt_extractor(BottleneckX_Origin, [3, 4, 6, 3], 1)
|
model/models.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.models as M
|
| 5 |
+
import math
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from torch.nn import Parameter
|
| 8 |
+
|
| 9 |
+
'''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
|
| 10 |
+
|
| 11 |
+
def l2normalize(v, eps=1e-12):
|
| 12 |
+
return v / (v.norm() + eps)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SpectralNorm(nn.Module):
|
| 16 |
+
def __init__(self, module, name='weight', power_iterations=1):
|
| 17 |
+
super(SpectralNorm, self).__init__()
|
| 18 |
+
self.module = module
|
| 19 |
+
self.name = name
|
| 20 |
+
self.power_iterations = power_iterations
|
| 21 |
+
if not self._made_params():
|
| 22 |
+
self._make_params()
|
| 23 |
+
|
| 24 |
+
def _update_u_v(self):
|
| 25 |
+
u = getattr(self.module, self.name + "_u")
|
| 26 |
+
v = getattr(self.module, self.name + "_v")
|
| 27 |
+
w = getattr(self.module, self.name + "_bar")
|
| 28 |
+
|
| 29 |
+
height = w.data.shape[0]
|
| 30 |
+
for _ in range(self.power_iterations):
|
| 31 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
|
| 32 |
+
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
|
| 33 |
+
|
| 34 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
| 35 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
| 36 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
| 37 |
+
|
| 38 |
+
def _made_params(self):
|
| 39 |
+
try:
|
| 40 |
+
u = getattr(self.module, self.name + "_u")
|
| 41 |
+
v = getattr(self.module, self.name + "_v")
|
| 42 |
+
w = getattr(self.module, self.name + "_bar")
|
| 43 |
+
return True
|
| 44 |
+
except AttributeError:
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _make_params(self):
|
| 49 |
+
w = getattr(self.module, self.name)
|
| 50 |
+
height = w.data.shape[0]
|
| 51 |
+
width = w.view(height, -1).data.shape[1]
|
| 52 |
+
|
| 53 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
| 54 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
| 55 |
+
u.data = l2normalize(u.data)
|
| 56 |
+
v.data = l2normalize(v.data)
|
| 57 |
+
w_bar = Parameter(w.data)
|
| 58 |
+
|
| 59 |
+
del self.module._parameters[self.name]
|
| 60 |
+
|
| 61 |
+
self.module.register_parameter(self.name + "_u", u)
|
| 62 |
+
self.module.register_parameter(self.name + "_v", v)
|
| 63 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def forward(self, *args):
|
| 67 |
+
self._update_u_v()
|
| 68 |
+
return self.module.forward(*args)
|
| 69 |
+
|
| 70 |
+
class Selayer(nn.Module):
|
| 71 |
+
def __init__(self, inplanes):
|
| 72 |
+
super(Selayer, self).__init__()
|
| 73 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 74 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
|
| 75 |
+
self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
|
| 76 |
+
self.relu = nn.ReLU(inplace=True)
|
| 77 |
+
self.sigmoid = nn.Sigmoid()
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
out = self.global_avgpool(x)
|
| 81 |
+
out = self.conv1(out)
|
| 82 |
+
out = self.relu(out)
|
| 83 |
+
out = self.conv2(out)
|
| 84 |
+
out = self.sigmoid(out)
|
| 85 |
+
|
| 86 |
+
return x * out
|
| 87 |
+
|
| 88 |
+
class SelayerSpectr(nn.Module):
|
| 89 |
+
def __init__(self, inplanes):
|
| 90 |
+
super(SelayerSpectr, self).__init__()
|
| 91 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 92 |
+
self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
|
| 93 |
+
self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
|
| 94 |
+
self.relu = nn.ReLU(inplace=True)
|
| 95 |
+
self.sigmoid = nn.Sigmoid()
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
out = self.global_avgpool(x)
|
| 99 |
+
out = self.conv1(out)
|
| 100 |
+
out = self.relu(out)
|
| 101 |
+
out = self.conv2(out)
|
| 102 |
+
out = self.sigmoid(out)
|
| 103 |
+
|
| 104 |
+
return x * out
|
| 105 |
+
|
| 106 |
+
class ResNeXtBottleneck(nn.Module):
|
| 107 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
| 108 |
+
super(ResNeXtBottleneck, self).__init__()
|
| 109 |
+
D = out_channels // 2
|
| 110 |
+
self.out_channels = out_channels
|
| 111 |
+
self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
|
| 112 |
+
self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
| 113 |
+
groups=cardinality,
|
| 114 |
+
bias=False)
|
| 115 |
+
self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
| 116 |
+
self.shortcut = nn.Sequential()
|
| 117 |
+
if stride != 1:
|
| 118 |
+
self.shortcut.add_module('shortcut',
|
| 119 |
+
nn.AvgPool2d(2, stride=2))
|
| 120 |
+
|
| 121 |
+
self.selayer = Selayer(out_channels)
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
bottleneck = self.conv_reduce.forward(x)
|
| 125 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 126 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
| 127 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 128 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
| 129 |
+
bottleneck = self.selayer(bottleneck)
|
| 130 |
+
|
| 131 |
+
x = self.shortcut.forward(x)
|
| 132 |
+
return x + bottleneck
|
| 133 |
+
|
| 134 |
+
class SpectrResNeXtBottleneck(nn.Module):
|
| 135 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
| 136 |
+
super(SpectrResNeXtBottleneck, self).__init__()
|
| 137 |
+
D = out_channels // 2
|
| 138 |
+
self.out_channels = out_channels
|
| 139 |
+
self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
|
| 140 |
+
self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
| 141 |
+
groups=cardinality,
|
| 142 |
+
bias=False))
|
| 143 |
+
self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
|
| 144 |
+
self.shortcut = nn.Sequential()
|
| 145 |
+
if stride != 1:
|
| 146 |
+
self.shortcut.add_module('shortcut',
|
| 147 |
+
nn.AvgPool2d(2, stride=2))
|
| 148 |
+
|
| 149 |
+
self.selayer = SelayerSpectr(out_channels)
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
bottleneck = self.conv_reduce.forward(x)
|
| 153 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 154 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
| 155 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 156 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
| 157 |
+
bottleneck = self.selayer(bottleneck)
|
| 158 |
+
|
| 159 |
+
x = self.shortcut.forward(x)
|
| 160 |
+
return x + bottleneck
|
| 161 |
+
|
| 162 |
+
class FeatureConv(nn.Module):
|
| 163 |
+
def __init__(self, input_dim=512, output_dim=512):
|
| 164 |
+
super(FeatureConv, self).__init__()
|
| 165 |
+
|
| 166 |
+
no_bn = True
|
| 167 |
+
|
| 168 |
+
seq = []
|
| 169 |
+
seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
| 170 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
| 171 |
+
seq.append(nn.ReLU(inplace=True))
|
| 172 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
|
| 173 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
| 174 |
+
seq.append(nn.ReLU(inplace=True))
|
| 175 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
| 176 |
+
seq.append(nn.ReLU(inplace=True))
|
| 177 |
+
|
| 178 |
+
self.network = nn.Sequential(*seq)
|
| 179 |
+
|
| 180 |
+
def forward(self, x):
|
| 181 |
+
return self.network(x)
|
| 182 |
+
|
| 183 |
+
class Generator(nn.Module):
|
| 184 |
+
def __init__(self, ngf=64):
|
| 185 |
+
super(Generator, self).__init__()
|
| 186 |
+
|
| 187 |
+
self.feature_conv = FeatureConv()
|
| 188 |
+
|
| 189 |
+
self.to0 = self._make_encoder_block_first(6, 32)
|
| 190 |
+
self.to1 = self._make_encoder_block(32, 64)
|
| 191 |
+
self.to2 = self._make_encoder_block(64, 128)
|
| 192 |
+
self.to3 = self._make_encoder_block(128, 256)
|
| 193 |
+
self.to4 = self._make_encoder_block(256, 512)
|
| 194 |
+
|
| 195 |
+
self.deconv_for_decoder = nn.Sequential(
|
| 196 |
+
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
|
| 197 |
+
nn.LeakyReLU(0.2),
|
| 198 |
+
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
|
| 199 |
+
nn.LeakyReLU(0.2),
|
| 200 |
+
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # output is 256 * 256
|
| 201 |
+
nn.LeakyReLU(0.2),
|
| 202 |
+
nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
|
| 203 |
+
nn.Tanh(),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
tunnel4 = nn.Sequential(*[ResNeXtBottleneck(ngf * 8, ngf * 8, cardinality=32, dilate=1) for _ in range(20)])
|
| 207 |
+
|
| 208 |
+
self.tunnel4 = nn.Sequential(nn.Conv2d(ngf * 8 + 512, ngf * 8, kernel_size=3, stride=1, padding=1),
|
| 209 |
+
nn.LeakyReLU(0.2, True),
|
| 210 |
+
tunnel4,
|
| 211 |
+
nn.Conv2d(ngf * 8, ngf * 4 * 4, kernel_size=3, stride=1, padding=1),
|
| 212 |
+
nn.PixelShuffle(2),
|
| 213 |
+
nn.LeakyReLU(0.2, True)
|
| 214 |
+
) # 64
|
| 215 |
+
|
| 216 |
+
depth = 2
|
| 217 |
+
tunnel = [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1) for _ in range(depth)]
|
| 218 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2) for _ in range(depth)]
|
| 219 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=4) for _ in range(depth)]
|
| 220 |
+
tunnel += [ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=2),
|
| 221 |
+
ResNeXtBottleneck(ngf * 4, ngf * 4, cardinality=32, dilate=1)]
|
| 222 |
+
tunnel3 = nn.Sequential(*tunnel)
|
| 223 |
+
|
| 224 |
+
self.tunnel3 = nn.Sequential(nn.Conv2d(ngf * 8, ngf * 4, kernel_size=3, stride=1, padding=1),
|
| 225 |
+
nn.LeakyReLU(0.2, True),
|
| 226 |
+
tunnel3,
|
| 227 |
+
nn.Conv2d(ngf * 4, ngf * 2 * 4, kernel_size=3, stride=1, padding=1),
|
| 228 |
+
nn.PixelShuffle(2),
|
| 229 |
+
nn.LeakyReLU(0.2, True)
|
| 230 |
+
) # 128
|
| 231 |
+
|
| 232 |
+
tunnel = [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1) for _ in range(depth)]
|
| 233 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2) for _ in range(depth)]
|
| 234 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=4) for _ in range(depth)]
|
| 235 |
+
tunnel += [ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=2),
|
| 236 |
+
ResNeXtBottleneck(ngf * 2, ngf * 2, cardinality=32, dilate=1)]
|
| 237 |
+
tunnel2 = nn.Sequential(*tunnel)
|
| 238 |
+
|
| 239 |
+
self.tunnel2 = nn.Sequential(nn.Conv2d(ngf * 4, ngf * 2, kernel_size=3, stride=1, padding=1),
|
| 240 |
+
nn.LeakyReLU(0.2, True),
|
| 241 |
+
tunnel2,
|
| 242 |
+
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1),
|
| 243 |
+
nn.PixelShuffle(2),
|
| 244 |
+
nn.LeakyReLU(0.2, True)
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
tunnel = [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
|
| 248 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2)]
|
| 249 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=4)]
|
| 250 |
+
tunnel += [ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=2),
|
| 251 |
+
ResNeXtBottleneck(ngf, ngf, cardinality=16, dilate=1)]
|
| 252 |
+
tunnel1 = nn.Sequential(*tunnel)
|
| 253 |
+
|
| 254 |
+
self.tunnel1 = nn.Sequential(nn.Conv2d(ngf * 2, ngf, kernel_size=3, stride=1, padding=1),
|
| 255 |
+
nn.LeakyReLU(0.2, True),
|
| 256 |
+
tunnel1,
|
| 257 |
+
nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=1, padding=1),
|
| 258 |
+
nn.PixelShuffle(2),
|
| 259 |
+
nn.LeakyReLU(0.2, True)
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
self.exit = nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _make_encoder_block(self, inplanes, planes):
|
| 266 |
+
return nn.Sequential(
|
| 267 |
+
nn.Conv2d(inplanes, planes, 3, 2, 1),
|
| 268 |
+
nn.LeakyReLU(0.2),
|
| 269 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
| 270 |
+
nn.LeakyReLU(0.2),
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def _make_encoder_block_first(self, inplanes, planes):
|
| 274 |
+
return nn.Sequential(
|
| 275 |
+
nn.Conv2d(inplanes, planes, 3, 1, 1),
|
| 276 |
+
nn.LeakyReLU(0.2),
|
| 277 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
| 278 |
+
nn.LeakyReLU(0.2),
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def forward(self, sketch, sketch_feat):
|
| 282 |
+
|
| 283 |
+
x0 = self.to0(sketch)
|
| 284 |
+
x1 = self.to1(x0)
|
| 285 |
+
x2 = self.to2(x1)
|
| 286 |
+
x3 = self.to3(x2)
|
| 287 |
+
x4 = self.to4(x3)
|
| 288 |
+
|
| 289 |
+
sketch_feat = self.feature_conv(sketch_feat)
|
| 290 |
+
|
| 291 |
+
out = self.tunnel4(torch.cat([x4, sketch_feat], 1))
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
x = self.tunnel3(torch.cat([out, x3], 1))
|
| 297 |
+
x = self.tunnel2(torch.cat([x, x2], 1))
|
| 298 |
+
x = self.tunnel1(torch.cat([x, x1], 1))
|
| 299 |
+
x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
|
| 300 |
+
|
| 301 |
+
decoder_output = self.deconv_for_decoder(out)
|
| 302 |
+
|
| 303 |
+
return x, decoder_output
|
| 304 |
+
'''
|
| 305 |
+
class Colorizer(nn.Module):
|
| 306 |
+
def __init__(self, extractor_path = 'model/model.pth'):
|
| 307 |
+
super(Colorizer, self).__init__()
|
| 308 |
+
|
| 309 |
+
self.generator = Generator()
|
| 310 |
+
self.extractor = se_resnext_half(dump_path=extractor_path, num_classes=370, input_channels=1)
|
| 311 |
+
|
| 312 |
+
def extractor_eval(self):
|
| 313 |
+
for param in self.extractor.parameters():
|
| 314 |
+
param.requires_grad = False
|
| 315 |
+
|
| 316 |
+
def extractor_train(self):
|
| 317 |
+
for param in extractor.parameters():
|
| 318 |
+
param.requires_grad = True
|
| 319 |
+
|
| 320 |
+
def forward(self, x, extractor_grad = False):
|
| 321 |
+
|
| 322 |
+
if extractor_grad:
|
| 323 |
+
features = self.extractor(x[:, 0:1])
|
| 324 |
+
else:
|
| 325 |
+
with torch.no_grad():
|
| 326 |
+
features = self.extractor(x[:, 0:1]).detach()
|
| 327 |
+
|
| 328 |
+
fake, guide = self.generator(x, features)
|
| 329 |
+
|
| 330 |
+
return fake, guide
|
| 331 |
+
'''
|
| 332 |
+
|
| 333 |
+
class Colorizer(nn.Module):
|
| 334 |
+
def __init__(self, generator_model, extractor_model):
|
| 335 |
+
super(Colorizer, self).__init__()
|
| 336 |
+
|
| 337 |
+
self.generator = generator_model
|
| 338 |
+
self.extractor = extractor_model
|
| 339 |
+
|
| 340 |
+
def load_generator_weights(self, gen_weights):
|
| 341 |
+
self.generator.load_state_dict(gen_weights)
|
| 342 |
+
|
| 343 |
+
def load_extractor_weights(self, ext_weights):
|
| 344 |
+
self.extractor.load_state_dict(ext_weights)
|
| 345 |
+
|
| 346 |
+
def extractor_eval(self):
|
| 347 |
+
for param in self.extractor.parameters():
|
| 348 |
+
param.requires_grad = False
|
| 349 |
+
self.extractor.eval()
|
| 350 |
+
|
| 351 |
+
def extractor_train(self):
|
| 352 |
+
for param in extractor.parameters():
|
| 353 |
+
param.requires_grad = True
|
| 354 |
+
self.extractor.train()
|
| 355 |
+
|
| 356 |
+
def forward(self, x, extractor_grad = False):
|
| 357 |
+
|
| 358 |
+
if extractor_grad:
|
| 359 |
+
features = self.extractor(x[:, 0:1])
|
| 360 |
+
else:
|
| 361 |
+
with torch.no_grad():
|
| 362 |
+
features = self.extractor(x[:, 0:1]).detach()
|
| 363 |
+
|
| 364 |
+
fake, guide = self.generator(x, features)
|
| 365 |
+
|
| 366 |
+
return fake, guide
|
| 367 |
+
|
| 368 |
+
class Discriminator(nn.Module):
|
| 369 |
+
def __init__(self, ndf=64):
|
| 370 |
+
super(Discriminator, self).__init__()
|
| 371 |
+
|
| 372 |
+
self.feed = nn.Sequential(SpectralNorm(nn.Conv2d(3, 64, 3, 1, 1)),
|
| 373 |
+
nn.LeakyReLU(0.2, True),
|
| 374 |
+
SpectralNorm(nn.Conv2d(64, 64, 3, 2, 0)),
|
| 375 |
+
nn.LeakyReLU(0.2, True),
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1),
|
| 381 |
+
SpectrResNeXtBottleneck(ndf, ndf, cardinality=8, dilate=1, stride=2), # 128
|
| 382 |
+
SpectralNorm(nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=False)),
|
| 383 |
+
nn.LeakyReLU(0.2, True),
|
| 384 |
+
|
| 385 |
+
SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1),
|
| 386 |
+
SpectrResNeXtBottleneck(ndf * 2, ndf * 2, cardinality=8, dilate=1, stride=2), # 64
|
| 387 |
+
SpectralNorm(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=1, stride=1, padding=0, bias=False)),
|
| 388 |
+
nn.LeakyReLU(0.2, True),
|
| 389 |
+
|
| 390 |
+
SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1),
|
| 391 |
+
SpectrResNeXtBottleneck(ndf * 4, ndf * 4, cardinality=8, dilate=1, stride=2), # 32,
|
| 392 |
+
SpectralNorm(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=1, stride=1, padding=1, bias=False)),
|
| 393 |
+
nn.LeakyReLU(0.2, True),
|
| 394 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
| 395 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1, stride=2), # 16
|
| 396 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
| 397 |
+
SpectrResNeXtBottleneck(ndf * 8, ndf * 8, cardinality=8, dilate=1),
|
| 398 |
+
nn.AdaptiveAvgPool2d((1, 1))
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
self.out = nn.Linear(512, 1)
|
| 402 |
+
|
| 403 |
+
def forward(self, color):
|
| 404 |
+
x = self.feed(color)
|
| 405 |
+
|
| 406 |
+
out = self.out(x.view(color.size(0), -1))
|
| 407 |
+
return out
|
| 408 |
+
|
| 409 |
+
class Content(nn.Module):
|
| 410 |
+
def __init__(self, path):
|
| 411 |
+
super(Content, self).__init__()
|
| 412 |
+
vgg16 = M.vgg16()
|
| 413 |
+
vgg16.load_state_dict(torch.load(path))
|
| 414 |
+
vgg16.features = nn.Sequential(
|
| 415 |
+
*list(vgg16.features.children())[:9]
|
| 416 |
+
)
|
| 417 |
+
self.model = vgg16.features
|
| 418 |
+
self.register_buffer('mean', torch.FloatTensor([0.485 - 0.5, 0.456 - 0.5, 0.406 - 0.5]).view(1, 3, 1, 1))
|
| 419 |
+
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 420 |
+
|
| 421 |
+
def forward(self, images):
|
| 422 |
+
return self.model((images.mul(0.5) - self.mean) / self.std)
|
model/vgg16-397923af.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0
|
| 3 |
+
size 553433881
|
readme.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automatic colorization
|
| 2 |
+
|
| 3 |
+
1. Download [generator](https://drive.google.com/file/d/1Oo6ycphJ3sUOpDCDoG29NA5pbhQVCevY/view?usp=sharing), [extractor](https://drive.google.com/file/d/12cbNyJcCa1zI2EBz6nea3BXl21Fm73Bt/view?usp=sharing) and [denoiser ](https://drive.google.com/file/d/161oyQcYpdkVdw8gKz_MA8RD-Wtg9XDp3/view?usp=sharing) weights. Put generator and extractor weights in `model` and denoiser weights in `denoising/models`.
|
| 4 |
+
2. To colorize image, folder of images, `.cbz` or `.cbr` file, use the following command:
|
| 5 |
+
```
|
| 6 |
+
$ python inference.py -p "path to file or folder"
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
# Manual colorization with color hints
|
| 10 |
+
|
| 11 |
+
1. Download [colorizer](https://drive.google.com/file/d/1BERrMl9e7cKsk9m2L0q1yO4k7blNhEWC/view?usp=sharing) and [denoiser ](https://drive.google.com/file/d/161oyQcYpdkVdw8gKz_MA8RD-Wtg9XDp3/view?usp=sharing) weights. Put colorizer weights in `model` and denoiser weights in `denoising/models`.
|
| 12 |
+
2. Run gunicorn server with:
|
| 13 |
+
```
|
| 14 |
+
$ ./run_drawing.sh
|
| 15 |
+
```
|
| 16 |
+
3. Open `localhost:5000` with a browser.
|
| 17 |
+
|
| 18 |
+
# References
|
| 19 |
+
1. Extractor weights are taken from https://github.com/blandocs/Tag2Pix/releases/download/release/model.pth
|
| 20 |
+
2. Denoiser weights are taken from http://www.ipol.im/pub/art/2019/231.
|
requirements.txt
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pip
|
| 2 |
+
cohere
|
| 3 |
+
openai
|
| 4 |
+
tiktoken
|
| 5 |
+
setuptools==68.2.2
|
| 6 |
+
numpy==1.23.5
|
| 7 |
+
scipy==1.11.3
|
| 8 |
+
matplotlib==3.8.0
|
| 9 |
+
opencv-python==4.8.1.78
|
| 10 |
+
scikit-learn==1.0.2
|
| 11 |
+
tb-nightly
|
| 12 |
+
flask==1.1.1
|
| 13 |
+
gunicorn==21.2.0
|
| 14 |
+
flask-wtf==0.14.3
|
| 15 |
+
snowy==0.0.9
|
| 16 |
+
scikit-image==0.19.3
|
| 17 |
+
patool==1.12
|
| 18 |
+
albumentations==1.3.0
|
| 19 |
+
PyYAML==6.0
|
| 20 |
+
qudida==0.0.4
|
| 21 |
+
joblib==1.2.0
|
| 22 |
+
threadpoolctl==3.1.0
|
| 23 |
+
typing-extensions==4.6.2
|
| 24 |
+
imageio==2.9.0
|
| 25 |
+
pillow==9.5.0
|
| 26 |
+
PyWavelets==1.3.0
|
| 27 |
+
tifffile==2021.11.2
|
| 28 |
+
imutils==0.5.4
|
| 29 |
+
cycler==0.11.0
|
| 30 |
+
kiwisolver==1.4.4
|
| 31 |
+
pyparsing==3.0.9
|
| 32 |
+
python-dateutil==2.8.2
|
| 33 |
+
pipdeptree==2.7.1
|
| 34 |
+
numba==0.56.4
|
| 35 |
+
llvmlite==0.39.1
|
| 36 |
+
torch==2.1.0
|
| 37 |
+
future==0.18.3
|
| 38 |
+
tqdm==4.65.0
|
| 39 |
+
colorama==0.4.6
|
| 40 |
+
wheel==0.40.0
|
| 41 |
+
torchvision
|
run_drawing.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
gunicorn --worker-class gevent --timeout 150 -w 1 -b 0.0.0.0:5000 drawing:app
|
static/js/draw.js
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
var canvas = document.getElementById('draw_canvas');
|
| 2 |
+
var ctx = canvas.getContext('2d');
|
| 3 |
+
var canvasWidth = canvas.width;
|
| 4 |
+
var canvasHeight = canvas.height;
|
| 5 |
+
var prevX, prevY;
|
| 6 |
+
|
| 7 |
+
var result_canvas = document.getElementById('result');
|
| 8 |
+
var result_ctx = result_canvas.getContext('2d');
|
| 9 |
+
result_canvas.width = canvas.width;
|
| 10 |
+
result_canvas.height = canvas.height;
|
| 11 |
+
|
| 12 |
+
var color_indicator = document.getElementById('color');
|
| 13 |
+
ctx.fillStyle = 'black';
|
| 14 |
+
color_indicator.value = '#000000';
|
| 15 |
+
|
| 16 |
+
var cur_id = window.location.pathname.substring(window.location.pathname.lastIndexOf('/') + 1);
|
| 17 |
+
|
| 18 |
+
function getRandomInt(max) {
|
| 19 |
+
return Math.floor(Math.random() * Math.floor(max));
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
var init_hint = new Image();
|
| 23 |
+
init_hint.addEventListener('load', function() {
|
| 24 |
+
ctx.drawImage(init_hint, 0, 0);
|
| 25 |
+
});
|
| 26 |
+
init_hint.src = '../static/temp_images/' + cur_id + '/hint.png?' + getRandomInt(100000).toString();
|
| 27 |
+
|
| 28 |
+
result_canvas.addEventListener('load', function(e) {
|
| 29 |
+
var img = new Image();
|
| 30 |
+
img.addEventListener('load', function() {
|
| 31 |
+
ctx.drawImage(img, 0, 0);
|
| 32 |
+
}, false);
|
| 33 |
+
console.log(window.location.pathname);
|
| 34 |
+
})
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
canvas.onload = function (e) {
|
| 38 |
+
var img = new Image();
|
| 39 |
+
img.addEventListener('load', function() {
|
| 40 |
+
ctx.drawImage(img, 0, 0);
|
| 41 |
+
}, false);
|
| 42 |
+
console.log(window.location.pathname);
|
| 43 |
+
//img.src = ;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
function reset() {
|
| 47 |
+
ctx.clearRect(0, 0, canvasWidth, canvasHeight);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
function getMousePos(canvas, evt) {
|
| 51 |
+
var rect = canvas.getBoundingClientRect();
|
| 52 |
+
return {
|
| 53 |
+
x: (evt.clientX - rect.left) / (rect.right - rect.left) * canvas.width,
|
| 54 |
+
y: (evt.clientY - rect.top) / (rect.bottom - rect.top) * canvas.height
|
| 55 |
+
};
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
function colorize() {
|
| 59 |
+
var file_id = document.location.pathname;
|
| 60 |
+
var image = canvas.toDataURL();
|
| 61 |
+
|
| 62 |
+
$.post("/colorize", { save_file_id: file_id, save_image: image}).done(function( data ) {
|
| 63 |
+
//console.log(document.location.origin + '/img/' + data)
|
| 64 |
+
//window.open(document.location.origin + '/img/' + data, '_blank');
|
| 65 |
+
//result.src = data;
|
| 66 |
+
var img = new Image();
|
| 67 |
+
img.addEventListener('load', function() {
|
| 68 |
+
result_ctx.drawImage(img, 0, 0);
|
| 69 |
+
}, false);
|
| 70 |
+
img.src = data;
|
| 71 |
+
});
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
canvas.addEventListener('mousedown', function(e) {
|
| 75 |
+
var mousePos = getMousePos(canvas, e);
|
| 76 |
+
if (e.button == 0) {
|
| 77 |
+
ctx.fillRect(mousePos['x'], mousePos['y'], 1, 1);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
if (e.button == 2) {
|
| 81 |
+
prevX = mousePos['x']
|
| 82 |
+
prevY = mousePos['y']
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
})
|
| 86 |
+
|
| 87 |
+
canvas.addEventListener('mouseup', function(e) {
|
| 88 |
+
if (e.button == 2) {
|
| 89 |
+
var mousePos = getMousePos(canvas, e);
|
| 90 |
+
var diff_width = mousePos['x'] - prevX;
|
| 91 |
+
var diff_height = mousePos['y'] - prevY;
|
| 92 |
+
|
| 93 |
+
ctx.clearRect(prevX, prevY, diff_width, diff_height);
|
| 94 |
+
}
|
| 95 |
+
})
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
canvas.addEventListener('contextmenu', function(evt) {
|
| 99 |
+
evt.preventDefault();
|
| 100 |
+
})
|
| 101 |
+
|
| 102 |
+
function color(color_value){
|
| 103 |
+
ctx.fillStyle = color_value;
|
| 104 |
+
color_indicator.value = color_value;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
color_indicator.oninput = function() {
|
| 108 |
+
color(this.value);
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
function rgbToHex(rgb){
|
| 112 |
+
return '#' + ((rgb[0] << 16) | (rgb[1] << 8) | rgb[2]).toString(16);
|
| 113 |
+
};
|
| 114 |
+
|
| 115 |
+
result_canvas.addEventListener('click', function(e) {
|
| 116 |
+
if (e.button == 0) {
|
| 117 |
+
var cur_pixel = result_ctx.getImageData(e.offsetX, e.offsetY, 1, 1).data;
|
| 118 |
+
color(rgbToHex(cur_pixel));
|
| 119 |
+
}
|
| 120 |
+
})
|
static/temp_images/.gitkeep
ADDED
|
File without changes
|
templates/drawing.html
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<title>Colorization app</title>
|
| 6 |
+
<style>
|
| 7 |
+
.back{
|
| 8 |
+
height:100%;
|
| 9 |
+
width:100%;
|
| 10 |
+
position: absolute;
|
| 11 |
+
background-color: yellow;
|
| 12 |
+
top:0px;
|
| 13 |
+
padding: 10px;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
#draw_canvas {
|
| 17 |
+
height: 200%;
|
| 18 |
+
border: 3px solid black;
|
| 19 |
+
background-image: linear-gradient(rgba(60,60,60,.85), rgba(60,60,60,.85)), url(../static/{{img_path}});
|
| 20 |
+
background-color: #c7b39b;
|
| 21 |
+
background-size: 100%;
|
| 22 |
+
}
|
| 23 |
+
</style>
|
| 24 |
+
|
| 25 |
+
</head>
|
| 26 |
+
<body>
|
| 27 |
+
<div align="left" style ="margin : 5px; margin-left : 0px">
|
| 28 |
+
<input type="button" onclick="location.href='/';" value="Home" />
|
| 29 |
+
</div>
|
| 30 |
+
<p style="margin: 0; padding: 0">
|
| 31 |
+
Left click - colorize, right hold - remove with rectangle, left click on result - use corresponding color.
|
| 32 |
+
|
| 33 |
+
</p>
|
| 34 |
+
<hr style ="margin: 0; padding: 0">
|
| 35 |
+
|
| 36 |
+
<p><table>
|
| 37 |
+
<tr>
|
| 38 |
+
<td>
|
| 39 |
+
<table>
|
| 40 |
+
<tr>
|
| 41 |
+
<td><button style="background-color: #000000; height: 20px; width: 20px;" onclick="color('#000000')"></button>
|
| 42 |
+
<td><button style="background-color: #B0171F; height: 20px; width: 20px;" onclick="color('#B0171F')"></button>
|
| 43 |
+
</tr>
|
| 44 |
+
<tr>
|
| 45 |
+
<td><button style="background-color: #DA70D6; height: 20px; width: 20px;" onclick="color('#DA70D6')"></button>
|
| 46 |
+
<td><button style="background-color: #8A2BE2; height: 20px; width: 20px;" onclick="color('#8A2BE2')"></button>
|
| 47 |
+
</tr>
|
| 48 |
+
<tr>
|
| 49 |
+
<td><button style="background-color: #0000FF; height: 20px; width: 20px;" onclick="color('#0000FF')"></button>
|
| 50 |
+
<td><button style="background-color: #4876FF; height: 20px; width: 20px;" onclick="color('#4876FF')"></button>
|
| 51 |
+
</tr>
|
| 52 |
+
<tr>
|
| 53 |
+
<td><button style="background-color: #CAE1FF; height: 20px; width: 20px;" onclick="color('#CAE1FF')"></button>
|
| 54 |
+
<td><button style="background-color: #6E7B8B; height: 20px; width: 20px;" onclick="color('#6E7B8B')"></button>
|
| 55 |
+
</tr>
|
| 56 |
+
<tr>
|
| 57 |
+
<td><button style="background-color: #00C78C; height: 20px; width: 20px;" onclick="color('#00C78C')"></button>
|
| 58 |
+
<td><button style="background-color: #00FA9A; height: 20px; width: 20px;" onclick="color('#00FA9A')"></button>
|
| 59 |
+
</tr>
|
| 60 |
+
<tr>
|
| 61 |
+
<td><button style="background-color: #00FF7F; height: 20px; width: 20px;" onclick="color('#00FF7F')"></button>
|
| 62 |
+
<td><button style="background-color: #00C957; height: 20px; width: 20px;" onclick="color('#00C957')"></button>
|
| 63 |
+
</tr>
|
| 64 |
+
<tr>
|
| 65 |
+
<td><button style="background-color: #3D9140; height: 20px; width: 20px;" onclick="color('#3D9140')"></button>
|
| 66 |
+
<td><button style="background-color: #32CD32; height: 20px; width: 20px;" onclick="color('#32CD32')"></button>
|
| 67 |
+
</tr>
|
| 68 |
+
<tr>
|
| 69 |
+
<td><button style="background-color: #00EE00; height: 20px; width: 20px;" onclick="color('#00EE00')"></button>
|
| 70 |
+
|
| 71 |
+
<td><button style="background-color: #008B00; height: 20px; width: 20px;" onclick="color('#008B00')"></button>
|
| 72 |
+
</tr>
|
| 73 |
+
<tr>
|
| 74 |
+
<td><button style="background-color: #76EE00; height: 20px; width: 20px;" onclick="color('#76EE00')"></button>
|
| 75 |
+
|
| 76 |
+
<td><button style="background-color: #CAFF70; height: 20px; width: 20px;" onclick="color('#CAFF70')"></button>
|
| 77 |
+
</tr>
|
| 78 |
+
<tr>
|
| 79 |
+
<td><button style="background-color: #FFFF00; height: 20px; width: 20px;" onclick="color('#FFFF00')"></button>
|
| 80 |
+
|
| 81 |
+
<td><button style="background-color: #CDCD00; height: 20px; width: 20px;" onclick="color('#CDCD00')"></button>
|
| 82 |
+
</tr>
|
| 83 |
+
<tr>
|
| 84 |
+
<td><button style="background-color: #FFF68F; height: 20px; width: 20px;" onclick="color('#FFF68F')"></button>
|
| 85 |
+
|
| 86 |
+
<td><button style="background-color: #FFFACD; height: 20px; width: 20px;" onclick="color('#FFFACD')"></button>
|
| 87 |
+
</tr>
|
| 88 |
+
<tr>
|
| 89 |
+
<td><button style="background-color: #FFEC8B; height: 20px; width: 20px;" onclick="color('#FFEC8B')"></button>
|
| 90 |
+
|
| 91 |
+
<td><button style="background-color: #FFD700; height: 20px; width: 20px;" onclick="color('#FFD700')"></button>
|
| 92 |
+
</tr>
|
| 93 |
+
<tr>
|
| 94 |
+
<td><button style="background-color: #F5DEB3; height: 20px; width: 20px;" onclick="color('#F5DEB3')"></button>
|
| 95 |
+
|
| 96 |
+
<td><button style="background-color: #FFE4B5; height: 20px; width: 20px;" onclick="color('#FFE4B5')"></button>
|
| 97 |
+
</tr>
|
| 98 |
+
<tr>
|
| 99 |
+
<td><button style="background-color: #EECFA1; height: 20px; width: 20px;" onclick="color('#EECFA1')"></button>
|
| 100 |
+
|
| 101 |
+
<td><button style="background-color: #FF9912; height: 20px; width: 20px;" onclick="color('#FF9912')"></button>
|
| 102 |
+
</tr>
|
| 103 |
+
<tr>
|
| 104 |
+
<td><button style="background-color: #8E388E; height: 20px; width: 20px;" onclick="color('#8E388E')"></button>
|
| 105 |
+
|
| 106 |
+
<td><button style="background-color: #7171C6; height: 20px; width: 20px;" onclick="color('#7171C6')"></button>
|
| 107 |
+
</tr>
|
| 108 |
+
|
| 109 |
+
<tr>
|
| 110 |
+
<td><button style="background-color: #7D9EC0; height: 20px; width: 20px;" onclick="color('#7D9EC0')"></button>
|
| 111 |
+
|
| 112 |
+
<td><button style="background-color: #388E8E; height: 20px; width: 20px;" onclick="color('#388E8E')"></button>
|
| 113 |
+
|
| 114 |
+
</tr>
|
| 115 |
+
|
| 116 |
+
<tr>
|
| 117 |
+
<td><button style="background-color: #71C671; height: 20px; width: 20px;" onclick="color('#71C671')"></button>
|
| 118 |
+
|
| 119 |
+
<td><button style="background-color: #8E8E38; height: 20px; width: 20px;" onclick="color('#8E8E38')"></button>
|
| 120 |
+
</tr>
|
| 121 |
+
<tr>
|
| 122 |
+
<td><button style="background-color: #C5C1AA; height: 20px; width: 20px;" onclick="color('#C5C1AA')"></button>
|
| 123 |
+
|
| 124 |
+
<td><button style="background-color: #C67171; height: 20px; width: 20px;" onclick="color('#C67171')"></button>
|
| 125 |
+
</tr>
|
| 126 |
+
<tr>
|
| 127 |
+
<td><button style="background-color: #555555; height: 20px; width: 20px;" onclick="color('#555555')"></button>
|
| 128 |
+
<td><button style="background-color: #848484; height: 20px; width: 20px;" onclick="color('#848484')"></button>
|
| 129 |
+
</tr>
|
| 130 |
+
<tr>
|
| 131 |
+
<td><button style="background-color: #FFFFFF; height: 20px; width: 20px;" onclick="color('#FFFFFF')"></button>
|
| 132 |
+
<td><button style="background-color: #EE0000; height: 20px; width: 20px;" onclick="color('#EE0000')"></button>
|
| 133 |
+
</tr>
|
| 134 |
+
<tr>
|
| 135 |
+
<td><button style="background-color: #FF4040; height: 20px; width: 20px;" onclick="color('#FF4040')"></button>
|
| 136 |
+
<td><button style="background-color: #EE6363; height: 20px; width: 20px;" onclick="color('#EE6363')"></button>
|
| 137 |
+
</tr>
|
| 138 |
+
<tr>
|
| 139 |
+
<td><button style="background-color: #FFC1C1; height: 20px; width: 20px;" onclick="color('#FFC1C1')"></button>
|
| 140 |
+
<td><button style="background-color: #FF7256; height: 20px; width: 20px;" onclick="color('#FF7256')"></button>
|
| 141 |
+
</tr>
|
| 142 |
+
<tr>
|
| 143 |
+
<td><button style="background-color: #FF4500; height: 20px; width: 20px;" onclick="color('#FF4500')"></button>
|
| 144 |
+
<td><button style="background-color: #F4A460; height: 20px; width: 20px;" onclick="color('#F4A460')"></button>
|
| 145 |
+
</tr>
|
| 146 |
+
<tr>
|
| 147 |
+
<td><button style="background-color: #FF8000; height: 20px; width: 20px;" onclick="color('FF8000')"></button>
|
| 148 |
+
<td><button style="background-color: #FFD700; height: 20px; width: 20px;" onclick="color('#FFD700')"></button>
|
| 149 |
+
</tr>
|
| 150 |
+
<tr>
|
| 151 |
+
<td><button style="background-color: #8B864E; height: 20px; width: 20px;" onclick="color('#8B864E')"></button>
|
| 152 |
+
<td><button style="background-color: #9ACD32; height: 20px; width: 20px;" onclick="color('#9ACD32')"></button>
|
| 153 |
+
</tr>
|
| 154 |
+
<tr>
|
| 155 |
+
<td><button style="background-color: #66CD00; height: 20px; width: 20px;" onclick="color('#66CD00')"></button>
|
| 156 |
+
<td><button style="background-color: #BDFCC9; height: 20px; width: 20px;" onclick="color('#BDFCC9')"></button>
|
| 157 |
+
</tr>
|
| 158 |
+
<tr>
|
| 159 |
+
<td><button style="background-color: #76EEC6; height: 20px; width: 20px;" onclick="color('#76EEC6')"></button>
|
| 160 |
+
<td><button style="background-color: #40E0D0; height: 20px; width: 20px;" onclick="color('#40E0D0')"></button>
|
| 161 |
+
</tr>
|
| 162 |
+
<tr>
|
| 163 |
+
<td><button style="background-color: #E0EEEE; height: 20px; width: 20px;" onclick="color('#E0EEEE')"></button>
|
| 164 |
+
<td><button style="background-color: #98F5FF; height: 20px; width: 20px;" onclick="color('#98F5FF')"></button>
|
| 165 |
+
</tr>
|
| 166 |
+
<tr>
|
| 167 |
+
<td><button style="background-color: #33A1C9; height: 20px; width: 20px;" onclick="color('#33A1C9')"></button>
|
| 168 |
+
<td><button style="background-color: #F0F8FF; height: 20px; width: 20px;" onclick="color('#F0F8FF')"></button>
|
| 169 |
+
</tr>
|
| 170 |
+
<tr>
|
| 171 |
+
<td><button style="background-color: #4682B4; height: 20px; width: 20px;" onclick="color('#4682B4')"></button>
|
| 172 |
+
<td><button style="background-color: #C6E2FF; height: 20px; width: 20px;" onclick="color('#C6E2FF')"></button>
|
| 173 |
+
</tr>
|
| 174 |
+
<tr>
|
| 175 |
+
<td><button style="background-color: #9B30FF; height: 20px; width: 20px;" onclick="color('#9B30FF')"></button>
|
| 176 |
+
<td><button style="background-color: #EE82EE; height: 20px; width: 20px;" onclick="color('#EE82EE')"></button>
|
| 177 |
+
</tr>
|
| 178 |
+
<tr>
|
| 179 |
+
<td><button style="background-color: #FFC0CB; height: 20px; width: 20px;" onclick="color('#FFC0CB')"></button>
|
| 180 |
+
<td><button style="background-color: #7CFC00; height: 20px; width: 20px;" onclick="color('#7CFC00')"></button>
|
| 181 |
+
</tr>
|
| 182 |
+
<tr>
|
| 183 |
+
<input type="color" id="color">
|
| 184 |
+
</tr>
|
| 185 |
+
</table>
|
| 186 |
+
</td>
|
| 187 |
+
<td>
|
| 188 |
+
<div style="width: 1150px; height: 800px; overflow: auto"><canvas align = "center" id="draw_canvas" width="{{width}}" height="{{height}}"></canvas></div>
|
| 189 |
+
</td>
|
| 190 |
+
<td>
|
| 191 |
+
<canvas id='result'></canvas>
|
| 192 |
+
</td>
|
| 193 |
+
</tr>
|
| 194 |
+
</table></p>
|
| 195 |
+
|
| 196 |
+
<button style="height: 20px; width: 80px" onclick="reset()">Clear</button>
|
| 197 |
+
|
| 198 |
+
<button style="height: 20px; width: 80px" onclick="colorize()" >Colorize</button>
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
<script src="http://code.jquery.com/jquery-1.8.3.js"></script>
|
| 203 |
+
<script src="/static/js/draw.js">
|
| 204 |
+
</script>
|
| 205 |
+
</body>
|
| 206 |
+
</html>
|
templates/submit.html
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<form method="POST" action="/" enctype="multipart/form-data" target="_blank">
|
| 2 |
+
{{ form.hidden_tag() }}
|
| 3 |
+
{{ form.file.label }} {{ form.file(size=20) }}
|
| 4 |
+
{{ form.denoise.label }} {{ form.denoise(size=5) }}
|
| 5 |
+
{{ form.denoise_sigma.label }} {{ form.denoise_sigma(size=5) }}
|
| 6 |
+
{{ form.autohint.label }} {{ form.autohint(size=5) }}
|
| 7 |
+
{{ form.autohint_sigma.label }} {{ form.autohint_sigma(size=5) }}
|
| 8 |
+
{{ form.ignore_gray.label }} {{ form.ignore_gray(size=5) }}
|
| 9 |
+
<input type="submit" value="Colorize">
|
| 10 |
+
|
| 11 |
+
</form>
|
templates/upload.html
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<style>
|
| 2 |
+
form{
|
| 3 |
+
position:fixed;
|
| 4 |
+
top:50%;
|
| 5 |
+
left:45%;
|
| 6 |
+
width:1250px;
|
| 7 |
+
}
|
| 8 |
+
</style>
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
<form id="form" method="POST" action="/" enctype="multipart/form-data">
|
| 12 |
+
{{ form.hidden_tag() }}
|
| 13 |
+
{{ form.file(size=20) }}
|
| 14 |
+
</form>
|
| 15 |
+
|
| 16 |
+
<script>
|
| 17 |
+
document.getElementById("file").onchange = function() {
|
| 18 |
+
document.getElementById("form").submit();
|
| 19 |
+
};
|
| 20 |
+
</script>
|
train.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import numpy as np
|
| 5 |
+
import albumentations as albu
|
| 6 |
+
import argparse
|
| 7 |
+
import datetime
|
| 8 |
+
|
| 9 |
+
from utils.utils import open_json, weights_init, weights_init_spectr, generate_mask
|
| 10 |
+
from model.models import Colorizer, Generator, Content, Discriminator
|
| 11 |
+
from model.extractor import get_seresnext_extractor
|
| 12 |
+
from dataset.datasets import TrainDataset, FineTuningDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_args():
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
parser.add_argument("-p", "--path", required=True, help = "dataset path")
|
| 19 |
+
parser.add_argument('-ft', '--fine_tuning', dest = 'fine_tuning', action = 'store_true')
|
| 20 |
+
parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
|
| 21 |
+
parser.set_defaults(fine_tuning = False)
|
| 22 |
+
parser.set_defaults(gpu = False)
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
return args
|
| 26 |
+
|
| 27 |
+
def get_transforms():
|
| 28 |
+
return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
|
| 29 |
+
|
| 30 |
+
def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number):
|
| 31 |
+
train_dataset = TrainDataset(data_path, transforms, mult_number)
|
| 32 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
|
| 33 |
+
|
| 34 |
+
if fine_tuning:
|
| 35 |
+
finetuning_dataset = FineTuningDataset(data_path, transforms)
|
| 36 |
+
finetuning_dataloader = torch.utils.data.DataLoader(finetuning_dataset, batch_size = batch_size, shuffle = True)
|
| 37 |
+
|
| 38 |
+
return train_dataloader, finetuning_dataloader
|
| 39 |
+
|
| 40 |
+
def get_models(device):
|
| 41 |
+
generator = Generator()
|
| 42 |
+
extractor = get_seresnext_extractor()
|
| 43 |
+
colorizer = Colorizer(generator, extractor)
|
| 44 |
+
|
| 45 |
+
colorizer.extractor_eval()
|
| 46 |
+
colorizer = colorizer.to(device)
|
| 47 |
+
|
| 48 |
+
discriminator = Discriminator().to(device)
|
| 49 |
+
|
| 50 |
+
content = Content('model/vgg16-397923af.pth').eval().to(device)
|
| 51 |
+
for param in content.parameters():
|
| 52 |
+
param.requires_grad = False
|
| 53 |
+
|
| 54 |
+
return colorizer, discriminator, content
|
| 55 |
+
|
| 56 |
+
def set_weights(colorizer, discriminator):
|
| 57 |
+
colorizer.generator.apply(weights_init)
|
| 58 |
+
colorizer.load_extractor_weights(torch.load('model/extractor.pth'))
|
| 59 |
+
|
| 60 |
+
discriminator.apply(weights_init_spectr)
|
| 61 |
+
|
| 62 |
+
def generator_loss(disc_output, true_labels, main_output, guide_output, real_image, content_gen, content_true, dist_loss = nn.L1Loss(), content_dist_loss = nn.MSELoss(), class_loss = nn.BCEWithLogitsLoss()):
|
| 63 |
+
sim_loss_full = dist_loss(main_output, real_image)
|
| 64 |
+
sim_loss_guide = dist_loss(guide_output, real_image)
|
| 65 |
+
|
| 66 |
+
adv_loss = class_loss(disc_output, true_labels)
|
| 67 |
+
|
| 68 |
+
content_loss = content_dist_loss(content_gen, content_true)
|
| 69 |
+
|
| 70 |
+
sum_loss = 10 * (sim_loss_full + 0.9 * sim_loss_guide) + adv_loss + content_loss
|
| 71 |
+
|
| 72 |
+
return sum_loss
|
| 73 |
+
|
| 74 |
+
def get_optimizers(colorizer, discriminator, generator_lr, discriminator_lr):
|
| 75 |
+
optimizerG = optim.Adam(colorizer.generator.parameters(), lr = generator_lr, betas=(0.5, 0.9))
|
| 76 |
+
optimizerD = optim.Adam(discriminator.parameters(), lr = discriminator_lr, betas=(0.5, 0.9))
|
| 77 |
+
|
| 78 |
+
return optimizerG, optimizerD
|
| 79 |
+
|
| 80 |
+
def generator_step(inputs, colorizer, discriminator, content, loss_function, optimizer, device, white_penalty = True):
|
| 81 |
+
for p in discriminator.parameters():
|
| 82 |
+
p.requires_grad = False
|
| 83 |
+
for p in colorizer.generator.parameters():
|
| 84 |
+
p.requires_grad = True
|
| 85 |
+
|
| 86 |
+
colorizer.generator.zero_grad()
|
| 87 |
+
|
| 88 |
+
bw, color, hint, dfm = inputs
|
| 89 |
+
bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
|
| 90 |
+
|
| 91 |
+
fake, guide = colorizer(torch.cat([bw, dfm, hint], 1))
|
| 92 |
+
|
| 93 |
+
logits_fake = discriminator(fake)
|
| 94 |
+
y_real = torch.ones((bw.size(0), 1), device = device)
|
| 95 |
+
|
| 96 |
+
content_fake = content(fake)
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
content_true = content(color)
|
| 99 |
+
|
| 100 |
+
generator_loss = loss_function(logits_fake, y_real, fake, guide, color, content_fake, content_true)
|
| 101 |
+
|
| 102 |
+
if white_penalty:
|
| 103 |
+
mask = (~((color > 0.85).float().sum(dim = 1) == 3).unsqueeze(1).repeat((1, 3, 1, 1 ))).float()
|
| 104 |
+
white_zones = mask * (fake + 1) / 2
|
| 105 |
+
white_penalty = (torch.pow(white_zones.sum(dim = 1), 2).sum(dim = (1, 2)) / (mask.sum(dim = (1, 2, 3)) + 1)).mean()
|
| 106 |
+
|
| 107 |
+
generator_loss += white_penalty
|
| 108 |
+
|
| 109 |
+
generator_loss.backward()
|
| 110 |
+
|
| 111 |
+
optimizer.step()
|
| 112 |
+
|
| 113 |
+
return generator_loss.item()
|
| 114 |
+
|
| 115 |
+
def discriminator_step(inputs, colorizer, discriminator, optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
|
| 116 |
+
|
| 117 |
+
for p in discriminator.parameters():
|
| 118 |
+
p.requires_grad = True
|
| 119 |
+
for p in colorizer.generator.parameters():
|
| 120 |
+
p.requires_grad = False
|
| 121 |
+
|
| 122 |
+
discriminator.zero_grad()
|
| 123 |
+
|
| 124 |
+
bw, color, hint, dfm = inputs
|
| 125 |
+
bw, color, hint, dfm = bw.to(device), color.to(device), hint.to(device), dfm.to(device)
|
| 126 |
+
|
| 127 |
+
y_real = torch.full((bw.size(0), 1), 0.9, device = device)
|
| 128 |
+
|
| 129 |
+
y_fake = torch.zeros((bw.size(0), 1), device = device)
|
| 130 |
+
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
fake_color, _ = colorizer(torch.cat([bw, dfm, hint], 1))
|
| 133 |
+
fake_color.detach()
|
| 134 |
+
|
| 135 |
+
logits_fake = discriminator(fake_color)
|
| 136 |
+
logits_real = discriminator(color)
|
| 137 |
+
|
| 138 |
+
fake_loss = loss_function(logits_fake, y_fake)
|
| 139 |
+
real_loss = loss_function(logits_real, y_real)
|
| 140 |
+
|
| 141 |
+
discriminator_loss = real_loss + fake_loss
|
| 142 |
+
|
| 143 |
+
discriminator_loss.backward()
|
| 144 |
+
optimizer.step()
|
| 145 |
+
|
| 146 |
+
return discriminator_loss.item()
|
| 147 |
+
|
| 148 |
+
def decrease_lr(optimizer, rate):
|
| 149 |
+
for group in optimizer.param_groups:
|
| 150 |
+
group['lr'] /= rate
|
| 151 |
+
|
| 152 |
+
def set_lr(optimizer, value):
|
| 153 |
+
for group in optimizer.param_groups:
|
| 154 |
+
group['lr'] = value
|
| 155 |
+
|
| 156 |
+
def train(colorizer, discriminator, content, dataloader, epochs, colorizer_optimizer, discriminator_optimizer, lr_decay_epoch = -1, device = 'cpu'):
|
| 157 |
+
colorizer.generator.train()
|
| 158 |
+
discriminator.train()
|
| 159 |
+
|
| 160 |
+
disc_step = True
|
| 161 |
+
|
| 162 |
+
for epoch in range(epochs):
|
| 163 |
+
if (epoch == lr_decay_epoch):
|
| 164 |
+
decrease_lr(colorizer_optimizer, 10)
|
| 165 |
+
decrease_lr(discriminator_optimizer, 10)
|
| 166 |
+
|
| 167 |
+
sum_disc_loss = 0
|
| 168 |
+
sum_gen_loss = 0
|
| 169 |
+
|
| 170 |
+
for n, inputs in enumerate(dataloader):
|
| 171 |
+
if n % 5 == 0:
|
| 172 |
+
print(datetime.datetime.now().time())
|
| 173 |
+
print('Step : %d Discr loss: %.4f Gen loss : %.4f \n'%(n, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
if disc_step:
|
| 177 |
+
step_loss = discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
|
| 178 |
+
sum_disc_loss += step_loss
|
| 179 |
+
else:
|
| 180 |
+
step_loss = generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
|
| 181 |
+
sum_gen_loss += step_loss
|
| 182 |
+
|
| 183 |
+
disc_step = disc_step ^ True
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
print(datetime.datetime.now().time())
|
| 187 |
+
print('Epoch : %d Discr loss: %.4f Gen loss : %.4f \n'%(epoch, sum_disc_loss / (n // 2 + 1), sum_gen_loss / (n // 2 + 1)))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_optimizer, device, loss_function = nn.BCEWithLogitsLoss()):
|
| 191 |
+
|
| 192 |
+
for p in discriminator.parameters():
|
| 193 |
+
p.requires_grad = True
|
| 194 |
+
for p in colorizer.generator.parameters():
|
| 195 |
+
p.requires_grad = False
|
| 196 |
+
|
| 197 |
+
for cur_disc_step in range(5):
|
| 198 |
+
discriminator.zero_grad()
|
| 199 |
+
|
| 200 |
+
bw, dfm, color_for_real = data_iter.next()
|
| 201 |
+
bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
|
| 202 |
+
|
| 203 |
+
y_real = torch.full((bw.size(0), 1), 0.9, device = device)
|
| 204 |
+
y_fake = torch.zeros((bw.size(0), 1), device = device)
|
| 205 |
+
|
| 206 |
+
empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3] ).float().to(device)
|
| 207 |
+
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
fake_color_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint ], 1))
|
| 210 |
+
fake_color_manga.detach()
|
| 211 |
+
|
| 212 |
+
logits_fake = discriminator(fake_color_manga)
|
| 213 |
+
logits_real = discriminator(color_for_real)
|
| 214 |
+
|
| 215 |
+
fake_loss = loss_function(logits_fake, y_fake)
|
| 216 |
+
real_loss = loss_function(logits_real, y_real)
|
| 217 |
+
discriminator_loss = real_loss + fake_loss
|
| 218 |
+
|
| 219 |
+
discriminator_loss.backward()
|
| 220 |
+
disc_optimizer.step()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
for p in discriminator.parameters():
|
| 224 |
+
p.requires_grad = False
|
| 225 |
+
for p in colorizer.generator.parameters():
|
| 226 |
+
p.requires_grad = True
|
| 227 |
+
|
| 228 |
+
colorizer.generator.zero_grad()
|
| 229 |
+
|
| 230 |
+
bw, dfm, _ = data_iter.next()
|
| 231 |
+
bw, dfm = bw.to(device), dfm.to(device)
|
| 232 |
+
|
| 233 |
+
y_real = torch.ones((bw.size(0), 1), device = device)
|
| 234 |
+
|
| 235 |
+
empty_hint = torch.zeros(bw.shape[0], 4, bw.shape[2] , bw.shape[3]).float().to(device)
|
| 236 |
+
|
| 237 |
+
fake_manga, _ = colorizer(torch.cat([bw, dfm, empty_hint], 1))
|
| 238 |
+
|
| 239 |
+
logits_fake = discriminator(fake_manga)
|
| 240 |
+
adv_loss = loss_function(logits_fake, y_real)
|
| 241 |
+
|
| 242 |
+
generator_loss = adv_loss
|
| 243 |
+
|
| 244 |
+
generator_loss.backward()
|
| 245 |
+
gen_optimizer.step()
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, data_iter, device = 'cpu'):
|
| 250 |
+
colorizer.generator.train()
|
| 251 |
+
discriminator.train()
|
| 252 |
+
|
| 253 |
+
disc_step = True
|
| 254 |
+
|
| 255 |
+
for n, inputs in enumerate(dataloader):
|
| 256 |
+
|
| 257 |
+
if n == iterations:
|
| 258 |
+
return
|
| 259 |
+
|
| 260 |
+
if disc_step:
|
| 261 |
+
discriminator_step(inputs, colorizer, discriminator, discriminator_optimizer, device)
|
| 262 |
+
else:
|
| 263 |
+
generator_step(inputs, colorizer, discriminator, content, generator_loss, colorizer_optimizer, device)
|
| 264 |
+
|
| 265 |
+
disc_step = disc_step ^ True
|
| 266 |
+
|
| 267 |
+
if n % 10 == 5:
|
| 268 |
+
fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
|
| 269 |
+
|
| 270 |
+
if __name__ == '__main__':
|
| 271 |
+
args = parse_args()
|
| 272 |
+
config = open_json('configs/train_config.json')
|
| 273 |
+
|
| 274 |
+
if args.gpu:
|
| 275 |
+
device = 'cuda'
|
| 276 |
+
else:
|
| 277 |
+
device = 'cpu'
|
| 278 |
+
|
| 279 |
+
augmentations = get_transforms()
|
| 280 |
+
|
| 281 |
+
train_dataloader, ft_dataloader = get_dataloaders(args.path, augmentations, config['batch_size'], args.fine_tuning, config['number_of_mults'])
|
| 282 |
+
|
| 283 |
+
colorizer, discriminator, content = get_models(device)
|
| 284 |
+
set_weights(colorizer, discriminator)
|
| 285 |
+
|
| 286 |
+
gen_optimizer, disc_optimizer = get_optimizers(colorizer, discriminator, config['generator_lr'], config['discriminator_lr'])
|
| 287 |
+
|
| 288 |
+
train(colorizer, discriminator, content, train_dataloader, config['epochs'], gen_optimizer, disc_optimizer, config['lr_decrease_epoch'], device)
|
| 289 |
+
|
| 290 |
+
if args.fine_tuning:
|
| 291 |
+
set_lr(gen_optimizer, config["finetuning_generator_lr"])
|
| 292 |
+
fine_tuning(colorizer, discriminator, content, train_dataloader, config['finetuning_iterations'], gen_optimizer, disc_optimizer, iter(ft_dataloader), device)
|
| 293 |
+
|
| 294 |
+
torch.save(colorizer.generator.state_dict(), str(datetime.datetime.now().time()))
|
utils/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|
utils/dataset_utils.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import cv2
|
| 4 |
+
import snowy
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_resized_image(img, size):
|
| 9 |
+
if len(img.shape) == 2:
|
| 10 |
+
img = np.repeat(np.expand_dims(img, 2), 3, 2)
|
| 11 |
+
|
| 12 |
+
if (img.shape[0] < img.shape[1]):
|
| 13 |
+
height = img.shape[0]
|
| 14 |
+
ratio = height / size
|
| 15 |
+
width = int(np.ceil(img.shape[1] / ratio))
|
| 16 |
+
img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
|
| 17 |
+
else:
|
| 18 |
+
width = img.shape[1]
|
| 19 |
+
ratio = width / size
|
| 20 |
+
height = int(np.ceil(img.shape[0] / ratio))
|
| 21 |
+
img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
|
| 22 |
+
|
| 23 |
+
if (img.dtype == 'float32'):
|
| 24 |
+
np.clip(img, 0, 1, out = img)
|
| 25 |
+
|
| 26 |
+
return img
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_sketch_image(img, sketcher, mult_val):
|
| 30 |
+
|
| 31 |
+
if mult_val:
|
| 32 |
+
sketch_image = sketcher.get_sketch_with_resize(img, mult = mult_val)
|
| 33 |
+
else:
|
| 34 |
+
sketch_image = sketcher.get_sketch_with_resize(img)
|
| 35 |
+
|
| 36 |
+
return sketch_image
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_dfm_image(sketch):
|
| 40 |
+
dfm_image = snowy.unitize(snowy.generate_sdf(np.expand_dims(1 - sketch, 2) != 0)).squeeze()
|
| 41 |
+
return dfm_image
|
| 42 |
+
|
| 43 |
+
def get_sketch(image, sketcher, dfm, mult = None):
|
| 44 |
+
sketch_image = get_sketch_image(image, sketcher, mult)
|
| 45 |
+
|
| 46 |
+
dfm_image = None
|
| 47 |
+
|
| 48 |
+
if dfm:
|
| 49 |
+
dfm_image = get_dfm_image(sketch_image)
|
| 50 |
+
|
| 51 |
+
sketch_image = (sketch_image * 255).astype('uint8')
|
| 52 |
+
|
| 53 |
+
if dfm:
|
| 54 |
+
dfm_image = (dfm_image * 255).astype('uint8')
|
| 55 |
+
|
| 56 |
+
return sketch_image, dfm_image
|
| 57 |
+
|
| 58 |
+
def get_sketches(image, sketcher, mult_list, dfm):
|
| 59 |
+
for mult in mult_list:
|
| 60 |
+
yield get_sketch(image, sketcher, dfm, mult)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def create_resized_dataset(source_path, target_path, side_size):
|
| 64 |
+
images = os.listdir(source_path)
|
| 65 |
+
|
| 66 |
+
for image_name in images:
|
| 67 |
+
|
| 68 |
+
new_image_name = image_name[:image_name.rfind('.')] + '.png'
|
| 69 |
+
new_path = os.path.join(target_path, new_image_name)
|
| 70 |
+
|
| 71 |
+
if not os.path.exists(new_path):
|
| 72 |
+
try:
|
| 73 |
+
image = cv2.imread(os.path.join(source_path, image_name))
|
| 74 |
+
|
| 75 |
+
if image is None:
|
| 76 |
+
raise Exception()
|
| 77 |
+
|
| 78 |
+
image = get_resized_image(image, side_size)
|
| 79 |
+
|
| 80 |
+
cv2.imwrite(new_path, image)
|
| 81 |
+
except:
|
| 82 |
+
print('Failed to process {}'.format(image_name))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def create_sketches_dataset(source_path, target_path, sketcher, mult_list, dfm = False):
|
| 86 |
+
|
| 87 |
+
images = os.listdir(source_path)
|
| 88 |
+
for image_name in images:
|
| 89 |
+
try:
|
| 90 |
+
image = cv2.imread(os.path.join(source_path, image_name))
|
| 91 |
+
|
| 92 |
+
if image is None:
|
| 93 |
+
raise Exception()
|
| 94 |
+
|
| 95 |
+
for number, (sketch_image, dfm_image) in enumerate(get_sketches(image, sketcher, mult_list, dfm)):
|
| 96 |
+
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
|
| 97 |
+
cv2.imwrite(os.path.join(target_path, new_sketch_name), sketch_image)
|
| 98 |
+
|
| 99 |
+
if dfm:
|
| 100 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
|
| 101 |
+
cv2.imwrite(os.path.join(target_path, dfm_name), dfm_image)
|
| 102 |
+
|
| 103 |
+
except:
|
| 104 |
+
print('Failed to process {}'.format(image_name))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def create_dataset(source_path, target_path, sketcher, mult_list, side_size, dfm = False):
|
| 108 |
+
images = os.listdir(source_path)
|
| 109 |
+
|
| 110 |
+
color_path = os.path.join(target_path, 'color')
|
| 111 |
+
sketch_path = os.path.join(target_path, 'bw')
|
| 112 |
+
|
| 113 |
+
if not os.path.exists(color_path):
|
| 114 |
+
os.makedirs(color_path)
|
| 115 |
+
|
| 116 |
+
if not os.path.exists(sketch_path):
|
| 117 |
+
os.makedirs(sketch_path)
|
| 118 |
+
|
| 119 |
+
for image_name in images:
|
| 120 |
+
new_image_name = image_name[:image_name.rfind('.')] + '.png'
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
image = cv2.imread(os.path.join(source_path, image_name))
|
| 124 |
+
|
| 125 |
+
if image is None:
|
| 126 |
+
raise Exception()
|
| 127 |
+
|
| 128 |
+
resized_image = get_resized_image(image, side_size)
|
| 129 |
+
cv2.imwrite(os.path.join(color_path, new_image_name), resized_image)
|
| 130 |
+
|
| 131 |
+
for number, (sketch_image, dfm_image) in enumerate(get_sketches(resized_image, sketcher, mult_list, dfm)):
|
| 132 |
+
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
|
| 133 |
+
cv2.imwrite(os.path.join(sketch_path, new_sketch_name), sketch_image)
|
| 134 |
+
|
| 135 |
+
if dfm:
|
| 136 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
|
| 137 |
+
cv2.imwrite(os.path.join(sketch_path, dfm_name), dfm_image)
|
| 138 |
+
|
| 139 |
+
except:
|
| 140 |
+
print('Failed to process {}'.format(image_name))
|
| 141 |
+
|
utils/utils.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import scipy.stats as stats
|
| 5 |
+
import cv2
|
| 6 |
+
import json
|
| 7 |
+
import patoolib
|
| 8 |
+
import re
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from shutil import rmtree
|
| 11 |
+
|
| 12 |
+
def weights_init(m):
|
| 13 |
+
classname = m.__class__.__name__
|
| 14 |
+
if classname.find('Conv2d') != -1:
|
| 15 |
+
nn.init.xavier_uniform_(m.weight.data)
|
| 16 |
+
|
| 17 |
+
def weights_init_spectr(m):
|
| 18 |
+
classname = m.__class__.__name__
|
| 19 |
+
if classname.find('Conv2d') != -1:
|
| 20 |
+
nn.init.xavier_uniform_(m.weight_bar.data)
|
| 21 |
+
|
| 22 |
+
def generate_mask(height, width, mu = 1, sigma = 0.0005, prob = 0.5, full = True, full_prob = 0.01):
|
| 23 |
+
X = stats.truncnorm((0 - mu) / sigma, (1 - mu) / sigma, loc=mu, scale=sigma)
|
| 24 |
+
|
| 25 |
+
if full:
|
| 26 |
+
if (np.random.binomial(1, p = full_prob) == 1):
|
| 27 |
+
return torch.ones(1, height, width).float()
|
| 28 |
+
|
| 29 |
+
if np.random.binomial(1, p = prob) == 1:
|
| 30 |
+
mask = torch.rand(1, height, width).ge(X.rvs(1)[0]).float()
|
| 31 |
+
else:
|
| 32 |
+
mask = torch.zeros(1, height, width).float()
|
| 33 |
+
|
| 34 |
+
return mask
|
| 35 |
+
|
| 36 |
+
def resize_pad(img, size = 512):
|
| 37 |
+
|
| 38 |
+
if len(img.shape) == 2:
|
| 39 |
+
img = np.expand_dims(img, 2)
|
| 40 |
+
|
| 41 |
+
if img.shape[2] == 1:
|
| 42 |
+
img = np.repeat(img, 3, 2)
|
| 43 |
+
|
| 44 |
+
if img.shape[2] == 4:
|
| 45 |
+
img = img[:, :, :3]
|
| 46 |
+
|
| 47 |
+
pad = None
|
| 48 |
+
|
| 49 |
+
if (img.shape[0] < img.shape[1]):
|
| 50 |
+
height = img.shape[0]
|
| 51 |
+
ratio = height / size
|
| 52 |
+
width = int(np.ceil(img.shape[1] / ratio))
|
| 53 |
+
img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
|
| 54 |
+
|
| 55 |
+
new_width = width
|
| 56 |
+
while (new_width % 32 != 0):
|
| 57 |
+
new_width += 1
|
| 58 |
+
|
| 59 |
+
pad = (0, new_width - width)
|
| 60 |
+
|
| 61 |
+
img = np.pad(img, ((0, 0), (0, pad[1]), (0, 0)), 'maximum')
|
| 62 |
+
else:
|
| 63 |
+
width = img.shape[1]
|
| 64 |
+
ratio = width / size
|
| 65 |
+
height = int(np.ceil(img.shape[0] / ratio))
|
| 66 |
+
img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
|
| 67 |
+
|
| 68 |
+
new_height = height
|
| 69 |
+
while (new_height % 32 != 0):
|
| 70 |
+
new_height += 1
|
| 71 |
+
|
| 72 |
+
pad = (new_height - height, 0)
|
| 73 |
+
|
| 74 |
+
img = np.pad(img, ((0, pad[0]), (0, 0), (0, 0)), 'maximum')
|
| 75 |
+
|
| 76 |
+
if (img.dtype == 'float32'):
|
| 77 |
+
np.clip(img, 0, 1, out = img)
|
| 78 |
+
|
| 79 |
+
return img, pad
|
| 80 |
+
|
| 81 |
+
def open_json(file):
|
| 82 |
+
with open(file) as json_file:
|
| 83 |
+
data = json.load(json_file)
|
| 84 |
+
|
| 85 |
+
return data
|
| 86 |
+
|
| 87 |
+
def extract_cbr(file, out_dir):
|
| 88 |
+
patoolib.extract_archive(file, outdir = out_dir, verbosity = 1, interactive = False)
|
| 89 |
+
|
| 90 |
+
def create_cbz(file_path, files):
|
| 91 |
+
patoolib.create_archive(file_path, files, verbosity = 1, interactive = False)
|
| 92 |
+
|
| 93 |
+
def subfolder_image_search(start_folder):
|
| 94 |
+
return [x.as_posix() for x in Path(start_folder).rglob("*.[pPjJ][nNpP][gG]")]
|
| 95 |
+
|
| 96 |
+
def remove_folder(folder_path):
|
| 97 |
+
rmtree(folder_path)
|
| 98 |
+
|
| 99 |
+
def sorted_alphanumeric(data):
|
| 100 |
+
convert = lambda text: int(text) if text.isdigit() else text.lower()
|
| 101 |
+
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
|
| 102 |
+
return sorted(data, key=alphanum_key)
|
utils/xdog.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from cv2 import resize, INTER_LANCZOS4, INTER_AREA
|
| 2 |
+
from skimage.color import rgb2gray
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy.ndimage.filters import gaussian_filter
|
| 5 |
+
from skimage.filters import threshold_otsu
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
|
| 8 |
+
class XDoGSketcher:
|
| 9 |
+
|
| 10 |
+
def __init__(self, gamma = 0.95, phi = 89.25, eps = -0.1, k = 8, sigma = 0.5, mult = 1):
|
| 11 |
+
self.params = {}
|
| 12 |
+
self.params['gamma'] = gamma
|
| 13 |
+
self.params['phi'] = phi
|
| 14 |
+
self.params['eps'] = eps
|
| 15 |
+
self.params['k'] = k
|
| 16 |
+
self.params['sigma'] = sigma
|
| 17 |
+
|
| 18 |
+
self.params['mult'] = mult
|
| 19 |
+
|
| 20 |
+
def _xdog(self, im, **transform_params):
|
| 21 |
+
# Source : https://github.com/CemalUnal/XDoG-Filter
|
| 22 |
+
# Reference : XDoG: An eXtended difference-of-Gaussians compendium including advanced image stylization
|
| 23 |
+
# Link : http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.365.151&rep=rep1&type=pdf
|
| 24 |
+
|
| 25 |
+
if im.shape[2] == 3:
|
| 26 |
+
im = rgb2gray(im)
|
| 27 |
+
|
| 28 |
+
imf1 = gaussian_filter(im, transform_params['sigma'])
|
| 29 |
+
imf2 = gaussian_filter(im, transform_params['sigma'] * transform_params['k'])
|
| 30 |
+
imdiff = imf1 - transform_params['gamma'] * imf2
|
| 31 |
+
imdiff = (imdiff < transform_params['eps']) * 1.0 \
|
| 32 |
+
+ (imdiff >= transform_params['eps']) * (1.0 + np.tanh(transform_params['phi'] * imdiff))
|
| 33 |
+
imdiff -= imdiff.min()
|
| 34 |
+
imdiff /= imdiff.max()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
th = threshold_otsu(imdiff)
|
| 38 |
+
imdiff = imdiff >= th
|
| 39 |
+
|
| 40 |
+
imdiff = imdiff.astype('float32')
|
| 41 |
+
|
| 42 |
+
return imdiff
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_sketch(self, image, **kwargs):
|
| 46 |
+
current_params = self.params.copy()
|
| 47 |
+
|
| 48 |
+
for key in kwargs.keys():
|
| 49 |
+
if key in current_params.keys():
|
| 50 |
+
current_params[key] = kwargs[key]
|
| 51 |
+
|
| 52 |
+
result_image = self._xdog(image, **current_params)
|
| 53 |
+
|
| 54 |
+
return result_image
|
| 55 |
+
|
| 56 |
+
def get_sketch_with_resize(self, image, **kwargs):
|
| 57 |
+
if 'mult' in kwargs.keys():
|
| 58 |
+
mult = kwargs['mult']
|
| 59 |
+
else:
|
| 60 |
+
mult = self.params['mult']
|
| 61 |
+
|
| 62 |
+
temp_image = resize(image, (image.shape[1] * mult, image.shape[0] * mult), interpolation = INTER_LANCZOS4)
|
| 63 |
+
temp_image = self.get_sketch(temp_image, **kwargs)
|
| 64 |
+
image = resize(temp_image, (image.shape[1], image.shape[0]), interpolation = INTER_AREA)
|
| 65 |
+
|
| 66 |
+
return image
|
| 67 |
+
|
| 68 |
+
|
web.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask, request, jsonify, abort, redirect, url_for, render_template, send_file
|
| 2 |
+
from flask_wtf import FlaskForm
|
| 3 |
+
from wtforms import StringField, FileField, BooleanField, DecimalField
|
| 4 |
+
from wtforms.validators import DataRequired
|
| 5 |
+
from flask import after_this_request
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from model.models import Colorizer, Generator
|
| 11 |
+
from model.extractor import get_seresnext_extractor
|
| 12 |
+
from utils.xdog import XDoGSketcher
|
| 13 |
+
from utils.utils import open_json
|
| 14 |
+
from denoising.denoiser import FFDNetDenoiser
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
|
| 17 |
+
from inference import colorize_single_image, colorize_images, colorize_cbr
|
| 18 |
+
|
| 19 |
+
if torch.cuda.is_available():
|
| 20 |
+
device = 'cuda'
|
| 21 |
+
else:
|
| 22 |
+
device = 'cpu'
|
| 23 |
+
|
| 24 |
+
generator = Generator()
|
| 25 |
+
generator.load_state_dict(torch.load('model/generator.pth'))
|
| 26 |
+
|
| 27 |
+
extractor = get_seresnext_extractor()
|
| 28 |
+
extractor.load_state_dict(torch.load('model/extractor.pth'))
|
| 29 |
+
|
| 30 |
+
colorizer = Colorizer(generator, extractor)
|
| 31 |
+
colorizer = colorizer.eval().to(device)
|
| 32 |
+
|
| 33 |
+
sketcher = XDoGSketcher()
|
| 34 |
+
xdog_config = open_json('configs/xdog_config.json')
|
| 35 |
+
for key in xdog_config.keys():
|
| 36 |
+
if key in sketcher.params:
|
| 37 |
+
sketcher.params[key] = xdog_config[key]
|
| 38 |
+
|
| 39 |
+
denoiser = FFDNetDenoiser(device)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
app = Flask(__name__)
|
| 43 |
+
app.config.update(dict(
|
| 44 |
+
SECRET_KEY="lol kek",
|
| 45 |
+
WTF_CSRF_SECRET_KEY="cheburek"
|
| 46 |
+
))
|
| 47 |
+
|
| 48 |
+
color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'device':device, 'dfm' : True}
|
| 49 |
+
|
| 50 |
+
class SubmitForm(FlaskForm):
|
| 51 |
+
file = FileField(validators=[DataRequired()])
|
| 52 |
+
denoise = BooleanField(default = 'checked')
|
| 53 |
+
denoise_sigma = DecimalField(label = 'Denoise sigma', validators=[DataRequired()], default = 25, places = None)
|
| 54 |
+
autohint = BooleanField(default = None)
|
| 55 |
+
autohint_sigma = DecimalField(label = 'Autohint sigma', validators=[DataRequired()], default= 0.0003, places = None)
|
| 56 |
+
ignore_gray = BooleanField(label = 'Ignore gray autohint', default = None)
|
| 57 |
+
|
| 58 |
+
@app.route('/img/<path>')
|
| 59 |
+
def show_image(path):
|
| 60 |
+
return f'<img src="/static/{path}">'
|
| 61 |
+
|
| 62 |
+
@app.route('/', methods=('GET', 'POST'))
|
| 63 |
+
def submit_data():
|
| 64 |
+
form = SubmitForm()
|
| 65 |
+
if form.validate_on_submit():
|
| 66 |
+
|
| 67 |
+
input_data = form.file.data
|
| 68 |
+
|
| 69 |
+
_, ext = os.path.splitext(input_data.filename)
|
| 70 |
+
filename = str(datetime.now()) + ext
|
| 71 |
+
|
| 72 |
+
input_data.save(filename)
|
| 73 |
+
|
| 74 |
+
color_args['auto_hint'] = form.autohint.data
|
| 75 |
+
color_args['auto_hint_sigma'] = float(form.autohint_sigma.data)
|
| 76 |
+
color_args['ignore_gray'] = form.ignore_gray.data
|
| 77 |
+
color_args['denoiser'] = None
|
| 78 |
+
|
| 79 |
+
if form.denoise.data:
|
| 80 |
+
color_args['denoiser'] = denoiser
|
| 81 |
+
color_args['denoiser_sigma'] = float(form.denoise_sigma.data)
|
| 82 |
+
|
| 83 |
+
if ext.lower() in ('.cbr', '.cbz', '.rar', '.zip'):
|
| 84 |
+
result_name = colorize_cbr(filename, color_args)
|
| 85 |
+
os.remove(filename)
|
| 86 |
+
|
| 87 |
+
@after_this_request
|
| 88 |
+
def remove_file(response):
|
| 89 |
+
try:
|
| 90 |
+
os.remove(result_name)
|
| 91 |
+
except Exception as error:
|
| 92 |
+
app.logger.error("Error removing or closing downloaded file handle", error)
|
| 93 |
+
return response
|
| 94 |
+
|
| 95 |
+
return send_file(result_name, mimetype='application/vnd.comicbook-rar', attachment_filename=result_name, as_attachment=True)
|
| 96 |
+
|
| 97 |
+
elif ext.lower() in ('.jpg', '.png', ',jpeg'):
|
| 98 |
+
random_name = str(datetime.now()) + '.png'
|
| 99 |
+
new_image_path = os.path.join('static', random_name)
|
| 100 |
+
|
| 101 |
+
colorize_single_image(filename, new_image_path, color_args)
|
| 102 |
+
os.remove(filename)
|
| 103 |
+
|
| 104 |
+
return redirect(f'/img/{random_name}')
|
| 105 |
+
else:
|
| 106 |
+
return 'Wrong format'
|
| 107 |
+
|
| 108 |
+
return render_template('submit.html', form=form)
|