Upload 33 files
Browse files- .gitattributes +1 -0
- .gitignore +8 -0
- __pycache__/colorizator.cpython-39.pyc +0 -0
- colorizator.py +63 -0
- denoising/__pycache__/denoiser.cpython-39.pyc +0 -0
- denoising/__pycache__/functions.cpython-39.pyc +0 -0
- denoising/__pycache__/models.cpython-39.pyc +0 -0
- denoising/__pycache__/utils.cpython-39.pyc +0 -0
- denoising/denoiser.py +117 -0
- denoising/functions.py +101 -0
- denoising/models.py +100 -0
- denoising/models/net_rgb.pth +3 -0
- denoising/utils.py +66 -0
- figures/bw1.jpg +0 -0
- figures/bw2.jpg +0 -0
- figures/bw3.jpg +0 -0
- figures/bw4.jpg +0 -0
- figures/bw5.jpg +0 -0
- figures/bw6.jpg +0 -0
- figures/color1.png +0 -0
- figures/color2.png +3 -0
- figures/color3.png +0 -0
- figures/color4.png +0 -0
- figures/color5.png +0 -0
- figures/color6.png +0 -0
- inference.py +90 -0
- networks/__pycache__/extractor.cpython-39.pyc +0 -0
- networks/__pycache__/models.cpython-39.pyc +0 -0
- networks/extractor.py +127 -0
- networks/generator.zip +3 -0
- networks/models.py +319 -0
- readme.md +16 -0
- utils/__pycache__/utils.cpython-39.pyc +0 -0
- utils/utils.py +44 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
figures/color2.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.ipynb
|
| 2 |
+
*.pth
|
| 3 |
+
*.zip
|
| 4 |
+
|
| 5 |
+
__pycache__/
|
| 6 |
+
temp_colorization/
|
| 7 |
+
|
| 8 |
+
static/temp_images/
|
__pycache__/colorizator.cpython-39.pyc
ADDED
|
Binary file (2.7 kB). View file
|
|
|
colorizator.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
print(torch.__version__)
|
| 3 |
+
from torchvision.transforms import ToTensor
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from networks.models import Colorizer
|
| 7 |
+
from denoising.denoiser import FFDNetDenoiser
|
| 8 |
+
from utils.utils import resize_pad
|
| 9 |
+
|
| 10 |
+
class MangaColorizator:
|
| 11 |
+
def __init__(self, device, generator_path = 'networks/generator.zip', extractor_path = 'networks/extractor.pth'):
|
| 12 |
+
self.colorizer = Colorizer().to(device)
|
| 13 |
+
self.colorizer.generator.load_state_dict(torch.load(generator_path, map_location = device))
|
| 14 |
+
self.colorizer = self.colorizer.eval()
|
| 15 |
+
|
| 16 |
+
self.denoiser = FFDNetDenoiser(device)
|
| 17 |
+
|
| 18 |
+
self.current_image = None
|
| 19 |
+
self.current_hint = None
|
| 20 |
+
self.current_pad = None
|
| 21 |
+
|
| 22 |
+
self.device = device
|
| 23 |
+
|
| 24 |
+
def set_image(self, image, size = 576, apply_denoise = True, denoise_sigma = 25, transform = ToTensor()):
|
| 25 |
+
if (size % 32 != 0):
|
| 26 |
+
raise RuntimeError("size is not divisible by 32")
|
| 27 |
+
|
| 28 |
+
if apply_denoise:
|
| 29 |
+
image = self.denoiser.get_denoised_image(image, sigma = denoise_sigma)
|
| 30 |
+
|
| 31 |
+
image, self.current_pad = resize_pad(image, size)
|
| 32 |
+
self.current_image = transform(image).unsqueeze(0).to(self.device)
|
| 33 |
+
self.current_hint = torch.zeros(1, 4, self.current_image.shape[2], self.current_image.shape[3]).float().to(self.device)
|
| 34 |
+
|
| 35 |
+
def update_hint(self, hint, mask):
|
| 36 |
+
'''
|
| 37 |
+
Args:
|
| 38 |
+
hint: numpy.ndarray with shape (self.current_image.shape[2], self.current_image.shape[3], 3)
|
| 39 |
+
mask: numpy.ndarray with shape (self.current_image.shape[2], self.current_image.shape[3])
|
| 40 |
+
'''
|
| 41 |
+
|
| 42 |
+
if issubclass(hint.dtype.type, np.integer):
|
| 43 |
+
hint = hint.astype('float32') / 255
|
| 44 |
+
|
| 45 |
+
hint = (hint - 0.5) / 0.5
|
| 46 |
+
hint = torch.FloatTensor(hint).permute(2, 0, 1)
|
| 47 |
+
mask = torch.FloatTensor(np.expand_dims(mask, 0))
|
| 48 |
+
|
| 49 |
+
self.current_hint = torch.cat([hint * mask, mask], 0).unsqueeze(0).to(self.device)
|
| 50 |
+
|
| 51 |
+
def colorize(self):
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
fake_color, _ = self.colorizer(torch.cat([self.current_image, self.current_hint], 1))
|
| 54 |
+
fake_color = fake_color.detach()
|
| 55 |
+
|
| 56 |
+
result = fake_color[0].detach().cpu().permute(1, 2, 0) * 0.5 + 0.5
|
| 57 |
+
|
| 58 |
+
if self.current_pad[0] != 0:
|
| 59 |
+
result = result[:-self.current_pad[0]]
|
| 60 |
+
if self.current_pad[1] != 0:
|
| 61 |
+
result = result[:, :-self.current_pad[1]]
|
| 62 |
+
|
| 63 |
+
return result.numpy()
|
denoising/__pycache__/denoiser.cpython-39.pyc
ADDED
|
Binary file (3.48 kB). View file
|
|
|
denoising/__pycache__/functions.cpython-39.pyc
ADDED
|
Binary file (3.6 kB). View file
|
|
|
denoising/__pycache__/models.cpython-39.pyc
ADDED
|
Binary file (3.49 kB). View file
|
|
|
denoising/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (2.39 kB). View file
|
|
|
denoising/denoiser.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Denoise an image with the FFDNet denoising method
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 5 |
+
|
| 6 |
+
This program is free software: you can use, modify and/or
|
| 7 |
+
redistribute it under the terms of the GNU General Public
|
| 8 |
+
License as published by the Free Software Foundation, either
|
| 9 |
+
version 3 of the License, or (at your option) any later
|
| 10 |
+
version. You should have received a copy of this license along
|
| 11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 12 |
+
"""
|
| 13 |
+
import os
|
| 14 |
+
import argparse
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import cv2
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from torch.autograd import Variable
|
| 23 |
+
from .models import FFDNet
|
| 24 |
+
from .utils import normalize, variable_to_cv2_image, remove_dataparallel_wrapper, is_rgb
|
| 25 |
+
|
| 26 |
+
class FFDNetDenoiser:
|
| 27 |
+
def __init__(self, _device, _sigma = 25, _weights_dir = 'denoising/models/', _in_ch = 3):
|
| 28 |
+
self.sigma = _sigma / 255
|
| 29 |
+
self.weights_dir = _weights_dir
|
| 30 |
+
self.channels = _in_ch
|
| 31 |
+
self.device = _device
|
| 32 |
+
|
| 33 |
+
self.model = FFDNet(num_input_channels = _in_ch)
|
| 34 |
+
self.load_weights()
|
| 35 |
+
self.model.eval()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_weights(self):
|
| 39 |
+
weights_name = 'net_rgb.pth' if self.channels == 3 else 'net_gray.pth'
|
| 40 |
+
weights_path = os.path.join(self.weights_dir, weights_name)
|
| 41 |
+
if self.device == 'cuda':
|
| 42 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
| 43 |
+
device_ids = [0]
|
| 44 |
+
self.model = nn.DataParallel(self.model, device_ids=device_ids).cuda()
|
| 45 |
+
else:
|
| 46 |
+
state_dict = torch.load(weights_path, map_location='cpu')
|
| 47 |
+
# CPU mode: remove the DataParallel wrapper
|
| 48 |
+
state_dict = remove_dataparallel_wrapper(state_dict)
|
| 49 |
+
self.model.load_state_dict(state_dict)
|
| 50 |
+
|
| 51 |
+
def get_denoised_image(self, imorig, sigma = None):
|
| 52 |
+
|
| 53 |
+
if sigma is not None:
|
| 54 |
+
cur_sigma = sigma / 255
|
| 55 |
+
else:
|
| 56 |
+
cur_sigma = self.sigma
|
| 57 |
+
|
| 58 |
+
if len(imorig.shape) < 3 or imorig.shape[2] == 1:
|
| 59 |
+
imorig = np.repeat(np.expand_dims(imorig, 2), 3, 2)
|
| 60 |
+
|
| 61 |
+
imorig = imorig[..., :3]
|
| 62 |
+
|
| 63 |
+
if (max(imorig.shape[0], imorig.shape[1]) > 1200):
|
| 64 |
+
ratio = max(imorig.shape[0], imorig.shape[1]) / 1200
|
| 65 |
+
imorig = cv2.resize(imorig, (int(imorig.shape[1] / ratio), int(imorig.shape[0] / ratio)), interpolation = cv2.INTER_AREA)
|
| 66 |
+
|
| 67 |
+
imorig = imorig.transpose(2, 0, 1)
|
| 68 |
+
|
| 69 |
+
if (imorig.max() > 1.2):
|
| 70 |
+
imorig = normalize(imorig)
|
| 71 |
+
imorig = np.expand_dims(imorig, 0)
|
| 72 |
+
|
| 73 |
+
# Handle odd sizes
|
| 74 |
+
expanded_h = False
|
| 75 |
+
expanded_w = False
|
| 76 |
+
sh_im = imorig.shape
|
| 77 |
+
if sh_im[2]%2 == 1:
|
| 78 |
+
expanded_h = True
|
| 79 |
+
imorig = np.concatenate((imorig, imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
|
| 80 |
+
|
| 81 |
+
if sh_im[3]%2 == 1:
|
| 82 |
+
expanded_w = True
|
| 83 |
+
imorig = np.concatenate((imorig, imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
imorig = torch.Tensor(imorig)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Sets data type according to CPU or GPU modes
|
| 90 |
+
if self.device == 'cuda':
|
| 91 |
+
dtype = torch.cuda.FloatTensor
|
| 92 |
+
else:
|
| 93 |
+
dtype = torch.FloatTensor
|
| 94 |
+
|
| 95 |
+
imnoisy = imorig.clone()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
imorig, imnoisy = imorig.type(dtype), imnoisy.type(dtype)
|
| 100 |
+
nsigma = torch.FloatTensor([cur_sigma]).type(dtype)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# Estimate noise and subtract it to the input image
|
| 104 |
+
im_noise_estim = self.model(imnoisy, nsigma)
|
| 105 |
+
outim = torch.clamp(imnoisy-im_noise_estim, 0., 1.)
|
| 106 |
+
|
| 107 |
+
if expanded_h:
|
| 108 |
+
imorig = imorig[:, :, :-1, :]
|
| 109 |
+
outim = outim[:, :, :-1, :]
|
| 110 |
+
imnoisy = imnoisy[:, :, :-1, :]
|
| 111 |
+
|
| 112 |
+
if expanded_w:
|
| 113 |
+
imorig = imorig[:, :, :, :-1]
|
| 114 |
+
outim = outim[:, :, :, :-1]
|
| 115 |
+
imnoisy = imnoisy[:, :, :, :-1]
|
| 116 |
+
|
| 117 |
+
return variable_to_cv2_image(outim)
|
denoising/functions.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Functions implementing custom NN layers
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 5 |
+
|
| 6 |
+
This program is free software: you can use, modify and/or
|
| 7 |
+
redistribute it under the terms of the GNU General Public
|
| 8 |
+
License as published by the Free Software Foundation, either
|
| 9 |
+
version 3 of the License, or (at your option) any later
|
| 10 |
+
version. You should have received a copy of this license along
|
| 11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 12 |
+
"""
|
| 13 |
+
import torch
|
| 14 |
+
from torch.autograd import Function, Variable
|
| 15 |
+
|
| 16 |
+
def concatenate_input_noise_map(input, noise_sigma):
|
| 17 |
+
r"""Implements the first layer of FFDNet. This function returns a
|
| 18 |
+
torch.autograd.Variable composed of the concatenation of the downsampled
|
| 19 |
+
input image and the noise map. Each image of the batch of size CxHxW gets
|
| 20 |
+
converted to an array of size 4*CxH/2xW/2. Each of the pixels of the
|
| 21 |
+
non-overlapped 2x2 patches of the input image are placed in the new array
|
| 22 |
+
along the first dimension.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
input: batch containing CxHxW images
|
| 26 |
+
noise_sigma: the value of the pixels of the CxH/2xW/2 noise map
|
| 27 |
+
"""
|
| 28 |
+
# noise_sigma is a list of length batch_size
|
| 29 |
+
N, C, H, W = input.size()
|
| 30 |
+
dtype = input.type()
|
| 31 |
+
sca = 2
|
| 32 |
+
sca2 = sca*sca
|
| 33 |
+
Cout = sca2*C
|
| 34 |
+
Hout = H//sca
|
| 35 |
+
Wout = W//sca
|
| 36 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
| 37 |
+
|
| 38 |
+
# Fill the downsampled image with zeros
|
| 39 |
+
if 'cuda' in dtype:
|
| 40 |
+
downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0)
|
| 41 |
+
else:
|
| 42 |
+
downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0)
|
| 43 |
+
|
| 44 |
+
# Build the CxH/2xW/2 noise map
|
| 45 |
+
noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout)
|
| 46 |
+
|
| 47 |
+
# Populate output
|
| 48 |
+
for idx in range(sca2):
|
| 49 |
+
downsampledfeatures[:, idx:Cout:sca2, :, :] = \
|
| 50 |
+
input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
|
| 51 |
+
|
| 52 |
+
# concatenate de-interleaved mosaic with noise map
|
| 53 |
+
return torch.cat((noise_map, downsampledfeatures), 1)
|
| 54 |
+
|
| 55 |
+
class UpSampleFeaturesFunction(Function):
|
| 56 |
+
r"""Extends PyTorch's modules by implementing a torch.autograd.Function.
|
| 57 |
+
This class implements the forward and backward methods of the last layer
|
| 58 |
+
of FFDNet. It basically performs the inverse of
|
| 59 |
+
concatenate_input_noise_map(): it converts each of the images of a
|
| 60 |
+
batch of size CxH/2xW/2 to images of size C/4xHxW
|
| 61 |
+
"""
|
| 62 |
+
@staticmethod
|
| 63 |
+
def forward(ctx, input):
|
| 64 |
+
N, Cin, Hin, Win = input.size()
|
| 65 |
+
dtype = input.type()
|
| 66 |
+
sca = 2
|
| 67 |
+
sca2 = sca*sca
|
| 68 |
+
Cout = Cin//sca2
|
| 69 |
+
Hout = Hin*sca
|
| 70 |
+
Wout = Win*sca
|
| 71 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
| 72 |
+
|
| 73 |
+
assert (Cin%sca2 == 0), 'Invalid input dimensions: number of channels should be divisible by 4'
|
| 74 |
+
|
| 75 |
+
result = torch.zeros((N, Cout, Hout, Wout)).type(dtype)
|
| 76 |
+
for idx in range(sca2):
|
| 77 |
+
result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = input[:, idx:Cin:sca2, :, :]
|
| 78 |
+
|
| 79 |
+
return result
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def backward(ctx, grad_output):
|
| 83 |
+
N, Cg_out, Hg_out, Wg_out = grad_output.size()
|
| 84 |
+
dtype = grad_output.data.type()
|
| 85 |
+
sca = 2
|
| 86 |
+
sca2 = sca*sca
|
| 87 |
+
Cg_in = sca2*Cg_out
|
| 88 |
+
Hg_in = Hg_out//sca
|
| 89 |
+
Wg_in = Wg_out//sca
|
| 90 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
| 91 |
+
|
| 92 |
+
# Build output
|
| 93 |
+
grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype)
|
| 94 |
+
# Populate output
|
| 95 |
+
for idx in range(sca2):
|
| 96 |
+
grad_input[:, idx:Cg_in:sca2, :, :] = grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
|
| 97 |
+
|
| 98 |
+
return Variable(grad_input)
|
| 99 |
+
|
| 100 |
+
# Alias functions
|
| 101 |
+
upsamplefeatures = UpSampleFeaturesFunction.apply
|
denoising/models.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Definition of the FFDNet model and its custom layers
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 5 |
+
|
| 6 |
+
This program is free software: you can use, modify and/or
|
| 7 |
+
redistribute it under the terms of the GNU General Public
|
| 8 |
+
License as published by the Free Software Foundation, either
|
| 9 |
+
version 3 of the License, or (at your option) any later
|
| 10 |
+
version. You should have received a copy of this license along
|
| 11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 12 |
+
"""
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
import denoising.functions as functions
|
| 16 |
+
|
| 17 |
+
class UpSampleFeatures(nn.Module):
|
| 18 |
+
r"""Implements the last layer of FFDNet
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super(UpSampleFeatures, self).__init__()
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return functions.upsamplefeatures(x)
|
| 24 |
+
|
| 25 |
+
class IntermediateDnCNN(nn.Module):
|
| 26 |
+
r"""Implements the middel part of the FFDNet architecture, which
|
| 27 |
+
is basically a DnCNN net
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self, input_features, middle_features, num_conv_layers):
|
| 30 |
+
super(IntermediateDnCNN, self).__init__()
|
| 31 |
+
self.kernel_size = 3
|
| 32 |
+
self.padding = 1
|
| 33 |
+
self.input_features = input_features
|
| 34 |
+
self.num_conv_layers = num_conv_layers
|
| 35 |
+
self.middle_features = middle_features
|
| 36 |
+
if self.input_features == 5:
|
| 37 |
+
self.output_features = 4 #Grayscale image
|
| 38 |
+
elif self.input_features == 15:
|
| 39 |
+
self.output_features = 12 #RGB image
|
| 40 |
+
else:
|
| 41 |
+
raise Exception('Invalid number of input features')
|
| 42 |
+
|
| 43 |
+
layers = []
|
| 44 |
+
layers.append(nn.Conv2d(in_channels=self.input_features,\
|
| 45 |
+
out_channels=self.middle_features,\
|
| 46 |
+
kernel_size=self.kernel_size,\
|
| 47 |
+
padding=self.padding,\
|
| 48 |
+
bias=False))
|
| 49 |
+
layers.append(nn.ReLU(inplace=True))
|
| 50 |
+
for _ in range(self.num_conv_layers-2):
|
| 51 |
+
layers.append(nn.Conv2d(in_channels=self.middle_features,\
|
| 52 |
+
out_channels=self.middle_features,\
|
| 53 |
+
kernel_size=self.kernel_size,\
|
| 54 |
+
padding=self.padding,\
|
| 55 |
+
bias=False))
|
| 56 |
+
layers.append(nn.BatchNorm2d(self.middle_features))
|
| 57 |
+
layers.append(nn.ReLU(inplace=True))
|
| 58 |
+
layers.append(nn.Conv2d(in_channels=self.middle_features,\
|
| 59 |
+
out_channels=self.output_features,\
|
| 60 |
+
kernel_size=self.kernel_size,\
|
| 61 |
+
padding=self.padding,\
|
| 62 |
+
bias=False))
|
| 63 |
+
self.itermediate_dncnn = nn.Sequential(*layers)
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
out = self.itermediate_dncnn(x)
|
| 66 |
+
return out
|
| 67 |
+
|
| 68 |
+
class FFDNet(nn.Module):
|
| 69 |
+
r"""Implements the FFDNet architecture
|
| 70 |
+
"""
|
| 71 |
+
def __init__(self, num_input_channels):
|
| 72 |
+
super(FFDNet, self).__init__()
|
| 73 |
+
self.num_input_channels = num_input_channels
|
| 74 |
+
if self.num_input_channels == 1:
|
| 75 |
+
# Grayscale image
|
| 76 |
+
self.num_feature_maps = 64
|
| 77 |
+
self.num_conv_layers = 15
|
| 78 |
+
self.downsampled_channels = 5
|
| 79 |
+
self.output_features = 4
|
| 80 |
+
elif self.num_input_channels == 3:
|
| 81 |
+
# RGB image
|
| 82 |
+
self.num_feature_maps = 96
|
| 83 |
+
self.num_conv_layers = 12
|
| 84 |
+
self.downsampled_channels = 15
|
| 85 |
+
self.output_features = 12
|
| 86 |
+
else:
|
| 87 |
+
raise Exception('Invalid number of input features')
|
| 88 |
+
|
| 89 |
+
self.intermediate_dncnn = IntermediateDnCNN(\
|
| 90 |
+
input_features=self.downsampled_channels,\
|
| 91 |
+
middle_features=self.num_feature_maps,\
|
| 92 |
+
num_conv_layers=self.num_conv_layers)
|
| 93 |
+
self.upsamplefeatures = UpSampleFeatures()
|
| 94 |
+
|
| 95 |
+
def forward(self, x, noise_sigma):
|
| 96 |
+
concat_noise_x = functions.concatenate_input_noise_map(x.data, noise_sigma.data)
|
| 97 |
+
concat_noise_x = Variable(concat_noise_x)
|
| 98 |
+
h_dncnn = self.intermediate_dncnn(concat_noise_x)
|
| 99 |
+
pred_noise = self.upsamplefeatures(h_dncnn)
|
| 100 |
+
return pred_noise
|
denoising/models/net_rgb.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0fe98bfd2ac870b15f360661b1c4789eecefc6dc2e4462842a0dd15e149a0433
|
| 3 |
+
size 3435567
|
denoising/utils.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Different utilities such as orthogonalization of weights, initialization of
|
| 3 |
+
loggers, etc
|
| 4 |
+
|
| 5 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 6 |
+
|
| 7 |
+
This program is free software: you can use, modify and/or
|
| 8 |
+
redistribute it under the terms of the GNU General Public
|
| 9 |
+
License as published by the Free Software Foundation, either
|
| 10 |
+
version 3 of the License, or (at your option) any later
|
| 11 |
+
version. You should have received a copy of this license along
|
| 12 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 13 |
+
"""
|
| 14 |
+
import numpy as np
|
| 15 |
+
import cv2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def variable_to_cv2_image(varim):
|
| 19 |
+
r"""Converts a torch.autograd.Variable to an OpenCV image
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
varim: a torch.autograd.Variable
|
| 23 |
+
"""
|
| 24 |
+
nchannels = varim.size()[1]
|
| 25 |
+
if nchannels == 1:
|
| 26 |
+
res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8)
|
| 27 |
+
elif nchannels == 3:
|
| 28 |
+
res = varim.data.cpu().numpy()[0]
|
| 29 |
+
res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
|
| 30 |
+
res = (res*255.).clip(0, 255).astype(np.uint8)
|
| 31 |
+
else:
|
| 32 |
+
raise Exception('Number of color channels not supported')
|
| 33 |
+
return res
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def normalize(data):
|
| 37 |
+
return np.float32(data/255.)
|
| 38 |
+
|
| 39 |
+
def remove_dataparallel_wrapper(state_dict):
|
| 40 |
+
r"""Converts a DataParallel model to a normal one by removing the "module."
|
| 41 |
+
wrapper in the module dictionary
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
state_dict: a torch.nn.DataParallel state dictionary
|
| 45 |
+
"""
|
| 46 |
+
from collections import OrderedDict
|
| 47 |
+
|
| 48 |
+
new_state_dict = OrderedDict()
|
| 49 |
+
for k, vl in state_dict.items():
|
| 50 |
+
name = k[7:] # remove 'module.' of DataParallel
|
| 51 |
+
new_state_dict[name] = vl
|
| 52 |
+
|
| 53 |
+
return new_state_dict
|
| 54 |
+
|
| 55 |
+
def is_rgb(im_path):
|
| 56 |
+
r""" Returns True if the image in im_path is an RGB image
|
| 57 |
+
"""
|
| 58 |
+
from skimage.io import imread
|
| 59 |
+
rgb = False
|
| 60 |
+
im = imread(im_path)
|
| 61 |
+
if (len(im.shape) == 3):
|
| 62 |
+
if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])):
|
| 63 |
+
rgb = True
|
| 64 |
+
print("rgb: {}".format(rgb))
|
| 65 |
+
print("im shape: {}".format(im.shape))
|
| 66 |
+
return rgb
|
figures/bw1.jpg
ADDED
|
figures/bw2.jpg
ADDED
|
figures/bw3.jpg
ADDED
|
figures/bw4.jpg
ADDED
|
figures/bw5.jpg
ADDED
|
figures/bw6.jpg
ADDED
|
figures/color1.png
ADDED
|
figures/color2.png
ADDED
|
Git LFS Details
|
figures/color3.png
ADDED
|
figures/color4.png
ADDED
|
figures/color5.png
ADDED
|
figures/color6.png
ADDED
|
inference.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
|
| 8 |
+
from colorizator import MangaColorizator
|
| 9 |
+
|
| 10 |
+
def process_image(image, colorizator, args):
|
| 11 |
+
colorizator.set_image(image, args.size, args.denoiser, args.denoiser_sigma)
|
| 12 |
+
|
| 13 |
+
return colorizator.colorize()
|
| 14 |
+
|
| 15 |
+
def colorize_single_image(image_path, save_path, colorizator, args):
|
| 16 |
+
|
| 17 |
+
image = plt.imread(image_path)
|
| 18 |
+
|
| 19 |
+
colorization = process_image(image, colorizator, args)
|
| 20 |
+
|
| 21 |
+
plt.imsave(save_path, colorization)
|
| 22 |
+
|
| 23 |
+
return True
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def colorize_images(target_path, colorizator, args):
|
| 27 |
+
images = os.listdir(args.path)
|
| 28 |
+
|
| 29 |
+
for image_name in images:
|
| 30 |
+
file_path = os.path.join(args.path, image_name)
|
| 31 |
+
|
| 32 |
+
if os.path.isdir(file_path):
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
name, ext = os.path.splitext(image_name)
|
| 36 |
+
if (ext != '.png'):
|
| 37 |
+
image_name = name + '.png'
|
| 38 |
+
|
| 39 |
+
print(file_path)
|
| 40 |
+
|
| 41 |
+
save_path = os.path.join(target_path, image_name)
|
| 42 |
+
colorize_single_image(file_path, save_path, colorizator, args)
|
| 43 |
+
|
| 44 |
+
def parse_args():
|
| 45 |
+
parser = argparse.ArgumentParser()
|
| 46 |
+
parser.add_argument("-p", "--path", required=True)
|
| 47 |
+
parser.add_argument("-gen", "--generator", default = 'networks/generator.zip')
|
| 48 |
+
parser.add_argument("-ext", "--extractor", default = 'networks/extractor.pth')
|
| 49 |
+
parser.add_argument('-g', '--gpu', dest = 'gpu', action = 'store_true')
|
| 50 |
+
parser.add_argument('-nd', '--no_denoise', dest = 'denoiser', action = 'store_false')
|
| 51 |
+
parser.add_argument("-ds", "--denoiser_sigma", type = int, default = 25)
|
| 52 |
+
parser.add_argument("-s", "--size", type = int, default = 576)
|
| 53 |
+
parser.set_defaults(gpu = False)
|
| 54 |
+
parser.set_defaults(denoiser = True)
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
|
| 57 |
+
return args
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
|
| 62 |
+
args = parse_args()
|
| 63 |
+
|
| 64 |
+
if args.gpu:
|
| 65 |
+
device = 'cuda'
|
| 66 |
+
else:
|
| 67 |
+
device = 'cpu'
|
| 68 |
+
|
| 69 |
+
colorizer = MangaColorizator(device, args.generator, args.extractor)
|
| 70 |
+
|
| 71 |
+
if os.path.isdir(args.path):
|
| 72 |
+
colorization_path = os.path.join(args.path, 'colorization')
|
| 73 |
+
if not os.path.exists(colorization_path):
|
| 74 |
+
os.makedirs(colorization_path)
|
| 75 |
+
|
| 76 |
+
colorize_images(colorization_path, colorizer, args)
|
| 77 |
+
|
| 78 |
+
elif os.path.isfile(args.path):
|
| 79 |
+
|
| 80 |
+
split = os.path.splitext(args.path)
|
| 81 |
+
|
| 82 |
+
if split[1].lower() in ('.jpg', '.png', ',jpeg'):
|
| 83 |
+
new_image_path = split[0] + '_colorized' + '.png'
|
| 84 |
+
|
| 85 |
+
colorize_single_image(args.path, new_image_path, colorizer, args)
|
| 86 |
+
else:
|
| 87 |
+
print('Wrong format')
|
| 88 |
+
else:
|
| 89 |
+
print('Wrong path')
|
| 90 |
+
|
networks/__pycache__/extractor.cpython-39.pyc
ADDED
|
Binary file (3.82 kB). View file
|
|
|
networks/__pycache__/models.cpython-39.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
networks/extractor.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
'''https://github.com/blandocs/Tag2Pix/blob/master/model/pretrained.py'''
|
| 6 |
+
|
| 7 |
+
# Pretrained version
|
| 8 |
+
class Selayer(nn.Module):
|
| 9 |
+
def __init__(self, inplanes):
|
| 10 |
+
super(Selayer, self).__init__()
|
| 11 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 12 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
|
| 13 |
+
self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
|
| 14 |
+
self.relu = nn.ReLU(inplace=True)
|
| 15 |
+
self.sigmoid = nn.Sigmoid()
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
out = self.global_avgpool(x)
|
| 19 |
+
out = self.conv1(out)
|
| 20 |
+
out = self.relu(out)
|
| 21 |
+
out = self.conv2(out)
|
| 22 |
+
out = self.sigmoid(out)
|
| 23 |
+
|
| 24 |
+
return x * out
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BottleneckX_Origin(nn.Module):
|
| 28 |
+
expansion = 4
|
| 29 |
+
|
| 30 |
+
def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
|
| 31 |
+
super(BottleneckX_Origin, self).__init__()
|
| 32 |
+
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
|
| 33 |
+
self.bn1 = nn.BatchNorm2d(planes * 2)
|
| 34 |
+
|
| 35 |
+
self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride,
|
| 36 |
+
padding=1, groups=cardinality, bias=False)
|
| 37 |
+
self.bn2 = nn.BatchNorm2d(planes * 2)
|
| 38 |
+
|
| 39 |
+
self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
|
| 40 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 41 |
+
|
| 42 |
+
self.selayer = Selayer(planes * 4)
|
| 43 |
+
|
| 44 |
+
self.relu = nn.ReLU(inplace=True)
|
| 45 |
+
self.downsample = downsample
|
| 46 |
+
self.stride = stride
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
residual = x
|
| 50 |
+
|
| 51 |
+
out = self.conv1(x)
|
| 52 |
+
out = self.bn1(out)
|
| 53 |
+
out = self.relu(out)
|
| 54 |
+
|
| 55 |
+
out = self.conv2(out)
|
| 56 |
+
out = self.bn2(out)
|
| 57 |
+
out = self.relu(out)
|
| 58 |
+
|
| 59 |
+
out = self.conv3(out)
|
| 60 |
+
out = self.bn3(out)
|
| 61 |
+
|
| 62 |
+
out = self.selayer(out)
|
| 63 |
+
|
| 64 |
+
if self.downsample is not None:
|
| 65 |
+
residual = self.downsample(x)
|
| 66 |
+
|
| 67 |
+
out += residual
|
| 68 |
+
out = self.relu(out)
|
| 69 |
+
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
class SEResNeXt_Origin(nn.Module):
|
| 73 |
+
def __init__(self, block, layers, input_channels=3, cardinality=32, num_classes=1000):
|
| 74 |
+
super(SEResNeXt_Origin, self).__init__()
|
| 75 |
+
self.cardinality = cardinality
|
| 76 |
+
self.inplanes = 64
|
| 77 |
+
self.input_channels = input_channels
|
| 78 |
+
|
| 79 |
+
self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
|
| 80 |
+
bias=False)
|
| 81 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 82 |
+
self.relu = nn.ReLU(inplace=True)
|
| 83 |
+
|
| 84 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 85 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 86 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 87 |
+
|
| 88 |
+
for m in self.modules():
|
| 89 |
+
if isinstance(m, nn.Conv2d):
|
| 90 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 91 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 92 |
+
if m.bias is not None:
|
| 93 |
+
m.bias.data.zero_()
|
| 94 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 95 |
+
m.weight.data.fill_(1)
|
| 96 |
+
m.bias.data.zero_()
|
| 97 |
+
|
| 98 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 99 |
+
downsample = None
|
| 100 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 101 |
+
downsample = nn.Sequential(
|
| 102 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 103 |
+
kernel_size=1, stride=stride, bias=False),
|
| 104 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
layers = []
|
| 108 |
+
layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
|
| 109 |
+
self.inplanes = planes * block.expansion
|
| 110 |
+
for i in range(1, blocks):
|
| 111 |
+
layers.append(block(self.inplanes, planes, self.cardinality))
|
| 112 |
+
|
| 113 |
+
return nn.Sequential(*layers)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
|
| 117 |
+
x = self.conv1(x)
|
| 118 |
+
x = self.bn1(x)
|
| 119 |
+
x1 = self.relu(x)
|
| 120 |
+
|
| 121 |
+
x2 = self.layer1(x1)
|
| 122 |
+
|
| 123 |
+
x3 = self.layer2(x2)
|
| 124 |
+
|
| 125 |
+
x4 = self.layer3(x3)
|
| 126 |
+
|
| 127 |
+
return x1, x2, x3, x4
|
networks/generator.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae9bc204753267a38eeb43d262fee9c96fb1b5035fd89bbdf567fc69e5d3ebd1
|
| 3 |
+
size 202088636
|
networks/models.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.models as M
|
| 5 |
+
import math
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from torch.nn import Parameter
|
| 8 |
+
|
| 9 |
+
from .extractor import SEResNeXt_Origin, BottleneckX_Origin
|
| 10 |
+
|
| 11 |
+
'''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
|
| 12 |
+
|
| 13 |
+
def l2normalize(v, eps=1e-12):
|
| 14 |
+
return v / (v.norm() + eps)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SpectralNorm(nn.Module):
|
| 18 |
+
def __init__(self, module, name='weight', power_iterations=1):
|
| 19 |
+
super(SpectralNorm, self).__init__()
|
| 20 |
+
self.module = module
|
| 21 |
+
self.name = name
|
| 22 |
+
self.power_iterations = power_iterations
|
| 23 |
+
if not self._made_params():
|
| 24 |
+
self._make_params()
|
| 25 |
+
|
| 26 |
+
def _update_u_v(self):
|
| 27 |
+
u = getattr(self.module, self.name + "_u")
|
| 28 |
+
v = getattr(self.module, self.name + "_v")
|
| 29 |
+
w = getattr(self.module, self.name + "_bar")
|
| 30 |
+
|
| 31 |
+
height = w.data.shape[0]
|
| 32 |
+
for _ in range(self.power_iterations):
|
| 33 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
|
| 34 |
+
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
|
| 35 |
+
|
| 36 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
| 37 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
| 38 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
| 39 |
+
|
| 40 |
+
def _made_params(self):
|
| 41 |
+
try:
|
| 42 |
+
u = getattr(self.module, self.name + "_u")
|
| 43 |
+
v = getattr(self.module, self.name + "_v")
|
| 44 |
+
w = getattr(self.module, self.name + "_bar")
|
| 45 |
+
return True
|
| 46 |
+
except AttributeError:
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _make_params(self):
|
| 51 |
+
w = getattr(self.module, self.name)
|
| 52 |
+
height = w.data.shape[0]
|
| 53 |
+
width = w.view(height, -1).data.shape[1]
|
| 54 |
+
|
| 55 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
| 56 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
| 57 |
+
u.data = l2normalize(u.data)
|
| 58 |
+
v.data = l2normalize(v.data)
|
| 59 |
+
w_bar = Parameter(w.data)
|
| 60 |
+
|
| 61 |
+
del self.module._parameters[self.name]
|
| 62 |
+
|
| 63 |
+
self.module.register_parameter(self.name + "_u", u)
|
| 64 |
+
self.module.register_parameter(self.name + "_v", v)
|
| 65 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def forward(self, *args):
|
| 69 |
+
self._update_u_v()
|
| 70 |
+
return self.module.forward(*args)
|
| 71 |
+
|
| 72 |
+
class Selayer(nn.Module):
|
| 73 |
+
def __init__(self, inplanes):
|
| 74 |
+
super(Selayer, self).__init__()
|
| 75 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 76 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
|
| 77 |
+
self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
|
| 78 |
+
self.relu = nn.ReLU(inplace=True)
|
| 79 |
+
self.sigmoid = nn.Sigmoid()
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
out = self.global_avgpool(x)
|
| 83 |
+
out = self.conv1(out)
|
| 84 |
+
out = self.relu(out)
|
| 85 |
+
out = self.conv2(out)
|
| 86 |
+
out = self.sigmoid(out)
|
| 87 |
+
|
| 88 |
+
return x * out
|
| 89 |
+
|
| 90 |
+
class SelayerSpectr(nn.Module):
|
| 91 |
+
def __init__(self, inplanes):
|
| 92 |
+
super(SelayerSpectr, self).__init__()
|
| 93 |
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
| 94 |
+
self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
|
| 95 |
+
self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
|
| 96 |
+
self.relu = nn.ReLU(inplace=True)
|
| 97 |
+
self.sigmoid = nn.Sigmoid()
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
out = self.global_avgpool(x)
|
| 101 |
+
out = self.conv1(out)
|
| 102 |
+
out = self.relu(out)
|
| 103 |
+
out = self.conv2(out)
|
| 104 |
+
out = self.sigmoid(out)
|
| 105 |
+
|
| 106 |
+
return x * out
|
| 107 |
+
|
| 108 |
+
class ResNeXtBottleneck(nn.Module):
|
| 109 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
| 110 |
+
super(ResNeXtBottleneck, self).__init__()
|
| 111 |
+
D = out_channels // 2
|
| 112 |
+
self.out_channels = out_channels
|
| 113 |
+
self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
|
| 114 |
+
self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
| 115 |
+
groups=cardinality,
|
| 116 |
+
bias=False)
|
| 117 |
+
self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
| 118 |
+
self.shortcut = nn.Sequential()
|
| 119 |
+
if stride != 1:
|
| 120 |
+
self.shortcut.add_module('shortcut',
|
| 121 |
+
nn.AvgPool2d(2, stride=2))
|
| 122 |
+
|
| 123 |
+
self.selayer = Selayer(out_channels)
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
bottleneck = self.conv_reduce.forward(x)
|
| 127 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 128 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
| 129 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 130 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
| 131 |
+
bottleneck = self.selayer(bottleneck)
|
| 132 |
+
|
| 133 |
+
x = self.shortcut.forward(x)
|
| 134 |
+
return x + bottleneck
|
| 135 |
+
|
| 136 |
+
class SpectrResNeXtBottleneck(nn.Module):
|
| 137 |
+
def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
|
| 138 |
+
super(SpectrResNeXtBottleneck, self).__init__()
|
| 139 |
+
D = out_channels // 2
|
| 140 |
+
self.out_channels = out_channels
|
| 141 |
+
self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
|
| 142 |
+
self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
|
| 143 |
+
groups=cardinality,
|
| 144 |
+
bias=False))
|
| 145 |
+
self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
|
| 146 |
+
self.shortcut = nn.Sequential()
|
| 147 |
+
if stride != 1:
|
| 148 |
+
self.shortcut.add_module('shortcut',
|
| 149 |
+
nn.AvgPool2d(2, stride=2))
|
| 150 |
+
|
| 151 |
+
self.selayer = SelayerSpectr(out_channels)
|
| 152 |
+
|
| 153 |
+
def forward(self, x):
|
| 154 |
+
bottleneck = self.conv_reduce.forward(x)
|
| 155 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 156 |
+
bottleneck = self.conv_conv.forward(bottleneck)
|
| 157 |
+
bottleneck = F.leaky_relu(bottleneck, 0.2, True)
|
| 158 |
+
bottleneck = self.conv_expand.forward(bottleneck)
|
| 159 |
+
bottleneck = self.selayer(bottleneck)
|
| 160 |
+
|
| 161 |
+
x = self.shortcut.forward(x)
|
| 162 |
+
return x + bottleneck
|
| 163 |
+
|
| 164 |
+
class FeatureConv(nn.Module):
|
| 165 |
+
def __init__(self, input_dim=512, output_dim=512):
|
| 166 |
+
super(FeatureConv, self).__init__()
|
| 167 |
+
|
| 168 |
+
no_bn = True
|
| 169 |
+
|
| 170 |
+
seq = []
|
| 171 |
+
seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
| 172 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
| 173 |
+
seq.append(nn.ReLU(inplace=True))
|
| 174 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
|
| 175 |
+
if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
|
| 176 |
+
seq.append(nn.ReLU(inplace=True))
|
| 177 |
+
seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
|
| 178 |
+
seq.append(nn.ReLU(inplace=True))
|
| 179 |
+
|
| 180 |
+
self.network = nn.Sequential(*seq)
|
| 181 |
+
|
| 182 |
+
def forward(self, x):
|
| 183 |
+
return self.network(x)
|
| 184 |
+
|
| 185 |
+
class Generator(nn.Module):
|
| 186 |
+
def __init__(self, ngf=64):
|
| 187 |
+
super(Generator, self).__init__()
|
| 188 |
+
|
| 189 |
+
self.encoder = SEResNeXt_Origin(BottleneckX_Origin, [3, 4, 6, 3], num_classes= 370, input_channels=1)
|
| 190 |
+
|
| 191 |
+
self.to0 = self._make_encoder_block_first(5, 32)
|
| 192 |
+
self.to1 = self._make_encoder_block(32, 64)
|
| 193 |
+
self.to2 = self._make_encoder_block(64, 92)
|
| 194 |
+
self.to3 = self._make_encoder_block(92, 128)
|
| 195 |
+
self.to4 = self._make_encoder_block(128, 256)
|
| 196 |
+
|
| 197 |
+
self.deconv_for_decoder = nn.Sequential(
|
| 198 |
+
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
|
| 199 |
+
nn.LeakyReLU(0.2),
|
| 200 |
+
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
|
| 201 |
+
nn.LeakyReLU(0.2),
|
| 202 |
+
nn.ConvTranspose2d(64, 32, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
|
| 203 |
+
nn.LeakyReLU(0.2),
|
| 204 |
+
nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
|
| 205 |
+
nn.Tanh(),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
tunnel4 = nn.Sequential(*[ResNeXtBottleneck(512, 512, cardinality=32, dilate=1) for _ in range(20)])
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
self.tunnel4 = nn.Sequential(nn.Conv2d(1024 + 128, 512, kernel_size=3, stride=1, padding=1),
|
| 212 |
+
nn.LeakyReLU(0.2, True),
|
| 213 |
+
tunnel4,
|
| 214 |
+
nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1),
|
| 215 |
+
nn.PixelShuffle(2),
|
| 216 |
+
nn.LeakyReLU(0.2, True)
|
| 217 |
+
) # 64
|
| 218 |
+
|
| 219 |
+
depth = 2
|
| 220 |
+
tunnel = [ResNeXtBottleneck(256, 256, cardinality=32, dilate=1) for _ in range(depth)]
|
| 221 |
+
tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=2) for _ in range(depth)]
|
| 222 |
+
tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=4) for _ in range(depth)]
|
| 223 |
+
tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=2),
|
| 224 |
+
ResNeXtBottleneck(256, 256, cardinality=32, dilate=1)]
|
| 225 |
+
tunnel3 = nn.Sequential(*tunnel)
|
| 226 |
+
|
| 227 |
+
self.tunnel3 = nn.Sequential(nn.Conv2d(512 + 256, 256, kernel_size=3, stride=1, padding=1),
|
| 228 |
+
nn.LeakyReLU(0.2, True),
|
| 229 |
+
tunnel3,
|
| 230 |
+
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
|
| 231 |
+
nn.PixelShuffle(2),
|
| 232 |
+
nn.LeakyReLU(0.2, True)
|
| 233 |
+
) # 128
|
| 234 |
+
|
| 235 |
+
tunnel = [ResNeXtBottleneck(128, 128, cardinality=32, dilate=1) for _ in range(depth)]
|
| 236 |
+
tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=2) for _ in range(depth)]
|
| 237 |
+
tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=4) for _ in range(depth)]
|
| 238 |
+
tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=2),
|
| 239 |
+
ResNeXtBottleneck(128, 128, cardinality=32, dilate=1)]
|
| 240 |
+
tunnel2 = nn.Sequential(*tunnel)
|
| 241 |
+
|
| 242 |
+
self.tunnel2 = nn.Sequential(nn.Conv2d(128 + 256 + 64, 128, kernel_size=3, stride=1, padding=1),
|
| 243 |
+
nn.LeakyReLU(0.2, True),
|
| 244 |
+
tunnel2,
|
| 245 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
| 246 |
+
nn.PixelShuffle(2),
|
| 247 |
+
nn.LeakyReLU(0.2, True)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
tunnel = [ResNeXtBottleneck(64, 64, cardinality=16, dilate=1)]
|
| 251 |
+
tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=2)]
|
| 252 |
+
tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=4)]
|
| 253 |
+
tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=2),
|
| 254 |
+
ResNeXtBottleneck(64, 64, cardinality=16, dilate=1)]
|
| 255 |
+
tunnel1 = nn.Sequential(*tunnel)
|
| 256 |
+
|
| 257 |
+
self.tunnel1 = nn.Sequential(nn.Conv2d(64 + 32, 64, kernel_size=3, stride=1, padding=1),
|
| 258 |
+
nn.LeakyReLU(0.2, True),
|
| 259 |
+
tunnel1,
|
| 260 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
| 261 |
+
nn.PixelShuffle(2),
|
| 262 |
+
nn.LeakyReLU(0.2, True)
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
self.exit = nn.Sequential(nn.Conv2d(64 + 32, 32, kernel_size=3, stride=1, padding=1),
|
| 266 |
+
nn.LeakyReLU(0.2, True),
|
| 267 |
+
nn.Conv2d(32, 3, kernel_size= 1, stride = 1, padding = 0))
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _make_encoder_block(self, inplanes, planes):
|
| 271 |
+
return nn.Sequential(
|
| 272 |
+
nn.Conv2d(inplanes, planes, 3, 2, 1),
|
| 273 |
+
nn.LeakyReLU(0.2),
|
| 274 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
| 275 |
+
nn.LeakyReLU(0.2),
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def _make_encoder_block_first(self, inplanes, planes):
|
| 279 |
+
return nn.Sequential(
|
| 280 |
+
nn.Conv2d(inplanes, planes, 3, 1, 1),
|
| 281 |
+
nn.LeakyReLU(0.2),
|
| 282 |
+
nn.Conv2d(planes, planes, 3, 1, 1),
|
| 283 |
+
nn.LeakyReLU(0.2),
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
def forward(self, sketch):
|
| 287 |
+
|
| 288 |
+
x0 = self.to0(sketch)
|
| 289 |
+
aux_out = self.to1(x0)
|
| 290 |
+
aux_out = self.to2(aux_out)
|
| 291 |
+
aux_out = self.to3(aux_out)
|
| 292 |
+
|
| 293 |
+
x1, x2, x3, x4 = self.encoder(sketch[:, 0:1])
|
| 294 |
+
|
| 295 |
+
out = self.tunnel4(torch.cat([x4, aux_out], 1))
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
x = self.tunnel3(torch.cat([out, x3], 1))
|
| 300 |
+
|
| 301 |
+
x = self.tunnel2(torch.cat([x, x2, x1], 1))
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
|
| 305 |
+
|
| 306 |
+
decoder_output = self.deconv_for_decoder(out)
|
| 307 |
+
|
| 308 |
+
return x, decoder_output
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class Colorizer(nn.Module):
|
| 312 |
+
def __init__(self):
|
| 313 |
+
super(Colorizer, self).__init__()
|
| 314 |
+
|
| 315 |
+
self.generator = Generator()
|
| 316 |
+
|
| 317 |
+
def forward(self, x, extractor_grad = False):
|
| 318 |
+
fake, guide = self.generator(x)
|
| 319 |
+
return fake, guide
|
readme.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automatic colorization
|
| 2 |
+
|
| 3 |
+
1. Download [generator](https://drive.google.com/file/d/1qmxUEKADkEM4iYLp1fpPLLKnfZ6tcF-t/view?usp=sharing) and [denoiser](https://drive.google.com/file/d/161oyQcYpdkVdw8gKz_MA8RD-Wtg9XDp3/view?usp=sharing) weights. Put generator and extractor weights in `networks` and denoiser weights in `denoising/models`.
|
| 4 |
+
2. To colorize image or folder of images, use the following command:
|
| 5 |
+
```
|
| 6 |
+
$ python inference.py -p "path to file or folder"
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
| col 1 | col 2 |
|
| 10 |
+
|------------|-------------|
|
| 11 |
+
| <img src="figures/bw1.jpg" width="512"> | <img src="figures/color1.png" width="512"> |
|
| 12 |
+
| <img src="figures/bw2.jpg" width="512"> | <img src="figures/color2.png" width="512"> |
|
| 13 |
+
| <img src="figures/bw3.jpg" width="512"> | <img src="figures/color3.png" width="512"> |
|
| 14 |
+
| <img src="figures/bw4.jpg" width="512"> | <img src="figures/color4.png" width="512"> |
|
| 15 |
+
| <img src="figures/bw5.jpg" width="512"> | <img src="figures/color5.png" width="512"> |
|
| 16 |
+
| <img src="figures/bw6.jpg" width="512"> | <img src="figures/color6.png" width="512"> |
|
utils/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (1.06 kB). View file
|
|
|
utils/utils.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
|
| 4 |
+
def resize_pad(img, size = 256):
|
| 5 |
+
|
| 6 |
+
if len(img.shape) == 2:
|
| 7 |
+
img = np.expand_dims(img, 2)
|
| 8 |
+
|
| 9 |
+
if img.shape[2] == 1:
|
| 10 |
+
img = np.repeat(img, 3, 2)
|
| 11 |
+
|
| 12 |
+
if img.shape[2] == 4:
|
| 13 |
+
img = img[:, :, :3]
|
| 14 |
+
|
| 15 |
+
pad = None
|
| 16 |
+
|
| 17 |
+
if (img.shape[0] < img.shape[1]):
|
| 18 |
+
height = img.shape[0]
|
| 19 |
+
ratio = height / (size * 1.5)
|
| 20 |
+
width = int(np.ceil(img.shape[1] / ratio))
|
| 21 |
+
img = cv2.resize(img, (width, int(size * 1.5)), interpolation = cv2.INTER_AREA)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
new_width = width + (32 - width % 32)
|
| 25 |
+
|
| 26 |
+
pad = (0, new_width - width)
|
| 27 |
+
|
| 28 |
+
img = np.pad(img, ((0, 0), (0, pad[1]), (0, 0)), 'maximum')
|
| 29 |
+
else:
|
| 30 |
+
width = img.shape[1]
|
| 31 |
+
ratio = width / size
|
| 32 |
+
height = int(np.ceil(img.shape[0] / ratio))
|
| 33 |
+
img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
|
| 34 |
+
|
| 35 |
+
new_height = height + (32 - height % 32)
|
| 36 |
+
|
| 37 |
+
pad = (new_height - height, 0)
|
| 38 |
+
|
| 39 |
+
img = np.pad(img, ((0, pad[0]), (0, 0), (0, 0)), 'maximum')
|
| 40 |
+
|
| 41 |
+
if (img.dtype == 'float32'):
|
| 42 |
+
np.clip(img, 0, 1, out = img)
|
| 43 |
+
|
| 44 |
+
return img[:, :, :1], pad
|