lord-reso commited on
Commit
b37f199
·
verified ·
1 Parent(s): 1dbe792

Upload 10 files

Browse files
speaker/__init__.py ADDED
File without changes
speaker/data.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio.datasets as datasets
3
+ import torchaudio.transforms as transforms
4
+ from collections import defaultdict
5
+ import random
6
+ import layers
7
+
8
+ import warnings
9
+
10
+ class SpeakerMelLoader(torch.utils.data.Dataset):
11
+ """
12
+ computes mel-spectrograms from audio file and pulls the speaker ID from the
13
+ dataset
14
+ """
15
+
16
+ def __init__(self, dataset, format='speaker', speaker_utterances=4, mel_length = 128, mel_type = 'Tacotron'):
17
+ self.dataset = dataset
18
+ self.set_format(format)
19
+ self.speaker_utterances = speaker_utterances
20
+ self.mel_length = mel_length
21
+ self.mel_type = mel_type
22
+ self.mel_generators = dict()
23
+
24
+ def set_format(self,format):
25
+ self.format = format
26
+
27
+ if format == 'speaker':
28
+ self.create_speaker_index()
29
+
30
+ def create_speaker_index(self):
31
+ vals = [x.split('-',1) for x in self.dataset._walker]
32
+ speaker_map = defaultdict(list)
33
+
34
+ for i,v in enumerate(vals):
35
+ speaker_map[v[0]].append(i)
36
+
37
+ self.speaker_map = speaker_map
38
+ self.speaker_keys = list(speaker_map.keys())
39
+
40
+ def apply_mel_gen(self, waveform, sampling_rate, channels=80):
41
+ if (sampling_rate, channels) not in self.mel_generators:
42
+ if self.mel_type == 'MFCC':
43
+ mel_gen = transforms.MFCC(sample_rate=sampling_rate, n_mfcc=channels)
44
+ elif self.mel_type == 'Mel':
45
+ mel_gen = transforms.MelSpectrogram(sample_rate=sampling_rate, n_mels=channels)
46
+ elif self.mel_type == 'Tacotron':
47
+ mel_gen = layers.TacotronSTFT(sampling_rate=sampling_rate,n_mel_channels=channels)
48
+ else:
49
+ raise NotImplementedError('Unsupported mel_type in MelSpeakerLoader: '+self.mel_type)
50
+ self.mel_generators[(sampling_rate,channels)] = mel_gen
51
+ else:
52
+ mel_gen = self.mel_generators[(sampling_rate, channels)]
53
+
54
+ if self.mel_type == 'Tacotron':
55
+ #Replicating from Tacotron2 data loader
56
+ max_wav_value=32768.0
57
+ #skip normalization from Tacotron2, LibriSpeech data looks pre-normalized (all vals between 0-1)
58
+ audio_norm = waveform #/ max_wav_value
59
+ audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
60
+ melspec = mel_gen.mel_spectrogram(audio_norm)
61
+ else:
62
+ audio = waveform.unsqueeze(0)
63
+ audio = torch.autograd.Variable(audio, requires_grad=False)
64
+ melspec = mel_gen(audio)
65
+
66
+ return melspec
67
+
68
+ def get_mel(self, waveform, sampling_rate, channels=80):
69
+ # We previously identified that these warnings were ok.
70
+ with warnings.catch_warnings():
71
+ warnings.filterwarnings('ignore', message=r'At least one mel filterbank has all zero values.*', module=r'torchaudio.*')
72
+ melspec = self.apply_mel_gen(waveform, sampling_rate, channels)
73
+ # melspec is (1,1,channels, time) by default
74
+ # return (time, channels)
75
+ melspec = torch.squeeze(melspec).T
76
+ return melspec
77
+
78
+ def __getitem__(self, index):
79
+ if self.format == 'utterance':
80
+ (waveform, sample_rate, _, speaker_id, _, _) = self.dataset[index]
81
+ mel = self.get_mel(waveform, sample_rate)
82
+ return (speaker_id, mel)
83
+ elif self.format == 'speaker':
84
+ speaker_id = self.speaker_keys[index]
85
+ utter_indexes = random.sample(self.speaker_map[speaker_id], self.speaker_utterances)
86
+ mels = []
87
+ for i in utter_indexes:
88
+ (waveform, sample_rate, _, speaker_id, _, _) = self.dataset[i]
89
+ mel = self.get_mel(waveform, sample_rate)
90
+ if mel.shape[0] < self.mel_length:
91
+ #Zero pad mel on the right to mel_length
92
+ #pad_tuple is (dn start, dn end, dn-1 start, dn-1 end, ... , d1 start, d1 end)
93
+ pad_tuple = (0,0,0,self.mel_length-mel.shape[0])
94
+ mel=torch.nn.functional.pad(mel,pad_tuple)
95
+ mel_frame = 0
96
+ else:
97
+ mel_frame = random.randint(0,mel.shape[0]-self.mel_length)
98
+ mels.append(mel[mel_frame:mel_frame+self.mel_length,:])
99
+ return (speaker_id, torch.stack(mels,0))
100
+ else:
101
+ raise NotImplementedError()
102
+
103
+ def __len__(self):
104
+ if self.format == 'utterance':
105
+ return len(self.dataset)
106
+ elif self.format == 'speaker':
107
+ return len(self.speaker_keys)
108
+ else:
109
+ raise NotImplementedError()
speaker/model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import numpy as np
3
+ import torch
4
+ from torch.nn.utils import clip_grad_norm_
5
+
6
+ class SpeakerEncoder(nn.Module):
7
+ """ Learn speaker representation from speech utterance of arbitrary lengths.
8
+ """
9
+ def __init__(self, device, loss_device):
10
+ super().__init__()
11
+ self.loss_device = loss_device
12
+
13
+ # lstm block consisting of 3 layers
14
+ # takes input 80 channel log-mel spectrograms, projected to 256 dimensions
15
+ self.lstm = nn.LSTM(
16
+ input_size=80,
17
+ hidden_size=256,
18
+ num_layers=3,
19
+ batch_first=True,
20
+ dropout=0,
21
+ bidirectional=False
22
+ ).to(device)
23
+
24
+ self.linear = nn.Linear(in_features=256, out_features=256).to(device)
25
+ self.relu = nn.ReLU().to(device)
26
+ # epsilon term for numerical stability ( ie - division by 0)
27
+ self.epsilon = 1e-5
28
+
29
+ #Cosine similarity weights
30
+ self.sim_weight = nn.Parameter(torch.tensor([5.])).to(loss_device)
31
+ self.sim_bias = nn.Parameter(torch.tensor([-1.])).to(loss_device)
32
+
33
+ def forward(self, utterances, h_init=None, c_init=None):
34
+ # implement section 2.1 from https://arxiv.org/pdf/1806.04558.pdf
35
+ if h_init is None or c_init is None:
36
+ out, (hidden, cell) = self.lstm(utterances)
37
+ else:
38
+ out, (hidden, cell) = self.lstm(utterances, (h_init, c_init))
39
+
40
+ # compute speaker embedding from hidden state of final layer
41
+ final_hidden = hidden[-1]
42
+ speaker_embedding = self.relu(self.linear(final_hidden))
43
+
44
+ # l2 norm of speaker embedding
45
+ speaker_embedding = speaker_embedding / (torch.norm(speaker_embedding, dim=1, keepdim=True) + self.epsilon)
46
+ return speaker_embedding
47
+
48
+ def gradient_clipping(self):
49
+ self.sim_weight.grad *= 0.01
50
+ self.sim_bias.grad *= 0.01
51
+
52
+ #Pytorch to clip gradients if norm greater than max
53
+ clip_grad_norm_(self.parameters(),max_norm=3,norm_type=2)
54
+
55
+ def similarity_matrix(self, embeds, debug=False):
56
+ # calculate s_ji,k from section 2.1 of GE2E paper
57
+ # output matrix is cosine similarity between each utterance x centroid of each speaker
58
+ # embeds input size: (speakers, utterances, embedding size)
59
+
60
+ # Speaker centroids
61
+ # Equal to average of utterance embeddings for the speaker
62
+ # Used for neg examples (utterance comparing to false speaker)
63
+ # Equation 1 in paper
64
+ # size: (speakers, 1, embedding size)
65
+ speaker_centroid = torch.mean(embeds,dim=1,keepdim=True)
66
+
67
+ # Utterance exclusive centroids
68
+ # Equal to average of utterance embeddings for the speaker, excluding ith utterance
69
+ # Used for pos samples (utterance comparing to true speaker; speaker centroid exludes the utterance)
70
+ # Equation 8 in paper
71
+ # size: (speakers, utterances, embedding size)
72
+ num_utterance = embeds.shape[1]
73
+ utter_ex_centroid = (torch.sum(embeds,dim=1,keepdim=True) - embeds) / (num_utterance-1)
74
+
75
+ if debug:
76
+ print("e",embeds.shape)
77
+ print(embeds)
78
+ print("sc",speaker_centroid.shape)
79
+ print(speaker_centroid)
80
+ print("uc",utter_ex_centroid.shape)
81
+ print(utter_ex_centroid)
82
+
83
+ # Create pos and neg masks
84
+ num_speaker = embeds.shape[0]
85
+ i = torch.eye(num_speaker, dtype=torch.int)
86
+ pos_mask = torch.where(i)
87
+ neg_mask = torch.where(1-i)
88
+
89
+ if debug:
90
+ print("pm",len(pos_mask),len(pos_mask[0]))
91
+ print(pos_mask)
92
+ print("nm",len(neg_mask),len(neg_mask[0]))
93
+ print(neg_mask)
94
+
95
+ # Compile similarity matrix
96
+ # size: (speakers, utterances, speakers)
97
+ # initial size is (speakers, speakers, utterances for easier vectorization)
98
+ sim_matrix = torch.zeros(num_speaker, num_speaker, num_utterance).to(self.loss_device)
99
+ sim_matrix[pos_mask] = nn.functional.cosine_similarity(embeds,utter_ex_centroid,dim=2)
100
+ sim_matrix[neg_mask] = nn.functional.cosine_similarity(embeds[neg_mask[0]],speaker_centroid[neg_mask[1]],dim=2)
101
+ if debug:
102
+ print("sm",sim_matrix.shape)
103
+ print("pos vals",sim_matrix[pos_mask])
104
+ print("neg vals",sim_matrix[neg_mask])
105
+ print(sim_matrix)
106
+
107
+ sim_matrix = sim_matrix.permute(0,2,1)
108
+
109
+ if debug:
110
+ print("sm",sim_matrix.shape)
111
+ print(sim_matrix)
112
+ print("cos sim weight", self.sim_weight)
113
+ print("cos sim bias", self.sim_bias)
114
+
115
+ # Apply weight / bias
116
+ sim_matrix = sim_matrix * self.sim_weight + self.sim_bias
117
+ return sim_matrix
118
+
119
+ def softmax_loss(self, embeds):
120
+ """
121
+ computes softmax loss as defined by equ 6 in the GE2E paper
122
+ :param embeds: shape (speakers, utterances, embedding size)
123
+ :return: computed softmax loss
124
+ """
125
+ # per the GE2E paper, softmax loss as defined by equ 6
126
+ # performs slightly better over Text-Independent Speaker
127
+ # Verification tasks.
128
+ # ref section 2.1 of the GE2E paper
129
+ speaker_count = embeds.shape[0]
130
+
131
+ # speaker, utterance, speaker
132
+ similarities = self.similarity_matrix(embeds)
133
+
134
+ # equ 6
135
+ loss_matrix = -similarities[torch.arange(0, speaker_count), :, torch.arange(0, speaker_count)] + \
136
+ torch.log(torch.sum(torch.exp(similarities), dim=2))
137
+
138
+ # equ 10
139
+ return torch.sum(loss_matrix)
140
+
141
+ def contrast_loss(self, embeds):
142
+ """
143
+ computes contrast loss as defined by equ 7 in the GE2E paper
144
+ :param embeds: shape (speakers, utterances, embedding size)
145
+ :return: computed softmax loss
146
+ """
147
+ # per the GE2E paper, contrast loss as defined by equ 7
148
+ # performs slightly better over Text-Dependent Speaker
149
+ # Verification tasks.
150
+ # ref section 2.1 of the GE2E paper
151
+ speaker_count, utterance_count = embeds.shape[0:2]
152
+
153
+ # speaker, utterance, speaker
154
+ similarities = self.similarity_matrix(embeds)
155
+
156
+ # Janky indexing to resolve k != j
157
+ mask = torch.ones(similarities.shape, dtype=torch.bool)
158
+ mask[torch.arange(speaker_count), :, torch.arange(speaker_count)] = False
159
+ closest_neighbors, _ = torch.max(similarities[mask].reshape(speaker_count, utterance_count, speaker_count - 1), dim=2)
160
+
161
+ # Positive influence over matching embeddings
162
+ matching_embedding = similarities[torch.arange(0, speaker_count), :, torch.arange(0, speaker_count)]
163
+
164
+ # equ 7
165
+ loss_matrix = 1 - torch.sigmoid(matching_embedding) + torch.sigmoid(closest_neighbors)
166
+
167
+ # equ 10
168
+ return torch.sum(loss_matrix)
169
+
170
+ def accuracy(self, embeds):
171
+ """
172
+ computes argmax accuracy
173
+ :param embeds: shape (speakers, utterances, speakers)
174
+ :return: accuracy
175
+ """
176
+ num_speaker, num_utter = embeds.shape[:2]
177
+
178
+ similarities = self.similarity_matrix(embeds)
179
+ preds = torch.argmax(similarities, dim=2)
180
+ preds_one_hot = torch.nn.functional.one_hot(preds,num_classes = num_speaker)
181
+
182
+ actual = torch.arange(num_speaker).unsqueeze(1).repeat(1,num_utter)
183
+ actual_one_hot = torch.nn.functional.one_hot(actual,num_classes=num_speaker)
184
+
185
+ return torch.sum(preds_one_hot * actual_one_hot)/(num_speaker*num_utter)
186
+
187
+
188
+
189
+
190
+
191
+
speaker/preprocess.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Reference https://github.com/CorentinJ/Real-Time-Voice-Cloning/blob/0713f860a3dd41afb56e83cff84dbdf589d5e11a/encoder/preprocess.py#L16
speaker/saved_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ccc0abcd0fb77104be73e6675454a06e7797bf1d4a1177181c32b648e9d75a9
3
+ size 5697243
speaker/saved_model_e175.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52ba80266b9f45fc3d825942aae40858eeaaa73994ba86e9ed017a533dc13323
3
+ size 5861083
speaker/speakers.txt ADDED
The diff for this file is too large to render. See raw diff
 
speaker/tacotron_mel_e10.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9799bc6035aa1e555968c1fb2f1ca8b8bb0cdb10f11875cb4cbc1411d811a59b
3
+ size 5861083
speaker/train.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio.datasets as datasets
3
+ import torchaudio.transforms as transforms
4
+ from speaker.data import SpeakerMelLoader
5
+ from speaker.model import SpeakerEncoder
6
+ from speaker.utils import get_mapping_array
7
+
8
+ from sklearn.manifold import TSNE
9
+ from sklearn.decomposition import PCA
10
+ from sklearn.metrics import silhouette_score
11
+
12
+ from matplotlib import pyplot as plt
13
+
14
+ import os
15
+ from os import path
16
+
17
+ import numpy as np
18
+
19
+ diagram_path = 'diagrams'
20
+ accuracy_path = 'accuracy'
21
+ loss_path = 'loss'
22
+ silhouette_path = 'silhouette'
23
+ tsne_path = 'tsne'
24
+
25
+
26
+ def load_data(directory=".", batch_size=4, format='speaker', utter_per_speaker = 4, mel_type='Tacotron'):
27
+ dataset = SpeakerMelLoader(datasets.LIBRISPEECH(directory, download=True), format, utter_per_speaker,mel_type=mel_type)
28
+ return torch.utils.data.DataLoader(
29
+ dataset,
30
+ batch_size,
31
+ num_workers=4,
32
+ shuffle=True
33
+ )
34
+
35
+
36
+ def load_validation(directory=".", batch_size=4, format='speaker', utter_per_speaker = 4, mel_type='Tacotron'):
37
+ dataset = SpeakerMelLoader(datasets.LIBRISPEECH(directory, "dev-clean",download=True), format, utter_per_speaker,mel_type=mel_type)
38
+ return torch.utils.data.DataLoader(
39
+ dataset,
40
+ batch_size,
41
+ num_workers=4,
42
+ shuffle=True
43
+ )
44
+
45
+
46
+ def train(speaker_per_batch=4, utter_per_speaker=4, epochs=2, learning_rate=1e-4, mel_type='Tacotron'):
47
+ # Init data loader
48
+ train_loader = load_data(".", speaker_per_batch, 'speaker', utter_per_speaker,mel_type=mel_type)
49
+ valid_loader = load_validation(".", speaker_per_batch, 'speaker', utter_per_speaker,mel_type=mel_type)
50
+
51
+ # Device
52
+ # Loss calc may run faster on cpu
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ loss_device = torch.device("cpu")
55
+
56
+ # Init model
57
+ model = SpeakerEncoder(device, loss_device)
58
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
59
+
60
+ sil_scores = np.zeros(0)
61
+ gender_scores = np.zeros(0)
62
+ val_losses = np.zeros(0)
63
+ val_accuracy = np.zeros(0)
64
+
65
+ gender_mapper = get_mapping_array()
66
+
67
+ # Train loop
68
+ for e in range(epochs):
69
+ print('epoch:', e+1, 'of', epochs)
70
+
71
+ model.train()
72
+ # train_ids = np.zeros(0)
73
+ # train_embeds = np.zeros((0, 256))
74
+ for step, batch in enumerate(train_loader):
75
+ #Forward
76
+ #inputs: (speaker, utter, mel_len, mel_channel)
77
+ speaker_id, inputs = batch
78
+ #embed_inputs: (speaker*utter, mel_len, mel_channel)
79
+ embed_inputs = inputs.reshape(-1, *(inputs.shape[2:])).to(device)
80
+ #embeds: (speaker*utter, embed_dim)
81
+ embeds = model(embed_inputs)
82
+ #loss_embeds: (speaker, utter, embed_dim)
83
+ loss_embeds = embeds.view((speaker_per_batch,utter_per_speaker,-1)).to(loss_device)
84
+ loss = model.softmax_loss(loss_embeds)
85
+
86
+ if step % 10 == 0:
87
+ print('train e{}-s{}:'.format(e + 1, step + 1), 'loss', loss)
88
+
89
+ #Backward
90
+ model.zero_grad()
91
+ loss.backward()
92
+ model.gradient_clipping()
93
+ optimizer.step()
94
+
95
+ # train_ids = np.concatenate((train_ids, np.repeat(speaker_id, inputs.shape[1])))
96
+ # train_embeds = np.concatenate((train_embeds, embeds))
97
+
98
+ model.eval()
99
+ loss = 0
100
+ acc = 0
101
+
102
+ valid_ids = np.zeros(0)
103
+ valid_embeds = np.zeros((0, 256))
104
+
105
+ for step,batch in enumerate(valid_loader):
106
+ with torch.no_grad():
107
+ speaker_id, inputs = batch
108
+ embed_inputs = inputs.reshape(-1, *(inputs.shape[2:])).to(device)
109
+ embeds = model(embed_inputs)
110
+ loss_embeds = embeds.view((speaker_per_batch,utter_per_speaker,-1)).to(loss_device)
111
+ loss += model.softmax_loss(loss_embeds)
112
+ acc += model.accuracy(loss_embeds)
113
+ valid_ids = np.concatenate((valid_ids, np.repeat(speaker_id, inputs.shape[1])))
114
+ valid_embeds = np.concatenate((valid_embeds, embeds.to(loss_device).detach()))
115
+
116
+ val_losses = np.concatenate((val_losses, [loss.to(loss_device).detach() / (step + 1)]))
117
+ val_accuracy = np.concatenate((val_accuracy, [acc.to(loss_device).detach() / (step + 1)]))
118
+ sil_scores = np.concatenate((sil_scores, [silhouette_score(valid_embeds, valid_ids)]))
119
+ gender_scores = np.concatenate((gender_scores, [silhouette_score(valid_embeds, gender_mapper[valid_ids.astype('int')])]))
120
+ print('valid e{}'.format(e + 1), 'loss', val_losses[-1])
121
+ print('valid e{}'.format(e + 1), 'accuracy', val_accuracy[-1])
122
+ print('silhouette score', sil_scores[-1])
123
+ print('gender silhouette score', gender_scores[-1])
124
+
125
+ plot_speaker_embeddings(valid_embeds, valid_ids, f'tsne_e{e + 1}_speaker.png', f'T-SNE Plot: Epoch {e + 1}')
126
+ plot_random_embeddings(valid_embeds, valid_ids, f'tsne_e{e + 1}_random.png', title=f'T-SNE Plot: Epoch {e + 1}')
127
+ plot_gender_embeddings(valid_embeds, valid_ids, f'tsne_e{e + 1}_gender.png', f'T-SNE Plot: Epoch {e + 1}')
128
+
129
+ save_model(model, path.join('speaker', f'saved_model_e{e + 1}.pt'))
130
+
131
+ plt.figure()
132
+ plt.title('Silhouette Scores')
133
+ plt.xlabel('Epoch')
134
+ plt.ylabel('Silhouette Score')
135
+ plt.plot(np.arange(e + 1) + 1, sil_scores)
136
+ # plt.show()
137
+ plt.savefig(path.join(diagram_path, silhouette_path, f'sil_scores_{e + 1}.png'))
138
+ plt.close()
139
+
140
+ plt.figure()
141
+ plt.title('Silhouette Scores over Gender')
142
+ plt.xlabel('Epoch')
143
+ plt.ylabel('Silhouette Score')
144
+ plt.plot(np.arange(e + 1) + 1, gender_scores)
145
+ # plt.show()
146
+ plt.savefig(path.join(diagram_path, silhouette_path, f'gender_scores_{e + 1}.png'))
147
+ plt.close()
148
+
149
+ plt.figure()
150
+ plt.title('Validation Loss')
151
+ plt.xlabel('Epoch')
152
+ plt.ylabel('Loss')
153
+ plt.plot(np.arange(e + 1) + 1, val_losses)
154
+ # plt.show()
155
+ plt.savefig(path.join(diagram_path, loss_path, f'val_losses_{e + 1}.png'))
156
+ plt.close()
157
+
158
+ plt.figure()
159
+ plt.title('Validation Accuracy')
160
+ plt.xlabel('Epoch')
161
+ plt.ylabel('Accuracy')
162
+ plt.plot(np.arange(e + 1) + 1, val_accuracy)
163
+ # plt.show()
164
+ plt.savefig(path.join(diagram_path, accuracy_path, f'val_accuracy_{e + 1}.png'))
165
+ plt.close()
166
+
167
+ return model
168
+
169
+
170
+ def save_model(model, path):
171
+ #Save model state to path
172
+ torch.save(model.state_dict(),path)
173
+
174
+
175
+ def load_model(path, device = None):
176
+ #Instantiate Model
177
+ if device is None:
178
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
179
+ loss_device = torch.device("cpu")
180
+ model = SpeakerEncoder(device, loss_device)
181
+
182
+ #Load model state
183
+ model.load_state_dict(torch.load(path))
184
+ # Try this if running on multi-gpu setup or running model on cpu
185
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices
186
+ # model.load_state_dict(torch.load(PATH, map_location=device))
187
+ return model
188
+
189
+
190
+ def check_model(path):
191
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
192
+ loss_device = torch.device("cpu")
193
+
194
+ print('**loading model')
195
+ model = load_model(path)
196
+
197
+ print('**loading data')
198
+ # data = load_data()
199
+ data = load_validation()
200
+
201
+ print('**running model')
202
+ loss_total = 0
203
+ acc_total = 0
204
+ all_ids = np.zeros(0)
205
+ all_embeds = np.zeros((0, 256))
206
+
207
+ for step, batch in enumerate(data):
208
+ speaker_id, inputs = batch
209
+
210
+ print('batch:', step)
211
+ embed_inputs = inputs.reshape(-1, *(inputs.shape[2:])).to(device)
212
+ embeds = model(embed_inputs)
213
+ loss_embeds = embeds.view(*(inputs.shape[:2]),-1).to(loss_device)
214
+ loss = model.softmax_loss(loss_embeds)
215
+ accuracy = model.accuracy(loss_embeds)
216
+
217
+ all_ids = np.concatenate((all_ids, np.repeat(speaker_id, inputs.shape[1])))
218
+ all_embeds = np.concatenate((all_embeds, embeds.to(loss_device).detach()))
219
+
220
+ loss_total += loss
221
+ acc_total += accuracy
222
+
223
+ # print('inputs.shape',inputs.shape)
224
+ # print('embed_inputs.embed_inputs',embeds.shape)
225
+ # print('embeds.shape',embeds.shape)
226
+ # print('loss_embeds.shape',loss_embeds.shape)
227
+ # print('loss.shape',loss.shape)
228
+ # print('loss',loss)
229
+ # print('accuracy',accuracy)
230
+
231
+ print('average loss', loss_total / (step+1))
232
+ print('average accuracy', acc_total / (step+1))
233
+ print('silhouette score', silhouette_score(all_embeds, all_ids))
234
+ plot_speaker_embeddings(all_embeds, all_ids, f'tsne_saved_speaker.png', f'T-SNE Plot')
235
+ plot_random_embeddings(all_embeds, all_ids, f'tsne_saved_random.png', title=f'T-SNE Plot')
236
+ plot_gender_embeddings(all_embeds, all_ids, f'tsne_saved_gender.png', f'T-SNE Plot')
237
+
238
+
239
+ def plot_gender_embeddings(embeddings, ids, filename, title='T-SNE Plot'):
240
+ # Per https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
241
+ # reducing dimensionality before running TSNE
242
+ pca = PCA(50)
243
+ reduction = pca.fit_transform(embeddings)
244
+ tsne = TSNE(init='pca', learning_rate='auto')
245
+ transformed = tsne.fit_transform(reduction)
246
+
247
+ gender_mapper = get_mapping_array()
248
+ genders = gender_mapper[ids.astype('int')]
249
+ females = genders == 1
250
+ males = genders == 2
251
+
252
+ plt.figure()
253
+ plt.title(title)
254
+
255
+ plt.scatter(transformed[females, 0], transformed[females, 1], label='Female')
256
+ plt.scatter(transformed[males, 0], transformed[males, 1], label='Male')
257
+ plt.legend()
258
+ plt.grid()
259
+ # plt.show()
260
+ plt.savefig(path.join(diagram_path, tsne_path, filename))
261
+ plt.close()
262
+
263
+
264
+ def plot_speaker_embeddings(embeddings, ids, filename, title='T-SNE Plot'):
265
+ # Per https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
266
+ # reducing dimensionality before running TSNE
267
+ pca = PCA(50)
268
+ reduction = pca.fit_transform(embeddings)
269
+ tsne = TSNE(init='pca', learning_rate='auto')
270
+ transformed = tsne.fit_transform(reduction)
271
+
272
+ ids = ids.astype('int')
273
+ unique_ids = np.unique(ids)
274
+
275
+ plt.figure()
276
+ plt.title(f'{title} Speakers')
277
+
278
+ for speaker_id in unique_ids:
279
+ speaker_idx = ids == speaker_id
280
+ plt.scatter(transformed[speaker_idx, 0], transformed[speaker_idx, 1], label=f'Speaker {speaker_id}')
281
+
282
+ # plt.legend()
283
+ plt.grid()
284
+ # plt.show()
285
+ plt.savefig(path.join(diagram_path, tsne_path, filename))
286
+ plt.close()
287
+
288
+
289
+ def plot_random_embeddings(embeddings, ids, filename, size=15, title='T-SNE Plot Random'):
290
+ # Per https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
291
+ # reducing dimensionality before running TSNE
292
+ pca = PCA(50)
293
+ reduction = pca.fit_transform(embeddings)
294
+ tsne = TSNE(init='pca', learning_rate='auto')
295
+ transformed = tsne.fit_transform(reduction)
296
+
297
+ ids = ids.astype('int')
298
+ unique_ids = np.unique(ids)
299
+ random_unique_ids = np.random.choice(ids, size=min(size, unique_ids.size), replace=False)
300
+
301
+ plt.figure()
302
+
303
+ plt.title(f'{title} - {random_unique_ids.size} Speakers')
304
+
305
+ for speaker_id in random_unique_ids:
306
+ speaker_idx = ids == speaker_id
307
+ plt.scatter(transformed[speaker_idx, 0], transformed[speaker_idx, 1], label=f'Speaker {speaker_id}')
308
+
309
+ # plt.legend()
310
+ plt.grid()
311
+ # plt.show()
312
+ plt.savefig(path.join(diagram_path, tsne_path, filename))
313
+ plt.close()
314
+
315
+
316
+ if __name__ == '__main__':
317
+ os.makedirs(diagram_path, exist_ok=True)
318
+ os.makedirs(path.join(diagram_path, loss_path), exist_ok=True)
319
+ os.makedirs(path.join(diagram_path, accuracy_path), exist_ok=True)
320
+ os.makedirs(path.join(diagram_path, tsne_path), exist_ok=True)
321
+ os.makedirs(path.join(diagram_path, silhouette_path), exist_ok=True)
322
+ # for speaker_id, mel in load_data():
323
+ # print(speaker_id, mel.shape)
324
+
325
+ # Might make sense to adjust speaker / utterance per batch, e.g. 64/10
326
+ m = train(epochs=300)
327
+
328
+ # save_model(m,'speaker/saved_model.pt')
329
+ check_model('speaker/saved_model_e175.pt')
speaker/utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+
4
+ __mapping_array = None
5
+
6
+ with open('speaker/speakers.txt') as speakers:
7
+ lines = []
8
+ for line in speakers.readlines():
9
+ if line[0] == ';':
10
+ continue
11
+ lines.append(line)
12
+
13
+ rows = [line.split('|') for line in lines]
14
+
15
+ __mapping_list = [(int(row[0].strip()), row[1].strip()) for row in rows]
16
+
17
+ max_id = max([speaker_id for (speaker_id, _) in __mapping_list])\
18
+
19
+ __mapping_array = np.zeros(max_id + 1,)
20
+ for speaker_id, gender in __mapping_list:
21
+ if gender == 'F':
22
+ __mapping_array[speaker_id] = 1
23
+ else:
24
+ __mapping_array[speaker_id] = 2
25
+
26
+
27
+ def get_mapping_array():
28
+ return np.copy(__mapping_array)