| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from annotator.mmpkg.mmcv.cnn import ConvModule |
|
|
| from annotator.mmpkg.mmseg.ops import resize |
| from ..builder import HEADS |
| from .decode_head import BaseDecodeHead |
|
|
| try: |
| try: |
| from mmcv.ops import PSAMask |
| except ImportError: |
| from annotator.mmpkg.mmcv.ops import PSAMask |
| except ModuleNotFoundError: |
| PSAMask = None |
|
|
|
|
| @HEADS.register_module() |
| class PSAHead(BaseDecodeHead): |
| """Point-wise Spatial Attention Network for Scene Parsing. |
| |
| This head is the implementation of `PSANet |
| <https://hszhao.github.io/papers/eccv18_psanet.pdf>`_. |
| |
| Args: |
| mask_size (tuple[int]): The PSA mask size. It usually equals input |
| size. |
| psa_type (str): The type of psa module. Options are 'collect', |
| 'distribute', 'bi-direction'. Default: 'bi-direction' |
| compact (bool): Whether use compact map for 'collect' mode. |
| Default: True. |
| shrink_factor (int): The downsample factors of psa mask. Default: 2. |
| normalization_factor (float): The normalize factor of attention. |
| psa_softmax (bool): Whether use softmax for attention. |
| """ |
|
|
| def __init__(self, |
| mask_size, |
| psa_type='bi-direction', |
| compact=False, |
| shrink_factor=2, |
| normalization_factor=1.0, |
| psa_softmax=True, |
| **kwargs): |
| if PSAMask is None: |
| raise RuntimeError('Please install mmcv-full for PSAMask ops') |
| super(PSAHead, self).__init__(**kwargs) |
| assert psa_type in ['collect', 'distribute', 'bi-direction'] |
| self.psa_type = psa_type |
| self.compact = compact |
| self.shrink_factor = shrink_factor |
| self.mask_size = mask_size |
| mask_h, mask_w = mask_size |
| self.psa_softmax = psa_softmax |
| if normalization_factor is None: |
| normalization_factor = mask_h * mask_w |
| self.normalization_factor = normalization_factor |
|
|
| self.reduce = ConvModule( |
| self.in_channels, |
| self.channels, |
| kernel_size=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
| self.attention = nn.Sequential( |
| ConvModule( |
| self.channels, |
| self.channels, |
| kernel_size=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg), |
| nn.Conv2d( |
| self.channels, mask_h * mask_w, kernel_size=1, bias=False)) |
| if psa_type == 'bi-direction': |
| self.reduce_p = ConvModule( |
| self.in_channels, |
| self.channels, |
| kernel_size=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
| self.attention_p = nn.Sequential( |
| ConvModule( |
| self.channels, |
| self.channels, |
| kernel_size=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg), |
| nn.Conv2d( |
| self.channels, mask_h * mask_w, kernel_size=1, bias=False)) |
| self.psamask_collect = PSAMask('collect', mask_size) |
| self.psamask_distribute = PSAMask('distribute', mask_size) |
| else: |
| self.psamask = PSAMask(psa_type, mask_size) |
| self.proj = ConvModule( |
| self.channels * (2 if psa_type == 'bi-direction' else 1), |
| self.in_channels, |
| kernel_size=1, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
| self.bottleneck = ConvModule( |
| self.in_channels * 2, |
| self.channels, |
| kernel_size=3, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
|
|
| def forward(self, inputs): |
| """Forward function.""" |
| x = self._transform_inputs(inputs) |
| identity = x |
| align_corners = self.align_corners |
| if self.psa_type in ['collect', 'distribute']: |
| out = self.reduce(x) |
| n, c, h, w = out.size() |
| if self.shrink_factor != 1: |
| if h % self.shrink_factor and w % self.shrink_factor: |
| h = (h - 1) // self.shrink_factor + 1 |
| w = (w - 1) // self.shrink_factor + 1 |
| align_corners = True |
| else: |
| h = h // self.shrink_factor |
| w = w // self.shrink_factor |
| align_corners = False |
| out = resize( |
| out, |
| size=(h, w), |
| mode='bilinear', |
| align_corners=align_corners) |
| y = self.attention(out) |
| if self.compact: |
| if self.psa_type == 'collect': |
| y = y.view(n, h * w, |
| h * w).transpose(1, 2).view(n, h * w, h, w) |
| else: |
| y = self.psamask(y) |
| if self.psa_softmax: |
| y = F.softmax(y, dim=1) |
| out = torch.bmm( |
| out.view(n, c, h * w), y.view(n, h * w, h * w)).view( |
| n, c, h, w) * (1.0 / self.normalization_factor) |
| else: |
| x_col = self.reduce(x) |
| x_dis = self.reduce_p(x) |
| n, c, h, w = x_col.size() |
| if self.shrink_factor != 1: |
| if h % self.shrink_factor and w % self.shrink_factor: |
| h = (h - 1) // self.shrink_factor + 1 |
| w = (w - 1) // self.shrink_factor + 1 |
| align_corners = True |
| else: |
| h = h // self.shrink_factor |
| w = w // self.shrink_factor |
| align_corners = False |
| x_col = resize( |
| x_col, |
| size=(h, w), |
| mode='bilinear', |
| align_corners=align_corners) |
| x_dis = resize( |
| x_dis, |
| size=(h, w), |
| mode='bilinear', |
| align_corners=align_corners) |
| y_col = self.attention(x_col) |
| y_dis = self.attention_p(x_dis) |
| if self.compact: |
| y_dis = y_dis.view(n, h * w, |
| h * w).transpose(1, 2).view(n, h * w, h, w) |
| else: |
| y_col = self.psamask_collect(y_col) |
| y_dis = self.psamask_distribute(y_dis) |
| if self.psa_softmax: |
| y_col = F.softmax(y_col, dim=1) |
| y_dis = F.softmax(y_dis, dim=1) |
| x_col = torch.bmm( |
| x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view( |
| n, c, h, w) * (1.0 / self.normalization_factor) |
| x_dis = torch.bmm( |
| x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view( |
| n, c, h, w) * (1.0 / self.normalization_factor) |
| out = torch.cat([x_col, x_dis], 1) |
| out = self.proj(out) |
| out = resize( |
| out, |
| size=identity.shape[2:], |
| mode='bilinear', |
| align_corners=align_corners) |
| out = self.bottleneck(torch.cat((identity, out), dim=1)) |
| out = self.cls_seg(out) |
| return out |
|
|