|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
|
|
|
class CVAE(nn.Module): |
|
|
def __init__(self, d_model, latent_dim, n_frames, n_mels, n_genres): |
|
|
super(CVAE, self).__init__() |
|
|
self.d_model = d_model |
|
|
self.latent_dim = latent_dim |
|
|
self.n_frames = int(np.ceil(n_frames / 2**3)) |
|
|
self.n_mels = int(np.ceil(n_mels / 2**3)) |
|
|
self.n_genres = n_genres |
|
|
print(self.n_frames, self.n_mels) |
|
|
|
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
nn.Conv2d(1 + self.n_genres, d_model, kernel_size=3, stride=2, padding=1), |
|
|
nn.BatchNorm2d(d_model), |
|
|
nn.SiLU(), |
|
|
nn.Dropout2d(0.05), |
|
|
|
|
|
nn.Conv2d(d_model, d_model * 2, kernel_size=3, stride=2, padding=1), |
|
|
nn.BatchNorm2d(d_model * 2), |
|
|
nn.SiLU(), |
|
|
nn.Dropout2d(0.1), |
|
|
|
|
|
nn.Conv2d(d_model * 2, d_model * 4, kernel_size=3, stride=2, padding=1), |
|
|
nn.BatchNorm2d(d_model * 4), |
|
|
nn.SiLU(), |
|
|
nn.Dropout2d(0.15), |
|
|
|
|
|
nn.AdaptiveAvgPool2d((1, 1)), |
|
|
nn.Flatten() |
|
|
) |
|
|
|
|
|
|
|
|
self.fc_mu = nn.Linear(d_model * 4, latent_dim) |
|
|
self.fc_logvar = nn.Linear(d_model * 4, latent_dim) |
|
|
|
|
|
|
|
|
self.decoder_input = nn.Linear(latent_dim + self.n_genres, d_model * 4 * self.n_frames * self.n_mels) |
|
|
self.decoder = nn.Sequential( |
|
|
nn.ConvTranspose2d(d_model * 4, d_model * 2, kernel_size=3, stride=2, padding=1, output_padding=(1, 0)), |
|
|
nn.BatchNorm2d(d_model * 2), |
|
|
nn.SiLU(), |
|
|
nn.Dropout2d(0.1), |
|
|
|
|
|
nn.ConvTranspose2d(d_model * 2, d_model, kernel_size=3, stride=2, padding=1, output_padding=(1, 0)), |
|
|
nn.BatchNorm2d(d_model), |
|
|
nn.SiLU(), |
|
|
nn.Dropout2d(0.05), |
|
|
|
|
|
nn.ConvTranspose2d(d_model, 1, kernel_size=3, stride=2, padding=1, output_padding=1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def reparameterize(self, mu, logvar): |
|
|
std = torch.exp(0.5 * logvar) |
|
|
eps = torch.randn_like(std) |
|
|
return mu + eps * std |
|
|
|
|
|
def forward(self, x, genres_input): |
|
|
ori_genres_embed = genres_input.view(genres_input.size(0), -1) |
|
|
genres_embed = ori_genres_embed.unsqueeze(-1).unsqueeze(-1) |
|
|
genres_embed = genres_embed.expand(-1, -1, x.size(2), x.size(3)) |
|
|
x_genres = torch.cat((x, genres_embed), dim=1) |
|
|
|
|
|
h = x_genres |
|
|
shortcuts = [] |
|
|
for block in self.encoder: |
|
|
h = block(h) |
|
|
if isinstance(block, nn.SiLU): |
|
|
shortcuts.append(h) |
|
|
|
|
|
mu = self.fc_mu(h) |
|
|
logvar = self.fc_logvar(h) |
|
|
|
|
|
z = self.reparameterize(mu, logvar) |
|
|
z_genres = torch.cat((z, ori_genres_embed), dim=1) |
|
|
|
|
|
h_dec = self.decoder_input(z_genres) |
|
|
h_dec = h_dec.view(-1, self.d_model * 4, self.n_frames, self.n_mels) |
|
|
|
|
|
for block in self.decoder: |
|
|
if isinstance(block, nn.ConvTranspose2d) and shortcuts: |
|
|
shortcut = shortcuts.pop() |
|
|
h_dec = h_dec + shortcut |
|
|
h_dec = block(h_dec) |
|
|
|
|
|
recon = h_dec[:, :, :x.size(2), :x.size(3)] |
|
|
return recon, mu, logvar |