ipad-vad-training / IPAD /model /reconstruction_model.py
MSherbinii's picture
Add IPAD model implementation
0f5deb2 verified
import torch.nn as nn
from functools import reduce
from operator import mul
import torch
class Reconstruction3DEncoder(nn.Module):
def __init__(self, chnum_in):
super(Reconstruction3DEncoder, self).__init__()
# Dong Gong's paper code
self.chnum_in = chnum_in
feature_num = 128
feature_num_2 = 96
feature_num_x2 = 256
self.encoder = nn.Sequential(
nn.Conv3d(self.chnum_in, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(feature_num_2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(feature_num, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True)
)
def forward(self, x):
x = self.encoder(x)
return x
class Reconstruction3DDecoder(nn.Module):
def __init__(self, chnum_in):
super(Reconstruction3DDecoder, self).__init__()
# Dong Gong's paper code + Tanh
self.chnum_in = chnum_in
feature_num = 128
feature_num_2 = 96
feature_num_x2 = 256
self.decoder = nn.Sequential(
nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_2),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num_2, self.chnum_in, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
output_padding=(0, 1, 1)),
nn.Tanh()
)
def forward(self, x):
x = self.decoder(x)
return x
class VST3DDecoder(nn.Module):
def __init__(self, chnum_out):
super(VST3DDecoder, self).__init__()
# Dong Gong's paper code + Tanh
self.chnum_out = chnum_out
feature_num = 128
feature_num_2 = 96
feature_num_x2 = 256
feature_num_in = 768
self.transformer_decoder = nn.Sequential(
# (4,768,4,8,8)
nn.ConvTranspose3d(feature_num_in, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True),
# (4,256,4,16,16)
nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
output_padding=(0, 1, 1)),
nn.BatchNorm3d(feature_num),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
output_padding=(0, 1, 1)),
nn.BatchNorm3d(feature_num_2),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num_2, self.chnum_out, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
output_padding=(0, 1, 1)),
nn.Tanh()
)
def forward(self, x):
x = self.transformer_decoder(x)
return x