InPeerReview's picture
Upload 161 files
226675b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from rscd.models.backbones.seaformer_vmanba import SeaFormer_L
from rscd.models.backbones.cdloma import SS2D
class cdlamba(nn.Module):
def __init__(self, channels):
super().__init__()
self.backbone = SeaFormer_L(pretrained=True)
self.channels = channels
self.css = nn.ModuleList()
for i in range(20):
self.css.append(SS2D(d_model = self.channels[i // 5], channel_first=True, stage_num= i // 5, depth_num= i % 5).cuda())
input_proj_list = []
for i in range(4):
in_channels = self.channels[i]
input_proj_list.append(nn.Sequential(
nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=1),
nn.GroupNorm(32, in_channels * 2),
))
self.input_proj = nn.ModuleList(input_proj_list)
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def forward(self, xA, xB):
inA, inB = xA, xB
css_out = []
for i in range(4):
fA = self.backbone(inA, i)
fB = self.backbone(inB, i)
f = torch.concat([fA, fB], 1)
f1 = self.css[i * 5](f)
f2 = self.css[i * 5 + 1](f)
f3 = self.css[i * 5 + 2](f)
f4 = self.css[i * 5 + 3](f)
f5 = self.css[i * 5 + 4](f)
f = self.input_proj[i](f1 + f2 + f3 + f4 + f5)
cdaA, cdaB = torch.split(f, self.channels[i], 1)
css_out.append(cdaA - cdaB)
inA, inB = fA + cdaA, fB + cdaB
for i in range(1, 4):
css_out[i] = F.interpolate(
css_out[i],
scale_factor=(2 ** i, 2 ** i),
mode="bilinear",
align_corners=False,
)
extract_out = torch.concat(css_out, dim=1)
return extract_out