| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from annotator.mmpkg.mmcv.cnn import ConvModule, build_norm_layer |
| |
|
| | from annotator.mmpkg.mmseg.ops import Encoding, resize |
| | from ..builder import HEADS, build_loss |
| | from .decode_head import BaseDecodeHead |
| |
|
| |
|
| | class EncModule(nn.Module): |
| | """Encoding Module used in EncNet. |
| | |
| | Args: |
| | in_channels (int): Input channels. |
| | num_codes (int): Number of code words. |
| | conv_cfg (dict|None): Config of conv layers. |
| | norm_cfg (dict|None): Config of norm layers. |
| | act_cfg (dict): Config of activation layers. |
| | """ |
| |
|
| | def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): |
| | super(EncModule, self).__init__() |
| | self.encoding_project = ConvModule( |
| | in_channels, |
| | in_channels, |
| | 1, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | act_cfg=act_cfg) |
| | |
| | |
| | if norm_cfg is not None: |
| | encoding_norm_cfg = norm_cfg.copy() |
| | if encoding_norm_cfg['type'] in ['BN', 'IN']: |
| | encoding_norm_cfg['type'] += '1d' |
| | else: |
| | encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( |
| | '2d', '1d') |
| | else: |
| | |
| | encoding_norm_cfg = dict(type='BN1d') |
| | self.encoding = nn.Sequential( |
| | Encoding(channels=in_channels, num_codes=num_codes), |
| | build_norm_layer(encoding_norm_cfg, num_codes)[1], |
| | nn.ReLU(inplace=True)) |
| | self.fc = nn.Sequential( |
| | nn.Linear(in_channels, in_channels), nn.Sigmoid()) |
| |
|
| | def forward(self, x): |
| | """Forward function.""" |
| | encoding_projection = self.encoding_project(x) |
| | encoding_feat = self.encoding(encoding_projection).mean(dim=1) |
| | batch_size, channels, _, _ = x.size() |
| | gamma = self.fc(encoding_feat) |
| | y = gamma.view(batch_size, channels, 1, 1) |
| | output = F.relu_(x + x * y) |
| | return encoding_feat, output |
| |
|
| |
|
| | @HEADS.register_module() |
| | class EncHead(BaseDecodeHead): |
| | """Context Encoding for Semantic Segmentation. |
| | |
| | This head is the implementation of `EncNet |
| | <https://arxiv.org/abs/1803.08904>`_. |
| | |
| | Args: |
| | num_codes (int): Number of code words. Default: 32. |
| | use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to |
| | regularize the training. Default: True. |
| | add_lateral (bool): Whether use lateral connection to fuse features. |
| | Default: False. |
| | loss_se_decode (dict): Config of decode loss. |
| | Default: dict(type='CrossEntropyLoss', use_sigmoid=True). |
| | """ |
| |
|
| | def __init__(self, |
| | num_codes=32, |
| | use_se_loss=True, |
| | add_lateral=False, |
| | loss_se_decode=dict( |
| | type='CrossEntropyLoss', |
| | use_sigmoid=True, |
| | loss_weight=0.2), |
| | **kwargs): |
| | super(EncHead, self).__init__( |
| | input_transform='multiple_select', **kwargs) |
| | self.use_se_loss = use_se_loss |
| | self.add_lateral = add_lateral |
| | self.num_codes = num_codes |
| | self.bottleneck = ConvModule( |
| | self.in_channels[-1], |
| | self.channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| | if add_lateral: |
| | self.lateral_convs = nn.ModuleList() |
| | for in_channels in self.in_channels[:-1]: |
| | self.lateral_convs.append( |
| | ConvModule( |
| | in_channels, |
| | self.channels, |
| | 1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg)) |
| | self.fusion = ConvModule( |
| | len(self.in_channels) * self.channels, |
| | self.channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| | self.enc_module = EncModule( |
| | self.channels, |
| | num_codes=num_codes, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| | if self.use_se_loss: |
| | self.loss_se_decode = build_loss(loss_se_decode) |
| | self.se_layer = nn.Linear(self.channels, self.num_classes) |
| |
|
| | def forward(self, inputs): |
| | """Forward function.""" |
| | inputs = self._transform_inputs(inputs) |
| | feat = self.bottleneck(inputs[-1]) |
| | if self.add_lateral: |
| | laterals = [ |
| | resize( |
| | lateral_conv(inputs[i]), |
| | size=feat.shape[2:], |
| | mode='bilinear', |
| | align_corners=self.align_corners) |
| | for i, lateral_conv in enumerate(self.lateral_convs) |
| | ] |
| | feat = self.fusion(torch.cat([feat, *laterals], 1)) |
| | encode_feat, output = self.enc_module(feat) |
| | output = self.cls_seg(output) |
| | if self.use_se_loss: |
| | se_output = self.se_layer(encode_feat) |
| | return output, se_output |
| | else: |
| | return output |
| |
|
| | def forward_test(self, inputs, img_metas, test_cfg): |
| | """Forward function for testing, ignore se_loss.""" |
| | if self.use_se_loss: |
| | return self.forward(inputs)[0] |
| | else: |
| | return self.forward(inputs) |
| |
|
| | @staticmethod |
| | def _convert_to_onehot_labels(seg_label, num_classes): |
| | """Convert segmentation label to onehot. |
| | |
| | Args: |
| | seg_label (Tensor): Segmentation label of shape (N, H, W). |
| | num_classes (int): Number of classes. |
| | |
| | Returns: |
| | Tensor: Onehot labels of shape (N, num_classes). |
| | """ |
| |
|
| | batch_size = seg_label.size(0) |
| | onehot_labels = seg_label.new_zeros((batch_size, num_classes)) |
| | for i in range(batch_size): |
| | hist = seg_label[i].float().histc( |
| | bins=num_classes, min=0, max=num_classes - 1) |
| | onehot_labels[i] = hist > 0 |
| | return onehot_labels |
| |
|
| | def losses(self, seg_logit, seg_label): |
| | """Compute segmentation and semantic encoding loss.""" |
| | seg_logit, se_seg_logit = seg_logit |
| | loss = dict() |
| | loss.update(super(EncHead, self).losses(seg_logit, seg_label)) |
| | se_loss = self.loss_se_decode( |
| | se_seg_logit, |
| | self._convert_to_onehot_labels(seg_label, self.num_classes)) |
| | loss['loss_se'] = se_loss |
| | return loss |
| |
|