|
|
import torch
|
|
|
from torch import nn
|
|
|
from torch.cuda.amp import autocast
|
|
|
from rscd.models.decoderheads.vision_lstm import ViLBlock, SequenceTraversal
|
|
|
from torch.nn import functional as F
|
|
|
from functools import partial
|
|
|
from rscd.models.backbones.lib_mamba.vmambanew import SS2D
|
|
|
import pywt
|
|
|
|
|
|
class PA(nn.Module):
|
|
|
def __init__(self, dim, norm_layer, act_layer):
|
|
|
super().__init__()
|
|
|
self.p_conv = nn.Sequential(
|
|
|
nn.Conv2d(dim, dim*4, 1, bias=False),
|
|
|
norm_layer(dim*4),
|
|
|
act_layer(),
|
|
|
nn.Conv2d(dim*4, dim, 1, bias=False)
|
|
|
)
|
|
|
self.gate_fn = nn.Sigmoid()
|
|
|
|
|
|
def forward(self, x):
|
|
|
att = self.p_conv(x)
|
|
|
x = x * self.gate_fn(att)
|
|
|
|
|
|
return x
|
|
|
|
|
|
class Mish(nn.Module):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
|
|
|
def forward(self, x):
|
|
|
return x * torch.tanh(F.softplus(x))
|
|
|
|
|
|
|
|
|
class _ScaleModule(nn.Module):
|
|
|
def __init__(self, dims, init_scale=1.0):
|
|
|
super().__init__()
|
|
|
self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return torch.mul(self.weight, x)
|
|
|
|
|
|
def create_wavelet_filter(wave, in_size, out_size, dtype=torch.float):
|
|
|
w = pywt.Wavelet(wave)
|
|
|
dec_hi = torch.tensor(w.dec_hi[::-1], dtype=dtype)
|
|
|
dec_lo = torch.tensor(w.dec_lo[::-1], dtype=dtype)
|
|
|
dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
|
|
|
dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
|
|
|
dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
|
|
|
dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
|
|
|
dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
|
|
|
rec_hi = torch.tensor(w.rec_hi[::-1], dtype=dtype).flip(dims=[0])
|
|
|
rec_lo = torch.tensor(w.rec_lo[::-1], dtype=dtype).flip(dims=[0])
|
|
|
rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
|
|
|
rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
|
|
|
rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
|
|
|
rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
|
|
|
rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
|
|
|
return dec_filters, rec_filters
|
|
|
|
|
|
def wavelet_transform(x, filters):
|
|
|
b, c, h, w = x.shape
|
|
|
pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
|
|
|
x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)
|
|
|
x = x.reshape(b, c, 4, h // 2, w // 2)
|
|
|
return x
|
|
|
|
|
|
def inverse_wavelet_transform(x, filters):
|
|
|
b, c, _, h_half, w_half = x.shape
|
|
|
pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
|
|
|
x = x.reshape(b, c * 4, h_half, w_half)
|
|
|
x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)
|
|
|
return x
|
|
|
|
|
|
class MBWTConv2d(nn.Module):
|
|
|
def __init__(self, in_channels, kernel_size=5, wt_levels=1, wt_type='db1', ssm_ratio=1, forward_type="v05"):
|
|
|
super().__init__()
|
|
|
assert in_channels == in_channels
|
|
|
self.wt_levels = wt_levels
|
|
|
self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels)
|
|
|
self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
|
|
|
self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
|
|
|
self.wt_function = partial(wavelet_transform, filters=self.wt_filter)
|
|
|
self.iwt_function = partial(inverse_wavelet_transform, filters=self.iwt_filter)
|
|
|
self.global_atten = SS2D(d_model=in_channels, d_state=1, ssm_ratio=ssm_ratio, initialize="v2",
|
|
|
forward_type=forward_type, channel_first=True, k_group=2)
|
|
|
self.base_scale = _ScaleModule([1, in_channels, 1, 1])
|
|
|
self.wavelet_convs = nn.ModuleList([
|
|
|
nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', groups=in_channels * 4)
|
|
|
for _ in range(wt_levels)
|
|
|
])
|
|
|
self.wavelet_scale = nn.ModuleList([
|
|
|
_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1)
|
|
|
for _ in range(wt_levels)
|
|
|
])
|
|
|
|
|
|
def forward(self, x):
|
|
|
x_ll_in_levels, x_h_in_levels, shapes_in_levels = [], [], []
|
|
|
curr_x_ll = x
|
|
|
for i in range(self.wt_levels):
|
|
|
curr_shape = curr_x_ll.shape
|
|
|
shapes_in_levels.append(curr_shape)
|
|
|
if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
|
|
|
curr_x_ll = F.pad(curr_x_ll, (0, curr_shape[3] % 2, 0, curr_shape[2] % 2))
|
|
|
curr_x = self.wt_function(curr_x_ll)
|
|
|
curr_x_ll = curr_x[:, :, 0, :, :]
|
|
|
shape_x = curr_x.shape
|
|
|
curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
|
|
|
curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag)).reshape(shape_x)
|
|
|
x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
|
|
|
x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])
|
|
|
next_x_ll = 0
|
|
|
for i in range(self.wt_levels - 1, -1, -1):
|
|
|
curr_x_ll = x_ll_in_levels.pop() + next_x_ll
|
|
|
curr_x = torch.cat([curr_x_ll.unsqueeze(2), x_h_in_levels.pop()], dim=2)
|
|
|
next_x_ll = self.iwt_function(curr_x)
|
|
|
next_x_ll = next_x_ll[:, :, :shapes_in_levels[i][2], :shapes_in_levels[i][3]]
|
|
|
x_tag = next_x_ll
|
|
|
x = self.base_scale(self.global_atten(x)) + x_tag
|
|
|
return x
|
|
|
|
|
|
class ChannelAttention(nn.Module):
|
|
|
def __init__(self, in_planes, ratio=16):
|
|
|
super(ChannelAttention, self).__init__()
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
|
|
|
|
|
self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
|
|
|
self.relu1 = nn.ReLU()
|
|
|
self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
def forward(self, x):
|
|
|
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
|
|
|
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
|
|
|
out = avg_out + max_out
|
|
|
return self.sigmoid(out)
|
|
|
|
|
|
|
|
|
class SpatialAttention(nn.Module):
|
|
|
def __init__(self, kernel_size=7):
|
|
|
super(SpatialAttention, self).__init__()
|
|
|
|
|
|
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
|
|
padding = 3 if kernel_size == 7 else 1
|
|
|
|
|
|
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
def forward(self, x):
|
|
|
avg_out = torch.mean(x, dim=1, keepdim=True)
|
|
|
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
|
|
x = torch.cat([avg_out, max_out], dim=1)
|
|
|
x = self.conv1(x)
|
|
|
return self.sigmoid(x)
|
|
|
|
|
|
|
|
|
class CBAM(nn.Module):
|
|
|
def __init__(self, in_planes):
|
|
|
super(CBAM, self).__init__()
|
|
|
self.ca = ChannelAttention(in_planes)
|
|
|
self.sa = SpatialAttention()
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.ca(x) * x
|
|
|
x = self.sa(x) * x
|
|
|
return x
|
|
|
|
|
|
class DynamicConv2d(nn.Module):
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=False, num_experts=4):
|
|
|
super(DynamicConv2d, self).__init__()
|
|
|
self.in_channels = in_channels
|
|
|
self.out_channels = out_channels
|
|
|
self.kernel_size = kernel_size
|
|
|
self.stride = stride
|
|
|
self.padding = padding
|
|
|
self.groups = groups
|
|
|
self.bias = bias
|
|
|
self.num_experts = num_experts
|
|
|
|
|
|
self.experts = nn.ModuleList([
|
|
|
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias)
|
|
|
for _ in range(num_experts)
|
|
|
])
|
|
|
self.gating = nn.Sequential(
|
|
|
nn.AdaptiveAvgPool2d(1),
|
|
|
nn.Conv2d(in_channels, num_experts, 1, bias=False),
|
|
|
nn.Softmax(dim=1)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
gates = self.gating(x)
|
|
|
gates = gates.view(x.size(0), self.num_experts, 1, 1, 1)
|
|
|
outputs = []
|
|
|
for i, expert in enumerate(self.experts):
|
|
|
outputs.append(expert(x).unsqueeze(1))
|
|
|
outputs = torch.cat(outputs, dim=1)
|
|
|
out = (gates * outputs).sum(dim=1)
|
|
|
return out
|
|
|
|
|
|
class DWConv2d_BN_ReLU(nn.Sequential):
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=3):
|
|
|
super().__init__()
|
|
|
self.add_module('dwconv3x3', DynamicConv2d(in_channels, in_channels, kernel_size=kernel_size,
|
|
|
stride=1, padding=kernel_size // 2, groups=in_channels, bias=False))
|
|
|
self.add_module('bn1', nn.BatchNorm2d(in_channels))
|
|
|
self.add_module('relu', Mish())
|
|
|
self.add_module('dwconv1x1', nn.Conv2d(in_channels, out_channels, kernel_size=1,
|
|
|
stride=1, padding=0, groups=in_channels, bias=False))
|
|
|
self.add_module('bn2', nn.BatchNorm2d(out_channels))
|
|
|
|
|
|
class Conv2d_BN(nn.Sequential):
|
|
|
def __init__(self, a, b, ks=1, stride=1, pad=0, groups=1):
|
|
|
super().__init__()
|
|
|
self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, groups=groups, bias=False))
|
|
|
self.add_module('bn', nn.BatchNorm2d(b))
|
|
|
|
|
|
class FFN(nn.Module):
|
|
|
def __init__(self, ed, h):
|
|
|
super().__init__()
|
|
|
self.pw1 = Conv2d_BN(ed, h)
|
|
|
self.act = Mish()
|
|
|
self.pw2 = Conv2d_BN(h, ed)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.pw2(self.act(self.pw1(x)))
|
|
|
|
|
|
class StochasticDepth(nn.Module):
|
|
|
def __init__(self, survival_prob=0.8):
|
|
|
super().__init__()
|
|
|
self.survival_prob = survival_prob
|
|
|
|
|
|
def forward(self, x):
|
|
|
if not self.training:
|
|
|
return x
|
|
|
batch_size = x.shape[0]
|
|
|
random_tensor = self.survival_prob + torch.rand([batch_size, 1, 1, 1], dtype=x.dtype, device=x.device)
|
|
|
binary_tensor = torch.floor(random_tensor)
|
|
|
return x * binary_tensor / self.survival_prob
|
|
|
|
|
|
class Residual(nn.Module):
|
|
|
def __init__(self, m, survival_prob=0.8):
|
|
|
super().__init__()
|
|
|
self.m = m
|
|
|
self.stochastic_depth = StochasticDepth(survival_prob)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return x + self.stochastic_depth(self.m(x))
|
|
|
|
|
|
class GLP_block(nn.Module):
|
|
|
def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,pa_ratio = 0.1, kernels=3, ssm_ratio=1, forward_type="v052d"):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
self.global_channels = int(global_ratio * dim)
|
|
|
self.local_channels = int(local_ratio * dim)
|
|
|
self.pa_channels = int(pa_ratio * dim)
|
|
|
self.identity_channels = dim - self.global_channels - self.local_channels - self.pa_channels
|
|
|
self.local_op = nn.ModuleList([
|
|
|
DWConv2d_BN_ReLU(self.local_channels, self.local_channels, k)
|
|
|
for k in [3, 5, 7]
|
|
|
]) if self.local_channels > 0 else nn.Identity()
|
|
|
self.global_op = MBWTConv2d(self.global_channels, kernel_size=kernels,
|
|
|
ssm_ratio=ssm_ratio, forward_type=forward_type) \
|
|
|
if self.global_channels > 0 else nn.Identity()
|
|
|
self.cbam = CBAM(dim)
|
|
|
self.proj = nn.Sequential(
|
|
|
Mish(),
|
|
|
Conv2d_BN(dim, dim),
|
|
|
CBAM(dim)
|
|
|
)
|
|
|
|
|
|
self.pa_op = PA(self.pa_channels, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU) \
|
|
|
if self.pa_channels > 0 else nn.Identity()
|
|
|
|
|
|
def forward(self, x):
|
|
|
x1, x2, x3, x4 = torch.split(x, [self.global_channels, self.local_channels, self.identity_channels, self.pa_channels], dim=1)
|
|
|
if isinstance(self.local_op, nn.ModuleList):
|
|
|
local_features = [op(x2) for op in self.local_op]
|
|
|
local_features = torch.cat(local_features, dim=1)
|
|
|
local_features = torch.mean(local_features, dim=1, keepdim=True)
|
|
|
local_features = local_features.expand(-1, self.local_channels, -1, -1)
|
|
|
else:
|
|
|
local_features = self.local_op(x2)
|
|
|
out = torch.cat([self.global_op(x1), local_features, x3, self.pa_op(x4)], dim=1)
|
|
|
return self.proj(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SASF(nn.Module):
|
|
|
def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,pa_ratio = 0.1, kernels=3, ssm_ratio=1, forward_type="v052d"):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
self.global_channels = int(global_ratio * dim)
|
|
|
self.local_channels = int(local_ratio * dim)
|
|
|
self.pa_channels = int(pa_ratio * dim)
|
|
|
self.identity_channels = dim - self.global_channels - self.local_channels - self.pa_channels
|
|
|
self.local_op = nn.ModuleList([
|
|
|
DWConv2d_BN_ReLU(self.local_channels, self.local_channels, k)
|
|
|
for k in [3, 5, 7]
|
|
|
]) if self.local_channels > 0 else nn.Identity()
|
|
|
self.global_op = MBWTConv2d(self.global_channels, kernel_size=kernels,
|
|
|
ssm_ratio=ssm_ratio, forward_type=forward_type) \
|
|
|
if self.global_channels > 0 else nn.Identity()
|
|
|
self.cbam = CBAM(dim)
|
|
|
self.proj = nn.Sequential(
|
|
|
Mish(),
|
|
|
Conv2d_BN(dim, dim),
|
|
|
CBAM(dim)
|
|
|
)
|
|
|
|
|
|
self.pa_op = PA(self.pa_channels, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU) \
|
|
|
if self.pa_channels > 0 else nn.Identity()
|
|
|
|
|
|
def forward(self, x):
|
|
|
x1, x2, x3, x4 = torch.split(x, [self.global_channels, self.local_channels, self.identity_channels, self.pa_channels], dim=1)
|
|
|
if isinstance(self.local_op, nn.ModuleList):
|
|
|
local_features = [op(x2) for op in self.local_op]
|
|
|
local_features = torch.cat(local_features, dim=1)
|
|
|
local_features = torch.mean(local_features, dim=1, keepdim=True)
|
|
|
local_features = local_features.expand(-1, self.local_channels, -1, -1)
|
|
|
else:
|
|
|
local_features = self.local_op(x2)
|
|
|
out = torch.cat([self.global_op(x1), local_features, x3, self.pa_op(x4)], dim=1)
|
|
|
return self.proj(out)
|
|
|
|
|
|
|
|
|
class ViLLayer(nn.Module):
|
|
|
def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
self.norm = nn.LayerNorm(dim)
|
|
|
self.vil = ViLBlock(
|
|
|
dim= self.dim,
|
|
|
direction=SequenceTraversal.ROWWISE_FROM_TOP_LEFT
|
|
|
)
|
|
|
|
|
|
@autocast(enabled=False)
|
|
|
def forward(self, x):
|
|
|
if x.dtype == torch.float16:
|
|
|
x = x.type(torch.float32)
|
|
|
B, C = x.shape[:2]
|
|
|
assert C == self.dim
|
|
|
n_tokens = x.shape[2:].numel()
|
|
|
img_dims = x.shape[2:]
|
|
|
x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
|
|
|
x_vil = self.vil(x_flat)
|
|
|
out = x_vil.transpose(-1, -2).reshape(B, C, *img_dims)
|
|
|
|
|
|
return out
|
|
|
|
|
|
def dsconv_3x3(in_channel, out_channel):
|
|
|
return nn.Sequential(
|
|
|
nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel),
|
|
|
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1),
|
|
|
nn.BatchNorm2d(out_channel),
|
|
|
nn.ReLU(inplace=True)
|
|
|
)
|
|
|
|
|
|
def conv_1x1(in_channel, out_channel):
|
|
|
return nn.Sequential(
|
|
|
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
|
|
|
nn.BatchNorm2d(out_channel),
|
|
|
nn.ReLU(inplace=True)
|
|
|
)
|
|
|
|
|
|
class SqueezeAxialPositionalEmbedding(nn.Module):
|
|
|
def __init__(self, dim, shape):
|
|
|
super().__init__()
|
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.randn([1, dim, shape]))
|
|
|
|
|
|
def forward(self, x):
|
|
|
B, C, N = x.shape
|
|
|
x = x + F.interpolate(self.pos_embed, size=(N), mode='linear', align_corners=False)
|
|
|
|
|
|
return x
|
|
|
|
|
|
class SEBlock(nn.Module):
|
|
|
def __init__(self, channels, r=16):
|
|
|
super().__init__()
|
|
|
self.fc = nn.Sequential(
|
|
|
nn.AdaptiveAvgPool2d(1),
|
|
|
nn.Conv2d(channels, channels//r, 1),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(channels//r, channels, 1),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
def forward(self, x):
|
|
|
w = self.fc(x)
|
|
|
return x * w
|
|
|
class CTTF1(nn.Module):
|
|
|
def __init__(self, in_channel, out_channel,global_ratio=0.2, local_ratio=0.2, pa_ratio = 0.2 ,kernels=5, ssm_ratio=2.0, forward_type="v052d"):
|
|
|
super().__init__()
|
|
|
self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
|
|
|
self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
|
|
|
self.catconv = dsconv_3x3(in_channel * 2, out_channel)
|
|
|
self.convA = nn.Conv2d(in_channel, 1, 1)
|
|
|
self.convB = nn.Conv2d(in_channel, 1, 1)
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
self.mixer = Residual(GLP_block(in_channel, global_ratio, local_ratio,pa_ratio, kernels, ssm_ratio, forward_type))
|
|
|
self.mixer2 = Residual(
|
|
|
SASF(in_channel, global_ratio = 0, local_ratio = 0.1, pa_ratio = 0, kernels = 5, ssm_ratio = 1, forward_type = "v052d"))
|
|
|
|
|
|
self.fuse = nn.Sequential(
|
|
|
nn.Conv2d(in_channel * 3, in_channel, kernel_size=1),
|
|
|
nn.ReLU(inplace=True)
|
|
|
)
|
|
|
self.cbam = CBAM(in_channel * 3)
|
|
|
|
|
|
self.act = nn.SiLU()
|
|
|
def forward(self, xA, xB):
|
|
|
x_diffA = self.mixer(xA)
|
|
|
x_diffB = self.mixer(xB)
|
|
|
|
|
|
f1 = x_diffA
|
|
|
f2 = x_diffB
|
|
|
diff_signed = f1 - f2
|
|
|
diff_abs = torch.abs(diff_signed)
|
|
|
sum_feat = f1 + f2
|
|
|
|
|
|
diff_signed = self.mixer2(diff_signed)
|
|
|
diff_abs = self.mixer2(diff_abs)
|
|
|
sum_feat = self.mixer2(sum_feat)
|
|
|
|
|
|
f_fuse = torch.cat([diff_signed, diff_abs, sum_feat], dim=1)
|
|
|
|
|
|
f_fuse = self.cbam(f_fuse)
|
|
|
x_diff = self.fuse(f_fuse)
|
|
|
|
|
|
return x_diff
|
|
|
|
|
|
class CTTF2(nn.Module):
|
|
|
def __init__(self, in_channel, out_channel, global_ratio=0.25, local_ratio=0.25, pa_ratio=0, kernels=7,
|
|
|
ssm_ratio=2.0, forward_type="v052d"):
|
|
|
super().__init__()
|
|
|
self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
|
|
|
self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
|
|
|
self.catconv = dsconv_3x3(in_channel * 2, out_channel)
|
|
|
self.convA = nn.Conv2d(in_channel, 1, 1)
|
|
|
self.convB = nn.Conv2d(in_channel, 1, 1)
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
|
|
|
self.mixer = Residual(
|
|
|
GLP_block(in_channel, global_ratio, local_ratio, pa_ratio, kernels, ssm_ratio, forward_type))
|
|
|
self.mixer2 = Residual(
|
|
|
SASF(in_channel, global_ratio=0, local_ratio=0.1, pa_ratio=0, kernels=5, ssm_ratio=1,
|
|
|
forward_type="v052d"))
|
|
|
|
|
|
self.fuse = nn.Sequential(
|
|
|
nn.Conv2d(in_channel * 3, in_channel, kernel_size=1),
|
|
|
nn.ReLU(inplace=True)
|
|
|
)
|
|
|
self.cbam = CBAM(in_channel * 3)
|
|
|
|
|
|
self.act = nn.SiLU()
|
|
|
|
|
|
def forward(self, xA, xB):
|
|
|
x_diffA = self.mixer(xA)
|
|
|
x_diffB = self.mixer(xB)
|
|
|
|
|
|
f1 = x_diffA
|
|
|
f2 = x_diffB
|
|
|
diff_signed = f1 - f2
|
|
|
diff_abs = torch.abs(diff_signed)
|
|
|
sum_feat = f1 + f2
|
|
|
|
|
|
diff_signed = self.mixer2(diff_signed)
|
|
|
diff_abs = self.mixer2(diff_abs)
|
|
|
sum_feat = self.mixer2(sum_feat)
|
|
|
f_fuse = torch.cat([diff_signed, diff_abs, sum_feat], dim=1)
|
|
|
f_fuse = self.cbam(f_fuse)
|
|
|
x_diff = self.fuse(f_fuse)
|
|
|
|
|
|
return x_diff
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=True):
|
|
|
super().__init__()
|
|
|
out_features = out_features or in_features
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
|
|
Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear
|
|
|
self.fc1 = Linear(in_features, hidden_features)
|
|
|
self.act = act_layer()
|
|
|
self.fc2 = Linear(hidden_features, out_features)
|
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.fc1(x)
|
|
|
x = self.act(x)
|
|
|
x = self.drop(x)
|
|
|
x = self.fc2(x)
|
|
|
x = self.drop(x)
|
|
|
return x
|
|
|
|
|
|
class LHBlock(nn.Module):
|
|
|
def __init__(self, channels_l, channels_h):
|
|
|
super().__init__()
|
|
|
self.channels_l = channels_l
|
|
|
self.channels_h = channels_h
|
|
|
self.cross_size = 12
|
|
|
self.cross_kv = nn.Sequential(
|
|
|
nn.BatchNorm2d(channels_l),
|
|
|
nn.AdaptiveMaxPool2d(output_size=(self.cross_size, self.cross_size)),
|
|
|
nn.Conv2d(channels_l, 2 * channels_h, 1, 1, 0)
|
|
|
)
|
|
|
|
|
|
self.conv = conv_1x1(channels_l, channels_h)
|
|
|
self.norm = nn.BatchNorm2d(channels_h)
|
|
|
|
|
|
self.mlp_l = Mlp(in_features=channels_l, out_features=channels_l)
|
|
|
self.mlp_h = Mlp(in_features=channels_h, out_features=channels_h)
|
|
|
|
|
|
def _act_sn(self, x):
|
|
|
_, _, H, W = x.shape
|
|
|
inner_channel = self.cross_size * self.cross_size
|
|
|
x = x.reshape([-1, inner_channel, H, W]) * (inner_channel**-0.5)
|
|
|
x = F.softmax(x, dim=1)
|
|
|
x = x.reshape([1, -1, H, W])
|
|
|
return x
|
|
|
|
|
|
def attn_h(self, x_h, cross_k, cross_v):
|
|
|
B, _, H, W = x_h.shape
|
|
|
x_h = self.norm(x_h)
|
|
|
x_h = x_h.reshape([1, -1, H, W])
|
|
|
x_h = F.conv2d(x_h, cross_k, bias=None, stride=1, padding=0,
|
|
|
groups=B)
|
|
|
x_h = self._act_sn(x_h)
|
|
|
x_h = F.conv2d(x_h, cross_v, bias=None, stride=1, padding=0,
|
|
|
groups=B)
|
|
|
x_h = x_h.reshape([-1, self.channels_h, H,
|
|
|
W])
|
|
|
|
|
|
return x_h
|
|
|
|
|
|
def forward(self, x_l, x_h):
|
|
|
x_l = x_l + self.mlp_l(x_l)
|
|
|
x_l_conv = self.conv(x_l)
|
|
|
x_h = x_h + F.interpolate(x_l_conv, size=x_h.shape[2:], mode='bilinear')
|
|
|
|
|
|
cross_kv = self.cross_kv(x_l)
|
|
|
cross_k, cross_v = cross_kv.split(self.channels_h, 1)
|
|
|
cross_k = cross_k.permute(0, 2, 3, 1).reshape([-1, self.channels_h, 1, 1])
|
|
|
cross_v = cross_v.reshape([-1, self.cross_size * self.cross_size, 1, 1])
|
|
|
|
|
|
x_h = x_h + self.attn_h(x_h, cross_k, cross_v)
|
|
|
x_h = x_h + self.mlp_h(x_h)
|
|
|
|
|
|
return x_h
|
|
|
|
|
|
|
|
|
class CTTF(nn.Module):
|
|
|
def __init__(self, channels=[40, 80, 192, 384]):
|
|
|
super().__init__()
|
|
|
self.channels = channels
|
|
|
self.fusion0 = CTTF1(channels[0], channels[0])
|
|
|
self.fusion1 = CTTF1(channels[1], channels[1])
|
|
|
self.fusion2 = CTTF2(channels[2], channels[2])
|
|
|
self.fusion3 = CTTF2(channels[3], channels[3])
|
|
|
|
|
|
self.LHBlock1 = LHBlock(channels[1], channels[0])
|
|
|
self.LHBlock2 = LHBlock(channels[2], channels[0])
|
|
|
self.LHBlock3 = LHBlock(channels[3], channels[0])
|
|
|
|
|
|
self.mlp1 = Mlp(in_features=channels[0], out_features=channels[0])
|
|
|
self.mlp2 = Mlp(in_features=channels[0], out_features=2)
|
|
|
self.dwc = dsconv_3x3(channels[0], channels[0])
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
featuresA, featuresB = inputs
|
|
|
|
|
|
|
|
|
x_diff_0 = self.fusion0(featuresA[0], featuresB[0])
|
|
|
x_diff_1 = self.fusion1(featuresA[1], featuresB[1])
|
|
|
|
|
|
|
|
|
x_diff_2 = self.fusion2(featuresA[2], featuresB[2])
|
|
|
x_diff_3 = self.fusion3(featuresA[3], featuresB[3])
|
|
|
|
|
|
x_h = x_diff_0
|
|
|
x_h = self.LHBlock1(x_diff_1, x_h)
|
|
|
x_h = self.LHBlock2(x_diff_2, x_h)
|
|
|
x_h = self.LHBlock3(x_diff_3, x_h)
|
|
|
|
|
|
out = self.mlp2(self.dwc(x_h) + self.mlp1(x_h))
|
|
|
|
|
|
out = F.interpolate(
|
|
|
out,
|
|
|
scale_factor=(4, 4),
|
|
|
mode="bilinear",
|
|
|
align_corners=False,
|
|
|
)
|
|
|
return out
|
|
|
|
|
|
|