import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim class VAE(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim, num_styles=2): super(VAE, self).__init__() self.input_dim = input_dim self.latent_dim = latent_dim self.hidden_dim = hidden_dim self.encode = Encoder(self.input_dim, self.hidden_dim, self.latent_dim) self.decode = Decoder(self.latent_dim, self.hidden_dim, self.input_dim) self.style_classifier = StyleClassifier(self.latent_dim, num_styles) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x, right=None, left=None, check=False): mu, logvar, output = self.encode(x) z = self.reparameterize(mu, logvar) style_pred = self.style_classifier(z) decod = self.decode(z, output) return decod, mu, logvar, style_pred class Encoder(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super(Encoder, self).__init__() self.hidden_dim = hidden_dim self.gru_piano_right = nn.GRU(input_dim, hidden_dim, batch_first=True) self.gru_piano_left = nn.GRU(input_dim, hidden_dim, batch_first=True) self.dense_layer = nn.Linear(hidden_dim * 2, hidden_dim, bias = True) self.fc_mu = nn.Linear(hidden_dim, latent_dim, bias = True) self.fc_logvar = nn.Linear(hidden_dim, latent_dim, bias = True) def forward(self, x): input_list = torch.chunk(x, 2, dim=1) right_input = input_list[0] # torch.Size([Batch Size, Sequence length, input length]) left_input = input_list[1] # initialize hidden state h0 = torch.zeros(1, right_input.size(0), self.hidden_dim, device=right_input.device) # Forward pass through GRU for each instrument o_r, h_r = self.gru_piano_right(right_input, h0) o_l, h_l = self.gru_piano_left(left_input, h0) output = torch.cat((o_r, o_l), dim=1) h = torch.cat((h_r[-1,], h_l[-1,]), dim=1) h = self.dense_layer(h) h = F.relu(h) mu = self.fc_mu(h) mu = F.relu(mu) logvar = self.fc_logvar(h) logvar = F.relu(logvar) + 1e-4 return mu, logvar, output class Decoder(nn.Module): def __init__(self, latent_dim, hidden_dim, output_dim): super(Decoder, self).__init__() self.latent_dim = latent_dim self.output_dim = output_dim self.latent_to_hidden = nn.Linear(latent_dim, latent_dim, bias = True) self.piano_right_layer = nn.Linear(latent_dim, hidden_dim, bias = True) self.piano_left_layer = nn.Linear(latent_dim, hidden_dim, bias = True) self.r_layer = nn.Linear(hidden_dim, output_dim, bias = True) self.l_layer = nn.Linear(hidden_dim, output_dim, bias = True) self.gru_piano_right_cell = nn.GRUCell(output_dim, hidden_dim) self.gru_piano_left_cell = nn.GRUCell(output_dim, hidden_dim) self.fr_layer = nn.Linear(hidden_dim, output_dim, bias = True) self.fl_layer = nn.Linear(hidden_dim , output_dim, bias = True) def forward(self, z, output): h = self.latent_to_hidden(z) h = F.relu(h) right = torch.split(output, 150, dim=1)[0] left = torch.split(output, 150, dim=1)[1] right = right.permute(1, 0, 2) left = left.permute(1, 0, 2) right = self.r_layer(right) right = F.tanh(right) left = self.l_layer(left) left = F.tanh(left) piano_hidden = self.piano_right_layer(h) left_hidden = self.piano_left_layer(h) right_outputs = [] left_outputs = [] for t in range(right.size(0)): piano_hidden = self.gru_piano_right_cell(right[t] , piano_hidden) left_hidden = self.gru_piano_left_cell(left[t], left_hidden) right_outputs.append(piano_hidden.unsqueeze(1)) left_outputs.append(left_hidden.unsqueeze(1)) # print(right_outputs[0].shape) right_outputs = torch.cat(right_outputs, dim=1) left_outputs = torch.cat(left_outputs, dim=1) right_outputs = self.fr_layer(right_outputs) left_outputs = self.fl_layer(left_outputs) right_outputs = F.sigmoid(right_outputs) left_outputs = F.sigmoid(left_outputs) output = torch.cat((right_outputs, left_outputs), dim=1) return output class StyleClassifier(nn.Module): def __init__(self, latent_dim, num_styles): super(StyleClassifier, self).__init__() self.fc1 = nn.Linear(latent_dim, 128) self.fc2 = nn.Linear(128, num_styles) def forward(self, z): x = F.relu(self.fc1(z)) x = self.fc2(x) return F.softmax(x, dim=-1)