Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer | |
| from mmengine.model import ModuleList, Sequential | |
| from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
| from mmseg.models.utils import Upsample | |
| from opencd.registry import MODELS | |
| class CrossAttention(nn.Module): | |
| def __init__(self, | |
| in_dims, | |
| embed_dims, | |
| num_heads, | |
| dropout_rate=0., | |
| apply_softmax=True): | |
| super(CrossAttention, self).__init__() | |
| self.num_heads = num_heads | |
| self.scale = in_dims ** -0.5 | |
| self.apply_softmax = apply_softmax | |
| self.to_q = nn.Linear(in_dims, embed_dims, bias=False) | |
| self.to_k = nn.Linear(in_dims, embed_dims, bias=False) | |
| self.to_v = nn.Linear(in_dims, embed_dims, bias=False) | |
| self.fc_out = nn.Sequential( | |
| nn.Linear(embed_dims, in_dims), | |
| nn.Dropout(dropout_rate) | |
| ) | |
| def forward(self, x, ref): | |
| b, n = x.shape[:2] | |
| h = self.num_heads | |
| q = self.to_q(x) | |
| k = self.to_k(ref) | |
| v = self.to_v(ref) | |
| q = q.reshape((b, n, h, -1)).permute((0, 2, 1, 3)) | |
| k = k.reshape((b, ref.shape[1], h, -1)).permute((0, 2, 1, 3)) | |
| v = v.reshape((b, ref.shape[1], h, -1)).permute((0, 2, 1, 3)) | |
| mult = torch.matmul(q, k.transpose(-1,-2)) * self.scale | |
| if self.apply_softmax: | |
| mult = F.softmax(mult, dim=-1) | |
| out = torch.matmul(mult, v) | |
| out = out.permute((0,2,1,3)).flatten(2) | |
| return self.fc_out(out) | |
| class FeedForward(nn.Sequential): | |
| def __init__(self, dim, hidden_dim, dropout_rate=0.): | |
| super().__init__( | |
| # TODO:to be more mmlab | |
| nn.Linear(dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(hidden_dim, dim), | |
| nn.Dropout(dropout_rate) | |
| ) | |
| class TransformerEncoder(nn.Module): | |
| def __init__(self, | |
| in_dims, | |
| embed_dims, | |
| num_heads, | |
| drop_rate, | |
| norm_cfg, | |
| apply_softmax=True): | |
| super(TransformerEncoder, self).__init__() | |
| self.attn = CrossAttention( | |
| in_dims, | |
| embed_dims, | |
| num_heads, | |
| dropout_rate=drop_rate, | |
| apply_softmax=apply_softmax) | |
| self.ff = FeedForward( | |
| in_dims, | |
| embed_dims, | |
| drop_rate | |
| ) | |
| self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] | |
| self.norm2 = build_norm_layer(norm_cfg, in_dims)[1] | |
| def forward(self, x): | |
| x_ = self.attn(self.norm1(x),self.norm1(x)) + x | |
| y = self.ff(self.norm2(x_)) + x_ | |
| return y | |
| class TransformerDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_dims, | |
| embed_dims, | |
| num_heads, | |
| drop_rate, | |
| norm_cfg, | |
| apply_softmax=True | |
| ): | |
| super(TransformerDecoder, self).__init__() | |
| self.attn = CrossAttention( | |
| in_dims, | |
| embed_dims, | |
| num_heads, | |
| dropout_rate=drop_rate, | |
| apply_softmax=apply_softmax) | |
| self.ff = FeedForward( | |
| in_dims, | |
| embed_dims, | |
| drop_rate | |
| ) | |
| self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] | |
| self.norm1_ = build_norm_layer(norm_cfg, in_dims)[1] | |
| self.norm2 = build_norm_layer(norm_cfg, in_dims)[1] | |
| def forward(self, x, ref): | |
| x_ = self.attn(self.norm1(x),self.norm1_(ref)) + x | |
| y = self.ff(self.norm2(x_)) + x_ | |
| return y | |
| class BITHead(BaseDecodeHead): | |
| """BIT Head | |
| This head is the improved implementation of'Remote Sensing Image | |
| Change Detection With Transformers<https://github.com/justchenhao/BIT_CD>' | |
| Args: | |
| in_channels (int): Number of input feature channels (from backbone). Default: 512 | |
| channels (int): Number of output channels of pre_process. Default: 32. | |
| embed_dims (int): Number of expanded channels of Attention block. Default: 64. | |
| enc_depth (int): Depth of block of transformer encoder. Default: 1. | |
| enc_with_pos (bool): Using position embedding in transformer encoder. | |
| Default: True | |
| dec_depth (int): Depth of block of transformer decoder. Default: 8. | |
| num_heads (int): Number of Multi-Head Cross-Attention Head of transformer encoder. | |
| Default: 8. | |
| use_tokenizer (bool),Using semantic token. Default: True | |
| token_len (int): Number of dims of token. Default: 4. | |
| pre_upsample (int): Scale factor of upsample of pre_process. | |
| (default upsample to 64x64) | |
| Default: 2. | |
| """ | |
| def __init__(self, | |
| in_channels=256, | |
| channels=32, | |
| embed_dims=64, | |
| enc_depth=1, | |
| enc_with_pos=True, | |
| dec_depth=8, | |
| num_heads=8, | |
| drop_rate=0., | |
| pool_size=2, | |
| pool_mode='max', | |
| use_tokenizer=True, | |
| token_len=4, | |
| pre_upsample=2, | |
| upsample_size=4, | |
| norm_cfg=dict(type='LN'), | |
| act_cfg=dict(type='ReLU', inplace=True), | |
| **kwargs): | |
| super().__init__(in_channels, channels, **kwargs) | |
| self.norm_cfg = norm_cfg | |
| self.act_cfg = act_cfg | |
| self.embed_dims=embed_dims | |
| self.use_tokenizer = use_tokenizer | |
| self.num_heads=num_heads | |
| if not use_tokenizer: | |
| # If a tokenzier is not to be used,then downsample the feature maps | |
| self.pool_size = pool_size | |
| self.pool_mode = pool_mode | |
| self.token_len = pool_size * pool_size | |
| else: | |
| self.token_len = token_len | |
| self.conv_att = ConvModule( | |
| self.channels, | |
| self.token_len, | |
| 1, | |
| conv_cfg=self.conv_cfg, | |
| ) | |
| self.enc_with_pos = enc_with_pos | |
| if enc_with_pos: | |
| self.enc_pos_embedding = nn.Parameter(torch.randn(1, self.token_len * 2, self.channels)) | |
| # pre_process to backbone feature | |
| self.pre_process = Sequential( | |
| Upsample(scale_factor=pre_upsample, mode='bilinear', align_corners=self.align_corners), | |
| ConvModule( | |
| self.in_channels, | |
| self.channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=self.conv_cfg | |
| ) | |
| ) | |
| # Transformer Encoder | |
| self.encoder = ModuleList() | |
| for _ in range(enc_depth): | |
| block = TransformerEncoder( | |
| self.channels, | |
| self.embed_dims, | |
| self.num_heads, | |
| drop_rate=drop_rate, | |
| norm_cfg=self.norm_cfg, | |
| ) | |
| self.encoder.append(block) | |
| # Transformer Decoder | |
| self.decoder = ModuleList() | |
| for _ in range(dec_depth): | |
| block = TransformerDecoder( | |
| self.channels, | |
| self.embed_dims, | |
| self.num_heads, | |
| drop_rate=drop_rate, | |
| norm_cfg=self.norm_cfg, | |
| ) | |
| self.decoder.append(block) | |
| self.upsample = Upsample(scale_factor=upsample_size,mode='bilinear',align_corners=self.align_corners) | |
| # Token | |
| def _forward_semantic_tokens(self, x): | |
| b, c = x.shape[:2] | |
| att_map = self.conv_att(x) | |
| att_map = att_map.reshape((b, self.token_len, 1, -1)) | |
| att_map = F.softmax(att_map, dim=-1) | |
| x = x.reshape((b, 1, c, -1)) | |
| tokens = (x * att_map).sum(-1) | |
| return tokens | |
| def _forward_reshaped_tokens(self, x): | |
| if self.pool_mode == 'max': | |
| x = F.adaptive_max_pool2d(x, (self.pool_size, self.pool_size)) | |
| elif self.pool_mode == 'avg': | |
| x = F.adaptive_avg_pool2d(x, (self.pool_size, self.pool_size)) | |
| else: | |
| x = x | |
| tokens = x.permute((0, 2, 3, 1)).flatten(1, 2) | |
| return tokens | |
| def _forward_feature(self, inputs): | |
| """Forward function for feature maps before classifying each pixel with | |
| ``self.cls_seg`` fc. | |
| Args: | |
| inputs (list[Tensor]): List of multi-level img features. | |
| Returns: | |
| feats (Tensor): A tensor of shape (batch_size, self.channels, | |
| H, W) which is feature map for last layer of decoder head. | |
| """ | |
| inputs = self._transform_inputs(inputs) | |
| x1, x2 = torch.chunk(inputs, 2, dim=1) | |
| x1 = self.pre_process(x1) | |
| x2 = self.pre_process(x2) | |
| # Tokenization | |
| if self.use_tokenizer: | |
| token1 = self._forward_semantic_tokens(x1) | |
| token2 = self._forward_semantic_tokens(x2) | |
| else: | |
| token1 = self._forward_reshaped_tokens(x1) | |
| token2 = self._forward_reshaped_tokens(x2) | |
| # Transformer encoder forward | |
| token = torch.cat([token1, token2], dim=1) | |
| if self.enc_with_pos: | |
| token += self.enc_pos_embedding | |
| for i, _encoder in enumerate(self.encoder): | |
| token = _encoder(token) | |
| token1, token2 = torch.chunk(token, 2, dim=1) | |
| # Transformer decoder forward | |
| for _decoder in self.decoder: | |
| b, c, h, w = x1.shape | |
| x1 = x1.permute((0, 2, 3, 1)).flatten(1, 2) | |
| x2 = x2.permute((0, 2, 3, 1)).flatten(1, 2) | |
| x1 = _decoder(x1, token1) | |
| x2 = _decoder(x2, token2) | |
| x1 = x1.transpose(1, 2).reshape((b, c, h, w)) | |
| x2 = x2.transpose(1, 2).reshape((b, c, h, w)) | |
| # Feature differencing | |
| y = torch.abs(x1 - x2) | |
| y = self.upsample(y) | |
| return y | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| output = self._forward_feature(inputs) | |
| output = self.cls_seg(output) | |
| return output | |