host / speaker /train.py
lord-reso's picture
Upload 10 files
b37f199 verified
import torch
import torchaudio.datasets as datasets
import torchaudio.transforms as transforms
from speaker.data import SpeakerMelLoader
from speaker.model import SpeakerEncoder
from speaker.utils import get_mapping_array
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from matplotlib import pyplot as plt
import os
from os import path
import numpy as np
diagram_path = 'diagrams'
accuracy_path = 'accuracy'
loss_path = 'loss'
silhouette_path = 'silhouette'
tsne_path = 'tsne'
def load_data(directory=".", batch_size=4, format='speaker', utter_per_speaker = 4, mel_type='Tacotron'):
dataset = SpeakerMelLoader(datasets.LIBRISPEECH(directory, download=True), format, utter_per_speaker,mel_type=mel_type)
return torch.utils.data.DataLoader(
dataset,
batch_size,
num_workers=4,
shuffle=True
)
def load_validation(directory=".", batch_size=4, format='speaker', utter_per_speaker = 4, mel_type='Tacotron'):
dataset = SpeakerMelLoader(datasets.LIBRISPEECH(directory, "dev-clean",download=True), format, utter_per_speaker,mel_type=mel_type)
return torch.utils.data.DataLoader(
dataset,
batch_size,
num_workers=4,
shuffle=True
)
def train(speaker_per_batch=4, utter_per_speaker=4, epochs=2, learning_rate=1e-4, mel_type='Tacotron'):
# Init data loader
train_loader = load_data(".", speaker_per_batch, 'speaker', utter_per_speaker,mel_type=mel_type)
valid_loader = load_validation(".", speaker_per_batch, 'speaker', utter_per_speaker,mel_type=mel_type)
# Device
# Loss calc may run faster on cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_device = torch.device("cpu")
# Init model
model = SpeakerEncoder(device, loss_device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
sil_scores = np.zeros(0)
gender_scores = np.zeros(0)
val_losses = np.zeros(0)
val_accuracy = np.zeros(0)
gender_mapper = get_mapping_array()
# Train loop
for e in range(epochs):
print('epoch:', e+1, 'of', epochs)
model.train()
# train_ids = np.zeros(0)
# train_embeds = np.zeros((0, 256))
for step, batch in enumerate(train_loader):
#Forward
#inputs: (speaker, utter, mel_len, mel_channel)
speaker_id, inputs = batch
#embed_inputs: (speaker*utter, mel_len, mel_channel)
embed_inputs = inputs.reshape(-1, *(inputs.shape[2:])).to(device)
#embeds: (speaker*utter, embed_dim)
embeds = model(embed_inputs)
#loss_embeds: (speaker, utter, embed_dim)
loss_embeds = embeds.view((speaker_per_batch,utter_per_speaker,-1)).to(loss_device)
loss = model.softmax_loss(loss_embeds)
if step % 10 == 0:
print('train e{}-s{}:'.format(e + 1, step + 1), 'loss', loss)
#Backward
model.zero_grad()
loss.backward()
model.gradient_clipping()
optimizer.step()
# train_ids = np.concatenate((train_ids, np.repeat(speaker_id, inputs.shape[1])))
# train_embeds = np.concatenate((train_embeds, embeds))
model.eval()
loss = 0
acc = 0
valid_ids = np.zeros(0)
valid_embeds = np.zeros((0, 256))
for step,batch in enumerate(valid_loader):
with torch.no_grad():
speaker_id, inputs = batch
embed_inputs = inputs.reshape(-1, *(inputs.shape[2:])).to(device)
embeds = model(embed_inputs)
loss_embeds = embeds.view((speaker_per_batch,utter_per_speaker,-1)).to(loss_device)
loss += model.softmax_loss(loss_embeds)
acc += model.accuracy(loss_embeds)
valid_ids = np.concatenate((valid_ids, np.repeat(speaker_id, inputs.shape[1])))
valid_embeds = np.concatenate((valid_embeds, embeds.to(loss_device).detach()))
val_losses = np.concatenate((val_losses, [loss.to(loss_device).detach() / (step + 1)]))
val_accuracy = np.concatenate((val_accuracy, [acc.to(loss_device).detach() / (step + 1)]))
sil_scores = np.concatenate((sil_scores, [silhouette_score(valid_embeds, valid_ids)]))
gender_scores = np.concatenate((gender_scores, [silhouette_score(valid_embeds, gender_mapper[valid_ids.astype('int')])]))
print('valid e{}'.format(e + 1), 'loss', val_losses[-1])
print('valid e{}'.format(e + 1), 'accuracy', val_accuracy[-1])
print('silhouette score', sil_scores[-1])
print('gender silhouette score', gender_scores[-1])
plot_speaker_embeddings(valid_embeds, valid_ids, f'tsne_e{e + 1}_speaker.png', f'T-SNE Plot: Epoch {e + 1}')
plot_random_embeddings(valid_embeds, valid_ids, f'tsne_e{e + 1}_random.png', title=f'T-SNE Plot: Epoch {e + 1}')
plot_gender_embeddings(valid_embeds, valid_ids, f'tsne_e{e + 1}_gender.png', f'T-SNE Plot: Epoch {e + 1}')
save_model(model, path.join('speaker', f'saved_model_e{e + 1}.pt'))
plt.figure()
plt.title('Silhouette Scores')
plt.xlabel('Epoch')
plt.ylabel('Silhouette Score')
plt.plot(np.arange(e + 1) + 1, sil_scores)
# plt.show()
plt.savefig(path.join(diagram_path, silhouette_path, f'sil_scores_{e + 1}.png'))
plt.close()
plt.figure()
plt.title('Silhouette Scores over Gender')
plt.xlabel('Epoch')
plt.ylabel('Silhouette Score')
plt.plot(np.arange(e + 1) + 1, gender_scores)
# plt.show()
plt.savefig(path.join(diagram_path, silhouette_path, f'gender_scores_{e + 1}.png'))
plt.close()
plt.figure()
plt.title('Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(np.arange(e + 1) + 1, val_losses)
# plt.show()
plt.savefig(path.join(diagram_path, loss_path, f'val_losses_{e + 1}.png'))
plt.close()
plt.figure()
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.plot(np.arange(e + 1) + 1, val_accuracy)
# plt.show()
plt.savefig(path.join(diagram_path, accuracy_path, f'val_accuracy_{e + 1}.png'))
plt.close()
return model
def save_model(model, path):
#Save model state to path
torch.save(model.state_dict(),path)
def load_model(path, device = None):
#Instantiate Model
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_device = torch.device("cpu")
model = SpeakerEncoder(device, loss_device)
#Load model state
model.load_state_dict(torch.load(path))
# Try this if running on multi-gpu setup or running model on cpu
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices
# model.load_state_dict(torch.load(PATH, map_location=device))
return model
def check_model(path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_device = torch.device("cpu")
print('**loading model')
model = load_model(path)
print('**loading data')
# data = load_data()
data = load_validation()
print('**running model')
loss_total = 0
acc_total = 0
all_ids = np.zeros(0)
all_embeds = np.zeros((0, 256))
for step, batch in enumerate(data):
speaker_id, inputs = batch
print('batch:', step)
embed_inputs = inputs.reshape(-1, *(inputs.shape[2:])).to(device)
embeds = model(embed_inputs)
loss_embeds = embeds.view(*(inputs.shape[:2]),-1).to(loss_device)
loss = model.softmax_loss(loss_embeds)
accuracy = model.accuracy(loss_embeds)
all_ids = np.concatenate((all_ids, np.repeat(speaker_id, inputs.shape[1])))
all_embeds = np.concatenate((all_embeds, embeds.to(loss_device).detach()))
loss_total += loss
acc_total += accuracy
# print('inputs.shape',inputs.shape)
# print('embed_inputs.embed_inputs',embeds.shape)
# print('embeds.shape',embeds.shape)
# print('loss_embeds.shape',loss_embeds.shape)
# print('loss.shape',loss.shape)
# print('loss',loss)
# print('accuracy',accuracy)
print('average loss', loss_total / (step+1))
print('average accuracy', acc_total / (step+1))
print('silhouette score', silhouette_score(all_embeds, all_ids))
plot_speaker_embeddings(all_embeds, all_ids, f'tsne_saved_speaker.png', f'T-SNE Plot')
plot_random_embeddings(all_embeds, all_ids, f'tsne_saved_random.png', title=f'T-SNE Plot')
plot_gender_embeddings(all_embeds, all_ids, f'tsne_saved_gender.png', f'T-SNE Plot')
def plot_gender_embeddings(embeddings, ids, filename, title='T-SNE Plot'):
# Per https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
# reducing dimensionality before running TSNE
pca = PCA(50)
reduction = pca.fit_transform(embeddings)
tsne = TSNE(init='pca', learning_rate='auto')
transformed = tsne.fit_transform(reduction)
gender_mapper = get_mapping_array()
genders = gender_mapper[ids.astype('int')]
females = genders == 1
males = genders == 2
plt.figure()
plt.title(title)
plt.scatter(transformed[females, 0], transformed[females, 1], label='Female')
plt.scatter(transformed[males, 0], transformed[males, 1], label='Male')
plt.legend()
plt.grid()
# plt.show()
plt.savefig(path.join(diagram_path, tsne_path, filename))
plt.close()
def plot_speaker_embeddings(embeddings, ids, filename, title='T-SNE Plot'):
# Per https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
# reducing dimensionality before running TSNE
pca = PCA(50)
reduction = pca.fit_transform(embeddings)
tsne = TSNE(init='pca', learning_rate='auto')
transformed = tsne.fit_transform(reduction)
ids = ids.astype('int')
unique_ids = np.unique(ids)
plt.figure()
plt.title(f'{title} Speakers')
for speaker_id in unique_ids:
speaker_idx = ids == speaker_id
plt.scatter(transformed[speaker_idx, 0], transformed[speaker_idx, 1], label=f'Speaker {speaker_id}')
# plt.legend()
plt.grid()
# plt.show()
plt.savefig(path.join(diagram_path, tsne_path, filename))
plt.close()
def plot_random_embeddings(embeddings, ids, filename, size=15, title='T-SNE Plot Random'):
# Per https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
# reducing dimensionality before running TSNE
pca = PCA(50)
reduction = pca.fit_transform(embeddings)
tsne = TSNE(init='pca', learning_rate='auto')
transformed = tsne.fit_transform(reduction)
ids = ids.astype('int')
unique_ids = np.unique(ids)
random_unique_ids = np.random.choice(ids, size=min(size, unique_ids.size), replace=False)
plt.figure()
plt.title(f'{title} - {random_unique_ids.size} Speakers')
for speaker_id in random_unique_ids:
speaker_idx = ids == speaker_id
plt.scatter(transformed[speaker_idx, 0], transformed[speaker_idx, 1], label=f'Speaker {speaker_id}')
# plt.legend()
plt.grid()
# plt.show()
plt.savefig(path.join(diagram_path, tsne_path, filename))
plt.close()
if __name__ == '__main__':
os.makedirs(diagram_path, exist_ok=True)
os.makedirs(path.join(diagram_path, loss_path), exist_ok=True)
os.makedirs(path.join(diagram_path, accuracy_path), exist_ok=True)
os.makedirs(path.join(diagram_path, tsne_path), exist_ok=True)
os.makedirs(path.join(diagram_path, silhouette_path), exist_ok=True)
# for speaker_id, mel in load_data():
# print(speaker_id, mel.shape)
# Might make sense to adjust speaker / utterance per batch, e.g. 64/10
m = train(epochs=300)
# save_model(m,'speaker/saved_model.pt')
check_model('speaker/saved_model_e175.pt')