| from typing import List |
|
|
| import torch |
| import torch.nn as nn |
| from mmcv.cnn.bricks import Swish |
| from mmengine.model import BaseModule |
|
|
| from mmdet.registry import MODELS |
| from mmdet.utils import MultiConfig, OptConfigType |
| from .utils import DepthWiseConvBlock, DownChannelBlock, MaxPool2dSamePadding |
|
|
|
|
| class BiFPNStage(nn.Module): |
| """ |
| in_channels: List[int], input dim for P3, P4, P5 |
| out_channels: int, output dim for P2 - P7 |
| first_time: int, whether is the first bifpnstage |
| conv_bn_act_pattern: bool, whether use conv_bn_act_pattern |
| norm_cfg: (:obj:`ConfigDict` or dict, optional): Config dict for |
| normalization layer. |
| epsilon: float, hyperparameter in fusion features |
| """ |
|
|
| def __init__(self, |
| in_channels: List[int], |
| out_channels: int, |
| first_time: bool = False, |
| apply_bn_for_resampling: bool = True, |
| conv_bn_act_pattern: bool = False, |
| norm_cfg: OptConfigType = dict( |
| type='BN', momentum=1e-2, eps=1e-3), |
| epsilon: float = 1e-4) -> None: |
| super().__init__() |
| assert isinstance(in_channels, list) |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.first_time = first_time |
| self.apply_bn_for_resampling = apply_bn_for_resampling |
| self.conv_bn_act_pattern = conv_bn_act_pattern |
| self.norm_cfg = norm_cfg |
| self.epsilon = epsilon |
|
|
| if self.first_time: |
| self.p5_down_channel = DownChannelBlock( |
| self.in_channels[-1], |
| self.out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| self.p4_down_channel = DownChannelBlock( |
| self.in_channels[-2], |
| self.out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| self.p3_down_channel = DownChannelBlock( |
| self.in_channels[-3], |
| self.out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| self.p5_to_p6 = nn.Sequential( |
| DownChannelBlock( |
| self.in_channels[-1], |
| self.out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg), MaxPool2dSamePadding(3, 2)) |
| self.p6_to_p7 = MaxPool2dSamePadding(3, 2) |
| self.p4_level_connection = DownChannelBlock( |
| self.in_channels[-2], |
| self.out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| self.p5_level_connection = DownChannelBlock( |
| self.in_channels[-1], |
| self.out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
|
|
| self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest') |
| self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest') |
| self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest') |
| self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest') |
|
|
| |
| self.p4_down_sample = MaxPool2dSamePadding(3, 2) |
| self.p5_down_sample = MaxPool2dSamePadding(3, 2) |
| self.p6_down_sample = MaxPool2dSamePadding(3, 2) |
| self.p7_down_sample = MaxPool2dSamePadding(3, 2) |
|
|
| |
| self.conv6_up = DepthWiseConvBlock( |
| out_channels, |
| out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| self.conv5_up = DepthWiseConvBlock( |
| out_channels, |
| out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| self.conv4_up = DepthWiseConvBlock( |
| out_channels, |
| out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| self.conv3_up = DepthWiseConvBlock( |
| out_channels, |
| out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| self.conv4_down = DepthWiseConvBlock( |
| out_channels, |
| out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| self.conv5_down = DepthWiseConvBlock( |
| out_channels, |
| out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| self.conv6_down = DepthWiseConvBlock( |
| out_channels, |
| out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| self.conv7_down = DepthWiseConvBlock( |
| out_channels, |
| out_channels, |
| apply_norm=self.apply_bn_for_resampling, |
| conv_bn_act_pattern=self.conv_bn_act_pattern, |
| norm_cfg=norm_cfg) |
| |
| self.p6_w1 = nn.Parameter( |
| torch.ones(2, dtype=torch.float32), requires_grad=True) |
| self.p6_w1_relu = nn.ReLU() |
| self.p5_w1 = nn.Parameter( |
| torch.ones(2, dtype=torch.float32), requires_grad=True) |
| self.p5_w1_relu = nn.ReLU() |
| self.p4_w1 = nn.Parameter( |
| torch.ones(2, dtype=torch.float32), requires_grad=True) |
| self.p4_w1_relu = nn.ReLU() |
| self.p3_w1 = nn.Parameter( |
| torch.ones(2, dtype=torch.float32), requires_grad=True) |
| self.p3_w1_relu = nn.ReLU() |
|
|
| self.p4_w2 = nn.Parameter( |
| torch.ones(3, dtype=torch.float32), requires_grad=True) |
| self.p4_w2_relu = nn.ReLU() |
| self.p5_w2 = nn.Parameter( |
| torch.ones(3, dtype=torch.float32), requires_grad=True) |
| self.p5_w2_relu = nn.ReLU() |
| self.p6_w2 = nn.Parameter( |
| torch.ones(3, dtype=torch.float32), requires_grad=True) |
| self.p6_w2_relu = nn.ReLU() |
| self.p7_w2 = nn.Parameter( |
| torch.ones(2, dtype=torch.float32), requires_grad=True) |
| self.p7_w2_relu = nn.ReLU() |
|
|
| self.swish = Swish() |
|
|
| def combine(self, x): |
| if not self.conv_bn_act_pattern: |
| x = self.swish(x) |
|
|
| return x |
|
|
| def forward(self, x): |
| if self.first_time: |
| p3, p4, p5 = x |
| |
| p6_in = self.p5_to_p6(p5) |
| |
| p7_in = self.p6_to_p7(p6_in) |
|
|
| p3_in = self.p3_down_channel(p3) |
| p4_in = self.p4_down_channel(p4) |
| p5_in = self.p5_down_channel(p5) |
|
|
| else: |
| p3_in, p4_in, p5_in, p6_in, p7_in = x |
|
|
| |
| p6_w1 = self.p6_w1_relu(self.p6_w1) |
| weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon) |
| |
| p6_up = self.conv6_up( |
| self.combine(weight[0] * p6_in + |
| weight[1] * self.p6_upsample(p7_in))) |
|
|
| |
| p5_w1 = self.p5_w1_relu(self.p5_w1) |
| weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon) |
| |
| p5_up = self.conv5_up( |
| self.combine(weight[0] * p5_in + |
| weight[1] * self.p5_upsample(p6_up))) |
|
|
| |
| p4_w1 = self.p4_w1_relu(self.p4_w1) |
| weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon) |
| |
| p4_up = self.conv4_up( |
| self.combine(weight[0] * p4_in + |
| weight[1] * self.p4_upsample(p5_up))) |
|
|
| |
| p3_w1 = self.p3_w1_relu(self.p3_w1) |
| weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon) |
| |
| p3_out = self.conv3_up( |
| self.combine(weight[0] * p3_in + |
| weight[1] * self.p3_upsample(p4_up))) |
|
|
| if self.first_time: |
| p4_in = self.p4_level_connection(p4) |
| p5_in = self.p5_level_connection(p5) |
|
|
| |
| p4_w2 = self.p4_w2_relu(self.p4_w2) |
| weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon) |
| |
| p4_out = self.conv4_down( |
| self.combine(weight[0] * p4_in + weight[1] * p4_up + |
| weight[2] * self.p4_down_sample(p3_out))) |
|
|
| |
| p5_w2 = self.p5_w2_relu(self.p5_w2) |
| weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon) |
| |
| p5_out = self.conv5_down( |
| self.combine(weight[0] * p5_in + weight[1] * p5_up + |
| weight[2] * self.p5_down_sample(p4_out))) |
|
|
| |
| p6_w2 = self.p6_w2_relu(self.p6_w2) |
| weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon) |
| |
| p6_out = self.conv6_down( |
| self.combine(weight[0] * p6_in + weight[1] * p6_up + |
| weight[2] * self.p6_down_sample(p5_out))) |
|
|
| |
| p7_w2 = self.p7_w2_relu(self.p7_w2) |
| weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon) |
| |
| p7_out = self.conv7_down( |
| self.combine(weight[0] * p7_in + |
| weight[1] * self.p7_down_sample(p6_out))) |
| return p3_out, p4_out, p5_out, p6_out, p7_out |
|
|
|
|
| @MODELS.register_module() |
| class BiFPN(BaseModule): |
| """ |
| num_stages: int, bifpn number of repeats |
| in_channels: List[int], input dim for P3, P4, P5 |
| out_channels: int, output dim for P2 - P7 |
| start_level: int, Index of input features in backbone |
| epsilon: float, hyperparameter in fusion features |
| apply_bn_for_resampling: bool, whether use bn after resampling |
| conv_bn_act_pattern: bool, whether use conv_bn_act_pattern |
| norm_cfg: (:obj:`ConfigDict` or dict, optional): Config dict for |
| normalization layer. |
| init_cfg: MultiConfig: init method |
| """ |
|
|
| def __init__(self, |
| num_stages: int, |
| in_channels: List[int], |
| out_channels: int, |
| start_level: int = 0, |
| epsilon: float = 1e-4, |
| apply_bn_for_resampling: bool = True, |
| conv_bn_act_pattern: bool = False, |
| norm_cfg: OptConfigType = dict( |
| type='BN', momentum=1e-2, eps=1e-3), |
| init_cfg: MultiConfig = None) -> None: |
| super().__init__(init_cfg=init_cfg) |
| self.start_level = start_level |
| self.bifpn = nn.Sequential(*[ |
| BiFPNStage( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| first_time=True if _ == 0 else False, |
| apply_bn_for_resampling=apply_bn_for_resampling, |
| conv_bn_act_pattern=conv_bn_act_pattern, |
| norm_cfg=norm_cfg, |
| epsilon=epsilon) for _ in range(num_stages) |
| ]) |
|
|
| def forward(self, x): |
| x = x[self.start_level:] |
| x = self.bifpn(x) |
|
|
| return x |
|
|