import torch import numpy as np import torch.nn as nn class CVAE(nn.Module): def __init__(self, d_model, latent_dim, n_frames, n_mels, n_genres): super(CVAE, self).__init__() self.d_model = d_model self.latent_dim = latent_dim self.n_frames = int(np.ceil(n_frames / 2**3)) self.n_mels = int(np.ceil(n_mels / 2**3)) self.n_genres = n_genres print(self.n_frames, self.n_mels) # Encoder self.encoder = nn.Sequential( nn.Conv2d(1 + self.n_genres, d_model, kernel_size=3, stride=2, padding=1), # [B, d, ceil(n_mels/2), ceil(n_frame/2)] nn.BatchNorm2d(d_model), nn.SiLU(), nn.Dropout2d(0.05), nn.Conv2d(d_model, d_model * 2, kernel_size=3, stride=2, padding=1), # [B, 2*d, ceil(n_mels/2**2), ceil(n_frames/2**2)] nn.BatchNorm2d(d_model * 2), nn.SiLU(), nn.Dropout2d(0.1), nn.Conv2d(d_model * 2, d_model * 4, kernel_size=3, stride=2, padding=1), # [B, 4*d, ceil(n_mels/2**3), ceil(n_frames/2**3)] nn.BatchNorm2d(d_model * 4), nn.SiLU(), nn.Dropout2d(0.15), nn.AdaptiveAvgPool2d((1, 1)), # [B, 4*d, 1, 1] nn.Flatten() ) # Latent space self.fc_mu = nn.Linear(d_model * 4, latent_dim) self.fc_logvar = nn.Linear(d_model * 4, latent_dim) # Decoder self.decoder_input = nn.Linear(latent_dim + self.n_genres, d_model * 4 * self.n_frames * self.n_mels) # [B, 4*d, ceil(n_mels/2**3), ceil(n_frames/2**3)] self.decoder = nn.Sequential( nn.ConvTranspose2d(d_model * 4, d_model * 2, kernel_size=3, stride=2, padding=1, output_padding=(1, 0)), # [B, 2*d, ceil(n_mels/2**2), ceil(n_frames/2**2)] nn.BatchNorm2d(d_model * 2), nn.SiLU(), nn.Dropout2d(0.1), nn.ConvTranspose2d(d_model * 2, d_model, kernel_size=3, stride=2, padding=1, output_padding=(1, 0)), # [B, d, ceil(n_mels/2), ceil(n_frame/2)] nn.BatchNorm2d(d_model), nn.SiLU(), nn.Dropout2d(0.05), nn.ConvTranspose2d(d_model, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # [B, 1, n_mels, n_frame] nn.Sigmoid() ) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x, genres_input): ori_genres_embed = genres_input.view(genres_input.size(0), -1) genres_embed = ori_genres_embed.unsqueeze(-1).unsqueeze(-1) genres_embed = genres_embed.expand(-1, -1, x.size(2), x.size(3)) x_genres = torch.cat((x, genres_embed), dim=1) h = x_genres shortcuts = [] for block in self.encoder: h = block(h) if isinstance(block, nn.SiLU): shortcuts.append(h) # skip-connection mu = self.fc_mu(h) logvar = self.fc_logvar(h) z = self.reparameterize(mu, logvar) z_genres = torch.cat((z, ori_genres_embed), dim=1) h_dec = self.decoder_input(z_genres) h_dec = h_dec.view(-1, self.d_model * 4, self.n_frames, self.n_mels) for block in self.decoder: if isinstance(block, nn.ConvTranspose2d) and shortcuts: shortcut = shortcuts.pop() # skip-connection h_dec = h_dec + shortcut h_dec = block(h_dec) recon = h_dec[:, :, :x.size(2), :x.size(3)] return recon, mu, logvar