| import torch |
| from mmcv.cnn import NonLocal2d |
|
|
| from ..builder import HEADS |
| from .fcn_head import FCNHead |
|
|
|
|
| @HEADS.register_module() |
| class NLHead(FCNHead): |
| """Non-local Neural Networks. |
| |
| This head is the implementation of `NLNet |
| <https://arxiv.org/abs/1711.07971>`_. |
| |
| Args: |
| reduction (int): Reduction factor of projection transform. Default: 2. |
| use_scale (bool): Whether to scale pairwise_weight by |
| sqrt(1/inter_channels). Default: True. |
| mode (str): The nonlocal mode. Options are 'embedded_gaussian', |
| 'dot_product'. Default: 'embedded_gaussian.'. |
| """ |
|
|
| def __init__(self, |
| reduction=2, |
| use_scale=True, |
| mode='embedded_gaussian', |
| **kwargs): |
| super(NLHead, self).__init__(num_convs=2, **kwargs) |
| self.reduction = reduction |
| self.use_scale = use_scale |
| self.mode = mode |
| self.nl_block = NonLocal2d( |
| in_channels=self.channels, |
| reduction=self.reduction, |
| use_scale=self.use_scale, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| mode=self.mode) |
|
|
| def forward(self, inputs): |
| """Forward function.""" |
| x = self._transform_inputs(inputs) |
| output = self.convs[0](x) |
| output = self.nl_block(output) |
| output = self.convs[1](output) |
| if self.concat_input: |
| output = self.conv_cat(torch.cat([x, output], dim=1)) |
| output = self.cls_seg(output) |
| return output |
|
|