from __future__ import absolute_import, print_function import torch from torch import nn from .memory_module import MemModule class AutoEncoderCov3DMem(nn.Module): def __init__(self, chnum_in, mem_dim, shrink_thres=0.0025): super(AutoEncoderCov3DMem, self).__init__() print('AutoEncoderCov3DMem') self.chnum_in = chnum_in feature_num = 128 feature_num_2 = 96 feature_num_x2 = 256 self.encoder = nn.Sequential( nn.Conv3d(self.chnum_in, feature_num_2, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_2), nn.LeakyReLU(0.2, inplace=True), nn.Conv3d(feature_num_2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(feature_num), nn.LeakyReLU(0.2, inplace=True), nn.Conv3d(feature_num, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_x2), nn.LeakyReLU(0.2, inplace=True), nn.Conv3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_x2), nn.LeakyReLU(0.2, inplace=True) ) self.mem_rep = MemModule(mem_dim=mem_dim, fea_dim=feature_num_x2, shrink_thres =shrink_thres) self.decoder = nn.Sequential( nn.ConvTranspose3d(feature_num_x2, feature_num_x2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_x2), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose3d(feature_num_x2, feature_num, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1)), nn.BatchNorm3d(feature_num), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose3d(feature_num, feature_num_2, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1)), nn.BatchNorm3d(feature_num_2), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose3d(feature_num_2, self.chnum_in, (3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), output_padding=(0, 1, 1)) ) def forward(self, x): f = self.encoder(x) res_mem = self.mem_rep(f) f = res_mem['output'] att = res_mem['att'] output = self.decoder(f) return {'output': output, 'att': att}