import torch from torch import nn from torch.nn import functional as F import torch import torch.nn as nn import torch.nn.functional as F class ParameterAttention(nn.Module): def __init__(self, in_channels, reduction_ratio=0.5): super(ParameterAttention, self).__init__() reduced_channels = max(1, int(in_channels * reduction_ratio)) self.dense_one = nn.Linear(in_channels, reduced_channels) self.dense_two = nn.Linear(reduced_channels, in_channels) def forward(self, x): b, c, h, w = x.shape maxpool = F.adaptive_max_pool2d(x, 1).view(b, c) avgpool = F.adaptive_avg_pool2d(x, 1).view(b, c) mlp_max = self.dense_two(F.relu(self.dense_one(maxpool))).view(b, c, 1, 1) mlp_avg = self.dense_two(F.relu(self.dense_one(avgpool))).view(b, c, 1, 1) channel_attention = torch.sigmoid(mlp_max + mlp_avg) return channel_attention * x + x # 残差连接 class MixedFusionBlockD(nn.Module): def __init__(self, in_channels,out_channels): super(MixedFusionBlockD, self).__init__() self.param_attention = ParameterAttention(in_channels * 15) # 15种融合方式 self.conv1 = nn.Conv2d(in_channels * 15, in_channels, kernel_size=3, padding=1) # self.bn1 = nn.BatchNorm2d(in_channels) self.bn1 = nn.InstanceNorm2d(in_channels, affine=True) self.conv2 = nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1) # self.bn2 = nn.BatchNorm2d(out_channels) self.bn2 = nn.InstanceNorm2d(out_channels, affine=True) def forward(self, x1, x2, x3, xx): fusion_tensors = [ x1, x2, x3, x1 + x2, x1 * x2, torch.max(x1, x2), torch.abs(x1 - x2), x1 + x3, x1 * x3, torch.max(x1, x3), torch.abs(x1 - x3), x2 + x3, x2 * x3, torch.max(x2, x3), torch.abs(x2 - x3) ] out_fusion_c = torch.cat(fusion_tensors, dim=1) out_fusion = self.param_attention(out_fusion_c) # out_fusion = F.leaky_relu(self.conv1(out_fusion), negative_slope=0.2) out_fusion = self.conv1(out_fusion) out_fusion = self.bn1(out_fusion) # BatchNorm out_fusion = F.leaky_relu(out_fusion, negative_slope=0.2) out = torch.cat([out_fusion, xx], dim=1) # out = F.leaky_relu(self.conv2(out), negative_slope=0.2) out = self.conv2(out) out = self.bn2(out) # BatchNorm out = F.leaky_relu(out, negative_slope=0.2) return out class MixedFusionBlockU(nn.Module): def __init__(self, in_channels,out_channels): super(MixedFusionBlockU, self).__init__() self.param_attention = ParameterAttention(in_channels * 15) self.conv1 = nn.Conv2d(in_channels * 15, in_channels, kernel_size=3, padding=1) # self.bn1 = nn.BatchNorm2d(in_channels) # 添加 BatchNorm self.bn1 = nn.InstanceNorm2d(in_channels, affine=True) self.conv2 = nn.Conv2d(in_channels * 4, out_channels, kernel_size=3, padding=1) # self.bn2 = nn.BatchNorm2d(out_channels) # 添加 BatchNorm self.bn2 = nn.InstanceNorm2d(out_channels, affine=True) def forward(self, x1, x2, x3, xx, skip): fusion_tensors = [ x1, x2, x3, x1 + x2, x1 * x2, torch.max(x1, x2), torch.abs(x1 - x2), x1 + x3, x1 * x3, torch.max(x1, x3), torch.abs(x1 - x3), x2 + x3, x2 * x3, torch.max(x2, x3), torch.abs(x2 - x3) ] out_fusion_c = torch.cat(fusion_tensors, dim=1) out_fusion = self.param_attention(out_fusion_c) # out_fusion = F.relu(self.conv1(out_fusion)) out_fusion = self.conv1(out_fusion) out_fusion = self.bn1(out_fusion) # BatchNorm out_fusion = F.relu(out_fusion) out = torch.cat([out_fusion, xx, skip], dim=1) # out = F.relu(self.conv2(out)) out = self.conv2(out) out = self.bn2(out) out = F.relu(out) return out class MixedFusionBlock0(nn.Module): def __init__(self, in_channels, out_channels): super(MixedFusionBlock0, self).__init__() self.param_attention = ParameterAttention(in_channels * 15) self.conv1 = nn.Conv2d(in_channels * 15, in_channels, kernel_size=3, padding=1) # self.bn1 = nn.BatchNorm2d(in_channels) # 添加 BatchNorm self.bn1 = nn.InstanceNorm2d(in_channels, affine=True) self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) # self.bn2 = nn.BatchNorm2d(out_channels) # 添加 BatchNorm self.bn2 = nn.InstanceNorm2d(out_channels, affine=True) def forward(self, x1, x2, x3): fusion_tensors = [ x1, x2, x3, x1 + x2, x1 * x2, torch.max(x1, x2), torch.abs(x1 - x2), x1 + x3, x1 * x3, torch.max(x1, x3), torch.abs(x1 - x3), x2 + x3, x2 * x3, torch.max(x2, x3), torch.abs(x2 - x3) ] out_fusion_c = torch.cat(fusion_tensors, dim=1) out_fusion = self.param_attention(out_fusion_c) # out_fusion = F.leaky_relu(self.conv1(out_fusion), negative_slope=0.2) out_fusion = self.conv1(out_fusion) out_fusion = self.bn1(out_fusion) # 应用 BatchNorm out_fusion = F.leaky_relu(out_fusion, negative_slope=0.2) # out = F.leaky_relu(self.conv2(out_fusion), negative_slope=0.2) out = self.conv2(out_fusion) out = self.bn2(out) # 应用 BatchNorm out = F.leaky_relu(out, negative_slope=0.2) return out # class EncoderBlock(nn.Module): # """Encoder block""" # # def __init__(self, inplanes, outplanes, kernel_size=3, stride=2, padding=1, norm=True): # super().__init__() # self.lrelu = nn.LeakyReLU(0.2, inplace=True) # self.conv = nn.Conv2d(inplanes, outplanes, kernel_size, stride, padding) # self.bn = nn.BatchNorm2d(outplanes) if norm else None # # def forward(self, x): # fx = self.lrelu(x) # fx = self.conv(fx) # if self.bn is not None: # fx = self.bn(fx) # return fx class EncoderBlock(nn.Module): """Encoder block with two convolutional layers. First convolution: stride=1 Second convolution: stride=2 """ def __init__(self, inplanes, outplanes, kernel_size=3, padding=1, norm=True): super().__init__() self.lrelu = nn.LeakyReLU(0.2, inplace=False) # First convolution (stride=1) self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size, stride=1, padding=padding) self.bn1 = nn.BatchNorm2d(outplanes) if norm else None # self.bn1 = nn.InstanceNorm2d(outplanes, affine=True) if norm else None # Second convolution (stride=2) self.conv2 = nn.Conv2d(outplanes, outplanes, kernel_size, stride=2, padding=padding) self.bn2 = nn.BatchNorm2d(outplanes) if norm else None # self.bn2 = nn.InstanceNorm2d(outplanes, affine=True) if norm else None def forward(self, x): fx = self.conv1(x) if self.bn1 is not None: fx = self.bn1(fx) fx = self.lrelu(fx) fx = self.conv2(fx) if self.bn2 is not None: fx = self.bn2(fx) fx = self.lrelu(fx) return fx class DecoderBlock(nn.Module): """Decoder block""" def __init__(self, inplanes, outplanes, kernel_size=3, stride=2, padding=1, dropout=False, output_padding=1): super().__init__() self.relu = nn.ReLU(inplace=False) self.deconv = nn.ConvTranspose2d(inplanes, outplanes, kernel_size, stride, padding, output_padding=output_padding) self.bn = nn.BatchNorm2d(outplanes) # self.bn = nn.InstanceNorm2d(outplanes, affine=True) self.dropout = nn.Dropout2d(p=0.4, inplace=False) if dropout else None def forward(self, x): fx = self.deconv(x) fx = self.bn(fx) fx = self.relu(fx) if self.dropout is not None: fx = self.dropout(fx) return fx class UnetGenerator(nn.Module): """U-Net like Encoder-Decoder model""" def __init__(self): super().__init__() # input 1 # Encoder (Downsampling) self.encoder1_i1 = EncoderBlock(1, 64) # 192x384 -> 96x192 self.encoder2_i1 = EncoderBlock(64, 128) # 96x192 -> 48x96 self.encoder3_i1 = EncoderBlock(128, 256) # 48x96 -> 24x48 self.encoder4_i1 = EncoderBlock(256, 512) # 24x48 -> 12x24 self.encoder5_i1 = EncoderBlock(512, 512, norm=False) # Decoder (Upsampling) self.decoder4_i1 = DecoderBlock(512, 512, dropout=True, output_padding=1) # 3x6 -> 6x12 self.decoder3_i1 = DecoderBlock(512, 256, dropout=True, output_padding=1) # 6x12 -> 12x24 self.decoder2_i1 = DecoderBlock(256, 128, output_padding=1) # 12x24 -> 24x48 self.decoder1_i1 = DecoderBlock(128, 64, output_padding=1) # 24x48 -> 48x96 self.decoder0_i1 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1, output_padding=1) # 96x192 -> 192x384 # input 2 # Encoder (Downsampling) self.encoder1_i2 = EncoderBlock(1, 64) # 192x384 -> 96x192 self.encoder2_i2 = EncoderBlock(64, 128) # 96x192 -> 48x96 self.encoder3_i2 = EncoderBlock(128, 256) # 48x96 -> 24x48 self.encoder4_i2 = EncoderBlock(256, 512) # 24x48 -> 12x24 self.encoder5_i2 = EncoderBlock(512, 512, norm=False) # 6x12 -> 3x6 # Decoder (Upsampling) self.decoder4_i2 = DecoderBlock(512, 512, dropout=True, output_padding=1) # 3x6 -> 6x12 self.decoder3_i2 = DecoderBlock(512, 256, dropout=True, output_padding=1) # 6x12 -> 12x24 self.decoder2_i2 = DecoderBlock(256, 128, output_padding=1) # 12x24 -> 24x48 self.decoder1_i2 = DecoderBlock(128, 64, output_padding=1) # 24x48 -> 48x96 self.decoder0_i2 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1, output_padding=1) # 96x192 -> 192x384 # input 3 # Encoder (Downsampling) self.encoder1_i3 = EncoderBlock(1, 64) # 192x384 -> 96x192 self.encoder2_i3 = EncoderBlock(64, 128) # 96x192 -> 48x96 self.encoder3_i3 = EncoderBlock(128, 256) # 48x96 -> 24x48 self.encoder4_i3 = EncoderBlock(256, 512) # 24x48 -> 12x24 self.encoder5_i3 = EncoderBlock(512, 512, norm=False) # 6x12 -> 3x6 # Decoder (Upsampling) self.decoder4_i3 = DecoderBlock(512, 512, dropout=True, output_padding=1) # 3x6 -> 6x12 self.decoder3_i3 = DecoderBlock(512, 256, dropout=True, output_padding=1) # 6x12 -> 12x24 self.decoder2_i3 = DecoderBlock(256, 128, output_padding=1) # 12x24 -> 24x48 self.decoder1_i3 = DecoderBlock(128, 64, output_padding=1) # 24x48 -> 48x96 self.decoder0_i3 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1,output_padding=1) # 96x192 -> 192x384 self.MixedFusion_block_0 = MixedFusionBlock0(in_channels=64, out_channels=128) self.MixedFusion_block_d_e2 = MixedFusionBlockD(in_channels=128, out_channels=256) self.MixedFusion_block_d_e3 = MixedFusionBlockD(in_channels=256, out_channels=512) self.MixedFusion_block_d_e4 = MixedFusionBlockD(in_channels=512, out_channels=1024) self.MixedFusion_block_u_d4 = MixedFusionBlockU(in_channels=512, out_channels=256) self.MixedFusion_block_u_d3 = MixedFusionBlockU(in_channels=256, out_channels=128) self.MixedFusion_block_u_d2 = MixedFusionBlockU(in_channels=128, out_channels=64) self.MixedFusion_block_u_d1 = MixedFusionBlockU(in_channels=64, out_channels=32) self.decoder_gd4 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1) self.decoder_gd3 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1) self.decoder_gd2 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1) self.decoder_gd1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.conv1 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() self.final_conv = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1) self.tanh = nn.Tanh() def forward(self, x1, x2, x3): """ Forward pass for the U-Net Generator. Inputs: input1, input2, input3: Tensor of shape (B, 1, 192, 384) Output: output: Tensor of shape (B, 1, 192, 384) """ # Encoder forward input1 e1_i1 = self.encoder1_i1(x1) # (B, C, 192, 384) -> (B, C1, 96, 192) e2_i1 = self.encoder2_i1(e1_i1) # (B, C1, 96, 192) -> (B, C2, 48, 96) e3_i1 = self.encoder3_i1(e2_i1) # (B, C2, 48, 96) -> (B, C3, 24, 48) e4_i1 = self.encoder4_i1(e3_i1) # (B, C3, 24, 48) -> (B, C4, 12, 24) e5_i1 = self.encoder5_i1(e4_i1) # (B, C4, 12, 24) -> (B, C5, 6, 12) d4_i1 = self.decoder4_i1(e5_i1) # (B, C6, 3, 6) -> (B, C5, 6, 12) d3_i1 = self.decoder3_i1(d4_i1) # (B, 2*C5, 6, 12) -> (B, C4, 12, 24) d2_i1 = self.decoder2_i1(d3_i1) # (B, 2*C4, 12, 24) -> (B, C3, 24, 48) d1_i1 = self.decoder1_i1(d2_i1) # (B, 2*C3, 24, 48) -> (B, C2, 48, 96) d0_i1 = self.decoder0_i1(d1_i1) out_i1 = torch.tanh(d0_i1) # (B, 1, 192, 384) # Encoder forward input2 e1_i2 = self.encoder1_i2(x2) # 96x192 e2_i2 = self.encoder2_i2(e1_i2) # 48x96 e3_i2 = self.encoder3_i2(e2_i2) # 24x48 e4_i2 = self.encoder4_i2(e3_i2) # 12x24 e5_i2 = self.encoder5_i2(e4_i2) # 6x12 d4_i2 = self.decoder4_i2(e5_i2) d3_i2 = self.decoder3_i2(d4_i2) d2_i2 = self.decoder2_i2(d3_i2) d1_i2 = self.decoder1_i2(d2_i2) d0_i2 = self.decoder0_i2(d1_i2) out_i2 = torch.tanh(d0_i2) # Encoder forward input3 e1_i3 = self.encoder1_i3(x3) # 96x192 e2_i3 = self.encoder2_i3(e1_i3) # 48x96 e3_i3 = self.encoder3_i3(e2_i3) # 24x48 e4_i3 = self.encoder4_i3(e3_i3) # 12x24 e5_i3 = self.encoder5_i3(e4_i3) # 6x12 d4_i3 = self.decoder4_i3(e5_i3) d3_i3 = self.decoder3_i3(d4_i3) d2_i3 = self.decoder2_i3(d3_i3) d1_i3 = self.decoder1_i3(d2_i3) d0_i3 = self.decoder0_i3(d1_i3) out_i3 = torch.tanh(d0_i3) # generation g_e1 = self.MixedFusion_block_0(e1_i1, e1_i2, e1_i3) # (B, C1, 96, 192) -> (B, C1, 96, 192) 64-128 g_e1_m = self.pool(g_e1) # (B, C1, 96, 192) -> (B, C1, 48, 96) 128 g_e2 = self.MixedFusion_block_d_e2(e2_i1, e2_i2, e2_i3, g_e1_m) # (B, C2, 48, 96) 128-256 g_e2_m = self.pool(g_e2) # (B, C2, 48, 96) -> (B, C2, 24, 48) 256 g_e3 = self.MixedFusion_block_d_e3(e3_i1, e3_i2, e3_i3, g_e2_m) # (B, C3, 24, 48) 256-512 g_e3_m = self.pool(g_e3) # (B, C3, 24, 48) -> (B, C3, 12, 24) 512 g_e4 = self.MixedFusion_block_d_e4(e4_i1, e4_i2, e4_i3, g_e3_m) # (B, C4, 12, 24) 512-1024 g_s1 = self.relu(self.conv1(g_e4)) # (B, C5, 6, 12) g_s2 = self.relu(self.conv2(g_s1)) # (B, C5, 6, 12) 512 g_d4 = self.MixedFusion_block_u_d4(d4_i1, d4_i2, d4_i3, g_s2, g_e4) # (B, C5, 6, 12) 512 g_d4 = self.decoder_gd4(g_d4) g_d3 = self.MixedFusion_block_u_d3(d3_i1, d3_i2, d3_i3, g_d4, g_e3) # (B, C4, 12, 24) 256 g_d3 = self.decoder_gd3(g_d3) g_d2 = self.MixedFusion_block_u_d2(d2_i1, d2_i2, d2_i3, g_d3, g_e2) # (B, C3, 24, 48) 128 g_d2 = self.decoder_gd2(g_d2) g_d1 = self.MixedFusion_block_u_d1(d1_i1, d1_i2, d1_i3, g_d2, g_e1) # (B, C2, 48, 96) 64 g_d1 = self.decoder_gd1(g_d1) g_target = self.tanh(self.final_conv(g_d1)) return g_target, out_i1, out_i2, out_i3