| import torch |
| import torch.nn as nn |
| from annotator.mmpkg.mmcv.cnn import ConvModule |
|
|
| from annotator.mmpkg.mmseg.ops import resize |
| from ..builder import HEADS |
| from .decode_head import BaseDecodeHead |
|
|
|
|
| class PPM(nn.ModuleList): |
| """Pooling Pyramid Module used in PSPNet. |
| |
| Args: |
| pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid |
| Module. |
| in_channels (int): Input channels. |
| channels (int): Channels after modules, before conv_seg. |
| conv_cfg (dict|None): Config of conv layers. |
| norm_cfg (dict|None): Config of norm layers. |
| act_cfg (dict): Config of activation layers. |
| align_corners (bool): align_corners argument of F.interpolate. |
| """ |
|
|
| def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, |
| act_cfg, align_corners): |
| super(PPM, self).__init__() |
| self.pool_scales = pool_scales |
| self.align_corners = align_corners |
| self.in_channels = in_channels |
| self.channels = channels |
| self.conv_cfg = conv_cfg |
| self.norm_cfg = norm_cfg |
| self.act_cfg = act_cfg |
| for pool_scale in pool_scales: |
| self.append( |
| nn.Sequential( |
| nn.AdaptiveAvgPool2d(pool_scale), |
| ConvModule( |
| self.in_channels, |
| self.channels, |
| 1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg))) |
|
|
| def forward(self, x): |
| """Forward function.""" |
| ppm_outs = [] |
| for ppm in self: |
| ppm_out = ppm(x) |
| upsampled_ppm_out = resize( |
| ppm_out, |
| size=x.size()[2:], |
| mode='bilinear', |
| align_corners=self.align_corners) |
| ppm_outs.append(upsampled_ppm_out) |
| return ppm_outs |
|
|
|
|
| @HEADS.register_module() |
| class PSPHead(BaseDecodeHead): |
| """Pyramid Scene Parsing Network. |
| |
| This head is the implementation of |
| `PSPNet <https://arxiv.org/abs/1612.01105>`_. |
| |
| Args: |
| pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid |
| Module. Default: (1, 2, 3, 6). |
| """ |
|
|
| def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): |
| super(PSPHead, self).__init__(**kwargs) |
| assert isinstance(pool_scales, (list, tuple)) |
| self.pool_scales = pool_scales |
| self.psp_modules = PPM( |
| self.pool_scales, |
| self.in_channels, |
| self.channels, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg, |
| align_corners=self.align_corners) |
| self.bottleneck = ConvModule( |
| self.in_channels + len(pool_scales) * self.channels, |
| self.channels, |
| 3, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
|
|
| def forward(self, inputs): |
| """Forward function.""" |
| x = self._transform_inputs(inputs) |
| psp_outs = [x] |
| psp_outs.extend(self.psp_modules(x)) |
| psp_outs = torch.cat(psp_outs, dim=1) |
| output = self.bottleneck(psp_outs) |
| output = self.cls_seg(output) |
| return output |
|
|