Ripefog's picture
Upload 11 files
4195b51 verified
import torch
import numpy as np
import torch.nn as nn
class CVAE(nn.Module):
def __init__(self, d_model, latent_dim, n_frames, n_mels, n_genres):
super(CVAE, self).__init__()
self.d_model = d_model
self.latent_dim = latent_dim
self.n_frames = int(np.ceil(n_frames / 2**3))
self.n_mels = int(np.ceil(n_mels / 2**3))
self.n_genres = n_genres
print(self.n_frames, self.n_mels)
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1 + self.n_genres, d_model, kernel_size=3, stride=2, padding=1), # [B, d, ceil(n_mels/2), ceil(n_frame/2)]
nn.BatchNorm2d(d_model),
nn.SiLU(),
nn.Dropout2d(0.05),
nn.Conv2d(d_model, d_model * 2, kernel_size=3, stride=2, padding=1), # [B, 2*d, ceil(n_mels/2**2), ceil(n_frames/2**2)]
nn.BatchNorm2d(d_model * 2),
nn.SiLU(),
nn.Dropout2d(0.1),
nn.Conv2d(d_model * 2, d_model * 4, kernel_size=3, stride=2, padding=1), # [B, 4*d, ceil(n_mels/2**3), ceil(n_frames/2**3)]
nn.BatchNorm2d(d_model * 4),
nn.SiLU(),
nn.Dropout2d(0.15),
nn.AdaptiveAvgPool2d((1, 1)), # [B, 4*d, 1, 1]
nn.Flatten()
)
# Latent space
self.fc_mu = nn.Linear(d_model * 4, latent_dim)
self.fc_logvar = nn.Linear(d_model * 4, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim + self.n_genres, d_model * 4 * self.n_frames * self.n_mels) # [B, 4*d, ceil(n_mels/2**3), ceil(n_frames/2**3)]
self.decoder = nn.Sequential(
nn.ConvTranspose2d(d_model * 4, d_model * 2, kernel_size=3, stride=2, padding=1, output_padding=(1, 0)), # [B, 2*d, ceil(n_mels/2**2), ceil(n_frames/2**2)]
nn.BatchNorm2d(d_model * 2),
nn.SiLU(),
nn.Dropout2d(0.1),
nn.ConvTranspose2d(d_model * 2, d_model, kernel_size=3, stride=2, padding=1, output_padding=(1, 0)), # [B, d, ceil(n_mels/2), ceil(n_frame/2)]
nn.BatchNorm2d(d_model),
nn.SiLU(),
nn.Dropout2d(0.05),
nn.ConvTranspose2d(d_model, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # [B, 1, n_mels, n_frame]
nn.Sigmoid()
)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x, genres_input):
ori_genres_embed = genres_input.view(genres_input.size(0), -1)
genres_embed = ori_genres_embed.unsqueeze(-1).unsqueeze(-1)
genres_embed = genres_embed.expand(-1, -1, x.size(2), x.size(3))
x_genres = torch.cat((x, genres_embed), dim=1)
h = x_genres
shortcuts = []
for block in self.encoder:
h = block(h)
if isinstance(block, nn.SiLU):
shortcuts.append(h) # skip-connection
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
z = self.reparameterize(mu, logvar)
z_genres = torch.cat((z, ori_genres_embed), dim=1)
h_dec = self.decoder_input(z_genres)
h_dec = h_dec.view(-1, self.d_model * 4, self.n_frames, self.n_mels)
for block in self.decoder:
if isinstance(block, nn.ConvTranspose2d) and shortcuts:
shortcut = shortcuts.pop() # skip-connection
h_dec = h_dec + shortcut
h_dec = block(h_dec)
recon = h_dec[:, :, :x.size(2), :x.size(3)]
return recon, mu, logvar