Spaces:
Sleeping
Sleeping
File size: 5,442 Bytes
65d7391 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2020
@author: davide.cozzolino
"""
import math
import torch.nn as nn
def conv_with_padding(in_planes, out_planes, kernelsize, stride=1, dilation=1, bias=False, padding = None):
if padding is None:
padding = kernelsize//2
return nn.Conv2d(in_planes, out_planes, kernel_size=kernelsize, stride=stride, dilation=dilation, padding=padding, bias=bias)
def conv_init(conv, act='linear'):
r"""
Reproduces conv initialization from DnCNN
"""
n = conv.kernel_size[0] * conv.kernel_size[1] * conv.out_channels
conv.weight.data.normal_(0, math.sqrt(2. / n))
def batchnorm_init(m, kernelsize=3):
r"""
Reproduces batchnorm initialization from DnCNN
"""
n = kernelsize**2 * m.num_features
m.weight.data.normal_(0, math.sqrt(2. / (n)))
m.bias.data.zero_()
def make_activation(act):
if act is None:
return None
elif act == 'relu':
return nn.ReLU(inplace=True)
elif act == 'tanh':
return nn.Tanh()
elif act == 'leaky_relu':
return nn.LeakyReLU(inplace=True)
elif act == 'softmax':
return nn.Softmax()
elif act == 'linear':
return None
else:
assert(False)
def make_net(nplanes_in, kernels, features, bns, acts, dilats, bn_momentum = 0.1, padding=None):
r"""
:param nplanes_in: number of of input feature channels
:param kernels: list of kernel size for convolution layers
:param features: list of hidden layer feature channels
:param bns: list of whether to add batchnorm layers
:param acts: list of activations
:param dilats: list of dilation factors
:param bn_momentum: momentum of batchnorm
:param padding: integer for padding (None for same padding)
"""
depth = len(features)
assert(len(features)==len(kernels))
layers = list()
for i in range(0,depth):
if i==0:
in_feats = nplanes_in
else:
in_feats = features[i-1]
elem = conv_with_padding(in_feats, features[i], kernelsize=kernels[i], dilation=dilats[i], padding=padding, bias=not(bns[i]))
conv_init(elem, act=acts[i])
layers.append(elem)
if bns[i]:
elem = nn.BatchNorm2d(features[i], momentum = bn_momentum)
batchnorm_init(elem, kernelsize=kernels[i])
layers.append(elem)
elem = make_activation(acts[i])
if elem is not None:
layers.append(elem)
return nn.Sequential(*layers)
class DnCNN(nn.Module):
r"""
Implements a DnCNN network
"""
def __init__(self, nplanes_in, nplanes_out, features, kernel, depth, activation, residual, bn, lastact=None, bn_momentum = 0.10, padding=None):
r"""
:param nplanes_in: number of of input feature channels
:param nplanes_out: number of of output feature channels
:param features: number of of hidden layer feature channels
:param kernel: kernel size of convolution layers
:param depth: number of convolution layers (minimum 2)
:param bn: whether to add batchnorm layers
:param residual: whether to add a residual connection from input to output
:param bn_momentum: momentum of batchnorm
:param padding: inteteger for padding
"""
super(DnCNN, self).__init__()
self.residual = residual
self.nplanes_out = nplanes_out
self.nplanes_in = nplanes_in
kernels = [kernel, ] * depth
features = [features, ] * (depth-1) + [nplanes_out, ]
bns = [False, ] + [bn,] * (depth - 2) + [False, ]
dilats = [1, ] * depth
acts = [activation, ] * (depth - 1) + [lastact, ]
self.layers = make_net(nplanes_in, kernels, features, bns, acts, dilats=dilats, bn_momentum = bn_momentum, padding=padding)
def forward(self, x):
shortcut = x
x = self.layers(x)
if self.residual:
nshortcut = min(self.nplanes_in, self.nplanes_out)
x[:, :nshortcut, :, :] = x[:, :nshortcut, :, :] + shortcut[:, :nshortcut, :, :]
return x
def add_commandline_networkparams(parser, name, features, depth, kernel, activation, bn):
parser.add_argument("--{}.{}".format(name, "features" ), type=int, default=features )
parser.add_argument("--{}.{}".format(name, "depth" ), type=int, default=depth )
parser.add_argument("--{}.{}".format(name, "kernel" ), type=int, default=kernel )
parser.add_argument("--{}.{}".format(name, "activation"), type=str, default=activation)
bnarg = "{}.{}".format(name, "bn")
parser.add_argument("--"+bnarg , action="store_true", dest=bnarg)
parser.add_argument("--{}.{}".format(name, "no-bn"), action="store_false", dest=bnarg)
parser.set_defaults(**{bnarg: bn})
|