Spaces:
Runtime error
Runtime error
Commit
Β·
af11ce4
0
Parent(s):
up
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- .gitattributes +36 -0
- LICENSE +201 -0
- README.md +9 -0
- T_English.wav +3 -0
- T_English_output.wav +3 -0
- ToαΊ‘i.wav +3 -0
- ToαΊ‘i_output.wav +3 -0
- Trung.wav +3 -0
- Trung_output.wav +3 -0
- app.py +346 -0
- assets/silence.wav +3 -0
- download_models.py +38 -0
- egs/zipvoice/README.md +15 -0
- egs/zipvoice/conf/zipvoice_base.json +26 -0
- egs/zipvoice/local/pinyin.txt +1550 -0
- egs/zipvoice/local/prepare_emilia.sh +149 -0
- egs/zipvoice/local/prepare_libritts.sh +100 -0
- egs/zipvoice/local/prepare_token_file_char.py +67 -0
- egs/zipvoice/local/prepare_token_file_emilia.py +91 -0
- egs/zipvoice/local/prepare_tokens_emilia.py +88 -0
- egs/zipvoice/local/preprocess_emilia.py +210 -0
- egs/zipvoice/run_custom.sh +138 -0
- egs/zipvoice/run_emilia.sh +178 -0
- egs/zipvoice/run_eval.sh +142 -0
- egs/zipvoice/run_finetune.sh +175 -0
- egs/zipvoice/run_libritts.sh +148 -0
- egs/zipvoice/utils/parse_options.sh +97 -0
- egs/zipvoice/utils/validate_manifest.py +70 -0
- egs/zipvoice_dialog/README.md +12 -0
- egs/zipvoice_dialog/local/prepare_opendialog.py +262 -0
- egs/zipvoice_dialog/run_custom.sh +145 -0
- egs/zipvoice_dialog/run_eval.sh +120 -0
- egs/zipvoice_dialog/run_finetune.sh +135 -0
- egs/zipvoice_dialog/run_opendialog.sh +122 -0
- infer.py +578 -0
- proccess_wav.py +364 -0
- pyproject.toml +5 -0
- requirements.txt +23 -0
- requirements_eval.txt +19 -0
- setup.py +55 -0
- zipvoice/__init__.py +7 -0
- zipvoice/bin/compute_fbank.py +272 -0
- zipvoice/bin/generate_averaged_model.py +229 -0
- zipvoice/bin/infer_zipvoice.py +614 -0
- zipvoice/bin/infer_zipvoice_dialog.py +756 -0
- zipvoice/bin/infer_zipvoice_onnx.py +712 -0
- zipvoice/bin/onnx_export.py +410 -0
- zipvoice/bin/prepare_dataset.py +274 -0
- zipvoice/bin/prepare_tokens.py +102 -0
- zipvoice/bin/train_zipvoice.py +1136 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-sa-4.0
|
| 3 |
+
title: Text_to_speech_Vietnamese
|
| 4 |
+
sdk: gradio
|
| 5 |
+
emoji: π
|
| 6 |
+
colorFrom: red
|
| 7 |
+
colorTo: yellow
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
T_English.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d6ffd499fdf637243bdd630cb52660635d5c7cb580b87f52aa7efca90a33311f
|
| 3 |
+
size 328364
|
T_English_output.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:756daab105a8f4508e6d4c237f1e72a25ec326ad1f6665dce974f96e9b86db7a
|
| 3 |
+
size 954742
|
ToαΊ‘i.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:798c60882fa0e9a758fd72cf506e6b71293c5ccb5ed27b92569d042a23624bdc
|
| 3 |
+
size 200782
|
ToαΊ‘i_output.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3916769643ff756700e08ae355a0f7e30fd7a0e5299b06a866facad5ff31afd1
|
| 3 |
+
size 2154910
|
Trung.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0bf7087aa7452978b6ec6c25eb8c078eb4ca337660c9d3e3661b8017da9238e9
|
| 3 |
+
size 199376
|
Trung_output.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:068c6cfd893846473d177b9636b4689119a0c0ee7e19f4079cd6c98e27bb94a3
|
| 3 |
+
size 745196
|
app.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import os
|
| 3 |
+
from download_models import download_all_models
|
| 4 |
+
from huggingface_hub import login
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
# ======================= HF LOGIN & DOWNLOAD MODEL =======================
|
| 9 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 10 |
+
if hf_token:
|
| 11 |
+
login(token=hf_token)
|
| 12 |
+
|
| 13 |
+
# TαΊ£i model khi Space ACTIVE
|
| 14 |
+
download_all_models()
|
| 15 |
+
|
| 16 |
+
from infer import run_zipvoice
|
| 17 |
+
|
| 18 |
+
# NEW: ASR + DENOISE
|
| 19 |
+
from chunkformer import ChunkFormerModel
|
| 20 |
+
from clearvoice import ClearVoice
|
| 21 |
+
from proccess_wav import enhance_ref_audio, transcribe_ref_audio
|
| 22 |
+
|
| 23 |
+
# (NαΊΏu 2 dΓ²ng test nΓ y khΓ΄ng cαΊ§n thΓ¬ bαΊ‘n cΓ³ thα» xoΓ‘ bα»t cho nhαΊΉ)
|
| 24 |
+
enhanced = enhance_ref_audio("ToαΊ‘i.wav")
|
| 25 |
+
text = transcribe_ref_audio(enhanced)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def infer_ref_text_ui(ref_audio_path: str) -> str:
|
| 29 |
+
"""
|
| 30 |
+
DΓΉng cho nΓΊt 'Infer Text':
|
| 31 |
+
- Enhance WAV (ClearVoice + xα» lΓ½ khoαΊ£ng lαΊ·ng + cαΊ―t 5β10s)
|
| 32 |
+
- ASR theo khoαΊ£ng lαΊ·ng
|
| 33 |
+
- Δα» kαΊΏt quαΊ£ vΓ o Γ΄ Reference Text
|
| 34 |
+
"""
|
| 35 |
+
if not ref_audio_path:
|
| 36 |
+
raise gr.Error("Vui lΓ²ng upload file giα»ng mαΊ«u trΖ°α»c khi infer text.")
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
enhanced = enhance_ref_audio(ref_audio_path)
|
| 40 |
+
text = transcribe_ref_audio(enhanced)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
raise gr.Error(f"Lα»i khi nhαΊn dαΊ‘ng tα»« audio tham chiαΊΏu: {e}")
|
| 43 |
+
|
| 44 |
+
if not text:
|
| 45 |
+
raise gr.Error("KhΓ΄ng nhαΊn dαΊ‘ng Δược nα»i dung tα»« audio tham chiαΊΏu.")
|
| 46 |
+
return text
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ======================= CαΊ€U HΓNH DEMO SαΊ΄N =======================
|
| 50 |
+
SAMPLE_CONFIGS = [
|
| 51 |
+
{
|
| 52 |
+
"name": "Sample 1 β Kα» chuyα»n",
|
| 53 |
+
"ref_audio": "ToαΊ‘i.wav",
|
| 54 |
+
"ref_text": "Trong bΓ³ng tα»i, ToαΊ‘i nΓ³i cΓ‘i gΓ¬ ΔΓ³ mΓ Thoan khΓ΄ng nghe thαΊ₯y.",
|
| 55 |
+
"gen_text": "ΔΓͺm nay trα»i nhiα»u mΓ’y, Γ‘nh trΔng bα» che khuαΊ₯t, chα» cΓ²n lαΊ‘i mα»t dαΊ£i sΓ‘ng yαΊΏu α»t rΖ‘i xuα»ng con ΔΖ°α»ng ΔαΊ₯t trαΊ£i dΓ i giα»―a cΓ‘nh Δα»ng. CαΊu bΓ© tΓͺn TΓn Δang dαΊ―t chiαΊΏc xe ΔαΊ‘p cΕ© Δi vα» nhΓ , bΓ‘nh xe bα» cΓ‘n Δinh nΓͺn lΔn nαΊ·ng vΓ chαΊm nhΖ° con trΓ’u mα»t nhα»c sau vα»₯ mΓΉa. GiΓ³ thα»i lαΊ‘nh buα»t, mΓΉi bΓΉn ΔαΊ₯t ngai ngΓ‘i quαΊ₯n lαΊ₯y chΓ’n cαΊu. Tα»i ΔoαΊ‘n rαΊ½ dαΊ«n vΓ o xΓ³m, TΓn nghe tiαΊΏng nΖ°α»c chαΊ£y khe khαΊ½ tα»« con mΖ°Ζ‘ng bΓͺn ΔΖ°α»ng. TiαΊΏng αΊ₯y vαΊ«n quen thuα»c, nhΖ°ng tα»i nay lαΊ‘i vang khΓ‘c lαΊ‘, nhΖ° cΓ³ giα»ng ngΖ°α»i Δang hΓ²a vΓ o nhα»p nΖ°α»c, lΓΊc trαΊ§m lΓΊc cao, nghe mΖ‘ hα» mΓ lαΊ‘nh sα»ng lΖ°ng. CαΊu dα»«ng lαΊ‘i, nghiΓͺng tai lαΊ―ng nghe, tim ΔαΊp nhanh nhΖ° muα»n vượt khα»i lα»ng ngα»±c.",
|
| 56 |
+
"out_audio": "ToαΊ‘i_output.wav",
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"name": "Sample 2 β Nα»―",
|
| 60 |
+
"ref_audio": "Trung.wav",
|
| 61 |
+
"ref_text": "MΓΉa hΓ¨ khΓ΄ng chα» lΓ khoαΊ£ng thα»i gian nghα» ngΖ‘i, mΓ cΓ²n lΓ khoαΊ£ng thα»i gian tuyα»t vα»i.",
|
| 62 |
+
"gen_text": "Tα»« cΓ‘c kαΊΏt quαΊ£ nΓ y, chΓΊng tΓ΄i Δα» xuαΊ₯t rαΊ±ng sα»± kαΊΏt hợp nhuαΊ§n nhuyα»
n giα»―a adaptive optimization, robust training pipelines vΓ interpretable model design sαΊ½ lΓ chΓ¬a khΓ³a Δα» phΓ‘t triα»n cΓ‘c hα» thα»ng Γ’y ai vα»«a mαΊ‘nh mαΊ½ vα»«a ΔΓ‘ng tin cαΊy trong mΓ΄i trΖ°α»ng thα»±c tαΊΏ.",
|
| 63 |
+
"out_audio": "Trung_output.wav",
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"name": "Sample 3 β English",
|
| 67 |
+
"ref_audio": "T_English.wav",
|
| 68 |
+
"ref_text": "And turning to the pole which he had dragged, He drew it close beneath the widowed bough, And what was of it unto it left bound.",
|
| 69 |
+
"gen_text": "Recent experiments indicate that the current model architecture still exhibits significant overfitting, especially when evaluated on out of distribution samples. Although the training accuracy remains consistently high, the performance drops sharply when the model is exposed to noise perturbed inputs, suggesting limited robustness.",
|
| 70 |
+
"out_audio": "T_English_output.wav",
|
| 71 |
+
},
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
# HΓ m dΓΉng khi bαΊ₯m "DΓΉng sample nΓ y"
|
| 75 |
+
def make_sample_loader(sample):
|
| 76 |
+
def _load_sample():
|
| 77 |
+
return (
|
| 78 |
+
sample["ref_audio"], # ref_audio -> input Audio
|
| 79 |
+
sample["ref_text"], # ref_text -> Textbox
|
| 80 |
+
sample["gen_text"], # gen_text -> Textbox
|
| 81 |
+
sample["out_audio"], # output_audio -> Audio
|
| 82 |
+
)
|
| 83 |
+
return _load_sample
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ======================= STYLE TΓY CHα»NH (LΓM SΓNG HΖ N) =======================
|
| 87 |
+
custom_css = """
|
| 88 |
+
#app-container {
|
| 89 |
+
max-width: 1000px;
|
| 90 |
+
margin: 0 auto;
|
| 91 |
+
}
|
| 92 |
+
.gradio-container {
|
| 93 |
+
background: radial-gradient(circle at top, #ffffff 0, #f9fafb 55%);
|
| 94 |
+
color: #111827;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
/* TiΓͺu Δα» lα»n */
|
| 98 |
+
#title-block h1 {
|
| 99 |
+
font-size: 2.4rem !important;
|
| 100 |
+
font-weight: 800 !important;
|
| 101 |
+
background: linear-gradient(120deg, #f97316, #eab308, #22c55e);
|
| 102 |
+
-webkit-background-clip: text;
|
| 103 |
+
color: transparent;
|
| 104 |
+
text-align: center;
|
| 105 |
+
}
|
| 106 |
+
#title-block p {
|
| 107 |
+
text-align:center;
|
| 108 |
+
font-size: 0.95rem;
|
| 109 |
+
color: #6b7280;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
/* Card sΓ‘ng hΖ‘n */
|
| 113 |
+
.sample-card {
|
| 114 |
+
border-radius: 16px;
|
| 115 |
+
padding: 16px;
|
| 116 |
+
background: rgba(255, 255, 255, 0.96);
|
| 117 |
+
border: 1px solid rgba(148, 163, 184, 0.6);
|
| 118 |
+
box-shadow: 0 18px 28px rgba(148, 163, 184, 0.35);
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
/* NΓΊt bαΊ₯m */
|
| 122 |
+
button.primary {
|
| 123 |
+
border-radius: 999px !important;
|
| 124 |
+
font-weight: 600 !important;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
/* Tabs */
|
| 128 |
+
.svelte-1ipelgc, .tabitem {
|
| 129 |
+
font-weight: 600;
|
| 130 |
+
}
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
# ======================= XỬ Là TEXT (NẾU CẦN) =======================
|
| 134 |
+
def post_process(text: str) -> str:
|
| 135 |
+
text = " " + text + " "
|
| 136 |
+
text = text.replace(" . . ", " . ")
|
| 137 |
+
text = " " + text + " "
|
| 138 |
+
text = text.replace(" .. ", " . ")
|
| 139 |
+
text = " " + text + " "
|
| 140 |
+
text = text.replace(" , , ", " , ")
|
| 141 |
+
text = " " + text + " "
|
| 142 |
+
text = text.replace(" ,, ", " , ")
|
| 143 |
+
text = " " + text + " "
|
| 144 |
+
text = text.replace('"', "")
|
| 145 |
+
return " ".join(text.split())
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@spaces.GPU
|
| 149 |
+
def infer_tts(ref_audio_path, ref_text, gen_text, steps, request: gr.Request = None):
|
| 150 |
+
if not ref_audio_path:
|
| 151 |
+
raise gr.Error("Please upload a sample audio file.")
|
| 152 |
+
|
| 153 |
+
if not gen_text.strip():
|
| 154 |
+
raise gr.Error("Please enter the text content to generate voice.")
|
| 155 |
+
|
| 156 |
+
# Giα»i hαΊ‘n Δα» dΓ i nα»i dung (4000 tα»«)
|
| 157 |
+
if len(gen_text.split()) > 4000:
|
| 158 |
+
raise gr.Error("Please enter text content with less than 4000 words.")
|
| 159 |
+
|
| 160 |
+
# 1) Enhance ref audio: clearvoice + xα» lΓ½ khoαΊ£ng lαΊ·ng + cαΊ―t 5β10s
|
| 161 |
+
try:
|
| 162 |
+
enhanced_ref_audio = enhance_ref_audio(ref_audio_path)
|
| 163 |
+
except Exception as e:
|
| 164 |
+
raise gr.Error(f"Lα»i khi xα» lΓ½ audio tham chiαΊΏu: {e}")
|
| 165 |
+
|
| 166 |
+
# 2) Nếu không có ref_text thì chẑy ASR theo khoảng lặng
|
| 167 |
+
if not ref_text or not ref_text.strip():
|
| 168 |
+
try:
|
| 169 |
+
inferred = transcribe_ref_audio(enhanced_ref_audio)
|
| 170 |
+
if not inferred:
|
| 171 |
+
raise gr.Error(
|
| 172 |
+
"KhΓ΄ng nhαΊn dαΊ‘ng Δược nα»i dung tα»« audio tham chiαΊΏu. "
|
| 173 |
+
"Vui lΓ²ng nhαΊp Reference Text thα»§ cΓ΄ng."
|
| 174 |
+
)
|
| 175 |
+
ref_text = inferred
|
| 176 |
+
print(f"[ASR] Inferred ref_text: {ref_text}")
|
| 177 |
+
except gr.Error:
|
| 178 |
+
raise
|
| 179 |
+
except Exception as e:
|
| 180 |
+
raise gr.Error(f"Lα»i khi tα»± Δα»ng nhαΊn dαΊ‘ng Reference Text: {e}")
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
out_path = "result.wav"
|
| 184 |
+
|
| 185 |
+
run_zipvoice(
|
| 186 |
+
model_name="zipvoice",
|
| 187 |
+
prompt_wav=enhanced_ref_audio, # dΓΉng file ΔΓ£ xα» lΓ½
|
| 188 |
+
prompt_text=ref_text.strip() if ref_text else "xin chΓ o cΓ‘c bαΊ‘n",
|
| 189 |
+
text=gen_text,
|
| 190 |
+
res_wav_path=out_path,
|
| 191 |
+
lang="vi",
|
| 192 |
+
tokenizer_name="espeak",
|
| 193 |
+
num_step=steps,
|
| 194 |
+
seed=123456,
|
| 195 |
+
speed=1.0,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
return out_path
|
| 199 |
+
|
| 200 |
+
except Exception as e:
|
| 201 |
+
raise gr.Error(f"Error generating voice: {e}")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ======================= UI =======================
|
| 205 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
| 206 |
+
with gr.Column(elem_id="app-container"):
|
| 207 |
+
# --------- TIΓU Δα» ----------
|
| 208 |
+
gr.Markdown(
|
| 209 |
+
"""
|
| 210 |
+
<div id="title-block">
|
| 211 |
+
<h1>π€ ZipVoice β Zero-shot Vietnamese TTS</h1>
|
| 212 |
+
<p>Upload mα»t mαΊ«u giα»ng + nhαΊp nα»i dung → hα» thα»ng sαΊ½ bαΊ―t chΖ°α»c giα»ng nΓ³i vΓ Δα»c ΔoαΊ‘n text cα»§a bαΊ‘n.</p>
|
| 213 |
+
</div>
|
| 214 |
+
""",
|
| 215 |
+
elem_id="title-block",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
with gr.Tabs():
|
| 219 |
+
# Chα» cΓ²n 1 tab chΓnh, demo cΕ©ng nαΊ±m trong tab nΓ y
|
| 220 |
+
with gr.TabItem("π― Tα»± tαΊ‘o giα»ng nΓ³i"):
|
| 221 |
+
# --------- KHα»I INPUT / OUTPUT CHΓNH ----------
|
| 222 |
+
with gr.Row():
|
| 223 |
+
with gr.Column(elem_classes=["sample-card"]):
|
| 224 |
+
gr.Markdown("#### 1οΈβ£ TαΊ£i giα»ng mαΊ«u & nhαΊp text")
|
| 225 |
+
|
| 226 |
+
ref_audio = gr.Audio(
|
| 227 |
+
label="π Sample Voice (upload hoαΊ·c kΓ©o thαΊ£)",
|
| 228 |
+
type="filepath",
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
ref_text = gr.Textbox(
|
| 232 |
+
label="π Reference Text (optional)",
|
| 233 |
+
placeholder="Nα»i dung Δang Δược nΓ³i trong file giα»ng mαΊ«u (nΓͺn tα»± viαΊΏt cho chΓnh xΓ‘c)",
|
| 234 |
+
lines=3,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# NΓΊt infer text tα»« audio tham chiαΊΏu (ASR + khα» nhiα»
u)
|
| 238 |
+
btn_infer_text = gr.Button(
|
| 239 |
+
"β¨ Infer Text tα»« audio tham chiαΊΏu"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
gen_text = gr.Textbox(
|
| 243 |
+
label="π Text to Generate",
|
| 244 |
+
placeholder="NhαΊp nα»i dung tiαΊΏng Viα»t bαΊ‘n muα»n tα»ng hợp...",
|
| 245 |
+
lines=6,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
steps = gr.Slider(
|
| 249 |
+
8,
|
| 250 |
+
64,
|
| 251 |
+
value=25,
|
| 252 |
+
step=1,
|
| 253 |
+
label="β‘ Step (cΓ ng lα»n, cΓ ng tα»t, cΓ ng lΓ’u)",
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
btn_synthesize = gr.Button(
|
| 257 |
+
"π₯ Generate Voice",
|
| 258 |
+
variant="primary",
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
with gr.Column(elem_classes=["sample-card"]):
|
| 262 |
+
gr.Markdown("#### 2οΈβ£ KαΊΏt quαΊ£ tα»ng hợp")
|
| 263 |
+
output_audio = gr.Audio(
|
| 264 |
+
label="π§ Generated Audio",
|
| 265 |
+
type="filepath",
|
| 266 |
+
)
|
| 267 |
+
gr.Markdown(
|
| 268 |
+
"""
|
| 269 |
+
- BαΊ‘n cΓ³ thα» tαΊ£i file `.wav` vα» sau khi tαΊ‘o.
|
| 270 |
+
- NαΊΏu nghe chΖ°a α»n, hΓ£y thα»:
|
| 271 |
+
- DΓΉng **ref audio ngαΊ―n 3-8s, phΓ‘t Γ’m chuαΊ©n hΖ‘n.
|
| 272 |
+
"""
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# mapping nΓΊt Generate -> infer_tts
|
| 276 |
+
btn_synthesize.click(
|
| 277 |
+
infer_tts,
|
| 278 |
+
inputs=[ref_audio, ref_text, gen_text, steps],
|
| 279 |
+
outputs=[output_audio],
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# mapping nΓΊt Infer Text -> Δiα»n ref_text (cΓ³ khα» nhiα»
u trΖ°α»c)
|
| 283 |
+
btn_infer_text.click(
|
| 284 |
+
infer_ref_text_ui,
|
| 285 |
+
inputs=[ref_audio],
|
| 286 |
+
outputs=[ref_text],
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# --------- KHα»I DEMO NαΊ°M NGAY TRONG TAB CHΓNH ----------
|
| 290 |
+
gr.Markdown(
|
| 291 |
+
"""
|
| 292 |
+
### π§ Demo cΓ³ sαΊ΅n
|
| 293 |
+
Click vΓ o mα»t sample bΓͺn dΖ°α»i Δα» tα»± Δα»ng nαΊ‘p:
|
| 294 |
+
- π Giα»ng mαΊ«u (ref voice)
|
| 295 |
+
- π Reference text
|
| 296 |
+
- π Text to generate
|
| 297 |
+
- π§ Output audio mαΊ«u
|
| 298 |
+
"""
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
for sample in SAMPLE_CONFIGS:
|
| 302 |
+
with gr.Column(elem_classes=["sample-card"]):
|
| 303 |
+
gr.Markdown(f"### {sample['name']}")
|
| 304 |
+
with gr.Row():
|
| 305 |
+
gr.Audio(
|
| 306 |
+
value=sample["ref_audio"],
|
| 307 |
+
label="π Reference Voice",
|
| 308 |
+
interactive=False,
|
| 309 |
+
)
|
| 310 |
+
gr.Textbox(
|
| 311 |
+
value=sample["ref_text"],
|
| 312 |
+
label="π Reference Text",
|
| 313 |
+
interactive=False,
|
| 314 |
+
lines=3,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
gr.Audio(
|
| 318 |
+
value=sample["out_audio"],
|
| 319 |
+
label="π§ Generated Sample (TTS)",
|
| 320 |
+
interactive=False,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if sample.get("gen_text"):
|
| 324 |
+
gr.Markdown(
|
| 325 |
+
f"**Text dΓΉng Δα» synth:** {sample['gen_text']}"
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# NΓΊt nΓ y sαΊ½ fill luΓ΄n ref_audio, ref_text, gen_text, output_audio
|
| 329 |
+
use_btn = gr.Button(f"β‘οΈ DΓΉng {sample['name']}")
|
| 330 |
+
|
| 331 |
+
use_btn.click(
|
| 332 |
+
make_sample_loader(sample),
|
| 333 |
+
inputs=[],
|
| 334 |
+
outputs=[ref_audio, ref_text, gen_text, output_audio],
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
gr.Markdown(
|
| 338 |
+
"""
|
| 339 |
+
### β οΈ Model Limitations
|
| 340 |
+
1. CΓ³ thα» xα» lΓ½ chΖ°a tα»t vα»i sα», ngΓ y thΓ‘ng, kΓ½ tα»± ΔαΊ·c biα»t.
|
| 341 |
+
2. Nhα»p Δiα»u ΔΓ΄i khi chΖ°a tα»± nhiΓͺn.
|
| 342 |
+
3. ChαΊ₯t lượng phα»₯ thuα»c khΓ‘ nhiα»u vΓ o chαΊ₯t lượng ref audio.
|
| 343 |
+
"""
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
demo.queue().launch()
|
assets/silence.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca5a251f2d1439929f1c6b44d98299e53d402da45306af79cbfab5005501fed9
|
| 3 |
+
size 4800044
|
download_models.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
|
| 4 |
+
MODEL_DIR = "zipvoice_finetune"
|
| 5 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 6 |
+
|
| 7 |
+
files = {
|
| 8 |
+
"iter-525000-avg-2.pt": "https://huggingface.co/datasets/kjanh/demo_zip/resolve/main/epoch-46-all-speak-600h-en-norm.pt",
|
| 9 |
+
"model.json": "https://huggingface.co/datasets/kjanh/demo_zip/resolve/main/model.json",
|
| 10 |
+
"tokens.txt": "https://huggingface.co/datasets/kjanh/demo_zip/resolve/main/tokens.txt",
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 14 |
+
|
| 15 |
+
def download_with_token(url, dest_path):
|
| 16 |
+
if os.path.exists(dest_path):
|
| 17 |
+
print(f"β File tα»n tαΊ‘i: {dest_path}")
|
| 18 |
+
return
|
| 19 |
+
|
| 20 |
+
if HF_TOKEN is None:
|
| 21 |
+
raise RuntimeError("β Missing HF_TOKEN in Secrets!")
|
| 22 |
+
|
| 23 |
+
print(f"β¬ Downloading {dest_path} ...")
|
| 24 |
+
|
| 25 |
+
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
|
| 26 |
+
r = requests.get(url, headers=headers, stream=True)
|
| 27 |
+
r.raise_for_status()
|
| 28 |
+
|
| 29 |
+
with open(dest_path, "wb") as f:
|
| 30 |
+
for chunk in r.iter_content(1024 * 1024):
|
| 31 |
+
f.write(chunk)
|
| 32 |
+
|
| 33 |
+
print(f"β
Downloaded {dest_path}")
|
| 34 |
+
# demo
|
| 35 |
+
def download_all_models():
|
| 36 |
+
for filename, url in files.items():
|
| 37 |
+
dest = os.path.join(MODEL_DIR, filename)
|
| 38 |
+
download_with_token(url, dest)
|
egs/zipvoice/README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ZipVoice Recipe
|
| 2 |
+
|
| 3 |
+
This recipe contains the following examples:
|
| 4 |
+
|
| 5 |
+
- Training ZipVoice on Emilia from scratch, see [run_emilia.sh](run_emilia.sh)
|
| 6 |
+
- Training ZipVoice on LibriTTS from scratch, see [run_libritts.sh](run_libritts.sh).
|
| 7 |
+
- Training ZipVoice on custom datasets (any language) from scratch, see [run_custom.sh](run_custom.sh).
|
| 8 |
+
- Fine-tuning pre-trained ZipVoice on custom datasets (any language), see [run_finetune.sh](run_finetune.sh).
|
| 9 |
+
- Evaluate TTS models with objective metrics reported in ZipVoice paper, see [run_eval.sh](run_eval.sh).
|
| 10 |
+
|
| 11 |
+
> **NOTE:** [run_emilia.sh](run_emilia.sh) is the most complete example, which covers: data preparation, ZipVoice trainnig, ZipVoice-Distill training, onnx export, and inference with all PyTorch and ONNX models.
|
| 12 |
+
|
| 13 |
+
> **NOTE:** For evaluation, first install packages from [../../requirements_eval.txt](../../requirements_eval.txt)
|
| 14 |
+
>
|
| 15 |
+
> `pip install -r ../../requirements_eval.txt`
|
egs/zipvoice/conf/zipvoice_base.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model" : {
|
| 3 |
+
"fm_decoder_downsampling_factor" : [1,2,4,2,1],
|
| 4 |
+
"fm_decoder_num_layers" : [2,2,4,4,4],
|
| 5 |
+
"fm_decoder_cnn_module_kernel" : [31,15,7,15,31],
|
| 6 |
+
"fm_decoder_feedforward_dim" : 1536,
|
| 7 |
+
"fm_decoder_num_heads" : 4,
|
| 8 |
+
"fm_decoder_dim" : 512,
|
| 9 |
+
"text_encoder_num_layers" : 4,
|
| 10 |
+
"text_encoder_feedforward_dim" : 512,
|
| 11 |
+
"text_encoder_cnn_module_kernel" : 9,
|
| 12 |
+
"text_encoder_num_heads" : 4,
|
| 13 |
+
"text_encoder_dim" : 192,
|
| 14 |
+
"query_head_dim" : 32,
|
| 15 |
+
"value_head_dim" : 12,
|
| 16 |
+
"pos_head_dim" : 4,
|
| 17 |
+
"pos_dim" : 48,
|
| 18 |
+
"time_embed_dim" : 192,
|
| 19 |
+
"text_embed_dim" : 192,
|
| 20 |
+
"feat_dim": 100
|
| 21 |
+
},
|
| 22 |
+
"feature" : {
|
| 23 |
+
"sampling_rate": 24000,
|
| 24 |
+
"type": "vocos"
|
| 25 |
+
}
|
| 26 |
+
}
|
egs/zipvoice/local/pinyin.txt
ADDED
|
@@ -0,0 +1,1550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
a
|
| 2 |
+
a1
|
| 3 |
+
a2
|
| 4 |
+
a3
|
| 5 |
+
a4
|
| 6 |
+
ai1
|
| 7 |
+
ai2
|
| 8 |
+
ai3
|
| 9 |
+
ai4
|
| 10 |
+
an1
|
| 11 |
+
an2
|
| 12 |
+
an3
|
| 13 |
+
an4
|
| 14 |
+
ang1
|
| 15 |
+
ang2
|
| 16 |
+
ang3
|
| 17 |
+
ang4
|
| 18 |
+
ao1
|
| 19 |
+
ao2
|
| 20 |
+
ao3
|
| 21 |
+
ao4
|
| 22 |
+
ba
|
| 23 |
+
ba1
|
| 24 |
+
ba2
|
| 25 |
+
ba3
|
| 26 |
+
ba4
|
| 27 |
+
bai
|
| 28 |
+
bai1
|
| 29 |
+
bai2
|
| 30 |
+
bai3
|
| 31 |
+
bai4
|
| 32 |
+
ban
|
| 33 |
+
ban1
|
| 34 |
+
ban3
|
| 35 |
+
ban4
|
| 36 |
+
bang1
|
| 37 |
+
bang3
|
| 38 |
+
bang4
|
| 39 |
+
bao1
|
| 40 |
+
bao2
|
| 41 |
+
bao3
|
| 42 |
+
bao4
|
| 43 |
+
bei
|
| 44 |
+
bei1
|
| 45 |
+
bei3
|
| 46 |
+
bei4
|
| 47 |
+
ben1
|
| 48 |
+
ben3
|
| 49 |
+
ben4
|
| 50 |
+
beng
|
| 51 |
+
beng1
|
| 52 |
+
beng2
|
| 53 |
+
beng3
|
| 54 |
+
beng4
|
| 55 |
+
bi1
|
| 56 |
+
bi2
|
| 57 |
+
bi3
|
| 58 |
+
bi4
|
| 59 |
+
bian
|
| 60 |
+
bian1
|
| 61 |
+
bian3
|
| 62 |
+
bian4
|
| 63 |
+
biang2
|
| 64 |
+
biao1
|
| 65 |
+
biao3
|
| 66 |
+
biao4
|
| 67 |
+
bie1
|
| 68 |
+
bie2
|
| 69 |
+
bie3
|
| 70 |
+
bie4
|
| 71 |
+
bin
|
| 72 |
+
bin1
|
| 73 |
+
bin3
|
| 74 |
+
bin4
|
| 75 |
+
bing1
|
| 76 |
+
bing3
|
| 77 |
+
bing4
|
| 78 |
+
bo
|
| 79 |
+
bo1
|
| 80 |
+
bo2
|
| 81 |
+
bo3
|
| 82 |
+
bo4
|
| 83 |
+
bu1
|
| 84 |
+
bu2
|
| 85 |
+
bu3
|
| 86 |
+
bu4
|
| 87 |
+
ca1
|
| 88 |
+
ca3
|
| 89 |
+
ca4
|
| 90 |
+
cai1
|
| 91 |
+
cai2
|
| 92 |
+
cai3
|
| 93 |
+
cai4
|
| 94 |
+
can1
|
| 95 |
+
can2
|
| 96 |
+
can3
|
| 97 |
+
can4
|
| 98 |
+
cang1
|
| 99 |
+
cang2
|
| 100 |
+
cang3
|
| 101 |
+
cang4
|
| 102 |
+
cao1
|
| 103 |
+
cao2
|
| 104 |
+
cao3
|
| 105 |
+
cao4
|
| 106 |
+
ce4
|
| 107 |
+
cei4
|
| 108 |
+
cen1
|
| 109 |
+
cen2
|
| 110 |
+
ceng1
|
| 111 |
+
ceng2
|
| 112 |
+
ceng4
|
| 113 |
+
cha1
|
| 114 |
+
cha2
|
| 115 |
+
cha3
|
| 116 |
+
cha4
|
| 117 |
+
chai1
|
| 118 |
+
chai2
|
| 119 |
+
chai3
|
| 120 |
+
chai4
|
| 121 |
+
chan1
|
| 122 |
+
chan2
|
| 123 |
+
chan3
|
| 124 |
+
chan4
|
| 125 |
+
chang
|
| 126 |
+
chang1
|
| 127 |
+
chang2
|
| 128 |
+
chang3
|
| 129 |
+
chang4
|
| 130 |
+
chao1
|
| 131 |
+
chao2
|
| 132 |
+
chao3
|
| 133 |
+
chao4
|
| 134 |
+
che1
|
| 135 |
+
che2
|
| 136 |
+
che3
|
| 137 |
+
che4
|
| 138 |
+
chen
|
| 139 |
+
chen1
|
| 140 |
+
chen2
|
| 141 |
+
chen3
|
| 142 |
+
chen4
|
| 143 |
+
cheng1
|
| 144 |
+
cheng2
|
| 145 |
+
cheng3
|
| 146 |
+
cheng4
|
| 147 |
+
chi
|
| 148 |
+
chi1
|
| 149 |
+
chi2
|
| 150 |
+
chi3
|
| 151 |
+
chi4
|
| 152 |
+
chong1
|
| 153 |
+
chong2
|
| 154 |
+
chong3
|
| 155 |
+
chong4
|
| 156 |
+
chou1
|
| 157 |
+
chou2
|
| 158 |
+
chou3
|
| 159 |
+
chou4
|
| 160 |
+
chu
|
| 161 |
+
chu1
|
| 162 |
+
chu2
|
| 163 |
+
chu3
|
| 164 |
+
chu4
|
| 165 |
+
chua1
|
| 166 |
+
chua3
|
| 167 |
+
chua4
|
| 168 |
+
chuai1
|
| 169 |
+
chuai2
|
| 170 |
+
chuai3
|
| 171 |
+
chuai4
|
| 172 |
+
chuan1
|
| 173 |
+
chuan2
|
| 174 |
+
chuan3
|
| 175 |
+
chuan4
|
| 176 |
+
chuang1
|
| 177 |
+
chuang2
|
| 178 |
+
chuang3
|
| 179 |
+
chuang4
|
| 180 |
+
chui1
|
| 181 |
+
chui2
|
| 182 |
+
chui3
|
| 183 |
+
chui4
|
| 184 |
+
chun1
|
| 185 |
+
chun2
|
| 186 |
+
chun3
|
| 187 |
+
chuo1
|
| 188 |
+
chuo4
|
| 189 |
+
ci1
|
| 190 |
+
ci2
|
| 191 |
+
ci3
|
| 192 |
+
ci4
|
| 193 |
+
cong1
|
| 194 |
+
cong2
|
| 195 |
+
cong3
|
| 196 |
+
cong4
|
| 197 |
+
cou1
|
| 198 |
+
cou2
|
| 199 |
+
cou3
|
| 200 |
+
cou4
|
| 201 |
+
cu1
|
| 202 |
+
cu2
|
| 203 |
+
cu3
|
| 204 |
+
cu4
|
| 205 |
+
cuan1
|
| 206 |
+
cuan2
|
| 207 |
+
cuan4
|
| 208 |
+
cui
|
| 209 |
+
cui1
|
| 210 |
+
cui3
|
| 211 |
+
cui4
|
| 212 |
+
cun1
|
| 213 |
+
cun2
|
| 214 |
+
cun3
|
| 215 |
+
cun4
|
| 216 |
+
cuo1
|
| 217 |
+
cuo2
|
| 218 |
+
cuo3
|
| 219 |
+
cuo4
|
| 220 |
+
da
|
| 221 |
+
da1
|
| 222 |
+
da2
|
| 223 |
+
da3
|
| 224 |
+
da4
|
| 225 |
+
dai
|
| 226 |
+
dai1
|
| 227 |
+
dai3
|
| 228 |
+
dai4
|
| 229 |
+
dan1
|
| 230 |
+
dan3
|
| 231 |
+
dan4
|
| 232 |
+
dang
|
| 233 |
+
dang1
|
| 234 |
+
dang3
|
| 235 |
+
dang4
|
| 236 |
+
dao1
|
| 237 |
+
dao2
|
| 238 |
+
dao3
|
| 239 |
+
dao4
|
| 240 |
+
de
|
| 241 |
+
de1
|
| 242 |
+
de2
|
| 243 |
+
dei1
|
| 244 |
+
dei3
|
| 245 |
+
den4
|
| 246 |
+
deng1
|
| 247 |
+
deng3
|
| 248 |
+
deng4
|
| 249 |
+
di1
|
| 250 |
+
di2
|
| 251 |
+
di3
|
| 252 |
+
di4
|
| 253 |
+
dia3
|
| 254 |
+
dian1
|
| 255 |
+
dian2
|
| 256 |
+
dian3
|
| 257 |
+
dian4
|
| 258 |
+
diao1
|
| 259 |
+
diao3
|
| 260 |
+
diao4
|
| 261 |
+
die1
|
| 262 |
+
die2
|
| 263 |
+
die3
|
| 264 |
+
die4
|
| 265 |
+
din4
|
| 266 |
+
ding1
|
| 267 |
+
ding3
|
| 268 |
+
ding4
|
| 269 |
+
diu1
|
| 270 |
+
dong1
|
| 271 |
+
dong3
|
| 272 |
+
dong4
|
| 273 |
+
dou1
|
| 274 |
+
dou3
|
| 275 |
+
dou4
|
| 276 |
+
du1
|
| 277 |
+
du2
|
| 278 |
+
du3
|
| 279 |
+
du4
|
| 280 |
+
duan1
|
| 281 |
+
duan3
|
| 282 |
+
duan4
|
| 283 |
+
dui1
|
| 284 |
+
dui3
|
| 285 |
+
dui4
|
| 286 |
+
dun1
|
| 287 |
+
dun3
|
| 288 |
+
dun4
|
| 289 |
+
duo
|
| 290 |
+
duo1
|
| 291 |
+
duo2
|
| 292 |
+
duo3
|
| 293 |
+
duo4
|
| 294 |
+
e
|
| 295 |
+
e1
|
| 296 |
+
e2
|
| 297 |
+
e3
|
| 298 |
+
e4
|
| 299 |
+
ei1
|
| 300 |
+
ei2
|
| 301 |
+
ei3
|
| 302 |
+
ei4
|
| 303 |
+
en1
|
| 304 |
+
en3
|
| 305 |
+
en4
|
| 306 |
+
eng1
|
| 307 |
+
er
|
| 308 |
+
er2
|
| 309 |
+
er3
|
| 310 |
+
er4
|
| 311 |
+
fa
|
| 312 |
+
fa1
|
| 313 |
+
fa2
|
| 314 |
+
fa3
|
| 315 |
+
fa4
|
| 316 |
+
fan1
|
| 317 |
+
fan2
|
| 318 |
+
fan3
|
| 319 |
+
fan4
|
| 320 |
+
fang
|
| 321 |
+
fang1
|
| 322 |
+
fang2
|
| 323 |
+
fang3
|
| 324 |
+
fang4
|
| 325 |
+
fei1
|
| 326 |
+
fei2
|
| 327 |
+
fei3
|
| 328 |
+
fei4
|
| 329 |
+
fen1
|
| 330 |
+
fen2
|
| 331 |
+
fen3
|
| 332 |
+
fen4
|
| 333 |
+
feng1
|
| 334 |
+
feng2
|
| 335 |
+
feng3
|
| 336 |
+
feng4
|
| 337 |
+
fiao4
|
| 338 |
+
fo2
|
| 339 |
+
fou1
|
| 340 |
+
fou2
|
| 341 |
+
fou3
|
| 342 |
+
fu
|
| 343 |
+
fu1
|
| 344 |
+
fu2
|
| 345 |
+
fu3
|
| 346 |
+
fu4
|
| 347 |
+
ga1
|
| 348 |
+
ga2
|
| 349 |
+
ga3
|
| 350 |
+
ga4
|
| 351 |
+
gai1
|
| 352 |
+
gai3
|
| 353 |
+
gai4
|
| 354 |
+
gan1
|
| 355 |
+
gan3
|
| 356 |
+
gan4
|
| 357 |
+
gang1
|
| 358 |
+
gang3
|
| 359 |
+
gang4
|
| 360 |
+
gao1
|
| 361 |
+
gao3
|
| 362 |
+
gao4
|
| 363 |
+
ge1
|
| 364 |
+
ge2
|
| 365 |
+
ge3
|
| 366 |
+
ge4
|
| 367 |
+
gei3
|
| 368 |
+
gen1
|
| 369 |
+
gen2
|
| 370 |
+
gen3
|
| 371 |
+
gen4
|
| 372 |
+
geng1
|
| 373 |
+
geng3
|
| 374 |
+
geng4
|
| 375 |
+
gong
|
| 376 |
+
gong1
|
| 377 |
+
gong3
|
| 378 |
+
gong4
|
| 379 |
+
gou1
|
| 380 |
+
gou3
|
| 381 |
+
gou4
|
| 382 |
+
gu
|
| 383 |
+
gu1
|
| 384 |
+
gu2
|
| 385 |
+
gu3
|
| 386 |
+
gu4
|
| 387 |
+
gua1
|
| 388 |
+
gua2
|
| 389 |
+
gua3
|
| 390 |
+
gua4
|
| 391 |
+
guai1
|
| 392 |
+
guai3
|
| 393 |
+
guai4
|
| 394 |
+
guan1
|
| 395 |
+
guan3
|
| 396 |
+
guan4
|
| 397 |
+
guang
|
| 398 |
+
guang1
|
| 399 |
+
guang3
|
| 400 |
+
guang4
|
| 401 |
+
gui1
|
| 402 |
+
gui3
|
| 403 |
+
gui4
|
| 404 |
+
gun3
|
| 405 |
+
gun4
|
| 406 |
+
guo
|
| 407 |
+
guo1
|
| 408 |
+
guo2
|
| 409 |
+
guo3
|
| 410 |
+
guo4
|
| 411 |
+
ha1
|
| 412 |
+
ha2
|
| 413 |
+
ha3
|
| 414 |
+
ha4
|
| 415 |
+
hai
|
| 416 |
+
hai1
|
| 417 |
+
hai2
|
| 418 |
+
hai3
|
| 419 |
+
hai4
|
| 420 |
+
han
|
| 421 |
+
han1
|
| 422 |
+
han2
|
| 423 |
+
han3
|
| 424 |
+
han4
|
| 425 |
+
hang1
|
| 426 |
+
hang2
|
| 427 |
+
hang3
|
| 428 |
+
hang4
|
| 429 |
+
hao1
|
| 430 |
+
hao2
|
| 431 |
+
hao3
|
| 432 |
+
hao4
|
| 433 |
+
he1
|
| 434 |
+
he2
|
| 435 |
+
he3
|
| 436 |
+
he4
|
| 437 |
+
hei1
|
| 438 |
+
hen1
|
| 439 |
+
hen2
|
| 440 |
+
hen3
|
| 441 |
+
hen4
|
| 442 |
+
heng1
|
| 443 |
+
heng2
|
| 444 |
+
heng4
|
| 445 |
+
hm
|
| 446 |
+
hng
|
| 447 |
+
hong1
|
| 448 |
+
hong2
|
| 449 |
+
hong3
|
| 450 |
+
hong4
|
| 451 |
+
hou1
|
| 452 |
+
hou2
|
| 453 |
+
hou3
|
| 454 |
+
hou4
|
| 455 |
+
hu
|
| 456 |
+
hu1
|
| 457 |
+
hu2
|
| 458 |
+
hu3
|
| 459 |
+
hu4
|
| 460 |
+
hua1
|
| 461 |
+
hua2
|
| 462 |
+
hua4
|
| 463 |
+
huai
|
| 464 |
+
huai2
|
| 465 |
+
huai4
|
| 466 |
+
huan1
|
| 467 |
+
huan2
|
| 468 |
+
huan3
|
| 469 |
+
huan4
|
| 470 |
+
huang
|
| 471 |
+
huang1
|
| 472 |
+
huang2
|
| 473 |
+
huang3
|
| 474 |
+
huang4
|
| 475 |
+
hui
|
| 476 |
+
hui1
|
| 477 |
+
hui2
|
| 478 |
+
hui3
|
| 479 |
+
hui4
|
| 480 |
+
hun1
|
| 481 |
+
hun2
|
| 482 |
+
hun3
|
| 483 |
+
hun4
|
| 484 |
+
huo
|
| 485 |
+
huo1
|
| 486 |
+
huo2
|
| 487 |
+
huo3
|
| 488 |
+
huo4
|
| 489 |
+
ji1
|
| 490 |
+
ji2
|
| 491 |
+
ji3
|
| 492 |
+
ji4
|
| 493 |
+
jia
|
| 494 |
+
jia1
|
| 495 |
+
jia2
|
| 496 |
+
jia3
|
| 497 |
+
jia4
|
| 498 |
+
jian
|
| 499 |
+
jian1
|
| 500 |
+
jian3
|
| 501 |
+
jian4
|
| 502 |
+
jiang
|
| 503 |
+
jiang1
|
| 504 |
+
jiang3
|
| 505 |
+
jiang4
|
| 506 |
+
jiao
|
| 507 |
+
jiao1
|
| 508 |
+
jiao2
|
| 509 |
+
jiao3
|
| 510 |
+
jiao4
|
| 511 |
+
jie
|
| 512 |
+
jie1
|
| 513 |
+
jie2
|
| 514 |
+
jie3
|
| 515 |
+
jie4
|
| 516 |
+
jin1
|
| 517 |
+
jin3
|
| 518 |
+
jin4
|
| 519 |
+
jing
|
| 520 |
+
jing1
|
| 521 |
+
jing3
|
| 522 |
+
jing4
|
| 523 |
+
jiong1
|
| 524 |
+
jiong3
|
| 525 |
+
jiong4
|
| 526 |
+
jiu
|
| 527 |
+
jiu1
|
| 528 |
+
jiu2
|
| 529 |
+
jiu3
|
| 530 |
+
jiu4
|
| 531 |
+
ju
|
| 532 |
+
ju1
|
| 533 |
+
ju2
|
| 534 |
+
ju3
|
| 535 |
+
ju4
|
| 536 |
+
juan1
|
| 537 |
+
juan3
|
| 538 |
+
juan4
|
| 539 |
+
jue1
|
| 540 |
+
jue2
|
| 541 |
+
jue3
|
| 542 |
+
jue4
|
| 543 |
+
jun1
|
| 544 |
+
jun3
|
| 545 |
+
jun4
|
| 546 |
+
ka1
|
| 547 |
+
ka3
|
| 548 |
+
kai1
|
| 549 |
+
kai3
|
| 550 |
+
kai4
|
| 551 |
+
kan1
|
| 552 |
+
kan3
|
| 553 |
+
kan4
|
| 554 |
+
kang1
|
| 555 |
+
kang2
|
| 556 |
+
kang3
|
| 557 |
+
kang4
|
| 558 |
+
kao1
|
| 559 |
+
kao3
|
| 560 |
+
kao4
|
| 561 |
+
ke
|
| 562 |
+
ke1
|
| 563 |
+
ke2
|
| 564 |
+
ke3
|
| 565 |
+
ke4
|
| 566 |
+
kei1
|
| 567 |
+
ken1
|
| 568 |
+
ken3
|
| 569 |
+
ken4
|
| 570 |
+
keng1
|
| 571 |
+
keng3
|
| 572 |
+
kong1
|
| 573 |
+
kong3
|
| 574 |
+
kong4
|
| 575 |
+
kou1
|
| 576 |
+
kou3
|
| 577 |
+
kou4
|
| 578 |
+
ku1
|
| 579 |
+
ku2
|
| 580 |
+
ku3
|
| 581 |
+
ku4
|
| 582 |
+
kua1
|
| 583 |
+
kua3
|
| 584 |
+
kua4
|
| 585 |
+
kuai3
|
| 586 |
+
kuai4
|
| 587 |
+
kuan1
|
| 588 |
+
kuan3
|
| 589 |
+
kuang1
|
| 590 |
+
kuang2
|
| 591 |
+
kuang3
|
| 592 |
+
kuang4
|
| 593 |
+
kui1
|
| 594 |
+
kui2
|
| 595 |
+
kui3
|
| 596 |
+
kui4
|
| 597 |
+
kun
|
| 598 |
+
kun1
|
| 599 |
+
kun3
|
| 600 |
+
kun4
|
| 601 |
+
kuo4
|
| 602 |
+
la
|
| 603 |
+
la1
|
| 604 |
+
la2
|
| 605 |
+
la3
|
| 606 |
+
la4
|
| 607 |
+
lai2
|
| 608 |
+
lai3
|
| 609 |
+
lai4
|
| 610 |
+
lan2
|
| 611 |
+
lan3
|
| 612 |
+
lan4
|
| 613 |
+
lang
|
| 614 |
+
lang1
|
| 615 |
+
lang2
|
| 616 |
+
lang3
|
| 617 |
+
lang4
|
| 618 |
+
lao
|
| 619 |
+
lao1
|
| 620 |
+
lao2
|
| 621 |
+
lao3
|
| 622 |
+
lao4
|
| 623 |
+
le
|
| 624 |
+
le1
|
| 625 |
+
le4
|
| 626 |
+
lei
|
| 627 |
+
lei1
|
| 628 |
+
lei2
|
| 629 |
+
lei3
|
| 630 |
+
lei4
|
| 631 |
+
len4
|
| 632 |
+
leng1
|
| 633 |
+
leng2
|
| 634 |
+
leng3
|
| 635 |
+
leng4
|
| 636 |
+
li
|
| 637 |
+
li1
|
| 638 |
+
li2
|
| 639 |
+
li3
|
| 640 |
+
li4
|
| 641 |
+
lia3
|
| 642 |
+
lian2
|
| 643 |
+
lian3
|
| 644 |
+
lian4
|
| 645 |
+
liang
|
| 646 |
+
liang2
|
| 647 |
+
liang3
|
| 648 |
+
liang4
|
| 649 |
+
liao1
|
| 650 |
+
liao2
|
| 651 |
+
liao3
|
| 652 |
+
liao4
|
| 653 |
+
lie
|
| 654 |
+
lie1
|
| 655 |
+
lie2
|
| 656 |
+
lie3
|
| 657 |
+
lie4
|
| 658 |
+
lin1
|
| 659 |
+
lin2
|
| 660 |
+
lin3
|
| 661 |
+
lin4
|
| 662 |
+
ling
|
| 663 |
+
ling1
|
| 664 |
+
ling2
|
| 665 |
+
ling3
|
| 666 |
+
ling4
|
| 667 |
+
liu1
|
| 668 |
+
liu2
|
| 669 |
+
liu3
|
| 670 |
+
liu4
|
| 671 |
+
lo
|
| 672 |
+
long1
|
| 673 |
+
long2
|
| 674 |
+
long3
|
| 675 |
+
long4
|
| 676 |
+
lou
|
| 677 |
+
lou1
|
| 678 |
+
lou2
|
| 679 |
+
lou3
|
| 680 |
+
lou4
|
| 681 |
+
lu
|
| 682 |
+
lu1
|
| 683 |
+
lu2
|
| 684 |
+
lu3
|
| 685 |
+
lu4
|
| 686 |
+
luan2
|
| 687 |
+
luan3
|
| 688 |
+
luan4
|
| 689 |
+
lun1
|
| 690 |
+
lun2
|
| 691 |
+
lun3
|
| 692 |
+
lun4
|
| 693 |
+
luo
|
| 694 |
+
luo1
|
| 695 |
+
luo2
|
| 696 |
+
luo3
|
| 697 |
+
luo4
|
| 698 |
+
lv2
|
| 699 |
+
lv3
|
| 700 |
+
lv4
|
| 701 |
+
lve3
|
| 702 |
+
lve4
|
| 703 |
+
m1
|
| 704 |
+
m2
|
| 705 |
+
m4
|
| 706 |
+
ma
|
| 707 |
+
ma1
|
| 708 |
+
ma2
|
| 709 |
+
ma3
|
| 710 |
+
ma4
|
| 711 |
+
mai2
|
| 712 |
+
mai3
|
| 713 |
+
mai4
|
| 714 |
+
man1
|
| 715 |
+
man2
|
| 716 |
+
man3
|
| 717 |
+
man4
|
| 718 |
+
mang1
|
| 719 |
+
mang2
|
| 720 |
+
mang3
|
| 721 |
+
mang4
|
| 722 |
+
mao1
|
| 723 |
+
mao2
|
| 724 |
+
mao3
|
| 725 |
+
mao4
|
| 726 |
+
me
|
| 727 |
+
me1
|
| 728 |
+
mei2
|
| 729 |
+
mei3
|
| 730 |
+
mei4
|
| 731 |
+
men
|
| 732 |
+
men1
|
| 733 |
+
men2
|
| 734 |
+
men4
|
| 735 |
+
meng
|
| 736 |
+
meng1
|
| 737 |
+
meng2
|
| 738 |
+
meng3
|
| 739 |
+
meng4
|
| 740 |
+
mi1
|
| 741 |
+
mi2
|
| 742 |
+
mi3
|
| 743 |
+
mi4
|
| 744 |
+
mian2
|
| 745 |
+
mian3
|
| 746 |
+
mian4
|
| 747 |
+
miao1
|
| 748 |
+
miao2
|
| 749 |
+
miao3
|
| 750 |
+
miao4
|
| 751 |
+
mie
|
| 752 |
+
mie1
|
| 753 |
+
mie2
|
| 754 |
+
mie4
|
| 755 |
+
min
|
| 756 |
+
min2
|
| 757 |
+
min3
|
| 758 |
+
ming
|
| 759 |
+
ming2
|
| 760 |
+
ming3
|
| 761 |
+
ming4
|
| 762 |
+
miu3
|
| 763 |
+
miu4
|
| 764 |
+
mo
|
| 765 |
+
mo1
|
| 766 |
+
mo2
|
| 767 |
+
mo3
|
| 768 |
+
mo4
|
| 769 |
+
mou1
|
| 770 |
+
mou2
|
| 771 |
+
mou3
|
| 772 |
+
mou4
|
| 773 |
+
mu2
|
| 774 |
+
mu3
|
| 775 |
+
mu4
|
| 776 |
+
n
|
| 777 |
+
n2
|
| 778 |
+
n3
|
| 779 |
+
n4
|
| 780 |
+
na
|
| 781 |
+
na1
|
| 782 |
+
na2
|
| 783 |
+
na3
|
| 784 |
+
na4
|
| 785 |
+
nai2
|
| 786 |
+
nai3
|
| 787 |
+
nai4
|
| 788 |
+
nan1
|
| 789 |
+
nan2
|
| 790 |
+
nan3
|
| 791 |
+
nan4
|
| 792 |
+
nang
|
| 793 |
+
nang1
|
| 794 |
+
nang2
|
| 795 |
+
nang3
|
| 796 |
+
nang4
|
| 797 |
+
nao1
|
| 798 |
+
nao2
|
| 799 |
+
nao3
|
| 800 |
+
nao4
|
| 801 |
+
ne
|
| 802 |
+
ne2
|
| 803 |
+
ne4
|
| 804 |
+
nei2
|
| 805 |
+
nei3
|
| 806 |
+
nei4
|
| 807 |
+
nen4
|
| 808 |
+
neng2
|
| 809 |
+
neng3
|
| 810 |
+
neng4
|
| 811 |
+
ng
|
| 812 |
+
ng2
|
| 813 |
+
ng3
|
| 814 |
+
ng4
|
| 815 |
+
ni1
|
| 816 |
+
ni2
|
| 817 |
+
ni3
|
| 818 |
+
ni4
|
| 819 |
+
nia1
|
| 820 |
+
nian1
|
| 821 |
+
nian2
|
| 822 |
+
nian3
|
| 823 |
+
nian4
|
| 824 |
+
niang2
|
| 825 |
+
niang3
|
| 826 |
+
niang4
|
| 827 |
+
niao3
|
| 828 |
+
niao4
|
| 829 |
+
nie1
|
| 830 |
+
nie2
|
| 831 |
+
nie3
|
| 832 |
+
nie4
|
| 833 |
+
nin
|
| 834 |
+
nin2
|
| 835 |
+
nin3
|
| 836 |
+
ning2
|
| 837 |
+
ning3
|
| 838 |
+
ning4
|
| 839 |
+
niu1
|
| 840 |
+
niu2
|
| 841 |
+
niu3
|
| 842 |
+
niu4
|
| 843 |
+
nong2
|
| 844 |
+
nong3
|
| 845 |
+
nong4
|
| 846 |
+
nou2
|
| 847 |
+
nou3
|
| 848 |
+
nou4
|
| 849 |
+
nu2
|
| 850 |
+
nu3
|
| 851 |
+
nu4
|
| 852 |
+
nuan2
|
| 853 |
+
nuan3
|
| 854 |
+
nuan4
|
| 855 |
+
nun2
|
| 856 |
+
nun4
|
| 857 |
+
nuo2
|
| 858 |
+
nuo3
|
| 859 |
+
nuo4
|
| 860 |
+
nv2
|
| 861 |
+
nv3
|
| 862 |
+
nv4
|
| 863 |
+
nve4
|
| 864 |
+
o
|
| 865 |
+
o1
|
| 866 |
+
o2
|
| 867 |
+
o3
|
| 868 |
+
o4
|
| 869 |
+
ou
|
| 870 |
+
ou1
|
| 871 |
+
ou2
|
| 872 |
+
ou3
|
| 873 |
+
ou4
|
| 874 |
+
pa1
|
| 875 |
+
pa2
|
| 876 |
+
pa3
|
| 877 |
+
pa4
|
| 878 |
+
pai1
|
| 879 |
+
pai2
|
| 880 |
+
pai3
|
| 881 |
+
pai4
|
| 882 |
+
pan1
|
| 883 |
+
pan2
|
| 884 |
+
pan3
|
| 885 |
+
pan4
|
| 886 |
+
pang1
|
| 887 |
+
pang2
|
| 888 |
+
pang3
|
| 889 |
+
pang4
|
| 890 |
+
pao1
|
| 891 |
+
pao2
|
| 892 |
+
pao3
|
| 893 |
+
pao4
|
| 894 |
+
pei1
|
| 895 |
+
pei2
|
| 896 |
+
pei3
|
| 897 |
+
pei4
|
| 898 |
+
pen1
|
| 899 |
+
pen2
|
| 900 |
+
pen3
|
| 901 |
+
pen4
|
| 902 |
+
peng1
|
| 903 |
+
peng2
|
| 904 |
+
peng3
|
| 905 |
+
peng4
|
| 906 |
+
pi1
|
| 907 |
+
pi2
|
| 908 |
+
pi3
|
| 909 |
+
pi4
|
| 910 |
+
pian1
|
| 911 |
+
pian2
|
| 912 |
+
pian3
|
| 913 |
+
pian4
|
| 914 |
+
piao1
|
| 915 |
+
piao2
|
| 916 |
+
piao3
|
| 917 |
+
piao4
|
| 918 |
+
pie1
|
| 919 |
+
pie3
|
| 920 |
+
pie4
|
| 921 |
+
pin1
|
| 922 |
+
pin2
|
| 923 |
+
pin3
|
| 924 |
+
pin4
|
| 925 |
+
ping1
|
| 926 |
+
ping2
|
| 927 |
+
ping3
|
| 928 |
+
ping4
|
| 929 |
+
po
|
| 930 |
+
po1
|
| 931 |
+
po2
|
| 932 |
+
po3
|
| 933 |
+
po4
|
| 934 |
+
pou1
|
| 935 |
+
pou2
|
| 936 |
+
pou3
|
| 937 |
+
pou4
|
| 938 |
+
pu
|
| 939 |
+
pu1
|
| 940 |
+
pu2
|
| 941 |
+
pu3
|
| 942 |
+
pu4
|
| 943 |
+
qi
|
| 944 |
+
qi1
|
| 945 |
+
qi2
|
| 946 |
+
qi3
|
| 947 |
+
qi4
|
| 948 |
+
qia1
|
| 949 |
+
qia2
|
| 950 |
+
qia3
|
| 951 |
+
qia4
|
| 952 |
+
qian
|
| 953 |
+
qian1
|
| 954 |
+
qian2
|
| 955 |
+
qian3
|
| 956 |
+
qian4
|
| 957 |
+
qiang1
|
| 958 |
+
qiang2
|
| 959 |
+
qiang3
|
| 960 |
+
qiang4
|
| 961 |
+
qiao1
|
| 962 |
+
qiao2
|
| 963 |
+
qiao3
|
| 964 |
+
qiao4
|
| 965 |
+
qie1
|
| 966 |
+
qie2
|
| 967 |
+
qie3
|
| 968 |
+
qie4
|
| 969 |
+
qin1
|
| 970 |
+
qin2
|
| 971 |
+
qin3
|
| 972 |
+
qin4
|
| 973 |
+
qing
|
| 974 |
+
qing1
|
| 975 |
+
qing2
|
| 976 |
+
qing3
|
| 977 |
+
qing4
|
| 978 |
+
qiong1
|
| 979 |
+
qiong2
|
| 980 |
+
qiong4
|
| 981 |
+
qiu1
|
| 982 |
+
qiu2
|
| 983 |
+
qiu3
|
| 984 |
+
qiu4
|
| 985 |
+
qu
|
| 986 |
+
qu1
|
| 987 |
+
qu2
|
| 988 |
+
qu3
|
| 989 |
+
qu4
|
| 990 |
+
quan
|
| 991 |
+
quan1
|
| 992 |
+
quan2
|
| 993 |
+
quan3
|
| 994 |
+
quan4
|
| 995 |
+
que1
|
| 996 |
+
que2
|
| 997 |
+
que4
|
| 998 |
+
qun1
|
| 999 |
+
qun2
|
| 1000 |
+
qun3
|
| 1001 |
+
ran2
|
| 1002 |
+
ran3
|
| 1003 |
+
ran4
|
| 1004 |
+
rang1
|
| 1005 |
+
rang2
|
| 1006 |
+
rang3
|
| 1007 |
+
rang4
|
| 1008 |
+
rao2
|
| 1009 |
+
rao3
|
| 1010 |
+
rao4
|
| 1011 |
+
re2
|
| 1012 |
+
re3
|
| 1013 |
+
re4
|
| 1014 |
+
ren2
|
| 1015 |
+
ren3
|
| 1016 |
+
ren4
|
| 1017 |
+
reng1
|
| 1018 |
+
reng2
|
| 1019 |
+
reng4
|
| 1020 |
+
ri4
|
| 1021 |
+
rong
|
| 1022 |
+
rong1
|
| 1023 |
+
rong2
|
| 1024 |
+
rong3
|
| 1025 |
+
rong4
|
| 1026 |
+
rou2
|
| 1027 |
+
rou3
|
| 1028 |
+
rou4
|
| 1029 |
+
ru
|
| 1030 |
+
ru2
|
| 1031 |
+
ru3
|
| 1032 |
+
ru4
|
| 1033 |
+
rua2
|
| 1034 |
+
ruan2
|
| 1035 |
+
ruan3
|
| 1036 |
+
ruan4
|
| 1037 |
+
rui2
|
| 1038 |
+
rui3
|
| 1039 |
+
rui4
|
| 1040 |
+
run2
|
| 1041 |
+
run3
|
| 1042 |
+
run4
|
| 1043 |
+
ruo2
|
| 1044 |
+
ruo4
|
| 1045 |
+
sa
|
| 1046 |
+
sa1
|
| 1047 |
+
sa3
|
| 1048 |
+
sa4
|
| 1049 |
+
sai1
|
| 1050 |
+
sai3
|
| 1051 |
+
sai4
|
| 1052 |
+
san
|
| 1053 |
+
san1
|
| 1054 |
+
san3
|
| 1055 |
+
san4
|
| 1056 |
+
sang1
|
| 1057 |
+
sang3
|
| 1058 |
+
sang4
|
| 1059 |
+
sao1
|
| 1060 |
+
sao3
|
| 1061 |
+
sao4
|
| 1062 |
+
se1
|
| 1063 |
+
se4
|
| 1064 |
+
sen1
|
| 1065 |
+
sen3
|
| 1066 |
+
seng1
|
| 1067 |
+
seng4
|
| 1068 |
+
sha
|
| 1069 |
+
sha1
|
| 1070 |
+
sha2
|
| 1071 |
+
sha3
|
| 1072 |
+
sha4
|
| 1073 |
+
shai1
|
| 1074 |
+
shai3
|
| 1075 |
+
shai4
|
| 1076 |
+
shan1
|
| 1077 |
+
shan2
|
| 1078 |
+
shan3
|
| 1079 |
+
shan4
|
| 1080 |
+
shang
|
| 1081 |
+
shang1
|
| 1082 |
+
shang3
|
| 1083 |
+
shang4
|
| 1084 |
+
shao1
|
| 1085 |
+
shao2
|
| 1086 |
+
shao3
|
| 1087 |
+
shao4
|
| 1088 |
+
she1
|
| 1089 |
+
she2
|
| 1090 |
+
she3
|
| 1091 |
+
she4
|
| 1092 |
+
shei2
|
| 1093 |
+
shen1
|
| 1094 |
+
shen2
|
| 1095 |
+
shen3
|
| 1096 |
+
shen4
|
| 1097 |
+
sheng1
|
| 1098 |
+
sheng2
|
| 1099 |
+
sheng3
|
| 1100 |
+
sheng4
|
| 1101 |
+
shi
|
| 1102 |
+
shi1
|
| 1103 |
+
shi2
|
| 1104 |
+
shi3
|
| 1105 |
+
shi4
|
| 1106 |
+
shou
|
| 1107 |
+
shou1
|
| 1108 |
+
shou2
|
| 1109 |
+
shou3
|
| 1110 |
+
shou4
|
| 1111 |
+
shu1
|
| 1112 |
+
shu2
|
| 1113 |
+
shu3
|
| 1114 |
+
shu4
|
| 1115 |
+
shua1
|
| 1116 |
+
shua3
|
| 1117 |
+
shua4
|
| 1118 |
+
shuai1
|
| 1119 |
+
shuai3
|
| 1120 |
+
shuai4
|
| 1121 |
+
shuan1
|
| 1122 |
+
shuan4
|
| 1123 |
+
shuang1
|
| 1124 |
+
shuang3
|
| 1125 |
+
shuang4
|
| 1126 |
+
shui
|
| 1127 |
+
shui2
|
| 1128 |
+
shui3
|
| 1129 |
+
shui4
|
| 1130 |
+
shun3
|
| 1131 |
+
shun4
|
| 1132 |
+
shuo1
|
| 1133 |
+
shuo2
|
| 1134 |
+
shuo4
|
| 1135 |
+
si
|
| 1136 |
+
si1
|
| 1137 |
+
si2
|
| 1138 |
+
si3
|
| 1139 |
+
si4
|
| 1140 |
+
song1
|
| 1141 |
+
song2
|
| 1142 |
+
song3
|
| 1143 |
+
song4
|
| 1144 |
+
sou1
|
| 1145 |
+
sou3
|
| 1146 |
+
sou4
|
| 1147 |
+
su1
|
| 1148 |
+
su2
|
| 1149 |
+
su3
|
| 1150 |
+
su4
|
| 1151 |
+
suan1
|
| 1152 |
+
suan3
|
| 1153 |
+
suan4
|
| 1154 |
+
sui1
|
| 1155 |
+
sui2
|
| 1156 |
+
sui3
|
| 1157 |
+
sui4
|
| 1158 |
+
sun1
|
| 1159 |
+
sun3
|
| 1160 |
+
sun4
|
| 1161 |
+
suo
|
| 1162 |
+
suo1
|
| 1163 |
+
suo2
|
| 1164 |
+
suo3
|
| 1165 |
+
suo4
|
| 1166 |
+
ta
|
| 1167 |
+
ta1
|
| 1168 |
+
ta2
|
| 1169 |
+
ta3
|
| 1170 |
+
ta4
|
| 1171 |
+
tai
|
| 1172 |
+
tai1
|
| 1173 |
+
tai2
|
| 1174 |
+
tai3
|
| 1175 |
+
tai4
|
| 1176 |
+
tan1
|
| 1177 |
+
tan2
|
| 1178 |
+
tan3
|
| 1179 |
+
tan4
|
| 1180 |
+
tang1
|
| 1181 |
+
tang2
|
| 1182 |
+
tang3
|
| 1183 |
+
tang4
|
| 1184 |
+
tao1
|
| 1185 |
+
tao2
|
| 1186 |
+
tao3
|
| 1187 |
+
tao4
|
| 1188 |
+
te
|
| 1189 |
+
te4
|
| 1190 |
+
tei1
|
| 1191 |
+
teng1
|
| 1192 |
+
teng2
|
| 1193 |
+
teng4
|
| 1194 |
+
ti
|
| 1195 |
+
ti1
|
| 1196 |
+
ti2
|
| 1197 |
+
ti3
|
| 1198 |
+
ti4
|
| 1199 |
+
tian1
|
| 1200 |
+
tian2
|
| 1201 |
+
tian3
|
| 1202 |
+
tian4
|
| 1203 |
+
tiao
|
| 1204 |
+
tiao1
|
| 1205 |
+
tiao2
|
| 1206 |
+
tiao3
|
| 1207 |
+
tiao4
|
| 1208 |
+
tie1
|
| 1209 |
+
tie2
|
| 1210 |
+
tie3
|
| 1211 |
+
tie4
|
| 1212 |
+
ting1
|
| 1213 |
+
ting2
|
| 1214 |
+
ting3
|
| 1215 |
+
ting4
|
| 1216 |
+
tong1
|
| 1217 |
+
tong2
|
| 1218 |
+
tong3
|
| 1219 |
+
tong4
|
| 1220 |
+
tou
|
| 1221 |
+
tou1
|
| 1222 |
+
tou2
|
| 1223 |
+
tou3
|
| 1224 |
+
tou4
|
| 1225 |
+
tu
|
| 1226 |
+
tu1
|
| 1227 |
+
tu2
|
| 1228 |
+
tu3
|
| 1229 |
+
tu4
|
| 1230 |
+
tuan1
|
| 1231 |
+
tuan2
|
| 1232 |
+
tuan3
|
| 1233 |
+
tuan4
|
| 1234 |
+
tui1
|
| 1235 |
+
tui2
|
| 1236 |
+
tui3
|
| 1237 |
+
tui4
|
| 1238 |
+
tun1
|
| 1239 |
+
tun2
|
| 1240 |
+
tun3
|
| 1241 |
+
tun4
|
| 1242 |
+
tuo1
|
| 1243 |
+
tuo2
|
| 1244 |
+
tuo3
|
| 1245 |
+
tuo4
|
| 1246 |
+
wa
|
| 1247 |
+
wa1
|
| 1248 |
+
wa2
|
| 1249 |
+
wa3
|
| 1250 |
+
wa4
|
| 1251 |
+
wai
|
| 1252 |
+
wai1
|
| 1253 |
+
wai3
|
| 1254 |
+
wai4
|
| 1255 |
+
wan1
|
| 1256 |
+
wan2
|
| 1257 |
+
wan3
|
| 1258 |
+
wan4
|
| 1259 |
+
wang1
|
| 1260 |
+
wang2
|
| 1261 |
+
wang3
|
| 1262 |
+
wang4
|
| 1263 |
+
wei
|
| 1264 |
+
wei1
|
| 1265 |
+
wei2
|
| 1266 |
+
wei3
|
| 1267 |
+
wei4
|
| 1268 |
+
wen
|
| 1269 |
+
wen1
|
| 1270 |
+
wen2
|
| 1271 |
+
wen3
|
| 1272 |
+
wen4
|
| 1273 |
+
weng1
|
| 1274 |
+
weng3
|
| 1275 |
+
weng4
|
| 1276 |
+
wo1
|
| 1277 |
+
wo3
|
| 1278 |
+
wo4
|
| 1279 |
+
wong4
|
| 1280 |
+
wu
|
| 1281 |
+
wu1
|
| 1282 |
+
wu2
|
| 1283 |
+
wu3
|
| 1284 |
+
wu4
|
| 1285 |
+
xi1
|
| 1286 |
+
xi2
|
| 1287 |
+
xi3
|
| 1288 |
+
xi4
|
| 1289 |
+
xia1
|
| 1290 |
+
xia2
|
| 1291 |
+
xia3
|
| 1292 |
+
xia4
|
| 1293 |
+
xian
|
| 1294 |
+
xian1
|
| 1295 |
+
xian2
|
| 1296 |
+
xian3
|
| 1297 |
+
xian4
|
| 1298 |
+
xiang1
|
| 1299 |
+
xiang2
|
| 1300 |
+
xiang3
|
| 1301 |
+
xiang4
|
| 1302 |
+
xiao
|
| 1303 |
+
xiao1
|
| 1304 |
+
xiao2
|
| 1305 |
+
xiao3
|
| 1306 |
+
xiao4
|
| 1307 |
+
xie1
|
| 1308 |
+
xie2
|
| 1309 |
+
xie3
|
| 1310 |
+
xie4
|
| 1311 |
+
xin
|
| 1312 |
+
xin1
|
| 1313 |
+
xin2
|
| 1314 |
+
xin3
|
| 1315 |
+
xin4
|
| 1316 |
+
xing
|
| 1317 |
+
xing1
|
| 1318 |
+
xing2
|
| 1319 |
+
xing3
|
| 1320 |
+
xing4
|
| 1321 |
+
xiong1
|
| 1322 |
+
xiong2
|
| 1323 |
+
xiong3
|
| 1324 |
+
xiong4
|
| 1325 |
+
xiu1
|
| 1326 |
+
xiu2
|
| 1327 |
+
xiu3
|
| 1328 |
+
xiu4
|
| 1329 |
+
xu
|
| 1330 |
+
xu1
|
| 1331 |
+
xu2
|
| 1332 |
+
xu3
|
| 1333 |
+
xu4
|
| 1334 |
+
xuan1
|
| 1335 |
+
xuan2
|
| 1336 |
+
xuan3
|
| 1337 |
+
xuan4
|
| 1338 |
+
xue1
|
| 1339 |
+
xue2
|
| 1340 |
+
xue3
|
| 1341 |
+
xue4
|
| 1342 |
+
xun1
|
| 1343 |
+
xun2
|
| 1344 |
+
xun4
|
| 1345 |
+
ya
|
| 1346 |
+
ya1
|
| 1347 |
+
ya2
|
| 1348 |
+
ya3
|
| 1349 |
+
ya4
|
| 1350 |
+
yan1
|
| 1351 |
+
yan2
|
| 1352 |
+
yan3
|
| 1353 |
+
yan4
|
| 1354 |
+
yang
|
| 1355 |
+
yang1
|
| 1356 |
+
yang2
|
| 1357 |
+
yang3
|
| 1358 |
+
yang4
|
| 1359 |
+
yao1
|
| 1360 |
+
yao2
|
| 1361 |
+
yao3
|
| 1362 |
+
yao4
|
| 1363 |
+
ye
|
| 1364 |
+
ye1
|
| 1365 |
+
ye2
|
| 1366 |
+
ye3
|
| 1367 |
+
ye4
|
| 1368 |
+
yi
|
| 1369 |
+
yi1
|
| 1370 |
+
yi2
|
| 1371 |
+
yi3
|
| 1372 |
+
yi4
|
| 1373 |
+
yin
|
| 1374 |
+
yin1
|
| 1375 |
+
yin2
|
| 1376 |
+
yin3
|
| 1377 |
+
yin4
|
| 1378 |
+
ying1
|
| 1379 |
+
ying2
|
| 1380 |
+
ying3
|
| 1381 |
+
ying4
|
| 1382 |
+
yo
|
| 1383 |
+
yo1
|
| 1384 |
+
yong1
|
| 1385 |
+
yong2
|
| 1386 |
+
yong3
|
| 1387 |
+
yong4
|
| 1388 |
+
you
|
| 1389 |
+
you1
|
| 1390 |
+
you2
|
| 1391 |
+
you3
|
| 1392 |
+
you4
|
| 1393 |
+
yu
|
| 1394 |
+
yu1
|
| 1395 |
+
yu2
|
| 1396 |
+
yu3
|
| 1397 |
+
yu4
|
| 1398 |
+
yuan1
|
| 1399 |
+
yuan2
|
| 1400 |
+
yuan3
|
| 1401 |
+
yuan4
|
| 1402 |
+
yue1
|
| 1403 |
+
yue2
|
| 1404 |
+
yue3
|
| 1405 |
+
yue4
|
| 1406 |
+
yun
|
| 1407 |
+
yun1
|
| 1408 |
+
yun2
|
| 1409 |
+
yun3
|
| 1410 |
+
yun4
|
| 1411 |
+
za1
|
| 1412 |
+
za2
|
| 1413 |
+
za3
|
| 1414 |
+
za4
|
| 1415 |
+
zai1
|
| 1416 |
+
zai3
|
| 1417 |
+
zai4
|
| 1418 |
+
zan
|
| 1419 |
+
zan1
|
| 1420 |
+
zan2
|
| 1421 |
+
zan3
|
| 1422 |
+
zan4
|
| 1423 |
+
zang1
|
| 1424 |
+
zang3
|
| 1425 |
+
zang4
|
| 1426 |
+
zao1
|
| 1427 |
+
zao2
|
| 1428 |
+
zao3
|
| 1429 |
+
zao4
|
| 1430 |
+
ze
|
| 1431 |
+
ze2
|
| 1432 |
+
ze4
|
| 1433 |
+
zei2
|
| 1434 |
+
zen
|
| 1435 |
+
zen1
|
| 1436 |
+
zen3
|
| 1437 |
+
zen4
|
| 1438 |
+
zeng1
|
| 1439 |
+
zeng3
|
| 1440 |
+
zeng4
|
| 1441 |
+
zha
|
| 1442 |
+
zha1
|
| 1443 |
+
zha2
|
| 1444 |
+
zha3
|
| 1445 |
+
zha4
|
| 1446 |
+
zhai1
|
| 1447 |
+
zhai2
|
| 1448 |
+
zhai3
|
| 1449 |
+
zhai4
|
| 1450 |
+
zhan1
|
| 1451 |
+
zhan2
|
| 1452 |
+
zhan3
|
| 1453 |
+
zhan4
|
| 1454 |
+
zhang
|
| 1455 |
+
zhang1
|
| 1456 |
+
zhang3
|
| 1457 |
+
zhang4
|
| 1458 |
+
zhao
|
| 1459 |
+
zhao1
|
| 1460 |
+
zhao2
|
| 1461 |
+
zhao3
|
| 1462 |
+
zhao4
|
| 1463 |
+
zhe
|
| 1464 |
+
zhe1
|
| 1465 |
+
zhe2
|
| 1466 |
+
zhe3
|
| 1467 |
+
zhe4
|
| 1468 |
+
zhei4
|
| 1469 |
+
zhen1
|
| 1470 |
+
zhen2
|
| 1471 |
+
zhen3
|
| 1472 |
+
zhen4
|
| 1473 |
+
zheng1
|
| 1474 |
+
zheng3
|
| 1475 |
+
zheng4
|
| 1476 |
+
zhi
|
| 1477 |
+
zhi1
|
| 1478 |
+
zhi2
|
| 1479 |
+
zhi3
|
| 1480 |
+
zhi4
|
| 1481 |
+
zhong1
|
| 1482 |
+
zhong3
|
| 1483 |
+
zhong4
|
| 1484 |
+
zhou1
|
| 1485 |
+
zhou2
|
| 1486 |
+
zhou3
|
| 1487 |
+
zhou4
|
| 1488 |
+
zhu1
|
| 1489 |
+
zhu2
|
| 1490 |
+
zhu3
|
| 1491 |
+
zhu4
|
| 1492 |
+
zhua1
|
| 1493 |
+
zhua3
|
| 1494 |
+
zhuai1
|
| 1495 |
+
zhuai3
|
| 1496 |
+
zhuai4
|
| 1497 |
+
zhuan1
|
| 1498 |
+
zhuan2
|
| 1499 |
+
zhuan3
|
| 1500 |
+
zhuan4
|
| 1501 |
+
zhuang1
|
| 1502 |
+
zhuang3
|
| 1503 |
+
zhuang4
|
| 1504 |
+
zhui1
|
| 1505 |
+
zhui3
|
| 1506 |
+
zhui4
|
| 1507 |
+
zhun1
|
| 1508 |
+
zhun3
|
| 1509 |
+
zhun4
|
| 1510 |
+
zhuo
|
| 1511 |
+
zhuo1
|
| 1512 |
+
zhuo2
|
| 1513 |
+
zhuo4
|
| 1514 |
+
zi
|
| 1515 |
+
zi1
|
| 1516 |
+
zi2
|
| 1517 |
+
zi3
|
| 1518 |
+
zi4
|
| 1519 |
+
zong
|
| 1520 |
+
zong1
|
| 1521 |
+
zong3
|
| 1522 |
+
zong4
|
| 1523 |
+
zou1
|
| 1524 |
+
zou3
|
| 1525 |
+
zou4
|
| 1526 |
+
zu1
|
| 1527 |
+
zu2
|
| 1528 |
+
zu3
|
| 1529 |
+
zu4
|
| 1530 |
+
zuan1
|
| 1531 |
+
zuan3
|
| 1532 |
+
zuan4
|
| 1533 |
+
zui
|
| 1534 |
+
zui1
|
| 1535 |
+
zui2
|
| 1536 |
+
zui3
|
| 1537 |
+
zui4
|
| 1538 |
+
zun1
|
| 1539 |
+
zun2
|
| 1540 |
+
zun3
|
| 1541 |
+
zun4
|
| 1542 |
+
zuo
|
| 1543 |
+
zuo1
|
| 1544 |
+
zuo2
|
| 1545 |
+
zuo3
|
| 1546 |
+
zuo4
|
| 1547 |
+
Γͺ1
|
| 1548 |
+
Γͺ2
|
| 1549 |
+
Γͺ3
|
| 1550 |
+
Γͺ4
|
egs/zipvoice/local/prepare_emilia.sh
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
| 4 |
+
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
| 5 |
+
export PYTHONPATH=../../:$PYTHONPATH
|
| 6 |
+
|
| 7 |
+
set -eou pipefail
|
| 8 |
+
|
| 9 |
+
stage=0
|
| 10 |
+
stop_stage=5
|
| 11 |
+
sampling_rate=24000
|
| 12 |
+
nj=32
|
| 13 |
+
|
| 14 |
+
dl_dir=$PWD/download
|
| 15 |
+
|
| 16 |
+
. scripts/parse_options.sh || exit 1
|
| 17 |
+
|
| 18 |
+
# All files generated by this script are saved in "data".
|
| 19 |
+
# You can safely remove "data" and rerun this script to regenerate it.
|
| 20 |
+
mkdir -p data
|
| 21 |
+
|
| 22 |
+
log() {
|
| 23 |
+
# This function is from espnet
|
| 24 |
+
local fname=${BASH_SOURCE[1]##*/}
|
| 25 |
+
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
log "dl_dir: $dl_dir"
|
| 29 |
+
|
| 30 |
+
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
| 31 |
+
log "Stage 0: Download data"
|
| 32 |
+
|
| 33 |
+
# Your download directory should look like this:
|
| 34 |
+
#
|
| 35 |
+
# download/Amphion___Emilia
|
| 36 |
+
# βββ metafile.yaml
|
| 37 |
+
# βββ raw
|
| 38 |
+
# β βββ DE
|
| 39 |
+
# β βββ EN
|
| 40 |
+
# β βββ FR
|
| 41 |
+
# β βββ JA
|
| 42 |
+
# β βββ KO
|
| 43 |
+
# β βββ openemilia_45batches.tar.gz
|
| 44 |
+
# β βββ openemilia_all.tar.gz
|
| 45 |
+
# β βββ ZH
|
| 46 |
+
# βββ README.md
|
| 47 |
+
|
| 48 |
+
if [ ! -d $dl_dir/Amphion___Emilia/raw ]; then
|
| 49 |
+
log "Please refer https://openxlab.org.cn/datasets/Amphion/Emilia to download the dataset."
|
| 50 |
+
exit(-1)
|
| 51 |
+
fi
|
| 52 |
+
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
| 56 |
+
log "Stage 1: Prepare emilia manifests (EN and ZH only)"
|
| 57 |
+
# We assume that you have downloaded the Emilia corpus
|
| 58 |
+
# to $dl_dir/Amphion___Emilia
|
| 59 |
+
# see stage 0 for the directory structure
|
| 60 |
+
mkdir -p data/manifests
|
| 61 |
+
if [ ! -e data/manifests/.emilia.done ]; then
|
| 62 |
+
lhotse prepare emilia --lang en --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests
|
| 63 |
+
lhotse prepare emilia --lang zh --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests
|
| 64 |
+
touch data/manifests/.emilia.done
|
| 65 |
+
fi
|
| 66 |
+
fi
|
| 67 |
+
|
| 68 |
+
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
| 69 |
+
log "Stage 2: Preprocess Emilia dataset, mainly for cleaning"
|
| 70 |
+
mkdir -p data/manifests/splits_raw
|
| 71 |
+
if [ ! -e data/manifests/split_raw/.emilia.split.done ]; then
|
| 72 |
+
lhotse split-lazy data/manifests/emilia_cuts_EN.jsonl.gz data/manifests/splits_raw 10000
|
| 73 |
+
lhotse split-lazy data/manifests/emilia_cuts_ZH.jsonl.gz data/manifests/splits_raw 10000
|
| 74 |
+
touch data/manifests/splits_raw/.emilia.split.done
|
| 75 |
+
fi
|
| 76 |
+
|
| 77 |
+
mkdir -p data/manifests/splits
|
| 78 |
+
|
| 79 |
+
if [ ! -e data/manifests/splits/.emilia.preprocess.done ]; then
|
| 80 |
+
python local/preprocess_emilia.py --subset EN
|
| 81 |
+
python local/preprocess_emilia.py --subset ZH
|
| 82 |
+
touch data/manifests/splits/.emilia.preprocess.done
|
| 83 |
+
fi
|
| 84 |
+
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
| 89 |
+
log "Stage 3: Add tokens to manifests"
|
| 90 |
+
|
| 91 |
+
mkdir -p data/manifests/tokenized_splits
|
| 92 |
+
|
| 93 |
+
if [ ! -e data/manifests/tokenized_splits/.emilia.preprocess.done ]; then
|
| 94 |
+
for subset in EN ZH; do
|
| 95 |
+
log "Tokenizing Emilia ${subset}"
|
| 96 |
+
python local/prepare_emilia.py \
|
| 97 |
+
--subset ${subset} \
|
| 98 |
+
--jobs ${nj} \
|
| 99 |
+
--source-dir data/manifests/splits/ \
|
| 100 |
+
--output-dir data/manifests/tokenized_splits/
|
| 101 |
+
done
|
| 102 |
+
touch data/manifests/tokenized_splits/.emilia.preprocess.done
|
| 103 |
+
fi
|
| 104 |
+
|
| 105 |
+
fi
|
| 106 |
+
|
| 107 |
+
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
| 108 |
+
log "Stage 4: Extract Fbank for Emilia"
|
| 109 |
+
mkdir -p data/fbank/emilia_splits
|
| 110 |
+
if [ ! -e data/fbank/emilia_splits/.emilia.fbank.done ]; then
|
| 111 |
+
# You can speed up the extraction by distributing splits to multiple machines.
|
| 112 |
+
for subset in EN ZH; do
|
| 113 |
+
python3 -m zipvoice.bin.compute_fbank \
|
| 114 |
+
--source-dir data/manifests/tokenized_splits \
|
| 115 |
+
--dest-dir data/fbank/emilia_splits \
|
| 116 |
+
--dataset emilia \
|
| 117 |
+
--subset ${subset} \
|
| 118 |
+
--splits-cuts 1 \
|
| 119 |
+
--split-begin 0 \
|
| 120 |
+
--split-end 2000 \
|
| 121 |
+
--num-jobs ${nj}
|
| 122 |
+
done
|
| 123 |
+
touch data/fbank/emilia_splits/.emilia.fbank.done
|
| 124 |
+
fi
|
| 125 |
+
|
| 126 |
+
if [ ! -e data/fbank/emilia_cuts_EN.jsonl.gz ]; then
|
| 127 |
+
log "Combining EN fbank cuts and spliting EN dev set"
|
| 128 |
+
gunzip -c data/fbank/emilia_splits/emilia_cuts_EN.*.jsonl.gz > data/fbank/emilia_cuts_EN.jsonl
|
| 129 |
+
head -n 1500 data/fbank/emilia_cuts_EN.jsonl | gzip -c > data/fbank/emilia_cuts_EN_dev.jsonl.gz
|
| 130 |
+
sed -i '1,1500d' data/fbank/emilia_cuts_EN.jsonl
|
| 131 |
+
gzip data/fbank/emilia_cuts_EN.jsonl
|
| 132 |
+
fi
|
| 133 |
+
|
| 134 |
+
if [ ! -e data/fbank/emilia_cuts_ZH.jsonl.gz ]; then
|
| 135 |
+
log "Combining ZH fbank cuts and spliting ZH dev set"
|
| 136 |
+
gunzip -c data/fbank/emilia_splits/emilia_cuts_ZH.*.jsonl.gz > data/fbank/emilia_cuts_ZH.jsonl
|
| 137 |
+
head -n 1500 data/fbank/emilia_cuts_ZH.jsonl | gzip -c > data/fbank/emilia_cuts_ZH_dev.jsonl.gz
|
| 138 |
+
sed -i '1,1500d' data/fbank/emilia_cuts_ZH.jsonl
|
| 139 |
+
gzip data/fbank/emilia_cuts_ZH.jsonl
|
| 140 |
+
fi
|
| 141 |
+
|
| 142 |
+
fi
|
| 143 |
+
|
| 144 |
+
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
| 145 |
+
log "Stage 5: Generate token file"
|
| 146 |
+
if [ ! -e data/tokens_emilia.txt ]; then
|
| 147 |
+
./local/prepare_token_file_emilia.py --tokens data/tokens_emilia.txt
|
| 148 |
+
fi
|
| 149 |
+
fi
|
egs/zipvoice/local/prepare_libritts.sh
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
| 4 |
+
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
| 5 |
+
export PYTHONPATH=../../:$PYTHONPATH
|
| 6 |
+
|
| 7 |
+
set -eou pipefail
|
| 8 |
+
|
| 9 |
+
stage=0
|
| 10 |
+
stop_stage=5
|
| 11 |
+
sampling_rate=24000
|
| 12 |
+
nj=20
|
| 13 |
+
|
| 14 |
+
dl_dir=$PWD/download
|
| 15 |
+
|
| 16 |
+
. utils/parse_options.sh || exit 1
|
| 17 |
+
|
| 18 |
+
# All files generated by this script are saved in "data".
|
| 19 |
+
# You can safely remove "data" and rerun this script to regenerate it.
|
| 20 |
+
mkdir -p data
|
| 21 |
+
|
| 22 |
+
log() {
|
| 23 |
+
# This function is from espnet
|
| 24 |
+
local fname=${BASH_SOURCE[1]##*/}
|
| 25 |
+
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
log "dl_dir: $dl_dir"
|
| 29 |
+
|
| 30 |
+
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
| 31 |
+
log "Stage 0: Download data"
|
| 32 |
+
|
| 33 |
+
# If you have pre-downloaded it to /path/to/LibriTTS,
|
| 34 |
+
# you can create a symlink
|
| 35 |
+
#
|
| 36 |
+
# ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS
|
| 37 |
+
#
|
| 38 |
+
if [ ! -d $dl_dir/LibriTTS ]; then
|
| 39 |
+
lhotse download libritts $dl_dir
|
| 40 |
+
fi
|
| 41 |
+
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
| 45 |
+
log "Stage 1: Prepare LibriTTS manifest"
|
| 46 |
+
# We assume that you have downloaded the LibriTTS corpus
|
| 47 |
+
# to $dl_dir/LibriTTS
|
| 48 |
+
|
| 49 |
+
# We did not add tokens to this manifest, as on-the-fly
|
| 50 |
+
# tokenization with LibriTTSTokenizer is not slow.
|
| 51 |
+
mkdir -p data/manifests
|
| 52 |
+
if [ ! -e data/manifests/.libritts.done ]; then
|
| 53 |
+
lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests
|
| 54 |
+
touch data/manifests/.libritts.done
|
| 55 |
+
fi
|
| 56 |
+
fi
|
| 57 |
+
|
| 58 |
+
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
| 59 |
+
log "Stage 2: Compute Fbank for LibriTTS"
|
| 60 |
+
mkdir -p data/fbank
|
| 61 |
+
|
| 62 |
+
if [ ! -e data/fbank/.libritts.done ]; then
|
| 63 |
+
for subset in train-clean-100 train-clean-360 train-other-500 dev-clean test-clean; do
|
| 64 |
+
python3 -m zipvoice.bin.compute_fbank \
|
| 65 |
+
--source-dir data/manifests \
|
| 66 |
+
--dest-dir data/fbank \
|
| 67 |
+
--dataset libritts \
|
| 68 |
+
--subset ${subset} \
|
| 69 |
+
--sampling-rate $sampling_rate \
|
| 70 |
+
--num-jobs ${nj}
|
| 71 |
+
done
|
| 72 |
+
touch data/fbank/.libritts.done
|
| 73 |
+
fi
|
| 74 |
+
|
| 75 |
+
# Here we shuffle and combine the train-clean-100, train-clean-360 and
|
| 76 |
+
# train-other-500 together to form the training set.
|
| 77 |
+
if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
|
| 78 |
+
cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
|
| 79 |
+
<(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \
|
| 80 |
+
<(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \
|
| 81 |
+
shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
|
| 82 |
+
fi
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if [ ! -e data/fbank/.libritts-validated.done ]; then
|
| 86 |
+
log "Validating data/fbank for LibriTTS"
|
| 87 |
+
python3 ./utils/validate_manifest.py \
|
| 88 |
+
data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
|
| 89 |
+
touch data/fbank/.libritts-validated.done
|
| 90 |
+
fi
|
| 91 |
+
fi
|
| 92 |
+
|
| 93 |
+
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
| 94 |
+
log "Stage 3: Generate token file"
|
| 95 |
+
if [ ! -e data/tokens_libritts.txt ]; then
|
| 96 |
+
python3 ./local/prepare_token_file_char.py \
|
| 97 |
+
--manifest data/fbank/libritts_cuts_train-all-shuf.jsonl.gz \
|
| 98 |
+
--tokens data/tokens_libritts.txt
|
| 99 |
+
fi
|
| 100 |
+
fi
|
egs/zipvoice/local/prepare_token_file_char.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import re
|
| 20 |
+
from collections import Counter
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
from lhotse import load_manifest_lazy
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_args():
|
| 27 |
+
parser = argparse.ArgumentParser()
|
| 28 |
+
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--tokens",
|
| 31 |
+
type=Path,
|
| 32 |
+
help="Path to the dict that maps the text tokens to IDs",
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--manifest",
|
| 37 |
+
type=Path,
|
| 38 |
+
help="Path to the manifest file",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return parser.parse_args()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def prepare_tokens(manifest_file, token_file):
|
| 45 |
+
counter = Counter()
|
| 46 |
+
manifest = load_manifest_lazy(manifest_file)
|
| 47 |
+
for cut in manifest:
|
| 48 |
+
line = re.sub(r"\s+", " ", cut.supervisions[0].text)
|
| 49 |
+
counter.update(line)
|
| 50 |
+
|
| 51 |
+
unique_chars = set(counter.keys())
|
| 52 |
+
|
| 53 |
+
if "_" in unique_chars:
|
| 54 |
+
unique_chars.remove("_")
|
| 55 |
+
|
| 56 |
+
sorted_chars = sorted(unique_chars, key=lambda char: counter[char], reverse=True)
|
| 57 |
+
|
| 58 |
+
result = ["_"] + sorted_chars
|
| 59 |
+
|
| 60 |
+
with open(token_file, "w", encoding="utf-8") as file:
|
| 61 |
+
for index, char in enumerate(result):
|
| 62 |
+
file.write(f"{char}\t{index}\n")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
args = get_args()
|
| 67 |
+
prepare_tokens(args.manifest, args.tokens)
|
egs/zipvoice/local/prepare_token_file_emilia.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
|
| 3 |
+
# Wei Kang)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
This file generates the file that maps tokens to IDs.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import logging
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import List
|
| 28 |
+
|
| 29 |
+
from piper_phonemize import get_espeak_map
|
| 30 |
+
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_args():
|
| 34 |
+
parser = argparse.ArgumentParser()
|
| 35 |
+
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--tokens",
|
| 38 |
+
type=Path,
|
| 39 |
+
default=Path("data/tokens_emilia.txt"),
|
| 40 |
+
help="Path to the dict that maps the text tokens to IDs",
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--pinyin",
|
| 45 |
+
type=Path,
|
| 46 |
+
default=Path("resources/pinyin.txt"),
|
| 47 |
+
help="Path to the all unique pinyin",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return parser.parse_args()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_pinyin_tokens(pinyin: Path) -> List[str]:
|
| 54 |
+
phones = set()
|
| 55 |
+
with open(pinyin, "r") as f:
|
| 56 |
+
for line in f:
|
| 57 |
+
x = line.strip()
|
| 58 |
+
initial = to_initials(x, strict=False)
|
| 59 |
+
# don't want to share tokens with espeak tokens, so use tone3 style
|
| 60 |
+
finals = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
|
| 61 |
+
if initial != "":
|
| 62 |
+
# don't want to share tokens with espeak tokens,
|
| 63 |
+
# so add a '0' after each initial
|
| 64 |
+
phones.add(initial + "0")
|
| 65 |
+
if finals != "":
|
| 66 |
+
phones.add(finals)
|
| 67 |
+
return sorted(phones)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_token2id(args):
|
| 71 |
+
"""Get a dict that maps token to IDs, and save it to the given filename."""
|
| 72 |
+
all_tokens = get_espeak_map() # token: [token_id]
|
| 73 |
+
all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
|
| 74 |
+
# sort by token_id
|
| 75 |
+
all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
|
| 76 |
+
|
| 77 |
+
all_pinyin = get_pinyin_tokens(args.pinyin)
|
| 78 |
+
with open(args.tokens, "w", encoding="utf-8") as f:
|
| 79 |
+
for token, token_id in all_tokens:
|
| 80 |
+
f.write(f"{token}\t{token_id}\n")
|
| 81 |
+
num_espeak_tokens = len(all_tokens)
|
| 82 |
+
for i, pinyin in enumerate(all_pinyin):
|
| 83 |
+
f.write(f"{pinyin}\t{num_espeak_tokens + i}\n")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 88 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 89 |
+
|
| 90 |
+
args = get_args()
|
| 91 |
+
get_token2id(args)
|
egs/zipvoice/local/prepare_tokens_emilia.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file reads the texts in given manifest and save the new cuts with phoneme tokens.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import glob
|
| 7 |
+
import logging
|
| 8 |
+
from concurrent.futures import ProcessPoolExecutor as Pool
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from lhotse import load_manifest_lazy
|
| 12 |
+
|
| 13 |
+
from zipvoice.tokenizer.tokenizer import add_tokens
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_args():
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--subset",
|
| 21 |
+
type=str,
|
| 22 |
+
help="Subset of emilia, (ZH, EN, etc.)",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--jobs",
|
| 27 |
+
type=int,
|
| 28 |
+
default=50,
|
| 29 |
+
help="Number of jobs to processing.",
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--source-dir",
|
| 34 |
+
type=str,
|
| 35 |
+
default="data/manifests/splits",
|
| 36 |
+
help="The source directory of manifest files.",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--dest-dir",
|
| 41 |
+
type=str,
|
| 42 |
+
help="The destination directory of manifest files.",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return parser.parse_args()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def prepare_tokens_emilia(file_name: str, input_dir: Path, output_dir: Path):
|
| 49 |
+
logging.info(f"Processing {file_name}")
|
| 50 |
+
if (output_dir / file_name).is_file():
|
| 51 |
+
logging.info(f"{file_name} exists, skipping.")
|
| 52 |
+
return
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
cut_set = load_manifest_lazy(input_dir / file_name)
|
| 56 |
+
cut_set = add_tokens(cut_set=cut_set, tokenizer="emilia")
|
| 57 |
+
cut_set.to_file(output_dir / file_name)
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logging.error(f"Manifest {file_name} failed with error: {e}")
|
| 60 |
+
raise
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 65 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 66 |
+
|
| 67 |
+
args = get_args()
|
| 68 |
+
|
| 69 |
+
input_dir = Path(args.source_dir)
|
| 70 |
+
output_dir = Path(args.dest_dir)
|
| 71 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz")
|
| 74 |
+
|
| 75 |
+
with Pool(max_workers=args.jobs) as pool:
|
| 76 |
+
futures = [
|
| 77 |
+
pool.submit(
|
| 78 |
+
prepare_tokens_emilia, filename.split("/")[-1], input_dir, output_dir
|
| 79 |
+
)
|
| 80 |
+
for filename in cut_files
|
| 81 |
+
]
|
| 82 |
+
for f in futures:
|
| 83 |
+
try:
|
| 84 |
+
f.result()
|
| 85 |
+
f.done()
|
| 86 |
+
except Exception as e:
|
| 87 |
+
logging.error(f"Future failed with error: {e}")
|
| 88 |
+
logging.info("Processing done.")
|
egs/zipvoice/local/preprocess_emilia.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
This file reads the texts in given manifest and save the cleaned new cuts.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import glob
|
| 25 |
+
import logging
|
| 26 |
+
import os
|
| 27 |
+
import re
|
| 28 |
+
import unicodedata
|
| 29 |
+
from concurrent.futures import ProcessPoolExecutor as Pool
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
from lhotse import load_manifest_lazy
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_args():
|
| 36 |
+
parser = argparse.ArgumentParser()
|
| 37 |
+
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--subset",
|
| 40 |
+
type=str,
|
| 41 |
+
help="Subset of emilia, (ZH, EN, etc.)",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--jobs",
|
| 46 |
+
type=int,
|
| 47 |
+
default=20,
|
| 48 |
+
help="Number of jobs to processing.",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--source-dir",
|
| 53 |
+
type=str,
|
| 54 |
+
default="data/manifests/splits_raw",
|
| 55 |
+
help="The source directory of manifest files.",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--dest-dir",
|
| 60 |
+
type=str,
|
| 61 |
+
default="data/manifests/splits",
|
| 62 |
+
help="The destination directory of manifest files.",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return parser.parse_args()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def tokenize_by_CJK_char(text: str) -> str:
|
| 69 |
+
"""
|
| 70 |
+
Tokenize a line of text with CJK char.
|
| 71 |
+
|
| 72 |
+
Example:
|
| 73 |
+
input = "δ½ ε₯½δΈηζ― hello world ηδΈζ"
|
| 74 |
+
output = ["δ½ ", "ε₯½", "δΈ", "η", "ζ―", "hello", "world", "η", "δΈ", "ζ"]
|
| 75 |
+
"""
|
| 76 |
+
pattern = re.compile(
|
| 77 |
+
r"([\u1100-\u11ff"
|
| 78 |
+
r"\u2e80-\ua4cf"
|
| 79 |
+
r"\ua840-\uD7AF"
|
| 80 |
+
r"\uF900-\uFAFF"
|
| 81 |
+
r"\uFE30-\uFE4F"
|
| 82 |
+
r"\uFF65-\uFFDC"
|
| 83 |
+
r"\U00020000-\U0002FFFF])"
|
| 84 |
+
)
|
| 85 |
+
chars = pattern.split(text.strip())
|
| 86 |
+
merged = " ".join([w.strip() for w in chars if w.strip()])
|
| 87 |
+
return merged.split()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def is_hangul(char):
|
| 91 |
+
letters = unicodedata.normalize("NFD", char)
|
| 92 |
+
return all(
|
| 93 |
+
["\u1100" <= c <= "\u11ff" or "\u3131" <= c <= "\u318e" for c in letters]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def is_japanese(char):
|
| 98 |
+
return any(
|
| 99 |
+
[
|
| 100 |
+
start <= char <= end
|
| 101 |
+
for start, end in [
|
| 102 |
+
("\u3041", "\u3096"),
|
| 103 |
+
("\u30a0", "\u30ff"),
|
| 104 |
+
("\uff5f", "\uff9f"),
|
| 105 |
+
("\u31f0", "\u31ff"),
|
| 106 |
+
("\u3220", "\u3243"),
|
| 107 |
+
("\u3280", "\u337f"),
|
| 108 |
+
]
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def is_chinese(char):
|
| 114 |
+
if char >= "\u4e00" and char <= "\u9fa5":
|
| 115 |
+
return True
|
| 116 |
+
else:
|
| 117 |
+
return False
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def is_alphabet(char):
|
| 121 |
+
if (char >= "\u0041" and char <= "\u005a") or (
|
| 122 |
+
char >= "\u0061" and char <= "\u007a"
|
| 123 |
+
):
|
| 124 |
+
return True
|
| 125 |
+
else:
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def preprocess_emilia(file_name: str, input_dir: Path, output_dir: Path):
|
| 130 |
+
logging.info(f"Processing {file_name}")
|
| 131 |
+
if (output_dir / file_name).is_file():
|
| 132 |
+
logging.info(f"{file_name} exists, skipping.")
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
def _filter_cut(cut):
|
| 136 |
+
text = cut.supervisions[0].text
|
| 137 |
+
duration = cut.supervisions[0].duration
|
| 138 |
+
chinese = []
|
| 139 |
+
english = []
|
| 140 |
+
|
| 141 |
+
# only contains chinese and space and alphabets
|
| 142 |
+
clean_chars = []
|
| 143 |
+
for x in text:
|
| 144 |
+
if is_hangul(x):
|
| 145 |
+
logging.warning(f"Delete cut with text containing Korean : {text}")
|
| 146 |
+
return False
|
| 147 |
+
if is_japanese(x):
|
| 148 |
+
logging.warning(f"Delete cut with text containing Japanese : {text}")
|
| 149 |
+
return False
|
| 150 |
+
if is_chinese(x):
|
| 151 |
+
chinese.append(x)
|
| 152 |
+
clean_chars.append(x)
|
| 153 |
+
if is_alphabet(x):
|
| 154 |
+
english.append(x)
|
| 155 |
+
clean_chars.append(x)
|
| 156 |
+
if x == " ":
|
| 157 |
+
clean_chars.append(x)
|
| 158 |
+
if len(english) + len(chinese) == 0:
|
| 159 |
+
logging.warning(f"Delete cut with text has no valid chars : {text}")
|
| 160 |
+
return False
|
| 161 |
+
|
| 162 |
+
words = tokenize_by_CJK_char("".join(clean_chars))
|
| 163 |
+
for i in range(len(words) - 10):
|
| 164 |
+
if words[i : i + 10].count(words[i]) == 10:
|
| 165 |
+
logging.warning(f"Delete cut with text with too much repeats : {text}")
|
| 166 |
+
return False
|
| 167 |
+
# word speed, 20 - 600 / minute
|
| 168 |
+
if duration < len(words) / 600 * 60 or duration > len(words) / 20 * 60:
|
| 169 |
+
logging.warning(
|
| 170 |
+
f"Delete cut with audio text mismatch, duration : {duration}s, "
|
| 171 |
+
f"words : {len(words)}, text : {text}"
|
| 172 |
+
)
|
| 173 |
+
return False
|
| 174 |
+
return True
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
cut_set = load_manifest_lazy(input_dir / file_name)
|
| 178 |
+
cut_set = cut_set.filter(_filter_cut)
|
| 179 |
+
cut_set.to_file(output_dir / file_name)
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logging.error(f"Manifest {file_name} failed with error: {e}")
|
| 182 |
+
os.remove(str(output_dir / file_name))
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 187 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 188 |
+
|
| 189 |
+
args = get_args()
|
| 190 |
+
|
| 191 |
+
input_dir = Path(args.source_dir)
|
| 192 |
+
output_dir = Path(args.dest_dir)
|
| 193 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 194 |
+
|
| 195 |
+
cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz")
|
| 196 |
+
|
| 197 |
+
with Pool(max_workers=args.jobs) as pool:
|
| 198 |
+
futures = [
|
| 199 |
+
pool.submit(
|
| 200 |
+
preprocess_emilia,
|
| 201 |
+
filename.split("/")[-1],
|
| 202 |
+
input_dir,
|
| 203 |
+
output_dir,
|
| 204 |
+
)
|
| 205 |
+
for filename in cut_files
|
| 206 |
+
]
|
| 207 |
+
for f in futures:
|
| 208 |
+
f.result()
|
| 209 |
+
f.done()
|
| 210 |
+
logging.info("Processing done.")
|
egs/zipvoice/run_custom.sh
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# This script is an example of training ZipVoice on your custom datasets from scratch.
|
| 4 |
+
|
| 5 |
+
# Add project root to PYTHONPATH
|
| 6 |
+
export PYTHONPATH=../../:$PYTHONPATH
|
| 7 |
+
|
| 8 |
+
# Set bash to 'debug' mode, it will exit on:
|
| 9 |
+
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
| 10 |
+
set -e
|
| 11 |
+
set -u
|
| 12 |
+
set -o pipefail
|
| 13 |
+
|
| 14 |
+
stage=1
|
| 15 |
+
stop_stage=6
|
| 16 |
+
|
| 17 |
+
# Number of jobs for data preparation
|
| 18 |
+
nj=20
|
| 19 |
+
|
| 20 |
+
# You can set `train_hours` and `max_len` according to statistics from
|
| 21 |
+
# the command `lhotse cut describe data/fbank/custom_cuts_train.jsonl.gz`.
|
| 22 |
+
# Set `train_hours` to "Total speech duration", and set `max_len` to 99% duration.
|
| 23 |
+
|
| 24 |
+
# Number of hours in training set, will affect the learning rate schedule
|
| 25 |
+
train_hours=500
|
| 26 |
+
# Maximum length (seconds) of the training utterance, will filter out longer utterances
|
| 27 |
+
max_len=20
|
| 28 |
+
|
| 29 |
+
# We suppose you have two TSV files: "data/raw/custom_train.tsv" and
|
| 30 |
+
# "data/raw/custom_dev.tsv", where "custom" is your dataset name,
|
| 31 |
+
# "train"/"dev" are used for training and validation respectively.
|
| 32 |
+
|
| 33 |
+
# Each line of the TSV files should be in one of the following formats:
|
| 34 |
+
# (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
|
| 35 |
+
# (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
|
| 36 |
+
# to part of the wav. The start_time and end_time specify the start and end
|
| 37 |
+
# times of the text within the wav, which should be in seconds.
|
| 38 |
+
# > Note: {uniq_id} must be unique for each line.
|
| 39 |
+
for subset in train dev;do
|
| 40 |
+
file_path=data/raw/custom_${subset}.tsv
|
| 41 |
+
[ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
|
| 42 |
+
done
|
| 43 |
+
|
| 44 |
+
### Prepare the training data (1 - 3)
|
| 45 |
+
|
| 46 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 47 |
+
echo "Stage 1: Prepare manifests for custom dataset from tsv files"
|
| 48 |
+
|
| 49 |
+
for subset in train dev;do
|
| 50 |
+
python3 -m zipvoice.bin.prepare_dataset \
|
| 51 |
+
--tsv-path data/raw/custom_${subset}.tsv \
|
| 52 |
+
--prefix custom \
|
| 53 |
+
--subset ${subset} \
|
| 54 |
+
--num-jobs ${nj} \
|
| 55 |
+
--output-dir data/manifests
|
| 56 |
+
done
|
| 57 |
+
# The output manifest files are "data/manifests/custom_cuts_train.jsonl.gz".
|
| 58 |
+
# and "data/manifests/custom_cuts_dev.jsonl.gz".
|
| 59 |
+
|
| 60 |
+
# We did not add tokens to the manifests, as on-the-fly tokenization
|
| 61 |
+
# with the simple tokenizer used in this example is not slow.
|
| 62 |
+
# If you change to a complex tokenizer, e.g., with g2p and heavy text normalization,
|
| 63 |
+
# you may need to add tokens to the manifests to speed up the training.
|
| 64 |
+
# Refer to the fine-tuning example for adding tokens to the manifests.
|
| 65 |
+
fi
|
| 66 |
+
|
| 67 |
+
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
| 68 |
+
echo "Stage 2: Compute Fbank for custom dataset"
|
| 69 |
+
# You can skip this step and use `--on-the-fly-feats 1` in training stage
|
| 70 |
+
for subset in train dev; do
|
| 71 |
+
python3 -m zipvoice.bin.compute_fbank \
|
| 72 |
+
--source-dir data/manifests \
|
| 73 |
+
--dest-dir data/fbank \
|
| 74 |
+
--dataset custom \
|
| 75 |
+
--subset ${subset} \
|
| 76 |
+
--num-jobs ${nj}
|
| 77 |
+
done
|
| 78 |
+
fi
|
| 79 |
+
|
| 80 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 81 |
+
echo "Stage 3: Prepare tokens file for custom dataset"
|
| 82 |
+
# In this example, we use the simplest tokenizer that
|
| 83 |
+
# treat every character as a token.
|
| 84 |
+
python3 ./local/prepare_token_file_char.py \
|
| 85 |
+
--manifest data/manifests/custom_cuts_train.jsonl.gz \
|
| 86 |
+
--tokens data/tokens_custom.txt
|
| 87 |
+
fi
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
### Training (4 - 5)
|
| 91 |
+
|
| 92 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 93 |
+
echo "Stage 4: Train the ZipVoice model"
|
| 94 |
+
|
| 95 |
+
[ -z "$train_hours" ] && { echo "Error: train_hours is not set!" >&2; exit 1; }
|
| 96 |
+
[ -z "$max_len" ] && { echo "Error: max_len is not set!" >&2; exit 1; }
|
| 97 |
+
|
| 98 |
+
# lr-hours will be set according to the `train_hours`,
|
| 99 |
+
# i.e., lr_hours = 1000 * (train_hours ** 0.3).
|
| 100 |
+
lr_hours=$(python3 -c "print(round(1000 * ($train_hours ** 0.3)))" )
|
| 101 |
+
python3 -m zipvoice.bin.train_zipvoice \
|
| 102 |
+
--world-size 4 \
|
| 103 |
+
--use-fp16 1 \
|
| 104 |
+
--num-iters 60000 \
|
| 105 |
+
--max-duration 500 \
|
| 106 |
+
--lr-hours ${lr_hours} \
|
| 107 |
+
--max-len ${max_len} \
|
| 108 |
+
--model-config conf/zipvoice_base.json \
|
| 109 |
+
--tokenizer simple \
|
| 110 |
+
--token-file data/tokens_custom.txt \
|
| 111 |
+
--dataset custom \
|
| 112 |
+
--train-manifest data/fbank/custom_cuts_train.jsonl.gz \
|
| 113 |
+
--dev-manifest data/fbank/custom_cuts_dev.jsonl.gz \
|
| 114 |
+
--exp-dir exp/zipvoice_custom
|
| 115 |
+
fi
|
| 116 |
+
|
| 117 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 118 |
+
echo "Stage 5: Average the checkpoints for ZipVoice"
|
| 119 |
+
python3 -m zipvoice.bin.generate_averaged_model \
|
| 120 |
+
--iter 60000 \
|
| 121 |
+
--avg 2 \
|
| 122 |
+
--model-name zipvoice \
|
| 123 |
+
--exp-dir exp/zipvoice_custom
|
| 124 |
+
# The generated model is exp/zipvoice_custom/iter-60000-avg-2.pt
|
| 125 |
+
fi
|
| 126 |
+
|
| 127 |
+
### Inference with PyTorch models (6)
|
| 128 |
+
|
| 129 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 130 |
+
echo "Stage 6: Inference of the ZipVoice model"
|
| 131 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 132 |
+
--model-name zipvoice \
|
| 133 |
+
--model-dir exp/zipvoice_custom \
|
| 134 |
+
--checkpoint-name iter-60000-avg-2.pt \
|
| 135 |
+
--tokenizer simple \
|
| 136 |
+
--test-list test.tsv \
|
| 137 |
+
--res-dir results/test_custom
|
| 138 |
+
fi
|
egs/zipvoice/run_emilia.sh
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# This is an example script for training ZipVoice on Emilia dataset.
|
| 4 |
+
|
| 5 |
+
# This script covers data preparation, ZipVoice trainnig,
|
| 6 |
+
# ZipVoice-Distill training, onnx export, and
|
| 7 |
+
# inference with all PyTorch and ONNX models.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Add project root to PYTHONPATH
|
| 11 |
+
export PYTHONPATH=../../:$PYTHONPATH
|
| 12 |
+
|
| 13 |
+
# Set bash to 'debug' mode, it will exit on :
|
| 14 |
+
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
| 15 |
+
set -e
|
| 16 |
+
set -u
|
| 17 |
+
set -o pipefail
|
| 18 |
+
|
| 19 |
+
stage=1
|
| 20 |
+
stop_stage=12
|
| 21 |
+
|
| 22 |
+
#### Prepare datasets (1)
|
| 23 |
+
|
| 24 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 25 |
+
echo "Stage 1: Data Preparation for Emilia dataset"
|
| 26 |
+
bash local/prepare_emilia.sh
|
| 27 |
+
fi
|
| 28 |
+
|
| 29 |
+
### Training ZipVoice (2 - 3)
|
| 30 |
+
|
| 31 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 32 |
+
echo "Stage 2: Train the ZipVoice model"
|
| 33 |
+
python3 -m zipvoice.bin.train_zipvoice \
|
| 34 |
+
--world-size 8 \
|
| 35 |
+
--use-fp16 1 \
|
| 36 |
+
--num-epochs 11 \
|
| 37 |
+
--max-duration 500 \
|
| 38 |
+
--lr-hours 30000 \
|
| 39 |
+
--model-config conf/zipvoice_base.json \
|
| 40 |
+
--tokenizer emilia \
|
| 41 |
+
--token-file data/tokens_emilia.txt \
|
| 42 |
+
--dataset emilia \
|
| 43 |
+
--manifest-dir data/fbank \
|
| 44 |
+
--exp-dir exp/zipvoice
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 48 |
+
echo "Stage 3: Average the checkpoints for ZipVoice"
|
| 49 |
+
python3 -m zipvoice.bin.generate_averaged_model \
|
| 50 |
+
--epoch 11 \
|
| 51 |
+
--avg 4 \
|
| 52 |
+
--model-name zipvoice \
|
| 53 |
+
--exp-dir exp/zipvoice
|
| 54 |
+
# The generated model is exp/zipvoice/epoch-11-avg-4.pt
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
#### (Optional) Training ZipVoice-Distill model (4 - 6)
|
| 58 |
+
|
| 59 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 60 |
+
echo "Stage 4: Train the ZipVoice-Distill model (first stage)"
|
| 61 |
+
python3 -m zipvoice.bin.train_zipvoice_distill \
|
| 62 |
+
--world-size 8 \
|
| 63 |
+
--use-fp16 1 \
|
| 64 |
+
--num-iters 60000 \
|
| 65 |
+
--max-duration 500 \
|
| 66 |
+
--base-lr 0.0005 \
|
| 67 |
+
--tokenizer emilia \
|
| 68 |
+
--token-file data/tokens_emilia.txt \
|
| 69 |
+
--dataset emilia \
|
| 70 |
+
--manifest-dir data/fbank \
|
| 71 |
+
--teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
|
| 72 |
+
--distill-stage first \
|
| 73 |
+
--exp-dir exp/zipvoice_distill_1stage
|
| 74 |
+
fi
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 78 |
+
echo "Stage 5: Average the checkpoints for ZipVoice-Distill (first stage)"
|
| 79 |
+
python3 -m zipvoice.bin.generate_averaged_model \
|
| 80 |
+
--iter 60000 \
|
| 81 |
+
--avg 7 \
|
| 82 |
+
--model-name zipvoice_distill \
|
| 83 |
+
--exp-dir exp/zipvoice_distill_1stage
|
| 84 |
+
# The generated model is exp/zipvoice_distill_1stage/iter-60000-avg-7.pt
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 88 |
+
echo "Stage 6: Train the ZipVoice-Distill model (second stage)"
|
| 89 |
+
|
| 90 |
+
python3 -m zipvoice.bin.train_zipvoice_distill \
|
| 91 |
+
--world-size 8 \
|
| 92 |
+
--use-fp16 1 \
|
| 93 |
+
--num-iters 2000 \
|
| 94 |
+
--save-every-n 1000 \
|
| 95 |
+
--max-duration 500 \
|
| 96 |
+
--base-lr 0.0001 \
|
| 97 |
+
--model-config conf/zipvoice_base.json \
|
| 98 |
+
--tokenizer emilia \
|
| 99 |
+
--token-file data/tokens_emilia.txt \
|
| 100 |
+
--dataset emilia \
|
| 101 |
+
--manifest-dir data/fbank \
|
| 102 |
+
--teacher-model exp/zipvoice_distill_1stage/iter-60000-avg-7.pt \
|
| 103 |
+
--distill-stage second \
|
| 104 |
+
--exp-dir exp/zipvoice_distill
|
| 105 |
+
fi
|
| 106 |
+
|
| 107 |
+
### Export ONNX model (7 - 8)
|
| 108 |
+
|
| 109 |
+
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
| 110 |
+
echo "Stage 7: Export ZipVoice ONNX model"
|
| 111 |
+
python3 -m zipvoice.bin.onnx_export \
|
| 112 |
+
--model-name zipvoice \
|
| 113 |
+
--model-dir exp/zipvoice/ \
|
| 114 |
+
--checkpoint-name epoch-11-avg-4.pt \
|
| 115 |
+
--onnx-model-dir exp/zipvoice/
|
| 116 |
+
fi
|
| 117 |
+
|
| 118 |
+
if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
|
| 119 |
+
echo "Stage 8: Export ZipVoice-Distill ONNX model"
|
| 120 |
+
python3 -m zipvoice.bin.onnx_export \
|
| 121 |
+
--model-name zipvoice_distill \
|
| 122 |
+
--model-dir exp/zipvoice_distill/ \
|
| 123 |
+
--checkpoint-name checkpoint-2000.pt \
|
| 124 |
+
--onnx-model-dir exp/zipvoice_distill/
|
| 125 |
+
fi
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
### Inference with PyTorch and ONNX models (9 - 12)
|
| 129 |
+
|
| 130 |
+
if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
|
| 131 |
+
echo "Stage 9: Inference of the ZipVoice model"
|
| 132 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 133 |
+
--model-name zipvoice \
|
| 134 |
+
--model-dir exp/zipvoice/ \
|
| 135 |
+
--checkpoint-name epoch-11-avg-4.pt \
|
| 136 |
+
--tokenizer emilia \
|
| 137 |
+
--test-list test.tsv \
|
| 138 |
+
--res-dir results/test \
|
| 139 |
+
--num-step 16 \
|
| 140 |
+
--guidance-scale 1
|
| 141 |
+
fi
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
|
| 145 |
+
echo "Stage 10: Inference of the ZipVoice-Distill model"
|
| 146 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 147 |
+
--model-name zipvoice_distill \
|
| 148 |
+
--model-dir exp/zipvoice_distill/ \
|
| 149 |
+
--checkpoint-name checkpoint-2000.pt \
|
| 150 |
+
--tokenizer emilia \
|
| 151 |
+
--test-list test.tsv \
|
| 152 |
+
--res-dir results/test_distill \
|
| 153 |
+
--num-step 8 \
|
| 154 |
+
--guidance-scale 3
|
| 155 |
+
fi
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then
|
| 159 |
+
echo "Stage 11: Inference with ZipVoice ONNX model"
|
| 160 |
+
python3 -m zipvoice.bin.infer_zipvoice_onnx \
|
| 161 |
+
--model-name zipvoice \
|
| 162 |
+
--onnx-int8 False \
|
| 163 |
+
--model-dir exp/zipvoice \
|
| 164 |
+
--tokenizer emilia \
|
| 165 |
+
--test-list test.tsv \
|
| 166 |
+
--res-dir results/test_onnx
|
| 167 |
+
fi
|
| 168 |
+
|
| 169 |
+
if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then
|
| 170 |
+
echo "Stage 12: Inference with ZipVoic-Distill ONNX model"
|
| 171 |
+
python3 -m zipvoice.bin.infer_zipvoice_onnx \
|
| 172 |
+
--model-name zipvoice_distill \
|
| 173 |
+
--onnx-int8 False \
|
| 174 |
+
--model-dir exp/zipvoice_distill \
|
| 175 |
+
--tokenizer emilia \
|
| 176 |
+
--test-list test.tsv \
|
| 177 |
+
--res-dir results/test_distill_onnx
|
| 178 |
+
fi
|
egs/zipvoice/run_eval.sh
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# This script is an example of evaluate TTS models with objective metrics reported in ZipVoice paper.
|
| 4 |
+
|
| 5 |
+
# Add project root to PYTHONPATH
|
| 6 |
+
export PYTHONPATH=../../:$PYTHONPATH
|
| 7 |
+
|
| 8 |
+
# Set bash to 'debug' mode, it will exit on:
|
| 9 |
+
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
| 10 |
+
set -e
|
| 11 |
+
set -u
|
| 12 |
+
set -o pipefail
|
| 13 |
+
|
| 14 |
+
stage=1
|
| 15 |
+
stop_stage=7
|
| 16 |
+
|
| 17 |
+
download_dir=download/
|
| 18 |
+
|
| 19 |
+
# Uncomment this line to use HF mirror
|
| 20 |
+
# export HF_ENDPOINT=https://hf-mirror.com
|
| 21 |
+
|
| 22 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 23 |
+
echo "Stage 1: Download test sets (LibriSpeech-PC and Seed-TTS)"
|
| 24 |
+
|
| 25 |
+
hf_repo=k2-fsa/TTS_eval_datasets
|
| 26 |
+
mkdir -p ${download_dir}/
|
| 27 |
+
for file in librispeech_pc_testset.tar.gz seedtts_testset.tar.gz; do
|
| 28 |
+
echo "Downloading ${file}..."
|
| 29 |
+
huggingface-cli download \
|
| 30 |
+
--repo-type dataset \
|
| 31 |
+
--local-dir ${download_dir}/ \
|
| 32 |
+
${hf_repo} \
|
| 33 |
+
${file}
|
| 34 |
+
echo "Extracting ${file}..."
|
| 35 |
+
tar -xzf ${download_dir}/${file} -C ${download_dir}/
|
| 36 |
+
done
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 41 |
+
echo "Stage 2: Download all required evaluation models"
|
| 42 |
+
hf_repo=k2-fsa/TTS_eval_models
|
| 43 |
+
mkdir -p ${download_dir}/tts_eval_models
|
| 44 |
+
huggingface-cli download \
|
| 45 |
+
--local-dir ${download_dir}/tts_eval_models \
|
| 46 |
+
${hf_repo}
|
| 47 |
+
fi
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 51 |
+
echo "Stage 3: Inference with the pre-trained ZipVoice model from huggingface"
|
| 52 |
+
|
| 53 |
+
for testset in librispeech_pc seedtts_en seedtts_zh; do
|
| 54 |
+
|
| 55 |
+
if [ "$testset" = "librispeech_pc" ]; then
|
| 56 |
+
test_tsv=${download_dir}/librispeech_pc_testset/test.tsv
|
| 57 |
+
|
| 58 |
+
elif [ "$testset" = "seedtts_en" ]; then
|
| 59 |
+
test_tsv=${download_dir}/seedtts_testset/en/test.tsv
|
| 60 |
+
elif [ "$testset" = "seedtts_zh" ]; then
|
| 61 |
+
test_tsv=${download_dir}/seedtts_testset/zh/test.tsv
|
| 62 |
+
else
|
| 63 |
+
echo "Error: unknown testset ${testset}" >&2
|
| 64 |
+
exit 1
|
| 65 |
+
fi
|
| 66 |
+
echo "Inference on tetset ${testset}..."
|
| 67 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 68 |
+
--model-name zipvoice \
|
| 69 |
+
--test-list ${test_tsv} \
|
| 70 |
+
--res-dir results/${testset}
|
| 71 |
+
done
|
| 72 |
+
fi
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 77 |
+
echo "Stage 4: Evaluation on LibriSpeech-PC"
|
| 78 |
+
model_path=${download_dir}/tts_eval_models
|
| 79 |
+
wav_path=results/librispeech_pc
|
| 80 |
+
test_tsv=${download_dir}/librispeech_pc_testset/test.tsv
|
| 81 |
+
# Use LibriSpeech style transcripts for WER evaluation
|
| 82 |
+
transcript_tsv=${download_dir}/librispeech_pc_testset/transcript.tsv
|
| 83 |
+
|
| 84 |
+
python3 -m zipvoice.eval.speaker_similarity.sim \
|
| 85 |
+
--wav-path ${wav_path} \
|
| 86 |
+
--test-list ${test_tsv} \
|
| 87 |
+
--model-dir ${model_path}
|
| 88 |
+
|
| 89 |
+
python3 -m zipvoice.eval.wer.hubert \
|
| 90 |
+
--wav-path ${wav_path} \
|
| 91 |
+
--test-list ${transcript_tsv} \
|
| 92 |
+
--model-dir ${model_path}
|
| 93 |
+
|
| 94 |
+
python3 -m zipvoice.eval.mos.utmos \
|
| 95 |
+
--wav-path ${wav_path} \
|
| 96 |
+
--model-dir ${model_path}
|
| 97 |
+
fi
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 101 |
+
echo "Stage 5: Evaluation on Seed-TTS test en"
|
| 102 |
+
model_path=${download_dir}/tts_eval_models
|
| 103 |
+
wav_path=results/seedtts_en
|
| 104 |
+
test_tsv=${download_dir}/seedtts_testset/en/test.tsv
|
| 105 |
+
|
| 106 |
+
python3 -m zipvoice.eval.speaker_similarity.sim \
|
| 107 |
+
--wav-path ${wav_path} \
|
| 108 |
+
--test-list ${test_tsv} \
|
| 109 |
+
--model-dir ${model_path}
|
| 110 |
+
|
| 111 |
+
python3 -m zipvoice.eval.wer.seedtts \
|
| 112 |
+
--wav-path ${wav_path} \
|
| 113 |
+
--test-list ${test_tsv} \
|
| 114 |
+
--model-dir ${model_path} \
|
| 115 |
+
--lang en
|
| 116 |
+
|
| 117 |
+
python3 -m zipvoice.eval.mos.utmos \
|
| 118 |
+
--wav-path ${wav_path} \
|
| 119 |
+
--model-dir ${model_path}
|
| 120 |
+
fi
|
| 121 |
+
|
| 122 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 123 |
+
echo "Stage 6: Evaluation on Seed-TTS test en"
|
| 124 |
+
model_path=${download_dir}/tts_eval_models
|
| 125 |
+
wav_path=results/seedtts_zh
|
| 126 |
+
test_tsv=${download_dir}/seedtts_testset/zh/test.tsv
|
| 127 |
+
|
| 128 |
+
python3 -m zipvoice.eval.speaker_similarity.sim \
|
| 129 |
+
--wav-path ${wav_path} \
|
| 130 |
+
--test-list ${test_tsv} \
|
| 131 |
+
--model-dir ${model_path}
|
| 132 |
+
|
| 133 |
+
python3 -m zipvoice.eval.wer.seedtts \
|
| 134 |
+
--wav-path ${wav_path} \
|
| 135 |
+
--test-list ${test_tsv} \
|
| 136 |
+
--model-dir ${model_path} \
|
| 137 |
+
--lang zh
|
| 138 |
+
|
| 139 |
+
python3 -m zipvoice.eval.mos.utmos \
|
| 140 |
+
--wav-path ${wav_path} \
|
| 141 |
+
--model-dir ${model_path}
|
| 142 |
+
fi
|
egs/zipvoice/run_finetune.sh
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# This script is an example of fine-tuning ZipVoice on your custom datasets.
|
| 4 |
+
|
| 5 |
+
# Add project root to PYTHONPATH
|
| 6 |
+
# export PYTHONPATH=../../:$PYTHONPATH
|
| 7 |
+
|
| 8 |
+
# Set bash to 'debug' mode, it will exit on:
|
| 9 |
+
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
| 10 |
+
set -e
|
| 11 |
+
set -u
|
| 12 |
+
set -o pipefail
|
| 13 |
+
|
| 14 |
+
stage=1
|
| 15 |
+
stop_stage=6
|
| 16 |
+
|
| 17 |
+
# Number of jobs for data preparation
|
| 18 |
+
nj=4
|
| 19 |
+
|
| 20 |
+
# Whether the language of training data is one of Chinese and English
|
| 21 |
+
is_zh_en=0
|
| 22 |
+
|
| 23 |
+
# Language identifier, used when language is not Chinese or English
|
| 24 |
+
# see https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md
|
| 25 |
+
# Example of French: lang=fr
|
| 26 |
+
lang=vi
|
| 27 |
+
|
| 28 |
+
if [ $is_zh_en -eq 1 ]; then
|
| 29 |
+
tokenizer=espeak
|
| 30 |
+
else
|
| 31 |
+
tokenizer=espeak
|
| 32 |
+
[ "$lang" = "default" ] && { echo "Error: lang is not set!" >&2; exit 1; }
|
| 33 |
+
fi
|
| 34 |
+
|
| 35 |
+
# You can set `max_len` according to statistics from the command
|
| 36 |
+
# `lhotse cut describe data/fbank/custom_cuts_train.jsonl.gz`.
|
| 37 |
+
# Set `max_len` to 99% duration.
|
| 38 |
+
|
| 39 |
+
# Maximum length (seconds) of the training utterance, will filter out longer utterances
|
| 40 |
+
max_len=25
|
| 41 |
+
|
| 42 |
+
# Download directory for pre-trained models
|
| 43 |
+
download_dir=download
|
| 44 |
+
|
| 45 |
+
# We suppose you have two TSV files: "data/raw/custom_train.tsv" and
|
| 46 |
+
# "data/raw/custom_dev.tsv", where "custom" is your dataset name,
|
| 47 |
+
# "train"/"dev" are used for training and validation respectively.
|
| 48 |
+
|
| 49 |
+
# Each line of the TSV files should be in one of the following formats:
|
| 50 |
+
# (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
|
| 51 |
+
# (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
|
| 52 |
+
# to part of the wav. The start_time and end_time specify the start and end
|
| 53 |
+
# times of the text within the wav, which should be in seconds.
|
| 54 |
+
# > Note: {uniq_id} must be unique for each line.
|
| 55 |
+
# for subset in train dev;do
|
| 56 |
+
# file_path=data/raw/custom_${subset}.tsv
|
| 57 |
+
# [ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
|
| 58 |
+
# done
|
| 59 |
+
|
| 60 |
+
# ### Prepare the training data (1 - 4)
|
| 61 |
+
|
| 62 |
+
# if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 63 |
+
# echo "Stage 1: Prepare manifests for custom dataset from tsv files"
|
| 64 |
+
|
| 65 |
+
# for subset in train dev;do
|
| 66 |
+
# python3 -m zipvoice.bin.prepare_dataset \
|
| 67 |
+
# --tsv-path data/raw/custom_${subset}.tsv \
|
| 68 |
+
# --prefix custom-finetune \
|
| 69 |
+
# --subset raw_${subset} \
|
| 70 |
+
# --num-jobs ${nj} \
|
| 71 |
+
# --output-dir data/manifests
|
| 72 |
+
# done
|
| 73 |
+
# # The output manifest files are "data/manifests/custom-finetune_cuts_raw_train.jsonl.gz".
|
| 74 |
+
# # and "data/manifests/custom-finetune_cuts_raw_dev.jsonl.gz".
|
| 75 |
+
# fi
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 79 |
+
# echo "Stage 2: Add tokens to manifests"
|
| 80 |
+
# # For "emilia" and "espeak" tokenizers, it's better to prepare the tokens
|
| 81 |
+
# # before training. Otherwise, the on-the-fly tokenization can significantly
|
| 82 |
+
# # slow down the training.
|
| 83 |
+
# for subset in train dev;do
|
| 84 |
+
# python3 -m zipvoice.bin.prepare_tokens \
|
| 85 |
+
# --input-file data/manifests/custom-finetune_cuts_raw_${subset}.jsonl.gz \
|
| 86 |
+
# --output-file data/manifests/custom-finetune_cuts_${subset}.jsonl.gz \
|
| 87 |
+
# --tokenizer ${tokenizer} \
|
| 88 |
+
# --lang ${lang}
|
| 89 |
+
# done
|
| 90 |
+
# # The output manifest files are "data/manifests/custom-finetune_cuts_train.jsonl.gz".
|
| 91 |
+
# # and "data/manifests/custom-finetune_cuts_dev.jsonl.gz".
|
| 92 |
+
# fi
|
| 93 |
+
|
| 94 |
+
# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
| 95 |
+
# echo "Stage 3: Compute Fbank for custom dataset"
|
| 96 |
+
# # You can skip this step and use `--on-the-fly-feats 1` in training stage
|
| 97 |
+
# for subset in train dev; do
|
| 98 |
+
# python3 -m zipvoice.bin.compute_fbank \
|
| 99 |
+
# --source-dir data/manifests \
|
| 100 |
+
# --dest-dir data/fbank \
|
| 101 |
+
# --dataset custom-finetune \
|
| 102 |
+
# --subset ${subset} \
|
| 103 |
+
# --num-jobs ${nj}
|
| 104 |
+
# done
|
| 105 |
+
# fi
|
| 106 |
+
|
| 107 |
+
# # if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 108 |
+
# # echo "Stage 4: Download pre-trained model, tokens file, and model config"
|
| 109 |
+
# # # Uncomment this line to use HF mirror
|
| 110 |
+
# # # export HF_ENDPOINT=https://hf-mirror.com
|
| 111 |
+
# # hf_repo=k2-fsa/ZipVoice
|
| 112 |
+
# # mkdir -p ${download_dir}
|
| 113 |
+
# # for file in model.pt tokens.txt model.json; do
|
| 114 |
+
# # huggingface-cli download \
|
| 115 |
+
# # --local-dir ${download_dir} \
|
| 116 |
+
# # ${hf_repo} \
|
| 117 |
+
# # zipvoice/${file}
|
| 118 |
+
# # done
|
| 119 |
+
# # fi
|
| 120 |
+
|
| 121 |
+
# # ### Training ZipVoice (5 - 6)
|
| 122 |
+
|
| 123 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 124 |
+
echo "Stage 5: Fine-tune the ZipVoice model"
|
| 125 |
+
|
| 126 |
+
[ -z "$max_len" ] && { echo "Error: max_len is not set!" >&2; exit 1; }
|
| 127 |
+
|
| 128 |
+
python3 -m zipvoice.bin.train_zipvoice \
|
| 129 |
+
--world-size 1 \
|
| 130 |
+
--use-fp16 1 \
|
| 131 |
+
--finetune 1 \
|
| 132 |
+
--base-lr 0.00006 \
|
| 133 |
+
--num-epochs 2 \
|
| 134 |
+
--save-every-n 1000 \
|
| 135 |
+
--keep-last-k 4 \
|
| 136 |
+
--max-duration 650 \
|
| 137 |
+
--max-len ${max_len} \
|
| 138 |
+
--min-len 1 \
|
| 139 |
+
--model-config ${download_dir}/zipvoice/model.json \
|
| 140 |
+
--checkpoint ${download_dir}/zipvoice/model.pt \
|
| 141 |
+
--tokenizer ${tokenizer} \
|
| 142 |
+
--lang ${lang} \
|
| 143 |
+
--token-file ${download_dir}/zipvoice/tokens.txt \
|
| 144 |
+
--dataset custom \
|
| 145 |
+
--train-manifest data/fbank/train_all.jsonl.gz \
|
| 146 |
+
--dev-manifest data/fbank/dev_all.jsonl.gz \
|
| 147 |
+
--exp-dir exp/zipvoice_finetune
|
| 148 |
+
|
| 149 |
+
fi
|
| 150 |
+
|
| 151 |
+
# if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 152 |
+
# echo "Stage 6: Average the checkpoints for ZipVoice"
|
| 153 |
+
# python3 -m zipvoice.bin.generate_averaged_model \
|
| 154 |
+
# --iter 10000 \
|
| 155 |
+
# --avg 2 \
|
| 156 |
+
# --model-name zipvoice \
|
| 157 |
+
# --exp-dir exp/zipvoice_finetune
|
| 158 |
+
# # The generated model is exp/zipvoice_finetune/iter-10000-avg-2.pt
|
| 159 |
+
# fi
|
| 160 |
+
|
| 161 |
+
# ### Inference with PyTorch models (7)
|
| 162 |
+
|
| 163 |
+
# if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
| 164 |
+
# echo "Stage 7: Inference of the ZipVoice model"
|
| 165 |
+
|
| 166 |
+
# python3 -m zipvoice.bin.infer_zipvoice \
|
| 167 |
+
# --model-name zipvoice \
|
| 168 |
+
# --model-dir exp/zipvoice_finetune/ \
|
| 169 |
+
# --checkpoint-name iter-10000-avg-2.pt \
|
| 170 |
+
# --tokenizer ${tokenizer} \
|
| 171 |
+
# --lang ${lang} \
|
| 172 |
+
# --test-list test.tsv \
|
| 173 |
+
# --res-dir results/test_finetune\
|
| 174 |
+
# --num-step 16
|
| 175 |
+
# fi
|
egs/zipvoice/run_libritts.sh
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# This is an example script for training ZipVoice on LibriTTS dataset.
|
| 4 |
+
|
| 5 |
+
# Add project root to PYTHONPATH
|
| 6 |
+
export PYTHONPATH=../../:$PYTHONPATH
|
| 7 |
+
|
| 8 |
+
# Set bash to 'debug' mode, it will exit on :
|
| 9 |
+
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
| 10 |
+
set -e
|
| 11 |
+
set -u
|
| 12 |
+
set -o pipefail
|
| 13 |
+
|
| 14 |
+
stage=1
|
| 15 |
+
stop_stage=9
|
| 16 |
+
|
| 17 |
+
#### Prepare datasets (1)
|
| 18 |
+
|
| 19 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 20 |
+
echo "Stage 1: Data Preparation for LibriTTS dataset"
|
| 21 |
+
bash local/prepare_libritts.sh
|
| 22 |
+
fi
|
| 23 |
+
|
| 24 |
+
### Training ZipVoice (2 - 3)
|
| 25 |
+
|
| 26 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 27 |
+
echo "Stage 2: Train the ZipVoice model"
|
| 28 |
+
python3 -m zipvoice.bin.train_zipvoice \
|
| 29 |
+
--world-size 8 \
|
| 30 |
+
--use-fp16 0 \
|
| 31 |
+
--num-epochs 60 \
|
| 32 |
+
--max-duration 250 \
|
| 33 |
+
--lr-epochs 10 \
|
| 34 |
+
--max-len 20 \
|
| 35 |
+
--valid-by-epoch 1 \
|
| 36 |
+
--model-config conf/zipvoice_base.json \
|
| 37 |
+
--tokenizer libritts \
|
| 38 |
+
--token-file data/tokens_libritts.txt \
|
| 39 |
+
--dataset libritts \
|
| 40 |
+
--manifest-dir data/fbank \
|
| 41 |
+
--exp-dir exp/zipvoice_libritts
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 45 |
+
echo "Stage 3: Average the checkpoints for ZipVoice"
|
| 46 |
+
python3 -m zipvoice.bin.generate_averaged_model \
|
| 47 |
+
--epoch 60 \
|
| 48 |
+
--avg 10 \
|
| 49 |
+
--model-name zipvoice \
|
| 50 |
+
--exp-dir exp/zipvoice_libritts
|
| 51 |
+
# The generated model is exp/zipvoice_libritts/epoch-60-avg-10.pt
|
| 52 |
+
fi
|
| 53 |
+
|
| 54 |
+
#### (Optional) Training ZipVoice-Distill model (4 - 7)
|
| 55 |
+
|
| 56 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 57 |
+
echo "Stage 4: Train the ZipVoice-Distill model (first stage)"
|
| 58 |
+
python3 -m zipvoice.bin.train_zipvoice_distill \
|
| 59 |
+
--world-size 8 \
|
| 60 |
+
--use-fp16 0 \
|
| 61 |
+
--num-epochs 6 \
|
| 62 |
+
--max-duration 250 \
|
| 63 |
+
--base-lr 0.001 \
|
| 64 |
+
--max-len 20 \
|
| 65 |
+
--valid-by-epoch 1 \
|
| 66 |
+
--model-config conf/zipvoice_base.json \
|
| 67 |
+
--tokenizer libritts \
|
| 68 |
+
--token-file data/tokens_libritts.txt \
|
| 69 |
+
--dataset "libritts" \
|
| 70 |
+
--manifest-dir "data/fbank" \
|
| 71 |
+
--teacher-model exp/zipvoice_libritts/epoch-60-avg-10.pt \
|
| 72 |
+
--distill-stage "first" \
|
| 73 |
+
--exp-dir exp/zipvoice_distill_1stage_libritts
|
| 74 |
+
fi
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 78 |
+
echo "Stage 5: Average the checkpoints for ZipVoice-Distill (first stage)"
|
| 79 |
+
python3 -m zipvoice.bin.generate_averaged_model \
|
| 80 |
+
--epoch 6 \
|
| 81 |
+
--avg 3 \
|
| 82 |
+
--model-name zipvoice_distill \
|
| 83 |
+
--exp-dir exp/zipvoice_distill_1stage_libritts
|
| 84 |
+
# The generated model is exp/zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 88 |
+
echo "Stage 6: Train the ZipVoice-Distill model (second stage)"
|
| 89 |
+
|
| 90 |
+
python3 -m zipvoice.bin.train_zipvoice_distill \
|
| 91 |
+
--world-size 8 \
|
| 92 |
+
--use-fp16 1 \
|
| 93 |
+
--num-epochs 6 \
|
| 94 |
+
--max-duration 250 \
|
| 95 |
+
--base-lr 0.001 \
|
| 96 |
+
--max-len 20 \
|
| 97 |
+
--valid-by-epoch 1 \
|
| 98 |
+
--model-config conf/zipvoice_base.json \
|
| 99 |
+
--tokenizer libritts \
|
| 100 |
+
--token-file data/tokens_libritts.txt \
|
| 101 |
+
--dataset libritts \
|
| 102 |
+
--manifest-dir data/fbank \
|
| 103 |
+
--teacher-model exp/zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt \
|
| 104 |
+
--distill-stage second \
|
| 105 |
+
--exp-dir exp/zipvoice_distill_libritts
|
| 106 |
+
fi
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
| 110 |
+
echo "Stage 7: Average the checkpoints for ZipVoice-Distill (second stage)"
|
| 111 |
+
python3 -m zipvoice.bin.generate_averaged_model \
|
| 112 |
+
--epoch 6 \
|
| 113 |
+
--avg 3 \
|
| 114 |
+
--model-name zipvoice_distill \
|
| 115 |
+
--exp-dir exp/zipvoice_distill_libritts
|
| 116 |
+
# The generated model is exp/zipvoice_distill_libritts/epoch-6-avg-3.pt
|
| 117 |
+
fi
|
| 118 |
+
|
| 119 |
+
### Inference with PyTorch models (8 - 9)
|
| 120 |
+
|
| 121 |
+
if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
|
| 122 |
+
echo "Stage 8: Inference of the ZipVoice model"
|
| 123 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 124 |
+
--model-name zipvoice \
|
| 125 |
+
--model-dir exp/zipvoice_libritts \
|
| 126 |
+
--checkpoint-name epoch-60-avg-10.pt \
|
| 127 |
+
--tokenizer libritts \
|
| 128 |
+
--test-list test.tsv \
|
| 129 |
+
--res-dir results/test_libritts \
|
| 130 |
+
--num-step 8 \
|
| 131 |
+
--guidance-scale 1 \
|
| 132 |
+
--t-shift 0.7
|
| 133 |
+
fi
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
|
| 137 |
+
echo "Stage 9: Inference of the ZipVoice-Distill model"
|
| 138 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 139 |
+
--model-name zipvoice_distill \
|
| 140 |
+
--model-dir exp/zipvoice_distill_libritts \
|
| 141 |
+
--checkpoint-name epoch-6-avg-3.pt \
|
| 142 |
+
--tokenizer libritts \
|
| 143 |
+
--test-list test.tsv \
|
| 144 |
+
--res-dir results/test_distill_libritts \
|
| 145 |
+
--num-step 4 \
|
| 146 |
+
--guidance-scale 3 \
|
| 147 |
+
--t-shift 0.7
|
| 148 |
+
fi
|
egs/zipvoice/utils/parse_options.sh
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
| 4 |
+
# Arnab Ghoshal, Karel Vesely
|
| 5 |
+
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
| 13 |
+
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
| 14 |
+
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
| 15 |
+
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
| 16 |
+
# See the Apache 2 License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Parse command-line options.
|
| 21 |
+
# To be sourced by another script (as in ". parse_options.sh").
|
| 22 |
+
# Option format is: --option-name arg
|
| 23 |
+
# and shell variable "option_name" gets set to value "arg."
|
| 24 |
+
# The exception is --help, which takes no arguments, but prints the
|
| 25 |
+
# $help_message variable (if defined).
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
###
|
| 29 |
+
### The --config file options have lower priority to command line
|
| 30 |
+
### options, so we need to import them first...
|
| 31 |
+
###
|
| 32 |
+
|
| 33 |
+
# Now import all the configs specified by command-line, in left-to-right order
|
| 34 |
+
for ((argpos=1; argpos<$#; argpos++)); do
|
| 35 |
+
if [ "${!argpos}" == "--config" ]; then
|
| 36 |
+
argpos_plus1=$((argpos+1))
|
| 37 |
+
config=${!argpos_plus1}
|
| 38 |
+
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
| 39 |
+
. $config # source the config file.
|
| 40 |
+
fi
|
| 41 |
+
done
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
###
|
| 45 |
+
### Now we process the command line options
|
| 46 |
+
###
|
| 47 |
+
while true; do
|
| 48 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 49 |
+
case "$1" in
|
| 50 |
+
# If the enclosing script is called with --help option, print the help
|
| 51 |
+
# message and exit. Scripts should put help messages in $help_message
|
| 52 |
+
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
| 53 |
+
else printf "$help_message\n" 1>&2 ; fi;
|
| 54 |
+
exit 0 ;;
|
| 55 |
+
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
| 56 |
+
exit 1 ;;
|
| 57 |
+
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
| 58 |
+
# then work out the variable name as $name, which will equal "foo_bar".
|
| 59 |
+
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
| 60 |
+
# Next we test whether the variable in question is undefned-- if so it's
|
| 61 |
+
# an invalid option and we die. Note: $0 evaluates to the name of the
|
| 62 |
+
# enclosing script.
|
| 63 |
+
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
| 64 |
+
# is undefined. We then have to wrap this test inside "eval" because
|
| 65 |
+
# foo_bar is itself inside a variable ($name).
|
| 66 |
+
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 67 |
+
|
| 68 |
+
oldval="`eval echo \\$$name`";
|
| 69 |
+
# Work out whether we seem to be expecting a Boolean argument.
|
| 70 |
+
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
| 71 |
+
was_bool=true;
|
| 72 |
+
else
|
| 73 |
+
was_bool=false;
|
| 74 |
+
fi
|
| 75 |
+
|
| 76 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
| 77 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 78 |
+
eval $name=\"$2\";
|
| 79 |
+
|
| 80 |
+
# Check that Boolean-valued arguments are really Boolean.
|
| 81 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 82 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 83 |
+
exit 1;
|
| 84 |
+
fi
|
| 85 |
+
shift 2;
|
| 86 |
+
;;
|
| 87 |
+
*) break;
|
| 88 |
+
esac
|
| 89 |
+
done
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Check for an empty argument to the --cmd option, which can easily occur as a
|
| 93 |
+
# result of scripting errors.
|
| 94 |
+
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
true; # so this script returns exit code 0.
|
egs/zipvoice/utils/validate_manifest.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
| 3 |
+
# Zengwei Yao)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
"""
|
| 19 |
+
This script checks the following assumptions of the generated manifest:
|
| 20 |
+
|
| 21 |
+
- Single supervision per cut
|
| 22 |
+
|
| 23 |
+
We will add more checks later if needed.
|
| 24 |
+
|
| 25 |
+
Usage example:
|
| 26 |
+
|
| 27 |
+
python3 ./utils/validate_manifest.py \
|
| 28 |
+
./data/spectrogram/ljspeech_cuts_all.jsonl.gz
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
import argparse
|
| 33 |
+
import logging
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
from lhotse import CutSet, load_manifest_lazy
|
| 37 |
+
from lhotse.dataset.speech_synthesis import validate_for_tts
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_args():
|
| 41 |
+
parser = argparse.ArgumentParser()
|
| 42 |
+
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"manifest",
|
| 45 |
+
type=Path,
|
| 46 |
+
help="Path to the manifest file",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
return parser.parse_args()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def main():
|
| 53 |
+
args = get_args()
|
| 54 |
+
|
| 55 |
+
manifest = args.manifest
|
| 56 |
+
logging.info(f"Validating {manifest}")
|
| 57 |
+
|
| 58 |
+
assert manifest.is_file(), f"{manifest} does not exist"
|
| 59 |
+
cut_set = load_manifest_lazy(manifest)
|
| 60 |
+
assert isinstance(cut_set, CutSet), type(cut_set)
|
| 61 |
+
|
| 62 |
+
validate_for_tts(cut_set)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 67 |
+
|
| 68 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 69 |
+
|
| 70 |
+
main()
|
egs/zipvoice_dialog/README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ZipVoice-Dialog Recipe
|
| 2 |
+
|
| 3 |
+
This recipe contains the following examples:
|
| 4 |
+
|
| 5 |
+
- Training ZipVoice-Dialog on OpenDialog dataset, see [run_opendialog.sh](run_opendialog.sh)
|
| 6 |
+
- Training ZipVoice-Dialog on custom datasets (Chinese/English), see [run_custom.sh](run_custom.sh).
|
| 7 |
+
- Fine-tuning pre-trained ZipVoice-Dialog on custom datasets (Chinese/English), see [run_finetune.sh](run_finetune.sh).
|
| 8 |
+
- Evaluate models with objective metrics reported in ZipVoice-Dialog paper, see [run_eval.sh](run_eval.sh).
|
| 9 |
+
|
| 10 |
+
> **NOTE:** For evaluation, first install packages from [../../requirements_eval.txt](../../requirements_eval.txt)
|
| 11 |
+
>
|
| 12 |
+
> `pip install -r ../../requirements_eval.txt`
|
egs/zipvoice_dialog/local/prepare_opendialog.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This script prepares lhotse manifest files from the raw OpenDialog datasets.
|
| 20 |
+
|
| 21 |
+
We assume that you have downloaded the OpenDialog dataset and untarred the
|
| 22 |
+
tar files in audio/en and audio/zh so that the mp3 files are placed under
|
| 23 |
+
these two directories.
|
| 24 |
+
|
| 25 |
+
Download OpenDialog at https://huggingface.co/datasets/k2-fsa/OpenDialog
|
| 26 |
+
or https://www.modelscope.cn/datasets/k2-fsa/OpenDialog
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import json
|
| 32 |
+
import logging
|
| 33 |
+
import math
|
| 34 |
+
import re
|
| 35 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 36 |
+
from functools import partial
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
from typing import List, Optional, Tuple
|
| 39 |
+
|
| 40 |
+
from lhotse import CutSet, validate_recordings_and_supervisions
|
| 41 |
+
from lhotse.audio import Recording, RecordingSet
|
| 42 |
+
from lhotse.cut import Cut
|
| 43 |
+
from lhotse.qa import fix_manifests
|
| 44 |
+
from lhotse.supervision import SupervisionSegment, SupervisionSet
|
| 45 |
+
from lhotse.utils import Pathlike
|
| 46 |
+
from tqdm.auto import tqdm
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_args():
|
| 50 |
+
parser = argparse.ArgumentParser()
|
| 51 |
+
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--dataset-path",
|
| 54 |
+
type=str,
|
| 55 |
+
help="The path of OpenDialog dataset.",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--num-jobs",
|
| 60 |
+
type=int,
|
| 61 |
+
default=20,
|
| 62 |
+
help="Number of jobs to processing.",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--output-dir",
|
| 67 |
+
type=str,
|
| 68 |
+
default="data/manifests",
|
| 69 |
+
help="The destination directory of manifest files.",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--sampling-rate",
|
| 73 |
+
type=int,
|
| 74 |
+
default=24000,
|
| 75 |
+
help="The target sampling rate.",
|
| 76 |
+
)
|
| 77 |
+
return parser.parse_args()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _parse_recording(
|
| 81 |
+
wav_path: str,
|
| 82 |
+
) -> Tuple[Recording, str]:
|
| 83 |
+
"""
|
| 84 |
+
:param wav_path: Path to the audio file
|
| 85 |
+
:return: a tuple of "recording" and "recording_id"
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
recording_id = Path(wav_path).stem
|
| 89 |
+
recording = Recording.from_file(path=wav_path, recording_id=recording_id)
|
| 90 |
+
|
| 91 |
+
return recording, recording_id
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _parse_supervision(
|
| 95 |
+
supervision: List, recording_dict: dict
|
| 96 |
+
) -> Optional[SupervisionSegment]:
|
| 97 |
+
"""
|
| 98 |
+
:param line: A line from the TSV file
|
| 99 |
+
:param recording_dict: Dictionary mapping recording IDs to Recording objects
|
| 100 |
+
:return: A SupervisionSegment object
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def _round_down(num, ndigits=0):
|
| 104 |
+
factor = 10**ndigits
|
| 105 |
+
return math.floor(num * factor) / factor
|
| 106 |
+
|
| 107 |
+
uniq_id, text, wav_path, start, end = supervision
|
| 108 |
+
try:
|
| 109 |
+
recording_id = Path(wav_path).stem
|
| 110 |
+
|
| 111 |
+
recording = recording_dict[recording_id]
|
| 112 |
+
duration = (
|
| 113 |
+
_round_down(end - start, ndigits=8)
|
| 114 |
+
if end is not None
|
| 115 |
+
else _round_down(recording.duration, ndigits=8)
|
| 116 |
+
)
|
| 117 |
+
assert duration <= recording.duration, f"Duration {duration} is greater than "
|
| 118 |
+
f"recording duration {recording.duration}"
|
| 119 |
+
|
| 120 |
+
text = re.sub("_", " ", text) # "_" is treated as padding symbol
|
| 121 |
+
text = re.sub(r"\s+", " ", text) # remove extra whitespace
|
| 122 |
+
|
| 123 |
+
return SupervisionSegment(
|
| 124 |
+
id=f"{uniq_id}",
|
| 125 |
+
recording_id=recording.id,
|
| 126 |
+
start=start,
|
| 127 |
+
duration=duration,
|
| 128 |
+
channel=recording.channel_ids,
|
| 129 |
+
text=text.strip(),
|
| 130 |
+
)
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logging.info(f"Error processing line: {e}")
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def prepare_subset(
|
| 137 |
+
jsonl_path: Pathlike,
|
| 138 |
+
lang: str,
|
| 139 |
+
sampling_rate: int,
|
| 140 |
+
num_jobs: int,
|
| 141 |
+
output_dir: Pathlike,
|
| 142 |
+
):
|
| 143 |
+
"""
|
| 144 |
+
Returns the manifests which consist of the Recordings and Supervisions
|
| 145 |
+
|
| 146 |
+
:param jsonl_path: Path to the jsonl file
|
| 147 |
+
:param lang: Language of the subset
|
| 148 |
+
:param sampling_rate: Target sampling rate of the audio
|
| 149 |
+
:param num_jobs: Number of processes for parallel processing
|
| 150 |
+
:param output_dir: Path where to write the manifests
|
| 151 |
+
"""
|
| 152 |
+
logging.info(f"Preparing {lang} subset")
|
| 153 |
+
|
| 154 |
+
# Step 1: Read all unique recording paths
|
| 155 |
+
logging.info(f"Reading {jsonl_path}")
|
| 156 |
+
recordings_path_set = set()
|
| 157 |
+
supervision_list = list()
|
| 158 |
+
with open(jsonl_path, "r") as fr:
|
| 159 |
+
for line in fr:
|
| 160 |
+
try:
|
| 161 |
+
items = json.loads(line)
|
| 162 |
+
uniq_id, text, wav_path = items["id"], items["text"], items["path"]
|
| 163 |
+
start, end = 0, None
|
| 164 |
+
recordings_path_set.add(jsonl_path.parent / wav_path)
|
| 165 |
+
supervision_list.append((uniq_id, text, wav_path, start, end))
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logging.warning(f"Error {e} when decoding JSON line: {line}")
|
| 168 |
+
continue
|
| 169 |
+
logging.info("Starting to process recordings...")
|
| 170 |
+
# Step 2: Process recordings
|
| 171 |
+
futures = []
|
| 172 |
+
recording_dict = {}
|
| 173 |
+
with ThreadPoolExecutor(max_workers=num_jobs) as ex:
|
| 174 |
+
for wav_path in tqdm(recordings_path_set, desc="Submitting jobs"):
|
| 175 |
+
futures.append(ex.submit(_parse_recording, wav_path))
|
| 176 |
+
|
| 177 |
+
for future in tqdm(futures, desc="Processing recordings"):
|
| 178 |
+
try:
|
| 179 |
+
recording, recording_id = future.result()
|
| 180 |
+
recording_dict[recording_id] = recording
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logging.warning(
|
| 183 |
+
f"Error processing recording {recording_id} with error: {e}"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
recording_set = RecordingSet.from_recordings(recording_dict.values())
|
| 187 |
+
|
| 188 |
+
logging.info("Starting to process supervisions...")
|
| 189 |
+
# Step 3: Process supervisions
|
| 190 |
+
supervisions = []
|
| 191 |
+
for supervision in tqdm(supervision_list, desc="Processing supervisions"):
|
| 192 |
+
seg = _parse_supervision(supervision, recording_dict)
|
| 193 |
+
if seg is not None:
|
| 194 |
+
supervisions.append(seg)
|
| 195 |
+
|
| 196 |
+
logging.info("Processing Cuts...")
|
| 197 |
+
|
| 198 |
+
# Step 4: Create and validate manifests
|
| 199 |
+
supervision_set = SupervisionSet.from_segments(supervisions)
|
| 200 |
+
|
| 201 |
+
recording_set, supervision_set = fix_manifests(recording_set, supervision_set)
|
| 202 |
+
validate_recordings_and_supervisions(recording_set, supervision_set)
|
| 203 |
+
|
| 204 |
+
cut_set = CutSet.from_manifests(
|
| 205 |
+
recordings=recording_set, supervisions=supervision_set
|
| 206 |
+
)
|
| 207 |
+
cut_set = cut_set.sort_by_recording_id()
|
| 208 |
+
if sampling_rate != 24000:
|
| 209 |
+
# All OpenDialog audios are 24kHz
|
| 210 |
+
cut_set = cut_set.resample(sampling_rate)
|
| 211 |
+
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
| 212 |
+
|
| 213 |
+
logging.info("Saving cuts to disk...")
|
| 214 |
+
# Step 5: Write manifests to disk
|
| 215 |
+
cut_set.to_file(output_dir / f"opendialog_cuts_raw_{lang.upper()}-all.jsonl.gz")
|
| 216 |
+
dev_cut_set = cut_set.subset(first=1000)
|
| 217 |
+
dev_cut_set.to_file(output_dir / f"opendialog_cuts_raw_{lang.upper()}-dev.jsonl.gz")
|
| 218 |
+
|
| 219 |
+
def remove_dev(c: Cut, set: set):
|
| 220 |
+
if c.id in set:
|
| 221 |
+
return False
|
| 222 |
+
return True
|
| 223 |
+
|
| 224 |
+
_remove_dev = partial(remove_dev, set=set(dev_cut_set.ids))
|
| 225 |
+
train_cut_set = cut_set.filter(_remove_dev)
|
| 226 |
+
train_cut_set.to_file(
|
| 227 |
+
output_dir / f"opendialog_cuts_raw_{lang.upper()}-train.jsonl.gz"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def prepare_dataset(
|
| 232 |
+
dataset_path: Pathlike,
|
| 233 |
+
sampling_rate: int,
|
| 234 |
+
num_jobs: int,
|
| 235 |
+
output_dir: Pathlike,
|
| 236 |
+
):
|
| 237 |
+
for lang in ["en", "zh"]:
|
| 238 |
+
jsonl_path = dataset_path / f"manifest.{lang}.jsonl"
|
| 239 |
+
prepare_subset(
|
| 240 |
+
jsonl_path=jsonl_path,
|
| 241 |
+
lang=lang,
|
| 242 |
+
sampling_rate=sampling_rate,
|
| 243 |
+
num_jobs=num_jobs,
|
| 244 |
+
output_dir=output_dir,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 250 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 251 |
+
|
| 252 |
+
args = get_args()
|
| 253 |
+
dataset_path = Path(args.dataset_path)
|
| 254 |
+
output_dir = Path(args.output_dir)
|
| 255 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 256 |
+
|
| 257 |
+
prepare_dataset(
|
| 258 |
+
dataset_path=dataset_path,
|
| 259 |
+
sampling_rate=args.sampling_rate,
|
| 260 |
+
num_jobs=args.num_jobs,
|
| 261 |
+
output_dir=output_dir,
|
| 262 |
+
)
|
egs/zipvoice_dialog/run_custom.sh
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# This script is an example of training ZipVoice-Dialog on your custom datasets.
|
| 4 |
+
# Only support English and Chinese for now.
|
| 5 |
+
|
| 6 |
+
# Add project root to PYTHONPATH
|
| 7 |
+
export PYTHONPATH=../../:$PYTHONPATH
|
| 8 |
+
|
| 9 |
+
# Set bash to 'debug' mode, it will exit on:
|
| 10 |
+
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
| 11 |
+
set -e
|
| 12 |
+
set -u
|
| 13 |
+
set -o pipefail
|
| 14 |
+
|
| 15 |
+
stage=1
|
| 16 |
+
stop_stage=6
|
| 17 |
+
|
| 18 |
+
# Number of jobs for data preparation
|
| 19 |
+
nj=20
|
| 20 |
+
download_dir=download/
|
| 21 |
+
|
| 22 |
+
# Maximum length (seconds) of the training utterance, will filter out longer utterances
|
| 23 |
+
max_len=60
|
| 24 |
+
|
| 25 |
+
# We suppose you have two TSV files: "data/raw/custom_train.tsv" and
|
| 26 |
+
# "data/raw/custom_dev.tsv", where "custom" is your dataset name,
|
| 27 |
+
# "train"/"dev" are used for training and validation respectively.
|
| 28 |
+
|
| 29 |
+
# Each line of the TSV files should be in one of the following formats:
|
| 30 |
+
# (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
|
| 31 |
+
# (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
|
| 32 |
+
# to part of the wav. The start_time and end_time specify the start and end
|
| 33 |
+
# times of the text within the wav, which should be in seconds.
|
| 34 |
+
# > Note: {uniq_id} must be unique for each line.
|
| 35 |
+
# > Note: {text} uses [S1] and [S2] tags to distinguish speakers, and must be begin with [S1].
|
| 36 |
+
# > eg: "[S1] Hello. [S2] How are you? [S1] I'm fine. [S2] What's your name?"
|
| 37 |
+
for subset in train dev;do
|
| 38 |
+
file_path=data/raw/custom_${subset}.tsv
|
| 39 |
+
[ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
|
| 40 |
+
done
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 44 |
+
echo "Stage 1: Prepare manifests for custom dataset from tsv files"
|
| 45 |
+
|
| 46 |
+
for subset in train dev;do
|
| 47 |
+
python3 -m zipvoice.bin.prepare_dataset \
|
| 48 |
+
--tsv-path data/raw/custom_${subset}.tsv \
|
| 49 |
+
--prefix custom \
|
| 50 |
+
--subset raw_${subset} \
|
| 51 |
+
--num-jobs ${nj} \
|
| 52 |
+
--output-dir data/manifests
|
| 53 |
+
done
|
| 54 |
+
# The output manifest files are "data/manifests/custom_cuts_raw_train.jsonl.gz".
|
| 55 |
+
# and "data/manifests/custom_cuts_raw_dev.jsonl.gz".
|
| 56 |
+
fi
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 60 |
+
echo "Stage 2: Add tokens to manifests"
|
| 61 |
+
for subset in train dev;do
|
| 62 |
+
python3 -m zipvoice.bin.prepare_tokens \
|
| 63 |
+
--input-file data/manifests/custom_cuts_raw_${subset}.jsonl.gz \
|
| 64 |
+
--output-file data/manifests/custom_cuts_${subset}.jsonl.gz \
|
| 65 |
+
--tokenizer dialog
|
| 66 |
+
done
|
| 67 |
+
# The output manifest files are "data/manifests/custom_cuts_train.jsonl.gz".
|
| 68 |
+
# and "data/manifests/custom_cuts_dev.jsonl.gz".
|
| 69 |
+
fi
|
| 70 |
+
|
| 71 |
+
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
| 72 |
+
echo "Stage 3: Compute Fbank for custom dataset"
|
| 73 |
+
# You can skip this step and use `--on-the-fly-feats 1` in training stage
|
| 74 |
+
for subset in train dev; do
|
| 75 |
+
python3 -m zipvoice.bin.compute_fbank \
|
| 76 |
+
--source-dir data/manifests \
|
| 77 |
+
--dest-dir data/fbank \
|
| 78 |
+
--dataset custom \
|
| 79 |
+
--subset ${subset} \
|
| 80 |
+
--num-jobs ${nj}
|
| 81 |
+
done
|
| 82 |
+
fi
|
| 83 |
+
|
| 84 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 85 |
+
echo "Stage 4: Download tokens file, pretrained models"
|
| 86 |
+
# Uncomment this line to use HF mirror
|
| 87 |
+
# export HF_ENDPOINT=https://hf-mirror.com
|
| 88 |
+
|
| 89 |
+
# The token file is obtained by extending some tokens
|
| 90 |
+
# on the bases of the Emilia token file.
|
| 91 |
+
mkdir -p ${download_dir}
|
| 92 |
+
hf_repo=k2-fsa/ZipVoice
|
| 93 |
+
huggingface-cli download \
|
| 94 |
+
--local-dir ${download_dir} \
|
| 95 |
+
${hf_repo} \
|
| 96 |
+
zipvoice_dialog/tokens.txt
|
| 97 |
+
|
| 98 |
+
# Pre-trained ZipVoice model is required as
|
| 99 |
+
# the initialization model.
|
| 100 |
+
for file in model.pt tokens.txt model.json; do
|
| 101 |
+
huggingface-cli download \
|
| 102 |
+
--local-dir ${download_dir} \
|
| 103 |
+
${hf_repo} \
|
| 104 |
+
zipvoice/${file}
|
| 105 |
+
done
|
| 106 |
+
fi
|
| 107 |
+
|
| 108 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 109 |
+
echo "Stage 5: Train the ZipVoice-Dialog model"
|
| 110 |
+
python3 -m zipvoice.bin.train_zipvoice_dialog \
|
| 111 |
+
--world-size 4 \
|
| 112 |
+
--use-fp16 1 \
|
| 113 |
+
--base-lr 0.0001 \
|
| 114 |
+
--num-iters 60000 \
|
| 115 |
+
--max-duration 500 \
|
| 116 |
+
--max-len ${max_len} \
|
| 117 |
+
--checkpoint ${download_dir}/zipvoice/model.pt \
|
| 118 |
+
--model-config ${download_dir}/zipvoice/model.json \
|
| 119 |
+
--token-file ${download_dir}/zipvoice_dialog/tokens.txt \
|
| 120 |
+
--dataset custom \
|
| 121 |
+
--train-manifest data/fbank/custom_cuts_train.jsonl.gz \
|
| 122 |
+
--dev-manifest data/fbank/custom_cuts_dev.jsonl.gz \
|
| 123 |
+
--exp-dir exp/zipvoice_dialog_custom
|
| 124 |
+
fi
|
| 125 |
+
|
| 126 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 127 |
+
echo "Stage 6: Average the checkpoints for ZipVoice"
|
| 128 |
+
python3 -m zipvoice.bin.generate_averaged_model \
|
| 129 |
+
--iter 60000 \
|
| 130 |
+
--avg 2 \
|
| 131 |
+
--model-name zipvoice_dialog \
|
| 132 |
+
--exp-dir exp/zipvoice_dialog_custom
|
| 133 |
+
# The generated model is exp/zipvoice_dialog/iter-60000-avg-2.pt
|
| 134 |
+
fi
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
| 138 |
+
echo "Stage 6: Inference of the ZipVoice model"
|
| 139 |
+
python3 -m zipvoice.bin.infer_zipvoice_dialog \
|
| 140 |
+
--model-name zipvoice_dialog \
|
| 141 |
+
--model-dir exp/zipvoice_dialog_custom \
|
| 142 |
+
--checkpoint-name iter-60000-avg-2.pt \
|
| 143 |
+
--test-list test.tsv \
|
| 144 |
+
--res-dir results/test_dialog_custom
|
| 145 |
+
fi
|
egs/zipvoice_dialog/run_eval.sh
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# This script is an example of evaluate TTS models with objective metrics reported in ZipVoice-Dialog paper.
|
| 4 |
+
|
| 5 |
+
# Add project root to PYTHONPATH
|
| 6 |
+
export PYTHONPATH=../../:$PYTHONPATH
|
| 7 |
+
|
| 8 |
+
# Set bash to 'debug' mode, it will exit on:
|
| 9 |
+
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
| 10 |
+
set -e
|
| 11 |
+
set -u
|
| 12 |
+
set -o pipefail
|
| 13 |
+
|
| 14 |
+
stage=1
|
| 15 |
+
stop_stage=6
|
| 16 |
+
|
| 17 |
+
download_dir=download/
|
| 18 |
+
|
| 19 |
+
# Uncomment this line to use HF mirror
|
| 20 |
+
# export HF_ENDPOINT=https://hf-mirror.com
|
| 21 |
+
|
| 22 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 23 |
+
echo "Stage 1: Download test sets (test-dialog)"
|
| 24 |
+
hf_repo=k2-fsa/TTS_eval_datasets
|
| 25 |
+
mkdir -p ${download_dir}/
|
| 26 |
+
file=dialog_testset.tar.gz
|
| 27 |
+
echo "Downloading ${file}..."
|
| 28 |
+
huggingface-cli download \
|
| 29 |
+
--repo-type dataset \
|
| 30 |
+
--local-dir ${download_dir}/ \
|
| 31 |
+
${hf_repo} \
|
| 32 |
+
${file}
|
| 33 |
+
echo "Extracting ${file}..."
|
| 34 |
+
tar -xzf ${download_dir}/${file} -C ${download_dir}/
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 39 |
+
echo "Stage 2: Download all required evaluation models"
|
| 40 |
+
mkdir -p ${download_dir}/tts_eval_models
|
| 41 |
+
mkdir -p ${download_dir}
|
| 42 |
+
huggingface-cli download \
|
| 43 |
+
--local-dir ${download_dir}/tts_eval_models \
|
| 44 |
+
${hf_repo}
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 48 |
+
echo "Stage 3: Inference with the pre-trained ZipVoice model from huggingface"
|
| 49 |
+
|
| 50 |
+
for testset in test_dialog_en test_dialog_zh; do
|
| 51 |
+
if [ "$testset" = "test_dialog_en" ]; then
|
| 52 |
+
test_tsv=${download_dir}/dialog_testset/en/test.tsv
|
| 53 |
+
elif [ "$testset" = "test_dialog_zh" ]; then
|
| 54 |
+
test_tsv=${download_dir}/dialog_testset/zh/test.tsv
|
| 55 |
+
else
|
| 56 |
+
echo "Error: unknown testset ${testset}" >&2
|
| 57 |
+
exit 1
|
| 58 |
+
fi
|
| 59 |
+
echo "Inference on tetset ${testset}..."
|
| 60 |
+
python3 -m zipvoice.bin.infer_zipvoice_dialog \
|
| 61 |
+
--model-name zipvoice_dialog \
|
| 62 |
+
--test-list ${test_tsv} \
|
| 63 |
+
--res-dir results/${testset}
|
| 64 |
+
done
|
| 65 |
+
fi
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 69 |
+
echo "Stage 4: Evaluation on test-dialog-en"
|
| 70 |
+
model_path=${download_dir}/tts_eval_models
|
| 71 |
+
wav_path=results/test_dialog_en
|
| 72 |
+
test_tsv=${download_dir}/dialog_testset/en/test.tsv
|
| 73 |
+
|
| 74 |
+
python3 -m zipvoice.eval.speaker_similarity.cpsim \
|
| 75 |
+
--wav-path ${wav_path} \
|
| 76 |
+
--test-list ${test_tsv} \
|
| 77 |
+
--model-dir ${model_path}
|
| 78 |
+
|
| 79 |
+
python3 -m zipvoice.eval.wer.dialog \
|
| 80 |
+
--wav-path ${wav_path} \
|
| 81 |
+
--test-list ${test_tsv} \
|
| 82 |
+
--model-dir ${model_path} \
|
| 83 |
+
--lang en
|
| 84 |
+
|
| 85 |
+
# cpWER mode: will only compute WER and cpWER
|
| 86 |
+
# for speech less than 30s
|
| 87 |
+
python3 -m zipvoice.eval.wer.dialog \
|
| 88 |
+
--wav-path ${wav_path} \
|
| 89 |
+
--test-list ${test_tsv} \
|
| 90 |
+
--model-dir ${model_path} \
|
| 91 |
+
--lang en \
|
| 92 |
+
--cpwer
|
| 93 |
+
|
| 94 |
+
python3 -m zipvoice.eval.mos.utmos \
|
| 95 |
+
--wav-path ${wav_path} \
|
| 96 |
+
--model-dir ${model_path}
|
| 97 |
+
fi
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 101 |
+
echo "Stage 5: Evaluation on test-dialog-zh"
|
| 102 |
+
model_path=${download_dir}/tts_eval_models
|
| 103 |
+
wav_path=results/test_dialog_zh
|
| 104 |
+
test_tsv=${download_dir}/dialog_testset/zh/test.tsv
|
| 105 |
+
|
| 106 |
+
python3 -m zipvoice.eval.speaker_similarity.cpsim \
|
| 107 |
+
--wav-path ${wav_path} \
|
| 108 |
+
--test-list ${test_tsv} \
|
| 109 |
+
--model-dir ${model_path}
|
| 110 |
+
|
| 111 |
+
python3 -m zipvoice.eval.wer.dialog \
|
| 112 |
+
--wav-path ${wav_path} \
|
| 113 |
+
--test-list ${test_tsv} \
|
| 114 |
+
--model-dir ${model_path} \
|
| 115 |
+
--lang zh
|
| 116 |
+
|
| 117 |
+
python3 -m zipvoice.eval.mos.utmos \
|
| 118 |
+
--wav-path ${wav_path} \
|
| 119 |
+
--model-dir ${model_path}
|
| 120 |
+
fi
|
egs/zipvoice_dialog/run_finetune.sh
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# This script is an example of fine-tune our pre-trained ZipVoice-Dialog on your custom datasets.
|
| 4 |
+
# Only support English and Chinese for now.
|
| 5 |
+
|
| 6 |
+
# Add project root to PYTHONPATH
|
| 7 |
+
export PYTHONPATH=../../:$PYTHONPATH
|
| 8 |
+
|
| 9 |
+
# Set bash to 'debug' mode, it will exit on:
|
| 10 |
+
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
| 11 |
+
set -e
|
| 12 |
+
set -u
|
| 13 |
+
set -o pipefail
|
| 14 |
+
|
| 15 |
+
stage=1
|
| 16 |
+
stop_stage=6
|
| 17 |
+
|
| 18 |
+
# Number of jobs for data preparation
|
| 19 |
+
nj=20
|
| 20 |
+
# Maximum length (seconds) of the training utterance, will filter out longer utterances
|
| 21 |
+
max_len=60
|
| 22 |
+
download_dir=download/
|
| 23 |
+
|
| 24 |
+
# We suppose you have two TSV files: "data/raw/custom_train.tsv" and
|
| 25 |
+
# "data/raw/custom_dev.tsv", where "custom" is your dataset name,
|
| 26 |
+
# "train"/"dev" are used for training and validation respectively.
|
| 27 |
+
|
| 28 |
+
# Each line of the TSV files should be in one of the following formats:
|
| 29 |
+
# (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
|
| 30 |
+
# (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
|
| 31 |
+
# to part of the wav. The start_time and end_time specify the start and end
|
| 32 |
+
# times of the text within the wav, which should be in seconds.
|
| 33 |
+
# > Note: {uniq_id} must be unique for each line.
|
| 34 |
+
# > Note: {text} uses [S1] and [S2] tags to distinguish speakers, and must be begin with [S1].
|
| 35 |
+
# > eg: "[S1] Hello. [S2] How are you? [S1] I'm fine. [S2] What's your name?"
|
| 36 |
+
for subset in train dev;do
|
| 37 |
+
file_path=data/raw/custom_${subset}.tsv
|
| 38 |
+
[ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
|
| 39 |
+
done
|
| 40 |
+
|
| 41 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 42 |
+
echo "Stage 1: Prepare manifests for custom dataset from tsv files"
|
| 43 |
+
|
| 44 |
+
for subset in train dev;do
|
| 45 |
+
python3 -m zipvoice.bin.prepare_dataset \
|
| 46 |
+
--tsv-path data/raw/custom_${subset}.tsv \
|
| 47 |
+
--prefix custom-finetune \
|
| 48 |
+
--subset raw_${subset} \
|
| 49 |
+
--num-jobs ${nj} \
|
| 50 |
+
--output-dir data/manifests
|
| 51 |
+
done
|
| 52 |
+
# The output manifest files are "data/manifests/custom-finetune_cuts_raw_train.jsonl.gz".
|
| 53 |
+
# and "data/manifests/custom-finetune_cuts_raw_dev.jsonl.gz".
|
| 54 |
+
fi
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 58 |
+
echo "Stage 2: Add tokens to manifests"
|
| 59 |
+
for subset in train dev;do
|
| 60 |
+
python3 -m zipvoice.bin.prepare_tokens \
|
| 61 |
+
--input-file data/manifests/custom-finetune_cuts_raw_${subset}.jsonl.gz \
|
| 62 |
+
--output-file data/manifests/custom-finetune_cuts_${subset}.jsonl.gz \
|
| 63 |
+
--tokenizer dialog
|
| 64 |
+
done
|
| 65 |
+
# The output manifest files are "data/manifests/custom-finetune_cuts_train.jsonl.gz".
|
| 66 |
+
# and "data/manifests/custom-finetune_cuts_dev.jsonl.gz".
|
| 67 |
+
fi
|
| 68 |
+
|
| 69 |
+
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
| 70 |
+
echo "Stage 3: Compute Fbank for custom dataset"
|
| 71 |
+
# You can skip this step and use `--on-the-fly-feats 1` in training stage
|
| 72 |
+
for subset in train dev; do
|
| 73 |
+
python3 -m zipvoice.bin.compute_fbank \
|
| 74 |
+
--source-dir data/manifests \
|
| 75 |
+
--dest-dir data/fbank \
|
| 76 |
+
--dataset custom-finetune \
|
| 77 |
+
--subset ${subset} \
|
| 78 |
+
--num-jobs ${nj}
|
| 79 |
+
done
|
| 80 |
+
fi
|
| 81 |
+
|
| 82 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 83 |
+
echo "Stage 4: Download pre-trained model, tokens file, and model config"
|
| 84 |
+
# Uncomment this line to use HF mirror
|
| 85 |
+
# export HF_ENDPOINT=https://hf-mirror.com
|
| 86 |
+
|
| 87 |
+
mkdir -p ${download_dir}
|
| 88 |
+
hf_repo=k2-fsa/ZipVoice
|
| 89 |
+
for file in model.pt tokens.txt model.json; do
|
| 90 |
+
huggingface-cli download \
|
| 91 |
+
--local-dir ${download_dir} \
|
| 92 |
+
${hf_repo} \
|
| 93 |
+
zipvoice_dialog/${file}
|
| 94 |
+
done
|
| 95 |
+
fi
|
| 96 |
+
|
| 97 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 98 |
+
echo "Stage 5: Fine-tune the ZipVoice-Dialog model"
|
| 99 |
+
python3 -m zipvoice.bin.train_zipvoice_dialog \
|
| 100 |
+
--world-size 4 \
|
| 101 |
+
--use-fp16 1 \
|
| 102 |
+
--finetune 1 \
|
| 103 |
+
--base-lr 0.0001 \
|
| 104 |
+
--num-iters 10000 \
|
| 105 |
+
--save-every-n 1000 \
|
| 106 |
+
--max-duration 500 \
|
| 107 |
+
--max-len ${max_len} \
|
| 108 |
+
--checkpoint ${download_dir}/zipvoice_dialog/model.pt \
|
| 109 |
+
--model-config ${download_dir}/zipvoice_dialog/model.json \
|
| 110 |
+
--token-file ${download_dir}/zipvoice_dialog/tokens.txt \
|
| 111 |
+
--dataset custom \
|
| 112 |
+
--train-manifest data/fbank/custom-finetune_cuts_train.jsonl.gz \
|
| 113 |
+
--dev-manifest data/fbank/custom-finetune_cuts_dev.jsonl.gz \
|
| 114 |
+
--exp-dir exp/zipvoice_dialog_finetune
|
| 115 |
+
fi
|
| 116 |
+
|
| 117 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 118 |
+
echo "Stage 6: Average the checkpoints for ZipVoice"
|
| 119 |
+
python3 -m zipvoice.bin.generate_averaged_model \
|
| 120 |
+
--iter 10000 \
|
| 121 |
+
--avg 2 \
|
| 122 |
+
--model-name zipvoice_dialog \
|
| 123 |
+
--exp-dir exp/zipvoice_dialog_finetune
|
| 124 |
+
# The generated model is exp/zipvoice_dialog_finetune/iter-10000-avg-2.pt
|
| 125 |
+
fi
|
| 126 |
+
|
| 127 |
+
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
| 128 |
+
echo "Stage 7: Inference of the ZipVoice model"
|
| 129 |
+
python3 -m zipvoice.bin.infer_zipvoice_dialog \
|
| 130 |
+
--model-name zipvoice_dialog \
|
| 131 |
+
--model-dir exp/zipvoice_dialog_finetune \
|
| 132 |
+
--checkpoint-name iter-10000-avg-2.pt \
|
| 133 |
+
--test-list test.tsv \
|
| 134 |
+
--res-dir results/test_dialog_finetune
|
| 135 |
+
fi
|
egs/zipvoice_dialog/run_opendialog.sh
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# This script is an example of training ZipVoice-Dialog on OpenDialog dataset.
|
| 4 |
+
|
| 5 |
+
# Add project root to PYTHONPATH
|
| 6 |
+
export PYTHONPATH=../../:$PYTHONPATH
|
| 7 |
+
|
| 8 |
+
# Set bash to 'debug' mode, it will exit on:
|
| 9 |
+
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
| 10 |
+
set -e
|
| 11 |
+
set -u
|
| 12 |
+
set -o pipefail
|
| 13 |
+
|
| 14 |
+
stage=1
|
| 15 |
+
stop_stage=6
|
| 16 |
+
|
| 17 |
+
# Number of jobs for data preparation
|
| 18 |
+
nj=20
|
| 19 |
+
|
| 20 |
+
# We assume that you have downloaded the OpenDialog dataset
|
| 21 |
+
# to download/OpenDialog and untarred the tar files in audio/en
|
| 22 |
+
# and audio/zh so that the mp3 files are placed under these two directories.
|
| 23 |
+
|
| 24 |
+
# Download OpenDialog at https://huggingface.co/datasets/k2-fsa/OpenDialog
|
| 25 |
+
# or https://www.modelscope.cn/datasets/k2-fsa/OpenDialog
|
| 26 |
+
data_dir=download/OpenDialog
|
| 27 |
+
download_dir=download/
|
| 28 |
+
|
| 29 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 30 |
+
echo "Stage 1: Prepare manifests for OpenDialog dataset"
|
| 31 |
+
|
| 32 |
+
python3 local/prepare_opendialog.py \
|
| 33 |
+
--dataset-path ${data_dir} \
|
| 34 |
+
--num-jobs ${nj} \
|
| 35 |
+
--output-dir data/manifests
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 40 |
+
echo "Stage 2: Add tokens to manifests"
|
| 41 |
+
for subset in ZH-dev ZH-train EN-dev EN-train;do
|
| 42 |
+
python3 -m zipvoice.bin.prepare_tokens \
|
| 43 |
+
--input-file data/manifests/opendialog_cuts_raw_${subset}.jsonl.gz \
|
| 44 |
+
--output-file data/manifests/opendialog_cuts_${subset}.jsonl.gz \
|
| 45 |
+
--tokenizer dialog
|
| 46 |
+
done
|
| 47 |
+
fi
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
| 51 |
+
echo "Stage 3: Compute Fbank for opendialog dataset"
|
| 52 |
+
# You can skip this step and use `--on-the-fly-feats 1` in training stage
|
| 53 |
+
for subset in ZH-dev ZH-train EN-dev EN-train;do
|
| 54 |
+
python3 -m zipvoice.bin.compute_fbank \
|
| 55 |
+
--source-dir data/manifests \
|
| 56 |
+
--dest-dir data/fbank \
|
| 57 |
+
--dataset opendialog \
|
| 58 |
+
--subset ${subset} \
|
| 59 |
+
--num-jobs ${nj}
|
| 60 |
+
done
|
| 61 |
+
fi
|
| 62 |
+
|
| 63 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 64 |
+
echo "Stage 4: Download tokens file, pretrained models"
|
| 65 |
+
# Uncomment this line to use HF mirror
|
| 66 |
+
# export HF_ENDPOINT=https://hf-mirror.com
|
| 67 |
+
|
| 68 |
+
# The token file is obtained by extending some tokens
|
| 69 |
+
# on the bases of the Emilia token file.
|
| 70 |
+
mkdir -p ${download_dir}
|
| 71 |
+
hf_repo=k2-fsa/ZipVoice
|
| 72 |
+
huggingface-cli download \
|
| 73 |
+
--local-dir ${download_dir} \
|
| 74 |
+
${hf_repo} \
|
| 75 |
+
zipvoice_dialog/tokens.txt
|
| 76 |
+
|
| 77 |
+
# Pre-trained ZipVoice model is required as
|
| 78 |
+
# the initialization model.
|
| 79 |
+
for file in model.pt tokens.txt model.json; do
|
| 80 |
+
huggingface-cli download \
|
| 81 |
+
--local-dir ${download_dir} \
|
| 82 |
+
${hf_repo} \
|
| 83 |
+
zipvoice/${file}
|
| 84 |
+
done
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 89 |
+
echo "Stage 5: Train the ZipVoice-Dialog model"
|
| 90 |
+
python3 -m zipvoice.bin.train_zipvoice_dialog \
|
| 91 |
+
--world-size 8 \
|
| 92 |
+
--use-fp16 1 \
|
| 93 |
+
--base-lr 0.0001 \
|
| 94 |
+
--max-duration 500 \
|
| 95 |
+
--checkpoint ${download_dir}/zipvoice/model.pt \
|
| 96 |
+
--model-config ${download_dir}/zipvoice/model.json \
|
| 97 |
+
--token-file ${download_dir}/zipvoice_dialog/tokens.txt \
|
| 98 |
+
--dataset opendialog \
|
| 99 |
+
--manifest-dir data/fbank \
|
| 100 |
+
--exp-dir exp/zipvoice_dialog_opendialog
|
| 101 |
+
fi
|
| 102 |
+
|
| 103 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 104 |
+
echo "Stage 6: Average the checkpoints for ZipVoice"
|
| 105 |
+
python3 -m zipvoice.bin.generate_averaged_model \
|
| 106 |
+
--iter 60000 \
|
| 107 |
+
--avg 2 \
|
| 108 |
+
--model-name zipvoice_dialog \
|
| 109 |
+
--exp-dir exp/zipvoice_dialog_opendialog
|
| 110 |
+
# The generated model is exp/zipvoice_dialog_opendialog/iter-60000-avg-2.pt
|
| 111 |
+
fi
|
| 112 |
+
|
| 113 |
+
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
| 114 |
+
echo "Stage 7: Inference of the ZipVoice model"
|
| 115 |
+
|
| 116 |
+
python3 -m zipvoice.bin.infer_zipvoice_dialog \
|
| 117 |
+
--model-name zipvoice_dialog \
|
| 118 |
+
--model-dir exp/zipvoice_dialog_opendialog \
|
| 119 |
+
--checkpoint-name iter-60000-avg-2.pt \
|
| 120 |
+
--test-list test.tsv \
|
| 121 |
+
--res-dir results/test_dialog
|
| 122 |
+
fi
|
infer.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Tuple
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import (
|
| 4 |
+
AutoTokenizer, AutoModelForTokenClassification,
|
| 5 |
+
DataCollatorForTokenClassification, Trainer, TrainingArguments
|
| 6 |
+
)
|
| 7 |
+
LABEL_LIST = ["O", "B-EN", "I-EN"]
|
| 8 |
+
LABEL2ID = {l:i for i,l in enumerate(LABEL_LIST)}
|
| 9 |
+
ID2LABEL = {i:l for l,i in LABEL2ID.items()}
|
| 10 |
+
|
| 11 |
+
model_name = "meandyou200175/detect_english"
|
| 12 |
+
model_detect = AutoModelForTokenClassification.from_pretrained(
|
| 13 |
+
model_name, num_labels=len(LABEL_LIST),
|
| 14 |
+
id2label=ID2LABEL, label2id=LABEL2ID
|
| 15 |
+
)
|
| 16 |
+
tokenizer_detect = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
| 17 |
+
|
| 18 |
+
def tokens_to_pred_spans(offsets: List[Tuple[int,int]], pred_ids: List[int]) -> List[Tuple[int,int]]:
|
| 19 |
+
spans=[]; cur=None
|
| 20 |
+
for (start,end), lid in zip(offsets, pred_ids):
|
| 21 |
+
if start==end: continue
|
| 22 |
+
lab = ID2LABEL.get(lid,"O")
|
| 23 |
+
if lab=="B-EN":
|
| 24 |
+
if cur: spans.append(cur)
|
| 25 |
+
cur=[start,end]
|
| 26 |
+
elif lab=="I-EN":
|
| 27 |
+
if cur: cur[1]=end
|
| 28 |
+
else: cur=[start,end]
|
| 29 |
+
else:
|
| 30 |
+
if cur: spans.append(cur); cur=None
|
| 31 |
+
if cur: spans.append(cur)
|
| 32 |
+
return [tuple(x) for x in spans]
|
| 33 |
+
|
| 34 |
+
def merge_close_spans(spans: List[Dict], max_gap: int = 2) -> List[Dict]:
|
| 35 |
+
if not spans:
|
| 36 |
+
return []
|
| 37 |
+
merged = [spans[0]]
|
| 38 |
+
for cur in spans[1:]:
|
| 39 |
+
prev = merged[-1]
|
| 40 |
+
if cur["start"] - prev["end"] <= max_gap:
|
| 41 |
+
# gα»p lαΊ‘i
|
| 42 |
+
prev["end"] = cur["end"]
|
| 43 |
+
else:
|
| 44 |
+
merged.append(cur)
|
| 45 |
+
return merged
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def infer_spans(text: str, tokenizer, model, max_length: int = 256) -> List[Dict]:
|
| 49 |
+
text = text.lower()
|
| 50 |
+
enc = tokenizer(text, return_offsets_mapping=True, truncation=True,
|
| 51 |
+
max_length=max_length, return_tensors="pt")
|
| 52 |
+
offsets = enc["offset_mapping"][0].tolist()
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
out = model(**{k: v for k, v in enc.items() if k != "offset_mapping"})
|
| 55 |
+
pred_ids = out.logits.argmax(-1)[0].tolist()
|
| 56 |
+
spans = tokens_to_pred_spans(offsets, pred_ids)
|
| 57 |
+
spans = [{"start": s, "end": e} for (s, e) in spans]
|
| 58 |
+
spans = merge_close_spans(spans, max_gap=2)
|
| 59 |
+
# print(spans)
|
| 60 |
+
return spans
|
| 61 |
+
|
| 62 |
+
import unicodedata
|
| 63 |
+
|
| 64 |
+
def is_letter(ch: str) -> bool:
|
| 65 |
+
if not ch:
|
| 66 |
+
return False
|
| 67 |
+
# NαΊΏu ngΖ°α»i dΓΉng lα»‘ truyα»n vΓ o tα» hợp cΓ³ dαΊ₯u (e + βΜ), chuαΊ©n hoΓ‘ vα» NFC:
|
| 68 |
+
ch = unicodedata.normalize("NFC", ch)
|
| 69 |
+
# Chα» chαΊ₯p nhαΊn ΔΓΊng 1 kΓ½ tα»± sau chuαΊ©n hoΓ‘
|
| 70 |
+
if len(ch) != 1:
|
| 71 |
+
return False
|
| 72 |
+
# NhΓ³m 'L*' cα»§a Unicode: Lu, Ll, Lt, Lm, Lo
|
| 73 |
+
return unicodedata.category(ch).startswith('L')
|
| 74 |
+
|
| 75 |
+
import re
|
| 76 |
+
from itertools import chain
|
| 77 |
+
from typing import List, Dict, Optional
|
| 78 |
+
import logging
|
| 79 |
+
from functools import reduce
|
| 80 |
+
from piper_phonemize import phonemize_espeak
|
| 81 |
+
|
| 82 |
+
class EspeakTokenizer():
|
| 83 |
+
"""A tokenizer with Espeak g2p function, hỠtrợ English + Vietnamese."""
|
| 84 |
+
|
| 85 |
+
def __init__(self, token_file: Optional[str] = None, lang: str = "vi",
|
| 86 |
+
tokenizer=None, model=None):
|
| 87 |
+
self.has_tokens = False
|
| 88 |
+
self.lang = lang
|
| 89 |
+
self.detector_tokenizer = tokenizer
|
| 90 |
+
self.detector_model = model
|
| 91 |
+
|
| 92 |
+
if token_file is None:
|
| 93 |
+
logging.debug("Initialize Tokenizer without tokens file, "
|
| 94 |
+
"will fail when map to ids.")
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
self.token2id: Dict[str, int] = {}
|
| 98 |
+
with open(token_file, "r", encoding="utf-8") as f:
|
| 99 |
+
for line in f.readlines():
|
| 100 |
+
info = line.rstrip().split("\t")
|
| 101 |
+
token, id = info[0], int(info[1])
|
| 102 |
+
assert token not in self.token2id, token
|
| 103 |
+
self.token2id[token] = id
|
| 104 |
+
self.pad_id = self.token2id["_"]
|
| 105 |
+
self.vocab_size = len(self.token2id)
|
| 106 |
+
self.has_tokens = True
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def _flatten(phs):
|
| 110 |
+
"""PhαΊ³ng hΓ³a list-of-lists (hoαΊ·c trαΊ£ lαΊ‘i list nαΊΏu ΔΓ£ phαΊ³ng)."""
|
| 111 |
+
if not phs:
|
| 112 |
+
return []
|
| 113 |
+
if isinstance(phs[0], (list, tuple)):
|
| 114 |
+
return list(chain.from_iterable(phs))
|
| 115 |
+
return list(phs)
|
| 116 |
+
|
| 117 |
+
def g2p_chunk(self, text: str, lang: str):
|
| 118 |
+
tokens = []
|
| 119 |
+
start = 0
|
| 120 |
+
for t in text:
|
| 121 |
+
if is_letter(t):
|
| 122 |
+
break
|
| 123 |
+
start = start + 1
|
| 124 |
+
|
| 125 |
+
# Giα»― lαΊ‘i: khoαΊ£ng trαΊ―ng (\s+), tα»« (\w+), kΓ½ tα»± khΓ‘c [^\w\s]
|
| 126 |
+
if start > 0 :
|
| 127 |
+
tokens.extend(self._flatten(text[0:start]))
|
| 128 |
+
phs = phonemize_espeak(text[start:], lang) # cΓ³ thα» trαΊ£ vα» list-of-lists
|
| 129 |
+
tokens.extend(self._flatten(phs))
|
| 130 |
+
return tokens
|
| 131 |
+
|
| 132 |
+
def g2p(self, text: str) -> List[str]:
|
| 133 |
+
"""TΓ‘ch text thΓ nh spans EN/VI rα»i phonemize tΖ°Ζ‘ng α»©ng, bαΊ£o toΓ n khoαΊ£ng trαΊ―ng/dαΊ₯u cΓ’u."""
|
| 134 |
+
try:
|
| 135 |
+
# Fallback: khΓ΄ng cΓ³ detector => phonemize toΓ n chuα»i theo self.lang,
|
| 136 |
+
# nhΖ°ng qua g2p_chunk Δα» khΓ΄ng mαΊ₯t khoαΊ£ng trαΊ―ng/dαΊ₯u cΓ’u.
|
| 137 |
+
if self.detector_tokenizer is None or self.detector_model is None:
|
| 138 |
+
return self.g2p_chunk(text, self.lang)
|
| 139 |
+
|
| 140 |
+
spans = infer_spans(text, self.detector_tokenizer, self.detector_model)
|
| 141 |
+
spans = sorted(spans, key=lambda x: x["start"])
|
| 142 |
+
|
| 143 |
+
tokens_all = []
|
| 144 |
+
last = 0
|
| 145 |
+
for sp in spans:
|
| 146 |
+
s, e = sp["start"], sp["end"]
|
| 147 |
+
# phαΊ§n trΖ°α»c ΔoαΊ‘n EN -> VI
|
| 148 |
+
if s > last:
|
| 149 |
+
vi_chunk = text[last:s]
|
| 150 |
+
if vi_chunk:
|
| 151 |
+
tokens_all.extend(self.g2p_chunk(vi_chunk, "vi"))
|
| 152 |
+
# ΔoαΊ‘n EN
|
| 153 |
+
en_chunk = text[s:e]
|
| 154 |
+
if en_chunk:
|
| 155 |
+
tokens_all.extend([" "])
|
| 156 |
+
tokens_all.extend(self.g2p_chunk(en_chunk, "en"))
|
| 157 |
+
last = e
|
| 158 |
+
|
| 159 |
+
# phαΊ§n cΓ²n lαΊ‘i sau EN -> VI
|
| 160 |
+
if last < len(text):
|
| 161 |
+
vi_chunk = text[last:]
|
| 162 |
+
if vi_chunk:
|
| 163 |
+
tokens_all.extend(self.g2p_chunk(vi_chunk, "vi"))
|
| 164 |
+
|
| 165 |
+
return tokens_all
|
| 166 |
+
|
| 167 |
+
except Exception as ex:
|
| 168 |
+
logging.warning(f"Tokenization of mixed {self.lang} texts failed: {ex}")
|
| 169 |
+
return []
|
| 170 |
+
def texts_to_token_ids(
|
| 171 |
+
self,
|
| 172 |
+
texts: List[str],
|
| 173 |
+
) -> List[List[int]]:
|
| 174 |
+
return self.tokens_to_token_ids(self.texts_to_tokens(texts))
|
| 175 |
+
|
| 176 |
+
def texts_to_tokens(
|
| 177 |
+
self,
|
| 178 |
+
texts: List[str],
|
| 179 |
+
) -> List[List[str]]:
|
| 180 |
+
tokens_list = [self.g2p(texts[i]) for i in range(len(texts))]
|
| 181 |
+
return tokens_list
|
| 182 |
+
|
| 183 |
+
def tokens_to_token_ids(
|
| 184 |
+
self,
|
| 185 |
+
tokens_list: List[List[str]],
|
| 186 |
+
) -> List[List[int]]:
|
| 187 |
+
assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
|
| 188 |
+
|
| 189 |
+
token_ids_list = []
|
| 190 |
+
|
| 191 |
+
for tokens in tokens_list:
|
| 192 |
+
token_ids = []
|
| 193 |
+
for t in tokens:
|
| 194 |
+
if t not in self.token2id:
|
| 195 |
+
logging.debug(f"Skip OOV {t}")
|
| 196 |
+
continue
|
| 197 |
+
token_ids.append(self.token2id[t])
|
| 198 |
+
|
| 199 |
+
token_ids_list.append(token_ids)
|
| 200 |
+
|
| 201 |
+
return token_ids_list
|
| 202 |
+
import re # <-- thΓͺm
|
| 203 |
+
import random
|
| 204 |
+
import datetime as dt
|
| 205 |
+
import json
|
| 206 |
+
import logging
|
| 207 |
+
import os
|
| 208 |
+
from pathlib import Path
|
| 209 |
+
from typing import Optional
|
| 210 |
+
|
| 211 |
+
import numpy as np
|
| 212 |
+
import safetensors.torch
|
| 213 |
+
import torch
|
| 214 |
+
import torchaudio
|
| 215 |
+
from huggingface_hub import hf_hub_download
|
| 216 |
+
from lhotse.utils import fix_random_seed
|
| 217 |
+
from vocos import Vocos
|
| 218 |
+
|
| 219 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 220 |
+
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
|
| 221 |
+
# from zipvoice.tokenizer.tokenizer import EmiliaTokenizer, EspeakTokenizer, LibriTTSTokenizer, SimpleTokenizer, SimpleTokenizer2
|
| 222 |
+
from zipvoice.utils.checkpoint import load_checkpoint
|
| 223 |
+
from zipvoice.utils.common import AttributeDict
|
| 224 |
+
from zipvoice.utils.feature import VocosFbank
|
| 225 |
+
def load_vocab(file_path):
|
| 226 |
+
"""Δα»c file vocab dαΊ‘ng char <tab> id -> trαΊ£ vα» dict {id: char}"""
|
| 227 |
+
id2char = {}
|
| 228 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 229 |
+
for line in f:
|
| 230 |
+
if not line.strip():
|
| 231 |
+
continue
|
| 232 |
+
# bα» \n nhΖ°ng giα»― lαΊ‘i space ΔαΊ§u dΓ²ng
|
| 233 |
+
line = line.rstrip("\n")
|
| 234 |
+
parts = line.split("\t")
|
| 235 |
+
if len(parts) != 2:
|
| 236 |
+
continue # bα» qua dΓ²ng lα»i
|
| 237 |
+
char, idx = parts
|
| 238 |
+
id2char[int(idx)] = char
|
| 239 |
+
return id2char
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def tokens_to_text(tokens, id2char):
|
| 243 |
+
"""Chuyα»n list token vα» string"""
|
| 244 |
+
return "".join(id2char.get(t, "<unk>") for t in tokens)
|
| 245 |
+
|
| 246 |
+
def get_vocoder(vocos_local_path: Optional[str] = None):
|
| 247 |
+
if vocos_local_path:
|
| 248 |
+
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
| 249 |
+
state_dict = torch.load(
|
| 250 |
+
f"{vocos_local_path}/pytorch_model.bin",
|
| 251 |
+
weights_only=True,
|
| 252 |
+
map_location="cpu",
|
| 253 |
+
)
|
| 254 |
+
vocoder.load_state_dict(state_dict)
|
| 255 |
+
else:
|
| 256 |
+
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 257 |
+
return vocoder
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
|
| 261 |
+
MODEL_DIR = {
|
| 262 |
+
"zipvoice": "zipvoice",
|
| 263 |
+
"zipvoice_distill": "zipvoice_distill",
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
model_dir="zipvoice_finetune/"
|
| 267 |
+
checkpoint_name="iter-525000-avg-2.pt"
|
| 268 |
+
# checkpoint_name="model.pt"
|
| 269 |
+
model_dir = Path(model_dir)
|
| 270 |
+
model_ckpt = model_dir / checkpoint_name
|
| 271 |
+
model_config_path = model_dir / "model.json"
|
| 272 |
+
token_file = model_dir / "tokens.txt"
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
tokenizer = EspeakTokenizer(token_file=token_file, tokenizer=tokenizer_detect, model=model_detect)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
|
| 279 |
+
|
| 280 |
+
with open(model_config_path, "r") as f:
|
| 281 |
+
model_config = json.load(f)
|
| 282 |
+
|
| 283 |
+
# --- Init model ---
|
| 284 |
+
|
| 285 |
+
model = ZipVoice(**model_config["model"], **tokenizer_config)
|
| 286 |
+
|
| 287 |
+
if str(model_ckpt).endswith(".safetensors"):
|
| 288 |
+
safetensors.torch.load_model(model, model_ckpt)
|
| 289 |
+
else:
|
| 290 |
+
load_checkpoint(filename=model_ckpt, model=model, strict=True)
|
| 291 |
+
|
| 292 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 293 |
+
model = model.to(device).eval()
|
| 294 |
+
|
| 295 |
+
# --- Vocoder & features ---
|
| 296 |
+
vocoder = get_vocoder(None).to(device).eval()
|
| 297 |
+
feature_extractor = VocosFbank()
|
| 298 |
+
sampling_rate = model_config["feature"]["sampling_rate"]
|
| 299 |
+
import torch
|
| 300 |
+
import numpy as np
|
| 301 |
+
|
| 302 |
+
import torch
|
| 303 |
+
import numpy as np
|
| 304 |
+
def score_tokens(A):
|
| 305 |
+
B = [9, 14, 18, 21, 27, 33, 37, 39, 42, 45, 50, 51, 52, 54, 58, 59, 61, 62, 63, 69, 73, 74, 79, 85, 99, 100, 102, 105, 119, 120, 121, 122, 123, 124, 141, 143, 144, 145, 146, 157, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 349, 350, 353, 356, 357, 358, 359]
|
| 306 |
+
|
| 307 |
+
total_score = 0
|
| 308 |
+
# ThΓͺm 3 vΓ o ΔαΊ§u vΓ cuα»i
|
| 309 |
+
tokens = [3] + A + [3]
|
| 310 |
+
|
| 311 |
+
# TΓ‘ch chuα»i theo sα» 3
|
| 312 |
+
segment = []
|
| 313 |
+
for t in tokens:
|
| 314 |
+
if t == 3:
|
| 315 |
+
if segment: # xα» lΓ½ 1 ΔoαΊ‘n
|
| 316 |
+
count = 0
|
| 317 |
+
for i in range(len(segment) - 1):
|
| 318 |
+
if (segment[i] in B and segment[i+1] not in B):
|
| 319 |
+
# print(f"{segment[i]} in B and {segment[i+1]} not in B)")
|
| 320 |
+
count += 1
|
| 321 |
+
if segment[-1] in B:
|
| 322 |
+
# print(f"{segment[-1]} in B")
|
| 323 |
+
count += 1
|
| 324 |
+
if count > 0:
|
| 325 |
+
total_score += 1 + (count - 1) * 0.5
|
| 326 |
+
segment = []
|
| 327 |
+
else:
|
| 328 |
+
segment.append(t)
|
| 329 |
+
|
| 330 |
+
return total_score
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def trim_leading_silence_torch(
|
| 334 |
+
wav: torch.Tensor,
|
| 335 |
+
sample_rate: int,
|
| 336 |
+
silence_thresh: float = 0.05,
|
| 337 |
+
chunk_ms: int = 10,
|
| 338 |
+
extend_ms: int = 20,
|
| 339 |
+
ratio: float = 0.95, # % sample phαΊ£i dΖ°α»i ngΖ°α»‘ng Δα» coi lΓ im lαΊ·ng
|
| 340 |
+
):
|
| 341 |
+
wav_np = wav.squeeze(0).cpu().numpy().astype(np.float32)
|
| 342 |
+
norm_wav = wav_np / (np.max(np.abs(wav_np)) + 1e-8)
|
| 343 |
+
|
| 344 |
+
chunk_size = int(sample_rate * chunk_ms / 1000)
|
| 345 |
+
total_chunks = int(len(norm_wav) / chunk_size)
|
| 346 |
+
|
| 347 |
+
start_idx = 0
|
| 348 |
+
for i in range(total_chunks):
|
| 349 |
+
chunk = norm_wav[i * chunk_size : (i + 1) * chunk_size]
|
| 350 |
+
# TΓnh tα»· lα» sample dΖ°α»i ngΖ°α»‘ng
|
| 351 |
+
silent_ratio = np.mean(np.abs(chunk) < silence_thresh)
|
| 352 |
+
if silent_ratio < ratio: # nαΊΏu Γt hΖ‘n 95% sample im lαΊ·ng β coi lΓ cΓ³ tiαΊΏng
|
| 353 |
+
start_idx = max(0, i * chunk_size - int(sample_rate * extend_ms / 1000))
|
| 354 |
+
break
|
| 355 |
+
|
| 356 |
+
return wav[:, start_idx:]
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
@torch.inference_mode()
|
| 362 |
+
def run_zipvoice(
|
| 363 |
+
model_name="zipvoice",
|
| 364 |
+
model_dir="zipvoice_finetune",
|
| 365 |
+
checkpoint_name="model.pt",
|
| 366 |
+
vocoder_path=None,
|
| 367 |
+
tokenizer_name="emilia",
|
| 368 |
+
lang="en-us",
|
| 369 |
+
test_list=None, # path to tsv file
|
| 370 |
+
prompt_wav=None,
|
| 371 |
+
prompt_text=None,
|
| 372 |
+
text=None,
|
| 373 |
+
res_dir="results",
|
| 374 |
+
res_wav_path="result.wav",
|
| 375 |
+
guidance_scale=None,
|
| 376 |
+
num_step=None,
|
| 377 |
+
feat_scale=0.1,
|
| 378 |
+
speed=1.0,
|
| 379 |
+
t_shift=0.5,
|
| 380 |
+
target_rms=0.1,
|
| 381 |
+
seed=666,
|
| 382 |
+
):
|
| 383 |
+
text = text.lower()
|
| 384 |
+
# --- Default settings per model ---
|
| 385 |
+
model_defaults = {
|
| 386 |
+
"zipvoice": {"num_step": 16, "guidance_scale": 1.0},
|
| 387 |
+
"zipvoice_distill": {"num_step": 8, "guidance_scale": 3.0},
|
| 388 |
+
}
|
| 389 |
+
# sα»a cΓ‘ch gΓ‘n mαΊ·c Δα»nh (khΓ΄ng dΓΉng locals() nα»―a)
|
| 390 |
+
if guidance_scale is None:
|
| 391 |
+
guidance_scale = model_defaults.get(model_name, {}).get("guidance_scale", 1.0)
|
| 392 |
+
if num_step is None:
|
| 393 |
+
num_step = model_defaults.get(model_name, {}).get("num_step", 16)
|
| 394 |
+
|
| 395 |
+
# --- Check inputs ---
|
| 396 |
+
assert (test_list is not None) ^ ((prompt_wav and prompt_text and text) is not None), \
|
| 397 |
+
"CαΊ§n test_list hoαΊ·c (prompt_wav + prompt_text + text)"
|
| 398 |
+
|
| 399 |
+
fix_random_seed(seed)
|
| 400 |
+
|
| 401 |
+
# --- Load tokenizer, model, vocoder, features ... (phαΊ§n nΓ y giα»― nguyΓͺn) ---
|
| 402 |
+
# [giα»― nguyΓͺn toΓ n bα» phαΊ§n load tokenizer/model/vocoder/feature_extractor/sampling_rate]
|
| 403 |
+
|
| 404 |
+
# ---------------------------
|
| 405 |
+
# NEW: HΓ m chia ΔoαΊ‘n vΔn bαΊ£n
|
| 406 |
+
# ---------------------------
|
| 407 |
+
def split_text_into_chunks(s: str, min_chars: int = 15, max_chars: int = 30):
|
| 408 |
+
"""
|
| 409 |
+
Chia theo dαΊ₯u ',' hoαΊ·c '.', sau ΔΓ³ gα»p/xαΊ» Δα» mα»i ΔoαΊ‘n dΓ i trong [min_chars, max_chars].
|
| 410 |
+
KhΓ΄ng cαΊ―t giα»―a tα»«.
|
| 411 |
+
"""
|
| 412 |
+
# normalize khoαΊ£ng trαΊ―ng
|
| 413 |
+
s = re.sub(r"\s+", " ", (s or "").strip())
|
| 414 |
+
if not s:
|
| 415 |
+
return []
|
| 416 |
+
|
| 417 |
+
# tΓ‘ch theo dαΊ₯u , hoαΊ·c .
|
| 418 |
+
raw_segs = [seg.strip() for seg in re.split(r"\s*[.,]\s*", s) if seg.strip()]
|
| 419 |
+
|
| 420 |
+
chunks = []
|
| 421 |
+
i = 0
|
| 422 |
+
while i < len(raw_segs):
|
| 423 |
+
cur = raw_segs[i]
|
| 424 |
+
i += 1
|
| 425 |
+
|
| 426 |
+
# gα»p tiαΊΏp theo nαΊΏu cur quΓ‘ ngαΊ―n
|
| 427 |
+
while len(cur) < min_chars and i < len(raw_segs):
|
| 428 |
+
cur = (cur + ", " + raw_segs[i]).strip()
|
| 429 |
+
i += 1
|
| 430 |
+
|
| 431 |
+
# nαΊΏu cur quΓ‘ dΓ i, xαΊ» theo tα»« Δα» <= max_chars
|
| 432 |
+
if len(cur) > max_chars:
|
| 433 |
+
words = cur.split()
|
| 434 |
+
buf = []
|
| 435 |
+
cur_len = 0
|
| 436 |
+
for w in words:
|
| 437 |
+
# +1 cho khoαΊ£ng trαΊ―ng nαΊΏu cαΊ§n
|
| 438 |
+
add_len = len(w) if cur_len == 0 else len(w) + 1
|
| 439 |
+
if cur_len + add_len <= max_chars:
|
| 440 |
+
buf.append(w)
|
| 441 |
+
cur_len += add_len
|
| 442 |
+
else:
|
| 443 |
+
# ΔΓ³ng lαΊ‘i mα»t chunk
|
| 444 |
+
part = ", ".join(buf).strip()
|
| 445 |
+
if part:
|
| 446 |
+
chunks.append(part)
|
| 447 |
+
# bαΊ―t ΔαΊ§u chunk mα»i
|
| 448 |
+
buf = [w]
|
| 449 |
+
cur_len = len(w)
|
| 450 |
+
# phαΊ§n cΓ²n lαΊ‘i
|
| 451 |
+
last = " ".join(buf).strip()
|
| 452 |
+
if last:
|
| 453 |
+
# nαΊΏu phαΊ§n cuα»i vαΊ«n < min_chars vΓ cΓ³ thα» gα»p vα»i chunk trΖ°α»c ΔΓ³
|
| 454 |
+
if len(last) < min_chars and chunks:
|
| 455 |
+
merged = (chunks[-1] + " " + last).strip()
|
| 456 |
+
if len(merged) <= max_chars:
|
| 457 |
+
chunks[-1] = merged
|
| 458 |
+
else:
|
| 459 |
+
chunks.append(last) # ΔΓ nh chαΊ₯p nhαΊn (nhΖ°ng thΖ°α»ng Γt gαΊ·p)
|
| 460 |
+
else:
|
| 461 |
+
chunks.append(last)
|
| 462 |
+
else:
|
| 463 |
+
chunks.append(cur)
|
| 464 |
+
|
| 465 |
+
# vΓ²ng tinh chα»nh cuα»i: nαΊΏu chunk cuα»i quΓ‘ ngαΊ―n, gα»p vΓ o trΖ°α»c ΔΓ³
|
| 466 |
+
if len(chunks) >= 2 and len(chunks[-1]) < min_chars:
|
| 467 |
+
merged = (chunks[-2] + ", " + chunks[-1]).strip()
|
| 468 |
+
if len(merged) <= max_chars:
|
| 469 |
+
chunks[-2] = merged
|
| 470 |
+
chunks.pop()
|
| 471 |
+
# print(chunks)
|
| 472 |
+
final_chunk = []
|
| 473 |
+
for chunk in chunks:
|
| 474 |
+
chunk = ", " + chunk + ","
|
| 475 |
+
final_chunk.append(chunk)
|
| 476 |
+
return final_chunk
|
| 477 |
+
|
| 478 |
+
# ---------------------------
|
| 479 |
+
# MODIFIED: generate_sentence synth theo tα»«ng ΔoαΊ‘n + nα»i lαΊ‘i
|
| 480 |
+
# ---------------------------
|
| 481 |
+
def generate_sentence(save_path, prompt_text, prompt_wav, text):
|
| 482 |
+
# chuαΊ©n hoΓ‘ & chia ΔoαΊ‘n
|
| 483 |
+
segments = split_text_into_chunks(text, min_chars=50, max_chars=200)
|
| 484 |
+
if not segments:
|
| 485 |
+
# khΓ΄ng cΓ³ gΓ¬ Δα» nΓ³i: xuαΊ₯t file rα»ng 0.2s
|
| 486 |
+
silence = torch.zeros((1, int(0.2 * sampling_rate)))
|
| 487 |
+
torchaudio.save(save_path, silence, sample_rate=sampling_rate)
|
| 488 |
+
return
|
| 489 |
+
|
| 490 |
+
# chuαΊ©n bα» prompt (lΓ m 1 lαΊ§n)
|
| 491 |
+
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
| 492 |
+
prompt_wav_tensor, sr = torchaudio.load(prompt_wav)
|
| 493 |
+
if sr != sampling_rate:
|
| 494 |
+
prompt_wav_tensor = torchaudio.transforms.Resample(sr, sampling_rate)(prompt_wav_tensor)
|
| 495 |
+
prompt_rms_val = torch.sqrt(torch.mean(prompt_wav_tensor**2))
|
| 496 |
+
if prompt_rms_val < target_rms:
|
| 497 |
+
prompt_wav_tensor *= target_rms / prompt_rms_val
|
| 498 |
+
|
| 499 |
+
prompt_features = feature_extractor.extract(
|
| 500 |
+
prompt_wav_tensor, sampling_rate=sampling_rate
|
| 501 |
+
).to(device)
|
| 502 |
+
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
| 503 |
+
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
| 504 |
+
# print(prompt_features_lens)
|
| 505 |
+
|
| 506 |
+
num_space_prompt = prompt_text.count(" ")
|
| 507 |
+
|
| 508 |
+
# khoαΊ£ng lαΊ·ng 0.2s
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
gap_duration = random.uniform(0.17, 0.2) # sα» ngαΊ«u nhiΓͺn tα»« 0.17 ΔαΊΏn 0.2
|
| 512 |
+
gap = torch.zeros((1, int(gap_duration * sampling_rate)))
|
| 513 |
+
|
| 514 |
+
wav_parts = []
|
| 515 |
+
print("segments",segments)
|
| 516 |
+
for idx, seg in enumerate(segments):
|
| 517 |
+
# print(seg)
|
| 518 |
+
num_space_text = seg.count(" ")
|
| 519 |
+
tokens = tokenizer.texts_to_token_ids([seg])
|
| 520 |
+
# print(tokens)
|
| 521 |
+
score = score_tokens(tokens[0])
|
| 522 |
+
# print(score)
|
| 523 |
+
# print(prompt_tokens)
|
| 524 |
+
score_prompt = score_tokens(prompt_tokens[0])
|
| 525 |
+
# print(score_prompt)
|
| 526 |
+
vocab_file = "zipvoice_finetune/tokens.txt" # file txt dαΊ‘ng bαΊ‘n ΔΖ°a
|
| 527 |
+
|
| 528 |
+
id2char = load_vocab(vocab_file)
|
| 529 |
+
decoded_text = tokens_to_text(tokens[0], id2char)
|
| 530 |
+
|
| 531 |
+
print(decoded_text)
|
| 532 |
+
|
| 533 |
+
pred_features, _, _, _ = model.sample(
|
| 534 |
+
num_space_text=[num_space_text],
|
| 535 |
+
num_space_prompt=[num_space_prompt],
|
| 536 |
+
tokens=tokens,
|
| 537 |
+
prompt_tokens=prompt_tokens,
|
| 538 |
+
prompt_features=prompt_features,
|
| 539 |
+
prompt_features_lens=prompt_features_lens,
|
| 540 |
+
speed= speed,
|
| 541 |
+
t_shift= t_shift,
|
| 542 |
+
duration="predict",
|
| 543 |
+
num_step= num_step,
|
| 544 |
+
guidance_scale= guidance_scale,
|
| 545 |
+
)
|
| 546 |
+
pred_features = pred_features.permute(0, 2, 1) / feat_scale
|
| 547 |
+
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
| 548 |
+
|
| 549 |
+
# phα»₯c hα»i mα»©c Γ’m lượng tΖ°Ζ‘ng quan prompt
|
| 550 |
+
if prompt_rms_val < target_rms:
|
| 551 |
+
wav *= prompt_rms_val / target_rms
|
| 552 |
+
wav = trim_leading_silence_torch(
|
| 553 |
+
wav, sample_rate=sampling_rate, silence_thresh=0.086, chunk_ms=10, extend_ms=20
|
| 554 |
+
)
|
| 555 |
+
wav_parts.append(wav.cpu())
|
| 556 |
+
if idx < len(segments) - 1:
|
| 557 |
+
wav_parts.append(gap) # chèn khoảng lặng
|
| 558 |
+
|
| 559 |
+
final_wav = torch.cat(wav_parts, dim=-1) # [1, T_total]
|
| 560 |
+
torchaudio.save(save_path, final_wav, sample_rate=sampling_rate)
|
| 561 |
+
|
| 562 |
+
# --- generate_list giα»― nguyΓͺn: gα»i generate_sentence nΓͺn tα»± Γ‘p dα»₯ng chia ΔoαΊ‘n ---
|
| 563 |
+
def generate_list(res_dir, test_list):
|
| 564 |
+
os.makedirs(res_dir, exist_ok=True)
|
| 565 |
+
with open(test_list, "r", encoding="utf-8") as fr:
|
| 566 |
+
for i, line in enumerate(fr):
|
| 567 |
+
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
| 568 |
+
save_path = f"{res_dir}/{wav_name}.wav"
|
| 569 |
+
generate_sentence(save_path, prompt_text, prompt_wav, text)
|
| 570 |
+
|
| 571 |
+
# --- Run ---
|
| 572 |
+
if test_list:
|
| 573 |
+
generate_list(res_dir, test_list)
|
| 574 |
+
else:
|
| 575 |
+
generate_sentence(res_wav_path, prompt_text, prompt_wav, text)
|
| 576 |
+
|
| 577 |
+
print("β
HoΓ n thΓ nh!")
|
| 578 |
+
return text,
|
proccess_wav.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pydub import AudioSegment
|
| 4 |
+
import os
|
| 5 |
+
from chunkformer import ChunkFormerModel
|
| 6 |
+
from clearvoice import ClearVoice
|
| 7 |
+
# ======================= ASR + CLEARVOICE + AUDIO PROCESSING =======================
|
| 8 |
+
|
| 9 |
+
ASR_MODEL = None
|
| 10 |
+
CLEARVOICE_MODEL = None
|
| 11 |
+
REF_AUDIO_CACHE = {} # cache: ΔΖ°α»ng dαΊ«n input -> ΔΖ°α»ng dαΊ«n output ΔΓ£ xα» lΓ½
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_asr_model() -> ChunkFormerModel:
|
| 15 |
+
"""Lazy-load ChunkFormer (ASR, chαΊ‘y trΓͺn CPU)."""
|
| 16 |
+
global ASR_MODEL
|
| 17 |
+
if ASR_MODEL is None:
|
| 18 |
+
ASR_MODEL = ChunkFormerModel.from_pretrained("khanhld/chunkformer-ctc-large-vie")
|
| 19 |
+
return ASR_MODEL
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_clearvoice_model() -> ClearVoice:
|
| 23 |
+
"""Lazy-load ClearVoice Δα» khα» nhiα»
u ref audio."""
|
| 24 |
+
global CLEARVOICE_MODEL
|
| 25 |
+
if CLEARVOICE_MODEL is None:
|
| 26 |
+
CLEARVOICE_MODEL = ClearVoice(
|
| 27 |
+
task="speech_enhancement",
|
| 28 |
+
model_names=["MossFormer2_SE_48K"],
|
| 29 |
+
)
|
| 30 |
+
return CLEARVOICE_MODEL
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def find_silent_regions(
|
| 34 |
+
audio: AudioSegment,
|
| 35 |
+
silence_thresh: float = 0.05, # biΓͺn Δα» sau chuαΊ©n hoΓ‘ [-1, 1]
|
| 36 |
+
chunk_ms: int = 10,
|
| 37 |
+
min_silence_len: int = 200,
|
| 38 |
+
) -> List[Tuple[int, int]]:
|
| 39 |
+
"""
|
| 40 |
+
TΓ¬m cΓ‘c khoαΊ£ng lαΊ·ng (start_ms, end_ms) trong AudioSegment dα»±a trΓͺn biΓͺn Δα».
|
| 41 |
+
"""
|
| 42 |
+
samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
|
| 43 |
+
if audio.channels > 1:
|
| 44 |
+
samples = samples.reshape((-1, audio.channels)).mean(axis=1)
|
| 45 |
+
|
| 46 |
+
norm = samples / (2 ** (audio.sample_width * 8 - 1))
|
| 47 |
+
sr = audio.frame_rate
|
| 48 |
+
|
| 49 |
+
chunk_size = max(1, int(sr * chunk_ms / 1000))
|
| 50 |
+
total_chunks = len(norm) // chunk_size
|
| 51 |
+
|
| 52 |
+
silent_regions: List[Tuple[int, int]] = []
|
| 53 |
+
start = None
|
| 54 |
+
for i in range(total_chunks):
|
| 55 |
+
chunk = norm[i * chunk_size: (i + 1) * chunk_size]
|
| 56 |
+
if chunk.size == 0:
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
if np.all((chunk > -silence_thresh) & (chunk < silence_thresh)):
|
| 60 |
+
if start is None:
|
| 61 |
+
start = i
|
| 62 |
+
else:
|
| 63 |
+
if start is not None:
|
| 64 |
+
dur = (i - start) * chunk_ms
|
| 65 |
+
if dur >= min_silence_len:
|
| 66 |
+
silent_regions.append((start * chunk_ms, i * chunk_ms))
|
| 67 |
+
start = None
|
| 68 |
+
|
| 69 |
+
if start is not None:
|
| 70 |
+
dur = (total_chunks - start) * chunk_ms
|
| 71 |
+
if dur >= min_silence_len:
|
| 72 |
+
silent_regions.append((start * chunk_ms, total_chunks * chunk_ms))
|
| 73 |
+
|
| 74 |
+
return silent_regions
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def trim_leading_trailing_silence(
|
| 78 |
+
audio: AudioSegment,
|
| 79 |
+
silence_thresh: float = 0.05,
|
| 80 |
+
chunk_ms: int = 10,
|
| 81 |
+
min_silence_len: int = 200,
|
| 82 |
+
) -> AudioSegment:
|
| 83 |
+
"""
|
| 84 |
+
Bα» khoαΊ£ng lαΊ·ng ΔαΊ§u/cuα»i file.
|
| 85 |
+
"""
|
| 86 |
+
duration = len(audio)
|
| 87 |
+
silent_regions = find_silent_regions(
|
| 88 |
+
audio,
|
| 89 |
+
silence_thresh=silence_thresh,
|
| 90 |
+
chunk_ms=chunk_ms,
|
| 91 |
+
min_silence_len=min_silence_len,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if not silent_regions:
|
| 95 |
+
return audio
|
| 96 |
+
|
| 97 |
+
start_trim = 0
|
| 98 |
+
end_trim = duration
|
| 99 |
+
|
| 100 |
+
# khoαΊ£ng lαΊ·ng ΔαΊ§u file
|
| 101 |
+
first_start, first_end = silent_regions[0]
|
| 102 |
+
if first_start <= 0:
|
| 103 |
+
start_trim = max(start_trim, first_end)
|
| 104 |
+
|
| 105 |
+
# khoαΊ£ng lαΊ·ng cuα»i file
|
| 106 |
+
last_start, last_end = silent_regions[-1]
|
| 107 |
+
if last_end >= duration:
|
| 108 |
+
end_trim = min(end_trim, last_start)
|
| 109 |
+
|
| 110 |
+
return audio[start_trim:end_trim]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def compress_internal_silence(
|
| 114 |
+
audio: AudioSegment,
|
| 115 |
+
max_silence_ms: int = 300,
|
| 116 |
+
silence_thresh: float = 0.05,
|
| 117 |
+
chunk_ms: int = 10,
|
| 118 |
+
min_silence_len: int = 50,
|
| 119 |
+
) -> AudioSegment:
|
| 120 |
+
"""
|
| 121 |
+
RΓΊt ngαΊ―n khoαΊ£ng lαΊ·ng giα»―a file:
|
| 122 |
+
- KhoαΊ£ng lαΊ·ng <= max_silence_ms: giα»― nguyΓͺn
|
| 123 |
+
- KhoαΊ£ng lαΊ·ng > max_silence_ms: cαΊ―t cΓ²n max_silence_ms
|
| 124 |
+
"""
|
| 125 |
+
duration = len(audio)
|
| 126 |
+
silent_regions = find_silent_regions(
|
| 127 |
+
audio,
|
| 128 |
+
silence_thresh=silence_thresh,
|
| 129 |
+
chunk_ms=chunk_ms,
|
| 130 |
+
min_silence_len=min_silence_len,
|
| 131 |
+
)
|
| 132 |
+
if not silent_regions:
|
| 133 |
+
return audio
|
| 134 |
+
|
| 135 |
+
new_audio = AudioSegment.silent(duration=0, frame_rate=audio.frame_rate)
|
| 136 |
+
cursor = 0
|
| 137 |
+
|
| 138 |
+
for s_start, s_end in silent_regions:
|
| 139 |
+
# phαΊ§n cΓ³ tiαΊΏng nΓ³i trΖ°α»c khoαΊ£ng lαΊ·ng
|
| 140 |
+
if s_start > cursor:
|
| 141 |
+
new_audio += audio[cursor:s_start]
|
| 142 |
+
|
| 143 |
+
silence_len = s_end - s_start
|
| 144 |
+
if silence_len <= max_silence_ms:
|
| 145 |
+
new_audio += audio[s_start:s_end]
|
| 146 |
+
else:
|
| 147 |
+
new_audio += audio[s_start: s_start + max_silence_ms]
|
| 148 |
+
|
| 149 |
+
cursor = s_end
|
| 150 |
+
|
| 151 |
+
# phαΊ§n cΓ²n lαΊ‘i sau khoαΊ£ng lαΊ·ng cuα»i
|
| 152 |
+
if cursor < duration:
|
| 153 |
+
new_audio += audio[cursor:]
|
| 154 |
+
|
| 155 |
+
return new_audio
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def select_subsegment_by_silence(
|
| 159 |
+
audio: AudioSegment,
|
| 160 |
+
min_len_ms: int = 5000,
|
| 161 |
+
max_len_ms: int = 10000,
|
| 162 |
+
silence_thresh: float = 0.05,
|
| 163 |
+
chunk_ms: int = 10,
|
| 164 |
+
min_silence_len: int = 200,
|
| 165 |
+
) -> AudioSegment:
|
| 166 |
+
"""
|
| 167 |
+
NαΊΏu audio > max_len_ms, chα»n 1 ΔoαΊ‘n dΓ i trong khoαΊ£ng [min_len_ms, max_len_ms],
|
| 168 |
+
cαΊ―t tαΊ‘i Δiα»m nαΊ±m trong khoοΏ½οΏ½οΏ½ng lαΊ·ng Δα» trΓ‘nh cαΊ―t dΓnh giα»ng nΓ³i.
|
| 169 |
+
"""
|
| 170 |
+
duration = len(audio)
|
| 171 |
+
if duration <= max_len_ms:
|
| 172 |
+
return audio
|
| 173 |
+
|
| 174 |
+
silent_regions = find_silent_regions(
|
| 175 |
+
audio,
|
| 176 |
+
silence_thresh=silence_thresh,
|
| 177 |
+
chunk_ms=chunk_ms,
|
| 178 |
+
min_silence_len=min_silence_len,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if not silent_regions:
|
| 182 |
+
# khΓ΄ng tΓ¬m Δược khoαΊ£ng lαΊ·ng -> lαΊ₯y ΔoαΊ‘n giα»―a
|
| 183 |
+
target_len = min(max_len_ms, duration)
|
| 184 |
+
start = max(0, (duration - target_len) // 2)
|
| 185 |
+
end = start + target_len
|
| 186 |
+
return audio[start:end]
|
| 187 |
+
|
| 188 |
+
# boundary lΓ midpoint cα»§a khoαΊ£ng lαΊ·ng (chαΊ―c chαΊ―n nαΊ±m trong vΓΉng im lαΊ·ng)
|
| 189 |
+
boundaries = [0]
|
| 190 |
+
for s_start, s_end in silent_regions:
|
| 191 |
+
mid = (s_start + s_end) // 2
|
| 192 |
+
if 0 < mid < duration:
|
| 193 |
+
boundaries.append(mid)
|
| 194 |
+
boundaries.append(duration)
|
| 195 |
+
boundaries = sorted(set(boundaries))
|
| 196 |
+
|
| 197 |
+
# Ζ°u tiΓͺn ΔoαΊ‘n ΔαΊ§u tiΓͺn thα»a 5β10s
|
| 198 |
+
for i in range(len(boundaries)):
|
| 199 |
+
for j in range(i + 1, len(boundaries)):
|
| 200 |
+
seg_len = boundaries[j] - boundaries[i]
|
| 201 |
+
if min_len_ms <= seg_len <= max_len_ms:
|
| 202 |
+
return audio[boundaries[i]:boundaries[j]]
|
| 203 |
+
|
| 204 |
+
# nαΊΏu khΓ΄ng cΓ³ ΔoαΊ‘n nΓ o nαΊ±m trα»n trong [min, max], chα»n ΔoαΊ‘n gαΊ§n max_len nhαΊ₯t
|
| 205 |
+
best_i, best_j, best_diff = 0, None, None
|
| 206 |
+
for i in range(len(boundaries)):
|
| 207 |
+
for j in range(i + 1, len(boundaries)):
|
| 208 |
+
seg_len = boundaries[j] - boundaries[i]
|
| 209 |
+
if seg_len >= min_len_ms:
|
| 210 |
+
diff = abs(seg_len - max_len_ms)
|
| 211 |
+
if best_diff is None or diff < best_diff:
|
| 212 |
+
best_diff = diff
|
| 213 |
+
best_i, best_j = i, j
|
| 214 |
+
|
| 215 |
+
if best_j is not None:
|
| 216 |
+
return audio[boundaries[best_i]:boundaries[best_j]]
|
| 217 |
+
|
| 218 |
+
# fallback cuα»i cΓΉng
|
| 219 |
+
target_len = min(max_len_ms, duration)
|
| 220 |
+
start = max(0, (duration - target_len) // 2)
|
| 221 |
+
end = start + target_len
|
| 222 |
+
return audio[start:end]
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def enhance_ref_audio(input_path: str) -> str:
|
| 226 |
+
"""
|
| 227 |
+
Pipeline xα» lΓ½ WAV cho TTS:
|
| 228 |
+
- ClearVoice khα» nhiα»
u
|
| 229 |
+
- Bα» khoαΊ£ng lαΊ·ng ΔαΊ§u/cuα»i
|
| 230 |
+
- RΓΊt ngαΊ―n khoαΊ£ng lαΊ·ng giα»―a > 0.3s thΓ nh 0.3s
|
| 231 |
+
- NαΊΏu audio > 10s: chα»n 1 ΔoαΊ‘n 5β10s, cαΊ―t tαΊ‘i khoαΊ£ng lαΊ·ng
|
| 232 |
+
TrαΊ£ vα» ΔΖ°α»ng dαΊ«n file wav ΔΓ£ xα» lΓ½.
|
| 233 |
+
"""
|
| 234 |
+
if not input_path:
|
| 235 |
+
raise ValueError("No input audio path for enhancement.")
|
| 236 |
+
|
| 237 |
+
# cache Δα» cΓΉng 1 file khΓ΄ng phαΊ£i xα» lΓ½ nhiα»u lαΊ§n
|
| 238 |
+
if input_path in REF_AUDIO_CACHE:
|
| 239 |
+
return REF_AUDIO_CACHE[input_path]
|
| 240 |
+
|
| 241 |
+
cv = get_clearvoice_model()
|
| 242 |
+
|
| 243 |
+
# 1) khα» nhiα»
u
|
| 244 |
+
try:
|
| 245 |
+
cv_out = cv(input_path=input_path, online_write=False)
|
| 246 |
+
base = os.path.basename(input_path)
|
| 247 |
+
name, ext = os.path.splitext(base)
|
| 248 |
+
if not ext:
|
| 249 |
+
ext = ".wav"
|
| 250 |
+
denoised_path = os.path.join(os.path.dirname(input_path), f"{name}_denoised{ext}")
|
| 251 |
+
cv.write(cv_out, output_path=denoised_path)
|
| 252 |
+
except Exception as e:
|
| 253 |
+
print(f"[ClearVoice] Error during denoising, fallback to original: {e}")
|
| 254 |
+
denoised_path = input_path
|
| 255 |
+
|
| 256 |
+
# 2) pydub xα» lΓ½ khoαΊ£ng lαΊ·ng + length
|
| 257 |
+
audio = AudioSegment.from_file(denoised_path)
|
| 258 |
+
|
| 259 |
+
# bα» khoαΊ£ng lαΊ·ng ΔαΊ§u/cuα»i
|
| 260 |
+
audio = trim_leading_trailing_silence(audio)
|
| 261 |
+
|
| 262 |
+
# rΓΊt ngαΊ―n khoαΊ£ng lαΊ·ng giα»―a
|
| 263 |
+
audio = compress_internal_silence(audio, max_silence_ms=300)
|
| 264 |
+
|
| 265 |
+
# nαΊΏu >10s thΓ¬ chα»n ΔoαΊ‘n trong khoαΊ£ng 5β10s
|
| 266 |
+
audio = select_subsegment_by_silence(audio, min_len_ms=5000, max_len_ms=10000)
|
| 267 |
+
|
| 268 |
+
# 3) ghi ra file mα»i
|
| 269 |
+
enhanced_path = os.path.join(os.path.dirname(denoised_path), f"{name}_enhanced.wav")
|
| 270 |
+
audio.export(enhanced_path, format="wav")
|
| 271 |
+
|
| 272 |
+
REF_AUDIO_CACHE[input_path] = enhanced_path
|
| 273 |
+
return enhanced_path
|
| 274 |
+
|
| 275 |
+
def split_audio_by_silence(
|
| 276 |
+
audio: AudioSegment,
|
| 277 |
+
silence_thresh: float = 0.05,
|
| 278 |
+
chunk_ms: int = 10,
|
| 279 |
+
min_silence_len: int = 200,
|
| 280 |
+
min_segment_len: int = 200,
|
| 281 |
+
) -> List[Tuple[int, int]]:
|
| 282 |
+
"""
|
| 283 |
+
Tα»« AudioSegment, trαΊ£ vα» cΓ‘c ΔoαΊ‘n cΓ³ tiαΊΏng nΓ³i (non-silent)
|
| 284 |
+
Δược tΓ‘ch bαΊ±ng khoαΊ£ng lαΊ·ng.
|
| 285 |
+
"""
|
| 286 |
+
duration = len(audio)
|
| 287 |
+
silent_regions = find_silent_regions(
|
| 288 |
+
audio,
|
| 289 |
+
silence_thresh=silence_thresh,
|
| 290 |
+
chunk_ms=chunk_ms,
|
| 291 |
+
min_silence_len=min_silence_len,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
segments: List[Tuple[int, int]] = []
|
| 295 |
+
cur_start = 0
|
| 296 |
+
|
| 297 |
+
for s_start, s_end in silent_regions:
|
| 298 |
+
if cur_start < s_start:
|
| 299 |
+
if s_start - cur_start >= min_segment_len:
|
| 300 |
+
segments.append((cur_start, s_start))
|
| 301 |
+
cur_start = s_end
|
| 302 |
+
|
| 303 |
+
if cur_start < duration and duration - cur_start >= min_segment_len:
|
| 304 |
+
segments.append((cur_start, duration))
|
| 305 |
+
|
| 306 |
+
# nαΊΏu khΓ΄ng tΓ¬m Δược ΔoαΊ‘n nΓ o, lαΊ₯y cαΊ£ file
|
| 307 |
+
if not segments:
|
| 308 |
+
segments.append((0, duration))
|
| 309 |
+
|
| 310 |
+
return segments
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def transcribe_ref_audio(audio_path: str) -> str:
|
| 314 |
+
"""
|
| 315 |
+
ASR theo yΓͺu cαΊ§u:
|
| 316 |
+
- CαΊ―t Γ’m thanh theo khoαΊ£ng lαΊ·ng
|
| 317 |
+
- ASR tα»«ng ΔoαΊ‘n
|
| 318 |
+
- Nα»i text bαΊ±ng dαΊ₯u phαΊ©y
|
| 319 |
+
"""
|
| 320 |
+
if not audio_path:
|
| 321 |
+
raise ValueError("No audio path for ASR.")
|
| 322 |
+
|
| 323 |
+
model = get_asr_model()
|
| 324 |
+
audio = AudioSegment.from_file(audio_path)
|
| 325 |
+
segments = split_audio_by_silence(audio)
|
| 326 |
+
|
| 327 |
+
texts = []
|
| 328 |
+
base, _ = os.path.splitext(audio_path)
|
| 329 |
+
|
| 330 |
+
for idx, (start_ms, end_ms) in enumerate(segments):
|
| 331 |
+
seg_audio = audio[start_ms:end_ms]
|
| 332 |
+
seg_path = f"{base}_seg_{idx}.wav"
|
| 333 |
+
seg_audio.export(seg_path, format="wav")
|
| 334 |
+
|
| 335 |
+
try:
|
| 336 |
+
transcription = model.endless_decode(
|
| 337 |
+
audio_path=seg_path,
|
| 338 |
+
chunk_size=32,
|
| 339 |
+
left_context_size=0,
|
| 340 |
+
right_context_size=0,
|
| 341 |
+
total_batch_duration=400,
|
| 342 |
+
return_timestamps=False,
|
| 343 |
+
)
|
| 344 |
+
except TypeError:
|
| 345 |
+
transcription = model.endless_decode(
|
| 346 |
+
audio_path=seg_path,
|
| 347 |
+
chunk_size=32,
|
| 348 |
+
left_context_size=0,
|
| 349 |
+
right_context_size=0,
|
| 350 |
+
total_batch_duration=400,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if isinstance(transcription, str):
|
| 354 |
+
text = transcription
|
| 355 |
+
else:
|
| 356 |
+
text = str(transcription)
|
| 357 |
+
|
| 358 |
+
text = text.strip()
|
| 359 |
+
if text:
|
| 360 |
+
texts.append(text)
|
| 361 |
+
|
| 362 |
+
return ", ".join(texts)
|
| 363 |
+
|
| 364 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.isort]
|
| 2 |
+
profile = "black"
|
| 3 |
+
|
| 4 |
+
[tool.black]
|
| 5 |
+
line-length = 88
|
requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
|
| 2 |
+
transformers==4.57.1
|
| 3 |
+
torch
|
| 4 |
+
torchaudio
|
| 5 |
+
torchcodec
|
| 6 |
+
numpy
|
| 7 |
+
lhotse
|
| 8 |
+
huggingface_hub
|
| 9 |
+
safetensors
|
| 10 |
+
tensorboard
|
| 11 |
+
vocos
|
| 12 |
+
|
| 13 |
+
# Normalization
|
| 14 |
+
cn2an
|
| 15 |
+
inflect
|
| 16 |
+
|
| 17 |
+
# Tokenization
|
| 18 |
+
jieba
|
| 19 |
+
piper_phonemize
|
| 20 |
+
pypinyin
|
| 21 |
+
setuptools<81
|
| 22 |
+
chunkformer
|
| 23 |
+
clearvoice
|
requirements_eval.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
numpy
|
| 3 |
+
|
| 4 |
+
# Audio processing
|
| 5 |
+
librosa
|
| 6 |
+
soundfile
|
| 7 |
+
|
| 8 |
+
# Model
|
| 9 |
+
s3prl
|
| 10 |
+
pyannote.audio
|
| 11 |
+
funasr
|
| 12 |
+
transformers
|
| 13 |
+
|
| 14 |
+
# WER
|
| 15 |
+
jiwer==3.1.0
|
| 16 |
+
|
| 17 |
+
# Normalization
|
| 18 |
+
zhconv
|
| 19 |
+
zhon
|
setup.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import requests
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
def run_cmd(cmd):
|
| 7 |
+
print(f"πΉ ChαΊ‘y lα»nh: {cmd}")
|
| 8 |
+
result = subprocess.run(cmd, shell=True)
|
| 9 |
+
if result.returncode != 0:
|
| 10 |
+
raise RuntimeError(f"Lα»nh thαΊ₯t bαΊ‘i: {cmd}")
|
| 11 |
+
|
| 12 |
+
def download_with_token(url, dest_path, token):
|
| 13 |
+
headers = {"Authorization": f"Bearer {token}"}
|
| 14 |
+
with requests.get(url, headers=headers, stream=True) as r:
|
| 15 |
+
r.raise_for_status()
|
| 16 |
+
with open(dest_path, "wb") as f:
|
| 17 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 18 |
+
f.write(chunk)
|
| 19 |
+
print(f"β
ΔΓ£ tαΊ£i: {dest_path}")
|
| 20 |
+
|
| 21 |
+
def main():
|
| 22 |
+
# Load biαΊΏn mΓ΄i trΖ°α»ng tα»« .env
|
| 23 |
+
load_dotenv()
|
| 24 |
+
token = os.getenv("HF_TOKEN")
|
| 25 |
+
|
| 26 |
+
if not token:
|
| 27 |
+
raise EnvironmentError("β ThiαΊΏu biαΊΏn mΓ΄i trΖ°α»ng HF_TOKEN. HΓ£y tαΊ‘o file .env vα»i dΓ²ng:\nHF_TOKEN=hf_your_token_here")
|
| 28 |
+
|
| 29 |
+
# ΔΔng nhαΊp vΓ o Hugging Face CLI
|
| 30 |
+
run_cmd(f"huggingface-cli login --token {token}")
|
| 31 |
+
|
| 32 |
+
# TαΊ‘o thΖ° mα»₯c chα»©a model
|
| 33 |
+
os.makedirs("zipvoice_finetune", exist_ok=True)
|
| 34 |
+
|
| 35 |
+
# Danh sΓ‘ch file cαΊ§n tαΊ£i
|
| 36 |
+
files = {
|
| 37 |
+
"iter-525000-avg-2.pt": "https://huggingface.co/datasets/meandyou200175/temp_file/resolve/main/zip/epoch-46-all-speak-600h-en-norm.pt",
|
| 38 |
+
"model.json": "https://huggingface.co/datasets/meandyou200175/temp_file/resolve/main/zip/model.json",
|
| 39 |
+
"tokens.txt": "https://huggingface.co/datasets/meandyou200175/temp_file/resolve/main/zip/tokens.txt",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
for filename, url in files.items():
|
| 43 |
+
dest = os.path.join("zipvoice_finetune", filename)
|
| 44 |
+
download_with_token(url, dest, token)
|
| 45 |
+
|
| 46 |
+
# CΓ i ΔαΊ·t requirements
|
| 47 |
+
if os.path.exists("requirements.txt"):
|
| 48 |
+
run_cmd("pip install -r requirements.txt")
|
| 49 |
+
else:
|
| 50 |
+
print("β οΈ KhΓ΄ng tΓ¬m thαΊ₯y requirements.txt")
|
| 51 |
+
|
| 52 |
+
print("\nπ Setup hoΓ n tαΊ₯t!")
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
main()
|
zipvoice/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
warnings.filterwarnings(
|
| 4 |
+
"ignore",
|
| 5 |
+
category=UserWarning,
|
| 6 |
+
message="pkg_resources is deprecated as an API.*",
|
| 7 |
+
)
|
zipvoice/bin/compute_fbank.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang
|
| 3 |
+
# Han Zhu)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
"""
|
| 19 |
+
Usage:
|
| 20 |
+
python3 -m zipvoice.bin.compute_fbank \
|
| 21 |
+
--source-dir data/manifests \
|
| 22 |
+
--dest-dir data/fbank \
|
| 23 |
+
--dataset libritts \
|
| 24 |
+
--subset dev-other \
|
| 25 |
+
--sampling-rate 24000 \
|
| 26 |
+
--num-jobs 20
|
| 27 |
+
|
| 28 |
+
The input would be data/manifests/libritts-cuts_dev-other.jsonl.gz or
|
| 29 |
+
(libritts_supervisions_dev-other.jsonl.gz and librittsrecordings_dev-other.jsonl.gz)
|
| 30 |
+
|
| 31 |
+
The output would be data/fbank/libritts-cuts_dev-other.jsonl.gz
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import argparse
|
| 36 |
+
import logging
|
| 37 |
+
from concurrent.futures import ProcessPoolExecutor as Pool
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
|
| 40 |
+
import lhotse
|
| 41 |
+
import torch
|
| 42 |
+
from lhotse import CutSet, LilcomChunkyWriter, load_manifest_lazy
|
| 43 |
+
|
| 44 |
+
from zipvoice.utils.common import str2bool
|
| 45 |
+
from zipvoice.utils.feature import VocosFbank
|
| 46 |
+
|
| 47 |
+
# Torch's multithreaded behavior needs to be disabled or
|
| 48 |
+
# it wastes a lot of CPU and slow things down.
|
| 49 |
+
# Do this outside of main() in case it needs to take effect
|
| 50 |
+
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
| 51 |
+
torch.set_num_threads(1)
|
| 52 |
+
torch.set_num_interop_threads(1)
|
| 53 |
+
|
| 54 |
+
lhotse.set_audio_duration_mismatch_tolerance(0.1)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_args():
|
| 58 |
+
parser = argparse.ArgumentParser()
|
| 59 |
+
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--sampling-rate",
|
| 62 |
+
type=int,
|
| 63 |
+
default=24000,
|
| 64 |
+
help="The target sampling rate, the audio will be resampled to it.",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--type",
|
| 69 |
+
type=str,
|
| 70 |
+
default="vocos",
|
| 71 |
+
help="fbank type",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--dataset",
|
| 76 |
+
type=str,
|
| 77 |
+
help="Dataset name.",
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--subset",
|
| 82 |
+
type=str,
|
| 83 |
+
help="The subset of the dataset.",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--source-dir",
|
| 88 |
+
type=str,
|
| 89 |
+
default="data/manifests",
|
| 90 |
+
help="The source directory of manifest files.",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--dest-dir",
|
| 95 |
+
type=str,
|
| 96 |
+
default="data/fbank",
|
| 97 |
+
help="The destination directory of manifest files.",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--split-cuts",
|
| 102 |
+
type=str2bool,
|
| 103 |
+
default=False,
|
| 104 |
+
help="Whether to use splited cuts.",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--split-begin",
|
| 109 |
+
type=int,
|
| 110 |
+
help="Start idx of splited cuts.",
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--split-end",
|
| 115 |
+
type=int,
|
| 116 |
+
help="End idx of splited cuts.",
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--batch-duration",
|
| 121 |
+
type=int,
|
| 122 |
+
default=1000,
|
| 123 |
+
help="The batch duration when computing the features.",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--num-jobs",
|
| 128 |
+
type=int,
|
| 129 |
+
default=20,
|
| 130 |
+
help="The number of extractor workers.",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return parser.parse_args()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def compute_fbank_split_single(params, idx):
|
| 137 |
+
logging.info(
|
| 138 |
+
f"Computing features for {idx}-th split of "
|
| 139 |
+
f"{params.dataset} dataset {params.subset} subset"
|
| 140 |
+
)
|
| 141 |
+
lhotse.set_audio_duration_mismatch_tolerance(0.1) # for emilia
|
| 142 |
+
src_dir = Path(params.source_dir)
|
| 143 |
+
output_dir = Path(params.dest_dir)
|
| 144 |
+
|
| 145 |
+
if not src_dir.exists():
|
| 146 |
+
logging.error(f"{src_dir} not exists")
|
| 147 |
+
return
|
| 148 |
+
|
| 149 |
+
if not output_dir.exists():
|
| 150 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 151 |
+
|
| 152 |
+
num_digits = 8
|
| 153 |
+
if params.type == "vocos":
|
| 154 |
+
extractor = VocosFbank()
|
| 155 |
+
else:
|
| 156 |
+
raise NotImplementedError(f"{params.type} is not supported")
|
| 157 |
+
|
| 158 |
+
prefix = params.dataset
|
| 159 |
+
subset = params.subset
|
| 160 |
+
suffix = "jsonl.gz"
|
| 161 |
+
|
| 162 |
+
idx = f"{idx}".zfill(num_digits)
|
| 163 |
+
cuts_filename = f"{prefix}_cuts_{subset}.{idx}.{suffix}"
|
| 164 |
+
|
| 165 |
+
if (src_dir / cuts_filename).is_file():
|
| 166 |
+
logging.info(f"Loading manifests {src_dir / cuts_filename}")
|
| 167 |
+
cut_set = load_manifest_lazy(src_dir / cuts_filename)
|
| 168 |
+
else:
|
| 169 |
+
logging.warning(f"Raw {cuts_filename} not exists, skipping")
|
| 170 |
+
return
|
| 171 |
+
|
| 172 |
+
cut_set = cut_set.resample(params.sampling_rate)
|
| 173 |
+
|
| 174 |
+
if (output_dir / cuts_filename).is_file():
|
| 175 |
+
logging.info(f"{cuts_filename} already exists - skipping.")
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
logging.info(f"Processing {subset}.{idx} of {prefix}")
|
| 179 |
+
|
| 180 |
+
cut_set = cut_set.compute_and_store_features_batch(
|
| 181 |
+
extractor=extractor,
|
| 182 |
+
storage_path=f"{output_dir}/{prefix}_feats_{subset}_{idx}",
|
| 183 |
+
num_workers=4,
|
| 184 |
+
batch_duration=params.batch_duration,
|
| 185 |
+
storage_type=LilcomChunkyWriter,
|
| 186 |
+
overwrite=True,
|
| 187 |
+
)
|
| 188 |
+
logging.info(f"Saving file to {output_dir / cuts_filename}")
|
| 189 |
+
cut_set.to_file(output_dir / cuts_filename)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def compute_fbank_split(params):
|
| 193 |
+
if params.split_end < params.split_begin:
|
| 194 |
+
logging.warning(
|
| 195 |
+
f"Split begin should be smaller than split end, given "
|
| 196 |
+
f"{params.split_begin} -> {params.split_end}."
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
with Pool(max_workers=params.num_jobs) as pool:
|
| 200 |
+
futures = [
|
| 201 |
+
pool.submit(compute_fbank_split_single, params, i)
|
| 202 |
+
for i in range(params.split_begin, params.split_end)
|
| 203 |
+
]
|
| 204 |
+
for f in futures:
|
| 205 |
+
f.result()
|
| 206 |
+
f.done()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def compute_fbank(params):
|
| 210 |
+
logging.info(
|
| 211 |
+
f"Computing features for {params.dataset} dataset {params.subset} subset"
|
| 212 |
+
)
|
| 213 |
+
src_dir = Path(params.source_dir)
|
| 214 |
+
output_dir = Path(params.dest_dir)
|
| 215 |
+
num_jobs = params.num_jobs
|
| 216 |
+
if not output_dir.exists():
|
| 217 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 218 |
+
|
| 219 |
+
prefix = params.dataset
|
| 220 |
+
subset = params.subset
|
| 221 |
+
suffix = "jsonl.gz"
|
| 222 |
+
|
| 223 |
+
cut_set_name = f"{prefix}_cuts_{subset}.{suffix}"
|
| 224 |
+
|
| 225 |
+
if (src_dir / cut_set_name).is_file():
|
| 226 |
+
logging.info(f"Loading manifests {src_dir / cut_set_name}")
|
| 227 |
+
cut_set = load_manifest_lazy(src_dir / cut_set_name)
|
| 228 |
+
else:
|
| 229 |
+
recordings = load_manifest_lazy(
|
| 230 |
+
src_dir / f"{prefix}_recordings_{subset}.{suffix}"
|
| 231 |
+
)
|
| 232 |
+
supervisions = load_manifest_lazy(
|
| 233 |
+
src_dir / f"{prefix}_supervisions_{subset}.{suffix}"
|
| 234 |
+
)
|
| 235 |
+
cut_set = CutSet.from_manifests(
|
| 236 |
+
recordings=recordings,
|
| 237 |
+
supervisions=supervisions,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
cut_set = cut_set.resample(params.sampling_rate)
|
| 241 |
+
if params.type == "vocos":
|
| 242 |
+
extractor = VocosFbank()
|
| 243 |
+
else:
|
| 244 |
+
raise NotImplementedError(f"{params.type} is not supported")
|
| 245 |
+
|
| 246 |
+
cuts_filename = f"{prefix}_cuts_{subset}.{suffix}"
|
| 247 |
+
if (output_dir / cuts_filename).is_file():
|
| 248 |
+
logging.info(f"{prefix} {subset} already exists - skipping.")
|
| 249 |
+
return
|
| 250 |
+
logging.info(f"Processing {subset} of {prefix}")
|
| 251 |
+
|
| 252 |
+
cut_set = cut_set.compute_and_store_features(
|
| 253 |
+
extractor=extractor,
|
| 254 |
+
storage_path=f"{output_dir}/{prefix}_feats_{subset}",
|
| 255 |
+
num_jobs=num_jobs,
|
| 256 |
+
storage_type=LilcomChunkyWriter,
|
| 257 |
+
)
|
| 258 |
+
logging.info(f"Saving file to {output_dir / cuts_filename}")
|
| 259 |
+
cut_set.to_file(output_dir / cuts_filename)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 264 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 265 |
+
|
| 266 |
+
args = get_args()
|
| 267 |
+
logging.info(vars(args))
|
| 268 |
+
if args.split_cuts:
|
| 269 |
+
compute_fbank_split(params=args)
|
| 270 |
+
else:
|
| 271 |
+
compute_fbank(params=args)
|
| 272 |
+
logging.info("Done!")
|
zipvoice/bin/generate_averaged_model.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2021-2022 Xiaomi Corporation
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
"""
|
| 19 |
+
Usage:
|
| 20 |
+
This script loads checkpoints and averages them.
|
| 21 |
+
|
| 22 |
+
python3 -m zipvoice.bin.generate_averaged_model \
|
| 23 |
+
--epoch 11 \
|
| 24 |
+
--avg 4 \
|
| 25 |
+
--model-name zipvoice \
|
| 26 |
+
--exp-dir exp/zipvoice
|
| 27 |
+
|
| 28 |
+
It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`.
|
| 29 |
+
You can later load it by `torch.load("epoch-11-avg-4.pt")`.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
import argparse
|
| 33 |
+
import json
|
| 34 |
+
import logging
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
|
| 39 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 40 |
+
from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
|
| 41 |
+
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
|
| 42 |
+
from zipvoice.tokenizer.tokenizer import SimpleTokenizer
|
| 43 |
+
from zipvoice.utils.checkpoint import (
|
| 44 |
+
average_checkpoints_with_averaged_model,
|
| 45 |
+
find_checkpoints,
|
| 46 |
+
)
|
| 47 |
+
from zipvoice.utils.common import AttributeDict
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_parser():
|
| 51 |
+
parser = argparse.ArgumentParser(
|
| 52 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--epoch",
|
| 57 |
+
type=int,
|
| 58 |
+
default=11,
|
| 59 |
+
help="""It specifies the checkpoint to use for decoding.
|
| 60 |
+
Note: Epoch counts from 1.
|
| 61 |
+
You can specify --avg to use more checkpoints for model averaging.""",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--iter",
|
| 66 |
+
type=int,
|
| 67 |
+
default=0,
|
| 68 |
+
help="""If positive, --epoch is ignored and it
|
| 69 |
+
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
| 70 |
+
You can specify --avg to use more checkpoints for model averaging.
|
| 71 |
+
""",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--avg",
|
| 76 |
+
type=int,
|
| 77 |
+
default=4,
|
| 78 |
+
help="Number of checkpoints to average. Automatically select "
|
| 79 |
+
"consecutive checkpoints before the checkpoint specified by "
|
| 80 |
+
"'--epoch' or --iter",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--exp-dir",
|
| 85 |
+
type=str,
|
| 86 |
+
default="exp/zipvoice",
|
| 87 |
+
help="The experiment dir",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--model-name",
|
| 92 |
+
type=str,
|
| 93 |
+
default="zipvoice",
|
| 94 |
+
choices=[
|
| 95 |
+
"zipvoice",
|
| 96 |
+
"zipvoice_distill",
|
| 97 |
+
"zipvoice_dialog",
|
| 98 |
+
"zipvoice_dialog_stereo",
|
| 99 |
+
],
|
| 100 |
+
help="The model type to be averaged. ",
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
return parser
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@torch.no_grad()
|
| 107 |
+
def main():
|
| 108 |
+
parser = get_parser()
|
| 109 |
+
args = parser.parse_args()
|
| 110 |
+
params = AttributeDict()
|
| 111 |
+
params.update(vars(args))
|
| 112 |
+
params.exp_dir = Path(params.exp_dir)
|
| 113 |
+
|
| 114 |
+
with open(params.exp_dir / "model.json", "r") as f:
|
| 115 |
+
model_config = json.load(f)
|
| 116 |
+
|
| 117 |
+
# Any tokenizer can be used here.
|
| 118 |
+
# Use SimpleTokenizer for simplicity.
|
| 119 |
+
tokenizer = SimpleTokenizer(token_file=params.exp_dir / "tokens.txt")
|
| 120 |
+
if params.model_name in ["zipvoice", "zipvoice_distill"]:
|
| 121 |
+
tokenizer_config = {
|
| 122 |
+
"vocab_size": tokenizer.vocab_size,
|
| 123 |
+
"pad_id": tokenizer.pad_id,
|
| 124 |
+
}
|
| 125 |
+
elif params.model_name in ["zipvoice_dialog", "zipvoice_dialog_stereo"]:
|
| 126 |
+
tokenizer_config = {
|
| 127 |
+
"vocab_size": tokenizer.vocab_size,
|
| 128 |
+
"pad_id": tokenizer.pad_id,
|
| 129 |
+
"spk_a_id": tokenizer.spk_a_id,
|
| 130 |
+
"spk_b_id": tokenizer.spk_b_id,
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
| 134 |
+
|
| 135 |
+
logging.info("Script started")
|
| 136 |
+
|
| 137 |
+
params.device = torch.device("cpu")
|
| 138 |
+
logging.info(f"Device: {params.device}")
|
| 139 |
+
|
| 140 |
+
logging.info("About to create model")
|
| 141 |
+
if params.model_name == "zipvoice":
|
| 142 |
+
model = ZipVoice(
|
| 143 |
+
**model_config["model"],
|
| 144 |
+
**tokenizer_config,
|
| 145 |
+
)
|
| 146 |
+
elif params.model_name == "zipvoice_distill":
|
| 147 |
+
model = ZipVoiceDistill(
|
| 148 |
+
**model_config["model"],
|
| 149 |
+
**tokenizer_config,
|
| 150 |
+
)
|
| 151 |
+
elif params.model_name == "zipvoice_dialog":
|
| 152 |
+
model = ZipVoiceDialog(
|
| 153 |
+
**model_config["model"],
|
| 154 |
+
**tokenizer_config,
|
| 155 |
+
)
|
| 156 |
+
elif params.model_name == "zipvoice_dialog_stereo":
|
| 157 |
+
model = ZipVoiceDialogStereo(
|
| 158 |
+
**model_config["model"],
|
| 159 |
+
**tokenizer_config,
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
raise ValueError(f"Unknown model name: {params.model_name}")
|
| 163 |
+
|
| 164 |
+
if params.iter > 0:
|
| 165 |
+
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
| 166 |
+
: params.avg + 1
|
| 167 |
+
]
|
| 168 |
+
if len(filenames) == 0:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
|
| 171 |
+
)
|
| 172 |
+
elif len(filenames) < params.avg + 1:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
f"Not enough checkpoints ({len(filenames)}) found for"
|
| 175 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
| 176 |
+
)
|
| 177 |
+
filename_start = filenames[-1]
|
| 178 |
+
filename_end = filenames[0]
|
| 179 |
+
logging.info(
|
| 180 |
+
"Calculating the averaged model over iteration checkpoints"
|
| 181 |
+
f" from {filename_start} (excluded) to {filename_end}"
|
| 182 |
+
)
|
| 183 |
+
model.to(params.device)
|
| 184 |
+
model.load_state_dict(
|
| 185 |
+
average_checkpoints_with_averaged_model(
|
| 186 |
+
filename_start=filename_start,
|
| 187 |
+
filename_end=filename_end,
|
| 188 |
+
device=params.device,
|
| 189 |
+
),
|
| 190 |
+
strict=True,
|
| 191 |
+
)
|
| 192 |
+
else:
|
| 193 |
+
assert params.avg > 0, params.avg
|
| 194 |
+
start = params.epoch - params.avg
|
| 195 |
+
assert start >= 1, start
|
| 196 |
+
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
| 197 |
+
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
| 198 |
+
logging.info(
|
| 199 |
+
f"Calculating the averaged model over epoch range from "
|
| 200 |
+
f"{start} (excluded) to {params.epoch}"
|
| 201 |
+
)
|
| 202 |
+
model.to(params.device)
|
| 203 |
+
model.load_state_dict(
|
| 204 |
+
average_checkpoints_with_averaged_model(
|
| 205 |
+
filename_start=filename_start,
|
| 206 |
+
filename_end=filename_end,
|
| 207 |
+
device=params.device,
|
| 208 |
+
),
|
| 209 |
+
strict=True,
|
| 210 |
+
)
|
| 211 |
+
if params.iter > 0:
|
| 212 |
+
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
|
| 213 |
+
else:
|
| 214 |
+
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
| 215 |
+
|
| 216 |
+
logging.info(f"Saving the averaged checkpoint to {filename}")
|
| 217 |
+
torch.save({"model": model.state_dict()}, filename)
|
| 218 |
+
|
| 219 |
+
num_param = sum([p.numel() for p in model.parameters()])
|
| 220 |
+
logging.info(f"Number of model parameters: {num_param}")
|
| 221 |
+
|
| 222 |
+
logging.info("Done!")
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
if __name__ == "__main__":
|
| 226 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 227 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 228 |
+
|
| 229 |
+
main()
|
zipvoice/bin/infer_zipvoice.py
ADDED
|
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This script generates speech with our pre-trained ZipVoice or
|
| 20 |
+
ZipVoice-Distill models. If no local model is specified,
|
| 21 |
+
Required files will be automatically downloaded from HuggingFace.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
|
| 25 |
+
Note: If you having trouble connecting to HuggingFace,
|
| 26 |
+
try switching endpoint to mirror site:
|
| 27 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 28 |
+
|
| 29 |
+
(1) Inference of a single sentence:
|
| 30 |
+
|
| 31 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 32 |
+
--model-name zipvoice \
|
| 33 |
+
--prompt-wav prompt.wav \
|
| 34 |
+
--prompt-text "I am a prompt." \
|
| 35 |
+
--text "I am a sentence." \
|
| 36 |
+
--res-wav-path result.wav
|
| 37 |
+
|
| 38 |
+
(2) Inference of a list of sentences:
|
| 39 |
+
|
| 40 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 41 |
+
--model-name zipvoice \
|
| 42 |
+
--test-list test.tsv \
|
| 43 |
+
--res-dir results
|
| 44 |
+
|
| 45 |
+
`--model-name` can be `zipvoice` or `zipvoice_distill`,
|
| 46 |
+
which are the models before and after distillation, respectively.
|
| 47 |
+
|
| 48 |
+
Each line of `test.tsv` is in the format of
|
| 49 |
+
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import argparse
|
| 53 |
+
import datetime as dt
|
| 54 |
+
import json
|
| 55 |
+
import logging
|
| 56 |
+
import os
|
| 57 |
+
from pathlib import Path
|
| 58 |
+
from typing import Optional
|
| 59 |
+
|
| 60 |
+
import numpy as np
|
| 61 |
+
import safetensors.torch
|
| 62 |
+
import torch
|
| 63 |
+
import torchaudio
|
| 64 |
+
from huggingface_hub import hf_hub_download
|
| 65 |
+
from lhotse.utils import fix_random_seed
|
| 66 |
+
from vocos import Vocos
|
| 67 |
+
|
| 68 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 69 |
+
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
|
| 70 |
+
from zipvoice.tokenizer.tokenizer import (
|
| 71 |
+
EmiliaTokenizer,
|
| 72 |
+
EspeakTokenizer,
|
| 73 |
+
LibriTTSTokenizer,
|
| 74 |
+
SimpleTokenizer,
|
| 75 |
+
)
|
| 76 |
+
from zipvoice.utils.checkpoint import load_checkpoint
|
| 77 |
+
from zipvoice.utils.common import AttributeDict
|
| 78 |
+
from zipvoice.utils.feature import VocosFbank
|
| 79 |
+
|
| 80 |
+
HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
|
| 81 |
+
MODEL_DIR = {
|
| 82 |
+
"zipvoice": "zipvoice",
|
| 83 |
+
"zipvoice_distill": "zipvoice_distill",
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_parser():
|
| 88 |
+
parser = argparse.ArgumentParser(
|
| 89 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--model-name",
|
| 94 |
+
type=str,
|
| 95 |
+
default="zipvoice",
|
| 96 |
+
choices=["zipvoice", "zipvoice_distill"],
|
| 97 |
+
help="The model used for inference",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--model-dir",
|
| 102 |
+
type=str,
|
| 103 |
+
default=None,
|
| 104 |
+
help="The model directory that contains model checkpoint, configuration "
|
| 105 |
+
"file model.json, and tokens file tokens.txt. Will download pre-trained "
|
| 106 |
+
"checkpoint from huggingface if not specified.",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--checkpoint-name",
|
| 111 |
+
type=str,
|
| 112 |
+
default="model.pt",
|
| 113 |
+
help="The name of model checkpoint.",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--vocoder-path",
|
| 118 |
+
type=str,
|
| 119 |
+
default=None,
|
| 120 |
+
help="The vocoder checkpoint. "
|
| 121 |
+
"Will download pre-trained vocoder from huggingface if not specified.",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--tokenizer",
|
| 126 |
+
type=str,
|
| 127 |
+
default="emilia",
|
| 128 |
+
choices=["emilia", "libritts", "espeak", "simple"],
|
| 129 |
+
help="Tokenizer type.",
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--lang",
|
| 134 |
+
type=str,
|
| 135 |
+
default="en-us",
|
| 136 |
+
help="Language identifier, used when tokenizer type is espeak. see"
|
| 137 |
+
"https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--test-list",
|
| 142 |
+
type=str,
|
| 143 |
+
default=None,
|
| 144 |
+
help="The list of prompt speech, prompt_transcription, "
|
| 145 |
+
"and text to synthesizein the format of "
|
| 146 |
+
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--prompt-wav",
|
| 151 |
+
type=str,
|
| 152 |
+
default=None,
|
| 153 |
+
help="The prompt wav to mimic",
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--prompt-text",
|
| 158 |
+
type=str,
|
| 159 |
+
default=None,
|
| 160 |
+
help="The transcription of the prompt wav",
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--text",
|
| 165 |
+
type=str,
|
| 166 |
+
default=None,
|
| 167 |
+
help="The text to synthesize",
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--res-dir",
|
| 172 |
+
type=str,
|
| 173 |
+
default="results",
|
| 174 |
+
help="""
|
| 175 |
+
Path name of the generated wavs dir,
|
| 176 |
+
used when test-list is not None
|
| 177 |
+
""",
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--res-wav-path",
|
| 182 |
+
type=str,
|
| 183 |
+
default="result.wav",
|
| 184 |
+
help="""
|
| 185 |
+
Path name of the generated wav path,
|
| 186 |
+
used when test-list is None
|
| 187 |
+
""",
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--guidance-scale",
|
| 192 |
+
type=float,
|
| 193 |
+
default=None,
|
| 194 |
+
help="The scale of classifier-free guidance during inference.",
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--num-step",
|
| 199 |
+
type=int,
|
| 200 |
+
default=None,
|
| 201 |
+
help="The number of sampling steps.",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--feat-scale",
|
| 206 |
+
type=float,
|
| 207 |
+
default=0.1,
|
| 208 |
+
help="The scale factor of fbank feature",
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--speed",
|
| 213 |
+
type=float,
|
| 214 |
+
default=1.0,
|
| 215 |
+
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--t-shift",
|
| 220 |
+
type=float,
|
| 221 |
+
default=0.5,
|
| 222 |
+
help="Shift t to smaller ones if t_shift < 1.0",
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--target-rms",
|
| 227 |
+
type=float,
|
| 228 |
+
default=0.1,
|
| 229 |
+
help="Target speech normalization rms value, set to 0 to disable normalization",
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--seed",
|
| 234 |
+
type=int,
|
| 235 |
+
default=666,
|
| 236 |
+
help="Random seed",
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
return parser
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def get_vocoder(vocos_local_path: Optional[str] = None):
|
| 243 |
+
if vocos_local_path:
|
| 244 |
+
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
| 245 |
+
state_dict = torch.load(
|
| 246 |
+
f"{vocos_local_path}/pytorch_model.bin",
|
| 247 |
+
weights_only=True,
|
| 248 |
+
map_location="cpu",
|
| 249 |
+
)
|
| 250 |
+
vocoder.load_state_dict(state_dict)
|
| 251 |
+
else:
|
| 252 |
+
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 253 |
+
return vocoder
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def generate_sentence(
|
| 257 |
+
save_path: str,
|
| 258 |
+
prompt_text: str,
|
| 259 |
+
prompt_wav: str,
|
| 260 |
+
text: str,
|
| 261 |
+
model: torch.nn.Module,
|
| 262 |
+
vocoder: torch.nn.Module,
|
| 263 |
+
tokenizer: EmiliaTokenizer,
|
| 264 |
+
feature_extractor: VocosFbank,
|
| 265 |
+
device: torch.device,
|
| 266 |
+
num_step: int = 16,
|
| 267 |
+
guidance_scale: float = 1.0,
|
| 268 |
+
speed: float = 1.0,
|
| 269 |
+
t_shift: float = 0.5,
|
| 270 |
+
target_rms: float = 0.1,
|
| 271 |
+
feat_scale: float = 0.1,
|
| 272 |
+
sampling_rate: int = 24000,
|
| 273 |
+
):
|
| 274 |
+
"""
|
| 275 |
+
Generate waveform of a text based on a given prompt
|
| 276 |
+
waveform and its transcription.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
save_path (str): Path to save the generated wav.
|
| 280 |
+
prompt_text (str): Transcription of the prompt wav.
|
| 281 |
+
prompt_wav (str): Path to the prompt wav file.
|
| 282 |
+
text (str): Text to be synthesized into a waveform.
|
| 283 |
+
model (torch.nn.Module): The model used for generation.
|
| 284 |
+
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
|
| 285 |
+
tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
|
| 286 |
+
feature_extractor (VocosFbank): The feature extractor used to
|
| 287 |
+
extract acoustic features.
|
| 288 |
+
device (torch.device): The device on which computations are performed.
|
| 289 |
+
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
| 290 |
+
guidance_scale (float, optional): Scale for classifier-free guidance.
|
| 291 |
+
Defaults to 1.0.
|
| 292 |
+
speed (float, optional): Speed control. Defaults to 1.0.
|
| 293 |
+
t_shift (float, optional): Time shift. Defaults to 0.5.
|
| 294 |
+
target_rms (float, optional): Target RMS for waveform normalization.
|
| 295 |
+
Defaults to 0.1.
|
| 296 |
+
feat_scale (float, optional): Scale for features.
|
| 297 |
+
Defaults to 0.1.
|
| 298 |
+
sampling_rate (int, optional): Sampling rate for the waveform.
|
| 299 |
+
Defaults to 24000.
|
| 300 |
+
Returns:
|
| 301 |
+
metrics (dict): Dictionary containing time and real-time
|
| 302 |
+
factor metrics for processing.
|
| 303 |
+
"""
|
| 304 |
+
# Convert text to tokens
|
| 305 |
+
tokens = tokenizer.texts_to_token_ids([text])
|
| 306 |
+
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
| 307 |
+
|
| 308 |
+
# Load and preprocess prompt wav
|
| 309 |
+
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
|
| 310 |
+
|
| 311 |
+
if prompt_sampling_rate != sampling_rate:
|
| 312 |
+
resampler = torchaudio.transforms.Resample(
|
| 313 |
+
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
| 314 |
+
)
|
| 315 |
+
prompt_wav = resampler(prompt_wav)
|
| 316 |
+
|
| 317 |
+
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
| 318 |
+
if prompt_rms < target_rms:
|
| 319 |
+
prompt_wav = prompt_wav * target_rms / prompt_rms
|
| 320 |
+
|
| 321 |
+
# Extract features from prompt wav
|
| 322 |
+
prompt_features = feature_extractor.extract(
|
| 323 |
+
prompt_wav, sampling_rate=sampling_rate
|
| 324 |
+
).to(device)
|
| 325 |
+
|
| 326 |
+
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
| 327 |
+
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
| 328 |
+
|
| 329 |
+
# Start timing
|
| 330 |
+
start_t = dt.datetime.now()
|
| 331 |
+
|
| 332 |
+
# Generate features
|
| 333 |
+
(
|
| 334 |
+
pred_features,
|
| 335 |
+
pred_features_lens,
|
| 336 |
+
pred_prompt_features,
|
| 337 |
+
pred_prompt_features_lens,
|
| 338 |
+
) = model.sample(
|
| 339 |
+
tokens=tokens,
|
| 340 |
+
prompt_tokens=prompt_tokens,
|
| 341 |
+
prompt_features=prompt_features,
|
| 342 |
+
prompt_features_lens=prompt_features_lens,
|
| 343 |
+
speed=speed,
|
| 344 |
+
t_shift=t_shift,
|
| 345 |
+
duration="predict",
|
| 346 |
+
num_step=num_step,
|
| 347 |
+
guidance_scale=guidance_scale,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# Postprocess predicted features
|
| 351 |
+
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
| 352 |
+
|
| 353 |
+
# Start vocoder processing
|
| 354 |
+
start_vocoder_t = dt.datetime.now()
|
| 355 |
+
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
| 356 |
+
|
| 357 |
+
# Calculate processing times and real-time factors
|
| 358 |
+
t = (dt.datetime.now() - start_t).total_seconds()
|
| 359 |
+
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
| 360 |
+
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
| 361 |
+
wav_seconds = wav.shape[-1] / sampling_rate
|
| 362 |
+
rtf = t / wav_seconds
|
| 363 |
+
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
| 364 |
+
rtf_vocoder = t_vocoder / wav_seconds
|
| 365 |
+
metrics = {
|
| 366 |
+
"t": t,
|
| 367 |
+
"t_no_vocoder": t_no_vocoder,
|
| 368 |
+
"t_vocoder": t_vocoder,
|
| 369 |
+
"wav_seconds": wav_seconds,
|
| 370 |
+
"rtf": rtf,
|
| 371 |
+
"rtf_no_vocoder": rtf_no_vocoder,
|
| 372 |
+
"rtf_vocoder": rtf_vocoder,
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
# Adjust wav volume if necessary
|
| 376 |
+
if prompt_rms < target_rms:
|
| 377 |
+
wav = wav * prompt_rms / target_rms
|
| 378 |
+
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
|
| 379 |
+
|
| 380 |
+
return metrics
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def generate_list(
|
| 384 |
+
res_dir: str,
|
| 385 |
+
test_list: str,
|
| 386 |
+
model: torch.nn.Module,
|
| 387 |
+
vocoder: torch.nn.Module,
|
| 388 |
+
tokenizer: EmiliaTokenizer,
|
| 389 |
+
feature_extractor: VocosFbank,
|
| 390 |
+
device: torch.device,
|
| 391 |
+
num_step: int = 16,
|
| 392 |
+
guidance_scale: float = 1.0,
|
| 393 |
+
speed: float = 1.0,
|
| 394 |
+
t_shift: float = 0.5,
|
| 395 |
+
target_rms: float = 0.1,
|
| 396 |
+
feat_scale: float = 0.1,
|
| 397 |
+
sampling_rate: int = 24000,
|
| 398 |
+
):
|
| 399 |
+
total_t = []
|
| 400 |
+
total_t_no_vocoder = []
|
| 401 |
+
total_t_vocoder = []
|
| 402 |
+
total_wav_seconds = []
|
| 403 |
+
|
| 404 |
+
with open(test_list, "r") as fr:
|
| 405 |
+
lines = fr.readlines()
|
| 406 |
+
|
| 407 |
+
for i, line in enumerate(lines):
|
| 408 |
+
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
| 409 |
+
save_path = f"{res_dir}/{wav_name}.wav"
|
| 410 |
+
metrics = generate_sentence(
|
| 411 |
+
save_path=save_path,
|
| 412 |
+
prompt_text=prompt_text,
|
| 413 |
+
prompt_wav=prompt_wav,
|
| 414 |
+
text=text,
|
| 415 |
+
model=model,
|
| 416 |
+
vocoder=vocoder,
|
| 417 |
+
tokenizer=tokenizer,
|
| 418 |
+
feature_extractor=feature_extractor,
|
| 419 |
+
device=device,
|
| 420 |
+
num_step=num_step,
|
| 421 |
+
guidance_scale=guidance_scale,
|
| 422 |
+
speed=speed,
|
| 423 |
+
t_shift=t_shift,
|
| 424 |
+
target_rms=target_rms,
|
| 425 |
+
feat_scale=feat_scale,
|
| 426 |
+
sampling_rate=sampling_rate,
|
| 427 |
+
)
|
| 428 |
+
logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
|
| 429 |
+
total_t.append(metrics["t"])
|
| 430 |
+
total_t_no_vocoder.append(metrics["t_no_vocoder"])
|
| 431 |
+
total_t_vocoder.append(metrics["t_vocoder"])
|
| 432 |
+
total_wav_seconds.append(metrics["wav_seconds"])
|
| 433 |
+
|
| 434 |
+
logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
|
| 435 |
+
logging.info(
|
| 436 |
+
f"Average RTF w/o vocoder: "
|
| 437 |
+
f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 438 |
+
)
|
| 439 |
+
logging.info(
|
| 440 |
+
f"Average RTF vocoder: "
|
| 441 |
+
f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
@torch.inference_mode()
|
| 446 |
+
def main():
|
| 447 |
+
parser = get_parser()
|
| 448 |
+
args = parser.parse_args()
|
| 449 |
+
|
| 450 |
+
params = AttributeDict()
|
| 451 |
+
params.update(vars(args))
|
| 452 |
+
fix_random_seed(params.seed)
|
| 453 |
+
|
| 454 |
+
model_defaults = {
|
| 455 |
+
"zipvoice": {
|
| 456 |
+
"num_step": 16,
|
| 457 |
+
"guidance_scale": 1.0,
|
| 458 |
+
},
|
| 459 |
+
"zipvoice_distill": {
|
| 460 |
+
"num_step": 8,
|
| 461 |
+
"guidance_scale": 3.0,
|
| 462 |
+
},
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
model_specific_defaults = model_defaults.get(params.model_name, {})
|
| 466 |
+
|
| 467 |
+
for param, value in model_specific_defaults.items():
|
| 468 |
+
if getattr(params, param) is None:
|
| 469 |
+
setattr(params, param, value)
|
| 470 |
+
logging.info(f"Setting {param} to default value: {value}")
|
| 471 |
+
|
| 472 |
+
assert (params.test_list is not None) ^ (
|
| 473 |
+
(params.prompt_wav and params.prompt_text and params.text) is not None
|
| 474 |
+
), (
|
| 475 |
+
"For inference, please provide prompts and text with either '--test-list'"
|
| 476 |
+
" or '--prompt-wav, --prompt-text and --text'."
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
if params.model_dir is not None:
|
| 480 |
+
params.model_dir = Path(params.model_dir)
|
| 481 |
+
if not params.model_dir.is_dir():
|
| 482 |
+
raise FileNotFoundError(f"{params.model_dir} does not exist")
|
| 483 |
+
for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
|
| 484 |
+
if not (params.model_dir / filename).is_file():
|
| 485 |
+
raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
|
| 486 |
+
model_ckpt = params.model_dir / params.checkpoint_name
|
| 487 |
+
model_config = params.model_dir / "model.json"
|
| 488 |
+
token_file = params.model_dir / "tokens.txt"
|
| 489 |
+
logging.info(
|
| 490 |
+
f"Using local model dir {params.model_dir}, "
|
| 491 |
+
f"checkpoint {params.checkpoint_name}"
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
logging.info("Using pretrained model from the huggingface")
|
| 495 |
+
logging.info("Downloading the requires files from HuggingFace")
|
| 496 |
+
model_ckpt = hf_hub_download(
|
| 497 |
+
HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.pt"
|
| 498 |
+
)
|
| 499 |
+
model_config = hf_hub_download(
|
| 500 |
+
HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
token_file = hf_hub_download(
|
| 504 |
+
HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
logging.info("Loading model...")
|
| 508 |
+
|
| 509 |
+
if params.tokenizer == "emilia":
|
| 510 |
+
tokenizer = EmiliaTokenizer(token_file=token_file)
|
| 511 |
+
elif params.tokenizer == "libritts":
|
| 512 |
+
tokenizer = LibriTTSTokenizer(token_file=token_file)
|
| 513 |
+
elif params.tokenizer == "espeak":
|
| 514 |
+
tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
|
| 515 |
+
else:
|
| 516 |
+
assert params.tokenizer == "simple"
|
| 517 |
+
tokenizer = SimpleTokenizer(token_file=token_file)
|
| 518 |
+
|
| 519 |
+
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
|
| 520 |
+
|
| 521 |
+
with open(model_config, "r") as f:
|
| 522 |
+
model_config = json.load(f)
|
| 523 |
+
|
| 524 |
+
if params.model_name == "zipvoice":
|
| 525 |
+
model = ZipVoice(
|
| 526 |
+
**model_config["model"],
|
| 527 |
+
**tokenizer_config,
|
| 528 |
+
)
|
| 529 |
+
else:
|
| 530 |
+
assert params.model_name == "zipvoice_distill"
|
| 531 |
+
model = ZipVoiceDistill(
|
| 532 |
+
**model_config["model"],
|
| 533 |
+
**tokenizer_config,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
if str(model_ckpt).endswith(".safetensors"):
|
| 537 |
+
safetensors.torch.load_model(model, model_ckpt)
|
| 538 |
+
elif str(model_ckpt).endswith(".pt"):
|
| 539 |
+
load_checkpoint(filename=model_ckpt, model=model, strict=True)
|
| 540 |
+
else:
|
| 541 |
+
raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
|
| 542 |
+
|
| 543 |
+
if torch.cuda.is_available():
|
| 544 |
+
params.device = torch.device("cuda", 0)
|
| 545 |
+
elif torch.backends.mps.is_available():
|
| 546 |
+
params.device = torch.device("mps")
|
| 547 |
+
else:
|
| 548 |
+
params.device = torch.device("cpu")
|
| 549 |
+
logging.info(f"Device: {params.device}")
|
| 550 |
+
|
| 551 |
+
model = model.to(params.device)
|
| 552 |
+
model.eval()
|
| 553 |
+
|
| 554 |
+
vocoder = get_vocoder(params.vocoder_path)
|
| 555 |
+
vocoder = vocoder.to(params.device)
|
| 556 |
+
vocoder.eval()
|
| 557 |
+
|
| 558 |
+
if model_config["feature"]["type"] == "vocos":
|
| 559 |
+
feature_extractor = VocosFbank()
|
| 560 |
+
else:
|
| 561 |
+
raise NotImplementedError(
|
| 562 |
+
f"Unsupported feature type: {model_config['feature']['type']}"
|
| 563 |
+
)
|
| 564 |
+
params.sampling_rate = model_config["feature"]["sampling_rate"]
|
| 565 |
+
|
| 566 |
+
logging.info("Start generating...")
|
| 567 |
+
if params.test_list:
|
| 568 |
+
os.makedirs(params.res_dir, exist_ok=True)
|
| 569 |
+
generate_list(
|
| 570 |
+
res_dir=params.res_dir,
|
| 571 |
+
test_list=params.test_list,
|
| 572 |
+
model=model,
|
| 573 |
+
vocoder=vocoder,
|
| 574 |
+
tokenizer=tokenizer,
|
| 575 |
+
feature_extractor=feature_extractor,
|
| 576 |
+
device=params.device,
|
| 577 |
+
num_step=params.num_step,
|
| 578 |
+
guidance_scale=params.guidance_scale,
|
| 579 |
+
speed=params.speed,
|
| 580 |
+
t_shift=params.t_shift,
|
| 581 |
+
target_rms=params.target_rms,
|
| 582 |
+
feat_scale=params.feat_scale,
|
| 583 |
+
sampling_rate=params.sampling_rate,
|
| 584 |
+
)
|
| 585 |
+
else:
|
| 586 |
+
generate_sentence(
|
| 587 |
+
save_path=params.res_wav_path,
|
| 588 |
+
prompt_text=params.prompt_text,
|
| 589 |
+
prompt_wav=params.prompt_wav,
|
| 590 |
+
text=params.text,
|
| 591 |
+
model=model,
|
| 592 |
+
vocoder=vocoder,
|
| 593 |
+
tokenizer=tokenizer,
|
| 594 |
+
feature_extractor=feature_extractor,
|
| 595 |
+
device=params.device,
|
| 596 |
+
num_step=params.num_step,
|
| 597 |
+
guidance_scale=params.guidance_scale,
|
| 598 |
+
speed=params.speed,
|
| 599 |
+
t_shift=params.t_shift,
|
| 600 |
+
target_rms=params.target_rms,
|
| 601 |
+
feat_scale=params.feat_scale,
|
| 602 |
+
sampling_rate=params.sampling_rate,
|
| 603 |
+
)
|
| 604 |
+
logging.info("Done")
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
if __name__ == "__main__":
|
| 608 |
+
torch.set_num_threads(1)
|
| 609 |
+
torch.set_num_interop_threads(1)
|
| 610 |
+
|
| 611 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 612 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 613 |
+
|
| 614 |
+
main()
|
zipvoice/bin/infer_zipvoice_dialog.py
ADDED
|
@@ -0,0 +1,756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This script generates speech with our pre-trained ZipVoice-Dialog or
|
| 20 |
+
ZipVoice-Dialog-Stereo models. If no local model is specified,
|
| 21 |
+
Required files will be automatically downloaded from HuggingFace.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
|
| 25 |
+
Note: If you having trouble connecting to HuggingFace,
|
| 26 |
+
try switching endpoint to mirror site:
|
| 27 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 28 |
+
|
| 29 |
+
python3 -m zipvoice.bin.infer_zipvoice_dialog \
|
| 30 |
+
--model-name zipvoice_dialog \
|
| 31 |
+
--test-list test.tsv \
|
| 32 |
+
--res-dir results
|
| 33 |
+
|
| 34 |
+
`--model-name` can be `zipvoice_dialog` or `zipvoice_dialog_stereo`,
|
| 35 |
+
which generate mono and stereo dialogues, respectively.
|
| 36 |
+
|
| 37 |
+
Each line of `test.tsv` is in the format of merged conversation:
|
| 38 |
+
'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'
|
| 39 |
+
or splited conversation:
|
| 40 |
+
'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}
|
| 41 |
+
\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
import argparse
|
| 45 |
+
import datetime as dt
|
| 46 |
+
import json
|
| 47 |
+
import logging
|
| 48 |
+
import os
|
| 49 |
+
from pathlib import Path
|
| 50 |
+
from typing import List, Optional, Union
|
| 51 |
+
|
| 52 |
+
import numpy as np
|
| 53 |
+
import safetensors.torch
|
| 54 |
+
import torch
|
| 55 |
+
import torchaudio
|
| 56 |
+
from huggingface_hub import hf_hub_download
|
| 57 |
+
from lhotse.utils import fix_random_seed
|
| 58 |
+
from vocos import Vocos
|
| 59 |
+
|
| 60 |
+
from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
|
| 61 |
+
from zipvoice.tokenizer.tokenizer import DialogTokenizer
|
| 62 |
+
from zipvoice.utils.checkpoint import load_checkpoint
|
| 63 |
+
from zipvoice.utils.common import AttributeDict
|
| 64 |
+
from zipvoice.utils.feature import VocosFbank
|
| 65 |
+
|
| 66 |
+
HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
|
| 67 |
+
MODEL_DIR = {
|
| 68 |
+
"zipvoice_dialog": "zipvoice_dialog",
|
| 69 |
+
"zipvoice_dialog_stereo": "zipvoice_dialog_stereo",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_parser():
|
| 74 |
+
parser = argparse.ArgumentParser(
|
| 75 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--model-name",
|
| 80 |
+
type=str,
|
| 81 |
+
default="zipvoice_dialog",
|
| 82 |
+
choices=["zipvoice_dialog", "zipvoice_dialog_stereo"],
|
| 83 |
+
help="The model used for inference",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--model-dir",
|
| 88 |
+
type=str,
|
| 89 |
+
default=None,
|
| 90 |
+
help="The model directory that contains model checkpoint, configuration "
|
| 91 |
+
"file model.json, and tokens file tokens.txt. Will download pre-trained "
|
| 92 |
+
"checkpoint from huggingface if not specified.",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--checkpoint-name",
|
| 97 |
+
type=str,
|
| 98 |
+
default="model.pt",
|
| 99 |
+
help="The name of model checkpoint.",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--vocoder-path",
|
| 104 |
+
type=str,
|
| 105 |
+
default=None,
|
| 106 |
+
help="The vocoder checkpoint. "
|
| 107 |
+
"Will download pre-trained vocoder from huggingface if not specified.",
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--test-list",
|
| 112 |
+
type=str,
|
| 113 |
+
default=None,
|
| 114 |
+
help="The list of prompt speech, prompt_transcription, "
|
| 115 |
+
"and text to synthesizein the format of merged conversation: "
|
| 116 |
+
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' "
|
| 117 |
+
"or splited conversation: "
|
| 118 |
+
"'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}"
|
| 119 |
+
"\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'.",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--res-dir",
|
| 124 |
+
type=str,
|
| 125 |
+
default="results",
|
| 126 |
+
help="""
|
| 127 |
+
Path name of the generated wavs dir,
|
| 128 |
+
used when test-list is not None
|
| 129 |
+
""",
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--guidance-scale",
|
| 134 |
+
type=float,
|
| 135 |
+
default=1.5,
|
| 136 |
+
help="The scale of classifier-free guidance during inference.",
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--num-step",
|
| 141 |
+
type=int,
|
| 142 |
+
default=16,
|
| 143 |
+
help="The number of sampling steps.",
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--feat-scale",
|
| 148 |
+
type=float,
|
| 149 |
+
default=0.1,
|
| 150 |
+
help="The scale factor of fbank feature",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
parser.add_argument(
|
| 154 |
+
"--speed",
|
| 155 |
+
type=float,
|
| 156 |
+
default=1.0,
|
| 157 |
+
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--t-shift",
|
| 162 |
+
type=float,
|
| 163 |
+
default=0.5,
|
| 164 |
+
help="Shift t to smaller ones if t_shift < 1.0",
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--target-rms",
|
| 169 |
+
type=float,
|
| 170 |
+
default=0.1,
|
| 171 |
+
help="Target speech normalization rms value, set to 0 to disable normalization",
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"--seed",
|
| 176 |
+
type=int,
|
| 177 |
+
default=666,
|
| 178 |
+
help="Random seed",
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--silence-wav",
|
| 183 |
+
type=str,
|
| 184 |
+
default="assets/silence.wav",
|
| 185 |
+
help="Path of the silence wav file, used in two-channel generation "
|
| 186 |
+
"with single-channel prompts",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return parser
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_vocoder(vocos_local_path: Optional[str] = None):
|
| 193 |
+
if vocos_local_path:
|
| 194 |
+
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
| 195 |
+
state_dict = torch.load(
|
| 196 |
+
f"{vocos_local_path}/pytorch_model.bin",
|
| 197 |
+
weights_only=True,
|
| 198 |
+
map_location="cpu",
|
| 199 |
+
)
|
| 200 |
+
vocoder.load_state_dict(state_dict)
|
| 201 |
+
else:
|
| 202 |
+
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 203 |
+
return vocoder
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def generate_sentence(
|
| 207 |
+
save_path: str,
|
| 208 |
+
prompt_text: str,
|
| 209 |
+
prompt_wav: Union[str, List[str]],
|
| 210 |
+
text: str,
|
| 211 |
+
model: torch.nn.Module,
|
| 212 |
+
vocoder: torch.nn.Module,
|
| 213 |
+
tokenizer: DialogTokenizer,
|
| 214 |
+
feature_extractor: VocosFbank,
|
| 215 |
+
device: torch.device,
|
| 216 |
+
num_step: int = 16,
|
| 217 |
+
guidance_scale: float = 1.0,
|
| 218 |
+
speed: float = 1.0,
|
| 219 |
+
t_shift: float = 0.5,
|
| 220 |
+
target_rms: float = 0.1,
|
| 221 |
+
feat_scale: float = 0.1,
|
| 222 |
+
sampling_rate: int = 24000,
|
| 223 |
+
):
|
| 224 |
+
"""
|
| 225 |
+
Generate waveform of a text based on a given prompt
|
| 226 |
+
waveform and its transcription.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
save_path (str): Path to save the generated wav.
|
| 230 |
+
prompt_text (str): Transcription of the prompt wav.
|
| 231 |
+
prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
|
| 232 |
+
one or two wav files, which corresponding to a merged conversational
|
| 233 |
+
speech or two seperate speaker's speech.
|
| 234 |
+
text (str): Text to be synthesized into a waveform.
|
| 235 |
+
model (torch.nn.Module): The model used for generation.
|
| 236 |
+
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
|
| 237 |
+
tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
|
| 238 |
+
feature_extractor (VocosFbank): The feature extractor used to
|
| 239 |
+
extract acoustic features.
|
| 240 |
+
device (torch.device): The device on which computations are performed.
|
| 241 |
+
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
| 242 |
+
guidance_scale (float, optional): Scale for classifier-free guidance.
|
| 243 |
+
Defaults to 1.0.
|
| 244 |
+
speed (float, optional): Speed control. Defaults to 1.0.
|
| 245 |
+
t_shift (float, optional): Time shift. Defaults to 0.5.
|
| 246 |
+
target_rms (float, optional): Target RMS for waveform normalization.
|
| 247 |
+
Defaults to 0.1.
|
| 248 |
+
feat_scale (float, optional): Scale for features.
|
| 249 |
+
Defaults to 0.1.
|
| 250 |
+
sampling_rate (int, optional): Sampling rate for the waveform.
|
| 251 |
+
Defaults to 24000.
|
| 252 |
+
Returns:
|
| 253 |
+
metrics (dict): Dictionary containing time and real-time
|
| 254 |
+
factor metrics for processing.
|
| 255 |
+
"""
|
| 256 |
+
# Convert text to tokens
|
| 257 |
+
tokens = tokenizer.texts_to_token_ids([text])
|
| 258 |
+
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
| 259 |
+
|
| 260 |
+
# Load and preprocess prompt wav
|
| 261 |
+
if isinstance(prompt_wav, str):
|
| 262 |
+
prompt_wav = [
|
| 263 |
+
prompt_wav,
|
| 264 |
+
]
|
| 265 |
+
else:
|
| 266 |
+
assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
|
| 267 |
+
|
| 268 |
+
loaded_prompt_wavs = prompt_wav
|
| 269 |
+
for i in range(len(prompt_wav)):
|
| 270 |
+
loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i])
|
| 271 |
+
if prompt_sampling_rate != sampling_rate:
|
| 272 |
+
resampler = torchaudio.transforms.Resample(
|
| 273 |
+
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
| 274 |
+
)
|
| 275 |
+
loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i])
|
| 276 |
+
if loaded_prompt_wavs[i].size(0) != 1:
|
| 277 |
+
loaded_prompt_wavs[i] = loaded_prompt_wavs[i].mean(0, keepdim=True)
|
| 278 |
+
|
| 279 |
+
if len(loaded_prompt_wavs) == 1:
|
| 280 |
+
prompt_wav = loaded_prompt_wavs[0]
|
| 281 |
+
else:
|
| 282 |
+
prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
|
| 283 |
+
|
| 284 |
+
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
| 285 |
+
if prompt_rms < target_rms:
|
| 286 |
+
prompt_wav = prompt_wav * target_rms / prompt_rms
|
| 287 |
+
|
| 288 |
+
# Extract features from prompt wav
|
| 289 |
+
prompt_features = feature_extractor.extract(
|
| 290 |
+
prompt_wav, sampling_rate=sampling_rate
|
| 291 |
+
).to(device)
|
| 292 |
+
|
| 293 |
+
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
| 294 |
+
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
| 295 |
+
|
| 296 |
+
# Start timing
|
| 297 |
+
start_t = dt.datetime.now()
|
| 298 |
+
|
| 299 |
+
# Generate features
|
| 300 |
+
(
|
| 301 |
+
pred_features,
|
| 302 |
+
pred_features_lens,
|
| 303 |
+
pred_prompt_features,
|
| 304 |
+
pred_prompt_features_lens,
|
| 305 |
+
) = model.sample(
|
| 306 |
+
tokens=tokens,
|
| 307 |
+
prompt_tokens=prompt_tokens,
|
| 308 |
+
prompt_features=prompt_features,
|
| 309 |
+
prompt_features_lens=prompt_features_lens,
|
| 310 |
+
speed=speed,
|
| 311 |
+
t_shift=t_shift,
|
| 312 |
+
duration="predict",
|
| 313 |
+
num_step=num_step,
|
| 314 |
+
guidance_scale=guidance_scale,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# Postprocess predicted features
|
| 318 |
+
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
| 319 |
+
|
| 320 |
+
# Start vocoder processing
|
| 321 |
+
start_vocoder_t = dt.datetime.now()
|
| 322 |
+
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
| 323 |
+
|
| 324 |
+
# Calculate processing times and real-time factors
|
| 325 |
+
t = (dt.datetime.now() - start_t).total_seconds()
|
| 326 |
+
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
| 327 |
+
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
| 328 |
+
wav_seconds = wav.shape[-1] / sampling_rate
|
| 329 |
+
rtf = t / wav_seconds
|
| 330 |
+
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
| 331 |
+
rtf_vocoder = t_vocoder / wav_seconds
|
| 332 |
+
metrics = {
|
| 333 |
+
"t": t,
|
| 334 |
+
"t_no_vocoder": t_no_vocoder,
|
| 335 |
+
"t_vocoder": t_vocoder,
|
| 336 |
+
"wav_seconds": wav_seconds,
|
| 337 |
+
"rtf": rtf,
|
| 338 |
+
"rtf_no_vocoder": rtf_no_vocoder,
|
| 339 |
+
"rtf_vocoder": rtf_vocoder,
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
# Adjust wav volume if necessary
|
| 343 |
+
if prompt_rms < target_rms:
|
| 344 |
+
wav = wav * prompt_rms / target_rms
|
| 345 |
+
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
|
| 346 |
+
|
| 347 |
+
return metrics
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def generate_sentence_stereo(
|
| 351 |
+
save_path: str,
|
| 352 |
+
prompt_text: str,
|
| 353 |
+
prompt_wav: Union[str, List[str]],
|
| 354 |
+
text: str,
|
| 355 |
+
model: torch.nn.Module,
|
| 356 |
+
vocoder: torch.nn.Module,
|
| 357 |
+
tokenizer: DialogTokenizer,
|
| 358 |
+
feature_extractor: VocosFbank,
|
| 359 |
+
device: torch.device,
|
| 360 |
+
num_step: int = 16,
|
| 361 |
+
guidance_scale: float = 1.0,
|
| 362 |
+
speed: float = 1.0,
|
| 363 |
+
t_shift: float = 0.5,
|
| 364 |
+
target_rms: float = 0.1,
|
| 365 |
+
feat_scale: float = 0.1,
|
| 366 |
+
sampling_rate: int = 24000,
|
| 367 |
+
silence_wav: Optional[str] = None,
|
| 368 |
+
):
|
| 369 |
+
"""
|
| 370 |
+
Generate waveform of a text based on a given prompt
|
| 371 |
+
waveform and its transcription.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
save_path (str): Path to save the generated wav.
|
| 375 |
+
prompt_text (str): Transcription of the prompt wav.
|
| 376 |
+
prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
|
| 377 |
+
one or two wav files, which corresponding to a merged conversational
|
| 378 |
+
speech or two seperate speaker's speech.
|
| 379 |
+
text (str): Text to be synthesized into a waveform.
|
| 380 |
+
model (torch.nn.Module): The model used for generation.
|
| 381 |
+
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
|
| 382 |
+
tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
|
| 383 |
+
feature_extractor (VocosFbank): The feature extractor used to
|
| 384 |
+
extract acoustic features.
|
| 385 |
+
device (torch.device): The device on which computations are performed.
|
| 386 |
+
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
| 387 |
+
guidance_scale (float, optional): Scale for classifier-free guidance.
|
| 388 |
+
Defaults to 1.0.
|
| 389 |
+
speed (float, optional): Speed control. Defaults to 1.0.
|
| 390 |
+
t_shift (float, optional): Time shift. Defaults to 0.5.
|
| 391 |
+
target_rms (float, optional): Target RMS for waveform normalization.
|
| 392 |
+
Defaults to 0.1.
|
| 393 |
+
feat_scale (float, optional): Scale for features.
|
| 394 |
+
Defaults to 0.1.
|
| 395 |
+
sampling_rate (int, optional): Sampling rate for the waveform.
|
| 396 |
+
Defaults to 24000.
|
| 397 |
+
silence_wav (str): Path of the silence wav file, used in two-channel
|
| 398 |
+
generation with single-channel prompts
|
| 399 |
+
Returns:
|
| 400 |
+
metrics (dict): Dictionary containing time and real-time
|
| 401 |
+
factor metrics for processing.
|
| 402 |
+
"""
|
| 403 |
+
# Convert text to tokens
|
| 404 |
+
tokens = tokenizer.texts_to_token_ids([text])
|
| 405 |
+
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
| 406 |
+
|
| 407 |
+
# Load and preprocess prompt wav
|
| 408 |
+
if isinstance(prompt_wav, str):
|
| 409 |
+
prompt_wav = [
|
| 410 |
+
prompt_wav,
|
| 411 |
+
]
|
| 412 |
+
else:
|
| 413 |
+
assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
|
| 414 |
+
|
| 415 |
+
loaded_prompt_wavs = prompt_wav
|
| 416 |
+
for i in range(len(prompt_wav)):
|
| 417 |
+
loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i])
|
| 418 |
+
if prompt_sampling_rate != sampling_rate:
|
| 419 |
+
resampler = torchaudio.transforms.Resample(
|
| 420 |
+
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
| 421 |
+
)
|
| 422 |
+
loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i])
|
| 423 |
+
|
| 424 |
+
if len(loaded_prompt_wavs) == 1:
|
| 425 |
+
assert (
|
| 426 |
+
loaded_prompt_wavs[0].size(0) == 2
|
| 427 |
+
), "Merged prompt wav must be stereo for stereo dialogue generation"
|
| 428 |
+
prompt_wav = loaded_prompt_wavs[0]
|
| 429 |
+
|
| 430 |
+
else:
|
| 431 |
+
assert len(loaded_prompt_wavs) == 2
|
| 432 |
+
if loaded_prompt_wavs[0].size(0) == 2:
|
| 433 |
+
prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
|
| 434 |
+
else:
|
| 435 |
+
assert loaded_prompt_wavs[0].size(0) == 1
|
| 436 |
+
silence_wav, silence_sampling_rate = torchaudio.load(silence_wav)
|
| 437 |
+
assert silence_sampling_rate == sampling_rate
|
| 438 |
+
prompt_wav = silence_wav[
|
| 439 |
+
:, : loaded_prompt_wavs[0].size(1) + loaded_prompt_wavs[1].size(1)
|
| 440 |
+
]
|
| 441 |
+
prompt_wav[0, : loaded_prompt_wavs[0].size(1)] = loaded_prompt_wavs[0]
|
| 442 |
+
prompt_wav[1, loaded_prompt_wavs[0].size(1) :] = loaded_prompt_wavs[1]
|
| 443 |
+
|
| 444 |
+
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
| 445 |
+
if prompt_rms < target_rms:
|
| 446 |
+
prompt_wav = prompt_wav * target_rms / prompt_rms
|
| 447 |
+
|
| 448 |
+
# Extract features from prompt wav
|
| 449 |
+
prompt_features = feature_extractor.extract(
|
| 450 |
+
prompt_wav, sampling_rate=sampling_rate
|
| 451 |
+
).to(device)
|
| 452 |
+
|
| 453 |
+
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
| 454 |
+
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
| 455 |
+
|
| 456 |
+
# Start timing
|
| 457 |
+
start_t = dt.datetime.now()
|
| 458 |
+
|
| 459 |
+
# Generate features
|
| 460 |
+
(
|
| 461 |
+
pred_features,
|
| 462 |
+
pred_features_lens,
|
| 463 |
+
pred_prompt_features,
|
| 464 |
+
pred_prompt_features_lens,
|
| 465 |
+
) = model.sample(
|
| 466 |
+
tokens=tokens,
|
| 467 |
+
prompt_tokens=prompt_tokens,
|
| 468 |
+
prompt_features=prompt_features,
|
| 469 |
+
prompt_features_lens=prompt_features_lens,
|
| 470 |
+
speed=speed,
|
| 471 |
+
t_shift=t_shift,
|
| 472 |
+
duration="predict",
|
| 473 |
+
num_step=num_step,
|
| 474 |
+
guidance_scale=guidance_scale,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Postprocess predicted features
|
| 478 |
+
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
| 479 |
+
|
| 480 |
+
# Start vocoder processing
|
| 481 |
+
start_vocoder_t = dt.datetime.now()
|
| 482 |
+
feat_dim = pred_features.size(1) // 2
|
| 483 |
+
wav_left = vocoder.decode(pred_features[:, :feat_dim]).squeeze(1).clamp(-1, 1)
|
| 484 |
+
wav_right = (
|
| 485 |
+
vocoder.decode(pred_features[:, feat_dim : feat_dim * 2])
|
| 486 |
+
.squeeze(1)
|
| 487 |
+
.clamp(-1, 1)
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
wav = torch.cat([wav_left, wav_right], dim=0)
|
| 491 |
+
|
| 492 |
+
# Calculate processing times and real-time factors
|
| 493 |
+
t = (dt.datetime.now() - start_t).total_seconds()
|
| 494 |
+
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
| 495 |
+
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
| 496 |
+
wav_seconds = wav.shape[-1] / sampling_rate
|
| 497 |
+
rtf = t / wav_seconds
|
| 498 |
+
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
| 499 |
+
rtf_vocoder = t_vocoder / wav_seconds
|
| 500 |
+
metrics = {
|
| 501 |
+
"t": t,
|
| 502 |
+
"t_no_vocoder": t_no_vocoder,
|
| 503 |
+
"t_vocoder": t_vocoder,
|
| 504 |
+
"wav_seconds": wav_seconds,
|
| 505 |
+
"rtf": rtf,
|
| 506 |
+
"rtf_no_vocoder": rtf_no_vocoder,
|
| 507 |
+
"rtf_vocoder": rtf_vocoder,
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
# Adjust wav volume if necessary
|
| 511 |
+
if prompt_rms < target_rms:
|
| 512 |
+
wav = wav * prompt_rms / target_rms
|
| 513 |
+
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
|
| 514 |
+
|
| 515 |
+
return metrics
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def generate_list(
|
| 519 |
+
model_name: str,
|
| 520 |
+
res_dir: str,
|
| 521 |
+
test_list: str,
|
| 522 |
+
model: torch.nn.Module,
|
| 523 |
+
vocoder: torch.nn.Module,
|
| 524 |
+
tokenizer: DialogTokenizer,
|
| 525 |
+
feature_extractor: VocosFbank,
|
| 526 |
+
device: torch.device,
|
| 527 |
+
num_step: int = 16,
|
| 528 |
+
guidance_scale: float = 1.5,
|
| 529 |
+
speed: float = 1.0,
|
| 530 |
+
t_shift: float = 0.5,
|
| 531 |
+
target_rms: float = 0.1,
|
| 532 |
+
feat_scale: float = 0.1,
|
| 533 |
+
sampling_rate: int = 24000,
|
| 534 |
+
silence_wav: Optional[str] = None,
|
| 535 |
+
):
|
| 536 |
+
total_t = []
|
| 537 |
+
total_t_no_vocoder = []
|
| 538 |
+
total_t_vocoder = []
|
| 539 |
+
total_wav_seconds = []
|
| 540 |
+
|
| 541 |
+
with open(test_list, "r") as fr:
|
| 542 |
+
lines = fr.readlines()
|
| 543 |
+
|
| 544 |
+
for i, line in enumerate(lines):
|
| 545 |
+
items = line.strip().split("\t")
|
| 546 |
+
if len(items) == 6:
|
| 547 |
+
(
|
| 548 |
+
wav_name,
|
| 549 |
+
prompt_text_1,
|
| 550 |
+
prompt_text_2,
|
| 551 |
+
prompt_wav_1,
|
| 552 |
+
prompt_wav_2,
|
| 553 |
+
text,
|
| 554 |
+
) = items
|
| 555 |
+
prompt_text = f"[S1]{prompt_text_1}[S2]{prompt_text_2}"
|
| 556 |
+
prompt_wav = [prompt_wav_1, prompt_wav_2]
|
| 557 |
+
elif len(items) == 4:
|
| 558 |
+
wav_name, prompt_text, prompt_wav, text = items
|
| 559 |
+
else:
|
| 560 |
+
raise ValueError(f"Invalid line: {line}")
|
| 561 |
+
assert text.startswith("[S1]")
|
| 562 |
+
|
| 563 |
+
save_path = f"{res_dir}/{wav_name}.wav"
|
| 564 |
+
|
| 565 |
+
if model_name == "zipvoice_dialog":
|
| 566 |
+
|
| 567 |
+
metrics = generate_sentence(
|
| 568 |
+
save_path=save_path,
|
| 569 |
+
prompt_text=prompt_text,
|
| 570 |
+
prompt_wav=prompt_wav,
|
| 571 |
+
text=text,
|
| 572 |
+
model=model,
|
| 573 |
+
vocoder=vocoder,
|
| 574 |
+
tokenizer=tokenizer,
|
| 575 |
+
feature_extractor=feature_extractor,
|
| 576 |
+
device=device,
|
| 577 |
+
num_step=num_step,
|
| 578 |
+
guidance_scale=guidance_scale,
|
| 579 |
+
speed=speed,
|
| 580 |
+
t_shift=t_shift,
|
| 581 |
+
target_rms=target_rms,
|
| 582 |
+
feat_scale=feat_scale,
|
| 583 |
+
sampling_rate=sampling_rate,
|
| 584 |
+
)
|
| 585 |
+
else:
|
| 586 |
+
assert model_name == "zipvoice_dialog_stereo"
|
| 587 |
+
metrics = generate_sentence_stereo(
|
| 588 |
+
save_path=save_path,
|
| 589 |
+
prompt_text=prompt_text,
|
| 590 |
+
prompt_wav=prompt_wav,
|
| 591 |
+
text=text,
|
| 592 |
+
model=model,
|
| 593 |
+
vocoder=vocoder,
|
| 594 |
+
tokenizer=tokenizer,
|
| 595 |
+
feature_extractor=feature_extractor,
|
| 596 |
+
device=device,
|
| 597 |
+
num_step=num_step,
|
| 598 |
+
guidance_scale=guidance_scale,
|
| 599 |
+
speed=speed,
|
| 600 |
+
t_shift=t_shift,
|
| 601 |
+
target_rms=target_rms,
|
| 602 |
+
feat_scale=feat_scale,
|
| 603 |
+
sampling_rate=sampling_rate,
|
| 604 |
+
silence_wav=silence_wav,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
|
| 608 |
+
total_t.append(metrics["t"])
|
| 609 |
+
total_t_no_vocoder.append(metrics["t_no_vocoder"])
|
| 610 |
+
total_t_vocoder.append(metrics["t_vocoder"])
|
| 611 |
+
total_wav_seconds.append(metrics["wav_seconds"])
|
| 612 |
+
|
| 613 |
+
logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
|
| 614 |
+
logging.info(
|
| 615 |
+
f"Average RTF w/o vocoder: "
|
| 616 |
+
f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 617 |
+
)
|
| 618 |
+
logging.info(
|
| 619 |
+
f"Average RTF vocoder: "
|
| 620 |
+
f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
@torch.inference_mode()
|
| 625 |
+
def main():
|
| 626 |
+
parser = get_parser()
|
| 627 |
+
args = parser.parse_args()
|
| 628 |
+
|
| 629 |
+
params = AttributeDict()
|
| 630 |
+
params.update(vars(args))
|
| 631 |
+
fix_random_seed(params.seed)
|
| 632 |
+
|
| 633 |
+
assert (
|
| 634 |
+
params.test_list is not None
|
| 635 |
+
), "For inference, please provide prompts and text with '--test-list'"
|
| 636 |
+
|
| 637 |
+
if params.model_dir is not None:
|
| 638 |
+
params.model_dir = Path(params.model_dir)
|
| 639 |
+
if not params.model_dir.is_dir():
|
| 640 |
+
raise FileNotFoundError(f"{params.model_dir} does not exist")
|
| 641 |
+
for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
|
| 642 |
+
if not (params.model_dir / filename).is_file():
|
| 643 |
+
raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
|
| 644 |
+
model_ckpt = params.model_dir / params.checkpoint_name
|
| 645 |
+
model_config = params.model_dir / "model.json"
|
| 646 |
+
token_file = params.model_dir / "tokens.txt"
|
| 647 |
+
logging.info(
|
| 648 |
+
f"Using local model dir {params.model_dir}, "
|
| 649 |
+
f"checkpoint {params.checkpoint_name}"
|
| 650 |
+
)
|
| 651 |
+
else:
|
| 652 |
+
logging.info("Using pretrained model from the huggingface")
|
| 653 |
+
logging.info("Downloading the requires files from HuggingFace")
|
| 654 |
+
model_ckpt = hf_hub_download(
|
| 655 |
+
HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.pt"
|
| 656 |
+
)
|
| 657 |
+
model_config = hf_hub_download(
|
| 658 |
+
HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
token_file = hf_hub_download(
|
| 662 |
+
HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
logging.info("Loading model...")
|
| 666 |
+
|
| 667 |
+
tokenizer = DialogTokenizer(token_file=token_file)
|
| 668 |
+
|
| 669 |
+
tokenizer_config = {
|
| 670 |
+
"vocab_size": tokenizer.vocab_size,
|
| 671 |
+
"pad_id": tokenizer.pad_id,
|
| 672 |
+
"spk_a_id": tokenizer.spk_a_id,
|
| 673 |
+
"spk_b_id": tokenizer.spk_b_id,
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
with open(model_config, "r") as f:
|
| 677 |
+
model_config = json.load(f)
|
| 678 |
+
|
| 679 |
+
if params.model_name == "zipvoice_dialog":
|
| 680 |
+
model = ZipVoiceDialog(
|
| 681 |
+
**model_config["model"],
|
| 682 |
+
**tokenizer_config,
|
| 683 |
+
)
|
| 684 |
+
else:
|
| 685 |
+
assert params.model_name == "zipvoice_dialog_stereo"
|
| 686 |
+
model = ZipVoiceDialogStereo(
|
| 687 |
+
**model_config["model"],
|
| 688 |
+
**tokenizer_config,
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
if str(model_ckpt).endswith(".safetensors"):
|
| 692 |
+
safetensors.torch.load_model(model, model_ckpt)
|
| 693 |
+
elif str(model_ckpt).endswith(".pt"):
|
| 694 |
+
load_checkpoint(filename=model_ckpt, model=model, strict=True)
|
| 695 |
+
else:
|
| 696 |
+
raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
|
| 697 |
+
|
| 698 |
+
if torch.cuda.is_available():
|
| 699 |
+
params.device = torch.device("cuda", 0)
|
| 700 |
+
elif torch.backends.mps.is_available():
|
| 701 |
+
params.device = torch.device("mps")
|
| 702 |
+
else:
|
| 703 |
+
params.device = torch.device("cpu")
|
| 704 |
+
logging.info(f"Device: {params.device}")
|
| 705 |
+
|
| 706 |
+
model = model.to(params.device)
|
| 707 |
+
model.eval()
|
| 708 |
+
|
| 709 |
+
vocoder = get_vocoder(params.vocoder_path)
|
| 710 |
+
vocoder = vocoder.to(params.device)
|
| 711 |
+
vocoder.eval()
|
| 712 |
+
|
| 713 |
+
if model_config["feature"]["type"] == "vocos":
|
| 714 |
+
if params.model_name == "zipvoice_dialog":
|
| 715 |
+
num_channels = 1
|
| 716 |
+
else:
|
| 717 |
+
assert params.model_name == "zipvoice_dialog_stereo"
|
| 718 |
+
num_channels = 2
|
| 719 |
+
feature_extractor = VocosFbank(num_channels=num_channels)
|
| 720 |
+
else:
|
| 721 |
+
raise NotImplementedError(
|
| 722 |
+
f"Unsupported feature type: {model_config['feature']['type']}"
|
| 723 |
+
)
|
| 724 |
+
params.sampling_rate = model_config["feature"]["sampling_rate"]
|
| 725 |
+
|
| 726 |
+
logging.info("Start generating...")
|
| 727 |
+
os.makedirs(params.res_dir, exist_ok=True)
|
| 728 |
+
generate_list(
|
| 729 |
+
model_name=params.model_name,
|
| 730 |
+
res_dir=params.res_dir,
|
| 731 |
+
test_list=params.test_list,
|
| 732 |
+
model=model,
|
| 733 |
+
vocoder=vocoder,
|
| 734 |
+
tokenizer=tokenizer,
|
| 735 |
+
feature_extractor=feature_extractor,
|
| 736 |
+
device=params.device,
|
| 737 |
+
num_step=params.num_step,
|
| 738 |
+
guidance_scale=params.guidance_scale,
|
| 739 |
+
speed=params.speed,
|
| 740 |
+
t_shift=params.t_shift,
|
| 741 |
+
target_rms=params.target_rms,
|
| 742 |
+
feat_scale=params.feat_scale,
|
| 743 |
+
sampling_rate=params.sampling_rate,
|
| 744 |
+
silence_wav=params.silence_wav,
|
| 745 |
+
)
|
| 746 |
+
logging.info("Done")
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
if __name__ == "__main__":
|
| 750 |
+
torch.set_num_threads(1)
|
| 751 |
+
torch.set_num_interop_threads(1)
|
| 752 |
+
|
| 753 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 754 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 755 |
+
|
| 756 |
+
main()
|
zipvoice/bin/infer_zipvoice_onnx.py
ADDED
|
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
|
| 2 |
+
# Zengwei Yao)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This script generates speech with our pre-trained ZipVoice or ZipVoice-Distill
|
| 20 |
+
ONNX models. If no local model is specified,
|
| 21 |
+
Required files will be automatically downloaded from HuggingFace.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
|
| 25 |
+
Note: If you having trouble connecting to HuggingFace,
|
| 26 |
+
try switching endpoint to mirror site:
|
| 27 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 28 |
+
|
| 29 |
+
(1) Inference of a single sentence:
|
| 30 |
+
|
| 31 |
+
python3 -m zipvoice.bin.infer_zipvoice_onnx \
|
| 32 |
+
--onnx-int8 False \
|
| 33 |
+
--model-name zipvoice \
|
| 34 |
+
--prompt-wav prompt.wav \
|
| 35 |
+
--prompt-text "I am a prompt." \
|
| 36 |
+
--text "I am a sentence." \
|
| 37 |
+
--res-wav-path result.wav
|
| 38 |
+
|
| 39 |
+
(2) Inference of a list of sentences:
|
| 40 |
+
python3 -m zipvoice.bin.infer_zipvoice_onnx \
|
| 41 |
+
--onnx-int8 False \
|
| 42 |
+
--model-name zipvoice \
|
| 43 |
+
--test-list test.tsv \
|
| 44 |
+
--res-dir results
|
| 45 |
+
|
| 46 |
+
`--model-name` can be `zipvoice` or `zipvoice_distill`,
|
| 47 |
+
which are the models before and after distillation, respectively.
|
| 48 |
+
|
| 49 |
+
Each line of `test.tsv` is in the format of
|
| 50 |
+
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
|
| 51 |
+
|
| 52 |
+
Set `--onnx-int8 True` to use int8 quantizated ONNX model.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
import argparse
|
| 56 |
+
import datetime as dt
|
| 57 |
+
import json
|
| 58 |
+
import logging
|
| 59 |
+
import os
|
| 60 |
+
from pathlib import Path
|
| 61 |
+
from typing import List, Tuple
|
| 62 |
+
|
| 63 |
+
import numpy as np
|
| 64 |
+
import onnxruntime as ort
|
| 65 |
+
import torch
|
| 66 |
+
import torchaudio
|
| 67 |
+
from huggingface_hub import hf_hub_download
|
| 68 |
+
from lhotse.utils import fix_random_seed
|
| 69 |
+
from torch import Tensor, nn
|
| 70 |
+
|
| 71 |
+
from zipvoice.bin.infer_zipvoice import get_vocoder
|
| 72 |
+
from zipvoice.models.modules.solver import get_time_steps
|
| 73 |
+
from zipvoice.tokenizer.tokenizer import (
|
| 74 |
+
EmiliaTokenizer,
|
| 75 |
+
EspeakTokenizer,
|
| 76 |
+
LibriTTSTokenizer,
|
| 77 |
+
SimpleTokenizer,
|
| 78 |
+
)
|
| 79 |
+
from zipvoice.utils.common import AttributeDict, str2bool
|
| 80 |
+
from zipvoice.utils.feature import VocosFbank
|
| 81 |
+
|
| 82 |
+
HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
|
| 83 |
+
MODEL_DIR = {
|
| 84 |
+
"zipvoice": "zipvoice",
|
| 85 |
+
"zipvoice_distill": "zipvoice_distill",
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_parser():
|
| 90 |
+
parser = argparse.ArgumentParser(
|
| 91 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--onnx-int8",
|
| 96 |
+
type=str2bool,
|
| 97 |
+
default=False,
|
| 98 |
+
help="Whether to use the int8 model",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--model-name",
|
| 103 |
+
type=str,
|
| 104 |
+
default="zipvoice",
|
| 105 |
+
choices=["zipvoice", "zipvoice_distill"],
|
| 106 |
+
help="The model used for inference",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--model-dir",
|
| 111 |
+
type=str,
|
| 112 |
+
default=None,
|
| 113 |
+
help="The path to the local onnx model. "
|
| 114 |
+
"Will download pre-trained checkpoint from huggingface if not specified.",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--vocoder-path",
|
| 119 |
+
type=str,
|
| 120 |
+
default=None,
|
| 121 |
+
help="The vocoder checkpoint. "
|
| 122 |
+
"Will download pre-trained vocoder from huggingface if not specified.",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--tokenizer",
|
| 127 |
+
type=str,
|
| 128 |
+
default="emilia",
|
| 129 |
+
choices=["emilia", "libritts", "espeak", "simple"],
|
| 130 |
+
help="Tokenizer type.",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--lang",
|
| 135 |
+
type=str,
|
| 136 |
+
default="en-us",
|
| 137 |
+
help="Language identifier, used when tokenizer type is espeak. see"
|
| 138 |
+
"https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--test-list",
|
| 143 |
+
type=str,
|
| 144 |
+
default=None,
|
| 145 |
+
help="The list of prompt speech, prompt_transcription, "
|
| 146 |
+
"and text to synthesizein the format of "
|
| 147 |
+
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--prompt-wav",
|
| 152 |
+
type=str,
|
| 153 |
+
default=None,
|
| 154 |
+
help="The prompt wav to mimic",
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--prompt-text",
|
| 159 |
+
type=str,
|
| 160 |
+
default=None,
|
| 161 |
+
help="The transcription of the prompt wav",
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--text",
|
| 166 |
+
type=str,
|
| 167 |
+
default=None,
|
| 168 |
+
help="The text to synthesize",
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--res-dir",
|
| 173 |
+
type=str,
|
| 174 |
+
default="results",
|
| 175 |
+
help="""
|
| 176 |
+
Path name of the generated wavs dir,
|
| 177 |
+
used when test-list is not None
|
| 178 |
+
""",
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--res-wav-path",
|
| 183 |
+
type=str,
|
| 184 |
+
default="result.wav",
|
| 185 |
+
help="""
|
| 186 |
+
Path name of the generated wav path,
|
| 187 |
+
used when test-list is None
|
| 188 |
+
""",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--guidance-scale",
|
| 193 |
+
type=float,
|
| 194 |
+
default=None,
|
| 195 |
+
help="The scale of classifier-free guidance during inference.",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--num-step",
|
| 200 |
+
type=int,
|
| 201 |
+
default=None,
|
| 202 |
+
help="The number of sampling steps.",
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
parser.add_argument(
|
| 206 |
+
"--feat-scale",
|
| 207 |
+
type=float,
|
| 208 |
+
default=0.1,
|
| 209 |
+
help="The scale factor of fbank feature",
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--speed",
|
| 214 |
+
type=float,
|
| 215 |
+
default=1.0,
|
| 216 |
+
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"--t-shift",
|
| 221 |
+
type=float,
|
| 222 |
+
default=0.5,
|
| 223 |
+
help="Shift t to smaller ones if t_shift < 1.0",
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--target-rms",
|
| 228 |
+
type=float,
|
| 229 |
+
default=0.1,
|
| 230 |
+
help="Target speech normalization rms value, set to 0 to disable normalization",
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--seed",
|
| 235 |
+
type=int,
|
| 236 |
+
default=666,
|
| 237 |
+
help="Random seed",
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
return parser
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class OnnxModel:
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
text_encoder_path: str,
|
| 247 |
+
fm_decoder_path: str,
|
| 248 |
+
):
|
| 249 |
+
session_opts = ort.SessionOptions()
|
| 250 |
+
session_opts.inter_op_num_threads = 1
|
| 251 |
+
session_opts.intra_op_num_threads = 1
|
| 252 |
+
|
| 253 |
+
self.session_opts = session_opts
|
| 254 |
+
|
| 255 |
+
self.init_text_encoder(text_encoder_path)
|
| 256 |
+
self.init_fm_decoder(fm_decoder_path)
|
| 257 |
+
|
| 258 |
+
def init_text_encoder(self, model_path: str):
|
| 259 |
+
self.text_encoder = ort.InferenceSession(
|
| 260 |
+
model_path,
|
| 261 |
+
sess_options=self.session_opts,
|
| 262 |
+
providers=["CPUExecutionProvider"],
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
def init_fm_decoder(self, model_path: str):
|
| 266 |
+
self.fm_decoder = ort.InferenceSession(
|
| 267 |
+
model_path,
|
| 268 |
+
sess_options=self.session_opts,
|
| 269 |
+
providers=["CPUExecutionProvider"],
|
| 270 |
+
)
|
| 271 |
+
meta = self.fm_decoder.get_modelmeta().custom_metadata_map
|
| 272 |
+
self.feat_dim = int(meta["feat_dim"])
|
| 273 |
+
|
| 274 |
+
def run_text_encoder(
|
| 275 |
+
self,
|
| 276 |
+
tokens: Tensor,
|
| 277 |
+
prompt_tokens: Tensor,
|
| 278 |
+
prompt_features_len: Tensor,
|
| 279 |
+
speed: Tensor,
|
| 280 |
+
) -> Tuple[Tensor, Tensor]:
|
| 281 |
+
out = self.text_encoder.run(
|
| 282 |
+
[
|
| 283 |
+
self.text_encoder.get_outputs()[0].name,
|
| 284 |
+
],
|
| 285 |
+
{
|
| 286 |
+
self.text_encoder.get_inputs()[0].name: tokens.numpy(),
|
| 287 |
+
self.text_encoder.get_inputs()[1].name: prompt_tokens.numpy(),
|
| 288 |
+
self.text_encoder.get_inputs()[2].name: prompt_features_len.numpy(),
|
| 289 |
+
self.text_encoder.get_inputs()[3].name: speed.numpy(),
|
| 290 |
+
},
|
| 291 |
+
)
|
| 292 |
+
return torch.from_numpy(out[0])
|
| 293 |
+
|
| 294 |
+
def run_fm_decoder(
|
| 295 |
+
self,
|
| 296 |
+
t: Tensor,
|
| 297 |
+
x: Tensor,
|
| 298 |
+
text_condition: Tensor,
|
| 299 |
+
speech_condition: torch.Tensor,
|
| 300 |
+
guidance_scale: Tensor,
|
| 301 |
+
) -> Tensor:
|
| 302 |
+
out = self.fm_decoder.run(
|
| 303 |
+
[
|
| 304 |
+
self.fm_decoder.get_outputs()[0].name,
|
| 305 |
+
],
|
| 306 |
+
{
|
| 307 |
+
self.fm_decoder.get_inputs()[0].name: t.numpy(),
|
| 308 |
+
self.fm_decoder.get_inputs()[1].name: x.numpy(),
|
| 309 |
+
self.fm_decoder.get_inputs()[2].name: text_condition.numpy(),
|
| 310 |
+
self.fm_decoder.get_inputs()[3].name: speech_condition.numpy(),
|
| 311 |
+
self.fm_decoder.get_inputs()[4].name: guidance_scale.numpy(),
|
| 312 |
+
},
|
| 313 |
+
)
|
| 314 |
+
return torch.from_numpy(out[0])
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def sample(
|
| 318 |
+
model: OnnxModel,
|
| 319 |
+
tokens: List[List[int]],
|
| 320 |
+
prompt_tokens: List[List[int]],
|
| 321 |
+
prompt_features: Tensor,
|
| 322 |
+
speed: float = 1.0,
|
| 323 |
+
t_shift: float = 0.5,
|
| 324 |
+
guidance_scale: float = 1.0,
|
| 325 |
+
num_step: int = 16,
|
| 326 |
+
) -> torch.Tensor:
|
| 327 |
+
"""
|
| 328 |
+
Generate acoustic features, given text tokens, prompts feature and prompt
|
| 329 |
+
transcription's text tokens.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
tokens: a list of list of text tokens.
|
| 333 |
+
prompt_tokens: a list of list of prompt tokens.
|
| 334 |
+
prompt_features: the prompt feature with the shape
|
| 335 |
+
(batch_size, seq_len, feat_dim).
|
| 336 |
+
speed : speed control.
|
| 337 |
+
t_shift: time shift.
|
| 338 |
+
guidance_scale: the guidance scale for classifier-free guidance.
|
| 339 |
+
num_step: the number of steps to use in the ODE solver.
|
| 340 |
+
"""
|
| 341 |
+
# Run text encoder
|
| 342 |
+
assert len(tokens) == len(prompt_tokens) == 1
|
| 343 |
+
tokens = torch.tensor(tokens, dtype=torch.int64)
|
| 344 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.int64)
|
| 345 |
+
prompt_features_len = torch.tensor(prompt_features.size(1), dtype=torch.int64)
|
| 346 |
+
speed = torch.tensor(speed, dtype=torch.float32)
|
| 347 |
+
|
| 348 |
+
text_condition = model.run_text_encoder(
|
| 349 |
+
tokens, prompt_tokens, prompt_features_len, speed
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
batch_size, num_frames, _ = text_condition.shape
|
| 353 |
+
assert batch_size == 1
|
| 354 |
+
feat_dim = model.feat_dim
|
| 355 |
+
|
| 356 |
+
# Run flow matching model
|
| 357 |
+
timesteps = get_time_steps(
|
| 358 |
+
t_start=0.0,
|
| 359 |
+
t_end=1.0,
|
| 360 |
+
num_step=num_step,
|
| 361 |
+
t_shift=t_shift,
|
| 362 |
+
)
|
| 363 |
+
x = torch.randn(batch_size, num_frames, feat_dim)
|
| 364 |
+
speech_condition = torch.nn.functional.pad(
|
| 365 |
+
prompt_features, (0, 0, 0, num_frames - prompt_features.shape[1])
|
| 366 |
+
) # (B, T, F)
|
| 367 |
+
guidance_scale = torch.tensor(guidance_scale, dtype=torch.float32)
|
| 368 |
+
|
| 369 |
+
for step in range(num_step):
|
| 370 |
+
v = model.run_fm_decoder(
|
| 371 |
+
t=timesteps[step],
|
| 372 |
+
x=x,
|
| 373 |
+
text_condition=text_condition,
|
| 374 |
+
speech_condition=speech_condition,
|
| 375 |
+
guidance_scale=guidance_scale,
|
| 376 |
+
)
|
| 377 |
+
x = x + v * (timesteps[step + 1] - timesteps[step])
|
| 378 |
+
|
| 379 |
+
x = x[:, prompt_features_len.item() :, :]
|
| 380 |
+
return x
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# Copied from zipvoice/bin/infer_zipvoice.py, but call an external sample function
|
| 384 |
+
def generate_sentence(
|
| 385 |
+
save_path: str,
|
| 386 |
+
prompt_text: str,
|
| 387 |
+
prompt_wav: str,
|
| 388 |
+
text: str,
|
| 389 |
+
model: OnnxModel,
|
| 390 |
+
vocoder: nn.Module,
|
| 391 |
+
tokenizer: EmiliaTokenizer,
|
| 392 |
+
feature_extractor: VocosFbank,
|
| 393 |
+
num_step: int = 16,
|
| 394 |
+
guidance_scale: float = 1.0,
|
| 395 |
+
speed: float = 1.0,
|
| 396 |
+
t_shift: float = 0.5,
|
| 397 |
+
target_rms: float = 0.1,
|
| 398 |
+
feat_scale: float = 0.1,
|
| 399 |
+
sampling_rate: int = 24000,
|
| 400 |
+
):
|
| 401 |
+
"""
|
| 402 |
+
Generate waveform of a text based on a given prompt
|
| 403 |
+
waveform and its transcription.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
save_path (str): Path to save the generated wav.
|
| 407 |
+
prompt_text (str): Transcription of the prompt wav.
|
| 408 |
+
prompt_wav (str): Path to the prompt wav file.
|
| 409 |
+
text (str): Text to be synthesized into a waveform.
|
| 410 |
+
model (torch.nn.Module): The model used for generation.
|
| 411 |
+
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
|
| 412 |
+
tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
|
| 413 |
+
feature_extractor (VocosFbank): The feature extractor used to
|
| 414 |
+
extract acoustic features.
|
| 415 |
+
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
| 416 |
+
guidance_scale (float, optional): Scale for classifier-free guidance.
|
| 417 |
+
Defaults to 1.0.
|
| 418 |
+
speed (float, optional): Speed control. Defaults to 1.0.
|
| 419 |
+
t_shift (float, optional): Time shift. Defaults to 0.5.
|
| 420 |
+
target_rms (float, optional): Target RMS for waveform normalization.
|
| 421 |
+
Defaults to 0.1.
|
| 422 |
+
feat_scale (float, optional): Scale for features.
|
| 423 |
+
Defaults to 0.1.
|
| 424 |
+
sampling_rate (int, optional): Sampling rate for the waveform.
|
| 425 |
+
Defaults to 24000.
|
| 426 |
+
Returns:
|
| 427 |
+
metrics (dict): Dictionary containing time and real-time
|
| 428 |
+
factor metrics for processing.
|
| 429 |
+
"""
|
| 430 |
+
# Convert text to tokens
|
| 431 |
+
tokens = tokenizer.texts_to_token_ids([text])
|
| 432 |
+
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
| 433 |
+
|
| 434 |
+
# Load and preprocess prompt wav
|
| 435 |
+
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
|
| 436 |
+
|
| 437 |
+
if prompt_sampling_rate != sampling_rate:
|
| 438 |
+
resampler = torchaudio.transforms.Resample(
|
| 439 |
+
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
| 440 |
+
)
|
| 441 |
+
prompt_wav = resampler(prompt_wav)
|
| 442 |
+
|
| 443 |
+
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
| 444 |
+
if prompt_rms < target_rms:
|
| 445 |
+
prompt_wav = prompt_wav * target_rms / prompt_rms
|
| 446 |
+
|
| 447 |
+
# Extract features from prompt wav
|
| 448 |
+
prompt_features = feature_extractor.extract(prompt_wav, sampling_rate=sampling_rate)
|
| 449 |
+
|
| 450 |
+
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
| 451 |
+
|
| 452 |
+
# Start timing
|
| 453 |
+
start_t = dt.datetime.now()
|
| 454 |
+
|
| 455 |
+
# Generate features
|
| 456 |
+
pred_features = sample(
|
| 457 |
+
model=model,
|
| 458 |
+
tokens=tokens,
|
| 459 |
+
prompt_tokens=prompt_tokens,
|
| 460 |
+
prompt_features=prompt_features,
|
| 461 |
+
speed=speed,
|
| 462 |
+
t_shift=t_shift,
|
| 463 |
+
guidance_scale=guidance_scale,
|
| 464 |
+
num_step=num_step,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# Postprocess predicted features
|
| 468 |
+
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
| 469 |
+
|
| 470 |
+
# Start vocoder processing
|
| 471 |
+
start_vocoder_t = dt.datetime.now()
|
| 472 |
+
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
| 473 |
+
|
| 474 |
+
# Calculate processing times and real-time factors
|
| 475 |
+
t = (dt.datetime.now() - start_t).total_seconds()
|
| 476 |
+
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
| 477 |
+
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
| 478 |
+
wav_seconds = wav.shape[-1] / sampling_rate
|
| 479 |
+
rtf = t / wav_seconds
|
| 480 |
+
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
| 481 |
+
rtf_vocoder = t_vocoder / wav_seconds
|
| 482 |
+
metrics = {
|
| 483 |
+
"t": t,
|
| 484 |
+
"t_no_vocoder": t_no_vocoder,
|
| 485 |
+
"t_vocoder": t_vocoder,
|
| 486 |
+
"wav_seconds": wav_seconds,
|
| 487 |
+
"rtf": rtf,
|
| 488 |
+
"rtf_no_vocoder": rtf_no_vocoder,
|
| 489 |
+
"rtf_vocoder": rtf_vocoder,
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
# Adjust wav volume if necessary
|
| 493 |
+
if prompt_rms < target_rms:
|
| 494 |
+
wav = wav * prompt_rms / target_rms
|
| 495 |
+
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
|
| 496 |
+
|
| 497 |
+
return metrics
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def generate_list(
|
| 501 |
+
res_dir: str,
|
| 502 |
+
test_list: str,
|
| 503 |
+
model: OnnxModel,
|
| 504 |
+
vocoder: nn.Module,
|
| 505 |
+
tokenizer: EmiliaTokenizer,
|
| 506 |
+
feature_extractor: VocosFbank,
|
| 507 |
+
num_step: int = 16,
|
| 508 |
+
guidance_scale: float = 1.0,
|
| 509 |
+
speed: float = 1.0,
|
| 510 |
+
t_shift: float = 0.5,
|
| 511 |
+
target_rms: float = 0.1,
|
| 512 |
+
feat_scale: float = 0.1,
|
| 513 |
+
sampling_rate: int = 24000,
|
| 514 |
+
):
|
| 515 |
+
total_t = []
|
| 516 |
+
total_t_no_vocoder = []
|
| 517 |
+
total_t_vocoder = []
|
| 518 |
+
total_wav_seconds = []
|
| 519 |
+
|
| 520 |
+
with open(test_list, "r") as fr:
|
| 521 |
+
lines = fr.readlines()
|
| 522 |
+
|
| 523 |
+
for i, line in enumerate(lines):
|
| 524 |
+
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
| 525 |
+
save_path = f"{res_dir}/{wav_name}.wav"
|
| 526 |
+
metrics = generate_sentence(
|
| 527 |
+
save_path=save_path,
|
| 528 |
+
prompt_text=prompt_text,
|
| 529 |
+
prompt_wav=prompt_wav,
|
| 530 |
+
text=text,
|
| 531 |
+
model=model,
|
| 532 |
+
vocoder=vocoder,
|
| 533 |
+
tokenizer=tokenizer,
|
| 534 |
+
feature_extractor=feature_extractor,
|
| 535 |
+
num_step=num_step,
|
| 536 |
+
guidance_scale=guidance_scale,
|
| 537 |
+
speed=speed,
|
| 538 |
+
t_shift=t_shift,
|
| 539 |
+
target_rms=target_rms,
|
| 540 |
+
feat_scale=feat_scale,
|
| 541 |
+
sampling_rate=sampling_rate,
|
| 542 |
+
)
|
| 543 |
+
logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
|
| 544 |
+
total_t.append(metrics["t"])
|
| 545 |
+
total_t_no_vocoder.append(metrics["t_no_vocoder"])
|
| 546 |
+
total_t_vocoder.append(metrics["t_vocoder"])
|
| 547 |
+
total_wav_seconds.append(metrics["wav_seconds"])
|
| 548 |
+
|
| 549 |
+
logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
|
| 550 |
+
logging.info(
|
| 551 |
+
f"Average RTF w/o vocoder: "
|
| 552 |
+
f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 553 |
+
)
|
| 554 |
+
logging.info(
|
| 555 |
+
f"Average RTF vocoder: "
|
| 556 |
+
f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
@torch.inference_mode()
|
| 561 |
+
def main():
|
| 562 |
+
parser = get_parser()
|
| 563 |
+
args = parser.parse_args()
|
| 564 |
+
|
| 565 |
+
params = AttributeDict()
|
| 566 |
+
params.update(vars(args))
|
| 567 |
+
fix_random_seed(params.seed)
|
| 568 |
+
|
| 569 |
+
model_defaults = {
|
| 570 |
+
"zipvoice": {
|
| 571 |
+
"num_step": 16,
|
| 572 |
+
"guidance_scale": 1.0,
|
| 573 |
+
},
|
| 574 |
+
"zipvoice_distill": {
|
| 575 |
+
"num_step": 8,
|
| 576 |
+
"guidance_scale": 3.0,
|
| 577 |
+
},
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
model_specific_defaults = model_defaults.get(params.model_name, {})
|
| 581 |
+
|
| 582 |
+
for param, value in model_specific_defaults.items():
|
| 583 |
+
if getattr(params, param) is None:
|
| 584 |
+
setattr(params, param, value)
|
| 585 |
+
logging.info(f"Setting {param} to default value: {value}")
|
| 586 |
+
|
| 587 |
+
assert (params.test_list is not None) ^ (
|
| 588 |
+
(params.prompt_wav and params.prompt_text and params.text) is not None
|
| 589 |
+
), (
|
| 590 |
+
"For inference, please provide prompts and text with either '--test-list'"
|
| 591 |
+
" or '--prompt-wav, --prompt-text and --text'."
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
if params.onnx_int8:
|
| 595 |
+
text_encoder_name = "text_encoder_int8.onnx"
|
| 596 |
+
fm_decoder_name = "fm_decoder_int8.onnx"
|
| 597 |
+
else:
|
| 598 |
+
text_encoder_name = "text_encoder.onnx"
|
| 599 |
+
fm_decoder_name = "fm_decoder.onnx"
|
| 600 |
+
|
| 601 |
+
if params.model_dir is not None:
|
| 602 |
+
params.model_dir = Path(params.model_dir)
|
| 603 |
+
if not params.model_dir.is_dir():
|
| 604 |
+
raise FileNotFoundError(f"{params.model_dir} does not exist")
|
| 605 |
+
|
| 606 |
+
for filename in [
|
| 607 |
+
text_encoder_name,
|
| 608 |
+
fm_decoder_name,
|
| 609 |
+
"model.json",
|
| 610 |
+
"tokens.txt",
|
| 611 |
+
]:
|
| 612 |
+
if not (params.model_dir / filename).is_file():
|
| 613 |
+
raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
|
| 614 |
+
text_encoder_path = params.model_dir / text_encoder_name
|
| 615 |
+
fm_decoder_path = params.model_dir / fm_decoder_name
|
| 616 |
+
model_config = params.model_dir / "model.json"
|
| 617 |
+
token_file = params.model_dir / "tokens.txt"
|
| 618 |
+
logging.info(f"Using local model dir {params.model_dir}.")
|
| 619 |
+
else:
|
| 620 |
+
logging.info("Using pretrained model from the huggingface")
|
| 621 |
+
logging.info("Downloading the requires files from HuggingFace")
|
| 622 |
+
text_encoder_path = hf_hub_download(
|
| 623 |
+
HUGGINGFACE_REPO,
|
| 624 |
+
filename=f"{MODEL_DIR[params.model_name]}/{text_encoder_name}",
|
| 625 |
+
)
|
| 626 |
+
fm_decoder_path = hf_hub_download(
|
| 627 |
+
HUGGINGFACE_REPO,
|
| 628 |
+
filename=f"{MODEL_DIR[params.model_name]}/{fm_decoder_name}",
|
| 629 |
+
)
|
| 630 |
+
model_config = hf_hub_download(
|
| 631 |
+
HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
token_file = hf_hub_download(
|
| 635 |
+
HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
logging.info("Loading model...")
|
| 639 |
+
|
| 640 |
+
if params.tokenizer == "emilia":
|
| 641 |
+
tokenizer = EmiliaTokenizer(token_file=token_file)
|
| 642 |
+
elif params.tokenizer == "libritts":
|
| 643 |
+
tokenizer = LibriTTSTokenizer(token_file=token_file)
|
| 644 |
+
elif params.tokenizer == "espeak":
|
| 645 |
+
tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
|
| 646 |
+
else:
|
| 647 |
+
assert params.tokenizer == "simple"
|
| 648 |
+
tokenizer = SimpleTokenizer(token_file=token_file)
|
| 649 |
+
|
| 650 |
+
with open(model_config, "r") as f:
|
| 651 |
+
model_config = json.load(f)
|
| 652 |
+
|
| 653 |
+
model = OnnxModel(text_encoder_path, fm_decoder_path)
|
| 654 |
+
|
| 655 |
+
vocoder = get_vocoder(params.vocoder_path)
|
| 656 |
+
vocoder.eval()
|
| 657 |
+
|
| 658 |
+
if model_config["feature"]["type"] == "vocos":
|
| 659 |
+
feature_extractor = VocosFbank()
|
| 660 |
+
else:
|
| 661 |
+
raise NotImplementedError(
|
| 662 |
+
f"Unsupported feature type: {model_config['feature']['type']}"
|
| 663 |
+
)
|
| 664 |
+
params.sampling_rate = model_config["feature"]["sampling_rate"]
|
| 665 |
+
|
| 666 |
+
logging.info("Start generating...")
|
| 667 |
+
if params.test_list:
|
| 668 |
+
os.makedirs(params.res_dir, exist_ok=True)
|
| 669 |
+
generate_list(
|
| 670 |
+
res_dir=params.res_dir,
|
| 671 |
+
test_list=params.test_list,
|
| 672 |
+
model=model,
|
| 673 |
+
vocoder=vocoder,
|
| 674 |
+
tokenizer=tokenizer,
|
| 675 |
+
feature_extractor=feature_extractor,
|
| 676 |
+
num_step=params.num_step,
|
| 677 |
+
guidance_scale=params.guidance_scale,
|
| 678 |
+
speed=params.speed,
|
| 679 |
+
t_shift=params.t_shift,
|
| 680 |
+
target_rms=params.target_rms,
|
| 681 |
+
feat_scale=params.feat_scale,
|
| 682 |
+
sampling_rate=params.sampling_rate,
|
| 683 |
+
)
|
| 684 |
+
else:
|
| 685 |
+
generate_sentence(
|
| 686 |
+
save_path=params.res_wav_path,
|
| 687 |
+
prompt_text=params.prompt_text,
|
| 688 |
+
prompt_wav=params.prompt_wav,
|
| 689 |
+
text=params.text,
|
| 690 |
+
model=model,
|
| 691 |
+
vocoder=vocoder,
|
| 692 |
+
tokenizer=tokenizer,
|
| 693 |
+
feature_extractor=feature_extractor,
|
| 694 |
+
num_step=params.num_step,
|
| 695 |
+
guidance_scale=params.guidance_scale,
|
| 696 |
+
speed=params.speed,
|
| 697 |
+
t_shift=params.t_shift,
|
| 698 |
+
target_rms=params.target_rms,
|
| 699 |
+
feat_scale=params.feat_scale,
|
| 700 |
+
sampling_rate=params.sampling_rate,
|
| 701 |
+
)
|
| 702 |
+
logging.info("Done")
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
if __name__ == "__main__":
|
| 706 |
+
torch.set_num_threads(1)
|
| 707 |
+
torch.set_num_interop_threads(1)
|
| 708 |
+
|
| 709 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 710 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 711 |
+
|
| 712 |
+
main()
|
zipvoice/bin/onnx_export.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Zengwei Yao)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This script exports a pre-trained ZipVoice or ZipVoice-Distill model from PyTorch to
|
| 20 |
+
ONNX.
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
|
| 24 |
+
python3 -m zipvoice.bin.onnx_export \
|
| 25 |
+
--model-name zipvoice \
|
| 26 |
+
--model-dir exp/zipvoice \
|
| 27 |
+
--checkpoint-name epoch-11-avg-4.pt \
|
| 28 |
+
--onnx-model-dir exp/zipvoice
|
| 29 |
+
|
| 30 |
+
`--model-name` can be `zipvoice` or `zipvoice_distill`,
|
| 31 |
+
which are the models before and after distillation, respectively.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import argparse
|
| 36 |
+
import json
|
| 37 |
+
import logging
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
from typing import Dict
|
| 40 |
+
|
| 41 |
+
import onnx
|
| 42 |
+
import safetensors.torch
|
| 43 |
+
import torch
|
| 44 |
+
from onnxruntime.quantization import QuantType, quantize_dynamic
|
| 45 |
+
from torch import Tensor, nn
|
| 46 |
+
|
| 47 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 48 |
+
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
|
| 49 |
+
from zipvoice.tokenizer.tokenizer import SimpleTokenizer
|
| 50 |
+
from zipvoice.utils.checkpoint import load_checkpoint
|
| 51 |
+
from zipvoice.utils.common import AttributeDict
|
| 52 |
+
from zipvoice.utils.scaling_converter import convert_scaled_to_non_scaled
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_parser():
|
| 56 |
+
parser = argparse.ArgumentParser(
|
| 57 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--onnx-model-dir",
|
| 62 |
+
type=str,
|
| 63 |
+
default="exp",
|
| 64 |
+
help="Dir to the exported models",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--model-name",
|
| 69 |
+
type=str,
|
| 70 |
+
default="zipvoice",
|
| 71 |
+
choices=["zipvoice", "zipvoice_distill"],
|
| 72 |
+
help="The model used for inference",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--model-dir",
|
| 77 |
+
type=str,
|
| 78 |
+
default=None,
|
| 79 |
+
help="The model directory that contains model checkpoint, configuration "
|
| 80 |
+
"file model.json, and tokens file tokens.txt. Will download pre-trained "
|
| 81 |
+
"checkpoint from huggingface if not specified.",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--checkpoint-name",
|
| 86 |
+
type=str,
|
| 87 |
+
default="model.pt",
|
| 88 |
+
help="The name of model checkpoint.",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return parser
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
| 95 |
+
"""Add meta data to an ONNX model. It is changed in-place.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
filename:
|
| 99 |
+
Filename of the ONNX model to be changed.
|
| 100 |
+
meta_data:
|
| 101 |
+
Key-value pairs.
|
| 102 |
+
"""
|
| 103 |
+
model = onnx.load(filename)
|
| 104 |
+
for key, value in meta_data.items():
|
| 105 |
+
meta = model.metadata_props.add()
|
| 106 |
+
meta.key = key
|
| 107 |
+
meta.value = value
|
| 108 |
+
|
| 109 |
+
onnx.save(model, filename)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class OnnxTextModel(nn.Module):
|
| 113 |
+
def __init__(self, model: nn.Module):
|
| 114 |
+
"""A wrapper for ZipVoice text encoder."""
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.embed = model.embed
|
| 117 |
+
self.text_encoder = model.text_encoder
|
| 118 |
+
self.pad_id = model.pad_id
|
| 119 |
+
|
| 120 |
+
def forward(
|
| 121 |
+
self,
|
| 122 |
+
tokens: Tensor,
|
| 123 |
+
prompt_tokens: Tensor,
|
| 124 |
+
prompt_features_len: Tensor,
|
| 125 |
+
speed: Tensor,
|
| 126 |
+
) -> Tensor:
|
| 127 |
+
cat_tokens = torch.cat([prompt_tokens, tokens], dim=1)
|
| 128 |
+
cat_tokens = nn.functional.pad(cat_tokens, (0, 1), value=self.pad_id)
|
| 129 |
+
tokens_len = cat_tokens.shape[1] - 1
|
| 130 |
+
padding_mask = (torch.arange(tokens_len + 1) == tokens_len).unsqueeze(0)
|
| 131 |
+
|
| 132 |
+
embed = self.embed(cat_tokens)
|
| 133 |
+
embed = self.text_encoder(x=embed, t=None, padding_mask=padding_mask)
|
| 134 |
+
|
| 135 |
+
features_len = torch.ceil(
|
| 136 |
+
(prompt_features_len / prompt_tokens.shape[1] * tokens_len / speed)
|
| 137 |
+
).to(dtype=torch.int64)
|
| 138 |
+
|
| 139 |
+
token_dur = torch.div(features_len, tokens_len, rounding_mode="floor").to(
|
| 140 |
+
dtype=torch.int64
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
text_condition = embed[:, :-1, :].unsqueeze(2).expand(-1, -1, token_dur, -1)
|
| 144 |
+
text_condition = text_condition.reshape(embed.shape[0], -1, embed.shape[2])
|
| 145 |
+
|
| 146 |
+
text_condition = torch.cat(
|
| 147 |
+
[
|
| 148 |
+
text_condition,
|
| 149 |
+
embed[:, -1:, :].expand(-1, features_len - text_condition.shape[1], -1),
|
| 150 |
+
],
|
| 151 |
+
dim=1,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return text_condition
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class OnnxFlowMatchingModel(nn.Module):
|
| 158 |
+
def __init__(self, model: nn.Module, distill: bool = False):
|
| 159 |
+
"""A wrapper for ZipVoice flow-matching decoder."""
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.distill = distill
|
| 162 |
+
self.fm_decoder = model.fm_decoder
|
| 163 |
+
self.model_func = getattr(model, "forward_fm_decoder")
|
| 164 |
+
self.feat_dim = model.feat_dim
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
t: Tensor,
|
| 169 |
+
x: Tensor,
|
| 170 |
+
text_condition: Tensor,
|
| 171 |
+
speech_condition: torch.Tensor,
|
| 172 |
+
guidance_scale: Tensor,
|
| 173 |
+
) -> Tensor:
|
| 174 |
+
if self.distill:
|
| 175 |
+
return self.model_func(
|
| 176 |
+
t=t,
|
| 177 |
+
xt=x,
|
| 178 |
+
text_condition=text_condition,
|
| 179 |
+
speech_condition=speech_condition,
|
| 180 |
+
guidance_scale=guidance_scale,
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
x = x.repeat(2, 1, 1)
|
| 184 |
+
text_condition = torch.cat(
|
| 185 |
+
[torch.zeros_like(text_condition), text_condition], dim=0
|
| 186 |
+
)
|
| 187 |
+
speech_condition = torch.cat(
|
| 188 |
+
[
|
| 189 |
+
torch.where(
|
| 190 |
+
t > 0.5, torch.zeros_like(speech_condition), speech_condition
|
| 191 |
+
),
|
| 192 |
+
speech_condition,
|
| 193 |
+
],
|
| 194 |
+
dim=0,
|
| 195 |
+
)
|
| 196 |
+
guidance_scale = torch.where(t > 0.5, guidance_scale, guidance_scale * 2.0)
|
| 197 |
+
data_uncond, data_cond = self.model_func(
|
| 198 |
+
t=t,
|
| 199 |
+
xt=x,
|
| 200 |
+
text_condition=text_condition,
|
| 201 |
+
speech_condition=speech_condition,
|
| 202 |
+
).chunk(2, dim=0)
|
| 203 |
+
v = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
|
| 204 |
+
return v
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def export_text_encoder(
|
| 208 |
+
model: OnnxTextModel,
|
| 209 |
+
filename: str,
|
| 210 |
+
opset_version: int = 11,
|
| 211 |
+
) -> None:
|
| 212 |
+
"""Export the text encoder model to ONNX format.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
model:
|
| 216 |
+
The input model
|
| 217 |
+
filename:
|
| 218 |
+
The filename to save the exported ONNX model.
|
| 219 |
+
opset_version:
|
| 220 |
+
The opset version to use.
|
| 221 |
+
"""
|
| 222 |
+
tokens = torch.tensor([[2, 3, 4, 5]], dtype=torch.int64)
|
| 223 |
+
prompt_tokens = torch.tensor([[0, 1]], dtype=torch.int64)
|
| 224 |
+
prompt_features_len = torch.tensor(10, dtype=torch.int64)
|
| 225 |
+
speed = torch.tensor(1.0, dtype=torch.float32)
|
| 226 |
+
|
| 227 |
+
model = torch.jit.trace(model, (tokens, prompt_tokens, prompt_features_len, speed))
|
| 228 |
+
|
| 229 |
+
torch.onnx.export(
|
| 230 |
+
model,
|
| 231 |
+
(tokens, prompt_tokens, prompt_features_len, speed),
|
| 232 |
+
filename,
|
| 233 |
+
verbose=False,
|
| 234 |
+
opset_version=opset_version,
|
| 235 |
+
input_names=["tokens", "prompt_tokens", "prompt_features_len", "speed"],
|
| 236 |
+
output_names=["text_condition"],
|
| 237 |
+
dynamic_axes={
|
| 238 |
+
"tokens": {0: "N", 1: "T"},
|
| 239 |
+
"prompt_tokens": {0: "N", 1: "T"},
|
| 240 |
+
"text_condition": {0: "N", 1: "T"},
|
| 241 |
+
},
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
meta_data = {
|
| 245 |
+
"version": "1",
|
| 246 |
+
"model_author": "k2-fsa",
|
| 247 |
+
"comment": "ZipVoice text encoder",
|
| 248 |
+
}
|
| 249 |
+
logging.info(f"meta_data: {meta_data}")
|
| 250 |
+
add_meta_data(filename=filename, meta_data=meta_data)
|
| 251 |
+
|
| 252 |
+
logging.info(f"Exported to {filename}")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def export_fm_decoder(
|
| 256 |
+
model: OnnxFlowMatchingModel,
|
| 257 |
+
filename: str,
|
| 258 |
+
opset_version: int = 11,
|
| 259 |
+
) -> None:
|
| 260 |
+
"""Export the flow matching decoder model to ONNX format.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
model:
|
| 264 |
+
The input model
|
| 265 |
+
filename:
|
| 266 |
+
The filename to save the exported ONNX model.
|
| 267 |
+
opset_version:
|
| 268 |
+
The opset version to use.
|
| 269 |
+
"""
|
| 270 |
+
feat_dim = model.feat_dim
|
| 271 |
+
seq_len = 200
|
| 272 |
+
t = torch.tensor(0.5, dtype=torch.float32)
|
| 273 |
+
x = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
|
| 274 |
+
text_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
|
| 275 |
+
speech_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
|
| 276 |
+
guidance_scale = torch.tensor(1.0, dtype=torch.float32)
|
| 277 |
+
|
| 278 |
+
model = torch.jit.trace(
|
| 279 |
+
model, (t, x, text_condition, speech_condition, guidance_scale)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
torch.onnx.export(
|
| 283 |
+
model,
|
| 284 |
+
(t, x, text_condition, speech_condition, guidance_scale),
|
| 285 |
+
filename,
|
| 286 |
+
verbose=False,
|
| 287 |
+
opset_version=opset_version,
|
| 288 |
+
input_names=["t", "x", "text_condition", "speech_condition", "guidance_scale"],
|
| 289 |
+
output_names=["v"],
|
| 290 |
+
dynamic_axes={
|
| 291 |
+
"x": {0: "N", 1: "T"},
|
| 292 |
+
"text_condition": {0: "N", 1: "T"},
|
| 293 |
+
"speech_condition": {0: "N", 1: "T"},
|
| 294 |
+
"v": {0: "N", 1: "T"},
|
| 295 |
+
},
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
meta_data = {
|
| 299 |
+
"version": "1",
|
| 300 |
+
"model_author": "k2-fsa",
|
| 301 |
+
"comment": "ZipVoice flow-matching decoder",
|
| 302 |
+
"feat_dim": str(feat_dim),
|
| 303 |
+
}
|
| 304 |
+
logging.info(f"meta_data: {meta_data}")
|
| 305 |
+
add_meta_data(filename=filename, meta_data=meta_data)
|
| 306 |
+
|
| 307 |
+
logging.info(f"Exported to {filename}")
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@torch.no_grad()
|
| 311 |
+
def main():
|
| 312 |
+
parser = get_parser()
|
| 313 |
+
args = parser.parse_args()
|
| 314 |
+
|
| 315 |
+
params = AttributeDict()
|
| 316 |
+
params.update(vars(args))
|
| 317 |
+
|
| 318 |
+
params.model_dir = Path(params.model_dir)
|
| 319 |
+
if not params.model_dir.is_dir():
|
| 320 |
+
raise FileNotFoundError(f"{params.model_dir} does not exist")
|
| 321 |
+
for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
|
| 322 |
+
if not (params.model_dir / filename).is_file():
|
| 323 |
+
raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
|
| 324 |
+
model_ckpt = params.model_dir / params.checkpoint_name
|
| 325 |
+
model_config = params.model_dir / "model.json"
|
| 326 |
+
token_file = params.model_dir / "tokens.txt"
|
| 327 |
+
|
| 328 |
+
logging.info(f"Loading model from {params.model_dir}")
|
| 329 |
+
|
| 330 |
+
tokenizer = SimpleTokenizer(token_file)
|
| 331 |
+
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
|
| 332 |
+
|
| 333 |
+
with open(model_config, "r") as f:
|
| 334 |
+
model_config = json.load(f)
|
| 335 |
+
|
| 336 |
+
if params.model_name == "zipvoice":
|
| 337 |
+
model = ZipVoice(
|
| 338 |
+
**model_config["model"],
|
| 339 |
+
**tokenizer_config,
|
| 340 |
+
)
|
| 341 |
+
distill = False
|
| 342 |
+
else:
|
| 343 |
+
assert params.model_name == "zipvoice_distill"
|
| 344 |
+
model = ZipVoiceDistill(
|
| 345 |
+
**model_config["model"],
|
| 346 |
+
**tokenizer_config,
|
| 347 |
+
)
|
| 348 |
+
distill = True
|
| 349 |
+
|
| 350 |
+
if str(model_ckpt).endswith(".safetensors"):
|
| 351 |
+
safetensors.torch.load_model(model, model_ckpt)
|
| 352 |
+
elif str(model_ckpt).endswith(".pt"):
|
| 353 |
+
load_checkpoint(filename=model_ckpt, model=model, strict=True)
|
| 354 |
+
else:
|
| 355 |
+
raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
|
| 356 |
+
|
| 357 |
+
device = torch.device("cpu")
|
| 358 |
+
model = model.to(device)
|
| 359 |
+
model.eval()
|
| 360 |
+
|
| 361 |
+
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
| 362 |
+
|
| 363 |
+
logging.info("Exporting model")
|
| 364 |
+
onnx_model_dir = Path(params.onnx_model_dir)
|
| 365 |
+
onnx_model_dir.mkdir(parents=True, exist_ok=True)
|
| 366 |
+
opset_version = 11
|
| 367 |
+
|
| 368 |
+
text_encoder = OnnxTextModel(model=model)
|
| 369 |
+
text_encoder_file = onnx_model_dir / "text_encoder.onnx"
|
| 370 |
+
export_text_encoder(
|
| 371 |
+
model=text_encoder,
|
| 372 |
+
filename=text_encoder_file,
|
| 373 |
+
opset_version=opset_version,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
fm_decoder = OnnxFlowMatchingModel(model=model, distill=distill)
|
| 377 |
+
fm_decoder_file = onnx_model_dir / "fm_decoder.onnx"
|
| 378 |
+
export_fm_decoder(
|
| 379 |
+
model=fm_decoder,
|
| 380 |
+
filename=fm_decoder_file,
|
| 381 |
+
opset_version=opset_version,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
logging.info("Generate int8 quantization models")
|
| 385 |
+
|
| 386 |
+
text_encoder_int8_file = onnx_model_dir / "text_encoder_int8.onnx"
|
| 387 |
+
quantize_dynamic(
|
| 388 |
+
model_input=text_encoder_file,
|
| 389 |
+
model_output=text_encoder_int8_file,
|
| 390 |
+
op_types_to_quantize=["MatMul"],
|
| 391 |
+
weight_type=QuantType.QInt8,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
fm_decoder_int8_file = onnx_model_dir / "fm_decoder_int8.onnx"
|
| 395 |
+
quantize_dynamic(
|
| 396 |
+
model_input=fm_decoder_file,
|
| 397 |
+
model_output=fm_decoder_int8_file,
|
| 398 |
+
op_types_to_quantize=["MatMul"],
|
| 399 |
+
weight_type=QuantType.QInt8,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
logging.info("Done!")
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
if __name__ == "__main__":
|
| 406 |
+
|
| 407 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 408 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 409 |
+
|
| 410 |
+
main()
|
zipvoice/bin/prepare_dataset.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This script generates lhotse manifest files from TSV files for custom datasets.
|
| 20 |
+
|
| 21 |
+
Each line of the TSV files should be in one of the following formats:
|
| 22 |
+
1. "{uniq_id}\t{text}\t{wav_path}" if the text corresponds to the full wav",
|
| 23 |
+
2. "{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time} if text corresponds
|
| 24 |
+
to part of the wav. The start_time and end_time specify the start and end
|
| 25 |
+
times of the text within the wav, which should be in seconds.
|
| 26 |
+
|
| 27 |
+
Note: {uniq_id} must be unique for each line.
|
| 28 |
+
|
| 29 |
+
Usage:
|
| 30 |
+
|
| 31 |
+
Suppose you have two TSV files: "custom_train.tsv" and "custom_dev.tsv",
|
| 32 |
+
where "custom" is your dataset name, "train"/"dev" are used for training and
|
| 33 |
+
validation respectively.
|
| 34 |
+
|
| 35 |
+
(1) Prepare the training data
|
| 36 |
+
|
| 37 |
+
python3 -m zipvoice.bin.prepare_dataset \
|
| 38 |
+
--tsv-path data/raw/custom_train.tsv \
|
| 39 |
+
--prefix "custom" \
|
| 40 |
+
--subset "train" \
|
| 41 |
+
--num-jobs 20 \
|
| 42 |
+
--output-dir "data/manifests"
|
| 43 |
+
|
| 44 |
+
The output file would be "data/manifests/custom_cuts_train.jsonl.gz".
|
| 45 |
+
|
| 46 |
+
(2) Prepare the validation data
|
| 47 |
+
|
| 48 |
+
python3 -m zipvoice.bin.prepare_dataset \
|
| 49 |
+
--tsv-path data/raw/custom_dev.tsv \
|
| 50 |
+
--prefix "custom" \
|
| 51 |
+
--subset "dev" \
|
| 52 |
+
--num-jobs 1 \
|
| 53 |
+
--output-dir "data/manifests"
|
| 54 |
+
|
| 55 |
+
The output file would be "data/manifests/custom_cuts_dev.jsonl.gz".
|
| 56 |
+
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
import argparse
|
| 60 |
+
import logging
|
| 61 |
+
import re
|
| 62 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 63 |
+
from pathlib import Path
|
| 64 |
+
from typing import List, Optional, Tuple
|
| 65 |
+
|
| 66 |
+
from lhotse import CutSet, validate_recordings_and_supervisions
|
| 67 |
+
from lhotse.audio import Recording, RecordingSet
|
| 68 |
+
from lhotse.qa import fix_manifests
|
| 69 |
+
from lhotse.supervision import SupervisionSegment, SupervisionSet
|
| 70 |
+
from lhotse.utils import Pathlike
|
| 71 |
+
from tqdm.auto import tqdm
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_args():
|
| 75 |
+
parser = argparse.ArgumentParser()
|
| 76 |
+
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--tsv-path",
|
| 79 |
+
type=str,
|
| 80 |
+
help="The path of the tsv file. Each line should be in the format: "
|
| 81 |
+
"{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time} "
|
| 82 |
+
"if text corresponds to part of the wav or {uniq_id}\t{text}\t{wav_path} "
|
| 83 |
+
"if the text corresponds to the full wav",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--prefix",
|
| 87 |
+
type=str,
|
| 88 |
+
default="custom",
|
| 89 |
+
help="Prefix of the output manifest file.",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--subset",
|
| 94 |
+
type=str,
|
| 95 |
+
default="train",
|
| 96 |
+
help="Subset name manifest file, typically train or dev.",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--num-jobs",
|
| 101 |
+
type=int,
|
| 102 |
+
default=20,
|
| 103 |
+
help="Number of jobs to processing.",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--output-dir",
|
| 108 |
+
type=str,
|
| 109 |
+
default="data/manifests",
|
| 110 |
+
help="The destination directory of manifest files.",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--sampling-rate",
|
| 114 |
+
type=int,
|
| 115 |
+
default=24000,
|
| 116 |
+
help="The target sampling rate.",
|
| 117 |
+
)
|
| 118 |
+
return parser.parse_args()
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _parse_recording(
|
| 122 |
+
wav_path: str,
|
| 123 |
+
) -> Tuple[Recording, str]:
|
| 124 |
+
"""
|
| 125 |
+
:param wav_path: Path to the audio file
|
| 126 |
+
:return: a tuple of "recording" and "recording_id"
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
recording_id = wav_path.replace("/", "_").replace(".", "_")
|
| 130 |
+
recording = Recording.from_file(path=wav_path, recording_id=recording_id)
|
| 131 |
+
|
| 132 |
+
return recording, recording_id
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _parse_supervision(
|
| 136 |
+
supervision: List, recording_dict: dict
|
| 137 |
+
) -> Optional[SupervisionSegment]:
|
| 138 |
+
"""
|
| 139 |
+
:param line: A line from the TSV file
|
| 140 |
+
:param recording_dict: Dictionary mapping recording IDs to Recording objects
|
| 141 |
+
:return: A SupervisionSegment object
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
uniq_id, text, wav_path, start, end = supervision
|
| 145 |
+
try:
|
| 146 |
+
recording_id = wav_path.replace("/", "_").replace(".", "_")
|
| 147 |
+
|
| 148 |
+
recording = recording_dict[recording_id]
|
| 149 |
+
duration = end - start if end is not None else recording.duration
|
| 150 |
+
assert duration <= recording.duration, f"Duration {duration} is greater than "
|
| 151 |
+
f"recording duration {recording.duration}"
|
| 152 |
+
|
| 153 |
+
text = re.sub("_", " ", text) # "_" is treated as padding symbol
|
| 154 |
+
text = re.sub(r"\s+", " ", text) # remove extra whitespace
|
| 155 |
+
|
| 156 |
+
return SupervisionSegment(
|
| 157 |
+
id=f"{uniq_id}",
|
| 158 |
+
recording_id=recording.id,
|
| 159 |
+
start=start,
|
| 160 |
+
duration=duration,
|
| 161 |
+
channel=recording.channel_ids,
|
| 162 |
+
text=text.strip(),
|
| 163 |
+
)
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logging.warning(f"Error processing line: {e}")
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def prepare_dataset(
|
| 170 |
+
tsv_path: Pathlike,
|
| 171 |
+
prefix: str,
|
| 172 |
+
subset: str,
|
| 173 |
+
sampling_rate: int,
|
| 174 |
+
num_jobs: int,
|
| 175 |
+
output_dir: Pathlike,
|
| 176 |
+
):
|
| 177 |
+
"""
|
| 178 |
+
Returns the manifests which consist of the Recordings and Supervisions
|
| 179 |
+
|
| 180 |
+
:param tsv_path: Path to the TSV file
|
| 181 |
+
:param output_dir: Path where to write the manifests
|
| 182 |
+
:param num_jobs: Number of processes for parallel processing
|
| 183 |
+
:return: The CutSet containing the data
|
| 184 |
+
"""
|
| 185 |
+
logging.info(f"Preparing {prefix} dataset {subset} subset.")
|
| 186 |
+
output_dir = Path(output_dir)
|
| 187 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 188 |
+
file_name = f"{prefix}_cuts_{subset}.jsonl.gz"
|
| 189 |
+
if (output_dir / file_name).is_file():
|
| 190 |
+
logging.info(f"{file_name} exists, skipping.")
|
| 191 |
+
return
|
| 192 |
+
|
| 193 |
+
# Step 1: Read all unique recording paths
|
| 194 |
+
recordings_path_set = set()
|
| 195 |
+
supervision_list = list()
|
| 196 |
+
with open(tsv_path, "r") as fr:
|
| 197 |
+
for line in fr:
|
| 198 |
+
items = line.strip().split("\t")
|
| 199 |
+
if len(items) == 3:
|
| 200 |
+
uniq_id, text, wav_path = items
|
| 201 |
+
start, end = 0, None
|
| 202 |
+
elif len(items) == 5:
|
| 203 |
+
uniq_id, text, wav_path, start, end = items
|
| 204 |
+
start, end = float(start), float(end)
|
| 205 |
+
else:
|
| 206 |
+
raise ValueError(
|
| 207 |
+
f"Invalid line format: {line},"
|
| 208 |
+
"requries to be 3 columns or 5 columns"
|
| 209 |
+
)
|
| 210 |
+
recordings_path_set.add(wav_path)
|
| 211 |
+
supervision_list.append((uniq_id, text, wav_path, start, end))
|
| 212 |
+
|
| 213 |
+
logging.info("Starting to process recordings...")
|
| 214 |
+
# Step 2: Process recordings
|
| 215 |
+
futures = []
|
| 216 |
+
recording_dict = {}
|
| 217 |
+
with ThreadPoolExecutor(max_workers=num_jobs) as ex:
|
| 218 |
+
for wav_path in tqdm(recordings_path_set, desc="Submitting jobs"):
|
| 219 |
+
futures.append(ex.submit(_parse_recording, wav_path))
|
| 220 |
+
|
| 221 |
+
for future in tqdm(futures, desc="Processing recordings"):
|
| 222 |
+
try:
|
| 223 |
+
recording, recording_id = future.result()
|
| 224 |
+
recording_dict[recording_id] = recording
|
| 225 |
+
except Exception as e:
|
| 226 |
+
logging.warning(
|
| 227 |
+
f"Error processing recording {recording_id} with error: {e}"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
recording_set = RecordingSet.from_recordings(recording_dict.values())
|
| 231 |
+
|
| 232 |
+
logging.info("Starting to process supervisions...")
|
| 233 |
+
# Step 3: Process supervisions
|
| 234 |
+
supervisions = []
|
| 235 |
+
for supervision in tqdm(supervision_list, desc="Processing supervisions"):
|
| 236 |
+
seg = _parse_supervision(supervision, recording_dict)
|
| 237 |
+
if seg is not None:
|
| 238 |
+
supervisions.append(seg)
|
| 239 |
+
|
| 240 |
+
logging.info("Processing Cuts...")
|
| 241 |
+
|
| 242 |
+
# Step 4: Create and validate manifests
|
| 243 |
+
supervision_set = SupervisionSet.from_segments(supervisions)
|
| 244 |
+
|
| 245 |
+
recording_set, supervision_set = fix_manifests(recording_set, supervision_set)
|
| 246 |
+
validate_recordings_and_supervisions(recording_set, supervision_set)
|
| 247 |
+
|
| 248 |
+
cut_set = CutSet.from_manifests(
|
| 249 |
+
recordings=recording_set, supervisions=supervision_set
|
| 250 |
+
)
|
| 251 |
+
cut_set = cut_set.sort_by_recording_id()
|
| 252 |
+
cut_set = cut_set.resample(sampling_rate)
|
| 253 |
+
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
| 254 |
+
|
| 255 |
+
logging.info(f"Saving file to {output_dir / file_name}")
|
| 256 |
+
# Step 5: Write manifests to disk
|
| 257 |
+
cut_set.to_file(output_dir / file_name)
|
| 258 |
+
logging.info("Done!")
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if __name__ == "__main__":
|
| 262 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 263 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 264 |
+
|
| 265 |
+
args = get_args()
|
| 266 |
+
|
| 267 |
+
prepare_dataset(
|
| 268 |
+
tsv_path=args.tsv_path,
|
| 269 |
+
prefix=args.prefix,
|
| 270 |
+
subset=args.subset,
|
| 271 |
+
sampling_rate=args.sampling_rate,
|
| 272 |
+
num_jobs=args.num_jobs,
|
| 273 |
+
output_dir=args.output_dir,
|
| 274 |
+
)
|
zipvoice/bin/prepare_tokens.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file reads the texts in given manifest and save the new cuts with prepared tokens.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import logging
|
| 7 |
+
from functools import partial
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from lhotse import load_manifest, split_parallelize_combine
|
| 11 |
+
|
| 12 |
+
from zipvoice.tokenizer.tokenizer import add_tokens
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_args():
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--input-file",
|
| 20 |
+
type=str,
|
| 21 |
+
help="Input manifest without tokens",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--output-file",
|
| 26 |
+
type=str,
|
| 27 |
+
help="Output manifest with tokens.",
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--num-jobs",
|
| 32 |
+
type=int,
|
| 33 |
+
default=20,
|
| 34 |
+
help="Number of jobs to run in parallel.",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--tokenizer",
|
| 39 |
+
type=str,
|
| 40 |
+
default="emilia",
|
| 41 |
+
help="The destination directory of manifest files.",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--lang",
|
| 46 |
+
type=str,
|
| 47 |
+
default="en-us",
|
| 48 |
+
help="Language identifier, used when tokenizer type is espeak. see"
|
| 49 |
+
"https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
return parser.parse_args()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def prepare_tokens(
|
| 56 |
+
input_file: Path,
|
| 57 |
+
output_file: Path,
|
| 58 |
+
num_jobs: int,
|
| 59 |
+
tokenizer: str,
|
| 60 |
+
lang: str = "en-us",
|
| 61 |
+
):
|
| 62 |
+
logging.info(f"Processing {input_file}")
|
| 63 |
+
if output_file.is_file():
|
| 64 |
+
logging.info(f"{output_file} exists, skipping.")
|
| 65 |
+
return
|
| 66 |
+
logging.info(f"loading manifest from {input_file}")
|
| 67 |
+
cut_set = load_manifest(input_file)
|
| 68 |
+
|
| 69 |
+
_add_tokens = partial(add_tokens, tokenizer=tokenizer, lang=lang)
|
| 70 |
+
|
| 71 |
+
logging.info("Adding tokens")
|
| 72 |
+
|
| 73 |
+
cut_set = split_parallelize_combine(
|
| 74 |
+
num_jobs=num_jobs, manifest=cut_set, fn=_add_tokens
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
logging.info(f"Saving file to {output_file}")
|
| 78 |
+
cut_set.to_file(output_file)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 83 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 84 |
+
|
| 85 |
+
args = get_args()
|
| 86 |
+
input_file = Path(args.input_file)
|
| 87 |
+
output_file = Path(args.output_file)
|
| 88 |
+
num_jobs = args.num_jobs
|
| 89 |
+
tokenizer = args.tokenizer
|
| 90 |
+
lang = args.lang
|
| 91 |
+
|
| 92 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 93 |
+
|
| 94 |
+
prepare_tokens(
|
| 95 |
+
input_file=input_file,
|
| 96 |
+
output_file=output_file,
|
| 97 |
+
num_jobs=num_jobs,
|
| 98 |
+
tokenizer=tokenizer,
|
| 99 |
+
lang=lang,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
logging.info("Done!")
|
zipvoice/bin/train_zipvoice.py
ADDED
|
@@ -0,0 +1,1136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang,
|
| 3 |
+
# Han Zhu)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
This script trains a ZipVoice model with the flow-matching loss.
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
|
| 24 |
+
python3 -m zipvoice.bin.train_zipvoice \
|
| 25 |
+
--world-size 8 \
|
| 26 |
+
--use-fp16 1 \
|
| 27 |
+
--num-epochs 11 \
|
| 28 |
+
--max-duration 500 \
|
| 29 |
+
--lr-hours 30000 \
|
| 30 |
+
--model-config conf/zipvoice_base.json \
|
| 31 |
+
--tokenizer emilia \
|
| 32 |
+
--token-file "data/tokens_emilia.txt" \
|
| 33 |
+
--dataset emilia \
|
| 34 |
+
--manifest-dir data/fbank \
|
| 35 |
+
--exp-dir exp/zipvoice
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import argparse
|
| 39 |
+
import copy
|
| 40 |
+
import json
|
| 41 |
+
import logging
|
| 42 |
+
import os
|
| 43 |
+
from functools import partial
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
from shutil import copyfile
|
| 46 |
+
from typing import List, Optional, Tuple, Union
|
| 47 |
+
|
| 48 |
+
import torch
|
| 49 |
+
import torch.multiprocessing as mp
|
| 50 |
+
import torch.nn as nn
|
| 51 |
+
from lhotse.cut import Cut, CutSet
|
| 52 |
+
from lhotse.utils import fix_random_seed
|
| 53 |
+
from torch import Tensor
|
| 54 |
+
from torch.amp.grad_scaler import GradScaler
|
| 55 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 56 |
+
from torch.optim import Optimizer
|
| 57 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 58 |
+
|
| 59 |
+
import zipvoice.utils.diagnostics as diagnostics
|
| 60 |
+
from zipvoice.dataset.datamodule import TtsDataModule
|
| 61 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 62 |
+
from zipvoice.tokenizer.tokenizer import (
|
| 63 |
+
EmiliaTokenizer,
|
| 64 |
+
EspeakTokenizer,
|
| 65 |
+
LibriTTSTokenizer,
|
| 66 |
+
SimpleTokenizer,
|
| 67 |
+
SimpleTokenizer2,
|
| 68 |
+
)
|
| 69 |
+
from zipvoice.utils.checkpoint import (
|
| 70 |
+
load_checkpoint,
|
| 71 |
+
remove_checkpoints,
|
| 72 |
+
resume_checkpoint,
|
| 73 |
+
save_checkpoint,
|
| 74 |
+
save_checkpoint_with_global_batch_idx,
|
| 75 |
+
update_averaged_model,
|
| 76 |
+
)
|
| 77 |
+
from zipvoice.utils.common import (
|
| 78 |
+
AttributeDict,
|
| 79 |
+
MetricsTracker,
|
| 80 |
+
cleanup_dist,
|
| 81 |
+
create_grad_scaler,
|
| 82 |
+
get_adjusted_batch_count,
|
| 83 |
+
get_env_info,
|
| 84 |
+
get_parameter_groups_with_lrs,
|
| 85 |
+
prepare_input,
|
| 86 |
+
set_batch_count,
|
| 87 |
+
setup_dist,
|
| 88 |
+
setup_logger,
|
| 89 |
+
str2bool,
|
| 90 |
+
torch_autocast,
|
| 91 |
+
)
|
| 92 |
+
from zipvoice.utils.hooks import register_inf_check_hooks
|
| 93 |
+
from zipvoice.utils.lr_scheduler import Eden, FixedLRScheduler, LRScheduler
|
| 94 |
+
from zipvoice.utils.optim import ScaledAdam
|
| 95 |
+
|
| 96 |
+
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_parser():
|
| 100 |
+
parser = argparse.ArgumentParser(
|
| 101 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--world-size",
|
| 106 |
+
type=int,
|
| 107 |
+
default=1,
|
| 108 |
+
help="Number of GPUs for DDP training.",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--master-port",
|
| 113 |
+
type=int,
|
| 114 |
+
default=12356,
|
| 115 |
+
help="Master port to use for DDP training.",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--tensorboard",
|
| 120 |
+
type=str2bool,
|
| 121 |
+
default=True,
|
| 122 |
+
help="Should various information be logged in tensorboard.",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--num-epochs",
|
| 127 |
+
type=int,
|
| 128 |
+
default=11,
|
| 129 |
+
help="Number of epochs to train.",
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--num-iters",
|
| 134 |
+
type=int,
|
| 135 |
+
default=0,
|
| 136 |
+
help="Number of iter to train, will ignore num_epochs if > 0.",
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--start-epoch",
|
| 141 |
+
type=int,
|
| 142 |
+
default=1,
|
| 143 |
+
help="""Resume training from this epoch. It should be positive.
|
| 144 |
+
If larger than 1, it will load checkpoint from
|
| 145 |
+
exp-dir/epoch-{start_epoch-1}.pt
|
| 146 |
+
""",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--checkpoint",
|
| 151 |
+
type=str,
|
| 152 |
+
default=None,
|
| 153 |
+
help="""Checkpoints of pre-trained models, will load it if not None
|
| 154 |
+
""",
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--exp-dir",
|
| 159 |
+
type=str,
|
| 160 |
+
default="exp/zipvoice",
|
| 161 |
+
help="""The experiment dir.
|
| 162 |
+
It specifies the directory where all training related
|
| 163 |
+
files, e.g., checkpoints, log, etc, are saved
|
| 164 |
+
""",
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--base-lr", type=float, default=0.02, help="The base learning rate."
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--lr-batches",
|
| 173 |
+
type=float,
|
| 174 |
+
default=7500,
|
| 175 |
+
help="""Number of steps that affects how rapidly the learning rate
|
| 176 |
+
decreases. We suggest not to change this.""",
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--lr-epochs",
|
| 181 |
+
type=float,
|
| 182 |
+
default=10,
|
| 183 |
+
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
| 184 |
+
""",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--lr-hours",
|
| 189 |
+
type=float,
|
| 190 |
+
default=0,
|
| 191 |
+
help="""If positive, --epoch is ignored and it specifies the number of hours
|
| 192 |
+
that affects how rapidly the learning rate decreases.
|
| 193 |
+
""",
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--ref-duration",
|
| 198 |
+
type=float,
|
| 199 |
+
default=50,
|
| 200 |
+
help="""Reference batch duration for purposes of adjusting batch counts for"
|
| 201 |
+
setting various schedules inside the model".
|
| 202 |
+
""",
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
parser.add_argument(
|
| 206 |
+
"--finetune",
|
| 207 |
+
type=str2bool,
|
| 208 |
+
default=False,
|
| 209 |
+
help="Whether to use the fine-tuning mode, will used a fixed learning rate "
|
| 210 |
+
"schedule and skip the large dropout phase.",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--seed",
|
| 215 |
+
type=int,
|
| 216 |
+
default=42,
|
| 217 |
+
help="The seed for random generators intended for reproducibility",
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--print-diagnostics",
|
| 222 |
+
type=str2bool,
|
| 223 |
+
default=False,
|
| 224 |
+
help="Accumulate stats on activations, print them and exit.",
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--scan-oom",
|
| 229 |
+
type=str2bool,
|
| 230 |
+
default=False,
|
| 231 |
+
help="Scan pessimistic batches to see whether they cause OOMs.",
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--inf-check",
|
| 236 |
+
type=str2bool,
|
| 237 |
+
default=False,
|
| 238 |
+
help="Add hooks to check for infinite module outputs and gradients.",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--save-every-n",
|
| 243 |
+
type=int,
|
| 244 |
+
default=5000,
|
| 245 |
+
help="""Save checkpoint after processing this number of batches"
|
| 246 |
+
periodically. We save checkpoint to exp-dir/ whenever
|
| 247 |
+
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
| 248 |
+
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
| 249 |
+
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
| 250 |
+
end of each epoch where `xxx` is the epoch number counting from 1.
|
| 251 |
+
""",
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
parser.add_argument(
|
| 255 |
+
"--valid-by-epoch",
|
| 256 |
+
type=str2bool,
|
| 257 |
+
default=False,
|
| 258 |
+
help="""Whether to validate after each epoch. If False, will validate
|
| 259 |
+
after every save_every_n iterations.
|
| 260 |
+
""",
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--keep-last-k",
|
| 265 |
+
type=int,
|
| 266 |
+
default=30,
|
| 267 |
+
help="""Only keep this number of checkpoints on disk.
|
| 268 |
+
For instance, if it is 3, there are only 3 checkpoints
|
| 269 |
+
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
| 270 |
+
It does not affect checkpoints with name `epoch-xxx.pt`.
|
| 271 |
+
""",
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
parser.add_argument(
|
| 275 |
+
"--average-period",
|
| 276 |
+
type=int,
|
| 277 |
+
default=200,
|
| 278 |
+
help="""Update the averaged model, namely `model_avg`, after processing
|
| 279 |
+
this number of batches. `model_avg` is a separate version of model,
|
| 280 |
+
in which each floating-point parameter is the average of all the
|
| 281 |
+
parameters from the start of training. Each time we take the average,
|
| 282 |
+
we do: `model_avg = model * (average_period / batch_idx_train) +
|
| 283 |
+
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
| 284 |
+
""",
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--use-fp16",
|
| 289 |
+
type=str2bool,
|
| 290 |
+
default=True,
|
| 291 |
+
help="Whether to use half precision training.",
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
parser.add_argument(
|
| 295 |
+
"--feat-scale",
|
| 296 |
+
type=float,
|
| 297 |
+
default=0.1,
|
| 298 |
+
help="The scale factor of fbank feature",
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
parser.add_argument(
|
| 302 |
+
"--condition-drop-ratio",
|
| 303 |
+
type=float,
|
| 304 |
+
default=0.2,
|
| 305 |
+
help="The drop rate of text condition during training.",
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
parser.add_argument(
|
| 309 |
+
"--dataset",
|
| 310 |
+
type=str,
|
| 311 |
+
default="emilia",
|
| 312 |
+
choices=["emilia", "libritts", "custom"],
|
| 313 |
+
help="The used training dataset",
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
parser.add_argument(
|
| 317 |
+
"--train-manifest",
|
| 318 |
+
type=str,
|
| 319 |
+
help="Path of the training manifest",
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
parser.add_argument(
|
| 323 |
+
"--dev-manifest",
|
| 324 |
+
type=str,
|
| 325 |
+
help="Path of the validation manifest",
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--min-len",
|
| 330 |
+
type=float,
|
| 331 |
+
default=1.0,
|
| 332 |
+
help="The minimum audio length used for training",
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
parser.add_argument(
|
| 336 |
+
"--max-len",
|
| 337 |
+
type=float,
|
| 338 |
+
default=30.0,
|
| 339 |
+
help="The maximum audio length used for training",
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--model-config",
|
| 344 |
+
type=str,
|
| 345 |
+
default="conf/zipvoice_base.json",
|
| 346 |
+
help="The model configuration file.",
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
parser.add_argument(
|
| 350 |
+
"--tokenizer",
|
| 351 |
+
type=str,
|
| 352 |
+
default="emilia",
|
| 353 |
+
help="Tokenizer type.",
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
parser.add_argument(
|
| 357 |
+
"--lang",
|
| 358 |
+
type=str,
|
| 359 |
+
default="en-us",
|
| 360 |
+
help="Language identifier, used when tokenizer type is espeak. see"
|
| 361 |
+
"https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
"--token-file",
|
| 366 |
+
type=str,
|
| 367 |
+
default="data/tokens_emilia.txt",
|
| 368 |
+
help="The file that contains information that maps tokens to ids,"
|
| 369 |
+
"which is a text file with '{token}\t{token_id}' per line.",
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
return parser
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def get_params() -> AttributeDict:
|
| 376 |
+
"""Return a dict containing training parameters.
|
| 377 |
+
|
| 378 |
+
All training related parameters that are not passed from the commandline
|
| 379 |
+
are saved in the variable `params`.
|
| 380 |
+
|
| 381 |
+
Commandline options are merged into `params` after they are parsed, so
|
| 382 |
+
you can also access them via `params`.
|
| 383 |
+
|
| 384 |
+
Explanation of options saved in `params`:
|
| 385 |
+
|
| 386 |
+
- best_train_loss: Best training loss so far. It is used to select
|
| 387 |
+
the model that has the lowest training loss. It is
|
| 388 |
+
updated during the training.
|
| 389 |
+
|
| 390 |
+
- best_valid_loss: Best validation loss so far. It is used to select
|
| 391 |
+
the model that has the lowest validation loss. It is
|
| 392 |
+
updated during the training.
|
| 393 |
+
|
| 394 |
+
- best_train_epoch: It is the epoch that has the best training loss.
|
| 395 |
+
|
| 396 |
+
- best_valid_epoch: It is the epoch that has the best validation loss.
|
| 397 |
+
|
| 398 |
+
- batch_idx_train: Used to writing statistics to tensorboard. It
|
| 399 |
+
contains number of batches trained so far across
|
| 400 |
+
epochs.
|
| 401 |
+
|
| 402 |
+
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
| 403 |
+
|
| 404 |
+
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
| 405 |
+
|
| 406 |
+
- env_info: A dict containing information about the environment.
|
| 407 |
+
|
| 408 |
+
"""
|
| 409 |
+
params = AttributeDict(
|
| 410 |
+
{
|
| 411 |
+
"best_train_loss": float("inf"),
|
| 412 |
+
"best_valid_loss": float("inf"),
|
| 413 |
+
"best_train_epoch": -1,
|
| 414 |
+
"best_valid_epoch": -1,
|
| 415 |
+
"batch_idx_train": 0,
|
| 416 |
+
"log_interval": 50,
|
| 417 |
+
"reset_interval": 200,
|
| 418 |
+
"env_info": get_env_info(),
|
| 419 |
+
}
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
return params
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def compute_fbank_loss(
|
| 426 |
+
params: AttributeDict,
|
| 427 |
+
model: Union[nn.Module, DDP],
|
| 428 |
+
features: Tensor,
|
| 429 |
+
features_lens: Tensor,
|
| 430 |
+
tokens: List[List[int]],
|
| 431 |
+
is_training: bool,
|
| 432 |
+
) -> Tuple[Tensor, MetricsTracker]:
|
| 433 |
+
"""
|
| 434 |
+
Compute loss given the model and its inputs.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
params:
|
| 438 |
+
Parameters for training. See :func:`get_params`.
|
| 439 |
+
model:
|
| 440 |
+
The model for training.
|
| 441 |
+
features:
|
| 442 |
+
The target acoustic feature.
|
| 443 |
+
features_lens:
|
| 444 |
+
The number of frames of each utterance.
|
| 445 |
+
tokens:
|
| 446 |
+
Input tokens that representing the transcripts.
|
| 447 |
+
is_training:
|
| 448 |
+
True for training. False for validation. When it is True, this
|
| 449 |
+
function enables autograd during computation; when it is False, it
|
| 450 |
+
disables autograd.
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 454 |
+
|
| 455 |
+
batch_size, num_frames, _ = features.shape
|
| 456 |
+
|
| 457 |
+
features = torch.nn.functional.pad(
|
| 458 |
+
features, (0, 0, 0, num_frames - features.size(1))
|
| 459 |
+
) # (B, T, F)
|
| 460 |
+
noise = torch.randn_like(features) # (B, T, F)
|
| 461 |
+
|
| 462 |
+
# Sampling t from uniform distribution
|
| 463 |
+
if is_training:
|
| 464 |
+
t = torch.rand(batch_size, 1, 1, device=device)
|
| 465 |
+
else:
|
| 466 |
+
t = (
|
| 467 |
+
(torch.arange(batch_size, device=device) / batch_size)
|
| 468 |
+
.unsqueeze(1)
|
| 469 |
+
.unsqueeze(2)
|
| 470 |
+
)
|
| 471 |
+
with torch.set_grad_enabled(is_training):
|
| 472 |
+
|
| 473 |
+
loss = model(
|
| 474 |
+
tokens=tokens,
|
| 475 |
+
features=features,
|
| 476 |
+
features_lens=features_lens,
|
| 477 |
+
noise=noise,
|
| 478 |
+
t=t,
|
| 479 |
+
condition_drop_ratio=params.condition_drop_ratio,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
assert loss.requires_grad == is_training
|
| 483 |
+
info = MetricsTracker()
|
| 484 |
+
num_frames = features_lens.sum().item()
|
| 485 |
+
info["frames"] = num_frames
|
| 486 |
+
info["loss"] = loss.detach().cpu().item() * num_frames
|
| 487 |
+
|
| 488 |
+
return loss, info
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def train_one_epoch(
|
| 492 |
+
params: AttributeDict,
|
| 493 |
+
model: Union[nn.Module, DDP],
|
| 494 |
+
optimizer: Optimizer,
|
| 495 |
+
scheduler: LRSchedulerType,
|
| 496 |
+
train_dl: torch.utils.data.DataLoader,
|
| 497 |
+
valid_dl: torch.utils.data.DataLoader,
|
| 498 |
+
scaler: GradScaler,
|
| 499 |
+
model_avg: Optional[nn.Module] = None,
|
| 500 |
+
tb_writer: Optional[SummaryWriter] = None,
|
| 501 |
+
world_size: int = 1,
|
| 502 |
+
rank: int = 0,
|
| 503 |
+
) -> None:
|
| 504 |
+
"""Train the model for one epoch.
|
| 505 |
+
|
| 506 |
+
The training loss from the mean of all frames is saved in
|
| 507 |
+
`params.train_loss`. It runs the validation process every
|
| 508 |
+
`params.valid_interval` batches or every epochs.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
params:
|
| 512 |
+
It is returned by :func:`get_params`.
|
| 513 |
+
model:
|
| 514 |
+
The model for training.
|
| 515 |
+
optimizer:
|
| 516 |
+
The optimizer.
|
| 517 |
+
scheduler:
|
| 518 |
+
The learning rate scheduler, we call step() every epoch.
|
| 519 |
+
train_dl:
|
| 520 |
+
Dataloader for the training dataset.
|
| 521 |
+
valid_dl:
|
| 522 |
+
Dataloader for the validation dataset.
|
| 523 |
+
scaler:
|
| 524 |
+
The scaler used for mix precision training.
|
| 525 |
+
tb_writer:
|
| 526 |
+
Writer to write log messages to tensorboard.
|
| 527 |
+
world_size:
|
| 528 |
+
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
| 529 |
+
rank:
|
| 530 |
+
The rank of the node in DDP training. If no DDP is used, it should
|
| 531 |
+
be set to 0.
|
| 532 |
+
"""
|
| 533 |
+
model.train()
|
| 534 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 535 |
+
|
| 536 |
+
# used to track the stats over iterations in one epoch
|
| 537 |
+
tot_loss = MetricsTracker()
|
| 538 |
+
|
| 539 |
+
saved_bad_model = False
|
| 540 |
+
|
| 541 |
+
def save_bad_model(suffix: str = ""):
|
| 542 |
+
save_checkpoint(
|
| 543 |
+
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
| 544 |
+
model=model,
|
| 545 |
+
model_avg=model_avg,
|
| 546 |
+
params=params,
|
| 547 |
+
optimizer=optimizer,
|
| 548 |
+
scheduler=scheduler,
|
| 549 |
+
sampler=train_dl.sampler,
|
| 550 |
+
scaler=scaler,
|
| 551 |
+
rank=0,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
for batch_idx, batch in enumerate(train_dl):
|
| 555 |
+
|
| 556 |
+
if batch_idx % 10 == 0:
|
| 557 |
+
if params.finetune:
|
| 558 |
+
set_batch_count(model, get_adjusted_batch_count(params) + 100000)
|
| 559 |
+
else:
|
| 560 |
+
set_batch_count(model, get_adjusted_batch_count(params))
|
| 561 |
+
|
| 562 |
+
if (
|
| 563 |
+
params.valid_by_epoch and batch_idx == 0 and not params.print_diagnostics
|
| 564 |
+
) or (
|
| 565 |
+
not params.valid_by_epoch
|
| 566 |
+
and params.batch_idx_train % params.valid_interval == 0
|
| 567 |
+
and not params.print_diagnostics
|
| 568 |
+
):
|
| 569 |
+
logging.info("Computing validation loss")
|
| 570 |
+
valid_info = compute_validation_loss(
|
| 571 |
+
params=params,
|
| 572 |
+
model=model,
|
| 573 |
+
valid_dl=valid_dl,
|
| 574 |
+
world_size=world_size,
|
| 575 |
+
)
|
| 576 |
+
model.train()
|
| 577 |
+
logging.info(
|
| 578 |
+
f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
|
| 579 |
+
f" validation: {valid_info}"
|
| 580 |
+
)
|
| 581 |
+
logging.info(
|
| 582 |
+
f"Maximum memory allocated so far is "
|
| 583 |
+
f"{torch.cuda.max_memory_allocated() // 1000000}MB"
|
| 584 |
+
)
|
| 585 |
+
if tb_writer is not None:
|
| 586 |
+
valid_info.write_summary(
|
| 587 |
+
tb_writer, "train/valid_", params.batch_idx_train
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
params.batch_idx_train += 1
|
| 591 |
+
|
| 592 |
+
batch_size = len(batch["text"])
|
| 593 |
+
|
| 594 |
+
tokens, features, features_lens = prepare_input(
|
| 595 |
+
params=params,
|
| 596 |
+
batch=batch,
|
| 597 |
+
device=device,
|
| 598 |
+
return_tokens=True,
|
| 599 |
+
return_feature=True,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
try:
|
| 603 |
+
with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
|
| 604 |
+
loss, loss_info = compute_fbank_loss(
|
| 605 |
+
params=params,
|
| 606 |
+
model=model,
|
| 607 |
+
features=features,
|
| 608 |
+
features_lens=features_lens,
|
| 609 |
+
tokens=tokens,
|
| 610 |
+
is_training=True,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
| 614 |
+
|
| 615 |
+
scaler.scale(loss).backward()
|
| 616 |
+
|
| 617 |
+
scheduler.step_batch(params.batch_idx_train)
|
| 618 |
+
# Use the number of hours of speech to adjust the learning rate
|
| 619 |
+
if params.lr_hours > 0:
|
| 620 |
+
scheduler.step_epoch(
|
| 621 |
+
params.batch_idx_train
|
| 622 |
+
* params.max_duration
|
| 623 |
+
* params.world_size
|
| 624 |
+
/ 3600
|
| 625 |
+
)
|
| 626 |
+
scaler.step(optimizer)
|
| 627 |
+
scaler.update()
|
| 628 |
+
optimizer.zero_grad()
|
| 629 |
+
except Exception as e:
|
| 630 |
+
logging.info(f"Caught exception : {e}.")
|
| 631 |
+
save_bad_model()
|
| 632 |
+
raise
|
| 633 |
+
|
| 634 |
+
if params.print_diagnostics and batch_idx == 5:
|
| 635 |
+
return
|
| 636 |
+
|
| 637 |
+
if (
|
| 638 |
+
rank == 0
|
| 639 |
+
and params.batch_idx_train > 0
|
| 640 |
+
and params.batch_idx_train % params.average_period == 0
|
| 641 |
+
):
|
| 642 |
+
update_averaged_model(
|
| 643 |
+
params=params,
|
| 644 |
+
model_cur=model,
|
| 645 |
+
model_avg=model_avg,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
if (
|
| 649 |
+
params.batch_idx_train > 0
|
| 650 |
+
and params.batch_idx_train % params.save_every_n == 0
|
| 651 |
+
):
|
| 652 |
+
save_checkpoint_with_global_batch_idx(
|
| 653 |
+
out_dir=params.exp_dir,
|
| 654 |
+
global_batch_idx=params.batch_idx_train,
|
| 655 |
+
model=model,
|
| 656 |
+
model_avg=model_avg,
|
| 657 |
+
params=params,
|
| 658 |
+
optimizer=optimizer,
|
| 659 |
+
scheduler=scheduler,
|
| 660 |
+
sampler=train_dl.sampler,
|
| 661 |
+
scaler=scaler,
|
| 662 |
+
rank=rank,
|
| 663 |
+
)
|
| 664 |
+
remove_checkpoints(
|
| 665 |
+
out_dir=params.exp_dir,
|
| 666 |
+
topk=params.keep_last_k,
|
| 667 |
+
rank=rank,
|
| 668 |
+
)
|
| 669 |
+
if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
|
| 670 |
+
break
|
| 671 |
+
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
| 672 |
+
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
| 673 |
+
# of the grad scaler is configurable, but we can't configure it to have
|
| 674 |
+
# different behavior depending on the current grad scale.
|
| 675 |
+
cur_grad_scale = scaler._scale.item()
|
| 676 |
+
|
| 677 |
+
if cur_grad_scale < 1024.0 or (
|
| 678 |
+
cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
|
| 679 |
+
):
|
| 680 |
+
scaler.update(cur_grad_scale * 2.0)
|
| 681 |
+
if cur_grad_scale < 0.01:
|
| 682 |
+
if not saved_bad_model:
|
| 683 |
+
save_bad_model(suffix="-first-warning")
|
| 684 |
+
saved_bad_model = True
|
| 685 |
+
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
| 686 |
+
if cur_grad_scale < 1.0e-05:
|
| 687 |
+
save_bad_model()
|
| 688 |
+
raise RuntimeError(
|
| 689 |
+
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
if params.batch_idx_train % params.log_interval == 0:
|
| 693 |
+
cur_lr = max(scheduler.get_last_lr())
|
| 694 |
+
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
| 695 |
+
|
| 696 |
+
logging.info(
|
| 697 |
+
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
| 698 |
+
f"global_batch_idx: {params.batch_idx_train}, "
|
| 699 |
+
f"batch size: {batch_size}, "
|
| 700 |
+
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
| 701 |
+
f"cur_lr: {cur_lr:.2e}, "
|
| 702 |
+
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
if tb_writer is not None:
|
| 706 |
+
tb_writer.add_scalar(
|
| 707 |
+
"train/learning_rate", cur_lr, params.batch_idx_train
|
| 708 |
+
)
|
| 709 |
+
loss_info.write_summary(
|
| 710 |
+
tb_writer, "train/current_", params.batch_idx_train
|
| 711 |
+
)
|
| 712 |
+
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
| 713 |
+
if params.use_fp16:
|
| 714 |
+
tb_writer.add_scalar(
|
| 715 |
+
"train/grad_scale",
|
| 716 |
+
cur_grad_scale,
|
| 717 |
+
params.batch_idx_train,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
loss_value = tot_loss["loss"]
|
| 721 |
+
params.train_loss = loss_value
|
| 722 |
+
if params.train_loss < params.best_train_loss:
|
| 723 |
+
params.best_train_epoch = params.cur_epoch
|
| 724 |
+
params.best_train_loss = params.train_loss
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def compute_validation_loss(
|
| 728 |
+
params: AttributeDict,
|
| 729 |
+
model: Union[nn.Module, DDP],
|
| 730 |
+
valid_dl: torch.utils.data.DataLoader,
|
| 731 |
+
world_size: int = 1,
|
| 732 |
+
) -> MetricsTracker:
|
| 733 |
+
"""Run the validation process."""
|
| 734 |
+
|
| 735 |
+
model.eval()
|
| 736 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 737 |
+
|
| 738 |
+
# used to summary the stats over iterations
|
| 739 |
+
tot_loss = MetricsTracker()
|
| 740 |
+
|
| 741 |
+
for batch_idx, batch in enumerate(valid_dl):
|
| 742 |
+
tokens, features, features_lens = prepare_input(
|
| 743 |
+
params=params,
|
| 744 |
+
batch=batch,
|
| 745 |
+
device=device,
|
| 746 |
+
return_tokens=True,
|
| 747 |
+
return_feature=True,
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
loss, loss_info = compute_fbank_loss(
|
| 751 |
+
params=params,
|
| 752 |
+
model=model,
|
| 753 |
+
features=features,
|
| 754 |
+
features_lens=features_lens,
|
| 755 |
+
tokens=tokens,
|
| 756 |
+
is_training=False,
|
| 757 |
+
)
|
| 758 |
+
assert loss.requires_grad is False
|
| 759 |
+
tot_loss = tot_loss + loss_info
|
| 760 |
+
|
| 761 |
+
if world_size > 1:
|
| 762 |
+
tot_loss.reduce(loss.device)
|
| 763 |
+
|
| 764 |
+
loss_value = tot_loss["loss"]
|
| 765 |
+
if loss_value < params.best_valid_loss:
|
| 766 |
+
params.best_valid_epoch = params.cur_epoch
|
| 767 |
+
params.best_valid_loss = loss_value
|
| 768 |
+
|
| 769 |
+
return tot_loss
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def display_and_save_batch(
|
| 773 |
+
batch: dict,
|
| 774 |
+
params: AttributeDict,
|
| 775 |
+
) -> None:
|
| 776 |
+
"""Display the batch statistics and save the batch into disk.
|
| 777 |
+
|
| 778 |
+
Args:
|
| 779 |
+
batch:
|
| 780 |
+
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
| 781 |
+
for the content in it.
|
| 782 |
+
params:
|
| 783 |
+
Parameters for training. See :func:`get_params`.
|
| 784 |
+
sp:
|
| 785 |
+
The BPE model.
|
| 786 |
+
"""
|
| 787 |
+
from lhotse.utils import uuid4
|
| 788 |
+
|
| 789 |
+
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
| 790 |
+
logging.info(f"Saving batch to {filename}")
|
| 791 |
+
torch.save(batch, filename)
|
| 792 |
+
|
| 793 |
+
features = batch["features"]
|
| 794 |
+
tokens = batch["tokens"]
|
| 795 |
+
|
| 796 |
+
logging.info(f"features shape: {features.shape}")
|
| 797 |
+
num_tokens = sum(len(i) for i in tokens)
|
| 798 |
+
logging.info(f"num tokens: {num_tokens}")
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
def scan_pessimistic_batches_for_oom(
|
| 802 |
+
model: Union[nn.Module, DDP],
|
| 803 |
+
train_dl: torch.utils.data.DataLoader,
|
| 804 |
+
optimizer: torch.optim.Optimizer,
|
| 805 |
+
params: AttributeDict,
|
| 806 |
+
):
|
| 807 |
+
from lhotse.dataset import find_pessimistic_batches
|
| 808 |
+
|
| 809 |
+
logging.info(
|
| 810 |
+
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
| 811 |
+
)
|
| 812 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 813 |
+
|
| 814 |
+
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
| 815 |
+
for criterion, cuts in batches.items():
|
| 816 |
+
batch = train_dl.dataset[cuts]
|
| 817 |
+
tokens, features, features_lens = prepare_input(
|
| 818 |
+
params=params,
|
| 819 |
+
batch=batch,
|
| 820 |
+
device=device,
|
| 821 |
+
return_tokens=True,
|
| 822 |
+
return_feature=True,
|
| 823 |
+
)
|
| 824 |
+
try:
|
| 825 |
+
with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
|
| 826 |
+
|
| 827 |
+
loss, loss_info = compute_fbank_loss(
|
| 828 |
+
params=params,
|
| 829 |
+
model=model,
|
| 830 |
+
features=features,
|
| 831 |
+
features_lens=features_lens,
|
| 832 |
+
tokens=tokens,
|
| 833 |
+
is_training=True,
|
| 834 |
+
)
|
| 835 |
+
loss.backward()
|
| 836 |
+
optimizer.zero_grad()
|
| 837 |
+
except Exception as e:
|
| 838 |
+
if "CUDA out of memory" in str(e):
|
| 839 |
+
logging.error(
|
| 840 |
+
"Your GPU ran out of memory with the current "
|
| 841 |
+
"max_duration setting. We recommend decreasing "
|
| 842 |
+
"max_duration and trying again.\n"
|
| 843 |
+
f"Failing criterion: {criterion} "
|
| 844 |
+
f"(={crit_values[criterion]}) ..."
|
| 845 |
+
)
|
| 846 |
+
display_and_save_batch(batch, params=params)
|
| 847 |
+
raise
|
| 848 |
+
logging.info(
|
| 849 |
+
f"Maximum memory allocated so far is "
|
| 850 |
+
f"{torch.cuda.max_memory_allocated() // 1000000}MB"
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
def tokenize_text(c: Cut, tokenizer):
|
| 855 |
+
if hasattr(c.supervisions[0], "tokens"):
|
| 856 |
+
tokens = tokenizer.tokens_to_token_ids([c.supervisions[0].tokens])
|
| 857 |
+
else:
|
| 858 |
+
tokens = tokenizer.texts_to_token_ids([c.supervisions[0].text])
|
| 859 |
+
print("ko tΓ¬m Δược tokens")
|
| 860 |
+
c.supervisions[0].tokens = tokens[0]
|
| 861 |
+
return c
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
def run(rank, world_size, args):
|
| 865 |
+
"""
|
| 866 |
+
Args:
|
| 867 |
+
rank:
|
| 868 |
+
It is a value between 0 and `world_size-1`, which is
|
| 869 |
+
passed automatically by `mp.spawn()` in :func:`main`.
|
| 870 |
+
The node with rank 0 is responsible for saving checkpoint.
|
| 871 |
+
world_size:
|
| 872 |
+
Number of GPUs for DDP training.
|
| 873 |
+
args:
|
| 874 |
+
The return value of get_parser().parse_args()
|
| 875 |
+
"""
|
| 876 |
+
params = get_params()
|
| 877 |
+
params.update(vars(args))
|
| 878 |
+
params.valid_interval = params.save_every_n
|
| 879 |
+
# Set epoch to a large number to ignore it.
|
| 880 |
+
if params.num_iters > 0:
|
| 881 |
+
params.num_epochs = 1000000
|
| 882 |
+
with open(params.model_config, "r") as f:
|
| 883 |
+
model_config = json.load(f)
|
| 884 |
+
params.update(model_config["model"])
|
| 885 |
+
params.update(model_config["feature"])
|
| 886 |
+
|
| 887 |
+
fix_random_seed(params.seed)
|
| 888 |
+
if world_size > 1:
|
| 889 |
+
setup_dist(rank, world_size, params.master_port)
|
| 890 |
+
|
| 891 |
+
os.makedirs(f"{params.exp_dir}", exist_ok=True)
|
| 892 |
+
copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
|
| 893 |
+
copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
|
| 894 |
+
setup_logger(f"{params.exp_dir}/log/log-train")
|
| 895 |
+
|
| 896 |
+
if args.tensorboard and rank == 0:
|
| 897 |
+
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
| 898 |
+
else:
|
| 899 |
+
tb_writer = None
|
| 900 |
+
|
| 901 |
+
if torch.cuda.is_available():
|
| 902 |
+
params.device = torch.device("cuda", rank)
|
| 903 |
+
else:
|
| 904 |
+
params.device = torch.device("cpu")
|
| 905 |
+
logging.info(f"Device: {params.device}")
|
| 906 |
+
|
| 907 |
+
if params.tokenizer == "emilia":
|
| 908 |
+
tokenizer = EmiliaTokenizer(token_file=params.token_file)
|
| 909 |
+
elif params.tokenizer == "libritts":
|
| 910 |
+
tokenizer = LibriTTSTokenizer(token_file=params.token_file)
|
| 911 |
+
elif params.tokenizer == "espeak":
|
| 912 |
+
tokenizer = EspeakTokenizer(token_file=params.token_file, lang=params.lang)
|
| 913 |
+
elif params.tokenizer == "simple2":
|
| 914 |
+
tokenizer = SimpleTokenizer2(token_file=params.token_file)
|
| 915 |
+
else:
|
| 916 |
+
assert params.tokenizer == "simple"
|
| 917 |
+
tokenizer = SimpleTokenizer(token_file=params.token_file)
|
| 918 |
+
|
| 919 |
+
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
|
| 920 |
+
params.update(tokenizer_config)
|
| 921 |
+
|
| 922 |
+
logging.info(params)
|
| 923 |
+
|
| 924 |
+
logging.info("About to create model")
|
| 925 |
+
|
| 926 |
+
model = ZipVoice(
|
| 927 |
+
**model_config["model"],
|
| 928 |
+
**tokenizer_config,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
if params.checkpoint is not None:
|
| 932 |
+
logging.info(f"Loading pre-trained model from {params.checkpoint}")
|
| 933 |
+
_ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
|
| 934 |
+
num_param = sum([p.numel() for p in model.parameters()])
|
| 935 |
+
logging.info(f"Number of parameters : {num_param}")
|
| 936 |
+
|
| 937 |
+
model_avg: Optional[nn.Module] = None
|
| 938 |
+
if rank == 0:
|
| 939 |
+
# model_avg is only used with rank 0
|
| 940 |
+
model_avg = copy.deepcopy(model).to(torch.float64)
|
| 941 |
+
|
| 942 |
+
assert params.start_epoch > 0, params.start_epoch
|
| 943 |
+
if params.start_epoch > 1:
|
| 944 |
+
checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
|
| 945 |
+
|
| 946 |
+
model = model.to(params.device)
|
| 947 |
+
if world_size > 1:
|
| 948 |
+
logging.info("Using DDP")
|
| 949 |
+
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
| 950 |
+
|
| 951 |
+
optimizer = ScaledAdam(
|
| 952 |
+
get_parameter_groups_with_lrs(
|
| 953 |
+
model,
|
| 954 |
+
lr=params.base_lr,
|
| 955 |
+
include_names=True,
|
| 956 |
+
),
|
| 957 |
+
lr=params.base_lr, # should have no effect
|
| 958 |
+
clipping_scale=2.0,
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
assert params.lr_hours >= 0
|
| 962 |
+
|
| 963 |
+
if params.finetune:
|
| 964 |
+
scheduler = FixedLRScheduler(optimizer)
|
| 965 |
+
elif params.lr_hours > 0:
|
| 966 |
+
scheduler = Eden(optimizer, params.lr_batches, params.lr_hours)
|
| 967 |
+
else:
|
| 968 |
+
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
| 969 |
+
|
| 970 |
+
scaler = create_grad_scaler(enabled=params.use_fp16)
|
| 971 |
+
|
| 972 |
+
if params.start_epoch > 1 and checkpoints is not None:
|
| 973 |
+
# load state_dict for optimizers
|
| 974 |
+
if "optimizer" in checkpoints:
|
| 975 |
+
logging.info("Loading optimizer state dict")
|
| 976 |
+
optimizer.load_state_dict(checkpoints["optimizer"])
|
| 977 |
+
|
| 978 |
+
# load state_dict for schedulers
|
| 979 |
+
if "scheduler" in checkpoints:
|
| 980 |
+
logging.info("Loading scheduler state dict")
|
| 981 |
+
scheduler.load_state_dict(checkpoints["scheduler"])
|
| 982 |
+
|
| 983 |
+
if "grad_scaler" in checkpoints:
|
| 984 |
+
logging.info("Loading grad scaler state dict")
|
| 985 |
+
scaler.load_state_dict(checkpoints["grad_scaler"])
|
| 986 |
+
|
| 987 |
+
if params.print_diagnostics:
|
| 988 |
+
opts = diagnostics.TensorDiagnosticOptions(
|
| 989 |
+
512
|
| 990 |
+
) # allow 4 megabytes per sub-module
|
| 991 |
+
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
| 992 |
+
|
| 993 |
+
if params.inf_check:
|
| 994 |
+
register_inf_check_hooks(model)
|
| 995 |
+
|
| 996 |
+
def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
|
| 997 |
+
if c.duration < min_len or c.duration > max_len:
|
| 998 |
+
return False
|
| 999 |
+
return True
|
| 1000 |
+
|
| 1001 |
+
_remove_short_and_long_utt = partial(
|
| 1002 |
+
remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
|
| 1003 |
+
)
|
| 1004 |
+
|
| 1005 |
+
datamodule = TtsDataModule(args)
|
| 1006 |
+
if params.dataset == "emilia":
|
| 1007 |
+
train_cuts = CutSet.mux(
|
| 1008 |
+
datamodule.train_emilia_EN_cuts(),
|
| 1009 |
+
datamodule.train_emilia_ZH_cuts(),
|
| 1010 |
+
weights=[46000, 49000],
|
| 1011 |
+
)
|
| 1012 |
+
train_cuts = train_cuts.filter(_remove_short_and_long_utt)
|
| 1013 |
+
dev_cuts = CutSet.mux(
|
| 1014 |
+
datamodule.dev_emilia_EN_cuts(),
|
| 1015 |
+
datamodule.dev_emilia_ZH_cuts(),
|
| 1016 |
+
weights=[0.5, 0.5],
|
| 1017 |
+
)
|
| 1018 |
+
elif params.dataset == "libritts":
|
| 1019 |
+
train_cuts = datamodule.train_libritts_cuts()
|
| 1020 |
+
train_cuts = train_cuts.filter(_remove_short_and_long_utt)
|
| 1021 |
+
dev_cuts = datamodule.dev_libritts_cuts()
|
| 1022 |
+
else:
|
| 1023 |
+
assert params.dataset == "custom"
|
| 1024 |
+
train_cuts = datamodule.train_custom_cuts(params.train_manifest)
|
| 1025 |
+
train_cuts = train_cuts.filter(_remove_short_and_long_utt)
|
| 1026 |
+
dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
|
| 1027 |
+
# To avoid OOM issues due to too long dev cuts
|
| 1028 |
+
dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
|
| 1029 |
+
|
| 1030 |
+
if params.tokenizer in ["emilia", "espeak", "dialog"]:
|
| 1031 |
+
if not hasattr(train_cuts[0].supervisions[0], "tokens") or not hasattr(
|
| 1032 |
+
dev_cuts[0].supervisions[0], "tokens"
|
| 1033 |
+
):
|
| 1034 |
+
logging.warning(
|
| 1035 |
+
f"Using {params.tokenizer} tokenizer but tokens are not prepared,"
|
| 1036 |
+
f"will tokenize on-the-fly, which can slow down training significantly."
|
| 1037 |
+
)
|
| 1038 |
+
_tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
|
| 1039 |
+
train_cuts = train_cuts.map(_tokenize_text)
|
| 1040 |
+
dev_cuts = dev_cuts.map(_tokenize_text)
|
| 1041 |
+
|
| 1042 |
+
train_dl = datamodule.train_dataloaders(train_cuts)
|
| 1043 |
+
|
| 1044 |
+
valid_dl = datamodule.dev_dataloaders(dev_cuts)
|
| 1045 |
+
|
| 1046 |
+
if params.scan_oom:
|
| 1047 |
+
scan_pessimistic_batches_for_oom(
|
| 1048 |
+
model=model,
|
| 1049 |
+
train_dl=train_dl,
|
| 1050 |
+
optimizer=optimizer,
|
| 1051 |
+
params=params,
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
logging.info("Training started")
|
| 1055 |
+
|
| 1056 |
+
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
| 1057 |
+
logging.info(f"Start epoch {epoch}")
|
| 1058 |
+
|
| 1059 |
+
if params.lr_hours == 0:
|
| 1060 |
+
scheduler.step_epoch(epoch - 1)
|
| 1061 |
+
fix_random_seed(params.seed + epoch - 1)
|
| 1062 |
+
train_dl.sampler.set_epoch(epoch - 1)
|
| 1063 |
+
|
| 1064 |
+
params.cur_epoch = epoch
|
| 1065 |
+
|
| 1066 |
+
if tb_writer is not None:
|
| 1067 |
+
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
| 1068 |
+
|
| 1069 |
+
train_one_epoch(
|
| 1070 |
+
params=params,
|
| 1071 |
+
model=model,
|
| 1072 |
+
model_avg=model_avg,
|
| 1073 |
+
optimizer=optimizer,
|
| 1074 |
+
scheduler=scheduler,
|
| 1075 |
+
train_dl=train_dl,
|
| 1076 |
+
valid_dl=valid_dl,
|
| 1077 |
+
scaler=scaler,
|
| 1078 |
+
tb_writer=tb_writer,
|
| 1079 |
+
world_size=world_size,
|
| 1080 |
+
rank=rank,
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
|
| 1084 |
+
break
|
| 1085 |
+
|
| 1086 |
+
if params.print_diagnostics:
|
| 1087 |
+
diagnostic.print_diagnostics()
|
| 1088 |
+
break
|
| 1089 |
+
|
| 1090 |
+
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
| 1091 |
+
save_checkpoint(
|
| 1092 |
+
filename=filename,
|
| 1093 |
+
params=params,
|
| 1094 |
+
model=model,
|
| 1095 |
+
model_avg=model_avg,
|
| 1096 |
+
optimizer=optimizer,
|
| 1097 |
+
scheduler=scheduler,
|
| 1098 |
+
sampler=train_dl.sampler,
|
| 1099 |
+
scaler=scaler,
|
| 1100 |
+
rank=rank,
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
if rank == 0:
|
| 1104 |
+
if params.best_train_epoch == params.cur_epoch:
|
| 1105 |
+
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
| 1106 |
+
copyfile(src=filename, dst=best_train_filename)
|
| 1107 |
+
|
| 1108 |
+
if params.best_valid_epoch == params.cur_epoch:
|
| 1109 |
+
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
| 1110 |
+
copyfile(src=filename, dst=best_valid_filename)
|
| 1111 |
+
|
| 1112 |
+
logging.info("Done!")
|
| 1113 |
+
|
| 1114 |
+
if world_size > 1:
|
| 1115 |
+
torch.distributed.barrier()
|
| 1116 |
+
cleanup_dist()
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
def main():
|
| 1120 |
+
parser = get_parser()
|
| 1121 |
+
TtsDataModule.add_arguments(parser)
|
| 1122 |
+
args = parser.parse_args()
|
| 1123 |
+
args.exp_dir = Path(args.exp_dir)
|
| 1124 |
+
|
| 1125 |
+
world_size = args.world_size
|
| 1126 |
+
assert world_size >= 1
|
| 1127 |
+
if world_size > 1:
|
| 1128 |
+
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
| 1129 |
+
else:
|
| 1130 |
+
run(rank=0, world_size=1, args=args)
|
| 1131 |
+
|
| 1132 |
+
|
| 1133 |
+
if __name__ == "__main__":
|
| 1134 |
+
torch.set_num_threads(1)
|
| 1135 |
+
torch.set_num_interop_threads(1)
|
| 1136 |
+
main()
|