Spaces:
Runtime error
Runtime error
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule | |
| from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
| from mmseg.models.utils import resize | |
| from opencd.registry import MODELS | |
| class BAM(nn.Module): | |
| """ Basic self-attention module | |
| """ | |
| def __init__(self, in_dim, ds=8, activation=nn.ReLU): | |
| super(BAM, self).__init__() | |
| self.chanel_in = in_dim | |
| self.key_channel = self.chanel_in // 8 | |
| self.activation = activation | |
| self.ds = ds # | |
| self.pool = nn.AvgPool2d(self.ds) | |
| self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) | |
| self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) | |
| self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) | |
| self.gamma = nn.Parameter(torch.zeros(1)) | |
| self.softmax = nn.Softmax(dim=-1) # | |
| def forward(self, input): | |
| """ | |
| inputs : | |
| x : input feature maps( B X C X W X H) | |
| returns : | |
| out : self attention value + input feature | |
| attention: B X N X N (N is Width*Height) | |
| """ | |
| x = self.pool(input) | |
| m_batchsize, C, width, height = x.size() | |
| proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds) | |
| proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds) | |
| energy = torch.bmm(proj_query, proj_key) # transpose check | |
| energy = (self.key_channel ** -.5) * energy | |
| attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds) | |
| proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N | |
| out = torch.bmm(proj_value, attention.permute(0, 2, 1)) | |
| out = out.view(m_batchsize, C, width, height) | |
| out = F.interpolate(out, [width * self.ds, height * self.ds]) | |
| out = out + input | |
| return out | |
| class _PAMBlock(nn.Module): | |
| ''' | |
| The basic implementation for self-attention block/non-local block | |
| Input/Output: | |
| N * C * H * (2*W) | |
| Parameters: | |
| in_channels : the dimension of the input feature map | |
| key_channels : the dimension after the key/query transform | |
| value_channels : the dimension after the value transform | |
| scale : choose the scale to partition the input feature maps | |
| ds : downsampling scale | |
| ''' | |
| def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1): | |
| super(_PAMBlock, self).__init__() | |
| self.scale = scale | |
| self.ds = ds | |
| self.pool = nn.AvgPool2d(self.ds) | |
| self.in_channels = in_channels | |
| self.key_channels = key_channels | |
| self.value_channels = value_channels | |
| self.f_key = nn.Sequential( | |
| nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, | |
| kernel_size=1, stride=1, padding=0), | |
| nn.BatchNorm2d(self.key_channels) | |
| ) | |
| self.f_query = nn.Sequential( | |
| nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, | |
| kernel_size=1, stride=1, padding=0), | |
| nn.BatchNorm2d(self.key_channels) | |
| ) | |
| self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| def forward(self, input): | |
| x = input | |
| if self.ds != 1: | |
| x = self.pool(input) | |
| # input shape: b,c,h,2w | |
| batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3) // 2 | |
| local_y = [] | |
| local_x = [] | |
| step_h, step_w = h // self.scale, w // self.scale | |
| for i in range(0, self.scale): | |
| for j in range(0, self.scale): | |
| start_x, start_y = i * step_h, j * step_w | |
| end_x, end_y = min(start_x + step_h, h), min(start_y + step_w, w) | |
| if i == (self.scale - 1): | |
| end_x = h | |
| if j == (self.scale - 1): | |
| end_y = w | |
| local_x += [start_x, end_x] | |
| local_y += [start_y, end_y] | |
| value = self.f_value(x) | |
| query = self.f_query(x) | |
| key = self.f_key(x) | |
| value = torch.stack([value[:, :, :, :w], value[:, :, :, w:]], 4) # B*N*H*W*2 | |
| query = torch.stack([query[:, :, :, :w], query[:, :, :, w:]], 4) # B*N*H*W*2 | |
| key = torch.stack([key[:, :, :, :w], key[:, :, :, w:]], 4) # B*N*H*W*2 | |
| local_block_cnt = 2 * self.scale * self.scale | |
| # self-attention func | |
| def func(value_local, query_local, key_local): | |
| batch_size_new = value_local.size(0) | |
| h_local, w_local = value_local.size(2), value_local.size(3) | |
| value_local = value_local.contiguous().view(batch_size_new, self.value_channels, -1) | |
| query_local = query_local.contiguous().view(batch_size_new, self.key_channels, -1) | |
| query_local = query_local.permute(0, 2, 1) | |
| key_local = key_local.contiguous().view(batch_size_new, self.key_channels, -1) | |
| sim_map = torch.bmm(query_local, key_local) # batch matrix multiplication | |
| sim_map = (self.key_channels ** -.5) * sim_map | |
| sim_map = F.softmax(sim_map, dim=-1) | |
| context_local = torch.bmm(value_local, sim_map.permute(0, 2, 1)) | |
| # context_local = context_local.permute(0, 2, 1).contiguous() | |
| context_local = context_local.view(batch_size_new, self.value_channels, h_local, w_local, 2) | |
| return context_local | |
| # Parallel Computing to speed up | |
| # reshape value_local, q, k | |
| v_list = [value[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in | |
| range(0, local_block_cnt, 2)] | |
| v_locals = torch.cat(v_list, dim=0) | |
| q_list = [query[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in | |
| range(0, local_block_cnt, 2)] | |
| q_locals = torch.cat(q_list, dim=0) | |
| k_list = [key[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in range(0, local_block_cnt, 2)] | |
| k_locals = torch.cat(k_list, dim=0) | |
| context_locals = func(v_locals, q_locals, k_locals) | |
| context_list = [] | |
| for i in range(0, self.scale): | |
| row_tmp = [] | |
| for j in range(0, self.scale): | |
| left = batch_size * (j + i * self.scale) | |
| right = batch_size * (j + i * self.scale) + batch_size | |
| tmp = context_locals[left:right] | |
| row_tmp.append(tmp) | |
| context_list.append(torch.cat(row_tmp, 3)) | |
| context = torch.cat(context_list, 2) | |
| context = torch.cat([context[:, :, :, :, 0], context[:, :, :, :, 1]], 3) | |
| if self.ds != 1: | |
| context = F.interpolate(context, [h * self.ds, 2 * w * self.ds]) | |
| return context | |
| class PAMBlock(_PAMBlock): | |
| def __init__(self, in_channels, key_channels=None, value_channels=None, scale=1, ds=1): | |
| if key_channels == None: | |
| key_channels = in_channels // 8 | |
| if value_channels == None: | |
| value_channels = in_channels | |
| super(PAMBlock, self).__init__(in_channels, key_channels, value_channels, scale, ds) | |
| class PAM(nn.Module): | |
| """ | |
| PAM module | |
| """ | |
| def __init__(self, in_channels, out_channels, sizes=([1]), ds=1): | |
| super(PAM, self).__init__() | |
| self.group = len(sizes) | |
| self.stages = [] | |
| self.ds = ds # output stride | |
| self.value_channels = out_channels | |
| self.key_channels = out_channels // 8 | |
| self.stages = nn.ModuleList( | |
| [self._make_stage(in_channels, self.key_channels, self.value_channels, size, self.ds) | |
| for size in sizes]) | |
| self.conv_bn = nn.Sequential( | |
| nn.Conv2d(in_channels * self.group, out_channels, kernel_size=1, padding=0, bias=False), | |
| # nn.BatchNorm2d(out_channels), | |
| ) | |
| def _make_stage(self, in_channels, key_channels, value_channels, size, ds): | |
| return PAMBlock(in_channels, key_channels, value_channels, size, ds) | |
| def forward(self, feats): | |
| priors = [stage(feats) for stage in self.stages] | |
| # concat | |
| context = [] | |
| for i in range(0, len(priors)): | |
| context += [priors[i]] | |
| output = self.conv_bn(torch.cat(context, 1)) | |
| return output | |
| def weights_init(m): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| nn.init.normal_(m.weight.data, 0.0, 0.02) | |
| elif classname.find('BatchNorm') != -1: | |
| nn.init.normal_(m.weight.data, 1.0, 0.02) | |
| nn.init.constant_(m.bias.data, 0) | |
| class CDSA(nn.Module): | |
| """self attention module for change detection | |
| """ | |
| def __init__(self, in_c, ds=1, mode='BAM'): | |
| super(CDSA, self).__init__() | |
| self.in_C = in_c | |
| self.ds = ds | |
| self.mode = mode | |
| if self.mode == 'BAM': | |
| self.Self_Att = BAM(self.in_C, ds=self.ds) | |
| elif self.mode == 'PAM': | |
| self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1, 2, 4, 8], ds=self.ds) | |
| elif self.mode == 'None': | |
| self.Self_Att = nn.Identity() | |
| self.apply(weights_init) | |
| def forward(self, x1, x2): | |
| height = x1.shape[3] | |
| x = torch.cat((x1, x2), 3) | |
| x = self.Self_Att(x) | |
| return x[:, :, :, 0:height], x[:, :, :, height:] | |
| class STAHead(BaseDecodeHead): | |
| """The Head of STANet. | |
| Args: | |
| sa_mode: | |
| interpolate_mode: The interpolate mode of MLP head upsample operation. | |
| Default: 'bilinear'. | |
| """ | |
| def __init__( | |
| self, | |
| sa_mode='PAM', | |
| sa_in_channels=256, | |
| sa_ds=1, | |
| distance_threshold=1, | |
| **kwargs): | |
| super().__init__(input_transform='multiple_select', num_classes=1, **kwargs) | |
| num_inputs = len(self.in_channels) | |
| assert num_inputs == len(self.in_index) | |
| self.distance_threshold = distance_threshold | |
| self.fpn_convs = nn.ModuleList() | |
| for in_channels in self.in_channels: | |
| fpn_conv = ConvModule( | |
| in_channels, | |
| self.channels, | |
| 1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| inplace=False) | |
| self.fpn_convs.append(fpn_conv) | |
| self.fpn_bottleneck = nn.Sequential( | |
| ConvModule( | |
| len(self.in_channels) * self.channels, | |
| sa_in_channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg), | |
| nn.Dropout(0.5), | |
| ConvModule( | |
| sa_in_channels, | |
| sa_in_channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg) | |
| ) | |
| self.netA = CDSA(in_c=sa_in_channels, ds=sa_ds, mode=sa_mode) | |
| self.calc_dist = nn.PairwiseDistance(keepdim=True) | |
| self.conv_seg = nn.Identity() | |
| def base_forward(self, inputs): | |
| fpn_outs = [ | |
| self.fpn_convs[i](inputs[i]) | |
| for i in range(len(self.in_channels)) | |
| ] | |
| for i in range(len(self.in_channels)): | |
| fpn_outs[i] = resize( | |
| fpn_outs[i], | |
| size=fpn_outs[0].shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| fpn_outs = torch.cat(fpn_outs, dim=1) | |
| feats = self.fpn_bottleneck(fpn_outs) | |
| return feats | |
| def forward(self, inputs): | |
| # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 | |
| inputs = self._transform_inputs(inputs) | |
| inputs1 = [] | |
| inputs2 = [] | |
| for input in inputs: | |
| f1, f2 = torch.chunk(input, 2, dim=1) | |
| inputs1.append(f1) | |
| inputs2.append(f2) | |
| f1 = self.base_forward(inputs1) | |
| f2 = self.base_forward(inputs2) | |
| f1, f2 = self.netA(f1, f2) | |
| # if you use PyTorch<=1.8, there may be some problems. | |
| # see https://github.com/justchenhao/STANet/issues/85 | |
| f1 = f1.permute(0, 2, 3, 1) | |
| f2 = f2.permute(0, 2, 3, 1) | |
| dist = self.calc_dist(f1, f2).permute(0, 3, 1, 2) | |
| dist = F.interpolate(dist, size=inputs[0].shape[2:], mode='bilinear', align_corners=True) | |
| return dist | |
| def predict_by_feat(self, seg_logits, batch_img_metas): | |
| """Transform a batch of output seg_logits to the input shape. | |
| Args: | |
| seg_logits (Tensor): The output from decode head forward function. | |
| batch_img_metas (list[dict]): Meta information of each image, e.g., | |
| image size, scaling factor, etc. | |
| Returns: | |
| Tensor: Outputs segmentation logits map. | |
| """ | |
| seg_logits_copy = copy.deepcopy(seg_logits) | |
| seg_logits[seg_logits_copy > self.distance_threshold] = 100 | |
| seg_logits[seg_logits_copy <= self.distance_threshold] = -100 | |
| seg_logits = resize( | |
| input=seg_logits, | |
| size=batch_img_metas[0]['img_shape'], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| return seg_logits | |