BurhaanZargar commited on
Commit
592300c
·
verified ·
1 Parent(s): 1618227

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. matcha/VERSION +1 -0
  3. matcha/__init__.py +0 -0
  4. matcha/__pycache__/__init__.cpython-311.pyc +0 -0
  5. matcha/app.py +357 -0
  6. matcha/cli.py +419 -0
  7. matcha/data/__init__.py +0 -0
  8. matcha/data/__pycache__/__init__.cpython-311.pyc +0 -0
  9. matcha/data/__pycache__/text_mel_datamodule.cpython-311.pyc +0 -0
  10. matcha/data/components/__init__.py +0 -0
  11. matcha/data/text_mel_datamodule.py +274 -0
  12. matcha/hifigan/LICENSE +21 -0
  13. matcha/hifigan/README.md +101 -0
  14. matcha/hifigan/__init__.py +0 -0
  15. matcha/hifigan/__pycache__/__init__.cpython-311.pyc +0 -0
  16. matcha/hifigan/__pycache__/config.cpython-311.pyc +0 -0
  17. matcha/hifigan/__pycache__/env.cpython-311.pyc +0 -0
  18. matcha/hifigan/__pycache__/models.cpython-311.pyc +0 -0
  19. matcha/hifigan/__pycache__/xutils.cpython-311.pyc +0 -0
  20. matcha/hifigan/config.py +28 -0
  21. matcha/hifigan/denoiser.py +68 -0
  22. matcha/hifigan/env.py +17 -0
  23. matcha/hifigan/meldataset.py +217 -0
  24. matcha/hifigan/models.py +368 -0
  25. matcha/hifigan/xutils.py +60 -0
  26. matcha/models/__init__.py +0 -0
  27. matcha/models/__pycache__/__init__.cpython-311.pyc +0 -0
  28. matcha/models/__pycache__/baselightningmodule.cpython-311.pyc +0 -0
  29. matcha/models/__pycache__/matcha_tts.cpython-311.pyc +0 -0
  30. matcha/models/baselightningmodule.py +210 -0
  31. matcha/models/components/__init__.py +0 -0
  32. matcha/models/components/__pycache__/__init__.cpython-311.pyc +0 -0
  33. matcha/models/components/__pycache__/decoder.cpython-311.pyc +0 -0
  34. matcha/models/components/__pycache__/flow_matching.cpython-311.pyc +0 -0
  35. matcha/models/components/__pycache__/text_encoder.cpython-311.pyc +0 -0
  36. matcha/models/components/__pycache__/transformer.cpython-311.pyc +0 -0
  37. matcha/models/components/decoder.py +443 -0
  38. matcha/models/components/flow_matching.py +132 -0
  39. matcha/models/components/text_encoder.py +410 -0
  40. matcha/models/components/transformer.py +316 -0
  41. matcha/models/matcha_tts.py +245 -0
  42. matcha/onnx/__init__.py +0 -0
  43. matcha/onnx/export.py +181 -0
  44. matcha/onnx/infer.py +168 -0
  45. matcha/text/__init__.py +57 -0
  46. matcha/text/__pycache__/__init__.cpython-311.pyc +0 -0
  47. matcha/text/__pycache__/cleaners.cpython-311.pyc +0 -0
  48. matcha/text/__pycache__/symbols.cpython-311.pyc +0 -0
  49. matcha/text/cleaners.py +145 -0
  50. matcha/text/numbers.py +71 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ matcha/utils/monotonic_align/core.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
matcha/VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.0.7.2
matcha/__init__.py ADDED
File without changes
matcha/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (171 Bytes). View file
 
matcha/app.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from argparse import Namespace
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import soundfile as sf
7
+ import torch
8
+
9
+ from matcha.cli import (
10
+ MATCHA_URLS,
11
+ VOCODER_URLS,
12
+ assert_model_downloaded,
13
+ get_device,
14
+ load_matcha,
15
+ load_vocoder,
16
+ process_text,
17
+ to_waveform,
18
+ )
19
+ from matcha.utils.utils import get_user_data_dir, plot_tensor
20
+
21
+ LOCATION = Path(get_user_data_dir())
22
+
23
+ args = Namespace(
24
+ cpu=False,
25
+ model="matcha_vctk",
26
+ vocoder="hifigan_univ_v1",
27
+ spk=0,
28
+ )
29
+
30
+ CURRENTLY_LOADED_MODEL = args.model
31
+
32
+
33
+ def MATCHA_TTS_LOC(x):
34
+ return LOCATION / f"{x}.ckpt"
35
+
36
+
37
+ def VOCODER_LOC(x):
38
+ return LOCATION / f"{x}"
39
+
40
+
41
+ LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
42
+ RADIO_OPTIONS = {
43
+ "Multi Speaker (VCTK)": {
44
+ "model": "matcha_vctk",
45
+ "vocoder": "hifigan_univ_v1",
46
+ },
47
+ "Single Speaker (LJ Speech)": {
48
+ "model": "matcha_ljspeech",
49
+ "vocoder": "hifigan_T2_v1",
50
+ },
51
+ }
52
+
53
+ # Ensure all the required models are downloaded
54
+ assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"])
55
+ assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"])
56
+ assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"])
57
+ assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"])
58
+
59
+ device = get_device(args)
60
+
61
+ # Load default model
62
+ model = load_matcha(args.model, MATCHA_TTS_LOC(args.model), device)
63
+ vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC(args.vocoder), device)
64
+
65
+
66
+ def load_model(model_name, vocoder_name):
67
+ model = load_matcha(model_name, MATCHA_TTS_LOC(model_name), device)
68
+ vocoder, denoiser = load_vocoder(vocoder_name, VOCODER_LOC(vocoder_name), device)
69
+ return model, vocoder, denoiser
70
+
71
+
72
+ def load_model_ui(model_type, textbox):
73
+ model_name, vocoder_name = RADIO_OPTIONS[model_type]["model"], RADIO_OPTIONS[model_type]["vocoder"]
74
+
75
+ global model, vocoder, denoiser, CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
76
+ if CURRENTLY_LOADED_MODEL != model_name:
77
+ model, vocoder, denoiser = load_model(model_name, vocoder_name)
78
+ CURRENTLY_LOADED_MODEL = model_name
79
+
80
+ if model_name == "matcha_ljspeech":
81
+ spk_slider = gr.update(visible=False, value=-1)
82
+ single_speaker_examples = gr.update(visible=True)
83
+ multi_speaker_examples = gr.update(visible=False)
84
+ length_scale = gr.update(value=0.95)
85
+ else:
86
+ spk_slider = gr.update(visible=True, value=0)
87
+ single_speaker_examples = gr.update(visible=False)
88
+ multi_speaker_examples = gr.update(visible=True)
89
+ length_scale = gr.update(value=0.85)
90
+
91
+ return (
92
+ textbox,
93
+ gr.update(interactive=True),
94
+ spk_slider,
95
+ single_speaker_examples,
96
+ multi_speaker_examples,
97
+ length_scale,
98
+ )
99
+
100
+
101
+ @torch.inference_mode()
102
+ def process_text_gradio(text):
103
+ output = process_text(1, text, device)
104
+ return output["x_phones"][1::2], output["x"], output["x_lengths"]
105
+
106
+
107
+ @torch.inference_mode()
108
+ def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk):
109
+ spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
110
+ output = model.synthesise(
111
+ text,
112
+ text_length,
113
+ n_timesteps=n_timesteps,
114
+ temperature=temperature,
115
+ spks=spk,
116
+ length_scale=length_scale,
117
+ )
118
+ output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
119
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
120
+ sf.write(fp.name, output["waveform"], 22050, "PCM_24")
121
+
122
+ return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy())
123
+
124
+
125
+ def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scale, spk):
126
+ global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
127
+ if CURRENTLY_LOADED_MODEL != "matcha_vctk":
128
+ global model, vocoder, denoiser # pylint: disable=global-statement
129
+ model, vocoder, denoiser = load_model("matcha_vctk", "hifigan_univ_v1")
130
+ CURRENTLY_LOADED_MODEL = "matcha_vctk"
131
+
132
+ phones, text, text_lengths = process_text_gradio(text)
133
+ audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
134
+ return phones, audio, mel_spectrogram
135
+
136
+
137
+ def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, spk=-1):
138
+ global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
139
+ if CURRENTLY_LOADED_MODEL != "matcha_ljspeech":
140
+ global model, vocoder, denoiser # pylint: disable=global-statement
141
+ model, vocoder, denoiser = load_model("matcha_ljspeech", "hifigan_T2_v1")
142
+ CURRENTLY_LOADED_MODEL = "matcha_ljspeech"
143
+
144
+ phones, text, text_lengths = process_text_gradio(text)
145
+ audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
146
+ return phones, audio, mel_spectrogram
147
+
148
+
149
+ def main():
150
+ description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
151
+ ### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
152
+ We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis. Our method:
153
+
154
+
155
+ * Is probabilistic
156
+ * Has compact memory footprint
157
+ * Sounds highly natural
158
+ * Is very fast to synthesise from
159
+
160
+
161
+ Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS). Read our [arXiv preprint for more details](https://arxiv.org/abs/2309.03199).
162
+ Code is available in our [GitHub repository](https://github.com/shivammehta25/Matcha-TTS), along with pre-trained models.
163
+
164
+ Cached examples are available at the bottom of the page.
165
+ """
166
+
167
+ with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo:
168
+ processed_text = gr.State(value=None)
169
+ processed_text_len = gr.State(value=None)
170
+
171
+ with gr.Box():
172
+ with gr.Row():
173
+ gr.Markdown(description, scale=3)
174
+ with gr.Column():
175
+ gr.Image(LOGO_URL, label="Matcha-TTS logo", height=50, width=50, scale=1, show_label=False)
176
+ html = '<br><iframe width="560" height="315" src="https://www.youtube.com/embed/xmvJkz3bqw0?si=jN7ILyDsbPwJCGoa" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>'
177
+ gr.HTML(html)
178
+
179
+ with gr.Box():
180
+ radio_options = list(RADIO_OPTIONS.keys())
181
+ model_type = gr.Radio(
182
+ radio_options, value=radio_options[0], label="Choose a Model", interactive=True, container=False
183
+ )
184
+
185
+ with gr.Row():
186
+ gr.Markdown("# Text Input")
187
+ with gr.Row():
188
+ text = gr.Textbox(value="", lines=2, label="Text to synthesise", scale=3)
189
+ spk_slider = gr.Slider(
190
+ minimum=0, maximum=107, step=1, value=args.spk, label="Speaker ID", interactive=True, scale=1
191
+ )
192
+
193
+ with gr.Row():
194
+ gr.Markdown("### Hyper parameters")
195
+ with gr.Row():
196
+ n_timesteps = gr.Slider(
197
+ label="Number of ODE steps",
198
+ minimum=1,
199
+ maximum=100,
200
+ step=1,
201
+ value=10,
202
+ interactive=True,
203
+ )
204
+ length_scale = gr.Slider(
205
+ label="Length scale (Speaking rate)",
206
+ minimum=0.5,
207
+ maximum=1.5,
208
+ step=0.05,
209
+ value=1.0,
210
+ interactive=True,
211
+ )
212
+ mel_temp = gr.Slider(
213
+ label="Sampling temperature",
214
+ minimum=0.00,
215
+ maximum=2.001,
216
+ step=0.16675,
217
+ value=0.667,
218
+ interactive=True,
219
+ )
220
+
221
+ synth_btn = gr.Button("Synthesise")
222
+
223
+ with gr.Box():
224
+ with gr.Row():
225
+ gr.Markdown("### Phonetised text")
226
+ phonetised_text = gr.Textbox(interactive=False, scale=10, label="Phonetised text")
227
+
228
+ with gr.Box():
229
+ with gr.Row():
230
+ mel_spectrogram = gr.Image(interactive=False, label="mel spectrogram")
231
+
232
+ # with gr.Row():
233
+ audio = gr.Audio(interactive=False, label="Audio")
234
+
235
+ with gr.Row(visible=False) as example_row_lj_speech:
236
+ examples = gr.Examples( # pylint: disable=unused-variable
237
+ examples=[
238
+ [
239
+ "We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
240
+ 50,
241
+ 0.677,
242
+ 0.95,
243
+ ],
244
+ [
245
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
246
+ 2,
247
+ 0.677,
248
+ 0.95,
249
+ ],
250
+ [
251
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
252
+ 4,
253
+ 0.677,
254
+ 0.95,
255
+ ],
256
+ [
257
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
258
+ 10,
259
+ 0.677,
260
+ 0.95,
261
+ ],
262
+ [
263
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
264
+ 50,
265
+ 0.677,
266
+ 0.95,
267
+ ],
268
+ [
269
+ "The narrative of these events is based largely on the recollections of the participants.",
270
+ 10,
271
+ 0.677,
272
+ 0.95,
273
+ ],
274
+ [
275
+ "The jury did not believe him, and the verdict was for the defendants.",
276
+ 10,
277
+ 0.677,
278
+ 0.95,
279
+ ],
280
+ ],
281
+ fn=ljspeech_example_cacher,
282
+ inputs=[text, n_timesteps, mel_temp, length_scale],
283
+ outputs=[phonetised_text, audio, mel_spectrogram],
284
+ cache_examples=True,
285
+ )
286
+
287
+ with gr.Row() as example_row_multispeaker:
288
+ multi_speaker_examples = gr.Examples( # pylint: disable=unused-variable
289
+ examples=[
290
+ [
291
+ "Hello everyone! I am speaker 0 and I am here to tell you that Matcha-TTS is amazing!",
292
+ 10,
293
+ 0.677,
294
+ 0.85,
295
+ 0,
296
+ ],
297
+ [
298
+ "Hello everyone! I am speaker 16 and I am here to tell you that Matcha-TTS is amazing!",
299
+ 10,
300
+ 0.677,
301
+ 0.85,
302
+ 16,
303
+ ],
304
+ [
305
+ "Hello everyone! I am speaker 44 and I am here to tell you that Matcha-TTS is amazing!",
306
+ 50,
307
+ 0.677,
308
+ 0.85,
309
+ 44,
310
+ ],
311
+ [
312
+ "Hello everyone! I am speaker 45 and I am here to tell you that Matcha-TTS is amazing!",
313
+ 50,
314
+ 0.677,
315
+ 0.85,
316
+ 45,
317
+ ],
318
+ [
319
+ "Hello everyone! I am speaker 58 and I am here to tell you that Matcha-TTS is amazing!",
320
+ 4,
321
+ 0.677,
322
+ 0.85,
323
+ 58,
324
+ ],
325
+ ],
326
+ fn=multispeaker_example_cacher,
327
+ inputs=[text, n_timesteps, mel_temp, length_scale, spk_slider],
328
+ outputs=[phonetised_text, audio, mel_spectrogram],
329
+ cache_examples=True,
330
+ label="Multi Speaker Examples",
331
+ )
332
+
333
+ model_type.change(lambda x: gr.update(interactive=False), inputs=[synth_btn], outputs=[synth_btn]).then(
334
+ load_model_ui,
335
+ inputs=[model_type, text],
336
+ outputs=[text, synth_btn, spk_slider, example_row_lj_speech, example_row_multispeaker, length_scale],
337
+ )
338
+
339
+ synth_btn.click(
340
+ fn=process_text_gradio,
341
+ inputs=[
342
+ text,
343
+ ],
344
+ outputs=[phonetised_text, processed_text, processed_text_len],
345
+ api_name="matcha_tts",
346
+ queue=True,
347
+ ).then(
348
+ fn=synthesise_mel,
349
+ inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale, spk_slider],
350
+ outputs=[audio, mel_spectrogram],
351
+ )
352
+
353
+ demo.queue().launch(share=True)
354
+
355
+
356
+ if __name__ == "__main__":
357
+ main()
matcha/cli.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime as dt
3
+ import os
4
+ import warnings
5
+ from pathlib import Path
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import torch
11
+
12
+ from matcha.hifigan.config import v1
13
+ from matcha.hifigan.denoiser import Denoiser
14
+ from matcha.hifigan.env import AttrDict
15
+ from matcha.hifigan.models import Generator as HiFiGAN
16
+ from matcha.models.matcha_tts import MatchaTTS
17
+ from matcha.text import sequence_to_text, text_to_sequence
18
+ from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse
19
+
20
+ MATCHA_URLS = {
21
+ "matcha_ljspeech": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_ljspeech.ckpt",
22
+ "matcha_vctk": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_vctk.ckpt",
23
+ }
24
+
25
+ VOCODER_URLS = {
26
+ "hifigan_T2_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1", # Old url: https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link
27
+ "hifigan_univ_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/g_02500000", # Old url: https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link
28
+ }
29
+
30
+ MULTISPEAKER_MODEL = {
31
+ "matcha_vctk": {"vocoder": "hifigan_univ_v1", "speaking_rate": 0.85, "spk": 0, "spk_range": (0, 107)}
32
+ }
33
+
34
+ SINGLESPEAKER_MODEL = {"matcha_ljspeech": {"vocoder": "hifigan_T2_v1", "speaking_rate": 0.95, "spk": None}}
35
+
36
+
37
+ def plot_spectrogram_to_numpy(spectrogram, filename):
38
+ fig, ax = plt.subplots(figsize=(12, 3))
39
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
40
+ plt.colorbar(im, ax=ax)
41
+ plt.xlabel("Frames")
42
+ plt.ylabel("Channels")
43
+ plt.title("Synthesised Mel-Spectrogram")
44
+ fig.canvas.draw()
45
+ plt.savefig(filename)
46
+
47
+
48
+ def process_text(i: int, text: str, device: torch.device):
49
+ print(f"[{i}] - Input text: {text}")
50
+ x = torch.tensor(
51
+ intersperse(text_to_sequence(text, ["basic_cleaners"])[0], 0),
52
+ dtype=torch.long,
53
+ device=device,
54
+ )[None]
55
+ x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
56
+ x_phones = sequence_to_text(x.squeeze(0).tolist())
57
+ print(f"[{i}] - Phonetised text: {x_phones[1::2]}")
58
+
59
+ return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
60
+
61
+
62
+ def get_texts(args):
63
+ if args.text:
64
+ texts = [args.text]
65
+ else:
66
+ with open(args.file, encoding="utf-8") as f:
67
+ texts = f.readlines()
68
+ return texts
69
+
70
+
71
+ def assert_required_models_available(args):
72
+ save_dir = get_user_data_dir()
73
+ if not hasattr(args, "checkpoint_path") and args.checkpoint_path is None:
74
+ model_path = args.checkpoint_path
75
+ else:
76
+ model_path = save_dir / f"{args.model}.ckpt"
77
+ assert_model_downloaded(model_path, MATCHA_URLS[args.model])
78
+
79
+ vocoder_path = save_dir / f"{args.vocoder}"
80
+ assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder])
81
+ return {"matcha": model_path, "vocoder": vocoder_path}
82
+
83
+
84
+ def load_hifigan(checkpoint_path, device):
85
+ h = AttrDict(v1)
86
+ hifigan = HiFiGAN(h).to(device)
87
+ hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"])
88
+ _ = hifigan.eval()
89
+ hifigan.remove_weight_norm()
90
+ return hifigan
91
+
92
+
93
+ def load_vocoder(vocoder_name, checkpoint_path, device):
94
+ print(f"[!] Loading {vocoder_name}!")
95
+ vocoder = None
96
+ if vocoder_name in ("hifigan_T2_v1", "hifigan_univ_v1"):
97
+ vocoder = load_hifigan(checkpoint_path, device)
98
+ else:
99
+ raise NotImplementedError(
100
+ f"Vocoder {vocoder_name} not implemented! define a load_<<vocoder_name>> method for it"
101
+ )
102
+
103
+ denoiser = Denoiser(vocoder, mode="zeros")
104
+ print(f"[+] {vocoder_name} loaded!")
105
+ return vocoder, denoiser
106
+
107
+
108
+ def load_matcha(model_name, checkpoint_path, device):
109
+ print(f"[!] Loading {model_name}!")
110
+ model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
111
+ _ = model.eval()
112
+
113
+ print(f"[+] {model_name} loaded!")
114
+ return model
115
+
116
+
117
+ def to_waveform(mel, vocoder, denoiser=None, denoiser_strength=0.00025):
118
+ audio = vocoder(mel).clamp(-1, 1)
119
+ if denoiser is not None:
120
+ audio = denoiser(audio.squeeze(), strength=denoiser_strength).cpu().squeeze()
121
+
122
+ return audio.cpu().squeeze()
123
+
124
+
125
+ def save_to_folder(filename: str, output: dict, folder: str):
126
+ folder = Path(folder)
127
+ folder.mkdir(exist_ok=True, parents=True)
128
+ plot_spectrogram_to_numpy(np.array(output["mel"].squeeze().float().cpu()), f"{filename}.png")
129
+ np.save(folder / f"{filename}", output["mel"].cpu().numpy())
130
+ sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
131
+ return folder.resolve() / f"{filename}.wav"
132
+
133
+
134
+ def validate_args(args):
135
+ assert (
136
+ args.text or args.file
137
+ ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
138
+ assert args.temperature >= 0, "Sampling temperature cannot be negative"
139
+ assert args.steps > 0, "Number of ODE steps must be greater than 0"
140
+
141
+ if args.checkpoint_path is None:
142
+ # When using pretrained models
143
+ if args.model in SINGLESPEAKER_MODEL:
144
+ args = validate_args_for_single_speaker_model(args)
145
+
146
+ if args.model in MULTISPEAKER_MODEL:
147
+ args = validate_args_for_multispeaker_model(args)
148
+ else:
149
+ # When using a custom model
150
+ if args.vocoder != "hifigan_univ_v1":
151
+ warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech."
152
+ warnings.warn(warn_, UserWarning)
153
+ if args.speaking_rate is None:
154
+ args.speaking_rate = 1.0
155
+
156
+ if args.batched:
157
+ assert args.batch_size > 0, "Batch size must be greater than 0"
158
+ assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
159
+
160
+ return args
161
+
162
+
163
+ def validate_args_for_multispeaker_model(args):
164
+ if args.vocoder is not None:
165
+ if args.vocoder != MULTISPEAKER_MODEL[args.model]["vocoder"]:
166
+ warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {MULTISPEAKER_MODEL[args.model]['vocoder']}"
167
+ warnings.warn(warn_, UserWarning)
168
+ else:
169
+ args.vocoder = MULTISPEAKER_MODEL[args.model]["vocoder"]
170
+
171
+ if args.speaking_rate is None:
172
+ args.speaking_rate = MULTISPEAKER_MODEL[args.model]["speaking_rate"]
173
+
174
+ spk_range = MULTISPEAKER_MODEL[args.model]["spk_range"]
175
+ if args.spk is not None:
176
+ assert (
177
+ args.spk >= spk_range[0] and args.spk <= spk_range[-1]
178
+ ), f"Speaker ID must be between {spk_range} for this model."
179
+ else:
180
+ available_spk_id = MULTISPEAKER_MODEL[args.model]["spk"]
181
+ warn_ = f"[!] Speaker ID not provided! Using speaker ID {available_spk_id}"
182
+ warnings.warn(warn_, UserWarning)
183
+ args.spk = available_spk_id
184
+
185
+ return args
186
+
187
+
188
+ def validate_args_for_single_speaker_model(args):
189
+ if args.vocoder is not None:
190
+ if args.vocoder != SINGLESPEAKER_MODEL[args.model]["vocoder"]:
191
+ warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {SINGLESPEAKER_MODEL[args.model]['vocoder']}"
192
+ warnings.warn(warn_, UserWarning)
193
+ else:
194
+ args.vocoder = SINGLESPEAKER_MODEL[args.model]["vocoder"]
195
+
196
+ if args.speaking_rate is None:
197
+ args.speaking_rate = SINGLESPEAKER_MODEL[args.model]["speaking_rate"]
198
+
199
+ if args.spk != SINGLESPEAKER_MODEL[args.model]["spk"]:
200
+ warn_ = f"[-] Ignoring speaker id {args.spk} for {args.model}"
201
+ warnings.warn(warn_, UserWarning)
202
+ args.spk = SINGLESPEAKER_MODEL[args.model]["spk"]
203
+
204
+ return args
205
+
206
+
207
+ @torch.inference_mode()
208
+ def cli():
209
+ parser = argparse.ArgumentParser(
210
+ description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
211
+ )
212
+ parser.add_argument(
213
+ "--model",
214
+ type=str,
215
+ default="matcha_ljspeech",
216
+ help="Model to use",
217
+ choices=MATCHA_URLS.keys(),
218
+ )
219
+
220
+ parser.add_argument(
221
+ "--checkpoint_path",
222
+ type=str,
223
+ default=None,
224
+ help="Path to the custom model checkpoint",
225
+ )
226
+
227
+ parser.add_argument(
228
+ "--vocoder",
229
+ type=str,
230
+ default=None,
231
+ help="Vocoder to use (default: will use the one suggested with the pretrained model))",
232
+ choices=VOCODER_URLS.keys(),
233
+ )
234
+ parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
235
+ parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
236
+ parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
237
+ parser.add_argument(
238
+ "--temperature",
239
+ type=float,
240
+ default=0.667,
241
+ help="Variance of the x0 noise (default: 0.667)",
242
+ )
243
+ parser.add_argument(
244
+ "--speaking_rate",
245
+ type=float,
246
+ default=None,
247
+ help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
248
+ )
249
+ parser.add_argument("--steps", type=int, default=10, help="Number of ODE steps (default: 10)")
250
+ parser.add_argument("--cpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
251
+ parser.add_argument(
252
+ "--denoiser_strength",
253
+ type=float,
254
+ default=0.00025,
255
+ help="Strength of the vocoder bias denoiser (default: 0.00025)",
256
+ )
257
+ parser.add_argument(
258
+ "--output_folder",
259
+ type=str,
260
+ default=os.getcwd(),
261
+ help="Output folder to save results (default: current dir)",
262
+ )
263
+ parser.add_argument("--batched", action="store_true", help="Batched inference (default: False)")
264
+ parser.add_argument(
265
+ "--batch_size", type=int, default=32, help="Batch size only useful when --batched (default: 32)"
266
+ )
267
+
268
+ args = parser.parse_args()
269
+
270
+ args = validate_args(args)
271
+ device = get_device(args)
272
+ print_config(args)
273
+ paths = assert_required_models_available(args)
274
+
275
+ if args.checkpoint_path is not None:
276
+ print(f"[🍵] Loading custom model from {args.checkpoint_path}")
277
+ paths["matcha"] = args.checkpoint_path
278
+ args.model = "custom_model"
279
+
280
+ model = load_matcha(args.model, paths["matcha"], device)
281
+ vocoder, denoiser = load_vocoder(args.vocoder, paths["vocoder"], device)
282
+
283
+ texts = get_texts(args)
284
+
285
+ spk = torch.tensor([args.spk], device=device, dtype=torch.long) if args.spk is not None else None
286
+ if len(texts) == 1 or not args.batched:
287
+ unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk)
288
+ else:
289
+ batched_synthesis(args, device, model, vocoder, denoiser, texts, spk)
290
+
291
+
292
+ class BatchedSynthesisDataset(torch.utils.data.Dataset):
293
+ def __init__(self, processed_texts):
294
+ self.processed_texts = processed_texts
295
+
296
+ def __len__(self):
297
+ return len(self.processed_texts)
298
+
299
+ def __getitem__(self, idx):
300
+ return self.processed_texts[idx]
301
+
302
+
303
+ def batched_collate_fn(batch):
304
+ x = []
305
+ x_lengths = []
306
+
307
+ for b in batch:
308
+ x.append(b["x"].squeeze(0))
309
+ x_lengths.append(b["x_lengths"])
310
+
311
+ x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
312
+ x_lengths = torch.concat(x_lengths, dim=0)
313
+ return {"x": x, "x_lengths": x_lengths}
314
+
315
+
316
+ def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
317
+ total_rtf = []
318
+ total_rtf_w = []
319
+ processed_text = [process_text(i, text, "cpu") for i, text in enumerate(texts)]
320
+ dataloader = torch.utils.data.DataLoader(
321
+ BatchedSynthesisDataset(processed_text),
322
+ batch_size=args.batch_size,
323
+ collate_fn=batched_collate_fn,
324
+ num_workers=8,
325
+ )
326
+ for i, batch in enumerate(dataloader):
327
+ i = i + 1
328
+ start_t = dt.datetime.now()
329
+ b = batch["x"].shape[0]
330
+ output = model.synthesise(
331
+ batch["x"].to(device),
332
+ batch["x_lengths"].to(device),
333
+ n_timesteps=args.steps,
334
+ temperature=args.temperature,
335
+ spks=spk.expand(b) if spk is not None else spk,
336
+ length_scale=args.speaking_rate,
337
+ )
338
+
339
+ output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, args.denoiser_strength)
340
+ t = (dt.datetime.now() - start_t).total_seconds()
341
+ rtf_w = t * 22050 / (output["waveform"].shape[-1])
342
+ print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}")
343
+ print(f"[🍵-Batch: {i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}")
344
+ total_rtf.append(output["rtf"])
345
+ total_rtf_w.append(rtf_w)
346
+ for j in range(output["mel"].shape[0]):
347
+ base_name = f"utterance_{j:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{j:03d}"
348
+ length = output["mel_lengths"][j]
349
+ new_dict = {"mel": output["mel"][j][:, :length], "waveform": output["waveform"][j][: length * 256]}
350
+ location = save_to_folder(base_name, new_dict, args.output_folder)
351
+ print(f"[🍵-{j}] Waveform saved: {location}")
352
+
353
+ print("".join(["="] * 100))
354
+ print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}")
355
+ print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}")
356
+ print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!")
357
+
358
+
359
+ def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
360
+ total_rtf = []
361
+ total_rtf_w = []
362
+ for i, text in enumerate(texts):
363
+ i = i + 1
364
+ base_name = f"utterance_{i:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{i:03d}"
365
+
366
+ print("".join(["="] * 100))
367
+ text = text.strip()
368
+ text_processed = process_text(i, text, device)
369
+
370
+ print(f"[🍵] Whisking Matcha-T(ea)TS for: {i}")
371
+ start_t = dt.datetime.now()
372
+ output = model.synthesise(
373
+ text_processed["x"],
374
+ text_processed["x_lengths"],
375
+ n_timesteps=args.steps,
376
+ temperature=args.temperature,
377
+ spks=spk,
378
+ length_scale=args.speaking_rate,
379
+ )
380
+ output["waveform"] = to_waveform(output["mel"], vocoder, denoiser, args.denoiser_strength)
381
+ # RTF with HiFiGAN
382
+ t = (dt.datetime.now() - start_t).total_seconds()
383
+ rtf_w = t * 22050 / (output["waveform"].shape[-1])
384
+ print(f"[🍵-{i}] Matcha-TTS RTF: {output['rtf']:.4f}")
385
+ print(f"[🍵-{i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}")
386
+ total_rtf.append(output["rtf"])
387
+ total_rtf_w.append(rtf_w)
388
+
389
+ location = save_to_folder(base_name, output, args.output_folder)
390
+ print(f"[+] Waveform saved: {location}")
391
+
392
+ print("".join(["="] * 100))
393
+ print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}")
394
+ print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}")
395
+ print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!")
396
+
397
+
398
+ def print_config(args):
399
+ print("[!] Configurations: ")
400
+ print(f"\t- Model: {args.model}")
401
+ print(f"\t- Vocoder: {args.vocoder}")
402
+ print(f"\t- Temperature: {args.temperature}")
403
+ print(f"\t- Speaking rate: {args.speaking_rate}")
404
+ print(f"\t- Number of ODE steps: {args.steps}")
405
+ print(f"\t- Speaker: {args.spk}")
406
+
407
+
408
+ def get_device(args):
409
+ if torch.cuda.is_available() and not args.cpu:
410
+ print("[+] GPU Available! Using GPU")
411
+ device = torch.device("cuda")
412
+ else:
413
+ print("[-] GPU not available or forced CPU run! Using CPU")
414
+ device = torch.device("cpu")
415
+ return device
416
+
417
+
418
+ if __name__ == "__main__":
419
+ cli()
matcha/data/__init__.py ADDED
File without changes
matcha/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (176 Bytes). View file
 
matcha/data/__pycache__/text_mel_datamodule.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
matcha/data/components/__init__.py ADDED
File without changes
matcha/data/text_mel_datamodule.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio as ta
8
+ from lightning import LightningDataModule
9
+ from torch.utils.data.dataloader import DataLoader
10
+
11
+ from matcha.text import text_to_sequence
12
+ from matcha.utils.audio import mel_spectrogram
13
+ from matcha.utils.model import fix_len_compatibility, normalize
14
+ from matcha.utils.utils import intersperse
15
+
16
+
17
+ def parse_filelist(filelist_path, split_char="|"):
18
+ with open(filelist_path, encoding="utf-8") as f:
19
+ filepaths_and_text = [line.strip().split(split_char) for line in f]
20
+ return filepaths_and_text
21
+
22
+
23
+ class TextMelDataModule(LightningDataModule):
24
+ def __init__( # pylint: disable=unused-argument
25
+ self,
26
+ name,
27
+ train_filelist_path,
28
+ valid_filelist_path,
29
+ batch_size,
30
+ num_workers,
31
+ pin_memory,
32
+ cleaners,
33
+ add_blank,
34
+ n_spks,
35
+ n_fft,
36
+ n_feats,
37
+ sample_rate,
38
+ hop_length,
39
+ win_length,
40
+ f_min,
41
+ f_max,
42
+ data_statistics,
43
+ seed,
44
+ load_durations,
45
+ ):
46
+ super().__init__()
47
+
48
+ # this line allows to access init params with 'self.hparams' attribute
49
+ # also ensures init params will be stored in ckpt
50
+ self.save_hyperparameters(logger=False)
51
+
52
+ def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument
53
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
54
+
55
+ This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
56
+ careful not to execute things like random split twice!
57
+ """
58
+ # load and split datasets only if not loaded already
59
+
60
+ self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
61
+ self.hparams.train_filelist_path,
62
+ self.hparams.n_spks,
63
+ self.hparams.cleaners,
64
+ self.hparams.add_blank,
65
+ self.hparams.n_fft,
66
+ self.hparams.n_feats,
67
+ self.hparams.sample_rate,
68
+ self.hparams.hop_length,
69
+ self.hparams.win_length,
70
+ self.hparams.f_min,
71
+ self.hparams.f_max,
72
+ self.hparams.data_statistics,
73
+ self.hparams.seed,
74
+ self.hparams.load_durations,
75
+ )
76
+ self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
77
+ self.hparams.valid_filelist_path,
78
+ self.hparams.n_spks,
79
+ self.hparams.cleaners,
80
+ self.hparams.add_blank,
81
+ self.hparams.n_fft,
82
+ self.hparams.n_feats,
83
+ self.hparams.sample_rate,
84
+ self.hparams.hop_length,
85
+ self.hparams.win_length,
86
+ self.hparams.f_min,
87
+ self.hparams.f_max,
88
+ self.hparams.data_statistics,
89
+ self.hparams.seed,
90
+ self.hparams.load_durations,
91
+ )
92
+
93
+ def train_dataloader(self):
94
+ return DataLoader(
95
+ dataset=self.trainset,
96
+ batch_size=self.hparams.batch_size,
97
+ num_workers=self.hparams.num_workers,
98
+ pin_memory=self.hparams.pin_memory,
99
+ shuffle=True,
100
+ collate_fn=TextMelBatchCollate(self.hparams.n_spks),
101
+ )
102
+
103
+ def val_dataloader(self):
104
+ return DataLoader(
105
+ dataset=self.validset,
106
+ batch_size=self.hparams.batch_size,
107
+ num_workers=self.hparams.num_workers,
108
+ pin_memory=self.hparams.pin_memory,
109
+ shuffle=False,
110
+ collate_fn=TextMelBatchCollate(self.hparams.n_spks),
111
+ )
112
+
113
+ def teardown(self, stage: Optional[str] = None):
114
+ """Clean up after fit or test."""
115
+ pass # pylint: disable=unnecessary-pass
116
+
117
+ def state_dict(self):
118
+ """Extra things to save to checkpoint."""
119
+ return {}
120
+
121
+ def load_state_dict(self, state_dict: Dict[str, Any]):
122
+ """Things to do when loading checkpoint."""
123
+ pass # pylint: disable=unnecessary-pass
124
+
125
+
126
+ class TextMelDataset(torch.utils.data.Dataset):
127
+ def __init__(
128
+ self,
129
+ filelist_path,
130
+ n_spks,
131
+ cleaners,
132
+ add_blank=True,
133
+ n_fft=1024,
134
+ n_mels=80,
135
+ sample_rate=22050,
136
+ hop_length=256,
137
+ win_length=1024,
138
+ f_min=0.0,
139
+ f_max=8000,
140
+ data_parameters=None,
141
+ seed=None,
142
+ load_durations=False,
143
+ ):
144
+ self.filepaths_and_text = parse_filelist(filelist_path)
145
+ self.n_spks = n_spks
146
+ self.cleaners = cleaners
147
+ self.add_blank = add_blank
148
+ self.n_fft = n_fft
149
+ self.n_mels = n_mels
150
+ self.sample_rate = sample_rate
151
+ self.hop_length = hop_length
152
+ self.win_length = win_length
153
+ self.f_min = f_min
154
+ self.f_max = f_max
155
+ self.load_durations = load_durations
156
+
157
+ if data_parameters is not None:
158
+ self.data_parameters = data_parameters
159
+ else:
160
+ self.data_parameters = {"mel_mean": 0, "mel_std": 1}
161
+ random.seed(seed)
162
+ random.shuffle(self.filepaths_and_text)
163
+
164
+ def get_datapoint(self, filepath_and_text):
165
+ if self.n_spks > 1:
166
+ filepath, spk, text = (
167
+ filepath_and_text[0],
168
+ int(filepath_and_text[1]),
169
+ filepath_and_text[2],
170
+ )
171
+ else:
172
+ filepath, text = filepath_and_text[0], filepath_and_text[1]
173
+ spk = None
174
+
175
+ text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
176
+ mel = self.get_mel(filepath)
177
+
178
+ durations = self.get_durations(filepath, text) if self.load_durations else None
179
+
180
+ return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations}
181
+
182
+ def get_durations(self, filepath, text):
183
+ filepath = Path(filepath)
184
+ data_dir, name = filepath.parent.parent, filepath.stem
185
+
186
+ try:
187
+ dur_loc = data_dir / "durations" / f"{name}.npy"
188
+ durs = torch.from_numpy(np.load(dur_loc).astype(int))
189
+
190
+ except FileNotFoundError as e:
191
+ raise FileNotFoundError(
192
+ f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n"
193
+ ) from e
194
+
195
+ assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match"
196
+
197
+ return durs
198
+
199
+ def get_mel(self, filepath):
200
+ audio, sr = ta.load(filepath)
201
+ assert sr == self.sample_rate
202
+ mel = mel_spectrogram(
203
+ audio,
204
+ self.n_fft,
205
+ self.n_mels,
206
+ self.sample_rate,
207
+ self.hop_length,
208
+ self.win_length,
209
+ self.f_min,
210
+ self.f_max,
211
+ center=False,
212
+ ).squeeze()
213
+ mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"])
214
+ return mel
215
+
216
+ def get_text(self, text, add_blank=True):
217
+ text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
218
+ if self.add_blank:
219
+ text_norm = intersperse(text_norm, 0)
220
+ text_norm = torch.IntTensor(text_norm)
221
+ return text_norm, cleaned_text
222
+
223
+ def __getitem__(self, index):
224
+ datapoint = self.get_datapoint(self.filepaths_and_text[index])
225
+ return datapoint
226
+
227
+ def __len__(self):
228
+ return len(self.filepaths_and_text)
229
+
230
+
231
+ class TextMelBatchCollate:
232
+ def __init__(self, n_spks):
233
+ self.n_spks = n_spks
234
+
235
+ def __call__(self, batch):
236
+ B = len(batch)
237
+ y_max_length = max([item["y"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
238
+ y_max_length = fix_len_compatibility(y_max_length)
239
+ x_max_length = max([item["x"].shape[-1] for item in batch]) # pylint: disable=consider-using-generator
240
+ n_feats = batch[0]["y"].shape[-2]
241
+
242
+ y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
243
+ x = torch.zeros((B, x_max_length), dtype=torch.long)
244
+ durations = torch.zeros((B, x_max_length), dtype=torch.long)
245
+
246
+ y_lengths, x_lengths = [], []
247
+ spks = []
248
+ filepaths, x_texts = [], []
249
+ for i, item in enumerate(batch):
250
+ y_, x_ = item["y"], item["x"]
251
+ y_lengths.append(y_.shape[-1])
252
+ x_lengths.append(x_.shape[-1])
253
+ y[i, :, : y_.shape[-1]] = y_
254
+ x[i, : x_.shape[-1]] = x_
255
+ spks.append(item["spk"])
256
+ filepaths.append(item["filepath"])
257
+ x_texts.append(item["x_text"])
258
+ if item["durations"] is not None:
259
+ durations[i, : item["durations"].shape[-1]] = item["durations"]
260
+
261
+ y_lengths = torch.tensor(y_lengths, dtype=torch.long)
262
+ x_lengths = torch.tensor(x_lengths, dtype=torch.long)
263
+ spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
264
+
265
+ return {
266
+ "x": x,
267
+ "x_lengths": x_lengths,
268
+ "y": y,
269
+ "y_lengths": y_lengths,
270
+ "spks": spks,
271
+ "filepaths": filepaths,
272
+ "x_texts": x_texts,
273
+ "durations": durations if not torch.eq(durations, 0).all() else None,
274
+ }
matcha/hifigan/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jungil Kong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
matcha/hifigan/README.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
2
+
3
+ ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
4
+
5
+ In our [paper](https://arxiv.org/abs/2010.05646),
6
+ we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.<br/>
7
+ We provide our implementation and pretrained models as open source in this repository.
8
+
9
+ **Abstract :**
10
+ Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
11
+ Although such methods improve the sampling efficiency and memory usage,
12
+ their sample quality has not yet reached that of autoregressive and flow-based generative models.
13
+ In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
14
+ As speech audio consists of sinusoidal signals with various periods,
15
+ we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
16
+ A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
17
+ demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
18
+ real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
19
+ speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
20
+ faster than real-time on CPU with comparable quality to an autoregressive counterpart.
21
+
22
+ Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
23
+
24
+ ## Pre-requisites
25
+
26
+ 1. Python >= 3.6
27
+ 2. Clone this repository.
28
+ 3. Install python requirements. Please refer [requirements.txt](requirements.txt)
29
+ 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
30
+ And move all wav files to `LJSpeech-1.1/wavs`
31
+
32
+ ## Training
33
+
34
+ ```
35
+ python train.py --config config_v1.json
36
+ ```
37
+
38
+ To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.<br>
39
+ Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.<br>
40
+ You can change the path by adding `--checkpoint_path` option.
41
+
42
+ Validation loss during training with V1 generator.<br>
43
+ ![validation loss](./validation_loss.png)
44
+
45
+ ## Pretrained Model
46
+
47
+ You can also use pretrained models we provide.<br/>
48
+ [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)<br/>
49
+ Details of each folder are as in follows:
50
+
51
+ | Folder Name | Generator | Dataset | Fine-Tuned |
52
+ | ------------ | --------- | --------- | ------------------------------------------------------ |
53
+ | LJ_V1 | V1 | LJSpeech | No |
54
+ | LJ_V2 | V2 | LJSpeech | No |
55
+ | LJ_V3 | V3 | LJSpeech | No |
56
+ | LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
57
+ | LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
58
+ | LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
59
+ | VCTK_V1 | V1 | VCTK | No |
60
+ | VCTK_V2 | V2 | VCTK | No |
61
+ | VCTK_V3 | V3 | VCTK | No |
62
+ | UNIVERSAL_V1 | V1 | Universal | No |
63
+
64
+ We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
65
+
66
+ ## Fine-Tuning
67
+
68
+ 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.<br/>
69
+ The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.<br/>
70
+ Example:
71
+ ` Audio File : LJ001-0001.wav
72
+ Mel-Spectrogram File : LJ001-0001.npy`
73
+ 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.<br/>
74
+ 3. Run the following command.
75
+ ```
76
+ python train.py --fine_tuning True --config config_v1.json
77
+ ```
78
+ For other command line options, please refer to the training section.
79
+
80
+ ## Inference from wav file
81
+
82
+ 1. Make `test_files` directory and copy wav files into the directory.
83
+ 2. Run the following command.
84
+ ` python inference.py --checkpoint_file [generator checkpoint file path]`
85
+ Generated wav files are saved in `generated_files` by default.<br>
86
+ You can change the path by adding `--output_dir` option.
87
+
88
+ ## Inference for end-to-end speech synthesis
89
+
90
+ 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.<br>
91
+ You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
92
+ [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
93
+ 2. Run the following command.
94
+ ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]`
95
+ Generated wav files are saved in `generated_files_from_mel` by default.<br>
96
+ You can change the path by adding `--output_dir` option.
97
+
98
+ ## Acknowledgements
99
+
100
+ We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
101
+ and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.
matcha/hifigan/__init__.py ADDED
File without changes
matcha/hifigan/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (179 Bytes). View file
 
matcha/hifigan/__pycache__/config.cpython-311.pyc ADDED
Binary file (1.23 kB). View file
 
matcha/hifigan/__pycache__/env.cpython-311.pyc ADDED
Binary file (1.38 kB). View file
 
matcha/hifigan/__pycache__/models.cpython-311.pyc ADDED
Binary file (19.1 kB). View file
 
matcha/hifigan/__pycache__/xutils.cpython-311.pyc ADDED
Binary file (3.45 kB). View file
 
matcha/hifigan/config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ v1 = {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 16,
5
+ "learning_rate": 0.0004,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.999,
9
+ "seed": 1234,
10
+ "upsample_rates": [8, 8, 2, 2],
11
+ "upsample_kernel_sizes": [16, 16, 4, 4],
12
+ "upsample_initial_channel": 512,
13
+ "resblock_kernel_sizes": [3, 7, 11],
14
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
15
+ "resblock_initial_channel": 256,
16
+ "segment_size": 8192,
17
+ "num_mels": 80,
18
+ "num_freq": 1025,
19
+ "n_fft": 1024,
20
+ "hop_size": 256,
21
+ "win_size": 1024,
22
+ "sampling_rate": 22050,
23
+ "fmin": 0,
24
+ "fmax": 8000,
25
+ "fmax_loss": None,
26
+ "num_workers": 4,
27
+ "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1},
28
+ }
matcha/hifigan/denoiser.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
2
+
3
+ """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
4
+ import torch
5
+
6
+
7
+ class ModeException(Exception):
8
+ pass
9
+
10
+
11
+ class Denoiser(torch.nn.Module):
12
+ """Removes model bias from audio produced with waveglow"""
13
+
14
+ def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
15
+ super().__init__()
16
+ self.filter_length = filter_length
17
+ self.hop_length = int(filter_length / n_overlap)
18
+ self.win_length = win_length
19
+
20
+ dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
21
+ self.device = device
22
+ if mode == "zeros":
23
+ mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
24
+ elif mode == "normal":
25
+ mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
26
+ else:
27
+ raise ModeException(f"Mode {mode} if not supported")
28
+
29
+ def stft_fn(audio, n_fft, hop_length, win_length, window):
30
+ spec = torch.stft(
31
+ audio,
32
+ n_fft=n_fft,
33
+ hop_length=hop_length,
34
+ win_length=win_length,
35
+ window=window,
36
+ return_complex=True,
37
+ )
38
+ spec = torch.view_as_real(spec)
39
+ return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
40
+
41
+ self.stft = lambda x: stft_fn(
42
+ audio=x,
43
+ n_fft=self.filter_length,
44
+ hop_length=self.hop_length,
45
+ win_length=self.win_length,
46
+ window=torch.hann_window(self.win_length, device=device),
47
+ )
48
+ self.istft = lambda x, y: torch.istft(
49
+ torch.complex(x * torch.cos(y), x * torch.sin(y)),
50
+ n_fft=self.filter_length,
51
+ hop_length=self.hop_length,
52
+ win_length=self.win_length,
53
+ window=torch.hann_window(self.win_length, device=device),
54
+ )
55
+
56
+ with torch.no_grad():
57
+ bias_audio = vocoder(mel_input).float().squeeze(0)
58
+ bias_spec, _ = self.stft(bias_audio)
59
+
60
+ self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
61
+
62
+ @torch.inference_mode()
63
+ def forward(self, audio, strength=0.0005):
64
+ audio_spec, audio_angles = self.stft(audio)
65
+ audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength
66
+ audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
67
+ audio_denoised = self.istft(audio_spec_denoised, audio_angles)
68
+ return audio_denoised
matcha/hifigan/env.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import os
4
+ import shutil
5
+
6
+
7
+ class AttrDict(dict):
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.__dict__ = self
11
+
12
+
13
+ def build_env(config, config_name, path):
14
+ t_path = os.path.join(path, config_name)
15
+ if config != t_path:
16
+ os.makedirs(path, exist_ok=True)
17
+ shutil.copyfile(config, os.path.join(path, config_name))
matcha/hifigan/meldataset.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import math
4
+ import os
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ from librosa.filters import mel as librosa_mel_fn
11
+ from librosa.util import normalize
12
+ from scipy.io.wavfile import read
13
+
14
+ MAX_WAV_VALUE = 32768.0
15
+
16
+
17
+ def load_wav(full_path):
18
+ sampling_rate, data = read(full_path)
19
+ return data, sampling_rate
20
+
21
+
22
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
23
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
24
+
25
+
26
+ def dynamic_range_decompression(x, C=1):
27
+ return np.exp(x) / C
28
+
29
+
30
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
31
+ return torch.log(torch.clamp(x, min=clip_val) * C)
32
+
33
+
34
+ def dynamic_range_decompression_torch(x, C=1):
35
+ return torch.exp(x) / C
36
+
37
+
38
+ def spectral_normalize_torch(magnitudes):
39
+ output = dynamic_range_compression_torch(magnitudes)
40
+ return output
41
+
42
+
43
+ def spectral_de_normalize_torch(magnitudes):
44
+ output = dynamic_range_decompression_torch(magnitudes)
45
+ return output
46
+
47
+
48
+ mel_basis = {}
49
+ hann_window = {}
50
+
51
+
52
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
53
+ if torch.min(y) < -1.0:
54
+ print("min value is ", torch.min(y))
55
+ if torch.max(y) > 1.0:
56
+ print("max value is ", torch.max(y))
57
+
58
+ global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
59
+ if fmax not in mel_basis:
60
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
61
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
62
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
63
+
64
+ y = torch.nn.functional.pad(
65
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
66
+ )
67
+ y = y.squeeze(1)
68
+
69
+ spec = torch.view_as_real(
70
+ torch.stft(
71
+ y,
72
+ n_fft,
73
+ hop_length=hop_size,
74
+ win_length=win_size,
75
+ window=hann_window[str(y.device)],
76
+ center=center,
77
+ pad_mode="reflect",
78
+ normalized=False,
79
+ onesided=True,
80
+ return_complex=True,
81
+ )
82
+ )
83
+
84
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
85
+
86
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
87
+ spec = spectral_normalize_torch(spec)
88
+
89
+ return spec
90
+
91
+
92
+ def get_dataset_filelist(a):
93
+ with open(a.input_training_file, encoding="utf-8") as fi:
94
+ training_files = [
95
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
96
+ ]
97
+
98
+ with open(a.input_validation_file, encoding="utf-8") as fi:
99
+ validation_files = [
100
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
101
+ ]
102
+ return training_files, validation_files
103
+
104
+
105
+ class MelDataset(torch.utils.data.Dataset):
106
+ def __init__(
107
+ self,
108
+ training_files,
109
+ segment_size,
110
+ n_fft,
111
+ num_mels,
112
+ hop_size,
113
+ win_size,
114
+ sampling_rate,
115
+ fmin,
116
+ fmax,
117
+ split=True,
118
+ shuffle=True,
119
+ n_cache_reuse=1,
120
+ device=None,
121
+ fmax_loss=None,
122
+ fine_tuning=False,
123
+ base_mels_path=None,
124
+ ):
125
+ self.audio_files = training_files
126
+ random.seed(1234)
127
+ if shuffle:
128
+ random.shuffle(self.audio_files)
129
+ self.segment_size = segment_size
130
+ self.sampling_rate = sampling_rate
131
+ self.split = split
132
+ self.n_fft = n_fft
133
+ self.num_mels = num_mels
134
+ self.hop_size = hop_size
135
+ self.win_size = win_size
136
+ self.fmin = fmin
137
+ self.fmax = fmax
138
+ self.fmax_loss = fmax_loss
139
+ self.cached_wav = None
140
+ self.n_cache_reuse = n_cache_reuse
141
+ self._cache_ref_count = 0
142
+ self.device = device
143
+ self.fine_tuning = fine_tuning
144
+ self.base_mels_path = base_mels_path
145
+
146
+ def __getitem__(self, index):
147
+ filename = self.audio_files[index]
148
+ if self._cache_ref_count == 0:
149
+ audio, sampling_rate = load_wav(filename)
150
+ audio = audio / MAX_WAV_VALUE
151
+ if not self.fine_tuning:
152
+ audio = normalize(audio) * 0.95
153
+ self.cached_wav = audio
154
+ if sampling_rate != self.sampling_rate:
155
+ raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR")
156
+ self._cache_ref_count = self.n_cache_reuse
157
+ else:
158
+ audio = self.cached_wav
159
+ self._cache_ref_count -= 1
160
+
161
+ audio = torch.FloatTensor(audio)
162
+ audio = audio.unsqueeze(0)
163
+
164
+ if not self.fine_tuning:
165
+ if self.split:
166
+ if audio.size(1) >= self.segment_size:
167
+ max_audio_start = audio.size(1) - self.segment_size
168
+ audio_start = random.randint(0, max_audio_start)
169
+ audio = audio[:, audio_start : audio_start + self.segment_size]
170
+ else:
171
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
172
+
173
+ mel = mel_spectrogram(
174
+ audio,
175
+ self.n_fft,
176
+ self.num_mels,
177
+ self.sampling_rate,
178
+ self.hop_size,
179
+ self.win_size,
180
+ self.fmin,
181
+ self.fmax,
182
+ center=False,
183
+ )
184
+ else:
185
+ mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy"))
186
+ mel = torch.from_numpy(mel)
187
+
188
+ if len(mel.shape) < 3:
189
+ mel = mel.unsqueeze(0)
190
+
191
+ if self.split:
192
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
193
+
194
+ if audio.size(1) >= self.segment_size:
195
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
196
+ mel = mel[:, :, mel_start : mel_start + frames_per_seg]
197
+ audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size]
198
+ else:
199
+ mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
200
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
201
+
202
+ mel_loss = mel_spectrogram(
203
+ audio,
204
+ self.n_fft,
205
+ self.num_mels,
206
+ self.sampling_rate,
207
+ self.hop_size,
208
+ self.win_size,
209
+ self.fmin,
210
+ self.fmax_loss,
211
+ center=False,
212
+ )
213
+
214
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
215
+
216
+ def __len__(self):
217
+ return len(self.audio_files)
matcha/hifigan/models.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import torch
4
+ import torch.nn as nn # pylint: disable=consider-using-from-import
5
+ import torch.nn.functional as F
6
+ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
8
+
9
+ from .xutils import get_padding, init_weights
10
+
11
+ LRELU_SLOPE = 0.1
12
+
13
+
14
+ class ResBlock1(torch.nn.Module):
15
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
16
+ super().__init__()
17
+ self.h = h
18
+ self.convs1 = nn.ModuleList(
19
+ [
20
+ weight_norm(
21
+ Conv1d(
22
+ channels,
23
+ channels,
24
+ kernel_size,
25
+ 1,
26
+ dilation=dilation[0],
27
+ padding=get_padding(kernel_size, dilation[0]),
28
+ )
29
+ ),
30
+ weight_norm(
31
+ Conv1d(
32
+ channels,
33
+ channels,
34
+ kernel_size,
35
+ 1,
36
+ dilation=dilation[1],
37
+ padding=get_padding(kernel_size, dilation[1]),
38
+ )
39
+ ),
40
+ weight_norm(
41
+ Conv1d(
42
+ channels,
43
+ channels,
44
+ kernel_size,
45
+ 1,
46
+ dilation=dilation[2],
47
+ padding=get_padding(kernel_size, dilation[2]),
48
+ )
49
+ ),
50
+ ]
51
+ )
52
+ self.convs1.apply(init_weights)
53
+
54
+ self.convs2 = nn.ModuleList(
55
+ [
56
+ weight_norm(
57
+ Conv1d(
58
+ channels,
59
+ channels,
60
+ kernel_size,
61
+ 1,
62
+ dilation=1,
63
+ padding=get_padding(kernel_size, 1),
64
+ )
65
+ ),
66
+ weight_norm(
67
+ Conv1d(
68
+ channels,
69
+ channels,
70
+ kernel_size,
71
+ 1,
72
+ dilation=1,
73
+ padding=get_padding(kernel_size, 1),
74
+ )
75
+ ),
76
+ weight_norm(
77
+ Conv1d(
78
+ channels,
79
+ channels,
80
+ kernel_size,
81
+ 1,
82
+ dilation=1,
83
+ padding=get_padding(kernel_size, 1),
84
+ )
85
+ ),
86
+ ]
87
+ )
88
+ self.convs2.apply(init_weights)
89
+
90
+ def forward(self, x):
91
+ for c1, c2 in zip(self.convs1, self.convs2):
92
+ xt = F.leaky_relu(x, LRELU_SLOPE)
93
+ xt = c1(xt)
94
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
95
+ xt = c2(xt)
96
+ x = xt + x
97
+ return x
98
+
99
+ def remove_weight_norm(self):
100
+ for l in self.convs1:
101
+ remove_weight_norm(l)
102
+ for l in self.convs2:
103
+ remove_weight_norm(l)
104
+
105
+
106
+ class ResBlock2(torch.nn.Module):
107
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
108
+ super().__init__()
109
+ self.h = h
110
+ self.convs = nn.ModuleList(
111
+ [
112
+ weight_norm(
113
+ Conv1d(
114
+ channels,
115
+ channels,
116
+ kernel_size,
117
+ 1,
118
+ dilation=dilation[0],
119
+ padding=get_padding(kernel_size, dilation[0]),
120
+ )
121
+ ),
122
+ weight_norm(
123
+ Conv1d(
124
+ channels,
125
+ channels,
126
+ kernel_size,
127
+ 1,
128
+ dilation=dilation[1],
129
+ padding=get_padding(kernel_size, dilation[1]),
130
+ )
131
+ ),
132
+ ]
133
+ )
134
+ self.convs.apply(init_weights)
135
+
136
+ def forward(self, x):
137
+ for c in self.convs:
138
+ xt = F.leaky_relu(x, LRELU_SLOPE)
139
+ xt = c(xt)
140
+ x = xt + x
141
+ return x
142
+
143
+ def remove_weight_norm(self):
144
+ for l in self.convs:
145
+ remove_weight_norm(l)
146
+
147
+
148
+ class Generator(torch.nn.Module):
149
+ def __init__(self, h):
150
+ super().__init__()
151
+ self.h = h
152
+ self.num_kernels = len(h.resblock_kernel_sizes)
153
+ self.num_upsamples = len(h.upsample_rates)
154
+ self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
155
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
156
+
157
+ self.ups = nn.ModuleList()
158
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
159
+ self.ups.append(
160
+ weight_norm(
161
+ ConvTranspose1d(
162
+ h.upsample_initial_channel // (2**i),
163
+ h.upsample_initial_channel // (2 ** (i + 1)),
164
+ k,
165
+ u,
166
+ padding=(k - u) // 2,
167
+ )
168
+ )
169
+ )
170
+
171
+ self.resblocks = nn.ModuleList()
172
+ for i in range(len(self.ups)):
173
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
174
+ for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
175
+ self.resblocks.append(resblock(h, ch, k, d))
176
+
177
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
178
+ self.ups.apply(init_weights)
179
+ self.conv_post.apply(init_weights)
180
+
181
+ def forward(self, x):
182
+ x = self.conv_pre(x)
183
+ for i in range(self.num_upsamples):
184
+ x = F.leaky_relu(x, LRELU_SLOPE)
185
+ x = self.ups[i](x)
186
+ xs = None
187
+ for j in range(self.num_kernels):
188
+ if xs is None:
189
+ xs = self.resblocks[i * self.num_kernels + j](x)
190
+ else:
191
+ xs += self.resblocks[i * self.num_kernels + j](x)
192
+ x = xs / self.num_kernels
193
+ x = F.leaky_relu(x)
194
+ x = self.conv_post(x)
195
+ x = torch.tanh(x)
196
+
197
+ return x
198
+
199
+ def remove_weight_norm(self):
200
+ print("Removing weight norm...")
201
+ for l in self.ups:
202
+ remove_weight_norm(l)
203
+ for l in self.resblocks:
204
+ l.remove_weight_norm()
205
+ remove_weight_norm(self.conv_pre)
206
+ remove_weight_norm(self.conv_post)
207
+
208
+
209
+ class DiscriminatorP(torch.nn.Module):
210
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
211
+ super().__init__()
212
+ self.period = period
213
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
214
+ self.convs = nn.ModuleList(
215
+ [
216
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
217
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
218
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
219
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
220
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
221
+ ]
222
+ )
223
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
224
+
225
+ def forward(self, x):
226
+ fmap = []
227
+
228
+ # 1d to 2d
229
+ b, c, t = x.shape
230
+ if t % self.period != 0: # pad first
231
+ n_pad = self.period - (t % self.period)
232
+ x = F.pad(x, (0, n_pad), "reflect")
233
+ t = t + n_pad
234
+ x = x.view(b, c, t // self.period, self.period)
235
+
236
+ for l in self.convs:
237
+ x = l(x)
238
+ x = F.leaky_relu(x, LRELU_SLOPE)
239
+ fmap.append(x)
240
+ x = self.conv_post(x)
241
+ fmap.append(x)
242
+ x = torch.flatten(x, 1, -1)
243
+
244
+ return x, fmap
245
+
246
+
247
+ class MultiPeriodDiscriminator(torch.nn.Module):
248
+ def __init__(self):
249
+ super().__init__()
250
+ self.discriminators = nn.ModuleList(
251
+ [
252
+ DiscriminatorP(2),
253
+ DiscriminatorP(3),
254
+ DiscriminatorP(5),
255
+ DiscriminatorP(7),
256
+ DiscriminatorP(11),
257
+ ]
258
+ )
259
+
260
+ def forward(self, y, y_hat):
261
+ y_d_rs = []
262
+ y_d_gs = []
263
+ fmap_rs = []
264
+ fmap_gs = []
265
+ for _, d in enumerate(self.discriminators):
266
+ y_d_r, fmap_r = d(y)
267
+ y_d_g, fmap_g = d(y_hat)
268
+ y_d_rs.append(y_d_r)
269
+ fmap_rs.append(fmap_r)
270
+ y_d_gs.append(y_d_g)
271
+ fmap_gs.append(fmap_g)
272
+
273
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
274
+
275
+
276
+ class DiscriminatorS(torch.nn.Module):
277
+ def __init__(self, use_spectral_norm=False):
278
+ super().__init__()
279
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
280
+ self.convs = nn.ModuleList(
281
+ [
282
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
283
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
284
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
285
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
286
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
287
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
288
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
289
+ ]
290
+ )
291
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
292
+
293
+ def forward(self, x):
294
+ fmap = []
295
+ for l in self.convs:
296
+ x = l(x)
297
+ x = F.leaky_relu(x, LRELU_SLOPE)
298
+ fmap.append(x)
299
+ x = self.conv_post(x)
300
+ fmap.append(x)
301
+ x = torch.flatten(x, 1, -1)
302
+
303
+ return x, fmap
304
+
305
+
306
+ class MultiScaleDiscriminator(torch.nn.Module):
307
+ def __init__(self):
308
+ super().__init__()
309
+ self.discriminators = nn.ModuleList(
310
+ [
311
+ DiscriminatorS(use_spectral_norm=True),
312
+ DiscriminatorS(),
313
+ DiscriminatorS(),
314
+ ]
315
+ )
316
+ self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
317
+
318
+ def forward(self, y, y_hat):
319
+ y_d_rs = []
320
+ y_d_gs = []
321
+ fmap_rs = []
322
+ fmap_gs = []
323
+ for i, d in enumerate(self.discriminators):
324
+ if i != 0:
325
+ y = self.meanpools[i - 1](y)
326
+ y_hat = self.meanpools[i - 1](y_hat)
327
+ y_d_r, fmap_r = d(y)
328
+ y_d_g, fmap_g = d(y_hat)
329
+ y_d_rs.append(y_d_r)
330
+ fmap_rs.append(fmap_r)
331
+ y_d_gs.append(y_d_g)
332
+ fmap_gs.append(fmap_g)
333
+
334
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
335
+
336
+
337
+ def feature_loss(fmap_r, fmap_g):
338
+ loss = 0
339
+ for dr, dg in zip(fmap_r, fmap_g):
340
+ for rl, gl in zip(dr, dg):
341
+ loss += torch.mean(torch.abs(rl - gl))
342
+
343
+ return loss * 2
344
+
345
+
346
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
347
+ loss = 0
348
+ r_losses = []
349
+ g_losses = []
350
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
351
+ r_loss = torch.mean((1 - dr) ** 2)
352
+ g_loss = torch.mean(dg**2)
353
+ loss += r_loss + g_loss
354
+ r_losses.append(r_loss.item())
355
+ g_losses.append(g_loss.item())
356
+
357
+ return loss, r_losses, g_losses
358
+
359
+
360
+ def generator_loss(disc_outputs):
361
+ loss = 0
362
+ gen_losses = []
363
+ for dg in disc_outputs:
364
+ l = torch.mean((1 - dg) ** 2)
365
+ gen_losses.append(l)
366
+ loss += l
367
+
368
+ return loss, gen_losses
matcha/hifigan/xutils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import glob
4
+ import os
5
+
6
+ import matplotlib
7
+ import torch
8
+ from torch.nn.utils import weight_norm
9
+
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pylab as plt
12
+
13
+
14
+ def plot_spectrogram(spectrogram):
15
+ fig, ax = plt.subplots(figsize=(10, 2))
16
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
17
+ plt.colorbar(im, ax=ax)
18
+
19
+ fig.canvas.draw()
20
+ plt.close()
21
+
22
+ return fig
23
+
24
+
25
+ def init_weights(m, mean=0.0, std=0.01):
26
+ classname = m.__class__.__name__
27
+ if classname.find("Conv") != -1:
28
+ m.weight.data.normal_(mean, std)
29
+
30
+
31
+ def apply_weight_norm(m):
32
+ classname = m.__class__.__name__
33
+ if classname.find("Conv") != -1:
34
+ weight_norm(m)
35
+
36
+
37
+ def get_padding(kernel_size, dilation=1):
38
+ return int((kernel_size * dilation - dilation) / 2)
39
+
40
+
41
+ def load_checkpoint(filepath, device):
42
+ assert os.path.isfile(filepath)
43
+ print(f"Loading '{filepath}'")
44
+ checkpoint_dict = torch.load(filepath, map_location=device)
45
+ print("Complete.")
46
+ return checkpoint_dict
47
+
48
+
49
+ def save_checkpoint(filepath, obj):
50
+ print(f"Saving checkpoint to {filepath}")
51
+ torch.save(obj, filepath)
52
+ print("Complete.")
53
+
54
+
55
+ def scan_checkpoint(cp_dir, prefix):
56
+ pattern = os.path.join(cp_dir, prefix + "????????")
57
+ cp_list = glob.glob(pattern)
58
+ if len(cp_list) == 0:
59
+ return None
60
+ return sorted(cp_list)[-1]
matcha/models/__init__.py ADDED
File without changes
matcha/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (178 Bytes). View file
 
matcha/models/__pycache__/baselightningmodule.cpython-311.pyc ADDED
Binary file (9.75 kB). View file
 
matcha/models/__pycache__/matcha_tts.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
matcha/models/baselightningmodule.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a base lightning module that can be used to train a model.
3
+ The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
4
+ """
5
+ import inspect
6
+ from abc import ABC
7
+ from typing import Any, Dict
8
+
9
+ import torch
10
+ from lightning import LightningModule
11
+ from lightning.pytorch.utilities import grad_norm
12
+
13
+ from matcha import utils
14
+ from matcha.utils.utils import plot_tensor
15
+
16
+ log = utils.get_pylogger(__name__)
17
+
18
+
19
+ class BaseLightningClass(LightningModule, ABC):
20
+ def update_data_statistics(self, data_statistics):
21
+ if data_statistics is None:
22
+ data_statistics = {
23
+ "mel_mean": 0.0,
24
+ "mel_std": 1.0,
25
+ }
26
+
27
+ self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
28
+ self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
29
+
30
+ def configure_optimizers(self) -> Any:
31
+ optimizer = self.hparams.optimizer(params=self.parameters())
32
+ if self.hparams.scheduler not in (None, {}):
33
+ scheduler_args = {}
34
+ # Manage last epoch for exponential schedulers
35
+ if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters:
36
+ if hasattr(self, "ckpt_loaded_epoch"):
37
+ current_epoch = self.ckpt_loaded_epoch - 1
38
+ else:
39
+ current_epoch = -1
40
+
41
+ scheduler_args.update({"optimizer": optimizer})
42
+ scheduler = self.hparams.scheduler.scheduler(**scheduler_args)
43
+ scheduler.last_epoch = current_epoch
44
+ return {
45
+ "optimizer": optimizer,
46
+ "lr_scheduler": {
47
+ "scheduler": scheduler,
48
+ "interval": self.hparams.scheduler.lightning_args.interval,
49
+ "frequency": self.hparams.scheduler.lightning_args.frequency,
50
+ "name": "learning_rate",
51
+ },
52
+ }
53
+
54
+ return {"optimizer": optimizer}
55
+
56
+ def get_losses(self, batch):
57
+ x, x_lengths = batch["x"], batch["x_lengths"]
58
+ y, y_lengths = batch["y"], batch["y_lengths"]
59
+ spks = batch["spks"]
60
+
61
+ dur_loss, prior_loss, diff_loss, *_ = self(
62
+ x=x,
63
+ x_lengths=x_lengths,
64
+ y=y,
65
+ y_lengths=y_lengths,
66
+ spks=spks,
67
+ out_size=self.out_size,
68
+ durations=batch["durations"],
69
+ )
70
+ return {
71
+ "dur_loss": dur_loss,
72
+ "prior_loss": prior_loss,
73
+ "diff_loss": diff_loss,
74
+ }
75
+
76
+ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
77
+ self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init
78
+
79
+ def training_step(self, batch: Any, batch_idx: int):
80
+ loss_dict = self.get_losses(batch)
81
+ self.log(
82
+ "step",
83
+ float(self.global_step),
84
+ on_step=True,
85
+ prog_bar=True,
86
+ logger=True,
87
+ sync_dist=True,
88
+ )
89
+
90
+ self.log(
91
+ "sub_loss/train_dur_loss",
92
+ loss_dict["dur_loss"],
93
+ on_step=True,
94
+ on_epoch=True,
95
+ logger=True,
96
+ sync_dist=True,
97
+ )
98
+ self.log(
99
+ "sub_loss/train_prior_loss",
100
+ loss_dict["prior_loss"],
101
+ on_step=True,
102
+ on_epoch=True,
103
+ logger=True,
104
+ sync_dist=True,
105
+ )
106
+ self.log(
107
+ "sub_loss/train_diff_loss",
108
+ loss_dict["diff_loss"],
109
+ on_step=True,
110
+ on_epoch=True,
111
+ logger=True,
112
+ sync_dist=True,
113
+ )
114
+
115
+ total_loss = sum(loss_dict.values())
116
+ self.log(
117
+ "loss/train",
118
+ total_loss,
119
+ on_step=True,
120
+ on_epoch=True,
121
+ logger=True,
122
+ prog_bar=True,
123
+ sync_dist=True,
124
+ )
125
+
126
+ return {"loss": total_loss, "log": loss_dict}
127
+
128
+ def validation_step(self, batch: Any, batch_idx: int):
129
+ loss_dict = self.get_losses(batch)
130
+ self.log(
131
+ "sub_loss/val_dur_loss",
132
+ loss_dict["dur_loss"],
133
+ on_step=True,
134
+ on_epoch=True,
135
+ logger=True,
136
+ sync_dist=True,
137
+ )
138
+ self.log(
139
+ "sub_loss/val_prior_loss",
140
+ loss_dict["prior_loss"],
141
+ on_step=True,
142
+ on_epoch=True,
143
+ logger=True,
144
+ sync_dist=True,
145
+ )
146
+ self.log(
147
+ "sub_loss/val_diff_loss",
148
+ loss_dict["diff_loss"],
149
+ on_step=True,
150
+ on_epoch=True,
151
+ logger=True,
152
+ sync_dist=True,
153
+ )
154
+
155
+ total_loss = sum(loss_dict.values())
156
+ self.log(
157
+ "loss/val",
158
+ total_loss,
159
+ on_step=True,
160
+ on_epoch=True,
161
+ logger=True,
162
+ prog_bar=True,
163
+ sync_dist=True,
164
+ )
165
+
166
+ return total_loss
167
+
168
+ def on_validation_end(self) -> None:
169
+ if self.trainer.is_global_zero:
170
+ one_batch = next(iter(self.trainer.val_dataloaders))
171
+ if self.current_epoch == 0:
172
+ log.debug("Plotting original samples")
173
+ for i in range(2):
174
+ y = one_batch["y"][i].unsqueeze(0).to(self.device)
175
+ self.logger.experiment.add_image(
176
+ f"original/{i}",
177
+ plot_tensor(y.squeeze().cpu()),
178
+ self.current_epoch,
179
+ dataformats="HWC",
180
+ )
181
+
182
+ log.debug("Synthesising...")
183
+ for i in range(2):
184
+ x = one_batch["x"][i].unsqueeze(0).to(self.device)
185
+ x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
186
+ spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None
187
+ output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks)
188
+ y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
189
+ attn = output["attn"]
190
+ self.logger.experiment.add_image(
191
+ f"generated_enc/{i}",
192
+ plot_tensor(y_enc.squeeze().cpu()),
193
+ self.current_epoch,
194
+ dataformats="HWC",
195
+ )
196
+ self.logger.experiment.add_image(
197
+ f"generated_dec/{i}",
198
+ plot_tensor(y_dec.squeeze().cpu()),
199
+ self.current_epoch,
200
+ dataformats="HWC",
201
+ )
202
+ self.logger.experiment.add_image(
203
+ f"alignment/{i}",
204
+ plot_tensor(attn.squeeze().cpu()),
205
+ self.current_epoch,
206
+ dataformats="HWC",
207
+ )
208
+
209
+ def on_before_optimizer_step(self, optimizer):
210
+ self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()})
matcha/models/components/__init__.py ADDED
File without changes
matcha/models/components/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (189 Bytes). View file
 
matcha/models/components/__pycache__/decoder.cpython-311.pyc ADDED
Binary file (21.3 kB). View file
 
matcha/models/components/__pycache__/flow_matching.cpython-311.pyc ADDED
Binary file (6.6 kB). View file
 
matcha/models/components/__pycache__/text_encoder.cpython-311.pyc ADDED
Binary file (25 kB). View file
 
matcha/models/components/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (14.8 kB). View file
 
matcha/models/components/decoder.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn # pylint: disable=consider-using-from-import
6
+ import torch.nn.functional as F
7
+ from conformer import ConformerBlock
8
+ from diffusers.models.activations import get_activation
9
+ from einops import pack, rearrange, repeat
10
+
11
+ from matcha.models.components.transformer import BasicTransformerBlock
12
+
13
+
14
+ class SinusoidalPosEmb(torch.nn.Module):
15
+ def __init__(self, dim):
16
+ super().__init__()
17
+ self.dim = dim
18
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
19
+
20
+ def forward(self, x, scale=1000):
21
+ if x.ndim < 1:
22
+ x = x.unsqueeze(0)
23
+ device = x.device
24
+ half_dim = self.dim // 2
25
+ emb = math.log(10000) / (half_dim - 1)
26
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
27
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
28
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
29
+ return emb
30
+
31
+
32
+ class Block1D(torch.nn.Module):
33
+ def __init__(self, dim, dim_out, groups=8):
34
+ super().__init__()
35
+ self.block = torch.nn.Sequential(
36
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
37
+ torch.nn.GroupNorm(groups, dim_out),
38
+ nn.Mish(),
39
+ )
40
+
41
+ def forward(self, x, mask):
42
+ output = self.block(x * mask)
43
+ return output * mask
44
+
45
+
46
+ class ResnetBlock1D(torch.nn.Module):
47
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
48
+ super().__init__()
49
+ self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
50
+
51
+ self.block1 = Block1D(dim, dim_out, groups=groups)
52
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
53
+
54
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
55
+
56
+ def forward(self, x, mask, time_emb):
57
+ h = self.block1(x, mask)
58
+ h += self.mlp(time_emb).unsqueeze(-1)
59
+ h = self.block2(h, mask)
60
+ output = h + self.res_conv(x * mask)
61
+ return output
62
+
63
+
64
+ class Downsample1D(nn.Module):
65
+ def __init__(self, dim):
66
+ super().__init__()
67
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
68
+
69
+ def forward(self, x):
70
+ return self.conv(x)
71
+
72
+
73
+ class TimestepEmbedding(nn.Module):
74
+ def __init__(
75
+ self,
76
+ in_channels: int,
77
+ time_embed_dim: int,
78
+ act_fn: str = "silu",
79
+ out_dim: int = None,
80
+ post_act_fn: Optional[str] = None,
81
+ cond_proj_dim=None,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
86
+
87
+ if cond_proj_dim is not None:
88
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
89
+ else:
90
+ self.cond_proj = None
91
+
92
+ self.act = get_activation(act_fn)
93
+
94
+ if out_dim is not None:
95
+ time_embed_dim_out = out_dim
96
+ else:
97
+ time_embed_dim_out = time_embed_dim
98
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
99
+
100
+ if post_act_fn is None:
101
+ self.post_act = None
102
+ else:
103
+ self.post_act = get_activation(post_act_fn)
104
+
105
+ def forward(self, sample, condition=None):
106
+ if condition is not None:
107
+ sample = sample + self.cond_proj(condition)
108
+ sample = self.linear_1(sample)
109
+
110
+ if self.act is not None:
111
+ sample = self.act(sample)
112
+
113
+ sample = self.linear_2(sample)
114
+
115
+ if self.post_act is not None:
116
+ sample = self.post_act(sample)
117
+ return sample
118
+
119
+
120
+ class Upsample1D(nn.Module):
121
+ """A 1D upsampling layer with an optional convolution.
122
+
123
+ Parameters:
124
+ channels (`int`):
125
+ number of channels in the inputs and outputs.
126
+ use_conv (`bool`, default `False`):
127
+ option to use a convolution.
128
+ use_conv_transpose (`bool`, default `False`):
129
+ option to use a convolution transpose.
130
+ out_channels (`int`, optional):
131
+ number of output channels. Defaults to `channels`.
132
+ """
133
+
134
+ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
135
+ super().__init__()
136
+ self.channels = channels
137
+ self.out_channels = out_channels or channels
138
+ self.use_conv = use_conv
139
+ self.use_conv_transpose = use_conv_transpose
140
+ self.name = name
141
+
142
+ self.conv = None
143
+ if use_conv_transpose:
144
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
145
+ elif use_conv:
146
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
147
+
148
+ def forward(self, inputs):
149
+ assert inputs.shape[1] == self.channels
150
+ if self.use_conv_transpose:
151
+ return self.conv(inputs)
152
+
153
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
154
+
155
+ if self.use_conv:
156
+ outputs = self.conv(outputs)
157
+
158
+ return outputs
159
+
160
+
161
+ class ConformerWrapper(ConformerBlock):
162
+ def __init__( # pylint: disable=useless-super-delegation
163
+ self,
164
+ *,
165
+ dim,
166
+ dim_head=64,
167
+ heads=8,
168
+ ff_mult=4,
169
+ conv_expansion_factor=2,
170
+ conv_kernel_size=31,
171
+ attn_dropout=0,
172
+ ff_dropout=0,
173
+ conv_dropout=0,
174
+ conv_causal=False,
175
+ ):
176
+ super().__init__(
177
+ dim=dim,
178
+ dim_head=dim_head,
179
+ heads=heads,
180
+ ff_mult=ff_mult,
181
+ conv_expansion_factor=conv_expansion_factor,
182
+ conv_kernel_size=conv_kernel_size,
183
+ attn_dropout=attn_dropout,
184
+ ff_dropout=ff_dropout,
185
+ conv_dropout=conv_dropout,
186
+ conv_causal=conv_causal,
187
+ )
188
+
189
+ def forward(
190
+ self,
191
+ hidden_states,
192
+ attention_mask,
193
+ encoder_hidden_states=None,
194
+ encoder_attention_mask=None,
195
+ timestep=None,
196
+ ):
197
+ return super().forward(x=hidden_states, mask=attention_mask.bool())
198
+
199
+
200
+ class Decoder(nn.Module):
201
+ def __init__(
202
+ self,
203
+ in_channels,
204
+ out_channels,
205
+ channels=(256, 256),
206
+ dropout=0.05,
207
+ attention_head_dim=64,
208
+ n_blocks=1,
209
+ num_mid_blocks=2,
210
+ num_heads=4,
211
+ act_fn="snake",
212
+ down_block_type="transformer",
213
+ mid_block_type="transformer",
214
+ up_block_type="transformer",
215
+ ):
216
+ super().__init__()
217
+ channels = tuple(channels)
218
+ self.in_channels = in_channels
219
+ self.out_channels = out_channels
220
+
221
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
222
+ time_embed_dim = channels[0] * 4
223
+ self.time_mlp = TimestepEmbedding(
224
+ in_channels=in_channels,
225
+ time_embed_dim=time_embed_dim,
226
+ act_fn="silu",
227
+ )
228
+
229
+ self.down_blocks = nn.ModuleList([])
230
+ self.mid_blocks = nn.ModuleList([])
231
+ self.up_blocks = nn.ModuleList([])
232
+
233
+ output_channel = in_channels
234
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
235
+ input_channel = output_channel
236
+ output_channel = channels[i]
237
+ is_last = i == len(channels) - 1
238
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
239
+ transformer_blocks = nn.ModuleList(
240
+ [
241
+ self.get_block(
242
+ down_block_type,
243
+ output_channel,
244
+ attention_head_dim,
245
+ num_heads,
246
+ dropout,
247
+ act_fn,
248
+ )
249
+ for _ in range(n_blocks)
250
+ ]
251
+ )
252
+ downsample = (
253
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
254
+ )
255
+
256
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
257
+
258
+ for i in range(num_mid_blocks):
259
+ input_channel = channels[-1]
260
+ out_channels = channels[-1]
261
+
262
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
263
+
264
+ transformer_blocks = nn.ModuleList(
265
+ [
266
+ self.get_block(
267
+ mid_block_type,
268
+ output_channel,
269
+ attention_head_dim,
270
+ num_heads,
271
+ dropout,
272
+ act_fn,
273
+ )
274
+ for _ in range(n_blocks)
275
+ ]
276
+ )
277
+
278
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
279
+
280
+ channels = channels[::-1] + (channels[0],)
281
+ for i in range(len(channels) - 1):
282
+ input_channel = channels[i]
283
+ output_channel = channels[i + 1]
284
+ is_last = i == len(channels) - 2
285
+
286
+ resnet = ResnetBlock1D(
287
+ dim=2 * input_channel,
288
+ dim_out=output_channel,
289
+ time_emb_dim=time_embed_dim,
290
+ )
291
+ transformer_blocks = nn.ModuleList(
292
+ [
293
+ self.get_block(
294
+ up_block_type,
295
+ output_channel,
296
+ attention_head_dim,
297
+ num_heads,
298
+ dropout,
299
+ act_fn,
300
+ )
301
+ for _ in range(n_blocks)
302
+ ]
303
+ )
304
+ upsample = (
305
+ Upsample1D(output_channel, use_conv_transpose=True)
306
+ if not is_last
307
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
308
+ )
309
+
310
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
311
+
312
+ self.final_block = Block1D(channels[-1], channels[-1])
313
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
314
+
315
+ self.initialize_weights()
316
+ # nn.init.normal_(self.final_proj.weight)
317
+
318
+ @staticmethod
319
+ def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
320
+ if block_type == "conformer":
321
+ block = ConformerWrapper(
322
+ dim=dim,
323
+ dim_head=attention_head_dim,
324
+ heads=num_heads,
325
+ ff_mult=1,
326
+ conv_expansion_factor=2,
327
+ ff_dropout=dropout,
328
+ attn_dropout=dropout,
329
+ conv_dropout=dropout,
330
+ conv_kernel_size=31,
331
+ )
332
+ elif block_type == "transformer":
333
+ block = BasicTransformerBlock(
334
+ dim=dim,
335
+ num_attention_heads=num_heads,
336
+ attention_head_dim=attention_head_dim,
337
+ dropout=dropout,
338
+ activation_fn=act_fn,
339
+ )
340
+ else:
341
+ raise ValueError(f"Unknown block type {block_type}")
342
+
343
+ return block
344
+
345
+ def initialize_weights(self):
346
+ for m in self.modules():
347
+ if isinstance(m, nn.Conv1d):
348
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
349
+
350
+ if m.bias is not None:
351
+ nn.init.constant_(m.bias, 0)
352
+
353
+ elif isinstance(m, nn.GroupNorm):
354
+ nn.init.constant_(m.weight, 1)
355
+ nn.init.constant_(m.bias, 0)
356
+
357
+ elif isinstance(m, nn.Linear):
358
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
359
+
360
+ if m.bias is not None:
361
+ nn.init.constant_(m.bias, 0)
362
+
363
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
364
+ """Forward pass of the UNet1DConditional model.
365
+
366
+ Args:
367
+ x (torch.Tensor): shape (batch_size, in_channels, time)
368
+ mask (_type_): shape (batch_size, 1, time)
369
+ t (_type_): shape (batch_size)
370
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
371
+ cond (_type_, optional): placeholder for future use. Defaults to None.
372
+
373
+ Raises:
374
+ ValueError: _description_
375
+ ValueError: _description_
376
+
377
+ Returns:
378
+ _type_: _description_
379
+ """
380
+
381
+ t = self.time_embeddings(t)
382
+ t = self.time_mlp(t)
383
+
384
+ x = pack([x, mu], "b * t")[0]
385
+
386
+ if spks is not None:
387
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
388
+ x = pack([x, spks], "b * t")[0]
389
+
390
+ hiddens = []
391
+ masks = [mask]
392
+ for resnet, transformer_blocks, downsample in self.down_blocks:
393
+ mask_down = masks[-1]
394
+ x = resnet(x, mask_down, t)
395
+ x = rearrange(x, "b c t -> b t c")
396
+ mask_down = rearrange(mask_down, "b 1 t -> b t")
397
+ for transformer_block in transformer_blocks:
398
+ x = transformer_block(
399
+ hidden_states=x,
400
+ attention_mask=mask_down,
401
+ timestep=t,
402
+ )
403
+ x = rearrange(x, "b t c -> b c t")
404
+ mask_down = rearrange(mask_down, "b t -> b 1 t")
405
+ hiddens.append(x) # Save hidden states for skip connections
406
+ x = downsample(x * mask_down)
407
+ masks.append(mask_down[:, :, ::2])
408
+
409
+ masks = masks[:-1]
410
+ mask_mid = masks[-1]
411
+
412
+ for resnet, transformer_blocks in self.mid_blocks:
413
+ x = resnet(x, mask_mid, t)
414
+ x = rearrange(x, "b c t -> b t c")
415
+ mask_mid = rearrange(mask_mid, "b 1 t -> b t")
416
+ for transformer_block in transformer_blocks:
417
+ x = transformer_block(
418
+ hidden_states=x,
419
+ attention_mask=mask_mid,
420
+ timestep=t,
421
+ )
422
+ x = rearrange(x, "b t c -> b c t")
423
+ mask_mid = rearrange(mask_mid, "b t -> b 1 t")
424
+
425
+ for resnet, transformer_blocks, upsample in self.up_blocks:
426
+ mask_up = masks.pop()
427
+ x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
428
+ x = rearrange(x, "b c t -> b t c")
429
+ mask_up = rearrange(mask_up, "b 1 t -> b t")
430
+ for transformer_block in transformer_blocks:
431
+ x = transformer_block(
432
+ hidden_states=x,
433
+ attention_mask=mask_up,
434
+ timestep=t,
435
+ )
436
+ x = rearrange(x, "b t c -> b c t")
437
+ mask_up = rearrange(mask_up, "b t -> b 1 t")
438
+ x = upsample(x * mask_up)
439
+
440
+ x = self.final_block(x, mask_up)
441
+ output = self.final_proj(x * mask_up)
442
+
443
+ return output * mask
matcha/models/components/flow_matching.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from matcha.models.components.decoder import Decoder
7
+ from matcha.utils.pylogger import get_pylogger
8
+
9
+ log = get_pylogger(__name__)
10
+
11
+
12
+ class BASECFM(torch.nn.Module, ABC):
13
+ def __init__(
14
+ self,
15
+ n_feats,
16
+ cfm_params,
17
+ n_spks=1,
18
+ spk_emb_dim=128,
19
+ ):
20
+ super().__init__()
21
+ self.n_feats = n_feats
22
+ self.n_spks = n_spks
23
+ self.spk_emb_dim = spk_emb_dim
24
+ self.solver = cfm_params.solver
25
+ if hasattr(cfm_params, "sigma_min"):
26
+ self.sigma_min = cfm_params.sigma_min
27
+ else:
28
+ self.sigma_min = 1e-4
29
+
30
+ self.estimator = None
31
+
32
+ @torch.inference_mode()
33
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
34
+ """Forward diffusion
35
+
36
+ Args:
37
+ mu (torch.Tensor): output of encoder
38
+ shape: (batch_size, n_feats, mel_timesteps)
39
+ mask (torch.Tensor): output_mask
40
+ shape: (batch_size, 1, mel_timesteps)
41
+ n_timesteps (int): number of diffusion steps
42
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
43
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
44
+ shape: (batch_size, spk_emb_dim)
45
+ cond: Not used but kept for future purposes
46
+
47
+ Returns:
48
+ sample: generated mel-spectrogram
49
+ shape: (batch_size, n_feats, mel_timesteps)
50
+ """
51
+ z = torch.randn_like(mu) * temperature
52
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
53
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
54
+
55
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
56
+ """
57
+ Fixed euler solver for ODEs.
58
+ Args:
59
+ x (torch.Tensor): random noise
60
+ t_span (torch.Tensor): n_timesteps interpolated
61
+ shape: (n_timesteps + 1,)
62
+ mu (torch.Tensor): output of encoder
63
+ shape: (batch_size, n_feats, mel_timesteps)
64
+ mask (torch.Tensor): output_mask
65
+ shape: (batch_size, 1, mel_timesteps)
66
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
67
+ shape: (batch_size, spk_emb_dim)
68
+ cond: Not used but kept for future purposes
69
+ """
70
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
71
+
72
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
73
+ # Or in future might add like a return_all_steps flag
74
+ sol = []
75
+
76
+ for step in range(1, len(t_span)):
77
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
78
+
79
+ x = x + dt * dphi_dt
80
+ t = t + dt
81
+ sol.append(x)
82
+ if step < len(t_span) - 1:
83
+ dt = t_span[step + 1] - t
84
+
85
+ return sol[-1]
86
+
87
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
88
+ """Computes diffusion loss
89
+
90
+ Args:
91
+ x1 (torch.Tensor): Target
92
+ shape: (batch_size, n_feats, mel_timesteps)
93
+ mask (torch.Tensor): target mask
94
+ shape: (batch_size, 1, mel_timesteps)
95
+ mu (torch.Tensor): output of encoder
96
+ shape: (batch_size, n_feats, mel_timesteps)
97
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
98
+ shape: (batch_size, spk_emb_dim)
99
+
100
+ Returns:
101
+ loss: conditional flow matching loss
102
+ y: conditional flow
103
+ shape: (batch_size, n_feats, mel_timesteps)
104
+ """
105
+ b, _, t = mu.shape
106
+
107
+ # random timestep
108
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
109
+ # sample noise p(x_0)
110
+ z = torch.randn_like(x1)
111
+
112
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
113
+ u = x1 - (1 - self.sigma_min) * z
114
+
115
+ loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
116
+ torch.sum(mask) * u.shape[1]
117
+ )
118
+ return loss, y
119
+
120
+
121
+ class CFM(BASECFM):
122
+ def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
123
+ super().__init__(
124
+ n_feats=in_channels,
125
+ cfm_params=cfm_params,
126
+ n_spks=n_spks,
127
+ spk_emb_dim=spk_emb_dim,
128
+ )
129
+
130
+ in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
131
+ # Just change the architecture of the estimator here
132
+ self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
matcha/models/components/text_encoder.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jaywalnut310/glow-tts """
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn # pylint: disable=consider-using-from-import
7
+ from einops import rearrange
8
+
9
+ import matcha.utils as utils # pylint: disable=consider-using-from-import
10
+ from matcha.utils.model import sequence_mask
11
+
12
+ log = utils.get_pylogger(__name__)
13
+
14
+
15
+ class LayerNorm(nn.Module):
16
+ def __init__(self, channels, eps=1e-4):
17
+ super().__init__()
18
+ self.channels = channels
19
+ self.eps = eps
20
+
21
+ self.gamma = torch.nn.Parameter(torch.ones(channels))
22
+ self.beta = torch.nn.Parameter(torch.zeros(channels))
23
+
24
+ def forward(self, x):
25
+ n_dims = len(x.shape)
26
+ mean = torch.mean(x, 1, keepdim=True)
27
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
28
+
29
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
30
+
31
+ shape = [1, -1] + [1] * (n_dims - 2)
32
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
33
+ return x
34
+
35
+
36
+ class ConvReluNorm(nn.Module):
37
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
38
+ super().__init__()
39
+ self.in_channels = in_channels
40
+ self.hidden_channels = hidden_channels
41
+ self.out_channels = out_channels
42
+ self.kernel_size = kernel_size
43
+ self.n_layers = n_layers
44
+ self.p_dropout = p_dropout
45
+
46
+ self.conv_layers = torch.nn.ModuleList()
47
+ self.norm_layers = torch.nn.ModuleList()
48
+ self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
49
+ self.norm_layers.append(LayerNorm(hidden_channels))
50
+ self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
51
+ for _ in range(n_layers - 1):
52
+ self.conv_layers.append(
53
+ torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
54
+ )
55
+ self.norm_layers.append(LayerNorm(hidden_channels))
56
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
57
+ self.proj.weight.data.zero_()
58
+ self.proj.bias.data.zero_()
59
+
60
+ def forward(self, x, x_mask):
61
+ x_org = x
62
+ for i in range(self.n_layers):
63
+ x = self.conv_layers[i](x * x_mask)
64
+ x = self.norm_layers[i](x)
65
+ x = self.relu_drop(x)
66
+ x = x_org + self.proj(x)
67
+ return x * x_mask
68
+
69
+
70
+ class DurationPredictor(nn.Module):
71
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
72
+ super().__init__()
73
+ self.in_channels = in_channels
74
+ self.filter_channels = filter_channels
75
+ self.p_dropout = p_dropout
76
+
77
+ self.drop = torch.nn.Dropout(p_dropout)
78
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
79
+ self.norm_1 = LayerNorm(filter_channels)
80
+ self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
81
+ self.norm_2 = LayerNorm(filter_channels)
82
+ self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
83
+
84
+ def forward(self, x, x_mask):
85
+ x = self.conv_1(x * x_mask)
86
+ x = torch.relu(x)
87
+ x = self.norm_1(x)
88
+ x = self.drop(x)
89
+ x = self.conv_2(x * x_mask)
90
+ x = torch.relu(x)
91
+ x = self.norm_2(x)
92
+ x = self.drop(x)
93
+ x = self.proj(x * x_mask)
94
+ return x * x_mask
95
+
96
+
97
+ class RotaryPositionalEmbeddings(nn.Module):
98
+ """
99
+ ## RoPE module
100
+
101
+ Rotary encoding transforms pairs of features by rotating in the 2D plane.
102
+ That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
103
+ Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
104
+ by an angle depending on the position of the token.
105
+ """
106
+
107
+ def __init__(self, d: int, base: int = 10_000):
108
+ r"""
109
+ * `d` is the number of features $d$
110
+ * `base` is the constant used for calculating $\Theta$
111
+ """
112
+ super().__init__()
113
+
114
+ self.base = base
115
+ self.d = int(d)
116
+ self.cos_cached = None
117
+ self.sin_cached = None
118
+
119
+ def _build_cache(self, x: torch.Tensor):
120
+ r"""
121
+ Cache $\cos$ and $\sin$ values
122
+ """
123
+ # Return if cache is already built
124
+ if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
125
+ return
126
+
127
+ # Get sequence length
128
+ seq_len = x.shape[0]
129
+
130
+ # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
131
+ theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
132
+
133
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
134
+ seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
135
+
136
+ # Calculate the product of position index and $\theta_i$
137
+ idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
138
+
139
+ # Concatenate so that for row $m$ we have
140
+ # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
141
+ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
142
+
143
+ # Cache them
144
+ self.cos_cached = idx_theta2.cos()[:, None, None, :]
145
+ self.sin_cached = idx_theta2.sin()[:, None, None, :]
146
+
147
+ def _neg_half(self, x: torch.Tensor):
148
+ # $\frac{d}{2}$
149
+ d_2 = self.d // 2
150
+
151
+ # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
152
+ return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
153
+
154
+ def forward(self, x: torch.Tensor):
155
+ """
156
+ * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
157
+ """
158
+ # Cache $\cos$ and $\sin$ values
159
+ x = rearrange(x, "b h t d -> t b h d")
160
+
161
+ self._build_cache(x)
162
+
163
+ # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
164
+ x_rope, x_pass = x[..., : self.d], x[..., self.d :]
165
+
166
+ # Calculate
167
+ # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
168
+ neg_half_x = self._neg_half(x_rope)
169
+
170
+ x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
171
+
172
+ return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
173
+
174
+
175
+ class MultiHeadAttention(nn.Module):
176
+ def __init__(
177
+ self,
178
+ channels,
179
+ out_channels,
180
+ n_heads,
181
+ heads_share=True,
182
+ p_dropout=0.0,
183
+ proximal_bias=False,
184
+ proximal_init=False,
185
+ ):
186
+ super().__init__()
187
+ assert channels % n_heads == 0
188
+
189
+ self.channels = channels
190
+ self.out_channels = out_channels
191
+ self.n_heads = n_heads
192
+ self.heads_share = heads_share
193
+ self.proximal_bias = proximal_bias
194
+ self.p_dropout = p_dropout
195
+ self.attn = None
196
+
197
+ self.k_channels = channels // n_heads
198
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
199
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
200
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
201
+
202
+ # from https://nn.labml.ai/transformers/rope/index.html
203
+ self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
204
+ self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
205
+
206
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
207
+ self.drop = torch.nn.Dropout(p_dropout)
208
+
209
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
210
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
211
+ if proximal_init:
212
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
213
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
214
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
215
+
216
+ def forward(self, x, c, attn_mask=None):
217
+ q = self.conv_q(x)
218
+ k = self.conv_k(c)
219
+ v = self.conv_v(c)
220
+
221
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
222
+
223
+ x = self.conv_o(x)
224
+ return x
225
+
226
+ def attention(self, query, key, value, mask=None):
227
+ b, d, t_s, t_t = (*key.size(), query.size(2))
228
+ query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
229
+ key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
230
+ value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
231
+
232
+ query = self.query_rotary_pe(query)
233
+ key = self.key_rotary_pe(key)
234
+
235
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
236
+
237
+ if self.proximal_bias:
238
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
239
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
240
+ if mask is not None:
241
+ scores = scores.masked_fill(mask == 0, -1e4)
242
+ p_attn = torch.nn.functional.softmax(scores, dim=-1)
243
+ p_attn = self.drop(p_attn)
244
+ output = torch.matmul(p_attn, value)
245
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
246
+ return output, p_attn
247
+
248
+ @staticmethod
249
+ def _attention_bias_proximal(length):
250
+ r = torch.arange(length, dtype=torch.float32)
251
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
252
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
253
+
254
+
255
+ class FFN(nn.Module):
256
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
257
+ super().__init__()
258
+ self.in_channels = in_channels
259
+ self.out_channels = out_channels
260
+ self.filter_channels = filter_channels
261
+ self.kernel_size = kernel_size
262
+ self.p_dropout = p_dropout
263
+
264
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
265
+ self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
266
+ self.drop = torch.nn.Dropout(p_dropout)
267
+
268
+ def forward(self, x, x_mask):
269
+ x = self.conv_1(x * x_mask)
270
+ x = torch.relu(x)
271
+ x = self.drop(x)
272
+ x = self.conv_2(x * x_mask)
273
+ return x * x_mask
274
+
275
+
276
+ class Encoder(nn.Module):
277
+ def __init__(
278
+ self,
279
+ hidden_channels,
280
+ filter_channels,
281
+ n_heads,
282
+ n_layers,
283
+ kernel_size=1,
284
+ p_dropout=0.0,
285
+ **kwargs,
286
+ ):
287
+ super().__init__()
288
+ self.hidden_channels = hidden_channels
289
+ self.filter_channels = filter_channels
290
+ self.n_heads = n_heads
291
+ self.n_layers = n_layers
292
+ self.kernel_size = kernel_size
293
+ self.p_dropout = p_dropout
294
+
295
+ self.drop = torch.nn.Dropout(p_dropout)
296
+ self.attn_layers = torch.nn.ModuleList()
297
+ self.norm_layers_1 = torch.nn.ModuleList()
298
+ self.ffn_layers = torch.nn.ModuleList()
299
+ self.norm_layers_2 = torch.nn.ModuleList()
300
+ for _ in range(self.n_layers):
301
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
302
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
303
+ self.ffn_layers.append(
304
+ FFN(
305
+ hidden_channels,
306
+ hidden_channels,
307
+ filter_channels,
308
+ kernel_size,
309
+ p_dropout=p_dropout,
310
+ )
311
+ )
312
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
313
+
314
+ def forward(self, x, x_mask):
315
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
316
+ for i in range(self.n_layers):
317
+ x = x * x_mask
318
+ y = self.attn_layers[i](x, x, attn_mask)
319
+ y = self.drop(y)
320
+ x = self.norm_layers_1[i](x + y)
321
+ y = self.ffn_layers[i](x, x_mask)
322
+ y = self.drop(y)
323
+ x = self.norm_layers_2[i](x + y)
324
+ x = x * x_mask
325
+ return x
326
+
327
+
328
+ class TextEncoder(nn.Module):
329
+ def __init__(
330
+ self,
331
+ encoder_type,
332
+ encoder_params,
333
+ duration_predictor_params,
334
+ n_vocab,
335
+ n_spks=1,
336
+ spk_emb_dim=128,
337
+ ):
338
+ super().__init__()
339
+ self.encoder_type = encoder_type
340
+ self.n_vocab = n_vocab
341
+ self.n_feats = encoder_params.n_feats
342
+ self.n_channels = encoder_params.n_channels
343
+ self.spk_emb_dim = spk_emb_dim
344
+ self.n_spks = n_spks
345
+
346
+ self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
347
+ torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
348
+
349
+ if encoder_params.prenet:
350
+ self.prenet = ConvReluNorm(
351
+ self.n_channels,
352
+ self.n_channels,
353
+ self.n_channels,
354
+ kernel_size=5,
355
+ n_layers=3,
356
+ p_dropout=0.5,
357
+ )
358
+ else:
359
+ self.prenet = lambda x, x_mask: x
360
+
361
+ self.encoder = Encoder(
362
+ encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
363
+ encoder_params.filter_channels,
364
+ encoder_params.n_heads,
365
+ encoder_params.n_layers,
366
+ encoder_params.kernel_size,
367
+ encoder_params.p_dropout,
368
+ )
369
+
370
+ self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
371
+ self.proj_w = DurationPredictor(
372
+ self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
373
+ duration_predictor_params.filter_channels_dp,
374
+ duration_predictor_params.kernel_size,
375
+ duration_predictor_params.p_dropout,
376
+ )
377
+
378
+ def forward(self, x, x_lengths, spks=None):
379
+ """Run forward pass to the transformer based encoder and duration predictor
380
+
381
+ Args:
382
+ x (torch.Tensor): text input
383
+ shape: (batch_size, max_text_length)
384
+ x_lengths (torch.Tensor): text input lengths
385
+ shape: (batch_size,)
386
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
387
+ shape: (batch_size,)
388
+
389
+ Returns:
390
+ mu (torch.Tensor): average output of the encoder
391
+ shape: (batch_size, n_feats, max_text_length)
392
+ logw (torch.Tensor): log duration predicted by the duration predictor
393
+ shape: (batch_size, 1, max_text_length)
394
+ x_mask (torch.Tensor): mask for the text input
395
+ shape: (batch_size, 1, max_text_length)
396
+ """
397
+ x = self.emb(x) * math.sqrt(self.n_channels)
398
+ x = torch.transpose(x, 1, -1)
399
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
400
+
401
+ x = self.prenet(x, x_mask)
402
+ if self.n_spks > 1:
403
+ x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
404
+ x = self.encoder(x, x_mask)
405
+ mu = self.proj_m(x) * x_mask
406
+
407
+ x_dp = torch.detach(x)
408
+ logw = self.proj_w(x_dp, x_mask)
409
+
410
+ return mu, logw, x_mask
matcha/models/components/transformer.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn # pylint: disable=consider-using-from-import
5
+ from diffusers.models.attention import (
6
+ GEGLU,
7
+ GELU,
8
+ AdaLayerNorm,
9
+ AdaLayerNormZero,
10
+ ApproximateGELU,
11
+ )
12
+ from diffusers.models.attention_processor import Attention
13
+ from diffusers.models.lora import LoRACompatibleLinear
14
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
15
+
16
+
17
+ class SnakeBeta(nn.Module):
18
+ """
19
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
20
+ Shape:
21
+ - Input: (B, C, T)
22
+ - Output: (B, C, T), same shape as the input
23
+ Parameters:
24
+ - alpha - trainable parameter that controls frequency
25
+ - beta - trainable parameter that controls magnitude
26
+ References:
27
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
28
+ https://arxiv.org/abs/2006.08195
29
+ Examples:
30
+ >>> a1 = snakebeta(256)
31
+ >>> x = torch.randn(256)
32
+ >>> x = a1(x)
33
+ """
34
+
35
+ def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
36
+ """
37
+ Initialization.
38
+ INPUT:
39
+ - in_features: shape of the input
40
+ - alpha - trainable parameter that controls frequency
41
+ - beta - trainable parameter that controls magnitude
42
+ alpha is initialized to 1 by default, higher values = higher-frequency.
43
+ beta is initialized to 1 by default, higher values = higher-magnitude.
44
+ alpha will be trained along with the rest of your model.
45
+ """
46
+ super().__init__()
47
+ self.in_features = out_features if isinstance(out_features, list) else [out_features]
48
+ self.proj = LoRACompatibleLinear(in_features, out_features)
49
+
50
+ # initialize alpha
51
+ self.alpha_logscale = alpha_logscale
52
+ if self.alpha_logscale: # log scale alphas initialized to zeros
53
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
54
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
55
+ else: # linear scale alphas initialized to ones
56
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
57
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
58
+
59
+ self.alpha.requires_grad = alpha_trainable
60
+ self.beta.requires_grad = alpha_trainable
61
+
62
+ self.no_div_by_zero = 0.000000001
63
+
64
+ def forward(self, x):
65
+ """
66
+ Forward pass of the function.
67
+ Applies the function to the input elementwise.
68
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
69
+ """
70
+ x = self.proj(x)
71
+ if self.alpha_logscale:
72
+ alpha = torch.exp(self.alpha)
73
+ beta = torch.exp(self.beta)
74
+ else:
75
+ alpha = self.alpha
76
+ beta = self.beta
77
+
78
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
79
+
80
+ return x
81
+
82
+
83
+ class FeedForward(nn.Module):
84
+ r"""
85
+ A feed-forward layer.
86
+
87
+ Parameters:
88
+ dim (`int`): The number of channels in the input.
89
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
90
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
91
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
92
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
93
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ dim: int,
99
+ dim_out: Optional[int] = None,
100
+ mult: int = 4,
101
+ dropout: float = 0.0,
102
+ activation_fn: str = "geglu",
103
+ final_dropout: bool = False,
104
+ ):
105
+ super().__init__()
106
+ inner_dim = int(dim * mult)
107
+ dim_out = dim_out if dim_out is not None else dim
108
+
109
+ if activation_fn == "gelu":
110
+ act_fn = GELU(dim, inner_dim)
111
+ if activation_fn == "gelu-approximate":
112
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
113
+ elif activation_fn == "geglu":
114
+ act_fn = GEGLU(dim, inner_dim)
115
+ elif activation_fn == "geglu-approximate":
116
+ act_fn = ApproximateGELU(dim, inner_dim)
117
+ elif activation_fn == "snakebeta":
118
+ act_fn = SnakeBeta(dim, inner_dim)
119
+
120
+ self.net = nn.ModuleList([])
121
+ # project in
122
+ self.net.append(act_fn)
123
+ # project dropout
124
+ self.net.append(nn.Dropout(dropout))
125
+ # project out
126
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
127
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
128
+ if final_dropout:
129
+ self.net.append(nn.Dropout(dropout))
130
+
131
+ def forward(self, hidden_states):
132
+ for module in self.net:
133
+ hidden_states = module(hidden_states)
134
+ return hidden_states
135
+
136
+
137
+ @maybe_allow_in_graph
138
+ class BasicTransformerBlock(nn.Module):
139
+ r"""
140
+ A basic Transformer block.
141
+
142
+ Parameters:
143
+ dim (`int`): The number of channels in the input and output.
144
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
145
+ attention_head_dim (`int`): The number of channels in each head.
146
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
147
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
148
+ only_cross_attention (`bool`, *optional*):
149
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
150
+ double_self_attention (`bool`, *optional*):
151
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
152
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
153
+ num_embeds_ada_norm (:
154
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
155
+ attention_bias (:
156
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ dim: int,
162
+ num_attention_heads: int,
163
+ attention_head_dim: int,
164
+ dropout=0.0,
165
+ cross_attention_dim: Optional[int] = None,
166
+ activation_fn: str = "geglu",
167
+ num_embeds_ada_norm: Optional[int] = None,
168
+ attention_bias: bool = False,
169
+ only_cross_attention: bool = False,
170
+ double_self_attention: bool = False,
171
+ upcast_attention: bool = False,
172
+ norm_elementwise_affine: bool = True,
173
+ norm_type: str = "layer_norm",
174
+ final_dropout: bool = False,
175
+ ):
176
+ super().__init__()
177
+ self.only_cross_attention = only_cross_attention
178
+
179
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
180
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
181
+
182
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
183
+ raise ValueError(
184
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
185
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
186
+ )
187
+
188
+ # Define 3 blocks. Each block has its own normalization layer.
189
+ # 1. Self-Attn
190
+ if self.use_ada_layer_norm:
191
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
192
+ elif self.use_ada_layer_norm_zero:
193
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
194
+ else:
195
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
196
+ self.attn1 = Attention(
197
+ query_dim=dim,
198
+ heads=num_attention_heads,
199
+ dim_head=attention_head_dim,
200
+ dropout=dropout,
201
+ bias=attention_bias,
202
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
203
+ upcast_attention=upcast_attention,
204
+ )
205
+
206
+ # 2. Cross-Attn
207
+ if cross_attention_dim is not None or double_self_attention:
208
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
209
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
210
+ # the second cross attention block.
211
+ self.norm2 = (
212
+ AdaLayerNorm(dim, num_embeds_ada_norm)
213
+ if self.use_ada_layer_norm
214
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
215
+ )
216
+ self.attn2 = Attention(
217
+ query_dim=dim,
218
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
219
+ heads=num_attention_heads,
220
+ dim_head=attention_head_dim,
221
+ dropout=dropout,
222
+ bias=attention_bias,
223
+ upcast_attention=upcast_attention,
224
+ # scale_qk=False, # uncomment this to not to use flash attention
225
+ ) # is self-attn if encoder_hidden_states is none
226
+ else:
227
+ self.norm2 = None
228
+ self.attn2 = None
229
+
230
+ # 3. Feed-forward
231
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
232
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
233
+
234
+ # let chunk size default to None
235
+ self._chunk_size = None
236
+ self._chunk_dim = 0
237
+
238
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
239
+ # Sets chunk feed-forward
240
+ self._chunk_size = chunk_size
241
+ self._chunk_dim = dim
242
+
243
+ def forward(
244
+ self,
245
+ hidden_states: torch.FloatTensor,
246
+ attention_mask: Optional[torch.FloatTensor] = None,
247
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
248
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
249
+ timestep: Optional[torch.LongTensor] = None,
250
+ cross_attention_kwargs: Dict[str, Any] = None,
251
+ class_labels: Optional[torch.LongTensor] = None,
252
+ ):
253
+ # Notice that normalization is always applied before the real computation in the following blocks.
254
+ # 1. Self-Attention
255
+ if self.use_ada_layer_norm:
256
+ norm_hidden_states = self.norm1(hidden_states, timestep)
257
+ elif self.use_ada_layer_norm_zero:
258
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
259
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
260
+ )
261
+ else:
262
+ norm_hidden_states = self.norm1(hidden_states)
263
+
264
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
265
+
266
+ attn_output = self.attn1(
267
+ norm_hidden_states,
268
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
269
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
270
+ **cross_attention_kwargs,
271
+ )
272
+ if self.use_ada_layer_norm_zero:
273
+ attn_output = gate_msa.unsqueeze(1) * attn_output
274
+ hidden_states = attn_output + hidden_states
275
+
276
+ # 2. Cross-Attention
277
+ if self.attn2 is not None:
278
+ norm_hidden_states = (
279
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
280
+ )
281
+
282
+ attn_output = self.attn2(
283
+ norm_hidden_states,
284
+ encoder_hidden_states=encoder_hidden_states,
285
+ attention_mask=encoder_attention_mask,
286
+ **cross_attention_kwargs,
287
+ )
288
+ hidden_states = attn_output + hidden_states
289
+
290
+ # 3. Feed-forward
291
+ norm_hidden_states = self.norm3(hidden_states)
292
+
293
+ if self.use_ada_layer_norm_zero:
294
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
295
+
296
+ if self._chunk_size is not None:
297
+ # "feed_forward_chunk_size" can be used to save memory
298
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
299
+ raise ValueError(
300
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
301
+ )
302
+
303
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
304
+ ff_output = torch.cat(
305
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
306
+ dim=self._chunk_dim,
307
+ )
308
+ else:
309
+ ff_output = self.ff(norm_hidden_states)
310
+
311
+ if self.use_ada_layer_norm_zero:
312
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
313
+
314
+ hidden_states = ff_output + hidden_states
315
+
316
+ return hidden_states
matcha/models/matcha_tts.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime as dt
2
+ import math
3
+ import random
4
+
5
+ import torch
6
+
7
+ import matcha.utils.monotonic_align as monotonic_align # pylint: disable=consider-using-from-import
8
+ from matcha import utils
9
+ from matcha.models.baselightningmodule import BaseLightningClass
10
+ from matcha.models.components.flow_matching import CFM
11
+ from matcha.models.components.text_encoder import TextEncoder
12
+ from matcha.utils.model import (
13
+ denormalize,
14
+ duration_loss,
15
+ fix_len_compatibility,
16
+ generate_path,
17
+ sequence_mask,
18
+ )
19
+
20
+ log = utils.get_pylogger(__name__)
21
+
22
+
23
+ class MatchaTTS(BaseLightningClass): # 🍵
24
+ def __init__(
25
+ self,
26
+ n_vocab,
27
+ n_spks,
28
+ spk_emb_dim,
29
+ n_feats,
30
+ encoder,
31
+ decoder,
32
+ cfm,
33
+ data_statistics,
34
+ out_size,
35
+ optimizer=None,
36
+ scheduler=None,
37
+ prior_loss=True,
38
+ use_precomputed_durations=False,
39
+ ):
40
+ super().__init__()
41
+
42
+ self.save_hyperparameters(logger=False)
43
+
44
+ self.n_vocab = n_vocab
45
+ self.n_spks = n_spks
46
+ self.spk_emb_dim = spk_emb_dim
47
+ self.n_feats = n_feats
48
+ self.out_size = out_size
49
+ self.prior_loss = prior_loss
50
+ self.use_precomputed_durations = use_precomputed_durations
51
+
52
+ if n_spks > 1:
53
+ self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
54
+
55
+ self.encoder = TextEncoder(
56
+ encoder.encoder_type,
57
+ encoder.encoder_params,
58
+ encoder.duration_predictor_params,
59
+ n_vocab,
60
+ n_spks,
61
+ spk_emb_dim,
62
+ )
63
+
64
+ self.decoder = CFM(
65
+ in_channels=2 * encoder.encoder_params.n_feats,
66
+ out_channel=encoder.encoder_params.n_feats,
67
+ cfm_params=cfm,
68
+ decoder_params=decoder,
69
+ n_spks=n_spks,
70
+ spk_emb_dim=spk_emb_dim,
71
+ )
72
+
73
+ self.update_data_statistics(data_statistics)
74
+
75
+ @torch.inference_mode()
76
+ def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
77
+ """
78
+ Generates mel-spectrogram from text. Returns:
79
+ 1. encoder outputs
80
+ 2. decoder outputs
81
+ 3. generated alignment
82
+
83
+ Args:
84
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
85
+ shape: (batch_size, max_text_length)
86
+ x_lengths (torch.Tensor): lengths of texts in batch.
87
+ shape: (batch_size,)
88
+ n_timesteps (int): number of steps to use for reverse diffusion in decoder.
89
+ temperature (float, optional): controls variance of terminal distribution.
90
+ spks (bool, optional): speaker ids.
91
+ shape: (batch_size,)
92
+ length_scale (float, optional): controls speech pace.
93
+ Increase value to slow down generated speech and vice versa.
94
+
95
+ Returns:
96
+ dict: {
97
+ "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
98
+ # Average mel spectrogram generated by the encoder
99
+ "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
100
+ # Refined mel spectrogram improved by the CFM
101
+ "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
102
+ # Alignment map between text and mel spectrogram
103
+ "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
104
+ # Denormalized mel spectrogram
105
+ "mel_lengths": torch.Tensor, shape: (batch_size,),
106
+ # Lengths of mel spectrograms
107
+ "rtf": float,
108
+ # Real-time factor
109
+ }
110
+ """
111
+ # For RTF computation
112
+ t = dt.datetime.now()
113
+
114
+ if self.n_spks > 1:
115
+ # Get speaker embedding
116
+ spks = self.spk_emb(spks.long())
117
+
118
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
119
+ mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
120
+
121
+ w = torch.exp(logw) * x_mask
122
+ w_ceil = torch.ceil(w) * length_scale
123
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
124
+ y_max_length = y_lengths.max()
125
+ y_max_length_ = fix_len_compatibility(y_max_length)
126
+
127
+ # Using obtained durations `w` construct alignment map `attn`
128
+ y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
129
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
130
+ attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
131
+
132
+ # Align encoded text and get mu_y
133
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
134
+ mu_y = mu_y.transpose(1, 2)
135
+ encoder_outputs = mu_y[:, :, :y_max_length]
136
+
137
+ # Generate sample tracing the probability flow
138
+ decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
139
+ decoder_outputs = decoder_outputs[:, :, :y_max_length]
140
+
141
+ t = (dt.datetime.now() - t).total_seconds()
142
+ rtf = t * 22050 / (decoder_outputs.shape[-1] * 256)
143
+
144
+ return {
145
+ "encoder_outputs": encoder_outputs,
146
+ "decoder_outputs": decoder_outputs,
147
+ "attn": attn[:, :, :y_max_length],
148
+ "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std),
149
+ "mel_lengths": y_lengths,
150
+ "rtf": rtf,
151
+ }
152
+
153
+ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
154
+ """
155
+ Computes 3 losses:
156
+ 1. duration loss: loss between predicted token durations and those extracted by Monotonic Alignment Search (MAS).
157
+ 2. prior loss: loss between mel-spectrogram and encoder outputs.
158
+ 3. flow matching loss: loss between mel-spectrogram and decoder outputs.
159
+
160
+ Args:
161
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
162
+ shape: (batch_size, max_text_length)
163
+ x_lengths (torch.Tensor): lengths of texts in batch.
164
+ shape: (batch_size,)
165
+ y (torch.Tensor): batch of corresponding mel-spectrograms.
166
+ shape: (batch_size, n_feats, max_mel_length)
167
+ y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
168
+ shape: (batch_size,)
169
+ out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
170
+ Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
171
+ spks (torch.Tensor, optional): speaker ids.
172
+ shape: (batch_size,)
173
+ """
174
+ if self.n_spks > 1:
175
+ # Get speaker embedding
176
+ spks = self.spk_emb(spks)
177
+
178
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
179
+ mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
180
+ y_max_length = y.shape[-1]
181
+
182
+ y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
183
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
184
+
185
+ if self.use_precomputed_durations:
186
+ attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
187
+ else:
188
+ # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
189
+ with torch.no_grad():
190
+ const = -0.5 * math.log(2 * math.pi) * self.n_feats
191
+ factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
192
+ y_square = torch.matmul(factor.transpose(1, 2), y**2)
193
+ y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
194
+ mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
195
+ log_prior = y_square - y_mu_double + mu_square + const
196
+
197
+ attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
198
+ attn = attn.detach() # b, t_text, T_mel
199
+
200
+ # Compute loss between predicted log-scaled durations and those obtained from MAS
201
+ # refered to as prior loss in the paper
202
+ logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
203
+ dur_loss = duration_loss(logw, logw_, x_lengths)
204
+
205
+ # Cut a small segment of mel-spectrogram in order to increase batch size
206
+ # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
207
+ # - Do not need this hack for Matcha-TTS, but it works with it as well
208
+ if not isinstance(out_size, type(None)):
209
+ max_offset = (y_lengths - out_size).clamp(0)
210
+ offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
211
+ out_offset = torch.LongTensor(
212
+ [torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges]
213
+ ).to(y_lengths)
214
+ attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
215
+ y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
216
+
217
+ y_cut_lengths = []
218
+ for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
219
+ y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
220
+ y_cut_lengths.append(y_cut_length)
221
+ cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
222
+ y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
223
+ attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
224
+
225
+ y_cut_lengths = torch.LongTensor(y_cut_lengths)
226
+ y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
227
+
228
+ attn = attn_cut
229
+ y = y_cut
230
+ y_mask = y_cut_mask
231
+
232
+ # Align encoded text with mel-spectrogram and get mu_y segment
233
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
234
+ mu_y = mu_y.transpose(1, 2)
235
+
236
+ # Compute loss of the decoder
237
+ diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
238
+
239
+ if self.prior_loss:
240
+ prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
241
+ prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
242
+ else:
243
+ prior_loss = 0
244
+
245
+ return dur_loss, prior_loss, diff_loss, attn
matcha/onnx/__init__.py ADDED
File without changes
matcha/onnx/export.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ from lightning import LightningModule
8
+
9
+ from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder
10
+
11
+ DEFAULT_OPSET = 15
12
+
13
+ SEED = 1234
14
+ random.seed(SEED)
15
+ np.random.seed(SEED)
16
+ torch.manual_seed(SEED)
17
+ torch.cuda.manual_seed(SEED)
18
+ torch.backends.cudnn.deterministic = True
19
+ torch.backends.cudnn.benchmark = False
20
+
21
+
22
+ class MatchaWithVocoder(LightningModule):
23
+ def __init__(self, matcha, vocoder):
24
+ super().__init__()
25
+ self.matcha = matcha
26
+ self.vocoder = vocoder
27
+
28
+ def forward(self, x, x_lengths, scales, spks=None):
29
+ mel, mel_lengths = self.matcha(x, x_lengths, scales, spks)
30
+ wavs = self.vocoder(mel).clamp(-1, 1)
31
+ lengths = mel_lengths * 256
32
+ return wavs.squeeze(1), lengths
33
+
34
+
35
+ def get_exportable_module(matcha, vocoder, n_timesteps):
36
+ """
37
+ Return an appropriate `LighteningModule` and output-node names
38
+ based on whether the vocoder is embedded in the final graph
39
+ """
40
+
41
+ def onnx_forward_func(x, x_lengths, scales, spks=None):
42
+ """
43
+ Custom forward function for accepting
44
+ scaler parameters as tensors
45
+ """
46
+ # Extract scaler parameters from tensors
47
+ temperature = scales[0]
48
+ length_scale = scales[1]
49
+ output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale)
50
+ return output["mel"], output["mel_lengths"]
51
+
52
+ # Monkey-patch Matcha's forward function
53
+ matcha.forward = onnx_forward_func
54
+
55
+ if vocoder is None:
56
+ model, output_names = matcha, ["mel", "mel_lengths"]
57
+ else:
58
+ model = MatchaWithVocoder(matcha, vocoder)
59
+ output_names = ["wav", "wav_lengths"]
60
+ return model, output_names
61
+
62
+
63
+ def get_inputs(is_multi_speaker):
64
+ """
65
+ Create dummy inputs for tracing
66
+ """
67
+ dummy_input_length = 50
68
+ x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long)
69
+ x_lengths = torch.LongTensor([dummy_input_length])
70
+
71
+ # Scales
72
+ temperature = 0.667
73
+ length_scale = 1.0
74
+ scales = torch.Tensor([temperature, length_scale])
75
+
76
+ model_inputs = [x, x_lengths, scales]
77
+ input_names = [
78
+ "x",
79
+ "x_lengths",
80
+ "scales",
81
+ ]
82
+
83
+ if is_multi_speaker:
84
+ spks = torch.LongTensor([1])
85
+ model_inputs.append(spks)
86
+ input_names.append("spks")
87
+
88
+ return tuple(model_inputs), input_names
89
+
90
+
91
+ def main():
92
+ parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX")
93
+
94
+ parser.add_argument(
95
+ "checkpoint_path",
96
+ type=str,
97
+ help="Path to the model checkpoint",
98
+ )
99
+ parser.add_argument("output", type=str, help="Path to output `.onnx` file")
100
+ parser.add_argument(
101
+ "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)"
102
+ )
103
+ parser.add_argument(
104
+ "--vocoder-name",
105
+ type=str,
106
+ choices=list(VOCODER_URLS.keys()),
107
+ default=None,
108
+ help="Name of the vocoder to embed in the ONNX graph",
109
+ )
110
+ parser.add_argument(
111
+ "--vocoder-checkpoint-path",
112
+ type=str,
113
+ default=None,
114
+ help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience",
115
+ )
116
+ parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15")
117
+
118
+ args = parser.parse_args()
119
+
120
+ print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}")
121
+ print(f"Setting n_timesteps to {args.n_timesteps}")
122
+
123
+ checkpoint_path = Path(args.checkpoint_path)
124
+ matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu")
125
+
126
+ if args.vocoder_name or args.vocoder_checkpoint_path:
127
+ assert (
128
+ args.vocoder_name and args.vocoder_checkpoint_path
129
+ ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph."
130
+ vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu")
131
+ else:
132
+ vocoder = None
133
+
134
+ is_multi_speaker = matcha.n_spks > 1
135
+
136
+ dummy_input, input_names = get_inputs(is_multi_speaker)
137
+ model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps)
138
+
139
+ # Set dynamic shape for inputs/outputs
140
+ dynamic_axes = {
141
+ "x": {0: "batch_size", 1: "time"},
142
+ "x_lengths": {0: "batch_size"},
143
+ }
144
+
145
+ if vocoder is None:
146
+ dynamic_axes.update(
147
+ {
148
+ "mel": {0: "batch_size", 2: "time"},
149
+ "mel_lengths": {0: "batch_size"},
150
+ }
151
+ )
152
+ else:
153
+ print("Embedding the vocoder in the ONNX graph")
154
+ dynamic_axes.update(
155
+ {
156
+ "wav": {0: "batch_size", 1: "time"},
157
+ "wav_lengths": {0: "batch_size"},
158
+ }
159
+ )
160
+
161
+ if is_multi_speaker:
162
+ dynamic_axes["spks"] = {0: "batch_size"}
163
+
164
+ # Create the output directory (if not exists)
165
+ Path(args.output).parent.mkdir(parents=True, exist_ok=True)
166
+
167
+ model.to_onnx(
168
+ args.output,
169
+ dummy_input,
170
+ input_names=input_names,
171
+ output_names=output_names,
172
+ dynamic_axes=dynamic_axes,
173
+ opset_version=args.opset,
174
+ export_params=True,
175
+ do_constant_folding=True,
176
+ )
177
+ print(f"[🍵] ONNX model exported to {args.output}")
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main()
matcha/onnx/infer.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import warnings
4
+ from pathlib import Path
5
+ from time import perf_counter
6
+
7
+ import numpy as np
8
+ import onnxruntime as ort
9
+ import soundfile as sf
10
+ import torch
11
+
12
+ from matcha.cli import plot_spectrogram_to_numpy, process_text
13
+
14
+
15
+ def validate_args(args):
16
+ assert (
17
+ args.text or args.file
18
+ ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
19
+ assert args.temperature >= 0, "Sampling temperature cannot be negative"
20
+ assert args.speaking_rate >= 0, "Speaking rate must be greater than 0"
21
+ return args
22
+
23
+
24
+ def write_wavs(model, inputs, output_dir, external_vocoder=None):
25
+ if external_vocoder is None:
26
+ print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly")
27
+ t0 = perf_counter()
28
+ wavs, wav_lengths = model.run(None, inputs)
29
+ infer_secs = perf_counter() - t0
30
+ mel_infer_secs = vocoder_infer_secs = None
31
+ else:
32
+ print("[🍵] Generating mel using Matcha")
33
+ mel_t0 = perf_counter()
34
+ mels, mel_lengths = model.run(None, inputs)
35
+ mel_infer_secs = perf_counter() - mel_t0
36
+ print("Generating waveform from mel using external vocoder")
37
+ vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels}
38
+ vocoder_t0 = perf_counter()
39
+ wavs = external_vocoder.run(None, vocoder_inputs)[0]
40
+ vocoder_infer_secs = perf_counter() - vocoder_t0
41
+ wavs = wavs.squeeze(1)
42
+ wav_lengths = mel_lengths * 256
43
+ infer_secs = mel_infer_secs + vocoder_infer_secs
44
+
45
+ output_dir = Path(output_dir)
46
+ output_dir.mkdir(parents=True, exist_ok=True)
47
+ for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)):
48
+ output_filename = output_dir.joinpath(f"output_{i + 1}.wav")
49
+ audio = wav[:wav_length]
50
+ print(f"Writing audio to {output_filename}")
51
+ sf.write(output_filename, audio, 22050, "PCM_24")
52
+
53
+ wav_secs = wav_lengths.sum() / 22050
54
+ print(f"Inference seconds: {infer_secs}")
55
+ print(f"Generated wav seconds: {wav_secs}")
56
+ rtf = infer_secs / wav_secs
57
+ if mel_infer_secs is not None:
58
+ mel_rtf = mel_infer_secs / wav_secs
59
+ print(f"Matcha RTF: {mel_rtf}")
60
+ if vocoder_infer_secs is not None:
61
+ vocoder_rtf = vocoder_infer_secs / wav_secs
62
+ print(f"Vocoder RTF: {vocoder_rtf}")
63
+ print(f"Overall RTF: {rtf}")
64
+
65
+
66
+ def write_mels(model, inputs, output_dir):
67
+ t0 = perf_counter()
68
+ mels, mel_lengths = model.run(None, inputs)
69
+ infer_secs = perf_counter() - t0
70
+
71
+ output_dir = Path(output_dir)
72
+ output_dir.mkdir(parents=True, exist_ok=True)
73
+ for i, mel in enumerate(mels):
74
+ output_stem = output_dir.joinpath(f"output_{i + 1}")
75
+ plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png"))
76
+ np.save(output_stem.with_suffix(".numpy"), mel)
77
+
78
+ wav_secs = (mel_lengths * 256).sum() / 22050
79
+ print(f"Inference seconds: {infer_secs}")
80
+ print(f"Generated wav seconds: {wav_secs}")
81
+ rtf = infer_secs / wav_secs
82
+ print(f"RTF: {rtf}")
83
+
84
+
85
+ def main():
86
+ parser = argparse.ArgumentParser(
87
+ description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
88
+ )
89
+ parser.add_argument(
90
+ "model",
91
+ type=str,
92
+ help="ONNX model to use",
93
+ )
94
+ parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)")
95
+ parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
96
+ parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
97
+ parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
98
+ parser.add_argument(
99
+ "--temperature",
100
+ type=float,
101
+ default=0.667,
102
+ help="Variance of the x0 noise (default: 0.667)",
103
+ )
104
+ parser.add_argument(
105
+ "--speaking-rate",
106
+ type=float,
107
+ default=1.0,
108
+ help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
109
+ )
110
+ parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
111
+ parser.add_argument(
112
+ "--output-dir",
113
+ type=str,
114
+ default=os.getcwd(),
115
+ help="Output folder to save results (default: current dir)",
116
+ )
117
+
118
+ args = parser.parse_args()
119
+ args = validate_args(args)
120
+
121
+ if args.gpu:
122
+ providers = ["GPUExecutionProvider"]
123
+ else:
124
+ providers = ["CPUExecutionProvider"]
125
+ model = ort.InferenceSession(args.model, providers=providers)
126
+
127
+ model_inputs = model.get_inputs()
128
+ model_outputs = list(model.get_outputs())
129
+
130
+ if args.text:
131
+ text_lines = args.text.splitlines()
132
+ else:
133
+ with open(args.file, encoding="utf-8") as file:
134
+ text_lines = file.read().splitlines()
135
+
136
+ processed_lines = [process_text(0, line, "cpu") for line in text_lines]
137
+ x = [line["x"].squeeze() for line in processed_lines]
138
+ # Pad
139
+ x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
140
+ x = x.detach().cpu().numpy()
141
+ x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64)
142
+ inputs = {
143
+ "x": x,
144
+ "x_lengths": x_lengths,
145
+ "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32),
146
+ }
147
+ is_multi_speaker = len(model_inputs) == 4
148
+ if is_multi_speaker:
149
+ if args.spk is None:
150
+ args.spk = 0
151
+ warn = "[!] Speaker ID not provided! Using speaker ID 0"
152
+ warnings.warn(warn, UserWarning)
153
+ inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64)
154
+
155
+ has_vocoder_embedded = model_outputs[0].name == "wav"
156
+ if has_vocoder_embedded:
157
+ write_wavs(model, inputs, args.output_dir)
158
+ elif args.vocoder:
159
+ external_vocoder = ort.InferenceSession(args.vocoder, providers=providers)
160
+ write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder)
161
+ else:
162
+ warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory"
163
+ warnings.warn(warn, UserWarning)
164
+ write_mels(model, inputs, args.output_dir)
165
+
166
+
167
+ if __name__ == "__main__":
168
+ main()
matcha/text/__init__.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from matcha.text import cleaners
3
+ from matcha.text.symbols import symbols
4
+
5
+ # Mappings from symbol to numeric ID and vice versa:
6
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
7
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
8
+
9
+
10
+ class UnknownCleanerException(Exception):
11
+ pass
12
+
13
+
14
+ def text_to_sequence(text, cleaner_names):
15
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
16
+ Args:
17
+ text: string to convert to a sequence
18
+ cleaner_names: names of the cleaner functions to run the text through
19
+ Returns:
20
+ List of integers corresponding to the symbols in the text
21
+ """
22
+ sequence = []
23
+
24
+ clean_text = _clean_text(text, cleaner_names)
25
+ for symbol in clean_text:
26
+ symbol_id = _symbol_to_id[symbol]
27
+ sequence += [symbol_id]
28
+ return sequence, clean_text
29
+
30
+
31
+ def cleaned_text_to_sequence(cleaned_text):
32
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
33
+ Args:
34
+ text: string to convert to a sequence
35
+ Returns:
36
+ List of integers corresponding to the symbols in the text
37
+ """
38
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text]
39
+ return sequence
40
+
41
+
42
+ def sequence_to_text(sequence):
43
+ """Converts a sequence of IDs back to a string"""
44
+ result = ""
45
+ for symbol_id in sequence:
46
+ s = _id_to_symbol[symbol_id]
47
+ result += s
48
+ return result
49
+
50
+
51
+ def _clean_text(text, cleaner_names):
52
+ for name in cleaner_names:
53
+ cleaner = getattr(cleaners, name)
54
+ if not cleaner:
55
+ raise UnknownCleanerException(f"Unknown cleaner: {name}")
56
+ text = cleaner(text)
57
+ return text
matcha/text/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3 kB). View file
 
matcha/text/__pycache__/cleaners.cpython-311.pyc ADDED
Binary file (4.74 kB). View file
 
matcha/text/__pycache__/symbols.cpython-311.pyc ADDED
Binary file (1.54 kB). View file
 
matcha/text/cleaners.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron
2
+
3
+ Cleaners are transformations that run over the input text at both training and eval time.
4
+
5
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
6
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
7
+ 1. "english_cleaners" for English text
8
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
9
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
10
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
11
+ the symbols in symbols.py to match your data).
12
+ """
13
+
14
+ import logging
15
+ import re
16
+
17
+ import phonemizer
18
+ from unidecode import unidecode
19
+
20
+ # To avoid excessive logging we set the log level of the phonemizer package to Critical
21
+ critical_logger = logging.getLogger("phonemizer")
22
+ critical_logger.setLevel(logging.CRITICAL)
23
+
24
+ # Intializing the phonemizer globally significantly reduces the speed
25
+ # now the phonemizer is not initialising at every call
26
+ # Might be less flexible, but it is much-much faster
27
+ # global_phonemizer = phonemizer.backend.EspeakBackend(
28
+ # language="en-us",
29
+ # preserve_punctuation=True,
30
+ # with_stress=True,
31
+ # language_switch="remove-flags",
32
+ # logger=critical_logger,
33
+ # )
34
+ global_phonemizer=None
35
+
36
+
37
+ # Regular expression matching whitespace:
38
+ _whitespace_re = re.compile(r"\s+")
39
+
40
+ # Remove brackets
41
+ _brackets_re = re.compile(r"[\[\]\(\)\{\}]")
42
+
43
+ # List of (regular expression, replacement) pairs for abbreviations:
44
+ _abbreviations = [
45
+ (re.compile(f"\\b{x[0]}\\.", re.IGNORECASE), x[1])
46
+ for x in [
47
+ ("mrs", "misess"),
48
+ ("mr", "mister"),
49
+ ("dr", "doctor"),
50
+ ("st", "saint"),
51
+ ("co", "company"),
52
+ ("jr", "junior"),
53
+ ("maj", "major"),
54
+ ("gen", "general"),
55
+ ("drs", "doctors"),
56
+ ("rev", "reverend"),
57
+ ("lt", "lieutenant"),
58
+ ("hon", "honorable"),
59
+ ("sgt", "sergeant"),
60
+ ("capt", "captain"),
61
+ ("esq", "esquire"),
62
+ ("ltd", "limited"),
63
+ ("col", "colonel"),
64
+ ("ft", "fort"),
65
+ ]
66
+ ]
67
+
68
+
69
+ def expand_abbreviations(text):
70
+ for regex, replacement in _abbreviations:
71
+ text = re.sub(regex, replacement, text)
72
+ return text
73
+
74
+
75
+ def lowercase(text):
76
+ return text.lower()
77
+
78
+
79
+ def remove_brackets(text):
80
+ return re.sub(_brackets_re, "", text)
81
+
82
+
83
+ def collapse_whitespace(text):
84
+ return re.sub(_whitespace_re, " ", text)
85
+
86
+
87
+ def convert_to_ascii(text):
88
+ return unidecode(text)
89
+
90
+
91
+ def basic_cleaners(text):
92
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
93
+ text = lowercase(text)
94
+ text = collapse_whitespace(text)
95
+ return text
96
+
97
+
98
+ def transliteration_cleaners(text):
99
+ """Pipeline for non-English text that transliterates to ASCII."""
100
+ text = convert_to_ascii(text)
101
+ text = lowercase(text)
102
+ text = collapse_whitespace(text)
103
+ return text
104
+
105
+
106
+ def english_cleaners2(text):
107
+ """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
108
+ text = convert_to_ascii(text)
109
+ text = lowercase(text)
110
+ text = expand_abbreviations(text)
111
+ phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
112
+ # Added in some cases espeak is not removing brackets
113
+ phonemes = remove_brackets(phonemes)
114
+ phonemes = collapse_whitespace(phonemes)
115
+ return phonemes
116
+
117
+
118
+ def ipa_simplifier(text):
119
+ replacements = [
120
+ ("ɐ", "ə"),
121
+ ("ˈə", "ə"),
122
+ ("ʤ", "dʒ"),
123
+ ("ʧ", "tʃ"),
124
+ ("ᵻ", "ɪ"),
125
+ ]
126
+ for replacement in replacements:
127
+ text = text.replace(replacement[0], replacement[1])
128
+ phonemes = collapse_whitespace(text)
129
+ return phonemes
130
+
131
+
132
+ # I am removing this due to incompatibility with several version of python
133
+ # However, if you want to use it, you can uncomment it
134
+ # and install piper-phonemize with the following command:
135
+ # pip install piper-phonemize
136
+
137
+ # import piper_phonemize
138
+ # def english_cleaners_piper(text):
139
+ # """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
140
+ # text = convert_to_ascii(text)
141
+ # text = lowercase(text)
142
+ # text = expand_abbreviations(text)
143
+ # phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
144
+ # phonemes = collapse_whitespace(phonemes)
145
+ # return phonemes
matcha/text/numbers.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import re
4
+
5
+ import inflect
6
+
7
+ _inflect = inflect.engine()
8
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
9
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
10
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
11
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
12
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13
+ _number_re = re.compile(r"[0-9]+")
14
+
15
+
16
+ def _remove_commas(m):
17
+ return m.group(1).replace(",", "")
18
+
19
+
20
+ def _expand_decimal_point(m):
21
+ return m.group(1).replace(".", " point ")
22
+
23
+
24
+ def _expand_dollars(m):
25
+ match = m.group(1)
26
+ parts = match.split(".")
27
+ if len(parts) > 2:
28
+ return match + " dollars"
29
+ dollars = int(parts[0]) if parts[0] else 0
30
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31
+ if dollars and cents:
32
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
33
+ cent_unit = "cent" if cents == 1 else "cents"
34
+ return f"{dollars} {dollar_unit}, {cents} {cent_unit}"
35
+ elif dollars:
36
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
37
+ return f"{dollars} {dollar_unit}"
38
+ elif cents:
39
+ cent_unit = "cent" if cents == 1 else "cents"
40
+ return f"{cents} {cent_unit}"
41
+ else:
42
+ return "zero dollars"
43
+
44
+
45
+ def _expand_ordinal(m):
46
+ return _inflect.number_to_words(m.group(0))
47
+
48
+
49
+ def _expand_number(m):
50
+ num = int(m.group(0))
51
+ if num > 1000 and num < 3000:
52
+ if num == 2000:
53
+ return "two thousand"
54
+ elif num > 2000 and num < 2010:
55
+ return "two thousand " + _inflect.number_to_words(num % 100)
56
+ elif num % 100 == 0:
57
+ return _inflect.number_to_words(num // 100) + " hundred"
58
+ else:
59
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
60
+ else:
61
+ return _inflect.number_to_words(num, andword="")
62
+
63
+
64
+ def normalize_numbers(text):
65
+ text = re.sub(_comma_number_re, _remove_commas, text)
66
+ text = re.sub(_pounds_re, r"\1 pounds", text)
67
+ text = re.sub(_dollars_re, _expand_dollars, text)
68
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
69
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
70
+ text = re.sub(_number_re, _expand_number, text)
71
+ return text