Spaces:
Runtime error
Runtime error
| # Copyright (c) Open-CD. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule | |
| from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
| from mmseg.models.utils import resize | |
| from opencd.registry import MODELS | |
| class TinyHead(BaseDecodeHead): | |
| """ | |
| This head is the implementation of `TinyCDv2 | |
| <https://arxiv.org/abs/>`_. | |
| Args: | |
| feature_strides (tuple[int]): The strides for input feature maps. | |
| stack_lateral. All strides suppose to be power of 2. The first | |
| one is of largest resolution. | |
| priori_attn (bool): Whether use Priori Guiding Connection. | |
| Default to False. | |
| """ | |
| def __init__(self, feature_strides, priori_attn=False, **kwargs): | |
| super().__init__(input_transform='multiple_select', **kwargs) | |
| assert len(feature_strides) == len(self.in_channels) | |
| assert min(feature_strides) == feature_strides[0] | |
| if priori_attn: | |
| attn_channels = self.in_channels[0] | |
| self.in_channels = self.in_channels[1:] | |
| feature_strides = feature_strides[1:] | |
| self.feature_strides = feature_strides | |
| self.priori_attn = priori_attn | |
| self.scale_heads = nn.ModuleList() | |
| for i in range(len(feature_strides)): | |
| scale_head = [] | |
| scale_head.append( | |
| ConvModule( | |
| in_channels=self.in_channels[i], | |
| out_channels=self.channels, | |
| kernel_size=1, | |
| stride=1, | |
| groups=1, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg)) | |
| self.scale_heads.append(nn.Sequential(*scale_head)) | |
| if self.priori_attn: | |
| self.gen_diff_attn = ConvModule( | |
| in_channels=attn_channels // 2, | |
| out_channels=self.channels, | |
| kernel_size=1, | |
| stride=1, | |
| groups=1, | |
| norm_cfg=None, | |
| act_cfg=None | |
| ) | |
| def forward(self, inputs): | |
| x = self._transform_inputs(inputs) | |
| if self.priori_attn: | |
| early_x = x[0] | |
| x = x[1:] | |
| output = self.scale_heads[0](x[0]) | |
| for i in range(1, len(self.feature_strides)): | |
| # non inplace | |
| output = output + resize( | |
| self.scale_heads[i](x[i]), | |
| size=output.shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| if self.priori_attn: | |
| x1_, x2_ = torch.chunk(early_x, 2, dim=1) | |
| diff_x = torch.abs(x1_ - x2_) | |
| diff_x = self.gen_diff_attn(diff_x) | |
| if diff_x.shape != output.shape: | |
| output = resize(output, diff_x.shape[2:], mode='bilinear', align_corners=self.align_corners) | |
| output = output * torch.sigmoid(diff_x) + output | |
| output = self.cls_seg(output) | |
| return output |