| | import torch.nn as nn |
| | import torch.utils.checkpoint as cp |
| | from annotator.mmpkg.mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, |
| | build_norm_layer, constant_init, kaiming_init) |
| | from annotator.mmpkg.mmcv.runner import load_checkpoint |
| | from annotator.mmpkg.mmcv.utils.parrots_wrapper import _BatchNorm |
| |
|
| | from annotator.mmpkg.mmseg.utils import get_root_logger |
| | from ..builder import BACKBONES |
| | from ..utils import UpConvBlock |
| |
|
| |
|
| | class BasicConvBlock(nn.Module): |
| | """Basic convolutional block for UNet. |
| | |
| | This module consists of several plain convolutional layers. |
| | |
| | Args: |
| | in_channels (int): Number of input channels. |
| | out_channels (int): Number of output channels. |
| | num_convs (int): Number of convolutional layers. Default: 2. |
| | stride (int): Whether use stride convolution to downsample |
| | the input feature map. If stride=2, it only uses stride convolution |
| | in the first convolutional layer to downsample the input feature |
| | map. Options are 1 or 2. Default: 1. |
| | dilation (int): Whether use dilated convolution to expand the |
| | receptive field. Set dilation rate of each convolutional layer and |
| | the dilation rate of the first convolutional layer is always 1. |
| | Default: 1. |
| | with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
| | memory while slowing down the training speed. Default: False. |
| | conv_cfg (dict | None): Config dict for convolution layer. |
| | Default: None. |
| | norm_cfg (dict | None): Config dict for normalization layer. |
| | Default: dict(type='BN'). |
| | act_cfg (dict | None): Config dict for activation layer in ConvModule. |
| | Default: dict(type='ReLU'). |
| | dcn (bool): Use deformable convolution in convolutional layer or not. |
| | Default: None. |
| | plugins (dict): plugins for convolutional layers. Default: None. |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels, |
| | out_channels, |
| | num_convs=2, |
| | stride=1, |
| | dilation=1, |
| | with_cp=False, |
| | conv_cfg=None, |
| | norm_cfg=dict(type='BN'), |
| | act_cfg=dict(type='ReLU'), |
| | dcn=None, |
| | plugins=None): |
| | super(BasicConvBlock, self).__init__() |
| | assert dcn is None, 'Not implemented yet.' |
| | assert plugins is None, 'Not implemented yet.' |
| |
|
| | self.with_cp = with_cp |
| | convs = [] |
| | for i in range(num_convs): |
| | convs.append( |
| | ConvModule( |
| | in_channels=in_channels if i == 0 else out_channels, |
| | out_channels=out_channels, |
| | kernel_size=3, |
| | stride=stride if i == 0 else 1, |
| | dilation=1 if i == 0 else dilation, |
| | padding=1 if i == 0 else dilation, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg)) |
| |
|
| | self.convs = nn.Sequential(*convs) |
| |
|
| | def forward(self, x): |
| | """Forward function.""" |
| |
|
| | if self.with_cp and x.requires_grad: |
| | out = cp.checkpoint(self.convs, x) |
| | else: |
| | out = self.convs(x) |
| | return out |
| |
|
| |
|
| | @UPSAMPLE_LAYERS.register_module() |
| | class DeconvModule(nn.Module): |
| | """Deconvolution upsample module in decoder for UNet (2X upsample). |
| | |
| | This module uses deconvolution to upsample feature map in the decoder |
| | of UNet. |
| | |
| | Args: |
| | in_channels (int): Number of input channels. |
| | out_channels (int): Number of output channels. |
| | with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
| | memory while slowing down the training speed. Default: False. |
| | norm_cfg (dict | None): Config dict for normalization layer. |
| | Default: dict(type='BN'). |
| | act_cfg (dict | None): Config dict for activation layer in ConvModule. |
| | Default: dict(type='ReLU'). |
| | kernel_size (int): Kernel size of the convolutional layer. Default: 4. |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels, |
| | out_channels, |
| | with_cp=False, |
| | norm_cfg=dict(type='BN'), |
| | act_cfg=dict(type='ReLU'), |
| | *, |
| | kernel_size=4, |
| | scale_factor=2): |
| | super(DeconvModule, self).__init__() |
| |
|
| | assert (kernel_size - scale_factor >= 0) and\ |
| | (kernel_size - scale_factor) % 2 == 0,\ |
| | f'kernel_size should be greater than or equal to scale_factor '\ |
| | f'and (kernel_size - scale_factor) should be even numbers, '\ |
| | f'while the kernel size is {kernel_size} and scale_factor is '\ |
| | f'{scale_factor}.' |
| |
|
| | stride = scale_factor |
| | padding = (kernel_size - scale_factor) // 2 |
| | self.with_cp = with_cp |
| | deconv = nn.ConvTranspose2d( |
| | in_channels, |
| | out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=padding) |
| |
|
| | norm_name, norm = build_norm_layer(norm_cfg, out_channels) |
| | activate = build_activation_layer(act_cfg) |
| | self.deconv_upsamping = nn.Sequential(deconv, norm, activate) |
| |
|
| | def forward(self, x): |
| | """Forward function.""" |
| |
|
| | if self.with_cp and x.requires_grad: |
| | out = cp.checkpoint(self.deconv_upsamping, x) |
| | else: |
| | out = self.deconv_upsamping(x) |
| | return out |
| |
|
| |
|
| | @UPSAMPLE_LAYERS.register_module() |
| | class InterpConv(nn.Module): |
| | """Interpolation upsample module in decoder for UNet. |
| | |
| | This module uses interpolation to upsample feature map in the decoder |
| | of UNet. It consists of one interpolation upsample layer and one |
| | convolutional layer. It can be one interpolation upsample layer followed |
| | by one convolutional layer (conv_first=False) or one convolutional layer |
| | followed by one interpolation upsample layer (conv_first=True). |
| | |
| | Args: |
| | in_channels (int): Number of input channels. |
| | out_channels (int): Number of output channels. |
| | with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
| | memory while slowing down the training speed. Default: False. |
| | norm_cfg (dict | None): Config dict for normalization layer. |
| | Default: dict(type='BN'). |
| | act_cfg (dict | None): Config dict for activation layer in ConvModule. |
| | Default: dict(type='ReLU'). |
| | conv_cfg (dict | None): Config dict for convolution layer. |
| | Default: None. |
| | conv_first (bool): Whether convolutional layer or interpolation |
| | upsample layer first. Default: False. It means interpolation |
| | upsample layer followed by one convolutional layer. |
| | kernel_size (int): Kernel size of the convolutional layer. Default: 1. |
| | stride (int): Stride of the convolutional layer. Default: 1. |
| | padding (int): Padding of the convolutional layer. Default: 1. |
| | upsample_cfg (dict): Interpolation config of the upsample layer. |
| | Default: dict( |
| | scale_factor=2, mode='bilinear', align_corners=False). |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels, |
| | out_channels, |
| | with_cp=False, |
| | norm_cfg=dict(type='BN'), |
| | act_cfg=dict(type='ReLU'), |
| | *, |
| | conv_cfg=None, |
| | conv_first=False, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0, |
| | upsample_cfg=dict( |
| | scale_factor=2, mode='bilinear', align_corners=False)): |
| | super(InterpConv, self).__init__() |
| |
|
| | self.with_cp = with_cp |
| | conv = ConvModule( |
| | in_channels, |
| | out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=padding, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg) |
| | upsample = nn.Upsample(**upsample_cfg) |
| | if conv_first: |
| | self.interp_upsample = nn.Sequential(conv, upsample) |
| | else: |
| | self.interp_upsample = nn.Sequential(upsample, conv) |
| |
|
| | def forward(self, x): |
| | """Forward function.""" |
| |
|
| | if self.with_cp and x.requires_grad: |
| | out = cp.checkpoint(self.interp_upsample, x) |
| | else: |
| | out = self.interp_upsample(x) |
| | return out |
| |
|
| |
|
| | @BACKBONES.register_module() |
| | class UNet(nn.Module): |
| | """UNet backbone. |
| | U-Net: Convolutional Networks for Biomedical Image Segmentation. |
| | https://arxiv.org/pdf/1505.04597.pdf |
| | |
| | Args: |
| | in_channels (int): Number of input image channels. Default" 3. |
| | base_channels (int): Number of base channels of each stage. |
| | The output channels of the first stage. Default: 64. |
| | num_stages (int): Number of stages in encoder, normally 5. Default: 5. |
| | strides (Sequence[int 1 | 2]): Strides of each stage in encoder. |
| | len(strides) is equal to num_stages. Normally the stride of the |
| | first stage in encoder is 1. If strides[i]=2, it uses stride |
| | convolution to downsample in the correspondence encoder stage. |
| | Default: (1, 1, 1, 1, 1). |
| | enc_num_convs (Sequence[int]): Number of convolutional layers in the |
| | convolution block of the correspondence encoder stage. |
| | Default: (2, 2, 2, 2, 2). |
| | dec_num_convs (Sequence[int]): Number of convolutional layers in the |
| | convolution block of the correspondence decoder stage. |
| | Default: (2, 2, 2, 2). |
| | downsamples (Sequence[int]): Whether use MaxPool to downsample the |
| | feature map after the first stage of encoder |
| | (stages: [1, num_stages)). If the correspondence encoder stage use |
| | stride convolution (strides[i]=2), it will never use MaxPool to |
| | downsample, even downsamples[i-1]=True. |
| | Default: (True, True, True, True). |
| | enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. |
| | Default: (1, 1, 1, 1, 1). |
| | dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. |
| | Default: (1, 1, 1, 1). |
| | with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
| | memory while slowing down the training speed. Default: False. |
| | conv_cfg (dict | None): Config dict for convolution layer. |
| | Default: None. |
| | norm_cfg (dict | None): Config dict for normalization layer. |
| | Default: dict(type='BN'). |
| | act_cfg (dict | None): Config dict for activation layer in ConvModule. |
| | Default: dict(type='ReLU'). |
| | upsample_cfg (dict): The upsample config of the upsample module in |
| | decoder. Default: dict(type='InterpConv'). |
| | norm_eval (bool): Whether to set norm layers to eval mode, namely, |
| | freeze running stats (mean and var). Note: Effect on Batch Norm |
| | and its variants only. Default: False. |
| | dcn (bool): Use deformable convolution in convolutional layer or not. |
| | Default: None. |
| | plugins (dict): plugins for convolutional layers. Default: None. |
| | |
| | Notice: |
| | The input image size should be divisible by the whole downsample rate |
| | of the encoder. More detail of the whole downsample rate can be found |
| | in UNet._check_input_divisible. |
| | |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1), |
| | with_cp=False, |
| | conv_cfg=None, |
| | norm_cfg=dict(type='BN'), |
| | act_cfg=dict(type='ReLU'), |
| | upsample_cfg=dict(type='InterpConv'), |
| | norm_eval=False, |
| | dcn=None, |
| | plugins=None): |
| | super(UNet, self).__init__() |
| | assert dcn is None, 'Not implemented yet.' |
| | assert plugins is None, 'Not implemented yet.' |
| | assert len(strides) == num_stages, \ |
| | 'The length of strides should be equal to num_stages, '\ |
| | f'while the strides is {strides}, the length of '\ |
| | f'strides is {len(strides)}, and the num_stages is '\ |
| | f'{num_stages}.' |
| | assert len(enc_num_convs) == num_stages, \ |
| | 'The length of enc_num_convs should be equal to num_stages, '\ |
| | f'while the enc_num_convs is {enc_num_convs}, the length of '\ |
| | f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ |
| | f'{num_stages}.' |
| | assert len(dec_num_convs) == (num_stages-1), \ |
| | 'The length of dec_num_convs should be equal to (num_stages-1), '\ |
| | f'while the dec_num_convs is {dec_num_convs}, the length of '\ |
| | f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ |
| | f'{num_stages}.' |
| | assert len(downsamples) == (num_stages-1), \ |
| | 'The length of downsamples should be equal to (num_stages-1), '\ |
| | f'while the downsamples is {downsamples}, the length of '\ |
| | f'downsamples is {len(downsamples)}, and the num_stages is '\ |
| | f'{num_stages}.' |
| | assert len(enc_dilations) == num_stages, \ |
| | 'The length of enc_dilations should be equal to num_stages, '\ |
| | f'while the enc_dilations is {enc_dilations}, the length of '\ |
| | f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ |
| | f'{num_stages}.' |
| | assert len(dec_dilations) == (num_stages-1), \ |
| | 'The length of dec_dilations should be equal to (num_stages-1), '\ |
| | f'while the dec_dilations is {dec_dilations}, the length of '\ |
| | f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ |
| | f'{num_stages}.' |
| | self.num_stages = num_stages |
| | self.strides = strides |
| | self.downsamples = downsamples |
| | self.norm_eval = norm_eval |
| | self.base_channels = base_channels |
| |
|
| | self.encoder = nn.ModuleList() |
| | self.decoder = nn.ModuleList() |
| |
|
| | for i in range(num_stages): |
| | enc_conv_block = [] |
| | if i != 0: |
| | if strides[i] == 1 and downsamples[i - 1]: |
| | enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) |
| | upsample = (strides[i] != 1 or downsamples[i - 1]) |
| | self.decoder.append( |
| | UpConvBlock( |
| | conv_block=BasicConvBlock, |
| | in_channels=base_channels * 2**i, |
| | skip_channels=base_channels * 2**(i - 1), |
| | out_channels=base_channels * 2**(i - 1), |
| | num_convs=dec_num_convs[i - 1], |
| | stride=1, |
| | dilation=dec_dilations[i - 1], |
| | with_cp=with_cp, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg, |
| | upsample_cfg=upsample_cfg if upsample else None, |
| | dcn=None, |
| | plugins=None)) |
| |
|
| | enc_conv_block.append( |
| | BasicConvBlock( |
| | in_channels=in_channels, |
| | out_channels=base_channels * 2**i, |
| | num_convs=enc_num_convs[i], |
| | stride=strides[i], |
| | dilation=enc_dilations[i], |
| | with_cp=with_cp, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg, |
| | dcn=None, |
| | plugins=None)) |
| | self.encoder.append((nn.Sequential(*enc_conv_block))) |
| | in_channels = base_channels * 2**i |
| |
|
| | def forward(self, x): |
| | self._check_input_divisible(x) |
| | enc_outs = [] |
| | for enc in self.encoder: |
| | x = enc(x) |
| | enc_outs.append(x) |
| | dec_outs = [x] |
| | for i in reversed(range(len(self.decoder))): |
| | x = self.decoder[i](enc_outs[i], x) |
| | dec_outs.append(x) |
| |
|
| | return dec_outs |
| |
|
| | def train(self, mode=True): |
| | """Convert the model into training mode while keep normalization layer |
| | freezed.""" |
| | super(UNet, self).train(mode) |
| | if mode and self.norm_eval: |
| | for m in self.modules(): |
| | |
| | if isinstance(m, _BatchNorm): |
| | m.eval() |
| |
|
| | def _check_input_divisible(self, x): |
| | h, w = x.shape[-2:] |
| | whole_downsample_rate = 1 |
| | for i in range(1, self.num_stages): |
| | if self.strides[i] == 2 or self.downsamples[i - 1]: |
| | whole_downsample_rate *= 2 |
| | assert (h % whole_downsample_rate == 0) \ |
| | and (w % whole_downsample_rate == 0),\ |
| | f'The input image size {(h, w)} should be divisible by the whole '\ |
| | f'downsample rate {whole_downsample_rate}, when num_stages is '\ |
| | f'{self.num_stages}, strides is {self.strides}, and downsamples '\ |
| | f'is {self.downsamples}.' |
| |
|
| | def init_weights(self, pretrained=None): |
| | """Initialize the weights in backbone. |
| | |
| | Args: |
| | pretrained (str, optional): Path to pre-trained weights. |
| | Defaults to None. |
| | """ |
| | if isinstance(pretrained, str): |
| | logger = get_root_logger() |
| | load_checkpoint(self, pretrained, strict=False, logger=logger) |
| | elif pretrained is None: |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | kaiming_init(m) |
| | elif isinstance(m, (_BatchNorm, nn.GroupNorm)): |
| | constant_init(m, 1) |
| | else: |
| | raise TypeError('pretrained must be a str or None') |
| |
|