| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule, build_upsample_layer | |
| class UpConvBlock(nn.Module): | |
| """Upsample convolution block in decoder for UNet. | |
| This upsample convolution block consists of one upsample module | |
| followed by one convolution block. The upsample module expands the | |
| high-level low-resolution feature map and the convolution block fuses | |
| the upsampled high-level low-resolution feature map and the low-level | |
| high-resolution feature map from encoder. | |
| Args: | |
| conv_block (nn.Sequential): Sequential of convolutional layers. | |
| in_channels (int): Number of input channels of the high-level | |
| skip_channels (int): Number of input channels of the low-level | |
| high-resolution feature map from encoder. | |
| out_channels (int): Number of output channels. | |
| num_convs (int): Number of convolutional layers in the conv_block. | |
| Default: 2. | |
| stride (int): Stride of convolutional layer in conv_block. Default: 1. | |
| dilation (int): Dilation rate of convolutional layer in conv_block. | |
| Default: 1. | |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
| memory while slowing down the training speed. Default: False. | |
| conv_cfg (dict | None): Config dict for convolution layer. | |
| Default: None. | |
| norm_cfg (dict | None): Config dict for normalization layer. | |
| Default: dict(type='BN'). | |
| act_cfg (dict | None): Config dict for activation layer in ConvModule. | |
| Default: dict(type='ReLU'). | |
| upsample_cfg (dict): The upsample config of the upsample module in | |
| decoder. Default: dict(type='InterpConv'). If the size of | |
| high-level feature map is the same as that of skip feature map | |
| (low-level feature map from encoder), it does not need upsample the | |
| high-level feature map and the upsample_cfg is None. | |
| dcn (bool): Use deformable convoluton in convolutional layer or not. | |
| Default: None. | |
| plugins (dict): plugins for convolutional layers. Default: None. | |
| """ | |
| def __init__(self, | |
| conv_block, | |
| in_channels, | |
| skip_channels, | |
| out_channels, | |
| num_convs=2, | |
| stride=1, | |
| dilation=1, | |
| with_cp=False, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'), | |
| upsample_cfg=dict(type='InterpConv'), | |
| dcn=None, | |
| plugins=None): | |
| super(UpConvBlock, self).__init__() | |
| assert dcn is None, 'Not implemented yet.' | |
| assert plugins is None, 'Not implemented yet.' | |
| self.conv_block = conv_block( | |
| in_channels=2 * skip_channels, | |
| out_channels=out_channels, | |
| num_convs=num_convs, | |
| stride=stride, | |
| dilation=dilation, | |
| with_cp=with_cp, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| dcn=None, | |
| plugins=None) | |
| if upsample_cfg is not None: | |
| self.upsample = build_upsample_layer( | |
| cfg=upsample_cfg, | |
| in_channels=in_channels, | |
| out_channels=skip_channels, | |
| with_cp=with_cp, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| else: | |
| self.upsample = ConvModule( | |
| in_channels, | |
| skip_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| def forward(self, skip, x): | |
| """Forward function.""" | |
| x = self.upsample(x) | |
| out = torch.cat([skip, x], dim=1) | |
| out = self.conv_block(out) | |
| return out | |