| import torch |
| import torch.nn as nn |
| from mmcv.cnn import ConvModule |
|
|
| from ..builder import HEADS |
| from .decode_head import BaseDecodeHead |
|
|
|
|
| @HEADS.register_module() |
| class FCNHead(BaseDecodeHead): |
| """Fully Convolution Networks for Semantic Segmentation. |
| |
| This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_. |
| |
| Args: |
| num_convs (int): Number of convs in the head. Default: 2. |
| kernel_size (int): The kernel size for convs in the head. Default: 3. |
| concat_input (bool): Whether concat the input and output of convs |
| before classification layer. |
| """ |
|
|
| def __init__(self, |
| num_convs=2, |
| kernel_size=3, |
| concat_input=True, |
| **kwargs): |
| assert num_convs >= 0 |
| self.num_convs = num_convs |
| self.concat_input = concat_input |
| self.kernel_size = kernel_size |
| super(FCNHead, self).__init__(**kwargs) |
| if num_convs == 0: |
| assert self.in_channels == self.channels |
|
|
| convs = [] |
| convs.append( |
| ConvModule( |
| self.in_channels, |
| self.channels, |
| kernel_size=kernel_size, |
| padding=kernel_size // 2, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg)) |
| for i in range(num_convs - 1): |
| convs.append( |
| ConvModule( |
| self.channels, |
| self.channels, |
| kernel_size=kernel_size, |
| padding=kernel_size // 2, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg)) |
| if num_convs == 0: |
| self.convs = nn.Identity() |
| else: |
| self.convs = nn.Sequential(*convs) |
| if self.concat_input: |
| self.conv_cat = ConvModule( |
| self.in_channels + self.channels, |
| self.channels, |
| kernel_size=kernel_size, |
| padding=kernel_size // 2, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| act_cfg=self.act_cfg) |
|
|
| def forward(self, inputs): |
| """Forward function.""" |
| x = self._transform_inputs(inputs) |
| output = self.convs(x) |
| if self.concat_input: |
| output = self.conv_cat(torch.cat([x, output], dim=1)) |
| output = self.cls_seg(output) |
| return output |
|
|