| import torch |
| import torch.nn as nn |
| from annotator.mmpkg.mmcv.cnn import ConvModule, build_upsample_layer |
|
|
|
|
| class UpConvBlock(nn.Module): |
| """Upsample convolution block in decoder for UNet. |
| |
| This upsample convolution block consists of one upsample module |
| followed by one convolution block. The upsample module expands the |
| high-level low-resolution feature map and the convolution block fuses |
| the upsampled high-level low-resolution feature map and the low-level |
| high-resolution feature map from encoder. |
| |
| Args: |
| conv_block (nn.Sequential): Sequential of convolutional layers. |
| in_channels (int): Number of input channels of the high-level |
| skip_channels (int): Number of input channels of the low-level |
| high-resolution feature map from encoder. |
| out_channels (int): Number of output channels. |
| num_convs (int): Number of convolutional layers in the conv_block. |
| Default: 2. |
| stride (int): Stride of convolutional layer in conv_block. Default: 1. |
| dilation (int): Dilation rate of convolutional layer in conv_block. |
| Default: 1. |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
| memory while slowing down the training speed. Default: False. |
| conv_cfg (dict | None): Config dict for convolution layer. |
| Default: None. |
| norm_cfg (dict | None): Config dict for normalization layer. |
| Default: dict(type='BN'). |
| act_cfg (dict | None): Config dict for activation layer in ConvModule. |
| Default: dict(type='ReLU'). |
| upsample_cfg (dict): The upsample config of the upsample module in |
| decoder. Default: dict(type='InterpConv'). If the size of |
| high-level feature map is the same as that of skip feature map |
| (low-level feature map from encoder), it does not need upsample the |
| high-level feature map and the upsample_cfg is None. |
| dcn (bool): Use deformable convolution in convolutional layer or not. |
| Default: None. |
| plugins (dict): plugins for convolutional layers. Default: None. |
| """ |
|
|
| def __init__(self, |
| conv_block, |
| in_channels, |
| skip_channels, |
| out_channels, |
| num_convs=2, |
| stride=1, |
| dilation=1, |
| with_cp=False, |
| conv_cfg=None, |
| norm_cfg=dict(type='BN'), |
| act_cfg=dict(type='ReLU'), |
| upsample_cfg=dict(type='InterpConv'), |
| dcn=None, |
| plugins=None): |
| super(UpConvBlock, self).__init__() |
| assert dcn is None, 'Not implemented yet.' |
| assert plugins is None, 'Not implemented yet.' |
|
|
| self.conv_block = conv_block( |
| in_channels=2 * skip_channels, |
| out_channels=out_channels, |
| num_convs=num_convs, |
| stride=stride, |
| dilation=dilation, |
| with_cp=with_cp, |
| conv_cfg=conv_cfg, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg, |
| dcn=None, |
| plugins=None) |
| if upsample_cfg is not None: |
| self.upsample = build_upsample_layer( |
| cfg=upsample_cfg, |
| in_channels=in_channels, |
| out_channels=skip_channels, |
| with_cp=with_cp, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg) |
| else: |
| self.upsample = ConvModule( |
| in_channels, |
| skip_channels, |
| kernel_size=1, |
| stride=1, |
| padding=0, |
| conv_cfg=conv_cfg, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg) |
|
|
| def forward(self, skip, x): |
| """Forward function.""" |
|
|
| x = self.upsample(x) |
| out = torch.cat([skip, x], dim=1) |
| out = self.conv_block(out) |
|
|
| return out |
|
|