| from collections import OrderedDict
|
|
|
| import torch.nn as nn
|
|
|
| from .bn import ABN, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
|
| import torch.nn.functional as functional
|
|
|
|
|
| class ResidualBlock(nn.Module):
|
| """Configurable residual block
|
|
|
| Parameters
|
| ----------
|
| in_channels : int
|
| Number of input channels.
|
| channels : list of int
|
| Number of channels in the internal feature maps. Can either have two or three elements: if three construct
|
| a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
|
| `3 x 3` then `1 x 1` convolutions.
|
| stride : int
|
| Stride of the first `3 x 3` convolution
|
| dilation : int
|
| Dilation to apply to the `3 x 3` convolutions.
|
| groups : int
|
| Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
|
| bottleneck blocks.
|
| norm_act : callable
|
| Function to create normalization / activation Module.
|
| dropout: callable
|
| Function to create Dropout Module.
|
| """
|
|
|
| def __init__(self,
|
| in_channels,
|
| channels,
|
| stride=1,
|
| dilation=1,
|
| groups=1,
|
| norm_act=ABN,
|
| dropout=None):
|
| super(ResidualBlock, self).__init__()
|
|
|
|
|
| if len(channels) != 2 and len(channels) != 3:
|
| raise ValueError("channels must contain either two or three values")
|
| if len(channels) == 2 and groups != 1:
|
| raise ValueError("groups > 1 are only valid if len(channels) == 3")
|
|
|
| is_bottleneck = len(channels) == 3
|
| need_proj_conv = stride != 1 or in_channels != channels[-1]
|
|
|
| if not is_bottleneck:
|
| bn2 = norm_act(channels[1])
|
| bn2.activation = ACT_NONE
|
| layers = [
|
| ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
|
| dilation=dilation)),
|
| ("bn1", norm_act(channels[0])),
|
| ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
|
| dilation=dilation)),
|
| ("bn2", bn2)
|
| ]
|
| if dropout is not None:
|
| layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
|
| else:
|
| bn3 = norm_act(channels[2])
|
| bn3.activation = ACT_NONE
|
| layers = [
|
| ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=1, padding=0, bias=False)),
|
| ("bn1", norm_act(channels[0])),
|
| ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=stride, padding=dilation, bias=False,
|
| groups=groups, dilation=dilation)),
|
| ("bn2", norm_act(channels[1])),
|
| ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)),
|
| ("bn3", bn3)
|
| ]
|
| if dropout is not None:
|
| layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
|
| self.convs = nn.Sequential(OrderedDict(layers))
|
|
|
| if need_proj_conv:
|
| self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
|
| self.proj_bn = norm_act(channels[-1])
|
| self.proj_bn.activation = ACT_NONE
|
|
|
| def forward(self, x):
|
| if hasattr(self, "proj_conv"):
|
| residual = self.proj_conv(x)
|
| residual = self.proj_bn(residual)
|
| else:
|
| residual = x
|
| x = self.convs(x) + residual
|
|
|
| if self.convs.bn1.activation == ACT_LEAKY_RELU:
|
| return functional.leaky_relu(x, negative_slope=self.convs.bn1.slope, inplace=True)
|
| elif self.convs.bn1.activation == ACT_ELU:
|
| return functional.elu(x, inplace=True)
|
| else:
|
| return x
|
|
|
|
|
| class IdentityResidualBlock(nn.Module):
|
| def __init__(self,
|
| in_channels,
|
| channels,
|
| stride=1,
|
| dilation=1,
|
| groups=1,
|
| norm_act=ABN,
|
| dropout=None):
|
| """Configurable identity-mapping residual block
|
|
|
| Parameters
|
| ----------
|
| in_channels : int
|
| Number of input channels.
|
| channels : list of int
|
| Number of channels in the internal feature maps. Can either have two or three elements: if three construct
|
| a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
|
| `3 x 3` then `1 x 1` convolutions.
|
| stride : int
|
| Stride of the first `3 x 3` convolution
|
| dilation : int
|
| Dilation to apply to the `3 x 3` convolutions.
|
| groups : int
|
| Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
|
| bottleneck blocks.
|
| norm_act : callable
|
| Function to create normalization / activation Module.
|
| dropout: callable
|
| Function to create Dropout Module.
|
| """
|
| super(IdentityResidualBlock, self).__init__()
|
|
|
|
|
| if len(channels) != 2 and len(channels) != 3:
|
| raise ValueError("channels must contain either two or three values")
|
| if len(channels) == 2 and groups != 1:
|
| raise ValueError("groups > 1 are only valid if len(channels) == 3")
|
|
|
| is_bottleneck = len(channels) == 3
|
| need_proj_conv = stride != 1 or in_channels != channels[-1]
|
|
|
| self.bn1 = norm_act(in_channels)
|
| if not is_bottleneck:
|
| layers = [
|
| ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
|
| dilation=dilation)),
|
| ("bn2", norm_act(channels[0])),
|
| ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
|
| dilation=dilation))
|
| ]
|
| if dropout is not None:
|
| layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
|
| else:
|
| layers = [
|
| ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)),
|
| ("bn2", norm_act(channels[0])),
|
| ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
|
| groups=groups, dilation=dilation)),
|
| ("bn3", norm_act(channels[1])),
|
| ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
|
| ]
|
| if dropout is not None:
|
| layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
|
| self.convs = nn.Sequential(OrderedDict(layers))
|
|
|
| if need_proj_conv:
|
| self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
|
|
|
| def forward(self, x):
|
| if hasattr(self, "proj_conv"):
|
| bn1 = self.bn1(x)
|
| shortcut = self.proj_conv(bn1)
|
| else:
|
| shortcut = x.clone()
|
| bn1 = self.bn1(x)
|
|
|
| out = self.convs(bn1)
|
| out.add_(shortcut)
|
|
|
| return out
|
|
|