Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader | |
| from src.models.phoneme import PPGEncoder | |
| from src.constants import LIBRISPEECH_NUM_PHONEMES, LIBRISPEECH_PHONEME_DICT | |
| from src.data import LibriSpeechDataset | |
| from src.utils.writer import Writer | |
| import numpy as np | |
| from sklearn.metrics import confusion_matrix, classification_report | |
| import seaborn as sn | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| ################################################################################ | |
| # Train a simple model to produce phonetic posteriorgrams (PPGs) | |
| ################################################################################ | |
| def main(): | |
| # training hyperparameters | |
| lr = .001 | |
| epochs = 60 | |
| batch_size = 250 | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # phoneme encoder hyperparameters | |
| lstm_depth = 2 | |
| hidden_size = 128 # 512 | |
| win_length = 256 | |
| hop_length = 128 | |
| n_mels = 32 | |
| n_mfcc = 19 | |
| lookahead_frames = 0 # 1 | |
| # datasets and loaders | |
| train_data = LibriSpeechDataset( | |
| split='train-clean-100', | |
| target='phoneme', | |
| features=None, | |
| hop_length=hop_length | |
| ) | |
| val_data = LibriSpeechDataset( | |
| split='test-clean', | |
| target='phoneme', | |
| features=None, | |
| hop_length=hop_length | |
| ) | |
| train_loader = DataLoader( | |
| train_data, | |
| batch_size=batch_size, | |
| shuffle=True) | |
| val_loader = DataLoader( | |
| val_data, | |
| batch_size=batch_size) | |
| # initialize phoneme encoder | |
| encoder = PPGEncoder( | |
| win_length=win_length, | |
| hop_length=hop_length, | |
| win_func=torch.hann_window, | |
| n_mels=n_mels, | |
| n_mfcc=n_mfcc, | |
| lstm_depth=lstm_depth, | |
| hidden_size=hidden_size, | |
| ) | |
| # initialize classification layer and wrap as single module | |
| classifier = nn.Sequential( | |
| encoder, | |
| nn.Linear(hidden_size, LIBRISPEECH_NUM_PHONEMES) | |
| ).to(device) | |
| # log training progress | |
| writer = Writer( | |
| name=f"phoneme_lookahead_{lookahead_frames}", | |
| use_tb=True, | |
| log_iter=len(train_loader) | |
| ) | |
| import builtins | |
| parameter_count = builtins.sum([ | |
| p.shape.numel() | |
| for p in classifier[0].parameters() | |
| if p.requires_grad | |
| ]) | |
| writer.log_info(f'Training PPG model with lookahead {lookahead_frames}' | |
| f' ({parameter_count} parameters)') | |
| # initialize optimizer and loss function | |
| optimizer = torch.optim.Adam(classifier.parameters(), lr=lr) | |
| loss_fn = nn.CrossEntropyLoss() | |
| iter_id = 0 | |
| min_val_loss = float('inf') | |
| for epoch in range(epochs): | |
| print(f'beginning epoch {epoch}') | |
| classifier.train() | |
| for batch in train_loader: | |
| optimizer.zero_grad(set_to_none=True) | |
| x, y = batch['x'].to(device), batch['y'].to(device) | |
| preds = classifier(x) | |
| # offset labels to incorporate lookahead | |
| y = y[:, :-lookahead_frames if lookahead_frames else None] | |
| # offset predictions correspondingly | |
| preds = preds[:, lookahead_frames:] | |
| # compute cross-entropy loss | |
| loss = loss_fn( | |
| preds.reshape(-1, LIBRISPEECH_NUM_PHONEMES), y.reshape(-1) | |
| ) | |
| loss.backward() | |
| optimizer.step() | |
| writer.log_scalar(loss, tag="CrossEntropyLoss-Train", global_step=iter_id) | |
| iter_id += 1 | |
| val_loss, val_acc, n = 0.0, 0.0, 0 | |
| classifier.eval() | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| x, y = batch['x'].to(device), batch['y'].to(device) | |
| preds = classifier(x) | |
| # offset labels to incorporate lookahead | |
| y = y[:, :-lookahead_frames if lookahead_frames else None] | |
| # offset predictions correspondingly | |
| preds = preds[:, lookahead_frames:] | |
| n += len(x) | |
| val_loss += loss_fn( | |
| preds.reshape(-1, LIBRISPEECH_NUM_PHONEMES), y.reshape(-1) | |
| ) * len(x) | |
| val_acc += len(x) * (torch.argmax(preds, dim=2) == y).flatten().float().mean() | |
| val_loss /= n | |
| val_acc /= n | |
| writer.log_scalar(val_loss, tag="CrossEntropyLoss-Val", global_step=iter_id) | |
| writer.log_scalar(val_acc, tag="Accuracy-Val") | |
| # save weights | |
| if val_loss < min_val_loss: | |
| min_val_loss = val_loss | |
| print(f'new best val loss {val_loss}; saving weights') | |
| writer.checkpoint(classifier[0].state_dict(), 'phoneme_classifier') | |
| # generate confusion matrix | |
| classifier.eval() | |
| # compute accuracy on validation data | |
| all_preds = [] | |
| all_true = [] | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| x, y = batch['x'].to(device), batch['y'].to(device) | |
| preds = classifier(x) | |
| # offset labels to incorporate lookahead | |
| y = y[:, :-lookahead_frames if lookahead_frames else None] | |
| # offset predictions correspondingly | |
| preds = preds[:, lookahead_frames:] | |
| all_preds.append(preds.argmax(dim=2).reshape(-1)) | |
| all_true.append(y.reshape(-1)) | |
| # compile predictions and targets | |
| all_preds = torch.cat(all_preds, dim=0).cpu().numpy() | |
| all_true = torch.cat(all_true, dim=0).cpu().numpy() | |
| reverse_dict = {v: k for (k, v) in LIBRISPEECH_PHONEME_DICT.items() if v != 0} | |
| reverse_dict[0] = 'sil' | |
| class_report = classification_report(all_true, all_preds) | |
| writer.log_info(class_report) | |
| cm = confusion_matrix(all_true, all_preds, labels=list(range(len(reverse_dict)))) | |
| df_cm = pd.DataFrame(cm, index=[i for i in sorted(list(reverse_dict.keys()))], | |
| columns=[i for i in sorted(list(reverse_dict.keys()))]) | |
| plt.figure(figsize=(40, 28)) | |
| sn.set(font_scale=1.0) # for label size | |
| sn.heatmap(df_cm, annot=True, annot_kws={"size": 35 / np.sqrt(len(cm))}, fmt='g') | |
| plt.savefig("phoneme_cm.png", dpi=200) | |
| if __name__ == '__main__': | |
| main() | |