| | import torch |
| | import torch.nn as nn |
| | from annotator.mmpkg.mmcv.cnn import ConvModule |
| |
|
| | from ..builder import HEADS |
| | from .decode_head import BaseDecodeHead |
| |
|
| |
|
| | @HEADS.register_module() |
| | class FCNHead(BaseDecodeHead): |
| | """Fully Convolution Networks for Semantic Segmentation. |
| | |
| | This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_. |
| | |
| | Args: |
| | num_convs (int): Number of convs in the head. Default: 2. |
| | kernel_size (int): The kernel size for convs in the head. Default: 3. |
| | concat_input (bool): Whether concat the input and output of convs |
| | before classification layer. |
| | dilation (int): The dilation rate for convs in the head. Default: 1. |
| | """ |
| |
|
| | def __init__(self, |
| | num_convs=2, |
| | kernel_size=3, |
| | concat_input=True, |
| | dilation=1, |
| | **kwargs): |
| | assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) |
| | self.num_convs = num_convs |
| | self.concat_input = concat_input |
| | self.kernel_size = kernel_size |
| | super(FCNHead, self).__init__(**kwargs) |
| | if num_convs == 0: |
| | assert self.in_channels == self.channels |
| |
|
| | conv_padding = (kernel_size // 2) * dilation |
| | convs = [] |
| | convs.append( |
| | ConvModule( |
| | self.in_channels, |
| | self.channels, |
| | kernel_size=kernel_size, |
| | padding=conv_padding, |
| | dilation=dilation, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg)) |
| | for i in range(num_convs - 1): |
| | convs.append( |
| | ConvModule( |
| | self.channels, |
| | self.channels, |
| | kernel_size=kernel_size, |
| | padding=conv_padding, |
| | dilation=dilation, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg)) |
| | if num_convs == 0: |
| | self.convs = nn.Identity() |
| | else: |
| | self.convs = nn.Sequential(*convs) |
| | if self.concat_input: |
| | self.conv_cat = ConvModule( |
| | self.in_channels + self.channels, |
| | self.channels, |
| | kernel_size=kernel_size, |
| | padding=kernel_size // 2, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg) |
| |
|
| | def forward(self, inputs): |
| | """Forward function.""" |
| | x = self._transform_inputs(inputs) |
| | output = self.convs(x) |
| | if self.concat_input: |
| | output = self.conv_cat(torch.cat([x, output], dim=1)) |
| | output = self.cls_seg(output) |
| | return output |
| |
|