3v324v23 commited on
Commit
af11ce4
Β·
0 Parent(s):
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. LICENSE +201 -0
  3. README.md +9 -0
  4. T_English.wav +3 -0
  5. T_English_output.wav +3 -0
  6. ToαΊ‘i.wav +3 -0
  7. ToαΊ‘i_output.wav +3 -0
  8. Trung.wav +3 -0
  9. Trung_output.wav +3 -0
  10. app.py +346 -0
  11. assets/silence.wav +3 -0
  12. download_models.py +38 -0
  13. egs/zipvoice/README.md +15 -0
  14. egs/zipvoice/conf/zipvoice_base.json +26 -0
  15. egs/zipvoice/local/pinyin.txt +1550 -0
  16. egs/zipvoice/local/prepare_emilia.sh +149 -0
  17. egs/zipvoice/local/prepare_libritts.sh +100 -0
  18. egs/zipvoice/local/prepare_token_file_char.py +67 -0
  19. egs/zipvoice/local/prepare_token_file_emilia.py +91 -0
  20. egs/zipvoice/local/prepare_tokens_emilia.py +88 -0
  21. egs/zipvoice/local/preprocess_emilia.py +210 -0
  22. egs/zipvoice/run_custom.sh +138 -0
  23. egs/zipvoice/run_emilia.sh +178 -0
  24. egs/zipvoice/run_eval.sh +142 -0
  25. egs/zipvoice/run_finetune.sh +175 -0
  26. egs/zipvoice/run_libritts.sh +148 -0
  27. egs/zipvoice/utils/parse_options.sh +97 -0
  28. egs/zipvoice/utils/validate_manifest.py +70 -0
  29. egs/zipvoice_dialog/README.md +12 -0
  30. egs/zipvoice_dialog/local/prepare_opendialog.py +262 -0
  31. egs/zipvoice_dialog/run_custom.sh +145 -0
  32. egs/zipvoice_dialog/run_eval.sh +120 -0
  33. egs/zipvoice_dialog/run_finetune.sh +135 -0
  34. egs/zipvoice_dialog/run_opendialog.sh +122 -0
  35. infer.py +578 -0
  36. proccess_wav.py +364 -0
  37. pyproject.toml +5 -0
  38. requirements.txt +23 -0
  39. requirements_eval.txt +19 -0
  40. setup.py +55 -0
  41. zipvoice/__init__.py +7 -0
  42. zipvoice/bin/compute_fbank.py +272 -0
  43. zipvoice/bin/generate_averaged_model.py +229 -0
  44. zipvoice/bin/infer_zipvoice.py +614 -0
  45. zipvoice/bin/infer_zipvoice_dialog.py +756 -0
  46. zipvoice/bin/infer_zipvoice_onnx.py +712 -0
  47. zipvoice/bin/onnx_export.py +410 -0
  48. zipvoice/bin/prepare_dataset.py +274 -0
  49. zipvoice/bin/prepare_tokens.py +102 -0
  50. zipvoice/bin/train_zipvoice.py +1136 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.wav filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-sa-4.0
3
+ title: Text_to_speech_Vietnamese
4
+ sdk: gradio
5
+ emoji: πŸš€
6
+ colorFrom: red
7
+ colorTo: yellow
8
+ pinned: false
9
+ ---
T_English.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6ffd499fdf637243bdd630cb52660635d5c7cb580b87f52aa7efca90a33311f
3
+ size 328364
T_English_output.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:756daab105a8f4508e6d4c237f1e72a25ec326ad1f6665dce974f96e9b86db7a
3
+ size 954742
ToαΊ‘i.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:798c60882fa0e9a758fd72cf506e6b71293c5ccb5ed27b92569d042a23624bdc
3
+ size 200782
ToαΊ‘i_output.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3916769643ff756700e08ae355a0f7e30fd7a0e5299b06a866facad5ff31afd1
3
+ size 2154910
Trung.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bf7087aa7452978b6ec6c25eb8c078eb4ca337660c9d3e3661b8017da9238e9
3
+ size 199376
Trung_output.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:068c6cfd893846473d177b9636b4689119a0c0ee7e19f4079cd6c98e27bb94a3
3
+ size 745196
app.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ from download_models import download_all_models
4
+ from huggingface_hub import login
5
+
6
+ import gradio as gr
7
+
8
+ # ======================= HF LOGIN & DOWNLOAD MODEL =======================
9
+ hf_token = os.getenv("HF_TOKEN")
10
+ if hf_token:
11
+ login(token=hf_token)
12
+
13
+ # TαΊ£i model khi Space ACTIVE
14
+ download_all_models()
15
+
16
+ from infer import run_zipvoice
17
+
18
+ # NEW: ASR + DENOISE
19
+ from chunkformer import ChunkFormerModel
20
+ from clearvoice import ClearVoice
21
+ from proccess_wav import enhance_ref_audio, transcribe_ref_audio
22
+
23
+ # (NαΊΏu 2 dΓ²ng test nΓ y khΓ΄ng cαΊ§n thΓ¬ bαΊ‘n cΓ³ thể xoΓ‘ bα»›t cho nhαΊΉ)
24
+ enhanced = enhance_ref_audio("ToαΊ‘i.wav")
25
+ text = transcribe_ref_audio(enhanced)
26
+
27
+
28
+ def infer_ref_text_ui(ref_audio_path: str) -> str:
29
+ """
30
+ DΓΉng cho nΓΊt 'Infer Text':
31
+ - Enhance WAV (ClearVoice + xα»­ lΓ½ khoαΊ£ng lαΊ·ng + cαΊ―t 5–10s)
32
+ - ASR theo khoαΊ£ng lαΊ·ng
33
+ - Đổ kαΊΏt quαΊ£ vΓ o Γ΄ Reference Text
34
+ """
35
+ if not ref_audio_path:
36
+ raise gr.Error("Vui lΓ²ng upload file giọng mαΊ«u trΖ°α»›c khi infer text.")
37
+
38
+ try:
39
+ enhanced = enhance_ref_audio(ref_audio_path)
40
+ text = transcribe_ref_audio(enhanced)
41
+ except Exception as e:
42
+ raise gr.Error(f"Lα»—i khi nhαΊ­n dαΊ‘ng tα»« audio tham chiαΊΏu: {e}")
43
+
44
+ if not text:
45
+ raise gr.Error("KhΓ΄ng nhαΊ­n dαΊ‘ng được nα»™i dung tα»« audio tham chiαΊΏu.")
46
+ return text
47
+
48
+
49
+ # ======================= CẀU HÌNH DEMO SẴN =======================
50
+ SAMPLE_CONFIGS = [
51
+ {
52
+ "name": "Sample 1 – Kể chuyện",
53
+ "ref_audio": "ToαΊ‘i.wav",
54
+ "ref_text": "Trong bΓ³ng tα»‘i, ToαΊ‘i nΓ³i cΓ‘i gΓ¬ Δ‘Γ³ mΓ  Thoan khΓ΄ng nghe thαΊ₯y.",
55
+ "gen_text": "ĐΓͺm nay trời nhiều mΓ’y, Γ‘nh trΔƒng bα»‹ che khuαΊ₯t, chỉ cΓ²n lαΊ‘i mα»™t dαΊ£i sΓ‘ng yαΊΏu α»›t rΖ‘i xuα»‘ng con đường Δ‘αΊ₯t trαΊ£i dΓ i giα»―a cΓ‘nh Δ‘α»“ng. CαΊ­u bΓ© tΓͺn TΓ­n Δ‘ang dαΊ―t chiαΊΏc xe Δ‘αΊ‘p cΕ© Δ‘i về nhΓ , bΓ‘nh xe bα»‹ cΓ‘n Δ‘inh nΓͺn lΔƒn nαΊ·ng vΓ  chαΊ­m nhΖ° con trΓ’u mệt nhọc sau vα»₯ mΓΉa. GiΓ³ thα»•i lαΊ‘nh buα»‘t, mΓΉi bΓΉn Δ‘αΊ₯t ngai ngΓ‘i quαΊ₯n lαΊ₯y chΓ’n cαΊ­u. Tα»›i Δ‘oαΊ‘n rαΊ½ dαΊ«n vΓ o xΓ³m, TΓ­n nghe tiαΊΏng nΖ°α»›c chαΊ£y khe khαΊ½ tα»« con mΖ°Ζ‘ng bΓͺn đường. TiαΊΏng αΊ₯y vαΊ«n quen thuα»™c, nhΖ°ng tα»‘i nay lαΊ‘i vang khΓ‘c lαΊ‘, nhΖ° cΓ³ giọng người Δ‘ang hΓ²a vΓ o nhα»‹p nΖ°α»›c, lΓΊc trαΊ§m lΓΊc cao, nghe mΖ‘ hα»“ mΓ  lαΊ‘nh sα»‘ng lΖ°ng. CαΊ­u dα»«ng lαΊ‘i, nghiΓͺng tai lαΊ―ng nghe, tim Δ‘αΊ­p nhanh nhΖ° muα»‘n vượt khỏi lα»“ng ngα»±c.",
56
+ "out_audio": "ToαΊ‘i_output.wav",
57
+ },
58
+ {
59
+ "name": "Sample 2 – Nα»―",
60
+ "ref_audio": "Trung.wav",
61
+ "ref_text": "MΓΉa hΓ¨ khΓ΄ng chỉ lΓ  khoαΊ£ng thời gian nghỉ ngΖ‘i, mΓ  cΓ²n lΓ  khoαΊ£ng thời gian tuyệt vời.",
62
+ "gen_text": "Tα»« cΓ‘c kαΊΏt quαΊ£ nΓ y, chΓΊng tΓ΄i đề xuαΊ₯t rαΊ±ng sα»± kαΊΏt hợp nhuαΊ§n nhuyα»…n giα»―a adaptive optimization, robust training pipelines vΓ  interpretable model design sαΊ½ lΓ  chΓ¬a khΓ³a để phΓ‘t triển cΓ‘c hệ thα»‘ng Γ’y ai vα»«a mαΊ‘nh mαΊ½ vα»«a Δ‘Γ‘ng tin cαΊ­y trong mΓ΄i trường thα»±c tαΊΏ.",
63
+ "out_audio": "Trung_output.wav",
64
+ },
65
+ {
66
+ "name": "Sample 3 – English",
67
+ "ref_audio": "T_English.wav",
68
+ "ref_text": "And turning to the pole which he had dragged, He drew it close beneath the widowed bough, And what was of it unto it left bound.",
69
+ "gen_text": "Recent experiments indicate that the current model architecture still exhibits significant overfitting, especially when evaluated on out of distribution samples. Although the training accuracy remains consistently high, the performance drops sharply when the model is exposed to noise perturbed inputs, suggesting limited robustness.",
70
+ "out_audio": "T_English_output.wav",
71
+ },
72
+ ]
73
+
74
+ # HΓ m dΓΉng khi bαΊ₯m "DΓΉng sample nΓ y"
75
+ def make_sample_loader(sample):
76
+ def _load_sample():
77
+ return (
78
+ sample["ref_audio"], # ref_audio -> input Audio
79
+ sample["ref_text"], # ref_text -> Textbox
80
+ sample["gen_text"], # gen_text -> Textbox
81
+ sample["out_audio"], # output_audio -> Audio
82
+ )
83
+ return _load_sample
84
+
85
+
86
+ # ======================= STYLE TΓ™Y CHỈNH (LΓ€M SÁNG HΖ N) =======================
87
+ custom_css = """
88
+ #app-container {
89
+ max-width: 1000px;
90
+ margin: 0 auto;
91
+ }
92
+ .gradio-container {
93
+ background: radial-gradient(circle at top, #ffffff 0, #f9fafb 55%);
94
+ color: #111827;
95
+ }
96
+
97
+ /* TiΓͺu đề lα»›n */
98
+ #title-block h1 {
99
+ font-size: 2.4rem !important;
100
+ font-weight: 800 !important;
101
+ background: linear-gradient(120deg, #f97316, #eab308, #22c55e);
102
+ -webkit-background-clip: text;
103
+ color: transparent;
104
+ text-align: center;
105
+ }
106
+ #title-block p {
107
+ text-align:center;
108
+ font-size: 0.95rem;
109
+ color: #6b7280;
110
+ }
111
+
112
+ /* Card sΓ‘ng hΖ‘n */
113
+ .sample-card {
114
+ border-radius: 16px;
115
+ padding: 16px;
116
+ background: rgba(255, 255, 255, 0.96);
117
+ border: 1px solid rgba(148, 163, 184, 0.6);
118
+ box-shadow: 0 18px 28px rgba(148, 163, 184, 0.35);
119
+ }
120
+
121
+ /* NΓΊt bαΊ₯m */
122
+ button.primary {
123
+ border-radius: 999px !important;
124
+ font-weight: 600 !important;
125
+ }
126
+
127
+ /* Tabs */
128
+ .svelte-1ipelgc, .tabitem {
129
+ font-weight: 600;
130
+ }
131
+ """
132
+
133
+ # ======================= XỬ LÝ TEXT (NẾU CẦN) =======================
134
+ def post_process(text: str) -> str:
135
+ text = " " + text + " "
136
+ text = text.replace(" . . ", " . ")
137
+ text = " " + text + " "
138
+ text = text.replace(" .. ", " . ")
139
+ text = " " + text + " "
140
+ text = text.replace(" , , ", " , ")
141
+ text = " " + text + " "
142
+ text = text.replace(" ,, ", " , ")
143
+ text = " " + text + " "
144
+ text = text.replace('"', "")
145
+ return " ".join(text.split())
146
+
147
+
148
+ @spaces.GPU
149
+ def infer_tts(ref_audio_path, ref_text, gen_text, steps, request: gr.Request = None):
150
+ if not ref_audio_path:
151
+ raise gr.Error("Please upload a sample audio file.")
152
+
153
+ if not gen_text.strip():
154
+ raise gr.Error("Please enter the text content to generate voice.")
155
+
156
+ # Giα»›i hαΊ‘n Δ‘α»™ dΓ i nα»™i dung (4000 tα»«)
157
+ if len(gen_text.split()) > 4000:
158
+ raise gr.Error("Please enter text content with less than 4000 words.")
159
+
160
+ # 1) Enhance ref audio: clearvoice + xα»­ lΓ½ khoαΊ£ng lαΊ·ng + cαΊ―t 5–10s
161
+ try:
162
+ enhanced_ref_audio = enhance_ref_audio(ref_audio_path)
163
+ except Exception as e:
164
+ raise gr.Error(f"Lα»—i khi xα»­ lΓ½ audio tham chiαΊΏu: {e}")
165
+
166
+ # 2) Nếu không có ref_text thì chẑy ASR theo khoảng lặng
167
+ if not ref_text or not ref_text.strip():
168
+ try:
169
+ inferred = transcribe_ref_audio(enhanced_ref_audio)
170
+ if not inferred:
171
+ raise gr.Error(
172
+ "KhΓ΄ng nhαΊ­n dαΊ‘ng được nα»™i dung tα»« audio tham chiαΊΏu. "
173
+ "Vui lΓ²ng nhαΊ­p Reference Text thα»§ cΓ΄ng."
174
+ )
175
+ ref_text = inferred
176
+ print(f"[ASR] Inferred ref_text: {ref_text}")
177
+ except gr.Error:
178
+ raise
179
+ except Exception as e:
180
+ raise gr.Error(f"Lα»—i khi tα»± Δ‘α»™ng nhαΊ­n dαΊ‘ng Reference Text: {e}")
181
+
182
+ try:
183
+ out_path = "result.wav"
184
+
185
+ run_zipvoice(
186
+ model_name="zipvoice",
187
+ prompt_wav=enhanced_ref_audio, # dΓΉng file Δ‘Γ£ xα»­ lΓ½
188
+ prompt_text=ref_text.strip() if ref_text else "xin chΓ o cΓ‘c bαΊ‘n",
189
+ text=gen_text,
190
+ res_wav_path=out_path,
191
+ lang="vi",
192
+ tokenizer_name="espeak",
193
+ num_step=steps,
194
+ seed=123456,
195
+ speed=1.0,
196
+ )
197
+
198
+ return out_path
199
+
200
+ except Exception as e:
201
+ raise gr.Error(f"Error generating voice: {e}")
202
+
203
+
204
+ # ======================= UI =======================
205
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
206
+ with gr.Column(elem_id="app-container"):
207
+ # --------- TIÊU ĐỀ ----------
208
+ gr.Markdown(
209
+ """
210
+ <div id="title-block">
211
+ <h1>🎀 ZipVoice – Zero-shot Vietnamese TTS</h1>
212
+ <p>Upload mα»™t mαΊ«u giọng + nhαΊ­p nα»™i dung &rarr; hệ thα»‘ng sαΊ½ bαΊ―t chΖ°α»›c giọng nΓ³i vΓ  đọc Δ‘oαΊ‘n text cα»§a bαΊ‘n.</p>
213
+ </div>
214
+ """,
215
+ elem_id="title-block",
216
+ )
217
+
218
+ with gr.Tabs():
219
+ # Chỉ cΓ²n 1 tab chΓ­nh, demo cΕ©ng nαΊ±m trong tab nΓ y
220
+ with gr.TabItem("🎯 Tự tẑo giọng nói"):
221
+ # --------- KHỐI INPUT / OUTPUT CHÍNH ----------
222
+ with gr.Row():
223
+ with gr.Column(elem_classes=["sample-card"]):
224
+ gr.Markdown("#### 1️⃣ TαΊ£i giọng mαΊ«u & nhαΊ­p text")
225
+
226
+ ref_audio = gr.Audio(
227
+ label="πŸ”Š Sample Voice (upload hoαΊ·c kΓ©o thαΊ£)",
228
+ type="filepath",
229
+ )
230
+
231
+ ref_text = gr.Textbox(
232
+ label="πŸ“ Reference Text (optional)",
233
+ placeholder="Nα»™i dung Δ‘ang được nΓ³i trong file giọng mαΊ«u (nΓͺn tα»± viαΊΏt cho chΓ­nh xΓ‘c)",
234
+ lines=3,
235
+ )
236
+
237
+ # NΓΊt infer text tα»« audio tham chiαΊΏu (ASR + khα»­ nhiα»…u)
238
+ btn_infer_text = gr.Button(
239
+ "✨ Infer Text từ audio tham chiếu"
240
+ )
241
+
242
+ gen_text = gr.Textbox(
243
+ label="πŸ“ Text to Generate",
244
+ placeholder="NhαΊ­p nα»™i dung tiαΊΏng Việt bαΊ‘n muα»‘n tα»•ng hợp...",
245
+ lines=6,
246
+ )
247
+
248
+ steps = gr.Slider(
249
+ 8,
250
+ 64,
251
+ value=25,
252
+ step=1,
253
+ label="⚑ Step (cΓ ng lα»›n, cΓ ng tα»‘t, cΓ ng lΓ’u)",
254
+ )
255
+
256
+ btn_synthesize = gr.Button(
257
+ "πŸ”₯ Generate Voice",
258
+ variant="primary",
259
+ )
260
+
261
+ with gr.Column(elem_classes=["sample-card"]):
262
+ gr.Markdown("#### 2️⃣ KαΊΏt quαΊ£ tα»•ng hợp")
263
+ output_audio = gr.Audio(
264
+ label="🎧 Generated Audio",
265
+ type="filepath",
266
+ )
267
+ gr.Markdown(
268
+ """
269
+ - BαΊ‘n cΓ³ thể tαΊ£i file `.wav` về sau khi tαΊ‘o.
270
+ - NαΊΏu nghe chΖ°a α»•n, hΓ£y thα»­:
271
+ - DΓΉng **ref audio ngαΊ―n 3-8s, phΓ‘t Γ’m chuαΊ©n hΖ‘n.
272
+ """
273
+ )
274
+
275
+ # mapping nΓΊt Generate -> infer_tts
276
+ btn_synthesize.click(
277
+ infer_tts,
278
+ inputs=[ref_audio, ref_text, gen_text, steps],
279
+ outputs=[output_audio],
280
+ )
281
+
282
+ # mapping nΓΊt Infer Text -> Δ‘iền ref_text (cΓ³ khα»­ nhiα»…u trΖ°α»›c)
283
+ btn_infer_text.click(
284
+ infer_ref_text_ui,
285
+ inputs=[ref_audio],
286
+ outputs=[ref_text],
287
+ )
288
+
289
+ # --------- KHỐI DEMO NẰM NGAY TRONG TAB CHÍNH ----------
290
+ gr.Markdown(
291
+ """
292
+ ### 🎧 Demo có sạn
293
+ Click vΓ o mα»™t sample bΓͺn dΖ°α»›i để tα»± Δ‘α»™ng nαΊ‘p:
294
+ - πŸ”Š Giọng mαΊ«u (ref voice)
295
+ - πŸ“ Reference text
296
+ - πŸ“ Text to generate
297
+ - 🎧 Output audio mẫu
298
+ """
299
+ )
300
+
301
+ for sample in SAMPLE_CONFIGS:
302
+ with gr.Column(elem_classes=["sample-card"]):
303
+ gr.Markdown(f"### {sample['name']}")
304
+ with gr.Row():
305
+ gr.Audio(
306
+ value=sample["ref_audio"],
307
+ label="πŸ”Š Reference Voice",
308
+ interactive=False,
309
+ )
310
+ gr.Textbox(
311
+ value=sample["ref_text"],
312
+ label="πŸ“ Reference Text",
313
+ interactive=False,
314
+ lines=3,
315
+ )
316
+
317
+ gr.Audio(
318
+ value=sample["out_audio"],
319
+ label="🎧 Generated Sample (TTS)",
320
+ interactive=False,
321
+ )
322
+
323
+ if sample.get("gen_text"):
324
+ gr.Markdown(
325
+ f"**Text dΓΉng để synth:** {sample['gen_text']}"
326
+ )
327
+
328
+ # NΓΊt nΓ y sαΊ½ fill luΓ΄n ref_audio, ref_text, gen_text, output_audio
329
+ use_btn = gr.Button(f"➑️ Dùng {sample['name']}")
330
+
331
+ use_btn.click(
332
+ make_sample_loader(sample),
333
+ inputs=[],
334
+ outputs=[ref_audio, ref_text, gen_text, output_audio],
335
+ )
336
+
337
+ gr.Markdown(
338
+ """
339
+ ### ⚠️ Model Limitations
340
+ 1. CΓ³ thể xα»­ lΓ½ chΖ°a tα»‘t vα»›i sα»‘, ngΓ y thΓ‘ng, kΓ½ tα»± Δ‘αΊ·c biệt.
341
+ 2. Nhα»‹p Δ‘iệu Δ‘Γ΄i khi chΖ°a tα»± nhiΓͺn.
342
+ 3. ChαΊ₯t lượng phα»₯ thuα»™c khΓ‘ nhiều vΓ o chαΊ₯t lượng ref audio.
343
+ """
344
+ )
345
+
346
+ demo.queue().launch()
assets/silence.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca5a251f2d1439929f1c6b44d98299e53d402da45306af79cbfab5005501fed9
3
+ size 4800044
download_models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+
4
+ MODEL_DIR = "zipvoice_finetune"
5
+ os.makedirs(MODEL_DIR, exist_ok=True)
6
+
7
+ files = {
8
+ "iter-525000-avg-2.pt": "https://huggingface.co/datasets/kjanh/demo_zip/resolve/main/epoch-46-all-speak-600h-en-norm.pt",
9
+ "model.json": "https://huggingface.co/datasets/kjanh/demo_zip/resolve/main/model.json",
10
+ "tokens.txt": "https://huggingface.co/datasets/kjanh/demo_zip/resolve/main/tokens.txt",
11
+ }
12
+
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
+
15
+ def download_with_token(url, dest_path):
16
+ if os.path.exists(dest_path):
17
+ print(f"βœ” File tα»“n tαΊ‘i: {dest_path}")
18
+ return
19
+
20
+ if HF_TOKEN is None:
21
+ raise RuntimeError("❌ Missing HF_TOKEN in Secrets!")
22
+
23
+ print(f"⬇ Downloading {dest_path} ...")
24
+
25
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
26
+ r = requests.get(url, headers=headers, stream=True)
27
+ r.raise_for_status()
28
+
29
+ with open(dest_path, "wb") as f:
30
+ for chunk in r.iter_content(1024 * 1024):
31
+ f.write(chunk)
32
+
33
+ print(f"βœ… Downloaded {dest_path}")
34
+ # demo
35
+ def download_all_models():
36
+ for filename, url in files.items():
37
+ dest = os.path.join(MODEL_DIR, filename)
38
+ download_with_token(url, dest)
egs/zipvoice/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ZipVoice Recipe
2
+
3
+ This recipe contains the following examples:
4
+
5
+ - Training ZipVoice on Emilia from scratch, see [run_emilia.sh](run_emilia.sh)
6
+ - Training ZipVoice on LibriTTS from scratch, see [run_libritts.sh](run_libritts.sh).
7
+ - Training ZipVoice on custom datasets (any language) from scratch, see [run_custom.sh](run_custom.sh).
8
+ - Fine-tuning pre-trained ZipVoice on custom datasets (any language), see [run_finetune.sh](run_finetune.sh).
9
+ - Evaluate TTS models with objective metrics reported in ZipVoice paper, see [run_eval.sh](run_eval.sh).
10
+
11
+ > **NOTE:** [run_emilia.sh](run_emilia.sh) is the most complete example, which covers: data preparation, ZipVoice trainnig, ZipVoice-Distill training, onnx export, and inference with all PyTorch and ONNX models.
12
+
13
+ > **NOTE:** For evaluation, first install packages from [../../requirements_eval.txt](../../requirements_eval.txt)
14
+ >
15
+ > `pip install -r ../../requirements_eval.txt`
egs/zipvoice/conf/zipvoice_base.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model" : {
3
+ "fm_decoder_downsampling_factor" : [1,2,4,2,1],
4
+ "fm_decoder_num_layers" : [2,2,4,4,4],
5
+ "fm_decoder_cnn_module_kernel" : [31,15,7,15,31],
6
+ "fm_decoder_feedforward_dim" : 1536,
7
+ "fm_decoder_num_heads" : 4,
8
+ "fm_decoder_dim" : 512,
9
+ "text_encoder_num_layers" : 4,
10
+ "text_encoder_feedforward_dim" : 512,
11
+ "text_encoder_cnn_module_kernel" : 9,
12
+ "text_encoder_num_heads" : 4,
13
+ "text_encoder_dim" : 192,
14
+ "query_head_dim" : 32,
15
+ "value_head_dim" : 12,
16
+ "pos_head_dim" : 4,
17
+ "pos_dim" : 48,
18
+ "time_embed_dim" : 192,
19
+ "text_embed_dim" : 192,
20
+ "feat_dim": 100
21
+ },
22
+ "feature" : {
23
+ "sampling_rate": 24000,
24
+ "type": "vocos"
25
+ }
26
+ }
egs/zipvoice/local/pinyin.txt ADDED
@@ -0,0 +1,1550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a
2
+ a1
3
+ a2
4
+ a3
5
+ a4
6
+ ai1
7
+ ai2
8
+ ai3
9
+ ai4
10
+ an1
11
+ an2
12
+ an3
13
+ an4
14
+ ang1
15
+ ang2
16
+ ang3
17
+ ang4
18
+ ao1
19
+ ao2
20
+ ao3
21
+ ao4
22
+ ba
23
+ ba1
24
+ ba2
25
+ ba3
26
+ ba4
27
+ bai
28
+ bai1
29
+ bai2
30
+ bai3
31
+ bai4
32
+ ban
33
+ ban1
34
+ ban3
35
+ ban4
36
+ bang1
37
+ bang3
38
+ bang4
39
+ bao1
40
+ bao2
41
+ bao3
42
+ bao4
43
+ bei
44
+ bei1
45
+ bei3
46
+ bei4
47
+ ben1
48
+ ben3
49
+ ben4
50
+ beng
51
+ beng1
52
+ beng2
53
+ beng3
54
+ beng4
55
+ bi1
56
+ bi2
57
+ bi3
58
+ bi4
59
+ bian
60
+ bian1
61
+ bian3
62
+ bian4
63
+ biang2
64
+ biao1
65
+ biao3
66
+ biao4
67
+ bie1
68
+ bie2
69
+ bie3
70
+ bie4
71
+ bin
72
+ bin1
73
+ bin3
74
+ bin4
75
+ bing1
76
+ bing3
77
+ bing4
78
+ bo
79
+ bo1
80
+ bo2
81
+ bo3
82
+ bo4
83
+ bu1
84
+ bu2
85
+ bu3
86
+ bu4
87
+ ca1
88
+ ca3
89
+ ca4
90
+ cai1
91
+ cai2
92
+ cai3
93
+ cai4
94
+ can1
95
+ can2
96
+ can3
97
+ can4
98
+ cang1
99
+ cang2
100
+ cang3
101
+ cang4
102
+ cao1
103
+ cao2
104
+ cao3
105
+ cao4
106
+ ce4
107
+ cei4
108
+ cen1
109
+ cen2
110
+ ceng1
111
+ ceng2
112
+ ceng4
113
+ cha1
114
+ cha2
115
+ cha3
116
+ cha4
117
+ chai1
118
+ chai2
119
+ chai3
120
+ chai4
121
+ chan1
122
+ chan2
123
+ chan3
124
+ chan4
125
+ chang
126
+ chang1
127
+ chang2
128
+ chang3
129
+ chang4
130
+ chao1
131
+ chao2
132
+ chao3
133
+ chao4
134
+ che1
135
+ che2
136
+ che3
137
+ che4
138
+ chen
139
+ chen1
140
+ chen2
141
+ chen3
142
+ chen4
143
+ cheng1
144
+ cheng2
145
+ cheng3
146
+ cheng4
147
+ chi
148
+ chi1
149
+ chi2
150
+ chi3
151
+ chi4
152
+ chong1
153
+ chong2
154
+ chong3
155
+ chong4
156
+ chou1
157
+ chou2
158
+ chou3
159
+ chou4
160
+ chu
161
+ chu1
162
+ chu2
163
+ chu3
164
+ chu4
165
+ chua1
166
+ chua3
167
+ chua4
168
+ chuai1
169
+ chuai2
170
+ chuai3
171
+ chuai4
172
+ chuan1
173
+ chuan2
174
+ chuan3
175
+ chuan4
176
+ chuang1
177
+ chuang2
178
+ chuang3
179
+ chuang4
180
+ chui1
181
+ chui2
182
+ chui3
183
+ chui4
184
+ chun1
185
+ chun2
186
+ chun3
187
+ chuo1
188
+ chuo4
189
+ ci1
190
+ ci2
191
+ ci3
192
+ ci4
193
+ cong1
194
+ cong2
195
+ cong3
196
+ cong4
197
+ cou1
198
+ cou2
199
+ cou3
200
+ cou4
201
+ cu1
202
+ cu2
203
+ cu3
204
+ cu4
205
+ cuan1
206
+ cuan2
207
+ cuan4
208
+ cui
209
+ cui1
210
+ cui3
211
+ cui4
212
+ cun1
213
+ cun2
214
+ cun3
215
+ cun4
216
+ cuo1
217
+ cuo2
218
+ cuo3
219
+ cuo4
220
+ da
221
+ da1
222
+ da2
223
+ da3
224
+ da4
225
+ dai
226
+ dai1
227
+ dai3
228
+ dai4
229
+ dan1
230
+ dan3
231
+ dan4
232
+ dang
233
+ dang1
234
+ dang3
235
+ dang4
236
+ dao1
237
+ dao2
238
+ dao3
239
+ dao4
240
+ de
241
+ de1
242
+ de2
243
+ dei1
244
+ dei3
245
+ den4
246
+ deng1
247
+ deng3
248
+ deng4
249
+ di1
250
+ di2
251
+ di3
252
+ di4
253
+ dia3
254
+ dian1
255
+ dian2
256
+ dian3
257
+ dian4
258
+ diao1
259
+ diao3
260
+ diao4
261
+ die1
262
+ die2
263
+ die3
264
+ die4
265
+ din4
266
+ ding1
267
+ ding3
268
+ ding4
269
+ diu1
270
+ dong1
271
+ dong3
272
+ dong4
273
+ dou1
274
+ dou3
275
+ dou4
276
+ du1
277
+ du2
278
+ du3
279
+ du4
280
+ duan1
281
+ duan3
282
+ duan4
283
+ dui1
284
+ dui3
285
+ dui4
286
+ dun1
287
+ dun3
288
+ dun4
289
+ duo
290
+ duo1
291
+ duo2
292
+ duo3
293
+ duo4
294
+ e
295
+ e1
296
+ e2
297
+ e3
298
+ e4
299
+ ei1
300
+ ei2
301
+ ei3
302
+ ei4
303
+ en1
304
+ en3
305
+ en4
306
+ eng1
307
+ er
308
+ er2
309
+ er3
310
+ er4
311
+ fa
312
+ fa1
313
+ fa2
314
+ fa3
315
+ fa4
316
+ fan1
317
+ fan2
318
+ fan3
319
+ fan4
320
+ fang
321
+ fang1
322
+ fang2
323
+ fang3
324
+ fang4
325
+ fei1
326
+ fei2
327
+ fei3
328
+ fei4
329
+ fen1
330
+ fen2
331
+ fen3
332
+ fen4
333
+ feng1
334
+ feng2
335
+ feng3
336
+ feng4
337
+ fiao4
338
+ fo2
339
+ fou1
340
+ fou2
341
+ fou3
342
+ fu
343
+ fu1
344
+ fu2
345
+ fu3
346
+ fu4
347
+ ga1
348
+ ga2
349
+ ga3
350
+ ga4
351
+ gai1
352
+ gai3
353
+ gai4
354
+ gan1
355
+ gan3
356
+ gan4
357
+ gang1
358
+ gang3
359
+ gang4
360
+ gao1
361
+ gao3
362
+ gao4
363
+ ge1
364
+ ge2
365
+ ge3
366
+ ge4
367
+ gei3
368
+ gen1
369
+ gen2
370
+ gen3
371
+ gen4
372
+ geng1
373
+ geng3
374
+ geng4
375
+ gong
376
+ gong1
377
+ gong3
378
+ gong4
379
+ gou1
380
+ gou3
381
+ gou4
382
+ gu
383
+ gu1
384
+ gu2
385
+ gu3
386
+ gu4
387
+ gua1
388
+ gua2
389
+ gua3
390
+ gua4
391
+ guai1
392
+ guai3
393
+ guai4
394
+ guan1
395
+ guan3
396
+ guan4
397
+ guang
398
+ guang1
399
+ guang3
400
+ guang4
401
+ gui1
402
+ gui3
403
+ gui4
404
+ gun3
405
+ gun4
406
+ guo
407
+ guo1
408
+ guo2
409
+ guo3
410
+ guo4
411
+ ha1
412
+ ha2
413
+ ha3
414
+ ha4
415
+ hai
416
+ hai1
417
+ hai2
418
+ hai3
419
+ hai4
420
+ han
421
+ han1
422
+ han2
423
+ han3
424
+ han4
425
+ hang1
426
+ hang2
427
+ hang3
428
+ hang4
429
+ hao1
430
+ hao2
431
+ hao3
432
+ hao4
433
+ he1
434
+ he2
435
+ he3
436
+ he4
437
+ hei1
438
+ hen1
439
+ hen2
440
+ hen3
441
+ hen4
442
+ heng1
443
+ heng2
444
+ heng4
445
+ hm
446
+ hng
447
+ hong1
448
+ hong2
449
+ hong3
450
+ hong4
451
+ hou1
452
+ hou2
453
+ hou3
454
+ hou4
455
+ hu
456
+ hu1
457
+ hu2
458
+ hu3
459
+ hu4
460
+ hua1
461
+ hua2
462
+ hua4
463
+ huai
464
+ huai2
465
+ huai4
466
+ huan1
467
+ huan2
468
+ huan3
469
+ huan4
470
+ huang
471
+ huang1
472
+ huang2
473
+ huang3
474
+ huang4
475
+ hui
476
+ hui1
477
+ hui2
478
+ hui3
479
+ hui4
480
+ hun1
481
+ hun2
482
+ hun3
483
+ hun4
484
+ huo
485
+ huo1
486
+ huo2
487
+ huo3
488
+ huo4
489
+ ji1
490
+ ji2
491
+ ji3
492
+ ji4
493
+ jia
494
+ jia1
495
+ jia2
496
+ jia3
497
+ jia4
498
+ jian
499
+ jian1
500
+ jian3
501
+ jian4
502
+ jiang
503
+ jiang1
504
+ jiang3
505
+ jiang4
506
+ jiao
507
+ jiao1
508
+ jiao2
509
+ jiao3
510
+ jiao4
511
+ jie
512
+ jie1
513
+ jie2
514
+ jie3
515
+ jie4
516
+ jin1
517
+ jin3
518
+ jin4
519
+ jing
520
+ jing1
521
+ jing3
522
+ jing4
523
+ jiong1
524
+ jiong3
525
+ jiong4
526
+ jiu
527
+ jiu1
528
+ jiu2
529
+ jiu3
530
+ jiu4
531
+ ju
532
+ ju1
533
+ ju2
534
+ ju3
535
+ ju4
536
+ juan1
537
+ juan3
538
+ juan4
539
+ jue1
540
+ jue2
541
+ jue3
542
+ jue4
543
+ jun1
544
+ jun3
545
+ jun4
546
+ ka1
547
+ ka3
548
+ kai1
549
+ kai3
550
+ kai4
551
+ kan1
552
+ kan3
553
+ kan4
554
+ kang1
555
+ kang2
556
+ kang3
557
+ kang4
558
+ kao1
559
+ kao3
560
+ kao4
561
+ ke
562
+ ke1
563
+ ke2
564
+ ke3
565
+ ke4
566
+ kei1
567
+ ken1
568
+ ken3
569
+ ken4
570
+ keng1
571
+ keng3
572
+ kong1
573
+ kong3
574
+ kong4
575
+ kou1
576
+ kou3
577
+ kou4
578
+ ku1
579
+ ku2
580
+ ku3
581
+ ku4
582
+ kua1
583
+ kua3
584
+ kua4
585
+ kuai3
586
+ kuai4
587
+ kuan1
588
+ kuan3
589
+ kuang1
590
+ kuang2
591
+ kuang3
592
+ kuang4
593
+ kui1
594
+ kui2
595
+ kui3
596
+ kui4
597
+ kun
598
+ kun1
599
+ kun3
600
+ kun4
601
+ kuo4
602
+ la
603
+ la1
604
+ la2
605
+ la3
606
+ la4
607
+ lai2
608
+ lai3
609
+ lai4
610
+ lan2
611
+ lan3
612
+ lan4
613
+ lang
614
+ lang1
615
+ lang2
616
+ lang3
617
+ lang4
618
+ lao
619
+ lao1
620
+ lao2
621
+ lao3
622
+ lao4
623
+ le
624
+ le1
625
+ le4
626
+ lei
627
+ lei1
628
+ lei2
629
+ lei3
630
+ lei4
631
+ len4
632
+ leng1
633
+ leng2
634
+ leng3
635
+ leng4
636
+ li
637
+ li1
638
+ li2
639
+ li3
640
+ li4
641
+ lia3
642
+ lian2
643
+ lian3
644
+ lian4
645
+ liang
646
+ liang2
647
+ liang3
648
+ liang4
649
+ liao1
650
+ liao2
651
+ liao3
652
+ liao4
653
+ lie
654
+ lie1
655
+ lie2
656
+ lie3
657
+ lie4
658
+ lin1
659
+ lin2
660
+ lin3
661
+ lin4
662
+ ling
663
+ ling1
664
+ ling2
665
+ ling3
666
+ ling4
667
+ liu1
668
+ liu2
669
+ liu3
670
+ liu4
671
+ lo
672
+ long1
673
+ long2
674
+ long3
675
+ long4
676
+ lou
677
+ lou1
678
+ lou2
679
+ lou3
680
+ lou4
681
+ lu
682
+ lu1
683
+ lu2
684
+ lu3
685
+ lu4
686
+ luan2
687
+ luan3
688
+ luan4
689
+ lun1
690
+ lun2
691
+ lun3
692
+ lun4
693
+ luo
694
+ luo1
695
+ luo2
696
+ luo3
697
+ luo4
698
+ lv2
699
+ lv3
700
+ lv4
701
+ lve3
702
+ lve4
703
+ m1
704
+ m2
705
+ m4
706
+ ma
707
+ ma1
708
+ ma2
709
+ ma3
710
+ ma4
711
+ mai2
712
+ mai3
713
+ mai4
714
+ man1
715
+ man2
716
+ man3
717
+ man4
718
+ mang1
719
+ mang2
720
+ mang3
721
+ mang4
722
+ mao1
723
+ mao2
724
+ mao3
725
+ mao4
726
+ me
727
+ me1
728
+ mei2
729
+ mei3
730
+ mei4
731
+ men
732
+ men1
733
+ men2
734
+ men4
735
+ meng
736
+ meng1
737
+ meng2
738
+ meng3
739
+ meng4
740
+ mi1
741
+ mi2
742
+ mi3
743
+ mi4
744
+ mian2
745
+ mian3
746
+ mian4
747
+ miao1
748
+ miao2
749
+ miao3
750
+ miao4
751
+ mie
752
+ mie1
753
+ mie2
754
+ mie4
755
+ min
756
+ min2
757
+ min3
758
+ ming
759
+ ming2
760
+ ming3
761
+ ming4
762
+ miu3
763
+ miu4
764
+ mo
765
+ mo1
766
+ mo2
767
+ mo3
768
+ mo4
769
+ mou1
770
+ mou2
771
+ mou3
772
+ mou4
773
+ mu2
774
+ mu3
775
+ mu4
776
+ n
777
+ n2
778
+ n3
779
+ n4
780
+ na
781
+ na1
782
+ na2
783
+ na3
784
+ na4
785
+ nai2
786
+ nai3
787
+ nai4
788
+ nan1
789
+ nan2
790
+ nan3
791
+ nan4
792
+ nang
793
+ nang1
794
+ nang2
795
+ nang3
796
+ nang4
797
+ nao1
798
+ nao2
799
+ nao3
800
+ nao4
801
+ ne
802
+ ne2
803
+ ne4
804
+ nei2
805
+ nei3
806
+ nei4
807
+ nen4
808
+ neng2
809
+ neng3
810
+ neng4
811
+ ng
812
+ ng2
813
+ ng3
814
+ ng4
815
+ ni1
816
+ ni2
817
+ ni3
818
+ ni4
819
+ nia1
820
+ nian1
821
+ nian2
822
+ nian3
823
+ nian4
824
+ niang2
825
+ niang3
826
+ niang4
827
+ niao3
828
+ niao4
829
+ nie1
830
+ nie2
831
+ nie3
832
+ nie4
833
+ nin
834
+ nin2
835
+ nin3
836
+ ning2
837
+ ning3
838
+ ning4
839
+ niu1
840
+ niu2
841
+ niu3
842
+ niu4
843
+ nong2
844
+ nong3
845
+ nong4
846
+ nou2
847
+ nou3
848
+ nou4
849
+ nu2
850
+ nu3
851
+ nu4
852
+ nuan2
853
+ nuan3
854
+ nuan4
855
+ nun2
856
+ nun4
857
+ nuo2
858
+ nuo3
859
+ nuo4
860
+ nv2
861
+ nv3
862
+ nv4
863
+ nve4
864
+ o
865
+ o1
866
+ o2
867
+ o3
868
+ o4
869
+ ou
870
+ ou1
871
+ ou2
872
+ ou3
873
+ ou4
874
+ pa1
875
+ pa2
876
+ pa3
877
+ pa4
878
+ pai1
879
+ pai2
880
+ pai3
881
+ pai4
882
+ pan1
883
+ pan2
884
+ pan3
885
+ pan4
886
+ pang1
887
+ pang2
888
+ pang3
889
+ pang4
890
+ pao1
891
+ pao2
892
+ pao3
893
+ pao4
894
+ pei1
895
+ pei2
896
+ pei3
897
+ pei4
898
+ pen1
899
+ pen2
900
+ pen3
901
+ pen4
902
+ peng1
903
+ peng2
904
+ peng3
905
+ peng4
906
+ pi1
907
+ pi2
908
+ pi3
909
+ pi4
910
+ pian1
911
+ pian2
912
+ pian3
913
+ pian4
914
+ piao1
915
+ piao2
916
+ piao3
917
+ piao4
918
+ pie1
919
+ pie3
920
+ pie4
921
+ pin1
922
+ pin2
923
+ pin3
924
+ pin4
925
+ ping1
926
+ ping2
927
+ ping3
928
+ ping4
929
+ po
930
+ po1
931
+ po2
932
+ po3
933
+ po4
934
+ pou1
935
+ pou2
936
+ pou3
937
+ pou4
938
+ pu
939
+ pu1
940
+ pu2
941
+ pu3
942
+ pu4
943
+ qi
944
+ qi1
945
+ qi2
946
+ qi3
947
+ qi4
948
+ qia1
949
+ qia2
950
+ qia3
951
+ qia4
952
+ qian
953
+ qian1
954
+ qian2
955
+ qian3
956
+ qian4
957
+ qiang1
958
+ qiang2
959
+ qiang3
960
+ qiang4
961
+ qiao1
962
+ qiao2
963
+ qiao3
964
+ qiao4
965
+ qie1
966
+ qie2
967
+ qie3
968
+ qie4
969
+ qin1
970
+ qin2
971
+ qin3
972
+ qin4
973
+ qing
974
+ qing1
975
+ qing2
976
+ qing3
977
+ qing4
978
+ qiong1
979
+ qiong2
980
+ qiong4
981
+ qiu1
982
+ qiu2
983
+ qiu3
984
+ qiu4
985
+ qu
986
+ qu1
987
+ qu2
988
+ qu3
989
+ qu4
990
+ quan
991
+ quan1
992
+ quan2
993
+ quan3
994
+ quan4
995
+ que1
996
+ que2
997
+ que4
998
+ qun1
999
+ qun2
1000
+ qun3
1001
+ ran2
1002
+ ran3
1003
+ ran4
1004
+ rang1
1005
+ rang2
1006
+ rang3
1007
+ rang4
1008
+ rao2
1009
+ rao3
1010
+ rao4
1011
+ re2
1012
+ re3
1013
+ re4
1014
+ ren2
1015
+ ren3
1016
+ ren4
1017
+ reng1
1018
+ reng2
1019
+ reng4
1020
+ ri4
1021
+ rong
1022
+ rong1
1023
+ rong2
1024
+ rong3
1025
+ rong4
1026
+ rou2
1027
+ rou3
1028
+ rou4
1029
+ ru
1030
+ ru2
1031
+ ru3
1032
+ ru4
1033
+ rua2
1034
+ ruan2
1035
+ ruan3
1036
+ ruan4
1037
+ rui2
1038
+ rui3
1039
+ rui4
1040
+ run2
1041
+ run3
1042
+ run4
1043
+ ruo2
1044
+ ruo4
1045
+ sa
1046
+ sa1
1047
+ sa3
1048
+ sa4
1049
+ sai1
1050
+ sai3
1051
+ sai4
1052
+ san
1053
+ san1
1054
+ san3
1055
+ san4
1056
+ sang1
1057
+ sang3
1058
+ sang4
1059
+ sao1
1060
+ sao3
1061
+ sao4
1062
+ se1
1063
+ se4
1064
+ sen1
1065
+ sen3
1066
+ seng1
1067
+ seng4
1068
+ sha
1069
+ sha1
1070
+ sha2
1071
+ sha3
1072
+ sha4
1073
+ shai1
1074
+ shai3
1075
+ shai4
1076
+ shan1
1077
+ shan2
1078
+ shan3
1079
+ shan4
1080
+ shang
1081
+ shang1
1082
+ shang3
1083
+ shang4
1084
+ shao1
1085
+ shao2
1086
+ shao3
1087
+ shao4
1088
+ she1
1089
+ she2
1090
+ she3
1091
+ she4
1092
+ shei2
1093
+ shen1
1094
+ shen2
1095
+ shen3
1096
+ shen4
1097
+ sheng1
1098
+ sheng2
1099
+ sheng3
1100
+ sheng4
1101
+ shi
1102
+ shi1
1103
+ shi2
1104
+ shi3
1105
+ shi4
1106
+ shou
1107
+ shou1
1108
+ shou2
1109
+ shou3
1110
+ shou4
1111
+ shu1
1112
+ shu2
1113
+ shu3
1114
+ shu4
1115
+ shua1
1116
+ shua3
1117
+ shua4
1118
+ shuai1
1119
+ shuai3
1120
+ shuai4
1121
+ shuan1
1122
+ shuan4
1123
+ shuang1
1124
+ shuang3
1125
+ shuang4
1126
+ shui
1127
+ shui2
1128
+ shui3
1129
+ shui4
1130
+ shun3
1131
+ shun4
1132
+ shuo1
1133
+ shuo2
1134
+ shuo4
1135
+ si
1136
+ si1
1137
+ si2
1138
+ si3
1139
+ si4
1140
+ song1
1141
+ song2
1142
+ song3
1143
+ song4
1144
+ sou1
1145
+ sou3
1146
+ sou4
1147
+ su1
1148
+ su2
1149
+ su3
1150
+ su4
1151
+ suan1
1152
+ suan3
1153
+ suan4
1154
+ sui1
1155
+ sui2
1156
+ sui3
1157
+ sui4
1158
+ sun1
1159
+ sun3
1160
+ sun4
1161
+ suo
1162
+ suo1
1163
+ suo2
1164
+ suo3
1165
+ suo4
1166
+ ta
1167
+ ta1
1168
+ ta2
1169
+ ta3
1170
+ ta4
1171
+ tai
1172
+ tai1
1173
+ tai2
1174
+ tai3
1175
+ tai4
1176
+ tan1
1177
+ tan2
1178
+ tan3
1179
+ tan4
1180
+ tang1
1181
+ tang2
1182
+ tang3
1183
+ tang4
1184
+ tao1
1185
+ tao2
1186
+ tao3
1187
+ tao4
1188
+ te
1189
+ te4
1190
+ tei1
1191
+ teng1
1192
+ teng2
1193
+ teng4
1194
+ ti
1195
+ ti1
1196
+ ti2
1197
+ ti3
1198
+ ti4
1199
+ tian1
1200
+ tian2
1201
+ tian3
1202
+ tian4
1203
+ tiao
1204
+ tiao1
1205
+ tiao2
1206
+ tiao3
1207
+ tiao4
1208
+ tie1
1209
+ tie2
1210
+ tie3
1211
+ tie4
1212
+ ting1
1213
+ ting2
1214
+ ting3
1215
+ ting4
1216
+ tong1
1217
+ tong2
1218
+ tong3
1219
+ tong4
1220
+ tou
1221
+ tou1
1222
+ tou2
1223
+ tou3
1224
+ tou4
1225
+ tu
1226
+ tu1
1227
+ tu2
1228
+ tu3
1229
+ tu4
1230
+ tuan1
1231
+ tuan2
1232
+ tuan3
1233
+ tuan4
1234
+ tui1
1235
+ tui2
1236
+ tui3
1237
+ tui4
1238
+ tun1
1239
+ tun2
1240
+ tun3
1241
+ tun4
1242
+ tuo1
1243
+ tuo2
1244
+ tuo3
1245
+ tuo4
1246
+ wa
1247
+ wa1
1248
+ wa2
1249
+ wa3
1250
+ wa4
1251
+ wai
1252
+ wai1
1253
+ wai3
1254
+ wai4
1255
+ wan1
1256
+ wan2
1257
+ wan3
1258
+ wan4
1259
+ wang1
1260
+ wang2
1261
+ wang3
1262
+ wang4
1263
+ wei
1264
+ wei1
1265
+ wei2
1266
+ wei3
1267
+ wei4
1268
+ wen
1269
+ wen1
1270
+ wen2
1271
+ wen3
1272
+ wen4
1273
+ weng1
1274
+ weng3
1275
+ weng4
1276
+ wo1
1277
+ wo3
1278
+ wo4
1279
+ wong4
1280
+ wu
1281
+ wu1
1282
+ wu2
1283
+ wu3
1284
+ wu4
1285
+ xi1
1286
+ xi2
1287
+ xi3
1288
+ xi4
1289
+ xia1
1290
+ xia2
1291
+ xia3
1292
+ xia4
1293
+ xian
1294
+ xian1
1295
+ xian2
1296
+ xian3
1297
+ xian4
1298
+ xiang1
1299
+ xiang2
1300
+ xiang3
1301
+ xiang4
1302
+ xiao
1303
+ xiao1
1304
+ xiao2
1305
+ xiao3
1306
+ xiao4
1307
+ xie1
1308
+ xie2
1309
+ xie3
1310
+ xie4
1311
+ xin
1312
+ xin1
1313
+ xin2
1314
+ xin3
1315
+ xin4
1316
+ xing
1317
+ xing1
1318
+ xing2
1319
+ xing3
1320
+ xing4
1321
+ xiong1
1322
+ xiong2
1323
+ xiong3
1324
+ xiong4
1325
+ xiu1
1326
+ xiu2
1327
+ xiu3
1328
+ xiu4
1329
+ xu
1330
+ xu1
1331
+ xu2
1332
+ xu3
1333
+ xu4
1334
+ xuan1
1335
+ xuan2
1336
+ xuan3
1337
+ xuan4
1338
+ xue1
1339
+ xue2
1340
+ xue3
1341
+ xue4
1342
+ xun1
1343
+ xun2
1344
+ xun4
1345
+ ya
1346
+ ya1
1347
+ ya2
1348
+ ya3
1349
+ ya4
1350
+ yan1
1351
+ yan2
1352
+ yan3
1353
+ yan4
1354
+ yang
1355
+ yang1
1356
+ yang2
1357
+ yang3
1358
+ yang4
1359
+ yao1
1360
+ yao2
1361
+ yao3
1362
+ yao4
1363
+ ye
1364
+ ye1
1365
+ ye2
1366
+ ye3
1367
+ ye4
1368
+ yi
1369
+ yi1
1370
+ yi2
1371
+ yi3
1372
+ yi4
1373
+ yin
1374
+ yin1
1375
+ yin2
1376
+ yin3
1377
+ yin4
1378
+ ying1
1379
+ ying2
1380
+ ying3
1381
+ ying4
1382
+ yo
1383
+ yo1
1384
+ yong1
1385
+ yong2
1386
+ yong3
1387
+ yong4
1388
+ you
1389
+ you1
1390
+ you2
1391
+ you3
1392
+ you4
1393
+ yu
1394
+ yu1
1395
+ yu2
1396
+ yu3
1397
+ yu4
1398
+ yuan1
1399
+ yuan2
1400
+ yuan3
1401
+ yuan4
1402
+ yue1
1403
+ yue2
1404
+ yue3
1405
+ yue4
1406
+ yun
1407
+ yun1
1408
+ yun2
1409
+ yun3
1410
+ yun4
1411
+ za1
1412
+ za2
1413
+ za3
1414
+ za4
1415
+ zai1
1416
+ zai3
1417
+ zai4
1418
+ zan
1419
+ zan1
1420
+ zan2
1421
+ zan3
1422
+ zan4
1423
+ zang1
1424
+ zang3
1425
+ zang4
1426
+ zao1
1427
+ zao2
1428
+ zao3
1429
+ zao4
1430
+ ze
1431
+ ze2
1432
+ ze4
1433
+ zei2
1434
+ zen
1435
+ zen1
1436
+ zen3
1437
+ zen4
1438
+ zeng1
1439
+ zeng3
1440
+ zeng4
1441
+ zha
1442
+ zha1
1443
+ zha2
1444
+ zha3
1445
+ zha4
1446
+ zhai1
1447
+ zhai2
1448
+ zhai3
1449
+ zhai4
1450
+ zhan1
1451
+ zhan2
1452
+ zhan3
1453
+ zhan4
1454
+ zhang
1455
+ zhang1
1456
+ zhang3
1457
+ zhang4
1458
+ zhao
1459
+ zhao1
1460
+ zhao2
1461
+ zhao3
1462
+ zhao4
1463
+ zhe
1464
+ zhe1
1465
+ zhe2
1466
+ zhe3
1467
+ zhe4
1468
+ zhei4
1469
+ zhen1
1470
+ zhen2
1471
+ zhen3
1472
+ zhen4
1473
+ zheng1
1474
+ zheng3
1475
+ zheng4
1476
+ zhi
1477
+ zhi1
1478
+ zhi2
1479
+ zhi3
1480
+ zhi4
1481
+ zhong1
1482
+ zhong3
1483
+ zhong4
1484
+ zhou1
1485
+ zhou2
1486
+ zhou3
1487
+ zhou4
1488
+ zhu1
1489
+ zhu2
1490
+ zhu3
1491
+ zhu4
1492
+ zhua1
1493
+ zhua3
1494
+ zhuai1
1495
+ zhuai3
1496
+ zhuai4
1497
+ zhuan1
1498
+ zhuan2
1499
+ zhuan3
1500
+ zhuan4
1501
+ zhuang1
1502
+ zhuang3
1503
+ zhuang4
1504
+ zhui1
1505
+ zhui3
1506
+ zhui4
1507
+ zhun1
1508
+ zhun3
1509
+ zhun4
1510
+ zhuo
1511
+ zhuo1
1512
+ zhuo2
1513
+ zhuo4
1514
+ zi
1515
+ zi1
1516
+ zi2
1517
+ zi3
1518
+ zi4
1519
+ zong
1520
+ zong1
1521
+ zong3
1522
+ zong4
1523
+ zou1
1524
+ zou3
1525
+ zou4
1526
+ zu1
1527
+ zu2
1528
+ zu3
1529
+ zu4
1530
+ zuan1
1531
+ zuan3
1532
+ zuan4
1533
+ zui
1534
+ zui1
1535
+ zui2
1536
+ zui3
1537
+ zui4
1538
+ zun1
1539
+ zun2
1540
+ zun3
1541
+ zun4
1542
+ zuo
1543
+ zuo1
1544
+ zuo2
1545
+ zuo3
1546
+ zuo4
1547
+ Γͺ1
1548
+ Γͺ2
1549
+ Γͺ3
1550
+ Γͺ4
egs/zipvoice/local/prepare_emilia.sh ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
4
+ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
5
+ export PYTHONPATH=../../:$PYTHONPATH
6
+
7
+ set -eou pipefail
8
+
9
+ stage=0
10
+ stop_stage=5
11
+ sampling_rate=24000
12
+ nj=32
13
+
14
+ dl_dir=$PWD/download
15
+
16
+ . scripts/parse_options.sh || exit 1
17
+
18
+ # All files generated by this script are saved in "data".
19
+ # You can safely remove "data" and rerun this script to regenerate it.
20
+ mkdir -p data
21
+
22
+ log() {
23
+ # This function is from espnet
24
+ local fname=${BASH_SOURCE[1]##*/}
25
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
26
+ }
27
+
28
+ log "dl_dir: $dl_dir"
29
+
30
+ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
31
+ log "Stage 0: Download data"
32
+
33
+ # Your download directory should look like this:
34
+ #
35
+ # download/Amphion___Emilia
36
+ # β”œβ”€β”€ metafile.yaml
37
+ # β”œβ”€β”€ raw
38
+ # β”‚ β”œβ”€β”€ DE
39
+ # β”‚ β”œβ”€β”€ EN
40
+ # β”‚ β”œβ”€β”€ FR
41
+ # β”‚ β”œβ”€β”€ JA
42
+ # β”‚ β”œβ”€β”€ KO
43
+ # β”‚ β”œβ”€β”€ openemilia_45batches.tar.gz
44
+ # β”‚ β”œβ”€β”€ openemilia_all.tar.gz
45
+ # β”‚ └── ZH
46
+ # └── README.md
47
+
48
+ if [ ! -d $dl_dir/Amphion___Emilia/raw ]; then
49
+ log "Please refer https://openxlab.org.cn/datasets/Amphion/Emilia to download the dataset."
50
+ exit(-1)
51
+ fi
52
+
53
+ fi
54
+
55
+ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
56
+ log "Stage 1: Prepare emilia manifests (EN and ZH only)"
57
+ # We assume that you have downloaded the Emilia corpus
58
+ # to $dl_dir/Amphion___Emilia
59
+ # see stage 0 for the directory structure
60
+ mkdir -p data/manifests
61
+ if [ ! -e data/manifests/.emilia.done ]; then
62
+ lhotse prepare emilia --lang en --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests
63
+ lhotse prepare emilia --lang zh --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests
64
+ touch data/manifests/.emilia.done
65
+ fi
66
+ fi
67
+
68
+ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
69
+ log "Stage 2: Preprocess Emilia dataset, mainly for cleaning"
70
+ mkdir -p data/manifests/splits_raw
71
+ if [ ! -e data/manifests/split_raw/.emilia.split.done ]; then
72
+ lhotse split-lazy data/manifests/emilia_cuts_EN.jsonl.gz data/manifests/splits_raw 10000
73
+ lhotse split-lazy data/manifests/emilia_cuts_ZH.jsonl.gz data/manifests/splits_raw 10000
74
+ touch data/manifests/splits_raw/.emilia.split.done
75
+ fi
76
+
77
+ mkdir -p data/manifests/splits
78
+
79
+ if [ ! -e data/manifests/splits/.emilia.preprocess.done ]; then
80
+ python local/preprocess_emilia.py --subset EN
81
+ python local/preprocess_emilia.py --subset ZH
82
+ touch data/manifests/splits/.emilia.preprocess.done
83
+ fi
84
+
85
+ fi
86
+
87
+
88
+ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
89
+ log "Stage 3: Add tokens to manifests"
90
+
91
+ mkdir -p data/manifests/tokenized_splits
92
+
93
+ if [ ! -e data/manifests/tokenized_splits/.emilia.preprocess.done ]; then
94
+ for subset in EN ZH; do
95
+ log "Tokenizing Emilia ${subset}"
96
+ python local/prepare_emilia.py \
97
+ --subset ${subset} \
98
+ --jobs ${nj} \
99
+ --source-dir data/manifests/splits/ \
100
+ --output-dir data/manifests/tokenized_splits/
101
+ done
102
+ touch data/manifests/tokenized_splits/.emilia.preprocess.done
103
+ fi
104
+
105
+ fi
106
+
107
+ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
108
+ log "Stage 4: Extract Fbank for Emilia"
109
+ mkdir -p data/fbank/emilia_splits
110
+ if [ ! -e data/fbank/emilia_splits/.emilia.fbank.done ]; then
111
+ # You can speed up the extraction by distributing splits to multiple machines.
112
+ for subset in EN ZH; do
113
+ python3 -m zipvoice.bin.compute_fbank \
114
+ --source-dir data/manifests/tokenized_splits \
115
+ --dest-dir data/fbank/emilia_splits \
116
+ --dataset emilia \
117
+ --subset ${subset} \
118
+ --splits-cuts 1 \
119
+ --split-begin 0 \
120
+ --split-end 2000 \
121
+ --num-jobs ${nj}
122
+ done
123
+ touch data/fbank/emilia_splits/.emilia.fbank.done
124
+ fi
125
+
126
+ if [ ! -e data/fbank/emilia_cuts_EN.jsonl.gz ]; then
127
+ log "Combining EN fbank cuts and spliting EN dev set"
128
+ gunzip -c data/fbank/emilia_splits/emilia_cuts_EN.*.jsonl.gz > data/fbank/emilia_cuts_EN.jsonl
129
+ head -n 1500 data/fbank/emilia_cuts_EN.jsonl | gzip -c > data/fbank/emilia_cuts_EN_dev.jsonl.gz
130
+ sed -i '1,1500d' data/fbank/emilia_cuts_EN.jsonl
131
+ gzip data/fbank/emilia_cuts_EN.jsonl
132
+ fi
133
+
134
+ if [ ! -e data/fbank/emilia_cuts_ZH.jsonl.gz ]; then
135
+ log "Combining ZH fbank cuts and spliting ZH dev set"
136
+ gunzip -c data/fbank/emilia_splits/emilia_cuts_ZH.*.jsonl.gz > data/fbank/emilia_cuts_ZH.jsonl
137
+ head -n 1500 data/fbank/emilia_cuts_ZH.jsonl | gzip -c > data/fbank/emilia_cuts_ZH_dev.jsonl.gz
138
+ sed -i '1,1500d' data/fbank/emilia_cuts_ZH.jsonl
139
+ gzip data/fbank/emilia_cuts_ZH.jsonl
140
+ fi
141
+
142
+ fi
143
+
144
+ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
145
+ log "Stage 5: Generate token file"
146
+ if [ ! -e data/tokens_emilia.txt ]; then
147
+ ./local/prepare_token_file_emilia.py --tokens data/tokens_emilia.txt
148
+ fi
149
+ fi
egs/zipvoice/local/prepare_libritts.sh ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
4
+ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
5
+ export PYTHONPATH=../../:$PYTHONPATH
6
+
7
+ set -eou pipefail
8
+
9
+ stage=0
10
+ stop_stage=5
11
+ sampling_rate=24000
12
+ nj=20
13
+
14
+ dl_dir=$PWD/download
15
+
16
+ . utils/parse_options.sh || exit 1
17
+
18
+ # All files generated by this script are saved in "data".
19
+ # You can safely remove "data" and rerun this script to regenerate it.
20
+ mkdir -p data
21
+
22
+ log() {
23
+ # This function is from espnet
24
+ local fname=${BASH_SOURCE[1]##*/}
25
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
26
+ }
27
+
28
+ log "dl_dir: $dl_dir"
29
+
30
+ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
31
+ log "Stage 0: Download data"
32
+
33
+ # If you have pre-downloaded it to /path/to/LibriTTS,
34
+ # you can create a symlink
35
+ #
36
+ # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS
37
+ #
38
+ if [ ! -d $dl_dir/LibriTTS ]; then
39
+ lhotse download libritts $dl_dir
40
+ fi
41
+
42
+ fi
43
+
44
+ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
45
+ log "Stage 1: Prepare LibriTTS manifest"
46
+ # We assume that you have downloaded the LibriTTS corpus
47
+ # to $dl_dir/LibriTTS
48
+
49
+ # We did not add tokens to this manifest, as on-the-fly
50
+ # tokenization with LibriTTSTokenizer is not slow.
51
+ mkdir -p data/manifests
52
+ if [ ! -e data/manifests/.libritts.done ]; then
53
+ lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests
54
+ touch data/manifests/.libritts.done
55
+ fi
56
+ fi
57
+
58
+ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
59
+ log "Stage 2: Compute Fbank for LibriTTS"
60
+ mkdir -p data/fbank
61
+
62
+ if [ ! -e data/fbank/.libritts.done ]; then
63
+ for subset in train-clean-100 train-clean-360 train-other-500 dev-clean test-clean; do
64
+ python3 -m zipvoice.bin.compute_fbank \
65
+ --source-dir data/manifests \
66
+ --dest-dir data/fbank \
67
+ --dataset libritts \
68
+ --subset ${subset} \
69
+ --sampling-rate $sampling_rate \
70
+ --num-jobs ${nj}
71
+ done
72
+ touch data/fbank/.libritts.done
73
+ fi
74
+
75
+ # Here we shuffle and combine the train-clean-100, train-clean-360 and
76
+ # train-other-500 together to form the training set.
77
+ if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
78
+ cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
79
+ <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \
80
+ <(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \
81
+ shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
82
+ fi
83
+
84
+
85
+ if [ ! -e data/fbank/.libritts-validated.done ]; then
86
+ log "Validating data/fbank for LibriTTS"
87
+ python3 ./utils/validate_manifest.py \
88
+ data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
89
+ touch data/fbank/.libritts-validated.done
90
+ fi
91
+ fi
92
+
93
+ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
94
+ log "Stage 3: Generate token file"
95
+ if [ ! -e data/tokens_libritts.txt ]; then
96
+ python3 ./local/prepare_token_file_char.py \
97
+ --manifest data/fbank/libritts_cuts_train-all-shuf.jsonl.gz \
98
+ --tokens data/tokens_libritts.txt
99
+ fi
100
+ fi
egs/zipvoice/local/prepare_token_file_char.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import argparse
19
+ import re
20
+ from collections import Counter
21
+ from pathlib import Path
22
+
23
+ from lhotse import load_manifest_lazy
24
+
25
+
26
+ def get_args():
27
+ parser = argparse.ArgumentParser()
28
+
29
+ parser.add_argument(
30
+ "--tokens",
31
+ type=Path,
32
+ help="Path to the dict that maps the text tokens to IDs",
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--manifest",
37
+ type=Path,
38
+ help="Path to the manifest file",
39
+ )
40
+
41
+ return parser.parse_args()
42
+
43
+
44
+ def prepare_tokens(manifest_file, token_file):
45
+ counter = Counter()
46
+ manifest = load_manifest_lazy(manifest_file)
47
+ for cut in manifest:
48
+ line = re.sub(r"\s+", " ", cut.supervisions[0].text)
49
+ counter.update(line)
50
+
51
+ unique_chars = set(counter.keys())
52
+
53
+ if "_" in unique_chars:
54
+ unique_chars.remove("_")
55
+
56
+ sorted_chars = sorted(unique_chars, key=lambda char: counter[char], reverse=True)
57
+
58
+ result = ["_"] + sorted_chars
59
+
60
+ with open(token_file, "w", encoding="utf-8") as file:
61
+ for index, char in enumerate(result):
62
+ file.write(f"{char}\t{index}\n")
63
+
64
+
65
+ if __name__ == "__main__":
66
+ args = get_args()
67
+ prepare_tokens(args.manifest, args.tokens)
egs/zipvoice/local/prepare_token_file_emilia.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
3
+ # Wei Kang)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ """
21
+ This file generates the file that maps tokens to IDs.
22
+ """
23
+
24
+ import argparse
25
+ import logging
26
+ from pathlib import Path
27
+ from typing import List
28
+
29
+ from piper_phonemize import get_espeak_map
30
+ from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser()
35
+
36
+ parser.add_argument(
37
+ "--tokens",
38
+ type=Path,
39
+ default=Path("data/tokens_emilia.txt"),
40
+ help="Path to the dict that maps the text tokens to IDs",
41
+ )
42
+
43
+ parser.add_argument(
44
+ "--pinyin",
45
+ type=Path,
46
+ default=Path("resources/pinyin.txt"),
47
+ help="Path to the all unique pinyin",
48
+ )
49
+
50
+ return parser.parse_args()
51
+
52
+
53
+ def get_pinyin_tokens(pinyin: Path) -> List[str]:
54
+ phones = set()
55
+ with open(pinyin, "r") as f:
56
+ for line in f:
57
+ x = line.strip()
58
+ initial = to_initials(x, strict=False)
59
+ # don't want to share tokens with espeak tokens, so use tone3 style
60
+ finals = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
61
+ if initial != "":
62
+ # don't want to share tokens with espeak tokens,
63
+ # so add a '0' after each initial
64
+ phones.add(initial + "0")
65
+ if finals != "":
66
+ phones.add(finals)
67
+ return sorted(phones)
68
+
69
+
70
+ def get_token2id(args):
71
+ """Get a dict that maps token to IDs, and save it to the given filename."""
72
+ all_tokens = get_espeak_map() # token: [token_id]
73
+ all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
74
+ # sort by token_id
75
+ all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
76
+
77
+ all_pinyin = get_pinyin_tokens(args.pinyin)
78
+ with open(args.tokens, "w", encoding="utf-8") as f:
79
+ for token, token_id in all_tokens:
80
+ f.write(f"{token}\t{token_id}\n")
81
+ num_espeak_tokens = len(all_tokens)
82
+ for i, pinyin in enumerate(all_pinyin):
83
+ f.write(f"{pinyin}\t{num_espeak_tokens + i}\n")
84
+
85
+
86
+ if __name__ == "__main__":
87
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
88
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
89
+
90
+ args = get_args()
91
+ get_token2id(args)
egs/zipvoice/local/prepare_tokens_emilia.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file reads the texts in given manifest and save the new cuts with phoneme tokens.
3
+ """
4
+
5
+ import argparse
6
+ import glob
7
+ import logging
8
+ from concurrent.futures import ProcessPoolExecutor as Pool
9
+ from pathlib import Path
10
+
11
+ from lhotse import load_manifest_lazy
12
+
13
+ from zipvoice.tokenizer.tokenizer import add_tokens
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser()
18
+
19
+ parser.add_argument(
20
+ "--subset",
21
+ type=str,
22
+ help="Subset of emilia, (ZH, EN, etc.)",
23
+ )
24
+
25
+ parser.add_argument(
26
+ "--jobs",
27
+ type=int,
28
+ default=50,
29
+ help="Number of jobs to processing.",
30
+ )
31
+
32
+ parser.add_argument(
33
+ "--source-dir",
34
+ type=str,
35
+ default="data/manifests/splits",
36
+ help="The source directory of manifest files.",
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--dest-dir",
41
+ type=str,
42
+ help="The destination directory of manifest files.",
43
+ )
44
+
45
+ return parser.parse_args()
46
+
47
+
48
+ def prepare_tokens_emilia(file_name: str, input_dir: Path, output_dir: Path):
49
+ logging.info(f"Processing {file_name}")
50
+ if (output_dir / file_name).is_file():
51
+ logging.info(f"{file_name} exists, skipping.")
52
+ return
53
+
54
+ try:
55
+ cut_set = load_manifest_lazy(input_dir / file_name)
56
+ cut_set = add_tokens(cut_set=cut_set, tokenizer="emilia")
57
+ cut_set.to_file(output_dir / file_name)
58
+ except Exception as e:
59
+ logging.error(f"Manifest {file_name} failed with error: {e}")
60
+ raise
61
+
62
+
63
+ if __name__ == "__main__":
64
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
65
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
66
+
67
+ args = get_args()
68
+
69
+ input_dir = Path(args.source_dir)
70
+ output_dir = Path(args.dest_dir)
71
+ output_dir.mkdir(parents=True, exist_ok=True)
72
+
73
+ cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz")
74
+
75
+ with Pool(max_workers=args.jobs) as pool:
76
+ futures = [
77
+ pool.submit(
78
+ prepare_tokens_emilia, filename.split("/")[-1], input_dir, output_dir
79
+ )
80
+ for filename in cut_files
81
+ ]
82
+ for f in futures:
83
+ try:
84
+ f.result()
85
+ f.done()
86
+ except Exception as e:
87
+ logging.error(f"Future failed with error: {e}")
88
+ logging.info("Processing done.")
egs/zipvoice/local/preprocess_emilia.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ """
20
+ This file reads the texts in given manifest and save the cleaned new cuts.
21
+ """
22
+
23
+ import argparse
24
+ import glob
25
+ import logging
26
+ import os
27
+ import re
28
+ import unicodedata
29
+ from concurrent.futures import ProcessPoolExecutor as Pool
30
+ from pathlib import Path
31
+
32
+ from lhotse import load_manifest_lazy
33
+
34
+
35
+ def get_args():
36
+ parser = argparse.ArgumentParser()
37
+
38
+ parser.add_argument(
39
+ "--subset",
40
+ type=str,
41
+ help="Subset of emilia, (ZH, EN, etc.)",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--jobs",
46
+ type=int,
47
+ default=20,
48
+ help="Number of jobs to processing.",
49
+ )
50
+
51
+ parser.add_argument(
52
+ "--source-dir",
53
+ type=str,
54
+ default="data/manifests/splits_raw",
55
+ help="The source directory of manifest files.",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--dest-dir",
60
+ type=str,
61
+ default="data/manifests/splits",
62
+ help="The destination directory of manifest files.",
63
+ )
64
+
65
+ return parser.parse_args()
66
+
67
+
68
+ def tokenize_by_CJK_char(text: str) -> str:
69
+ """
70
+ Tokenize a line of text with CJK char.
71
+
72
+ Example:
73
+ input = "δ½ ε₯½δΈ–η•Œζ˜― hello world ηš„δΈ­ζ–‡"
74
+ output = ["δ½ ", "ε₯½", "δΈ–", "η•Œ", "是", "hello", "world", "ηš„", "δΈ­", "ζ–‡"]
75
+ """
76
+ pattern = re.compile(
77
+ r"([\u1100-\u11ff"
78
+ r"\u2e80-\ua4cf"
79
+ r"\ua840-\uD7AF"
80
+ r"\uF900-\uFAFF"
81
+ r"\uFE30-\uFE4F"
82
+ r"\uFF65-\uFFDC"
83
+ r"\U00020000-\U0002FFFF])"
84
+ )
85
+ chars = pattern.split(text.strip())
86
+ merged = " ".join([w.strip() for w in chars if w.strip()])
87
+ return merged.split()
88
+
89
+
90
+ def is_hangul(char):
91
+ letters = unicodedata.normalize("NFD", char)
92
+ return all(
93
+ ["\u1100" <= c <= "\u11ff" or "\u3131" <= c <= "\u318e" for c in letters]
94
+ )
95
+
96
+
97
+ def is_japanese(char):
98
+ return any(
99
+ [
100
+ start <= char <= end
101
+ for start, end in [
102
+ ("\u3041", "\u3096"),
103
+ ("\u30a0", "\u30ff"),
104
+ ("\uff5f", "\uff9f"),
105
+ ("\u31f0", "\u31ff"),
106
+ ("\u3220", "\u3243"),
107
+ ("\u3280", "\u337f"),
108
+ ]
109
+ ]
110
+ )
111
+
112
+
113
+ def is_chinese(char):
114
+ if char >= "\u4e00" and char <= "\u9fa5":
115
+ return True
116
+ else:
117
+ return False
118
+
119
+
120
+ def is_alphabet(char):
121
+ if (char >= "\u0041" and char <= "\u005a") or (
122
+ char >= "\u0061" and char <= "\u007a"
123
+ ):
124
+ return True
125
+ else:
126
+ return False
127
+
128
+
129
+ def preprocess_emilia(file_name: str, input_dir: Path, output_dir: Path):
130
+ logging.info(f"Processing {file_name}")
131
+ if (output_dir / file_name).is_file():
132
+ logging.info(f"{file_name} exists, skipping.")
133
+ return
134
+
135
+ def _filter_cut(cut):
136
+ text = cut.supervisions[0].text
137
+ duration = cut.supervisions[0].duration
138
+ chinese = []
139
+ english = []
140
+
141
+ # only contains chinese and space and alphabets
142
+ clean_chars = []
143
+ for x in text:
144
+ if is_hangul(x):
145
+ logging.warning(f"Delete cut with text containing Korean : {text}")
146
+ return False
147
+ if is_japanese(x):
148
+ logging.warning(f"Delete cut with text containing Japanese : {text}")
149
+ return False
150
+ if is_chinese(x):
151
+ chinese.append(x)
152
+ clean_chars.append(x)
153
+ if is_alphabet(x):
154
+ english.append(x)
155
+ clean_chars.append(x)
156
+ if x == " ":
157
+ clean_chars.append(x)
158
+ if len(english) + len(chinese) == 0:
159
+ logging.warning(f"Delete cut with text has no valid chars : {text}")
160
+ return False
161
+
162
+ words = tokenize_by_CJK_char("".join(clean_chars))
163
+ for i in range(len(words) - 10):
164
+ if words[i : i + 10].count(words[i]) == 10:
165
+ logging.warning(f"Delete cut with text with too much repeats : {text}")
166
+ return False
167
+ # word speed, 20 - 600 / minute
168
+ if duration < len(words) / 600 * 60 or duration > len(words) / 20 * 60:
169
+ logging.warning(
170
+ f"Delete cut with audio text mismatch, duration : {duration}s, "
171
+ f"words : {len(words)}, text : {text}"
172
+ )
173
+ return False
174
+ return True
175
+
176
+ try:
177
+ cut_set = load_manifest_lazy(input_dir / file_name)
178
+ cut_set = cut_set.filter(_filter_cut)
179
+ cut_set.to_file(output_dir / file_name)
180
+ except Exception as e:
181
+ logging.error(f"Manifest {file_name} failed with error: {e}")
182
+ os.remove(str(output_dir / file_name))
183
+
184
+
185
+ if __name__ == "__main__":
186
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
187
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
188
+
189
+ args = get_args()
190
+
191
+ input_dir = Path(args.source_dir)
192
+ output_dir = Path(args.dest_dir)
193
+ output_dir.mkdir(parents=True, exist_ok=True)
194
+
195
+ cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz")
196
+
197
+ with Pool(max_workers=args.jobs) as pool:
198
+ futures = [
199
+ pool.submit(
200
+ preprocess_emilia,
201
+ filename.split("/")[-1],
202
+ input_dir,
203
+ output_dir,
204
+ )
205
+ for filename in cut_files
206
+ ]
207
+ for f in futures:
208
+ f.result()
209
+ f.done()
210
+ logging.info("Processing done.")
egs/zipvoice/run_custom.sh ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script is an example of training ZipVoice on your custom datasets from scratch.
4
+
5
+ # Add project root to PYTHONPATH
6
+ export PYTHONPATH=../../:$PYTHONPATH
7
+
8
+ # Set bash to 'debug' mode, it will exit on:
9
+ # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
10
+ set -e
11
+ set -u
12
+ set -o pipefail
13
+
14
+ stage=1
15
+ stop_stage=6
16
+
17
+ # Number of jobs for data preparation
18
+ nj=20
19
+
20
+ # You can set `train_hours` and `max_len` according to statistics from
21
+ # the command `lhotse cut describe data/fbank/custom_cuts_train.jsonl.gz`.
22
+ # Set `train_hours` to "Total speech duration", and set `max_len` to 99% duration.
23
+
24
+ # Number of hours in training set, will affect the learning rate schedule
25
+ train_hours=500
26
+ # Maximum length (seconds) of the training utterance, will filter out longer utterances
27
+ max_len=20
28
+
29
+ # We suppose you have two TSV files: "data/raw/custom_train.tsv" and
30
+ # "data/raw/custom_dev.tsv", where "custom" is your dataset name,
31
+ # "train"/"dev" are used for training and validation respectively.
32
+
33
+ # Each line of the TSV files should be in one of the following formats:
34
+ # (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
35
+ # (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
36
+ # to part of the wav. The start_time and end_time specify the start and end
37
+ # times of the text within the wav, which should be in seconds.
38
+ # > Note: {uniq_id} must be unique for each line.
39
+ for subset in train dev;do
40
+ file_path=data/raw/custom_${subset}.tsv
41
+ [ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
42
+ done
43
+
44
+ ### Prepare the training data (1 - 3)
45
+
46
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
47
+ echo "Stage 1: Prepare manifests for custom dataset from tsv files"
48
+
49
+ for subset in train dev;do
50
+ python3 -m zipvoice.bin.prepare_dataset \
51
+ --tsv-path data/raw/custom_${subset}.tsv \
52
+ --prefix custom \
53
+ --subset ${subset} \
54
+ --num-jobs ${nj} \
55
+ --output-dir data/manifests
56
+ done
57
+ # The output manifest files are "data/manifests/custom_cuts_train.jsonl.gz".
58
+ # and "data/manifests/custom_cuts_dev.jsonl.gz".
59
+
60
+ # We did not add tokens to the manifests, as on-the-fly tokenization
61
+ # with the simple tokenizer used in this example is not slow.
62
+ # If you change to a complex tokenizer, e.g., with g2p and heavy text normalization,
63
+ # you may need to add tokens to the manifests to speed up the training.
64
+ # Refer to the fine-tuning example for adding tokens to the manifests.
65
+ fi
66
+
67
+ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
68
+ echo "Stage 2: Compute Fbank for custom dataset"
69
+ # You can skip this step and use `--on-the-fly-feats 1` in training stage
70
+ for subset in train dev; do
71
+ python3 -m zipvoice.bin.compute_fbank \
72
+ --source-dir data/manifests \
73
+ --dest-dir data/fbank \
74
+ --dataset custom \
75
+ --subset ${subset} \
76
+ --num-jobs ${nj}
77
+ done
78
+ fi
79
+
80
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
81
+ echo "Stage 3: Prepare tokens file for custom dataset"
82
+ # In this example, we use the simplest tokenizer that
83
+ # treat every character as a token.
84
+ python3 ./local/prepare_token_file_char.py \
85
+ --manifest data/manifests/custom_cuts_train.jsonl.gz \
86
+ --tokens data/tokens_custom.txt
87
+ fi
88
+
89
+
90
+ ### Training (4 - 5)
91
+
92
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
93
+ echo "Stage 4: Train the ZipVoice model"
94
+
95
+ [ -z "$train_hours" ] && { echo "Error: train_hours is not set!" >&2; exit 1; }
96
+ [ -z "$max_len" ] && { echo "Error: max_len is not set!" >&2; exit 1; }
97
+
98
+ # lr-hours will be set according to the `train_hours`,
99
+ # i.e., lr_hours = 1000 * (train_hours ** 0.3).
100
+ lr_hours=$(python3 -c "print(round(1000 * ($train_hours ** 0.3)))" )
101
+ python3 -m zipvoice.bin.train_zipvoice \
102
+ --world-size 4 \
103
+ --use-fp16 1 \
104
+ --num-iters 60000 \
105
+ --max-duration 500 \
106
+ --lr-hours ${lr_hours} \
107
+ --max-len ${max_len} \
108
+ --model-config conf/zipvoice_base.json \
109
+ --tokenizer simple \
110
+ --token-file data/tokens_custom.txt \
111
+ --dataset custom \
112
+ --train-manifest data/fbank/custom_cuts_train.jsonl.gz \
113
+ --dev-manifest data/fbank/custom_cuts_dev.jsonl.gz \
114
+ --exp-dir exp/zipvoice_custom
115
+ fi
116
+
117
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
118
+ echo "Stage 5: Average the checkpoints for ZipVoice"
119
+ python3 -m zipvoice.bin.generate_averaged_model \
120
+ --iter 60000 \
121
+ --avg 2 \
122
+ --model-name zipvoice \
123
+ --exp-dir exp/zipvoice_custom
124
+ # The generated model is exp/zipvoice_custom/iter-60000-avg-2.pt
125
+ fi
126
+
127
+ ### Inference with PyTorch models (6)
128
+
129
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
130
+ echo "Stage 6: Inference of the ZipVoice model"
131
+ python3 -m zipvoice.bin.infer_zipvoice \
132
+ --model-name zipvoice \
133
+ --model-dir exp/zipvoice_custom \
134
+ --checkpoint-name iter-60000-avg-2.pt \
135
+ --tokenizer simple \
136
+ --test-list test.tsv \
137
+ --res-dir results/test_custom
138
+ fi
egs/zipvoice/run_emilia.sh ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This is an example script for training ZipVoice on Emilia dataset.
4
+
5
+ # This script covers data preparation, ZipVoice trainnig,
6
+ # ZipVoice-Distill training, onnx export, and
7
+ # inference with all PyTorch and ONNX models.
8
+
9
+
10
+ # Add project root to PYTHONPATH
11
+ export PYTHONPATH=../../:$PYTHONPATH
12
+
13
+ # Set bash to 'debug' mode, it will exit on :
14
+ # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
15
+ set -e
16
+ set -u
17
+ set -o pipefail
18
+
19
+ stage=1
20
+ stop_stage=12
21
+
22
+ #### Prepare datasets (1)
23
+
24
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
25
+ echo "Stage 1: Data Preparation for Emilia dataset"
26
+ bash local/prepare_emilia.sh
27
+ fi
28
+
29
+ ### Training ZipVoice (2 - 3)
30
+
31
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
32
+ echo "Stage 2: Train the ZipVoice model"
33
+ python3 -m zipvoice.bin.train_zipvoice \
34
+ --world-size 8 \
35
+ --use-fp16 1 \
36
+ --num-epochs 11 \
37
+ --max-duration 500 \
38
+ --lr-hours 30000 \
39
+ --model-config conf/zipvoice_base.json \
40
+ --tokenizer emilia \
41
+ --token-file data/tokens_emilia.txt \
42
+ --dataset emilia \
43
+ --manifest-dir data/fbank \
44
+ --exp-dir exp/zipvoice
45
+ fi
46
+
47
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
48
+ echo "Stage 3: Average the checkpoints for ZipVoice"
49
+ python3 -m zipvoice.bin.generate_averaged_model \
50
+ --epoch 11 \
51
+ --avg 4 \
52
+ --model-name zipvoice \
53
+ --exp-dir exp/zipvoice
54
+ # The generated model is exp/zipvoice/epoch-11-avg-4.pt
55
+ fi
56
+
57
+ #### (Optional) Training ZipVoice-Distill model (4 - 6)
58
+
59
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
60
+ echo "Stage 4: Train the ZipVoice-Distill model (first stage)"
61
+ python3 -m zipvoice.bin.train_zipvoice_distill \
62
+ --world-size 8 \
63
+ --use-fp16 1 \
64
+ --num-iters 60000 \
65
+ --max-duration 500 \
66
+ --base-lr 0.0005 \
67
+ --tokenizer emilia \
68
+ --token-file data/tokens_emilia.txt \
69
+ --dataset emilia \
70
+ --manifest-dir data/fbank \
71
+ --teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
72
+ --distill-stage first \
73
+ --exp-dir exp/zipvoice_distill_1stage
74
+ fi
75
+
76
+
77
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
78
+ echo "Stage 5: Average the checkpoints for ZipVoice-Distill (first stage)"
79
+ python3 -m zipvoice.bin.generate_averaged_model \
80
+ --iter 60000 \
81
+ --avg 7 \
82
+ --model-name zipvoice_distill \
83
+ --exp-dir exp/zipvoice_distill_1stage
84
+ # The generated model is exp/zipvoice_distill_1stage/iter-60000-avg-7.pt
85
+ fi
86
+
87
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
88
+ echo "Stage 6: Train the ZipVoice-Distill model (second stage)"
89
+
90
+ python3 -m zipvoice.bin.train_zipvoice_distill \
91
+ --world-size 8 \
92
+ --use-fp16 1 \
93
+ --num-iters 2000 \
94
+ --save-every-n 1000 \
95
+ --max-duration 500 \
96
+ --base-lr 0.0001 \
97
+ --model-config conf/zipvoice_base.json \
98
+ --tokenizer emilia \
99
+ --token-file data/tokens_emilia.txt \
100
+ --dataset emilia \
101
+ --manifest-dir data/fbank \
102
+ --teacher-model exp/zipvoice_distill_1stage/iter-60000-avg-7.pt \
103
+ --distill-stage second \
104
+ --exp-dir exp/zipvoice_distill
105
+ fi
106
+
107
+ ### Export ONNX model (7 - 8)
108
+
109
+ if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
110
+ echo "Stage 7: Export ZipVoice ONNX model"
111
+ python3 -m zipvoice.bin.onnx_export \
112
+ --model-name zipvoice \
113
+ --model-dir exp/zipvoice/ \
114
+ --checkpoint-name epoch-11-avg-4.pt \
115
+ --onnx-model-dir exp/zipvoice/
116
+ fi
117
+
118
+ if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
119
+ echo "Stage 8: Export ZipVoice-Distill ONNX model"
120
+ python3 -m zipvoice.bin.onnx_export \
121
+ --model-name zipvoice_distill \
122
+ --model-dir exp/zipvoice_distill/ \
123
+ --checkpoint-name checkpoint-2000.pt \
124
+ --onnx-model-dir exp/zipvoice_distill/
125
+ fi
126
+
127
+
128
+ ### Inference with PyTorch and ONNX models (9 - 12)
129
+
130
+ if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
131
+ echo "Stage 9: Inference of the ZipVoice model"
132
+ python3 -m zipvoice.bin.infer_zipvoice \
133
+ --model-name zipvoice \
134
+ --model-dir exp/zipvoice/ \
135
+ --checkpoint-name epoch-11-avg-4.pt \
136
+ --tokenizer emilia \
137
+ --test-list test.tsv \
138
+ --res-dir results/test \
139
+ --num-step 16 \
140
+ --guidance-scale 1
141
+ fi
142
+
143
+
144
+ if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
145
+ echo "Stage 10: Inference of the ZipVoice-Distill model"
146
+ python3 -m zipvoice.bin.infer_zipvoice \
147
+ --model-name zipvoice_distill \
148
+ --model-dir exp/zipvoice_distill/ \
149
+ --checkpoint-name checkpoint-2000.pt \
150
+ --tokenizer emilia \
151
+ --test-list test.tsv \
152
+ --res-dir results/test_distill \
153
+ --num-step 8 \
154
+ --guidance-scale 3
155
+ fi
156
+
157
+
158
+ if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then
159
+ echo "Stage 11: Inference with ZipVoice ONNX model"
160
+ python3 -m zipvoice.bin.infer_zipvoice_onnx \
161
+ --model-name zipvoice \
162
+ --onnx-int8 False \
163
+ --model-dir exp/zipvoice \
164
+ --tokenizer emilia \
165
+ --test-list test.tsv \
166
+ --res-dir results/test_onnx
167
+ fi
168
+
169
+ if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then
170
+ echo "Stage 12: Inference with ZipVoic-Distill ONNX model"
171
+ python3 -m zipvoice.bin.infer_zipvoice_onnx \
172
+ --model-name zipvoice_distill \
173
+ --onnx-int8 False \
174
+ --model-dir exp/zipvoice_distill \
175
+ --tokenizer emilia \
176
+ --test-list test.tsv \
177
+ --res-dir results/test_distill_onnx
178
+ fi
egs/zipvoice/run_eval.sh ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script is an example of evaluate TTS models with objective metrics reported in ZipVoice paper.
4
+
5
+ # Add project root to PYTHONPATH
6
+ export PYTHONPATH=../../:$PYTHONPATH
7
+
8
+ # Set bash to 'debug' mode, it will exit on:
9
+ # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
10
+ set -e
11
+ set -u
12
+ set -o pipefail
13
+
14
+ stage=1
15
+ stop_stage=7
16
+
17
+ download_dir=download/
18
+
19
+ # Uncomment this line to use HF mirror
20
+ # export HF_ENDPOINT=https://hf-mirror.com
21
+
22
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
23
+ echo "Stage 1: Download test sets (LibriSpeech-PC and Seed-TTS)"
24
+
25
+ hf_repo=k2-fsa/TTS_eval_datasets
26
+ mkdir -p ${download_dir}/
27
+ for file in librispeech_pc_testset.tar.gz seedtts_testset.tar.gz; do
28
+ echo "Downloading ${file}..."
29
+ huggingface-cli download \
30
+ --repo-type dataset \
31
+ --local-dir ${download_dir}/ \
32
+ ${hf_repo} \
33
+ ${file}
34
+ echo "Extracting ${file}..."
35
+ tar -xzf ${download_dir}/${file} -C ${download_dir}/
36
+ done
37
+ fi
38
+
39
+
40
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
41
+ echo "Stage 2: Download all required evaluation models"
42
+ hf_repo=k2-fsa/TTS_eval_models
43
+ mkdir -p ${download_dir}/tts_eval_models
44
+ huggingface-cli download \
45
+ --local-dir ${download_dir}/tts_eval_models \
46
+ ${hf_repo}
47
+ fi
48
+
49
+
50
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
51
+ echo "Stage 3: Inference with the pre-trained ZipVoice model from huggingface"
52
+
53
+ for testset in librispeech_pc seedtts_en seedtts_zh; do
54
+
55
+ if [ "$testset" = "librispeech_pc" ]; then
56
+ test_tsv=${download_dir}/librispeech_pc_testset/test.tsv
57
+
58
+ elif [ "$testset" = "seedtts_en" ]; then
59
+ test_tsv=${download_dir}/seedtts_testset/en/test.tsv
60
+ elif [ "$testset" = "seedtts_zh" ]; then
61
+ test_tsv=${download_dir}/seedtts_testset/zh/test.tsv
62
+ else
63
+ echo "Error: unknown testset ${testset}" >&2
64
+ exit 1
65
+ fi
66
+ echo "Inference on tetset ${testset}..."
67
+ python3 -m zipvoice.bin.infer_zipvoice \
68
+ --model-name zipvoice \
69
+ --test-list ${test_tsv} \
70
+ --res-dir results/${testset}
71
+ done
72
+ fi
73
+
74
+
75
+
76
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
77
+ echo "Stage 4: Evaluation on LibriSpeech-PC"
78
+ model_path=${download_dir}/tts_eval_models
79
+ wav_path=results/librispeech_pc
80
+ test_tsv=${download_dir}/librispeech_pc_testset/test.tsv
81
+ # Use LibriSpeech style transcripts for WER evaluation
82
+ transcript_tsv=${download_dir}/librispeech_pc_testset/transcript.tsv
83
+
84
+ python3 -m zipvoice.eval.speaker_similarity.sim \
85
+ --wav-path ${wav_path} \
86
+ --test-list ${test_tsv} \
87
+ --model-dir ${model_path}
88
+
89
+ python3 -m zipvoice.eval.wer.hubert \
90
+ --wav-path ${wav_path} \
91
+ --test-list ${transcript_tsv} \
92
+ --model-dir ${model_path}
93
+
94
+ python3 -m zipvoice.eval.mos.utmos \
95
+ --wav-path ${wav_path} \
96
+ --model-dir ${model_path}
97
+ fi
98
+
99
+
100
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
101
+ echo "Stage 5: Evaluation on Seed-TTS test en"
102
+ model_path=${download_dir}/tts_eval_models
103
+ wav_path=results/seedtts_en
104
+ test_tsv=${download_dir}/seedtts_testset/en/test.tsv
105
+
106
+ python3 -m zipvoice.eval.speaker_similarity.sim \
107
+ --wav-path ${wav_path} \
108
+ --test-list ${test_tsv} \
109
+ --model-dir ${model_path}
110
+
111
+ python3 -m zipvoice.eval.wer.seedtts \
112
+ --wav-path ${wav_path} \
113
+ --test-list ${test_tsv} \
114
+ --model-dir ${model_path} \
115
+ --lang en
116
+
117
+ python3 -m zipvoice.eval.mos.utmos \
118
+ --wav-path ${wav_path} \
119
+ --model-dir ${model_path}
120
+ fi
121
+
122
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
123
+ echo "Stage 6: Evaluation on Seed-TTS test en"
124
+ model_path=${download_dir}/tts_eval_models
125
+ wav_path=results/seedtts_zh
126
+ test_tsv=${download_dir}/seedtts_testset/zh/test.tsv
127
+
128
+ python3 -m zipvoice.eval.speaker_similarity.sim \
129
+ --wav-path ${wav_path} \
130
+ --test-list ${test_tsv} \
131
+ --model-dir ${model_path}
132
+
133
+ python3 -m zipvoice.eval.wer.seedtts \
134
+ --wav-path ${wav_path} \
135
+ --test-list ${test_tsv} \
136
+ --model-dir ${model_path} \
137
+ --lang zh
138
+
139
+ python3 -m zipvoice.eval.mos.utmos \
140
+ --wav-path ${wav_path} \
141
+ --model-dir ${model_path}
142
+ fi
egs/zipvoice/run_finetune.sh ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script is an example of fine-tuning ZipVoice on your custom datasets.
4
+
5
+ # Add project root to PYTHONPATH
6
+ # export PYTHONPATH=../../:$PYTHONPATH
7
+
8
+ # Set bash to 'debug' mode, it will exit on:
9
+ # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
10
+ set -e
11
+ set -u
12
+ set -o pipefail
13
+
14
+ stage=1
15
+ stop_stage=6
16
+
17
+ # Number of jobs for data preparation
18
+ nj=4
19
+
20
+ # Whether the language of training data is one of Chinese and English
21
+ is_zh_en=0
22
+
23
+ # Language identifier, used when language is not Chinese or English
24
+ # see https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md
25
+ # Example of French: lang=fr
26
+ lang=vi
27
+
28
+ if [ $is_zh_en -eq 1 ]; then
29
+ tokenizer=espeak
30
+ else
31
+ tokenizer=espeak
32
+ [ "$lang" = "default" ] && { echo "Error: lang is not set!" >&2; exit 1; }
33
+ fi
34
+
35
+ # You can set `max_len` according to statistics from the command
36
+ # `lhotse cut describe data/fbank/custom_cuts_train.jsonl.gz`.
37
+ # Set `max_len` to 99% duration.
38
+
39
+ # Maximum length (seconds) of the training utterance, will filter out longer utterances
40
+ max_len=25
41
+
42
+ # Download directory for pre-trained models
43
+ download_dir=download
44
+
45
+ # We suppose you have two TSV files: "data/raw/custom_train.tsv" and
46
+ # "data/raw/custom_dev.tsv", where "custom" is your dataset name,
47
+ # "train"/"dev" are used for training and validation respectively.
48
+
49
+ # Each line of the TSV files should be in one of the following formats:
50
+ # (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
51
+ # (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
52
+ # to part of the wav. The start_time and end_time specify the start and end
53
+ # times of the text within the wav, which should be in seconds.
54
+ # > Note: {uniq_id} must be unique for each line.
55
+ # for subset in train dev;do
56
+ # file_path=data/raw/custom_${subset}.tsv
57
+ # [ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
58
+ # done
59
+
60
+ # ### Prepare the training data (1 - 4)
61
+
62
+ # if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
63
+ # echo "Stage 1: Prepare manifests for custom dataset from tsv files"
64
+
65
+ # for subset in train dev;do
66
+ # python3 -m zipvoice.bin.prepare_dataset \
67
+ # --tsv-path data/raw/custom_${subset}.tsv \
68
+ # --prefix custom-finetune \
69
+ # --subset raw_${subset} \
70
+ # --num-jobs ${nj} \
71
+ # --output-dir data/manifests
72
+ # done
73
+ # # The output manifest files are "data/manifests/custom-finetune_cuts_raw_train.jsonl.gz".
74
+ # # and "data/manifests/custom-finetune_cuts_raw_dev.jsonl.gz".
75
+ # fi
76
+
77
+
78
+ # if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
79
+ # echo "Stage 2: Add tokens to manifests"
80
+ # # For "emilia" and "espeak" tokenizers, it's better to prepare the tokens
81
+ # # before training. Otherwise, the on-the-fly tokenization can significantly
82
+ # # slow down the training.
83
+ # for subset in train dev;do
84
+ # python3 -m zipvoice.bin.prepare_tokens \
85
+ # --input-file data/manifests/custom-finetune_cuts_raw_${subset}.jsonl.gz \
86
+ # --output-file data/manifests/custom-finetune_cuts_${subset}.jsonl.gz \
87
+ # --tokenizer ${tokenizer} \
88
+ # --lang ${lang}
89
+ # done
90
+ # # The output manifest files are "data/manifests/custom-finetune_cuts_train.jsonl.gz".
91
+ # # and "data/manifests/custom-finetune_cuts_dev.jsonl.gz".
92
+ # fi
93
+
94
+ # if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
95
+ # echo "Stage 3: Compute Fbank for custom dataset"
96
+ # # You can skip this step and use `--on-the-fly-feats 1` in training stage
97
+ # for subset in train dev; do
98
+ # python3 -m zipvoice.bin.compute_fbank \
99
+ # --source-dir data/manifests \
100
+ # --dest-dir data/fbank \
101
+ # --dataset custom-finetune \
102
+ # --subset ${subset} \
103
+ # --num-jobs ${nj}
104
+ # done
105
+ # fi
106
+
107
+ # # if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
108
+ # # echo "Stage 4: Download pre-trained model, tokens file, and model config"
109
+ # # # Uncomment this line to use HF mirror
110
+ # # # export HF_ENDPOINT=https://hf-mirror.com
111
+ # # hf_repo=k2-fsa/ZipVoice
112
+ # # mkdir -p ${download_dir}
113
+ # # for file in model.pt tokens.txt model.json; do
114
+ # # huggingface-cli download \
115
+ # # --local-dir ${download_dir} \
116
+ # # ${hf_repo} \
117
+ # # zipvoice/${file}
118
+ # # done
119
+ # # fi
120
+
121
+ # # ### Training ZipVoice (5 - 6)
122
+
123
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
124
+ echo "Stage 5: Fine-tune the ZipVoice model"
125
+
126
+ [ -z "$max_len" ] && { echo "Error: max_len is not set!" >&2; exit 1; }
127
+
128
+ python3 -m zipvoice.bin.train_zipvoice \
129
+ --world-size 1 \
130
+ --use-fp16 1 \
131
+ --finetune 1 \
132
+ --base-lr 0.00006 \
133
+ --num-epochs 2 \
134
+ --save-every-n 1000 \
135
+ --keep-last-k 4 \
136
+ --max-duration 650 \
137
+ --max-len ${max_len} \
138
+ --min-len 1 \
139
+ --model-config ${download_dir}/zipvoice/model.json \
140
+ --checkpoint ${download_dir}/zipvoice/model.pt \
141
+ --tokenizer ${tokenizer} \
142
+ --lang ${lang} \
143
+ --token-file ${download_dir}/zipvoice/tokens.txt \
144
+ --dataset custom \
145
+ --train-manifest data/fbank/train_all.jsonl.gz \
146
+ --dev-manifest data/fbank/dev_all.jsonl.gz \
147
+ --exp-dir exp/zipvoice_finetune
148
+
149
+ fi
150
+
151
+ # if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
152
+ # echo "Stage 6: Average the checkpoints for ZipVoice"
153
+ # python3 -m zipvoice.bin.generate_averaged_model \
154
+ # --iter 10000 \
155
+ # --avg 2 \
156
+ # --model-name zipvoice \
157
+ # --exp-dir exp/zipvoice_finetune
158
+ # # The generated model is exp/zipvoice_finetune/iter-10000-avg-2.pt
159
+ # fi
160
+
161
+ # ### Inference with PyTorch models (7)
162
+
163
+ # if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
164
+ # echo "Stage 7: Inference of the ZipVoice model"
165
+
166
+ # python3 -m zipvoice.bin.infer_zipvoice \
167
+ # --model-name zipvoice \
168
+ # --model-dir exp/zipvoice_finetune/ \
169
+ # --checkpoint-name iter-10000-avg-2.pt \
170
+ # --tokenizer ${tokenizer} \
171
+ # --lang ${lang} \
172
+ # --test-list test.tsv \
173
+ # --res-dir results/test_finetune\
174
+ # --num-step 16
175
+ # fi
egs/zipvoice/run_libritts.sh ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This is an example script for training ZipVoice on LibriTTS dataset.
4
+
5
+ # Add project root to PYTHONPATH
6
+ export PYTHONPATH=../../:$PYTHONPATH
7
+
8
+ # Set bash to 'debug' mode, it will exit on :
9
+ # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
10
+ set -e
11
+ set -u
12
+ set -o pipefail
13
+
14
+ stage=1
15
+ stop_stage=9
16
+
17
+ #### Prepare datasets (1)
18
+
19
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
20
+ echo "Stage 1: Data Preparation for LibriTTS dataset"
21
+ bash local/prepare_libritts.sh
22
+ fi
23
+
24
+ ### Training ZipVoice (2 - 3)
25
+
26
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
27
+ echo "Stage 2: Train the ZipVoice model"
28
+ python3 -m zipvoice.bin.train_zipvoice \
29
+ --world-size 8 \
30
+ --use-fp16 0 \
31
+ --num-epochs 60 \
32
+ --max-duration 250 \
33
+ --lr-epochs 10 \
34
+ --max-len 20 \
35
+ --valid-by-epoch 1 \
36
+ --model-config conf/zipvoice_base.json \
37
+ --tokenizer libritts \
38
+ --token-file data/tokens_libritts.txt \
39
+ --dataset libritts \
40
+ --manifest-dir data/fbank \
41
+ --exp-dir exp/zipvoice_libritts
42
+ fi
43
+
44
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
45
+ echo "Stage 3: Average the checkpoints for ZipVoice"
46
+ python3 -m zipvoice.bin.generate_averaged_model \
47
+ --epoch 60 \
48
+ --avg 10 \
49
+ --model-name zipvoice \
50
+ --exp-dir exp/zipvoice_libritts
51
+ # The generated model is exp/zipvoice_libritts/epoch-60-avg-10.pt
52
+ fi
53
+
54
+ #### (Optional) Training ZipVoice-Distill model (4 - 7)
55
+
56
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
57
+ echo "Stage 4: Train the ZipVoice-Distill model (first stage)"
58
+ python3 -m zipvoice.bin.train_zipvoice_distill \
59
+ --world-size 8 \
60
+ --use-fp16 0 \
61
+ --num-epochs 6 \
62
+ --max-duration 250 \
63
+ --base-lr 0.001 \
64
+ --max-len 20 \
65
+ --valid-by-epoch 1 \
66
+ --model-config conf/zipvoice_base.json \
67
+ --tokenizer libritts \
68
+ --token-file data/tokens_libritts.txt \
69
+ --dataset "libritts" \
70
+ --manifest-dir "data/fbank" \
71
+ --teacher-model exp/zipvoice_libritts/epoch-60-avg-10.pt \
72
+ --distill-stage "first" \
73
+ --exp-dir exp/zipvoice_distill_1stage_libritts
74
+ fi
75
+
76
+
77
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
78
+ echo "Stage 5: Average the checkpoints for ZipVoice-Distill (first stage)"
79
+ python3 -m zipvoice.bin.generate_averaged_model \
80
+ --epoch 6 \
81
+ --avg 3 \
82
+ --model-name zipvoice_distill \
83
+ --exp-dir exp/zipvoice_distill_1stage_libritts
84
+ # The generated model is exp/zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt
85
+ fi
86
+
87
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
88
+ echo "Stage 6: Train the ZipVoice-Distill model (second stage)"
89
+
90
+ python3 -m zipvoice.bin.train_zipvoice_distill \
91
+ --world-size 8 \
92
+ --use-fp16 1 \
93
+ --num-epochs 6 \
94
+ --max-duration 250 \
95
+ --base-lr 0.001 \
96
+ --max-len 20 \
97
+ --valid-by-epoch 1 \
98
+ --model-config conf/zipvoice_base.json \
99
+ --tokenizer libritts \
100
+ --token-file data/tokens_libritts.txt \
101
+ --dataset libritts \
102
+ --manifest-dir data/fbank \
103
+ --teacher-model exp/zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt \
104
+ --distill-stage second \
105
+ --exp-dir exp/zipvoice_distill_libritts
106
+ fi
107
+
108
+
109
+ if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
110
+ echo "Stage 7: Average the checkpoints for ZipVoice-Distill (second stage)"
111
+ python3 -m zipvoice.bin.generate_averaged_model \
112
+ --epoch 6 \
113
+ --avg 3 \
114
+ --model-name zipvoice_distill \
115
+ --exp-dir exp/zipvoice_distill_libritts
116
+ # The generated model is exp/zipvoice_distill_libritts/epoch-6-avg-3.pt
117
+ fi
118
+
119
+ ### Inference with PyTorch models (8 - 9)
120
+
121
+ if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
122
+ echo "Stage 8: Inference of the ZipVoice model"
123
+ python3 -m zipvoice.bin.infer_zipvoice \
124
+ --model-name zipvoice \
125
+ --model-dir exp/zipvoice_libritts \
126
+ --checkpoint-name epoch-60-avg-10.pt \
127
+ --tokenizer libritts \
128
+ --test-list test.tsv \
129
+ --res-dir results/test_libritts \
130
+ --num-step 8 \
131
+ --guidance-scale 1 \
132
+ --t-shift 0.7
133
+ fi
134
+
135
+
136
+ if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
137
+ echo "Stage 9: Inference of the ZipVoice-Distill model"
138
+ python3 -m zipvoice.bin.infer_zipvoice \
139
+ --model-name zipvoice_distill \
140
+ --model-dir exp/zipvoice_distill_libritts \
141
+ --checkpoint-name epoch-6-avg-3.pt \
142
+ --tokenizer libritts \
143
+ --test-list test.tsv \
144
+ --res-dir results/test_distill_libritts \
145
+ --num-step 4 \
146
+ --guidance-scale 3 \
147
+ --t-shift 0.7
148
+ fi
egs/zipvoice/utils/parse_options.sh ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
4
+ # Arnab Ghoshal, Karel Vesely
5
+
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13
+ # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
14
+ # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
15
+ # MERCHANTABLITY OR NON-INFRINGEMENT.
16
+ # See the Apache 2 License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ # Parse command-line options.
21
+ # To be sourced by another script (as in ". parse_options.sh").
22
+ # Option format is: --option-name arg
23
+ # and shell variable "option_name" gets set to value "arg."
24
+ # The exception is --help, which takes no arguments, but prints the
25
+ # $help_message variable (if defined).
26
+
27
+
28
+ ###
29
+ ### The --config file options have lower priority to command line
30
+ ### options, so we need to import them first...
31
+ ###
32
+
33
+ # Now import all the configs specified by command-line, in left-to-right order
34
+ for ((argpos=1; argpos<$#; argpos++)); do
35
+ if [ "${!argpos}" == "--config" ]; then
36
+ argpos_plus1=$((argpos+1))
37
+ config=${!argpos_plus1}
38
+ [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
39
+ . $config # source the config file.
40
+ fi
41
+ done
42
+
43
+
44
+ ###
45
+ ### Now we process the command line options
46
+ ###
47
+ while true; do
48
+ [ -z "${1:-}" ] && break; # break if there are no arguments
49
+ case "$1" in
50
+ # If the enclosing script is called with --help option, print the help
51
+ # message and exit. Scripts should put help messages in $help_message
52
+ --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
53
+ else printf "$help_message\n" 1>&2 ; fi;
54
+ exit 0 ;;
55
+ --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
56
+ exit 1 ;;
57
+ # If the first command-line argument begins with "--" (e.g. --foo-bar),
58
+ # then work out the variable name as $name, which will equal "foo_bar".
59
+ --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
60
+ # Next we test whether the variable in question is undefned-- if so it's
61
+ # an invalid option and we die. Note: $0 evaluates to the name of the
62
+ # enclosing script.
63
+ # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
64
+ # is undefined. We then have to wrap this test inside "eval" because
65
+ # foo_bar is itself inside a variable ($name).
66
+ eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
67
+
68
+ oldval="`eval echo \\$$name`";
69
+ # Work out whether we seem to be expecting a Boolean argument.
70
+ if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
71
+ was_bool=true;
72
+ else
73
+ was_bool=false;
74
+ fi
75
+
76
+ # Set the variable to the right value-- the escaped quotes make it work if
77
+ # the option had spaces, like --cmd "queue.pl -sync y"
78
+ eval $name=\"$2\";
79
+
80
+ # Check that Boolean-valued arguments are really Boolean.
81
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
82
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
83
+ exit 1;
84
+ fi
85
+ shift 2;
86
+ ;;
87
+ *) break;
88
+ esac
89
+ done
90
+
91
+
92
+ # Check for an empty argument to the --cmd option, which can easily occur as a
93
+ # result of scripting errors.
94
+ [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
95
+
96
+
97
+ true; # so this script returns exit code 0.
egs/zipvoice/utils/validate_manifest.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
3
+ # Zengwei Yao)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ This script checks the following assumptions of the generated manifest:
20
+
21
+ - Single supervision per cut
22
+
23
+ We will add more checks later if needed.
24
+
25
+ Usage example:
26
+
27
+ python3 ./utils/validate_manifest.py \
28
+ ./data/spectrogram/ljspeech_cuts_all.jsonl.gz
29
+
30
+ """
31
+
32
+ import argparse
33
+ import logging
34
+ from pathlib import Path
35
+
36
+ from lhotse import CutSet, load_manifest_lazy
37
+ from lhotse.dataset.speech_synthesis import validate_for_tts
38
+
39
+
40
+ def get_args():
41
+ parser = argparse.ArgumentParser()
42
+
43
+ parser.add_argument(
44
+ "manifest",
45
+ type=Path,
46
+ help="Path to the manifest file",
47
+ )
48
+
49
+ return parser.parse_args()
50
+
51
+
52
+ def main():
53
+ args = get_args()
54
+
55
+ manifest = args.manifest
56
+ logging.info(f"Validating {manifest}")
57
+
58
+ assert manifest.is_file(), f"{manifest} does not exist"
59
+ cut_set = load_manifest_lazy(manifest)
60
+ assert isinstance(cut_set, CutSet), type(cut_set)
61
+
62
+ validate_for_tts(cut_set)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
67
+
68
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
69
+
70
+ main()
egs/zipvoice_dialog/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ZipVoice-Dialog Recipe
2
+
3
+ This recipe contains the following examples:
4
+
5
+ - Training ZipVoice-Dialog on OpenDialog dataset, see [run_opendialog.sh](run_opendialog.sh)
6
+ - Training ZipVoice-Dialog on custom datasets (Chinese/English), see [run_custom.sh](run_custom.sh).
7
+ - Fine-tuning pre-trained ZipVoice-Dialog on custom datasets (Chinese/English), see [run_finetune.sh](run_finetune.sh).
8
+ - Evaluate models with objective metrics reported in ZipVoice-Dialog paper, see [run_eval.sh](run_eval.sh).
9
+
10
+ > **NOTE:** For evaluation, first install packages from [../../requirements_eval.txt](../../requirements_eval.txt)
11
+ >
12
+ > `pip install -r ../../requirements_eval.txt`
egs/zipvoice_dialog/local/prepare_opendialog.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script prepares lhotse manifest files from the raw OpenDialog datasets.
20
+
21
+ We assume that you have downloaded the OpenDialog dataset and untarred the
22
+ tar files in audio/en and audio/zh so that the mp3 files are placed under
23
+ these two directories.
24
+
25
+ Download OpenDialog at https://huggingface.co/datasets/k2-fsa/OpenDialog
26
+ or https://www.modelscope.cn/datasets/k2-fsa/OpenDialog
27
+
28
+ """
29
+
30
+ import argparse
31
+ import json
32
+ import logging
33
+ import math
34
+ import re
35
+ from concurrent.futures import ThreadPoolExecutor
36
+ from functools import partial
37
+ from pathlib import Path
38
+ from typing import List, Optional, Tuple
39
+
40
+ from lhotse import CutSet, validate_recordings_and_supervisions
41
+ from lhotse.audio import Recording, RecordingSet
42
+ from lhotse.cut import Cut
43
+ from lhotse.qa import fix_manifests
44
+ from lhotse.supervision import SupervisionSegment, SupervisionSet
45
+ from lhotse.utils import Pathlike
46
+ from tqdm.auto import tqdm
47
+
48
+
49
+ def get_args():
50
+ parser = argparse.ArgumentParser()
51
+
52
+ parser.add_argument(
53
+ "--dataset-path",
54
+ type=str,
55
+ help="The path of OpenDialog dataset.",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--num-jobs",
60
+ type=int,
61
+ default=20,
62
+ help="Number of jobs to processing.",
63
+ )
64
+
65
+ parser.add_argument(
66
+ "--output-dir",
67
+ type=str,
68
+ default="data/manifests",
69
+ help="The destination directory of manifest files.",
70
+ )
71
+ parser.add_argument(
72
+ "--sampling-rate",
73
+ type=int,
74
+ default=24000,
75
+ help="The target sampling rate.",
76
+ )
77
+ return parser.parse_args()
78
+
79
+
80
+ def _parse_recording(
81
+ wav_path: str,
82
+ ) -> Tuple[Recording, str]:
83
+ """
84
+ :param wav_path: Path to the audio file
85
+ :return: a tuple of "recording" and "recording_id"
86
+ """
87
+
88
+ recording_id = Path(wav_path).stem
89
+ recording = Recording.from_file(path=wav_path, recording_id=recording_id)
90
+
91
+ return recording, recording_id
92
+
93
+
94
+ def _parse_supervision(
95
+ supervision: List, recording_dict: dict
96
+ ) -> Optional[SupervisionSegment]:
97
+ """
98
+ :param line: A line from the TSV file
99
+ :param recording_dict: Dictionary mapping recording IDs to Recording objects
100
+ :return: A SupervisionSegment object
101
+ """
102
+
103
+ def _round_down(num, ndigits=0):
104
+ factor = 10**ndigits
105
+ return math.floor(num * factor) / factor
106
+
107
+ uniq_id, text, wav_path, start, end = supervision
108
+ try:
109
+ recording_id = Path(wav_path).stem
110
+
111
+ recording = recording_dict[recording_id]
112
+ duration = (
113
+ _round_down(end - start, ndigits=8)
114
+ if end is not None
115
+ else _round_down(recording.duration, ndigits=8)
116
+ )
117
+ assert duration <= recording.duration, f"Duration {duration} is greater than "
118
+ f"recording duration {recording.duration}"
119
+
120
+ text = re.sub("_", " ", text) # "_" is treated as padding symbol
121
+ text = re.sub(r"\s+", " ", text) # remove extra whitespace
122
+
123
+ return SupervisionSegment(
124
+ id=f"{uniq_id}",
125
+ recording_id=recording.id,
126
+ start=start,
127
+ duration=duration,
128
+ channel=recording.channel_ids,
129
+ text=text.strip(),
130
+ )
131
+ except Exception as e:
132
+ logging.info(f"Error processing line: {e}")
133
+ return None
134
+
135
+
136
+ def prepare_subset(
137
+ jsonl_path: Pathlike,
138
+ lang: str,
139
+ sampling_rate: int,
140
+ num_jobs: int,
141
+ output_dir: Pathlike,
142
+ ):
143
+ """
144
+ Returns the manifests which consist of the Recordings and Supervisions
145
+
146
+ :param jsonl_path: Path to the jsonl file
147
+ :param lang: Language of the subset
148
+ :param sampling_rate: Target sampling rate of the audio
149
+ :param num_jobs: Number of processes for parallel processing
150
+ :param output_dir: Path where to write the manifests
151
+ """
152
+ logging.info(f"Preparing {lang} subset")
153
+
154
+ # Step 1: Read all unique recording paths
155
+ logging.info(f"Reading {jsonl_path}")
156
+ recordings_path_set = set()
157
+ supervision_list = list()
158
+ with open(jsonl_path, "r") as fr:
159
+ for line in fr:
160
+ try:
161
+ items = json.loads(line)
162
+ uniq_id, text, wav_path = items["id"], items["text"], items["path"]
163
+ start, end = 0, None
164
+ recordings_path_set.add(jsonl_path.parent / wav_path)
165
+ supervision_list.append((uniq_id, text, wav_path, start, end))
166
+ except Exception as e:
167
+ logging.warning(f"Error {e} when decoding JSON line: {line}")
168
+ continue
169
+ logging.info("Starting to process recordings...")
170
+ # Step 2: Process recordings
171
+ futures = []
172
+ recording_dict = {}
173
+ with ThreadPoolExecutor(max_workers=num_jobs) as ex:
174
+ for wav_path in tqdm(recordings_path_set, desc="Submitting jobs"):
175
+ futures.append(ex.submit(_parse_recording, wav_path))
176
+
177
+ for future in tqdm(futures, desc="Processing recordings"):
178
+ try:
179
+ recording, recording_id = future.result()
180
+ recording_dict[recording_id] = recording
181
+ except Exception as e:
182
+ logging.warning(
183
+ f"Error processing recording {recording_id} with error: {e}"
184
+ )
185
+
186
+ recording_set = RecordingSet.from_recordings(recording_dict.values())
187
+
188
+ logging.info("Starting to process supervisions...")
189
+ # Step 3: Process supervisions
190
+ supervisions = []
191
+ for supervision in tqdm(supervision_list, desc="Processing supervisions"):
192
+ seg = _parse_supervision(supervision, recording_dict)
193
+ if seg is not None:
194
+ supervisions.append(seg)
195
+
196
+ logging.info("Processing Cuts...")
197
+
198
+ # Step 4: Create and validate manifests
199
+ supervision_set = SupervisionSet.from_segments(supervisions)
200
+
201
+ recording_set, supervision_set = fix_manifests(recording_set, supervision_set)
202
+ validate_recordings_and_supervisions(recording_set, supervision_set)
203
+
204
+ cut_set = CutSet.from_manifests(
205
+ recordings=recording_set, supervisions=supervision_set
206
+ )
207
+ cut_set = cut_set.sort_by_recording_id()
208
+ if sampling_rate != 24000:
209
+ # All OpenDialog audios are 24kHz
210
+ cut_set = cut_set.resample(sampling_rate)
211
+ cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
212
+
213
+ logging.info("Saving cuts to disk...")
214
+ # Step 5: Write manifests to disk
215
+ cut_set.to_file(output_dir / f"opendialog_cuts_raw_{lang.upper()}-all.jsonl.gz")
216
+ dev_cut_set = cut_set.subset(first=1000)
217
+ dev_cut_set.to_file(output_dir / f"opendialog_cuts_raw_{lang.upper()}-dev.jsonl.gz")
218
+
219
+ def remove_dev(c: Cut, set: set):
220
+ if c.id in set:
221
+ return False
222
+ return True
223
+
224
+ _remove_dev = partial(remove_dev, set=set(dev_cut_set.ids))
225
+ train_cut_set = cut_set.filter(_remove_dev)
226
+ train_cut_set.to_file(
227
+ output_dir / f"opendialog_cuts_raw_{lang.upper()}-train.jsonl.gz"
228
+ )
229
+
230
+
231
+ def prepare_dataset(
232
+ dataset_path: Pathlike,
233
+ sampling_rate: int,
234
+ num_jobs: int,
235
+ output_dir: Pathlike,
236
+ ):
237
+ for lang in ["en", "zh"]:
238
+ jsonl_path = dataset_path / f"manifest.{lang}.jsonl"
239
+ prepare_subset(
240
+ jsonl_path=jsonl_path,
241
+ lang=lang,
242
+ sampling_rate=sampling_rate,
243
+ num_jobs=num_jobs,
244
+ output_dir=output_dir,
245
+ )
246
+
247
+
248
+ if __name__ == "__main__":
249
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
250
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
251
+
252
+ args = get_args()
253
+ dataset_path = Path(args.dataset_path)
254
+ output_dir = Path(args.output_dir)
255
+ output_dir.mkdir(parents=True, exist_ok=True)
256
+
257
+ prepare_dataset(
258
+ dataset_path=dataset_path,
259
+ sampling_rate=args.sampling_rate,
260
+ num_jobs=args.num_jobs,
261
+ output_dir=output_dir,
262
+ )
egs/zipvoice_dialog/run_custom.sh ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script is an example of training ZipVoice-Dialog on your custom datasets.
4
+ # Only support English and Chinese for now.
5
+
6
+ # Add project root to PYTHONPATH
7
+ export PYTHONPATH=../../:$PYTHONPATH
8
+
9
+ # Set bash to 'debug' mode, it will exit on:
10
+ # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
11
+ set -e
12
+ set -u
13
+ set -o pipefail
14
+
15
+ stage=1
16
+ stop_stage=6
17
+
18
+ # Number of jobs for data preparation
19
+ nj=20
20
+ download_dir=download/
21
+
22
+ # Maximum length (seconds) of the training utterance, will filter out longer utterances
23
+ max_len=60
24
+
25
+ # We suppose you have two TSV files: "data/raw/custom_train.tsv" and
26
+ # "data/raw/custom_dev.tsv", where "custom" is your dataset name,
27
+ # "train"/"dev" are used for training and validation respectively.
28
+
29
+ # Each line of the TSV files should be in one of the following formats:
30
+ # (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
31
+ # (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
32
+ # to part of the wav. The start_time and end_time specify the start and end
33
+ # times of the text within the wav, which should be in seconds.
34
+ # > Note: {uniq_id} must be unique for each line.
35
+ # > Note: {text} uses [S1] and [S2] tags to distinguish speakers, and must be begin with [S1].
36
+ # > eg: "[S1] Hello. [S2] How are you? [S1] I'm fine. [S2] What's your name?"
37
+ for subset in train dev;do
38
+ file_path=data/raw/custom_${subset}.tsv
39
+ [ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
40
+ done
41
+
42
+
43
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
44
+ echo "Stage 1: Prepare manifests for custom dataset from tsv files"
45
+
46
+ for subset in train dev;do
47
+ python3 -m zipvoice.bin.prepare_dataset \
48
+ --tsv-path data/raw/custom_${subset}.tsv \
49
+ --prefix custom \
50
+ --subset raw_${subset} \
51
+ --num-jobs ${nj} \
52
+ --output-dir data/manifests
53
+ done
54
+ # The output manifest files are "data/manifests/custom_cuts_raw_train.jsonl.gz".
55
+ # and "data/manifests/custom_cuts_raw_dev.jsonl.gz".
56
+ fi
57
+
58
+
59
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
60
+ echo "Stage 2: Add tokens to manifests"
61
+ for subset in train dev;do
62
+ python3 -m zipvoice.bin.prepare_tokens \
63
+ --input-file data/manifests/custom_cuts_raw_${subset}.jsonl.gz \
64
+ --output-file data/manifests/custom_cuts_${subset}.jsonl.gz \
65
+ --tokenizer dialog
66
+ done
67
+ # The output manifest files are "data/manifests/custom_cuts_train.jsonl.gz".
68
+ # and "data/manifests/custom_cuts_dev.jsonl.gz".
69
+ fi
70
+
71
+ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
72
+ echo "Stage 3: Compute Fbank for custom dataset"
73
+ # You can skip this step and use `--on-the-fly-feats 1` in training stage
74
+ for subset in train dev; do
75
+ python3 -m zipvoice.bin.compute_fbank \
76
+ --source-dir data/manifests \
77
+ --dest-dir data/fbank \
78
+ --dataset custom \
79
+ --subset ${subset} \
80
+ --num-jobs ${nj}
81
+ done
82
+ fi
83
+
84
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
85
+ echo "Stage 4: Download tokens file, pretrained models"
86
+ # Uncomment this line to use HF mirror
87
+ # export HF_ENDPOINT=https://hf-mirror.com
88
+
89
+ # The token file is obtained by extending some tokens
90
+ # on the bases of the Emilia token file.
91
+ mkdir -p ${download_dir}
92
+ hf_repo=k2-fsa/ZipVoice
93
+ huggingface-cli download \
94
+ --local-dir ${download_dir} \
95
+ ${hf_repo} \
96
+ zipvoice_dialog/tokens.txt
97
+
98
+ # Pre-trained ZipVoice model is required as
99
+ # the initialization model.
100
+ for file in model.pt tokens.txt model.json; do
101
+ huggingface-cli download \
102
+ --local-dir ${download_dir} \
103
+ ${hf_repo} \
104
+ zipvoice/${file}
105
+ done
106
+ fi
107
+
108
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
109
+ echo "Stage 5: Train the ZipVoice-Dialog model"
110
+ python3 -m zipvoice.bin.train_zipvoice_dialog \
111
+ --world-size 4 \
112
+ --use-fp16 1 \
113
+ --base-lr 0.0001 \
114
+ --num-iters 60000 \
115
+ --max-duration 500 \
116
+ --max-len ${max_len} \
117
+ --checkpoint ${download_dir}/zipvoice/model.pt \
118
+ --model-config ${download_dir}/zipvoice/model.json \
119
+ --token-file ${download_dir}/zipvoice_dialog/tokens.txt \
120
+ --dataset custom \
121
+ --train-manifest data/fbank/custom_cuts_train.jsonl.gz \
122
+ --dev-manifest data/fbank/custom_cuts_dev.jsonl.gz \
123
+ --exp-dir exp/zipvoice_dialog_custom
124
+ fi
125
+
126
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
127
+ echo "Stage 6: Average the checkpoints for ZipVoice"
128
+ python3 -m zipvoice.bin.generate_averaged_model \
129
+ --iter 60000 \
130
+ --avg 2 \
131
+ --model-name zipvoice_dialog \
132
+ --exp-dir exp/zipvoice_dialog_custom
133
+ # The generated model is exp/zipvoice_dialog/iter-60000-avg-2.pt
134
+ fi
135
+
136
+
137
+ if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
138
+ echo "Stage 6: Inference of the ZipVoice model"
139
+ python3 -m zipvoice.bin.infer_zipvoice_dialog \
140
+ --model-name zipvoice_dialog \
141
+ --model-dir exp/zipvoice_dialog_custom \
142
+ --checkpoint-name iter-60000-avg-2.pt \
143
+ --test-list test.tsv \
144
+ --res-dir results/test_dialog_custom
145
+ fi
egs/zipvoice_dialog/run_eval.sh ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script is an example of evaluate TTS models with objective metrics reported in ZipVoice-Dialog paper.
4
+
5
+ # Add project root to PYTHONPATH
6
+ export PYTHONPATH=../../:$PYTHONPATH
7
+
8
+ # Set bash to 'debug' mode, it will exit on:
9
+ # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
10
+ set -e
11
+ set -u
12
+ set -o pipefail
13
+
14
+ stage=1
15
+ stop_stage=6
16
+
17
+ download_dir=download/
18
+
19
+ # Uncomment this line to use HF mirror
20
+ # export HF_ENDPOINT=https://hf-mirror.com
21
+
22
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
23
+ echo "Stage 1: Download test sets (test-dialog)"
24
+ hf_repo=k2-fsa/TTS_eval_datasets
25
+ mkdir -p ${download_dir}/
26
+ file=dialog_testset.tar.gz
27
+ echo "Downloading ${file}..."
28
+ huggingface-cli download \
29
+ --repo-type dataset \
30
+ --local-dir ${download_dir}/ \
31
+ ${hf_repo} \
32
+ ${file}
33
+ echo "Extracting ${file}..."
34
+ tar -xzf ${download_dir}/${file} -C ${download_dir}/
35
+ fi
36
+
37
+
38
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
39
+ echo "Stage 2: Download all required evaluation models"
40
+ mkdir -p ${download_dir}/tts_eval_models
41
+ mkdir -p ${download_dir}
42
+ huggingface-cli download \
43
+ --local-dir ${download_dir}/tts_eval_models \
44
+ ${hf_repo}
45
+ fi
46
+
47
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
48
+ echo "Stage 3: Inference with the pre-trained ZipVoice model from huggingface"
49
+
50
+ for testset in test_dialog_en test_dialog_zh; do
51
+ if [ "$testset" = "test_dialog_en" ]; then
52
+ test_tsv=${download_dir}/dialog_testset/en/test.tsv
53
+ elif [ "$testset" = "test_dialog_zh" ]; then
54
+ test_tsv=${download_dir}/dialog_testset/zh/test.tsv
55
+ else
56
+ echo "Error: unknown testset ${testset}" >&2
57
+ exit 1
58
+ fi
59
+ echo "Inference on tetset ${testset}..."
60
+ python3 -m zipvoice.bin.infer_zipvoice_dialog \
61
+ --model-name zipvoice_dialog \
62
+ --test-list ${test_tsv} \
63
+ --res-dir results/${testset}
64
+ done
65
+ fi
66
+
67
+
68
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
69
+ echo "Stage 4: Evaluation on test-dialog-en"
70
+ model_path=${download_dir}/tts_eval_models
71
+ wav_path=results/test_dialog_en
72
+ test_tsv=${download_dir}/dialog_testset/en/test.tsv
73
+
74
+ python3 -m zipvoice.eval.speaker_similarity.cpsim \
75
+ --wav-path ${wav_path} \
76
+ --test-list ${test_tsv} \
77
+ --model-dir ${model_path}
78
+
79
+ python3 -m zipvoice.eval.wer.dialog \
80
+ --wav-path ${wav_path} \
81
+ --test-list ${test_tsv} \
82
+ --model-dir ${model_path} \
83
+ --lang en
84
+
85
+ # cpWER mode: will only compute WER and cpWER
86
+ # for speech less than 30s
87
+ python3 -m zipvoice.eval.wer.dialog \
88
+ --wav-path ${wav_path} \
89
+ --test-list ${test_tsv} \
90
+ --model-dir ${model_path} \
91
+ --lang en \
92
+ --cpwer
93
+
94
+ python3 -m zipvoice.eval.mos.utmos \
95
+ --wav-path ${wav_path} \
96
+ --model-dir ${model_path}
97
+ fi
98
+
99
+
100
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
101
+ echo "Stage 5: Evaluation on test-dialog-zh"
102
+ model_path=${download_dir}/tts_eval_models
103
+ wav_path=results/test_dialog_zh
104
+ test_tsv=${download_dir}/dialog_testset/zh/test.tsv
105
+
106
+ python3 -m zipvoice.eval.speaker_similarity.cpsim \
107
+ --wav-path ${wav_path} \
108
+ --test-list ${test_tsv} \
109
+ --model-dir ${model_path}
110
+
111
+ python3 -m zipvoice.eval.wer.dialog \
112
+ --wav-path ${wav_path} \
113
+ --test-list ${test_tsv} \
114
+ --model-dir ${model_path} \
115
+ --lang zh
116
+
117
+ python3 -m zipvoice.eval.mos.utmos \
118
+ --wav-path ${wav_path} \
119
+ --model-dir ${model_path}
120
+ fi
egs/zipvoice_dialog/run_finetune.sh ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script is an example of fine-tune our pre-trained ZipVoice-Dialog on your custom datasets.
4
+ # Only support English and Chinese for now.
5
+
6
+ # Add project root to PYTHONPATH
7
+ export PYTHONPATH=../../:$PYTHONPATH
8
+
9
+ # Set bash to 'debug' mode, it will exit on:
10
+ # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
11
+ set -e
12
+ set -u
13
+ set -o pipefail
14
+
15
+ stage=1
16
+ stop_stage=6
17
+
18
+ # Number of jobs for data preparation
19
+ nj=20
20
+ # Maximum length (seconds) of the training utterance, will filter out longer utterances
21
+ max_len=60
22
+ download_dir=download/
23
+
24
+ # We suppose you have two TSV files: "data/raw/custom_train.tsv" and
25
+ # "data/raw/custom_dev.tsv", where "custom" is your dataset name,
26
+ # "train"/"dev" are used for training and validation respectively.
27
+
28
+ # Each line of the TSV files should be in one of the following formats:
29
+ # (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
30
+ # (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
31
+ # to part of the wav. The start_time and end_time specify the start and end
32
+ # times of the text within the wav, which should be in seconds.
33
+ # > Note: {uniq_id} must be unique for each line.
34
+ # > Note: {text} uses [S1] and [S2] tags to distinguish speakers, and must be begin with [S1].
35
+ # > eg: "[S1] Hello. [S2] How are you? [S1] I'm fine. [S2] What's your name?"
36
+ for subset in train dev;do
37
+ file_path=data/raw/custom_${subset}.tsv
38
+ [ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
39
+ done
40
+
41
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
42
+ echo "Stage 1: Prepare manifests for custom dataset from tsv files"
43
+
44
+ for subset in train dev;do
45
+ python3 -m zipvoice.bin.prepare_dataset \
46
+ --tsv-path data/raw/custom_${subset}.tsv \
47
+ --prefix custom-finetune \
48
+ --subset raw_${subset} \
49
+ --num-jobs ${nj} \
50
+ --output-dir data/manifests
51
+ done
52
+ # The output manifest files are "data/manifests/custom-finetune_cuts_raw_train.jsonl.gz".
53
+ # and "data/manifests/custom-finetune_cuts_raw_dev.jsonl.gz".
54
+ fi
55
+
56
+
57
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
58
+ echo "Stage 2: Add tokens to manifests"
59
+ for subset in train dev;do
60
+ python3 -m zipvoice.bin.prepare_tokens \
61
+ --input-file data/manifests/custom-finetune_cuts_raw_${subset}.jsonl.gz \
62
+ --output-file data/manifests/custom-finetune_cuts_${subset}.jsonl.gz \
63
+ --tokenizer dialog
64
+ done
65
+ # The output manifest files are "data/manifests/custom-finetune_cuts_train.jsonl.gz".
66
+ # and "data/manifests/custom-finetune_cuts_dev.jsonl.gz".
67
+ fi
68
+
69
+ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
70
+ echo "Stage 3: Compute Fbank for custom dataset"
71
+ # You can skip this step and use `--on-the-fly-feats 1` in training stage
72
+ for subset in train dev; do
73
+ python3 -m zipvoice.bin.compute_fbank \
74
+ --source-dir data/manifests \
75
+ --dest-dir data/fbank \
76
+ --dataset custom-finetune \
77
+ --subset ${subset} \
78
+ --num-jobs ${nj}
79
+ done
80
+ fi
81
+
82
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
83
+ echo "Stage 4: Download pre-trained model, tokens file, and model config"
84
+ # Uncomment this line to use HF mirror
85
+ # export HF_ENDPOINT=https://hf-mirror.com
86
+
87
+ mkdir -p ${download_dir}
88
+ hf_repo=k2-fsa/ZipVoice
89
+ for file in model.pt tokens.txt model.json; do
90
+ huggingface-cli download \
91
+ --local-dir ${download_dir} \
92
+ ${hf_repo} \
93
+ zipvoice_dialog/${file}
94
+ done
95
+ fi
96
+
97
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
98
+ echo "Stage 5: Fine-tune the ZipVoice-Dialog model"
99
+ python3 -m zipvoice.bin.train_zipvoice_dialog \
100
+ --world-size 4 \
101
+ --use-fp16 1 \
102
+ --finetune 1 \
103
+ --base-lr 0.0001 \
104
+ --num-iters 10000 \
105
+ --save-every-n 1000 \
106
+ --max-duration 500 \
107
+ --max-len ${max_len} \
108
+ --checkpoint ${download_dir}/zipvoice_dialog/model.pt \
109
+ --model-config ${download_dir}/zipvoice_dialog/model.json \
110
+ --token-file ${download_dir}/zipvoice_dialog/tokens.txt \
111
+ --dataset custom \
112
+ --train-manifest data/fbank/custom-finetune_cuts_train.jsonl.gz \
113
+ --dev-manifest data/fbank/custom-finetune_cuts_dev.jsonl.gz \
114
+ --exp-dir exp/zipvoice_dialog_finetune
115
+ fi
116
+
117
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
118
+ echo "Stage 6: Average the checkpoints for ZipVoice"
119
+ python3 -m zipvoice.bin.generate_averaged_model \
120
+ --iter 10000 \
121
+ --avg 2 \
122
+ --model-name zipvoice_dialog \
123
+ --exp-dir exp/zipvoice_dialog_finetune
124
+ # The generated model is exp/zipvoice_dialog_finetune/iter-10000-avg-2.pt
125
+ fi
126
+
127
+ if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
128
+ echo "Stage 7: Inference of the ZipVoice model"
129
+ python3 -m zipvoice.bin.infer_zipvoice_dialog \
130
+ --model-name zipvoice_dialog \
131
+ --model-dir exp/zipvoice_dialog_finetune \
132
+ --checkpoint-name iter-10000-avg-2.pt \
133
+ --test-list test.tsv \
134
+ --res-dir results/test_dialog_finetune
135
+ fi
egs/zipvoice_dialog/run_opendialog.sh ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script is an example of training ZipVoice-Dialog on OpenDialog dataset.
4
+
5
+ # Add project root to PYTHONPATH
6
+ export PYTHONPATH=../../:$PYTHONPATH
7
+
8
+ # Set bash to 'debug' mode, it will exit on:
9
+ # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
10
+ set -e
11
+ set -u
12
+ set -o pipefail
13
+
14
+ stage=1
15
+ stop_stage=6
16
+
17
+ # Number of jobs for data preparation
18
+ nj=20
19
+
20
+ # We assume that you have downloaded the OpenDialog dataset
21
+ # to download/OpenDialog and untarred the tar files in audio/en
22
+ # and audio/zh so that the mp3 files are placed under these two directories.
23
+
24
+ # Download OpenDialog at https://huggingface.co/datasets/k2-fsa/OpenDialog
25
+ # or https://www.modelscope.cn/datasets/k2-fsa/OpenDialog
26
+ data_dir=download/OpenDialog
27
+ download_dir=download/
28
+
29
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
30
+ echo "Stage 1: Prepare manifests for OpenDialog dataset"
31
+
32
+ python3 local/prepare_opendialog.py \
33
+ --dataset-path ${data_dir} \
34
+ --num-jobs ${nj} \
35
+ --output-dir data/manifests
36
+ fi
37
+
38
+
39
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
40
+ echo "Stage 2: Add tokens to manifests"
41
+ for subset in ZH-dev ZH-train EN-dev EN-train;do
42
+ python3 -m zipvoice.bin.prepare_tokens \
43
+ --input-file data/manifests/opendialog_cuts_raw_${subset}.jsonl.gz \
44
+ --output-file data/manifests/opendialog_cuts_${subset}.jsonl.gz \
45
+ --tokenizer dialog
46
+ done
47
+ fi
48
+
49
+
50
+ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
51
+ echo "Stage 3: Compute Fbank for opendialog dataset"
52
+ # You can skip this step and use `--on-the-fly-feats 1` in training stage
53
+ for subset in ZH-dev ZH-train EN-dev EN-train;do
54
+ python3 -m zipvoice.bin.compute_fbank \
55
+ --source-dir data/manifests \
56
+ --dest-dir data/fbank \
57
+ --dataset opendialog \
58
+ --subset ${subset} \
59
+ --num-jobs ${nj}
60
+ done
61
+ fi
62
+
63
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
64
+ echo "Stage 4: Download tokens file, pretrained models"
65
+ # Uncomment this line to use HF mirror
66
+ # export HF_ENDPOINT=https://hf-mirror.com
67
+
68
+ # The token file is obtained by extending some tokens
69
+ # on the bases of the Emilia token file.
70
+ mkdir -p ${download_dir}
71
+ hf_repo=k2-fsa/ZipVoice
72
+ huggingface-cli download \
73
+ --local-dir ${download_dir} \
74
+ ${hf_repo} \
75
+ zipvoice_dialog/tokens.txt
76
+
77
+ # Pre-trained ZipVoice model is required as
78
+ # the initialization model.
79
+ for file in model.pt tokens.txt model.json; do
80
+ huggingface-cli download \
81
+ --local-dir ${download_dir} \
82
+ ${hf_repo} \
83
+ zipvoice/${file}
84
+ done
85
+ fi
86
+
87
+
88
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
89
+ echo "Stage 5: Train the ZipVoice-Dialog model"
90
+ python3 -m zipvoice.bin.train_zipvoice_dialog \
91
+ --world-size 8 \
92
+ --use-fp16 1 \
93
+ --base-lr 0.0001 \
94
+ --max-duration 500 \
95
+ --checkpoint ${download_dir}/zipvoice/model.pt \
96
+ --model-config ${download_dir}/zipvoice/model.json \
97
+ --token-file ${download_dir}/zipvoice_dialog/tokens.txt \
98
+ --dataset opendialog \
99
+ --manifest-dir data/fbank \
100
+ --exp-dir exp/zipvoice_dialog_opendialog
101
+ fi
102
+
103
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
104
+ echo "Stage 6: Average the checkpoints for ZipVoice"
105
+ python3 -m zipvoice.bin.generate_averaged_model \
106
+ --iter 60000 \
107
+ --avg 2 \
108
+ --model-name zipvoice_dialog \
109
+ --exp-dir exp/zipvoice_dialog_opendialog
110
+ # The generated model is exp/zipvoice_dialog_opendialog/iter-60000-avg-2.pt
111
+ fi
112
+
113
+ if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
114
+ echo "Stage 7: Inference of the ZipVoice model"
115
+
116
+ python3 -m zipvoice.bin.infer_zipvoice_dialog \
117
+ --model-name zipvoice_dialog \
118
+ --model-dir exp/zipvoice_dialog_opendialog \
119
+ --checkpoint-name iter-60000-avg-2.pt \
120
+ --test-list test.tsv \
121
+ --res-dir results/test_dialog
122
+ fi
infer.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Tuple
2
+ import torch
3
+ from transformers import (
4
+ AutoTokenizer, AutoModelForTokenClassification,
5
+ DataCollatorForTokenClassification, Trainer, TrainingArguments
6
+ )
7
+ LABEL_LIST = ["O", "B-EN", "I-EN"]
8
+ LABEL2ID = {l:i for i,l in enumerate(LABEL_LIST)}
9
+ ID2LABEL = {i:l for l,i in LABEL2ID.items()}
10
+
11
+ model_name = "meandyou200175/detect_english"
12
+ model_detect = AutoModelForTokenClassification.from_pretrained(
13
+ model_name, num_labels=len(LABEL_LIST),
14
+ id2label=ID2LABEL, label2id=LABEL2ID
15
+ )
16
+ tokenizer_detect = AutoTokenizer.from_pretrained(model_name, use_fast=True)
17
+
18
+ def tokens_to_pred_spans(offsets: List[Tuple[int,int]], pred_ids: List[int]) -> List[Tuple[int,int]]:
19
+ spans=[]; cur=None
20
+ for (start,end), lid in zip(offsets, pred_ids):
21
+ if start==end: continue
22
+ lab = ID2LABEL.get(lid,"O")
23
+ if lab=="B-EN":
24
+ if cur: spans.append(cur)
25
+ cur=[start,end]
26
+ elif lab=="I-EN":
27
+ if cur: cur[1]=end
28
+ else: cur=[start,end]
29
+ else:
30
+ if cur: spans.append(cur); cur=None
31
+ if cur: spans.append(cur)
32
+ return [tuple(x) for x in spans]
33
+
34
+ def merge_close_spans(spans: List[Dict], max_gap: int = 2) -> List[Dict]:
35
+ if not spans:
36
+ return []
37
+ merged = [spans[0]]
38
+ for cur in spans[1:]:
39
+ prev = merged[-1]
40
+ if cur["start"] - prev["end"] <= max_gap:
41
+ # gα»™p lαΊ‘i
42
+ prev["end"] = cur["end"]
43
+ else:
44
+ merged.append(cur)
45
+ return merged
46
+
47
+
48
+ def infer_spans(text: str, tokenizer, model, max_length: int = 256) -> List[Dict]:
49
+ text = text.lower()
50
+ enc = tokenizer(text, return_offsets_mapping=True, truncation=True,
51
+ max_length=max_length, return_tensors="pt")
52
+ offsets = enc["offset_mapping"][0].tolist()
53
+ with torch.no_grad():
54
+ out = model(**{k: v for k, v in enc.items() if k != "offset_mapping"})
55
+ pred_ids = out.logits.argmax(-1)[0].tolist()
56
+ spans = tokens_to_pred_spans(offsets, pred_ids)
57
+ spans = [{"start": s, "end": e} for (s, e) in spans]
58
+ spans = merge_close_spans(spans, max_gap=2)
59
+ # print(spans)
60
+ return spans
61
+
62
+ import unicodedata
63
+
64
+ def is_letter(ch: str) -> bool:
65
+ if not ch:
66
+ return False
67
+ # NαΊΏu người dΓΉng lα»‘ truyền vΓ o tα»• hợp cΓ³ dαΊ₯u (e + β—ŒΜ), chuαΊ©n hoΓ‘ về NFC:
68
+ ch = unicodedata.normalize("NFC", ch)
69
+ # Chỉ chαΊ₯p nhαΊ­n Δ‘ΓΊng 1 kΓ½ tα»± sau chuαΊ©n hoΓ‘
70
+ if len(ch) != 1:
71
+ return False
72
+ # NhΓ³m 'L*' cα»§a Unicode: Lu, Ll, Lt, Lm, Lo
73
+ return unicodedata.category(ch).startswith('L')
74
+
75
+ import re
76
+ from itertools import chain
77
+ from typing import List, Dict, Optional
78
+ import logging
79
+ from functools import reduce
80
+ from piper_phonemize import phonemize_espeak
81
+
82
+ class EspeakTokenizer():
83
+ """A tokenizer with Espeak g2p function, hα»— trợ English + Vietnamese."""
84
+
85
+ def __init__(self, token_file: Optional[str] = None, lang: str = "vi",
86
+ tokenizer=None, model=None):
87
+ self.has_tokens = False
88
+ self.lang = lang
89
+ self.detector_tokenizer = tokenizer
90
+ self.detector_model = model
91
+
92
+ if token_file is None:
93
+ logging.debug("Initialize Tokenizer without tokens file, "
94
+ "will fail when map to ids.")
95
+ return
96
+
97
+ self.token2id: Dict[str, int] = {}
98
+ with open(token_file, "r", encoding="utf-8") as f:
99
+ for line in f.readlines():
100
+ info = line.rstrip().split("\t")
101
+ token, id = info[0], int(info[1])
102
+ assert token not in self.token2id, token
103
+ self.token2id[token] = id
104
+ self.pad_id = self.token2id["_"]
105
+ self.vocab_size = len(self.token2id)
106
+ self.has_tokens = True
107
+
108
+ @staticmethod
109
+ def _flatten(phs):
110
+ """PhαΊ³ng hΓ³a list-of-lists (hoαΊ·c trαΊ£ lαΊ‘i list nαΊΏu Δ‘Γ£ phαΊ³ng)."""
111
+ if not phs:
112
+ return []
113
+ if isinstance(phs[0], (list, tuple)):
114
+ return list(chain.from_iterable(phs))
115
+ return list(phs)
116
+
117
+ def g2p_chunk(self, text: str, lang: str):
118
+ tokens = []
119
+ start = 0
120
+ for t in text:
121
+ if is_letter(t):
122
+ break
123
+ start = start + 1
124
+
125
+ # Giα»― lαΊ‘i: khoαΊ£ng trαΊ―ng (\s+), tα»« (\w+), kΓ½ tα»± khΓ‘c [^\w\s]
126
+ if start > 0 :
127
+ tokens.extend(self._flatten(text[0:start]))
128
+ phs = phonemize_espeak(text[start:], lang) # cΓ³ thể trαΊ£ về list-of-lists
129
+ tokens.extend(self._flatten(phs))
130
+ return tokens
131
+
132
+ def g2p(self, text: str) -> List[str]:
133
+ """TΓ‘ch text thΓ nh spans EN/VI rα»“i phonemize tΖ°Ζ‘ng α»©ng, bαΊ£o toΓ n khoαΊ£ng trαΊ―ng/dαΊ₯u cΓ’u."""
134
+ try:
135
+ # Fallback: khΓ΄ng cΓ³ detector => phonemize toΓ n chuα»—i theo self.lang,
136
+ # nhΖ°ng qua g2p_chunk để khΓ΄ng mαΊ₯t khoαΊ£ng trαΊ―ng/dαΊ₯u cΓ’u.
137
+ if self.detector_tokenizer is None or self.detector_model is None:
138
+ return self.g2p_chunk(text, self.lang)
139
+
140
+ spans = infer_spans(text, self.detector_tokenizer, self.detector_model)
141
+ spans = sorted(spans, key=lambda x: x["start"])
142
+
143
+ tokens_all = []
144
+ last = 0
145
+ for sp in spans:
146
+ s, e = sp["start"], sp["end"]
147
+ # phαΊ§n trΖ°α»›c Δ‘oαΊ‘n EN -> VI
148
+ if s > last:
149
+ vi_chunk = text[last:s]
150
+ if vi_chunk:
151
+ tokens_all.extend(self.g2p_chunk(vi_chunk, "vi"))
152
+ # Δ‘oαΊ‘n EN
153
+ en_chunk = text[s:e]
154
+ if en_chunk:
155
+ tokens_all.extend([" "])
156
+ tokens_all.extend(self.g2p_chunk(en_chunk, "en"))
157
+ last = e
158
+
159
+ # phαΊ§n cΓ²n lαΊ‘i sau EN -> VI
160
+ if last < len(text):
161
+ vi_chunk = text[last:]
162
+ if vi_chunk:
163
+ tokens_all.extend(self.g2p_chunk(vi_chunk, "vi"))
164
+
165
+ return tokens_all
166
+
167
+ except Exception as ex:
168
+ logging.warning(f"Tokenization of mixed {self.lang} texts failed: {ex}")
169
+ return []
170
+ def texts_to_token_ids(
171
+ self,
172
+ texts: List[str],
173
+ ) -> List[List[int]]:
174
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
175
+
176
+ def texts_to_tokens(
177
+ self,
178
+ texts: List[str],
179
+ ) -> List[List[str]]:
180
+ tokens_list = [self.g2p(texts[i]) for i in range(len(texts))]
181
+ return tokens_list
182
+
183
+ def tokens_to_token_ids(
184
+ self,
185
+ tokens_list: List[List[str]],
186
+ ) -> List[List[int]]:
187
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
188
+
189
+ token_ids_list = []
190
+
191
+ for tokens in tokens_list:
192
+ token_ids = []
193
+ for t in tokens:
194
+ if t not in self.token2id:
195
+ logging.debug(f"Skip OOV {t}")
196
+ continue
197
+ token_ids.append(self.token2id[t])
198
+
199
+ token_ids_list.append(token_ids)
200
+
201
+ return token_ids_list
202
+ import re # <-- thΓͺm
203
+ import random
204
+ import datetime as dt
205
+ import json
206
+ import logging
207
+ import os
208
+ from pathlib import Path
209
+ from typing import Optional
210
+
211
+ import numpy as np
212
+ import safetensors.torch
213
+ import torch
214
+ import torchaudio
215
+ from huggingface_hub import hf_hub_download
216
+ from lhotse.utils import fix_random_seed
217
+ from vocos import Vocos
218
+
219
+ from zipvoice.models.zipvoice import ZipVoice
220
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
221
+ # from zipvoice.tokenizer.tokenizer import EmiliaTokenizer, EspeakTokenizer, LibriTTSTokenizer, SimpleTokenizer, SimpleTokenizer2
222
+ from zipvoice.utils.checkpoint import load_checkpoint
223
+ from zipvoice.utils.common import AttributeDict
224
+ from zipvoice.utils.feature import VocosFbank
225
+ def load_vocab(file_path):
226
+ """Đọc file vocab dẑng char <tab> id -> trả về dict {id: char}"""
227
+ id2char = {}
228
+ with open(file_path, "r", encoding="utf-8") as f:
229
+ for line in f:
230
+ if not line.strip():
231
+ continue
232
+ # bỏ \n nhΖ°ng giα»― lαΊ‘i space Δ‘αΊ§u dΓ²ng
233
+ line = line.rstrip("\n")
234
+ parts = line.split("\t")
235
+ if len(parts) != 2:
236
+ continue # bỏ qua dΓ²ng lα»—i
237
+ char, idx = parts
238
+ id2char[int(idx)] = char
239
+ return id2char
240
+
241
+
242
+ def tokens_to_text(tokens, id2char):
243
+ """Chuyển list token về string"""
244
+ return "".join(id2char.get(t, "<unk>") for t in tokens)
245
+
246
+ def get_vocoder(vocos_local_path: Optional[str] = None):
247
+ if vocos_local_path:
248
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
249
+ state_dict = torch.load(
250
+ f"{vocos_local_path}/pytorch_model.bin",
251
+ weights_only=True,
252
+ map_location="cpu",
253
+ )
254
+ vocoder.load_state_dict(state_dict)
255
+ else:
256
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
257
+ return vocoder
258
+
259
+
260
+ HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
261
+ MODEL_DIR = {
262
+ "zipvoice": "zipvoice",
263
+ "zipvoice_distill": "zipvoice_distill",
264
+ }
265
+
266
+ model_dir="zipvoice_finetune/"
267
+ checkpoint_name="iter-525000-avg-2.pt"
268
+ # checkpoint_name="model.pt"
269
+ model_dir = Path(model_dir)
270
+ model_ckpt = model_dir / checkpoint_name
271
+ model_config_path = model_dir / "model.json"
272
+ token_file = model_dir / "tokens.txt"
273
+
274
+
275
+ tokenizer = EspeakTokenizer(token_file=token_file, tokenizer=tokenizer_detect, model=model_detect)
276
+
277
+
278
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
279
+
280
+ with open(model_config_path, "r") as f:
281
+ model_config = json.load(f)
282
+
283
+ # --- Init model ---
284
+
285
+ model = ZipVoice(**model_config["model"], **tokenizer_config)
286
+
287
+ if str(model_ckpt).endswith(".safetensors"):
288
+ safetensors.torch.load_model(model, model_ckpt)
289
+ else:
290
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
291
+
292
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
293
+ model = model.to(device).eval()
294
+
295
+ # --- Vocoder & features ---
296
+ vocoder = get_vocoder(None).to(device).eval()
297
+ feature_extractor = VocosFbank()
298
+ sampling_rate = model_config["feature"]["sampling_rate"]
299
+ import torch
300
+ import numpy as np
301
+
302
+ import torch
303
+ import numpy as np
304
+ def score_tokens(A):
305
+ B = [9, 14, 18, 21, 27, 33, 37, 39, 42, 45, 50, 51, 52, 54, 58, 59, 61, 62, 63, 69, 73, 74, 79, 85, 99, 100, 102, 105, 119, 120, 121, 122, 123, 124, 141, 143, 144, 145, 146, 157, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 349, 350, 353, 356, 357, 358, 359]
306
+
307
+ total_score = 0
308
+ # ThΓͺm 3 vΓ o Δ‘αΊ§u vΓ  cuα»‘i
309
+ tokens = [3] + A + [3]
310
+
311
+ # TΓ‘ch chuα»—i theo sα»‘ 3
312
+ segment = []
313
+ for t in tokens:
314
+ if t == 3:
315
+ if segment: # xα»­ lΓ½ 1 Δ‘oαΊ‘n
316
+ count = 0
317
+ for i in range(len(segment) - 1):
318
+ if (segment[i] in B and segment[i+1] not in B):
319
+ # print(f"{segment[i]} in B and {segment[i+1]} not in B)")
320
+ count += 1
321
+ if segment[-1] in B:
322
+ # print(f"{segment[-1]} in B")
323
+ count += 1
324
+ if count > 0:
325
+ total_score += 1 + (count - 1) * 0.5
326
+ segment = []
327
+ else:
328
+ segment.append(t)
329
+
330
+ return total_score
331
+
332
+
333
+ def trim_leading_silence_torch(
334
+ wav: torch.Tensor,
335
+ sample_rate: int,
336
+ silence_thresh: float = 0.05,
337
+ chunk_ms: int = 10,
338
+ extend_ms: int = 20,
339
+ ratio: float = 0.95, # % sample phαΊ£i dΖ°α»›i ngΖ°α»‘ng để coi lΓ  im lαΊ·ng
340
+ ):
341
+ wav_np = wav.squeeze(0).cpu().numpy().astype(np.float32)
342
+ norm_wav = wav_np / (np.max(np.abs(wav_np)) + 1e-8)
343
+
344
+ chunk_size = int(sample_rate * chunk_ms / 1000)
345
+ total_chunks = int(len(norm_wav) / chunk_size)
346
+
347
+ start_idx = 0
348
+ for i in range(total_chunks):
349
+ chunk = norm_wav[i * chunk_size : (i + 1) * chunk_size]
350
+ # TΓ­nh tα»· lệ sample dΖ°α»›i ngΖ°α»‘ng
351
+ silent_ratio = np.mean(np.abs(chunk) < silence_thresh)
352
+ if silent_ratio < ratio: # nαΊΏu Γ­t hΖ‘n 95% sample im lαΊ·ng β†’ coi lΓ  cΓ³ tiαΊΏng
353
+ start_idx = max(0, i * chunk_size - int(sample_rate * extend_ms / 1000))
354
+ break
355
+
356
+ return wav[:, start_idx:]
357
+
358
+
359
+
360
+
361
+ @torch.inference_mode()
362
+ def run_zipvoice(
363
+ model_name="zipvoice",
364
+ model_dir="zipvoice_finetune",
365
+ checkpoint_name="model.pt",
366
+ vocoder_path=None,
367
+ tokenizer_name="emilia",
368
+ lang="en-us",
369
+ test_list=None, # path to tsv file
370
+ prompt_wav=None,
371
+ prompt_text=None,
372
+ text=None,
373
+ res_dir="results",
374
+ res_wav_path="result.wav",
375
+ guidance_scale=None,
376
+ num_step=None,
377
+ feat_scale=0.1,
378
+ speed=1.0,
379
+ t_shift=0.5,
380
+ target_rms=0.1,
381
+ seed=666,
382
+ ):
383
+ text = text.lower()
384
+ # --- Default settings per model ---
385
+ model_defaults = {
386
+ "zipvoice": {"num_step": 16, "guidance_scale": 1.0},
387
+ "zipvoice_distill": {"num_step": 8, "guidance_scale": 3.0},
388
+ }
389
+ # sα»­a cΓ‘ch gΓ‘n mαΊ·c Δ‘α»‹nh (khΓ΄ng dΓΉng locals() nα»―a)
390
+ if guidance_scale is None:
391
+ guidance_scale = model_defaults.get(model_name, {}).get("guidance_scale", 1.0)
392
+ if num_step is None:
393
+ num_step = model_defaults.get(model_name, {}).get("num_step", 16)
394
+
395
+ # --- Check inputs ---
396
+ assert (test_list is not None) ^ ((prompt_wav and prompt_text and text) is not None), \
397
+ "CαΊ§n test_list hoαΊ·c (prompt_wav + prompt_text + text)"
398
+
399
+ fix_random_seed(seed)
400
+
401
+ # --- Load tokenizer, model, vocoder, features ... (phαΊ§n nΓ y giα»― nguyΓͺn) ---
402
+ # [giα»― nguyΓͺn toΓ n bα»™ phαΊ§n load tokenizer/model/vocoder/feature_extractor/sampling_rate]
403
+
404
+ # ---------------------------
405
+ # NEW: Hàm chia đoẑn văn bản
406
+ # ---------------------------
407
+ def split_text_into_chunks(s: str, min_chars: int = 15, max_chars: int = 30):
408
+ """
409
+ Chia theo dαΊ₯u ',' hoαΊ·c '.', sau Δ‘Γ³ gα»™p/xαΊ» để mα»—i Δ‘oαΊ‘n dΓ i trong [min_chars, max_chars].
410
+ KhΓ΄ng cαΊ―t giα»―a tα»«.
411
+ """
412
+ # normalize khoαΊ£ng trαΊ―ng
413
+ s = re.sub(r"\s+", " ", (s or "").strip())
414
+ if not s:
415
+ return []
416
+
417
+ # tΓ‘ch theo dαΊ₯u , hoαΊ·c .
418
+ raw_segs = [seg.strip() for seg in re.split(r"\s*[.,]\s*", s) if seg.strip()]
419
+
420
+ chunks = []
421
+ i = 0
422
+ while i < len(raw_segs):
423
+ cur = raw_segs[i]
424
+ i += 1
425
+
426
+ # gα»™p tiαΊΏp theo nαΊΏu cur quΓ‘ ngαΊ―n
427
+ while len(cur) < min_chars and i < len(raw_segs):
428
+ cur = (cur + ", " + raw_segs[i]).strip()
429
+ i += 1
430
+
431
+ # nαΊΏu cur quΓ‘ dΓ i, xαΊ» theo tα»« để <= max_chars
432
+ if len(cur) > max_chars:
433
+ words = cur.split()
434
+ buf = []
435
+ cur_len = 0
436
+ for w in words:
437
+ # +1 cho khoαΊ£ng trαΊ―ng nαΊΏu cαΊ§n
438
+ add_len = len(w) if cur_len == 0 else len(w) + 1
439
+ if cur_len + add_len <= max_chars:
440
+ buf.append(w)
441
+ cur_len += add_len
442
+ else:
443
+ # Δ‘Γ³ng lαΊ‘i mα»™t chunk
444
+ part = ", ".join(buf).strip()
445
+ if part:
446
+ chunks.append(part)
447
+ # bαΊ―t Δ‘αΊ§u chunk mα»›i
448
+ buf = [w]
449
+ cur_len = len(w)
450
+ # phαΊ§n cΓ²n lαΊ‘i
451
+ last = " ".join(buf).strip()
452
+ if last:
453
+ # nαΊΏu phαΊ§n cuα»‘i vαΊ«n < min_chars vΓ  cΓ³ thể gα»™p vα»›i chunk trΖ°α»›c Δ‘Γ³
454
+ if len(last) < min_chars and chunks:
455
+ merged = (chunks[-1] + " " + last).strip()
456
+ if len(merged) <= max_chars:
457
+ chunks[-1] = merged
458
+ else:
459
+ chunks.append(last) # Δ‘Γ nh chαΊ₯p nhαΊ­n (nhΖ°ng thường Γ­t gαΊ·p)
460
+ else:
461
+ chunks.append(last)
462
+ else:
463
+ chunks.append(cur)
464
+
465
+ # vΓ²ng tinh chỉnh cuα»‘i: nαΊΏu chunk cuα»‘i quΓ‘ ngαΊ―n, gα»™p vΓ o trΖ°α»›c Δ‘Γ³
466
+ if len(chunks) >= 2 and len(chunks[-1]) < min_chars:
467
+ merged = (chunks[-2] + ", " + chunks[-1]).strip()
468
+ if len(merged) <= max_chars:
469
+ chunks[-2] = merged
470
+ chunks.pop()
471
+ # print(chunks)
472
+ final_chunk = []
473
+ for chunk in chunks:
474
+ chunk = ", " + chunk + ","
475
+ final_chunk.append(chunk)
476
+ return final_chunk
477
+
478
+ # ---------------------------
479
+ # MODIFIED: generate_sentence synth theo tα»«ng Δ‘oαΊ‘n + nα»‘i lαΊ‘i
480
+ # ---------------------------
481
+ def generate_sentence(save_path, prompt_text, prompt_wav, text):
482
+ # chuαΊ©n hoΓ‘ & chia Δ‘oαΊ‘n
483
+ segments = split_text_into_chunks(text, min_chars=50, max_chars=200)
484
+ if not segments:
485
+ # khΓ΄ng cΓ³ gΓ¬ để nΓ³i: xuαΊ₯t file rα»—ng 0.2s
486
+ silence = torch.zeros((1, int(0.2 * sampling_rate)))
487
+ torchaudio.save(save_path, silence, sample_rate=sampling_rate)
488
+ return
489
+
490
+ # chuαΊ©n bα»‹ prompt (lΓ m 1 lαΊ§n)
491
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
492
+ prompt_wav_tensor, sr = torchaudio.load(prompt_wav)
493
+ if sr != sampling_rate:
494
+ prompt_wav_tensor = torchaudio.transforms.Resample(sr, sampling_rate)(prompt_wav_tensor)
495
+ prompt_rms_val = torch.sqrt(torch.mean(prompt_wav_tensor**2))
496
+ if prompt_rms_val < target_rms:
497
+ prompt_wav_tensor *= target_rms / prompt_rms_val
498
+
499
+ prompt_features = feature_extractor.extract(
500
+ prompt_wav_tensor, sampling_rate=sampling_rate
501
+ ).to(device)
502
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
503
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
504
+ # print(prompt_features_lens)
505
+
506
+ num_space_prompt = prompt_text.count(" ")
507
+
508
+ # khoαΊ£ng lαΊ·ng 0.2s
509
+
510
+
511
+ gap_duration = random.uniform(0.17, 0.2) # sα»‘ ngαΊ«u nhiΓͺn tα»« 0.17 Δ‘αΊΏn 0.2
512
+ gap = torch.zeros((1, int(gap_duration * sampling_rate)))
513
+
514
+ wav_parts = []
515
+ print("segments",segments)
516
+ for idx, seg in enumerate(segments):
517
+ # print(seg)
518
+ num_space_text = seg.count(" ")
519
+ tokens = tokenizer.texts_to_token_ids([seg])
520
+ # print(tokens)
521
+ score = score_tokens(tokens[0])
522
+ # print(score)
523
+ # print(prompt_tokens)
524
+ score_prompt = score_tokens(prompt_tokens[0])
525
+ # print(score_prompt)
526
+ vocab_file = "zipvoice_finetune/tokens.txt" # file txt dαΊ‘ng bαΊ‘n Δ‘Ζ°a
527
+
528
+ id2char = load_vocab(vocab_file)
529
+ decoded_text = tokens_to_text(tokens[0], id2char)
530
+
531
+ print(decoded_text)
532
+
533
+ pred_features, _, _, _ = model.sample(
534
+ num_space_text=[num_space_text],
535
+ num_space_prompt=[num_space_prompt],
536
+ tokens=tokens,
537
+ prompt_tokens=prompt_tokens,
538
+ prompt_features=prompt_features,
539
+ prompt_features_lens=prompt_features_lens,
540
+ speed= speed,
541
+ t_shift= t_shift,
542
+ duration="predict",
543
+ num_step= num_step,
544
+ guidance_scale= guidance_scale,
545
+ )
546
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale
547
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
548
+
549
+ # phα»₯c hα»“i mα»©c Γ’m lượng tΖ°Ζ‘ng quan prompt
550
+ if prompt_rms_val < target_rms:
551
+ wav *= prompt_rms_val / target_rms
552
+ wav = trim_leading_silence_torch(
553
+ wav, sample_rate=sampling_rate, silence_thresh=0.086, chunk_ms=10, extend_ms=20
554
+ )
555
+ wav_parts.append(wav.cpu())
556
+ if idx < len(segments) - 1:
557
+ wav_parts.append(gap) # chèn khoảng lặng
558
+
559
+ final_wav = torch.cat(wav_parts, dim=-1) # [1, T_total]
560
+ torchaudio.save(save_path, final_wav, sample_rate=sampling_rate)
561
+
562
+ # --- generate_list giα»― nguyΓͺn: gọi generate_sentence nΓͺn tα»± Γ‘p dα»₯ng chia Δ‘oαΊ‘n ---
563
+ def generate_list(res_dir, test_list):
564
+ os.makedirs(res_dir, exist_ok=True)
565
+ with open(test_list, "r", encoding="utf-8") as fr:
566
+ for i, line in enumerate(fr):
567
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
568
+ save_path = f"{res_dir}/{wav_name}.wav"
569
+ generate_sentence(save_path, prompt_text, prompt_wav, text)
570
+
571
+ # --- Run ---
572
+ if test_list:
573
+ generate_list(res_dir, test_list)
574
+ else:
575
+ generate_sentence(res_wav_path, prompt_text, prompt_wav, text)
576
+
577
+ print("βœ… HoΓ n thΓ nh!")
578
+ return text,
proccess_wav.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import numpy as np
3
+ from pydub import AudioSegment
4
+ import os
5
+ from chunkformer import ChunkFormerModel
6
+ from clearvoice import ClearVoice
7
+ # ======================= ASR + CLEARVOICE + AUDIO PROCESSING =======================
8
+
9
+ ASR_MODEL = None
10
+ CLEARVOICE_MODEL = None
11
+ REF_AUDIO_CACHE = {} # cache: đường dαΊ«n input -> đường dαΊ«n output Δ‘Γ£ xα»­ lΓ½
12
+
13
+
14
+ def get_asr_model() -> ChunkFormerModel:
15
+ """Lazy-load ChunkFormer (ASR, chαΊ‘y trΓͺn CPU)."""
16
+ global ASR_MODEL
17
+ if ASR_MODEL is None:
18
+ ASR_MODEL = ChunkFormerModel.from_pretrained("khanhld/chunkformer-ctc-large-vie")
19
+ return ASR_MODEL
20
+
21
+
22
+ def get_clearvoice_model() -> ClearVoice:
23
+ """Lazy-load ClearVoice để khα»­ nhiα»…u ref audio."""
24
+ global CLEARVOICE_MODEL
25
+ if CLEARVOICE_MODEL is None:
26
+ CLEARVOICE_MODEL = ClearVoice(
27
+ task="speech_enhancement",
28
+ model_names=["MossFormer2_SE_48K"],
29
+ )
30
+ return CLEARVOICE_MODEL
31
+
32
+
33
+ def find_silent_regions(
34
+ audio: AudioSegment,
35
+ silence_thresh: float = 0.05, # biΓͺn Δ‘α»™ sau chuαΊ©n hoΓ‘ [-1, 1]
36
+ chunk_ms: int = 10,
37
+ min_silence_len: int = 200,
38
+ ) -> List[Tuple[int, int]]:
39
+ """
40
+ TΓ¬m cΓ‘c khoαΊ£ng lαΊ·ng (start_ms, end_ms) trong AudioSegment dα»±a trΓͺn biΓͺn Δ‘α»™.
41
+ """
42
+ samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
43
+ if audio.channels > 1:
44
+ samples = samples.reshape((-1, audio.channels)).mean(axis=1)
45
+
46
+ norm = samples / (2 ** (audio.sample_width * 8 - 1))
47
+ sr = audio.frame_rate
48
+
49
+ chunk_size = max(1, int(sr * chunk_ms / 1000))
50
+ total_chunks = len(norm) // chunk_size
51
+
52
+ silent_regions: List[Tuple[int, int]] = []
53
+ start = None
54
+ for i in range(total_chunks):
55
+ chunk = norm[i * chunk_size: (i + 1) * chunk_size]
56
+ if chunk.size == 0:
57
+ continue
58
+
59
+ if np.all((chunk > -silence_thresh) & (chunk < silence_thresh)):
60
+ if start is None:
61
+ start = i
62
+ else:
63
+ if start is not None:
64
+ dur = (i - start) * chunk_ms
65
+ if dur >= min_silence_len:
66
+ silent_regions.append((start * chunk_ms, i * chunk_ms))
67
+ start = None
68
+
69
+ if start is not None:
70
+ dur = (total_chunks - start) * chunk_ms
71
+ if dur >= min_silence_len:
72
+ silent_regions.append((start * chunk_ms, total_chunks * chunk_ms))
73
+
74
+ return silent_regions
75
+
76
+
77
+ def trim_leading_trailing_silence(
78
+ audio: AudioSegment,
79
+ silence_thresh: float = 0.05,
80
+ chunk_ms: int = 10,
81
+ min_silence_len: int = 200,
82
+ ) -> AudioSegment:
83
+ """
84
+ Bỏ khoαΊ£ng lαΊ·ng Δ‘αΊ§u/cuα»‘i file.
85
+ """
86
+ duration = len(audio)
87
+ silent_regions = find_silent_regions(
88
+ audio,
89
+ silence_thresh=silence_thresh,
90
+ chunk_ms=chunk_ms,
91
+ min_silence_len=min_silence_len,
92
+ )
93
+
94
+ if not silent_regions:
95
+ return audio
96
+
97
+ start_trim = 0
98
+ end_trim = duration
99
+
100
+ # khoαΊ£ng lαΊ·ng Δ‘αΊ§u file
101
+ first_start, first_end = silent_regions[0]
102
+ if first_start <= 0:
103
+ start_trim = max(start_trim, first_end)
104
+
105
+ # khoαΊ£ng lαΊ·ng cuα»‘i file
106
+ last_start, last_end = silent_regions[-1]
107
+ if last_end >= duration:
108
+ end_trim = min(end_trim, last_start)
109
+
110
+ return audio[start_trim:end_trim]
111
+
112
+
113
+ def compress_internal_silence(
114
+ audio: AudioSegment,
115
+ max_silence_ms: int = 300,
116
+ silence_thresh: float = 0.05,
117
+ chunk_ms: int = 10,
118
+ min_silence_len: int = 50,
119
+ ) -> AudioSegment:
120
+ """
121
+ RΓΊt ngαΊ―n khoαΊ£ng lαΊ·ng giα»―a file:
122
+ - KhoαΊ£ng lαΊ·ng <= max_silence_ms: giα»― nguyΓͺn
123
+ - KhoαΊ£ng lαΊ·ng > max_silence_ms: cαΊ―t cΓ²n max_silence_ms
124
+ """
125
+ duration = len(audio)
126
+ silent_regions = find_silent_regions(
127
+ audio,
128
+ silence_thresh=silence_thresh,
129
+ chunk_ms=chunk_ms,
130
+ min_silence_len=min_silence_len,
131
+ )
132
+ if not silent_regions:
133
+ return audio
134
+
135
+ new_audio = AudioSegment.silent(duration=0, frame_rate=audio.frame_rate)
136
+ cursor = 0
137
+
138
+ for s_start, s_end in silent_regions:
139
+ # phαΊ§n cΓ³ tiαΊΏng nΓ³i trΖ°α»›c khoαΊ£ng lαΊ·ng
140
+ if s_start > cursor:
141
+ new_audio += audio[cursor:s_start]
142
+
143
+ silence_len = s_end - s_start
144
+ if silence_len <= max_silence_ms:
145
+ new_audio += audio[s_start:s_end]
146
+ else:
147
+ new_audio += audio[s_start: s_start + max_silence_ms]
148
+
149
+ cursor = s_end
150
+
151
+ # phαΊ§n cΓ²n lαΊ‘i sau khoαΊ£ng lαΊ·ng cuα»‘i
152
+ if cursor < duration:
153
+ new_audio += audio[cursor:]
154
+
155
+ return new_audio
156
+
157
+
158
+ def select_subsegment_by_silence(
159
+ audio: AudioSegment,
160
+ min_len_ms: int = 5000,
161
+ max_len_ms: int = 10000,
162
+ silence_thresh: float = 0.05,
163
+ chunk_ms: int = 10,
164
+ min_silence_len: int = 200,
165
+ ) -> AudioSegment:
166
+ """
167
+ NαΊΏu audio > max_len_ms, chọn 1 Δ‘oαΊ‘n dΓ i trong khoαΊ£ng [min_len_ms, max_len_ms],
168
+ cαΊ―t tαΊ‘i Δ‘iểm nαΊ±m trong khoοΏ½οΏ½οΏ½ng lαΊ·ng để trΓ‘nh cαΊ―t dΓ­nh giọng nΓ³i.
169
+ """
170
+ duration = len(audio)
171
+ if duration <= max_len_ms:
172
+ return audio
173
+
174
+ silent_regions = find_silent_regions(
175
+ audio,
176
+ silence_thresh=silence_thresh,
177
+ chunk_ms=chunk_ms,
178
+ min_silence_len=min_silence_len,
179
+ )
180
+
181
+ if not silent_regions:
182
+ # khΓ΄ng tΓ¬m được khoαΊ£ng lαΊ·ng -> lαΊ₯y Δ‘oαΊ‘n giα»―a
183
+ target_len = min(max_len_ms, duration)
184
+ start = max(0, (duration - target_len) // 2)
185
+ end = start + target_len
186
+ return audio[start:end]
187
+
188
+ # boundary lΓ  midpoint cα»§a khoαΊ£ng lαΊ·ng (chαΊ―c chαΊ―n nαΊ±m trong vΓΉng im lαΊ·ng)
189
+ boundaries = [0]
190
+ for s_start, s_end in silent_regions:
191
+ mid = (s_start + s_end) // 2
192
+ if 0 < mid < duration:
193
+ boundaries.append(mid)
194
+ boundaries.append(duration)
195
+ boundaries = sorted(set(boundaries))
196
+
197
+ # Ζ°u tiΓͺn Δ‘oαΊ‘n Δ‘αΊ§u tiΓͺn thỏa 5–10s
198
+ for i in range(len(boundaries)):
199
+ for j in range(i + 1, len(boundaries)):
200
+ seg_len = boundaries[j] - boundaries[i]
201
+ if min_len_ms <= seg_len <= max_len_ms:
202
+ return audio[boundaries[i]:boundaries[j]]
203
+
204
+ # nαΊΏu khΓ΄ng cΓ³ Δ‘oαΊ‘n nΓ o nαΊ±m trọn trong [min, max], chọn Δ‘oαΊ‘n gαΊ§n max_len nhαΊ₯t
205
+ best_i, best_j, best_diff = 0, None, None
206
+ for i in range(len(boundaries)):
207
+ for j in range(i + 1, len(boundaries)):
208
+ seg_len = boundaries[j] - boundaries[i]
209
+ if seg_len >= min_len_ms:
210
+ diff = abs(seg_len - max_len_ms)
211
+ if best_diff is None or diff < best_diff:
212
+ best_diff = diff
213
+ best_i, best_j = i, j
214
+
215
+ if best_j is not None:
216
+ return audio[boundaries[best_i]:boundaries[best_j]]
217
+
218
+ # fallback cuα»‘i cΓΉng
219
+ target_len = min(max_len_ms, duration)
220
+ start = max(0, (duration - target_len) // 2)
221
+ end = start + target_len
222
+ return audio[start:end]
223
+
224
+
225
+ def enhance_ref_audio(input_path: str) -> str:
226
+ """
227
+ Pipeline xα»­ lΓ½ WAV cho TTS:
228
+ - ClearVoice khα»­ nhiα»…u
229
+ - Bỏ khoαΊ£ng lαΊ·ng Δ‘αΊ§u/cuα»‘i
230
+ - RΓΊt ngαΊ―n khoαΊ£ng lαΊ·ng giα»―a > 0.3s thΓ nh 0.3s
231
+ - NαΊΏu audio > 10s: chọn 1 Δ‘oαΊ‘n 5–10s, cαΊ―t tαΊ‘i khoαΊ£ng lαΊ·ng
232
+ TrαΊ£ về đường dαΊ«n file wav Δ‘Γ£ xα»­ lΓ½.
233
+ """
234
+ if not input_path:
235
+ raise ValueError("No input audio path for enhancement.")
236
+
237
+ # cache để cΓΉng 1 file khΓ΄ng phαΊ£i xα»­ lΓ½ nhiều lαΊ§n
238
+ if input_path in REF_AUDIO_CACHE:
239
+ return REF_AUDIO_CACHE[input_path]
240
+
241
+ cv = get_clearvoice_model()
242
+
243
+ # 1) khα»­ nhiα»…u
244
+ try:
245
+ cv_out = cv(input_path=input_path, online_write=False)
246
+ base = os.path.basename(input_path)
247
+ name, ext = os.path.splitext(base)
248
+ if not ext:
249
+ ext = ".wav"
250
+ denoised_path = os.path.join(os.path.dirname(input_path), f"{name}_denoised{ext}")
251
+ cv.write(cv_out, output_path=denoised_path)
252
+ except Exception as e:
253
+ print(f"[ClearVoice] Error during denoising, fallback to original: {e}")
254
+ denoised_path = input_path
255
+
256
+ # 2) pydub xα»­ lΓ½ khoαΊ£ng lαΊ·ng + length
257
+ audio = AudioSegment.from_file(denoised_path)
258
+
259
+ # bỏ khoαΊ£ng lαΊ·ng Δ‘αΊ§u/cuα»‘i
260
+ audio = trim_leading_trailing_silence(audio)
261
+
262
+ # rΓΊt ngαΊ―n khoαΊ£ng lαΊ·ng giα»―a
263
+ audio = compress_internal_silence(audio, max_silence_ms=300)
264
+
265
+ # nαΊΏu >10s thΓ¬ chọn Δ‘oαΊ‘n trong khoαΊ£ng 5–10s
266
+ audio = select_subsegment_by_silence(audio, min_len_ms=5000, max_len_ms=10000)
267
+
268
+ # 3) ghi ra file mα»›i
269
+ enhanced_path = os.path.join(os.path.dirname(denoised_path), f"{name}_enhanced.wav")
270
+ audio.export(enhanced_path, format="wav")
271
+
272
+ REF_AUDIO_CACHE[input_path] = enhanced_path
273
+ return enhanced_path
274
+
275
+ def split_audio_by_silence(
276
+ audio: AudioSegment,
277
+ silence_thresh: float = 0.05,
278
+ chunk_ms: int = 10,
279
+ min_silence_len: int = 200,
280
+ min_segment_len: int = 200,
281
+ ) -> List[Tuple[int, int]]:
282
+ """
283
+ Tα»« AudioSegment, trαΊ£ về cΓ‘c Δ‘oαΊ‘n cΓ³ tiαΊΏng nΓ³i (non-silent)
284
+ được tΓ‘ch bαΊ±ng khoαΊ£ng lαΊ·ng.
285
+ """
286
+ duration = len(audio)
287
+ silent_regions = find_silent_regions(
288
+ audio,
289
+ silence_thresh=silence_thresh,
290
+ chunk_ms=chunk_ms,
291
+ min_silence_len=min_silence_len,
292
+ )
293
+
294
+ segments: List[Tuple[int, int]] = []
295
+ cur_start = 0
296
+
297
+ for s_start, s_end in silent_regions:
298
+ if cur_start < s_start:
299
+ if s_start - cur_start >= min_segment_len:
300
+ segments.append((cur_start, s_start))
301
+ cur_start = s_end
302
+
303
+ if cur_start < duration and duration - cur_start >= min_segment_len:
304
+ segments.append((cur_start, duration))
305
+
306
+ # nαΊΏu khΓ΄ng tΓ¬m được Δ‘oαΊ‘n nΓ o, lαΊ₯y cαΊ£ file
307
+ if not segments:
308
+ segments.append((0, duration))
309
+
310
+ return segments
311
+
312
+
313
+ def transcribe_ref_audio(audio_path: str) -> str:
314
+ """
315
+ ASR theo yΓͺu cαΊ§u:
316
+ - CαΊ―t Γ’m thanh theo khoαΊ£ng lαΊ·ng
317
+ - ASR tα»«ng Δ‘oαΊ‘n
318
+ - Nα»‘i text bαΊ±ng dαΊ₯u phαΊ©y
319
+ """
320
+ if not audio_path:
321
+ raise ValueError("No audio path for ASR.")
322
+
323
+ model = get_asr_model()
324
+ audio = AudioSegment.from_file(audio_path)
325
+ segments = split_audio_by_silence(audio)
326
+
327
+ texts = []
328
+ base, _ = os.path.splitext(audio_path)
329
+
330
+ for idx, (start_ms, end_ms) in enumerate(segments):
331
+ seg_audio = audio[start_ms:end_ms]
332
+ seg_path = f"{base}_seg_{idx}.wav"
333
+ seg_audio.export(seg_path, format="wav")
334
+
335
+ try:
336
+ transcription = model.endless_decode(
337
+ audio_path=seg_path,
338
+ chunk_size=32,
339
+ left_context_size=0,
340
+ right_context_size=0,
341
+ total_batch_duration=400,
342
+ return_timestamps=False,
343
+ )
344
+ except TypeError:
345
+ transcription = model.endless_decode(
346
+ audio_path=seg_path,
347
+ chunk_size=32,
348
+ left_context_size=0,
349
+ right_context_size=0,
350
+ total_batch_duration=400,
351
+ )
352
+
353
+ if isinstance(transcription, str):
354
+ text = transcription
355
+ else:
356
+ text = str(transcription)
357
+
358
+ text = text.strip()
359
+ if text:
360
+ texts.append(text)
361
+
362
+ return ", ".join(texts)
363
+
364
+
pyproject.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [tool.isort]
2
+ profile = "black"
3
+
4
+ [tool.black]
5
+ line-length = 88
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
2
+ transformers==4.57.1
3
+ torch
4
+ torchaudio
5
+ torchcodec
6
+ numpy
7
+ lhotse
8
+ huggingface_hub
9
+ safetensors
10
+ tensorboard
11
+ vocos
12
+
13
+ # Normalization
14
+ cn2an
15
+ inflect
16
+
17
+ # Tokenization
18
+ jieba
19
+ piper_phonemize
20
+ pypinyin
21
+ setuptools<81
22
+ chunkformer
23
+ clearvoice
requirements_eval.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+
4
+ # Audio processing
5
+ librosa
6
+ soundfile
7
+
8
+ # Model
9
+ s3prl
10
+ pyannote.audio
11
+ funasr
12
+ transformers
13
+
14
+ # WER
15
+ jiwer==3.1.0
16
+
17
+ # Normalization
18
+ zhconv
19
+ zhon
setup.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import requests
4
+ from dotenv import load_dotenv
5
+
6
+ def run_cmd(cmd):
7
+ print(f"πŸ”Ή ChαΊ‘y lệnh: {cmd}")
8
+ result = subprocess.run(cmd, shell=True)
9
+ if result.returncode != 0:
10
+ raise RuntimeError(f"Lệnh thαΊ₯t bαΊ‘i: {cmd}")
11
+
12
+ def download_with_token(url, dest_path, token):
13
+ headers = {"Authorization": f"Bearer {token}"}
14
+ with requests.get(url, headers=headers, stream=True) as r:
15
+ r.raise_for_status()
16
+ with open(dest_path, "wb") as f:
17
+ for chunk in r.iter_content(chunk_size=8192):
18
+ f.write(chunk)
19
+ print(f"βœ… Đã tαΊ£i: {dest_path}")
20
+
21
+ def main():
22
+ # Load biến môi trường từ .env
23
+ load_dotenv()
24
+ token = os.getenv("HF_TOKEN")
25
+
26
+ if not token:
27
+ raise EnvironmentError("❌ ThiαΊΏu biαΊΏn mΓ΄i trường HF_TOKEN. HΓ£y tαΊ‘o file .env vα»›i dΓ²ng:\nHF_TOKEN=hf_your_token_here")
28
+
29
+ # Đăng nhαΊ­p vΓ o Hugging Face CLI
30
+ run_cmd(f"huggingface-cli login --token {token}")
31
+
32
+ # TαΊ‘o thΖ° mα»₯c chα»©a model
33
+ os.makedirs("zipvoice_finetune", exist_ok=True)
34
+
35
+ # Danh sΓ‘ch file cαΊ§n tαΊ£i
36
+ files = {
37
+ "iter-525000-avg-2.pt": "https://huggingface.co/datasets/meandyou200175/temp_file/resolve/main/zip/epoch-46-all-speak-600h-en-norm.pt",
38
+ "model.json": "https://huggingface.co/datasets/meandyou200175/temp_file/resolve/main/zip/model.json",
39
+ "tokens.txt": "https://huggingface.co/datasets/meandyou200175/temp_file/resolve/main/zip/tokens.txt",
40
+ }
41
+
42
+ for filename, url in files.items():
43
+ dest = os.path.join("zipvoice_finetune", filename)
44
+ download_with_token(url, dest, token)
45
+
46
+ # CΓ i Δ‘αΊ·t requirements
47
+ if os.path.exists("requirements.txt"):
48
+ run_cmd("pip install -r requirements.txt")
49
+ else:
50
+ print("⚠️ KhΓ΄ng tΓ¬m thαΊ₯y requirements.txt")
51
+
52
+ print("\nπŸŽ‰ Setup hoΓ n tαΊ₯t!")
53
+
54
+ if __name__ == "__main__":
55
+ main()
zipvoice/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings(
4
+ "ignore",
5
+ category=UserWarning,
6
+ message="pkg_resources is deprecated as an API.*",
7
+ )
zipvoice/bin/compute_fbank.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang
3
+ # Han Zhu)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ Usage:
20
+ python3 -m zipvoice.bin.compute_fbank \
21
+ --source-dir data/manifests \
22
+ --dest-dir data/fbank \
23
+ --dataset libritts \
24
+ --subset dev-other \
25
+ --sampling-rate 24000 \
26
+ --num-jobs 20
27
+
28
+ The input would be data/manifests/libritts-cuts_dev-other.jsonl.gz or
29
+ (libritts_supervisions_dev-other.jsonl.gz and librittsrecordings_dev-other.jsonl.gz)
30
+
31
+ The output would be data/fbank/libritts-cuts_dev-other.jsonl.gz
32
+ """
33
+
34
+
35
+ import argparse
36
+ import logging
37
+ from concurrent.futures import ProcessPoolExecutor as Pool
38
+ from pathlib import Path
39
+
40
+ import lhotse
41
+ import torch
42
+ from lhotse import CutSet, LilcomChunkyWriter, load_manifest_lazy
43
+
44
+ from zipvoice.utils.common import str2bool
45
+ from zipvoice.utils.feature import VocosFbank
46
+
47
+ # Torch's multithreaded behavior needs to be disabled or
48
+ # it wastes a lot of CPU and slow things down.
49
+ # Do this outside of main() in case it needs to take effect
50
+ # even when we are not invoking the main (e.g. when spawning subprocesses).
51
+ torch.set_num_threads(1)
52
+ torch.set_num_interop_threads(1)
53
+
54
+ lhotse.set_audio_duration_mismatch_tolerance(0.1)
55
+
56
+
57
+ def get_args():
58
+ parser = argparse.ArgumentParser()
59
+
60
+ parser.add_argument(
61
+ "--sampling-rate",
62
+ type=int,
63
+ default=24000,
64
+ help="The target sampling rate, the audio will be resampled to it.",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--type",
69
+ type=str,
70
+ default="vocos",
71
+ help="fbank type",
72
+ )
73
+
74
+ parser.add_argument(
75
+ "--dataset",
76
+ type=str,
77
+ help="Dataset name.",
78
+ )
79
+
80
+ parser.add_argument(
81
+ "--subset",
82
+ type=str,
83
+ help="The subset of the dataset.",
84
+ )
85
+
86
+ parser.add_argument(
87
+ "--source-dir",
88
+ type=str,
89
+ default="data/manifests",
90
+ help="The source directory of manifest files.",
91
+ )
92
+
93
+ parser.add_argument(
94
+ "--dest-dir",
95
+ type=str,
96
+ default="data/fbank",
97
+ help="The destination directory of manifest files.",
98
+ )
99
+
100
+ parser.add_argument(
101
+ "--split-cuts",
102
+ type=str2bool,
103
+ default=False,
104
+ help="Whether to use splited cuts.",
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--split-begin",
109
+ type=int,
110
+ help="Start idx of splited cuts.",
111
+ )
112
+
113
+ parser.add_argument(
114
+ "--split-end",
115
+ type=int,
116
+ help="End idx of splited cuts.",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--batch-duration",
121
+ type=int,
122
+ default=1000,
123
+ help="The batch duration when computing the features.",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--num-jobs",
128
+ type=int,
129
+ default=20,
130
+ help="The number of extractor workers.",
131
+ )
132
+
133
+ return parser.parse_args()
134
+
135
+
136
+ def compute_fbank_split_single(params, idx):
137
+ logging.info(
138
+ f"Computing features for {idx}-th split of "
139
+ f"{params.dataset} dataset {params.subset} subset"
140
+ )
141
+ lhotse.set_audio_duration_mismatch_tolerance(0.1) # for emilia
142
+ src_dir = Path(params.source_dir)
143
+ output_dir = Path(params.dest_dir)
144
+
145
+ if not src_dir.exists():
146
+ logging.error(f"{src_dir} not exists")
147
+ return
148
+
149
+ if not output_dir.exists():
150
+ output_dir.mkdir(parents=True, exist_ok=True)
151
+
152
+ num_digits = 8
153
+ if params.type == "vocos":
154
+ extractor = VocosFbank()
155
+ else:
156
+ raise NotImplementedError(f"{params.type} is not supported")
157
+
158
+ prefix = params.dataset
159
+ subset = params.subset
160
+ suffix = "jsonl.gz"
161
+
162
+ idx = f"{idx}".zfill(num_digits)
163
+ cuts_filename = f"{prefix}_cuts_{subset}.{idx}.{suffix}"
164
+
165
+ if (src_dir / cuts_filename).is_file():
166
+ logging.info(f"Loading manifests {src_dir / cuts_filename}")
167
+ cut_set = load_manifest_lazy(src_dir / cuts_filename)
168
+ else:
169
+ logging.warning(f"Raw {cuts_filename} not exists, skipping")
170
+ return
171
+
172
+ cut_set = cut_set.resample(params.sampling_rate)
173
+
174
+ if (output_dir / cuts_filename).is_file():
175
+ logging.info(f"{cuts_filename} already exists - skipping.")
176
+ return
177
+
178
+ logging.info(f"Processing {subset}.{idx} of {prefix}")
179
+
180
+ cut_set = cut_set.compute_and_store_features_batch(
181
+ extractor=extractor,
182
+ storage_path=f"{output_dir}/{prefix}_feats_{subset}_{idx}",
183
+ num_workers=4,
184
+ batch_duration=params.batch_duration,
185
+ storage_type=LilcomChunkyWriter,
186
+ overwrite=True,
187
+ )
188
+ logging.info(f"Saving file to {output_dir / cuts_filename}")
189
+ cut_set.to_file(output_dir / cuts_filename)
190
+
191
+
192
+ def compute_fbank_split(params):
193
+ if params.split_end < params.split_begin:
194
+ logging.warning(
195
+ f"Split begin should be smaller than split end, given "
196
+ f"{params.split_begin} -> {params.split_end}."
197
+ )
198
+
199
+ with Pool(max_workers=params.num_jobs) as pool:
200
+ futures = [
201
+ pool.submit(compute_fbank_split_single, params, i)
202
+ for i in range(params.split_begin, params.split_end)
203
+ ]
204
+ for f in futures:
205
+ f.result()
206
+ f.done()
207
+
208
+
209
+ def compute_fbank(params):
210
+ logging.info(
211
+ f"Computing features for {params.dataset} dataset {params.subset} subset"
212
+ )
213
+ src_dir = Path(params.source_dir)
214
+ output_dir = Path(params.dest_dir)
215
+ num_jobs = params.num_jobs
216
+ if not output_dir.exists():
217
+ output_dir.mkdir(parents=True, exist_ok=True)
218
+
219
+ prefix = params.dataset
220
+ subset = params.subset
221
+ suffix = "jsonl.gz"
222
+
223
+ cut_set_name = f"{prefix}_cuts_{subset}.{suffix}"
224
+
225
+ if (src_dir / cut_set_name).is_file():
226
+ logging.info(f"Loading manifests {src_dir / cut_set_name}")
227
+ cut_set = load_manifest_lazy(src_dir / cut_set_name)
228
+ else:
229
+ recordings = load_manifest_lazy(
230
+ src_dir / f"{prefix}_recordings_{subset}.{suffix}"
231
+ )
232
+ supervisions = load_manifest_lazy(
233
+ src_dir / f"{prefix}_supervisions_{subset}.{suffix}"
234
+ )
235
+ cut_set = CutSet.from_manifests(
236
+ recordings=recordings,
237
+ supervisions=supervisions,
238
+ )
239
+
240
+ cut_set = cut_set.resample(params.sampling_rate)
241
+ if params.type == "vocos":
242
+ extractor = VocosFbank()
243
+ else:
244
+ raise NotImplementedError(f"{params.type} is not supported")
245
+
246
+ cuts_filename = f"{prefix}_cuts_{subset}.{suffix}"
247
+ if (output_dir / cuts_filename).is_file():
248
+ logging.info(f"{prefix} {subset} already exists - skipping.")
249
+ return
250
+ logging.info(f"Processing {subset} of {prefix}")
251
+
252
+ cut_set = cut_set.compute_and_store_features(
253
+ extractor=extractor,
254
+ storage_path=f"{output_dir}/{prefix}_feats_{subset}",
255
+ num_jobs=num_jobs,
256
+ storage_type=LilcomChunkyWriter,
257
+ )
258
+ logging.info(f"Saving file to {output_dir / cuts_filename}")
259
+ cut_set.to_file(output_dir / cuts_filename)
260
+
261
+
262
+ if __name__ == "__main__":
263
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
264
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
265
+
266
+ args = get_args()
267
+ logging.info(vars(args))
268
+ if args.split_cuts:
269
+ compute_fbank_split(params=args)
270
+ else:
271
+ compute_fbank(params=args)
272
+ logging.info("Done!")
zipvoice/bin/generate_averaged_model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2021-2022 Xiaomi Corporation
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ Usage:
20
+ This script loads checkpoints and averages them.
21
+
22
+ python3 -m zipvoice.bin.generate_averaged_model \
23
+ --epoch 11 \
24
+ --avg 4 \
25
+ --model-name zipvoice \
26
+ --exp-dir exp/zipvoice
27
+
28
+ It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`.
29
+ You can later load it by `torch.load("epoch-11-avg-4.pt")`.
30
+ """
31
+
32
+ import argparse
33
+ import json
34
+ import logging
35
+ from pathlib import Path
36
+
37
+ import torch
38
+
39
+ from zipvoice.models.zipvoice import ZipVoice
40
+ from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
41
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
42
+ from zipvoice.tokenizer.tokenizer import SimpleTokenizer
43
+ from zipvoice.utils.checkpoint import (
44
+ average_checkpoints_with_averaged_model,
45
+ find_checkpoints,
46
+ )
47
+ from zipvoice.utils.common import AttributeDict
48
+
49
+
50
+ def get_parser():
51
+ parser = argparse.ArgumentParser(
52
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--epoch",
57
+ type=int,
58
+ default=11,
59
+ help="""It specifies the checkpoint to use for decoding.
60
+ Note: Epoch counts from 1.
61
+ You can specify --avg to use more checkpoints for model averaging.""",
62
+ )
63
+
64
+ parser.add_argument(
65
+ "--iter",
66
+ type=int,
67
+ default=0,
68
+ help="""If positive, --epoch is ignored and it
69
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
70
+ You can specify --avg to use more checkpoints for model averaging.
71
+ """,
72
+ )
73
+
74
+ parser.add_argument(
75
+ "--avg",
76
+ type=int,
77
+ default=4,
78
+ help="Number of checkpoints to average. Automatically select "
79
+ "consecutive checkpoints before the checkpoint specified by "
80
+ "'--epoch' or --iter",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--exp-dir",
85
+ type=str,
86
+ default="exp/zipvoice",
87
+ help="The experiment dir",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--model-name",
92
+ type=str,
93
+ default="zipvoice",
94
+ choices=[
95
+ "zipvoice",
96
+ "zipvoice_distill",
97
+ "zipvoice_dialog",
98
+ "zipvoice_dialog_stereo",
99
+ ],
100
+ help="The model type to be averaged. ",
101
+ )
102
+
103
+ return parser
104
+
105
+
106
+ @torch.no_grad()
107
+ def main():
108
+ parser = get_parser()
109
+ args = parser.parse_args()
110
+ params = AttributeDict()
111
+ params.update(vars(args))
112
+ params.exp_dir = Path(params.exp_dir)
113
+
114
+ with open(params.exp_dir / "model.json", "r") as f:
115
+ model_config = json.load(f)
116
+
117
+ # Any tokenizer can be used here.
118
+ # Use SimpleTokenizer for simplicity.
119
+ tokenizer = SimpleTokenizer(token_file=params.exp_dir / "tokens.txt")
120
+ if params.model_name in ["zipvoice", "zipvoice_distill"]:
121
+ tokenizer_config = {
122
+ "vocab_size": tokenizer.vocab_size,
123
+ "pad_id": tokenizer.pad_id,
124
+ }
125
+ elif params.model_name in ["zipvoice_dialog", "zipvoice_dialog_stereo"]:
126
+ tokenizer_config = {
127
+ "vocab_size": tokenizer.vocab_size,
128
+ "pad_id": tokenizer.pad_id,
129
+ "spk_a_id": tokenizer.spk_a_id,
130
+ "spk_b_id": tokenizer.spk_b_id,
131
+ }
132
+
133
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
134
+
135
+ logging.info("Script started")
136
+
137
+ params.device = torch.device("cpu")
138
+ logging.info(f"Device: {params.device}")
139
+
140
+ logging.info("About to create model")
141
+ if params.model_name == "zipvoice":
142
+ model = ZipVoice(
143
+ **model_config["model"],
144
+ **tokenizer_config,
145
+ )
146
+ elif params.model_name == "zipvoice_distill":
147
+ model = ZipVoiceDistill(
148
+ **model_config["model"],
149
+ **tokenizer_config,
150
+ )
151
+ elif params.model_name == "zipvoice_dialog":
152
+ model = ZipVoiceDialog(
153
+ **model_config["model"],
154
+ **tokenizer_config,
155
+ )
156
+ elif params.model_name == "zipvoice_dialog_stereo":
157
+ model = ZipVoiceDialogStereo(
158
+ **model_config["model"],
159
+ **tokenizer_config,
160
+ )
161
+ else:
162
+ raise ValueError(f"Unknown model name: {params.model_name}")
163
+
164
+ if params.iter > 0:
165
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
166
+ : params.avg + 1
167
+ ]
168
+ if len(filenames) == 0:
169
+ raise ValueError(
170
+ f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
171
+ )
172
+ elif len(filenames) < params.avg + 1:
173
+ raise ValueError(
174
+ f"Not enough checkpoints ({len(filenames)}) found for"
175
+ f" --iter {params.iter}, --avg {params.avg}"
176
+ )
177
+ filename_start = filenames[-1]
178
+ filename_end = filenames[0]
179
+ logging.info(
180
+ "Calculating the averaged model over iteration checkpoints"
181
+ f" from {filename_start} (excluded) to {filename_end}"
182
+ )
183
+ model.to(params.device)
184
+ model.load_state_dict(
185
+ average_checkpoints_with_averaged_model(
186
+ filename_start=filename_start,
187
+ filename_end=filename_end,
188
+ device=params.device,
189
+ ),
190
+ strict=True,
191
+ )
192
+ else:
193
+ assert params.avg > 0, params.avg
194
+ start = params.epoch - params.avg
195
+ assert start >= 1, start
196
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
197
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
198
+ logging.info(
199
+ f"Calculating the averaged model over epoch range from "
200
+ f"{start} (excluded) to {params.epoch}"
201
+ )
202
+ model.to(params.device)
203
+ model.load_state_dict(
204
+ average_checkpoints_with_averaged_model(
205
+ filename_start=filename_start,
206
+ filename_end=filename_end,
207
+ device=params.device,
208
+ ),
209
+ strict=True,
210
+ )
211
+ if params.iter > 0:
212
+ filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
213
+ else:
214
+ filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
215
+
216
+ logging.info(f"Saving the averaged checkpoint to {filename}")
217
+ torch.save({"model": model.state_dict()}, filename)
218
+
219
+ num_param = sum([p.numel() for p in model.parameters()])
220
+ logging.info(f"Number of model parameters: {num_param}")
221
+
222
+ logging.info("Done!")
223
+
224
+
225
+ if __name__ == "__main__":
226
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
227
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
228
+
229
+ main()
zipvoice/bin/infer_zipvoice.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script generates speech with our pre-trained ZipVoice or
20
+ ZipVoice-Distill models. If no local model is specified,
21
+ Required files will be automatically downloaded from HuggingFace.
22
+
23
+ Usage:
24
+
25
+ Note: If you having trouble connecting to HuggingFace,
26
+ try switching endpoint to mirror site:
27
+ export HF_ENDPOINT=https://hf-mirror.com
28
+
29
+ (1) Inference of a single sentence:
30
+
31
+ python3 -m zipvoice.bin.infer_zipvoice \
32
+ --model-name zipvoice \
33
+ --prompt-wav prompt.wav \
34
+ --prompt-text "I am a prompt." \
35
+ --text "I am a sentence." \
36
+ --res-wav-path result.wav
37
+
38
+ (2) Inference of a list of sentences:
39
+
40
+ python3 -m zipvoice.bin.infer_zipvoice \
41
+ --model-name zipvoice \
42
+ --test-list test.tsv \
43
+ --res-dir results
44
+
45
+ `--model-name` can be `zipvoice` or `zipvoice_distill`,
46
+ which are the models before and after distillation, respectively.
47
+
48
+ Each line of `test.tsv` is in the format of
49
+ `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
50
+ """
51
+
52
+ import argparse
53
+ import datetime as dt
54
+ import json
55
+ import logging
56
+ import os
57
+ from pathlib import Path
58
+ from typing import Optional
59
+
60
+ import numpy as np
61
+ import safetensors.torch
62
+ import torch
63
+ import torchaudio
64
+ from huggingface_hub import hf_hub_download
65
+ from lhotse.utils import fix_random_seed
66
+ from vocos import Vocos
67
+
68
+ from zipvoice.models.zipvoice import ZipVoice
69
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
70
+ from zipvoice.tokenizer.tokenizer import (
71
+ EmiliaTokenizer,
72
+ EspeakTokenizer,
73
+ LibriTTSTokenizer,
74
+ SimpleTokenizer,
75
+ )
76
+ from zipvoice.utils.checkpoint import load_checkpoint
77
+ from zipvoice.utils.common import AttributeDict
78
+ from zipvoice.utils.feature import VocosFbank
79
+
80
+ HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
81
+ MODEL_DIR = {
82
+ "zipvoice": "zipvoice",
83
+ "zipvoice_distill": "zipvoice_distill",
84
+ }
85
+
86
+
87
+ def get_parser():
88
+ parser = argparse.ArgumentParser(
89
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
90
+ )
91
+
92
+ parser.add_argument(
93
+ "--model-name",
94
+ type=str,
95
+ default="zipvoice",
96
+ choices=["zipvoice", "zipvoice_distill"],
97
+ help="The model used for inference",
98
+ )
99
+
100
+ parser.add_argument(
101
+ "--model-dir",
102
+ type=str,
103
+ default=None,
104
+ help="The model directory that contains model checkpoint, configuration "
105
+ "file model.json, and tokens file tokens.txt. Will download pre-trained "
106
+ "checkpoint from huggingface if not specified.",
107
+ )
108
+
109
+ parser.add_argument(
110
+ "--checkpoint-name",
111
+ type=str,
112
+ default="model.pt",
113
+ help="The name of model checkpoint.",
114
+ )
115
+
116
+ parser.add_argument(
117
+ "--vocoder-path",
118
+ type=str,
119
+ default=None,
120
+ help="The vocoder checkpoint. "
121
+ "Will download pre-trained vocoder from huggingface if not specified.",
122
+ )
123
+
124
+ parser.add_argument(
125
+ "--tokenizer",
126
+ type=str,
127
+ default="emilia",
128
+ choices=["emilia", "libritts", "espeak", "simple"],
129
+ help="Tokenizer type.",
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--lang",
134
+ type=str,
135
+ default="en-us",
136
+ help="Language identifier, used when tokenizer type is espeak. see"
137
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
138
+ )
139
+
140
+ parser.add_argument(
141
+ "--test-list",
142
+ type=str,
143
+ default=None,
144
+ help="The list of prompt speech, prompt_transcription, "
145
+ "and text to synthesizein the format of "
146
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
147
+ )
148
+
149
+ parser.add_argument(
150
+ "--prompt-wav",
151
+ type=str,
152
+ default=None,
153
+ help="The prompt wav to mimic",
154
+ )
155
+
156
+ parser.add_argument(
157
+ "--prompt-text",
158
+ type=str,
159
+ default=None,
160
+ help="The transcription of the prompt wav",
161
+ )
162
+
163
+ parser.add_argument(
164
+ "--text",
165
+ type=str,
166
+ default=None,
167
+ help="The text to synthesize",
168
+ )
169
+
170
+ parser.add_argument(
171
+ "--res-dir",
172
+ type=str,
173
+ default="results",
174
+ help="""
175
+ Path name of the generated wavs dir,
176
+ used when test-list is not None
177
+ """,
178
+ )
179
+
180
+ parser.add_argument(
181
+ "--res-wav-path",
182
+ type=str,
183
+ default="result.wav",
184
+ help="""
185
+ Path name of the generated wav path,
186
+ used when test-list is None
187
+ """,
188
+ )
189
+
190
+ parser.add_argument(
191
+ "--guidance-scale",
192
+ type=float,
193
+ default=None,
194
+ help="The scale of classifier-free guidance during inference.",
195
+ )
196
+
197
+ parser.add_argument(
198
+ "--num-step",
199
+ type=int,
200
+ default=None,
201
+ help="The number of sampling steps.",
202
+ )
203
+
204
+ parser.add_argument(
205
+ "--feat-scale",
206
+ type=float,
207
+ default=0.1,
208
+ help="The scale factor of fbank feature",
209
+ )
210
+
211
+ parser.add_argument(
212
+ "--speed",
213
+ type=float,
214
+ default=1.0,
215
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
216
+ )
217
+
218
+ parser.add_argument(
219
+ "--t-shift",
220
+ type=float,
221
+ default=0.5,
222
+ help="Shift t to smaller ones if t_shift < 1.0",
223
+ )
224
+
225
+ parser.add_argument(
226
+ "--target-rms",
227
+ type=float,
228
+ default=0.1,
229
+ help="Target speech normalization rms value, set to 0 to disable normalization",
230
+ )
231
+
232
+ parser.add_argument(
233
+ "--seed",
234
+ type=int,
235
+ default=666,
236
+ help="Random seed",
237
+ )
238
+
239
+ return parser
240
+
241
+
242
+ def get_vocoder(vocos_local_path: Optional[str] = None):
243
+ if vocos_local_path:
244
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
245
+ state_dict = torch.load(
246
+ f"{vocos_local_path}/pytorch_model.bin",
247
+ weights_only=True,
248
+ map_location="cpu",
249
+ )
250
+ vocoder.load_state_dict(state_dict)
251
+ else:
252
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
253
+ return vocoder
254
+
255
+
256
+ def generate_sentence(
257
+ save_path: str,
258
+ prompt_text: str,
259
+ prompt_wav: str,
260
+ text: str,
261
+ model: torch.nn.Module,
262
+ vocoder: torch.nn.Module,
263
+ tokenizer: EmiliaTokenizer,
264
+ feature_extractor: VocosFbank,
265
+ device: torch.device,
266
+ num_step: int = 16,
267
+ guidance_scale: float = 1.0,
268
+ speed: float = 1.0,
269
+ t_shift: float = 0.5,
270
+ target_rms: float = 0.1,
271
+ feat_scale: float = 0.1,
272
+ sampling_rate: int = 24000,
273
+ ):
274
+ """
275
+ Generate waveform of a text based on a given prompt
276
+ waveform and its transcription.
277
+
278
+ Args:
279
+ save_path (str): Path to save the generated wav.
280
+ prompt_text (str): Transcription of the prompt wav.
281
+ prompt_wav (str): Path to the prompt wav file.
282
+ text (str): Text to be synthesized into a waveform.
283
+ model (torch.nn.Module): The model used for generation.
284
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
285
+ tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
286
+ feature_extractor (VocosFbank): The feature extractor used to
287
+ extract acoustic features.
288
+ device (torch.device): The device on which computations are performed.
289
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
290
+ guidance_scale (float, optional): Scale for classifier-free guidance.
291
+ Defaults to 1.0.
292
+ speed (float, optional): Speed control. Defaults to 1.0.
293
+ t_shift (float, optional): Time shift. Defaults to 0.5.
294
+ target_rms (float, optional): Target RMS for waveform normalization.
295
+ Defaults to 0.1.
296
+ feat_scale (float, optional): Scale for features.
297
+ Defaults to 0.1.
298
+ sampling_rate (int, optional): Sampling rate for the waveform.
299
+ Defaults to 24000.
300
+ Returns:
301
+ metrics (dict): Dictionary containing time and real-time
302
+ factor metrics for processing.
303
+ """
304
+ # Convert text to tokens
305
+ tokens = tokenizer.texts_to_token_ids([text])
306
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
307
+
308
+ # Load and preprocess prompt wav
309
+ prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
310
+
311
+ if prompt_sampling_rate != sampling_rate:
312
+ resampler = torchaudio.transforms.Resample(
313
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
314
+ )
315
+ prompt_wav = resampler(prompt_wav)
316
+
317
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
318
+ if prompt_rms < target_rms:
319
+ prompt_wav = prompt_wav * target_rms / prompt_rms
320
+
321
+ # Extract features from prompt wav
322
+ prompt_features = feature_extractor.extract(
323
+ prompt_wav, sampling_rate=sampling_rate
324
+ ).to(device)
325
+
326
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
327
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
328
+
329
+ # Start timing
330
+ start_t = dt.datetime.now()
331
+
332
+ # Generate features
333
+ (
334
+ pred_features,
335
+ pred_features_lens,
336
+ pred_prompt_features,
337
+ pred_prompt_features_lens,
338
+ ) = model.sample(
339
+ tokens=tokens,
340
+ prompt_tokens=prompt_tokens,
341
+ prompt_features=prompt_features,
342
+ prompt_features_lens=prompt_features_lens,
343
+ speed=speed,
344
+ t_shift=t_shift,
345
+ duration="predict",
346
+ num_step=num_step,
347
+ guidance_scale=guidance_scale,
348
+ )
349
+
350
+ # Postprocess predicted features
351
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
352
+
353
+ # Start vocoder processing
354
+ start_vocoder_t = dt.datetime.now()
355
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
356
+
357
+ # Calculate processing times and real-time factors
358
+ t = (dt.datetime.now() - start_t).total_seconds()
359
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
360
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
361
+ wav_seconds = wav.shape[-1] / sampling_rate
362
+ rtf = t / wav_seconds
363
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
364
+ rtf_vocoder = t_vocoder / wav_seconds
365
+ metrics = {
366
+ "t": t,
367
+ "t_no_vocoder": t_no_vocoder,
368
+ "t_vocoder": t_vocoder,
369
+ "wav_seconds": wav_seconds,
370
+ "rtf": rtf,
371
+ "rtf_no_vocoder": rtf_no_vocoder,
372
+ "rtf_vocoder": rtf_vocoder,
373
+ }
374
+
375
+ # Adjust wav volume if necessary
376
+ if prompt_rms < target_rms:
377
+ wav = wav * prompt_rms / target_rms
378
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
379
+
380
+ return metrics
381
+
382
+
383
+ def generate_list(
384
+ res_dir: str,
385
+ test_list: str,
386
+ model: torch.nn.Module,
387
+ vocoder: torch.nn.Module,
388
+ tokenizer: EmiliaTokenizer,
389
+ feature_extractor: VocosFbank,
390
+ device: torch.device,
391
+ num_step: int = 16,
392
+ guidance_scale: float = 1.0,
393
+ speed: float = 1.0,
394
+ t_shift: float = 0.5,
395
+ target_rms: float = 0.1,
396
+ feat_scale: float = 0.1,
397
+ sampling_rate: int = 24000,
398
+ ):
399
+ total_t = []
400
+ total_t_no_vocoder = []
401
+ total_t_vocoder = []
402
+ total_wav_seconds = []
403
+
404
+ with open(test_list, "r") as fr:
405
+ lines = fr.readlines()
406
+
407
+ for i, line in enumerate(lines):
408
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
409
+ save_path = f"{res_dir}/{wav_name}.wav"
410
+ metrics = generate_sentence(
411
+ save_path=save_path,
412
+ prompt_text=prompt_text,
413
+ prompt_wav=prompt_wav,
414
+ text=text,
415
+ model=model,
416
+ vocoder=vocoder,
417
+ tokenizer=tokenizer,
418
+ feature_extractor=feature_extractor,
419
+ device=device,
420
+ num_step=num_step,
421
+ guidance_scale=guidance_scale,
422
+ speed=speed,
423
+ t_shift=t_shift,
424
+ target_rms=target_rms,
425
+ feat_scale=feat_scale,
426
+ sampling_rate=sampling_rate,
427
+ )
428
+ logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
429
+ total_t.append(metrics["t"])
430
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
431
+ total_t_vocoder.append(metrics["t_vocoder"])
432
+ total_wav_seconds.append(metrics["wav_seconds"])
433
+
434
+ logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
435
+ logging.info(
436
+ f"Average RTF w/o vocoder: "
437
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
438
+ )
439
+ logging.info(
440
+ f"Average RTF vocoder: "
441
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
442
+ )
443
+
444
+
445
+ @torch.inference_mode()
446
+ def main():
447
+ parser = get_parser()
448
+ args = parser.parse_args()
449
+
450
+ params = AttributeDict()
451
+ params.update(vars(args))
452
+ fix_random_seed(params.seed)
453
+
454
+ model_defaults = {
455
+ "zipvoice": {
456
+ "num_step": 16,
457
+ "guidance_scale": 1.0,
458
+ },
459
+ "zipvoice_distill": {
460
+ "num_step": 8,
461
+ "guidance_scale": 3.0,
462
+ },
463
+ }
464
+
465
+ model_specific_defaults = model_defaults.get(params.model_name, {})
466
+
467
+ for param, value in model_specific_defaults.items():
468
+ if getattr(params, param) is None:
469
+ setattr(params, param, value)
470
+ logging.info(f"Setting {param} to default value: {value}")
471
+
472
+ assert (params.test_list is not None) ^ (
473
+ (params.prompt_wav and params.prompt_text and params.text) is not None
474
+ ), (
475
+ "For inference, please provide prompts and text with either '--test-list'"
476
+ " or '--prompt-wav, --prompt-text and --text'."
477
+ )
478
+
479
+ if params.model_dir is not None:
480
+ params.model_dir = Path(params.model_dir)
481
+ if not params.model_dir.is_dir():
482
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
483
+ for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
484
+ if not (params.model_dir / filename).is_file():
485
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
486
+ model_ckpt = params.model_dir / params.checkpoint_name
487
+ model_config = params.model_dir / "model.json"
488
+ token_file = params.model_dir / "tokens.txt"
489
+ logging.info(
490
+ f"Using local model dir {params.model_dir}, "
491
+ f"checkpoint {params.checkpoint_name}"
492
+ )
493
+ else:
494
+ logging.info("Using pretrained model from the huggingface")
495
+ logging.info("Downloading the requires files from HuggingFace")
496
+ model_ckpt = hf_hub_download(
497
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.pt"
498
+ )
499
+ model_config = hf_hub_download(
500
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
501
+ )
502
+
503
+ token_file = hf_hub_download(
504
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
505
+ )
506
+
507
+ logging.info("Loading model...")
508
+
509
+ if params.tokenizer == "emilia":
510
+ tokenizer = EmiliaTokenizer(token_file=token_file)
511
+ elif params.tokenizer == "libritts":
512
+ tokenizer = LibriTTSTokenizer(token_file=token_file)
513
+ elif params.tokenizer == "espeak":
514
+ tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
515
+ else:
516
+ assert params.tokenizer == "simple"
517
+ tokenizer = SimpleTokenizer(token_file=token_file)
518
+
519
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
520
+
521
+ with open(model_config, "r") as f:
522
+ model_config = json.load(f)
523
+
524
+ if params.model_name == "zipvoice":
525
+ model = ZipVoice(
526
+ **model_config["model"],
527
+ **tokenizer_config,
528
+ )
529
+ else:
530
+ assert params.model_name == "zipvoice_distill"
531
+ model = ZipVoiceDistill(
532
+ **model_config["model"],
533
+ **tokenizer_config,
534
+ )
535
+
536
+ if str(model_ckpt).endswith(".safetensors"):
537
+ safetensors.torch.load_model(model, model_ckpt)
538
+ elif str(model_ckpt).endswith(".pt"):
539
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
540
+ else:
541
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
542
+
543
+ if torch.cuda.is_available():
544
+ params.device = torch.device("cuda", 0)
545
+ elif torch.backends.mps.is_available():
546
+ params.device = torch.device("mps")
547
+ else:
548
+ params.device = torch.device("cpu")
549
+ logging.info(f"Device: {params.device}")
550
+
551
+ model = model.to(params.device)
552
+ model.eval()
553
+
554
+ vocoder = get_vocoder(params.vocoder_path)
555
+ vocoder = vocoder.to(params.device)
556
+ vocoder.eval()
557
+
558
+ if model_config["feature"]["type"] == "vocos":
559
+ feature_extractor = VocosFbank()
560
+ else:
561
+ raise NotImplementedError(
562
+ f"Unsupported feature type: {model_config['feature']['type']}"
563
+ )
564
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
565
+
566
+ logging.info("Start generating...")
567
+ if params.test_list:
568
+ os.makedirs(params.res_dir, exist_ok=True)
569
+ generate_list(
570
+ res_dir=params.res_dir,
571
+ test_list=params.test_list,
572
+ model=model,
573
+ vocoder=vocoder,
574
+ tokenizer=tokenizer,
575
+ feature_extractor=feature_extractor,
576
+ device=params.device,
577
+ num_step=params.num_step,
578
+ guidance_scale=params.guidance_scale,
579
+ speed=params.speed,
580
+ t_shift=params.t_shift,
581
+ target_rms=params.target_rms,
582
+ feat_scale=params.feat_scale,
583
+ sampling_rate=params.sampling_rate,
584
+ )
585
+ else:
586
+ generate_sentence(
587
+ save_path=params.res_wav_path,
588
+ prompt_text=params.prompt_text,
589
+ prompt_wav=params.prompt_wav,
590
+ text=params.text,
591
+ model=model,
592
+ vocoder=vocoder,
593
+ tokenizer=tokenizer,
594
+ feature_extractor=feature_extractor,
595
+ device=params.device,
596
+ num_step=params.num_step,
597
+ guidance_scale=params.guidance_scale,
598
+ speed=params.speed,
599
+ t_shift=params.t_shift,
600
+ target_rms=params.target_rms,
601
+ feat_scale=params.feat_scale,
602
+ sampling_rate=params.sampling_rate,
603
+ )
604
+ logging.info("Done")
605
+
606
+
607
+ if __name__ == "__main__":
608
+ torch.set_num_threads(1)
609
+ torch.set_num_interop_threads(1)
610
+
611
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
612
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
613
+
614
+ main()
zipvoice/bin/infer_zipvoice_dialog.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script generates speech with our pre-trained ZipVoice-Dialog or
20
+ ZipVoice-Dialog-Stereo models. If no local model is specified,
21
+ Required files will be automatically downloaded from HuggingFace.
22
+
23
+ Usage:
24
+
25
+ Note: If you having trouble connecting to HuggingFace,
26
+ try switching endpoint to mirror site:
27
+ export HF_ENDPOINT=https://hf-mirror.com
28
+
29
+ python3 -m zipvoice.bin.infer_zipvoice_dialog \
30
+ --model-name zipvoice_dialog \
31
+ --test-list test.tsv \
32
+ --res-dir results
33
+
34
+ `--model-name` can be `zipvoice_dialog` or `zipvoice_dialog_stereo`,
35
+ which generate mono and stereo dialogues, respectively.
36
+
37
+ Each line of `test.tsv` is in the format of merged conversation:
38
+ '{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'
39
+ or splited conversation:
40
+ '{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}
41
+ \t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'
42
+ """
43
+
44
+ import argparse
45
+ import datetime as dt
46
+ import json
47
+ import logging
48
+ import os
49
+ from pathlib import Path
50
+ from typing import List, Optional, Union
51
+
52
+ import numpy as np
53
+ import safetensors.torch
54
+ import torch
55
+ import torchaudio
56
+ from huggingface_hub import hf_hub_download
57
+ from lhotse.utils import fix_random_seed
58
+ from vocos import Vocos
59
+
60
+ from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
61
+ from zipvoice.tokenizer.tokenizer import DialogTokenizer
62
+ from zipvoice.utils.checkpoint import load_checkpoint
63
+ from zipvoice.utils.common import AttributeDict
64
+ from zipvoice.utils.feature import VocosFbank
65
+
66
+ HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
67
+ MODEL_DIR = {
68
+ "zipvoice_dialog": "zipvoice_dialog",
69
+ "zipvoice_dialog_stereo": "zipvoice_dialog_stereo",
70
+ }
71
+
72
+
73
+ def get_parser():
74
+ parser = argparse.ArgumentParser(
75
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
76
+ )
77
+
78
+ parser.add_argument(
79
+ "--model-name",
80
+ type=str,
81
+ default="zipvoice_dialog",
82
+ choices=["zipvoice_dialog", "zipvoice_dialog_stereo"],
83
+ help="The model used for inference",
84
+ )
85
+
86
+ parser.add_argument(
87
+ "--model-dir",
88
+ type=str,
89
+ default=None,
90
+ help="The model directory that contains model checkpoint, configuration "
91
+ "file model.json, and tokens file tokens.txt. Will download pre-trained "
92
+ "checkpoint from huggingface if not specified.",
93
+ )
94
+
95
+ parser.add_argument(
96
+ "--checkpoint-name",
97
+ type=str,
98
+ default="model.pt",
99
+ help="The name of model checkpoint.",
100
+ )
101
+
102
+ parser.add_argument(
103
+ "--vocoder-path",
104
+ type=str,
105
+ default=None,
106
+ help="The vocoder checkpoint. "
107
+ "Will download pre-trained vocoder from huggingface if not specified.",
108
+ )
109
+
110
+ parser.add_argument(
111
+ "--test-list",
112
+ type=str,
113
+ default=None,
114
+ help="The list of prompt speech, prompt_transcription, "
115
+ "and text to synthesizein the format of merged conversation: "
116
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' "
117
+ "or splited conversation: "
118
+ "'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}"
119
+ "\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'.",
120
+ )
121
+
122
+ parser.add_argument(
123
+ "--res-dir",
124
+ type=str,
125
+ default="results",
126
+ help="""
127
+ Path name of the generated wavs dir,
128
+ used when test-list is not None
129
+ """,
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--guidance-scale",
134
+ type=float,
135
+ default=1.5,
136
+ help="The scale of classifier-free guidance during inference.",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--num-step",
141
+ type=int,
142
+ default=16,
143
+ help="The number of sampling steps.",
144
+ )
145
+
146
+ parser.add_argument(
147
+ "--feat-scale",
148
+ type=float,
149
+ default=0.1,
150
+ help="The scale factor of fbank feature",
151
+ )
152
+
153
+ parser.add_argument(
154
+ "--speed",
155
+ type=float,
156
+ default=1.0,
157
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
158
+ )
159
+
160
+ parser.add_argument(
161
+ "--t-shift",
162
+ type=float,
163
+ default=0.5,
164
+ help="Shift t to smaller ones if t_shift < 1.0",
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--target-rms",
169
+ type=float,
170
+ default=0.1,
171
+ help="Target speech normalization rms value, set to 0 to disable normalization",
172
+ )
173
+
174
+ parser.add_argument(
175
+ "--seed",
176
+ type=int,
177
+ default=666,
178
+ help="Random seed",
179
+ )
180
+
181
+ parser.add_argument(
182
+ "--silence-wav",
183
+ type=str,
184
+ default="assets/silence.wav",
185
+ help="Path of the silence wav file, used in two-channel generation "
186
+ "with single-channel prompts",
187
+ )
188
+
189
+ return parser
190
+
191
+
192
+ def get_vocoder(vocos_local_path: Optional[str] = None):
193
+ if vocos_local_path:
194
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
195
+ state_dict = torch.load(
196
+ f"{vocos_local_path}/pytorch_model.bin",
197
+ weights_only=True,
198
+ map_location="cpu",
199
+ )
200
+ vocoder.load_state_dict(state_dict)
201
+ else:
202
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
203
+ return vocoder
204
+
205
+
206
+ def generate_sentence(
207
+ save_path: str,
208
+ prompt_text: str,
209
+ prompt_wav: Union[str, List[str]],
210
+ text: str,
211
+ model: torch.nn.Module,
212
+ vocoder: torch.nn.Module,
213
+ tokenizer: DialogTokenizer,
214
+ feature_extractor: VocosFbank,
215
+ device: torch.device,
216
+ num_step: int = 16,
217
+ guidance_scale: float = 1.0,
218
+ speed: float = 1.0,
219
+ t_shift: float = 0.5,
220
+ target_rms: float = 0.1,
221
+ feat_scale: float = 0.1,
222
+ sampling_rate: int = 24000,
223
+ ):
224
+ """
225
+ Generate waveform of a text based on a given prompt
226
+ waveform and its transcription.
227
+
228
+ Args:
229
+ save_path (str): Path to save the generated wav.
230
+ prompt_text (str): Transcription of the prompt wav.
231
+ prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
232
+ one or two wav files, which corresponding to a merged conversational
233
+ speech or two seperate speaker's speech.
234
+ text (str): Text to be synthesized into a waveform.
235
+ model (torch.nn.Module): The model used for generation.
236
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
237
+ tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
238
+ feature_extractor (VocosFbank): The feature extractor used to
239
+ extract acoustic features.
240
+ device (torch.device): The device on which computations are performed.
241
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
242
+ guidance_scale (float, optional): Scale for classifier-free guidance.
243
+ Defaults to 1.0.
244
+ speed (float, optional): Speed control. Defaults to 1.0.
245
+ t_shift (float, optional): Time shift. Defaults to 0.5.
246
+ target_rms (float, optional): Target RMS for waveform normalization.
247
+ Defaults to 0.1.
248
+ feat_scale (float, optional): Scale for features.
249
+ Defaults to 0.1.
250
+ sampling_rate (int, optional): Sampling rate for the waveform.
251
+ Defaults to 24000.
252
+ Returns:
253
+ metrics (dict): Dictionary containing time and real-time
254
+ factor metrics for processing.
255
+ """
256
+ # Convert text to tokens
257
+ tokens = tokenizer.texts_to_token_ids([text])
258
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
259
+
260
+ # Load and preprocess prompt wav
261
+ if isinstance(prompt_wav, str):
262
+ prompt_wav = [
263
+ prompt_wav,
264
+ ]
265
+ else:
266
+ assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
267
+
268
+ loaded_prompt_wavs = prompt_wav
269
+ for i in range(len(prompt_wav)):
270
+ loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i])
271
+ if prompt_sampling_rate != sampling_rate:
272
+ resampler = torchaudio.transforms.Resample(
273
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
274
+ )
275
+ loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i])
276
+ if loaded_prompt_wavs[i].size(0) != 1:
277
+ loaded_prompt_wavs[i] = loaded_prompt_wavs[i].mean(0, keepdim=True)
278
+
279
+ if len(loaded_prompt_wavs) == 1:
280
+ prompt_wav = loaded_prompt_wavs[0]
281
+ else:
282
+ prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
283
+
284
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
285
+ if prompt_rms < target_rms:
286
+ prompt_wav = prompt_wav * target_rms / prompt_rms
287
+
288
+ # Extract features from prompt wav
289
+ prompt_features = feature_extractor.extract(
290
+ prompt_wav, sampling_rate=sampling_rate
291
+ ).to(device)
292
+
293
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
294
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
295
+
296
+ # Start timing
297
+ start_t = dt.datetime.now()
298
+
299
+ # Generate features
300
+ (
301
+ pred_features,
302
+ pred_features_lens,
303
+ pred_prompt_features,
304
+ pred_prompt_features_lens,
305
+ ) = model.sample(
306
+ tokens=tokens,
307
+ prompt_tokens=prompt_tokens,
308
+ prompt_features=prompt_features,
309
+ prompt_features_lens=prompt_features_lens,
310
+ speed=speed,
311
+ t_shift=t_shift,
312
+ duration="predict",
313
+ num_step=num_step,
314
+ guidance_scale=guidance_scale,
315
+ )
316
+
317
+ # Postprocess predicted features
318
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
319
+
320
+ # Start vocoder processing
321
+ start_vocoder_t = dt.datetime.now()
322
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
323
+
324
+ # Calculate processing times and real-time factors
325
+ t = (dt.datetime.now() - start_t).total_seconds()
326
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
327
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
328
+ wav_seconds = wav.shape[-1] / sampling_rate
329
+ rtf = t / wav_seconds
330
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
331
+ rtf_vocoder = t_vocoder / wav_seconds
332
+ metrics = {
333
+ "t": t,
334
+ "t_no_vocoder": t_no_vocoder,
335
+ "t_vocoder": t_vocoder,
336
+ "wav_seconds": wav_seconds,
337
+ "rtf": rtf,
338
+ "rtf_no_vocoder": rtf_no_vocoder,
339
+ "rtf_vocoder": rtf_vocoder,
340
+ }
341
+
342
+ # Adjust wav volume if necessary
343
+ if prompt_rms < target_rms:
344
+ wav = wav * prompt_rms / target_rms
345
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
346
+
347
+ return metrics
348
+
349
+
350
+ def generate_sentence_stereo(
351
+ save_path: str,
352
+ prompt_text: str,
353
+ prompt_wav: Union[str, List[str]],
354
+ text: str,
355
+ model: torch.nn.Module,
356
+ vocoder: torch.nn.Module,
357
+ tokenizer: DialogTokenizer,
358
+ feature_extractor: VocosFbank,
359
+ device: torch.device,
360
+ num_step: int = 16,
361
+ guidance_scale: float = 1.0,
362
+ speed: float = 1.0,
363
+ t_shift: float = 0.5,
364
+ target_rms: float = 0.1,
365
+ feat_scale: float = 0.1,
366
+ sampling_rate: int = 24000,
367
+ silence_wav: Optional[str] = None,
368
+ ):
369
+ """
370
+ Generate waveform of a text based on a given prompt
371
+ waveform and its transcription.
372
+
373
+ Args:
374
+ save_path (str): Path to save the generated wav.
375
+ prompt_text (str): Transcription of the prompt wav.
376
+ prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
377
+ one or two wav files, which corresponding to a merged conversational
378
+ speech or two seperate speaker's speech.
379
+ text (str): Text to be synthesized into a waveform.
380
+ model (torch.nn.Module): The model used for generation.
381
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
382
+ tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
383
+ feature_extractor (VocosFbank): The feature extractor used to
384
+ extract acoustic features.
385
+ device (torch.device): The device on which computations are performed.
386
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
387
+ guidance_scale (float, optional): Scale for classifier-free guidance.
388
+ Defaults to 1.0.
389
+ speed (float, optional): Speed control. Defaults to 1.0.
390
+ t_shift (float, optional): Time shift. Defaults to 0.5.
391
+ target_rms (float, optional): Target RMS for waveform normalization.
392
+ Defaults to 0.1.
393
+ feat_scale (float, optional): Scale for features.
394
+ Defaults to 0.1.
395
+ sampling_rate (int, optional): Sampling rate for the waveform.
396
+ Defaults to 24000.
397
+ silence_wav (str): Path of the silence wav file, used in two-channel
398
+ generation with single-channel prompts
399
+ Returns:
400
+ metrics (dict): Dictionary containing time and real-time
401
+ factor metrics for processing.
402
+ """
403
+ # Convert text to tokens
404
+ tokens = tokenizer.texts_to_token_ids([text])
405
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
406
+
407
+ # Load and preprocess prompt wav
408
+ if isinstance(prompt_wav, str):
409
+ prompt_wav = [
410
+ prompt_wav,
411
+ ]
412
+ else:
413
+ assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
414
+
415
+ loaded_prompt_wavs = prompt_wav
416
+ for i in range(len(prompt_wav)):
417
+ loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i])
418
+ if prompt_sampling_rate != sampling_rate:
419
+ resampler = torchaudio.transforms.Resample(
420
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
421
+ )
422
+ loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i])
423
+
424
+ if len(loaded_prompt_wavs) == 1:
425
+ assert (
426
+ loaded_prompt_wavs[0].size(0) == 2
427
+ ), "Merged prompt wav must be stereo for stereo dialogue generation"
428
+ prompt_wav = loaded_prompt_wavs[0]
429
+
430
+ else:
431
+ assert len(loaded_prompt_wavs) == 2
432
+ if loaded_prompt_wavs[0].size(0) == 2:
433
+ prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
434
+ else:
435
+ assert loaded_prompt_wavs[0].size(0) == 1
436
+ silence_wav, silence_sampling_rate = torchaudio.load(silence_wav)
437
+ assert silence_sampling_rate == sampling_rate
438
+ prompt_wav = silence_wav[
439
+ :, : loaded_prompt_wavs[0].size(1) + loaded_prompt_wavs[1].size(1)
440
+ ]
441
+ prompt_wav[0, : loaded_prompt_wavs[0].size(1)] = loaded_prompt_wavs[0]
442
+ prompt_wav[1, loaded_prompt_wavs[0].size(1) :] = loaded_prompt_wavs[1]
443
+
444
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
445
+ if prompt_rms < target_rms:
446
+ prompt_wav = prompt_wav * target_rms / prompt_rms
447
+
448
+ # Extract features from prompt wav
449
+ prompt_features = feature_extractor.extract(
450
+ prompt_wav, sampling_rate=sampling_rate
451
+ ).to(device)
452
+
453
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
454
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
455
+
456
+ # Start timing
457
+ start_t = dt.datetime.now()
458
+
459
+ # Generate features
460
+ (
461
+ pred_features,
462
+ pred_features_lens,
463
+ pred_prompt_features,
464
+ pred_prompt_features_lens,
465
+ ) = model.sample(
466
+ tokens=tokens,
467
+ prompt_tokens=prompt_tokens,
468
+ prompt_features=prompt_features,
469
+ prompt_features_lens=prompt_features_lens,
470
+ speed=speed,
471
+ t_shift=t_shift,
472
+ duration="predict",
473
+ num_step=num_step,
474
+ guidance_scale=guidance_scale,
475
+ )
476
+
477
+ # Postprocess predicted features
478
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
479
+
480
+ # Start vocoder processing
481
+ start_vocoder_t = dt.datetime.now()
482
+ feat_dim = pred_features.size(1) // 2
483
+ wav_left = vocoder.decode(pred_features[:, :feat_dim]).squeeze(1).clamp(-1, 1)
484
+ wav_right = (
485
+ vocoder.decode(pred_features[:, feat_dim : feat_dim * 2])
486
+ .squeeze(1)
487
+ .clamp(-1, 1)
488
+ )
489
+
490
+ wav = torch.cat([wav_left, wav_right], dim=0)
491
+
492
+ # Calculate processing times and real-time factors
493
+ t = (dt.datetime.now() - start_t).total_seconds()
494
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
495
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
496
+ wav_seconds = wav.shape[-1] / sampling_rate
497
+ rtf = t / wav_seconds
498
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
499
+ rtf_vocoder = t_vocoder / wav_seconds
500
+ metrics = {
501
+ "t": t,
502
+ "t_no_vocoder": t_no_vocoder,
503
+ "t_vocoder": t_vocoder,
504
+ "wav_seconds": wav_seconds,
505
+ "rtf": rtf,
506
+ "rtf_no_vocoder": rtf_no_vocoder,
507
+ "rtf_vocoder": rtf_vocoder,
508
+ }
509
+
510
+ # Adjust wav volume if necessary
511
+ if prompt_rms < target_rms:
512
+ wav = wav * prompt_rms / target_rms
513
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
514
+
515
+ return metrics
516
+
517
+
518
+ def generate_list(
519
+ model_name: str,
520
+ res_dir: str,
521
+ test_list: str,
522
+ model: torch.nn.Module,
523
+ vocoder: torch.nn.Module,
524
+ tokenizer: DialogTokenizer,
525
+ feature_extractor: VocosFbank,
526
+ device: torch.device,
527
+ num_step: int = 16,
528
+ guidance_scale: float = 1.5,
529
+ speed: float = 1.0,
530
+ t_shift: float = 0.5,
531
+ target_rms: float = 0.1,
532
+ feat_scale: float = 0.1,
533
+ sampling_rate: int = 24000,
534
+ silence_wav: Optional[str] = None,
535
+ ):
536
+ total_t = []
537
+ total_t_no_vocoder = []
538
+ total_t_vocoder = []
539
+ total_wav_seconds = []
540
+
541
+ with open(test_list, "r") as fr:
542
+ lines = fr.readlines()
543
+
544
+ for i, line in enumerate(lines):
545
+ items = line.strip().split("\t")
546
+ if len(items) == 6:
547
+ (
548
+ wav_name,
549
+ prompt_text_1,
550
+ prompt_text_2,
551
+ prompt_wav_1,
552
+ prompt_wav_2,
553
+ text,
554
+ ) = items
555
+ prompt_text = f"[S1]{prompt_text_1}[S2]{prompt_text_2}"
556
+ prompt_wav = [prompt_wav_1, prompt_wav_2]
557
+ elif len(items) == 4:
558
+ wav_name, prompt_text, prompt_wav, text = items
559
+ else:
560
+ raise ValueError(f"Invalid line: {line}")
561
+ assert text.startswith("[S1]")
562
+
563
+ save_path = f"{res_dir}/{wav_name}.wav"
564
+
565
+ if model_name == "zipvoice_dialog":
566
+
567
+ metrics = generate_sentence(
568
+ save_path=save_path,
569
+ prompt_text=prompt_text,
570
+ prompt_wav=prompt_wav,
571
+ text=text,
572
+ model=model,
573
+ vocoder=vocoder,
574
+ tokenizer=tokenizer,
575
+ feature_extractor=feature_extractor,
576
+ device=device,
577
+ num_step=num_step,
578
+ guidance_scale=guidance_scale,
579
+ speed=speed,
580
+ t_shift=t_shift,
581
+ target_rms=target_rms,
582
+ feat_scale=feat_scale,
583
+ sampling_rate=sampling_rate,
584
+ )
585
+ else:
586
+ assert model_name == "zipvoice_dialog_stereo"
587
+ metrics = generate_sentence_stereo(
588
+ save_path=save_path,
589
+ prompt_text=prompt_text,
590
+ prompt_wav=prompt_wav,
591
+ text=text,
592
+ model=model,
593
+ vocoder=vocoder,
594
+ tokenizer=tokenizer,
595
+ feature_extractor=feature_extractor,
596
+ device=device,
597
+ num_step=num_step,
598
+ guidance_scale=guidance_scale,
599
+ speed=speed,
600
+ t_shift=t_shift,
601
+ target_rms=target_rms,
602
+ feat_scale=feat_scale,
603
+ sampling_rate=sampling_rate,
604
+ silence_wav=silence_wav,
605
+ )
606
+
607
+ logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
608
+ total_t.append(metrics["t"])
609
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
610
+ total_t_vocoder.append(metrics["t_vocoder"])
611
+ total_wav_seconds.append(metrics["wav_seconds"])
612
+
613
+ logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
614
+ logging.info(
615
+ f"Average RTF w/o vocoder: "
616
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
617
+ )
618
+ logging.info(
619
+ f"Average RTF vocoder: "
620
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
621
+ )
622
+
623
+
624
+ @torch.inference_mode()
625
+ def main():
626
+ parser = get_parser()
627
+ args = parser.parse_args()
628
+
629
+ params = AttributeDict()
630
+ params.update(vars(args))
631
+ fix_random_seed(params.seed)
632
+
633
+ assert (
634
+ params.test_list is not None
635
+ ), "For inference, please provide prompts and text with '--test-list'"
636
+
637
+ if params.model_dir is not None:
638
+ params.model_dir = Path(params.model_dir)
639
+ if not params.model_dir.is_dir():
640
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
641
+ for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
642
+ if not (params.model_dir / filename).is_file():
643
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
644
+ model_ckpt = params.model_dir / params.checkpoint_name
645
+ model_config = params.model_dir / "model.json"
646
+ token_file = params.model_dir / "tokens.txt"
647
+ logging.info(
648
+ f"Using local model dir {params.model_dir}, "
649
+ f"checkpoint {params.checkpoint_name}"
650
+ )
651
+ else:
652
+ logging.info("Using pretrained model from the huggingface")
653
+ logging.info("Downloading the requires files from HuggingFace")
654
+ model_ckpt = hf_hub_download(
655
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.pt"
656
+ )
657
+ model_config = hf_hub_download(
658
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
659
+ )
660
+
661
+ token_file = hf_hub_download(
662
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
663
+ )
664
+
665
+ logging.info("Loading model...")
666
+
667
+ tokenizer = DialogTokenizer(token_file=token_file)
668
+
669
+ tokenizer_config = {
670
+ "vocab_size": tokenizer.vocab_size,
671
+ "pad_id": tokenizer.pad_id,
672
+ "spk_a_id": tokenizer.spk_a_id,
673
+ "spk_b_id": tokenizer.spk_b_id,
674
+ }
675
+
676
+ with open(model_config, "r") as f:
677
+ model_config = json.load(f)
678
+
679
+ if params.model_name == "zipvoice_dialog":
680
+ model = ZipVoiceDialog(
681
+ **model_config["model"],
682
+ **tokenizer_config,
683
+ )
684
+ else:
685
+ assert params.model_name == "zipvoice_dialog_stereo"
686
+ model = ZipVoiceDialogStereo(
687
+ **model_config["model"],
688
+ **tokenizer_config,
689
+ )
690
+
691
+ if str(model_ckpt).endswith(".safetensors"):
692
+ safetensors.torch.load_model(model, model_ckpt)
693
+ elif str(model_ckpt).endswith(".pt"):
694
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
695
+ else:
696
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
697
+
698
+ if torch.cuda.is_available():
699
+ params.device = torch.device("cuda", 0)
700
+ elif torch.backends.mps.is_available():
701
+ params.device = torch.device("mps")
702
+ else:
703
+ params.device = torch.device("cpu")
704
+ logging.info(f"Device: {params.device}")
705
+
706
+ model = model.to(params.device)
707
+ model.eval()
708
+
709
+ vocoder = get_vocoder(params.vocoder_path)
710
+ vocoder = vocoder.to(params.device)
711
+ vocoder.eval()
712
+
713
+ if model_config["feature"]["type"] == "vocos":
714
+ if params.model_name == "zipvoice_dialog":
715
+ num_channels = 1
716
+ else:
717
+ assert params.model_name == "zipvoice_dialog_stereo"
718
+ num_channels = 2
719
+ feature_extractor = VocosFbank(num_channels=num_channels)
720
+ else:
721
+ raise NotImplementedError(
722
+ f"Unsupported feature type: {model_config['feature']['type']}"
723
+ )
724
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
725
+
726
+ logging.info("Start generating...")
727
+ os.makedirs(params.res_dir, exist_ok=True)
728
+ generate_list(
729
+ model_name=params.model_name,
730
+ res_dir=params.res_dir,
731
+ test_list=params.test_list,
732
+ model=model,
733
+ vocoder=vocoder,
734
+ tokenizer=tokenizer,
735
+ feature_extractor=feature_extractor,
736
+ device=params.device,
737
+ num_step=params.num_step,
738
+ guidance_scale=params.guidance_scale,
739
+ speed=params.speed,
740
+ t_shift=params.t_shift,
741
+ target_rms=params.target_rms,
742
+ feat_scale=params.feat_scale,
743
+ sampling_rate=params.sampling_rate,
744
+ silence_wav=params.silence_wav,
745
+ )
746
+ logging.info("Done")
747
+
748
+
749
+ if __name__ == "__main__":
750
+ torch.set_num_threads(1)
751
+ torch.set_num_interop_threads(1)
752
+
753
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
754
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
755
+
756
+ main()
zipvoice/bin/infer_zipvoice_onnx.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
2
+ # Zengwei Yao)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script generates speech with our pre-trained ZipVoice or ZipVoice-Distill
20
+ ONNX models. If no local model is specified,
21
+ Required files will be automatically downloaded from HuggingFace.
22
+
23
+ Usage:
24
+
25
+ Note: If you having trouble connecting to HuggingFace,
26
+ try switching endpoint to mirror site:
27
+ export HF_ENDPOINT=https://hf-mirror.com
28
+
29
+ (1) Inference of a single sentence:
30
+
31
+ python3 -m zipvoice.bin.infer_zipvoice_onnx \
32
+ --onnx-int8 False \
33
+ --model-name zipvoice \
34
+ --prompt-wav prompt.wav \
35
+ --prompt-text "I am a prompt." \
36
+ --text "I am a sentence." \
37
+ --res-wav-path result.wav
38
+
39
+ (2) Inference of a list of sentences:
40
+ python3 -m zipvoice.bin.infer_zipvoice_onnx \
41
+ --onnx-int8 False \
42
+ --model-name zipvoice \
43
+ --test-list test.tsv \
44
+ --res-dir results
45
+
46
+ `--model-name` can be `zipvoice` or `zipvoice_distill`,
47
+ which are the models before and after distillation, respectively.
48
+
49
+ Each line of `test.tsv` is in the format of
50
+ `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
51
+
52
+ Set `--onnx-int8 True` to use int8 quantizated ONNX model.
53
+ """
54
+
55
+ import argparse
56
+ import datetime as dt
57
+ import json
58
+ import logging
59
+ import os
60
+ from pathlib import Path
61
+ from typing import List, Tuple
62
+
63
+ import numpy as np
64
+ import onnxruntime as ort
65
+ import torch
66
+ import torchaudio
67
+ from huggingface_hub import hf_hub_download
68
+ from lhotse.utils import fix_random_seed
69
+ from torch import Tensor, nn
70
+
71
+ from zipvoice.bin.infer_zipvoice import get_vocoder
72
+ from zipvoice.models.modules.solver import get_time_steps
73
+ from zipvoice.tokenizer.tokenizer import (
74
+ EmiliaTokenizer,
75
+ EspeakTokenizer,
76
+ LibriTTSTokenizer,
77
+ SimpleTokenizer,
78
+ )
79
+ from zipvoice.utils.common import AttributeDict, str2bool
80
+ from zipvoice.utils.feature import VocosFbank
81
+
82
+ HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
83
+ MODEL_DIR = {
84
+ "zipvoice": "zipvoice",
85
+ "zipvoice_distill": "zipvoice_distill",
86
+ }
87
+
88
+
89
+ def get_parser():
90
+ parser = argparse.ArgumentParser(
91
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--onnx-int8",
96
+ type=str2bool,
97
+ default=False,
98
+ help="Whether to use the int8 model",
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--model-name",
103
+ type=str,
104
+ default="zipvoice",
105
+ choices=["zipvoice", "zipvoice_distill"],
106
+ help="The model used for inference",
107
+ )
108
+
109
+ parser.add_argument(
110
+ "--model-dir",
111
+ type=str,
112
+ default=None,
113
+ help="The path to the local onnx model. "
114
+ "Will download pre-trained checkpoint from huggingface if not specified.",
115
+ )
116
+
117
+ parser.add_argument(
118
+ "--vocoder-path",
119
+ type=str,
120
+ default=None,
121
+ help="The vocoder checkpoint. "
122
+ "Will download pre-trained vocoder from huggingface if not specified.",
123
+ )
124
+
125
+ parser.add_argument(
126
+ "--tokenizer",
127
+ type=str,
128
+ default="emilia",
129
+ choices=["emilia", "libritts", "espeak", "simple"],
130
+ help="Tokenizer type.",
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--lang",
135
+ type=str,
136
+ default="en-us",
137
+ help="Language identifier, used when tokenizer type is espeak. see"
138
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
139
+ )
140
+
141
+ parser.add_argument(
142
+ "--test-list",
143
+ type=str,
144
+ default=None,
145
+ help="The list of prompt speech, prompt_transcription, "
146
+ "and text to synthesizein the format of "
147
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
148
+ )
149
+
150
+ parser.add_argument(
151
+ "--prompt-wav",
152
+ type=str,
153
+ default=None,
154
+ help="The prompt wav to mimic",
155
+ )
156
+
157
+ parser.add_argument(
158
+ "--prompt-text",
159
+ type=str,
160
+ default=None,
161
+ help="The transcription of the prompt wav",
162
+ )
163
+
164
+ parser.add_argument(
165
+ "--text",
166
+ type=str,
167
+ default=None,
168
+ help="The text to synthesize",
169
+ )
170
+
171
+ parser.add_argument(
172
+ "--res-dir",
173
+ type=str,
174
+ default="results",
175
+ help="""
176
+ Path name of the generated wavs dir,
177
+ used when test-list is not None
178
+ """,
179
+ )
180
+
181
+ parser.add_argument(
182
+ "--res-wav-path",
183
+ type=str,
184
+ default="result.wav",
185
+ help="""
186
+ Path name of the generated wav path,
187
+ used when test-list is None
188
+ """,
189
+ )
190
+
191
+ parser.add_argument(
192
+ "--guidance-scale",
193
+ type=float,
194
+ default=None,
195
+ help="The scale of classifier-free guidance during inference.",
196
+ )
197
+
198
+ parser.add_argument(
199
+ "--num-step",
200
+ type=int,
201
+ default=None,
202
+ help="The number of sampling steps.",
203
+ )
204
+
205
+ parser.add_argument(
206
+ "--feat-scale",
207
+ type=float,
208
+ default=0.1,
209
+ help="The scale factor of fbank feature",
210
+ )
211
+
212
+ parser.add_argument(
213
+ "--speed",
214
+ type=float,
215
+ default=1.0,
216
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
217
+ )
218
+
219
+ parser.add_argument(
220
+ "--t-shift",
221
+ type=float,
222
+ default=0.5,
223
+ help="Shift t to smaller ones if t_shift < 1.0",
224
+ )
225
+
226
+ parser.add_argument(
227
+ "--target-rms",
228
+ type=float,
229
+ default=0.1,
230
+ help="Target speech normalization rms value, set to 0 to disable normalization",
231
+ )
232
+
233
+ parser.add_argument(
234
+ "--seed",
235
+ type=int,
236
+ default=666,
237
+ help="Random seed",
238
+ )
239
+
240
+ return parser
241
+
242
+
243
+ class OnnxModel:
244
+ def __init__(
245
+ self,
246
+ text_encoder_path: str,
247
+ fm_decoder_path: str,
248
+ ):
249
+ session_opts = ort.SessionOptions()
250
+ session_opts.inter_op_num_threads = 1
251
+ session_opts.intra_op_num_threads = 1
252
+
253
+ self.session_opts = session_opts
254
+
255
+ self.init_text_encoder(text_encoder_path)
256
+ self.init_fm_decoder(fm_decoder_path)
257
+
258
+ def init_text_encoder(self, model_path: str):
259
+ self.text_encoder = ort.InferenceSession(
260
+ model_path,
261
+ sess_options=self.session_opts,
262
+ providers=["CPUExecutionProvider"],
263
+ )
264
+
265
+ def init_fm_decoder(self, model_path: str):
266
+ self.fm_decoder = ort.InferenceSession(
267
+ model_path,
268
+ sess_options=self.session_opts,
269
+ providers=["CPUExecutionProvider"],
270
+ )
271
+ meta = self.fm_decoder.get_modelmeta().custom_metadata_map
272
+ self.feat_dim = int(meta["feat_dim"])
273
+
274
+ def run_text_encoder(
275
+ self,
276
+ tokens: Tensor,
277
+ prompt_tokens: Tensor,
278
+ prompt_features_len: Tensor,
279
+ speed: Tensor,
280
+ ) -> Tuple[Tensor, Tensor]:
281
+ out = self.text_encoder.run(
282
+ [
283
+ self.text_encoder.get_outputs()[0].name,
284
+ ],
285
+ {
286
+ self.text_encoder.get_inputs()[0].name: tokens.numpy(),
287
+ self.text_encoder.get_inputs()[1].name: prompt_tokens.numpy(),
288
+ self.text_encoder.get_inputs()[2].name: prompt_features_len.numpy(),
289
+ self.text_encoder.get_inputs()[3].name: speed.numpy(),
290
+ },
291
+ )
292
+ return torch.from_numpy(out[0])
293
+
294
+ def run_fm_decoder(
295
+ self,
296
+ t: Tensor,
297
+ x: Tensor,
298
+ text_condition: Tensor,
299
+ speech_condition: torch.Tensor,
300
+ guidance_scale: Tensor,
301
+ ) -> Tensor:
302
+ out = self.fm_decoder.run(
303
+ [
304
+ self.fm_decoder.get_outputs()[0].name,
305
+ ],
306
+ {
307
+ self.fm_decoder.get_inputs()[0].name: t.numpy(),
308
+ self.fm_decoder.get_inputs()[1].name: x.numpy(),
309
+ self.fm_decoder.get_inputs()[2].name: text_condition.numpy(),
310
+ self.fm_decoder.get_inputs()[3].name: speech_condition.numpy(),
311
+ self.fm_decoder.get_inputs()[4].name: guidance_scale.numpy(),
312
+ },
313
+ )
314
+ return torch.from_numpy(out[0])
315
+
316
+
317
+ def sample(
318
+ model: OnnxModel,
319
+ tokens: List[List[int]],
320
+ prompt_tokens: List[List[int]],
321
+ prompt_features: Tensor,
322
+ speed: float = 1.0,
323
+ t_shift: float = 0.5,
324
+ guidance_scale: float = 1.0,
325
+ num_step: int = 16,
326
+ ) -> torch.Tensor:
327
+ """
328
+ Generate acoustic features, given text tokens, prompts feature and prompt
329
+ transcription's text tokens.
330
+
331
+ Args:
332
+ tokens: a list of list of text tokens.
333
+ prompt_tokens: a list of list of prompt tokens.
334
+ prompt_features: the prompt feature with the shape
335
+ (batch_size, seq_len, feat_dim).
336
+ speed : speed control.
337
+ t_shift: time shift.
338
+ guidance_scale: the guidance scale for classifier-free guidance.
339
+ num_step: the number of steps to use in the ODE solver.
340
+ """
341
+ # Run text encoder
342
+ assert len(tokens) == len(prompt_tokens) == 1
343
+ tokens = torch.tensor(tokens, dtype=torch.int64)
344
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.int64)
345
+ prompt_features_len = torch.tensor(prompt_features.size(1), dtype=torch.int64)
346
+ speed = torch.tensor(speed, dtype=torch.float32)
347
+
348
+ text_condition = model.run_text_encoder(
349
+ tokens, prompt_tokens, prompt_features_len, speed
350
+ )
351
+
352
+ batch_size, num_frames, _ = text_condition.shape
353
+ assert batch_size == 1
354
+ feat_dim = model.feat_dim
355
+
356
+ # Run flow matching model
357
+ timesteps = get_time_steps(
358
+ t_start=0.0,
359
+ t_end=1.0,
360
+ num_step=num_step,
361
+ t_shift=t_shift,
362
+ )
363
+ x = torch.randn(batch_size, num_frames, feat_dim)
364
+ speech_condition = torch.nn.functional.pad(
365
+ prompt_features, (0, 0, 0, num_frames - prompt_features.shape[1])
366
+ ) # (B, T, F)
367
+ guidance_scale = torch.tensor(guidance_scale, dtype=torch.float32)
368
+
369
+ for step in range(num_step):
370
+ v = model.run_fm_decoder(
371
+ t=timesteps[step],
372
+ x=x,
373
+ text_condition=text_condition,
374
+ speech_condition=speech_condition,
375
+ guidance_scale=guidance_scale,
376
+ )
377
+ x = x + v * (timesteps[step + 1] - timesteps[step])
378
+
379
+ x = x[:, prompt_features_len.item() :, :]
380
+ return x
381
+
382
+
383
+ # Copied from zipvoice/bin/infer_zipvoice.py, but call an external sample function
384
+ def generate_sentence(
385
+ save_path: str,
386
+ prompt_text: str,
387
+ prompt_wav: str,
388
+ text: str,
389
+ model: OnnxModel,
390
+ vocoder: nn.Module,
391
+ tokenizer: EmiliaTokenizer,
392
+ feature_extractor: VocosFbank,
393
+ num_step: int = 16,
394
+ guidance_scale: float = 1.0,
395
+ speed: float = 1.0,
396
+ t_shift: float = 0.5,
397
+ target_rms: float = 0.1,
398
+ feat_scale: float = 0.1,
399
+ sampling_rate: int = 24000,
400
+ ):
401
+ """
402
+ Generate waveform of a text based on a given prompt
403
+ waveform and its transcription.
404
+
405
+ Args:
406
+ save_path (str): Path to save the generated wav.
407
+ prompt_text (str): Transcription of the prompt wav.
408
+ prompt_wav (str): Path to the prompt wav file.
409
+ text (str): Text to be synthesized into a waveform.
410
+ model (torch.nn.Module): The model used for generation.
411
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
412
+ tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
413
+ feature_extractor (VocosFbank): The feature extractor used to
414
+ extract acoustic features.
415
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
416
+ guidance_scale (float, optional): Scale for classifier-free guidance.
417
+ Defaults to 1.0.
418
+ speed (float, optional): Speed control. Defaults to 1.0.
419
+ t_shift (float, optional): Time shift. Defaults to 0.5.
420
+ target_rms (float, optional): Target RMS for waveform normalization.
421
+ Defaults to 0.1.
422
+ feat_scale (float, optional): Scale for features.
423
+ Defaults to 0.1.
424
+ sampling_rate (int, optional): Sampling rate for the waveform.
425
+ Defaults to 24000.
426
+ Returns:
427
+ metrics (dict): Dictionary containing time and real-time
428
+ factor metrics for processing.
429
+ """
430
+ # Convert text to tokens
431
+ tokens = tokenizer.texts_to_token_ids([text])
432
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
433
+
434
+ # Load and preprocess prompt wav
435
+ prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
436
+
437
+ if prompt_sampling_rate != sampling_rate:
438
+ resampler = torchaudio.transforms.Resample(
439
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
440
+ )
441
+ prompt_wav = resampler(prompt_wav)
442
+
443
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
444
+ if prompt_rms < target_rms:
445
+ prompt_wav = prompt_wav * target_rms / prompt_rms
446
+
447
+ # Extract features from prompt wav
448
+ prompt_features = feature_extractor.extract(prompt_wav, sampling_rate=sampling_rate)
449
+
450
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
451
+
452
+ # Start timing
453
+ start_t = dt.datetime.now()
454
+
455
+ # Generate features
456
+ pred_features = sample(
457
+ model=model,
458
+ tokens=tokens,
459
+ prompt_tokens=prompt_tokens,
460
+ prompt_features=prompt_features,
461
+ speed=speed,
462
+ t_shift=t_shift,
463
+ guidance_scale=guidance_scale,
464
+ num_step=num_step,
465
+ )
466
+
467
+ # Postprocess predicted features
468
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
469
+
470
+ # Start vocoder processing
471
+ start_vocoder_t = dt.datetime.now()
472
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
473
+
474
+ # Calculate processing times and real-time factors
475
+ t = (dt.datetime.now() - start_t).total_seconds()
476
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
477
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
478
+ wav_seconds = wav.shape[-1] / sampling_rate
479
+ rtf = t / wav_seconds
480
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
481
+ rtf_vocoder = t_vocoder / wav_seconds
482
+ metrics = {
483
+ "t": t,
484
+ "t_no_vocoder": t_no_vocoder,
485
+ "t_vocoder": t_vocoder,
486
+ "wav_seconds": wav_seconds,
487
+ "rtf": rtf,
488
+ "rtf_no_vocoder": rtf_no_vocoder,
489
+ "rtf_vocoder": rtf_vocoder,
490
+ }
491
+
492
+ # Adjust wav volume if necessary
493
+ if prompt_rms < target_rms:
494
+ wav = wav * prompt_rms / target_rms
495
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
496
+
497
+ return metrics
498
+
499
+
500
+ def generate_list(
501
+ res_dir: str,
502
+ test_list: str,
503
+ model: OnnxModel,
504
+ vocoder: nn.Module,
505
+ tokenizer: EmiliaTokenizer,
506
+ feature_extractor: VocosFbank,
507
+ num_step: int = 16,
508
+ guidance_scale: float = 1.0,
509
+ speed: float = 1.0,
510
+ t_shift: float = 0.5,
511
+ target_rms: float = 0.1,
512
+ feat_scale: float = 0.1,
513
+ sampling_rate: int = 24000,
514
+ ):
515
+ total_t = []
516
+ total_t_no_vocoder = []
517
+ total_t_vocoder = []
518
+ total_wav_seconds = []
519
+
520
+ with open(test_list, "r") as fr:
521
+ lines = fr.readlines()
522
+
523
+ for i, line in enumerate(lines):
524
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
525
+ save_path = f"{res_dir}/{wav_name}.wav"
526
+ metrics = generate_sentence(
527
+ save_path=save_path,
528
+ prompt_text=prompt_text,
529
+ prompt_wav=prompt_wav,
530
+ text=text,
531
+ model=model,
532
+ vocoder=vocoder,
533
+ tokenizer=tokenizer,
534
+ feature_extractor=feature_extractor,
535
+ num_step=num_step,
536
+ guidance_scale=guidance_scale,
537
+ speed=speed,
538
+ t_shift=t_shift,
539
+ target_rms=target_rms,
540
+ feat_scale=feat_scale,
541
+ sampling_rate=sampling_rate,
542
+ )
543
+ logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
544
+ total_t.append(metrics["t"])
545
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
546
+ total_t_vocoder.append(metrics["t_vocoder"])
547
+ total_wav_seconds.append(metrics["wav_seconds"])
548
+
549
+ logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
550
+ logging.info(
551
+ f"Average RTF w/o vocoder: "
552
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
553
+ )
554
+ logging.info(
555
+ f"Average RTF vocoder: "
556
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
557
+ )
558
+
559
+
560
+ @torch.inference_mode()
561
+ def main():
562
+ parser = get_parser()
563
+ args = parser.parse_args()
564
+
565
+ params = AttributeDict()
566
+ params.update(vars(args))
567
+ fix_random_seed(params.seed)
568
+
569
+ model_defaults = {
570
+ "zipvoice": {
571
+ "num_step": 16,
572
+ "guidance_scale": 1.0,
573
+ },
574
+ "zipvoice_distill": {
575
+ "num_step": 8,
576
+ "guidance_scale": 3.0,
577
+ },
578
+ }
579
+
580
+ model_specific_defaults = model_defaults.get(params.model_name, {})
581
+
582
+ for param, value in model_specific_defaults.items():
583
+ if getattr(params, param) is None:
584
+ setattr(params, param, value)
585
+ logging.info(f"Setting {param} to default value: {value}")
586
+
587
+ assert (params.test_list is not None) ^ (
588
+ (params.prompt_wav and params.prompt_text and params.text) is not None
589
+ ), (
590
+ "For inference, please provide prompts and text with either '--test-list'"
591
+ " or '--prompt-wav, --prompt-text and --text'."
592
+ )
593
+
594
+ if params.onnx_int8:
595
+ text_encoder_name = "text_encoder_int8.onnx"
596
+ fm_decoder_name = "fm_decoder_int8.onnx"
597
+ else:
598
+ text_encoder_name = "text_encoder.onnx"
599
+ fm_decoder_name = "fm_decoder.onnx"
600
+
601
+ if params.model_dir is not None:
602
+ params.model_dir = Path(params.model_dir)
603
+ if not params.model_dir.is_dir():
604
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
605
+
606
+ for filename in [
607
+ text_encoder_name,
608
+ fm_decoder_name,
609
+ "model.json",
610
+ "tokens.txt",
611
+ ]:
612
+ if not (params.model_dir / filename).is_file():
613
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
614
+ text_encoder_path = params.model_dir / text_encoder_name
615
+ fm_decoder_path = params.model_dir / fm_decoder_name
616
+ model_config = params.model_dir / "model.json"
617
+ token_file = params.model_dir / "tokens.txt"
618
+ logging.info(f"Using local model dir {params.model_dir}.")
619
+ else:
620
+ logging.info("Using pretrained model from the huggingface")
621
+ logging.info("Downloading the requires files from HuggingFace")
622
+ text_encoder_path = hf_hub_download(
623
+ HUGGINGFACE_REPO,
624
+ filename=f"{MODEL_DIR[params.model_name]}/{text_encoder_name}",
625
+ )
626
+ fm_decoder_path = hf_hub_download(
627
+ HUGGINGFACE_REPO,
628
+ filename=f"{MODEL_DIR[params.model_name]}/{fm_decoder_name}",
629
+ )
630
+ model_config = hf_hub_download(
631
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
632
+ )
633
+
634
+ token_file = hf_hub_download(
635
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
636
+ )
637
+
638
+ logging.info("Loading model...")
639
+
640
+ if params.tokenizer == "emilia":
641
+ tokenizer = EmiliaTokenizer(token_file=token_file)
642
+ elif params.tokenizer == "libritts":
643
+ tokenizer = LibriTTSTokenizer(token_file=token_file)
644
+ elif params.tokenizer == "espeak":
645
+ tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
646
+ else:
647
+ assert params.tokenizer == "simple"
648
+ tokenizer = SimpleTokenizer(token_file=token_file)
649
+
650
+ with open(model_config, "r") as f:
651
+ model_config = json.load(f)
652
+
653
+ model = OnnxModel(text_encoder_path, fm_decoder_path)
654
+
655
+ vocoder = get_vocoder(params.vocoder_path)
656
+ vocoder.eval()
657
+
658
+ if model_config["feature"]["type"] == "vocos":
659
+ feature_extractor = VocosFbank()
660
+ else:
661
+ raise NotImplementedError(
662
+ f"Unsupported feature type: {model_config['feature']['type']}"
663
+ )
664
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
665
+
666
+ logging.info("Start generating...")
667
+ if params.test_list:
668
+ os.makedirs(params.res_dir, exist_ok=True)
669
+ generate_list(
670
+ res_dir=params.res_dir,
671
+ test_list=params.test_list,
672
+ model=model,
673
+ vocoder=vocoder,
674
+ tokenizer=tokenizer,
675
+ feature_extractor=feature_extractor,
676
+ num_step=params.num_step,
677
+ guidance_scale=params.guidance_scale,
678
+ speed=params.speed,
679
+ t_shift=params.t_shift,
680
+ target_rms=params.target_rms,
681
+ feat_scale=params.feat_scale,
682
+ sampling_rate=params.sampling_rate,
683
+ )
684
+ else:
685
+ generate_sentence(
686
+ save_path=params.res_wav_path,
687
+ prompt_text=params.prompt_text,
688
+ prompt_wav=params.prompt_wav,
689
+ text=params.text,
690
+ model=model,
691
+ vocoder=vocoder,
692
+ tokenizer=tokenizer,
693
+ feature_extractor=feature_extractor,
694
+ num_step=params.num_step,
695
+ guidance_scale=params.guidance_scale,
696
+ speed=params.speed,
697
+ t_shift=params.t_shift,
698
+ target_rms=params.target_rms,
699
+ feat_scale=params.feat_scale,
700
+ sampling_rate=params.sampling_rate,
701
+ )
702
+ logging.info("Done")
703
+
704
+
705
+ if __name__ == "__main__":
706
+ torch.set_num_threads(1)
707
+ torch.set_num_interop_threads(1)
708
+
709
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
710
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
711
+
712
+ main()
zipvoice/bin/onnx_export.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Zengwei Yao)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script exports a pre-trained ZipVoice or ZipVoice-Distill model from PyTorch to
20
+ ONNX.
21
+
22
+ Usage:
23
+
24
+ python3 -m zipvoice.bin.onnx_export \
25
+ --model-name zipvoice \
26
+ --model-dir exp/zipvoice \
27
+ --checkpoint-name epoch-11-avg-4.pt \
28
+ --onnx-model-dir exp/zipvoice
29
+
30
+ `--model-name` can be `zipvoice` or `zipvoice_distill`,
31
+ which are the models before and after distillation, respectively.
32
+ """
33
+
34
+
35
+ import argparse
36
+ import json
37
+ import logging
38
+ from pathlib import Path
39
+ from typing import Dict
40
+
41
+ import onnx
42
+ import safetensors.torch
43
+ import torch
44
+ from onnxruntime.quantization import QuantType, quantize_dynamic
45
+ from torch import Tensor, nn
46
+
47
+ from zipvoice.models.zipvoice import ZipVoice
48
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
49
+ from zipvoice.tokenizer.tokenizer import SimpleTokenizer
50
+ from zipvoice.utils.checkpoint import load_checkpoint
51
+ from zipvoice.utils.common import AttributeDict
52
+ from zipvoice.utils.scaling_converter import convert_scaled_to_non_scaled
53
+
54
+
55
+ def get_parser():
56
+ parser = argparse.ArgumentParser(
57
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--onnx-model-dir",
62
+ type=str,
63
+ default="exp",
64
+ help="Dir to the exported models",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--model-name",
69
+ type=str,
70
+ default="zipvoice",
71
+ choices=["zipvoice", "zipvoice_distill"],
72
+ help="The model used for inference",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--model-dir",
77
+ type=str,
78
+ default=None,
79
+ help="The model directory that contains model checkpoint, configuration "
80
+ "file model.json, and tokens file tokens.txt. Will download pre-trained "
81
+ "checkpoint from huggingface if not specified.",
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--checkpoint-name",
86
+ type=str,
87
+ default="model.pt",
88
+ help="The name of model checkpoint.",
89
+ )
90
+
91
+ return parser
92
+
93
+
94
+ def add_meta_data(filename: str, meta_data: Dict[str, str]):
95
+ """Add meta data to an ONNX model. It is changed in-place.
96
+
97
+ Args:
98
+ filename:
99
+ Filename of the ONNX model to be changed.
100
+ meta_data:
101
+ Key-value pairs.
102
+ """
103
+ model = onnx.load(filename)
104
+ for key, value in meta_data.items():
105
+ meta = model.metadata_props.add()
106
+ meta.key = key
107
+ meta.value = value
108
+
109
+ onnx.save(model, filename)
110
+
111
+
112
+ class OnnxTextModel(nn.Module):
113
+ def __init__(self, model: nn.Module):
114
+ """A wrapper for ZipVoice text encoder."""
115
+ super().__init__()
116
+ self.embed = model.embed
117
+ self.text_encoder = model.text_encoder
118
+ self.pad_id = model.pad_id
119
+
120
+ def forward(
121
+ self,
122
+ tokens: Tensor,
123
+ prompt_tokens: Tensor,
124
+ prompt_features_len: Tensor,
125
+ speed: Tensor,
126
+ ) -> Tensor:
127
+ cat_tokens = torch.cat([prompt_tokens, tokens], dim=1)
128
+ cat_tokens = nn.functional.pad(cat_tokens, (0, 1), value=self.pad_id)
129
+ tokens_len = cat_tokens.shape[1] - 1
130
+ padding_mask = (torch.arange(tokens_len + 1) == tokens_len).unsqueeze(0)
131
+
132
+ embed = self.embed(cat_tokens)
133
+ embed = self.text_encoder(x=embed, t=None, padding_mask=padding_mask)
134
+
135
+ features_len = torch.ceil(
136
+ (prompt_features_len / prompt_tokens.shape[1] * tokens_len / speed)
137
+ ).to(dtype=torch.int64)
138
+
139
+ token_dur = torch.div(features_len, tokens_len, rounding_mode="floor").to(
140
+ dtype=torch.int64
141
+ )
142
+
143
+ text_condition = embed[:, :-1, :].unsqueeze(2).expand(-1, -1, token_dur, -1)
144
+ text_condition = text_condition.reshape(embed.shape[0], -1, embed.shape[2])
145
+
146
+ text_condition = torch.cat(
147
+ [
148
+ text_condition,
149
+ embed[:, -1:, :].expand(-1, features_len - text_condition.shape[1], -1),
150
+ ],
151
+ dim=1,
152
+ )
153
+
154
+ return text_condition
155
+
156
+
157
+ class OnnxFlowMatchingModel(nn.Module):
158
+ def __init__(self, model: nn.Module, distill: bool = False):
159
+ """A wrapper for ZipVoice flow-matching decoder."""
160
+ super().__init__()
161
+ self.distill = distill
162
+ self.fm_decoder = model.fm_decoder
163
+ self.model_func = getattr(model, "forward_fm_decoder")
164
+ self.feat_dim = model.feat_dim
165
+
166
+ def forward(
167
+ self,
168
+ t: Tensor,
169
+ x: Tensor,
170
+ text_condition: Tensor,
171
+ speech_condition: torch.Tensor,
172
+ guidance_scale: Tensor,
173
+ ) -> Tensor:
174
+ if self.distill:
175
+ return self.model_func(
176
+ t=t,
177
+ xt=x,
178
+ text_condition=text_condition,
179
+ speech_condition=speech_condition,
180
+ guidance_scale=guidance_scale,
181
+ )
182
+ else:
183
+ x = x.repeat(2, 1, 1)
184
+ text_condition = torch.cat(
185
+ [torch.zeros_like(text_condition), text_condition], dim=0
186
+ )
187
+ speech_condition = torch.cat(
188
+ [
189
+ torch.where(
190
+ t > 0.5, torch.zeros_like(speech_condition), speech_condition
191
+ ),
192
+ speech_condition,
193
+ ],
194
+ dim=0,
195
+ )
196
+ guidance_scale = torch.where(t > 0.5, guidance_scale, guidance_scale * 2.0)
197
+ data_uncond, data_cond = self.model_func(
198
+ t=t,
199
+ xt=x,
200
+ text_condition=text_condition,
201
+ speech_condition=speech_condition,
202
+ ).chunk(2, dim=0)
203
+ v = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
204
+ return v
205
+
206
+
207
+ def export_text_encoder(
208
+ model: OnnxTextModel,
209
+ filename: str,
210
+ opset_version: int = 11,
211
+ ) -> None:
212
+ """Export the text encoder model to ONNX format.
213
+
214
+ Args:
215
+ model:
216
+ The input model
217
+ filename:
218
+ The filename to save the exported ONNX model.
219
+ opset_version:
220
+ The opset version to use.
221
+ """
222
+ tokens = torch.tensor([[2, 3, 4, 5]], dtype=torch.int64)
223
+ prompt_tokens = torch.tensor([[0, 1]], dtype=torch.int64)
224
+ prompt_features_len = torch.tensor(10, dtype=torch.int64)
225
+ speed = torch.tensor(1.0, dtype=torch.float32)
226
+
227
+ model = torch.jit.trace(model, (tokens, prompt_tokens, prompt_features_len, speed))
228
+
229
+ torch.onnx.export(
230
+ model,
231
+ (tokens, prompt_tokens, prompt_features_len, speed),
232
+ filename,
233
+ verbose=False,
234
+ opset_version=opset_version,
235
+ input_names=["tokens", "prompt_tokens", "prompt_features_len", "speed"],
236
+ output_names=["text_condition"],
237
+ dynamic_axes={
238
+ "tokens": {0: "N", 1: "T"},
239
+ "prompt_tokens": {0: "N", 1: "T"},
240
+ "text_condition": {0: "N", 1: "T"},
241
+ },
242
+ )
243
+
244
+ meta_data = {
245
+ "version": "1",
246
+ "model_author": "k2-fsa",
247
+ "comment": "ZipVoice text encoder",
248
+ }
249
+ logging.info(f"meta_data: {meta_data}")
250
+ add_meta_data(filename=filename, meta_data=meta_data)
251
+
252
+ logging.info(f"Exported to {filename}")
253
+
254
+
255
+ def export_fm_decoder(
256
+ model: OnnxFlowMatchingModel,
257
+ filename: str,
258
+ opset_version: int = 11,
259
+ ) -> None:
260
+ """Export the flow matching decoder model to ONNX format.
261
+
262
+ Args:
263
+ model:
264
+ The input model
265
+ filename:
266
+ The filename to save the exported ONNX model.
267
+ opset_version:
268
+ The opset version to use.
269
+ """
270
+ feat_dim = model.feat_dim
271
+ seq_len = 200
272
+ t = torch.tensor(0.5, dtype=torch.float32)
273
+ x = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
274
+ text_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
275
+ speech_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
276
+ guidance_scale = torch.tensor(1.0, dtype=torch.float32)
277
+
278
+ model = torch.jit.trace(
279
+ model, (t, x, text_condition, speech_condition, guidance_scale)
280
+ )
281
+
282
+ torch.onnx.export(
283
+ model,
284
+ (t, x, text_condition, speech_condition, guidance_scale),
285
+ filename,
286
+ verbose=False,
287
+ opset_version=opset_version,
288
+ input_names=["t", "x", "text_condition", "speech_condition", "guidance_scale"],
289
+ output_names=["v"],
290
+ dynamic_axes={
291
+ "x": {0: "N", 1: "T"},
292
+ "text_condition": {0: "N", 1: "T"},
293
+ "speech_condition": {0: "N", 1: "T"},
294
+ "v": {0: "N", 1: "T"},
295
+ },
296
+ )
297
+
298
+ meta_data = {
299
+ "version": "1",
300
+ "model_author": "k2-fsa",
301
+ "comment": "ZipVoice flow-matching decoder",
302
+ "feat_dim": str(feat_dim),
303
+ }
304
+ logging.info(f"meta_data: {meta_data}")
305
+ add_meta_data(filename=filename, meta_data=meta_data)
306
+
307
+ logging.info(f"Exported to {filename}")
308
+
309
+
310
+ @torch.no_grad()
311
+ def main():
312
+ parser = get_parser()
313
+ args = parser.parse_args()
314
+
315
+ params = AttributeDict()
316
+ params.update(vars(args))
317
+
318
+ params.model_dir = Path(params.model_dir)
319
+ if not params.model_dir.is_dir():
320
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
321
+ for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
322
+ if not (params.model_dir / filename).is_file():
323
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
324
+ model_ckpt = params.model_dir / params.checkpoint_name
325
+ model_config = params.model_dir / "model.json"
326
+ token_file = params.model_dir / "tokens.txt"
327
+
328
+ logging.info(f"Loading model from {params.model_dir}")
329
+
330
+ tokenizer = SimpleTokenizer(token_file)
331
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
332
+
333
+ with open(model_config, "r") as f:
334
+ model_config = json.load(f)
335
+
336
+ if params.model_name == "zipvoice":
337
+ model = ZipVoice(
338
+ **model_config["model"],
339
+ **tokenizer_config,
340
+ )
341
+ distill = False
342
+ else:
343
+ assert params.model_name == "zipvoice_distill"
344
+ model = ZipVoiceDistill(
345
+ **model_config["model"],
346
+ **tokenizer_config,
347
+ )
348
+ distill = True
349
+
350
+ if str(model_ckpt).endswith(".safetensors"):
351
+ safetensors.torch.load_model(model, model_ckpt)
352
+ elif str(model_ckpt).endswith(".pt"):
353
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
354
+ else:
355
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
356
+
357
+ device = torch.device("cpu")
358
+ model = model.to(device)
359
+ model.eval()
360
+
361
+ convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
362
+
363
+ logging.info("Exporting model")
364
+ onnx_model_dir = Path(params.onnx_model_dir)
365
+ onnx_model_dir.mkdir(parents=True, exist_ok=True)
366
+ opset_version = 11
367
+
368
+ text_encoder = OnnxTextModel(model=model)
369
+ text_encoder_file = onnx_model_dir / "text_encoder.onnx"
370
+ export_text_encoder(
371
+ model=text_encoder,
372
+ filename=text_encoder_file,
373
+ opset_version=opset_version,
374
+ )
375
+
376
+ fm_decoder = OnnxFlowMatchingModel(model=model, distill=distill)
377
+ fm_decoder_file = onnx_model_dir / "fm_decoder.onnx"
378
+ export_fm_decoder(
379
+ model=fm_decoder,
380
+ filename=fm_decoder_file,
381
+ opset_version=opset_version,
382
+ )
383
+
384
+ logging.info("Generate int8 quantization models")
385
+
386
+ text_encoder_int8_file = onnx_model_dir / "text_encoder_int8.onnx"
387
+ quantize_dynamic(
388
+ model_input=text_encoder_file,
389
+ model_output=text_encoder_int8_file,
390
+ op_types_to_quantize=["MatMul"],
391
+ weight_type=QuantType.QInt8,
392
+ )
393
+
394
+ fm_decoder_int8_file = onnx_model_dir / "fm_decoder_int8.onnx"
395
+ quantize_dynamic(
396
+ model_input=fm_decoder_file,
397
+ model_output=fm_decoder_int8_file,
398
+ op_types_to_quantize=["MatMul"],
399
+ weight_type=QuantType.QInt8,
400
+ )
401
+
402
+ logging.info("Done!")
403
+
404
+
405
+ if __name__ == "__main__":
406
+
407
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
408
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
409
+
410
+ main()
zipvoice/bin/prepare_dataset.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script generates lhotse manifest files from TSV files for custom datasets.
20
+
21
+ Each line of the TSV files should be in one of the following formats:
22
+ 1. "{uniq_id}\t{text}\t{wav_path}" if the text corresponds to the full wav",
23
+ 2. "{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time} if text corresponds
24
+ to part of the wav. The start_time and end_time specify the start and end
25
+ times of the text within the wav, which should be in seconds.
26
+
27
+ Note: {uniq_id} must be unique for each line.
28
+
29
+ Usage:
30
+
31
+ Suppose you have two TSV files: "custom_train.tsv" and "custom_dev.tsv",
32
+ where "custom" is your dataset name, "train"/"dev" are used for training and
33
+ validation respectively.
34
+
35
+ (1) Prepare the training data
36
+
37
+ python3 -m zipvoice.bin.prepare_dataset \
38
+ --tsv-path data/raw/custom_train.tsv \
39
+ --prefix "custom" \
40
+ --subset "train" \
41
+ --num-jobs 20 \
42
+ --output-dir "data/manifests"
43
+
44
+ The output file would be "data/manifests/custom_cuts_train.jsonl.gz".
45
+
46
+ (2) Prepare the validation data
47
+
48
+ python3 -m zipvoice.bin.prepare_dataset \
49
+ --tsv-path data/raw/custom_dev.tsv \
50
+ --prefix "custom" \
51
+ --subset "dev" \
52
+ --num-jobs 1 \
53
+ --output-dir "data/manifests"
54
+
55
+ The output file would be "data/manifests/custom_cuts_dev.jsonl.gz".
56
+
57
+ """
58
+
59
+ import argparse
60
+ import logging
61
+ import re
62
+ from concurrent.futures import ThreadPoolExecutor
63
+ from pathlib import Path
64
+ from typing import List, Optional, Tuple
65
+
66
+ from lhotse import CutSet, validate_recordings_and_supervisions
67
+ from lhotse.audio import Recording, RecordingSet
68
+ from lhotse.qa import fix_manifests
69
+ from lhotse.supervision import SupervisionSegment, SupervisionSet
70
+ from lhotse.utils import Pathlike
71
+ from tqdm.auto import tqdm
72
+
73
+
74
+ def get_args():
75
+ parser = argparse.ArgumentParser()
76
+
77
+ parser.add_argument(
78
+ "--tsv-path",
79
+ type=str,
80
+ help="The path of the tsv file. Each line should be in the format: "
81
+ "{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time} "
82
+ "if text corresponds to part of the wav or {uniq_id}\t{text}\t{wav_path} "
83
+ "if the text corresponds to the full wav",
84
+ )
85
+ parser.add_argument(
86
+ "--prefix",
87
+ type=str,
88
+ default="custom",
89
+ help="Prefix of the output manifest file.",
90
+ )
91
+
92
+ parser.add_argument(
93
+ "--subset",
94
+ type=str,
95
+ default="train",
96
+ help="Subset name manifest file, typically train or dev.",
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--num-jobs",
101
+ type=int,
102
+ default=20,
103
+ help="Number of jobs to processing.",
104
+ )
105
+
106
+ parser.add_argument(
107
+ "--output-dir",
108
+ type=str,
109
+ default="data/manifests",
110
+ help="The destination directory of manifest files.",
111
+ )
112
+ parser.add_argument(
113
+ "--sampling-rate",
114
+ type=int,
115
+ default=24000,
116
+ help="The target sampling rate.",
117
+ )
118
+ return parser.parse_args()
119
+
120
+
121
+ def _parse_recording(
122
+ wav_path: str,
123
+ ) -> Tuple[Recording, str]:
124
+ """
125
+ :param wav_path: Path to the audio file
126
+ :return: a tuple of "recording" and "recording_id"
127
+ """
128
+
129
+ recording_id = wav_path.replace("/", "_").replace(".", "_")
130
+ recording = Recording.from_file(path=wav_path, recording_id=recording_id)
131
+
132
+ return recording, recording_id
133
+
134
+
135
+ def _parse_supervision(
136
+ supervision: List, recording_dict: dict
137
+ ) -> Optional[SupervisionSegment]:
138
+ """
139
+ :param line: A line from the TSV file
140
+ :param recording_dict: Dictionary mapping recording IDs to Recording objects
141
+ :return: A SupervisionSegment object
142
+ """
143
+
144
+ uniq_id, text, wav_path, start, end = supervision
145
+ try:
146
+ recording_id = wav_path.replace("/", "_").replace(".", "_")
147
+
148
+ recording = recording_dict[recording_id]
149
+ duration = end - start if end is not None else recording.duration
150
+ assert duration <= recording.duration, f"Duration {duration} is greater than "
151
+ f"recording duration {recording.duration}"
152
+
153
+ text = re.sub("_", " ", text) # "_" is treated as padding symbol
154
+ text = re.sub(r"\s+", " ", text) # remove extra whitespace
155
+
156
+ return SupervisionSegment(
157
+ id=f"{uniq_id}",
158
+ recording_id=recording.id,
159
+ start=start,
160
+ duration=duration,
161
+ channel=recording.channel_ids,
162
+ text=text.strip(),
163
+ )
164
+ except Exception as e:
165
+ logging.warning(f"Error processing line: {e}")
166
+ return None
167
+
168
+
169
+ def prepare_dataset(
170
+ tsv_path: Pathlike,
171
+ prefix: str,
172
+ subset: str,
173
+ sampling_rate: int,
174
+ num_jobs: int,
175
+ output_dir: Pathlike,
176
+ ):
177
+ """
178
+ Returns the manifests which consist of the Recordings and Supervisions
179
+
180
+ :param tsv_path: Path to the TSV file
181
+ :param output_dir: Path where to write the manifests
182
+ :param num_jobs: Number of processes for parallel processing
183
+ :return: The CutSet containing the data
184
+ """
185
+ logging.info(f"Preparing {prefix} dataset {subset} subset.")
186
+ output_dir = Path(output_dir)
187
+ output_dir.mkdir(parents=True, exist_ok=True)
188
+ file_name = f"{prefix}_cuts_{subset}.jsonl.gz"
189
+ if (output_dir / file_name).is_file():
190
+ logging.info(f"{file_name} exists, skipping.")
191
+ return
192
+
193
+ # Step 1: Read all unique recording paths
194
+ recordings_path_set = set()
195
+ supervision_list = list()
196
+ with open(tsv_path, "r") as fr:
197
+ for line in fr:
198
+ items = line.strip().split("\t")
199
+ if len(items) == 3:
200
+ uniq_id, text, wav_path = items
201
+ start, end = 0, None
202
+ elif len(items) == 5:
203
+ uniq_id, text, wav_path, start, end = items
204
+ start, end = float(start), float(end)
205
+ else:
206
+ raise ValueError(
207
+ f"Invalid line format: {line},"
208
+ "requries to be 3 columns or 5 columns"
209
+ )
210
+ recordings_path_set.add(wav_path)
211
+ supervision_list.append((uniq_id, text, wav_path, start, end))
212
+
213
+ logging.info("Starting to process recordings...")
214
+ # Step 2: Process recordings
215
+ futures = []
216
+ recording_dict = {}
217
+ with ThreadPoolExecutor(max_workers=num_jobs) as ex:
218
+ for wav_path in tqdm(recordings_path_set, desc="Submitting jobs"):
219
+ futures.append(ex.submit(_parse_recording, wav_path))
220
+
221
+ for future in tqdm(futures, desc="Processing recordings"):
222
+ try:
223
+ recording, recording_id = future.result()
224
+ recording_dict[recording_id] = recording
225
+ except Exception as e:
226
+ logging.warning(
227
+ f"Error processing recording {recording_id} with error: {e}"
228
+ )
229
+
230
+ recording_set = RecordingSet.from_recordings(recording_dict.values())
231
+
232
+ logging.info("Starting to process supervisions...")
233
+ # Step 3: Process supervisions
234
+ supervisions = []
235
+ for supervision in tqdm(supervision_list, desc="Processing supervisions"):
236
+ seg = _parse_supervision(supervision, recording_dict)
237
+ if seg is not None:
238
+ supervisions.append(seg)
239
+
240
+ logging.info("Processing Cuts...")
241
+
242
+ # Step 4: Create and validate manifests
243
+ supervision_set = SupervisionSet.from_segments(supervisions)
244
+
245
+ recording_set, supervision_set = fix_manifests(recording_set, supervision_set)
246
+ validate_recordings_and_supervisions(recording_set, supervision_set)
247
+
248
+ cut_set = CutSet.from_manifests(
249
+ recordings=recording_set, supervisions=supervision_set
250
+ )
251
+ cut_set = cut_set.sort_by_recording_id()
252
+ cut_set = cut_set.resample(sampling_rate)
253
+ cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
254
+
255
+ logging.info(f"Saving file to {output_dir / file_name}")
256
+ # Step 5: Write manifests to disk
257
+ cut_set.to_file(output_dir / file_name)
258
+ logging.info("Done!")
259
+
260
+
261
+ if __name__ == "__main__":
262
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
263
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
264
+
265
+ args = get_args()
266
+
267
+ prepare_dataset(
268
+ tsv_path=args.tsv_path,
269
+ prefix=args.prefix,
270
+ subset=args.subset,
271
+ sampling_rate=args.sampling_rate,
272
+ num_jobs=args.num_jobs,
273
+ output_dir=args.output_dir,
274
+ )
zipvoice/bin/prepare_tokens.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file reads the texts in given manifest and save the new cuts with prepared tokens.
3
+ """
4
+
5
+ import argparse
6
+ import logging
7
+ from functools import partial
8
+ from pathlib import Path
9
+
10
+ from lhotse import load_manifest, split_parallelize_combine
11
+
12
+ from zipvoice.tokenizer.tokenizer import add_tokens
13
+
14
+
15
+ def get_args():
16
+ parser = argparse.ArgumentParser()
17
+
18
+ parser.add_argument(
19
+ "--input-file",
20
+ type=str,
21
+ help="Input manifest without tokens",
22
+ )
23
+
24
+ parser.add_argument(
25
+ "--output-file",
26
+ type=str,
27
+ help="Output manifest with tokens.",
28
+ )
29
+
30
+ parser.add_argument(
31
+ "--num-jobs",
32
+ type=int,
33
+ default=20,
34
+ help="Number of jobs to run in parallel.",
35
+ )
36
+
37
+ parser.add_argument(
38
+ "--tokenizer",
39
+ type=str,
40
+ default="emilia",
41
+ help="The destination directory of manifest files.",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--lang",
46
+ type=str,
47
+ default="en-us",
48
+ help="Language identifier, used when tokenizer type is espeak. see"
49
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
50
+ )
51
+
52
+ return parser.parse_args()
53
+
54
+
55
+ def prepare_tokens(
56
+ input_file: Path,
57
+ output_file: Path,
58
+ num_jobs: int,
59
+ tokenizer: str,
60
+ lang: str = "en-us",
61
+ ):
62
+ logging.info(f"Processing {input_file}")
63
+ if output_file.is_file():
64
+ logging.info(f"{output_file} exists, skipping.")
65
+ return
66
+ logging.info(f"loading manifest from {input_file}")
67
+ cut_set = load_manifest(input_file)
68
+
69
+ _add_tokens = partial(add_tokens, tokenizer=tokenizer, lang=lang)
70
+
71
+ logging.info("Adding tokens")
72
+
73
+ cut_set = split_parallelize_combine(
74
+ num_jobs=num_jobs, manifest=cut_set, fn=_add_tokens
75
+ )
76
+
77
+ logging.info(f"Saving file to {output_file}")
78
+ cut_set.to_file(output_file)
79
+
80
+
81
+ if __name__ == "__main__":
82
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
83
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
84
+
85
+ args = get_args()
86
+ input_file = Path(args.input_file)
87
+ output_file = Path(args.output_file)
88
+ num_jobs = args.num_jobs
89
+ tokenizer = args.tokenizer
90
+ lang = args.lang
91
+
92
+ output_file.parent.mkdir(parents=True, exist_ok=True)
93
+
94
+ prepare_tokens(
95
+ input_file=input_file,
96
+ output_file=output_file,
97
+ num_jobs=num_jobs,
98
+ tokenizer=tokenizer,
99
+ lang=lang,
100
+ )
101
+
102
+ logging.info("Done!")
zipvoice/bin/train_zipvoice.py ADDED
@@ -0,0 +1,1136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang,
3
+ # Han Zhu)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """
20
+ This script trains a ZipVoice model with the flow-matching loss.
21
+
22
+ Usage:
23
+
24
+ python3 -m zipvoice.bin.train_zipvoice \
25
+ --world-size 8 \
26
+ --use-fp16 1 \
27
+ --num-epochs 11 \
28
+ --max-duration 500 \
29
+ --lr-hours 30000 \
30
+ --model-config conf/zipvoice_base.json \
31
+ --tokenizer emilia \
32
+ --token-file "data/tokens_emilia.txt" \
33
+ --dataset emilia \
34
+ --manifest-dir data/fbank \
35
+ --exp-dir exp/zipvoice
36
+ """
37
+
38
+ import argparse
39
+ import copy
40
+ import json
41
+ import logging
42
+ import os
43
+ from functools import partial
44
+ from pathlib import Path
45
+ from shutil import copyfile
46
+ from typing import List, Optional, Tuple, Union
47
+
48
+ import torch
49
+ import torch.multiprocessing as mp
50
+ import torch.nn as nn
51
+ from lhotse.cut import Cut, CutSet
52
+ from lhotse.utils import fix_random_seed
53
+ from torch import Tensor
54
+ from torch.amp.grad_scaler import GradScaler
55
+ from torch.nn.parallel import DistributedDataParallel as DDP
56
+ from torch.optim import Optimizer
57
+ from torch.utils.tensorboard import SummaryWriter
58
+
59
+ import zipvoice.utils.diagnostics as diagnostics
60
+ from zipvoice.dataset.datamodule import TtsDataModule
61
+ from zipvoice.models.zipvoice import ZipVoice
62
+ from zipvoice.tokenizer.tokenizer import (
63
+ EmiliaTokenizer,
64
+ EspeakTokenizer,
65
+ LibriTTSTokenizer,
66
+ SimpleTokenizer,
67
+ SimpleTokenizer2,
68
+ )
69
+ from zipvoice.utils.checkpoint import (
70
+ load_checkpoint,
71
+ remove_checkpoints,
72
+ resume_checkpoint,
73
+ save_checkpoint,
74
+ save_checkpoint_with_global_batch_idx,
75
+ update_averaged_model,
76
+ )
77
+ from zipvoice.utils.common import (
78
+ AttributeDict,
79
+ MetricsTracker,
80
+ cleanup_dist,
81
+ create_grad_scaler,
82
+ get_adjusted_batch_count,
83
+ get_env_info,
84
+ get_parameter_groups_with_lrs,
85
+ prepare_input,
86
+ set_batch_count,
87
+ setup_dist,
88
+ setup_logger,
89
+ str2bool,
90
+ torch_autocast,
91
+ )
92
+ from zipvoice.utils.hooks import register_inf_check_hooks
93
+ from zipvoice.utils.lr_scheduler import Eden, FixedLRScheduler, LRScheduler
94
+ from zipvoice.utils.optim import ScaledAdam
95
+
96
+ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
97
+
98
+
99
+ def get_parser():
100
+ parser = argparse.ArgumentParser(
101
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
102
+ )
103
+
104
+ parser.add_argument(
105
+ "--world-size",
106
+ type=int,
107
+ default=1,
108
+ help="Number of GPUs for DDP training.",
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--master-port",
113
+ type=int,
114
+ default=12356,
115
+ help="Master port to use for DDP training.",
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--tensorboard",
120
+ type=str2bool,
121
+ default=True,
122
+ help="Should various information be logged in tensorboard.",
123
+ )
124
+
125
+ parser.add_argument(
126
+ "--num-epochs",
127
+ type=int,
128
+ default=11,
129
+ help="Number of epochs to train.",
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--num-iters",
134
+ type=int,
135
+ default=0,
136
+ help="Number of iter to train, will ignore num_epochs if > 0.",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--start-epoch",
141
+ type=int,
142
+ default=1,
143
+ help="""Resume training from this epoch. It should be positive.
144
+ If larger than 1, it will load checkpoint from
145
+ exp-dir/epoch-{start_epoch-1}.pt
146
+ """,
147
+ )
148
+
149
+ parser.add_argument(
150
+ "--checkpoint",
151
+ type=str,
152
+ default=None,
153
+ help="""Checkpoints of pre-trained models, will load it if not None
154
+ """,
155
+ )
156
+
157
+ parser.add_argument(
158
+ "--exp-dir",
159
+ type=str,
160
+ default="exp/zipvoice",
161
+ help="""The experiment dir.
162
+ It specifies the directory where all training related
163
+ files, e.g., checkpoints, log, etc, are saved
164
+ """,
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--base-lr", type=float, default=0.02, help="The base learning rate."
169
+ )
170
+
171
+ parser.add_argument(
172
+ "--lr-batches",
173
+ type=float,
174
+ default=7500,
175
+ help="""Number of steps that affects how rapidly the learning rate
176
+ decreases. We suggest not to change this.""",
177
+ )
178
+
179
+ parser.add_argument(
180
+ "--lr-epochs",
181
+ type=float,
182
+ default=10,
183
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
184
+ """,
185
+ )
186
+
187
+ parser.add_argument(
188
+ "--lr-hours",
189
+ type=float,
190
+ default=0,
191
+ help="""If positive, --epoch is ignored and it specifies the number of hours
192
+ that affects how rapidly the learning rate decreases.
193
+ """,
194
+ )
195
+
196
+ parser.add_argument(
197
+ "--ref-duration",
198
+ type=float,
199
+ default=50,
200
+ help="""Reference batch duration for purposes of adjusting batch counts for"
201
+ setting various schedules inside the model".
202
+ """,
203
+ )
204
+
205
+ parser.add_argument(
206
+ "--finetune",
207
+ type=str2bool,
208
+ default=False,
209
+ help="Whether to use the fine-tuning mode, will used a fixed learning rate "
210
+ "schedule and skip the large dropout phase.",
211
+ )
212
+
213
+ parser.add_argument(
214
+ "--seed",
215
+ type=int,
216
+ default=42,
217
+ help="The seed for random generators intended for reproducibility",
218
+ )
219
+
220
+ parser.add_argument(
221
+ "--print-diagnostics",
222
+ type=str2bool,
223
+ default=False,
224
+ help="Accumulate stats on activations, print them and exit.",
225
+ )
226
+
227
+ parser.add_argument(
228
+ "--scan-oom",
229
+ type=str2bool,
230
+ default=False,
231
+ help="Scan pessimistic batches to see whether they cause OOMs.",
232
+ )
233
+
234
+ parser.add_argument(
235
+ "--inf-check",
236
+ type=str2bool,
237
+ default=False,
238
+ help="Add hooks to check for infinite module outputs and gradients.",
239
+ )
240
+
241
+ parser.add_argument(
242
+ "--save-every-n",
243
+ type=int,
244
+ default=5000,
245
+ help="""Save checkpoint after processing this number of batches"
246
+ periodically. We save checkpoint to exp-dir/ whenever
247
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
248
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
249
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
250
+ end of each epoch where `xxx` is the epoch number counting from 1.
251
+ """,
252
+ )
253
+
254
+ parser.add_argument(
255
+ "--valid-by-epoch",
256
+ type=str2bool,
257
+ default=False,
258
+ help="""Whether to validate after each epoch. If False, will validate
259
+ after every save_every_n iterations.
260
+ """,
261
+ )
262
+
263
+ parser.add_argument(
264
+ "--keep-last-k",
265
+ type=int,
266
+ default=30,
267
+ help="""Only keep this number of checkpoints on disk.
268
+ For instance, if it is 3, there are only 3 checkpoints
269
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
270
+ It does not affect checkpoints with name `epoch-xxx.pt`.
271
+ """,
272
+ )
273
+
274
+ parser.add_argument(
275
+ "--average-period",
276
+ type=int,
277
+ default=200,
278
+ help="""Update the averaged model, namely `model_avg`, after processing
279
+ this number of batches. `model_avg` is a separate version of model,
280
+ in which each floating-point parameter is the average of all the
281
+ parameters from the start of training. Each time we take the average,
282
+ we do: `model_avg = model * (average_period / batch_idx_train) +
283
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
284
+ """,
285
+ )
286
+
287
+ parser.add_argument(
288
+ "--use-fp16",
289
+ type=str2bool,
290
+ default=True,
291
+ help="Whether to use half precision training.",
292
+ )
293
+
294
+ parser.add_argument(
295
+ "--feat-scale",
296
+ type=float,
297
+ default=0.1,
298
+ help="The scale factor of fbank feature",
299
+ )
300
+
301
+ parser.add_argument(
302
+ "--condition-drop-ratio",
303
+ type=float,
304
+ default=0.2,
305
+ help="The drop rate of text condition during training.",
306
+ )
307
+
308
+ parser.add_argument(
309
+ "--dataset",
310
+ type=str,
311
+ default="emilia",
312
+ choices=["emilia", "libritts", "custom"],
313
+ help="The used training dataset",
314
+ )
315
+
316
+ parser.add_argument(
317
+ "--train-manifest",
318
+ type=str,
319
+ help="Path of the training manifest",
320
+ )
321
+
322
+ parser.add_argument(
323
+ "--dev-manifest",
324
+ type=str,
325
+ help="Path of the validation manifest",
326
+ )
327
+
328
+ parser.add_argument(
329
+ "--min-len",
330
+ type=float,
331
+ default=1.0,
332
+ help="The minimum audio length used for training",
333
+ )
334
+
335
+ parser.add_argument(
336
+ "--max-len",
337
+ type=float,
338
+ default=30.0,
339
+ help="The maximum audio length used for training",
340
+ )
341
+
342
+ parser.add_argument(
343
+ "--model-config",
344
+ type=str,
345
+ default="conf/zipvoice_base.json",
346
+ help="The model configuration file.",
347
+ )
348
+
349
+ parser.add_argument(
350
+ "--tokenizer",
351
+ type=str,
352
+ default="emilia",
353
+ help="Tokenizer type.",
354
+ )
355
+
356
+ parser.add_argument(
357
+ "--lang",
358
+ type=str,
359
+ default="en-us",
360
+ help="Language identifier, used when tokenizer type is espeak. see"
361
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
362
+ )
363
+
364
+ parser.add_argument(
365
+ "--token-file",
366
+ type=str,
367
+ default="data/tokens_emilia.txt",
368
+ help="The file that contains information that maps tokens to ids,"
369
+ "which is a text file with '{token}\t{token_id}' per line.",
370
+ )
371
+
372
+ return parser
373
+
374
+
375
+ def get_params() -> AttributeDict:
376
+ """Return a dict containing training parameters.
377
+
378
+ All training related parameters that are not passed from the commandline
379
+ are saved in the variable `params`.
380
+
381
+ Commandline options are merged into `params` after they are parsed, so
382
+ you can also access them via `params`.
383
+
384
+ Explanation of options saved in `params`:
385
+
386
+ - best_train_loss: Best training loss so far. It is used to select
387
+ the model that has the lowest training loss. It is
388
+ updated during the training.
389
+
390
+ - best_valid_loss: Best validation loss so far. It is used to select
391
+ the model that has the lowest validation loss. It is
392
+ updated during the training.
393
+
394
+ - best_train_epoch: It is the epoch that has the best training loss.
395
+
396
+ - best_valid_epoch: It is the epoch that has the best validation loss.
397
+
398
+ - batch_idx_train: Used to writing statistics to tensorboard. It
399
+ contains number of batches trained so far across
400
+ epochs.
401
+
402
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
403
+
404
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
405
+
406
+ - env_info: A dict containing information about the environment.
407
+
408
+ """
409
+ params = AttributeDict(
410
+ {
411
+ "best_train_loss": float("inf"),
412
+ "best_valid_loss": float("inf"),
413
+ "best_train_epoch": -1,
414
+ "best_valid_epoch": -1,
415
+ "batch_idx_train": 0,
416
+ "log_interval": 50,
417
+ "reset_interval": 200,
418
+ "env_info": get_env_info(),
419
+ }
420
+ )
421
+
422
+ return params
423
+
424
+
425
+ def compute_fbank_loss(
426
+ params: AttributeDict,
427
+ model: Union[nn.Module, DDP],
428
+ features: Tensor,
429
+ features_lens: Tensor,
430
+ tokens: List[List[int]],
431
+ is_training: bool,
432
+ ) -> Tuple[Tensor, MetricsTracker]:
433
+ """
434
+ Compute loss given the model and its inputs.
435
+
436
+ Args:
437
+ params:
438
+ Parameters for training. See :func:`get_params`.
439
+ model:
440
+ The model for training.
441
+ features:
442
+ The target acoustic feature.
443
+ features_lens:
444
+ The number of frames of each utterance.
445
+ tokens:
446
+ Input tokens that representing the transcripts.
447
+ is_training:
448
+ True for training. False for validation. When it is True, this
449
+ function enables autograd during computation; when it is False, it
450
+ disables autograd.
451
+ """
452
+
453
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
454
+
455
+ batch_size, num_frames, _ = features.shape
456
+
457
+ features = torch.nn.functional.pad(
458
+ features, (0, 0, 0, num_frames - features.size(1))
459
+ ) # (B, T, F)
460
+ noise = torch.randn_like(features) # (B, T, F)
461
+
462
+ # Sampling t from uniform distribution
463
+ if is_training:
464
+ t = torch.rand(batch_size, 1, 1, device=device)
465
+ else:
466
+ t = (
467
+ (torch.arange(batch_size, device=device) / batch_size)
468
+ .unsqueeze(1)
469
+ .unsqueeze(2)
470
+ )
471
+ with torch.set_grad_enabled(is_training):
472
+
473
+ loss = model(
474
+ tokens=tokens,
475
+ features=features,
476
+ features_lens=features_lens,
477
+ noise=noise,
478
+ t=t,
479
+ condition_drop_ratio=params.condition_drop_ratio,
480
+ )
481
+
482
+ assert loss.requires_grad == is_training
483
+ info = MetricsTracker()
484
+ num_frames = features_lens.sum().item()
485
+ info["frames"] = num_frames
486
+ info["loss"] = loss.detach().cpu().item() * num_frames
487
+
488
+ return loss, info
489
+
490
+
491
+ def train_one_epoch(
492
+ params: AttributeDict,
493
+ model: Union[nn.Module, DDP],
494
+ optimizer: Optimizer,
495
+ scheduler: LRSchedulerType,
496
+ train_dl: torch.utils.data.DataLoader,
497
+ valid_dl: torch.utils.data.DataLoader,
498
+ scaler: GradScaler,
499
+ model_avg: Optional[nn.Module] = None,
500
+ tb_writer: Optional[SummaryWriter] = None,
501
+ world_size: int = 1,
502
+ rank: int = 0,
503
+ ) -> None:
504
+ """Train the model for one epoch.
505
+
506
+ The training loss from the mean of all frames is saved in
507
+ `params.train_loss`. It runs the validation process every
508
+ `params.valid_interval` batches or every epochs.
509
+
510
+ Args:
511
+ params:
512
+ It is returned by :func:`get_params`.
513
+ model:
514
+ The model for training.
515
+ optimizer:
516
+ The optimizer.
517
+ scheduler:
518
+ The learning rate scheduler, we call step() every epoch.
519
+ train_dl:
520
+ Dataloader for the training dataset.
521
+ valid_dl:
522
+ Dataloader for the validation dataset.
523
+ scaler:
524
+ The scaler used for mix precision training.
525
+ tb_writer:
526
+ Writer to write log messages to tensorboard.
527
+ world_size:
528
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
529
+ rank:
530
+ The rank of the node in DDP training. If no DDP is used, it should
531
+ be set to 0.
532
+ """
533
+ model.train()
534
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
535
+
536
+ # used to track the stats over iterations in one epoch
537
+ tot_loss = MetricsTracker()
538
+
539
+ saved_bad_model = False
540
+
541
+ def save_bad_model(suffix: str = ""):
542
+ save_checkpoint(
543
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
544
+ model=model,
545
+ model_avg=model_avg,
546
+ params=params,
547
+ optimizer=optimizer,
548
+ scheduler=scheduler,
549
+ sampler=train_dl.sampler,
550
+ scaler=scaler,
551
+ rank=0,
552
+ )
553
+
554
+ for batch_idx, batch in enumerate(train_dl):
555
+
556
+ if batch_idx % 10 == 0:
557
+ if params.finetune:
558
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
559
+ else:
560
+ set_batch_count(model, get_adjusted_batch_count(params))
561
+
562
+ if (
563
+ params.valid_by_epoch and batch_idx == 0 and not params.print_diagnostics
564
+ ) or (
565
+ not params.valid_by_epoch
566
+ and params.batch_idx_train % params.valid_interval == 0
567
+ and not params.print_diagnostics
568
+ ):
569
+ logging.info("Computing validation loss")
570
+ valid_info = compute_validation_loss(
571
+ params=params,
572
+ model=model,
573
+ valid_dl=valid_dl,
574
+ world_size=world_size,
575
+ )
576
+ model.train()
577
+ logging.info(
578
+ f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
579
+ f" validation: {valid_info}"
580
+ )
581
+ logging.info(
582
+ f"Maximum memory allocated so far is "
583
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
584
+ )
585
+ if tb_writer is not None:
586
+ valid_info.write_summary(
587
+ tb_writer, "train/valid_", params.batch_idx_train
588
+ )
589
+
590
+ params.batch_idx_train += 1
591
+
592
+ batch_size = len(batch["text"])
593
+
594
+ tokens, features, features_lens = prepare_input(
595
+ params=params,
596
+ batch=batch,
597
+ device=device,
598
+ return_tokens=True,
599
+ return_feature=True,
600
+ )
601
+
602
+ try:
603
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
604
+ loss, loss_info = compute_fbank_loss(
605
+ params=params,
606
+ model=model,
607
+ features=features,
608
+ features_lens=features_lens,
609
+ tokens=tokens,
610
+ is_training=True,
611
+ )
612
+
613
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
614
+
615
+ scaler.scale(loss).backward()
616
+
617
+ scheduler.step_batch(params.batch_idx_train)
618
+ # Use the number of hours of speech to adjust the learning rate
619
+ if params.lr_hours > 0:
620
+ scheduler.step_epoch(
621
+ params.batch_idx_train
622
+ * params.max_duration
623
+ * params.world_size
624
+ / 3600
625
+ )
626
+ scaler.step(optimizer)
627
+ scaler.update()
628
+ optimizer.zero_grad()
629
+ except Exception as e:
630
+ logging.info(f"Caught exception : {e}.")
631
+ save_bad_model()
632
+ raise
633
+
634
+ if params.print_diagnostics and batch_idx == 5:
635
+ return
636
+
637
+ if (
638
+ rank == 0
639
+ and params.batch_idx_train > 0
640
+ and params.batch_idx_train % params.average_period == 0
641
+ ):
642
+ update_averaged_model(
643
+ params=params,
644
+ model_cur=model,
645
+ model_avg=model_avg,
646
+ )
647
+
648
+ if (
649
+ params.batch_idx_train > 0
650
+ and params.batch_idx_train % params.save_every_n == 0
651
+ ):
652
+ save_checkpoint_with_global_batch_idx(
653
+ out_dir=params.exp_dir,
654
+ global_batch_idx=params.batch_idx_train,
655
+ model=model,
656
+ model_avg=model_avg,
657
+ params=params,
658
+ optimizer=optimizer,
659
+ scheduler=scheduler,
660
+ sampler=train_dl.sampler,
661
+ scaler=scaler,
662
+ rank=rank,
663
+ )
664
+ remove_checkpoints(
665
+ out_dir=params.exp_dir,
666
+ topk=params.keep_last_k,
667
+ rank=rank,
668
+ )
669
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
670
+ break
671
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
672
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
673
+ # of the grad scaler is configurable, but we can't configure it to have
674
+ # different behavior depending on the current grad scale.
675
+ cur_grad_scale = scaler._scale.item()
676
+
677
+ if cur_grad_scale < 1024.0 or (
678
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
679
+ ):
680
+ scaler.update(cur_grad_scale * 2.0)
681
+ if cur_grad_scale < 0.01:
682
+ if not saved_bad_model:
683
+ save_bad_model(suffix="-first-warning")
684
+ saved_bad_model = True
685
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
686
+ if cur_grad_scale < 1.0e-05:
687
+ save_bad_model()
688
+ raise RuntimeError(
689
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
690
+ )
691
+
692
+ if params.batch_idx_train % params.log_interval == 0:
693
+ cur_lr = max(scheduler.get_last_lr())
694
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
695
+
696
+ logging.info(
697
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
698
+ f"global_batch_idx: {params.batch_idx_train}, "
699
+ f"batch size: {batch_size}, "
700
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
701
+ f"cur_lr: {cur_lr:.2e}, "
702
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
703
+ )
704
+
705
+ if tb_writer is not None:
706
+ tb_writer.add_scalar(
707
+ "train/learning_rate", cur_lr, params.batch_idx_train
708
+ )
709
+ loss_info.write_summary(
710
+ tb_writer, "train/current_", params.batch_idx_train
711
+ )
712
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
713
+ if params.use_fp16:
714
+ tb_writer.add_scalar(
715
+ "train/grad_scale",
716
+ cur_grad_scale,
717
+ params.batch_idx_train,
718
+ )
719
+
720
+ loss_value = tot_loss["loss"]
721
+ params.train_loss = loss_value
722
+ if params.train_loss < params.best_train_loss:
723
+ params.best_train_epoch = params.cur_epoch
724
+ params.best_train_loss = params.train_loss
725
+
726
+
727
+ def compute_validation_loss(
728
+ params: AttributeDict,
729
+ model: Union[nn.Module, DDP],
730
+ valid_dl: torch.utils.data.DataLoader,
731
+ world_size: int = 1,
732
+ ) -> MetricsTracker:
733
+ """Run the validation process."""
734
+
735
+ model.eval()
736
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
737
+
738
+ # used to summary the stats over iterations
739
+ tot_loss = MetricsTracker()
740
+
741
+ for batch_idx, batch in enumerate(valid_dl):
742
+ tokens, features, features_lens = prepare_input(
743
+ params=params,
744
+ batch=batch,
745
+ device=device,
746
+ return_tokens=True,
747
+ return_feature=True,
748
+ )
749
+
750
+ loss, loss_info = compute_fbank_loss(
751
+ params=params,
752
+ model=model,
753
+ features=features,
754
+ features_lens=features_lens,
755
+ tokens=tokens,
756
+ is_training=False,
757
+ )
758
+ assert loss.requires_grad is False
759
+ tot_loss = tot_loss + loss_info
760
+
761
+ if world_size > 1:
762
+ tot_loss.reduce(loss.device)
763
+
764
+ loss_value = tot_loss["loss"]
765
+ if loss_value < params.best_valid_loss:
766
+ params.best_valid_epoch = params.cur_epoch
767
+ params.best_valid_loss = loss_value
768
+
769
+ return tot_loss
770
+
771
+
772
+ def display_and_save_batch(
773
+ batch: dict,
774
+ params: AttributeDict,
775
+ ) -> None:
776
+ """Display the batch statistics and save the batch into disk.
777
+
778
+ Args:
779
+ batch:
780
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
781
+ for the content in it.
782
+ params:
783
+ Parameters for training. See :func:`get_params`.
784
+ sp:
785
+ The BPE model.
786
+ """
787
+ from lhotse.utils import uuid4
788
+
789
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
790
+ logging.info(f"Saving batch to {filename}")
791
+ torch.save(batch, filename)
792
+
793
+ features = batch["features"]
794
+ tokens = batch["tokens"]
795
+
796
+ logging.info(f"features shape: {features.shape}")
797
+ num_tokens = sum(len(i) for i in tokens)
798
+ logging.info(f"num tokens: {num_tokens}")
799
+
800
+
801
+ def scan_pessimistic_batches_for_oom(
802
+ model: Union[nn.Module, DDP],
803
+ train_dl: torch.utils.data.DataLoader,
804
+ optimizer: torch.optim.Optimizer,
805
+ params: AttributeDict,
806
+ ):
807
+ from lhotse.dataset import find_pessimistic_batches
808
+
809
+ logging.info(
810
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
811
+ )
812
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
813
+
814
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
815
+ for criterion, cuts in batches.items():
816
+ batch = train_dl.dataset[cuts]
817
+ tokens, features, features_lens = prepare_input(
818
+ params=params,
819
+ batch=batch,
820
+ device=device,
821
+ return_tokens=True,
822
+ return_feature=True,
823
+ )
824
+ try:
825
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
826
+
827
+ loss, loss_info = compute_fbank_loss(
828
+ params=params,
829
+ model=model,
830
+ features=features,
831
+ features_lens=features_lens,
832
+ tokens=tokens,
833
+ is_training=True,
834
+ )
835
+ loss.backward()
836
+ optimizer.zero_grad()
837
+ except Exception as e:
838
+ if "CUDA out of memory" in str(e):
839
+ logging.error(
840
+ "Your GPU ran out of memory with the current "
841
+ "max_duration setting. We recommend decreasing "
842
+ "max_duration and trying again.\n"
843
+ f"Failing criterion: {criterion} "
844
+ f"(={crit_values[criterion]}) ..."
845
+ )
846
+ display_and_save_batch(batch, params=params)
847
+ raise
848
+ logging.info(
849
+ f"Maximum memory allocated so far is "
850
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
851
+ )
852
+
853
+
854
+ def tokenize_text(c: Cut, tokenizer):
855
+ if hasattr(c.supervisions[0], "tokens"):
856
+ tokens = tokenizer.tokens_to_token_ids([c.supervisions[0].tokens])
857
+ else:
858
+ tokens = tokenizer.texts_to_token_ids([c.supervisions[0].text])
859
+ print("ko tΓ¬m được tokens")
860
+ c.supervisions[0].tokens = tokens[0]
861
+ return c
862
+
863
+
864
+ def run(rank, world_size, args):
865
+ """
866
+ Args:
867
+ rank:
868
+ It is a value between 0 and `world_size-1`, which is
869
+ passed automatically by `mp.spawn()` in :func:`main`.
870
+ The node with rank 0 is responsible for saving checkpoint.
871
+ world_size:
872
+ Number of GPUs for DDP training.
873
+ args:
874
+ The return value of get_parser().parse_args()
875
+ """
876
+ params = get_params()
877
+ params.update(vars(args))
878
+ params.valid_interval = params.save_every_n
879
+ # Set epoch to a large number to ignore it.
880
+ if params.num_iters > 0:
881
+ params.num_epochs = 1000000
882
+ with open(params.model_config, "r") as f:
883
+ model_config = json.load(f)
884
+ params.update(model_config["model"])
885
+ params.update(model_config["feature"])
886
+
887
+ fix_random_seed(params.seed)
888
+ if world_size > 1:
889
+ setup_dist(rank, world_size, params.master_port)
890
+
891
+ os.makedirs(f"{params.exp_dir}", exist_ok=True)
892
+ copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
893
+ copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
894
+ setup_logger(f"{params.exp_dir}/log/log-train")
895
+
896
+ if args.tensorboard and rank == 0:
897
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
898
+ else:
899
+ tb_writer = None
900
+
901
+ if torch.cuda.is_available():
902
+ params.device = torch.device("cuda", rank)
903
+ else:
904
+ params.device = torch.device("cpu")
905
+ logging.info(f"Device: {params.device}")
906
+
907
+ if params.tokenizer == "emilia":
908
+ tokenizer = EmiliaTokenizer(token_file=params.token_file)
909
+ elif params.tokenizer == "libritts":
910
+ tokenizer = LibriTTSTokenizer(token_file=params.token_file)
911
+ elif params.tokenizer == "espeak":
912
+ tokenizer = EspeakTokenizer(token_file=params.token_file, lang=params.lang)
913
+ elif params.tokenizer == "simple2":
914
+ tokenizer = SimpleTokenizer2(token_file=params.token_file)
915
+ else:
916
+ assert params.tokenizer == "simple"
917
+ tokenizer = SimpleTokenizer(token_file=params.token_file)
918
+
919
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
920
+ params.update(tokenizer_config)
921
+
922
+ logging.info(params)
923
+
924
+ logging.info("About to create model")
925
+
926
+ model = ZipVoice(
927
+ **model_config["model"],
928
+ **tokenizer_config,
929
+ )
930
+
931
+ if params.checkpoint is not None:
932
+ logging.info(f"Loading pre-trained model from {params.checkpoint}")
933
+ _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
934
+ num_param = sum([p.numel() for p in model.parameters()])
935
+ logging.info(f"Number of parameters : {num_param}")
936
+
937
+ model_avg: Optional[nn.Module] = None
938
+ if rank == 0:
939
+ # model_avg is only used with rank 0
940
+ model_avg = copy.deepcopy(model).to(torch.float64)
941
+
942
+ assert params.start_epoch > 0, params.start_epoch
943
+ if params.start_epoch > 1:
944
+ checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
945
+
946
+ model = model.to(params.device)
947
+ if world_size > 1:
948
+ logging.info("Using DDP")
949
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
950
+
951
+ optimizer = ScaledAdam(
952
+ get_parameter_groups_with_lrs(
953
+ model,
954
+ lr=params.base_lr,
955
+ include_names=True,
956
+ ),
957
+ lr=params.base_lr, # should have no effect
958
+ clipping_scale=2.0,
959
+ )
960
+
961
+ assert params.lr_hours >= 0
962
+
963
+ if params.finetune:
964
+ scheduler = FixedLRScheduler(optimizer)
965
+ elif params.lr_hours > 0:
966
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_hours)
967
+ else:
968
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
969
+
970
+ scaler = create_grad_scaler(enabled=params.use_fp16)
971
+
972
+ if params.start_epoch > 1 and checkpoints is not None:
973
+ # load state_dict for optimizers
974
+ if "optimizer" in checkpoints:
975
+ logging.info("Loading optimizer state dict")
976
+ optimizer.load_state_dict(checkpoints["optimizer"])
977
+
978
+ # load state_dict for schedulers
979
+ if "scheduler" in checkpoints:
980
+ logging.info("Loading scheduler state dict")
981
+ scheduler.load_state_dict(checkpoints["scheduler"])
982
+
983
+ if "grad_scaler" in checkpoints:
984
+ logging.info("Loading grad scaler state dict")
985
+ scaler.load_state_dict(checkpoints["grad_scaler"])
986
+
987
+ if params.print_diagnostics:
988
+ opts = diagnostics.TensorDiagnosticOptions(
989
+ 512
990
+ ) # allow 4 megabytes per sub-module
991
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
992
+
993
+ if params.inf_check:
994
+ register_inf_check_hooks(model)
995
+
996
+ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
997
+ if c.duration < min_len or c.duration > max_len:
998
+ return False
999
+ return True
1000
+
1001
+ _remove_short_and_long_utt = partial(
1002
+ remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
1003
+ )
1004
+
1005
+ datamodule = TtsDataModule(args)
1006
+ if params.dataset == "emilia":
1007
+ train_cuts = CutSet.mux(
1008
+ datamodule.train_emilia_EN_cuts(),
1009
+ datamodule.train_emilia_ZH_cuts(),
1010
+ weights=[46000, 49000],
1011
+ )
1012
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1013
+ dev_cuts = CutSet.mux(
1014
+ datamodule.dev_emilia_EN_cuts(),
1015
+ datamodule.dev_emilia_ZH_cuts(),
1016
+ weights=[0.5, 0.5],
1017
+ )
1018
+ elif params.dataset == "libritts":
1019
+ train_cuts = datamodule.train_libritts_cuts()
1020
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1021
+ dev_cuts = datamodule.dev_libritts_cuts()
1022
+ else:
1023
+ assert params.dataset == "custom"
1024
+ train_cuts = datamodule.train_custom_cuts(params.train_manifest)
1025
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1026
+ dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
1027
+ # To avoid OOM issues due to too long dev cuts
1028
+ dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
1029
+
1030
+ if params.tokenizer in ["emilia", "espeak", "dialog"]:
1031
+ if not hasattr(train_cuts[0].supervisions[0], "tokens") or not hasattr(
1032
+ dev_cuts[0].supervisions[0], "tokens"
1033
+ ):
1034
+ logging.warning(
1035
+ f"Using {params.tokenizer} tokenizer but tokens are not prepared,"
1036
+ f"will tokenize on-the-fly, which can slow down training significantly."
1037
+ )
1038
+ _tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
1039
+ train_cuts = train_cuts.map(_tokenize_text)
1040
+ dev_cuts = dev_cuts.map(_tokenize_text)
1041
+
1042
+ train_dl = datamodule.train_dataloaders(train_cuts)
1043
+
1044
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
1045
+
1046
+ if params.scan_oom:
1047
+ scan_pessimistic_batches_for_oom(
1048
+ model=model,
1049
+ train_dl=train_dl,
1050
+ optimizer=optimizer,
1051
+ params=params,
1052
+ )
1053
+
1054
+ logging.info("Training started")
1055
+
1056
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
1057
+ logging.info(f"Start epoch {epoch}")
1058
+
1059
+ if params.lr_hours == 0:
1060
+ scheduler.step_epoch(epoch - 1)
1061
+ fix_random_seed(params.seed + epoch - 1)
1062
+ train_dl.sampler.set_epoch(epoch - 1)
1063
+
1064
+ params.cur_epoch = epoch
1065
+
1066
+ if tb_writer is not None:
1067
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
1068
+
1069
+ train_one_epoch(
1070
+ params=params,
1071
+ model=model,
1072
+ model_avg=model_avg,
1073
+ optimizer=optimizer,
1074
+ scheduler=scheduler,
1075
+ train_dl=train_dl,
1076
+ valid_dl=valid_dl,
1077
+ scaler=scaler,
1078
+ tb_writer=tb_writer,
1079
+ world_size=world_size,
1080
+ rank=rank,
1081
+ )
1082
+
1083
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
1084
+ break
1085
+
1086
+ if params.print_diagnostics:
1087
+ diagnostic.print_diagnostics()
1088
+ break
1089
+
1090
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
1091
+ save_checkpoint(
1092
+ filename=filename,
1093
+ params=params,
1094
+ model=model,
1095
+ model_avg=model_avg,
1096
+ optimizer=optimizer,
1097
+ scheduler=scheduler,
1098
+ sampler=train_dl.sampler,
1099
+ scaler=scaler,
1100
+ rank=rank,
1101
+ )
1102
+
1103
+ if rank == 0:
1104
+ if params.best_train_epoch == params.cur_epoch:
1105
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
1106
+ copyfile(src=filename, dst=best_train_filename)
1107
+
1108
+ if params.best_valid_epoch == params.cur_epoch:
1109
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
1110
+ copyfile(src=filename, dst=best_valid_filename)
1111
+
1112
+ logging.info("Done!")
1113
+
1114
+ if world_size > 1:
1115
+ torch.distributed.barrier()
1116
+ cleanup_dist()
1117
+
1118
+
1119
+ def main():
1120
+ parser = get_parser()
1121
+ TtsDataModule.add_arguments(parser)
1122
+ args = parser.parse_args()
1123
+ args.exp_dir = Path(args.exp_dir)
1124
+
1125
+ world_size = args.world_size
1126
+ assert world_size >= 1
1127
+ if world_size > 1:
1128
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
1129
+ else:
1130
+ run(rank=0, world_size=1, args=args)
1131
+
1132
+
1133
+ if __name__ == "__main__":
1134
+ torch.set_num_threads(1)
1135
+ torch.set_num_interop_threads(1)
1136
+ main()