| | import torch |
| | import torch.nn as nn |
| | from annotator.mmpkg.mmcv import is_tuple_of |
| | from annotator.mmpkg.mmcv.cnn import ConvModule |
| |
|
| | from annotator.mmpkg.mmseg.ops import resize |
| | from ..builder import HEADS |
| | from .decode_head import BaseDecodeHead |
| |
|
| |
|
| | @HEADS.register_module() |
| | class LRASPPHead(BaseDecodeHead): |
| | """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. |
| | |
| | This head is the improved implementation of `Searching for MobileNetV3 |
| | <https://ieeexplore.ieee.org/document/9008835>`_. |
| | |
| | Args: |
| | branch_channels (tuple[int]): The number of output channels in every |
| | each branch. Default: (32, 64). |
| | """ |
| |
|
| | def __init__(self, branch_channels=(32, 64), **kwargs): |
| | super(LRASPPHead, self).__init__(**kwargs) |
| | if self.input_transform != 'multiple_select': |
| | raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' |
| | f'must be \'multiple_select\'. But received ' |
| | f'\'{self.input_transform}\'') |
| | assert is_tuple_of(branch_channels, int) |
| | assert len(branch_channels) == len(self.in_channels) - 1 |
| | self.branch_channels = branch_channels |
| |
|
| | self.convs = nn.Sequential() |
| | self.conv_ups = nn.Sequential() |
| | for i in range(len(branch_channels)): |
| | self.convs.add_module( |
| | f'conv{i}', |
| | nn.Conv2d( |
| | self.in_channels[i], branch_channels[i], 1, bias=False)) |
| | self.conv_ups.add_module( |
| | f'conv_up{i}', |
| | ConvModule( |
| | self.channels + branch_channels[i], |
| | self.channels, |
| | 1, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg, |
| | bias=False)) |
| |
|
| | self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) |
| |
|
| | self.aspp_conv = ConvModule( |
| | self.in_channels[-1], |
| | self.channels, |
| | 1, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg, |
| | bias=False) |
| | self.image_pool = nn.Sequential( |
| | nn.AvgPool2d(kernel_size=49, stride=(16, 20)), |
| | ConvModule( |
| | self.in_channels[2], |
| | self.channels, |
| | 1, |
| | act_cfg=dict(type='Sigmoid'), |
| | bias=False)) |
| |
|
| | def forward(self, inputs): |
| | """Forward function.""" |
| | inputs = self._transform_inputs(inputs) |
| |
|
| | x = inputs[-1] |
| |
|
| | x = self.aspp_conv(x) * resize( |
| | self.image_pool(x), |
| | size=x.size()[2:], |
| | mode='bilinear', |
| | align_corners=self.align_corners) |
| | x = self.conv_up_input(x) |
| |
|
| | for i in range(len(self.branch_channels) - 1, -1, -1): |
| | x = resize( |
| | x, |
| | size=inputs[i].size()[2:], |
| | mode='bilinear', |
| | align_corners=self.align_corners) |
| | x = torch.cat([x, self.convs[i](inputs[i])], 1) |
| | x = self.conv_ups[i](x) |
| |
|
| | return self.cls_seg(x) |
| |
|