from .utils import init_weights,init_weights_orthogonal_normal, l2_regularisation import torch.nn.functional as F from torch.distributions import Normal, Independent, kl import torch import torch.nn as nn device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class Encoder(nn.Module): """ A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block convolutional layers, after each block a pooling operation is performed. And after each convolutional layer a non-linear (ReLU) activation function is applied. """ def __init__(self, input_channels, num_filters, no_convs_per_block, initializers, padding=True, posterior=False): super(Encoder, self).__init__() self.contracting_path = nn.ModuleList() self.input_channels = input_channels self.num_filters = num_filters if posterior: #To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels. self.input_channels += 1 layers = [] for i in range(len(self.num_filters)): """ Determine input_dim and output_dim of conv layers in this block. The first layer is input x output, All the subsequent layers are output x output. """ input_dim = self.input_channels if i == 0 else output_dim output_dim = num_filters[i] if i != 0: layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding))) layers.append(nn.ReLU(inplace=True)) for _ in range(no_convs_per_block-1): layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding))) layers.append(nn.ReLU(inplace=True)) self.layers = nn.Sequential(*layers) self.layers.apply(init_weights) def forward(self, input): output = self.layers(input) return output class AxisAlignedConvGaussian(nn.Module): """ A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix. """ def __init__(self, input_channels, num_filters, no_convs_per_block, latent_dim, initializers, posterior=False): super(AxisAlignedConvGaussian, self).__init__() self.input_channels = input_channels self.channel_axis = 1 self.num_filters = num_filters self.no_convs_per_block = no_convs_per_block self.latent_dim = latent_dim self.posterior = posterior if self.posterior: self.name = 'Posterior' else: self.name = 'Prior' self.encoder = Encoder(self.input_channels, self.num_filters, self.no_convs_per_block, initializers, posterior=self.posterior) self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1,1), stride=1) self.show_img = 0 self.show_seg = 0 self.show_concat = 0 self.show_enc = 0 self.sum_input = 0 nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu') nn.init.normal_(self.conv_layer.bias) def forward(self, input, segm=None): #If segmentation is not none, concatenate the mask to the channel axis of the input if segm is not None: self.show_img = input self.show_seg = segm input = torch.cat((input, segm), dim=1) self.show_concat = input self.sum_input = torch.sum(input) encoding = self.encoder(input) self.show_enc = encoding #We only want the mean of the resulting hxw image encoding = torch.mean(encoding, dim=2, keepdim=True) encoding = torch.mean(encoding, dim=3, keepdim=True) #Convert encoding to 2 x latent dim and split up for mu and log_sigma mu_log_sigma = self.conv_layer(encoding) #We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1 mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) mu = mu_log_sigma[:,:self.latent_dim] log_sigma = mu_log_sigma[:,self.latent_dim:] #This is a multivariate normal with diagonal covariance matrix sigma #https://github.com/pytorch/pytorch/pull/11178 dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),1) return dist