JukeBox / VAE.py
hjimjim
model upload: reconstruct
081442b
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim, num_styles=2):
super(VAE, self).__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.hidden_dim = hidden_dim
self.encode = Encoder(self.input_dim, self.hidden_dim, self.latent_dim)
self.decode = Decoder(self.latent_dim, self.hidden_dim, self.input_dim)
self.style_classifier = StyleClassifier(self.latent_dim, num_styles)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x, right=None, left=None, check=False):
mu, logvar, output = self.encode(x)
z = self.reparameterize(mu, logvar)
style_pred = self.style_classifier(z)
decod = self.decode(z, output)
return decod, mu, logvar, style_pred
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.hidden_dim = hidden_dim
self.gru_piano_right = nn.GRU(input_dim, hidden_dim, batch_first=True)
self.gru_piano_left = nn.GRU(input_dim, hidden_dim, batch_first=True)
self.dense_layer = nn.Linear(hidden_dim * 2, hidden_dim, bias = True)
self.fc_mu = nn.Linear(hidden_dim, latent_dim, bias = True)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim, bias = True)
def forward(self, x):
input_list = torch.chunk(x, 2, dim=1)
right_input = input_list[0] # torch.Size([Batch Size, Sequence length, input length])
left_input = input_list[1]
# initialize hidden state
h0 = torch.zeros(1, right_input.size(0), self.hidden_dim, device=right_input.device)
# Forward pass through GRU for each instrument
o_r, h_r = self.gru_piano_right(right_input, h0)
o_l, h_l = self.gru_piano_left(left_input, h0)
output = torch.cat((o_r, o_l), dim=1)
h = torch.cat((h_r[-1,], h_l[-1,]), dim=1)
h = self.dense_layer(h)
h = F.relu(h)
mu = self.fc_mu(h)
mu = F.relu(mu)
logvar = self.fc_logvar(h)
logvar = F.relu(logvar) + 1e-4
return mu, logvar, output
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.latent_dim = latent_dim
self.output_dim = output_dim
self.latent_to_hidden = nn.Linear(latent_dim, latent_dim, bias = True)
self.piano_right_layer = nn.Linear(latent_dim, hidden_dim, bias = True)
self.piano_left_layer = nn.Linear(latent_dim, hidden_dim, bias = True)
self.r_layer = nn.Linear(hidden_dim, output_dim, bias = True)
self.l_layer = nn.Linear(hidden_dim, output_dim, bias = True)
self.gru_piano_right_cell = nn.GRUCell(output_dim, hidden_dim)
self.gru_piano_left_cell = nn.GRUCell(output_dim, hidden_dim)
self.fr_layer = nn.Linear(hidden_dim, output_dim, bias = True)
self.fl_layer = nn.Linear(hidden_dim , output_dim, bias = True)
def forward(self, z, output):
h = self.latent_to_hidden(z)
h = F.relu(h)
right = torch.split(output, 150, dim=1)[0]
left = torch.split(output, 150, dim=1)[1]
right = right.permute(1, 0, 2)
left = left.permute(1, 0, 2)
right = self.r_layer(right)
right = F.tanh(right)
left = self.l_layer(left)
left = F.tanh(left)
piano_hidden = self.piano_right_layer(h)
left_hidden = self.piano_left_layer(h)
right_outputs = []
left_outputs = []
for t in range(right.size(0)):
piano_hidden = self.gru_piano_right_cell(right[t] , piano_hidden)
left_hidden = self.gru_piano_left_cell(left[t], left_hidden)
right_outputs.append(piano_hidden.unsqueeze(1))
left_outputs.append(left_hidden.unsqueeze(1))
# print(right_outputs[0].shape)
right_outputs = torch.cat(right_outputs, dim=1)
left_outputs = torch.cat(left_outputs, dim=1)
right_outputs = self.fr_layer(right_outputs)
left_outputs = self.fl_layer(left_outputs)
right_outputs = F.sigmoid(right_outputs)
left_outputs = F.sigmoid(left_outputs)
output = torch.cat((right_outputs, left_outputs), dim=1)
return output
class StyleClassifier(nn.Module):
def __init__(self, latent_dim, num_styles):
super(StyleClassifier, self).__init__()
self.fc1 = nn.Linear(latent_dim, 128)
self.fc2 = nn.Linear(128, num_styles)
def forward(self, z):
x = F.relu(self.fc1(z))
x = self.fc2(x)
return F.softmax(x, dim=-1)