Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ParameterAttention(nn.Module): | |
| def __init__(self, in_channels, reduction_ratio=0.5): | |
| super(ParameterAttention, self).__init__() | |
| reduced_channels = max(1, int(in_channels * reduction_ratio)) | |
| self.dense_one = nn.Linear(in_channels, reduced_channels) | |
| self.dense_two = nn.Linear(reduced_channels, in_channels) | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| maxpool = F.adaptive_max_pool2d(x, 1).view(b, c) | |
| avgpool = F.adaptive_avg_pool2d(x, 1).view(b, c) | |
| mlp_max = self.dense_two(F.relu(self.dense_one(maxpool))).view(b, c, 1, 1) | |
| mlp_avg = self.dense_two(F.relu(self.dense_one(avgpool))).view(b, c, 1, 1) | |
| channel_attention = torch.sigmoid(mlp_max + mlp_avg) | |
| return channel_attention * x + x # 残差连接 | |
| class MixedFusionBlockD(nn.Module): | |
| def __init__(self, in_channels,out_channels): | |
| super(MixedFusionBlockD, self).__init__() | |
| self.param_attention = ParameterAttention(in_channels * 15) # 15种融合方式 | |
| self.conv1 = nn.Conv2d(in_channels * 15, in_channels, kernel_size=3, padding=1) | |
| # self.bn1 = nn.BatchNorm2d(in_channels) | |
| self.bn1 = nn.InstanceNorm2d(in_channels, affine=True) | |
| self.conv2 = nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1) | |
| # self.bn2 = nn.BatchNorm2d(out_channels) | |
| self.bn2 = nn.InstanceNorm2d(out_channels, affine=True) | |
| def forward(self, x1, x2, x3, xx): | |
| fusion_tensors = [ | |
| x1, x2, x3, | |
| x1 + x2, x1 * x2, torch.max(x1, x2), torch.abs(x1 - x2), | |
| x1 + x3, x1 * x3, torch.max(x1, x3), torch.abs(x1 - x3), | |
| x2 + x3, x2 * x3, torch.max(x2, x3), torch.abs(x2 - x3) | |
| ] | |
| out_fusion_c = torch.cat(fusion_tensors, dim=1) | |
| out_fusion = self.param_attention(out_fusion_c) | |
| # out_fusion = F.leaky_relu(self.conv1(out_fusion), negative_slope=0.2) | |
| out_fusion = self.conv1(out_fusion) | |
| out_fusion = self.bn1(out_fusion) # BatchNorm | |
| out_fusion = F.leaky_relu(out_fusion, negative_slope=0.2) | |
| out = torch.cat([out_fusion, xx], dim=1) | |
| # out = F.leaky_relu(self.conv2(out), negative_slope=0.2) | |
| out = self.conv2(out) | |
| out = self.bn2(out) # BatchNorm | |
| out = F.leaky_relu(out, negative_slope=0.2) | |
| return out | |
| class MixedFusionBlockU(nn.Module): | |
| def __init__(self, in_channels,out_channels): | |
| super(MixedFusionBlockU, self).__init__() | |
| self.param_attention = ParameterAttention(in_channels * 15) | |
| self.conv1 = nn.Conv2d(in_channels * 15, in_channels, kernel_size=3, padding=1) | |
| # self.bn1 = nn.BatchNorm2d(in_channels) # 添加 BatchNorm | |
| self.bn1 = nn.InstanceNorm2d(in_channels, affine=True) | |
| self.conv2 = nn.Conv2d(in_channels * 4, out_channels, kernel_size=3, padding=1) | |
| # self.bn2 = nn.BatchNorm2d(out_channels) # 添加 BatchNorm | |
| self.bn2 = nn.InstanceNorm2d(out_channels, affine=True) | |
| def forward(self, x1, x2, x3, xx, skip): | |
| fusion_tensors = [ | |
| x1, x2, x3, | |
| x1 + x2, x1 * x2, torch.max(x1, x2), torch.abs(x1 - x2), | |
| x1 + x3, x1 * x3, torch.max(x1, x3), torch.abs(x1 - x3), | |
| x2 + x3, x2 * x3, torch.max(x2, x3), torch.abs(x2 - x3) | |
| ] | |
| out_fusion_c = torch.cat(fusion_tensors, dim=1) | |
| out_fusion = self.param_attention(out_fusion_c) | |
| # out_fusion = F.relu(self.conv1(out_fusion)) | |
| out_fusion = self.conv1(out_fusion) | |
| out_fusion = self.bn1(out_fusion) # BatchNorm | |
| out_fusion = F.relu(out_fusion) | |
| out = torch.cat([out_fusion, xx, skip], dim=1) | |
| # out = F.relu(self.conv2(out)) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = F.relu(out) | |
| return out | |
| class MixedFusionBlock0(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(MixedFusionBlock0, self).__init__() | |
| self.param_attention = ParameterAttention(in_channels * 15) | |
| self.conv1 = nn.Conv2d(in_channels * 15, in_channels, kernel_size=3, padding=1) | |
| # self.bn1 = nn.BatchNorm2d(in_channels) # 添加 BatchNorm | |
| self.bn1 = nn.InstanceNorm2d(in_channels, affine=True) | |
| self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) | |
| # self.bn2 = nn.BatchNorm2d(out_channels) # 添加 BatchNorm | |
| self.bn2 = nn.InstanceNorm2d(out_channels, affine=True) | |
| def forward(self, x1, x2, x3): | |
| fusion_tensors = [ | |
| x1, x2, x3, | |
| x1 + x2, x1 * x2, torch.max(x1, x2), torch.abs(x1 - x2), | |
| x1 + x3, x1 * x3, torch.max(x1, x3), torch.abs(x1 - x3), | |
| x2 + x3, x2 * x3, torch.max(x2, x3), torch.abs(x2 - x3) | |
| ] | |
| out_fusion_c = torch.cat(fusion_tensors, dim=1) | |
| out_fusion = self.param_attention(out_fusion_c) | |
| # out_fusion = F.leaky_relu(self.conv1(out_fusion), negative_slope=0.2) | |
| out_fusion = self.conv1(out_fusion) | |
| out_fusion = self.bn1(out_fusion) # 应用 BatchNorm | |
| out_fusion = F.leaky_relu(out_fusion, negative_slope=0.2) | |
| # out = F.leaky_relu(self.conv2(out_fusion), negative_slope=0.2) | |
| out = self.conv2(out_fusion) | |
| out = self.bn2(out) # 应用 BatchNorm | |
| out = F.leaky_relu(out, negative_slope=0.2) | |
| return out | |
| # class EncoderBlock(nn.Module): | |
| # """Encoder block""" | |
| # | |
| # def __init__(self, inplanes, outplanes, kernel_size=3, stride=2, padding=1, norm=True): | |
| # super().__init__() | |
| # self.lrelu = nn.LeakyReLU(0.2, inplace=True) | |
| # self.conv = nn.Conv2d(inplanes, outplanes, kernel_size, stride, padding) | |
| # self.bn = nn.BatchNorm2d(outplanes) if norm else None | |
| # | |
| # def forward(self, x): | |
| # fx = self.lrelu(x) | |
| # fx = self.conv(fx) | |
| # if self.bn is not None: | |
| # fx = self.bn(fx) | |
| # return fx | |
| class EncoderBlock(nn.Module): | |
| """Encoder block with two convolutional layers. | |
| First convolution: stride=1 | |
| Second convolution: stride=2 | |
| """ | |
| def __init__(self, inplanes, outplanes, kernel_size=3, padding=1, norm=True): | |
| super().__init__() | |
| self.lrelu = nn.LeakyReLU(0.2, inplace=False) | |
| # First convolution (stride=1) | |
| self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size, stride=1, padding=padding) | |
| self.bn1 = nn.BatchNorm2d(outplanes) if norm else None | |
| # self.bn1 = nn.InstanceNorm2d(outplanes, affine=True) if norm else None | |
| # Second convolution (stride=2) | |
| self.conv2 = nn.Conv2d(outplanes, outplanes, kernel_size, stride=2, padding=padding) | |
| self.bn2 = nn.BatchNorm2d(outplanes) if norm else None | |
| # self.bn2 = nn.InstanceNorm2d(outplanes, affine=True) if norm else None | |
| def forward(self, x): | |
| fx = self.conv1(x) | |
| if self.bn1 is not None: | |
| fx = self.bn1(fx) | |
| fx = self.lrelu(fx) | |
| fx = self.conv2(fx) | |
| if self.bn2 is not None: | |
| fx = self.bn2(fx) | |
| fx = self.lrelu(fx) | |
| return fx | |
| class DecoderBlock(nn.Module): | |
| """Decoder block""" | |
| def __init__(self, inplanes, outplanes, kernel_size=3, stride=2, padding=1, dropout=False, output_padding=1): | |
| super().__init__() | |
| self.relu = nn.ReLU(inplace=False) | |
| self.deconv = nn.ConvTranspose2d(inplanes, outplanes, kernel_size, stride, padding, output_padding=output_padding) | |
| self.bn = nn.BatchNorm2d(outplanes) | |
| # self.bn = nn.InstanceNorm2d(outplanes, affine=True) | |
| self.dropout = nn.Dropout2d(p=0.4, inplace=False) if dropout else None | |
| def forward(self, x): | |
| fx = self.deconv(x) | |
| fx = self.bn(fx) | |
| fx = self.relu(fx) | |
| if self.dropout is not None: | |
| fx = self.dropout(fx) | |
| return fx | |
| class UnetGenerator(nn.Module): | |
| """U-Net like Encoder-Decoder model""" | |
| def __init__(self): | |
| super().__init__() | |
| # input 1 | |
| # Encoder (Downsampling) | |
| self.encoder1_i1 = EncoderBlock(1, 64) # 192x384 -> 96x192 | |
| self.encoder2_i1 = EncoderBlock(64, 128) # 96x192 -> 48x96 | |
| self.encoder3_i1 = EncoderBlock(128, 256) # 48x96 -> 24x48 | |
| self.encoder4_i1 = EncoderBlock(256, 512) # 24x48 -> 12x24 | |
| self.encoder5_i1 = EncoderBlock(512, 512, norm=False) | |
| # Decoder (Upsampling) | |
| self.decoder4_i1 = DecoderBlock(512, 512, dropout=True, output_padding=1) # 3x6 -> 6x12 | |
| self.decoder3_i1 = DecoderBlock(512, 256, dropout=True, output_padding=1) # 6x12 -> 12x24 | |
| self.decoder2_i1 = DecoderBlock(256, 128, output_padding=1) # 12x24 -> 24x48 | |
| self.decoder1_i1 = DecoderBlock(128, 64, output_padding=1) # 24x48 -> 48x96 | |
| self.decoder0_i1 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1, output_padding=1) # 96x192 -> 192x384 | |
| # input 2 | |
| # Encoder (Downsampling) | |
| self.encoder1_i2 = EncoderBlock(1, 64) # 192x384 -> 96x192 | |
| self.encoder2_i2 = EncoderBlock(64, 128) # 96x192 -> 48x96 | |
| self.encoder3_i2 = EncoderBlock(128, 256) # 48x96 -> 24x48 | |
| self.encoder4_i2 = EncoderBlock(256, 512) # 24x48 -> 12x24 | |
| self.encoder5_i2 = EncoderBlock(512, 512, norm=False) # 6x12 -> 3x6 | |
| # Decoder (Upsampling) | |
| self.decoder4_i2 = DecoderBlock(512, 512, dropout=True, output_padding=1) # 3x6 -> 6x12 | |
| self.decoder3_i2 = DecoderBlock(512, 256, dropout=True, output_padding=1) # 6x12 -> 12x24 | |
| self.decoder2_i2 = DecoderBlock(256, 128, output_padding=1) # 12x24 -> 24x48 | |
| self.decoder1_i2 = DecoderBlock(128, 64, output_padding=1) # 24x48 -> 48x96 | |
| self.decoder0_i2 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1, output_padding=1) # 96x192 -> 192x384 | |
| # input 3 | |
| # Encoder (Downsampling) | |
| self.encoder1_i3 = EncoderBlock(1, 64) # 192x384 -> 96x192 | |
| self.encoder2_i3 = EncoderBlock(64, 128) # 96x192 -> 48x96 | |
| self.encoder3_i3 = EncoderBlock(128, 256) # 48x96 -> 24x48 | |
| self.encoder4_i3 = EncoderBlock(256, 512) # 24x48 -> 12x24 | |
| self.encoder5_i3 = EncoderBlock(512, 512, norm=False) # 6x12 -> 3x6 | |
| # Decoder (Upsampling) | |
| self.decoder4_i3 = DecoderBlock(512, 512, dropout=True, output_padding=1) # 3x6 -> 6x12 | |
| self.decoder3_i3 = DecoderBlock(512, 256, dropout=True, output_padding=1) # 6x12 -> 12x24 | |
| self.decoder2_i3 = DecoderBlock(256, 128, output_padding=1) # 12x24 -> 24x48 | |
| self.decoder1_i3 = DecoderBlock(128, 64, output_padding=1) # 24x48 -> 48x96 | |
| self.decoder0_i3 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1,output_padding=1) # 96x192 -> 192x384 | |
| self.MixedFusion_block_0 = MixedFusionBlock0(in_channels=64, out_channels=128) | |
| self.MixedFusion_block_d_e2 = MixedFusionBlockD(in_channels=128, out_channels=256) | |
| self.MixedFusion_block_d_e3 = MixedFusionBlockD(in_channels=256, out_channels=512) | |
| self.MixedFusion_block_d_e4 = MixedFusionBlockD(in_channels=512, out_channels=1024) | |
| self.MixedFusion_block_u_d4 = MixedFusionBlockU(in_channels=512, out_channels=256) | |
| self.MixedFusion_block_u_d3 = MixedFusionBlockU(in_channels=256, out_channels=128) | |
| self.MixedFusion_block_u_d2 = MixedFusionBlockU(in_channels=128, out_channels=64) | |
| self.MixedFusion_block_u_d1 = MixedFusionBlockU(in_channels=64, out_channels=32) | |
| self.decoder_gd4 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1) | |
| self.decoder_gd3 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1) | |
| self.decoder_gd2 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1) | |
| self.decoder_gd1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1) | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.conv1 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1) | |
| self.conv2 = nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1) | |
| self.relu = nn.ReLU() | |
| self.final_conv = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1) | |
| self.tanh = nn.Tanh() | |
| def forward(self, x1, x2, x3): | |
| """ | |
| Forward pass for the U-Net Generator. | |
| Inputs: | |
| input1, input2, input3: Tensor of shape (B, 1, 192, 384) | |
| Output: | |
| output: Tensor of shape (B, 1, 192, 384) | |
| """ | |
| # Encoder forward input1 | |
| e1_i1 = self.encoder1_i1(x1) # (B, C, 192, 384) -> (B, C1, 96, 192) | |
| e2_i1 = self.encoder2_i1(e1_i1) # (B, C1, 96, 192) -> (B, C2, 48, 96) | |
| e3_i1 = self.encoder3_i1(e2_i1) # (B, C2, 48, 96) -> (B, C3, 24, 48) | |
| e4_i1 = self.encoder4_i1(e3_i1) # (B, C3, 24, 48) -> (B, C4, 12, 24) | |
| e5_i1 = self.encoder5_i1(e4_i1) # (B, C4, 12, 24) -> (B, C5, 6, 12) | |
| d4_i1 = self.decoder4_i1(e5_i1) # (B, C6, 3, 6) -> (B, C5, 6, 12) | |
| d3_i1 = self.decoder3_i1(d4_i1) # (B, 2*C5, 6, 12) -> (B, C4, 12, 24) | |
| d2_i1 = self.decoder2_i1(d3_i1) # (B, 2*C4, 12, 24) -> (B, C3, 24, 48) | |
| d1_i1 = self.decoder1_i1(d2_i1) # (B, 2*C3, 24, 48) -> (B, C2, 48, 96) | |
| d0_i1 = self.decoder0_i1(d1_i1) | |
| out_i1 = torch.tanh(d0_i1) # (B, 1, 192, 384) | |
| # Encoder forward input2 | |
| e1_i2 = self.encoder1_i2(x2) # 96x192 | |
| e2_i2 = self.encoder2_i2(e1_i2) # 48x96 | |
| e3_i2 = self.encoder3_i2(e2_i2) # 24x48 | |
| e4_i2 = self.encoder4_i2(e3_i2) # 12x24 | |
| e5_i2 = self.encoder5_i2(e4_i2) # 6x12 | |
| d4_i2 = self.decoder4_i2(e5_i2) | |
| d3_i2 = self.decoder3_i2(d4_i2) | |
| d2_i2 = self.decoder2_i2(d3_i2) | |
| d1_i2 = self.decoder1_i2(d2_i2) | |
| d0_i2 = self.decoder0_i2(d1_i2) | |
| out_i2 = torch.tanh(d0_i2) | |
| # Encoder forward input3 | |
| e1_i3 = self.encoder1_i3(x3) # 96x192 | |
| e2_i3 = self.encoder2_i3(e1_i3) # 48x96 | |
| e3_i3 = self.encoder3_i3(e2_i3) # 24x48 | |
| e4_i3 = self.encoder4_i3(e3_i3) # 12x24 | |
| e5_i3 = self.encoder5_i3(e4_i3) # 6x12 | |
| d4_i3 = self.decoder4_i3(e5_i3) | |
| d3_i3 = self.decoder3_i3(d4_i3) | |
| d2_i3 = self.decoder2_i3(d3_i3) | |
| d1_i3 = self.decoder1_i3(d2_i3) | |
| d0_i3 = self.decoder0_i3(d1_i3) | |
| out_i3 = torch.tanh(d0_i3) | |
| # generation | |
| g_e1 = self.MixedFusion_block_0(e1_i1, e1_i2, e1_i3) # (B, C1, 96, 192) -> (B, C1, 96, 192) 64-128 | |
| g_e1_m = self.pool(g_e1) # (B, C1, 96, 192) -> (B, C1, 48, 96) 128 | |
| g_e2 = self.MixedFusion_block_d_e2(e2_i1, e2_i2, e2_i3, g_e1_m) # (B, C2, 48, 96) 128-256 | |
| g_e2_m = self.pool(g_e2) # (B, C2, 48, 96) -> (B, C2, 24, 48) 256 | |
| g_e3 = self.MixedFusion_block_d_e3(e3_i1, e3_i2, e3_i3, g_e2_m) # (B, C3, 24, 48) 256-512 | |
| g_e3_m = self.pool(g_e3) # (B, C3, 24, 48) -> (B, C3, 12, 24) 512 | |
| g_e4 = self.MixedFusion_block_d_e4(e4_i1, e4_i2, e4_i3, g_e3_m) # (B, C4, 12, 24) 512-1024 | |
| g_s1 = self.relu(self.conv1(g_e4)) # (B, C5, 6, 12) | |
| g_s2 = self.relu(self.conv2(g_s1)) # (B, C5, 6, 12) 512 | |
| g_d4 = self.MixedFusion_block_u_d4(d4_i1, d4_i2, d4_i3, g_s2, g_e4) # (B, C5, 6, 12) 512 | |
| g_d4 = self.decoder_gd4(g_d4) | |
| g_d3 = self.MixedFusion_block_u_d3(d3_i1, d3_i2, d3_i3, g_d4, g_e3) # (B, C4, 12, 24) 256 | |
| g_d3 = self.decoder_gd3(g_d3) | |
| g_d2 = self.MixedFusion_block_u_d2(d2_i1, d2_i2, d2_i3, g_d3, g_e2) # (B, C3, 24, 48) 128 | |
| g_d2 = self.decoder_gd2(g_d2) | |
| g_d1 = self.MixedFusion_block_u_d1(d1_i1, d1_i2, d1_i3, g_d2, g_e1) # (B, C2, 48, 96) 64 | |
| g_d1 = self.decoder_gd1(g_d1) | |
| g_target = self.tanh(self.final_conv(g_d1)) | |
| return g_target, out_i1, out_i2, out_i3 | |