|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.optim as optim |
|
|
|
|
|
class VAE(nn.Module): |
|
|
def __init__(self, input_dim, hidden_dim, latent_dim, num_styles=2): |
|
|
super(VAE, self).__init__() |
|
|
self.input_dim = input_dim |
|
|
self.latent_dim = latent_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
self.encode = Encoder(self.input_dim, self.hidden_dim, self.latent_dim) |
|
|
self.decode = Decoder(self.latent_dim, self.hidden_dim, self.input_dim) |
|
|
self.style_classifier = StyleClassifier(self.latent_dim, num_styles) |
|
|
|
|
|
def reparameterize(self, mu, logvar): |
|
|
std = torch.exp(0.5 * logvar) |
|
|
eps = torch.randn_like(std) |
|
|
return mu + eps * std |
|
|
|
|
|
def forward(self, x, right=None, left=None, check=False): |
|
|
mu, logvar, output = self.encode(x) |
|
|
z = self.reparameterize(mu, logvar) |
|
|
style_pred = self.style_classifier(z) |
|
|
decod = self.decode(z, output) |
|
|
return decod, mu, logvar, style_pred |
|
|
|
|
|
class Encoder(nn.Module): |
|
|
def __init__(self, input_dim, hidden_dim, latent_dim): |
|
|
super(Encoder, self).__init__() |
|
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
self.gru_piano_right = nn.GRU(input_dim, hidden_dim, batch_first=True) |
|
|
self.gru_piano_left = nn.GRU(input_dim, hidden_dim, batch_first=True) |
|
|
|
|
|
self.dense_layer = nn.Linear(hidden_dim * 2, hidden_dim, bias = True) |
|
|
|
|
|
self.fc_mu = nn.Linear(hidden_dim, latent_dim, bias = True) |
|
|
self.fc_logvar = nn.Linear(hidden_dim, latent_dim, bias = True) |
|
|
|
|
|
def forward(self, x): |
|
|
input_list = torch.chunk(x, 2, dim=1) |
|
|
right_input = input_list[0] |
|
|
left_input = input_list[1] |
|
|
|
|
|
|
|
|
h0 = torch.zeros(1, right_input.size(0), self.hidden_dim, device=right_input.device) |
|
|
|
|
|
|
|
|
o_r, h_r = self.gru_piano_right(right_input, h0) |
|
|
o_l, h_l = self.gru_piano_left(left_input, h0) |
|
|
|
|
|
output = torch.cat((o_r, o_l), dim=1) |
|
|
h = torch.cat((h_r[-1,], h_l[-1,]), dim=1) |
|
|
|
|
|
h = self.dense_layer(h) |
|
|
h = F.relu(h) |
|
|
mu = self.fc_mu(h) |
|
|
mu = F.relu(mu) |
|
|
logvar = self.fc_logvar(h) |
|
|
logvar = F.relu(logvar) + 1e-4 |
|
|
return mu, logvar, output |
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__(self, latent_dim, hidden_dim, output_dim): |
|
|
super(Decoder, self).__init__() |
|
|
self.latent_dim = latent_dim |
|
|
self.output_dim = output_dim |
|
|
|
|
|
self.latent_to_hidden = nn.Linear(latent_dim, latent_dim, bias = True) |
|
|
|
|
|
self.piano_right_layer = nn.Linear(latent_dim, hidden_dim, bias = True) |
|
|
self.piano_left_layer = nn.Linear(latent_dim, hidden_dim, bias = True) |
|
|
|
|
|
self.r_layer = nn.Linear(hidden_dim, output_dim, bias = True) |
|
|
self.l_layer = nn.Linear(hidden_dim, output_dim, bias = True) |
|
|
|
|
|
self.gru_piano_right_cell = nn.GRUCell(output_dim, hidden_dim) |
|
|
self.gru_piano_left_cell = nn.GRUCell(output_dim, hidden_dim) |
|
|
|
|
|
|
|
|
self.fr_layer = nn.Linear(hidden_dim, output_dim, bias = True) |
|
|
self.fl_layer = nn.Linear(hidden_dim , output_dim, bias = True) |
|
|
|
|
|
def forward(self, z, output): |
|
|
|
|
|
h = self.latent_to_hidden(z) |
|
|
h = F.relu(h) |
|
|
|
|
|
right = torch.split(output, 150, dim=1)[0] |
|
|
left = torch.split(output, 150, dim=1)[1] |
|
|
|
|
|
right = right.permute(1, 0, 2) |
|
|
left = left.permute(1, 0, 2) |
|
|
|
|
|
right = self.r_layer(right) |
|
|
right = F.tanh(right) |
|
|
left = self.l_layer(left) |
|
|
left = F.tanh(left) |
|
|
|
|
|
|
|
|
piano_hidden = self.piano_right_layer(h) |
|
|
left_hidden = self.piano_left_layer(h) |
|
|
|
|
|
right_outputs = [] |
|
|
left_outputs = [] |
|
|
|
|
|
for t in range(right.size(0)): |
|
|
piano_hidden = self.gru_piano_right_cell(right[t] , piano_hidden) |
|
|
left_hidden = self.gru_piano_left_cell(left[t], left_hidden) |
|
|
|
|
|
right_outputs.append(piano_hidden.unsqueeze(1)) |
|
|
left_outputs.append(left_hidden.unsqueeze(1)) |
|
|
|
|
|
|
|
|
right_outputs = torch.cat(right_outputs, dim=1) |
|
|
left_outputs = torch.cat(left_outputs, dim=1) |
|
|
|
|
|
right_outputs = self.fr_layer(right_outputs) |
|
|
left_outputs = self.fl_layer(left_outputs) |
|
|
|
|
|
right_outputs = F.sigmoid(right_outputs) |
|
|
left_outputs = F.sigmoid(left_outputs) |
|
|
|
|
|
output = torch.cat((right_outputs, left_outputs), dim=1) |
|
|
|
|
|
return output |
|
|
|
|
|
class StyleClassifier(nn.Module): |
|
|
def __init__(self, latent_dim, num_styles): |
|
|
super(StyleClassifier, self).__init__() |
|
|
self.fc1 = nn.Linear(latent_dim, 128) |
|
|
self.fc2 = nn.Linear(128, num_styles) |
|
|
|
|
|
def forward(self, z): |
|
|
x = F.relu(self.fc1(z)) |
|
|
x = self.fc2(x) |
|
|
return F.softmax(x, dim=-1) |