| |
| from typing import Tuple |
|
|
| import torch.nn.functional as F |
| from mmcv.cnn import ConvModule |
| from mmcv.cnn.bricks import NonLocal2d |
| from mmengine.model import BaseModule |
| from torch import Tensor |
|
|
| from mmdet.registry import MODELS |
| from mmdet.utils import OptConfigType, OptMultiConfig |
|
|
|
|
| @MODELS.register_module() |
| class BFP(BaseModule): |
| """BFP (Balanced Feature Pyramids) |
| |
| BFP takes multi-level features as inputs and gather them into a single one, |
| then refine the gathered feature and scatter the refined results to |
| multi-level features. This module is used in Libra R-CNN (CVPR 2019), see |
| the paper `Libra R-CNN: Towards Balanced Learning for Object Detection |
| <https://arxiv.org/abs/1904.02701>`_ for details. |
| |
| Args: |
| in_channels (int): Number of input channels (feature maps of all levels |
| should have the same channels). |
| num_levels (int): Number of input feature levels. |
| refine_level (int): Index of integration and refine level of BSF in |
| multi-level features from bottom to top. |
| refine_type (str): Type of the refine op, currently support |
| [None, 'conv', 'non_local']. |
| conv_cfg (:obj:`ConfigDict` or dict, optional): The config dict for |
| convolution layers. |
| norm_cfg (:obj:`ConfigDict` or dict, optional): The config dict for |
| normalization layers. |
| init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or |
| dict], optional): Initialization config dict. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| num_levels: int, |
| refine_level: int = 2, |
| refine_type: str = None, |
| conv_cfg: OptConfigType = None, |
| norm_cfg: OptConfigType = None, |
| init_cfg: OptMultiConfig = dict( |
| type='Xavier', layer='Conv2d', distribution='uniform') |
| ) -> None: |
| super().__init__(init_cfg=init_cfg) |
| assert refine_type in [None, 'conv', 'non_local'] |
|
|
| self.in_channels = in_channels |
| self.num_levels = num_levels |
| self.conv_cfg = conv_cfg |
| self.norm_cfg = norm_cfg |
|
|
| self.refine_level = refine_level |
| self.refine_type = refine_type |
| assert 0 <= self.refine_level < self.num_levels |
|
|
| if self.refine_type == 'conv': |
| self.refine = ConvModule( |
| self.in_channels, |
| self.in_channels, |
| 3, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg) |
| elif self.refine_type == 'non_local': |
| self.refine = NonLocal2d( |
| self.in_channels, |
| reduction=1, |
| use_scale=False, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg) |
|
|
| def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]: |
| """Forward function.""" |
| assert len(inputs) == self.num_levels |
|
|
| |
| feats = [] |
| gather_size = inputs[self.refine_level].size()[2:] |
| for i in range(self.num_levels): |
| if i < self.refine_level: |
| gathered = F.adaptive_max_pool2d( |
| inputs[i], output_size=gather_size) |
| else: |
| gathered = F.interpolate( |
| inputs[i], size=gather_size, mode='nearest') |
| feats.append(gathered) |
|
|
| bsf = sum(feats) / len(feats) |
|
|
| |
| if self.refine_type is not None: |
| bsf = self.refine(bsf) |
|
|
| |
| outs = [] |
| for i in range(self.num_levels): |
| out_size = inputs[i].size()[2:] |
| if i < self.refine_level: |
| residual = F.interpolate(bsf, size=out_size, mode='nearest') |
| else: |
| residual = F.adaptive_max_pool2d(bsf, output_size=out_size) |
| outs.append(residual + inputs[i]) |
|
|
| return tuple(outs) |
|
|