ak36 commited on
Commit
07b5cfc
·
verified ·
1 Parent(s): d394762

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. .gitignore +1 -0
  2. .ipynb_checkpoints/bad_wavs-checkpoint.txt +0 -0
  3. .ipynb_checkpoints/inference_first-checkpoint.py +90 -0
  4. .ipynb_checkpoints/meldataset-checkpoint.py +282 -0
  5. .ipynb_checkpoints/models-checkpoint.py +713 -0
  6. Colab/StyleTTS2_Demo_LJSpeech.ipynb +486 -0
  7. Colab/StyleTTS2_Demo_LibriTTS.ipynb +1218 -0
  8. Colab/StyleTTS2_Finetune_Demo.ipynb +480 -0
  9. Configs/config.yml +116 -0
  10. Configs/config_ft.yml +111 -0
  11. Configs/config_ft_single.yml +118 -0
  12. Configs/config_libritts.yml +113 -0
  13. Data/train_list.txt +0 -0
  14. Data/val_list.txt +100 -0
  15. Demo/Inference_LJSpeech.ipynb +554 -0
  16. Demo/Inference_LibriTTS.ipynb +1155 -0
  17. LICENSE +21 -0
  18. Modules/__init__.py +1 -0
  19. Modules/discriminators.py +188 -0
  20. Modules/hifigan.py +477 -0
  21. Modules/istftnet.py +530 -0
  22. Modules/slmadv.py +195 -0
  23. Modules/utils.py +14 -0
  24. README.md +125 -0
  25. Utils/__init__.py +1 -0
  26. __pycache__/losses.cpython-310.pyc +0 -0
  27. __pycache__/meldataset.cpython-310.pyc +0 -0
  28. __pycache__/models.cpython-310.pyc +0 -0
  29. __pycache__/optimizers.cpython-310.pyc +0 -0
  30. __pycache__/utils.cpython-310.pyc +0 -0
  31. bad_wavs.txt +0 -0
  32. data/OOD_dummy.txt +2 -0
  33. data/add_phones.py +42 -0
  34. data/val_list.txt +0 -0
  35. inference_first.py +90 -0
  36. logs/pod_90h_30k/config_ft_single.yml +118 -0
  37. logs/pod_90h_30k/tensorboard/events.out.tfevents.1749338343.104-171-203-10.11888.0 +3 -0
  38. logs/pod_90h_30k/train.log +0 -0
  39. losses.py +253 -0
  40. meldataset.py +282 -0
  41. models.py +713 -0
  42. optimizers.py +73 -0
  43. preview.wav +0 -0
  44. requirements.txt +17 -0
  45. text_utils.py +26 -0
  46. train_finetune.py +707 -0
  47. train_finetune_accelerate.py +714 -0
  48. train_first.py +445 -0
  49. train_second.py +792 -0
  50. utils.py +74 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
.ipynb_checkpoints/bad_wavs-checkpoint.txt ADDED
File without changes
.ipynb_checkpoints/inference_first-checkpoint.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ inference_first.py — quick Stage-1 sanity-check for StyleTTS-2
4
+
5
+ Example
6
+ -------
7
+ python inference_first.py \
8
+ --ckpt logs/pod_90h_30k/epoch_1st_0004.pth \
9
+ --ref data/wavs/123_abcd_part042_00.wav \
10
+ --text "<evt_gasp> ðɪs ɪz ɐ tɛst ˈsɛntəns"
11
+ It writes preview.wav in the current directory.
12
+ """
13
+ import argparse, yaml, torch, torchaudio
14
+ from models import build_model, load_ASR_models, load_F0_models
15
+ from Utils.PLBERT.util import load_plbert
16
+ from utils import recursive_munch, log_norm, length_to_mask
17
+ from meldataset import TextCleaner, preprocess
18
+
19
+ # ────────────────────────── helpers ────────────────────────────
20
+ def _restore_batch(x):
21
+ """(T,) ▸ (1,T) or (C,T) ▸ (1,C,T) (handles squeeze in JDCNet)."""
22
+ return x.unsqueeze(0) if x.dim() == 1 else x
23
+
24
+ def _match_len(x, target_len):
25
+ """Crop or zero-pad last axis to target_len."""
26
+ cur = x.shape[-1]
27
+ if cur > target_len:
28
+ return x[..., :target_len]
29
+ if cur < target_len:
30
+ pad = target_len - cur
31
+ return torch.nn.functional.pad(x, (0, pad))
32
+ return x
33
+
34
+ # ────────────────────────── CLI ────────────────────────────────
35
+ p = argparse.ArgumentParser()
36
+ p.add_argument("--ckpt", required=True, help="epoch_1st_*.pth")
37
+ p.add_argument("--ref", required=True, help="reference wav (24 kHz mono)")
38
+ p.add_argument("--text", required=True, help="IPA / phoneme string")
39
+ p.add_argument("--cfg", default="Configs/config_ft_single.yml")
40
+ args = p.parse_args()
41
+
42
+ # ───────────────── net & cfg ───────────────────────────────────
43
+ cfg = yaml.safe_load(open(args.cfg))
44
+ sr = cfg["preprocess_params"]["sr"]
45
+ device = "cuda"
46
+
47
+ asr = load_ASR_models(cfg["ASR_path"], cfg["ASR_config"])
48
+ f0 = load_F0_models(cfg["F0_path"])
49
+ bert = load_plbert(cfg["PLBERT_dir"])
50
+ model = build_model(recursive_munch(cfg["model_params"]), asr, f0, bert)
51
+
52
+ state = torch.load(args.ckpt, map_location="cpu")["net"]
53
+ for k in model:
54
+ model[k].load_state_dict(state[k], strict=False)
55
+ model[k].eval().to(device)
56
+
57
+ # ───────────────── prepare inputs ──────────────────────────────
58
+ cleaner = TextCleaner()
59
+ text_ids = torch.LongTensor(cleaner(args.text)).unsqueeze(0).to(device)
60
+ input_lengths = torch.LongTensor([text_ids.shape[1]]).to(device)
61
+ text_mask = length_to_mask(input_lengths).to(device)
62
+
63
+ wav, _ = torchaudio.load(args.ref) # (1,N)
64
+ mel_ref = preprocess(wav.squeeze().numpy()).to(device) # (1,80,T)
65
+
66
+ style = model.style_encoder(mel_ref.unsqueeze(1)) # (1,128)
67
+
68
+ F0_real, _, _ = model.pitch_extractor(mel_ref.unsqueeze(1))
69
+ F0_real = _restore_batch(F0_real) # (1,T')
70
+
71
+ real_norm = log_norm(mel_ref.unsqueeze(1)).squeeze(1) # (1,T")
72
+ real_norm = _restore_batch(real_norm) # (1,T")
73
+
74
+ # ───────────────── align lengths ───────────────────────────────
75
+ enc = model.text_encoder(text_ids, input_lengths, text_mask) # (1,512,L)
76
+ enc_len = enc.shape[-1] # L
77
+ target = enc_len * 2 # decoder expects 2×L
78
+
79
+ F0_real = _match_len(F0_real, target) # (1,2L)
80
+ real_norm = _match_len(real_norm, target) # (1,2L)
81
+
82
+ # ───────────────── decode & save ───────────────────────────────
83
+ with torch.no_grad():
84
+ y = model.decoder(enc, F0_real, real_norm, style)
85
+
86
+ # ─── make it (channels, samples) = (1, T) ────────────────────────────
87
+ y = y.squeeze(0) # (1, T)
88
+
89
+ torchaudio.save("preview.wav", y.cpu(), sr)
90
+ print("✅ wrote preview.wav")
.ipynb_checkpoints/meldataset-checkpoint.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding: utf-8
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ import random
6
+ import numpy as np
7
+ import random
8
+ import soundfile as sf
9
+ import librosa
10
+ import re, unicodedata
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+ import torchaudio
16
+ from torch.utils.data import DataLoader
17
+
18
+ import logging
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.DEBUG)
21
+
22
+ import pandas as pd
23
+
24
+ _pad = "$"
25
+ _punctuation = ';:,.!?¡¿—…"«»“” '
26
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
27
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
28
+
29
+ # Export all symbols:
30
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
31
+
32
+ dicts = {}
33
+ for i in range(len((symbols))):
34
+ dicts[symbols[i]] = i
35
+
36
+ # class TextCleaner:
37
+ # def __init__(self, dummy=None):
38
+ # self.word_index_dictionary = dicts
39
+ # def __call__(self, text):
40
+ # indexes = []
41
+ # for char in text:
42
+ # try:
43
+ # indexes.append(self.word_index_dictionary[char])
44
+ # except KeyError:
45
+ # print(text)
46
+ # return indexes
47
+
48
+ class TextCleaner:
49
+ """
50
+ • Normalises text to NFC so pre-composed IPA glyphs match `symbols`.
51
+ • Splits on event tokens first (e.g. <evt_gasp>), then per-character.
52
+ • Unknown chars map to the <unk> symbol instead of printing.
53
+ """
54
+ _EVENT_RE = re.compile(r"<[^>]+>|.") # match <evt_xxx> or single char
55
+
56
+ def __init__(self):
57
+ # `dicts` must already include EVENT_TOKENS and "<unk>"
58
+ self.lookup = dicts
59
+ self.unk_id = 0
60
+
61
+ def __call__(self, text: str):
62
+ text = unicodedata.normalize("NFC", text)
63
+ ids = []
64
+ for tok in self._EVENT_RE.findall(text):
65
+ ids.append(self.lookup.get(tok, self.unk_id))
66
+ return ids
67
+
68
+
69
+ np.random.seed(1)
70
+ random.seed(1)
71
+ SPECT_PARAMS = {
72
+ "n_fft": 2048,
73
+ "win_length": 1200,
74
+ "hop_length": 300
75
+ }
76
+ MEL_PARAMS = {
77
+ "n_mels": 80,
78
+ }
79
+
80
+ to_mel = torchaudio.transforms.MelSpectrogram(
81
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
82
+ mean, std = -4, 4
83
+
84
+ def preprocess(wave):
85
+ wave_tensor = torch.from_numpy(wave).float()
86
+ mel_tensor = to_mel(wave_tensor)
87
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
88
+ return mel_tensor
89
+
90
+ class FilePathDataset(torch.utils.data.Dataset):
91
+ def __init__(self,
92
+ data_list,
93
+ root_path,
94
+ sr=24000,
95
+ data_augmentation=False,
96
+ validation=False,
97
+ OOD_data="Data/OOD_texts.txt",
98
+ min_length=50,
99
+ ):
100
+
101
+ spect_params = SPECT_PARAMS
102
+ mel_params = MEL_PARAMS
103
+
104
+ _data_list = [l.strip().split('|') for l in data_list]
105
+ self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
106
+ self.text_cleaner = TextCleaner()
107
+ self.sr = sr
108
+
109
+ self.df = pd.DataFrame(self.data_list)
110
+
111
+ self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS)
112
+
113
+ self.mean, self.std = -4, 4
114
+ self.data_augmentation = data_augmentation and (not validation)
115
+ self.max_mel_length = 192
116
+
117
+ self.min_length = min_length
118
+ with open(OOD_data, 'r', encoding='utf-8') as f:
119
+ tl = f.readlines()
120
+ idx = 1 if '.wav' in tl[0].split('|')[0] else 0
121
+ self.ptexts = [t.split('|')[idx] for t in tl]
122
+
123
+ self.root_path = root_path
124
+
125
+ def __len__(self):
126
+ return len(self.data_list)
127
+
128
+ def __getitem__(self, idx):
129
+ data = self.data_list[idx]
130
+ path = data[0]
131
+
132
+ wave, text_tensor, speaker_id = self._load_tensor(data)
133
+
134
+ mel_tensor = preprocess(wave).squeeze()
135
+
136
+ acoustic_feature = mel_tensor.squeeze()
137
+ length_feature = acoustic_feature.size(1)
138
+ acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
139
+
140
+ # get reference sample
141
+ ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist()
142
+ ref_mel_tensor, ref_label = self._load_data(ref_data[:3])
143
+
144
+ # get OOD text
145
+
146
+ ps = ""
147
+
148
+ while len(ps) < self.min_length:
149
+ rand_idx = np.random.randint(0, len(self.ptexts) - 1)
150
+ ps = self.ptexts[rand_idx]
151
+
152
+ text = self.text_cleaner(ps)
153
+ text.insert(0, 0)
154
+ text.append(0)
155
+
156
+ ref_text = torch.LongTensor(text)
157
+
158
+ return speaker_id, acoustic_feature, text_tensor, ref_text, ref_mel_tensor, ref_label, path, wave
159
+
160
+ def _load_tensor(self, data):
161
+ wave_path, text, speaker_id = data
162
+ speaker_id = int(speaker_id)
163
+ full_path = osp.join(self.root_path, wave_path)
164
+ try:
165
+ wave, sr = sf.read(full_path, dtype="float32")
166
+ except Exception as e:
167
+ print(f"[BAD] {full_path} -> {e}", flush=True)
168
+ raise
169
+ if wave.shape[-1] == 2:
170
+ wave = wave[:, 0].squeeze()
171
+ if sr != 24000:
172
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=24000)
173
+ print(wave_path, sr)
174
+
175
+ wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)
176
+
177
+ text = self.text_cleaner(text)
178
+
179
+ text.insert(0, 0)
180
+ text.append(0)
181
+
182
+ text = torch.LongTensor(text)
183
+
184
+ return wave, text, speaker_id
185
+
186
+ def _load_data(self, data):
187
+ wave, text_tensor, speaker_id = self._load_tensor(data)
188
+ mel_tensor = preprocess(wave).squeeze()
189
+
190
+ mel_length = mel_tensor.size(1)
191
+ if mel_length > self.max_mel_length:
192
+ random_start = np.random.randint(0, mel_length - self.max_mel_length)
193
+ mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length]
194
+
195
+ return mel_tensor, speaker_id
196
+
197
+
198
+ class Collater(object):
199
+ """
200
+ Args:
201
+ adaptive_batch_size (bool): if true, decrease batch size when long data comes.
202
+ """
203
+
204
+ def __init__(self, return_wave=False):
205
+ self.text_pad_index = 0
206
+ self.min_mel_length = 192
207
+ self.max_mel_length = 192
208
+ self.return_wave = return_wave
209
+
210
+
211
+ def __call__(self, batch):
212
+ # batch[0] = wave, mel, text, f0, speakerid
213
+ batch_size = len(batch)
214
+
215
+ # sort by mel length
216
+ lengths = [b[1].shape[1] for b in batch]
217
+ batch_indexes = np.argsort(lengths)[::-1]
218
+ batch = [batch[bid] for bid in batch_indexes]
219
+
220
+ nmels = batch[0][1].size(0)
221
+ max_mel_length = max([b[1].shape[1] for b in batch])
222
+ max_text_length = max([b[2].shape[0] for b in batch])
223
+ max_rtext_length = max([b[3].shape[0] for b in batch])
224
+
225
+ labels = torch.zeros((batch_size)).long()
226
+ mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
227
+ texts = torch.zeros((batch_size, max_text_length)).long()
228
+ ref_texts = torch.zeros((batch_size, max_rtext_length)).long()
229
+
230
+ input_lengths = torch.zeros(batch_size).long()
231
+ ref_lengths = torch.zeros(batch_size).long()
232
+ output_lengths = torch.zeros(batch_size).long()
233
+ ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float()
234
+ ref_labels = torch.zeros((batch_size)).long()
235
+ paths = ['' for _ in range(batch_size)]
236
+ waves = [None for _ in range(batch_size)]
237
+
238
+ for bid, (label, mel, text, ref_text, ref_mel, ref_label, path, wave) in enumerate(batch):
239
+ mel_size = mel.size(1)
240
+ text_size = text.size(0)
241
+ rtext_size = ref_text.size(0)
242
+ labels[bid] = label
243
+ mels[bid, :, :mel_size] = mel
244
+ texts[bid, :text_size] = text
245
+ ref_texts[bid, :rtext_size] = ref_text
246
+ input_lengths[bid] = text_size
247
+ ref_lengths[bid] = rtext_size
248
+ output_lengths[bid] = mel_size
249
+ paths[bid] = path
250
+ ref_mel_size = ref_mel.size(1)
251
+ ref_mels[bid, :, :ref_mel_size] = ref_mel
252
+
253
+ ref_labels[bid] = ref_label
254
+ waves[bid] = wave
255
+
256
+ return waves, texts, input_lengths, ref_texts, ref_lengths, mels, output_lengths, ref_mels
257
+
258
+
259
+
260
+ def build_dataloader(path_list,
261
+ root_path,
262
+ validation=False,
263
+ OOD_data="Data/OOD_texts.txt",
264
+ min_length=50,
265
+ batch_size=4,
266
+ num_workers=1,
267
+ device='cpu',
268
+ collate_config={},
269
+ dataset_config={}):
270
+
271
+ dataset = FilePathDataset(path_list, root_path, OOD_data=OOD_data, min_length=min_length, validation=validation, **dataset_config)
272
+ collate_fn = Collater(**collate_config)
273
+ data_loader = DataLoader(dataset,
274
+ batch_size=batch_size,
275
+ shuffle=(not validation),
276
+ num_workers=num_workers,
277
+ drop_last=(not validation),
278
+ collate_fn=collate_fn,
279
+ pin_memory=(device != 'cpu'))
280
+
281
+ return data_loader
282
+
.ipynb_checkpoints/models-checkpoint.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding:utf-8
2
+
3
+ import os
4
+ import os.path as osp
5
+
6
+ import copy
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14
+
15
+ from Utils.ASR.models import ASRCNN
16
+ from Utils.JDC.model import JDCNet
17
+
18
+ from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
19
+ from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
20
+ from Modules.diffusion.diffusion import AudioDiffusionConditional
21
+
22
+ from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator
23
+
24
+ from munch import Munch
25
+ import yaml
26
+
27
+ class LearnedDownSample(nn.Module):
28
+ def __init__(self, layer_type, dim_in):
29
+ super().__init__()
30
+ self.layer_type = layer_type
31
+
32
+ if self.layer_type == 'none':
33
+ self.conv = nn.Identity()
34
+ elif self.layer_type == 'timepreserve':
35
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
36
+ elif self.layer_type == 'half':
37
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
38
+ else:
39
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
40
+
41
+ def forward(self, x):
42
+ return self.conv(x)
43
+
44
+ class LearnedUpSample(nn.Module):
45
+ def __init__(self, layer_type, dim_in):
46
+ super().__init__()
47
+ self.layer_type = layer_type
48
+
49
+ if self.layer_type == 'none':
50
+ self.conv = nn.Identity()
51
+ elif self.layer_type == 'timepreserve':
52
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
53
+ elif self.layer_type == 'half':
54
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
55
+ else:
56
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
57
+
58
+
59
+ def forward(self, x):
60
+ return self.conv(x)
61
+
62
+ class DownSample(nn.Module):
63
+ def __init__(self, layer_type):
64
+ super().__init__()
65
+ self.layer_type = layer_type
66
+
67
+ def forward(self, x):
68
+ if self.layer_type == 'none':
69
+ return x
70
+ elif self.layer_type == 'timepreserve':
71
+ return F.avg_pool2d(x, (2, 1))
72
+ elif self.layer_type == 'half':
73
+ if x.shape[-1] % 2 != 0:
74
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
75
+ return F.avg_pool2d(x, 2)
76
+ else:
77
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
78
+
79
+
80
+ class UpSample(nn.Module):
81
+ def __init__(self, layer_type):
82
+ super().__init__()
83
+ self.layer_type = layer_type
84
+
85
+ def forward(self, x):
86
+ if self.layer_type == 'none':
87
+ return x
88
+ elif self.layer_type == 'timepreserve':
89
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
90
+ elif self.layer_type == 'half':
91
+ return F.interpolate(x, scale_factor=2, mode='nearest')
92
+ else:
93
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
94
+
95
+
96
+ class ResBlk(nn.Module):
97
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
98
+ normalize=False, downsample='none'):
99
+ super().__init__()
100
+ self.actv = actv
101
+ self.normalize = normalize
102
+ self.downsample = DownSample(downsample)
103
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
104
+ self.learned_sc = dim_in != dim_out
105
+ self._build_weights(dim_in, dim_out)
106
+
107
+ def _build_weights(self, dim_in, dim_out):
108
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
109
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
110
+ if self.normalize:
111
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
112
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
113
+ if self.learned_sc:
114
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
115
+
116
+ def _shortcut(self, x):
117
+ if self.learned_sc:
118
+ x = self.conv1x1(x)
119
+ if self.downsample:
120
+ x = self.downsample(x)
121
+ return x
122
+
123
+ def _residual(self, x):
124
+ if self.normalize:
125
+ x = self.norm1(x)
126
+ x = self.actv(x)
127
+ x = self.conv1(x)
128
+ x = self.downsample_res(x)
129
+ if self.normalize:
130
+ x = self.norm2(x)
131
+ x = self.actv(x)
132
+ x = self.conv2(x)
133
+ return x
134
+
135
+ def forward(self, x):
136
+ x = self._shortcut(x) + self._residual(x)
137
+ return x / math.sqrt(2) # unit variance
138
+
139
+ class StyleEncoder(nn.Module):
140
+ def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
141
+ super().__init__()
142
+ blocks = []
143
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
144
+
145
+ repeat_num = 4
146
+ for _ in range(repeat_num):
147
+ dim_out = min(dim_in*2, max_conv_dim)
148
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
149
+ dim_in = dim_out
150
+
151
+ blocks += [nn.LeakyReLU(0.2)]
152
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
153
+ blocks += [nn.AdaptiveAvgPool2d(1)]
154
+ blocks += [nn.LeakyReLU(0.2)]
155
+ self.shared = nn.Sequential(*blocks)
156
+
157
+ self.unshared = nn.Linear(dim_out, style_dim)
158
+
159
+ def forward(self, x):
160
+ h = self.shared(x)
161
+ h = h.view(h.size(0), -1)
162
+ s = self.unshared(h)
163
+
164
+ return s
165
+
166
+ class LinearNorm(torch.nn.Module):
167
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
168
+ super(LinearNorm, self).__init__()
169
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
170
+
171
+ torch.nn.init.xavier_uniform_(
172
+ self.linear_layer.weight,
173
+ gain=torch.nn.init.calculate_gain(w_init_gain))
174
+
175
+ def forward(self, x):
176
+ return self.linear_layer(x)
177
+
178
+ class Discriminator2d(nn.Module):
179
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
180
+ super().__init__()
181
+ blocks = []
182
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
183
+
184
+ for lid in range(repeat_num):
185
+ dim_out = min(dim_in*2, max_conv_dim)
186
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
187
+ dim_in = dim_out
188
+
189
+ blocks += [nn.LeakyReLU(0.2)]
190
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
191
+ blocks += [nn.LeakyReLU(0.2)]
192
+ blocks += [nn.AdaptiveAvgPool2d(1)]
193
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
194
+ self.main = nn.Sequential(*blocks)
195
+
196
+ def get_feature(self, x):
197
+ features = []
198
+ for l in self.main:
199
+ x = l(x)
200
+ features.append(x)
201
+ out = features[-1]
202
+ out = out.view(out.size(0), -1) # (batch, num_domains)
203
+ return out, features
204
+
205
+ def forward(self, x):
206
+ out, features = self.get_feature(x)
207
+ out = out.squeeze() # (batch)
208
+ return out, features
209
+
210
+ class ResBlk1d(nn.Module):
211
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
212
+ normalize=False, downsample='none', dropout_p=0.2):
213
+ super().__init__()
214
+ self.actv = actv
215
+ self.normalize = normalize
216
+ self.downsample_type = downsample
217
+ self.learned_sc = dim_in != dim_out
218
+ self._build_weights(dim_in, dim_out)
219
+ self.dropout_p = dropout_p
220
+
221
+ if self.downsample_type == 'none':
222
+ self.pool = nn.Identity()
223
+ else:
224
+ self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
225
+
226
+ def _build_weights(self, dim_in, dim_out):
227
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
228
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
229
+ if self.normalize:
230
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
231
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
232
+ if self.learned_sc:
233
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
234
+
235
+ def downsample(self, x):
236
+ if self.downsample_type == 'none':
237
+ return x
238
+ else:
239
+ if x.shape[-1] % 2 != 0:
240
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
241
+ return F.avg_pool1d(x, 2)
242
+
243
+ def _shortcut(self, x):
244
+ if self.learned_sc:
245
+ x = self.conv1x1(x)
246
+ x = self.downsample(x)
247
+ return x
248
+
249
+ def _residual(self, x):
250
+ if self.normalize:
251
+ x = self.norm1(x)
252
+ x = self.actv(x)
253
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
254
+
255
+ x = self.conv1(x)
256
+ x = self.pool(x)
257
+ if self.normalize:
258
+ x = self.norm2(x)
259
+
260
+ x = self.actv(x)
261
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
262
+
263
+ x = self.conv2(x)
264
+ return x
265
+
266
+ def forward(self, x):
267
+ x = self._shortcut(x) + self._residual(x)
268
+ return x / math.sqrt(2) # unit variance
269
+
270
+ class LayerNorm(nn.Module):
271
+ def __init__(self, channels, eps=1e-5):
272
+ super().__init__()
273
+ self.channels = channels
274
+ self.eps = eps
275
+
276
+ self.gamma = nn.Parameter(torch.ones(channels))
277
+ self.beta = nn.Parameter(torch.zeros(channels))
278
+
279
+ def forward(self, x):
280
+ x = x.transpose(1, -1)
281
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
282
+ return x.transpose(1, -1)
283
+
284
+ class TextEncoder(nn.Module):
285
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
286
+ super().__init__()
287
+ self.embedding = nn.Embedding(n_symbols, channels)
288
+
289
+ padding = (kernel_size - 1) // 2
290
+ self.cnn = nn.ModuleList()
291
+ for _ in range(depth):
292
+ self.cnn.append(nn.Sequential(
293
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
294
+ LayerNorm(channels),
295
+ actv,
296
+ nn.Dropout(0.2),
297
+ ))
298
+ # self.cnn = nn.Sequential(*self.cnn)
299
+
300
+ self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
301
+
302
+ def forward(self, x, input_lengths, m):
303
+ x = self.embedding(x) # [B, T, emb]
304
+ x = x.transpose(1, 2) # [B, emb, T]
305
+ m = m.to(input_lengths.device).unsqueeze(1)
306
+ x.masked_fill_(m, 0.0)
307
+
308
+ for c in self.cnn:
309
+ x = c(x)
310
+ x.masked_fill_(m, 0.0)
311
+
312
+ x = x.transpose(1, 2) # [B, T, chn]
313
+
314
+ input_lengths = input_lengths.cpu().numpy()
315
+ x = nn.utils.rnn.pack_padded_sequence(
316
+ x, input_lengths, batch_first=True, enforce_sorted=False)
317
+
318
+ self.lstm.flatten_parameters()
319
+ x, _ = self.lstm(x)
320
+ x, _ = nn.utils.rnn.pad_packed_sequence(
321
+ x, batch_first=True)
322
+
323
+ x = x.transpose(-1, -2)
324
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
325
+
326
+ x_pad[:, :, :x.shape[-1]] = x
327
+ x = x_pad.to(x.device)
328
+
329
+ x.masked_fill_(m, 0.0)
330
+
331
+ return x
332
+
333
+ def inference(self, x):
334
+ x = self.embedding(x)
335
+ x = x.transpose(1, 2)
336
+ x = self.cnn(x)
337
+ x = x.transpose(1, 2)
338
+ self.lstm.flatten_parameters()
339
+ x, _ = self.lstm(x)
340
+ return x
341
+
342
+ def length_to_mask(self, lengths):
343
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
344
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
345
+ return mask
346
+
347
+
348
+
349
+ class AdaIN1d(nn.Module):
350
+ def __init__(self, style_dim, num_features):
351
+ super().__init__()
352
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
353
+ self.fc = nn.Linear(style_dim, num_features*2)
354
+
355
+ def forward(self, x, s):
356
+ h = self.fc(s)
357
+ h = h.view(h.size(0), h.size(1), 1)
358
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
359
+ return (1 + gamma) * self.norm(x) + beta
360
+
361
+ class UpSample1d(nn.Module):
362
+ def __init__(self, layer_type):
363
+ super().__init__()
364
+ self.layer_type = layer_type
365
+
366
+ def forward(self, x):
367
+ if self.layer_type == 'none':
368
+ return x
369
+ else:
370
+ return F.interpolate(x, scale_factor=2, mode='nearest')
371
+
372
+ class AdainResBlk1d(nn.Module):
373
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
374
+ upsample='none', dropout_p=0.0):
375
+ super().__init__()
376
+ self.actv = actv
377
+ self.upsample_type = upsample
378
+ self.upsample = UpSample1d(upsample)
379
+ self.learned_sc = dim_in != dim_out
380
+ self._build_weights(dim_in, dim_out, style_dim)
381
+ self.dropout = nn.Dropout(dropout_p)
382
+
383
+ if upsample == 'none':
384
+ self.pool = nn.Identity()
385
+ else:
386
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
387
+
388
+
389
+ def _build_weights(self, dim_in, dim_out, style_dim):
390
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
391
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
392
+ self.norm1 = AdaIN1d(style_dim, dim_in)
393
+ self.norm2 = AdaIN1d(style_dim, dim_out)
394
+ if self.learned_sc:
395
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
396
+
397
+ def _shortcut(self, x):
398
+ x = self.upsample(x)
399
+ if self.learned_sc:
400
+ x = self.conv1x1(x)
401
+ return x
402
+
403
+ def _residual(self, x, s):
404
+ x = self.norm1(x, s)
405
+ x = self.actv(x)
406
+ x = self.pool(x)
407
+ x = self.conv1(self.dropout(x))
408
+ x = self.norm2(x, s)
409
+ x = self.actv(x)
410
+ x = self.conv2(self.dropout(x))
411
+ return x
412
+
413
+ def forward(self, x, s):
414
+ out = self._residual(x, s)
415
+ out = (out + self._shortcut(x)) / math.sqrt(2)
416
+ return out
417
+
418
+ class AdaLayerNorm(nn.Module):
419
+ def __init__(self, style_dim, channels, eps=1e-5):
420
+ super().__init__()
421
+ self.channels = channels
422
+ self.eps = eps
423
+
424
+ self.fc = nn.Linear(style_dim, channels*2)
425
+
426
+ def forward(self, x, s):
427
+ x = x.transpose(-1, -2)
428
+ x = x.transpose(1, -1)
429
+
430
+ h = self.fc(s)
431
+ h = h.view(h.size(0), h.size(1), 1)
432
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
433
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
434
+
435
+
436
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
437
+ x = (1 + gamma) * x + beta
438
+ return x.transpose(1, -1).transpose(-1, -2)
439
+
440
+ class ProsodyPredictor(nn.Module):
441
+
442
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
443
+ super().__init__()
444
+
445
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
446
+ d_model=d_hid,
447
+ nlayers=nlayers,
448
+ dropout=dropout)
449
+
450
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
451
+ self.duration_proj = LinearNorm(d_hid, max_dur)
452
+
453
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
454
+ self.F0 = nn.ModuleList()
455
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
456
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
457
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
458
+
459
+ self.N = nn.ModuleList()
460
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
461
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
462
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
463
+
464
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
465
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
466
+
467
+
468
+ def forward(self, texts, style, text_lengths, alignment, m):
469
+ d = self.text_encoder(texts, style, text_lengths, m)
470
+
471
+ batch_size = d.shape[0]
472
+ text_size = d.shape[1]
473
+
474
+ # predict duration
475
+ input_lengths = text_lengths.cpu().numpy()
476
+ x = nn.utils.rnn.pack_padded_sequence(
477
+ d, input_lengths, batch_first=True, enforce_sorted=False)
478
+
479
+ m = m.to(text_lengths.device).unsqueeze(1)
480
+
481
+ self.lstm.flatten_parameters()
482
+ x, _ = self.lstm(x)
483
+ x, _ = nn.utils.rnn.pad_packed_sequence(
484
+ x, batch_first=True)
485
+
486
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
487
+
488
+ x_pad[:, :x.shape[1], :] = x
489
+ x = x_pad.to(x.device)
490
+
491
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
492
+
493
+ en = (d.transpose(-1, -2) @ alignment)
494
+
495
+ return duration.squeeze(-1), en
496
+
497
+ def F0Ntrain(self, x, s):
498
+ x, _ = self.shared(x.transpose(-1, -2))
499
+
500
+ F0 = x.transpose(-1, -2)
501
+ for block in self.F0:
502
+ F0 = block(F0, s)
503
+ F0 = self.F0_proj(F0)
504
+
505
+ N = x.transpose(-1, -2)
506
+ for block in self.N:
507
+ N = block(N, s)
508
+ N = self.N_proj(N)
509
+
510
+ return F0.squeeze(1), N.squeeze(1)
511
+
512
+ def length_to_mask(self, lengths):
513
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
514
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
515
+ return mask
516
+
517
+ class DurationEncoder(nn.Module):
518
+
519
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
520
+ super().__init__()
521
+ self.lstms = nn.ModuleList()
522
+ for _ in range(nlayers):
523
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
524
+ d_model // 2,
525
+ num_layers=1,
526
+ batch_first=True,
527
+ bidirectional=True,
528
+ dropout=dropout))
529
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
530
+
531
+
532
+ self.dropout = dropout
533
+ self.d_model = d_model
534
+ self.sty_dim = sty_dim
535
+
536
+ def forward(self, x, style, text_lengths, m):
537
+ masks = m.to(text_lengths.device)
538
+
539
+ x = x.permute(2, 0, 1)
540
+ s = style.expand(x.shape[0], x.shape[1], -1)
541
+ x = torch.cat([x, s], axis=-1)
542
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
543
+
544
+ x = x.transpose(0, 1)
545
+ input_lengths = text_lengths.cpu().numpy()
546
+ x = x.transpose(-1, -2)
547
+
548
+ for block in self.lstms:
549
+ if isinstance(block, AdaLayerNorm):
550
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
551
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
552
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
553
+ else:
554
+ x = x.transpose(-1, -2)
555
+ x = nn.utils.rnn.pack_padded_sequence(
556
+ x, input_lengths, batch_first=True, enforce_sorted=False)
557
+ block.flatten_parameters()
558
+ x, _ = block(x)
559
+ x, _ = nn.utils.rnn.pad_packed_sequence(
560
+ x, batch_first=True)
561
+ x = F.dropout(x, p=self.dropout, training=self.training)
562
+ x = x.transpose(-1, -2)
563
+
564
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
565
+
566
+ x_pad[:, :, :x.shape[-1]] = x
567
+ x = x_pad.to(x.device)
568
+
569
+ return x.transpose(-1, -2)
570
+
571
+ def inference(self, x, style):
572
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
573
+ style = style.expand(x.shape[0], x.shape[1], -1)
574
+ x = torch.cat([x, style], axis=-1)
575
+ src = self.pos_encoder(x)
576
+ output = self.transformer_encoder(src).transpose(0, 1)
577
+ return output
578
+
579
+ def length_to_mask(self, lengths):
580
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
581
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
582
+ return mask
583
+
584
+ def load_F0_models(path):
585
+ # load F0 model
586
+
587
+ F0_model = JDCNet(num_class=1, seq_len=192)
588
+ params = torch.load(path, map_location='cpu')['net']
589
+ F0_model.load_state_dict(params)
590
+ _ = F0_model.train()
591
+
592
+ return F0_model
593
+
594
+ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
595
+ # load ASR model
596
+ def _load_config(path):
597
+ with open(path) as f:
598
+ config = yaml.safe_load(f)
599
+ model_config = config['model_params']
600
+ return model_config
601
+
602
+ def _load_model(model_config, model_path):
603
+ model = ASRCNN(**model_config)
604
+ params = torch.load(model_path, map_location='cpu', weights_only=False)['model']
605
+ model.load_state_dict(params)
606
+ return model
607
+
608
+ asr_model_config = _load_config(ASR_MODEL_CONFIG)
609
+ asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
610
+ _ = asr_model.train()
611
+
612
+ return asr_model
613
+
614
+ def build_model(args, text_aligner, pitch_extractor, bert):
615
+ assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
616
+
617
+ if args.decoder.type == "istftnet":
618
+ from Modules.istftnet import Decoder
619
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
620
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
621
+ upsample_rates = args.decoder.upsample_rates,
622
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
623
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
624
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
625
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
626
+ else:
627
+ from Modules.hifigan import Decoder
628
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
629
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
630
+ upsample_rates = args.decoder.upsample_rates,
631
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
632
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
633
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
634
+
635
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
636
+
637
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
638
+
639
+ style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
640
+ predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
641
+
642
+ # define diffusion model
643
+ if args.multispeaker:
644
+ transformer = StyleTransformer1d(channels=args.style_dim*2,
645
+ context_embedding_features=bert.config.hidden_size,
646
+ context_features=args.style_dim*2,
647
+ **args.diffusion.transformer)
648
+ else:
649
+ transformer = Transformer1d(channels=args.style_dim*2,
650
+ context_embedding_features=bert.config.hidden_size,
651
+ **args.diffusion.transformer)
652
+
653
+ diffusion = AudioDiffusionConditional(
654
+ in_channels=1,
655
+ embedding_max_length=bert.config.max_position_embeddings,
656
+ embedding_features=bert.config.hidden_size,
657
+ embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
658
+ channels=args.style_dim*2,
659
+ context_features=args.style_dim*2,
660
+ )
661
+
662
+ diffusion.diffusion = KDiffusion(
663
+ net=diffusion.unet,
664
+ sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
665
+ sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
666
+ dynamic_threshold=0.0
667
+ )
668
+ diffusion.diffusion.net = transformer
669
+ diffusion.unet = transformer
670
+
671
+
672
+ nets = Munch(
673
+ bert=bert,
674
+ bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
675
+
676
+ predictor=predictor,
677
+ decoder=decoder,
678
+ text_encoder=text_encoder,
679
+
680
+ predictor_encoder=predictor_encoder,
681
+ style_encoder=style_encoder,
682
+ diffusion=diffusion,
683
+
684
+ text_aligner = text_aligner,
685
+ pitch_extractor=pitch_extractor,
686
+
687
+ mpd = MultiPeriodDiscriminator(),
688
+ msd = MultiResSpecDiscriminator(),
689
+
690
+ # slm discriminator head
691
+ wd = WavLMDiscriminator(args.slm.hidden, args.slm.nlayers, args.slm.initial_channel),
692
+ )
693
+
694
+ return nets
695
+
696
+ def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
697
+ state = torch.load(path, map_location='cpu')
698
+ params = state['net']
699
+ for key in model:
700
+ if key in params and key not in ignore_modules:
701
+ print('%s loaded' % key)
702
+ model[key].load_state_dict(params[key], strict=False)
703
+ _ = [model[key].eval() for key in model]
704
+
705
+ if not load_only_params:
706
+ epoch = state["epoch"]
707
+ iters = state["iters"]
708
+ optimizer.load_state_dict(state["optimizer"])
709
+ else:
710
+ epoch = 0
711
+ iters = 0
712
+
713
+ return model, optimizer, epoch, iters
Colab/StyleTTS2_Demo_LJSpeech.ipynb ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4",
8
+ "authorship_tag": "ABX9TyM1x2mx2VnkYNFVlD+DFzmy",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "accelerator": "GPU"
19
+ },
20
+ "cells": [
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {
24
+ "id": "view-in-github",
25
+ "colab_type": "text"
26
+ },
27
+ "source": [
28
+ "<a href=\"https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/Colab/StyleTTS2_Demo_LJSpeech.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "source": [
34
+ "### Install packages and download models"
35
+ ],
36
+ "metadata": {
37
+ "id": "nm653VK4CG9F"
38
+ }
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "source": [
43
+ "%%shell\n",
44
+ "git clone https://github.com/yl4579/StyleTTS2.git\n",
45
+ "cd StyleTTS2\n",
46
+ "pip install SoundFile torchaudio munch torch pydub pyyaml librosa nltk matplotlib accelerate transformers phonemizer einops einops-exts tqdm typing-extensions git+https://github.com/resemble-ai/monotonic_align.git\n",
47
+ "sudo apt-get install espeak-ng\n",
48
+ "git-lfs clone https://huggingface.co/yl4579/StyleTTS2-LJSpeech\n",
49
+ "mv StyleTTS2-LJSpeech/Models ."
50
+ ],
51
+ "metadata": {
52
+ "id": "gciBKMqCCLvT"
53
+ },
54
+ "execution_count": null,
55
+ "outputs": []
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "source": [
60
+ "### Load models"
61
+ ],
62
+ "metadata": {
63
+ "id": "OAA8lx-XCQnM"
64
+ }
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "source": [
69
+ "%cd StyleTTS2\n",
70
+ "\n",
71
+ "import torch\n",
72
+ "torch.manual_seed(0)\n",
73
+ "torch.backends.cudnn.benchmark = False\n",
74
+ "torch.backends.cudnn.deterministic = True\n",
75
+ "\n",
76
+ "import random\n",
77
+ "random.seed(0)\n",
78
+ "\n",
79
+ "import numpy as np\n",
80
+ "np.random.seed(0)\n",
81
+ "\n",
82
+ "import nltk\n",
83
+ "nltk.download('punkt')\n",
84
+ "\n",
85
+ "# load packages\n",
86
+ "import time\n",
87
+ "import random\n",
88
+ "import yaml\n",
89
+ "from munch import Munch\n",
90
+ "import numpy as np\n",
91
+ "import torch\n",
92
+ "from torch import nn\n",
93
+ "import torch.nn.functional as F\n",
94
+ "import torchaudio\n",
95
+ "import librosa\n",
96
+ "from nltk.tokenize import word_tokenize\n",
97
+ "\n",
98
+ "from models import *\n",
99
+ "from utils import *\n",
100
+ "from text_utils import TextCleaner\n",
101
+ "textclenaer = TextCleaner()\n",
102
+ "\n",
103
+ "%matplotlib inline\n",
104
+ "\n",
105
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
106
+ "\n",
107
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
108
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
109
+ "mean, std = -4, 4\n",
110
+ "\n",
111
+ "def length_to_mask(lengths):\n",
112
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
113
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
114
+ " return mask\n",
115
+ "\n",
116
+ "def preprocess(wave):\n",
117
+ " wave_tensor = torch.from_numpy(wave).float()\n",
118
+ " mel_tensor = to_mel(wave_tensor)\n",
119
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
120
+ " return mel_tensor\n",
121
+ "\n",
122
+ "def compute_style(ref_dicts):\n",
123
+ " reference_embeddings = {}\n",
124
+ " for key, path in ref_dicts.items():\n",
125
+ " wave, sr = librosa.load(path, sr=24000)\n",
126
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
127
+ " if sr != 24000:\n",
128
+ " audio = librosa.resample(audio, sr, 24000)\n",
129
+ " mel_tensor = preprocess(audio).to(device)\n",
130
+ "\n",
131
+ " with torch.no_grad():\n",
132
+ " ref = model.style_encoder(mel_tensor.unsqueeze(1))\n",
133
+ " reference_embeddings[key] = (ref.squeeze(1), audio)\n",
134
+ "\n",
135
+ " return reference_embeddings\n",
136
+ "\n",
137
+ "# load phonemizer\n",
138
+ "import phonemizer\n",
139
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore')\n",
140
+ "\n",
141
+ "config = yaml.safe_load(open(\"Models/LJSpeech/config.yml\"))\n",
142
+ "\n",
143
+ "# load pretrained ASR model\n",
144
+ "ASR_config = config.get('ASR_config', False)\n",
145
+ "ASR_path = config.get('ASR_path', False)\n",
146
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
147
+ "\n",
148
+ "# load pretrained F0 model\n",
149
+ "F0_path = config.get('F0_path', False)\n",
150
+ "pitch_extractor = load_F0_models(F0_path)\n",
151
+ "\n",
152
+ "# load BERT model\n",
153
+ "from Utils.PLBERT.util import load_plbert\n",
154
+ "BERT_path = config.get('PLBERT_dir', False)\n",
155
+ "plbert = load_plbert(BERT_path)\n",
156
+ "\n",
157
+ "model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)\n",
158
+ "_ = [model[key].eval() for key in model]\n",
159
+ "_ = [model[key].to(device) for key in model]\n",
160
+ "\n",
161
+ "params_whole = torch.load(\"Models/LJSpeech/epoch_2nd_00100.pth\", map_location='cpu')\n",
162
+ "params = params_whole['net']\n",
163
+ "\n",
164
+ "for key in model:\n",
165
+ " if key in params:\n",
166
+ " print('%s loaded' % key)\n",
167
+ " try:\n",
168
+ " model[key].load_state_dict(params[key])\n",
169
+ " except:\n",
170
+ " from collections import OrderedDict\n",
171
+ " state_dict = params[key]\n",
172
+ " new_state_dict = OrderedDict()\n",
173
+ " for k, v in state_dict.items():\n",
174
+ " name = k[7:] # remove `module.`\n",
175
+ " new_state_dict[name] = v\n",
176
+ " # load params\n",
177
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
178
+ "# except:\n",
179
+ "# _load(params[key], model[key])\n",
180
+ "_ = [model[key].eval() for key in model]\n",
181
+ "\n",
182
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n",
183
+ "\n",
184
+ "sampler = DiffusionSampler(\n",
185
+ " model.diffusion.diffusion,\n",
186
+ " sampler=ADPM2Sampler(),\n",
187
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
188
+ " clamp=False\n",
189
+ ")\n",
190
+ "\n",
191
+ "def inference(text, noise, diffusion_steps=5, embedding_scale=1):\n",
192
+ " text = text.strip()\n",
193
+ " text = text.replace('\"', '')\n",
194
+ " ps = global_phonemizer.phonemize([text])\n",
195
+ " ps = word_tokenize(ps[0])\n",
196
+ " ps = ' '.join(ps)\n",
197
+ "\n",
198
+ " tokens = textclenaer(ps)\n",
199
+ " tokens.insert(0, 0)\n",
200
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
201
+ "\n",
202
+ " with torch.no_grad():\n",
203
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)\n",
204
+ " text_mask = length_to_mask(input_lengths).to(tokens.device)\n",
205
+ "\n",
206
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
207
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
208
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
209
+ "\n",
210
+ " s_pred = sampler(noise,\n",
211
+ " embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,\n",
212
+ " embedding_scale=embedding_scale).squeeze(0)\n",
213
+ "\n",
214
+ " s = s_pred[:, 128:]\n",
215
+ " ref = s_pred[:, :128]\n",
216
+ "\n",
217
+ " d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)\n",
218
+ "\n",
219
+ " x, _ = model.predictor.lstm(d)\n",
220
+ " duration = model.predictor.duration_proj(x)\n",
221
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
222
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
223
+ "\n",
224
+ " pred_dur[-1] += 5\n",
225
+ "\n",
226
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
227
+ " c_frame = 0\n",
228
+ " for i in range(pred_aln_trg.size(0)):\n",
229
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
230
+ " c_frame += int(pred_dur[i].data)\n",
231
+ "\n",
232
+ " # encode prosody\n",
233
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
234
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
235
+ " out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),\n",
236
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
237
+ "\n",
238
+ " return out.squeeze().cpu().numpy()\n",
239
+ "\n",
240
+ "def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1):\n",
241
+ " text = text.strip()\n",
242
+ " text = text.replace('\"', '')\n",
243
+ " ps = global_phonemizer.phonemize([text])\n",
244
+ " ps = word_tokenize(ps[0])\n",
245
+ " ps = ' '.join(ps)\n",
246
+ "\n",
247
+ " tokens = textclenaer(ps)\n",
248
+ " tokens.insert(0, 0)\n",
249
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
250
+ "\n",
251
+ " with torch.no_grad():\n",
252
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)\n",
253
+ " text_mask = length_to_mask(input_lengths).to(tokens.device)\n",
254
+ "\n",
255
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
256
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
257
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
258
+ "\n",
259
+ " s_pred = sampler(noise,\n",
260
+ " embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,\n",
261
+ " embedding_scale=embedding_scale).squeeze(0)\n",
262
+ "\n",
263
+ " if s_prev is not None:\n",
264
+ " # convex combination of previous and current style\n",
265
+ " s_pred = alpha * s_prev + (1 - alpha) * s_pred\n",
266
+ "\n",
267
+ " s = s_pred[:, 128:]\n",
268
+ " ref = s_pred[:, :128]\n",
269
+ "\n",
270
+ " d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)\n",
271
+ "\n",
272
+ " x, _ = model.predictor.lstm(d)\n",
273
+ " duration = model.predictor.duration_proj(x)\n",
274
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
275
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
276
+ "\n",
277
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
278
+ " c_frame = 0\n",
279
+ " for i in range(pred_aln_trg.size(0)):\n",
280
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
281
+ " c_frame += int(pred_dur[i].data)\n",
282
+ "\n",
283
+ " # encode prosody\n",
284
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
285
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
286
+ " out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),\n",
287
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
288
+ "\n",
289
+ " return out.squeeze().cpu().numpy(), s_pred"
290
+ ],
291
+ "metadata": {
292
+ "id": "m0XRpbxSCSix"
293
+ },
294
+ "execution_count": null,
295
+ "outputs": []
296
+ },
297
+ {
298
+ "cell_type": "markdown",
299
+ "source": [
300
+ "### Synthesize speech"
301
+ ],
302
+ "metadata": {
303
+ "id": "vuCbS0gdArgJ"
304
+ }
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "source": [
309
+ "# @title Input Text { display-mode: \"form\" }\n",
310
+ "# synthesize a text\n",
311
+ "text = \"StyleTTS 2 is a text-to-speech model that leverages style diffusion and adversarial training with large speech language models to achieve human-level text-to-speech synthesis.\" # @param {type:\"string\"}\n"
312
+ ],
313
+ "metadata": {
314
+ "id": "7Ud1Y-kbBPTw"
315
+ },
316
+ "execution_count": 3,
317
+ "outputs": []
318
+ },
319
+ {
320
+ "cell_type": "markdown",
321
+ "source": [
322
+ "#### Basic synthesis (5 diffusion steps)"
323
+ ],
324
+ "metadata": {
325
+ "id": "TM2NjuM7B6sz"
326
+ }
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "source": [
331
+ "start = time.time()\n",
332
+ "noise = torch.randn(1,1,256).to(device)\n",
333
+ "wav = inference(text, noise, diffusion_steps=5, embedding_scale=1)\n",
334
+ "rtf = (time.time() - start) / (len(wav) / 24000)\n",
335
+ "print(f\"RTF = {rtf:5f}\")\n",
336
+ "import IPython.display as ipd\n",
337
+ "display(ipd.Audio(wav, rate=24000))"
338
+ ],
339
+ "metadata": {
340
+ "id": "KILqC-V-Ay5e"
341
+ },
342
+ "execution_count": null,
343
+ "outputs": []
344
+ },
345
+ {
346
+ "cell_type": "markdown",
347
+ "source": [
348
+ "#### With higher diffusion steps (more diverse)\n",
349
+ "Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed."
350
+ ],
351
+ "metadata": {
352
+ "id": "oZk9o-EzCBVx"
353
+ }
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "source": [
358
+ "start = time.time()\n",
359
+ "noise = torch.randn(1,1,256).to(device)\n",
360
+ "wav = inference(text, noise, diffusion_steps=10, embedding_scale=1)\n",
361
+ "rtf = (time.time() - start) / (len(wav) / 24000)\n",
362
+ "print(f\"RTF = {rtf:5f}\")\n",
363
+ "import IPython.display as ipd\n",
364
+ "display(ipd.Audio(wav, rate=24000))"
365
+ ],
366
+ "metadata": {
367
+ "id": "9_OHtzMbB9gL"
368
+ },
369
+ "execution_count": null,
370
+ "outputs": []
371
+ },
372
+ {
373
+ "cell_type": "markdown",
374
+ "source": [
375
+ "### Speech expressiveness\n",
376
+ "The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page."
377
+ ],
378
+ "metadata": {
379
+ "id": "NyDACd-0CaqL"
380
+ }
381
+ },
382
+ {
383
+ "cell_type": "markdown",
384
+ "source": [
385
+ "#### With embedding_scale=1\n",
386
+ "This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional."
387
+ ],
388
+ "metadata": {
389
+ "id": "cRkS5VWxCck4"
390
+ }
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "source": [
395
+ "texts = {}\n",
396
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
397
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
398
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
399
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
400
+ "\n",
401
+ "for k,v in texts.items():\n",
402
+ " noise = torch.randn(1,1,256).to(device)\n",
403
+ " wav = inference(v, noise, diffusion_steps=10, embedding_scale=1)\n",
404
+ " print(k + \": \")\n",
405
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
406
+ ],
407
+ "metadata": {
408
+ "id": "H5g5RO-mCbZB"
409
+ },
410
+ "execution_count": null,
411
+ "outputs": []
412
+ },
413
+ {
414
+ "cell_type": "markdown",
415
+ "source": [
416
+ "#### With embedding_scale=2"
417
+ ],
418
+ "metadata": {
419
+ "id": "f4S8TXSpCgpA"
420
+ }
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "source": [
425
+ "texts = {}\n",
426
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
427
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
428
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
429
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
430
+ "\n",
431
+ "for k,v in texts.items():\n",
432
+ " noise = torch.randn(1,1,256).to(device)\n",
433
+ " wav = inference(v, noise, diffusion_steps=10, embedding_scale=2) # embedding_scale=2 for more pronounced emotion\n",
434
+ " print(k + \": \")\n",
435
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
436
+ ],
437
+ "metadata": {
438
+ "id": "xHHIdeNrCezC"
439
+ },
440
+ "execution_count": null,
441
+ "outputs": []
442
+ },
443
+ {
444
+ "cell_type": "markdown",
445
+ "source": [
446
+ "### Long-form generation\n",
447
+ "This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page."
448
+ ],
449
+ "metadata": {
450
+ "id": "nAh7Tov4CkuH"
451
+ }
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "source": [
456
+ "passage = '''If the supply of fruit is greater than the family needs, it may be made a source of income by sending the fresh fruit to the market if there is one near enough, or by preserving, canning, and making jelly for sale. To make such an enterprise a success the fruit and work must be first class. There is magic in the word \"Homemade,\" when the product appeals to the eye and the palate; but many careless and incompetent people have found to their sorrow that this word has not magic enough to float inferior goods on the market. As a rule large canning and preserving establishments are clean and have the best appliances, and they employ chemists and skilled labor. The home product must be very good to compete with the attractive goods that are sent out from such establishments. Yet for first-class homemade products there is a market in all large cities. All first-class grocers have customers who purchase such goods.''' # @param {type:\"string\"}"
457
+ ],
458
+ "metadata": {
459
+ "cellView": "form",
460
+ "id": "IJwUbgvACoDu"
461
+ },
462
+ "execution_count": 8,
463
+ "outputs": []
464
+ },
465
+ {
466
+ "cell_type": "code",
467
+ "source": [
468
+ "sentences = passage.split('.') # simple split by comma\n",
469
+ "wavs = []\n",
470
+ "s_prev = None\n",
471
+ "for text in sentences:\n",
472
+ " if text.strip() == \"\": continue\n",
473
+ " text += '.' # add it back\n",
474
+ " noise = torch.randn(1,1,256).to(device)\n",
475
+ " wav, s_prev = LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=10, embedding_scale=1.5)\n",
476
+ " wavs.append(wav)\n",
477
+ "display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))"
478
+ ],
479
+ "metadata": {
480
+ "id": "nP-7i2QAC0JT"
481
+ },
482
+ "execution_count": null,
483
+ "outputs": []
484
+ }
485
+ ]
486
+ }
Colab/StyleTTS2_Demo_LibriTTS.ipynb ADDED
@@ -0,0 +1,1218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "view-in-github",
7
+ "colab_type": "text"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/Colab/StyleTTS2_Demo_LibriTTS.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "aAGQPfgYIR23"
17
+ },
18
+ "source": [
19
+ "### Install packages and download models"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {
26
+ "colab": {
27
+ "base_uri": "https://localhost:8080/"
28
+ },
29
+ "id": "zDPW5uSpISd2",
30
+ "outputId": "6463ff79-18d5-4071-c6ad-01947beeb368"
31
+ },
32
+ "outputs": [
33
+ {
34
+ "output_type": "stream",
35
+ "name": "stdout",
36
+ "text": [
37
+
38
+ ]
39
+ }
40
+ ],
41
+ "source": [
42
+ "%%shell\n",
43
+ "git clone https://github.com/yl4579/StyleTTS2.git\n",
44
+ "cd StyleTTS2\n",
45
+ "pip install SoundFile torchaudio munch torch pydub pyyaml librosa nltk matplotlib accelerate transformers phonemizer einops einops-exts tqdm typing-extensions git+https://github.com/resemble-ai/monotonic_align.git\n",
46
+ "sudo apt-get install espeak-ng\n",
47
+ "git-lfs clone https://huggingface.co/yl4579/StyleTTS2-LibriTTS\n",
48
+ "mv StyleTTS2-LibriTTS/Models .\n",
49
+ "mv StyleTTS2-LibriTTS/reference_audio.zip .\n",
50
+ "unzip reference_audio.zip\n",
51
+ "mv reference_audio Demo/reference_audio"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "markdown",
56
+ "metadata": {
57
+ "id": "eJdB_nCOIVIN"
58
+ },
59
+ "source": [
60
+ "### Load models"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {
67
+ "id": "cha8Tr2uJwN0"
68
+ },
69
+ "outputs": [],
70
+ "source": [
71
+ "import nltk\n",
72
+ "nltk.download('punkt')"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {
79
+ "id": "Qoow8Wd8ITtm"
80
+ },
81
+ "outputs": [],
82
+ "source": [
83
+ "%cd StyleTTS2\n",
84
+ "\n",
85
+ "import torch\n",
86
+ "torch.manual_seed(0)\n",
87
+ "torch.backends.cudnn.benchmark = False\n",
88
+ "torch.backends.cudnn.deterministic = True\n",
89
+ "\n",
90
+ "import random\n",
91
+ "random.seed(0)\n",
92
+ "\n",
93
+ "import numpy as np\n",
94
+ "np.random.seed(0)\n",
95
+ "\n",
96
+ "# load packages\n",
97
+ "import time\n",
98
+ "import random\n",
99
+ "import yaml\n",
100
+ "from munch import Munch\n",
101
+ "import numpy as np\n",
102
+ "import torch\n",
103
+ "from torch import nn\n",
104
+ "import torch.nn.functional as F\n",
105
+ "import torchaudio\n",
106
+ "import librosa\n",
107
+ "from nltk.tokenize import word_tokenize\n",
108
+ "\n",
109
+ "from models import *\n",
110
+ "from utils import *\n",
111
+ "from text_utils import TextCleaner\n",
112
+ "textclenaer = TextCleaner()\n",
113
+ "\n",
114
+ "%matplotlib inline\n",
115
+ "\n",
116
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
117
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
118
+ "mean, std = -4, 4\n",
119
+ "\n",
120
+ "def length_to_mask(lengths):\n",
121
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
122
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
123
+ " return mask\n",
124
+ "\n",
125
+ "def preprocess(wave):\n",
126
+ " wave_tensor = torch.from_numpy(wave).float()\n",
127
+ " mel_tensor = to_mel(wave_tensor)\n",
128
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
129
+ " return mel_tensor\n",
130
+ "\n",
131
+ "def compute_style(path):\n",
132
+ " wave, sr = librosa.load(path, sr=24000)\n",
133
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
134
+ " if sr != 24000:\n",
135
+ " audio = librosa.resample(audio, sr, 24000)\n",
136
+ " mel_tensor = preprocess(audio).to(device)\n",
137
+ "\n",
138
+ " with torch.no_grad():\n",
139
+ " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
140
+ " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
141
+ "\n",
142
+ " return torch.cat([ref_s, ref_p], dim=1)\n",
143
+ "\n",
144
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
145
+ "\n",
146
+ "# load phonemizer\n",
147
+ "import phonemizer\n",
148
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n",
149
+ "\n",
150
+ "config = yaml.safe_load(open(\"Models/LibriTTS/config.yml\"))\n",
151
+ "\n",
152
+ "# load pretrained ASR model\n",
153
+ "ASR_config = config.get('ASR_config', False)\n",
154
+ "ASR_path = config.get('ASR_path', False)\n",
155
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
156
+ "\n",
157
+ "# load pretrained F0 model\n",
158
+ "F0_path = config.get('F0_path', False)\n",
159
+ "pitch_extractor = load_F0_models(F0_path)\n",
160
+ "\n",
161
+ "# load BERT model\n",
162
+ "from Utils.PLBERT.util import load_plbert\n",
163
+ "BERT_path = config.get('PLBERT_dir', False)\n",
164
+ "plbert = load_plbert(BERT_path)\n",
165
+ "\n",
166
+ "model_params = recursive_munch(config['model_params'])\n",
167
+ "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
168
+ "_ = [model[key].eval() for key in model]\n",
169
+ "_ = [model[key].to(device) for key in model]\n",
170
+ "\n",
171
+ "params_whole = torch.load(\"Models/LibriTTS/epochs_2nd_00020.pth\", map_location='cpu')\n",
172
+ "params = params_whole['net']\n",
173
+ "\n",
174
+ "for key in model:\n",
175
+ " if key in params:\n",
176
+ " print('%s loaded' % key)\n",
177
+ " try:\n",
178
+ " model[key].load_state_dict(params[key])\n",
179
+ " except:\n",
180
+ " from collections import OrderedDict\n",
181
+ " state_dict = params[key]\n",
182
+ " new_state_dict = OrderedDict()\n",
183
+ " for k, v in state_dict.items():\n",
184
+ " name = k[7:] # remove `module.`\n",
185
+ " new_state_dict[name] = v\n",
186
+ " # load params\n",
187
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
188
+ "# except:\n",
189
+ "# _load(params[key], model[key])\n",
190
+ "_ = [model[key].eval() for key in model]\n",
191
+ "\n",
192
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n",
193
+ "\n",
194
+ "sampler = DiffusionSampler(\n",
195
+ " model.diffusion.diffusion,\n",
196
+ " sampler=ADPM2Sampler(),\n",
197
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
198
+ " clamp=False\n",
199
+ ")\n",
200
+ "\n",
201
+ "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
202
+ " text = text.strip()\n",
203
+ " ps = global_phonemizer.phonemize([text])\n",
204
+ " ps = word_tokenize(ps[0])\n",
205
+ " ps = ' '.join(ps)\n",
206
+ " tokens = textclenaer(ps)\n",
207
+ " tokens.insert(0, 0)\n",
208
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
209
+ "\n",
210
+ " with torch.no_grad():\n",
211
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
212
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
213
+ "\n",
214
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
215
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
216
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
217
+ "\n",
218
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n",
219
+ " embedding=bert_dur,\n",
220
+ " embedding_scale=embedding_scale,\n",
221
+ " features=ref_s, # reference from the same speaker as the embedding\n",
222
+ " num_steps=diffusion_steps).squeeze(1)\n",
223
+ "\n",
224
+ "\n",
225
+ " s = s_pred[:, 128:]\n",
226
+ " ref = s_pred[:, :128]\n",
227
+ "\n",
228
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
229
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
230
+ "\n",
231
+ " d = model.predictor.text_encoder(d_en,\n",
232
+ " s, input_lengths, text_mask)\n",
233
+ "\n",
234
+ " x, _ = model.predictor.lstm(d)\n",
235
+ " duration = model.predictor.duration_proj(x)\n",
236
+ "\n",
237
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
238
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
239
+ "\n",
240
+ "\n",
241
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
242
+ " c_frame = 0\n",
243
+ " for i in range(pred_aln_trg.size(0)):\n",
244
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
245
+ " c_frame += int(pred_dur[i].data)\n",
246
+ "\n",
247
+ " # encode prosody\n",
248
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
249
+ " if model_params.decoder.type == \"hifigan\":\n",
250
+ " asr_new = torch.zeros_like(en)\n",
251
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
252
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
253
+ " en = asr_new\n",
254
+ "\n",
255
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
256
+ "\n",
257
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
258
+ " if model_params.decoder.type == \"hifigan\":\n",
259
+ " asr_new = torch.zeros_like(asr)\n",
260
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
261
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
262
+ " asr = asr_new\n",
263
+ "\n",
264
+ " out = model.decoder(asr,\n",
265
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
266
+ "\n",
267
+ "\n",
268
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later\n",
269
+ "\n",
270
+ "def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):\n",
271
+ " text = text.strip()\n",
272
+ " ps = global_phonemizer.phonemize([text])\n",
273
+ " ps = word_tokenize(ps[0])\n",
274
+ " ps = ' '.join(ps)\n",
275
+ " ps = ps.replace('``', '\"')\n",
276
+ " ps = ps.replace(\"''\", '\"')\n",
277
+ "\n",
278
+ " tokens = textclenaer(ps)\n",
279
+ " tokens.insert(0, 0)\n",
280
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
281
+ "\n",
282
+ " with torch.no_grad():\n",
283
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
284
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
285
+ "\n",
286
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
287
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
288
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
289
+ "\n",
290
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n",
291
+ " embedding=bert_dur,\n",
292
+ " embedding_scale=embedding_scale,\n",
293
+ " features=ref_s, # reference from the same speaker as the embedding\n",
294
+ " num_steps=diffusion_steps).squeeze(1)\n",
295
+ "\n",
296
+ " if s_prev is not None:\n",
297
+ " # convex combination of previous and current style\n",
298
+ " s_pred = t * s_prev + (1 - t) * s_pred\n",
299
+ "\n",
300
+ " s = s_pred[:, 128:]\n",
301
+ " ref = s_pred[:, :128]\n",
302
+ "\n",
303
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
304
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
305
+ "\n",
306
+ " s_pred = torch.cat([ref, s], dim=-1)\n",
307
+ "\n",
308
+ " d = model.predictor.text_encoder(d_en,\n",
309
+ " s, input_lengths, text_mask)\n",
310
+ "\n",
311
+ " x, _ = model.predictor.lstm(d)\n",
312
+ " duration = model.predictor.duration_proj(x)\n",
313
+ "\n",
314
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
315
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
316
+ "\n",
317
+ "\n",
318
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
319
+ " c_frame = 0\n",
320
+ " for i in range(pred_aln_trg.size(0)):\n",
321
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
322
+ " c_frame += int(pred_dur[i].data)\n",
323
+ "\n",
324
+ " # encode prosody\n",
325
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
326
+ " if model_params.decoder.type == \"hifigan\":\n",
327
+ " asr_new = torch.zeros_like(en)\n",
328
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
329
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
330
+ " en = asr_new\n",
331
+ "\n",
332
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
333
+ "\n",
334
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
335
+ " if model_params.decoder.type == \"hifigan\":\n",
336
+ " asr_new = torch.zeros_like(asr)\n",
337
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
338
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
339
+ " asr = asr_new\n",
340
+ "\n",
341
+ " out = model.decoder(asr,\n",
342
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
343
+ "\n",
344
+ "\n",
345
+ " return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later\n",
346
+ "\n",
347
+ "def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
348
+ " text = text.strip()\n",
349
+ " ps = global_phonemizer.phonemize([text])\n",
350
+ " ps = word_tokenize(ps[0])\n",
351
+ " ps = ' '.join(ps)\n",
352
+ "\n",
353
+ " tokens = textclenaer(ps)\n",
354
+ " tokens.insert(0, 0)\n",
355
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
356
+ "\n",
357
+ " ref_text = ref_text.strip()\n",
358
+ " ps = global_phonemizer.phonemize([ref_text])\n",
359
+ " ps = word_tokenize(ps[0])\n",
360
+ " ps = ' '.join(ps)\n",
361
+ "\n",
362
+ " ref_tokens = textclenaer(ps)\n",
363
+ " ref_tokens.insert(0, 0)\n",
364
+ " ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)\n",
365
+ "\n",
366
+ "\n",
367
+ " with torch.no_grad():\n",
368
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
369
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
370
+ "\n",
371
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
372
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
373
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
374
+ "\n",
375
+ " ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)\n",
376
+ " ref_text_mask = length_to_mask(ref_input_lengths).to(device)\n",
377
+ " ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())\n",
378
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n",
379
+ " embedding=bert_dur,\n",
380
+ " embedding_scale=embedding_scale,\n",
381
+ " features=ref_s, # reference from the same speaker as the embedding\n",
382
+ " num_steps=diffusion_steps).squeeze(1)\n",
383
+ "\n",
384
+ "\n",
385
+ " s = s_pred[:, 128:]\n",
386
+ " ref = s_pred[:, :128]\n",
387
+ "\n",
388
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
389
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
390
+ "\n",
391
+ " d = model.predictor.text_encoder(d_en,\n",
392
+ " s, input_lengths, text_mask)\n",
393
+ "\n",
394
+ " x, _ = model.predictor.lstm(d)\n",
395
+ " duration = model.predictor.duration_proj(x)\n",
396
+ "\n",
397
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
398
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
399
+ "\n",
400
+ "\n",
401
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
402
+ " c_frame = 0\n",
403
+ " for i in range(pred_aln_trg.size(0)):\n",
404
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
405
+ " c_frame += int(pred_dur[i].data)\n",
406
+ "\n",
407
+ " # encode prosody\n",
408
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
409
+ " if model_params.decoder.type == \"hifigan\":\n",
410
+ " asr_new = torch.zeros_like(en)\n",
411
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
412
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
413
+ " en = asr_new\n",
414
+ "\n",
415
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
416
+ "\n",
417
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
418
+ " if model_params.decoder.type == \"hifigan\":\n",
419
+ " asr_new = torch.zeros_like(asr)\n",
420
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
421
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
422
+ " asr = asr_new\n",
423
+ "\n",
424
+ " out = model.decoder(asr,\n",
425
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
426
+ "\n",
427
+ "\n",
428
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later\n"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "markdown",
433
+ "metadata": {
434
+ "id": "32S6U0LyJbCA"
435
+ },
436
+ "source": [
437
+ "### Synthesize speech"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "markdown",
442
+ "metadata": {
443
+ "id": "ehK_0daMJdk_"
444
+ },
445
+ "source": [
446
+ "#### Basic synthesis (5 diffusion steps, seen speakers)"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": null,
452
+ "metadata": {
453
+ "id": "SJs2x41MJhM-"
454
+ },
455
+ "outputs": [],
456
+ "source": [
457
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. ''' # @param {type:\"string\"}\n"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": null,
463
+ "metadata": {
464
+ "id": "xuqIJe-IJb7A"
465
+ },
466
+ "outputs": [],
467
+ "source": [
468
+ "reference_dicts = {}\n",
469
+ "reference_dicts['696_92939'] = \"Demo/reference_audio/696_92939_000016_000006.wav\"\n",
470
+ "reference_dicts['1789_142896'] = \"Demo/reference_audio/1789_142896_000022_000005.wav\""
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {
477
+ "id": "H3ra3IxJJmF0"
478
+ },
479
+ "outputs": [],
480
+ "source": [
481
+ "noise = torch.randn(1,1,256).to(device)\n",
482
+ "for k, path in reference_dicts.items():\n",
483
+ " ref_s = compute_style(path)\n",
484
+ " start = time.time()\n",
485
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
486
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
487
+ " print(f\"RTF = {rtf:5f}\")\n",
488
+ " import IPython.display as ipd\n",
489
+ " print(k + ' Synthesized:')\n",
490
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
491
+ " print('Reference:')\n",
492
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "markdown",
497
+ "metadata": {
498
+ "id": "aB3wUz6yJ-P_"
499
+ },
500
+ "source": [
501
+ "#### With higher diffusion steps (more diverse)\n",
502
+ "\n",
503
+ "Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed."
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "execution_count": null,
509
+ "metadata": {
510
+ "id": "lF27XUo4JrKk"
511
+ },
512
+ "outputs": [],
513
+ "source": [
514
+ "noise = torch.randn(1,1,256).to(device)\n",
515
+ "for k, path in reference_dicts.items():\n",
516
+ " ref_s = compute_style(path)\n",
517
+ " start = time.time()\n",
518
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=10, embedding_scale=1)\n",
519
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
520
+ " print(f\"RTF = {rtf:5f}\")\n",
521
+ " import IPython.display as ipd\n",
522
+ " print(k + ' Synthesized:')\n",
523
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
524
+ " print(k + ' Reference:')\n",
525
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "markdown",
530
+ "metadata": {
531
+ "id": "pFT_vmJcKDs1"
532
+ },
533
+ "source": [
534
+ "#### Basic synthesis (5 diffusion steps, unseen speakers)\n",
535
+ "The following samples are to reproduce samples in [Section 4](https://styletts2.github.io/#libri) of the demo page. All spsakers are unseen during training. You can compare the generated samples to popular zero-shot TTS models like Vall-E and NaturalSpeech 2."
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "execution_count": null,
541
+ "metadata": {
542
+ "id": "HvNAeGPEKAWN"
543
+ },
544
+ "outputs": [],
545
+ "source": [
546
+ "reference_dicts = {}\n",
547
+ "# format: (path, text)\n",
548
+ "reference_dicts['1221-135767'] = (\"Demo/reference_audio/1221-135767-0014.wav\", \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\")\n",
549
+ "reference_dicts['5639-40744'] = (\"Demo/reference_audio/5639-40744-0020.wav\", \"Thus did this humane and right minded father comfort his unhappy daughter, and her mother embracing her again, did all she could to soothe her feelings.\")\n",
550
+ "reference_dicts['908-157963'] = (\"Demo/reference_audio/908-157963-0027.wav\", \"And lay me down in my cold bed and leave my shining lot.\")\n",
551
+ "reference_dicts['4077-13754'] = (\"Demo/reference_audio/4077-13754-0000.wav\", \"The army found the people in poverty and left them in comparative wealth.\")"
552
+ ]
553
+ },
554
+ {
555
+ "cell_type": "code",
556
+ "execution_count": null,
557
+ "metadata": {
558
+ "id": "mFnyvYp5KAYN"
559
+ },
560
+ "outputs": [],
561
+ "source": [
562
+ "noise = torch.randn(1,1,256).to(device)\n",
563
+ "for k, v in reference_dicts.items():\n",
564
+ " path, text = v\n",
565
+ " ref_s = compute_style(path)\n",
566
+ " start = time.time()\n",
567
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
568
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
569
+ " print(f\"RTF = {rtf:5f}\")\n",
570
+ " import IPython.display as ipd\n",
571
+ " print(k + ' Synthesized: ' + text)\n",
572
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
573
+ " print(k + ' Reference:')\n",
574
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
575
+ ]
576
+ },
577
+ {
578
+ "cell_type": "markdown",
579
+ "metadata": {
580
+ "id": "QBZ53BQtKNQ6"
581
+ },
582
+ "source": [
583
+ "### Speech expressiveness\n",
584
+ "\n",
585
+ "The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page. The speaker reference used is `1221-135767-0014.wav`, which is unseen during training.\n",
586
+ "\n",
587
+ "#### With `embedding_scale=1`\n",
588
+ "This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional."
589
+ ]
590
+ },
591
+ {
592
+ "cell_type": "code",
593
+ "execution_count": null,
594
+ "metadata": {
595
+ "id": "5FwE9CefKQk6"
596
+ },
597
+ "outputs": [],
598
+ "source": [
599
+ "ref_s = compute_style(\"Demo/reference_audio/1221-135767-0014.wav\")"
600
+ ]
601
+ },
602
+ {
603
+ "cell_type": "code",
604
+ "execution_count": null,
605
+ "metadata": {
606
+ "id": "0CKMI0ZsKUDh"
607
+ },
608
+ "outputs": [],
609
+ "source": [
610
+ "texts = {}\n",
611
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
612
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
613
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
614
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
615
+ "\n",
616
+ "for k,v in texts.items():\n",
617
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
618
+ " print(k + \": \")\n",
619
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "markdown",
624
+ "metadata": {
625
+ "id": "reemQKVEKWAZ"
626
+ },
627
+ "source": [
628
+ "#### With `embedding_scale=2`"
629
+ ]
630
+ },
631
+ {
632
+ "cell_type": "code",
633
+ "execution_count": null,
634
+ "metadata": {
635
+ "id": "npIAiAUvKYGv"
636
+ },
637
+ "outputs": [],
638
+ "source": [
639
+ "texts = {}\n",
640
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
641
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
642
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
643
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
644
+ "\n",
645
+ "for k,v in texts.items():\n",
646
+ " noise = torch.randn(1,1,256).to(device)\n",
647
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=2)\n",
648
+ " print(k + \": \")\n",
649
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "markdown",
654
+ "metadata": {
655
+ "id": "lqKZaXeYKbrH"
656
+ },
657
+ "source": [
658
+ "#### With `embedding_scale=2, alpha = 0.5, beta = 0.9`\n",
659
+ "`alpha` and `beta` is the factor to determine much we use the style sampled based on the text instead of the reference. The higher the value of `alpha` and `beta`, the more suitable the style it is to the text but less similar to the reference. Using higher beta makes the synthesized speech more emotional, at the cost of lower similarity to the reference. `alpha` determines the timbre of the speaker while `beta` determines the prosody."
660
+ ]
661
+ },
662
+ {
663
+ "cell_type": "code",
664
+ "execution_count": null,
665
+ "metadata": {
666
+ "id": "VjXuRCCWKcdN"
667
+ },
668
+ "outputs": [],
669
+ "source": [
670
+ "texts = {}\n",
671
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
672
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
673
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
674
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
675
+ "\n",
676
+ "for k,v in texts.items():\n",
677
+ " noise = torch.randn(1,1,256).to(device)\n",
678
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=2)\n",
679
+ " print(k + \": \")\n",
680
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
681
+ ]
682
+ },
683
+ {
684
+ "cell_type": "markdown",
685
+ "metadata": {
686
+ "id": "xrwYXGh0KiIW"
687
+ },
688
+ "source": [
689
+ "### Zero-shot speaker adaptation\n",
690
+ "This section recreates the \"Acoustic Environment Maintenance\" and \"Speaker’s Emotion Maintenance\" demo in [Section 4](https://styletts2.github.io/#libri) of the demo page. You can compare the generated samples to popular zero-shot TTS models like Vall-E. Note that the model was trained only on LibriTTS, which is about 250 times fewer data compared to those used to trian Vall-E with similar or better effect for these maintainance."
691
+ ]
692
+ },
693
+ {
694
+ "cell_type": "markdown",
695
+ "metadata": {
696
+ "id": "ETUywHHmKimE"
697
+ },
698
+ "source": [
699
+ "#### Acoustic Environment Maintenance\n",
700
+ "\n",
701
+ "Since we want to maintain the acoustic environment in the speaker (timbre), we set `alpha = 0` to make the speaker as close to the reference as possible while only changing the prosody according to the text. "
702
+ ]
703
+ },
704
+ {
705
+ "cell_type": "code",
706
+ "execution_count": null,
707
+ "metadata": {
708
+ "id": "yvjBK3syKnZL"
709
+ },
710
+ "outputs": [],
711
+ "source": [
712
+ "reference_dicts = {}\n",
713
+ "# format: (path, text)\n",
714
+ "reference_dicts['3'] = (\"Demo/reference_audio/3.wav\", \"As friends thing I definitely I've got more male friends.\")\n",
715
+ "reference_dicts['4'] = (\"Demo/reference_audio/4.wav\", \"Everything is run by computer but you got to know how to think before you can do a computer.\")\n",
716
+ "reference_dicts['5'] = (\"Demo/reference_audio/5.wav\", \"Then out in LA you guys got a whole another ball game within California to worry about.\")"
717
+ ]
718
+ },
719
+ {
720
+ "cell_type": "code",
721
+ "execution_count": null,
722
+ "metadata": {
723
+ "id": "jclowWp4KomJ"
724
+ },
725
+ "outputs": [],
726
+ "source": [
727
+ "noise = torch.randn(1,1,256).to(device)\n",
728
+ "for k, v in reference_dicts.items():\n",
729
+ " path, text = v\n",
730
+ " ref_s = compute_style(path)\n",
731
+ " start = time.time()\n",
732
+ " wav = inference(text, ref_s, alpha=0.0, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
733
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
734
+ " print(f\"RTF = {rtf:5f}\")\n",
735
+ " import IPython.display as ipd\n",
736
+ " print('Synthesized: ' + text)\n",
737
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
738
+ " print('Reference:')\n",
739
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
740
+ ]
741
+ },
742
+ {
743
+ "cell_type": "markdown",
744
+ "metadata": {
745
+ "id": "LgIm7M93KqVZ"
746
+ },
747
+ "source": [
748
+ "#### Speaker’s Emotion Maintenance\n",
749
+ "\n",
750
+ "Since we want to maintain the emotion in the speaker (prosody), we set `beta = 0.1` to make the speaker as closer to the reference as possible while having some diversity thruogh the slight timbre change."
751
+ ]
752
+ },
753
+ {
754
+ "cell_type": "code",
755
+ "execution_count": null,
756
+ "metadata": {
757
+ "id": "yzsNoP6oKulL"
758
+ },
759
+ "outputs": [],
760
+ "source": [
761
+ "reference_dicts = {}\n",
762
+ "# format: (path, text)\n",
763
+ "reference_dicts['Anger'] = (\"Demo/reference_audio/anger.wav\", \"We have to reduce the number of plastic bags.\")\n",
764
+ "reference_dicts['Sleepy'] = (\"Demo/reference_audio/sleepy.wav\", \"We have to reduce the number of plastic bags.\")\n",
765
+ "reference_dicts['Amused'] = (\"Demo/reference_audio/amused.wav\", \"We have to reduce the number of plastic bags.\")\n",
766
+ "reference_dicts['Disgusted'] = (\"Demo/reference_audio/disgusted.wav\", \"We have to reduce the number of plastic bags.\")"
767
+ ]
768
+ },
769
+ {
770
+ "cell_type": "code",
771
+ "execution_count": null,
772
+ "metadata": {
773
+ "id": "7h2-9cpfKwr4"
774
+ },
775
+ "outputs": [],
776
+ "source": [
777
+ "noise = torch.randn(1,1,256).to(device)\n",
778
+ "for k, v in reference_dicts.items():\n",
779
+ " path, text = v\n",
780
+ " ref_s = compute_style(path)\n",
781
+ " start = time.time()\n",
782
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.1, diffusion_steps=10, embedding_scale=1)\n",
783
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
784
+ " print(f\"RTF = {rtf:5f}\")\n",
785
+ " import IPython.display as ipd\n",
786
+ " print(k + ' Synthesized: ' + text)\n",
787
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
788
+ " print(k + ' Reference:')\n",
789
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
790
+ ]
791
+ },
792
+ {
793
+ "cell_type": "markdown",
794
+ "metadata": {
795
+ "id": "aNS82PGwKzgg"
796
+ },
797
+ "source": [
798
+ "### Longform Narration\n",
799
+ "\n",
800
+ "This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page."
801
+ ]
802
+ },
803
+ {
804
+ "cell_type": "code",
805
+ "execution_count": null,
806
+ "metadata": {
807
+ "cellView": "form",
808
+ "id": "qs97nL5HK5DH"
809
+ },
810
+ "outputs": [],
811
+ "source": [
812
+ "passage = passage = '''If the supply of fruit is greater than the family needs, it may be made a source of income by sending the fresh fruit to the market if there is one near enough, or by preserving, canning, and making jelly for sale. To make such an enterprise a success the fruit and work must be first class. There is magic in the word \"Homemade,\" when the product appeals to the eye and the palate; but many careless and incompetent people have found to their sorrow that this word has not magic enough to float inferior goods on the market. As a rule large canning and preserving establishments are clean and have the best appliances, and they employ chemists and skilled labor. The home product must be very good to compete with the attractive goods that are sent out from such establishments. Yet for first class home made products there is a market in all large cities. All first-class grocers have customers who purchase such goods.''' # @param {type:\"string\"}"
813
+ ]
814
+ },
815
+ {
816
+ "cell_type": "code",
817
+ "execution_count": null,
818
+ "metadata": {
819
+ "colab": {
820
+ "background_save": true
821
+ },
822
+ "id": "8Mu9whHYK_1b"
823
+ },
824
+ "outputs": [],
825
+ "source": [
826
+ "# seen speaker\n",
827
+ "path = \"Demo/reference_audio/696_92939_000016_000006.wav\"\n",
828
+ "s_ref = compute_style(path)\n",
829
+ "sentences = passage.split('.') # simple split by comma\n",
830
+ "wavs = []\n",
831
+ "s_prev = None\n",
832
+ "for text in sentences:\n",
833
+ " if text.strip() == \"\": continue\n",
834
+ " text += '.' # add it back\n",
835
+ "\n",
836
+ " wav, s_prev = LFinference(text,\n",
837
+ " s_prev,\n",
838
+ " s_ref,\n",
839
+ " alpha = 0.3,\n",
840
+ " beta = 0.9, # make it more suitable for the text\n",
841
+ " t = 0.7,\n",
842
+ " diffusion_steps=10, embedding_scale=1.5)\n",
843
+ " wavs.append(wav)\n",
844
+ "print('Synthesized: ')\n",
845
+ "display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))\n",
846
+ "print('Reference: ')\n",
847
+ "display(ipd.Audio(path, rate=24000, normalize=False))"
848
+ ]
849
+ },
850
+ {
851
+ "cell_type": "markdown",
852
+ "metadata": {
853
+ "id": "81Rh-lgWLB2i"
854
+ },
855
+ "source": [
856
+ "### Style Transfer\n",
857
+ "\n",
858
+ "The following section demostrates the style transfer capacity for unseen speakers in [Section 6](https://styletts2.github.io/#emo) of the demo page. For this, we set `alpha=0.5, beta = 0.9` for the most pronounced effects (mostly using the sampled style)."
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "code",
863
+ "execution_count": null,
864
+ "metadata": {
865
+ "id": "CtIgr5kOLE9a"
866
+ },
867
+ "outputs": [],
868
+ "source": [
869
+ "# reference texts to sample styles\n",
870
+ "\n",
871
+ "ref_texts = {}\n",
872
+ "ref_texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
873
+ "ref_texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
874
+ "ref_texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
875
+ "ref_texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\""
876
+ ]
877
+ },
878
+ {
879
+ "cell_type": "code",
880
+ "execution_count": null,
881
+ "metadata": {
882
+ "id": "MlA1CbhzLIoI"
883
+ },
884
+ "outputs": [],
885
+ "source": [
886
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
887
+ "s_ref = compute_style(path)\n",
888
+ "\n",
889
+ "text = \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\"\n",
890
+ "for k,v in ref_texts.items():\n",
891
+ " wav = STinference(text, s_ref, v, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=1.5)\n",
892
+ " print(k + \": \")\n",
893
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
894
+ ]
895
+ },
896
+ {
897
+ "cell_type": "markdown",
898
+ "metadata": {
899
+ "id": "2M0iaXlkLJUQ"
900
+ },
901
+ "source": [
902
+ "### Speech diversity\n",
903
+ "\n",
904
+ "This section reproduces samples in [Section 7](https://styletts2.github.io/#var) of the demo page.\n",
905
+ "\n",
906
+ "`alpha` and `beta` determine the diversity of the synthesized speech. There are two extreme cases:\n",
907
+ "- If `alpha = 1` and `beta = 1`, the synthesized speech sounds the most dissimilar to the reference speaker, but it is also the most diverse (each time you synthesize a speech it will be totally different).\n",
908
+ "- If `alpha = 0` and `beta = 0`, the synthesized speech sounds the most siimlar to the reference speaker, but it is deterministic (i.e., the sampled style is not used for speech synthesis).\n"
909
+ ]
910
+ },
911
+ {
912
+ "cell_type": "markdown",
913
+ "metadata": {
914
+ "id": "tSxZDvF2LNu4"
915
+ },
916
+ "source": [
917
+ "#### Default setting (`alpha = 0.3, beta=0.7`)\n",
918
+ "This setting uses 70% of the reference timbre and 30% of the reference prosody and use the diffusion model to sample them based on the text."
919
+ ]
920
+ },
921
+ {
922
+ "cell_type": "code",
923
+ "execution_count": null,
924
+ "metadata": {
925
+ "id": "AAomGCDZLIt5"
926
+ },
927
+ "outputs": [],
928
+ "source": [
929
+ "# unseen speaker\n",
930
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
931
+ "ref_s = compute_style(path)\n",
932
+ "\n",
933
+ "text = \"How much variation is there?\"\n",
934
+ "for _ in range(5):\n",
935
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
936
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
937
+ ]
938
+ },
939
+ {
940
+ "cell_type": "markdown",
941
+ "metadata": {
942
+ "id": "BKrSMdgcLQRP"
943
+ },
944
+ "source": [
945
+ "#### Less diverse setting (`alpha = 0.1, beta=0.3`)\n",
946
+ "This setting uses 90% of the reference timbre and 70% of the reference prosody. This makes it more similar to the reference speaker at cost of less diverse samples."
947
+ ]
948
+ },
949
+ {
950
+ "cell_type": "code",
951
+ "execution_count": null,
952
+ "metadata": {
953
+ "id": "Uo7gVmFoLRfm"
954
+ },
955
+ "outputs": [],
956
+ "source": [
957
+ "# unseen speaker\n",
958
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
959
+ "ref_s = compute_style(path)\n",
960
+ "\n",
961
+ "text = \"How much variation is there?\"\n",
962
+ "for _ in range(5):\n",
963
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.1, beta=0.3, embedding_scale=1)\n",
964
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
965
+ ]
966
+ },
967
+ {
968
+ "cell_type": "markdown",
969
+ "metadata": {
970
+ "id": "nfQ0Xrg9LStd"
971
+ },
972
+ "source": [
973
+ "#### More diverse setting (`alpha = 0.5, beta=0.95`)\n",
974
+ "This setting uses 50% of the reference timbre and 5% of the reference prosody (so it uses 100% of the sampled prosody, which makes it more diverse), but this makes it more dissimilar to the reference speaker. "
975
+ ]
976
+ },
977
+ {
978
+ "cell_type": "code",
979
+ "execution_count": null,
980
+ "metadata": {
981
+ "id": "cPHz4BzVLT_u"
982
+ },
983
+ "outputs": [],
984
+ "source": [
985
+ "# unseen speaker\n",
986
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
987
+ "ref_s = compute_style(path)\n",
988
+ "\n",
989
+ "text = \"How much variation is there?\"\n",
990
+ "for _ in range(5):\n",
991
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.5, beta=0.95, embedding_scale=1)\n",
992
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
993
+ ]
994
+ },
995
+ {
996
+ "cell_type": "markdown",
997
+ "source": [
998
+ "#### Extreme setting (`alpha = 1, beta=1`)\n",
999
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very dissimilar to the reference speaker."
1000
+ ],
1001
+ "metadata": {
1002
+ "id": "hPKg9eYpL00f"
1003
+ }
1004
+ },
1005
+ {
1006
+ "cell_type": "code",
1007
+ "source": [
1008
+ "# unseen speaker\n",
1009
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1010
+ "ref_s = compute_style(path)\n",
1011
+ "\n",
1012
+ "text = \"How much variation is there?\"\n",
1013
+ "for _ in range(5):\n",
1014
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=1, beta=1, embedding_scale=1)\n",
1015
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1016
+ ],
1017
+ "metadata": {
1018
+ "id": "Ei-7JOccL0bF"
1019
+ },
1020
+ "execution_count": null,
1021
+ "outputs": []
1022
+ },
1023
+ {
1024
+ "cell_type": "markdown",
1025
+ "source": [
1026
+ "#### No variation (`alpha = 0, beta=0`)\n",
1027
+ "This setting uses 100% of the reference timbre and prosody and do not use the diffusion model at all. This makes the speaker very similar to the reference speaker, but there is no variation."
1028
+ ],
1029
+ "metadata": {
1030
+ "id": "FVMPc3bhL3eL"
1031
+ }
1032
+ },
1033
+ {
1034
+ "cell_type": "code",
1035
+ "source": [
1036
+ "# unseen speaker\n",
1037
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1038
+ "ref_s = compute_style(path)\n",
1039
+ "\n",
1040
+ "text = \"How much variation is there?\"\n",
1041
+ "for _ in range(5):\n",
1042
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0, beta=0, embedding_scale=1)\n",
1043
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1044
+ ],
1045
+ "metadata": {
1046
+ "id": "yh1QZ7uhL4wM"
1047
+ },
1048
+ "execution_count": null,
1049
+ "outputs": []
1050
+ },
1051
+ {
1052
+ "cell_type": "markdown",
1053
+ "source": [
1054
+ "### Extra fun!\n",
1055
+ "\n",
1056
+ "You can record your own voice and clone it using pre-trained StyleTTS 2 model here."
1057
+ ],
1058
+ "metadata": {
1059
+ "id": "T0EvkWrAMBDB"
1060
+ }
1061
+ },
1062
+ {
1063
+ "cell_type": "markdown",
1064
+ "source": [
1065
+ "#### Run the following cell to record your voice for 5 seconds. Please keep speaking to have the best effect."
1066
+ ],
1067
+ "metadata": {
1068
+ "id": "R985j5QONY8I"
1069
+ }
1070
+ },
1071
+ {
1072
+ "cell_type": "code",
1073
+ "source": [
1074
+ "# all imports\n",
1075
+ "from IPython.display import Javascript\n",
1076
+ "from google.colab import output\n",
1077
+ "from base64 import b64decode\n",
1078
+ "\n",
1079
+ "RECORD = \"\"\"\n",
1080
+ "const sleep = time => new Promise(resolve => setTimeout(resolve, time))\n",
1081
+ "const b2text = blob => new Promise(resolve => {\n",
1082
+ " const reader = new FileReader()\n",
1083
+ " reader.onloadend = e => resolve(e.srcElement.result)\n",
1084
+ " reader.readAsDataURL(blob)\n",
1085
+ "})\n",
1086
+ "var record = time => new Promise(async resolve => {\n",
1087
+ " stream = await navigator.mediaDevices.getUserMedia({ audio: true })\n",
1088
+ " recorder = new MediaRecorder(stream)\n",
1089
+ " chunks = []\n",
1090
+ " recorder.ondataavailable = e => chunks.push(e.data)\n",
1091
+ " recorder.start()\n",
1092
+ " await sleep(time)\n",
1093
+ " recorder.onstop = async ()=>{\n",
1094
+ " blob = new Blob(chunks)\n",
1095
+ " text = await b2text(blob)\n",
1096
+ " resolve(text)\n",
1097
+ " }\n",
1098
+ " recorder.stop()\n",
1099
+ "})\n",
1100
+ "\"\"\"\n",
1101
+ "\n",
1102
+ "def record(sec=3):\n",
1103
+ " display(Javascript(RECORD))\n",
1104
+ " s = output.eval_js('record(%d)' % (sec*1000))\n",
1105
+ " b = b64decode(s.split(',')[1])\n",
1106
+ " with open('audio.wav','wb') as f:\n",
1107
+ " f.write(b)\n",
1108
+ " return 'audio.wav' # or webm ?"
1109
+ ],
1110
+ "metadata": {
1111
+ "id": "MWrFs0KWMBpz"
1112
+ },
1113
+ "execution_count": null,
1114
+ "outputs": []
1115
+ },
1116
+ {
1117
+ "cell_type": "markdown",
1118
+ "source": [
1119
+ "#### Please run this cell and speak:"
1120
+ ],
1121
+ "metadata": {
1122
+ "id": "z35qXwM0Nhx1"
1123
+ }
1124
+ },
1125
+ {
1126
+ "cell_type": "code",
1127
+ "source": [
1128
+ "print('Speak now for 5 seconds.')\n",
1129
+ "audio = record(sec=5)\n",
1130
+ "import IPython.display as ipd\n",
1131
+ "display(ipd.Audio(audio, rate=24000, normalize=False))"
1132
+ ],
1133
+ "metadata": {
1134
+ "id": "KUEoFyQBMR-8"
1135
+ },
1136
+ "execution_count": null,
1137
+ "outputs": []
1138
+ },
1139
+ {
1140
+ "cell_type": "markdown",
1141
+ "source": [
1142
+ "#### Synthesize in your own voice"
1143
+ ],
1144
+ "metadata": {
1145
+ "id": "OQS_7IBpNmM1"
1146
+ }
1147
+ },
1148
+ {
1149
+ "cell_type": "code",
1150
+ "source": [
1151
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. ''' # @param {type:\"string\"}\n"
1152
+ ],
1153
+ "metadata": {
1154
+ "cellView": "form",
1155
+ "id": "c0I3LY7vM8Ta"
1156
+ },
1157
+ "execution_count": null,
1158
+ "outputs": []
1159
+ },
1160
+ {
1161
+ "cell_type": "code",
1162
+ "source": [
1163
+ "reference_dicts = {}\n",
1164
+ "reference_dicts['You'] = audio"
1165
+ ],
1166
+ "metadata": {
1167
+ "id": "80eW-pwxNCxu"
1168
+ },
1169
+ "execution_count": null,
1170
+ "outputs": []
1171
+ },
1172
+ {
1173
+ "cell_type": "code",
1174
+ "source": [
1175
+ "start = time.time()\n",
1176
+ "noise = torch.randn(1,1,256).to(device)\n",
1177
+ "for k, path in reference_dicts.items():\n",
1178
+ " ref_s = compute_style(path)\n",
1179
+ "\n",
1180
+ " wav = inference(text, ref_s, alpha=0.1, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
1181
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
1182
+ " print('Speaker: ' + k)\n",
1183
+ " import IPython.display as ipd\n",
1184
+ " print('Synthesized:')\n",
1185
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
1186
+ " print('Reference:')\n",
1187
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
1188
+ ],
1189
+ "metadata": {
1190
+ "id": "yIga6MTuNJaN"
1191
+ },
1192
+ "execution_count": null,
1193
+ "outputs": []
1194
+ }
1195
+ ],
1196
+ "metadata": {
1197
+ "accelerator": "GPU",
1198
+ "colab": {
1199
+ "provenance": [],
1200
+ "collapsed_sections": [
1201
+ "aAGQPfgYIR23",
1202
+ "eJdB_nCOIVIN",
1203
+ "R985j5QONY8I"
1204
+ ],
1205
+ "authorship_tag": "ABX9TyPQdFTqqVEknEG/ma/HMfU+",
1206
+ "include_colab_link": true
1207
+ },
1208
+ "kernelspec": {
1209
+ "display_name": "Python 3",
1210
+ "name": "python3"
1211
+ },
1212
+ "language_info": {
1213
+ "name": "python"
1214
+ }
1215
+ },
1216
+ "nbformat": 4,
1217
+ "nbformat_minor": 0
1218
+ }
Colab/StyleTTS2_Finetune_Demo.ipynb ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4",
8
+ "authorship_tag": "ABX9TyNiDU9ykIeYxO86Lmuid+ph",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "accelerator": "GPU"
19
+ },
20
+ "cells": [
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {
24
+ "id": "view-in-github",
25
+ "colab_type": "text"
26
+ },
27
+ "source": [
28
+ "<a href=\"https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/Colab/StyleTTS2_Finetune_Demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "source": [
34
+ "### Install packages and download models"
35
+ ],
36
+ "metadata": {
37
+ "id": "yLqBa4uYPrqE"
38
+ }
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "source": [
43
+ "%%shell\n",
44
+ "git clone https://github.com/yl4579/StyleTTS2.git\n",
45
+ "cd StyleTTS2\n",
46
+ "pip install SoundFile torchaudio munch torch pydub pyyaml librosa nltk matplotlib accelerate transformers phonemizer einops einops-exts tqdm typing-extensions git+https://github.com/resemble-ai/monotonic_align.git\n",
47
+ "sudo apt-get install espeak-ng\n",
48
+ "git-lfs clone https://huggingface.co/yl4579/StyleTTS2-LibriTTS\n",
49
+ "mv StyleTTS2-LibriTTS/Models ."
50
+ ],
51
+ "metadata": {
52
+ "id": "H72WF06ZPrTF"
53
+ },
54
+ "execution_count": null,
55
+ "outputs": []
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "source": [
60
+ "### Download dataset (LJSpeech, 200 samples, ~15 minutes of data)\n",
61
+ "\n",
62
+ "You can definitely do it with fewer samples. This is just a proof of concept with 200 smaples."
63
+ ],
64
+ "metadata": {
65
+ "id": "G398sL8wPzTB"
66
+ }
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "source": [
71
+ "%cd StyleTTS2\n",
72
+ "!rm -rf Data"
73
+ ],
74
+ "metadata": {
75
+ "id": "kJuQUBrEPy5C"
76
+ },
77
+ "execution_count": null,
78
+ "outputs": []
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "source": [
83
+ "!gdown --id 1vqz26D3yn7OXS2vbfYxfSnpLS6m6tOFP\n",
84
+ "!unzip Data.zip"
85
+ ],
86
+ "metadata": {
87
+ "id": "mDXW8ZZePuSb"
88
+ },
89
+ "execution_count": null,
90
+ "outputs": []
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "source": [
95
+ "### Change the finetuning config\n",
96
+ "\n",
97
+ "Depending on the GPU you got, you may want to change the bacth size, max audio length, epiochs and so on."
98
+ ],
99
+ "metadata": {
100
+ "id": "_AlBQREWU8ud"
101
+ }
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "source": [
106
+ "config_path = \"Configs/config_ft.yml\"\n",
107
+ "\n",
108
+ "import yaml\n",
109
+ "config = yaml.safe_load(open(config_path))"
110
+ ],
111
+ "metadata": {
112
+ "id": "7uEITi0hU4I2"
113
+ },
114
+ "execution_count": null,
115
+ "outputs": []
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "source": [
120
+ "config['data_params']['root_path'] = \"Data/wavs\"\n",
121
+ "\n",
122
+ "config['batch_size'] = 2 # not enough RAM\n",
123
+ "config['max_len'] = 100 # not enough RAM\n",
124
+ "config['loss_params']['joint_epoch'] = 110 # we do not do SLM adversarial training due to not enough RAM\n",
125
+ "\n",
126
+ "with open(config_path, 'w') as outfile:\n",
127
+ " yaml.dump(config, outfile, default_flow_style=True)"
128
+ ],
129
+ "metadata": {
130
+ "id": "TPTRgOKSVT4K"
131
+ },
132
+ "execution_count": null,
133
+ "outputs": []
134
+ },
135
+ {
136
+ "cell_type": "markdown",
137
+ "source": [
138
+ "### Start finetuning\n"
139
+ ],
140
+ "metadata": {
141
+ "id": "uUuB_19NWj2Y"
142
+ }
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "source": [
147
+ "!python train_finetune.py --config_path ./Configs/config_ft.yml"
148
+ ],
149
+ "metadata": {
150
+ "id": "HZVAD5GKWm-O"
151
+ },
152
+ "execution_count": null,
153
+ "outputs": []
154
+ },
155
+ {
156
+ "cell_type": "markdown",
157
+ "source": [
158
+ "### Test the model quality\n",
159
+ "\n",
160
+ "Note that this mainly serves as a proof of concept due to RAM limitation of free Colab instances. A lot of settings are suboptimal. In the future when DDP works for train_second.py, we will also add mixed precision finetuning to save time and RAM. You can also add SLM adversarial training run if you have paid Colab services (such as A100 with 40G of RAM)."
161
+ ],
162
+ "metadata": {
163
+ "id": "I0_7wsGkXGfc"
164
+ }
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "source": [
169
+ "import nltk\n",
170
+ "nltk.download('punkt')"
171
+ ],
172
+ "metadata": {
173
+ "id": "OPLphjbncE7p"
174
+ },
175
+ "execution_count": null,
176
+ "outputs": []
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "source": [
181
+ "import torch\n",
182
+ "torch.manual_seed(0)\n",
183
+ "torch.backends.cudnn.benchmark = False\n",
184
+ "torch.backends.cudnn.deterministic = True\n",
185
+ "\n",
186
+ "import random\n",
187
+ "random.seed(0)\n",
188
+ "\n",
189
+ "import numpy as np\n",
190
+ "np.random.seed(0)\n",
191
+ "\n",
192
+ "# load packages\n",
193
+ "import time\n",
194
+ "import random\n",
195
+ "import yaml\n",
196
+ "from munch import Munch\n",
197
+ "import numpy as np\n",
198
+ "import torch\n",
199
+ "from torch import nn\n",
200
+ "import torch.nn.functional as F\n",
201
+ "import torchaudio\n",
202
+ "import librosa\n",
203
+ "from nltk.tokenize import word_tokenize\n",
204
+ "\n",
205
+ "from models import *\n",
206
+ "from utils import *\n",
207
+ "from text_utils import TextCleaner\n",
208
+ "textclenaer = TextCleaner()\n",
209
+ "\n",
210
+ "%matplotlib inline\n",
211
+ "\n",
212
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
213
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
214
+ "mean, std = -4, 4\n",
215
+ "\n",
216
+ "def length_to_mask(lengths):\n",
217
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
218
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
219
+ " return mask\n",
220
+ "\n",
221
+ "def preprocess(wave):\n",
222
+ " wave_tensor = torch.from_numpy(wave).float()\n",
223
+ " mel_tensor = to_mel(wave_tensor)\n",
224
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
225
+ " return mel_tensor\n",
226
+ "\n",
227
+ "def compute_style(path):\n",
228
+ " wave, sr = librosa.load(path, sr=24000)\n",
229
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
230
+ " if sr != 24000:\n",
231
+ " audio = librosa.resample(audio, sr, 24000)\n",
232
+ " mel_tensor = preprocess(audio).to(device)\n",
233
+ "\n",
234
+ " with torch.no_grad():\n",
235
+ " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
236
+ " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
237
+ "\n",
238
+ " return torch.cat([ref_s, ref_p], dim=1)\n",
239
+ "\n",
240
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
241
+ "\n",
242
+ "# load phonemizer\n",
243
+ "import phonemizer\n",
244
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n",
245
+ "\n",
246
+ "config = yaml.safe_load(open(\"Models/LJSpeech/config_ft.yml\"))\n",
247
+ "\n",
248
+ "# load pretrained ASR model\n",
249
+ "ASR_config = config.get('ASR_config', False)\n",
250
+ "ASR_path = config.get('ASR_path', False)\n",
251
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
252
+ "\n",
253
+ "# load pretrained F0 model\n",
254
+ "F0_path = config.get('F0_path', False)\n",
255
+ "pitch_extractor = load_F0_models(F0_path)\n",
256
+ "\n",
257
+ "# load BERT model\n",
258
+ "from Utils.PLBERT.util import load_plbert\n",
259
+ "BERT_path = config.get('PLBERT_dir', False)\n",
260
+ "plbert = load_plbert(BERT_path)\n",
261
+ "\n",
262
+ "model_params = recursive_munch(config['model_params'])\n",
263
+ "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
264
+ "_ = [model[key].eval() for key in model]\n",
265
+ "_ = [model[key].to(device) for key in model]"
266
+ ],
267
+ "metadata": {
268
+ "id": "jIIAoDACXJL0"
269
+ },
270
+ "execution_count": null,
271
+ "outputs": []
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "source": [
276
+ "files = [f for f in os.listdir(\"Models/LJSpeech/\") if f.endswith('.pth')]\n",
277
+ "sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))"
278
+ ],
279
+ "metadata": {
280
+ "id": "eKXRAyyzcMpQ"
281
+ },
282
+ "execution_count": null,
283
+ "outputs": []
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "source": [
288
+ "params_whole = torch.load(\"Models/LJSpeech/\" + sorted_files[-1], map_location='cpu')\n",
289
+ "params = params_whole['net']"
290
+ ],
291
+ "metadata": {
292
+ "id": "ULuU9-VDb9Pk"
293
+ },
294
+ "execution_count": null,
295
+ "outputs": []
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "source": [
300
+ "for key in model:\n",
301
+ " if key in params:\n",
302
+ " print('%s loaded' % key)\n",
303
+ " try:\n",
304
+ " model[key].load_state_dict(params[key])\n",
305
+ " except:\n",
306
+ " from collections import OrderedDict\n",
307
+ " state_dict = params[key]\n",
308
+ " new_state_dict = OrderedDict()\n",
309
+ " for k, v in state_dict.items():\n",
310
+ " name = k[7:] # remove `module.`\n",
311
+ " new_state_dict[name] = v\n",
312
+ " # load params\n",
313
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
314
+ "# except:\n",
315
+ "# _load(params[key], model[key])\n",
316
+ "_ = [model[key].eval() for key in model]"
317
+ ],
318
+ "metadata": {
319
+ "id": "J-U29yIYc2ea"
320
+ },
321
+ "execution_count": null,
322
+ "outputs": []
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "source": [
327
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule"
328
+ ],
329
+ "metadata": {
330
+ "id": "jrPQ_Yrwc3n6"
331
+ },
332
+ "execution_count": null,
333
+ "outputs": []
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "source": [
338
+ "sampler = DiffusionSampler(\n",
339
+ " model.diffusion.diffusion,\n",
340
+ " sampler=ADPM2Sampler(),\n",
341
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
342
+ " clamp=False\n",
343
+ ")"
344
+ ],
345
+ "metadata": {
346
+ "id": "n2CWYNoqc455"
347
+ },
348
+ "execution_count": null,
349
+ "outputs": []
350
+ },
351
+ {
352
+ "cell_type": "code",
353
+ "source": [
354
+ "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
355
+ " text = text.strip()\n",
356
+ " ps = global_phonemizer.phonemize([text])\n",
357
+ " ps = word_tokenize(ps[0])\n",
358
+ " ps = ' '.join(ps)\n",
359
+ " tokens = textclenaer(ps)\n",
360
+ " tokens.insert(0, 0)\n",
361
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
362
+ "\n",
363
+ " with torch.no_grad():\n",
364
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
365
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
366
+ "\n",
367
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
368
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
369
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
370
+ "\n",
371
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n",
372
+ " embedding=bert_dur,\n",
373
+ " embedding_scale=embedding_scale,\n",
374
+ " features=ref_s, # reference from the same speaker as the embedding\n",
375
+ " num_steps=diffusion_steps).squeeze(1)\n",
376
+ "\n",
377
+ "\n",
378
+ " s = s_pred[:, 128:]\n",
379
+ " ref = s_pred[:, :128]\n",
380
+ "\n",
381
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
382
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
383
+ "\n",
384
+ " d = model.predictor.text_encoder(d_en,\n",
385
+ " s, input_lengths, text_mask)\n",
386
+ "\n",
387
+ " x, _ = model.predictor.lstm(d)\n",
388
+ " duration = model.predictor.duration_proj(x)\n",
389
+ "\n",
390
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
391
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
392
+ "\n",
393
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
394
+ " c_frame = 0\n",
395
+ " for i in range(pred_aln_trg.size(0)):\n",
396
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
397
+ " c_frame += int(pred_dur[i].data)\n",
398
+ "\n",
399
+ " # encode prosody\n",
400
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
401
+ " if model_params.decoder.type == \"hifigan\":\n",
402
+ " asr_new = torch.zeros_like(en)\n",
403
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
404
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
405
+ " en = asr_new\n",
406
+ "\n",
407
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
408
+ "\n",
409
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
410
+ " if model_params.decoder.type == \"hifigan\":\n",
411
+ " asr_new = torch.zeros_like(asr)\n",
412
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
413
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
414
+ " asr = asr_new\n",
415
+ "\n",
416
+ " out = model.decoder(asr,\n",
417
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
418
+ "\n",
419
+ "\n",
420
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
421
+ ],
422
+ "metadata": {
423
+ "id": "2x5kVb3nc_eY"
424
+ },
425
+ "execution_count": null,
426
+ "outputs": []
427
+ },
428
+ {
429
+ "cell_type": "markdown",
430
+ "source": [
431
+ "### Synthesize speech"
432
+ ],
433
+ "metadata": {
434
+ "id": "O159JnwCc6CC"
435
+ }
436
+ },
437
+ {
438
+ "cell_type": "code",
439
+ "source": [
440
+ "text = '''Maltby and Company would issue warrants on them deliverable to the importer, and the goods were then passed to be stored in neighboring warehouses.\n",
441
+ "'''"
442
+ ],
443
+ "metadata": {
444
+ "id": "ThciXQ6rc9Eq"
445
+ },
446
+ "execution_count": null,
447
+ "outputs": []
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "source": [
452
+ "# get a random reference in the training set, note that it doesn't matter which one you use\n",
453
+ "path = \"Data/wavs/LJ001-0110.wav\"\n",
454
+ "# this style vector ref_s can be saved as a parameter together with the model weights\n",
455
+ "ref_s = compute_style(path)"
456
+ ],
457
+ "metadata": {
458
+ "id": "jldPkJyCc83a"
459
+ },
460
+ "execution_count": null,
461
+ "outputs": []
462
+ },
463
+ {
464
+ "cell_type": "code",
465
+ "source": [
466
+ "start = time.time()\n",
467
+ "wav = inference(text, ref_s, alpha=0.9, beta=0.9, diffusion_steps=10, embedding_scale=1)\n",
468
+ "rtf = (time.time() - start) / (len(wav) / 24000)\n",
469
+ "print(f\"RTF = {rtf:5f}\")\n",
470
+ "import IPython.display as ipd\n",
471
+ "display(ipd.Audio(wav, rate=24000, normalize=False))"
472
+ ],
473
+ "metadata": {
474
+ "id": "_mIU0jqDdQ-c"
475
+ },
476
+ "execution_count": null,
477
+ "outputs": []
478
+ }
479
+ ]
480
+ }
Configs/config.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LJSpeech"
2
+ first_stage_path: "first_stage.pth"
3
+ save_freq: 2
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 200 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 100 # number of peochs for second stage training (joint training)
8
+ batch_size: 16
9
+ max_len: 400 # maximum number of frames
10
+ pretrained_model: ""
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "Utils/JDC/bst.t7"
15
+ ASR_config: "Utils/ASR/config.yml"
16
+ ASR_path: "Utils/ASR/epoch_00080.pth"
17
+ PLBERT_dir: 'Utils/PLBERT/'
18
+
19
+ data_params:
20
+ train_data: "Data/train_list.txt"
21
+ val_data: "Data/val_list.txt"
22
+ root_path: "/local/LJSpeech-1.1/wavs"
23
+ OOD_data: "Data/OOD_texts.txt"
24
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
25
+
26
+ preprocess_params:
27
+ sr: 24000
28
+ spect_params:
29
+ n_fft: 2048
30
+ win_length: 1200
31
+ hop_length: 300
32
+
33
+ model_params:
34
+ multispeaker: false
35
+
36
+ dim_in: 64
37
+ hidden_dim: 512
38
+ max_conv_dim: 512
39
+ n_layer: 3
40
+ n_mels: 80
41
+
42
+ n_token: 178 # number of phoneme tokens
43
+ max_dur: 50 # maximum duration of a single phoneme
44
+ style_dim: 128 # style vector size
45
+
46
+ dropout: 0.2
47
+
48
+ # config for decoder
49
+ decoder:
50
+ type: 'istftnet' # either hifigan or istftnet
51
+ resblock_kernel_sizes: [3,7,11]
52
+ upsample_rates : [10, 6]
53
+ upsample_initial_channel: 512
54
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
55
+ upsample_kernel_sizes: [20, 12]
56
+ gen_istft_n_fft: 20
57
+ gen_istft_hop_size: 5
58
+
59
+ # speech language model config
60
+ slm:
61
+ model: 'microsoft/wavlm-base-plus'
62
+ sr: 16000 # sampling rate of SLM
63
+ hidden: 768 # hidden size of SLM
64
+ nlayers: 13 # number of layers of SLM
65
+ initial_channel: 64 # initial channels of SLM discriminator head
66
+
67
+ # style diffusion model config
68
+ diffusion:
69
+ embedding_mask_proba: 0.1
70
+ # transformer config
71
+ transformer:
72
+ num_layers: 3
73
+ num_heads: 8
74
+ head_features: 64
75
+ multiplier: 2
76
+
77
+ # diffusion distribution config
78
+ dist:
79
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
80
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
81
+ mean: -3.0
82
+ std: 1.0
83
+
84
+ loss_params:
85
+ lambda_mel: 5. # mel reconstruction loss
86
+ lambda_gen: 1. # generator loss
87
+ lambda_slm: 1. # slm feature matching loss
88
+
89
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
90
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
91
+ TMA_epoch: 50 # TMA starting epoch (1st stage)
92
+
93
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
94
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
95
+ lambda_dur: 1. # duration loss (2nd stage)
96
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
97
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
98
+ lambda_diff: 1. # score matching loss (2nd stage)
99
+
100
+ diff_epoch: 20 # style diffusion starting epoch (2nd stage)
101
+ joint_epoch: 50 # joint training starting epoch (2nd stage)
102
+
103
+ optimizer_params:
104
+ lr: 0.0001 # general learning rate
105
+ bert_lr: 0.00001 # learning rate for PLBERT
106
+ ft_lr: 0.00001 # learning rate for acoustic modules
107
+
108
+ slmadv_params:
109
+ min_len: 400 # minimum length of samples
110
+ max_len: 500 # maximum length of samples
111
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
112
+ iter: 10 # update the discriminator every this iterations of generator update
113
+ thresh: 5 # gradient norm above which the gradient is scaled
114
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
115
+ sig: 1.5 # sigma for differentiable duration modeling
116
+
Configs/config_ft.yml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LJSpeech"
2
+ save_freq: 5
3
+ log_interval: 10
4
+ device: "cuda"
5
+ epochs: 50 # number of finetuning epoch (1 hour of data)
6
+ batch_size: 8
7
+ max_len: 400 # maximum number of frames
8
+ pretrained_model: "Models/LibriTTS/epochs_2nd_00020.pth"
9
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
10
+ load_only_params: true # set to true if do not want to load epoch numbers and optimizer parameters
11
+
12
+ F0_path: "Utils/JDC/bst.t7"
13
+ ASR_config: "Utils/ASR/config.yml"
14
+ ASR_path: "Utils/ASR/epoch_00080.pth"
15
+ PLBERT_dir: 'Utils/PLBERT/'
16
+
17
+ data_params:
18
+ train_data: "Data/train_list.txt"
19
+ val_data: "Data/val_list.txt"
20
+ root_path: "/local/LJSpeech-1.1/wavs"
21
+ OOD_data: "Data/OOD_texts.txt"
22
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
23
+
24
+ preprocess_params:
25
+ sr: 24000
26
+ spect_params:
27
+ n_fft: 2048
28
+ win_length: 1200
29
+ hop_length: 300
30
+
31
+ model_params:
32
+ multispeaker: true
33
+
34
+ dim_in: 64
35
+ hidden_dim: 512
36
+ max_conv_dim: 512
37
+ n_layer: 3
38
+ n_mels: 80
39
+
40
+ n_token: 178 # number of phoneme tokens
41
+ max_dur: 50 # maximum duration of a single phoneme
42
+ style_dim: 128 # style vector size
43
+
44
+ dropout: 0.2
45
+
46
+ # config for decoder
47
+ decoder:
48
+ type: 'hifigan' # either hifigan or istftnet
49
+ resblock_kernel_sizes: [3,7,11]
50
+ upsample_rates : [10,5,3,2]
51
+ upsample_initial_channel: 512
52
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
53
+ upsample_kernel_sizes: [20,10,6,4]
54
+
55
+ # speech language model config
56
+ slm:
57
+ model: 'microsoft/wavlm-base-plus'
58
+ sr: 16000 # sampling rate of SLM
59
+ hidden: 768 # hidden size of SLM
60
+ nlayers: 13 # number of layers of SLM
61
+ initial_channel: 64 # initial channels of SLM discriminator head
62
+
63
+ # style diffusion model config
64
+ diffusion:
65
+ embedding_mask_proba: 0.1
66
+ # transformer config
67
+ transformer:
68
+ num_layers: 3
69
+ num_heads: 8
70
+ head_features: 64
71
+ multiplier: 2
72
+
73
+ # diffusion distribution config
74
+ dist:
75
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
76
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
77
+ mean: -3.0
78
+ std: 1.0
79
+
80
+ loss_params:
81
+ lambda_mel: 5. # mel reconstruction loss
82
+ lambda_gen: 1. # generator loss
83
+ lambda_slm: 1. # slm feature matching loss
84
+
85
+ lambda_mono: 1. # monotonic alignment loss (TMA)
86
+ lambda_s2s: 1. # sequence-to-sequence loss (TMA)
87
+
88
+ lambda_F0: 1. # F0 reconstruction loss
89
+ lambda_norm: 1. # norm reconstruction loss
90
+ lambda_dur: 1. # duration loss
91
+ lambda_ce: 20. # duration predictor probability output CE loss
92
+ lambda_sty: 1. # style reconstruction loss
93
+ lambda_diff: 1. # score matching loss
94
+
95
+ diff_epoch: 10 # style diffusion starting epoch
96
+ joint_epoch: 30 # joint training starting epoch
97
+
98
+ optimizer_params:
99
+ lr: 0.0001 # general learning rate
100
+ bert_lr: 0.00001 # learning rate for PLBERT
101
+ ft_lr: 0.0001 # learning rate for acoustic modules
102
+
103
+ slmadv_params:
104
+ min_len: 400 # minimum length of samples
105
+ max_len: 500 # maximum length of samples
106
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
107
+ iter: 10 # update the discriminator every this iterations of generator update
108
+ thresh: 5 # gradient norm above which the gradient is scaled
109
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
110
+ sig: 1.5 # sigma for differentiable duration modeling
111
+
Configs/config_ft_single.yml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─── GLOBAL ──────────────────────────────────────────────────────────
2
+ log_dir: logs/pod_90h_30k
3
+ device: "cuda"
4
+
5
+ batch_size: 8 # 40 GB A100, fp16
6
+ max_len: 300 # ≈ 8 s (200 × 40 ms)
7
+
8
+ epochs_1st: 25 # first-stage schedule
9
+ epochs_2nd: 15 # second-stage schedule (later)
10
+ save_freq: 2
11
+ log_interval: 50
12
+
13
+ # leave blank on first run
14
+ pretrained_model: /home/ubuntu/styletts2-ft/logs/pod_90h_30k/epoch_1st_0012.pth
15
+ second_stage_load_pretrained: false
16
+ load_only_params: false
17
+
18
+ # ─── PRE-PROCESS ─────────────────────────────────────────────────────
19
+ preprocess_params:
20
+ sr: 24000
21
+ spect_params: # required by Mel extractor
22
+ n_fft: 2048
23
+ win_length: 1200
24
+ hop_length: 300
25
+
26
+ # ─── DATA ────────────────────────────────────────────────────────────
27
+ data_params:
28
+ root_path: /home/ubuntu/styletts2-ft/data/wavs
29
+ train_data: /home/ubuntu/styletts2-ft/data/train_list.txt
30
+ val_data: /home/ubuntu/styletts2-ft/data/val_list.txt
31
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
32
+ OOD_data: /home/ubuntu/styletts2-ft/data/OOD_texts.txt
33
+
34
+ # ─── LOSS SCHEDULE ──────────────────────────────────────────────────
35
+ loss_params:
36
+ lambda_mel: 5. # mel reconstruction loss
37
+ lambda_gen: 1. # generator loss
38
+ lambda_slm: 1. # slm feature matching loss
39
+
40
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
41
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
42
+ TMA_epoch: 14 # TMA starting epoch (1st stage)
43
+
44
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
45
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
46
+ lambda_dur: 1. # duration loss (2nd stage)
47
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
48
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
49
+ lambda_diff: 1. # score matching loss (2nd stage)
50
+
51
+ diff_epoch: 20 # style diffusion starting epoch (2nd stage)
52
+ joint_epoch: 50 # joint training starting epoch (2nd stage)
53
+
54
+ # ─── OPTIMISER ──────────────────────────────────────────────────────
55
+ optimizer_params:
56
+ lr: 0.0008
57
+ bert_lr: 0.00002
58
+ ft_lr: 0.0002
59
+ grad_accum_steps: 2
60
+
61
+ # ─── MODEL (core network & sub-modules) ─────────────────────────────
62
+ model_params:
63
+ multispeaker: true # speaker-ID column present
64
+ dim_in: 64
65
+ hidden_dim: 512
66
+ max_conv_dim: 512
67
+ n_layer: 3
68
+ n_mels: 80
69
+
70
+ n_token: 178 # 178 phonemes
71
+ max_dur: 50
72
+ style_dim: 128
73
+ dropout: 0.2
74
+
75
+ decoder:
76
+ type: hifigan
77
+ resblock_kernel_sizes: [3, 7, 11]
78
+ upsample_rates: [10, 5, 3, 2]
79
+ upsample_initial_channel: 512
80
+ resblock_dilation_sizes: [[1,3,5],[1,3,5],[1,3,5]]
81
+ upsample_kernel_sizes: [20, 10, 6, 4]
82
+
83
+ slm:
84
+ model: microsoft/wavlm-base-plus
85
+ sr: 16000
86
+ hidden: 768
87
+ nlayers: 13
88
+ initial_channel: 64
89
+
90
+ diffusion:
91
+ embedding_mask_proba: 0.1
92
+ transformer:
93
+ num_layers: 3
94
+ num_heads: 8
95
+ head_features: 64
96
+ multiplier: 2
97
+ dist:
98
+ sigma_data: 0.2 # ← placeholder; code will overwrite if
99
+ estimate_sigma_data: true
100
+ mean: -3.0
101
+ std: 1.0
102
+
103
+ # ─── EXTERNAL CHECKPOINTS ───────────────────────────────────────────
104
+ F0_path: "Utils/JDC/bst.t7"
105
+ ASR_config: "Utils/ASR/config.yml"
106
+ ASR_path: "Utils/ASR/epoch_00080.pth"
107
+ PLBERT_dir: 'Utils/PLBERT/'
108
+ first_stage_path: "" # filled automatically after this run
109
+
110
+ # ─── SLM ADVERSARIAL (ignored in stage-1, kept default) ─────────────
111
+ slmadv_params:
112
+ min_len: 400
113
+ max_len: 500
114
+ batch_percentage: 0.5
115
+ iter: 20
116
+ thresh: 5
117
+ scale: 0.01
118
+ sig: 1.5
Configs/config_libritts.yml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LibriTTS"
2
+ first_stage_path: "first_stage.pth"
3
+ save_freq: 1
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 50 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 30 # number of peochs for second stage training (joint training)
8
+ batch_size: 16
9
+ max_len: 300 # maximum number of frames
10
+ pretrained_model: ""
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "Utils/JDC/bst.t7"
15
+ ASR_config: "Utils/ASR/config.yml"
16
+ ASR_path: "Utils/ASR/epoch_00080.pth"
17
+ PLBERT_dir: 'Utils/PLBERT/'
18
+
19
+ data_params:
20
+ train_data: "Data/train_list.txt"
21
+ val_data: "Data/val_list.txt"
22
+ root_path: ""
23
+ OOD_data: "Data/OOD_texts.txt"
24
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
25
+
26
+ preprocess_params:
27
+ sr: 24000
28
+ spect_params:
29
+ n_fft: 2048
30
+ win_length: 1200
31
+ hop_length: 300
32
+
33
+ model_params:
34
+ multispeaker: true
35
+
36
+ dim_in: 64
37
+ hidden_dim: 512
38
+ max_conv_dim: 512
39
+ n_layer: 3
40
+ n_mels: 80
41
+
42
+ n_token: 178 # number of phoneme tokens
43
+ max_dur: 50 # maximum duration of a single phoneme
44
+ style_dim: 128 # style vector size
45
+
46
+ dropout: 0.2
47
+
48
+ # config for decoder
49
+ decoder:
50
+ type: 'hifigan' # either hifigan or istftnet
51
+ resblock_kernel_sizes: [3,7,11]
52
+ upsample_rates : [10,5,3,2]
53
+ upsample_initial_channel: 512
54
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
55
+ upsample_kernel_sizes: [20,10,6,4]
56
+
57
+ # speech language model config
58
+ slm:
59
+ model: 'microsoft/wavlm-base-plus'
60
+ sr: 16000 # sampling rate of SLM
61
+ hidden: 768 # hidden size of SLM
62
+ nlayers: 13 # number of layers of SLM
63
+ initial_channel: 64 # initial channels of SLM discriminator head
64
+
65
+ # style diffusion model config
66
+ diffusion:
67
+ embedding_mask_proba: 0.1
68
+ # transformer config
69
+ transformer:
70
+ num_layers: 3
71
+ num_heads: 8
72
+ head_features: 64
73
+ multiplier: 2
74
+
75
+ # diffusion distribution config
76
+ dist:
77
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
78
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
79
+ mean: -3.0
80
+ std: 1.0
81
+
82
+ loss_params:
83
+ lambda_mel: 5. # mel reconstruction loss
84
+ lambda_gen: 1. # generator loss
85
+ lambda_slm: 1. # slm feature matching loss
86
+
87
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
88
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
89
+ TMA_epoch: 5 # TMA starting epoch (1st stage)
90
+
91
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
92
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
93
+ lambda_dur: 1. # duration loss (2nd stage)
94
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
95
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
96
+ lambda_diff: 1. # score matching loss (2nd stage)
97
+
98
+ diff_epoch: 10 # style diffusion starting epoch (2nd stage)
99
+ joint_epoch: 15 # joint training starting epoch (2nd stage)
100
+
101
+ optimizer_params:
102
+ lr: 0.0001 # general learning rate
103
+ bert_lr: 0.00001 # learning rate for PLBERT
104
+ ft_lr: 0.00001 # learning rate for acoustic modules
105
+
106
+ slmadv_params:
107
+ min_len: 400 # minimum length of samples
108
+ max_len: 500 # maximum length of samples
109
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
110
+ iter: 20 # update the discriminator every this iterations of generator update
111
+ thresh: 5 # gradient norm above which the gradient is scaled
112
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
113
+ sig: 1.5 # sigma for differentiable duration modeling
Data/train_list.txt ADDED
The diff for this file is too large to render. See raw diff
 
Data/val_list.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LJ022-0023.wav|ðɪ ˌoʊvɚwˈɛlmɪŋ mədʒˈɔːɹᵻɾi ʌv pˈiːpəl ɪn ðɪs kˈʌntɹi nˈoʊ hˌaʊ tə sˈɪft ðə wˈiːt fɹʌmðə tʃˈæf ɪn wʌt ðeɪ hˈɪɹ ænd wʌt ðeɪ ɹˈiːd .|0
2
+ LJ043-0030.wav|ɪf sˈʌmbɑːdi dˈɪd ðˈæt tə mˌiː , ɐ lˈaʊsi tɹˈɪk lˈaɪk ðˈæt , tə tˈeɪk maɪ wˈaɪf ɐwˈeɪ , ænd ˈɔːl ðə fˈɜːnɪtʃɚ , aɪ wʊd biː mˈæd æz hˈɛl , tˈuː .|0
3
+ LJ005-0201.wav|ˌæzˌɪz ʃˈoʊn baɪ ðə ɹᵻpˈoːɹt ʌvðə kəmˈɪʃənɚz tʊ ɪŋkwˈaɪɚɹ ˌɪntʊ ðə stˈeɪt ʌvðə mjuːnˈɪsɪpəl kˌɔːɹpɚɹˈeɪʃənz ɪn ˈeɪtiːn θˈɜːɾi fˈaɪv .|0
4
+ LJ001-0110.wav|ˈiːvən ðə kˈæslɑːn tˈaɪp wɛn ɛnlˈɑːɹdʒd ʃˈoʊz ɡɹˈeɪt ʃˈɔːɹtkʌmɪŋz ɪn ðɪs ɹᵻspˈɛkt :|0
5
+ LJ003-0345.wav|ˈɔːl ðə kəmˈɪɾi kʊd dˈuː ɪn ðɪs ɹᵻspˈɛkt wʌz tə θɹˈoʊ ðə ɹᵻspˌɑːnsəbˈɪlɪɾi ˌɔn ˈʌðɚz .|0
6
+ LJ007-0154.wav|ðiːz pˈʌndʒənt ænd wˈɛl ɡɹˈaʊndᵻd stɹˈɪktʃɚz ɐplˈaɪd wɪð stˈɪl ɡɹˈeɪɾɚ fˈoːɹs tə ðɪ ʌŋkənvˈɪktᵻd pɹˈɪzənɚ , ðə mˈæn hˌuː kˈeɪm tə ðə pɹˈɪzən ˈɪnəsənt , ænd stˈɪl ʌŋkəntˈæmᵻnˌeɪɾᵻd ,|0
7
+ LJ018-0098.wav|ænd ɹˈɛkəɡnˌaɪzd æz wˈʌn ʌvðə fɹˈiːkwɛntɚz ʌvðə bˈoʊɡəs lˈɔː stˈeɪʃənɚz . hɪz ɚɹˈɛst lˈɛd tə ðæt ʌv ˈʌðɚz .|0
8
+ LJ047-0044.wav|ˈɑːswəld wʌz , haʊˈɛvɚ , wˈɪlɪŋ tə dɪskˈʌs hɪz kˈɑːntækts wɪð sˈoʊviət ɐθˈɔːɹɪɾiz . hiː dᵻnˈaɪd hˌævɪŋ ˌɛni ɪnvˈɑːlvmənt wɪð sˈoʊviət ɪntˈɛlɪdʒəns ˈeɪdʒənsiz|0
9
+ LJ031-0038.wav|ðə fˈɜːst fɪzˈɪʃən tə sˈiː ðə pɹˈɛzɪdənt æt pˈɑːɹklənd hˈɑːspɪɾəl wʌz dˈɑːktɚ . tʃˈɑːɹlz dʒˈeɪ . kˈæɹɪkˌoʊ , ɐ ɹˈɛzᵻdənt ɪn dʒˈɛnɚɹəl sˈɜːdʒɚɹi .|0
10
+ LJ048-0194.wav|dˈʊɹɹɪŋ ðə mˈɔːɹnɪŋ ʌv noʊvˈɛmbɚ twˈɛnti tˈuː pɹˈaɪɚ tə ðə mˈoʊɾɚkˌeɪd .|0
11
+ LJ049-0026.wav|ˌɔn əkˈeɪʒən ðə sˈiːkɹᵻt sˈɜːvɪs hɐzbɪn pɚmˈɪɾᵻd tə hæv ɐn ˈeɪdʒənt ɹˈaɪdɪŋ ɪnðə pˈæsɪndʒɚ kəmpˈɑːɹtmənt wɪððə pɹˈɛzɪdənt .|0
12
+ LJ004-0152.wav|ɔːlðˈoʊ æt mˈɪstɚ . bˈʌkstənz vˈɪzɪt ɐ nˈuː dʒˈeɪl wʌz ɪn pɹˈɑːsɛs ʌv ɪɹˈɛkʃən , ðə fˈɜːst stˈɛp təwˈɔːɹdz ɹᵻfˈɔːɹm sˈɪns hˈaʊɚdz vˌɪzɪtˈeɪʃən ɪn sˈɛvəntˌiːn sˈɛvənti fˈoːɹ .|0
13
+ LJ008-0278.wav|ɔːɹ ðˈɛɹz mˌaɪt biː wˈʌn ʌv mˈɛni , ænd ɪt mˌaɪt biː kənsˈɪdɚd nˈɛsᵻsɚɹi tə dˈɑːlɚ mˌeɪk ɐn ɛɡzˈæmpəl.dˈɑːlɚ|0
14
+ LJ043-0002.wav|ðə wˈɔːɹəŋ kəmˈɪʃən ɹᵻpˈoːɹt . baɪ ðə pɹˈɛzɪdənts kəmˈɪʃən ɔnðɪ ɐsˌæsᵻnˈeɪʃən ʌv pɹˈɛzɪdənt kˈɛnədi . tʃˈæptɚ sˈɛvən . lˈiː hˈɑːɹvi ˈɑːswəld :|0
15
+ LJ009-0114.wav|mˈɪstɚ . wˈeɪkfiːld wˈaɪndz ˈʌp hɪz ɡɹˈæfɪk bˌʌt sˈʌmwʌt sɛnsˈeɪʃənəl ɐkˈaʊnt baɪ dᵻskɹˈaɪbɪŋ ɐnˈʌðɚ ɹᵻlˈɪdʒəs sˈɜːvɪs , wˌɪtʃ mˈeɪ ɐpɹˈoʊpɹɪˌeɪtli biː ɪnsˈɜːɾᵻd hˈɪɹ .|0
16
+ LJ028-0506.wav|ɐ mˈɑːdɚn ˈɑːɹɾɪst wʊdhɐv dˈɪfɪkˌʌlti ɪn dˌuːɪŋ sˈʌtʃ ˈækjʊɹət wˈɜːk .|0
17
+ LJ050-0168.wav|wɪððə pɚtˈɪkjʊlɚ pˈɜːpəsᵻz ʌvðɪ ˈeɪdʒənsi ɪnvˈɑːlvd . ðə kəmˈɪʃən ɹˈɛkəɡnˌaɪzᵻz ðæt ðɪs ɪz ɐ kˌɑːntɹəvˈɜːʃəl ˈɛɹiə|0
18
+ LJ039-0223.wav|ˈɑːswəldz mɚɹˈiːn tɹˈeɪnɪŋ ɪn mˈɑːɹksmənʃˌɪp , hɪz ˈʌðɚ ɹˈaɪfəl ɛkspˈiəɹɪəns ænd hɪz ɪstˈæblɪʃt fəmˌɪliˈæɹɪɾi wɪð ðɪs pɚtˈɪkjʊlɚ wˈɛpən|0
19
+ LJ029-0032.wav|ɐkˈoːɹdɪŋ tʊ oʊdˈɑːnəl , kwˈoʊt , wiː hæd ɐ mˈoʊɾɚkˌeɪd wɛɹˈɛvɚ kplˈʌsplʌs wˌɪtʃ hɐdbɪn bˌɪn hˈeɪstili sˈʌmənd fɚðə ðə pˈɜːpəs wiː wˈɛnt , ˈɛnd kwˈoʊt .|0
20
+ LJ031-0070.wav|dˈɑːktɚ . klˈɑːɹk , hˌuː mˈoʊst klˈoʊsli əbzˈɜːvd ðə hˈɛd wˈuːnd ,|0
21
+ LJ034-0198.wav|jˈuːɪnz , hˌuː wʌz ɔnðə saʊθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən stɹˈiːts tˈɛstᵻfˌaɪd ðæt hiː kʊd nˌɑːt dᵻskɹˈaɪb ðə mˈæn hiː sˈɔː ɪnðə wˈɪndoʊ .|0
22
+ LJ026-0068.wav|ˈɛnɚdʒi ˈɛntɚz ðə plˈænt , tʊ ɐ smˈɔːl ɛkstˈɛnt ,|0
23
+ LJ039-0075.wav|wˈʌns juː nˈoʊ ðæt juː mˈʌst pˌʊt ðə kɹˈɔshɛɹz ɔnðə tˈɑːɹɡɪt ænd ðæt ɪz ˈɔːl ðæt ɪz nˈɛsᵻsɚɹi .|0
24
+ LJ004-0096.wav|ðə fˈeɪɾəl kˈɑːnsɪkwənsᵻz wˈɛɹɑːf mˌaɪt biː pɹɪvˈɛntᵻd ɪf ðə dʒˈʌstɪsᵻz ʌvðə pˈiːs wɜː djˈuːli ˈɔːθɚɹˌaɪzd|0
25
+ LJ005-0014.wav|spˈiːkɪŋ ˌɔn ɐ dᵻbˈeɪt ˌɔn pɹˈɪzən mˈæɾɚz , hiː dᵻklˈɛɹd ðˈæt|0
26
+ LJ012-0161.wav|hiː wʌz ɹᵻpˈoːɹɾᵻd tə hæv fˈɔːlən ɐwˈeɪ tʊ ɐ ʃˈædoʊ .|0
27
+ LJ018-0239.wav|hɪz dˌɪsɐpˈɪɹəns ɡˈeɪv kˈʌlɚ ænd sˈʌbstəns tʊ ˈiːvəl ɹᵻpˈoːɹts ɔːlɹˌɛdi ɪn sˌɜːkjʊlˈeɪʃən ðætðə wɪl ænd kənvˈeɪəns əbˌʌv ɹᵻfˈɜːd tuː|0
28
+ LJ019-0257.wav|hˈɪɹ ðə tɹˈɛd wˈiːl wʌz ɪn jˈuːs , ðɛɹ sˈɛljʊlɚ kɹˈæŋks , ɔːɹ hˈɑːɹd lˈeɪbɚ məʃˈiːnz .|0
29
+ LJ028-0008.wav|juː tˈæp dʒˈɛntli wɪð jʊɹ hˈiːl əpˌɑːn ðə ʃˈoʊldɚɹ ʌvðə dɹˈoʊmdɚɹi tʊ ˈɜːdʒ hɜːɹ ˈɔn .|0
30
+ LJ024-0083.wav|ðɪs plˈæn ʌv mˈaɪn ɪz nˈoʊ ɐtˈæk ɔnðə kˈoːɹt ;|0
31
+ LJ042-0129.wav|nˈoʊ nˈaɪt klˈʌbz ɔːɹ bˈoʊlɪŋ ˈælɪz , nˈoʊ plˈeɪsᵻz ʌv ɹˌɛkɹiːˈeɪʃən ɛksˈɛpt ðə tɹˈeɪd jˈuːniən dˈænsᵻz . aɪ hæv hæd ɪnˈʌf .|0
32
+ LJ036-0103.wav|ðə pəlˈiːs ˈæskt hˌɪm wˈɛðɚ hiː kʊd pˈɪk ˈaʊt hɪz pˈæsɪndʒɚ fɹʌmðə lˈaɪnʌp .|0
33
+ LJ046-0058.wav|dˈʊɹɹɪŋ hɪz pɹˈɛzɪdənsi , fɹˈæŋklɪn dˈiː . ɹˈoʊzəvˌɛlt mˌeɪd ˈɔːlmoʊst fˈoːɹ hˈʌndɹɪd dʒˈɜːniz ænd tɹˈævəld mˈoːɹ ðɐn θɹˈiː hˈʌndɹɪd fˈɪfti θˈaʊzənd mˈaɪlz .|0
34
+ LJ014-0076.wav|hiː wʌz sˈiːn ˈæftɚwɚdz smˈoʊkɪŋ ænd tˈɔːkɪŋ wɪð hɪz hˈoʊsts ɪn ðɛɹ bˈæk pˈɑːɹlɚ , ænd nˈɛvɚ sˈiːn ɐɡˈɛn ɐlˈaɪv .|0
35
+ LJ002-0043.wav|lˈɔŋ nˈæɹoʊ ɹˈuːmz wˈʌn θˈɜːɾi sˈɪks fˈiːt , sˈɪks twˈɛnti θɹˈiː fˈiːt , ænd ðɪ ˈeɪtθ ˈeɪtiːn ,|0
36
+ LJ009-0076.wav|wiː kˈʌm tə ðə sˈɜːmən .|0
37
+ LJ017-0131.wav|ˈiːvən wɛn ðə hˈaɪ ʃˈɛɹɪf hæd tˈoʊld hˌɪm ðɛɹwˌʌz nˈoʊ pˌɑːsəbˈɪlɪɾi əvɚ ɹᵻpɹˈiːv , ænd wɪðˌɪn ɐ fjˈuː ˈaʊɚz ʌv ˌɛksɪkjˈuːʃən .|0
38
+ LJ046-0184.wav|bˌʌt ðɛɹ ɪz ɐ sˈɪstəm fɚðɪ ɪmˈiːdɪət nˌoʊɾɪfɪkˈeɪʃən ʌvðə sˈiːkɹᵻt sˈɜːvɪs baɪ ðə kənfˈaɪnɪŋ ˌɪnstɪtˈuːʃən wɛn ɐ sˈʌbdʒɛkt ɪz ɹᵻlˈiːst ɔːɹ ɛskˈeɪps .|0
39
+ LJ014-0263.wav|wˌɛn ˈʌðɚ plˈɛʒɚz pˈɔːld hiː tˈʊk ɐ θˈiəɾɚ , ænd pˈoʊzd æz ɐ mjuːnˈɪfɪsənt pˈeɪtɹən ʌvðə dɹəmˈæɾɪk ˈɑːɹt .|0
40
+ LJ042-0096.wav|ˈoʊld ɛkstʃˈeɪndʒ ɹˈeɪt ɪn ɐdˈɪʃən tə hɪz fˈæktɚɹi sˈælɚɹi ʌv ɐpɹˈɑːksɪmətli ˈiːkwəl ɐmˈaʊnt|0
41
+ LJ049-0050.wav|hˈɪl hæd bˈoʊθ fˈiːt ɔnðə kˈɑːɹ ænd wʌz klˈaɪmɪŋ ɐbˈoːɹd tʊ ɐsˈɪst pɹˈɛzɪdənt ænd mˈɪsɪz . kˈɛnədi .|0
42
+ LJ019-0186.wav|sˈiːɪŋ ðæt sˈɪns ðɪ ɪstˈæblɪʃmənt ʌvðə sˈɛntɹəl kɹˈɪmɪnəl kˈoːɹt , nˈuːɡeɪt ɹᵻsˈiːvd pɹˈɪzənɚz fɔːɹ tɹˈaɪəl fɹʌm sˈɛvɹəl kˈaʊntiz ,|0
43
+ LJ028-0307.wav|ðˈɛn lˈɛt twˈɛnti dˈeɪz pˈæs , ænd æt ðɪ ˈɛnd ʌv ðæt tˈaɪm stˈeɪʃən nˌɪɹ ðə tʃˈældæsəŋ ɡˈeɪts ɐ bˈɑːdi ʌv fˈoːɹ θˈaʊzənd .|0
44
+ LJ012-0235.wav|wˌaɪl ðeɪ wɜːɹ ɪn ɐ stˈeɪt ʌv ɪnsˌɛnsəbˈɪlɪɾi ðə mˈɜːdɚ wʌz kəmˈɪɾᵻd .|0
45
+ LJ034-0053.wav|ɹˈiːtʃt ðə sˈeɪm kəŋklˈuːʒən æz lætˈoʊnə ðætðə pɹˈɪnts fˈaʊnd ɔnðə kˈɑːɹtənz wɜː ðoʊz ʌv lˈiː hˈɑːɹvi ˈɑːswəld .|0
46
+ LJ014-0030.wav|ðiːz wɜː dˈæmnətˌoːɹi fˈækts wˌɪtʃ wˈɛl səpˈoːɹɾᵻd ðə pɹˌɑːsɪkjˈuːʃən .|0
47
+ LJ015-0203.wav|bˌʌt wɜː ðə pɹɪkˈɔːʃənz tˈuː mˈɪnɪt , ðə vˈɪdʒɪləns tˈuː klˈoʊs təbi ᵻlˈuːdᵻd ɔːɹ ˌoʊvɚkˈʌm ?|0
48
+ LJ028-0093.wav|bˌʌt hɪz skɹˈaɪb ɹˈoʊt ɪɾ ɪnðə mˈænɚ kˈʌstəmˌɛɹi fɚðə skɹˈaɪbz ʌv ðoʊz dˈeɪz tə ɹˈaɪt ʌv ðɛɹ ɹˈɔɪəl mˈæstɚz .|0
49
+ LJ002-0018.wav|ðɪ ɪnˈædɪkwəsi ʌvðə dʒˈeɪl wʌz nˈoʊɾɪst ænd ɹᵻpˈoːɹɾᵻd əpˌɑːn ɐɡˈɛn ænd ɐɡˈɛn baɪ ðə ɡɹˈænd dʒˈʊɹɹiz ʌvðə sˈɪɾi ʌv lˈʌndən ,|0
50
+ LJ028-0275.wav|æt lˈæst , ɪnðə twˈɛntiəθ mˈʌnθ ,|0
51
+ LJ012-0042.wav|wˌɪtʃ hiː kˈɛpt kənsˈiːld ɪn ɐ hˈaɪdɪŋ plˈeɪs wɪð ɐ tɹˈæp dˈoːɹ dʒˈʌst ˌʌndɚ hɪz bˈɛd .|0
52
+ LJ011-0096.wav|hiː mˈæɹid ɐ lˈeɪdi ˈɔːlsoʊ bᵻlˈɔŋɪŋ tə ðə səsˈaɪəɾi ʌv fɹˈɛndz , hˌuː bɹˈɔːt hˌɪm ɐ lˈɑːɹdʒ fˈɔːɹtʃʊn , wˈɪtʃ , ænd hɪz ˈoʊn mˈʌni , hiː pˌʊt ˌɪntʊ ɐ sˈɪɾi fˈɜːm ,|0
53
+ LJ036-0077.wav|ɹˈɑːdʒɚ dˈiː . kɹˈeɪɡ , ɐ dˈɛpjuːɾi ʃˈɛɹɪf ʌv dˈæləs kˈaʊnti ,|0
54
+ LJ016-0318.wav|ˈʌðɚɹ əfˈɪʃəlz , ɡɹˈeɪt lˈɔɪɚz , ɡˈʌvɚnɚz ʌv pɹˈɪzənz , ænd tʃˈæplɪnz səpˈoːɹɾᵻd ðɪs vjˈuː .|0
55
+ LJ013-0164.wav|hˌuː kˈeɪm fɹʌm hɪz ɹˈuːm ɹˈɛdi dɹˈɛst , ɐ səspˈɪʃəs sˈɜːkəmstˌæns , æz hiː wʌz ˈɔːlweɪz lˈeɪt ɪnðə mˈɔːɹnɪŋ .|0
56
+ LJ027-0141.wav|ɪz klˈoʊsli ɹᵻpɹədˈuːst ɪnðə lˈaɪf hˈɪstɚɹi ʌv ɛɡzˈɪstɪŋ dˈɪɹ . ɔːɹ , ɪn ˈʌðɚ wˈɜːdz ,|0
57
+ LJ028-0335.wav|ɐkˈoːɹdɪŋli ðeɪ kəmˈɪɾᵻd tə hˌɪm ðə kəmˈænd ʌv ðɛɹ hˈoʊl ˈɑːɹmi , ænd pˌʊt ðə kˈiːz ʌv ðɛɹ sˈɪɾi ˌɪntʊ hɪz hˈændz .|0
58
+ LJ031-0202.wav|mˈɪsɪz . kˈɛnədi tʃˈoʊz ðə hˈɑːspɪɾəl ɪn bəθˈɛzdə fɚðɪ ˈɔːtɑːpsi bɪkˈʌz ðə pɹˈɛzɪdənt hæd sˈɜːvd ɪnðə nˈeɪvi .|0
59
+ LJ021-0145.wav|fɹʌm ðoʊz wˈɪlɪŋ tə dʒˈɔɪn ɪn ɪstˈæblɪʃɪŋ ðɪs hˈo��pt fɔːɹ pˈiəɹɪəd ʌv pˈiːs ,|0
60
+ LJ016-0288.wav|dˈɑːlɚ mˈuːlɚ , mˈuːlɚ , hiːz ðə mˈæn , dˈɑːlɚ tˈɪl ɐ daɪvˈɜːʒən wʌz kɹiːˈeɪɾᵻd baɪ ðɪ ɐpˈɪɹəns ʌvðə ɡˈæloʊz , wˌɪtʃ wʌz ɹᵻsˈiːvd wɪð kəntˈɪnjuːəs jˈɛlz .|0
61
+ LJ028-0081.wav|jˈɪɹz lˈeɪɾɚ , wˌɛn ðɪ ˌɑːɹkiːˈɑːlədʒˌɪsts kʊd ɹˈɛdili dɪstˈɪŋɡwɪʃ ðə fˈɔls fɹʌmðə tɹˈuː ,|0
62
+ LJ018-0081.wav|hɪz dᵻfˈɛns bˌiːɪŋ ðæt hiː hæd ɪntˈɛndᵻd tə kəmˈɪt sˈuːɪsˌaɪd , bˌʌt ðˈæt , ɔnðɪ ɐpˈɪɹəns ʌv ðɪs ˈɑːfɪsɚ hˌuː hæd ɹˈɔŋd hˌɪm ,|0
63
+ LJ021-0066.wav|təɡˌɛðɚ wɪð ɐ ɡɹˈeɪt ˈɪŋkɹiːs ɪnðə pˈeɪɹoʊlz , ðɛɹ hɐz kˈʌm ɐ səbstˈænʃəl ɹˈaɪz ɪnðə tˈoʊɾəl ʌv ɪndˈʌstɹɪəl pɹˈɑːfɪts|0
64
+ LJ009-0238.wav|ˈæftɚ ðɪs ðə ʃˈɛɹɪfs sˈɛnt fɔːɹ ɐnˈʌðɚ ɹˈoʊp , bˌʌt ðə spɛktˈeɪɾɚz ˌɪntəfˈɪɹd , ænd ðə mˈæn wʌz kˈæɹid bˈæk tə dʒˈeɪl .|0
65
+ LJ005-0079.wav|ænd ɪmpɹˈuːv ðə mˈɔːɹəlz ʌvðə pɹˈɪzənɚz , ænd ʃˌæl ɪnʃˈʊɹ ðə pɹˈɑːpɚ mˈɛʒɚɹ ʌv pˈʌnɪʃmənt tə kənvˈɪktᵻd əfˈɛndɚz .|0
66
+ LJ035-0019.wav|dɹˈoʊv tə ðə nɔːɹθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən , ænd pˈɑːɹkt ɐpɹˈɑːksɪmətli tˈɛn fˈiːt fɹʌmðə tɹˈæfɪk sˈɪɡnəl .|0
67
+ LJ036-0174.wav|ðɪs ɪz ðɪ ɐpɹˈɑːksɪmət tˈaɪm hiː ˈɛntɚd ðə ɹˈuːmɪŋhˌaʊs , ɐkˈoːɹdɪŋ tʊ ˈɜːliːn ɹˈɑːbɚts , ðə hˈaʊskiːpɚ ðˈɛɹ .|0
68
+ LJ046-0146.wav|ðə kɹaɪtˈiəɹɪə ɪn ɪfˈɛkt pɹˈaɪɚ tə noʊvˈɛmbɚ twˈɛnti tˈuː , nˈaɪntiːn sˈɪksti θɹˈiː , fɔːɹ dɪtˈɜːmɪnɪŋ wˈɛðɚ tʊ ɐksˈɛpt mətˈɪɹiəl fɚðə pˌiːˌɑːɹɹˈɛs dʒˈɛnɚɹəl fˈaɪlz|0
69
+ LJ017-0044.wav|ænd ðə dˈiːpɪst æŋzˈaɪəɾi wʌz fˈɛlt ðætðə kɹˈaɪm , ɪf kɹˈaɪm ðˈɛɹ hɐdbɪn , ʃˌʊd biː bɹˈɔːt hˈoʊm tʊ ɪts pˈɜːpɪtɹˌeɪɾɚ .|0
70
+ LJ017-0070.wav|bˌʌt hɪz spˈoːɹɾɪŋ ˌɑːpɚɹˈeɪʃənz dɪdnˌɑːt pɹˈɑːspɚ , ænd hiː bɪkˌeɪm ɐ nˈiːdi mˈæn , ˈɔːlweɪz dɹˈɪvən tə dˈɛspɚɹət stɹˈeɪts fɔːɹ kˈæʃ .|0
71
+ LJ014-0020.wav|hiː wʌz sˈuːn ˈæftɚwɚdz ɚɹˈɛstᵻd ˌɔn səspˈɪʃən , ænd ɐ sˈɜːtʃ ʌv hɪz lˈɑːdʒɪŋz bɹˈɔːt tə lˈaɪt sˈɛvɹəl ɡˈɑːɹmənts sˈætʃɚɹˌeɪɾᵻd wɪð blˈʌd ;|0
72
+ LJ016-0020.wav|hiː nˈɛvɚ ɹˈiːtʃt ðə sˈɪstɚn , bˌʌt fˈɛl bˈæk ˌɪntʊ ðə jˈɑːɹd , ˈɪndʒɚɹɪŋ hɪz lˈɛɡz sᵻvˈɪɹli .|0
73
+ LJ045-0230.wav|wˌɛn hiː wʌz fˈaɪnəli ˌæpɹihˈɛndᵻd ɪnðə tˈɛksəs θˈiəɾɚ . ɔːlðˈoʊ ɪɾ ɪz nˌɑːt fˈʊli kɚɹˈɑːbɚɹˌeɪɾᵻd baɪ ˈʌðɚz hˌuː wɜː pɹˈɛzənt ,|0
74
+ LJ035-0129.wav|ænd ʃiː mˈʌstɐv ɹˈʌn dˌaʊn ðə stˈɛɹz ɐhˈɛd ʌv ˈɑːswəld ænd wʊd pɹˈɑːbəbli hæv sˈiːn ɔːɹ hˈɜːd hˌɪm .|0
75
+ LJ008-0307.wav|ˈæftɚwɚdz ɛkspɹˈɛs ɐ wˈɪʃ tə mˈɜːdɚ ðə ɹᵻkˈoːɹdɚ fɔːɹ hˌævɪŋ kˈɛpt ðˌɛm sˌoʊ lˈɔŋ ɪn səspˈɛns .|0
76
+ LJ008-0294.wav|nˌɪɹli ɪndˈɛfɪnətli dᵻfˈɜːd .|0
77
+ LJ047-0148.wav|ˌɔn ɑːktˈoʊbɚ twˈɛnti fˈaɪv ,|0
78
+ LJ008-0111.wav|ðeɪ ˈɛntɚd ɐ dˈɑːlɚ stˈoʊŋ kˈoʊld ɹˈuːm , dˈɑːlɚɹ ænd wɜː pɹˈɛzəntli dʒˈɔɪnd baɪ ðə pɹˈɪzənɚ .|0
79
+ LJ034-0042.wav|ðæt hiː kʊd ˈoʊnli tˈɛstᵻfˌaɪ wɪð sˈɜːtənti ðætðə pɹˈɪnt wʌz lˈɛs ðɐn θɹˈiː dˈeɪz ˈoʊld .|0
80
+ LJ037-0234.wav|mˈɪsɪz . mˈɛɹi bɹˈɑːk , ðə wˈaɪf əvə mɪkˈænɪk hˌuː wˈɜːkt æt ðə stˈeɪʃən , wʌz ðɛɹ æt ðə tˈaɪm ænd ʃiː sˈɔː ɐ wˈaɪt mˈeɪl ,|0
81
+ LJ040-0002.wav|tʃˈæptɚ sˈɛvən . lˈiː hˈɑːɹvi ˈɑːswəld : bˈækɡɹaʊnd ænd pˈɑːsᵻbəl mˈoʊɾɪvz , pˈɑːɹt wˌʌn .|0
82
+ LJ045-0140.wav|ðɪ ˈɑːɹɡjuːmənts hiː jˈuːzd tə dʒˈʌstᵻfˌaɪ hɪz jˈuːs ʌvðɪ ˈeɪliəs sədʒˈɛst ðæt ˈɑːswəld mˌeɪhɐv kˈʌm tə θˈɪŋk ðætðə hˈoʊl wˈɜːld wʌz bᵻkˈʌmɪŋ ɪnvˈɑːlvd|0
83
+ LJ012-0035.wav|ðə nˈʌmbɚ ænd nˈeɪmz ˌɔn wˈɑːtʃᵻz , wɜː kˈɛɹfəli ɹᵻmˈuːvd ɔːɹ əblˈɪɾɚɹˌeɪɾᵻd ˈæftɚ ðə ɡˈʊdz pˈæst ˌaʊɾəv hɪz hˈændz .|0
84
+ LJ012-0250.wav|ɔnðə sˈɛvənθ dʒuːlˈaɪ , ˈeɪtiːn θˈɜːɾi sˈɛvən ,|0
85
+ LJ016-0179.wav|kəntɹˈæktᵻd wɪð ʃˈɛɹɪfs ænd kənvˈiːnɚz tə wˈɜːk baɪ ðə dʒˈɑːb .|0
86
+ LJ016-0138.wav|æɾə dˈɪstəns fɹʌmðə pɹˈɪzən .|0
87
+ LJ027-0052.wav|ðiːz pɹˈɪnsɪpəlz ʌv həmˈɑːlədʒi ɑːɹ ᵻsˈɛnʃəl tʊ ɐ kɚɹˈɛkt ɪntˌɜːpɹɪtˈeɪʃən ʌvðə fˈækts ʌv mɔːɹfˈɑːlədʒi .|0
88
+ LJ031-0134.wav|ˌɔn wˈʌn əkˈeɪʒən mˈɪsɪz . dʒˈɑːnsən , ɐkˈʌmpənid baɪ tˈuː sˈiːkɹᵻt sˈɜːvɪs ˈeɪdʒənts , lˈɛft ðə ɹˈuːm tə sˈiː mˈɪsɪz . kˈɛnədi ænd mˈɪsɪz . kˈɑːnæli .|0
89
+ LJ019-0273.wav|wˌɪtʃ sˌɜː dʒˈɑːʃjuːə dʒˈɛb tˈoʊld ðə kəmˈɪɾi hiː kənsˈɪdɚd ðə pɹˈɑːpɚɹ ˈɛlɪmənts ʌv pˈiːnəl dˈɪsɪplˌɪn .|0
90
+ LJ014-0110.wav|æt ðə fˈɜːst ðə bˈɑːksᵻz wɜːɹ ɪmpˈaʊndᵻd , ˈoʊpənd , ænd fˈaʊnd tə kəntˈeɪn mˈɛnɪəv oʊkˈɑːnɚz ɪfˈɛkts .|0
91
+ LJ034-0160.wav|ˌɔn bɹˈɛnənz sˈʌbsᵻkwənt sˈɜːʔn̩ aɪdˈɛntɪfɪkˈeɪʃən ʌv lˈiː hˈɑːɹvi ˈɑːswəld æz ðə mˈæn hiː sˈɔː fˈaɪɚ ðə ɹˈaɪfəl .|0
92
+ LJ038-0199.wav|ᵻlˈɛvən . ɪf aɪɐm ɐlˈaɪv ænd tˈeɪkən pɹˈɪzənɚ ,|0
93
+ LJ014-0010.wav|jˈɛt hiː kʊd nˌɑːt ˌoʊvɚkˈʌm ðə stɹˈeɪndʒ fˌæsᵻnˈeɪʃən ɪt hˈæd fɔːɹ hˌɪm , ænd ɹᵻmˈeɪnd baɪ ðə sˈaɪd ʌvðə kˈɔːɹps tˈɪl ðə stɹˈɛtʃɚ kˈeɪm .|0
94
+ LJ033-0047.wav|aɪ nˈoʊɾɪst wɛn aɪ wɛnt ˈaʊt ðætðə lˈaɪt wʌz ˈɔn , ˈɛnd kwˈoʊt ,|0
95
+ LJ040-0027.wav|hiː wʌz nˈɛvɚ sˈæɾɪsfˌaɪd wɪð ˈɛnɪθˌɪŋ .|0
96
+ LJ048-0228.wav|ænd ˈʌðɚz hˌuː wɜː pɹˈɛzənt sˈeɪ ðæt nˈoʊ ˈeɪdʒənt wʌz ɪnˈiːbɹɪˌeɪɾᵻd ɔːɹ ˈæktᵻd ɪmpɹˈɑːpɚli .|0
97
+ LJ003-0111.wav|hiː wʌz ɪŋ kˈɑːnsɪkwəns pˌʊt ˌaʊɾəv ðə pɹətˈɛkʃən ʌv ðɛɹ ɪntˈɜːnəl lˈɔː , ˈɛnd kwˈoʊt . ðɛɹ kˈoʊd wʌzɐ sˈʌbdʒɛkt ʌv sˌʌm kjˌʊɹɹɪˈɔsɪɾi .|0
98
+ LJ008-0258.wav|lˈɛt mˌiː ɹᵻtɹˈeɪs maɪ stˈɛps , ænd spˈiːk mˈoːɹ ɪn diːtˈeɪl ʌvðə tɹˈiːtmənt ʌvðə kəndˈɛmd ɪn ðoʊz blˈʌdθɜːsti ænd bɹˈuːɾəli ɪndˈɪfɹənt dˈeɪz ,|0
99
+ LJ029-0022.wav|ðɪ ɚɹˈɪdʒɪnəl plˈæŋ kˈɔːld fɚðə pɹˈɛzɪdənt tə spˈɛnd ˈoʊnli wˈʌn dˈeɪ ɪnðə stˈeɪt , mˌeɪkɪŋ wˈɜːlwɪnd vˈɪzɪts tə dˈæləs , fˈɔːɹt wˈɜːθ , sˌæn æntˈoʊnɪˌoʊ , ænd hjˈuːstən .|0
100
+ LJ004-0045.wav|mˈɪstɚ . stˈɜːdʒᵻz bˈoːɹn , sˌɜː dʒˈeɪmz mˈækɪntˌɑːʃ , sˌɜː dʒˈeɪmz skˈɑːɹlɪt , ænd wˈɪljəm wˈɪlbɚfˌoːɹs .|0
Demo/Inference_LJSpeech.ipynb ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9adb7bd1",
6
+ "metadata": {},
7
+ "source": [
8
+ "# StyleTTS 2 Demo (LJSpeech)\n"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "6108384d",
14
+ "metadata": {},
15
+ "source": [
16
+ "### Utils"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "id": "96e173bf",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "import torch\n",
27
+ "torch.manual_seed(0)\n",
28
+ "torch.backends.cudnn.benchmark = False\n",
29
+ "torch.backends.cudnn.deterministic = True\n",
30
+ "\n",
31
+ "import random\n",
32
+ "random.seed(0)\n",
33
+ "\n",
34
+ "import numpy as np\n",
35
+ "np.random.seed(0)"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "da84c60f",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "%cd .."
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "id": "5a3ddcc8",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "# load packages\n",
56
+ "import time\n",
57
+ "import random\n",
58
+ "import yaml\n",
59
+ "from munch import Munch\n",
60
+ "import numpy as np\n",
61
+ "import torch\n",
62
+ "from torch import nn\n",
63
+ "import torch.nn.functional as F\n",
64
+ "import torchaudio\n",
65
+ "import librosa\n",
66
+ "from nltk.tokenize import word_tokenize\n",
67
+ "\n",
68
+ "from models import *\n",
69
+ "from utils import *\n",
70
+ "from text_utils import TextCleaner\n",
71
+ "textclenaer = TextCleaner()\n",
72
+ "\n",
73
+ "%matplotlib inline"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "bbdc04c0",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "00ee05e1",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
94
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
95
+ "mean, std = -4, 4\n",
96
+ "\n",
97
+ "def length_to_mask(lengths):\n",
98
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
99
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
100
+ " return mask\n",
101
+ "\n",
102
+ "def preprocess(wave):\n",
103
+ " wave_tensor = torch.from_numpy(wave).float()\n",
104
+ " mel_tensor = to_mel(wave_tensor)\n",
105
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
106
+ " return mel_tensor\n",
107
+ "\n",
108
+ "def compute_style(ref_dicts):\n",
109
+ " reference_embeddings = {}\n",
110
+ " for key, path in ref_dicts.items():\n",
111
+ " wave, sr = librosa.load(path, sr=24000)\n",
112
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
113
+ " if sr != 24000:\n",
114
+ " audio = librosa.resample(audio, sr, 24000)\n",
115
+ " mel_tensor = preprocess(audio).to(device)\n",
116
+ "\n",
117
+ " with torch.no_grad():\n",
118
+ " ref = model.style_encoder(mel_tensor.unsqueeze(1))\n",
119
+ " reference_embeddings[key] = (ref.squeeze(1), audio)\n",
120
+ " \n",
121
+ " return reference_embeddings"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "id": "7b9cecbe",
127
+ "metadata": {},
128
+ "source": [
129
+ "### Load models"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "id": "64fc4c0f",
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "# load phonemizer\n",
140
+ "import phonemizer\n",
141
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "48e7b644",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "config = yaml.safe_load(open(\"Models/LJSpeech/config.yml\"))\n",
152
+ "\n",
153
+ "# load pretrained ASR model\n",
154
+ "ASR_config = config.get('ASR_config', False)\n",
155
+ "ASR_path = config.get('ASR_path', False)\n",
156
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
157
+ "\n",
158
+ "# load pretrained F0 model\n",
159
+ "F0_path = config.get('F0_path', False)\n",
160
+ "pitch_extractor = load_F0_models(F0_path)\n",
161
+ "\n",
162
+ "# load BERT model\n",
163
+ "from Utils.PLBERT.util import load_plbert\n",
164
+ "BERT_path = config.get('PLBERT_dir', False)\n",
165
+ "plbert = load_plbert(BERT_path)"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": null,
171
+ "id": "ffc18cf7",
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)\n",
176
+ "_ = [model[key].eval() for key in model]\n",
177
+ "_ = [model[key].to(device) for key in model]"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "64529d5c",
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "params_whole = torch.load(\"Models/LJSpeech/epoch_2nd_00100.pth\", map_location='cpu')\n",
188
+ "params = params_whole['net']"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "id": "895d9706",
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "for key in model:\n",
199
+ " if key in params:\n",
200
+ " print('%s loaded' % key)\n",
201
+ " try:\n",
202
+ " model[key].load_state_dict(params[key])\n",
203
+ " except:\n",
204
+ " from collections import OrderedDict\n",
205
+ " state_dict = params[key]\n",
206
+ " new_state_dict = OrderedDict()\n",
207
+ " for k, v in state_dict.items():\n",
208
+ " name = k[7:] # remove `module.`\n",
209
+ " new_state_dict[name] = v\n",
210
+ " # load params\n",
211
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
212
+ "# except:\n",
213
+ "# _load(params[key], model[key])\n",
214
+ "_ = [model[key].eval() for key in model]"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": null,
220
+ "id": "c1a59db2",
221
+ "metadata": {},
222
+ "outputs": [],
223
+ "source": [
224
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "id": "e30985ab",
231
+ "metadata": {},
232
+ "outputs": [],
233
+ "source": [
234
+ "sampler = DiffusionSampler(\n",
235
+ " model.diffusion.diffusion,\n",
236
+ " sampler=ADPM2Sampler(),\n",
237
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
238
+ " clamp=False\n",
239
+ ")"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "markdown",
244
+ "id": "b803110e",
245
+ "metadata": {},
246
+ "source": [
247
+ "### Synthesize speech"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": null,
253
+ "id": "24655f46",
254
+ "metadata": {},
255
+ "outputs": [],
256
+ "source": [
257
+ "# synthesize a text\n",
258
+ "text = ''' StyleTTS 2 is a text-to-speech model that leverages style diffusion and adversarial training with large speech language models to achieve human-level text-to-speech synthesis. '''"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "id": "ca57469c",
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "def inference(text, noise, diffusion_steps=5, embedding_scale=1):\n",
269
+ " text = text.strip()\n",
270
+ " text = text.replace('\"', '')\n",
271
+ " ps = global_phonemizer.phonemize([text])\n",
272
+ " ps = word_tokenize(ps[0])\n",
273
+ " ps = ' '.join(ps)\n",
274
+ "\n",
275
+ " tokens = textclenaer(ps)\n",
276
+ " tokens.insert(0, 0)\n",
277
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
278
+ " \n",
279
+ " with torch.no_grad():\n",
280
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)\n",
281
+ " text_mask = length_to_mask(input_lengths).to(tokens.device)\n",
282
+ "\n",
283
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
284
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
285
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
286
+ "\n",
287
+ " s_pred = sampler(noise, \n",
288
+ " embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,\n",
289
+ " embedding_scale=embedding_scale).squeeze(0)\n",
290
+ "\n",
291
+ " s = s_pred[:, 128:]\n",
292
+ " ref = s_pred[:, :128]\n",
293
+ "\n",
294
+ " d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)\n",
295
+ "\n",
296
+ " x, _ = model.predictor.lstm(d)\n",
297
+ " duration = model.predictor.duration_proj(x)\n",
298
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
299
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
300
+ "\n",
301
+ " pred_dur[-1] += 5\n",
302
+ "\n",
303
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
304
+ " c_frame = 0\n",
305
+ " for i in range(pred_aln_trg.size(0)):\n",
306
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
307
+ " c_frame += int(pred_dur[i].data)\n",
308
+ "\n",
309
+ " # encode prosody\n",
310
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
311
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
312
+ " out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), \n",
313
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
314
+ " \n",
315
+ " return out.squeeze().cpu().numpy()"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "markdown",
320
+ "id": "d438ef4f",
321
+ "metadata": {},
322
+ "source": [
323
+ "#### Basic synthesis (5 diffusion steps)"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "code",
328
+ "execution_count": null,
329
+ "id": "d3d7f7d5",
330
+ "metadata": {
331
+ "scrolled": true
332
+ },
333
+ "outputs": [],
334
+ "source": [
335
+ "start = time.time()\n",
336
+ "noise = torch.randn(1,1,256).to(device)\n",
337
+ "wav = inference(text, noise, diffusion_steps=5, embedding_scale=1)\n",
338
+ "rtf = (time.time() - start) / (len(wav) / 24000)\n",
339
+ "print(f\"RTF = {rtf:5f}\")\n",
340
+ "import IPython.display as ipd\n",
341
+ "display(ipd.Audio(wav, rate=24000))"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "markdown",
346
+ "id": "2d5d9df0",
347
+ "metadata": {},
348
+ "source": [
349
+ "#### With higher diffusion steps (more diverse)\n",
350
+ "Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed."
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "id": "a10129fd",
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "start = time.time()\n",
361
+ "noise = torch.randn(1,1,256).to(device)\n",
362
+ "wav = inference(text, noise, diffusion_steps=10, embedding_scale=1)\n",
363
+ "rtf = (time.time() - start) / (len(wav) / 24000)\n",
364
+ "print(f\"RTF = {rtf:5f}\")\n",
365
+ "import IPython.display as ipd\n",
366
+ "display(ipd.Audio(wav, rate=24000))"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "markdown",
371
+ "id": "1877ea15",
372
+ "metadata": {},
373
+ "source": [
374
+ "### Speech expressiveness\n",
375
+ "The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page."
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "markdown",
380
+ "id": "4c4777b7",
381
+ "metadata": {},
382
+ "source": [
383
+ "#### With embedding_scale=1\n",
384
+ "This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional. "
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": null,
390
+ "id": "c29ea2f0",
391
+ "metadata": {},
392
+ "outputs": [],
393
+ "source": [
394
+ "texts = {}\n",
395
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
396
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
397
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
398
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
399
+ "\n",
400
+ "for k,v in texts.items():\n",
401
+ " noise = torch.randn(1,1,256).to(device)\n",
402
+ " wav = inference(v, noise, diffusion_steps=10, embedding_scale=1)\n",
403
+ " print(k + \": \")\n",
404
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
405
+ ]
406
+ },
407
+ {
408
+ "cell_type": "markdown",
409
+ "id": "3c89499f",
410
+ "metadata": {},
411
+ "source": [
412
+ "#### With embedding_scale=2"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "id": "f73be3aa",
419
+ "metadata": {},
420
+ "outputs": [],
421
+ "source": [
422
+ "texts = {}\n",
423
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
424
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
425
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
426
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
427
+ "\n",
428
+ "for k,v in texts.items():\n",
429
+ " noise = torch.randn(1,1,256).to(device)\n",
430
+ " wav = inference(v, noise, diffusion_steps=10, embedding_scale=2) # embedding_scale=2 for more pronounced emotion\n",
431
+ " print(k + \": \")\n",
432
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "markdown",
437
+ "id": "9320da63",
438
+ "metadata": {},
439
+ "source": [
440
+ "### Long-form generation\n",
441
+ "This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page. "
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
+ "id": "cdd4db51",
448
+ "metadata": {},
449
+ "outputs": [],
450
+ "source": [
451
+ "passage = '''If the supply of fruit is greater than the family needs, it may be made a source of income by sending the fresh fruit to the market if there is one near enough, or by preserving, canning, and making jelly for sale. To make such an enterprise a success the fruit and work must be first class. There is magic in the word \"Homemade,\" when the product appeals to the eye and the palate; but many careless and incompetent people have found to their sorrow that this word has not magic enough to float inferior goods on the market. As a rule large canning and preserving establishments are clean and have the best appliances, and they employ chemists and skilled labor. The home product must be very good to compete with the attractive goods that are sent out from such establishments. Yet for first-class homemade products there is a market in all large cities. All first-class grocers have customers who purchase such goods.'''"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": null,
457
+ "id": "ebb941c8",
458
+ "metadata": {},
459
+ "outputs": [],
460
+ "source": [
461
+ "def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1):\n",
462
+ " text = text.strip()\n",
463
+ " text = text.replace('\"', '')\n",
464
+ " ps = global_phonemizer.phonemize([text])\n",
465
+ " ps = word_tokenize(ps[0])\n",
466
+ " ps = ' '.join(ps)\n",
467
+ "\n",
468
+ " tokens = textclenaer(ps)\n",
469
+ " tokens.insert(0, 0)\n",
470
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
471
+ " \n",
472
+ " with torch.no_grad():\n",
473
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)\n",
474
+ " text_mask = length_to_mask(input_lengths).to(tokens.device)\n",
475
+ "\n",
476
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
477
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
478
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
479
+ "\n",
480
+ " s_pred = sampler(noise, \n",
481
+ " embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,\n",
482
+ " embedding_scale=embedding_scale).squeeze(0)\n",
483
+ " \n",
484
+ " if s_prev is not None:\n",
485
+ " # convex combination of previous and current style\n",
486
+ " s_pred = alpha * s_prev + (1 - alpha) * s_pred\n",
487
+ " \n",
488
+ " s = s_pred[:, 128:]\n",
489
+ " ref = s_pred[:, :128]\n",
490
+ "\n",
491
+ " d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)\n",
492
+ "\n",
493
+ " x, _ = model.predictor.lstm(d)\n",
494
+ " duration = model.predictor.duration_proj(x)\n",
495
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
496
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
497
+ "\n",
498
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
499
+ " c_frame = 0\n",
500
+ " for i in range(pred_aln_trg.size(0)):\n",
501
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
502
+ " c_frame += int(pred_dur[i].data)\n",
503
+ "\n",
504
+ " # encode prosody\n",
505
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
506
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
507
+ " out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), \n",
508
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
509
+ " \n",
510
+ " return out.squeeze().cpu().numpy(), s_pred"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": null,
516
+ "id": "7ca0ef2e",
517
+ "metadata": {},
518
+ "outputs": [],
519
+ "source": [
520
+ "sentences = passage.split('.') # simple split by comma\n",
521
+ "wavs = []\n",
522
+ "s_prev = None\n",
523
+ "for text in sentences:\n",
524
+ " if text.strip() == \"\": continue\n",
525
+ " text += '.' # add it back\n",
526
+ " noise = torch.randn(1,1,256).to(device)\n",
527
+ " wav, s_prev = LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=10, embedding_scale=1.5)\n",
528
+ " wavs.append(wav)\n",
529
+ "display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))"
530
+ ]
531
+ }
532
+ ],
533
+ "metadata": {
534
+ "kernelspec": {
535
+ "display_name": "NLP",
536
+ "language": "python",
537
+ "name": "nlp"
538
+ },
539
+ "language_info": {
540
+ "codemirror_mode": {
541
+ "name": "ipython",
542
+ "version": 3
543
+ },
544
+ "file_extension": ".py",
545
+ "mimetype": "text/x-python",
546
+ "name": "python",
547
+ "nbconvert_exporter": "python",
548
+ "pygments_lexer": "ipython3",
549
+ "version": "3.9.7"
550
+ }
551
+ },
552
+ "nbformat": 4,
553
+ "nbformat_minor": 5
554
+ }
Demo/Inference_LibriTTS.ipynb ADDED
@@ -0,0 +1,1155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9adb7bd1",
6
+ "metadata": {},
7
+ "source": [
8
+ "# StyleTTS 2 Demo (LibriTTS)\n",
9
+ "\n",
10
+ "Before you run the following cells, please make sure you have downloaded [reference_audio.zip](https://huggingface.co/yl4579/StyleTTS2-LibriTTS/resolve/main/reference_audio.zip) and unzipped it under the `demo` folder."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "6108384d",
16
+ "metadata": {},
17
+ "source": [
18
+ "### Utils"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "96e173bf",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import torch\n",
29
+ "torch.manual_seed(0)\n",
30
+ "torch.backends.cudnn.benchmark = False\n",
31
+ "torch.backends.cudnn.deterministic = True\n",
32
+ "\n",
33
+ "import random\n",
34
+ "random.seed(0)\n",
35
+ "\n",
36
+ "import numpy as np\n",
37
+ "np.random.seed(0)"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "da84c60f",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "%cd .."
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "id": "5a3ddcc8",
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "# load packages\n",
58
+ "import time\n",
59
+ "import random\n",
60
+ "import yaml\n",
61
+ "from munch import Munch\n",
62
+ "import numpy as np\n",
63
+ "import torch\n",
64
+ "from torch import nn\n",
65
+ "import torch.nn.functional as F\n",
66
+ "import torchaudio\n",
67
+ "import librosa\n",
68
+ "from nltk.tokenize import word_tokenize\n",
69
+ "\n",
70
+ "from models import *\n",
71
+ "from utils import *\n",
72
+ "from text_utils import TextCleaner\n",
73
+ "textclenaer = TextCleaner()\n",
74
+ "\n",
75
+ "%matplotlib inline"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "id": "00ee05e1",
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
86
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
87
+ "mean, std = -4, 4\n",
88
+ "\n",
89
+ "def length_to_mask(lengths):\n",
90
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
91
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
92
+ " return mask\n",
93
+ "\n",
94
+ "def preprocess(wave):\n",
95
+ " wave_tensor = torch.from_numpy(wave).float()\n",
96
+ " mel_tensor = to_mel(wave_tensor)\n",
97
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
98
+ " return mel_tensor\n",
99
+ "\n",
100
+ "def compute_style(path):\n",
101
+ " wave, sr = librosa.load(path, sr=24000)\n",
102
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
103
+ " if sr != 24000:\n",
104
+ " audio = librosa.resample(audio, sr, 24000)\n",
105
+ " mel_tensor = preprocess(audio).to(device)\n",
106
+ "\n",
107
+ " with torch.no_grad():\n",
108
+ " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
109
+ " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
110
+ "\n",
111
+ " return torch.cat([ref_s, ref_p], dim=1)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "id": "bbdc04c0",
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "id": "7b9cecbe",
127
+ "metadata": {},
128
+ "source": [
129
+ "### Load models"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "id": "64fc4c0f",
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "# load phonemizer\n",
140
+ "import phonemizer\n",
141
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "48e7b644",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "config = yaml.safe_load(open(\"Models/LibriTTS/config.yml\"))\n",
152
+ "\n",
153
+ "# load pretrained ASR model\n",
154
+ "ASR_config = config.get('ASR_config', False)\n",
155
+ "ASR_path = config.get('ASR_path', False)\n",
156
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
157
+ "\n",
158
+ "# load pretrained F0 model\n",
159
+ "F0_path = config.get('F0_path', False)\n",
160
+ "pitch_extractor = load_F0_models(F0_path)\n",
161
+ "\n",
162
+ "# load BERT model\n",
163
+ "from Utils.PLBERT.util import load_plbert\n",
164
+ "BERT_path = config.get('PLBERT_dir', False)\n",
165
+ "plbert = load_plbert(BERT_path)"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": null,
171
+ "id": "ffc18cf7",
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "model_params = recursive_munch(config['model_params'])\n",
176
+ "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
177
+ "_ = [model[key].eval() for key in model]\n",
178
+ "_ = [model[key].to(device) for key in model]"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "id": "64529d5c",
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "params_whole = torch.load(\"Models/LibriTTS/epochs_2nd_00020.pth\", map_location='cpu')\n",
189
+ "params = params_whole['net']"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "895d9706",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "for key in model:\n",
200
+ " if key in params:\n",
201
+ " print('%s loaded' % key)\n",
202
+ " try:\n",
203
+ " model[key].load_state_dict(params[key])\n",
204
+ " except:\n",
205
+ " from collections import OrderedDict\n",
206
+ " state_dict = params[key]\n",
207
+ " new_state_dict = OrderedDict()\n",
208
+ " for k, v in state_dict.items():\n",
209
+ " name = k[7:] # remove `module.`\n",
210
+ " new_state_dict[name] = v\n",
211
+ " # load params\n",
212
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
213
+ "# except:\n",
214
+ "# _load(params[key], model[key])\n",
215
+ "_ = [model[key].eval() for key in model]"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "id": "c1a59db2",
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "id": "e30985ab",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "sampler = DiffusionSampler(\n",
236
+ " model.diffusion.diffusion,\n",
237
+ " sampler=ADPM2Sampler(),\n",
238
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
239
+ " clamp=False\n",
240
+ ")"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "markdown",
245
+ "id": "b803110e",
246
+ "metadata": {},
247
+ "source": [
248
+ "### Synthesize speech"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "id": "ca57469c",
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
259
+ " text = text.strip()\n",
260
+ " ps = global_phonemizer.phonemize([text])\n",
261
+ " ps = word_tokenize(ps[0])\n",
262
+ " ps = ' '.join(ps)\n",
263
+ " tokens = textclenaer(ps)\n",
264
+ " tokens.insert(0, 0)\n",
265
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
266
+ " \n",
267
+ " with torch.no_grad():\n",
268
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
269
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
270
+ "\n",
271
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
272
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
273
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
274
+ "\n",
275
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
276
+ " embedding=bert_dur,\n",
277
+ " embedding_scale=embedding_scale,\n",
278
+ " features=ref_s, # reference from the same speaker as the embedding\n",
279
+ " num_steps=diffusion_steps).squeeze(1)\n",
280
+ "\n",
281
+ "\n",
282
+ " s = s_pred[:, 128:]\n",
283
+ " ref = s_pred[:, :128]\n",
284
+ "\n",
285
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
286
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
287
+ "\n",
288
+ " d = model.predictor.text_encoder(d_en, \n",
289
+ " s, input_lengths, text_mask)\n",
290
+ "\n",
291
+ " x, _ = model.predictor.lstm(d)\n",
292
+ " duration = model.predictor.duration_proj(x)\n",
293
+ "\n",
294
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
295
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
296
+ "\n",
297
+ "\n",
298
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
299
+ " c_frame = 0\n",
300
+ " for i in range(pred_aln_trg.size(0)):\n",
301
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
302
+ " c_frame += int(pred_dur[i].data)\n",
303
+ "\n",
304
+ " # encode prosody\n",
305
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
306
+ " if model_params.decoder.type == \"hifigan\":\n",
307
+ " asr_new = torch.zeros_like(en)\n",
308
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
309
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
310
+ " en = asr_new\n",
311
+ "\n",
312
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
313
+ "\n",
314
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
315
+ " if model_params.decoder.type == \"hifigan\":\n",
316
+ " asr_new = torch.zeros_like(asr)\n",
317
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
318
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
319
+ " asr = asr_new\n",
320
+ "\n",
321
+ " out = model.decoder(asr, \n",
322
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
323
+ " \n",
324
+ " \n",
325
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "markdown",
330
+ "id": "d438ef4f",
331
+ "metadata": {},
332
+ "source": [
333
+ "#### Basic synthesis (5 diffusion steps, seen speakers)"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "cace9787",
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": [
343
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. '''"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "execution_count": null,
349
+ "id": "7c88f461",
350
+ "metadata": {},
351
+ "outputs": [],
352
+ "source": [
353
+ "reference_dicts = {}\n",
354
+ "reference_dicts['696_92939'] = \"Demo/reference_audio/696_92939_000016_000006.wav\"\n",
355
+ "reference_dicts['1789_142896'] = \"Demo/reference_audio/1789_142896_000022_000005.wav\""
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "id": "16e8ac60",
362
+ "metadata": {},
363
+ "outputs": [],
364
+ "source": [
365
+ "start = time.time()\n",
366
+ "noise = torch.randn(1,1,256).to(device)\n",
367
+ "for k, path in reference_dicts.items():\n",
368
+ " ref_s = compute_style(path)\n",
369
+ " \n",
370
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
371
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
372
+ " print(f\"RTF = {rtf:5f}\")\n",
373
+ " import IPython.display as ipd\n",
374
+ " print(k + ' Synthesized:')\n",
375
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
376
+ " print('Reference:')\n",
377
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "markdown",
382
+ "id": "14838708",
383
+ "metadata": {},
384
+ "source": [
385
+ "#### With higher diffusion steps (more diverse)\n",
386
+ "\n",
387
+ "Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed."
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": null,
393
+ "id": "6fbff03b",
394
+ "metadata": {},
395
+ "outputs": [],
396
+ "source": [
397
+ "noise = torch.randn(1,1,256).to(device)\n",
398
+ "for k, path in reference_dicts.items():\n",
399
+ " ref_s = compute_style(path)\n",
400
+ " start = time.time()\n",
401
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=10, embedding_scale=1)\n",
402
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
403
+ " print(f\"RTF = {rtf:5f}\")\n",
404
+ " import IPython.display as ipd\n",
405
+ " print(k + ' Synthesized:')\n",
406
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
407
+ " print(k + ' Reference:')\n",
408
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "markdown",
413
+ "id": "7e6867fd",
414
+ "metadata": {},
415
+ "source": [
416
+ "#### Basic synthesis (5 diffusion steps, umseen speakers)\n",
417
+ "The following samples are to reproduce samples in [Section 4](https://styletts2.github.io/#libri) of the demo page. All spsakers are unseen during training. You can compare the generated samples to popular zero-shot TTS models like Vall-E and NaturalSpeech 2."
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": null,
423
+ "id": "f4e8faa0",
424
+ "metadata": {},
425
+ "outputs": [],
426
+ "source": [
427
+ "reference_dicts = {}\n",
428
+ "# format: (path, text)\n",
429
+ "reference_dicts['1221-135767'] = (\"Demo/reference_audio/1221-135767-0014.wav\", \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\")\n",
430
+ "reference_dicts['5639-40744'] = (\"Demo/reference_audio/5639-40744-0020.wav\", \"Thus did this humane and right minded father comfort his unhappy daughter, and her mother embracing her again, did all she could to soothe her feelings.\")\n",
431
+ "reference_dicts['908-157963'] = (\"Demo/reference_audio/908-157963-0027.wav\", \"And lay me down in my cold bed and leave my shining lot.\")\n",
432
+ "reference_dicts['4077-13754'] = (\"Demo/reference_audio/4077-13754-0000.wav\", \"The army found the people in poverty and left them in comparative wealth.\")"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "id": "653f1406",
439
+ "metadata": {},
440
+ "outputs": [],
441
+ "source": [
442
+ "noise = torch.randn(1,1,256).to(device)\n",
443
+ "for k, v in reference_dicts.items():\n",
444
+ " path, text = v\n",
445
+ " ref_s = compute_style(path)\n",
446
+ " start = time.time()\n",
447
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
448
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
449
+ " print(f\"RTF = {rtf:5f}\")\n",
450
+ " import IPython.display as ipd\n",
451
+ " print(k + ' Synthesized: ' + text)\n",
452
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
453
+ " print(k + ' Reference:')\n",
454
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "markdown",
459
+ "id": "141e91b3",
460
+ "metadata": {},
461
+ "source": [
462
+ "### Speech expressiveness\n",
463
+ "\n",
464
+ "The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page. The speaker reference used is `1221-135767-0014.wav`, which is unseen during training. \n",
465
+ "\n",
466
+ "#### With `embedding_scale=1`\n",
467
+ "This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional.\n",
468
+ "\n"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": null,
474
+ "id": "81addda4",
475
+ "metadata": {},
476
+ "outputs": [],
477
+ "source": [
478
+ "ref_s = compute_style(\"Demo/reference_audio/1221-135767-0014.wav\")"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "id": "be1b2a11",
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": [
488
+ "texts = {}\n",
489
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
490
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
491
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
492
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
493
+ "\n",
494
+ "for k,v in texts.items():\n",
495
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
496
+ " print(k + \": \")\n",
497
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "markdown",
502
+ "id": "96d262b8",
503
+ "metadata": {},
504
+ "source": [
505
+ "#### With `embedding_scale=2`"
506
+ ]
507
+ },
508
+ {
509
+ "cell_type": "code",
510
+ "execution_count": null,
511
+ "id": "3e7d40b4",
512
+ "metadata": {},
513
+ "outputs": [],
514
+ "source": [
515
+ "texts = {}\n",
516
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
517
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
518
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
519
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
520
+ "\n",
521
+ "for k,v in texts.items():\n",
522
+ " noise = torch.randn(1,1,256).to(device)\n",
523
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=2)\n",
524
+ " print(k + \": \")\n",
525
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "markdown",
530
+ "id": "402b2bd6",
531
+ "metadata": {},
532
+ "source": [
533
+ "#### With `embedding_scale=2, alpha = 0.5, beta = 0.9`\n",
534
+ "`alpha` and `beta` is the factor to determine much we use the style sampled based on the text instead of the reference. The higher the value of `alpha` and `beta`, the more suitable the style it is to the text but less similar to the reference. Using higher beta makes the synthesized speech more emotional, at the cost of lower similarity to the reference. `alpha` determines the timbre of the speaker while `beta` determines the prosody. "
535
+ ]
536
+ },
537
+ {
538
+ "cell_type": "code",
539
+ "execution_count": null,
540
+ "id": "599de5d5",
541
+ "metadata": {},
542
+ "outputs": [],
543
+ "source": [
544
+ "texts = {}\n",
545
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
546
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
547
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
548
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
549
+ "\n",
550
+ "for k,v in texts.items():\n",
551
+ " noise = torch.randn(1,1,256).to(device)\n",
552
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=2)\n",
553
+ " print(k + \": \")\n",
554
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "markdown",
559
+ "id": "48548866",
560
+ "metadata": {},
561
+ "source": [
562
+ "### Zero-shot speaker adaptation\n",
563
+ "This section recreates the \"Acoustic Environment Maintenance\" and \"Speaker’s Emotion Maintenance\" demo in [Section 4](https://styletts2.github.io/#libri) of the demo page. You can compare the generated samples to popular zero-shot TTS models like Vall-E. Note that the model was trained only on LibriTTS, which is about 250 times fewer data compared to those used to trian Vall-E with similar or better effect for these maintainance. "
564
+ ]
565
+ },
566
+ {
567
+ "cell_type": "markdown",
568
+ "id": "23e81572",
569
+ "metadata": {},
570
+ "source": [
571
+ "#### Acoustic Environment Maintenance\n",
572
+ "\n",
573
+ "Since we want to maintain the acoustic environment in the speaker (timbre), we set `alpha = 0` to make the speaker as closer to the reference as possible while only changing the prosody according to the text. "
574
+ ]
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "execution_count": null,
579
+ "id": "8087bccb",
580
+ "metadata": {},
581
+ "outputs": [],
582
+ "source": [
583
+ "reference_dicts = {}\n",
584
+ "# format: (path, text)\n",
585
+ "reference_dicts['3'] = (\"Demo/reference_audio/3.wav\", \"As friends thing I definitely I've got more male friends.\")\n",
586
+ "reference_dicts['4'] = (\"Demo/reference_audio/4.wav\", \"Everything is run by computer but you got to know how to think before you can do a computer.\")\n",
587
+ "reference_dicts['5'] = (\"Demo/reference_audio/5.wav\", \"Then out in LA you guys got a whole another ball game within California to worry about.\")"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "code",
592
+ "execution_count": null,
593
+ "id": "1e99c200",
594
+ "metadata": {},
595
+ "outputs": [],
596
+ "source": [
597
+ "noise = torch.randn(1,1,256).to(device)\n",
598
+ "for k, v in reference_dicts.items():\n",
599
+ " path, text = v\n",
600
+ " ref_s = compute_style(path)\n",
601
+ " start = time.time()\n",
602
+ " wav = inference(text, ref_s, alpha=0.0, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
603
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
604
+ " print(f\"RTF = {rtf:5f}\")\n",
605
+ " import IPython.display as ipd\n",
606
+ " print('Synthesized: ' + text)\n",
607
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
608
+ " print('Reference:')\n",
609
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "markdown",
614
+ "id": "7d56505d",
615
+ "metadata": {},
616
+ "source": [
617
+ "#### Speaker’s Emotion Maintenance\n",
618
+ "\n",
619
+ "Since we want to maintain the emotion in the speaker (prosody), we set `beta = 0.1` to make the speaker as closer to the reference as possible while having some diversity thruogh the slight timbre change."
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "code",
624
+ "execution_count": null,
625
+ "id": "f90179e7",
626
+ "metadata": {},
627
+ "outputs": [],
628
+ "source": [
629
+ "reference_dicts = {}\n",
630
+ "# format: (path, text)\n",
631
+ "reference_dicts['Anger'] = (\"Demo/reference_audio/anger.wav\", \"We have to reduce the number of plastic bags.\")\n",
632
+ "reference_dicts['Sleepy'] = (\"Demo/reference_audio/sleepy.wav\", \"We have to reduce the number of plastic bags.\")\n",
633
+ "reference_dicts['Amused'] = (\"Demo/reference_audio/amused.wav\", \"We have to reduce the number of plastic bags.\")\n",
634
+ "reference_dicts['Disgusted'] = (\"Demo/reference_audio/disgusted.wav\", \"We have to reduce the number of plastic bags.\")"
635
+ ]
636
+ },
637
+ {
638
+ "cell_type": "code",
639
+ "execution_count": null,
640
+ "id": "2e6bdfed",
641
+ "metadata": {},
642
+ "outputs": [],
643
+ "source": [
644
+ "noise = torch.randn(1,1,256).to(device)\n",
645
+ "for k, v in reference_dicts.items():\n",
646
+ " path, text = v\n",
647
+ " ref_s = compute_style(path)\n",
648
+ " start = time.time()\n",
649
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.1, diffusion_steps=10, embedding_scale=1)\n",
650
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
651
+ " print(f\"RTF = {rtf:5f}\")\n",
652
+ " import IPython.display as ipd\n",
653
+ " print(k + ' Synthesized: ' + text)\n",
654
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
655
+ " print(k + ' Reference:')\n",
656
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
657
+ ]
658
+ },
659
+ {
660
+ "cell_type": "markdown",
661
+ "id": "37ae3963",
662
+ "metadata": {},
663
+ "source": [
664
+ "### Longform Narration\n",
665
+ "\n",
666
+ "This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page."
667
+ ]
668
+ },
669
+ {
670
+ "cell_type": "code",
671
+ "execution_count": null,
672
+ "id": "f12a716b",
673
+ "metadata": {},
674
+ "outputs": [],
675
+ "source": [
676
+ "passage = '''If the supply of fruit is greater than the family needs, it may be made a source of income by sending the fresh fruit to the market if there is one near enough, or by preserving, canning, and making jelly for sale. To make such an enterprise a success the fruit and work must be first class. There is magic in the word \"Homemade,\" when the product appeals to the eye and the palate; but many careless and incompetent people have found to their sorrow that this word has not magic enough to float inferior goods on the market. As a rule large canning and preserving establishments are clean and have the best appliances, and they employ chemists and skilled labor. The home product must be very good to compete with the attractive goods that are sent out from such establishments. Yet for first class home made products there is a market in all large cities. All first-class grocers have customers who purchase such goods.'''"
677
+ ]
678
+ },
679
+ {
680
+ "cell_type": "code",
681
+ "execution_count": null,
682
+ "id": "a1a38079",
683
+ "metadata": {},
684
+ "outputs": [],
685
+ "source": [
686
+ "def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):\n",
687
+ " text = text.strip()\n",
688
+ " ps = global_phonemizer.phonemize([text])\n",
689
+ " ps = word_tokenize(ps[0])\n",
690
+ " ps = ' '.join(ps)\n",
691
+ " ps = ps.replace('``', '\"')\n",
692
+ " ps = ps.replace(\"''\", '\"')\n",
693
+ "\n",
694
+ " tokens = textclenaer(ps)\n",
695
+ " tokens.insert(0, 0)\n",
696
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
697
+ " \n",
698
+ " with torch.no_grad():\n",
699
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
700
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
701
+ "\n",
702
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
703
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
704
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
705
+ "\n",
706
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
707
+ " embedding=bert_dur,\n",
708
+ " embedding_scale=embedding_scale,\n",
709
+ " features=ref_s, # reference from the same speaker as the embedding\n",
710
+ " num_steps=diffusion_steps).squeeze(1)\n",
711
+ " \n",
712
+ " if s_prev is not None:\n",
713
+ " # convex combination of previous and current style\n",
714
+ " s_pred = t * s_prev + (1 - t) * s_pred\n",
715
+ " \n",
716
+ " s = s_pred[:, 128:]\n",
717
+ " ref = s_pred[:, :128]\n",
718
+ " \n",
719
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
720
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
721
+ "\n",
722
+ " s_pred = torch.cat([ref, s], dim=-1)\n",
723
+ "\n",
724
+ " d = model.predictor.text_encoder(d_en, \n",
725
+ " s, input_lengths, text_mask)\n",
726
+ "\n",
727
+ " x, _ = model.predictor.lstm(d)\n",
728
+ " duration = model.predictor.duration_proj(x)\n",
729
+ "\n",
730
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
731
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
732
+ "\n",
733
+ "\n",
734
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
735
+ " c_frame = 0\n",
736
+ " for i in range(pred_aln_trg.size(0)):\n",
737
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
738
+ " c_frame += int(pred_dur[i].data)\n",
739
+ "\n",
740
+ " # encode prosody\n",
741
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
742
+ " if model_params.decoder.type == \"hifigan\":\n",
743
+ " asr_new = torch.zeros_like(en)\n",
744
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
745
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
746
+ " en = asr_new\n",
747
+ "\n",
748
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
749
+ "\n",
750
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
751
+ " if model_params.decoder.type == \"hifigan\":\n",
752
+ " asr_new = torch.zeros_like(asr)\n",
753
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
754
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
755
+ " asr = asr_new\n",
756
+ "\n",
757
+ " out = model.decoder(asr, \n",
758
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
759
+ " \n",
760
+ " \n",
761
+ " return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later"
762
+ ]
763
+ },
764
+ {
765
+ "cell_type": "code",
766
+ "execution_count": null,
767
+ "id": "e9088f7a",
768
+ "metadata": {},
769
+ "outputs": [],
770
+ "source": [
771
+ "# unseen speaker\n",
772
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
773
+ "s_ref = compute_style(path)\n",
774
+ "sentences = passage.split('.') # simple split by comma\n",
775
+ "wavs = []\n",
776
+ "s_prev = None\n",
777
+ "for text in sentences:\n",
778
+ " if text.strip() == \"\": continue\n",
779
+ " text += '.' # add it back\n",
780
+ " \n",
781
+ " wav, s_prev = LFinference(text, \n",
782
+ " s_prev, \n",
783
+ " s_ref, \n",
784
+ " alpha = 0.3, \n",
785
+ " beta = 0.9, # make it more suitable for the text\n",
786
+ " t = 0.7, \n",
787
+ " diffusion_steps=10, embedding_scale=1.5)\n",
788
+ " wavs.append(wav)\n",
789
+ "print('Synthesized: ')\n",
790
+ "display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))\n",
791
+ "print('Reference: ')\n",
792
+ "display(ipd.Audio(path, rate=24000, normalize=False))"
793
+ ]
794
+ },
795
+ {
796
+ "cell_type": "markdown",
797
+ "id": "7517b657",
798
+ "metadata": {},
799
+ "source": [
800
+ "### Style Transfer\n",
801
+ "\n",
802
+ "The following section demostrates the style transfer capacity for unseen speakers in [Section 6](https://styletts2.github.io/#emo) of the demo page. For this, we set `alpha=0.5, beta = 0.9` for the most pronounced effects (mostly using the sampled style). "
803
+ ]
804
+ },
805
+ {
806
+ "cell_type": "code",
807
+ "execution_count": null,
808
+ "id": "ed95d0f7",
809
+ "metadata": {},
810
+ "outputs": [],
811
+ "source": [
812
+ "def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
813
+ " text = text.strip()\n",
814
+ " ps = global_phonemizer.phonemize([text])\n",
815
+ " ps = word_tokenize(ps[0])\n",
816
+ " ps = ' '.join(ps)\n",
817
+ "\n",
818
+ " tokens = textclenaer(ps)\n",
819
+ " tokens.insert(0, 0)\n",
820
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
821
+ " \n",
822
+ " ref_text = ref_text.strip()\n",
823
+ " ps = global_phonemizer.phonemize([ref_text])\n",
824
+ " ps = word_tokenize(ps[0])\n",
825
+ " ps = ' '.join(ps)\n",
826
+ "\n",
827
+ " ref_tokens = textclenaer(ps)\n",
828
+ " ref_tokens.insert(0, 0)\n",
829
+ " ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)\n",
830
+ " \n",
831
+ " \n",
832
+ " with torch.no_grad():\n",
833
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
834
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
835
+ "\n",
836
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
837
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
838
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
839
+ " \n",
840
+ " ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)\n",
841
+ " ref_text_mask = length_to_mask(ref_input_lengths).to(device)\n",
842
+ " ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())\n",
843
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
844
+ " embedding=bert_dur,\n",
845
+ " embedding_scale=embedding_scale,\n",
846
+ " features=ref_s, # reference from the same speaker as the embedding\n",
847
+ " num_steps=diffusion_steps).squeeze(1)\n",
848
+ "\n",
849
+ "\n",
850
+ " s = s_pred[:, 128:]\n",
851
+ " ref = s_pred[:, :128]\n",
852
+ "\n",
853
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
854
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
855
+ "\n",
856
+ " d = model.predictor.text_encoder(d_en, \n",
857
+ " s, input_lengths, text_mask)\n",
858
+ "\n",
859
+ " x, _ = model.predictor.lstm(d)\n",
860
+ " duration = model.predictor.duration_proj(x)\n",
861
+ "\n",
862
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
863
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
864
+ "\n",
865
+ "\n",
866
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
867
+ " c_frame = 0\n",
868
+ " for i in range(pred_aln_trg.size(0)):\n",
869
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
870
+ " c_frame += int(pred_dur[i].data)\n",
871
+ "\n",
872
+ " # encode prosody\n",
873
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
874
+ " if model_params.decoder.type == \"hifigan\":\n",
875
+ " asr_new = torch.zeros_like(en)\n",
876
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
877
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
878
+ " en = asr_new\n",
879
+ "\n",
880
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
881
+ "\n",
882
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
883
+ " if model_params.decoder.type == \"hifigan\":\n",
884
+ " asr_new = torch.zeros_like(asr)\n",
885
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
886
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
887
+ " asr = asr_new\n",
888
+ "\n",
889
+ " out = model.decoder(asr, \n",
890
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
891
+ " \n",
892
+ " \n",
893
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
894
+ ]
895
+ },
896
+ {
897
+ "cell_type": "code",
898
+ "execution_count": null,
899
+ "id": "ec3f0da4",
900
+ "metadata": {},
901
+ "outputs": [],
902
+ "source": [
903
+ "# reference texts to sample styles\n",
904
+ "\n",
905
+ "ref_texts = {}\n",
906
+ "ref_texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
907
+ "ref_texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
908
+ "ref_texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
909
+ "ref_texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\""
910
+ ]
911
+ },
912
+ {
913
+ "cell_type": "code",
914
+ "execution_count": null,
915
+ "id": "6d0a3825",
916
+ "metadata": {
917
+ "scrolled": false
918
+ },
919
+ "outputs": [],
920
+ "source": [
921
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
922
+ "s_ref = compute_style(path)\n",
923
+ "\n",
924
+ "text = \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\"\n",
925
+ "for k,v in ref_texts.items():\n",
926
+ " wav = STinference(text, s_ref, v, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=1.5)\n",
927
+ " print(k + \": \")\n",
928
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
929
+ ]
930
+ },
931
+ {
932
+ "cell_type": "markdown",
933
+ "id": "6750aed9",
934
+ "metadata": {},
935
+ "source": [
936
+ "### Speech diversity\n",
937
+ "\n",
938
+ "This section reproduces samples in [Section 7](https://styletts2.github.io/#var) of the demo page. \n",
939
+ "\n",
940
+ "`alpha` and `beta` determine the diversity of the synthesized speech. There are two extreme cases:\n",
941
+ "- If `alpha = 1` and `beta = 1`, the synthesized speech sounds the most dissimilar to the reference speaker, but it is also the most diverse (each time you synthesize a speech it will be totally different). \n",
942
+ "- If `alpha = 0` and `beta = 0`, the synthesized speech sounds the most siimlar to the reference speaker, but it is deterministic (i.e., the sampled style is not used for speech synthesis). \n"
943
+ ]
944
+ },
945
+ {
946
+ "cell_type": "markdown",
947
+ "id": "f6ae0aa5",
948
+ "metadata": {},
949
+ "source": [
950
+ "#### Default setting (`alpha = 0.3, beta=0.7`)\n",
951
+ "This setting uses 70% of the reference timbre and 30% of the reference prosody and use the diffusion model to sample them based on the text. "
952
+ ]
953
+ },
954
+ {
955
+ "cell_type": "code",
956
+ "execution_count": null,
957
+ "id": "36dc0148",
958
+ "metadata": {},
959
+ "outputs": [],
960
+ "source": [
961
+ "# unseen speaker\n",
962
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
963
+ "ref_s = compute_style(path)\n",
964
+ "\n",
965
+ "text = \"How much variation is there?\"\n",
966
+ "for _ in range(5):\n",
967
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
968
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
969
+ ]
970
+ },
971
+ {
972
+ "cell_type": "markdown",
973
+ "id": "bf9ef421",
974
+ "metadata": {},
975
+ "source": [
976
+ "#### Less diverse setting (`alpha = 0.1, beta=0.3`)\n",
977
+ "This setting uses 90% of the reference timbre and 70% of the reference prosody. This makes it more similar to the reference speaker at cost of less diverse samples. "
978
+ ]
979
+ },
980
+ {
981
+ "cell_type": "code",
982
+ "execution_count": null,
983
+ "id": "9ba406bd",
984
+ "metadata": {},
985
+ "outputs": [],
986
+ "source": [
987
+ "# unseen speaker\n",
988
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
989
+ "ref_s = compute_style(path)\n",
990
+ "\n",
991
+ "text = \"How much variation is there?\"\n",
992
+ "for _ in range(5):\n",
993
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.1, beta=0.3, embedding_scale=1)\n",
994
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
995
+ ]
996
+ },
997
+ {
998
+ "cell_type": "markdown",
999
+ "id": "a38fe464",
1000
+ "metadata": {},
1001
+ "source": [
1002
+ "#### More diverse setting (`alpha = 0.5, beta=0.95`)\n",
1003
+ "This setting uses 50% of the reference timbre and 5% of the reference prosody (so it uses 100% of the sampled prosody, which makes it more diverse), but this makes it more dissimilar to the reference speaker. "
1004
+ ]
1005
+ },
1006
+ {
1007
+ "cell_type": "code",
1008
+ "execution_count": null,
1009
+ "id": "5f25bf94",
1010
+ "metadata": {},
1011
+ "outputs": [],
1012
+ "source": [
1013
+ "# unseen speaker\n",
1014
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1015
+ "ref_s = compute_style(path)\n",
1016
+ "\n",
1017
+ "text = \"How much variation is there?\"\n",
1018
+ "for _ in range(5):\n",
1019
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.5, beta=0.95, embedding_scale=1)\n",
1020
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1021
+ ]
1022
+ },
1023
+ {
1024
+ "cell_type": "markdown",
1025
+ "id": "21c3a071",
1026
+ "metadata": {},
1027
+ "source": [
1028
+ "#### Extreme setting (`alpha = 1, beta=1`)\n",
1029
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very dissimilar to the reference speaker. "
1030
+ ]
1031
+ },
1032
+ {
1033
+ "cell_type": "code",
1034
+ "execution_count": null,
1035
+ "id": "fff8bab1",
1036
+ "metadata": {},
1037
+ "outputs": [],
1038
+ "source": [
1039
+ "# unseen speaker\n",
1040
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1041
+ "ref_s = compute_style(path)\n",
1042
+ "\n",
1043
+ "text = \"How much variation is there?\"\n",
1044
+ "for _ in range(5):\n",
1045
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=1, beta=1, embedding_scale=1)\n",
1046
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1047
+ ]
1048
+ },
1049
+ {
1050
+ "cell_type": "markdown",
1051
+ "id": "a8741e5a",
1052
+ "metadata": {},
1053
+ "source": [
1054
+ "#### No variation (`alpha = 0, beta=0`)\n",
1055
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very similar to the reference speaker, but there is no variation. "
1056
+ ]
1057
+ },
1058
+ {
1059
+ "cell_type": "code",
1060
+ "execution_count": null,
1061
+ "id": "e55dd281",
1062
+ "metadata": {},
1063
+ "outputs": [],
1064
+ "source": [
1065
+ "# unseen speaker\n",
1066
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1067
+ "ref_s = compute_style(path)\n",
1068
+ "\n",
1069
+ "text = \"How much variation is there?\"\n",
1070
+ "for _ in range(5):\n",
1071
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0, beta=0, embedding_scale=1)\n",
1072
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1073
+ ]
1074
+ },
1075
+ {
1076
+ "cell_type": "markdown",
1077
+ "id": "d5e86423",
1078
+ "metadata": {},
1079
+ "source": [
1080
+ "### Extra fun!\n",
1081
+ "\n",
1082
+ "Here we clone some of the authors' voice of the StyleTTS 2 papers with a few seconds of the recording in the wild. None of the voices is in the dataset and all authors agreed to have their voices cloned here."
1083
+ ]
1084
+ },
1085
+ {
1086
+ "cell_type": "code",
1087
+ "execution_count": null,
1088
+ "id": "6f558314",
1089
+ "metadata": {},
1090
+ "outputs": [],
1091
+ "source": [
1092
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. '''"
1093
+ ]
1094
+ },
1095
+ {
1096
+ "cell_type": "code",
1097
+ "execution_count": null,
1098
+ "id": "caa5747c",
1099
+ "metadata": {},
1100
+ "outputs": [],
1101
+ "source": [
1102
+ "reference_dicts = {}\n",
1103
+ "reference_dicts['Yinghao'] = \"Demo/reference_audio/Yinghao.wav\"\n",
1104
+ "reference_dicts['Gavin'] = \"Demo/reference_audio/Gavin.wav\"\n",
1105
+ "reference_dicts['Vinay'] = \"Demo/reference_audio/Vinay.wav\"\n",
1106
+ "reference_dicts['Nima'] = \"Demo/reference_audio/Nima.wav\""
1107
+ ]
1108
+ },
1109
+ {
1110
+ "cell_type": "code",
1111
+ "execution_count": null,
1112
+ "id": "44a4cea1",
1113
+ "metadata": {
1114
+ "scrolled": false
1115
+ },
1116
+ "outputs": [],
1117
+ "source": [
1118
+ "start = time.time()\n",
1119
+ "noise = torch.randn(1,1,256).to(device)\n",
1120
+ "for k, path in reference_dicts.items():\n",
1121
+ " ref_s = compute_style(path)\n",
1122
+ " \n",
1123
+ " wav = inference(text, ref_s, alpha=0.1, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
1124
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
1125
+ " print('Speaker: ' + k)\n",
1126
+ " import IPython.display as ipd\n",
1127
+ " print('Synthesized:')\n",
1128
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
1129
+ " print('Reference:')\n",
1130
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
1131
+ ]
1132
+ }
1133
+ ],
1134
+ "metadata": {
1135
+ "kernelspec": {
1136
+ "display_name": "NLP",
1137
+ "language": "python",
1138
+ "name": "nlp"
1139
+ },
1140
+ "language_info": {
1141
+ "codemirror_mode": {
1142
+ "name": "ipython",
1143
+ "version": 3
1144
+ },
1145
+ "file_extension": ".py",
1146
+ "mimetype": "text/x-python",
1147
+ "name": "python",
1148
+ "nbconvert_exporter": "python",
1149
+ "pygments_lexer": "ipython3",
1150
+ "version": "3.9.7"
1151
+ }
1152
+ },
1153
+ "nbformat": 4,
1154
+ "nbformat_minor": 5
1155
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Aaron (Yinghao) Li
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Modules/discriminators.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, spectral_norm
6
+
7
+ from .utils import get_padding
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+ def stft(x, fft_size, hop_size, win_length, window):
12
+ """Perform STFT and convert to magnitude spectrogram.
13
+ Args:
14
+ x (Tensor): Input signal tensor (B, T).
15
+ fft_size (int): FFT size.
16
+ hop_size (int): Hop size.
17
+ win_length (int): Window length.
18
+ window (str): Window function type.
19
+ Returns:
20
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
21
+ """
22
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
23
+ return_complex=True)
24
+ real = x_stft[..., 0]
25
+ imag = x_stft[..., 1]
26
+
27
+ return torch.abs(x_stft).transpose(2, 1)
28
+
29
+ class SpecDiscriminator(nn.Module):
30
+ """docstring for Discriminator."""
31
+
32
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
33
+ super(SpecDiscriminator, self).__init__()
34
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
35
+ self.fft_size = fft_size
36
+ self.shift_size = shift_size
37
+ self.win_length = win_length
38
+ self.window = getattr(torch, window)(win_length)
39
+ self.discriminators = nn.ModuleList([
40
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
41
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
42
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
43
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
44
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
45
+ ])
46
+
47
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
48
+
49
+ def forward(self, y):
50
+
51
+ fmap = []
52
+ y = y.squeeze(1)
53
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
54
+ y = y.unsqueeze(1)
55
+ for i, d in enumerate(self.discriminators):
56
+ y = d(y)
57
+ y = F.leaky_relu(y, LRELU_SLOPE)
58
+ fmap.append(y)
59
+
60
+ y = self.out(y)
61
+ fmap.append(y)
62
+
63
+ return torch.flatten(y, 1, -1), fmap
64
+
65
+ class MultiResSpecDiscriminator(torch.nn.Module):
66
+
67
+ def __init__(self,
68
+ fft_sizes=[1024, 2048, 512],
69
+ hop_sizes=[120, 240, 50],
70
+ win_lengths=[600, 1200, 240],
71
+ window="hann_window"):
72
+
73
+ super(MultiResSpecDiscriminator, self).__init__()
74
+ self.discriminators = nn.ModuleList([
75
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
76
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
77
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
78
+ ])
79
+
80
+ def forward(self, y, y_hat):
81
+ y_d_rs = []
82
+ y_d_gs = []
83
+ fmap_rs = []
84
+ fmap_gs = []
85
+ for i, d in enumerate(self.discriminators):
86
+ y_d_r, fmap_r = d(y)
87
+ y_d_g, fmap_g = d(y_hat)
88
+ y_d_rs.append(y_d_r)
89
+ fmap_rs.append(fmap_r)
90
+ y_d_gs.append(y_d_g)
91
+ fmap_gs.append(fmap_g)
92
+
93
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
94
+
95
+
96
+ class DiscriminatorP(torch.nn.Module):
97
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
98
+ super(DiscriminatorP, self).__init__()
99
+ self.period = period
100
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
101
+ self.convs = nn.ModuleList([
102
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
103
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
104
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
105
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
106
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
107
+ ])
108
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
109
+
110
+ def forward(self, x):
111
+ fmap = []
112
+
113
+ # 1d to 2d
114
+ b, c, t = x.shape
115
+ if t % self.period != 0: # pad first
116
+ n_pad = self.period - (t % self.period)
117
+ x = F.pad(x, (0, n_pad), "reflect")
118
+ t = t + n_pad
119
+ x = x.view(b, c, t // self.period, self.period)
120
+
121
+ for l in self.convs:
122
+ x = l(x)
123
+ x = F.leaky_relu(x, LRELU_SLOPE)
124
+ fmap.append(x)
125
+ x = self.conv_post(x)
126
+ fmap.append(x)
127
+ x = torch.flatten(x, 1, -1)
128
+
129
+ return x, fmap
130
+
131
+
132
+ class MultiPeriodDiscriminator(torch.nn.Module):
133
+ def __init__(self):
134
+ super(MultiPeriodDiscriminator, self).__init__()
135
+ self.discriminators = nn.ModuleList([
136
+ DiscriminatorP(2),
137
+ DiscriminatorP(3),
138
+ DiscriminatorP(5),
139
+ DiscriminatorP(7),
140
+ DiscriminatorP(11),
141
+ ])
142
+
143
+ def forward(self, y, y_hat):
144
+ y_d_rs = []
145
+ y_d_gs = []
146
+ fmap_rs = []
147
+ fmap_gs = []
148
+ for i, d in enumerate(self.discriminators):
149
+ y_d_r, fmap_r = d(y)
150
+ y_d_g, fmap_g = d(y_hat)
151
+ y_d_rs.append(y_d_r)
152
+ fmap_rs.append(fmap_r)
153
+ y_d_gs.append(y_d_g)
154
+ fmap_gs.append(fmap_g)
155
+
156
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
157
+
158
+ class WavLMDiscriminator(nn.Module):
159
+ """docstring for Discriminator."""
160
+
161
+ def __init__(self, slm_hidden=768,
162
+ slm_layers=13,
163
+ initial_channel=64,
164
+ use_spectral_norm=False):
165
+ super(WavLMDiscriminator, self).__init__()
166
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
167
+ self.pre = norm_f(Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0))
168
+
169
+ self.convs = nn.ModuleList([
170
+ norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)),
171
+ norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)),
172
+ norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)),
173
+ ])
174
+
175
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
176
+
177
+ def forward(self, x):
178
+ x = self.pre(x)
179
+
180
+ fmap = []
181
+ for l in self.convs:
182
+ x = l(x)
183
+ x = F.leaky_relu(x, LRELU_SLOPE)
184
+ fmap.append(x)
185
+ x = self.conv_post(x)
186
+ x = torch.flatten(x, 1, -1)
187
+
188
+ return x
Modules/hifigan.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+ class AdaIN1d(nn.Module):
15
+ def __init__(self, style_dim, num_features):
16
+ super().__init__()
17
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
18
+ self.fc = nn.Linear(style_dim, num_features*2)
19
+
20
+ def forward(self, x, s):
21
+ h = self.fc(s)
22
+ h = h.view(h.size(0), h.size(1), 1)
23
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
24
+ return (1 + gamma) * self.norm(x) + beta
25
+
26
+ class AdaINResBlock1(torch.nn.Module):
27
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
28
+ super(AdaINResBlock1, self).__init__()
29
+ self.convs1 = nn.ModuleList([
30
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
31
+ padding=get_padding(kernel_size, dilation[0]))),
32
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
33
+ padding=get_padding(kernel_size, dilation[1]))),
34
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
35
+ padding=get_padding(kernel_size, dilation[2])))
36
+ ])
37
+ self.convs1.apply(init_weights)
38
+
39
+ self.convs2 = nn.ModuleList([
40
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
41
+ padding=get_padding(kernel_size, 1))),
42
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
43
+ padding=get_padding(kernel_size, 1))),
44
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
45
+ padding=get_padding(kernel_size, 1)))
46
+ ])
47
+ self.convs2.apply(init_weights)
48
+
49
+ self.adain1 = nn.ModuleList([
50
+ AdaIN1d(style_dim, channels),
51
+ AdaIN1d(style_dim, channels),
52
+ AdaIN1d(style_dim, channels),
53
+ ])
54
+
55
+ self.adain2 = nn.ModuleList([
56
+ AdaIN1d(style_dim, channels),
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ ])
60
+
61
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
62
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
63
+
64
+
65
+ def forward(self, x, s):
66
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
67
+ xt = n1(x, s)
68
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
69
+ xt = c1(xt)
70
+ xt = n2(xt, s)
71
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
72
+ xt = c2(xt)
73
+ x = xt + x
74
+ return x
75
+
76
+ def remove_weight_norm(self):
77
+ for l in self.convs1:
78
+ remove_weight_norm(l)
79
+ for l in self.convs2:
80
+ remove_weight_norm(l)
81
+
82
+ class SineGen(torch.nn.Module):
83
+ """ Definition of sine generator
84
+ SineGen(samp_rate, harmonic_num = 0,
85
+ sine_amp = 0.1, noise_std = 0.003,
86
+ voiced_threshold = 0,
87
+ flag_for_pulse=False)
88
+ samp_rate: sampling rate in Hz
89
+ harmonic_num: number of harmonic overtones (default 0)
90
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
91
+ noise_std: std of Gaussian noise (default 0.003)
92
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
93
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
94
+ Note: when flag_for_pulse is True, the first time step of a voiced
95
+ segment is always sin(np.pi) or cos(0)
96
+ """
97
+
98
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
99
+ sine_amp=0.1, noise_std=0.003,
100
+ voiced_threshold=0,
101
+ flag_for_pulse=False):
102
+ super(SineGen, self).__init__()
103
+ self.sine_amp = sine_amp
104
+ self.noise_std = noise_std
105
+ self.harmonic_num = harmonic_num
106
+ self.dim = self.harmonic_num + 1
107
+ self.sampling_rate = samp_rate
108
+ self.voiced_threshold = voiced_threshold
109
+ self.flag_for_pulse = flag_for_pulse
110
+ self.upsample_scale = upsample_scale
111
+
112
+ def _f02uv(self, f0):
113
+ # generate uv signal
114
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
115
+ return uv
116
+
117
+ def _f02sine(self, f0_values):
118
+ """ f0_values: (batchsize, length, dim)
119
+ where dim indicates fundamental tone and overtones
120
+ """
121
+ # convert to F0 in rad. The interger part n can be ignored
122
+ # because 2 * np.pi * n doesn't affect phase
123
+ rad_values = (f0_values / self.sampling_rate) % 1
124
+
125
+ # initial phase noise (no noise for fundamental component)
126
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
127
+ device=f0_values.device)
128
+ rand_ini[:, 0] = 0
129
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
130
+
131
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
132
+ if not self.flag_for_pulse:
133
+ # # for normal case
134
+
135
+ # # To prevent torch.cumsum numerical overflow,
136
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
137
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
138
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
139
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
140
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
141
+ # cumsum_shift = torch.zeros_like(rad_values)
142
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
143
+
144
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
145
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
146
+ scale_factor=1/self.upsample_scale,
147
+ mode="linear").transpose(1, 2)
148
+
149
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
150
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
151
+ # cumsum_shift = torch.zeros_like(rad_values)
152
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
153
+
154
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
155
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
156
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
157
+ sines = torch.sin(phase)
158
+
159
+ else:
160
+ # If necessary, make sure that the first time step of every
161
+ # voiced segments is sin(pi) or cos(0)
162
+ # This is used for pulse-train generation
163
+
164
+ # identify the last time step in unvoiced segments
165
+ uv = self._f02uv(f0_values)
166
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
167
+ uv_1[:, -1, :] = 1
168
+ u_loc = (uv < 1) * (uv_1 > 0)
169
+
170
+ # get the instantanouse phase
171
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
172
+ # different batch needs to be processed differently
173
+ for idx in range(f0_values.shape[0]):
174
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
175
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
176
+ # stores the accumulation of i.phase within
177
+ # each voiced segments
178
+ tmp_cumsum[idx, :, :] = 0
179
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
180
+
181
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
182
+ # within the previous voiced segment.
183
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
184
+
185
+ # get the sines
186
+ sines = torch.cos(i_phase * 2 * np.pi)
187
+ return sines
188
+
189
+ def forward(self, f0):
190
+ """ sine_tensor, uv = forward(f0)
191
+ input F0: tensor(batchsize=1, length, dim=1)
192
+ f0 for unvoiced steps should be 0
193
+ output sine_tensor: tensor(batchsize=1, length, dim)
194
+ output uv: tensor(batchsize=1, length, 1)
195
+ """
196
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
197
+ device=f0.device)
198
+ # fundamental component
199
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
200
+
201
+ # generate sine waveforms
202
+ sine_waves = self._f02sine(fn) * self.sine_amp
203
+
204
+ # generate uv signal
205
+ # uv = torch.ones(f0.shape)
206
+ # uv = uv * (f0 > self.voiced_threshold)
207
+ uv = self._f02uv(f0)
208
+
209
+ # noise: for unvoiced should be similar to sine_amp
210
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
211
+ # . for voiced regions is self.noise_std
212
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
213
+ noise = noise_amp * torch.randn_like(sine_waves)
214
+
215
+ # first: set the unvoiced part to 0 by uv
216
+ # then: additive noise
217
+ sine_waves = sine_waves * uv + noise
218
+ return sine_waves, uv, noise
219
+
220
+
221
+ class SourceModuleHnNSF(torch.nn.Module):
222
+ """ SourceModule for hn-nsf
223
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
224
+ add_noise_std=0.003, voiced_threshod=0)
225
+ sampling_rate: sampling_rate in Hz
226
+ harmonic_num: number of harmonic above F0 (default: 0)
227
+ sine_amp: amplitude of sine source signal (default: 0.1)
228
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
229
+ note that amplitude of noise in unvoiced is decided
230
+ by sine_amp
231
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
232
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
233
+ F0_sampled (batchsize, length, 1)
234
+ Sine_source (batchsize, length, 1)
235
+ noise_source (batchsize, length 1)
236
+ uv (batchsize, length, 1)
237
+ """
238
+
239
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
240
+ add_noise_std=0.003, voiced_threshod=0):
241
+ super(SourceModuleHnNSF, self).__init__()
242
+
243
+ self.sine_amp = sine_amp
244
+ self.noise_std = add_noise_std
245
+
246
+ # to produce sine waveforms
247
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
248
+ sine_amp, add_noise_std, voiced_threshod)
249
+
250
+ # to merge source harmonics into a single excitation
251
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
252
+ self.l_tanh = torch.nn.Tanh()
253
+
254
+ def forward(self, x):
255
+ """
256
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
257
+ F0_sampled (batchsize, length, 1)
258
+ Sine_source (batchsize, length, 1)
259
+ noise_source (batchsize, length 1)
260
+ """
261
+ # source for harmonic branch
262
+ with torch.no_grad():
263
+ sine_wavs, uv, _ = self.l_sin_gen(x)
264
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
265
+
266
+ # source for noise branch, in the same shape as uv
267
+ noise = torch.randn_like(uv) * self.sine_amp / 3
268
+ return sine_merge, noise, uv
269
+ def padDiff(x):
270
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
271
+
272
+ class Generator(torch.nn.Module):
273
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes):
274
+ super(Generator, self).__init__()
275
+ self.num_kernels = len(resblock_kernel_sizes)
276
+ self.num_upsamples = len(upsample_rates)
277
+ resblock = AdaINResBlock1
278
+
279
+ self.m_source = SourceModuleHnNSF(
280
+ sampling_rate=24000,
281
+ upsample_scale=np.prod(upsample_rates),
282
+ harmonic_num=8, voiced_threshod=10)
283
+
284
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
285
+ self.noise_convs = nn.ModuleList()
286
+ self.ups = nn.ModuleList()
287
+ self.noise_res = nn.ModuleList()
288
+
289
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
290
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
291
+
292
+ self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel//(2**i),
293
+ upsample_initial_channel//(2**(i+1)),
294
+ k, u, padding=(u//2 + u%2), output_padding=u%2)))
295
+
296
+ if i + 1 < len(upsample_rates): #
297
+ stride_f0 = np.prod(upsample_rates[i + 1:])
298
+ self.noise_convs.append(Conv1d(
299
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
300
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
301
+ else:
302
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
303
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
304
+
305
+ self.resblocks = nn.ModuleList()
306
+
307
+ self.alphas = nn.ParameterList()
308
+ self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1)))
309
+
310
+ for i in range(len(self.ups)):
311
+ ch = upsample_initial_channel//(2**(i+1))
312
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
313
+
314
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
315
+ self.resblocks.append(resblock(ch, k, d, style_dim))
316
+
317
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
318
+ self.ups.apply(init_weights)
319
+ self.conv_post.apply(init_weights)
320
+
321
+ def forward(self, x, s, f0):
322
+
323
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
324
+
325
+ har_source, noi_source, uv = self.m_source(f0)
326
+ har_source = har_source.transpose(1, 2)
327
+
328
+ for i in range(self.num_upsamples):
329
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
330
+ x_source = self.noise_convs[i](har_source)
331
+ x_source = self.noise_res[i](x_source, s)
332
+
333
+ x = self.ups[i](x)
334
+ x = x + x_source
335
+
336
+ xs = None
337
+ for j in range(self.num_kernels):
338
+ if xs is None:
339
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
340
+ else:
341
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
342
+ x = xs / self.num_kernels
343
+ x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2)
344
+ x = self.conv_post(x)
345
+ x = torch.tanh(x)
346
+
347
+ return x
348
+
349
+ def remove_weight_norm(self):
350
+ print('Removing weight norm...')
351
+ for l in self.ups:
352
+ remove_weight_norm(l)
353
+ for l in self.resblocks:
354
+ l.remove_weight_norm()
355
+ remove_weight_norm(self.conv_pre)
356
+ remove_weight_norm(self.conv_post)
357
+
358
+
359
+ class AdainResBlk1d(nn.Module):
360
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
361
+ upsample='none', dropout_p=0.0):
362
+ super().__init__()
363
+ self.actv = actv
364
+ self.upsample_type = upsample
365
+ self.upsample = UpSample1d(upsample)
366
+ self.learned_sc = dim_in != dim_out
367
+ self._build_weights(dim_in, dim_out, style_dim)
368
+ self.dropout = nn.Dropout(dropout_p)
369
+
370
+ if upsample == 'none':
371
+ self.pool = nn.Identity()
372
+ else:
373
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
374
+
375
+
376
+ def _build_weights(self, dim_in, dim_out, style_dim):
377
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
378
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
379
+ self.norm1 = AdaIN1d(style_dim, dim_in)
380
+ self.norm2 = AdaIN1d(style_dim, dim_out)
381
+ if self.learned_sc:
382
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
383
+
384
+ def _shortcut(self, x):
385
+ x = self.upsample(x)
386
+ if self.learned_sc:
387
+ x = self.conv1x1(x)
388
+ return x
389
+
390
+ def _residual(self, x, s):
391
+ x = self.norm1(x, s)
392
+ x = self.actv(x)
393
+ x = self.pool(x)
394
+ x = self.conv1(self.dropout(x))
395
+ x = self.norm2(x, s)
396
+ x = self.actv(x)
397
+ x = self.conv2(self.dropout(x))
398
+ return x
399
+
400
+ def forward(self, x, s):
401
+ out = self._residual(x, s)
402
+ out = (out + self._shortcut(x)) / math.sqrt(2)
403
+ return out
404
+
405
+ class UpSample1d(nn.Module):
406
+ def __init__(self, layer_type):
407
+ super().__init__()
408
+ self.layer_type = layer_type
409
+
410
+ def forward(self, x):
411
+ if self.layer_type == 'none':
412
+ return x
413
+ else:
414
+ return F.interpolate(x, scale_factor=2, mode='nearest')
415
+
416
+ class Decoder(nn.Module):
417
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
418
+ resblock_kernel_sizes = [3,7,11],
419
+ upsample_rates = [10,5,3,2],
420
+ upsample_initial_channel=512,
421
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
422
+ upsample_kernel_sizes=[20,10,6,4]):
423
+ super().__init__()
424
+
425
+ self.decode = nn.ModuleList()
426
+
427
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
428
+
429
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
430
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
431
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
432
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
433
+
434
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
435
+
436
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
437
+
438
+ self.asr_res = nn.Sequential(
439
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
440
+ )
441
+
442
+
443
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
444
+
445
+
446
+ def forward(self, asr, F0_curve, N, s):
447
+ if self.training:
448
+ downlist = [0, 3, 7]
449
+ F0_down = downlist[random.randint(0, 2)]
450
+ downlist = [0, 3, 7, 15]
451
+ N_down = downlist[random.randint(0, 3)]
452
+ if F0_down:
453
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
454
+ if N_down:
455
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
456
+
457
+
458
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
459
+ N = self.N_conv(N.unsqueeze(1))
460
+
461
+ x = torch.cat([asr, F0, N], axis=1)
462
+ x = self.encode(x, s)
463
+
464
+ asr_res = self.asr_res(asr)
465
+
466
+ res = True
467
+ for block in self.decode:
468
+ if res:
469
+ x = torch.cat([x, asr_res, F0, N], axis=1)
470
+ x = block(x, s)
471
+ if block.upsample_type != "none":
472
+ res = False
473
+
474
+ x = self.generator(x, s, F0_curve)
475
+ return x
476
+
477
+
Modules/istftnet.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+ from scipy.signal import get_window
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+ class AdaIN1d(nn.Module):
16
+ def __init__(self, style_dim, num_features):
17
+ super().__init__()
18
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
19
+ self.fc = nn.Linear(style_dim, num_features*2)
20
+
21
+ def forward(self, x, s):
22
+ h = self.fc(s)
23
+ h = h.view(h.size(0), h.size(1), 1)
24
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
25
+ return (1 + gamma) * self.norm(x) + beta
26
+
27
+ class AdaINResBlock1(torch.nn.Module):
28
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
29
+ super(AdaINResBlock1, self).__init__()
30
+ self.convs1 = nn.ModuleList([
31
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
32
+ padding=get_padding(kernel_size, dilation[0]))),
33
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
34
+ padding=get_padding(kernel_size, dilation[1]))),
35
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
36
+ padding=get_padding(kernel_size, dilation[2])))
37
+ ])
38
+ self.convs1.apply(init_weights)
39
+
40
+ self.convs2 = nn.ModuleList([
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
42
+ padding=get_padding(kernel_size, 1))),
43
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
44
+ padding=get_padding(kernel_size, 1))),
45
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
+ padding=get_padding(kernel_size, 1)))
47
+ ])
48
+ self.convs2.apply(init_weights)
49
+
50
+ self.adain1 = nn.ModuleList([
51
+ AdaIN1d(style_dim, channels),
52
+ AdaIN1d(style_dim, channels),
53
+ AdaIN1d(style_dim, channels),
54
+ ])
55
+
56
+ self.adain2 = nn.ModuleList([
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ AdaIN1d(style_dim, channels),
60
+ ])
61
+
62
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
63
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
64
+
65
+
66
+ def forward(self, x, s):
67
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
68
+ xt = n1(x, s)
69
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
70
+ xt = c1(xt)
71
+ xt = n2(xt, s)
72
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
73
+ xt = c2(xt)
74
+ x = xt + x
75
+ return x
76
+
77
+ def remove_weight_norm(self):
78
+ for l in self.convs1:
79
+ remove_weight_norm(l)
80
+ for l in self.convs2:
81
+ remove_weight_norm(l)
82
+
83
+ class TorchSTFT(torch.nn.Module):
84
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
85
+ super().__init__()
86
+ self.filter_length = filter_length
87
+ self.hop_length = hop_length
88
+ self.win_length = win_length
89
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
90
+
91
+ def transform(self, input_data):
92
+ forward_transform = torch.stft(
93
+ input_data,
94
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
95
+ return_complex=True)
96
+
97
+ return torch.abs(forward_transform), torch.angle(forward_transform)
98
+
99
+ def inverse(self, magnitude, phase):
100
+ inverse_transform = torch.istft(
101
+ magnitude * torch.exp(phase * 1j),
102
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
103
+
104
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
105
+
106
+ def forward(self, input_data):
107
+ self.magnitude, self.phase = self.transform(input_data)
108
+ reconstruction = self.inverse(self.magnitude, self.phase)
109
+ return reconstruction
110
+
111
+ class SineGen(torch.nn.Module):
112
+ """ Definition of sine generator
113
+ SineGen(samp_rate, harmonic_num = 0,
114
+ sine_amp = 0.1, noise_std = 0.003,
115
+ voiced_threshold = 0,
116
+ flag_for_pulse=False)
117
+ samp_rate: sampling rate in Hz
118
+ harmonic_num: number of harmonic overtones (default 0)
119
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
120
+ noise_std: std of Gaussian noise (default 0.003)
121
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
122
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
123
+ Note: when flag_for_pulse is True, the first time step of a voiced
124
+ segment is always sin(np.pi) or cos(0)
125
+ """
126
+
127
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
128
+ sine_amp=0.1, noise_std=0.003,
129
+ voiced_threshold=0,
130
+ flag_for_pulse=False):
131
+ super(SineGen, self).__init__()
132
+ self.sine_amp = sine_amp
133
+ self.noise_std = noise_std
134
+ self.harmonic_num = harmonic_num
135
+ self.dim = self.harmonic_num + 1
136
+ self.sampling_rate = samp_rate
137
+ self.voiced_threshold = voiced_threshold
138
+ self.flag_for_pulse = flag_for_pulse
139
+ self.upsample_scale = upsample_scale
140
+
141
+ def _f02uv(self, f0):
142
+ # generate uv signal
143
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
144
+ return uv
145
+
146
+ def _f02sine(self, f0_values):
147
+ """ f0_values: (batchsize, length, dim)
148
+ where dim indicates fundamental tone and overtones
149
+ """
150
+ # convert to F0 in rad. The interger part n can be ignored
151
+ # because 2 * np.pi * n doesn't affect phase
152
+ rad_values = (f0_values / self.sampling_rate) % 1
153
+
154
+ # initial phase noise (no noise for fundamental component)
155
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
156
+ device=f0_values.device)
157
+ rand_ini[:, 0] = 0
158
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
159
+
160
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
161
+ if not self.flag_for_pulse:
162
+ # # for normal case
163
+
164
+ # # To prevent torch.cumsum numerical overflow,
165
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
166
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
167
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
168
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
169
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
170
+ # cumsum_shift = torch.zeros_like(rad_values)
171
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
172
+
173
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
174
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
175
+ scale_factor=1/self.upsample_scale,
176
+ mode="linear").transpose(1, 2)
177
+
178
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
179
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
180
+ # cumsum_shift = torch.zeros_like(rad_values)
181
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
182
+
183
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
184
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
185
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
186
+ sines = torch.sin(phase)
187
+
188
+ else:
189
+ # If necessary, make sure that the first time step of every
190
+ # voiced segments is sin(pi) or cos(0)
191
+ # This is used for pulse-train generation
192
+
193
+ # identify the last time step in unvoiced segments
194
+ uv = self._f02uv(f0_values)
195
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
196
+ uv_1[:, -1, :] = 1
197
+ u_loc = (uv < 1) * (uv_1 > 0)
198
+
199
+ # get the instantanouse phase
200
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
201
+ # different batch needs to be processed differently
202
+ for idx in range(f0_values.shape[0]):
203
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
204
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
205
+ # stores the accumulation of i.phase within
206
+ # each voiced segments
207
+ tmp_cumsum[idx, :, :] = 0
208
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
209
+
210
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
211
+ # within the previous voiced segment.
212
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
213
+
214
+ # get the sines
215
+ sines = torch.cos(i_phase * 2 * np.pi)
216
+ return sines
217
+
218
+ def forward(self, f0):
219
+ """ sine_tensor, uv = forward(f0)
220
+ input F0: tensor(batchsize=1, length, dim=1)
221
+ f0 for unvoiced steps should be 0
222
+ output sine_tensor: tensor(batchsize=1, length, dim)
223
+ output uv: tensor(batchsize=1, length, 1)
224
+ """
225
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
226
+ device=f0.device)
227
+ # fundamental component
228
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
229
+
230
+ # generate sine waveforms
231
+ sine_waves = self._f02sine(fn) * self.sine_amp
232
+
233
+ # generate uv signal
234
+ # uv = torch.ones(f0.shape)
235
+ # uv = uv * (f0 > self.voiced_threshold)
236
+ uv = self._f02uv(f0)
237
+
238
+ # noise: for unvoiced should be similar to sine_amp
239
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
240
+ # . for voiced regions is self.noise_std
241
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
242
+ noise = noise_amp * torch.randn_like(sine_waves)
243
+
244
+ # first: set the unvoiced part to 0 by uv
245
+ # then: additive noise
246
+ sine_waves = sine_waves * uv + noise
247
+ return sine_waves, uv, noise
248
+
249
+
250
+ class SourceModuleHnNSF(torch.nn.Module):
251
+ """ SourceModule for hn-nsf
252
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
253
+ add_noise_std=0.003, voiced_threshod=0)
254
+ sampling_rate: sampling_rate in Hz
255
+ harmonic_num: number of harmonic above F0 (default: 0)
256
+ sine_amp: amplitude of sine source signal (default: 0.1)
257
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
258
+ note that amplitude of noise in unvoiced is decided
259
+ by sine_amp
260
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
261
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
262
+ F0_sampled (batchsize, length, 1)
263
+ Sine_source (batchsize, length, 1)
264
+ noise_source (batchsize, length 1)
265
+ uv (batchsize, length, 1)
266
+ """
267
+
268
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
269
+ add_noise_std=0.003, voiced_threshod=0):
270
+ super(SourceModuleHnNSF, self).__init__()
271
+
272
+ self.sine_amp = sine_amp
273
+ self.noise_std = add_noise_std
274
+
275
+ # to produce sine waveforms
276
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
277
+ sine_amp, add_noise_std, voiced_threshod)
278
+
279
+ # to merge source harmonics into a single excitation
280
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
281
+ self.l_tanh = torch.nn.Tanh()
282
+
283
+ def forward(self, x):
284
+ """
285
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
286
+ F0_sampled (batchsize, length, 1)
287
+ Sine_source (batchsize, length, 1)
288
+ noise_source (batchsize, length 1)
289
+ """
290
+ # source for harmonic branch
291
+ with torch.no_grad():
292
+ sine_wavs, uv, _ = self.l_sin_gen(x)
293
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
294
+
295
+ # source for noise branch, in the same shape as uv
296
+ noise = torch.randn_like(uv) * self.sine_amp / 3
297
+ return sine_merge, noise, uv
298
+ def padDiff(x):
299
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
300
+
301
+
302
+ class Generator(torch.nn.Module):
303
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
304
+ super(Generator, self).__init__()
305
+
306
+ self.num_kernels = len(resblock_kernel_sizes)
307
+ self.num_upsamples = len(upsample_rates)
308
+ resblock = AdaINResBlock1
309
+
310
+ self.m_source = SourceModuleHnNSF(
311
+ sampling_rate=24000,
312
+ upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
313
+ harmonic_num=8, voiced_threshod=10)
314
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
315
+ self.noise_convs = nn.ModuleList()
316
+ self.noise_res = nn.ModuleList()
317
+
318
+ self.ups = nn.ModuleList()
319
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
320
+ self.ups.append(weight_norm(
321
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
322
+ k, u, padding=(k-u)//2)))
323
+
324
+ self.resblocks = nn.ModuleList()
325
+ for i in range(len(self.ups)):
326
+ ch = upsample_initial_channel//(2**(i+1))
327
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
328
+ self.resblocks.append(resblock(ch, k, d, style_dim))
329
+
330
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
331
+
332
+ if i + 1 < len(upsample_rates): #
333
+ stride_f0 = np.prod(upsample_rates[i + 1:])
334
+ self.noise_convs.append(Conv1d(
335
+ gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
336
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
337
+ else:
338
+ self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
339
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
340
+
341
+
342
+ self.post_n_fft = gen_istft_n_fft
343
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
344
+ self.ups.apply(init_weights)
345
+ self.conv_post.apply(init_weights)
346
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
347
+ self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
348
+
349
+
350
+ def forward(self, x, s, f0):
351
+ with torch.no_grad():
352
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
353
+
354
+ har_source, noi_source, uv = self.m_source(f0)
355
+ har_source = har_source.transpose(1, 2).squeeze(1)
356
+ har_spec, har_phase = self.stft.transform(har_source)
357
+ har = torch.cat([har_spec, har_phase], dim=1)
358
+
359
+ for i in range(self.num_upsamples):
360
+ x = F.leaky_relu(x, LRELU_SLOPE)
361
+ x_source = self.noise_convs[i](har)
362
+ x_source = self.noise_res[i](x_source, s)
363
+
364
+ x = self.ups[i](x)
365
+ if i == self.num_upsamples - 1:
366
+ x = self.reflection_pad(x)
367
+
368
+ x = x + x_source
369
+ xs = None
370
+ for j in range(self.num_kernels):
371
+ if xs is None:
372
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
373
+ else:
374
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
375
+ x = xs / self.num_kernels
376
+ x = F.leaky_relu(x)
377
+ x = self.conv_post(x)
378
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
379
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
380
+ return self.stft.inverse(spec, phase)
381
+
382
+ def fw_phase(self, x, s):
383
+ for i in range(self.num_upsamples):
384
+ x = F.leaky_relu(x, LRELU_SLOPE)
385
+ x = self.ups[i](x)
386
+ xs = None
387
+ for j in range(self.num_kernels):
388
+ if xs is None:
389
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
390
+ else:
391
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
392
+ x = xs / self.num_kernels
393
+ x = F.leaky_relu(x)
394
+ x = self.reflection_pad(x)
395
+ x = self.conv_post(x)
396
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
397
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
398
+ return spec, phase
399
+
400
+ def remove_weight_norm(self):
401
+ print('Removing weight norm...')
402
+ for l in self.ups:
403
+ remove_weight_norm(l)
404
+ for l in self.resblocks:
405
+ l.remove_weight_norm()
406
+ remove_weight_norm(self.conv_pre)
407
+ remove_weight_norm(self.conv_post)
408
+
409
+
410
+ class AdainResBlk1d(nn.Module):
411
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
412
+ upsample='none', dropout_p=0.0):
413
+ super().__init__()
414
+ self.actv = actv
415
+ self.upsample_type = upsample
416
+ self.upsample = UpSample1d(upsample)
417
+ self.learned_sc = dim_in != dim_out
418
+ self._build_weights(dim_in, dim_out, style_dim)
419
+ self.dropout = nn.Dropout(dropout_p)
420
+
421
+ if upsample == 'none':
422
+ self.pool = nn.Identity()
423
+ else:
424
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
425
+
426
+
427
+ def _build_weights(self, dim_in, dim_out, style_dim):
428
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
429
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
430
+ self.norm1 = AdaIN1d(style_dim, dim_in)
431
+ self.norm2 = AdaIN1d(style_dim, dim_out)
432
+ if self.learned_sc:
433
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
434
+
435
+ def _shortcut(self, x):
436
+ x = self.upsample(x)
437
+ if self.learned_sc:
438
+ x = self.conv1x1(x)
439
+ return x
440
+
441
+ def _residual(self, x, s):
442
+ x = self.norm1(x, s)
443
+ x = self.actv(x)
444
+ x = self.pool(x)
445
+ x = self.conv1(self.dropout(x))
446
+ x = self.norm2(x, s)
447
+ x = self.actv(x)
448
+ x = self.conv2(self.dropout(x))
449
+ return x
450
+
451
+ def forward(self, x, s):
452
+ out = self._residual(x, s)
453
+ out = (out + self._shortcut(x)) / math.sqrt(2)
454
+ return out
455
+
456
+ class UpSample1d(nn.Module):
457
+ def __init__(self, layer_type):
458
+ super().__init__()
459
+ self.layer_type = layer_type
460
+
461
+ def forward(self, x):
462
+ if self.layer_type == 'none':
463
+ return x
464
+ else:
465
+ return F.interpolate(x, scale_factor=2, mode='nearest')
466
+
467
+ class Decoder(nn.Module):
468
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
469
+ resblock_kernel_sizes = [3,7,11],
470
+ upsample_rates = [10, 6],
471
+ upsample_initial_channel=512,
472
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
473
+ upsample_kernel_sizes=[20, 12],
474
+ gen_istft_n_fft=20, gen_istft_hop_size=5):
475
+ super().__init__()
476
+
477
+ self.decode = nn.ModuleList()
478
+
479
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
480
+
481
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
482
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
483
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
484
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
485
+
486
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
487
+
488
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
489
+
490
+ self.asr_res = nn.Sequential(
491
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
492
+ )
493
+
494
+
495
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
496
+ upsample_initial_channel, resblock_dilation_sizes,
497
+ upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
498
+
499
+ def forward(self, asr, F0_curve, N, s):
500
+ if self.training:
501
+ downlist = [0, 3, 7]
502
+ F0_down = downlist[random.randint(0, 2)]
503
+ downlist = [0, 3, 7, 15]
504
+ N_down = downlist[random.randint(0, 3)]
505
+ if F0_down:
506
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
507
+ if N_down:
508
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
509
+
510
+
511
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
512
+ N = self.N_conv(N.unsqueeze(1))
513
+
514
+ x = torch.cat([asr, F0, N], axis=1)
515
+ x = self.encode(x, s)
516
+
517
+ asr_res = self.asr_res(asr)
518
+
519
+ res = True
520
+ for block in self.decode:
521
+ if res:
522
+ x = torch.cat([x, asr_res, F0, N], axis=1)
523
+ x = block(x, s)
524
+ if block.upsample_type != "none":
525
+ res = False
526
+
527
+ x = self.generator(x, s, F0_curve)
528
+ return x
529
+
530
+
Modules/slmadv.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ class SLMAdversarialLoss(torch.nn.Module):
6
+
7
+ def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5):
8
+ super(SLMAdversarialLoss, self).__init__()
9
+ self.model = model
10
+ self.wl = wl
11
+ self.sampler = sampler
12
+
13
+ self.min_len = min_len
14
+ self.max_len = max_len
15
+ self.batch_percentage = batch_percentage
16
+
17
+ self.sig = sig
18
+ self.skip_update = skip_update
19
+
20
+ def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None):
21
+ text_mask = length_to_mask(ref_lengths).to(ref_text.device)
22
+ bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
23
+ d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
24
+
25
+ if use_ind and np.random.rand() < 0.5:
26
+ s_preds = s_trg
27
+ else:
28
+ num_steps = np.random.randint(3, 5)
29
+ if ref_s is not None:
30
+ s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
31
+ embedding=bert_dur,
32
+ embedding_scale=1,
33
+ features=ref_s, # reference from the same speaker as the embedding
34
+ embedding_mask_proba=0.1,
35
+ num_steps=num_steps).squeeze(1)
36
+ else:
37
+ s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
38
+ embedding=bert_dur,
39
+ embedding_scale=1,
40
+ embedding_mask_proba=0.1,
41
+ num_steps=num_steps).squeeze(1)
42
+
43
+ s_dur = s_preds[:, 128:]
44
+ s = s_preds[:, :128]
45
+
46
+ d, _ = self.model.predictor(d_en, s_dur,
47
+ ref_lengths,
48
+ torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
49
+ text_mask)
50
+
51
+ bib = 0
52
+
53
+ output_lengths = []
54
+ attn_preds = []
55
+
56
+ # differentiable duration modeling
57
+ for _s2s_pred, _text_length in zip(d, ref_lengths):
58
+
59
+ _s2s_pred_org = _s2s_pred[:_text_length, :]
60
+
61
+ _s2s_pred = torch.sigmoid(_s2s_pred_org)
62
+ _dur_pred = _s2s_pred.sum(axis=-1)
63
+
64
+ l = int(torch.round(_s2s_pred.sum()).item())
65
+ t = torch.arange(0, l).expand(l)
66
+
67
+ t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device)
68
+ loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
69
+
70
+ h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2)
71
+
72
+ out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0),
73
+ h.unsqueeze(1),
74
+ padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
75
+ attn_preds.append(F.softmax(out.squeeze(), dim=0))
76
+
77
+ output_lengths.append(l)
78
+
79
+ max_len = max(output_lengths)
80
+
81
+ with torch.no_grad():
82
+ t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
83
+
84
+ s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(ref_text.device)
85
+ for bib in range(len(output_lengths)):
86
+ s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib]
87
+
88
+ asr_pred = t_en @ s2s_attn
89
+
90
+ _, p_pred = self.model.predictor(d_en, s_dur,
91
+ ref_lengths,
92
+ s2s_attn,
93
+ text_mask)
94
+
95
+ mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
96
+ mel_len = min(mel_len, self.max_len // 2)
97
+
98
+ # get clips
99
+
100
+ en = []
101
+ p_en = []
102
+ sp = []
103
+
104
+ F0_fakes = []
105
+ N_fakes = []
106
+
107
+ wav = []
108
+
109
+ for bib in range(len(output_lengths)):
110
+ mel_length_pred = output_lengths[bib]
111
+ mel_length_gt = int(mel_input_length[bib].item() / 2)
112
+ if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
113
+ continue
114
+
115
+ sp.append(s_preds[bib])
116
+
117
+ random_start = np.random.randint(0, mel_length_pred - mel_len)
118
+ en.append(asr_pred[bib, :, random_start:random_start+mel_len])
119
+ p_en.append(p_pred[bib, :, random_start:random_start+mel_len])
120
+
121
+ # get ground truth clips
122
+ random_start = np.random.randint(0, mel_length_gt - mel_len)
123
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
124
+ wav.append(torch.from_numpy(y).to(ref_text.device))
125
+
126
+ if len(wav) >= self.batch_percentage * len(waves): # prevent OOM due to longer lengths
127
+ break
128
+
129
+ if len(sp) <= 1:
130
+ return None
131
+
132
+ sp = torch.stack(sp)
133
+ wav = torch.stack(wav).float()
134
+ en = torch.stack(en)
135
+ p_en = torch.stack(p_en)
136
+
137
+ F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
138
+ y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
139
+
140
+ # discriminator loss
141
+ if (iters + 1) % self.skip_update == 0:
142
+ if np.random.randint(0, 2) == 0:
143
+ wav = y_rec_gt_pred
144
+ use_rec = True
145
+ else:
146
+ use_rec = False
147
+
148
+ crop_size = min(wav.size(-1), y_pred.size(-1))
149
+ if use_rec: # use reconstructed (shorter lengths), do length invariant regularization
150
+ if wav.size(-1) > y_pred.size(-1):
151
+ real_GP = wav[:, : , :crop_size]
152
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
153
+ out_org = self.wl.discriminator_forward(wav.detach().squeeze())
154
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
155
+
156
+ if np.random.randint(0, 2) == 0:
157
+ d_loss = self.wl.discriminator(real_GP.detach().squeeze(), y_pred.detach().squeeze()).mean()
158
+ else:
159
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
160
+ else:
161
+ real_GP = y_pred[:, : , :crop_size]
162
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
163
+ out_org = self.wl.discriminator_forward(y_pred.detach().squeeze())
164
+ loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
165
+
166
+ if np.random.randint(0, 2) == 0:
167
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), real_GP.detach().squeeze()).mean()
168
+ else:
169
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
170
+
171
+ # regularization (ignore length variation)
172
+ d_loss += loss_reg
173
+
174
+ out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze())
175
+ out_rec = self.wl.discriminator_forward(y_rec_gt_pred.detach().squeeze())
176
+
177
+ # regularization (ignore reconstruction artifacts)
178
+ d_loss += F.l1_loss(out_gt, out_rec)
179
+
180
+ else:
181
+ d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
182
+ else:
183
+ d_loss = 0
184
+
185
+ # generator loss
186
+ gen_loss = self.wl.generator(y_pred.squeeze())
187
+
188
+ gen_loss = gen_loss.mean()
189
+
190
+ return d_loss, gen_loss, y_pred.detach().cpu().numpy()
191
+
192
+ def length_to_mask(lengths):
193
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
194
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
195
+ return mask
Modules/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def init_weights(m, mean=0.0, std=0.01):
2
+ classname = m.__class__.__name__
3
+ if classname.find("Conv") != -1:
4
+ m.weight.data.normal_(mean, std)
5
+
6
+
7
+ def apply_weight_norm(m):
8
+ classname = m.__class__.__name__
9
+ if classname.find("Conv") != -1:
10
+ weight_norm(m)
11
+
12
+
13
+ def get_padding(kernel_size, dilation=1):
14
+ return int((kernel_size*dilation - dilation)/2)
README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models
2
+
3
+ ### Yinghao Aaron Li, Cong Han, Vinay S. Raghavan, Gavin Mischler, Nima Mesgarani
4
+
5
+ > In this paper, we present StyleTTS 2, a text-to-speech (TTS) model that leverages style diffusion and adversarial training with large speech language models (SLMs) to achieve human-level TTS synthesis. StyleTTS 2 differs from its predecessor by modeling styles as a latent random variable through diffusion models to generate the most suitable style for the text without requiring reference speech, achieving efficient latent diffusion while benefiting from the diverse speech synthesis offered by diffusion models. Furthermore, we employ large pre-trained SLMs, such as WavLM, as discriminators with our novel differentiable duration modeling for end-to-end training, resulting in improved speech naturalness. StyleTTS 2 surpasses human recordings on the single-speaker LJSpeech dataset and matches it on the multispeaker VCTK dataset as judged by native English speakers. Moreover, when trained on the LibriTTS dataset, our model outperforms previous publicly available models for zero-shot speaker adaptation. This work achieves the first human-level TTS synthesis on both single and multispeaker datasets, showcasing the potential of style diffusion and adversarial training with large SLMs.
6
+
7
+ Paper: [https://arxiv.org/abs/2306.07691](https://arxiv.org/abs/2306.07691)
8
+
9
+ Audio samples: [https://styletts2.github.io/](https://styletts2.github.io/)
10
+
11
+ Online demo: [Hugging Face](https://huggingface.co/spaces/styletts2/styletts2) (thank [@fakerybakery](https://github.com/fakerybakery) for the wonderful online demo)
12
+
13
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/) [![Discord](https://img.shields.io/discord/1197679063150637117?logo=discord&logoColor=white&label=Join%20our%20Community)](https://discord.gg/ha8sxdG2K4)
14
+
15
+ ## TODO
16
+ - [x] Training and inference demo code for single-speaker models (LJSpeech)
17
+ - [x] Test training code for multi-speaker models (VCTK and LibriTTS)
18
+ - [x] Finish demo code for multispeaker model and upload pre-trained models
19
+ - [x] Add a finetuning script for new speakers with base pre-trained multispeaker models
20
+ - [ ] Fix DDP (accelerator) for `train_second.py` **(I have tried everything I could to fix this but had no success, so if you are willing to help, please see [#7](https://github.com/yl4579/StyleTTS2/issues/7))**
21
+
22
+ ## Pre-requisites
23
+ 1. Python >= 3.7
24
+ 2. Clone this repository:
25
+ ```bash
26
+ git clone https://github.com/yl4579/StyleTTS2.git
27
+ cd StyleTTS2
28
+ ```
29
+ 3. Install python requirements:
30
+ ```bash
31
+ pip install -r requirements.txt
32
+ ```
33
+ On Windows add:
34
+ ```bash
35
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -U
36
+ ```
37
+ Also install phonemizer and espeak if you want to run the demo:
38
+ ```bash
39
+ pip install phonemizer
40
+ sudo apt-get install espeak-ng
41
+ ```
42
+ 4. Download and extract the [LJSpeech dataset](https://keithito.com/LJ-Speech-Dataset/), unzip to the data folder and upsample the data to 24 kHz. The text aligner and pitch extractor are pre-trained on 24 kHz data, but you can easily change the preprocessing and re-train them using your own preprocessing.
43
+ For LibriTTS, you will need to combine train-clean-360 with train-clean-100 and rename the folder train-clean-460 (see [val_list_libritts.txt](https://github.com/yl4579/StyleTTS/blob/main/Data/val_list_libritts.txt) as an example).
44
+
45
+ ## Training
46
+ First stage training:
47
+ ```bash
48
+ accelerate launch train_first.py --config_path ./Configs/config.yml
49
+ ```
50
+ Second stage training **(DDP version not working, so the current version uses DP, again see [#7](https://github.com/yl4579/StyleTTS2/issues/7) if you want to help)**:
51
+ ```bash
52
+ python train_second.py --config_path ./Configs/config.yml
53
+ ```
54
+ You can run both consecutively and it will train both the first and second stages. The model will be saved in the format "epoch_1st_%05d.pth" and "epoch_2nd_%05d.pth". Checkpoints and Tensorboard logs will be saved at `log_dir`.
55
+
56
+ The data list format needs to be `filename.wav|transcription|speaker`, see [val_list.txt](https://github.com/yl4579/StyleTTS2/blob/main/Data/val_list.txt) as an example. The speaker labels are needed for multi-speaker models because we need to sample reference audio for style diffusion model training.
57
+
58
+ ### Important Configurations
59
+ In [config.yml](https://github.com/yl4579/StyleTTS2/blob/main/Configs/config.yml), there are a few important configurations to take care of:
60
+ - `OOD_data`: The path for out-of-distribution texts for SLM adversarial training. The format should be `text|anything`.
61
+ - `min_length`: Minimum length of OOD texts for training. This is to make sure the synthesized speech has a minimum length.
62
+ - `max_len`: Maximum length of audio for training. The unit is frame. Since the default hop size is 300, one frame is approximately `300 / 24000` (0.0125) second. Lowering this if you encounter the out-of-memory issue.
63
+ - `multispeaker`: Set to true if you want to train a multispeaker model. This is needed because the architecture of the denoiser is different for single and multispeaker models.
64
+ - `batch_percentage`: This is to make sure during SLM adversarial training there are no out-of-memory (OOM) issues. If you encounter OOM problem, please set a lower number for this.
65
+
66
+ ### Pre-trained modules
67
+ In [Utils](https://github.com/yl4579/StyleTTS2/tree/main/Utils) folder, there are three pre-trained models:
68
+ - **[ASR](https://github.com/yl4579/StyleTTS2/tree/main/Utils/ASR) folder**: It contains the pre-trained text aligner, which was pre-trained on English (LibriTTS), Japanese (JVS), and Chinese (AiShell) corpus. It works well for most other languages without fine-tuning, but you can always train your own text aligner with the code here: [yl4579/AuxiliaryASR](https://github.com/yl4579/AuxiliaryASR).
69
+ - **[JDC](https://github.com/yl4579/StyleTTS2/tree/main/Utils/JDC) folder**: It contains the pre-trained pitch extractor, which was pre-trained on English (LibriTTS) corpus only. However, it works well for other languages too because F0 is independent of language. If you want to train on singing corpus, it is recommended to train a new pitch extractor with the code here: [yl4579/PitchExtractor](https://github.com/yl4579/PitchExtractor).
70
+ - **[PLBERT](https://github.com/yl4579/StyleTTS2/tree/main/Utils/PLBERT) folder**: It contains the pre-trained [PL-BERT](https://arxiv.org/abs/2301.08810) model, which was pre-trained on English (Wikipedia) corpus only. It probably does not work very well on other languages, so you will need to train a different PL-BERT for different languages using the repo here: [yl4579/PL-BERT](https://github.com/yl4579/PL-BERT). You can also use the [multilingual PL-BERT](https://huggingface.co/papercup-ai/multilingual-pl-bert) which supports 14 languages.
71
+
72
+ ### Common Issues
73
+ - **Loss becomes NaN**: If it is the first stage, please make sure you do not use mixed precision, as it can cause loss becoming NaN for some particular datasets when the batch size is not set properly (need to be more than 16 to work well). For the second stage, please also experiment with different batch sizes, with higher batch sizes being more likely to cause NaN loss values. We recommend the batch size to be 16. You can refer to issues [#10](https://github.com/yl4579/StyleTTS2/issues/10) and [#11](https://github.com/yl4579/StyleTTS2/issues/11) for more details.
74
+ - **Out of memory**: Please either use lower `batch_size` or `max_len`. You may refer to issue [#10](https://github.com/yl4579/StyleTTS2/issues/10) for more information.
75
+ - **Non-English dataset**: You can train on any language you want, but you will need to use a pre-trained PL-BERT model for that language. We have a pre-trained [multilingual PL-BERT](https://huggingface.co/papercup-ai/multilingual-pl-bert) that supports 14 languages. You may refer to [yl4579/StyleTTS#10](https://github.com/yl4579/StyleTTS/issues/10) and [#70](https://github.com/yl4579/StyleTTS2/issues/70) for some examples to train on Chinese datasets.
76
+
77
+ ## Finetuning
78
+ The script is modified from `train_second.py` which uses DP, as DDP does not work for `train_second.py`. Please see the bold section above if you are willing to help with this problem.
79
+ ```bash
80
+ python train_finetune.py --config_path ./Configs/config_ft.yml
81
+ ```
82
+ Please make sure you have the LibriTTS checkpoint downloaded and unzipped under the folder. The default configuration `config_ft.yml` finetunes on LJSpeech with 1 hour of speech data (around 1k samples) for 50 epochs. This took about 4 hours to finish on four NVidia A100. The quality is slightly worse (similar to NaturalSpeech on LJSpeech) than LJSpeech model trained from scratch with 24 hours of speech data, which took around 2.5 days to finish on four A100. The samples can be found at [#65 (comment)](https://github.com/yl4579/StyleTTS2/discussions/65#discussioncomment-7668393).
83
+
84
+ If you are using a **single GPU** (because the script doesn't work with DDP) and want to save training speed and VRAM, you can do (thank [@korakoe](https://github.com/korakoe) for making the script at [#100](https://github.com/yl4579/StyleTTS2/pull/100)):
85
+ ```bash
86
+ accelerate launch --mixed_precision=fp16 --num_processes=1 train_finetune_accelerate.py --config_path ./Configs/config_ft.yml
87
+ ```
88
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/Colab/StyleTTS2_Finetune_Demo.ipynb)
89
+
90
+ ### Common Issues
91
+ [@Kreevoz](https://github.com/Kreevoz) has made detailed notes on common issues in finetuning, with suggestions in maximizing audio quality: [#81](https://github.com/yl4579/StyleTTS2/discussions/81). Some of these also apply to training from scratch. [@IIEleven11](https://github.com/IIEleven11) has also made a guideline for fine-tuning: [#128](https://github.com/yl4579/StyleTTS2/discussions/128).
92
+
93
+ - **Out of memory after `joint_epoch`**: This is likely because your GPU RAM is not big enough for SLM adversarial training run. You may skip that but the quality could be worse. Setting `joint_epoch` a larger number than `epochs` could skip the SLM advesariral training.
94
+
95
+ ## Inference
96
+ Please refer to [Inference_LJSpeech.ipynb](https://github.com/yl4579/StyleTTS2/blob/main/Demo/Inference_LJSpeech.ipynb) (single-speaker) and [Inference_LibriTTS.ipynb](https://github.com/yl4579/StyleTTS2/blob/main/Demo/Inference_LibriTTS.ipynb) (multi-speaker) for details. For LibriTTS, you will also need to download [reference_audio.zip](https://huggingface.co/yl4579/StyleTTS2-LibriTTS/resolve/main/reference_audio.zip) and unzip it under the `demo` before running the demo.
97
+
98
+ - The pretrained StyleTTS 2 on LJSpeech corpus in 24 kHz can be downloaded at [https://huggingface.co/yl4579/StyleTTS2-LJSpeech/tree/main](https://huggingface.co/yl4579/StyleTTS2-LJSpeech/tree/main).
99
+
100
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/Colab/StyleTTS2_Demo_LJSpeech.ipynb)
101
+
102
+ - The pretrained StyleTTS 2 model on LibriTTS can be downloaded at [https://huggingface.co/yl4579/StyleTTS2-LibriTTS/tree/main](https://huggingface.co/yl4579/StyleTTS2-LibriTTS/tree/main).
103
+
104
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/Colab/StyleTTS2_Demo_LibriTTS.ipynb)
105
+
106
+
107
+ You can import StyleTTS 2 and run it in your own code. However, the inference depends on a GPL-licensed package, so it is not included directly in this repository. A [GPL-licensed fork](https://github.com/NeuralVox/StyleTTS2) has an importable script, as well as an experimental streaming API, etc. A [fully MIT-licensed package](https://pypi.org/project/styletts2/) that uses gruut (albeit lower quality due to mismatch between phonemizer and gruut) is also available.
108
+
109
+ ***Before using these pre-trained models, you agree to inform the listeners that the speech samples are synthesized by the pre-trained models, unless you have the permission to use the voice you synthesize. That is, you agree to only use voices whose speakers grant the permission to have their voice cloned, either directly or by license before making synthesized voices public, or you have to publicly announce that these voices are synthesized if you do not have the permission to use these voices.***
110
+
111
+ ### Common Issues
112
+ - **High-pitched background noise**: This is caused by numerical float differences in older GPUs. For more details, please refer to issue [#13](https://github.com/yl4579/StyleTTS2/issues/13). Basically, you will need to use more modern GPUs or do inference on CPUs.
113
+ - **Pre-trained model license**: You only need to abide by the above rules if you use **the pre-trained models** and the voices are **NOT** in the training set, i.e., your reference speakers are not from any open access dataset. For more details of rules to use the pre-trained models, please see [#37](https://github.com/yl4579/StyleTTS2/issues/37).
114
+
115
+ ## References
116
+ - [archinetai/audio-diffusion-pytorch](https://github.com/archinetai/audio-diffusion-pytorch)
117
+ - [jik876/hifi-gan](https://github.com/jik876/hifi-gan)
118
+ - [rishikksh20/iSTFTNet-pytorch](https://github.com/rishikksh20/iSTFTNet-pytorch)
119
+ - [nii-yamagishilab/project-NN-Pytorch-scripts/project/01-nsf](https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/tree/master/project/01-nsf)
120
+
121
+ ## License
122
+
123
+ Code: MIT License
124
+
125
+ Pre-Trained Models: Before using these pre-trained models, you agree to inform the listeners that the speech samples are synthesized by the pre-trained models, unless you have the permission to use the voice you synthesize. That is, you agree to only use voices whose speakers grant the permission to have their voice cloned, either directly or by license before making synthesized voices public, or you have to publicly announce that these voices are synthesized if you do not have the permission to use these voices.
Utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
__pycache__/losses.cpython-310.pyc ADDED
Binary file (8.95 kB). View file
 
__pycache__/meldataset.cpython-310.pyc ADDED
Binary file (8.54 kB). View file
 
__pycache__/models.cpython-310.pyc ADDED
Binary file (22.4 kB). View file
 
__pycache__/optimizers.cpython-310.pyc ADDED
Binary file (3.97 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.92 kB). View file
 
bad_wavs.txt ADDED
File without changes
data/OOD_dummy.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ This is an out-of-domain sentence.|
2
+ This is an out-of-domain sentence.|
data/add_phones.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, unicodedata
2
+
3
+ _pad = "$"
4
+ _punctuation = ';:,.!?¡¿—…"«»“” '
5
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
6
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
7
+
8
+ # Export all symbols:
9
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
10
+
11
+ dicts = {}
12
+ for i in range(len((symbols))):
13
+ dicts[symbols[i]] = i
14
+
15
+ class TextCleaner:
16
+ """
17
+ • Normalises text to NFC so pre-composed IPA glyphs match `symbols`.
18
+ • Splits on event tokens first (e.g. <evt_gasp>), then per-character.
19
+ • Unknown chars map to the <unk> symbol instead of printing.
20
+ """
21
+ _EVENT_RE = re.compile(r"<[^>]+>|.") # match <evt_xxx> or single char
22
+
23
+ def __init__(self):
24
+ # `dicts` must already include EVENT_TOKENS and "<unk>"
25
+ self.lookup = dicts
26
+ self.unk_id = 0
27
+
28
+ def __call__(self, text: str):
29
+ text = unicodedata.normalize("NFC", text)
30
+ ids = []
31
+ for tok in self._EVENT_RE.findall(text):
32
+ ids.append(self.lookup.get(tok, self.unk_id))
33
+ return ids
34
+
35
+ tc = TextCleaner()
36
+ miss = {}
37
+
38
+ with open("/home/ubuntu/styletts2-ft/data/train_list.txt", encoding="utf-8") as f:
39
+ for line in f:
40
+ for i in tc(line.split("|")[1]): # convert once
41
+ pass # if it got an ID, it's known
42
+ print("Unknown chars left:", [k for k,v in miss.items()])
data/val_list.txt ADDED
The diff for this file is too large to render. See raw diff
 
inference_first.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ inference_first.py — quick Stage-1 sanity-check for StyleTTS-2
4
+
5
+ Example
6
+ -------
7
+ python inference_first.py \
8
+ --ckpt logs/pod_90h_30k/epoch_1st_0004.pth \
9
+ --ref data/wavs/123_abcd_part042_00.wav \
10
+ --text "<evt_gasp> ðɪs ɪz ɐ tɛst ˈsɛntəns"
11
+ It writes preview.wav in the current directory.
12
+ """
13
+ import argparse, yaml, torch, torchaudio
14
+ from models import build_model, load_ASR_models, load_F0_models
15
+ from Utils.PLBERT.util import load_plbert
16
+ from utils import recursive_munch, log_norm, length_to_mask
17
+ from meldataset import TextCleaner, preprocess
18
+
19
+ # ────────────────────────── helpers ────────────────────────────
20
+ def _restore_batch(x):
21
+ """(T,) ▸ (1,T) or (C,T) ▸ (1,C,T) (handles squeeze in JDCNet)."""
22
+ return x.unsqueeze(0) if x.dim() == 1 else x
23
+
24
+ def _match_len(x, target_len):
25
+ """Crop or zero-pad last axis to target_len."""
26
+ cur = x.shape[-1]
27
+ if cur > target_len:
28
+ return x[..., :target_len]
29
+ if cur < target_len:
30
+ pad = target_len - cur
31
+ return torch.nn.functional.pad(x, (0, pad))
32
+ return x
33
+
34
+ # ────────────────────────── CLI ────────────────────────────────
35
+ p = argparse.ArgumentParser()
36
+ p.add_argument("--ckpt", required=True, help="epoch_1st_*.pth")
37
+ p.add_argument("--ref", required=True, help="reference wav (24 kHz mono)")
38
+ p.add_argument("--text", required=True, help="IPA / phoneme string")
39
+ p.add_argument("--cfg", default="Configs/config_ft_single.yml")
40
+ args = p.parse_args()
41
+
42
+ # ───────────────── net & cfg ───────────────────────────────────
43
+ cfg = yaml.safe_load(open(args.cfg))
44
+ sr = cfg["preprocess_params"]["sr"]
45
+ device = "cuda"
46
+
47
+ asr = load_ASR_models(cfg["ASR_path"], cfg["ASR_config"])
48
+ f0 = load_F0_models(cfg["F0_path"])
49
+ bert = load_plbert(cfg["PLBERT_dir"])
50
+ model = build_model(recursive_munch(cfg["model_params"]), asr, f0, bert)
51
+
52
+ state = torch.load(args.ckpt, map_location="cpu")["net"]
53
+ for k in model:
54
+ model[k].load_state_dict(state[k], strict=False)
55
+ model[k].eval().to(device)
56
+
57
+ # ───────────────── prepare inputs ──────────────────────────────
58
+ cleaner = TextCleaner()
59
+ text_ids = torch.LongTensor(cleaner(args.text)).unsqueeze(0).to(device)
60
+ input_lengths = torch.LongTensor([text_ids.shape[1]]).to(device)
61
+ text_mask = length_to_mask(input_lengths).to(device)
62
+
63
+ wav, _ = torchaudio.load(args.ref) # (1,N)
64
+ mel_ref = preprocess(wav.squeeze().numpy()).to(device) # (1,80,T)
65
+
66
+ style = model.style_encoder(mel_ref.unsqueeze(1)) # (1,128)
67
+
68
+ F0_real, _, _ = model.pitch_extractor(mel_ref.unsqueeze(1))
69
+ F0_real = _restore_batch(F0_real) # (1,T')
70
+
71
+ real_norm = log_norm(mel_ref.unsqueeze(1)).squeeze(1) # (1,T")
72
+ real_norm = _restore_batch(real_norm) # (1,T")
73
+
74
+ # ───────────────── align lengths ───────────────────────────────
75
+ enc = model.text_encoder(text_ids, input_lengths, text_mask) # (1,512,L)
76
+ enc_len = enc.shape[-1] # L
77
+ target = enc_len * 2 # decoder expects 2×L
78
+
79
+ F0_real = _match_len(F0_real, target) # (1,2L)
80
+ real_norm = _match_len(real_norm, target) # (1,2L)
81
+
82
+ # ───────────────── decode & save ───────────────────────────────
83
+ with torch.no_grad():
84
+ y = model.decoder(enc, F0_real, real_norm, style)
85
+
86
+ # ─── make it (channels, samples) = (1, T) ────────────────────────────
87
+ y = y.squeeze(0) # (1, T)
88
+
89
+ torchaudio.save("preview.wav", y.cpu(), sr)
90
+ print("✅ wrote preview.wav")
logs/pod_90h_30k/config_ft_single.yml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─── GLOBAL ──────────────────────────────────────────────────────────
2
+ log_dir: logs/pod_90h_30k
3
+ device: "cuda"
4
+
5
+ batch_size: 8 # 40 GB A100, fp16
6
+ max_len: 160 # ≈ 8 s (200 × 40 ms)
7
+
8
+ epochs_1st: 13 # first-stage schedule
9
+ epochs_2nd: 13 # second-stage schedule (later)
10
+ save_freq: 2
11
+ log_interval: 50
12
+
13
+ # leave blank on first run
14
+ pretrained_model: ""
15
+ second_stage_load_pretrained: false
16
+ load_only_params: false
17
+
18
+ # ─── PRE-PROCESS ─────────────────────────────────────────────────────
19
+ preprocess_params:
20
+ sr: 24000
21
+ spect_params: # required by Mel extractor
22
+ n_fft: 2048
23
+ win_length: 1200
24
+ hop_length: 300
25
+
26
+ # ─── DATA ────────────────────────────────────────────────────────────
27
+ data_params:
28
+ root_path: /home/ubuntu/styletts2-ft/data/wavs
29
+ train_data: /home/ubuntu/styletts2-ft/data/train_list.txt
30
+ val_data: /home/ubuntu/styletts2-ft/data/val_list.txt
31
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
32
+ OOD_data: /home/ubuntu/styletts2-ft/data/OOD_texts.txt
33
+
34
+ # ─── LOSS SCHEDULE ──────────────────────────────────────────────────
35
+ loss_params:
36
+ lambda_mel: 5. # mel reconstruction loss
37
+ lambda_gen: 1. # generator loss
38
+ lambda_slm: 1. # slm feature matching loss
39
+
40
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
41
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
42
+ TMA_epoch: 50 # TMA starting epoch (1st stage)
43
+
44
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
45
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
46
+ lambda_dur: 1. # duration loss (2nd stage)
47
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
48
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
49
+ lambda_diff: 1. # score matching loss (2nd stage)
50
+
51
+ diff_epoch: 20 # style diffusion starting epoch (2nd stage)
52
+ joint_epoch: 50 # joint training starting epoch (2nd stage)
53
+
54
+ # ─── OPTIMISER ──────────────────────────────────────────────────────
55
+ optimizer_params:
56
+ lr: 0.0008
57
+ bert_lr: 0.00002
58
+ ft_lr: 0.0002
59
+ grad_accum_steps: 2
60
+
61
+ # ─── MODEL (core network & sub-modules) ─────────────────────────────
62
+ model_params:
63
+ multispeaker: true # speaker-ID column present
64
+ dim_in: 64
65
+ hidden_dim: 512
66
+ max_conv_dim: 512
67
+ n_layer: 3
68
+ n_mels: 80
69
+
70
+ n_token: 178 # 178 phonemes
71
+ max_dur: 50
72
+ style_dim: 128
73
+ dropout: 0.2
74
+
75
+ decoder:
76
+ type: hifigan
77
+ resblock_kernel_sizes: [3, 7, 11]
78
+ upsample_rates: [10, 5, 3, 2]
79
+ upsample_initial_channel: 512
80
+ resblock_dilation_sizes: [[1,3,5],[1,3,5],[1,3,5]]
81
+ upsample_kernel_sizes: [20, 10, 6, 4]
82
+
83
+ slm:
84
+ model: microsoft/wavlm-base-plus
85
+ sr: 16000
86
+ hidden: 768
87
+ nlayers: 13
88
+ initial_channel: 64
89
+
90
+ diffusion:
91
+ embedding_mask_proba: 0.1
92
+ transformer:
93
+ num_layers: 3
94
+ num_heads: 8
95
+ head_features: 64
96
+ multiplier: 2
97
+ dist:
98
+ sigma_data: 0.2 # ← placeholder; code will overwrite if
99
+ estimate_sigma_data: true
100
+ mean: -3.0
101
+ std: 1.0
102
+
103
+ # ─── EXTERNAL CHECKPOINTS ───────────────────────────────────────────
104
+ F0_path: "Utils/JDC/bst.t7"
105
+ ASR_config: "Utils/ASR/config.yml"
106
+ ASR_path: "Utils/ASR/epoch_00080.pth"
107
+ PLBERT_dir: 'Utils/PLBERT/'
108
+ first_stage_path: "" # filled automatically after this run
109
+
110
+ # ─── SLM ADVERSARIAL (ignored in stage-1, kept default) ─────────────
111
+ slmadv_params:
112
+ min_len: 400
113
+ max_len: 500
114
+ batch_percentage: 0.5
115
+ iter: 20
116
+ thresh: 5
117
+ scale: 0.01
118
+ sig: 1.5
logs/pod_90h_30k/tensorboard/events.out.tfevents.1749338343.104-171-203-10.11888.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f0b61e4b8751a174bd3e23296859ac8dc0de4f2e2a513995e1d587498a328bc
3
+ size 88
logs/pod_90h_30k/train.log ADDED
The diff for this file is too large to render. See raw diff
 
losses.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ from transformers import AutoModel
6
+
7
+ class SpectralConvergengeLoss(torch.nn.Module):
8
+ """Spectral convergence loss module."""
9
+
10
+ def __init__(self):
11
+ """Initilize spectral convergence loss module."""
12
+ super(SpectralConvergengeLoss, self).__init__()
13
+
14
+ def forward(self, x_mag, y_mag):
15
+ """Calculate forward propagation.
16
+ Args:
17
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
18
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
19
+ Returns:
20
+ Tensor: Spectral convergence loss value.
21
+ """
22
+ return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
23
+
24
+ class STFTLoss(torch.nn.Module):
25
+ """STFT loss module."""
26
+
27
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window):
28
+ """Initialize STFT loss module."""
29
+ super(STFTLoss, self).__init__()
30
+ self.fft_size = fft_size
31
+ self.shift_size = shift_size
32
+ self.win_length = win_length
33
+ self.to_mel = torchaudio.transforms.MelSpectrogram(sample_rate=24000, n_fft=fft_size, win_length=win_length, hop_length=shift_size, window_fn=window)
34
+
35
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
36
+
37
+ def forward(self, x, y):
38
+ """Calculate forward propagation.
39
+ Args:
40
+ x (Tensor): Predicted signal (B, T).
41
+ y (Tensor): Groundtruth signal (B, T).
42
+ Returns:
43
+ Tensor: Spectral convergence loss value.
44
+ Tensor: Log STFT magnitude loss value.
45
+ """
46
+ x_mag = self.to_mel(x)
47
+ mean, std = -4, 4
48
+ x_mag = (torch.log(1e-5 + x_mag) - mean) / std
49
+
50
+ y_mag = self.to_mel(y)
51
+ mean, std = -4, 4
52
+ y_mag = (torch.log(1e-5 + y_mag) - mean) / std
53
+
54
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
55
+ return sc_loss
56
+
57
+
58
+ class MultiResolutionSTFTLoss(torch.nn.Module):
59
+ """Multi resolution STFT loss module."""
60
+
61
+ def __init__(self,
62
+ fft_sizes=[1024, 2048, 512],
63
+ hop_sizes=[120, 240, 50],
64
+ win_lengths=[600, 1200, 240],
65
+ window=torch.hann_window):
66
+ """Initialize Multi resolution STFT loss module.
67
+ Args:
68
+ fft_sizes (list): List of FFT sizes.
69
+ hop_sizes (list): List of hop sizes.
70
+ win_lengths (list): List of window lengths.
71
+ window (str): Window function type.
72
+ """
73
+ super(MultiResolutionSTFTLoss, self).__init__()
74
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
75
+ self.stft_losses = torch.nn.ModuleList()
76
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
77
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
78
+
79
+ def forward(self, x, y):
80
+ """Calculate forward propagation.
81
+ Args:
82
+ x (Tensor): Predicted signal (B, T).
83
+ y (Tensor): Groundtruth signal (B, T).
84
+ Returns:
85
+ Tensor: Multi resolution spectral convergence loss value.
86
+ Tensor: Multi resolution log STFT magnitude loss value.
87
+ """
88
+ sc_loss = 0.0
89
+ for f in self.stft_losses:
90
+ sc_l = f(x, y)
91
+ sc_loss += sc_l
92
+ sc_loss /= len(self.stft_losses)
93
+
94
+ return sc_loss
95
+
96
+
97
+ def feature_loss(fmap_r, fmap_g):
98
+ loss = 0
99
+ for dr, dg in zip(fmap_r, fmap_g):
100
+ for rl, gl in zip(dr, dg):
101
+ loss += torch.mean(torch.abs(rl - gl))
102
+
103
+ return loss*2
104
+
105
+
106
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
107
+ loss = 0
108
+ r_losses = []
109
+ g_losses = []
110
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
111
+ r_loss = torch.mean((1-dr)**2)
112
+ g_loss = torch.mean(dg**2)
113
+ loss += (r_loss + g_loss)
114
+ r_losses.append(r_loss.item())
115
+ g_losses.append(g_loss.item())
116
+
117
+ return loss, r_losses, g_losses
118
+
119
+
120
+ def generator_loss(disc_outputs):
121
+ loss = 0
122
+ gen_losses = []
123
+ for dg in disc_outputs:
124
+ l = torch.mean((1-dg)**2)
125
+ gen_losses.append(l)
126
+ loss += l
127
+
128
+ return loss, gen_losses
129
+
130
+ """ https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """
131
+ def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
132
+ loss = 0
133
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
134
+ tau = 0.04
135
+ m_DG = torch.median((dr-dg))
136
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
137
+ loss += tau - F.relu(tau - L_rel)
138
+ return loss
139
+
140
+ def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
141
+ loss = 0
142
+ for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
143
+ tau = 0.04
144
+ m_DG = torch.median((dr-dg))
145
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
146
+ loss += tau - F.relu(tau - L_rel)
147
+ return loss
148
+
149
+ class GeneratorLoss(torch.nn.Module):
150
+
151
+ def __init__(self, mpd, msd):
152
+ super(GeneratorLoss, self).__init__()
153
+ self.mpd = mpd
154
+ self.msd = msd
155
+
156
+ def forward(self, y, y_hat):
157
+ y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
158
+ y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
159
+ loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
160
+ loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
161
+ loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
162
+ loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
163
+
164
+ loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
165
+
166
+ loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel
167
+
168
+ return loss_gen_all.mean()
169
+
170
+ class DiscriminatorLoss(torch.nn.Module):
171
+
172
+ def __init__(self, mpd, msd):
173
+ super(DiscriminatorLoss, self).__init__()
174
+ self.mpd = mpd
175
+ self.msd = msd
176
+
177
+ def forward(self, y, y_hat):
178
+ # MPD
179
+ y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
180
+ loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
181
+ # MSD
182
+ y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
183
+ loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
184
+
185
+ loss_rel = discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
186
+
187
+
188
+ d_loss = loss_disc_s + loss_disc_f + loss_rel
189
+
190
+ return d_loss.mean()
191
+
192
+
193
+ class WavLMLoss(torch.nn.Module):
194
+
195
+ def __init__(self, model, wd, model_sr, slm_sr=16000):
196
+ super(WavLMLoss, self).__init__()
197
+ self.wavlm = AutoModel.from_pretrained(model)
198
+ self.wd = wd
199
+ self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
200
+
201
+ def forward(self, wav, y_rec):
202
+ with torch.no_grad():
203
+ wav_16 = self.resample(wav)
204
+ wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
205
+ y_rec_16 = self.resample(y_rec)
206
+ y_rec_embeddings = self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True).hidden_states
207
+
208
+ floss = 0
209
+ for er, eg in zip(wav_embeddings, y_rec_embeddings):
210
+ floss += torch.mean(torch.abs(er - eg))
211
+
212
+ return floss.mean()
213
+
214
+ def generator(self, y_rec):
215
+ y_rec_16 = self.resample(y_rec)
216
+ y_rec_embeddings = self.wavlm(input_values=y_rec_16, output_hidden_states=True).hidden_states
217
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
218
+ y_df_hat_g = self.wd(y_rec_embeddings)
219
+ loss_gen = torch.mean((1-y_df_hat_g)**2)
220
+
221
+ return loss_gen
222
+
223
+ def discriminator(self, wav, y_rec):
224
+ with torch.no_grad():
225
+ wav_16 = self.resample(wav)
226
+ wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
227
+ y_rec_16 = self.resample(y_rec)
228
+ y_rec_embeddings = self.wavlm(input_values=y_rec_16, output_hidden_states=True).hidden_states
229
+
230
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
231
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
232
+
233
+ y_d_rs = self.wd(y_embeddings)
234
+ y_d_gs = self.wd(y_rec_embeddings)
235
+
236
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
237
+
238
+ r_loss = torch.mean((1-y_df_hat_r)**2)
239
+ g_loss = torch.mean((y_df_hat_g)**2)
240
+
241
+ loss_disc_f = r_loss + g_loss
242
+
243
+ return loss_disc_f.mean()
244
+
245
+ def discriminator_forward(self, wav):
246
+ with torch.no_grad():
247
+ wav_16 = self.resample(wav)
248
+ wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
249
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
250
+
251
+ y_d_rs = self.wd(y_embeddings)
252
+
253
+ return y_d_rs
meldataset.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding: utf-8
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ import random
6
+ import numpy as np
7
+ import random
8
+ import soundfile as sf
9
+ import librosa
10
+ import re, unicodedata
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+ import torchaudio
16
+ from torch.utils.data import DataLoader
17
+
18
+ import logging
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.DEBUG)
21
+
22
+ import pandas as pd
23
+
24
+ _pad = "$"
25
+ _punctuation = ';:,.!?¡¿—…"«»“” '
26
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
27
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
28
+
29
+ # Export all symbols:
30
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
31
+
32
+ dicts = {}
33
+ for i in range(len((symbols))):
34
+ dicts[symbols[i]] = i
35
+
36
+ # class TextCleaner:
37
+ # def __init__(self, dummy=None):
38
+ # self.word_index_dictionary = dicts
39
+ # def __call__(self, text):
40
+ # indexes = []
41
+ # for char in text:
42
+ # try:
43
+ # indexes.append(self.word_index_dictionary[char])
44
+ # except KeyError:
45
+ # print(text)
46
+ # return indexes
47
+
48
+ class TextCleaner:
49
+ """
50
+ • Normalises text to NFC so pre-composed IPA glyphs match `symbols`.
51
+ • Splits on event tokens first (e.g. <evt_gasp>), then per-character.
52
+ • Unknown chars map to the <unk> symbol instead of printing.
53
+ """
54
+ _EVENT_RE = re.compile(r"<[^>]+>|.") # match <evt_xxx> or single char
55
+
56
+ def __init__(self):
57
+ # `dicts` must already include EVENT_TOKENS and "<unk>"
58
+ self.lookup = dicts
59
+ self.unk_id = 0
60
+
61
+ def __call__(self, text: str):
62
+ text = unicodedata.normalize("NFC", text)
63
+ ids = []
64
+ for tok in self._EVENT_RE.findall(text):
65
+ ids.append(self.lookup.get(tok, self.unk_id))
66
+ return ids
67
+
68
+
69
+ np.random.seed(1)
70
+ random.seed(1)
71
+ SPECT_PARAMS = {
72
+ "n_fft": 2048,
73
+ "win_length": 1200,
74
+ "hop_length": 300
75
+ }
76
+ MEL_PARAMS = {
77
+ "n_mels": 80,
78
+ }
79
+
80
+ to_mel = torchaudio.transforms.MelSpectrogram(
81
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
82
+ mean, std = -4, 4
83
+
84
+ def preprocess(wave):
85
+ wave_tensor = torch.from_numpy(wave).float()
86
+ mel_tensor = to_mel(wave_tensor)
87
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
88
+ return mel_tensor
89
+
90
+ class FilePathDataset(torch.utils.data.Dataset):
91
+ def __init__(self,
92
+ data_list,
93
+ root_path,
94
+ sr=24000,
95
+ data_augmentation=False,
96
+ validation=False,
97
+ OOD_data="Data/OOD_texts.txt",
98
+ min_length=50,
99
+ ):
100
+
101
+ spect_params = SPECT_PARAMS
102
+ mel_params = MEL_PARAMS
103
+
104
+ _data_list = [l.strip().split('|') for l in data_list]
105
+ self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
106
+ self.text_cleaner = TextCleaner()
107
+ self.sr = sr
108
+
109
+ self.df = pd.DataFrame(self.data_list)
110
+
111
+ self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS)
112
+
113
+ self.mean, self.std = -4, 4
114
+ self.data_augmentation = data_augmentation and (not validation)
115
+ self.max_mel_length = 192
116
+
117
+ self.min_length = min_length
118
+ with open(OOD_data, 'r', encoding='utf-8') as f:
119
+ tl = f.readlines()
120
+ idx = 1 if '.wav' in tl[0].split('|')[0] else 0
121
+ self.ptexts = [t.split('|')[idx] for t in tl]
122
+
123
+ self.root_path = root_path
124
+
125
+ def __len__(self):
126
+ return len(self.data_list)
127
+
128
+ def __getitem__(self, idx):
129
+ data = self.data_list[idx]
130
+ path = data[0]
131
+
132
+ wave, text_tensor, speaker_id = self._load_tensor(data)
133
+
134
+ mel_tensor = preprocess(wave).squeeze()
135
+
136
+ acoustic_feature = mel_tensor.squeeze()
137
+ length_feature = acoustic_feature.size(1)
138
+ acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
139
+
140
+ # get reference sample
141
+ ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist()
142
+ ref_mel_tensor, ref_label = self._load_data(ref_data[:3])
143
+
144
+ # get OOD text
145
+
146
+ ps = ""
147
+
148
+ while len(ps) < self.min_length:
149
+ rand_idx = np.random.randint(0, len(self.ptexts) - 1)
150
+ ps = self.ptexts[rand_idx]
151
+
152
+ text = self.text_cleaner(ps)
153
+ text.insert(0, 0)
154
+ text.append(0)
155
+
156
+ ref_text = torch.LongTensor(text)
157
+
158
+ return speaker_id, acoustic_feature, text_tensor, ref_text, ref_mel_tensor, ref_label, path, wave
159
+
160
+ def _load_tensor(self, data):
161
+ wave_path, text, speaker_id = data
162
+ speaker_id = int(speaker_id)
163
+ full_path = osp.join(self.root_path, wave_path)
164
+ try:
165
+ wave, sr = sf.read(full_path, dtype="float32")
166
+ except Exception as e:
167
+ print(f"[BAD] {full_path} -> {e}", flush=True)
168
+ raise
169
+ if wave.shape[-1] == 2:
170
+ wave = wave[:, 0].squeeze()
171
+ if sr != 24000:
172
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=24000)
173
+ print(wave_path, sr)
174
+
175
+ wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)
176
+
177
+ text = self.text_cleaner(text)
178
+
179
+ text.insert(0, 0)
180
+ text.append(0)
181
+
182
+ text = torch.LongTensor(text)
183
+
184
+ return wave, text, speaker_id
185
+
186
+ def _load_data(self, data):
187
+ wave, text_tensor, speaker_id = self._load_tensor(data)
188
+ mel_tensor = preprocess(wave).squeeze()
189
+
190
+ mel_length = mel_tensor.size(1)
191
+ if mel_length > self.max_mel_length:
192
+ random_start = np.random.randint(0, mel_length - self.max_mel_length)
193
+ mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length]
194
+
195
+ return mel_tensor, speaker_id
196
+
197
+
198
+ class Collater(object):
199
+ """
200
+ Args:
201
+ adaptive_batch_size (bool): if true, decrease batch size when long data comes.
202
+ """
203
+
204
+ def __init__(self, return_wave=False):
205
+ self.text_pad_index = 0
206
+ self.min_mel_length = 192
207
+ self.max_mel_length = 192
208
+ self.return_wave = return_wave
209
+
210
+
211
+ def __call__(self, batch):
212
+ # batch[0] = wave, mel, text, f0, speakerid
213
+ batch_size = len(batch)
214
+
215
+ # sort by mel length
216
+ lengths = [b[1].shape[1] for b in batch]
217
+ batch_indexes = np.argsort(lengths)[::-1]
218
+ batch = [batch[bid] for bid in batch_indexes]
219
+
220
+ nmels = batch[0][1].size(0)
221
+ max_mel_length = max([b[1].shape[1] for b in batch])
222
+ max_text_length = max([b[2].shape[0] for b in batch])
223
+ max_rtext_length = max([b[3].shape[0] for b in batch])
224
+
225
+ labels = torch.zeros((batch_size)).long()
226
+ mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
227
+ texts = torch.zeros((batch_size, max_text_length)).long()
228
+ ref_texts = torch.zeros((batch_size, max_rtext_length)).long()
229
+
230
+ input_lengths = torch.zeros(batch_size).long()
231
+ ref_lengths = torch.zeros(batch_size).long()
232
+ output_lengths = torch.zeros(batch_size).long()
233
+ ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float()
234
+ ref_labels = torch.zeros((batch_size)).long()
235
+ paths = ['' for _ in range(batch_size)]
236
+ waves = [None for _ in range(batch_size)]
237
+
238
+ for bid, (label, mel, text, ref_text, ref_mel, ref_label, path, wave) in enumerate(batch):
239
+ mel_size = mel.size(1)
240
+ text_size = text.size(0)
241
+ rtext_size = ref_text.size(0)
242
+ labels[bid] = label
243
+ mels[bid, :, :mel_size] = mel
244
+ texts[bid, :text_size] = text
245
+ ref_texts[bid, :rtext_size] = ref_text
246
+ input_lengths[bid] = text_size
247
+ ref_lengths[bid] = rtext_size
248
+ output_lengths[bid] = mel_size
249
+ paths[bid] = path
250
+ ref_mel_size = ref_mel.size(1)
251
+ ref_mels[bid, :, :ref_mel_size] = ref_mel
252
+
253
+ ref_labels[bid] = ref_label
254
+ waves[bid] = wave
255
+
256
+ return waves, texts, input_lengths, ref_texts, ref_lengths, mels, output_lengths, ref_mels
257
+
258
+
259
+
260
+ def build_dataloader(path_list,
261
+ root_path,
262
+ validation=False,
263
+ OOD_data="Data/OOD_texts.txt",
264
+ min_length=50,
265
+ batch_size=4,
266
+ num_workers=1,
267
+ device='cpu',
268
+ collate_config={},
269
+ dataset_config={}):
270
+
271
+ dataset = FilePathDataset(path_list, root_path, OOD_data=OOD_data, min_length=min_length, validation=validation, **dataset_config)
272
+ collate_fn = Collater(**collate_config)
273
+ data_loader = DataLoader(dataset,
274
+ batch_size=batch_size,
275
+ shuffle=(not validation),
276
+ num_workers=num_workers,
277
+ drop_last=(not validation),
278
+ collate_fn=collate_fn,
279
+ pin_memory=(device != 'cpu'))
280
+
281
+ return data_loader
282
+
models.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding:utf-8
2
+
3
+ import os
4
+ import os.path as osp
5
+
6
+ import copy
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14
+
15
+ from Utils.ASR.models import ASRCNN
16
+ from Utils.JDC.model import JDCNet
17
+
18
+ from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
19
+ from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
20
+ from Modules.diffusion.diffusion import AudioDiffusionConditional
21
+
22
+ from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator
23
+
24
+ from munch import Munch
25
+ import yaml
26
+
27
+ class LearnedDownSample(nn.Module):
28
+ def __init__(self, layer_type, dim_in):
29
+ super().__init__()
30
+ self.layer_type = layer_type
31
+
32
+ if self.layer_type == 'none':
33
+ self.conv = nn.Identity()
34
+ elif self.layer_type == 'timepreserve':
35
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
36
+ elif self.layer_type == 'half':
37
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
38
+ else:
39
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
40
+
41
+ def forward(self, x):
42
+ return self.conv(x)
43
+
44
+ class LearnedUpSample(nn.Module):
45
+ def __init__(self, layer_type, dim_in):
46
+ super().__init__()
47
+ self.layer_type = layer_type
48
+
49
+ if self.layer_type == 'none':
50
+ self.conv = nn.Identity()
51
+ elif self.layer_type == 'timepreserve':
52
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
53
+ elif self.layer_type == 'half':
54
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
55
+ else:
56
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
57
+
58
+
59
+ def forward(self, x):
60
+ return self.conv(x)
61
+
62
+ class DownSample(nn.Module):
63
+ def __init__(self, layer_type):
64
+ super().__init__()
65
+ self.layer_type = layer_type
66
+
67
+ def forward(self, x):
68
+ if self.layer_type == 'none':
69
+ return x
70
+ elif self.layer_type == 'timepreserve':
71
+ return F.avg_pool2d(x, (2, 1))
72
+ elif self.layer_type == 'half':
73
+ if x.shape[-1] % 2 != 0:
74
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
75
+ return F.avg_pool2d(x, 2)
76
+ else:
77
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
78
+
79
+
80
+ class UpSample(nn.Module):
81
+ def __init__(self, layer_type):
82
+ super().__init__()
83
+ self.layer_type = layer_type
84
+
85
+ def forward(self, x):
86
+ if self.layer_type == 'none':
87
+ return x
88
+ elif self.layer_type == 'timepreserve':
89
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
90
+ elif self.layer_type == 'half':
91
+ return F.interpolate(x, scale_factor=2, mode='nearest')
92
+ else:
93
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
94
+
95
+
96
+ class ResBlk(nn.Module):
97
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
98
+ normalize=False, downsample='none'):
99
+ super().__init__()
100
+ self.actv = actv
101
+ self.normalize = normalize
102
+ self.downsample = DownSample(downsample)
103
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
104
+ self.learned_sc = dim_in != dim_out
105
+ self._build_weights(dim_in, dim_out)
106
+
107
+ def _build_weights(self, dim_in, dim_out):
108
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
109
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
110
+ if self.normalize:
111
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
112
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
113
+ if self.learned_sc:
114
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
115
+
116
+ def _shortcut(self, x):
117
+ if self.learned_sc:
118
+ x = self.conv1x1(x)
119
+ if self.downsample:
120
+ x = self.downsample(x)
121
+ return x
122
+
123
+ def _residual(self, x):
124
+ if self.normalize:
125
+ x = self.norm1(x)
126
+ x = self.actv(x)
127
+ x = self.conv1(x)
128
+ x = self.downsample_res(x)
129
+ if self.normalize:
130
+ x = self.norm2(x)
131
+ x = self.actv(x)
132
+ x = self.conv2(x)
133
+ return x
134
+
135
+ def forward(self, x):
136
+ x = self._shortcut(x) + self._residual(x)
137
+ return x / math.sqrt(2) # unit variance
138
+
139
+ class StyleEncoder(nn.Module):
140
+ def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
141
+ super().__init__()
142
+ blocks = []
143
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
144
+
145
+ repeat_num = 4
146
+ for _ in range(repeat_num):
147
+ dim_out = min(dim_in*2, max_conv_dim)
148
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
149
+ dim_in = dim_out
150
+
151
+ blocks += [nn.LeakyReLU(0.2)]
152
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
153
+ blocks += [nn.AdaptiveAvgPool2d(1)]
154
+ blocks += [nn.LeakyReLU(0.2)]
155
+ self.shared = nn.Sequential(*blocks)
156
+
157
+ self.unshared = nn.Linear(dim_out, style_dim)
158
+
159
+ def forward(self, x):
160
+ h = self.shared(x)
161
+ h = h.view(h.size(0), -1)
162
+ s = self.unshared(h)
163
+
164
+ return s
165
+
166
+ class LinearNorm(torch.nn.Module):
167
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
168
+ super(LinearNorm, self).__init__()
169
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
170
+
171
+ torch.nn.init.xavier_uniform_(
172
+ self.linear_layer.weight,
173
+ gain=torch.nn.init.calculate_gain(w_init_gain))
174
+
175
+ def forward(self, x):
176
+ return self.linear_layer(x)
177
+
178
+ class Discriminator2d(nn.Module):
179
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
180
+ super().__init__()
181
+ blocks = []
182
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
183
+
184
+ for lid in range(repeat_num):
185
+ dim_out = min(dim_in*2, max_conv_dim)
186
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
187
+ dim_in = dim_out
188
+
189
+ blocks += [nn.LeakyReLU(0.2)]
190
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
191
+ blocks += [nn.LeakyReLU(0.2)]
192
+ blocks += [nn.AdaptiveAvgPool2d(1)]
193
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
194
+ self.main = nn.Sequential(*blocks)
195
+
196
+ def get_feature(self, x):
197
+ features = []
198
+ for l in self.main:
199
+ x = l(x)
200
+ features.append(x)
201
+ out = features[-1]
202
+ out = out.view(out.size(0), -1) # (batch, num_domains)
203
+ return out, features
204
+
205
+ def forward(self, x):
206
+ out, features = self.get_feature(x)
207
+ out = out.squeeze() # (batch)
208
+ return out, features
209
+
210
+ class ResBlk1d(nn.Module):
211
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
212
+ normalize=False, downsample='none', dropout_p=0.2):
213
+ super().__init__()
214
+ self.actv = actv
215
+ self.normalize = normalize
216
+ self.downsample_type = downsample
217
+ self.learned_sc = dim_in != dim_out
218
+ self._build_weights(dim_in, dim_out)
219
+ self.dropout_p = dropout_p
220
+
221
+ if self.downsample_type == 'none':
222
+ self.pool = nn.Identity()
223
+ else:
224
+ self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
225
+
226
+ def _build_weights(self, dim_in, dim_out):
227
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
228
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
229
+ if self.normalize:
230
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
231
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
232
+ if self.learned_sc:
233
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
234
+
235
+ def downsample(self, x):
236
+ if self.downsample_type == 'none':
237
+ return x
238
+ else:
239
+ if x.shape[-1] % 2 != 0:
240
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
241
+ return F.avg_pool1d(x, 2)
242
+
243
+ def _shortcut(self, x):
244
+ if self.learned_sc:
245
+ x = self.conv1x1(x)
246
+ x = self.downsample(x)
247
+ return x
248
+
249
+ def _residual(self, x):
250
+ if self.normalize:
251
+ x = self.norm1(x)
252
+ x = self.actv(x)
253
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
254
+
255
+ x = self.conv1(x)
256
+ x = self.pool(x)
257
+ if self.normalize:
258
+ x = self.norm2(x)
259
+
260
+ x = self.actv(x)
261
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
262
+
263
+ x = self.conv2(x)
264
+ return x
265
+
266
+ def forward(self, x):
267
+ x = self._shortcut(x) + self._residual(x)
268
+ return x / math.sqrt(2) # unit variance
269
+
270
+ class LayerNorm(nn.Module):
271
+ def __init__(self, channels, eps=1e-5):
272
+ super().__init__()
273
+ self.channels = channels
274
+ self.eps = eps
275
+
276
+ self.gamma = nn.Parameter(torch.ones(channels))
277
+ self.beta = nn.Parameter(torch.zeros(channels))
278
+
279
+ def forward(self, x):
280
+ x = x.transpose(1, -1)
281
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
282
+ return x.transpose(1, -1)
283
+
284
+ class TextEncoder(nn.Module):
285
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
286
+ super().__init__()
287
+ self.embedding = nn.Embedding(n_symbols, channels)
288
+
289
+ padding = (kernel_size - 1) // 2
290
+ self.cnn = nn.ModuleList()
291
+ for _ in range(depth):
292
+ self.cnn.append(nn.Sequential(
293
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
294
+ LayerNorm(channels),
295
+ actv,
296
+ nn.Dropout(0.2),
297
+ ))
298
+ # self.cnn = nn.Sequential(*self.cnn)
299
+
300
+ self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
301
+
302
+ def forward(self, x, input_lengths, m):
303
+ x = self.embedding(x) # [B, T, emb]
304
+ x = x.transpose(1, 2) # [B, emb, T]
305
+ m = m.to(input_lengths.device).unsqueeze(1)
306
+ x.masked_fill_(m, 0.0)
307
+
308
+ for c in self.cnn:
309
+ x = c(x)
310
+ x.masked_fill_(m, 0.0)
311
+
312
+ x = x.transpose(1, 2) # [B, T, chn]
313
+
314
+ input_lengths = input_lengths.cpu().numpy()
315
+ x = nn.utils.rnn.pack_padded_sequence(
316
+ x, input_lengths, batch_first=True, enforce_sorted=False)
317
+
318
+ self.lstm.flatten_parameters()
319
+ x, _ = self.lstm(x)
320
+ x, _ = nn.utils.rnn.pad_packed_sequence(
321
+ x, batch_first=True)
322
+
323
+ x = x.transpose(-1, -2)
324
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
325
+
326
+ x_pad[:, :, :x.shape[-1]] = x
327
+ x = x_pad.to(x.device)
328
+
329
+ x.masked_fill_(m, 0.0)
330
+
331
+ return x
332
+
333
+ def inference(self, x):
334
+ x = self.embedding(x)
335
+ x = x.transpose(1, 2)
336
+ x = self.cnn(x)
337
+ x = x.transpose(1, 2)
338
+ self.lstm.flatten_parameters()
339
+ x, _ = self.lstm(x)
340
+ return x
341
+
342
+ def length_to_mask(self, lengths):
343
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
344
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
345
+ return mask
346
+
347
+
348
+
349
+ class AdaIN1d(nn.Module):
350
+ def __init__(self, style_dim, num_features):
351
+ super().__init__()
352
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
353
+ self.fc = nn.Linear(style_dim, num_features*2)
354
+
355
+ def forward(self, x, s):
356
+ h = self.fc(s)
357
+ h = h.view(h.size(0), h.size(1), 1)
358
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
359
+ return (1 + gamma) * self.norm(x) + beta
360
+
361
+ class UpSample1d(nn.Module):
362
+ def __init__(self, layer_type):
363
+ super().__init__()
364
+ self.layer_type = layer_type
365
+
366
+ def forward(self, x):
367
+ if self.layer_type == 'none':
368
+ return x
369
+ else:
370
+ return F.interpolate(x, scale_factor=2, mode='nearest')
371
+
372
+ class AdainResBlk1d(nn.Module):
373
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
374
+ upsample='none', dropout_p=0.0):
375
+ super().__init__()
376
+ self.actv = actv
377
+ self.upsample_type = upsample
378
+ self.upsample = UpSample1d(upsample)
379
+ self.learned_sc = dim_in != dim_out
380
+ self._build_weights(dim_in, dim_out, style_dim)
381
+ self.dropout = nn.Dropout(dropout_p)
382
+
383
+ if upsample == 'none':
384
+ self.pool = nn.Identity()
385
+ else:
386
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
387
+
388
+
389
+ def _build_weights(self, dim_in, dim_out, style_dim):
390
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
391
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
392
+ self.norm1 = AdaIN1d(style_dim, dim_in)
393
+ self.norm2 = AdaIN1d(style_dim, dim_out)
394
+ if self.learned_sc:
395
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
396
+
397
+ def _shortcut(self, x):
398
+ x = self.upsample(x)
399
+ if self.learned_sc:
400
+ x = self.conv1x1(x)
401
+ return x
402
+
403
+ def _residual(self, x, s):
404
+ x = self.norm1(x, s)
405
+ x = self.actv(x)
406
+ x = self.pool(x)
407
+ x = self.conv1(self.dropout(x))
408
+ x = self.norm2(x, s)
409
+ x = self.actv(x)
410
+ x = self.conv2(self.dropout(x))
411
+ return x
412
+
413
+ def forward(self, x, s):
414
+ out = self._residual(x, s)
415
+ out = (out + self._shortcut(x)) / math.sqrt(2)
416
+ return out
417
+
418
+ class AdaLayerNorm(nn.Module):
419
+ def __init__(self, style_dim, channels, eps=1e-5):
420
+ super().__init__()
421
+ self.channels = channels
422
+ self.eps = eps
423
+
424
+ self.fc = nn.Linear(style_dim, channels*2)
425
+
426
+ def forward(self, x, s):
427
+ x = x.transpose(-1, -2)
428
+ x = x.transpose(1, -1)
429
+
430
+ h = self.fc(s)
431
+ h = h.view(h.size(0), h.size(1), 1)
432
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
433
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
434
+
435
+
436
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
437
+ x = (1 + gamma) * x + beta
438
+ return x.transpose(1, -1).transpose(-1, -2)
439
+
440
+ class ProsodyPredictor(nn.Module):
441
+
442
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
443
+ super().__init__()
444
+
445
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
446
+ d_model=d_hid,
447
+ nlayers=nlayers,
448
+ dropout=dropout)
449
+
450
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
451
+ self.duration_proj = LinearNorm(d_hid, max_dur)
452
+
453
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
454
+ self.F0 = nn.ModuleList()
455
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
456
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
457
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
458
+
459
+ self.N = nn.ModuleList()
460
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
461
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
462
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
463
+
464
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
465
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
466
+
467
+
468
+ def forward(self, texts, style, text_lengths, alignment, m):
469
+ d = self.text_encoder(texts, style, text_lengths, m)
470
+
471
+ batch_size = d.shape[0]
472
+ text_size = d.shape[1]
473
+
474
+ # predict duration
475
+ input_lengths = text_lengths.cpu().numpy()
476
+ x = nn.utils.rnn.pack_padded_sequence(
477
+ d, input_lengths, batch_first=True, enforce_sorted=False)
478
+
479
+ m = m.to(text_lengths.device).unsqueeze(1)
480
+
481
+ self.lstm.flatten_parameters()
482
+ x, _ = self.lstm(x)
483
+ x, _ = nn.utils.rnn.pad_packed_sequence(
484
+ x, batch_first=True)
485
+
486
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
487
+
488
+ x_pad[:, :x.shape[1], :] = x
489
+ x = x_pad.to(x.device)
490
+
491
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
492
+
493
+ en = (d.transpose(-1, -2) @ alignment)
494
+
495
+ return duration.squeeze(-1), en
496
+
497
+ def F0Ntrain(self, x, s):
498
+ x, _ = self.shared(x.transpose(-1, -2))
499
+
500
+ F0 = x.transpose(-1, -2)
501
+ for block in self.F0:
502
+ F0 = block(F0, s)
503
+ F0 = self.F0_proj(F0)
504
+
505
+ N = x.transpose(-1, -2)
506
+ for block in self.N:
507
+ N = block(N, s)
508
+ N = self.N_proj(N)
509
+
510
+ return F0.squeeze(1), N.squeeze(1)
511
+
512
+ def length_to_mask(self, lengths):
513
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
514
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
515
+ return mask
516
+
517
+ class DurationEncoder(nn.Module):
518
+
519
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
520
+ super().__init__()
521
+ self.lstms = nn.ModuleList()
522
+ for _ in range(nlayers):
523
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
524
+ d_model // 2,
525
+ num_layers=1,
526
+ batch_first=True,
527
+ bidirectional=True,
528
+ dropout=dropout))
529
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
530
+
531
+
532
+ self.dropout = dropout
533
+ self.d_model = d_model
534
+ self.sty_dim = sty_dim
535
+
536
+ def forward(self, x, style, text_lengths, m):
537
+ masks = m.to(text_lengths.device)
538
+
539
+ x = x.permute(2, 0, 1)
540
+ s = style.expand(x.shape[0], x.shape[1], -1)
541
+ x = torch.cat([x, s], axis=-1)
542
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
543
+
544
+ x = x.transpose(0, 1)
545
+ input_lengths = text_lengths.cpu().numpy()
546
+ x = x.transpose(-1, -2)
547
+
548
+ for block in self.lstms:
549
+ if isinstance(block, AdaLayerNorm):
550
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
551
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
552
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
553
+ else:
554
+ x = x.transpose(-1, -2)
555
+ x = nn.utils.rnn.pack_padded_sequence(
556
+ x, input_lengths, batch_first=True, enforce_sorted=False)
557
+ block.flatten_parameters()
558
+ x, _ = block(x)
559
+ x, _ = nn.utils.rnn.pad_packed_sequence(
560
+ x, batch_first=True)
561
+ x = F.dropout(x, p=self.dropout, training=self.training)
562
+ x = x.transpose(-1, -2)
563
+
564
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
565
+
566
+ x_pad[:, :, :x.shape[-1]] = x
567
+ x = x_pad.to(x.device)
568
+
569
+ return x.transpose(-1, -2)
570
+
571
+ def inference(self, x, style):
572
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
573
+ style = style.expand(x.shape[0], x.shape[1], -1)
574
+ x = torch.cat([x, style], axis=-1)
575
+ src = self.pos_encoder(x)
576
+ output = self.transformer_encoder(src).transpose(0, 1)
577
+ return output
578
+
579
+ def length_to_mask(self, lengths):
580
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
581
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
582
+ return mask
583
+
584
+ def load_F0_models(path):
585
+ # load F0 model
586
+
587
+ F0_model = JDCNet(num_class=1, seq_len=192)
588
+ params = torch.load(path, map_location='cpu')['net']
589
+ F0_model.load_state_dict(params)
590
+ _ = F0_model.train()
591
+
592
+ return F0_model
593
+
594
+ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
595
+ # load ASR model
596
+ def _load_config(path):
597
+ with open(path) as f:
598
+ config = yaml.safe_load(f)
599
+ model_config = config['model_params']
600
+ return model_config
601
+
602
+ def _load_model(model_config, model_path):
603
+ model = ASRCNN(**model_config)
604
+ params = torch.load(model_path, map_location='cpu', weights_only=False)['model']
605
+ model.load_state_dict(params)
606
+ return model
607
+
608
+ asr_model_config = _load_config(ASR_MODEL_CONFIG)
609
+ asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
610
+ _ = asr_model.train()
611
+
612
+ return asr_model
613
+
614
+ def build_model(args, text_aligner, pitch_extractor, bert):
615
+ assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
616
+
617
+ if args.decoder.type == "istftnet":
618
+ from Modules.istftnet import Decoder
619
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
620
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
621
+ upsample_rates = args.decoder.upsample_rates,
622
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
623
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
624
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
625
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
626
+ else:
627
+ from Modules.hifigan import Decoder
628
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
629
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
630
+ upsample_rates = args.decoder.upsample_rates,
631
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
632
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
633
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
634
+
635
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
636
+
637
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
638
+
639
+ style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
640
+ predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
641
+
642
+ # define diffusion model
643
+ if args.multispeaker:
644
+ transformer = StyleTransformer1d(channels=args.style_dim*2,
645
+ context_embedding_features=bert.config.hidden_size,
646
+ context_features=args.style_dim*2,
647
+ **args.diffusion.transformer)
648
+ else:
649
+ transformer = Transformer1d(channels=args.style_dim*2,
650
+ context_embedding_features=bert.config.hidden_size,
651
+ **args.diffusion.transformer)
652
+
653
+ diffusion = AudioDiffusionConditional(
654
+ in_channels=1,
655
+ embedding_max_length=bert.config.max_position_embeddings,
656
+ embedding_features=bert.config.hidden_size,
657
+ embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
658
+ channels=args.style_dim*2,
659
+ context_features=args.style_dim*2,
660
+ )
661
+
662
+ diffusion.diffusion = KDiffusion(
663
+ net=diffusion.unet,
664
+ sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
665
+ sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
666
+ dynamic_threshold=0.0
667
+ )
668
+ diffusion.diffusion.net = transformer
669
+ diffusion.unet = transformer
670
+
671
+
672
+ nets = Munch(
673
+ bert=bert,
674
+ bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
675
+
676
+ predictor=predictor,
677
+ decoder=decoder,
678
+ text_encoder=text_encoder,
679
+
680
+ predictor_encoder=predictor_encoder,
681
+ style_encoder=style_encoder,
682
+ diffusion=diffusion,
683
+
684
+ text_aligner = text_aligner,
685
+ pitch_extractor=pitch_extractor,
686
+
687
+ mpd = MultiPeriodDiscriminator(),
688
+ msd = MultiResSpecDiscriminator(),
689
+
690
+ # slm discriminator head
691
+ wd = WavLMDiscriminator(args.slm.hidden, args.slm.nlayers, args.slm.initial_channel),
692
+ )
693
+
694
+ return nets
695
+
696
+ def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
697
+ state = torch.load(path, map_location='cpu')
698
+ params = state['net']
699
+ for key in model:
700
+ if key in params and key not in ignore_modules:
701
+ print('%s loaded' % key)
702
+ model[key].load_state_dict(params[key], strict=False)
703
+ _ = [model[key].eval() for key in model]
704
+
705
+ if not load_only_params:
706
+ epoch = state["epoch"]
707
+ iters = state["iters"]
708
+ optimizer.load_state_dict(state["optimizer"])
709
+ else:
710
+ epoch = 0
711
+ iters = 0
712
+
713
+ return model, optimizer, epoch, iters
optimizers.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding:utf-8
2
+ import os, sys
3
+ import os.path as osp
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from torch.optim import Optimizer
8
+ from functools import reduce
9
+ from torch.optim import AdamW
10
+
11
+ class MultiOptimizer:
12
+ def __init__(self, optimizers={}, schedulers={}):
13
+ self.optimizers = optimizers
14
+ self.schedulers = schedulers
15
+ self.keys = list(optimizers.keys())
16
+ self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()])
17
+
18
+ def state_dict(self):
19
+ state_dicts = [(key, self.optimizers[key].state_dict())\
20
+ for key in self.keys]
21
+ return state_dicts
22
+
23
+ def load_state_dict(self, state_dict):
24
+ for key, val in state_dict:
25
+ try:
26
+ self.optimizers[key].load_state_dict(val)
27
+ except:
28
+ print("Unloaded %s" % key)
29
+
30
+ def step(self, key=None, scaler=None):
31
+ keys = [key] if key is not None else self.keys
32
+ _ = [self._step(key, scaler) for key in keys]
33
+
34
+ def _step(self, key, scaler=None):
35
+ if scaler is not None:
36
+ scaler.step(self.optimizers[key])
37
+ scaler.update()
38
+ else:
39
+ self.optimizers[key].step()
40
+
41
+ def zero_grad(self, key=None):
42
+ if key is not None:
43
+ self.optimizers[key].zero_grad()
44
+ else:
45
+ _ = [self.optimizers[key].zero_grad() for key in self.keys]
46
+
47
+ def scheduler(self, *args, key=None):
48
+ if key is not None:
49
+ self.schedulers[key].step(*args)
50
+ else:
51
+ _ = [self.schedulers[key].step(*args) for key in self.keys]
52
+
53
+ def define_scheduler(optimizer, params):
54
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
55
+ optimizer,
56
+ max_lr=params.get('max_lr', 2e-4),
57
+ epochs=params.get('epochs', 200),
58
+ steps_per_epoch=params.get('steps_per_epoch', 1000),
59
+ pct_start=params.get('pct_start', 0.0),
60
+ div_factor=1,
61
+ final_div_factor=1)
62
+
63
+ return scheduler
64
+
65
+ def build_optimizer(parameters_dict, scheduler_params_dict, lr):
66
+ optim = dict([(key, AdamW(params, lr=lr, weight_decay=1e-4, betas=(0.0, 0.99), eps=1e-9))
67
+ for key, params in parameters_dict.items()])
68
+
69
+ schedulers = dict([(key, define_scheduler(opt, scheduler_params_dict[key])) \
70
+ for key, opt in optim.items()])
71
+
72
+ multi_optim = MultiOptimizer(optim, schedulers)
73
+ return multi_optim
preview.wav ADDED
Binary file (28.9 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SoundFile
2
+ torchaudio
3
+ munch
4
+ torch
5
+ pydub
6
+ pyyaml
7
+ librosa
8
+ nltk
9
+ matplotlib
10
+ accelerate
11
+ transformers
12
+ einops
13
+ einops-exts
14
+ tqdm
15
+ typing
16
+ typing-extensions
17
+ git+https://github.com/resemble-ai/monotonic_align.git
text_utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IPA Phonemizer: https://github.com/bootphon/phonemizer
2
+
3
+ _pad = "$"
4
+ _punctuation = ';:,.!?¡¿—…"«»“” '
5
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
6
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
7
+
8
+ # Export all symbols:
9
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
10
+
11
+ dicts = {}
12
+ for i in range(len((symbols))):
13
+ dicts[symbols[i]] = i
14
+
15
+ class TextCleaner:
16
+ def __init__(self, dummy=None):
17
+ self.word_index_dictionary = dicts
18
+ print(len(dicts))
19
+ def __call__(self, text):
20
+ indexes = []
21
+ for char in text:
22
+ try:
23
+ indexes.append(self.word_index_dictionary[char])
24
+ except KeyError:
25
+ print(text)
26
+ return indexes
train_finetune.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import warnings
15
+ warnings.simplefilter('ignore')
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from meldataset import build_dataloader
19
+
20
+ from Utils.ASR.models import ASRCNN
21
+ from Utils.JDC.model import JDCNet
22
+ from Utils.PLBERT.util import load_plbert
23
+
24
+ from models import *
25
+ from losses import *
26
+ from utils import *
27
+
28
+ from Modules.slmadv import SLMAdversarialLoss
29
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
30
+
31
+ from optimizers import build_optimizer
32
+
33
+ # simple fix for dataparallel that allows access to class attributes
34
+ class MyDataParallel(torch.nn.DataParallel):
35
+ def __getattr__(self, name):
36
+ try:
37
+ return super().__getattr__(name)
38
+ except AttributeError:
39
+ return getattr(self.module, name)
40
+
41
+ import logging
42
+ from logging import StreamHandler
43
+ logger = logging.getLogger(__name__)
44
+ logger.setLevel(logging.DEBUG)
45
+ handler = StreamHandler()
46
+ handler.setLevel(logging.DEBUG)
47
+ logger.addHandler(handler)
48
+
49
+
50
+ @click.command()
51
+ @click.option('-p', '--config_path', default='Configs/config_ft.yml', type=str)
52
+ def main(config_path):
53
+ config = yaml.safe_load(open(config_path))
54
+
55
+ log_dir = config['log_dir']
56
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
57
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
58
+ writer = SummaryWriter(log_dir + "/tensorboard")
59
+
60
+ # write logs
61
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
62
+ file_handler.setLevel(logging.DEBUG)
63
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
64
+ logger.addHandler(file_handler)
65
+
66
+
67
+ batch_size = config.get('batch_size', 10)
68
+
69
+ epochs = config.get('epochs', 200)
70
+ save_freq = config.get('save_freq', 2)
71
+ log_interval = config.get('log_interval', 10)
72
+ saving_epoch = config.get('save_freq', 2)
73
+
74
+ data_params = config.get('data_params', None)
75
+ sr = config['preprocess_params'].get('sr', 24000)
76
+ train_path = data_params['train_data']
77
+ val_path = data_params['val_data']
78
+ root_path = data_params['root_path']
79
+ min_length = data_params['min_length']
80
+ OOD_data = data_params['OOD_data']
81
+
82
+ max_len = config.get('max_len', 200)
83
+
84
+ loss_params = Munch(config['loss_params'])
85
+ diff_epoch = loss_params.diff_epoch
86
+ joint_epoch = loss_params.joint_epoch
87
+
88
+ optimizer_params = Munch(config['optimizer_params'])
89
+
90
+ train_list, val_list = get_data_path_list(train_path, val_path)
91
+ device = 'cuda'
92
+
93
+ train_dataloader = build_dataloader(train_list,
94
+ root_path,
95
+ OOD_data=OOD_data,
96
+ min_length=min_length,
97
+ batch_size=batch_size,
98
+ num_workers=2,
99
+ dataset_config={},
100
+ device=device)
101
+
102
+ val_dataloader = build_dataloader(val_list,
103
+ root_path,
104
+ OOD_data=OOD_data,
105
+ min_length=min_length,
106
+ batch_size=batch_size,
107
+ validation=True,
108
+ num_workers=0,
109
+ device=device,
110
+ dataset_config={})
111
+
112
+ # load pretrained ASR model
113
+ ASR_config = config.get('ASR_config', False)
114
+ ASR_path = config.get('ASR_path', False)
115
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
116
+
117
+ # load pretrained F0 model
118
+ F0_path = config.get('F0_path', False)
119
+ pitch_extractor = load_F0_models(F0_path)
120
+
121
+ # load PL-BERT model
122
+ BERT_path = config.get('PLBERT_dir', False)
123
+ plbert = load_plbert(BERT_path)
124
+
125
+ # build model
126
+ model_params = recursive_munch(config['model_params'])
127
+ multispeaker = model_params.multispeaker
128
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
129
+ _ = [model[key].to(device) for key in model]
130
+
131
+ # DP
132
+ for key in model:
133
+ if key != "mpd" and key != "msd" and key != "wd":
134
+ model[key] = MyDataParallel(model[key])
135
+
136
+ start_epoch = 0
137
+ iters = 0
138
+
139
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
140
+
141
+ if not load_pretrained:
142
+ if config.get('first_stage_path', '') != '':
143
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
144
+ print('Loading the first stage model at %s ...' % first_stage_path)
145
+ model, _, start_epoch, iters = load_checkpoint(model,
146
+ None,
147
+ first_stage_path,
148
+ load_only_params=True,
149
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
150
+
151
+ # these epochs should be counted from the start epoch
152
+ diff_epoch += start_epoch
153
+ joint_epoch += start_epoch
154
+ epochs += start_epoch
155
+
156
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
157
+ else:
158
+ raise ValueError('You need to specify the path to the first stage model.')
159
+
160
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
161
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
162
+ wl = WavLMLoss(model_params.slm.model,
163
+ model.wd,
164
+ sr,
165
+ model_params.slm.sr).to(device)
166
+
167
+ gl = MyDataParallel(gl)
168
+ dl = MyDataParallel(dl)
169
+ wl = MyDataParallel(wl)
170
+
171
+ sampler = DiffusionSampler(
172
+ model.diffusion.diffusion,
173
+ sampler=ADPM2Sampler(),
174
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
175
+ clamp=False
176
+ )
177
+
178
+ scheduler_params = {
179
+ "max_lr": optimizer_params.lr,
180
+ "pct_start": float(0),
181
+ "epochs": epochs,
182
+ "steps_per_epoch": len(train_dataloader),
183
+ }
184
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
185
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
186
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
187
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
188
+
189
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
190
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
191
+
192
+ # adjust BERT learning rate
193
+ for g in optimizer.optimizers['bert'].param_groups:
194
+ g['betas'] = (0.9, 0.99)
195
+ g['lr'] = optimizer_params.bert_lr
196
+ g['initial_lr'] = optimizer_params.bert_lr
197
+ g['min_lr'] = 0
198
+ g['weight_decay'] = 0.01
199
+
200
+ # adjust acoustic module learning rate
201
+ for module in ["decoder", "style_encoder"]:
202
+ for g in optimizer.optimizers[module].param_groups:
203
+ g['betas'] = (0.0, 0.99)
204
+ g['lr'] = optimizer_params.ft_lr
205
+ g['initial_lr'] = optimizer_params.ft_lr
206
+ g['min_lr'] = 0
207
+ g['weight_decay'] = 1e-4
208
+
209
+ # load models if there is a model
210
+ if load_pretrained:
211
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
212
+ load_only_params=config.get('load_only_params', True))
213
+
214
+ n_down = model.text_aligner.n_down
215
+
216
+ best_loss = float('inf') # best test loss
217
+ loss_train_record = list([])
218
+ loss_test_record = list([])
219
+ iters = 0
220
+
221
+ criterion = nn.L1Loss() # F0 loss (regression)
222
+ torch.cuda.empty_cache()
223
+
224
+ stft_loss = MultiResolutionSTFTLoss().to(device)
225
+
226
+ print('BERT', optimizer.optimizers['bert'])
227
+ print('decoder', optimizer.optimizers['decoder'])
228
+
229
+ start_ds = False
230
+
231
+ running_std = []
232
+
233
+ slmadv_params = Munch(config['slmadv_params'])
234
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
235
+ slmadv_params.min_len,
236
+ slmadv_params.max_len,
237
+ batch_percentage=slmadv_params.batch_percentage,
238
+ skip_update=slmadv_params.iter,
239
+ sig=slmadv_params.sig
240
+ )
241
+
242
+
243
+ for epoch in range(start_epoch, epochs):
244
+ running_loss = 0
245
+ start_time = time.time()
246
+
247
+ _ = [model[key].eval() for key in model]
248
+
249
+ model.text_aligner.train()
250
+ model.text_encoder.train()
251
+
252
+ model.predictor.train()
253
+ model.bert_encoder.train()
254
+ model.bert.train()
255
+ model.msd.train()
256
+ model.mpd.train()
257
+
258
+ for i, batch in enumerate(train_dataloader):
259
+ waves = batch[0]
260
+ batch = [b.to(device) for b in batch[1:]]
261
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
262
+ with torch.no_grad():
263
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
264
+ mel_mask = length_to_mask(mel_input_length).to(device)
265
+ text_mask = length_to_mask(input_lengths).to(texts.device)
266
+
267
+ # compute reference styles
268
+ if multispeaker and epoch >= diff_epoch:
269
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
270
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
271
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
272
+
273
+ try:
274
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
275
+ s2s_attn = s2s_attn.transpose(-1, -2)
276
+ s2s_attn = s2s_attn[..., 1:]
277
+ s2s_attn = s2s_attn.transpose(-1, -2)
278
+ except:
279
+ continue
280
+
281
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
282
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
283
+
284
+ # encode
285
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
286
+
287
+ # 50% of chance of using monotonic version
288
+ if bool(random.getrandbits(1)):
289
+ asr = (t_en @ s2s_attn)
290
+ else:
291
+ asr = (t_en @ s2s_attn_mono)
292
+
293
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
294
+
295
+ # compute the style of the entire utterance
296
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
297
+ ss = []
298
+ gs = []
299
+ for bib in range(len(mel_input_length)):
300
+ mel_length = int(mel_input_length[bib].item())
301
+ mel = mels[bib, :, :mel_input_length[bib]]
302
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
303
+ ss.append(s)
304
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
305
+ gs.append(s)
306
+
307
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
308
+ gs = torch.stack(gs).squeeze() # global acoustic styles
309
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
310
+
311
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
312
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
313
+
314
+ # denoiser training
315
+ if epoch >= diff_epoch:
316
+ num_steps = np.random.randint(3, 5)
317
+
318
+ if model_params.diffusion.dist.estimate_sigma_data:
319
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
320
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
321
+
322
+ if multispeaker:
323
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
324
+ embedding=bert_dur,
325
+ embedding_scale=1,
326
+ features=ref, # reference from the same speaker as the embedding
327
+ embedding_mask_proba=0.1,
328
+ num_steps=num_steps).squeeze(1)
329
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
330
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
331
+ else:
332
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
333
+ embedding=bert_dur,
334
+ embedding_scale=1,
335
+ embedding_mask_proba=0.1,
336
+ num_steps=num_steps).squeeze(1)
337
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
338
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
339
+ else:
340
+ loss_sty = 0
341
+ loss_diff = 0
342
+
343
+
344
+ s_loss = 0
345
+
346
+
347
+ d, p = model.predictor(d_en, s_dur,
348
+ input_lengths,
349
+ s2s_attn_mono,
350
+ text_mask)
351
+
352
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
353
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
354
+ en = []
355
+ gt = []
356
+ p_en = []
357
+ wav = []
358
+ st = []
359
+
360
+ for bib in range(len(mel_input_length)):
361
+ mel_length = int(mel_input_length[bib].item() / 2)
362
+
363
+ random_start = np.random.randint(0, mel_length - mel_len)
364
+ en.append(asr[bib, :, random_start:random_start+mel_len])
365
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
366
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
367
+
368
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
369
+ wav.append(torch.from_numpy(y).to(device))
370
+
371
+ # style reference (better to be different from the GT)
372
+ random_start = np.random.randint(0, mel_length - mel_len_st)
373
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
374
+
375
+ wav = torch.stack(wav).float().detach()
376
+
377
+ en = torch.stack(en)
378
+ p_en = torch.stack(p_en)
379
+ gt = torch.stack(gt).detach()
380
+ st = torch.stack(st).detach()
381
+
382
+
383
+ if gt.size(-1) < 80:
384
+ continue
385
+
386
+ s = model.style_encoder(gt.unsqueeze(1))
387
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
388
+
389
+ with torch.no_grad():
390
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
391
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
392
+
393
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
394
+
395
+ y_rec_gt = wav.unsqueeze(1)
396
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
397
+
398
+ wav = y_rec_gt
399
+
400
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
401
+
402
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
403
+
404
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
405
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
406
+
407
+ optimizer.zero_grad()
408
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
409
+ d_loss.backward()
410
+ optimizer.step('msd')
411
+ optimizer.step('mpd')
412
+
413
+ # generator loss
414
+ optimizer.zero_grad()
415
+
416
+ loss_mel = stft_loss(y_rec, wav)
417
+ loss_gen_all = gl(wav, y_rec).mean()
418
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
419
+
420
+ loss_ce = 0
421
+ loss_dur = 0
422
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
423
+ _s2s_pred = _s2s_pred[:_text_length, :]
424
+ _text_input = _text_input[:_text_length].long()
425
+ _s2s_trg = torch.zeros_like(_s2s_pred)
426
+ for p in range(_s2s_trg.shape[0]):
427
+ _s2s_trg[p, :_text_input[p]] = 1
428
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
429
+
430
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
431
+ _text_input[1:_text_length-1])
432
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
433
+
434
+ loss_ce /= texts.size(0)
435
+ loss_dur /= texts.size(0)
436
+
437
+ loss_s2s = 0
438
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
439
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
440
+ loss_s2s /= texts.size(0)
441
+
442
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
443
+
444
+ g_loss = loss_params.lambda_mel * loss_mel + \
445
+ loss_params.lambda_F0 * loss_F0_rec + \
446
+ loss_params.lambda_ce * loss_ce + \
447
+ loss_params.lambda_norm * loss_norm_rec + \
448
+ loss_params.lambda_dur * loss_dur + \
449
+ loss_params.lambda_gen * loss_gen_all + \
450
+ loss_params.lambda_slm * loss_lm + \
451
+ loss_params.lambda_sty * loss_sty + \
452
+ loss_params.lambda_diff * loss_diff + \
453
+ loss_params.lambda_mono * loss_mono + \
454
+ loss_params.lambda_s2s * loss_s2s
455
+
456
+ running_loss += loss_mel.item()
457
+ g_loss.backward()
458
+ if torch.isnan(g_loss):
459
+ from IPython.core.debugger import set_trace
460
+ set_trace()
461
+
462
+ optimizer.step('bert_encoder')
463
+ optimizer.step('bert')
464
+ optimizer.step('predictor')
465
+ optimizer.step('predictor_encoder')
466
+ optimizer.step('style_encoder')
467
+ optimizer.step('decoder')
468
+
469
+ optimizer.step('text_encoder')
470
+ optimizer.step('text_aligner')
471
+
472
+ if epoch >= diff_epoch:
473
+ optimizer.step('diffusion')
474
+
475
+ d_loss_slm, loss_gen_lm = 0, 0
476
+ if epoch >= joint_epoch:
477
+ # randomly pick whether to use in-distribution text
478
+ if np.random.rand() < 0.5:
479
+ use_ind = True
480
+ else:
481
+ use_ind = False
482
+
483
+ if use_ind:
484
+ ref_lengths = input_lengths
485
+ ref_texts = texts
486
+
487
+ slm_out = slmadv(i,
488
+ y_rec_gt,
489
+ y_rec_gt_pred,
490
+ waves,
491
+ mel_input_length,
492
+ ref_texts,
493
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
494
+
495
+ if slm_out is not None:
496
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
497
+
498
+ # SLM generator loss
499
+ optimizer.zero_grad()
500
+ loss_gen_lm.backward()
501
+
502
+ # compute the gradient norm
503
+ total_norm = {}
504
+ for key in model.keys():
505
+ total_norm[key] = 0
506
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
507
+ for p in parameters:
508
+ param_norm = p.grad.detach().data.norm(2)
509
+ total_norm[key] += param_norm.item() ** 2
510
+ total_norm[key] = total_norm[key] ** 0.5
511
+
512
+ # gradient scaling
513
+ if total_norm['predictor'] > slmadv_params.thresh:
514
+ for key in model.keys():
515
+ for p in model[key].parameters():
516
+ if p.grad is not None:
517
+ p.grad *= (1 / total_norm['predictor'])
518
+
519
+ for p in model.predictor.duration_proj.parameters():
520
+ if p.grad is not None:
521
+ p.grad *= slmadv_params.scale
522
+
523
+ for p in model.predictor.lstm.parameters():
524
+ if p.grad is not None:
525
+ p.grad *= slmadv_params.scale
526
+
527
+ for p in model.diffusion.parameters():
528
+ if p.grad is not None:
529
+ p.grad *= slmadv_params.scale
530
+
531
+ optimizer.step('bert_encoder')
532
+ optimizer.step('bert')
533
+ optimizer.step('predictor')
534
+ optimizer.step('diffusion')
535
+
536
+ # SLM discriminator loss
537
+ if d_loss_slm != 0:
538
+ optimizer.zero_grad()
539
+ d_loss_slm.backward(retain_graph=True)
540
+ optimizer.step('wd')
541
+
542
+ iters = iters + 1
543
+
544
+ if (i+1)%log_interval == 0:
545
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f'
546
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm, s_loss, loss_s2s, loss_mono))
547
+
548
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
549
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
550
+ writer.add_scalar('train/d_loss', d_loss, iters)
551
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
552
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
553
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
554
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
555
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
556
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
557
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
558
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
559
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
560
+
561
+ running_loss = 0
562
+
563
+ print('Time elasped:', time.time()-start_time)
564
+
565
+ loss_test = 0
566
+ loss_align = 0
567
+ loss_f = 0
568
+ _ = [model[key].eval() for key in model]
569
+
570
+ with torch.no_grad():
571
+ iters_test = 0
572
+ for batch_idx, batch in enumerate(val_dataloader):
573
+ optimizer.zero_grad()
574
+
575
+ try:
576
+ waves = batch[0]
577
+ batch = [b.to(device) for b in batch[1:]]
578
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
579
+ with torch.no_grad():
580
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
581
+ text_mask = length_to_mask(input_lengths).to(texts.device)
582
+
583
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
584
+ s2s_attn = s2s_attn.transpose(-1, -2)
585
+ s2s_attn = s2s_attn[..., 1:]
586
+ s2s_attn = s2s_attn.transpose(-1, -2)
587
+
588
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
589
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
590
+
591
+ # encode
592
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
593
+ asr = (t_en @ s2s_attn_mono)
594
+
595
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
596
+
597
+ ss = []
598
+ gs = []
599
+
600
+ for bib in range(len(mel_input_length)):
601
+ mel_length = int(mel_input_length[bib].item())
602
+ mel = mels[bib, :, :mel_input_length[bib]]
603
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
604
+ ss.append(s)
605
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
606
+ gs.append(s)
607
+
608
+ s = torch.stack(ss).squeeze()
609
+ gs = torch.stack(gs).squeeze()
610
+ s_trg = torch.cat([s, gs], dim=-1).detach()
611
+
612
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
613
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
614
+ d, p = model.predictor(d_en, s,
615
+ input_lengths,
616
+ s2s_attn_mono,
617
+ text_mask)
618
+ # get clips
619
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
620
+ en = []
621
+ gt = []
622
+
623
+ p_en = []
624
+ wav = []
625
+
626
+ for bib in range(len(mel_input_length)):
627
+ mel_length = int(mel_input_length[bib].item() / 2)
628
+
629
+ random_start = np.random.randint(0, mel_length - mel_len)
630
+ en.append(asr[bib, :, random_start:random_start+mel_len])
631
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
632
+
633
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
634
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
635
+ wav.append(torch.from_numpy(y).to(device))
636
+
637
+ wav = torch.stack(wav).float().detach()
638
+
639
+ en = torch.stack(en)
640
+ p_en = torch.stack(p_en)
641
+ gt = torch.stack(gt).detach()
642
+ s = model.predictor_encoder(gt.unsqueeze(1))
643
+
644
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
645
+
646
+ loss_dur = 0
647
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
648
+ _s2s_pred = _s2s_pred[:_text_length, :]
649
+ _text_input = _text_input[:_text_length].long()
650
+ _s2s_trg = torch.zeros_like(_s2s_pred)
651
+ for bib in range(_s2s_trg.shape[0]):
652
+ _s2s_trg[bib, :_text_input[bib]] = 1
653
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
654
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
655
+ _text_input[1:_text_length-1])
656
+
657
+ loss_dur /= texts.size(0)
658
+
659
+ s = model.style_encoder(gt.unsqueeze(1))
660
+
661
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
662
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
663
+
664
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
665
+
666
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
667
+
668
+ loss_test += (loss_mel).mean()
669
+ loss_align += (loss_dur).mean()
670
+ loss_f += (loss_F0).mean()
671
+
672
+ iters_test += 1
673
+ except:
674
+ continue
675
+
676
+ print('Epochs:', epoch + 1)
677
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
678
+ print('\n\n\n')
679
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
680
+ writer.add_scalar('eval/dur_loss', loss_test / iters_test, epoch + 1)
681
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
682
+
683
+
684
+ if (epoch + 1) % save_freq == 0 :
685
+ if (loss_test / iters_test) < best_loss:
686
+ best_loss = loss_test / iters_test
687
+ print('Saving..')
688
+ state = {
689
+ 'net': {key: model[key].state_dict() for key in model},
690
+ 'optimizer': optimizer.state_dict(),
691
+ 'iters': iters,
692
+ 'val_loss': loss_test / iters_test,
693
+ 'epoch': epoch,
694
+ }
695
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
696
+ torch.save(state, save_path)
697
+
698
+ # if estimate sigma, save the estimated simga
699
+ if model_params.diffusion.dist.estimate_sigma_data:
700
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
701
+
702
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
703
+ yaml.dump(config, outfile, default_flow_style=True)
704
+
705
+
706
+ if __name__=="__main__":
707
+ main()
train_finetune_accelerate.py ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import warnings
15
+ warnings.simplefilter('ignore')
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from meldataset import build_dataloader
19
+
20
+ from Utils.ASR.models import ASRCNN
21
+ from Utils.JDC.model import JDCNet
22
+ from Utils.PLBERT.util import load_plbert
23
+
24
+ from models import *
25
+ from losses import *
26
+ from utils import *
27
+
28
+ from Modules.slmadv import SLMAdversarialLoss
29
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
30
+
31
+ from optimizers import build_optimizer
32
+
33
+ from accelerate import Accelerator
34
+
35
+ accelerator = Accelerator()
36
+
37
+ # simple fix for dataparallel that allows access to class attributes
38
+ class MyDataParallel(torch.nn.DataParallel):
39
+ def __getattr__(self, name):
40
+ try:
41
+ return super().__getattr__(name)
42
+ except AttributeError:
43
+ return getattr(self.module, name)
44
+
45
+ import logging
46
+ from logging import StreamHandler
47
+ logger = logging.getLogger(__name__)
48
+ logger.setLevel(logging.DEBUG)
49
+ handler = StreamHandler()
50
+ handler.setLevel(logging.DEBUG)
51
+ logger.addHandler(handler)
52
+
53
+
54
+ @click.command()
55
+ @click.option('-p', '--config_path', default='Configs/config_ft.yml', type=str)
56
+ def main(config_path):
57
+ config = yaml.safe_load(open(config_path))
58
+
59
+ log_dir = config['log_dir']
60
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
61
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
62
+ writer = SummaryWriter(log_dir + "/tensorboard")
63
+
64
+ # write logs
65
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
66
+ file_handler.setLevel(logging.DEBUG)
67
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
68
+ logger.addHandler(file_handler)
69
+
70
+
71
+ batch_size = config.get('batch_size', 10)
72
+
73
+ epochs = config.get('epochs', 200)
74
+ save_freq = config.get('save_freq', 2)
75
+ log_interval = config.get('log_interval', 10)
76
+ saving_epoch = config.get('save_freq', 2)
77
+
78
+ data_params = config.get('data_params', None)
79
+ sr = config['preprocess_params'].get('sr', 24000)
80
+ train_path = data_params['train_data']
81
+ val_path = data_params['val_data']
82
+ root_path = data_params['root_path']
83
+ min_length = data_params['min_length']
84
+ OOD_data = data_params['OOD_data']
85
+
86
+ max_len = config.get('max_len', 200)
87
+
88
+ loss_params = Munch(config['loss_params'])
89
+ diff_epoch = loss_params.diff_epoch
90
+ joint_epoch = loss_params.joint_epoch
91
+
92
+ optimizer_params = Munch(config['optimizer_params'])
93
+
94
+ train_list, val_list = get_data_path_list(train_path, val_path)
95
+ device = accelerator.device
96
+
97
+ train_dataloader = build_dataloader(train_list,
98
+ root_path,
99
+ OOD_data=OOD_data,
100
+ min_length=min_length,
101
+ batch_size=batch_size,
102
+ num_workers=2,
103
+ dataset_config={},
104
+ device=device)
105
+
106
+ val_dataloader = build_dataloader(val_list,
107
+ root_path,
108
+ OOD_data=OOD_data,
109
+ min_length=min_length,
110
+ batch_size=batch_size,
111
+ validation=True,
112
+ num_workers=0,
113
+ device=device,
114
+ dataset_config={})
115
+
116
+ # load pretrained ASR model
117
+ ASR_config = config.get('ASR_config', False)
118
+ ASR_path = config.get('ASR_path', False)
119
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
120
+
121
+ # load pretrained F0 model
122
+ F0_path = config.get('F0_path', False)
123
+ pitch_extractor = load_F0_models(F0_path)
124
+
125
+ # load PL-BERT model
126
+ BERT_path = config.get('PLBERT_dir', False)
127
+ plbert = load_plbert(BERT_path)
128
+
129
+ # build model
130
+ model_params = recursive_munch(config['model_params'])
131
+ multispeaker = model_params.multispeaker
132
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
133
+ _ = [model[key].to(device) for key in model]
134
+
135
+ # DP
136
+ for key in model:
137
+ if key != "mpd" and key != "msd" and key != "wd":
138
+ model[key] = MyDataParallel(model[key])
139
+
140
+ start_epoch = 0
141
+ iters = 0
142
+
143
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
144
+
145
+ if not load_pretrained:
146
+ if config.get('first_stage_path', '') != '':
147
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
148
+ print('Loading the first stage model at %s ...' % first_stage_path)
149
+ model, _, start_epoch, iters = load_checkpoint(model,
150
+ None,
151
+ first_stage_path,
152
+ load_only_params=True,
153
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
154
+
155
+ # these epochs should be counted from the start epoch
156
+ diff_epoch += start_epoch
157
+ joint_epoch += start_epoch
158
+ epochs += start_epoch
159
+
160
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
161
+ else:
162
+ raise ValueError('You need to specify the path to the first stage model.')
163
+
164
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
165
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
166
+ wl = WavLMLoss(model_params.slm.model,
167
+ model.wd,
168
+ sr,
169
+ model_params.slm.sr).to(device)
170
+
171
+ gl = MyDataParallel(gl)
172
+ dl = MyDataParallel(dl)
173
+ wl = MyDataParallel(wl)
174
+
175
+ sampler = DiffusionSampler(
176
+ model.diffusion.diffusion,
177
+ sampler=ADPM2Sampler(),
178
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
179
+ clamp=False
180
+ )
181
+
182
+ scheduler_params = {
183
+ "max_lr": optimizer_params.lr,
184
+ "pct_start": float(0),
185
+ "epochs": epochs,
186
+ "steps_per_epoch": len(train_dataloader),
187
+ }
188
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
189
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
190
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
191
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
192
+
193
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
194
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
195
+
196
+ # adjust BERT learning rate
197
+ for g in optimizer.optimizers['bert'].param_groups:
198
+ g['betas'] = (0.9, 0.99)
199
+ g['lr'] = optimizer_params.bert_lr
200
+ g['initial_lr'] = optimizer_params.bert_lr
201
+ g['min_lr'] = 0
202
+ g['weight_decay'] = 0.01
203
+
204
+ # adjust acoustic module learning rate
205
+ for module in ["decoder", "style_encoder"]:
206
+ for g in optimizer.optimizers[module].param_groups:
207
+ g['betas'] = (0.0, 0.99)
208
+ g['lr'] = optimizer_params.ft_lr
209
+ g['initial_lr'] = optimizer_params.ft_lr
210
+ g['min_lr'] = 0
211
+ g['weight_decay'] = 1e-4
212
+
213
+ # load models if there is a model
214
+ if load_pretrained:
215
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
216
+ load_only_params=config.get('load_only_params', True))
217
+
218
+ n_down = model.text_aligner.n_down
219
+
220
+ best_loss = float('inf') # best test loss
221
+ loss_train_record = list([])
222
+ loss_test_record = list([])
223
+ iters = 0
224
+
225
+ criterion = nn.L1Loss() # F0 loss (regression)
226
+ torch.cuda.empty_cache()
227
+
228
+ stft_loss = MultiResolutionSTFTLoss().to(device)
229
+
230
+ print('BERT', optimizer.optimizers['bert'])
231
+ print('decoder', optimizer.optimizers['decoder'])
232
+
233
+ start_ds = False
234
+
235
+ running_std = []
236
+
237
+ slmadv_params = Munch(config['slmadv_params'])
238
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
239
+ slmadv_params.min_len,
240
+ slmadv_params.max_len,
241
+ batch_percentage=slmadv_params.batch_percentage,
242
+ skip_update=slmadv_params.iter,
243
+ sig=slmadv_params.sig
244
+ )
245
+
246
+ model, optimizer, train_dataloader = accelerator.prepare(
247
+ model, optimizer, train_dataloader
248
+ )
249
+
250
+ for epoch in range(start_epoch, epochs):
251
+ running_loss = 0
252
+ start_time = time.time()
253
+
254
+ _ = [model[key].eval() for key in model]
255
+
256
+ model.text_aligner.train()
257
+ model.text_encoder.train()
258
+
259
+ model.predictor.train()
260
+ model.bert_encoder.train()
261
+ model.bert.train()
262
+ model.msd.train()
263
+ model.mpd.train()
264
+
265
+ for i, batch in enumerate(train_dataloader):
266
+ waves = batch[0]
267
+ batch = [b.to(device) for b in batch[1:]]
268
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
269
+ with torch.no_grad():
270
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
271
+ mel_mask = length_to_mask(mel_input_length).to(device)
272
+ text_mask = length_to_mask(input_lengths).to(texts.device)
273
+
274
+ # compute reference styles
275
+ if multispeaker and epoch >= diff_epoch:
276
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
277
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
278
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
279
+
280
+ try:
281
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
282
+ s2s_attn = s2s_attn.transpose(-1, -2)
283
+ s2s_attn = s2s_attn[..., 1:]
284
+ s2s_attn = s2s_attn.transpose(-1, -2)
285
+ except:
286
+ continue
287
+
288
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
289
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
290
+
291
+ # encode
292
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
293
+
294
+ # 50% of chance of using monotonic version
295
+ if bool(random.getrandbits(1)):
296
+ asr = (t_en @ s2s_attn)
297
+ else:
298
+ asr = (t_en @ s2s_attn_mono)
299
+
300
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
301
+
302
+ # compute the style of the entire utterance
303
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
304
+ ss = []
305
+ gs = []
306
+ for bib in range(len(mel_input_length)):
307
+ mel_length = int(mel_input_length[bib].item())
308
+ mel = mels[bib, :, :mel_input_length[bib]]
309
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
310
+ ss.append(s)
311
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
312
+ gs.append(s)
313
+
314
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
315
+ gs = torch.stack(gs).squeeze() # global acoustic styles
316
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
317
+
318
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
319
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
320
+
321
+ # denoiser training
322
+ if epoch >= diff_epoch:
323
+ num_steps = np.random.randint(3, 5)
324
+
325
+ if model_params.diffusion.dist.estimate_sigma_data:
326
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
327
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
328
+
329
+ if multispeaker:
330
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
331
+ embedding=bert_dur,
332
+ embedding_scale=1,
333
+ features=ref, # reference from the same speaker as the embedding
334
+ embedding_mask_proba=0.1,
335
+ num_steps=num_steps).squeeze(1)
336
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
337
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
338
+ else:
339
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
340
+ embedding=bert_dur,
341
+ embedding_scale=1,
342
+ embedding_mask_proba=0.1,
343
+ num_steps=num_steps).squeeze(1)
344
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
345
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
346
+ else:
347
+ loss_sty = 0
348
+ loss_diff = 0
349
+
350
+
351
+ s_loss = 0
352
+
353
+
354
+ d, p = model.predictor(d_en, s_dur,
355
+ input_lengths,
356
+ s2s_attn_mono,
357
+ text_mask)
358
+
359
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
360
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
361
+ en = []
362
+ gt = []
363
+ p_en = []
364
+ wav = []
365
+ st = []
366
+
367
+ for bib in range(len(mel_input_length)):
368
+ mel_length = int(mel_input_length[bib].item() / 2)
369
+
370
+ random_start = np.random.randint(0, mel_length - mel_len)
371
+ en.append(asr[bib, :, random_start:random_start+mel_len])
372
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
373
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
374
+
375
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
376
+ wav.append(torch.from_numpy(y).to(device))
377
+
378
+ # style reference (better to be different from the GT)
379
+ random_start = np.random.randint(0, mel_length - mel_len_st)
380
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
381
+
382
+ wav = torch.stack(wav).float().detach()
383
+
384
+ en = torch.stack(en)
385
+ p_en = torch.stack(p_en)
386
+ gt = torch.stack(gt).detach()
387
+ st = torch.stack(st).detach()
388
+
389
+
390
+ if gt.size(-1) < 80:
391
+ continue
392
+
393
+ s = model.style_encoder(gt.unsqueeze(1))
394
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
395
+
396
+ with torch.no_grad():
397
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
398
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
399
+
400
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
401
+
402
+ y_rec_gt = wav.unsqueeze(1)
403
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
404
+
405
+ wav = y_rec_gt
406
+
407
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
408
+
409
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
410
+
411
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
412
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
413
+
414
+ optimizer.zero_grad()
415
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
416
+ accelerator.backward(d_loss)
417
+ optimizer.step('msd')
418
+ optimizer.step('mpd')
419
+
420
+ # generator loss
421
+ optimizer.zero_grad()
422
+
423
+ loss_mel = stft_loss(y_rec, wav)
424
+ loss_gen_all = gl(wav, y_rec).mean()
425
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
426
+
427
+ loss_ce = 0
428
+ loss_dur = 0
429
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
430
+ _s2s_pred = _s2s_pred[:_text_length, :]
431
+ _text_input = _text_input[:_text_length].long()
432
+ _s2s_trg = torch.zeros_like(_s2s_pred)
433
+ for p in range(_s2s_trg.shape[0]):
434
+ _s2s_trg[p, :_text_input[p]] = 1
435
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
436
+
437
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
438
+ _text_input[1:_text_length-1])
439
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
440
+
441
+ loss_ce /= texts.size(0)
442
+ loss_dur /= texts.size(0)
443
+
444
+ loss_s2s = 0
445
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
446
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
447
+ loss_s2s /= texts.size(0)
448
+
449
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
450
+
451
+ g_loss = loss_params.lambda_mel * loss_mel + \
452
+ loss_params.lambda_F0 * loss_F0_rec + \
453
+ loss_params.lambda_ce * loss_ce + \
454
+ loss_params.lambda_norm * loss_norm_rec + \
455
+ loss_params.lambda_dur * loss_dur + \
456
+ loss_params.lambda_gen * loss_gen_all + \
457
+ loss_params.lambda_slm * loss_lm + \
458
+ loss_params.lambda_sty * loss_sty + \
459
+ loss_params.lambda_diff * loss_diff + \
460
+ loss_params.lambda_mono * loss_mono + \
461
+ loss_params.lambda_s2s * loss_s2s
462
+
463
+ running_loss += loss_mel.item()
464
+ accelerator.backward(g_loss)
465
+ if torch.isnan(g_loss):
466
+ from IPython.core.debugger import set_trace
467
+ set_trace()
468
+
469
+ optimizer.step('bert_encoder')
470
+ optimizer.step('bert')
471
+ optimizer.step('predictor')
472
+ optimizer.step('predictor_encoder')
473
+ optimizer.step('style_encoder')
474
+ optimizer.step('decoder')
475
+
476
+ optimizer.step('text_encoder')
477
+ optimizer.step('text_aligner')
478
+
479
+ if epoch >= diff_epoch:
480
+ optimizer.step('diffusion')
481
+
482
+ d_loss_slm, loss_gen_lm = 0, 0
483
+ if epoch >= joint_epoch:
484
+ # randomly pick whether to use in-distribution text
485
+ if np.random.rand() < 0.5:
486
+ use_ind = True
487
+ else:
488
+ use_ind = False
489
+
490
+ if use_ind:
491
+ ref_lengths = input_lengths
492
+ ref_texts = texts
493
+
494
+ slm_out = slmadv(i,
495
+ y_rec_gt,
496
+ y_rec_gt_pred,
497
+ waves,
498
+ mel_input_length,
499
+ ref_texts,
500
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
501
+
502
+ if slm_out is not None:
503
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
504
+
505
+ # SLM generator loss
506
+ optimizer.zero_grad()
507
+ accelerator.backward(loss_gen_lm)
508
+
509
+ # compute the gradient norm
510
+ total_norm = {}
511
+ for key in model.keys():
512
+ total_norm[key] = 0
513
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
514
+ for p in parameters:
515
+ param_norm = p.grad.detach().data.norm(2)
516
+ total_norm[key] += param_norm.item() ** 2
517
+ total_norm[key] = total_norm[key] ** 0.5
518
+
519
+ # gradient scaling
520
+ if total_norm['predictor'] > slmadv_params.thresh:
521
+ for key in model.keys():
522
+ for p in model[key].parameters():
523
+ if p.grad is not None:
524
+ p.grad *= (1 / total_norm['predictor'])
525
+
526
+ for p in model.predictor.duration_proj.parameters():
527
+ if p.grad is not None:
528
+ p.grad *= slmadv_params.scale
529
+
530
+ for p in model.predictor.lstm.parameters():
531
+ if p.grad is not None:
532
+ p.grad *= slmadv_params.scale
533
+
534
+ for p in model.diffusion.parameters():
535
+ if p.grad is not None:
536
+ p.grad *= slmadv_params.scale
537
+
538
+ optimizer.step('bert_encoder')
539
+ optimizer.step('bert')
540
+ optimizer.step('predictor')
541
+ optimizer.step('diffusion')
542
+
543
+ # SLM discriminator loss
544
+ if d_loss_slm != 0:
545
+ optimizer.zero_grad()
546
+ accelerator.backward(d_loss_slm)
547
+ optimizer.step('wd')
548
+
549
+ iters = iters + 1
550
+
551
+ if (i+1)%log_interval == 0:
552
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f'
553
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm, s_loss, loss_s2s, loss_mono))
554
+
555
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
556
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
557
+ writer.add_scalar('train/d_loss', d_loss, iters)
558
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
559
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
560
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
561
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
562
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
563
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
564
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
565
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
566
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
567
+
568
+ running_loss = 0
569
+
570
+ print('Time elasped:', time.time()-start_time)
571
+
572
+ loss_test = 0
573
+ loss_align = 0
574
+ loss_f = 0
575
+ _ = [model[key].eval() for key in model]
576
+
577
+ with torch.no_grad():
578
+ iters_test = 0
579
+ for batch_idx, batch in enumerate(val_dataloader):
580
+ optimizer.zero_grad()
581
+
582
+ try:
583
+ waves = batch[0]
584
+ batch = [b.to(device) for b in batch[1:]]
585
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
586
+ with torch.no_grad():
587
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
588
+ text_mask = length_to_mask(input_lengths).to(texts.device)
589
+
590
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
591
+ s2s_attn = s2s_attn.transpose(-1, -2)
592
+ s2s_attn = s2s_attn[..., 1:]
593
+ s2s_attn = s2s_attn.transpose(-1, -2)
594
+
595
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
596
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
597
+
598
+ # encode
599
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
600
+ asr = (t_en @ s2s_attn_mono)
601
+
602
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
603
+
604
+ ss = []
605
+ gs = []
606
+
607
+ for bib in range(len(mel_input_length)):
608
+ mel_length = int(mel_input_length[bib].item())
609
+ mel = mels[bib, :, :mel_input_length[bib]]
610
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
611
+ ss.append(s)
612
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
613
+ gs.append(s)
614
+
615
+ s = torch.stack(ss).squeeze()
616
+ gs = torch.stack(gs).squeeze()
617
+ s_trg = torch.cat([s, gs], dim=-1).detach()
618
+
619
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
620
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
621
+ d, p = model.predictor(d_en, s,
622
+ input_lengths,
623
+ s2s_attn_mono,
624
+ text_mask)
625
+ # get clips
626
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
627
+ en = []
628
+ gt = []
629
+
630
+ p_en = []
631
+ wav = []
632
+
633
+ for bib in range(len(mel_input_length)):
634
+ mel_length = int(mel_input_length[bib].item() / 2)
635
+
636
+ random_start = np.random.randint(0, mel_length - mel_len)
637
+ en.append(asr[bib, :, random_start:random_start+mel_len])
638
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
639
+
640
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
641
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
642
+ wav.append(torch.from_numpy(y).to(device))
643
+
644
+ wav = torch.stack(wav).float().detach()
645
+
646
+ en = torch.stack(en)
647
+ p_en = torch.stack(p_en)
648
+ gt = torch.stack(gt).detach()
649
+ s = model.predictor_encoder(gt.unsqueeze(1))
650
+
651
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
652
+
653
+ loss_dur = 0
654
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
655
+ _s2s_pred = _s2s_pred[:_text_length, :]
656
+ _text_input = _text_input[:_text_length].long()
657
+ _s2s_trg = torch.zeros_like(_s2s_pred)
658
+ for bib in range(_s2s_trg.shape[0]):
659
+ _s2s_trg[bib, :_text_input[bib]] = 1
660
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
661
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
662
+ _text_input[1:_text_length-1])
663
+
664
+ loss_dur /= texts.size(0)
665
+
666
+ s = model.style_encoder(gt.unsqueeze(1))
667
+
668
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
669
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
670
+
671
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
672
+
673
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
674
+
675
+ loss_test += (loss_mel).mean()
676
+ loss_align += (loss_dur).mean()
677
+ loss_f += (loss_F0).mean()
678
+
679
+ iters_test += 1
680
+ except:
681
+ continue
682
+
683
+ print('Epochs:', epoch + 1)
684
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
685
+ print('\n\n\n')
686
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
687
+ writer.add_scalar('eval/dur_loss', loss_test / iters_test, epoch + 1)
688
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
689
+
690
+
691
+ if (epoch + 1) % save_freq == 0 :
692
+ if (loss_test / iters_test) < best_loss:
693
+ best_loss = loss_test / iters_test
694
+ print('Saving..')
695
+ state = {
696
+ 'net': {key: model[key].state_dict() for key in model},
697
+ 'optimizer': optimizer.state_dict(),
698
+ 'iters': iters,
699
+ 'val_loss': loss_test / iters_test,
700
+ 'epoch': epoch,
701
+ }
702
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
703
+ torch.save(state, save_path)
704
+
705
+ # if estimate sigma, save the estimated simga
706
+ if model_params.diffusion.dist.estimate_sigma_data:
707
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
708
+
709
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
710
+ yaml.dump(config, outfile, default_flow_style=True)
711
+
712
+
713
+ if __name__=="__main__":
714
+ main()
train_first.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import re
4
+ import sys
5
+ import yaml
6
+ import shutil
7
+ import numpy as np
8
+ import torch
9
+ import click
10
+ import warnings
11
+ warnings.simplefilter('ignore')
12
+
13
+ # load packages
14
+ import random
15
+ import yaml
16
+ from munch import Munch
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+ import torch.nn.functional as F
21
+ import torchaudio
22
+ import librosa
23
+
24
+ from models import *
25
+ from meldataset import build_dataloader
26
+ from utils import *
27
+ from losses import *
28
+ from optimizers import build_optimizer
29
+ import time
30
+
31
+ from accelerate import Accelerator
32
+ from accelerate.utils import LoggerType
33
+ from accelerate import DistributedDataParallelKwargs
34
+
35
+ from torch.utils.tensorboard import SummaryWriter
36
+
37
+ import logging
38
+ from accelerate.logging import get_logger
39
+ logger = get_logger(__name__, log_level="DEBUG")
40
+
41
+ @click.command()
42
+ @click.option('-p', '--config_path', default='Configs/config.yml', type=str)
43
+ def main(config_path):
44
+ config = yaml.safe_load(open(config_path))
45
+
46
+ log_dir = config['log_dir']
47
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
48
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
49
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
50
+ accelerator = Accelerator(project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs])
51
+ if accelerator.is_main_process:
52
+ writer = SummaryWriter(log_dir + "/tensorboard")
53
+
54
+ # write logs
55
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
56
+ file_handler.setLevel(logging.DEBUG)
57
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
58
+ logger.logger.addHandler(file_handler)
59
+
60
+ batch_size = config.get('batch_size', 10)
61
+ device = accelerator.device
62
+
63
+ epochs = config.get('epochs_1st', 200)
64
+ save_freq = config.get('save_freq', 2)
65
+ log_interval = config.get('log_interval', 10)
66
+ saving_epoch = config.get('save_freq', 2)
67
+
68
+ data_params = config.get('data_params', None)
69
+ sr = config['preprocess_params'].get('sr', 24000)
70
+ train_path = data_params['train_data']
71
+ val_path = data_params['val_data']
72
+ root_path = data_params['root_path']
73
+ min_length = data_params['min_length']
74
+ OOD_data = data_params['OOD_data']
75
+
76
+ max_len = config.get('max_len', 200)
77
+
78
+ # load data
79
+ train_list, val_list = get_data_path_list(train_path, val_path)
80
+
81
+ train_dataloader = build_dataloader(train_list,
82
+ root_path,
83
+ OOD_data=OOD_data,
84
+ min_length=min_length,
85
+ batch_size=batch_size,
86
+ num_workers=2,
87
+ dataset_config={},
88
+ device=device)
89
+
90
+ val_dataloader = build_dataloader(val_list,
91
+ root_path,
92
+ OOD_data=OOD_data,
93
+ min_length=min_length,
94
+ batch_size=batch_size,
95
+ validation=True,
96
+ num_workers=0,
97
+ device=device,
98
+ dataset_config={})
99
+
100
+ with accelerator.main_process_first():
101
+ # load pretrained ASR model
102
+ ASR_config = config.get('ASR_config', False)
103
+ ASR_path = config.get('ASR_path', False)
104
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
105
+
106
+ # load pretrained F0 model
107
+ F0_path = config.get('F0_path', False)
108
+ pitch_extractor = load_F0_models(F0_path)
109
+
110
+ # load BERT model
111
+ from Utils.PLBERT.util import load_plbert
112
+ BERT_path = config.get('PLBERT_dir', False)
113
+ plbert = load_plbert(BERT_path)
114
+
115
+ scheduler_params = {
116
+ "max_lr": float(config['optimizer_params'].get('lr', 1e-4)),
117
+ "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)),
118
+ "epochs": epochs,
119
+ "steps_per_epoch": len(train_dataloader),
120
+ }
121
+
122
+ model_params = recursive_munch(config['model_params'])
123
+ multispeaker = model_params.multispeaker
124
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
125
+
126
+ best_loss = float('inf') # best test loss
127
+ loss_train_record = list([])
128
+ loss_test_record = list([])
129
+
130
+ loss_params = Munch(config['loss_params'])
131
+ TMA_epoch = loss_params.TMA_epoch
132
+
133
+ for k in model:
134
+ model[k] = accelerator.prepare(model[k])
135
+
136
+ train_dataloader, val_dataloader = accelerator.prepare(
137
+ train_dataloader, val_dataloader
138
+ )
139
+
140
+ _ = [model[key].to(device) for key in model]
141
+
142
+ # initialize optimizers after preparing models for compatibility with FSDP
143
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
144
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model},
145
+ lr=float(config['optimizer_params'].get('lr', 1e-4)))
146
+
147
+ for k, v in optimizer.optimizers.items():
148
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
149
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
150
+
151
+ with accelerator.main_process_first():
152
+ if config.get('pretrained_model', '') != '':
153
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
154
+ load_only_params=config.get('load_only_params', True))
155
+ else:
156
+ start_epoch = 0
157
+ iters = 0
158
+
159
+ # in case not distributed
160
+ try:
161
+ n_down = model.text_aligner.module.n_down
162
+ except:
163
+ n_down = model.text_aligner.n_down
164
+
165
+ # wrapped losses for compatibility with mixed precision
166
+ stft_loss = MultiResolutionSTFTLoss().to(device)
167
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
168
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
169
+ wl = WavLMLoss(model_params.slm.model,
170
+ model.wd,
171
+ sr,
172
+ model_params.slm.sr).to(device)
173
+
174
+ for epoch in range(start_epoch, epochs):
175
+ running_loss = 0
176
+ start_time = time.time()
177
+
178
+ _ = [model[key].train() for key in model]
179
+
180
+ for i, batch in enumerate(train_dataloader):
181
+ waves = batch[0]
182
+ batch = [b.to(device) for b in batch[1:]]
183
+ texts, input_lengths, _, _, mels, mel_input_length, _ = batch
184
+
185
+ with torch.no_grad():
186
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
187
+ text_mask = length_to_mask(input_lengths).to(texts.device)
188
+
189
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
190
+
191
+ s2s_attn = s2s_attn.transpose(-1, -2)
192
+ s2s_attn = s2s_attn[..., 1:]
193
+ s2s_attn = s2s_attn.transpose(-1, -2)
194
+
195
+ with torch.no_grad():
196
+ attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
197
+ attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
198
+ attn_mask = (attn_mask < 1)
199
+
200
+ s2s_attn.masked_fill_(attn_mask, 0.0)
201
+
202
+ with torch.no_grad():
203
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
204
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
205
+
206
+ # encode
207
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
208
+
209
+ # 50% of chance of using monotonic version
210
+ if bool(random.getrandbits(1)):
211
+ asr = (t_en @ s2s_attn)
212
+ else:
213
+ asr = (t_en @ s2s_attn_mono)
214
+
215
+ # get clips
216
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
217
+ mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
218
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
219
+
220
+ en = []
221
+ gt = []
222
+ wav = []
223
+ st = []
224
+
225
+ for bib in range(len(mel_input_length)):
226
+ mel_length = int(mel_input_length[bib].item() / 2)
227
+
228
+ random_start = np.random.randint(0, mel_length - mel_len)
229
+ en.append(asr[bib, :, random_start:random_start+mel_len])
230
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
231
+
232
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
233
+ wav.append(torch.from_numpy(y).to(device))
234
+
235
+ # style reference (better to be different from the GT)
236
+ random_start = np.random.randint(0, mel_length - mel_len_st)
237
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
238
+
239
+ en = torch.stack(en)
240
+ gt = torch.stack(gt).detach()
241
+ st = torch.stack(st).detach()
242
+
243
+ wav = torch.stack(wav).float().detach()
244
+
245
+ # clip too short to be used by the style encoder
246
+ if gt.shape[-1] < 80:
247
+ continue
248
+
249
+ with torch.no_grad():
250
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1).detach()
251
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
252
+
253
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
254
+
255
+ y_rec = model.decoder(en, F0_real, real_norm, s)
256
+
257
+ # discriminator loss
258
+
259
+ if epoch >= TMA_epoch:
260
+ optimizer.zero_grad()
261
+ d_loss = dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean()
262
+ accelerator.backward(d_loss)
263
+ optimizer.step('msd')
264
+ optimizer.step('mpd')
265
+ else:
266
+ d_loss = 0
267
+
268
+ # generator loss
269
+ optimizer.zero_grad()
270
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
271
+
272
+ if epoch >= TMA_epoch: # start TMA training
273
+ loss_s2s = 0
274
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
275
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
276
+ loss_s2s /= texts.size(0)
277
+
278
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
279
+
280
+ loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean()
281
+ loss_slm = wl(wav.detach(), y_rec).mean()
282
+
283
+ g_loss = loss_params.lambda_mel * loss_mel + \
284
+ loss_params.lambda_mono * loss_mono + \
285
+ loss_params.lambda_s2s * loss_s2s + \
286
+ loss_params.lambda_gen * loss_gen_all + \
287
+ loss_params.lambda_slm * loss_slm
288
+
289
+ else:
290
+ loss_s2s = 0
291
+ loss_mono = 0
292
+ loss_gen_all = 0
293
+ loss_slm = 0
294
+ g_loss = loss_mel
295
+
296
+ running_loss += accelerator.gather(loss_mel).mean().item()
297
+
298
+ accelerator.backward(g_loss)
299
+
300
+ optimizer.step('text_encoder')
301
+ optimizer.step('style_encoder')
302
+ optimizer.step('decoder')
303
+
304
+ if epoch >= TMA_epoch:
305
+ optimizer.step('text_aligner')
306
+ optimizer.step('pitch_extractor')
307
+
308
+ iters = iters + 1
309
+
310
+ if (i+1)%log_interval == 0 and accelerator.is_main_process:
311
+ log_print ('Epoch [%d/%d], Step [%d/%d], Mel Loss: %.5f, Gen Loss: %.5f, Disc Loss: %.5f, Mono Loss: %.5f, S2S Loss: %.5f, SLM Loss: %.5f'
312
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, loss_gen_all, d_loss, loss_mono, loss_s2s, loss_slm), logger)
313
+
314
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
315
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
316
+ writer.add_scalar('train/d_loss', d_loss, iters)
317
+ writer.add_scalar('train/mono_loss', loss_mono, iters)
318
+ writer.add_scalar('train/s2s_loss', loss_s2s, iters)
319
+ writer.add_scalar('train/slm_loss', loss_slm, iters)
320
+
321
+ running_loss = 0
322
+
323
+ print('Time elasped:', time.time()-start_time)
324
+
325
+ loss_test = 0
326
+
327
+ _ = [model[key].eval() for key in model]
328
+
329
+ with torch.no_grad():
330
+ iters_test = 0
331
+ for batch_idx, batch in enumerate(val_dataloader):
332
+ optimizer.zero_grad()
333
+
334
+ waves = batch[0]
335
+ batch = [b.to(device) for b in batch[1:]]
336
+ texts, input_lengths, _, _, mels, mel_input_length, _ = batch
337
+
338
+ with torch.no_grad():
339
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
340
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
341
+
342
+ s2s_attn = s2s_attn.transpose(-1, -2)
343
+ s2s_attn = s2s_attn[..., 1:]
344
+ s2s_attn = s2s_attn.transpose(-1, -2)
345
+
346
+ text_mask = length_to_mask(input_lengths).to(texts.device)
347
+ attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
348
+ attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
349
+ attn_mask = (attn_mask < 1)
350
+ s2s_attn.masked_fill_(attn_mask, 0.0)
351
+
352
+ # encode
353
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
354
+
355
+ asr = (t_en @ s2s_attn)
356
+
357
+ # get clips
358
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
359
+ mel_len = min([int(mel_input_length.min().item() / 2 - 1), max_len // 2])
360
+
361
+ en = []
362
+ gt = []
363
+ wav = []
364
+ for bib in range(len(mel_input_length)):
365
+ mel_length = int(mel_input_length[bib].item() / 2)
366
+
367
+ random_start = np.random.randint(0, mel_length - mel_len)
368
+ en.append(asr[bib, :, random_start:random_start+mel_len])
369
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
370
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
371
+ wav.append(torch.from_numpy(y).to('cuda'))
372
+
373
+ wav = torch.stack(wav).float().detach()
374
+
375
+ en = torch.stack(en)
376
+ gt = torch.stack(gt).detach()
377
+
378
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
379
+ s = model.style_encoder(gt.unsqueeze(1))
380
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
381
+ y_rec = model.decoder(en, F0_real, real_norm, s)
382
+
383
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
384
+
385
+ loss_test += accelerator.gather(loss_mel).mean().item()
386
+ iters_test += 1
387
+
388
+ if accelerator.is_main_process:
389
+ print('Epochs:', epoch + 1)
390
+ log_print('Validation loss: %.3f' % (loss_test / iters_test) + '\n\n\n\n', logger)
391
+ print('\n\n\n')
392
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
393
+ attn_image = get_image(s2s_attn[0].cpu().numpy().squeeze())
394
+ writer.add_figure('eval/attn', attn_image, epoch)
395
+
396
+ with torch.no_grad():
397
+ for bib in range(len(asr)):
398
+ mel_length = int(mel_input_length[bib].item())
399
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
400
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
401
+
402
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
403
+ F0_real = F0_real.unsqueeze(0)
404
+ s = model.style_encoder(gt.unsqueeze(1))
405
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
406
+
407
+ y_rec = model.decoder(en, F0_real, real_norm, s)
408
+
409
+ writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)
410
+ if epoch == 0:
411
+ writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
412
+
413
+ if bib >= 6:
414
+ break
415
+
416
+ if epoch % saving_epoch == 0:
417
+ if (loss_test / iters_test) < best_loss:
418
+ best_loss = loss_test / iters_test
419
+ print('Saving..')
420
+ state = {
421
+ 'net': {key: model[key].state_dict() for key in model},
422
+ 'optimizer': optimizer.state_dict(),
423
+ 'iters': iters,
424
+ 'val_loss': loss_test / iters_test,
425
+ 'epoch': epoch,
426
+ }
427
+ save_path = osp.join(log_dir, 'epoch_1st_%05d.pth' % epoch)
428
+ torch.save(state, save_path)
429
+
430
+ if accelerator.is_main_process:
431
+ print('Saving..')
432
+ state = {
433
+ 'net': {key: model[key].state_dict() for key in model},
434
+ 'optimizer': optimizer.state_dict(),
435
+ 'iters': iters,
436
+ 'val_loss': loss_test / iters_test,
437
+ 'epoch': epoch,
438
+ }
439
+ save_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
440
+ torch.save(state, save_path)
441
+
442
+
443
+
444
+ if __name__=="__main__":
445
+ main()
train_second.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import traceback
15
+ import warnings
16
+ warnings.simplefilter('ignore')
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ from meldataset import build_dataloader
20
+
21
+ from Utils.ASR.models import ASRCNN
22
+ from Utils.JDC.model import JDCNet
23
+ from Utils.PLBERT.util import load_plbert
24
+
25
+ from models import *
26
+ from losses import *
27
+ from utils import *
28
+
29
+ from Modules.slmadv import SLMAdversarialLoss
30
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
31
+
32
+ from optimizers import build_optimizer
33
+
34
+ # simple fix for dataparallel that allows access to class attributes
35
+ class MyDataParallel(torch.nn.DataParallel):
36
+ def __getattr__(self, name):
37
+ try:
38
+ return super().__getattr__(name)
39
+ except AttributeError:
40
+ return getattr(self.module, name)
41
+
42
+ import logging
43
+ from logging import StreamHandler
44
+ logger = logging.getLogger(__name__)
45
+ logger.setLevel(logging.DEBUG)
46
+ handler = StreamHandler()
47
+ handler.setLevel(logging.DEBUG)
48
+ logger.addHandler(handler)
49
+
50
+
51
+ @click.command()
52
+ @click.option('-p', '--config_path', default='Configs/config.yml', type=str)
53
+ def main(config_path):
54
+ config = yaml.safe_load(open(config_path))
55
+
56
+ log_dir = config['log_dir']
57
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
58
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
59
+ writer = SummaryWriter(log_dir + "/tensorboard")
60
+
61
+ # write logs
62
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
63
+ file_handler.setLevel(logging.DEBUG)
64
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
65
+ logger.addHandler(file_handler)
66
+
67
+
68
+ batch_size = config.get('batch_size', 10)
69
+
70
+ epochs = config.get('epochs_2nd', 200)
71
+ save_freq = config.get('save_freq', 2)
72
+ log_interval = config.get('log_interval', 10)
73
+ saving_epoch = config.get('save_freq', 2)
74
+
75
+ data_params = config.get('data_params', None)
76
+ sr = config['preprocess_params'].get('sr', 24000)
77
+ train_path = data_params['train_data']
78
+ val_path = data_params['val_data']
79
+ root_path = data_params['root_path']
80
+ min_length = data_params['min_length']
81
+ OOD_data = data_params['OOD_data']
82
+
83
+ max_len = config.get('max_len', 200)
84
+
85
+ loss_params = Munch(config['loss_params'])
86
+ diff_epoch = loss_params.diff_epoch
87
+ joint_epoch = loss_params.joint_epoch
88
+
89
+ optimizer_params = Munch(config['optimizer_params'])
90
+
91
+ train_list, val_list = get_data_path_list(train_path, val_path)
92
+ device = 'cuda'
93
+
94
+ train_dataloader = build_dataloader(train_list,
95
+ root_path,
96
+ OOD_data=OOD_data,
97
+ min_length=min_length,
98
+ batch_size=batch_size,
99
+ num_workers=2,
100
+ dataset_config={},
101
+ device=device)
102
+
103
+ val_dataloader = build_dataloader(val_list,
104
+ root_path,
105
+ OOD_data=OOD_data,
106
+ min_length=min_length,
107
+ batch_size=batch_size,
108
+ validation=True,
109
+ num_workers=0,
110
+ device=device,
111
+ dataset_config={})
112
+
113
+ # load pretrained ASR model
114
+ ASR_config = config.get('ASR_config', False)
115
+ ASR_path = config.get('ASR_path', False)
116
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
117
+
118
+ # load pretrained F0 model
119
+ F0_path = config.get('F0_path', False)
120
+ pitch_extractor = load_F0_models(F0_path)
121
+
122
+ # load PL-BERT model
123
+ BERT_path = config.get('PLBERT_dir', False)
124
+ plbert = load_plbert(BERT_path)
125
+
126
+ # build model
127
+ model_params = recursive_munch(config['model_params'])
128
+ multispeaker = model_params.multispeaker
129
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
130
+ _ = [model[key].to(device) for key in model]
131
+
132
+ # DP
133
+ for key in model:
134
+ if key != "mpd" and key != "msd" and key != "wd":
135
+ model[key] = MyDataParallel(model[key])
136
+
137
+ start_epoch = 0
138
+ iters = 0
139
+
140
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
141
+
142
+ if not load_pretrained:
143
+ if config.get('first_stage_path', '') != '':
144
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
145
+ print('Loading the first stage model at %s ...' % first_stage_path)
146
+ model, _, start_epoch, iters = load_checkpoint(model,
147
+ None,
148
+ first_stage_path,
149
+ load_only_params=True,
150
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
151
+
152
+ # these epochs should be counted from the start epoch
153
+ diff_epoch += start_epoch
154
+ joint_epoch += start_epoch
155
+ epochs += start_epoch
156
+
157
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
158
+ else:
159
+ raise ValueError('You need to specify the path to the first stage model.')
160
+
161
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
162
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
163
+ wl = WavLMLoss(model_params.slm.model,
164
+ model.wd,
165
+ sr,
166
+ model_params.slm.sr).to(device)
167
+
168
+ gl = MyDataParallel(gl)
169
+ dl = MyDataParallel(dl)
170
+ wl = MyDataParallel(wl)
171
+
172
+ sampler = DiffusionSampler(
173
+ model.diffusion.diffusion,
174
+ sampler=ADPM2Sampler(),
175
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
176
+ clamp=False
177
+ )
178
+
179
+ scheduler_params = {
180
+ "max_lr": optimizer_params.lr,
181
+ "pct_start": float(0),
182
+ "epochs": epochs,
183
+ "steps_per_epoch": len(train_dataloader),
184
+ }
185
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
186
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
187
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
188
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
189
+
190
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
191
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
192
+
193
+ # adjust BERT learning rate
194
+ for g in optimizer.optimizers['bert'].param_groups:
195
+ g['betas'] = (0.9, 0.99)
196
+ g['lr'] = optimizer_params.bert_lr
197
+ g['initial_lr'] = optimizer_params.bert_lr
198
+ g['min_lr'] = 0
199
+ g['weight_decay'] = 0.01
200
+
201
+ # adjust acoustic module learning rate
202
+ for module in ["decoder", "style_encoder"]:
203
+ for g in optimizer.optimizers[module].param_groups:
204
+ g['betas'] = (0.0, 0.99)
205
+ g['lr'] = optimizer_params.ft_lr
206
+ g['initial_lr'] = optimizer_params.ft_lr
207
+ g['min_lr'] = 0
208
+ g['weight_decay'] = 1e-4
209
+
210
+ # load models if there is a model
211
+ if load_pretrained:
212
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
213
+ load_only_params=config.get('load_only_params', True))
214
+
215
+ n_down = model.text_aligner.n_down
216
+
217
+ best_loss = float('inf') # best test loss
218
+ loss_train_record = list([])
219
+ loss_test_record = list([])
220
+ iters = 0
221
+
222
+ criterion = nn.L1Loss() # F0 loss (regression)
223
+ torch.cuda.empty_cache()
224
+
225
+ stft_loss = MultiResolutionSTFTLoss().to(device)
226
+
227
+ print('BERT', optimizer.optimizers['bert'])
228
+ print('decoder', optimizer.optimizers['decoder'])
229
+
230
+ start_ds = False
231
+
232
+ running_std = []
233
+
234
+ slmadv_params = Munch(config['slmadv_params'])
235
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
236
+ slmadv_params.min_len,
237
+ slmadv_params.max_len,
238
+ batch_percentage=slmadv_params.batch_percentage,
239
+ skip_update=slmadv_params.iter,
240
+ sig=slmadv_params.sig
241
+ )
242
+
243
+
244
+ for epoch in range(start_epoch, epochs):
245
+ running_loss = 0
246
+ start_time = time.time()
247
+
248
+ _ = [model[key].eval() for key in model]
249
+
250
+ model.predictor.train()
251
+ model.bert_encoder.train()
252
+ model.bert.train()
253
+ model.msd.train()
254
+ model.mpd.train()
255
+
256
+
257
+ if epoch >= diff_epoch:
258
+ start_ds = True
259
+
260
+ for i, batch in enumerate(train_dataloader):
261
+ waves = batch[0]
262
+ batch = [b.to(device) for b in batch[1:]]
263
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
264
+
265
+ with torch.no_grad():
266
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
267
+ mel_mask = length_to_mask(mel_input_length).to(device)
268
+ text_mask = length_to_mask(input_lengths).to(texts.device)
269
+
270
+ try:
271
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
272
+ s2s_attn = s2s_attn.transpose(-1, -2)
273
+ s2s_attn = s2s_attn[..., 1:]
274
+ s2s_attn = s2s_attn.transpose(-1, -2)
275
+ except:
276
+ continue
277
+
278
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
279
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
280
+
281
+ # encode
282
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
283
+ asr = (t_en @ s2s_attn_mono)
284
+
285
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
286
+
287
+ # compute reference styles
288
+ if multispeaker and epoch >= diff_epoch:
289
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
290
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
291
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
292
+
293
+ # compute the style of the entire utterance
294
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
295
+ ss = []
296
+ gs = []
297
+ for bib in range(len(mel_input_length)):
298
+ mel_length = int(mel_input_length[bib].item())
299
+ mel = mels[bib, :, :mel_input_length[bib]]
300
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
301
+ ss.append(s)
302
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
303
+ gs.append(s)
304
+
305
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
306
+ gs = torch.stack(gs).squeeze() # global acoustic styles
307
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
308
+
309
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
310
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
311
+
312
+ # denoiser training
313
+ if epoch >= diff_epoch:
314
+ num_steps = np.random.randint(3, 5)
315
+
316
+ if model_params.diffusion.dist.estimate_sigma_data:
317
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
318
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
319
+
320
+ if multispeaker:
321
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
322
+ embedding=bert_dur,
323
+ embedding_scale=1,
324
+ features=ref, # reference from the same speaker as the embedding
325
+ embedding_mask_proba=0.1,
326
+ num_steps=num_steps).squeeze(1)
327
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
328
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
329
+ else:
330
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
331
+ embedding=bert_dur,
332
+ embedding_scale=1,
333
+ embedding_mask_proba=0.1,
334
+ num_steps=num_steps).squeeze(1)
335
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
336
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
337
+ else:
338
+ loss_sty = 0
339
+ loss_diff = 0
340
+
341
+ d, p = model.predictor(d_en, s_dur,
342
+ input_lengths,
343
+ s2s_attn_mono,
344
+ text_mask)
345
+
346
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
347
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
348
+ en = []
349
+ gt = []
350
+ st = []
351
+ p_en = []
352
+ wav = []
353
+
354
+ for bib in range(len(mel_input_length)):
355
+ mel_length = int(mel_input_length[bib].item() / 2)
356
+
357
+ random_start = np.random.randint(0, mel_length - mel_len)
358
+ en.append(asr[bib, :, random_start:random_start+mel_len])
359
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
360
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
361
+
362
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
363
+ wav.append(torch.from_numpy(y).to(device))
364
+
365
+ # style reference (better to be different from the GT)
366
+ random_start = np.random.randint(0, mel_length - mel_len_st)
367
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
368
+
369
+ wav = torch.stack(wav).float().detach()
370
+
371
+ en = torch.stack(en)
372
+ p_en = torch.stack(p_en)
373
+ gt = torch.stack(gt).detach()
374
+ st = torch.stack(st).detach()
375
+
376
+ if gt.size(-1) < 80:
377
+ continue
378
+
379
+ s_dur = model.predictor_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
380
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
381
+
382
+ with torch.no_grad():
383
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
384
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
385
+
386
+ asr_real = model.text_aligner.get_feature(gt)
387
+
388
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
389
+
390
+ y_rec_gt = wav.unsqueeze(1)
391
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
392
+
393
+ if epoch >= joint_epoch:
394
+ # ground truth from recording
395
+ wav = y_rec_gt # use recording since decoder is tuned
396
+ else:
397
+ # ground truth from reconstruction
398
+ wav = y_rec_gt_pred # use reconstruction since decoder is fixed
399
+
400
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
401
+
402
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
403
+
404
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
405
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
406
+
407
+ if start_ds:
408
+ optimizer.zero_grad()
409
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
410
+ d_loss.backward()
411
+ optimizer.step('msd')
412
+ optimizer.step('mpd')
413
+ else:
414
+ d_loss = 0
415
+
416
+ # generator loss
417
+ optimizer.zero_grad()
418
+
419
+ loss_mel = stft_loss(y_rec, wav)
420
+ if start_ds:
421
+ loss_gen_all = gl(wav, y_rec).mean()
422
+ else:
423
+ loss_gen_all = 0
424
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
425
+
426
+ loss_ce = 0
427
+ loss_dur = 0
428
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
429
+ _s2s_pred = _s2s_pred[:_text_length, :]
430
+ _text_input = _text_input[:_text_length].long()
431
+ _s2s_trg = torch.zeros_like(_s2s_pred)
432
+ for p in range(_s2s_trg.shape[0]):
433
+ _s2s_trg[p, :_text_input[p]] = 1
434
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
435
+
436
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
437
+ _text_input[1:_text_length-1])
438
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
439
+
440
+ loss_ce /= texts.size(0)
441
+ loss_dur /= texts.size(0)
442
+
443
+ g_loss = loss_params.lambda_mel * loss_mel + \
444
+ loss_params.lambda_F0 * loss_F0_rec + \
445
+ loss_params.lambda_ce * loss_ce + \
446
+ loss_params.lambda_norm * loss_norm_rec + \
447
+ loss_params.lambda_dur * loss_dur + \
448
+ loss_params.lambda_gen * loss_gen_all + \
449
+ loss_params.lambda_slm * loss_lm + \
450
+ loss_params.lambda_sty * loss_sty + \
451
+ loss_params.lambda_diff * loss_diff
452
+
453
+ running_loss += loss_mel.item()
454
+ g_loss.backward()
455
+ if torch.isnan(g_loss):
456
+ from IPython.core.debugger import set_trace
457
+ set_trace()
458
+
459
+ optimizer.step('bert_encoder')
460
+ optimizer.step('bert')
461
+ optimizer.step('predictor')
462
+ optimizer.step('predictor_encoder')
463
+
464
+ if epoch >= diff_epoch:
465
+ optimizer.step('diffusion')
466
+
467
+ if epoch >= joint_epoch:
468
+ optimizer.step('style_encoder')
469
+ optimizer.step('decoder')
470
+
471
+ # randomly pick whether to use in-distribution text
472
+ if np.random.rand() < 0.5:
473
+ use_ind = True
474
+ else:
475
+ use_ind = False
476
+
477
+ if use_ind:
478
+ ref_lengths = input_lengths
479
+ ref_texts = texts
480
+
481
+ slm_out = slmadv(i,
482
+ y_rec_gt,
483
+ y_rec_gt_pred,
484
+ waves,
485
+ mel_input_length,
486
+ ref_texts,
487
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
488
+
489
+ if slm_out is None:
490
+ continue
491
+
492
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
493
+
494
+ # SLM generator loss
495
+ optimizer.zero_grad()
496
+ loss_gen_lm.backward()
497
+
498
+ # compute the gradient norm
499
+ total_norm = {}
500
+ for key in model.keys():
501
+ total_norm[key] = 0
502
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
503
+ for p in parameters:
504
+ param_norm = p.grad.detach().data.norm(2)
505
+ total_norm[key] += param_norm.item() ** 2
506
+ total_norm[key] = total_norm[key] ** 0.5
507
+
508
+ # gradient scaling
509
+ if total_norm['predictor'] > slmadv_params.thresh:
510
+ for key in model.keys():
511
+ for p in model[key].parameters():
512
+ if p.grad is not None:
513
+ p.grad *= (1 / total_norm['predictor'])
514
+
515
+ for p in model.predictor.duration_proj.parameters():
516
+ if p.grad is not None:
517
+ p.grad *= slmadv_params.scale
518
+
519
+ for p in model.predictor.lstm.parameters():
520
+ if p.grad is not None:
521
+ p.grad *= slmadv_params.scale
522
+
523
+ for p in model.diffusion.parameters():
524
+ if p.grad is not None:
525
+ p.grad *= slmadv_params.scale
526
+
527
+ optimizer.step('bert_encoder')
528
+ optimizer.step('bert')
529
+ optimizer.step('predictor')
530
+ optimizer.step('diffusion')
531
+
532
+ # SLM discriminator loss
533
+ if d_loss_slm != 0:
534
+ optimizer.zero_grad()
535
+ d_loss_slm.backward(retain_graph=True)
536
+ optimizer.step('wd')
537
+
538
+ else:
539
+ d_loss_slm, loss_gen_lm = 0, 0
540
+
541
+ iters = iters + 1
542
+
543
+ if (i+1)%log_interval == 0:
544
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
545
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm))
546
+
547
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
548
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
549
+ writer.add_scalar('train/d_loss', d_loss, iters)
550
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
551
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
552
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
553
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
554
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
555
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
556
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
557
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
558
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
559
+
560
+ running_loss = 0
561
+
562
+ print('Time elasped:', time.time()-start_time)
563
+
564
+ loss_test = 0
565
+ loss_align = 0
566
+ loss_f = 0
567
+ _ = [model[key].eval() for key in model]
568
+
569
+ with torch.no_grad():
570
+ iters_test = 0
571
+ for batch_idx, batch in enumerate(val_dataloader):
572
+ optimizer.zero_grad()
573
+
574
+ try:
575
+ waves = batch[0]
576
+ batch = [b.to(device) for b in batch[1:]]
577
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
578
+ with torch.no_grad():
579
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
580
+ text_mask = length_to_mask(input_lengths).to(texts.device)
581
+
582
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
583
+ s2s_attn = s2s_attn.transpose(-1, -2)
584
+ s2s_attn = s2s_attn[..., 1:]
585
+ s2s_attn = s2s_attn.transpose(-1, -2)
586
+
587
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
588
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
589
+
590
+ # encode
591
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
592
+ asr = (t_en @ s2s_attn_mono)
593
+
594
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
595
+
596
+ ss = []
597
+ gs = []
598
+
599
+ for bib in range(len(mel_input_length)):
600
+ mel_length = int(mel_input_length[bib].item())
601
+ mel = mels[bib, :, :mel_input_length[bib]]
602
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
603
+ ss.append(s)
604
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
605
+ gs.append(s)
606
+
607
+ s = torch.stack(ss).squeeze()
608
+ gs = torch.stack(gs).squeeze()
609
+ s_trg = torch.cat([s, gs], dim=-1).detach()
610
+
611
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
612
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
613
+ d, p = model.predictor(d_en, s,
614
+ input_lengths,
615
+ s2s_attn_mono,
616
+ text_mask)
617
+ # get clips
618
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
619
+ en = []
620
+ gt = []
621
+ p_en = []
622
+ wav = []
623
+
624
+ for bib in range(len(mel_input_length)):
625
+ mel_length = int(mel_input_length[bib].item() / 2)
626
+
627
+ random_start = np.random.randint(0, mel_length - mel_len)
628
+ en.append(asr[bib, :, random_start:random_start+mel_len])
629
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
630
+
631
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
632
+
633
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
634
+ wav.append(torch.from_numpy(y).to(device))
635
+
636
+ wav = torch.stack(wav).float().detach()
637
+
638
+ en = torch.stack(en)
639
+ p_en = torch.stack(p_en)
640
+ gt = torch.stack(gt).detach()
641
+
642
+ s = model.predictor_encoder(gt.unsqueeze(1))
643
+
644
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
645
+
646
+ loss_dur = 0
647
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
648
+ _s2s_pred = _s2s_pred[:_text_length, :]
649
+ _text_input = _text_input[:_text_length].long()
650
+ _s2s_trg = torch.zeros_like(_s2s_pred)
651
+ for bib in range(_s2s_trg.shape[0]):
652
+ _s2s_trg[bib, :_text_input[bib]] = 1
653
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
654
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
655
+ _text_input[1:_text_length-1])
656
+
657
+ loss_dur /= texts.size(0)
658
+
659
+ s = model.style_encoder(gt.unsqueeze(1))
660
+
661
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
662
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
663
+
664
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
665
+
666
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
667
+
668
+ loss_test += (loss_mel).mean()
669
+ loss_align += (loss_dur).mean()
670
+ loss_f += (loss_F0).mean()
671
+
672
+ iters_test += 1
673
+ except Exception as e:
674
+ print(f"run into exception", e)
675
+ traceback.print_exc()
676
+ continue
677
+
678
+ print('Epochs:', epoch + 1)
679
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
680
+ print('\n\n\n')
681
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
682
+ writer.add_scalar('eval/dur_loss', loss_align / iters_test, epoch + 1)
683
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
684
+
685
+ if epoch < joint_epoch:
686
+ # generating reconstruction examples with GT duration
687
+
688
+ with torch.no_grad():
689
+ for bib in range(len(asr)):
690
+ mel_length = int(mel_input_length[bib].item())
691
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
692
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
693
+
694
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
695
+ F0_real = F0_real.unsqueeze(0)
696
+ s = model.style_encoder(gt.unsqueeze(1))
697
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
698
+
699
+ y_rec = model.decoder(en, F0_real, real_norm, s)
700
+
701
+ writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)
702
+
703
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
704
+ p_en = p[bib, :, :mel_length // 2].unsqueeze(0)
705
+
706
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
707
+
708
+ y_pred = model.decoder(en, F0_fake, N_fake, s)
709
+
710
+ writer.add_audio('pred/y' + str(bib), y_pred.cpu().numpy().squeeze(), epoch, sample_rate=sr)
711
+
712
+ if epoch == 0:
713
+ writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
714
+
715
+ if bib >= 5:
716
+ break
717
+ else:
718
+ # generating sampled speech from text directly
719
+ with torch.no_grad():
720
+ # compute reference styles
721
+ if multispeaker and epoch >= diff_epoch:
722
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
723
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
724
+ ref_s = torch.cat([ref_ss, ref_sp], dim=1)
725
+
726
+ for bib in range(len(d_en)):
727
+ if multispeaker:
728
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device),
729
+ embedding=bert_dur[bib].unsqueeze(0),
730
+ embedding_scale=1,
731
+ features=ref_s[bib].unsqueeze(0), # reference from the same speaker as the embedding
732
+ num_steps=5).squeeze(1)
733
+ else:
734
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device),
735
+ embedding=bert_dur[bib].unsqueeze(0),
736
+ embedding_scale=1,
737
+ num_steps=5).squeeze(1)
738
+
739
+ s = s_pred[:, 128:]
740
+ ref = s_pred[:, :128]
741
+
742
+ d = model.predictor.text_encoder(d_en[bib, :, :input_lengths[bib]].unsqueeze(0),
743
+ s, input_lengths[bib, ...].unsqueeze(0), text_mask[bib, :input_lengths[bib]].unsqueeze(0))
744
+
745
+ x, _ = model.predictor.lstm(d)
746
+ duration = model.predictor.duration_proj(x)
747
+
748
+ duration = torch.sigmoid(duration).sum(axis=-1)
749
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
750
+
751
+ pred_dur[-1] += 5
752
+
753
+ pred_aln_trg = torch.zeros(input_lengths[bib], int(pred_dur.sum().data))
754
+ c_frame = 0
755
+ for i in range(pred_aln_trg.size(0)):
756
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
757
+ c_frame += int(pred_dur[i].data)
758
+
759
+ # encode prosody
760
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(texts.device))
761
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
762
+ out = model.decoder((t_en[bib, :, :input_lengths[bib]].unsqueeze(0) @ pred_aln_trg.unsqueeze(0).to(texts.device)),
763
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
764
+
765
+ writer.add_audio('pred/y' + str(bib), out.cpu().numpy().squeeze(), epoch, sample_rate=sr)
766
+
767
+ if bib >= 5:
768
+ break
769
+
770
+ if epoch % saving_epoch == 0:
771
+ if (loss_test / iters_test) < best_loss:
772
+ best_loss = loss_test / iters_test
773
+ print('Saving..')
774
+ state = {
775
+ 'net': {key: model[key].state_dict() for key in model},
776
+ 'optimizer': optimizer.state_dict(),
777
+ 'iters': iters,
778
+ 'val_loss': loss_test / iters_test,
779
+ 'epoch': epoch,
780
+ }
781
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
782
+ torch.save(state, save_path)
783
+
784
+ # if estimate sigma, save the estimated simga
785
+ if model_params.diffusion.dist.estimate_sigma_data:
786
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
787
+
788
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
789
+ yaml.dump(config, outfile, default_flow_style=True)
790
+
791
+ if __name__=="__main__":
792
+ main()
utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from monotonic_align import maximum_path
2
+ from monotonic_align import mask_from_lens
3
+ from monotonic_align.core import maximum_path_c
4
+ import numpy as np
5
+ import torch
6
+ import copy
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ import torchaudio
10
+ import librosa
11
+ import matplotlib.pyplot as plt
12
+ from munch import Munch
13
+
14
+ def maximum_path(neg_cent, mask):
15
+ """ Cython optimized version.
16
+ neg_cent: [b, t_t, t_s]
17
+ mask: [b, t_t, t_s]
18
+ """
19
+ device = neg_cent.device
20
+ dtype = neg_cent.dtype
21
+ neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32))
22
+ path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32))
23
+
24
+ t_t_max = np.ascontiguousarray(mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32))
25
+ t_s_max = np.ascontiguousarray(mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32))
26
+ maximum_path_c(path, neg_cent, t_t_max, t_s_max)
27
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
28
+
29
+ def get_data_path_list(train_path=None, val_path=None):
30
+ if train_path is None:
31
+ train_path = "Data/train_list.txt"
32
+ if val_path is None:
33
+ val_path = "Data/val_list.txt"
34
+
35
+ with open(train_path, 'r', encoding='utf-8', errors='ignore') as f:
36
+ train_list = f.readlines()
37
+ with open(val_path, 'r', encoding='utf-8', errors='ignore') as f:
38
+ val_list = f.readlines()
39
+
40
+ return train_list, val_list
41
+
42
+ def length_to_mask(lengths):
43
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
44
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
45
+ return mask
46
+
47
+ # for norm consistency loss
48
+ def log_norm(x, mean=-4, std=4, dim=2):
49
+ """
50
+ normalized log mel -> mel -> norm -> log(norm)
51
+ """
52
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
53
+ return x
54
+
55
+ def get_image(arrs):
56
+ plt.switch_backend('agg')
57
+ fig = plt.figure()
58
+ ax = plt.gca()
59
+ ax.imshow(arrs)
60
+
61
+ return fig
62
+
63
+ def recursive_munch(d):
64
+ if isinstance(d, dict):
65
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
66
+ elif isinstance(d, list):
67
+ return [recursive_munch(v) for v in d]
68
+ else:
69
+ return d
70
+
71
+ def log_print(message, logger):
72
+ logger.info(message)
73
+ print(message)
74
+