import torch import torch.nn as nn from .blocks.complexblock import ComplexLinearLayer, ComplexConv2d from .cvvae import CVDiagonalGaussianDistribution class CVBottleNeck(nn.Module): def __init__(self, input_dim, latent_dim): super(CVBottleNeck, self).__init__() self.real_mu = nn.Linear(input_dim, latent_dim) self.real_var = nn.Linear(input_dim, latent_dim) self.imag_mu = nn.Linear(input_dim, latent_dim) self.imag_var = nn.Linear(input_dim, latent_dim) def reparameterize(self, mu, log_var): """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param log_var: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return eps * std + mu def forward(self, real, imag): # Real part real_mu = self.real_mu(real) real_var = self.real_var(real) # Imag part imag_mu = self.imag_mu(imag) imag_var = self.imag_var(imag) # Apply reparam trick to allow backprop z_real = self.reparameterize(real_mu, real_var) z_imag = self.reparameterize(imag_mu, imag_var) # Combined each real and imag part for KL-Div mu = torch.cat([real_mu, imag_mu], 1) var = torch.cat([real_var, imag_var], 1) return z_real, z_imag, mu, var class ComplexLatentSample(nn.Module): def __init__(self, latent_dim, output_dim, last_encoder_dim, complex_axis=1): super(ComplexLatentSample, self).__init__() self.complex_axis = complex_axis self.output_chan_axis = output_dim//4 self.last_encoder_dim = last_encoder_dim # self.sample_real = nn.Linear(latent_dim, output_dim//2) # self.sample_imag = nn.Linear(latent_dim, output_dim//2) self.C_half = self.output_chan_axis // 2 self.H = last_encoder_dim self.sample_real = nn.Linear( latent_dim, self.C_half * self.H * self.H ) self.sample_imag = nn.Linear( latent_dim, self.C_half * self.H * self.H ) self.linear_conv = ComplexLinearLayer(self.output_chan_axis, self.output_chan_axis) def forward(self, z_real, z_imag): # Sampling samp_real = self.sample_real(z_real) samp_imag = self.sample_imag(z_imag) # Re-arrange to (B, C, T, F) # real = samp_real.view([-1, self.output_chan_axis//2, self.last_encoder_dim, self.last_encoder_dim]) # imag = samp_imag.view([-1, self.output_chan_axis//2, self.last_encoder_dim, self.last_encoder_dim]) B = z_real.size(0) real = samp_real.view( B, self.C_half, self.H, self.H ) imag = samp_imag.view( B, self.C_half, self.H, self.H ) # Concat and give to 1x1 conv sample = torch.cat([real, imag], self.complex_axis) sample = self.linear_conv(sample) return sample class CVVarBottleNeck(nn.Module): def __init__(self, encoder_hidden_dims: list = None, feature_size: tuple = (256, 256), latent_dim: int = 512, **kwargs): super().__init__() if encoder_hidden_dims is None: encoder_hidden_dims = [64, 128, 256, 512, 512, 512, 512] H, W = feature_size assert H == W, "Currently only square feature maps are supported" last_enc_dim = feature_size[0] // (2 ** (len(encoder_hidden_dims))) self.complex_linear_map = ComplexLinearLayer(encoder_hidden_dims[-1], encoder_hidden_dims[-1]) self.complex_bottleneck = CVBottleNeck(encoder_hidden_dims[-1] * last_enc_dim * last_enc_dim // 2, latent_dim) # factor 2 for complex self.complex_sampling = ComplexLatentSample(latent_dim, encoder_hidden_dims[-1] * 4, last_enc_dim) def forward(self, x): x = self.complex_linear_map(x) x = torch.flatten(x, start_dim=1) # Split real and imaginary parts mag, phase = torch.chunk(x, 2, dim=1) z_mag, z_phase, mu, log_var = self.complex_bottleneck(mag, phase) # Sample s = self.complex_sampling(z_mag, z_phase) return s, mu, log_var class CVBottleNeckKL(nn.Module): def __init__(self, latent_channels: int = 8, double_z: bool = True, latest_hidden_dims: int = 512, **kwargs): super().__init__() enc_out_channels = 2 * latent_channels if double_z else latent_channels self.quant_conv = ComplexConv2d( latest_hidden_dims, enc_out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) ) sample_out_channels = latent_channels if double_z else latent_channels // 2 self.post_quant_conv = ComplexConv2d( in_channels=sample_out_channels, out_channels=latest_hidden_dims, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) ) def forward(self, x): # Encode to get moments moments = self.quant_conv(x) # Get posterior distribution posterior = CVDiagonalGaussianDistribution(moments) z = posterior.sample() # Get output for Decoder o = self.post_quant_conv(z) return o, posterior