| | """ |
| | Definition of the FFDNet model and its custom layers |
| | |
| | Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr> |
| | |
| | This program is free software: you can use, modify and/or |
| | redistribute it under the terms of the GNU General Public |
| | License as published by the Free Software Foundation, either |
| | version 3 of the License, or (at your option) any later |
| | version. You should have received a copy of this license along |
| | this program. If not, see <http://www.gnu.org/licenses/>. |
| | """ |
| | import torch.nn as nn |
| | from torch.autograd import Variable |
| | import denoising.functions as functions |
| | |
| | class UpSampleFeatures(nn.Module): |
| | r"""Implements the last layer of FFDNet |
| | """ |
| | def __init__(self): |
| | super(UpSampleFeatures, self).__init__() |
| | def forward(self, x): |
| | return functions.upsamplefeatures(x) |
| |
|
| | class IntermediateDnCNN(nn.Module): |
| | r"""Implements the middel part of the FFDNet architecture, which |
| | is basically a DnCNN net |
| | """ |
| | def __init__(self, input_features, middle_features, num_conv_layers): |
| | super(IntermediateDnCNN, self).__init__() |
| | self.kernel_size = 3 |
| | self.padding = 1 |
| | self.input_features = input_features |
| | self.num_conv_layers = num_conv_layers |
| | self.middle_features = middle_features |
| | if self.input_features == 5: |
| | self.output_features = 4 |
| | elif self.input_features == 15: |
| | self.output_features = 12 |
| | else: |
| | raise Exception('Invalid number of input features') |
| |
|
| | layers = [] |
| | layers.append(nn.Conv2d(in_channels=self.input_features,\ |
| | out_channels=self.middle_features,\ |
| | kernel_size=self.kernel_size,\ |
| | padding=self.padding,\ |
| | bias=False)) |
| | layers.append(nn.ReLU(inplace=True)) |
| | for _ in range(self.num_conv_layers-2): |
| | layers.append(nn.Conv2d(in_channels=self.middle_features,\ |
| | out_channels=self.middle_features,\ |
| | kernel_size=self.kernel_size,\ |
| | padding=self.padding,\ |
| | bias=False)) |
| | layers.append(nn.BatchNorm2d(self.middle_features)) |
| | layers.append(nn.ReLU(inplace=True)) |
| | layers.append(nn.Conv2d(in_channels=self.middle_features,\ |
| | out_channels=self.output_features,\ |
| | kernel_size=self.kernel_size,\ |
| | padding=self.padding,\ |
| | bias=False)) |
| | self.itermediate_dncnn = nn.Sequential(*layers) |
| | def forward(self, x): |
| | out = self.itermediate_dncnn(x) |
| | return out |
| |
|
| | class FFDNet(nn.Module): |
| | r"""Implements the FFDNet architecture |
| | """ |
| | def __init__(self, num_input_channels): |
| | super(FFDNet, self).__init__() |
| | self.num_input_channels = num_input_channels |
| | if self.num_input_channels == 1: |
| | |
| | self.num_feature_maps = 64 |
| | self.num_conv_layers = 15 |
| | self.downsampled_channels = 5 |
| | self.output_features = 4 |
| | elif self.num_input_channels == 3: |
| | |
| | self.num_feature_maps = 96 |
| | self.num_conv_layers = 12 |
| | self.downsampled_channels = 15 |
| | self.output_features = 12 |
| | else: |
| | raise Exception('Invalid number of input features') |
| |
|
| | self.intermediate_dncnn = IntermediateDnCNN(\ |
| | input_features=self.downsampled_channels,\ |
| | middle_features=self.num_feature_maps,\ |
| | num_conv_layers=self.num_conv_layers) |
| | self.upsamplefeatures = UpSampleFeatures() |
| |
|
| | def forward(self, x, noise_sigma): |
| | concat_noise_x = functions.concatenate_input_noise_map(x.data, noise_sigma.data) |
| | concat_noise_x = Variable(concat_noise_x) |
| | h_dncnn = self.intermediate_dncnn(concat_noise_x) |
| | pred_noise = self.upsamplefeatures(h_dncnn) |
| | return pred_noise |
| |
|