| import torch |
| import torch.nn as nn |
| from annotator.mmpkg.mmcv.cnn import ConvModule, DepthwiseSeparableConvModule |
|
|
| from annotator.mmpkg.mmseg.ops import resize |
| from ..builder import HEADS |
| from .aspp_head import ASPPHead, ASPPModule |
|
|
|
|
| class DepthwiseSeparableASPPModule(ASPPModule): |
| """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable |
| conv.""" |
|
|
| def __init__(self, **kwargs): |
| super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) |
| for i, dilation in enumerate(self.dilations): |
| if dilation > 1: |
| self[i] = DepthwiseSeparableConvModule( |
| self.in_channels, |
| self.channels, |
| 3, |
| dilation=dilation, |
| padding=dilation, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
|
|
|
|
| @HEADS.register_module() |
| class DepthwiseSeparableASPPHead(ASPPHead): |
| """Encoder-Decoder with Atrous Separable Convolution for Semantic Image |
| Segmentation. |
| |
| This head is the implementation of `DeepLabV3+ |
| <https://arxiv.org/abs/1802.02611>`_. |
| |
| Args: |
| c1_in_channels (int): The input channels of c1 decoder. If is 0, |
| the no decoder will be used. |
| c1_channels (int): The intermediate channels of c1 decoder. |
| """ |
|
|
| def __init__(self, c1_in_channels, c1_channels, **kwargs): |
| super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) |
| assert c1_in_channels >= 0 |
| self.aspp_modules = DepthwiseSeparableASPPModule( |
| dilations=self.dilations, |
| in_channels=self.in_channels, |
| channels=self.channels, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
| if c1_in_channels > 0: |
| self.c1_bottleneck = ConvModule( |
| c1_in_channels, |
| c1_channels, |
| 1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
| else: |
| self.c1_bottleneck = None |
| self.sep_bottleneck = nn.Sequential( |
| DepthwiseSeparableConvModule( |
| self.channels + c1_channels, |
| self.channels, |
| 3, |
| padding=1, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg), |
| DepthwiseSeparableConvModule( |
| self.channels, |
| self.channels, |
| 3, |
| padding=1, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg)) |
|
|
| def forward(self, inputs): |
| """Forward function.""" |
| x = self._transform_inputs(inputs) |
| aspp_outs = [ |
| resize( |
| self.image_pool(x), |
| size=x.size()[2:], |
| mode='bilinear', |
| align_corners=self.align_corners) |
| ] |
| aspp_outs.extend(self.aspp_modules(x)) |
| aspp_outs = torch.cat(aspp_outs, dim=1) |
| output = self.bottleneck(aspp_outs) |
| if self.c1_bottleneck is not None: |
| c1_output = self.c1_bottleneck(inputs[0]) |
| output = resize( |
| input=output, |
| size=c1_output.shape[2:], |
| mode='bilinear', |
| align_corners=self.align_corners) |
| output = torch.cat([output, c1_output], dim=1) |
| output = self.sep_bottleneck(output) |
| output = self.cls_seg(output) |
| return output |
|
|