| | import math |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from annotator.mmpkg.mmcv.cnn import ConvModule |
| |
|
| | from ..builder import HEADS |
| | from .decode_head import BaseDecodeHead |
| |
|
| |
|
| | def reduce_mean(tensor): |
| | """Reduce mean when distributed training.""" |
| | if not (dist.is_available() and dist.is_initialized()): |
| | return tensor |
| | tensor = tensor.clone() |
| | dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) |
| | return tensor |
| |
|
| |
|
| | class EMAModule(nn.Module): |
| | """Expectation Maximization Attention Module used in EMANet. |
| | |
| | Args: |
| | channels (int): Channels of the whole module. |
| | num_bases (int): Number of bases. |
| | num_stages (int): Number of the EM iterations. |
| | """ |
| |
|
| | def __init__(self, channels, num_bases, num_stages, momentum): |
| | super(EMAModule, self).__init__() |
| | assert num_stages >= 1, 'num_stages must be at least 1!' |
| | self.num_bases = num_bases |
| | self.num_stages = num_stages |
| | self.momentum = momentum |
| |
|
| | bases = torch.zeros(1, channels, self.num_bases) |
| | bases.normal_(0, math.sqrt(2. / self.num_bases)) |
| | |
| | bases = F.normalize(bases, dim=1, p=2) |
| | self.register_buffer('bases', bases) |
| |
|
| | def forward(self, feats): |
| | """Forward function.""" |
| | batch_size, channels, height, width = feats.size() |
| | |
| | feats = feats.view(batch_size, channels, height * width) |
| | |
| | bases = self.bases.repeat(batch_size, 1, 1) |
| |
|
| | with torch.no_grad(): |
| | for i in range(self.num_stages): |
| | |
| | attention = torch.einsum('bcn,bck->bnk', feats, bases) |
| | attention = F.softmax(attention, dim=2) |
| | |
| | attention_normed = F.normalize(attention, dim=1, p=1) |
| | |
| | bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) |
| | |
| | bases = F.normalize(bases, dim=1, p=2) |
| |
|
| | feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) |
| | feats_recon = feats_recon.view(batch_size, channels, height, width) |
| |
|
| | if self.training: |
| | bases = bases.mean(dim=0, keepdim=True) |
| | bases = reduce_mean(bases) |
| | |
| | bases = F.normalize(bases, dim=1, p=2) |
| | self.bases = (1 - |
| | self.momentum) * self.bases + self.momentum * bases |
| |
|
| | return feats_recon |
| |
|
| |
|
| | @HEADS.register_module() |
| | class EMAHead(BaseDecodeHead): |
| | """Expectation Maximization Attention Networks for Semantic Segmentation. |
| | |
| | This head is the implementation of `EMANet |
| | <https://arxiv.org/abs/1907.13426>`_. |
| | |
| | Args: |
| | ema_channels (int): EMA module channels |
| | num_bases (int): Number of bases. |
| | num_stages (int): Number of the EM iterations. |
| | concat_input (bool): Whether concat the input and output of convs |
| | before classification layer. Default: True |
| | momentum (float): Momentum to update the base. Default: 0.1. |
| | """ |
| |
|
| | def __init__(self, |
| | ema_channels, |
| | num_bases, |
| | num_stages, |
| | concat_input=True, |
| | momentum=0.1, |
| | **kwargs): |
| | super(EMAHead, self).__init__(**kwargs) |
| | self.ema_channels = ema_channels |
| | self.num_bases = num_bases |
| | self.num_stages = num_stages |
| | self.concat_input = concat_input |
| | self.momentum = momentum |
| | self.ema_module = EMAModule(self.ema_channels, self.num_bases, |
| | self.num_stages, self.momentum) |
| |
|
| | self.ema_in_conv = ConvModule( |
| | self.in_channels, |
| | self.ema_channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| | |
| | self.ema_mid_conv = ConvModule( |
| | self.ema_channels, |
| | self.ema_channels, |
| | 1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=None, |
| | act_cfg=None) |
| | for param in self.ema_mid_conv.parameters(): |
| | param.requires_grad = False |
| |
|
| | self.ema_out_conv = ConvModule( |
| | self.ema_channels, |
| | self.ema_channels, |
| | 1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=None) |
| | self.bottleneck = ConvModule( |
| | self.ema_channels, |
| | self.channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| | if self.concat_input: |
| | self.conv_cat = ConvModule( |
| | self.in_channels + self.channels, |
| | self.channels, |
| | kernel_size=3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| |
|
| | def forward(self, inputs): |
| | """Forward function.""" |
| | x = self._transform_inputs(inputs) |
| | feats = self.ema_in_conv(x) |
| | identity = feats |
| | feats = self.ema_mid_conv(feats) |
| | recon = self.ema_module(feats) |
| | recon = F.relu(recon, inplace=True) |
| | recon = self.ema_out_conv(recon) |
| | output = F.relu(identity + recon, inplace=True) |
| | output = self.bottleneck(output) |
| | if self.concat_input: |
| | output = self.conv_cat(torch.cat([x, output], dim=1)) |
| | output = self.cls_seg(output) |
| | return output |
| |
|