Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| from functools import reduce | |
| from operator import mul | |
| import torch | |
| class Reconstruction3DEncoder(nn.Module): | |
| def __init__(self, chnum_in): | |
| super(Reconstruction3DEncoder, self).__init__() | |
| # Dong Gong's paper code | |
| self.chnum_in = chnum_in | |
| feature_num = 128 | |
| feature_num_2 = 96 | |
| feature_num_x2 = 256 | |
| self.encoder = nn.Sequential( | |
| nn.Conv3d(self.chnum_in, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)), | |
| nn.BatchNorm3d(feature_num_2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv3d(feature_num_2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), | |
| nn.BatchNorm3d(feature_num), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv3d(feature_num, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), | |
| nn.BatchNorm3d(feature_num_x2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), | |
| nn.BatchNorm3d(feature_num_x2), | |
| nn.LeakyReLU(0.2, inplace=True) | |
| ) | |
| def forward(self, x): | |
| x = self.encoder(x) | |
| return x | |
| class Reconstruction3DDecoder(nn.Module): | |
| def __init__(self, chnum_in): | |
| super(Reconstruction3DDecoder, self).__init__() | |
| # Dong Gong's paper code + Tanh | |
| self.chnum_in = chnum_in | |
| feature_num = 128 | |
| feature_num_2 = 96 | |
| feature_num_x2 = 256 | |
| self.decoder = nn.Sequential( | |
| nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), | |
| output_padding=(1, 1, 1)), | |
| nn.BatchNorm3d(feature_num_x2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), | |
| output_padding=(1, 1, 1)), | |
| nn.BatchNorm3d(feature_num), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), | |
| output_padding=(1, 1, 1)), | |
| nn.BatchNorm3d(feature_num_2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.ConvTranspose3d(feature_num_2, self.chnum_in, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), | |
| output_padding=(0, 1, 1)), | |
| nn.Tanh() | |
| ) | |
| def forward(self, x): | |
| x = self.decoder(x) | |
| return x | |
| class VST3DDecoder(nn.Module): | |
| def __init__(self, chnum_out): | |
| super(VST3DDecoder, self).__init__() | |
| # Dong Gong's paper code + Tanh | |
| self.chnum_out = chnum_out | |
| feature_num = 128 | |
| feature_num_2 = 96 | |
| feature_num_x2 = 256 | |
| feature_num_in = 768 | |
| self.transformer_decoder = nn.Sequential( | |
| # (4,768,4,8,8) | |
| nn.ConvTranspose3d(feature_num_in, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), | |
| output_padding=(1, 1, 1)), | |
| nn.BatchNorm3d(feature_num_x2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| # (4,256,4,16,16) | |
| nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), | |
| output_padding=(1, 1, 1)), | |
| nn.BatchNorm3d(feature_num_x2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), | |
| output_padding=(0, 1, 1)), | |
| nn.BatchNorm3d(feature_num), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), | |
| output_padding=(0, 1, 1)), | |
| nn.BatchNorm3d(feature_num_2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.ConvTranspose3d(feature_num_2, self.chnum_out, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), | |
| output_padding=(0, 1, 1)), | |
| nn.Tanh() | |
| ) | |
| def forward(self, x): | |
| x = self.transformer_decoder(x) | |
| return x |