| | import torch |
| |
|
| | from ..builder import HEADS |
| | from .fcn_head import FCNHead |
| |
|
| | try: |
| | from mmcv.ops import CrissCrossAttention |
| | except ModuleNotFoundError: |
| | CrissCrossAttention = None |
| |
|
| |
|
| | @HEADS.register_module() |
| | class CCHead(FCNHead): |
| | """CCNet: Criss-Cross Attention for Semantic Segmentation. |
| | |
| | This head is the implementation of `CCNet |
| | <https://arxiv.org/abs/1811.11721>`_. |
| | |
| | Args: |
| | recurrence (int): Number of recurrence of Criss Cross Attention |
| | module. Default: 2. |
| | """ |
| |
|
| | def __init__(self, recurrence=2, **kwargs): |
| | if CrissCrossAttention is None: |
| | raise RuntimeError('Please install mmcv-full for ' |
| | 'CrissCrossAttention ops') |
| | super(CCHead, self).__init__(num_convs=2, **kwargs) |
| | self.recurrence = recurrence |
| | self.cca = CrissCrossAttention(self.channels) |
| |
|
| | def forward(self, inputs): |
| | """Forward function.""" |
| | x = self._transform_inputs(inputs) |
| | output = self.convs[0](x) |
| | for _ in range(self.recurrence): |
| | output = self.cca(output) |
| | output = self.convs[1](output) |
| | if self.concat_input: |
| | output = self.conv_cat(torch.cat([x, output], dim=1)) |
| | output = self.cls_seg(output) |
| | return output |
| |
|