| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from rscd.models.backbones.seaformer_vmanba import SeaFormer_L | |
| from rscd.models.backbones.cdloma import SS2D | |
| class cdlamba(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.backbone = SeaFormer_L(pretrained=True) | |
| self.channels = channels | |
| self.css = nn.ModuleList() | |
| for i in range(20): | |
| self.css.append(SS2D(d_model = self.channels[i // 5], channel_first=True, stage_num= i // 5, depth_num= i % 5).cuda()) | |
| input_proj_list = [] | |
| for i in range(4): | |
| in_channels = self.channels[i] | |
| input_proj_list.append(nn.Sequential( | |
| nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=1), | |
| nn.GroupNorm(32, in_channels * 2), | |
| )) | |
| self.input_proj = nn.ModuleList(input_proj_list) | |
| for proj in self.input_proj: | |
| nn.init.xavier_uniform_(proj[0].weight, gain=1) | |
| nn.init.constant_(proj[0].bias, 0) | |
| def forward(self, xA, xB): | |
| inA, inB = xA, xB | |
| css_out = [] | |
| for i in range(4): | |
| fA = self.backbone(inA, i) | |
| fB = self.backbone(inB, i) | |
| f = torch.concat([fA, fB], 1) | |
| f1 = self.css[i * 5](f) | |
| f2 = self.css[i * 5 + 1](f) | |
| f3 = self.css[i * 5 + 2](f) | |
| f4 = self.css[i * 5 + 3](f) | |
| f5 = self.css[i * 5 + 4](f) | |
| f = self.input_proj[i](f1 + f2 + f3 + f4 + f5) | |
| cdaA, cdaB = torch.split(f, self.channels[i], 1) | |
| css_out.append(cdaA - cdaB) | |
| inA, inB = fA + cdaA, fB + cdaB | |
| for i in range(1, 4): | |
| css_out[i] = F.interpolate( | |
| css_out[i], | |
| scale_factor=(2 ** i, 2 ** i), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| extract_out = torch.concat(css_out, dim=1) | |
| return extract_out | |