Spaces:
Paused
Paused
| import json | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tools import create_key | |
| class TimbreEncoder(nn.Module): | |
| def __init__(self, input_dim, feature_dim, hidden_dim, num_instrument_classes, num_instrument_family_classes, num_velocity_classes, num_qualities, num_layers=1): | |
| super(TimbreEncoder, self).__init__() | |
| # Input layer | |
| self.input_layer = nn.Linear(input_dim, feature_dim) | |
| # LSTM Layer | |
| self.lstm = nn.LSTM(feature_dim, hidden_dim, num_layers=num_layers, batch_first=True) | |
| # Fully Connected Layers for classification | |
| self.instrument_classifier_layer = nn.Linear(hidden_dim, num_instrument_classes) | |
| self.instrument_family_classifier_layer = nn.Linear(hidden_dim, num_instrument_family_classes) | |
| self.velocity_classifier_layer = nn.Linear(hidden_dim, num_velocity_classes) | |
| self.qualities_classifier_layer = nn.Linear(hidden_dim, num_qualities) | |
| # Softmax for converting output to probabilities | |
| self.softmax = nn.LogSoftmax(dim=1) | |
| def forward(self, x): | |
| # # Merge first two dimensions | |
| batch_size, _, _, seq_len = x.shape | |
| x = x.view(batch_size, -1, seq_len) # [batch_size, input_dim, seq_len] | |
| # Forward propagate LSTM | |
| x = x.permute(0, 2, 1) | |
| x = self.input_layer(x) | |
| feature, _ = self.lstm(x) | |
| feature = feature[:, -1, :] | |
| # Apply classification layers | |
| instrument_logits = self.instrument_classifier_layer(feature) | |
| instrument_family_logits = self.instrument_family_classifier_layer(feature) | |
| velocity_logits = self.velocity_classifier_layer(feature) | |
| qualities = self.qualities_classifier_layer(feature) | |
| # Apply Softmax | |
| instrument_logits = self.softmax(instrument_logits) | |
| instrument_family_logits= self.softmax(instrument_family_logits) | |
| velocity_logits = self.softmax(velocity_logits) | |
| qualities = torch.sigmoid(qualities) | |
| return feature, instrument_logits, instrument_family_logits, velocity_logits, qualities | |
| def get_multiclass_acc(outputs, ground_truth): | |
| _, predicted = torch.max(outputs.data, 1) | |
| total = ground_truth.size(0) | |
| correct = (predicted == ground_truth).sum().item() | |
| accuracy = 100 * correct / total | |
| return accuracy | |
| def get_binary_accuracy(y_pred, y_true): | |
| predictions = (y_pred > 0.5).int() | |
| correct_predictions = (predictions == y_true).float() | |
| accuracy = correct_predictions.mean() | |
| return accuracy.item() * 100.0 | |
| def get_timbre_encoder(model_Config, load_pretrain=False, model_name=None, device="cpu"): | |
| timbreEncoder = TimbreEncoder(**model_Config) | |
| print(f"Model intialized, size: {sum(p.numel() for p in timbreEncoder.parameters() if p.requires_grad)}") | |
| timbreEncoder.to(device) | |
| if load_pretrain: | |
| print(f"Loading weights from models/{model_name}_timbre_encoder.pth") | |
| checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth', map_location=device) | |
| timbreEncoder.load_state_dict(checkpoint['model_state_dict']) | |
| timbreEncoder.eval() | |
| return timbreEncoder | |
| def evaluate_timbre_encoder(device, model, iterator, nll_Loss, bce_Loss, n_sample=100): | |
| model.to(device) | |
| model.eval() | |
| eva_loss = [] | |
| for i in range(n_sample): | |
| representation, attributes = next(iter(iterator)) | |
| instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device) | |
| instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device) | |
| velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device) | |
| qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device) | |
| _, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device)) | |
| # compute loss | |
| instrument_loss = nll_Loss(instrument_logits, instrument) | |
| instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family) | |
| velocity_loss = nll_Loss(velocity_logits, velocity) | |
| qualities_loss = bce_Loss(qualities_pred, qualities) | |
| loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss | |
| eva_loss.append(loss.item()) | |
| eva_loss = np.mean(eva_loss) | |
| return eva_loss | |
| def train_timbre_encoder(device, model_name, timbre_encoder_Config, BATCH_SIZE, lr, max_iter, training_iterator, load_pretrain): | |
| def save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, current_iter, | |
| current_loss): | |
| model_hyperparameter = timbre_encoder_Config | |
| model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE | |
| model_hyperparameter["lr"] = lr | |
| model_hyperparameter["model_size"] = model_size | |
| model_hyperparameter["current_iter"] = current_iter | |
| model_hyperparameter["current_loss"] = current_loss | |
| with open(f"models/hyperparameters/{model_name}_timbre_encoder.json", "w") as json_file: | |
| json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) | |
| model = TimbreEncoder(**timbre_encoder_Config) | |
| model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Model size: {model_size}") | |
| model.to(device) | |
| nll_Loss = torch.nn.NLLLoss() | |
| bce_Loss = torch.nn.BCELoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=False) | |
| if load_pretrain: | |
| print(f"Loading weights from models/{model_name}_timbre_encoder.pt") | |
| checkpoint = torch.load(f'models/{model_name}_timbre_encoder.pth') | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| else: | |
| print("Model initialized.") | |
| if max_iter == 0: | |
| print("Return model directly.") | |
| return model, model | |
| train_loss, training_instrument_acc, training_instrument_family_acc, training_velocity_acc, training_qualities_acc = [], [], [], [], [] | |
| writer = SummaryWriter(f'runs/{model_name}_timbre_encoder') | |
| current_best_model = model | |
| previous_lowest_loss = 100.0 | |
| print(f"initial__loss: {previous_lowest_loss}") | |
| for i in range(max_iter): | |
| model.train() | |
| representation, attributes = next(iter(training_iterator)) | |
| instrument = torch.tensor([s["instrument"] for s in attributes], dtype=torch.long).to(device) | |
| instrument_family = torch.tensor([s["instrument_family"] for s in attributes], dtype=torch.long).to(device) | |
| velocity = torch.tensor([s["velocity"] for s in attributes], dtype=torch.long).to(device) | |
| qualities = torch.tensor([[int(char) for char in create_key(attribute)[-10:]] for attribute in attributes], dtype=torch.float32).to(device) | |
| optimizer.zero_grad() | |
| _, instrument_logits, instrument_family_logits, velocity_logits, qualities_pred = model(representation.to(device)) | |
| # compute loss | |
| instrument_loss = nll_Loss(instrument_logits, instrument) | |
| instrument_family_loss = nll_Loss(instrument_family_logits, instrument_family) | |
| velocity_loss = nll_Loss(velocity_logits, velocity) | |
| qualities_loss = bce_Loss(qualities_pred, qualities) | |
| loss = instrument_loss + instrument_family_loss + velocity_loss + qualities_loss | |
| loss.backward() | |
| optimizer.step() | |
| instrument_acc = get_multiclass_acc(instrument_logits, instrument) | |
| instrument_family_acc = get_multiclass_acc(instrument_family_logits, instrument_family) | |
| velocity_acc = get_multiclass_acc(velocity_logits, velocity) | |
| qualities_acc = get_binary_accuracy(qualities_pred, qualities) | |
| train_loss.append(loss.item()) | |
| training_instrument_acc.append(instrument_acc) | |
| training_instrument_family_acc.append(instrument_family_acc) | |
| training_velocity_acc.append(velocity_acc) | |
| training_qualities_acc.append(qualities_acc) | |
| step = int(optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].numpy()) | |
| if (i + 1) % 100 == 0: | |
| print('%d step' % (step)) | |
| save_steps = 500 | |
| if (i + 1) % save_steps == 0: | |
| current_loss = np.mean(train_loss[-save_steps:]) | |
| current_instrument_acc = np.mean(training_instrument_acc[-save_steps:]) | |
| current_instrument_family_acc = np.mean(training_instrument_family_acc[-save_steps:]) | |
| current_velocity_acc = np.mean(training_velocity_acc[-save_steps:]) | |
| current_qualities_acc = np.mean(training_qualities_acc[-save_steps:]) | |
| print('train_loss: %.5f' % current_loss) | |
| print('current_instrument_acc: %.5f' % current_instrument_acc) | |
| print('current_instrument_family_acc: %.5f' % current_instrument_family_acc) | |
| print('current_velocity_acc: %.5f' % current_velocity_acc) | |
| print('current_qualities_acc: %.5f' % current_qualities_acc) | |
| writer.add_scalar(f"train_loss", current_loss, step) | |
| writer.add_scalar(f"current_instrument_acc", current_instrument_acc, step) | |
| writer.add_scalar(f"current_instrument_family_acc", current_instrument_family_acc, step) | |
| writer.add_scalar(f"current_velocity_acc", current_velocity_acc, step) | |
| writer.add_scalar(f"current_qualities_acc", current_qualities_acc, step) | |
| if current_loss < previous_lowest_loss: | |
| previous_lowest_loss = current_loss | |
| current_best_model = model | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| }, f'models/{model_name}_timbre_encoder.pth') | |
| save_model_hyperparameter(model_name, timbre_encoder_Config, BATCH_SIZE, lr, model_size, step, | |
| current_loss) | |
| return model, current_best_model | |