| |
| import torch.nn as nn |
| import torch.utils.checkpoint as cp |
| from mmcv.cnn import build_conv_layer, build_norm_layer |
| from mmengine.logging import MMLogger |
| from mmengine.model import Sequential, constant_init, kaiming_init |
| from mmengine.runner.checkpoint import load_checkpoint |
| from torch.nn.modules.batchnorm import _BatchNorm |
|
|
| from mmdet.registry import MODELS |
| from .resnet import BasicBlock |
| from .resnet import Bottleneck as _Bottleneck |
| from .resnet import ResNet |
|
|
|
|
| class Bottleneck(_Bottleneck): |
| r"""Bottleneck for the ResNet backbone in `DetectoRS |
| <https://arxiv.org/pdf/2006.02334.pdf>`_. |
| |
| This bottleneck allows the users to specify whether to use |
| SAC (Switchable Atrous Convolution) and RFP (Recursive Feature Pyramid). |
| |
| Args: |
| inplanes (int): The number of input channels. |
| planes (int): The number of output channels before expansion. |
| rfp_inplanes (int, optional): The number of channels from RFP. |
| Default: None. If specified, an additional conv layer will be |
| added for ``rfp_feat``. Otherwise, the structure is the same as |
| base class. |
| sac (dict, optional): Dictionary to construct SAC. Default: None. |
| init_cfg (dict or list[dict], optional): Initialization config dict. |
| Default: None |
| """ |
| expansion = 4 |
|
|
| def __init__(self, |
| inplanes, |
| planes, |
| rfp_inplanes=None, |
| sac=None, |
| init_cfg=None, |
| **kwargs): |
| super(Bottleneck, self).__init__( |
| inplanes, planes, init_cfg=init_cfg, **kwargs) |
|
|
| assert sac is None or isinstance(sac, dict) |
| self.sac = sac |
| self.with_sac = sac is not None |
| if self.with_sac: |
| self.conv2 = build_conv_layer( |
| self.sac, |
| planes, |
| planes, |
| kernel_size=3, |
| stride=self.conv2_stride, |
| padding=self.dilation, |
| dilation=self.dilation, |
| bias=False) |
|
|
| self.rfp_inplanes = rfp_inplanes |
| if self.rfp_inplanes: |
| self.rfp_conv = build_conv_layer( |
| None, |
| self.rfp_inplanes, |
| planes * self.expansion, |
| 1, |
| stride=1, |
| bias=True) |
| if init_cfg is None: |
| self.init_cfg = dict( |
| type='Constant', val=0, override=dict(name='rfp_conv')) |
|
|
| def rfp_forward(self, x, rfp_feat): |
| """The forward function that also takes the RFP features as input.""" |
|
|
| def _inner_forward(x): |
| identity = x |
|
|
| out = self.conv1(x) |
| out = self.norm1(out) |
| out = self.relu(out) |
|
|
| if self.with_plugins: |
| out = self.forward_plugin(out, self.after_conv1_plugin_names) |
|
|
| out = self.conv2(out) |
| out = self.norm2(out) |
| out = self.relu(out) |
|
|
| if self.with_plugins: |
| out = self.forward_plugin(out, self.after_conv2_plugin_names) |
|
|
| out = self.conv3(out) |
| out = self.norm3(out) |
|
|
| if self.with_plugins: |
| out = self.forward_plugin(out, self.after_conv3_plugin_names) |
|
|
| if self.downsample is not None: |
| identity = self.downsample(x) |
|
|
| out += identity |
|
|
| return out |
|
|
| if self.with_cp and x.requires_grad: |
| out = cp.checkpoint(_inner_forward, x) |
| else: |
| out = _inner_forward(x) |
|
|
| if self.rfp_inplanes: |
| rfp_feat = self.rfp_conv(rfp_feat) |
| out = out + rfp_feat |
|
|
| out = self.relu(out) |
|
|
| return out |
|
|
|
|
| class ResLayer(Sequential): |
| """ResLayer to build ResNet style backbone for RPF in detectoRS. |
| |
| The difference between this module and base class is that we pass |
| ``rfp_inplanes`` to the first block. |
| |
| Args: |
| block (nn.Module): block used to build ResLayer. |
| inplanes (int): inplanes of block. |
| planes (int): planes of block. |
| num_blocks (int): number of blocks. |
| stride (int): stride of the first block. Default: 1 |
| avg_down (bool): Use AvgPool instead of stride conv when |
| downsampling in the bottleneck. Default: False |
| conv_cfg (dict): dictionary to construct and config conv layer. |
| Default: None |
| norm_cfg (dict): dictionary to construct and config norm layer. |
| Default: dict(type='BN') |
| downsample_first (bool): Downsample at the first block or last block. |
| False for Hourglass, True for ResNet. Default: True |
| rfp_inplanes (int, optional): The number of channels from RFP. |
| Default: None. If specified, an additional conv layer will be |
| added for ``rfp_feat``. Otherwise, the structure is the same as |
| base class. |
| """ |
|
|
| def __init__(self, |
| block, |
| inplanes, |
| planes, |
| num_blocks, |
| stride=1, |
| avg_down=False, |
| conv_cfg=None, |
| norm_cfg=dict(type='BN'), |
| downsample_first=True, |
| rfp_inplanes=None, |
| **kwargs): |
| self.block = block |
| assert downsample_first, f'downsample_first={downsample_first} is ' \ |
| 'not supported in DetectoRS' |
|
|
| downsample = None |
| if stride != 1 or inplanes != planes * block.expansion: |
| downsample = [] |
| conv_stride = stride |
| if avg_down and stride != 1: |
| conv_stride = 1 |
| downsample.append( |
| nn.AvgPool2d( |
| kernel_size=stride, |
| stride=stride, |
| ceil_mode=True, |
| count_include_pad=False)) |
| downsample.extend([ |
| build_conv_layer( |
| conv_cfg, |
| inplanes, |
| planes * block.expansion, |
| kernel_size=1, |
| stride=conv_stride, |
| bias=False), |
| build_norm_layer(norm_cfg, planes * block.expansion)[1] |
| ]) |
| downsample = nn.Sequential(*downsample) |
|
|
| layers = [] |
| layers.append( |
| block( |
| inplanes=inplanes, |
| planes=planes, |
| stride=stride, |
| downsample=downsample, |
| conv_cfg=conv_cfg, |
| norm_cfg=norm_cfg, |
| rfp_inplanes=rfp_inplanes, |
| **kwargs)) |
| inplanes = planes * block.expansion |
| for _ in range(1, num_blocks): |
| layers.append( |
| block( |
| inplanes=inplanes, |
| planes=planes, |
| stride=1, |
| conv_cfg=conv_cfg, |
| norm_cfg=norm_cfg, |
| **kwargs)) |
|
|
| super(ResLayer, self).__init__(*layers) |
|
|
|
|
| @MODELS.register_module() |
| class DetectoRS_ResNet(ResNet): |
| """ResNet backbone for DetectoRS. |
| |
| Args: |
| sac (dict, optional): Dictionary to construct SAC (Switchable Atrous |
| Convolution). Default: None. |
| stage_with_sac (list): Which stage to use sac. Default: (False, False, |
| False, False). |
| rfp_inplanes (int, optional): The number of channels from RFP. |
| Default: None. If specified, an additional conv layer will be |
| added for ``rfp_feat``. Otherwise, the structure is the same as |
| base class. |
| output_img (bool): If ``True``, the input image will be inserted into |
| the starting position of output. Default: False. |
| """ |
|
|
| arch_settings = { |
| 50: (Bottleneck, (3, 4, 6, 3)), |
| 101: (Bottleneck, (3, 4, 23, 3)), |
| 152: (Bottleneck, (3, 8, 36, 3)) |
| } |
|
|
| def __init__(self, |
| sac=None, |
| stage_with_sac=(False, False, False, False), |
| rfp_inplanes=None, |
| output_img=False, |
| pretrained=None, |
| init_cfg=None, |
| **kwargs): |
| assert not (init_cfg and pretrained), \ |
| 'init_cfg and pretrained cannot be specified at the same time' |
| self.pretrained = pretrained |
| if init_cfg is not None: |
| assert isinstance(init_cfg, dict), \ |
| f'init_cfg must be a dict, but got {type(init_cfg)}' |
| if 'type' in init_cfg: |
| assert init_cfg.get('type') == 'Pretrained', \ |
| 'Only can initialize module by loading a pretrained model' |
| else: |
| raise KeyError('`init_cfg` must contain the key "type"') |
| self.pretrained = init_cfg.get('checkpoint') |
| self.sac = sac |
| self.stage_with_sac = stage_with_sac |
| self.rfp_inplanes = rfp_inplanes |
| self.output_img = output_img |
| super(DetectoRS_ResNet, self).__init__(**kwargs) |
|
|
| self.inplanes = self.stem_channels |
| self.res_layers = [] |
| for i, num_blocks in enumerate(self.stage_blocks): |
| stride = self.strides[i] |
| dilation = self.dilations[i] |
| dcn = self.dcn if self.stage_with_dcn[i] else None |
| sac = self.sac if self.stage_with_sac[i] else None |
| if self.plugins is not None: |
| stage_plugins = self.make_stage_plugins(self.plugins, i) |
| else: |
| stage_plugins = None |
| planes = self.base_channels * 2**i |
| res_layer = self.make_res_layer( |
| block=self.block, |
| inplanes=self.inplanes, |
| planes=planes, |
| num_blocks=num_blocks, |
| stride=stride, |
| dilation=dilation, |
| style=self.style, |
| avg_down=self.avg_down, |
| with_cp=self.with_cp, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| dcn=dcn, |
| sac=sac, |
| rfp_inplanes=rfp_inplanes if i > 0 else None, |
| plugins=stage_plugins) |
| self.inplanes = planes * self.block.expansion |
| layer_name = f'layer{i + 1}' |
| self.add_module(layer_name, res_layer) |
| self.res_layers.append(layer_name) |
|
|
| self._freeze_stages() |
|
|
| |
| def init_weights(self): |
| |
| |
|
|
| if isinstance(self.pretrained, str): |
| logger = MMLogger.get_current_instance() |
| load_checkpoint(self, self.pretrained, strict=False, logger=logger) |
| elif self.pretrained is None: |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| kaiming_init(m) |
| elif isinstance(m, (_BatchNorm, nn.GroupNorm)): |
| constant_init(m, 1) |
|
|
| if self.dcn is not None: |
| for m in self.modules(): |
| if isinstance(m, Bottleneck) and hasattr( |
| m.conv2, 'conv_offset'): |
| constant_init(m.conv2.conv_offset, 0) |
|
|
| if self.zero_init_residual: |
| for m in self.modules(): |
| if isinstance(m, Bottleneck): |
| constant_init(m.norm3, 0) |
| elif isinstance(m, BasicBlock): |
| constant_init(m.norm2, 0) |
| else: |
| raise TypeError('pretrained must be a str or None') |
|
|
| def make_res_layer(self, **kwargs): |
| """Pack all blocks in a stage into a ``ResLayer`` for DetectoRS.""" |
| return ResLayer(**kwargs) |
|
|
| def forward(self, x): |
| """Forward function.""" |
| outs = list(super(DetectoRS_ResNet, self).forward(x)) |
| if self.output_img: |
| outs.insert(0, x) |
| return tuple(outs) |
|
|
| def rfp_forward(self, x, rfp_feats): |
| """Forward function for RFP.""" |
| if self.deep_stem: |
| x = self.stem(x) |
| else: |
| x = self.conv1(x) |
| x = self.norm1(x) |
| x = self.relu(x) |
| x = self.maxpool(x) |
| outs = [] |
| for i, layer_name in enumerate(self.res_layers): |
| res_layer = getattr(self, layer_name) |
| rfp_feat = rfp_feats[i] if i > 0 else None |
| for layer in res_layer: |
| x = layer.rfp_forward(x, rfp_feat) |
| if i in self.out_indices: |
| outs.append(x) |
| return tuple(outs) |
|
|