import torch import torch.nn as nn import torch.nn.functional as F from rscd.models.backbones.seaformer_vmanba import SeaFormer_L from rscd.models.backbones.cdloma import SS2D class cdlamba(nn.Module): def __init__(self, channels): super().__init__() self.backbone = SeaFormer_L(pretrained=True) self.channels = channels self.css = nn.ModuleList() for i in range(20): self.css.append(SS2D(d_model = self.channels[i // 5], channel_first=True, stage_num= i // 5, depth_num= i % 5).cuda()) input_proj_list = [] for i in range(4): in_channels = self.channels[i] input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=1), nn.GroupNorm(32, in_channels * 2), )) self.input_proj = nn.ModuleList(input_proj_list) for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) def forward(self, xA, xB): inA, inB = xA, xB css_out = [] for i in range(4): fA = self.backbone(inA, i) fB = self.backbone(inB, i) f = torch.concat([fA, fB], 1) f1 = self.css[i * 5](f) f2 = self.css[i * 5 + 1](f) f3 = self.css[i * 5 + 2](f) f4 = self.css[i * 5 + 3](f) f5 = self.css[i * 5 + 4](f) f = self.input_proj[i](f1 + f2 + f3 + f4 + f5) cdaA, cdaB = torch.split(f, self.channels[i], 1) css_out.append(cdaA - cdaB) inA, inB = fA + cdaA, fB + cdaB for i in range(1, 4): css_out[i] = F.interpolate( css_out[i], scale_factor=(2 ** i, 2 ** i), mode="bilinear", align_corners=False, ) extract_out = torch.concat(css_out, dim=1) return extract_out