voiceblock / voicebox /scripts /experiments /train_phoneme_predictor.py
ALeLacheur's picture
Voiceblock demo: Attempt 8
957e2dc
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.models.phoneme import PPGEncoder
from src.constants import LIBRISPEECH_NUM_PHONEMES, LIBRISPEECH_PHONEME_DICT
from src.data import LibriSpeechDataset
from src.utils.writer import Writer
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
################################################################################
# Train a simple model to produce phonetic posteriorgrams (PPGs)
################################################################################
def main():
# training hyperparameters
lr = .001
epochs = 60
batch_size = 250
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# phoneme encoder hyperparameters
lstm_depth = 2
hidden_size = 128 # 512
win_length = 256
hop_length = 128
n_mels = 32
n_mfcc = 19
lookahead_frames = 0 # 1
# datasets and loaders
train_data = LibriSpeechDataset(
split='train-clean-100',
target='phoneme',
features=None,
hop_length=hop_length
)
val_data = LibriSpeechDataset(
split='test-clean',
target='phoneme',
features=None,
hop_length=hop_length
)
train_loader = DataLoader(
train_data,
batch_size=batch_size,
shuffle=True)
val_loader = DataLoader(
val_data,
batch_size=batch_size)
# initialize phoneme encoder
encoder = PPGEncoder(
win_length=win_length,
hop_length=hop_length,
win_func=torch.hann_window,
n_mels=n_mels,
n_mfcc=n_mfcc,
lstm_depth=lstm_depth,
hidden_size=hidden_size,
)
# initialize classification layer and wrap as single module
classifier = nn.Sequential(
encoder,
nn.Linear(hidden_size, LIBRISPEECH_NUM_PHONEMES)
).to(device)
# log training progress
writer = Writer(
name=f"phoneme_lookahead_{lookahead_frames}",
use_tb=True,
log_iter=len(train_loader)
)
import builtins
parameter_count = builtins.sum([
p.shape.numel()
for p in classifier[0].parameters()
if p.requires_grad
])
writer.log_info(f'Training PPG model with lookahead {lookahead_frames}'
f' ({parameter_count} parameters)')
# initialize optimizer and loss function
optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
iter_id = 0
min_val_loss = float('inf')
for epoch in range(epochs):
print(f'beginning epoch {epoch}')
classifier.train()
for batch in train_loader:
optimizer.zero_grad(set_to_none=True)
x, y = batch['x'].to(device), batch['y'].to(device)
preds = classifier(x)
# offset labels to incorporate lookahead
y = y[:, :-lookahead_frames if lookahead_frames else None]
# offset predictions correspondingly
preds = preds[:, lookahead_frames:]
# compute cross-entropy loss
loss = loss_fn(
preds.reshape(-1, LIBRISPEECH_NUM_PHONEMES), y.reshape(-1)
)
loss.backward()
optimizer.step()
writer.log_scalar(loss, tag="CrossEntropyLoss-Train", global_step=iter_id)
iter_id += 1
val_loss, val_acc, n = 0.0, 0.0, 0
classifier.eval()
with torch.no_grad():
for batch in val_loader:
x, y = batch['x'].to(device), batch['y'].to(device)
preds = classifier(x)
# offset labels to incorporate lookahead
y = y[:, :-lookahead_frames if lookahead_frames else None]
# offset predictions correspondingly
preds = preds[:, lookahead_frames:]
n += len(x)
val_loss += loss_fn(
preds.reshape(-1, LIBRISPEECH_NUM_PHONEMES), y.reshape(-1)
) * len(x)
val_acc += len(x) * (torch.argmax(preds, dim=2) == y).flatten().float().mean()
val_loss /= n
val_acc /= n
writer.log_scalar(val_loss, tag="CrossEntropyLoss-Val", global_step=iter_id)
writer.log_scalar(val_acc, tag="Accuracy-Val")
# save weights
if val_loss < min_val_loss:
min_val_loss = val_loss
print(f'new best val loss {val_loss}; saving weights')
writer.checkpoint(classifier[0].state_dict(), 'phoneme_classifier')
# generate confusion matrix
classifier.eval()
# compute accuracy on validation data
all_preds = []
all_true = []
with torch.no_grad():
for batch in val_loader:
x, y = batch['x'].to(device), batch['y'].to(device)
preds = classifier(x)
# offset labels to incorporate lookahead
y = y[:, :-lookahead_frames if lookahead_frames else None]
# offset predictions correspondingly
preds = preds[:, lookahead_frames:]
all_preds.append(preds.argmax(dim=2).reshape(-1))
all_true.append(y.reshape(-1))
# compile predictions and targets
all_preds = torch.cat(all_preds, dim=0).cpu().numpy()
all_true = torch.cat(all_true, dim=0).cpu().numpy()
reverse_dict = {v: k for (k, v) in LIBRISPEECH_PHONEME_DICT.items() if v != 0}
reverse_dict[0] = 'sil'
class_report = classification_report(all_true, all_preds)
writer.log_info(class_report)
cm = confusion_matrix(all_true, all_preds, labels=list(range(len(reverse_dict))))
df_cm = pd.DataFrame(cm, index=[i for i in sorted(list(reverse_dict.keys()))],
columns=[i for i in sorted(list(reverse_dict.keys()))])
plt.figure(figsize=(40, 28))
sn.set(font_scale=1.0) # for label size
sn.heatmap(df_cm, annot=True, annot_kws={"size": 35 / np.sqrt(len(cm))}, fmt='g')
plt.savefig("phoneme_cm.png", dpi=200)
if __name__ == '__main__':
main()