import torch.nn as nn from functools import reduce from operator import mul import torch class Reconstruction3DEncoder(nn.Module): def __init__(self, chnum_in): super(Reconstruction3DEncoder, self).__init__() # Dong Gong's paper code self.chnum_in = chnum_in feature_num = 128 feature_num_2 = 96 feature_num_x2 = 256 self.encoder = nn.Sequential( nn.Conv3d(self.chnum_in, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_2), nn.LeakyReLU(0.2, inplace=True), nn.Conv3d(feature_num_2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(feature_num), nn.LeakyReLU(0.2, inplace=True), nn.Conv3d(feature_num, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_x2), nn.LeakyReLU(0.2, inplace=True), nn.Conv3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_x2), nn.LeakyReLU(0.2, inplace=True) ) def forward(self, x): x = self.encoder(x) return x class Reconstruction3DDecoder(nn.Module): def __init__(self, chnum_in): super(Reconstruction3DDecoder, self).__init__() # Dong Gong's paper code + Tanh self.chnum_in = chnum_in feature_num = 128 feature_num_2 = 96 feature_num_x2 = 256 self.decoder = nn.Sequential( nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_x2), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1)), nn.BatchNorm3d(feature_num), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_2), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose3d(feature_num_2, self.chnum_in, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), output_padding=(0, 1, 1)), nn.Tanh() ) def forward(self, x): x = self.decoder(x) return x class VST3DDecoder(nn.Module): def __init__(self, chnum_out): super(VST3DDecoder, self).__init__() # Dong Gong's paper code + Tanh self.chnum_out = chnum_out feature_num = 128 feature_num_2 = 96 feature_num_x2 = 256 feature_num_in = 768 self.transformer_decoder = nn.Sequential( # (4,768,4,8,8) nn.ConvTranspose3d(feature_num_in, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_x2), nn.LeakyReLU(0.2, inplace=True), # (4,256,4,16,16) nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_x2), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), output_padding=(0, 1, 1)), nn.BatchNorm3d(feature_num), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), output_padding=(0, 1, 1)), nn.BatchNorm3d(feature_num_2), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose3d(feature_num_2, self.chnum_out, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), output_padding=(0, 1, 1)), nn.Tanh() ) def forward(self, x): x = self.transformer_decoder(x) return x