Synthetic_Model / generator.py
zhang0319's picture
Update generator.py
f8e4a15 verified
raw
history blame
16.1 kB
import torch
from torch import nn
from torch.nn import functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
class ParameterAttention(nn.Module):
def __init__(self, in_channels, reduction_ratio=0.5):
super(ParameterAttention, self).__init__()
reduced_channels = max(1, int(in_channels * reduction_ratio))
self.dense_one = nn.Linear(in_channels, reduced_channels)
self.dense_two = nn.Linear(reduced_channels, in_channels)
def forward(self, x):
b, c, h, w = x.shape
maxpool = F.adaptive_max_pool2d(x, 1).view(b, c)
avgpool = F.adaptive_avg_pool2d(x, 1).view(b, c)
mlp_max = self.dense_two(F.relu(self.dense_one(maxpool))).view(b, c, 1, 1)
mlp_avg = self.dense_two(F.relu(self.dense_one(avgpool))).view(b, c, 1, 1)
channel_attention = torch.sigmoid(mlp_max + mlp_avg)
return channel_attention * x + x # 残差连接
class MixedFusionBlockD(nn.Module):
def __init__(self, in_channels,out_channels):
super(MixedFusionBlockD, self).__init__()
self.param_attention = ParameterAttention(in_channels * 15) # 15种融合方式
self.conv1 = nn.Conv2d(in_channels * 15, in_channels, kernel_size=3, padding=1)
# self.bn1 = nn.BatchNorm2d(in_channels)
self.bn1 = nn.InstanceNorm2d(in_channels, affine=True)
self.conv2 = nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1)
# self.bn2 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x1, x2, x3, xx):
fusion_tensors = [
x1, x2, x3,
x1 + x2, x1 * x2, torch.max(x1, x2), torch.abs(x1 - x2),
x1 + x3, x1 * x3, torch.max(x1, x3), torch.abs(x1 - x3),
x2 + x3, x2 * x3, torch.max(x2, x3), torch.abs(x2 - x3)
]
out_fusion_c = torch.cat(fusion_tensors, dim=1)
out_fusion = self.param_attention(out_fusion_c)
# out_fusion = F.leaky_relu(self.conv1(out_fusion), negative_slope=0.2)
out_fusion = self.conv1(out_fusion)
out_fusion = self.bn1(out_fusion) # BatchNorm
out_fusion = F.leaky_relu(out_fusion, negative_slope=0.2)
out = torch.cat([out_fusion, xx], dim=1)
# out = F.leaky_relu(self.conv2(out), negative_slope=0.2)
out = self.conv2(out)
out = self.bn2(out) # BatchNorm
out = F.leaky_relu(out, negative_slope=0.2)
return out
class MixedFusionBlockU(nn.Module):
def __init__(self, in_channels,out_channels):
super(MixedFusionBlockU, self).__init__()
self.param_attention = ParameterAttention(in_channels * 15)
self.conv1 = nn.Conv2d(in_channels * 15, in_channels, kernel_size=3, padding=1)
# self.bn1 = nn.BatchNorm2d(in_channels) # 添加 BatchNorm
self.bn1 = nn.InstanceNorm2d(in_channels, affine=True)
self.conv2 = nn.Conv2d(in_channels * 4, out_channels, kernel_size=3, padding=1)
# self.bn2 = nn.BatchNorm2d(out_channels) # 添加 BatchNorm
self.bn2 = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x1, x2, x3, xx, skip):
fusion_tensors = [
x1, x2, x3,
x1 + x2, x1 * x2, torch.max(x1, x2), torch.abs(x1 - x2),
x1 + x3, x1 * x3, torch.max(x1, x3), torch.abs(x1 - x3),
x2 + x3, x2 * x3, torch.max(x2, x3), torch.abs(x2 - x3)
]
out_fusion_c = torch.cat(fusion_tensors, dim=1)
out_fusion = self.param_attention(out_fusion_c)
# out_fusion = F.relu(self.conv1(out_fusion))
out_fusion = self.conv1(out_fusion)
out_fusion = self.bn1(out_fusion) # BatchNorm
out_fusion = F.relu(out_fusion)
out = torch.cat([out_fusion, xx, skip], dim=1)
# out = F.relu(self.conv2(out))
out = self.conv2(out)
out = self.bn2(out)
out = F.relu(out)
return out
class MixedFusionBlock0(nn.Module):
def __init__(self, in_channels, out_channels):
super(MixedFusionBlock0, self).__init__()
self.param_attention = ParameterAttention(in_channels * 15)
self.conv1 = nn.Conv2d(in_channels * 15, in_channels, kernel_size=3, padding=1)
# self.bn1 = nn.BatchNorm2d(in_channels) # 添加 BatchNorm
self.bn1 = nn.InstanceNorm2d(in_channels, affine=True)
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
# self.bn2 = nn.BatchNorm2d(out_channels) # 添加 BatchNorm
self.bn2 = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x1, x2, x3):
fusion_tensors = [
x1, x2, x3,
x1 + x2, x1 * x2, torch.max(x1, x2), torch.abs(x1 - x2),
x1 + x3, x1 * x3, torch.max(x1, x3), torch.abs(x1 - x3),
x2 + x3, x2 * x3, torch.max(x2, x3), torch.abs(x2 - x3)
]
out_fusion_c = torch.cat(fusion_tensors, dim=1)
out_fusion = self.param_attention(out_fusion_c)
# out_fusion = F.leaky_relu(self.conv1(out_fusion), negative_slope=0.2)
out_fusion = self.conv1(out_fusion)
out_fusion = self.bn1(out_fusion) # 应用 BatchNorm
out_fusion = F.leaky_relu(out_fusion, negative_slope=0.2)
# out = F.leaky_relu(self.conv2(out_fusion), negative_slope=0.2)
out = self.conv2(out_fusion)
out = self.bn2(out) # 应用 BatchNorm
out = F.leaky_relu(out, negative_slope=0.2)
return out
# class EncoderBlock(nn.Module):
# """Encoder block"""
#
# def __init__(self, inplanes, outplanes, kernel_size=3, stride=2, padding=1, norm=True):
# super().__init__()
# self.lrelu = nn.LeakyReLU(0.2, inplace=True)
# self.conv = nn.Conv2d(inplanes, outplanes, kernel_size, stride, padding)
# self.bn = nn.BatchNorm2d(outplanes) if norm else None
#
# def forward(self, x):
# fx = self.lrelu(x)
# fx = self.conv(fx)
# if self.bn is not None:
# fx = self.bn(fx)
# return fx
class EncoderBlock(nn.Module):
"""Encoder block with two convolutional layers.
First convolution: stride=1
Second convolution: stride=2
"""
def __init__(self, inplanes, outplanes, kernel_size=3, padding=1, norm=True):
super().__init__()
self.lrelu = nn.LeakyReLU(0.2, inplace=False)
# First convolution (stride=1)
self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size, stride=1, padding=padding)
self.bn1 = nn.BatchNorm2d(outplanes) if norm else None
# self.bn1 = nn.InstanceNorm2d(outplanes, affine=True) if norm else None
# Second convolution (stride=2)
self.conv2 = nn.Conv2d(outplanes, outplanes, kernel_size, stride=2, padding=padding)
self.bn2 = nn.BatchNorm2d(outplanes) if norm else None
# self.bn2 = nn.InstanceNorm2d(outplanes, affine=True) if norm else None
def forward(self, x):
fx = self.conv1(x)
if self.bn1 is not None:
fx = self.bn1(fx)
fx = self.lrelu(fx)
fx = self.conv2(fx)
if self.bn2 is not None:
fx = self.bn2(fx)
fx = self.lrelu(fx)
return fx
class DecoderBlock(nn.Module):
"""Decoder block"""
def __init__(self, inplanes, outplanes, kernel_size=3, stride=2, padding=1, dropout=False, output_padding=1):
super().__init__()
self.relu = nn.ReLU(inplace=False)
self.deconv = nn.ConvTranspose2d(inplanes, outplanes, kernel_size, stride, padding, output_padding=output_padding)
self.bn = nn.BatchNorm2d(outplanes)
# self.bn = nn.InstanceNorm2d(outplanes, affine=True)
self.dropout = nn.Dropout2d(p=0.4, inplace=False) if dropout else None
def forward(self, x):
fx = self.deconv(x)
fx = self.bn(fx)
fx = self.relu(fx)
if self.dropout is not None:
fx = self.dropout(fx)
return fx
class UnetGenerator(nn.Module):
"""U-Net like Encoder-Decoder model"""
def __init__(self):
super().__init__()
# input 1
# Encoder (Downsampling)
self.encoder1_i1 = EncoderBlock(1, 64) # 192x384 -> 96x192
self.encoder2_i1 = EncoderBlock(64, 128) # 96x192 -> 48x96
self.encoder3_i1 = EncoderBlock(128, 256) # 48x96 -> 24x48
self.encoder4_i1 = EncoderBlock(256, 512) # 24x48 -> 12x24
self.encoder5_i1 = EncoderBlock(512, 512, norm=False)
# Decoder (Upsampling)
self.decoder4_i1 = DecoderBlock(512, 512, dropout=True, output_padding=1) # 3x6 -> 6x12
self.decoder3_i1 = DecoderBlock(512, 256, dropout=True, output_padding=1) # 6x12 -> 12x24
self.decoder2_i1 = DecoderBlock(256, 128, output_padding=1) # 12x24 -> 24x48
self.decoder1_i1 = DecoderBlock(128, 64, output_padding=1) # 24x48 -> 48x96
self.decoder0_i1 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1, output_padding=1) # 96x192 -> 192x384
# input 2
# Encoder (Downsampling)
self.encoder1_i2 = EncoderBlock(1, 64) # 192x384 -> 96x192
self.encoder2_i2 = EncoderBlock(64, 128) # 96x192 -> 48x96
self.encoder3_i2 = EncoderBlock(128, 256) # 48x96 -> 24x48
self.encoder4_i2 = EncoderBlock(256, 512) # 24x48 -> 12x24
self.encoder5_i2 = EncoderBlock(512, 512, norm=False) # 6x12 -> 3x6
# Decoder (Upsampling)
self.decoder4_i2 = DecoderBlock(512, 512, dropout=True, output_padding=1) # 3x6 -> 6x12
self.decoder3_i2 = DecoderBlock(512, 256, dropout=True, output_padding=1) # 6x12 -> 12x24
self.decoder2_i2 = DecoderBlock(256, 128, output_padding=1) # 12x24 -> 24x48
self.decoder1_i2 = DecoderBlock(128, 64, output_padding=1) # 24x48 -> 48x96
self.decoder0_i2 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1, output_padding=1) # 96x192 -> 192x384
# input 3
# Encoder (Downsampling)
self.encoder1_i3 = EncoderBlock(1, 64) # 192x384 -> 96x192
self.encoder2_i3 = EncoderBlock(64, 128) # 96x192 -> 48x96
self.encoder3_i3 = EncoderBlock(128, 256) # 48x96 -> 24x48
self.encoder4_i3 = EncoderBlock(256, 512) # 24x48 -> 12x24
self.encoder5_i3 = EncoderBlock(512, 512, norm=False) # 6x12 -> 3x6
# Decoder (Upsampling)
self.decoder4_i3 = DecoderBlock(512, 512, dropout=True, output_padding=1) # 3x6 -> 6x12
self.decoder3_i3 = DecoderBlock(512, 256, dropout=True, output_padding=1) # 6x12 -> 12x24
self.decoder2_i3 = DecoderBlock(256, 128, output_padding=1) # 12x24 -> 24x48
self.decoder1_i3 = DecoderBlock(128, 64, output_padding=1) # 24x48 -> 48x96
self.decoder0_i3 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1,output_padding=1) # 96x192 -> 192x384
self.MixedFusion_block_0 = MixedFusionBlock0(in_channels=64, out_channels=128)
self.MixedFusion_block_d_e2 = MixedFusionBlockD(in_channels=128, out_channels=256)
self.MixedFusion_block_d_e3 = MixedFusionBlockD(in_channels=256, out_channels=512)
self.MixedFusion_block_d_e4 = MixedFusionBlockD(in_channels=512, out_channels=1024)
self.MixedFusion_block_u_d4 = MixedFusionBlockU(in_channels=512, out_channels=256)
self.MixedFusion_block_u_d3 = MixedFusionBlockU(in_channels=256, out_channels=128)
self.MixedFusion_block_u_d2 = MixedFusionBlockU(in_channels=128, out_channels=64)
self.MixedFusion_block_u_d1 = MixedFusionBlockU(in_channels=64, out_channels=32)
self.decoder_gd4 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
self.decoder_gd3 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
self.decoder_gd2 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
self.decoder_gd1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv1 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.final_conv = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)
self.tanh = nn.Tanh()
def forward(self, x1, x2, x3):
"""
Forward pass for the U-Net Generator.
Inputs:
input1, input2, input3: Tensor of shape (B, 1, 192, 384)
Output:
output: Tensor of shape (B, 1, 192, 384)
"""
# Encoder forward input1
e1_i1 = self.encoder1_i1(x1) # (B, C, 192, 384) -> (B, C1, 96, 192)
e2_i1 = self.encoder2_i1(e1_i1) # (B, C1, 96, 192) -> (B, C2, 48, 96)
e3_i1 = self.encoder3_i1(e2_i1) # (B, C2, 48, 96) -> (B, C3, 24, 48)
e4_i1 = self.encoder4_i1(e3_i1) # (B, C3, 24, 48) -> (B, C4, 12, 24)
e5_i1 = self.encoder5_i1(e4_i1) # (B, C4, 12, 24) -> (B, C5, 6, 12)
d4_i1 = self.decoder4_i1(e5_i1) # (B, C6, 3, 6) -> (B, C5, 6, 12)
d3_i1 = self.decoder3_i1(d4_i1) # (B, 2*C5, 6, 12) -> (B, C4, 12, 24)
d2_i1 = self.decoder2_i1(d3_i1) # (B, 2*C4, 12, 24) -> (B, C3, 24, 48)
d1_i1 = self.decoder1_i1(d2_i1) # (B, 2*C3, 24, 48) -> (B, C2, 48, 96)
d0_i1 = self.decoder0_i1(d1_i1)
out_i1 = torch.tanh(d0_i1) # (B, 1, 192, 384)
# Encoder forward input2
e1_i2 = self.encoder1_i2(x2) # 96x192
e2_i2 = self.encoder2_i2(e1_i2) # 48x96
e3_i2 = self.encoder3_i2(e2_i2) # 24x48
e4_i2 = self.encoder4_i2(e3_i2) # 12x24
e5_i2 = self.encoder5_i2(e4_i2) # 6x12
d4_i2 = self.decoder4_i2(e5_i2)
d3_i2 = self.decoder3_i2(d4_i2)
d2_i2 = self.decoder2_i2(d3_i2)
d1_i2 = self.decoder1_i2(d2_i2)
d0_i2 = self.decoder0_i2(d1_i2)
out_i2 = torch.tanh(d0_i2)
# Encoder forward input3
e1_i3 = self.encoder1_i3(x3) # 96x192
e2_i3 = self.encoder2_i3(e1_i3) # 48x96
e3_i3 = self.encoder3_i3(e2_i3) # 24x48
e4_i3 = self.encoder4_i3(e3_i3) # 12x24
e5_i3 = self.encoder5_i3(e4_i3) # 6x12
d4_i3 = self.decoder4_i3(e5_i3)
d3_i3 = self.decoder3_i3(d4_i3)
d2_i3 = self.decoder2_i3(d3_i3)
d1_i3 = self.decoder1_i3(d2_i3)
d0_i3 = self.decoder0_i3(d1_i3)
out_i3 = torch.tanh(d0_i3)
# generation
g_e1 = self.MixedFusion_block_0(e1_i1, e1_i2, e1_i3) # (B, C1, 96, 192) -> (B, C1, 96, 192) 64-128
g_e1_m = self.pool(g_e1) # (B, C1, 96, 192) -> (B, C1, 48, 96) 128
g_e2 = self.MixedFusion_block_d_e2(e2_i1, e2_i2, e2_i3, g_e1_m) # (B, C2, 48, 96) 128-256
g_e2_m = self.pool(g_e2) # (B, C2, 48, 96) -> (B, C2, 24, 48) 256
g_e3 = self.MixedFusion_block_d_e3(e3_i1, e3_i2, e3_i3, g_e2_m) # (B, C3, 24, 48) 256-512
g_e3_m = self.pool(g_e3) # (B, C3, 24, 48) -> (B, C3, 12, 24) 512
g_e4 = self.MixedFusion_block_d_e4(e4_i1, e4_i2, e4_i3, g_e3_m) # (B, C4, 12, 24) 512-1024
g_s1 = self.relu(self.conv1(g_e4)) # (B, C5, 6, 12)
g_s2 = self.relu(self.conv2(g_s1)) # (B, C5, 6, 12) 512
g_d4 = self.MixedFusion_block_u_d4(d4_i1, d4_i2, d4_i3, g_s2, g_e4) # (B, C5, 6, 12) 512
g_d4 = self.decoder_gd4(g_d4)
g_d3 = self.MixedFusion_block_u_d3(d3_i1, d3_i2, d3_i3, g_d4, g_e3) # (B, C4, 12, 24) 256
g_d3 = self.decoder_gd3(g_d3)
g_d2 = self.MixedFusion_block_u_d2(d2_i1, d2_i2, d2_i3, g_d3, g_e2) # (B, C3, 24, 48) 128
g_d2 = self.decoder_gd2(g_d2)
g_d1 = self.MixedFusion_block_u_d1(d1_i1, d1_i2, d1_i3, g_d2, g_e1) # (B, C2, 48, 96) 64
g_d1 = self.decoder_gd1(g_d1)
g_target = self.tanh(self.final_conv(g_d1))
return g_target, out_i1, out_i2, out_i3