| | |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from mmengine.model import BaseModule, ModuleList, constant_init, xavier_init |
| |
|
| | from mmdet.registry import MODELS |
| | from .fpn import FPN |
| |
|
| |
|
| | class ASPP(BaseModule): |
| | """ASPP (Atrous Spatial Pyramid Pooling) |
| | |
| | This is an implementation of the ASPP module used in DetectoRS |
| | (https://arxiv.org/pdf/2006.02334.pdf) |
| | |
| | Args: |
| | in_channels (int): Number of input channels. |
| | out_channels (int): Number of channels produced by this module |
| | dilations (tuple[int]): Dilations of the four branches. |
| | Default: (1, 3, 6, 1) |
| | init_cfg (dict or list[dict], optional): Initialization config dict. |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels, |
| | out_channels, |
| | dilations=(1, 3, 6, 1), |
| | init_cfg=dict(type='Kaiming', layer='Conv2d')): |
| | super().__init__(init_cfg) |
| | assert dilations[-1] == 1 |
| | self.aspp = nn.ModuleList() |
| | for dilation in dilations: |
| | kernel_size = 3 if dilation > 1 else 1 |
| | padding = dilation if dilation > 1 else 0 |
| | conv = nn.Conv2d( |
| | in_channels, |
| | out_channels, |
| | kernel_size=kernel_size, |
| | stride=1, |
| | dilation=dilation, |
| | padding=padding, |
| | bias=True) |
| | self.aspp.append(conv) |
| | self.gap = nn.AdaptiveAvgPool2d(1) |
| |
|
| | def forward(self, x): |
| | avg_x = self.gap(x) |
| | out = [] |
| | for aspp_idx in range(len(self.aspp)): |
| | inp = avg_x if (aspp_idx == len(self.aspp) - 1) else x |
| | out.append(F.relu_(self.aspp[aspp_idx](inp))) |
| | out[-1] = out[-1].expand_as(out[-2]) |
| | out = torch.cat(out, dim=1) |
| | return out |
| |
|
| |
|
| | @MODELS.register_module() |
| | class RFP(FPN): |
| | """RFP (Recursive Feature Pyramid) |
| | |
| | This is an implementation of RFP in `DetectoRS |
| | <https://arxiv.org/pdf/2006.02334.pdf>`_. Different from standard FPN, the |
| | input of RFP should be multi level features along with origin input image |
| | of backbone. |
| | |
| | Args: |
| | rfp_steps (int): Number of unrolled steps of RFP. |
| | rfp_backbone (dict): Configuration of the backbone for RFP. |
| | aspp_out_channels (int): Number of output channels of ASPP module. |
| | aspp_dilations (tuple[int]): Dilation rates of four branches. |
| | Default: (1, 3, 6, 1) |
| | init_cfg (dict or list[dict], optional): Initialization config dict. |
| | Default: None |
| | """ |
| |
|
| | def __init__(self, |
| | rfp_steps, |
| | rfp_backbone, |
| | aspp_out_channels, |
| | aspp_dilations=(1, 3, 6, 1), |
| | init_cfg=None, |
| | **kwargs): |
| | assert init_cfg is None, 'To prevent abnormal initialization ' \ |
| | 'behavior, init_cfg is not allowed to be set' |
| | super().__init__(init_cfg=init_cfg, **kwargs) |
| | self.rfp_steps = rfp_steps |
| | |
| | |
| | self.rfp_modules = ModuleList() |
| | for rfp_idx in range(1, rfp_steps): |
| | rfp_module = MODELS.build(rfp_backbone) |
| | self.rfp_modules.append(rfp_module) |
| | self.rfp_aspp = ASPP(self.out_channels, aspp_out_channels, |
| | aspp_dilations) |
| | self.rfp_weight = nn.Conv2d( |
| | self.out_channels, |
| | 1, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0, |
| | bias=True) |
| |
|
| | def init_weights(self): |
| | |
| | |
| | |
| | for convs in [self.lateral_convs, self.fpn_convs]: |
| | for m in convs.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | xavier_init(m, distribution='uniform') |
| | for rfp_idx in range(self.rfp_steps - 1): |
| | self.rfp_modules[rfp_idx].init_weights() |
| | constant_init(self.rfp_weight, 0) |
| |
|
| | def forward(self, inputs): |
| | inputs = list(inputs) |
| | assert len(inputs) == len(self.in_channels) + 1 |
| | img = inputs.pop(0) |
| | |
| | x = super().forward(tuple(inputs)) |
| | for rfp_idx in range(self.rfp_steps - 1): |
| | rfp_feats = [x[0]] + list( |
| | self.rfp_aspp(x[i]) for i in range(1, len(x))) |
| | x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats) |
| | |
| | x_idx = super().forward(x_idx) |
| | x_new = [] |
| | for ft_idx in range(len(x_idx)): |
| | add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx])) |
| | x_new.append(add_weight * x_idx[ft_idx] + |
| | (1 - add_weight) * x[ft_idx]) |
| | x = x_new |
| | return x |
| |
|