import torch from torch import nn from torch.cuda.amp import autocast from rscd.models.decoderheads.vision_lstm import ViLBlock, SequenceTraversal from torch.nn import functional as F from functools import partial from rscd.models.backbones.lib_mamba.vmambanew import SS2D import pywt class PA(nn.Module): def __init__(self, dim, norm_layer, act_layer): super().__init__() self.p_conv = nn.Sequential( nn.Conv2d(dim, dim*4, 1, bias=False), norm_layer(dim*4), act_layer(), nn.Conv2d(dim*4, dim, 1, bias=False) ) self.gate_fn = nn.Sigmoid() def forward(self, x): att = self.p_conv(x) x = x * self.gate_fn(att) return x class Mish(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x * torch.tanh(F.softplus(x)) class _ScaleModule(nn.Module): def __init__(self, dims, init_scale=1.0): super().__init__() self.weight = nn.Parameter(torch.ones(*dims) * init_scale) def forward(self, x): return torch.mul(self.weight, x) def create_wavelet_filter(wave, in_size, out_size, dtype=torch.float): w = pywt.Wavelet(wave) dec_hi = torch.tensor(w.dec_hi[::-1], dtype=dtype) dec_lo = torch.tensor(w.dec_lo[::-1], dtype=dtype) dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1), dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1), dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1), dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0) dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1) rec_hi = torch.tensor(w.rec_hi[::-1], dtype=dtype).flip(dims=[0]) rec_lo = torch.tensor(w.rec_lo[::-1], dtype=dtype).flip(dims=[0]) rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1), rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1), rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1), rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0) rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1) return dec_filters, rec_filters def wavelet_transform(x, filters): b, c, h, w = x.shape pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1) x = F.conv2d(x, filters, stride=2, groups=c, padding=pad) x = x.reshape(b, c, 4, h // 2, w // 2) return x def inverse_wavelet_transform(x, filters): b, c, _, h_half, w_half = x.shape pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1) x = x.reshape(b, c * 4, h_half, w_half) x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad) return x class MBWTConv2d(nn.Module): def __init__(self, in_channels, kernel_size=5, wt_levels=1, wt_type='db1', ssm_ratio=1, forward_type="v05"): super().__init__() assert in_channels == in_channels self.wt_levels = wt_levels self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels) self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False) self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False) self.wt_function = partial(wavelet_transform, filters=self.wt_filter) self.iwt_function = partial(inverse_wavelet_transform, filters=self.iwt_filter) self.global_atten = SS2D(d_model=in_channels, d_state=1, ssm_ratio=ssm_ratio, initialize="v2", forward_type=forward_type, channel_first=True, k_group=2) self.base_scale = _ScaleModule([1, in_channels, 1, 1]) self.wavelet_convs = nn.ModuleList([ nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', groups=in_channels * 4) for _ in range(wt_levels) ]) self.wavelet_scale = nn.ModuleList([ _ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(wt_levels) ]) def forward(self, x): x_ll_in_levels, x_h_in_levels, shapes_in_levels = [], [], [] curr_x_ll = x for i in range(self.wt_levels): curr_shape = curr_x_ll.shape shapes_in_levels.append(curr_shape) if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0): curr_x_ll = F.pad(curr_x_ll, (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)) curr_x = self.wt_function(curr_x_ll) curr_x_ll = curr_x[:, :, 0, :, :] shape_x = curr_x.shape curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4]) curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag)).reshape(shape_x) x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :]) x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :]) next_x_ll = 0 for i in range(self.wt_levels - 1, -1, -1): curr_x_ll = x_ll_in_levels.pop() + next_x_ll curr_x = torch.cat([curr_x_ll.unsqueeze(2), x_h_in_levels.pop()], dim=2) next_x_ll = self.iwt_function(curr_x) next_x_ll = next_x_ll[:, :, :shapes_in_levels[i][2], :shapes_in_levels[i][3]] x_tag = next_x_ll x = self.base_scale(self.global_atten(x)) + x_tag return x class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) self.relu1 = nn.ReLU() self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) out = avg_out + max_out return self.sigmoid(out) # 空间注意力模块 class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() assert kernel_size in (3, 7), 'kernel size must be 3 or 7' padding = 3 if kernel_size == 7 else 1 self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv1(x) return self.sigmoid(x) # CBAM 注意力模块 class CBAM(nn.Module): def __init__(self, in_planes): super(CBAM, self).__init__() self.ca = ChannelAttention(in_planes) self.sa = SpatialAttention() def forward(self, x): x = self.ca(x) * x x = self.sa(x) * x return x class DynamicConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=False, num_experts=4): super(DynamicConv2d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.groups = groups self.bias = bias self.num_experts = num_experts self.experts = nn.ModuleList([ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias) for _ in range(num_experts) ]) self.gating = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, num_experts, 1, bias=False), nn.Softmax(dim=1) ) def forward(self, x): gates = self.gating(x) gates = gates.view(x.size(0), self.num_experts, 1, 1, 1) outputs = [] for i, expert in enumerate(self.experts): outputs.append(expert(x).unsqueeze(1)) outputs = torch.cat(outputs, dim=1) out = (gates * outputs).sum(dim=1) return out class DWConv2d_BN_ReLU(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3): super().__init__() self.add_module('dwconv3x3', DynamicConv2d(in_channels, in_channels, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=in_channels, bias=False)) self.add_module('bn1', nn.BatchNorm2d(in_channels)) self.add_module('relu', Mish()) self.add_module('dwconv1x1', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=in_channels, bias=False)) self.add_module('bn2', nn.BatchNorm2d(out_channels)) class Conv2d_BN(nn.Sequential): def __init__(self, a, b, ks=1, stride=1, pad=0, groups=1): super().__init__() self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, groups=groups, bias=False)) self.add_module('bn', nn.BatchNorm2d(b)) class FFN(nn.Module): def __init__(self, ed, h): super().__init__() self.pw1 = Conv2d_BN(ed, h) self.act = Mish() self.pw2 = Conv2d_BN(h, ed) def forward(self, x): return self.pw2(self.act(self.pw1(x))) class StochasticDepth(nn.Module): def __init__(self, survival_prob=0.8): super().__init__() self.survival_prob = survival_prob def forward(self, x): if not self.training: return x batch_size = x.shape[0] random_tensor = self.survival_prob + torch.rand([batch_size, 1, 1, 1], dtype=x.dtype, device=x.device) binary_tensor = torch.floor(random_tensor) return x * binary_tensor / self.survival_prob class Residual(nn.Module): def __init__(self, m, survival_prob=0.8): super().__init__() self.m = m self.stochastic_depth = StochasticDepth(survival_prob) def forward(self, x): return x + self.stochastic_depth(self.m(x)) class GLP_block(nn.Module): def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,pa_ratio = 0.1, kernels=3, ssm_ratio=1, forward_type="v052d"): super().__init__() self.dim = dim self.global_channels = int(global_ratio * dim) self.local_channels = int(local_ratio * dim) self.pa_channels = int(pa_ratio * dim) self.identity_channels = dim - self.global_channels - self.local_channels - self.pa_channels self.local_op = nn.ModuleList([ DWConv2d_BN_ReLU(self.local_channels, self.local_channels, k) for k in [3, 5, 7] ]) if self.local_channels > 0 else nn.Identity() self.global_op = MBWTConv2d(self.global_channels, kernel_size=kernels, ssm_ratio=ssm_ratio, forward_type=forward_type) \ if self.global_channels > 0 else nn.Identity() self.cbam = CBAM(dim) self.proj = nn.Sequential( Mish(), Conv2d_BN(dim, dim), CBAM(dim) ) self.pa_op = PA(self.pa_channels, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU) \ if self.pa_channels > 0 else nn.Identity() def forward(self, x): x1, x2, x3, x4 = torch.split(x, [self.global_channels, self.local_channels, self.identity_channels, self.pa_channels], dim=1) if isinstance(self.local_op, nn.ModuleList): local_features = [op(x2) for op in self.local_op] local_features = torch.cat(local_features, dim=1) local_features = torch.mean(local_features, dim=1, keepdim=True) local_features = local_features.expand(-1, self.local_channels, -1, -1) else: local_features = self.local_op(x2) out = torch.cat([self.global_op(x1), local_features, x3, self.pa_op(x4)], dim=1) return self.proj(out) class SASF(nn.Module): def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,pa_ratio = 0.1, kernels=3, ssm_ratio=1, forward_type="v052d"): super().__init__() self.dim = dim self.global_channels = int(global_ratio * dim) self.local_channels = int(local_ratio * dim) self.pa_channels = int(pa_ratio * dim) self.identity_channels = dim - self.global_channels - self.local_channels - self.pa_channels self.local_op = nn.ModuleList([ DWConv2d_BN_ReLU(self.local_channels, self.local_channels, k) for k in [3, 5, 7] ]) if self.local_channels > 0 else nn.Identity() self.global_op = MBWTConv2d(self.global_channels, kernel_size=kernels, ssm_ratio=ssm_ratio, forward_type=forward_type) \ if self.global_channels > 0 else nn.Identity() self.cbam = CBAM(dim) self.proj = nn.Sequential( Mish(), Conv2d_BN(dim, dim), CBAM(dim) ) self.pa_op = PA(self.pa_channels, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU) \ if self.pa_channels > 0 else nn.Identity() def forward(self, x): x1, x2, x3, x4 = torch.split(x, [self.global_channels, self.local_channels, self.identity_channels, self.pa_channels], dim=1) if isinstance(self.local_op, nn.ModuleList): local_features = [op(x2) for op in self.local_op] local_features = torch.cat(local_features, dim=1) local_features = torch.mean(local_features, dim=1, keepdim=True) local_features = local_features.expand(-1, self.local_channels, -1, -1) else: local_features = self.local_op(x2) out = torch.cat([self.global_op(x1), local_features, x3, self.pa_op(x4)], dim=1) return self.proj(out) class ViLLayer(nn.Module): def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2): super().__init__() self.dim = dim self.norm = nn.LayerNorm(dim) self.vil = ViLBlock( dim= self.dim, direction=SequenceTraversal.ROWWISE_FROM_TOP_LEFT ) @autocast(enabled=False) def forward(self, x): if x.dtype == torch.float16: x = x.type(torch.float32) B, C = x.shape[:2] assert C == self.dim n_tokens = x.shape[2:].numel() img_dims = x.shape[2:] x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) x_vil = self.vil(x_flat) out = x_vil.transpose(-1, -2).reshape(B, C, *img_dims) return out def dsconv_3x3(in_channel, out_channel): return nn.Sequential( nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel), nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1), nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True) ) def conv_1x1(in_channel, out_channel): return nn.Sequential( nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True) ) class SqueezeAxialPositionalEmbedding(nn.Module): def __init__(self, dim, shape): super().__init__() self.pos_embed = nn.Parameter(torch.randn([1, dim, shape])) def forward(self, x): B, C, N = x.shape x = x + F.interpolate(self.pos_embed, size=(N), mode='linear', align_corners=False) return x class SEBlock(nn.Module): def __init__(self, channels, r=16): super().__init__() self.fc = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels//r, 1), nn.ReLU(inplace=True), nn.Conv2d(channels//r, channels, 1), nn.Sigmoid() ) def forward(self, x): w = self.fc(x) # (B, C, 1, 1) return x * w class CTTF1(nn.Module): def __init__(self, in_channel, out_channel,global_ratio=0.2, local_ratio=0.2, pa_ratio = 0.2 ,kernels=5, ssm_ratio=2.0, forward_type="v052d"): super().__init__() self.catconvA = dsconv_3x3(in_channel * 2, in_channel) self.catconvB = dsconv_3x3(in_channel * 2, in_channel) self.catconv = dsconv_3x3(in_channel * 2, out_channel) self.convA = nn.Conv2d(in_channel, 1, 1) self.convB = nn.Conv2d(in_channel, 1, 1) self.sigmoid = nn.Sigmoid() self.mixer = Residual(GLP_block(in_channel, global_ratio, local_ratio,pa_ratio, kernels, ssm_ratio, forward_type)) self.mixer2 = Residual( SASF(in_channel, global_ratio = 0, local_ratio = 0.1, pa_ratio = 0, kernels = 5, ssm_ratio = 1, forward_type = "v052d")) self.fuse = nn.Sequential( nn.Conv2d(in_channel * 3, in_channel, kernel_size=1), nn.ReLU(inplace=True) ) self.cbam = CBAM(in_channel * 3) self.act = nn.SiLU() def forward(self, xA, xB): x_diffA = self.mixer(xA) x_diffB = self.mixer(xB) f1 = x_diffA f2 = x_diffB diff_signed = f1 - f2 diff_abs = torch.abs(diff_signed) sum_feat = f1 + f2 diff_signed = self.mixer2(diff_signed) diff_abs = self.mixer2(diff_abs) sum_feat = self.mixer2(sum_feat) # 将多路特征在通道维度拼接 f_fuse = torch.cat([diff_signed, diff_abs, sum_feat], dim=1) # (B, 4C, H, W) # 再接一个 1x1 卷积降维或提炼信息 f_fuse = self.cbam(f_fuse) x_diff = self.fuse(f_fuse) return x_diff class CTTF2(nn.Module): def __init__(self, in_channel, out_channel, global_ratio=0.25, local_ratio=0.25, pa_ratio=0, kernels=7, ssm_ratio=2.0, forward_type="v052d"): super().__init__() self.catconvA = dsconv_3x3(in_channel * 2, in_channel) self.catconvB = dsconv_3x3(in_channel * 2, in_channel) self.catconv = dsconv_3x3(in_channel * 2, out_channel) self.convA = nn.Conv2d(in_channel, 1, 1) self.convB = nn.Conv2d(in_channel, 1, 1) self.sigmoid = nn.Sigmoid() self.mixer = Residual( GLP_block(in_channel, global_ratio, local_ratio, pa_ratio, kernels, ssm_ratio, forward_type)) self.mixer2 = Residual( SASF(in_channel, global_ratio=0, local_ratio=0.1, pa_ratio=0, kernels=5, ssm_ratio=1, forward_type="v052d")) self.fuse = nn.Sequential( nn.Conv2d(in_channel * 3, in_channel, kernel_size=1), nn.ReLU(inplace=True) ) self.cbam = CBAM(in_channel * 3) self.act = nn.SiLU() def forward(self, xA, xB): x_diffA = self.mixer(xA) x_diffB = self.mixer(xB) f1 = x_diffA f2 = x_diffB diff_signed = f1 - f2 diff_abs = torch.abs(diff_signed) sum_feat = f1 + f2 diff_signed = self.mixer2(diff_signed) diff_abs = self.mixer2(diff_abs) sum_feat = self.mixer2(sum_feat) f_fuse = torch.cat([diff_signed, diff_abs, sum_feat], dim=1) # (B, 4C, H, W) f_fuse = self.cbam(f_fuse) x_diff = self.fuse(f_fuse) return x_diff class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=True): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear self.fc1 = Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class LHBlock(nn.Module): def __init__(self, channels_l, channels_h): super().__init__() self.channels_l = channels_l self.channels_h = channels_h self.cross_size = 12 self.cross_kv = nn.Sequential( nn.BatchNorm2d(channels_l), nn.AdaptiveMaxPool2d(output_size=(self.cross_size, self.cross_size)), nn.Conv2d(channels_l, 2 * channels_h, 1, 1, 0) ) self.conv = conv_1x1(channels_l, channels_h) self.norm = nn.BatchNorm2d(channels_h) self.mlp_l = Mlp(in_features=channels_l, out_features=channels_l) self.mlp_h = Mlp(in_features=channels_h, out_features=channels_h) def _act_sn(self, x): _, _, H, W = x.shape inner_channel = self.cross_size * self.cross_size x = x.reshape([-1, inner_channel, H, W]) * (inner_channel**-0.5) x = F.softmax(x, dim=1) x = x.reshape([1, -1, H, W]) return x def attn_h(self, x_h, cross_k, cross_v): B, _, H, W = x_h.shape x_h = self.norm(x_h) x_h = x_h.reshape([1, -1, H, W]) # n,c_in,h,w -> 1,n*c_in,h,w x_h = F.conv2d(x_h, cross_k, bias=None, stride=1, padding=0, groups=B) # 1,n*c_in,h,w -> 1,n*144,h,w (group=B) x_h = self._act_sn(x_h) x_h = F.conv2d(x_h, cross_v, bias=None, stride=1, padding=0, groups=B) # 1,n*144,h,w -> 1, n*c_in,h,w (group=B) x_h = x_h.reshape([-1, self.channels_h, H, W]) # 1, n*c_in,h,w -> n,c_in,h,w (c_in = c_out) return x_h def forward(self, x_l, x_h): x_l = x_l + self.mlp_l(x_l) x_l_conv = self.conv(x_l) x_h = x_h + F.interpolate(x_l_conv, size=x_h.shape[2:], mode='bilinear') cross_kv = self.cross_kv(x_l) cross_k, cross_v = cross_kv.split(self.channels_h, 1) cross_k = cross_k.permute(0, 2, 3, 1).reshape([-1, self.channels_h, 1, 1]) # n*144,channels_h,1,1 cross_v = cross_v.reshape([-1, self.cross_size * self.cross_size, 1, 1]) # n*channels_h,144,1,1 x_h = x_h + self.attn_h(x_h, cross_k, cross_v) # [4, 40, 128, 128] x_h = x_h + self.mlp_h(x_h) return x_h class CTTF(nn.Module): def __init__(self, channels=[40, 80, 192, 384]): super().__init__() self.channels = channels self.fusion0 = CTTF1(channels[0], channels[0]) self.fusion1 = CTTF1(channels[1], channels[1]) self.fusion2 = CTTF2(channels[2], channels[2]) self.fusion3 = CTTF2(channels[3], channels[3]) self.LHBlock1 = LHBlock(channels[1], channels[0]) self.LHBlock2 = LHBlock(channels[2], channels[0]) self.LHBlock3 = LHBlock(channels[3], channels[0]) self.mlp1 = Mlp(in_features=channels[0], out_features=channels[0]) self.mlp2 = Mlp(in_features=channels[0], out_features=2) self.dwc = dsconv_3x3(channels[0], channels[0]) def forward(self, inputs): featuresA, featuresB = inputs # fA_0, fA_1, fA_2, fA_3 = featuresA # fB_0, fB_1, fB_2, fB_3 = featuresB x_diff_0 = self.fusion0(featuresA[0], featuresB[0]) # [4, 40, 128, 128] x_diff_1 = self.fusion1(featuresA[1], featuresB[1]) # [4, 80, 64, 64] # x_diff_2 = featuresA[2] - featuresB[2] # x_diff_3 = featuresA[3] - featuresB[3] x_diff_2 = self.fusion2(featuresA[2], featuresB[2]) # [4, 192, 32, 32] x_diff_3 = self.fusion3(featuresA[3], featuresB[3]) # [4, 384, 16, 16] x_h = x_diff_0 x_h = self.LHBlock1(x_diff_1, x_h) # [4, 40, 128, 128] x_h = self.LHBlock2(x_diff_2, x_h) x_h = self.LHBlock3(x_diff_3, x_h) out = self.mlp2(self.dwc(x_h) + self.mlp1(x_h)) out = F.interpolate( out, scale_factor=(4, 4), mode="bilinear", align_corners=False, ) return out