Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from .blocks.complexblock import ComplexLinearLayer, ComplexConv2d | |
| from .cvvae import CVDiagonalGaussianDistribution | |
| class CVBottleNeck(nn.Module): | |
| def __init__(self, input_dim, latent_dim): | |
| super(CVBottleNeck, self).__init__() | |
| self.real_mu = nn.Linear(input_dim, latent_dim) | |
| self.real_var = nn.Linear(input_dim, latent_dim) | |
| self.imag_mu = nn.Linear(input_dim, latent_dim) | |
| self.imag_var = nn.Linear(input_dim, latent_dim) | |
| def reparameterize(self, mu, log_var): | |
| """ | |
| Reparameterization trick to sample from N(mu, var) from | |
| N(0,1). | |
| :param mu: (Tensor) Mean of the latent Gaussian [B x D] | |
| :param log_var: (Tensor) Standard deviation of the latent Gaussian [B x D] | |
| :return: (Tensor) [B x D] | |
| """ | |
| std = torch.exp(0.5 * log_var) | |
| eps = torch.randn_like(std) | |
| return eps * std + mu | |
| def forward(self, real, imag): | |
| # Real part | |
| real_mu = self.real_mu(real) | |
| real_var = self.real_var(real) | |
| # Imag part | |
| imag_mu = self.imag_mu(imag) | |
| imag_var = self.imag_var(imag) | |
| # Apply reparam trick to allow backprop | |
| z_real = self.reparameterize(real_mu, real_var) | |
| z_imag = self.reparameterize(imag_mu, imag_var) | |
| # Combined each real and imag part for KL-Div | |
| mu = torch.cat([real_mu, imag_mu], 1) | |
| var = torch.cat([real_var, imag_var], 1) | |
| return z_real, z_imag, mu, var | |
| class ComplexLatentSample(nn.Module): | |
| def __init__(self, latent_dim, output_dim, last_encoder_dim, complex_axis=1): | |
| super(ComplexLatentSample, self).__init__() | |
| self.complex_axis = complex_axis | |
| self.output_chan_axis = output_dim//4 | |
| self.last_encoder_dim = last_encoder_dim | |
| # self.sample_real = nn.Linear(latent_dim, output_dim//2) | |
| # self.sample_imag = nn.Linear(latent_dim, output_dim//2) | |
| self.C_half = self.output_chan_axis // 2 | |
| self.H = last_encoder_dim | |
| self.sample_real = nn.Linear( | |
| latent_dim, | |
| self.C_half * self.H * self.H | |
| ) | |
| self.sample_imag = nn.Linear( | |
| latent_dim, | |
| self.C_half * self.H * self.H | |
| ) | |
| self.linear_conv = ComplexLinearLayer(self.output_chan_axis, self.output_chan_axis) | |
| def forward(self, z_real, z_imag): | |
| # Sampling | |
| samp_real = self.sample_real(z_real) | |
| samp_imag = self.sample_imag(z_imag) | |
| # Re-arrange to (B, C, T, F) | |
| # real = samp_real.view([-1, self.output_chan_axis//2, self.last_encoder_dim, self.last_encoder_dim]) | |
| # imag = samp_imag.view([-1, self.output_chan_axis//2, self.last_encoder_dim, self.last_encoder_dim]) | |
| B = z_real.size(0) | |
| real = samp_real.view( | |
| B, | |
| self.C_half, | |
| self.H, | |
| self.H | |
| ) | |
| imag = samp_imag.view( | |
| B, | |
| self.C_half, | |
| self.H, | |
| self.H | |
| ) | |
| # Concat and give to 1x1 conv | |
| sample = torch.cat([real, imag], self.complex_axis) | |
| sample = self.linear_conv(sample) | |
| return sample | |
| class CVVarBottleNeck(nn.Module): | |
| def __init__(self, encoder_hidden_dims: list = None, feature_size: tuple = (256, 256), latent_dim: int = 512, **kwargs): | |
| super().__init__() | |
| if encoder_hidden_dims is None: | |
| encoder_hidden_dims = [64, 128, 256, 512, 512, 512, 512] | |
| H, W = feature_size | |
| assert H == W, "Currently only square feature maps are supported" | |
| last_enc_dim = feature_size[0] // (2 ** (len(encoder_hidden_dims))) | |
| self.complex_linear_map = ComplexLinearLayer(encoder_hidden_dims[-1], encoder_hidden_dims[-1]) | |
| self.complex_bottleneck = CVBottleNeck(encoder_hidden_dims[-1] * last_enc_dim * last_enc_dim // 2, latent_dim) # factor 2 for complex | |
| self.complex_sampling = ComplexLatentSample(latent_dim, encoder_hidden_dims[-1] * 4, last_enc_dim) | |
| def forward(self, x): | |
| x = self.complex_linear_map(x) | |
| x = torch.flatten(x, start_dim=1) | |
| # Split real and imaginary parts | |
| mag, phase = torch.chunk(x, 2, dim=1) | |
| z_mag, z_phase, mu, log_var = self.complex_bottleneck(mag, phase) | |
| # Sample | |
| s = self.complex_sampling(z_mag, z_phase) | |
| return s, mu, log_var | |
| class CVBottleNeckKL(nn.Module): | |
| def __init__(self, latent_channels: int = 8, double_z: bool = True, latest_hidden_dims: int = 512, **kwargs): | |
| super().__init__() | |
| enc_out_channels = 2 * latent_channels if double_z else latent_channels | |
| self.quant_conv = ComplexConv2d( | |
| latest_hidden_dims, | |
| enc_out_channels, | |
| kernel_size=(3, 3), | |
| stride=(1, 1), | |
| padding=(1, 1) | |
| ) | |
| sample_out_channels = latent_channels if double_z else latent_channels // 2 | |
| self.post_quant_conv = ComplexConv2d( | |
| in_channels=sample_out_channels, | |
| out_channels=latest_hidden_dims, | |
| kernel_size=(3, 3), | |
| stride=(1, 1), | |
| padding=(1, 1) | |
| ) | |
| def forward(self, x): | |
| # Encode to get moments | |
| moments = self.quant_conv(x) | |
| # Get posterior distribution | |
| posterior = CVDiagonalGaussianDistribution(moments) | |
| z = posterior.sample() | |
| # Get output for Decoder | |
| o = self.post_quant_conv(z) | |
| return o, posterior |