koesan commited on
Commit
b9e6a83
·
verified ·
1 Parent(s): 131e4f3

Upload 4 files

Browse files
denoising/denoiser.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Denoise an image with the FFDNet denoising method
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import os
14
+ import argparse
15
+ import time
16
+
17
+
18
+ import numpy as np
19
+ import cv2
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.autograd import Variable
23
+ from .models import FFDNet
24
+ from .utils import normalize, variable_to_cv2_image, remove_dataparallel_wrapper, is_rgb
25
+
26
+ class FFDNetDenoiser:
27
+ def __init__(self, _device, _sigma = 25, _weights_dir = 'denoising/models/', _in_ch = 3):
28
+ self.sigma = _sigma / 255
29
+ self.weights_dir = _weights_dir
30
+ self.channels = _in_ch
31
+ self.device = _device
32
+
33
+ self.model = FFDNet(num_input_channels = _in_ch)
34
+ self.load_weights()
35
+ self.model.eval()
36
+
37
+
38
+ def load_weights(self):
39
+ weights_name = 'net_rgb.pth' if self.channels == 3 else 'net_gray.pth'
40
+ weights_path = os.path.join(self.weights_dir, weights_name)
41
+ if self.device == 'cuda':
42
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
43
+ device_ids = [0]
44
+ self.model = nn.DataParallel(self.model, device_ids=device_ids).cuda()
45
+ else:
46
+ state_dict = torch.load(weights_path, map_location='cpu')
47
+ # CPU mode: remove the DataParallel wrapper
48
+ state_dict = remove_dataparallel_wrapper(state_dict)
49
+ self.model.load_state_dict(state_dict)
50
+
51
+ def get_denoised_image(self, imorig, sigma = None):
52
+
53
+ if sigma is not None:
54
+ cur_sigma = sigma / 255
55
+ else:
56
+ cur_sigma = self.sigma
57
+
58
+ if len(imorig.shape) < 3 or imorig.shape[2] == 1:
59
+ imorig = np.repeat(np.expand_dims(imorig, 2), 3, 2)
60
+
61
+ imorig = imorig[..., :3]
62
+
63
+ if (max(imorig.shape[0], imorig.shape[1]) > 1200):
64
+ ratio = max(imorig.shape[0], imorig.shape[1]) / 1200
65
+ imorig = cv2.resize(imorig, (int(imorig.shape[1] / ratio), int(imorig.shape[0] / ratio)), interpolation = cv2.INTER_AREA)
66
+
67
+ imorig = imorig.transpose(2, 0, 1)
68
+
69
+ if (imorig.max() > 1.2):
70
+ imorig = normalize(imorig)
71
+ imorig = np.expand_dims(imorig, 0)
72
+
73
+ # Handle odd sizes
74
+ expanded_h = False
75
+ expanded_w = False
76
+ sh_im = imorig.shape
77
+ if sh_im[2]%2 == 1:
78
+ expanded_h = True
79
+ imorig = np.concatenate((imorig, imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
80
+
81
+ if sh_im[3]%2 == 1:
82
+ expanded_w = True
83
+ imorig = np.concatenate((imorig, imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
84
+
85
+
86
+ imorig = torch.Tensor(imorig)
87
+
88
+
89
+ # Sets data type according to CPU or GPU modes
90
+ if self.device == 'cuda':
91
+ dtype = torch.cuda.FloatTensor
92
+ else:
93
+ dtype = torch.FloatTensor
94
+
95
+ imnoisy = imorig.clone()
96
+
97
+
98
+ with torch.no_grad():
99
+ imorig, imnoisy = imorig.type(dtype), imnoisy.type(dtype)
100
+ nsigma = torch.FloatTensor([cur_sigma]).type(dtype)
101
+
102
+
103
+ # Estimate noise and subtract it to the input image
104
+ im_noise_estim = self.model(imnoisy, nsigma)
105
+ outim = torch.clamp(imnoisy-im_noise_estim, 0., 1.)
106
+
107
+ if expanded_h:
108
+ imorig = imorig[:, :, :-1, :]
109
+ outim = outim[:, :, :-1, :]
110
+ imnoisy = imnoisy[:, :, :-1, :]
111
+
112
+ if expanded_w:
113
+ imorig = imorig[:, :, :, :-1]
114
+ outim = outim[:, :, :, :-1]
115
+ imnoisy = imnoisy[:, :, :, :-1]
116
+
117
+ return variable_to_cv2_image(outim)
denoising/functions.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions implementing custom NN layers
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import torch
14
+ from torch.autograd import Function, Variable
15
+
16
+ def concatenate_input_noise_map(input, noise_sigma):
17
+ r"""Implements the first layer of FFDNet. This function returns a
18
+ torch.autograd.Variable composed of the concatenation of the downsampled
19
+ input image and the noise map. Each image of the batch of size CxHxW gets
20
+ converted to an array of size 4*CxH/2xW/2. Each of the pixels of the
21
+ non-overlapped 2x2 patches of the input image are placed in the new array
22
+ along the first dimension.
23
+
24
+ Args:
25
+ input: batch containing CxHxW images
26
+ noise_sigma: the value of the pixels of the CxH/2xW/2 noise map
27
+ """
28
+ # noise_sigma is a list of length batch_size
29
+ N, C, H, W = input.size()
30
+ dtype = input.type()
31
+ sca = 2
32
+ sca2 = sca*sca
33
+ Cout = sca2*C
34
+ Hout = H//sca
35
+ Wout = W//sca
36
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
37
+
38
+ # Fill the downsampled image with zeros
39
+ if 'cuda' in dtype:
40
+ downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0)
41
+ else:
42
+ downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0)
43
+
44
+ # Build the CxH/2xW/2 noise map
45
+ noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout)
46
+
47
+ # Populate output
48
+ for idx in range(sca2):
49
+ downsampledfeatures[:, idx:Cout:sca2, :, :] = \
50
+ input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
51
+
52
+ # concatenate de-interleaved mosaic with noise map
53
+ return torch.cat((noise_map, downsampledfeatures), 1)
54
+
55
+ class UpSampleFeaturesFunction(Function):
56
+ r"""Extends PyTorch's modules by implementing a torch.autograd.Function.
57
+ This class implements the forward and backward methods of the last layer
58
+ of FFDNet. It basically performs the inverse of
59
+ concatenate_input_noise_map(): it converts each of the images of a
60
+ batch of size CxH/2xW/2 to images of size C/4xHxW
61
+ """
62
+ @staticmethod
63
+ def forward(ctx, input):
64
+ N, Cin, Hin, Win = input.size()
65
+ dtype = input.type()
66
+ sca = 2
67
+ sca2 = sca*sca
68
+ Cout = Cin//sca2
69
+ Hout = Hin*sca
70
+ Wout = Win*sca
71
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
72
+
73
+ assert (Cin%sca2 == 0), 'Invalid input dimensions: number of channels should be divisible by 4'
74
+
75
+ result = torch.zeros((N, Cout, Hout, Wout)).type(dtype)
76
+ for idx in range(sca2):
77
+ result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = input[:, idx:Cin:sca2, :, :]
78
+
79
+ return result
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output):
83
+ N, Cg_out, Hg_out, Wg_out = grad_output.size()
84
+ dtype = grad_output.data.type()
85
+ sca = 2
86
+ sca2 = sca*sca
87
+ Cg_in = sca2*Cg_out
88
+ Hg_in = Hg_out//sca
89
+ Wg_in = Wg_out//sca
90
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
91
+
92
+ # Build output
93
+ grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype)
94
+ # Populate output
95
+ for idx in range(sca2):
96
+ grad_input[:, idx:Cg_in:sca2, :, :] = grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
97
+
98
+ return Variable(grad_input)
99
+
100
+ # Alias functions
101
+ upsamplefeatures = UpSampleFeaturesFunction.apply
denoising/models.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Definition of the FFDNet model and its custom layers
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import torch.nn as nn
14
+ from torch.autograd import Variable
15
+ import denoising.functions as functions
16
+
17
+ class UpSampleFeatures(nn.Module):
18
+ r"""Implements the last layer of FFDNet
19
+ """
20
+ def __init__(self):
21
+ super(UpSampleFeatures, self).__init__()
22
+ def forward(self, x):
23
+ return functions.upsamplefeatures(x)
24
+
25
+ class IntermediateDnCNN(nn.Module):
26
+ r"""Implements the middel part of the FFDNet architecture, which
27
+ is basically a DnCNN net
28
+ """
29
+ def __init__(self, input_features, middle_features, num_conv_layers):
30
+ super(IntermediateDnCNN, self).__init__()
31
+ self.kernel_size = 3
32
+ self.padding = 1
33
+ self.input_features = input_features
34
+ self.num_conv_layers = num_conv_layers
35
+ self.middle_features = middle_features
36
+ if self.input_features == 5:
37
+ self.output_features = 4 #Grayscale image
38
+ elif self.input_features == 15:
39
+ self.output_features = 12 #RGB image
40
+ else:
41
+ raise Exception('Invalid number of input features')
42
+
43
+ layers = []
44
+ layers.append(nn.Conv2d(in_channels=self.input_features,\
45
+ out_channels=self.middle_features,\
46
+ kernel_size=self.kernel_size,\
47
+ padding=self.padding,\
48
+ bias=False))
49
+ layers.append(nn.ReLU(inplace=True))
50
+ for _ in range(self.num_conv_layers-2):
51
+ layers.append(nn.Conv2d(in_channels=self.middle_features,\
52
+ out_channels=self.middle_features,\
53
+ kernel_size=self.kernel_size,\
54
+ padding=self.padding,\
55
+ bias=False))
56
+ layers.append(nn.BatchNorm2d(self.middle_features))
57
+ layers.append(nn.ReLU(inplace=True))
58
+ layers.append(nn.Conv2d(in_channels=self.middle_features,\
59
+ out_channels=self.output_features,\
60
+ kernel_size=self.kernel_size,\
61
+ padding=self.padding,\
62
+ bias=False))
63
+ self.itermediate_dncnn = nn.Sequential(*layers)
64
+ def forward(self, x):
65
+ out = self.itermediate_dncnn(x)
66
+ return out
67
+
68
+ class FFDNet(nn.Module):
69
+ r"""Implements the FFDNet architecture
70
+ """
71
+ def __init__(self, num_input_channels):
72
+ super(FFDNet, self).__init__()
73
+ self.num_input_channels = num_input_channels
74
+ if self.num_input_channels == 1:
75
+ # Grayscale image
76
+ self.num_feature_maps = 64
77
+ self.num_conv_layers = 15
78
+ self.downsampled_channels = 5
79
+ self.output_features = 4
80
+ elif self.num_input_channels == 3:
81
+ # RGB image
82
+ self.num_feature_maps = 96
83
+ self.num_conv_layers = 12
84
+ self.downsampled_channels = 15
85
+ self.output_features = 12
86
+ else:
87
+ raise Exception('Invalid number of input features')
88
+
89
+ self.intermediate_dncnn = IntermediateDnCNN(\
90
+ input_features=self.downsampled_channels,\
91
+ middle_features=self.num_feature_maps,\
92
+ num_conv_layers=self.num_conv_layers)
93
+ self.upsamplefeatures = UpSampleFeatures()
94
+
95
+ def forward(self, x, noise_sigma):
96
+ concat_noise_x = functions.concatenate_input_noise_map(x.data, noise_sigma.data)
97
+ concat_noise_x = Variable(concat_noise_x)
98
+ h_dncnn = self.intermediate_dncnn(concat_noise_x)
99
+ pred_noise = self.upsamplefeatures(h_dncnn)
100
+ return pred_noise
denoising/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Different utilities such as orthogonalization of weights, initialization of
3
+ loggers, etc
4
+
5
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
6
+
7
+ This program is free software: you can use, modify and/or
8
+ redistribute it under the terms of the GNU General Public
9
+ License as published by the Free Software Foundation, either
10
+ version 3 of the License, or (at your option) any later
11
+ version. You should have received a copy of this license along
12
+ this program. If not, see <http://www.gnu.org/licenses/>.
13
+ """
14
+ import numpy as np
15
+ import cv2
16
+
17
+
18
+ def variable_to_cv2_image(varim):
19
+ r"""Converts a torch.autograd.Variable to an OpenCV image
20
+
21
+ Args:
22
+ varim: a torch.autograd.Variable
23
+ """
24
+ nchannels = varim.size()[1]
25
+ if nchannels == 1:
26
+ res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8)
27
+ elif nchannels == 3:
28
+ res = varim.data.cpu().numpy()[0]
29
+ res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
30
+ res = (res*255.).clip(0, 255).astype(np.uint8)
31
+ else:
32
+ raise Exception('Number of color channels not supported')
33
+ return res
34
+
35
+
36
+ def normalize(data):
37
+ return np.float32(data/255.)
38
+
39
+ def remove_dataparallel_wrapper(state_dict):
40
+ r"""Converts a DataParallel model to a normal one by removing the "module."
41
+ wrapper in the module dictionary
42
+
43
+ Args:
44
+ state_dict: a torch.nn.DataParallel state dictionary
45
+ """
46
+ from collections import OrderedDict
47
+
48
+ new_state_dict = OrderedDict()
49
+ for k, vl in state_dict.items():
50
+ name = k[7:] # remove 'module.' of DataParallel
51
+ new_state_dict[name] = vl
52
+
53
+ return new_state_dict
54
+
55
+ def is_rgb(im_path):
56
+ r""" Returns True if the image in im_path is an RGB image
57
+ """
58
+ from skimage.io import imread
59
+ rgb = False
60
+ im = imread(im_path)
61
+ if (len(im.shape) == 3):
62
+ if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])):
63
+ rgb = True
64
+ print("rgb: {}".format(rgb))
65
+ print("im shape: {}".format(im.shape))
66
+ return rgb