InPeerReview's picture
Upload 2 files
2317bc0 verified
import torch
from torch import nn
from torch.cuda.amp import autocast
from rscd.models.decoderheads.vision_lstm import ViLBlock, SequenceTraversal
from torch.nn import functional as F
from functools import partial
from rscd.models.backbones.lib_mamba.vmambanew import SS2D
import pywt
class PA(nn.Module):
def __init__(self, dim, norm_layer, act_layer):
super().__init__()
self.p_conv = nn.Sequential(
nn.Conv2d(dim, dim*4, 1, bias=False),
norm_layer(dim*4),
act_layer(),
nn.Conv2d(dim*4, dim, 1, bias=False)
)
self.gate_fn = nn.Sigmoid()
def forward(self, x):
att = self.p_conv(x)
x = x * self.gate_fn(att)
return x
class Mish(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class _ScaleModule(nn.Module):
def __init__(self, dims, init_scale=1.0):
super().__init__()
self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
def forward(self, x):
return torch.mul(self.weight, x)
def create_wavelet_filter(wave, in_size, out_size, dtype=torch.float):
w = pywt.Wavelet(wave)
dec_hi = torch.tensor(w.dec_hi[::-1], dtype=dtype)
dec_lo = torch.tensor(w.dec_lo[::-1], dtype=dtype)
dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
rec_hi = torch.tensor(w.rec_hi[::-1], dtype=dtype).flip(dims=[0])
rec_lo = torch.tensor(w.rec_lo[::-1], dtype=dtype).flip(dims=[0])
rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
return dec_filters, rec_filters
def wavelet_transform(x, filters):
b, c, h, w = x.shape
pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)
x = x.reshape(b, c, 4, h // 2, w // 2)
return x
def inverse_wavelet_transform(x, filters):
b, c, _, h_half, w_half = x.shape
pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
x = x.reshape(b, c * 4, h_half, w_half)
x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)
return x
class MBWTConv2d(nn.Module):
def __init__(self, in_channels, kernel_size=5, wt_levels=1, wt_type='db1', ssm_ratio=1, forward_type="v05"):
super().__init__()
assert in_channels == in_channels
self.wt_levels = wt_levels
self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels)
self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
self.wt_function = partial(wavelet_transform, filters=self.wt_filter)
self.iwt_function = partial(inverse_wavelet_transform, filters=self.iwt_filter)
self.global_atten = SS2D(d_model=in_channels, d_state=1, ssm_ratio=ssm_ratio, initialize="v2",
forward_type=forward_type, channel_first=True, k_group=2)
self.base_scale = _ScaleModule([1, in_channels, 1, 1])
self.wavelet_convs = nn.ModuleList([
nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', groups=in_channels * 4)
for _ in range(wt_levels)
])
self.wavelet_scale = nn.ModuleList([
_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1)
for _ in range(wt_levels)
])
def forward(self, x):
x_ll_in_levels, x_h_in_levels, shapes_in_levels = [], [], []
curr_x_ll = x
for i in range(self.wt_levels):
curr_shape = curr_x_ll.shape
shapes_in_levels.append(curr_shape)
if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
curr_x_ll = F.pad(curr_x_ll, (0, curr_shape[3] % 2, 0, curr_shape[2] % 2))
curr_x = self.wt_function(curr_x_ll)
curr_x_ll = curr_x[:, :, 0, :, :]
shape_x = curr_x.shape
curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag)).reshape(shape_x)
x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])
next_x_ll = 0
for i in range(self.wt_levels - 1, -1, -1):
curr_x_ll = x_ll_in_levels.pop() + next_x_ll
curr_x = torch.cat([curr_x_ll.unsqueeze(2), x_h_in_levels.pop()], dim=2)
next_x_ll = self.iwt_function(curr_x)
next_x_ll = next_x_ll[:, :, :shapes_in_levels[i][2], :shapes_in_levels[i][3]]
x_tag = next_x_ll
x = self.base_scale(self.global_atten(x)) + x_tag
return x
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)
# 空间注意力模块
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
# CBAM 注意力模块
class CBAM(nn.Module):
def __init__(self, in_planes):
super(CBAM, self).__init__()
self.ca = ChannelAttention(in_planes)
self.sa = SpatialAttention()
def forward(self, x):
x = self.ca(x) * x
x = self.sa(x) * x
return x
class DynamicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=False, num_experts=4):
super(DynamicConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.bias = bias
self.num_experts = num_experts
self.experts = nn.ModuleList([
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias)
for _ in range(num_experts)
])
self.gating = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, num_experts, 1, bias=False),
nn.Softmax(dim=1)
)
def forward(self, x):
gates = self.gating(x)
gates = gates.view(x.size(0), self.num_experts, 1, 1, 1)
outputs = []
for i, expert in enumerate(self.experts):
outputs.append(expert(x).unsqueeze(1))
outputs = torch.cat(outputs, dim=1)
out = (gates * outputs).sum(dim=1)
return out
class DWConv2d_BN_ReLU(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3):
super().__init__()
self.add_module('dwconv3x3', DynamicConv2d(in_channels, in_channels, kernel_size=kernel_size,
stride=1, padding=kernel_size // 2, groups=in_channels, bias=False))
self.add_module('bn1', nn.BatchNorm2d(in_channels))
self.add_module('relu', Mish())
self.add_module('dwconv1x1', nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=1, padding=0, groups=in_channels, bias=False))
self.add_module('bn2', nn.BatchNorm2d(out_channels))
class Conv2d_BN(nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, groups=1):
super().__init__()
self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, groups=groups, bias=False))
self.add_module('bn', nn.BatchNorm2d(b))
class FFN(nn.Module):
def __init__(self, ed, h):
super().__init__()
self.pw1 = Conv2d_BN(ed, h)
self.act = Mish()
self.pw2 = Conv2d_BN(h, ed)
def forward(self, x):
return self.pw2(self.act(self.pw1(x)))
class StochasticDepth(nn.Module):
def __init__(self, survival_prob=0.8):
super().__init__()
self.survival_prob = survival_prob
def forward(self, x):
if not self.training:
return x
batch_size = x.shape[0]
random_tensor = self.survival_prob + torch.rand([batch_size, 1, 1, 1], dtype=x.dtype, device=x.device)
binary_tensor = torch.floor(random_tensor)
return x * binary_tensor / self.survival_prob
class Residual(nn.Module):
def __init__(self, m, survival_prob=0.8):
super().__init__()
self.m = m
self.stochastic_depth = StochasticDepth(survival_prob)
def forward(self, x):
return x + self.stochastic_depth(self.m(x))
class GLP_block(nn.Module):
def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,pa_ratio = 0.1, kernels=3, ssm_ratio=1, forward_type="v052d"):
super().__init__()
self.dim = dim
self.global_channels = int(global_ratio * dim)
self.local_channels = int(local_ratio * dim)
self.pa_channels = int(pa_ratio * dim)
self.identity_channels = dim - self.global_channels - self.local_channels - self.pa_channels
self.local_op = nn.ModuleList([
DWConv2d_BN_ReLU(self.local_channels, self.local_channels, k)
for k in [3, 5, 7]
]) if self.local_channels > 0 else nn.Identity()
self.global_op = MBWTConv2d(self.global_channels, kernel_size=kernels,
ssm_ratio=ssm_ratio, forward_type=forward_type) \
if self.global_channels > 0 else nn.Identity()
self.cbam = CBAM(dim)
self.proj = nn.Sequential(
Mish(),
Conv2d_BN(dim, dim),
CBAM(dim)
)
self.pa_op = PA(self.pa_channels, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU) \
if self.pa_channels > 0 else nn.Identity()
def forward(self, x):
x1, x2, x3, x4 = torch.split(x, [self.global_channels, self.local_channels, self.identity_channels, self.pa_channels], dim=1)
if isinstance(self.local_op, nn.ModuleList):
local_features = [op(x2) for op in self.local_op]
local_features = torch.cat(local_features, dim=1)
local_features = torch.mean(local_features, dim=1, keepdim=True)
local_features = local_features.expand(-1, self.local_channels, -1, -1)
else:
local_features = self.local_op(x2)
out = torch.cat([self.global_op(x1), local_features, x3, self.pa_op(x4)], dim=1)
return self.proj(out)
class SASF(nn.Module):
def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,pa_ratio = 0.1, kernels=3, ssm_ratio=1, forward_type="v052d"):
super().__init__()
self.dim = dim
self.global_channels = int(global_ratio * dim)
self.local_channels = int(local_ratio * dim)
self.pa_channels = int(pa_ratio * dim)
self.identity_channels = dim - self.global_channels - self.local_channels - self.pa_channels
self.local_op = nn.ModuleList([
DWConv2d_BN_ReLU(self.local_channels, self.local_channels, k)
for k in [3, 5, 7]
]) if self.local_channels > 0 else nn.Identity()
self.global_op = MBWTConv2d(self.global_channels, kernel_size=kernels,
ssm_ratio=ssm_ratio, forward_type=forward_type) \
if self.global_channels > 0 else nn.Identity()
self.cbam = CBAM(dim)
self.proj = nn.Sequential(
Mish(),
Conv2d_BN(dim, dim),
CBAM(dim)
)
self.pa_op = PA(self.pa_channels, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU) \
if self.pa_channels > 0 else nn.Identity()
def forward(self, x):
x1, x2, x3, x4 = torch.split(x, [self.global_channels, self.local_channels, self.identity_channels, self.pa_channels], dim=1)
if isinstance(self.local_op, nn.ModuleList):
local_features = [op(x2) for op in self.local_op]
local_features = torch.cat(local_features, dim=1)
local_features = torch.mean(local_features, dim=1, keepdim=True)
local_features = local_features.expand(-1, self.local_channels, -1, -1)
else:
local_features = self.local_op(x2)
out = torch.cat([self.global_op(x1), local_features, x3, self.pa_op(x4)], dim=1)
return self.proj(out)
class ViLLayer(nn.Module):
def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2):
super().__init__()
self.dim = dim
self.norm = nn.LayerNorm(dim)
self.vil = ViLBlock(
dim= self.dim,
direction=SequenceTraversal.ROWWISE_FROM_TOP_LEFT
)
@autocast(enabled=False)
def forward(self, x):
if x.dtype == torch.float16:
x = x.type(torch.float32)
B, C = x.shape[:2]
assert C == self.dim
n_tokens = x.shape[2:].numel()
img_dims = x.shape[2:]
x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
x_vil = self.vil(x_flat)
out = x_vil.transpose(-1, -2).reshape(B, C, *img_dims)
return out
def dsconv_3x3(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel),
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True)
)
def conv_1x1(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True)
)
class SqueezeAxialPositionalEmbedding(nn.Module):
def __init__(self, dim, shape):
super().__init__()
self.pos_embed = nn.Parameter(torch.randn([1, dim, shape]))
def forward(self, x):
B, C, N = x.shape
x = x + F.interpolate(self.pos_embed, size=(N), mode='linear', align_corners=False)
return x
class SEBlock(nn.Module):
def __init__(self, channels, r=16):
super().__init__()
self.fc = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels//r, 1),
nn.ReLU(inplace=True),
nn.Conv2d(channels//r, channels, 1),
nn.Sigmoid()
)
def forward(self, x):
w = self.fc(x) # (B, C, 1, 1)
return x * w
class CTTF1(nn.Module):
def __init__(self, in_channel, out_channel,global_ratio=0.2, local_ratio=0.2, pa_ratio = 0.2 ,kernels=5, ssm_ratio=2.0, forward_type="v052d"):
super().__init__()
self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
self.catconv = dsconv_3x3(in_channel * 2, out_channel)
self.convA = nn.Conv2d(in_channel, 1, 1)
self.convB = nn.Conv2d(in_channel, 1, 1)
self.sigmoid = nn.Sigmoid()
self.mixer = Residual(GLP_block(in_channel, global_ratio, local_ratio,pa_ratio, kernels, ssm_ratio, forward_type))
self.mixer2 = Residual(
SASF(in_channel, global_ratio = 0, local_ratio = 0.1, pa_ratio = 0, kernels = 5, ssm_ratio = 1, forward_type = "v052d"))
self.fuse = nn.Sequential(
nn.Conv2d(in_channel * 3, in_channel, kernel_size=1),
nn.ReLU(inplace=True)
)
self.cbam = CBAM(in_channel * 3)
self.act = nn.SiLU()
def forward(self, xA, xB):
x_diffA = self.mixer(xA)
x_diffB = self.mixer(xB)
f1 = x_diffA
f2 = x_diffB
diff_signed = f1 - f2
diff_abs = torch.abs(diff_signed)
sum_feat = f1 + f2
diff_signed = self.mixer2(diff_signed)
diff_abs = self.mixer2(diff_abs)
sum_feat = self.mixer2(sum_feat)
# 将多路特征在通道维度拼接
f_fuse = torch.cat([diff_signed, diff_abs, sum_feat], dim=1) # (B, 4C, H, W)
# 再接一个 1x1 卷积降维或提炼信息
f_fuse = self.cbam(f_fuse)
x_diff = self.fuse(f_fuse)
return x_diff
class CTTF2(nn.Module):
def __init__(self, in_channel, out_channel, global_ratio=0.25, local_ratio=0.25, pa_ratio=0, kernels=7,
ssm_ratio=2.0, forward_type="v052d"):
super().__init__()
self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
self.catconv = dsconv_3x3(in_channel * 2, out_channel)
self.convA = nn.Conv2d(in_channel, 1, 1)
self.convB = nn.Conv2d(in_channel, 1, 1)
self.sigmoid = nn.Sigmoid()
self.mixer = Residual(
GLP_block(in_channel, global_ratio, local_ratio, pa_ratio, kernels, ssm_ratio, forward_type))
self.mixer2 = Residual(
SASF(in_channel, global_ratio=0, local_ratio=0.1, pa_ratio=0, kernels=5, ssm_ratio=1,
forward_type="v052d"))
self.fuse = nn.Sequential(
nn.Conv2d(in_channel * 3, in_channel, kernel_size=1),
nn.ReLU(inplace=True)
)
self.cbam = CBAM(in_channel * 3)
self.act = nn.SiLU()
def forward(self, xA, xB):
x_diffA = self.mixer(xA)
x_diffB = self.mixer(xB)
f1 = x_diffA
f2 = x_diffB
diff_signed = f1 - f2
diff_abs = torch.abs(diff_signed)
sum_feat = f1 + f2
diff_signed = self.mixer2(diff_signed)
diff_abs = self.mixer2(diff_abs)
sum_feat = self.mixer2(sum_feat)
f_fuse = torch.cat([diff_signed, diff_abs, sum_feat], dim=1) # (B, 4C, H, W)
f_fuse = self.cbam(f_fuse)
x_diff = self.fuse(f_fuse)
return x_diff
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=True):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear
self.fc1 = Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class LHBlock(nn.Module):
def __init__(self, channels_l, channels_h):
super().__init__()
self.channels_l = channels_l
self.channels_h = channels_h
self.cross_size = 12
self.cross_kv = nn.Sequential(
nn.BatchNorm2d(channels_l),
nn.AdaptiveMaxPool2d(output_size=(self.cross_size, self.cross_size)),
nn.Conv2d(channels_l, 2 * channels_h, 1, 1, 0)
)
self.conv = conv_1x1(channels_l, channels_h)
self.norm = nn.BatchNorm2d(channels_h)
self.mlp_l = Mlp(in_features=channels_l, out_features=channels_l)
self.mlp_h = Mlp(in_features=channels_h, out_features=channels_h)
def _act_sn(self, x):
_, _, H, W = x.shape
inner_channel = self.cross_size * self.cross_size
x = x.reshape([-1, inner_channel, H, W]) * (inner_channel**-0.5)
x = F.softmax(x, dim=1)
x = x.reshape([1, -1, H, W])
return x
def attn_h(self, x_h, cross_k, cross_v):
B, _, H, W = x_h.shape
x_h = self.norm(x_h)
x_h = x_h.reshape([1, -1, H, W]) # n,c_in,h,w -> 1,n*c_in,h,w
x_h = F.conv2d(x_h, cross_k, bias=None, stride=1, padding=0,
groups=B) # 1,n*c_in,h,w -> 1,n*144,h,w (group=B)
x_h = self._act_sn(x_h)
x_h = F.conv2d(x_h, cross_v, bias=None, stride=1, padding=0,
groups=B) # 1,n*144,h,w -> 1, n*c_in,h,w (group=B)
x_h = x_h.reshape([-1, self.channels_h, H,
W]) # 1, n*c_in,h,w -> n,c_in,h,w (c_in = c_out)
return x_h
def forward(self, x_l, x_h):
x_l = x_l + self.mlp_l(x_l)
x_l_conv = self.conv(x_l)
x_h = x_h + F.interpolate(x_l_conv, size=x_h.shape[2:], mode='bilinear')
cross_kv = self.cross_kv(x_l)
cross_k, cross_v = cross_kv.split(self.channels_h, 1)
cross_k = cross_k.permute(0, 2, 3, 1).reshape([-1, self.channels_h, 1, 1]) # n*144,channels_h,1,1
cross_v = cross_v.reshape([-1, self.cross_size * self.cross_size, 1, 1]) # n*channels_h,144,1,1
x_h = x_h + self.attn_h(x_h, cross_k, cross_v) # [4, 40, 128, 128]
x_h = x_h + self.mlp_h(x_h)
return x_h
class CTTF(nn.Module):
def __init__(self, channels=[40, 80, 192, 384]):
super().__init__()
self.channels = channels
self.fusion0 = CTTF1(channels[0], channels[0])
self.fusion1 = CTTF1(channels[1], channels[1])
self.fusion2 = CTTF2(channels[2], channels[2])
self.fusion3 = CTTF2(channels[3], channels[3])
self.LHBlock1 = LHBlock(channels[1], channels[0])
self.LHBlock2 = LHBlock(channels[2], channels[0])
self.LHBlock3 = LHBlock(channels[3], channels[0])
self.mlp1 = Mlp(in_features=channels[0], out_features=channels[0])
self.mlp2 = Mlp(in_features=channels[0], out_features=2)
self.dwc = dsconv_3x3(channels[0], channels[0])
def forward(self, inputs):
featuresA, featuresB = inputs
# fA_0, fA_1, fA_2, fA_3 = featuresA
# fB_0, fB_1, fB_2, fB_3 = featuresB
x_diff_0 = self.fusion0(featuresA[0], featuresB[0]) # [4, 40, 128, 128]
x_diff_1 = self.fusion1(featuresA[1], featuresB[1]) # [4, 80, 64, 64]
# x_diff_2 = featuresA[2] - featuresB[2]
# x_diff_3 = featuresA[3] - featuresB[3]
x_diff_2 = self.fusion2(featuresA[2], featuresB[2]) # [4, 192, 32, 32]
x_diff_3 = self.fusion3(featuresA[3], featuresB[3]) # [4, 384, 16, 16]
x_h = x_diff_0
x_h = self.LHBlock1(x_diff_1, x_h) # [4, 40, 128, 128]
x_h = self.LHBlock2(x_diff_2, x_h)
x_h = self.LHBlock3(x_diff_3, x_h)
out = self.mlp2(self.dwc(x_h) + self.mlp1(x_h))
out = F.interpolate(
out,
scale_factor=(4, 4),
mode="bilinear",
align_corners=False,
)
return out