| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from mmcv.cnn import ConvModule |
| |
|
| | from mmseg.ops import resize |
| | from ..builder import HEADS |
| | from ..utils import SelfAttentionBlock as _SelfAttentionBlock |
| | from .cascade_decode_head import BaseCascadeDecodeHead |
| |
|
| |
|
| | class SpatialGatherModule(nn.Module): |
| | """Aggregate the context features according to the initial predicted |
| | probability distribution. |
| | |
| | Employ the soft-weighted method to aggregate the context. |
| | """ |
| |
|
| | def __init__(self, scale): |
| | super(SpatialGatherModule, self).__init__() |
| | self.scale = scale |
| |
|
| | def forward(self, feats, probs): |
| | """Forward function.""" |
| | batch_size, num_classes, height, width = probs.size() |
| | channels = feats.size(1) |
| | probs = probs.view(batch_size, num_classes, -1) |
| | feats = feats.view(batch_size, channels, -1) |
| | |
| | feats = feats.permute(0, 2, 1) |
| | |
| | probs = F.softmax(self.scale * probs, dim=2) |
| | |
| | ocr_context = torch.matmul(probs, feats) |
| | ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3) |
| | return ocr_context |
| |
|
| |
|
| | class ObjectAttentionBlock(_SelfAttentionBlock): |
| | """Make a OCR used SelfAttentionBlock.""" |
| |
|
| | def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, |
| | act_cfg): |
| | if scale > 1: |
| | query_downsample = nn.MaxPool2d(kernel_size=scale) |
| | else: |
| | query_downsample = None |
| | super(ObjectAttentionBlock, self).__init__( |
| | key_in_channels=in_channels, |
| | query_in_channels=in_channels, |
| | channels=channels, |
| | out_channels=in_channels, |
| | share_key_query=False, |
| | query_downsample=query_downsample, |
| | key_downsample=None, |
| | key_query_num_convs=2, |
| | key_query_norm=True, |
| | value_out_num_convs=1, |
| | value_out_norm=True, |
| | matmul_norm=True, |
| | with_out=True, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg) |
| | self.bottleneck = ConvModule( |
| | in_channels * 2, |
| | in_channels, |
| | 1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| |
|
| | def forward(self, query_feats, key_feats): |
| | """Forward function.""" |
| | context = super(ObjectAttentionBlock, |
| | self).forward(query_feats, key_feats) |
| | output = self.bottleneck(torch.cat([context, query_feats], dim=1)) |
| | if self.query_downsample is not None: |
| | output = resize(query_feats) |
| |
|
| | return output |
| |
|
| |
|
| | @HEADS.register_module() |
| | class OCRHead(BaseCascadeDecodeHead): |
| | """Object-Contextual Representations for Semantic Segmentation. |
| | |
| | This head is the implementation of `OCRNet |
| | <https://arxiv.org/abs/1909.11065>`_. |
| | |
| | Args: |
| | ocr_channels (int): The intermediate channels of OCR block. |
| | scale (int): The scale of probability map in SpatialGatherModule in |
| | Default: 1. |
| | """ |
| |
|
| | def __init__(self, ocr_channels, scale=1, **kwargs): |
| | super(OCRHead, self).__init__(**kwargs) |
| | self.ocr_channels = ocr_channels |
| | self.scale = scale |
| | self.object_context_block = ObjectAttentionBlock( |
| | self.channels, |
| | self.ocr_channels, |
| | self.scale, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| | self.spatial_gather_module = SpatialGatherModule(self.scale) |
| |
|
| | self.bottleneck = ConvModule( |
| | self.in_channels, |
| | self.channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| |
|
| | def forward(self, inputs, prev_output): |
| | """Forward function.""" |
| | x = self._transform_inputs(inputs) |
| | feats = self.bottleneck(x) |
| | context = self.spatial_gather_module(feats, prev_output) |
| | object_context = self.object_context_block(feats, context) |
| | output = self.cls_seg(object_context) |
| |
|
| | return output |
| |
|