ipad-vad-training / IPAD /model /memae_3dconv.py
MSherbinii's picture
Fix import in memae_3dconv.py
7844753 verified
from __future__ import absolute_import, print_function
import torch
from torch import nn
from .memory_module import MemModule
class AutoEncoderCov3DMem(nn.Module):
def __init__(self, chnum_in, mem_dim, shrink_thres=0.0025):
super(AutoEncoderCov3DMem, self).__init__()
print('AutoEncoderCov3DMem')
self.chnum_in = chnum_in
feature_num = 128
feature_num_2 = 96
feature_num_x2 = 256
self.encoder = nn.Sequential(
nn.Conv3d(self.chnum_in, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(feature_num_2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(feature_num, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True)
)
self.mem_rep = MemModule(mem_dim=mem_dim, fea_dim=feature_num_x2, shrink_thres =shrink_thres)
self.decoder = nn.Sequential(
nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_x2),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1),
output_padding=(1, 1, 1)),
nn.BatchNorm3d(feature_num_2),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose3d(feature_num_2, self.chnum_in, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1),
output_padding=(0, 1, 1))
)
def forward(self, x):
f = self.encoder(x)
res_mem = self.mem_rep(f)
f = res_mem['output']
att = res_mem['att']
output = self.decoder(f)
return {'output': output, 'att': att}