Spaces:
Running
Running
Upload 4 files
Browse files- denoising/denoiser.py +117 -0
- denoising/functions.py +101 -0
- denoising/models.py +100 -0
- denoising/utils.py +66 -0
denoising/denoiser.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Denoise an image with the FFDNet denoising method
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 5 |
+
|
| 6 |
+
This program is free software: you can use, modify and/or
|
| 7 |
+
redistribute it under the terms of the GNU General Public
|
| 8 |
+
License as published by the Free Software Foundation, either
|
| 9 |
+
version 3 of the License, or (at your option) any later
|
| 10 |
+
version. You should have received a copy of this license along
|
| 11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 12 |
+
"""
|
| 13 |
+
import os
|
| 14 |
+
import argparse
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import cv2
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from torch.autograd import Variable
|
| 23 |
+
from .models import FFDNet
|
| 24 |
+
from .utils import normalize, variable_to_cv2_image, remove_dataparallel_wrapper, is_rgb
|
| 25 |
+
|
| 26 |
+
class FFDNetDenoiser:
|
| 27 |
+
def __init__(self, _device, _sigma = 25, _weights_dir = 'denoising/models/', _in_ch = 3):
|
| 28 |
+
self.sigma = _sigma / 255
|
| 29 |
+
self.weights_dir = _weights_dir
|
| 30 |
+
self.channels = _in_ch
|
| 31 |
+
self.device = _device
|
| 32 |
+
|
| 33 |
+
self.model = FFDNet(num_input_channels = _in_ch)
|
| 34 |
+
self.load_weights()
|
| 35 |
+
self.model.eval()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_weights(self):
|
| 39 |
+
weights_name = 'net_rgb.pth' if self.channels == 3 else 'net_gray.pth'
|
| 40 |
+
weights_path = os.path.join(self.weights_dir, weights_name)
|
| 41 |
+
if self.device == 'cuda':
|
| 42 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
| 43 |
+
device_ids = [0]
|
| 44 |
+
self.model = nn.DataParallel(self.model, device_ids=device_ids).cuda()
|
| 45 |
+
else:
|
| 46 |
+
state_dict = torch.load(weights_path, map_location='cpu')
|
| 47 |
+
# CPU mode: remove the DataParallel wrapper
|
| 48 |
+
state_dict = remove_dataparallel_wrapper(state_dict)
|
| 49 |
+
self.model.load_state_dict(state_dict)
|
| 50 |
+
|
| 51 |
+
def get_denoised_image(self, imorig, sigma = None):
|
| 52 |
+
|
| 53 |
+
if sigma is not None:
|
| 54 |
+
cur_sigma = sigma / 255
|
| 55 |
+
else:
|
| 56 |
+
cur_sigma = self.sigma
|
| 57 |
+
|
| 58 |
+
if len(imorig.shape) < 3 or imorig.shape[2] == 1:
|
| 59 |
+
imorig = np.repeat(np.expand_dims(imorig, 2), 3, 2)
|
| 60 |
+
|
| 61 |
+
imorig = imorig[..., :3]
|
| 62 |
+
|
| 63 |
+
if (max(imorig.shape[0], imorig.shape[1]) > 1200):
|
| 64 |
+
ratio = max(imorig.shape[0], imorig.shape[1]) / 1200
|
| 65 |
+
imorig = cv2.resize(imorig, (int(imorig.shape[1] / ratio), int(imorig.shape[0] / ratio)), interpolation = cv2.INTER_AREA)
|
| 66 |
+
|
| 67 |
+
imorig = imorig.transpose(2, 0, 1)
|
| 68 |
+
|
| 69 |
+
if (imorig.max() > 1.2):
|
| 70 |
+
imorig = normalize(imorig)
|
| 71 |
+
imorig = np.expand_dims(imorig, 0)
|
| 72 |
+
|
| 73 |
+
# Handle odd sizes
|
| 74 |
+
expanded_h = False
|
| 75 |
+
expanded_w = False
|
| 76 |
+
sh_im = imorig.shape
|
| 77 |
+
if sh_im[2]%2 == 1:
|
| 78 |
+
expanded_h = True
|
| 79 |
+
imorig = np.concatenate((imorig, imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
|
| 80 |
+
|
| 81 |
+
if sh_im[3]%2 == 1:
|
| 82 |
+
expanded_w = True
|
| 83 |
+
imorig = np.concatenate((imorig, imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
imorig = torch.Tensor(imorig)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Sets data type according to CPU or GPU modes
|
| 90 |
+
if self.device == 'cuda':
|
| 91 |
+
dtype = torch.cuda.FloatTensor
|
| 92 |
+
else:
|
| 93 |
+
dtype = torch.FloatTensor
|
| 94 |
+
|
| 95 |
+
imnoisy = imorig.clone()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
imorig, imnoisy = imorig.type(dtype), imnoisy.type(dtype)
|
| 100 |
+
nsigma = torch.FloatTensor([cur_sigma]).type(dtype)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# Estimate noise and subtract it to the input image
|
| 104 |
+
im_noise_estim = self.model(imnoisy, nsigma)
|
| 105 |
+
outim = torch.clamp(imnoisy-im_noise_estim, 0., 1.)
|
| 106 |
+
|
| 107 |
+
if expanded_h:
|
| 108 |
+
imorig = imorig[:, :, :-1, :]
|
| 109 |
+
outim = outim[:, :, :-1, :]
|
| 110 |
+
imnoisy = imnoisy[:, :, :-1, :]
|
| 111 |
+
|
| 112 |
+
if expanded_w:
|
| 113 |
+
imorig = imorig[:, :, :, :-1]
|
| 114 |
+
outim = outim[:, :, :, :-1]
|
| 115 |
+
imnoisy = imnoisy[:, :, :, :-1]
|
| 116 |
+
|
| 117 |
+
return variable_to_cv2_image(outim)
|
denoising/functions.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Functions implementing custom NN layers
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 5 |
+
|
| 6 |
+
This program is free software: you can use, modify and/or
|
| 7 |
+
redistribute it under the terms of the GNU General Public
|
| 8 |
+
License as published by the Free Software Foundation, either
|
| 9 |
+
version 3 of the License, or (at your option) any later
|
| 10 |
+
version. You should have received a copy of this license along
|
| 11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 12 |
+
"""
|
| 13 |
+
import torch
|
| 14 |
+
from torch.autograd import Function, Variable
|
| 15 |
+
|
| 16 |
+
def concatenate_input_noise_map(input, noise_sigma):
|
| 17 |
+
r"""Implements the first layer of FFDNet. This function returns a
|
| 18 |
+
torch.autograd.Variable composed of the concatenation of the downsampled
|
| 19 |
+
input image and the noise map. Each image of the batch of size CxHxW gets
|
| 20 |
+
converted to an array of size 4*CxH/2xW/2. Each of the pixels of the
|
| 21 |
+
non-overlapped 2x2 patches of the input image are placed in the new array
|
| 22 |
+
along the first dimension.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
input: batch containing CxHxW images
|
| 26 |
+
noise_sigma: the value of the pixels of the CxH/2xW/2 noise map
|
| 27 |
+
"""
|
| 28 |
+
# noise_sigma is a list of length batch_size
|
| 29 |
+
N, C, H, W = input.size()
|
| 30 |
+
dtype = input.type()
|
| 31 |
+
sca = 2
|
| 32 |
+
sca2 = sca*sca
|
| 33 |
+
Cout = sca2*C
|
| 34 |
+
Hout = H//sca
|
| 35 |
+
Wout = W//sca
|
| 36 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
| 37 |
+
|
| 38 |
+
# Fill the downsampled image with zeros
|
| 39 |
+
if 'cuda' in dtype:
|
| 40 |
+
downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0)
|
| 41 |
+
else:
|
| 42 |
+
downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0)
|
| 43 |
+
|
| 44 |
+
# Build the CxH/2xW/2 noise map
|
| 45 |
+
noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout)
|
| 46 |
+
|
| 47 |
+
# Populate output
|
| 48 |
+
for idx in range(sca2):
|
| 49 |
+
downsampledfeatures[:, idx:Cout:sca2, :, :] = \
|
| 50 |
+
input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
|
| 51 |
+
|
| 52 |
+
# concatenate de-interleaved mosaic with noise map
|
| 53 |
+
return torch.cat((noise_map, downsampledfeatures), 1)
|
| 54 |
+
|
| 55 |
+
class UpSampleFeaturesFunction(Function):
|
| 56 |
+
r"""Extends PyTorch's modules by implementing a torch.autograd.Function.
|
| 57 |
+
This class implements the forward and backward methods of the last layer
|
| 58 |
+
of FFDNet. It basically performs the inverse of
|
| 59 |
+
concatenate_input_noise_map(): it converts each of the images of a
|
| 60 |
+
batch of size CxH/2xW/2 to images of size C/4xHxW
|
| 61 |
+
"""
|
| 62 |
+
@staticmethod
|
| 63 |
+
def forward(ctx, input):
|
| 64 |
+
N, Cin, Hin, Win = input.size()
|
| 65 |
+
dtype = input.type()
|
| 66 |
+
sca = 2
|
| 67 |
+
sca2 = sca*sca
|
| 68 |
+
Cout = Cin//sca2
|
| 69 |
+
Hout = Hin*sca
|
| 70 |
+
Wout = Win*sca
|
| 71 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
| 72 |
+
|
| 73 |
+
assert (Cin%sca2 == 0), 'Invalid input dimensions: number of channels should be divisible by 4'
|
| 74 |
+
|
| 75 |
+
result = torch.zeros((N, Cout, Hout, Wout)).type(dtype)
|
| 76 |
+
for idx in range(sca2):
|
| 77 |
+
result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = input[:, idx:Cin:sca2, :, :]
|
| 78 |
+
|
| 79 |
+
return result
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def backward(ctx, grad_output):
|
| 83 |
+
N, Cg_out, Hg_out, Wg_out = grad_output.size()
|
| 84 |
+
dtype = grad_output.data.type()
|
| 85 |
+
sca = 2
|
| 86 |
+
sca2 = sca*sca
|
| 87 |
+
Cg_in = sca2*Cg_out
|
| 88 |
+
Hg_in = Hg_out//sca
|
| 89 |
+
Wg_in = Wg_out//sca
|
| 90 |
+
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
|
| 91 |
+
|
| 92 |
+
# Build output
|
| 93 |
+
grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype)
|
| 94 |
+
# Populate output
|
| 95 |
+
for idx in range(sca2):
|
| 96 |
+
grad_input[:, idx:Cg_in:sca2, :, :] = grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
|
| 97 |
+
|
| 98 |
+
return Variable(grad_input)
|
| 99 |
+
|
| 100 |
+
# Alias functions
|
| 101 |
+
upsamplefeatures = UpSampleFeaturesFunction.apply
|
denoising/models.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Definition of the FFDNet model and its custom layers
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 5 |
+
|
| 6 |
+
This program is free software: you can use, modify and/or
|
| 7 |
+
redistribute it under the terms of the GNU General Public
|
| 8 |
+
License as published by the Free Software Foundation, either
|
| 9 |
+
version 3 of the License, or (at your option) any later
|
| 10 |
+
version. You should have received a copy of this license along
|
| 11 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 12 |
+
"""
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
import denoising.functions as functions
|
| 16 |
+
|
| 17 |
+
class UpSampleFeatures(nn.Module):
|
| 18 |
+
r"""Implements the last layer of FFDNet
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super(UpSampleFeatures, self).__init__()
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return functions.upsamplefeatures(x)
|
| 24 |
+
|
| 25 |
+
class IntermediateDnCNN(nn.Module):
|
| 26 |
+
r"""Implements the middel part of the FFDNet architecture, which
|
| 27 |
+
is basically a DnCNN net
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self, input_features, middle_features, num_conv_layers):
|
| 30 |
+
super(IntermediateDnCNN, self).__init__()
|
| 31 |
+
self.kernel_size = 3
|
| 32 |
+
self.padding = 1
|
| 33 |
+
self.input_features = input_features
|
| 34 |
+
self.num_conv_layers = num_conv_layers
|
| 35 |
+
self.middle_features = middle_features
|
| 36 |
+
if self.input_features == 5:
|
| 37 |
+
self.output_features = 4 #Grayscale image
|
| 38 |
+
elif self.input_features == 15:
|
| 39 |
+
self.output_features = 12 #RGB image
|
| 40 |
+
else:
|
| 41 |
+
raise Exception('Invalid number of input features')
|
| 42 |
+
|
| 43 |
+
layers = []
|
| 44 |
+
layers.append(nn.Conv2d(in_channels=self.input_features,\
|
| 45 |
+
out_channels=self.middle_features,\
|
| 46 |
+
kernel_size=self.kernel_size,\
|
| 47 |
+
padding=self.padding,\
|
| 48 |
+
bias=False))
|
| 49 |
+
layers.append(nn.ReLU(inplace=True))
|
| 50 |
+
for _ in range(self.num_conv_layers-2):
|
| 51 |
+
layers.append(nn.Conv2d(in_channels=self.middle_features,\
|
| 52 |
+
out_channels=self.middle_features,\
|
| 53 |
+
kernel_size=self.kernel_size,\
|
| 54 |
+
padding=self.padding,\
|
| 55 |
+
bias=False))
|
| 56 |
+
layers.append(nn.BatchNorm2d(self.middle_features))
|
| 57 |
+
layers.append(nn.ReLU(inplace=True))
|
| 58 |
+
layers.append(nn.Conv2d(in_channels=self.middle_features,\
|
| 59 |
+
out_channels=self.output_features,\
|
| 60 |
+
kernel_size=self.kernel_size,\
|
| 61 |
+
padding=self.padding,\
|
| 62 |
+
bias=False))
|
| 63 |
+
self.itermediate_dncnn = nn.Sequential(*layers)
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
out = self.itermediate_dncnn(x)
|
| 66 |
+
return out
|
| 67 |
+
|
| 68 |
+
class FFDNet(nn.Module):
|
| 69 |
+
r"""Implements the FFDNet architecture
|
| 70 |
+
"""
|
| 71 |
+
def __init__(self, num_input_channels):
|
| 72 |
+
super(FFDNet, self).__init__()
|
| 73 |
+
self.num_input_channels = num_input_channels
|
| 74 |
+
if self.num_input_channels == 1:
|
| 75 |
+
# Grayscale image
|
| 76 |
+
self.num_feature_maps = 64
|
| 77 |
+
self.num_conv_layers = 15
|
| 78 |
+
self.downsampled_channels = 5
|
| 79 |
+
self.output_features = 4
|
| 80 |
+
elif self.num_input_channels == 3:
|
| 81 |
+
# RGB image
|
| 82 |
+
self.num_feature_maps = 96
|
| 83 |
+
self.num_conv_layers = 12
|
| 84 |
+
self.downsampled_channels = 15
|
| 85 |
+
self.output_features = 12
|
| 86 |
+
else:
|
| 87 |
+
raise Exception('Invalid number of input features')
|
| 88 |
+
|
| 89 |
+
self.intermediate_dncnn = IntermediateDnCNN(\
|
| 90 |
+
input_features=self.downsampled_channels,\
|
| 91 |
+
middle_features=self.num_feature_maps,\
|
| 92 |
+
num_conv_layers=self.num_conv_layers)
|
| 93 |
+
self.upsamplefeatures = UpSampleFeatures()
|
| 94 |
+
|
| 95 |
+
def forward(self, x, noise_sigma):
|
| 96 |
+
concat_noise_x = functions.concatenate_input_noise_map(x.data, noise_sigma.data)
|
| 97 |
+
concat_noise_x = Variable(concat_noise_x)
|
| 98 |
+
h_dncnn = self.intermediate_dncnn(concat_noise_x)
|
| 99 |
+
pred_noise = self.upsamplefeatures(h_dncnn)
|
| 100 |
+
return pred_noise
|
denoising/utils.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Different utilities such as orthogonalization of weights, initialization of
|
| 3 |
+
loggers, etc
|
| 4 |
+
|
| 5 |
+
Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
|
| 6 |
+
|
| 7 |
+
This program is free software: you can use, modify and/or
|
| 8 |
+
redistribute it under the terms of the GNU General Public
|
| 9 |
+
License as published by the Free Software Foundation, either
|
| 10 |
+
version 3 of the License, or (at your option) any later
|
| 11 |
+
version. You should have received a copy of this license along
|
| 12 |
+
this program. If not, see <http://www.gnu.org/licenses/>.
|
| 13 |
+
"""
|
| 14 |
+
import numpy as np
|
| 15 |
+
import cv2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def variable_to_cv2_image(varim):
|
| 19 |
+
r"""Converts a torch.autograd.Variable to an OpenCV image
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
varim: a torch.autograd.Variable
|
| 23 |
+
"""
|
| 24 |
+
nchannels = varim.size()[1]
|
| 25 |
+
if nchannels == 1:
|
| 26 |
+
res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8)
|
| 27 |
+
elif nchannels == 3:
|
| 28 |
+
res = varim.data.cpu().numpy()[0]
|
| 29 |
+
res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
|
| 30 |
+
res = (res*255.).clip(0, 255).astype(np.uint8)
|
| 31 |
+
else:
|
| 32 |
+
raise Exception('Number of color channels not supported')
|
| 33 |
+
return res
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def normalize(data):
|
| 37 |
+
return np.float32(data/255.)
|
| 38 |
+
|
| 39 |
+
def remove_dataparallel_wrapper(state_dict):
|
| 40 |
+
r"""Converts a DataParallel model to a normal one by removing the "module."
|
| 41 |
+
wrapper in the module dictionary
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
state_dict: a torch.nn.DataParallel state dictionary
|
| 45 |
+
"""
|
| 46 |
+
from collections import OrderedDict
|
| 47 |
+
|
| 48 |
+
new_state_dict = OrderedDict()
|
| 49 |
+
for k, vl in state_dict.items():
|
| 50 |
+
name = k[7:] # remove 'module.' of DataParallel
|
| 51 |
+
new_state_dict[name] = vl
|
| 52 |
+
|
| 53 |
+
return new_state_dict
|
| 54 |
+
|
| 55 |
+
def is_rgb(im_path):
|
| 56 |
+
r""" Returns True if the image in im_path is an RGB image
|
| 57 |
+
"""
|
| 58 |
+
from skimage.io import imread
|
| 59 |
+
rgb = False
|
| 60 |
+
im = imread(im_path)
|
| 61 |
+
if (len(im.shape) == 3):
|
| 62 |
+
if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])):
|
| 63 |
+
rgb = True
|
| 64 |
+
print("rgb: {}".format(rgb))
|
| 65 |
+
print("im shape: {}".format(im.shape))
|
| 66 |
+
return rgb
|