| | import torch |
| | import torch.nn as nn |
| | from annotator.mmpkg.mmcv.cnn import ConvModule |
| |
|
| | from annotator.mmpkg.mmseg.ops import resize |
| | from ..builder import HEADS |
| | from .decode_head import BaseDecodeHead |
| |
|
| |
|
| | class ASPPModule(nn.ModuleList): |
| | """Atrous Spatial Pyramid Pooling (ASPP) Module. |
| | |
| | Args: |
| | dilations (tuple[int]): Dilation rate of each layer. |
| | in_channels (int): Input channels. |
| | channels (int): Channels after modules, before conv_seg. |
| | conv_cfg (dict|None): Config of conv layers. |
| | norm_cfg (dict|None): Config of norm layers. |
| | act_cfg (dict): Config of activation layers. |
| | """ |
| |
|
| | def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, |
| | act_cfg): |
| | super(ASPPModule, self).__init__() |
| | self.dilations = dilations |
| | self.in_channels = in_channels |
| | self.channels = channels |
| | self.conv_cfg = conv_cfg |
| | self.norm_cfg = norm_cfg |
| | self.act_cfg = act_cfg |
| | for dilation in dilations: |
| | self.append( |
| | ConvModule( |
| | self.in_channels, |
| | self.channels, |
| | 1 if dilation == 1 else 3, |
| | dilation=dilation, |
| | padding=0 if dilation == 1 else dilation, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg)) |
| |
|
| | def forward(self, x): |
| | """Forward function.""" |
| | aspp_outs = [] |
| | for aspp_module in self: |
| | aspp_outs.append(aspp_module(x)) |
| |
|
| | return aspp_outs |
| |
|
| |
|
| | @HEADS.register_module() |
| | class ASPPHead(BaseDecodeHead): |
| | """Rethinking Atrous Convolution for Semantic Image Segmentation. |
| | |
| | This head is the implementation of `DeepLabV3 |
| | <https://arxiv.org/abs/1706.05587>`_. |
| | |
| | Args: |
| | dilations (tuple[int]): Dilation rates for ASPP module. |
| | Default: (1, 6, 12, 18). |
| | """ |
| |
|
| | def __init__(self, dilations=(1, 6, 12, 18), **kwargs): |
| | super(ASPPHead, self).__init__(**kwargs) |
| | assert isinstance(dilations, (list, tuple)) |
| | self.dilations = dilations |
| | self.image_pool = nn.Sequential( |
| | nn.AdaptiveAvgPool2d(1), |
| | ConvModule( |
| | self.in_channels, |
| | self.channels, |
| | 1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg)) |
| | self.aspp_modules = ASPPModule( |
| | dilations, |
| | self.in_channels, |
| | self.channels, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| | self.bottleneck = ConvModule( |
| | (len(dilations) + 1) * self.channels, |
| | self.channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| |
|
| | def forward(self, inputs): |
| | """Forward function.""" |
| | x = self._transform_inputs(inputs) |
| | aspp_outs = [ |
| | resize( |
| | self.image_pool(x), |
| | size=x.size()[2:], |
| | mode='bilinear', |
| | align_corners=self.align_corners) |
| | ] |
| | aspp_outs.extend(self.aspp_modules(x)) |
| | aspp_outs = torch.cat(aspp_outs, dim=1) |
| | output = self.bottleneck(aspp_outs) |
| | output = self.cls_seg(output) |
| | return output |
| |
|