Spaces:
Runtime error
Runtime error
| # Copyright (c) Open-CD. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import Conv2d, ConvModule, build_activation_layer | |
| from mmcv.cnn.bricks.drop import build_dropout | |
| from mmengine.model import BaseModule, Sequential | |
| from torch.nn import functional as F | |
| from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
| from mmseg.models.utils import resize | |
| from opencd.registry import MODELS | |
| from ..necks.feature_fusion import FeatureFusionNeck | |
| class FDAF(BaseModule): | |
| """Flow Dual-Alignment Fusion Module. | |
| Args: | |
| in_channels (int): Input channels of features. | |
| conv_cfg (dict | None): Config of conv layers. | |
| Default: None | |
| norm_cfg (dict | None): Config of norm layers. | |
| Default: dict(type='BN') | |
| act_cfg (dict): Config of activation layers. | |
| Default: dict(type='ReLU') | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='IN'), | |
| act_cfg=dict(type='GELU')): | |
| super(FDAF, self).__init__() | |
| self.in_channels = in_channels | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.act_cfg = act_cfg | |
| # TODO | |
| conv_cfg=None | |
| norm_cfg=dict(type='IN') | |
| act_cfg=dict(type='GELU') | |
| kernel_size = 5 | |
| self.flow_make = Sequential( | |
| nn.Conv2d(in_channels*2, in_channels*2, kernel_size=kernel_size, padding=(kernel_size-1)//2, bias=True, groups=in_channels*2), | |
| nn.InstanceNorm2d(in_channels*2), | |
| nn.GELU(), | |
| nn.Conv2d(in_channels*2, 4, kernel_size=1, padding=0, bias=False), | |
| ) | |
| def forward(self, x1, x2, fusion_policy=None): | |
| """Forward function.""" | |
| output = torch.cat([x1, x2], dim=1) | |
| flow = self.flow_make(output) | |
| f1, f2 = torch.chunk(flow, 2, dim=1) | |
| x1_feat = self.warp(x1, f1) - x2 | |
| x2_feat = self.warp(x2, f2) - x1 | |
| if fusion_policy == None: | |
| return x1_feat, x2_feat | |
| output = FeatureFusionNeck.fusion(x1_feat, x2_feat, fusion_policy) | |
| return output | |
| def warp(x, flow): | |
| n, c, h, w = x.size() | |
| norm = torch.tensor([[[[w, h]]]]).type_as(x).to(x.device) | |
| col = torch.linspace(-1.0, 1.0, h).view(-1, 1).repeat(1, w) | |
| row = torch.linspace(-1.0, 1.0, w).repeat(h, 1) | |
| grid = torch.cat((row.unsqueeze(2), col.unsqueeze(2)), 2) | |
| grid = grid.repeat(n, 1, 1, 1).type_as(x).to(x.device) | |
| grid = grid + flow.permute(0, 2, 3, 1) / norm | |
| output = F.grid_sample(x, grid, align_corners=True) | |
| return output | |
| class MixFFN(BaseModule): | |
| """An implementation of MixFFN of Segformer. \ | |
| Here MixFFN is uesd as projection head of Changer. | |
| Args: | |
| embed_dims (int): The feature dimension. Same as | |
| `MultiheadAttention`. Defaults: 256. | |
| feedforward_channels (int): The hidden dimension of FFNs. | |
| Defaults: 1024. | |
| act_cfg (dict, optional): The activation config for FFNs. | |
| Default: dict(type='ReLU') | |
| ffn_drop (float, optional): Probability of an element to be | |
| zeroed in FFN. Default 0.0. | |
| dropout_layer (obj:`ConfigDict`): The dropout_layer used | |
| when adding the shortcut. | |
| init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. | |
| Default: None. | |
| """ | |
| def __init__(self, | |
| embed_dims, | |
| feedforward_channels, | |
| act_cfg=dict(type='GELU'), | |
| ffn_drop=0., | |
| dropout_layer=None, | |
| init_cfg=None): | |
| super(MixFFN, self).__init__(init_cfg) | |
| self.embed_dims = embed_dims | |
| self.feedforward_channels = feedforward_channels | |
| self.act_cfg = act_cfg | |
| self.activate = build_activation_layer(act_cfg) | |
| in_channels = embed_dims | |
| fc1 = Conv2d( | |
| in_channels=in_channels, | |
| out_channels=feedforward_channels, | |
| kernel_size=1, | |
| stride=1, | |
| bias=True) | |
| # 3x3 depth wise conv to provide positional encode information | |
| pe_conv = Conv2d( | |
| in_channels=feedforward_channels, | |
| out_channels=feedforward_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=(3 - 1) // 2, | |
| bias=True, | |
| groups=feedforward_channels) | |
| fc2 = Conv2d( | |
| in_channels=feedforward_channels, | |
| out_channels=in_channels, | |
| kernel_size=1, | |
| stride=1, | |
| bias=True) | |
| drop = nn.Dropout(ffn_drop) | |
| layers = [fc1, pe_conv, self.activate, drop, fc2, drop] | |
| self.layers = Sequential(*layers) | |
| self.dropout_layer = build_dropout( | |
| dropout_layer) if dropout_layer else torch.nn.Identity() | |
| def forward(self, x, identity=None): | |
| out = self.layers(x) | |
| if identity is None: | |
| identity = x | |
| return identity + self.dropout_layer(out) | |
| class Changer(BaseDecodeHead): | |
| """The Head of Changer. | |
| This head is the implementation of | |
| `Changer <https://arxiv.org/abs/2209.08290>` _. | |
| Args: | |
| interpolate_mode: The interpolate mode of MLP head upsample operation. | |
| Default: 'bilinear'. | |
| """ | |
| def __init__(self, interpolate_mode='bilinear', **kwargs): | |
| super().__init__(input_transform='multiple_select', **kwargs) | |
| self.interpolate_mode = interpolate_mode | |
| num_inputs = len(self.in_channels) | |
| assert num_inputs == len(self.in_index) | |
| self.convs = nn.ModuleList() | |
| for i in range(num_inputs): | |
| self.convs.append( | |
| ConvModule( | |
| in_channels=self.in_channels[i], | |
| out_channels=self.channels, | |
| kernel_size=1, | |
| stride=1, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg)) | |
| self.fusion_conv = ConvModule( | |
| in_channels=self.channels * num_inputs, | |
| out_channels=self.channels // 2, | |
| kernel_size=1, | |
| norm_cfg=self.norm_cfg) | |
| self.neck_layer = FDAF(in_channels=self.channels // 2) | |
| # projection head | |
| self.discriminator = MixFFN( | |
| embed_dims=self.channels, | |
| feedforward_channels=self.channels, | |
| ffn_drop=0., | |
| dropout_layer=dict(type='DropPath', drop_prob=0.), | |
| act_cfg=dict(type='GELU')) | |
| def base_forward(self, inputs): | |
| outs = [] | |
| for idx in range(len(inputs)): | |
| x = inputs[idx] | |
| conv = self.convs[idx] | |
| outs.append( | |
| resize( | |
| input=conv(x), | |
| size=inputs[0].shape[2:], | |
| mode=self.interpolate_mode, | |
| align_corners=self.align_corners)) | |
| out = self.fusion_conv(torch.cat(outs, dim=1)) | |
| return out | |
| def forward(self, inputs): | |
| # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 | |
| inputs = self._transform_inputs(inputs) | |
| inputs1 = [] | |
| inputs2 = [] | |
| for input in inputs: | |
| f1, f2 = torch.chunk(input, 2, dim=1) | |
| inputs1.append(f1) | |
| inputs2.append(f2) | |
| out1 = self.base_forward(inputs1) | |
| out2 = self.base_forward(inputs2) | |
| out = self.neck_layer(out1, out2, 'concat') | |
| out = self.discriminator(out) | |
| out = self.cls_seg(out) | |
| return out | |