Dionyssos commited on
Commit
b2b0a60
·
1 Parent(s): 8099af0
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +10 -6
  2. app.py +467 -0
  3. audionar.py +623 -0
  4. requirements.txt +14 -0
  5. textual.py +536 -0
  6. tts.py +847 -0
  7. wav/af_ZA_google-nwu_0184.wav +0 -0
  8. wav/af_ZA_google-nwu_1919.wav +0 -0
  9. wav/af_ZA_google-nwu_2418.wav +0 -0
  10. wav/af_ZA_google-nwu_6590.wav +0 -0
  11. wav/af_ZA_google-nwu_7130.wav +0 -0
  12. wav/af_ZA_google-nwu_7214.wav +0 -0
  13. wav/af_ZA_google-nwu_8148.wav +0 -0
  14. wav/af_ZA_google-nwu_8924.wav +0 -0
  15. wav/af_ZA_google-nwu_8963.wav +0 -0
  16. wav/bn_multi_00737.wav +0 -0
  17. wav/bn_multi_00779.wav +0 -0
  18. wav/bn_multi_01232.wav +0 -0
  19. wav/bn_multi_01701.wav +0 -0
  20. wav/bn_multi_03042.wav +0 -0
  21. wav/bn_multi_0834.wav +0 -0
  22. wav/bn_multi_1010.wav +0 -0
  23. wav/bn_multi_3108.wav +0 -0
  24. wav/bn_multi_3713.wav +0 -0
  25. wav/bn_multi_3958.wav +0 -0
  26. wav/bn_multi_4046.wav +0 -0
  27. wav/bn_multi_4811.wav +0 -0
  28. wav/bn_multi_5958.wav +0 -0
  29. wav/bn_multi_9169.wav +0 -0
  30. wav/bn_multi_rm.wav +0 -0
  31. wav/de_DE_m-ailabs_angela_merkel.wav +0 -0
  32. wav/de_DE_m-ailabs_eva_k.wav +0 -0
  33. wav/de_DE_m-ailabs_karlsson.wav +0 -0
  34. wav/de_DE_m-ailabs_ramona_deininger.wav +0 -0
  35. wav/de_DE_m-ailabs_rebecca_braunert_plunkett.wav +0 -0
  36. wav/de_DE_thorsten-emotion_amused.wav +0 -0
  37. wav/el_GR_rapunzelina.wav +0 -0
  38. wav/en_UK_apope.wav +0 -0
  39. wav/en_US_cmu_arctic_aew.wav +0 -0
  40. wav/en_US_cmu_arctic_aup.wav +0 -0
  41. wav/en_US_cmu_arctic_awb.wav +0 -0
  42. wav/en_US_cmu_arctic_awbrms.wav +0 -0
  43. wav/en_US_cmu_arctic_axb.wav +0 -0
  44. wav/en_US_cmu_arctic_bdl.wav +0 -0
  45. wav/en_US_cmu_arctic_clb.wav +0 -0
  46. wav/en_US_cmu_arctic_eey.wav +0 -0
  47. wav/en_US_cmu_arctic_fem.wav +0 -0
  48. wav/en_US_cmu_arctic_gka.wav +0 -0
  49. wav/en_US_cmu_arctic_jmk.wav +0 -0
  50. wav/en_US_cmu_arctic_ksp.wav +0 -0
README.md CHANGED
@@ -1,14 +1,18 @@
1
  ---
2
- title: SHIFT
3
- emoji: 🐨
4
- colorFrom: indigo
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.45.0
8
  app_file: app.py
9
- pinned: false
10
  license: cc-by-nc-4.0
11
- short_description: https://shift-europe.eu/
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Speech analysis
3
+ emoji: 💤
4
+ colorFrom: gray
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.41.1
8
  app_file: app.py
9
+ short_description: TTS for CPU
10
  license: cc-by-nc-4.0
11
+ tags:
12
+ - non-AR
13
+ - affective
14
+ - shift
15
+ - tts
16
  ---
17
 
18
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import typing
3
+ import gradio as gr
4
+ import numpy as np
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import audiofile
9
+ from tts import StyleTTS2
10
+ from textual import only_greek_or_only_latin, transliterate_number, fix_vocals
11
+ import textwrap
12
+ import nltk
13
+ from audionar import VitsModel, VitsTokenizer
14
+
15
+
16
+ nltk.download('punkt', download_dir='./')
17
+ nltk.download('punkt_tab', download_dir='./')
18
+ nltk.data.path.append('.')
19
+
20
+
21
+
22
+
23
+
24
+
25
+ language_names = ['Ancient greek',
26
+ 'English',
27
+ 'Deutsch',
28
+ 'French',
29
+ 'Hungarian',
30
+ 'Romanian',
31
+ 'Serbian (Approx.)']
32
+
33
+
34
+ def audionar_tts(text=None,
35
+ lang='Romanian'):
36
+
37
+ # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
38
+
39
+
40
+ lang_map = {
41
+ 'ancient greek': 'grc',
42
+ 'english': 'eng',
43
+ 'deutsch': 'deu',
44
+ 'french': 'fra',
45
+ 'hungarian': 'hun',
46
+ 'romanian': 'ron',
47
+ 'serbian (approx.)': 'rmc-script_latin',
48
+ }
49
+
50
+ if text is None or text.strip() == '':
51
+ text = 'No Audio or Txt Input'
52
+
53
+
54
+
55
+
56
+ if lang not in language_names: # StyleTTS2
57
+
58
+ text = only_greek_or_only_latin(text, lang='eng')
59
+
60
+ x = _tts.inference(text,
61
+ ref_s='wav/' + lang + '.wav')[0, 0, :].numpy() # 24 Khz
62
+
63
+ else: # VITS
64
+
65
+ lang_code = lang_map.get(lang.lower(), lang.lower().split()[0].strip())
66
+
67
+ global cached_lang_code, cached_net_g, cached_tokenizer
68
+
69
+ if 'cached_lang_code' not in globals() or cached_lang_code != lang_code:
70
+ cached_lang_code = lang_code
71
+ cached_net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval()
72
+ cached_tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}')
73
+
74
+ net_g = cached_net_g
75
+ tokenizer = cached_tokenizer
76
+ text = only_greek_or_only_latin(text, lang=lang_code)
77
+ text = transliterate_number(text, lang=lang_code)
78
+ text = fix_vocals(text, lang=lang_code)
79
+
80
+
81
+ sentences = textwrap.wrap(text, width=439)
82
+
83
+ total_audio_parts = []
84
+ for sentence in sentences:
85
+ inputs = cached_tokenizer(sentence, return_tensors="pt")
86
+ with torch.no_grad():
87
+ audio_part = cached_net_g(
88
+ input_ids=inputs.input_ids,
89
+ attention_mask=inputs.attention_mask,
90
+ lang_code=lang_code,
91
+ )[0, :]
92
+ total_audio_parts.append(audio_part)
93
+
94
+ x = torch.cat(total_audio_parts).cpu().numpy()
95
+
96
+
97
+
98
+ x = x[None, :]
99
+ x = np.concatenate([0.49 * x, 0.51 * x], 0)
100
+
101
+
102
+ wavfile = '_vits_.wav'
103
+ audiofile.write(wavfile, x, 16000)
104
+ return wavfile # 2x file for [audio out & state to pass to the Emotion reco tAB]
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+ # TTS
113
+ # VOICES = [f'wav/{vox}' for vox in os.listdir('wav')]
114
+ # add unidecode (to parse non-roman characters for the StyleTTS2
115
+ # # for the VITS it should better skip the unknown letters - dont use unidecode())
116
+ # at generation fill the state of "last tts"
117
+ # at record fill the state of "last record" and place in list of voice/langs for TTS
118
+ VOICES = ['jv_ID_google-gmu_04982.wav',
119
+ # 'it_IT_mls_1595.wav',
120
+ 'en_US_vctk_p303.wav',
121
+ 'en_US_vctk_p306.wav',
122
+ 'it_IT_mls_8842.wav',
123
+ 'en_US_cmu_arctic_ksp.wav',
124
+ 'jv_ID_google-gmu_05970.wav',
125
+ 'en_US_vctk_p318.wav',
126
+ 'ha_NE_openbible.wav',
127
+ 'ne_NP_ne-google_0883.wav',
128
+ 'en_US_vctk_p280.wav',
129
+ 'bn_multi_1010.wav',
130
+ 'en_US_vctk_p259.wav',
131
+ 'it_IT_mls_844.wav',
132
+ 'en_US_vctk_p269.wav',
133
+ 'en_US_vctk_p285.wav',
134
+ 'de_DE_m-ailabs_angela_merkel.wav',
135
+ 'en_US_vctk_p316.wav',
136
+ 'en_US_vctk_p362.wav',
137
+ 'jv_ID_google-gmu_06207.wav',
138
+ 'tn_ZA_google-nwu_9061.wav',
139
+ 'fr_FR_tom.wav',
140
+ 'en_US_vctk_p233.wav',
141
+ 'it_IT_mls_4975.wav',
142
+ 'en_US_vctk_p236.wav',
143
+ 'bn_multi_01232.wav',
144
+ 'bn_multi_5958.wav',
145
+ 'it_IT_mls_9185.wav',
146
+ 'en_US_vctk_p248.wav',
147
+ 'en_US_vctk_p287.wav',
148
+ 'it_IT_mls_9772.wav',
149
+ 'te_IN_cmu-indic_sk.wav',
150
+ 'tn_ZA_google-nwu_8333.wav',
151
+ 'en_US_vctk_p260.wav',
152
+ 'en_US_vctk_p247.wav',
153
+ 'en_US_vctk_p329.wav',
154
+ 'en_US_cmu_arctic_fem.wav',
155
+ 'en_US_cmu_arctic_rms.wav',
156
+ 'en_US_vctk_p308.wav',
157
+ 'jv_ID_google-gmu_08736.wav',
158
+ 'en_US_vctk_p245.wav',
159
+ 'fr_FR_m-ailabs_nadine_eckert_boulet.wav',
160
+ 'jv_ID_google-gmu_03314.wav',
161
+ 'en_US_vctk_p239.wav',
162
+ 'jv_ID_google-gmu_05540.wav',
163
+ 'it_IT_mls_7440.wav',
164
+ 'en_US_vctk_p310.wav',
165
+ 'en_US_vctk_p237.wav',
166
+ 'en_US_hifi-tts_92.wav',
167
+ 'en_US_cmu_arctic_aew.wav',
168
+ 'ne_NP_ne-google_2099.wav',
169
+ 'en_US_vctk_p226.wav',
170
+ 'af_ZA_google-nwu_1919.wav',
171
+ 'jv_ID_google-gmu_03727.wav',
172
+ 'en_US_vctk_p317.wav',
173
+ 'tn_ZA_google-nwu_0378.wav',
174
+ 'nl_pmk.wav',
175
+ 'en_US_vctk_p286.wav',
176
+ 'tn_ZA_google-nwu_3342.wav',
177
+ # 'en_US_vctk_p343.wav',
178
+ 'de_DE_m-ailabs_ramona_deininger.wav',
179
+ 'jv_ID_google-gmu_03424.wav',
180
+ 'en_US_vctk_p341.wav',
181
+ 'jv_ID_google-gmu_03187.wav',
182
+ 'ne_NP_ne-google_3960.wav',
183
+ 'jv_ID_google-gmu_06080.wav',
184
+ 'ne_NP_ne-google_3997.wav',
185
+ # 'en_US_vctk_p267.wav',
186
+ 'en_US_vctk_p240.wav',
187
+ 'ne_NP_ne-google_5687.wav',
188
+ 'ne_NP_ne-google_9407.wav',
189
+ 'jv_ID_google-gmu_05667.wav',
190
+ 'jv_ID_google-gmu_01519.wav',
191
+ 'ne_NP_ne-google_7957.wav',
192
+ 'it_IT_mls_4705.wav',
193
+ 'ne_NP_ne-google_6329.wav',
194
+ 'it_IT_mls_1725.wav',
195
+ 'tn_ZA_google-nwu_8914.wav',
196
+ 'en_US_ljspeech.wav',
197
+ 'tn_ZA_google-nwu_4850.wav',
198
+ 'en_US_vctk_p238.wav',
199
+ 'en_US_vctk_p302.wav',
200
+ 'jv_ID_google-gmu_08178.wav',
201
+ 'en_US_vctk_p313.wav',
202
+ 'af_ZA_google-nwu_2418.wav',
203
+ 'bn_multi_00737.wav',
204
+ 'en_US_vctk_p275.wav', # y
205
+ 'af_ZA_google-nwu_0184.wav',
206
+ 'jv_ID_google-gmu_07638.wav',
207
+ 'ne_NP_ne-google_6587.wav',
208
+ 'ne_NP_ne-google_0258.wav',
209
+ 'en_US_vctk_p232.wav',
210
+ 'en_US_vctk_p336.wav',
211
+ 'jv_ID_google-gmu_09039.wav',
212
+ 'en_US_vctk_p312.wav',
213
+ 'af_ZA_google-nwu_8148.wav',
214
+ 'en_US_vctk_p326.wav',
215
+ 'en_US_vctk_p264.wav',
216
+ 'en_US_vctk_p295.wav',
217
+ # 'en_US_vctk_p298.wav',
218
+ 'es_ES_m-ailabs_victor_villarraza.wav',
219
+ 'pl_PL_m-ailabs_nina_brown.wav',
220
+ 'tn_ZA_google-nwu_9365.wav',
221
+ 'en_US_vctk_p294.wav',
222
+ 'jv_ID_google-gmu_00658.wav',
223
+ 'jv_ID_google-gmu_08305.wav',
224
+ 'en_US_vctk_p330.wav',
225
+ 'gu_IN_cmu-indic_cmu_indic_guj_dp.wav',
226
+ 'jv_ID_google-gmu_05219.wav',
227
+ 'en_US_vctk_p284.wav',
228
+ 'de_DE_m-ailabs_eva_k.wav',
229
+ # 'bn_multi_00779.wav',
230
+ 'en_UK_apope.wav',
231
+ 'en_US_vctk_p345.wav',
232
+ 'it_IT_mls_6744.wav',
233
+ 'en_US_vctk_p347.wav',
234
+ 'en_US_m-ailabs_mary_ann.wav',
235
+ 'en_US_m-ailabs_elliot_miller.wav',
236
+ 'en_US_vctk_p279.wav',
237
+ 'ru_RU_multi_nikolaev.wav',
238
+ 'bn_multi_4811.wav',
239
+ 'tn_ZA_google-nwu_7693.wav',
240
+ 'bn_multi_01701.wav',
241
+ 'en_US_vctk_p262.wav',
242
+ # 'en_US_vctk_p266.wav',
243
+ 'en_US_vctk_p243.wav',
244
+ 'en_US_vctk_p297.wav',
245
+ 'en_US_vctk_p278.wav',
246
+ 'jv_ID_google-gmu_02059.wav',
247
+ 'en_US_vctk_p231.wav',
248
+ 'te_IN_cmu-indic_kpn.wav',
249
+ 'en_US_vctk_p250.wav',
250
+ 'it_IT_mls_4974.wav',
251
+ 'en_US_cmu_arctic_awbrms.wav',
252
+ # 'en_US_vctk_p263.wav',
253
+ 'nl_femal.wav',
254
+ 'tn_ZA_google-nwu_6116.wav',
255
+ 'jv_ID_google-gmu_06383.wav',
256
+ 'en_US_vctk_p225.wav',
257
+ 'en_US_vctk_p228.wav',
258
+ 'it_IT_mls_277.wav',
259
+ 'tn_ZA_google-nwu_7866.wav',
260
+ 'en_US_vctk_p300.wav',
261
+ 'ne_NP_ne-google_0649.wav',
262
+ 'es_ES_carlfm.wav',
263
+ 'jv_ID_google-gmu_06510.wav',
264
+ 'de_DE_m-ailabs_rebecca_braunert_plunkett.wav',
265
+ 'en_US_vctk_p340.wav',
266
+ 'en_US_cmu_arctic_gka.wav',
267
+ 'ne_NP_ne-google_2027.wav',
268
+ 'jv_ID_google-gmu_09724.wav',
269
+ 'en_US_vctk_p361.wav',
270
+ 'ne_NP_ne-google_6834.wav',
271
+ 'jv_ID_google-gmu_02326.wav',
272
+ 'fr_FR_m-ailabs_zeckou.wav',
273
+ 'tn_ZA_google-nwu_1932.wav',
274
+ # 'female-20-happy.wav',
275
+ 'tn_ZA_google-nwu_1483.wav',
276
+ 'de_DE_thorsten-emotion_amused.wav',
277
+ 'ru_RU_multi_minaev.wav',
278
+ 'sw_lanfrica.wav',
279
+ 'en_US_vctk_p271.wav',
280
+ 'tn_ZA_google-nwu_0441.wav',
281
+ 'it_IT_mls_6001.wav',
282
+ 'en_US_vctk_p305.wav',
283
+ 'it_IT_mls_8828.wav',
284
+ 'jv_ID_google-gmu_08002.wav',
285
+ 'it_IT_mls_2033.wav',
286
+ 'tn_ZA_google-nwu_3629.wav',
287
+ 'it_IT_mls_6348.wav',
288
+ 'en_US_cmu_arctic_axb.wav',
289
+ 'it_IT_mls_8181.wav',
290
+ 'en_US_vctk_p230.wav',
291
+ 'af_ZA_google-nwu_7214.wav',
292
+ 'nl_nathalie.wav',
293
+ 'it_IT_mls_8207.wav',
294
+ 'ko_KO_kss.wav',
295
+ 'af_ZA_google-nwu_6590.wav',
296
+ 'jv_ID_google-gmu_00264.wav',
297
+ 'tn_ZA_google-nwu_6234.wav',
298
+ 'jv_ID_google-gmu_05522.wav',
299
+ 'en_US_cmu_arctic_lnh.wav',
300
+ 'en_US_vctk_p272.wav',
301
+ 'en_US_cmu_arctic_slp.wav',
302
+ 'en_US_vctk_p299.wav',
303
+ 'en_US_hifi-tts_9017.wav',
304
+ 'it_IT_mls_4998.wav',
305
+ 'it_IT_mls_6299.wav',
306
+ 'en_US_cmu_arctic_rxr.wav',
307
+ # 'female-46-neutral.wav',
308
+ 'jv_ID_google-gmu_01392.wav',
309
+ 'tn_ZA_google-nwu_8512.wav',
310
+ 'en_US_vctk_p244.wav',
311
+ # 'bn_multi_3108.wav',
312
+ # 'it_IT_mls_7405.wav',
313
+ # 'bn_multi_3713.wav',
314
+ # 'yo_openbible.wav',
315
+ # 'jv_ID_google-gmu_01932.wav',
316
+ 'en_US_vctk_p270.wav',
317
+ 'tn_ZA_google-nwu_6459.wav',
318
+ 'bn_multi_4046.wav',
319
+ 'en_US_vctk_p288.wav',
320
+ 'en_US_vctk_p251.wav',
321
+ 'es_ES_m-ailabs_tux.wav',
322
+ 'tn_ZA_google-nwu_6206.wav',
323
+ 'bn_multi_9169.wav',
324
+ # 'en_US_vctk_p293.wav',
325
+ # 'en_US_vctk_p255.wav',
326
+ 'af_ZA_google-nwu_8963.wav',
327
+ # 'en_US_vctk_p265.wav',
328
+ 'gu_IN_cmu-indic_cmu_indic_guj_ad.wav',
329
+ 'jv_ID_google-gmu_07335.wav',
330
+ 'en_US_vctk_p323.wav',
331
+ 'en_US_vctk_p281.wav',
332
+ 'en_US_cmu_arctic_bdl.wav',
333
+ 'en_US_m-ailabs_judy_bieber.wav',
334
+ 'it_IT_mls_10446.wav',
335
+ 'en_US_vctk_p261.wav',
336
+ 'en_US_vctk_p292.wav',
337
+ 'te_IN_cmu-indic_ss.wav',
338
+ 'en_US_vctk_p311.wav',
339
+ 'it_IT_mls_12428.wav',
340
+ 'en_US_cmu_arctic_aup.wav',
341
+ 'jv_ID_google-gmu_04679.wav',
342
+ 'it_IT_mls_4971.wav',
343
+ 'en_US_cmu_arctic_ljm.wav',
344
+ 'fa_haaniye.wav',
345
+ 'en_US_vctk_p339.wav',
346
+ 'tn_ZA_google-nwu_7896.wav',
347
+ 'en_US_vctk_p253.wav',
348
+ 'it_IT_mls_5421.wav',
349
+ # 'ne_NP_ne-google_0546.wav',
350
+ 'vi_VN_vais1000.wav',
351
+ 'en_US_vctk_p229.wav',
352
+ 'en_US_vctk_p254.wav',
353
+ 'en_US_vctk_p258.wav',
354
+ 'it_IT_mls_7936.wav',
355
+ 'en_US_vctk_p301.wav',
356
+ 'tn_ZA_google-nwu_0045.wav',
357
+ 'it_IT_mls_659.wav',
358
+ 'tn_ZA_google-nwu_7674.wav',
359
+ 'it_IT_mls_12804.wav',
360
+ 'el_GR_rapunzelina.wav',
361
+ 'en_US_hifi-tts_6097.wav',
362
+ 'en_US_vctk_p257.wav',
363
+ 'jv_ID_google-gmu_07875.wav',
364
+ 'it_IT_mls_1157.wav',
365
+ 'it_IT_mls_643.wav',
366
+ 'en_US_vctk_p304.wav',
367
+ 'ru_RU_multi_hajdurova.wav',
368
+ 'it_IT_mls_8461.wav',
369
+ 'bn_multi_3958.wav',
370
+ 'it_IT_mls_1989.wav',
371
+ 'en_US_vctk_p249.wav',
372
+ # 'bn_multi_0834.wav',
373
+ 'en_US_vctk_p307.wav',
374
+ 'es_ES_m-ailabs_karen_savage.wav',
375
+ 'fr_FR_m-ailabs_bernard.wav',
376
+ 'en_US_vctk_p252.wav',
377
+ 'en_US_cmu_arctic_jmk.wav',
378
+ 'en_US_vctk_p333.wav',
379
+ 'tn_ZA_google-nwu_4506.wav',
380
+ 'ne_NP_ne-google_0283.wav',
381
+ 'de_DE_m-ailabs_karlsson.wav',
382
+ 'en_US_cmu_arctic_awb.wav',
383
+ 'en_US_vctk_p246.wav',
384
+ 'en_US_cmu_arctic_clb.wav',
385
+ 'en_US_vctk_p364.wav',
386
+ 'nl_flemishguy.wav',
387
+ 'en_US_vctk_p276.wav', # y
388
+ # 'en_US_vctk_p274.wav',
389
+ 'fr_FR_m-ailabs_gilles_g_le_blanc.wav',
390
+ 'it_IT_mls_7444.wav',
391
+ 'style_o22050.wav',
392
+ 'en_US_vctk_s5.wav',
393
+ 'en_US_vctk_p268.wav',
394
+ 'it_IT_mls_6807.wav',
395
+ 'it_IT_mls_2019.wav',
396
+ # 'male-60-angry.wav',
397
+ 'af_ZA_google-nwu_8924.wav',
398
+ 'en_US_vctk_p374.wav',
399
+ 'en_US_vctk_p363.wav',
400
+ 'it_IT_mls_644.wav',
401
+ 'ne_NP_ne-google_3614.wav',
402
+ 'en_US_vctk_p241.wav',
403
+ 'ne_NP_ne-google_3154.wav',
404
+ 'en_US_vctk_p234.wav',
405
+ 'it_IT_mls_8384.wav',
406
+ 'fr_FR_m-ailabs_ezwa.wav',
407
+ 'it_IT_mls_5010.wav',
408
+ 'en_US_vctk_p351.wav',
409
+ 'en_US_cmu_arctic_eey.wav',
410
+ 'jv_ID_google-gmu_04285.wav',
411
+ 'jv_ID_google-gmu_06941.wav',
412
+ 'hu_HU_diana-majlinger.wav',
413
+ 'tn_ZA_google-nwu_2839.wav',
414
+ 'bn_multi_03042.wav',
415
+ 'tn_ZA_google-nwu_5628.wav',
416
+ 'it_IT_mls_4649.wav',
417
+ 'af_ZA_google-nwu_7130.wav',
418
+ 'en_US_cmu_arctic_slt.wav',
419
+ 'jv_ID_google-gmu_04175.wav',
420
+ 'gu_IN_cmu-indic_cmu_indic_guj_kt.wav',
421
+ 'jv_ID_google-gmu_00027.wav',
422
+ 'jv_ID_google-gmu_02884.wav',
423
+ 'en_US_vctk_p360.wav',
424
+ 'en_US_vctk_p334.wav',
425
+ # 'male-27-sad.wav',
426
+ 'tn_ZA_google-nwu_1498.wav',
427
+ 'fi_FI_harri-tapani-ylilammi.wav',
428
+ 'bn_multi_rm.wav',
429
+ 'ne_NP_ne-google_2139.wav',
430
+ 'pl_PL_m-ailabs_piotr_nater.wav',
431
+ 'fr_FR_siwis.wav',
432
+ 'nl_bart-de-leeuw.wav',
433
+ 'jv_ID_google-gmu_04715.wav',
434
+ 'en_US_vctk_p283.wav',
435
+ 'en_US_vctk_p314.wav',
436
+ 'en_US_vctk_p335.wav',
437
+ 'jv_ID_google-gmu_07765.wav',
438
+ 'en_US_vctk_p273.wav'
439
+ ]
440
+ VOICES = [t[:-4] for t in VOICES] # crop .wav for visuals in gr.DropDown
441
+
442
+ _tts = StyleTTS2().to('cpu')
443
+
444
+
445
+ with gr.Blocks(theme='huggingface') as demo:
446
+ with gr.Row():
447
+ text_input = gr.Textbox(
448
+ label="Type text for TTS:",
449
+ placeholder="Type Text for TTS",
450
+ lines=4,
451
+ value='Η γρηγορη καφετι αλεπου πειδαει πανω απο τον τεμπελη σκυλο.',
452
+ )
453
+ choice_dropdown = gr.Dropdown(
454
+ choices=language_names + VOICES,
455
+ label="Vox",
456
+ value=language_names[0]
457
+ )
458
+ generate_button = gr.Button("Generate Audio", variant="primary")
459
+
460
+ output_audio = gr.Audio(label="TTS Output")
461
+
462
+ generate_button.click(
463
+ fn=audionar_tts,
464
+ inputs=[text_input, choice_dropdown],
465
+ outputs=[output_audio]
466
+ )
467
+ demo.launch(debug=True)
audionar.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from transformers.modeling_utils import PreTrainedModel
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ import json
8
+ import os
9
+ import re
10
+ from transformers.tokenization_utils import PreTrainedTokenizer
11
+ import phonemizer
12
+ import torch.nn.functional as F
13
+
14
+
15
+
16
+ OSCILLATION = {
17
+ 'deu': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1],
18
+ 'rmc-script_latin': [2, 2, 1, 2, 2],
19
+ 'hun': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1],
20
+ 'fra': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1],
21
+ 'eng': [1, 2, 2, 1, 2, 2],
22
+ 'grc': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1],
23
+ 'ron': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2],
24
+ }
25
+
26
+
27
+ def has_non_roman_characters(input_string):
28
+ # Find any character outside the ASCII range
29
+ non_roman_pattern = re.compile(r"[^\x00-\x7F]")
30
+
31
+ # Search the input string for non-Roman characters
32
+ match = non_roman_pattern.search(input_string)
33
+ has_non_roman = match is not None
34
+ return has_non_roman
35
+
36
+
37
+ class VitsConfig(PretrainedConfig):
38
+
39
+ model_type = "vits"
40
+
41
+ def __init__(
42
+ self,
43
+ vocab_size=38,
44
+ hidden_size=192,
45
+ num_hidden_layers=6,
46
+ num_attention_heads=2,
47
+ use_bias=True,
48
+ ffn_dim=768,
49
+ ffn_kernel_size=3,
50
+ flow_size=192,
51
+ # hidden_act="relu",
52
+ upsample_initial_channel=512,
53
+ upsample_rates=[8, 8, 2, 2],
54
+ upsample_kernel_sizes=[16, 16, 4, 4],
55
+ resblock_kernel_sizes=[3, 7, 11],
56
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
57
+ prior_encoder_num_flows=4,
58
+ prior_encoder_num_wavenet_layers=4,
59
+ wavenet_kernel_size=5,
60
+ **kwargs,
61
+ ):
62
+ self.vocab_size = vocab_size
63
+ self.hidden_size = hidden_size
64
+ self.num_hidden_layers = num_hidden_layers
65
+ self.num_attention_heads = num_attention_heads
66
+ self.use_bias = use_bias
67
+ self.ffn_dim = ffn_dim
68
+ self.ffn_kernel_size = ffn_kernel_size
69
+ self.flow_size = flow_size
70
+ self.upsample_initial_channel = upsample_initial_channel
71
+ self.upsample_rates = upsample_rates
72
+ self.upsample_kernel_sizes = upsample_kernel_sizes
73
+ self.resblock_kernel_sizes = resblock_kernel_sizes
74
+ self.resblock_dilation_sizes = resblock_dilation_sizes
75
+ self.prior_encoder_num_flows = prior_encoder_num_flows
76
+ self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers
77
+ self.wavenet_kernel_size = wavenet_kernel_size
78
+ super().__init__()
79
+
80
+
81
+ class VitsWaveNet(torch.nn.Module):
82
+ def __init__(self, config, num_layers):
83
+ super().__init__()
84
+ self.hidden_size = config.hidden_size
85
+ self.num_layers = num_layers
86
+ self.in_layers = torch.nn.ModuleList()
87
+ self.res_skip_layers = torch.nn.ModuleList()
88
+ # if hasattr(nn.utils.parametrizations, "weight_norm"):
89
+ # # raise ValueError
90
+ weight_norm = nn.utils.parametrizations.weight_norm
91
+ # else:
92
+ # raise ValueError
93
+ # # weight_norm = nn.utils.weight_norm
94
+ for i in range(num_layers):
95
+
96
+ in_layer = torch.nn.Conv1d(
97
+ in_channels=config.hidden_size,
98
+ out_channels=2 * config.hidden_size,
99
+ kernel_size=config.wavenet_kernel_size,
100
+ dilation=1,
101
+ padding=2,
102
+ )
103
+ in_layer = weight_norm(in_layer, name="weight")
104
+ self.in_layers.append(in_layer)
105
+
106
+ # last one is not necessary
107
+ if i < num_layers - 1:
108
+ res_skip_channels = 2 * config.hidden_size
109
+ else:
110
+ res_skip_channels = config.hidden_size
111
+ res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
112
+ res_skip_layer = weight_norm(res_skip_layer, name="weight")
113
+ self.res_skip_layers.append(res_skip_layer)
114
+
115
+ def forward(self,
116
+ inputs):
117
+ outputs = torch.zeros_like(inputs)
118
+ num_channels = torch.IntTensor([self.hidden_size])[0]
119
+ for i in range(self.num_layers):
120
+ in_act = self.in_layers[i](inputs)
121
+ # global_states = torch.zeros_like(hidden_states) # style ?
122
+ # acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
123
+ # --
124
+ # def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
125
+ # in_act = input_a # + input_b
126
+ t_act = torch.tanh(in_act[:, :num_channels, :])
127
+ s_act = torch.sigmoid(in_act[:, num_channels:, :])
128
+ acts = t_act * s_act
129
+ res_skip_acts = self.res_skip_layers[i](acts)
130
+ if i < self.num_layers - 1:
131
+ res_acts = res_skip_acts[:, : self.hidden_size, :]
132
+ inputs = inputs + res_acts
133
+ outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
134
+ else:
135
+ outputs = outputs + res_skip_acts
136
+ return outputs
137
+
138
+ # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
139
+ class HifiGanResidualBlock(nn.Module):
140
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
141
+ super().__init__()
142
+ self.leaky_relu_slope = leaky_relu_slope
143
+
144
+ self.convs1 = nn.ModuleList(
145
+ [
146
+ nn.Conv1d(
147
+ channels,
148
+ channels,
149
+ kernel_size,
150
+ stride=1,
151
+ dilation=dilation[i],
152
+ padding=self.get_padding(kernel_size, dilation[i]),
153
+ )
154
+ for i in range(len(dilation))
155
+ ]
156
+ )
157
+ self.convs2 = nn.ModuleList(
158
+ [
159
+ nn.Conv1d(
160
+ channels,
161
+ channels,
162
+ kernel_size,
163
+ stride=1,
164
+ dilation=1,
165
+ padding=self.get_padding(kernel_size, 1),
166
+ )
167
+ for _ in range(len(dilation))
168
+ ]
169
+ )
170
+
171
+ def get_padding(self, kernel_size, dilation=1):
172
+ # 1, 3, 5, 15
173
+ return (kernel_size * dilation - dilation) // 2
174
+
175
+ def forward(self, hidden_states):
176
+ for conv1, conv2 in zip(self.convs1, self.convs2):
177
+ residual = hidden_states
178
+ hidden_states = nn.functional.leaky_relu(hidden_states, negative_slope=self.leaky_relu_slope)
179
+ hidden_states = conv1(hidden_states)
180
+ hidden_states = nn.functional.leaky_relu(hidden_states, negative_slope=self.leaky_relu_slope)
181
+ hidden_states = conv2(hidden_states)
182
+ hidden_states = hidden_states + residual
183
+ return hidden_states
184
+
185
+
186
+ class VitsHifiGan(nn.Module):
187
+ def __init__(self, config):
188
+ super().__init__()
189
+ self.config = config
190
+ self.num_kernels = len(config.resblock_kernel_sizes)
191
+ self.num_upsamples = len(config.upsample_rates)
192
+ self.conv_pre = nn.Conv1d(
193
+ config.flow_size,
194
+ config.upsample_initial_channel,
195
+ kernel_size=7,
196
+ stride=1,
197
+ padding=3,
198
+ )
199
+
200
+ self.upsampler = nn.ModuleList()
201
+ for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
202
+ self.upsampler.append(
203
+ nn.ConvTranspose1d(
204
+ config.upsample_initial_channel // (2**i),
205
+ config.upsample_initial_channel // (2 ** (i + 1)),
206
+ kernel_size=kernel_size,
207
+ stride=upsample_rate,
208
+ padding=(kernel_size - upsample_rate) // 2,
209
+ )
210
+ )
211
+
212
+ self.resblocks = nn.ModuleList()
213
+ for i in range(len(self.upsampler)):
214
+ channels = config.upsample_initial_channel // (2 ** (i + 1))
215
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
216
+ self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation))
217
+ self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
218
+
219
+ def forward(self,
220
+ spectrogram):
221
+ hidden_states = self.conv_pre(spectrogram)
222
+ for i in range(self.num_upsamples):
223
+ hidden_states = F.leaky_relu(hidden_states, negative_slope=.1, inplace=True)
224
+ hidden_states = self.upsampler[i](hidden_states)
225
+ res_state = self.resblocks[i * self.num_kernels](hidden_states)
226
+ for j in range(1, self.num_kernels):
227
+ res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
228
+ hidden_states = res_state / self.num_kernels
229
+ hidden_states = F.leaky_relu(hidden_states, negative_slope=.01, inplace=True)
230
+ hidden_states = self.conv_post(hidden_states)
231
+ waveform = torch.tanh(hidden_states)
232
+ return waveform
233
+
234
+
235
+ class VitsResidualCouplingLayer(nn.Module):
236
+ def __init__(self, config):
237
+ super().__init__()
238
+ self.half_channels = config.flow_size // 2
239
+ self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
240
+ self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
241
+ self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
242
+
243
+ def forward(self,
244
+ x,
245
+ reverse=False):
246
+ first_half, second_half = torch.split(x, [self.half_channels] * 2, dim=1)
247
+ hidden_states = self.conv_pre(first_half)
248
+ hidden_states = self.wavenet(hidden_states)
249
+ mean = self.conv_post(hidden_states)
250
+ second_half = (second_half - mean)
251
+ outputs = torch.cat([first_half, second_half], dim=1)
252
+ return outputs
253
+
254
+
255
+ class VitsResidualCouplingBlock(nn.Module):
256
+ def __init__(self, config):
257
+ super().__init__()
258
+ self.flows = nn.ModuleList()
259
+ for _ in range(config.prior_encoder_num_flows):
260
+ self.flows.append(VitsResidualCouplingLayer(config))
261
+
262
+ def forward(self, x, reverse=False):
263
+ # x L [1, 192, 481]
264
+ for flow in reversed(self.flows):
265
+ x = torch.flip(x, [1]) # flipud CHANNELs
266
+ x = flow(x, reverse=True)
267
+ return x
268
+
269
+
270
+ class VitsAttention(nn.Module):
271
+ """has no positional info"""
272
+
273
+ def __init__(self, config):
274
+ super().__init__()
275
+ self.embed_dim = config.hidden_size
276
+ self.num_heads = config.num_attention_heads
277
+
278
+
279
+
280
+ self.head_dim = self.embed_dim // self.num_heads
281
+ self.scaling = self.head_dim**-0.5
282
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
283
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
284
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
285
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
286
+
287
+ def _shape(self, tensor, seq_len, bsz):
288
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
289
+
290
+ def forward(
291
+ self,
292
+ hidden_states,
293
+ layer_head_mask = None,
294
+ output_attentions = False,
295
+ ):
296
+
297
+
298
+ bsz, tgt_len, _ = hidden_states.size()
299
+
300
+ # Q
301
+
302
+ query_states = self.q_proj(hidden_states) * self.scaling
303
+
304
+ # K/V
305
+ hidden_states = hidden_states[:, :40, :] # drop time-frames from k/v [bs*2, time, 96=ch]
306
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
307
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
308
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
309
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
310
+ key_states = key_states.view(*proj_shape)
311
+ value_states = value_states.view(*proj_shape)
312
+
313
+
314
+
315
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
316
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
317
+ attn_output = torch.bmm(attn_weights,
318
+ value_states)
319
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
320
+ attn_output = attn_output.transpose(1, 2)
321
+
322
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
323
+ # partitioned aross GPUs when using tensor-parallelism.
324
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
325
+
326
+ attn_output = self.out_proj(attn_output)
327
+
328
+ return attn_output
329
+
330
+
331
+ class VitsFeedForward(nn.Module):
332
+ def __init__(self, config):
333
+ super().__init__()
334
+ self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size, padding=1)
335
+ self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size, padding=1)
336
+
337
+ def forward(self, hidden_states):
338
+ hidden_states = hidden_states.permute(0, 2, 1)
339
+ hidden_states = F.relu(self.conv_1(hidden_states)) # inplace changes sound ;
340
+ hidden_states = self.conv_2(hidden_states)
341
+ hidden_states = hidden_states.permute(0, 2, 1)
342
+ return hidden_states
343
+
344
+
345
+ class VitsEncoderLayer(nn.Module):
346
+ def __init__(self, config):
347
+ super().__init__()
348
+ self.attention = VitsAttention(config)
349
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-5)
350
+ self.feed_forward = VitsFeedForward(config)
351
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-5)
352
+
353
+ def forward(
354
+ self,
355
+ hidden_states,
356
+ output_attentions = False,
357
+ ):
358
+ residual = hidden_states
359
+ hidden_states = self.attention(
360
+ hidden_states=hidden_states,
361
+ # attention_mask=attention_mask,
362
+ output_attentions=output_attentions,
363
+ )
364
+
365
+
366
+ hidden_states = self.layer_norm(residual + hidden_states)
367
+
368
+ residual = hidden_states
369
+ hidden_states = self.feed_forward(hidden_states)
370
+
371
+ hidden_states = self.final_layer_norm(residual + hidden_states)
372
+
373
+ outputs = (hidden_states,)
374
+
375
+ return outputs
376
+
377
+
378
+ class VitsEncoder(nn.Module):
379
+ def __init__(self, config):
380
+ super().__init__()
381
+ self.config = config
382
+ self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)])
383
+
384
+ def forward(
385
+ self,
386
+ hidden_states):
387
+ for _layer in self.layers:
388
+ layer_outputs = _layer(hidden_states)
389
+ hidden_states = layer_outputs[0]
390
+ return hidden_states
391
+
392
+
393
+
394
+ class VitsTextEncoder(nn.Module):
395
+ """
396
+ Has VitsEncoder
397
+ """
398
+
399
+ def __init__(self, config):
400
+ super().__init__()
401
+ self.config = config
402
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
403
+ self.encoder = VitsEncoder(config) # 6 Layers of VitsAttention
404
+ self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
405
+
406
+ def forward(self,
407
+ input_ids
408
+ ):
409
+ hidden_states = self.embed_tokens(input_ids) * 4 #Actually4-or-4.856406460551018-@-845-len-ids-deu
410
+ stats = self.project(self.encoder(hidden_states=hidden_states).transpose(1, 2)).transpose(1, 2)
411
+ return stats[:, :, :self.config.flow_size] # prior_means
412
+
413
+
414
+ class VitsPreTrainedModel(PreTrainedModel):
415
+ config_class = VitsConfig
416
+ base_model_prefix = "vits"
417
+ main_input_name = "input_ids"
418
+ supports_gradient_checkpointing = True
419
+
420
+
421
+
422
+ class VitsModel(VitsPreTrainedModel):
423
+ def __init__(self, config):
424
+ super().__init__(config)
425
+ self.config = config
426
+ self.text_encoder = VitsTextEncoder(config) # has VitsEncoder that includes 6L of VitsAttention
427
+ self.flow = VitsResidualCouplingBlock(config)
428
+ self.decoder = VitsHifiGan(config)
429
+
430
+ def forward(
431
+ self,
432
+ input_ids = None,
433
+ attention_mask = None,
434
+ speaker_id = None,
435
+ output_attentions = None,
436
+ output_hidden_states = None,
437
+ return_dict = None,
438
+ labels = None,
439
+ speed = None,
440
+ lang_code = 'deu', # speed oscillation pattern per voice/lang
441
+ ):
442
+ mask_dtype = self.text_encoder.embed_tokens.weight.dtype
443
+ if attention_mask is not None:
444
+ input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
445
+ else:
446
+ raise ValueError
447
+ input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
448
+ prior_means = self.text_encoder(input_ids=input_ids)
449
+
450
+ input_padding_mask = input_padding_mask.transpose(1, 2)
451
+
452
+
453
+ bs, in_len, _ = prior_means.shape
454
+ # VITS Duration Oscillation
455
+ pattern = OSCILLATION.get(lang_code, [1, 2, 1])
456
+
457
+ duration = torch.tensor(pattern,
458
+ device=prior_means.device).repeat(int(in_len / len(pattern)) + 2)[None, None, :in_len] # perhaps define [1, 2, 1] per voice or language
459
+ duration[:, :, 0] = 4
460
+ duration[:, :, -1] = 3
461
+ # ATTN
462
+ predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
463
+ indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
464
+ output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
465
+ output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
466
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
467
+ batch_size, _, output_length, input_length = attn_mask.shape
468
+ cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
469
+ indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
470
+ valid_indices = indices.unsqueeze(0) < cum_duration
471
+ valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
472
+ padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
473
+ attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
474
+ attn = attn[:, 0, :, :]
475
+
476
+
477
+ attn = attn + 1e-4 * torch.rand_like(attn)
478
+ attn /= attn.sum(2, keepdims=True)
479
+ #print(attn)
480
+ prior_means = torch.matmul(attn, prior_means) # try attn to contain .5/.5 instead of 1/0 so it smoothly interpolates repeated prior_means
481
+
482
+ #prior_means = F.interpolate(prior_means.transpose(1,2), int(1.74 * prior_means.shape[1]), mode='linear').transpose(1,2) # extend for slow speed
483
+
484
+
485
+
486
+ # prior means have now been replicated x duration of each prior mean
487
+
488
+ latents = self.flow(prior_means.transpose(1, 2), # + torch.randn_like(prior_means) * .94,
489
+ reverse=True)
490
+
491
+ waveform = self.decoder(latents) # [bs, 1, 16000]
492
+
493
+ return waveform[:, 0, :]
494
+
495
+
496
+ class VitsTokenizer(PreTrainedTokenizer):
497
+ vocab_files_names = {"vocab_file": "vocab.json"}
498
+ model_input_names = ["input_ids", "attention_mask"]
499
+
500
+ def __init__(
501
+ self,
502
+ vocab_file,
503
+ pad_token="<pad>",
504
+ unk_token="<unk>",
505
+ language=None,
506
+ add_blank=True,
507
+ normalize=True,
508
+ phonemize=True,
509
+ is_uroman=False,
510
+ **kwargs,
511
+ ):
512
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
513
+ self.encoder = json.load(vocab_handle)
514
+
515
+ self.decoder = {v: k for k, v in self.encoder.items()}
516
+ self.language = language
517
+ self.add_blank = add_blank
518
+ self.normalize = normalize
519
+ self.phonemize = phonemize
520
+
521
+ self.is_uroman = is_uroman
522
+
523
+ super().__init__(
524
+ pad_token=pad_token,
525
+ unk_token=unk_token,
526
+ language=language,
527
+ add_blank=add_blank,
528
+ normalize=normalize,
529
+ phonemize=phonemize,
530
+ is_uroman=is_uroman,
531
+ **kwargs,
532
+ )
533
+
534
+ @property
535
+ def vocab_size(self):
536
+ return len(self.encoder)
537
+
538
+ def get_vocab(self):
539
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
540
+ vocab.update(self.added_tokens_encoder)
541
+ return vocab
542
+
543
+ def normalize_text(self, input_string):
544
+ """Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased."""
545
+ all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys())
546
+ filtered_text = ""
547
+
548
+ i = 0
549
+ while i < len(input_string):
550
+ found_match = False
551
+ for word in all_vocabulary:
552
+ if input_string[i : i + len(word)] == word:
553
+ filtered_text += word
554
+ i += len(word)
555
+ found_match = True
556
+ break
557
+
558
+ if not found_match:
559
+ filtered_text += input_string[i].lower()
560
+ i += 1
561
+
562
+ return filtered_text
563
+
564
+ def _preprocess_char(self, text):
565
+ """Special treatment of characters in certain languages"""
566
+ if self.language == "ron":
567
+ text = text.replace("ț", "ţ")
568
+ return text
569
+
570
+ def prepare_for_tokenization(
571
+ self, text: str, is_split_into_words: bool = False, normalize = None, **kwargs):
572
+
573
+ normalize = normalize if normalize is not None else self.normalize
574
+
575
+ if normalize:
576
+ # normalise for casing
577
+ text = self.normalize_text(text)
578
+
579
+ filtered_text = self._preprocess_char(text)
580
+
581
+ if has_non_roman_characters(filtered_text) and self.is_uroman:
582
+ # 7 langs - For now replace all to romans in app.py
583
+ raise ValueError
584
+
585
+ if self.phonemize:
586
+ if not is_phonemizer_available():
587
+ raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.")
588
+
589
+ filtered_text = phonemizer.phonemize(
590
+ filtered_text,
591
+ language="en-us",
592
+ backend="espeak",
593
+ strip=True,
594
+ preserve_punctuation=True,
595
+ with_stress=True,
596
+ )
597
+ filtered_text = re.sub(r"\s+", " ", filtered_text)
598
+ elif normalize:
599
+ # strip any chars outside of the vocab (punctuation)
600
+ filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip()
601
+
602
+ return filtered_text, kwargs
603
+
604
+ def _tokenize(self, text):
605
+ """Tokenize a string by inserting the `<pad>` token at the boundary between adjacent characters."""
606
+ tokens = list(text)
607
+
608
+ if self.add_blank:
609
+ # sounds dyslexi if no space between letters
610
+ # sounds disconnected if >2 spaces between letters
611
+ interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2) # + 1) # +1 rises slice index error if tokens odd
612
+ interspersed[::2] = tokens
613
+ tokens = interspersed + [self._convert_id_to_token(0)] # append one last space (it has indexing error ::2 mismatch if tokens is odd)
614
+
615
+ return tokens
616
+
617
+ def _convert_token_to_id(self, token):
618
+ """Converts a token (str) in an id using the vocab."""
619
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
620
+
621
+ def _convert_id_to_token(self, index):
622
+ """Converts an index (integer) in a token (str) using the vocab."""
623
+ return self.decoder.get(index)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf
2
+ nltk
3
+ librosa
4
+ phonemizer
5
+ audiofile
6
+ num2words
7
+ numpy<2.0.0
8
+ gradio==5.27.0
9
+ Numbers2Words-Greek
10
+ einops
11
+ torch
12
+ pydantic==2.10.6
13
+ transformers==4.49.0
14
+ sentencepiece
textual.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import unicodedata
3
+ from num2words import num2words
4
+ from num2word_greek.numbers2words import convert_numbers
5
+
6
+ def only_greek_or_only_latin(text, lang='grc'):
7
+ '''
8
+ str: The converted string in the specified target script.
9
+ Characters not found in any mapping are preserved as is.
10
+ Latin accented characters in the input (e.g., 'É', 'ü') will
11
+ be preserved in their lowercase form (e.g., 'é', 'ü') if
12
+ converting to Latin.
13
+ '''
14
+
15
+ # --- Mapping Dictionaries ---
16
+ # Keys are in lowercase as input text is case-folded.
17
+ # If the output needs to maintain original casing, additional logic is required.
18
+
19
+ latin_to_greek_map = {
20
+ 'a': 'α', 'b': 'β', 'g': 'γ', 'd': 'δ', 'e': 'ε',
21
+ 'ch': 'τσο', # Example of a multi-character Latin sequence
22
+ 'z': 'ζ', 'h': 'χ', 'i': 'ι', 'k': 'κ', 'l': 'λ',
23
+ 'm': 'μ', 'n': 'ν', 'x': 'ξ', 'o': 'ο', 'p': 'π',
24
+ 'v': 'β', 'sc': 'σκ', 'r': 'ρ', 's': 'σ', 't': 'τ',
25
+ 'u': 'ου', 'f': 'φ', 'c': 'σ', 'w': 'β', 'y': 'γ',
26
+ }
27
+
28
+ greek_to_latin_map = {
29
+ 'ου': 'ou', # Prioritize common diphthongs/digraphs
30
+ 'α': 'a', 'β': 'v', 'γ': 'g', 'δ': 'd', 'ε': 'e',
31
+ 'ζ': 'z', 'η': 'i', 'θ': 'th', 'ι': 'i', 'κ': 'k',
32
+ 'λ': 'l', 'μ': 'm', 'ν': 'n', 'ξ': 'x', 'ο': 'o',
33
+ 'π': 'p', 'ρ': 'r', 'σ': 's', 'τ': 't', 'υ': 'y', # 'y' is a common transliteration for upsilon
34
+ 'φ': 'f', 'χ': 'ch', 'ψ': 'ps', 'ω': 'o',
35
+ 'ς': 's', # Final sigma
36
+ }
37
+
38
+ cyrillic_to_latin_map = {
39
+ 'а': 'a', 'б': 'b', 'в': 'v', 'г': 'g', 'д': 'd', 'е': 'e', 'ё': 'yo', 'ж': 'zh',
40
+ 'з': 'z', 'и': 'i', 'й': 'y', 'к': 'k', 'л': 'l', 'м': 'm', 'н': 'n', 'о': 'o',
41
+ 'п': 'p', 'р': 'r', 'с': 's', 'т': 't', 'у': 'u', 'ф': 'f', 'х': 'kh', 'ц': 'ts',
42
+ 'ч': 'ch', 'ш': 'sh', 'щ': 'shch', 'ъ': '', 'ы': 'y', 'ь': '', 'э': 'e', 'ю': 'yu',
43
+ 'я': 'ya',
44
+ }
45
+
46
+ # Direct Cyrillic to Greek mapping based on phonetic similarity.
47
+ # These are approximations and may not be universally accepted transliterations.
48
+ cyrillic_to_greek_map = {
49
+ 'а': 'α', 'б': 'β', 'в': 'β', 'г': 'γ', 'д': 'δ', 'е': 'ε', 'ё': 'ιο', 'ж': 'ζ',
50
+ 'з': 'ζ', 'и': 'ι', 'й': 'ι', 'κ': 'κ', 'λ': 'λ', 'м': 'μ', 'н': 'ν', 'о': 'ο',
51
+ 'π': 'π', 'ρ': 'ρ', 'σ': 'σ', 'τ': 'τ', 'у': 'ου', 'ф': 'φ', 'х': 'χ', 'ц': 'τσ',
52
+ 'ч': 'τσ', # or τζ depending on desired sound
53
+ 'ш': 'σ', 'щ': 'σ', # approximations
54
+ 'ъ': '', 'ы': 'ι', 'ь': '', 'э': 'ε', 'ю': 'ιου',
55
+ 'я': 'ια',
56
+ }
57
+
58
+ # Convert the input text to lowercase, preserving accents for Latin characters.
59
+ # casefold() is used for more robust caseless matching across Unicode characters.
60
+ lowercased_text = text.lower() #casefold()
61
+ output_chars = []
62
+ current_index = 0
63
+
64
+ if lang == 'grc':
65
+ # Combine all relevant maps for direct lookup to Greek
66
+ conversion_map = {**latin_to_greek_map, **cyrillic_to_greek_map}
67
+
68
+ # Sort keys by length in reverse order to handle multi-character sequences first
69
+ sorted_source_keys = sorted(
70
+ list(latin_to_greek_map.keys()) + list(cyrillic_to_greek_map.keys()),
71
+ key=len,
72
+ reverse=True
73
+ )
74
+
75
+ while current_index < len(lowercased_text):
76
+ found_conversion = False
77
+ for key in sorted_source_keys:
78
+ if lowercased_text.startswith(key, current_index):
79
+ output_chars.append(conversion_map[key])
80
+ current_index += len(key)
81
+ found_conversion = True
82
+ break
83
+ if not found_conversion:
84
+ # If no specific mapping found, append the character as is.
85
+ # This handles unmapped characters and already Greek characters.
86
+ output_chars.append(lowercased_text[current_index])
87
+ current_index += 1
88
+ return ''.join(output_chars)
89
+
90
+ else: # Default to 'lat' conversion
91
+ # Combine Greek to Latin and Cyrillic to Latin maps.
92
+ # Cyrillic map keys will take precedence in case of overlap if defined after Greek.
93
+ combined_to_latin_map = {**greek_to_latin_map, **cyrillic_to_latin_map}
94
+
95
+ # Sort all relevant source keys by length in reverse for replacement
96
+ sorted_source_keys = sorted(
97
+ list(greek_to_latin_map.keys()) + list(cyrillic_to_latin_map.keys()),
98
+ key=len,
99
+ reverse=True
100
+ )
101
+
102
+ while current_index < len(lowercased_text):
103
+ found_conversion = False
104
+ for key in sorted_source_keys:
105
+ if lowercased_text.startswith(key, current_index):
106
+ latin_equivalent = combined_to_latin_map[key]
107
+
108
+ # Strip accents ONLY if the source character was from the Greek map.
109
+ # This preserves accents on original Latin characters (like 'é')
110
+ # and allows for intentional accent stripping from Greek transliterations.
111
+ if key in greek_to_latin_map:
112
+ normalized_latin = unicodedata.normalize('NFD', latin_equivalent)
113
+ stripped_latin = ''.join(c for c in normalized_latin if not unicodedata.combining(c))
114
+ output_chars.append(stripped_latin)
115
+ else:
116
+ output_chars.append(latin_equivalent)
117
+
118
+ current_index += len(key)
119
+ found_conversion = True
120
+ break
121
+
122
+ if not found_conversion:
123
+ # If no conversion happened from Greek or Cyrillic, append the character as is.
124
+ # This preserves existing Latin characters (including accented ones from input),
125
+ # numbers, punctuation, and other symbols.
126
+ output_chars.append(lowercased_text[current_index])
127
+ current_index += 1
128
+
129
+ return ''.join(output_chars)
130
+
131
+
132
+ # =====================================================
133
+ #
134
+
135
+ def fix_vocals(text, lang='ron'):
136
+
137
+ # Longer phrases should come before shorter ones to prevent partial matches.
138
+
139
+ ron_replacements = {
140
+ 'ţ': 'ț',
141
+ 'ț': 'ts',
142
+ 'î': 'u',
143
+ 'â': 'a',
144
+ 'ş': 's',
145
+ 'w': 'oui',
146
+ 'k': 'c',
147
+ 'l': 'll',
148
+ # Math symbols
149
+ 'sqrt': ' rădăcina pătrată din ',
150
+ '^': ' la puterea ',
151
+ '+': ' plus ',
152
+ ' - ': ' minus ', # only replace if standalone so to not say minus if is a-b-c
153
+ '*': ' ori ', # times
154
+ '/': ' împărțit la ', # divided by
155
+ '=': ' egal cu ', # equals
156
+ 'pi': ' pi ',
157
+ '<': ' mai mic decât ',
158
+ '>': ' mai mare decât',
159
+ '%': ' la sută ', # percent (from previous)
160
+ '(': ' paranteză deschisă ',
161
+ ')': ' paranteză închisă ',
162
+ '[': ' paranteză pătrată deschisă ',
163
+ ']': ' paranteză pătrată închisă ',
164
+ '{': ' acoladă deschisă ',
165
+ '}': ' acoladă închisă ',
166
+ '≠': ' nu este egal cu ',
167
+ '≤': ' mai mic sau egal cu ',
168
+ '≥': ' mai mare sau egal cu ',
169
+ '≈': ' aproximativ ',
170
+ '∞': ' infinit ',
171
+ '€': ' euro ',
172
+ '$': ' dolar ',
173
+ '£': ' liră ',
174
+ '&': ' și ', # and
175
+ '@': ' la ', # at
176
+ '#': ' diez ', # hash
177
+ '∑': ' sumă ',
178
+ '∫': ' integrală ',
179
+ '√': ' rădăcina pătrată a ', # more generic square root
180
+ }
181
+
182
+ eng_replacements = {
183
+ 'wik': 'weaky',
184
+ 'sh': 'ss',
185
+ 'ch': 'ttss',
186
+ 'oo': 'oeo',
187
+ # Math symbols for English
188
+ 'sqrt': ' square root of ',
189
+ '^': ' to the power of ',
190
+ '+': ' plus ',
191
+ ' - ': ' minus ',
192
+ '*': ' times ',
193
+ ' / ': ' divided by ',
194
+ '=': ' equals ',
195
+ 'pi': ' pi ',
196
+ '<': ' less than ',
197
+ '>': ' greater than ',
198
+ # Additional common math symbols from previous list
199
+ '%': ' percent ',
200
+ '(': ' open parenthesis ',
201
+ ')': ' close parenthesis ',
202
+ '[': ' open bracket ',
203
+ ']': ' close bracket ',
204
+ '{': ' open curly brace ',
205
+ '}': ' close curly brace ',
206
+ '∑': ' sum ',
207
+ '∫': ' integral ',
208
+ '√': ' square root of ',
209
+ '≠': ' not equals ',
210
+ '≤': ' less than or equals ',
211
+ '≥': ' greater than or equals ',
212
+ '≈': ' approximately ',
213
+ '∞': ' infinity ',
214
+ '€': ' euro ',
215
+ '$': ' dollar ',
216
+ '£': ' pound ',
217
+ '&': ' and ',
218
+ '@': ' at ',
219
+ '#': ' hash ',
220
+ }
221
+
222
+ serbian_replacements = {
223
+ 'rn': 'rrn',
224
+ 'ć': 'č',
225
+ 'c': 'č',
226
+ 'đ': 'd',
227
+ 'j': 'i',
228
+ 'l': 'lll',
229
+ 'w': 'v',
230
+ # https://huggingface.co/facebook/mms-tts-rmc-script_latin
231
+ 'sqrt': 'kvadratni koren iz',
232
+ '^': ' na stepen ',
233
+ '+': ' plus ',
234
+ ' - ': ' minus ',
235
+ '*': ' puta ',
236
+ ' / ': ' podeljeno sa ',
237
+ '=': ' jednako ',
238
+ 'pi': ' pi ',
239
+ '<': ' manje od ',
240
+ '>': ' veće od ',
241
+ '%': ' procenat ',
242
+ '(': ' otvorena zagrada ',
243
+ ')': ' zatvorena zagrada ',
244
+ '[': ' otvorena uglasta zagrada ',
245
+ ']': ' zatvorena uglasta zagrada ',
246
+ '{': ' otvorena vitičasta zagrada ',
247
+ '}': ' zatvorena vitičasta zagrada ',
248
+ '∑': ' suma ',
249
+ '∫': ' integral ',
250
+ '√': ' kvadratni koren ',
251
+ '≠': ' nije jednako ',
252
+ '≤': ' manje ili jednako od ',
253
+ '≥': ' veće ili jednako od ',
254
+ '≈': ' približno ',
255
+ '∞': ' beskonačnost ',
256
+ '€': ' evro ',
257
+ '$': ' dolar ',
258
+ '£': ' funta ',
259
+ '&': ' i ',
260
+ '@': ' et ',
261
+ '#': ' taraba ',
262
+ # Others
263
+ # 'rn': 'rrn',
264
+ # 'ć': 'č',
265
+ # 'c': 'č',
266
+ # 'đ': 'd',
267
+ # 'l': 'le',
268
+ # 'ij': 'i',
269
+ # 'ji': 'i',
270
+ # 'j': 'i',
271
+ # 'služ': 'sloooozz', # 'službeno'
272
+ # 'suver': 'siuveeerra', # 'suverena'
273
+ # 'država': 'dirrezav', # 'država'
274
+ # 'iči': 'ici', # 'Graniči'
275
+ # 's ': 'se', # a s with space
276
+ # 'q': 'ku',
277
+ # 'w': 'aou',
278
+ # 'z': 's',
279
+ # "š": "s",
280
+ # 'th': 'ta',
281
+ # 'v': 'vv',
282
+ # "ć": "č",
283
+ # "đ": "ď",
284
+ # "lj": "ľ",
285
+ # "nj": "ň",
286
+ # "ž": "z",
287
+ # "c": "č"
288
+ }
289
+
290
+ deu_replacements = {
291
+ 'sch': 'sh',
292
+ 'ch': 'kh',
293
+ 'ie': 'ee',
294
+ 'ei': 'ai',
295
+ 'ä': 'ae',
296
+ 'ö': 'oe',
297
+ 'ü': 'ue',
298
+ 'ß': 'ss',
299
+ # Math symbols for German
300
+ 'sqrt': ' Quadratwurzel aus ',
301
+ '^': ' hoch ',
302
+ '+': ' plus ',
303
+ ' - ': ' minus ',
304
+ '*': ' mal ',
305
+ ' / ': ' geteilt durch ',
306
+ '=': ' gleich ',
307
+ 'pi': ' pi ',
308
+ '<': ' kleiner als ',
309
+ '>': ' größer als',
310
+ # Additional common math symbols from previous list
311
+ '%': ' prozent ',
312
+ '(': ' Klammer auf ',
313
+ ')': ' Klammer zu ',
314
+ '[': ' eckige Klammer auf ',
315
+ ']': ' eckige Klammer zu ',
316
+ '{': ' geschweifte Klammer auf ',
317
+ '}': ' geschweifte Klammer zu ',
318
+ '∑': ' Summe ',
319
+ '∫': ' Integral ',
320
+ '√': ' Quadratwurzel ',
321
+ '≠': ' ungleich ',
322
+ '≤': ' kleiner oder gleich ',
323
+ '≥': ' größer oder gleich ',
324
+ '≈': ' ungefähr ',
325
+ '∞': ' unendlich ',
326
+ '€': ' euro ',
327
+ '$': ' dollar ',
328
+ '£': ' pfund ',
329
+ '&': ' und ',
330
+ '@': ' at ', # 'Klammeraffe' is also common but 'at' is simpler
331
+ '#': ' raute ',
332
+ }
333
+
334
+ fra_replacements = {
335
+ # French specific phonetic replacements (add as needed)
336
+ # e.g., 'ç': 's', 'é': 'e', etc.
337
+ 'w': 'v',
338
+ # Math symbols for French
339
+ 'sqrt': ' racine carrée de ',
340
+ '^': ' à la puissance ',
341
+ '+': ' plus ',
342
+ ' - ': ' moins ', # tiré ;
343
+ '*': ' fois ',
344
+ ' / ': ' divisé par ',
345
+ '=': ' égale ',
346
+ 'pi': ' pi ',
347
+ '<': ' inférieur à ',
348
+ '>': ' supérieur à ',
349
+ # Add more common math symbols as needed for French
350
+ '%': ' pour cent ',
351
+ '(': ' parenthèse ouverte ',
352
+ ')': ' parenthèse fermée ',
353
+ '[': ' crochet ouvert ',
354
+ ']': ' crochet fermé ',
355
+ '{': ' accolade ouverte ',
356
+ '}': ' accolade fermée ',
357
+ '∑': ' somme ',
358
+ '∫': ' intégrale ',
359
+ '√': ' racine carrée ',
360
+ '≠': ' n\'égale pas ',
361
+ '≤': ' inférieur ou égal à ',
362
+ '≥': ' supérieur ou égal à ',
363
+ '≈': ' approximativement ',
364
+ '∞': ' infini ',
365
+ '€': ' euro ',
366
+ '$': ' dollar ',
367
+ '£': ' livre ',
368
+ '&': ' et ',
369
+ '@': ' arobase ',
370
+ '#': ' dièse ',
371
+ }
372
+
373
+ hun_replacements = {
374
+ # Hungarian specific phonetic replacements (add as needed)
375
+ # e.g., 'á': 'a', 'é': 'e', etc.
376
+ 'ch': 'ts',
377
+ 'cs': 'tz',
378
+ 'g': 'gk',
379
+ 'w': 'v',
380
+ 'z': 'zz',
381
+ # Math symbols for Hungarian
382
+ 'sqrt': ' négyzetgyök ',
383
+ '^': ' hatvány ',
384
+ '+': ' plusz ',
385
+ ' - ': ' mínusz ',
386
+ '*': ' szorozva ',
387
+ ' / ': ' osztva ',
388
+ '=': ' egyenlő ',
389
+ 'pi': ' pi ',
390
+ '<': ' kisebb mint ',
391
+ '>': ' nagyobb mint ',
392
+ # Add more common math symbols as needed for Hungarian
393
+ '%': ' százalék ',
394
+ '(': ' nyitó zárójel ',
395
+ ')': ' záró zárójel ',
396
+ '[': ' nyitó szögletes zárójel ',
397
+ ']': ' záró szögletes zárójel ',
398
+ '{': ' nyitó kapcsos zárójel ',
399
+ '}': ' záró kapcsos zárójel ',
400
+ '∑': ' szumma ',
401
+ '∫': ' integrál ',
402
+ '√': ' négyzetgyök ',
403
+ '≠': ' nem egyenlő ',
404
+ '≤': ' kisebb vagy egyenlő ',
405
+ '≥': ' nagyobb vagy egyenlő ',
406
+ '≈': ' körülbelül ',
407
+ '∞': ' végtelen ',
408
+ '€': ' euró ',
409
+ '$': ' dollár ',
410
+ '£': ' font ',
411
+ '&': ' és ',
412
+ '@': ' kukac ',
413
+ '#': ' kettőskereszt ',
414
+ }
415
+
416
+ grc_replacements = {
417
+ # Ancient Greek specific phonetic replacements (add as needed)
418
+ # These are more about transliterating Greek letters if they are in the input text.
419
+ # Math symbols for Ancient Greek (literal translations)
420
+ 'sqrt': ' τετραγωνικὴ ῥίζα ',
421
+ '^': ' εἰς τὴν δύναμιν ',
422
+ '+': ' σὺν ',
423
+ ' - ': ' χωρὶς ',
424
+ '*': ' πο��λάκις ',
425
+ ' / ': ' διαιρέω ',
426
+ '=': ' ἴσον ',
427
+ 'pi': ' πῖ ',
428
+ '<': ' ἔλαττον ',
429
+ '>': ' μεῖζον ',
430
+ # Add more common math symbols as needed for Ancient Greek
431
+ '%': ' τοῖς ἑκατόν ', # tois hekaton - 'of the hundred'
432
+ '(': ' ἀνοικτὴ παρένθεσις ',
433
+ ')': ' κλειστὴ παρένθεσις ',
434
+ '[': ' ἀνοικτὴ ἀγκύλη ',
435
+ ']': ' κλειστὴ ἀγκύλη ',
436
+ '{': ' ἀνοικτὴ σγουρὴ ἀγκύλη ',
437
+ '}': ' κλειστὴ σγουρὴ ἀγκύλη ',
438
+ '∑': ' ἄθροισμα ',
439
+ '∫': ' ὁλοκλήρωμα ',
440
+ '√': ' τετραγωνικὴ ῥίζα ',
441
+ '≠': ' οὐκ ἴσον ',
442
+ '≤': ' ἔλαττον ἢ ἴσον ',
443
+ '≥': ' μεῖζον ἢ ἴσον ',
444
+ '≈': ' περίπου ',
445
+ '∞': ' ἄπειρον ',
446
+ '€': ' εὐρώ ',
447
+ '$': ' δολάριον ',
448
+ '£': ' λίρα ',
449
+ '&': ' καὶ ',
450
+ '@': ' ἀτ ', # at
451
+ '#': ' δίεση ', # hash
452
+ }
453
+
454
+
455
+ # Select the appropriate replacement dictionary based on the language
456
+ replacements_map = {
457
+ 'grc': grc_replacements,
458
+ 'ron': ron_replacements,
459
+ 'eng': eng_replacements,
460
+ 'deu': deu_replacements,
461
+ 'fra': fra_replacements,
462
+ 'hun': hun_replacements,
463
+ 'rmc-script_latin': serbian_replacements,
464
+ }
465
+
466
+ current_replacements = replacements_map.get(lang)
467
+ if current_replacements:
468
+ # Sort replacements by length of the key in descending order.
469
+ # This is crucial for correctly replacing multi-character strings (like 'sqrt', 'sch')
470
+ # before their shorter substrings ('s', 'ch', 'q', 'r', 't').
471
+ sorted_replacements = sorted(current_replacements.items(), key=lambda item: len(item[0]), reverse=True)
472
+ for old, new in sorted_replacements:
473
+ text = text.replace(old, new)
474
+ return text
475
+ else:
476
+ # If the language is not supported, return the original text
477
+ print(f"Warning: Language '{lang}' not supported for text replacement. Returning original text.")
478
+ return text
479
+
480
+
481
+ def _num2words(text='01234', lang=None):
482
+ if lang == 'grc':
483
+ return convert_numbers(text)
484
+ return num2words(text, lang=lang) # HAS TO BE kwarg lang=lang
485
+
486
+
487
+ def transliterate_number(number_string,
488
+ lang=None):
489
+ if lang == 'rmc-script_latin':
490
+ lang = 'sr'
491
+ exponential_pronoun = ' puta deset na stepen od '
492
+ comma = ' tačka '
493
+ elif lang == 'ron':
494
+ lang = 'ro'
495
+ exponential_pronoun = ' tízszer a erejéig '
496
+ comma = ' virgulă '
497
+ elif lang == 'hun':
498
+ lang = 'hu'
499
+ exponential_pronoun = ' tízszer a erejéig '
500
+ comma = ' virgula '
501
+ elif lang == 'deu':
502
+ exponential_pronoun = ' mal zehn hoch '
503
+ comma = ' komma '
504
+ elif lang == 'fra':
505
+ lang = 'fr'
506
+ exponential_pronoun = ' puissance '
507
+ comma = 'virgule'
508
+ elif lang == 'grc':
509
+ exponential_pronoun = ' εις την δυναμην του '
510
+ comma = 'κομμα'
511
+ else:
512
+ lang = lang[:2]
513
+ exponential_pronoun = ' times ten to the power of '
514
+ comma = ' point '
515
+
516
+ def replace_number(match):
517
+ prefix = match.group(1) or ""
518
+ number_part = match.group(2)
519
+ suffix = match.group(5) or ""
520
+
521
+ try:
522
+ if 'e' in number_part.lower():
523
+ base, exponent = number_part.lower().split('e')
524
+ words = _num2words(base, lang=lang) + exponential_pronoun + _num2words(exponent, lang=lang)
525
+ elif '.' in number_part:
526
+ integer_part, decimal_part = number_part.split('.')
527
+ words = _num2words(integer_part, lang=lang) + comma + " ".join(
528
+ [_num2words(digit, lang=lang) for digit in decimal_part])
529
+ else:
530
+ words = _num2words(number_part, lang=lang)
531
+ return prefix + words + suffix
532
+ except ValueError:
533
+ return match.group(0) # Return original if conversion fails
534
+
535
+ pattern = r'([^\d]*)(\d+(\.\d+)?([Ee][+-]?\d+)?)([^\d]*)'
536
+ return re.sub(pattern, replace_number, number_string)
tts.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import nltk
3
+ nltk.download('punkt', download_dir='./') # COMMENT IF DOWNLOADED
4
+ nltk.download('punkt_tab', download_dir='./') # COMMENT IF DOWNLOADED
5
+ nltk.data.path.append('.')
6
+ import librosa
7
+ import audiofile
8
+ import torch.nn.functional as F
9
+ import math
10
+ import numpy as np
11
+ import torch.nn as nn
12
+ import string
13
+ import textwrap
14
+ import phonemizer
15
+ from espeak_util import set_espeak_library
16
+ from transformers import AlbertConfig, AlbertModel
17
+ from huggingface_hub import hf_hub_download
18
+ from nltk.tokenize import word_tokenize
19
+ from torch.nn import Conv1d, ConvTranspose1d
20
+ from torch.nn.utils.parametrizations import weight_norm
21
+ from torch.nn.utils import spectral_norm
22
+
23
+ _pad = "$"
24
+ _punctuation = ';:,.!?¡¿—…"«»“” '
25
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
26
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
27
+ MAX_PHONEMES = 424 # For OOM is the max length of single (non-split) sentence for StyleTTS2 inference
28
+
29
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
30
+
31
+ dicts = {}
32
+ for i in range(len((symbols))):
33
+ dicts[symbols[i]] = i
34
+
35
+
36
+ class TextCleaner:
37
+ def __init__(self, dummy=None):
38
+ self.word_index_dictionary = dicts
39
+ print(len(dicts))
40
+
41
+ def __call__(self, text):
42
+ indexes = []
43
+ for char in text:
44
+ try:
45
+ indexes.append(self.word_index_dictionary[char])
46
+ except KeyError:
47
+ # `=NONVOCAL == \x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f
48
+ # print(f'NonVOCAL {char}', end='\r')
49
+ pass
50
+ return indexes
51
+
52
+ set_espeak_library()
53
+
54
+ textclenaer = TextCleaner()
55
+
56
+ global_phonemizer = phonemizer.backend.EspeakBackend(language="en-us", preserve_punctuation=True, with_stress=True)
57
+
58
+ def _del_prefix(d):
59
+ # del ".module"
60
+ out = {}
61
+ for k, v in d.items():
62
+ out[k[7:]] = v
63
+ return out
64
+
65
+
66
+
67
+
68
+ class StyleTTS2(nn.Module):
69
+
70
+ def __init__(self):
71
+ super().__init__()
72
+ albert_base_configuration = AlbertConfig(vocab_size=178,
73
+ hidden_size=768,
74
+ num_attention_heads=12,
75
+ intermediate_size=2048,
76
+ max_position_embeddings=512,
77
+ num_hidden_layers=12,
78
+ dropout=0.1)
79
+ self.bert = AlbertModel(albert_base_configuration)
80
+ state_dict = torch.load(hf_hub_download(repo_id='dkounadis/artificial-styletts2',
81
+ filename='Utils/PLBERT/step_1000000.pth'),
82
+ map_location='cpu')['net']
83
+ new_state_dict = {}
84
+ for k, v in state_dict.items():
85
+ name = k[7:] # remove `module.`
86
+ if name.startswith('encoder.'):
87
+ name = name[8:] # remove `encoder.`
88
+ new_state_dict[name] = v
89
+ del new_state_dict["embeddings.position_ids"]
90
+ self.bert.load_state_dict(new_state_dict, strict=True)
91
+ self.decoder = Decoder(dim_in=512,
92
+ style_dim=128,
93
+ dim_out=80, # n_mels
94
+ resblock_kernel_sizes=[3, 7, 11],
95
+ upsample_rates=[10, 5, 3, 2],
96
+ upsample_initial_channel=512,
97
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
98
+ upsample_kernel_sizes=[20, 10, 6, 4])
99
+ self.text_encoder = TextEncoder(channels=512,
100
+ kernel_size=5,
101
+ depth=3, # args['model_params']['n_layer'],
102
+ n_symbols=178, # args['model_params']['n_token']
103
+ )
104
+ self.predictor = ProsodyPredictor(style_dim=128,
105
+ d_hid=512,
106
+ nlayers=3, # OFFICIAL config.nlayers=5;
107
+ max_dur=50)
108
+ self.style_encoder = StyleEncoder()
109
+ self.predictor_encoder = StyleEncoder()
110
+ self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, 512)
111
+ self.mel_spec = MelSpec()
112
+ params = torch.load(hf_hub_download(repo_id='yl4579/StyleTTS2-LibriTTS',
113
+ filename='Models/LibriTTS/epochs_2nd_00020.pth'),
114
+ map_location='cpu')['net']
115
+ self.bert.load_state_dict(_del_prefix(params['bert']), strict=True)
116
+ self.bert_encoder.load_state_dict(_del_prefix(params['bert_encoder']), strict=True)
117
+ self.predictor.load_state_dict(_del_prefix(params['predictor']), strict=True)
118
+ self.decoder.load_state_dict(_del_prefix(params['decoder']), strict=True)
119
+ self.text_encoder.load_state_dict(_del_prefix(params['text_encoder']), strict=True)
120
+ self.predictor_encoder.load_state_dict(_del_prefix(params['predictor_encoder']), strict=True)
121
+ self.style_encoder.load_state_dict(_del_prefix(params['style_encoder']), strict=True)
122
+
123
+ # FOR LSTM
124
+ for n, p in self.named_parameters():
125
+ p.requires_grad = False
126
+ self.eval()
127
+
128
+
129
+ def device(self):
130
+ return self.style_encoder.unshared.weight.device
131
+
132
+ def compute_style(self, wav_file=None):
133
+
134
+ x, sr = librosa.load(wav_file, sr=24000)
135
+ x, _ = librosa.effects.trim(x, top_db=30)
136
+ if sr != 24000:
137
+ x = librosa.resample(x, sr, 24000)
138
+ # LOGMEL - Has 16KHz default basisc - Called on 24KHz .wav
139
+ x = torch.from_numpy(x[None, :]).to(device=self.device(),
140
+ dtype=torch.float)
141
+ mel_tensor = (torch.log(1e-5 + self.mel_spec(x)) + 4) / 4
142
+ #mel_tensor = preprocess(audio).to(device)
143
+ ref_s = self.style_encoder(mel_tensor)
144
+ ref_p = self.predictor_encoder(mel_tensor) # [bs, 11, 1, 128]
145
+ s = torch.cat([ref_s, ref_p], dim=3) # [bs, 11, 1, 256]
146
+ s = s[:, :, 0, :].transpose(1, 2) # [1, 128, 11]
147
+ return s # [1, 128, 11]
148
+
149
+ def inference(self,
150
+ text,
151
+ ref_s=None):
152
+ '''text may become too long when phonemized'''
153
+
154
+ if isinstance(ref_s, str):
155
+ ref_s = self.compute_style(ref_s)
156
+ else:
157
+ pass # assume ref_s = precomputed style vector
158
+
159
+
160
+ # text = transliterate_number(text, lang='en').strip()
161
+ # as we are in english transliteration is already done by the text cleaner?
162
+ # somehow we have phonemes in text that try to be rephonemized
163
+ # The ds txt should be only ascii
164
+
165
+
166
+ if isinstance(text, str):
167
+
168
+ _translator = str.maketrans('', '', string.punctuation)
169
+
170
+ text = [sub_sent.translate(_translator) + '.' for sub_sent in textwrap.wrap(text, 74)]
171
+
172
+ # # text = nltk.sent_tokenize(text)
173
+ # # text = [i for sent in sentences for i in textwrap.wrap(sent, width=120)]
174
+
175
+
176
+ # # text = textwrap.wrap(text, width=MAX_PHONEMES) # phonemes thus sent_tokenize() can't split them in sentences
177
+
178
+
179
+ device = ref_s.device
180
+ total = []
181
+ for _t in text:
182
+
183
+ _t = global_phonemizer.phonemize([_t])
184
+ _t = word_tokenize(_t[0])
185
+ _t = ' '.join(_t)
186
+
187
+ tokens = textclenaer(_t)[:MAX_PHONEMES] + [4] # textclenaer('.;?!') = [4,1,6,5] # append . punctuation to assure proper sound termination (pulse Issue)
188
+
189
+ # After filter we should assure is terminating as a sentence
190
+ # print(len(_t), len(tokens), 'Msi')#, textclenaer('.;?!'))
191
+ # ================================= Delete Phonemes If len(phonemes) > len(text) === OOM during training
192
+ tokens.insert(0, 0)
193
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
194
+ with torch.no_grad():
195
+ hidden_states = self.text_encoder(tokens)
196
+ bert_dur = self.bert(tokens, attention_mask=torch.ones_like(tokens)
197
+ ).last_hidden_state
198
+ d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
199
+ aln_trg, F0_pred, N_pred = self.predictor(d_en=d_en, s=ref_s[:, 128:, :])
200
+ asr = torch.bmm(aln_trg, hidden_states)
201
+ asr = asr.transpose(1, 2)
202
+ asr_new = torch.zeros_like(asr)
203
+ asr_new[:, :, 0] = asr[:, :, 0]
204
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
205
+ asr = asr_new
206
+ x = self.decoder(asr=asr,
207
+ F0_curve=F0_pred,
208
+ N=N_pred,
209
+ s=ref_s[:, :128, :]) # different part of ref_s
210
+ # print(x.shape, 'TTS TTS TTS TTS')
211
+ if x.shape[2] < 100:
212
+ x = torch.zeros(1, 1, 1000, device=self.device()) # silence if this sentence was empty
213
+
214
+ # NORMALIS / Crop Scratch at end (The endingscratch sound is not solved even with nltk.sentence split & punctuation)
215
+ x = x[..., 40:-4000]
216
+ # x /= x.abs().max() + 1e-7 # preserve as torch
217
+ # return x
218
+ if x.shape[2] == 0:
219
+ # nohing to vocode
220
+ x = torch.zeros(1, 1, 1000, device=self.device())
221
+ total.append(x)
222
+
223
+ # --
224
+ total = 1.94 * torch.cat(total, 2) # 1.94 * Perhaps exceeding -1,1 affects MIMI encode
225
+ total /= 1.02 * total.abs().max() + 1e-7
226
+ # --
227
+ return total
228
+
229
+
230
+
231
+
232
+ def get_padding(kernel_size, dilation=1):
233
+ return int((kernel_size*dilation - dilation)/2)
234
+
235
+
236
+ def _tile(x,
237
+ length=None):
238
+ x = x.repeat(1, 1, int(length / x.shape[2]) + 1)[:, :, :length]
239
+ return x
240
+
241
+
242
+ class AdaIN1d(nn.Module):
243
+
244
+ # used by HiFiGan & ProsodyPredictor
245
+
246
+ def __init__(self, style_dim, num_features):
247
+ super().__init__()
248
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
249
+ self.fc = nn.Linear(style_dim, num_features*2)
250
+
251
+ def forward(self, x, s):
252
+
253
+ # x = torch.Size([1, 512, 248]) same as output
254
+ # s = torch.Size([1, 7, 1, 128])
255
+
256
+ s = self.fc(s.transpose(1, 2)).transpose(1, 2)
257
+
258
+ s = _tile(s, length=x.shape[2])
259
+
260
+ gamma, beta = torch.chunk(s, chunks=2, dim=1)
261
+ return (1+gamma) * self.norm(x) + beta
262
+
263
+
264
+ class AdaINResBlock1(torch.nn.Module):
265
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
266
+ super(AdaINResBlock1, self).__init__()
267
+ self.convs1 = nn.ModuleList([
268
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
269
+ padding=get_padding(kernel_size, dilation[0]))),
270
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
271
+ padding=get_padding(kernel_size, dilation[1]))),
272
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
273
+ padding=get_padding(kernel_size, dilation[2])))
274
+ ])
275
+ # self.convs1.apply(init_weights)
276
+
277
+ self.convs2 = nn.ModuleList([
278
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
279
+ padding=get_padding(kernel_size, 1))),
280
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
281
+ padding=get_padding(kernel_size, 1))),
282
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
283
+ padding=get_padding(kernel_size, 1)))
284
+ ])
285
+ # self.convs2.apply(init_weights)
286
+
287
+ self.adain1 = nn.ModuleList([
288
+ AdaIN1d(style_dim, channels),
289
+ AdaIN1d(style_dim, channels),
290
+ AdaIN1d(style_dim, channels),
291
+ ])
292
+
293
+ self.adain2 = nn.ModuleList([
294
+ AdaIN1d(style_dim, channels),
295
+ AdaIN1d(style_dim, channels),
296
+ AdaIN1d(style_dim, channels),
297
+ ])
298
+
299
+ self.alpha1 = nn.ParameterList(
300
+ [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
301
+ self.alpha2 = nn.ParameterList(
302
+ [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
303
+
304
+ def forward(self, x, s):
305
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
306
+ xt = n1(x, s) # THIS IS ADAIN - EXPECTS conv1d dims
307
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
308
+ xt = c1(xt)
309
+ xt = n2(xt, s) # THIS IS ADAIN - EXPECTS conv1d dims
310
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
311
+ xt = c2(xt)
312
+ x = xt + x
313
+ return x
314
+
315
+
316
+ class SourceModuleHnNSF(torch.nn.Module):
317
+
318
+ def __init__(self):
319
+
320
+ super().__init__()
321
+ self.harmonic_num = 8
322
+ self.l_linear = torch.nn.Linear(self.harmonic_num + 1, 1)
323
+ self.upsample_scale = 300
324
+
325
+
326
+ def forward(self, x):
327
+ # --
328
+ x = torch.multiply(x, torch.FloatTensor(
329
+ [[range(1, self.harmonic_num + 2)]]).to(x.device)) # [1, 145200, 9]
330
+
331
+ # modulo of negative f0_values => -21 % 10 = 9 as -3*10 + 9 = 21 NOTICE THAT f0_values IS SIGNED
332
+ rad_values = x / 25647 #).clamp(0, 1)
333
+ # rad_values = torch.where(torch.logical_or(rad_values < 0, rad_values > 1), 0.5, rad_values)
334
+ rad_values = rad_values % 1 # % of neg values
335
+ rad_values = F.interpolate(rad_values.transpose(1, 2),
336
+ scale_factor=1/self.upsample_scale,
337
+ mode='linear').transpose(1, 2)
338
+
339
+ # 1.89 sounds also nice has woofer at punctuation
340
+ phase = torch.cumsum(rad_values, dim=1) * 1.84 * np.pi
341
+ phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale,
342
+ scale_factor=self.upsample_scale, mode='linear').transpose(1, 2)
343
+ x = .009 * phase.sin()
344
+ # --
345
+ x = self.l_linear(x).tanh()
346
+ return x
347
+
348
+
349
+ class Generator(torch.nn.Module):
350
+ def __init__(self,
351
+ style_dim,
352
+ resblock_kernel_sizes,
353
+ upsample_rates,
354
+ upsample_initial_channel,
355
+ resblock_dilation_sizes,
356
+ upsample_kernel_sizes):
357
+ super(Generator, self).__init__()
358
+ self.num_kernels = len(resblock_kernel_sizes)
359
+ self.num_upsamples = len(upsample_rates)
360
+ self.m_source = SourceModuleHnNSF()
361
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
362
+ self.noise_convs = nn.ModuleList()
363
+ self.ups = nn.ModuleList()
364
+ self.noise_res = nn.ModuleList()
365
+
366
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
367
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
368
+
369
+ self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel//(2**i),
370
+ upsample_initial_channel//(
371
+ 2**(i+1)),
372
+ k, u, padding=(u//2 + u % 2), output_padding=u % 2)))
373
+
374
+ if i + 1 < len(upsample_rates):
375
+ stride_f0 = np.prod(upsample_rates[i + 1:])
376
+ self.noise_convs.append(Conv1d(
377
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
378
+ self.noise_res.append(AdaINResBlock1(
379
+ c_cur, 7, [1, 3, 5], style_dim))
380
+ else:
381
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
382
+ self.noise_res.append(AdaINResBlock1(
383
+ c_cur, 11, [1, 3, 5], style_dim))
384
+
385
+ self.resblocks = nn.ModuleList()
386
+
387
+ self.alphas = nn.ParameterList()
388
+ self.alphas.append(nn.Parameter(
389
+ torch.ones(1, upsample_initial_channel, 1)))
390
+
391
+ for i in range(len(self.ups)):
392
+ ch = upsample_initial_channel//(2**(i+1))
393
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
394
+
395
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
396
+ self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
397
+
398
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
399
+
400
+ def forward(self, x, s, f0):
401
+
402
+ # x.shape=torch.Size([1, 512, 484]) s.shape=torch.Size([1, 1, 1, 128]) f0.shape=torch.Size([1, 484]) GENERAT 249
403
+ f0 = self.f0_upsamp(f0).transpose(1, 2)
404
+
405
+ # x.shape=torch.Size([1, 512, 484]) s.shape=torch.Size([1, 1, 1, 128]) f0.shape=torch.Size([1, 145200, 1]) GENERAT 253
406
+
407
+ # [1, 145400, 1] f0 enters already upsampled to full wav 24kHz length
408
+ har_source = self.m_source(f0)
409
+
410
+ har_source = har_source.transpose(1, 2)
411
+
412
+ for i in range(self.num_upsamples):
413
+
414
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
415
+ x_source = self.noise_convs[i](har_source)
416
+ x_source = self.noise_res[i](x_source, s)
417
+
418
+ x = self.ups[i](x)
419
+
420
+ x = x + x_source
421
+
422
+ xs = None
423
+ for j in range(self.num_kernels):
424
+
425
+ if xs is None:
426
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
427
+ else:
428
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
429
+ x = xs / self.num_kernels
430
+ # x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2) # noisy
431
+ x = self.conv_post(x)
432
+ x = torch.tanh(x)
433
+
434
+ return x
435
+
436
+ class AdainResBlk1d(nn.Module):
437
+
438
+ # also used in ProsodyPredictor()
439
+
440
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
441
+ upsample='none', dropout_p=0.0):
442
+ super().__init__()
443
+ self.actv = actv
444
+ self.upsample_type = upsample
445
+ self.upsample = UpSample1d(upsample)
446
+ self.learned_sc = dim_in != dim_out
447
+ self._build_weights(dim_in, dim_out, style_dim)
448
+ if upsample == 'none':
449
+ self.pool = nn.Identity()
450
+ else:
451
+ self.pool = weight_norm(nn.ConvTranspose1d(
452
+ dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
453
+
454
+ def _build_weights(self, dim_in, dim_out, style_dim):
455
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
456
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
457
+ self.norm1 = AdaIN1d(style_dim, dim_in)
458
+ self.norm2 = AdaIN1d(style_dim, dim_out)
459
+ if self.learned_sc:
460
+ self.conv1x1 = weight_norm(
461
+ nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
462
+
463
+ def _shortcut(self, x):
464
+ x = self.upsample(x)
465
+ if self.learned_sc:
466
+ x = self.conv1x1(x)
467
+ return x
468
+
469
+ def _residual(self, x, s):
470
+ x = self.norm1(x, s)
471
+ x = self.actv(x)
472
+ x = self.pool(x)
473
+ x = self.conv1(x)
474
+ x = self.norm2(x, s)
475
+ x = self.actv(x)
476
+ x = self.conv2(x)
477
+ return x
478
+
479
+ def forward(self, x, s):
480
+ out = self._residual(x, s)
481
+ out = (out + self._shortcut(x)) / math.sqrt(2)
482
+ return out
483
+
484
+
485
+ class UpSample1d(nn.Module):
486
+ def __init__(self, layer_type):
487
+ super().__init__()
488
+ self.layer_type = layer_type
489
+
490
+ def forward(self, x):
491
+ if self.layer_type == 'none':
492
+ return x
493
+ else:
494
+ return F.interpolate(x, scale_factor=2, mode='nearest-exact')
495
+
496
+
497
+ class Decoder(nn.Module):
498
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
499
+ resblock_kernel_sizes=[3, 7, 11],
500
+ upsample_rates=[10, 5, 3, 2],
501
+ upsample_initial_channel=512,
502
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
503
+ upsample_kernel_sizes=[20, 10, 6, 4]):
504
+ super().__init__()
505
+
506
+ self.decode = nn.ModuleList()
507
+
508
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
509
+
510
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
511
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
512
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
513
+ self.decode.append(AdainResBlk1d(
514
+ 1024 + 2 + 64, 512, style_dim, upsample=True))
515
+
516
+ self.F0_conv = weight_norm(
517
+ nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) # smooth
518
+
519
+ self.N_conv = weight_norm(
520
+ nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
521
+
522
+ self.asr_res = nn.Sequential(
523
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
524
+ )
525
+
526
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
527
+ upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
528
+
529
+ def forward(self, asr=None, F0_curve=None, N=None, s=None):
530
+
531
+
532
+ F0 = self.F0_conv(F0_curve)
533
+ N = self.N_conv(N)
534
+
535
+
536
+ x = torch.cat([asr, F0, N], axis=1)
537
+
538
+ x = self.encode(x, s)
539
+
540
+ asr_res = self.asr_res(asr)
541
+
542
+ res = True
543
+ for block in self.decode:
544
+ if res:
545
+
546
+ x = torch.cat([x, asr_res, F0, N], axis=1)
547
+
548
+ x = block(x, s)
549
+ if block.upsample_type != "none":
550
+ res = False
551
+
552
+ x = self.generator(x, s, F0_curve)
553
+ return x
554
+
555
+
556
+ class MelSpec(torch.nn.Module):
557
+
558
+ def __init__(self,
559
+ sample_rate=17402, # https://github.com/fakerybakery/styletts2-cli/blob/main/msinference.py = Default 16000. However 17400 vocalises better also "en_US/vctk_p274"
560
+ n_fft=2048,
561
+ win_length=1200,
562
+ hop_length=300,
563
+ n_mels=80
564
+ ):
565
+ '''avoids dependency on torchaudio'''
566
+ super().__init__()
567
+ self.n_fft = n_fft
568
+ self.win_length = win_length if win_length is not None else n_fft
569
+ self.hop_length = hop_length if hop_length is not None else self.win_length // 2
570
+ # --
571
+ f_min = 0.0
572
+ f_max = float(sample_rate // 2)
573
+ all_freqs = torch.linspace(0, sample_rate // 2, n_fft//2+1)
574
+ m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
575
+ m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
576
+ m_pts = torch.linspace(m_min, m_max, n_mels + 2)
577
+ f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
578
+ f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
579
+ slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)
580
+ zero = torch.zeros(1)
581
+ down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
582
+ up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
583
+ fb = torch.max(zero, torch.min(down_slopes, up_slopes))
584
+ # --
585
+ self.register_buffer('fb', fb, persistent=False)
586
+ window = torch.hann_window(self.win_length)
587
+ self.register_buffer('window', window, persistent=False)
588
+
589
+ def forward(self, x):
590
+ spec_f = torch.stft(x,
591
+ self.n_fft,
592
+ self.hop_length,
593
+ self.win_length,
594
+ self.window,
595
+ center=True,
596
+ pad_mode="reflect",
597
+ normalized=False,
598
+ onesided=True,
599
+ return_complex=True) # [bs, 1025, 56]
600
+ mel_specgram = torch.matmul(spec_f.abs().pow(2).transpose(1, 2), self.fb).transpose(1, 2)
601
+ return mel_specgram[:, None, :, :] # [bs, 1, 80, time]
602
+
603
+
604
+ class LearnedDownSample(nn.Module):
605
+ def __init__(self, dim_in):
606
+ super().__init__()
607
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(
608
+ 3, 3), stride=(2, 2), groups=dim_in, padding=1))
609
+
610
+ def forward(self, x):
611
+ return self.conv(x)
612
+
613
+
614
+ class ResBlk(nn.Module):
615
+ def __init__(self,
616
+ dim_in, dim_out):
617
+ super().__init__()
618
+ self.actv = nn.LeakyReLU(0.2) # .07 also nice
619
+ self.downsample_res = LearnedDownSample(dim_in)
620
+ self.learned_sc = dim_in != dim_out
621
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
622
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
623
+ if self.learned_sc:
624
+ self.conv1x1 = spectral_norm(
625
+ nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
626
+
627
+ def _shortcut(self, x):
628
+ if self.learned_sc:
629
+ x = self.conv1x1(x)
630
+ if x.shape[3] % 2 != 0: # [bs, 128, Freq, Time]
631
+ x = torch.cat([x, x[:, :, :, -1:]], dim=3)
632
+ return F.interpolate(x, scale_factor=.5, mode='nearest-exact') # F.avg_pool2d(x, 2)
633
+
634
+ def _residual(self, x):
635
+ x = self.actv(x)
636
+ x = self.conv1(x)
637
+ x = self.downsample_res(x)
638
+ x = self.actv(x)
639
+ x = self.conv2(x)
640
+ return x
641
+
642
+ def forward(self, x):
643
+ x = self._shortcut(x) + self._residual(x)
644
+ return x / math.sqrt(2) # unit variance
645
+
646
+
647
+ class StyleEncoder(nn.Module):
648
+
649
+ # for both acoustic & prosodic ref_s/p
650
+
651
+ def __init__(self,
652
+ dim_in=64,
653
+ style_dim=128,
654
+ max_conv_dim=512):
655
+ super().__init__()
656
+ blocks = [spectral_norm(nn.Conv2d(1, dim_in, 3, stride=1, padding=1))]
657
+ for _ in range(4):
658
+ dim_out = min(dim_in * 2,
659
+ max_conv_dim)
660
+ blocks += [ResBlk(dim_in, dim_out)]
661
+ dim_in = dim_out
662
+ blocks += [nn.LeakyReLU(0.24), # w/o this activation - produces no speech
663
+ spectral_norm(nn.Conv2d(dim_out, dim_out, 5, stride=1, padding=0)),
664
+ nn.LeakyReLU(0.2) # 0.3 sounds nice
665
+ ]
666
+ self.shared = nn.Sequential(*blocks)
667
+ self.unshared = nn.Linear(dim_out, style_dim)
668
+
669
+ def forward(self, x):
670
+ x = self.shared(x)
671
+ x = x.mean(3, keepdims=True) # comment this line for time varying style vector
672
+ x = x.transpose(1, 3)
673
+ s = self.unshared(x)
674
+ return s
675
+
676
+
677
+ class LinearNorm(torch.nn.Module):
678
+ def __init__(self, in_dim, out_dim, bias=True):
679
+ super().__init__()
680
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
681
+
682
+ def forward(self, x):
683
+ return self.linear_layer(x)
684
+
685
+
686
+ class LayerNorm(nn.Module):
687
+ def __init__(self, channels, eps=1e-5):
688
+ super().__init__()
689
+ self.channels = channels
690
+ self.eps = eps
691
+
692
+ self.gamma = nn.Parameter(torch.ones(channels))
693
+ self.beta = nn.Parameter(torch.zeros(channels))
694
+
695
+ def forward(self, x):
696
+ x = x.transpose(1, -1)
697
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
698
+ return x.transpose(1, -1)
699
+
700
+
701
+ class TextEncoder(nn.Module):
702
+ def __init__(self, channels, kernel_size, depth, n_symbols):
703
+ super().__init__()
704
+ self.embedding = nn.Embedding(n_symbols, channels)
705
+ padding = (kernel_size - 1) // 2
706
+ self.cnn = nn.ModuleList()
707
+ for _ in range(depth):
708
+ self.cnn.append(nn.Sequential(
709
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
710
+ LayerNorm(channels),
711
+ nn.LeakyReLU(0.24))
712
+ )
713
+ self.lstm = nn.LSTM(channels, channels//2, 1,
714
+ batch_first=True, bidirectional=True)
715
+
716
+ def forward(self, x):
717
+ x = self.embedding(x) # [B, T, emb]
718
+ x = x.transpose(1, 2)
719
+ for c in self.cnn:
720
+ x = c(x)
721
+ x = x.transpose(1, 2)
722
+ x, _ = self.lstm(x)
723
+ return x
724
+
725
+
726
+ class AdaLayerNorm(nn.Module):
727
+
728
+ def __init__(self, style_dim, channels=None, eps=1e-5):
729
+ super().__init__()
730
+ self.eps = eps
731
+ self.fc = nn.Linear(style_dim, 1024)
732
+
733
+ def forward(self, x, s):
734
+ h = self.fc(s)
735
+ gamma = h[:, :, :512]
736
+ beta = h[:, :, 512:1024]
737
+ x = F.layer_norm(x, (512, ), eps=self.eps)
738
+ x = (1 + gamma) * x + beta
739
+ return x # [1, 75, 512]
740
+
741
+
742
+ class ProsodyPredictor(nn.Module):
743
+
744
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50):
745
+ super().__init__()
746
+
747
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
748
+ d_model=d_hid,
749
+ nlayers=nlayers) # called outside forward
750
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2,
751
+ 1, batch_first=True, bidirectional=True)
752
+ self.duration_proj = LinearNorm(d_hid, max_dur)
753
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid //
754
+ 2, 1, batch_first=True, bidirectional=True)
755
+ self.F0 = nn.ModuleList([
756
+ AdainResBlk1d(d_hid, d_hid, style_dim),
757
+ AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True),
758
+ AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim),
759
+ ])
760
+ self.N = nn.ModuleList([
761
+ AdainResBlk1d(d_hid, d_hid, style_dim),
762
+ AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True),
763
+ AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim)
764
+ ])
765
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
766
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
767
+
768
+ def F0Ntrain(self, x, s):
769
+
770
+ x, _ = self.shared(x) # [bs, time, ch] LSTM
771
+
772
+ x = x.transpose(1, 2) # [bs, ch, time]
773
+
774
+ F0 = x
775
+
776
+ for block in self.F0:
777
+ # print(f'LOOP {F0.shape=} {s.shape=}\n')
778
+ # )N F0.shape=torch.Size([1, 512, 147]) s.shape=torch.Size([1, 128])
779
+ # This is an AdainResBlk1d expects conv1d dimensions
780
+ F0 = block(F0, s)
781
+ F0 = self.F0_proj(F0)
782
+
783
+ N = x
784
+
785
+ for block in self.N:
786
+ N = block(N, s)
787
+ N = self.N_proj(N)
788
+
789
+ return F0, N
790
+
791
+ def forward(self, d_en=None, s=None):
792
+ blend = self.text_encoder(d_en, s)
793
+ x, _ = self.lstm(blend)
794
+ dur = self.duration_proj(x) # [bs, 150, 50]
795
+
796
+ _, input_length, classifier_50 = dur.shape
797
+
798
+ dur = dur[0, :, :]
799
+ dur = torch.sigmoid(dur).sum(1)
800
+ dur = dur.round().clamp(min=1).to(torch.int64)
801
+ aln_trg = torch.zeros(1,
802
+ dur.sum(),
803
+ input_length,
804
+ device=s.device)
805
+ c_frame = 0
806
+ for i in range(input_length):
807
+ aln_trg[:, c_frame:c_frame + dur[i], i] = 1
808
+ c_frame += dur[i]
809
+ en = torch.bmm(aln_trg, blend)
810
+ F0_pred, N_pred = self.F0Ntrain(en, s)
811
+ return aln_trg, F0_pred, N_pred
812
+
813
+
814
+ class DurationEncoder(nn.Module):
815
+
816
+ def __init__(self, sty_dim=128, d_model=512, nlayers=3):
817
+ super().__init__()
818
+ self.lstms = nn.ModuleList()
819
+ for _ in range(nlayers):
820
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
821
+ d_model // 2,
822
+ num_layers=1,
823
+ batch_first=True,
824
+ bidirectional=True
825
+ ))
826
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
827
+
828
+
829
+ def forward(self, x, style):
830
+
831
+ _, _, input_lengths = x.shape # [bs, 512, time]
832
+
833
+ style = _tile(style, length=x.shape[2]).transpose(1, 2)
834
+ x = x.transpose(1, 2)
835
+
836
+ for block in self.lstms:
837
+ if isinstance(block, AdaLayerNorm):
838
+
839
+ x = block(x, style) # LSTM has transposed x
840
+
841
+ else:
842
+ x = torch.cat([x, style], axis=2)
843
+ # LSTM
844
+
845
+ x,_ = block(x) # expects [bs, time, chan] OUTPUTS [bs, time, 2*chan] 2x FROM BIDIRECTIONAL
846
+
847
+ return torch.cat([x, style], axis=2) # predictor.lstm()
wav/af_ZA_google-nwu_0184.wav ADDED
Binary file (92 kB). View file
 
wav/af_ZA_google-nwu_1919.wav ADDED
Binary file (92 kB). View file
 
wav/af_ZA_google-nwu_2418.wav ADDED
Binary file (92 kB). View file
 
wav/af_ZA_google-nwu_6590.wav ADDED
Binary file (92 kB). View file
 
wav/af_ZA_google-nwu_7130.wav ADDED
Binary file (92 kB). View file
 
wav/af_ZA_google-nwu_7214.wav ADDED
Binary file (92 kB). View file
 
wav/af_ZA_google-nwu_8148.wav ADDED
Binary file (92 kB). View file
 
wav/af_ZA_google-nwu_8924.wav ADDED
Binary file (92 kB). View file
 
wav/af_ZA_google-nwu_8963.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_00737.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_00779.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_01232.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_01701.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_03042.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_0834.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_1010.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_3108.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_3713.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_3958.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_4046.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_4811.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_5958.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_9169.wav ADDED
Binary file (92 kB). View file
 
wav/bn_multi_rm.wav ADDED
Binary file (92 kB). View file
 
wav/de_DE_m-ailabs_angela_merkel.wav ADDED
Binary file (90.7 kB). View file
 
wav/de_DE_m-ailabs_eva_k.wav ADDED
Binary file (92.7 kB). View file
 
wav/de_DE_m-ailabs_karlsson.wav ADDED
Binary file (92.7 kB). View file
 
wav/de_DE_m-ailabs_ramona_deininger.wav ADDED
Binary file (91.2 kB). View file
 
wav/de_DE_m-ailabs_rebecca_braunert_plunkett.wav ADDED
Binary file (91.2 kB). View file
 
wav/de_DE_thorsten-emotion_amused.wav ADDED
Binary file (92 kB). View file
 
wav/el_GR_rapunzelina.wav ADDED
Binary file (92 kB). View file
 
wav/en_UK_apope.wav ADDED
Binary file (92 kB). View file
 
wav/en_US_cmu_arctic_aew.wav ADDED
Binary file (92 kB). View file
 
wav/en_US_cmu_arctic_aup.wav ADDED
Binary file (94.3 kB). View file
 
wav/en_US_cmu_arctic_awb.wav ADDED
Binary file (92 kB). View file
 
wav/en_US_cmu_arctic_awbrms.wav ADDED
Binary file (92.7 kB). View file
 
wav/en_US_cmu_arctic_axb.wav ADDED
Binary file (92 kB). View file
 
wav/en_US_cmu_arctic_bdl.wav ADDED
Binary file (94.8 kB). View file
 
wav/en_US_cmu_arctic_clb.wav ADDED
Binary file (92 kB). View file
 
wav/en_US_cmu_arctic_eey.wav ADDED
Binary file (95.3 kB). View file
 
wav/en_US_cmu_arctic_fem.wav ADDED
Binary file (94.8 kB). View file
 
wav/en_US_cmu_arctic_gka.wav ADDED
Binary file (95.3 kB). View file
 
wav/en_US_cmu_arctic_jmk.wav ADDED
Binary file (93.2 kB). View file
 
wav/en_US_cmu_arctic_ksp.wav ADDED
Binary file (92 kB). View file