| | import torch |
| | import torch.nn.functional as F |
| | from mmcv.cnn import ConvModule, Scale |
| | from torch import nn |
| |
|
| | from mmseg.core import add_prefix |
| | from ..builder import HEADS |
| | from ..utils import SelfAttentionBlock as _SelfAttentionBlock |
| | from .decode_head import BaseDecodeHead |
| |
|
| |
|
| | class PAM(_SelfAttentionBlock): |
| | """Position Attention Module (PAM) |
| | |
| | Args: |
| | in_channels (int): Input channels of key/query feature. |
| | channels (int): Output channels of key/query transform. |
| | """ |
| |
|
| | def __init__(self, in_channels, channels): |
| | super(PAM, self).__init__( |
| | key_in_channels=in_channels, |
| | query_in_channels=in_channels, |
| | channels=channels, |
| | out_channels=in_channels, |
| | share_key_query=False, |
| | query_downsample=None, |
| | key_downsample=None, |
| | key_query_num_convs=1, |
| | key_query_norm=False, |
| | value_out_num_convs=1, |
| | value_out_norm=False, |
| | matmul_norm=False, |
| | with_out=False, |
| | conv_cfg=None, |
| | norm_cfg=None, |
| | act_cfg=None) |
| |
|
| | self.gamma = Scale(0) |
| |
|
| | def forward(self, x): |
| | """Forward function.""" |
| | out = super(PAM, self).forward(x, x) |
| |
|
| | out = self.gamma(out) + x |
| | return out |
| |
|
| |
|
| | class CAM(nn.Module): |
| | """Channel Attention Module (CAM)""" |
| |
|
| | def __init__(self): |
| | super(CAM, self).__init__() |
| | self.gamma = Scale(0) |
| |
|
| | def forward(self, x): |
| | """Forward function.""" |
| | batch_size, channels, height, width = x.size() |
| | proj_query = x.view(batch_size, channels, -1) |
| | proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) |
| | energy = torch.bmm(proj_query, proj_key) |
| | energy_new = torch.max( |
| | energy, -1, keepdim=True)[0].expand_as(energy) - energy |
| | attention = F.softmax(energy_new, dim=-1) |
| | proj_value = x.view(batch_size, channels, -1) |
| |
|
| | out = torch.bmm(attention, proj_value) |
| | out = out.view(batch_size, channels, height, width) |
| |
|
| | out = self.gamma(out) + x |
| | return out |
| |
|
| |
|
| | @HEADS.register_module() |
| | class DAHead(BaseDecodeHead): |
| | """Dual Attention Network for Scene Segmentation. |
| | |
| | This head is the implementation of `DANet |
| | <https://arxiv.org/abs/1809.02983>`_. |
| | |
| | Args: |
| | pam_channels (int): The channels of Position Attention Module(PAM). |
| | """ |
| |
|
| | def __init__(self, pam_channels, **kwargs): |
| | super(DAHead, self).__init__(**kwargs) |
| | self.pam_channels = pam_channels |
| | self.pam_in_conv = ConvModule( |
| | self.in_channels, |
| | self.channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| | self.pam = PAM(self.channels, pam_channels) |
| | self.pam_out_conv = ConvModule( |
| | self.channels, |
| | self.channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| | self.pam_conv_seg = nn.Conv2d( |
| | self.channels, self.num_classes, kernel_size=1) |
| |
|
| | self.cam_in_conv = ConvModule( |
| | self.in_channels, |
| | self.channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| | self.cam = CAM() |
| | self.cam_out_conv = ConvModule( |
| | self.channels, |
| | self.channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| | self.cam_conv_seg = nn.Conv2d( |
| | self.channels, self.num_classes, kernel_size=1) |
| |
|
| | def pam_cls_seg(self, feat): |
| | """PAM feature classification.""" |
| | if self.dropout is not None: |
| | feat = self.dropout(feat) |
| | output = self.pam_conv_seg(feat) |
| | return output |
| |
|
| | def cam_cls_seg(self, feat): |
| | """CAM feature classification.""" |
| | if self.dropout is not None: |
| | feat = self.dropout(feat) |
| | output = self.cam_conv_seg(feat) |
| | return output |
| |
|
| | def forward(self, inputs): |
| | """Forward function.""" |
| | x = self._transform_inputs(inputs) |
| | pam_feat = self.pam_in_conv(x) |
| | pam_feat = self.pam(pam_feat) |
| | pam_feat = self.pam_out_conv(pam_feat) |
| | pam_out = self.pam_cls_seg(pam_feat) |
| |
|
| | cam_feat = self.cam_in_conv(x) |
| | cam_feat = self.cam(cam_feat) |
| | cam_feat = self.cam_out_conv(cam_feat) |
| | cam_out = self.cam_cls_seg(cam_feat) |
| |
|
| | feat_sum = pam_feat + cam_feat |
| | pam_cam_out = self.cls_seg(feat_sum) |
| |
|
| | return pam_cam_out, pam_out, cam_out |
| |
|
| | def forward_test(self, inputs, img_metas, test_cfg): |
| | """Forward function for testing, only ``pam_cam`` is used.""" |
| | return self.forward(inputs)[0] |
| |
|
| | def losses(self, seg_logit, seg_label): |
| | """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" |
| | pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit |
| | loss = dict() |
| | loss.update( |
| | add_prefix( |
| | super(DAHead, self).losses(pam_cam_seg_logit, seg_label), |
| | 'pam_cam')) |
| | loss.update( |
| | add_prefix( |
| | super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam')) |
| | loss.update( |
| | add_prefix( |
| | super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam')) |
| | return loss |
| |
|