| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from copy import deepcopy |
| | from nnunet.network_architecture.custom_modules.helperModules import Identity |
| | from torch import nn |
| |
|
| |
|
| | class ConvDropoutNormReLU(nn.Module): |
| | def __init__(self, input_channels, output_channels, kernel_size, network_props): |
| | """ |
| | if network_props['dropout_op'] is None then no dropout |
| | if network_props['norm_op'] is None then no norm |
| | :param input_channels: |
| | :param output_channels: |
| | :param kernel_size: |
| | :param network_props: |
| | """ |
| | super(ConvDropoutNormReLU, self).__init__() |
| |
|
| | network_props = deepcopy(network_props) |
| |
|
| | self.conv = network_props['conv_op'](input_channels, output_channels, kernel_size, |
| | padding=[(i - 1) // 2 for i in kernel_size], |
| | **network_props['conv_op_kwargs']) |
| |
|
| | |
| | if network_props['dropout_op'] is not None: |
| | self.do = network_props['dropout_op'](**network_props['dropout_op_kwargs']) |
| | else: |
| | self.do = Identity() |
| |
|
| | if network_props['norm_op'] is not None: |
| | self.norm = network_props['norm_op'](output_channels, **network_props['norm_op_kwargs']) |
| | else: |
| | self.norm = Identity() |
| |
|
| | self.nonlin = network_props['nonlin'](**network_props['nonlin_kwargs']) |
| |
|
| | self.all = nn.Sequential(self.conv, self.do, self.norm, self.nonlin) |
| |
|
| | def forward(self, x): |
| | return self.all(x) |
| |
|
| |
|
| | class StackedConvLayers(nn.Module): |
| | def __init__(self, input_channels, output_channels, kernel_size, network_props, num_convs, first_stride=None): |
| | """ |
| | if network_props['dropout_op'] is None then no dropout |
| | if network_props['norm_op'] is None then no norm |
| | :param input_channels: |
| | :param output_channels: |
| | :param kernel_size: |
| | :param network_props: |
| | """ |
| | super(StackedConvLayers, self).__init__() |
| |
|
| | network_props = deepcopy(network_props) |
| | network_props_first = deepcopy(network_props) |
| |
|
| | if first_stride is not None: |
| | network_props_first['conv_op_kwargs']['stride'] = first_stride |
| |
|
| | self.convs = nn.Sequential( |
| | ConvDropoutNormReLU(input_channels, output_channels, kernel_size, network_props_first), |
| | *[ConvDropoutNormReLU(output_channels, output_channels, kernel_size, network_props) for _ in |
| | range(num_convs - 1)] |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.convs(x) |
| |
|
| |
|
| | class BasicResidualBlock(nn.Module): |
| | def __init__(self, in_planes, out_planes, kernel_size, props, stride=None): |
| | """ |
| | This is the conv bn nonlin conv bn nonlin kind of block |
| | :param in_planes: |
| | :param out_planes: |
| | :param props: |
| | :param override_stride: |
| | """ |
| | super().__init__() |
| |
|
| | self.kernel_size = kernel_size |
| | props['conv_op_kwargs']['stride'] = 1 |
| |
|
| | self.stride = stride |
| | self.props = props |
| | self.out_planes = out_planes |
| | self.in_planes = in_planes |
| |
|
| | if stride is not None: |
| | kwargs_conv1 = deepcopy(props['conv_op_kwargs']) |
| | kwargs_conv1['stride'] = stride |
| | else: |
| | kwargs_conv1 = props['conv_op_kwargs'] |
| |
|
| | self.conv1 = props['conv_op'](in_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], |
| | **kwargs_conv1) |
| | self.norm1 = props['norm_op'](out_planes, **props['norm_op_kwargs']) |
| | self.nonlin1 = props['nonlin'](**props['nonlin_kwargs']) |
| |
|
| | if props['dropout_op_kwargs']['p'] != 0: |
| | self.dropout = props['dropout_op'](**props['dropout_op_kwargs']) |
| | else: |
| | self.dropout = Identity() |
| |
|
| | self.conv2 = props['conv_op'](out_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], |
| | **props['conv_op_kwargs']) |
| | self.norm2 = props['norm_op'](out_planes, **props['norm_op_kwargs']) |
| | self.nonlin2 = props['nonlin'](**props['nonlin_kwargs']) |
| |
|
| | if (self.stride is not None and any((i != 1 for i in self.stride))) or (in_planes != out_planes): |
| | stride_here = stride if stride is not None else 1 |
| | self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False), |
| | props['norm_op'](out_planes, **props['norm_op_kwargs'])) |
| | else: |
| | self.downsample_skip = lambda x: x |
| |
|
| | def forward(self, x): |
| | residual = x |
| |
|
| | out = self.dropout(self.conv1(x)) |
| | out = self.nonlin1(self.norm1(out)) |
| |
|
| | out = self.norm2(self.conv2(out)) |
| |
|
| | residual = self.downsample_skip(residual) |
| |
|
| | out += residual |
| |
|
| | return self.nonlin2(out) |
| |
|
| |
|
| | class ResidualBottleneckBlock(nn.Module): |
| | def __init__(self, in_planes, out_planes, kernel_size, props, stride=None): |
| | """ |
| | This is the conv bn nonlin conv bn nonlin kind of block |
| | :param in_planes: |
| | :param out_planes: |
| | :param props: |
| | :param override_stride: |
| | """ |
| | super().__init__() |
| |
|
| | if props['dropout_op_kwargs'] is None and props['dropout_op_kwargs'] > 0: |
| | raise NotImplementedError("ResidualBottleneckBlock does not yet support dropout!") |
| |
|
| | self.kernel_size = kernel_size |
| | props['conv_op_kwargs']['stride'] = 1 |
| |
|
| | self.stride = stride |
| | self.props = props |
| | self.out_planes = out_planes |
| | self.in_planes = in_planes |
| | self.bottleneck_planes = out_planes // 4 |
| |
|
| | if stride is not None: |
| | kwargs_conv1 = deepcopy(props['conv_op_kwargs']) |
| | kwargs_conv1['stride'] = stride |
| | else: |
| | kwargs_conv1 = props['conv_op_kwargs'] |
| |
|
| | self.conv1 = props['conv_op'](in_planes, self.bottleneck_planes, [1 for _ in kernel_size], padding=[0 for i in kernel_size], |
| | **kwargs_conv1) |
| | self.norm1 = props['norm_op'](self.bottleneck_planes, **props['norm_op_kwargs']) |
| | self.nonlin1 = props['nonlin'](**props['nonlin_kwargs']) |
| |
|
| | self.conv2 = props['conv_op'](self.bottleneck_planes, self.bottleneck_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], |
| | **props['conv_op_kwargs']) |
| | self.norm2 = props['norm_op'](self.bottleneck_planes, **props['norm_op_kwargs']) |
| | self.nonlin2 = props['nonlin'](**props['nonlin_kwargs']) |
| |
|
| | self.conv3 = props['conv_op'](self.bottleneck_planes, out_planes, [1 for _ in kernel_size], padding=[0 for i in kernel_size], |
| | **props['conv_op_kwargs']) |
| | self.norm3 = props['norm_op'](out_planes, **props['norm_op_kwargs']) |
| | self.nonlin3 = props['nonlin'](**props['nonlin_kwargs']) |
| |
|
| | if (self.stride is not None and any((i != 1 for i in self.stride))) or (in_planes != out_planes): |
| | stride_here = stride if stride is not None else 1 |
| | self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False), |
| | props['norm_op'](out_planes, **props['norm_op_kwargs'])) |
| | else: |
| | self.downsample_skip = lambda x: x |
| |
|
| | def forward(self, x): |
| | residual = x |
| |
|
| | out = self.nonlin1(self.norm1(self.conv1(x))) |
| | out = self.nonlin2(self.norm2(self.conv2(out))) |
| |
|
| | out = self.norm3(self.conv3(out)) |
| |
|
| | residual = self.downsample_skip(residual) |
| |
|
| | out += residual |
| |
|
| | return self.nonlin3(out) |
| |
|
| |
|
| | class ResidualLayer(nn.Module): |
| | def __init__(self, input_channels, output_channels, kernel_size, network_props, num_blocks, first_stride=None, block=BasicResidualBlock): |
| | super().__init__() |
| |
|
| | network_props = deepcopy(network_props) |
| |
|
| | self.convs = nn.Sequential( |
| | block(input_channels, output_channels, kernel_size, network_props, first_stride), |
| | *[block(output_channels, output_channels, kernel_size, network_props) for _ in |
| | range(num_blocks - 1)] |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.convs(x) |
| |
|
| |
|