| | import numpy as np |
| | import torch.nn as nn |
| | from annotator.mmpkg.mmcv.cnn import ConvModule |
| |
|
| | from annotator.mmpkg.mmseg.ops import resize |
| | from ..builder import HEADS |
| | from .decode_head import BaseDecodeHead |
| |
|
| |
|
| | @HEADS.register_module() |
| | class FPNHead(BaseDecodeHead): |
| | """Panoptic Feature Pyramid Networks. |
| | |
| | This head is the implementation of `Semantic FPN |
| | <https://arxiv.org/abs/1901.02446>`_. |
| | |
| | Args: |
| | feature_strides (tuple[int]): The strides for input feature maps. |
| | stack_lateral. All strides suppose to be power of 2. The first |
| | one is of largest resolution. |
| | """ |
| |
|
| | def __init__(self, feature_strides, **kwargs): |
| | super(FPNHead, self).__init__( |
| | input_transform='multiple_select', **kwargs) |
| | assert len(feature_strides) == len(self.in_channels) |
| | assert min(feature_strides) == feature_strides[0] |
| | self.feature_strides = feature_strides |
| |
|
| | self.scale_heads = nn.ModuleList() |
| | for i in range(len(feature_strides)): |
| | head_length = max( |
| | 1, |
| | int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) |
| | scale_head = [] |
| | for k in range(head_length): |
| | scale_head.append( |
| | ConvModule( |
| | self.in_channels[i] if k == 0 else self.channels, |
| | self.channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg, |
| | act_cfg=self.act_cfg)) |
| | if feature_strides[i] != feature_strides[0]: |
| | scale_head.append( |
| | nn.Upsample( |
| | scale_factor=2, |
| | mode='bilinear', |
| | align_corners=self.align_corners)) |
| | self.scale_heads.append(nn.Sequential(*scale_head)) |
| |
|
| | def forward(self, inputs): |
| |
|
| | x = self._transform_inputs(inputs) |
| |
|
| | output = self.scale_heads[0](x[0]) |
| | for i in range(1, len(self.feature_strides)): |
| | |
| | output = output + resize( |
| | self.scale_heads[i](x[i]), |
| | size=output.shape[2:], |
| | mode='bilinear', |
| | align_corners=self.align_corners) |
| |
|
| | output = self.cls_seg(output) |
| | return output |
| |
|