Spaces:
Sleeping
Sleeping
File size: 4,431 Bytes
0f5deb2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | import torch.nn as nn
from functools import reduce
from operator import mul
import torch
class Reconstruction3DEncoder(nn.Module):
def __init__(self, chnum_in):
super(Reconstruction3DEncoder, self).__init__()
# Dong Gong's paper code
self.chnum_in = chnum_in
feature_num = 128
feature_num_2 = 96
feature_num_x2 = 256
self.encoder = nn.Sequential(
nn.Conv3d(self.chnum_in, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(feature_num_2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(feature_num, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True)
)
def forward(self, x):
x = self.encoder(x)
return x
class Reconstruction3DDecoder(nn.Module):
def __init__(self, chnum_in):
super(Reconstruction3DDecoder, self).__init__()
# Dong Gong's paper code + Tanh
self.chnum_in = chnum_in
feature_num = 128
feature_num_2 = 96
feature_num_x2 = 256
self.decoder = nn.Sequential(
nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_2),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num_2, self.chnum_in, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
output_padding=(0, 1, 1)),
nn.Tanh()
)
def forward(self, x):
x = self.decoder(x)
return x
class VST3DDecoder(nn.Module):
def __init__(self, chnum_out):
super(VST3DDecoder, self).__init__()
# Dong Gong's paper code + Tanh
self.chnum_out = chnum_out
feature_num = 128
feature_num_2 = 96
feature_num_x2 = 256
feature_num_in = 768
self.transformer_decoder = nn.Sequential(
# (4,768,4,8,8)
nn.ConvTranspose3d(feature_num_in, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True),
# (4,256,4,16,16)
nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
output_padding=(0, 1, 1)),
nn.BatchNorm3d(feature_num),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
output_padding=(0, 1, 1)),
nn.BatchNorm3d(feature_num_2),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num_2, self.chnum_out, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
output_padding=(0, 1, 1)),
nn.Tanh()
)
def forward(self, x):
x = self.transformer_decoder(x)
return x |