mitchsayre commited on
Commit
f71bc95
·
1 Parent(s): ce40172
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .venv
2
+ out.wav
3
+ __pycache__/
4
+ *.pyc
5
+ *.egg-info/
6
+ dist/
7
+ build/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Wfloat
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -3,4 +3,123 @@ license: mit
3
  language:
4
  - en
5
  pipeline_tag: text-to-speech
6
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  language:
4
  - en
5
  pipeline_tag: text-to-speech
6
+ ---
7
+
8
+ # wfloat-tts
9
+
10
+ `wfloat-tts` is a lightweight multi-speaker English VITS text-to-speech model with explicit speaker, emotion, and intensity control.
11
+
12
+ This repo includes:
13
+
14
+ - `model.safetensors`: inference weights
15
+ - `config.json`: model config and token mapping
16
+ - `src/wfloat_tts/`: a small Python inference helper
17
+
18
+ The repo is set up for standalone inference from the released model files. You do not need the original training codebase to synthesize speech with it.
19
+
20
+ ## Inputs
21
+
22
+ The intended inference inputs are:
23
+
24
+ - `text`: the utterance to synthesize
25
+ - `sid`: numeric speaker id
26
+ - `emotion`: emotion label
27
+ - `intensity`: value from `0.0` to `1.0`
28
+
29
+ You do not need to pass raw control symbols. The Python helper converts `emotion` and `intensity` into the control tokens the model was trained on.
30
+
31
+ ## Install
32
+
33
+ ```bash
34
+ pip install -e .
35
+ pip install "piper-phonemize==1.3.0" -f https://k2-fsa.github.io/icefall/piper_phonemize
36
+ ```
37
+
38
+ Runtime dependencies:
39
+
40
+ - `torch`
41
+ - `numpy`
42
+ - `safetensors`
43
+ - `piper-phonemize`
44
+
45
+ `piper-phonemize` is installed separately because the current recommended wheels are hosted here:
46
+
47
+ - https://k2-fsa.github.io/icefall/piper_phonemize
48
+
49
+ ## Python Example
50
+
51
+ ```python
52
+ from wfloat_tts import load_generator, write_wave
53
+
54
+ generator = load_generator(
55
+ checkpoint_path="model.safetensors",
56
+ config_path="config.json",
57
+ )
58
+
59
+ audio = generator.generate(
60
+ text="Hey there, how are you today?",
61
+ sid=11,
62
+ emotion="neutral",
63
+ intensity=0.5,
64
+ )
65
+
66
+ write_wave("out.wav", audio.samples, audio.sample_rate)
67
+ ```
68
+
69
+ ## How It Is Conditioned
70
+
71
+ This model was trained to condition on:
72
+
73
+ - speaker id
74
+ - one emotion control token
75
+ - one intensity control token
76
+
77
+ The reference inference path processes a full utterance, appends one emotion token and one intensity token for the whole utterance, and runs synthesis over that full sequence.
78
+
79
+ ## Speaker IDs
80
+
81
+ Use numeric `sid` values:
82
+
83
+ | Speaker | SID |
84
+ | --- | ---: |
85
+ | `skilled_hero_man` | 0 |
86
+ | `skilled_hero_woman` | 1 |
87
+ | `fun_hero_man` | 2 |
88
+ | `fun_hero_woman` | 3 |
89
+ | `strong_hero_man` | 4 |
90
+ | `strong_hero_woman` | 5 |
91
+ | `mad_scientist_man` | 6 |
92
+ | `mad_scientist_woman` | 7 |
93
+ | `clever_villain_man` | 8 |
94
+ | `clever_villain_woman` | 9 |
95
+ | `narrator_man` | 10 |
96
+ | `narrator_woman` | 11 |
97
+ | `wise_elder_man` | 12 |
98
+ | `wise_elder_woman` | 13 |
99
+ | `outgoing_anime_man` | 14 |
100
+ | `outgoing_anime_woman` | 15 |
101
+ | `scary_villain_man` | 16 |
102
+ | `scary_villain_woman` | 17 |
103
+ | `news_reporter_man` | 18 |
104
+ | `news_reporter_woman` | 19 |
105
+
106
+ ## Emotions
107
+
108
+ Supported emotion labels:
109
+
110
+ - `neutral`
111
+ - `joy`
112
+ - `sadness`
113
+ - `anger`
114
+ - `fear`
115
+ - `surprise`
116
+ - `dismissive`
117
+ - `confusion`
118
+
119
+ `intensity` is clamped to the range `[0.0, 1.0]` and mapped to one of ten discrete intensity levels.
120
+
121
+ ## Notes
122
+
123
+ - `model.safetensors` is the main inference artifact in this repo.
124
+ - `config.json` includes the token mapping needed by the processor.
125
+ - The current release uses a multi-speaker model with 20 speakers.
config.json ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset": "wumbospeech0",
3
+ "audio": {
4
+ "sample_rate": 22050,
5
+ "quality": "wumbospeech0"
6
+ },
7
+ "espeak": {
8
+ "voice": "en-us"
9
+ },
10
+ "language": {
11
+ "code": "en-us"
12
+ },
13
+ "inference": {
14
+ "noise_scale": 0.667,
15
+ "length_scale": 1,
16
+ "noise_w": 0.8
17
+ },
18
+ "phoneme_type": "espeak",
19
+ "phoneme_map": {},
20
+ "phoneme_id_map": {
21
+ " ": [
22
+ 3
23
+ ],
24
+ "!": [
25
+ 4
26
+ ],
27
+ "\"": [
28
+ 150
29
+ ],
30
+ "#": [
31
+ 149
32
+ ],
33
+ "$": [
34
+ 2
35
+ ],
36
+ "'": [
37
+ 5
38
+ ],
39
+ "(": [
40
+ 6
41
+ ],
42
+ ")": [
43
+ 7
44
+ ],
45
+ ",": [
46
+ 8
47
+ ],
48
+ "-": [
49
+ 9
50
+ ],
51
+ ".": [
52
+ 10
53
+ ],
54
+ "0": [
55
+ 130
56
+ ],
57
+ "1": [
58
+ 131
59
+ ],
60
+ "2": [
61
+ 132
62
+ ],
63
+ "3": [
64
+ 133
65
+ ],
66
+ "4": [
67
+ 134
68
+ ],
69
+ "5": [
70
+ 135
71
+ ],
72
+ "6": [
73
+ 136
74
+ ],
75
+ "7": [
76
+ 137
77
+ ],
78
+ "8": [
79
+ 138
80
+ ],
81
+ "9": [
82
+ 139
83
+ ],
84
+ ":": [
85
+ 11
86
+ ],
87
+ ";": [
88
+ 12
89
+ ],
90
+ "?": [
91
+ 13
92
+ ],
93
+ "X": [
94
+ 156
95
+ ],
96
+ "^": [
97
+ 1
98
+ ],
99
+ "_": [
100
+ 0
101
+ ],
102
+ "a": [
103
+ 14
104
+ ],
105
+ "b": [
106
+ 15
107
+ ],
108
+ "c": [
109
+ 16
110
+ ],
111
+ "d": [
112
+ 17
113
+ ],
114
+ "e": [
115
+ 18
116
+ ],
117
+ "f": [
118
+ 19
119
+ ],
120
+ "g": [
121
+ 154
122
+ ],
123
+ "h": [
124
+ 20
125
+ ],
126
+ "i": [
127
+ 21
128
+ ],
129
+ "j": [
130
+ 22
131
+ ],
132
+ "k": [
133
+ 23
134
+ ],
135
+ "l": [
136
+ 24
137
+ ],
138
+ "m": [
139
+ 25
140
+ ],
141
+ "n": [
142
+ 26
143
+ ],
144
+ "o": [
145
+ 27
146
+ ],
147
+ "p": [
148
+ 28
149
+ ],
150
+ "q": [
151
+ 29
152
+ ],
153
+ "r": [
154
+ 30
155
+ ],
156
+ "s": [
157
+ 31
158
+ ],
159
+ "t": [
160
+ 32
161
+ ],
162
+ "u": [
163
+ 33
164
+ ],
165
+ "v": [
166
+ 34
167
+ ],
168
+ "w": [
169
+ 35
170
+ ],
171
+ "x": [
172
+ 36
173
+ ],
174
+ "y": [
175
+ 37
176
+ ],
177
+ "z": [
178
+ 38
179
+ ],
180
+ "æ": [
181
+ 39
182
+ ],
183
+ "ç": [
184
+ 40
185
+ ],
186
+ "ð": [
187
+ 41
188
+ ],
189
+ "ø": [
190
+ 42
191
+ ],
192
+ "ħ": [
193
+ 43
194
+ ],
195
+ "ŋ": [
196
+ 44
197
+ ],
198
+ "œ": [
199
+ 45
200
+ ],
201
+ "ǀ": [
202
+ 46
203
+ ],
204
+ "ǁ": [
205
+ 47
206
+ ],
207
+ "ǂ": [
208
+ 48
209
+ ],
210
+ "ǃ": [
211
+ 49
212
+ ],
213
+ "ɐ": [
214
+ 50
215
+ ],
216
+ "ɑ": [
217
+ 51
218
+ ],
219
+ "ɒ": [
220
+ 52
221
+ ],
222
+ "ɓ": [
223
+ 53
224
+ ],
225
+ "ɔ": [
226
+ 54
227
+ ],
228
+ "ɕ": [
229
+ 55
230
+ ],
231
+ "ɖ": [
232
+ 56
233
+ ],
234
+ "ɗ": [
235
+ 57
236
+ ],
237
+ "ɘ": [
238
+ 58
239
+ ],
240
+ "ə": [
241
+ 59
242
+ ],
243
+ "ɚ": [
244
+ 60
245
+ ],
246
+ "ɛ": [
247
+ 61
248
+ ],
249
+ "ɜ": [
250
+ 62
251
+ ],
252
+ "ɞ": [
253
+ 63
254
+ ],
255
+ "ɟ": [
256
+ 64
257
+ ],
258
+ "ɠ": [
259
+ 65
260
+ ],
261
+ "ɡ": [
262
+ 66
263
+ ],
264
+ "ɢ": [
265
+ 67
266
+ ],
267
+ "ɣ": [
268
+ 68
269
+ ],
270
+ "ɤ": [
271
+ 69
272
+ ],
273
+ "ɥ": [
274
+ 70
275
+ ],
276
+ "ɦ": [
277
+ 71
278
+ ],
279
+ "ɧ": [
280
+ 72
281
+ ],
282
+ "ɨ": [
283
+ 73
284
+ ],
285
+ "ɪ": [
286
+ 74
287
+ ],
288
+ "ɫ": [
289
+ 75
290
+ ],
291
+ "ɬ": [
292
+ 76
293
+ ],
294
+ "ɭ": [
295
+ 77
296
+ ],
297
+ "ɮ": [
298
+ 78
299
+ ],
300
+ "ɯ": [
301
+ 79
302
+ ],
303
+ "ɰ": [
304
+ 80
305
+ ],
306
+ "ɱ": [
307
+ 81
308
+ ],
309
+ "ɲ": [
310
+ 82
311
+ ],
312
+ "ɳ": [
313
+ 83
314
+ ],
315
+ "ɴ": [
316
+ 84
317
+ ],
318
+ "ɵ": [
319
+ 85
320
+ ],
321
+ "ɶ": [
322
+ 86
323
+ ],
324
+ "ɸ": [
325
+ 87
326
+ ],
327
+ "ɹ": [
328
+ 88
329
+ ],
330
+ "ɺ": [
331
+ 89
332
+ ],
333
+ "ɻ": [
334
+ 90
335
+ ],
336
+ "ɽ": [
337
+ 91
338
+ ],
339
+ "ɾ": [
340
+ 92
341
+ ],
342
+ "ʀ": [
343
+ 93
344
+ ],
345
+ "ʁ": [
346
+ 94
347
+ ],
348
+ "ʂ": [
349
+ 95
350
+ ],
351
+ "ʃ": [
352
+ 96
353
+ ],
354
+ "ʄ": [
355
+ 97
356
+ ],
357
+ "ʈ": [
358
+ 98
359
+ ],
360
+ "ʉ": [
361
+ 99
362
+ ],
363
+ "ʊ": [
364
+ 100
365
+ ],
366
+ "ʋ": [
367
+ 101
368
+ ],
369
+ "ʌ": [
370
+ 102
371
+ ],
372
+ "ʍ": [
373
+ 103
374
+ ],
375
+ "ʎ": [
376
+ 104
377
+ ],
378
+ "ʏ": [
379
+ 105
380
+ ],
381
+ "ʐ": [
382
+ 106
383
+ ],
384
+ "ʑ": [
385
+ 107
386
+ ],
387
+ "ʒ": [
388
+ 108
389
+ ],
390
+ "ʔ": [
391
+ 109
392
+ ],
393
+ "ʕ": [
394
+ 110
395
+ ],
396
+ "ʘ": [
397
+ 111
398
+ ],
399
+ "ʙ": [
400
+ 112
401
+ ],
402
+ "ʛ": [
403
+ 113
404
+ ],
405
+ "ʜ": [
406
+ 114
407
+ ],
408
+ "ʝ": [
409
+ 115
410
+ ],
411
+ "ʟ": [
412
+ 116
413
+ ],
414
+ "ʡ": [
415
+ 117
416
+ ],
417
+ "ʢ": [
418
+ 118
419
+ ],
420
+ "ʦ": [
421
+ 155
422
+ ],
423
+ "ʰ": [
424
+ 145
425
+ ],
426
+ "ʲ": [
427
+ 119
428
+ ],
429
+ "ˈ": [
430
+ 120
431
+ ],
432
+ "ˌ": [
433
+ 121
434
+ ],
435
+ "ː": [
436
+ 122
437
+ ],
438
+ "ˑ": [
439
+ 123
440
+ ],
441
+ "˞": [
442
+ 124
443
+ ],
444
+ "ˤ": [
445
+ 146
446
+ ],
447
+ "̃": [
448
+ 141
449
+ ],
450
+ "̊": [
451
+ 158
452
+ ],
453
+ "̝": [
454
+ 157
455
+ ],
456
+ "̧": [
457
+ 140
458
+ ],
459
+ "̩": [
460
+ 144
461
+ ],
462
+ "̪": [
463
+ 142
464
+ ],
465
+ "̯": [
466
+ 143
467
+ ],
468
+ "̺": [
469
+ 152
470
+ ],
471
+ "̻": [
472
+ 153
473
+ ],
474
+ "β": [
475
+ 125
476
+ ],
477
+ "ε": [
478
+ 147
479
+ ],
480
+ "θ": [
481
+ 126
482
+ ],
483
+ "χ": [
484
+ 127
485
+ ],
486
+ "ᵻ": [
487
+ 128
488
+ ],
489
+ "↑": [
490
+ 151
491
+ ],
492
+ "↓": [
493
+ 148
494
+ ],
495
+ "ⱱ": [
496
+ 129
497
+ ],
498
+ "😐": [
499
+ 159
500
+ ],
501
+ "😄": [
502
+ 160
503
+ ],
504
+ "😢": [
505
+ 161
506
+ ],
507
+ "😡": [
508
+ 162
509
+ ],
510
+ "😱": [
511
+ 163
512
+ ],
513
+ "😲": [
514
+ 164
515
+ ],
516
+ "🙄": [
517
+ 165
518
+ ],
519
+ "🤔": [
520
+ 166
521
+ ],
522
+ "🙂": [
523
+ 167
524
+ ],
525
+ "😏": [
526
+ 168
527
+ ],
528
+ "😜": [
529
+ 169
530
+ ],
531
+ "😌": [
532
+ 170
533
+ ],
534
+ "🎭": [
535
+ 171
536
+ ],
537
+ "🧐": [
538
+ 172
539
+ ],
540
+ "⓪": [
541
+ 173
542
+ ],
543
+ "①": [
544
+ 174
545
+ ],
546
+ "②": [
547
+ 175
548
+ ],
549
+ "③": [
550
+ 176
551
+ ],
552
+ "④": [
553
+ 177
554
+ ],
555
+ "⑤": [
556
+ 178
557
+ ],
558
+ "⑥": [
559
+ 179
560
+ ],
561
+ "⑦": [
562
+ 180
563
+ ],
564
+ "⑧": [
565
+ 181
566
+ ],
567
+ "⑨": [
568
+ 182
569
+ ]
570
+ },
571
+ "num_symbols": 256,
572
+ "num_speakers": 20,
573
+ "model": {
574
+ "resblock": "2",
575
+ "resblock_kernel_sizes": [
576
+ 3,
577
+ 5,
578
+ 7
579
+ ],
580
+ "resblock_dilation_sizes": [
581
+ [
582
+ 1,
583
+ 2
584
+ ],
585
+ [
586
+ 2,
587
+ 6
588
+ ],
589
+ [
590
+ 3,
591
+ 12
592
+ ]
593
+ ],
594
+ "upsample_rates": [
595
+ 8,
596
+ 8,
597
+ 4
598
+ ],
599
+ "upsample_initial_channel": 256,
600
+ "upsample_kernel_sizes": [
601
+ 16,
602
+ 16,
603
+ 8
604
+ ],
605
+ "filter_length": 1024,
606
+ "hop_length": 256,
607
+ "win_length": 1024,
608
+ "mel_channels": 80,
609
+ "sample_rate": 22050,
610
+ "sample_bytes": 2,
611
+ "channels": 1,
612
+ "mel_fmin": 0.0,
613
+ "mel_fmax": null,
614
+ "inter_channels": 192,
615
+ "hidden_channels": 192,
616
+ "filter_channels": 768,
617
+ "n_heads": 2,
618
+ "n_layers": 6,
619
+ "kernel_size": 3,
620
+ "p_dropout": 0.1,
621
+ "n_layers_q": 3,
622
+ "use_spectral_norm": false,
623
+ "gin_channels": 512,
624
+ "use_sdp": true,
625
+ "segment_size": 8192
626
+ },
627
+ "speaker_id_map": {
628
+ "skilled_hero_man": 0,
629
+ "skilled_hero_woman": 1,
630
+ "fun_hero_man": 2,
631
+ "fun_hero_woman": 3,
632
+ "strong_hero_man": 4,
633
+ "strong_hero_woman": 5,
634
+ "mad_scientist_man": 6,
635
+ "mad_scientist_woman": 7,
636
+ "clever_villain_man": 8,
637
+ "clever_villain_woman": 9,
638
+ "narrator_man": 10,
639
+ "narrator_woman": 11,
640
+ "wise_elder_man": 12,
641
+ "wise_elder_woman": 13,
642
+ "outgoing_anime_man": 14,
643
+ "outgoing_anime_woman": 15,
644
+ "scary_villain_man": 16,
645
+ "scary_villain_woman": 17,
646
+ "news_reporter_man": 18,
647
+ "news_reporter_woman": 19
648
+ },
649
+ "piper_version": "1.0.0"
650
+ }
examples/basic_infer.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from wfloat_tts import load_generator, write_wave
2
+
3
+
4
+ def main() -> None:
5
+ generator = load_generator(
6
+ checkpoint_path="model.safetensors",
7
+ config_path="config.json",
8
+ )
9
+ audio = generator.generate(
10
+ text="Hey there, how are you today?",
11
+ sid=11,
12
+ emotion="neutral",
13
+ intensity=0.5,
14
+ )
15
+ out_path = "out.wav"
16
+ write_wave(out_path, audio.samples, audio.sample_rate)
17
+ print(out_path)
18
+
19
+
20
+ if __name__ == "__main__":
21
+ main()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1468266ccb48d73aa044c5799a2d3e660399418c237fd447f4019919f28a4e1
3
+ size 120950832
pyproject.toml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "wfloat-tts"
7
+ version = "0.1.0"
8
+ description = "Reference inference helpers for the Wfloat TTS checkpoint release."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "numpy>=1.24",
13
+ "packaging>=23",
14
+ "safetensors>=0.4",
15
+ "torch>=2.1",
16
+ ]
17
+
18
+ [project.scripts]
19
+ wfloat-tts = "wfloat_tts.cli:main"
20
+
21
+ [tool.setuptools]
22
+ package-dir = {"" = "src"}
23
+
24
+ [tool.setuptools.packages.find]
25
+ where = ["src"]
src/wfloat_tts/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import EMOTION_TO_SYMBOL, INTENSITY_SYMBOLS, SPEAKER_IDS, VALID_EMOTIONS
2
+ from .infer import GeneratedAudio, WfloatGenerator, load_generator, write_wave
3
+ from .processor import PreparedInput, prepare_input
4
+
5
+ __all__ = [
6
+ "EMOTION_TO_SYMBOL",
7
+ "GeneratedAudio",
8
+ "INTENSITY_SYMBOLS",
9
+ "PreparedInput",
10
+ "SPEAKER_IDS",
11
+ "VALID_EMOTIONS",
12
+ "WfloatGenerator",
13
+ "load_generator",
14
+ "prepare_input",
15
+ "write_wave",
16
+ ]
src/wfloat_tts/cli.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+
5
+ from .infer import load_generator, write_wave
6
+
7
+
8
+ def build_parser() -> argparse.ArgumentParser:
9
+ parser = argparse.ArgumentParser(prog="wfloat-tts")
10
+ parser.add_argument("--model", "--checkpoint", dest="model", default="model.safetensors")
11
+ parser.add_argument("--config", default="config.json")
12
+ parser.add_argument("--text", required=True)
13
+ parser.add_argument("--sid", type=int, default=0)
14
+ parser.add_argument("--emotion", default="neutral")
15
+ parser.add_argument("--intensity", type=float, default=0.5)
16
+ parser.add_argument("--noise-scale", type=float, default=None)
17
+ parser.add_argument("--length-scale", type=float, default=None)
18
+ parser.add_argument("--noise-w", type=float, default=None)
19
+ parser.add_argument("--device", default="cpu")
20
+ parser.add_argument("--output", required=True)
21
+ return parser
22
+
23
+
24
+ def main() -> None:
25
+ parser = build_parser()
26
+ args = parser.parse_args()
27
+
28
+ generator = load_generator(
29
+ checkpoint_path=args.model,
30
+ config_path=args.config,
31
+ device=args.device,
32
+ )
33
+ audio = generator.generate(
34
+ text=args.text,
35
+ sid=args.sid,
36
+ emotion=args.emotion,
37
+ intensity=args.intensity,
38
+ noise_scale=args.noise_scale,
39
+ length_scale=args.length_scale,
40
+ noise_w=args.noise_w,
41
+ )
42
+ write_wave(args.output, audio.samples, audio.sample_rate)
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
src/wfloat_tts/constants.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EMOTION_TO_SYMBOL = {
2
+ "neutral": "😐",
3
+ "joy": "😄",
4
+ "sadness": "😢",
5
+ "anger": "😡",
6
+ "fear": "😱",
7
+ "surprise": "😲",
8
+ "dismissive": "🙄",
9
+ "confusion": "🤔",
10
+ }
11
+
12
+ VALID_EMOTIONS = tuple(EMOTION_TO_SYMBOL.keys())
13
+
14
+ INTENSITY_SYMBOLS = (
15
+ "⓪",
16
+ "①",
17
+ "②",
18
+ "③",
19
+ "④",
20
+ "⑤",
21
+ "⑥",
22
+ "⑦",
23
+ "⑧",
24
+ "⑨",
25
+ )
26
+
27
+ SPEAKER_IDS = {
28
+ "skilled_hero_man": 0,
29
+ "skilled_hero_woman": 1,
30
+ "fun_hero_man": 2,
31
+ "fun_hero_woman": 3,
32
+ "strong_hero_man": 4,
33
+ "strong_hero_woman": 5,
34
+ "mad_scientist_man": 6,
35
+ "mad_scientist_woman": 7,
36
+ "clever_villain_man": 8,
37
+ "clever_villain_woman": 9,
38
+ "narrator_man": 10,
39
+ "narrator_woman": 11,
40
+ "wise_elder_man": 12,
41
+ "wise_elder_woman": 13,
42
+ "outgoing_anime_man": 14,
43
+ "outgoing_anime_woman": 15,
44
+ "scary_villain_man": 16,
45
+ "scary_villain_woman": 17,
46
+ "news_reporter_man": 18,
47
+ "news_reporter_woman": 19,
48
+ }
49
+
50
+ DEFAULT_ESPEAK_VOICE = "en-us"
51
+ DEFAULT_SAMPLE_RATE = 22050
src/wfloat_tts/infer.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import wave
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+ from safetensors.torch import load_file as load_safetensors_file
11
+
12
+ from .constants import DEFAULT_ESPEAK_VOICE, DEFAULT_SAMPLE_RATE
13
+ from .processor import PreparedInput, prepare_input
14
+ from .vits import SynthesizerTrn
15
+
16
+
17
+ def _repo_root() -> Path:
18
+ return Path(__file__).resolve().parents[2]
19
+
20
+
21
+ def _default_model_path() -> Path:
22
+ safetensors_path = _repo_root() / "model.safetensors"
23
+ if safetensors_path.exists():
24
+ return safetensors_path
25
+
26
+ return _repo_root() / "model.ckpt"
27
+
28
+
29
+ def _default_config_path() -> Path:
30
+ return _repo_root() / "config.json"
31
+
32
+
33
+ def _import_torch() -> Any:
34
+ try:
35
+ import torch
36
+ except ImportError as exc:
37
+ raise ImportError("torch is required for checkpoint inference") from exc
38
+
39
+ return torch
40
+
41
+
42
+ def load_release_config(config_path: str | Path) -> dict[str, Any]:
43
+ with Path(config_path).open("r", encoding="utf-8") as config_file:
44
+ return json.load(config_file)
45
+
46
+
47
+ def audio_float_to_int16(audio: np.ndarray, max_wav_value: float = 32767.0) -> np.ndarray:
48
+ audio = np.asarray(audio, dtype=np.float32)
49
+ scale = max(0.01, float(np.max(np.abs(audio)))) if audio.size else 1.0
50
+ audio_norm = audio * (max_wav_value / scale)
51
+ audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value)
52
+ return audio_norm.astype(np.int16)
53
+
54
+
55
+ def write_wave(path: str | Path, samples: np.ndarray, sample_rate: int) -> Path:
56
+ path = Path(path)
57
+ pcm = audio_float_to_int16(samples)
58
+
59
+ with wave.open(str(path), "wb") as wav_file:
60
+ wav_file.setnchannels(1)
61
+ wav_file.setsampwidth(2)
62
+ wav_file.setframerate(sample_rate)
63
+ wav_file.writeframes(pcm.tobytes())
64
+
65
+ return path
66
+
67
+
68
+ def _generator_kwargs_from_config(config: dict[str, Any]) -> dict[str, Any]:
69
+ model = config.get("model", {})
70
+
71
+ return {
72
+ "n_vocab": int(config["num_symbols"]),
73
+ "spec_channels": int(model["filter_length"]) // 2 + 1,
74
+ "segment_size": int(model["segment_size"]) // int(model["hop_length"]),
75
+ "inter_channels": int(model["inter_channels"]),
76
+ "hidden_channels": int(model["hidden_channels"]),
77
+ "filter_channels": int(model["filter_channels"]),
78
+ "n_heads": int(model["n_heads"]),
79
+ "n_layers": int(model["n_layers"]),
80
+ "kernel_size": int(model["kernel_size"]),
81
+ "p_dropout": float(model["p_dropout"]),
82
+ "resblock": model["resblock"],
83
+ "resblock_kernel_sizes": tuple(model["resblock_kernel_sizes"]),
84
+ "resblock_dilation_sizes": tuple(tuple(x) for x in model["resblock_dilation_sizes"]),
85
+ "upsample_rates": tuple(model["upsample_rates"]),
86
+ "upsample_initial_channel": int(model["upsample_initial_channel"]),
87
+ "upsample_kernel_sizes": tuple(model["upsample_kernel_sizes"]),
88
+ "n_speakers": int(config["num_speakers"]),
89
+ "gin_channels": int(model["gin_channels"]),
90
+ "use_sdp": bool(model.get("use_sdp", True)),
91
+ }
92
+
93
+
94
+ def _load_generator_state(model_path: Path, torch_module: Any) -> dict[str, Any]:
95
+ if model_path.suffix == ".safetensors":
96
+ return load_safetensors_file(str(model_path), device="cpu")
97
+
98
+ checkpoint = torch_module.load(model_path, map_location="cpu", weights_only=False)
99
+ state_dict = checkpoint["state_dict"]
100
+ return {
101
+ key[len("model_g.") :]: value
102
+ for key, value in state_dict.items()
103
+ if key.startswith("model_g.")
104
+ }
105
+
106
+
107
+ @dataclass(frozen=True)
108
+ class GeneratedAudio:
109
+ samples: np.ndarray
110
+ sample_rate: int
111
+ prepared_input: PreparedInput
112
+
113
+
114
+ class WfloatGenerator:
115
+ def __init__(
116
+ self,
117
+ checkpoint_path: str | Path | None = None,
118
+ config_path: str | Path | None = None,
119
+ device: str = "cpu",
120
+ ) -> None:
121
+ self.checkpoint_path = Path(checkpoint_path or _default_model_path())
122
+ self.config_path = Path(config_path or _default_config_path())
123
+ self.device = device
124
+
125
+ if not self.checkpoint_path.exists():
126
+ raise FileNotFoundError(
127
+ f"Checkpoint not found at {self.checkpoint_path}. "
128
+ "Place a compatible multi-speaker checkpoint there or pass --checkpoint."
129
+ )
130
+
131
+ if not self.config_path.exists():
132
+ raise FileNotFoundError(f"Config not found at {self.config_path}")
133
+
134
+ self.config = load_release_config(self.config_path)
135
+ self.sample_rate = int(self.config.get("audio", {}).get("sample_rate", DEFAULT_SAMPLE_RATE))
136
+ self.espeak_voice = self.config.get("espeak", {}).get("voice", DEFAULT_ESPEAK_VOICE)
137
+ self.num_speakers = int(self.config.get("num_speakers", 1))
138
+
139
+ torch = _import_torch()
140
+ self._torch = torch
141
+ self._model = SynthesizerTrn(**_generator_kwargs_from_config(self.config))
142
+ state_dict = _load_generator_state(self.checkpoint_path, torch)
143
+ self._model.load_state_dict(state_dict, strict=True)
144
+ self._model.eval()
145
+
146
+ with torch.no_grad():
147
+ self._model.dec.remove_weight_norm()
148
+
149
+ self._model.to(self.device)
150
+ self.num_speakers = int(getattr(self._model, "n_speakers", self.num_speakers))
151
+
152
+ configured_num_speakers = int(self.config.get("num_speakers", self.num_speakers))
153
+ if configured_num_speakers != self.num_speakers:
154
+ raise ValueError(
155
+ "Checkpoint/config mismatch: "
156
+ f"config.json declares num_speakers={configured_num_speakers}, "
157
+ f"but checkpoint reports num_speakers={self.num_speakers}."
158
+ )
159
+
160
+ def generate(
161
+ self,
162
+ text: str,
163
+ sid: int = 0,
164
+ emotion: str = "neutral",
165
+ intensity: float = 0.5,
166
+ noise_scale: float | None = None,
167
+ length_scale: float | None = None,
168
+ noise_w: float | None = None,
169
+ ) -> GeneratedAudio:
170
+ if self.num_speakers <= 1:
171
+ if sid not in (0, None):
172
+ raise ValueError(
173
+ f"Loaded checkpoint is single-speaker but sid={sid} was provided"
174
+ )
175
+ sid_tensor = None
176
+ else:
177
+ sid_tensor = self._torch.LongTensor([int(sid)]).to(self.device)
178
+
179
+ prepared = prepare_input(
180
+ text=text,
181
+ config=self.config,
182
+ emotion=emotion,
183
+ intensity=intensity,
184
+ espeak_voice=self.espeak_voice,
185
+ )
186
+
187
+ text_tensor = self._torch.LongTensor(prepared.token_ids).unsqueeze(0).to(self.device)
188
+ text_lengths = self._torch.LongTensor([len(prepared.token_ids)]).to(self.device)
189
+
190
+ inference = self.config.get("inference", {})
191
+ scales = [
192
+ float(inference.get("noise_scale", 0.667) if noise_scale is None else noise_scale),
193
+ float(inference.get("length_scale", 1.0) if length_scale is None else length_scale),
194
+ float(inference.get("noise_w", 0.8) if noise_w is None else noise_w),
195
+ ]
196
+
197
+ with self._torch.no_grad():
198
+ audio, *_ = self._model.infer(
199
+ text_tensor,
200
+ text_lengths,
201
+ sid=sid_tensor,
202
+ noise_scale=scales[0],
203
+ length_scale=scales[1],
204
+ noise_scale_w=scales[2],
205
+ )
206
+
207
+ samples = audio.detach().cpu().numpy().squeeze().astype(np.float32)
208
+
209
+ return GeneratedAudio(
210
+ samples=samples,
211
+ sample_rate=self.sample_rate,
212
+ prepared_input=prepared,
213
+ )
214
+
215
+
216
+ def load_generator(
217
+ checkpoint_path: str | Path | None = None,
218
+ config_path: str | Path | None = None,
219
+ device: str = "cpu",
220
+ ) -> WfloatGenerator:
221
+ return WfloatGenerator(
222
+ checkpoint_path=checkpoint_path,
223
+ config_path=config_path,
224
+ device=device,
225
+ )
src/wfloat_tts/processor.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List
5
+
6
+ from .constants import DEFAULT_ESPEAK_VOICE, EMOTION_TO_SYMBOL, INTENSITY_SYMBOLS
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class PreparedInput:
11
+ text: str
12
+ phonemes: List[str]
13
+ token_ids: List[int]
14
+ emotion: str
15
+ intensity: float
16
+ emotion_symbol: str
17
+ intensity_symbol: str
18
+
19
+
20
+ def clamp_unit(value: float) -> float:
21
+ if value != value:
22
+ return 0.0
23
+
24
+ if value < 0.0:
25
+ return 0.0
26
+
27
+ if value > 1.0:
28
+ return 1.0
29
+
30
+ return float(value)
31
+
32
+
33
+ def load_token_map(config: dict[str, Any]) -> Dict[str, int]:
34
+ phoneme_id_map = config.get("phoneme_id_map")
35
+ if not isinstance(phoneme_id_map, dict):
36
+ raise KeyError("config.json is missing phoneme_id_map")
37
+
38
+ token_map: Dict[str, int] = {}
39
+
40
+ for symbol, raw_value in phoneme_id_map.items():
41
+ if isinstance(raw_value, int):
42
+ token_map[symbol] = raw_value
43
+ continue
44
+
45
+ if isinstance(raw_value, list) and len(raw_value) == 1:
46
+ token_map[symbol] = int(raw_value[0])
47
+ continue
48
+
49
+ raise ValueError(
50
+ f"Unsupported token mapping for symbol {symbol!r}: expected int or single-item list"
51
+ )
52
+
53
+ return token_map
54
+
55
+
56
+ def intensity_to_symbol(intensity: float) -> str:
57
+ value = clamp_unit(intensity)
58
+ idx = int(value * len(INTENSITY_SYMBOLS))
59
+ idx = max(0, min(idx, len(INTENSITY_SYMBOLS) - 1))
60
+ return INTENSITY_SYMBOLS[idx]
61
+
62
+
63
+ def normalize_emotion(emotion: str | None) -> str:
64
+ value = (emotion or "neutral").strip().lower()
65
+ if value not in EMOTION_TO_SYMBOL:
66
+ raise ValueError(
67
+ f"Unsupported emotion {emotion!r}. Expected one of: {', '.join(EMOTION_TO_SYMBOL)}"
68
+ )
69
+
70
+ return value
71
+
72
+
73
+ def phonemize_full_utterance(text: str, espeak_voice: str = DEFAULT_ESPEAK_VOICE) -> List[str]:
74
+ try:
75
+ from piper_phonemize import phonemize_espeak
76
+ except ImportError as exc:
77
+ raise ImportError(
78
+ "wfloat-tts requires piper-phonemize for phonemization. "
79
+ "Install it with: pip install \"piper-phonemize==1.3.0\" "
80
+ "-f https://k2-fsa.github.io/icefall/piper_phonemize"
81
+ ) from exc
82
+
83
+ sentence_groups = phonemize_espeak(text, espeak_voice)
84
+ phonemes: List[str] = []
85
+
86
+ for group in sentence_groups:
87
+ if not group:
88
+ continue
89
+
90
+ if phonemes:
91
+ phonemes.append(" ")
92
+
93
+ phonemes.extend(group)
94
+
95
+ return phonemes
96
+
97
+
98
+ def prepare_input(
99
+ text: str,
100
+ config: dict[str, Any],
101
+ emotion: str = "neutral",
102
+ intensity: float = 0.5,
103
+ espeak_voice: str = DEFAULT_ESPEAK_VOICE,
104
+ ) -> PreparedInput:
105
+ normalized_emotion = normalize_emotion(emotion)
106
+ normalized_intensity = clamp_unit(intensity)
107
+
108
+ phonemes = phonemize_full_utterance(text, espeak_voice=espeak_voice)
109
+ emotion_symbol = EMOTION_TO_SYMBOL[normalized_emotion]
110
+ intensity_symbol = intensity_to_symbol(normalized_intensity)
111
+ phonemes.extend([emotion_symbol, intensity_symbol])
112
+
113
+ token_map = load_token_map(config)
114
+
115
+ missing = [symbol for symbol in phonemes if symbol not in token_map]
116
+ if missing:
117
+ joined = ", ".join(sorted(set(missing)))
118
+ raise KeyError(f"Missing symbol(s) in config.json phoneme_id_map: {joined}")
119
+
120
+ token_ids = [token_map[symbol] for symbol in phonemes]
121
+
122
+ return PreparedInput(
123
+ text=text,
124
+ phonemes=phonemes,
125
+ token_ids=token_ids,
126
+ emotion=normalized_emotion,
127
+ intensity=normalized_intensity,
128
+ emotion_symbol=emotion_symbol,
129
+ intensity_symbol=intensity_symbol,
130
+ )
src/wfloat_tts/vits/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .models import SynthesizerTrn
2
+
3
+ __all__ = ["SynthesizerTrn"]
src/wfloat_tts/vits/attentions.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from .commons import subsequent_mask
9
+ from .modules import LayerNorm
10
+
11
+
12
+ class Encoder(nn.Module):
13
+ def __init__(
14
+ self,
15
+ hidden_channels: int,
16
+ filter_channels: int,
17
+ n_heads: int,
18
+ n_layers: int,
19
+ kernel_size: int = 1,
20
+ p_dropout: float = 0.0,
21
+ window_size: int = 4,
22
+ **kwargs
23
+ ):
24
+ super().__init__()
25
+ self.hidden_channels = hidden_channels
26
+ self.filter_channels = filter_channels
27
+ self.n_heads = n_heads
28
+ self.n_layers = n_layers
29
+ self.kernel_size = kernel_size
30
+ self.p_dropout = p_dropout
31
+ self.window_size = window_size
32
+
33
+ self.drop = nn.Dropout(p_dropout)
34
+ self.attn_layers = nn.ModuleList()
35
+ self.norm_layers_1 = nn.ModuleList()
36
+ self.ffn_layers = nn.ModuleList()
37
+ self.norm_layers_2 = nn.ModuleList()
38
+ for i in range(self.n_layers):
39
+ self.attn_layers.append(
40
+ MultiHeadAttention(
41
+ hidden_channels,
42
+ hidden_channels,
43
+ n_heads,
44
+ p_dropout=p_dropout,
45
+ window_size=window_size,
46
+ )
47
+ )
48
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
49
+ self.ffn_layers.append(
50
+ FFN(
51
+ hidden_channels,
52
+ hidden_channels,
53
+ filter_channels,
54
+ kernel_size,
55
+ p_dropout=p_dropout,
56
+ )
57
+ )
58
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
59
+
60
+ def forward(self, x, x_mask):
61
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
62
+ x = x * x_mask
63
+ for attn_layer, norm_layer_1, ffn_layer, norm_layer_2 in zip(
64
+ self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
65
+ ):
66
+ y = attn_layer(x, x, attn_mask)
67
+ y = self.drop(y)
68
+ x = norm_layer_1(x + y)
69
+
70
+ y = ffn_layer(x, x_mask)
71
+ y = self.drop(y)
72
+ x = norm_layer_2(x + y)
73
+ x = x * x_mask
74
+ return x
75
+
76
+
77
+ class Decoder(nn.Module):
78
+ def __init__(
79
+ self,
80
+ hidden_channels: int,
81
+ filter_channels: int,
82
+ n_heads: int,
83
+ n_layers: int,
84
+ kernel_size: int = 1,
85
+ p_dropout: float = 0.0,
86
+ proximal_bias: bool = False,
87
+ proximal_init: bool = True,
88
+ **kwargs
89
+ ):
90
+ super().__init__()
91
+ self.hidden_channels = hidden_channels
92
+ self.filter_channels = filter_channels
93
+ self.n_heads = n_heads
94
+ self.n_layers = n_layers
95
+ self.kernel_size = kernel_size
96
+ self.p_dropout = p_dropout
97
+ self.proximal_bias = proximal_bias
98
+ self.proximal_init = proximal_init
99
+
100
+ self.drop = nn.Dropout(p_dropout)
101
+ self.self_attn_layers = nn.ModuleList()
102
+ self.norm_layers_0 = nn.ModuleList()
103
+ self.encdec_attn_layers = nn.ModuleList()
104
+ self.norm_layers_1 = nn.ModuleList()
105
+ self.ffn_layers = nn.ModuleList()
106
+ self.norm_layers_2 = nn.ModuleList()
107
+ for i in range(self.n_layers):
108
+ self.self_attn_layers.append(
109
+ MultiHeadAttention(
110
+ hidden_channels,
111
+ hidden_channels,
112
+ n_heads,
113
+ p_dropout=p_dropout,
114
+ proximal_bias=proximal_bias,
115
+ proximal_init=proximal_init,
116
+ )
117
+ )
118
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
119
+ self.encdec_attn_layers.append(
120
+ MultiHeadAttention(
121
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
122
+ )
123
+ )
124
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
125
+ self.ffn_layers.append(
126
+ FFN(
127
+ hidden_channels,
128
+ hidden_channels,
129
+ filter_channels,
130
+ kernel_size,
131
+ p_dropout=p_dropout,
132
+ causal=True,
133
+ )
134
+ )
135
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
136
+
137
+ def forward(self, x, x_mask, h, h_mask):
138
+ """
139
+ x: decoder input
140
+ h: encoder output
141
+ """
142
+ self_attn_mask = subsequent_mask(x_mask.size(2)).type_as(x)
143
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
144
+ x = x * x_mask
145
+ for i in range(self.n_layers):
146
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
147
+ y = self.drop(y)
148
+ x = self.norm_layers_0[i](x + y)
149
+
150
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
151
+ y = self.drop(y)
152
+ x = self.norm_layers_1[i](x + y)
153
+
154
+ y = self.ffn_layers[i](x, x_mask)
155
+ y = self.drop(y)
156
+ x = self.norm_layers_2[i](x + y)
157
+ x = x * x_mask
158
+ return x
159
+
160
+
161
+ class MultiHeadAttention(nn.Module):
162
+ def __init__(
163
+ self,
164
+ channels: int,
165
+ out_channels: int,
166
+ n_heads: int,
167
+ p_dropout: float = 0.0,
168
+ window_size: typing.Optional[int] = None,
169
+ heads_share: bool = True,
170
+ block_length: typing.Optional[int] = None,
171
+ proximal_bias: bool = False,
172
+ proximal_init: bool = False,
173
+ ):
174
+ super().__init__()
175
+ assert channels % n_heads == 0
176
+
177
+ self.channels = channels
178
+ self.out_channels = out_channels
179
+ self.n_heads = n_heads
180
+ self.p_dropout = p_dropout
181
+ self.window_size = window_size
182
+ self.heads_share = heads_share
183
+ self.block_length = block_length
184
+ self.proximal_bias = proximal_bias
185
+ self.proximal_init = proximal_init
186
+ self.attn = torch.zeros(1)
187
+
188
+ self.k_channels = channels // n_heads
189
+ self.conv_q = nn.Conv1d(channels, channels, 1)
190
+ self.conv_k = nn.Conv1d(channels, channels, 1)
191
+ self.conv_v = nn.Conv1d(channels, channels, 1)
192
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
193
+ self.drop = nn.Dropout(p_dropout)
194
+
195
+ if window_size is not None:
196
+ n_heads_rel = 1 if heads_share else n_heads
197
+ rel_stddev = self.k_channels**-0.5
198
+ self.emb_rel_k = nn.Parameter(
199
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
200
+ * rel_stddev
201
+ )
202
+ self.emb_rel_v = nn.Parameter(
203
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
204
+ * rel_stddev
205
+ )
206
+
207
+ nn.init.xavier_uniform_(self.conv_q.weight)
208
+ nn.init.xavier_uniform_(self.conv_k.weight)
209
+ nn.init.xavier_uniform_(self.conv_v.weight)
210
+ if proximal_init:
211
+ with torch.no_grad():
212
+ self.conv_k.weight.copy_(self.conv_q.weight)
213
+ self.conv_k.bias.copy_(self.conv_q.bias)
214
+
215
+ def forward(self, x, c, attn_mask=None):
216
+ q = self.conv_q(x)
217
+ k = self.conv_k(c)
218
+ v = self.conv_v(c)
219
+
220
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
221
+
222
+ x = self.conv_o(x)
223
+ return x
224
+
225
+ def attention(self, query, key, value, mask=None):
226
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
227
+ b, d, t_s, t_t = (key.size(0), key.size(1), key.size(2), query.size(2))
228
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
229
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
230
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
+
232
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
233
+ if self.window_size is not None:
234
+ assert (
235
+ t_s == t_t
236
+ ), "Relative attention is only available for self-attention."
237
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
238
+ rel_logits = self._matmul_with_relative_keys(
239
+ query / math.sqrt(self.k_channels), key_relative_embeddings
240
+ )
241
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
242
+ scores = scores + scores_local
243
+ if self.proximal_bias:
244
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
245
+ scores = scores + self._attention_bias_proximal(t_s).type_as(scores)
246
+ if mask is not None:
247
+ scores = scores.masked_fill(mask == 0, -1e4)
248
+ if self.block_length is not None:
249
+ assert (
250
+ t_s == t_t
251
+ ), "Local attention is only available for self-attention."
252
+ block_mask = (
253
+ torch.ones_like(scores)
254
+ .triu(-self.block_length)
255
+ .tril(self.block_length)
256
+ )
257
+ scores = scores.masked_fill(block_mask == 0, -1e4)
258
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
259
+ p_attn = self.drop(p_attn)
260
+ output = torch.matmul(p_attn, value)
261
+ if self.window_size is not None:
262
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
263
+ value_relative_embeddings = self._get_relative_embeddings(
264
+ self.emb_rel_v, t_s
265
+ )
266
+ output = output + self._matmul_with_relative_values(
267
+ relative_weights, value_relative_embeddings
268
+ )
269
+ output = (
270
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
271
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
272
+ return output, p_attn
273
+
274
+ def _matmul_with_relative_values(self, x, y):
275
+ """
276
+ x: [b, h, l, m]
277
+ y: [h or 1, m, d]
278
+ ret: [b, h, l, d]
279
+ """
280
+ ret = torch.matmul(x, y.unsqueeze(0))
281
+ return ret
282
+
283
+ def _matmul_with_relative_keys(self, x, y):
284
+ """
285
+ x: [b, h, l, d]
286
+ y: [h or 1, m, d]
287
+ ret: [b, h, l, m]
288
+ """
289
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
290
+ return ret
291
+
292
+ def _get_relative_embeddings(self, relative_embeddings, length: int):
293
+ # max_relative_position = 2 * self.window_size + 1
294
+ # Pad first before slice to avoid using cond ops.
295
+ pad_length = max(length - (self.window_size + 1), 0)
296
+ slice_start_position = max((self.window_size + 1) - length, 0)
297
+ slice_end_position = slice_start_position + 2 * length - 1
298
+ if pad_length > 0:
299
+ padded_relative_embeddings = F.pad(
300
+ relative_embeddings,
301
+ # convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
302
+ (0, 0, pad_length, pad_length, 0, 0),
303
+ )
304
+ else:
305
+ padded_relative_embeddings = relative_embeddings
306
+ used_relative_embeddings = padded_relative_embeddings[
307
+ :, slice_start_position:slice_end_position
308
+ ]
309
+ return used_relative_embeddings
310
+
311
+ def _relative_position_to_absolute_position(self, x):
312
+ """
313
+ x: [b, h, l, 2*l-1]
314
+ ret: [b, h, l, l]
315
+ """
316
+ batch, heads, length, _ = x.size()
317
+
318
+ # Concat columns of pad to shift from relative to absolute indexing.
319
+ # x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
320
+ x = F.pad(x, (0, 1, 0, 0, 0, 0, 0, 0))
321
+
322
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
323
+ x_flat = x.view([batch, heads, length * 2 * length])
324
+ # x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
325
+ x_flat = F.pad(x_flat, (0, length - 1, 0, 0, 0, 0))
326
+
327
+ # Reshape and slice out the padded elements.
328
+ x_final = x_flat.view([batch, heads, length + 1, (2 * length) - 1])[
329
+ :, :, :length, length - 1 :
330
+ ]
331
+ return x_final
332
+
333
+ def _absolute_position_to_relative_position(self, x):
334
+ """
335
+ x: [b, h, l, l]
336
+ ret: [b, h, l, 2*l-1]
337
+ """
338
+ batch, heads, length, _ = x.size()
339
+
340
+ # padd along column
341
+ # x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
342
+ x = F.pad(x, (0, length - 1, 0, 0, 0, 0, 0, 0))
343
+ x_flat = x.view([batch, heads, (length * length) + (length * (length - 1))])
344
+ # add 0's in the beginning that will skew the elements after reshape
345
+ # x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
346
+ x_flat = F.pad(x_flat, (length, 0, 0, 0, 0, 0))
347
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348
+ return x_final
349
+
350
+ def _attention_bias_proximal(self, length: int):
351
+ """Bias for self-attention to encourage attention to close positions.
352
+ Args:
353
+ length: an integer scalar.
354
+ Returns:
355
+ a Tensor with shape [1, 1, length, length]
356
+ """
357
+ r = torch.arange(length, dtype=torch.float32)
358
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360
+
361
+
362
+ class FFN(nn.Module):
363
+ def __init__(
364
+ self,
365
+ in_channels: int,
366
+ out_channels: int,
367
+ filter_channels: int,
368
+ kernel_size: int,
369
+ p_dropout: float = 0.0,
370
+ activation: str = "",
371
+ causal: bool = False,
372
+ ):
373
+ super().__init__()
374
+ self.in_channels = in_channels
375
+ self.out_channels = out_channels
376
+ self.filter_channels = filter_channels
377
+ self.kernel_size = kernel_size
378
+ self.p_dropout = p_dropout
379
+ self.activation = activation
380
+ self.causal = causal
381
+
382
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
383
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
384
+ self.drop = nn.Dropout(p_dropout)
385
+
386
+ def forward(self, x, x_mask):
387
+ if self.causal:
388
+ padding1 = self._causal_padding(x * x_mask)
389
+ else:
390
+ padding1 = self._same_padding(x * x_mask)
391
+
392
+ x = self.conv_1(padding1)
393
+
394
+ if self.activation == "gelu":
395
+ x = x * torch.sigmoid(1.702 * x)
396
+ else:
397
+ x = torch.relu(x)
398
+ x = self.drop(x)
399
+
400
+ if self.causal:
401
+ padding2 = self._causal_padding(x * x_mask)
402
+ else:
403
+ padding2 = self._same_padding(x * x_mask)
404
+
405
+ x = self.conv_2(padding2)
406
+
407
+ return x * x_mask
408
+
409
+ def _causal_padding(self, x):
410
+ if self.kernel_size == 1:
411
+ return x
412
+ pad_l = self.kernel_size - 1
413
+ pad_r = 0
414
+ # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
415
+ # x = F.pad(x, convert_pad_shape(padding))
416
+ x = F.pad(x, (pad_l, pad_r, 0, 0, 0, 0))
417
+ return x
418
+
419
+ def _same_padding(self, x):
420
+ if self.kernel_size == 1:
421
+ return x
422
+ pad_l = (self.kernel_size - 1) // 2
423
+ pad_r = self.kernel_size // 2
424
+ # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
425
+ # x = F.pad(x, convert_pad_shape(padding))
426
+ x = F.pad(x, (pad_l, pad_r, 0, 0, 0, 0))
427
+ return x
src/wfloat_tts/vits/commons.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from torch.nn import functional as F
7
+
8
+ _LOGGER = logging.getLogger("vits.commons")
9
+
10
+
11
+ def init_weights(m, mean=0.0, std=0.01):
12
+ classname = m.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ m.weight.data.normal_(mean, std)
15
+
16
+
17
+ def get_padding(kernel_size, dilation=1):
18
+ return int((kernel_size * dilation - dilation) / 2)
19
+
20
+
21
+ def intersperse(lst, item):
22
+ result = [item] * (len(lst) * 2 + 1)
23
+ result[1::2] = lst
24
+ return result
25
+
26
+
27
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
28
+ """KL(P||Q)"""
29
+ kl = (logs_q - logs_p) - 0.5
30
+ kl += (
31
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
32
+ )
33
+ return kl
34
+
35
+
36
+ def rand_gumbel(shape):
37
+ """Sample from the Gumbel distribution, protect from overflows."""
38
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
39
+ return -torch.log(-torch.log(uniform_samples))
40
+
41
+
42
+ def rand_gumbel_like(x):
43
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
44
+ return g
45
+
46
+
47
+ def slice_segments(x, ids_str, segment_size=4):
48
+ ret = torch.zeros_like(x[:, :, :segment_size])
49
+ for i in range(x.size(0)):
50
+ idx_str = max(0, ids_str[i])
51
+ idx_end = idx_str + segment_size
52
+ ret[i] = x[i, :, idx_str:idx_end]
53
+ return ret
54
+
55
+
56
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
57
+ b, d, t = x.size()
58
+ if x_lengths is None:
59
+ x_lengths = t
60
+ ids_str_max = x_lengths - segment_size + 1
61
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
62
+ ret = slice_segments(x, ids_str, segment_size)
63
+ return ret, ids_str
64
+
65
+
66
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
67
+ position = torch.arange(length, dtype=torch.float)
68
+ num_timescales = channels // 2
69
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
70
+ num_timescales - 1
71
+ )
72
+ inv_timescales = min_timescale * torch.exp(
73
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
74
+ )
75
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
76
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
77
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
78
+ signal = signal.view(1, channels, length)
79
+ return signal
80
+
81
+
82
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
83
+ b, channels, length = x.size()
84
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
85
+ return x + signal.to(dtype=x.dtype, device=x.device)
86
+
87
+
88
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
89
+ b, channels, length = x.size()
90
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
91
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
92
+
93
+
94
+ def subsequent_mask(length: int):
95
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
96
+ return mask
97
+
98
+
99
+ @torch.jit.script
100
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
101
+ n_channels_int = n_channels[0]
102
+ in_act = input_a + input_b
103
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
104
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
105
+ acts = t_act * s_act
106
+ return acts
107
+
108
+
109
+ def sequence_mask(length, max_length: Optional[int] = None):
110
+ if max_length is None:
111
+ max_length = length.max()
112
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
113
+ return x.unsqueeze(0) < length.unsqueeze(1)
114
+
115
+
116
+ def generate_path(duration, mask):
117
+ """
118
+ duration: [b, 1, t_x]
119
+ mask: [b, 1, t_y, t_x]
120
+ """
121
+ b, _, t_y, t_x = mask.shape
122
+ cum_duration = torch.cumsum(duration, -1)
123
+
124
+ cum_duration_flat = cum_duration.view(b * t_x)
125
+ path = sequence_mask(cum_duration_flat, t_y).type_as(mask)
126
+ path = path.view(b, t_x, t_y)
127
+ path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1]
128
+ path = path.unsqueeze(1).transpose(2, 3) * mask
129
+ return path
130
+
131
+
132
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
133
+ if isinstance(parameters, torch.Tensor):
134
+ parameters = [parameters]
135
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
136
+ norm_type = float(norm_type)
137
+ if clip_value is not None:
138
+ clip_value = float(clip_value)
139
+
140
+ total_norm = 0
141
+ for p in parameters:
142
+ param_norm = p.grad.data.norm(norm_type)
143
+ total_norm += param_norm.item() ** norm_type
144
+ if clip_value is not None:
145
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
146
+ total_norm = total_norm ** (1.0 / norm_type)
147
+ return total_norm
src/wfloat_tts/vits/models.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
9
+
10
+ from . import attentions, commons, modules
11
+ from .commons import get_padding, init_weights
12
+
13
+
14
+ class StochasticDurationPredictor(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ filter_channels: int,
19
+ kernel_size: int,
20
+ p_dropout: float,
21
+ n_flows: int = 4,
22
+ gin_channels: int = 0,
23
+ ):
24
+ super().__init__()
25
+ filter_channels = in_channels # it needs to be removed from future version.
26
+ self.in_channels = in_channels
27
+ self.filter_channels = filter_channels
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.n_flows = n_flows
31
+ self.gin_channels = gin_channels
32
+
33
+ self.log_flow = modules.Log()
34
+ self.flows = nn.ModuleList()
35
+ self.flows.append(modules.ElementwiseAffine(2))
36
+ for i in range(n_flows):
37
+ self.flows.append(
38
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
39
+ )
40
+ self.flows.append(modules.Flip())
41
+
42
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
43
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
44
+ self.post_convs = modules.DDSConv(
45
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
46
+ )
47
+ self.post_flows = nn.ModuleList()
48
+ self.post_flows.append(modules.ElementwiseAffine(2))
49
+ for i in range(4):
50
+ self.post_flows.append(
51
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
52
+ )
53
+ self.post_flows.append(modules.Flip())
54
+
55
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
56
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
57
+ self.convs = modules.DDSConv(
58
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
59
+ )
60
+ if gin_channels != 0:
61
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
62
+
63
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
64
+ x = torch.detach(x)
65
+ x = self.pre(x)
66
+ if g is not None:
67
+ g = torch.detach(g)
68
+ x = x + self.cond(g)
69
+ x = self.convs(x, x_mask)
70
+ x = self.proj(x) * x_mask
71
+
72
+ if not reverse:
73
+ flows = self.flows
74
+ assert w is not None
75
+
76
+ logdet_tot_q = 0
77
+ h_w = self.post_pre(w)
78
+ h_w = self.post_convs(h_w, x_mask)
79
+ h_w = self.post_proj(h_w) * x_mask
80
+ e_q = torch.randn(w.size(0), 2, w.size(2)).type_as(x) * x_mask
81
+ z_q = e_q
82
+ for flow in self.post_flows:
83
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
84
+ logdet_tot_q += logdet_q
85
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
86
+ u = torch.sigmoid(z_u) * x_mask
87
+ z0 = (w - u) * x_mask
88
+ logdet_tot_q += torch.sum(
89
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
90
+ )
91
+ logq = (
92
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
93
+ - logdet_tot_q
94
+ )
95
+
96
+ logdet_tot = 0
97
+ z0, logdet = self.log_flow(z0, x_mask)
98
+ logdet_tot += logdet
99
+ z = torch.cat([z0, z1], 1)
100
+ for flow in flows:
101
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
102
+ logdet_tot = logdet_tot + logdet
103
+ nll = (
104
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
105
+ - logdet_tot
106
+ )
107
+ return nll + logq # [b]
108
+ else:
109
+ flows = list(reversed(self.flows))
110
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
111
+ z = torch.randn(x.size(0), 2, x.size(2)).type_as(x) * noise_scale
112
+
113
+ for flow in flows:
114
+ z = flow(z, x_mask, g=x, reverse=reverse)
115
+ z0, z1 = torch.split(z, [1, 1], 1)
116
+ logw = z0
117
+ return logw
118
+
119
+
120
+ class DurationPredictor(nn.Module):
121
+ def __init__(
122
+ self,
123
+ in_channels: int,
124
+ filter_channels: int,
125
+ kernel_size: int,
126
+ p_dropout: float,
127
+ gin_channels: int = 0,
128
+ ):
129
+ super().__init__()
130
+
131
+ self.in_channels = in_channels
132
+ self.filter_channels = filter_channels
133
+ self.kernel_size = kernel_size
134
+ self.p_dropout = p_dropout
135
+ self.gin_channels = gin_channels
136
+
137
+ self.drop = nn.Dropout(p_dropout)
138
+ self.conv_1 = nn.Conv1d(
139
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
140
+ )
141
+ self.norm_1 = modules.LayerNorm(filter_channels)
142
+ self.conv_2 = nn.Conv1d(
143
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
144
+ )
145
+ self.norm_2 = modules.LayerNorm(filter_channels)
146
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
147
+
148
+ if gin_channels != 0:
149
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
150
+
151
+ def forward(self, x, x_mask, g=None):
152
+ x = torch.detach(x)
153
+ if g is not None:
154
+ g = torch.detach(g)
155
+ x = x + self.cond(g)
156
+ x = self.conv_1(x * x_mask)
157
+ x = torch.relu(x)
158
+ x = self.norm_1(x)
159
+ x = self.drop(x)
160
+ x = self.conv_2(x * x_mask)
161
+ x = torch.relu(x)
162
+ x = self.norm_2(x)
163
+ x = self.drop(x)
164
+ x = self.proj(x * x_mask)
165
+ return x * x_mask
166
+
167
+
168
+ class TextEncoder(nn.Module):
169
+ def __init__(
170
+ self,
171
+ n_vocab: int,
172
+ out_channels: int,
173
+ hidden_channels: int,
174
+ filter_channels: int,
175
+ n_heads: int,
176
+ n_layers: int,
177
+ kernel_size: int,
178
+ p_dropout: float,
179
+ ):
180
+ super().__init__()
181
+ self.n_vocab = n_vocab
182
+ self.out_channels = out_channels
183
+ self.hidden_channels = hidden_channels
184
+ self.filter_channels = filter_channels
185
+ self.n_heads = n_heads
186
+ self.n_layers = n_layers
187
+ self.kernel_size = kernel_size
188
+ self.p_dropout = p_dropout
189
+
190
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
191
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
192
+
193
+ self.encoder = attentions.Encoder(
194
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
195
+ )
196
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
197
+
198
+ def forward(self, x, x_lengths):
199
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
200
+ x = torch.transpose(x, 1, -1) # [b, h, t]
201
+ x_mask = torch.unsqueeze(
202
+ commons.sequence_mask(x_lengths, x.size(2)), 1
203
+ ).type_as(x)
204
+
205
+ x = self.encoder(x * x_mask, x_mask)
206
+ stats = self.proj(x) * x_mask
207
+
208
+ m, logs = torch.split(stats, self.out_channels, dim=1)
209
+ return x, m, logs, x_mask
210
+
211
+
212
+ class ResidualCouplingBlock(nn.Module):
213
+ def __init__(
214
+ self,
215
+ channels: int,
216
+ hidden_channels: int,
217
+ kernel_size: int,
218
+ dilation_rate: int,
219
+ n_layers: int,
220
+ n_flows: int = 4,
221
+ gin_channels: int = 0,
222
+ ):
223
+ super().__init__()
224
+ self.channels = channels
225
+ self.hidden_channels = hidden_channels
226
+ self.kernel_size = kernel_size
227
+ self.dilation_rate = dilation_rate
228
+ self.n_layers = n_layers
229
+ self.n_flows = n_flows
230
+ self.gin_channels = gin_channels
231
+
232
+ self.flows = nn.ModuleList()
233
+ for i in range(n_flows):
234
+ self.flows.append(
235
+ modules.ResidualCouplingLayer(
236
+ channels,
237
+ hidden_channels,
238
+ kernel_size,
239
+ dilation_rate,
240
+ n_layers,
241
+ gin_channels=gin_channels,
242
+ mean_only=True,
243
+ )
244
+ )
245
+ self.flows.append(modules.Flip())
246
+
247
+ def forward(self, x, x_mask, g=None, reverse=False):
248
+ if not reverse:
249
+ for flow in self.flows:
250
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
251
+ else:
252
+ for flow in reversed(self.flows):
253
+ x = flow(x, x_mask, g=g, reverse=reverse)
254
+ return x
255
+
256
+
257
+ class PosteriorEncoder(nn.Module):
258
+ def __init__(
259
+ self,
260
+ in_channels: int,
261
+ out_channels: int,
262
+ hidden_channels: int,
263
+ kernel_size: int,
264
+ dilation_rate: int,
265
+ n_layers: int,
266
+ gin_channels: int = 0,
267
+ ):
268
+ super().__init__()
269
+ self.in_channels = in_channels
270
+ self.out_channels = out_channels
271
+ self.hidden_channels = hidden_channels
272
+ self.kernel_size = kernel_size
273
+ self.dilation_rate = dilation_rate
274
+ self.n_layers = n_layers
275
+ self.gin_channels = gin_channels
276
+
277
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
278
+ self.enc = modules.WN(
279
+ hidden_channels,
280
+ kernel_size,
281
+ dilation_rate,
282
+ n_layers,
283
+ gin_channels=gin_channels,
284
+ )
285
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
286
+
287
+ def forward(self, x, x_lengths, g=None):
288
+ x_mask = torch.unsqueeze(
289
+ commons.sequence_mask(x_lengths, x.size(2)), 1
290
+ ).type_as(x)
291
+ x = self.pre(x) * x_mask
292
+ x = self.enc(x, x_mask, g=g)
293
+ stats = self.proj(x) * x_mask
294
+ m, logs = torch.split(stats, self.out_channels, dim=1)
295
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
296
+ return z, m, logs, x_mask
297
+
298
+
299
+ class Generator(torch.nn.Module):
300
+ def __init__(
301
+ self,
302
+ initial_channel: int,
303
+ resblock: typing.Optional[str],
304
+ resblock_kernel_sizes: typing.Tuple[int, ...],
305
+ resblock_dilation_sizes: typing.Tuple[typing.Tuple[int, ...], ...],
306
+ upsample_rates: typing.Tuple[int, ...],
307
+ upsample_initial_channel: int,
308
+ upsample_kernel_sizes: typing.Tuple[int, ...],
309
+ gin_channels: int = 0,
310
+ ):
311
+ super(Generator, self).__init__()
312
+ self.LRELU_SLOPE = 0.1
313
+ self.num_kernels = len(resblock_kernel_sizes)
314
+ self.num_upsamples = len(upsample_rates)
315
+ self.conv_pre = Conv1d(
316
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
317
+ )
318
+ resblock_module = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
319
+
320
+ self.ups = nn.ModuleList()
321
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
322
+ self.ups.append(
323
+ weight_norm(
324
+ ConvTranspose1d(
325
+ upsample_initial_channel // (2**i),
326
+ upsample_initial_channel // (2 ** (i + 1)),
327
+ k,
328
+ u,
329
+ padding=(k - u) // 2,
330
+ )
331
+ )
332
+ )
333
+
334
+ self.resblocks = nn.ModuleList()
335
+ for i in range(len(self.ups)):
336
+ ch = upsample_initial_channel // (2 ** (i + 1))
337
+ for j, (k, d) in enumerate(
338
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
339
+ ):
340
+ self.resblocks.append(resblock_module(ch, k, d))
341
+
342
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
343
+ self.ups.apply(init_weights)
344
+
345
+ if gin_channels != 0:
346
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
347
+
348
+ def forward(self, x, g=None):
349
+ x = self.conv_pre(x)
350
+ if g is not None:
351
+ x = x + self.cond(g)
352
+
353
+ for i, up in enumerate(self.ups):
354
+ x = F.leaky_relu(x, self.LRELU_SLOPE)
355
+ x = up(x)
356
+ xs = torch.zeros(1)
357
+ for j, resblock in enumerate(self.resblocks):
358
+ index = j - (i * self.num_kernels)
359
+ if index == 0:
360
+ xs = resblock(x)
361
+ elif (index > 0) and (index < self.num_kernels):
362
+ xs += resblock(x)
363
+ x = xs / self.num_kernels
364
+ x = F.leaky_relu(x)
365
+ x = self.conv_post(x)
366
+ x = torch.tanh(x)
367
+
368
+ return x
369
+
370
+ def remove_weight_norm(self):
371
+ print("Removing weight norm...")
372
+ for l in self.ups:
373
+ remove_weight_norm(l)
374
+ for l in self.resblocks:
375
+ l.remove_weight_norm()
376
+
377
+
378
+ class DiscriminatorP(torch.nn.Module):
379
+ def __init__(
380
+ self,
381
+ period: int,
382
+ kernel_size: int = 5,
383
+ stride: int = 3,
384
+ use_spectral_norm: bool = False,
385
+ ):
386
+ super(DiscriminatorP, self).__init__()
387
+ self.LRELU_SLOPE = 0.1
388
+ self.period = period
389
+ self.use_spectral_norm = use_spectral_norm
390
+ norm_f = weight_norm if not use_spectral_norm else spectral_norm
391
+ self.convs = nn.ModuleList(
392
+ [
393
+ norm_f(
394
+ Conv2d(
395
+ 1,
396
+ 32,
397
+ (kernel_size, 1),
398
+ (stride, 1),
399
+ padding=(get_padding(kernel_size, 1), 0),
400
+ )
401
+ ),
402
+ norm_f(
403
+ Conv2d(
404
+ 32,
405
+ 128,
406
+ (kernel_size, 1),
407
+ (stride, 1),
408
+ padding=(get_padding(kernel_size, 1), 0),
409
+ )
410
+ ),
411
+ norm_f(
412
+ Conv2d(
413
+ 128,
414
+ 512,
415
+ (kernel_size, 1),
416
+ (stride, 1),
417
+ padding=(get_padding(kernel_size, 1), 0),
418
+ )
419
+ ),
420
+ norm_f(
421
+ Conv2d(
422
+ 512,
423
+ 1024,
424
+ (kernel_size, 1),
425
+ (stride, 1),
426
+ padding=(get_padding(kernel_size, 1), 0),
427
+ )
428
+ ),
429
+ norm_f(
430
+ Conv2d(
431
+ 1024,
432
+ 1024,
433
+ (kernel_size, 1),
434
+ 1,
435
+ padding=(get_padding(kernel_size, 1), 0),
436
+ )
437
+ ),
438
+ ]
439
+ )
440
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
441
+
442
+ def forward(self, x):
443
+ fmap = []
444
+
445
+ # 1d to 2d
446
+ b, c, t = x.shape
447
+ if t % self.period != 0: # pad first
448
+ n_pad = self.period - (t % self.period)
449
+ x = F.pad(x, (0, n_pad), "reflect")
450
+ t = t + n_pad
451
+ x = x.view(b, c, t // self.period, self.period)
452
+
453
+ for l in self.convs:
454
+ x = l(x)
455
+ x = F.leaky_relu(x, self.LRELU_SLOPE)
456
+ fmap.append(x)
457
+ x = self.conv_post(x)
458
+ fmap.append(x)
459
+ x = torch.flatten(x, 1, -1)
460
+
461
+ return x, fmap
462
+
463
+
464
+ class DiscriminatorS(torch.nn.Module):
465
+ def __init__(self, use_spectral_norm=False):
466
+ super(DiscriminatorS, self).__init__()
467
+ self.LRELU_SLOPE = 0.1
468
+ norm_f = spectral_norm if use_spectral_norm else weight_norm
469
+ self.convs = nn.ModuleList(
470
+ [
471
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
472
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
473
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
474
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
475
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
476
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
477
+ ]
478
+ )
479
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
480
+
481
+ def forward(self, x):
482
+ fmap = []
483
+
484
+ for l in self.convs:
485
+ x = l(x)
486
+ x = F.leaky_relu(x, self.LRELU_SLOPE)
487
+ fmap.append(x)
488
+ x = self.conv_post(x)
489
+ fmap.append(x)
490
+ x = torch.flatten(x, 1, -1)
491
+
492
+ return x, fmap
493
+
494
+
495
+ class MultiPeriodDiscriminator(torch.nn.Module):
496
+ def __init__(self, use_spectral_norm=False):
497
+ super(MultiPeriodDiscriminator, self).__init__()
498
+ periods = [2, 3, 5, 7, 11]
499
+
500
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
501
+ discs = discs + [
502
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
503
+ ]
504
+ self.discriminators = nn.ModuleList(discs)
505
+
506
+ def forward(self, y, y_hat):
507
+ y_d_rs = []
508
+ y_d_gs = []
509
+ fmap_rs = []
510
+ fmap_gs = []
511
+ for i, d in enumerate(self.discriminators):
512
+ y_d_r, fmap_r = d(y)
513
+ y_d_g, fmap_g = d(y_hat)
514
+ y_d_rs.append(y_d_r)
515
+ y_d_gs.append(y_d_g)
516
+ fmap_rs.append(fmap_r)
517
+ fmap_gs.append(fmap_g)
518
+
519
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
520
+
521
+
522
+ class SynthesizerTrn(nn.Module):
523
+ """
524
+ Synthesizer for Training
525
+ """
526
+
527
+ def __init__(
528
+ self,
529
+ n_vocab: int,
530
+ spec_channels: int,
531
+ segment_size: int,
532
+ inter_channels: int,
533
+ hidden_channels: int,
534
+ filter_channels: int,
535
+ n_heads: int,
536
+ n_layers: int,
537
+ kernel_size: int,
538
+ p_dropout: float,
539
+ resblock: str,
540
+ resblock_kernel_sizes: typing.Tuple[int, ...],
541
+ resblock_dilation_sizes: typing.Tuple[typing.Tuple[int, ...], ...],
542
+ upsample_rates: typing.Tuple[int, ...],
543
+ upsample_initial_channel: int,
544
+ upsample_kernel_sizes: typing.Tuple[int, ...],
545
+ n_speakers: int = 1,
546
+ gin_channels: int = 0,
547
+ use_sdp: bool = True,
548
+ ):
549
+
550
+ super().__init__()
551
+ self.n_vocab = n_vocab
552
+ self.spec_channels = spec_channels
553
+ self.inter_channels = inter_channels
554
+ self.hidden_channels = hidden_channels
555
+ self.filter_channels = filter_channels
556
+ self.n_heads = n_heads
557
+ self.n_layers = n_layers
558
+ self.kernel_size = kernel_size
559
+ self.p_dropout = p_dropout
560
+ self.resblock = resblock
561
+ self.resblock_kernel_sizes = resblock_kernel_sizes
562
+ self.resblock_dilation_sizes = resblock_dilation_sizes
563
+ self.upsample_rates = upsample_rates
564
+ self.upsample_initial_channel = upsample_initial_channel
565
+ self.upsample_kernel_sizes = upsample_kernel_sizes
566
+ self.segment_size = segment_size
567
+ self.n_speakers = n_speakers
568
+ self.gin_channels = gin_channels
569
+
570
+ self.use_sdp = use_sdp
571
+
572
+ self.enc_p = TextEncoder(
573
+ n_vocab,
574
+ inter_channels,
575
+ hidden_channels,
576
+ filter_channels,
577
+ n_heads,
578
+ n_layers,
579
+ kernel_size,
580
+ p_dropout,
581
+ )
582
+ self.dec = Generator(
583
+ inter_channels,
584
+ resblock,
585
+ resblock_kernel_sizes,
586
+ resblock_dilation_sizes,
587
+ upsample_rates,
588
+ upsample_initial_channel,
589
+ upsample_kernel_sizes,
590
+ gin_channels=gin_channels,
591
+ )
592
+ self.enc_q = PosteriorEncoder(
593
+ spec_channels,
594
+ inter_channels,
595
+ hidden_channels,
596
+ 5,
597
+ 1,
598
+ 16,
599
+ gin_channels=gin_channels,
600
+ )
601
+ self.flow = ResidualCouplingBlock(
602
+ inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
603
+ )
604
+
605
+ if use_sdp:
606
+ self.dp = StochasticDurationPredictor(
607
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
608
+ )
609
+ else:
610
+ self.dp = DurationPredictor(
611
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
612
+ )
613
+
614
+ if n_speakers > 1:
615
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
616
+
617
+ def forward(self, x, x_lengths, y, y_lengths, sid=None):
618
+ raise NotImplementedError(
619
+ "wfloat-tts vendors an inference-only VITS runtime. "
620
+ "Training forward() is intentionally not included."
621
+ )
622
+
623
+ def infer(
624
+ self,
625
+ x,
626
+ x_lengths,
627
+ sid=None,
628
+ noise_scale=0.667,
629
+ length_scale=1,
630
+ noise_scale_w=0.8,
631
+ max_len=None,
632
+ ):
633
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
634
+ if self.n_speakers > 1:
635
+ assert sid is not None, "Missing speaker id"
636
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
637
+ else:
638
+ g = None
639
+
640
+ if self.use_sdp:
641
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
642
+ else:
643
+ logw = self.dp(x, x_mask, g=g)
644
+ w = torch.exp(logw) * x_mask * length_scale
645
+ w_ceil = torch.ceil(w)
646
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
647
+ y_mask = torch.unsqueeze(
648
+ commons.sequence_mask(y_lengths, y_lengths.max()), 1
649
+ ).type_as(x_mask)
650
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
651
+ attn = commons.generate_path(w_ceil, attn_mask)
652
+
653
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
654
+ 1, 2
655
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
656
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
657
+ 1, 2
658
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
659
+
660
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
661
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
662
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
663
+
664
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
665
+
666
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
667
+ raise NotImplementedError(
668
+ "wfloat-tts ships text-to-speech inference only. "
669
+ "Voice conversion is not part of this runtime."
670
+ )
src/wfloat_tts/vits/modules.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import remove_weight_norm, weight_norm
9
+
10
+ from .commons import fused_add_tanh_sigmoid_multiply, get_padding, init_weights
11
+ from .transforms import piecewise_rational_quadratic_transform
12
+
13
+
14
+ class LayerNorm(nn.Module):
15
+ def __init__(self, channels: int, eps: float = 1e-5):
16
+ super().__init__()
17
+ self.channels = channels
18
+ self.eps = eps
19
+
20
+ self.gamma = nn.Parameter(torch.ones(channels))
21
+ self.beta = nn.Parameter(torch.zeros(channels))
22
+
23
+ def forward(self, x):
24
+ x = x.transpose(1, -1)
25
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
26
+ return x.transpose(1, -1)
27
+
28
+
29
+ class ConvReluNorm(nn.Module):
30
+ def __init__(
31
+ self,
32
+ in_channels: int,
33
+ hidden_channels: int,
34
+ out_channels: int,
35
+ kernel_size: int,
36
+ n_layers: int,
37
+ p_dropout: float,
38
+ ):
39
+ super().__init__()
40
+ self.in_channels = in_channels
41
+ self.hidden_channels = hidden_channels
42
+ self.out_channels = out_channels
43
+ self.kernel_size = kernel_size
44
+ self.n_layers = n_layers
45
+ self.p_dropout = p_dropout
46
+ assert n_layers > 1, "Number of layers should be larger than 0."
47
+
48
+ self.conv_layers = nn.ModuleList()
49
+ self.norm_layers = nn.ModuleList()
50
+ self.conv_layers.append(
51
+ nn.Conv1d(
52
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
53
+ )
54
+ )
55
+ self.norm_layers.append(LayerNorm(hidden_channels))
56
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
57
+ for _ in range(n_layers - 1):
58
+ self.conv_layers.append(
59
+ nn.Conv1d(
60
+ hidden_channels,
61
+ hidden_channels,
62
+ kernel_size,
63
+ padding=kernel_size // 2,
64
+ )
65
+ )
66
+ self.norm_layers.append(LayerNorm(hidden_channels))
67
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
68
+ self.proj.weight.data.zero_()
69
+ self.proj.bias.data.zero_()
70
+
71
+ def forward(self, x, x_mask):
72
+ x_org = x
73
+ for i in range(self.n_layers):
74
+ x = self.conv_layers[i](x * x_mask)
75
+ x = self.norm_layers[i](x)
76
+ x = self.relu_drop(x)
77
+ x = x_org + self.proj(x)
78
+ return x * x_mask
79
+
80
+
81
+ class DDSConv(nn.Module):
82
+ """
83
+ Dialted and Depth-Separable Convolution
84
+ """
85
+
86
+ def __init__(
87
+ self, channels: int, kernel_size: int, n_layers: int, p_dropout: float = 0.0
88
+ ):
89
+ super().__init__()
90
+ self.channels = channels
91
+ self.kernel_size = kernel_size
92
+ self.n_layers = n_layers
93
+ self.p_dropout = p_dropout
94
+
95
+ self.drop = nn.Dropout(p_dropout)
96
+ self.convs_sep = nn.ModuleList()
97
+ self.convs_1x1 = nn.ModuleList()
98
+ self.norms_1 = nn.ModuleList()
99
+ self.norms_2 = nn.ModuleList()
100
+ for i in range(n_layers):
101
+ dilation = kernel_size**i
102
+ padding = (kernel_size * dilation - dilation) // 2
103
+ self.convs_sep.append(
104
+ nn.Conv1d(
105
+ channels,
106
+ channels,
107
+ kernel_size,
108
+ groups=channels,
109
+ dilation=dilation,
110
+ padding=padding,
111
+ )
112
+ )
113
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
114
+ self.norms_1.append(LayerNorm(channels))
115
+ self.norms_2.append(LayerNorm(channels))
116
+
117
+ def forward(self, x, x_mask, g=None):
118
+ if g is not None:
119
+ x = x + g
120
+ for i in range(self.n_layers):
121
+ y = self.convs_sep[i](x * x_mask)
122
+ y = self.norms_1[i](y)
123
+ y = F.gelu(y)
124
+ y = self.convs_1x1[i](y)
125
+ y = self.norms_2[i](y)
126
+ y = F.gelu(y)
127
+ y = self.drop(y)
128
+ x = x + y
129
+ return x * x_mask
130
+
131
+
132
+ class WN(torch.nn.Module):
133
+ def __init__(
134
+ self,
135
+ hidden_channels: int,
136
+ kernel_size: int,
137
+ dilation_rate: int,
138
+ n_layers: int,
139
+ gin_channels: int = 0,
140
+ p_dropout: float = 0,
141
+ ):
142
+ super().__init__()
143
+ assert kernel_size % 2 == 1
144
+ self.hidden_channels = hidden_channels
145
+ self.kernel_size = (kernel_size,)
146
+ self.dilation_rate = dilation_rate
147
+ self.n_layers = n_layers
148
+ self.gin_channels = gin_channels
149
+ self.p_dropout = p_dropout
150
+
151
+ self.in_layers = torch.nn.ModuleList()
152
+ self.res_skip_layers = torch.nn.ModuleList()
153
+ self.drop = nn.Dropout(p_dropout)
154
+
155
+ if gin_channels != 0:
156
+ cond_layer = torch.nn.Conv1d(
157
+ gin_channels, 2 * hidden_channels * n_layers, 1
158
+ )
159
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
160
+
161
+ for i in range(n_layers):
162
+ dilation = dilation_rate**i
163
+ padding = int((kernel_size * dilation - dilation) / 2)
164
+ in_layer = torch.nn.Conv1d(
165
+ hidden_channels,
166
+ 2 * hidden_channels,
167
+ kernel_size,
168
+ dilation=dilation,
169
+ padding=padding,
170
+ )
171
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
172
+ self.in_layers.append(in_layer)
173
+
174
+ # last one is not necessary
175
+ if i < n_layers - 1:
176
+ res_skip_channels = 2 * hidden_channels
177
+ else:
178
+ res_skip_channels = hidden_channels
179
+
180
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
181
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
182
+ self.res_skip_layers.append(res_skip_layer)
183
+
184
+ def forward(self, x, x_mask, g=None, **kwargs):
185
+ output = torch.zeros_like(x)
186
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
187
+
188
+ if g is not None:
189
+ g = self.cond_layer(g)
190
+
191
+ for i in range(self.n_layers):
192
+ x_in = self.in_layers[i](x)
193
+ if g is not None:
194
+ cond_offset = i * 2 * self.hidden_channels
195
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
196
+ else:
197
+ g_l = torch.zeros_like(x_in)
198
+
199
+ acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
200
+ acts = self.drop(acts)
201
+
202
+ res_skip_acts = self.res_skip_layers[i](acts)
203
+ if i < self.n_layers - 1:
204
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
205
+ x = (x + res_acts) * x_mask
206
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
207
+ else:
208
+ output = output + res_skip_acts
209
+ return output * x_mask
210
+
211
+ def remove_weight_norm(self):
212
+ if self.gin_channels != 0:
213
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
214
+ for l in self.in_layers:
215
+ torch.nn.utils.remove_weight_norm(l)
216
+ for l in self.res_skip_layers:
217
+ torch.nn.utils.remove_weight_norm(l)
218
+
219
+
220
+ class ResBlock1(torch.nn.Module):
221
+ def __init__(
222
+ self,
223
+ channels: int,
224
+ kernel_size: int = 3,
225
+ dilation: typing.Tuple[int] = (1, 3, 5),
226
+ ):
227
+ super(ResBlock1, self).__init__()
228
+ self.LRELU_SLOPE = 0.1
229
+ self.convs1 = nn.ModuleList(
230
+ [
231
+ weight_norm(
232
+ Conv1d(
233
+ channels,
234
+ channels,
235
+ kernel_size,
236
+ 1,
237
+ dilation=dilation[0],
238
+ padding=get_padding(kernel_size, dilation[0]),
239
+ )
240
+ ),
241
+ weight_norm(
242
+ Conv1d(
243
+ channels,
244
+ channels,
245
+ kernel_size,
246
+ 1,
247
+ dilation=dilation[1],
248
+ padding=get_padding(kernel_size, dilation[1]),
249
+ )
250
+ ),
251
+ weight_norm(
252
+ Conv1d(
253
+ channels,
254
+ channels,
255
+ kernel_size,
256
+ 1,
257
+ dilation=dilation[2],
258
+ padding=get_padding(kernel_size, dilation[2]),
259
+ )
260
+ ),
261
+ ]
262
+ )
263
+ self.convs1.apply(init_weights)
264
+
265
+ self.convs2 = nn.ModuleList(
266
+ [
267
+ weight_norm(
268
+ Conv1d(
269
+ channels,
270
+ channels,
271
+ kernel_size,
272
+ 1,
273
+ dilation=1,
274
+ padding=get_padding(kernel_size, 1),
275
+ )
276
+ ),
277
+ weight_norm(
278
+ Conv1d(
279
+ channels,
280
+ channels,
281
+ kernel_size,
282
+ 1,
283
+ dilation=1,
284
+ padding=get_padding(kernel_size, 1),
285
+ )
286
+ ),
287
+ weight_norm(
288
+ Conv1d(
289
+ channels,
290
+ channels,
291
+ kernel_size,
292
+ 1,
293
+ dilation=1,
294
+ padding=get_padding(kernel_size, 1),
295
+ )
296
+ ),
297
+ ]
298
+ )
299
+ self.convs2.apply(init_weights)
300
+
301
+ def forward(self, x, x_mask=None):
302
+ for c1, c2 in zip(self.convs1, self.convs2):
303
+ xt = F.leaky_relu(x, self.LRELU_SLOPE)
304
+ if x_mask is not None:
305
+ xt = xt * x_mask
306
+ xt = c1(xt)
307
+ xt = F.leaky_relu(xt, self.LRELU_SLOPE)
308
+ if x_mask is not None:
309
+ xt = xt * x_mask
310
+ xt = c2(xt)
311
+ x = xt + x
312
+ if x_mask is not None:
313
+ x = x * x_mask
314
+ return x
315
+
316
+ def remove_weight_norm(self):
317
+ for l in self.convs1:
318
+ remove_weight_norm(l)
319
+ for l in self.convs2:
320
+ remove_weight_norm(l)
321
+
322
+
323
+ class ResBlock2(torch.nn.Module):
324
+ def __init__(
325
+ self, channels: int, kernel_size: int = 3, dilation: typing.Tuple[int] = (1, 3)
326
+ ):
327
+ super(ResBlock2, self).__init__()
328
+ self.LRELU_SLOPE = 0.1
329
+ self.convs = nn.ModuleList(
330
+ [
331
+ weight_norm(
332
+ Conv1d(
333
+ channels,
334
+ channels,
335
+ kernel_size,
336
+ 1,
337
+ dilation=dilation[0],
338
+ padding=get_padding(kernel_size, dilation[0]),
339
+ )
340
+ ),
341
+ weight_norm(
342
+ Conv1d(
343
+ channels,
344
+ channels,
345
+ kernel_size,
346
+ 1,
347
+ dilation=dilation[1],
348
+ padding=get_padding(kernel_size, dilation[1]),
349
+ )
350
+ ),
351
+ ]
352
+ )
353
+ self.convs.apply(init_weights)
354
+
355
+ def forward(self, x, x_mask=None):
356
+ for c in self.convs:
357
+ xt = F.leaky_relu(x, self.LRELU_SLOPE)
358
+ if x_mask is not None:
359
+ xt = xt * x_mask
360
+ xt = c(xt)
361
+ x = xt + x
362
+ if x_mask is not None:
363
+ x = x * x_mask
364
+ return x
365
+
366
+ def remove_weight_norm(self):
367
+ for l in self.convs:
368
+ remove_weight_norm(l)
369
+
370
+
371
+ class Log(nn.Module):
372
+ def forward(
373
+ self, x: torch.Tensor, x_mask: torch.Tensor, reverse: bool = False, **kwargs
374
+ ):
375
+ if not reverse:
376
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
377
+ logdet = torch.sum(-y, [1, 2])
378
+ return y, logdet
379
+ else:
380
+ x = torch.exp(x) * x_mask
381
+ return x
382
+
383
+
384
+ class Flip(nn.Module):
385
+ def forward(self, x: torch.Tensor, *args, reverse: bool = False, **kwargs):
386
+ x = torch.flip(x, [1])
387
+ if not reverse:
388
+ logdet = torch.zeros(x.size(0)).type_as(x)
389
+ return x, logdet
390
+ else:
391
+ return x
392
+
393
+
394
+ class ElementwiseAffine(nn.Module):
395
+ def __init__(self, channels: int):
396
+ super().__init__()
397
+ self.channels = channels
398
+ self.m = nn.Parameter(torch.zeros(channels, 1))
399
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
400
+
401
+ def forward(self, x, x_mask, reverse=False, **kwargs):
402
+ if not reverse:
403
+ y = self.m + torch.exp(self.logs) * x
404
+ y = y * x_mask
405
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
406
+ return y, logdet
407
+ else:
408
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
409
+ return x
410
+
411
+
412
+ class ResidualCouplingLayer(nn.Module):
413
+ def __init__(
414
+ self,
415
+ channels: int,
416
+ hidden_channels: int,
417
+ kernel_size: int,
418
+ dilation_rate: int,
419
+ n_layers: int,
420
+ p_dropout: float = 0,
421
+ gin_channels: int = 0,
422
+ mean_only: bool = False,
423
+ ):
424
+ assert channels % 2 == 0, "channels should be divisible by 2"
425
+ super().__init__()
426
+ self.channels = channels
427
+ self.hidden_channels = hidden_channels
428
+ self.kernel_size = kernel_size
429
+ self.dilation_rate = dilation_rate
430
+ self.n_layers = n_layers
431
+ self.half_channels = channels // 2
432
+ self.mean_only = mean_only
433
+
434
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
435
+ self.enc = WN(
436
+ hidden_channels,
437
+ kernel_size,
438
+ dilation_rate,
439
+ n_layers,
440
+ p_dropout=p_dropout,
441
+ gin_channels=gin_channels,
442
+ )
443
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
444
+ self.post.weight.data.zero_()
445
+ self.post.bias.data.zero_()
446
+
447
+ def forward(self, x, x_mask, g=None, reverse=False):
448
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
449
+ h = self.pre(x0) * x_mask
450
+ h = self.enc(h, x_mask, g=g)
451
+ stats = self.post(h) * x_mask
452
+ if not self.mean_only:
453
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
454
+ else:
455
+ m = stats
456
+ logs = torch.zeros_like(m)
457
+
458
+ if not reverse:
459
+ x1 = m + x1 * torch.exp(logs) * x_mask
460
+ x = torch.cat([x0, x1], 1)
461
+ logdet = torch.sum(logs, [1, 2])
462
+ return x, logdet
463
+ else:
464
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
465
+ x = torch.cat([x0, x1], 1)
466
+ return x
467
+
468
+
469
+ class ConvFlow(nn.Module):
470
+ def __init__(
471
+ self,
472
+ in_channels: int,
473
+ filter_channels: int,
474
+ kernel_size: int,
475
+ n_layers: int,
476
+ num_bins: int = 10,
477
+ tail_bound: float = 5.0,
478
+ ):
479
+ super().__init__()
480
+ self.in_channels = in_channels
481
+ self.filter_channels = filter_channels
482
+ self.kernel_size = kernel_size
483
+ self.n_layers = n_layers
484
+ self.num_bins = num_bins
485
+ self.tail_bound = tail_bound
486
+ self.half_channels = in_channels // 2
487
+
488
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
489
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
490
+ self.proj = nn.Conv1d(
491
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
492
+ )
493
+ self.proj.weight.data.zero_()
494
+ self.proj.bias.data.zero_()
495
+
496
+ def forward(self, x, x_mask, g=None, reverse=False):
497
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
498
+ h = self.pre(x0)
499
+ h = self.convs(h, x_mask, g=g)
500
+ h = self.proj(h) * x_mask
501
+
502
+ b, c, t = x0.shape
503
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
504
+
505
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
506
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
507
+ self.filter_channels
508
+ )
509
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
510
+
511
+ x1, logabsdet = piecewise_rational_quadratic_transform(
512
+ x1,
513
+ unnormalized_widths,
514
+ unnormalized_heights,
515
+ unnormalized_derivatives,
516
+ inverse=reverse,
517
+ tails="linear",
518
+ tail_bound=self.tail_bound,
519
+ )
520
+
521
+ x = torch.cat([x0, x1], 1) * x_mask
522
+
523
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
524
+ if not reverse:
525
+ return x, logdet
526
+ else:
527
+ return x
src/wfloat_tts/vits/transforms.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
6
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
7
+ DEFAULT_MIN_DERIVATIVE = 1e-3
8
+
9
+
10
+ def piecewise_rational_quadratic_transform(
11
+ inputs,
12
+ unnormalized_widths,
13
+ unnormalized_heights,
14
+ unnormalized_derivatives,
15
+ inverse=False,
16
+ tails=None,
17
+ tail_bound=1.0,
18
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
19
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
20
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
21
+ ):
22
+
23
+ if tails is None:
24
+ spline_fn = rational_quadratic_spline
25
+ spline_kwargs = {}
26
+ else:
27
+ spline_fn = unconstrained_rational_quadratic_spline
28
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
29
+
30
+ outputs, logabsdet = spline_fn(
31
+ inputs=inputs,
32
+ unnormalized_widths=unnormalized_widths,
33
+ unnormalized_heights=unnormalized_heights,
34
+ unnormalized_derivatives=unnormalized_derivatives,
35
+ inverse=inverse,
36
+ min_bin_width=min_bin_width,
37
+ min_bin_height=min_bin_height,
38
+ min_derivative=min_derivative,
39
+ **spline_kwargs
40
+ )
41
+ return outputs, logabsdet
42
+
43
+
44
+ def searchsorted(bin_locations, inputs, eps=1e-6):
45
+ # bin_locations[..., -1] += eps
46
+ bin_locations[..., bin_locations.size(-1) - 1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ # unnormalized_derivatives[..., -1] = constant
73
+ unnormalized_derivatives[..., unnormalized_derivatives.size(-1) - 1] = constant
74
+
75
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
76
+ logabsdet[outside_interval_mask] = 0
77
+ else:
78
+ raise RuntimeError("{} tails are not implemented.".format(tails))
79
+
80
+ (
81
+ outputs[inside_interval_mask],
82
+ logabsdet[inside_interval_mask],
83
+ ) = rational_quadratic_spline(
84
+ inputs=inputs[inside_interval_mask],
85
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
86
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
87
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
88
+ inverse=inverse,
89
+ left=-tail_bound,
90
+ right=tail_bound,
91
+ bottom=-tail_bound,
92
+ top=tail_bound,
93
+ min_bin_width=min_bin_width,
94
+ min_bin_height=min_bin_height,
95
+ min_derivative=min_derivative,
96
+ )
97
+
98
+ return outputs, logabsdet
99
+
100
+
101
+ def rational_quadratic_spline(
102
+ inputs,
103
+ unnormalized_widths,
104
+ unnormalized_heights,
105
+ unnormalized_derivatives,
106
+ inverse=False,
107
+ left=0.0,
108
+ right=1.0,
109
+ bottom=0.0,
110
+ top=1.0,
111
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
112
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
113
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
114
+ ):
115
+ # if torch.min(inputs) < left or torch.max(inputs) > right:
116
+ # raise ValueError("Input to a transform is not within its domain")
117
+
118
+ num_bins = unnormalized_widths.shape[-1]
119
+
120
+ # if min_bin_width * num_bins > 1.0:
121
+ # raise ValueError("Minimal bin width too large for the number of bins")
122
+ # if min_bin_height * num_bins > 1.0:
123
+ # raise ValueError("Minimal bin height too large for the number of bins")
124
+
125
+ widths = F.softmax(unnormalized_widths, dim=-1)
126
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
127
+ cumwidths = torch.cumsum(widths, dim=-1)
128
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
129
+ cumwidths = (right - left) * cumwidths + left
130
+ cumwidths[..., 0] = left
131
+ # cumwidths[..., -1] = right
132
+ cumwidths[..., cumwidths.size(-1) - 1] = right
133
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
134
+
135
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
136
+
137
+ heights = F.softmax(unnormalized_heights, dim=-1)
138
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
139
+ cumheights = torch.cumsum(heights, dim=-1)
140
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
141
+ cumheights = (top - bottom) * cumheights + bottom
142
+ cumheights[..., 0] = bottom
143
+ # cumheights[..., -1] = top
144
+ cumheights[..., cumheights.size(-1) - 1] = top
145
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
146
+
147
+ if inverse:
148
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
149
+ else:
150
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
151
+
152
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
153
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
154
+
155
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
156
+ delta = heights / widths
157
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
158
+
159
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
160
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
161
+
162
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
163
+
164
+ if inverse:
165
+ a = (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ ) + input_heights * (input_delta - input_derivatives)
168
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
169
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
170
+ )
171
+ c = -input_delta * (inputs - input_cumheights)
172
+
173
+ discriminant = b.pow(2) - 4 * a * c
174
+ assert (discriminant >= 0).all(), discriminant
175
+
176
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
177
+ outputs = root * input_bin_widths + input_cumwidths
178
+
179
+ theta_one_minus_theta = root * (1 - root)
180
+ denominator = input_delta + (
181
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
182
+ * theta_one_minus_theta
183
+ )
184
+ derivative_numerator = input_delta.pow(2) * (
185
+ input_derivatives_plus_one * root.pow(2)
186
+ + 2 * input_delta * theta_one_minus_theta
187
+ + input_derivatives * (1 - root).pow(2)
188
+ )
189
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
190
+
191
+ return outputs, -logabsdet
192
+
193
+ theta = (inputs - input_cumwidths) / input_bin_widths
194
+ theta_one_minus_theta = theta * (1 - theta)
195
+
196
+ numerator = input_heights * (
197
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
198
+ )
199
+ denominator = input_delta + (
200
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
201
+ * theta_one_minus_theta
202
+ )
203
+ outputs = input_cumheights + numerator / denominator
204
+
205
+ derivative_numerator = input_delta.pow(2) * (
206
+ input_derivatives_plus_one * theta.pow(2)
207
+ + 2 * input_delta * theta_one_minus_theta
208
+ + input_derivatives * (1 - theta).pow(2)
209
+ )
210
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
211
+
212
+ return outputs, logabsdet