Spaces:
Sleeping
Sleeping
Upload 10 files
Browse files- encoder/__init__.py +0 -0
- encoder/audio.py +117 -0
- encoder/config.py +45 -0
- encoder/inference.py +178 -0
- encoder/model.py +135 -0
- encoder/params_data.py +29 -0
- encoder/params_model.py +11 -0
- encoder/preprocess.py +184 -0
- encoder/train.py +125 -0
- encoder/visualizations.py +179 -0
encoder/__init__.py
ADDED
|
File without changes
|
encoder/audio.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scipy.ndimage.morphology import binary_dilation
|
| 2 |
+
from encoder.params_data import *
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
from warnings import warn
|
| 6 |
+
import numpy as np
|
| 7 |
+
import librosa
|
| 8 |
+
import struct
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import webrtcvad
|
| 12 |
+
except:
|
| 13 |
+
warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
|
| 14 |
+
webrtcvad=None
|
| 15 |
+
|
| 16 |
+
int16_max = (2 ** 15) - 1
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
|
| 20 |
+
source_sr: Optional[int] = None,
|
| 21 |
+
normalize: Optional[bool] = True,
|
| 22 |
+
trim_silence: Optional[bool] = True):
|
| 23 |
+
"""
|
| 24 |
+
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
|
| 25 |
+
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
|
| 26 |
+
|
| 27 |
+
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
|
| 28 |
+
just .wav), either the waveform as a numpy array of floats.
|
| 29 |
+
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
|
| 30 |
+
preprocessing. After preprocessing, the waveform's sampling rate will match the data
|
| 31 |
+
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
|
| 32 |
+
this argument will be ignored.
|
| 33 |
+
"""
|
| 34 |
+
# Load the wav from disk if needed
|
| 35 |
+
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
| 36 |
+
wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
|
| 37 |
+
else:
|
| 38 |
+
wav = fpath_or_wav
|
| 39 |
+
|
| 40 |
+
# Resample the wav if needed
|
| 41 |
+
if source_sr is not None and source_sr != sampling_rate:
|
| 42 |
+
wav = librosa.resample(wav, source_sr, sampling_rate)
|
| 43 |
+
|
| 44 |
+
# Apply the preprocessing: normalize volume and shorten long silences
|
| 45 |
+
if normalize:
|
| 46 |
+
wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
|
| 47 |
+
if webrtcvad and trim_silence:
|
| 48 |
+
wav = trim_long_silences(wav)
|
| 49 |
+
|
| 50 |
+
return wav
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def wav_to_mel_spectrogram(wav):
|
| 54 |
+
"""
|
| 55 |
+
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
| 56 |
+
Note: this not a log-mel spectrogram.
|
| 57 |
+
"""
|
| 58 |
+
frames = librosa.feature.melspectrogram(
|
| 59 |
+
wav,
|
| 60 |
+
sampling_rate,
|
| 61 |
+
n_fft=int(sampling_rate * mel_window_length / 1000),
|
| 62 |
+
hop_length=int(sampling_rate * mel_window_step / 1000),
|
| 63 |
+
n_mels=mel_n_channels
|
| 64 |
+
)
|
| 65 |
+
return frames.astype(np.float32).T
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def trim_long_silences(wav):
|
| 69 |
+
"""
|
| 70 |
+
Ensures that segments without voice in the waveform remain no longer than a
|
| 71 |
+
threshold determined by the VAD parameters in params.py.
|
| 72 |
+
|
| 73 |
+
:param wav: the raw waveform as a numpy array of floats
|
| 74 |
+
:return: the same waveform with silences trimmed away (length <= original wav length)
|
| 75 |
+
"""
|
| 76 |
+
# Compute the voice detection window size
|
| 77 |
+
samples_per_window = (vad_window_length * sampling_rate) // 1000
|
| 78 |
+
|
| 79 |
+
# Trim the end of the audio to have a multiple of the window size
|
| 80 |
+
wav = wav[:len(wav) - (len(wav) % samples_per_window)]
|
| 81 |
+
|
| 82 |
+
# Convert the float waveform to 16-bit mono PCM
|
| 83 |
+
pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
|
| 84 |
+
|
| 85 |
+
# Perform voice activation detection
|
| 86 |
+
voice_flags = []
|
| 87 |
+
vad = webrtcvad.Vad(mode=3)
|
| 88 |
+
for window_start in range(0, len(wav), samples_per_window):
|
| 89 |
+
window_end = window_start + samples_per_window
|
| 90 |
+
voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
|
| 91 |
+
sample_rate=sampling_rate))
|
| 92 |
+
voice_flags = np.array(voice_flags)
|
| 93 |
+
|
| 94 |
+
# Smooth the voice detection with a moving average
|
| 95 |
+
def moving_average(array, width):
|
| 96 |
+
array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
|
| 97 |
+
ret = np.cumsum(array_padded, dtype=float)
|
| 98 |
+
ret[width:] = ret[width:] - ret[:-width]
|
| 99 |
+
return ret[width - 1:] / width
|
| 100 |
+
|
| 101 |
+
audio_mask = moving_average(voice_flags, vad_moving_average_width)
|
| 102 |
+
audio_mask = np.round(audio_mask).astype(np.bool)
|
| 103 |
+
|
| 104 |
+
# Dilate the voiced regions
|
| 105 |
+
audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
|
| 106 |
+
audio_mask = np.repeat(audio_mask, samples_per_window)
|
| 107 |
+
|
| 108 |
+
return wav[audio_mask == True]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
|
| 112 |
+
if increase_only and decrease_only:
|
| 113 |
+
raise ValueError("Both increase only and decrease only are set")
|
| 114 |
+
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
|
| 115 |
+
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
|
| 116 |
+
return wav
|
| 117 |
+
return wav * (10 ** (dBFS_change / 20))
|
encoder/config.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
librispeech_datasets = {
|
| 2 |
+
"train": {
|
| 3 |
+
"clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
|
| 4 |
+
"other": ["LibriSpeech/train-other-500"]
|
| 5 |
+
},
|
| 6 |
+
"test": {
|
| 7 |
+
"clean": ["LibriSpeech/test-clean"],
|
| 8 |
+
"other": ["LibriSpeech/test-other"]
|
| 9 |
+
},
|
| 10 |
+
"dev": {
|
| 11 |
+
"clean": ["LibriSpeech/dev-clean"],
|
| 12 |
+
"other": ["LibriSpeech/dev-other"]
|
| 13 |
+
},
|
| 14 |
+
}
|
| 15 |
+
libritts_datasets = {
|
| 16 |
+
"train": {
|
| 17 |
+
"clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
|
| 18 |
+
"other": ["LibriTTS/train-other-500"]
|
| 19 |
+
},
|
| 20 |
+
"test": {
|
| 21 |
+
"clean": ["LibriTTS/test-clean"],
|
| 22 |
+
"other": ["LibriTTS/test-other"]
|
| 23 |
+
},
|
| 24 |
+
"dev": {
|
| 25 |
+
"clean": ["LibriTTS/dev-clean"],
|
| 26 |
+
"other": ["LibriTTS/dev-other"]
|
| 27 |
+
},
|
| 28 |
+
}
|
| 29 |
+
voxceleb_datasets = {
|
| 30 |
+
"voxceleb1" : {
|
| 31 |
+
"train": ["VoxCeleb1/wav"],
|
| 32 |
+
"test": ["VoxCeleb1/test_wav"]
|
| 33 |
+
},
|
| 34 |
+
"voxceleb2" : {
|
| 35 |
+
"train": ["VoxCeleb2/dev/aac"],
|
| 36 |
+
"test": ["VoxCeleb2/test_wav"]
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
other_datasets = [
|
| 41 |
+
"LJSpeech-1.1",
|
| 42 |
+
"VCTK-Corpus/wav48",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
|
encoder/inference.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from encoder.params_data import *
|
| 2 |
+
from encoder.model import SpeakerEncoder
|
| 3 |
+
from encoder.audio import preprocess_wav # We want to expose this function from here
|
| 4 |
+
from matplotlib import cm
|
| 5 |
+
from encoder import audio
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
_model = None # type: SpeakerEncoder
|
| 11 |
+
_device = None # type: torch.device
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_model(weights_fpath: Path, device=None):
|
| 15 |
+
"""
|
| 16 |
+
Loads the model in memory. If this function is not explicitely called, it will be run on the
|
| 17 |
+
first call to embed_frames() with the default weights file.
|
| 18 |
+
|
| 19 |
+
:param weights_fpath: the path to saved model weights.
|
| 20 |
+
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
|
| 21 |
+
model will be loaded and will run on this device. Outputs will however always be on the cpu.
|
| 22 |
+
If None, will default to your GPU if it"s available, otherwise your CPU.
|
| 23 |
+
"""
|
| 24 |
+
# TODO: I think the slow loading of the encoder might have something to do with the device it
|
| 25 |
+
# was saved on. Worth investigating.
|
| 26 |
+
global _model, _device
|
| 27 |
+
if device is None:
|
| 28 |
+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
elif isinstance(device, str):
|
| 30 |
+
_device = torch.device(device)
|
| 31 |
+
_model = SpeakerEncoder(_device, torch.device("cpu"))
|
| 32 |
+
checkpoint = torch.load(weights_fpath, _device)
|
| 33 |
+
_model.load_state_dict(checkpoint["model_state"])
|
| 34 |
+
_model.eval()
|
| 35 |
+
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def is_loaded():
|
| 39 |
+
return _model is not None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def embed_frames_batch(frames_batch):
|
| 43 |
+
"""
|
| 44 |
+
Computes embeddings for a batch of mel spectrogram.
|
| 45 |
+
|
| 46 |
+
:param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
|
| 47 |
+
(batch_size, n_frames, n_channels)
|
| 48 |
+
:return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
|
| 49 |
+
"""
|
| 50 |
+
if _model is None:
|
| 51 |
+
raise Exception("Model was not loaded. Call load_model() before inference.")
|
| 52 |
+
|
| 53 |
+
frames = torch.from_numpy(frames_batch).to(_device)
|
| 54 |
+
embed = _model.forward(frames).detach().cpu().numpy()
|
| 55 |
+
return embed
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
| 59 |
+
min_pad_coverage=0.75, overlap=0.5):
|
| 60 |
+
"""
|
| 61 |
+
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
| 62 |
+
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
| 63 |
+
spectrogram slices are returned, so as to make each partial utterance waveform correspond to
|
| 64 |
+
its spectrogram. This function assumes that the mel spectrogram parameters used are those
|
| 65 |
+
defined in params_data.py.
|
| 66 |
+
|
| 67 |
+
The returned ranges may be indexing further than the length of the waveform. It is
|
| 68 |
+
recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
|
| 69 |
+
|
| 70 |
+
:param n_samples: the number of samples in the waveform
|
| 71 |
+
:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
|
| 72 |
+
utterance
|
| 73 |
+
:param min_pad_coverage: when reaching the last partial utterance, it may or may not have
|
| 74 |
+
enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
|
| 75 |
+
then the last partial utterance will be considered, as if we padded the audio. Otherwise,
|
| 76 |
+
it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
|
| 77 |
+
utterance, this parameter is ignored so that the function always returns at least 1 slice.
|
| 78 |
+
:param overlap: by how much the partial utterance should overlap. If set to 0, the partial
|
| 79 |
+
utterances are entirely disjoint.
|
| 80 |
+
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
| 81 |
+
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
| 82 |
+
utterances.
|
| 83 |
+
"""
|
| 84 |
+
assert 0 <= overlap < 1
|
| 85 |
+
assert 0 < min_pad_coverage <= 1
|
| 86 |
+
|
| 87 |
+
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
| 88 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
| 89 |
+
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
| 90 |
+
|
| 91 |
+
# Compute the slices
|
| 92 |
+
wav_slices, mel_slices = [], []
|
| 93 |
+
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
|
| 94 |
+
for i in range(0, steps, frame_step):
|
| 95 |
+
mel_range = np.array([i, i + partial_utterance_n_frames])
|
| 96 |
+
wav_range = mel_range * samples_per_frame
|
| 97 |
+
mel_slices.append(slice(*mel_range))
|
| 98 |
+
wav_slices.append(slice(*wav_range))
|
| 99 |
+
|
| 100 |
+
# Evaluate whether extra padding is warranted or not
|
| 101 |
+
last_wav_range = wav_slices[-1]
|
| 102 |
+
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
| 103 |
+
if coverage < min_pad_coverage and len(mel_slices) > 1:
|
| 104 |
+
mel_slices = mel_slices[:-1]
|
| 105 |
+
wav_slices = wav_slices[:-1]
|
| 106 |
+
|
| 107 |
+
return wav_slices, mel_slices
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
|
| 111 |
+
"""
|
| 112 |
+
Computes an embedding for a single utterance.
|
| 113 |
+
|
| 114 |
+
# TODO: handle multiple wavs to benefit from batching on GPU
|
| 115 |
+
:param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
|
| 116 |
+
:param using_partials: if True, then the utterance is split in partial utterances of
|
| 117 |
+
<partial_utterance_n_frames> frames and the utterance embedding is computed from their
|
| 118 |
+
normalized average. If False, the utterance is instead computed from feeding the entire
|
| 119 |
+
spectogram to the network.
|
| 120 |
+
:param return_partials: if True, the partial embeddings will also be returned along with the
|
| 121 |
+
wav slices that correspond to the partial embeddings.
|
| 122 |
+
:param kwargs: additional arguments to compute_partial_splits()
|
| 123 |
+
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
| 124 |
+
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
| 125 |
+
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
| 126 |
+
returned. If <using_partials> is simultaneously set to False, both these values will be None
|
| 127 |
+
instead.
|
| 128 |
+
"""
|
| 129 |
+
# Process the entire utterance if not using partials
|
| 130 |
+
if not using_partials:
|
| 131 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
| 132 |
+
embed = embed_frames_batch(frames[None, ...])[0]
|
| 133 |
+
if return_partials:
|
| 134 |
+
return embed, None, None
|
| 135 |
+
return embed
|
| 136 |
+
|
| 137 |
+
# Compute where to split the utterance into partials and pad if necessary
|
| 138 |
+
wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
|
| 139 |
+
max_wave_length = wave_slices[-1].stop
|
| 140 |
+
if max_wave_length >= len(wav):
|
| 141 |
+
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
| 142 |
+
|
| 143 |
+
# Split the utterance into partials
|
| 144 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
| 145 |
+
frames_batch = np.array([frames[s] for s in mel_slices])
|
| 146 |
+
partial_embeds = embed_frames_batch(frames_batch)
|
| 147 |
+
|
| 148 |
+
# Compute the utterance embedding from the partial embeddings
|
| 149 |
+
raw_embed = np.mean(partial_embeds, axis=0)
|
| 150 |
+
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
| 151 |
+
|
| 152 |
+
if return_partials:
|
| 153 |
+
return embed, partial_embeds, wave_slices
|
| 154 |
+
return embed
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def embed_speaker(wavs, **kwargs):
|
| 158 |
+
raise NotImplemented()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
|
| 162 |
+
import matplotlib.pyplot as plt
|
| 163 |
+
if ax is None:
|
| 164 |
+
ax = plt.gca()
|
| 165 |
+
|
| 166 |
+
if shape is None:
|
| 167 |
+
height = int(np.sqrt(len(embed)))
|
| 168 |
+
shape = (height, -1)
|
| 169 |
+
embed = embed.reshape(shape)
|
| 170 |
+
|
| 171 |
+
cmap = cm.get_cmap()
|
| 172 |
+
mappable = ax.imshow(embed, cmap=cmap)
|
| 173 |
+
cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
|
| 174 |
+
sm = cm.ScalarMappable(cmap=cmap)
|
| 175 |
+
sm.set_clim(*color_range)
|
| 176 |
+
|
| 177 |
+
ax.set_xticks([]), ax.set_yticks([])
|
| 178 |
+
ax.set_title(title)
|
encoder/model.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from encoder.params_model import *
|
| 2 |
+
from encoder.params_data import *
|
| 3 |
+
from scipy.interpolate import interp1d
|
| 4 |
+
from sklearn.metrics import roc_curve
|
| 5 |
+
from torch.nn.utils import clip_grad_norm_
|
| 6 |
+
from scipy.optimize import brentq
|
| 7 |
+
from torch import nn
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SpeakerEncoder(nn.Module):
|
| 13 |
+
def __init__(self, device, loss_device):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.loss_device = loss_device
|
| 16 |
+
|
| 17 |
+
# Network defition
|
| 18 |
+
self.lstm = nn.LSTM(input_size=mel_n_channels,
|
| 19 |
+
hidden_size=model_hidden_size,
|
| 20 |
+
num_layers=model_num_layers,
|
| 21 |
+
batch_first=True).to(device)
|
| 22 |
+
self.linear = nn.Linear(in_features=model_hidden_size,
|
| 23 |
+
out_features=model_embedding_size).to(device)
|
| 24 |
+
self.relu = torch.nn.ReLU().to(device)
|
| 25 |
+
|
| 26 |
+
# Cosine similarity scaling (with fixed initial parameter values)
|
| 27 |
+
self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
|
| 28 |
+
self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
|
| 29 |
+
|
| 30 |
+
# Loss
|
| 31 |
+
self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
|
| 32 |
+
|
| 33 |
+
def do_gradient_ops(self):
|
| 34 |
+
# Gradient scale
|
| 35 |
+
self.similarity_weight.grad *= 0.01
|
| 36 |
+
self.similarity_bias.grad *= 0.01
|
| 37 |
+
|
| 38 |
+
# Gradient clipping
|
| 39 |
+
clip_grad_norm_(self.parameters(), 3, norm_type=2)
|
| 40 |
+
|
| 41 |
+
def forward(self, utterances, hidden_init=None):
|
| 42 |
+
"""
|
| 43 |
+
Computes the embeddings of a batch of utterance spectrograms.
|
| 44 |
+
|
| 45 |
+
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
|
| 46 |
+
(batch_size, n_frames, n_channels)
|
| 47 |
+
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
|
| 48 |
+
batch_size, hidden_size). Will default to a tensor of zeros if None.
|
| 49 |
+
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
|
| 50 |
+
"""
|
| 51 |
+
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
|
| 52 |
+
# and the final cell state.
|
| 53 |
+
out, (hidden, cell) = self.lstm(utterances, hidden_init)
|
| 54 |
+
|
| 55 |
+
# We take only the hidden state of the last layer
|
| 56 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
| 57 |
+
|
| 58 |
+
# L2-normalize it
|
| 59 |
+
embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
|
| 60 |
+
|
| 61 |
+
return embeds
|
| 62 |
+
|
| 63 |
+
def similarity_matrix(self, embeds):
|
| 64 |
+
"""
|
| 65 |
+
Computes the similarity matrix according the section 2.1 of GE2E.
|
| 66 |
+
|
| 67 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
| 68 |
+
utterances_per_speaker, embedding_size)
|
| 69 |
+
:return: the similarity matrix as a tensor of shape (speakers_per_batch,
|
| 70 |
+
utterances_per_speaker, speakers_per_batch)
|
| 71 |
+
"""
|
| 72 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
| 73 |
+
|
| 74 |
+
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
| 75 |
+
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
|
| 76 |
+
centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
|
| 77 |
+
|
| 78 |
+
# Exclusive centroids (1 per utterance)
|
| 79 |
+
centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
|
| 80 |
+
centroids_excl /= (utterances_per_speaker - 1)
|
| 81 |
+
centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
|
| 82 |
+
|
| 83 |
+
# Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
|
| 84 |
+
# product of these vectors (which is just an element-wise multiplication reduced by a sum).
|
| 85 |
+
# We vectorize the computation for efficiency.
|
| 86 |
+
sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
|
| 87 |
+
speakers_per_batch).to(self.loss_device)
|
| 88 |
+
mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
|
| 89 |
+
for j in range(speakers_per_batch):
|
| 90 |
+
mask = np.where(mask_matrix[j])[0]
|
| 91 |
+
sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
|
| 92 |
+
sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
|
| 93 |
+
|
| 94 |
+
## Even more vectorized version (slower maybe because of transpose)
|
| 95 |
+
# sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
|
| 96 |
+
# ).to(self.loss_device)
|
| 97 |
+
# eye = np.eye(speakers_per_batch, dtype=np.int)
|
| 98 |
+
# mask = np.where(1 - eye)
|
| 99 |
+
# sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
|
| 100 |
+
# mask = np.where(eye)
|
| 101 |
+
# sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
|
| 102 |
+
# sim_matrix2 = sim_matrix2.transpose(1, 2)
|
| 103 |
+
|
| 104 |
+
sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
|
| 105 |
+
return sim_matrix
|
| 106 |
+
|
| 107 |
+
def loss(self, embeds):
|
| 108 |
+
"""
|
| 109 |
+
Computes the softmax loss according the section 2.1 of GE2E.
|
| 110 |
+
|
| 111 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
| 112 |
+
utterances_per_speaker, embedding_size)
|
| 113 |
+
:return: the loss and the EER for this batch of embeddings.
|
| 114 |
+
"""
|
| 115 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
| 116 |
+
|
| 117 |
+
# Loss
|
| 118 |
+
sim_matrix = self.similarity_matrix(embeds)
|
| 119 |
+
sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
|
| 120 |
+
speakers_per_batch))
|
| 121 |
+
ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
|
| 122 |
+
target = torch.from_numpy(ground_truth).long().to(self.loss_device)
|
| 123 |
+
loss = self.loss_fn(sim_matrix, target)
|
| 124 |
+
|
| 125 |
+
# EER (not backpropagated)
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
|
| 128 |
+
labels = np.array([inv_argmax(i) for i in ground_truth])
|
| 129 |
+
preds = sim_matrix.detach().cpu().numpy()
|
| 130 |
+
|
| 131 |
+
# Snippet from https://yangcha.github.io/EER-ROC/
|
| 132 |
+
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
| 133 |
+
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
| 134 |
+
|
| 135 |
+
return loss, eer
|
encoder/params_data.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
## Mel-filterbank
|
| 3 |
+
mel_window_length = 25 # In milliseconds
|
| 4 |
+
mel_window_step = 10 # In milliseconds
|
| 5 |
+
mel_n_channels = 40
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
## Audio
|
| 9 |
+
sampling_rate = 16000
|
| 10 |
+
# Number of spectrogram frames in a partial utterance
|
| 11 |
+
partials_n_frames = 160 # 1600 ms
|
| 12 |
+
# Number of spectrogram frames at inference
|
| 13 |
+
inference_n_frames = 80 # 800 ms
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
## Voice Activation Detection
|
| 17 |
+
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
| 18 |
+
# This sets the granularity of the VAD. Should not need to be changed.
|
| 19 |
+
vad_window_length = 30 # In milliseconds
|
| 20 |
+
# Number of frames to average together when performing the moving average smoothing.
|
| 21 |
+
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
| 22 |
+
vad_moving_average_width = 8
|
| 23 |
+
# Maximum number of consecutive silent frames a segment can have.
|
| 24 |
+
vad_max_silence_length = 6
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
## Audio volume normalization
|
| 28 |
+
audio_norm_target_dBFS = -30
|
| 29 |
+
|
encoder/params_model.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
## Model parameters
|
| 3 |
+
model_hidden_size = 256
|
| 4 |
+
model_embedding_size = 256
|
| 5 |
+
model_num_layers = 3
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
## Training parameters
|
| 9 |
+
learning_rate_init = 1e-4
|
| 10 |
+
speakers_per_batch = 64
|
| 11 |
+
utterances_per_speaker = 10
|
encoder/preprocess.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from functools import partial
|
| 3 |
+
from multiprocessing import Pool
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from encoder import audio
|
| 10 |
+
from encoder.config import librispeech_datasets, anglophone_nationalites
|
| 11 |
+
from encoder.params_data import *
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_AUDIO_EXTENSIONS = ("wav", "flac", "m4a", "mp3")
|
| 15 |
+
|
| 16 |
+
class DatasetLog:
|
| 17 |
+
"""
|
| 18 |
+
Registers metadata about the dataset in a text file.
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, root, name):
|
| 21 |
+
self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
|
| 22 |
+
self.sample_data = dict()
|
| 23 |
+
|
| 24 |
+
start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
| 25 |
+
self.write_line("Creating dataset %s on %s" % (name, start_time))
|
| 26 |
+
self.write_line("-----")
|
| 27 |
+
self._log_params()
|
| 28 |
+
|
| 29 |
+
def _log_params(self):
|
| 30 |
+
from encoder import params_data
|
| 31 |
+
self.write_line("Parameter values:")
|
| 32 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
| 33 |
+
value = getattr(params_data, param_name)
|
| 34 |
+
self.write_line("\t%s: %s" % (param_name, value))
|
| 35 |
+
self.write_line("-----")
|
| 36 |
+
|
| 37 |
+
def write_line(self, line):
|
| 38 |
+
self.text_file.write("%s\n" % line)
|
| 39 |
+
|
| 40 |
+
def add_sample(self, **kwargs):
|
| 41 |
+
for param_name, value in kwargs.items():
|
| 42 |
+
if not param_name in self.sample_data:
|
| 43 |
+
self.sample_data[param_name] = []
|
| 44 |
+
self.sample_data[param_name].append(value)
|
| 45 |
+
|
| 46 |
+
def finalize(self):
|
| 47 |
+
self.write_line("Statistics:")
|
| 48 |
+
for param_name, values in self.sample_data.items():
|
| 49 |
+
self.write_line("\t%s:" % param_name)
|
| 50 |
+
self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
|
| 51 |
+
self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
|
| 52 |
+
self.write_line("-----")
|
| 53 |
+
end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
| 54 |
+
self.write_line("Finished on %s" % end_time)
|
| 55 |
+
self.text_file.close()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
|
| 59 |
+
dataset_root = datasets_root.joinpath(dataset_name)
|
| 60 |
+
if not dataset_root.exists():
|
| 61 |
+
print("Couldn\'t find %s, skipping this dataset." % dataset_root)
|
| 62 |
+
return None, None
|
| 63 |
+
return dataset_root, DatasetLog(out_dir, dataset_name)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, skip_existing: bool):
|
| 67 |
+
# Give a name to the speaker that includes its dataset
|
| 68 |
+
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
| 69 |
+
|
| 70 |
+
# Create an output directory with that name, as well as a txt file containing a
|
| 71 |
+
# reference to each source file.
|
| 72 |
+
speaker_out_dir = out_dir.joinpath(speaker_name)
|
| 73 |
+
speaker_out_dir.mkdir(exist_ok=True)
|
| 74 |
+
sources_fpath = speaker_out_dir.joinpath("_sources.txt")
|
| 75 |
+
|
| 76 |
+
# There's a possibility that the preprocessing was interrupted earlier, check if
|
| 77 |
+
# there already is a sources file.
|
| 78 |
+
if sources_fpath.exists():
|
| 79 |
+
try:
|
| 80 |
+
with sources_fpath.open("r") as sources_file:
|
| 81 |
+
existing_fnames = {line.split(",")[0] for line in sources_file}
|
| 82 |
+
except:
|
| 83 |
+
existing_fnames = {}
|
| 84 |
+
else:
|
| 85 |
+
existing_fnames = {}
|
| 86 |
+
|
| 87 |
+
# Gather all audio files for that speaker recursively
|
| 88 |
+
sources_file = sources_fpath.open("a" if skip_existing else "w")
|
| 89 |
+
audio_durs = []
|
| 90 |
+
for extension in _AUDIO_EXTENSIONS:
|
| 91 |
+
for in_fpath in speaker_dir.glob("**/*.%s" % extension):
|
| 92 |
+
# Check if the target output file already exists
|
| 93 |
+
out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
|
| 94 |
+
out_fname = out_fname.replace(".%s" % extension, ".npy")
|
| 95 |
+
if skip_existing and out_fname in existing_fnames:
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
# Load and preprocess the waveform
|
| 99 |
+
wav = audio.preprocess_wav(in_fpath)
|
| 100 |
+
if len(wav) == 0:
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
# Create the mel spectrogram, discard those that are too short
|
| 104 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
| 105 |
+
if len(frames) < partials_n_frames:
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
out_fpath = speaker_out_dir.joinpath(out_fname)
|
| 109 |
+
np.save(out_fpath, frames)
|
| 110 |
+
sources_file.write("%s,%s\n" % (out_fname, in_fpath))
|
| 111 |
+
audio_durs.append(len(wav) / sampling_rate)
|
| 112 |
+
|
| 113 |
+
sources_file.close()
|
| 114 |
+
|
| 115 |
+
return audio_durs
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger):
|
| 119 |
+
print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
|
| 120 |
+
|
| 121 |
+
# Process the utterances for each speaker
|
| 122 |
+
work_fn = partial(_preprocess_speaker, datasets_root=datasets_root, out_dir=out_dir, skip_existing=skip_existing)
|
| 123 |
+
with Pool(4) as pool:
|
| 124 |
+
tasks = pool.imap(work_fn, speaker_dirs)
|
| 125 |
+
for sample_durs in tqdm(tasks, dataset_name, len(speaker_dirs), unit="speakers"):
|
| 126 |
+
for sample_dur in sample_durs:
|
| 127 |
+
logger.add_sample(duration=sample_dur)
|
| 128 |
+
|
| 129 |
+
logger.finalize()
|
| 130 |
+
print("Done preprocessing %s.\n" % dataset_name)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
|
| 134 |
+
for dataset_name in librispeech_datasets["train"]["other"]:
|
| 135 |
+
# Initialize the preprocessing
|
| 136 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
| 137 |
+
if not dataset_root:
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
# Preprocess all speakers
|
| 141 |
+
speaker_dirs = list(dataset_root.glob("*"))
|
| 142 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
|
| 146 |
+
# Initialize the preprocessing
|
| 147 |
+
dataset_name = "VoxCeleb1"
|
| 148 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
| 149 |
+
if not dataset_root:
|
| 150 |
+
return
|
| 151 |
+
|
| 152 |
+
# Get the contents of the meta file
|
| 153 |
+
with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
|
| 154 |
+
metadata = [line.split("\t") for line in metafile][1:]
|
| 155 |
+
|
| 156 |
+
# Select the ID and the nationality, filter out non-anglophone speakers
|
| 157 |
+
nationalities = {line[0]: line[3] for line in metadata}
|
| 158 |
+
keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
|
| 159 |
+
nationality.lower() in anglophone_nationalites]
|
| 160 |
+
print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
|
| 161 |
+
(len(keep_speaker_ids), len(nationalities)))
|
| 162 |
+
|
| 163 |
+
# Get the speaker directories for anglophone speakers only
|
| 164 |
+
speaker_dirs = dataset_root.joinpath("wav").glob("*")
|
| 165 |
+
speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
|
| 166 |
+
speaker_dir.name in keep_speaker_ids]
|
| 167 |
+
print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
|
| 168 |
+
(len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
|
| 169 |
+
|
| 170 |
+
# Preprocess all speakers
|
| 171 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
|
| 175 |
+
# Initialize the preprocessing
|
| 176 |
+
dataset_name = "VoxCeleb2"
|
| 177 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
| 178 |
+
if not dataset_root:
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
# Get the speaker directories
|
| 182 |
+
# Preprocess all speakers
|
| 183 |
+
speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
|
| 184 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
|
encoder/train.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
|
| 6 |
+
from encoder.model import SpeakerEncoder
|
| 7 |
+
from encoder.params_model import *
|
| 8 |
+
from encoder.visualizations import Visualizations
|
| 9 |
+
from utils.profiler import Profiler
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def sync(device: torch.device):
|
| 13 |
+
# For correct profiling (cuda operations are async)
|
| 14 |
+
if device.type == "cuda":
|
| 15 |
+
torch.cuda.synchronize(device)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
|
| 19 |
+
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
|
| 20 |
+
no_visdom: bool):
|
| 21 |
+
# Create a dataset and a dataloader
|
| 22 |
+
dataset = SpeakerVerificationDataset(clean_data_root)
|
| 23 |
+
loader = SpeakerVerificationDataLoader(
|
| 24 |
+
dataset,
|
| 25 |
+
speakers_per_batch,
|
| 26 |
+
utterances_per_speaker,
|
| 27 |
+
num_workers=4,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Setup the device on which to run the forward pass and the loss. These can be different,
|
| 31 |
+
# because the forward pass is faster on the GPU whereas the loss is often (depending on your
|
| 32 |
+
# hyperparameters) faster on the CPU.
|
| 33 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 34 |
+
# FIXME: currently, the gradient is None if loss_device is cuda
|
| 35 |
+
loss_device = torch.device("cpu")
|
| 36 |
+
|
| 37 |
+
# Create the model and the optimizer
|
| 38 |
+
model = SpeakerEncoder(device, loss_device)
|
| 39 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
|
| 40 |
+
init_step = 1
|
| 41 |
+
|
| 42 |
+
# Configure file path for the model
|
| 43 |
+
model_dir = models_dir / run_id
|
| 44 |
+
model_dir.mkdir(exist_ok=True, parents=True)
|
| 45 |
+
state_fpath = model_dir / "encoder.pt"
|
| 46 |
+
|
| 47 |
+
# Load any existing model
|
| 48 |
+
if not force_restart:
|
| 49 |
+
if state_fpath.exists():
|
| 50 |
+
print("Found existing model \"%s\", loading it and resuming training." % run_id)
|
| 51 |
+
checkpoint = torch.load(state_fpath)
|
| 52 |
+
init_step = checkpoint["step"]
|
| 53 |
+
model.load_state_dict(checkpoint["model_state"])
|
| 54 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
| 55 |
+
optimizer.param_groups[0]["lr"] = learning_rate_init
|
| 56 |
+
else:
|
| 57 |
+
print("No model \"%s\" found, starting training from scratch." % run_id)
|
| 58 |
+
else:
|
| 59 |
+
print("Starting the training from scratch.")
|
| 60 |
+
model.train()
|
| 61 |
+
|
| 62 |
+
# Initialize the visualization environment
|
| 63 |
+
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
|
| 64 |
+
vis.log_dataset(dataset)
|
| 65 |
+
vis.log_params()
|
| 66 |
+
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
|
| 67 |
+
vis.log_implementation({"Device": device_name})
|
| 68 |
+
|
| 69 |
+
# Training loop
|
| 70 |
+
profiler = Profiler(summarize_every=10, disabled=False)
|
| 71 |
+
for step, speaker_batch in enumerate(loader, init_step):
|
| 72 |
+
profiler.tick("Blocking, waiting for batch (threaded)")
|
| 73 |
+
|
| 74 |
+
# Forward pass
|
| 75 |
+
inputs = torch.from_numpy(speaker_batch.data).to(device)
|
| 76 |
+
sync(device)
|
| 77 |
+
profiler.tick("Data to %s" % device)
|
| 78 |
+
embeds = model(inputs)
|
| 79 |
+
sync(device)
|
| 80 |
+
profiler.tick("Forward pass")
|
| 81 |
+
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
|
| 82 |
+
loss, eer = model.loss(embeds_loss)
|
| 83 |
+
sync(loss_device)
|
| 84 |
+
profiler.tick("Loss")
|
| 85 |
+
|
| 86 |
+
# Backward pass
|
| 87 |
+
model.zero_grad()
|
| 88 |
+
loss.backward()
|
| 89 |
+
profiler.tick("Backward pass")
|
| 90 |
+
model.do_gradient_ops()
|
| 91 |
+
optimizer.step()
|
| 92 |
+
profiler.tick("Parameter update")
|
| 93 |
+
|
| 94 |
+
# Update visualizations
|
| 95 |
+
# learning_rate = optimizer.param_groups[0]["lr"]
|
| 96 |
+
vis.update(loss.item(), eer, step)
|
| 97 |
+
|
| 98 |
+
# Draw projections and save them to the backup folder
|
| 99 |
+
if umap_every != 0 and step % umap_every == 0:
|
| 100 |
+
print("Drawing and saving projections (step %d)" % step)
|
| 101 |
+
projection_fpath = model_dir / f"umap_{step:06d}.png"
|
| 102 |
+
embeds = embeds.detach().cpu().numpy()
|
| 103 |
+
vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
|
| 104 |
+
vis.save()
|
| 105 |
+
|
| 106 |
+
# Overwrite the latest version of the model
|
| 107 |
+
if save_every != 0 and step % save_every == 0:
|
| 108 |
+
print("Saving the model (step %d)" % step)
|
| 109 |
+
torch.save({
|
| 110 |
+
"step": step + 1,
|
| 111 |
+
"model_state": model.state_dict(),
|
| 112 |
+
"optimizer_state": optimizer.state_dict(),
|
| 113 |
+
}, state_fpath)
|
| 114 |
+
|
| 115 |
+
# Make a backup
|
| 116 |
+
if backup_every != 0 and step % backup_every == 0:
|
| 117 |
+
print("Making a backup (step %d)" % step)
|
| 118 |
+
backup_fpath = model_dir / f"encoder_{step:06d}.bak"
|
| 119 |
+
torch.save({
|
| 120 |
+
"step": step + 1,
|
| 121 |
+
"model_state": model.state_dict(),
|
| 122 |
+
"optimizer_state": optimizer.state_dict(),
|
| 123 |
+
}, backup_fpath)
|
| 124 |
+
|
| 125 |
+
profiler.tick("Extras (visualizations, saving)")
|
encoder/visualizations.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from time import perf_counter as timer
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import umap
|
| 6 |
+
import visdom
|
| 7 |
+
|
| 8 |
+
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
colormap = np.array([
|
| 12 |
+
[76, 255, 0],
|
| 13 |
+
[0, 127, 70],
|
| 14 |
+
[255, 0, 0],
|
| 15 |
+
[255, 217, 38],
|
| 16 |
+
[0, 135, 255],
|
| 17 |
+
[165, 0, 165],
|
| 18 |
+
[255, 167, 255],
|
| 19 |
+
[0, 255, 255],
|
| 20 |
+
[255, 96, 38],
|
| 21 |
+
[142, 76, 0],
|
| 22 |
+
[33, 0, 127],
|
| 23 |
+
[0, 0, 0],
|
| 24 |
+
[183, 183, 183],
|
| 25 |
+
], dtype=np.float) / 255
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Visualizations:
|
| 29 |
+
def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
|
| 30 |
+
# Tracking data
|
| 31 |
+
self.last_update_timestamp = timer()
|
| 32 |
+
self.update_every = update_every
|
| 33 |
+
self.step_times = []
|
| 34 |
+
self.losses = []
|
| 35 |
+
self.eers = []
|
| 36 |
+
print("Updating the visualizations every %d steps." % update_every)
|
| 37 |
+
|
| 38 |
+
# If visdom is disabled TODO: use a better paradigm for that
|
| 39 |
+
self.disabled = disabled
|
| 40 |
+
if self.disabled:
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
# Set the environment name
|
| 44 |
+
now = str(datetime.now().strftime("%d-%m %Hh%M"))
|
| 45 |
+
if env_name is None:
|
| 46 |
+
self.env_name = now
|
| 47 |
+
else:
|
| 48 |
+
self.env_name = "%s (%s)" % (env_name, now)
|
| 49 |
+
|
| 50 |
+
# Connect to visdom and open the corresponding window in the browser
|
| 51 |
+
try:
|
| 52 |
+
self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
|
| 53 |
+
except ConnectionError:
|
| 54 |
+
raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
|
| 55 |
+
"start it.")
|
| 56 |
+
# webbrowser.open("http://localhost:8097/env/" + self.env_name)
|
| 57 |
+
|
| 58 |
+
# Create the windows
|
| 59 |
+
self.loss_win = None
|
| 60 |
+
self.eer_win = None
|
| 61 |
+
# self.lr_win = None
|
| 62 |
+
self.implementation_win = None
|
| 63 |
+
self.projection_win = None
|
| 64 |
+
self.implementation_string = ""
|
| 65 |
+
|
| 66 |
+
def log_params(self):
|
| 67 |
+
if self.disabled:
|
| 68 |
+
return
|
| 69 |
+
from encoder import params_data
|
| 70 |
+
from encoder import params_model
|
| 71 |
+
param_string = "<b>Model parameters</b>:<br>"
|
| 72 |
+
for param_name in (p for p in dir(params_model) if not p.startswith("__")):
|
| 73 |
+
value = getattr(params_model, param_name)
|
| 74 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
| 75 |
+
param_string += "<b>Data parameters</b>:<br>"
|
| 76 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
| 77 |
+
value = getattr(params_data, param_name)
|
| 78 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
| 79 |
+
self.vis.text(param_string, opts={"title": "Parameters"})
|
| 80 |
+
|
| 81 |
+
def log_dataset(self, dataset: SpeakerVerificationDataset):
|
| 82 |
+
if self.disabled:
|
| 83 |
+
return
|
| 84 |
+
dataset_string = ""
|
| 85 |
+
dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
|
| 86 |
+
dataset_string += "\n" + dataset.get_logs()
|
| 87 |
+
dataset_string = dataset_string.replace("\n", "<br>")
|
| 88 |
+
self.vis.text(dataset_string, opts={"title": "Dataset"})
|
| 89 |
+
|
| 90 |
+
def log_implementation(self, params):
|
| 91 |
+
if self.disabled:
|
| 92 |
+
return
|
| 93 |
+
implementation_string = ""
|
| 94 |
+
for param, value in params.items():
|
| 95 |
+
implementation_string += "<b>%s</b>: %s\n" % (param, value)
|
| 96 |
+
implementation_string = implementation_string.replace("\n", "<br>")
|
| 97 |
+
self.implementation_string = implementation_string
|
| 98 |
+
self.implementation_win = self.vis.text(
|
| 99 |
+
implementation_string,
|
| 100 |
+
opts={"title": "Training implementation"}
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def update(self, loss, eer, step):
|
| 104 |
+
# Update the tracking data
|
| 105 |
+
now = timer()
|
| 106 |
+
self.step_times.append(1000 * (now - self.last_update_timestamp))
|
| 107 |
+
self.last_update_timestamp = now
|
| 108 |
+
self.losses.append(loss)
|
| 109 |
+
self.eers.append(eer)
|
| 110 |
+
print(".", end="")
|
| 111 |
+
|
| 112 |
+
# Update the plots every <update_every> steps
|
| 113 |
+
if step % self.update_every != 0:
|
| 114 |
+
return
|
| 115 |
+
time_string = "Step time: mean: %5dms std: %5dms" % \
|
| 116 |
+
(int(np.mean(self.step_times)), int(np.std(self.step_times)))
|
| 117 |
+
print("\nStep %6d Loss: %.4f EER: %.4f %s" %
|
| 118 |
+
(step, np.mean(self.losses), np.mean(self.eers), time_string))
|
| 119 |
+
if not self.disabled:
|
| 120 |
+
self.loss_win = self.vis.line(
|
| 121 |
+
[np.mean(self.losses)],
|
| 122 |
+
[step],
|
| 123 |
+
win=self.loss_win,
|
| 124 |
+
update="append" if self.loss_win else None,
|
| 125 |
+
opts=dict(
|
| 126 |
+
legend=["Avg. loss"],
|
| 127 |
+
xlabel="Step",
|
| 128 |
+
ylabel="Loss",
|
| 129 |
+
title="Loss",
|
| 130 |
+
)
|
| 131 |
+
)
|
| 132 |
+
self.eer_win = self.vis.line(
|
| 133 |
+
[np.mean(self.eers)],
|
| 134 |
+
[step],
|
| 135 |
+
win=self.eer_win,
|
| 136 |
+
update="append" if self.eer_win else None,
|
| 137 |
+
opts=dict(
|
| 138 |
+
legend=["Avg. EER"],
|
| 139 |
+
xlabel="Step",
|
| 140 |
+
ylabel="EER",
|
| 141 |
+
title="Equal error rate"
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
if self.implementation_win is not None:
|
| 145 |
+
self.vis.text(
|
| 146 |
+
self.implementation_string + ("<b>%s</b>" % time_string),
|
| 147 |
+
win=self.implementation_win,
|
| 148 |
+
opts={"title": "Training implementation"},
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Reset the tracking
|
| 152 |
+
self.losses.clear()
|
| 153 |
+
self.eers.clear()
|
| 154 |
+
self.step_times.clear()
|
| 155 |
+
|
| 156 |
+
def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10):
|
| 157 |
+
import matplotlib.pyplot as plt
|
| 158 |
+
|
| 159 |
+
max_speakers = min(max_speakers, len(colormap))
|
| 160 |
+
embeds = embeds[:max_speakers * utterances_per_speaker]
|
| 161 |
+
|
| 162 |
+
n_speakers = len(embeds) // utterances_per_speaker
|
| 163 |
+
ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
|
| 164 |
+
colors = [colormap[i] for i in ground_truth]
|
| 165 |
+
|
| 166 |
+
reducer = umap.UMAP()
|
| 167 |
+
projected = reducer.fit_transform(embeds)
|
| 168 |
+
plt.scatter(projected[:, 0], projected[:, 1], c=colors)
|
| 169 |
+
plt.gca().set_aspect("equal", "datalim")
|
| 170 |
+
plt.title("UMAP projection (step %d)" % step)
|
| 171 |
+
if not self.disabled:
|
| 172 |
+
self.projection_win = self.vis.matplot(plt, win=self.projection_win)
|
| 173 |
+
if out_fpath is not None:
|
| 174 |
+
plt.savefig(out_fpath)
|
| 175 |
+
plt.clf()
|
| 176 |
+
|
| 177 |
+
def save(self):
|
| 178 |
+
if not self.disabled:
|
| 179 |
+
self.vis.save([self.env_name])
|