erfanasgari21 commited on
Commit
152eac6
·
verified ·
1 Parent(s): b9fabbf

Initial upload

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +30 -0
  3. LICENSE +14 -0
  4. README.md +97 -0
  5. __init__.py +0 -0
  6. encoder/__init__.py +0 -0
  7. encoder/audio.py +117 -0
  8. encoder/config.py +53 -0
  9. encoder/data_objects/__init__.py +2 -0
  10. encoder/data_objects/random_cycler.py +37 -0
  11. encoder/data_objects/speaker.py +40 -0
  12. encoder/data_objects/speaker_batch.py +13 -0
  13. encoder/data_objects/speaker_verification_dataset.py +56 -0
  14. encoder/data_objects/utterance.py +26 -0
  15. encoder/inference.py +179 -0
  16. encoder/model.py +135 -0
  17. encoder/params_data.py +29 -0
  18. encoder/params_model.py +11 -0
  19. encoder/preprocess.py +196 -0
  20. encoder/train.py +125 -0
  21. encoder/visualizations.py +179 -0
  22. encoder_preprocess.py +69 -0
  23. encoder_train.py +45 -0
  24. inference.py +94 -0
  25. persian_numbers.py +295 -0
  26. pipe.py +67 -0
  27. prepare_data.py +96 -0
  28. requirements.txt +0 -0
  29. resources/model.JPG +0 -0
  30. sample.wav +3 -0
  31. saved_models/default/encoder.pt +3 -0
  32. saved_models/default/vocoder_WavRNN.pt +3 -0
  33. saved_models/final_models/config.yml +191 -0
  34. saved_models/final_models/encoder.pt +3 -0
  35. saved_models/final_models/synthesizer.pt +3 -0
  36. saved_models/final_models/vocoder_HiFiGAN.pkl +3 -0
  37. sentence_splitter.py +123 -0
  38. spaces.py +165 -0
  39. synthesizer/LICENSE.txt +24 -0
  40. synthesizer/__init__.py +1 -0
  41. synthesizer/audio.py +206 -0
  42. synthesizer/audio_v2(support_hifigan).py +154 -0
  43. synthesizer/english utils/__init__.py +45 -0
  44. synthesizer/english utils/_cmudict.py +62 -0
  45. synthesizer/english utils/cleaners.py +88 -0
  46. synthesizer/english utils/numbers.py +69 -0
  47. synthesizer/english utils/plot.py +82 -0
  48. synthesizer/english utils/symbols.py +17 -0
  49. synthesizer/english utils/text.py +75 -0
  50. synthesizer/hparams.py +108 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ results/test_output.wav filter=lfs diff=lfs merge=lfs -text
37
+ results/test_output2.wav filter=lfs diff=lfs merge=lfs -text
38
+ sample.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ *.aux
3
+ *.log
4
+ *.out
5
+ *.synctex.gz
6
+ *.suo
7
+ *pycache
8
+ *.idea
9
+ *.ipynb_checkpoints
10
+ *.pickle
11
+ *.npy
12
+ *.blg
13
+ *.bbl
14
+ *.bcf
15
+ *.toc
16
+ *.sh
17
+
18
+ encoder/saved_models/*
19
+ synthesizer/saved_models/*
20
+ vocoder/saved_models/*
21
+ saved_models/my_run
22
+ saved_models/train_encoder
23
+ dataset/*
24
+ results/best_result
25
+ vocoder/hifigan/*.pkl
26
+ vocoder/hifigan2
27
+ evaluate_vocoder
28
+ features_check
29
+ auto_inference.py
30
+ start_instruction.txt
LICENSE ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Modified & original work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
4
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
5
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
6
+ Original work Copyright (c) 2015 braindead (https://github.com/braindead)
7
+ Modified work Copyright (c) 2025 Majid Adibian (https://github.com/Adibian)
8
+ Modified work Copyright (c) 2025 Mahta Fetrat (https://github.com/MahtaFetrat)
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MultiSpeaker Tacotron2 in Persian Language
2
+ This repository implements [Transfer Learning from Speaker Verification to
3
+ Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf) (SV2TTS) for the Persian language. The core codebase is derived from [this repository](https://github.com/Adibian/Persian-MultiSpeaker-Tacotron2), which has been updated to address deprecated features and complete setup for Persian language compatibility. The original codebase, sourced from [this repository](https://github.com/CorentinJ/Real-Time-Voice-Cloning/tree/master), has been modified to support Persian language requirements.
4
+
5
+ <img src="https://github.com/majidAdibian77/persian-SV2TTS/blob/master/resources/model.JPG" width="800">
6
+
7
+ ---
8
+
9
+ ## Training
10
+ **1. Character-set definition:**
11
+
12
+ Open the `synthesizer/persian_utils/symbols.py` file and update the `_characters` variable to include all the characters that exist in your text files. Most of Persian characters and symbols are already included in this variable as follows:
13
+ ```
14
+ _characters = "ءابتثجحخدذرزسشصضطظعغفقلمنهويِپچژکگیآۀأؤإئًَُّ!(),-.:;? ̠،…؛؟‌٪#ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_–@+/\u200c"
15
+ ```
16
+
17
+ **2. Data structures:**
18
+ ```
19
+ dataset/persian_date/
20
+ train_data/
21
+ speaker1/book-1/
22
+ sample1.txt
23
+ sample1.wav
24
+ ...
25
+ ...
26
+ test_data/
27
+ ...
28
+ ```
29
+
30
+ **3. Preprocessing:**
31
+ ```
32
+ python3 synthesizer_preprocess_audio.py dataset --datasets_name persian_data --subfolders train_data --no_alignments --skip_existing --n_processes 4 --out_dir dataset/train/SV2TTS/synthesizer
33
+ python3 synthesizer_preprocess_audio.py dataset --datasets_name persian_data --subfolders test_data --no_alignments --skip_existing --n_processes 4 --out_dir dataset/test/SV2TTS/synthesizer
34
+ ```
35
+ 2. **Embedding Preprocessing**
36
+ ```
37
+ python3 synthesizer_preprocess_embeds.py dataset/train/SV2TTS/synthesizer
38
+ python3 synthesizer_preprocess_embeds.py dataset/test/SV2TTS/synthesizer
39
+ ```
40
+
41
+ **4. Train synthesizer:**
42
+ ```
43
+ python3 synthesizer_train.py my_run dataset/train/SV2TTS/synthesizer
44
+ ```
45
+
46
+ ## Inference
47
+
48
+ To generate a wav file, place all trained models in the `saved_models/final_models` directory. If you haven’t trained the speaker encoder or vocoder models, you can use pretrained models from `saved_models/default`. These models include `encoder.pt`, your latest synthesizer checkpoint like `synthesizer_000300.pt`, and a vocoder as follows.
49
+
50
+ ### Using WavRNN as Vocoder
51
+
52
+ ```
53
+ python3 inference.py --vocoder "WavRNN" --text "یک نمونه از خروجی" --ref_wav_path "/path/to/sample/reference.wav" --test_name "test1"
54
+ ```
55
+
56
+ ### Using HiFiGAN as Vocoder (Recommended)
57
+ WavRNN is an old vocoder and if you want to use HiFiGAN you must first download a pretrained model in English.
58
+ 1. **Install Parallel WaveGAN**
59
+ ```
60
+ pip install parallel_wavegan
61
+ ```
62
+ 2. **Download Pretrained HiFiGAN Model**
63
+ ```
64
+ from parallel_wavegan.utils import download_pretrained_model
65
+ download_pretrained_model("vctk_hifigan.v1", "saved_models/final_models/vocoder_HiFiGAN")
66
+ ```
67
+ 3. **Run Inference with HiFiGAN**
68
+ ```
69
+ python3 inference.py --vocoder "HiFiGAN" --text "یک نمونه از خروجی" --ref_wav_path "/path/to/sample/reference.wav" --test_name "test1"
70
+ ```
71
+
72
+ ## ManaTTS-Trained Model
73
+
74
+ This architecture has been used to train a Persian Text-to-Speech (TTS) model on the [**ManaTTS dataset**](https://huggingface.co/datasets/MahtaFetrat/Mana-TTS), the largest publicly available single-speaker Persian corpus. The trained model weights and detailed inference instructions can be found in the following repositories:
75
+
76
+ - [Hugging Face Repository](https://huggingface.co/MahtaFetrat/Persian-Tacotron2-on-ManaTTS)
77
+ - [GitHub Repository](https://github.com/MahtaFetrat/ManaTTS-Persian-Tacotron2-Model)
78
+
79
+ ## References:
80
+ - [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf) Ye Jia, *et al*.,
81
+ - [Real-Time-Voice-Cloning repository](https://github.com/CorentinJ/Real-Time-Voice-Cloning/tree/master),
82
+ - [ParallelWaveGAN repository](https://github.com/kan-bayashi/ParallelWaveGAN)
83
+ - [Persian-MultiSpeaker-Tacotron2](https://github.com/Adibian/Persian-MultiSpeaker-Tacotron2)
84
+
85
+ ## License
86
+ This project is based on [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning),
87
+ which is licensed under the MIT License.
88
+ ```
89
+ Modified & original work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
90
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
91
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
92
+ Original work Copyright (c) 2015 braindead (https://github.com/braindead)
93
+ Modified work Copyright (c) 2025 Majid Adibian (https://github.com/Adibian)
94
+ Modified work Copyright (c) 2025 Mahta Fetrat (https://github.com/MahtaFetrat)
95
+ ```
96
+
97
+
__init__.py ADDED
File without changes
encoder/__init__.py ADDED
File without changes
encoder/audio.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ from warnings import warn
6
+ import numpy as np
7
+ import librosa
8
+ import struct
9
+
10
+ try:
11
+ import webrtcvad
12
+ except:
13
+ warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
14
+ webrtcvad=None
15
+
16
+ int16_max = (2 ** 15) - 1
17
+
18
+
19
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
20
+ source_sr: Optional[int] = None,
21
+ normalize: Optional[bool] = True,
22
+ trim_silence: Optional[bool] = True):
23
+ """
24
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
25
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
26
+
27
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
28
+ just .wav), either the waveform as a numpy array of floats.
29
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
30
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
31
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
32
+ this argument will be ignored.
33
+ """
34
+ # Load the wav from disk if needed
35
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
36
+ wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
37
+ else:
38
+ wav = fpath_or_wav
39
+
40
+ # Resample the wav if needed
41
+ if source_sr is not None and source_sr != sampling_rate:
42
+ wav = librosa.resample(wav, source_sr, sampling_rate)
43
+
44
+ # Apply the preprocessing: normalize volume and shorten long silences
45
+ if normalize:
46
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
47
+ if webrtcvad and trim_silence:
48
+ wav = trim_long_silences(wav)
49
+
50
+ return wav
51
+
52
+
53
+ def wav_to_mel_spectrogram(wav):
54
+ """
55
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
56
+ Note: this not a log-mel spectrogram.
57
+ """
58
+ frames = librosa.feature.melspectrogram(
59
+ y=wav,
60
+ sr=sampling_rate,
61
+ n_fft=int(sampling_rate * mel_window_length / 1000),
62
+ hop_length=int(sampling_rate * mel_window_step / 1000),
63
+ n_mels=mel_n_channels
64
+ )
65
+ return frames.astype(np.float32).T
66
+
67
+
68
+ def trim_long_silences(wav):
69
+ """
70
+ Ensures that segments without voice in the waveform remain no longer than a
71
+ threshold determined by the VAD parameters in params.py.
72
+
73
+ :param wav: the raw waveform as a numpy array of floats
74
+ :return: the same waveform with silences trimmed away (length <= original wav length)
75
+ """
76
+ # Compute the voice detection window size
77
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
78
+
79
+ # Trim the end of the audio to have a multiple of the window size
80
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
81
+
82
+ # Convert the float waveform to 16-bit mono PCM
83
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
84
+
85
+ # Perform voice activation detection
86
+ voice_flags = []
87
+ vad = webrtcvad.Vad(mode=3)
88
+ for window_start in range(0, len(wav), samples_per_window):
89
+ window_end = window_start + samples_per_window
90
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
91
+ sample_rate=sampling_rate))
92
+ voice_flags = np.array(voice_flags)
93
+
94
+ # Smooth the voice detection with a moving average
95
+ def moving_average(array, width):
96
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
97
+ ret = np.cumsum(array_padded, dtype=float)
98
+ ret[width:] = ret[width:] - ret[:-width]
99
+ return ret[width - 1:] / width
100
+
101
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
102
+ audio_mask = np.round(audio_mask).astype(bool)
103
+
104
+ # Dilate the voiced regions
105
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
106
+ audio_mask = np.repeat(audio_mask, samples_per_window)
107
+
108
+ return wav[audio_mask == True]
109
+
110
+
111
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
112
+ if increase_only and decrease_only:
113
+ raise ValueError("Both increase only and decrease only are set")
114
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
115
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
116
+ return wav
117
+ return wav * (10 ** (dBFS_change / 20))
encoder/config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ persian_datasets = {
2
+ "train": {
3
+ "data": ["persian_data/train_data"]
4
+ },
5
+ "test": {
6
+ "data": ["persian_data/test_data"]
7
+ }
8
+ }
9
+ librispeech_datasets = {
10
+ "train": {
11
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
12
+ "other": ["LibriSpeech/train-other-500"]
13
+ },
14
+ "test": {
15
+ "clean": ["LibriSpeech/test-clean"],
16
+ "other": ["LibriSpeech/test-other"]
17
+ },
18
+ "dev": {
19
+ "clean": ["LibriSpeech/dev-clean"],
20
+ "other": ["LibriSpeech/dev-other"]
21
+ },
22
+ }
23
+ libritts_datasets = {
24
+ "train": {
25
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
26
+ "other": ["LibriTTS/train-other-500"]
27
+ },
28
+ "test": {
29
+ "clean": ["LibriTTS/test-clean"],
30
+ "other": ["LibriTTS/test-other"]
31
+ },
32
+ "dev": {
33
+ "clean": ["LibriTTS/dev-clean"],
34
+ "other": ["LibriTTS/dev-other"]
35
+ },
36
+ }
37
+ voxceleb_datasets = {
38
+ "voxceleb1" : {
39
+ "train": ["VoxCeleb1/wav"],
40
+ "test": ["VoxCeleb1/test_wav"]
41
+ },
42
+ "voxceleb2" : {
43
+ "train": ["VoxCeleb2/dev/aac"],
44
+ "test": ["VoxCeleb2/test_wav"]
45
+ }
46
+ }
47
+
48
+ other_datasets = [
49
+ "LJSpeech-1.1",
50
+ "VCTK-Corpus/wav48",
51
+ ]
52
+
53
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from encoder.data_objects.speaker import Speaker
4
+
5
+
6
+ class SpeakerBatch:
7
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
8
+ self.speakers = speakers
9
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
10
+
11
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
12
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
13
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from encoder.data_objects.speaker import Speaker
4
+ from encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
encoder/inference.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_data import *
2
+ from encoder.model import SpeakerEncoder
3
+ from encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from encoder import audio
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import torch
9
+
10
+ _model = None # type: SpeakerEncoder
11
+ _device = None # type: torch.device
12
+
13
+
14
+ def load_model(weights_fpath: Path, device=None):
15
+ """
16
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
17
+ first call to embed_frames() with the default weights file.
18
+
19
+ :param weights_fpath: the path to saved model weights.
20
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
21
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
22
+ If None, will default to your GPU if it"s available, otherwise your CPU.
23
+ """
24
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
25
+ # was saved on. Worth investigating.
26
+ global _model, _device
27
+ if device is None:
28
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ elif isinstance(device, str):
30
+ _device = torch.device(device)
31
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
32
+ checkpoint = torch.load(weights_fpath, _device)
33
+ _model.load_state_dict(checkpoint["model_state"])
34
+ _model.eval()
35
+ # print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
36
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath, checkpoint["step"]))
37
+
38
+
39
+ def is_loaded():
40
+ return _model is not None
41
+
42
+
43
+ def embed_frames_batch(frames_batch):
44
+ """
45
+ Computes embeddings for a batch of mel spectrogram.
46
+
47
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48
+ (batch_size, n_frames, n_channels)
49
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50
+ """
51
+ if _model is None:
52
+ raise Exception("Model was not loaded. Call load_model() before inference.")
53
+
54
+ frames = torch.from_numpy(frames_batch).to(_device)
55
+ embed = _model.forward(frames).detach().cpu().numpy()
56
+ return embed
57
+
58
+
59
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60
+ min_pad_coverage=0.75, overlap=0.5):
61
+ """
62
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
64
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
66
+ defined in params_data.py.
67
+
68
+ The returned ranges may be indexing further than the length of the waveform. It is
69
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70
+
71
+ :param n_samples: the number of samples in the waveform
72
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73
+ utterance
74
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
76
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
79
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80
+ utterances are entirely disjoint.
81
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
83
+ utterances.
84
+ """
85
+ assert 0 <= overlap < 1
86
+ assert 0 < min_pad_coverage <= 1
87
+
88
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91
+
92
+ # Compute the slices
93
+ wav_slices, mel_slices = [], []
94
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95
+ for i in range(0, steps, frame_step):
96
+ mel_range = np.array([i, i + partial_utterance_n_frames])
97
+ wav_range = mel_range * samples_per_frame
98
+ mel_slices.append(slice(*mel_range))
99
+ wav_slices.append(slice(*wav_range))
100
+
101
+ # Evaluate whether extra padding is warranted or not
102
+ last_wav_range = wav_slices[-1]
103
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
105
+ mel_slices = mel_slices[:-1]
106
+ wav_slices = wav_slices[:-1]
107
+
108
+ return wav_slices, mel_slices
109
+
110
+
111
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112
+ """
113
+ Computes an embedding for a single utterance.
114
+
115
+ # TODO: handle multiple wavs to benefit from batching on GPU
116
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117
+ :param using_partials: if True, then the utterance is split in partial utterances of
118
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
119
+ normalized average. If False, the utterance is instead computed from feeding the entire
120
+ spectogram to the network.
121
+ :param return_partials: if True, the partial embeddings will also be returned along with the
122
+ wav slices that correspond to the partial embeddings.
123
+ :param kwargs: additional arguments to compute_partial_splits()
124
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
126
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
128
+ instead.
129
+ """
130
+ # Process the entire utterance if not using partials
131
+ if not using_partials:
132
+ frames = audio.wav_to_mel_spectrogram(wav)
133
+ embed = embed_frames_batch(frames[None, ...])[0]
134
+ if return_partials:
135
+ return embed, None, None
136
+ return embed
137
+
138
+ # Compute where to split the utterance into partials and pad if necessary
139
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140
+ max_wave_length = wave_slices[-1].stop
141
+ if max_wave_length >= len(wav):
142
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143
+
144
+ # Split the utterance into partials
145
+ frames = audio.wav_to_mel_spectrogram(wav)
146
+ frames_batch = np.array([frames[s] for s in mel_slices])
147
+ partial_embeds = embed_frames_batch(frames_batch)
148
+
149
+ # Compute the utterance embedding from the partial embeddings
150
+ raw_embed = np.mean(partial_embeds, axis=0)
151
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
152
+
153
+ if return_partials:
154
+ return embed, partial_embeds, wave_slices
155
+ return embed
156
+
157
+
158
+ def embed_speaker(wavs, **kwargs):
159
+ raise NotImplemented()
160
+
161
+
162
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163
+ import matplotlib.pyplot as plt
164
+ if ax is None:
165
+ ax = plt.gca()
166
+
167
+ if shape is None:
168
+ height = int(np.sqrt(len(embed)))
169
+ shape = (height, -1)
170
+ embed = embed.reshape(shape)
171
+
172
+ cmap = cm.get_cmap()
173
+ mappable = ax.imshow(embed, cmap=cmap)
174
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
175
+ sm = cm.ScalarMappable(cmap=cmap)
176
+ sm.set_clim(*color_range)
177
+
178
+ ax.set_xticks([]), ax.set_yticks([])
179
+ ax.set_title(title)
encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_model import *
2
+ from encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels,
19
+ hidden_size=model_hidden_size,
20
+ num_layers=model_num_layers,
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
encoder/preprocess.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from functools import partial
3
+ from multiprocessing import Pool
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from encoder import audio
10
+ from encoder.config import librispeech_datasets, anglophone_nationalites, persian_datasets
11
+ from encoder.params_data import *
12
+
13
+
14
+ _AUDIO_EXTENSIONS = ("wav", "flac", "m4a", "mp3")
15
+
16
+ class DatasetLog:
17
+ """
18
+ Registers metadata about the dataset in a text file.
19
+ """
20
+ def __init__(self, root, name):
21
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
22
+ self.sample_data = dict()
23
+
24
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
25
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
26
+ self.write_line("-----")
27
+ self._log_params()
28
+
29
+ def _log_params(self):
30
+ from encoder import params_data
31
+ self.write_line("Parameter values:")
32
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
33
+ value = getattr(params_data, param_name)
34
+ self.write_line("\t%s: %s" % (param_name, value))
35
+ self.write_line("-----")
36
+
37
+ def write_line(self, line):
38
+ self.text_file.write("%s\n" % line)
39
+
40
+ def add_sample(self, **kwargs):
41
+ for param_name, value in kwargs.items():
42
+ if not param_name in self.sample_data:
43
+ self.sample_data[param_name] = []
44
+ self.sample_data[param_name].append(value)
45
+
46
+ def finalize(self):
47
+ self.write_line("Statistics:")
48
+ for param_name, values in self.sample_data.items():
49
+ self.write_line("\t%s:" % param_name)
50
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
51
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
52
+ self.write_line("-----")
53
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
54
+ self.write_line("Finished on %s" % end_time)
55
+ self.text_file.close()
56
+
57
+
58
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
59
+ dataset_root = datasets_root.joinpath(dataset_name)
60
+ if not dataset_root.exists():
61
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
62
+ return None, None
63
+ return dataset_root, DatasetLog(out_dir, dataset_name)
64
+
65
+
66
+ def _preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, skip_existing: bool):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ audio_durs = []
90
+ for extension in _AUDIO_EXTENSIONS:
91
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
92
+ # Check if the target output file already exists
93
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
94
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
95
+ if skip_existing and out_fname in existing_fnames:
96
+ continue
97
+
98
+ # Load and preprocess the waveform
99
+ wav = audio.preprocess_wav(in_fpath)
100
+ if len(wav) == 0:
101
+ continue
102
+
103
+ # Create the mel spectrogram, discard those that are too short
104
+ frames = audio.wav_to_mel_spectrogram(wav)
105
+ if len(frames) < partials_n_frames:
106
+ continue
107
+
108
+ out_fpath = speaker_out_dir.joinpath(out_fname)
109
+ np.save(out_fpath, frames)
110
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
111
+ audio_durs.append(len(wav) / sampling_rate)
112
+
113
+ sources_file.close()
114
+
115
+ return audio_durs
116
+
117
+
118
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger):
119
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
120
+
121
+ # Process the utterances for each speaker
122
+ work_fn = partial(_preprocess_speaker, datasets_root=datasets_root, out_dir=out_dir, skip_existing=skip_existing)
123
+ with Pool(4) as pool:
124
+ tasks = pool.imap(work_fn, speaker_dirs)
125
+ for sample_durs in tqdm(tasks, dataset_name, len(speaker_dirs), unit="speakers"):
126
+ for sample_dur in sample_durs:
127
+ logger.add_sample(duration=sample_dur)
128
+
129
+ logger.finalize()
130
+ print("Done preprocessing %s.\n" % dataset_name)
131
+
132
+
133
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
134
+ for dataset_name in librispeech_datasets["train"]["other"]:
135
+ # Initialize the preprocessing
136
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
137
+ if not dataset_root:
138
+ return
139
+
140
+ # Preprocess all speakers
141
+ speaker_dirs = list(dataset_root.glob("*"))
142
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
143
+
144
+
145
+
146
+ def preprocess_persian(datasets_root: Path, out_dir: Path, skip_existing=False):
147
+ for dataset_name in persian_datasets["train"]["data"]:
148
+ # Initialize the preprocessing
149
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
150
+ if not dataset_root:
151
+ return
152
+ # Preprocess all speakers
153
+ speaker_dirs = list(dataset_root.glob("*"))
154
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
155
+
156
+
157
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
158
+ # Initialize the preprocessing
159
+ dataset_name = "VoxCeleb1"
160
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
161
+ if not dataset_root:
162
+ return
163
+
164
+ # Get the contents of the meta file
165
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
166
+ metadata = [line.split("\t") for line in metafile][1:]
167
+
168
+ # Select the ID and the nationality, filter out non-anglophone speakers
169
+ nationalities = {line[0]: line[3] for line in metadata}
170
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
171
+ nationality.lower() in anglophone_nationalites]
172
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
173
+ (len(keep_speaker_ids), len(nationalities)))
174
+
175
+ # Get the speaker directories for anglophone speakers only
176
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
177
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
178
+ speaker_dir.name in keep_speaker_ids]
179
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
180
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
181
+
182
+ # Preprocess all speakers
183
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
184
+
185
+
186
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
187
+ # Initialize the preprocessing
188
+ dataset_name = "VoxCeleb2"
189
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
190
+ if not dataset_root:
191
+ return
192
+
193
+ # Get the speaker directories
194
+ # Preprocess all speakers
195
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
196
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
encoder/train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+
5
+ from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
6
+ from encoder.model import SpeakerEncoder
7
+ from encoder.params_model import *
8
+ from encoder.visualizations import Visualizations
9
+ from utils.profiler import Profiler
10
+
11
+
12
+ def sync(device: torch.device):
13
+ # For correct profiling (cuda operations are async)
14
+ if device.type == "cuda":
15
+ torch.cuda.synchronize(device)
16
+
17
+
18
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
19
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
20
+ no_visdom: bool):
21
+ # Create a dataset and a dataloader
22
+ dataset = SpeakerVerificationDataset(clean_data_root)
23
+ loader = SpeakerVerificationDataLoader(
24
+ dataset,
25
+ speakers_per_batch,
26
+ utterances_per_speaker,
27
+ num_workers=4,
28
+ )
29
+
30
+ # Setup the device on which to run the forward pass and the loss. These can be different,
31
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
32
+ # hyperparameters) faster on the CPU.
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ # FIXME: currently, the gradient is None if loss_device is cuda
35
+ loss_device = torch.device("cpu")
36
+
37
+ # Create the model and the optimizer
38
+ model = SpeakerEncoder(device, loss_device)
39
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
40
+ init_step = 1
41
+
42
+ # Configure file path for the model
43
+ model_dir = models_dir / run_id
44
+ model_dir.mkdir(exist_ok=True, parents=True)
45
+ state_fpath = model_dir / "encoder.pt"
46
+
47
+ # Load any existing model
48
+ if not force_restart:
49
+ if state_fpath.exists():
50
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
51
+ checkpoint = torch.load(state_fpath)
52
+ init_step = checkpoint["step"]
53
+ model.load_state_dict(checkpoint["model_state"])
54
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
55
+ optimizer.param_groups[0]["lr"] = learning_rate_init
56
+ else:
57
+ print("No model \"%s\" found, starting training from scratch." % run_id)
58
+ else:
59
+ print("Starting the training from scratch.")
60
+ model.train()
61
+
62
+ # Initialize the visualization environment
63
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
64
+ vis.log_dataset(dataset)
65
+ vis.log_params()
66
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
67
+ vis.log_implementation({"Device": device_name})
68
+
69
+ # Training loop
70
+ profiler = Profiler(summarize_every=10, disabled=False)
71
+ for step, speaker_batch in enumerate(loader, init_step):
72
+ profiler.tick("Blocking, waiting for batch (threaded)")
73
+
74
+ # Forward pass
75
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
76
+ sync(device)
77
+ profiler.tick("Data to %s" % device)
78
+ embeds = model(inputs)
79
+ sync(device)
80
+ profiler.tick("Forward pass")
81
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
82
+ loss, eer = model.loss(embeds_loss)
83
+ sync(loss_device)
84
+ profiler.tick("Loss")
85
+
86
+ # Backward pass
87
+ model.zero_grad()
88
+ loss.backward()
89
+ profiler.tick("Backward pass")
90
+ model.do_gradient_ops()
91
+ optimizer.step()
92
+ profiler.tick("Parameter update")
93
+
94
+ # Update visualizations
95
+ # learning_rate = optimizer.param_groups[0]["lr"]
96
+ vis.update(loss.item(), eer, step)
97
+
98
+ # Draw projections and save them to the backup folder
99
+ if umap_every != 0 and step % umap_every == 0:
100
+ print("Drawing and saving projections (step %d)" % step)
101
+ projection_fpath = model_dir / f"umap_{step:06d}.png"
102
+ embeds = embeds.detach().cpu().numpy()
103
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
104
+ vis.save()
105
+
106
+ # Overwrite the latest version of the model
107
+ if save_every != 0 and step % save_every == 0:
108
+ print("Saving the model (step %d)" % step)
109
+ torch.save({
110
+ "step": step + 1,
111
+ "model_state": model.state_dict(),
112
+ "optimizer_state": optimizer.state_dict(),
113
+ }, state_fpath)
114
+
115
+ # Make a backup
116
+ if backup_every != 0 and step % backup_every == 0:
117
+ print("Making a backup (step %d)" % step)
118
+ backup_fpath = model_dir / f"encoder_{step:06d}.bak"
119
+ torch.save({
120
+ "step": step + 1,
121
+ "model_state": model.state_dict(),
122
+ "optimizer_state": optimizer.state_dict(),
123
+ }, backup_fpath)
124
+
125
+ profiler.tick("Extras (visualizations, saving)")
encoder/visualizations.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from time import perf_counter as timer
3
+
4
+ import numpy as np
5
+ import umap
6
+ import visdom
7
+
8
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
9
+
10
+
11
+ colormap = np.array([
12
+ [76, 255, 0],
13
+ [0, 127, 70],
14
+ [255, 0, 0],
15
+ [255, 217, 38],
16
+ [0, 135, 255],
17
+ [165, 0, 165],
18
+ [255, 167, 255],
19
+ [0, 255, 255],
20
+ [255, 96, 38],
21
+ [142, 76, 0],
22
+ [33, 0, 127],
23
+ [0, 0, 0],
24
+ [183, 183, 183],
25
+ ], dtype=np.float) / 255
26
+
27
+
28
+ class Visualizations:
29
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
30
+ # Tracking data
31
+ self.last_update_timestamp = timer()
32
+ self.update_every = update_every
33
+ self.step_times = []
34
+ self.losses = []
35
+ self.eers = []
36
+ print("Updating the visualizations every %d steps." % update_every)
37
+
38
+ # If visdom is disabled TODO: use a better paradigm for that
39
+ self.disabled = disabled
40
+ if self.disabled:
41
+ return
42
+
43
+ # Set the environment name
44
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
45
+ if env_name is None:
46
+ self.env_name = now
47
+ else:
48
+ self.env_name = "%s (%s)" % (env_name, now)
49
+
50
+ # Connect to visdom and open the corresponding window in the browser
51
+ try:
52
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
53
+ except ConnectionError:
54
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
55
+ "start it.")
56
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
57
+
58
+ # Create the windows
59
+ self.loss_win = None
60
+ self.eer_win = None
61
+ # self.lr_win = None
62
+ self.implementation_win = None
63
+ self.projection_win = None
64
+ self.implementation_string = ""
65
+
66
+ def log_params(self):
67
+ if self.disabled:
68
+ return
69
+ from encoder import params_data
70
+ from encoder import params_model
71
+ param_string = "<b>Model parameters</b>:<br>"
72
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
73
+ value = getattr(params_model, param_name)
74
+ param_string += "\t%s: %s<br>" % (param_name, value)
75
+ param_string += "<b>Data parameters</b>:<br>"
76
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
77
+ value = getattr(params_data, param_name)
78
+ param_string += "\t%s: %s<br>" % (param_name, value)
79
+ self.vis.text(param_string, opts={"title": "Parameters"})
80
+
81
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
82
+ if self.disabled:
83
+ return
84
+ dataset_string = ""
85
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
86
+ dataset_string += "\n" + dataset.get_logs()
87
+ dataset_string = dataset_string.replace("\n", "<br>")
88
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
89
+
90
+ def log_implementation(self, params):
91
+ if self.disabled:
92
+ return
93
+ implementation_string = ""
94
+ for param, value in params.items():
95
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
96
+ implementation_string = implementation_string.replace("\n", "<br>")
97
+ self.implementation_string = implementation_string
98
+ self.implementation_win = self.vis.text(
99
+ implementation_string,
100
+ opts={"title": "Training implementation"}
101
+ )
102
+
103
+ def update(self, loss, eer, step):
104
+ # Update the tracking data
105
+ now = timer()
106
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
107
+ self.last_update_timestamp = now
108
+ self.losses.append(loss)
109
+ self.eers.append(eer)
110
+ print(".", end="")
111
+
112
+ # Update the plots every <update_every> steps
113
+ if step % self.update_every != 0:
114
+ return
115
+ time_string = "Step time: mean: %5dms std: %5dms" % \
116
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
117
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
118
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
119
+ if not self.disabled:
120
+ self.loss_win = self.vis.line(
121
+ [np.mean(self.losses)],
122
+ [step],
123
+ win=self.loss_win,
124
+ update="append" if self.loss_win else None,
125
+ opts=dict(
126
+ legend=["Avg. loss"],
127
+ xlabel="Step",
128
+ ylabel="Loss",
129
+ title="Loss",
130
+ )
131
+ )
132
+ self.eer_win = self.vis.line(
133
+ [np.mean(self.eers)],
134
+ [step],
135
+ win=self.eer_win,
136
+ update="append" if self.eer_win else None,
137
+ opts=dict(
138
+ legend=["Avg. EER"],
139
+ xlabel="Step",
140
+ ylabel="EER",
141
+ title="Equal error rate"
142
+ )
143
+ )
144
+ if self.implementation_win is not None:
145
+ self.vis.text(
146
+ self.implementation_string + ("<b>%s</b>" % time_string),
147
+ win=self.implementation_win,
148
+ opts={"title": "Training implementation"},
149
+ )
150
+
151
+ # Reset the tracking
152
+ self.losses.clear()
153
+ self.eers.clear()
154
+ self.step_times.clear()
155
+
156
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10):
157
+ import matplotlib.pyplot as plt
158
+
159
+ max_speakers = min(max_speakers, len(colormap))
160
+ embeds = embeds[:max_speakers * utterances_per_speaker]
161
+
162
+ n_speakers = len(embeds) // utterances_per_speaker
163
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
164
+ colors = [colormap[i] for i in ground_truth]
165
+
166
+ reducer = umap.UMAP()
167
+ projected = reducer.fit_transform(embeds)
168
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
169
+ plt.gca().set_aspect("equal", "datalim")
170
+ plt.title("UMAP projection (step %d)" % step)
171
+ if not self.disabled:
172
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
173
+ if out_fpath is not None:
174
+ plt.savefig(out_fpath)
175
+ plt.clf()
176
+
177
+ def save(self):
178
+ if not self.disabled:
179
+ self.vis.save([self.env_name])
encoder_preprocess.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.preprocess import preprocess_persian
2
+ from utils.argutils import print_args
3
+ from pathlib import Path
4
+ import argparse
5
+
6
+
7
+ if __name__ == "__main__":
8
+ class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
9
+ pass
10
+
11
+ parser = argparse.ArgumentParser(
12
+ description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
13
+ "writes them to the disk. This will allow you to train the encoder. The "
14
+ "datasets required are at least one of VoxCeleb1, VoxCeleb2 and LibriSpeech. "
15
+ "Ideally, you should have all three. You should extract them as they are "
16
+ "after having downloaded them and put them in a same directory, e.g.:\n"
17
+ "-[datasets_root]\n"
18
+ " -LibriSpeech\n"
19
+ " -train-other-500\n"
20
+ " -VoxCeleb1\n"
21
+ " -wav\n"
22
+ " -vox1_meta.csv\n"
23
+ " -VoxCeleb2\n"
24
+ " -dev",
25
+ formatter_class=MyFormatter
26
+ )
27
+ parser.add_argument("datasets_root", type=Path, help=\
28
+ "Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.")
29
+ parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
30
+ "Path to the output directory that will contain the mel spectrograms. If left out, "
31
+ "defaults to <datasets_root>/SV2TTS/encoder/")
32
+ parser.add_argument("-d", "--datasets", type=str,
33
+ default="librispeech_other,voxceleb1,voxceleb2", help=\
34
+ "Comma-separated list of the name of the datasets you want to preprocess. Only the train "
35
+ "set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
36
+ "voxceleb2.")
37
+ parser.add_argument("-s", "--skip_existing", action="store_true", help=\
38
+ "Whether to skip existing output files with the same name. Useful if this script was "
39
+ "interrupted.")
40
+ parser.add_argument("--no_trim", action="store_true", help=\
41
+ "Preprocess audio without trimming silences (not recommended).")
42
+ args = parser.parse_args()
43
+
44
+ # Verify webrtcvad is available
45
+ if not args.no_trim:
46
+ try:
47
+ import webrtcvad
48
+ except:
49
+ raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
50
+ "noise removal and is recommended. Please install and try again. If installation fails, "
51
+ "use --no_trim to disable this error message.")
52
+ del args.no_trim
53
+
54
+ # Process the arguments
55
+ args.datasets = args.datasets.split(",")
56
+ if not hasattr(args, "out_dir"):
57
+ args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder")
58
+ assert args.datasets_root.exists()
59
+ args.out_dir.mkdir(exist_ok=True, parents=True)
60
+
61
+ # Preprocess the datasets
62
+ print_args(args, parser)
63
+ preprocess_func = {
64
+ "persian_data": preprocess_persian
65
+ }
66
+ args = vars(args)
67
+ for dataset in args.pop("datasets"):
68
+ print("Preprocessing %s" % dataset)
69
+ preprocess_func[dataset](**args)
encoder_train.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.argutils import print_args
2
+ from encoder.train import train
3
+ from pathlib import Path
4
+ import argparse
5
+
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser(
9
+ description="Trains the speaker encoder. You must have run encoder_preprocess.py first.",
10
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
11
+ )
12
+
13
+ parser.add_argument("run_id", type=str, help= \
14
+ "Name for this model. By default, training outputs will be stored to saved_models/<run_id>/. If a model state "
15
+ "from the same run ID was previously saved, the training will restart from there. Pass -f to overwrite saved "
16
+ "states and restart from scratch.")
17
+ parser.add_argument("clean_data_root", type=Path, help= \
18
+ "Path to the output directory of encoder_preprocess.py. If you left the default "
19
+ "output directory when preprocessing, it should be <datasets_root>/SV2TTS/encoder/.")
20
+ parser.add_argument("-m", "--models_dir", type=Path, default="saved_models", help=\
21
+ "Path to the root directory that contains all models. A directory <run_name> will be created under this root."
22
+ "It will contain the saved model weights, as well as backups of those weights and plots generated during "
23
+ "training.")
24
+ parser.add_argument("-v", "--vis_every", type=int, default=10, help= \
25
+ "Number of steps between updates of the loss and the plots.")
26
+ parser.add_argument("-u", "--umap_every", type=int, default=500, help= \
27
+ "Number of steps between updates of the umap projection. Set to 0 to never update the "
28
+ "projections.")
29
+ parser.add_argument("-s", "--save_every", type=int, default=500, help= \
30
+ "Number of steps between updates of the model on the disk. Set to 0 to never save the "
31
+ "model.")
32
+ parser.add_argument("-b", "--backup_every", type=int, default=10000, help= \
33
+ "Number of steps between backups of the model. Set to 0 to never make backups of the "
34
+ "model.")
35
+ parser.add_argument("-f", "--force_restart", action="store_true", help= \
36
+ "Do not load any saved model.")
37
+ parser.add_argument("--visdom_server", type=str, default="http://localhost")
38
+ parser.add_argument("--no_visdom", action="store_true", help= \
39
+ "Disable visdom.")
40
+ args = parser.parse_args()
41
+
42
+ # Run the training
43
+ print_args(args, parser)
44
+ train(**vars(args))
45
+
inference.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import sys
4
+
5
+ from encoder import inference as encoder
6
+ from synthesizer.inference import Synthesizer
7
+ from vocoder import inference as vocoder_wavrnn
8
+ from parallel_wavegan.utils import load_model as vocoder_hifigan
9
+
10
+ import soundfile as sf
11
+ import os
12
+ import argparse
13
+
14
+
15
+ main_path = os.getcwd()
16
+ models_path = os.path.join(main_path, 'saved_models/final_models/')
17
+
18
+ def wavRNN_infer(text, ref_wav_path, test_name):
19
+ encoder.load_model(os.path.join(models_path, 'encoder.pt'))
20
+ synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt'))
21
+ vocoder_wavrnn.load_model(os.path.join(models_path, 'vocoder_WavRNN.pt'))
22
+
23
+ ref_wav_path = os.path.join(main_path, 'dataset/persian_data/train_data/book-1/', ref_wav_path) ## refrence wav
24
+ wav = Synthesizer.load_preprocess_wav(ref_wav_path)
25
+
26
+ encoder_wav = encoder.preprocess_wav(wav)
27
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
28
+
29
+ texts = [text]
30
+ embeds = [embed] * len(texts)
31
+ specs = synthesizer.synthesize_spectrograms(texts, embeds)
32
+ breaks = [spec.shape[1] for spec in specs]
33
+ spec = np.concatenate(specs, axis=1)
34
+
35
+ wav = vocoder_wavrnn.infer_waveform(spec)
36
+ b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
37
+ b_starts = np.concatenate(([0], b_ends[:-1]))
38
+ wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
39
+ breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
40
+ wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
41
+ wav = wav / np.abs(wav).max() * 0.97
42
+
43
+ res_path = os.path.join(main_path, 'results/', test_name+".wav")
44
+ sf.write(res_path, wav, Synthesizer.sample_rate)
45
+ print('\nwav file is saved.')
46
+
47
+
48
+ def hifigan_infer(text, ref_wav_path, test_name):
49
+ encoder.load_model(os.path.join(models_path, 'encoder.pt'))
50
+ synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt'))
51
+ vocoder = vocoder_hifigan(os.path.join(models_path, 'vocoder_HiFiGAN.pkl'))
52
+ vocoder.remove_weight_norm()
53
+ vocoder = vocoder.eval().to('cpu')
54
+
55
+ wav = Synthesizer.load_preprocess_wav(ref_wav_path)
56
+
57
+ encoder_wav = encoder.preprocess_wav(wav)
58
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
59
+
60
+ texts = [text]
61
+ embeds = [embed] * len(texts)
62
+ specs = synthesizer.synthesize_spectrograms(texts, embeds)
63
+ spec = np.concatenate(specs, axis=1)
64
+ x = torch.from_numpy(spec.T).to('cpu')
65
+
66
+ with torch.no_grad():
67
+ wav = vocoder.inference(x)
68
+ wav = wav / np.abs(wav).max() * 0.97
69
+
70
+ res_path = os.path.join(main_path, 'results/', test_name+".wav")
71
+ sf.write(res_path, wav, Synthesizer.sample_rate)
72
+ print('\nwav file is saved.')
73
+
74
+
75
+ def main(args):
76
+ if str(args.vocoder).lower() == "wavrnn":
77
+ wavRNN_infer(args.text, args.ref_wav_path, args.test_name)
78
+ elif str(args.vocoder).lower() == "hifigan":
79
+ hifigan_infer(args.text, args.ref_wav_path, args.test_name)
80
+ else:
81
+ print("--vocoder must be one of HiFiGAN or WavRNN")
82
+
83
+
84
+ if __name__ == "__main__":
85
+ parser = argparse.ArgumentParser()
86
+ parser.add_argument("--vocoder", type=str, help= "vocoder name: HiFiGAN or WavRNN")
87
+ parser.add_argument("--text", type=str, help="input text")
88
+ parser.add_argument("--ref_wav_path", type=str, help="path to refrence wav to create speaker from that")
89
+ parser.add_argument("--test_name", type=str, default="test1", help="name of current test to save the result wav")
90
+ args = parser.parse_args()
91
+
92
+ main(args)
93
+
94
+
persian_numbers.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ DIGITS_MAP = {
5
+ '0': 'صِفر', '1': 'یک', '2': 'دو', '3': 'سه', '4': 'چهار',
6
+ '5': 'پنج', '6': 'شِش', '7': 'هفت', '8': 'هشت', '9': 'نُه'
7
+ }
8
+
9
+ TENS = {
10
+ 10: 'دَه', 11: 'یازده', 12: 'دوازده', 13: 'سیزده', 14: 'چهارده',
11
+ 15: 'پانزده', 16: 'شانزده', 17: 'هفده', 18: 'هجده', 19: 'نوزده',
12
+ 20: 'بیست', 30: 'سی', 40: 'چهل', 50: 'پنجاه',
13
+ 60: 'شصت', 70: 'هفتاد', 80: 'هشتاد', 90: 'نود'
14
+ }
15
+
16
+ HUNDREDS = {
17
+ 100: 'صَد', 200: 'دویست', 300: 'سیصد', 400: 'چهارصد', 500: 'پانصد',
18
+ 600: 'ششصد', 700: 'هفتصد', 800: 'هشتصد', 900: 'نهصد'
19
+ }
20
+
21
+
22
+ def _convert_three_digit(num: int) -> str:
23
+ if num == 0:
24
+ return ''
25
+
26
+ if num < 10:
27
+ return DIGITS_MAP[str(num)]
28
+ elif num < 20:
29
+ return TENS[num]
30
+ elif num < 100:
31
+ tens_part = (num // 10) * 10
32
+ ones_part = num % 10
33
+ if ones_part == 0:
34
+ return TENS[tens_part]
35
+ return f"{TENS[tens_part]} و {DIGITS_MAP[str(ones_part)]}"
36
+ else:
37
+ hundreds_part = (num // 100) * 100
38
+ rem = num % 100
39
+ if rem == 0:
40
+ return HUNDREDS[hundreds_part]
41
+ return f"{HUNDREDS[hundreds_part]} و {_convert_three_digit(rem)}"
42
+
43
+
44
+ def num_to_text(num: int) -> str:
45
+ if num == 0:
46
+ return 'صِفر'
47
+
48
+ if num < 0:
49
+ return f"مَنفی {num_to_text(abs(num))}"
50
+
51
+ if num < 1000:
52
+ return _convert_three_digit(num)
53
+
54
+ parts = []
55
+
56
+ if num >= 1_000_000_000:
57
+ billions = num // 1_000_000_000
58
+ parts.append(f"{_convert_three_digit(billions)} میلیارد")
59
+ num %= 1_000_000_000
60
+
61
+ if num >= 1_000_000:
62
+ millions = num // 1_000_000
63
+ parts.append(f"{_convert_three_digit(millions)} میلیون")
64
+ num %= 1_000_000
65
+
66
+ if num >= 1000:
67
+ thousands = num // 1000
68
+ parts.append(f"{_convert_three_digit(thousands)} هزار")
69
+ num %= 1000
70
+
71
+ if num > 0:
72
+ parts.append(_convert_three_digit(num))
73
+
74
+ return ' و '.join(parts)
75
+
76
+
77
+ def _read_phone_chunk(chunk: str) -> str:
78
+ if not chunk:
79
+ return ""
80
+
81
+ if all(c == '0' for c in chunk):
82
+ count = len(chunk)
83
+ if count == 2:
84
+ return "دو صِفر"
85
+ elif count == 3:
86
+ return "سِِتا صفر"
87
+ elif count == 4:
88
+ return "چهارتا صفر"
89
+ else:
90
+ return f"{num_to_text(count)} تا صِفر"
91
+
92
+ result_parts = []
93
+ temp_chunk = chunk
94
+
95
+ while temp_chunk.startswith('0'):
96
+ result_parts.append("صِفر")
97
+ temp_chunk = temp_chunk[1:]
98
+
99
+ if temp_chunk:
100
+ val = int(temp_chunk)
101
+ result_parts.append(num_to_text(val))
102
+
103
+ return " ".join(result_parts)
104
+
105
+
106
+ def _smart_split_phone(phone_str: str, has_plus: bool = False) -> list:
107
+ length = len(phone_str)
108
+ chunks = []
109
+
110
+ if has_plus:
111
+ if phone_str.startswith('98') and len(phone_str) > 5:
112
+ chunks.append("+" + phone_str[:2])
113
+ rest = phone_str[2:]
114
+ if rest.startswith('9'):
115
+
116
+ inner_chunks = _smart_split_phone("0" + rest)
117
+ chunks.extend(inner_chunks)
118
+ return chunks
119
+ else:
120
+ chunks.append(rest)
121
+ return chunks
122
+
123
+ elif phone_str.startswith('1') and length == 11:
124
+ chunks.append("+" + phone_str[:1])
125
+ chunks.append(phone_str[1:4])
126
+ chunks.append(phone_str[4:7])
127
+ chunks.append(phone_str[7:])
128
+ return chunks
129
+
130
+ if phone_str.startswith('09') and length == 11:
131
+ chunks.append(phone_str[:4])
132
+ rest = phone_str[4:]
133
+
134
+ part_mid = rest[:3]
135
+ part_end = rest[3:]
136
+
137
+ is_end_round = False
138
+ if part_end == '0000':
139
+ is_end_round = True
140
+ elif part_end.endswith('00'):
141
+ is_end_round = True
142
+ elif part_end[1] == '0' and part_end[2] == '0':
143
+ is_end_round = True
144
+ if part_mid == '000':
145
+ is_end_round = True
146
+
147
+ if is_end_round:
148
+ chunks.append(part_mid)
149
+ chunks.append(part_end)
150
+ else:
151
+ chunks.append(rest[:3])
152
+ chunks.append(rest[3:5])
153
+ chunks.append(rest[5:])
154
+ return chunks
155
+
156
+ if phone_str.startswith('0') and length == 11:
157
+ chunks.append(phone_str[:3])
158
+ rest = phone_str[3:]
159
+
160
+ part1 = rest[:4]
161
+ part2 = rest[4:]
162
+
163
+ if (part1.endswith('00') and part2.endswith('00')) or (part2 == '0000'):
164
+ chunks.append(part1)
165
+ chunks.append(part2)
166
+ return chunks
167
+
168
+ p3_1 = rest[:3]
169
+ p3_2 = rest[3:6]
170
+ if p3_1.endswith('0') and p3_2.endswith('0'):
171
+ chunks.append(p3_1)
172
+ chunks.append(p3_2)
173
+ chunks.append(rest[6:])
174
+ return chunks
175
+
176
+ chunks.append(rest[:2])
177
+ chunks.append(rest[2:4])
178
+ chunks.append(rest[4:6])
179
+ chunks.append(rest[6:])
180
+ return chunks
181
+
182
+ if not phone_str.startswith('0'):
183
+ if length == 8:
184
+ chunks.append(phone_str[:2])
185
+ chunks.append(phone_str[2:4])
186
+ chunks.append(phone_str[4:6])
187
+ chunks.append(phone_str[6:])
188
+ return chunks
189
+ elif length == 4:
190
+ chunks.append(phone_str)
191
+ return chunks
192
+ elif length == 5:
193
+ chunks.append(phone_str)
194
+ return chunks
195
+
196
+ if length == 10 and phone_str.startswith('9'):
197
+ chunks.append(phone_str[:3])
198
+ chunks.append(phone_str[3:6])
199
+ chunks.append(phone_str[6:8])
200
+ chunks.append(phone_str[8:])
201
+ return chunks
202
+
203
+ return [phone_str]
204
+
205
+
206
+ def phone_to_text(raw_input: str) -> str:
207
+ clean_input = raw_input.replace(' ', '').replace(
208
+ '-', '').replace('(', '').replace(')', '')
209
+
210
+ persian_digits = '۰۱۲۳۴۵۶۷۸۹'
211
+ english_digits = '0123456789'
212
+ trans_table = str.maketrans(persian_digits, english_digits)
213
+ clean_input = clean_input.translate(trans_table)
214
+
215
+ has_plus = False
216
+ if clean_input.startswith('+'):
217
+ has_plus = True
218
+ clean_input = clean_input[1:]
219
+
220
+ if not clean_input.isdigit():
221
+ return raw_input
222
+
223
+ chunks = _smart_split_phone(clean_input, has_plus)
224
+
225
+ text_parts = []
226
+ for ch in chunks:
227
+ if ch.startswith('+'):
228
+ val = int(ch[1:])
229
+ text_parts.append(f"مثبت {num_to_text(val)}")
230
+ else:
231
+ text_parts.append(_read_phone_chunk(ch))
232
+
233
+ return "، ".join(text_parts)
234
+
235
+
236
+ def _is_likely_phone(num_str: str) -> bool:
237
+ if num_str.startswith('+'):
238
+ return True
239
+
240
+ if num_str.startswith('09') and len(num_str) == 11:
241
+ return True
242
+
243
+ if num_str.startswith('0') and len(num_str) >= 7:
244
+ return True
245
+
246
+ return False
247
+
248
+
249
+ def find_and_normalize_numbers(text: str) -> str:
250
+ text = text.translate(str.maketrans('٠١٢٣٤٥٦٧٨٩', '0123456789'))\
251
+ .translate(str.maketrans('۰۱۲۳۴۵۶۷۸۹', '0123456789'))
252
+
253
+ pattern = r'(?:\+|-)?\d+(?:[,\-]\d+)*'
254
+
255
+ def replace_match(match):
256
+ original_str = match.group()
257
+ clean_str = original_str.replace(',', '')
258
+
259
+ if _is_likely_phone(clean_str):
260
+ return phone_to_text(clean_str)
261
+ else:
262
+ try:
263
+ val = int(clean_str)
264
+ return num_to_text(val)
265
+ except ValueError:
266
+ return original_str
267
+
268
+ return re.sub(pattern, replace_match, text)
269
+
270
+
271
+ if __name__ == "__main__":
272
+ examples = [
273
+
274
+ "شماره من ۰۹۱۲۳۴۵۶۷۸۹ است",
275
+ "تلفن شرکت ۰۲۱۸۸۰۵۶۰۷۰ می باشد",
276
+ "کد تایید: ۸۸۹۹۱۱۰۰",
277
+ "تماس بین المللی: +۹۸۹۱۵۱۰۰۲۰۳۰",
278
+ "شارژ مستقیم ۰۹۳۵۲۰۰۳۰۴۰",
279
+ "کد پستی ۱۱۱۱۱۰۰۰۰۰",
280
+ "و با تلفن ۰۲۱-۸۸۸۰۳۳۵۴ تماس بگیرید",
281
+
282
+
283
+ "قیمت این کالا ۵,۴۰۰ تومان است",
284
+ "جمعیت ایران ۸۵۰۰۰۰۰۰ نفر است",
285
+ "دمای هوا منفی ۵ درجه است: -5",
286
+ "تعداد ۱۰۰۱ شب",
287
+ "عدد صفر 0"
288
+ ]
289
+
290
+ print("--- بررسی عملکرد کد ادغام شده ---\n")
291
+ for ex in examples:
292
+ converted = find_and_normalize_numbers(ex)
293
+ print(f"Original: {ex}")
294
+ print(f"Converted: {converted}")
295
+ print("-" * 30)
pipe.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import sys
4
+
5
+ from encoder import inference as encoder
6
+ from synthesizer.inference import Synthesizer
7
+ from vocoder import inference as vocoder_wavrnn
8
+ from parallel_wavegan.utils import load_model as vocoder_hifigan
9
+
10
+ import soundfile as sf
11
+ import os
12
+ import argparse
13
+
14
+ import time
15
+
16
+
17
+ class PersianMultiSpeakerTacotron2:
18
+ def __init__(self, main_path=os.getcwd(), ref_wav_path=None):
19
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print("device", self.device)
21
+ self.models_path = os.path.join(main_path, 'saved_models/final_models/')
22
+ encoder.load_model(os.path.join(self.models_path, 'encoder.pt'))
23
+ self.synthesizer = Synthesizer(os.path.join(self.models_path, 'synthesizer.pt'))
24
+ self.vocoder = vocoder_hifigan(os.path.join(self.models_path, 'vocoder_HiFiGAN.pkl'))
25
+ self.vocoder.remove_weight_norm()
26
+ self.vocoder = self.vocoder.eval().to(self.device)
27
+
28
+ if(ref_wav_path != None):
29
+ wav = Synthesizer.load_preprocess_wav(ref_wav_path)
30
+ encoder_wav = encoder.preprocess_wav(wav)
31
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
32
+ self.embed = embed
33
+
34
+
35
+ def __call__(self, text, ref_wav_path=None, output_wav_path=None):
36
+
37
+ embed = None
38
+ if(ref_wav_path != None):
39
+ wav = Synthesizer.load_preprocess_wav(ref_wav_path)
40
+ encoder_wav = encoder.preprocess_wav(wav)
41
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
42
+ elif(self.embed is not None):
43
+ embed = self.embed
44
+ else:
45
+ raise "ref wav path must be specified"
46
+
47
+ # Start timer
48
+ start_time = time.time()
49
+
50
+ texts = [text]
51
+ embeds = [embed] * len(texts)
52
+ specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
53
+ spec = np.concatenate(specs, axis=1)
54
+ x = torch.from_numpy(spec.T).to(self.device)
55
+
56
+ with torch.no_grad():
57
+ wav = self.vocoder.inference(x).cpu()
58
+ wav = wav / np.abs(wav.max()) * 0.97
59
+
60
+ # End timer
61
+ total_time = time.time() - start_time
62
+ print(f"⏱ Wave generation took: {total_time:.3f} seconds")
63
+
64
+ if(output_wav_path==None):
65
+ return wav
66
+ else:
67
+ sf.write(output_wav_path, wav, Synthesizer.sample_rate)
prepare_data.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import csv
4
+ import shutil
5
+ import os
6
+ import argparse
7
+
8
+ main_path = os.getcwd()
9
+
10
+ def get_duration(row):
11
+ phone_durs = row.split()
12
+ dur_sum = 0
13
+ for phone_dur in phone_durs:
14
+ if phone_dur == '|':
15
+ continue
16
+ else:
17
+ phone_dur = phone_dur.split('[')
18
+ dur = float(phone_dur[1][:-1])/1000
19
+ dur_sum += dur
20
+ return dur_sum
21
+
22
+ def prepare_data_for_model(path, duration_lim):
23
+ f = open(path, 'r')
24
+ data = csv.DictReader(f)
25
+ data_lines = []
26
+ for row in data:
27
+ dur = get_duration(row['phenome'])
28
+ if dur > duration_lim:
29
+ continue
30
+ phoneme = row['phenome']
31
+ utterance_name = row['seg_id']
32
+ speaker_id = row['speaker_id']
33
+ phoneme = re.sub("\[([0-9]+)\]", '', phoneme)
34
+ phoneme = re.sub("\s+\|\s+", ' ', phoneme)
35
+ data_lines.append([phoneme, utterance_name, speaker_id])
36
+ f.close()
37
+ return data_lines
38
+
39
+
40
+ def save_files(train_data, test_data, data_path):
41
+ for line in train_data:
42
+ try:
43
+ original = os.path.join(data_path, 'train_wav/{}.wav'.format(line[1]))
44
+ target = os.path.join(main_path, 'dataset/persian_data/train_data/speaker-{0}/book-1/utterance-{1}.wav'.format(line[2], line[1]))
45
+ os.makedirs(os.path.dirname(target), exist_ok=True)
46
+ shutil.copyfile(original, target)
47
+ except Exception as e:
48
+ print(e)
49
+ return False
50
+
51
+ path = os.path.join(main_path, 'dataset/persian_data/train_data/speaker-{0}/book-1/utterance-{1}.txt'.format(line[2], line[1]))
52
+ with open(path, 'w') as fp:
53
+ fp.write(line[0])
54
+
55
+ for line in test_data:
56
+ try:
57
+ original = os.path.join(data_path, 'test_wav/{}.wav'.format(line[1]))
58
+ target = os.path.join(main_path, 'dataset/persian_data/test_data/speaker-{0}/book-1/utterance-{1}.wav'.format(line[2], line[1]))
59
+ os.makedirs(os.path.dirname(target), exist_ok=True)
60
+ shutil.copyfile(original, target)
61
+ except Exception as e:
62
+ print(e)
63
+ return False
64
+
65
+ path = os.path.join(main_path, 'dataset/persian_data/test_data/speaker-{0}/book-1/utterance-{1}.txt'.format(line[2], line[1]))
66
+ with open(path, 'w') as fp:
67
+ fp.write(line[0])
68
+ return True
69
+
70
+ def main():
71
+ parser = argparse.ArgumentParser()
72
+ parser.add_argument('--data_path', required=True)
73
+ args = parser.parse_args()
74
+ data_path = args.data_path
75
+
76
+ if os.path.isfile(os.path.join(data_path, 'train_info.csv')):
77
+ train_data_path = os.path.join(data_path, 'train_info.csv')
78
+ else:
79
+ print('data_path is not correct!')
80
+ return -1
81
+ if os.path.isfile(os.path.join(data_path, 'test_info.csv')):
82
+ test_data_path = os.path.join(data_path, 'test_info.csv')
83
+ else:
84
+ print('data_path is not correct!')
85
+ return -1
86
+ train_data = prepare_data_for_model(train_data_path, 12)
87
+ test_data = prepare_data_for_model(test_data_path, 15)
88
+ print('number of train data: ' + str(len(train_data)))
89
+ print('number of test data: ' + str(len(test_data)))
90
+
91
+ res = save_files(train_data, test_data, data_path)
92
+ if res:
93
+ print('Data is created.')
94
+
95
+ if __name__ == "__main__":
96
+ main()
requirements.txt ADDED
Binary file (562 Bytes). View file
 
resources/model.JPG ADDED
sample.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a29fd952720cb604c7c5276065792b41434800c8ae00c1326a35dbbcee1aa9d
3
+ size 1107294
saved_models/default/encoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39373b86598fa3da9fcddee6142382efe09777e8d37dc9c0561f41f0070f134e
3
+ size 17090379
saved_models/default/vocoder_WavRNN.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d7a6861589e927e0fbdaa5849ca022258fe2b58a20cc7bfb8fb598ccf936169
3
+ size 53845290
saved_models/final_models/config.yml ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ allow_cache: true
2
+ batch_max_steps: 8400
3
+ batch_size: 16
4
+ config: conf/hifigan.v1.yaml
5
+ dev_dumpdir: dump/dev_all/norm
6
+ dev_feats_scp: null
7
+ dev_segments: null
8
+ dev_wav_scp: null
9
+ discriminator_adv_loss_params:
10
+ average_by_discriminators: false
11
+ discriminator_grad_norm: -1
12
+ discriminator_optimizer_params:
13
+ betas:
14
+ - 0.5
15
+ - 0.9
16
+ lr: 0.0002
17
+ weight_decay: 0.0
18
+ discriminator_optimizer_type: Adam
19
+ discriminator_params:
20
+ follow_official_norm: true
21
+ period_discriminator_params:
22
+ bias: true
23
+ channels: 32
24
+ downsample_scales:
25
+ - 3
26
+ - 3
27
+ - 3
28
+ - 3
29
+ - 1
30
+ in_channels: 1
31
+ kernel_sizes:
32
+ - 5
33
+ - 3
34
+ max_downsample_channels: 1024
35
+ nonlinear_activation: LeakyReLU
36
+ nonlinear_activation_params:
37
+ negative_slope: 0.1
38
+ out_channels: 1
39
+ use_spectral_norm: false
40
+ use_weight_norm: true
41
+ periods:
42
+ - 2
43
+ - 3
44
+ - 5
45
+ - 7
46
+ - 11
47
+ scale_discriminator_params:
48
+ bias: true
49
+ channels: 128
50
+ downsample_scales:
51
+ - 4
52
+ - 4
53
+ - 4
54
+ - 4
55
+ - 1
56
+ in_channels: 1
57
+ kernel_sizes:
58
+ - 15
59
+ - 41
60
+ - 5
61
+ - 3
62
+ max_downsample_channels: 1024
63
+ max_groups: 16
64
+ nonlinear_activation: LeakyReLU
65
+ nonlinear_activation_params:
66
+ negative_slope: 0.1
67
+ out_channels: 1
68
+ scale_downsample_pooling: AvgPool1d
69
+ scale_downsample_pooling_params:
70
+ kernel_size: 4
71
+ padding: 2
72
+ stride: 2
73
+ scales: 3
74
+ discriminator_scheduler_params:
75
+ gamma: 0.5
76
+ milestones:
77
+ - 200000
78
+ - 400000
79
+ - 600000
80
+ - 800000
81
+ discriminator_scheduler_type: MultiStepLR
82
+ discriminator_train_start_steps: 0
83
+ discriminator_type: HiFiGANMultiScaleMultiPeriodDiscriminator
84
+ distributed: false
85
+ eval_interval_steps: 1000
86
+ feat_match_loss_params:
87
+ average_by_discriminators: false
88
+ average_by_layers: false
89
+ include_final_outputs: false
90
+ fft_size: 2048
91
+ fmax: 7600
92
+ fmin: 80
93
+ format: hdf5
94
+ generator_adv_loss_params:
95
+ average_by_discriminators: false
96
+ generator_grad_norm: -1
97
+ generator_optimizer_params:
98
+ betas:
99
+ - 0.5
100
+ - 0.9
101
+ lr: 0.0002
102
+ weight_decay: 0.0
103
+ generator_optimizer_type: Adam
104
+ generator_params:
105
+ bias: true
106
+ channels: 512
107
+ in_channels: 80
108
+ kernel_size: 7
109
+ nonlinear_activation: LeakyReLU
110
+ nonlinear_activation_params:
111
+ negative_slope: 0.1
112
+ out_channels: 1
113
+ resblock_dilations:
114
+ - - 1
115
+ - 3
116
+ - 5
117
+ - - 1
118
+ - 3
119
+ - 5
120
+ - - 1
121
+ - 3
122
+ - 5
123
+ resblock_kernel_sizes:
124
+ - 3
125
+ - 7
126
+ - 11
127
+ upsample_kernal_sizes:
128
+ - 10
129
+ - 10
130
+ - 8
131
+ - 6
132
+ upsample_scales:
133
+ - 5
134
+ - 5
135
+ - 4
136
+ - 3
137
+ use_additional_convs: true
138
+ use_weight_norm: true
139
+ generator_scheduler_params:
140
+ gamma: 0.5
141
+ milestones:
142
+ - 200000
143
+ - 400000
144
+ - 600000
145
+ - 800000
146
+ generator_scheduler_type: MultiStepLR
147
+ generator_train_start_steps: 1
148
+ generator_type: HiFiGANGenerator
149
+ global_gain_scale: 1.0
150
+ hop_size: 300
151
+ lambda_adv: 1.0
152
+ lambda_aux: 45.0
153
+ lambda_feat_match: 2.0
154
+ log_interval_steps: 100
155
+ mel_loss_params:
156
+ fft_size: 2048
157
+ fmax: 12000
158
+ fmin: 0
159
+ fs: 24000
160
+ hop_size: 300
161
+ log_base: null
162
+ num_mels: 80
163
+ win_length: 1200
164
+ window: hann
165
+ num_mels: 80
166
+ num_save_intermediate_results: 4
167
+ num_workers: 2
168
+ outdir: exp/train_nodev_all_vctk_hifigan.v1
169
+ pin_memory: true
170
+ pretrain: ''
171
+ rank: 0
172
+ remove_short_samples: false
173
+ resume: exp/train_nodev_all_vctk_hifigan.v1/checkpoint-2310000steps.pkl
174
+ sampling_rate: 24000
175
+ save_interval_steps: 10000
176
+ train_dumpdir: dump/train_nodev_all/norm
177
+ train_feats_scp: null
178
+ train_max_steps: 2500000
179
+ train_segments: null
180
+ train_wav_scp: null
181
+ trim_frame_size: 1024
182
+ trim_hop_size: 256
183
+ trim_silence: false
184
+ trim_threshold_in_db: 20
185
+ use_feat_match_loss: true
186
+ use_mel_loss: true
187
+ use_stft_loss: false
188
+ verbose: 1
189
+ version: 0.5.1
190
+ win_length: 1200
191
+ window: hann
saved_models/final_models/encoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39373b86598fa3da9fcddee6142382efe09777e8d37dc9c0561f41f0070f134e
3
+ size 17090379
saved_models/final_models/synthesizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:433588014a7c0622e9f3b00faa75d8fa35344f0d61f64fd715f27a096978a21d
3
+ size 371019913
saved_models/final_models/vocoder_HiFiGAN.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22a34e2308bdaf5566de0a326a45b9be16b524a3f97c566480fce72370604f49
3
+ size 1004606861
sentence_splitter.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List
3
+
4
+ class PersianSentenceSplitter:
5
+
6
+ def __init__(self, max_chars: int = 200, min_chars: int = 50):
7
+ self.max_chars = max_chars
8
+ self.min_chars = min_chars
9
+
10
+ self.sentence_endings = r'[.!?؟۔]'
11
+
12
+ self.weak_boundaries = r'[،,;؛]'
13
+
14
+ def clean_text(self, text: str) -> str:
15
+ text = re.sub(r'\s+', ' ', text)
16
+
17
+ text = text.replace('_', '\u200c')
18
+
19
+ text = text.replace('ك', 'ک').replace('ي', 'ی')
20
+
21
+ persian_digits = '۰۱۲۳۴۵۶۷۸۹'
22
+ english_digits = '0123456789'
23
+ digit_map = str.maketrans(persian_digits, english_digits)
24
+ text = text.translate(digit_map)
25
+
26
+ arabic_digits = '٠١٢٣٤٥٦٧٨٩'
27
+ arabic_map = str.maketrans(arabic_digits, english_digits)
28
+ text = text.translate(arabic_map)
29
+
30
+ return text.strip()
31
+
32
+ def split_by_punctuation(self, text: str) -> List[str]:
33
+ segments = re.split(f'({self.sentence_endings})', text)
34
+
35
+ sentences = []
36
+ for i in range(0, len(segments) - 1, 2):
37
+ if i + 1 < len(segments):
38
+ sentence = segments[i] + segments[i + 1]
39
+ else:
40
+ sentence = segments[i]
41
+
42
+ sentence = sentence.strip()
43
+ if sentence:
44
+ sentences.append(sentence)
45
+
46
+ if len(segments) % 2 == 1 and segments[-1].strip():
47
+ sentences.append(segments[-1].strip())
48
+
49
+ return sentences
50
+
51
+ def split_long_sentence(self, sentence: str) -> List[str]:
52
+ if len(sentence) <= self.max_chars:
53
+ return [sentence]
54
+
55
+ chunks = []
56
+ current_chunk = ""
57
+
58
+ parts = re.split(f'({self.weak_boundaries})', sentence)
59
+
60
+ for i in range(0, len(parts)):
61
+ part = parts[i]
62
+
63
+ if len(current_chunk + part) > self.max_chars and current_chunk:
64
+ chunks.append(current_chunk.strip())
65
+ current_chunk = part
66
+ else:
67
+ current_chunk += part
68
+
69
+ if current_chunk.strip():
70
+ chunks.append(current_chunk.strip())
71
+
72
+ final_chunks = []
73
+ for chunk in chunks:
74
+ if len(chunk) > self.max_chars:
75
+ final_chunks.extend(self.force_split_by_words(chunk))
76
+ else:
77
+ final_chunks.append(chunk)
78
+
79
+ return final_chunks
80
+
81
+ def force_split_by_words(self, text: str) -> List[str]:
82
+ words = text.split()
83
+ chunks = []
84
+ current_chunk = []
85
+ current_length = 0
86
+
87
+ for word in words:
88
+ word_length = len(word) + 1 # +1 for space
89
+
90
+ if current_length + word_length > self.max_chars and current_chunk:
91
+ chunks.append(' '.join(current_chunk))
92
+ current_chunk = [word]
93
+ current_length = word_length
94
+ else:
95
+ current_chunk.append(word)
96
+ current_length += word_length
97
+
98
+ if current_chunk:
99
+ chunks.append(' '.join(current_chunk))
100
+
101
+ return chunks
102
+
103
+ def split(self, text: str) -> List[str]:
104
+ text = self.clean_text(text)
105
+
106
+ if not text:
107
+ return []
108
+
109
+ if len(text) <= self.max_chars:
110
+ return [text]
111
+
112
+ sentences = self.split_by_punctuation(text)
113
+
114
+ final_segments = []
115
+ for sentence in sentences:
116
+ if len(sentence) > self.max_chars:
117
+ final_segments.extend(self.split_long_sentence(sentence))
118
+ else:
119
+ final_segments.append(sentence)
120
+
121
+ final_segments = [seg.strip() for seg in final_segments if seg.strip()]
122
+
123
+ return final_segments
spaces.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import re
4
+ import numpy as np
5
+ import torch
6
+ import soundfile as sf
7
+
8
+ # from config import models_path, results_path, sample_path, BASE_DIR
9
+
10
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11
+ models_path = os.path.join(BASE_DIR, 'saved_models', 'final_models')
12
+ results_path = os.path.join(BASE_DIR, 'results')
13
+ sample_path = os.path.join(BASE_DIR, 'sample.wav')
14
+
15
+ from sentence_splitter import PersianSentenceSplitter
16
+ from persian_numbers import find_and_normalize_numbers
17
+
18
+ encoder = None
19
+ synthesizer = None
20
+ vocoder = None
21
+ sentence_splitter = None
22
+
23
+ def load_models():
24
+ global encoder, synthesizer, vocoder, sentence_splitter
25
+
26
+ try:
27
+ sys.path.append(os.path.join(BASE_DIR, 'pmt2'))
28
+
29
+ from encoder import inference as encoder_module
30
+ from synthesizer.inference import Synthesizer
31
+ from parallel_wavegan.utils import load_model as vocoder_hifigan
32
+
33
+ global encoder
34
+ encoder = encoder_module
35
+
36
+ print("Loading encoder model...")
37
+ encoder.load_model(os.path.join(models_path, 'encoder.pt'))
38
+
39
+ print("Loading synthesizer model...")
40
+ synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt'))
41
+
42
+ print("Loading HiFiGAN vocoder...")
43
+ vocoder = vocoder_hifigan(os.path.join(models_path, 'vocoder_HiFiGAN.pkl'))
44
+ vocoder.remove_weight_norm()
45
+ vocoder = vocoder.eval().to('cuda' if torch.cuda.is_available() else 'cpu')
46
+
47
+ sentence_splitter = PersianSentenceSplitter(max_chars=150, min_chars=30)
48
+
49
+ print("Models loaded successfully!")
50
+ return True
51
+ except Exception as e:
52
+ import traceback
53
+ print(f"Error loading models: {traceback.format_exc()}")
54
+ return False
55
+
56
+
57
+ def normalize_text_for_synthesis(text: str) -> str:
58
+ text = text.replace('ك', 'ک').replace('ي', 'ی')
59
+
60
+ text = text.replace('_', '\u200c')
61
+
62
+ text = re.sub(r'\s+', ' ', text)
63
+ text = text.strip()
64
+
65
+ text = find_and_normalize_numbers(text)
66
+
67
+ return text
68
+
69
+
70
+ def synthesize_segment(text_segment: str, embed: np.ndarray) -> np.ndarray:
71
+ try:
72
+ text_segment = normalize_text_for_synthesis(text_segment)
73
+
74
+ specs = synthesizer.synthesize_spectrograms([text_segment], [embed])
75
+ spec = specs[0]
76
+
77
+ x = torch.from_numpy(spec.T).to('cuda' if torch.cuda.is_available() else 'cpu')
78
+
79
+ with torch.no_grad():
80
+ wav = vocoder.inference(x)
81
+
82
+ wav = wav.cpu().numpy()
83
+
84
+ if wav.ndim > 1:
85
+ wav = wav.squeeze()
86
+
87
+ return wav
88
+
89
+ except Exception as e:
90
+ import traceback
91
+ print(f"Error synthesizing segment '{text_segment[:50]}...': {traceback.format_exc()}")
92
+ return None
93
+
94
+
95
+ def add_silence(duration_ms: int = 300) -> np.ndarray:
96
+ sample_rate = synthesizer.sample_rate
97
+ num_samples = int(sample_rate * duration_ms / 1000)
98
+ return np.zeros(num_samples, dtype=np.float32)
99
+
100
+
101
+ def generate_speech(text, reference_audio=None, add_pauses: bool = True):
102
+ if not text or text.strip() == "":
103
+ return None
104
+
105
+ try:
106
+ if reference_audio is None:
107
+ ref_wav_path = sample_path
108
+ else:
109
+ ref_wav_path = os.path.join(results_path, "reference_audio.wav")
110
+ sf.write(ref_wav_path, reference_audio[1], reference_audio[0])
111
+
112
+ print(f"Using reference audio: {ref_wav_path}")
113
+
114
+ wav = synthesizer.load_preprocess_wav(ref_wav_path)
115
+
116
+ encoder_wav = encoder.preprocess_wav(wav)
117
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
118
+
119
+ text_segments = sentence_splitter.split(text)
120
+
121
+ print(f"Split text into {len(text_segments)} segments:")
122
+ for i, segment in enumerate(text_segments, 1):
123
+ print(f" Segment {i}: {segment[:60]}{'...' if len(segment) > 60 else ''}")
124
+
125
+ audio_segments = []
126
+ silence = add_silence(300) if add_pauses else None # 300ms pause
127
+
128
+ for i, segment in enumerate(text_segments):
129
+ print(f"Processing segment {i+1}/{len(text_segments)}...")
130
+
131
+ segment_wav = synthesize_segment(segment, embed)
132
+
133
+ if segment_wav is not None:
134
+ segment_wav = segment_wav.flatten() if segment_wav.ndim > 1 else segment_wav
135
+ audio_segments.append(segment_wav)
136
+
137
+ if add_pauses and i < len(text_segments) - 1:
138
+ audio_segments.append(silence)
139
+ else:
140
+ print(f"Warning: Failed to synthesize segment {i+1}")
141
+
142
+ if not audio_segments:
143
+ print("Error: No audio segments were generated successfully")
144
+ return None
145
+
146
+ audio_segments = [seg.flatten() if seg.ndim > 1 else seg for seg in audio_segments]
147
+
148
+ final_wav = np.concatenate(audio_segments)
149
+
150
+ final_wav = final_wav / np.abs(final_wav).max() * 0.97
151
+
152
+ output_filename = f"generated_{abs(hash(text)) % 100000}.wav"
153
+ output_path = os.path.join(results_path, output_filename)
154
+ sf.write(output_path, final_wav, synthesizer.sample_rate)
155
+
156
+ print(f"✓ Successfully generated speech: {output_path}")
157
+ print(f" Total duration: {len(final_wav) / synthesizer.sample_rate:.2f} seconds")
158
+
159
+ return output_path
160
+
161
+ except Exception as e:
162
+ import traceback
163
+ error_details = traceback.format_exc()
164
+ print(f"Error generating speech: {error_details}")
165
+ return None
synthesizer/LICENSE.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
4
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
5
+ Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
6
+ Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
synthesizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
synthesizer/audio.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from scipy.io import wavfile
6
+ import soundfile as sf
7
+
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ sf.write(path, wav.astype(np.float32), sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ #From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
31
+ def start_and_end_indices(quantized, silence_threshold=2):
32
+ for start in range(quantized.size):
33
+ if abs(quantized[start] - 127) > silence_threshold:
34
+ break
35
+ for end in range(quantized.size - 1, 1, -1):
36
+ if abs(quantized[end] - 127) > silence_threshold:
37
+ break
38
+
39
+ assert abs(quantized[start] - 127) > silence_threshold
40
+ assert abs(quantized[end] - 127) > silence_threshold
41
+
42
+ return start, end
43
+
44
+ def get_hop_size(hparams):
45
+ hop_size = hparams.hop_size
46
+ if hop_size is None:
47
+ assert hparams.frame_shift_ms is not None
48
+ hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
49
+ return hop_size
50
+
51
+ def linearspectrogram(wav, hparams):
52
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
53
+ S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
54
+
55
+ if hparams.signal_normalization:
56
+ return _normalize(S, hparams)
57
+ return S
58
+
59
+ def melspectrogram(wav, hparams):
60
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
61
+ S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
62
+
63
+ if hparams.signal_normalization:
64
+ return _normalize(S, hparams)
65
+ return S
66
+
67
+ def inv_linear_spectrogram(linear_spectrogram, hparams):
68
+ """Converts linear spectrogram to waveform using librosa"""
69
+ if hparams.signal_normalization:
70
+ D = _denormalize(linear_spectrogram, hparams)
71
+ else:
72
+ D = linear_spectrogram
73
+
74
+ S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
75
+
76
+ if hparams.use_lws:
77
+ processor = _lws_processor(hparams)
78
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
79
+ y = processor.istft(D).astype(np.float32)
80
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
81
+ else:
82
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
83
+
84
+ def inv_mel_spectrogram(mel_spectrogram, hparams):
85
+ """Converts mel spectrogram to waveform using librosa"""
86
+ if hparams.signal_normalization:
87
+ D = _denormalize(mel_spectrogram, hparams)
88
+ else:
89
+ D = mel_spectrogram
90
+
91
+ S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
92
+
93
+ if hparams.use_lws:
94
+ processor = _lws_processor(hparams)
95
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
96
+ y = processor.istft(D).astype(np.float32)
97
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
98
+ else:
99
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
100
+
101
+ def _lws_processor(hparams):
102
+ import lws
103
+ return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
104
+
105
+ def _griffin_lim(S, hparams):
106
+ """librosa implementation of Griffin-Lim
107
+ Based on https://github.com/librosa/librosa/issues/434
108
+ """
109
+ angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
110
+ S_complex = np.abs(S).astype(np.complex_)
111
+ y = _istft(S_complex * angles, hparams)
112
+ for i in range(hparams.griffin_lim_iters):
113
+ angles = np.exp(1j * np.angle(_stft(y, hparams)))
114
+ y = _istft(S_complex * angles, hparams)
115
+ return y
116
+
117
+ def _stft(y, hparams):
118
+ if hparams.use_lws:
119
+ return _lws_processor(hparams).stft(y).T
120
+ else:
121
+ return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
122
+
123
+ def _istft(y, hparams):
124
+ return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
125
+
126
+ ##########################################################
127
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
128
+ def num_frames(length, fsize, fshift):
129
+ """Compute number of time frames of spectrogram
130
+ """
131
+ pad = (fsize - fshift)
132
+ if length % fshift == 0:
133
+ M = (length + pad * 2 - fsize) // fshift + 1
134
+ else:
135
+ M = (length + pad * 2 - fsize) // fshift + 2
136
+ return M
137
+
138
+
139
+ def pad_lr(x, fsize, fshift):
140
+ """Compute left and right padding
141
+ """
142
+ M = num_frames(len(x), fsize, fshift)
143
+ pad = (fsize - fshift)
144
+ T = len(x) + 2 * pad
145
+ r = (M - 1) * fshift + fsize - T
146
+ return pad, pad + r
147
+ ##########################################################
148
+ #Librosa correct padding
149
+ def librosa_pad_lr(x, fsize, fshift):
150
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
151
+
152
+ # Conversions
153
+ _mel_basis = None
154
+ _inv_mel_basis = None
155
+
156
+ def _linear_to_mel(spectogram, hparams):
157
+ global _mel_basis
158
+ if _mel_basis is None:
159
+ _mel_basis = _build_mel_basis(hparams)
160
+ return np.dot(_mel_basis, spectogram)
161
+
162
+ def _mel_to_linear(mel_spectrogram, hparams):
163
+ global _inv_mel_basis
164
+ if _inv_mel_basis is None:
165
+ _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
166
+ return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
167
+
168
+ def _build_mel_basis(hparams):
169
+ assert hparams.fmax <= hparams.sample_rate // 2
170
+ return librosa.filters.mel(sr=hparams.sample_rate, n_fft=hparams.n_fft, n_mels=hparams.num_mels,
171
+ fmin=hparams.fmin, fmax=hparams.fmax)
172
+
173
+ def _amp_to_db(x, hparams):
174
+ min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
175
+ return 20 * np.log10(np.maximum(min_level, x))
176
+
177
+ def _db_to_amp(x):
178
+ return np.power(10.0, (x) * 0.05)
179
+
180
+ def _normalize(S, hparams):
181
+ if hparams.allow_clipping_in_normalization:
182
+ if hparams.symmetric_mels:
183
+ return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
184
+ -hparams.max_abs_value, hparams.max_abs_value)
185
+ else:
186
+ return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
187
+
188
+ assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
189
+ if hparams.symmetric_mels:
190
+ return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
191
+ else:
192
+ return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
193
+
194
+ def _denormalize(D, hparams):
195
+ if hparams.allow_clipping_in_normalization:
196
+ if hparams.symmetric_mels:
197
+ return (((np.clip(D, -hparams.max_abs_value,
198
+ hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
199
+ + hparams.min_level_db)
200
+ else:
201
+ return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
202
+
203
+ if hparams.symmetric_mels:
204
+ return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
205
+ else:
206
+ return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
synthesizer/audio_v2(support_hifigan).py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # raccoonML audio tools.
2
+ # MIT License
3
+ # Copyright (c) 2021 raccoonML (https://patreon.com/raccoonML)
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software") to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in
13
+ # all copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR ANY OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ # THE SOFTWARE.
22
+
23
+ import librosa
24
+ import numpy as np
25
+ import soundfile as sf
26
+ import torch
27
+ from scipy import signal
28
+
29
+ _mel_basis = None
30
+
31
+
32
+ def load_wav(path, sr):
33
+ # Loads an audio file and returns the waveform data.
34
+ wav, _ = librosa.load(str(path), sr=sr)
35
+ return wav
36
+
37
+
38
+ def save_wav(wav, path, sr):
39
+ # Saves waveform data to audio file.
40
+ sf.write(path, wav, sr)
41
+
42
+
43
+ def melspectrogram(wav, hparams):
44
+ # Converts a waveform to a mel-scale spectrogram.
45
+ # Output shape = (num_mels, frames)
46
+
47
+ # Apply preemphasis
48
+ if hparams.preemphasize:
49
+ wav = preemphasis(wav, hparams.preemphasis)
50
+
51
+ # Short-time Fourier Transform (STFT)
52
+ D = librosa.stft(
53
+ y=wav,
54
+ n_fft=hparams.n_fft,
55
+ hop_length=hparams.hop_size,
56
+ win_length=hparams.win_size,
57
+ )
58
+
59
+ # Convert complex-valued output of STFT to absolute value (real)
60
+ S = np.abs(D)
61
+
62
+ # Build and cache mel basis
63
+ # This improves speed when calculating thousands of mel spectrograms.
64
+ global _mel_basis
65
+ if _mel_basis is None:
66
+ _mel_basis = _build_mel_basis(hparams)
67
+
68
+ # Transform to mel scale
69
+ S = np.dot(_mel_basis, S)
70
+
71
+ # Dynamic range compression
72
+ S = np.log(np.clip(S, a_min=1e-5, a_max=None))
73
+
74
+ return S.astype(np.float32)
75
+
76
+
77
+ def inv_mel_spectrogram(S, hparams):
78
+ # Converts a mel spectrogram to waveform using Griffin-Lim
79
+ # Input shape = (num_mels, frames)
80
+
81
+ # Denormalize
82
+ S = np.exp(S)
83
+
84
+ # Build and cache mel basis
85
+ # This improves speed when calculating thousands of mel spectrograms.
86
+ global _mel_basis
87
+ if _mel_basis is None:
88
+ _mel_basis = _build_mel_basis(hparams)
89
+
90
+ # Inverse mel basis
91
+ p = np.matmul(_mel_basis, _mel_basis.T)
92
+ d = [1.0 / x if np.abs(x) > 1.0e-8 else x for x in np.sum(p, axis=0)]
93
+ _inv_mel_basis = np.matmul(_mel_basis.T, np.diag(d))
94
+
95
+ # Invert mel basis to recover linear spectrogram
96
+ S = np.dot(_inv_mel_basis, S)
97
+
98
+ # Use Griffin-Lim to recover waveform
99
+ wav = _griffin_lim(S ** hparams.power, hparams)
100
+
101
+ # Invert preemphasis
102
+ if hparams.preemphasize:
103
+ wav = inv_preemphasis(wav, hparams.preemphasis)
104
+
105
+ return wav
106
+
107
+
108
+ def preemphasis(wav, k, preemphasize=True):
109
+ # Amplifies high frequency content in a waveform.
110
+ if preemphasize:
111
+ wav = signal.lfilter([1, -k], [1], wav)
112
+ return wav
113
+
114
+
115
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
116
+ # Inverts the preemphasis filter.
117
+ if inv_preemphasize:
118
+ wav = signal.lfilter([1], [1, -k], wav)
119
+ return wav
120
+
121
+
122
+ def _build_mel_basis(hparams):
123
+ return librosa.filters.mel(
124
+ sr=hparams.sample_rate,
125
+ n_fft=hparams.n_fft,
126
+ n_mels=hparams.num_mels,
127
+ fmin=hparams.fmin,
128
+ fmax=hparams.fmax,
129
+ )
130
+
131
+
132
+ def _griffin_lim(S, hparams):
133
+ angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
134
+ S = np.abs(S).astype(np.complex)
135
+ wav = librosa.istft(
136
+ S * angles, hop_length=hparams.hop_size, win_length=hparams.win_size
137
+ )
138
+ for i in range(hparams.griffin_lim_iters):
139
+ angles = np.exp(
140
+ 1j
141
+ * np.angle(
142
+ librosa.stft(
143
+ wav,
144
+ n_fft=hparams.n_fft,
145
+ hop_length=hparams.hop_size,
146
+ win_length=hparams.win_size,
147
+ )
148
+ )
149
+ )
150
+ wav = librosa.istft(
151
+ S * angles, hop_length=hparams.hop_size, win_length=hparams.win_size
152
+ )
153
+
154
+ return wav
synthesizer/english utils/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ _output_ref = None
5
+ _replicas_ref = None
6
+
7
+ def data_parallel_workaround(model, *input):
8
+ global _output_ref
9
+ global _replicas_ref
10
+ device_ids = list(range(torch.cuda.device_count()))
11
+ output_device = device_ids[0]
12
+ replicas = torch.nn.parallel.replicate(model, device_ids)
13
+ # input.shape = (num_args, batch, ...)
14
+ inputs = torch.nn.parallel.scatter(input, device_ids)
15
+ # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
16
+ replicas = replicas[:len(inputs)]
17
+ outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
18
+ y_hat = torch.nn.parallel.gather(outputs, output_device)
19
+ _output_ref = outputs
20
+ _replicas_ref = replicas
21
+ return y_hat
22
+
23
+
24
+ class ValueWindow():
25
+ def __init__(self, window_size=100):
26
+ self._window_size = window_size
27
+ self._values = []
28
+
29
+ def append(self, x):
30
+ self._values = self._values[-(self._window_size - 1):] + [x]
31
+
32
+ @property
33
+ def sum(self):
34
+ return sum(self._values)
35
+
36
+ @property
37
+ def count(self):
38
+ return len(self._values)
39
+
40
+ @property
41
+ def average(self):
42
+ return self.sum / max(1, self.count)
43
+
44
+ def reset(self):
45
+ self._values = []
synthesizer/english utils/_cmudict.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ valid_symbols = [
4
+ "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
5
+ "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
6
+ "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
7
+ "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
8
+ "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
9
+ "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
10
+ "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
11
+ ]
12
+
13
+ _valid_symbol_set = set(valid_symbols)
14
+
15
+
16
+ class CMUDict:
17
+ """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
18
+ def __init__(self, file_or_path, keep_ambiguous=True):
19
+ if isinstance(file_or_path, str):
20
+ with open(file_or_path, encoding="latin-1") as f:
21
+ entries = _parse_cmudict(f)
22
+ else:
23
+ entries = _parse_cmudict(file_or_path)
24
+ if not keep_ambiguous:
25
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
26
+ self._entries = entries
27
+
28
+
29
+ def __len__(self):
30
+ return len(self._entries)
31
+
32
+
33
+ def lookup(self, word):
34
+ """Returns list of ARPAbet pronunciations of the given word."""
35
+ return self._entries.get(word.upper())
36
+
37
+
38
+
39
+ _alt_re = re.compile(r"\([0-9]+\)")
40
+
41
+
42
+ def _parse_cmudict(file):
43
+ cmudict = {}
44
+ for line in file:
45
+ if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
46
+ parts = line.split(" ")
47
+ word = re.sub(_alt_re, "", parts[0])
48
+ pronunciation = _get_pronunciation(parts[1])
49
+ if pronunciation:
50
+ if word in cmudict:
51
+ cmudict[word].append(pronunciation)
52
+ else:
53
+ cmudict[word] = [pronunciation]
54
+ return cmudict
55
+
56
+
57
+ def _get_pronunciation(s):
58
+ parts = s.strip().split(" ")
59
+ for part in parts:
60
+ if part not in _valid_symbol_set:
61
+ return None
62
+ return " ".join(parts)
synthesizer/english utils/cleaners.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cleaners are transformations that run over the input text at both training and eval time.
3
+
4
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
5
+ hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
6
+ 1. "english_cleaners" for English text
7
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
8
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
9
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
10
+ the symbols in symbols.py to match your data).
11
+ """
12
+ import re
13
+ from unidecode import unidecode
14
+ from synthesizer.utils.numbers import normalize_numbers
15
+
16
+
17
+ # Regular expression matching whitespace:
18
+ _whitespace_re = re.compile(r"\s+")
19
+
20
+ # List of (regular expression, replacement) pairs for abbreviations:
21
+ _abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
22
+ ("mrs", "misess"),
23
+ ("mr", "mister"),
24
+ ("dr", "doctor"),
25
+ ("st", "saint"),
26
+ ("co", "company"),
27
+ ("jr", "junior"),
28
+ ("maj", "major"),
29
+ ("gen", "general"),
30
+ ("drs", "doctors"),
31
+ ("rev", "reverend"),
32
+ ("lt", "lieutenant"),
33
+ ("hon", "honorable"),
34
+ ("sgt", "sergeant"),
35
+ ("capt", "captain"),
36
+ ("esq", "esquire"),
37
+ ("ltd", "limited"),
38
+ ("col", "colonel"),
39
+ ("ft", "fort"),
40
+ ]]
41
+
42
+
43
+ def expand_abbreviations(text):
44
+ for regex, replacement in _abbreviations:
45
+ text = re.sub(regex, replacement, text)
46
+ return text
47
+
48
+
49
+ def expand_numbers(text):
50
+ return normalize_numbers(text)
51
+
52
+
53
+ def lowercase(text):
54
+ """lowercase input tokens."""
55
+ return text.lower()
56
+
57
+
58
+ def collapse_whitespace(text):
59
+ return re.sub(_whitespace_re, " ", text)
60
+
61
+
62
+ def convert_to_ascii(text):
63
+ return unidecode(text)
64
+
65
+
66
+ def basic_cleaners(text):
67
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
68
+ text = lowercase(text)
69
+ text = collapse_whitespace(text)
70
+ return text
71
+
72
+
73
+ def transliteration_cleaners(text):
74
+ """Pipeline for non-English text that transliterates to ASCII."""
75
+ text = convert_to_ascii(text)
76
+ text = lowercase(text)
77
+ text = collapse_whitespace(text)
78
+ return text
79
+
80
+
81
+ def english_cleaners(text):
82
+ """Pipeline for English text, including number and abbreviation expansion."""
83
+ text = convert_to_ascii(text)
84
+ # text = lowercase(text)
85
+ text = expand_numbers(text)
86
+ text = expand_abbreviations(text)
87
+ text = collapse_whitespace(text)
88
+ return text
synthesizer/english utils/numbers.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import inflect
3
+
4
+
5
+ _inflect = inflect.engine()
6
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
7
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
8
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
9
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
10
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
11
+ _number_re = re.compile(r"[0-9]+")
12
+
13
+
14
+ def _remove_commas(m):
15
+ return m.group(1).replace(",", "")
16
+
17
+
18
+ def _expand_decimal_point(m):
19
+ return m.group(1).replace(".", " point ")
20
+
21
+
22
+ def _expand_dollars(m):
23
+ match = m.group(1)
24
+ parts = match.split(".")
25
+ if len(parts) > 2:
26
+ return match + " dollars" # Unexpected format
27
+ dollars = int(parts[0]) if parts[0] else 0
28
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
29
+ if dollars and cents:
30
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
31
+ cent_unit = "cent" if cents == 1 else "cents"
32
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
33
+ elif dollars:
34
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
35
+ return "%s %s" % (dollars, dollar_unit)
36
+ elif cents:
37
+ cent_unit = "cent" if cents == 1 else "cents"
38
+ return "%s %s" % (cents, cent_unit)
39
+ else:
40
+ return "zero dollars"
41
+
42
+
43
+ def _expand_ordinal(m):
44
+ return _inflect.number_to_words(m.group(0))
45
+
46
+
47
+ def _expand_number(m):
48
+ num = int(m.group(0))
49
+ if num > 1000 and num < 3000:
50
+ if num == 2000:
51
+ return "two thousand"
52
+ elif num > 2000 and num < 2010:
53
+ return "two thousand " + _inflect.number_to_words(num % 100)
54
+ elif num % 100 == 0:
55
+ return _inflect.number_to_words(num // 100) + " hundred"
56
+ else:
57
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
58
+ else:
59
+ return _inflect.number_to_words(num, andword="")
60
+
61
+
62
+ def normalize_numbers(text):
63
+ text = re.sub(_comma_number_re, _remove_commas, text)
64
+ text = re.sub(_pounds_re, r"\1 pounds", text)
65
+ text = re.sub(_dollars_re, _expand_dollars, text)
66
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
67
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
68
+ text = re.sub(_number_re, _expand_number, text)
69
+ return text
synthesizer/english utils/plot.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def split_title_line(title_text, max_words=5):
5
+ """
6
+ A function that splits any string based on specific character
7
+ (returning it with the string), with maximum number of words on it
8
+ """
9
+ seq = title_text.split()
10
+ return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
11
+
12
+
13
+ def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
14
+ import matplotlib
15
+ matplotlib.use("Agg")
16
+ import matplotlib.pyplot as plt
17
+
18
+ if max_len is not None:
19
+ alignment = alignment[:, :max_len]
20
+
21
+ fig = plt.figure(figsize=(8, 6))
22
+ ax = fig.add_subplot(111)
23
+
24
+ im = ax.imshow(
25
+ alignment,
26
+ aspect="auto",
27
+ origin="lower",
28
+ interpolation="none")
29
+ fig.colorbar(im, ax=ax)
30
+ xlabel = "Decoder timestep"
31
+
32
+ if split_title:
33
+ title = split_title_line(title)
34
+
35
+ plt.xlabel(xlabel)
36
+ plt.title(title)
37
+ plt.ylabel("Encoder timestep")
38
+ plt.tight_layout()
39
+ plt.savefig(path, format="png")
40
+ plt.close()
41
+
42
+
43
+ def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
44
+ import matplotlib
45
+ matplotlib.use("Agg")
46
+ import matplotlib.pyplot as plt
47
+
48
+ if max_len is not None:
49
+ target_spectrogram = target_spectrogram[:max_len]
50
+ pred_spectrogram = pred_spectrogram[:max_len]
51
+
52
+ if split_title:
53
+ title = split_title_line(title)
54
+
55
+ fig = plt.figure(figsize=(10, 8))
56
+ # Set common labels
57
+ fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
58
+
59
+ #target spectrogram subplot
60
+ if target_spectrogram is not None:
61
+ ax1 = fig.add_subplot(311)
62
+ ax2 = fig.add_subplot(312)
63
+
64
+ if auto_aspect:
65
+ im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
66
+ else:
67
+ im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
68
+ ax1.set_title("Target Mel-Spectrogram")
69
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
70
+ ax2.set_title("Predicted Mel-Spectrogram")
71
+ else:
72
+ ax2 = fig.add_subplot(211)
73
+
74
+ if auto_aspect:
75
+ im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
76
+ else:
77
+ im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
78
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
79
+
80
+ plt.tight_layout()
81
+ plt.savefig(path, format="png")
82
+ plt.close()
synthesizer/english utils/symbols.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines the set of symbols used in text input to the model.
3
+
4
+ The default is a set of ASCII characters that works well for English or text that has been run
5
+ through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
6
+ """
7
+ # from . import cmudict
8
+
9
+ _pad = "_"
10
+ _eos = "~"
11
+ _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? "
12
+
13
+ # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
14
+ #_arpabet = ["@' + s for s in cmudict.valid_symbols]
15
+
16
+ # Export all symbols:
17
+ symbols = [_pad, _eos] + list(_characters) #+ _arpabet
synthesizer/english utils/text.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from synthesizer.utils.symbols import symbols
2
+ from synthesizer.utils import cleaners
3
+ import re
4
+
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+ # Regular expression matching text enclosed in curly braces:
11
+ _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
12
+
13
+
14
+ def text_to_sequence(text, cleaner_names):
15
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
16
+
17
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
18
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
19
+
20
+ Args:
21
+ text: string to convert to a sequence
22
+ cleaner_names: names of the cleaner functions to run the text through
23
+
24
+ Returns:
25
+ List of integers corresponding to the symbols in the text
26
+ """
27
+ sequence = []
28
+
29
+ # Check for curly braces and treat their contents as ARPAbet:
30
+ while len(text):
31
+ m = _curly_re.match(text)
32
+ if not m:
33
+ sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
34
+ break
35
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
36
+ sequence += _arpabet_to_sequence(m.group(2))
37
+ text = m.group(3)
38
+
39
+ # Append EOS token
40
+ sequence.append(_symbol_to_id["~"])
41
+ return sequence
42
+
43
+
44
+ def sequence_to_text(sequence):
45
+ """Converts a sequence of IDs back to a string"""
46
+ result = ""
47
+ for symbol_id in sequence:
48
+ if symbol_id in _id_to_symbol:
49
+ s = _id_to_symbol[symbol_id]
50
+ # Enclose ARPAbet back in curly braces:
51
+ if len(s) > 1 and s[0] == "@":
52
+ s = "{%s}" % s[1:]
53
+ result += s
54
+ return result.replace("}{", " ")
55
+
56
+
57
+ def _clean_text(text, cleaner_names):
58
+ for name in cleaner_names:
59
+ cleaner = getattr(cleaners, name)
60
+ if not cleaner:
61
+ raise Exception("Unknown cleaner: %s" % name)
62
+ text = cleaner(text)
63
+ return text
64
+
65
+
66
+ def _symbols_to_sequence(symbols):
67
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
68
+
69
+
70
+ def _arpabet_to_sequence(text):
71
+ return _symbols_to_sequence(["@" + s for s in text.split()])
72
+
73
+
74
+ def _should_keep_symbol(s):
75
+ return s in _symbol_to_id and s not in ("_", "~")
synthesizer/hparams.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import pprint
3
+
4
+ class HParams(object):
5
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
6
+ def __setitem__(self, key, value): setattr(self, key, value)
7
+ def __getitem__(self, key): return getattr(self, key)
8
+ def __repr__(self): return pprint.pformat(self.__dict__)
9
+
10
+ def parse(self, string):
11
+ # Overrides hparams from a comma-separated string of name=value pairs
12
+ if len(string) > 0:
13
+ overrides = [s.split("=") for s in string.split(",")]
14
+ keys, values = zip(*overrides)
15
+ keys = list(map(str.strip, keys))
16
+ values = list(map(str.strip, values))
17
+ for k in keys:
18
+ self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
19
+ return self
20
+
21
+ hparams = HParams(
22
+ ### Signal Processing (used in both synthesizer and vocoder)
23
+
24
+ # sample_rate = 22050,
25
+ # n_fft = 1024,
26
+ # num_mels = 80,
27
+ # hop_size = 256, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
28
+ # win_size = 1024, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
29
+ # fmin = 0,
30
+ # fmax = 11025,
31
+
32
+ sample_rate = 24000,
33
+ n_fft = 2048,
34
+ num_mels = 80,
35
+ hop_size = 300, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
36
+ win_size = 1200, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
37
+ fmin = 80,
38
+
39
+ # sample_rate = 16000,
40
+ # n_fft = 800,
41
+ # num_mels = 80,
42
+ # hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
43
+ # win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
44
+ # fmin = 55,
45
+ min_level_db = -100,
46
+ ref_level_db = 20,
47
+ max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
48
+ preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
49
+ preemphasize = True,
50
+
51
+ ### Tacotron Text-to-Speech (TTS)
52
+ tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
53
+ tts_encoder_dims = 256,
54
+ tts_decoder_dims = 128,
55
+ tts_postnet_dims = 512,
56
+ tts_encoder_K = 5,
57
+ tts_lstm_dims = 1024,
58
+ tts_postnet_K = 5,
59
+ tts_num_highways = 4,
60
+ tts_dropout = 0.5,
61
+ tts_cleaner_names = ["persian_cleaners"],
62
+ tts_stop_threshold = -3.4, # Value below which audio generation ends.
63
+ # For example, for a range of [-4, 4], this
64
+ # will terminate the sequence at the first
65
+ # frame that has all values < -3.4
66
+
67
+ ### Tacotron Training
68
+ tts_schedule = [(2, 1e-3, 10_000, 16), # Progressive training schedule
69
+ (2, 5e-4, 20_000, 16), # (r, lr, step, batch_size)
70
+ (2, 2e-4, 40_000, 16), #
71
+ (2, 1e-4, 80_000, 16), # r = reduction factor (# of mel frames
72
+ (2, 3e-5, 160_000, 16), # synthesized for each decoder iteration)
73
+ (2, 1e-5, 320_000, 16)], # lr = learning rate
74
+
75
+ tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
76
+ tts_eval_interval = 5000, # Number of steps between model evaluation (sample generation)
77
+ # Set to -1 to generate after completing epoch, or 0 to disable
78
+
79
+ tts_eval_num_samples = 1, # Makes this number of samples
80
+
81
+ ### Data Preprocessing
82
+ max_mel_frames = 900,
83
+ rescale = True,
84
+ rescaling_max = 0.9,
85
+ synthesis_batch_size = 16, # For vocoder preprocessing and inference.
86
+
87
+ ### Mel Visualization and Griffin-Lim
88
+ signal_normalization = True,
89
+ power = 1.5,
90
+ griffin_lim_iters = 60,
91
+
92
+ ### Audio processing options
93
+ fmax = 7600, # Should not exceed (sample_rate // 2)
94
+ allow_clipping_in_normalization = True, # Used when signal_normalization = True
95
+ clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
96
+ use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
97
+ symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
98
+ # and [0, max_abs_value] if False
99
+ trim_silence = True, # Use with sample_rate of 16000 for best results
100
+
101
+ ### SV2TTS
102
+ speaker_embedding_size = 256, # Dimension for the speaker embedding
103
+ silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
104
+ utterance_min_duration = 0.8, # Duration in seconds below which utterances are discarded
105
+ )
106
+
107
+ def hparams_debug_string():
108
+ return str(hparams)