ak36 commited on
Commit
bf65828
·
1 Parent(s): c32984a

second_stage_v1

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .ipynb_checkpoints/train-checkpoint.log +342 -0
  2. .ipynb_checkpoints/train_second-checkpoint.py +879 -0
  3. logs/pod_90h_30k/config_ft_single.yml → Configs/.ipynb_checkpoints/config_ft_single-checkpoint.yml +19 -19
  4. Configs/.ipynb_checkpoints/config_libritts-checkpoint.yml +113 -0
  5. Configs/config_ft_single.yml +16 -16
  6. Demo/.ipynb_checkpoints/Inference_LibriTTS-checkpoint.ipynb +1155 -0
  7. Demo/.ipynb_checkpoints/Inference_pod_90h_30k-checkpoint.ipynb +1155 -0
  8. Demo/Inference_pod_90h_30k.ipynb +1360 -0
  9. Modules/.ipynb_checkpoints/slmadv-checkpoint.py +177 -0
  10. Modules/slmadv.py +126 -144
  11. __pycache__/losses.cpython-310.pyc +0 -0
  12. __pycache__/meldataset.cpython-310.pyc +0 -0
  13. __pycache__/models.cpython-310.pyc +0 -0
  14. __pycache__/optimizers.cpython-310.pyc +0 -0
  15. __pycache__/utils.cpython-310.pyc +0 -0
  16. events.out.tfevents.1749451143.164-152-17-237.47710.0 +0 -3
  17. events.out.tfevents.1749451143.164-152-17-237.47712.0 +0 -3
  18. events.out.tfevents.1749451144.164-152-17-237.47706.0 +0 -3
  19. events.out.tfevents.1749451144.164-152-17-237.47708.0 +0 -3
  20. events.out.tfevents.1749451144.164-152-17-237.47709.0 +0 -3
  21. events.out.tfevents.1749451144.164-152-17-237.47711.0 +0 -3
  22. events.out.tfevents.1749451220.164-152-17-237.48862.0 +0 -3
  23. events.out.tfevents.1749451220.164-152-17-237.48863.0 +0 -3
  24. events.out.tfevents.1749451220.164-152-17-237.48864.0 +0 -3
  25. events.out.tfevents.1749451220.164-152-17-237.48865.0 +0 -3
  26. events.out.tfevents.1749451220.164-152-17-237.48868.0 +0 -3
  27. events.out.tfevents.1749451221.164-152-17-237.48861.0 +0 -3
  28. events.out.tfevents.1749451221.164-152-17-237.48867.0 +0 -3
  29. events.out.tfevents.1749451222.164-152-17-237.48866.0 +0 -3
  30. events.out.tfevents.1749453792.164-152-17-237.51057.0 +0 -3
  31. events.out.tfevents.1749453792.164-152-17-237.51059.0 +0 -3
  32. events.out.tfevents.1749453792.164-152-17-237.51061.0 +0 -3
  33. events.out.tfevents.1749453792.164-152-17-237.51063.0 +0 -3
  34. events.out.tfevents.1749453793.164-152-17-237.51056.0 +0 -3
  35. events.out.tfevents.1749453793.164-152-17-237.51058.0 +0 -3
  36. events.out.tfevents.1749453793.164-152-17-237.51060.0 +0 -3
  37. events.out.tfevents.1749453794.164-152-17-237.51062.0 +0 -3
  38. events.out.tfevents.1749453905.164-152-17-237.52357.0 +0 -3
  39. events.out.tfevents.1749453905.164-152-17-237.52358.0 +0 -3
  40. events.out.tfevents.1749453905.164-152-17-237.52360.0 +0 -3
  41. events.out.tfevents.1749453905.164-152-17-237.52361.0 +0 -3
  42. events.out.tfevents.1749453906.164-152-17-237.52355.0 +0 -3
  43. events.out.tfevents.1749453906.164-152-17-237.52356.0 +0 -3
  44. events.out.tfevents.1749453906.164-152-17-237.52359.0 +0 -3
  45. events.out.tfevents.1749453906.164-152-17-237.52362.0 +0 -3
  46. events.out.tfevents.1749453977.164-152-17-237.53096.0 +0 -3
  47. events.out.tfevents.1749453977.164-152-17-237.53097.0 +0 -3
  48. events.out.tfevents.1749453977.164-152-17-237.53098.0 +0 -3
  49. events.out.tfevents.1749453977.164-152-17-237.53099.0 +0 -3
  50. events.out.tfevents.1749453977.164-152-17-237.53100.0 +0 -3
.ipynb_checkpoints/train-checkpoint.log ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ INFO:2025-06-09 01:00:12,153: Epoch [3/25], Step [50/3970], Mel Loss: 0.63656, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
2
+ INFO:2025-06-09 01:00:55,689: Epoch [3/25], Step [100/3970], Mel Loss: 0.63674, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
3
+ INFO:2025-06-09 01:01:45,433: Epoch [3/25], Step [150/3970], Mel Loss: 0.63203, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
4
+ INFO:2025-06-09 01:02:37,587: Epoch [3/25], Step [200/3970], Mel Loss: 0.62929, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
5
+ INFO:2025-06-09 01:03:28,825: Epoch [3/25], Step [250/3970], Mel Loss: 0.63209, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
6
+ INFO:2025-06-09 01:04:18,272: Epoch [3/25], Step [300/3970], Mel Loss: 0.62710, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
7
+ INFO:2025-06-09 01:05:09,751: Epoch [3/25], Step [350/3970], Mel Loss: 0.62325, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
8
+ INFO:2025-06-09 01:06:00,396: Epoch [3/25], Step [400/3970], Mel Loss: 0.62540, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
9
+ INFO:2025-06-09 01:06:51,713: Epoch [3/25], Step [450/3970], Mel Loss: 0.61673, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
10
+ INFO:2025-06-09 01:08:07,720: Validation loss: 0.568
11
+
12
+
13
+
14
+
15
+ INFO:2025-06-09 01:09:07,489: Epoch [4/25], Step [50/3970], Mel Loss: 0.62133, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
16
+ INFO:2025-06-09 01:09:59,049: Epoch [4/25], Step [100/3970], Mel Loss: 0.61368, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
17
+ INFO:2025-06-09 01:10:46,786: Epoch [4/25], Step [150/3970], Mel Loss: 0.61887, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
18
+ INFO:2025-06-09 01:11:36,393: Epoch [4/25], Step [200/3970], Mel Loss: 0.61688, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
19
+ INFO:2025-06-09 01:12:21,624: Epoch [4/25], Step [250/3970], Mel Loss: 0.61630, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
20
+ INFO:2025-06-09 01:13:08,125: Epoch [4/25], Step [300/3970], Mel Loss: 0.61238, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
21
+ INFO:2025-06-09 01:13:53,747: Epoch [4/25], Step [350/3970], Mel Loss: 0.61566, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
22
+ INFO:2025-06-09 01:14:42,113: Epoch [4/25], Step [400/3970], Mel Loss: 0.61601, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
23
+ INFO:2025-06-09 01:15:29,167: Epoch [4/25], Step [450/3970], Mel Loss: 0.61588, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
24
+ INFO:2025-06-09 01:16:42,289: Validation loss: 0.559
25
+
26
+
27
+
28
+
29
+ INFO:2025-06-09 01:17:33,371: Epoch [5/25], Step [50/3970], Mel Loss: 0.61289, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
30
+ INFO:2025-06-09 01:18:21,238: Epoch [5/25], Step [100/3970], Mel Loss: 0.61256, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
31
+ INFO:2025-06-09 01:19:08,129: Epoch [5/25], Step [150/3970], Mel Loss: 0.60756, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
32
+ INFO:2025-06-09 01:19:55,306: Epoch [5/25], Step [200/3970], Mel Loss: 0.60886, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
33
+ INFO:2025-06-09 01:20:38,852: Epoch [5/25], Step [250/3970], Mel Loss: 0.61364, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
34
+ INFO:2025-06-09 01:21:23,920: Epoch [5/25], Step [300/3970], Mel Loss: 0.60994, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
35
+ INFO:2025-06-09 01:22:13,541: Epoch [5/25], Step [350/3970], Mel Loss: 0.59860, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
36
+ INFO:2025-06-09 01:22:59,673: Epoch [5/25], Step [400/3970], Mel Loss: 0.61045, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
37
+ INFO:2025-06-09 01:23:48,982: Epoch [5/25], Step [450/3970], Mel Loss: 0.59750, Gen Loss: 0.00000, Disc Loss: 0.00000, Mono Loss: 0.00000, S2S Loss: 0.00000, SLM Loss: 0.00000
38
+ INFO:2025-06-09 01:25:03,693: Validation loss: 0.554
39
+
40
+
41
+
42
+
43
+ INFO:2025-06-09 01:26:38,581: Epoch [6/25], Step [50/3970], Mel Loss: 0.97984, Gen Loss: 23.17753, Disc Loss: 2.41794, Mono Loss: 0.03876, S2S Loss: 2.62273, SLM Loss: 2.94417
44
+ INFO:2025-06-09 01:28:02,840: Epoch [6/25], Step [100/3970], Mel Loss: 1.07962, Gen Loss: 9.85569, Disc Loss: 3.64044, Mono Loss: 0.04134, S2S Loss: 2.85493, SLM Loss: 2.91264
45
+ INFO:2025-06-09 01:29:28,320: Epoch [6/25], Step [150/3970], Mel Loss: 0.86465, Gen Loss: 7.72572, Disc Loss: 3.54385, Mono Loss: 0.03521, S2S Loss: 2.93994, SLM Loss: 2.88020
46
+ INFO:2025-06-09 01:30:52,824: Epoch [6/25], Step [200/3970], Mel Loss: 0.79431, Gen Loss: 6.71339, Disc Loss: 3.60161, Mono Loss: 0.03136, S2S Loss: 2.63649, SLM Loss: 2.84435
47
+ INFO:2025-06-09 01:32:17,827: Epoch [6/25], Step [250/3970], Mel Loss: 0.69364, Gen Loss: 7.94329, Disc Loss: 3.65054, Mono Loss: 0.02880, S2S Loss: 2.69598, SLM Loss: 2.74597
48
+ INFO:2025-06-09 01:33:40,309: Epoch [6/25], Step [300/3970], Mel Loss: 0.64917, Gen Loss: 6.08137, Disc Loss: 3.72214, Mono Loss: 0.02712, S2S Loss: 2.67866, SLM Loss: 2.67792
49
+ INFO:2025-06-09 01:35:00,338: Epoch [6/25], Step [350/3970], Mel Loss: 0.61861, Gen Loss: 5.89171, Disc Loss: 3.76209, Mono Loss: 0.03497, S2S Loss: 2.41264, SLM Loss: 2.63655
50
+ INFO:2025-06-09 01:36:24,903: Epoch [6/25], Step [400/3970], Mel Loss: 0.60979, Gen Loss: 7.66665, Disc Loss: 3.91850, Mono Loss: 0.03728, S2S Loss: 2.28397, SLM Loss: 2.58636
51
+ INFO:2025-06-09 01:37:48,840: Epoch [6/25], Step [450/3970], Mel Loss: 0.58291, Gen Loss: 6.77857, Disc Loss: 3.68198, Mono Loss: 0.03211, S2S Loss: 2.31591, SLM Loss: 2.60305
52
+ INFO:2025-06-09 01:39:34,991: Validation loss: 0.537
53
+
54
+
55
+
56
+
57
+ INFO:2025-06-09 01:41:06,553: Epoch [7/25], Step [50/3970], Mel Loss: 0.56623, Gen Loss: 9.29693, Disc Loss: 3.64386, Mono Loss: 0.03481, S2S Loss: 2.34774, SLM Loss: 2.54822
58
+ INFO:2025-06-09 01:42:30,275: Epoch [7/25], Step [100/3970], Mel Loss: 0.56485, Gen Loss: 7.86536, Disc Loss: 3.62585, Mono Loss: 0.02787, S2S Loss: 2.28482, SLM Loss: 2.35773
59
+ INFO:2025-06-09 01:43:57,072: Epoch [7/25], Step [150/3970], Mel Loss: 0.55288, Gen Loss: 6.11425, Disc Loss: 3.71416, Mono Loss: 0.03559, S2S Loss: 2.12334, SLM Loss: 2.29561
60
+ INFO:2025-06-09 01:45:19,320: Epoch [7/25], Step [200/3970], Mel Loss: 0.54845, Gen Loss: 8.49720, Disc Loss: 3.60499, Mono Loss: 0.03008, S2S Loss: 2.01824, SLM Loss: 2.54976
61
+ INFO:2025-06-09 01:46:43,392: Epoch [7/25], Step [250/3970], Mel Loss: 0.54677, Gen Loss: 8.23377, Disc Loss: 3.75765, Mono Loss: 0.03607, S2S Loss: 1.90671, SLM Loss: 2.42518
62
+ INFO:2025-06-09 01:48:06,193: Epoch [7/25], Step [300/3970], Mel Loss: 0.53016, Gen Loss: 6.69302, Disc Loss: 3.70597, Mono Loss: 0.02783, S2S Loss: 1.66317, SLM Loss: 2.40360
63
+ INFO:2025-06-09 01:49:31,131: Epoch [7/25], Step [350/3970], Mel Loss: 0.53427, Gen Loss: 9.37859, Disc Loss: 3.71469, Mono Loss: 0.03775, S2S Loss: 1.82596, SLM Loss: 2.32477
64
+ INFO:2025-06-09 01:50:55,050: Epoch [7/25], Step [400/3970], Mel Loss: 0.52960, Gen Loss: 7.75816, Disc Loss: 3.68042, Mono Loss: 0.03405, S2S Loss: 1.99194, SLM Loss: 2.36643
65
+ INFO:2025-06-09 01:52:17,195: Epoch [7/25], Step [450/3970], Mel Loss: 0.54899, Gen Loss: 6.99428, Disc Loss: 3.58416, Mono Loss: 0.02682, S2S Loss: 1.87868, SLM Loss: 2.30144
66
+ INFO:2025-06-09 01:54:03,014: Validation loss: 0.483
67
+
68
+
69
+
70
+
71
+ INFO:2025-06-09 01:55:38,965: Epoch [8/25], Step [50/3970], Mel Loss: 0.52470, Gen Loss: 10.28693, Disc Loss: 3.63978, Mono Loss: 0.03907, S2S Loss: 1.83708, SLM Loss: 2.34067
72
+ INFO:2025-06-09 01:57:02,183: Epoch [8/25], Step [100/3970], Mel Loss: 0.53876, Gen Loss: 7.51378, Disc Loss: 3.64213, Mono Loss: 0.03625, S2S Loss: 1.72809, SLM Loss: 2.16873
73
+ INFO:2025-06-09 01:58:25,854: Epoch [8/25], Step [150/3970], Mel Loss: 0.52859, Gen Loss: 7.03971, Disc Loss: 3.77774, Mono Loss: 0.03803, S2S Loss: 2.06151, SLM Loss: 2.42842
74
+ INFO:2025-06-09 01:59:48,325: Epoch [8/25], Step [200/3970], Mel Loss: 0.52193, Gen Loss: 8.06612, Disc Loss: 3.57948, Mono Loss: 0.02800, S2S Loss: 1.74277, SLM Loss: 2.34817
75
+ INFO:2025-06-09 02:01:13,028: Epoch [8/25], Step [250/3970], Mel Loss: 0.51478, Gen Loss: 8.66409, Disc Loss: 3.49825, Mono Loss: 0.04338, S2S Loss: 1.61835, SLM Loss: 2.15467
76
+ INFO:2025-06-09 02:02:38,248: Epoch [8/25], Step [300/3970], Mel Loss: 0.52305, Gen Loss: 6.17309, Disc Loss: 3.58619, Mono Loss: 0.02633, S2S Loss: 1.72876, SLM Loss: 2.30287
77
+ INFO:2025-06-09 02:03:59,711: Epoch [8/25], Step [350/3970], Mel Loss: 0.52308, Gen Loss: 9.16025, Disc Loss: 3.61594, Mono Loss: 0.03729, S2S Loss: 1.62949, SLM Loss: 2.24222
78
+ INFO:2025-06-09 02:05:25,200: Epoch [8/25], Step [400/3970], Mel Loss: 0.52249, Gen Loss: 10.42249, Disc Loss: 3.28414, Mono Loss: 0.03355, S2S Loss: 1.73984, SLM Loss: 2.41117
79
+ INFO:2025-06-09 02:06:52,106: Epoch [8/25], Step [450/3970], Mel Loss: 0.53768, Gen Loss: 9.69683, Disc Loss: 3.94595, Mono Loss: 0.03386, S2S Loss: 1.51221, SLM Loss: 2.22714
80
+ INFO:2025-06-09 02:08:43,461: Validation loss: 0.576
81
+
82
+
83
+
84
+
85
+ INFO:2025-06-09 02:10:15,610: Epoch [9/25], Step [50/3970], Mel Loss: 0.54530, Gen Loss: 8.99131, Disc Loss: 3.51233, Mono Loss: 0.04022, S2S Loss: 1.83184, SLM Loss: 2.27819
86
+ INFO:2025-06-09 02:11:40,932: Epoch [9/25], Step [100/3970], Mel Loss: 0.52112, Gen Loss: 10.46535, Disc Loss: 3.27846, Mono Loss: 0.04066, S2S Loss: 1.40977, SLM Loss: 2.23133
87
+ INFO:2025-06-09 02:13:05,924: Epoch [9/25], Step [150/3970], Mel Loss: 0.52145, Gen Loss: 7.49124, Disc Loss: 3.61879, Mono Loss: 0.03863, S2S Loss: 1.37945, SLM Loss: 1.97726
88
+ INFO:2025-06-09 02:14:33,751: Epoch [9/25], Step [200/3970], Mel Loss: 0.52140, Gen Loss: 9.70458, Disc Loss: 3.41580, Mono Loss: 0.02937, S2S Loss: 1.11212, SLM Loss: 2.06116
89
+ INFO:2025-06-09 02:16:00,853: Epoch [9/25], Step [250/3970], Mel Loss: 0.51334, Gen Loss: 9.93914, Disc Loss: 3.31054, Mono Loss: 0.03239, S2S Loss: 1.65795, SLM Loss: 2.17712
90
+ INFO:2025-06-09 02:17:34,691: Epoch [9/25], Step [300/3970], Mel Loss: 0.51978, Gen Loss: 8.67425, Disc Loss: 3.50655, Mono Loss: 0.03035, S2S Loss: 1.59474, SLM Loss: 2.07509
91
+ INFO:2025-06-09 02:19:00,242: Epoch [9/25], Step [350/3970], Mel Loss: 0.53707, Gen Loss: 9.08727, Disc Loss: 3.54805, Mono Loss: 0.02839, S2S Loss: 1.37795, SLM Loss: 2.33604
92
+ INFO:2025-06-09 02:20:26,229: Epoch [9/25], Step [400/3970], Mel Loss: 0.53049, Gen Loss: 10.96715, Disc Loss: 3.46047, Mono Loss: 0.01226, S2S Loss: 1.65820, SLM Loss: 2.27651
93
+ INFO:2025-06-09 02:21:51,865: Epoch [9/25], Step [450/3970], Mel Loss: 0.54155, Gen Loss: 12.13123, Disc Loss: 3.59190, Mono Loss: 0.03659, S2S Loss: 1.53226, SLM Loss: 2.27701
94
+ INFO:2025-06-09 02:23:43,778: Validation loss: 0.570
95
+
96
+
97
+
98
+
99
+ INFO:2025-06-09 02:25:19,386: Epoch [10/25], Step [50/3970], Mel Loss: 0.51787, Gen Loss: 8.53281, Disc Loss: 3.29783, Mono Loss: 0.03446, S2S Loss: 1.45484, SLM Loss: 2.05988
100
+ INFO:2025-06-09 02:26:47,838: Epoch [10/25], Step [100/3970], Mel Loss: 0.50870, Gen Loss: 10.23766, Disc Loss: 3.65731, Mono Loss: 0.02026, S2S Loss: 1.49496, SLM Loss: 2.13103
101
+ INFO:2025-06-09 02:28:14,726: Epoch [10/25], Step [150/3970], Mel Loss: 0.54335, Gen Loss: 10.79815, Disc Loss: 3.13370, Mono Loss: 0.03967, S2S Loss: 1.45980, SLM Loss: 2.08032
102
+ INFO:2025-06-09 02:29:40,057: Epoch [10/25], Step [200/3970], Mel Loss: 0.54070, Gen Loss: 16.04053, Disc Loss: 2.25289, Mono Loss: 0.03900, S2S Loss: 1.27599, SLM Loss: 1.92492
103
+ INFO:2025-06-09 02:31:03,537: Epoch [10/25], Step [250/3970], Mel Loss: 0.52524, Gen Loss: 10.16288, Disc Loss: 3.41181, Mono Loss: 0.03559, S2S Loss: 1.43502, SLM Loss: 2.11398
104
+ INFO:2025-06-09 02:32:27,638: Epoch [10/25], Step [300/3970], Mel Loss: 0.50069, Gen Loss: 10.37155, Disc Loss: 3.37344, Mono Loss: 0.02960, S2S Loss: 1.31690, SLM Loss: 2.10520
105
+ INFO:2025-06-09 02:33:55,403: Epoch [10/25], Step [350/3970], Mel Loss: 0.51798, Gen Loss: 8.71367, Disc Loss: 3.56202, Mono Loss: 0.02943, S2S Loss: 1.46521, SLM Loss: 1.94630
106
+ INFO:2025-06-09 02:35:20,742: Epoch [10/25], Step [400/3970], Mel Loss: 0.50403, Gen Loss: 10.07516, Disc Loss: 3.33516, Mono Loss: 0.03924, S2S Loss: 1.33150, SLM Loss: 2.00906
107
+ INFO:2025-06-09 02:36:48,130: Epoch [10/25], Step [450/3970], Mel Loss: 0.51103, Gen Loss: 13.04639, Disc Loss: 2.76208, Mono Loss: 0.03206, S2S Loss: 1.34592, SLM Loss: 1.97846
108
+ INFO:2025-06-09 02:38:35,829: Validation loss: 0.494
109
+
110
+
111
+
112
+
113
+ INFO:2025-06-09 02:40:09,941: Epoch [11/25], Step [50/3970], Mel Loss: 0.50742, Gen Loss: 14.54983, Disc Loss: 3.16779, Mono Loss: 0.03301, S2S Loss: 1.24001, SLM Loss: 2.09207
114
+ INFO:2025-06-09 02:41:31,696: Epoch [11/25], Step [100/3970], Mel Loss: 0.51317, Gen Loss: 10.19775, Disc Loss: 3.39758, Mono Loss: 0.03635, S2S Loss: 1.18542, SLM Loss: 1.87807
115
+ INFO:2025-06-09 02:43:00,303: Epoch [11/25], Step [150/3970], Mel Loss: 0.49535, Gen Loss: 10.78111, Disc Loss: 3.42311, Mono Loss: 0.02923, S2S Loss: 1.30470, SLM Loss: 2.03685
116
+ INFO:2025-06-09 02:44:25,379: Epoch [11/25], Step [200/3970], Mel Loss: 0.52437, Gen Loss: 12.13915, Disc Loss: 3.48997, Mono Loss: 0.03481, S2S Loss: 1.40437, SLM Loss: 2.22972
117
+ INFO:2025-06-09 02:45:48,875: Epoch [11/25], Step [250/3970], Mel Loss: 0.52986, Gen Loss: 10.21078, Disc Loss: 3.29919, Mono Loss: 0.03813, S2S Loss: 1.31922, SLM Loss: 2.12596
118
+ INFO:2025-06-09 02:47:12,758: Epoch [11/25], Step [300/3970], Mel Loss: 0.50854, Gen Loss: 11.45082, Disc Loss: 3.70037, Mono Loss: 0.02635, S2S Loss: 1.30175, SLM Loss: 2.10208
119
+ INFO:2025-06-09 02:48:37,380: Epoch [11/25], Step [350/3970], Mel Loss: 0.51107, Gen Loss: 12.95888, Disc Loss: 2.67743, Mono Loss: 0.03419, S2S Loss: 1.20725, SLM Loss: 1.90229
120
+ INFO:2025-06-09 02:50:00,179: Epoch [11/25], Step [400/3970], Mel Loss: 0.50464, Gen Loss: 14.90269, Disc Loss: 2.48575, Mono Loss: 0.02608, S2S Loss: 1.25147, SLM Loss: 2.09807
121
+ INFO:2025-06-09 02:51:23,892: Epoch [11/25], Step [450/3970], Mel Loss: 0.50813, Gen Loss: 11.16796, Disc Loss: 3.33927, Mono Loss: 0.02247, S2S Loss: 1.23355, SLM Loss: 2.03762
122
+ INFO:2025-06-09 02:53:11,784: Validation loss: 0.500
123
+
124
+
125
+
126
+
127
+ INFO:2025-06-09 02:54:42,588: Epoch [12/25], Step [50/3970], Mel Loss: 0.51559, Gen Loss: 12.73433, Disc Loss: 3.10313, Mono Loss: 0.02916, S2S Loss: 1.15304, SLM Loss: 2.04169
128
+ INFO:2025-06-09 02:56:06,203: Epoch [12/25], Step [100/3970], Mel Loss: 0.53166, Gen Loss: 15.15074, Disc Loss: 2.44015, Mono Loss: 0.02971, S2S Loss: 1.16761, SLM Loss: 2.05299
129
+ INFO:2025-06-09 02:57:34,086: Epoch [12/25], Step [150/3970], Mel Loss: 0.52280, Gen Loss: 13.14780, Disc Loss: 2.62888, Mono Loss: 0.02973, S2S Loss: 1.36648, SLM Loss: 2.20558
130
+ INFO:2025-06-09 02:59:02,556: Epoch [12/25], Step [200/3970], Mel Loss: 0.51851, Gen Loss: 11.34044, Disc Loss: 3.37104, Mono Loss: 0.02984, S2S Loss: 1.27682, SLM Loss: 2.09497
131
+ INFO:2025-06-09 03:00:26,363: Epoch [12/25], Step [250/3970], Mel Loss: 0.50483, Gen Loss: 9.85288, Disc Loss: 3.57650, Mono Loss: 0.02718, S2S Loss: 1.23974, SLM Loss: 2.03210
132
+ INFO:2025-06-09 03:01:51,385: Epoch [12/25], Step [300/3970], Mel Loss: 0.49458, Gen Loss: 11.01429, Disc Loss: 3.23517, Mono Loss: 0.03900, S2S Loss: 1.20380, SLM Loss: 2.05801
133
+ INFO:2025-06-09 03:03:16,092: Epoch [12/25], Step [350/3970], Mel Loss: 0.52576, Gen Loss: 9.24854, Disc Loss: 3.31379, Mono Loss: 0.03915, S2S Loss: 1.13775, SLM Loss: 2.09123
134
+ INFO:2025-06-09 03:04:46,606: Epoch [12/25], Step [400/3970], Mel Loss: 0.50221, Gen Loss: 10.17654, Disc Loss: 3.53851, Mono Loss: 0.03372, S2S Loss: 1.20534, SLM Loss: 2.00393
135
+ INFO:2025-06-09 03:06:09,986: Epoch [12/25], Step [450/3970], Mel Loss: 0.52733, Gen Loss: 14.04951, Disc Loss: 2.97020, Mono Loss: 0.02758, S2S Loss: 1.06250, SLM Loss: 1.99839
136
+ INFO:2025-06-09 03:07:59,553: Validation loss: 0.471
137
+
138
+
139
+
140
+
141
+ INFO:2025-06-09 03:09:33,778: Epoch [13/25], Step [50/3970], Mel Loss: 0.49267, Gen Loss: 11.16026, Disc Loss: 3.05907, Mono Loss: 0.02503, S2S Loss: 1.00237, SLM Loss: 1.82247
142
+ INFO:2025-06-09 03:11:01,558: Epoch [13/25], Step [100/3970], Mel Loss: 0.49195, Gen Loss: 9.90175, Disc Loss: 3.69180, Mono Loss: 0.02045, S2S Loss: 1.30512, SLM Loss: 1.89232
143
+ INFO:2025-06-09 03:12:28,265: Epoch [13/25], Step [150/3970], Mel Loss: 0.49494, Gen Loss: 12.47827, Disc Loss: 2.95689, Mono Loss: 0.02328, S2S Loss: 1.05090, SLM Loss: 1.93738
144
+ INFO:2025-06-09 03:13:50,919: Epoch [13/25], Step [200/3970], Mel Loss: 0.51474, Gen Loss: 11.96166, Disc Loss: 3.02671, Mono Loss: 0.03436, S2S Loss: 1.09315, SLM Loss: 1.92749
145
+ INFO:2025-06-09 03:15:17,888: Epoch [13/25], Step [250/3970], Mel Loss: 0.51118, Gen Loss: 17.41326, Disc Loss: 2.00980, Mono Loss: 0.02579, S2S Loss: 1.27793, SLM Loss: 2.11468
146
+ INFO:2025-06-09 03:16:42,783: Epoch [13/25], Step [300/3970], Mel Loss: 0.52890, Gen Loss: 10.31864, Disc Loss: 3.39953, Mono Loss: 0.03986, S2S Loss: 1.07551, SLM Loss: 1.85846
147
+ INFO:2025-06-09 03:18:10,773: Epoch [13/25], Step [350/3970], Mel Loss: 0.49484, Gen Loss: 9.64177, Disc Loss: 3.54298, Mono Loss: 0.03428, S2S Loss: 1.03875, SLM Loss: 1.94741
148
+ INFO:2025-06-09 03:19:34,116: Epoch [13/25], Step [400/3970], Mel Loss: 0.49129, Gen Loss: 15.92753, Disc Loss: 2.42741, Mono Loss: 0.03079, S2S Loss: 1.08931, SLM Loss: 1.93091
149
+ INFO:2025-06-09 03:20:57,906: Epoch [13/25], Step [450/3970], Mel Loss: 0.50837, Gen Loss: 11.88165, Disc Loss: 3.05828, Mono Loss: 0.02700, S2S Loss: 1.02226, SLM Loss: 2.08510
150
+ INFO:2025-06-09 03:22:45,754: Validation loss: 0.481
151
+
152
+
153
+
154
+
155
+ INFO:2025-06-09 03:24:21,819: Epoch [14/25], Step [50/3970], Mel Loss: 0.48985, Gen Loss: 12.91563, Disc Loss: 2.76847, Mono Loss: 0.04083, S2S Loss: 1.14642, SLM Loss: 1.95628
156
+ INFO:2025-06-09 03:25:47,789: Epoch [14/25], Step [100/3970], Mel Loss: 0.49605, Gen Loss: 12.36551, Disc Loss: 2.68910, Mono Loss: 0.02049, S2S Loss: 1.21105, SLM Loss: 1.98918
157
+ INFO:2025-06-09 03:27:15,432: Epoch [14/25], Step [150/3970], Mel Loss: 0.50257, Gen Loss: 11.92622, Disc Loss: 3.19436, Mono Loss: 0.02369, S2S Loss: 1.07385, SLM Loss: 1.89891
158
+ INFO:2025-06-09 03:28:39,067: Epoch [14/25], Step [200/3970], Mel Loss: 0.49241, Gen Loss: 10.28760, Disc Loss: 3.14552, Mono Loss: 0.02857, S2S Loss: 1.18241, SLM Loss: 1.92434
159
+ INFO:2025-06-09 03:30:02,661: Epoch [14/25], Step [250/3970], Mel Loss: 0.50479, Gen Loss: 14.76222, Disc Loss: 2.38417, Mono Loss: 0.03464, S2S Loss: 1.11464, SLM Loss: 1.90056
160
+ INFO:2025-06-09 03:31:25,232: Epoch [14/25], Step [300/3970], Mel Loss: 0.51008, Gen Loss: 10.93404, Disc Loss: 3.09978, Mono Loss: 0.02636, S2S Loss: 1.06983, SLM Loss: 2.01580
161
+ INFO:2025-06-09 03:32:47,763: Epoch [14/25], Step [350/3970], Mel Loss: 0.49603, Gen Loss: 13.14640, Disc Loss: 2.93398, Mono Loss: 0.03515, S2S Loss: 1.13013, SLM Loss: 1.87904
162
+ INFO:2025-06-09 03:34:09,946: Epoch [14/25], Step [400/3970], Mel Loss: 0.49311, Gen Loss: 15.16063, Disc Loss: 2.26509, Mono Loss: 0.02470, S2S Loss: 1.13779, SLM Loss: 1.98486
163
+ INFO:2025-06-09 03:35:35,944: Epoch [14/25], Step [450/3970], Mel Loss: 0.49960, Gen Loss: 14.25554, Disc Loss: 2.94801, Mono Loss: 0.03173, S2S Loss: 1.06616, SLM Loss: 1.87259
164
+ INFO:2025-06-09 03:37:23,594: Validation loss: 0.474
165
+
166
+
167
+
168
+
169
+ INFO:2025-06-09 03:38:59,186: Epoch [15/25], Step [50/3970], Mel Loss: 0.49733, Gen Loss: 15.70522, Disc Loss: 2.52600, Mono Loss: 0.02580, S2S Loss: 1.16987, SLM Loss: 1.96256
170
+ INFO:2025-06-09 03:40:28,804: Epoch [15/25], Step [100/3970], Mel Loss: 0.50295, Gen Loss: 14.09082, Disc Loss: 2.55250, Mono Loss: 0.02887, S2S Loss: 0.98603, SLM Loss: 1.95487
171
+ INFO:2025-06-09 03:41:51,594: Epoch [15/25], Step [150/3970], Mel Loss: 0.50385, Gen Loss: 9.58022, Disc Loss: 3.57700, Mono Loss: 0.02735, S2S Loss: 1.30176, SLM Loss: 1.89019
172
+ INFO:2025-06-09 03:43:21,769: Epoch [15/25], Step [200/3970], Mel Loss: 0.50567, Gen Loss: 13.83009, Disc Loss: 2.59786, Mono Loss: 0.02999, S2S Loss: 0.97563, SLM Loss: 1.98239
173
+ INFO:2025-06-09 03:44:48,877: Epoch [15/25], Step [250/3970], Mel Loss: 0.50670, Gen Loss: 15.52833, Disc Loss: 2.77497, Mono Loss: 0.03401, S2S Loss: 1.08127, SLM Loss: 1.98682
174
+ INFO:2025-06-09 03:46:14,023: Epoch [15/25], Step [300/3970], Mel Loss: 0.48451, Gen Loss: 10.18036, Disc Loss: 3.13374, Mono Loss: 0.03035, S2S Loss: 1.07857, SLM Loss: 1.76348
175
+ INFO:2025-06-09 03:47:40,871: Epoch [15/25], Step [350/3970], Mel Loss: 0.48213, Gen Loss: 14.40100, Disc Loss: 2.57544, Mono Loss: 0.03224, S2S Loss: 1.27353, SLM Loss: 2.20880
176
+ INFO:2025-06-09 03:49:02,418: Epoch [15/25], Step [400/3970], Mel Loss: 0.49212, Gen Loss: 12.04808, Disc Loss: 3.22365, Mono Loss: 0.04031, S2S Loss: 1.16138, SLM Loss: 1.93176
177
+ INFO:2025-06-09 03:50:29,288: Epoch [15/25], Step [450/3970], Mel Loss: 0.49498, Gen Loss: 15.64781, Disc Loss: 2.42535, Mono Loss: 0.02771, S2S Loss: 1.12834, SLM Loss: 1.92330
178
+ INFO:2025-06-09 03:52:16,555: Validation loss: 0.484
179
+
180
+
181
+
182
+
183
+ INFO:2025-06-09 03:53:54,302: Epoch [16/25], Step [50/3970], Mel Loss: 0.51999, Gen Loss: 14.74118, Disc Loss: 3.12985, Mono Loss: 0.03898, S2S Loss: 1.17255, SLM Loss: 2.04765
184
+ INFO:2025-06-09 03:55:21,475: Epoch [16/25], Step [100/3970], Mel Loss: 0.50182, Gen Loss: 12.08678, Disc Loss: 2.86320, Mono Loss: 0.03554, S2S Loss: 1.07750, SLM Loss: 1.93926
185
+ INFO:2025-06-09 03:56:47,475: Epoch [16/25], Step [150/3970], Mel Loss: 0.50050, Gen Loss: 9.98121, Disc Loss: 3.66234, Mono Loss: 0.04147, S2S Loss: 1.05245, SLM Loss: 1.79418
186
+ INFO:2025-06-09 03:58:18,119: Epoch [16/25], Step [200/3970], Mel Loss: 0.48289, Gen Loss: 11.81459, Disc Loss: 3.18255, Mono Loss: 0.02753, S2S Loss: 1.00464, SLM Loss: 2.07043
187
+ INFO:2025-06-09 03:59:52,641: Epoch [16/25], Step [250/3970], Mel Loss: 0.50084, Gen Loss: 12.98343, Disc Loss: 2.73330, Mono Loss: 0.02163, S2S Loss: 1.01344, SLM Loss: 2.09205
188
+ INFO:2025-06-09 04:01:18,443: Epoch [16/25], Step [300/3970], Mel Loss: 0.48698, Gen Loss: 10.43984, Disc Loss: 3.65968, Mono Loss: 0.02921, S2S Loss: 1.09594, SLM Loss: 1.76368
189
+ INFO:2025-06-09 04:02:43,362: Epoch [16/25], Step [350/3970], Mel Loss: 0.49274, Gen Loss: 14.27684, Disc Loss: 2.40941, Mono Loss: 0.02796, S2S Loss: 1.00817, SLM Loss: 1.93564
190
+ INFO:2025-06-09 04:04:11,908: Epoch [16/25], Step [400/3970], Mel Loss: 0.48388, Gen Loss: 12.83471, Disc Loss: 2.84156, Mono Loss: 0.02278, S2S Loss: 1.02146, SLM Loss: 1.84777
191
+ INFO:2025-06-09 04:05:42,332: Epoch [16/25], Step [450/3970], Mel Loss: 0.47624, Gen Loss: 11.16790, Disc Loss: 3.20130, Mono Loss: 0.03327, S2S Loss: 1.30728, SLM Loss: 2.05759
192
+ INFO:2025-06-09 04:07:30,779: Validation loss: 0.441
193
+
194
+
195
+
196
+
197
+ INFO:2025-06-09 04:09:06,393: Epoch [17/25], Step [50/3970], Mel Loss: 0.47322, Gen Loss: 10.21182, Disc Loss: 3.17271, Mono Loss: 0.03857, S2S Loss: 1.05606, SLM Loss: 1.89386
198
+ INFO:2025-06-09 04:10:36,875: Epoch [17/25], Step [100/3970], Mel Loss: 0.49200, Gen Loss: 11.82951, Disc Loss: 2.96425, Mono Loss: 0.03267, S2S Loss: 0.93129, SLM Loss: 1.86663
199
+ INFO:2025-06-09 04:12:04,591: Epoch [17/25], Step [150/3970], Mel Loss: 0.47137, Gen Loss: 8.62027, Disc Loss: 3.83087, Mono Loss: 0.03580, S2S Loss: 0.96079, SLM Loss: 1.77226
200
+ INFO:2025-06-09 04:13:32,190: Epoch [17/25], Step [200/3970], Mel Loss: 0.47893, Gen Loss: 10.65770, Disc Loss: 3.56187, Mono Loss: 0.02644, S2S Loss: 1.01706, SLM Loss: 2.04362
201
+ INFO:2025-06-09 04:14:56,130: Epoch [17/25], Step [250/3970], Mel Loss: 0.48653, Gen Loss: 8.75337, Disc Loss: 3.41021, Mono Loss: 0.03018, S2S Loss: 1.21360, SLM Loss: 1.96406
202
+ INFO:2025-06-09 04:16:22,783: Epoch [17/25], Step [300/3970], Mel Loss: 0.47912, Gen Loss: 13.35522, Disc Loss: 2.50100, Mono Loss: 0.03506, S2S Loss: 0.97357, SLM Loss: 1.86547
203
+ INFO:2025-06-09 04:17:48,667: Epoch [17/25], Step [350/3970], Mel Loss: 0.48305, Gen Loss: 11.61836, Disc Loss: 2.81784, Mono Loss: 0.03363, S2S Loss: 1.08002, SLM Loss: 2.06560
204
+ INFO:2025-06-09 04:19:14,326: Epoch [17/25], Step [400/3970], Mel Loss: 0.47853, Gen Loss: 12.34376, Disc Loss: 2.69610, Mono Loss: 0.02721, S2S Loss: 1.04334, SLM Loss: 1.70590
205
+ INFO:2025-06-09 04:20:35,887: Epoch [17/25], Step [450/3970], Mel Loss: 0.49008, Gen Loss: 13.10175, Disc Loss: 2.95339, Mono Loss: 0.02577, S2S Loss: 0.94490, SLM Loss: 1.85765
206
+ INFO:2025-06-09 04:22:27,589: Validation loss: 0.452
207
+
208
+
209
+
210
+
211
+ INFO:2025-06-09 04:24:06,620: Epoch [18/25], Step [50/3970], Mel Loss: 0.47998, Gen Loss: 11.54100, Disc Loss: 2.96952, Mono Loss: 0.03144, S2S Loss: 1.16552, SLM Loss: 2.03827
212
+ INFO:2025-06-09 04:25:34,179: Epoch [18/25], Step [100/3970], Mel Loss: 0.47805, Gen Loss: 13.14331, Disc Loss: 2.78030, Mono Loss: 0.03480, S2S Loss: 1.11621, SLM Loss: 1.77017
213
+ INFO:2025-06-09 04:26:56,768: Epoch [18/25], Step [150/3970], Mel Loss: 0.47885, Gen Loss: 13.77032, Disc Loss: 2.75735, Mono Loss: 0.03735, S2S Loss: 1.01647, SLM Loss: 1.84030
214
+ INFO:2025-06-09 04:28:22,414: Epoch [18/25], Step [200/3970], Mel Loss: 0.48841, Gen Loss: 9.09012, Disc Loss: 4.03486, Mono Loss: 0.02713, S2S Loss: 0.98568, SLM Loss: 2.00300
215
+ INFO:2025-06-09 04:29:49,349: Epoch [18/25], Step [250/3970], Mel Loss: 0.47842, Gen Loss: 11.78235, Disc Loss: 3.07922, Mono Loss: 0.03915, S2S Loss: 1.13136, SLM Loss: 1.98326
216
+ INFO:2025-06-09 04:31:14,759: Epoch [18/25], Step [300/3970], Mel Loss: 0.47979, Gen Loss: 11.35575, Disc Loss: 3.51251, Mono Loss: 0.02926, S2S Loss: 0.97950, SLM Loss: 1.76771
217
+ INFO:2025-06-09 04:32:36,423: Epoch [18/25], Step [350/3970], Mel Loss: 0.47505, Gen Loss: 11.21230, Disc Loss: 3.30019, Mono Loss: 0.04272, S2S Loss: 0.82646, SLM Loss: 1.81805
218
+ INFO:2025-06-09 04:34:01,703: Epoch [18/25], Step [400/3970], Mel Loss: 0.47216, Gen Loss: 12.97506, Disc Loss: 2.76475, Mono Loss: 0.02976, S2S Loss: 0.98103, SLM Loss: 1.91380
219
+ INFO:2025-06-09 04:35:26,512: Epoch [18/25], Step [450/3970], Mel Loss: 0.47869, Gen Loss: 14.93690, Disc Loss: 2.70676, Mono Loss: 0.02386, S2S Loss: 0.86181, SLM Loss: 1.86082
220
+ INFO:2025-06-09 04:37:13,388: Validation loss: 0.472
221
+
222
+
223
+
224
+
225
+ INFO:2025-06-09 04:38:47,409: Epoch [19/25], Step [50/3970], Mel Loss: 0.50106, Gen Loss: 10.75373, Disc Loss: 3.20791, Mono Loss: 0.03283, S2S Loss: 1.05680, SLM Loss: 2.04265
226
+ INFO:2025-06-09 04:40:11,614: Epoch [19/25], Step [100/3970], Mel Loss: 0.49229, Gen Loss: 12.93675, Disc Loss: 2.80900, Mono Loss: 0.02687, S2S Loss: 1.00646, SLM Loss: 1.95273
227
+ INFO:2025-06-09 04:41:38,257: Epoch [19/25], Step [150/3970], Mel Loss: 0.48437, Gen Loss: 9.62462, Disc Loss: 3.65687, Mono Loss: 0.03190, S2S Loss: 1.10174, SLM Loss: 1.83353
228
+ INFO:2025-06-09 04:43:10,165: Epoch [19/25], Step [200/3970], Mel Loss: 0.47472, Gen Loss: 14.23377, Disc Loss: 2.64775, Mono Loss: 0.03328, S2S Loss: 0.93959, SLM Loss: 1.89750
229
+ INFO:2025-06-09 04:44:38,022: Epoch [19/25], Step [250/3970], Mel Loss: 0.47779, Gen Loss: 14.43170, Disc Loss: 2.78379, Mono Loss: 0.01906, S2S Loss: 0.99975, SLM Loss: 1.75941
230
+ INFO:2025-06-09 04:46:03,339: Epoch [19/25], Step [300/3970], Mel Loss: 0.48885, Gen Loss: 13.73142, Disc Loss: 2.71020, Mono Loss: 0.03351, S2S Loss: 0.97011, SLM Loss: 1.83246
231
+ INFO:2025-06-09 04:47:31,942: Epoch [19/25], Step [350/3970], Mel Loss: 0.49685, Gen Loss: 11.76912, Disc Loss: 3.36635, Mono Loss: 0.03053, S2S Loss: 1.02552, SLM Loss: 2.14654
232
+ INFO:2025-06-09 04:49:03,189: Epoch [19/25], Step [400/3970], Mel Loss: 0.50997, Gen Loss: 12.43883, Disc Loss: 3.06039, Mono Loss: 0.02553, S2S Loss: 1.06563, SLM Loss: 2.07121
233
+ INFO:2025-06-09 04:50:29,105: Epoch [19/25], Step [450/3970], Mel Loss: 0.47820, Gen Loss: 14.01424, Disc Loss: 2.55081, Mono Loss: 0.02975, S2S Loss: 0.90710, SLM Loss: 1.83058
234
+ INFO:2025-06-09 04:52:19,425: Validation loss: 0.458
235
+
236
+
237
+
238
+
239
+ INFO:2025-06-09 04:53:52,052: Epoch [20/25], Step [50/3970], Mel Loss: 0.46500, Gen Loss: 10.54346, Disc Loss: 3.31023, Mono Loss: 0.02908, S2S Loss: 0.90407, SLM Loss: 1.78864
240
+ INFO:2025-06-09 04:55:17,797: Epoch [20/25], Step [100/3970], Mel Loss: 0.47071, Gen Loss: 12.65996, Disc Loss: 2.62974, Mono Loss: 0.03045, S2S Loss: 0.92329, SLM Loss: 1.80409
241
+ INFO:2025-06-09 04:56:43,658: Epoch [20/25], Step [150/3970], Mel Loss: 0.50137, Gen Loss: 10.70967, Disc Loss: 3.02257, Mono Loss: 0.03576, S2S Loss: 1.06377, SLM Loss: 1.94652
242
+ INFO:2025-06-09 04:58:10,885: Epoch [20/25], Step [200/3970], Mel Loss: 0.48577, Gen Loss: 13.93800, Disc Loss: 2.42582, Mono Loss: 0.03042, S2S Loss: 1.00638, SLM Loss: 1.94606
243
+ INFO:2025-06-09 04:59:36,042: Epoch [20/25], Step [250/3970], Mel Loss: 0.47842, Gen Loss: 11.98050, Disc Loss: 2.85478, Mono Loss: 0.03991, S2S Loss: 0.88393, SLM Loss: 1.78679
244
+ INFO:2025-06-09 05:01:05,055: Epoch [20/25], Step [300/3970], Mel Loss: 0.47698, Gen Loss: 14.17083, Disc Loss: 2.71589, Mono Loss: 0.03090, S2S Loss: 0.95709, SLM Loss: 1.76501
245
+ INFO:2025-06-09 05:02:31,554: Epoch [20/25], Step [350/3970], Mel Loss: 0.47476, Gen Loss: 12.08669, Disc Loss: 2.79200, Mono Loss: 0.03368, S2S Loss: 1.02864, SLM Loss: 1.86428
246
+ INFO:2025-06-09 05:03:58,164: Epoch [20/25], Step [400/3970], Mel Loss: 0.47778, Gen Loss: 16.65605, Disc Loss: 1.92339, Mono Loss: 0.02666, S2S Loss: 1.15662, SLM Loss: 1.86039
247
+ INFO:2025-06-09 05:05:20,371: Epoch [20/25], Step [450/3970], Mel Loss: 0.47590, Gen Loss: 10.66579, Disc Loss: 2.86761, Mono Loss: 0.02992, S2S Loss: 0.91479, SLM Loss: 1.80008
248
+ INFO:2025-06-09 05:07:11,411: Validation loss: 0.443
249
+
250
+
251
+
252
+
253
+ INFO:2025-06-09 05:08:44,874: Epoch [21/25], Step [50/3970], Mel Loss: 0.46999, Gen Loss: 13.12624, Disc Loss: 2.57343, Mono Loss: 0.03491, S2S Loss: 1.04637, SLM Loss: 1.88191
254
+ INFO:2025-06-09 05:10:12,520: Epoch [21/25], Step [100/3970], Mel Loss: 0.48069, Gen Loss: 13.96414, Disc Loss: 2.44672, Mono Loss: 0.02962, S2S Loss: 1.05849, SLM Loss: 1.87176
255
+ INFO:2025-06-09 05:11:34,572: Epoch [21/25], Step [150/3970], Mel Loss: 0.47060, Gen Loss: 9.93725, Disc Loss: 3.41884, Mono Loss: 0.03497, S2S Loss: 0.67804, SLM Loss: 1.80349
256
+ INFO:2025-06-09 05:13:00,537: Epoch [21/25], Step [200/3970], Mel Loss: 0.46755, Gen Loss: 13.78485, Disc Loss: 2.85819, Mono Loss: 0.02449, S2S Loss: 1.11917, SLM Loss: 1.90021
257
+ INFO:2025-06-09 05:14:25,136: Epoch [21/25], Step [250/3970], Mel Loss: 0.47410, Gen Loss: 15.36125, Disc Loss: 2.36199, Mono Loss: 0.02162, S2S Loss: 0.80983, SLM Loss: 1.83971
258
+ INFO:2025-06-09 05:15:52,212: Epoch [21/25], Step [300/3970], Mel Loss: 0.48966, Gen Loss: 10.06612, Disc Loss: 3.59549, Mono Loss: 0.02996, S2S Loss: 0.83736, SLM Loss: 1.80890
259
+ INFO:2025-06-09 05:17:15,903: Epoch [21/25], Step [350/3970], Mel Loss: 0.46942, Gen Loss: 13.45973, Disc Loss: 2.75646, Mono Loss: 0.02494, S2S Loss: 0.99167, SLM Loss: 1.77071
260
+ INFO:2025-06-09 05:18:39,464: Epoch [21/25], Step [400/3970], Mel Loss: 0.50879, Gen Loss: 12.60575, Disc Loss: 2.79759, Mono Loss: 0.02696, S2S Loss: 0.86569, SLM Loss: 1.84605
261
+ INFO:2025-06-09 05:20:04,744: Epoch [21/25], Step [450/3970], Mel Loss: 0.46598, Gen Loss: 11.56731, Disc Loss: 3.28558, Mono Loss: 0.03384, S2S Loss: 1.14885, SLM Loss: 1.89640
262
+ INFO:2025-06-09 05:21:56,848: Validation loss: 0.531
263
+
264
+
265
+
266
+
267
+ INFO:2025-06-09 05:23:29,412: Epoch [22/25], Step [50/3970], Mel Loss: 0.48972, Gen Loss: 9.84740, Disc Loss: 3.27302, Mono Loss: 0.03283, S2S Loss: 0.90118, SLM Loss: 1.72501
268
+ INFO:2025-06-09 05:24:56,674: Epoch [22/25], Step [100/3970], Mel Loss: 0.46670, Gen Loss: 12.74649, Disc Loss: 3.23306, Mono Loss: 0.02487, S2S Loss: 0.86857, SLM Loss: 1.82290
269
+ INFO:2025-06-09 05:26:22,931: Epoch [22/25], Step [150/3970], Mel Loss: 0.47883, Gen Loss: 11.48265, Disc Loss: 2.94219, Mono Loss: 0.02303, S2S Loss: 0.74939, SLM Loss: 1.80694
270
+ INFO:2025-06-09 05:27:48,330: Epoch [22/25], Step [200/3970], Mel Loss: 0.47217, Gen Loss: 13.39269, Disc Loss: 2.95101, Mono Loss: 0.02749, S2S Loss: 1.05043, SLM Loss: 1.96215
271
+ INFO:2025-06-09 05:29:13,642: Epoch [22/25], Step [250/3970], Mel Loss: 0.46282, Gen Loss: 11.55498, Disc Loss: 3.11340, Mono Loss: 0.02747, S2S Loss: 1.03113, SLM Loss: 1.80026
272
+ INFO:2025-06-09 05:30:37,418: Epoch [22/25], Step [300/3970], Mel Loss: 0.46544, Gen Loss: 11.89728, Disc Loss: 3.54280, Mono Loss: 0.02800, S2S Loss: 0.78287, SLM Loss: 1.71039
273
+ INFO:2025-06-09 05:32:00,981: Epoch [22/25], Step [350/3970], Mel Loss: 0.47136, Gen Loss: 10.04337, Disc Loss: 3.48481, Mono Loss: 0.02932, S2S Loss: 0.85002, SLM Loss: 1.70120
274
+ INFO:2025-06-09 05:33:23,173: Epoch [22/25], Step [400/3970], Mel Loss: 0.47503, Gen Loss: 13.56164, Disc Loss: 2.41358, Mono Loss: 0.03078, S2S Loss: 1.02305, SLM Loss: 1.93645
275
+ INFO:2025-06-09 05:34:50,016: Epoch [22/25], Step [450/3970], Mel Loss: 0.46342, Gen Loss: 10.59329, Disc Loss: 3.17809, Mono Loss: 0.02502, S2S Loss: 0.97440, SLM Loss: 1.87193
276
+ INFO:2025-06-09 05:36:40,602: Validation loss: 0.433
277
+
278
+
279
+
280
+
281
+ INFO:2025-06-09 05:38:17,194: Epoch [23/25], Step [50/3970], Mel Loss: 0.46734, Gen Loss: 6.01506, Disc Loss: 3.87277, Mono Loss: 0.03631, S2S Loss: 0.86797, SLM Loss: 1.78308
282
+ INFO:2025-06-09 05:39:45,134: Epoch [23/25], Step [100/3970], Mel Loss: 0.45456, Gen Loss: 10.32937, Disc Loss: 3.33413, Mono Loss: 0.03403, S2S Loss: 0.69475, SLM Loss: 1.64199
283
+ INFO:2025-06-09 05:41:10,925: Epoch [23/25], Step [150/3970], Mel Loss: 0.47284, Gen Loss: 10.40523, Disc Loss: 3.22554, Mono Loss: 0.03001, S2S Loss: 1.01316, SLM Loss: 1.79790
284
+ INFO:2025-06-09 05:42:38,603: Epoch [23/25], Step [200/3970], Mel Loss: 0.46535, Gen Loss: 13.34508, Disc Loss: 3.19096, Mono Loss: 0.03172, S2S Loss: 0.90094, SLM Loss: 1.90261
285
+ INFO:2025-06-09 05:44:04,344: Epoch [23/25], Step [250/3970], Mel Loss: 0.46423, Gen Loss: 13.47969, Disc Loss: 2.72158, Mono Loss: 0.02347, S2S Loss: 0.84805, SLM Loss: 1.84319
286
+ INFO:2025-06-09 05:45:29,913: Epoch [23/25], Step [300/3970], Mel Loss: 0.47467, Gen Loss: 13.43347, Disc Loss: 2.60238, Mono Loss: 0.02359, S2S Loss: 0.73610, SLM Loss: 1.73985
287
+ INFO:2025-06-09 05:46:57,724: Epoch [23/25], Step [350/3970], Mel Loss: 0.47841, Gen Loss: 14.26015, Disc Loss: 2.87291, Mono Loss: 0.02988, S2S Loss: 0.85957, SLM Loss: 1.76982
288
+ INFO:2025-06-09 05:48:22,973: Epoch [23/25], Step [400/3970], Mel Loss: 0.46632, Gen Loss: 11.81628, Disc Loss: 2.90987, Mono Loss: 0.02264, S2S Loss: 0.88753, SLM Loss: 1.71842
289
+ INFO:2025-06-09 05:49:48,069: Epoch [23/25], Step [450/3970], Mel Loss: 0.45516, Gen Loss: 12.15128, Disc Loss: 2.99541, Mono Loss: 0.02988, S2S Loss: 0.89353, SLM Loss: 1.90833
290
+ INFO:2025-06-09 05:51:36,052: Validation loss: 0.414
291
+
292
+
293
+
294
+
295
+ INFO:2025-06-09 05:53:10,255: Epoch [24/25], Step [50/3970], Mel Loss: 0.46737, Gen Loss: 10.23576, Disc Loss: 3.82643, Mono Loss: 0.02651, S2S Loss: 1.06037, SLM Loss: 1.99568
296
+ INFO:2025-06-09 05:54:37,468: Epoch [24/25], Step [100/3970], Mel Loss: 0.47177, Gen Loss: 12.41867, Disc Loss: 2.91510, Mono Loss: 0.02961, S2S Loss: 0.85282, SLM Loss: 1.90385
297
+ INFO:2025-06-09 05:56:04,209: Epoch [24/25], Step [150/3970], Mel Loss: 0.46322, Gen Loss: 11.11626, Disc Loss: 3.14571, Mono Loss: 0.02543, S2S Loss: 1.12388, SLM Loss: 1.97046
298
+ INFO:2025-06-09 05:57:29,491: Epoch [24/25], Step [200/3970], Mel Loss: 0.46526, Gen Loss: 15.59772, Disc Loss: 2.67776, Mono Loss: 0.03296, S2S Loss: 0.79784, SLM Loss: 1.69807
299
+ INFO:2025-06-09 05:58:52,805: Epoch [24/25], Step [250/3970], Mel Loss: 0.46915, Gen Loss: 11.35859, Disc Loss: 2.94053, Mono Loss: 0.03207, S2S Loss: 0.66795, SLM Loss: 1.76326
300
+ INFO:2025-06-09 06:00:20,548: Epoch [24/25], Step [300/3970], Mel Loss: 0.46567, Gen Loss: 11.95895, Disc Loss: 3.04587, Mono Loss: 0.03336, S2S Loss: 1.00292, SLM Loss: 1.68971
301
+ INFO:2025-06-09 06:01:46,336: Epoch [24/25], Step [350/3970], Mel Loss: 0.46672, Gen Loss: 12.67813, Disc Loss: 3.09165, Mono Loss: 0.02409, S2S Loss: 1.00270, SLM Loss: 1.76963
302
+ INFO:2025-06-09 06:03:11,223: Epoch [24/25], Step [400/3970], Mel Loss: 0.47140, Gen Loss: 12.71134, Disc Loss: 2.75087, Mono Loss: 0.03305, S2S Loss: 0.78916, SLM Loss: 1.78080
303
+ INFO:2025-06-09 06:04:34,537: Epoch [24/25], Step [450/3970], Mel Loss: 0.47079, Gen Loss: 8.86274, Disc Loss: 3.47894, Mono Loss: 0.03888, S2S Loss: 0.75467, SLM Loss: 1.68494
304
+ INFO:2025-06-09 06:06:21,549: Validation loss: 0.421
305
+
306
+
307
+
308
+
309
+ INFO:2025-06-09 06:07:56,930: Epoch [25/25], Step [50/3970], Mel Loss: 0.46430, Gen Loss: 10.24589, Disc Loss: 3.36352, Mono Loss: 0.03448, S2S Loss: 0.82375, SLM Loss: 1.89295
310
+ INFO:2025-06-09 06:09:24,068: Epoch [25/25], Step [100/3970], Mel Loss: 0.47254, Gen Loss: 9.77706, Disc Loss: 3.08066, Mono Loss: 0.02397, S2S Loss: 0.90991, SLM Loss: 1.88774
311
+ INFO:2025-06-09 06:10:48,092: Epoch [25/25], Step [150/3970], Mel Loss: 0.46583, Gen Loss: 11.92317, Disc Loss: 3.11952, Mono Loss: 0.02979, S2S Loss: 0.85678, SLM Loss: 1.75618
312
+ INFO:2025-06-09 06:12:13,142: Epoch [25/25], Step [200/3970], Mel Loss: 0.47014, Gen Loss: 11.47270, Disc Loss: 3.22507, Mono Loss: 0.02897, S2S Loss: 1.13971, SLM Loss: 1.97860
313
+ INFO:2025-06-09 06:13:37,846: Epoch [25/25], Step [250/3970], Mel Loss: 0.46148, Gen Loss: 10.36515, Disc Loss: 3.29522, Mono Loss: 0.02770, S2S Loss: 0.90695, SLM Loss: 1.83152
314
+ INFO:2025-06-09 06:15:03,448: Epoch [25/25], Step [300/3970], Mel Loss: 0.47940, Gen Loss: 14.23678, Disc Loss: 2.69026, Mono Loss: 0.03347, S2S Loss: 0.90296, SLM Loss: 1.91125
315
+ INFO:2025-06-09 06:16:31,839: Epoch [25/25], Step [350/3970], Mel Loss: 0.46467, Gen Loss: 12.48178, Disc Loss: 2.69238, Mono Loss: 0.03317, S2S Loss: 0.71102, SLM Loss: 1.83116
316
+ INFO:2025-06-09 06:17:56,540: Epoch [25/25], Step [400/3970], Mel Loss: 0.47247, Gen Loss: 11.85046, Disc Loss: 3.40187, Mono Loss: 0.03009, S2S Loss: 0.67559, SLM Loss: 1.67930
317
+ INFO:2025-06-09 06:19:22,839: Epoch [25/25], Step [450/3970], Mel Loss: 0.46736, Gen Loss: 11.12021, Disc Loss: 2.89506, Mono Loss: 0.03046, S2S Loss: 0.86089, SLM Loss: 1.69023
318
+ INFO:2025-06-09 06:21:10,431: Validation loss: 0.427
319
+
320
+
321
+
322
+
323
+ INFO:2025-06-09 07:59:08,543: Epoch [1/15], Step [50/15883], Loss: 0.59463, Disc Loss: 0.00000, Dur Loss: 1.86396, CE Loss: 0.14515, Norm Loss: 1.16517, F0 Loss: 7.99762, LM Loss: 1.75907, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
324
+ INFO:2025-06-09 08:00:16,205: Epoch [1/15], Step [100/15883], Loss: 0.58628, Disc Loss: 0.00000, Dur Loss: 1.41872, CE Loss: 0.10803, Norm Loss: 3.45058, F0 Loss: 8.65071, LM Loss: 2.24374, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
325
+ INFO:2025-06-09 08:03:19,805: Epoch [1/15], Step [50/15883], Loss: 0.62156, Disc Loss: 0.00000, Dur Loss: 1.35592, CE Loss: 0.12068, Norm Loss: 4.53912, F0 Loss: 12.41762, LM Loss: 2.12398, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
326
+ INFO:2025-06-09 08:04:26,944: Epoch [1/15], Step [100/15883], Loss: 0.57647, Disc Loss: 0.00000, Dur Loss: 1.64274, CE Loss: 0.13532, Norm Loss: 2.14458, F0 Loss: 10.61257, LM Loss: 2.36555, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
327
+ INFO:2025-06-09 08:05:34,608: Epoch [1/15], Step [150/15883], Loss: 0.57745, Disc Loss: 0.00000, Dur Loss: 1.28145, CE Loss: 0.09792, Norm Loss: 2.62505, F0 Loss: 7.29611, LM Loss: 1.92034, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
328
+ INFO:2025-06-09 08:06:41,940: Epoch [1/15], Step [200/15883], Loss: 0.56741, Disc Loss: 0.00000, Dur Loss: 1.18127, CE Loss: 0.08711, Norm Loss: 1.55673, F0 Loss: 9.18722, LM Loss: 1.94882, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
329
+ INFO:2025-06-09 08:07:51,378: Epoch [1/15], Step [250/15883], Loss: 0.58637, Disc Loss: 0.00000, Dur Loss: 1.34848, CE Loss: 0.10076, Norm Loss: 2.22754, F0 Loss: 3.84831, LM Loss: 1.77836, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
330
+ INFO:2025-06-09 08:09:00,842: Epoch [1/15], Step [300/15883], Loss: 0.52509, Disc Loss: 0.00000, Dur Loss: 2.02432, CE Loss: 0.15954, Norm Loss: 2.11042, F0 Loss: 6.38905, LM Loss: 2.07231, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
331
+ INFO:2025-06-09 08:10:11,156: Epoch [1/15], Step [350/15883], Loss: 0.52738, Disc Loss: 0.00000, Dur Loss: 1.24347, CE Loss: 0.07718, Norm Loss: 2.49583, F0 Loss: 6.60764, LM Loss: 2.14540, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
332
+ INFO:2025-06-09 08:11:21,010: Epoch [1/15], Step [400/15883], Loss: 0.53797, Disc Loss: 0.00000, Dur Loss: 1.27058, CE Loss: 0.08413, Norm Loss: 2.08075, F0 Loss: 4.33628, LM Loss: 2.04978, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
333
+ INFO:2025-06-09 08:12:31,649: Epoch [1/15], Step [450/15883], Loss: 0.53851, Disc Loss: 0.00000, Dur Loss: 1.21631, CE Loss: 0.08424, Norm Loss: 2.76793, F0 Loss: 5.36540, LM Loss: 1.88584, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
334
+ INFO:2025-06-09 08:13:39,582: Epoch [1/15], Step [500/15883], Loss: 0.53331, Disc Loss: 0.00000, Dur Loss: 1.04547, CE Loss: 0.07928, Norm Loss: 1.58679, F0 Loss: 3.28320, LM Loss: 1.99279, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
335
+ INFO:2025-06-09 08:14:47,552: Epoch [1/15], Step [550/15883], Loss: 0.53619, Disc Loss: 0.00000, Dur Loss: 1.58742, CE Loss: 0.11445, Norm Loss: 3.74290, F0 Loss: 4.74286, LM Loss: 1.91775, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
336
+ INFO:2025-06-09 08:15:55,037: Epoch [1/15], Step [600/15883], Loss: 0.54803, Disc Loss: 0.00000, Dur Loss: 1.44735, CE Loss: 0.09927, Norm Loss: 2.30749, F0 Loss: 6.71221, LM Loss: 1.95221, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
337
+ INFO:2025-06-09 08:17:02,005: Epoch [1/15], Step [650/15883], Loss: 0.53016, Disc Loss: 0.00000, Dur Loss: 1.47494, CE Loss: 0.12974, Norm Loss: 2.32919, F0 Loss: 7.98697, LM Loss: 2.14263, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
338
+ INFO:2025-06-09 08:18:09,942: Epoch [1/15], Step [700/15883], Loss: 0.53085, Disc Loss: 0.00000, Dur Loss: 1.08772, CE Loss: 0.05854, Norm Loss: 1.36569, F0 Loss: 3.09399, LM Loss: 1.99569, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
339
+ INFO:2025-06-09 08:19:15,731: Epoch [1/15], Step [750/15883], Loss: 0.53651, Disc Loss: 0.00000, Dur Loss: 0.83910, CE Loss: 0.05060, Norm Loss: 2.25991, F0 Loss: 11.18594, LM Loss: 1.93452, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
340
+ INFO:2025-06-09 08:20:24,249: Epoch [1/15], Step [800/15883], Loss: 0.52069, Disc Loss: 0.00000, Dur Loss: 1.07455, CE Loss: 0.06327, Norm Loss: 5.22029, F0 Loss: 7.52101, LM Loss: 2.07087, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
341
+ INFO:2025-06-09 08:21:34,678: Epoch [1/15], Step [850/15883], Loss: 0.51518, Disc Loss: 0.00000, Dur Loss: 1.63769, CE Loss: 0.10630, Norm Loss: 2.12592, F0 Loss: 2.67696, LM Loss: 1.75673, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
342
+ INFO:2025-06-09 08:22:43,786: Epoch [1/15], Step [900/15883], Loss: 0.50269, Disc Loss: 0.00000, Dur Loss: 1.29939, CE Loss: 0.07080, Norm Loss: 1.26741, F0 Loss: 2.30947, LM Loss: 1.68987, Gen Loss: 0.00000, Sty Loss: 0.00000, Diff Loss: 0.00000, DiscLM Loss: 0.00000, GenLM Loss: 0.00000
.ipynb_checkpoints/train_second-checkpoint.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import traceback
15
+ import warnings
16
+ warnings.simplefilter('ignore')
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ from meldataset import build_dataloader
20
+
21
+ from Utils.ASR.models import ASRCNN
22
+ from Utils.JDC.model import JDCNet
23
+ from Utils.PLBERT.util import load_plbert
24
+
25
+ from models import *
26
+ from losses import *
27
+ from utils import *
28
+
29
+ from Modules.slmadv import SLMAdversarialLoss
30
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
31
+
32
+ from optimizers import build_optimizer
33
+
34
+ def clip_to_bert(texts, mask, max_len: int = 510):
35
+ """
36
+ Hard-clip batch to ≤ max_len tokens and return
37
+ (texts_clipped, **fresh full-width mask**, new_lengths).
38
+ """
39
+ if texts.size(1) > max_len:
40
+ texts = texts[:, :max_len]
41
+ lengths = (texts != 0).sum(dim=1) # PAD id = 0
42
+ seq_len = texts.size(1)
43
+ mask = torch.arange(seq_len, device=texts.device).unsqueeze(0) >= \
44
+ lengths.unsqueeze(1) # shape [B, seq_len]
45
+ return texts, mask, lengths
46
+
47
+ # simple fix for dataparallel that allows access to class attributes
48
+ class MyDataParallel(torch.nn.DataParallel):
49
+ def __getattr__(self, name):
50
+ try:
51
+ return super().__getattr__(name)
52
+ except AttributeError:
53
+ return getattr(self.module, name)
54
+
55
+ import logging
56
+ from logging import StreamHandler
57
+ logger = logging.getLogger(__name__)
58
+ logger.setLevel(logging.DEBUG)
59
+ handler = StreamHandler()
60
+ handler.setLevel(logging.DEBUG)
61
+ logger.addHandler(handler)
62
+
63
+
64
+ @click.command()
65
+ @click.option('-p', '--config_path', default='Configs/config.yml', type=str)
66
+ def main(config_path):
67
+ config = yaml.safe_load(open(config_path))
68
+
69
+ log_dir = config['log_dir']
70
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
71
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
72
+ writer = SummaryWriter(log_dir + "/tensorboard")
73
+
74
+ # write logs
75
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
76
+ file_handler.setLevel(logging.DEBUG)
77
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
78
+ logger.addHandler(file_handler)
79
+
80
+
81
+ batch_size = config.get('batch_size', 10)
82
+
83
+ epochs = config.get('epochs_2nd', 200)
84
+ save_freq = config.get('save_freq', 2)
85
+ log_interval = config.get('log_interval', 10)
86
+ saving_epoch = config.get('save_freq', 2)
87
+
88
+ data_params = config.get('data_params', None)
89
+ sr = config['preprocess_params'].get('sr', 24000)
90
+ train_path = data_params['train_data']
91
+ val_path = data_params['val_data']
92
+ root_path = data_params['root_path']
93
+ min_length = data_params['min_length']
94
+ OOD_data = data_params['OOD_data']
95
+
96
+ max_len = config.get('max_len', 200)
97
+
98
+ loss_params = Munch(config['loss_params'])
99
+ diff_epoch = loss_params.diff_epoch
100
+ joint_epoch = loss_params.joint_epoch
101
+
102
+ optimizer_params = Munch(config['optimizer_params'])
103
+
104
+ train_list, val_list = get_data_path_list(train_path, val_path)
105
+ device = 'cuda'
106
+
107
+ train_dataloader = build_dataloader(train_list,
108
+ root_path,
109
+ OOD_data=OOD_data,
110
+ min_length=min_length,
111
+ batch_size=batch_size,
112
+ num_workers=2,
113
+ dataset_config={},
114
+ device=device)
115
+
116
+ val_dataloader = build_dataloader(val_list,
117
+ root_path,
118
+ OOD_data=OOD_data,
119
+ min_length=min_length,
120
+ batch_size=batch_size,
121
+ validation=True,
122
+ num_workers=0,
123
+ device=device,
124
+ dataset_config={})
125
+
126
+ # load pretrained ASR model
127
+ ASR_config = config.get('ASR_config', False)
128
+ ASR_path = config.get('ASR_path', False)
129
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
130
+
131
+ # load pretrained F0 model
132
+ F0_path = config.get('F0_path', False)
133
+ pitch_extractor = load_F0_models(F0_path)
134
+
135
+ # load PL-BERT model
136
+ BERT_path = config.get('PLBERT_dir', False)
137
+ plbert = load_plbert(BERT_path)
138
+
139
+ # build model
140
+ model_params = recursive_munch(config['model_params'])
141
+ multispeaker = model_params.multispeaker
142
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
143
+ _ = [model[key].to(device) for key in model]
144
+
145
+ # DP
146
+ for key in model:
147
+ if key != "mpd" and key != "msd" and key != "wd":
148
+ model[key] = MyDataParallel(model[key])
149
+
150
+ start_epoch = 0
151
+ iters = 0
152
+
153
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
154
+
155
+ if not load_pretrained:
156
+ if config.get('first_stage_path', '') != '':
157
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
158
+ print('Loading the first stage model at %s ...' % first_stage_path)
159
+ model, _, start_epoch, iters = load_checkpoint(model,
160
+ None,
161
+ first_stage_path,
162
+ load_only_params=True,
163
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
164
+
165
+ # these epochs should be counted from the start epoch
166
+ diff_epoch += start_epoch
167
+ joint_epoch += start_epoch
168
+ epochs += start_epoch
169
+
170
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
171
+ else:
172
+ raise ValueError('You need to specify the path to the first stage model.')
173
+
174
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
175
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
176
+ wl = WavLMLoss(model_params.slm.model,
177
+ model.wd,
178
+ sr,
179
+ model_params.slm.sr).to(device)
180
+
181
+ gl = MyDataParallel(gl)
182
+ dl = MyDataParallel(dl)
183
+ wl = MyDataParallel(wl)
184
+
185
+ sampler = DiffusionSampler(
186
+ model.diffusion.diffusion,
187
+ sampler=ADPM2Sampler(),
188
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
189
+ clamp=False
190
+ )
191
+
192
+ scheduler_params = {
193
+ "max_lr": optimizer_params.lr,
194
+ "pct_start": float(0),
195
+ "epochs": epochs,
196
+ "steps_per_epoch": len(train_dataloader),
197
+ }
198
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
199
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
200
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
201
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
202
+
203
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
204
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
205
+
206
+ # adjust BERT learning rate
207
+ for g in optimizer.optimizers['bert'].param_groups:
208
+ g['betas'] = (0.9, 0.99)
209
+ g['lr'] = optimizer_params.bert_lr
210
+ g['initial_lr'] = optimizer_params.bert_lr
211
+ g['min_lr'] = 0
212
+ g['weight_decay'] = 0.01
213
+
214
+ # adjust acoustic module learning rate
215
+ for module in ["decoder", "style_encoder"]:
216
+ for g in optimizer.optimizers[module].param_groups:
217
+ g['betas'] = (0.0, 0.99)
218
+ g['lr'] = optimizer_params.ft_lr
219
+ g['initial_lr'] = optimizer_params.ft_lr
220
+ g['min_lr'] = 0
221
+ g['weight_decay'] = 1e-4
222
+
223
+ # load models if there is a model
224
+ if load_pretrained:
225
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
226
+ load_only_params=config.get('load_only_params', True))
227
+
228
+ n_down = model.text_aligner.n_down
229
+
230
+ best_loss = float('inf') # best test loss
231
+ loss_train_record = list([])
232
+ loss_test_record = list([])
233
+ iters = 0
234
+
235
+ criterion = nn.L1Loss() # F0 loss (regression)
236
+ torch.cuda.empty_cache()
237
+
238
+ stft_loss = MultiResolutionSTFTLoss().to(device)
239
+
240
+ print('BERT', optimizer.optimizers['bert'])
241
+ print('decoder', optimizer.optimizers['decoder'])
242
+
243
+ start_ds = False
244
+
245
+ running_std = []
246
+
247
+ slmadv_params = Munch(config['slmadv_params'])
248
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
249
+ slmadv_params.min_len,
250
+ slmadv_params.max_len,
251
+ batch_percentage=slmadv_params.batch_percentage,
252
+ skip_update=slmadv_params.iter,
253
+ sig=slmadv_params.sig
254
+ )
255
+
256
+
257
+ for epoch in range(start_epoch, epochs):
258
+ running_loss = 0
259
+ start_time = time.time()
260
+
261
+ _ = [model[key].eval() for key in model]
262
+
263
+ model.predictor.train()
264
+ model.bert_encoder.train()
265
+ model.bert.train()
266
+ model.msd.train()
267
+ model.mpd.train()
268
+
269
+
270
+ if epoch >= diff_epoch:
271
+ start_ds = True
272
+
273
+ for i, batch in enumerate(train_dataloader):
274
+ waves = batch[0]
275
+ batch = [b.to(device) for b in batch[1:]]
276
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
277
+
278
+ # --------------- CLIP TEXTS *ONCE* -----------------
279
+ text_mask = length_to_mask(input_lengths).to(texts.device)
280
+ texts, text_mask, input_lengths = clip_to_bert(texts, text_mask)
281
+ # ── drop rows that became all-PAD after clipping ───────────
282
+ keep = (input_lengths > 0).nonzero(as_tuple=True)[0]
283
+ if keep.numel() != texts.size(0):
284
+ texts, text_mask, input_lengths = texts[keep], text_mask[keep], input_lengths[keep]
285
+ ref_texts, ref_lengths = ref_texts[keep], ref_lengths[keep]
286
+ mels, mel_input_length, ref_mels = mels[keep], mel_input_length[keep], ref_mels[keep]
287
+ waves = [waves[i] for i in keep.tolist()]
288
+ # ----------------------------------------------------
289
+
290
+ with torch.no_grad():
291
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
292
+ mel_mask = length_to_mask(mel_input_length).to(device)
293
+
294
+ try:
295
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
296
+ s2s_attn = s2s_attn.transpose(-1, -2)
297
+ s2s_attn = s2s_attn[..., 1:]
298
+ s2s_attn = s2s_attn.transpose(-1, -2)
299
+ except:
300
+ continue
301
+
302
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
303
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
304
+
305
+ # encode
306
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
307
+ asr = (t_en @ s2s_attn_mono)
308
+
309
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
310
+
311
+ # compute reference styles
312
+ if multispeaker and epoch >= diff_epoch:
313
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
314
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
315
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
316
+
317
+ # compute the style of the entire utterance
318
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
319
+ ss = []
320
+ gs = []
321
+ for bib in range(len(mel_input_length)):
322
+ mel_length = int(mel_input_length[bib].item())
323
+ mel = mels[bib, :, :mel_input_length[bib]]
324
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
325
+ ss.append(s)
326
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
327
+ gs.append(s)
328
+
329
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
330
+ gs = torch.stack(gs).squeeze() # global acoustic styles
331
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
332
+
333
+ # texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
334
+
335
+ # # ────── PATCH: keep PL-BERT below 512 tokens ─────────
336
+ # MAX_BERT_LEN = 510 # leave room for [CLS] and [SEP]
337
+ # if texts.size(1) > MAX_BERT_LEN: # truncate batch-wise
338
+ # texts = texts[:, :MAX_BERT_LEN]
339
+ # seq_len = texts.size(1) # current padded width
340
+ # input_lengths = (texts != 0).sum(1) # 0 is PAD
341
+ # arange_row = torch.arange(seq_len, device=texts.device) # shape [L]
342
+ # text_mask = arange_row.unsqueeze(0) >= input_lengths.unsqueeze(1)
343
+ # # shape [B, L]
344
+
345
+ # # keep only rows that still have at least one real token
346
+ # keep = (input_lengths > 0).nonzero(as_tuple=True)[0]
347
+ # if keep.numel() != texts.size(0): # a row was truncated to length 0
348
+ # texts, text_mask, input_lengths = texts[keep], text_mask[keep], input_lengths[keep]
349
+ # ref_texts, ref_lengths = ref_texts[keep], ref_lengths[keep]
350
+ # mels, mel_input_length, ref_mels = mels[keep], mel_input_length[keep], ref_mels[keep]
351
+ # waves = [waves[i] for i in keep.tolist()]
352
+
353
+ # # clip alignments to the *current* width (seq_len)
354
+ # s2s_attn_mono = s2s_attn_mono[:, :seq_len, :]
355
+ # d_gt = d_gt[:, :seq_len]
356
+ # # ─────────────────────────────────────────────────────
357
+
358
+ # -------------------------------------------------------------
359
+ # Now build *everything* that depends on token count
360
+ with torch.no_grad():
361
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
362
+
363
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
364
+ s2s_attn = s2s_attn.transpose(-1, -2)[..., 1:].transpose(-1, -2)
365
+
366
+ mask_ST = mask_from_lens(s2s_attn, input_lengths,
367
+ mel_input_length // 2**n_down)
368
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
369
+
370
+ asr = t_en @ s2s_attn_mono
371
+ d_gt = s2s_attn_mono.sum(dim=-1)
372
+
373
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
374
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
375
+
376
+ # denoiser training
377
+ if epoch >= diff_epoch:
378
+ num_steps = np.random.randint(3, 5)
379
+
380
+ if model_params.diffusion.dist.estimate_sigma_data:
381
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
382
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
383
+
384
+ if multispeaker:
385
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
386
+ embedding=bert_dur,
387
+ embedding_scale=1,
388
+ features=ref, # reference from the same speaker as the embedding
389
+ embedding_mask_proba=0.1,
390
+ num_steps=num_steps).squeeze(1)
391
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
392
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
393
+ else:
394
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
395
+ embedding=bert_dur,
396
+ embedding_scale=1,
397
+ embedding_mask_proba=0.1,
398
+ num_steps=num_steps).squeeze(1)
399
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
400
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
401
+ else:
402
+ loss_sty = 0
403
+ loss_diff = 0
404
+
405
+ d, p = model.predictor(d_en, s_dur,
406
+ input_lengths,
407
+ s2s_attn_mono,
408
+ text_mask)
409
+
410
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
411
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
412
+ en = []
413
+ gt = []
414
+ st = []
415
+ p_en = []
416
+ wav = []
417
+
418
+ for bib in range(len(mel_input_length)):
419
+ mel_length = int(mel_input_length[bib].item() / 2)
420
+
421
+ random_start = np.random.randint(0, mel_length - mel_len)
422
+ en.append(asr[bib, :, random_start:random_start+mel_len])
423
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
424
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
425
+
426
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
427
+ wav.append(torch.from_numpy(y).to(device))
428
+
429
+ # style reference (better to be different from the GT)
430
+ random_start = np.random.randint(0, mel_length - mel_len_st)
431
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
432
+
433
+ wav = torch.stack(wav).float().detach()
434
+
435
+ en = torch.stack(en)
436
+ p_en = torch.stack(p_en)
437
+ gt = torch.stack(gt).detach()
438
+ st = torch.stack(st).detach()
439
+
440
+ if gt.size(-1) < 80:
441
+ continue
442
+
443
+ s_dur = model.predictor_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
444
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
445
+
446
+ with torch.no_grad():
447
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
448
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
449
+
450
+ asr_real = model.text_aligner.get_feature(gt)
451
+
452
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
453
+
454
+ y_rec_gt = wav.unsqueeze(1)
455
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
456
+
457
+ if epoch >= joint_epoch:
458
+ # ground truth from recording
459
+ wav = y_rec_gt # use recording since decoder is tuned
460
+ else:
461
+ # ground truth from reconstruction
462
+ wav = y_rec_gt_pred # use reconstruction since decoder is fixed
463
+
464
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
465
+
466
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
467
+
468
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
469
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
470
+
471
+ if start_ds:
472
+ optimizer.zero_grad()
473
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
474
+ d_loss.backward()
475
+ optimizer.step('msd')
476
+ optimizer.step('mpd')
477
+ else:
478
+ d_loss = 0
479
+
480
+ # generator loss
481
+ optimizer.zero_grad()
482
+
483
+ loss_mel = stft_loss(y_rec, wav)
484
+ if start_ds:
485
+ loss_gen_all = gl(wav, y_rec).mean()
486
+ else:
487
+ loss_gen_all = 0
488
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
489
+
490
+ loss_ce = 0
491
+ loss_dur = 0
492
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
493
+ _s2s_pred = _s2s_pred[:_text_length, :]
494
+ _text_input = _text_input[:_text_length].long()
495
+ _s2s_trg = torch.zeros_like(_s2s_pred)
496
+ for p in range(_s2s_trg.shape[0]):
497
+ _s2s_trg[p, :_text_input[p]] = 1
498
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
499
+
500
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
501
+ _text_input[1:_text_length-1])
502
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
503
+
504
+ loss_ce /= texts.size(0)
505
+ loss_dur /= texts.size(0)
506
+
507
+ g_loss = loss_params.lambda_mel * loss_mel + \
508
+ loss_params.lambda_F0 * loss_F0_rec + \
509
+ loss_params.lambda_ce * loss_ce + \
510
+ loss_params.lambda_norm * loss_norm_rec + \
511
+ loss_params.lambda_dur * loss_dur + \
512
+ loss_params.lambda_gen * loss_gen_all + \
513
+ loss_params.lambda_slm * loss_lm + \
514
+ loss_params.lambda_sty * loss_sty + \
515
+ loss_params.lambda_diff * loss_diff
516
+
517
+ running_loss += loss_mel.item()
518
+ g_loss.backward()
519
+ if torch.isnan(g_loss):
520
+ from IPython.core.debugger import set_trace
521
+ set_trace()
522
+
523
+ optimizer.step('bert_encoder')
524
+ optimizer.step('bert')
525
+ optimizer.step('predictor')
526
+ optimizer.step('predictor_encoder')
527
+
528
+ if epoch >= diff_epoch:
529
+ optimizer.step('diffusion')
530
+
531
+ if epoch >= joint_epoch:
532
+ optimizer.step('style_encoder')
533
+ optimizer.step('decoder')
534
+
535
+ # randomly pick whether to use in-distribution text
536
+ if np.random.rand() < 0.5:
537
+ use_ind = True
538
+ else:
539
+ use_ind = False
540
+
541
+ if use_ind:
542
+ ref_lengths = input_lengths
543
+ ref_texts = texts
544
+
545
+ # ---- clip reference text exactly the same way ----
546
+ ref_mask = length_to_mask(ref_lengths).to(ref_texts.device)
547
+ ref_texts, ref_mask, ref_lengths = clip_to_bert(ref_texts, ref_mask)
548
+
549
+ slm_out = slmadv(i,
550
+ y_rec_gt,
551
+ y_rec_gt_pred,
552
+ waves,
553
+ mel_input_length,
554
+ ref_texts,
555
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
556
+
557
+ if slm_out is None:
558
+ continue
559
+
560
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
561
+
562
+ # SLM generator loss
563
+ optimizer.zero_grad()
564
+ loss_gen_lm.backward()
565
+
566
+ # compute the gradient norm
567
+ total_norm = {}
568
+ for key in model.keys():
569
+ total_norm[key] = 0
570
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
571
+ for p in parameters:
572
+ param_norm = p.grad.detach().data.norm(2)
573
+ total_norm[key] += param_norm.item() ** 2
574
+ total_norm[key] = total_norm[key] ** 0.5
575
+
576
+ # gradient scaling
577
+ if total_norm['predictor'] > slmadv_params.thresh:
578
+ for key in model.keys():
579
+ for p in model[key].parameters():
580
+ if p.grad is not None:
581
+ p.grad *= (1 / total_norm['predictor'])
582
+
583
+ for p in model.predictor.duration_proj.parameters():
584
+ if p.grad is not None:
585
+ p.grad *= slmadv_params.scale
586
+
587
+ for p in model.predictor.lstm.parameters():
588
+ if p.grad is not None:
589
+ p.grad *= slmadv_params.scale
590
+
591
+ for p in model.diffusion.parameters():
592
+ if p.grad is not None:
593
+ p.grad *= slmadv_params.scale
594
+
595
+ optimizer.step('bert_encoder')
596
+ optimizer.step('bert')
597
+ optimizer.step('predictor')
598
+ optimizer.step('diffusion')
599
+
600
+ # SLM discriminator loss
601
+ if d_loss_slm != 0:
602
+ optimizer.zero_grad()
603
+ d_loss_slm.backward(retain_graph=True)
604
+ optimizer.step('wd')
605
+
606
+ else:
607
+ d_loss_slm, loss_gen_lm = 0, 0
608
+
609
+ iters = iters + 1
610
+
611
+ if (i+1)%log_interval == 0:
612
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
613
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm))
614
+
615
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
616
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
617
+ writer.add_scalar('train/d_loss', d_loss, iters)
618
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
619
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
620
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
621
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
622
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
623
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
624
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
625
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
626
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
627
+
628
+ running_loss = 0
629
+
630
+ print('Time elasped:', time.time()-start_time)
631
+
632
+ loss_test = 0
633
+ loss_align = 0
634
+ loss_f = 0
635
+ _ = [model[key].eval() for key in model]
636
+
637
+ with torch.no_grad():
638
+ iters_test = 0
639
+ for batch_idx, batch in enumerate(val_dataloader):
640
+ optimizer.zero_grad()
641
+
642
+ try:
643
+ waves = batch[0]
644
+ batch = [b.to(device) for b in batch[1:]]
645
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
646
+
647
+ texts, text_mask, input_lengths = clip_to_bert(texts, text_mask)
648
+ keep = (input_lengths > 0).nonzero(as_tuple=True)[0]
649
+ if keep.numel() != texts.size(0):
650
+ texts, text_mask, input_lengths = texts[keep], text_mask[keep], input_lengths[keep]
651
+ ref_texts, ref_lengths = ref_texts[keep], ref_lengths[keep]
652
+ mels, mel_input_length, ref_mels = mels[keep], mel_input_length[keep], ref_mels[keep]
653
+ waves = [waves[i] for i in keep.tolist()]
654
+
655
+ with torch.no_grad():
656
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(texts.device)
657
+ # mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
658
+
659
+ # _, _, s2s_attn = model.text_aligner(mels, mask, texts)
660
+ # s2s_attn = s2s_attn.transpose(-1, -2)
661
+ # s2s_attn = s2s_attn[..., 1:]
662
+ # s2s_attn = s2s_attn.transpose(-1, -2)
663
+
664
+ # mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
665
+ # s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
666
+
667
+ # # encode
668
+ # t_en = model.text_encoder(texts, input_lengths, text_mask)
669
+ # asr = (t_en @ s2s_attn_mono)
670
+
671
+ # d_gt = s2s_attn_mono.sum(axis=-1).detach()
672
+
673
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
674
+ s2s_attn = s2s_attn.transpose(-1, -2)[..., 1:].transpose(-1, -2)
675
+ mask_ST = mask_from_lens(s2s_attn, input_lengths,
676
+ mel_input_length // 2 ** n_down)
677
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
678
+
679
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
680
+ asr = t_en @ s2s_attn_mono
681
+ d_gt = s2s_attn_mono.sum(dim=-1).detach()
682
+
683
+ ss = []
684
+ gs = []
685
+
686
+ for bib in range(len(mel_input_length)):
687
+ mel_length = int(mel_input_length[bib].item())
688
+ mel = mels[bib, :, :mel_input_length[bib]]
689
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
690
+ ss.append(s)
691
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
692
+ gs.append(s)
693
+
694
+ s = torch.stack(ss).squeeze()
695
+ gs = torch.stack(gs).squeeze()
696
+ s_trg = torch.cat([s, gs], dim=-1).detach()
697
+
698
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
699
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
700
+ d, p = model.predictor(d_en, s,
701
+ input_lengths,
702
+ s2s_attn_mono,
703
+ text_mask)
704
+ # get clips
705
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
706
+ en = []
707
+ gt = []
708
+ p_en = []
709
+ wav = []
710
+
711
+ for bib in range(len(mel_input_length)):
712
+ mel_length = int(mel_input_length[bib].item() / 2)
713
+
714
+ random_start = np.random.randint(0, mel_length - mel_len)
715
+ en.append(asr[bib, :, random_start:random_start+mel_len])
716
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
717
+
718
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
719
+
720
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
721
+ wav.append(torch.from_numpy(y).to(device))
722
+
723
+ wav = torch.stack(wav).float().detach()
724
+
725
+ en = torch.stack(en)
726
+ p_en = torch.stack(p_en)
727
+ gt = torch.stack(gt).detach()
728
+
729
+ s = model.predictor_encoder(gt.unsqueeze(1))
730
+
731
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
732
+
733
+ loss_dur = 0
734
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
735
+ _s2s_pred = _s2s_pred[:_text_length, :]
736
+ _text_input = _text_input[:_text_length].long()
737
+ _s2s_trg = torch.zeros_like(_s2s_pred)
738
+ for bib in range(_s2s_trg.shape[0]):
739
+ _s2s_trg[bib, :_text_input[bib]] = 1
740
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
741
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
742
+ _text_input[1:_text_length-1])
743
+
744
+ loss_dur /= texts.size(0)
745
+
746
+ s = model.style_encoder(gt.unsqueeze(1))
747
+
748
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
749
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
750
+
751
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
752
+
753
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
754
+
755
+ loss_test += (loss_mel).mean()
756
+ loss_align += (loss_dur).mean()
757
+ loss_f += (loss_F0).mean()
758
+
759
+ iters_test += 1
760
+ except Exception as e:
761
+ print(f"run into exception", e)
762
+ traceback.print_exc()
763
+ continue
764
+
765
+ print('Epochs:', epoch + 1)
766
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
767
+ print('\n\n\n')
768
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
769
+ writer.add_scalar('eval/dur_loss', loss_align / iters_test, epoch + 1)
770
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
771
+
772
+ if epoch < joint_epoch:
773
+ # generating reconstruction examples with GT duration
774
+
775
+ with torch.no_grad():
776
+ for bib in range(len(asr)):
777
+ mel_length = int(mel_input_length[bib].item())
778
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
779
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
780
+
781
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
782
+ F0_real = F0_real.unsqueeze(0)
783
+ s = model.style_encoder(gt.unsqueeze(1))
784
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
785
+
786
+ y_rec = model.decoder(en, F0_real, real_norm, s)
787
+
788
+ writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)
789
+
790
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
791
+ p_en = p[bib, :, :mel_length // 2].unsqueeze(0)
792
+
793
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
794
+
795
+ y_pred = model.decoder(en, F0_fake, N_fake, s)
796
+
797
+ writer.add_audio('pred/y' + str(bib), y_pred.cpu().numpy().squeeze(), epoch, sample_rate=sr)
798
+
799
+ if epoch == 0:
800
+ writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
801
+
802
+ if bib >= 5:
803
+ break
804
+ else:
805
+ # generating sampled speech from text directly
806
+ with torch.no_grad():
807
+ # compute reference styles
808
+ if multispeaker and epoch >= diff_epoch:
809
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
810
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
811
+ ref_s = torch.cat([ref_ss, ref_sp], dim=1)
812
+
813
+ for bib in range(len(d_en)):
814
+ if multispeaker:
815
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device),
816
+ embedding=bert_dur[bib].unsqueeze(0),
817
+ embedding_scale=1,
818
+ features=ref_s[bib].unsqueeze(0), # reference from the same speaker as the embedding
819
+ num_steps=5).squeeze(1)
820
+ else:
821
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(texts.device),
822
+ embedding=bert_dur[bib].unsqueeze(0),
823
+ embedding_scale=1,
824
+ num_steps=5).squeeze(1)
825
+
826
+ s = s_pred[:, 128:]
827
+ ref = s_pred[:, :128]
828
+
829
+ d = model.predictor.text_encoder(d_en[bib, :, :input_lengths[bib]].unsqueeze(0),
830
+ s, input_lengths[bib, ...].unsqueeze(0), text_mask[bib, :input_lengths[bib]].unsqueeze(0))
831
+
832
+ x, _ = model.predictor.lstm(d)
833
+ duration = model.predictor.duration_proj(x)
834
+
835
+ duration = torch.sigmoid(duration).sum(axis=-1)
836
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
837
+
838
+ pred_dur[-1] += 5
839
+
840
+ pred_aln_trg = torch.zeros(input_lengths[bib], int(pred_dur.sum().data))
841
+ c_frame = 0
842
+ for i in range(pred_aln_trg.size(0)):
843
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
844
+ c_frame += int(pred_dur[i].data)
845
+
846
+ # encode prosody
847
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(texts.device))
848
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
849
+ out = model.decoder((t_en[bib, :, :input_lengths[bib]].unsqueeze(0) @ pred_aln_trg.unsqueeze(0).to(texts.device)),
850
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
851
+
852
+ writer.add_audio('pred/y' + str(bib), out.cpu().numpy().squeeze(), epoch, sample_rate=sr)
853
+
854
+ if bib >= 5:
855
+ break
856
+
857
+ if epoch % saving_epoch == 0:
858
+ if (loss_test / iters_test) < best_loss:
859
+ best_loss = loss_test / iters_test
860
+ print('Saving..')
861
+ state = {
862
+ 'net': {key: model[key].state_dict() for key in model},
863
+ 'optimizer': optimizer.state_dict(),
864
+ 'iters': iters,
865
+ 'val_loss': loss_test / iters_test,
866
+ 'epoch': epoch,
867
+ }
868
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
869
+ torch.save(state, save_path)
870
+
871
+ # if estimate sigma, save the estimated simga
872
+ if model_params.diffusion.dist.estimate_sigma_data:
873
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
874
+
875
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
876
+ yaml.dump(config, outfile, default_flow_style=True)
877
+
878
+ if __name__=="__main__":
879
+ main()
logs/pod_90h_30k/config_ft_single.yml → Configs/.ipynb_checkpoints/config_ft_single-checkpoint.yml RENAMED
@@ -1,18 +1,18 @@
1
  # ─── GLOBAL ──────────────────────────────────────────────────────────
2
- log_dir: logs/pod_90h_30k
3
  device: "cuda"
4
 
5
- batch_size: 8 # 40 GB A100, fp16
6
- max_len: 160 # ≈ 8 s (200 × 40 ms)
7
 
8
- epochs_1st: 13 # first-stage schedule
9
- epochs_2nd: 13 # second-stage schedule (later)
10
- save_freq: 2
11
  log_interval: 50
12
 
13
  # leave blank on first run
14
- pretrained_model: ""
15
- second_stage_load_pretrained: false
16
  load_only_params: false
17
 
18
  # ─── PRE-PROCESS ─────────────────────────────────────────────────────
@@ -25,11 +25,11 @@ preprocess_params:
25
 
26
  # ─── DATA ────────────────────────────────────────────────────────────
27
  data_params:
28
- root_path: /home/ubuntu/styletts2-ft/data/wavs
29
- train_data: /home/ubuntu/styletts2-ft/data/train_list.txt
30
- val_data: /home/ubuntu/styletts2-ft/data/val_list.txt
31
  min_length: 50 # sample until texts with this size are obtained for OOD texts
32
- OOD_data: /home/ubuntu/styletts2-ft/data/OOD_texts.txt
33
 
34
  # ─── LOSS SCHEDULE ──────────────────────────────────────────────────
35
  loss_params:
@@ -39,7 +39,7 @@ loss_params:
39
 
40
  lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
41
  lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
42
- TMA_epoch: 50 # TMA starting epoch (1st stage)
43
 
44
  lambda_F0: 1. # F0 reconstruction loss (2nd stage)
45
  lambda_norm: 1. # norm reconstruction loss (2nd stage)
@@ -48,14 +48,14 @@ loss_params:
48
  lambda_sty: 1. # style reconstruction loss (2nd stage)
49
  lambda_diff: 1. # score matching loss (2nd stage)
50
 
51
- diff_epoch: 20 # style diffusion starting epoch (2nd stage)
52
- joint_epoch: 50 # joint training starting epoch (2nd stage)
53
 
54
  # ─── OPTIMISER ──────────────────────────────────────────────────────
55
  optimizer_params:
56
- lr: 0.0008
57
- bert_lr: 0.00002
58
- ft_lr: 0.0002
59
  grad_accum_steps: 2
60
 
61
  # ─── MODEL (core network & sub-modules) ─────────────────────────────
@@ -105,7 +105,7 @@ F0_path: "Utils/JDC/bst.t7"
105
  ASR_config: "Utils/ASR/config.yml"
106
  ASR_path: "Utils/ASR/epoch_00080.pth"
107
  PLBERT_dir: 'Utils/PLBERT/'
108
- first_stage_path: "" # filled automatically after this run
109
 
110
  # ─── SLM ADVERSARIAL (ignored in stage-1, kept default) ─────────────
111
  slmadv_params:
 
1
  # ─── GLOBAL ──────────────────────────────────────────────────────────
2
+ log_dir: logs/pod_90h_30k_second_lr1
3
  device: "cuda"
4
 
5
+ batch_size: 12 # 40 GB A100, fp16
6
+ max_len: 300 # ≈ 8 s (200 × 40 ms)
7
 
8
+ epochs_1st: 25 # first-stage schedule
9
+ epochs_2nd: 20 # second-stage schedule (later)
10
+ save_freq: 1
11
  log_interval: 50
12
 
13
  # leave blank on first run
14
+ pretrained_model: "" #"/workspace/styletts2/logs/pod_90h_30k/epoch_2nd_00003.pth"
15
+ second_stage_load_pretrained: true
16
  load_only_params: false
17
 
18
  # ─── PRE-PROCESS ─────────────────────────────────────────────────────
 
25
 
26
  # ─── DATA ────────────────────────────────────────────────────────────
27
  data_params:
28
+ root_path: /workspace
29
+ train_data: /workspace/styletts2/data/train_list.txt
30
+ val_data: /workspace/styletts2/data/val_list.txt
31
  min_length: 50 # sample until texts with this size are obtained for OOD texts
32
+ OOD_data: /workspace/styletts2/data/OOD_texts.txt
33
 
34
  # ─── LOSS SCHEDULE ──────────────────────────────────────────────────
35
  loss_params:
 
39
 
40
  lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
41
  lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
42
+ TMA_epoch: 14 # TMA starting epoch (1st stage)
43
 
44
  lambda_F0: 1. # F0 reconstruction loss (2nd stage)
45
  lambda_norm: 1. # norm reconstruction loss (2nd stage)
 
48
  lambda_sty: 1. # style reconstruction loss (2nd stage)
49
  lambda_diff: 1. # score matching loss (2nd stage)
50
 
51
+ diff_epoch: 0 # style diffusion starting epoch (2nd stage)
52
+ joint_epoch: 0 # joint training starting epoch (2nd stage)
53
 
54
  # ─── OPTIMISER ──────────────────────────────────────────────────────
55
  optimizer_params:
56
+ lr: 0.0001
57
+ bert_lr: 0.00001
58
+ ft_lr: 0.0001
59
  grad_accum_steps: 2
60
 
61
  # ─── MODEL (core network & sub-modules) ─────────────────────────────
 
105
  ASR_config: "Utils/ASR/config.yml"
106
  ASR_path: "Utils/ASR/epoch_00080.pth"
107
  PLBERT_dir: 'Utils/PLBERT/'
108
+ first_stage_path: "/workspace/styletts2/stage1_final.pth" # filled automatically after this run
109
 
110
  # ─── SLM ADVERSARIAL (ignored in stage-1, kept default) ─────────────
111
  slmadv_params:
Configs/.ipynb_checkpoints/config_libritts-checkpoint.yml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LibriTTS"
2
+ first_stage_path: "first_stage.pth"
3
+ save_freq: 1
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 50 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 30 # number of peochs for second stage training (joint training)
8
+ batch_size: 16
9
+ max_len: 300 # maximum number of frames
10
+ pretrained_model: ""
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "Utils/JDC/bst.t7"
15
+ ASR_config: "Utils/ASR/config.yml"
16
+ ASR_path: "Utils/ASR/epoch_00080.pth"
17
+ PLBERT_dir: 'Utils/PLBERT/'
18
+
19
+ data_params:
20
+ train_data: "Data/train_list.txt"
21
+ val_data: "Data/val_list.txt"
22
+ root_path: ""
23
+ OOD_data: "Data/OOD_texts.txt"
24
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
25
+
26
+ preprocess_params:
27
+ sr: 24000
28
+ spect_params:
29
+ n_fft: 2048
30
+ win_length: 1200
31
+ hop_length: 300
32
+
33
+ model_params:
34
+ multispeaker: true
35
+
36
+ dim_in: 64
37
+ hidden_dim: 512
38
+ max_conv_dim: 512
39
+ n_layer: 3
40
+ n_mels: 80
41
+
42
+ n_token: 178 # number of phoneme tokens
43
+ max_dur: 50 # maximum duration of a single phoneme
44
+ style_dim: 128 # style vector size
45
+
46
+ dropout: 0.2
47
+
48
+ # config for decoder
49
+ decoder:
50
+ type: 'hifigan' # either hifigan or istftnet
51
+ resblock_kernel_sizes: [3,7,11]
52
+ upsample_rates : [10,5,3,2]
53
+ upsample_initial_channel: 512
54
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
55
+ upsample_kernel_sizes: [20,10,6,4]
56
+
57
+ # speech language model config
58
+ slm:
59
+ model: 'microsoft/wavlm-base-plus'
60
+ sr: 16000 # sampling rate of SLM
61
+ hidden: 768 # hidden size of SLM
62
+ nlayers: 13 # number of layers of SLM
63
+ initial_channel: 64 # initial channels of SLM discriminator head
64
+
65
+ # style diffusion model config
66
+ diffusion:
67
+ embedding_mask_proba: 0.1
68
+ # transformer config
69
+ transformer:
70
+ num_layers: 3
71
+ num_heads: 8
72
+ head_features: 64
73
+ multiplier: 2
74
+
75
+ # diffusion distribution config
76
+ dist:
77
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
78
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
79
+ mean: -3.0
80
+ std: 1.0
81
+
82
+ loss_params:
83
+ lambda_mel: 5. # mel reconstruction loss
84
+ lambda_gen: 1. # generator loss
85
+ lambda_slm: 1. # slm feature matching loss
86
+
87
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
88
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
89
+ TMA_epoch: 5 # TMA starting epoch (1st stage)
90
+
91
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
92
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
93
+ lambda_dur: 1. # duration loss (2nd stage)
94
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
95
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
96
+ lambda_diff: 1. # score matching loss (2nd stage)
97
+
98
+ diff_epoch: 10 # style diffusion starting epoch (2nd stage)
99
+ joint_epoch: 15 # joint training starting epoch (2nd stage)
100
+
101
+ optimizer_params:
102
+ lr: 0.0001 # general learning rate
103
+ bert_lr: 0.00001 # learning rate for PLBERT
104
+ ft_lr: 0.00001 # learning rate for acoustic modules
105
+
106
+ slmadv_params:
107
+ min_len: 400 # minimum length of samples
108
+ max_len: 500 # maximum length of samples
109
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
110
+ iter: 20 # update the discriminator every this iterations of generator update
111
+ thresh: 5 # gradient norm above which the gradient is scaled
112
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
113
+ sig: 1.5 # sigma for differentiable duration modeling
Configs/config_ft_single.yml CHANGED
@@ -1,18 +1,18 @@
1
  # ─── GLOBAL ──────────────────────────────────────────────────────────
2
- log_dir: logs/pod_90h_30k
3
  device: "cuda"
4
 
5
- batch_size: 8 # 40 GB A100, fp16
6
  max_len: 300 # ≈ 8 s (200 × 40 ms)
7
 
8
  epochs_1st: 25 # first-stage schedule
9
- epochs_2nd: 15 # second-stage schedule (later)
10
- save_freq: 2
11
  log_interval: 50
12
 
13
  # leave blank on first run
14
- pretrained_model: /home/ubuntu/styletts2-ft/logs/pod_90h_30k/epoch_1st_0012.pth
15
- second_stage_load_pretrained: false
16
  load_only_params: false
17
 
18
  # ─── PRE-PROCESS ─────────────────────────────────────────────────────
@@ -25,11 +25,11 @@ preprocess_params:
25
 
26
  # ─── DATA ────────────────────────────────────────────────────────────
27
  data_params:
28
- root_path: /home/ubuntu/styletts2-ft/data/wavs
29
- train_data: /home/ubuntu/styletts2-ft/data/train_list.txt
30
- val_data: /home/ubuntu/styletts2-ft/data/val_list.txt
31
  min_length: 50 # sample until texts with this size are obtained for OOD texts
32
- OOD_data: /home/ubuntu/styletts2-ft/data/OOD_texts.txt
33
 
34
  # ─── LOSS SCHEDULE ──────────────────────────────────────────────────
35
  loss_params:
@@ -48,14 +48,14 @@ loss_params:
48
  lambda_sty: 1. # style reconstruction loss (2nd stage)
49
  lambda_diff: 1. # score matching loss (2nd stage)
50
 
51
- diff_epoch: 20 # style diffusion starting epoch (2nd stage)
52
- joint_epoch: 50 # joint training starting epoch (2nd stage)
53
 
54
  # ─── OPTIMISER ──────────────────────────────────────────────────────
55
  optimizer_params:
56
- lr: 0.0008
57
- bert_lr: 0.00002
58
- ft_lr: 0.0002
59
  grad_accum_steps: 2
60
 
61
  # ─── MODEL (core network & sub-modules) ─────────────────────────────
@@ -105,7 +105,7 @@ F0_path: "Utils/JDC/bst.t7"
105
  ASR_config: "Utils/ASR/config.yml"
106
  ASR_path: "Utils/ASR/epoch_00080.pth"
107
  PLBERT_dir: 'Utils/PLBERT/'
108
- first_stage_path: "" # filled automatically after this run
109
 
110
  # ─── SLM ADVERSARIAL (ignored in stage-1, kept default) ─────────────
111
  slmadv_params:
 
1
  # ─── GLOBAL ──────────────────────────────────────────────────────────
2
+ log_dir: logs/pod_90h_30k_second_lr1
3
  device: "cuda"
4
 
5
+ batch_size: 12 # 40 GB A100, fp16
6
  max_len: 300 # ≈ 8 s (200 × 40 ms)
7
 
8
  epochs_1st: 25 # first-stage schedule
9
+ epochs_2nd: 20 # second-stage schedule (later)
10
+ save_freq: 1
11
  log_interval: 50
12
 
13
  # leave blank on first run
14
+ pretrained_model: "" #"/workspace/styletts2/logs/pod_90h_30k/epoch_2nd_00003.pth"
15
+ second_stage_load_pretrained: true
16
  load_only_params: false
17
 
18
  # ─── PRE-PROCESS ─────────────────────────────────────────────────────
 
25
 
26
  # ─── DATA ────────────────────────────────────────────────────────────
27
  data_params:
28
+ root_path: /workspace
29
+ train_data: /workspace/styletts2/data/train_list.txt
30
+ val_data: /workspace/styletts2/data/val_list.txt
31
  min_length: 50 # sample until texts with this size are obtained for OOD texts
32
+ OOD_data: /workspace/styletts2/data/OOD_texts.txt
33
 
34
  # ─── LOSS SCHEDULE ──────────────────────────────────────────────────
35
  loss_params:
 
48
  lambda_sty: 1. # style reconstruction loss (2nd stage)
49
  lambda_diff: 1. # score matching loss (2nd stage)
50
 
51
+ diff_epoch: 0 # style diffusion starting epoch (2nd stage)
52
+ joint_epoch: 0 # joint training starting epoch (2nd stage)
53
 
54
  # ─── OPTIMISER ──────────────────────────────────────────────────────
55
  optimizer_params:
56
+ lr: 0.0001
57
+ bert_lr: 0.00001
58
+ ft_lr: 0.00001
59
  grad_accum_steps: 2
60
 
61
  # ─── MODEL (core network & sub-modules) ─────────────────────────────
 
105
  ASR_config: "Utils/ASR/config.yml"
106
  ASR_path: "Utils/ASR/epoch_00080.pth"
107
  PLBERT_dir: 'Utils/PLBERT/'
108
+ first_stage_path: "/workspace/styletts2/stage1_final.pth" # filled automatically after this run
109
 
110
  # ─── SLM ADVERSARIAL (ignored in stage-1, kept default) ─────────────
111
  slmadv_params:
Demo/.ipynb_checkpoints/Inference_LibriTTS-checkpoint.ipynb ADDED
@@ -0,0 +1,1155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9adb7bd1",
6
+ "metadata": {},
7
+ "source": [
8
+ "# StyleTTS 2 Demo (LibriTTS)\n",
9
+ "\n",
10
+ "Before you run the following cells, please make sure you have downloaded [reference_audio.zip](https://huggingface.co/yl4579/StyleTTS2-LibriTTS/resolve/main/reference_audio.zip) and unzipped it under the `demo` folder."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "6108384d",
16
+ "metadata": {},
17
+ "source": [
18
+ "### Utils"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "96e173bf",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import torch\n",
29
+ "torch.manual_seed(0)\n",
30
+ "torch.backends.cudnn.benchmark = False\n",
31
+ "torch.backends.cudnn.deterministic = True\n",
32
+ "\n",
33
+ "import random\n",
34
+ "random.seed(0)\n",
35
+ "\n",
36
+ "import numpy as np\n",
37
+ "np.random.seed(0)"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "da84c60f",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "%cd .."
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "id": "5a3ddcc8",
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "# load packages\n",
58
+ "import time\n",
59
+ "import random\n",
60
+ "import yaml\n",
61
+ "from munch import Munch\n",
62
+ "import numpy as np\n",
63
+ "import torch\n",
64
+ "from torch import nn\n",
65
+ "import torch.nn.functional as F\n",
66
+ "import torchaudio\n",
67
+ "import librosa\n",
68
+ "from nltk.tokenize import word_tokenize\n",
69
+ "\n",
70
+ "from models import *\n",
71
+ "from utils import *\n",
72
+ "from text_utils import TextCleaner\n",
73
+ "textclenaer = TextCleaner()\n",
74
+ "\n",
75
+ "%matplotlib inline"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "id": "00ee05e1",
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
86
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
87
+ "mean, std = -4, 4\n",
88
+ "\n",
89
+ "def length_to_mask(lengths):\n",
90
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
91
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
92
+ " return mask\n",
93
+ "\n",
94
+ "def preprocess(wave):\n",
95
+ " wave_tensor = torch.from_numpy(wave).float()\n",
96
+ " mel_tensor = to_mel(wave_tensor)\n",
97
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
98
+ " return mel_tensor\n",
99
+ "\n",
100
+ "def compute_style(path):\n",
101
+ " wave, sr = librosa.load(path, sr=24000)\n",
102
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
103
+ " if sr != 24000:\n",
104
+ " audio = librosa.resample(audio, sr, 24000)\n",
105
+ " mel_tensor = preprocess(audio).to(device)\n",
106
+ "\n",
107
+ " with torch.no_grad():\n",
108
+ " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
109
+ " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
110
+ "\n",
111
+ " return torch.cat([ref_s, ref_p], dim=1)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "id": "bbdc04c0",
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "id": "7b9cecbe",
127
+ "metadata": {},
128
+ "source": [
129
+ "### Load models"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "id": "64fc4c0f",
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "# load phonemizer\n",
140
+ "import phonemizer\n",
141
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "48e7b644",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "config = yaml.safe_load(open(\"Models/LibriTTS/config.yml\"))\n",
152
+ "\n",
153
+ "# load pretrained ASR model\n",
154
+ "ASR_config = config.get('ASR_config', False)\n",
155
+ "ASR_path = config.get('ASR_path', False)\n",
156
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
157
+ "\n",
158
+ "# load pretrained F0 model\n",
159
+ "F0_path = config.get('F0_path', False)\n",
160
+ "pitch_extractor = load_F0_models(F0_path)\n",
161
+ "\n",
162
+ "# load BERT model\n",
163
+ "from Utils.PLBERT.util import load_plbert\n",
164
+ "BERT_path = config.get('PLBERT_dir', False)\n",
165
+ "plbert = load_plbert(BERT_path)"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": null,
171
+ "id": "ffc18cf7",
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "model_params = recursive_munch(config['model_params'])\n",
176
+ "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
177
+ "_ = [model[key].eval() for key in model]\n",
178
+ "_ = [model[key].to(device) for key in model]"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "id": "64529d5c",
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "params_whole = torch.load(\"Models/LibriTTS/epochs_2nd_00020.pth\", map_location='cpu')\n",
189
+ "params = params_whole['net']"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "895d9706",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "for key in model:\n",
200
+ " if key in params:\n",
201
+ " print('%s loaded' % key)\n",
202
+ " try:\n",
203
+ " model[key].load_state_dict(params[key])\n",
204
+ " except:\n",
205
+ " from collections import OrderedDict\n",
206
+ " state_dict = params[key]\n",
207
+ " new_state_dict = OrderedDict()\n",
208
+ " for k, v in state_dict.items():\n",
209
+ " name = k[7:] # remove `module.`\n",
210
+ " new_state_dict[name] = v\n",
211
+ " # load params\n",
212
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
213
+ "# except:\n",
214
+ "# _load(params[key], model[key])\n",
215
+ "_ = [model[key].eval() for key in model]"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "id": "c1a59db2",
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "id": "e30985ab",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "sampler = DiffusionSampler(\n",
236
+ " model.diffusion.diffusion,\n",
237
+ " sampler=ADPM2Sampler(),\n",
238
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
239
+ " clamp=False\n",
240
+ ")"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "markdown",
245
+ "id": "b803110e",
246
+ "metadata": {},
247
+ "source": [
248
+ "### Synthesize speech"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "id": "ca57469c",
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
259
+ " text = text.strip()\n",
260
+ " ps = global_phonemizer.phonemize([text])\n",
261
+ " ps = word_tokenize(ps[0])\n",
262
+ " ps = ' '.join(ps)\n",
263
+ " tokens = textclenaer(ps)\n",
264
+ " tokens.insert(0, 0)\n",
265
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
266
+ " \n",
267
+ " with torch.no_grad():\n",
268
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
269
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
270
+ "\n",
271
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
272
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
273
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
274
+ "\n",
275
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
276
+ " embedding=bert_dur,\n",
277
+ " embedding_scale=embedding_scale,\n",
278
+ " features=ref_s, # reference from the same speaker as the embedding\n",
279
+ " num_steps=diffusion_steps).squeeze(1)\n",
280
+ "\n",
281
+ "\n",
282
+ " s = s_pred[:, 128:]\n",
283
+ " ref = s_pred[:, :128]\n",
284
+ "\n",
285
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
286
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
287
+ "\n",
288
+ " d = model.predictor.text_encoder(d_en, \n",
289
+ " s, input_lengths, text_mask)\n",
290
+ "\n",
291
+ " x, _ = model.predictor.lstm(d)\n",
292
+ " duration = model.predictor.duration_proj(x)\n",
293
+ "\n",
294
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
295
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
296
+ "\n",
297
+ "\n",
298
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
299
+ " c_frame = 0\n",
300
+ " for i in range(pred_aln_trg.size(0)):\n",
301
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
302
+ " c_frame += int(pred_dur[i].data)\n",
303
+ "\n",
304
+ " # encode prosody\n",
305
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
306
+ " if model_params.decoder.type == \"hifigan\":\n",
307
+ " asr_new = torch.zeros_like(en)\n",
308
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
309
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
310
+ " en = asr_new\n",
311
+ "\n",
312
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
313
+ "\n",
314
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
315
+ " if model_params.decoder.type == \"hifigan\":\n",
316
+ " asr_new = torch.zeros_like(asr)\n",
317
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
318
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
319
+ " asr = asr_new\n",
320
+ "\n",
321
+ " out = model.decoder(asr, \n",
322
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
323
+ " \n",
324
+ " \n",
325
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "markdown",
330
+ "id": "d438ef4f",
331
+ "metadata": {},
332
+ "source": [
333
+ "#### Basic synthesis (5 diffusion steps, seen speakers)"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "cace9787",
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": [
343
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. '''"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "execution_count": null,
349
+ "id": "7c88f461",
350
+ "metadata": {},
351
+ "outputs": [],
352
+ "source": [
353
+ "reference_dicts = {}\n",
354
+ "reference_dicts['696_92939'] = \"Demo/reference_audio/696_92939_000016_000006.wav\"\n",
355
+ "reference_dicts['1789_142896'] = \"Demo/reference_audio/1789_142896_000022_000005.wav\""
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "id": "16e8ac60",
362
+ "metadata": {},
363
+ "outputs": [],
364
+ "source": [
365
+ "start = time.time()\n",
366
+ "noise = torch.randn(1,1,256).to(device)\n",
367
+ "for k, path in reference_dicts.items():\n",
368
+ " ref_s = compute_style(path)\n",
369
+ " \n",
370
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
371
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
372
+ " print(f\"RTF = {rtf:5f}\")\n",
373
+ " import IPython.display as ipd\n",
374
+ " print(k + ' Synthesized:')\n",
375
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
376
+ " print('Reference:')\n",
377
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "markdown",
382
+ "id": "14838708",
383
+ "metadata": {},
384
+ "source": [
385
+ "#### With higher diffusion steps (more diverse)\n",
386
+ "\n",
387
+ "Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed."
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": null,
393
+ "id": "6fbff03b",
394
+ "metadata": {},
395
+ "outputs": [],
396
+ "source": [
397
+ "noise = torch.randn(1,1,256).to(device)\n",
398
+ "for k, path in reference_dicts.items():\n",
399
+ " ref_s = compute_style(path)\n",
400
+ " start = time.time()\n",
401
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=10, embedding_scale=1)\n",
402
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
403
+ " print(f\"RTF = {rtf:5f}\")\n",
404
+ " import IPython.display as ipd\n",
405
+ " print(k + ' Synthesized:')\n",
406
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
407
+ " print(k + ' Reference:')\n",
408
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "markdown",
413
+ "id": "7e6867fd",
414
+ "metadata": {},
415
+ "source": [
416
+ "#### Basic synthesis (5 diffusion steps, umseen speakers)\n",
417
+ "The following samples are to reproduce samples in [Section 4](https://styletts2.github.io/#libri) of the demo page. All spsakers are unseen during training. You can compare the generated samples to popular zero-shot TTS models like Vall-E and NaturalSpeech 2."
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": null,
423
+ "id": "f4e8faa0",
424
+ "metadata": {},
425
+ "outputs": [],
426
+ "source": [
427
+ "reference_dicts = {}\n",
428
+ "# format: (path, text)\n",
429
+ "reference_dicts['1221-135767'] = (\"Demo/reference_audio/1221-135767-0014.wav\", \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\")\n",
430
+ "reference_dicts['5639-40744'] = (\"Demo/reference_audio/5639-40744-0020.wav\", \"Thus did this humane and right minded father comfort his unhappy daughter, and her mother embracing her again, did all she could to soothe her feelings.\")\n",
431
+ "reference_dicts['908-157963'] = (\"Demo/reference_audio/908-157963-0027.wav\", \"And lay me down in my cold bed and leave my shining lot.\")\n",
432
+ "reference_dicts['4077-13754'] = (\"Demo/reference_audio/4077-13754-0000.wav\", \"The army found the people in poverty and left them in comparative wealth.\")"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "id": "653f1406",
439
+ "metadata": {},
440
+ "outputs": [],
441
+ "source": [
442
+ "noise = torch.randn(1,1,256).to(device)\n",
443
+ "for k, v in reference_dicts.items():\n",
444
+ " path, text = v\n",
445
+ " ref_s = compute_style(path)\n",
446
+ " start = time.time()\n",
447
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
448
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
449
+ " print(f\"RTF = {rtf:5f}\")\n",
450
+ " import IPython.display as ipd\n",
451
+ " print(k + ' Synthesized: ' + text)\n",
452
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
453
+ " print(k + ' Reference:')\n",
454
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "markdown",
459
+ "id": "141e91b3",
460
+ "metadata": {},
461
+ "source": [
462
+ "### Speech expressiveness\n",
463
+ "\n",
464
+ "The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page. The speaker reference used is `1221-135767-0014.wav`, which is unseen during training. \n",
465
+ "\n",
466
+ "#### With `embedding_scale=1`\n",
467
+ "This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional.\n",
468
+ "\n"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": null,
474
+ "id": "81addda4",
475
+ "metadata": {},
476
+ "outputs": [],
477
+ "source": [
478
+ "ref_s = compute_style(\"Demo/reference_audio/1221-135767-0014.wav\")"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "id": "be1b2a11",
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": [
488
+ "texts = {}\n",
489
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
490
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
491
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
492
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
493
+ "\n",
494
+ "for k,v in texts.items():\n",
495
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
496
+ " print(k + \": \")\n",
497
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "markdown",
502
+ "id": "96d262b8",
503
+ "metadata": {},
504
+ "source": [
505
+ "#### With `embedding_scale=2`"
506
+ ]
507
+ },
508
+ {
509
+ "cell_type": "code",
510
+ "execution_count": null,
511
+ "id": "3e7d40b4",
512
+ "metadata": {},
513
+ "outputs": [],
514
+ "source": [
515
+ "texts = {}\n",
516
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
517
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
518
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
519
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
520
+ "\n",
521
+ "for k,v in texts.items():\n",
522
+ " noise = torch.randn(1,1,256).to(device)\n",
523
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=2)\n",
524
+ " print(k + \": \")\n",
525
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "markdown",
530
+ "id": "402b2bd6",
531
+ "metadata": {},
532
+ "source": [
533
+ "#### With `embedding_scale=2, alpha = 0.5, beta = 0.9`\n",
534
+ "`alpha` and `beta` is the factor to determine much we use the style sampled based on the text instead of the reference. The higher the value of `alpha` and `beta`, the more suitable the style it is to the text but less similar to the reference. Using higher beta makes the synthesized speech more emotional, at the cost of lower similarity to the reference. `alpha` determines the timbre of the speaker while `beta` determines the prosody. "
535
+ ]
536
+ },
537
+ {
538
+ "cell_type": "code",
539
+ "execution_count": null,
540
+ "id": "599de5d5",
541
+ "metadata": {},
542
+ "outputs": [],
543
+ "source": [
544
+ "texts = {}\n",
545
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
546
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
547
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
548
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
549
+ "\n",
550
+ "for k,v in texts.items():\n",
551
+ " noise = torch.randn(1,1,256).to(device)\n",
552
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=2)\n",
553
+ " print(k + \": \")\n",
554
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "markdown",
559
+ "id": "48548866",
560
+ "metadata": {},
561
+ "source": [
562
+ "### Zero-shot speaker adaptation\n",
563
+ "This section recreates the \"Acoustic Environment Maintenance\" and \"Speaker’s Emotion Maintenance\" demo in [Section 4](https://styletts2.github.io/#libri) of the demo page. You can compare the generated samples to popular zero-shot TTS models like Vall-E. Note that the model was trained only on LibriTTS, which is about 250 times fewer data compared to those used to trian Vall-E with similar or better effect for these maintainance. "
564
+ ]
565
+ },
566
+ {
567
+ "cell_type": "markdown",
568
+ "id": "23e81572",
569
+ "metadata": {},
570
+ "source": [
571
+ "#### Acoustic Environment Maintenance\n",
572
+ "\n",
573
+ "Since we want to maintain the acoustic environment in the speaker (timbre), we set `alpha = 0` to make the speaker as closer to the reference as possible while only changing the prosody according to the text. "
574
+ ]
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "execution_count": null,
579
+ "id": "8087bccb",
580
+ "metadata": {},
581
+ "outputs": [],
582
+ "source": [
583
+ "reference_dicts = {}\n",
584
+ "# format: (path, text)\n",
585
+ "reference_dicts['3'] = (\"Demo/reference_audio/3.wav\", \"As friends thing I definitely I've got more male friends.\")\n",
586
+ "reference_dicts['4'] = (\"Demo/reference_audio/4.wav\", \"Everything is run by computer but you got to know how to think before you can do a computer.\")\n",
587
+ "reference_dicts['5'] = (\"Demo/reference_audio/5.wav\", \"Then out in LA you guys got a whole another ball game within California to worry about.\")"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "code",
592
+ "execution_count": null,
593
+ "id": "1e99c200",
594
+ "metadata": {},
595
+ "outputs": [],
596
+ "source": [
597
+ "noise = torch.randn(1,1,256).to(device)\n",
598
+ "for k, v in reference_dicts.items():\n",
599
+ " path, text = v\n",
600
+ " ref_s = compute_style(path)\n",
601
+ " start = time.time()\n",
602
+ " wav = inference(text, ref_s, alpha=0.0, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
603
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
604
+ " print(f\"RTF = {rtf:5f}\")\n",
605
+ " import IPython.display as ipd\n",
606
+ " print('Synthesized: ' + text)\n",
607
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
608
+ " print('Reference:')\n",
609
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "markdown",
614
+ "id": "7d56505d",
615
+ "metadata": {},
616
+ "source": [
617
+ "#### Speaker’s Emotion Maintenance\n",
618
+ "\n",
619
+ "Since we want to maintain the emotion in the speaker (prosody), we set `beta = 0.1` to make the speaker as closer to the reference as possible while having some diversity thruogh the slight timbre change."
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "code",
624
+ "execution_count": null,
625
+ "id": "f90179e7",
626
+ "metadata": {},
627
+ "outputs": [],
628
+ "source": [
629
+ "reference_dicts = {}\n",
630
+ "# format: (path, text)\n",
631
+ "reference_dicts['Anger'] = (\"Demo/reference_audio/anger.wav\", \"We have to reduce the number of plastic bags.\")\n",
632
+ "reference_dicts['Sleepy'] = (\"Demo/reference_audio/sleepy.wav\", \"We have to reduce the number of plastic bags.\")\n",
633
+ "reference_dicts['Amused'] = (\"Demo/reference_audio/amused.wav\", \"We have to reduce the number of plastic bags.\")\n",
634
+ "reference_dicts['Disgusted'] = (\"Demo/reference_audio/disgusted.wav\", \"We have to reduce the number of plastic bags.\")"
635
+ ]
636
+ },
637
+ {
638
+ "cell_type": "code",
639
+ "execution_count": null,
640
+ "id": "2e6bdfed",
641
+ "metadata": {},
642
+ "outputs": [],
643
+ "source": [
644
+ "noise = torch.randn(1,1,256).to(device)\n",
645
+ "for k, v in reference_dicts.items():\n",
646
+ " path, text = v\n",
647
+ " ref_s = compute_style(path)\n",
648
+ " start = time.time()\n",
649
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.1, diffusion_steps=10, embedding_scale=1)\n",
650
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
651
+ " print(f\"RTF = {rtf:5f}\")\n",
652
+ " import IPython.display as ipd\n",
653
+ " print(k + ' Synthesized: ' + text)\n",
654
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
655
+ " print(k + ' Reference:')\n",
656
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
657
+ ]
658
+ },
659
+ {
660
+ "cell_type": "markdown",
661
+ "id": "37ae3963",
662
+ "metadata": {},
663
+ "source": [
664
+ "### Longform Narration\n",
665
+ "\n",
666
+ "This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page."
667
+ ]
668
+ },
669
+ {
670
+ "cell_type": "code",
671
+ "execution_count": null,
672
+ "id": "f12a716b",
673
+ "metadata": {},
674
+ "outputs": [],
675
+ "source": [
676
+ "passage = '''If the supply of fruit is greater than the family needs, it may be made a source of income by sending the fresh fruit to the market if there is one near enough, or by preserving, canning, and making jelly for sale. To make such an enterprise a success the fruit and work must be first class. There is magic in the word \"Homemade,\" when the product appeals to the eye and the palate; but many careless and incompetent people have found to their sorrow that this word has not magic enough to float inferior goods on the market. As a rule large canning and preserving establishments are clean and have the best appliances, and they employ chemists and skilled labor. The home product must be very good to compete with the attractive goods that are sent out from such establishments. Yet for first class home made products there is a market in all large cities. All first-class grocers have customers who purchase such goods.'''"
677
+ ]
678
+ },
679
+ {
680
+ "cell_type": "code",
681
+ "execution_count": null,
682
+ "id": "a1a38079",
683
+ "metadata": {},
684
+ "outputs": [],
685
+ "source": [
686
+ "def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):\n",
687
+ " text = text.strip()\n",
688
+ " ps = global_phonemizer.phonemize([text])\n",
689
+ " ps = word_tokenize(ps[0])\n",
690
+ " ps = ' '.join(ps)\n",
691
+ " ps = ps.replace('``', '\"')\n",
692
+ " ps = ps.replace(\"''\", '\"')\n",
693
+ "\n",
694
+ " tokens = textclenaer(ps)\n",
695
+ " tokens.insert(0, 0)\n",
696
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
697
+ " \n",
698
+ " with torch.no_grad():\n",
699
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
700
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
701
+ "\n",
702
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
703
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
704
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
705
+ "\n",
706
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
707
+ " embedding=bert_dur,\n",
708
+ " embedding_scale=embedding_scale,\n",
709
+ " features=ref_s, # reference from the same speaker as the embedding\n",
710
+ " num_steps=diffusion_steps).squeeze(1)\n",
711
+ " \n",
712
+ " if s_prev is not None:\n",
713
+ " # convex combination of previous and current style\n",
714
+ " s_pred = t * s_prev + (1 - t) * s_pred\n",
715
+ " \n",
716
+ " s = s_pred[:, 128:]\n",
717
+ " ref = s_pred[:, :128]\n",
718
+ " \n",
719
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
720
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
721
+ "\n",
722
+ " s_pred = torch.cat([ref, s], dim=-1)\n",
723
+ "\n",
724
+ " d = model.predictor.text_encoder(d_en, \n",
725
+ " s, input_lengths, text_mask)\n",
726
+ "\n",
727
+ " x, _ = model.predictor.lstm(d)\n",
728
+ " duration = model.predictor.duration_proj(x)\n",
729
+ "\n",
730
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
731
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
732
+ "\n",
733
+ "\n",
734
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
735
+ " c_frame = 0\n",
736
+ " for i in range(pred_aln_trg.size(0)):\n",
737
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
738
+ " c_frame += int(pred_dur[i].data)\n",
739
+ "\n",
740
+ " # encode prosody\n",
741
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
742
+ " if model_params.decoder.type == \"hifigan\":\n",
743
+ " asr_new = torch.zeros_like(en)\n",
744
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
745
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
746
+ " en = asr_new\n",
747
+ "\n",
748
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
749
+ "\n",
750
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
751
+ " if model_params.decoder.type == \"hifigan\":\n",
752
+ " asr_new = torch.zeros_like(asr)\n",
753
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
754
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
755
+ " asr = asr_new\n",
756
+ "\n",
757
+ " out = model.decoder(asr, \n",
758
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
759
+ " \n",
760
+ " \n",
761
+ " return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later"
762
+ ]
763
+ },
764
+ {
765
+ "cell_type": "code",
766
+ "execution_count": null,
767
+ "id": "e9088f7a",
768
+ "metadata": {},
769
+ "outputs": [],
770
+ "source": [
771
+ "# unseen speaker\n",
772
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
773
+ "s_ref = compute_style(path)\n",
774
+ "sentences = passage.split('.') # simple split by comma\n",
775
+ "wavs = []\n",
776
+ "s_prev = None\n",
777
+ "for text in sentences:\n",
778
+ " if text.strip() == \"\": continue\n",
779
+ " text += '.' # add it back\n",
780
+ " \n",
781
+ " wav, s_prev = LFinference(text, \n",
782
+ " s_prev, \n",
783
+ " s_ref, \n",
784
+ " alpha = 0.3, \n",
785
+ " beta = 0.9, # make it more suitable for the text\n",
786
+ " t = 0.7, \n",
787
+ " diffusion_steps=10, embedding_scale=1.5)\n",
788
+ " wavs.append(wav)\n",
789
+ "print('Synthesized: ')\n",
790
+ "display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))\n",
791
+ "print('Reference: ')\n",
792
+ "display(ipd.Audio(path, rate=24000, normalize=False))"
793
+ ]
794
+ },
795
+ {
796
+ "cell_type": "markdown",
797
+ "id": "7517b657",
798
+ "metadata": {},
799
+ "source": [
800
+ "### Style Transfer\n",
801
+ "\n",
802
+ "The following section demostrates the style transfer capacity for unseen speakers in [Section 6](https://styletts2.github.io/#emo) of the demo page. For this, we set `alpha=0.5, beta = 0.9` for the most pronounced effects (mostly using the sampled style). "
803
+ ]
804
+ },
805
+ {
806
+ "cell_type": "code",
807
+ "execution_count": null,
808
+ "id": "ed95d0f7",
809
+ "metadata": {},
810
+ "outputs": [],
811
+ "source": [
812
+ "def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
813
+ " text = text.strip()\n",
814
+ " ps = global_phonemizer.phonemize([text])\n",
815
+ " ps = word_tokenize(ps[0])\n",
816
+ " ps = ' '.join(ps)\n",
817
+ "\n",
818
+ " tokens = textclenaer(ps)\n",
819
+ " tokens.insert(0, 0)\n",
820
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
821
+ " \n",
822
+ " ref_text = ref_text.strip()\n",
823
+ " ps = global_phonemizer.phonemize([ref_text])\n",
824
+ " ps = word_tokenize(ps[0])\n",
825
+ " ps = ' '.join(ps)\n",
826
+ "\n",
827
+ " ref_tokens = textclenaer(ps)\n",
828
+ " ref_tokens.insert(0, 0)\n",
829
+ " ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)\n",
830
+ " \n",
831
+ " \n",
832
+ " with torch.no_grad():\n",
833
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
834
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
835
+ "\n",
836
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
837
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
838
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
839
+ " \n",
840
+ " ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)\n",
841
+ " ref_text_mask = length_to_mask(ref_input_lengths).to(device)\n",
842
+ " ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())\n",
843
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
844
+ " embedding=bert_dur,\n",
845
+ " embedding_scale=embedding_scale,\n",
846
+ " features=ref_s, # reference from the same speaker as the embedding\n",
847
+ " num_steps=diffusion_steps).squeeze(1)\n",
848
+ "\n",
849
+ "\n",
850
+ " s = s_pred[:, 128:]\n",
851
+ " ref = s_pred[:, :128]\n",
852
+ "\n",
853
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
854
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
855
+ "\n",
856
+ " d = model.predictor.text_encoder(d_en, \n",
857
+ " s, input_lengths, text_mask)\n",
858
+ "\n",
859
+ " x, _ = model.predictor.lstm(d)\n",
860
+ " duration = model.predictor.duration_proj(x)\n",
861
+ "\n",
862
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
863
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
864
+ "\n",
865
+ "\n",
866
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
867
+ " c_frame = 0\n",
868
+ " for i in range(pred_aln_trg.size(0)):\n",
869
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
870
+ " c_frame += int(pred_dur[i].data)\n",
871
+ "\n",
872
+ " # encode prosody\n",
873
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
874
+ " if model_params.decoder.type == \"hifigan\":\n",
875
+ " asr_new = torch.zeros_like(en)\n",
876
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
877
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
878
+ " en = asr_new\n",
879
+ "\n",
880
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
881
+ "\n",
882
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
883
+ " if model_params.decoder.type == \"hifigan\":\n",
884
+ " asr_new = torch.zeros_like(asr)\n",
885
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
886
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
887
+ " asr = asr_new\n",
888
+ "\n",
889
+ " out = model.decoder(asr, \n",
890
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
891
+ " \n",
892
+ " \n",
893
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
894
+ ]
895
+ },
896
+ {
897
+ "cell_type": "code",
898
+ "execution_count": null,
899
+ "id": "ec3f0da4",
900
+ "metadata": {},
901
+ "outputs": [],
902
+ "source": [
903
+ "# reference texts to sample styles\n",
904
+ "\n",
905
+ "ref_texts = {}\n",
906
+ "ref_texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
907
+ "ref_texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
908
+ "ref_texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
909
+ "ref_texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\""
910
+ ]
911
+ },
912
+ {
913
+ "cell_type": "code",
914
+ "execution_count": null,
915
+ "id": "6d0a3825",
916
+ "metadata": {
917
+ "scrolled": false
918
+ },
919
+ "outputs": [],
920
+ "source": [
921
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
922
+ "s_ref = compute_style(path)\n",
923
+ "\n",
924
+ "text = \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\"\n",
925
+ "for k,v in ref_texts.items():\n",
926
+ " wav = STinference(text, s_ref, v, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=1.5)\n",
927
+ " print(k + \": \")\n",
928
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
929
+ ]
930
+ },
931
+ {
932
+ "cell_type": "markdown",
933
+ "id": "6750aed9",
934
+ "metadata": {},
935
+ "source": [
936
+ "### Speech diversity\n",
937
+ "\n",
938
+ "This section reproduces samples in [Section 7](https://styletts2.github.io/#var) of the demo page. \n",
939
+ "\n",
940
+ "`alpha` and `beta` determine the diversity of the synthesized speech. There are two extreme cases:\n",
941
+ "- If `alpha = 1` and `beta = 1`, the synthesized speech sounds the most dissimilar to the reference speaker, but it is also the most diverse (each time you synthesize a speech it will be totally different). \n",
942
+ "- If `alpha = 0` and `beta = 0`, the synthesized speech sounds the most siimlar to the reference speaker, but it is deterministic (i.e., the sampled style is not used for speech synthesis). \n"
943
+ ]
944
+ },
945
+ {
946
+ "cell_type": "markdown",
947
+ "id": "f6ae0aa5",
948
+ "metadata": {},
949
+ "source": [
950
+ "#### Default setting (`alpha = 0.3, beta=0.7`)\n",
951
+ "This setting uses 70% of the reference timbre and 30% of the reference prosody and use the diffusion model to sample them based on the text. "
952
+ ]
953
+ },
954
+ {
955
+ "cell_type": "code",
956
+ "execution_count": null,
957
+ "id": "36dc0148",
958
+ "metadata": {},
959
+ "outputs": [],
960
+ "source": [
961
+ "# unseen speaker\n",
962
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
963
+ "ref_s = compute_style(path)\n",
964
+ "\n",
965
+ "text = \"How much variation is there?\"\n",
966
+ "for _ in range(5):\n",
967
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
968
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
969
+ ]
970
+ },
971
+ {
972
+ "cell_type": "markdown",
973
+ "id": "bf9ef421",
974
+ "metadata": {},
975
+ "source": [
976
+ "#### Less diverse setting (`alpha = 0.1, beta=0.3`)\n",
977
+ "This setting uses 90% of the reference timbre and 70% of the reference prosody. This makes it more similar to the reference speaker at cost of less diverse samples. "
978
+ ]
979
+ },
980
+ {
981
+ "cell_type": "code",
982
+ "execution_count": null,
983
+ "id": "9ba406bd",
984
+ "metadata": {},
985
+ "outputs": [],
986
+ "source": [
987
+ "# unseen speaker\n",
988
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
989
+ "ref_s = compute_style(path)\n",
990
+ "\n",
991
+ "text = \"How much variation is there?\"\n",
992
+ "for _ in range(5):\n",
993
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.1, beta=0.3, embedding_scale=1)\n",
994
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
995
+ ]
996
+ },
997
+ {
998
+ "cell_type": "markdown",
999
+ "id": "a38fe464",
1000
+ "metadata": {},
1001
+ "source": [
1002
+ "#### More diverse setting (`alpha = 0.5, beta=0.95`)\n",
1003
+ "This setting uses 50% of the reference timbre and 5% of the reference prosody (so it uses 100% of the sampled prosody, which makes it more diverse), but this makes it more dissimilar to the reference speaker. "
1004
+ ]
1005
+ },
1006
+ {
1007
+ "cell_type": "code",
1008
+ "execution_count": null,
1009
+ "id": "5f25bf94",
1010
+ "metadata": {},
1011
+ "outputs": [],
1012
+ "source": [
1013
+ "# unseen speaker\n",
1014
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1015
+ "ref_s = compute_style(path)\n",
1016
+ "\n",
1017
+ "text = \"How much variation is there?\"\n",
1018
+ "for _ in range(5):\n",
1019
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.5, beta=0.95, embedding_scale=1)\n",
1020
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1021
+ ]
1022
+ },
1023
+ {
1024
+ "cell_type": "markdown",
1025
+ "id": "21c3a071",
1026
+ "metadata": {},
1027
+ "source": [
1028
+ "#### Extreme setting (`alpha = 1, beta=1`)\n",
1029
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very dissimilar to the reference speaker. "
1030
+ ]
1031
+ },
1032
+ {
1033
+ "cell_type": "code",
1034
+ "execution_count": null,
1035
+ "id": "fff8bab1",
1036
+ "metadata": {},
1037
+ "outputs": [],
1038
+ "source": [
1039
+ "# unseen speaker\n",
1040
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1041
+ "ref_s = compute_style(path)\n",
1042
+ "\n",
1043
+ "text = \"How much variation is there?\"\n",
1044
+ "for _ in range(5):\n",
1045
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=1, beta=1, embedding_scale=1)\n",
1046
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1047
+ ]
1048
+ },
1049
+ {
1050
+ "cell_type": "markdown",
1051
+ "id": "a8741e5a",
1052
+ "metadata": {},
1053
+ "source": [
1054
+ "#### No variation (`alpha = 0, beta=0`)\n",
1055
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very similar to the reference speaker, but there is no variation. "
1056
+ ]
1057
+ },
1058
+ {
1059
+ "cell_type": "code",
1060
+ "execution_count": null,
1061
+ "id": "e55dd281",
1062
+ "metadata": {},
1063
+ "outputs": [],
1064
+ "source": [
1065
+ "# unseen speaker\n",
1066
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1067
+ "ref_s = compute_style(path)\n",
1068
+ "\n",
1069
+ "text = \"How much variation is there?\"\n",
1070
+ "for _ in range(5):\n",
1071
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0, beta=0, embedding_scale=1)\n",
1072
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1073
+ ]
1074
+ },
1075
+ {
1076
+ "cell_type": "markdown",
1077
+ "id": "d5e86423",
1078
+ "metadata": {},
1079
+ "source": [
1080
+ "### Extra fun!\n",
1081
+ "\n",
1082
+ "Here we clone some of the authors' voice of the StyleTTS 2 papers with a few seconds of the recording in the wild. None of the voices is in the dataset and all authors agreed to have their voices cloned here."
1083
+ ]
1084
+ },
1085
+ {
1086
+ "cell_type": "code",
1087
+ "execution_count": null,
1088
+ "id": "6f558314",
1089
+ "metadata": {},
1090
+ "outputs": [],
1091
+ "source": [
1092
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. '''"
1093
+ ]
1094
+ },
1095
+ {
1096
+ "cell_type": "code",
1097
+ "execution_count": null,
1098
+ "id": "caa5747c",
1099
+ "metadata": {},
1100
+ "outputs": [],
1101
+ "source": [
1102
+ "reference_dicts = {}\n",
1103
+ "reference_dicts['Yinghao'] = \"Demo/reference_audio/Yinghao.wav\"\n",
1104
+ "reference_dicts['Gavin'] = \"Demo/reference_audio/Gavin.wav\"\n",
1105
+ "reference_dicts['Vinay'] = \"Demo/reference_audio/Vinay.wav\"\n",
1106
+ "reference_dicts['Nima'] = \"Demo/reference_audio/Nima.wav\""
1107
+ ]
1108
+ },
1109
+ {
1110
+ "cell_type": "code",
1111
+ "execution_count": null,
1112
+ "id": "44a4cea1",
1113
+ "metadata": {
1114
+ "scrolled": false
1115
+ },
1116
+ "outputs": [],
1117
+ "source": [
1118
+ "start = time.time()\n",
1119
+ "noise = torch.randn(1,1,256).to(device)\n",
1120
+ "for k, path in reference_dicts.items():\n",
1121
+ " ref_s = compute_style(path)\n",
1122
+ " \n",
1123
+ " wav = inference(text, ref_s, alpha=0.1, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
1124
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
1125
+ " print('Speaker: ' + k)\n",
1126
+ " import IPython.display as ipd\n",
1127
+ " print('Synthesized:')\n",
1128
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
1129
+ " print('Reference:')\n",
1130
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
1131
+ ]
1132
+ }
1133
+ ],
1134
+ "metadata": {
1135
+ "kernelspec": {
1136
+ "display_name": "NLP",
1137
+ "language": "python",
1138
+ "name": "nlp"
1139
+ },
1140
+ "language_info": {
1141
+ "codemirror_mode": {
1142
+ "name": "ipython",
1143
+ "version": 3
1144
+ },
1145
+ "file_extension": ".py",
1146
+ "mimetype": "text/x-python",
1147
+ "name": "python",
1148
+ "nbconvert_exporter": "python",
1149
+ "pygments_lexer": "ipython3",
1150
+ "version": "3.9.7"
1151
+ }
1152
+ },
1153
+ "nbformat": 4,
1154
+ "nbformat_minor": 5
1155
+ }
Demo/.ipynb_checkpoints/Inference_pod_90h_30k-checkpoint.ipynb ADDED
@@ -0,0 +1,1155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9adb7bd1",
6
+ "metadata": {},
7
+ "source": [
8
+ "# StyleTTS 2 Demo (LibriTTS)\n",
9
+ "\n",
10
+ "Before you run the following cells, please make sure you have downloaded [reference_audio.zip](https://huggingface.co/yl4579/StyleTTS2-LibriTTS/resolve/main/reference_audio.zip) and unzipped it under the `demo` folder."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "6108384d",
16
+ "metadata": {},
17
+ "source": [
18
+ "### Utils"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "96e173bf",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import torch\n",
29
+ "torch.manual_seed(0)\n",
30
+ "torch.backends.cudnn.benchmark = False\n",
31
+ "torch.backends.cudnn.deterministic = True\n",
32
+ "\n",
33
+ "import random\n",
34
+ "random.seed(0)\n",
35
+ "\n",
36
+ "import numpy as np\n",
37
+ "np.random.seed(0)"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "da84c60f",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "%cd .."
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "id": "5a3ddcc8",
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "# load packages\n",
58
+ "import time\n",
59
+ "import random\n",
60
+ "import yaml\n",
61
+ "from munch import Munch\n",
62
+ "import numpy as np\n",
63
+ "import torch\n",
64
+ "from torch import nn\n",
65
+ "import torch.nn.functional as F\n",
66
+ "import torchaudio\n",
67
+ "import librosa\n",
68
+ "from nltk.tokenize import word_tokenize\n",
69
+ "\n",
70
+ "from models import *\n",
71
+ "from utils import *\n",
72
+ "from text_utils import TextCleaner\n",
73
+ "textclenaer = TextCleaner()\n",
74
+ "\n",
75
+ "%matplotlib inline"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "id": "00ee05e1",
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
86
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
87
+ "mean, std = -4, 4\n",
88
+ "\n",
89
+ "def length_to_mask(lengths):\n",
90
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
91
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
92
+ " return mask\n",
93
+ "\n",
94
+ "def preprocess(wave):\n",
95
+ " wave_tensor = torch.from_numpy(wave).float()\n",
96
+ " mel_tensor = to_mel(wave_tensor)\n",
97
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
98
+ " return mel_tensor\n",
99
+ "\n",
100
+ "def compute_style(path):\n",
101
+ " wave, sr = librosa.load(path, sr=24000)\n",
102
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
103
+ " if sr != 24000:\n",
104
+ " audio = librosa.resample(audio, sr, 24000)\n",
105
+ " mel_tensor = preprocess(audio).to(device)\n",
106
+ "\n",
107
+ " with torch.no_grad():\n",
108
+ " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
109
+ " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
110
+ "\n",
111
+ " return torch.cat([ref_s, ref_p], dim=1)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "id": "bbdc04c0",
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "id": "7b9cecbe",
127
+ "metadata": {},
128
+ "source": [
129
+ "### Load models"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "id": "64fc4c0f",
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "# load phonemizer\n",
140
+ "import phonemizer\n",
141
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "48e7b644",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "config = yaml.safe_load(open(\"Models/LibriTTS/config.yml\"))\n",
152
+ "\n",
153
+ "# load pretrained ASR model\n",
154
+ "ASR_config = config.get('ASR_config', False)\n",
155
+ "ASR_path = config.get('ASR_path', False)\n",
156
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
157
+ "\n",
158
+ "# load pretrained F0 model\n",
159
+ "F0_path = config.get('F0_path', False)\n",
160
+ "pitch_extractor = load_F0_models(F0_path)\n",
161
+ "\n",
162
+ "# load BERT model\n",
163
+ "from Utils.PLBERT.util import load_plbert\n",
164
+ "BERT_path = config.get('PLBERT_dir', False)\n",
165
+ "plbert = load_plbert(BERT_path)"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": null,
171
+ "id": "ffc18cf7",
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "model_params = recursive_munch(config['model_params'])\n",
176
+ "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
177
+ "_ = [model[key].eval() for key in model]\n",
178
+ "_ = [model[key].to(device) for key in model]"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "id": "64529d5c",
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "params_whole = torch.load(\"Models/LibriTTS/epochs_2nd_00020.pth\", map_location='cpu')\n",
189
+ "params = params_whole['net']"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "895d9706",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "for key in model:\n",
200
+ " if key in params:\n",
201
+ " print('%s loaded' % key)\n",
202
+ " try:\n",
203
+ " model[key].load_state_dict(params[key])\n",
204
+ " except:\n",
205
+ " from collections import OrderedDict\n",
206
+ " state_dict = params[key]\n",
207
+ " new_state_dict = OrderedDict()\n",
208
+ " for k, v in state_dict.items():\n",
209
+ " name = k[7:] # remove `module.`\n",
210
+ " new_state_dict[name] = v\n",
211
+ " # load params\n",
212
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
213
+ "# except:\n",
214
+ "# _load(params[key], model[key])\n",
215
+ "_ = [model[key].eval() for key in model]"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "id": "c1a59db2",
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "id": "e30985ab",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "sampler = DiffusionSampler(\n",
236
+ " model.diffusion.diffusion,\n",
237
+ " sampler=ADPM2Sampler(),\n",
238
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
239
+ " clamp=False\n",
240
+ ")"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "markdown",
245
+ "id": "b803110e",
246
+ "metadata": {},
247
+ "source": [
248
+ "### Synthesize speech"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "id": "ca57469c",
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
259
+ " text = text.strip()\n",
260
+ " ps = global_phonemizer.phonemize([text])\n",
261
+ " ps = word_tokenize(ps[0])\n",
262
+ " ps = ' '.join(ps)\n",
263
+ " tokens = textclenaer(ps)\n",
264
+ " tokens.insert(0, 0)\n",
265
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
266
+ " \n",
267
+ " with torch.no_grad():\n",
268
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
269
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
270
+ "\n",
271
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
272
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
273
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
274
+ "\n",
275
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
276
+ " embedding=bert_dur,\n",
277
+ " embedding_scale=embedding_scale,\n",
278
+ " features=ref_s, # reference from the same speaker as the embedding\n",
279
+ " num_steps=diffusion_steps).squeeze(1)\n",
280
+ "\n",
281
+ "\n",
282
+ " s = s_pred[:, 128:]\n",
283
+ " ref = s_pred[:, :128]\n",
284
+ "\n",
285
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
286
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
287
+ "\n",
288
+ " d = model.predictor.text_encoder(d_en, \n",
289
+ " s, input_lengths, text_mask)\n",
290
+ "\n",
291
+ " x, _ = model.predictor.lstm(d)\n",
292
+ " duration = model.predictor.duration_proj(x)\n",
293
+ "\n",
294
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
295
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
296
+ "\n",
297
+ "\n",
298
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
299
+ " c_frame = 0\n",
300
+ " for i in range(pred_aln_trg.size(0)):\n",
301
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
302
+ " c_frame += int(pred_dur[i].data)\n",
303
+ "\n",
304
+ " # encode prosody\n",
305
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
306
+ " if model_params.decoder.type == \"hifigan\":\n",
307
+ " asr_new = torch.zeros_like(en)\n",
308
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
309
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
310
+ " en = asr_new\n",
311
+ "\n",
312
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
313
+ "\n",
314
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
315
+ " if model_params.decoder.type == \"hifigan\":\n",
316
+ " asr_new = torch.zeros_like(asr)\n",
317
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
318
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
319
+ " asr = asr_new\n",
320
+ "\n",
321
+ " out = model.decoder(asr, \n",
322
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
323
+ " \n",
324
+ " \n",
325
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "markdown",
330
+ "id": "d438ef4f",
331
+ "metadata": {},
332
+ "source": [
333
+ "#### Basic synthesis (5 diffusion steps, seen speakers)"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "cace9787",
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": [
343
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. '''"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "execution_count": null,
349
+ "id": "7c88f461",
350
+ "metadata": {},
351
+ "outputs": [],
352
+ "source": [
353
+ "reference_dicts = {}\n",
354
+ "reference_dicts['696_92939'] = \"Demo/reference_audio/696_92939_000016_000006.wav\"\n",
355
+ "reference_dicts['1789_142896'] = \"Demo/reference_audio/1789_142896_000022_000005.wav\""
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "id": "16e8ac60",
362
+ "metadata": {},
363
+ "outputs": [],
364
+ "source": [
365
+ "start = time.time()\n",
366
+ "noise = torch.randn(1,1,256).to(device)\n",
367
+ "for k, path in reference_dicts.items():\n",
368
+ " ref_s = compute_style(path)\n",
369
+ " \n",
370
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
371
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
372
+ " print(f\"RTF = {rtf:5f}\")\n",
373
+ " import IPython.display as ipd\n",
374
+ " print(k + ' Synthesized:')\n",
375
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
376
+ " print('Reference:')\n",
377
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "markdown",
382
+ "id": "14838708",
383
+ "metadata": {},
384
+ "source": [
385
+ "#### With higher diffusion steps (more diverse)\n",
386
+ "\n",
387
+ "Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed."
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": null,
393
+ "id": "6fbff03b",
394
+ "metadata": {},
395
+ "outputs": [],
396
+ "source": [
397
+ "noise = torch.randn(1,1,256).to(device)\n",
398
+ "for k, path in reference_dicts.items():\n",
399
+ " ref_s = compute_style(path)\n",
400
+ " start = time.time()\n",
401
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=10, embedding_scale=1)\n",
402
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
403
+ " print(f\"RTF = {rtf:5f}\")\n",
404
+ " import IPython.display as ipd\n",
405
+ " print(k + ' Synthesized:')\n",
406
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
407
+ " print(k + ' Reference:')\n",
408
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "markdown",
413
+ "id": "7e6867fd",
414
+ "metadata": {},
415
+ "source": [
416
+ "#### Basic synthesis (5 diffusion steps, umseen speakers)\n",
417
+ "The following samples are to reproduce samples in [Section 4](https://styletts2.github.io/#libri) of the demo page. All spsakers are unseen during training. You can compare the generated samples to popular zero-shot TTS models like Vall-E and NaturalSpeech 2."
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": null,
423
+ "id": "f4e8faa0",
424
+ "metadata": {},
425
+ "outputs": [],
426
+ "source": [
427
+ "reference_dicts = {}\n",
428
+ "# format: (path, text)\n",
429
+ "reference_dicts['1221-135767'] = (\"Demo/reference_audio/1221-135767-0014.wav\", \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\")\n",
430
+ "reference_dicts['5639-40744'] = (\"Demo/reference_audio/5639-40744-0020.wav\", \"Thus did this humane and right minded father comfort his unhappy daughter, and her mother embracing her again, did all she could to soothe her feelings.\")\n",
431
+ "reference_dicts['908-157963'] = (\"Demo/reference_audio/908-157963-0027.wav\", \"And lay me down in my cold bed and leave my shining lot.\")\n",
432
+ "reference_dicts['4077-13754'] = (\"Demo/reference_audio/4077-13754-0000.wav\", \"The army found the people in poverty and left them in comparative wealth.\")"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "id": "653f1406",
439
+ "metadata": {},
440
+ "outputs": [],
441
+ "source": [
442
+ "noise = torch.randn(1,1,256).to(device)\n",
443
+ "for k, v in reference_dicts.items():\n",
444
+ " path, text = v\n",
445
+ " ref_s = compute_style(path)\n",
446
+ " start = time.time()\n",
447
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
448
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
449
+ " print(f\"RTF = {rtf:5f}\")\n",
450
+ " import IPython.display as ipd\n",
451
+ " print(k + ' Synthesized: ' + text)\n",
452
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
453
+ " print(k + ' Reference:')\n",
454
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "markdown",
459
+ "id": "141e91b3",
460
+ "metadata": {},
461
+ "source": [
462
+ "### Speech expressiveness\n",
463
+ "\n",
464
+ "The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page. The speaker reference used is `1221-135767-0014.wav`, which is unseen during training. \n",
465
+ "\n",
466
+ "#### With `embedding_scale=1`\n",
467
+ "This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional.\n",
468
+ "\n"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": null,
474
+ "id": "81addda4",
475
+ "metadata": {},
476
+ "outputs": [],
477
+ "source": [
478
+ "ref_s = compute_style(\"Demo/reference_audio/1221-135767-0014.wav\")"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "id": "be1b2a11",
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": [
488
+ "texts = {}\n",
489
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
490
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
491
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
492
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
493
+ "\n",
494
+ "for k,v in texts.items():\n",
495
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
496
+ " print(k + \": \")\n",
497
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "markdown",
502
+ "id": "96d262b8",
503
+ "metadata": {},
504
+ "source": [
505
+ "#### With `embedding_scale=2`"
506
+ ]
507
+ },
508
+ {
509
+ "cell_type": "code",
510
+ "execution_count": null,
511
+ "id": "3e7d40b4",
512
+ "metadata": {},
513
+ "outputs": [],
514
+ "source": [
515
+ "texts = {}\n",
516
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
517
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
518
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
519
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
520
+ "\n",
521
+ "for k,v in texts.items():\n",
522
+ " noise = torch.randn(1,1,256).to(device)\n",
523
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=2)\n",
524
+ " print(k + \": \")\n",
525
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "markdown",
530
+ "id": "402b2bd6",
531
+ "metadata": {},
532
+ "source": [
533
+ "#### With `embedding_scale=2, alpha = 0.5, beta = 0.9`\n",
534
+ "`alpha` and `beta` is the factor to determine much we use the style sampled based on the text instead of the reference. The higher the value of `alpha` and `beta`, the more suitable the style it is to the text but less similar to the reference. Using higher beta makes the synthesized speech more emotional, at the cost of lower similarity to the reference. `alpha` determines the timbre of the speaker while `beta` determines the prosody. "
535
+ ]
536
+ },
537
+ {
538
+ "cell_type": "code",
539
+ "execution_count": null,
540
+ "id": "599de5d5",
541
+ "metadata": {},
542
+ "outputs": [],
543
+ "source": [
544
+ "texts = {}\n",
545
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
546
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
547
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
548
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
549
+ "\n",
550
+ "for k,v in texts.items():\n",
551
+ " noise = torch.randn(1,1,256).to(device)\n",
552
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=2)\n",
553
+ " print(k + \": \")\n",
554
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "markdown",
559
+ "id": "48548866",
560
+ "metadata": {},
561
+ "source": [
562
+ "### Zero-shot speaker adaptation\n",
563
+ "This section recreates the \"Acoustic Environment Maintenance\" and \"Speaker’s Emotion Maintenance\" demo in [Section 4](https://styletts2.github.io/#libri) of the demo page. You can compare the generated samples to popular zero-shot TTS models like Vall-E. Note that the model was trained only on LibriTTS, which is about 250 times fewer data compared to those used to trian Vall-E with similar or better effect for these maintainance. "
564
+ ]
565
+ },
566
+ {
567
+ "cell_type": "markdown",
568
+ "id": "23e81572",
569
+ "metadata": {},
570
+ "source": [
571
+ "#### Acoustic Environment Maintenance\n",
572
+ "\n",
573
+ "Since we want to maintain the acoustic environment in the speaker (timbre), we set `alpha = 0` to make the speaker as closer to the reference as possible while only changing the prosody according to the text. "
574
+ ]
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "execution_count": null,
579
+ "id": "8087bccb",
580
+ "metadata": {},
581
+ "outputs": [],
582
+ "source": [
583
+ "reference_dicts = {}\n",
584
+ "# format: (path, text)\n",
585
+ "reference_dicts['3'] = (\"Demo/reference_audio/3.wav\", \"As friends thing I definitely I've got more male friends.\")\n",
586
+ "reference_dicts['4'] = (\"Demo/reference_audio/4.wav\", \"Everything is run by computer but you got to know how to think before you can do a computer.\")\n",
587
+ "reference_dicts['5'] = (\"Demo/reference_audio/5.wav\", \"Then out in LA you guys got a whole another ball game within California to worry about.\")"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "code",
592
+ "execution_count": null,
593
+ "id": "1e99c200",
594
+ "metadata": {},
595
+ "outputs": [],
596
+ "source": [
597
+ "noise = torch.randn(1,1,256).to(device)\n",
598
+ "for k, v in reference_dicts.items():\n",
599
+ " path, text = v\n",
600
+ " ref_s = compute_style(path)\n",
601
+ " start = time.time()\n",
602
+ " wav = inference(text, ref_s, alpha=0.0, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
603
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
604
+ " print(f\"RTF = {rtf:5f}\")\n",
605
+ " import IPython.display as ipd\n",
606
+ " print('Synthesized: ' + text)\n",
607
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
608
+ " print('Reference:')\n",
609
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "markdown",
614
+ "id": "7d56505d",
615
+ "metadata": {},
616
+ "source": [
617
+ "#### Speaker’s Emotion Maintenance\n",
618
+ "\n",
619
+ "Since we want to maintain the emotion in the speaker (prosody), we set `beta = 0.1` to make the speaker as closer to the reference as possible while having some diversity thruogh the slight timbre change."
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "code",
624
+ "execution_count": null,
625
+ "id": "f90179e7",
626
+ "metadata": {},
627
+ "outputs": [],
628
+ "source": [
629
+ "reference_dicts = {}\n",
630
+ "# format: (path, text)\n",
631
+ "reference_dicts['Anger'] = (\"Demo/reference_audio/anger.wav\", \"We have to reduce the number of plastic bags.\")\n",
632
+ "reference_dicts['Sleepy'] = (\"Demo/reference_audio/sleepy.wav\", \"We have to reduce the number of plastic bags.\")\n",
633
+ "reference_dicts['Amused'] = (\"Demo/reference_audio/amused.wav\", \"We have to reduce the number of plastic bags.\")\n",
634
+ "reference_dicts['Disgusted'] = (\"Demo/reference_audio/disgusted.wav\", \"We have to reduce the number of plastic bags.\")"
635
+ ]
636
+ },
637
+ {
638
+ "cell_type": "code",
639
+ "execution_count": null,
640
+ "id": "2e6bdfed",
641
+ "metadata": {},
642
+ "outputs": [],
643
+ "source": [
644
+ "noise = torch.randn(1,1,256).to(device)\n",
645
+ "for k, v in reference_dicts.items():\n",
646
+ " path, text = v\n",
647
+ " ref_s = compute_style(path)\n",
648
+ " start = time.time()\n",
649
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.1, diffusion_steps=10, embedding_scale=1)\n",
650
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
651
+ " print(f\"RTF = {rtf:5f}\")\n",
652
+ " import IPython.display as ipd\n",
653
+ " print(k + ' Synthesized: ' + text)\n",
654
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
655
+ " print(k + ' Reference:')\n",
656
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
657
+ ]
658
+ },
659
+ {
660
+ "cell_type": "markdown",
661
+ "id": "37ae3963",
662
+ "metadata": {},
663
+ "source": [
664
+ "### Longform Narration\n",
665
+ "\n",
666
+ "This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page."
667
+ ]
668
+ },
669
+ {
670
+ "cell_type": "code",
671
+ "execution_count": null,
672
+ "id": "f12a716b",
673
+ "metadata": {},
674
+ "outputs": [],
675
+ "source": [
676
+ "passage = '''If the supply of fruit is greater than the family needs, it may be made a source of income by sending the fresh fruit to the market if there is one near enough, or by preserving, canning, and making jelly for sale. To make such an enterprise a success the fruit and work must be first class. There is magic in the word \"Homemade,\" when the product appeals to the eye and the palate; but many careless and incompetent people have found to their sorrow that this word has not magic enough to float inferior goods on the market. As a rule large canning and preserving establishments are clean and have the best appliances, and they employ chemists and skilled labor. The home product must be very good to compete with the attractive goods that are sent out from such establishments. Yet for first class home made products there is a market in all large cities. All first-class grocers have customers who purchase such goods.'''"
677
+ ]
678
+ },
679
+ {
680
+ "cell_type": "code",
681
+ "execution_count": null,
682
+ "id": "a1a38079",
683
+ "metadata": {},
684
+ "outputs": [],
685
+ "source": [
686
+ "def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):\n",
687
+ " text = text.strip()\n",
688
+ " ps = global_phonemizer.phonemize([text])\n",
689
+ " ps = word_tokenize(ps[0])\n",
690
+ " ps = ' '.join(ps)\n",
691
+ " ps = ps.replace('``', '\"')\n",
692
+ " ps = ps.replace(\"''\", '\"')\n",
693
+ "\n",
694
+ " tokens = textclenaer(ps)\n",
695
+ " tokens.insert(0, 0)\n",
696
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
697
+ " \n",
698
+ " with torch.no_grad():\n",
699
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
700
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
701
+ "\n",
702
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
703
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
704
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
705
+ "\n",
706
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
707
+ " embedding=bert_dur,\n",
708
+ " embedding_scale=embedding_scale,\n",
709
+ " features=ref_s, # reference from the same speaker as the embedding\n",
710
+ " num_steps=diffusion_steps).squeeze(1)\n",
711
+ " \n",
712
+ " if s_prev is not None:\n",
713
+ " # convex combination of previous and current style\n",
714
+ " s_pred = t * s_prev + (1 - t) * s_pred\n",
715
+ " \n",
716
+ " s = s_pred[:, 128:]\n",
717
+ " ref = s_pred[:, :128]\n",
718
+ " \n",
719
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
720
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
721
+ "\n",
722
+ " s_pred = torch.cat([ref, s], dim=-1)\n",
723
+ "\n",
724
+ " d = model.predictor.text_encoder(d_en, \n",
725
+ " s, input_lengths, text_mask)\n",
726
+ "\n",
727
+ " x, _ = model.predictor.lstm(d)\n",
728
+ " duration = model.predictor.duration_proj(x)\n",
729
+ "\n",
730
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
731
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
732
+ "\n",
733
+ "\n",
734
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
735
+ " c_frame = 0\n",
736
+ " for i in range(pred_aln_trg.size(0)):\n",
737
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
738
+ " c_frame += int(pred_dur[i].data)\n",
739
+ "\n",
740
+ " # encode prosody\n",
741
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
742
+ " if model_params.decoder.type == \"hifigan\":\n",
743
+ " asr_new = torch.zeros_like(en)\n",
744
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
745
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
746
+ " en = asr_new\n",
747
+ "\n",
748
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
749
+ "\n",
750
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
751
+ " if model_params.decoder.type == \"hifigan\":\n",
752
+ " asr_new = torch.zeros_like(asr)\n",
753
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
754
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
755
+ " asr = asr_new\n",
756
+ "\n",
757
+ " out = model.decoder(asr, \n",
758
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
759
+ " \n",
760
+ " \n",
761
+ " return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later"
762
+ ]
763
+ },
764
+ {
765
+ "cell_type": "code",
766
+ "execution_count": null,
767
+ "id": "e9088f7a",
768
+ "metadata": {},
769
+ "outputs": [],
770
+ "source": [
771
+ "# unseen speaker\n",
772
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
773
+ "s_ref = compute_style(path)\n",
774
+ "sentences = passage.split('.') # simple split by comma\n",
775
+ "wavs = []\n",
776
+ "s_prev = None\n",
777
+ "for text in sentences:\n",
778
+ " if text.strip() == \"\": continue\n",
779
+ " text += '.' # add it back\n",
780
+ " \n",
781
+ " wav, s_prev = LFinference(text, \n",
782
+ " s_prev, \n",
783
+ " s_ref, \n",
784
+ " alpha = 0.3, \n",
785
+ " beta = 0.9, # make it more suitable for the text\n",
786
+ " t = 0.7, \n",
787
+ " diffusion_steps=10, embedding_scale=1.5)\n",
788
+ " wavs.append(wav)\n",
789
+ "print('Synthesized: ')\n",
790
+ "display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))\n",
791
+ "print('Reference: ')\n",
792
+ "display(ipd.Audio(path, rate=24000, normalize=False))"
793
+ ]
794
+ },
795
+ {
796
+ "cell_type": "markdown",
797
+ "id": "7517b657",
798
+ "metadata": {},
799
+ "source": [
800
+ "### Style Transfer\n",
801
+ "\n",
802
+ "The following section demostrates the style transfer capacity for unseen speakers in [Section 6](https://styletts2.github.io/#emo) of the demo page. For this, we set `alpha=0.5, beta = 0.9` for the most pronounced effects (mostly using the sampled style). "
803
+ ]
804
+ },
805
+ {
806
+ "cell_type": "code",
807
+ "execution_count": null,
808
+ "id": "ed95d0f7",
809
+ "metadata": {},
810
+ "outputs": [],
811
+ "source": [
812
+ "def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
813
+ " text = text.strip()\n",
814
+ " ps = global_phonemizer.phonemize([text])\n",
815
+ " ps = word_tokenize(ps[0])\n",
816
+ " ps = ' '.join(ps)\n",
817
+ "\n",
818
+ " tokens = textclenaer(ps)\n",
819
+ " tokens.insert(0, 0)\n",
820
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
821
+ " \n",
822
+ " ref_text = ref_text.strip()\n",
823
+ " ps = global_phonemizer.phonemize([ref_text])\n",
824
+ " ps = word_tokenize(ps[0])\n",
825
+ " ps = ' '.join(ps)\n",
826
+ "\n",
827
+ " ref_tokens = textclenaer(ps)\n",
828
+ " ref_tokens.insert(0, 0)\n",
829
+ " ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)\n",
830
+ " \n",
831
+ " \n",
832
+ " with torch.no_grad():\n",
833
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
834
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
835
+ "\n",
836
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
837
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
838
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
839
+ " \n",
840
+ " ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)\n",
841
+ " ref_text_mask = length_to_mask(ref_input_lengths).to(device)\n",
842
+ " ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())\n",
843
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
844
+ " embedding=bert_dur,\n",
845
+ " embedding_scale=embedding_scale,\n",
846
+ " features=ref_s, # reference from the same speaker as the embedding\n",
847
+ " num_steps=diffusion_steps).squeeze(1)\n",
848
+ "\n",
849
+ "\n",
850
+ " s = s_pred[:, 128:]\n",
851
+ " ref = s_pred[:, :128]\n",
852
+ "\n",
853
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
854
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
855
+ "\n",
856
+ " d = model.predictor.text_encoder(d_en, \n",
857
+ " s, input_lengths, text_mask)\n",
858
+ "\n",
859
+ " x, _ = model.predictor.lstm(d)\n",
860
+ " duration = model.predictor.duration_proj(x)\n",
861
+ "\n",
862
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
863
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
864
+ "\n",
865
+ "\n",
866
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
867
+ " c_frame = 0\n",
868
+ " for i in range(pred_aln_trg.size(0)):\n",
869
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
870
+ " c_frame += int(pred_dur[i].data)\n",
871
+ "\n",
872
+ " # encode prosody\n",
873
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
874
+ " if model_params.decoder.type == \"hifigan\":\n",
875
+ " asr_new = torch.zeros_like(en)\n",
876
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
877
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
878
+ " en = asr_new\n",
879
+ "\n",
880
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
881
+ "\n",
882
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
883
+ " if model_params.decoder.type == \"hifigan\":\n",
884
+ " asr_new = torch.zeros_like(asr)\n",
885
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
886
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
887
+ " asr = asr_new\n",
888
+ "\n",
889
+ " out = model.decoder(asr, \n",
890
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
891
+ " \n",
892
+ " \n",
893
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
894
+ ]
895
+ },
896
+ {
897
+ "cell_type": "code",
898
+ "execution_count": null,
899
+ "id": "ec3f0da4",
900
+ "metadata": {},
901
+ "outputs": [],
902
+ "source": [
903
+ "# reference texts to sample styles\n",
904
+ "\n",
905
+ "ref_texts = {}\n",
906
+ "ref_texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
907
+ "ref_texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
908
+ "ref_texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
909
+ "ref_texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\""
910
+ ]
911
+ },
912
+ {
913
+ "cell_type": "code",
914
+ "execution_count": null,
915
+ "id": "6d0a3825",
916
+ "metadata": {
917
+ "scrolled": false
918
+ },
919
+ "outputs": [],
920
+ "source": [
921
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
922
+ "s_ref = compute_style(path)\n",
923
+ "\n",
924
+ "text = \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\"\n",
925
+ "for k,v in ref_texts.items():\n",
926
+ " wav = STinference(text, s_ref, v, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=1.5)\n",
927
+ " print(k + \": \")\n",
928
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
929
+ ]
930
+ },
931
+ {
932
+ "cell_type": "markdown",
933
+ "id": "6750aed9",
934
+ "metadata": {},
935
+ "source": [
936
+ "### Speech diversity\n",
937
+ "\n",
938
+ "This section reproduces samples in [Section 7](https://styletts2.github.io/#var) of the demo page. \n",
939
+ "\n",
940
+ "`alpha` and `beta` determine the diversity of the synthesized speech. There are two extreme cases:\n",
941
+ "- If `alpha = 1` and `beta = 1`, the synthesized speech sounds the most dissimilar to the reference speaker, but it is also the most diverse (each time you synthesize a speech it will be totally different). \n",
942
+ "- If `alpha = 0` and `beta = 0`, the synthesized speech sounds the most siimlar to the reference speaker, but it is deterministic (i.e., the sampled style is not used for speech synthesis). \n"
943
+ ]
944
+ },
945
+ {
946
+ "cell_type": "markdown",
947
+ "id": "f6ae0aa5",
948
+ "metadata": {},
949
+ "source": [
950
+ "#### Default setting (`alpha = 0.3, beta=0.7`)\n",
951
+ "This setting uses 70% of the reference timbre and 30% of the reference prosody and use the diffusion model to sample them based on the text. "
952
+ ]
953
+ },
954
+ {
955
+ "cell_type": "code",
956
+ "execution_count": null,
957
+ "id": "36dc0148",
958
+ "metadata": {},
959
+ "outputs": [],
960
+ "source": [
961
+ "# unseen speaker\n",
962
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
963
+ "ref_s = compute_style(path)\n",
964
+ "\n",
965
+ "text = \"How much variation is there?\"\n",
966
+ "for _ in range(5):\n",
967
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
968
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
969
+ ]
970
+ },
971
+ {
972
+ "cell_type": "markdown",
973
+ "id": "bf9ef421",
974
+ "metadata": {},
975
+ "source": [
976
+ "#### Less diverse setting (`alpha = 0.1, beta=0.3`)\n",
977
+ "This setting uses 90% of the reference timbre and 70% of the reference prosody. This makes it more similar to the reference speaker at cost of less diverse samples. "
978
+ ]
979
+ },
980
+ {
981
+ "cell_type": "code",
982
+ "execution_count": null,
983
+ "id": "9ba406bd",
984
+ "metadata": {},
985
+ "outputs": [],
986
+ "source": [
987
+ "# unseen speaker\n",
988
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
989
+ "ref_s = compute_style(path)\n",
990
+ "\n",
991
+ "text = \"How much variation is there?\"\n",
992
+ "for _ in range(5):\n",
993
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.1, beta=0.3, embedding_scale=1)\n",
994
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
995
+ ]
996
+ },
997
+ {
998
+ "cell_type": "markdown",
999
+ "id": "a38fe464",
1000
+ "metadata": {},
1001
+ "source": [
1002
+ "#### More diverse setting (`alpha = 0.5, beta=0.95`)\n",
1003
+ "This setting uses 50% of the reference timbre and 5% of the reference prosody (so it uses 100% of the sampled prosody, which makes it more diverse), but this makes it more dissimilar to the reference speaker. "
1004
+ ]
1005
+ },
1006
+ {
1007
+ "cell_type": "code",
1008
+ "execution_count": null,
1009
+ "id": "5f25bf94",
1010
+ "metadata": {},
1011
+ "outputs": [],
1012
+ "source": [
1013
+ "# unseen speaker\n",
1014
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1015
+ "ref_s = compute_style(path)\n",
1016
+ "\n",
1017
+ "text = \"How much variation is there?\"\n",
1018
+ "for _ in range(5):\n",
1019
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.5, beta=0.95, embedding_scale=1)\n",
1020
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1021
+ ]
1022
+ },
1023
+ {
1024
+ "cell_type": "markdown",
1025
+ "id": "21c3a071",
1026
+ "metadata": {},
1027
+ "source": [
1028
+ "#### Extreme setting (`alpha = 1, beta=1`)\n",
1029
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very dissimilar to the reference speaker. "
1030
+ ]
1031
+ },
1032
+ {
1033
+ "cell_type": "code",
1034
+ "execution_count": null,
1035
+ "id": "fff8bab1",
1036
+ "metadata": {},
1037
+ "outputs": [],
1038
+ "source": [
1039
+ "# unseen speaker\n",
1040
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1041
+ "ref_s = compute_style(path)\n",
1042
+ "\n",
1043
+ "text = \"How much variation is there?\"\n",
1044
+ "for _ in range(5):\n",
1045
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=1, beta=1, embedding_scale=1)\n",
1046
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1047
+ ]
1048
+ },
1049
+ {
1050
+ "cell_type": "markdown",
1051
+ "id": "a8741e5a",
1052
+ "metadata": {},
1053
+ "source": [
1054
+ "#### No variation (`alpha = 0, beta=0`)\n",
1055
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very similar to the reference speaker, but there is no variation. "
1056
+ ]
1057
+ },
1058
+ {
1059
+ "cell_type": "code",
1060
+ "execution_count": null,
1061
+ "id": "e55dd281",
1062
+ "metadata": {},
1063
+ "outputs": [],
1064
+ "source": [
1065
+ "# unseen speaker\n",
1066
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1067
+ "ref_s = compute_style(path)\n",
1068
+ "\n",
1069
+ "text = \"How much variation is there?\"\n",
1070
+ "for _ in range(5):\n",
1071
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0, beta=0, embedding_scale=1)\n",
1072
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1073
+ ]
1074
+ },
1075
+ {
1076
+ "cell_type": "markdown",
1077
+ "id": "d5e86423",
1078
+ "metadata": {},
1079
+ "source": [
1080
+ "### Extra fun!\n",
1081
+ "\n",
1082
+ "Here we clone some of the authors' voice of the StyleTTS 2 papers with a few seconds of the recording in the wild. None of the voices is in the dataset and all authors agreed to have their voices cloned here."
1083
+ ]
1084
+ },
1085
+ {
1086
+ "cell_type": "code",
1087
+ "execution_count": null,
1088
+ "id": "6f558314",
1089
+ "metadata": {},
1090
+ "outputs": [],
1091
+ "source": [
1092
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. '''"
1093
+ ]
1094
+ },
1095
+ {
1096
+ "cell_type": "code",
1097
+ "execution_count": null,
1098
+ "id": "caa5747c",
1099
+ "metadata": {},
1100
+ "outputs": [],
1101
+ "source": [
1102
+ "reference_dicts = {}\n",
1103
+ "reference_dicts['Yinghao'] = \"Demo/reference_audio/Yinghao.wav\"\n",
1104
+ "reference_dicts['Gavin'] = \"Demo/reference_audio/Gavin.wav\"\n",
1105
+ "reference_dicts['Vinay'] = \"Demo/reference_audio/Vinay.wav\"\n",
1106
+ "reference_dicts['Nima'] = \"Demo/reference_audio/Nima.wav\""
1107
+ ]
1108
+ },
1109
+ {
1110
+ "cell_type": "code",
1111
+ "execution_count": null,
1112
+ "id": "44a4cea1",
1113
+ "metadata": {
1114
+ "scrolled": false
1115
+ },
1116
+ "outputs": [],
1117
+ "source": [
1118
+ "start = time.time()\n",
1119
+ "noise = torch.randn(1,1,256).to(device)\n",
1120
+ "for k, path in reference_dicts.items():\n",
1121
+ " ref_s = compute_style(path)\n",
1122
+ " \n",
1123
+ " wav = inference(text, ref_s, alpha=0.1, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
1124
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
1125
+ " print('Speaker: ' + k)\n",
1126
+ " import IPython.display as ipd\n",
1127
+ " print('Synthesized:')\n",
1128
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
1129
+ " print('Reference:')\n",
1130
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
1131
+ ]
1132
+ }
1133
+ ],
1134
+ "metadata": {
1135
+ "kernelspec": {
1136
+ "display_name": "NLP",
1137
+ "language": "python",
1138
+ "name": "nlp"
1139
+ },
1140
+ "language_info": {
1141
+ "codemirror_mode": {
1142
+ "name": "ipython",
1143
+ "version": 3
1144
+ },
1145
+ "file_extension": ".py",
1146
+ "mimetype": "text/x-python",
1147
+ "name": "python",
1148
+ "nbconvert_exporter": "python",
1149
+ "pygments_lexer": "ipython3",
1150
+ "version": "3.9.7"
1151
+ }
1152
+ },
1153
+ "nbformat": 4,
1154
+ "nbformat_minor": 5
1155
+ }
Demo/Inference_pod_90h_30k.ipynb ADDED
@@ -0,0 +1,1360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9adb7bd1",
6
+ "metadata": {},
7
+ "source": [
8
+ "# StyleTTS 2 Demo (LibriTTS)\n",
9
+ "\n",
10
+ "Before you run the following cells, please make sure you have downloaded [reference_audio.zip](https://huggingface.co/yl4579/StyleTTS2-LibriTTS/resolve/main/reference_audio.zip) and unzipped it under the `demo` folder."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "6108384d",
16
+ "metadata": {},
17
+ "source": [
18
+ "### Utils"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "96e173bf",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import torch\n",
29
+ "torch.manual_seed(0)\n",
30
+ "torch.backends.cudnn.benchmark = False\n",
31
+ "torch.backends.cudnn.deterministic = True\n",
32
+ "\n",
33
+ "import random\n",
34
+ "random.seed(0)\n",
35
+ "\n",
36
+ "import numpy as np\n",
37
+ "np.random.seed(0)"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": 4,
43
+ "id": "2458c639-10a0-4b57-8602-22bc893c5176",
44
+ "metadata": {},
45
+ "outputs": [
46
+ {
47
+ "name": "stdout",
48
+ "output_type": "stream",
49
+ "text": [
50
+ "Collecting git+https://github.com/resemble-ai/monotonic_align.git (from -r requirements.txt (line 17))\n",
51
+ " Cloning https://github.com/resemble-ai/monotonic_align.git to /tmp/pip-req-build-ps9pa2ga\n",
52
+ " Running command git clone --filter=blob:none --quiet https://github.com/resemble-ai/monotonic_align.git /tmp/pip-req-build-ps9pa2ga\n",
53
+ " Resolved https://github.com/resemble-ai/monotonic_align.git to commit c6e5e6cb19882164027eb6e35118e841eed9298e\n",
54
+ " Installing build dependencies ... \u001b[?25ldone\n",
55
+ "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
56
+ "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
57
+ "\u001b[?25hCollecting SoundFile (from -r requirements.txt (line 1))\n",
58
+ " Using cached soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)\n",
59
+ "Requirement already satisfied: torchaudio in /venv/main/lib/python3.12/site-packages (from -r requirements.txt (line 2)) (2.6.0+cu126)\n",
60
+ "Collecting munch (from -r requirements.txt (line 3))\n",
61
+ " Using cached munch-4.0.0-py2.py3-none-any.whl.metadata (5.9 kB)\n",
62
+ "Requirement already satisfied: torch in /venv/main/lib/python3.12/site-packages (from -r requirements.txt (line 4)) (2.6.0+cu126)\n",
63
+ "Collecting pydub (from -r requirements.txt (line 5))\n",
64
+ " Using cached pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)\n",
65
+ "Requirement already satisfied: pyyaml in /venv/main/lib/python3.12/site-packages (from -r requirements.txt (line 6)) (6.0.2)\n",
66
+ "Collecting librosa (from -r requirements.txt (line 7))\n",
67
+ " Using cached librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)\n",
68
+ "Collecting nltk (from -r requirements.txt (line 8))\n",
69
+ " Using cached nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)\n",
70
+ "Collecting matplotlib (from -r requirements.txt (line 9))\n",
71
+ " Downloading matplotlib-3.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)\n",
72
+ "Collecting accelerate (from -r requirements.txt (line 10))\n",
73
+ " Using cached accelerate-1.7.0-py3-none-any.whl.metadata (19 kB)\n",
74
+ "Collecting transformers (from -r requirements.txt (line 11))\n",
75
+ " Using cached transformers-4.52.4-py3-none-any.whl.metadata (38 kB)\n",
76
+ "Collecting einops (from -r requirements.txt (line 12))\n",
77
+ " Using cached einops-0.8.1-py3-none-any.whl.metadata (13 kB)\n",
78
+ "Collecting einops-exts (from -r requirements.txt (line 13))\n",
79
+ " Using cached einops_exts-0.0.4-py3-none-any.whl.metadata (621 bytes)\n",
80
+ "Requirement already satisfied: tqdm in /venv/main/lib/python3.12/site-packages (from -r requirements.txt (line 14)) (4.67.1)\n",
81
+ "Collecting typing (from -r requirements.txt (line 15))\n",
82
+ " Using cached typing-3.7.4.3.tar.gz (78 kB)\n",
83
+ " Preparing metadata (setup.py) ... \u001b[?25ldone\n",
84
+ "\u001b[?25hRequirement already satisfied: typing-extensions in /venv/main/lib/python3.12/site-packages (from -r requirements.txt (line 16)) (4.13.2)\n",
85
+ "Collecting cffi>=1.0 (from SoundFile->-r requirements.txt (line 1))\n",
86
+ " Downloading cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
87
+ "Requirement already satisfied: numpy in /venv/main/lib/python3.12/site-packages (from SoundFile->-r requirements.txt (line 1)) (2.1.2)\n",
88
+ "Requirement already satisfied: filelock in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (3.18.0)\n",
89
+ "Requirement already satisfied: setuptools in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (70.2.0)\n",
90
+ "Requirement already satisfied: sympy==1.13.1 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (1.13.1)\n",
91
+ "Requirement already satisfied: networkx in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (3.3)\n",
92
+ "Requirement already satisfied: jinja2 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (3.1.4)\n",
93
+ "Requirement already satisfied: fsspec in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (2025.3.2)\n",
94
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (12.6.77)\n",
95
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (12.6.77)\n",
96
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (12.6.80)\n",
97
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (9.5.1.17)\n",
98
+ "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (12.6.4.1)\n",
99
+ "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (11.3.0.4)\n",
100
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (10.3.7.77)\n",
101
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (11.7.1.2)\n",
102
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (12.5.4.2)\n",
103
+ "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (0.6.3)\n",
104
+ "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (2.21.5)\n",
105
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (12.6.77)\n",
106
+ "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (12.6.85)\n",
107
+ "Requirement already satisfied: triton==3.2.0 in /venv/main/lib/python3.12/site-packages (from torch->-r requirements.txt (line 4)) (3.2.0)\n",
108
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /venv/main/lib/python3.12/site-packages (from sympy==1.13.1->torch->-r requirements.txt (line 4)) (1.3.0)\n",
109
+ "Collecting audioread>=2.1.9 (from librosa->-r requirements.txt (line 7))\n",
110
+ " Using cached audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)\n",
111
+ "Collecting numba>=0.51.0 (from librosa->-r requirements.txt (line 7))\n",
112
+ " Downloading numba-0.61.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)\n",
113
+ "Collecting scipy>=1.6.0 (from librosa->-r requirements.txt (line 7))\n",
114
+ " Downloading scipy-1.15.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
115
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.0/62.0 kB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
116
+ "\u001b[?25hCollecting scikit-learn>=1.1.0 (from librosa->-r requirements.txt (line 7))\n",
117
+ " Downloading scikit_learn-1.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (17 kB)\n",
118
+ "Collecting joblib>=1.0 (from librosa->-r requirements.txt (line 7))\n",
119
+ " Using cached joblib-1.5.1-py3-none-any.whl.metadata (5.6 kB)\n",
120
+ "Requirement already satisfied: decorator>=4.3.0 in /venv/main/lib/python3.12/site-packages (from librosa->-r requirements.txt (line 7)) (5.2.1)\n",
121
+ "Collecting pooch>=1.1 (from librosa->-r requirements.txt (line 7))\n",
122
+ " Using cached pooch-1.8.2-py3-none-any.whl.metadata (10 kB)\n",
123
+ "Collecting soxr>=0.3.2 (from librosa->-r requirements.txt (line 7))\n",
124
+ " Downloading soxr-0.5.0.post1-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)\n",
125
+ "Collecting lazy_loader>=0.1 (from librosa->-r requirements.txt (line 7))\n",
126
+ " Using cached lazy_loader-0.4-py3-none-any.whl.metadata (7.6 kB)\n",
127
+ "Collecting msgpack>=1.0 (from librosa->-r requirements.txt (line 7))\n",
128
+ " Downloading msgpack-1.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)\n",
129
+ "Collecting click (from nltk->-r requirements.txt (line 8))\n",
130
+ " Using cached click-8.2.1-py3-none-any.whl.metadata (2.5 kB)\n",
131
+ "Collecting regex>=2021.8.3 (from nltk->-r requirements.txt (line 8))\n",
132
+ " Downloading regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)\n",
133
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.5/40.5 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
134
+ "\u001b[?25hCollecting contourpy>=1.0.1 (from matplotlib->-r requirements.txt (line 9))\n",
135
+ " Downloading contourpy-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)\n",
136
+ "Collecting cycler>=0.10 (from matplotlib->-r requirements.txt (line 9))\n",
137
+ " Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)\n",
138
+ "Collecting fonttools>=4.22.0 (from matplotlib->-r requirements.txt (line 9))\n",
139
+ " Downloading fonttools-4.58.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (106 kB)\n",
140
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m106.3/106.3 kB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
141
+ "\u001b[?25hCollecting kiwisolver>=1.3.1 (from matplotlib->-r requirements.txt (line 9))\n",
142
+ " Downloading kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.2 kB)\n",
143
+ "Requirement already satisfied: packaging>=20.0 in /venv/main/lib/python3.12/site-packages (from matplotlib->-r requirements.txt (line 9)) (25.0)\n",
144
+ "Requirement already satisfied: pillow>=8 in /venv/main/lib/python3.12/site-packages (from matplotlib->-r requirements.txt (line 9)) (11.0.0)\n",
145
+ "Collecting pyparsing>=2.3.1 (from matplotlib->-r requirements.txt (line 9))\n",
146
+ " Using cached pyparsing-3.2.3-py3-none-any.whl.metadata (5.0 kB)\n",
147
+ "Requirement already satisfied: python-dateutil>=2.7 in /venv/main/lib/python3.12/site-packages (from matplotlib->-r requirements.txt (line 9)) (2.9.0.post0)\n",
148
+ "Requirement already satisfied: psutil in /venv/main/lib/python3.12/site-packages (from accelerate->-r requirements.txt (line 10)) (7.0.0)\n",
149
+ "Requirement already satisfied: huggingface-hub>=0.21.0 in /venv/main/lib/python3.12/site-packages (from accelerate->-r requirements.txt (line 10)) (0.30.2)\n",
150
+ "Collecting safetensors>=0.4.3 (from accelerate->-r requirements.txt (line 10))\n",
151
+ " Using cached safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)\n",
152
+ "Requirement already satisfied: requests in /venv/main/lib/python3.12/site-packages (from transformers->-r requirements.txt (line 11)) (2.32.3)\n",
153
+ "Collecting tokenizers<0.22,>=0.21 (from transformers->-r requirements.txt (line 11))\n",
154
+ " Using cached tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)\n",
155
+ "Collecting pycparser (from cffi>=1.0->SoundFile->-r requirements.txt (line 1))\n",
156
+ " Using cached pycparser-2.22-py3-none-any.whl.metadata (943 bytes)\n",
157
+ "Collecting llvmlite<0.45,>=0.44.0dev0 (from numba>=0.51.0->librosa->-r requirements.txt (line 7))\n",
158
+ " Downloading llvmlite-0.44.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.0 kB)\n",
159
+ "Requirement already satisfied: platformdirs>=2.5.0 in /venv/main/lib/python3.12/site-packages (from pooch>=1.1->librosa->-r requirements.txt (line 7)) (4.3.7)\n",
160
+ "Requirement already satisfied: six>=1.5 in /venv/main/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib->-r requirements.txt (line 9)) (1.17.0)\n",
161
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /venv/main/lib/python3.12/site-packages (from requests->transformers->-r requirements.txt (line 11)) (3.4.1)\n",
162
+ "Requirement already satisfied: idna<4,>=2.5 in /venv/main/lib/python3.12/site-packages (from requests->transformers->-r requirements.txt (line 11)) (3.10)\n",
163
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /venv/main/lib/python3.12/site-packages (from requests->transformers->-r requirements.txt (line 11)) (2.4.0)\n",
164
+ "Requirement already satisfied: certifi>=2017.4.17 in /venv/main/lib/python3.12/site-packages (from requests->transformers->-r requirements.txt (line 11)) (2025.4.26)\n",
165
+ "Collecting threadpoolctl>=3.1.0 (from scikit-learn>=1.1.0->librosa->-r requirements.txt (line 7))\n",
166
+ " Using cached threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)\n",
167
+ "Requirement already satisfied: MarkupSafe>=2.0 in /venv/main/lib/python3.12/site-packages (from jinja2->torch->-r requirements.txt (line 4)) (2.1.5)\n",
168
+ "Using cached soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl (1.3 MB)\n",
169
+ "Using cached munch-4.0.0-py2.py3-none-any.whl (9.9 kB)\n",
170
+ "Using cached pydub-0.25.1-py2.py3-none-any.whl (32 kB)\n",
171
+ "Using cached librosa-0.11.0-py3-none-any.whl (260 kB)\n",
172
+ "Using cached nltk-3.9.1-py3-none-any.whl (1.5 MB)\n",
173
+ "Downloading matplotlib-3.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)\n",
174
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.6/8.6 MB\u001b[0m \u001b[31m28.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
175
+ "\u001b[?25hUsing cached accelerate-1.7.0-py3-none-any.whl (362 kB)\n",
176
+ "Using cached transformers-4.52.4-py3-none-any.whl (10.5 MB)\n",
177
+ "Using cached einops-0.8.1-py3-none-any.whl (64 kB)\n",
178
+ "Using cached einops_exts-0.0.4-py3-none-any.whl (3.9 kB)\n",
179
+ "Using cached audioread-3.0.1-py3-none-any.whl (23 kB)\n",
180
+ "Downloading cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (479 kB)\n",
181
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m479.4/479.4 kB\u001b[0m \u001b[31m169.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
182
+ "\u001b[?25hDownloading contourpy-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (323 kB)\n",
183
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m323.7/323.7 kB\u001b[0m \u001b[31m127.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
184
+ "\u001b[?25hUsing cached cycler-0.12.1-py3-none-any.whl (8.3 kB)\n",
185
+ "Downloading fonttools-4.58.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.9 MB)\n",
186
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m87.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mta \u001b[36m0:00:01\u001b[0m\n",
187
+ "\u001b[?25hUsing cached joblib-1.5.1-py3-none-any.whl (307 kB)\n",
188
+ "Downloading kiwisolver-1.4.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.5 MB)\n",
189
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.5/1.5 MB\u001b[0m \u001b[31m185.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
190
+ "\u001b[?25hUsing cached lazy_loader-0.4-py3-none-any.whl (12 kB)\n",
191
+ "Downloading msgpack-1.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (401 kB)\n",
192
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m401.4/401.4 kB\u001b[0m \u001b[31m192.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
193
+ "\u001b[?25hDownloading numba-0.61.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.9 MB)\n",
194
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.9/3.9 MB\u001b[0m \u001b[31m42.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mta \u001b[36m0:00:01\u001b[0m\n",
195
+ "\u001b[?25hUsing cached pooch-1.8.2-py3-none-any.whl (64 kB)\n",
196
+ "Using cached pyparsing-3.2.3-py3-none-any.whl (111 kB)\n",
197
+ "Downloading regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (796 kB)\n",
198
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m796.9/796.9 kB\u001b[0m \u001b[31m125.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
199
+ "\u001b[?25hUsing cached safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (471 kB)\n",
200
+ "Downloading scikit_learn-1.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.5 MB)\n",
201
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.5/12.5 MB\u001b[0m \u001b[31m43.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
202
+ "\u001b[?25hDownloading scipy-1.15.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.3 MB)\n",
203
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m37.3/37.3 MB\u001b[0m \u001b[31m26.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
204
+ "\u001b[?25hDownloading soxr-0.5.0.post1-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (248 kB)\n",
205
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m248.5/248.5 kB\u001b[0m \u001b[31m36.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
206
+ "\u001b[?25hUsing cached tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)\n",
207
+ "Using cached click-8.2.1-py3-none-any.whl (102 kB)\n",
208
+ "Downloading llvmlite-0.44.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (42.4 MB)\n",
209
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.4/42.4 MB\u001b[0m \u001b[31m14.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
210
+ "\u001b[?25hUsing cached threadpoolctl-3.6.0-py3-none-any.whl (18 kB)\n",
211
+ "Using cached pycparser-2.22-py3-none-any.whl (117 kB)\n",
212
+ "Building wheels for collected packages: typing, monotonic_align\n",
213
+ " Building wheel for typing (setup.py) ... \u001b[?25ldone\n",
214
+ "\u001b[?25h Created wheel for typing: filename=typing-3.7.4.3-py3-none-any.whl size=26304 sha256=7bd8523fe1f7cb4e20da87ee646956891addbdea2d87074f6bbf77fe282e8720\n",
215
+ " Stored in directory: /root/.cache/pip/wheels/12/98/52/2bffe242a9a487f00886e43b8ed8dac46456702e11a0d6abef\n",
216
+ " Building wheel for monotonic_align (pyproject.toml) ... \u001b[?25ldone\n",
217
+ "\u001b[?25h Created wheel for monotonic_align: filename=monotonic_align-1.2-cp312-cp312-linux_x86_64.whl size=1543517 sha256=dc9566d3e5a0656ebf939e760d934e0926d435f336db84e0019c7391576cd4cc\n",
218
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-0gzg26zy/wheels/76/0a/37/00634137cd000799e060087bd1cb49a060ac6a48fc42a15488\n",
219
+ "Successfully built typing monotonic_align\n",
220
+ "Installing collected packages: pydub, typing, threadpoolctl, soxr, scipy, safetensors, regex, pyparsing, pycparser, munch, msgpack, monotonic_align, llvmlite, lazy_loader, kiwisolver, joblib, fonttools, einops, cycler, contourpy, click, audioread, scikit-learn, pooch, numba, nltk, matplotlib, einops-exts, cffi, tokenizers, SoundFile, transformers, librosa, accelerate\n",
221
+ "Successfully installed SoundFile-0.13.1 accelerate-1.7.0 audioread-3.0.1 cffi-1.17.1 click-8.2.1 contourpy-1.3.2 cycler-0.12.1 einops-0.8.1 einops-exts-0.0.4 fonttools-4.58.2 joblib-1.5.1 kiwisolver-1.4.8 lazy_loader-0.4 librosa-0.11.0 llvmlite-0.44.0 matplotlib-3.10.3 monotonic_align-1.2 msgpack-1.1.0 munch-4.0.0 nltk-3.9.1 numba-0.61.2 pooch-1.8.2 pycparser-2.22 pydub-0.25.1 pyparsing-3.2.3 regex-2024.11.6 safetensors-0.5.3 scikit-learn-1.7.0 scipy-1.15.3 soxr-0.5.0.post1 threadpoolctl-3.6.0 tokenizers-0.21.1 transformers-4.52.4 typing-3.7.4.3\n"
222
+ ]
223
+ }
224
+ ],
225
+ "source": [
226
+ "!pip install -r requirements.txt"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": 2,
232
+ "id": "da84c60f",
233
+ "metadata": {},
234
+ "outputs": [
235
+ {
236
+ "name": "stdout",
237
+ "output_type": "stream",
238
+ "text": [
239
+ "/workspace/styletts2\n"
240
+ ]
241
+ }
242
+ ],
243
+ "source": [
244
+ "%cd .."
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": 3,
250
+ "id": "5a3ddcc8",
251
+ "metadata": {},
252
+ "outputs": [
253
+ {
254
+ "ename": "ModuleNotFoundError",
255
+ "evalue": "No module named 'munch'",
256
+ "output_type": "error",
257
+ "traceback": [
258
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
259
+ "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
260
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 5\u001b[39m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrandom\u001b[39;00m\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01myaml\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmunch\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Munch\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n",
261
+ "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'munch'"
262
+ ]
263
+ }
264
+ ],
265
+ "source": [
266
+ "# load packages\n",
267
+ "import time\n",
268
+ "import random\n",
269
+ "import yaml\n",
270
+ "from munch import Munch\n",
271
+ "import numpy as np\n",
272
+ "import torch\n",
273
+ "from torch import nn\n",
274
+ "import torch.nn.functional as F\n",
275
+ "import torchaudio\n",
276
+ "import librosa\n",
277
+ "from nltk.tokenize import word_tokenize\n",
278
+ "\n",
279
+ "from models import *\n",
280
+ "from utils import *\n",
281
+ "from text_utils import TextCleaner\n",
282
+ "textclenaer = TextCleaner()\n",
283
+ "\n",
284
+ "%matplotlib inline"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "id": "00ee05e1",
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
295
+ " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
296
+ "mean, std = -4, 4\n",
297
+ "\n",
298
+ "def length_to_mask(lengths):\n",
299
+ " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
300
+ " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
301
+ " return mask\n",
302
+ "\n",
303
+ "def preprocess(wave):\n",
304
+ " wave_tensor = torch.from_numpy(wave).float()\n",
305
+ " mel_tensor = to_mel(wave_tensor)\n",
306
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
307
+ " return mel_tensor\n",
308
+ "\n",
309
+ "def compute_style(path):\n",
310
+ " wave, sr = librosa.load(path, sr=24000)\n",
311
+ " audio, index = librosa.effects.trim(wave, top_db=30)\n",
312
+ " if sr != 24000:\n",
313
+ " audio = librosa.resample(audio, sr, 24000)\n",
314
+ " mel_tensor = preprocess(audio).to(device)\n",
315
+ "\n",
316
+ " with torch.no_grad():\n",
317
+ " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
318
+ " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
319
+ "\n",
320
+ " return torch.cat([ref_s, ref_p], dim=1)"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": null,
326
+ "id": "bbdc04c0",
327
+ "metadata": {},
328
+ "outputs": [],
329
+ "source": [
330
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "id": "7b9cecbe",
336
+ "metadata": {},
337
+ "source": [
338
+ "### Load models"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": null,
344
+ "id": "64fc4c0f",
345
+ "metadata": {},
346
+ "outputs": [],
347
+ "source": [
348
+ "# load phonemizer\n",
349
+ "import phonemizer\n",
350
+ "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "id": "48e7b644",
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "config = yaml.safe_load(open(\"Models/LibriTTS/config.yml\"))\n",
361
+ "\n",
362
+ "# load pretrained ASR model\n",
363
+ "ASR_config = config.get('ASR_config', False)\n",
364
+ "ASR_path = config.get('ASR_path', False)\n",
365
+ "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
366
+ "\n",
367
+ "# load pretrained F0 model\n",
368
+ "F0_path = config.get('F0_path', False)\n",
369
+ "pitch_extractor = load_F0_models(F0_path)\n",
370
+ "\n",
371
+ "# load BERT model\n",
372
+ "from Utils.PLBERT.util import load_plbert\n",
373
+ "BERT_path = config.get('PLBERT_dir', False)\n",
374
+ "plbert = load_plbert(BERT_path)"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": null,
380
+ "id": "ffc18cf7",
381
+ "metadata": {},
382
+ "outputs": [],
383
+ "source": [
384
+ "model_params = recursive_munch(config['model_params'])\n",
385
+ "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
386
+ "_ = [model[key].eval() for key in model]\n",
387
+ "_ = [model[key].to(device) for key in model]"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": null,
393
+ "id": "64529d5c",
394
+ "metadata": {},
395
+ "outputs": [],
396
+ "source": [
397
+ "params_whole = torch.load(\"Models/LibriTTS/epochs_2nd_00020.pth\", map_location='cpu')\n",
398
+ "params = params_whole['net']"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": null,
404
+ "id": "895d9706",
405
+ "metadata": {},
406
+ "outputs": [],
407
+ "source": [
408
+ "for key in model:\n",
409
+ " if key in params:\n",
410
+ " print('%s loaded' % key)\n",
411
+ " try:\n",
412
+ " model[key].load_state_dict(params[key])\n",
413
+ " except:\n",
414
+ " from collections import OrderedDict\n",
415
+ " state_dict = params[key]\n",
416
+ " new_state_dict = OrderedDict()\n",
417
+ " for k, v in state_dict.items():\n",
418
+ " name = k[7:] # remove `module.`\n",
419
+ " new_state_dict[name] = v\n",
420
+ " # load params\n",
421
+ " model[key].load_state_dict(new_state_dict, strict=False)\n",
422
+ "# except:\n",
423
+ "# _load(params[key], model[key])\n",
424
+ "_ = [model[key].eval() for key in model]"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": null,
430
+ "id": "c1a59db2",
431
+ "metadata": {},
432
+ "outputs": [],
433
+ "source": [
434
+ "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule"
435
+ ]
436
+ },
437
+ {
438
+ "cell_type": "code",
439
+ "execution_count": null,
440
+ "id": "e30985ab",
441
+ "metadata": {},
442
+ "outputs": [],
443
+ "source": [
444
+ "sampler = DiffusionSampler(\n",
445
+ " model.diffusion.diffusion,\n",
446
+ " sampler=ADPM2Sampler(),\n",
447
+ " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
448
+ " clamp=False\n",
449
+ ")"
450
+ ]
451
+ },
452
+ {
453
+ "cell_type": "markdown",
454
+ "id": "b803110e",
455
+ "metadata": {},
456
+ "source": [
457
+ "### Synthesize speech"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": null,
463
+ "id": "ca57469c",
464
+ "metadata": {},
465
+ "outputs": [],
466
+ "source": [
467
+ "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
468
+ " text = text.strip()\n",
469
+ " ps = global_phonemizer.phonemize([text])\n",
470
+ " ps = word_tokenize(ps[0])\n",
471
+ " ps = ' '.join(ps)\n",
472
+ " tokens = textclenaer(ps)\n",
473
+ " tokens.insert(0, 0)\n",
474
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
475
+ " \n",
476
+ " with torch.no_grad():\n",
477
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
478
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
479
+ "\n",
480
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
481
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
482
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
483
+ "\n",
484
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
485
+ " embedding=bert_dur,\n",
486
+ " embedding_scale=embedding_scale,\n",
487
+ " features=ref_s, # reference from the same speaker as the embedding\n",
488
+ " num_steps=diffusion_steps).squeeze(1)\n",
489
+ "\n",
490
+ "\n",
491
+ " s = s_pred[:, 128:]\n",
492
+ " ref = s_pred[:, :128]\n",
493
+ "\n",
494
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
495
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
496
+ "\n",
497
+ " d = model.predictor.text_encoder(d_en, \n",
498
+ " s, input_lengths, text_mask)\n",
499
+ "\n",
500
+ " x, _ = model.predictor.lstm(d)\n",
501
+ " duration = model.predictor.duration_proj(x)\n",
502
+ "\n",
503
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
504
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
505
+ "\n",
506
+ "\n",
507
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
508
+ " c_frame = 0\n",
509
+ " for i in range(pred_aln_trg.size(0)):\n",
510
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
511
+ " c_frame += int(pred_dur[i].data)\n",
512
+ "\n",
513
+ " # encode prosody\n",
514
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
515
+ " if model_params.decoder.type == \"hifigan\":\n",
516
+ " asr_new = torch.zeros_like(en)\n",
517
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
518
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
519
+ " en = asr_new\n",
520
+ "\n",
521
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
522
+ "\n",
523
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
524
+ " if model_params.decoder.type == \"hifigan\":\n",
525
+ " asr_new = torch.zeros_like(asr)\n",
526
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
527
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
528
+ " asr = asr_new\n",
529
+ "\n",
530
+ " out = model.decoder(asr, \n",
531
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
532
+ " \n",
533
+ " \n",
534
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
535
+ ]
536
+ },
537
+ {
538
+ "cell_type": "markdown",
539
+ "id": "d438ef4f",
540
+ "metadata": {},
541
+ "source": [
542
+ "#### Basic synthesis (5 diffusion steps, seen speakers)"
543
+ ]
544
+ },
545
+ {
546
+ "cell_type": "code",
547
+ "execution_count": null,
548
+ "id": "cace9787",
549
+ "metadata": {},
550
+ "outputs": [],
551
+ "source": [
552
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. '''"
553
+ ]
554
+ },
555
+ {
556
+ "cell_type": "code",
557
+ "execution_count": null,
558
+ "id": "7c88f461",
559
+ "metadata": {},
560
+ "outputs": [],
561
+ "source": [
562
+ "reference_dicts = {}\n",
563
+ "reference_dicts['696_92939'] = \"Demo/reference_audio/696_92939_000016_000006.wav\"\n",
564
+ "reference_dicts['1789_142896'] = \"Demo/reference_audio/1789_142896_000022_000005.wav\""
565
+ ]
566
+ },
567
+ {
568
+ "cell_type": "code",
569
+ "execution_count": null,
570
+ "id": "16e8ac60",
571
+ "metadata": {},
572
+ "outputs": [],
573
+ "source": [
574
+ "start = time.time()\n",
575
+ "noise = torch.randn(1,1,256).to(device)\n",
576
+ "for k, path in reference_dicts.items():\n",
577
+ " ref_s = compute_style(path)\n",
578
+ " \n",
579
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
580
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
581
+ " print(f\"RTF = {rtf:5f}\")\n",
582
+ " import IPython.display as ipd\n",
583
+ " print(k + ' Synthesized:')\n",
584
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
585
+ " print('Reference:')\n",
586
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
587
+ ]
588
+ },
589
+ {
590
+ "cell_type": "markdown",
591
+ "id": "14838708",
592
+ "metadata": {},
593
+ "source": [
594
+ "#### With higher diffusion steps (more diverse)\n",
595
+ "\n",
596
+ "Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed."
597
+ ]
598
+ },
599
+ {
600
+ "cell_type": "code",
601
+ "execution_count": null,
602
+ "id": "6fbff03b",
603
+ "metadata": {},
604
+ "outputs": [],
605
+ "source": [
606
+ "noise = torch.randn(1,1,256).to(device)\n",
607
+ "for k, path in reference_dicts.items():\n",
608
+ " ref_s = compute_style(path)\n",
609
+ " start = time.time()\n",
610
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=10, embedding_scale=1)\n",
611
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
612
+ " print(f\"RTF = {rtf:5f}\")\n",
613
+ " import IPython.display as ipd\n",
614
+ " print(k + ' Synthesized:')\n",
615
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
616
+ " print(k + ' Reference:')\n",
617
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
618
+ ]
619
+ },
620
+ {
621
+ "cell_type": "markdown",
622
+ "id": "7e6867fd",
623
+ "metadata": {},
624
+ "source": [
625
+ "#### Basic synthesis (5 diffusion steps, umseen speakers)\n",
626
+ "The following samples are to reproduce samples in [Section 4](https://styletts2.github.io/#libri) of the demo page. All spsakers are unseen during training. You can compare the generated samples to popular zero-shot TTS models like Vall-E and NaturalSpeech 2."
627
+ ]
628
+ },
629
+ {
630
+ "cell_type": "code",
631
+ "execution_count": null,
632
+ "id": "f4e8faa0",
633
+ "metadata": {},
634
+ "outputs": [],
635
+ "source": [
636
+ "reference_dicts = {}\n",
637
+ "# format: (path, text)\n",
638
+ "reference_dicts['1221-135767'] = (\"Demo/reference_audio/1221-135767-0014.wav\", \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\")\n",
639
+ "reference_dicts['5639-40744'] = (\"Demo/reference_audio/5639-40744-0020.wav\", \"Thus did this humane and right minded father comfort his unhappy daughter, and her mother embracing her again, did all she could to soothe her feelings.\")\n",
640
+ "reference_dicts['908-157963'] = (\"Demo/reference_audio/908-157963-0027.wav\", \"And lay me down in my cold bed and leave my shining lot.\")\n",
641
+ "reference_dicts['4077-13754'] = (\"Demo/reference_audio/4077-13754-0000.wav\", \"The army found the people in poverty and left them in comparative wealth.\")"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "code",
646
+ "execution_count": null,
647
+ "id": "653f1406",
648
+ "metadata": {},
649
+ "outputs": [],
650
+ "source": [
651
+ "noise = torch.randn(1,1,256).to(device)\n",
652
+ "for k, v in reference_dicts.items():\n",
653
+ " path, text = v\n",
654
+ " ref_s = compute_style(path)\n",
655
+ " start = time.time()\n",
656
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=5, embedding_scale=1)\n",
657
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
658
+ " print(f\"RTF = {rtf:5f}\")\n",
659
+ " import IPython.display as ipd\n",
660
+ " print(k + ' Synthesized: ' + text)\n",
661
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
662
+ " print(k + ' Reference:')\n",
663
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
664
+ ]
665
+ },
666
+ {
667
+ "cell_type": "markdown",
668
+ "id": "141e91b3",
669
+ "metadata": {},
670
+ "source": [
671
+ "### Speech expressiveness\n",
672
+ "\n",
673
+ "The following section recreates the samples shown in [Section 6](https://styletts2.github.io/#emo) of the demo page. The speaker reference used is `1221-135767-0014.wav`, which is unseen during training. \n",
674
+ "\n",
675
+ "#### With `embedding_scale=1`\n",
676
+ "This is the classifier-free guidance scale. The higher the scale, the more conditional the style is to the input text and hence more emotional.\n",
677
+ "\n"
678
+ ]
679
+ },
680
+ {
681
+ "cell_type": "code",
682
+ "execution_count": null,
683
+ "id": "81addda4",
684
+ "metadata": {},
685
+ "outputs": [],
686
+ "source": [
687
+ "ref_s = compute_style(\"Demo/reference_audio/1221-135767-0014.wav\")"
688
+ ]
689
+ },
690
+ {
691
+ "cell_type": "code",
692
+ "execution_count": null,
693
+ "id": "be1b2a11",
694
+ "metadata": {},
695
+ "outputs": [],
696
+ "source": [
697
+ "texts = {}\n",
698
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
699
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
700
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
701
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
702
+ "\n",
703
+ "for k,v in texts.items():\n",
704
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
705
+ " print(k + \": \")\n",
706
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
707
+ ]
708
+ },
709
+ {
710
+ "cell_type": "markdown",
711
+ "id": "96d262b8",
712
+ "metadata": {},
713
+ "source": [
714
+ "#### With `embedding_scale=2`"
715
+ ]
716
+ },
717
+ {
718
+ "cell_type": "code",
719
+ "execution_count": null,
720
+ "id": "3e7d40b4",
721
+ "metadata": {},
722
+ "outputs": [],
723
+ "source": [
724
+ "texts = {}\n",
725
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
726
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
727
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
728
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
729
+ "\n",
730
+ "for k,v in texts.items():\n",
731
+ " noise = torch.randn(1,1,256).to(device)\n",
732
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=2)\n",
733
+ " print(k + \": \")\n",
734
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
735
+ ]
736
+ },
737
+ {
738
+ "cell_type": "markdown",
739
+ "id": "402b2bd6",
740
+ "metadata": {},
741
+ "source": [
742
+ "#### With `embedding_scale=2, alpha = 0.5, beta = 0.9`\n",
743
+ "`alpha` and `beta` is the factor to determine much we use the style sampled based on the text instead of the reference. The higher the value of `alpha` and `beta`, the more suitable the style it is to the text but less similar to the reference. Using higher beta makes the synthesized speech more emotional, at the cost of lower similarity to the reference. `alpha` determines the timbre of the speaker while `beta` determines the prosody. "
744
+ ]
745
+ },
746
+ {
747
+ "cell_type": "code",
748
+ "execution_count": null,
749
+ "id": "599de5d5",
750
+ "metadata": {},
751
+ "outputs": [],
752
+ "source": [
753
+ "texts = {}\n",
754
+ "texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
755
+ "texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
756
+ "texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
757
+ "texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\"\n",
758
+ "\n",
759
+ "for k,v in texts.items():\n",
760
+ " noise = torch.randn(1,1,256).to(device)\n",
761
+ " wav = inference(v, ref_s, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=2)\n",
762
+ " print(k + \": \")\n",
763
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
764
+ ]
765
+ },
766
+ {
767
+ "cell_type": "markdown",
768
+ "id": "48548866",
769
+ "metadata": {},
770
+ "source": [
771
+ "### Zero-shot speaker adaptation\n",
772
+ "This section recreates the \"Acoustic Environment Maintenance\" and \"Speaker’s Emotion Maintenance\" demo in [Section 4](https://styletts2.github.io/#libri) of the demo page. You can compare the generated samples to popular zero-shot TTS models like Vall-E. Note that the model was trained only on LibriTTS, which is about 250 times fewer data compared to those used to trian Vall-E with similar or better effect for these maintainance. "
773
+ ]
774
+ },
775
+ {
776
+ "cell_type": "markdown",
777
+ "id": "23e81572",
778
+ "metadata": {},
779
+ "source": [
780
+ "#### Acoustic Environment Maintenance\n",
781
+ "\n",
782
+ "Since we want to maintain the acoustic environment in the speaker (timbre), we set `alpha = 0` to make the speaker as closer to the reference as possible while only changing the prosody according to the text. "
783
+ ]
784
+ },
785
+ {
786
+ "cell_type": "code",
787
+ "execution_count": null,
788
+ "id": "8087bccb",
789
+ "metadata": {},
790
+ "outputs": [],
791
+ "source": [
792
+ "reference_dicts = {}\n",
793
+ "# format: (path, text)\n",
794
+ "reference_dicts['3'] = (\"Demo/reference_audio/3.wav\", \"As friends thing I definitely I've got more male friends.\")\n",
795
+ "reference_dicts['4'] = (\"Demo/reference_audio/4.wav\", \"Everything is run by computer but you got to know how to think before you can do a computer.\")\n",
796
+ "reference_dicts['5'] = (\"Demo/reference_audio/5.wav\", \"Then out in LA you guys got a whole another ball game within California to worry about.\")"
797
+ ]
798
+ },
799
+ {
800
+ "cell_type": "code",
801
+ "execution_count": null,
802
+ "id": "1e99c200",
803
+ "metadata": {},
804
+ "outputs": [],
805
+ "source": [
806
+ "noise = torch.randn(1,1,256).to(device)\n",
807
+ "for k, v in reference_dicts.items():\n",
808
+ " path, text = v\n",
809
+ " ref_s = compute_style(path)\n",
810
+ " start = time.time()\n",
811
+ " wav = inference(text, ref_s, alpha=0.0, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
812
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
813
+ " print(f\"RTF = {rtf:5f}\")\n",
814
+ " import IPython.display as ipd\n",
815
+ " print('Synthesized: ' + text)\n",
816
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
817
+ " print('Reference:')\n",
818
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
819
+ ]
820
+ },
821
+ {
822
+ "cell_type": "markdown",
823
+ "id": "7d56505d",
824
+ "metadata": {},
825
+ "source": [
826
+ "#### Speaker’s Emotion Maintenance\n",
827
+ "\n",
828
+ "Since we want to maintain the emotion in the speaker (prosody), we set `beta = 0.1` to make the speaker as closer to the reference as possible while having some diversity thruogh the slight timbre change."
829
+ ]
830
+ },
831
+ {
832
+ "cell_type": "code",
833
+ "execution_count": null,
834
+ "id": "f90179e7",
835
+ "metadata": {},
836
+ "outputs": [],
837
+ "source": [
838
+ "reference_dicts = {}\n",
839
+ "# format: (path, text)\n",
840
+ "reference_dicts['Anger'] = (\"Demo/reference_audio/anger.wav\", \"We have to reduce the number of plastic bags.\")\n",
841
+ "reference_dicts['Sleepy'] = (\"Demo/reference_audio/sleepy.wav\", \"We have to reduce the number of plastic bags.\")\n",
842
+ "reference_dicts['Amused'] = (\"Demo/reference_audio/amused.wav\", \"We have to reduce the number of plastic bags.\")\n",
843
+ "reference_dicts['Disgusted'] = (\"Demo/reference_audio/disgusted.wav\", \"We have to reduce the number of plastic bags.\")"
844
+ ]
845
+ },
846
+ {
847
+ "cell_type": "code",
848
+ "execution_count": null,
849
+ "id": "2e6bdfed",
850
+ "metadata": {},
851
+ "outputs": [],
852
+ "source": [
853
+ "noise = torch.randn(1,1,256).to(device)\n",
854
+ "for k, v in reference_dicts.items():\n",
855
+ " path, text = v\n",
856
+ " ref_s = compute_style(path)\n",
857
+ " start = time.time()\n",
858
+ " wav = inference(text, ref_s, alpha=0.3, beta=0.1, diffusion_steps=10, embedding_scale=1)\n",
859
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
860
+ " print(f\"RTF = {rtf:5f}\")\n",
861
+ " import IPython.display as ipd\n",
862
+ " print(k + ' Synthesized: ' + text)\n",
863
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
864
+ " print(k + ' Reference:')\n",
865
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
866
+ ]
867
+ },
868
+ {
869
+ "cell_type": "markdown",
870
+ "id": "37ae3963",
871
+ "metadata": {},
872
+ "source": [
873
+ "### Longform Narration\n",
874
+ "\n",
875
+ "This section includes basic implementation of Algorithm 1 in the paper for consistent longform audio generation. The example passage is taken from [Section 5](https://styletts2.github.io/#long) of the demo page."
876
+ ]
877
+ },
878
+ {
879
+ "cell_type": "code",
880
+ "execution_count": null,
881
+ "id": "f12a716b",
882
+ "metadata": {},
883
+ "outputs": [],
884
+ "source": [
885
+ "passage = '''If the supply of fruit is greater than the family needs, it may be made a source of income by sending the fresh fruit to the market if there is one near enough, or by preserving, canning, and making jelly for sale. To make such an enterprise a success the fruit and work must be first class. There is magic in the word \"Homemade,\" when the product appeals to the eye and the palate; but many careless and incompetent people have found to their sorrow that this word has not magic enough to float inferior goods on the market. As a rule large canning and preserving establishments are clean and have the best appliances, and they employ chemists and skilled labor. The home product must be very good to compete with the attractive goods that are sent out from such establishments. Yet for first class home made products there is a market in all large cities. All first-class grocers have customers who purchase such goods.'''"
886
+ ]
887
+ },
888
+ {
889
+ "cell_type": "code",
890
+ "execution_count": null,
891
+ "id": "a1a38079",
892
+ "metadata": {},
893
+ "outputs": [],
894
+ "source": [
895
+ "def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):\n",
896
+ " text = text.strip()\n",
897
+ " ps = global_phonemizer.phonemize([text])\n",
898
+ " ps = word_tokenize(ps[0])\n",
899
+ " ps = ' '.join(ps)\n",
900
+ " ps = ps.replace('``', '\"')\n",
901
+ " ps = ps.replace(\"''\", '\"')\n",
902
+ "\n",
903
+ " tokens = textclenaer(ps)\n",
904
+ " tokens.insert(0, 0)\n",
905
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
906
+ " \n",
907
+ " with torch.no_grad():\n",
908
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
909
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
910
+ "\n",
911
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
912
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
913
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
914
+ "\n",
915
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
916
+ " embedding=bert_dur,\n",
917
+ " embedding_scale=embedding_scale,\n",
918
+ " features=ref_s, # reference from the same speaker as the embedding\n",
919
+ " num_steps=diffusion_steps).squeeze(1)\n",
920
+ " \n",
921
+ " if s_prev is not None:\n",
922
+ " # convex combination of previous and current style\n",
923
+ " s_pred = t * s_prev + (1 - t) * s_pred\n",
924
+ " \n",
925
+ " s = s_pred[:, 128:]\n",
926
+ " ref = s_pred[:, :128]\n",
927
+ " \n",
928
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
929
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
930
+ "\n",
931
+ " s_pred = torch.cat([ref, s], dim=-1)\n",
932
+ "\n",
933
+ " d = model.predictor.text_encoder(d_en, \n",
934
+ " s, input_lengths, text_mask)\n",
935
+ "\n",
936
+ " x, _ = model.predictor.lstm(d)\n",
937
+ " duration = model.predictor.duration_proj(x)\n",
938
+ "\n",
939
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
940
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
941
+ "\n",
942
+ "\n",
943
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
944
+ " c_frame = 0\n",
945
+ " for i in range(pred_aln_trg.size(0)):\n",
946
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
947
+ " c_frame += int(pred_dur[i].data)\n",
948
+ "\n",
949
+ " # encode prosody\n",
950
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
951
+ " if model_params.decoder.type == \"hifigan\":\n",
952
+ " asr_new = torch.zeros_like(en)\n",
953
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
954
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
955
+ " en = asr_new\n",
956
+ "\n",
957
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
958
+ "\n",
959
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
960
+ " if model_params.decoder.type == \"hifigan\":\n",
961
+ " asr_new = torch.zeros_like(asr)\n",
962
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
963
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
964
+ " asr = asr_new\n",
965
+ "\n",
966
+ " out = model.decoder(asr, \n",
967
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
968
+ " \n",
969
+ " \n",
970
+ " return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later"
971
+ ]
972
+ },
973
+ {
974
+ "cell_type": "code",
975
+ "execution_count": null,
976
+ "id": "e9088f7a",
977
+ "metadata": {},
978
+ "outputs": [],
979
+ "source": [
980
+ "# unseen speaker\n",
981
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
982
+ "s_ref = compute_style(path)\n",
983
+ "sentences = passage.split('.') # simple split by comma\n",
984
+ "wavs = []\n",
985
+ "s_prev = None\n",
986
+ "for text in sentences:\n",
987
+ " if text.strip() == \"\": continue\n",
988
+ " text += '.' # add it back\n",
989
+ " \n",
990
+ " wav, s_prev = LFinference(text, \n",
991
+ " s_prev, \n",
992
+ " s_ref, \n",
993
+ " alpha = 0.3, \n",
994
+ " beta = 0.9, # make it more suitable for the text\n",
995
+ " t = 0.7, \n",
996
+ " diffusion_steps=10, embedding_scale=1.5)\n",
997
+ " wavs.append(wav)\n",
998
+ "print('Synthesized: ')\n",
999
+ "display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))\n",
1000
+ "print('Reference: ')\n",
1001
+ "display(ipd.Audio(path, rate=24000, normalize=False))"
1002
+ ]
1003
+ },
1004
+ {
1005
+ "cell_type": "markdown",
1006
+ "id": "7517b657",
1007
+ "metadata": {},
1008
+ "source": [
1009
+ "### Style Transfer\n",
1010
+ "\n",
1011
+ "The following section demostrates the style transfer capacity for unseen speakers in [Section 6](https://styletts2.github.io/#emo) of the demo page. For this, we set `alpha=0.5, beta = 0.9` for the most pronounced effects (mostly using the sampled style). "
1012
+ ]
1013
+ },
1014
+ {
1015
+ "cell_type": "code",
1016
+ "execution_count": null,
1017
+ "id": "ed95d0f7",
1018
+ "metadata": {},
1019
+ "outputs": [],
1020
+ "source": [
1021
+ "def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
1022
+ " text = text.strip()\n",
1023
+ " ps = global_phonemizer.phonemize([text])\n",
1024
+ " ps = word_tokenize(ps[0])\n",
1025
+ " ps = ' '.join(ps)\n",
1026
+ "\n",
1027
+ " tokens = textclenaer(ps)\n",
1028
+ " tokens.insert(0, 0)\n",
1029
+ " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
1030
+ " \n",
1031
+ " ref_text = ref_text.strip()\n",
1032
+ " ps = global_phonemizer.phonemize([ref_text])\n",
1033
+ " ps = word_tokenize(ps[0])\n",
1034
+ " ps = ' '.join(ps)\n",
1035
+ "\n",
1036
+ " ref_tokens = textclenaer(ps)\n",
1037
+ " ref_tokens.insert(0, 0)\n",
1038
+ " ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)\n",
1039
+ " \n",
1040
+ " \n",
1041
+ " with torch.no_grad():\n",
1042
+ " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
1043
+ " text_mask = length_to_mask(input_lengths).to(device)\n",
1044
+ "\n",
1045
+ " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
1046
+ " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
1047
+ " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n",
1048
+ " \n",
1049
+ " ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)\n",
1050
+ " ref_text_mask = length_to_mask(ref_input_lengths).to(device)\n",
1051
+ " ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())\n",
1052
+ " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), \n",
1053
+ " embedding=bert_dur,\n",
1054
+ " embedding_scale=embedding_scale,\n",
1055
+ " features=ref_s, # reference from the same speaker as the embedding\n",
1056
+ " num_steps=diffusion_steps).squeeze(1)\n",
1057
+ "\n",
1058
+ "\n",
1059
+ " s = s_pred[:, 128:]\n",
1060
+ " ref = s_pred[:, :128]\n",
1061
+ "\n",
1062
+ " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
1063
+ " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
1064
+ "\n",
1065
+ " d = model.predictor.text_encoder(d_en, \n",
1066
+ " s, input_lengths, text_mask)\n",
1067
+ "\n",
1068
+ " x, _ = model.predictor.lstm(d)\n",
1069
+ " duration = model.predictor.duration_proj(x)\n",
1070
+ "\n",
1071
+ " duration = torch.sigmoid(duration).sum(axis=-1)\n",
1072
+ " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
1073
+ "\n",
1074
+ "\n",
1075
+ " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
1076
+ " c_frame = 0\n",
1077
+ " for i in range(pred_aln_trg.size(0)):\n",
1078
+ " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
1079
+ " c_frame += int(pred_dur[i].data)\n",
1080
+ "\n",
1081
+ " # encode prosody\n",
1082
+ " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
1083
+ " if model_params.decoder.type == \"hifigan\":\n",
1084
+ " asr_new = torch.zeros_like(en)\n",
1085
+ " asr_new[:, :, 0] = en[:, :, 0]\n",
1086
+ " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
1087
+ " en = asr_new\n",
1088
+ "\n",
1089
+ " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
1090
+ "\n",
1091
+ " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
1092
+ " if model_params.decoder.type == \"hifigan\":\n",
1093
+ " asr_new = torch.zeros_like(asr)\n",
1094
+ " asr_new[:, :, 0] = asr[:, :, 0]\n",
1095
+ " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
1096
+ " asr = asr_new\n",
1097
+ "\n",
1098
+ " out = model.decoder(asr, \n",
1099
+ " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
1100
+ " \n",
1101
+ " \n",
1102
+ " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
1103
+ ]
1104
+ },
1105
+ {
1106
+ "cell_type": "code",
1107
+ "execution_count": null,
1108
+ "id": "ec3f0da4",
1109
+ "metadata": {},
1110
+ "outputs": [],
1111
+ "source": [
1112
+ "# reference texts to sample styles\n",
1113
+ "\n",
1114
+ "ref_texts = {}\n",
1115
+ "ref_texts['Happy'] = \"We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands.\"\n",
1116
+ "ref_texts['Sad'] = \"I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence.\"\n",
1117
+ "ref_texts['Angry'] = \"The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!\"\n",
1118
+ "ref_texts['Surprised'] = \"I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?\""
1119
+ ]
1120
+ },
1121
+ {
1122
+ "cell_type": "code",
1123
+ "execution_count": null,
1124
+ "id": "6d0a3825",
1125
+ "metadata": {},
1126
+ "outputs": [],
1127
+ "source": [
1128
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1129
+ "s_ref = compute_style(path)\n",
1130
+ "\n",
1131
+ "text = \"Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.\"\n",
1132
+ "for k,v in ref_texts.items():\n",
1133
+ " wav = STinference(text, s_ref, v, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=1.5)\n",
1134
+ " print(k + \": \")\n",
1135
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1136
+ ]
1137
+ },
1138
+ {
1139
+ "cell_type": "markdown",
1140
+ "id": "6750aed9",
1141
+ "metadata": {},
1142
+ "source": [
1143
+ "### Speech diversity\n",
1144
+ "\n",
1145
+ "This section reproduces samples in [Section 7](https://styletts2.github.io/#var) of the demo page. \n",
1146
+ "\n",
1147
+ "`alpha` and `beta` determine the diversity of the synthesized speech. There are two extreme cases:\n",
1148
+ "- If `alpha = 1` and `beta = 1`, the synthesized speech sounds the most dissimilar to the reference speaker, but it is also the most diverse (each time you synthesize a speech it will be totally different). \n",
1149
+ "- If `alpha = 0` and `beta = 0`, the synthesized speech sounds the most siimlar to the reference speaker, but it is deterministic (i.e., the sampled style is not used for speech synthesis). \n"
1150
+ ]
1151
+ },
1152
+ {
1153
+ "cell_type": "markdown",
1154
+ "id": "f6ae0aa5",
1155
+ "metadata": {},
1156
+ "source": [
1157
+ "#### Default setting (`alpha = 0.3, beta=0.7`)\n",
1158
+ "This setting uses 70% of the reference timbre and 30% of the reference prosody and use the diffusion model to sample them based on the text. "
1159
+ ]
1160
+ },
1161
+ {
1162
+ "cell_type": "code",
1163
+ "execution_count": null,
1164
+ "id": "36dc0148",
1165
+ "metadata": {},
1166
+ "outputs": [],
1167
+ "source": [
1168
+ "# unseen speaker\n",
1169
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1170
+ "ref_s = compute_style(path)\n",
1171
+ "\n",
1172
+ "text = \"How much variation is there?\"\n",
1173
+ "for _ in range(5):\n",
1174
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.3, beta=0.7, embedding_scale=1)\n",
1175
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1176
+ ]
1177
+ },
1178
+ {
1179
+ "cell_type": "markdown",
1180
+ "id": "bf9ef421",
1181
+ "metadata": {},
1182
+ "source": [
1183
+ "#### Less diverse setting (`alpha = 0.1, beta=0.3`)\n",
1184
+ "This setting uses 90% of the reference timbre and 70% of the reference prosody. This makes it more similar to the reference speaker at cost of less diverse samples. "
1185
+ ]
1186
+ },
1187
+ {
1188
+ "cell_type": "code",
1189
+ "execution_count": null,
1190
+ "id": "9ba406bd",
1191
+ "metadata": {},
1192
+ "outputs": [],
1193
+ "source": [
1194
+ "# unseen speaker\n",
1195
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1196
+ "ref_s = compute_style(path)\n",
1197
+ "\n",
1198
+ "text = \"How much variation is there?\"\n",
1199
+ "for _ in range(5):\n",
1200
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.1, beta=0.3, embedding_scale=1)\n",
1201
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1202
+ ]
1203
+ },
1204
+ {
1205
+ "cell_type": "markdown",
1206
+ "id": "a38fe464",
1207
+ "metadata": {},
1208
+ "source": [
1209
+ "#### More diverse setting (`alpha = 0.5, beta=0.95`)\n",
1210
+ "This setting uses 50% of the reference timbre and 5% of the reference prosody (so it uses 100% of the sampled prosody, which makes it more diverse), but this makes it more dissimilar to the reference speaker. "
1211
+ ]
1212
+ },
1213
+ {
1214
+ "cell_type": "code",
1215
+ "execution_count": null,
1216
+ "id": "5f25bf94",
1217
+ "metadata": {},
1218
+ "outputs": [],
1219
+ "source": [
1220
+ "# unseen speaker\n",
1221
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1222
+ "ref_s = compute_style(path)\n",
1223
+ "\n",
1224
+ "text = \"How much variation is there?\"\n",
1225
+ "for _ in range(5):\n",
1226
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0.5, beta=0.95, embedding_scale=1)\n",
1227
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1228
+ ]
1229
+ },
1230
+ {
1231
+ "cell_type": "markdown",
1232
+ "id": "21c3a071",
1233
+ "metadata": {},
1234
+ "source": [
1235
+ "#### Extreme setting (`alpha = 1, beta=1`)\n",
1236
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very dissimilar to the reference speaker. "
1237
+ ]
1238
+ },
1239
+ {
1240
+ "cell_type": "code",
1241
+ "execution_count": null,
1242
+ "id": "fff8bab1",
1243
+ "metadata": {},
1244
+ "outputs": [],
1245
+ "source": [
1246
+ "# unseen speaker\n",
1247
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1248
+ "ref_s = compute_style(path)\n",
1249
+ "\n",
1250
+ "text = \"How much variation is there?\"\n",
1251
+ "for _ in range(5):\n",
1252
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=1, beta=1, embedding_scale=1)\n",
1253
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1254
+ ]
1255
+ },
1256
+ {
1257
+ "cell_type": "markdown",
1258
+ "id": "a8741e5a",
1259
+ "metadata": {},
1260
+ "source": [
1261
+ "#### No variation (`alpha = 0, beta=0`)\n",
1262
+ "This setting uses 0% of the reference timbre and prosody and use the diffusion model to sample the entire style. This makes the speaker very similar to the reference speaker, but there is no variation. "
1263
+ ]
1264
+ },
1265
+ {
1266
+ "cell_type": "code",
1267
+ "execution_count": null,
1268
+ "id": "e55dd281",
1269
+ "metadata": {},
1270
+ "outputs": [],
1271
+ "source": [
1272
+ "# unseen speaker\n",
1273
+ "path = \"Demo/reference_audio/1221-135767-0014.wav\"\n",
1274
+ "ref_s = compute_style(path)\n",
1275
+ "\n",
1276
+ "text = \"How much variation is there?\"\n",
1277
+ "for _ in range(5):\n",
1278
+ " wav = inference(text, ref_s, diffusion_steps=10, alpha=0, beta=0, embedding_scale=1)\n",
1279
+ " display(ipd.Audio(wav, rate=24000, normalize=False))"
1280
+ ]
1281
+ },
1282
+ {
1283
+ "cell_type": "markdown",
1284
+ "id": "d5e86423",
1285
+ "metadata": {},
1286
+ "source": [
1287
+ "### Extra fun!\n",
1288
+ "\n",
1289
+ "Here we clone some of the authors' voice of the StyleTTS 2 papers with a few seconds of the recording in the wild. None of the voices is in the dataset and all authors agreed to have their voices cloned here."
1290
+ ]
1291
+ },
1292
+ {
1293
+ "cell_type": "code",
1294
+ "execution_count": null,
1295
+ "id": "6f558314",
1296
+ "metadata": {},
1297
+ "outputs": [],
1298
+ "source": [
1299
+ "text = ''' StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human level text to speech synthesis. '''"
1300
+ ]
1301
+ },
1302
+ {
1303
+ "cell_type": "code",
1304
+ "execution_count": null,
1305
+ "id": "caa5747c",
1306
+ "metadata": {},
1307
+ "outputs": [],
1308
+ "source": [
1309
+ "reference_dicts = {}\n",
1310
+ "reference_dicts['Yinghao'] = \"Demo/reference_audio/Yinghao.wav\"\n",
1311
+ "reference_dicts['Gavin'] = \"Demo/reference_audio/Gavin.wav\"\n",
1312
+ "reference_dicts['Vinay'] = \"Demo/reference_audio/Vinay.wav\"\n",
1313
+ "reference_dicts['Nima'] = \"Demo/reference_audio/Nima.wav\""
1314
+ ]
1315
+ },
1316
+ {
1317
+ "cell_type": "code",
1318
+ "execution_count": null,
1319
+ "id": "44a4cea1",
1320
+ "metadata": {},
1321
+ "outputs": [],
1322
+ "source": [
1323
+ "start = time.time()\n",
1324
+ "noise = torch.randn(1,1,256).to(device)\n",
1325
+ "for k, path in reference_dicts.items():\n",
1326
+ " ref_s = compute_style(path)\n",
1327
+ " \n",
1328
+ " wav = inference(text, ref_s, alpha=0.1, beta=0.5, diffusion_steps=5, embedding_scale=1)\n",
1329
+ " rtf = (time.time() - start) / (len(wav) / 24000)\n",
1330
+ " print('Speaker: ' + k)\n",
1331
+ " import IPython.display as ipd\n",
1332
+ " print('Synthesized:')\n",
1333
+ " display(ipd.Audio(wav, rate=24000, normalize=False))\n",
1334
+ " print('Reference:')\n",
1335
+ " display(ipd.Audio(path, rate=24000, normalize=False))"
1336
+ ]
1337
+ }
1338
+ ],
1339
+ "metadata": {
1340
+ "kernelspec": {
1341
+ "display_name": "Python3 (main venv)",
1342
+ "language": "python",
1343
+ "name": "main"
1344
+ },
1345
+ "language_info": {
1346
+ "codemirror_mode": {
1347
+ "name": "ipython",
1348
+ "version": 3
1349
+ },
1350
+ "file_extension": ".py",
1351
+ "mimetype": "text/x-python",
1352
+ "name": "python",
1353
+ "nbconvert_exporter": "python",
1354
+ "pygments_lexer": "ipython3",
1355
+ "version": "3.12.3"
1356
+ }
1357
+ },
1358
+ "nbformat": 4,
1359
+ "nbformat_minor": 5
1360
+ }
Modules/.ipynb_checkpoints/slmadv-checkpoint.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SLMAdversarialLoss(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ model,
10
+ wl,
11
+ sampler,
12
+ min_len,
13
+ max_len,
14
+ batch_percentage=0.5,
15
+ skip_update=10,
16
+ sig=1.5,
17
+ ):
18
+ super().__init__()
19
+ self.model = model
20
+ self.wl = wl
21
+ self.sampler = sampler
22
+
23
+ self.min_len = min_len
24
+ self.max_len = max_len
25
+ self.batch_percentage = batch_percentage
26
+
27
+ self.sig = sig
28
+ self.skip_update = skip_update
29
+
30
+ # ------------------------------------------------------------------ #
31
+ def forward(
32
+ self,
33
+ iters,
34
+ y_rec_gt,
35
+ y_rec_gt_pred,
36
+ waves,
37
+ mel_input_length,
38
+ ref_text,
39
+ ref_lengths,
40
+ use_ind,
41
+ s_trg,
42
+ ref_s=None,
43
+ ):
44
+ # ---- full-width mask (matches ref_text.size(1)) ----------------
45
+ seq_len = ref_text.size(1)
46
+ text_mask = (
47
+ torch.arange(seq_len, device=ref_text.device)
48
+ .unsqueeze(0)
49
+ >= ref_lengths.unsqueeze(1)
50
+ ) # shape [B, seq_len]
51
+
52
+ bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
53
+ d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
54
+
55
+ # ----- style / prosody sampling ---------------------------------
56
+ if use_ind and np.random.rand() < 0.5:
57
+ s_preds = s_trg
58
+ else:
59
+ num_steps = np.random.randint(3, 5)
60
+ noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device)
61
+ sampler_kwargs = dict(
62
+ noise=noise,
63
+ embedding=bert_dur,
64
+ embedding_scale=1,
65
+ embedding_mask_proba=0.1,
66
+ num_steps=num_steps,
67
+ )
68
+ if ref_s is not None:
69
+ sampler_kwargs["features"] = ref_s
70
+ s_preds = self.sampler(**sampler_kwargs).squeeze(1)
71
+
72
+ s_dur, s = s_preds[:, 128:], s_preds[:, :128]
73
+
74
+ # random alignment placeholder must match the *padded* token width
75
+ seq_len = ref_text.size(1)
76
+ rand_align = torch.randn(ref_text.size(0), seq_len, 2, device=ref_text.device)
77
+
78
+ d, _ = self.model.predictor(
79
+ d_en, s_dur, ref_lengths,
80
+ rand_align,
81
+ text_mask,
82
+ )
83
+
84
+ # ----- differentiable duration modelling -----------------------
85
+ attn_preds, output_lengths = [], []
86
+ for _s2s_pred, _len in zip(d, ref_lengths):
87
+ _s2s_pred_org = _s2s_pred[: _len]
88
+ _s2s_pred_sig = torch.sigmoid(_s2s_pred_org)
89
+ _dur_pred = _s2s_pred_sig.sum(dim=-1)
90
+
91
+ l = int(torch.round(_s2s_pred_sig.sum()).item())
92
+ t = torch.arange(l, device=ref_text.device).unsqueeze(0).expand(_len, l)
93
+ loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
94
+ h = torch.exp(-0.5 * (t - (l - loc.unsqueeze(-1))) ** 2 / (self.sig**2))
95
+
96
+ out = F.conv1d(
97
+ _s2s_pred_org.unsqueeze(0),
98
+ h.unsqueeze(1),
99
+ padding=h.size(-1) - 1,
100
+ groups=int(_len),
101
+ )[..., :l]
102
+ attn_preds.append(F.softmax(out.squeeze(), dim=0))
103
+ output_lengths.append(l)
104
+
105
+ max_len = max(output_lengths)
106
+
107
+ # ----- build full-width alignment matrix -----------------------
108
+ with torch.no_grad():
109
+ t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
110
+
111
+ seq_len = ref_text.size(1)
112
+ s2s_attn = torch.zeros(
113
+ len(ref_lengths), seq_len, max_len, device=ref_text.device
114
+ )
115
+ for bib, (attn, L) in enumerate(zip(attn_preds, output_lengths)):
116
+ s2s_attn[bib, : ref_lengths[bib], :L] = attn
117
+
118
+ asr_pred = t_en @ s2s_attn
119
+
120
+ _, p_pred = self.model.predictor(
121
+ d_en, s_dur, ref_lengths, s2s_attn, text_mask
122
+ )
123
+
124
+ # ----- clip extraction -----------------------------------------
125
+ mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
126
+ mel_len = min(mel_len, self.max_len // 2)
127
+
128
+ en, p_en, sp, wav = [], [], [], []
129
+ for bib, L_pred in enumerate(output_lengths):
130
+ L_gt = int(mel_input_length[bib].item() / 2)
131
+ if L_gt <= mel_len or L_pred <= mel_len:
132
+ continue
133
+
134
+ sp.append(s_preds[bib])
135
+
136
+ start = np.random.randint(0, L_pred - mel_len)
137
+ en.append(asr_pred[bib, :, start : start + mel_len])
138
+ p_en.append(p_pred[bib, :, start : start + mel_len])
139
+
140
+ start_gt = np.random.randint(0, L_gt - mel_len)
141
+ y = waves[bib][(start_gt * 2) * 300 : ((start_gt + mel_len) * 2) * 300]
142
+ wav.append(torch.from_numpy(y).to(ref_text.device))
143
+
144
+ if len(wav) >= self.batch_percentage * len(waves):
145
+ break
146
+
147
+ if len(sp) <= 1:
148
+ return None
149
+
150
+ sp = torch.stack(sp)
151
+ wav = torch.stack(wav).float()
152
+ en = torch.stack(en)
153
+ p_en = torch.stack(p_en)
154
+
155
+ F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
156
+ y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
157
+
158
+ # -------------- adversarial losses -----------------------------
159
+ if (iters + 1) % self.skip_update == 0:
160
+ d_loss = self.wl.discriminator(wav.squeeze(), y_pred.detach().squeeze()).mean()
161
+ else:
162
+ d_loss = 0
163
+
164
+ gen_loss = self.wl.generator(y_pred.squeeze()).mean()
165
+ return d_loss, gen_loss, y_pred.detach().cpu().numpy()
166
+
167
+
168
+ # ------------------------------------------------------------------ #
169
+ def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
170
+ """Classic length mask: 1 → PAD, 0 → real token."""
171
+ max_len = lengths.max()
172
+ mask = (
173
+ torch.arange(max_len, device=lengths.device)
174
+ .unsqueeze(0)
175
+ .expand(lengths.size(0), -1)
176
+ )
177
+ return mask >= lengths.unsqueeze(1)
Modules/slmadv.py CHANGED
@@ -2,194 +2,176 @@ import torch
2
  import numpy as np
3
  import torch.nn.functional as F
4
 
5
- class SLMAdversarialLoss(torch.nn.Module):
6
 
7
- def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5):
8
- super(SLMAdversarialLoss, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
9
  self.model = model
10
  self.wl = wl
11
  self.sampler = sampler
12
-
13
  self.min_len = min_len
14
  self.max_len = max_len
15
  self.batch_percentage = batch_percentage
16
-
17
  self.sig = sig
18
  self.skip_update = skip_update
19
-
20
- def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None):
21
- text_mask = length_to_mask(ref_lengths).to(ref_text.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
23
- d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
24
-
 
25
  if use_ind and np.random.rand() < 0.5:
26
  s_preds = s_trg
27
  else:
28
  num_steps = np.random.randint(3, 5)
 
 
 
 
 
 
 
 
29
  if ref_s is not None:
30
- s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
31
- embedding=bert_dur,
32
- embedding_scale=1,
33
- features=ref_s, # reference from the same speaker as the embedding
34
- embedding_mask_proba=0.1,
35
- num_steps=num_steps).squeeze(1)
36
- else:
37
- s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
38
- embedding=bert_dur,
39
- embedding_scale=1,
40
- embedding_mask_proba=0.1,
41
- num_steps=num_steps).squeeze(1)
42
-
43
- s_dur = s_preds[:, 128:]
44
- s = s_preds[:, :128]
45
-
46
- d, _ = self.model.predictor(d_en, s_dur,
47
- ref_lengths,
48
- torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
49
- text_mask)
50
-
51
- bib = 0
52
-
53
- output_lengths = []
54
- attn_preds = []
55
-
56
- # differentiable duration modeling
57
- for _s2s_pred, _text_length in zip(d, ref_lengths):
58
-
59
- _s2s_pred_org = _s2s_pred[:_text_length, :]
60
-
61
- _s2s_pred = torch.sigmoid(_s2s_pred_org)
62
- _dur_pred = _s2s_pred.sum(axis=-1)
63
-
64
- l = int(torch.round(_s2s_pred.sum()).item())
65
- t = torch.arange(0, l).expand(l)
66
-
67
- t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device)
68
  loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
69
-
70
- h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2)
71
-
72
- out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0),
73
- h.unsqueeze(1),
74
- padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
 
 
75
  attn_preds.append(F.softmax(out.squeeze(), dim=0))
76
-
77
  output_lengths.append(l)
78
 
79
  max_len = max(output_lengths)
80
-
 
81
  with torch.no_grad():
82
  t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
83
-
84
- s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(ref_text.device)
85
- for bib in range(len(output_lengths)):
86
- s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib]
 
 
 
87
 
88
  asr_pred = t_en @ s2s_attn
89
 
90
- _, p_pred = self.model.predictor(d_en, s_dur,
91
- ref_lengths,
92
- s2s_attn,
93
- text_mask)
94
-
95
  mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
96
  mel_len = min(mel_len, self.max_len // 2)
97
-
98
- # get clips
99
-
100
- en = []
101
- p_en = []
102
- sp = []
103
-
104
- F0_fakes = []
105
- N_fakes = []
106
-
107
- wav = []
108
-
109
- for bib in range(len(output_lengths)):
110
- mel_length_pred = output_lengths[bib]
111
- mel_length_gt = int(mel_input_length[bib].item() / 2)
112
- if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
113
  continue
114
 
115
  sp.append(s_preds[bib])
116
 
117
- random_start = np.random.randint(0, mel_length_pred - mel_len)
118
- en.append(asr_pred[bib, :, random_start:random_start+mel_len])
119
- p_en.append(p_pred[bib, :, random_start:random_start+mel_len])
120
 
121
- # get ground truth clips
122
- random_start = np.random.randint(0, mel_length_gt - mel_len)
123
- y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
124
  wav.append(torch.from_numpy(y).to(ref_text.device))
125
-
126
- if len(wav) >= self.batch_percentage * len(waves): # prevent OOM due to longer lengths
127
  break
128
 
129
  if len(sp) <= 1:
130
  return None
131
-
132
  sp = torch.stack(sp)
133
  wav = torch.stack(wav).float()
134
  en = torch.stack(en)
135
  p_en = torch.stack(p_en)
136
-
137
  F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
138
  y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
139
-
140
- # discriminator loss
141
  if (iters + 1) % self.skip_update == 0:
142
- if np.random.randint(0, 2) == 0:
143
- wav = y_rec_gt_pred
144
- use_rec = True
145
- else:
146
- use_rec = False
147
-
148
- crop_size = min(wav.size(-1), y_pred.size(-1))
149
- if use_rec: # use reconstructed (shorter lengths), do length invariant regularization
150
- if wav.size(-1) > y_pred.size(-1):
151
- real_GP = wav[:, : , :crop_size]
152
- out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
153
- out_org = self.wl.discriminator_forward(wav.detach().squeeze())
154
- loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
155
-
156
- if np.random.randint(0, 2) == 0:
157
- d_loss = self.wl.discriminator(real_GP.detach().squeeze(), y_pred.detach().squeeze()).mean()
158
- else:
159
- d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
160
- else:
161
- real_GP = y_pred[:, : , :crop_size]
162
- out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
163
- out_org = self.wl.discriminator_forward(y_pred.detach().squeeze())
164
- loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)])
165
-
166
- if np.random.randint(0, 2) == 0:
167
- d_loss = self.wl.discriminator(wav.detach().squeeze(), real_GP.detach().squeeze()).mean()
168
- else:
169
- d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
170
-
171
- # regularization (ignore length variation)
172
- d_loss += loss_reg
173
-
174
- out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze())
175
- out_rec = self.wl.discriminator_forward(y_rec_gt_pred.detach().squeeze())
176
-
177
- # regularization (ignore reconstruction artifacts)
178
- d_loss += F.l1_loss(out_gt, out_rec)
179
-
180
- else:
181
- d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()
182
  else:
183
  d_loss = 0
184
-
185
- # generator loss
186
- gen_loss = self.wl.generator(y_pred.squeeze())
187
-
188
- gen_loss = gen_loss.mean()
189
-
190
  return d_loss, gen_loss, y_pred.detach().cpu().numpy()
191
-
192
- def length_to_mask(lengths):
193
- mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
194
- mask = torch.gt(mask+1, lengths.unsqueeze(1))
195
- return mask
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import torch.nn.functional as F
4
 
 
5
 
6
+ class SLMAdversarialLoss(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ model,
10
+ wl,
11
+ sampler,
12
+ min_len,
13
+ max_len,
14
+ batch_percentage=0.5,
15
+ skip_update=10,
16
+ sig=1.5,
17
+ ):
18
+ super().__init__()
19
  self.model = model
20
  self.wl = wl
21
  self.sampler = sampler
22
+
23
  self.min_len = min_len
24
  self.max_len = max_len
25
  self.batch_percentage = batch_percentage
26
+
27
  self.sig = sig
28
  self.skip_update = skip_update
29
+
30
+ # ------------------------------------------------------------------ #
31
+ def forward(
32
+ self,
33
+ iters,
34
+ y_rec_gt,
35
+ y_rec_gt_pred,
36
+ waves,
37
+ mel_input_length,
38
+ ref_text,
39
+ ref_lengths,
40
+ use_ind,
41
+ s_trg,
42
+ ref_s=None,
43
+ ):
44
+ # ---- full-width mask (matches ref_text.size(1)) ----------------
45
+ seq_len = ref_text.size(1)
46
+ text_mask = (
47
+ torch.arange(seq_len, device=ref_text.device)
48
+ .unsqueeze(0)
49
+ >= ref_lengths.unsqueeze(1)
50
+ ) # shape [B, seq_len]
51
+
52
  bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
53
+ d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
54
+
55
+ # ----- style / prosody sampling ---------------------------------
56
  if use_ind and np.random.rand() < 0.5:
57
  s_preds = s_trg
58
  else:
59
  num_steps = np.random.randint(3, 5)
60
+ noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device)
61
+ sampler_kwargs = dict(
62
+ noise=noise,
63
+ embedding=bert_dur,
64
+ embedding_scale=1,
65
+ embedding_mask_proba=0.1,
66
+ num_steps=num_steps,
67
+ )
68
  if ref_s is not None:
69
+ sampler_kwargs["features"] = ref_s
70
+ s_preds = self.sampler(**sampler_kwargs).squeeze(1)
71
+
72
+ s_dur, s = s_preds[:, 128:], s_preds[:, :128]
73
+
74
+ # random alignment placeholder must match the *padded* token width
75
+ seq_len = ref_text.size(1)
76
+ rand_align = torch.randn(ref_text.size(0), seq_len, 2, device=ref_text.device)
77
+
78
+ d, _ = self.model.predictor(
79
+ d_en, s_dur, ref_lengths,
80
+ rand_align,
81
+ text_mask,
82
+ )
83
+
84
+ # ----- differentiable duration modelling -----------------------
85
+ attn_preds, output_lengths = [], []
86
+ for _s2s_pred, _len in zip(d, ref_lengths):
87
+ _s2s_pred_org = _s2s_pred[: _len]
88
+ _s2s_pred_sig = torch.sigmoid(_s2s_pred_org)
89
+ _dur_pred = _s2s_pred_sig.sum(dim=-1)
90
+
91
+ l = int(torch.round(_s2s_pred_sig.sum()).item())
92
+ t = torch.arange(l, device=ref_text.device).unsqueeze(0).expand(_len, l)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
94
+ h = torch.exp(-0.5 * (t - (l - loc.unsqueeze(-1))) ** 2 / (self.sig**2))
95
+
96
+ out = F.conv1d(
97
+ _s2s_pred_org.unsqueeze(0),
98
+ h.unsqueeze(1),
99
+ padding=h.size(-1) - 1,
100
+ groups=int(_len),
101
+ )[..., :l]
102
  attn_preds.append(F.softmax(out.squeeze(), dim=0))
 
103
  output_lengths.append(l)
104
 
105
  max_len = max(output_lengths)
106
+
107
+ # ----- build full-width alignment matrix -----------------------
108
  with torch.no_grad():
109
  t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
110
+
111
+ seq_len = ref_text.size(1)
112
+ s2s_attn = torch.zeros(
113
+ len(ref_lengths), seq_len, max_len, device=ref_text.device
114
+ )
115
+ for bib, (attn, L) in enumerate(zip(attn_preds, output_lengths)):
116
+ s2s_attn[bib, : ref_lengths[bib], :L] = attn
117
 
118
  asr_pred = t_en @ s2s_attn
119
 
120
+ _, p_pred = self.model.predictor(
121
+ d_en, s_dur, ref_lengths, s2s_attn, text_mask
122
+ )
123
+
124
+ # ----- clip extraction -----------------------------------------
125
  mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
126
  mel_len = min(mel_len, self.max_len // 2)
127
+
128
+ en, p_en, sp, wav = [], [], [], []
129
+ for bib, L_pred in enumerate(output_lengths):
130
+ L_gt = int(mel_input_length[bib].item() / 2)
131
+ if L_gt <= mel_len or L_pred <= mel_len:
 
 
 
 
 
 
 
 
 
 
 
132
  continue
133
 
134
  sp.append(s_preds[bib])
135
 
136
+ start = np.random.randint(0, L_pred - mel_len)
137
+ en.append(asr_pred[bib, :, start : start + mel_len])
138
+ p_en.append(p_pred[bib, :, start : start + mel_len])
139
 
140
+ start_gt = np.random.randint(0, L_gt - mel_len)
141
+ y = waves[bib][(start_gt * 2) * 300 : ((start_gt + mel_len) * 2) * 300]
 
142
  wav.append(torch.from_numpy(y).to(ref_text.device))
143
+
144
+ if len(wav) >= self.batch_percentage * len(waves):
145
  break
146
 
147
  if len(sp) <= 1:
148
  return None
149
+
150
  sp = torch.stack(sp)
151
  wav = torch.stack(wav).float()
152
  en = torch.stack(en)
153
  p_en = torch.stack(p_en)
154
+
155
  F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
156
  y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
157
+
158
+ # -------------- adversarial losses -----------------------------
159
  if (iters + 1) % self.skip_update == 0:
160
+ d_loss = self.wl.discriminator(wav.squeeze(), y_pred.detach().squeeze()).mean()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  else:
162
  d_loss = 0
163
+
164
+ gen_loss = self.wl.generator(y_pred.squeeze()).mean()
 
 
 
 
165
  return d_loss, gen_loss, y_pred.detach().cpu().numpy()
166
+
167
+
168
+ # ------------------------------------------------------------------ #
169
+ def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
170
+ """Classic length mask: 1 → PAD, 0 → real token."""
171
+ max_len = lengths.max()
172
+ mask = (
173
+ torch.arange(max_len, device=lengths.device)
174
+ .unsqueeze(0)
175
+ .expand(lengths.size(0), -1)
176
+ )
177
+ return mask >= lengths.unsqueeze(1)
__pycache__/losses.cpython-310.pyc CHANGED
Binary files a/__pycache__/losses.cpython-310.pyc and b/__pycache__/losses.cpython-310.pyc differ
 
__pycache__/meldataset.cpython-310.pyc CHANGED
Binary files a/__pycache__/meldataset.cpython-310.pyc and b/__pycache__/meldataset.cpython-310.pyc differ
 
__pycache__/models.cpython-310.pyc CHANGED
Binary files a/__pycache__/models.cpython-310.pyc and b/__pycache__/models.cpython-310.pyc differ
 
__pycache__/optimizers.cpython-310.pyc CHANGED
Binary files a/__pycache__/optimizers.cpython-310.pyc and b/__pycache__/optimizers.cpython-310.pyc differ
 
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
events.out.tfevents.1749451143.164-152-17-237.47710.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ca5ac7da0de1cd8b2940a042eddfe0f7ea50cc867411a91d90240fa2186962b0
3
- size 88
 
 
 
 
events.out.tfevents.1749451143.164-152-17-237.47712.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:24f7b8986b9471590fd70ce3705e31a5b5a97854cdc1887585591ba318c1c150
3
- size 88
 
 
 
 
events.out.tfevents.1749451144.164-152-17-237.47706.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a4f87f7a9fa06bc2a39e77d91d3dd4c7d76ee7c9bbbf2f6d6b73f3a9d6836d0a
3
- size 88
 
 
 
 
events.out.tfevents.1749451144.164-152-17-237.47708.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4f57d72c8bb8f7d68c2a16d4e5eea3151d1cd8aa752be3a879c003aa481c19b3
3
- size 88
 
 
 
 
events.out.tfevents.1749451144.164-152-17-237.47709.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:103d5b2f29512166ac9979033248e0fb344847396ec0ed3dea7e96e5fad84e80
3
- size 88
 
 
 
 
events.out.tfevents.1749451144.164-152-17-237.47711.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7fa0f5d9031c97fbf708f0a40c4e2950dcbd07c683488a659710ab9fcfd1c224
3
- size 88
 
 
 
 
events.out.tfevents.1749451220.164-152-17-237.48862.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1f67c96bcdf2b41944f1f6710d3735137dd4254b5d58570d3b304e894de5acc8
3
- size 88
 
 
 
 
events.out.tfevents.1749451220.164-152-17-237.48863.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2e09c7f592ec6d1f20e8a0b1e0fce4ff9b209f4c6d2e466ee6c2a10c761207a4
3
- size 88
 
 
 
 
events.out.tfevents.1749451220.164-152-17-237.48864.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6358d1685327b1fc73bcdfc1ba181c900d63e4bc2a679a646aa697446cbcc818
3
- size 88
 
 
 
 
events.out.tfevents.1749451220.164-152-17-237.48865.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0c2cfe09514438496a286074eff7b5d988953c53eca8c446a57b833aca2cd233
3
- size 88
 
 
 
 
events.out.tfevents.1749451220.164-152-17-237.48868.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a2d99a10e442411fc79a40bba5ba012773c90f5be44254d6b872f3e350d0bb98
3
- size 88
 
 
 
 
events.out.tfevents.1749451221.164-152-17-237.48861.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:72554714ef2293ef47dc12bdc8698f70024abd16bfa736eb3de03d0e8b1c0eee
3
- size 88
 
 
 
 
events.out.tfevents.1749451221.164-152-17-237.48867.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7a882509295dfd1193368519c9cad538370392117ee7b6c483ee939ee7979769
3
- size 88
 
 
 
 
events.out.tfevents.1749451222.164-152-17-237.48866.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:37a6fd68ad5ef36692d7d5389ba938e318c6287b20c2684a628e5245f186048c
3
- size 88
 
 
 
 
events.out.tfevents.1749453792.164-152-17-237.51057.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8d42df1e6023c4e593c8479eecb153738ecc7600b94a1c388173708d38fc3688
3
- size 88
 
 
 
 
events.out.tfevents.1749453792.164-152-17-237.51059.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3be5b37365ae5d97c1f0573f94f70a082f707c6c7d49926d118111ac7e48a818
3
- size 88
 
 
 
 
events.out.tfevents.1749453792.164-152-17-237.51061.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:16300e80b48b67bd14ce00c5751e4a6841df3da3be4873c677dc08a99a6c3aae
3
- size 88
 
 
 
 
events.out.tfevents.1749453792.164-152-17-237.51063.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:84bbbebc6a97e89725078bcf6533475d3286f6af25b23f781558e7cb8d8957e3
3
- size 88
 
 
 
 
events.out.tfevents.1749453793.164-152-17-237.51056.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a3560e484d3bd2c79ebe1507599d497f261cb3b62d1644697a5a2953d156c0d
3
- size 88
 
 
 
 
events.out.tfevents.1749453793.164-152-17-237.51058.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c70617864f6b48f9177c3051cf3d0e857374656601e1f3b2130e91dd6d3090ed
3
- size 88
 
 
 
 
events.out.tfevents.1749453793.164-152-17-237.51060.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e2a2cccc5958217207ffb6d11f98e619f4dedf92682121709eb5870fb3db085d
3
- size 88
 
 
 
 
events.out.tfevents.1749453794.164-152-17-237.51062.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b66ebcaedfeb8d0a1d4a801e5e18783d8dfa41eedf362b2a0733a38c8f0a82fa
3
- size 88
 
 
 
 
events.out.tfevents.1749453905.164-152-17-237.52357.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:39098f172013dbdfbddcfb78c6c126b7b67671f89989a344d4262fcd433c3e9f
3
- size 88
 
 
 
 
events.out.tfevents.1749453905.164-152-17-237.52358.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e31a13e3e32ec6f2081c987fd6b1bd6c20bce2ca61312a576eea1aeceea533dc
3
- size 88
 
 
 
 
events.out.tfevents.1749453905.164-152-17-237.52360.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c85710b39b9fb9764094a24aa502a715e94261613d7865f89e037553679ee109
3
- size 88
 
 
 
 
events.out.tfevents.1749453905.164-152-17-237.52361.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:521592225e7079f153bc3a7f18f79d9122eedbe6dffc5d342912d69ba4a9a1e1
3
- size 88
 
 
 
 
events.out.tfevents.1749453906.164-152-17-237.52355.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:914db70df7d7e69fc3396fb49f23b5a6849c9a251533d681c447df73ab81df34
3
- size 88
 
 
 
 
events.out.tfevents.1749453906.164-152-17-237.52356.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e1c4479d3572d16ab62395082c0c4a300c42b739f724966866a1a7b15c08344e
3
- size 88
 
 
 
 
events.out.tfevents.1749453906.164-152-17-237.52359.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f2daad6b2a7a604a0853f70aee92c15f79557490acd34797d875c670446db7e3
3
- size 88
 
 
 
 
events.out.tfevents.1749453906.164-152-17-237.52362.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6abbb8d2b926e553f18af6c4c9203b122736ea0c2a508407150076c1b2842dad
3
- size 88
 
 
 
 
events.out.tfevents.1749453977.164-152-17-237.53096.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7a6eaa7928ccbca758c1c4d170f439c3938a2f300a78a12746749cfca3b997cf
3
- size 88
 
 
 
 
events.out.tfevents.1749453977.164-152-17-237.53097.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f90c29e2391c038fb896ced73c9099014dfbff51e889bf6075ea8181a59da78d
3
- size 88
 
 
 
 
events.out.tfevents.1749453977.164-152-17-237.53098.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0d82622b483fc183e71198f54b5b10c60851af9752b27ddf4dbaca39e988e15d
3
- size 88
 
 
 
 
events.out.tfevents.1749453977.164-152-17-237.53099.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:85fb51c72b78ca2c571570edd1a44288737f426a7e314f865c395d3f1d42d764
3
- size 88
 
 
 
 
events.out.tfevents.1749453977.164-152-17-237.53100.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9d3516306f7671fae64a2df06e433f84e451c5d5e6387a9159d6861958371c75
3
- size 88