| import torch |
| import torch.nn as nn |
| from annotator.mmpkg.mmcv.cnn import ConvModule |
|
|
| from annotator.mmpkg.mmseg.ops import resize |
| from ..builder import HEADS |
| from .decode_head import BaseDecodeHead |
| from .psp_head import PPM |
|
|
|
|
| @HEADS.register_module() |
| class UPerHead(BaseDecodeHead): |
| """Unified Perceptual Parsing for Scene Understanding. |
| |
| This head is the implementation of `UPerNet |
| <https://arxiv.org/abs/1807.10221>`_. |
| |
| Args: |
| pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid |
| Module applied on the last feature. Default: (1, 2, 3, 6). |
| """ |
|
|
| def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): |
| super(UPerHead, self).__init__( |
| input_transform='multiple_select', **kwargs) |
| |
| self.psp_modules = PPM( |
| pool_scales, |
| self.in_channels[-1], |
| self.channels, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg, |
| align_corners=self.align_corners) |
| self.bottleneck = ConvModule( |
| self.in_channels[-1] + len(pool_scales) * self.channels, |
| self.channels, |
| 3, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
| |
| self.lateral_convs = nn.ModuleList() |
| self.fpn_convs = nn.ModuleList() |
| for in_channels in self.in_channels[:-1]: |
| l_conv = ConvModule( |
| in_channels, |
| self.channels, |
| 1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg, |
| inplace=False) |
| fpn_conv = ConvModule( |
| self.channels, |
| self.channels, |
| 3, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg, |
| inplace=False) |
| self.lateral_convs.append(l_conv) |
| self.fpn_convs.append(fpn_conv) |
|
|
| self.fpn_bottleneck = ConvModule( |
| len(self.in_channels) * self.channels, |
| self.channels, |
| 3, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
|
|
| def psp_forward(self, inputs): |
| """Forward function of PSP module.""" |
| x = inputs[-1] |
| psp_outs = [x] |
| psp_outs.extend(self.psp_modules(x)) |
| psp_outs = torch.cat(psp_outs, dim=1) |
| output = self.bottleneck(psp_outs) |
|
|
| return output |
|
|
| def forward(self, inputs): |
| """Forward function.""" |
|
|
| inputs = self._transform_inputs(inputs) |
|
|
| |
| laterals = [ |
| lateral_conv(inputs[i]) |
| for i, lateral_conv in enumerate(self.lateral_convs) |
| ] |
|
|
| laterals.append(self.psp_forward(inputs)) |
|
|
| |
| used_backbone_levels = len(laterals) |
| for i in range(used_backbone_levels - 1, 0, -1): |
| prev_shape = laterals[i - 1].shape[2:] |
| laterals[i - 1] += resize( |
| laterals[i], |
| size=prev_shape, |
| mode='bilinear', |
| align_corners=self.align_corners) |
|
|
| |
| fpn_outs = [ |
| self.fpn_convs[i](laterals[i]) |
| for i in range(used_backbone_levels - 1) |
| ] |
| |
| fpn_outs.append(laterals[-1]) |
|
|
| for i in range(used_backbone_levels - 1, 0, -1): |
| fpn_outs[i] = resize( |
| fpn_outs[i], |
| size=fpn_outs[0].shape[2:], |
| mode='bilinear', |
| align_corners=self.align_corners) |
| fpn_outs = torch.cat(fpn_outs, dim=1) |
| output = self.fpn_bottleneck(fpn_outs) |
| output = self.cls_seg(output) |
| return output |
|
|