U-Past / modules /bottleneck.py
lycaoduong's picture
Initial space
e8160b2 verified
import torch
import torch.nn as nn
from .blocks.complexblock import ComplexLinearLayer, ComplexConv2d
from .cvvae import CVDiagonalGaussianDistribution
class CVBottleNeck(nn.Module):
def __init__(self, input_dim, latent_dim):
super(CVBottleNeck, self).__init__()
self.real_mu = nn.Linear(input_dim, latent_dim)
self.real_var = nn.Linear(input_dim, latent_dim)
self.imag_mu = nn.Linear(input_dim, latent_dim)
self.imag_var = nn.Linear(input_dim, latent_dim)
def reparameterize(self, mu, log_var):
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param log_var: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, real, imag):
# Real part
real_mu = self.real_mu(real)
real_var = self.real_var(real)
# Imag part
imag_mu = self.imag_mu(imag)
imag_var = self.imag_var(imag)
# Apply reparam trick to allow backprop
z_real = self.reparameterize(real_mu, real_var)
z_imag = self.reparameterize(imag_mu, imag_var)
# Combined each real and imag part for KL-Div
mu = torch.cat([real_mu, imag_mu], 1)
var = torch.cat([real_var, imag_var], 1)
return z_real, z_imag, mu, var
class ComplexLatentSample(nn.Module):
def __init__(self, latent_dim, output_dim, last_encoder_dim, complex_axis=1):
super(ComplexLatentSample, self).__init__()
self.complex_axis = complex_axis
self.output_chan_axis = output_dim//4
self.last_encoder_dim = last_encoder_dim
# self.sample_real = nn.Linear(latent_dim, output_dim//2)
# self.sample_imag = nn.Linear(latent_dim, output_dim//2)
self.C_half = self.output_chan_axis // 2
self.H = last_encoder_dim
self.sample_real = nn.Linear(
latent_dim,
self.C_half * self.H * self.H
)
self.sample_imag = nn.Linear(
latent_dim,
self.C_half * self.H * self.H
)
self.linear_conv = ComplexLinearLayer(self.output_chan_axis, self.output_chan_axis)
def forward(self, z_real, z_imag):
# Sampling
samp_real = self.sample_real(z_real)
samp_imag = self.sample_imag(z_imag)
# Re-arrange to (B, C, T, F)
# real = samp_real.view([-1, self.output_chan_axis//2, self.last_encoder_dim, self.last_encoder_dim])
# imag = samp_imag.view([-1, self.output_chan_axis//2, self.last_encoder_dim, self.last_encoder_dim])
B = z_real.size(0)
real = samp_real.view(
B,
self.C_half,
self.H,
self.H
)
imag = samp_imag.view(
B,
self.C_half,
self.H,
self.H
)
# Concat and give to 1x1 conv
sample = torch.cat([real, imag], self.complex_axis)
sample = self.linear_conv(sample)
return sample
class CVVarBottleNeck(nn.Module):
def __init__(self, encoder_hidden_dims: list = None, feature_size: tuple = (256, 256), latent_dim: int = 512, **kwargs):
super().__init__()
if encoder_hidden_dims is None:
encoder_hidden_dims = [64, 128, 256, 512, 512, 512, 512]
H, W = feature_size
assert H == W, "Currently only square feature maps are supported"
last_enc_dim = feature_size[0] // (2 ** (len(encoder_hidden_dims)))
self.complex_linear_map = ComplexLinearLayer(encoder_hidden_dims[-1], encoder_hidden_dims[-1])
self.complex_bottleneck = CVBottleNeck(encoder_hidden_dims[-1] * last_enc_dim * last_enc_dim // 2, latent_dim) # factor 2 for complex
self.complex_sampling = ComplexLatentSample(latent_dim, encoder_hidden_dims[-1] * 4, last_enc_dim)
def forward(self, x):
x = self.complex_linear_map(x)
x = torch.flatten(x, start_dim=1)
# Split real and imaginary parts
mag, phase = torch.chunk(x, 2, dim=1)
z_mag, z_phase, mu, log_var = self.complex_bottleneck(mag, phase)
# Sample
s = self.complex_sampling(z_mag, z_phase)
return s, mu, log_var
class CVBottleNeckKL(nn.Module):
def __init__(self, latent_channels: int = 8, double_z: bool = True, latest_hidden_dims: int = 512, **kwargs):
super().__init__()
enc_out_channels = 2 * latent_channels if double_z else latent_channels
self.quant_conv = ComplexConv2d(
latest_hidden_dims,
enc_out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1)
)
sample_out_channels = latent_channels if double_z else latent_channels // 2
self.post_quant_conv = ComplexConv2d(
in_channels=sample_out_channels,
out_channels=latest_hidden_dims,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1)
)
def forward(self, x):
# Encode to get moments
moments = self.quant_conv(x)
# Get posterior distribution
posterior = CVDiagonalGaussianDistribution(moments)
z = posterior.sample()
# Get output for Decoder
o = self.post_quant_conv(z)
return o, posterior