diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..83cfd8dbb643612f79f25d84b65ac7e4b3c4fb7f
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,36 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.wav filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b7e841d6a18ea413867d688d50a8c9f1fbea0483
--- /dev/null
+++ b/README.md
@@ -0,0 +1,9 @@
+---
+license: cc-by-sa-4.0
+title: Text_to_speech_Vietnamese
+sdk: gradio
+emoji: 🚀
+colorFrom: red
+colorTo: yellow
+pinned: false
+---
\ No newline at end of file
diff --git a/T_English.wav b/T_English.wav
new file mode 100644
index 0000000000000000000000000000000000000000..508f4a76173337cba25ab3258ada4251162f6e12
--- /dev/null
+++ b/T_English.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d6ffd499fdf637243bdd630cb52660635d5c7cb580b87f52aa7efca90a33311f
+size 328364
diff --git a/T_English_output.wav b/T_English_output.wav
new file mode 100644
index 0000000000000000000000000000000000000000..524f017bd4c2f0a8ddd67d796f8d2ced4982df72
--- /dev/null
+++ b/T_English_output.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:756daab105a8f4508e6d4c237f1e72a25ec326ad1f6665dce974f96e9b86db7a
+size 954742
diff --git "a/To\341\272\241i.wav" "b/To\341\272\241i.wav"
new file mode 100644
index 0000000000000000000000000000000000000000..b86bc4342732296ef56c3733a9e95fdc83971c92
--- /dev/null
+++ "b/To\341\272\241i.wav"
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:798c60882fa0e9a758fd72cf506e6b71293c5ccb5ed27b92569d042a23624bdc
+size 200782
diff --git "a/To\341\272\241i_output.wav" "b/To\341\272\241i_output.wav"
new file mode 100644
index 0000000000000000000000000000000000000000..1a0edc515d450ff262cc07d76c3a3b6282b99595
--- /dev/null
+++ "b/To\341\272\241i_output.wav"
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3916769643ff756700e08ae355a0f7e30fd7a0e5299b06a866facad5ff31afd1
+size 2154910
diff --git a/Trung.wav b/Trung.wav
new file mode 100644
index 0000000000000000000000000000000000000000..959987c09e00963357f86b2d52ab7ccd781facf9
--- /dev/null
+++ b/Trung.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bf7087aa7452978b6ec6c25eb8c078eb4ca337660c9d3e3661b8017da9238e9
+size 199376
diff --git a/Trung_output.wav b/Trung_output.wav
new file mode 100644
index 0000000000000000000000000000000000000000..d1298f3af169e3c7db36112c959a7ca58bf73729
--- /dev/null
+++ b/Trung_output.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:068c6cfd893846473d177b9636b4689119a0c0ee7e19f4079cd6c98e27bb94a3
+size 745196
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..15e13bd1c750a8a9dbfcbc0a7255a333f8cf0390
--- /dev/null
+++ b/app.py
@@ -0,0 +1,346 @@
+import spaces
+import os
+from download_models import download_all_models
+from huggingface_hub import login
+
+import gradio as gr
+
+# ======================= HF LOGIN & DOWNLOAD MODEL =======================
+hf_token = os.getenv("HF_TOKEN")
+if hf_token:
+ login(token=hf_token)
+
+# Tải model khi Space ACTIVE
+download_all_models()
+
+from infer import run_zipvoice
+
+# NEW: ASR + DENOISE
+from chunkformer import ChunkFormerModel
+from clearvoice import ClearVoice
+from proccess_wav import enhance_ref_audio, transcribe_ref_audio
+
+# (Nếu 2 dòng test này không cần thì bạn có thể xoá bớt cho nhẹ)
+enhanced = enhance_ref_audio("Toại.wav")
+text = transcribe_ref_audio(enhanced)
+
+
+def infer_ref_text_ui(ref_audio_path: str) -> str:
+ """
+ Dùng cho nút 'Infer Text':
+ - Enhance WAV (ClearVoice + xử lý khoảng lặng + cắt 5–10s)
+ - ASR theo khoảng lặng
+ - Đổ kết quả vào ô Reference Text
+ """
+ if not ref_audio_path:
+ raise gr.Error("Vui lòng upload file giọng mẫu trước khi infer text.")
+
+ try:
+ enhanced = enhance_ref_audio(ref_audio_path)
+ text = transcribe_ref_audio(enhanced)
+ except Exception as e:
+ raise gr.Error(f"Lỗi khi nhận dạng từ audio tham chiếu: {e}")
+
+ if not text:
+ raise gr.Error("Không nhận dạng được nội dung từ audio tham chiếu.")
+ return text
+
+
+# ======================= CẤU HÌNH DEMO SẴN =======================
+SAMPLE_CONFIGS = [
+ {
+ "name": "Sample 1 – Kể chuyện",
+ "ref_audio": "Toại.wav",
+ "ref_text": "Trong bóng tối, Toại nói cái gì đó mà Thoan không nghe thấy.",
+ "gen_text": "Đêm nay trời nhiều mây, ánh trăng bị che khuất, chỉ còn lại một dải sáng yếu ớt rơi xuống con đường đất trải dài giữa cánh đồng. Cậu bé tên Tín đang dắt chiếc xe đạp cũ đi về nhà, bánh xe bị cán đinh nên lăn nặng và chậm như con trâu mệt nhọc sau vụ mùa. Gió thổi lạnh buốt, mùi bùn đất ngai ngái quấn lấy chân cậu. Tới đoạn rẽ dẫn vào xóm, Tín nghe tiếng nước chảy khe khẽ từ con mương bên đường. Tiếng ấy vẫn quen thuộc, nhưng tối nay lại vang khác lạ, như có giọng người đang hòa vào nhịp nước, lúc trầm lúc cao, nghe mơ hồ mà lạnh sống lưng. Cậu dừng lại, nghiêng tai lắng nghe, tim đập nhanh như muốn vượt khỏi lồng ngực.",
+ "out_audio": "Toại_output.wav",
+ },
+ {
+ "name": "Sample 2 – Nữ",
+ "ref_audio": "Trung.wav",
+ "ref_text": "Mùa hè không chỉ là khoảng thời gian nghỉ ngơi, mà còn là khoảng thời gian tuyệt vời.",
+ "gen_text": "Từ các kết quả này, chúng tôi đề xuất rằng sự kết hợp nhuần nhuyễn giữa adaptive optimization, robust training pipelines và interpretable model design sẽ là chìa khóa để phát triển các hệ thống ây ai vừa mạnh mẽ vừa đáng tin cậy trong môi trường thực tế.",
+ "out_audio": "Trung_output.wav",
+ },
+ {
+ "name": "Sample 3 – English",
+ "ref_audio": "T_English.wav",
+ "ref_text": "And turning to the pole which he had dragged, He drew it close beneath the widowed bough, And what was of it unto it left bound.",
+ "gen_text": "Recent experiments indicate that the current model architecture still exhibits significant overfitting, especially when evaluated on out of distribution samples. Although the training accuracy remains consistently high, the performance drops sharply when the model is exposed to noise perturbed inputs, suggesting limited robustness.",
+ "out_audio": "T_English_output.wav",
+ },
+]
+
+# Hàm dùng khi bấm "Dùng sample này"
+def make_sample_loader(sample):
+ def _load_sample():
+ return (
+ sample["ref_audio"], # ref_audio -> input Audio
+ sample["ref_text"], # ref_text -> Textbox
+ sample["gen_text"], # gen_text -> Textbox
+ sample["out_audio"], # output_audio -> Audio
+ )
+ return _load_sample
+
+
+# ======================= STYLE TÙY CHỈNH (LÀM SÁNG HƠN) =======================
+custom_css = """
+#app-container {
+ max-width: 1000px;
+ margin: 0 auto;
+}
+.gradio-container {
+ background: radial-gradient(circle at top, #ffffff 0, #f9fafb 55%);
+ color: #111827;
+}
+
+/* Tiêu đề lớn */
+#title-block h1 {
+ font-size: 2.4rem !important;
+ font-weight: 800 !important;
+ background: linear-gradient(120deg, #f97316, #eab308, #22c55e);
+ -webkit-background-clip: text;
+ color: transparent;
+ text-align: center;
+}
+#title-block p {
+ text-align:center;
+ font-size: 0.95rem;
+ color: #6b7280;
+}
+
+/* Card sáng hơn */
+.sample-card {
+ border-radius: 16px;
+ padding: 16px;
+ background: rgba(255, 255, 255, 0.96);
+ border: 1px solid rgba(148, 163, 184, 0.6);
+ box-shadow: 0 18px 28px rgba(148, 163, 184, 0.35);
+}
+
+/* Nút bấm */
+button.primary {
+ border-radius: 999px !important;
+ font-weight: 600 !important;
+}
+
+/* Tabs */
+.svelte-1ipelgc, .tabitem {
+ font-weight: 600;
+}
+"""
+
+# ======================= XỬ LÝ TEXT (NẾU CẦN) =======================
+def post_process(text: str) -> str:
+ text = " " + text + " "
+ text = text.replace(" . . ", " . ")
+ text = " " + text + " "
+ text = text.replace(" .. ", " . ")
+ text = " " + text + " "
+ text = text.replace(" , , ", " , ")
+ text = " " + text + " "
+ text = text.replace(" ,, ", " , ")
+ text = " " + text + " "
+ text = text.replace('"', "")
+ return " ".join(text.split())
+
+
+@spaces.GPU
+def infer_tts(ref_audio_path, ref_text, gen_text, steps, request: gr.Request = None):
+ if not ref_audio_path:
+ raise gr.Error("Please upload a sample audio file.")
+
+ if not gen_text.strip():
+ raise gr.Error("Please enter the text content to generate voice.")
+
+ # Giới hạn độ dài nội dung (4000 từ)
+ if len(gen_text.split()) > 4000:
+ raise gr.Error("Please enter text content with less than 4000 words.")
+
+ # 1) Enhance ref audio: clearvoice + xử lý khoảng lặng + cắt 5–10s
+ try:
+ enhanced_ref_audio = enhance_ref_audio(ref_audio_path)
+ except Exception as e:
+ raise gr.Error(f"Lỗi khi xử lý audio tham chiếu: {e}")
+
+ # 2) Nếu không có ref_text thì chạy ASR theo khoảng lặng
+ if not ref_text or not ref_text.strip():
+ try:
+ inferred = transcribe_ref_audio(enhanced_ref_audio)
+ if not inferred:
+ raise gr.Error(
+ "Không nhận dạng được nội dung từ audio tham chiếu. "
+ "Vui lòng nhập Reference Text thủ công."
+ )
+ ref_text = inferred
+ print(f"[ASR] Inferred ref_text: {ref_text}")
+ except gr.Error:
+ raise
+ except Exception as e:
+ raise gr.Error(f"Lỗi khi tự động nhận dạng Reference Text: {e}")
+
+ try:
+ out_path = "result.wav"
+
+ run_zipvoice(
+ model_name="zipvoice",
+ prompt_wav=enhanced_ref_audio, # dùng file đã xử lý
+ prompt_text=ref_text.strip() if ref_text else "xin chào các bạn",
+ text=gen_text,
+ res_wav_path=out_path,
+ lang="vi",
+ tokenizer_name="espeak",
+ num_step=steps,
+ seed=123456,
+ speed=1.0,
+ )
+
+ return out_path
+
+ except Exception as e:
+ raise gr.Error(f"Error generating voice: {e}")
+
+
+# ======================= UI =======================
+with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
+ with gr.Column(elem_id="app-container"):
+ # --------- TIÊU ĐỀ ----------
+ gr.Markdown(
+ """
+
+
🎤 ZipVoice – Zero-shot Vietnamese TTS
+
Upload một mẫu giọng + nhập nội dung → hệ thống sẽ bắt chước giọng nói và đọc đoạn text của bạn.
+
+ """,
+ elem_id="title-block",
+ )
+
+ with gr.Tabs():
+ # Chỉ còn 1 tab chính, demo cũng nằm trong tab này
+ with gr.TabItem("🎯 Tự tạo giọng nói"):
+ # --------- KHỐI INPUT / OUTPUT CHÍNH ----------
+ with gr.Row():
+ with gr.Column(elem_classes=["sample-card"]):
+ gr.Markdown("#### 1️⃣ Tải giọng mẫu & nhập text")
+
+ ref_audio = gr.Audio(
+ label="🔊 Sample Voice (upload hoặc kéo thả)",
+ type="filepath",
+ )
+
+ ref_text = gr.Textbox(
+ label="📝 Reference Text (optional)",
+ placeholder="Nội dung đang được nói trong file giọng mẫu (nên tự viết cho chính xác)",
+ lines=3,
+ )
+
+ # Nút infer text từ audio tham chiếu (ASR + khử nhiễu)
+ btn_infer_text = gr.Button(
+ "✨ Infer Text từ audio tham chiếu"
+ )
+
+ gen_text = gr.Textbox(
+ label="📝 Text to Generate",
+ placeholder="Nhập nội dung tiếng Việt bạn muốn tổng hợp...",
+ lines=6,
+ )
+
+ steps = gr.Slider(
+ 8,
+ 64,
+ value=25,
+ step=1,
+ label="⚡ Step (càng lớn, càng tốt, càng lâu)",
+ )
+
+ btn_synthesize = gr.Button(
+ "🔥 Generate Voice",
+ variant="primary",
+ )
+
+ with gr.Column(elem_classes=["sample-card"]):
+ gr.Markdown("#### 2️⃣ Kết quả tổng hợp")
+ output_audio = gr.Audio(
+ label="🎧 Generated Audio",
+ type="filepath",
+ )
+ gr.Markdown(
+ """
+- Bạn có thể tải file `.wav` về sau khi tạo.
+- Nếu nghe chưa ổn, hãy thử:
+ - Dùng **ref audio ngắn 3-8s, phát âm chuẩn hơn.
+ """
+ )
+
+ # mapping nút Generate -> infer_tts
+ btn_synthesize.click(
+ infer_tts,
+ inputs=[ref_audio, ref_text, gen_text, steps],
+ outputs=[output_audio],
+ )
+
+ # mapping nút Infer Text -> điền ref_text (có khử nhiễu trước)
+ btn_infer_text.click(
+ infer_ref_text_ui,
+ inputs=[ref_audio],
+ outputs=[ref_text],
+ )
+
+ # --------- KHỐI DEMO NẰM NGAY TRONG TAB CHÍNH ----------
+ gr.Markdown(
+ """
+### 🎧 Demo có sẵn
+Click vào một sample bên dưới để tự động nạp:
+- 🔊 Giọng mẫu (ref voice)
+- 📝 Reference text
+- 📝 Text to generate
+- 🎧 Output audio mẫu
+ """
+ )
+
+ for sample in SAMPLE_CONFIGS:
+ with gr.Column(elem_classes=["sample-card"]):
+ gr.Markdown(f"### {sample['name']}")
+ with gr.Row():
+ gr.Audio(
+ value=sample["ref_audio"],
+ label="🔊 Reference Voice",
+ interactive=False,
+ )
+ gr.Textbox(
+ value=sample["ref_text"],
+ label="📝 Reference Text",
+ interactive=False,
+ lines=3,
+ )
+
+ gr.Audio(
+ value=sample["out_audio"],
+ label="🎧 Generated Sample (TTS)",
+ interactive=False,
+ )
+
+ if sample.get("gen_text"):
+ gr.Markdown(
+ f"**Text dùng để synth:** {sample['gen_text']}"
+ )
+
+ # Nút này sẽ fill luôn ref_audio, ref_text, gen_text, output_audio
+ use_btn = gr.Button(f"➡️ Dùng {sample['name']}")
+
+ use_btn.click(
+ make_sample_loader(sample),
+ inputs=[],
+ outputs=[ref_audio, ref_text, gen_text, output_audio],
+ )
+
+ gr.Markdown(
+ """
+### ⚠️ Model Limitations
+1. Có thể xử lý chưa tốt với số, ngày tháng, ký tự đặc biệt.
+2. Nhịp điệu đôi khi chưa tự nhiên.
+3. Chất lượng phụ thuộc khá nhiều vào chất lượng ref audio.
+"""
+ )
+
+demo.queue().launch()
diff --git a/assets/silence.wav b/assets/silence.wav
new file mode 100644
index 0000000000000000000000000000000000000000..21baf5596af5b0270a37394b8a7cd652f69f9385
--- /dev/null
+++ b/assets/silence.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ca5a251f2d1439929f1c6b44d98299e53d402da45306af79cbfab5005501fed9
+size 4800044
diff --git a/download_models.py b/download_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b03faace7dd6891b4ed7a206643601234a51700
--- /dev/null
+++ b/download_models.py
@@ -0,0 +1,38 @@
+import os
+import requests
+
+MODEL_DIR = "zipvoice_finetune"
+os.makedirs(MODEL_DIR, exist_ok=True)
+
+files = {
+ "iter-525000-avg-2.pt": "https://huggingface.co/datasets/kjanh/demo_zip/resolve/main/epoch-46-all-speak-600h-en-norm.pt",
+ "model.json": "https://huggingface.co/datasets/kjanh/demo_zip/resolve/main/model.json",
+ "tokens.txt": "https://huggingface.co/datasets/kjanh/demo_zip/resolve/main/tokens.txt",
+}
+
+HF_TOKEN = os.getenv("HF_TOKEN")
+
+def download_with_token(url, dest_path):
+ if os.path.exists(dest_path):
+ print(f"✔ File tồn tại: {dest_path}")
+ return
+
+ if HF_TOKEN is None:
+ raise RuntimeError("❌ Missing HF_TOKEN in Secrets!")
+
+ print(f"⬇ Downloading {dest_path} ...")
+
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
+ r = requests.get(url, headers=headers, stream=True)
+ r.raise_for_status()
+
+ with open(dest_path, "wb") as f:
+ for chunk in r.iter_content(1024 * 1024):
+ f.write(chunk)
+
+ print(f"✅ Downloaded {dest_path}")
+# demo
+def download_all_models():
+ for filename, url in files.items():
+ dest = os.path.join(MODEL_DIR, filename)
+ download_with_token(url, dest)
diff --git a/egs/zipvoice/README.md b/egs/zipvoice/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0aaed14807c929e3d140b54fb4298090701ea358
--- /dev/null
+++ b/egs/zipvoice/README.md
@@ -0,0 +1,15 @@
+# ZipVoice Recipe
+
+This recipe contains the following examples:
+
+- Training ZipVoice on Emilia from scratch, see [run_emilia.sh](run_emilia.sh)
+- Training ZipVoice on LibriTTS from scratch, see [run_libritts.sh](run_libritts.sh).
+- Training ZipVoice on custom datasets (any language) from scratch, see [run_custom.sh](run_custom.sh).
+- Fine-tuning pre-trained ZipVoice on custom datasets (any language), see [run_finetune.sh](run_finetune.sh).
+- Evaluate TTS models with objective metrics reported in ZipVoice paper, see [run_eval.sh](run_eval.sh).
+
+> **NOTE:** [run_emilia.sh](run_emilia.sh) is the most complete example, which covers: data preparation, ZipVoice trainnig, ZipVoice-Distill training, onnx export, and inference with all PyTorch and ONNX models.
+
+> **NOTE:** For evaluation, first install packages from [../../requirements_eval.txt](../../requirements_eval.txt)
+>
+> `pip install -r ../../requirements_eval.txt`
diff --git a/egs/zipvoice/conf/zipvoice_base.json b/egs/zipvoice/conf/zipvoice_base.json
new file mode 100644
index 0000000000000000000000000000000000000000..013bd157061a155827570f829f850e8ca6832a00
--- /dev/null
+++ b/egs/zipvoice/conf/zipvoice_base.json
@@ -0,0 +1,26 @@
+{
+ "model" : {
+ "fm_decoder_downsampling_factor" : [1,2,4,2,1],
+ "fm_decoder_num_layers" : [2,2,4,4,4],
+ "fm_decoder_cnn_module_kernel" : [31,15,7,15,31],
+ "fm_decoder_feedforward_dim" : 1536,
+ "fm_decoder_num_heads" : 4,
+ "fm_decoder_dim" : 512,
+ "text_encoder_num_layers" : 4,
+ "text_encoder_feedforward_dim" : 512,
+ "text_encoder_cnn_module_kernel" : 9,
+ "text_encoder_num_heads" : 4,
+ "text_encoder_dim" : 192,
+ "query_head_dim" : 32,
+ "value_head_dim" : 12,
+ "pos_head_dim" : 4,
+ "pos_dim" : 48,
+ "time_embed_dim" : 192,
+ "text_embed_dim" : 192,
+ "feat_dim": 100
+ },
+ "feature" : {
+ "sampling_rate": 24000,
+ "type": "vocos"
+ }
+}
\ No newline at end of file
diff --git a/egs/zipvoice/local/pinyin.txt b/egs/zipvoice/local/pinyin.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cd8d14dc38a1ee8f0cac16560c404f1c0795192e
--- /dev/null
+++ b/egs/zipvoice/local/pinyin.txt
@@ -0,0 +1,1550 @@
+a
+a1
+a2
+a3
+a4
+ai1
+ai2
+ai3
+ai4
+an1
+an2
+an3
+an4
+ang1
+ang2
+ang3
+ang4
+ao1
+ao2
+ao3
+ao4
+ba
+ba1
+ba2
+ba3
+ba4
+bai
+bai1
+bai2
+bai3
+bai4
+ban
+ban1
+ban3
+ban4
+bang1
+bang3
+bang4
+bao1
+bao2
+bao3
+bao4
+bei
+bei1
+bei3
+bei4
+ben1
+ben3
+ben4
+beng
+beng1
+beng2
+beng3
+beng4
+bi1
+bi2
+bi3
+bi4
+bian
+bian1
+bian3
+bian4
+biang2
+biao1
+biao3
+biao4
+bie1
+bie2
+bie3
+bie4
+bin
+bin1
+bin3
+bin4
+bing1
+bing3
+bing4
+bo
+bo1
+bo2
+bo3
+bo4
+bu1
+bu2
+bu3
+bu4
+ca1
+ca3
+ca4
+cai1
+cai2
+cai3
+cai4
+can1
+can2
+can3
+can4
+cang1
+cang2
+cang3
+cang4
+cao1
+cao2
+cao3
+cao4
+ce4
+cei4
+cen1
+cen2
+ceng1
+ceng2
+ceng4
+cha1
+cha2
+cha3
+cha4
+chai1
+chai2
+chai3
+chai4
+chan1
+chan2
+chan3
+chan4
+chang
+chang1
+chang2
+chang3
+chang4
+chao1
+chao2
+chao3
+chao4
+che1
+che2
+che3
+che4
+chen
+chen1
+chen2
+chen3
+chen4
+cheng1
+cheng2
+cheng3
+cheng4
+chi
+chi1
+chi2
+chi3
+chi4
+chong1
+chong2
+chong3
+chong4
+chou1
+chou2
+chou3
+chou4
+chu
+chu1
+chu2
+chu3
+chu4
+chua1
+chua3
+chua4
+chuai1
+chuai2
+chuai3
+chuai4
+chuan1
+chuan2
+chuan3
+chuan4
+chuang1
+chuang2
+chuang3
+chuang4
+chui1
+chui2
+chui3
+chui4
+chun1
+chun2
+chun3
+chuo1
+chuo4
+ci1
+ci2
+ci3
+ci4
+cong1
+cong2
+cong3
+cong4
+cou1
+cou2
+cou3
+cou4
+cu1
+cu2
+cu3
+cu4
+cuan1
+cuan2
+cuan4
+cui
+cui1
+cui3
+cui4
+cun1
+cun2
+cun3
+cun4
+cuo1
+cuo2
+cuo3
+cuo4
+da
+da1
+da2
+da3
+da4
+dai
+dai1
+dai3
+dai4
+dan1
+dan3
+dan4
+dang
+dang1
+dang3
+dang4
+dao1
+dao2
+dao3
+dao4
+de
+de1
+de2
+dei1
+dei3
+den4
+deng1
+deng3
+deng4
+di1
+di2
+di3
+di4
+dia3
+dian1
+dian2
+dian3
+dian4
+diao1
+diao3
+diao4
+die1
+die2
+die3
+die4
+din4
+ding1
+ding3
+ding4
+diu1
+dong1
+dong3
+dong4
+dou1
+dou3
+dou4
+du1
+du2
+du3
+du4
+duan1
+duan3
+duan4
+dui1
+dui3
+dui4
+dun1
+dun3
+dun4
+duo
+duo1
+duo2
+duo3
+duo4
+e
+e1
+e2
+e3
+e4
+ei1
+ei2
+ei3
+ei4
+en1
+en3
+en4
+eng1
+er
+er2
+er3
+er4
+fa
+fa1
+fa2
+fa3
+fa4
+fan1
+fan2
+fan3
+fan4
+fang
+fang1
+fang2
+fang3
+fang4
+fei1
+fei2
+fei3
+fei4
+fen1
+fen2
+fen3
+fen4
+feng1
+feng2
+feng3
+feng4
+fiao4
+fo2
+fou1
+fou2
+fou3
+fu
+fu1
+fu2
+fu3
+fu4
+ga1
+ga2
+ga3
+ga4
+gai1
+gai3
+gai4
+gan1
+gan3
+gan4
+gang1
+gang3
+gang4
+gao1
+gao3
+gao4
+ge1
+ge2
+ge3
+ge4
+gei3
+gen1
+gen2
+gen3
+gen4
+geng1
+geng3
+geng4
+gong
+gong1
+gong3
+gong4
+gou1
+gou3
+gou4
+gu
+gu1
+gu2
+gu3
+gu4
+gua1
+gua2
+gua3
+gua4
+guai1
+guai3
+guai4
+guan1
+guan3
+guan4
+guang
+guang1
+guang3
+guang4
+gui1
+gui3
+gui4
+gun3
+gun4
+guo
+guo1
+guo2
+guo3
+guo4
+ha1
+ha2
+ha3
+ha4
+hai
+hai1
+hai2
+hai3
+hai4
+han
+han1
+han2
+han3
+han4
+hang1
+hang2
+hang3
+hang4
+hao1
+hao2
+hao3
+hao4
+he1
+he2
+he3
+he4
+hei1
+hen1
+hen2
+hen3
+hen4
+heng1
+heng2
+heng4
+hm
+hng
+hong1
+hong2
+hong3
+hong4
+hou1
+hou2
+hou3
+hou4
+hu
+hu1
+hu2
+hu3
+hu4
+hua1
+hua2
+hua4
+huai
+huai2
+huai4
+huan1
+huan2
+huan3
+huan4
+huang
+huang1
+huang2
+huang3
+huang4
+hui
+hui1
+hui2
+hui3
+hui4
+hun1
+hun2
+hun3
+hun4
+huo
+huo1
+huo2
+huo3
+huo4
+ji1
+ji2
+ji3
+ji4
+jia
+jia1
+jia2
+jia3
+jia4
+jian
+jian1
+jian3
+jian4
+jiang
+jiang1
+jiang3
+jiang4
+jiao
+jiao1
+jiao2
+jiao3
+jiao4
+jie
+jie1
+jie2
+jie3
+jie4
+jin1
+jin3
+jin4
+jing
+jing1
+jing3
+jing4
+jiong1
+jiong3
+jiong4
+jiu
+jiu1
+jiu2
+jiu3
+jiu4
+ju
+ju1
+ju2
+ju3
+ju4
+juan1
+juan3
+juan4
+jue1
+jue2
+jue3
+jue4
+jun1
+jun3
+jun4
+ka1
+ka3
+kai1
+kai3
+kai4
+kan1
+kan3
+kan4
+kang1
+kang2
+kang3
+kang4
+kao1
+kao3
+kao4
+ke
+ke1
+ke2
+ke3
+ke4
+kei1
+ken1
+ken3
+ken4
+keng1
+keng3
+kong1
+kong3
+kong4
+kou1
+kou3
+kou4
+ku1
+ku2
+ku3
+ku4
+kua1
+kua3
+kua4
+kuai3
+kuai4
+kuan1
+kuan3
+kuang1
+kuang2
+kuang3
+kuang4
+kui1
+kui2
+kui3
+kui4
+kun
+kun1
+kun3
+kun4
+kuo4
+la
+la1
+la2
+la3
+la4
+lai2
+lai3
+lai4
+lan2
+lan3
+lan4
+lang
+lang1
+lang2
+lang3
+lang4
+lao
+lao1
+lao2
+lao3
+lao4
+le
+le1
+le4
+lei
+lei1
+lei2
+lei3
+lei4
+len4
+leng1
+leng2
+leng3
+leng4
+li
+li1
+li2
+li3
+li4
+lia3
+lian2
+lian3
+lian4
+liang
+liang2
+liang3
+liang4
+liao1
+liao2
+liao3
+liao4
+lie
+lie1
+lie2
+lie3
+lie4
+lin1
+lin2
+lin3
+lin4
+ling
+ling1
+ling2
+ling3
+ling4
+liu1
+liu2
+liu3
+liu4
+lo
+long1
+long2
+long3
+long4
+lou
+lou1
+lou2
+lou3
+lou4
+lu
+lu1
+lu2
+lu3
+lu4
+luan2
+luan3
+luan4
+lun1
+lun2
+lun3
+lun4
+luo
+luo1
+luo2
+luo3
+luo4
+lv2
+lv3
+lv4
+lve3
+lve4
+m1
+m2
+m4
+ma
+ma1
+ma2
+ma3
+ma4
+mai2
+mai3
+mai4
+man1
+man2
+man3
+man4
+mang1
+mang2
+mang3
+mang4
+mao1
+mao2
+mao3
+mao4
+me
+me1
+mei2
+mei3
+mei4
+men
+men1
+men2
+men4
+meng
+meng1
+meng2
+meng3
+meng4
+mi1
+mi2
+mi3
+mi4
+mian2
+mian3
+mian4
+miao1
+miao2
+miao3
+miao4
+mie
+mie1
+mie2
+mie4
+min
+min2
+min3
+ming
+ming2
+ming3
+ming4
+miu3
+miu4
+mo
+mo1
+mo2
+mo3
+mo4
+mou1
+mou2
+mou3
+mou4
+mu2
+mu3
+mu4
+n
+n2
+n3
+n4
+na
+na1
+na2
+na3
+na4
+nai2
+nai3
+nai4
+nan1
+nan2
+nan3
+nan4
+nang
+nang1
+nang2
+nang3
+nang4
+nao1
+nao2
+nao3
+nao4
+ne
+ne2
+ne4
+nei2
+nei3
+nei4
+nen4
+neng2
+neng3
+neng4
+ng
+ng2
+ng3
+ng4
+ni1
+ni2
+ni3
+ni4
+nia1
+nian1
+nian2
+nian3
+nian4
+niang2
+niang3
+niang4
+niao3
+niao4
+nie1
+nie2
+nie3
+nie4
+nin
+nin2
+nin3
+ning2
+ning3
+ning4
+niu1
+niu2
+niu3
+niu4
+nong2
+nong3
+nong4
+nou2
+nou3
+nou4
+nu2
+nu3
+nu4
+nuan2
+nuan3
+nuan4
+nun2
+nun4
+nuo2
+nuo3
+nuo4
+nv2
+nv3
+nv4
+nve4
+o
+o1
+o2
+o3
+o4
+ou
+ou1
+ou2
+ou3
+ou4
+pa1
+pa2
+pa3
+pa4
+pai1
+pai2
+pai3
+pai4
+pan1
+pan2
+pan3
+pan4
+pang1
+pang2
+pang3
+pang4
+pao1
+pao2
+pao3
+pao4
+pei1
+pei2
+pei3
+pei4
+pen1
+pen2
+pen3
+pen4
+peng1
+peng2
+peng3
+peng4
+pi1
+pi2
+pi3
+pi4
+pian1
+pian2
+pian3
+pian4
+piao1
+piao2
+piao3
+piao4
+pie1
+pie3
+pie4
+pin1
+pin2
+pin3
+pin4
+ping1
+ping2
+ping3
+ping4
+po
+po1
+po2
+po3
+po4
+pou1
+pou2
+pou3
+pou4
+pu
+pu1
+pu2
+pu3
+pu4
+qi
+qi1
+qi2
+qi3
+qi4
+qia1
+qia2
+qia3
+qia4
+qian
+qian1
+qian2
+qian3
+qian4
+qiang1
+qiang2
+qiang3
+qiang4
+qiao1
+qiao2
+qiao3
+qiao4
+qie1
+qie2
+qie3
+qie4
+qin1
+qin2
+qin3
+qin4
+qing
+qing1
+qing2
+qing3
+qing4
+qiong1
+qiong2
+qiong4
+qiu1
+qiu2
+qiu3
+qiu4
+qu
+qu1
+qu2
+qu3
+qu4
+quan
+quan1
+quan2
+quan3
+quan4
+que1
+que2
+que4
+qun1
+qun2
+qun3
+ran2
+ran3
+ran4
+rang1
+rang2
+rang3
+rang4
+rao2
+rao3
+rao4
+re2
+re3
+re4
+ren2
+ren3
+ren4
+reng1
+reng2
+reng4
+ri4
+rong
+rong1
+rong2
+rong3
+rong4
+rou2
+rou3
+rou4
+ru
+ru2
+ru3
+ru4
+rua2
+ruan2
+ruan3
+ruan4
+rui2
+rui3
+rui4
+run2
+run3
+run4
+ruo2
+ruo4
+sa
+sa1
+sa3
+sa4
+sai1
+sai3
+sai4
+san
+san1
+san3
+san4
+sang1
+sang3
+sang4
+sao1
+sao3
+sao4
+se1
+se4
+sen1
+sen3
+seng1
+seng4
+sha
+sha1
+sha2
+sha3
+sha4
+shai1
+shai3
+shai4
+shan1
+shan2
+shan3
+shan4
+shang
+shang1
+shang3
+shang4
+shao1
+shao2
+shao3
+shao4
+she1
+she2
+she3
+she4
+shei2
+shen1
+shen2
+shen3
+shen4
+sheng1
+sheng2
+sheng3
+sheng4
+shi
+shi1
+shi2
+shi3
+shi4
+shou
+shou1
+shou2
+shou3
+shou4
+shu1
+shu2
+shu3
+shu4
+shua1
+shua3
+shua4
+shuai1
+shuai3
+shuai4
+shuan1
+shuan4
+shuang1
+shuang3
+shuang4
+shui
+shui2
+shui3
+shui4
+shun3
+shun4
+shuo1
+shuo2
+shuo4
+si
+si1
+si2
+si3
+si4
+song1
+song2
+song3
+song4
+sou1
+sou3
+sou4
+su1
+su2
+su3
+su4
+suan1
+suan3
+suan4
+sui1
+sui2
+sui3
+sui4
+sun1
+sun3
+sun4
+suo
+suo1
+suo2
+suo3
+suo4
+ta
+ta1
+ta2
+ta3
+ta4
+tai
+tai1
+tai2
+tai3
+tai4
+tan1
+tan2
+tan3
+tan4
+tang1
+tang2
+tang3
+tang4
+tao1
+tao2
+tao3
+tao4
+te
+te4
+tei1
+teng1
+teng2
+teng4
+ti
+ti1
+ti2
+ti3
+ti4
+tian1
+tian2
+tian3
+tian4
+tiao
+tiao1
+tiao2
+tiao3
+tiao4
+tie1
+tie2
+tie3
+tie4
+ting1
+ting2
+ting3
+ting4
+tong1
+tong2
+tong3
+tong4
+tou
+tou1
+tou2
+tou3
+tou4
+tu
+tu1
+tu2
+tu3
+tu4
+tuan1
+tuan2
+tuan3
+tuan4
+tui1
+tui2
+tui3
+tui4
+tun1
+tun2
+tun3
+tun4
+tuo1
+tuo2
+tuo3
+tuo4
+wa
+wa1
+wa2
+wa3
+wa4
+wai
+wai1
+wai3
+wai4
+wan1
+wan2
+wan3
+wan4
+wang1
+wang2
+wang3
+wang4
+wei
+wei1
+wei2
+wei3
+wei4
+wen
+wen1
+wen2
+wen3
+wen4
+weng1
+weng3
+weng4
+wo1
+wo3
+wo4
+wong4
+wu
+wu1
+wu2
+wu3
+wu4
+xi1
+xi2
+xi3
+xi4
+xia1
+xia2
+xia3
+xia4
+xian
+xian1
+xian2
+xian3
+xian4
+xiang1
+xiang2
+xiang3
+xiang4
+xiao
+xiao1
+xiao2
+xiao3
+xiao4
+xie1
+xie2
+xie3
+xie4
+xin
+xin1
+xin2
+xin3
+xin4
+xing
+xing1
+xing2
+xing3
+xing4
+xiong1
+xiong2
+xiong3
+xiong4
+xiu1
+xiu2
+xiu3
+xiu4
+xu
+xu1
+xu2
+xu3
+xu4
+xuan1
+xuan2
+xuan3
+xuan4
+xue1
+xue2
+xue3
+xue4
+xun1
+xun2
+xun4
+ya
+ya1
+ya2
+ya3
+ya4
+yan1
+yan2
+yan3
+yan4
+yang
+yang1
+yang2
+yang3
+yang4
+yao1
+yao2
+yao3
+yao4
+ye
+ye1
+ye2
+ye3
+ye4
+yi
+yi1
+yi2
+yi3
+yi4
+yin
+yin1
+yin2
+yin3
+yin4
+ying1
+ying2
+ying3
+ying4
+yo
+yo1
+yong1
+yong2
+yong3
+yong4
+you
+you1
+you2
+you3
+you4
+yu
+yu1
+yu2
+yu3
+yu4
+yuan1
+yuan2
+yuan3
+yuan4
+yue1
+yue2
+yue3
+yue4
+yun
+yun1
+yun2
+yun3
+yun4
+za1
+za2
+za3
+za4
+zai1
+zai3
+zai4
+zan
+zan1
+zan2
+zan3
+zan4
+zang1
+zang3
+zang4
+zao1
+zao2
+zao3
+zao4
+ze
+ze2
+ze4
+zei2
+zen
+zen1
+zen3
+zen4
+zeng1
+zeng3
+zeng4
+zha
+zha1
+zha2
+zha3
+zha4
+zhai1
+zhai2
+zhai3
+zhai4
+zhan1
+zhan2
+zhan3
+zhan4
+zhang
+zhang1
+zhang3
+zhang4
+zhao
+zhao1
+zhao2
+zhao3
+zhao4
+zhe
+zhe1
+zhe2
+zhe3
+zhe4
+zhei4
+zhen1
+zhen2
+zhen3
+zhen4
+zheng1
+zheng3
+zheng4
+zhi
+zhi1
+zhi2
+zhi3
+zhi4
+zhong1
+zhong3
+zhong4
+zhou1
+zhou2
+zhou3
+zhou4
+zhu1
+zhu2
+zhu3
+zhu4
+zhua1
+zhua3
+zhuai1
+zhuai3
+zhuai4
+zhuan1
+zhuan2
+zhuan3
+zhuan4
+zhuang1
+zhuang3
+zhuang4
+zhui1
+zhui3
+zhui4
+zhun1
+zhun3
+zhun4
+zhuo
+zhuo1
+zhuo2
+zhuo4
+zi
+zi1
+zi2
+zi3
+zi4
+zong
+zong1
+zong3
+zong4
+zou1
+zou3
+zou4
+zu1
+zu2
+zu3
+zu4
+zuan1
+zuan3
+zuan4
+zui
+zui1
+zui2
+zui3
+zui4
+zun1
+zun2
+zun3
+zun4
+zuo
+zuo1
+zuo2
+zuo3
+zuo4
+ê1
+ê2
+ê3
+ê4
diff --git a/egs/zipvoice/local/prepare_emilia.sh b/egs/zipvoice/local/prepare_emilia.sh
new file mode 100644
index 0000000000000000000000000000000000000000..df64e5c0cfd826a651c6e904609c833c0e419ca2
--- /dev/null
+++ b/egs/zipvoice/local/prepare_emilia.sh
@@ -0,0 +1,149 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+export PYTHONPATH=../../:$PYTHONPATH
+
+set -eou pipefail
+
+stage=0
+stop_stage=5
+sampling_rate=24000
+nj=32
+
+dl_dir=$PWD/download
+
+. scripts/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # Your download directory should look like this:
+ #
+ # download/Amphion___Emilia
+ # ├── metafile.yaml
+ # ├── raw
+ # │ ├── DE
+ # │ ├── EN
+ # │ ├── FR
+ # │ ├── JA
+ # │ ├── KO
+ # │ ├── openemilia_45batches.tar.gz
+ # │ ├── openemilia_all.tar.gz
+ # │ └── ZH
+ # └── README.md
+
+ if [ ! -d $dl_dir/Amphion___Emilia/raw ]; then
+ log "Please refer https://openxlab.org.cn/datasets/Amphion/Emilia to download the dataset."
+ exit(-1)
+ fi
+
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare emilia manifests (EN and ZH only)"
+ # We assume that you have downloaded the Emilia corpus
+ # to $dl_dir/Amphion___Emilia
+ # see stage 0 for the directory structure
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.emilia.done ]; then
+ lhotse prepare emilia --lang en --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests
+ lhotse prepare emilia --lang zh --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests
+ touch data/manifests/.emilia.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Preprocess Emilia dataset, mainly for cleaning"
+ mkdir -p data/manifests/splits_raw
+ if [ ! -e data/manifests/split_raw/.emilia.split.done ]; then
+ lhotse split-lazy data/manifests/emilia_cuts_EN.jsonl.gz data/manifests/splits_raw 10000
+ lhotse split-lazy data/manifests/emilia_cuts_ZH.jsonl.gz data/manifests/splits_raw 10000
+ touch data/manifests/splits_raw/.emilia.split.done
+ fi
+
+ mkdir -p data/manifests/splits
+
+ if [ ! -e data/manifests/splits/.emilia.preprocess.done ]; then
+ python local/preprocess_emilia.py --subset EN
+ python local/preprocess_emilia.py --subset ZH
+ touch data/manifests/splits/.emilia.preprocess.done
+ fi
+
+fi
+
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Add tokens to manifests"
+
+ mkdir -p data/manifests/tokenized_splits
+
+ if [ ! -e data/manifests/tokenized_splits/.emilia.preprocess.done ]; then
+ for subset in EN ZH; do
+ log "Tokenizing Emilia ${subset}"
+ python local/prepare_emilia.py \
+ --subset ${subset} \
+ --jobs ${nj} \
+ --source-dir data/manifests/splits/ \
+ --output-dir data/manifests/tokenized_splits/
+ done
+ touch data/manifests/tokenized_splits/.emilia.preprocess.done
+ fi
+
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Extract Fbank for Emilia"
+ mkdir -p data/fbank/emilia_splits
+ if [ ! -e data/fbank/emilia_splits/.emilia.fbank.done ]; then
+ # You can speed up the extraction by distributing splits to multiple machines.
+ for subset in EN ZH; do
+ python3 -m zipvoice.bin.compute_fbank \
+ --source-dir data/manifests/tokenized_splits \
+ --dest-dir data/fbank/emilia_splits \
+ --dataset emilia \
+ --subset ${subset} \
+ --splits-cuts 1 \
+ --split-begin 0 \
+ --split-end 2000 \
+ --num-jobs ${nj}
+ done
+ touch data/fbank/emilia_splits/.emilia.fbank.done
+ fi
+
+ if [ ! -e data/fbank/emilia_cuts_EN.jsonl.gz ]; then
+ log "Combining EN fbank cuts and spliting EN dev set"
+ gunzip -c data/fbank/emilia_splits/emilia_cuts_EN.*.jsonl.gz > data/fbank/emilia_cuts_EN.jsonl
+ head -n 1500 data/fbank/emilia_cuts_EN.jsonl | gzip -c > data/fbank/emilia_cuts_EN_dev.jsonl.gz
+ sed -i '1,1500d' data/fbank/emilia_cuts_EN.jsonl
+ gzip data/fbank/emilia_cuts_EN.jsonl
+ fi
+
+ if [ ! -e data/fbank/emilia_cuts_ZH.jsonl.gz ]; then
+ log "Combining ZH fbank cuts and spliting ZH dev set"
+ gunzip -c data/fbank/emilia_splits/emilia_cuts_ZH.*.jsonl.gz > data/fbank/emilia_cuts_ZH.jsonl
+ head -n 1500 data/fbank/emilia_cuts_ZH.jsonl | gzip -c > data/fbank/emilia_cuts_ZH_dev.jsonl.gz
+ sed -i '1,1500d' data/fbank/emilia_cuts_ZH.jsonl
+ gzip data/fbank/emilia_cuts_ZH.jsonl
+ fi
+
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Generate token file"
+ if [ ! -e data/tokens_emilia.txt ]; then
+ ./local/prepare_token_file_emilia.py --tokens data/tokens_emilia.txt
+ fi
+fi
diff --git a/egs/zipvoice/local/prepare_libritts.sh b/egs/zipvoice/local/prepare_libritts.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1b7fe1caf0b60ba64f8465fc8ac6a6ba5d170f18
--- /dev/null
+++ b/egs/zipvoice/local/prepare_libritts.sh
@@ -0,0 +1,100 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+export PYTHONPATH=../../:$PYTHONPATH
+
+set -eou pipefail
+
+stage=0
+stop_stage=5
+sampling_rate=24000
+nj=20
+
+dl_dir=$PWD/download
+
+. utils/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/LibriTTS,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS
+ #
+ if [ ! -d $dl_dir/LibriTTS ]; then
+ lhotse download libritts $dl_dir
+ fi
+
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare LibriTTS manifest"
+ # We assume that you have downloaded the LibriTTS corpus
+ # to $dl_dir/LibriTTS
+
+ # We did not add tokens to this manifest, as on-the-fly
+ # tokenization with LibriTTSTokenizer is not slow.
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.libritts.done ]; then
+ lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests
+ touch data/manifests/.libritts.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compute Fbank for LibriTTS"
+ mkdir -p data/fbank
+
+ if [ ! -e data/fbank/.libritts.done ]; then
+ for subset in train-clean-100 train-clean-360 train-other-500 dev-clean test-clean; do
+ python3 -m zipvoice.bin.compute_fbank \
+ --source-dir data/manifests \
+ --dest-dir data/fbank \
+ --dataset libritts \
+ --subset ${subset} \
+ --sampling-rate $sampling_rate \
+ --num-jobs ${nj}
+ done
+ touch data/fbank/.libritts.done
+ fi
+
+ # Here we shuffle and combine the train-clean-100, train-clean-360 and
+ # train-other-500 together to form the training set.
+ if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
+ cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
+ <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \
+ <(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \
+ shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
+ fi
+
+
+ if [ ! -e data/fbank/.libritts-validated.done ]; then
+ log "Validating data/fbank for LibriTTS"
+ python3 ./utils/validate_manifest.py \
+ data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
+ touch data/fbank/.libritts-validated.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Generate token file"
+ if [ ! -e data/tokens_libritts.txt ]; then
+ python3 ./local/prepare_token_file_char.py \
+ --manifest data/fbank/libritts_cuts_train-all-shuf.jsonl.gz \
+ --tokens data/tokens_libritts.txt
+ fi
+fi
diff --git a/egs/zipvoice/local/prepare_token_file_char.py b/egs/zipvoice/local/prepare_token_file_char.py
new file mode 100644
index 0000000000000000000000000000000000000000..1967c2b87353bb19ec001c5ebdd5ccffa2d5775d
--- /dev/null
+++ b/egs/zipvoice/local/prepare_token_file_char.py
@@ -0,0 +1,67 @@
+#!/usr/bin/env python3
+# Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import re
+from collections import Counter
+from pathlib import Path
+
+from lhotse import load_manifest_lazy
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--tokens",
+ type=Path,
+ help="Path to the dict that maps the text tokens to IDs",
+ )
+
+ parser.add_argument(
+ "--manifest",
+ type=Path,
+ help="Path to the manifest file",
+ )
+
+ return parser.parse_args()
+
+
+def prepare_tokens(manifest_file, token_file):
+ counter = Counter()
+ manifest = load_manifest_lazy(manifest_file)
+ for cut in manifest:
+ line = re.sub(r"\s+", " ", cut.supervisions[0].text)
+ counter.update(line)
+
+ unique_chars = set(counter.keys())
+
+ if "_" in unique_chars:
+ unique_chars.remove("_")
+
+ sorted_chars = sorted(unique_chars, key=lambda char: counter[char], reverse=True)
+
+ result = ["_"] + sorted_chars
+
+ with open(token_file, "w", encoding="utf-8") as file:
+ for index, char in enumerate(result):
+ file.write(f"{char}\t{index}\n")
+
+
+if __name__ == "__main__":
+ args = get_args()
+ prepare_tokens(args.manifest, args.tokens)
diff --git a/egs/zipvoice/local/prepare_token_file_emilia.py b/egs/zipvoice/local/prepare_token_file_emilia.py
new file mode 100644
index 0000000000000000000000000000000000000000..65aa302991a0b6a85037a30d928f7a5a871b4041
--- /dev/null
+++ b/egs/zipvoice/local/prepare_token_file_emilia.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file generates the file that maps tokens to IDs.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import List
+
+from piper_phonemize import get_espeak_map
+from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--tokens",
+ type=Path,
+ default=Path("data/tokens_emilia.txt"),
+ help="Path to the dict that maps the text tokens to IDs",
+ )
+
+ parser.add_argument(
+ "--pinyin",
+ type=Path,
+ default=Path("resources/pinyin.txt"),
+ help="Path to the all unique pinyin",
+ )
+
+ return parser.parse_args()
+
+
+def get_pinyin_tokens(pinyin: Path) -> List[str]:
+ phones = set()
+ with open(pinyin, "r") as f:
+ for line in f:
+ x = line.strip()
+ initial = to_initials(x, strict=False)
+ # don't want to share tokens with espeak tokens, so use tone3 style
+ finals = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
+ if initial != "":
+ # don't want to share tokens with espeak tokens,
+ # so add a '0' after each initial
+ phones.add(initial + "0")
+ if finals != "":
+ phones.add(finals)
+ return sorted(phones)
+
+
+def get_token2id(args):
+ """Get a dict that maps token to IDs, and save it to the given filename."""
+ all_tokens = get_espeak_map() # token: [token_id]
+ all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
+ # sort by token_id
+ all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
+
+ all_pinyin = get_pinyin_tokens(args.pinyin)
+ with open(args.tokens, "w", encoding="utf-8") as f:
+ for token, token_id in all_tokens:
+ f.write(f"{token}\t{token_id}\n")
+ num_espeak_tokens = len(all_tokens)
+ for i, pinyin in enumerate(all_pinyin):
+ f.write(f"{pinyin}\t{num_espeak_tokens + i}\n")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ args = get_args()
+ get_token2id(args)
diff --git a/egs/zipvoice/local/prepare_tokens_emilia.py b/egs/zipvoice/local/prepare_tokens_emilia.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5a22c8a2a61bc395f33b4f7edc8280e4be33fda
--- /dev/null
+++ b/egs/zipvoice/local/prepare_tokens_emilia.py
@@ -0,0 +1,88 @@
+"""
+This file reads the texts in given manifest and save the new cuts with phoneme tokens.
+"""
+
+import argparse
+import glob
+import logging
+from concurrent.futures import ProcessPoolExecutor as Pool
+from pathlib import Path
+
+from lhotse import load_manifest_lazy
+
+from zipvoice.tokenizer.tokenizer import add_tokens
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--subset",
+ type=str,
+ help="Subset of emilia, (ZH, EN, etc.)",
+ )
+
+ parser.add_argument(
+ "--jobs",
+ type=int,
+ default=50,
+ help="Number of jobs to processing.",
+ )
+
+ parser.add_argument(
+ "--source-dir",
+ type=str,
+ default="data/manifests/splits",
+ help="The source directory of manifest files.",
+ )
+
+ parser.add_argument(
+ "--dest-dir",
+ type=str,
+ help="The destination directory of manifest files.",
+ )
+
+ return parser.parse_args()
+
+
+def prepare_tokens_emilia(file_name: str, input_dir: Path, output_dir: Path):
+ logging.info(f"Processing {file_name}")
+ if (output_dir / file_name).is_file():
+ logging.info(f"{file_name} exists, skipping.")
+ return
+
+ try:
+ cut_set = load_manifest_lazy(input_dir / file_name)
+ cut_set = add_tokens(cut_set=cut_set, tokenizer="emilia")
+ cut_set.to_file(output_dir / file_name)
+ except Exception as e:
+ logging.error(f"Manifest {file_name} failed with error: {e}")
+ raise
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ args = get_args()
+
+ input_dir = Path(args.source_dir)
+ output_dir = Path(args.dest_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz")
+
+ with Pool(max_workers=args.jobs) as pool:
+ futures = [
+ pool.submit(
+ prepare_tokens_emilia, filename.split("/")[-1], input_dir, output_dir
+ )
+ for filename in cut_files
+ ]
+ for f in futures:
+ try:
+ f.result()
+ f.done()
+ except Exception as e:
+ logging.error(f"Future failed with error: {e}")
+ logging.info("Processing done.")
diff --git a/egs/zipvoice/local/preprocess_emilia.py b/egs/zipvoice/local/preprocess_emilia.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb732f1515735d63fe2e120dc0593eab08f20e1e
--- /dev/null
+++ b/egs/zipvoice/local/preprocess_emilia.py
@@ -0,0 +1,210 @@
+#!/usr/bin/env python3
+# Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and save the cleaned new cuts.
+"""
+
+import argparse
+import glob
+import logging
+import os
+import re
+import unicodedata
+from concurrent.futures import ProcessPoolExecutor as Pool
+from pathlib import Path
+
+from lhotse import load_manifest_lazy
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--subset",
+ type=str,
+ help="Subset of emilia, (ZH, EN, etc.)",
+ )
+
+ parser.add_argument(
+ "--jobs",
+ type=int,
+ default=20,
+ help="Number of jobs to processing.",
+ )
+
+ parser.add_argument(
+ "--source-dir",
+ type=str,
+ default="data/manifests/splits_raw",
+ help="The source directory of manifest files.",
+ )
+
+ parser.add_argument(
+ "--dest-dir",
+ type=str,
+ default="data/manifests/splits",
+ help="The destination directory of manifest files.",
+ )
+
+ return parser.parse_args()
+
+
+def tokenize_by_CJK_char(text: str) -> str:
+ """
+ Tokenize a line of text with CJK char.
+
+ Example:
+ input = "你好世界是 hello world 的中文"
+ output = ["你", "好", "世", "界", "是", "hello", "world", "的", "中", "文"]
+ """
+ pattern = re.compile(
+ r"([\u1100-\u11ff"
+ r"\u2e80-\ua4cf"
+ r"\ua840-\uD7AF"
+ r"\uF900-\uFAFF"
+ r"\uFE30-\uFE4F"
+ r"\uFF65-\uFFDC"
+ r"\U00020000-\U0002FFFF])"
+ )
+ chars = pattern.split(text.strip())
+ merged = " ".join([w.strip() for w in chars if w.strip()])
+ return merged.split()
+
+
+def is_hangul(char):
+ letters = unicodedata.normalize("NFD", char)
+ return all(
+ ["\u1100" <= c <= "\u11ff" or "\u3131" <= c <= "\u318e" for c in letters]
+ )
+
+
+def is_japanese(char):
+ return any(
+ [
+ start <= char <= end
+ for start, end in [
+ ("\u3041", "\u3096"),
+ ("\u30a0", "\u30ff"),
+ ("\uff5f", "\uff9f"),
+ ("\u31f0", "\u31ff"),
+ ("\u3220", "\u3243"),
+ ("\u3280", "\u337f"),
+ ]
+ ]
+ )
+
+
+def is_chinese(char):
+ if char >= "\u4e00" and char <= "\u9fa5":
+ return True
+ else:
+ return False
+
+
+def is_alphabet(char):
+ if (char >= "\u0041" and char <= "\u005a") or (
+ char >= "\u0061" and char <= "\u007a"
+ ):
+ return True
+ else:
+ return False
+
+
+def preprocess_emilia(file_name: str, input_dir: Path, output_dir: Path):
+ logging.info(f"Processing {file_name}")
+ if (output_dir / file_name).is_file():
+ logging.info(f"{file_name} exists, skipping.")
+ return
+
+ def _filter_cut(cut):
+ text = cut.supervisions[0].text
+ duration = cut.supervisions[0].duration
+ chinese = []
+ english = []
+
+ # only contains chinese and space and alphabets
+ clean_chars = []
+ for x in text:
+ if is_hangul(x):
+ logging.warning(f"Delete cut with text containing Korean : {text}")
+ return False
+ if is_japanese(x):
+ logging.warning(f"Delete cut with text containing Japanese : {text}")
+ return False
+ if is_chinese(x):
+ chinese.append(x)
+ clean_chars.append(x)
+ if is_alphabet(x):
+ english.append(x)
+ clean_chars.append(x)
+ if x == " ":
+ clean_chars.append(x)
+ if len(english) + len(chinese) == 0:
+ logging.warning(f"Delete cut with text has no valid chars : {text}")
+ return False
+
+ words = tokenize_by_CJK_char("".join(clean_chars))
+ for i in range(len(words) - 10):
+ if words[i : i + 10].count(words[i]) == 10:
+ logging.warning(f"Delete cut with text with too much repeats : {text}")
+ return False
+ # word speed, 20 - 600 / minute
+ if duration < len(words) / 600 * 60 or duration > len(words) / 20 * 60:
+ logging.warning(
+ f"Delete cut with audio text mismatch, duration : {duration}s, "
+ f"words : {len(words)}, text : {text}"
+ )
+ return False
+ return True
+
+ try:
+ cut_set = load_manifest_lazy(input_dir / file_name)
+ cut_set = cut_set.filter(_filter_cut)
+ cut_set.to_file(output_dir / file_name)
+ except Exception as e:
+ logging.error(f"Manifest {file_name} failed with error: {e}")
+ os.remove(str(output_dir / file_name))
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ args = get_args()
+
+ input_dir = Path(args.source_dir)
+ output_dir = Path(args.dest_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz")
+
+ with Pool(max_workers=args.jobs) as pool:
+ futures = [
+ pool.submit(
+ preprocess_emilia,
+ filename.split("/")[-1],
+ input_dir,
+ output_dir,
+ )
+ for filename in cut_files
+ ]
+ for f in futures:
+ f.result()
+ f.done()
+ logging.info("Processing done.")
diff --git a/egs/zipvoice/run_custom.sh b/egs/zipvoice/run_custom.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8f41d00fcf7f92885aca6e6003e4a6aaffd12891
--- /dev/null
+++ b/egs/zipvoice/run_custom.sh
@@ -0,0 +1,138 @@
+#!/bin/bash
+
+# This script is an example of training ZipVoice on your custom datasets from scratch.
+
+# Add project root to PYTHONPATH
+export PYTHONPATH=../../:$PYTHONPATH
+
+# Set bash to 'debug' mode, it will exit on:
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=6
+
+# Number of jobs for data preparation
+nj=20
+
+# You can set `train_hours` and `max_len` according to statistics from
+# the command `lhotse cut describe data/fbank/custom_cuts_train.jsonl.gz`.
+# Set `train_hours` to "Total speech duration", and set `max_len` to 99% duration.
+
+# Number of hours in training set, will affect the learning rate schedule
+train_hours=500
+# Maximum length (seconds) of the training utterance, will filter out longer utterances
+max_len=20
+
+# We suppose you have two TSV files: "data/raw/custom_train.tsv" and
+# "data/raw/custom_dev.tsv", where "custom" is your dataset name,
+# "train"/"dev" are used for training and validation respectively.
+
+# Each line of the TSV files should be in one of the following formats:
+# (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
+# (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
+# to part of the wav. The start_time and end_time specify the start and end
+# times of the text within the wav, which should be in seconds.
+# > Note: {uniq_id} must be unique for each line.
+for subset in train dev;do
+ file_path=data/raw/custom_${subset}.tsv
+ [ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
+done
+
+### Prepare the training data (1 - 3)
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "Stage 1: Prepare manifests for custom dataset from tsv files"
+
+ for subset in train dev;do
+ python3 -m zipvoice.bin.prepare_dataset \
+ --tsv-path data/raw/custom_${subset}.tsv \
+ --prefix custom \
+ --subset ${subset} \
+ --num-jobs ${nj} \
+ --output-dir data/manifests
+ done
+ # The output manifest files are "data/manifests/custom_cuts_train.jsonl.gz".
+ # and "data/manifests/custom_cuts_dev.jsonl.gz".
+
+ # We did not add tokens to the manifests, as on-the-fly tokenization
+ # with the simple tokenizer used in this example is not slow.
+ # If you change to a complex tokenizer, e.g., with g2p and heavy text normalization,
+ # you may need to add tokens to the manifests to speed up the training.
+ # Refer to the fine-tuning example for adding tokens to the manifests.
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ echo "Stage 2: Compute Fbank for custom dataset"
+ # You can skip this step and use `--on-the-fly-feats 1` in training stage
+ for subset in train dev; do
+ python3 -m zipvoice.bin.compute_fbank \
+ --source-dir data/manifests \
+ --dest-dir data/fbank \
+ --dataset custom \
+ --subset ${subset} \
+ --num-jobs ${nj}
+ done
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "Stage 3: Prepare tokens file for custom dataset"
+ # In this example, we use the simplest tokenizer that
+ # treat every character as a token.
+ python3 ./local/prepare_token_file_char.py \
+ --manifest data/manifests/custom_cuts_train.jsonl.gz \
+ --tokens data/tokens_custom.txt
+fi
+
+
+### Training (4 - 5)
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Stage 4: Train the ZipVoice model"
+
+ [ -z "$train_hours" ] && { echo "Error: train_hours is not set!" >&2; exit 1; }
+ [ -z "$max_len" ] && { echo "Error: max_len is not set!" >&2; exit 1; }
+
+ # lr-hours will be set according to the `train_hours`,
+ # i.e., lr_hours = 1000 * (train_hours ** 0.3).
+ lr_hours=$(python3 -c "print(round(1000 * ($train_hours ** 0.3)))" )
+ python3 -m zipvoice.bin.train_zipvoice \
+ --world-size 4 \
+ --use-fp16 1 \
+ --num-iters 60000 \
+ --max-duration 500 \
+ --lr-hours ${lr_hours} \
+ --max-len ${max_len} \
+ --model-config conf/zipvoice_base.json \
+ --tokenizer simple \
+ --token-file data/tokens_custom.txt \
+ --dataset custom \
+ --train-manifest data/fbank/custom_cuts_train.jsonl.gz \
+ --dev-manifest data/fbank/custom_cuts_dev.jsonl.gz \
+ --exp-dir exp/zipvoice_custom
+fi
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "Stage 5: Average the checkpoints for ZipVoice"
+ python3 -m zipvoice.bin.generate_averaged_model \
+ --iter 60000 \
+ --avg 2 \
+ --model-name zipvoice \
+ --exp-dir exp/zipvoice_custom
+ # The generated model is exp/zipvoice_custom/iter-60000-avg-2.pt
+fi
+
+### Inference with PyTorch models (6)
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "Stage 6: Inference of the ZipVoice model"
+ python3 -m zipvoice.bin.infer_zipvoice \
+ --model-name zipvoice \
+ --model-dir exp/zipvoice_custom \
+ --checkpoint-name iter-60000-avg-2.pt \
+ --tokenizer simple \
+ --test-list test.tsv \
+ --res-dir results/test_custom
+fi
diff --git a/egs/zipvoice/run_emilia.sh b/egs/zipvoice/run_emilia.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b566a03096180802a3ba35c26c7fb334e258a2ce
--- /dev/null
+++ b/egs/zipvoice/run_emilia.sh
@@ -0,0 +1,178 @@
+#!/bin/bash
+
+# This is an example script for training ZipVoice on Emilia dataset.
+
+# This script covers data preparation, ZipVoice trainnig,
+# ZipVoice-Distill training, onnx export, and
+# inference with all PyTorch and ONNX models.
+
+
+# Add project root to PYTHONPATH
+export PYTHONPATH=../../:$PYTHONPATH
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=12
+
+#### Prepare datasets (1)
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "Stage 1: Data Preparation for Emilia dataset"
+ bash local/prepare_emilia.sh
+fi
+
+### Training ZipVoice (2 - 3)
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "Stage 2: Train the ZipVoice model"
+ python3 -m zipvoice.bin.train_zipvoice \
+ --world-size 8 \
+ --use-fp16 1 \
+ --num-epochs 11 \
+ --max-duration 500 \
+ --lr-hours 30000 \
+ --model-config conf/zipvoice_base.json \
+ --tokenizer emilia \
+ --token-file data/tokens_emilia.txt \
+ --dataset emilia \
+ --manifest-dir data/fbank \
+ --exp-dir exp/zipvoice
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "Stage 3: Average the checkpoints for ZipVoice"
+ python3 -m zipvoice.bin.generate_averaged_model \
+ --epoch 11 \
+ --avg 4 \
+ --model-name zipvoice \
+ --exp-dir exp/zipvoice
+ # The generated model is exp/zipvoice/epoch-11-avg-4.pt
+fi
+
+#### (Optional) Training ZipVoice-Distill model (4 - 6)
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Stage 4: Train the ZipVoice-Distill model (first stage)"
+ python3 -m zipvoice.bin.train_zipvoice_distill \
+ --world-size 8 \
+ --use-fp16 1 \
+ --num-iters 60000 \
+ --max-duration 500 \
+ --base-lr 0.0005 \
+ --tokenizer emilia \
+ --token-file data/tokens_emilia.txt \
+ --dataset emilia \
+ --manifest-dir data/fbank \
+ --teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
+ --distill-stage first \
+ --exp-dir exp/zipvoice_distill_1stage
+fi
+
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "Stage 5: Average the checkpoints for ZipVoice-Distill (first stage)"
+ python3 -m zipvoice.bin.generate_averaged_model \
+ --iter 60000 \
+ --avg 7 \
+ --model-name zipvoice_distill \
+ --exp-dir exp/zipvoice_distill_1stage
+ # The generated model is exp/zipvoice_distill_1stage/iter-60000-avg-7.pt
+fi
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "Stage 6: Train the ZipVoice-Distill model (second stage)"
+
+ python3 -m zipvoice.bin.train_zipvoice_distill \
+ --world-size 8 \
+ --use-fp16 1 \
+ --num-iters 2000 \
+ --save-every-n 1000 \
+ --max-duration 500 \
+ --base-lr 0.0001 \
+ --model-config conf/zipvoice_base.json \
+ --tokenizer emilia \
+ --token-file data/tokens_emilia.txt \
+ --dataset emilia \
+ --manifest-dir data/fbank \
+ --teacher-model exp/zipvoice_distill_1stage/iter-60000-avg-7.pt \
+ --distill-stage second \
+ --exp-dir exp/zipvoice_distill
+fi
+
+### Export ONNX model (7 - 8)
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+ echo "Stage 7: Export ZipVoice ONNX model"
+ python3 -m zipvoice.bin.onnx_export \
+ --model-name zipvoice \
+ --model-dir exp/zipvoice/ \
+ --checkpoint-name epoch-11-avg-4.pt \
+ --onnx-model-dir exp/zipvoice/
+fi
+
+if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
+ echo "Stage 8: Export ZipVoice-Distill ONNX model"
+ python3 -m zipvoice.bin.onnx_export \
+ --model-name zipvoice_distill \
+ --model-dir exp/zipvoice_distill/ \
+ --checkpoint-name checkpoint-2000.pt \
+ --onnx-model-dir exp/zipvoice_distill/
+fi
+
+
+### Inference with PyTorch and ONNX models (9 - 12)
+
+if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
+ echo "Stage 9: Inference of the ZipVoice model"
+ python3 -m zipvoice.bin.infer_zipvoice \
+ --model-name zipvoice \
+ --model-dir exp/zipvoice/ \
+ --checkpoint-name epoch-11-avg-4.pt \
+ --tokenizer emilia \
+ --test-list test.tsv \
+ --res-dir results/test \
+ --num-step 16 \
+ --guidance-scale 1
+fi
+
+
+if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
+ echo "Stage 10: Inference of the ZipVoice-Distill model"
+ python3 -m zipvoice.bin.infer_zipvoice \
+ --model-name zipvoice_distill \
+ --model-dir exp/zipvoice_distill/ \
+ --checkpoint-name checkpoint-2000.pt \
+ --tokenizer emilia \
+ --test-list test.tsv \
+ --res-dir results/test_distill \
+ --num-step 8 \
+ --guidance-scale 3
+fi
+
+
+if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then
+ echo "Stage 11: Inference with ZipVoice ONNX model"
+ python3 -m zipvoice.bin.infer_zipvoice_onnx \
+ --model-name zipvoice \
+ --onnx-int8 False \
+ --model-dir exp/zipvoice \
+ --tokenizer emilia \
+ --test-list test.tsv \
+ --res-dir results/test_onnx
+fi
+
+if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then
+ echo "Stage 12: Inference with ZipVoic-Distill ONNX model"
+ python3 -m zipvoice.bin.infer_zipvoice_onnx \
+ --model-name zipvoice_distill \
+ --onnx-int8 False \
+ --model-dir exp/zipvoice_distill \
+ --tokenizer emilia \
+ --test-list test.tsv \
+ --res-dir results/test_distill_onnx
+fi
\ No newline at end of file
diff --git a/egs/zipvoice/run_eval.sh b/egs/zipvoice/run_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0e8bbd2f385393c4e8567bfc8382a013305acde8
--- /dev/null
+++ b/egs/zipvoice/run_eval.sh
@@ -0,0 +1,142 @@
+#!/bin/bash
+
+# This script is an example of evaluate TTS models with objective metrics reported in ZipVoice paper.
+
+# Add project root to PYTHONPATH
+export PYTHONPATH=../../:$PYTHONPATH
+
+# Set bash to 'debug' mode, it will exit on:
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=7
+
+download_dir=download/
+
+# Uncomment this line to use HF mirror
+# export HF_ENDPOINT=https://hf-mirror.com
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "Stage 1: Download test sets (LibriSpeech-PC and Seed-TTS)"
+
+ hf_repo=k2-fsa/TTS_eval_datasets
+ mkdir -p ${download_dir}/
+ for file in librispeech_pc_testset.tar.gz seedtts_testset.tar.gz; do
+ echo "Downloading ${file}..."
+ huggingface-cli download \
+ --repo-type dataset \
+ --local-dir ${download_dir}/ \
+ ${hf_repo} \
+ ${file}
+ echo "Extracting ${file}..."
+ tar -xzf ${download_dir}/${file} -C ${download_dir}/
+ done
+fi
+
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "Stage 2: Download all required evaluation models"
+ hf_repo=k2-fsa/TTS_eval_models
+ mkdir -p ${download_dir}/tts_eval_models
+ huggingface-cli download \
+ --local-dir ${download_dir}/tts_eval_models \
+ ${hf_repo}
+fi
+
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "Stage 3: Inference with the pre-trained ZipVoice model from huggingface"
+
+ for testset in librispeech_pc seedtts_en seedtts_zh; do
+
+ if [ "$testset" = "librispeech_pc" ]; then
+ test_tsv=${download_dir}/librispeech_pc_testset/test.tsv
+
+ elif [ "$testset" = "seedtts_en" ]; then
+ test_tsv=${download_dir}/seedtts_testset/en/test.tsv
+ elif [ "$testset" = "seedtts_zh" ]; then
+ test_tsv=${download_dir}/seedtts_testset/zh/test.tsv
+ else
+ echo "Error: unknown testset ${testset}" >&2
+ exit 1
+ fi
+ echo "Inference on tetset ${testset}..."
+ python3 -m zipvoice.bin.infer_zipvoice \
+ --model-name zipvoice \
+ --test-list ${test_tsv} \
+ --res-dir results/${testset}
+ done
+fi
+
+
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Stage 4: Evaluation on LibriSpeech-PC"
+ model_path=${download_dir}/tts_eval_models
+ wav_path=results/librispeech_pc
+ test_tsv=${download_dir}/librispeech_pc_testset/test.tsv
+ # Use LibriSpeech style transcripts for WER evaluation
+ transcript_tsv=${download_dir}/librispeech_pc_testset/transcript.tsv
+
+ python3 -m zipvoice.eval.speaker_similarity.sim \
+ --wav-path ${wav_path} \
+ --test-list ${test_tsv} \
+ --model-dir ${model_path}
+
+ python3 -m zipvoice.eval.wer.hubert \
+ --wav-path ${wav_path} \
+ --test-list ${transcript_tsv} \
+ --model-dir ${model_path}
+
+ python3 -m zipvoice.eval.mos.utmos \
+ --wav-path ${wav_path} \
+ --model-dir ${model_path}
+fi
+
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "Stage 5: Evaluation on Seed-TTS test en"
+ model_path=${download_dir}/tts_eval_models
+ wav_path=results/seedtts_en
+ test_tsv=${download_dir}/seedtts_testset/en/test.tsv
+
+ python3 -m zipvoice.eval.speaker_similarity.sim \
+ --wav-path ${wav_path} \
+ --test-list ${test_tsv} \
+ --model-dir ${model_path}
+
+ python3 -m zipvoice.eval.wer.seedtts \
+ --wav-path ${wav_path} \
+ --test-list ${test_tsv} \
+ --model-dir ${model_path} \
+ --lang en
+
+ python3 -m zipvoice.eval.mos.utmos \
+ --wav-path ${wav_path} \
+ --model-dir ${model_path}
+fi
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "Stage 6: Evaluation on Seed-TTS test en"
+ model_path=${download_dir}/tts_eval_models
+ wav_path=results/seedtts_zh
+ test_tsv=${download_dir}/seedtts_testset/zh/test.tsv
+
+ python3 -m zipvoice.eval.speaker_similarity.sim \
+ --wav-path ${wav_path} \
+ --test-list ${test_tsv} \
+ --model-dir ${model_path}
+
+ python3 -m zipvoice.eval.wer.seedtts \
+ --wav-path ${wav_path} \
+ --test-list ${test_tsv} \
+ --model-dir ${model_path} \
+ --lang zh
+
+ python3 -m zipvoice.eval.mos.utmos \
+ --wav-path ${wav_path} \
+ --model-dir ${model_path}
+fi
\ No newline at end of file
diff --git a/egs/zipvoice/run_finetune.sh b/egs/zipvoice/run_finetune.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5aefe0b7fc766c639e3ffb40dff85810cc7ed0fe
--- /dev/null
+++ b/egs/zipvoice/run_finetune.sh
@@ -0,0 +1,175 @@
+#!/bin/bash
+
+# This script is an example of fine-tuning ZipVoice on your custom datasets.
+
+# Add project root to PYTHONPATH
+# export PYTHONPATH=../../:$PYTHONPATH
+
+# Set bash to 'debug' mode, it will exit on:
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=6
+
+# Number of jobs for data preparation
+nj=4
+
+# Whether the language of training data is one of Chinese and English
+is_zh_en=0
+
+# Language identifier, used when language is not Chinese or English
+# see https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md
+# Example of French: lang=fr
+lang=vi
+
+if [ $is_zh_en -eq 1 ]; then
+ tokenizer=espeak
+else
+ tokenizer=espeak
+ [ "$lang" = "default" ] && { echo "Error: lang is not set!" >&2; exit 1; }
+fi
+
+# You can set `max_len` according to statistics from the command
+# `lhotse cut describe data/fbank/custom_cuts_train.jsonl.gz`.
+# Set `max_len` to 99% duration.
+
+# Maximum length (seconds) of the training utterance, will filter out longer utterances
+max_len=25
+
+# Download directory for pre-trained models
+download_dir=download
+
+# We suppose you have two TSV files: "data/raw/custom_train.tsv" and
+# "data/raw/custom_dev.tsv", where "custom" is your dataset name,
+# "train"/"dev" are used for training and validation respectively.
+
+# Each line of the TSV files should be in one of the following formats:
+# (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
+# (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
+# to part of the wav. The start_time and end_time specify the start and end
+# times of the text within the wav, which should be in seconds.
+# > Note: {uniq_id} must be unique for each line.
+# for subset in train dev;do
+# file_path=data/raw/custom_${subset}.tsv
+# [ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
+# done
+
+# ### Prepare the training data (1 - 4)
+
+# if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+# echo "Stage 1: Prepare manifests for custom dataset from tsv files"
+
+# for subset in train dev;do
+# python3 -m zipvoice.bin.prepare_dataset \
+# --tsv-path data/raw/custom_${subset}.tsv \
+# --prefix custom-finetune \
+# --subset raw_${subset} \
+# --num-jobs ${nj} \
+# --output-dir data/manifests
+# done
+# # The output manifest files are "data/manifests/custom-finetune_cuts_raw_train.jsonl.gz".
+# # and "data/manifests/custom-finetune_cuts_raw_dev.jsonl.gz".
+# fi
+
+
+# if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+# echo "Stage 2: Add tokens to manifests"
+# # For "emilia" and "espeak" tokenizers, it's better to prepare the tokens
+# # before training. Otherwise, the on-the-fly tokenization can significantly
+# # slow down the training.
+# for subset in train dev;do
+# python3 -m zipvoice.bin.prepare_tokens \
+# --input-file data/manifests/custom-finetune_cuts_raw_${subset}.jsonl.gz \
+# --output-file data/manifests/custom-finetune_cuts_${subset}.jsonl.gz \
+# --tokenizer ${tokenizer} \
+# --lang ${lang}
+# done
+# # The output manifest files are "data/manifests/custom-finetune_cuts_train.jsonl.gz".
+# # and "data/manifests/custom-finetune_cuts_dev.jsonl.gz".
+# fi
+
+# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+# echo "Stage 3: Compute Fbank for custom dataset"
+# # You can skip this step and use `--on-the-fly-feats 1` in training stage
+# for subset in train dev; do
+# python3 -m zipvoice.bin.compute_fbank \
+# --source-dir data/manifests \
+# --dest-dir data/fbank \
+# --dataset custom-finetune \
+# --subset ${subset} \
+# --num-jobs ${nj}
+# done
+# fi
+
+# # if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+# # echo "Stage 4: Download pre-trained model, tokens file, and model config"
+# # # Uncomment this line to use HF mirror
+# # # export HF_ENDPOINT=https://hf-mirror.com
+# # hf_repo=k2-fsa/ZipVoice
+# # mkdir -p ${download_dir}
+# # for file in model.pt tokens.txt model.json; do
+# # huggingface-cli download \
+# # --local-dir ${download_dir} \
+# # ${hf_repo} \
+# # zipvoice/${file}
+# # done
+# # fi
+
+# # ### Training ZipVoice (5 - 6)
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "Stage 5: Fine-tune the ZipVoice model"
+
+ [ -z "$max_len" ] && { echo "Error: max_len is not set!" >&2; exit 1; }
+
+ python3 -m zipvoice.bin.train_zipvoice \
+ --world-size 1 \
+ --use-fp16 1 \
+ --finetune 1 \
+ --base-lr 0.00006 \
+ --num-epochs 2 \
+ --save-every-n 1000 \
+ --keep-last-k 4 \
+ --max-duration 650 \
+ --max-len ${max_len} \
+ --min-len 1 \
+ --model-config ${download_dir}/zipvoice/model.json \
+ --checkpoint ${download_dir}/zipvoice/model.pt \
+ --tokenizer ${tokenizer} \
+ --lang ${lang} \
+ --token-file ${download_dir}/zipvoice/tokens.txt \
+ --dataset custom \
+ --train-manifest data/fbank/train_all.jsonl.gz \
+ --dev-manifest data/fbank/dev_all.jsonl.gz \
+ --exp-dir exp/zipvoice_finetune
+
+fi
+
+# if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+# echo "Stage 6: Average the checkpoints for ZipVoice"
+# python3 -m zipvoice.bin.generate_averaged_model \
+# --iter 10000 \
+# --avg 2 \
+# --model-name zipvoice \
+# --exp-dir exp/zipvoice_finetune
+# # The generated model is exp/zipvoice_finetune/iter-10000-avg-2.pt
+# fi
+
+# ### Inference with PyTorch models (7)
+
+# if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+# echo "Stage 7: Inference of the ZipVoice model"
+
+# python3 -m zipvoice.bin.infer_zipvoice \
+# --model-name zipvoice \
+# --model-dir exp/zipvoice_finetune/ \
+# --checkpoint-name iter-10000-avg-2.pt \
+# --tokenizer ${tokenizer} \
+# --lang ${lang} \
+# --test-list test.tsv \
+# --res-dir results/test_finetune\
+# --num-step 16
+# fi
diff --git a/egs/zipvoice/run_libritts.sh b/egs/zipvoice/run_libritts.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1dc00d31fe130ef2579912b121c24c5e526dae62
--- /dev/null
+++ b/egs/zipvoice/run_libritts.sh
@@ -0,0 +1,148 @@
+#!/bin/bash
+
+# This is an example script for training ZipVoice on LibriTTS dataset.
+
+# Add project root to PYTHONPATH
+export PYTHONPATH=../../:$PYTHONPATH
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=9
+
+#### Prepare datasets (1)
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "Stage 1: Data Preparation for LibriTTS dataset"
+ bash local/prepare_libritts.sh
+fi
+
+### Training ZipVoice (2 - 3)
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "Stage 2: Train the ZipVoice model"
+ python3 -m zipvoice.bin.train_zipvoice \
+ --world-size 8 \
+ --use-fp16 0 \
+ --num-epochs 60 \
+ --max-duration 250 \
+ --lr-epochs 10 \
+ --max-len 20 \
+ --valid-by-epoch 1 \
+ --model-config conf/zipvoice_base.json \
+ --tokenizer libritts \
+ --token-file data/tokens_libritts.txt \
+ --dataset libritts \
+ --manifest-dir data/fbank \
+ --exp-dir exp/zipvoice_libritts
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "Stage 3: Average the checkpoints for ZipVoice"
+ python3 -m zipvoice.bin.generate_averaged_model \
+ --epoch 60 \
+ --avg 10 \
+ --model-name zipvoice \
+ --exp-dir exp/zipvoice_libritts
+ # The generated model is exp/zipvoice_libritts/epoch-60-avg-10.pt
+fi
+
+#### (Optional) Training ZipVoice-Distill model (4 - 7)
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Stage 4: Train the ZipVoice-Distill model (first stage)"
+ python3 -m zipvoice.bin.train_zipvoice_distill \
+ --world-size 8 \
+ --use-fp16 0 \
+ --num-epochs 6 \
+ --max-duration 250 \
+ --base-lr 0.001 \
+ --max-len 20 \
+ --valid-by-epoch 1 \
+ --model-config conf/zipvoice_base.json \
+ --tokenizer libritts \
+ --token-file data/tokens_libritts.txt \
+ --dataset "libritts" \
+ --manifest-dir "data/fbank" \
+ --teacher-model exp/zipvoice_libritts/epoch-60-avg-10.pt \
+ --distill-stage "first" \
+ --exp-dir exp/zipvoice_distill_1stage_libritts
+fi
+
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "Stage 5: Average the checkpoints for ZipVoice-Distill (first stage)"
+ python3 -m zipvoice.bin.generate_averaged_model \
+ --epoch 6 \
+ --avg 3 \
+ --model-name zipvoice_distill \
+ --exp-dir exp/zipvoice_distill_1stage_libritts
+ # The generated model is exp/zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt
+fi
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "Stage 6: Train the ZipVoice-Distill model (second stage)"
+
+ python3 -m zipvoice.bin.train_zipvoice_distill \
+ --world-size 8 \
+ --use-fp16 1 \
+ --num-epochs 6 \
+ --max-duration 250 \
+ --base-lr 0.001 \
+ --max-len 20 \
+ --valid-by-epoch 1 \
+ --model-config conf/zipvoice_base.json \
+ --tokenizer libritts \
+ --token-file data/tokens_libritts.txt \
+ --dataset libritts \
+ --manifest-dir data/fbank \
+ --teacher-model exp/zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt \
+ --distill-stage second \
+ --exp-dir exp/zipvoice_distill_libritts
+fi
+
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+ echo "Stage 7: Average the checkpoints for ZipVoice-Distill (second stage)"
+ python3 -m zipvoice.bin.generate_averaged_model \
+ --epoch 6 \
+ --avg 3 \
+ --model-name zipvoice_distill \
+ --exp-dir exp/zipvoice_distill_libritts
+ # The generated model is exp/zipvoice_distill_libritts/epoch-6-avg-3.pt
+fi
+
+### Inference with PyTorch models (8 - 9)
+
+if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
+ echo "Stage 8: Inference of the ZipVoice model"
+ python3 -m zipvoice.bin.infer_zipvoice \
+ --model-name zipvoice \
+ --model-dir exp/zipvoice_libritts \
+ --checkpoint-name epoch-60-avg-10.pt \
+ --tokenizer libritts \
+ --test-list test.tsv \
+ --res-dir results/test_libritts \
+ --num-step 8 \
+ --guidance-scale 1 \
+ --t-shift 0.7
+fi
+
+
+if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
+ echo "Stage 9: Inference of the ZipVoice-Distill model"
+ python3 -m zipvoice.bin.infer_zipvoice \
+ --model-name zipvoice_distill \
+ --model-dir exp/zipvoice_distill_libritts \
+ --checkpoint-name epoch-6-avg-3.pt \
+ --tokenizer libritts \
+ --test-list test.tsv \
+ --res-dir results/test_distill_libritts \
+ --num-step 4 \
+ --guidance-scale 3 \
+ --t-shift 0.7
+fi
diff --git a/egs/zipvoice/utils/parse_options.sh b/egs/zipvoice/utils/parse_options.sh
new file mode 100644
index 0000000000000000000000000000000000000000..71fb9e5ea1db641ff5bf82a18dc292db6ceba146
--- /dev/null
+++ b/egs/zipvoice/utils/parse_options.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
+# Arnab Ghoshal, Karel Vesely
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Parse command-line options.
+# To be sourced by another script (as in ". parse_options.sh").
+# Option format is: --option-name arg
+# and shell variable "option_name" gets set to value "arg."
+# The exception is --help, which takes no arguments, but prints the
+# $help_message variable (if defined).
+
+
+###
+### The --config file options have lower priority to command line
+### options, so we need to import them first...
+###
+
+# Now import all the configs specified by command-line, in left-to-right order
+for ((argpos=1; argpos<$#; argpos++)); do
+ if [ "${!argpos}" == "--config" ]; then
+ argpos_plus1=$((argpos+1))
+ config=${!argpos_plus1}
+ [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
+ . $config # source the config file.
+ fi
+done
+
+
+###
+### Now we process the command line options
+###
+while true; do
+ [ -z "${1:-}" ] && break; # break if there are no arguments
+ case "$1" in
+ # If the enclosing script is called with --help option, print the help
+ # message and exit. Scripts should put help messages in $help_message
+ --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
+ else printf "$help_message\n" 1>&2 ; fi;
+ exit 0 ;;
+ --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
+ exit 1 ;;
+ # If the first command-line argument begins with "--" (e.g. --foo-bar),
+ # then work out the variable name as $name, which will equal "foo_bar".
+ --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
+ # Next we test whether the variable in question is undefned-- if so it's
+ # an invalid option and we die. Note: $0 evaluates to the name of the
+ # enclosing script.
+ # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
+ # is undefined. We then have to wrap this test inside "eval" because
+ # foo_bar is itself inside a variable ($name).
+ eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
+
+ oldval="`eval echo \\$$name`";
+ # Work out whether we seem to be expecting a Boolean argument.
+ if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
+ was_bool=true;
+ else
+ was_bool=false;
+ fi
+
+ # Set the variable to the right value-- the escaped quotes make it work if
+ # the option had spaces, like --cmd "queue.pl -sync y"
+ eval $name=\"$2\";
+
+ # Check that Boolean-valued arguments are really Boolean.
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
+ exit 1;
+ fi
+ shift 2;
+ ;;
+ *) break;
+ esac
+done
+
+
+# Check for an empty argument to the --cmd option, which can easily occur as a
+# result of scripting errors.
+[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
+
+
+true; # so this script returns exit code 0.
diff --git a/egs/zipvoice/utils/validate_manifest.py b/egs/zipvoice/utils/validate_manifest.py
new file mode 100644
index 0000000000000000000000000000000000000000..635c5e7886cc219da901409aa175679642549c05
--- /dev/null
+++ b/egs/zipvoice/utils/validate_manifest.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script checks the following assumptions of the generated manifest:
+
+- Single supervision per cut
+
+We will add more checks later if needed.
+
+Usage example:
+
+ python3 ./utils/validate_manifest.py \
+ ./data/spectrogram/ljspeech_cuts_all.jsonl.gz
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.dataset.speech_synthesis import validate_for_tts
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "manifest",
+ type=Path,
+ help="Path to the manifest file",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ manifest = args.manifest
+ logging.info(f"Validating {manifest}")
+
+ assert manifest.is_file(), f"{manifest} does not exist"
+ cut_set = load_manifest_lazy(manifest)
+ assert isinstance(cut_set, CutSet), type(cut_set)
+
+ validate_for_tts(cut_set)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ main()
diff --git a/egs/zipvoice_dialog/README.md b/egs/zipvoice_dialog/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0252b5e5cff5a2bd07b5e854b5f07bd436fd8037
--- /dev/null
+++ b/egs/zipvoice_dialog/README.md
@@ -0,0 +1,12 @@
+# ZipVoice-Dialog Recipe
+
+This recipe contains the following examples:
+
+- Training ZipVoice-Dialog on OpenDialog dataset, see [run_opendialog.sh](run_opendialog.sh)
+- Training ZipVoice-Dialog on custom datasets (Chinese/English), see [run_custom.sh](run_custom.sh).
+- Fine-tuning pre-trained ZipVoice-Dialog on custom datasets (Chinese/English), see [run_finetune.sh](run_finetune.sh).
+- Evaluate models with objective metrics reported in ZipVoice-Dialog paper, see [run_eval.sh](run_eval.sh).
+
+> **NOTE:** For evaluation, first install packages from [../../requirements_eval.txt](../../requirements_eval.txt)
+>
+> `pip install -r ../../requirements_eval.txt`
\ No newline at end of file
diff --git a/egs/zipvoice_dialog/local/prepare_opendialog.py b/egs/zipvoice_dialog/local/prepare_opendialog.py
new file mode 100644
index 0000000000000000000000000000000000000000..4934dc382b57ba65c0af6ee164cb34f8437283b9
--- /dev/null
+++ b/egs/zipvoice_dialog/local/prepare_opendialog.py
@@ -0,0 +1,262 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script prepares lhotse manifest files from the raw OpenDialog datasets.
+
+We assume that you have downloaded the OpenDialog dataset and untarred the
+tar files in audio/en and audio/zh so that the mp3 files are placed under
+these two directories.
+
+Download OpenDialog at https://huggingface.co/datasets/k2-fsa/OpenDialog
+or https://www.modelscope.cn/datasets/k2-fsa/OpenDialog
+
+"""
+
+import argparse
+import json
+import logging
+import math
+import re
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+from pathlib import Path
+from typing import List, Optional, Tuple
+
+from lhotse import CutSet, validate_recordings_and_supervisions
+from lhotse.audio import Recording, RecordingSet
+from lhotse.cut import Cut
+from lhotse.qa import fix_manifests
+from lhotse.supervision import SupervisionSegment, SupervisionSet
+from lhotse.utils import Pathlike
+from tqdm.auto import tqdm
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--dataset-path",
+ type=str,
+ help="The path of OpenDialog dataset.",
+ )
+
+ parser.add_argument(
+ "--num-jobs",
+ type=int,
+ default=20,
+ help="Number of jobs to processing.",
+ )
+
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="data/manifests",
+ help="The destination directory of manifest files.",
+ )
+ parser.add_argument(
+ "--sampling-rate",
+ type=int,
+ default=24000,
+ help="The target sampling rate.",
+ )
+ return parser.parse_args()
+
+
+def _parse_recording(
+ wav_path: str,
+) -> Tuple[Recording, str]:
+ """
+ :param wav_path: Path to the audio file
+ :return: a tuple of "recording" and "recording_id"
+ """
+
+ recording_id = Path(wav_path).stem
+ recording = Recording.from_file(path=wav_path, recording_id=recording_id)
+
+ return recording, recording_id
+
+
+def _parse_supervision(
+ supervision: List, recording_dict: dict
+) -> Optional[SupervisionSegment]:
+ """
+ :param line: A line from the TSV file
+ :param recording_dict: Dictionary mapping recording IDs to Recording objects
+ :return: A SupervisionSegment object
+ """
+
+ def _round_down(num, ndigits=0):
+ factor = 10**ndigits
+ return math.floor(num * factor) / factor
+
+ uniq_id, text, wav_path, start, end = supervision
+ try:
+ recording_id = Path(wav_path).stem
+
+ recording = recording_dict[recording_id]
+ duration = (
+ _round_down(end - start, ndigits=8)
+ if end is not None
+ else _round_down(recording.duration, ndigits=8)
+ )
+ assert duration <= recording.duration, f"Duration {duration} is greater than "
+ f"recording duration {recording.duration}"
+
+ text = re.sub("_", " ", text) # "_" is treated as padding symbol
+ text = re.sub(r"\s+", " ", text) # remove extra whitespace
+
+ return SupervisionSegment(
+ id=f"{uniq_id}",
+ recording_id=recording.id,
+ start=start,
+ duration=duration,
+ channel=recording.channel_ids,
+ text=text.strip(),
+ )
+ except Exception as e:
+ logging.info(f"Error processing line: {e}")
+ return None
+
+
+def prepare_subset(
+ jsonl_path: Pathlike,
+ lang: str,
+ sampling_rate: int,
+ num_jobs: int,
+ output_dir: Pathlike,
+):
+ """
+ Returns the manifests which consist of the Recordings and Supervisions
+
+ :param jsonl_path: Path to the jsonl file
+ :param lang: Language of the subset
+ :param sampling_rate: Target sampling rate of the audio
+ :param num_jobs: Number of processes for parallel processing
+ :param output_dir: Path where to write the manifests
+ """
+ logging.info(f"Preparing {lang} subset")
+
+ # Step 1: Read all unique recording paths
+ logging.info(f"Reading {jsonl_path}")
+ recordings_path_set = set()
+ supervision_list = list()
+ with open(jsonl_path, "r") as fr:
+ for line in fr:
+ try:
+ items = json.loads(line)
+ uniq_id, text, wav_path = items["id"], items["text"], items["path"]
+ start, end = 0, None
+ recordings_path_set.add(jsonl_path.parent / wav_path)
+ supervision_list.append((uniq_id, text, wav_path, start, end))
+ except Exception as e:
+ logging.warning(f"Error {e} when decoding JSON line: {line}")
+ continue
+ logging.info("Starting to process recordings...")
+ # Step 2: Process recordings
+ futures = []
+ recording_dict = {}
+ with ThreadPoolExecutor(max_workers=num_jobs) as ex:
+ for wav_path in tqdm(recordings_path_set, desc="Submitting jobs"):
+ futures.append(ex.submit(_parse_recording, wav_path))
+
+ for future in tqdm(futures, desc="Processing recordings"):
+ try:
+ recording, recording_id = future.result()
+ recording_dict[recording_id] = recording
+ except Exception as e:
+ logging.warning(
+ f"Error processing recording {recording_id} with error: {e}"
+ )
+
+ recording_set = RecordingSet.from_recordings(recording_dict.values())
+
+ logging.info("Starting to process supervisions...")
+ # Step 3: Process supervisions
+ supervisions = []
+ for supervision in tqdm(supervision_list, desc="Processing supervisions"):
+ seg = _parse_supervision(supervision, recording_dict)
+ if seg is not None:
+ supervisions.append(seg)
+
+ logging.info("Processing Cuts...")
+
+ # Step 4: Create and validate manifests
+ supervision_set = SupervisionSet.from_segments(supervisions)
+
+ recording_set, supervision_set = fix_manifests(recording_set, supervision_set)
+ validate_recordings_and_supervisions(recording_set, supervision_set)
+
+ cut_set = CutSet.from_manifests(
+ recordings=recording_set, supervisions=supervision_set
+ )
+ cut_set = cut_set.sort_by_recording_id()
+ if sampling_rate != 24000:
+ # All OpenDialog audios are 24kHz
+ cut_set = cut_set.resample(sampling_rate)
+ cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
+
+ logging.info("Saving cuts to disk...")
+ # Step 5: Write manifests to disk
+ cut_set.to_file(output_dir / f"opendialog_cuts_raw_{lang.upper()}-all.jsonl.gz")
+ dev_cut_set = cut_set.subset(first=1000)
+ dev_cut_set.to_file(output_dir / f"opendialog_cuts_raw_{lang.upper()}-dev.jsonl.gz")
+
+ def remove_dev(c: Cut, set: set):
+ if c.id in set:
+ return False
+ return True
+
+ _remove_dev = partial(remove_dev, set=set(dev_cut_set.ids))
+ train_cut_set = cut_set.filter(_remove_dev)
+ train_cut_set.to_file(
+ output_dir / f"opendialog_cuts_raw_{lang.upper()}-train.jsonl.gz"
+ )
+
+
+def prepare_dataset(
+ dataset_path: Pathlike,
+ sampling_rate: int,
+ num_jobs: int,
+ output_dir: Pathlike,
+):
+ for lang in ["en", "zh"]:
+ jsonl_path = dataset_path / f"manifest.{lang}.jsonl"
+ prepare_subset(
+ jsonl_path=jsonl_path,
+ lang=lang,
+ sampling_rate=sampling_rate,
+ num_jobs=num_jobs,
+ output_dir=output_dir,
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ args = get_args()
+ dataset_path = Path(args.dataset_path)
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ prepare_dataset(
+ dataset_path=dataset_path,
+ sampling_rate=args.sampling_rate,
+ num_jobs=args.num_jobs,
+ output_dir=output_dir,
+ )
diff --git a/egs/zipvoice_dialog/run_custom.sh b/egs/zipvoice_dialog/run_custom.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0a3892bba92268aff08e614177d85733609dd75f
--- /dev/null
+++ b/egs/zipvoice_dialog/run_custom.sh
@@ -0,0 +1,145 @@
+#!/bin/bash
+
+# This script is an example of training ZipVoice-Dialog on your custom datasets.
+# Only support English and Chinese for now.
+
+# Add project root to PYTHONPATH
+export PYTHONPATH=../../:$PYTHONPATH
+
+# Set bash to 'debug' mode, it will exit on:
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=6
+
+# Number of jobs for data preparation
+nj=20
+download_dir=download/
+
+# Maximum length (seconds) of the training utterance, will filter out longer utterances
+max_len=60
+
+# We suppose you have two TSV files: "data/raw/custom_train.tsv" and
+# "data/raw/custom_dev.tsv", where "custom" is your dataset name,
+# "train"/"dev" are used for training and validation respectively.
+
+# Each line of the TSV files should be in one of the following formats:
+# (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
+# (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
+# to part of the wav. The start_time and end_time specify the start and end
+# times of the text within the wav, which should be in seconds.
+# > Note: {uniq_id} must be unique for each line.
+# > Note: {text} uses [S1] and [S2] tags to distinguish speakers, and must be begin with [S1].
+# > eg: "[S1] Hello. [S2] How are you? [S1] I'm fine. [S2] What's your name?"
+for subset in train dev;do
+ file_path=data/raw/custom_${subset}.tsv
+ [ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
+done
+
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "Stage 1: Prepare manifests for custom dataset from tsv files"
+
+ for subset in train dev;do
+ python3 -m zipvoice.bin.prepare_dataset \
+ --tsv-path data/raw/custom_${subset}.tsv \
+ --prefix custom \
+ --subset raw_${subset} \
+ --num-jobs ${nj} \
+ --output-dir data/manifests
+ done
+ # The output manifest files are "data/manifests/custom_cuts_raw_train.jsonl.gz".
+ # and "data/manifests/custom_cuts_raw_dev.jsonl.gz".
+fi
+
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "Stage 2: Add tokens to manifests"
+ for subset in train dev;do
+ python3 -m zipvoice.bin.prepare_tokens \
+ --input-file data/manifests/custom_cuts_raw_${subset}.jsonl.gz \
+ --output-file data/manifests/custom_cuts_${subset}.jsonl.gz \
+ --tokenizer dialog
+ done
+ # The output manifest files are "data/manifests/custom_cuts_train.jsonl.gz".
+ # and "data/manifests/custom_cuts_dev.jsonl.gz".
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ echo "Stage 3: Compute Fbank for custom dataset"
+ # You can skip this step and use `--on-the-fly-feats 1` in training stage
+ for subset in train dev; do
+ python3 -m zipvoice.bin.compute_fbank \
+ --source-dir data/manifests \
+ --dest-dir data/fbank \
+ --dataset custom \
+ --subset ${subset} \
+ --num-jobs ${nj}
+ done
+fi
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Stage 4: Download tokens file, pretrained models"
+ # Uncomment this line to use HF mirror
+ # export HF_ENDPOINT=https://hf-mirror.com
+
+ # The token file is obtained by extending some tokens
+ # on the bases of the Emilia token file.
+ mkdir -p ${download_dir}
+ hf_repo=k2-fsa/ZipVoice
+ huggingface-cli download \
+ --local-dir ${download_dir} \
+ ${hf_repo} \
+ zipvoice_dialog/tokens.txt
+
+ # Pre-trained ZipVoice model is required as
+ # the initialization model.
+ for file in model.pt tokens.txt model.json; do
+ huggingface-cli download \
+ --local-dir ${download_dir} \
+ ${hf_repo} \
+ zipvoice/${file}
+ done
+fi
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "Stage 5: Train the ZipVoice-Dialog model"
+ python3 -m zipvoice.bin.train_zipvoice_dialog \
+ --world-size 4 \
+ --use-fp16 1 \
+ --base-lr 0.0001 \
+ --num-iters 60000 \
+ --max-duration 500 \
+ --max-len ${max_len} \
+ --checkpoint ${download_dir}/zipvoice/model.pt \
+ --model-config ${download_dir}/zipvoice/model.json \
+ --token-file ${download_dir}/zipvoice_dialog/tokens.txt \
+ --dataset custom \
+ --train-manifest data/fbank/custom_cuts_train.jsonl.gz \
+ --dev-manifest data/fbank/custom_cuts_dev.jsonl.gz \
+ --exp-dir exp/zipvoice_dialog_custom
+fi
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "Stage 6: Average the checkpoints for ZipVoice"
+ python3 -m zipvoice.bin.generate_averaged_model \
+ --iter 60000 \
+ --avg 2 \
+ --model-name zipvoice_dialog \
+ --exp-dir exp/zipvoice_dialog_custom
+ # The generated model is exp/zipvoice_dialog/iter-60000-avg-2.pt
+fi
+
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+ echo "Stage 6: Inference of the ZipVoice model"
+ python3 -m zipvoice.bin.infer_zipvoice_dialog \
+ --model-name zipvoice_dialog \
+ --model-dir exp/zipvoice_dialog_custom \
+ --checkpoint-name iter-60000-avg-2.pt \
+ --test-list test.tsv \
+ --res-dir results/test_dialog_custom
+fi
\ No newline at end of file
diff --git a/egs/zipvoice_dialog/run_eval.sh b/egs/zipvoice_dialog/run_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..289ff822cbbbd62998a10b05dcffa553803d9a6f
--- /dev/null
+++ b/egs/zipvoice_dialog/run_eval.sh
@@ -0,0 +1,120 @@
+#!/bin/bash
+
+# This script is an example of evaluate TTS models with objective metrics reported in ZipVoice-Dialog paper.
+
+# Add project root to PYTHONPATH
+export PYTHONPATH=../../:$PYTHONPATH
+
+# Set bash to 'debug' mode, it will exit on:
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=6
+
+download_dir=download/
+
+# Uncomment this line to use HF mirror
+# export HF_ENDPOINT=https://hf-mirror.com
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "Stage 1: Download test sets (test-dialog)"
+ hf_repo=k2-fsa/TTS_eval_datasets
+ mkdir -p ${download_dir}/
+ file=dialog_testset.tar.gz
+ echo "Downloading ${file}..."
+ huggingface-cli download \
+ --repo-type dataset \
+ --local-dir ${download_dir}/ \
+ ${hf_repo} \
+ ${file}
+ echo "Extracting ${file}..."
+ tar -xzf ${download_dir}/${file} -C ${download_dir}/
+fi
+
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "Stage 2: Download all required evaluation models"
+ mkdir -p ${download_dir}/tts_eval_models
+ mkdir -p ${download_dir}
+ huggingface-cli download \
+ --local-dir ${download_dir}/tts_eval_models \
+ ${hf_repo}
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "Stage 3: Inference with the pre-trained ZipVoice model from huggingface"
+
+ for testset in test_dialog_en test_dialog_zh; do
+ if [ "$testset" = "test_dialog_en" ]; then
+ test_tsv=${download_dir}/dialog_testset/en/test.tsv
+ elif [ "$testset" = "test_dialog_zh" ]; then
+ test_tsv=${download_dir}/dialog_testset/zh/test.tsv
+ else
+ echo "Error: unknown testset ${testset}" >&2
+ exit 1
+ fi
+ echo "Inference on tetset ${testset}..."
+ python3 -m zipvoice.bin.infer_zipvoice_dialog \
+ --model-name zipvoice_dialog \
+ --test-list ${test_tsv} \
+ --res-dir results/${testset}
+ done
+fi
+
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Stage 4: Evaluation on test-dialog-en"
+ model_path=${download_dir}/tts_eval_models
+ wav_path=results/test_dialog_en
+ test_tsv=${download_dir}/dialog_testset/en/test.tsv
+
+ python3 -m zipvoice.eval.speaker_similarity.cpsim \
+ --wav-path ${wav_path} \
+ --test-list ${test_tsv} \
+ --model-dir ${model_path}
+
+ python3 -m zipvoice.eval.wer.dialog \
+ --wav-path ${wav_path} \
+ --test-list ${test_tsv} \
+ --model-dir ${model_path} \
+ --lang en
+
+ # cpWER mode: will only compute WER and cpWER
+ # for speech less than 30s
+ python3 -m zipvoice.eval.wer.dialog \
+ --wav-path ${wav_path} \
+ --test-list ${test_tsv} \
+ --model-dir ${model_path} \
+ --lang en \
+ --cpwer
+
+ python3 -m zipvoice.eval.mos.utmos \
+ --wav-path ${wav_path} \
+ --model-dir ${model_path}
+fi
+
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "Stage 5: Evaluation on test-dialog-zh"
+ model_path=${download_dir}/tts_eval_models
+ wav_path=results/test_dialog_zh
+ test_tsv=${download_dir}/dialog_testset/zh/test.tsv
+
+ python3 -m zipvoice.eval.speaker_similarity.cpsim \
+ --wav-path ${wav_path} \
+ --test-list ${test_tsv} \
+ --model-dir ${model_path}
+
+ python3 -m zipvoice.eval.wer.dialog \
+ --wav-path ${wav_path} \
+ --test-list ${test_tsv} \
+ --model-dir ${model_path} \
+ --lang zh
+
+ python3 -m zipvoice.eval.mos.utmos \
+ --wav-path ${wav_path} \
+ --model-dir ${model_path}
+fi
\ No newline at end of file
diff --git a/egs/zipvoice_dialog/run_finetune.sh b/egs/zipvoice_dialog/run_finetune.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ec12bfce1fc6a1bf74a50e00b029e350541410bf
--- /dev/null
+++ b/egs/zipvoice_dialog/run_finetune.sh
@@ -0,0 +1,135 @@
+#!/bin/bash
+
+# This script is an example of fine-tune our pre-trained ZipVoice-Dialog on your custom datasets.
+# Only support English and Chinese for now.
+
+# Add project root to PYTHONPATH
+export PYTHONPATH=../../:$PYTHONPATH
+
+# Set bash to 'debug' mode, it will exit on:
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=6
+
+# Number of jobs for data preparation
+nj=20
+# Maximum length (seconds) of the training utterance, will filter out longer utterances
+max_len=60
+download_dir=download/
+
+# We suppose you have two TSV files: "data/raw/custom_train.tsv" and
+# "data/raw/custom_dev.tsv", where "custom" is your dataset name,
+# "train"/"dev" are used for training and validation respectively.
+
+# Each line of the TSV files should be in one of the following formats:
+# (1) `{uniq_id}\t{text}\t{wav_path}` if the text corresponds to the full wav,
+# (2) `{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time}` if text corresponds
+# to part of the wav. The start_time and end_time specify the start and end
+# times of the text within the wav, which should be in seconds.
+# > Note: {uniq_id} must be unique for each line.
+# > Note: {text} uses [S1] and [S2] tags to distinguish speakers, and must be begin with [S1].
+# > eg: "[S1] Hello. [S2] How are you? [S1] I'm fine. [S2] What's your name?"
+for subset in train dev;do
+ file_path=data/raw/custom_${subset}.tsv
+ [ -f "$file_path" ] || { echo "Error: expect $file_path !" >&2; exit 1; }
+done
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "Stage 1: Prepare manifests for custom dataset from tsv files"
+
+ for subset in train dev;do
+ python3 -m zipvoice.bin.prepare_dataset \
+ --tsv-path data/raw/custom_${subset}.tsv \
+ --prefix custom-finetune \
+ --subset raw_${subset} \
+ --num-jobs ${nj} \
+ --output-dir data/manifests
+ done
+ # The output manifest files are "data/manifests/custom-finetune_cuts_raw_train.jsonl.gz".
+ # and "data/manifests/custom-finetune_cuts_raw_dev.jsonl.gz".
+fi
+
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "Stage 2: Add tokens to manifests"
+ for subset in train dev;do
+ python3 -m zipvoice.bin.prepare_tokens \
+ --input-file data/manifests/custom-finetune_cuts_raw_${subset}.jsonl.gz \
+ --output-file data/manifests/custom-finetune_cuts_${subset}.jsonl.gz \
+ --tokenizer dialog
+ done
+ # The output manifest files are "data/manifests/custom-finetune_cuts_train.jsonl.gz".
+ # and "data/manifests/custom-finetune_cuts_dev.jsonl.gz".
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ echo "Stage 3: Compute Fbank for custom dataset"
+ # You can skip this step and use `--on-the-fly-feats 1` in training stage
+ for subset in train dev; do
+ python3 -m zipvoice.bin.compute_fbank \
+ --source-dir data/manifests \
+ --dest-dir data/fbank \
+ --dataset custom-finetune \
+ --subset ${subset} \
+ --num-jobs ${nj}
+ done
+fi
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Stage 4: Download pre-trained model, tokens file, and model config"
+ # Uncomment this line to use HF mirror
+ # export HF_ENDPOINT=https://hf-mirror.com
+
+ mkdir -p ${download_dir}
+ hf_repo=k2-fsa/ZipVoice
+ for file in model.pt tokens.txt model.json; do
+ huggingface-cli download \
+ --local-dir ${download_dir} \
+ ${hf_repo} \
+ zipvoice_dialog/${file}
+ done
+fi
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "Stage 5: Fine-tune the ZipVoice-Dialog model"
+ python3 -m zipvoice.bin.train_zipvoice_dialog \
+ --world-size 4 \
+ --use-fp16 1 \
+ --finetune 1 \
+ --base-lr 0.0001 \
+ --num-iters 10000 \
+ --save-every-n 1000 \
+ --max-duration 500 \
+ --max-len ${max_len} \
+ --checkpoint ${download_dir}/zipvoice_dialog/model.pt \
+ --model-config ${download_dir}/zipvoice_dialog/model.json \
+ --token-file ${download_dir}/zipvoice_dialog/tokens.txt \
+ --dataset custom \
+ --train-manifest data/fbank/custom-finetune_cuts_train.jsonl.gz \
+ --dev-manifest data/fbank/custom-finetune_cuts_dev.jsonl.gz \
+ --exp-dir exp/zipvoice_dialog_finetune
+fi
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "Stage 6: Average the checkpoints for ZipVoice"
+ python3 -m zipvoice.bin.generate_averaged_model \
+ --iter 10000 \
+ --avg 2 \
+ --model-name zipvoice_dialog \
+ --exp-dir exp/zipvoice_dialog_finetune
+ # The generated model is exp/zipvoice_dialog_finetune/iter-10000-avg-2.pt
+fi
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+ echo "Stage 7: Inference of the ZipVoice model"
+ python3 -m zipvoice.bin.infer_zipvoice_dialog \
+ --model-name zipvoice_dialog \
+ --model-dir exp/zipvoice_dialog_finetune \
+ --checkpoint-name iter-10000-avg-2.pt \
+ --test-list test.tsv \
+ --res-dir results/test_dialog_finetune
+fi
\ No newline at end of file
diff --git a/egs/zipvoice_dialog/run_opendialog.sh b/egs/zipvoice_dialog/run_opendialog.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b6c84815f302f3db7d42bafeffba3e1015344b19
--- /dev/null
+++ b/egs/zipvoice_dialog/run_opendialog.sh
@@ -0,0 +1,122 @@
+#!/bin/bash
+
+# This script is an example of training ZipVoice-Dialog on OpenDialog dataset.
+
+# Add project root to PYTHONPATH
+export PYTHONPATH=../../:$PYTHONPATH
+
+# Set bash to 'debug' mode, it will exit on:
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=6
+
+# Number of jobs for data preparation
+nj=20
+
+# We assume that you have downloaded the OpenDialog dataset
+# to download/OpenDialog and untarred the tar files in audio/en
+# and audio/zh so that the mp3 files are placed under these two directories.
+
+# Download OpenDialog at https://huggingface.co/datasets/k2-fsa/OpenDialog
+# or https://www.modelscope.cn/datasets/k2-fsa/OpenDialog
+data_dir=download/OpenDialog
+download_dir=download/
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "Stage 1: Prepare manifests for OpenDialog dataset"
+
+ python3 local/prepare_opendialog.py \
+ --dataset-path ${data_dir} \
+ --num-jobs ${nj} \
+ --output-dir data/manifests
+fi
+
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "Stage 2: Add tokens to manifests"
+ for subset in ZH-dev ZH-train EN-dev EN-train;do
+ python3 -m zipvoice.bin.prepare_tokens \
+ --input-file data/manifests/opendialog_cuts_raw_${subset}.jsonl.gz \
+ --output-file data/manifests/opendialog_cuts_${subset}.jsonl.gz \
+ --tokenizer dialog
+ done
+fi
+
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ echo "Stage 3: Compute Fbank for opendialog dataset"
+ # You can skip this step and use `--on-the-fly-feats 1` in training stage
+ for subset in ZH-dev ZH-train EN-dev EN-train;do
+ python3 -m zipvoice.bin.compute_fbank \
+ --source-dir data/manifests \
+ --dest-dir data/fbank \
+ --dataset opendialog \
+ --subset ${subset} \
+ --num-jobs ${nj}
+ done
+fi
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Stage 4: Download tokens file, pretrained models"
+ # Uncomment this line to use HF mirror
+ # export HF_ENDPOINT=https://hf-mirror.com
+
+ # The token file is obtained by extending some tokens
+ # on the bases of the Emilia token file.
+ mkdir -p ${download_dir}
+ hf_repo=k2-fsa/ZipVoice
+ huggingface-cli download \
+ --local-dir ${download_dir} \
+ ${hf_repo} \
+ zipvoice_dialog/tokens.txt
+
+ # Pre-trained ZipVoice model is required as
+ # the initialization model.
+ for file in model.pt tokens.txt model.json; do
+ huggingface-cli download \
+ --local-dir ${download_dir} \
+ ${hf_repo} \
+ zipvoice/${file}
+ done
+fi
+
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "Stage 5: Train the ZipVoice-Dialog model"
+ python3 -m zipvoice.bin.train_zipvoice_dialog \
+ --world-size 8 \
+ --use-fp16 1 \
+ --base-lr 0.0001 \
+ --max-duration 500 \
+ --checkpoint ${download_dir}/zipvoice/model.pt \
+ --model-config ${download_dir}/zipvoice/model.json \
+ --token-file ${download_dir}/zipvoice_dialog/tokens.txt \
+ --dataset opendialog \
+ --manifest-dir data/fbank \
+ --exp-dir exp/zipvoice_dialog_opendialog
+fi
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "Stage 6: Average the checkpoints for ZipVoice"
+ python3 -m zipvoice.bin.generate_averaged_model \
+ --iter 60000 \
+ --avg 2 \
+ --model-name zipvoice_dialog \
+ --exp-dir exp/zipvoice_dialog_opendialog
+ # The generated model is exp/zipvoice_dialog_opendialog/iter-60000-avg-2.pt
+fi
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+ echo "Stage 7: Inference of the ZipVoice model"
+
+ python3 -m zipvoice.bin.infer_zipvoice_dialog \
+ --model-name zipvoice_dialog \
+ --model-dir exp/zipvoice_dialog_opendialog \
+ --checkpoint-name iter-60000-avg-2.pt \
+ --test-list test.tsv \
+ --res-dir results/test_dialog
+fi
\ No newline at end of file
diff --git a/infer.py b/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2316aacff87bc70349ebb9a5f035ca83dd7e383e
--- /dev/null
+++ b/infer.py
@@ -0,0 +1,578 @@
+from typing import List, Dict, Tuple
+import torch
+from transformers import (
+ AutoTokenizer, AutoModelForTokenClassification,
+ DataCollatorForTokenClassification, Trainer, TrainingArguments
+)
+LABEL_LIST = ["O", "B-EN", "I-EN"]
+LABEL2ID = {l:i for i,l in enumerate(LABEL_LIST)}
+ID2LABEL = {i:l for l,i in LABEL2ID.items()}
+
+model_name = "meandyou200175/detect_english"
+model_detect = AutoModelForTokenClassification.from_pretrained(
+ model_name, num_labels=len(LABEL_LIST),
+ id2label=ID2LABEL, label2id=LABEL2ID
+)
+tokenizer_detect = AutoTokenizer.from_pretrained(model_name, use_fast=True)
+
+def tokens_to_pred_spans(offsets: List[Tuple[int,int]], pred_ids: List[int]) -> List[Tuple[int,int]]:
+ spans=[]; cur=None
+ for (start,end), lid in zip(offsets, pred_ids):
+ if start==end: continue
+ lab = ID2LABEL.get(lid,"O")
+ if lab=="B-EN":
+ if cur: spans.append(cur)
+ cur=[start,end]
+ elif lab=="I-EN":
+ if cur: cur[1]=end
+ else: cur=[start,end]
+ else:
+ if cur: spans.append(cur); cur=None
+ if cur: spans.append(cur)
+ return [tuple(x) for x in spans]
+
+def merge_close_spans(spans: List[Dict], max_gap: int = 2) -> List[Dict]:
+ if not spans:
+ return []
+ merged = [spans[0]]
+ for cur in spans[1:]:
+ prev = merged[-1]
+ if cur["start"] - prev["end"] <= max_gap:
+ # gộp lại
+ prev["end"] = cur["end"]
+ else:
+ merged.append(cur)
+ return merged
+
+
+def infer_spans(text: str, tokenizer, model, max_length: int = 256) -> List[Dict]:
+ text = text.lower()
+ enc = tokenizer(text, return_offsets_mapping=True, truncation=True,
+ max_length=max_length, return_tensors="pt")
+ offsets = enc["offset_mapping"][0].tolist()
+ with torch.no_grad():
+ out = model(**{k: v for k, v in enc.items() if k != "offset_mapping"})
+ pred_ids = out.logits.argmax(-1)[0].tolist()
+ spans = tokens_to_pred_spans(offsets, pred_ids)
+ spans = [{"start": s, "end": e} for (s, e) in spans]
+ spans = merge_close_spans(spans, max_gap=2)
+ # print(spans)
+ return spans
+
+import unicodedata
+
+def is_letter(ch: str) -> bool:
+ if not ch:
+ return False
+ # Nếu người dùng lỡ truyền vào tổ hợp có dấu (e + ◌́), chuẩn hoá về NFC:
+ ch = unicodedata.normalize("NFC", ch)
+ # Chỉ chấp nhận đúng 1 ký tự sau chuẩn hoá
+ if len(ch) != 1:
+ return False
+ # Nhóm 'L*' của Unicode: Lu, Ll, Lt, Lm, Lo
+ return unicodedata.category(ch).startswith('L')
+
+import re
+from itertools import chain
+from typing import List, Dict, Optional
+import logging
+from functools import reduce
+from piper_phonemize import phonemize_espeak
+
+class EspeakTokenizer():
+ """A tokenizer with Espeak g2p function, hỗ trợ English + Vietnamese."""
+
+ def __init__(self, token_file: Optional[str] = None, lang: str = "vi",
+ tokenizer=None, model=None):
+ self.has_tokens = False
+ self.lang = lang
+ self.detector_tokenizer = tokenizer
+ self.detector_model = model
+
+ if token_file is None:
+ logging.debug("Initialize Tokenizer without tokens file, "
+ "will fail when map to ids.")
+ return
+
+ self.token2id: Dict[str, int] = {}
+ with open(token_file, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ info = line.rstrip().split("\t")
+ token, id = info[0], int(info[1])
+ assert token not in self.token2id, token
+ self.token2id[token] = id
+ self.pad_id = self.token2id["_"]
+ self.vocab_size = len(self.token2id)
+ self.has_tokens = True
+
+ @staticmethod
+ def _flatten(phs):
+ """Phẳng hóa list-of-lists (hoặc trả lại list nếu đã phẳng)."""
+ if not phs:
+ return []
+ if isinstance(phs[0], (list, tuple)):
+ return list(chain.from_iterable(phs))
+ return list(phs)
+
+ def g2p_chunk(self, text: str, lang: str):
+ tokens = []
+ start = 0
+ for t in text:
+ if is_letter(t):
+ break
+ start = start + 1
+
+ # Giữ lại: khoảng trắng (\s+), từ (\w+), ký tự khác [^\w\s]
+ if start > 0 :
+ tokens.extend(self._flatten(text[0:start]))
+ phs = phonemize_espeak(text[start:], lang) # có thể trả về list-of-lists
+ tokens.extend(self._flatten(phs))
+ return tokens
+
+ def g2p(self, text: str) -> List[str]:
+ """Tách text thành spans EN/VI rồi phonemize tương ứng, bảo toàn khoảng trắng/dấu câu."""
+ try:
+ # Fallback: không có detector => phonemize toàn chuỗi theo self.lang,
+ # nhưng qua g2p_chunk để không mất khoảng trắng/dấu câu.
+ if self.detector_tokenizer is None or self.detector_model is None:
+ return self.g2p_chunk(text, self.lang)
+
+ spans = infer_spans(text, self.detector_tokenizer, self.detector_model)
+ spans = sorted(spans, key=lambda x: x["start"])
+
+ tokens_all = []
+ last = 0
+ for sp in spans:
+ s, e = sp["start"], sp["end"]
+ # phần trước đoạn EN -> VI
+ if s > last:
+ vi_chunk = text[last:s]
+ if vi_chunk:
+ tokens_all.extend(self.g2p_chunk(vi_chunk, "vi"))
+ # đoạn EN
+ en_chunk = text[s:e]
+ if en_chunk:
+ tokens_all.extend([" "])
+ tokens_all.extend(self.g2p_chunk(en_chunk, "en"))
+ last = e
+
+ # phần còn lại sau EN -> VI
+ if last < len(text):
+ vi_chunk = text[last:]
+ if vi_chunk:
+ tokens_all.extend(self.g2p_chunk(vi_chunk, "vi"))
+
+ return tokens_all
+
+ except Exception as ex:
+ logging.warning(f"Tokenization of mixed {self.lang} texts failed: {ex}")
+ return []
+ def texts_to_token_ids(
+ self,
+ texts: List[str],
+ ) -> List[List[int]]:
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
+
+ def texts_to_tokens(
+ self,
+ texts: List[str],
+ ) -> List[List[str]]:
+ tokens_list = [self.g2p(texts[i]) for i in range(len(texts))]
+ return tokens_list
+
+ def tokens_to_token_ids(
+ self,
+ tokens_list: List[List[str]],
+ ) -> List[List[int]]:
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
+
+ token_ids_list = []
+
+ for tokens in tokens_list:
+ token_ids = []
+ for t in tokens:
+ if t not in self.token2id:
+ logging.debug(f"Skip OOV {t}")
+ continue
+ token_ids.append(self.token2id[t])
+
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
+import re # <-- thêm
+import random
+import datetime as dt
+import json
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import safetensors.torch
+import torch
+import torchaudio
+from huggingface_hub import hf_hub_download
+from lhotse.utils import fix_random_seed
+from vocos import Vocos
+
+from zipvoice.models.zipvoice import ZipVoice
+from zipvoice.models.zipvoice_distill import ZipVoiceDistill
+# from zipvoice.tokenizer.tokenizer import EmiliaTokenizer, EspeakTokenizer, LibriTTSTokenizer, SimpleTokenizer, SimpleTokenizer2
+from zipvoice.utils.checkpoint import load_checkpoint
+from zipvoice.utils.common import AttributeDict
+from zipvoice.utils.feature import VocosFbank
+def load_vocab(file_path):
+ """Đọc file vocab dạng char id -> trả về dict {id: char}"""
+ id2char = {}
+ with open(file_path, "r", encoding="utf-8") as f:
+ for line in f:
+ if not line.strip():
+ continue
+ # bỏ \n nhưng giữ lại space đầu dòng
+ line = line.rstrip("\n")
+ parts = line.split("\t")
+ if len(parts) != 2:
+ continue # bỏ qua dòng lỗi
+ char, idx = parts
+ id2char[int(idx)] = char
+ return id2char
+
+
+def tokens_to_text(tokens, id2char):
+ """Chuyển list token về string"""
+ return "".join(id2char.get(t, "") for t in tokens)
+
+def get_vocoder(vocos_local_path: Optional[str] = None):
+ if vocos_local_path:
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
+ state_dict = torch.load(
+ f"{vocos_local_path}/pytorch_model.bin",
+ weights_only=True,
+ map_location="cpu",
+ )
+ vocoder.load_state_dict(state_dict)
+ else:
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
+ return vocoder
+
+
+HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
+MODEL_DIR = {
+ "zipvoice": "zipvoice",
+ "zipvoice_distill": "zipvoice_distill",
+}
+
+model_dir="zipvoice_finetune/"
+checkpoint_name="iter-525000-avg-2.pt"
+# checkpoint_name="model.pt"
+model_dir = Path(model_dir)
+model_ckpt = model_dir / checkpoint_name
+model_config_path = model_dir / "model.json"
+token_file = model_dir / "tokens.txt"
+
+
+tokenizer = EspeakTokenizer(token_file=token_file, tokenizer=tokenizer_detect, model=model_detect)
+
+
+tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
+
+with open(model_config_path, "r") as f:
+ model_config = json.load(f)
+
+# --- Init model ---
+
+model = ZipVoice(**model_config["model"], **tokenizer_config)
+
+if str(model_ckpt).endswith(".safetensors"):
+ safetensors.torch.load_model(model, model_ckpt)
+else:
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+model = model.to(device).eval()
+
+# --- Vocoder & features ---
+vocoder = get_vocoder(None).to(device).eval()
+feature_extractor = VocosFbank()
+sampling_rate = model_config["feature"]["sampling_rate"]
+import torch
+import numpy as np
+
+import torch
+import numpy as np
+def score_tokens(A):
+ B = [9, 14, 18, 21, 27, 33, 37, 39, 42, 45, 50, 51, 52, 54, 58, 59, 61, 62, 63, 69, 73, 74, 79, 85, 99, 100, 102, 105, 119, 120, 121, 122, 123, 124, 141, 143, 144, 145, 146, 157, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 349, 350, 353, 356, 357, 358, 359]
+
+ total_score = 0
+ # Thêm 3 vào đầu và cuối
+ tokens = [3] + A + [3]
+
+ # Tách chuỗi theo số 3
+ segment = []
+ for t in tokens:
+ if t == 3:
+ if segment: # xử lý 1 đoạn
+ count = 0
+ for i in range(len(segment) - 1):
+ if (segment[i] in B and segment[i+1] not in B):
+ # print(f"{segment[i]} in B and {segment[i+1]} not in B)")
+ count += 1
+ if segment[-1] in B:
+ # print(f"{segment[-1]} in B")
+ count += 1
+ if count > 0:
+ total_score += 1 + (count - 1) * 0.5
+ segment = []
+ else:
+ segment.append(t)
+
+ return total_score
+
+
+def trim_leading_silence_torch(
+ wav: torch.Tensor,
+ sample_rate: int,
+ silence_thresh: float = 0.05,
+ chunk_ms: int = 10,
+ extend_ms: int = 20,
+ ratio: float = 0.95, # % sample phải dưới ngưỡng để coi là im lặng
+):
+ wav_np = wav.squeeze(0).cpu().numpy().astype(np.float32)
+ norm_wav = wav_np / (np.max(np.abs(wav_np)) + 1e-8)
+
+ chunk_size = int(sample_rate * chunk_ms / 1000)
+ total_chunks = int(len(norm_wav) / chunk_size)
+
+ start_idx = 0
+ for i in range(total_chunks):
+ chunk = norm_wav[i * chunk_size : (i + 1) * chunk_size]
+ # Tính tỷ lệ sample dưới ngưỡng
+ silent_ratio = np.mean(np.abs(chunk) < silence_thresh)
+ if silent_ratio < ratio: # nếu ít hơn 95% sample im lặng → coi là có tiếng
+ start_idx = max(0, i * chunk_size - int(sample_rate * extend_ms / 1000))
+ break
+
+ return wav[:, start_idx:]
+
+
+
+
+@torch.inference_mode()
+def run_zipvoice(
+ model_name="zipvoice",
+ model_dir="zipvoice_finetune",
+ checkpoint_name="model.pt",
+ vocoder_path=None,
+ tokenizer_name="emilia",
+ lang="en-us",
+ test_list=None, # path to tsv file
+ prompt_wav=None,
+ prompt_text=None,
+ text=None,
+ res_dir="results",
+ res_wav_path="result.wav",
+ guidance_scale=None,
+ num_step=None,
+ feat_scale=0.1,
+ speed=1.0,
+ t_shift=0.5,
+ target_rms=0.1,
+ seed=666,
+):
+ text = text.lower()
+ # --- Default settings per model ---
+ model_defaults = {
+ "zipvoice": {"num_step": 16, "guidance_scale": 1.0},
+ "zipvoice_distill": {"num_step": 8, "guidance_scale": 3.0},
+ }
+ # sửa cách gán mặc định (không dùng locals() nữa)
+ if guidance_scale is None:
+ guidance_scale = model_defaults.get(model_name, {}).get("guidance_scale", 1.0)
+ if num_step is None:
+ num_step = model_defaults.get(model_name, {}).get("num_step", 16)
+
+ # --- Check inputs ---
+ assert (test_list is not None) ^ ((prompt_wav and prompt_text and text) is not None), \
+ "Cần test_list hoặc (prompt_wav + prompt_text + text)"
+
+ fix_random_seed(seed)
+
+ # --- Load tokenizer, model, vocoder, features ... (phần này giữ nguyên) ---
+ # [giữ nguyên toàn bộ phần load tokenizer/model/vocoder/feature_extractor/sampling_rate]
+
+ # ---------------------------
+ # NEW: Hàm chia đoạn văn bản
+ # ---------------------------
+ def split_text_into_chunks(s: str, min_chars: int = 15, max_chars: int = 30):
+ """
+ Chia theo dấu ',' hoặc '.', sau đó gộp/xẻ để mỗi đoạn dài trong [min_chars, max_chars].
+ Không cắt giữa từ.
+ """
+ # normalize khoảng trắng
+ s = re.sub(r"\s+", " ", (s or "").strip())
+ if not s:
+ return []
+
+ # tách theo dấu , hoặc .
+ raw_segs = [seg.strip() for seg in re.split(r"\s*[.,]\s*", s) if seg.strip()]
+
+ chunks = []
+ i = 0
+ while i < len(raw_segs):
+ cur = raw_segs[i]
+ i += 1
+
+ # gộp tiếp theo nếu cur quá ngắn
+ while len(cur) < min_chars and i < len(raw_segs):
+ cur = (cur + ", " + raw_segs[i]).strip()
+ i += 1
+
+ # nếu cur quá dài, xẻ theo từ để <= max_chars
+ if len(cur) > max_chars:
+ words = cur.split()
+ buf = []
+ cur_len = 0
+ for w in words:
+ # +1 cho khoảng trắng nếu cần
+ add_len = len(w) if cur_len == 0 else len(w) + 1
+ if cur_len + add_len <= max_chars:
+ buf.append(w)
+ cur_len += add_len
+ else:
+ # đóng lại một chunk
+ part = ", ".join(buf).strip()
+ if part:
+ chunks.append(part)
+ # bắt đầu chunk mới
+ buf = [w]
+ cur_len = len(w)
+ # phần còn lại
+ last = " ".join(buf).strip()
+ if last:
+ # nếu phần cuối vẫn < min_chars và có thể gộp với chunk trước đó
+ if len(last) < min_chars and chunks:
+ merged = (chunks[-1] + " " + last).strip()
+ if len(merged) <= max_chars:
+ chunks[-1] = merged
+ else:
+ chunks.append(last) # đành chấp nhận (nhưng thường ít gặp)
+ else:
+ chunks.append(last)
+ else:
+ chunks.append(cur)
+
+ # vòng tinh chỉnh cuối: nếu chunk cuối quá ngắn, gộp vào trước đó
+ if len(chunks) >= 2 and len(chunks[-1]) < min_chars:
+ merged = (chunks[-2] + ", " + chunks[-1]).strip()
+ if len(merged) <= max_chars:
+ chunks[-2] = merged
+ chunks.pop()
+ # print(chunks)
+ final_chunk = []
+ for chunk in chunks:
+ chunk = ", " + chunk + ","
+ final_chunk.append(chunk)
+ return final_chunk
+
+ # ---------------------------
+ # MODIFIED: generate_sentence synth theo từng đoạn + nối lại
+ # ---------------------------
+ def generate_sentence(save_path, prompt_text, prompt_wav, text):
+ # chuẩn hoá & chia đoạn
+ segments = split_text_into_chunks(text, min_chars=50, max_chars=200)
+ if not segments:
+ # không có gì để nói: xuất file rỗng 0.2s
+ silence = torch.zeros((1, int(0.2 * sampling_rate)))
+ torchaudio.save(save_path, silence, sample_rate=sampling_rate)
+ return
+
+ # chuẩn bị prompt (làm 1 lần)
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
+ prompt_wav_tensor, sr = torchaudio.load(prompt_wav)
+ if sr != sampling_rate:
+ prompt_wav_tensor = torchaudio.transforms.Resample(sr, sampling_rate)(prompt_wav_tensor)
+ prompt_rms_val = torch.sqrt(torch.mean(prompt_wav_tensor**2))
+ if prompt_rms_val < target_rms:
+ prompt_wav_tensor *= target_rms / prompt_rms_val
+
+ prompt_features = feature_extractor.extract(
+ prompt_wav_tensor, sampling_rate=sampling_rate
+ ).to(device)
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
+ # print(prompt_features_lens)
+
+ num_space_prompt = prompt_text.count(" ")
+
+ # khoảng lặng 0.2s
+
+
+ gap_duration = random.uniform(0.17, 0.2) # số ngẫu nhiên từ 0.17 đến 0.2
+ gap = torch.zeros((1, int(gap_duration * sampling_rate)))
+
+ wav_parts = []
+ print("segments",segments)
+ for idx, seg in enumerate(segments):
+ # print(seg)
+ num_space_text = seg.count(" ")
+ tokens = tokenizer.texts_to_token_ids([seg])
+ # print(tokens)
+ score = score_tokens(tokens[0])
+ # print(score)
+ # print(prompt_tokens)
+ score_prompt = score_tokens(prompt_tokens[0])
+ # print(score_prompt)
+ vocab_file = "zipvoice_finetune/tokens.txt" # file txt dạng bạn đưa
+
+ id2char = load_vocab(vocab_file)
+ decoded_text = tokens_to_text(tokens[0], id2char)
+
+ print(decoded_text)
+
+ pred_features, _, _, _ = model.sample(
+ num_space_text=[num_space_text],
+ num_space_prompt=[num_space_prompt],
+ tokens=tokens,
+ prompt_tokens=prompt_tokens,
+ prompt_features=prompt_features,
+ prompt_features_lens=prompt_features_lens,
+ speed= speed,
+ t_shift= t_shift,
+ duration="predict",
+ num_step= num_step,
+ guidance_scale= guidance_scale,
+ )
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
+
+ # phục hồi mức âm lượng tương quan prompt
+ if prompt_rms_val < target_rms:
+ wav *= prompt_rms_val / target_rms
+ wav = trim_leading_silence_torch(
+ wav, sample_rate=sampling_rate, silence_thresh=0.086, chunk_ms=10, extend_ms=20
+ )
+ wav_parts.append(wav.cpu())
+ if idx < len(segments) - 1:
+ wav_parts.append(gap) # chèn khoảng lặng
+
+ final_wav = torch.cat(wav_parts, dim=-1) # [1, T_total]
+ torchaudio.save(save_path, final_wav, sample_rate=sampling_rate)
+
+ # --- generate_list giữ nguyên: gọi generate_sentence nên tự áp dụng chia đoạn ---
+ def generate_list(res_dir, test_list):
+ os.makedirs(res_dir, exist_ok=True)
+ with open(test_list, "r", encoding="utf-8") as fr:
+ for i, line in enumerate(fr):
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
+ save_path = f"{res_dir}/{wav_name}.wav"
+ generate_sentence(save_path, prompt_text, prompt_wav, text)
+
+ # --- Run ---
+ if test_list:
+ generate_list(res_dir, test_list)
+ else:
+ generate_sentence(res_wav_path, prompt_text, prompt_wav, text)
+
+ print("✅ Hoàn thành!")
+ return text,
diff --git a/proccess_wav.py b/proccess_wav.py
new file mode 100644
index 0000000000000000000000000000000000000000..5544e02a5c41a9d9ba6a1912b2f145cf0daec302
--- /dev/null
+++ b/proccess_wav.py
@@ -0,0 +1,364 @@
+from typing import List, Tuple
+import numpy as np
+from pydub import AudioSegment
+import os
+from chunkformer import ChunkFormerModel
+from clearvoice import ClearVoice
+# ======================= ASR + CLEARVOICE + AUDIO PROCESSING =======================
+
+ASR_MODEL = None
+CLEARVOICE_MODEL = None
+REF_AUDIO_CACHE = {} # cache: đường dẫn input -> đường dẫn output đã xử lý
+
+
+def get_asr_model() -> ChunkFormerModel:
+ """Lazy-load ChunkFormer (ASR, chạy trên CPU)."""
+ global ASR_MODEL
+ if ASR_MODEL is None:
+ ASR_MODEL = ChunkFormerModel.from_pretrained("khanhld/chunkformer-ctc-large-vie")
+ return ASR_MODEL
+
+
+def get_clearvoice_model() -> ClearVoice:
+ """Lazy-load ClearVoice để khử nhiễu ref audio."""
+ global CLEARVOICE_MODEL
+ if CLEARVOICE_MODEL is None:
+ CLEARVOICE_MODEL = ClearVoice(
+ task="speech_enhancement",
+ model_names=["MossFormer2_SE_48K"],
+ )
+ return CLEARVOICE_MODEL
+
+
+def find_silent_regions(
+ audio: AudioSegment,
+ silence_thresh: float = 0.05, # biên độ sau chuẩn hoá [-1, 1]
+ chunk_ms: int = 10,
+ min_silence_len: int = 200,
+) -> List[Tuple[int, int]]:
+ """
+ Tìm các khoảng lặng (start_ms, end_ms) trong AudioSegment dựa trên biên độ.
+ """
+ samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
+ if audio.channels > 1:
+ samples = samples.reshape((-1, audio.channels)).mean(axis=1)
+
+ norm = samples / (2 ** (audio.sample_width * 8 - 1))
+ sr = audio.frame_rate
+
+ chunk_size = max(1, int(sr * chunk_ms / 1000))
+ total_chunks = len(norm) // chunk_size
+
+ silent_regions: List[Tuple[int, int]] = []
+ start = None
+ for i in range(total_chunks):
+ chunk = norm[i * chunk_size: (i + 1) * chunk_size]
+ if chunk.size == 0:
+ continue
+
+ if np.all((chunk > -silence_thresh) & (chunk < silence_thresh)):
+ if start is None:
+ start = i
+ else:
+ if start is not None:
+ dur = (i - start) * chunk_ms
+ if dur >= min_silence_len:
+ silent_regions.append((start * chunk_ms, i * chunk_ms))
+ start = None
+
+ if start is not None:
+ dur = (total_chunks - start) * chunk_ms
+ if dur >= min_silence_len:
+ silent_regions.append((start * chunk_ms, total_chunks * chunk_ms))
+
+ return silent_regions
+
+
+def trim_leading_trailing_silence(
+ audio: AudioSegment,
+ silence_thresh: float = 0.05,
+ chunk_ms: int = 10,
+ min_silence_len: int = 200,
+) -> AudioSegment:
+ """
+ Bỏ khoảng lặng đầu/cuối file.
+ """
+ duration = len(audio)
+ silent_regions = find_silent_regions(
+ audio,
+ silence_thresh=silence_thresh,
+ chunk_ms=chunk_ms,
+ min_silence_len=min_silence_len,
+ )
+
+ if not silent_regions:
+ return audio
+
+ start_trim = 0
+ end_trim = duration
+
+ # khoảng lặng đầu file
+ first_start, first_end = silent_regions[0]
+ if first_start <= 0:
+ start_trim = max(start_trim, first_end)
+
+ # khoảng lặng cuối file
+ last_start, last_end = silent_regions[-1]
+ if last_end >= duration:
+ end_trim = min(end_trim, last_start)
+
+ return audio[start_trim:end_trim]
+
+
+def compress_internal_silence(
+ audio: AudioSegment,
+ max_silence_ms: int = 300,
+ silence_thresh: float = 0.05,
+ chunk_ms: int = 10,
+ min_silence_len: int = 50,
+) -> AudioSegment:
+ """
+ Rút ngắn khoảng lặng giữa file:
+ - Khoảng lặng <= max_silence_ms: giữ nguyên
+ - Khoảng lặng > max_silence_ms: cắt còn max_silence_ms
+ """
+ duration = len(audio)
+ silent_regions = find_silent_regions(
+ audio,
+ silence_thresh=silence_thresh,
+ chunk_ms=chunk_ms,
+ min_silence_len=min_silence_len,
+ )
+ if not silent_regions:
+ return audio
+
+ new_audio = AudioSegment.silent(duration=0, frame_rate=audio.frame_rate)
+ cursor = 0
+
+ for s_start, s_end in silent_regions:
+ # phần có tiếng nói trước khoảng lặng
+ if s_start > cursor:
+ new_audio += audio[cursor:s_start]
+
+ silence_len = s_end - s_start
+ if silence_len <= max_silence_ms:
+ new_audio += audio[s_start:s_end]
+ else:
+ new_audio += audio[s_start: s_start + max_silence_ms]
+
+ cursor = s_end
+
+ # phần còn lại sau khoảng lặng cuối
+ if cursor < duration:
+ new_audio += audio[cursor:]
+
+ return new_audio
+
+
+def select_subsegment_by_silence(
+ audio: AudioSegment,
+ min_len_ms: int = 5000,
+ max_len_ms: int = 10000,
+ silence_thresh: float = 0.05,
+ chunk_ms: int = 10,
+ min_silence_len: int = 200,
+) -> AudioSegment:
+ """
+ Nếu audio > max_len_ms, chọn 1 đoạn dài trong khoảng [min_len_ms, max_len_ms],
+ cắt tại điểm nằm trong khoảng lặng để tránh cắt dính giọng nói.
+ """
+ duration = len(audio)
+ if duration <= max_len_ms:
+ return audio
+
+ silent_regions = find_silent_regions(
+ audio,
+ silence_thresh=silence_thresh,
+ chunk_ms=chunk_ms,
+ min_silence_len=min_silence_len,
+ )
+
+ if not silent_regions:
+ # không tìm được khoảng lặng -> lấy đoạn giữa
+ target_len = min(max_len_ms, duration)
+ start = max(0, (duration - target_len) // 2)
+ end = start + target_len
+ return audio[start:end]
+
+ # boundary là midpoint của khoảng lặng (chắc chắn nằm trong vùng im lặng)
+ boundaries = [0]
+ for s_start, s_end in silent_regions:
+ mid = (s_start + s_end) // 2
+ if 0 < mid < duration:
+ boundaries.append(mid)
+ boundaries.append(duration)
+ boundaries = sorted(set(boundaries))
+
+ # ưu tiên đoạn đầu tiên thỏa 5–10s
+ for i in range(len(boundaries)):
+ for j in range(i + 1, len(boundaries)):
+ seg_len = boundaries[j] - boundaries[i]
+ if min_len_ms <= seg_len <= max_len_ms:
+ return audio[boundaries[i]:boundaries[j]]
+
+ # nếu không có đoạn nào nằm trọn trong [min, max], chọn đoạn gần max_len nhất
+ best_i, best_j, best_diff = 0, None, None
+ for i in range(len(boundaries)):
+ for j in range(i + 1, len(boundaries)):
+ seg_len = boundaries[j] - boundaries[i]
+ if seg_len >= min_len_ms:
+ diff = abs(seg_len - max_len_ms)
+ if best_diff is None or diff < best_diff:
+ best_diff = diff
+ best_i, best_j = i, j
+
+ if best_j is not None:
+ return audio[boundaries[best_i]:boundaries[best_j]]
+
+ # fallback cuối cùng
+ target_len = min(max_len_ms, duration)
+ start = max(0, (duration - target_len) // 2)
+ end = start + target_len
+ return audio[start:end]
+
+
+def enhance_ref_audio(input_path: str) -> str:
+ """
+ Pipeline xử lý WAV cho TTS:
+ - ClearVoice khử nhiễu
+ - Bỏ khoảng lặng đầu/cuối
+ - Rút ngắn khoảng lặng giữa > 0.3s thành 0.3s
+ - Nếu audio > 10s: chọn 1 đoạn 5–10s, cắt tại khoảng lặng
+ Trả về đường dẫn file wav đã xử lý.
+ """
+ if not input_path:
+ raise ValueError("No input audio path for enhancement.")
+
+ # cache để cùng 1 file không phải xử lý nhiều lần
+ if input_path in REF_AUDIO_CACHE:
+ return REF_AUDIO_CACHE[input_path]
+
+ cv = get_clearvoice_model()
+
+ # 1) khử nhiễu
+ try:
+ cv_out = cv(input_path=input_path, online_write=False)
+ base = os.path.basename(input_path)
+ name, ext = os.path.splitext(base)
+ if not ext:
+ ext = ".wav"
+ denoised_path = os.path.join(os.path.dirname(input_path), f"{name}_denoised{ext}")
+ cv.write(cv_out, output_path=denoised_path)
+ except Exception as e:
+ print(f"[ClearVoice] Error during denoising, fallback to original: {e}")
+ denoised_path = input_path
+
+ # 2) pydub xử lý khoảng lặng + length
+ audio = AudioSegment.from_file(denoised_path)
+
+ # bỏ khoảng lặng đầu/cuối
+ audio = trim_leading_trailing_silence(audio)
+
+ # rút ngắn khoảng lặng giữa
+ audio = compress_internal_silence(audio, max_silence_ms=300)
+
+ # nếu >10s thì chọn đoạn trong khoảng 5–10s
+ audio = select_subsegment_by_silence(audio, min_len_ms=5000, max_len_ms=10000)
+
+ # 3) ghi ra file mới
+ enhanced_path = os.path.join(os.path.dirname(denoised_path), f"{name}_enhanced.wav")
+ audio.export(enhanced_path, format="wav")
+
+ REF_AUDIO_CACHE[input_path] = enhanced_path
+ return enhanced_path
+
+def split_audio_by_silence(
+ audio: AudioSegment,
+ silence_thresh: float = 0.05,
+ chunk_ms: int = 10,
+ min_silence_len: int = 200,
+ min_segment_len: int = 200,
+) -> List[Tuple[int, int]]:
+ """
+ Từ AudioSegment, trả về các đoạn có tiếng nói (non-silent)
+ được tách bằng khoảng lặng.
+ """
+ duration = len(audio)
+ silent_regions = find_silent_regions(
+ audio,
+ silence_thresh=silence_thresh,
+ chunk_ms=chunk_ms,
+ min_silence_len=min_silence_len,
+ )
+
+ segments: List[Tuple[int, int]] = []
+ cur_start = 0
+
+ for s_start, s_end in silent_regions:
+ if cur_start < s_start:
+ if s_start - cur_start >= min_segment_len:
+ segments.append((cur_start, s_start))
+ cur_start = s_end
+
+ if cur_start < duration and duration - cur_start >= min_segment_len:
+ segments.append((cur_start, duration))
+
+ # nếu không tìm được đoạn nào, lấy cả file
+ if not segments:
+ segments.append((0, duration))
+
+ return segments
+
+
+def transcribe_ref_audio(audio_path: str) -> str:
+ """
+ ASR theo yêu cầu:
+ - Cắt âm thanh theo khoảng lặng
+ - ASR từng đoạn
+ - Nối text bằng dấu phẩy
+ """
+ if not audio_path:
+ raise ValueError("No audio path for ASR.")
+
+ model = get_asr_model()
+ audio = AudioSegment.from_file(audio_path)
+ segments = split_audio_by_silence(audio)
+
+ texts = []
+ base, _ = os.path.splitext(audio_path)
+
+ for idx, (start_ms, end_ms) in enumerate(segments):
+ seg_audio = audio[start_ms:end_ms]
+ seg_path = f"{base}_seg_{idx}.wav"
+ seg_audio.export(seg_path, format="wav")
+
+ try:
+ transcription = model.endless_decode(
+ audio_path=seg_path,
+ chunk_size=32,
+ left_context_size=0,
+ right_context_size=0,
+ total_batch_duration=400,
+ return_timestamps=False,
+ )
+ except TypeError:
+ transcription = model.endless_decode(
+ audio_path=seg_path,
+ chunk_size=32,
+ left_context_size=0,
+ right_context_size=0,
+ total_batch_duration=400,
+ )
+
+ if isinstance(transcription, str):
+ text = transcription
+ else:
+ text = str(transcription)
+
+ text = text.strip()
+ if text:
+ texts.append(text)
+
+ return ", ".join(texts)
+
+
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..22df24d65d1a4a9ead1252f76af633e1e27212ff
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,5 @@
+[tool.isort]
+profile = "black"
+
+[tool.black]
+line-length = 88
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..22b5752bf1ce93627914c10cdc702c19e89245a1
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+--find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
+transformers==4.57.1
+torch
+torchaudio
+torchcodec
+numpy
+lhotse
+huggingface_hub
+safetensors
+tensorboard
+vocos
+
+# Normalization
+cn2an
+inflect
+
+# Tokenization
+jieba
+piper_phonemize
+pypinyin
+setuptools<81
+chunkformer
+clearvoice
\ No newline at end of file
diff --git a/requirements_eval.txt b/requirements_eval.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1e8eba00777f4c38e6a75f5ae0ce107db92cfad0
--- /dev/null
+++ b/requirements_eval.txt
@@ -0,0 +1,19 @@
+torch
+numpy
+
+# Audio processing
+librosa
+soundfile
+
+# Model
+s3prl
+pyannote.audio
+funasr
+transformers
+
+# WER
+jiwer==3.1.0
+
+# Normalization
+zhconv
+zhon
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8801cd5b234145ebf6220d1537d7814022c0512
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,55 @@
+import os
+import subprocess
+import requests
+from dotenv import load_dotenv
+
+def run_cmd(cmd):
+ print(f"🔹 Chạy lệnh: {cmd}")
+ result = subprocess.run(cmd, shell=True)
+ if result.returncode != 0:
+ raise RuntimeError(f"Lệnh thất bại: {cmd}")
+
+def download_with_token(url, dest_path, token):
+ headers = {"Authorization": f"Bearer {token}"}
+ with requests.get(url, headers=headers, stream=True) as r:
+ r.raise_for_status()
+ with open(dest_path, "wb") as f:
+ for chunk in r.iter_content(chunk_size=8192):
+ f.write(chunk)
+ print(f"✅ Đã tải: {dest_path}")
+
+def main():
+ # Load biến môi trường từ .env
+ load_dotenv()
+ token = os.getenv("HF_TOKEN")
+
+ if not token:
+ raise EnvironmentError("❌ Thiếu biến môi trường HF_TOKEN. Hãy tạo file .env với dòng:\nHF_TOKEN=hf_your_token_here")
+
+ # Đăng nhập vào Hugging Face CLI
+ run_cmd(f"huggingface-cli login --token {token}")
+
+ # Tạo thư mục chứa model
+ os.makedirs("zipvoice_finetune", exist_ok=True)
+
+ # Danh sách file cần tải
+ files = {
+ "iter-525000-avg-2.pt": "https://huggingface.co/datasets/meandyou200175/temp_file/resolve/main/zip/epoch-46-all-speak-600h-en-norm.pt",
+ "model.json": "https://huggingface.co/datasets/meandyou200175/temp_file/resolve/main/zip/model.json",
+ "tokens.txt": "https://huggingface.co/datasets/meandyou200175/temp_file/resolve/main/zip/tokens.txt",
+ }
+
+ for filename, url in files.items():
+ dest = os.path.join("zipvoice_finetune", filename)
+ download_with_token(url, dest, token)
+
+ # Cài đặt requirements
+ if os.path.exists("requirements.txt"):
+ run_cmd("pip install -r requirements.txt")
+ else:
+ print("⚠️ Không tìm thấy requirements.txt")
+
+ print("\n🎉 Setup hoàn tất!")
+
+if __name__ == "__main__":
+ main()
diff --git a/zipvoice/__init__.py b/zipvoice/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e221d75a4b6581539366e7a49b6856bda71b8c7
--- /dev/null
+++ b/zipvoice/__init__.py
@@ -0,0 +1,7 @@
+import warnings
+
+warnings.filterwarnings(
+ "ignore",
+ category=UserWarning,
+ message="pkg_resources is deprecated as an API.*",
+)
diff --git a/zipvoice/bin/compute_fbank.py b/zipvoice/bin/compute_fbank.py
new file mode 100644
index 0000000000000000000000000000000000000000..b707bbbd0ae9470c07ad8cc5e902e4ace88b5eed
--- /dev/null
+++ b/zipvoice/bin/compute_fbank.py
@@ -0,0 +1,272 @@
+#!/usr/bin/env python3
+# Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang
+# Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+ python3 -m zipvoice.bin.compute_fbank \
+ --source-dir data/manifests \
+ --dest-dir data/fbank \
+ --dataset libritts \
+ --subset dev-other \
+ --sampling-rate 24000 \
+ --num-jobs 20
+
+The input would be data/manifests/libritts-cuts_dev-other.jsonl.gz or
+ (libritts_supervisions_dev-other.jsonl.gz and librittsrecordings_dev-other.jsonl.gz)
+
+The output would be data/fbank/libritts-cuts_dev-other.jsonl.gz
+"""
+
+
+import argparse
+import logging
+from concurrent.futures import ProcessPoolExecutor as Pool
+from pathlib import Path
+
+import lhotse
+import torch
+from lhotse import CutSet, LilcomChunkyWriter, load_manifest_lazy
+
+from zipvoice.utils.common import str2bool
+from zipvoice.utils.feature import VocosFbank
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+lhotse.set_audio_duration_mismatch_tolerance(0.1)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--sampling-rate",
+ type=int,
+ default=24000,
+ help="The target sampling rate, the audio will be resampled to it.",
+ )
+
+ parser.add_argument(
+ "--type",
+ type=str,
+ default="vocos",
+ help="fbank type",
+ )
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="Dataset name.",
+ )
+
+ parser.add_argument(
+ "--subset",
+ type=str,
+ help="The subset of the dataset.",
+ )
+
+ parser.add_argument(
+ "--source-dir",
+ type=str,
+ default="data/manifests",
+ help="The source directory of manifest files.",
+ )
+
+ parser.add_argument(
+ "--dest-dir",
+ type=str,
+ default="data/fbank",
+ help="The destination directory of manifest files.",
+ )
+
+ parser.add_argument(
+ "--split-cuts",
+ type=str2bool,
+ default=False,
+ help="Whether to use splited cuts.",
+ )
+
+ parser.add_argument(
+ "--split-begin",
+ type=int,
+ help="Start idx of splited cuts.",
+ )
+
+ parser.add_argument(
+ "--split-end",
+ type=int,
+ help="End idx of splited cuts.",
+ )
+
+ parser.add_argument(
+ "--batch-duration",
+ type=int,
+ default=1000,
+ help="The batch duration when computing the features.",
+ )
+
+ parser.add_argument(
+ "--num-jobs",
+ type=int,
+ default=20,
+ help="The number of extractor workers.",
+ )
+
+ return parser.parse_args()
+
+
+def compute_fbank_split_single(params, idx):
+ logging.info(
+ f"Computing features for {idx}-th split of "
+ f"{params.dataset} dataset {params.subset} subset"
+ )
+ lhotse.set_audio_duration_mismatch_tolerance(0.1) # for emilia
+ src_dir = Path(params.source_dir)
+ output_dir = Path(params.dest_dir)
+
+ if not src_dir.exists():
+ logging.error(f"{src_dir} not exists")
+ return
+
+ if not output_dir.exists():
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ num_digits = 8
+ if params.type == "vocos":
+ extractor = VocosFbank()
+ else:
+ raise NotImplementedError(f"{params.type} is not supported")
+
+ prefix = params.dataset
+ subset = params.subset
+ suffix = "jsonl.gz"
+
+ idx = f"{idx}".zfill(num_digits)
+ cuts_filename = f"{prefix}_cuts_{subset}.{idx}.{suffix}"
+
+ if (src_dir / cuts_filename).is_file():
+ logging.info(f"Loading manifests {src_dir / cuts_filename}")
+ cut_set = load_manifest_lazy(src_dir / cuts_filename)
+ else:
+ logging.warning(f"Raw {cuts_filename} not exists, skipping")
+ return
+
+ cut_set = cut_set.resample(params.sampling_rate)
+
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{cuts_filename} already exists - skipping.")
+ return
+
+ logging.info(f"Processing {subset}.{idx} of {prefix}")
+
+ cut_set = cut_set.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{subset}_{idx}",
+ num_workers=4,
+ batch_duration=params.batch_duration,
+ storage_type=LilcomChunkyWriter,
+ overwrite=True,
+ )
+ logging.info(f"Saving file to {output_dir / cuts_filename}")
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+def compute_fbank_split(params):
+ if params.split_end < params.split_begin:
+ logging.warning(
+ f"Split begin should be smaller than split end, given "
+ f"{params.split_begin} -> {params.split_end}."
+ )
+
+ with Pool(max_workers=params.num_jobs) as pool:
+ futures = [
+ pool.submit(compute_fbank_split_single, params, i)
+ for i in range(params.split_begin, params.split_end)
+ ]
+ for f in futures:
+ f.result()
+ f.done()
+
+
+def compute_fbank(params):
+ logging.info(
+ f"Computing features for {params.dataset} dataset {params.subset} subset"
+ )
+ src_dir = Path(params.source_dir)
+ output_dir = Path(params.dest_dir)
+ num_jobs = params.num_jobs
+ if not output_dir.exists():
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ prefix = params.dataset
+ subset = params.subset
+ suffix = "jsonl.gz"
+
+ cut_set_name = f"{prefix}_cuts_{subset}.{suffix}"
+
+ if (src_dir / cut_set_name).is_file():
+ logging.info(f"Loading manifests {src_dir / cut_set_name}")
+ cut_set = load_manifest_lazy(src_dir / cut_set_name)
+ else:
+ recordings = load_manifest_lazy(
+ src_dir / f"{prefix}_recordings_{subset}.{suffix}"
+ )
+ supervisions = load_manifest_lazy(
+ src_dir / f"{prefix}_supervisions_{subset}.{suffix}"
+ )
+ cut_set = CutSet.from_manifests(
+ recordings=recordings,
+ supervisions=supervisions,
+ )
+
+ cut_set = cut_set.resample(params.sampling_rate)
+ if params.type == "vocos":
+ extractor = VocosFbank()
+ else:
+ raise NotImplementedError(f"{params.type} is not supported")
+
+ cuts_filename = f"{prefix}_cuts_{subset}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{prefix} {subset} already exists - skipping.")
+ return
+ logging.info(f"Processing {subset} of {prefix}")
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{subset}",
+ num_jobs=num_jobs,
+ storage_type=LilcomChunkyWriter,
+ )
+ logging.info(f"Saving file to {output_dir / cuts_filename}")
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ args = get_args()
+ logging.info(vars(args))
+ if args.split_cuts:
+ compute_fbank_split(params=args)
+ else:
+ compute_fbank(params=args)
+ logging.info("Done!")
diff --git a/zipvoice/bin/generate_averaged_model.py b/zipvoice/bin/generate_averaged_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ff432b7146fef2ae095d72d46a5b8b798bd7388
--- /dev/null
+++ b/zipvoice/bin/generate_averaged_model.py
@@ -0,0 +1,229 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+This script loads checkpoints and averages them.
+
+python3 -m zipvoice.bin.generate_averaged_model \
+ --epoch 11 \
+ --avg 4 \
+ --model-name zipvoice \
+ --exp-dir exp/zipvoice
+
+It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`.
+You can later load it by `torch.load("epoch-11-avg-4.pt")`.
+"""
+
+import argparse
+import json
+import logging
+from pathlib import Path
+
+import torch
+
+from zipvoice.models.zipvoice import ZipVoice
+from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
+from zipvoice.models.zipvoice_distill import ZipVoiceDistill
+from zipvoice.tokenizer.tokenizer import SimpleTokenizer
+from zipvoice.utils.checkpoint import (
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+)
+from zipvoice.utils.common import AttributeDict
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=11,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=4,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' or --iter",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="exp/zipvoice",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="zipvoice",
+ choices=[
+ "zipvoice",
+ "zipvoice_distill",
+ "zipvoice_dialog",
+ "zipvoice_dialog_stereo",
+ ],
+ help="The model type to be averaged. ",
+ )
+
+ return parser
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ params = AttributeDict()
+ params.update(vars(args))
+ params.exp_dir = Path(params.exp_dir)
+
+ with open(params.exp_dir / "model.json", "r") as f:
+ model_config = json.load(f)
+
+ # Any tokenizer can be used here.
+ # Use SimpleTokenizer for simplicity.
+ tokenizer = SimpleTokenizer(token_file=params.exp_dir / "tokens.txt")
+ if params.model_name in ["zipvoice", "zipvoice_distill"]:
+ tokenizer_config = {
+ "vocab_size": tokenizer.vocab_size,
+ "pad_id": tokenizer.pad_id,
+ }
+ elif params.model_name in ["zipvoice_dialog", "zipvoice_dialog_stereo"]:
+ tokenizer_config = {
+ "vocab_size": tokenizer.vocab_size,
+ "pad_id": tokenizer.pad_id,
+ "spk_a_id": tokenizer.spk_a_id,
+ "spk_b_id": tokenizer.spk_b_id,
+ }
+
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ logging.info("Script started")
+
+ params.device = torch.device("cpu")
+ logging.info(f"Device: {params.device}")
+
+ logging.info("About to create model")
+ if params.model_name == "zipvoice":
+ model = ZipVoice(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+ elif params.model_name == "zipvoice_distill":
+ model = ZipVoiceDistill(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+ elif params.model_name == "zipvoice_dialog":
+ model = ZipVoiceDialog(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+ elif params.model_name == "zipvoice_dialog_stereo":
+ model = ZipVoiceDialogStereo(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+ else:
+ raise ValueError(f"Unknown model name: {params.model_name}")
+
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(params.device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=params.device,
+ ),
+ strict=True,
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(params.device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=params.device,
+ ),
+ strict=True,
+ )
+ if params.iter > 0:
+ filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
+ else:
+ filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
+
+ logging.info(f"Saving the averaged checkpoint to {filename}")
+ torch.save({"model": model.state_dict()}, filename)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ main()
diff --git a/zipvoice/bin/infer_zipvoice.py b/zipvoice/bin/infer_zipvoice.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dc2aac68dcddb0bddf48724fa98e04195f7c92f
--- /dev/null
+++ b/zipvoice/bin/infer_zipvoice.py
@@ -0,0 +1,614 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script generates speech with our pre-trained ZipVoice or
+ ZipVoice-Distill models. If no local model is specified,
+ Required files will be automatically downloaded from HuggingFace.
+
+Usage:
+
+Note: If you having trouble connecting to HuggingFace,
+ try switching endpoint to mirror site:
+export HF_ENDPOINT=https://hf-mirror.com
+
+(1) Inference of a single sentence:
+
+python3 -m zipvoice.bin.infer_zipvoice \
+ --model-name zipvoice \
+ --prompt-wav prompt.wav \
+ --prompt-text "I am a prompt." \
+ --text "I am a sentence." \
+ --res-wav-path result.wav
+
+(2) Inference of a list of sentences:
+
+python3 -m zipvoice.bin.infer_zipvoice \
+ --model-name zipvoice \
+ --test-list test.tsv \
+ --res-dir results
+
+`--model-name` can be `zipvoice` or `zipvoice_distill`,
+ which are the models before and after distillation, respectively.
+
+Each line of `test.tsv` is in the format of
+ `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
+"""
+
+import argparse
+import datetime as dt
+import json
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import safetensors.torch
+import torch
+import torchaudio
+from huggingface_hub import hf_hub_download
+from lhotse.utils import fix_random_seed
+from vocos import Vocos
+
+from zipvoice.models.zipvoice import ZipVoice
+from zipvoice.models.zipvoice_distill import ZipVoiceDistill
+from zipvoice.tokenizer.tokenizer import (
+ EmiliaTokenizer,
+ EspeakTokenizer,
+ LibriTTSTokenizer,
+ SimpleTokenizer,
+)
+from zipvoice.utils.checkpoint import load_checkpoint
+from zipvoice.utils.common import AttributeDict
+from zipvoice.utils.feature import VocosFbank
+
+HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
+MODEL_DIR = {
+ "zipvoice": "zipvoice",
+ "zipvoice_distill": "zipvoice_distill",
+}
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="zipvoice",
+ choices=["zipvoice", "zipvoice_distill"],
+ help="The model used for inference",
+ )
+
+ parser.add_argument(
+ "--model-dir",
+ type=str,
+ default=None,
+ help="The model directory that contains model checkpoint, configuration "
+ "file model.json, and tokens file tokens.txt. Will download pre-trained "
+ "checkpoint from huggingface if not specified.",
+ )
+
+ parser.add_argument(
+ "--checkpoint-name",
+ type=str,
+ default="model.pt",
+ help="The name of model checkpoint.",
+ )
+
+ parser.add_argument(
+ "--vocoder-path",
+ type=str,
+ default=None,
+ help="The vocoder checkpoint. "
+ "Will download pre-trained vocoder from huggingface if not specified.",
+ )
+
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ default="emilia",
+ choices=["emilia", "libritts", "espeak", "simple"],
+ help="Tokenizer type.",
+ )
+
+ parser.add_argument(
+ "--lang",
+ type=str,
+ default="en-us",
+ help="Language identifier, used when tokenizer type is espeak. see"
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
+ )
+
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ default=None,
+ help="The list of prompt speech, prompt_transcription, "
+ "and text to synthesizein the format of "
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
+ )
+
+ parser.add_argument(
+ "--prompt-wav",
+ type=str,
+ default=None,
+ help="The prompt wav to mimic",
+ )
+
+ parser.add_argument(
+ "--prompt-text",
+ type=str,
+ default=None,
+ help="The transcription of the prompt wav",
+ )
+
+ parser.add_argument(
+ "--text",
+ type=str,
+ default=None,
+ help="The text to synthesize",
+ )
+
+ parser.add_argument(
+ "--res-dir",
+ type=str,
+ default="results",
+ help="""
+ Path name of the generated wavs dir,
+ used when test-list is not None
+ """,
+ )
+
+ parser.add_argument(
+ "--res-wav-path",
+ type=str,
+ default="result.wav",
+ help="""
+ Path name of the generated wav path,
+ used when test-list is None
+ """,
+ )
+
+ parser.add_argument(
+ "--guidance-scale",
+ type=float,
+ default=None,
+ help="The scale of classifier-free guidance during inference.",
+ )
+
+ parser.add_argument(
+ "--num-step",
+ type=int,
+ default=None,
+ help="The number of sampling steps.",
+ )
+
+ parser.add_argument(
+ "--feat-scale",
+ type=float,
+ default=0.1,
+ help="The scale factor of fbank feature",
+ )
+
+ parser.add_argument(
+ "--speed",
+ type=float,
+ default=1.0,
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
+ )
+
+ parser.add_argument(
+ "--t-shift",
+ type=float,
+ default=0.5,
+ help="Shift t to smaller ones if t_shift < 1.0",
+ )
+
+ parser.add_argument(
+ "--target-rms",
+ type=float,
+ default=0.1,
+ help="Target speech normalization rms value, set to 0 to disable normalization",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=666,
+ help="Random seed",
+ )
+
+ return parser
+
+
+def get_vocoder(vocos_local_path: Optional[str] = None):
+ if vocos_local_path:
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
+ state_dict = torch.load(
+ f"{vocos_local_path}/pytorch_model.bin",
+ weights_only=True,
+ map_location="cpu",
+ )
+ vocoder.load_state_dict(state_dict)
+ else:
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
+ return vocoder
+
+
+def generate_sentence(
+ save_path: str,
+ prompt_text: str,
+ prompt_wav: str,
+ text: str,
+ model: torch.nn.Module,
+ vocoder: torch.nn.Module,
+ tokenizer: EmiliaTokenizer,
+ feature_extractor: VocosFbank,
+ device: torch.device,
+ num_step: int = 16,
+ guidance_scale: float = 1.0,
+ speed: float = 1.0,
+ t_shift: float = 0.5,
+ target_rms: float = 0.1,
+ feat_scale: float = 0.1,
+ sampling_rate: int = 24000,
+):
+ """
+ Generate waveform of a text based on a given prompt
+ waveform and its transcription.
+
+ Args:
+ save_path (str): Path to save the generated wav.
+ prompt_text (str): Transcription of the prompt wav.
+ prompt_wav (str): Path to the prompt wav file.
+ text (str): Text to be synthesized into a waveform.
+ model (torch.nn.Module): The model used for generation.
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
+ tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
+ feature_extractor (VocosFbank): The feature extractor used to
+ extract acoustic features.
+ device (torch.device): The device on which computations are performed.
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
+ guidance_scale (float, optional): Scale for classifier-free guidance.
+ Defaults to 1.0.
+ speed (float, optional): Speed control. Defaults to 1.0.
+ t_shift (float, optional): Time shift. Defaults to 0.5.
+ target_rms (float, optional): Target RMS for waveform normalization.
+ Defaults to 0.1.
+ feat_scale (float, optional): Scale for features.
+ Defaults to 0.1.
+ sampling_rate (int, optional): Sampling rate for the waveform.
+ Defaults to 24000.
+ Returns:
+ metrics (dict): Dictionary containing time and real-time
+ factor metrics for processing.
+ """
+ # Convert text to tokens
+ tokens = tokenizer.texts_to_token_ids([text])
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
+
+ # Load and preprocess prompt wav
+ prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
+
+ if prompt_sampling_rate != sampling_rate:
+ resampler = torchaudio.transforms.Resample(
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
+ )
+ prompt_wav = resampler(prompt_wav)
+
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
+ if prompt_rms < target_rms:
+ prompt_wav = prompt_wav * target_rms / prompt_rms
+
+ # Extract features from prompt wav
+ prompt_features = feature_extractor.extract(
+ prompt_wav, sampling_rate=sampling_rate
+ ).to(device)
+
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
+
+ # Start timing
+ start_t = dt.datetime.now()
+
+ # Generate features
+ (
+ pred_features,
+ pred_features_lens,
+ pred_prompt_features,
+ pred_prompt_features_lens,
+ ) = model.sample(
+ tokens=tokens,
+ prompt_tokens=prompt_tokens,
+ prompt_features=prompt_features,
+ prompt_features_lens=prompt_features_lens,
+ speed=speed,
+ t_shift=t_shift,
+ duration="predict",
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ )
+
+ # Postprocess predicted features
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
+
+ # Start vocoder processing
+ start_vocoder_t = dt.datetime.now()
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
+
+ # Calculate processing times and real-time factors
+ t = (dt.datetime.now() - start_t).total_seconds()
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
+ wav_seconds = wav.shape[-1] / sampling_rate
+ rtf = t / wav_seconds
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
+ rtf_vocoder = t_vocoder / wav_seconds
+ metrics = {
+ "t": t,
+ "t_no_vocoder": t_no_vocoder,
+ "t_vocoder": t_vocoder,
+ "wav_seconds": wav_seconds,
+ "rtf": rtf,
+ "rtf_no_vocoder": rtf_no_vocoder,
+ "rtf_vocoder": rtf_vocoder,
+ }
+
+ # Adjust wav volume if necessary
+ if prompt_rms < target_rms:
+ wav = wav * prompt_rms / target_rms
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
+
+ return metrics
+
+
+def generate_list(
+ res_dir: str,
+ test_list: str,
+ model: torch.nn.Module,
+ vocoder: torch.nn.Module,
+ tokenizer: EmiliaTokenizer,
+ feature_extractor: VocosFbank,
+ device: torch.device,
+ num_step: int = 16,
+ guidance_scale: float = 1.0,
+ speed: float = 1.0,
+ t_shift: float = 0.5,
+ target_rms: float = 0.1,
+ feat_scale: float = 0.1,
+ sampling_rate: int = 24000,
+):
+ total_t = []
+ total_t_no_vocoder = []
+ total_t_vocoder = []
+ total_wav_seconds = []
+
+ with open(test_list, "r") as fr:
+ lines = fr.readlines()
+
+ for i, line in enumerate(lines):
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
+ save_path = f"{res_dir}/{wav_name}.wav"
+ metrics = generate_sentence(
+ save_path=save_path,
+ prompt_text=prompt_text,
+ prompt_wav=prompt_wav,
+ text=text,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ device=device,
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ speed=speed,
+ t_shift=t_shift,
+ target_rms=target_rms,
+ feat_scale=feat_scale,
+ sampling_rate=sampling_rate,
+ )
+ logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
+ total_t.append(metrics["t"])
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
+ total_t_vocoder.append(metrics["t_vocoder"])
+ total_wav_seconds.append(metrics["wav_seconds"])
+
+ logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
+ logging.info(
+ f"Average RTF w/o vocoder: "
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
+ )
+ logging.info(
+ f"Average RTF vocoder: "
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
+ )
+
+
+@torch.inference_mode()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = AttributeDict()
+ params.update(vars(args))
+ fix_random_seed(params.seed)
+
+ model_defaults = {
+ "zipvoice": {
+ "num_step": 16,
+ "guidance_scale": 1.0,
+ },
+ "zipvoice_distill": {
+ "num_step": 8,
+ "guidance_scale": 3.0,
+ },
+ }
+
+ model_specific_defaults = model_defaults.get(params.model_name, {})
+
+ for param, value in model_specific_defaults.items():
+ if getattr(params, param) is None:
+ setattr(params, param, value)
+ logging.info(f"Setting {param} to default value: {value}")
+
+ assert (params.test_list is not None) ^ (
+ (params.prompt_wav and params.prompt_text and params.text) is not None
+ ), (
+ "For inference, please provide prompts and text with either '--test-list'"
+ " or '--prompt-wav, --prompt-text and --text'."
+ )
+
+ if params.model_dir is not None:
+ params.model_dir = Path(params.model_dir)
+ if not params.model_dir.is_dir():
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
+ for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
+ if not (params.model_dir / filename).is_file():
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
+ model_ckpt = params.model_dir / params.checkpoint_name
+ model_config = params.model_dir / "model.json"
+ token_file = params.model_dir / "tokens.txt"
+ logging.info(
+ f"Using local model dir {params.model_dir}, "
+ f"checkpoint {params.checkpoint_name}"
+ )
+ else:
+ logging.info("Using pretrained model from the huggingface")
+ logging.info("Downloading the requires files from HuggingFace")
+ model_ckpt = hf_hub_download(
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.pt"
+ )
+ model_config = hf_hub_download(
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
+ )
+
+ token_file = hf_hub_download(
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
+ )
+
+ logging.info("Loading model...")
+
+ if params.tokenizer == "emilia":
+ tokenizer = EmiliaTokenizer(token_file=token_file)
+ elif params.tokenizer == "libritts":
+ tokenizer = LibriTTSTokenizer(token_file=token_file)
+ elif params.tokenizer == "espeak":
+ tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
+ else:
+ assert params.tokenizer == "simple"
+ tokenizer = SimpleTokenizer(token_file=token_file)
+
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
+
+ with open(model_config, "r") as f:
+ model_config = json.load(f)
+
+ if params.model_name == "zipvoice":
+ model = ZipVoice(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+ else:
+ assert params.model_name == "zipvoice_distill"
+ model = ZipVoiceDistill(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+
+ if str(model_ckpt).endswith(".safetensors"):
+ safetensors.torch.load_model(model, model_ckpt)
+ elif str(model_ckpt).endswith(".pt"):
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
+ else:
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
+
+ if torch.cuda.is_available():
+ params.device = torch.device("cuda", 0)
+ elif torch.backends.mps.is_available():
+ params.device = torch.device("mps")
+ else:
+ params.device = torch.device("cpu")
+ logging.info(f"Device: {params.device}")
+
+ model = model.to(params.device)
+ model.eval()
+
+ vocoder = get_vocoder(params.vocoder_path)
+ vocoder = vocoder.to(params.device)
+ vocoder.eval()
+
+ if model_config["feature"]["type"] == "vocos":
+ feature_extractor = VocosFbank()
+ else:
+ raise NotImplementedError(
+ f"Unsupported feature type: {model_config['feature']['type']}"
+ )
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
+
+ logging.info("Start generating...")
+ if params.test_list:
+ os.makedirs(params.res_dir, exist_ok=True)
+ generate_list(
+ res_dir=params.res_dir,
+ test_list=params.test_list,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ device=params.device,
+ num_step=params.num_step,
+ guidance_scale=params.guidance_scale,
+ speed=params.speed,
+ t_shift=params.t_shift,
+ target_rms=params.target_rms,
+ feat_scale=params.feat_scale,
+ sampling_rate=params.sampling_rate,
+ )
+ else:
+ generate_sentence(
+ save_path=params.res_wav_path,
+ prompt_text=params.prompt_text,
+ prompt_wav=params.prompt_wav,
+ text=params.text,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ device=params.device,
+ num_step=params.num_step,
+ guidance_scale=params.guidance_scale,
+ speed=params.speed,
+ t_shift=params.t_shift,
+ target_rms=params.target_rms,
+ feat_scale=params.feat_scale,
+ sampling_rate=params.sampling_rate,
+ )
+ logging.info("Done")
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ main()
diff --git a/zipvoice/bin/infer_zipvoice_dialog.py b/zipvoice/bin/infer_zipvoice_dialog.py
new file mode 100644
index 0000000000000000000000000000000000000000..26f9c7766ddfde5becb405e823e078503fb5297f
--- /dev/null
+++ b/zipvoice/bin/infer_zipvoice_dialog.py
@@ -0,0 +1,756 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script generates speech with our pre-trained ZipVoice-Dialog or
+ ZipVoice-Dialog-Stereo models. If no local model is specified,
+ Required files will be automatically downloaded from HuggingFace.
+
+Usage:
+
+Note: If you having trouble connecting to HuggingFace,
+ try switching endpoint to mirror site:
+export HF_ENDPOINT=https://hf-mirror.com
+
+python3 -m zipvoice.bin.infer_zipvoice_dialog \
+ --model-name zipvoice_dialog \
+ --test-list test.tsv \
+ --res-dir results
+
+`--model-name` can be `zipvoice_dialog` or `zipvoice_dialog_stereo`,
+ which generate mono and stereo dialogues, respectively.
+
+Each line of `test.tsv` is in the format of merged conversation:
+ '{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'
+ or splited conversation:
+ '{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}
+ \t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'
+"""
+
+import argparse
+import datetime as dt
+import json
+import logging
+import os
+from pathlib import Path
+from typing import List, Optional, Union
+
+import numpy as np
+import safetensors.torch
+import torch
+import torchaudio
+from huggingface_hub import hf_hub_download
+from lhotse.utils import fix_random_seed
+from vocos import Vocos
+
+from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
+from zipvoice.tokenizer.tokenizer import DialogTokenizer
+from zipvoice.utils.checkpoint import load_checkpoint
+from zipvoice.utils.common import AttributeDict
+from zipvoice.utils.feature import VocosFbank
+
+HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
+MODEL_DIR = {
+ "zipvoice_dialog": "zipvoice_dialog",
+ "zipvoice_dialog_stereo": "zipvoice_dialog_stereo",
+}
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="zipvoice_dialog",
+ choices=["zipvoice_dialog", "zipvoice_dialog_stereo"],
+ help="The model used for inference",
+ )
+
+ parser.add_argument(
+ "--model-dir",
+ type=str,
+ default=None,
+ help="The model directory that contains model checkpoint, configuration "
+ "file model.json, and tokens file tokens.txt. Will download pre-trained "
+ "checkpoint from huggingface if not specified.",
+ )
+
+ parser.add_argument(
+ "--checkpoint-name",
+ type=str,
+ default="model.pt",
+ help="The name of model checkpoint.",
+ )
+
+ parser.add_argument(
+ "--vocoder-path",
+ type=str,
+ default=None,
+ help="The vocoder checkpoint. "
+ "Will download pre-trained vocoder from huggingface if not specified.",
+ )
+
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ default=None,
+ help="The list of prompt speech, prompt_transcription, "
+ "and text to synthesizein the format of merged conversation: "
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' "
+ "or splited conversation: "
+ "'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}"
+ "\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'.",
+ )
+
+ parser.add_argument(
+ "--res-dir",
+ type=str,
+ default="results",
+ help="""
+ Path name of the generated wavs dir,
+ used when test-list is not None
+ """,
+ )
+
+ parser.add_argument(
+ "--guidance-scale",
+ type=float,
+ default=1.5,
+ help="The scale of classifier-free guidance during inference.",
+ )
+
+ parser.add_argument(
+ "--num-step",
+ type=int,
+ default=16,
+ help="The number of sampling steps.",
+ )
+
+ parser.add_argument(
+ "--feat-scale",
+ type=float,
+ default=0.1,
+ help="The scale factor of fbank feature",
+ )
+
+ parser.add_argument(
+ "--speed",
+ type=float,
+ default=1.0,
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
+ )
+
+ parser.add_argument(
+ "--t-shift",
+ type=float,
+ default=0.5,
+ help="Shift t to smaller ones if t_shift < 1.0",
+ )
+
+ parser.add_argument(
+ "--target-rms",
+ type=float,
+ default=0.1,
+ help="Target speech normalization rms value, set to 0 to disable normalization",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=666,
+ help="Random seed",
+ )
+
+ parser.add_argument(
+ "--silence-wav",
+ type=str,
+ default="assets/silence.wav",
+ help="Path of the silence wav file, used in two-channel generation "
+ "with single-channel prompts",
+ )
+
+ return parser
+
+
+def get_vocoder(vocos_local_path: Optional[str] = None):
+ if vocos_local_path:
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
+ state_dict = torch.load(
+ f"{vocos_local_path}/pytorch_model.bin",
+ weights_only=True,
+ map_location="cpu",
+ )
+ vocoder.load_state_dict(state_dict)
+ else:
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
+ return vocoder
+
+
+def generate_sentence(
+ save_path: str,
+ prompt_text: str,
+ prompt_wav: Union[str, List[str]],
+ text: str,
+ model: torch.nn.Module,
+ vocoder: torch.nn.Module,
+ tokenizer: DialogTokenizer,
+ feature_extractor: VocosFbank,
+ device: torch.device,
+ num_step: int = 16,
+ guidance_scale: float = 1.0,
+ speed: float = 1.0,
+ t_shift: float = 0.5,
+ target_rms: float = 0.1,
+ feat_scale: float = 0.1,
+ sampling_rate: int = 24000,
+):
+ """
+ Generate waveform of a text based on a given prompt
+ waveform and its transcription.
+
+ Args:
+ save_path (str): Path to save the generated wav.
+ prompt_text (str): Transcription of the prompt wav.
+ prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
+ one or two wav files, which corresponding to a merged conversational
+ speech or two seperate speaker's speech.
+ text (str): Text to be synthesized into a waveform.
+ model (torch.nn.Module): The model used for generation.
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
+ tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
+ feature_extractor (VocosFbank): The feature extractor used to
+ extract acoustic features.
+ device (torch.device): The device on which computations are performed.
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
+ guidance_scale (float, optional): Scale for classifier-free guidance.
+ Defaults to 1.0.
+ speed (float, optional): Speed control. Defaults to 1.0.
+ t_shift (float, optional): Time shift. Defaults to 0.5.
+ target_rms (float, optional): Target RMS for waveform normalization.
+ Defaults to 0.1.
+ feat_scale (float, optional): Scale for features.
+ Defaults to 0.1.
+ sampling_rate (int, optional): Sampling rate for the waveform.
+ Defaults to 24000.
+ Returns:
+ metrics (dict): Dictionary containing time and real-time
+ factor metrics for processing.
+ """
+ # Convert text to tokens
+ tokens = tokenizer.texts_to_token_ids([text])
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
+
+ # Load and preprocess prompt wav
+ if isinstance(prompt_wav, str):
+ prompt_wav = [
+ prompt_wav,
+ ]
+ else:
+ assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
+
+ loaded_prompt_wavs = prompt_wav
+ for i in range(len(prompt_wav)):
+ loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i])
+ if prompt_sampling_rate != sampling_rate:
+ resampler = torchaudio.transforms.Resample(
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
+ )
+ loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i])
+ if loaded_prompt_wavs[i].size(0) != 1:
+ loaded_prompt_wavs[i] = loaded_prompt_wavs[i].mean(0, keepdim=True)
+
+ if len(loaded_prompt_wavs) == 1:
+ prompt_wav = loaded_prompt_wavs[0]
+ else:
+ prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
+
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
+ if prompt_rms < target_rms:
+ prompt_wav = prompt_wav * target_rms / prompt_rms
+
+ # Extract features from prompt wav
+ prompt_features = feature_extractor.extract(
+ prompt_wav, sampling_rate=sampling_rate
+ ).to(device)
+
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
+
+ # Start timing
+ start_t = dt.datetime.now()
+
+ # Generate features
+ (
+ pred_features,
+ pred_features_lens,
+ pred_prompt_features,
+ pred_prompt_features_lens,
+ ) = model.sample(
+ tokens=tokens,
+ prompt_tokens=prompt_tokens,
+ prompt_features=prompt_features,
+ prompt_features_lens=prompt_features_lens,
+ speed=speed,
+ t_shift=t_shift,
+ duration="predict",
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ )
+
+ # Postprocess predicted features
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
+
+ # Start vocoder processing
+ start_vocoder_t = dt.datetime.now()
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
+
+ # Calculate processing times and real-time factors
+ t = (dt.datetime.now() - start_t).total_seconds()
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
+ wav_seconds = wav.shape[-1] / sampling_rate
+ rtf = t / wav_seconds
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
+ rtf_vocoder = t_vocoder / wav_seconds
+ metrics = {
+ "t": t,
+ "t_no_vocoder": t_no_vocoder,
+ "t_vocoder": t_vocoder,
+ "wav_seconds": wav_seconds,
+ "rtf": rtf,
+ "rtf_no_vocoder": rtf_no_vocoder,
+ "rtf_vocoder": rtf_vocoder,
+ }
+
+ # Adjust wav volume if necessary
+ if prompt_rms < target_rms:
+ wav = wav * prompt_rms / target_rms
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
+
+ return metrics
+
+
+def generate_sentence_stereo(
+ save_path: str,
+ prompt_text: str,
+ prompt_wav: Union[str, List[str]],
+ text: str,
+ model: torch.nn.Module,
+ vocoder: torch.nn.Module,
+ tokenizer: DialogTokenizer,
+ feature_extractor: VocosFbank,
+ device: torch.device,
+ num_step: int = 16,
+ guidance_scale: float = 1.0,
+ speed: float = 1.0,
+ t_shift: float = 0.5,
+ target_rms: float = 0.1,
+ feat_scale: float = 0.1,
+ sampling_rate: int = 24000,
+ silence_wav: Optional[str] = None,
+):
+ """
+ Generate waveform of a text based on a given prompt
+ waveform and its transcription.
+
+ Args:
+ save_path (str): Path to save the generated wav.
+ prompt_text (str): Transcription of the prompt wav.
+ prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
+ one or two wav files, which corresponding to a merged conversational
+ speech or two seperate speaker's speech.
+ text (str): Text to be synthesized into a waveform.
+ model (torch.nn.Module): The model used for generation.
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
+ tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
+ feature_extractor (VocosFbank): The feature extractor used to
+ extract acoustic features.
+ device (torch.device): The device on which computations are performed.
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
+ guidance_scale (float, optional): Scale for classifier-free guidance.
+ Defaults to 1.0.
+ speed (float, optional): Speed control. Defaults to 1.0.
+ t_shift (float, optional): Time shift. Defaults to 0.5.
+ target_rms (float, optional): Target RMS for waveform normalization.
+ Defaults to 0.1.
+ feat_scale (float, optional): Scale for features.
+ Defaults to 0.1.
+ sampling_rate (int, optional): Sampling rate for the waveform.
+ Defaults to 24000.
+ silence_wav (str): Path of the silence wav file, used in two-channel
+ generation with single-channel prompts
+ Returns:
+ metrics (dict): Dictionary containing time and real-time
+ factor metrics for processing.
+ """
+ # Convert text to tokens
+ tokens = tokenizer.texts_to_token_ids([text])
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
+
+ # Load and preprocess prompt wav
+ if isinstance(prompt_wav, str):
+ prompt_wav = [
+ prompt_wav,
+ ]
+ else:
+ assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
+
+ loaded_prompt_wavs = prompt_wav
+ for i in range(len(prompt_wav)):
+ loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i])
+ if prompt_sampling_rate != sampling_rate:
+ resampler = torchaudio.transforms.Resample(
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
+ )
+ loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i])
+
+ if len(loaded_prompt_wavs) == 1:
+ assert (
+ loaded_prompt_wavs[0].size(0) == 2
+ ), "Merged prompt wav must be stereo for stereo dialogue generation"
+ prompt_wav = loaded_prompt_wavs[0]
+
+ else:
+ assert len(loaded_prompt_wavs) == 2
+ if loaded_prompt_wavs[0].size(0) == 2:
+ prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
+ else:
+ assert loaded_prompt_wavs[0].size(0) == 1
+ silence_wav, silence_sampling_rate = torchaudio.load(silence_wav)
+ assert silence_sampling_rate == sampling_rate
+ prompt_wav = silence_wav[
+ :, : loaded_prompt_wavs[0].size(1) + loaded_prompt_wavs[1].size(1)
+ ]
+ prompt_wav[0, : loaded_prompt_wavs[0].size(1)] = loaded_prompt_wavs[0]
+ prompt_wav[1, loaded_prompt_wavs[0].size(1) :] = loaded_prompt_wavs[1]
+
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
+ if prompt_rms < target_rms:
+ prompt_wav = prompt_wav * target_rms / prompt_rms
+
+ # Extract features from prompt wav
+ prompt_features = feature_extractor.extract(
+ prompt_wav, sampling_rate=sampling_rate
+ ).to(device)
+
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
+
+ # Start timing
+ start_t = dt.datetime.now()
+
+ # Generate features
+ (
+ pred_features,
+ pred_features_lens,
+ pred_prompt_features,
+ pred_prompt_features_lens,
+ ) = model.sample(
+ tokens=tokens,
+ prompt_tokens=prompt_tokens,
+ prompt_features=prompt_features,
+ prompt_features_lens=prompt_features_lens,
+ speed=speed,
+ t_shift=t_shift,
+ duration="predict",
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ )
+
+ # Postprocess predicted features
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
+
+ # Start vocoder processing
+ start_vocoder_t = dt.datetime.now()
+ feat_dim = pred_features.size(1) // 2
+ wav_left = vocoder.decode(pred_features[:, :feat_dim]).squeeze(1).clamp(-1, 1)
+ wav_right = (
+ vocoder.decode(pred_features[:, feat_dim : feat_dim * 2])
+ .squeeze(1)
+ .clamp(-1, 1)
+ )
+
+ wav = torch.cat([wav_left, wav_right], dim=0)
+
+ # Calculate processing times and real-time factors
+ t = (dt.datetime.now() - start_t).total_seconds()
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
+ wav_seconds = wav.shape[-1] / sampling_rate
+ rtf = t / wav_seconds
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
+ rtf_vocoder = t_vocoder / wav_seconds
+ metrics = {
+ "t": t,
+ "t_no_vocoder": t_no_vocoder,
+ "t_vocoder": t_vocoder,
+ "wav_seconds": wav_seconds,
+ "rtf": rtf,
+ "rtf_no_vocoder": rtf_no_vocoder,
+ "rtf_vocoder": rtf_vocoder,
+ }
+
+ # Adjust wav volume if necessary
+ if prompt_rms < target_rms:
+ wav = wav * prompt_rms / target_rms
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
+
+ return metrics
+
+
+def generate_list(
+ model_name: str,
+ res_dir: str,
+ test_list: str,
+ model: torch.nn.Module,
+ vocoder: torch.nn.Module,
+ tokenizer: DialogTokenizer,
+ feature_extractor: VocosFbank,
+ device: torch.device,
+ num_step: int = 16,
+ guidance_scale: float = 1.5,
+ speed: float = 1.0,
+ t_shift: float = 0.5,
+ target_rms: float = 0.1,
+ feat_scale: float = 0.1,
+ sampling_rate: int = 24000,
+ silence_wav: Optional[str] = None,
+):
+ total_t = []
+ total_t_no_vocoder = []
+ total_t_vocoder = []
+ total_wav_seconds = []
+
+ with open(test_list, "r") as fr:
+ lines = fr.readlines()
+
+ for i, line in enumerate(lines):
+ items = line.strip().split("\t")
+ if len(items) == 6:
+ (
+ wav_name,
+ prompt_text_1,
+ prompt_text_2,
+ prompt_wav_1,
+ prompt_wav_2,
+ text,
+ ) = items
+ prompt_text = f"[S1]{prompt_text_1}[S2]{prompt_text_2}"
+ prompt_wav = [prompt_wav_1, prompt_wav_2]
+ elif len(items) == 4:
+ wav_name, prompt_text, prompt_wav, text = items
+ else:
+ raise ValueError(f"Invalid line: {line}")
+ assert text.startswith("[S1]")
+
+ save_path = f"{res_dir}/{wav_name}.wav"
+
+ if model_name == "zipvoice_dialog":
+
+ metrics = generate_sentence(
+ save_path=save_path,
+ prompt_text=prompt_text,
+ prompt_wav=prompt_wav,
+ text=text,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ device=device,
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ speed=speed,
+ t_shift=t_shift,
+ target_rms=target_rms,
+ feat_scale=feat_scale,
+ sampling_rate=sampling_rate,
+ )
+ else:
+ assert model_name == "zipvoice_dialog_stereo"
+ metrics = generate_sentence_stereo(
+ save_path=save_path,
+ prompt_text=prompt_text,
+ prompt_wav=prompt_wav,
+ text=text,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ device=device,
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ speed=speed,
+ t_shift=t_shift,
+ target_rms=target_rms,
+ feat_scale=feat_scale,
+ sampling_rate=sampling_rate,
+ silence_wav=silence_wav,
+ )
+
+ logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
+ total_t.append(metrics["t"])
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
+ total_t_vocoder.append(metrics["t_vocoder"])
+ total_wav_seconds.append(metrics["wav_seconds"])
+
+ logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
+ logging.info(
+ f"Average RTF w/o vocoder: "
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
+ )
+ logging.info(
+ f"Average RTF vocoder: "
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
+ )
+
+
+@torch.inference_mode()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = AttributeDict()
+ params.update(vars(args))
+ fix_random_seed(params.seed)
+
+ assert (
+ params.test_list is not None
+ ), "For inference, please provide prompts and text with '--test-list'"
+
+ if params.model_dir is not None:
+ params.model_dir = Path(params.model_dir)
+ if not params.model_dir.is_dir():
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
+ for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
+ if not (params.model_dir / filename).is_file():
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
+ model_ckpt = params.model_dir / params.checkpoint_name
+ model_config = params.model_dir / "model.json"
+ token_file = params.model_dir / "tokens.txt"
+ logging.info(
+ f"Using local model dir {params.model_dir}, "
+ f"checkpoint {params.checkpoint_name}"
+ )
+ else:
+ logging.info("Using pretrained model from the huggingface")
+ logging.info("Downloading the requires files from HuggingFace")
+ model_ckpt = hf_hub_download(
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.pt"
+ )
+ model_config = hf_hub_download(
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
+ )
+
+ token_file = hf_hub_download(
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
+ )
+
+ logging.info("Loading model...")
+
+ tokenizer = DialogTokenizer(token_file=token_file)
+
+ tokenizer_config = {
+ "vocab_size": tokenizer.vocab_size,
+ "pad_id": tokenizer.pad_id,
+ "spk_a_id": tokenizer.spk_a_id,
+ "spk_b_id": tokenizer.spk_b_id,
+ }
+
+ with open(model_config, "r") as f:
+ model_config = json.load(f)
+
+ if params.model_name == "zipvoice_dialog":
+ model = ZipVoiceDialog(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+ else:
+ assert params.model_name == "zipvoice_dialog_stereo"
+ model = ZipVoiceDialogStereo(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+
+ if str(model_ckpt).endswith(".safetensors"):
+ safetensors.torch.load_model(model, model_ckpt)
+ elif str(model_ckpt).endswith(".pt"):
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
+ else:
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
+
+ if torch.cuda.is_available():
+ params.device = torch.device("cuda", 0)
+ elif torch.backends.mps.is_available():
+ params.device = torch.device("mps")
+ else:
+ params.device = torch.device("cpu")
+ logging.info(f"Device: {params.device}")
+
+ model = model.to(params.device)
+ model.eval()
+
+ vocoder = get_vocoder(params.vocoder_path)
+ vocoder = vocoder.to(params.device)
+ vocoder.eval()
+
+ if model_config["feature"]["type"] == "vocos":
+ if params.model_name == "zipvoice_dialog":
+ num_channels = 1
+ else:
+ assert params.model_name == "zipvoice_dialog_stereo"
+ num_channels = 2
+ feature_extractor = VocosFbank(num_channels=num_channels)
+ else:
+ raise NotImplementedError(
+ f"Unsupported feature type: {model_config['feature']['type']}"
+ )
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
+
+ logging.info("Start generating...")
+ os.makedirs(params.res_dir, exist_ok=True)
+ generate_list(
+ model_name=params.model_name,
+ res_dir=params.res_dir,
+ test_list=params.test_list,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ device=params.device,
+ num_step=params.num_step,
+ guidance_scale=params.guidance_scale,
+ speed=params.speed,
+ t_shift=params.t_shift,
+ target_rms=params.target_rms,
+ feat_scale=params.feat_scale,
+ sampling_rate=params.sampling_rate,
+ silence_wav=params.silence_wav,
+ )
+ logging.info("Done")
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ main()
diff --git a/zipvoice/bin/infer_zipvoice_onnx.py b/zipvoice/bin/infer_zipvoice_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a59757733ddb0ac8327064c44737266b6519f5a
--- /dev/null
+++ b/zipvoice/bin/infer_zipvoice_onnx.py
@@ -0,0 +1,712 @@
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script generates speech with our pre-trained ZipVoice or ZipVoice-Distill
+ ONNX models. If no local model is specified,
+ Required files will be automatically downloaded from HuggingFace.
+
+Usage:
+
+Note: If you having trouble connecting to HuggingFace,
+ try switching endpoint to mirror site:
+export HF_ENDPOINT=https://hf-mirror.com
+
+(1) Inference of a single sentence:
+
+python3 -m zipvoice.bin.infer_zipvoice_onnx \
+ --onnx-int8 False \
+ --model-name zipvoice \
+ --prompt-wav prompt.wav \
+ --prompt-text "I am a prompt." \
+ --text "I am a sentence." \
+ --res-wav-path result.wav
+
+(2) Inference of a list of sentences:
+python3 -m zipvoice.bin.infer_zipvoice_onnx \
+ --onnx-int8 False \
+ --model-name zipvoice \
+ --test-list test.tsv \
+ --res-dir results
+
+`--model-name` can be `zipvoice` or `zipvoice_distill`,
+ which are the models before and after distillation, respectively.
+
+Each line of `test.tsv` is in the format of
+ `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
+
+Set `--onnx-int8 True` to use int8 quantizated ONNX model.
+"""
+
+import argparse
+import datetime as dt
+import json
+import logging
+import os
+from pathlib import Path
+from typing import List, Tuple
+
+import numpy as np
+import onnxruntime as ort
+import torch
+import torchaudio
+from huggingface_hub import hf_hub_download
+from lhotse.utils import fix_random_seed
+from torch import Tensor, nn
+
+from zipvoice.bin.infer_zipvoice import get_vocoder
+from zipvoice.models.modules.solver import get_time_steps
+from zipvoice.tokenizer.tokenizer import (
+ EmiliaTokenizer,
+ EspeakTokenizer,
+ LibriTTSTokenizer,
+ SimpleTokenizer,
+)
+from zipvoice.utils.common import AttributeDict, str2bool
+from zipvoice.utils.feature import VocosFbank
+
+HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
+MODEL_DIR = {
+ "zipvoice": "zipvoice",
+ "zipvoice_distill": "zipvoice_distill",
+}
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--onnx-int8",
+ type=str2bool,
+ default=False,
+ help="Whether to use the int8 model",
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="zipvoice",
+ choices=["zipvoice", "zipvoice_distill"],
+ help="The model used for inference",
+ )
+
+ parser.add_argument(
+ "--model-dir",
+ type=str,
+ default=None,
+ help="The path to the local onnx model. "
+ "Will download pre-trained checkpoint from huggingface if not specified.",
+ )
+
+ parser.add_argument(
+ "--vocoder-path",
+ type=str,
+ default=None,
+ help="The vocoder checkpoint. "
+ "Will download pre-trained vocoder from huggingface if not specified.",
+ )
+
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ default="emilia",
+ choices=["emilia", "libritts", "espeak", "simple"],
+ help="Tokenizer type.",
+ )
+
+ parser.add_argument(
+ "--lang",
+ type=str,
+ default="en-us",
+ help="Language identifier, used when tokenizer type is espeak. see"
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
+ )
+
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ default=None,
+ help="The list of prompt speech, prompt_transcription, "
+ "and text to synthesizein the format of "
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
+ )
+
+ parser.add_argument(
+ "--prompt-wav",
+ type=str,
+ default=None,
+ help="The prompt wav to mimic",
+ )
+
+ parser.add_argument(
+ "--prompt-text",
+ type=str,
+ default=None,
+ help="The transcription of the prompt wav",
+ )
+
+ parser.add_argument(
+ "--text",
+ type=str,
+ default=None,
+ help="The text to synthesize",
+ )
+
+ parser.add_argument(
+ "--res-dir",
+ type=str,
+ default="results",
+ help="""
+ Path name of the generated wavs dir,
+ used when test-list is not None
+ """,
+ )
+
+ parser.add_argument(
+ "--res-wav-path",
+ type=str,
+ default="result.wav",
+ help="""
+ Path name of the generated wav path,
+ used when test-list is None
+ """,
+ )
+
+ parser.add_argument(
+ "--guidance-scale",
+ type=float,
+ default=None,
+ help="The scale of classifier-free guidance during inference.",
+ )
+
+ parser.add_argument(
+ "--num-step",
+ type=int,
+ default=None,
+ help="The number of sampling steps.",
+ )
+
+ parser.add_argument(
+ "--feat-scale",
+ type=float,
+ default=0.1,
+ help="The scale factor of fbank feature",
+ )
+
+ parser.add_argument(
+ "--speed",
+ type=float,
+ default=1.0,
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
+ )
+
+ parser.add_argument(
+ "--t-shift",
+ type=float,
+ default=0.5,
+ help="Shift t to smaller ones if t_shift < 1.0",
+ )
+
+ parser.add_argument(
+ "--target-rms",
+ type=float,
+ default=0.1,
+ help="Target speech normalization rms value, set to 0 to disable normalization",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=666,
+ help="Random seed",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(
+ self,
+ text_encoder_path: str,
+ fm_decoder_path: str,
+ ):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 1
+
+ self.session_opts = session_opts
+
+ self.init_text_encoder(text_encoder_path)
+ self.init_fm_decoder(fm_decoder_path)
+
+ def init_text_encoder(self, model_path: str):
+ self.text_encoder = ort.InferenceSession(
+ model_path,
+ sess_options=self.session_opts,
+ providers=["CPUExecutionProvider"],
+ )
+
+ def init_fm_decoder(self, model_path: str):
+ self.fm_decoder = ort.InferenceSession(
+ model_path,
+ sess_options=self.session_opts,
+ providers=["CPUExecutionProvider"],
+ )
+ meta = self.fm_decoder.get_modelmeta().custom_metadata_map
+ self.feat_dim = int(meta["feat_dim"])
+
+ def run_text_encoder(
+ self,
+ tokens: Tensor,
+ prompt_tokens: Tensor,
+ prompt_features_len: Tensor,
+ speed: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ out = self.text_encoder.run(
+ [
+ self.text_encoder.get_outputs()[0].name,
+ ],
+ {
+ self.text_encoder.get_inputs()[0].name: tokens.numpy(),
+ self.text_encoder.get_inputs()[1].name: prompt_tokens.numpy(),
+ self.text_encoder.get_inputs()[2].name: prompt_features_len.numpy(),
+ self.text_encoder.get_inputs()[3].name: speed.numpy(),
+ },
+ )
+ return torch.from_numpy(out[0])
+
+ def run_fm_decoder(
+ self,
+ t: Tensor,
+ x: Tensor,
+ text_condition: Tensor,
+ speech_condition: torch.Tensor,
+ guidance_scale: Tensor,
+ ) -> Tensor:
+ out = self.fm_decoder.run(
+ [
+ self.fm_decoder.get_outputs()[0].name,
+ ],
+ {
+ self.fm_decoder.get_inputs()[0].name: t.numpy(),
+ self.fm_decoder.get_inputs()[1].name: x.numpy(),
+ self.fm_decoder.get_inputs()[2].name: text_condition.numpy(),
+ self.fm_decoder.get_inputs()[3].name: speech_condition.numpy(),
+ self.fm_decoder.get_inputs()[4].name: guidance_scale.numpy(),
+ },
+ )
+ return torch.from_numpy(out[0])
+
+
+def sample(
+ model: OnnxModel,
+ tokens: List[List[int]],
+ prompt_tokens: List[List[int]],
+ prompt_features: Tensor,
+ speed: float = 1.0,
+ t_shift: float = 0.5,
+ guidance_scale: float = 1.0,
+ num_step: int = 16,
+) -> torch.Tensor:
+ """
+ Generate acoustic features, given text tokens, prompts feature and prompt
+ transcription's text tokens.
+
+ Args:
+ tokens: a list of list of text tokens.
+ prompt_tokens: a list of list of prompt tokens.
+ prompt_features: the prompt feature with the shape
+ (batch_size, seq_len, feat_dim).
+ speed : speed control.
+ t_shift: time shift.
+ guidance_scale: the guidance scale for classifier-free guidance.
+ num_step: the number of steps to use in the ODE solver.
+ """
+ # Run text encoder
+ assert len(tokens) == len(prompt_tokens) == 1
+ tokens = torch.tensor(tokens, dtype=torch.int64)
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.int64)
+ prompt_features_len = torch.tensor(prompt_features.size(1), dtype=torch.int64)
+ speed = torch.tensor(speed, dtype=torch.float32)
+
+ text_condition = model.run_text_encoder(
+ tokens, prompt_tokens, prompt_features_len, speed
+ )
+
+ batch_size, num_frames, _ = text_condition.shape
+ assert batch_size == 1
+ feat_dim = model.feat_dim
+
+ # Run flow matching model
+ timesteps = get_time_steps(
+ t_start=0.0,
+ t_end=1.0,
+ num_step=num_step,
+ t_shift=t_shift,
+ )
+ x = torch.randn(batch_size, num_frames, feat_dim)
+ speech_condition = torch.nn.functional.pad(
+ prompt_features, (0, 0, 0, num_frames - prompt_features.shape[1])
+ ) # (B, T, F)
+ guidance_scale = torch.tensor(guidance_scale, dtype=torch.float32)
+
+ for step in range(num_step):
+ v = model.run_fm_decoder(
+ t=timesteps[step],
+ x=x,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ guidance_scale=guidance_scale,
+ )
+ x = x + v * (timesteps[step + 1] - timesteps[step])
+
+ x = x[:, prompt_features_len.item() :, :]
+ return x
+
+
+# Copied from zipvoice/bin/infer_zipvoice.py, but call an external sample function
+def generate_sentence(
+ save_path: str,
+ prompt_text: str,
+ prompt_wav: str,
+ text: str,
+ model: OnnxModel,
+ vocoder: nn.Module,
+ tokenizer: EmiliaTokenizer,
+ feature_extractor: VocosFbank,
+ num_step: int = 16,
+ guidance_scale: float = 1.0,
+ speed: float = 1.0,
+ t_shift: float = 0.5,
+ target_rms: float = 0.1,
+ feat_scale: float = 0.1,
+ sampling_rate: int = 24000,
+):
+ """
+ Generate waveform of a text based on a given prompt
+ waveform and its transcription.
+
+ Args:
+ save_path (str): Path to save the generated wav.
+ prompt_text (str): Transcription of the prompt wav.
+ prompt_wav (str): Path to the prompt wav file.
+ text (str): Text to be synthesized into a waveform.
+ model (torch.nn.Module): The model used for generation.
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
+ tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
+ feature_extractor (VocosFbank): The feature extractor used to
+ extract acoustic features.
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
+ guidance_scale (float, optional): Scale for classifier-free guidance.
+ Defaults to 1.0.
+ speed (float, optional): Speed control. Defaults to 1.0.
+ t_shift (float, optional): Time shift. Defaults to 0.5.
+ target_rms (float, optional): Target RMS for waveform normalization.
+ Defaults to 0.1.
+ feat_scale (float, optional): Scale for features.
+ Defaults to 0.1.
+ sampling_rate (int, optional): Sampling rate for the waveform.
+ Defaults to 24000.
+ Returns:
+ metrics (dict): Dictionary containing time and real-time
+ factor metrics for processing.
+ """
+ # Convert text to tokens
+ tokens = tokenizer.texts_to_token_ids([text])
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
+
+ # Load and preprocess prompt wav
+ prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
+
+ if prompt_sampling_rate != sampling_rate:
+ resampler = torchaudio.transforms.Resample(
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
+ )
+ prompt_wav = resampler(prompt_wav)
+
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
+ if prompt_rms < target_rms:
+ prompt_wav = prompt_wav * target_rms / prompt_rms
+
+ # Extract features from prompt wav
+ prompt_features = feature_extractor.extract(prompt_wav, sampling_rate=sampling_rate)
+
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
+
+ # Start timing
+ start_t = dt.datetime.now()
+
+ # Generate features
+ pred_features = sample(
+ model=model,
+ tokens=tokens,
+ prompt_tokens=prompt_tokens,
+ prompt_features=prompt_features,
+ speed=speed,
+ t_shift=t_shift,
+ guidance_scale=guidance_scale,
+ num_step=num_step,
+ )
+
+ # Postprocess predicted features
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
+
+ # Start vocoder processing
+ start_vocoder_t = dt.datetime.now()
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
+
+ # Calculate processing times and real-time factors
+ t = (dt.datetime.now() - start_t).total_seconds()
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
+ wav_seconds = wav.shape[-1] / sampling_rate
+ rtf = t / wav_seconds
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
+ rtf_vocoder = t_vocoder / wav_seconds
+ metrics = {
+ "t": t,
+ "t_no_vocoder": t_no_vocoder,
+ "t_vocoder": t_vocoder,
+ "wav_seconds": wav_seconds,
+ "rtf": rtf,
+ "rtf_no_vocoder": rtf_no_vocoder,
+ "rtf_vocoder": rtf_vocoder,
+ }
+
+ # Adjust wav volume if necessary
+ if prompt_rms < target_rms:
+ wav = wav * prompt_rms / target_rms
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
+
+ return metrics
+
+
+def generate_list(
+ res_dir: str,
+ test_list: str,
+ model: OnnxModel,
+ vocoder: nn.Module,
+ tokenizer: EmiliaTokenizer,
+ feature_extractor: VocosFbank,
+ num_step: int = 16,
+ guidance_scale: float = 1.0,
+ speed: float = 1.0,
+ t_shift: float = 0.5,
+ target_rms: float = 0.1,
+ feat_scale: float = 0.1,
+ sampling_rate: int = 24000,
+):
+ total_t = []
+ total_t_no_vocoder = []
+ total_t_vocoder = []
+ total_wav_seconds = []
+
+ with open(test_list, "r") as fr:
+ lines = fr.readlines()
+
+ for i, line in enumerate(lines):
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
+ save_path = f"{res_dir}/{wav_name}.wav"
+ metrics = generate_sentence(
+ save_path=save_path,
+ prompt_text=prompt_text,
+ prompt_wav=prompt_wav,
+ text=text,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ speed=speed,
+ t_shift=t_shift,
+ target_rms=target_rms,
+ feat_scale=feat_scale,
+ sampling_rate=sampling_rate,
+ )
+ logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
+ total_t.append(metrics["t"])
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
+ total_t_vocoder.append(metrics["t_vocoder"])
+ total_wav_seconds.append(metrics["wav_seconds"])
+
+ logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
+ logging.info(
+ f"Average RTF w/o vocoder: "
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
+ )
+ logging.info(
+ f"Average RTF vocoder: "
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
+ )
+
+
+@torch.inference_mode()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = AttributeDict()
+ params.update(vars(args))
+ fix_random_seed(params.seed)
+
+ model_defaults = {
+ "zipvoice": {
+ "num_step": 16,
+ "guidance_scale": 1.0,
+ },
+ "zipvoice_distill": {
+ "num_step": 8,
+ "guidance_scale": 3.0,
+ },
+ }
+
+ model_specific_defaults = model_defaults.get(params.model_name, {})
+
+ for param, value in model_specific_defaults.items():
+ if getattr(params, param) is None:
+ setattr(params, param, value)
+ logging.info(f"Setting {param} to default value: {value}")
+
+ assert (params.test_list is not None) ^ (
+ (params.prompt_wav and params.prompt_text and params.text) is not None
+ ), (
+ "For inference, please provide prompts and text with either '--test-list'"
+ " or '--prompt-wav, --prompt-text and --text'."
+ )
+
+ if params.onnx_int8:
+ text_encoder_name = "text_encoder_int8.onnx"
+ fm_decoder_name = "fm_decoder_int8.onnx"
+ else:
+ text_encoder_name = "text_encoder.onnx"
+ fm_decoder_name = "fm_decoder.onnx"
+
+ if params.model_dir is not None:
+ params.model_dir = Path(params.model_dir)
+ if not params.model_dir.is_dir():
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
+
+ for filename in [
+ text_encoder_name,
+ fm_decoder_name,
+ "model.json",
+ "tokens.txt",
+ ]:
+ if not (params.model_dir / filename).is_file():
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
+ text_encoder_path = params.model_dir / text_encoder_name
+ fm_decoder_path = params.model_dir / fm_decoder_name
+ model_config = params.model_dir / "model.json"
+ token_file = params.model_dir / "tokens.txt"
+ logging.info(f"Using local model dir {params.model_dir}.")
+ else:
+ logging.info("Using pretrained model from the huggingface")
+ logging.info("Downloading the requires files from HuggingFace")
+ text_encoder_path = hf_hub_download(
+ HUGGINGFACE_REPO,
+ filename=f"{MODEL_DIR[params.model_name]}/{text_encoder_name}",
+ )
+ fm_decoder_path = hf_hub_download(
+ HUGGINGFACE_REPO,
+ filename=f"{MODEL_DIR[params.model_name]}/{fm_decoder_name}",
+ )
+ model_config = hf_hub_download(
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
+ )
+
+ token_file = hf_hub_download(
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
+ )
+
+ logging.info("Loading model...")
+
+ if params.tokenizer == "emilia":
+ tokenizer = EmiliaTokenizer(token_file=token_file)
+ elif params.tokenizer == "libritts":
+ tokenizer = LibriTTSTokenizer(token_file=token_file)
+ elif params.tokenizer == "espeak":
+ tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
+ else:
+ assert params.tokenizer == "simple"
+ tokenizer = SimpleTokenizer(token_file=token_file)
+
+ with open(model_config, "r") as f:
+ model_config = json.load(f)
+
+ model = OnnxModel(text_encoder_path, fm_decoder_path)
+
+ vocoder = get_vocoder(params.vocoder_path)
+ vocoder.eval()
+
+ if model_config["feature"]["type"] == "vocos":
+ feature_extractor = VocosFbank()
+ else:
+ raise NotImplementedError(
+ f"Unsupported feature type: {model_config['feature']['type']}"
+ )
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
+
+ logging.info("Start generating...")
+ if params.test_list:
+ os.makedirs(params.res_dir, exist_ok=True)
+ generate_list(
+ res_dir=params.res_dir,
+ test_list=params.test_list,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ num_step=params.num_step,
+ guidance_scale=params.guidance_scale,
+ speed=params.speed,
+ t_shift=params.t_shift,
+ target_rms=params.target_rms,
+ feat_scale=params.feat_scale,
+ sampling_rate=params.sampling_rate,
+ )
+ else:
+ generate_sentence(
+ save_path=params.res_wav_path,
+ prompt_text=params.prompt_text,
+ prompt_wav=params.prompt_wav,
+ text=params.text,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ num_step=params.num_step,
+ guidance_scale=params.guidance_scale,
+ speed=params.speed,
+ t_shift=params.t_shift,
+ target_rms=params.target_rms,
+ feat_scale=params.feat_scale,
+ sampling_rate=params.sampling_rate,
+ )
+ logging.info("Done")
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ main()
diff --git a/zipvoice/bin/onnx_export.py b/zipvoice/bin/onnx_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..49cb3bac2f0d206a02cd31ace62e2ef44b55fd55
--- /dev/null
+++ b/zipvoice/bin/onnx_export.py
@@ -0,0 +1,410 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script exports a pre-trained ZipVoice or ZipVoice-Distill model from PyTorch to
+ONNX.
+
+Usage:
+
+python3 -m zipvoice.bin.onnx_export \
+ --model-name zipvoice \
+ --model-dir exp/zipvoice \
+ --checkpoint-name epoch-11-avg-4.pt \
+ --onnx-model-dir exp/zipvoice
+
+`--model-name` can be `zipvoice` or `zipvoice_distill`,
+ which are the models before and after distillation, respectively.
+"""
+
+
+import argparse
+import json
+import logging
+from pathlib import Path
+from typing import Dict
+
+import onnx
+import safetensors.torch
+import torch
+from onnxruntime.quantization import QuantType, quantize_dynamic
+from torch import Tensor, nn
+
+from zipvoice.models.zipvoice import ZipVoice
+from zipvoice.models.zipvoice_distill import ZipVoiceDistill
+from zipvoice.tokenizer.tokenizer import SimpleTokenizer
+from zipvoice.utils.checkpoint import load_checkpoint
+from zipvoice.utils.common import AttributeDict
+from zipvoice.utils.scaling_converter import convert_scaled_to_non_scaled
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--onnx-model-dir",
+ type=str,
+ default="exp",
+ help="Dir to the exported models",
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="zipvoice",
+ choices=["zipvoice", "zipvoice_distill"],
+ help="The model used for inference",
+ )
+
+ parser.add_argument(
+ "--model-dir",
+ type=str,
+ default=None,
+ help="The model directory that contains model checkpoint, configuration "
+ "file model.json, and tokens file tokens.txt. Will download pre-trained "
+ "checkpoint from huggingface if not specified.",
+ )
+
+ parser.add_argument(
+ "--checkpoint-name",
+ type=str,
+ default="model.pt",
+ help="The name of model checkpoint.",
+ )
+
+ return parser
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+ """Add meta data to an ONNX model. It is changed in-place.
+
+ Args:
+ filename:
+ Filename of the ONNX model to be changed.
+ meta_data:
+ Key-value pairs.
+ """
+ model = onnx.load(filename)
+ for key, value in meta_data.items():
+ meta = model.metadata_props.add()
+ meta.key = key
+ meta.value = value
+
+ onnx.save(model, filename)
+
+
+class OnnxTextModel(nn.Module):
+ def __init__(self, model: nn.Module):
+ """A wrapper for ZipVoice text encoder."""
+ super().__init__()
+ self.embed = model.embed
+ self.text_encoder = model.text_encoder
+ self.pad_id = model.pad_id
+
+ def forward(
+ self,
+ tokens: Tensor,
+ prompt_tokens: Tensor,
+ prompt_features_len: Tensor,
+ speed: Tensor,
+ ) -> Tensor:
+ cat_tokens = torch.cat([prompt_tokens, tokens], dim=1)
+ cat_tokens = nn.functional.pad(cat_tokens, (0, 1), value=self.pad_id)
+ tokens_len = cat_tokens.shape[1] - 1
+ padding_mask = (torch.arange(tokens_len + 1) == tokens_len).unsqueeze(0)
+
+ embed = self.embed(cat_tokens)
+ embed = self.text_encoder(x=embed, t=None, padding_mask=padding_mask)
+
+ features_len = torch.ceil(
+ (prompt_features_len / prompt_tokens.shape[1] * tokens_len / speed)
+ ).to(dtype=torch.int64)
+
+ token_dur = torch.div(features_len, tokens_len, rounding_mode="floor").to(
+ dtype=torch.int64
+ )
+
+ text_condition = embed[:, :-1, :].unsqueeze(2).expand(-1, -1, token_dur, -1)
+ text_condition = text_condition.reshape(embed.shape[0], -1, embed.shape[2])
+
+ text_condition = torch.cat(
+ [
+ text_condition,
+ embed[:, -1:, :].expand(-1, features_len - text_condition.shape[1], -1),
+ ],
+ dim=1,
+ )
+
+ return text_condition
+
+
+class OnnxFlowMatchingModel(nn.Module):
+ def __init__(self, model: nn.Module, distill: bool = False):
+ """A wrapper for ZipVoice flow-matching decoder."""
+ super().__init__()
+ self.distill = distill
+ self.fm_decoder = model.fm_decoder
+ self.model_func = getattr(model, "forward_fm_decoder")
+ self.feat_dim = model.feat_dim
+
+ def forward(
+ self,
+ t: Tensor,
+ x: Tensor,
+ text_condition: Tensor,
+ speech_condition: torch.Tensor,
+ guidance_scale: Tensor,
+ ) -> Tensor:
+ if self.distill:
+ return self.model_func(
+ t=t,
+ xt=x,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ guidance_scale=guidance_scale,
+ )
+ else:
+ x = x.repeat(2, 1, 1)
+ text_condition = torch.cat(
+ [torch.zeros_like(text_condition), text_condition], dim=0
+ )
+ speech_condition = torch.cat(
+ [
+ torch.where(
+ t > 0.5, torch.zeros_like(speech_condition), speech_condition
+ ),
+ speech_condition,
+ ],
+ dim=0,
+ )
+ guidance_scale = torch.where(t > 0.5, guidance_scale, guidance_scale * 2.0)
+ data_uncond, data_cond = self.model_func(
+ t=t,
+ xt=x,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ ).chunk(2, dim=0)
+ v = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
+ return v
+
+
+def export_text_encoder(
+ model: OnnxTextModel,
+ filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the text encoder model to ONNX format.
+
+ Args:
+ model:
+ The input model
+ filename:
+ The filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ tokens = torch.tensor([[2, 3, 4, 5]], dtype=torch.int64)
+ prompt_tokens = torch.tensor([[0, 1]], dtype=torch.int64)
+ prompt_features_len = torch.tensor(10, dtype=torch.int64)
+ speed = torch.tensor(1.0, dtype=torch.float32)
+
+ model = torch.jit.trace(model, (tokens, prompt_tokens, prompt_features_len, speed))
+
+ torch.onnx.export(
+ model,
+ (tokens, prompt_tokens, prompt_features_len, speed),
+ filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=["tokens", "prompt_tokens", "prompt_features_len", "speed"],
+ output_names=["text_condition"],
+ dynamic_axes={
+ "tokens": {0: "N", 1: "T"},
+ "prompt_tokens": {0: "N", 1: "T"},
+ "text_condition": {0: "N", 1: "T"},
+ },
+ )
+
+ meta_data = {
+ "version": "1",
+ "model_author": "k2-fsa",
+ "comment": "ZipVoice text encoder",
+ }
+ logging.info(f"meta_data: {meta_data}")
+ add_meta_data(filename=filename, meta_data=meta_data)
+
+ logging.info(f"Exported to {filename}")
+
+
+def export_fm_decoder(
+ model: OnnxFlowMatchingModel,
+ filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the flow matching decoder model to ONNX format.
+
+ Args:
+ model:
+ The input model
+ filename:
+ The filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ feat_dim = model.feat_dim
+ seq_len = 200
+ t = torch.tensor(0.5, dtype=torch.float32)
+ x = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
+ text_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
+ speech_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
+ guidance_scale = torch.tensor(1.0, dtype=torch.float32)
+
+ model = torch.jit.trace(
+ model, (t, x, text_condition, speech_condition, guidance_scale)
+ )
+
+ torch.onnx.export(
+ model,
+ (t, x, text_condition, speech_condition, guidance_scale),
+ filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=["t", "x", "text_condition", "speech_condition", "guidance_scale"],
+ output_names=["v"],
+ dynamic_axes={
+ "x": {0: "N", 1: "T"},
+ "text_condition": {0: "N", 1: "T"},
+ "speech_condition": {0: "N", 1: "T"},
+ "v": {0: "N", 1: "T"},
+ },
+ )
+
+ meta_data = {
+ "version": "1",
+ "model_author": "k2-fsa",
+ "comment": "ZipVoice flow-matching decoder",
+ "feat_dim": str(feat_dim),
+ }
+ logging.info(f"meta_data: {meta_data}")
+ add_meta_data(filename=filename, meta_data=meta_data)
+
+ logging.info(f"Exported to {filename}")
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = AttributeDict()
+ params.update(vars(args))
+
+ params.model_dir = Path(params.model_dir)
+ if not params.model_dir.is_dir():
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
+ for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
+ if not (params.model_dir / filename).is_file():
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
+ model_ckpt = params.model_dir / params.checkpoint_name
+ model_config = params.model_dir / "model.json"
+ token_file = params.model_dir / "tokens.txt"
+
+ logging.info(f"Loading model from {params.model_dir}")
+
+ tokenizer = SimpleTokenizer(token_file)
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
+
+ with open(model_config, "r") as f:
+ model_config = json.load(f)
+
+ if params.model_name == "zipvoice":
+ model = ZipVoice(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+ distill = False
+ else:
+ assert params.model_name == "zipvoice_distill"
+ model = ZipVoiceDistill(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+ distill = True
+
+ if str(model_ckpt).endswith(".safetensors"):
+ safetensors.torch.load_model(model, model_ckpt)
+ elif str(model_ckpt).endswith(".pt"):
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
+ else:
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
+
+ device = torch.device("cpu")
+ model = model.to(device)
+ model.eval()
+
+ convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
+
+ logging.info("Exporting model")
+ onnx_model_dir = Path(params.onnx_model_dir)
+ onnx_model_dir.mkdir(parents=True, exist_ok=True)
+ opset_version = 11
+
+ text_encoder = OnnxTextModel(model=model)
+ text_encoder_file = onnx_model_dir / "text_encoder.onnx"
+ export_text_encoder(
+ model=text_encoder,
+ filename=text_encoder_file,
+ opset_version=opset_version,
+ )
+
+ fm_decoder = OnnxFlowMatchingModel(model=model, distill=distill)
+ fm_decoder_file = onnx_model_dir / "fm_decoder.onnx"
+ export_fm_decoder(
+ model=fm_decoder,
+ filename=fm_decoder_file,
+ opset_version=opset_version,
+ )
+
+ logging.info("Generate int8 quantization models")
+
+ text_encoder_int8_file = onnx_model_dir / "text_encoder_int8.onnx"
+ quantize_dynamic(
+ model_input=text_encoder_file,
+ model_output=text_encoder_int8_file,
+ op_types_to_quantize=["MatMul"],
+ weight_type=QuantType.QInt8,
+ )
+
+ fm_decoder_int8_file = onnx_model_dir / "fm_decoder_int8.onnx"
+ quantize_dynamic(
+ model_input=fm_decoder_file,
+ model_output=fm_decoder_int8_file,
+ op_types_to_quantize=["MatMul"],
+ weight_type=QuantType.QInt8,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ main()
diff --git a/zipvoice/bin/prepare_dataset.py b/zipvoice/bin/prepare_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3a2b28b40fa26ea971c61d969f9860221fe2b14
--- /dev/null
+++ b/zipvoice/bin/prepare_dataset.py
@@ -0,0 +1,274 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script generates lhotse manifest files from TSV files for custom datasets.
+
+Each line of the TSV files should be in one of the following formats:
+1. "{uniq_id}\t{text}\t{wav_path}" if the text corresponds to the full wav",
+2. "{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time} if text corresponds
+ to part of the wav. The start_time and end_time specify the start and end
+ times of the text within the wav, which should be in seconds.
+
+Note: {uniq_id} must be unique for each line.
+
+Usage:
+
+Suppose you have two TSV files: "custom_train.tsv" and "custom_dev.tsv",
+where "custom" is your dataset name, "train"/"dev" are used for training and
+validation respectively.
+
+(1) Prepare the training data
+
+python3 -m zipvoice.bin.prepare_dataset \
+ --tsv-path data/raw/custom_train.tsv \
+ --prefix "custom" \
+ --subset "train" \
+ --num-jobs 20 \
+ --output-dir "data/manifests"
+
+The output file would be "data/manifests/custom_cuts_train.jsonl.gz".
+
+(2) Prepare the validation data
+
+python3 -m zipvoice.bin.prepare_dataset \
+ --tsv-path data/raw/custom_dev.tsv \
+ --prefix "custom" \
+ --subset "dev" \
+ --num-jobs 1 \
+ --output-dir "data/manifests"
+
+The output file would be "data/manifests/custom_cuts_dev.jsonl.gz".
+
+"""
+
+import argparse
+import logging
+import re
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import List, Optional, Tuple
+
+from lhotse import CutSet, validate_recordings_and_supervisions
+from lhotse.audio import Recording, RecordingSet
+from lhotse.qa import fix_manifests
+from lhotse.supervision import SupervisionSegment, SupervisionSet
+from lhotse.utils import Pathlike
+from tqdm.auto import tqdm
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--tsv-path",
+ type=str,
+ help="The path of the tsv file. Each line should be in the format: "
+ "{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time} "
+ "if text corresponds to part of the wav or {uniq_id}\t{text}\t{wav_path} "
+ "if the text corresponds to the full wav",
+ )
+ parser.add_argument(
+ "--prefix",
+ type=str,
+ default="custom",
+ help="Prefix of the output manifest file.",
+ )
+
+ parser.add_argument(
+ "--subset",
+ type=str,
+ default="train",
+ help="Subset name manifest file, typically train or dev.",
+ )
+
+ parser.add_argument(
+ "--num-jobs",
+ type=int,
+ default=20,
+ help="Number of jobs to processing.",
+ )
+
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="data/manifests",
+ help="The destination directory of manifest files.",
+ )
+ parser.add_argument(
+ "--sampling-rate",
+ type=int,
+ default=24000,
+ help="The target sampling rate.",
+ )
+ return parser.parse_args()
+
+
+def _parse_recording(
+ wav_path: str,
+) -> Tuple[Recording, str]:
+ """
+ :param wav_path: Path to the audio file
+ :return: a tuple of "recording" and "recording_id"
+ """
+
+ recording_id = wav_path.replace("/", "_").replace(".", "_")
+ recording = Recording.from_file(path=wav_path, recording_id=recording_id)
+
+ return recording, recording_id
+
+
+def _parse_supervision(
+ supervision: List, recording_dict: dict
+) -> Optional[SupervisionSegment]:
+ """
+ :param line: A line from the TSV file
+ :param recording_dict: Dictionary mapping recording IDs to Recording objects
+ :return: A SupervisionSegment object
+ """
+
+ uniq_id, text, wav_path, start, end = supervision
+ try:
+ recording_id = wav_path.replace("/", "_").replace(".", "_")
+
+ recording = recording_dict[recording_id]
+ duration = end - start if end is not None else recording.duration
+ assert duration <= recording.duration, f"Duration {duration} is greater than "
+ f"recording duration {recording.duration}"
+
+ text = re.sub("_", " ", text) # "_" is treated as padding symbol
+ text = re.sub(r"\s+", " ", text) # remove extra whitespace
+
+ return SupervisionSegment(
+ id=f"{uniq_id}",
+ recording_id=recording.id,
+ start=start,
+ duration=duration,
+ channel=recording.channel_ids,
+ text=text.strip(),
+ )
+ except Exception as e:
+ logging.warning(f"Error processing line: {e}")
+ return None
+
+
+def prepare_dataset(
+ tsv_path: Pathlike,
+ prefix: str,
+ subset: str,
+ sampling_rate: int,
+ num_jobs: int,
+ output_dir: Pathlike,
+):
+ """
+ Returns the manifests which consist of the Recordings and Supervisions
+
+ :param tsv_path: Path to the TSV file
+ :param output_dir: Path where to write the manifests
+ :param num_jobs: Number of processes for parallel processing
+ :return: The CutSet containing the data
+ """
+ logging.info(f"Preparing {prefix} dataset {subset} subset.")
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ file_name = f"{prefix}_cuts_{subset}.jsonl.gz"
+ if (output_dir / file_name).is_file():
+ logging.info(f"{file_name} exists, skipping.")
+ return
+
+ # Step 1: Read all unique recording paths
+ recordings_path_set = set()
+ supervision_list = list()
+ with open(tsv_path, "r") as fr:
+ for line in fr:
+ items = line.strip().split("\t")
+ if len(items) == 3:
+ uniq_id, text, wav_path = items
+ start, end = 0, None
+ elif len(items) == 5:
+ uniq_id, text, wav_path, start, end = items
+ start, end = float(start), float(end)
+ else:
+ raise ValueError(
+ f"Invalid line format: {line},"
+ "requries to be 3 columns or 5 columns"
+ )
+ recordings_path_set.add(wav_path)
+ supervision_list.append((uniq_id, text, wav_path, start, end))
+
+ logging.info("Starting to process recordings...")
+ # Step 2: Process recordings
+ futures = []
+ recording_dict = {}
+ with ThreadPoolExecutor(max_workers=num_jobs) as ex:
+ for wav_path in tqdm(recordings_path_set, desc="Submitting jobs"):
+ futures.append(ex.submit(_parse_recording, wav_path))
+
+ for future in tqdm(futures, desc="Processing recordings"):
+ try:
+ recording, recording_id = future.result()
+ recording_dict[recording_id] = recording
+ except Exception as e:
+ logging.warning(
+ f"Error processing recording {recording_id} with error: {e}"
+ )
+
+ recording_set = RecordingSet.from_recordings(recording_dict.values())
+
+ logging.info("Starting to process supervisions...")
+ # Step 3: Process supervisions
+ supervisions = []
+ for supervision in tqdm(supervision_list, desc="Processing supervisions"):
+ seg = _parse_supervision(supervision, recording_dict)
+ if seg is not None:
+ supervisions.append(seg)
+
+ logging.info("Processing Cuts...")
+
+ # Step 4: Create and validate manifests
+ supervision_set = SupervisionSet.from_segments(supervisions)
+
+ recording_set, supervision_set = fix_manifests(recording_set, supervision_set)
+ validate_recordings_and_supervisions(recording_set, supervision_set)
+
+ cut_set = CutSet.from_manifests(
+ recordings=recording_set, supervisions=supervision_set
+ )
+ cut_set = cut_set.sort_by_recording_id()
+ cut_set = cut_set.resample(sampling_rate)
+ cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
+
+ logging.info(f"Saving file to {output_dir / file_name}")
+ # Step 5: Write manifests to disk
+ cut_set.to_file(output_dir / file_name)
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ args = get_args()
+
+ prepare_dataset(
+ tsv_path=args.tsv_path,
+ prefix=args.prefix,
+ subset=args.subset,
+ sampling_rate=args.sampling_rate,
+ num_jobs=args.num_jobs,
+ output_dir=args.output_dir,
+ )
diff --git a/zipvoice/bin/prepare_tokens.py b/zipvoice/bin/prepare_tokens.py
new file mode 100644
index 0000000000000000000000000000000000000000..724d742784fd7ec3dbaa22ffbf23d3d07836c013
--- /dev/null
+++ b/zipvoice/bin/prepare_tokens.py
@@ -0,0 +1,102 @@
+"""
+This file reads the texts in given manifest and save the new cuts with prepared tokens.
+"""
+
+import argparse
+import logging
+from functools import partial
+from pathlib import Path
+
+from lhotse import load_manifest, split_parallelize_combine
+
+from zipvoice.tokenizer.tokenizer import add_tokens
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--input-file",
+ type=str,
+ help="Input manifest without tokens",
+ )
+
+ parser.add_argument(
+ "--output-file",
+ type=str,
+ help="Output manifest with tokens.",
+ )
+
+ parser.add_argument(
+ "--num-jobs",
+ type=int,
+ default=20,
+ help="Number of jobs to run in parallel.",
+ )
+
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ default="emilia",
+ help="The destination directory of manifest files.",
+ )
+
+ parser.add_argument(
+ "--lang",
+ type=str,
+ default="en-us",
+ help="Language identifier, used when tokenizer type is espeak. see"
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
+ )
+
+ return parser.parse_args()
+
+
+def prepare_tokens(
+ input_file: Path,
+ output_file: Path,
+ num_jobs: int,
+ tokenizer: str,
+ lang: str = "en-us",
+):
+ logging.info(f"Processing {input_file}")
+ if output_file.is_file():
+ logging.info(f"{output_file} exists, skipping.")
+ return
+ logging.info(f"loading manifest from {input_file}")
+ cut_set = load_manifest(input_file)
+
+ _add_tokens = partial(add_tokens, tokenizer=tokenizer, lang=lang)
+
+ logging.info("Adding tokens")
+
+ cut_set = split_parallelize_combine(
+ num_jobs=num_jobs, manifest=cut_set, fn=_add_tokens
+ )
+
+ logging.info(f"Saving file to {output_file}")
+ cut_set.to_file(output_file)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ args = get_args()
+ input_file = Path(args.input_file)
+ output_file = Path(args.output_file)
+ num_jobs = args.num_jobs
+ tokenizer = args.tokenizer
+ lang = args.lang
+
+ output_file.parent.mkdir(parents=True, exist_ok=True)
+
+ prepare_tokens(
+ input_file=input_file,
+ output_file=output_file,
+ num_jobs=num_jobs,
+ tokenizer=tokenizer,
+ lang=lang,
+ )
+
+ logging.info("Done!")
diff --git a/zipvoice/bin/train_zipvoice.py b/zipvoice/bin/train_zipvoice.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f1376c8f4376e8e9dc0af78e8dd1bafa572264e
--- /dev/null
+++ b/zipvoice/bin/train_zipvoice.py
@@ -0,0 +1,1136 @@
+#!/usr/bin/env python3
+# Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang,
+# Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script trains a ZipVoice model with the flow-matching loss.
+
+Usage:
+
+python3 -m zipvoice.bin.train_zipvoice \
+ --world-size 8 \
+ --use-fp16 1 \
+ --num-epochs 11 \
+ --max-duration 500 \
+ --lr-hours 30000 \
+ --model-config conf/zipvoice_base.json \
+ --tokenizer emilia \
+ --token-file "data/tokens_emilia.txt" \
+ --dataset emilia \
+ --manifest-dir data/fbank \
+ --exp-dir exp/zipvoice
+"""
+
+import argparse
+import copy
+import json
+import logging
+import os
+from functools import partial
+from pathlib import Path
+from shutil import copyfile
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from lhotse.cut import Cut, CutSet
+from lhotse.utils import fix_random_seed
+from torch import Tensor
+from torch.amp.grad_scaler import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.tensorboard import SummaryWriter
+
+import zipvoice.utils.diagnostics as diagnostics
+from zipvoice.dataset.datamodule import TtsDataModule
+from zipvoice.models.zipvoice import ZipVoice
+from zipvoice.tokenizer.tokenizer import (
+ EmiliaTokenizer,
+ EspeakTokenizer,
+ LibriTTSTokenizer,
+ SimpleTokenizer,
+ SimpleTokenizer2,
+)
+from zipvoice.utils.checkpoint import (
+ load_checkpoint,
+ remove_checkpoints,
+ resume_checkpoint,
+ save_checkpoint,
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from zipvoice.utils.common import (
+ AttributeDict,
+ MetricsTracker,
+ cleanup_dist,
+ create_grad_scaler,
+ get_adjusted_batch_count,
+ get_env_info,
+ get_parameter_groups_with_lrs,
+ prepare_input,
+ set_batch_count,
+ setup_dist,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
+from zipvoice.utils.hooks import register_inf_check_hooks
+from zipvoice.utils.lr_scheduler import Eden, FixedLRScheduler, LRScheduler
+from zipvoice.utils.optim import ScaledAdam
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12356,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=11,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--num-iters",
+ type=int,
+ default=0,
+ help="Number of iter to train, will ignore num_epochs if > 0.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ default=None,
+ help="""Checkpoints of pre-trained models, will load it if not None
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="exp/zipvoice",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.02, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=10,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--lr-hours",
+ type=float,
+ default=0,
+ help="""If positive, --epoch is ignored and it specifies the number of hours
+ that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=50,
+ help="""Reference batch duration for purposes of adjusting batch counts for"
+ setting various schedules inside the model".
+ """,
+ )
+
+ parser.add_argument(
+ "--finetune",
+ type=str2bool,
+ default=False,
+ help="Whether to use the fine-tuning mode, will used a fixed learning rate "
+ "schedule and skip the large dropout phase.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--scan-oom",
+ type=str2bool,
+ default=False,
+ help="Scan pessimistic batches to see whether they cause OOMs.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=5000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--valid-by-epoch",
+ type=str2bool,
+ default=False,
+ help="""Whether to validate after each epoch. If False, will validate
+ after every save_every_n iterations.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--feat-scale",
+ type=float,
+ default=0.1,
+ help="The scale factor of fbank feature",
+ )
+
+ parser.add_argument(
+ "--condition-drop-ratio",
+ type=float,
+ default=0.2,
+ help="The drop rate of text condition during training.",
+ )
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="emilia",
+ choices=["emilia", "libritts", "custom"],
+ help="The used training dataset",
+ )
+
+ parser.add_argument(
+ "--train-manifest",
+ type=str,
+ help="Path of the training manifest",
+ )
+
+ parser.add_argument(
+ "--dev-manifest",
+ type=str,
+ help="Path of the validation manifest",
+ )
+
+ parser.add_argument(
+ "--min-len",
+ type=float,
+ default=1.0,
+ help="The minimum audio length used for training",
+ )
+
+ parser.add_argument(
+ "--max-len",
+ type=float,
+ default=30.0,
+ help="The maximum audio length used for training",
+ )
+
+ parser.add_argument(
+ "--model-config",
+ type=str,
+ default="conf/zipvoice_base.json",
+ help="The model configuration file.",
+ )
+
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ default="emilia",
+ help="Tokenizer type.",
+ )
+
+ parser.add_argument(
+ "--lang",
+ type=str,
+ default="en-us",
+ help="Language identifier, used when tokenizer type is espeak. see"
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
+ )
+
+ parser.add_argument(
+ "--token-file",
+ type=str,
+ default="data/tokens_emilia.txt",
+ help="The file that contains information that maps tokens to ids,"
+ "which is a text file with '{token}\t{token_id}' per line.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - env_info: A dict containing information about the environment.
+
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def compute_fbank_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ features: Tensor,
+ features_lens: Tensor,
+ tokens: List[List[int]],
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training.
+ features:
+ The target acoustic feature.
+ features_lens:
+ The number of frames of each utterance.
+ tokens:
+ Input tokens that representing the transcripts.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ batch_size, num_frames, _ = features.shape
+
+ features = torch.nn.functional.pad(
+ features, (0, 0, 0, num_frames - features.size(1))
+ ) # (B, T, F)
+ noise = torch.randn_like(features) # (B, T, F)
+
+ # Sampling t from uniform distribution
+ if is_training:
+ t = torch.rand(batch_size, 1, 1, device=device)
+ else:
+ t = (
+ (torch.arange(batch_size, device=device) / batch_size)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ )
+ with torch.set_grad_enabled(is_training):
+
+ loss = model(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=noise,
+ t=t,
+ condition_drop_ratio=params.condition_drop_ratio,
+ )
+
+ assert loss.requires_grad == is_training
+ info = MetricsTracker()
+ num_frames = features_lens.sum().item()
+ info["frames"] = num_frames
+ info["loss"] = loss.detach().cpu().item() * num_frames
+
+ return loss, info
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches or every epochs.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer.
+ scheduler:
+ The learning rate scheduler, we call step() every epoch.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to track the stats over iterations in one epoch
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+
+ if batch_idx % 10 == 0:
+ if params.finetune:
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
+ else:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ if (
+ params.valid_by_epoch and batch_idx == 0 and not params.print_diagnostics
+ ) or (
+ not params.valid_by_epoch
+ and params.batch_idx_train % params.valid_interval == 0
+ and not params.print_diagnostics
+ ):
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(
+ f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
+ f" validation: {valid_info}"
+ )
+ logging.info(
+ f"Maximum memory allocated so far is "
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ params.batch_idx_train += 1
+
+ batch_size = len(batch["text"])
+
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ try:
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=True,
+ )
+
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ scaler.scale(loss).backward()
+
+ scheduler.step_batch(params.batch_idx_train)
+ # Use the number of hours of speech to adjust the learning rate
+ if params.lr_hours > 0:
+ scheduler.step_epoch(
+ params.batch_idx_train
+ * params.max_duration
+ * params.world_size
+ / 3600
+ )
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except Exception as e:
+ logging.info(f"Caught exception : {e}.")
+ save_bad_model()
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
+ break
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have
+ # different behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 1024.0 or (
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
+ ):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if params.batch_idx_train % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"global_batch_idx: {params.batch_idx_train}, "
+ f"batch size: {batch_size}, "
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ f"cur_lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ loss_value = tot_loss["loss"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+
+ model.eval()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ features = batch["features"]
+ tokens = batch["tokens"]
+
+ logging.info(f"features shape: {features.shape}")
+ num_tokens = sum(len(i) for i in tokens)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+ try:
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
+
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is "
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
+ )
+
+
+def tokenize_text(c: Cut, tokenizer):
+ if hasattr(c.supervisions[0], "tokens"):
+ tokens = tokenizer.tokens_to_token_ids([c.supervisions[0].tokens])
+ else:
+ tokens = tokenizer.texts_to_token_ids([c.supervisions[0].text])
+ print("ko tìm được tokens")
+ c.supervisions[0].tokens = tokens[0]
+ return c
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+ params.valid_interval = params.save_every_n
+ # Set epoch to a large number to ignore it.
+ if params.num_iters > 0:
+ params.num_epochs = 1000000
+ with open(params.model_config, "r") as f:
+ model_config = json.load(f)
+ params.update(model_config["model"])
+ params.update(model_config["feature"])
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ os.makedirs(f"{params.exp_dir}", exist_ok=True)
+ copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
+ copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
+ setup_logger(f"{params.exp_dir}/log/log-train")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ if torch.cuda.is_available():
+ params.device = torch.device("cuda", rank)
+ else:
+ params.device = torch.device("cpu")
+ logging.info(f"Device: {params.device}")
+
+ if params.tokenizer == "emilia":
+ tokenizer = EmiliaTokenizer(token_file=params.token_file)
+ elif params.tokenizer == "libritts":
+ tokenizer = LibriTTSTokenizer(token_file=params.token_file)
+ elif params.tokenizer == "espeak":
+ tokenizer = EspeakTokenizer(token_file=params.token_file, lang=params.lang)
+ elif params.tokenizer == "simple2":
+ tokenizer = SimpleTokenizer2(token_file=params.token_file)
+ else:
+ assert params.tokenizer == "simple"
+ tokenizer = SimpleTokenizer(token_file=params.token_file)
+
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
+ params.update(tokenizer_config)
+
+ logging.info(params)
+
+ logging.info("About to create model")
+
+ model = ZipVoice(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+
+ if params.checkpoint is not None:
+ logging.info(f"Loading pre-trained model from {params.checkpoint}")
+ _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of parameters : {num_param}")
+
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ if params.start_epoch > 1:
+ checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
+
+ model = model.to(params.device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(
+ model,
+ lr=params.base_lr,
+ include_names=True,
+ ),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ assert params.lr_hours >= 0
+
+ if params.finetune:
+ scheduler = FixedLRScheduler(optimizer)
+ elif params.lr_hours > 0:
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_hours)
+ else:
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ scaler = create_grad_scaler(enabled=params.use_fp16)
+
+ if params.start_epoch > 1 and checkpoints is not None:
+ # load state_dict for optimizers
+ if "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ # load state_dict for schedulers
+ if "scheduler" in checkpoints:
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
+ if c.duration < min_len or c.duration > max_len:
+ return False
+ return True
+
+ _remove_short_and_long_utt = partial(
+ remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
+ )
+
+ datamodule = TtsDataModule(args)
+ if params.dataset == "emilia":
+ train_cuts = CutSet.mux(
+ datamodule.train_emilia_EN_cuts(),
+ datamodule.train_emilia_ZH_cuts(),
+ weights=[46000, 49000],
+ )
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
+ dev_cuts = CutSet.mux(
+ datamodule.dev_emilia_EN_cuts(),
+ datamodule.dev_emilia_ZH_cuts(),
+ weights=[0.5, 0.5],
+ )
+ elif params.dataset == "libritts":
+ train_cuts = datamodule.train_libritts_cuts()
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
+ dev_cuts = datamodule.dev_libritts_cuts()
+ else:
+ assert params.dataset == "custom"
+ train_cuts = datamodule.train_custom_cuts(params.train_manifest)
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
+ dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
+ # To avoid OOM issues due to too long dev cuts
+ dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
+
+ if params.tokenizer in ["emilia", "espeak", "dialog"]:
+ if not hasattr(train_cuts[0].supervisions[0], "tokens") or not hasattr(
+ dev_cuts[0].supervisions[0], "tokens"
+ ):
+ logging.warning(
+ f"Using {params.tokenizer} tokenizer but tokens are not prepared,"
+ f"will tokenize on-the-fly, which can slow down training significantly."
+ )
+ _tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
+ train_cuts = train_cuts.map(_tokenize_text)
+ dev_cuts = dev_cuts.map(_tokenize_text)
+
+ train_dl = datamodule.train_dataloaders(train_cuts)
+
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
+
+ if params.scan_oom:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ params=params,
+ )
+
+ logging.info("Training started")
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ logging.info(f"Start epoch {epoch}")
+
+ if params.lr_hours == 0:
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ params.cur_epoch = epoch
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
+ break
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint(
+ filename=filename,
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if rank == 0:
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ TtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ main()
diff --git a/zipvoice/bin/train_zipvoice_dialog.py b/zipvoice/bin/train_zipvoice_dialog.py
new file mode 100644
index 0000000000000000000000000000000000000000..2515b8b65daef95fc1fc580793156dfaf42bd88b
--- /dev/null
+++ b/zipvoice/bin/train_zipvoice_dialog.py
@@ -0,0 +1,983 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script trains a ZipVoice-Dialog model.
+
+Usage:
+
+python3 -m zipvoice.bin.train_zipvoice_dialog \
+ --world-size 8 \
+ --use-fp16 1 \
+ --base-lr 0.0001 \
+ --max-duration 500 \
+ --checkpoint download/zipvoice/model.pt \
+ --model-config conf/zipvoice_base.json \
+ --token-file "data/tokens_dialog.txt" \
+ --dataset opendialog \
+ --manifest-dir data/fbank \
+ --exp-dir exp/zipvoice_dialog
+"""
+
+import argparse
+import copy
+import json
+import logging
+import os
+from functools import partial
+from pathlib import Path
+from shutil import copyfile
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from lhotse.cut import Cut, CutSet
+from lhotse.utils import fix_random_seed
+from torch import Tensor
+from torch.amp.grad_scaler import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.tensorboard import SummaryWriter
+
+import zipvoice.utils.diagnostics as diagnostics
+from zipvoice.bin.train_zipvoice import (
+ display_and_save_batch,
+ get_params,
+ tokenize_text,
+)
+from zipvoice.dataset.datamodule import TtsDataModule
+from zipvoice.models.zipvoice_dialog import ZipVoiceDialog
+from zipvoice.tokenizer.tokenizer import DialogTokenizer
+from zipvoice.utils.checkpoint import (
+ load_checkpoint,
+ load_checkpoint_extend_vocab_size,
+ remove_checkpoints,
+ resume_checkpoint,
+ save_checkpoint,
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from zipvoice.utils.common import (
+ AttributeDict,
+ MetricsTracker,
+ cleanup_dist,
+ create_grad_scaler,
+ get_adjusted_batch_count,
+ get_parameter_groups_with_lrs,
+ prepare_input,
+ set_batch_count,
+ setup_dist,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
+from zipvoice.utils.hooks import register_inf_check_hooks
+from zipvoice.utils.lr_scheduler import FixedLRScheduler, LRScheduler
+from zipvoice.utils.optim import ScaledAdam
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12356,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=8,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--num-iters",
+ type=int,
+ default=60000,
+ help="Number of iter to train, will ignore num_epochs if > 0.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="""Checkpoints of pre-trained models, either a ZipVoice model or a
+ ZipVoice-Dialog model.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="exp/zipvoice_dialog",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.0001, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=50,
+ help="""Reference batch duration for purposes of adjusting batch counts for"
+ setting various schedules inside the model".
+ """,
+ )
+
+ parser.add_argument(
+ "--finetune",
+ type=str2bool,
+ default=False,
+ help="Whether to fine-tune from our pre-traied ZipVoice-Dialog model."
+ "False means to fine-tune from a pre-trained ZipVoice model.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--scan-oom",
+ type=str2bool,
+ default=False,
+ help="Scan pessimistic batches to see whether they cause OOMs.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=5000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--feat-scale",
+ type=float,
+ default=0.1,
+ help="The scale factor of fbank feature",
+ )
+
+ parser.add_argument(
+ "--condition-drop-ratio",
+ type=float,
+ default=0.2,
+ help="The drop rate of text condition during training.",
+ )
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="opendialog",
+ choices=["opendialog", "custom"],
+ help="The used training dataset",
+ )
+
+ parser.add_argument(
+ "--train-manifest",
+ type=str,
+ help="Path of the training manifest",
+ )
+
+ parser.add_argument(
+ "--dev-manifest",
+ type=str,
+ help="Path of the validation manifest",
+ )
+
+ parser.add_argument(
+ "--min-len",
+ type=float,
+ default=1.0,
+ help="The minimum audio length used for training",
+ )
+
+ parser.add_argument(
+ "--max-len",
+ type=float,
+ default=30.0,
+ help="The maximum audio length used for training",
+ )
+
+ parser.add_argument(
+ "--model-config",
+ type=str,
+ default="zipvoice_base.json",
+ help="The model configuration file.",
+ )
+
+ parser.add_argument(
+ "--token-file",
+ type=str,
+ default="data/tokens_dialog.txt",
+ help="The file that contains information that maps tokens to ids,"
+ "which is a text file with '{token}\t{token_id}' per line.",
+ )
+
+ return parser
+
+
+def compute_fbank_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ features: Tensor,
+ features_lens: Tensor,
+ tokens: List[List[int]],
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training.
+ features:
+ The target acoustic feature.
+ features_lens:
+ The number of frames of each utterance.
+ tokens:
+ Input tokens that representing the transcripts.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ batch_size, num_frames, _ = features.shape
+
+ features = torch.nn.functional.pad(
+ features, (0, 0, 0, num_frames - features.size(1))
+ ) # (B, T, F)
+ noise = torch.randn_like(features) # (B, T, F)
+
+ # Sampling t from uniform distribution
+ if is_training:
+ t = torch.rand(batch_size, 1, 1, device=device)
+ else:
+ t = (
+ (torch.arange(batch_size, device=device) / batch_size)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ )
+ with torch.set_grad_enabled(is_training):
+
+ loss = model(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=noise,
+ t=t,
+ condition_drop_ratio=params.condition_drop_ratio,
+ )
+
+ assert loss.requires_grad == is_training
+ info = MetricsTracker()
+ num_frames = features_lens.sum().item()
+ info["frames"] = num_frames
+ info["loss"] = loss.detach().cpu().item() * num_frames
+
+ return loss, info
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer.
+ scheduler:
+ The learning rate scheduler, we call step() every epoch.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to track the stats over iterations in one epoch
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.valid_interval == 0
+ and not params.print_diagnostics
+ ):
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(
+ f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
+ f" validation: {valid_info}"
+ )
+ logging.info(
+ f"Maximum memory allocated so far is "
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ params.batch_idx_train += 1
+
+ batch_size = len(batch["text"])
+
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ try:
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=True,
+ )
+
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ scaler.scale(loss).backward()
+
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except Exception as e:
+ logging.info(f"Caught exception : {e}.")
+ save_bad_model()
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
+ break
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have
+ # different behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 1024.0 or (
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
+ ):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if params.batch_idx_train % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"global_batch_idx: {params.batch_idx_train}, "
+ f"batch size: {batch_size}, "
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ f"cur_lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ loss_value = tot_loss["loss"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+
+ model.eval()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+ try:
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
+
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is "
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
+ )
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+ params.valid_interval = params.save_every_n
+ # Set epoch to a large number to ignore it.
+ if params.num_iters > 0:
+ params.num_epochs = 1000000
+ with open(params.model_config, "r") as f:
+ model_config = json.load(f)
+ params.update(model_config["model"])
+ params.update(model_config["feature"])
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ os.makedirs(f"{params.exp_dir}", exist_ok=True)
+ copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
+ copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
+ setup_logger(f"{params.exp_dir}/log/log-train")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ if torch.cuda.is_available():
+ params.device = torch.device("cuda", rank)
+ else:
+ params.device = torch.device("cpu")
+ logging.info(f"Device: {params.device}")
+
+ tokenizer = DialogTokenizer(token_file=params.token_file)
+ tokenizer_config = {
+ "vocab_size": tokenizer.vocab_size,
+ "pad_id": tokenizer.pad_id,
+ "spk_a_id": tokenizer.spk_a_id,
+ "spk_b_id": tokenizer.spk_b_id,
+ }
+ params.update(tokenizer_config)
+
+ logging.info(params)
+
+ logging.info("About to create model")
+
+ model = ZipVoiceDialog(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+
+ assert params.checkpoint is not None, (
+ "require a pre-trained checkpoint, as training from random initialization "
+ "leads to uninteligible dialogue speech"
+ )
+ logging.info(f"Loading pre-trained model from {params.checkpoint}")
+
+ if params.finetune:
+ # load a pre-trained ZipVoice-Dialog model
+ _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
+ else:
+ # load a pre-trained ZipVoice model, extend the vocab size for additional tokens
+ _ = load_checkpoint_extend_vocab_size(
+ filename=params.checkpoint,
+ extend_size=28,
+ model=model,
+ strict=True,
+ )
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of parameters : {num_param}")
+
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ if params.start_epoch > 1:
+ checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
+
+ model = model.to(params.device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(
+ model,
+ lr=params.base_lr,
+ include_names=True,
+ ),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = FixedLRScheduler(optimizer)
+
+ scaler = create_grad_scaler(enabled=params.use_fp16)
+
+ if params.start_epoch > 1 and checkpoints is not None:
+ # load state_dict for optimizers
+ if "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ # load state_dict for schedulers
+ if "scheduler" in checkpoints:
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
+ if c.duration < min_len or c.duration > max_len:
+ return False
+ return True
+
+ _remove_short_and_long_utt = partial(
+ remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
+ )
+
+ datamodule = TtsDataModule(args)
+ if params.dataset == "opendialog":
+ train_opendialog_en_cuts = datamodule.train_opendialog_en_cuts()
+ train_opendialog_zh_cuts = datamodule.train_opendialog_zh_cuts().repeat(2)
+
+ train_cuts = CutSet.mux(
+ train_opendialog_en_cuts,
+ train_opendialog_zh_cuts,
+ weights=[
+ len(train_opendialog_en_cuts),
+ len(train_opendialog_zh_cuts),
+ ],
+ )
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
+
+ dev_cuts = CutSet.mux(
+ datamodule.dev_opendialog_en_cuts(),
+ datamodule.dev_opendialog_zh_cuts(),
+ )
+ else:
+ assert params.dataset == "custom"
+ train_cuts = datamodule.train_custom_cuts(params.train_manifest)
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
+ dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
+ # To avoid OOM issues due to too long dev cuts
+ dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
+
+ if not hasattr(train_cuts[0].supervisions[0], "tokens") or not hasattr(
+ dev_cuts[0].supervisions[0], "tokens"
+ ):
+ logging.warning(
+ "Tokens are not prepared, will tokenize on-the-fly, "
+ "which can slow down training significantly."
+ )
+ _tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
+ train_cuts = train_cuts.map(_tokenize_text)
+ dev_cuts = dev_cuts.map(_tokenize_text)
+
+ train_dl = datamodule.train_dataloaders(train_cuts)
+
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
+
+ if params.scan_oom:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ params=params,
+ )
+
+ logging.info("Training started")
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ logging.info(f"Start epoch {epoch}")
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ params.cur_epoch = epoch
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
+ break
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint(
+ filename=filename,
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if rank == 0:
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ TtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ main()
diff --git a/zipvoice/bin/train_zipvoice_dialog_stereo.py b/zipvoice/bin/train_zipvoice_dialog_stereo.py
new file mode 100644
index 0000000000000000000000000000000000000000..5764514279befd2a07081397e731121498e5e7ac
--- /dev/null
+++ b/zipvoice/bin/train_zipvoice_dialog_stereo.py
@@ -0,0 +1,966 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script trains a ZipVoice-Dialog model.
+
+Usage:
+
+python3 -m zipvoice.bin.train_zipvoice_dialog_stereo \
+ --world-size 8 \
+ --use-fp16 1 \
+ --base-lr 0.002 \
+ --max-duration 500 \
+ --model-config conf/zipvoice_base.json \
+ --token-file "data/tokens_dialog.txt" \
+ --manifest-dir data/fbank \
+ --exp-dir exp/zipvoice_dialog_stereo
+"""
+
+import argparse
+import copy
+import json
+import logging
+import os
+from functools import partial
+from pathlib import Path
+from shutil import copyfile
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from torch import Tensor
+from torch.amp.grad_scaler import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.tensorboard import SummaryWriter
+
+import zipvoice.utils.diagnostics as diagnostics
+from zipvoice.bin.train_zipvoice import (
+ display_and_save_batch,
+ get_params,
+ tokenize_text,
+)
+from zipvoice.dataset.datamodule import TtsDataModule
+from zipvoice.models.zipvoice_dialog import ZipVoiceDialogStereo
+from zipvoice.tokenizer.tokenizer import DialogTokenizer
+from zipvoice.utils.checkpoint import (
+ load_checkpoint,
+ load_checkpoint_copy_proj_three_channel_alter,
+ remove_checkpoints,
+ resume_checkpoint,
+ save_checkpoint,
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from zipvoice.utils.common import (
+ AttributeDict,
+ MetricsTracker,
+ cleanup_dist,
+ create_grad_scaler,
+ get_adjusted_batch_count,
+ get_parameter_groups_with_lrs,
+ prepare_input,
+ set_batch_count,
+ setup_dist,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
+from zipvoice.utils.hooks import register_inf_check_hooks
+from zipvoice.utils.lr_scheduler import FixedLRScheduler, LRScheduler
+from zipvoice.utils.optim import ScaledAdam
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12356,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=8,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--num-iters",
+ type=int,
+ default=25000,
+ help="Number of iter to train, will ignore num_epochs if > 0.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="""Checkpoints of pre-trained models, either a ZipVoice model or a
+ ZipVoice-Dialog model.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="exp/zipvoice_dialog",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.002, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=50,
+ help="""Reference batch duration for purposes of adjusting batch counts for"
+ setting various schedules inside the model".
+ """,
+ )
+
+ parser.add_argument(
+ "--finetune",
+ type=str2bool,
+ default=False,
+ help="Whether to fine-tune from our pre-traied ZipVoice-Dialog model."
+ "False means to fine-tune from a pre-trained ZipVoice model.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--scan-oom",
+ type=str2bool,
+ default=False,
+ help="Scan pessimistic batches to see whether they cause OOMs.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=5000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--feat-scale",
+ type=float,
+ default=0.1,
+ help="The scale factor of fbank feature",
+ )
+
+ parser.add_argument(
+ "--condition-drop-ratio",
+ type=float,
+ default=0.2,
+ help="The drop rate of text condition during training.",
+ )
+
+ parser.add_argument(
+ "--train-manifest",
+ type=str,
+ help="Path of the training manifest",
+ )
+
+ parser.add_argument(
+ "--dev-manifest",
+ type=str,
+ help="Path of the validation manifest",
+ )
+
+ parser.add_argument(
+ "--min-len",
+ type=float,
+ default=1.0,
+ help="The minimum audio length used for training",
+ )
+
+ parser.add_argument(
+ "--max-len",
+ type=float,
+ default=60.0,
+ help="The maximum audio length used for training",
+ )
+
+ parser.add_argument(
+ "--model-config",
+ type=str,
+ default="zipvoice_base.json",
+ help="The model configuration file.",
+ )
+
+ parser.add_argument(
+ "--token-file",
+ type=str,
+ default="data/tokens_dialog.txt",
+ help="The file that contains information that maps tokens to ids,"
+ "which is a text file with '{token}\t{token_id}' per line.",
+ )
+
+ return parser
+
+
+def compute_fbank_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ features: Tensor,
+ features_lens: Tensor,
+ tokens: List[List[int]],
+ is_training: bool,
+ use_two_channel: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training.
+ features:
+ The target acoustic feature.
+ features_lens:
+ The number of frames of each utterance.
+ tokens:
+ Input tokens that representing the transcripts.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ use_two_channel:
+ True for using two channel features, False for using one channel features.
+ """
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ batch_size, num_frames, _ = features.shape
+
+ features = torch.nn.functional.pad(
+ features, (0, 0, 0, num_frames - features.size(1))
+ ) # (B, T, F)
+ assert (
+ features.size(2) == 3 * params.feat_dim
+ ), "we assume three channel features, the last channel is the mixed-channel feature"
+ if use_two_channel:
+ features = features[:, :, : params.feat_dim * 2]
+ else:
+ features = features[:, :, params.feat_dim * 2 :]
+
+ noise = torch.randn_like(features) # (B, T, F)
+
+ # Sampling t from uniform distribution
+ if is_training:
+ t = torch.rand(batch_size, 1, 1, device=device)
+ else:
+ t = (
+ (torch.arange(batch_size, device=device) / batch_size)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ )
+ with torch.set_grad_enabled(is_training):
+
+ loss = model(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=noise,
+ t=t,
+ condition_drop_ratio=params.condition_drop_ratio,
+ se_weight=1 if use_two_channel else 0,
+ )
+
+ assert loss.requires_grad == is_training
+ info = MetricsTracker()
+ num_frames = features_lens.sum().item()
+ info["frames"] = num_frames
+ info["loss"] = loss.detach().cpu().item() * num_frames
+
+ return loss, info
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer.
+ scheduler:
+ The learning rate scheduler, we call step() every epoch.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to track the stats over iterations in one epoch
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.valid_interval == 0
+ and not params.print_diagnostics
+ ):
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(
+ f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
+ f" validation: {valid_info}"
+ )
+ logging.info(
+ f"Maximum memory allocated so far is "
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ params.batch_idx_train += 1
+
+ batch_size = len(batch["text"])
+
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ try:
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=True,
+ use_two_channel=(batch_idx % 2 == 1),
+ )
+
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ scaler.scale(loss).backward()
+
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except Exception as e:
+ logging.info(f"Caught exception : {e}.")
+ save_bad_model()
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
+ break
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have
+ # different behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 1024.0 or (
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
+ ):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if params.batch_idx_train % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"global_batch_idx: {params.batch_idx_train}, "
+ f"batch size: {batch_size}, "
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ f"cur_lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ loss_value = tot_loss["loss"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+
+ model.eval()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=False,
+ use_two_channel=True,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+ try:
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
+
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=True,
+ use_two_channel=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is "
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
+ )
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+ params.valid_interval = params.save_every_n
+ # Set epoch to a large number to ignore it.
+ if params.num_iters > 0:
+ params.num_epochs = 1000000
+ with open(params.model_config, "r") as f:
+ model_config = json.load(f)
+ params.update(model_config["model"])
+ params.update(model_config["feature"])
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ os.makedirs(f"{params.exp_dir}", exist_ok=True)
+ copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
+ copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
+ setup_logger(f"{params.exp_dir}/log/log-train")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ if torch.cuda.is_available():
+ params.device = torch.device("cuda", rank)
+ else:
+ params.device = torch.device("cpu")
+ logging.info(f"Device: {params.device}")
+
+ tokenizer = DialogTokenizer(token_file=params.token_file)
+ tokenizer_config = {
+ "vocab_size": tokenizer.vocab_size,
+ "pad_id": tokenizer.pad_id,
+ "spk_a_id": tokenizer.spk_a_id,
+ "spk_b_id": tokenizer.spk_b_id,
+ }
+ params.update(tokenizer_config)
+
+ logging.info(params)
+
+ logging.info("About to create model")
+
+ model = ZipVoiceDialogStereo(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+
+ assert params.checkpoint is not None
+ logging.info(f"Loading pre-trained model from {params.checkpoint}")
+
+ if params.finetune:
+ # load a pre-trained ZipVoice-Dialog-Stereo model
+ _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
+ else:
+ # load a pre-trained ZipVoice-Dialog model, duplicate the proj layers
+ load_checkpoint_copy_proj_three_channel_alter(
+ filename=params.checkpoint,
+ in_proj_key="fm_decoder.in_proj",
+ out_proj_key="fm_decoder.out_proj",
+ dim=params.feat_dim,
+ model=model,
+ )
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of parameters : {num_param}")
+
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ if params.start_epoch > 1:
+ checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
+
+ model = model.to(params.device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(
+ model,
+ lr=params.base_lr,
+ include_names=True,
+ ),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = FixedLRScheduler(optimizer)
+
+ scaler = create_grad_scaler(enabled=params.use_fp16)
+
+ if params.start_epoch > 1 and checkpoints is not None:
+ # load state_dict for optimizers
+ if "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ # load state_dict for schedulers
+ if "scheduler" in checkpoints:
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
+ if c.duration < min_len or c.duration > max_len:
+ return False
+ return True
+
+ _remove_short_and_long_utt = partial(
+ remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
+ )
+
+ datamodule = TtsDataModule(args)
+ train_cuts = datamodule.train_custom_cuts(params.train_manifest)
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
+ dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
+ # To avoid OOM issues due to too long dev cuts
+ dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
+
+ if not hasattr(train_cuts[0].supervisions[0], "tokens") or not hasattr(
+ dev_cuts[0].supervisions[0], "tokens"
+ ):
+ logging.warning(
+ "Tokens are not prepared, will tokenize on-the-fly, "
+ "which can slow down training significantly."
+ )
+ _tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
+ train_cuts = train_cuts.map(_tokenize_text)
+ dev_cuts = dev_cuts.map(_tokenize_text)
+
+ train_dl = datamodule.train_dataloaders(train_cuts)
+
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
+
+ if params.scan_oom:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ params=params,
+ )
+
+ logging.info("Training started")
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ logging.info(f"Start epoch {epoch}")
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ params.cur_epoch = epoch
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
+ break
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint(
+ filename=filename,
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if rank == 0:
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ TtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ main()
diff --git a/zipvoice/bin/train_zipvoice_distill.py b/zipvoice/bin/train_zipvoice_distill.py
new file mode 100644
index 0000000000000000000000000000000000000000..9701398cbb57c436a094bcc08a7dd3ed6488476e
--- /dev/null
+++ b/zipvoice/bin/train_zipvoice_distill.py
@@ -0,0 +1,1161 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script trains a ZipVoice-Distill model starting from a ZipVoice model.
+It has two distillation stages.
+
+Usage:
+
+(1) The first distillation stage with a fixed ZipVoice model as the teacher.
+
+python3 -m zipvoice.bin.train_zipvoice_distill \
+ --world-size 8 \
+ --use-fp16 1 \
+ --num-iters 60000 \
+ --max-duration 500 \
+ --base-lr 0.0005 \
+ --tokenizer emilia \
+ --token-file data/tokens_emilia.txt \
+ --dataset emilia \
+ --manifest-dir data/fbank \
+ --teacher-model exp/zipvoice/epoch-11-avg-4.pt \
+ --distill-stage first \
+ --exp-dir exp/zipvoice_distill_1stage
+
+(2) The second distillation stage with a EMA model as the teacher.
+python3 -m zipvoice.bin.train_zipvoice_distill \
+ --world-size 8 \
+ --use-fp16 1 \
+ --num-iters 2000 \
+ --save-every-n 1000 \
+ --max-duration 500 \
+ --base-lr 0.0001 \
+ --model-config conf/zipvoice_base.json \
+ --tokenizer emilia \
+ --token-file data/tokens_emilia.txt \
+ --dataset emilia \
+ --manifest-dir data/fbank \
+ --teacher-model exp/zipvoice_distill_1stage/iter-60000-avg-7.pt \
+ --distill-stage second \
+ --exp-dir exp/zipvoice_distill
+"""
+
+import argparse
+import copy
+import json
+import logging
+import os
+import random
+from functools import partial
+from pathlib import Path
+from shutil import copyfile
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from lhotse.cut import Cut, CutSet
+from lhotse.utils import fix_random_seed
+from torch import Tensor
+from torch.amp.grad_scaler import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.tensorboard import SummaryWriter
+
+import zipvoice.utils.diagnostics as diagnostics
+from zipvoice.bin.train_zipvoice import (
+ display_and_save_batch,
+ get_params,
+ tokenize_text,
+)
+from zipvoice.dataset.datamodule import TtsDataModule
+from zipvoice.models.zipvoice import ZipVoice
+from zipvoice.models.zipvoice_distill import ZipVoiceDistill
+from zipvoice.tokenizer.tokenizer import (
+ EmiliaTokenizer,
+ EspeakTokenizer,
+ LibriTTSTokenizer,
+ SimpleTokenizer,
+)
+from zipvoice.utils.checkpoint import (
+ load_checkpoint,
+ remove_checkpoints,
+ resume_checkpoint,
+ save_checkpoint,
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from zipvoice.utils.common import (
+ AttributeDict,
+ MetricsTracker,
+ cleanup_dist,
+ condition_time_mask,
+ create_grad_scaler,
+ get_adjusted_batch_count,
+ get_parameter_groups_with_lrs,
+ make_pad_mask,
+ prepare_input,
+ set_batch_count,
+ setup_dist,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
+from zipvoice.utils.hooks import register_inf_check_hooks
+from zipvoice.utils.lr_scheduler import FixedLRScheduler, LRScheduler
+from zipvoice.utils.optim import ScaledAdam
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12356,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=1,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--num-iters",
+ type=int,
+ default=0,
+ help="Number of iter to train, will ignore num_epochs if > 0.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--teacher-model",
+ type=str,
+ help="""Checkpoints of pre-trained teacher model""",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="exp/zipvoice_distill",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.001, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=50,
+ help="Reference batch duration for purposes of adjusting batch counts for "
+ "setting various schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--scan-oom",
+ type=str2bool,
+ default=False,
+ help="Scan pessimistic batches to see whether they cause OOMs.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=1000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--feat-scale",
+ type=float,
+ default=0.1,
+ help="The scale factor of fbank feature",
+ )
+
+ parser.add_argument(
+ "--ema-decay",
+ type=float,
+ default=0.9999,
+ help="The EMA decay factor of target model in distillation.",
+ )
+ parser.add_argument(
+ "--distill-stage",
+ type=str,
+ choices=["first", "second"],
+ help="The stage of distillation.",
+ )
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="emilia",
+ choices=["emilia", "libritts", "custom"],
+ help="The used training dataset",
+ )
+
+ parser.add_argument(
+ "--train-manifest",
+ type=str,
+ help="Path of the training manifest",
+ )
+
+ parser.add_argument(
+ "--dev-manifest",
+ type=str,
+ help="Path of the validation manifest",
+ )
+
+ parser.add_argument(
+ "--min-len",
+ type=float,
+ default=1.0,
+ help="The minimum audio length used for training",
+ )
+
+ parser.add_argument(
+ "--max-len",
+ type=float,
+ default=30.0,
+ help="The maximum audio length used for training",
+ )
+
+ parser.add_argument(
+ "--model-config",
+ type=str,
+ default="conf/zipvoice_base.json",
+ help="The model configuration file.",
+ )
+
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ default="emilia",
+ choices=["emilia", "libritts", "espeak", "simple"],
+ help="Tokenizer type.",
+ )
+
+ parser.add_argument(
+ "--lang",
+ type=str,
+ default="en-us",
+ help="Language identifier, used when tokenizer type is espeak. see"
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
+ )
+
+ parser.add_argument(
+ "--token-file",
+ type=str,
+ default="data/tokens_emilia.txt",
+ help="The file that contains information that maps tokens to ids,"
+ "which is a text file with '{token}\t{token_id}' per line.",
+ )
+
+ return parser
+
+
+def ema(new_model, ema_model, decay):
+ if isinstance(new_model, DDP):
+ new_model = new_model.module
+ if isinstance(ema_model, DDP):
+ ema_model = ema_model.module
+ new_model_dict = new_model.state_dict()
+ ema_model_dict = ema_model.state_dict()
+ for key in new_model_dict.keys():
+ ema_model_dict[key].data.copy_(
+ ema_model_dict[key].data * decay + new_model_dict[key].data * (1 - decay)
+ )
+
+
+def compute_fbank_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ teacher_model: Union[nn.Module, DDP],
+ features: Tensor,
+ features_lens: Tensor,
+ tokens: List[List[int]],
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training.
+ teacher_model:
+ The teacher model for distillation.
+ features:
+ The target acoustic feature.
+ features_lens:
+ The number of frames of each utterance.
+ tokens:
+ Input tokens that representing the transcripts.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ batch_size, num_frames, _ = features.shape
+
+ features = torch.nn.functional.pad(
+ features, (0, 0, 0, num_frames - features.size(1))
+ ) # (B, T, F)
+ noise = torch.randn_like(features) # (B, T, F)
+
+ # Sampling t and guidance_scale from uniform distribution
+
+ t_value = random.random()
+ t = torch.ones(batch_size, 1, 1, device=device) * t_value
+ if params.distill_stage == "first":
+ guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2
+ else:
+ guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2 + 1
+ xt = features * t + noise * (1 - t)
+ t_delta_fix = random.uniform(0.0, min(0.3, 1 - t_value))
+ t_delta_ema = random.uniform(0.0, min(0.3, 1 - t_value - t_delta_fix))
+ t_dest = t_value + t_delta_fix + t_delta_ema
+
+ with torch.no_grad():
+ speech_condition_mask = condition_time_mask(
+ features_lens=features_lens,
+ mask_percent=(0.7, 1.0),
+ max_len=features.size(1),
+ )
+
+ if params.distill_stage == "first":
+ teacher_x_t_mid, _ = teacher_model.sample_intermediate(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=xt,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t_value,
+ t_end=t_value + t_delta_fix,
+ num_step=1,
+ guidance_scale=guidance_scale,
+ )
+
+ target_x1, _ = teacher_model.sample_intermediate(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=teacher_x_t_mid,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t_value + t_delta_fix,
+ t_end=t_dest,
+ num_step=1,
+ guidance_scale=guidance_scale,
+ )
+ else:
+ teacher_x_t_mid, _ = teacher_model(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=xt,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t_value,
+ t_end=t_value + t_delta_fix,
+ num_step=1,
+ guidance_scale=guidance_scale,
+ )
+
+ target_x1, _ = teacher_model(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=teacher_x_t_mid,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t_value + t_delta_fix,
+ t_end=t_dest,
+ num_step=1,
+ guidance_scale=guidance_scale,
+ )
+
+ with torch.set_grad_enabled(is_training):
+
+ pred_x1, _ = model(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=xt,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t_value,
+ t_end=t_dest,
+ num_step=1,
+ guidance_scale=guidance_scale,
+ )
+ pred_v = (pred_x1 - xt) / (t_dest - t)
+
+ padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
+ loss_mask = speech_condition_mask & (~padding_mask)
+
+ target_v = (target_x1 - xt) / (t_dest - t)
+ loss = torch.mean((pred_v[loss_mask] - target_v[loss_mask]) ** 2)
+
+ ut = features - noise # (B, T, F)
+
+ ref_loss = torch.mean((pred_v[loss_mask] - ut[loss_mask]) ** 2)
+
+ assert loss.requires_grad == is_training
+ info = MetricsTracker()
+ num_frames = features_lens.sum().item()
+ info["frames"] = num_frames
+ info["loss"] = loss.detach().cpu().item() * num_frames
+ info["ref_loss"] = ref_loss.detach().cpu().item() * num_frames
+ return loss, info
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ teacher_model: Union[nn.Module, DDP],
+ optimizer: Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ teacher_model:
+ The model for distillation.
+ Used to convert text to tokens.
+ optimizer:
+ The optimizer.
+ scheduler:
+ The learning rate scheduler, we call step() every epoch.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to track the stats over iterations in one epoch
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ model_ema=teacher_model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
+
+ if (
+ params.batch_idx_train % params.valid_interval == 0
+ and not params.print_diagnostics
+ ):
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ teacher_model=teacher_model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(
+ f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
+ f" validation: {valid_info}"
+ )
+ logging.info(
+ f"Maximum memory allocated so far is "
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ params.batch_idx_train += 1
+
+ batch_size = len(batch["text"])
+
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ try:
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ teacher_model=teacher_model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=True,
+ )
+
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ scaler.scale(loss).backward()
+
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ if params.distill_stage == "second":
+ ema(model, teacher_model, params.ema_decay)
+ except Exception as e:
+ logging.info(f"Caught exception : {e}.")
+ save_bad_model()
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
+ break
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have
+ # different behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 1024.0 or (
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
+ ):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if params.batch_idx_train % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"global_batch_idx: {params.batch_idx_train}, "
+ f"batch size: {batch_size}, "
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ f"cur_lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ loss_value = tot_loss["loss"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ teacher_model: Optional[nn.Module],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+
+ model.eval()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ teacher_model=teacher_model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ teacher_model: nn.Module,
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ return_tokens=True,
+ return_feature=True,
+ )
+ try:
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
+
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ teacher_model=teacher_model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is "
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
+ )
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+ params.valid_interval = params.save_every_n
+ # Set epoch to a large number to ignore it.
+ if params.num_iters > 0:
+ params.num_epochs = 1000000
+ with open(params.model_config, "r") as f:
+ model_config = json.load(f)
+ params.update(model_config["model"])
+ params.update(model_config["feature"])
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ os.makedirs(f"{params.exp_dir}", exist_ok=True)
+ copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
+ copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
+ setup_logger(f"{params.exp_dir}/log/log-train")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ if torch.cuda.is_available():
+ params.device = torch.device("cuda", rank)
+ else:
+ params.device = torch.device("cpu")
+ logging.info(f"Device: {params.device}")
+
+ if params.tokenizer == "emilia":
+ tokenizer = EmiliaTokenizer(token_file=params.token_file)
+ elif params.tokenizer == "libritts":
+ tokenizer = LibriTTSTokenizer(token_file=params.token_file)
+ elif params.tokenizer == "espeak":
+ tokenizer = EspeakTokenizer(token_file=params.token_file, lang=params.lang)
+ else:
+ assert params.tokenizer == "simple"
+ tokenizer = SimpleTokenizer(token_file=params.token_file)
+
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
+ params.update(tokenizer_config)
+
+ logging.info(params)
+
+ logging.info("About to create model")
+
+ assert params.teacher_model is not None
+ logging.info(f"Loading pre-trained model from {params.teacher_model}")
+ model = ZipVoiceDistill(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+ _ = load_checkpoint(
+ filename=params.teacher_model,
+ model=model,
+ strict=(params.distill_stage == "second"),
+ )
+
+ if params.distill_stage == "first":
+ teacher_model = ZipVoice(
+ **model_config["model"],
+ **tokenizer_config,
+ )
+ _ = load_checkpoint(
+ filename=params.teacher_model, model=teacher_model, strict=True
+ )
+ else:
+ teacher_model = copy.deepcopy(model)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of parameters : {num_param}")
+
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+ assert params.start_epoch > 0, params.start_epoch
+ if params.start_epoch > 1:
+ logging.info(f"Resuming from epoch {params.start_epoch}")
+ if params.distill_stage == "first":
+ checkpoints = resume_checkpoint(
+ params=params, model=model, model_avg=model_avg
+ )
+ else:
+ checkpoints = resume_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ model_ema=teacher_model,
+ )
+
+ model = model.to(params.device)
+ teacher_model.to(params.device)
+ teacher_model.eval()
+
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ # only update the fm_decoder
+ num_trainable = 0
+ for name, p in model.named_parameters():
+ if "fm_decoder" in name:
+ p.requires_grad = True
+ num_trainable += p.numel()
+ else:
+ p.requires_grad = False
+
+ logging.info(
+ "A total of {} trainable parameters ({:.3f}% of the whole model)".format(
+ num_trainable, num_trainable / num_param * 100
+ )
+ )
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(
+ model,
+ lr=params.base_lr,
+ include_names=True,
+ ),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = FixedLRScheduler(optimizer)
+
+ scaler = create_grad_scaler(enabled=params.use_fp16)
+
+ if params.start_epoch > 1 and checkpoints is not None:
+ # load state_dict for optimizers
+ if "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ # load state_dict for schedulers
+ if "scheduler" in checkpoints:
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
+ if c.duration < min_len or c.duration > max_len:
+ return False
+ return True
+
+ _remove_short_and_long_utt = partial(
+ remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
+ )
+
+ datamodule = TtsDataModule(args)
+ if params.dataset == "emilia":
+ train_cuts = CutSet.mux(
+ datamodule.train_emilia_EN_cuts(),
+ datamodule.train_emilia_ZH_cuts(),
+ weights=[46000, 49000],
+ )
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
+ dev_cuts = CutSet.mux(
+ datamodule.dev_emilia_EN_cuts(),
+ datamodule.dev_emilia_ZH_cuts(),
+ weights=[0.5, 0.5],
+ )
+ elif params.dataset == "libritts":
+ train_cuts = datamodule.train_libritts_cuts()
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
+ dev_cuts = datamodule.dev_libritts_cuts()
+ else:
+ assert params.dataset == "custom"
+ train_cuts = datamodule.train_custom_cuts(params.train_manifest)
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
+ dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
+ # To avoid OOM issues due to too long dev cuts
+ dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
+
+ if params.tokenizer in ["emilia", "espeak", "dialog"]:
+ if not hasattr(train_cuts[0].supervisions[0], "tokens") or not hasattr(
+ dev_cuts[0].supervisions[0], "tokens"
+ ):
+ logging.warning(
+ f"Using {params.tokenizer} tokenizer but tokens are not prepared,"
+ f"will tokenize on-the-fly, which can slow down training significantly."
+ )
+ _tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
+ train_cuts = train_cuts.map(_tokenize_text)
+ dev_cuts = dev_cuts.map(_tokenize_text)
+
+ train_dl = datamodule.train_dataloaders(train_cuts)
+
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
+
+ if params.scan_oom:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ teacher_model=teacher_model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ params=params,
+ )
+ logging.info("Training started")
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ logging.info(f"Start epoch {epoch}")
+
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ params.cur_epoch = epoch
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ teacher_model=teacher_model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
+ break
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint(
+ filename=filename,
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ model_ema=teacher_model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if rank == 0:
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ TtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ main()
diff --git a/zipvoice/dataset/datamodule.py b/zipvoice/dataset/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e9f668989fbf6da6021c353c05ba91f84b0eb77
--- /dev/null
+++ b/zipvoice/dataset/datamodule.py
@@ -0,0 +1,347 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
+# Zengwei Yao,
+# Zengrui Jin,
+# Han Zhu,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
+from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from zipvoice.dataset.dataset import SpeechSynthesisDataset
+from zipvoice.utils.common import str2bool
+from zipvoice.utils.feature import VocosFbank
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+SAMPLING_RATE = 24000
+
+
+class TtsDataModule:
+ """
+ DataModule for tts experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="TTS data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, each batch will have the "
+ "field: batch['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=8,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to create train dataset")
+
+ train = SpeechSynthesisDataset(
+ return_text=True,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(VocosFbank())
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ logging.info("About to create dev dataset")
+ validate = SpeechSynthesisDataset(
+ return_text=True,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(VocosFbank())
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ dev_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create valid dataloader")
+ dev_dl = DataLoader(
+ validate,
+ sampler=dev_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return dev_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.info("About to create test dataset")
+ test = SpeechSynthesisDataset(
+ return_text=True,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(VocosFbank())
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ return_audio=True,
+ )
+ test_sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=test_sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_custom_cuts(self, manifest_file) -> CutSet:
+ logging.info(f"About to get the custom training cuts {manifest_file}")
+ return load_manifest_lazy(manifest_file)
+
+ @lru_cache()
+ def dev_custom_cuts(self, manifest_file) -> CutSet:
+ logging.info(f"About to get the custom validation cuts {manifest_file}")
+ return load_manifest_lazy(manifest_file)
+
+ @lru_cache()
+ def train_emilia_EN_cuts(self) -> CutSet:
+ logging.info("About to get train the EN subset")
+ return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_EN.jsonl.gz")
+
+ @lru_cache()
+ def train_emilia_ZH_cuts(self) -> CutSet:
+ logging.info("About to get train the ZH subset")
+ return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_ZH.jsonl.gz")
+
+ @lru_cache()
+ def dev_emilia_EN_cuts(self) -> CutSet:
+ logging.info("About to get dev the EN subset")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "emilia_cuts_EN-dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_emilia_ZH_cuts(self) -> CutSet:
+ logging.info("About to get dev the ZH subset")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "emilia_cuts_ZH-dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_libritts_cuts(self) -> CutSet:
+ logging.info(
+ "About to get the shuffled train-clean-100, \
+ train-clean-360 and train-other-500 cuts"
+ )
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_libritts_cuts(self) -> CutSet:
+ logging.info("About to get dev-clean cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_opendialog_en_cuts(self) -> CutSet:
+ logging.info("About to ge the EN train subset of OpenDialog")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "opendialog_cuts_EN-train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_opendialog_zh_cuts(self) -> CutSet:
+ logging.info("About to get the ZH train subset of OpenDialog")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "opendialog_cuts_ZH-train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_opendialog_en_cuts(self) -> CutSet:
+ logging.info("About to ge the EN dev subset of OpenDialog")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "opendialog_cuts_EN-dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_opendialog_zh_cuts(self) -> CutSet:
+ logging.info("About to get the ZH dev subset of OpenDialog")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "opendialog_cuts_ZH-dev.jsonl.gz"
+ )
diff --git a/zipvoice/dataset/dataset.py b/zipvoice/dataset/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5948e13ecff7c845f9ce1c8c99da7eefc71ce4f
--- /dev/null
+++ b/zipvoice/dataset/dataset.py
@@ -0,0 +1,105 @@
+from typing import Callable, Dict, List, Sequence, Union
+
+import torch
+from lhotse import CutSet, validate
+from lhotse.dataset import PrecomputedFeatures
+from lhotse.dataset.collation import collate_audio
+from lhotse.dataset.input_strategies import BatchIO
+from lhotse.utils import ifnone
+
+
+class SpeechSynthesisDataset(torch.utils.data.Dataset):
+ """
+ The PyTorch Dataset for the speech synthesis task.
+ Each item in this dataset is a dict of:
+
+ .. code-block::
+
+ {
+ 'audio': (B x NumSamples) float tensor
+ 'features': (B x NumFrames x NumFeatures) float tensor
+ 'audio_lens': (B, ) int tensor
+ 'features_lens': (B, ) int tensor
+ 'text': List[str] of len B # when return_text=True
+ 'tokens': List[List[str]] # when return_tokens=True
+ 'speakers': List[str] of len B # when return_spk_ids=True
+ 'cut': List of Cuts # when return_cuts=True
+ }
+ """
+
+ def __init__(
+ self,
+ cut_transforms: List[Callable[[CutSet], CutSet]] = None,
+ feature_input_strategy: BatchIO = PrecomputedFeatures(),
+ feature_transforms: Union[Sequence[Callable], Callable] = None,
+ return_text: bool = True,
+ return_tokens: bool = False,
+ return_spk_ids: bool = False,
+ return_cuts: bool = False,
+ return_audio: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.cut_transforms = ifnone(cut_transforms, [])
+ self.feature_input_strategy = feature_input_strategy
+
+ self.return_text = return_text
+ self.return_tokens = return_tokens
+ self.return_spk_ids = return_spk_ids
+ self.return_cuts = return_cuts
+ self.return_audio = return_audio
+
+ if feature_transforms is None:
+ feature_transforms = []
+ elif not isinstance(feature_transforms, Sequence):
+ feature_transforms = [feature_transforms]
+
+ assert all(
+ isinstance(transform, Callable) for transform in feature_transforms
+ ), "Feature transforms must be Callable"
+ self.feature_transforms = feature_transforms
+
+ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
+ validate_for_tts(cuts)
+
+ for transform in self.cut_transforms:
+ cuts = transform(cuts)
+
+ features, features_lens = self.feature_input_strategy(cuts)
+
+ for transform in self.feature_transforms:
+ features = transform(features)
+
+ batch = {
+ "features": features,
+ "features_lens": features_lens,
+ }
+
+ if self.return_audio:
+ audio, audio_lens = collate_audio(cuts)
+ batch["audio"] = audio
+ batch["audio_lens"] = audio_lens
+
+ if self.return_text:
+ text = [cut.supervisions[0].text for cut in cuts]
+ batch["text"] = text
+
+ if self.return_tokens:
+ tokens = [cut.supervisions[0].tokens for cut in cuts]
+ batch["tokens"] = tokens
+
+ if self.return_spk_ids:
+ batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
+
+ if self.return_cuts:
+ batch["cut"] = [cut for cut in cuts]
+
+ return batch
+
+
+def validate_for_tts(cuts: CutSet) -> None:
+ validate(cuts)
+ for cut in cuts:
+ assert (
+ len(cut.supervisions) == 1
+ ), "Only the Cuts with single supervision are supported."
diff --git a/zipvoice/eval/models/ecapa_tdnn_wavllm.py b/zipvoice/eval/models/ecapa_tdnn_wavllm.py
new file mode 100644
index 0000000000000000000000000000000000000000..bedbdcd5f660eaaa867b51a2d95d7c30fe31db91
--- /dev/null
+++ b/zipvoice/eval/models/ecapa_tdnn_wavllm.py
@@ -0,0 +1,357 @@
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ECAPA_TDNN_WAVLLM(nn.Module):
+ def __init__(
+ self,
+ feat_dim=80,
+ channels=512,
+ emb_dim=192,
+ global_context_att=False,
+ sr=16000,
+ ssl_model_path=None,
+ ):
+ super().__init__()
+ self.sr = sr
+
+ if ssl_model_path is None:
+ self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
+ else:
+ self.feature_extract = torch.hub.load(
+ os.path.dirname(ssl_model_path),
+ "wavlm_local",
+ source="local",
+ ckpt=ssl_model_path,
+ )
+
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
+ self.feature_extract.model.encoder.layers[23].self_attn,
+ "fp32_attention",
+ ):
+ self.feature_extract.model.encoder.layers[
+ 23
+ ].self_attn.fp32_attention = False
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
+ self.feature_extract.model.encoder.layers[11].self_attn,
+ "fp32_attention",
+ ):
+ self.feature_extract.model.encoder.layers[
+ 11
+ ].self_attn.fp32_attention = False
+
+ self.feat_num = self.get_feat_num()
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
+
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
+ # self.channels = [channels] * 4 + [channels * 3]
+ self.channels = [channels] * 4 + [1536]
+
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
+ self.layer2 = SE_Res2Block(
+ self.channels[0],
+ self.channels[1],
+ kernel_size=3,
+ stride=1,
+ padding=2,
+ dilation=2,
+ scale=8,
+ se_bottleneck_dim=128,
+ )
+ self.layer3 = SE_Res2Block(
+ self.channels[1],
+ self.channels[2],
+ kernel_size=3,
+ stride=1,
+ padding=3,
+ dilation=3,
+ scale=8,
+ se_bottleneck_dim=128,
+ )
+ self.layer4 = SE_Res2Block(
+ self.channels[2],
+ self.channels[3],
+ kernel_size=3,
+ stride=1,
+ padding=4,
+ dilation=4,
+ scale=8,
+ se_bottleneck_dim=128,
+ )
+
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
+ cat_channels = channels * 3
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
+ self.pooling = AttentiveStatsPool(
+ self.channels[-1],
+ attention_channels=128,
+ global_context_att=global_context_att,
+ )
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
+
+ def get_feat_num(self):
+ self.feature_extract.eval()
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
+ with torch.no_grad():
+ features = self.feature_extract(wav)
+ select_feature = features["hidden_states"]
+ if isinstance(select_feature, (list, tuple)):
+ return len(select_feature)
+ else:
+ return 1
+
+ def get_feat(self, x):
+ with torch.no_grad():
+ x = self.feature_extract([sample for sample in x])
+
+ x = x["hidden_states"]
+ if isinstance(x, (list, tuple)):
+ x = torch.stack(x, dim=0)
+ else:
+ x = x.unsqueeze(0)
+ norm_weights = (
+ F.softmax(self.feature_weight, dim=-1)
+ .unsqueeze(-1)
+ .unsqueeze(-1)
+ .unsqueeze(-1)
+ )
+ x = (norm_weights * x).sum(dim=0)
+ x = torch.transpose(x, 1, 2) + 1e-6
+
+ x = self.instance_norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.get_feat(x)
+
+ out1 = self.layer1(x)
+ out2 = self.layer2(out1)
+ out3 = self.layer3(out2)
+ out4 = self.layer4(out3)
+
+ out = torch.cat([out2, out3, out4], dim=1)
+ out = F.relu(self.conv(out))
+ out = self.bn(self.pooling(out))
+ out = self.linear(out)
+
+ return out
+
+
+# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
+
+""" Res2Conv1d + BatchNorm1d + ReLU
+"""
+
+
+class Res2Conv1dReluBn(nn.Module):
+ """
+ in_channels == out_channels == channels
+ """
+
+ def __init__(
+ self,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=True,
+ scale=4,
+ ):
+ super().__init__()
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
+ self.scale = scale
+ self.width = channels // scale
+ self.nums = scale if scale == 1 else scale - 1
+
+ self.convs = []
+ self.bns = []
+ for i in range(self.nums):
+ self.convs.append(
+ nn.Conv1d(
+ self.width,
+ self.width,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ bias=bias,
+ )
+ )
+ self.bns.append(nn.BatchNorm1d(self.width))
+ self.convs = nn.ModuleList(self.convs)
+ self.bns = nn.ModuleList(self.bns)
+
+ def forward(self, x):
+ out = []
+ spx = torch.split(x, self.width, 1)
+ for i in range(self.nums):
+ if i == 0:
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ # Order: conv -> relu -> bn
+ sp = self.convs[i](sp)
+ sp = self.bns[i](F.relu(sp))
+ out.append(sp)
+ if self.scale != 1:
+ out.append(spx[self.nums])
+ out = torch.cat(out, dim=1)
+
+ return out
+
+
+""" Conv1d + BatchNorm1d + ReLU
+"""
+
+
+class Conv1dReluBn(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=True,
+ ):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ bias=bias,
+ )
+ self.bn = nn.BatchNorm1d(out_channels)
+
+ def forward(self, x):
+ return self.bn(F.relu(self.conv(x)))
+
+
+""" The SE connection of 1D case.
+"""
+
+
+class SE_Connect(nn.Module):
+ def __init__(self, channels, se_bottleneck_dim=128):
+ super().__init__()
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
+
+ def forward(self, x):
+ out = x.mean(dim=2)
+ out = F.relu(self.linear1(out))
+ out = torch.sigmoid(self.linear2(out))
+ out = x * out.unsqueeze(2)
+
+ return out
+
+
+""" SE-Res2Block of the ECAPA-TDNN architecture.
+"""
+
+
+# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
+# return nn.Sequential(
+# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
+# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
+# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
+# SE_Connect(channels)
+# )
+
+
+class SE_Res2Block(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ scale,
+ se_bottleneck_dim,
+ ):
+ super().__init__()
+ self.Conv1dReluBn1 = Conv1dReluBn(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(
+ out_channels, kernel_size, stride, padding, dilation, scale=scale
+ )
+ self.Conv1dReluBn2 = Conv1dReluBn(
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
+
+ self.shortcut = None
+ if in_channels != out_channels:
+ self.shortcut = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ residual = x
+ if self.shortcut:
+ residual = self.shortcut(x)
+
+ x = self.Conv1dReluBn1(x)
+ x = self.Res2Conv1dReluBn(x)
+ x = self.Conv1dReluBn2(x)
+ x = self.SE_Connect(x)
+
+ return x + residual
+
+
+""" Attentive weighted mean and standard deviation pooling.
+"""
+
+
+class AttentiveStatsPool(nn.Module):
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
+ super().__init__()
+ self.global_context_att = global_context_att
+
+ # Use Conv1d with stride == 1 rather than Linear,
+ # then we don't need to transpose inputs.
+ if global_context_att:
+ self.linear1 = nn.Conv1d(
+ in_dim * 3, attention_channels, kernel_size=1
+ ) # equals W and b in the paper
+ else:
+ self.linear1 = nn.Conv1d(
+ in_dim, attention_channels, kernel_size=1
+ ) # equals W and b in the paper
+ self.linear2 = nn.Conv1d(
+ attention_channels, in_dim, kernel_size=1
+ ) # equals V and k in the paper
+
+ def forward(self, x):
+
+ if self.global_context_att:
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
+ context_std = torch.sqrt(
+ torch.var(x, dim=-1, keepdim=True) + 1e-10
+ ).expand_as(x)
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
+ else:
+ x_in = x
+
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
+ alpha = torch.tanh(self.linear1(x_in))
+ # alpha = F.relu(self.linear1(x_in))
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
+ mean = torch.sum(alpha * x, dim=2)
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
+ std = torch.sqrt(residuals.clamp(min=1e-9))
+ return torch.cat([mean, std], dim=1)
diff --git a/zipvoice/eval/models/ecapa_tdnn_wavlm.py b/zipvoice/eval/models/ecapa_tdnn_wavlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2aa40b2671533580a83aaa70064a519de4570d1
--- /dev/null
+++ b/zipvoice/eval/models/ecapa_tdnn_wavlm.py
@@ -0,0 +1,357 @@
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ECAPA_TDNN_WAVLM(nn.Module):
+ def __init__(
+ self,
+ feat_dim=80,
+ channels=512,
+ emb_dim=192,
+ global_context_att=False,
+ sr=16000,
+ ssl_model_path=None,
+ ):
+ super().__init__()
+ self.sr = sr
+
+ if ssl_model_path is None:
+ self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
+ else:
+ self.feature_extract = torch.hub.load(
+ os.path.dirname(ssl_model_path),
+ "wavlm_local",
+ source="local",
+ ckpt=os.path.join(ssl_model_path, "wavlm_large.pt"),
+ )
+
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
+ self.feature_extract.model.encoder.layers[23].self_attn,
+ "fp32_attention",
+ ):
+ self.feature_extract.model.encoder.layers[
+ 23
+ ].self_attn.fp32_attention = False
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
+ self.feature_extract.model.encoder.layers[11].self_attn,
+ "fp32_attention",
+ ):
+ self.feature_extract.model.encoder.layers[
+ 11
+ ].self_attn.fp32_attention = False
+
+ self.feat_num = self.get_feat_num()
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
+
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
+ # self.channels = [channels] * 4 + [channels * 3]
+ self.channels = [channels] * 4 + [1536]
+
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
+ self.layer2 = SE_Res2Block(
+ self.channels[0],
+ self.channels[1],
+ kernel_size=3,
+ stride=1,
+ padding=2,
+ dilation=2,
+ scale=8,
+ se_bottleneck_dim=128,
+ )
+ self.layer3 = SE_Res2Block(
+ self.channels[1],
+ self.channels[2],
+ kernel_size=3,
+ stride=1,
+ padding=3,
+ dilation=3,
+ scale=8,
+ se_bottleneck_dim=128,
+ )
+ self.layer4 = SE_Res2Block(
+ self.channels[2],
+ self.channels[3],
+ kernel_size=3,
+ stride=1,
+ padding=4,
+ dilation=4,
+ scale=8,
+ se_bottleneck_dim=128,
+ )
+
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
+ cat_channels = channels * 3
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
+ self.pooling = AttentiveStatsPool(
+ self.channels[-1],
+ attention_channels=128,
+ global_context_att=global_context_att,
+ )
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
+
+ def get_feat_num(self):
+ self.feature_extract.eval()
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
+ with torch.no_grad():
+ features = self.feature_extract(wav)
+ select_feature = features["hidden_states"]
+ if isinstance(select_feature, (list, tuple)):
+ return len(select_feature)
+ else:
+ return 1
+
+ def get_feat(self, x):
+ with torch.no_grad():
+ x = self.feature_extract([sample for sample in x])
+
+ x = x["hidden_states"]
+ if isinstance(x, (list, tuple)):
+ x = torch.stack(x, dim=0)
+ else:
+ x = x.unsqueeze(0)
+ norm_weights = (
+ F.softmax(self.feature_weight, dim=-1)
+ .unsqueeze(-1)
+ .unsqueeze(-1)
+ .unsqueeze(-1)
+ )
+ x = (norm_weights * x).sum(dim=0)
+ x = torch.transpose(x, 1, 2) + 1e-6
+
+ x = self.instance_norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.get_feat(x)
+
+ out1 = self.layer1(x)
+ out2 = self.layer2(out1)
+ out3 = self.layer3(out2)
+ out4 = self.layer4(out3)
+
+ out = torch.cat([out2, out3, out4], dim=1)
+ out = F.relu(self.conv(out))
+ out = self.bn(self.pooling(out))
+ out = self.linear(out)
+
+ return out
+
+
+# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
+
+""" Res2Conv1d + BatchNorm1d + ReLU
+"""
+
+
+class Res2Conv1dReluBn(nn.Module):
+ """
+ in_channels == out_channels == channels
+ """
+
+ def __init__(
+ self,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=True,
+ scale=4,
+ ):
+ super().__init__()
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
+ self.scale = scale
+ self.width = channels // scale
+ self.nums = scale if scale == 1 else scale - 1
+
+ self.convs = []
+ self.bns = []
+ for i in range(self.nums):
+ self.convs.append(
+ nn.Conv1d(
+ self.width,
+ self.width,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ bias=bias,
+ )
+ )
+ self.bns.append(nn.BatchNorm1d(self.width))
+ self.convs = nn.ModuleList(self.convs)
+ self.bns = nn.ModuleList(self.bns)
+
+ def forward(self, x):
+ out = []
+ spx = torch.split(x, self.width, 1)
+ for i in range(self.nums):
+ if i == 0:
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ # Order: conv -> relu -> bn
+ sp = self.convs[i](sp)
+ sp = self.bns[i](F.relu(sp))
+ out.append(sp)
+ if self.scale != 1:
+ out.append(spx[self.nums])
+ out = torch.cat(out, dim=1)
+
+ return out
+
+
+""" Conv1d + BatchNorm1d + ReLU
+"""
+
+
+class Conv1dReluBn(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=True,
+ ):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ bias=bias,
+ )
+ self.bn = nn.BatchNorm1d(out_channels)
+
+ def forward(self, x):
+ return self.bn(F.relu(self.conv(x)))
+
+
+""" The SE connection of 1D case.
+"""
+
+
+class SE_Connect(nn.Module):
+ def __init__(self, channels, se_bottleneck_dim=128):
+ super().__init__()
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
+
+ def forward(self, x):
+ out = x.mean(dim=2)
+ out = F.relu(self.linear1(out))
+ out = torch.sigmoid(self.linear2(out))
+ out = x * out.unsqueeze(2)
+
+ return out
+
+
+""" SE-Res2Block of the ECAPA-TDNN architecture.
+"""
+
+
+# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
+# return nn.Sequential(
+# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
+# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
+# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
+# SE_Connect(channels)
+# )
+
+
+class SE_Res2Block(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ scale,
+ se_bottleneck_dim,
+ ):
+ super().__init__()
+ self.Conv1dReluBn1 = Conv1dReluBn(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(
+ out_channels, kernel_size, stride, padding, dilation, scale=scale
+ )
+ self.Conv1dReluBn2 = Conv1dReluBn(
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
+
+ self.shortcut = None
+ if in_channels != out_channels:
+ self.shortcut = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ residual = x
+ if self.shortcut:
+ residual = self.shortcut(x)
+
+ x = self.Conv1dReluBn1(x)
+ x = self.Res2Conv1dReluBn(x)
+ x = self.Conv1dReluBn2(x)
+ x = self.SE_Connect(x)
+
+ return x + residual
+
+
+""" Attentive weighted mean and standard deviation pooling.
+"""
+
+
+class AttentiveStatsPool(nn.Module):
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
+ super().__init__()
+ self.global_context_att = global_context_att
+
+ # Use Conv1d with stride == 1 rather than Linear,
+ # then we don't need to transpose inputs.
+ if global_context_att:
+ self.linear1 = nn.Conv1d(
+ in_dim * 3, attention_channels, kernel_size=1
+ ) # equals W and b in the paper
+ else:
+ self.linear1 = nn.Conv1d(
+ in_dim, attention_channels, kernel_size=1
+ ) # equals W and b in the paper
+ self.linear2 = nn.Conv1d(
+ attention_channels, in_dim, kernel_size=1
+ ) # equals V and k in the paper
+
+ def forward(self, x):
+
+ if self.global_context_att:
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
+ context_std = torch.sqrt(
+ torch.var(x, dim=-1, keepdim=True) + 1e-10
+ ).expand_as(x)
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
+ else:
+ x_in = x
+
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
+ alpha = torch.tanh(self.linear1(x_in))
+ # alpha = F.relu(self.linear1(x_in))
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
+ mean = torch.sum(alpha * x, dim=2)
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
+ std = torch.sqrt(residuals.clamp(min=1e-9))
+ return torch.cat([mean, std], dim=1)
diff --git a/zipvoice/eval/models/utmos.py b/zipvoice/eval/models/utmos.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd102aeef7691c9bc4318e238c5a78ef95d69fde
--- /dev/null
+++ b/zipvoice/eval/models/utmos.py
@@ -0,0 +1,354 @@
+"""
+UTMOS strong model.
+Implementation from https://github.com/tarepan/SpeechMOS
+
+"""
+
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import torchaudio # pyright: ignore [reportMissingTypeStubs]
+from torch import Tensor, nn
+
+
+class UTMOS22Strong(nn.Module):
+ """Saeki_2022 paper's `UTMOS strong learner` inference model
+ (w/o Phoneme encoder)."""
+
+ def __init__(self):
+ """Init."""
+
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
+
+ feat_ssl, feat_domain_emb, feat_judge_emb, feat_rnn_h, feat_proj_h = (
+ 768,
+ 128,
+ 128,
+ 512,
+ 2048,
+ )
+ feat_cat = feat_ssl + feat_domain_emb + feat_judge_emb
+
+ # SSL/DataDomainEmb/JudgeIdEmb/BLSTM/Projection
+ self.wav2vec2 = Wav2Vec2Model()
+ self.domain_emb = nn.Parameter(
+ data=torch.empty(1, feat_domain_emb), requires_grad=False
+ )
+ self.judge_emb = nn.Parameter(
+ data=torch.empty(1, feat_judge_emb), requires_grad=False
+ )
+ self.blstm = nn.LSTM(
+ input_size=feat_cat,
+ hidden_size=feat_rnn_h,
+ batch_first=True,
+ bidirectional=True,
+ )
+ self.projection = nn.Sequential(
+ nn.Linear(feat_rnn_h * 2, feat_proj_h), nn.ReLU(), nn.Linear(feat_proj_h, 1)
+ )
+
+ def forward(self, wave: Tensor, sr: int) -> Tensor: # pylint: disable=invalid-name
+ """wave-to-score :: (B, T) -> (B,)"""
+
+ # Feature extraction :: (B, T) -> (B, Frame, Feat)
+ unit_series = self.wav2vec2(wave)
+ bsz, frm, _ = unit_series.size()
+
+ # DataDomain/JudgeId Embedding's Batch/Time expansion ::
+ # (B=1, Feat) -> (B=bsz, Frame=frm, Feat)
+ domain_series = self.domain_emb.unsqueeze(1).expand(bsz, frm, -1)
+ judge_series = self.judge_emb.unsqueeze(1).expand(bsz, frm, -1)
+
+ # Feature concatenation :: (B, Frame, Feat=f1) + (B, Frame, Feat=f2) +
+ # (B, Frame, Feat=f3) -> (B, Frame, Feat=f1+f2+f3)
+ cat_series = torch.cat([unit_series, domain_series, judge_series], dim=2)
+
+ # Frame-scale score estimation :: (B, Frame, Feat) -> (B, Frame, Feat)
+ # -> (B, Frame, Feat=1) - BLSTM/Projection
+ feat_series = self.blstm(cat_series)[0]
+ score_series = self.projection(feat_series)
+
+ # Utterance-scale score :: (B, Frame, Feat=1) -> (B, Feat=1)
+ # -> (B,) - Time averaging
+ utter_score = score_series.mean(dim=1).squeeze(1) * 2 + 3
+
+ return utter_score
+
+
+class Wav2Vec2Model(nn.Module):
+ """Wav2Vev2."""
+
+ def __init__(self):
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
+
+ feat_h1, feat_h2 = 512, 768
+ feature_enc_layers = (
+ [(feat_h1, 10, 5)] + [(feat_h1, 3, 2)] * 4 + [(feat_h1, 2, 2)] * 2
+ )
+
+ self.feature_extractor = ConvFeatureExtractionModel(
+ conv_layers=feature_enc_layers
+ ) # pyright: ignore [reportGeneralTypeIssues]
+ self.layer_norm = nn.LayerNorm(feat_h1)
+ self.post_extract_proj = nn.Linear(feat_h1, feat_h2)
+ self.dropout_input = nn.Dropout(0.1)
+ self.encoder = TransformerEncoder(feat_h2)
+
+ # Remnants
+ self.mask_emb = nn.Parameter(torch.FloatTensor(feat_h2))
+
+ def forward(self, source: Tensor):
+ """FeatureEncoder + ContextTransformer"""
+
+ # Feature encoding
+ features = self.feature_extractor(source)
+ features = features.transpose(1, 2)
+ features = self.layer_norm(features)
+ features = self.post_extract_proj(features)
+
+ # Context transformer
+ x = self.encoder(features)
+
+ return x
+
+
+class ConvFeatureExtractionModel(nn.Module):
+ """Feature Encoder."""
+
+ def __init__(self, conv_layers: List[Tuple[int, int, int]]):
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
+
+ def block(
+ n_in: int, n_out: int, k: int, stride: int, is_group_norm: bool = False
+ ):
+ if is_group_norm:
+ return nn.Sequential(
+ nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
+ nn.Dropout(p=0.0),
+ nn.GroupNorm(dim, dim, affine=True),
+ nn.GELU(),
+ )
+ else:
+ return nn.Sequential(
+ nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
+ nn.Dropout(p=0.0),
+ nn.GELU(),
+ )
+
+ in_d = 1
+ self.conv_layers = nn.ModuleList()
+ for i, params in enumerate(conv_layers):
+ (dim, k, stride) = params
+ self.conv_layers.append(block(in_d, dim, k, stride, is_group_norm=i == 0))
+ in_d = dim
+
+ def forward(self, series: Tensor) -> Tensor:
+ """:: (B, T) -> (B, Feat, Frame)"""
+
+ series = series.unsqueeze(1)
+ for conv in self.conv_layers:
+ series = conv(series)
+
+ return series
+
+
+class TransformerEncoder(nn.Module):
+ """Transformer."""
+
+ def build_encoder_layer(self, feat: int):
+ """Layer builder."""
+ return TransformerSentenceEncoderLayer(
+ embedding_dim=feat,
+ ffn_embedding_dim=3072,
+ num_attention_heads=12,
+ activation_fn="gelu",
+ dropout=0.1,
+ attention_dropout=0.1,
+ activation_dropout=0.0,
+ layer_norm_first=False,
+ )
+
+ def __init__(self, feat: int):
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
+
+ self.required_seq_len_multiple = 2
+
+ self.pos_conv = nn.Sequential(
+ *[
+ nn.utils.weight_norm(
+ nn.Conv1d(feat, feat, kernel_size=128, padding=128 // 2, groups=16),
+ name="weight",
+ dim=2,
+ ),
+ SamePad(128),
+ nn.GELU(),
+ ]
+ )
+ self.layer_norm = nn.LayerNorm(feat)
+ self.layers = nn.ModuleList([self.build_encoder_layer(feat) for _ in range(12)])
+
+ def forward(self, x: Tensor) -> Tensor:
+
+ x_conv = self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
+ x = x + x_conv
+
+ x = self.layer_norm(x)
+
+ # pad to the sequence length dimension
+ x, pad_length = pad_to_multiple(
+ x, self.required_seq_len_multiple, dim=-2, value=0
+ )
+ if pad_length > 0:
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
+ padding_mask[:, -pad_length:] = True
+ else:
+ padding_mask, _ = pad_to_multiple(
+ None, self.required_seq_len_multiple, dim=-1, value=True
+ )
+
+ # :: (B, T, Feat) -> (T, B, Feat)
+ x = x.transpose(0, 1)
+ for layer in self.layers:
+ x = layer(x, padding_mask)
+ # :: (T, B, Feat) -> (B, T, Feat)
+ x = x.transpose(0, 1)
+
+ # undo paddding
+ if pad_length > 0:
+ x = x[:, :-pad_length]
+
+ return x
+
+
+class SamePad(nn.Module):
+ """Tail inverse padding."""
+
+ def __init__(self, kernel_size: int):
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
+ assert kernel_size % 2 == 0, "`SamePad` now support only even kernel."
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x[:, :, :-1]
+
+
+def pad_to_multiple(
+ x: Optional[Tensor], multiple: int, dim: int = -1, value: float = 0
+) -> Tuple[Optional[Tensor], int]:
+ """Tail padding."""
+ if x is None:
+ return None, 0
+ tsz = x.size(dim)
+ m = tsz / multiple
+ remainder = math.ceil(m) * multiple - tsz
+ if m.is_integer():
+ return x, 0
+ pad_offset = (0,) * (-1 - dim) * 2
+
+ return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
+
+
+class TransformerSentenceEncoderLayer(nn.Module):
+ """Transformer Encoder Layer used in BERT/XLM style pre-trained models."""
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ ffn_embedding_dim: int,
+ num_attention_heads: int,
+ activation_fn: str,
+ dropout: float,
+ attention_dropout: float,
+ activation_dropout: float,
+ layer_norm_first: bool,
+ ) -> None:
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
+
+ assert layer_norm_first is False, "`layer_norm_first` is fixed to `False`"
+ assert activation_fn == "gelu", "`activation_fn` is fixed to `gelu`"
+
+ feat = embedding_dim
+
+ self.self_attn = MultiheadAttention(
+ feat, num_attention_heads, attention_dropout
+ )
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(activation_dropout)
+ self.dropout3 = nn.Dropout(dropout)
+ self.fc1 = nn.Linear(feat, ffn_embedding_dim)
+ self.fc2 = nn.Linear(ffn_embedding_dim, feat)
+ self.self_attn_layer_norm = nn.LayerNorm(feat)
+ self.final_layer_norm = nn.LayerNorm(feat)
+
+ def forward(self, x: Tensor, self_attn_padding_mask: Optional[Tensor]):
+ # Res[Attn-Do]-LN
+ residual = x
+ x = self.self_attn(x, x, x, self_attn_padding_mask)
+ x = self.dropout1(x)
+ x = residual + x
+ x = self.self_attn_layer_norm(x)
+
+ # Res[SegFC-GELU-Do-SegFC-Do]-LN
+ residual = x
+ x = F.gelu(self.fc1(x)) # pyright: ignore [reportUnknownMemberType]
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ x = self.final_layer_norm(x)
+
+ return x
+
+
+class MultiheadAttention(nn.Module):
+ """Multi-headed attention."""
+
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float):
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
+
+ self.embed_dim, self.num_heads, self.p_dropout = embed_dim, num_heads, dropout
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ key_padding_mask: Optional[Tensor],
+ ) -> Tensor:
+ """
+ Args:
+ query :: (T, B, Feat)
+ key_padding_mask :: (B, src_len) - mask to exclude keys that are pads
+ , where padding elements are indicated by 1s.
+ """
+ return F.multi_head_attention_forward(
+ query=query,
+ key=key,
+ value=value,
+ embed_dim_to_check=self.embed_dim,
+ num_heads=self.num_heads,
+ in_proj_weight=torch.empty([0]),
+ in_proj_bias=torch.cat(
+ (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)
+ ),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=self.p_dropout,
+ out_proj_weight=self.out_proj.weight,
+ out_proj_bias=self.out_proj.bias,
+ training=False,
+ key_padding_mask=key_padding_mask.bool()
+ if key_padding_mask is not None
+ else None,
+ need_weights=False,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ )[0]
diff --git a/zipvoice/eval/mos/utmos.py b/zipvoice/eval/mos/utmos.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcd96684d3f8d940fdbcd551fc07c8d353155347
--- /dev/null
+++ b/zipvoice/eval/mos/utmos.py
@@ -0,0 +1,174 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
+"""
+import argparse
+import logging
+import os
+from typing import List
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from zipvoice.eval.models.utmos import UTMOS22Strong
+from zipvoice.eval.utils import load_waveform
+
+
+def get_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(
+ description="Calculate UTMOS score using UTMOS22Strong model."
+ )
+
+ parser.add_argument(
+ "--wav-path",
+ type=str,
+ required=True,
+ help="Path to the directory containing evaluated speech files.",
+ )
+ parser.add_argument(
+ "--model-dir",
+ type=str,
+ required=True,
+ help="Local path of our evaluatioin model repository."
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models."
+ "Will use 'tts_eval_models/mos/utmos22_strong_step7459_v1.pt'"
+ " in this script",
+ )
+
+ parser.add_argument(
+ "--extension",
+ type=str,
+ default="wav",
+ help="Extension of the speech files. Default: wav",
+ )
+ return parser
+
+
+class UTMOSScore:
+ """Predicting UTMOS score for each audio clip."""
+
+ def __init__(self, model_path: str):
+ """
+ Initializes the UTMOS score evaluator with the specified model.
+
+ Args:
+ model_path (str): Path of the UTMOS model checkpoint.
+ """
+ self.sample_rate = 16000
+ self.device = (
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ )
+ logging.info(f"Using device: {self.device}")
+
+ # Initialize and load the model
+ self.model = UTMOS22Strong()
+ try:
+ state_dict = torch.load(
+ model_path, map_location=lambda storage, loc: storage
+ )
+ self.model.load_state_dict(state_dict)
+ except Exception as e:
+ logging.error(f"Failed to load model from {model_path}: {e}")
+ raise
+
+ self.model.to(self.device)
+ self.model.eval()
+
+ @torch.no_grad()
+ def score_files(self, wav_paths: List[str]) -> List[float]:
+ """
+ Computes UTMOS scores for a list of audio files.
+
+ Args:
+ wav_paths (List[str]): List of paths to audio files.
+
+ Returns:
+ List[float]: List of UTMOS scores.
+ """
+ scores = []
+ for wav_path in tqdm(wav_paths, desc="Scoring audio files"):
+ # Load and preprocess waveform
+ speech = load_waveform(wav_path, self.sample_rate, device=self.device)
+ # Compute score
+ score = self.model(speech.unsqueeze(0), self.sample_rate)
+ scores.append(score.item())
+
+ return scores
+
+ def score_dir(self, dir_path: str, extension: str) -> float:
+ """
+ Computes the average UTMOS score for all files in a directory.
+
+ Args:
+ dir_path (str): Path to the directory containing audio files.
+
+ Returns:
+ float: Average UTMOS score for the directory.
+ """
+ logging.info(f"Calculating UTMOS score for {dir_path}")
+
+ # Get list of wav files
+ wav_files = [
+ os.path.join(dir_path, f)
+ for f in os.listdir(dir_path)
+ if f.lower().endswith(extension)
+ ]
+
+ if not wav_files:
+ raise ValueError(f"No audio files found in {dir_path}")
+
+ # Compute scores
+ scores = self.score_files(wav_files)
+
+ return float(np.mean(scores))
+
+
+if __name__ == "__main__":
+
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ parser = get_parser()
+ args = parser.parse_args()
+
+ # Validate input path
+ if not os.path.isdir(args.wav_path):
+ logging.error(f"Invalid directory: {args.wav_path}")
+ exit(1)
+
+ # Initialize evaluator
+ model_path = os.path.join(args.model_dir, "mos/utmos22_strong_step7459_v1.pt")
+ if not os.path.exists(model_path):
+ logging.error(
+ "Please download evaluation models from "
+ "https://huggingface.co/k2-fsa/TTS_eval_models"
+ " and pass this dir with --model-dir"
+ )
+ exit(1)
+ utmos_evaluator = UTMOSScore(model_path)
+
+ # Compute UTMOS score
+ score = utmos_evaluator.score_dir(args.wav_path, args.extension)
+ print("-" * 50)
+ logging.info(f"UTMOS score: {score:.2f}")
+ print("-" * 50)
diff --git a/zipvoice/eval/speaker_similarity/cpsim.py b/zipvoice/eval/speaker_similarity/cpsim.py
new file mode 100644
index 0000000000000000000000000000000000000000..998926c47a34adc84fe8493b902a18289e30c280
--- /dev/null
+++ b/zipvoice/eval/speaker_similarity/cpsim.py
@@ -0,0 +1,411 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Computes concatenated maximum permutation speaker similarity (cpSIM) scores using:
+- A WavLM-based ECAPA-TDNN model for speaker embedding extraction.
+- A pyannote pipeline for speaker diarization (segmenting speakers).
+"""
+import argparse
+import logging
+import os
+import warnings
+from typing import List, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from pyannote.audio import Pipeline
+from tqdm import tqdm
+
+from zipvoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
+from zipvoice.eval.utils import load_waveform
+
+warnings.filterwarnings("ignore")
+
+
+def get_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(
+ description="Calculate concatenated maximum permutation speaker "
+ "similarity (cpSIM) score."
+ )
+ parser.add_argument(
+ "--wav-path",
+ type=str,
+ required=True,
+ help="Path to the directory containing evaluated speech files.",
+ )
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ help="Path to the tsv file for speaker splitted prompts. "
+ "Each line contains (audio_name, prompt_text_1, prompt_text_2, "
+ "prompt_audio_1, prompt_audio_2, text) separated by tabs.",
+ )
+
+ parser.add_argument(
+ "--test-list-merge",
+ type=str,
+ help="Path to the tsv file for merged dialogue prompts. "
+ "Each line contains (audio_name, prompt_text_dialogue, "
+ "prompt_audio_dialogue, text) separated by tabs.",
+ )
+ parser.add_argument(
+ "--model-dir",
+ type=str,
+ required=True,
+ help="Local path of our evaluatioin model repository."
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models."
+ "Will use 'tts_eval_models/speaker_similarity/wavlm_large_finetune.pth'"
+ ", 'tts_eval_models/speaker_similarity/wavlm_large/' and "
+ "tts_eval_models/speaker_similarity/pyannote/ in this script",
+ )
+
+ parser.add_argument(
+ "--extension",
+ type=str,
+ default="wav",
+ help="Extension of the speech files. Default: wav",
+ )
+ return parser
+
+
+class CpSpeakerSimilarity:
+ """
+ Computes concatenated maximum permutation speaker similarity (cpSIM) scores using:
+ - A WavLM-based ECAPA-TDNN model for speaker embedding extraction.
+ - A pyannote pipeline for speaker diarization (segmenting speakers).
+ """
+
+ def __init__(
+ self,
+ sv_model_path: str = "speaker_similarity/wavlm_large_finetune.pth",
+ ssl_model_path: str = "speaker_similarity/wavlm_large/",
+ pyannote_model_path: str = "speaker_similarity/pyannote/",
+ ):
+ """
+ Initializes the cpSIM evaluator with the specified models.
+
+ Args:
+ sv_model_path (str): Path of the wavlm-based ECAPA-TDNN model checkpoint.
+ ssl_model_path (str): Path of the wavlm SSL model directory.
+ pyannote_model_path (str): Path of the pyannote diarization model directory.
+ """
+ self.sample_rate = 16000
+ self.device = (
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ )
+ logging.info(f"Using device: {self.device}")
+
+ # Initialize speaker verification model
+ self.sv_model = ECAPA_TDNN_WAVLM(
+ feat_dim=1024,
+ channels=512,
+ emb_dim=256,
+ sr=self.sample_rate,
+ ssl_model_path=ssl_model_path,
+ )
+ state_dict = torch.load(
+ sv_model_path, map_location=lambda storage, loc: storage
+ )
+ self.sv_model.load_state_dict(state_dict["model"], strict=False)
+ self.sv_model.to(self.device)
+ self.sv_model.eval()
+
+ # Initialize diarization pipeline
+ self.diarization_pipeline = Pipeline.from_pretrained(
+ os.path.join(pyannote_model_path, "pyannote_diarization_config.yaml")
+ )
+ self.diarization_pipeline.to(self.device)
+
+ @torch.no_grad()
+ def get_embeddings_with_diarization(
+ self, audio_paths: List[str]
+ ) -> List[List[torch.Tensor]]:
+ """
+ Extracts speaker embeddings from audio files
+ with speaker diarization (for 2-speaker conversations).
+
+ Args:
+ audio_paths: List of paths to audio files (each containing 2 speakers).
+
+ Returns:
+ List of embedding pairs, where each pair is
+ [embedding_speaker1, embedding_speaker2].
+ """
+
+ embeddings_list = []
+ for audio_path in tqdm(
+ audio_paths, desc="Extracting embeddings with diarization"
+ ):
+ # Load audio waveform
+ speech = load_waveform(
+ audio_path, self.sample_rate, device=self.device, max_seconds=120
+ )
+
+ # Perform speaker diarization (assumes 2 speakers)
+ diarization = self.diarization_pipeline(
+ {"waveform": speech.unsqueeze(0), "sample_rate": self.sample_rate},
+ num_speakers=2,
+ )
+
+ # Collect speech chunks for each speaker
+ speaker1_chunks = []
+ speaker2_chunks = []
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
+ start_frame = int(turn.start * self.sample_rate)
+ end_frame = int(turn.end * self.sample_rate)
+ chunk = speech[start_frame:end_frame]
+
+ if speaker == "SPEAKER_00":
+ speaker1_chunks.append(chunk)
+ elif speaker == "SPEAKER_01":
+ speaker2_chunks.append(chunk)
+
+ # Handle cases where diarization fails to detect 2 speakers
+ if not (speaker1_chunks and speaker2_chunks):
+ logging.debug(
+ f"Insufficient speaker chunks in {audio_path} "
+ f"using full audio for both speakers"
+ )
+ speaker1_speech = speech
+ speaker2_speech = speech
+ else:
+ speaker1_speech = torch.cat(speaker1_chunks, dim=0)
+ speaker2_speech = torch.cat(speaker2_chunks, dim=0)
+
+ # Extract embeddings with no gradient computation
+ try:
+ emb_speaker1 = self.sv_model([speaker1_speech])
+ emb_speaker2 = self.sv_model([speaker2_speech])
+ except Exception as e:
+ logging.debug(
+ f"Encountered an error {e} when extracting embeddings with "
+ f"segmented speech, will use full audio for both speakers."
+ )
+ emb_speaker1 = self.sv_model([speech])
+ emb_speaker2 = self.sv_model([speech])
+
+ embeddings_list.append([emb_speaker1, emb_speaker2])
+
+ return embeddings_list
+
+ @torch.no_grad()
+ def get_embeddings_from_pairs(
+ self, audio_pairs: List[Tuple[str, str]]
+ ) -> List[List[torch.Tensor]]:
+ """
+ Extracts speaker embeddings from pairs of single-speaker audio files.
+
+ Args:
+ audio_pairs: List of tuples (path_speaker1, path_speaker2).
+
+ Returns:
+ List of embedding pairs, where each pair is
+ [embedding_speaker1, embedding_speaker2].
+ """
+ embeddings_list = []
+ for (path1, path2) in tqdm(
+ audio_pairs, desc="Extracting embeddings from pairs"
+ ):
+ # Load audio for each speaker
+ speech1 = load_waveform(path1, self.sample_rate, device=self.device)
+ speech2 = load_waveform(path2, self.sample_rate, device=self.device)
+
+ # Extract embeddings
+ emb_speaker1 = self.sv_model([speech1])
+ emb_speaker2 = self.sv_model([speech2])
+
+ embeddings_list.append([emb_speaker1, emb_speaker2])
+
+ return embeddings_list
+
+ def score(
+ self,
+ wav_path: str,
+ extension: str,
+ test_list: str,
+ prompt_mode: str,
+ ) -> float:
+ """
+ Computes the cpSIM score by comparing embeddings of prompt and evaluated speech.
+
+ Args:
+ wav_path: Directory containing evaluated speech files.
+ test_list: Path to test list file mapping evaluated files to prompts.
+ prompt_mode: Either "merge" (2-speaker prompt) or "split"
+ (two single-speaker prompts).
+
+ Returns:
+ Average cpSIM score across all test pairs.
+ """
+ logging.info(f"Calculating cpSIM score for {wav_path} (mode: {prompt_mode})")
+
+ # Load and parse test list
+ try:
+ with open(test_list, "r", encoding="utf-8") as f:
+ lines = [line.strip() for line in f if line.strip()]
+ except Exception as e:
+ logging.error(f"Failed to read test list {test_list}: {e}")
+ raise
+
+ if not lines:
+ raise ValueError(f"Test list {test_list} is empty")
+
+ # Collect valid prompt-eval audio pairs
+ prompt_audios = [] # For "merge": [path]; for "split": [(path1, path2)]
+ eval_audios = []
+
+ for line_num, line in enumerate(lines, 1):
+ parts = line.split("\t")
+ if prompt_mode == "merge":
+ if len(parts) != 4:
+ raise ValueError(f"Expected 4 columns, got {len(parts)}")
+ audio_name, prompt_text, prompt_audio, text = parts
+ eval_audio_path = os.path.join(wav_path, f"{audio_name}.{extension}")
+ prompt_audios.append(prompt_audio)
+
+ elif prompt_mode == "split":
+ if len(parts) != 6:
+ raise ValueError(f"Expected 6 columns, got {len(parts)}")
+ (
+ audio_name,
+ prompt_text1,
+ prompt_text2,
+ prompt_audio_1,
+ prompt_audio_2,
+ text,
+ ) = parts
+ eval_audio_path = os.path.join(wav_path, f"{audio_name}.{extension}")
+ prompt_audios.append((prompt_audio_1, prompt_audio_2))
+
+ else:
+ raise ValueError(f"Invalid prompt_mode: {prompt_mode}")
+
+ # Validate file existence
+ if not os.path.exists(eval_audio_path):
+ raise FileNotFoundError(f"Evaluated file not found: {eval_audio_path}")
+
+ if prompt_mode == "merge":
+ if not os.path.exists(prompt_audio):
+ raise FileNotFoundError(
+ f"Prompt merge file not found: {prompt_audio}"
+ )
+ else:
+ if not (
+ os.path.exists(prompt_audio_1) and os.path.exists(prompt_audio_2)
+ ):
+ raise FileNotFoundError(
+ f"One or more prompt files missing in {prompt_audio_1}, "
+ f"{prompt_audio_2}"
+ )
+
+ eval_audios.append(eval_audio_path)
+
+ if not prompt_audios or not eval_audios:
+ raise ValueError(f"No valid prompt-eval pairs found in {test_list}")
+
+ logging.info(f"Processing {len(prompt_audios)} valid test pairs")
+
+ # Extract embeddings for prompts and evaluations
+ if prompt_mode == "merge":
+ prompt_embeddings = self.get_embeddings_with_diarization(prompt_audios)
+ else:
+ prompt_embeddings = self.get_embeddings_from_pairs(prompt_audios)
+
+ eval_embeddings = self.get_embeddings_with_diarization(eval_audios)
+
+ if len(prompt_embeddings) != len(eval_embeddings):
+ raise RuntimeError(
+ f"Mismatch: {len(prompt_embeddings)} prompt vs "
+ f" {len(eval_embeddings)} eval embeddings"
+ )
+
+ # Calculate maximum permutation similarity scores
+ scores = []
+ for prompt_embs, eval_embs in zip(prompt_embeddings, eval_embeddings):
+ # Prompt and eval each have 2 embeddings: [emb1, emb2]
+ sim1 = F.cosine_similarity(
+ prompt_embs[0], eval_embs[0], dim=-1
+ ) + F.cosine_similarity(prompt_embs[1], eval_embs[1], dim=-1)
+ sim2 = F.cosine_similarity(
+ prompt_embs[0], eval_embs[1], dim=-1
+ ) + F.cosine_similarity(prompt_embs[1], eval_embs[0], dim=-1)
+ max_sim = torch.max(sim1, sim2).item() / 2 # Average the sum
+ scores.append(max_sim)
+
+ return float(np.mean(scores))
+
+
+if __name__ == "__main__":
+
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ parser = get_parser()
+ args = parser.parse_args()
+
+ # Validate test list arguments
+ if not (args.test_list or args.test_list_merge):
+ raise ValueError("Either --test-list or --test-list-merge must be provided")
+ if args.test_list and args.test_list_merge:
+ raise ValueError(
+ "Only one of --test-list-split or --test-list-merge can be provided"
+ )
+ # Determine mode and test list
+ if args.test_list:
+ prompt_mode = "split"
+ test_list = args.test_list
+ else:
+ prompt_mode = "merge"
+ test_list = args.test_list_merge
+
+ # Initialize evaluator
+ sv_model_path = os.path.join(
+ args.model_dir, "speaker_similarity/wavlm_large_finetune.pth"
+ )
+ ssl_model_path = os.path.join(args.model_dir, "speaker_similarity/wavlm_large/")
+ pyannote_model_path = os.path.join(args.model_dir, "speaker_similarity/pyannote/")
+ if (
+ not os.path.exists(sv_model_path)
+ or not os.path.exists(ssl_model_path)
+ or not os.path.exists(pyannote_model_path)
+ ):
+ logging.error(
+ "Please download evaluation models from "
+ "https://huggingface.co/k2-fsa/TTS_eval_models"
+ " and pass this dir with --model-dir"
+ )
+ exit(1)
+ cp_sim = CpSpeakerSimilarity(
+ sv_model_path=sv_model_path,
+ ssl_model_path=ssl_model_path,
+ pyannote_model_path=pyannote_model_path,
+ )
+ # Compute similarity score
+ score = cp_sim.score(
+ wav_path=args.wav_path,
+ extension=args.extension,
+ test_list=test_list,
+ prompt_mode=prompt_mode,
+ )
+ print("-" * 50)
+ logging.info(f"cpSIM score: {score:.3f}")
+ print("-" * 50)
diff --git a/zipvoice/eval/speaker_similarity/sim.py b/zipvoice/eval/speaker_similarity/sim.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fa462069c8ce6cd02c599fdd01487c587bb6ee3
--- /dev/null
+++ b/zipvoice/eval/speaker_similarity/sim.py
@@ -0,0 +1,229 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Computes speaker similarity (SIM-o) using a WavLM-based
+ ECAPA-TDNN speaker verification model.
+"""
+import argparse
+import logging
+import os
+import warnings
+from typing import List
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from zipvoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
+from zipvoice.eval.utils import load_waveform
+
+warnings.filterwarnings("ignore")
+
+
+def get_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(
+ description="Calculate speaker similarity (SIM-o) score."
+ )
+
+ parser.add_argument(
+ "--wav-path",
+ type=str,
+ required=True,
+ help="Path to the directory containing evaluated speech files.",
+ )
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ required=True,
+ help="Path to the file list that contains the correspondence between prompts "
+ "and evaluated speech. Each line contains (audio_name, prompt_text_1, "
+ "prompt_text_2, prompt_audio_1, prompt_audio_2, text) separated by tabs.",
+ )
+ parser.add_argument(
+ "--model-dir",
+ type=str,
+ required=True,
+ help="Local path of our evaluatioin model repository."
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models."
+ "Will use 'tts_eval_models/speaker_similarity/wavlm_large_finetune.pth'"
+ "and 'tts_eval_models/speaker_similarity/wavlm_large/' in this script",
+ )
+
+ parser.add_argument(
+ "--extension",
+ type=str,
+ default="wav",
+ help="Extension of the speech files. Default: wav",
+ )
+ return parser
+
+
+class SpeakerSimilarity:
+ """
+ Computes speaker similarity (SIM-o) using a WavLM-based
+ ECAPA-TDNN speaker verification model.
+ """
+
+ def __init__(
+ self,
+ sv_model_path: str = "speaker_similarity/wavlm_large_finetune.pth",
+ ssl_model_path: str = "speaker_similarity/wavlm_large/",
+ ):
+ """
+ Initializes the speaker similarity evaluator with the specified models.
+
+ Args:
+ sv_model_path (str): Path of the wavlm-based ECAPA-TDNN model checkpoint.
+ ssl_model_path (str): Path of the wavlm SSL model directory.
+ """
+ self.sample_rate = 16000
+ self.device = (
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ )
+ logging.info(f"Using device: {self.device}")
+ self.model = ECAPA_TDNN_WAVLM(
+ feat_dim=1024,
+ channels=512,
+ emb_dim=256,
+ sr=self.sample_rate,
+ ssl_model_path=ssl_model_path,
+ )
+ state_dict = torch.load(
+ sv_model_path, map_location=lambda storage, loc: storage
+ )
+ self.model.load_state_dict(state_dict["model"], strict=False)
+ self.model.to(self.device)
+ self.model.eval()
+
+ @torch.no_grad()
+ def get_embeddings(self, wav_paths: List[str]) -> List[torch.Tensor]:
+ """
+ Extracts speaker embeddings from a list of audio files.
+
+ Args:
+ wav_paths (List[str]): List of paths to audio files.
+
+ Returns:
+ List[torch.Tensor]: List of speaker embeddings.
+ """
+ embeddings = []
+ for wav_path in tqdm(wav_paths, desc="Extracting speaker embeddings"):
+ # Load and preprocess waveform
+ speech = load_waveform(
+ wav_path, self.sample_rate, device=self.device, max_seconds=120
+ )
+ # Extract embedding
+ embedding = self.model([speech])
+ embeddings.append(embedding)
+
+ return embeddings
+
+ def score(self, wav_path: str, extension: str, test_list: str) -> float:
+ """
+ Computes the Speaker Similarity (SIM-o) score between reference and
+ evaluated speech.
+
+ Args:
+ wav_path (str): Path to the directory containing evaluated speech files.
+ test_list (str): Path to the test list file mapping evaluated files
+ to reference prompts.
+
+ Returns:
+ float: Average similarity score between reference and evaluated embeddings.
+ """
+ logging.info(f"Calculating Speaker Similarity (SIM-o) score for {wav_path}")
+ # Read test pairs
+ try:
+ with open(test_list, "r", encoding="utf-8") as f:
+ lines = [line.strip().split("\t") for line in f if line.strip()]
+ except Exception as e:
+ logging.error(f"Failed to read test list: {e}")
+ raise
+
+ if not lines:
+ raise ValueError(f"Test list {test_list} is empty or malformed")
+ # Parse test pairs
+ prompt_wavs = []
+ eval_wavs = []
+ for line in lines:
+ if len(line) != 4:
+ raise ValueError(f"Invalid line: {line}")
+ wav_name, prompt_text, prompt_wav, text = line
+ eval_wav_path = os.path.join(wav_path, f"{wav_name}.{extension}")
+ # Validate file existence
+ if not os.path.exists(prompt_wav):
+ raise FileNotFoundError(f"Prompt file not found: {prompt_wav}")
+ if not os.path.exists(eval_wav_path):
+ raise FileNotFoundError(f"Evaluated file not found: {eval_wav_path}")
+ prompt_wavs.append(prompt_wav)
+ eval_wavs.append(eval_wav_path)
+ logging.info(f"Found {len(prompt_wavs)} valid test pairs")
+ # Extract embeddings
+
+ prompt_embeddings = self.get_embeddings(prompt_wavs)
+ eval_embeddings = self.get_embeddings(eval_wavs)
+
+ if len(prompt_embeddings) != len(eval_embeddings):
+ raise RuntimeError(
+ f"Mismatch: {len(prompt_embeddings)} prompt vs "
+ f" {len(eval_embeddings)} eval embeddings"
+ )
+
+ # Calculate similarity scores
+ scores = []
+ for prompt_emb, eval_emb in zip(prompt_embeddings, eval_embeddings):
+ # Compute cosine similarity
+ similarity = torch.nn.functional.cosine_similarity(
+ prompt_emb, eval_emb, dim=-1
+ )
+ scores.append(similarity.item())
+
+ return float(np.mean(scores))
+
+
+if __name__ == "__main__":
+
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ parser = get_parser()
+ args = parser.parse_args()
+ # Initialize evaluator
+ sv_model_path = os.path.join(
+ args.model_dir, "speaker_similarity/wavlm_large_finetune.pth"
+ )
+ ssl_model_path = os.path.join(args.model_dir, "speaker_similarity/wavlm_large/")
+ if not os.path.exists(sv_model_path) or not os.path.exists(ssl_model_path):
+ logging.error(
+ "Please download evaluation models from "
+ "https://huggingface.co/k2-fsa/TTS_eval_models"
+ " and pass this dir with --model-dir"
+ )
+ exit(1)
+ sim_evaluator = SpeakerSimilarity(
+ sv_model_path=sv_model_path, ssl_model_path=ssl_model_path
+ )
+ # Compute similarity score
+ score = sim_evaluator.score(args.wav_path, args.extension, args.test_list)
+ print("-" * 50)
+ logging.info(f"SIM-o score: {score:.3f}")
+ print("-" * 50)
diff --git a/zipvoice/eval/utils.py b/zipvoice/eval/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b80abda2d601eaa8edf4f1c4edac3759ebc346cd
--- /dev/null
+++ b/zipvoice/eval/utils.py
@@ -0,0 +1,62 @@
+import logging
+
+import librosa
+import soundfile as sf
+import torch
+
+
+def load_waveform(
+ fname: str,
+ sample_rate: int,
+ dtype: str = "float32",
+ device: torch.device = torch.device("cpu"),
+ return_numpy: bool = False,
+ max_seconds: float = None,
+) -> torch.Tensor:
+ """
+ Load an audio file, preprocess it, and convert to a PyTorch tensor.
+
+ Args:
+ fname (str): Path to the audio file.
+ sample_rate (int): Target sample rate for resampling.
+ dtype (str, optional): Data type to load audio as (default: "float32").
+ device (torch.device, optional): Device to place the resulting tensor
+ on (default: CPU).
+ return_numpy (bool): If True, returns a NumPy array instead of a
+ PyTorch tensor.
+ max_seconds (int): Maximum length (seconds) of the audio tensor.
+ If the audio is longer than this, it will be truncated.
+
+ Returns:
+ torch.Tensor: Processed audio waveform as a PyTorch tensor,
+ with shape (num_samples,).
+
+ Notes:
+ - If the audio is stereo, it will be converted to mono by averaging channels.
+ - If the audio's sample rate differs from the target, it will be resampled.
+ """
+ # Load audio file with specified data type
+ wav_data, sr = sf.read(fname, dtype=dtype)
+
+ # Convert stereo to mono if necessary
+ if len(wav_data.shape) == 2:
+ wav_data = wav_data.mean(1)
+
+ # Resample to target sample rate if needed
+ if sr != sample_rate:
+ wav_data = librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate)
+
+ if max_seconds is not None:
+ # Trim to max length
+ max_length = sample_rate * max_seconds
+ if len(wav_data) > max_length:
+ wav_data = wav_data[:max_length]
+ logging.warning(
+ f"Wav file {fname} is longer than 2 minutes, "
+ f"truncated to 2 minutes to avoid OOM."
+ )
+ if return_numpy:
+ return wav_data
+ else:
+ wav_data = torch.from_numpy(wav_data)
+ return wav_data.to(device)
diff --git a/zipvoice/eval/wer/dialog.py b/zipvoice/eval/wer/dialog.py
new file mode 100644
index 0000000000000000000000000000000000000000..00a0eeb8d63e730f000d68023a0f21b419eb644a
--- /dev/null
+++ b/zipvoice/eval/wer/dialog.py
@@ -0,0 +1,493 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Computes WER or cpWER for English dialogue speech with WhisperD
+or compute WER for Chinese with Paraformer.
+"""
+
+import argparse
+import logging
+import os
+import re
+import string
+from typing import List, Tuple
+
+import numpy as np
+import torch
+import zhconv
+from funasr import AutoModel
+from jiwer import compute_measures
+from tqdm import tqdm
+from transformers import (
+ WhisperForConditionalGeneration,
+ WhisperProcessor,
+ WhisperTokenizer,
+ pipeline,
+)
+from zhon.hanzi import punctuation
+
+from zipvoice.eval.utils import load_waveform
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Computes WER or cpWER for English dialogue speech"
+ " with WhisperD or compute WER for Chinese with Paraformer.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--wav-path",
+ type=str,
+ required=True,
+ help="Path to the directory containing speech files.",
+ )
+
+ parser.add_argument(
+ "--extension",
+ type=str,
+ default="wav",
+ help="Extension of the speech files. Default: wav",
+ )
+
+ parser.add_argument(
+ "--decode-path",
+ type=str,
+ default=None,
+ help="Path to the output file where WER information will be saved. "
+ "If not provided, results are only printed to console.",
+ )
+ parser.add_argument(
+ "--model-dir",
+ type=str,
+ required=True,
+ help="Local path of evaluation models repository. "
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models. "
+ "This script expects 'tts_eval_models/wer/whisper-d-v1a/' for English "
+ "and 'tts_eval_models/wer/paraformer-zh/' for Chinese within this directory.",
+ )
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ default="test.tsv",
+ help="Path to the tsv file for speaker splitted prompts. "
+ "Each line contains (audio_name, prompt_text_1, prompt_text_2, "
+ "prompt_audio_1, prompt_audio_2, text) separated by tabs.",
+ )
+ parser.add_argument(
+ "--lang",
+ type=str,
+ choices=["zh", "en"],
+ required=True,
+ help="Language of the audio and transcripts for "
+ "decoding ('zh' for Chinese or 'en' for English).",
+ )
+ parser.add_argument(
+ "--cpwer",
+ action="store_true",
+ help="whether to compute the cpWER",
+ )
+ return parser
+
+
+def load_en_model(model_dir, device):
+ model_path = os.path.join(model_dir, "wer/whisper-d-v1a/")
+ if not os.path.exists(model_path):
+ logging.error(
+ f"Error: Whisper model not found at {model_path}. "
+ "Please download evaluation modelss from "
+ "https://huggingface.co/k2-fsa/TTS_eval_models "
+ "and pass this directory with --model-dir."
+ )
+ exit(1)
+ logging.info(f"Loading Whisper model from: {model_path}")
+ processor = WhisperProcessor.from_pretrained(model_path)
+ tokenizer = WhisperTokenizer.from_pretrained(model_path)
+ model = WhisperForConditionalGeneration.from_pretrained(
+ model_path, torch_dtype=torch.float16
+ )
+
+ model.generation_config.suppress_tokens = None
+ model.generation_config.forced_decoder_ids = None
+ # Using pipline to handle long audios
+ pipe = pipeline(
+ "automatic-speech-recognition",
+ model=model,
+ tokenizer=tokenizer,
+ feature_extractor=processor.feature_extractor,
+ chunk_length_s=30,
+ device=device,
+ )
+ return pipe
+
+
+def load_zh_model(model_dir):
+ model_path = os.path.join(model_dir, "wer/paraformer-zh/")
+ if not os.path.exists(model_path):
+ logging.error(
+ f"Error: Paraformer model not found at {model_path}. "
+ "Please download evaluation modelss from "
+ "https://huggingface.co/k2-fsa/TTS_eval_models "
+ "and pass this directory with --model-dir."
+ )
+ exit(1)
+ logging.info(f"Loading Paraformer model from: {model_path}")
+ model = AutoModel(model=model_path, disable_update=True)
+ return model
+
+
+def post_process(text: str, lang: str) -> str:
+ """
+ Cleans and normalizes text for WER calculation.
+ Args:
+ text (str): The input text to be processed.
+ lang (str): The language of the input text.
+
+ Returns:
+ str: The cleaned and normalized text.
+ """
+ punctuation_all = punctuation + string.punctuation
+ text = re.sub(r"\[.*?\]|<.*?>|\(.*?\)", "", text)
+ for x in punctuation_all:
+ if x == "'":
+ continue
+ text = text.replace(x, "")
+ text = re.sub(r"\s+", " ", text).strip()
+ if lang == "zh":
+ text = " ".join([x for x in text])
+ elif lang == "en":
+ text = text.lower()
+ else:
+ raise NotImplementedError
+ return text
+
+
+def process_one(hypothesis: str, truth: str, lang: str) -> tuple:
+ """
+ Computes WER and related metrics for a single hypothesis-truth pair.
+
+ Args:
+ hypothesis (str): The transcribed text from the ASR model.
+ truth (str): The ground truth transcript.
+
+ Returns:
+ tuple: A tuple containing:
+ - truth (str): Post-processed ground truth text.
+ - hypothesis (str): Post-processed hypothesis text.
+ - wer (float): Word Error Rate.
+ - substitutions (int): Number of substitutions.
+ - deletions (int): Number of deletions.
+ - insertions (int): Number of insertions.
+ - word_num (int): Number of words in the post-processed ground truth.
+ """
+ truth_processed = post_process(truth, lang)
+ hypothesis_processed = post_process(hypothesis, lang)
+
+ measures = compute_measures(truth_processed, hypothesis_processed)
+ word_num = len(truth_processed.split(" "))
+
+ return (
+ truth_processed,
+ hypothesis_processed,
+ measures["wer"],
+ measures["substitutions"],
+ measures["deletions"],
+ measures["insertions"],
+ word_num,
+ )
+
+
+def process_one_cpwer(hypothesis: str, truth: str, lang: str) -> tuple:
+ """
+ Computes cpWER and related metrics for a single hypothesis-truth pair.
+
+ Args:
+ hypothesis (str): The transcribed text from the ASR model.
+ truth (str): The ground truth transcript.
+
+ Returns:
+ tuple: A tuple containing:
+ - truth (str): Post-processed ground truth text.
+ - hypothesis (str): Post-processed hypothesis text.
+ - wer (float): Word Error Rate.
+ - substitutions (int): Number of substitutions.
+ - deletions (int): Number of deletions.
+ - insertions (int): Number of insertions.
+ - word_num (int): Number of words in the post-processed ground truth.
+ """
+ assert lang == "en"
+ truths = split_dialogue(truth)
+ hypotheses = split_dialogue(hypothesis)
+ for i in range(2):
+ truths[i] = post_process(truths[i], lang)
+ hypotheses[i] = post_process(hypotheses[i], lang)
+
+ measures_1 = compute_measures(
+ f"{truths[0]} {truths[1]}", f"{hypotheses[0]} {hypotheses[1]}"
+ )
+ measures_2 = compute_measures(
+ f"{truths[0]} {truths[1]}", f"{hypotheses[1]} {hypotheses[0]}"
+ )
+ truth = f"[S1] {truths[0]} [S2] {truths[1]}"
+ if measures_1["wer"] < measures_2["wer"]:
+ measures = measures_1
+ hypothesis = f"[S1] {hypotheses[0]} [S2] {hypotheses[1]}"
+ else:
+ measures = measures_2
+ hypothesis = f"[S1] {hypotheses[1]} [S2] {hypotheses[0]}"
+ truth = re.sub(r"\s+", " ", truth)
+ hypothesis = re.sub(r"\s+", " ", hypothesis)
+ word_num = len(truth.split(" ")) - 2
+ return (
+ truth,
+ hypothesis,
+ measures["wer"],
+ measures["substitutions"],
+ measures["deletions"],
+ measures["insertions"],
+ word_num,
+ )
+
+
+def split_dialogue(text):
+ segments = re.split(r"\[S[1-9]\]", text)
+ segments = [segment.strip() for segment in segments]
+ spk1_texts = " ".join(segments[::2])
+ spk2_texts = " ".join(segments[1::2])
+ return [spk1_texts, spk2_texts]
+
+
+class SpeechEvalDataset(torch.utils.data.Dataset):
+ """
+ A PyTorch Dataset for loading speech waveforms and their transcripts
+ for evaluation. Will only keep shorter-than-30s waveforms if in `cpwer` mode.
+ """
+
+ def __init__(
+ self, wav_transcript_path_pair: List[Tuple[str, str]], cpwer: bool = False
+ ):
+ super().__init__()
+ if cpwer:
+ self.wav_transcript_path_pair = []
+ for wav_path, transcript in wav_transcript_path_pair:
+ waveform = load_waveform(
+ wav_path,
+ sample_rate=16000,
+ )
+ if len(waveform) / 16000 <= 30:
+ self.wav_transcript_path_pair.append((wav_path, transcript))
+ else:
+ self.wav_transcript_path_pair = wav_transcript_path_pair
+
+ def __len__(self):
+ return len(self.wav_transcript_path_pair)
+
+ def __getitem__(self, index: int):
+ waveform = load_waveform(
+ self.wav_transcript_path_pair[index][0],
+ sample_rate=16000,
+ return_numpy=True,
+ )
+ item = {
+ "array": waveform,
+ "sampling_rate": 16000,
+ "reference": self.wav_transcript_path_pair[index][1],
+ "wav_path": self.wav_transcript_path_pair[index][0],
+ }
+ return item
+
+
+def main(test_list, wav_dir, extension, model_dir, decode_path, lang, cpwer, device):
+ logging.info(f"Calculating WER for {wav_dir} (cpwer={cpwer})")
+ if lang == "en":
+ model = load_en_model(model_dir, device=device)
+ elif lang == "zh":
+ model = load_zh_model(model_dir)
+ params = []
+ for line in open(test_list).readlines():
+ line = line.strip()
+ assert len(line.split("\t")) == 6
+ items = line.split("\t")
+ wav_name, text_ref = items[0], items[-1]
+ file_path = os.path.join(wav_dir, wav_name + "." + extension)
+ assert os.path.exists(file_path), f"{file_path}"
+ params.append((file_path, text_ref))
+
+ if decode_path:
+ # Ensure the output directory exists
+ decode_dir = os.path.dirname(decode_path)
+ if decode_dir and not os.path.exists(decode_dir):
+ os.makedirs(decode_dir)
+ fout = open(decode_path, "w", encoding="utf8")
+ logging.info(f"Saving detailed WER results to: {decode_path}")
+ fout.write(
+ "Name\tWER\tTruth\tHypothesis\tInsertions\tDeletions\tSubstitutions\n"
+ )
+
+ # Initialize metrics for overall WER calculation
+ wers = []
+ inses = []
+ deles = []
+ subses = []
+ word_nums = 0
+ if cpwer:
+ cp_wers = []
+ cp_inses = []
+ cp_deles = []
+ cp_subses = []
+ cp_word_nums = 0
+ if decode_path:
+ fout = open(decode_path, "w")
+ if lang == "zh":
+ for wav_path, text_ref in tqdm(params):
+ res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
+ transcription = res[0]["text"]
+ transcription = zhconv.convert(transcription, "zh-cn")
+
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
+ transcription, text_ref, lang
+ )
+ if decode_path:
+ fout.write(
+ f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n"
+ )
+ wers.append(float(wer))
+ inses.append(float(inse))
+ deles.append(float(dele))
+ subses.append(float(subs))
+ word_nums += word_num
+ elif lang == "en":
+ dataset = SpeechEvalDataset(params, cpwer)
+ bar = tqdm(
+ model(
+ dataset,
+ generate_kwargs={"language": lang, "task": "transcribe"},
+ batch_size=16,
+ ),
+ total=len(dataset),
+ )
+ for out in bar:
+ transcription = out["text"]
+ text_ref = out["reference"][0]
+ wav_path = out["wav_path"][0]
+ if cpwer:
+ (
+ cp_truth,
+ cp_hypo,
+ cp_wer,
+ cp_subs,
+ cp_dele,
+ cp_inse,
+ cp_word_num,
+ ) = process_one_cpwer(transcription, text_ref, lang)
+ if decode_path:
+ fout.write(
+ f"{wav_path}\t{cp_wer}\t{cp_truth}\t"
+ f"{cp_hypo}\t{cp_inse}\t{cp_dele}\t{cp_subs}\n"
+ )
+ cp_wers.append(float(cp_wer))
+ cp_inses.append(float(cp_inse))
+ cp_deles.append(float(cp_dele))
+ cp_subses.append(float(cp_subs))
+ cp_word_nums += cp_word_num
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
+ transcription, text_ref, lang
+ )
+ if decode_path:
+ fout.write(
+ f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n"
+ )
+ wers.append(float(wer))
+ inses.append(float(inse))
+ deles.append(float(dele))
+ subses.append(float(subs))
+ word_nums += word_num
+ if cpwer:
+ assert (
+ word_num == cp_word_num
+ ), f"{wav_path} has {word_num} words, but {cp_word_num} cp words"
+
+ print("-" * 50)
+ if cpwer:
+ cp_wer = round(
+ (np.sum(cp_subses) + np.sum(cp_deles) + np.sum(cp_inses))
+ / cp_word_nums
+ * 100,
+ 2,
+ )
+ cp_inse = np.sum(cp_inses)
+ cp_dele = np.sum(cp_deles)
+ cp_subs = np.sum(cp_subses)
+ logging.info(f"cpWER = {cp_wer}%")
+ logging.info(
+ f"Errors: {cp_inse} insertions, {cp_dele} deletions, {cp_subs} "
+ f"substitutions, over {cp_word_nums} reference words"
+ )
+ if decode_path:
+ fout.write(f"cpWER = {cp_wer}%\n")
+ fout.write(
+ f"Errors: {cp_inse} insertions, {cp_dele} deletions, {cp_subs} "
+ f"substitutions, over {cp_word_nums} reference words\n"
+ )
+ wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 2)
+ inse = np.sum(inses)
+ dele = np.sum(deles)
+ subs = np.sum(subses)
+
+ logging.info(f"WER = {wer}%")
+ logging.info(
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
+ f"over {word_nums} reference words"
+ )
+ print("-" * 50)
+
+ if decode_path:
+ fout.write(f"WER = {wer}%\n")
+ fout.write(
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
+ f"over {word_nums} reference words\n"
+ )
+ fout.flush()
+
+
+if __name__ == "__main__":
+
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ parser = get_parser()
+ args = parser.parse_args()
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ else:
+ device = torch.device("cpu")
+ if args.cpwer:
+ assert args.lang == "en", "Only English is supported for cpWER"
+ main(
+ args.test_list,
+ args.wav_path,
+ args.extension,
+ args.model_dir,
+ args.decode_path,
+ args.lang,
+ args.cpwer,
+ device,
+ )
diff --git a/zipvoice/eval/wer/hubert.py b/zipvoice/eval/wer/hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0681bea6bc58fe817a37665bae0941bee60c473
--- /dev/null
+++ b/zipvoice/eval/wer/hubert.py
@@ -0,0 +1,285 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Computes word error rate (WER) with Hubert models for LibriSpeech test sets.
+"""
+import argparse
+import logging
+import os
+import re
+from pathlib import Path
+
+import numpy as np
+import torch
+from jiwer import compute_measures
+from tqdm import tqdm
+from transformers import pipeline
+
+from zipvoice.eval.utils import load_waveform
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Computes WER with Hubert models.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--wav-path",
+ type=str,
+ required=True,
+ help="Path to the directory containing speech files.",
+ )
+
+ parser.add_argument(
+ "--extension",
+ type=str,
+ default="wav",
+ help="Extension of the speech files. Default: wav",
+ )
+
+ parser.add_argument(
+ "--decode-path",
+ type=str,
+ default=None,
+ help="Path to the output file where WER information will be saved. "
+ "If not provided, results are only printed to console.",
+ )
+ parser.add_argument(
+ "--model-dir",
+ type=str,
+ required=True,
+ help="Local path of our evaluatioin model repository."
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models."
+ "Will use 'tts_eval_models/wer/hubert-large-ls960-ft/'"
+ " in this script",
+ )
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ default="transcript.tsv",
+ help="path of the tsv file. Each line is in the format:"
+ "(audio_name, text) separated by tabs.",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=16,
+ help="Batch size for decoding with the Hugging Face pipeline.",
+ )
+ return parser
+
+
+def post_process(text: str) -> str:
+ """
+ Cleans and normalizes text for WER calculation.
+ Args:
+ text (str): The input text to be processed.
+
+ Returns:
+ str: The cleaned and normalized text.
+ """
+ text = text.replace("‘", "'").replace("’", "'")
+ text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
+ text = re.sub(r"\s+", " ", text).strip()
+ return text
+
+
+def process_one(hypothesis: str, truth: str) -> tuple:
+ """
+ Computes WER and related metrics for a single hypothesis-truth pair.
+
+ Args:
+ hypothesis (str): The transcribed text from the ASR model.
+ truth (str): The ground truth transcript.
+
+ Returns:
+ tuple: A tuple containing:
+ - truth (str): Post-processed ground truth text.
+ - hypothesis (str): Post-processed hypothesis text.
+ - wer (float): Word Error Rate.
+ - substitutions (int): Number of substitutions.
+ - deletions (int): Number of deletions.
+ - insertions (int): Number of insertions.
+ - word_num (int): Number of words in the post-processed ground truth.
+ """
+ truth_processed = post_process(truth)
+ hypothesis_processed = post_process(hypothesis)
+
+ measures = compute_measures(truth_processed, hypothesis_processed)
+ word_num = len(truth_processed.split(" "))
+
+ return (
+ truth_processed,
+ hypothesis_processed,
+ measures["wer"],
+ measures["substitutions"],
+ measures["deletions"],
+ measures["insertions"],
+ word_num,
+ )
+
+
+class SpeechEvalDataset(torch.utils.data.Dataset):
+ """
+ A PyTorch Dataset for loading speech waveforms and their transcripts
+ for evaluation.
+ """
+
+ def __init__(self, wav_path: str, test_list: str, extension: str = "wav"):
+ """
+ Initializes the dataset.
+
+ Args:
+ wav_path (str): Path to the directory containing speech files.
+ test_list (str): Path to the TSV file with speech file names and
+ transcripts.
+ """
+ super().__init__()
+ self.wav_names = []
+ self.wav_paths = []
+ self.transcripts = []
+ with Path(test_list).open("r", encoding="utf8") as f:
+ meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
+ for item in meta:
+ self.wav_names.append(item[0])
+ self.wav_paths.append(Path(wav_path, item[0] + "." + extension))
+ self.transcripts.append(item[-1])
+
+ def __len__(self):
+ return len(self.wav_paths)
+
+ def __getitem__(self, index: int):
+ waveform = load_waveform(
+ self.wav_paths[index],
+ sample_rate=16000,
+ return_numpy=True,
+ )
+ item = {
+ "array": waveform,
+ "sampling_rate": 16000,
+ "reference": self.transcripts[index],
+ "wav_name": self.wav_names[index],
+ }
+ return item
+
+
+def main(test_list, wav_path, extension, model_dir, decode_path, batch_size, device):
+ logging.info(f"Calculating WER for {wav_path}")
+ model_path = os.path.join(model_dir, "wer/hubert-large-ls960-ft/")
+ if not os.path.exists(model_path):
+ logging.error(
+ "Please download evaluation models from "
+ "https://huggingface.co/k2-fsa/TTS_eval_models"
+ " and pass this dir with --model-dir"
+ )
+ exit(1)
+
+ asr_pipeline = pipeline(
+ "automatic-speech-recognition",
+ model=model_path,
+ device=device,
+ tokenizer=model_path,
+ )
+
+ dataset = SpeechEvalDataset(wav_path, test_list, extension)
+
+ transcription_results = tqdm(
+ asr_pipeline(
+ dataset,
+ generate_kwargs={"language": "english", "task": "transcribe"},
+ batch_size=batch_size,
+ ),
+ total=len(dataset),
+ )
+
+ # Initialize metrics for overall WER calculation
+ wers = []
+ inses = []
+ deles = []
+ subses = []
+ word_nums = 0
+ if decode_path:
+ # Ensure the output directory exists
+ decode_dir = os.path.dirname(decode_path)
+ if decode_dir and not os.path.exists(decode_dir):
+ os.makedirs(decode_dir)
+ fout = open(decode_path, "w", encoding="utf8")
+ logging.info(f"Saving detailed WER results to: {decode_path}")
+ fout.write(
+ "Name\tWER\tTruth\tHypothesis\tInsertions\tDeletions\tSubstitutions\n"
+ )
+ for out in transcription_results:
+ wav_name = out["wav_name"][0]
+ transcription = out["text"].strip()
+ text_ref = out["reference"][0].strip()
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
+ transcription, text_ref
+ )
+ if decode_path:
+ fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
+ wers.append(float(wer))
+ inses.append(float(inse))
+ deles.append(float(dele))
+ subses.append(float(subs))
+ word_nums += word_num
+
+ wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 2)
+ inse = np.sum(inses)
+ dele = np.sum(deles)
+ subs = np.sum(subses)
+ print("-" * 50)
+ logging.info(f"WER = {wer}%")
+ logging.info(
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
+ f"over {word_nums} reference words"
+ )
+ print("-" * 50)
+ if decode_path:
+ fout.write(f"WER = {wer}%\n")
+ fout.write(
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
+ f"over {word_nums} reference words\n"
+ )
+ fout.flush()
+
+
+if __name__ == "__main__":
+
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ parser = get_parser()
+ args = parser.parse_args()
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ else:
+ device = torch.device("cpu")
+ main(
+ args.test_list,
+ args.wav_path,
+ args.extension,
+ args.model_dir,
+ args.decode_path,
+ args.batch_size,
+ device,
+ )
diff --git a/zipvoice/eval/wer/seedtts.py b/zipvoice/eval/wer/seedtts.py
new file mode 100644
index 0000000000000000000000000000000000000000..e32f9f999495ded349cbc305cd102abaa893afd5
--- /dev/null
+++ b/zipvoice/eval/wer/seedtts.py
@@ -0,0 +1,298 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Computes word error rate (WER) with Whisper-large-v3 for English and
+Paraformer for Chinese. Intended to evaluate WERs on Seed-TTS test sets.
+"""
+
+import argparse
+import logging
+import os
+import string
+
+import numpy as np
+import scipy
+import soundfile as sf
+import torch
+import zhconv
+from funasr import AutoModel
+from jiwer import compute_measures
+from tqdm import tqdm
+from transformers import WhisperForConditionalGeneration, WhisperProcessor
+from zhon.hanzi import punctuation
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Computes WER with Whisper and Paraformer models, "
+ "following Seed-TTS.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--wav-path",
+ type=str,
+ required=True,
+ help="Path to the directory containing speech files.",
+ )
+
+ parser.add_argument(
+ "--extension",
+ type=str,
+ default="wav",
+ help="Extension of the speech files. Default: wav",
+ )
+
+ parser.add_argument(
+ "--decode-path",
+ type=str,
+ default=None,
+ help="Path to the output file where WER information will be saved. "
+ "If not provided, results are only printed to console.",
+ )
+ parser.add_argument(
+ "--model-dir",
+ type=str,
+ required=True,
+ help="Local path of evaluation models repository. "
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models. "
+ "This script expects 'tts_eval_models/wer/whisper-large-v3/' for English "
+ "and 'tts_eval_models/wer/paraformer-zh/' for Chinese within this directory.",
+ )
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ default="test.tsv",
+ help="path of the tsv file. Each line is in the format:"
+ "(audio_name, prompt_text,prompt_audio, text) separated by tabs.",
+ )
+ parser.add_argument(
+ "--lang",
+ type=str,
+ choices=["zh", "en"],
+ required=True,
+ help="Language of the audio and transcripts for "
+ "decoding ('zh' for Chinese or 'en' for English).",
+ )
+ return parser
+
+
+def load_en_model(model_dir):
+ model_path = os.path.join(model_dir, "wer/whisper-large-v3/")
+ if not os.path.exists(model_path):
+ logging.error(
+ f"Error: Whisper model not found at {model_path}. "
+ "Please download evaluation modelss from "
+ "https://huggingface.co/k2-fsa/TTS_eval_models "
+ "and pass this directory with --model-dir."
+ )
+ exit(1)
+ logging.info(f"Loading Whisper model from: {model_path}")
+ processor = WhisperProcessor.from_pretrained(model_path)
+ model = WhisperForConditionalGeneration.from_pretrained(model_path)
+ return processor, model
+
+
+def load_zh_model(model_dir):
+ model_path = os.path.join(model_dir, "wer/paraformer-zh/")
+ if not os.path.exists(model_path):
+ logging.error(
+ f"Error: Paraformer model not found at {model_path}. "
+ "Please download evaluation modelss from "
+ "https://huggingface.co/k2-fsa/TTS_eval_models "
+ "and pass this directory with --model-dir."
+ )
+ exit(1)
+ logging.info(f"Loading Paraformer model from: {model_path}")
+ model = AutoModel(model=model_path, disable_update=True)
+ return model
+
+
+def post_process(text: str, lang: str) -> str:
+ """
+ Cleans and normalizes text for WER calculation.
+ Args:
+ text (str): The input text to be processed.
+ lang (str): The language of the input text.
+
+ Returns:
+ str: The cleaned and normalized text.
+ """
+ punctuation_all = punctuation + string.punctuation
+ for x in punctuation_all:
+ if x == "'":
+ continue
+ text = text.replace(x, "")
+
+ text = text.replace(" ", " ")
+
+ if lang == "zh":
+ text = " ".join([x for x in text])
+ elif lang == "en":
+ text = text.lower()
+ else:
+ raise NotImplementedError
+ return text
+
+
+def process_one(hypothesis: str, truth: str, lang: str) -> tuple:
+ """
+ Computes WER and related metrics for a single hypothesis-truth pair.
+
+ Args:
+ hypothesis (str): The transcribed text from the ASR model.
+ truth (str): The ground truth transcript.
+
+ Returns:
+ tuple: A tuple containing:
+ - truth (str): Post-processed ground truth text.
+ - hypothesis (str): Post-processed hypothesis text.
+ - wer (float): Word Error Rate.
+ - substitutions (int): Number of substitutions.
+ - deletions (int): Number of deletions.
+ - insertions (int): Number of insertions.
+ - word_num (int): Number of words in the post-processed ground truth.
+ """
+ truth_processed = post_process(truth, lang)
+ hypothesis_processed = post_process(hypothesis, lang)
+
+ measures = compute_measures(truth_processed, hypothesis_processed)
+ word_num = len(truth_processed.split(" "))
+
+ return (
+ truth_processed,
+ hypothesis_processed,
+ measures["wer"],
+ measures["substitutions"],
+ measures["deletions"],
+ measures["insertions"],
+ word_num,
+ )
+
+
+def main(test_list, wav_path, extension, model_path, decode_path, lang, device):
+ logging.info(f"Calculating WER for {wav_path}")
+ if lang == "en":
+ processor, model = load_en_model(model_path)
+ model.to(device)
+ elif lang == "zh":
+ model = load_zh_model(model_path)
+ params = []
+ for line in open(test_list).readlines():
+ line = line.strip()
+ items = line.split("\t")
+ wav_name, text_ref = items[0], items[-1]
+ file_path = os.path.join(wav_path, wav_name + "." + extension)
+ assert os.path.exists(file_path), f"{file_path}"
+
+ params.append((file_path, text_ref))
+ # Initialize metrics for overall WER calculation
+ wers = []
+ inses = []
+ deles = []
+ subses = []
+ word_nums = 0
+ if decode_path:
+ # Ensure the output directory exists
+ decode_dir = os.path.dirname(decode_path)
+ if decode_dir and not os.path.exists(decode_dir):
+ os.makedirs(decode_dir)
+ fout = open(decode_path, "w")
+ for wav_path, text_ref in tqdm(params):
+ if lang == "en":
+ wav, sr = sf.read(wav_path)
+ if sr != 16000:
+ wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
+ input_features = processor(
+ wav, sampling_rate=16000, return_tensors="pt"
+ ).input_features
+ input_features = input_features.to(device)
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
+ language="english", task="transcribe"
+ )
+ predicted_ids = model.generate(
+ input_features, forced_decoder_ids=forced_decoder_ids
+ )
+ transcription = processor.batch_decode(
+ predicted_ids, skip_special_tokens=True
+ )[0]
+ elif lang == "zh":
+ res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
+ transcription = res[0]["text"]
+ transcription = zhconv.convert(transcription, "zh-cn")
+
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
+ transcription, text_ref, lang
+ )
+ if decode_path:
+ fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
+ wers.append(float(wer))
+ inses.append(float(inse))
+ deles.append(float(dele))
+ subses.append(float(subs))
+ word_nums += word_num
+
+ wer_avg = round(np.mean(wers) * 100, 2)
+ wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 2)
+ inse = np.sum(inses)
+ dele = np.sum(deles)
+ subs = np.sum(subses)
+ print("-" * 50)
+ # The official evaluation codes of Seed-TTS uses the average of WERs
+ # instead of the weighted average of WERs.
+ logging.info(f"Seed-TTS WER: {wer_avg}%\n")
+ logging.info(f"WER: {wer}%\n")
+ logging.info(
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
+ f"over {word_nums} reference words"
+ )
+ print("-" * 50)
+ if decode_path:
+ fout.write(f"SeedTTS WER: {wer_avg}%\n")
+ fout.write(f"WER: {wer}%\n")
+ fout.write(
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
+ f"over {word_nums} reference words\n"
+ )
+ fout.flush()
+
+
+if __name__ == "__main__":
+
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
+
+ parser = get_parser()
+ args = parser.parse_args()
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ else:
+ device = torch.device("cpu")
+ main(
+ args.test_list,
+ args.wav_path,
+ args.extension,
+ args.model_dir,
+ args.decode_path,
+ args.lang,
+ device,
+ )
diff --git a/zipvoice/models/modules/scaling.py b/zipvoice/models/modules/scaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..5532b1a785f0216a268701b4784194f7a1adf86c
--- /dev/null
+++ b/zipvoice/models/modules/scaling.py
@@ -0,0 +1,1575 @@
+# Copyright 2022-2025 Xiaomi Corp. (authors: Daniel Povey
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import math
+import random
+import sys
+from typing import Optional, Tuple, Union
+
+try:
+ import k2
+except Exception as e:
+ logging.warning(
+ f"Failed import k2 with error {e}. Swoosh functions will fallback to PyTorch"
+ f" implementation, leading to slower speed and higher memory consumption."
+ )
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+custom_bwd = lambda func: torch.amp.custom_bwd(func, device_type="cuda")
+custom_fwd = lambda func: torch.amp.custom_fwd(func, device_type="cuda")
+
+
+def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
+ max_value = torch.max(x, y)
+ diff = torch.abs(x - y)
+ return max_value + torch.log1p(torch.exp(-diff))
+
+
+# RuntimeError: Exporting the operator logaddexp to ONNX opset version
+# 14 is not supported. Please feel free to request support or submit
+# a pull request on PyTorch GitHub.
+#
+# The following function is to solve the above error when exporting
+# models to ONNX via torch.jit.trace()
+def logaddexp(x: Tensor, y: Tensor) -> Tensor:
+ # Caution(fangjun): Put torch.jit.is_scripting() before
+ # torch.onnx.is_in_onnx_export();
+ # otherwise, it will cause errors for torch.jit.script().
+ #
+ # torch.logaddexp() works for both torch.jit.script() and
+ # torch.jit.trace() but it causes errors for ONNX export.
+ #
+ if torch.jit.is_scripting():
+ # Note: We cannot use torch.jit.is_tracing() here as it also
+ # matches torch.onnx.export().
+ return torch.logaddexp(x, y)
+ elif torch.onnx.is_in_onnx_export():
+ return logaddexp_onnx(x, y)
+ else:
+ # for torch.jit.trace()
+ return torch.logaddexp(x, y)
+
+
+class PiecewiseLinear(object):
+ """
+ Piecewise linear function, from float to float, specified as nonempty list of (x,y)
+ pairs with the x values in order. x values <[initial x] or >[final x] are map to
+ [initial y], [final y] respectively.
+ """
+
+ def __init__(self, *args):
+ assert len(args) >= 1, len(args)
+ if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
+ self.pairs = list(args[0].pairs)
+ else:
+ self.pairs = [(float(x), float(y)) for x, y in args]
+ for x, y in self.pairs:
+ assert isinstance(x, (float, int)), type(x)
+ assert isinstance(y, (float, int)), type(y)
+
+ for i in range(len(self.pairs) - 1):
+ assert self.pairs[i + 1][0] > self.pairs[i][0], (
+ i,
+ self.pairs[i],
+ self.pairs[i + 1],
+ )
+
+ def __str__(self):
+ # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
+ return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
+
+ def __call__(self, x):
+ if x <= self.pairs[0][0]:
+ return self.pairs[0][1]
+ elif x >= self.pairs[-1][0]:
+ return self.pairs[-1][1]
+ else:
+ cur_x, cur_y = self.pairs[0]
+ for i in range(1, len(self.pairs)):
+ next_x, next_y = self.pairs[i]
+ if x >= cur_x and x <= next_x:
+ return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
+ cur_x, cur_y = next_x, next_y
+ assert False
+
+ def __mul__(self, alpha):
+ return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
+
+ def __add__(self, x):
+ if isinstance(x, (float, int)):
+ return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
+ s, x = self.get_common_basis(x)
+ return PiecewiseLinear(
+ *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def max(self, x):
+ if isinstance(x, (float, int)):
+ x = PiecewiseLinear((0, x))
+ s, x = self.get_common_basis(x, include_crossings=True)
+ return PiecewiseLinear(
+ *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def min(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ x = PiecewiseLinear((0, x))
+ s, x = self.get_common_basis(x, include_crossings=True)
+ return PiecewiseLinear(
+ *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def __eq__(self, other):
+ return self.pairs == other.pairs
+
+ def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
+ """
+ Returns (self_mod, p_mod) which are equivalent piecewise linear
+ functions to self and p, but with the same x values.
+
+ p: the other piecewise linear function
+ include_crossings: if true, include in the x values positions
+ where the functions indicate by this and p crosss.
+ """
+ assert isinstance(p, PiecewiseLinear), type(p)
+
+ # get sorted x-values without repetition.
+ x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
+ y_vals1 = [self(x) for x in x_vals]
+ y_vals2 = [p(x) for x in x_vals]
+
+ if include_crossings:
+ extra_x_vals = []
+ for i in range(len(x_vals) - 1):
+ if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
+ # if the two lines in this subsegment potentially cross each other..
+ diff_cur = abs(y_vals1[i] - y_vals2[i])
+ diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
+ # `pos`, between 0 and 1, gives the relative x position,
+ # with 0 being x_vals[i] and 1 being x_vals[i+1].
+ pos = diff_cur / (diff_cur + diff_next)
+ extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
+ extra_x_vals.append(extra_x_val)
+ if len(extra_x_vals) > 0:
+ x_vals = sorted(set(x_vals + extra_x_vals))
+ y_vals1 = [self(x) for x in x_vals]
+ y_vals2 = [p(x) for x in x_vals]
+ return (
+ PiecewiseLinear(*zip(x_vals, y_vals1)),
+ PiecewiseLinear(*zip(x_vals, y_vals2)),
+ )
+
+
+class ScheduledFloat(torch.nn.Module):
+ """
+ This object is a torch.nn.Module only because we want it to show up in
+ [top_level module].modules(); it does not have a working forward() function.
+ You are supposed to cast it to float, as in, float(parent_module.whatever), and use
+ it as something like a dropout prob.
+
+ It is a floating point value whose value changes depending on the batch count of the
+ training loop. It is a piecewise linear function where you specify the (x,y) pairs
+ in sorted order on x; x corresponds to the batch index. For batch-index values
+ before the first x or after the last x, we just use the first or last y value.
+
+ Example:
+ self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
+
+ `default` is used when self.batch_count is not set or not in training mode or in
+ torch.jit scripting mode.
+ """
+
+ def __init__(self, *args, default: float = 0.0):
+ super().__init__()
+ # self.batch_count and self.name will be written to in the training loop.
+ self.batch_count = None
+ self.name = None
+ self.default = default
+ self.schedule = PiecewiseLinear(*args)
+
+ def extra_repr(self) -> str:
+ return (
+ f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
+ )
+
+ def __float__(self):
+ batch_count = self.batch_count
+ if (
+ batch_count is None
+ or not self.training
+ or torch.jit.is_scripting()
+ or torch.jit.is_tracing()
+ ):
+ return float(self.default)
+ else:
+ ans = self.schedule(self.batch_count)
+ if random.random() < 0.0002:
+ logging.debug(
+ f"ScheduledFloat: name={self.name}, "
+ f"batch_count={self.batch_count}, ans={ans}"
+ )
+ return ans
+
+ def __add__(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ return ScheduledFloat(self.schedule + x, default=self.default)
+ else:
+ return ScheduledFloat(
+ self.schedule + x.schedule, default=self.default + x.default
+ )
+
+ def max(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ return ScheduledFloat(self.schedule.max(x), default=self.default)
+ else:
+ return ScheduledFloat(
+ self.schedule.max(x.schedule),
+ default=max(self.default, x.default),
+ )
+
+
+FloatLike = Union[float, ScheduledFloat]
+
+
+class CutoffEstimator:
+ """
+ Estimates cutoffs of an arbitrary numerical quantity such that a specified
+ proportion of items will be above the cutoff on average.
+
+ p is the proportion of items that should be above the cutoff.
+ """
+
+ def __init__(self, p: float):
+ self.p = p
+ # total count of items
+ self.count = 0
+ # total count of items that were above the cutoff
+ self.count_above = 0
+ # initial cutoff value
+ self.cutoff = 0
+
+ def __call__(self, x: float) -> bool:
+ """
+ Returns true if x is above the cutoff.
+ """
+ ans = x > self.cutoff
+ self.count += 1
+ if ans:
+ self.count_above += 1
+ cur_p = self.count_above / self.count
+ delta_p = cur_p - self.p
+ if (delta_p > 0) == ans:
+ q = abs(delta_p)
+ self.cutoff = x * q + self.cutoff * (1 - q)
+ return ans
+
+
+class SoftmaxFunction(torch.autograd.Function):
+ """
+ Tries to handle half-precision derivatives in a randomized way that should
+ be more accurate for training than the default behavior.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor, dim: int):
+ ans = x.softmax(dim=dim)
+ # if x dtype is float16, x.softmax() returns a float32 because
+ # (presumably) that op does not support float16, and autocast
+ # is enabled.
+ if torch.is_autocast_enabled():
+ ans = ans.to(torch.float16)
+ ctx.save_for_backward(ans)
+ ctx.x_dtype = x.dtype
+ ctx.dim = dim
+ return ans
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ (ans,) = ctx.saved_tensors
+ with torch.amp.autocast("cuda", enabled=False):
+ ans_grad = ans_grad.to(torch.float32)
+ ans = ans.to(torch.float32)
+ x_grad = ans_grad * ans
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
+ return x_grad, None
+
+
+def softmax(x: Tensor, dim: int):
+ if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x.softmax(dim=dim)
+
+ return SoftmaxFunction.apply(x, dim)
+
+
+class BiasNormFunction(torch.autograd.Function):
+ # This computes:
+ # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
+ # return x * scales
+ # (after unsqueezing the bias), but it does it in a memory-efficient way so that
+ # it can just store the returned value (chances are, this will also be needed for
+ # some other reason, related to the next operation, so we can save memory).
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ bias: Tensor,
+ log_scale: Tensor,
+ channel_dim: int,
+ store_output_for_backprop: bool,
+ ) -> Tensor:
+ assert bias.ndim == 1
+ if channel_dim < 0:
+ channel_dim = channel_dim + x.ndim
+ ctx.store_output_for_backprop = store_output_for_backprop
+ ctx.channel_dim = channel_dim
+ for _ in range(channel_dim + 1, x.ndim):
+ bias = bias.unsqueeze(-1)
+ scales = (
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
+ ) * log_scale.exp()
+ ans = x * scales
+ ctx.save_for_backward(
+ ans.detach() if store_output_for_backprop else x,
+ scales.detach(),
+ bias.detach(),
+ log_scale.detach(),
+ )
+ return ans
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor) -> Tensor:
+ ans_or_x, scales, bias, log_scale = ctx.saved_tensors
+ if ctx.store_output_for_backprop:
+ x = ans_or_x / scales
+ else:
+ x = ans_or_x
+ x = x.detach()
+ x.requires_grad = True
+ bias.requires_grad = True
+ log_scale.requires_grad = True
+ with torch.enable_grad():
+ # recompute scales from x, bias and log_scale.
+ scales = (
+ torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
+ ) * log_scale.exp()
+ ans = x * scales
+ ans.backward(gradient=ans_grad)
+ return x.grad, bias.grad.flatten(), log_scale.grad, None, None
+
+
+class BiasNorm(torch.nn.Module):
+ """
+ This is intended to be a simpler, and hopefully cheaper, replacement for
+ LayerNorm. The observation this is based on, is that Transformer-type
+ networks, especially with pre-norm, sometimes seem to set one of the
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
+ the LayerNorm because the output magnitude is then not strongly dependent
+ on the other (useful) features. Presumably the weight and bias of the
+ LayerNorm are required to allow it to do this.
+
+ Instead, we give the BiasNorm a trainable bias that it can use when
+ computing the scale for normalization. We also give it a (scalar)
+ trainable scale on the output.
+
+
+ Args:
+ num_channels: the number of channels, e.g. 512.
+ channel_dim: the axis/dimension corresponding to the channel,
+ interpreted as an offset from the input's ndim if negative.
+ This is NOT the num_channels; it should typically be one of
+ {-2, -1, 0, 1, 2, 3}.
+ log_scale: the initial log-scale that we multiply the output by; this
+ is learnable.
+ log_scale_min: FloatLike, minimum allowed value of log_scale
+ log_scale_max: FloatLike, maximum allowed value of log_scale
+ store_output_for_backprop: only possibly affects memory use; recommend
+ to set to True if you think the output of this module is more likely
+ than the input of this module to be required to be stored for the
+ backprop.
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int = -1, # CAUTION: see documentation.
+ log_scale: float = 1.0,
+ log_scale_min: float = -1.5,
+ log_scale_max: float = 1.5,
+ store_output_for_backprop: bool = False,
+ ) -> None:
+ super(BiasNorm, self).__init__()
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.log_scale = nn.Parameter(torch.tensor(log_scale))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+
+ self.log_scale_min = log_scale_min
+ self.log_scale_max = log_scale_max
+
+ self.store_output_for_backprop = store_output_for_backprop
+
+ def forward(self, x: Tensor) -> Tensor:
+ assert x.shape[self.channel_dim] == self.num_channels
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ channel_dim = self.channel_dim
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ bias = self.bias
+ for _ in range(channel_dim + 1, x.ndim):
+ bias = bias.unsqueeze(-1)
+ scales = (
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
+ ) * self.log_scale.exp()
+ return x * scales
+
+ log_scale = limit_param_value(
+ self.log_scale,
+ min=float(self.log_scale_min),
+ max=float(self.log_scale_max),
+ training=self.training,
+ )
+
+ return BiasNormFunction.apply(
+ x,
+ self.bias,
+ log_scale,
+ self.channel_dim,
+ self.store_output_for_backprop,
+ )
+
+
+def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
+ """
+ Behaves like a constructor of a modified version of nn.Linear
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Linear(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+class BalancerFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ min_mean: float,
+ max_mean: float,
+ min_rms: float,
+ max_rms: float,
+ grad_scale: float,
+ channel_dim: int,
+ ) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ ctx.channel_dim = channel_dim
+ ctx.save_for_backward(x)
+ ctx.config = (
+ min_mean,
+ max_mean,
+ min_rms,
+ max_rms,
+ grad_scale,
+ channel_dim,
+ )
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
+ (x,) = ctx.saved_tensors
+ (
+ min_mean,
+ max_mean,
+ min_rms,
+ max_rms,
+ grad_scale,
+ channel_dim,
+ ) = ctx.config
+
+ try:
+ with torch.enable_grad():
+ with torch.amp.autocast("cuda", enabled=False):
+ x = x.to(torch.float32)
+ x = x.detach()
+ x.requires_grad = True
+ mean_dims = [i for i in range(x.ndim) if i != channel_dim]
+ uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
+ mean = x.mean(dim=mean_dims, keepdim=True)
+ stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
+ rms = uncentered_var.clamp(min=1.0e-20).sqrt()
+
+ m = mean / stddev
+ # part of loss that relates to mean / stddev
+ m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
+
+ # put a much larger scale on the RMS-max-limit loss, so that if both
+ # it and the m_loss are violated we fix the RMS loss first.
+ rms_clamped = rms.clamp(min=min_rms, max=max_rms)
+ r_loss = (rms_clamped / rms).log().abs()
+
+ loss = m_loss + r_loss
+
+ loss.backward(gradient=torch.ones_like(loss))
+ loss_grad = x.grad
+ loss_grad_rms = (
+ (loss_grad**2)
+ .mean(dim=mean_dims, keepdim=True)
+ .sqrt()
+ .clamp(min=1.0e-20)
+ )
+
+ loss_grad = loss_grad * (grad_scale / loss_grad_rms)
+
+ x_grad_float = x_grad.to(torch.float32)
+ # scale each element of loss_grad by the absolute value of the
+ # corresponding element of x_grad, which we view as a noisy estimate
+ # of its magnitude for that (frame and dimension). later we can
+ # consider factored versions.
+ x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
+ x_grad = x_grad_mod.to(x_grad.dtype)
+ except Exception as e:
+ logging.info(
+ f"Caught exception in Balancer backward: {e}, "
+ f"size={list(x_grad.shape)}, will continue."
+ )
+
+ return x_grad, None, None, None, None, None, None
+
+
+class Balancer(torch.nn.Module):
+ """
+ Modifies the backpropped derivatives of a function to try to encourage, for
+ each channel, that it is positive at least a proportion `threshold` of the
+ time. It does this by multiplying negative derivative values by up to
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
+ interpolated from 1 at the threshold to those extremal values when none
+ of the inputs are positive.
+
+ Args:
+ num_channels: the number of channels
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
+ min_positive: the minimum, per channel, of the proportion of the time
+ that (x > 0), below which we start to modify the derivatives.
+ max_positive: the maximum, per channel, of the proportion of the time
+ that (x > 0), above which we start to modify the derivatives.
+ scale_gain_factor: determines the 'gain' with which we increase the
+ change in gradient once the constraints on min_abs and max_abs
+ are violated.
+ min_abs: the minimum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ max_abs: the maximum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ prob: determines the minimum probability with which we modify the
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
+ on each forward(). This is done randomly to prevent all layers
+ from doing it at the same time.
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int,
+ min_positive: FloatLike = 0.05,
+ max_positive: FloatLike = 0.95,
+ min_abs: FloatLike = 0.2,
+ max_abs: FloatLike = 100.0,
+ grad_scale: FloatLike = 0.04,
+ prob: Optional[FloatLike] = None,
+ ):
+ super().__init__()
+
+ if prob is None:
+ prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
+ self.prob = prob
+ # 5% of the time we will return and do nothing because memory usage is
+ # too high.
+ self.mem_cutoff = CutoffEstimator(0.05)
+
+ # actually self.num_channels is no longer needed except for an assertion.
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.min_positive = min_positive
+ self.max_positive = max_positive
+ self.min_abs = min_abs
+ self.max_abs = max_abs
+ self.grad_scale = grad_scale
+
+ def forward(self, x: Tensor) -> Tensor:
+ if (
+ torch.jit.is_scripting()
+ or not x.requires_grad
+ or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
+ ):
+ return _no_op(x)
+
+ prob = float(self.prob)
+ if random.random() < prob:
+ # The following inner-functions convert from the way we historically
+ # specified these limitations, as limits on the absolute value and the
+ # proportion of positive values, to limits on the RMS value and
+ # the (mean / stddev).
+ def _abs_to_rms(x):
+ # for normally distributed data, if the expected absolute value is x,
+ # the expected rms value will be sqrt(pi/2) * x.
+ return 1.25331413732 * x
+
+ def _proportion_positive_to_mean(x):
+ def _atanh(x):
+ eps = 1.0e-10
+ # eps is to prevent crashes if x is exactly 0 or 1.
+ # we'll just end up returning a fairly large value.
+ return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0
+
+ def _approx_inverse_erf(x):
+ # 1 / (sqrt(pi) * ln(2)),
+ # see https://math.stackexchange.com/questions/321569/
+ # approximating-the-error-function-erf-by-analytical-functions
+ # this approximation is extremely crude and gets progressively worse
+ # for x very close to -1 or +1, but we mostly care about the
+ # "middle" region
+ # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
+ # and math.erf(0.0407316414078772) = 0.045935330944660666,
+ # which is pretty close to 0.05.
+ return 0.8139535143 * _atanh(x)
+
+ # first convert x from the range 0..1 to the range -1..1 which the error
+ # function returns
+ x = -1 + (2 * x)
+ return _approx_inverse_erf(x)
+
+ min_mean = _proportion_positive_to_mean(float(self.min_positive))
+ max_mean = _proportion_positive_to_mean(float(self.max_positive))
+ min_rms = _abs_to_rms(float(self.min_abs))
+ max_rms = _abs_to_rms(float(self.max_abs))
+ grad_scale = float(self.grad_scale)
+
+ assert x.shape[self.channel_dim] == self.num_channels
+
+ return BalancerFunction.apply(
+ x,
+ min_mean,
+ max_mean,
+ min_rms,
+ max_rms,
+ grad_scale,
+ self.channel_dim,
+ )
+ else:
+ return _no_op(x)
+
+
+def penalize_abs_values_gt(
+ x: Tensor, limit: float, penalty: float, name: str = None
+) -> Tensor:
+ """
+ Returns x unmodified, but in backprop will put a penalty for the excess of
+ the absolute values of elements of x over the limit "limit". E.g. if
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
+
+ Caution: the value of this penalty will be affected by grad scaling used
+ in automatic mixed precision training. For this reasons we use this,
+ it shouldn't really matter, or may even be helpful; we just use this
+ to disallow really implausible values of scores to be given to softmax.
+
+ The name is for randomly printed debug info.
+ """
+ x_sign = x.sign()
+ over_limit = (x.abs() - limit) > 0
+ # The following is a memory efficient way to penalize the absolute values of
+ # x that's over the limit. (The memory efficiency comes when you think
+ # about which items torch needs to cache for the autograd, and which ones it
+ # can throw away). The numerical value of aux_loss as computed here will
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
+ # limit).relu().
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
+ # sum() due to how with_loss() works.
+ x = with_loss(x, aux_loss, name)
+ # you must use x for something, or this will be ineffective.
+ return x
+
+
+def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
+ if x.ndim == 2:
+ return x.diag()
+ else:
+ (batch, dim, dim) = x.shape
+ x = x.reshape(batch, dim * dim)
+ x = x[:, :: dim + 1]
+ assert x.shape == (batch, dim)
+ return x
+
+
+def _whitening_metric(x: Tensor, num_groups: int):
+ """
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
+ of the centered feature covariance are the same within each group's covariance
+ matrix and also between groups.
+ Args:
+ x: a Tensor of shape (*, num_channels)
+ num_groups: the number of groups of channels, a number >=1 that divides
+ num_channels
+ Returns:
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
+ greater than 1.0 otherwise.
+ """
+ assert x.dtype != torch.float16
+ x = x.reshape(-1, x.shape[-1])
+ (num_frames, num_channels) = x.shape
+ assert num_channels % num_groups == 0
+ channels_per_group = num_channels // num_groups
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
+ # x now has shape (num_groups, num_frames, channels_per_group)
+ # subtract the mean so we use the centered, not uncentered, covariance.
+ # My experience has been that when we "mess with the gradients" like this,
+ # it's better not do anything that tries to move the mean around, because
+ # that can easily cause instability.
+ x = x - x.mean(dim=1, keepdim=True)
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
+ x_covar = torch.matmul(x.transpose(1, 2), x)
+ x_covar_mean_diag = _diag(x_covar).mean()
+ # the following expression is what we'd get if we took the matrix product
+ # of each covariance and measured the mean of its trace, i.e.
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
+ return metric
+
+
+class WhiteningPenaltyFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, module: nn.Module) -> Tensor:
+ ctx.save_for_backward(x)
+ ctx.module = module
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor):
+ (x_orig,) = ctx.saved_tensors
+ w = ctx.module
+
+ try:
+ with torch.enable_grad():
+ with torch.amp.autocast("cuda", enabled=False):
+ x_detached = x_orig.to(torch.float32).detach()
+ x_detached.requires_grad = True
+
+ metric = _whitening_metric(x_detached, w.num_groups)
+
+ if random.random() < 0.005 or __name__ == "__main__":
+ logging.debug(
+ f"Whitening: name={w.name}, num_groups={w.num_groups},"
+ f"num_channels={x_orig.shape[-1]}, "
+ f"metric={metric.item():.2f}"
+ f" vs. limit={float(w.whitening_limit)}"
+ )
+
+ if metric < float(w.whitening_limit):
+ w.prob = w.min_prob
+ return x_grad, None
+ else:
+ w.prob = w.max_prob
+ metric.backward()
+ penalty_grad = x_detached.grad
+ scale = w.grad_scale * (
+ x_grad.to(torch.float32).norm()
+ / (penalty_grad.norm() + 1.0e-20)
+ )
+ penalty_grad = penalty_grad * scale
+ return x_grad + penalty_grad.to(x_grad.dtype), None
+ except Exception as e:
+ logging.info(
+ f"Caught exception in Whiten backward: {e}, "
+ f"size={list(x_grad.shape)}, will continue."
+ )
+ return x_grad, None
+
+
+class Whiten(nn.Module):
+ def __init__(
+ self,
+ num_groups: int,
+ whitening_limit: FloatLike,
+ prob: Union[float, Tuple[float, float]],
+ grad_scale: FloatLike,
+ ):
+ """
+ Args:
+ num_groups: the number of groups to divide the channel dim into before
+ whitening. We will attempt to make the feature covariance
+ within each group, after mean subtraction, as "white" as possible,
+ while having the same trace across all groups.
+ whitening_limit: a value greater than 1.0, that dictates how much
+ freedom we have to violate the constraints. 1.0 would mean perfectly
+ white, with exactly the same trace across groups; larger values
+ give more freedom. E.g. 2.0.
+ prob: the probability with which we apply the gradient modification
+ (also affects the grad scale). May be supplied as a float,
+ or as a pair (min_prob, max_prob)
+
+ grad_scale: determines the scale on the gradient term from this object,
+ relative to the rest of the gradient on the attention weights.
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
+ """
+ super(Whiten, self).__init__()
+ assert num_groups >= 1
+ assert float(whitening_limit) >= 1
+ assert grad_scale >= 0
+ self.num_groups = num_groups
+ self.whitening_limit = whitening_limit
+ self.grad_scale = grad_scale
+
+ if isinstance(prob, float):
+ prob = (prob, prob)
+ (self.min_prob, self.max_prob) = prob
+ assert 0 < self.min_prob <= self.max_prob <= 1
+ self.prob = self.max_prob
+ self.name = None # will be set in training loop
+
+ def forward(self, x: Tensor) -> Tensor:
+ """
+ In the forward pass, this function just returns the input unmodified.
+ In the backward pass, it will modify the gradients to ensure that the
+ distribution in each group has close to (lambda times I) as the covariance
+ after mean subtraction, with the same lambda across groups.
+ For whitening_limit > 1, there will be more freedom to violate this
+ constraint.
+
+ Args:
+ x: the input of shape (*, num_channels)
+
+ Returns:
+ x, unmodified. You should make sure
+ you use the returned value, or the graph will be freed
+ and nothing will happen in backprop.
+ """
+ grad_scale = float(self.grad_scale)
+ if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
+ return _no_op(x)
+ else:
+ return WhiteningPenaltyFunction.apply(x, self)
+
+
+class WithLoss(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, y: Tensor, name: str):
+ ctx.y_shape = y.shape
+ if random.random() < 0.002 and name is not None:
+ loss_sum = y.sum().item()
+ logging.debug(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
+ return x
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ return (
+ ans_grad,
+ torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
+ None,
+ )
+
+
+def with_loss(x, y, name):
+ # returns x but adds y.sum() to the loss function.
+ return WithLoss.apply(x, y, name)
+
+
+class LimitParamValue(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, min: float, max: float):
+ ctx.save_for_backward(x)
+ assert max >= min
+ ctx.min = min
+ ctx.max = max
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor):
+ (x,) = ctx.saved_tensors
+ # where x < ctx.min, ensure all grads are negative (this will tend to make
+ # x more positive).
+ x_grad = x_grad * torch.where(
+ torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
+ )
+ # where x > ctx.max, ensure all grads are positive (this will tend to make
+ # x more negative).
+ x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
+ return x_grad, None, None
+
+
+def limit_param_value(
+ x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
+):
+ # You apply this to (typically) an nn.Parameter during training to ensure that its
+ # (elements mostly) stays within a supplied range. This is done by modifying the
+ # gradients in backprop.
+ # It's not necessary to do this on every batch: do it only some of the time,
+ # to save a little time.
+ if training and random.random() < prob:
+ return LimitParamValue.apply(x, min, max)
+ else:
+ return x
+
+
+def _no_op(x: Tensor) -> Tensor:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x
+ else:
+ # a no-op function that will have a node in the autograd graph,
+ # to avoid certain bugs relating to backward hooks
+ return x.chunk(1, dim=-1)[0]
+
+
+# Identity more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
+class Identity(torch.nn.Module):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, x):
+ return _no_op(x)
+
+
+# Dropout2 is just like normal dropout, except it supports schedules
+# on the dropout rates.
+class Dropout2(nn.Module):
+ def __init__(self, p: FloatLike):
+ super().__init__()
+ self.p = p
+
+ def forward(self, x: Tensor) -> Tensor:
+ return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
+
+
+class MulForDropout3(torch.autograd.Function):
+ # returns (x * y * alpha) where alpha is a float and y doesn't require
+ # grad and is zero-or-one.
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, x, y, alpha):
+ assert not y.requires_grad
+ ans = x * y * alpha
+ ctx.save_for_backward(ans)
+ ctx.alpha = alpha
+ return ans
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, ans_grad):
+ (ans,) = ctx.saved_tensors
+ x_grad = ctx.alpha * ans_grad * (ans != 0)
+ return x_grad, None, None
+
+
+# Dropout3 is just like normal dropout, except it supports schedules on the dropout
+# rates, and it lets you choose one dimension to share the dropout mask over
+class Dropout3(nn.Module):
+ def __init__(self, p: FloatLike, shared_dim: int):
+ super().__init__()
+ self.p = p
+ self.shared_dim = shared_dim
+
+ def forward(self, x: Tensor) -> Tensor:
+ p = float(self.p)
+ if not self.training or p == 0:
+ return _no_op(x)
+ scale = 1.0 / (1 - p)
+ rand_shape = list(x.shape)
+ rand_shape[self.shared_dim] = 1
+ mask = torch.rand(*rand_shape, device=x.device) > p
+ ans = MulForDropout3.apply(x, mask, scale)
+ return ans
+
+
+class SwooshLFunction(torch.autograd.Function):
+ """
+ swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+
+ coeff = -0.08
+
+ with torch.amp.autocast("cuda", enabled=False):
+ with torch.enable_grad():
+ x = x.detach()
+ x.requires_grad = True
+ y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
+
+ if not requires_grad:
+ return y
+
+ y.backward(gradient=torch.ones_like(y))
+
+ grad = x.grad
+ floor = coeff
+ ceil = 1.0 + coeff + 0.005
+
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ grad
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+ coeff = -0.08
+ floor = coeff
+ ceil = 1.0 + coeff + 0.005
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class SwooshL(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-L activation."""
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
+ elif "k2" not in sys.modules:
+ return SwooshLFunction.apply(x)
+ else:
+ if not x.requires_grad:
+ return k2.swoosh_l_forward(x)
+ else:
+ return k2.swoosh_l(x)
+
+
+class SwooshLOnnx(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-L activation."""
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
+
+
+class SwooshRFunction(torch.autograd.Function):
+ """
+ swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
+
+ derivatives are between -0.08 and 0.92.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+
+ with torch.amp.autocast("cuda", enabled=False):
+ with torch.enable_grad():
+ x = x.detach()
+ x.requires_grad = True
+ y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
+
+ if not requires_grad:
+ return y
+ y.backward(gradient=torch.ones_like(y))
+
+ grad = x.grad
+ floor = -0.08
+ ceil = 0.925
+
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ grad
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+ floor = -0.08
+ ceil = 0.925
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class SwooshR(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-R activation."""
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
+ elif "k2" not in sys.modules:
+ return SwooshRFunction.apply(x)
+ else:
+ if not x.requires_grad:
+ return k2.swoosh_r_forward(x)
+ else:
+ return k2.swoosh_r(x)
+
+
+class SwooshROnnx(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-R activation."""
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687
+
+
+# simple version of SwooshL that does not redefine the backprop, used in
+# ActivationDropoutAndLinearFunction.
+def SwooshLForward(x: Tensor):
+ with torch.amp.autocast("cuda", enabled=False):
+ x = x.to(torch.float32)
+ x_offset = x - 4.0
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
+ return log_sum - 0.08 * x - 0.035
+
+
+# simple version of SwooshR that does not redefine the backprop, used in
+# ActivationDropoutAndLinearFunction.
+def SwooshRForward(x: Tensor):
+ with torch.amp.autocast("cuda", enabled=False):
+ x = x.to(torch.float32)
+ x_offset = x - 1.0
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
+ return log_sum - 0.08 * x - 0.313261687
+
+
+class ActivationDropoutAndLinearFunction(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor],
+ activation: str,
+ dropout_p: float,
+ dropout_shared_dim: Optional[int],
+ ):
+ if dropout_p != 0.0:
+ dropout_shape = list(x.shape)
+ if dropout_shared_dim is not None:
+ dropout_shape[dropout_shared_dim] = 1
+ # else it won't be very memory efficient.
+ dropout_mask = (1.0 / (1.0 - dropout_p)) * (
+ torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
+ )
+ else:
+ dropout_mask = None
+
+ ctx.save_for_backward(x, weight, bias, dropout_mask)
+
+ ctx.activation = activation
+
+ forward_activation_dict = {
+ "SwooshL": k2.swoosh_l_forward,
+ "SwooshR": k2.swoosh_r_forward,
+ }
+ # it will raise a KeyError if this fails. This will be an error. We let it
+ # propagate to the user.
+ activation_func = forward_activation_dict[activation]
+ x = activation_func(x)
+ if dropout_mask is not None:
+ x = x * dropout_mask
+ x = torch.nn.functional.linear(x, weight, bias)
+ return x
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, ans_grad: Tensor):
+ saved = ctx.saved_tensors
+ (x, weight, bias, dropout_mask) = saved
+
+ forward_and_deriv_activation_dict = {
+ "SwooshL": k2.swoosh_l_forward_and_deriv,
+ "SwooshR": k2.swoosh_r_forward_and_deriv,
+ }
+ # the following lines a KeyError if the activation is unrecognized.
+ # This will be an error. We let it propagate to the user.
+ func = forward_and_deriv_activation_dict[ctx.activation]
+
+ y, func_deriv = func(x)
+ if dropout_mask is not None:
+ y = y * dropout_mask
+ # now compute derivative of y w.r.t. weight and bias..
+ # y: (..., in_channels), ans_grad: (..., out_channels),
+ (out_channels, in_channels) = weight.shape
+
+ in_channels = y.shape[-1]
+ g = ans_grad.reshape(-1, out_channels)
+ weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels))
+ y_deriv = torch.matmul(ans_grad, weight)
+ bias_deriv = None if bias is None else g.sum(dim=0)
+ x_deriv = y_deriv * func_deriv
+ if dropout_mask is not None:
+ # order versus func_deriv does not matter
+ x_deriv = x_deriv * dropout_mask
+
+ return x_deriv, weight_deriv, bias_deriv, None, None, None
+
+
+class ActivationDropoutAndLinear(torch.nn.Module):
+ """
+ This merges an activation function followed by dropout and then a nn.Linear module;
+ it does so in a memory efficient way so that it only stores the input to the whole
+ module. If activation == SwooshL and dropout_shared_dim != None, this will be
+ equivalent to:
+ nn.Sequential(SwooshL(),
+ Dropout3(dropout_p, shared_dim=dropout_shared_dim),
+ ScaledLinear(in_channels, out_channels, bias=bias,
+ initial_scale=initial_scale))
+ If dropout_shared_dim is None, the dropout would be equivalent to
+ Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
+ mask is smaller.
+
+ Args:
+ in_channels: number of input channels, e.g. 256
+ out_channels: number of output channels, e.g. 256
+ bias: if true, have a bias
+ activation: the activation function, for now just support SwooshL.
+ dropout_p: the dropout probability or schedule (happens after nonlinearity).
+ dropout_shared_dim: the dimension, if any, across which the dropout mask is
+ shared (e.g. the time dimension). If None, this may be less memory
+ efficient if there are modules before this one that cache the input
+ for their backprop (e.g. Balancer or Whiten).
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ bias: bool = True,
+ activation: str = "SwooshL",
+ dropout_p: FloatLike = 0.0,
+ dropout_shared_dim: Optional[int] = -1,
+ initial_scale: float = 1.0,
+ ):
+ super().__init__()
+ # create a temporary module of nn.Linear that we'll steal the
+ # weights and bias from
+ l = ScaledLinear(
+ in_channels, out_channels, bias=bias, initial_scale=initial_scale
+ )
+
+ self.weight = l.weight
+ # register_parameter properly handles making it a parameter when l.bias
+ # is None. I think there is some reason for doing it this way rather
+ # than just setting it to None but I don't know what it is, maybe
+ # something to do with exporting the module..
+ self.register_parameter("bias", l.bias)
+
+ self.activation = activation
+ self.dropout_p = dropout_p
+ self.dropout_shared_dim = dropout_shared_dim
+
+ def forward(self, x: Tensor):
+ if (
+ torch.jit.is_scripting()
+ or torch.jit.is_tracing()
+ or "k2" not in sys.modules
+ ):
+ if self.activation == "SwooshL":
+ x = SwooshLForward(x)
+ elif self.activation == "SwooshR":
+ x = SwooshRForward(x)
+ else:
+ assert False, self.activation
+ return torch.nn.functional.linear(x, self.weight, self.bias)
+
+ return ActivationDropoutAndLinearFunction.apply(
+ x,
+ self.weight,
+ self.bias,
+ self.activation,
+ float(self.dropout_p),
+ self.dropout_shared_dim,
+ )
+
+
+def _test_whiten():
+ for proportion in [0.1, 0.5, 10.0]:
+ logging.info(f"_test_whiten(): proportion = {proportion}")
+ x = torch.randn(100, 128)
+ direction = torch.randn(128)
+ coeffs = torch.randn(100, 1)
+ x += proportion * direction * coeffs
+
+ x.requires_grad = True
+
+ m = Whiten(
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
+ ) # grad_scale
+
+ for _ in range(4):
+ y = m(x)
+
+ y_grad = torch.randn_like(x)
+ y.backward(gradient=y_grad)
+
+ if proportion < 0.2:
+ assert torch.allclose(x.grad, y_grad)
+ elif proportion > 1.0:
+ assert not torch.allclose(x.grad, y_grad)
+
+
+def _test_balancer_sign():
+ probs = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
+ x = x.detach()
+ x.requires_grad = True
+ m = Balancer(
+ probs.numel(),
+ channel_dim=0,
+ min_positive=0.05,
+ max_positive=0.95,
+ min_abs=0.0,
+ prob=1.0,
+ )
+
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_balancer_sign: x = ", x)
+ print("_test_balancer_sign: y grad = ", y_grad)
+ print("_test_balancer_sign: x grad = ", x.grad)
+
+
+def _test_balancer_magnitude():
+ magnitudes = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
+ x = x.detach()
+ x.requires_grad = True
+ m = Balancer(
+ magnitudes.numel(),
+ channel_dim=0,
+ min_positive=0.0,
+ max_positive=1.0,
+ min_abs=0.2,
+ max_abs=0.7,
+ prob=1.0,
+ )
+
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_balancer_magnitude: x = ", x)
+ print("_test_balancer_magnitude: y grad = ", y_grad)
+ print("_test_balancer_magnitude: x grad = ", x.grad)
+
+
+def _test_swooshl_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = SwooshL()
+
+ tol = 1.0 / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+ return y
+
+
+def _test_swooshr_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = SwooshR()
+
+ tol = 1.0 / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+ return y
+
+
+def _test_softmax():
+ a = torch.randn(2, 10, dtype=torch.float64)
+ b = a.clone()
+ a.requires_grad = True
+ b.requires_grad = True
+ a.softmax(dim=1)[:, 0].sum().backward()
+ print("a grad = ", a.grad)
+ softmax(b, dim=1)[:, 0].sum().backward()
+ print("b grad = ", b.grad)
+ assert torch.allclose(a.grad, b.grad)
+
+
+def _test_piecewise_linear():
+ p = PiecewiseLinear((0, 10.0))
+ for x in [-100, 0, 100]:
+ assert p(x) == 10.0
+ p = PiecewiseLinear((0, 10.0), (1, 0.0))
+ for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
+ print("x, y = ", x, y)
+ assert p(x) == y, (x, p(x), y)
+
+ q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
+ x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
+ pq = p.max(q)
+ for x in x_vals:
+ y1 = max(p(x), q(x))
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+ pq = p.min(q)
+ for x in x_vals:
+ y1 = min(p(x), q(x))
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+ pq = p + q
+ for x in x_vals:
+ y1 = p(x) + q(x)
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+
+
+def _test_activation_dropout_and_linear():
+ in_channels = 20
+ out_channels = 30
+
+ for bias in [True, False]:
+ # actually we don't test for dropout_p != 0.0 because forward functions will
+ # different answers. This is because we are using the k2 implementation of
+ # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
+ # internally, messing up the random state.
+ for dropout_p in [0.0]:
+ for activation in ["SwooshL", "SwooshR"]:
+ m1 = nn.Sequential(
+ SwooshL() if activation == "SwooshL" else SwooshR(),
+ Dropout3(p=dropout_p, shared_dim=-1),
+ ScaledLinear(
+ in_channels, out_channels, bias=bias, initial_scale=0.5
+ ),
+ )
+ m2 = ActivationDropoutAndLinear(
+ in_channels,
+ out_channels,
+ bias=bias,
+ initial_scale=0.5,
+ activation=activation,
+ dropout_p=dropout_p,
+ )
+ with torch.no_grad():
+ m2.weight[:] = m1[2].weight
+ if bias:
+ m2.bias[:] = m1[2].bias
+ # make sure forward gives same result.
+ x1 = torch.randn(10, in_channels)
+ x1.requires_grad = True
+
+ # TEMP.
+ assert torch.allclose(
+ SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
+ )
+
+ x2 = x1.clone().detach()
+ x2.requires_grad = True
+ seed = 10
+ torch.manual_seed(seed)
+ y1 = m1(x1)
+ y_grad = torch.randn_like(y1)
+ y1.backward(gradient=y_grad)
+ torch.manual_seed(seed)
+ y2 = m2(x2)
+ y2.backward(gradient=y_grad)
+
+ print(
+ f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
+ )
+ print("y1 = ", y1)
+ print("y2 = ", y2)
+ assert torch.allclose(y1, y2, atol=0.02)
+ assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
+ if bias:
+ assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
+ print("x1.grad = ", x1.grad)
+ print("x2.grad = ", x2.grad)
+
+ def isclose(a, b):
+ # return true if cosine similarity is > 0.9.
+ return (a * b).sum() > 0.9 * (
+ (a**2).sum() * (b**2).sum()
+ ).sqrt()
+
+ # the SwooshL() implementation has a noisy gradient due to 1-byte
+ # storage of it.
+ assert isclose(x1.grad, x2.grad)
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.DEBUG)
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ _test_piecewise_linear()
+ _test_softmax()
+ _test_whiten()
+ _test_balancer_sign()
+ _test_balancer_magnitude()
+ _test_swooshr_deriv()
+ _test_swooshl_deriv()
+ _test_activation_dropout_and_linear()
diff --git a/zipvoice/models/modules/solver.py b/zipvoice/models/modules/solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4b9c4ee90046629264f26da1095921a2bed67c6
--- /dev/null
+++ b/zipvoice/models/modules/solver.py
@@ -0,0 +1,281 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+
+
+class DiffusionModel(torch.nn.Module):
+ """A wrapper of diffusion models for inference.
+ Args:
+ model: The diffusion model.
+ func_name: The function name to call.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ func_name: str = "forward_fm_decoder",
+ ):
+ super().__init__()
+ self.model = model
+ self.func_name = func_name
+ self.model_func = getattr(self.model, func_name)
+
+ def forward(
+ self,
+ t: torch.Tensor,
+ x: torch.Tensor,
+ text_condition: torch.Tensor,
+ speech_condition: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ guidance_scale: Union[float, torch.Tensor] = 0.0,
+ **kwargs
+ ) -> torch.Tensor:
+ """
+ Forward function that Handles the classifier-free guidance.
+ Args:
+ t: The current timestep, a tensor of a tensor of a single float.
+ x: The initial value, with the shape (batch, seq_len, emb_dim).
+ text_condition: The text_condition of the diffision model, with
+ the shape (batch, seq_len, emb_dim).
+ speech_condition: The speech_condition of the diffision model, with the
+ shape (batch, seq_len, emb_dim).
+ padding_mask: The mask for padding; True means masked position, with the
+ shape (batch, seq_len).
+ guidance_scale: The scale of classifier-free guidance, a float or a tensor
+ of shape (batch, 1, 1).
+ Retrun:
+ The prediction with the shape (batch, seq_len, emb_dim).
+ """
+ if not torch.is_tensor(guidance_scale):
+ guidance_scale = torch.tensor(
+ guidance_scale, dtype=t.dtype, device=t.device
+ )
+
+ if (guidance_scale == 0.0).all():
+ return self.model_func(
+ t=t,
+ xt=x,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ **kwargs
+ )
+ else:
+ assert t.dim() == 0
+
+ x = torch.cat([x] * 2, dim=0)
+ padding_mask = torch.cat([padding_mask] * 2, dim=0)
+
+ text_condition = torch.cat(
+ [torch.zeros_like(text_condition), text_condition], dim=0
+ )
+
+ if t > 0.5:
+ speech_condition = torch.cat(
+ [torch.zeros_like(speech_condition), speech_condition], dim=0
+ )
+ else:
+ guidance_scale = guidance_scale * 2
+ speech_condition = torch.cat(
+ [speech_condition, speech_condition], dim=0
+ )
+
+ data_uncond, data_cond = self.model_func(
+ t=t,
+ xt=x,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ **kwargs
+ ).chunk(2, dim=0)
+
+ res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
+ return res
+
+
+class DistillDiffusionModel(DiffusionModel):
+ """A wrapper of distilled diffusion models for inference.
+ Args:
+ model: The distilled diffusion model.
+ func_name: The function name to call.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ func_name: str = "forward_fm_decoder",
+ ):
+ super().__init__(model=model, func_name=func_name)
+
+ def forward(
+ self,
+ t: torch.Tensor,
+ x: torch.Tensor,
+ text_condition: torch.Tensor,
+ speech_condition: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ guidance_scale: Union[float, torch.Tensor] = 0.0,
+ **kwargs
+ ) -> torch.Tensor:
+ """
+ Forward function that Handles the classifier-free guidance.
+ Args:
+ t: The current timestep, a tensor of a single float.
+ x: The initial value, with the shape (batch, seq_len, emb_dim).
+ text_condition: The text_condition of the diffision model, with
+ the shape (batch, seq_len, emb_dim).
+ speech_condition: The speech_condition of the diffision model, with the
+ shape (batch, seq_len, emb_dim).
+ padding_mask: The mask for padding; True means masked position, with the
+ shape (batch, seq_len).
+ guidance_scale: The scale of classifier-free guidance, a float or a tensor
+ of shape (batch, 1, 1).
+ Retrun:
+ The prediction with the shape (batch, seq_len, emb_dim).
+ """
+ if not torch.is_tensor(guidance_scale):
+ guidance_scale = torch.tensor(
+ guidance_scale, dtype=t.dtype, device=t.device
+ )
+ return self.model_func(
+ t=t,
+ xt=x,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ guidance_scale=guidance_scale,
+ **kwargs
+ )
+
+
+class EulerSolver:
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ func_name: str = "forward_fm_decoder",
+ ):
+ """Construct a Euler Solver
+ Args:
+ model: The diffusion model.
+ func_name: The function name to call.
+ """
+
+ self.model = DiffusionModel(model, func_name=func_name)
+
+ def sample(
+ self,
+ x: torch.Tensor,
+ text_condition: torch.Tensor,
+ speech_condition: torch.Tensor,
+ padding_mask: torch.Tensor,
+ num_step: int = 10,
+ guidance_scale: Union[float, torch.Tensor] = 0.0,
+ t_start: float = 0.0,
+ t_end: float = 1.0,
+ t_shift: float = 1.0,
+ **kwargs
+ ) -> torch.Tensor:
+ """
+ Compute the sample at time `t_end` by Euler Solver.
+ Args:
+ x: The initial value at time `t_start`, with the shape (batch, seq_len,
+ emb_dim).
+ text_condition: The text condition of the diffision mode, with the
+ shape (batch, seq_len, emb_dim).
+ speech_condition: The speech condition of the diffision model, with the
+ shape (batch, seq_len, emb_dim).
+ padding_mask: The mask for padding; True means masked position, with the
+ shape (batch, seq_len).
+ num_step: The number of ODE steps.
+ guidance_scale: The scale for classifier-free guidance, which is
+ a float or a tensor with the shape (batch, 1, 1).
+ t_start: the start timestep in the range of [0, 1].
+ t_end: the end time_step in the range of [0, 1].
+ t_shift: shift the t toward smaller numbers so that the sampling
+ will emphasize low SNR region. Should be in the range of (0, 1].
+ The shifting will be more significant when the number is smaller.
+
+ Returns:
+ The approximated solution at time `t_end`.
+ """
+ device = x.device
+ assert isinstance(t_start, float) and isinstance(t_end, float)
+
+ timesteps = get_time_steps(
+ t_start=t_start,
+ t_end=t_end,
+ num_step=num_step,
+ t_shift=t_shift,
+ device=device,
+ )
+
+ for step in range(num_step):
+ v = self.model(
+ t=timesteps[step],
+ x=x,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ guidance_scale=guidance_scale,
+ **kwargs
+ )
+ x = x + v * (timesteps[step + 1] - timesteps[step])
+ return x
+
+
+class DistillEulerSolver(EulerSolver):
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ func_name: str = "forward_fm_decoder",
+ ):
+ """Construct a Euler Solver for distilled diffusion models.
+ Args:
+ model: The diffusion model.
+ """
+ self.model = DistillDiffusionModel(model, func_name=func_name)
+
+
+def get_time_steps(
+ t_start: float = 0.0,
+ t_end: float = 1.0,
+ num_step: int = 10,
+ t_shift: float = 1.0,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Compute the intermediate time steps for sampling.
+
+ Args:
+ t_start: The starting time of the sampling (default is 0).
+ t_end: The starting time of the sampling (default is 1).
+ num_step: The number of sampling.
+ t_shift: shift the t toward smaller numbers so that the sampling
+ will emphasize low SNR region. Should be in the range of (0, 1].
+ The shifting will be more significant when the number is smaller.
+ device: A torch device.
+ Returns:
+ The time step with the shape (num_step + 1,).
+ """
+
+ timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
+
+ timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
+
+ return timesteps
diff --git a/zipvoice/models/modules/zipformer.py b/zipvoice/models/modules/zipformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c00bd144dc7b5c82eec1330d010356d2eabd4ca
--- /dev/null
+++ b/zipvoice/models/modules/zipformer.py
@@ -0,0 +1,1680 @@
+#!/usr/bin/env python3
+# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey,
+# Zengwei Yao,
+# Wei Kang
+# Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import logging
+import math
+import random
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import Tensor, nn
+
+from zipvoice.models.modules.scaling import (
+ ActivationDropoutAndLinear,
+ Balancer,
+ BiasNorm,
+ Dropout2,
+ FloatLike,
+ Identity,
+ ScaledLinear,
+ ScheduledFloat,
+ SwooshR,
+ Whiten,
+ limit_param_value,
+ penalize_abs_values_gt,
+ softmax,
+)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """Create sinusoidal timestep embeddings.
+
+ :param timesteps: shape of (N) or (N, T)
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
+ / half
+ )
+
+ if timesteps.dim() == 2:
+ timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
+
+ args = timesteps[..., None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
+ return embedding
+
+
+class TTSZipformer(nn.Module):
+ """
+ Args:
+
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same
+ length as downsampling_factor if they are single ints or one-element tuples.
+ The length of downsampling_factor defines the number of stacks.
+
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
+ Note: this is in addition to the downsampling factor of 2 that is applied in
+ the frontend (self.encoder_embed).
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks,
+ one per encoder stack.
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
+ head: per stack, if a tuple..
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection
+ per attention head
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
+ Must be at least 4.
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
+
+ pos_dim (int): the dimension of each positional-encoding vector prior to
+ projection, e.g. 128.
+
+ dropout (float): dropout rate
+ warmup_batches (float): number of batches to warm up over; this controls
+ dropout of encoder layers.
+ use_time_embed: (bool): if True, take time embedding as an additional input.
+ time_embed_dim: (int): the dimension of the time embedding.
+ use_guidance_scale_embed (bool): if True, take guidance scale embedding as
+ an additional input.
+ guidance_scale_embed_dim: (int): the dimension of the guidance scale embedding.
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ downsampling_factor: Union[int, Tuple[int]] = (2, 4),
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
+ encoder_dim: int = 384,
+ query_head_dim: int = 24,
+ pos_head_dim: int = 4,
+ value_head_dim: int = 12,
+ num_heads: int = 8,
+ feedforward_dim: int = 1536,
+ pos_dim: int = 192,
+ dropout: FloatLike = None, # see code below for default
+ warmup_batches: float = 4000.0,
+ use_time_embed: bool = True,
+ time_embed_dim: int = 192,
+ use_guidance_scale_embed: bool = False,
+ guidance_scale_embed_dim: int = 192,
+ use_conv: bool = True,
+ ) -> None:
+ super(TTSZipformer, self).__init__()
+
+ if dropout is None:
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
+ if isinstance(downsampling_factor, int):
+ downsampling_factor = (downsampling_factor,)
+
+ def _to_tuple(x):
+ """Converts a single int or a 1-tuple of an int to a tuple with the same
+ length as downsampling_factor"""
+ if isinstance(x, int):
+ x = (x,)
+ if len(x) == 1:
+ x = x * len(downsampling_factor)
+ else:
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
+ return x
+
+ def _assert_downsampling_factor(factors):
+ """assert downsampling_factor follows u-net style"""
+ assert factors[0] == 1 and factors[-1] == 1
+
+ for i in range(1, len(factors) // 2 + 1):
+ assert factors[i] == factors[i - 1] * 2
+
+ for i in range(len(factors) // 2 + 1, len(factors)):
+ assert factors[i] * 2 == factors[i - 1]
+
+ _assert_downsampling_factor(downsampling_factor)
+ self.downsampling_factor = downsampling_factor # tuple
+ num_encoder_layers = _to_tuple(num_encoder_layers)
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
+ self.encoder_dim = encoder_dim
+ self.num_encoder_layers = num_encoder_layers
+ self.query_head_dim = query_head_dim
+ self.value_head_dim = value_head_dim
+ self.num_heads = num_heads
+
+ self.use_time_embed = use_time_embed
+ self.use_guidance_scale_embed = use_guidance_scale_embed
+
+ self.time_embed_dim = time_embed_dim
+ if self.use_time_embed:
+ assert time_embed_dim != -1
+ else:
+ time_embed_dim = -1
+ self.guidance_scale_embed_dim = guidance_scale_embed_dim
+
+ self.in_proj = nn.Linear(in_dim, encoder_dim)
+ self.out_proj = nn.Linear(encoder_dim, out_dim)
+
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
+ encoders = []
+
+ num_encoders = len(downsampling_factor)
+ for i in range(num_encoders):
+ encoder_layer = Zipformer2EncoderLayer(
+ embed_dim=encoder_dim,
+ pos_dim=pos_dim,
+ num_heads=num_heads,
+ query_head_dim=query_head_dim,
+ pos_head_dim=pos_head_dim,
+ value_head_dim=value_head_dim,
+ feedforward_dim=feedforward_dim,
+ use_conv=use_conv,
+ cnn_module_kernel=cnn_module_kernel[i],
+ dropout=dropout,
+ )
+
+ # For the segment of the warmup period, we let the Conv2dSubsampling
+ # layer learn something. Then we start to warm up the other encoders.
+ encoder = Zipformer2Encoder(
+ encoder_layer,
+ num_encoder_layers[i],
+ embed_dim=encoder_dim,
+ time_embed_dim=time_embed_dim,
+ pos_dim=pos_dim,
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
+ )
+
+ if downsampling_factor[i] != 1:
+ encoder = DownsampledZipformer2Encoder(
+ encoder,
+ dim=encoder_dim,
+ downsample=downsampling_factor[i],
+ )
+
+ encoders.append(encoder)
+
+ self.encoders = nn.ModuleList(encoders)
+ if self.use_time_embed:
+ self.time_embed = nn.Sequential(
+ nn.Linear(time_embed_dim, time_embed_dim * 2),
+ SwooshR(),
+ nn.Linear(time_embed_dim * 2, time_embed_dim),
+ )
+ else:
+ self.time_embed = None
+
+ if self.use_guidance_scale_embed:
+ self.guidance_scale_embed = ScaledLinear(
+ guidance_scale_embed_dim,
+ time_embed_dim,
+ bias=False,
+ initial_scale=0.1,
+ )
+ else:
+ self.guidance_scale_embed = None
+
+ def forward(
+ self,
+ x: Tensor,
+ t: Optional[Tensor] = None,
+ padding_mask: Optional[Tensor] = None,
+ guidance_scale: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
+ t:
+ A t tensor of shape (batch_size,) or (batch_size, seq_len)
+ padding_mask:
+ The mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+ guidance_scale:
+ The guidance scale in classifier-free guidance of distillation model.
+ Returns:
+ Return the output embeddings. its shape is
+ (batch_size, output_seq_len, encoder_dim)
+ """
+ x = x.permute(1, 0, 2)
+ x = self.in_proj(x)
+
+ if t is not None:
+ assert t.dim() == 1 or t.dim() == 2, t.shape
+ time_emb = timestep_embedding(t, self.time_embed_dim)
+ if guidance_scale is not None:
+ assert (
+ guidance_scale.dim() == 1 or guidance_scale.dim() == 2
+ ), guidance_scale.shape
+ guidance_scale_emb = self.guidance_scale_embed(
+ timestep_embedding(guidance_scale, self.guidance_scale_embed_dim)
+ )
+ time_emb = time_emb + guidance_scale_emb
+ time_emb = self.time_embed(time_emb)
+ else:
+ time_emb = None
+
+ attn_mask = None
+
+ for i, module in enumerate(self.encoders):
+ x = module(
+ x,
+ time_emb=time_emb,
+ src_key_padding_mask=padding_mask,
+ attn_mask=attn_mask,
+ )
+ x = self.out_proj(x)
+ x = x.permute(1, 0, 2)
+ return x
+
+
+def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
+ return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
+
+
+class Zipformer2EncoderLayer(nn.Module):
+ """
+ Args:
+ embed_dim: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ feedforward_dim: the dimension of the feedforward network model (required).
+ dropout: the dropout value (default=0.1).
+ cnn_module_kernel (int): Kernel size of convolution module (default=31).
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> pos_emb = torch.rand(32, 19, 512)
+ >>> out = encoder_layer(src, pos_emb)
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ value_head_dim: int,
+ feedforward_dim: int,
+ dropout: FloatLike = 0.1,
+ cnn_module_kernel: int = 31,
+ use_conv: bool = True,
+ attention_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ conv_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ const_attention_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.25), (4000.0, 0.025), default=0
+ ),
+ ff2_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ ff3_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ bypass_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.5), (4000.0, 0.02), default=0
+ ),
+ ) -> None:
+ super(Zipformer2EncoderLayer, self).__init__()
+ self.embed_dim = embed_dim
+
+ # self.bypass implements layer skipping as well as bypass.
+ self.bypass = BypassModule(
+ embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
+ )
+ # bypass_mid is bypass used in the middle of the layer.
+ self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
+
+ # skip probability for dynamic modules (meaning: anything but feedforward).
+ self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
+ # an additional skip probability that applies to ConvModule to stop it from
+ # contributing too much early on.
+ self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
+
+ # ff2_skip_rate is to prevent the ff2 module from having output that's too big
+ # compared to its residual.
+ self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
+ self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
+
+ self.const_attention_rate = copy.deepcopy(const_attention_rate)
+
+ self.self_attn_weights = RelPositionMultiheadAttentionWeights(
+ embed_dim,
+ pos_dim=pos_dim,
+ num_heads=num_heads,
+ query_head_dim=query_head_dim,
+ pos_head_dim=pos_head_dim,
+ dropout=0.0,
+ )
+
+ self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
+
+ self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
+
+ self.feed_forward1 = FeedforwardModule(
+ embed_dim, (feedforward_dim * 3) // 4, dropout
+ )
+
+ self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
+
+ self.feed_forward3 = FeedforwardModule(
+ embed_dim, (feedforward_dim * 5) // 4, dropout
+ )
+
+ self.nonlin_attention = NonlinAttention(
+ embed_dim, hidden_channels=3 * embed_dim // 4
+ )
+
+ self.use_conv = use_conv
+
+ if self.use_conv:
+ self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel)
+
+ self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel)
+
+ self.norm = BiasNorm(embed_dim)
+
+ self.balancer1 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.2,
+ max_abs=4.0,
+ )
+
+ # balancer for output of NonlinAttentionModule
+ self.balancer_na = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
+ prob=0.05, # out of concern for memory usage
+ )
+
+ # balancer for output of feedforward2, prevent it from staying too
+ # small. give this a very small probability, even at the start of
+ # training, it's to fix a rare problem and it's OK to fix it slowly.
+ self.balancer_ff2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
+ max_abs=2.0,
+ prob=0.05,
+ )
+
+ self.balancer_ff3 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
+ max_abs=4.0,
+ prob=0.05,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(4.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.balancer2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.1,
+ max_abs=4.0,
+ )
+
+ def get_sequence_dropout_mask(
+ self, x: Tensor, dropout_rate: float
+ ) -> Optional[Tensor]:
+ if (
+ dropout_rate == 0.0
+ or not self.training
+ or torch.jit.is_scripting()
+ or torch.jit.is_tracing()
+ ):
+ return None
+ batch_size = x.shape[1]
+ mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
+ return mask
+
+ def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
+ """
+ Apply sequence-level dropout to x.
+ x shape: (seq_len, batch_size, embed_dim)
+ """
+ dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
+ if dropout_mask is None:
+ return x
+ else:
+ return x * dropout_mask
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ time_emb: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Pass the input through the encoder layer.
+ Args:
+ src: the sequence to the encoder (required):
+ shape (seq_len, batch_size, embedding_dim).
+ pos_emb: (1, 2*seq_len-1, pos_emb_dim) or
+ (batch_size, 2*seq_len-1, pos_emb_dim)
+ time_emb: the embedding representing the current timestep
+ shape (batch_size, embedding_dim) or (seq_len, batch_size, embedding_dim).
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
+ or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len)
+ or (tgt_seq_len, src_seq_len). True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
+ True means masked position. May be None.
+
+ Returns:
+ A tensor which has the same shape as src
+ """
+ src_orig = src
+
+ # dropout rate for non-feedforward submodules
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ attention_skip_rate = 0.0
+ else:
+ attention_skip_rate = (
+ float(self.attention_skip_rate) if self.training else 0.0
+ )
+
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ attn_weights = self.self_attn_weights(
+ src,
+ pos_emb=pos_emb,
+ attn_mask=attn_mask,
+ key_padding_mask=src_key_padding_mask,
+ )
+ if time_emb is not None:
+
+ src = src + time_emb
+
+ src = src + self.feed_forward1(src)
+
+ self_attn_dropout_mask = self.get_sequence_dropout_mask(
+ src, attention_skip_rate
+ )
+
+ selected_attn_weights = attn_weights[0:1]
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif self.training and random.random() < float(self.const_attention_rate):
+ # Make attention weights constant. The intention is to
+ # encourage these modules to do something similar to an
+ # averaging-over-time operation.
+ # only need the mask, can just use the 1st one and expand later
+ selected_attn_weights = selected_attn_weights[0:1]
+ selected_attn_weights = (selected_attn_weights > 0.0).to(
+ selected_attn_weights.dtype
+ )
+ selected_attn_weights = selected_attn_weights * (
+ 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
+ )
+
+ na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
+
+ src = src + (
+ na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
+ )
+
+ self_attn = self.self_attn1(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if self.use_conv:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+
+ if time_emb is not None:
+ src = src + time_emb
+
+ src = src + self.sequence_dropout(
+ self.conv_module1(
+ src,
+ src_key_padding_mask=src_key_padding_mask,
+ ),
+ conv_skip_rate,
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff2_skip_rate = 0.0
+ else:
+ ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
+ )
+
+ # bypass in the middle of the layer.
+ src = self.bypass_mid(src_orig, src)
+
+ self_attn = self.self_attn2(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if self.use_conv:
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+
+ if time_emb is not None:
+ src = src + time_emb
+
+ src = src + self.sequence_dropout(
+ self.conv_module2(
+ src,
+ src_key_padding_mask=src_key_padding_mask,
+ ),
+ conv_skip_rate,
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff3_skip_rate = 0.0
+ else:
+ ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
+ )
+
+ src = self.balancer1(src)
+ src = self.norm(src)
+
+ src = self.bypass(src_orig, src)
+
+ src = self.balancer2(src)
+ src = self.whiten(src)
+
+ return src
+
+
+class Zipformer2Encoder(nn.Module):
+ r"""Zipformer2Encoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+ pos_dim: the dimension for the relative positional encoding
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = zipformer_encoder(src)
+ """
+
+ def __init__(
+ self,
+ encoder_layer: nn.Module,
+ num_layers: int,
+ embed_dim: int,
+ time_embed_dim: int,
+ pos_dim: int,
+ warmup_begin: float,
+ warmup_end: float,
+ initial_layerdrop_rate: float = 0.5,
+ final_layerdrop_rate: float = 0.05,
+ ) -> None:
+ super().__init__()
+ self.encoder_pos = CompactRelPositionalEncoding(
+ pos_dim, dropout_rate=0.15, length_factor=1.0
+ )
+ if time_embed_dim != -1:
+ self.time_emb = nn.Sequential(
+ SwooshR(),
+ nn.Linear(time_embed_dim, embed_dim),
+ )
+ else:
+ self.time_emb = None
+
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ assert 0 <= warmup_begin <= warmup_end
+
+ delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
+ cur_begin = warmup_begin # interpreted as a training batch index
+ for i in range(num_layers):
+ cur_end = cur_begin + delta
+ self.layers[i].bypass.skip_rate = ScheduledFloat(
+ (cur_begin, initial_layerdrop_rate),
+ (cur_end, final_layerdrop_rate),
+ default=0.0,
+ )
+ cur_begin = cur_end
+
+ def forward(
+ self,
+ src: Tensor,
+ time_emb: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required):
+ shape (seq_len, batch_size, embedding_dim).
+ time_emb: the embedding representing the current timestep:
+ shape (batch_size, embedding_dim)
+ or (seq_len, batch_size, embedding_dim) .
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
+ or (seq_len, seq_len), interpreted as
+ (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
+ True means masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ pos_emb = self.encoder_pos(src)
+ if self.time_emb is not None:
+ assert time_emb is not None
+ time_emb = self.time_emb(time_emb)
+ else:
+ assert time_emb is None
+
+ output = src
+
+ for i, mod in enumerate(self.layers):
+ output = mod(
+ output,
+ pos_emb,
+ time_emb=time_emb,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+
+ return output
+
+
+class BypassModule(nn.Module):
+ """
+ An nn.Module that implements a learnable bypass scale, and also randomized
+ per-sequence layer-skipping. The bypass is limited during early stages of training
+ to be close to "straight-through", i.e. to not do the bypass operation much
+ initially, in order to force all the modules to learn something.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ skip_rate: FloatLike = 0.0,
+ straight_through_rate: FloatLike = 0.0,
+ scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
+ scale_max: FloatLike = 1.0,
+ ):
+ super().__init__()
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+ self.skip_rate = copy.deepcopy(skip_rate)
+ self.straight_through_rate = copy.deepcopy(straight_through_rate)
+ self.scale_min = copy.deepcopy(scale_min)
+ self.scale_max = copy.deepcopy(scale_max)
+
+ def _get_bypass_scale(self, batch_size: int):
+ # returns bypass-scale of shape (num_channels,),
+ # or (batch_size, num_channels,). This is actually the
+ # scale on the non-residual term, so 0 corresponds to bypassing
+ # this module.
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
+ return self.bypass_scale
+ else:
+ ans = limit_param_value(
+ self.bypass_scale,
+ min=float(self.scale_min),
+ max=float(self.scale_max),
+ )
+ skip_rate = float(self.skip_rate)
+ if skip_rate != 0.0:
+ mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
+ ans = ans * mask
+ # now ans is of shape (batch_size, num_channels), and is zero for
+ # sequences on which we have randomly chosen to do layer-skipping.
+ straight_through_rate = float(self.straight_through_rate)
+ if straight_through_rate != 0.0:
+ mask = (
+ torch.rand((batch_size, 1), device=ans.device)
+ < straight_through_rate
+ )
+ ans = torch.maximum(ans, mask.to(ans.dtype))
+ return ans
+
+ def forward(self, src_orig: Tensor, src: Tensor):
+ """
+ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
+ Returns: something with the same shape as src and src_orig
+ """
+ bypass_scale = self._get_bypass_scale(src.shape[1])
+ return src_orig + (src - src_orig) * bypass_scale
+
+
+class DownsampledZipformer2Encoder(nn.Module):
+ r"""
+ DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame
+ rate, after convolutional downsampling, and then upsampled again at the output, and
+ combined with the origin input, so that the output has the same shape as the input.
+ """
+
+ def __init__(self, encoder: nn.Module, dim: int, downsample: int):
+ super(DownsampledZipformer2Encoder, self).__init__()
+ self.downsample_factor = downsample
+ self.downsample = SimpleDownsample(downsample)
+ self.num_layers = encoder.num_layers
+ self.encoder = encoder
+ self.upsample = SimpleUpsample(downsample)
+ self.out_combiner = BypassModule(dim, straight_through_rate=0)
+
+ def forward(
+ self,
+ src: Tensor,
+ time_emb: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Downsample, go through encoder, upsample.
+
+ Args:
+ src: the sequence to the encoder (required):
+ shape (seq_len, batch_size, embedding_dim).
+ time_emb: the embedding representing the current timestep:
+ shape (batch_size, embedding_dim)
+ or (seq_len, batch_size, embedding_dim) .
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape
+ (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
+ or (seq_len, seq_len), interpreted as
+ (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
+ True means masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ src_orig = src
+ src = self.downsample(src)
+ ds = self.downsample_factor
+ if time_emb is not None and time_emb.dim() == 3:
+ time_emb = time_emb[::ds]
+ if attn_mask is not None:
+ attn_mask = attn_mask[::ds, ::ds]
+ if src_key_padding_mask is not None:
+ src_key_padding_mask = src_key_padding_mask[..., ::ds]
+
+ src = self.encoder(
+ src,
+ time_emb=time_emb,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ src = self.upsample(src)
+ # remove any extra frames that are not a multiple of downsample_factor
+ src = src[: src_orig.shape[0]]
+
+ return self.out_combiner(src_orig, src)
+
+
+class SimpleDownsample(torch.nn.Module):
+ """
+ Does downsampling with attention, by weighted sum.
+ """
+
+ def __init__(self, downsample: int):
+ super(SimpleDownsample, self).__init__()
+
+ self.bias = nn.Parameter(torch.zeros(downsample))
+
+ self.name = None # will be set from training code
+
+ self.downsample = downsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, in_channels)
+ Returns a tensor of shape
+ ( (seq_len+downsample-1)//downsample, batch_size, channels)
+ """
+ (seq_len, batch_size, in_channels) = src.shape
+ ds = self.downsample
+ d_seq_len = (seq_len + ds - 1) // ds
+
+ # Pad to an exact multiple of self.downsample
+ # right-pad src, repeating the last element.
+ pad = d_seq_len * ds - seq_len
+ src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
+ src = torch.cat((src, src_extra), dim=0)
+ assert src.shape[0] == d_seq_len * ds
+
+ src = src.reshape(d_seq_len, ds, batch_size, in_channels)
+
+ weights = self.bias.softmax(dim=0)
+ # weights: (downsample, 1, 1)
+ weights = weights.unsqueeze(-1).unsqueeze(-1)
+
+ # ans1 is the first `in_channels` channels of the output
+ ans = (src * weights).sum(dim=1)
+
+ return ans
+
+
+class SimpleUpsample(torch.nn.Module):
+ """
+ A very simple form of upsampling that just repeats the input.
+ """
+
+ def __init__(self, upsample: int):
+ super(SimpleUpsample, self).__init__()
+ self.upsample = upsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, num_channels)
+ Returns a tensor of shape
+ ( (seq_len*upsample), batch_size, num_channels)
+ """
+ upsample = self.upsample
+ (seq_len, batch_size, num_channels) = src.shape
+ src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
+ src = src.reshape(seq_len * upsample, batch_size, num_channels)
+ return src
+
+
+class CompactRelPositionalEncoding(torch.nn.Module):
+ """
+ Relative positional encoding module. This version is "compact" meaning it is able
+ to encode the important information about the relative position in a relatively
+ small number of dimensions. The goal is to make it so that small differences between
+ large relative offsets (e.g. 1000 vs. 1001) make very little difference to the
+ embedding. Such differences were potentially important when encoding absolute
+ position, but not important when encoding relative position because there is now no
+ need to compare two large offsets with each other.
+
+ Our embedding works by projecting the interval [-infinity,infinity] to a finite
+ interval using the atan() function, before doing the Fourier transform of that fixed
+ interval. The atan() function would compress the "long tails" too small, making it
+ hard to distinguish between different magnitudes of large offsets, so we use a
+ logarithmic function to compress large offsets to a smaller range before applying
+ atan(). Scalings are chosen in such a way that the embedding can clearly distinguish
+ individual offsets as long as they are quite close to the origin, e.g. abs(offset)
+ <= about sqrt(embedding_dim)
+
+
+ Args:
+ embed_dim: Embedding dimension.
+ dropout_rate: Dropout rate.
+ max_len: Maximum input length: just a heuristic for initialization.
+ length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
+ less weight to small differences of offset near the origin.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ dropout_rate: FloatLike,
+ max_len: int = 1000,
+ length_factor: float = 1.0,
+ ) -> None:
+ """Construct a CompactRelPositionalEncoding object."""
+ super(CompactRelPositionalEncoding, self).__init__()
+ self.embed_dim = embed_dim
+ assert embed_dim % 2 == 0, embed_dim
+ self.dropout = Dropout2(dropout_rate)
+ self.pe = None
+ assert length_factor >= 1.0, length_factor
+ self.length_factor = length_factor
+ self.extend_pe(torch.tensor(0.0).expand(max_len))
+
+ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
+ """Reset the positional encodings."""
+ T = x.size(0) + left_context_len
+
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(0) >= T * 2 - 1:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+
+ # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
+ x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
+
+ freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
+
+ # `compression_length` this is arbitrary/heuristic, if it is larger we have more
+ # resolution for small time offsets but less resolution for large time offsets.
+ compression_length = self.embed_dim**0.5
+ # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity
+ # to infinity; but it does so more slowly than T for large absolute values of T.
+ # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which is
+ # important.
+ x_compressed = (
+ compression_length
+ * x.sign()
+ * ((x.abs() + compression_length).log() - math.log(compression_length))
+ )
+
+ # if self.length_factor == 1.0, then length_scale is chosen so that the
+ # FFT can exactly separate points close to the origin (T == 0). So this
+ # part of the formulation is not really heuristic.
+ # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
+ length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
+
+ # note for machine implementations: if atan is not available, we can use:
+ # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
+ # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 ,
+ # atan(x))
+ x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
+
+ cosines = (x_atan * freqs).cos()
+ sines = (x_atan * freqs).sin()
+
+ pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
+ pe[:, 0::2] = cosines
+ pe[:, 1::2] = sines
+ pe[:, -1] = 1.0 # for bias.
+
+ self.pe = pe.to(dtype=x.dtype)
+
+ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
+ """Create positional encoding.
+
+ Args:
+ x (Tensor): Input tensor (time, batch, `*`).
+ left_context_len: (int): Length of cached left context.
+
+ Returns:
+ positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
+ """
+ self.extend_pe(x, left_context_len)
+ x_size_left = x.size(0) + left_context_len
+ # length of positive side: x.size(0) + left_context_len
+ # length of negative side: x.size(0)
+ pos_emb = self.pe[
+ self.pe.size(0) // 2
+ - x_size_left
+ + 1 : self.pe.size(0) // 2 # noqa E203
+ + x.size(0),
+ :,
+ ]
+ pos_emb = pos_emb.unsqueeze(0)
+ return self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttentionWeights(nn.Module):
+ r"""Module that computes multi-head attention weights with relative position
+ encoding. Various other modules consume the resulting attention weights:
+ see, for example, the SimpleAttention module which allows you to compute
+ conventional attention.
+
+ This is a quite heavily modified from: "Transformer-XL: Attentive Language
+ Models Beyond a Fixed-Length Context",
+ we have to write up the differences.
+
+
+ Args:
+ embed_dim: number of channels at the input to this module, e.g. 256
+ pos_dim: dimension of the positional encoding vectors, e.g. 128.
+ num_heads: number of heads to compute weights for, e.g. 8
+ query_head_dim: dimension of the query (and key), per head. e.g. 24.
+ pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
+ dropout: dropout probability for attn_output_weights. Default: 0.0.
+ pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
+ any given call to forward(), in training time.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ dropout: float = 0.0,
+ pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
+ ) -> None:
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.query_head_dim = query_head_dim
+ self.pos_head_dim = pos_head_dim
+ self.dropout = dropout
+ self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
+ self.name = None # will be overwritten in training code; for diagnostics.
+
+ key_head_dim = query_head_dim
+ in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
+
+ # the initial_scale is supposed to take over the "scaling" factor of
+ # head_dim ** -0.5 that has been used in previous forms of attention,
+ # dividing it between the query and key. Note: this module is intended
+ # to be used with the ScaledAdam optimizer; with most other optimizers,
+ # it would be necessary to apply the scaling factor in the forward function.
+ self.in_proj = ScaledLinear(
+ embed_dim,
+ in_proj_dim,
+ bias=True,
+ initial_scale=query_head_dim**-0.25,
+ )
+
+ self.whiten_keys = Whiten(
+ num_groups=num_heads,
+ whitening_limit=_whitening_schedule(3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.025,
+ )
+
+ # add a balancer for the keys that runs with very small probability, and
+ # tries to enforce that all dimensions have mean around zero. The
+ # weights produced by this module are invariant to adding a constant to
+ # the keys, so the derivative of the bias is mathematically zero; but
+ # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
+ # bias because the small numerical roundoff tends to have a non-random
+ # sign. This module is intended to prevent that. Use a very small
+ # probability; that should be sufficient to fix the problem.
+ self.balance_keys = Balancer(
+ key_head_dim * num_heads,
+ channel_dim=-1,
+ min_positive=0.4,
+ max_positive=0.6,
+ min_abs=0.0,
+ max_abs=100.0,
+ prob=0.025,
+ )
+
+ # linear transformation for positional encoding.
+ self.linear_pos = ScaledLinear(
+ pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
+ )
+
+ # the following are for diagnostics only, see --print-diagnostics option
+ self.copy_pos_query = Identity()
+ self.copy_query = Identity()
+
+ def forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""
+ Args:
+ x: input of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len).
+ Positions that are True in this mask will be ignored as sources in the
+ attention weighting.
+ attn_mask: mask of shape (seq_len, seq_len) or
+ (batch_size, seq_len, seq_len), interpreted as
+ ([batch_size,] tgt_seq_len, src_seq_len)
+ saying which positions are allowed to attend to which other positions.
+ Returns:
+ a tensor of attention weights, of
+ shape (hum_heads, batch_size, seq_len, seq_len)
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+ """
+ x = self.in_proj(x)
+ query_head_dim = self.query_head_dim
+ pos_head_dim = self.pos_head_dim
+ num_heads = self.num_heads
+
+ seq_len, batch_size, _ = x.shape
+
+ query_dim = query_head_dim * num_heads
+
+ # self-attention
+ q = x[..., 0:query_dim]
+ k = x[..., query_dim : 2 * query_dim]
+ # p is the position-encoding query
+ p = x[..., 2 * query_dim :]
+ assert p.shape[-1] == num_heads * pos_head_dim, (
+ p.shape[-1],
+ num_heads,
+ pos_head_dim,
+ )
+
+ q = self.copy_query(q) # for diagnostics only, does nothing.
+ k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
+ p = self.copy_pos_query(p) # for diagnostics only, does nothing.
+
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+ k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
+
+ # time1 refers to target, time2 refers to source.
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
+
+ attn_scores = torch.matmul(q, k)
+
+ use_pos_scores = False
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ # We can't put random.random() in the same line
+ use_pos_scores = True
+ elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
+ use_pos_scores = True
+
+ if use_pos_scores:
+ pos_emb = self.linear_pos(pos_emb)
+ seq_len2 = 2 * seq_len - 1
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+ 2, 0, 3, 1
+ )
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head,
+ # batch, time1, seq_len2) [where seq_len2 represents relative position.]
+ pos_scores = torch.matmul(p, pos_emb)
+ # the following .as_strided() expression converts the last axis of
+ # pos_scores from relative to absolute position. I don't know whether I
+ # might have got the time-offsets backwards or not, but let this code define
+ # which way round it is supposed to be.
+ if torch.jit.is_tracing():
+ (num_heads, batch_size, time1, n) = pos_scores.shape
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+ cols = torch.arange(seq_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+ pos_scores = pos_scores.reshape(-1, n)
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
+ else:
+ pos_scores = pos_scores.as_strided(
+ (num_heads, batch_size, seq_len, seq_len),
+ (
+ pos_scores.stride(0),
+ pos_scores.stride(1),
+ pos_scores.stride(2) - pos_scores.stride(3),
+ pos_scores.stride(3),
+ ),
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
+ )
+
+ attn_scores = attn_scores + pos_scores
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif self.training and random.random() < 0.1:
+ # This is a harder way of limiting the attention scores to not be
+ # too large. It incurs a penalty if any of them has an absolute
+ # value greater than 50.0. this should be outside the normal range
+ # of the attention scores. We use this mechanism instead of, say,
+ # something added to the loss function involving the entropy,
+ # because once the entropy gets very small gradients through the
+ # softmax can become very small, and we'd get zero derivatives. The
+ # choices of 1.0e-04 as the scale on the penalty makes this
+ # mechanism vulnerable to the absolute scale of the loss function,
+ # but we view this as a failsafe to avoid "implausible" parameter
+ # values rather than a regularization method that should be active
+ # under normal circumstances.
+ attn_scores = penalize_abs_values_gt(
+ attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
+ )
+
+ assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ if attn_mask is not None:
+ assert attn_mask.dtype == torch.bool
+ # use -1000 to avoid nan's where attn_mask and key_padding_mask make
+ # all scores zero. It's important that this be large enough that exp(-1000)
+ # is exactly zero, for reasons related to const_attention_rate, it
+ # compares the final weights with zero.
+ attn_scores = attn_scores.masked_fill(attn_mask, -1000)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (
+ batch_size,
+ seq_len,
+ ), key_padding_mask.shape
+ attn_scores = attn_scores.masked_fill(
+ key_padding_mask.unsqueeze(1),
+ -1000,
+ )
+
+ # We use our own version of softmax, defined in scaling.py, which should
+ # save a little of the memory used in backprop by, if we are in
+ # automatic mixed precision mode (amp / autocast), by only storing the
+ # half-precision output for backprop purposes.
+ attn_weights = softmax(attn_scores, dim=-1)
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif random.random() < 0.001 and not self.training:
+ self._print_attn_entropy(attn_weights)
+
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+
+ return attn_weights
+
+ def _print_attn_entropy(self, attn_weights: Tensor):
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
+
+ with torch.no_grad():
+ with torch.amp.autocast("cuda", enabled=False):
+ attn_weights = attn_weights.to(torch.float32)
+ attn_weights_entropy = (
+ -((attn_weights + 1.0e-20).log() * attn_weights)
+ .sum(dim=-1)
+ .mean(dim=(1, 2))
+ )
+ logging.debug(
+ f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
+ )
+
+
+class SelfAttention(nn.Module):
+ """
+ The simplest possible attention module. This one works with already-computed
+ attention weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
+
+ Args:
+ embed_dim: the input and output embedding dimension
+ num_heads: the number of attention heads
+ value_head_dim: the value dimension per head
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ value_head_dim: int,
+ ) -> None:
+ super().__init__()
+ self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
+
+ self.out_proj = ScaledLinear(
+ num_heads * value_head_dim,
+ embed_dim,
+ bias=True,
+ initial_scale=0.05,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """
+ Args:
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
+ attn_weights.sum(dim=-1) == 1.
+ Returns:
+ a tensor with the same shape as x.
+ """
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
+ value_head_dim = x.shape[-1]
+
+ # todo: see whether there is benefit in overriding matmul
+ x = torch.matmul(attn_weights, x)
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+ x = (
+ x.permute(2, 1, 0, 3)
+ .contiguous()
+ .view(seq_len, batch_size, num_heads * value_head_dim)
+ )
+
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+ x = self.out_proj(x)
+ x = self.whiten(x)
+
+ return x
+
+
+class FeedforwardModule(nn.Module):
+ """Feedforward module in TTSZipformer model."""
+
+ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
+ super(FeedforwardModule, self).__init__()
+ self.in_proj = nn.Linear(embed_dim, feedforward_dim)
+
+ self.hidden_balancer = Balancer(
+ feedforward_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=1.0,
+ min_abs=0.75,
+ max_abs=5.0,
+ )
+
+ # shared_dim=0 means we share the dropout mask along the time axis
+ self.out_proj = ActivationDropoutAndLinear(
+ feedforward_dim,
+ embed_dim,
+ activation="SwooshL",
+ dropout_p=dropout,
+ dropout_shared_dim=0,
+ bias=True,
+ initial_scale=0.1,
+ )
+
+ self.out_whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(self, x: Tensor):
+ x = self.in_proj(x)
+ x = self.hidden_balancer(x)
+ # out_proj contains SwooshL activation, then dropout, then linear.
+ x = self.out_proj(x)
+ x = self.out_whiten(x)
+ return x
+
+
+class NonlinAttention(nn.Module):
+ """This is like the ConvolutionModule, but refactored so that we use multiplication
+ by attention weights (borrowed from the attention module) in place of actual
+ convolution. We also took out the second nonlinearity, the one after the
+ attention mechanism.
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ hidden_channels: int,
+ ) -> None:
+ super().__init__()
+
+ self.hidden_channels = hidden_channels
+
+ self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
+
+ # balancer that goes before the sigmoid. Have quite a large min_abs value, at
+ # 2.0, because we noticed that well-trained instances of this module have
+ # abs-value before the sigmoid starting from about 3, and poorly-trained
+ # instances of the module have smaller abs values before the sigmoid.
+ self.balancer = Balancer(
+ hidden_channels,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
+ max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
+ min_abs=0.5,
+ max_abs=5.0,
+ )
+ self.tanh = nn.Tanh()
+
+ self.identity1 = Identity() # for diagnostics.
+ self.identity2 = Identity() # for diagnostics.
+ self.identity3 = Identity() # for diagnostics.
+
+ self.out_proj = ScaledLinear(
+ hidden_channels, channels, bias=True, initial_scale=0.05
+ )
+
+ self.whiten1 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.whiten2 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """.
+ Args:
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+ Returns:
+ a Tensor with the same shape as x
+ """
+ x = self.in_proj(x)
+
+ (seq_len, batch_size, _) = x.shape
+ hidden_channels = self.hidden_channels
+
+ s, x, y = x.chunk(3, dim=2)
+
+ # s will go through tanh.
+
+ s = self.balancer(s)
+ s = self.tanh(s)
+
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+ x = self.whiten1(x)
+ x = x * s
+ x = self.identity1(x) # diagnostics only, it's the identity.
+
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = torch.matmul(attn_weights, x)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+ y = self.identity2(y)
+ x = x * y
+ x = self.identity3(x)
+
+ x = self.out_proj(x)
+ x = self.whiten2(x)
+ return x
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Zipformer2 model.
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ bias (bool): Whether to use bias in conv layers (default=True).
+
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ ) -> None:
+ """Construct a ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ bottleneck_dim = channels
+
+ self.in_proj = nn.Linear(
+ channels,
+ 2 * bottleneck_dim,
+ )
+ # the gradients on in_proj are a little noisy, likely to do with the
+ # sigmoid in glu.
+
+ # after in_proj we put x through a gated linear unit (nn.functional.glu). For
+ # most layers the normal rms value of channels of x seems to be in the range 1
+ # to 4, but sometimes, for some reason, for layer 0 the rms ends up being very
+ # large, between 50 and 100 for different channels. This will cause very peaky
+ # and sparse derivatives for the sigmoid gating function, which will tend to
+ # make the loss function not learn effectively. (for most layers the average
+ # absolute values are in the range 0.5..9.0, and the average p(x>0), i.e.
+ # positive proportion, at the output of pointwise_conv1.output is around 0.35 to
+ # 0.45 for different layers, which likely breaks down as 0.5 for the "linear"
+ # half and 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that
+ # if we constrain the rms values to a reasonable range via a constraint of
+ # max_abs=10.0, it will be in a better position to start learning something,
+ # i.e. to latch onto the correct range.
+ self.balancer1 = Balancer(
+ bottleneck_dim,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
+ max_positive=1.0,
+ min_abs=1.5,
+ max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
+ )
+
+ self.activation1 = Identity() # for diagnostics
+
+ self.sigmoid = nn.Sigmoid()
+
+ self.activation2 = Identity() # for diagnostics
+
+ assert kernel_size % 2 == 1
+
+ self.depthwise_conv = nn.Conv1d(
+ in_channels=bottleneck_dim,
+ out_channels=bottleneck_dim,
+ groups=bottleneck_dim,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ )
+
+ self.balancer2 = Balancer(
+ bottleneck_dim,
+ channel_dim=1,
+ min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
+ max_positive=1.0,
+ min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
+ max_abs=10.0,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.out_proj = ActivationDropoutAndLinear(
+ bottleneck_dim,
+ channels,
+ activation="SwooshR",
+ dropout_p=0.0,
+ initial_scale=0.05,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Compute convolution module.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ src_key_padding_mask: the mask for the src keys per batch (optional):
+ (batch, #time), contains True in masked positions.
+
+ Returns:
+ Tensor: Output tensor (#time, batch, channels).
+
+ """
+
+ x = self.in_proj(x) # (time, batch, 2*channels)
+
+ x, s = x.chunk(2, dim=2)
+ s = self.balancer1(s)
+ s = self.sigmoid(s)
+ x = self.activation1(x) # identity.
+ x = x * s
+ x = self.activation2(x) # identity
+
+ # (time, batch, channels)
+
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ if src_key_padding_mask is not None:
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+ x = self.depthwise_conv(x)
+
+ x = self.balancer2(x)
+ x = x.permute(2, 0, 1) # (time, batch, channels)
+
+ x = self.whiten(x) # (time, batch, channels)
+ x = self.out_proj(x) # (time, batch, channels)
+
+ return x
diff --git a/zipvoice/models/modules/zipformer_two_stream.py b/zipvoice/models/modules/zipformer_two_stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dd05094070d8ae94a72b5774ed33fa9a0217e1
--- /dev/null
+++ b/zipvoice/models/modules/zipformer_two_stream.py
@@ -0,0 +1,264 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import Tensor, nn
+
+from zipvoice.models.modules.scaling import FloatLike, ScheduledFloat, SwooshR
+from zipvoice.models.modules.zipformer import (
+ DownsampledZipformer2Encoder,
+ TTSZipformer,
+ Zipformer2Encoder,
+ Zipformer2EncoderLayer,
+)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """Create sinusoidal timestep embeddings.
+
+ :param timesteps: shape of (N) or (N, T)
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
+ / half
+ )
+
+ if timesteps.dim() == 2:
+ timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
+
+ args = timesteps[..., None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
+ return embedding
+
+
+class TTSZipformerTwoStream(TTSZipformer):
+ """
+ Args:
+
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same
+ length as downsampling_factor if they are single ints or one-element tuples.
+ The length of downsampling_factor defines the number of stacks.
+
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
+ Note: this is in addition to the downsampling factor of 2 that is applied in
+ the frontend (self.encoder_embed).
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks,
+ one per encoder stack.
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
+ head: per stack, if a tuple..
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection
+ per attention head
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
+ Must be at least 4.
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
+
+ pos_dim (int): the dimension of each positional-encoding vector prior to
+ projection, e.g. 128.
+
+ dropout (float): dropout rate
+ warmup_batches (float): number of batches to warm up over; this controls
+ dropout of encoder layers.
+ use_time_embed: (bool): if True, do not take time embedding as additional input.
+ time_embed_dim: (int): the dimension of the time embedding.
+ """
+
+ def __init__(
+ self,
+ in_dim: Tuple[int],
+ out_dim: Tuple[int],
+ downsampling_factor: Tuple[int] = (2, 4),
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
+ encoder_dim: int = 384,
+ query_head_dim: int = 24,
+ pos_head_dim: int = 4,
+ value_head_dim: int = 12,
+ num_heads: int = 8,
+ feedforward_dim: int = 1536,
+ pos_dim: int = 192,
+ dropout: FloatLike = None, # see code below for default
+ warmup_batches: float = 4000.0,
+ use_time_embed: bool = True,
+ time_embed_dim: int = 192,
+ use_conv: bool = True,
+ ) -> None:
+ nn.Module.__init__(self)
+
+ if dropout is None:
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
+ if isinstance(downsampling_factor, int):
+ downsampling_factor = (downsampling_factor,)
+
+ def _to_tuple(x):
+ """Converts a single int or a 1-tuple of an int to a tuple with the same
+ length as downsampling_factor"""
+ if isinstance(x, int):
+ x = (x,)
+ if len(x) == 1:
+ x = x * len(downsampling_factor)
+ else:
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
+ return x
+
+ def _assert_downsampling_factor(factors):
+ """assert downsampling_factor follows u-net style"""
+ assert factors[0] == 1 and factors[-1] == 1
+
+ for i in range(1, len(factors) // 2 + 1):
+ assert factors[i] == factors[i - 1] * 2
+
+ for i in range(len(factors) // 2 + 1, len(factors)):
+ assert factors[i] * 2 == factors[i - 1]
+
+ _assert_downsampling_factor(downsampling_factor)
+ self.downsampling_factor = downsampling_factor # tuple
+ num_encoder_layers = _to_tuple(num_encoder_layers)
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
+ self.encoder_dim = encoder_dim
+ self.num_encoder_layers = num_encoder_layers
+ self.query_head_dim = query_head_dim
+ self.value_head_dim = value_head_dim
+ self.num_heads = num_heads
+
+ self.use_time_embed = use_time_embed
+
+ self.time_embed_dim = time_embed_dim
+ if self.use_time_embed:
+ assert time_embed_dim != -1
+ else:
+ time_embed_dim = -1
+
+ assert len(in_dim) == len(out_dim) == 2
+
+ self.in_dim = in_dim
+ self.in_proj = nn.ModuleList(
+ [nn.Linear(in_dim[0], encoder_dim), nn.Linear(in_dim[1], encoder_dim)]
+ )
+ self.out_dim = out_dim
+ self.out_proj = nn.ModuleList(
+ [nn.Linear(encoder_dim, out_dim[0]), nn.Linear(encoder_dim, out_dim[1])]
+ )
+
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
+ encoders = []
+
+ num_encoders = len(downsampling_factor)
+ for i in range(num_encoders):
+ encoder_layer = Zipformer2EncoderLayer(
+ embed_dim=encoder_dim,
+ pos_dim=pos_dim,
+ num_heads=num_heads,
+ query_head_dim=query_head_dim,
+ pos_head_dim=pos_head_dim,
+ value_head_dim=value_head_dim,
+ feedforward_dim=feedforward_dim,
+ use_conv=use_conv,
+ cnn_module_kernel=cnn_module_kernel[i],
+ dropout=dropout,
+ )
+
+ # For the segment of the warmup period, we let the Conv2dSubsampling
+ # layer learn something. Then we start to warm up the other encoders.
+ encoder = Zipformer2Encoder(
+ encoder_layer,
+ num_encoder_layers[i],
+ embed_dim=encoder_dim,
+ time_embed_dim=time_embed_dim,
+ pos_dim=pos_dim,
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
+ )
+
+ if downsampling_factor[i] != 1:
+ encoder = DownsampledZipformer2Encoder(
+ encoder,
+ dim=encoder_dim,
+ downsample=downsampling_factor[i],
+ )
+
+ encoders.append(encoder)
+
+ self.encoders = nn.ModuleList(encoders)
+ if self.use_time_embed:
+ self.time_embed = nn.Sequential(
+ nn.Linear(time_embed_dim, time_embed_dim * 2),
+ SwooshR(),
+ nn.Linear(time_embed_dim * 2, time_embed_dim),
+ )
+ else:
+ self.time_embed = None
+
+ def forward(
+ self,
+ x: Tensor,
+ t: Optional[Tensor] = None,
+ padding_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
+ t:
+ A t tensor of shape (batch_size,) or (batch_size, seq_len)
+ padding_mask:
+ The mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+ Returns:
+ Return the output embeddings. its shape is
+ (batch_size, output_seq_len, encoder_dim)
+ """
+ assert x.size(2) in self.in_dim, f"{x.size(2)} in {self.in_dim}"
+ if x.size(2) == self.in_dim[0]:
+ index = 0
+ else:
+ index = 1
+ x = x.permute(1, 0, 2)
+ x = self.in_proj[index](x)
+
+ if t is not None:
+ assert t.dim() == 1 or t.dim() == 2, t.shape
+ time_emb = timestep_embedding(t, self.time_embed_dim)
+ time_emb = self.time_embed(time_emb)
+ else:
+ time_emb = None
+
+ attn_mask = None
+
+ for i, module in enumerate(self.encoders):
+ x = module(
+ x,
+ time_emb=time_emb,
+ src_key_padding_mask=padding_mask,
+ attn_mask=attn_mask,
+ )
+ x = self.out_proj[index](x)
+ x = x.permute(1, 0, 2)
+ return x
diff --git a/zipvoice/models/zipvoice.py b/zipvoice/models/zipvoice.py
new file mode 100644
index 0000000000000000000000000000000000000000..83780424d90d623696ebdf23a077928ca6846810
--- /dev/null
+++ b/zipvoice/models/zipvoice.py
@@ -0,0 +1,590 @@
+# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
+# Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+from zipvoice.models.modules.solver import EulerSolver
+from zipvoice.models.modules.zipformer import TTSZipformer
+from zipvoice.utils.common import (
+ condition_time_mask,
+ get_tokens_index,
+ make_pad_mask,
+ pad_labels,
+ prepare_avg_tokens_durations,
+)
+def score_tokens(A):
+ B = [9, 14, 18, 21, 27, 33, 37, 39, 42, 45, 50, 51, 52, 54, 58, 59, 61, 62, 63, 69, 73, 74, 79, 85, 99, 100, 102, 105, 119, 120, 121, 122, 123, 124, 141, 143, 144, 145, 146, 157, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 349, 350, 353, 356, 357, 358, 359]
+
+ total_score = 0
+ # Thêm 3 vào đầu và cuối
+ tokens = [3] + A + [3]
+
+ # Tách chuỗi theo số 3
+ segment = []
+ for t in tokens:
+ if t == 3:
+ if segment: # xử lý 1 đoạn
+ count = 0
+ for i in range(len(segment) - 1):
+ if (segment[i] in B and segment[i+1] not in B):
+ # print(f"{segment[i]} in B and {segment[i+1]} not in B)")
+ count += 1
+ if segment[-1] in B:
+ # print(f"{segment[-1]} in B")
+ count += 1
+ if count > 0:
+ total_score += 1 + (count - 1) * 0.85
+ segment = []
+ else:
+ segment.append(t)
+
+ return total_score
+
+
+
+class ZipVoice(nn.Module):
+ """The ZipVoice model."""
+
+ def __init__(
+ self,
+ fm_decoder_downsampling_factor: List[int] = [1, 2, 4, 2, 1],
+ fm_decoder_num_layers: List[int] = [2, 2, 4, 4, 4],
+ fm_decoder_cnn_module_kernel: List[int] = [31, 15, 7, 15, 31],
+ fm_decoder_feedforward_dim: int = 1536,
+ fm_decoder_num_heads: int = 4,
+ fm_decoder_dim: int = 512,
+ text_encoder_num_layers: int = 4,
+ text_encoder_feedforward_dim: int = 512,
+ text_encoder_cnn_module_kernel: int = 9,
+ text_encoder_num_heads: int = 4,
+ text_encoder_dim: int = 192,
+ time_embed_dim: int = 192,
+ text_embed_dim: int = 192,
+ query_head_dim: int = 32,
+ value_head_dim: int = 12,
+ pos_head_dim: int = 4,
+ pos_dim: int = 48,
+ feat_dim: int = 100,
+ vocab_size: int = 26,
+ pad_id: int = 0,
+ ):
+ """
+ Initialize the model with specified configuration parameters.
+
+ Args:
+ fm_decoder_downsampling_factor: List of downsampling factors for each layer
+ in the flow-matching decoder.
+ fm_decoder_num_layers: List of the number of layers for each block in the
+ flow-matching decoder.
+ fm_decoder_cnn_module_kernel: List of kernel sizes for CNN modules in the
+ flow-matching decoder.
+ fm_decoder_feedforward_dim: Dimension of the feedforward network in the
+ flow-matching decoder.
+ fm_decoder_num_heads: Number of attention heads in the flow-matching
+ decoder.
+ fm_decoder_dim: Hidden dimension of the flow-matching decoder.
+ text_encoder_num_layers: Number of layers in the text encoder.
+ text_encoder_feedforward_dim: Dimension of the feedforward network in the
+ text encoder.
+ text_encoder_cnn_module_kernel: Kernel size for the CNN module in the
+ text encoder.
+ text_encoder_num_heads: Number of attention heads in the text encoder.
+ text_encoder_dim: Hidden dimension of the text encoder.
+ time_embed_dim: Dimension of the time embedding.
+ text_embed_dim: Dimension of the text embedding.
+ query_head_dim: Dimension of the query attention head.
+ value_head_dim: Dimension of the value attention head.
+ pos_head_dim: Dimension of the position attention head.
+ pos_dim: Dimension of the positional encoding.
+ feat_dim: Dimension of the acoustic features.
+ vocab_size: Size of the vocabulary.
+ pad_id: ID used for padding tokens.
+ """
+ super().__init__()
+
+ self.fm_decoder = TTSZipformer(
+ in_dim=feat_dim * 3,
+ out_dim=feat_dim,
+ downsampling_factor=fm_decoder_downsampling_factor,
+ num_encoder_layers=fm_decoder_num_layers,
+ cnn_module_kernel=fm_decoder_cnn_module_kernel,
+ encoder_dim=fm_decoder_dim,
+ feedforward_dim=fm_decoder_feedforward_dim,
+ num_heads=fm_decoder_num_heads,
+ query_head_dim=query_head_dim,
+ pos_head_dim=pos_head_dim,
+ value_head_dim=value_head_dim,
+ pos_dim=pos_dim,
+ use_time_embed=True,
+ time_embed_dim=time_embed_dim,
+ )
+
+ self.text_encoder = TTSZipformer(
+ in_dim=text_embed_dim,
+ out_dim=feat_dim,
+ downsampling_factor=1,
+ num_encoder_layers=text_encoder_num_layers,
+ cnn_module_kernel=text_encoder_cnn_module_kernel,
+ encoder_dim=text_encoder_dim,
+ feedforward_dim=text_encoder_feedforward_dim,
+ num_heads=text_encoder_num_heads,
+ query_head_dim=query_head_dim,
+ pos_head_dim=pos_head_dim,
+ value_head_dim=value_head_dim,
+ pos_dim=pos_dim,
+ use_time_embed=False,
+ )
+
+ self.feat_dim = feat_dim
+ self.text_embed_dim = text_embed_dim
+ self.pad_id = pad_id
+
+ self.embed = nn.Embedding(vocab_size, text_embed_dim)
+ self.solver = EulerSolver(self, func_name="forward_fm_decoder")
+
+ def forward_fm_decoder(
+ self,
+ t: torch.Tensor,
+ xt: torch.Tensor,
+ text_condition: torch.Tensor,
+ speech_condition: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ guidance_scale: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Compute velocity.
+ Args:
+ t: A tensor of shape (N, 1, 1) or a tensor of a float,
+ in the range of (0, 1).
+ xt: the input of the current timestep, including condition
+ embeddings and noisy acoustic features.
+ text_condition: the text condition embeddings, with the
+ shape (batch, seq_len, emb_dim).
+ speech_condition: the speech condition embeddings, with the
+ shape (batch, seq_len, emb_dim).
+ padding_mask: The mask for padding, True means masked
+ position, with the shape (N, T).
+ guidance_scale: The guidance scale in classifier-free guidance,
+ which is a tensor of shape (N, 1, 1) or a tensor of a float.
+
+ Returns:
+ predicted velocity, with the shape (batch, seq_len, emb_dim).
+ """
+
+ xt = torch.cat([xt, text_condition, speech_condition], dim=2)
+
+ assert t.dim() in (0, 3)
+ # Handle t with the shape (N, 1, 1):
+ # squeeze the last dimension if it's size is 1.
+ while t.dim() > 1 and t.size(-1) == 1:
+ t = t.squeeze(-1)
+ # Handle t with a single value: expand to the size of batch size.
+ if t.dim() == 0:
+ t = t.repeat(xt.shape[0])
+
+ if guidance_scale is not None:
+ while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1:
+ guidance_scale = guidance_scale.squeeze(-1)
+ if guidance_scale.dim() == 0:
+ guidance_scale = guidance_scale.repeat(xt.shape[0])
+
+ vt = self.fm_decoder(
+ x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale
+ )
+ else:
+ vt = self.fm_decoder(x=xt, t=t, padding_mask=padding_mask)
+ return vt
+
+ def forward_text_embed(
+ self,
+ tokens: List[List[int]],
+ ):
+ """
+ Get the text embeddings.
+ Args:
+ tokens: a list of list of token ids.
+ Returns:
+ embed: the text embeddings, shape (batch, seq_len, emb_dim).
+ tokens_lens: the length of each token sequence, shape (batch,).
+ """
+ device = (
+ self.device if isinstance(self, DDP) else next(self.parameters()).device
+ )
+ tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
+ embed = self.embed(tokens_padded) # (B, S, C)
+ tokens_lens = torch.tensor(
+ [len(token) for token in tokens], dtype=torch.int64, device=device
+ )
+ tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
+
+ embed = self.text_encoder(
+ x=embed, t=None, padding_mask=tokens_padding_mask
+ ) # (B, S, C)
+ return embed, tokens_lens
+
+ def forward_text_condition(
+ self,
+ embed: torch.Tensor,
+ tokens_lens: torch.Tensor,
+ features_lens: torch.Tensor,
+ ):
+ """
+ Get the text condition with the same length of the acoustic feature.
+ Args:
+ embed: the text embeddings, shape (batch, token_seq_len, emb_dim).
+ tokens_lens: the length of each token sequence, shape (batch,).
+ features_lens: the length of each acoustic feature sequence,
+ shape (batch,).
+ Returns:
+ text_condition: the text condition, shape
+ (batch, feature_seq_len, emb_dim).
+ padding_mask: the padding mask of text condition, shape
+ (batch, feature_seq_len).
+ """
+
+ num_frames = int(features_lens.max())
+
+ padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
+
+ tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens)
+
+ tokens_index = get_tokens_index(tokens_durations, num_frames).to(
+ embed.device
+ ) # (B, T)
+
+ text_condition = torch.gather(
+ embed,
+ dim=1,
+ index=tokens_index.unsqueeze(-1).expand(
+ embed.size(0), num_frames, embed.size(-1)
+ ),
+ ) # (B, T, F)
+ return text_condition, padding_mask
+
+ def forward_text_train(
+ self,
+ tokens: List[List[int]],
+ features_lens: torch.Tensor,
+ ):
+ """
+ Process text for training, given text tokens and real feature lengths.
+ """
+ embed, tokens_lens = self.forward_text_embed(tokens)
+ text_condition, padding_mask = self.forward_text_condition(
+ embed, tokens_lens, features_lens
+ )
+ return (
+ text_condition,
+ padding_mask,
+ )
+
+ def forward_text_inference_gt_duration(
+ self,
+ tokens: List[List[int]],
+ features_lens: torch.Tensor,
+ prompt_tokens: List[List[int]],
+ prompt_features_lens: torch.Tensor,
+ ):
+ """
+ Process text for inference, given text tokens, real feature lengths and prompts.
+ """
+ tokens = [
+ prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
+ ]
+ features_lens = prompt_features_lens + features_lens
+ embed, tokens_lens = self.forward_text_embed(tokens)
+ text_condition, padding_mask = self.forward_text_condition(
+ embed, tokens_lens, features_lens
+ )
+ return text_condition, padding_mask
+
+ def forward_text_inference_ratio_duration(
+ self,
+ tokens: List[List[int]],
+ prompt_tokens: List[List[int]],
+ prompt_features_lens: torch.Tensor,
+ speed: float,
+ num_space_text=[-1],
+ num_space_prompt=[-1]
+ ):
+ """
+ Process text for inference, length ước lượng theo số khoảng trắng (token=3) + 1.
+ """
+ device = (
+ self.device if isinstance(self, DDP) else next(self.parameters()).device
+ )
+
+ cat_tokens = [
+ prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
+ ]
+
+ prompt_tokens_lens = torch.tensor(
+ [len(token) for token in prompt_tokens],
+ dtype=torch.int64,
+ device=device,
+ )
+
+ tokens_lens = torch.tensor(
+ [len(token) for token in tokens],
+ dtype=torch.int64,
+ device=device,
+ )
+
+ # 🔑 số khoảng trắng + 1
+ prompt_space_lens = torch.tensor(
+ [(score_tokens(token)-(token.count(3) - numsp - 1) + token.count(8)*0.5)*100 for token, numsp in zip(prompt_tokens,num_space_prompt)],
+ dtype=torch.int64,
+ device=device,
+ )
+ tokens_space_lens = torch.tensor(
+ [(score_tokens(token)-(token.count(3) - numsp - 1) + token.count(8)*0.5)*100 for token, numsp in zip (tokens,num_space_text)],
+ dtype=torch.int64,
+ device=device,
+ )
+ print("tokens_space_lens: ", tokens_space_lens)
+ print("prompt_space_lens: ", prompt_space_lens)
+
+ cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens)
+ def alpha(prompt_space_lens: float) -> float:
+ if prompt_space_lens <= 1:
+ return 1.1
+ elif prompt_space_lens >= 30:
+ return 1.03
+ else:
+ return 1.1 - (prompt_space_lens - 1) / (30 - 1) * (1.1 - 1.03)
+
+ # frames_per_word * số_word_trong_text
+ features_lens = prompt_features_lens + torch.ceil(
+ (prompt_features_lens / prompt_space_lens * tokens_space_lens / speed * alpha(tokens_space_lens/100))
+ ).to(dtype=torch.int64)
+
+ text_condition, padding_mask = self.forward_text_condition(
+ cat_embed, cat_tokens_lens, features_lens
+ )
+ return text_condition, padding_mask
+
+
+ def forward(
+ self,
+ tokens: List[List[int]],
+ features: torch.Tensor,
+ features_lens: torch.Tensor,
+ noise: torch.Tensor,
+ t: torch.Tensor,
+ condition_drop_ratio: float = 0.0,
+ ) -> torch.Tensor:
+ """Forward pass of the model for training.
+ Args:
+ tokens: a list of list of token ids.
+ features: the acoustic features, with the shape (batch, seq_len, feat_dim).
+ features_lens: the length of each acoustic feature sequence, shape (batch,).
+ noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
+ t: the time step, with the shape (batch, 1, 1).
+ condition_drop_ratio: the ratio of dropped text condition.
+ Returns:
+ fm_loss: the flow-matching loss.
+ """
+
+ (text_condition, padding_mask,) = self.forward_text_train(
+ tokens=tokens,
+ features_lens=features_lens,
+ )
+
+ speech_condition_mask = condition_time_mask(
+ features_lens=features_lens,
+ mask_percent=(0.7, 1.0),
+ max_len=features.size(1),
+ )
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
+
+ if condition_drop_ratio > 0.0:
+ drop_mask = (
+ torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
+ > condition_drop_ratio
+ )
+ text_condition = text_condition * drop_mask
+
+ xt = features * t + noise * (1 - t)
+ ut = features - noise # (B, T, F)
+
+ vt = self.forward_fm_decoder(
+ t=t,
+ xt=xt,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ )
+
+ loss_mask = speech_condition_mask & (~padding_mask)
+ fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
+
+ return fm_loss
+
+ def sample(
+ self,
+ tokens: List[List[int]],
+ prompt_tokens: List[List[int]],
+ prompt_features: torch.Tensor,
+ prompt_features_lens: torch.Tensor,
+ features_lens: Optional[torch.Tensor] = None,
+ speed: float = 1.0,
+ t_shift: float = 1.0,
+ duration: str = "predict",
+ num_step: int = 5,
+ guidance_scale: float = 0.5,
+ num_space_text=[-1],
+ num_space_prompt=[-1]
+ ) -> torch.Tensor:
+ """
+ Generate acoustic features, given text tokens, prompts feature
+ and prompt transcription's text tokens.
+ Args:
+ tokens: a list of list of text tokens.
+ prompt_tokens: a list of list of prompt tokens.
+ prompt_features: the prompt feature with the shape
+ (batch_size, seq_len, feat_dim).
+ prompt_features_lens: the length of each prompt feature,
+ with the shape (batch_size,).
+ features_lens: the length of the predicted eature, with the
+ shape (batch_size,). It is used only when duration is "real".
+ duration: "real" or "predict". If "real", the predicted
+ feature length is given by features_lens.
+ num_step: the number of steps to use in the ODE solver.
+ guidance_scale: the guidance scale for classifier-free guidance.
+ """
+
+ assert duration in ["real", "predict"]
+
+ if duration == "predict":
+ (
+ text_condition,
+ padding_mask,
+ ) = self.forward_text_inference_ratio_duration(
+ tokens=tokens,
+ prompt_tokens=prompt_tokens,
+ prompt_features_lens=prompt_features_lens,
+ speed=speed,
+ num_space_text=num_space_text,
+ num_space_prompt=num_space_prompt,
+ )
+ else:
+ assert features_lens is not None
+ text_condition, padding_mask = self.forward_text_inference_gt_duration(
+ tokens=tokens,
+ features_lens=features_lens,
+ prompt_tokens=prompt_tokens,
+ prompt_features_lens=prompt_features_lens,
+ )
+ batch_size, num_frames, _ = text_condition.shape
+
+ speech_condition = torch.nn.functional.pad(
+ prompt_features, (0, 0, 0, num_frames - prompt_features.size(1))
+ ) # (B, T, F)
+
+ # False means speech condition positions.
+ speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames)
+ speech_condition = torch.where(
+ speech_condition_mask.unsqueeze(-1),
+ torch.zeros_like(speech_condition),
+ speech_condition,
+ )
+
+ x0 = torch.randn(
+ batch_size,
+ num_frames,
+ prompt_features.size(-1),
+ device=text_condition.device,
+ )
+
+ x1 = self.solver.sample(
+ x=x0,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ t_shift=t_shift,
+ )
+ x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens
+ x1_prompt = torch.zeros(
+ x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device
+ )
+ x1_wo_prompt = torch.zeros(
+ x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device
+ )
+ for i in range(x1.size(0)):
+ x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[
+ i,
+ prompt_features_lens[i] : prompt_features_lens[i]
+ + x1_wo_prompt_lens[i],
+ ]
+ x1_prompt[i, : prompt_features_lens[i], :] = x1[
+ i, : prompt_features_lens[i]
+ ]
+
+ return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens
+
+ def sample_intermediate(
+ self,
+ tokens: List[List[int]],
+ features: torch.Tensor,
+ features_lens: torch.Tensor,
+ noise: torch.Tensor,
+ speech_condition_mask: torch.Tensor,
+ t_start: float,
+ t_end: float,
+ num_step: int = 1,
+ guidance_scale: torch.Tensor = None,
+ ) -> torch.Tensor:
+ """
+ Generate acoustic features in intermediate timesteps.
+ Args:
+ tokens: List of list of token ids.
+ features: The acoustic features, with the shape (batch, seq_len, feat_dim).
+ features_lens: The length of each acoustic feature sequence,
+ with the shape (batch,).
+ noise: The initial noise, with the shape (batch, seq_len, feat_dim).
+ speech_condition_mask: The mask for speech condition, True means
+ non-condition positions, with the shape (batch, seq_len).
+ t_start: The start timestep.
+ t_end: The end timestep.
+ num_step: The number of steps for sampling.
+ guidance_scale: The scale for classifier-free guidance inference,
+ with the shape (batch, 1, 1).
+ """
+ (text_condition, padding_mask,) = self.forward_text_train(
+ tokens=tokens,
+ features_lens=features_lens,
+ )
+
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
+
+ x_t_end = self.solver.sample(
+ x=noise,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ t_start=t_start,
+ t_end=t_end,
+ )
+ x_t_end_lens = (~padding_mask).sum(-1)
+ return x_t_end, x_t_end_lens
diff --git a/zipvoice/models/zipvoice_dialog.py b/zipvoice/models/zipvoice_dialog.py
new file mode 100644
index 0000000000000000000000000000000000000000..1137d5cebb135b128a0551bcf15c3825f296e01f
--- /dev/null
+++ b/zipvoice/models/zipvoice_dialog.py
@@ -0,0 +1,358 @@
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+import torch.nn as nn
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+from zipvoice.models.modules.zipformer_two_stream import TTSZipformerTwoStream
+from zipvoice.models.zipvoice import ZipVoice
+from zipvoice.utils.common import condition_time_mask_suffix, make_pad_mask, pad_labels
+
+
+class ZipVoiceDialog(ZipVoice):
+ """The ZipVoice-Dialog model."""
+
+ def __init__(
+ self,
+ fm_decoder_downsampling_factor: List[int] = [1, 2, 4, 2, 1],
+ fm_decoder_num_layers: List[int] = [2, 2, 4, 4, 4],
+ fm_decoder_cnn_module_kernel: List[int] = [31, 15, 7, 15, 31],
+ fm_decoder_feedforward_dim: int = 1536,
+ fm_decoder_num_heads: int = 4,
+ fm_decoder_dim: int = 512,
+ text_encoder_num_layers: int = 4,
+ text_encoder_feedforward_dim: int = 512,
+ text_encoder_cnn_module_kernel: int = 9,
+ text_encoder_num_heads: int = 4,
+ text_encoder_dim: int = 192,
+ time_embed_dim: int = 192,
+ text_embed_dim: int = 192,
+ query_head_dim: int = 32,
+ value_head_dim: int = 12,
+ pos_head_dim: int = 4,
+ pos_dim: int = 48,
+ feat_dim: int = 100,
+ vocab_size: int = 26,
+ pad_id: int = 0,
+ spk_a_id: int = 360,
+ spk_b_id: int = 361,
+ ):
+ """
+ Initialize the model with specified configuration parameters.
+
+ Args:
+ fm_decoder_downsampling_factor: List of downsampling factors for each layer
+ in the flow-matching decoder.
+ fm_decoder_num_layers: List of the number of layers for each block in the
+ flow-matching decoder.
+ fm_decoder_cnn_module_kernel: List of kernel sizes for CNN modules in the
+ flow-matching decoder.
+ fm_decoder_feedforward_dim: Dimension of the feedforward network in the
+ flow-matching decoder.
+ fm_decoder_num_heads: Number of attention heads in the flow-matching
+ decoder.
+ fm_decoder_dim: Hidden dimension of the flow-matching decoder.
+ text_encoder_num_layers: Number of layers in the text encoder.
+ text_encoder_feedforward_dim: Dimension of the feedforward network in the
+ text encoder.
+ text_encoder_cnn_module_kernel: Kernel size for the CNN module in the
+ text encoder.
+ text_encoder_num_heads: Number of attention heads in the text encoder.
+ text_encoder_dim: Hidden dimension of the text encoder.
+ time_embed_dim: Dimension of the time embedding.
+ text_embed_dim: Dimension of the text embedding.
+ query_head_dim: Dimension of the query attention head.
+ value_head_dim: Dimension of the value attention head.
+ pos_head_dim: Dimension of the position attention head.
+ pos_dim: Dimension of the positional encoding.
+ feat_dim: Dimension of the acoustic features.
+ vocab_size: Size of the vocabulary.
+ pad_id: ID used for padding tokens.
+ spk_a_id: ID of speaker A / [S1].
+ spk_b_id: ID of speaker B / [S2].
+ """
+ super().__init__(
+ fm_decoder_downsampling_factor=fm_decoder_downsampling_factor,
+ fm_decoder_num_layers=fm_decoder_num_layers,
+ fm_decoder_cnn_module_kernel=fm_decoder_cnn_module_kernel,
+ fm_decoder_feedforward_dim=fm_decoder_feedforward_dim,
+ fm_decoder_num_heads=fm_decoder_num_heads,
+ fm_decoder_dim=fm_decoder_dim,
+ text_encoder_num_layers=text_encoder_num_layers,
+ text_encoder_feedforward_dim=text_encoder_feedforward_dim,
+ text_encoder_cnn_module_kernel=text_encoder_cnn_module_kernel,
+ text_encoder_num_heads=text_encoder_num_heads,
+ text_encoder_dim=text_encoder_dim,
+ time_embed_dim=time_embed_dim,
+ text_embed_dim=text_embed_dim,
+ query_head_dim=query_head_dim,
+ value_head_dim=value_head_dim,
+ pos_head_dim=pos_head_dim,
+ pos_dim=pos_dim,
+ feat_dim=feat_dim,
+ vocab_size=vocab_size,
+ pad_id=pad_id,
+ )
+
+ self.spk_a_id = spk_a_id
+ self.spk_b_id = spk_b_id
+ self.spk_embed = nn.Embedding(2, feat_dim)
+ torch.nn.init.normal_(self.spk_embed.weight, mean=0, std=0.1)
+
+ def extract_spk_indices(self, tensor):
+ turn_mask = ((tensor == self.spk_a_id) | (tensor == self.spk_b_id)).long()
+ turn_counts = turn_mask.cumsum(dim=1)
+ spk_mask = turn_counts % 2
+ spk_mask = torch.where(tensor == self.pad_id, -1, spk_mask)
+ spk_a_indices = torch.where(spk_mask == 0)
+ spk_b_indices = torch.where(spk_mask == 1)
+ return spk_a_indices, spk_b_indices
+
+ def forward_text_embed(
+ self,
+ tokens: List[List[int]],
+ ):
+ """
+ Get the text embeddings.
+ Args:
+ tokens: a list of list of token ids.
+ Returns:
+ embed: the text embeddings, shape (batch, seq_len, emb_dim).
+ tokens_lens: the length of each token sequence, shape (batch,).
+ """
+ device = (
+ self.device if isinstance(self, DDP) else next(self.parameters()).device
+ )
+ tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
+ embed = self.embed(tokens_padded) # (B, S, C)
+ spk_a_indices, spk_b_indices = self.extract_spk_indices(tokens_padded)
+ tokens_lens = torch.tensor(
+ [len(token) for token in tokens], dtype=torch.int64, device=device
+ )
+ tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
+
+ embed = self.text_encoder(
+ x=embed, t=None, padding_mask=tokens_padding_mask
+ ) # (B, S, C)
+ embed[spk_a_indices] += self.spk_embed(torch.tensor(0, device=device)).to(
+ embed.dtype
+ )
+ embed[spk_b_indices] += self.spk_embed(torch.tensor(1, device=device)).to(
+ embed.dtype
+ )
+ return embed, tokens_lens
+
+ def forward(
+ self,
+ tokens: List[List[int]],
+ features: torch.Tensor,
+ features_lens: torch.Tensor,
+ noise: torch.Tensor,
+ t: torch.Tensor,
+ condition_drop_ratio: float = 0.0,
+ ) -> torch.Tensor:
+ """Forward pass of the model for training.
+ Args:
+ tokens: a list of list of token ids.
+ features: the acoustic features, with the shape (batch, seq_len, feat_dim).
+ features_lens: the length of each acoustic feature sequence, shape (batch,).
+ noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
+ t: the time step, with the shape (batch, 1, 1).
+ condition_drop_ratio: the ratio of dropped text condition.
+ Returns:
+ fm_loss: the flow-matching loss.
+ """
+
+ (text_condition, padding_mask,) = self.forward_text_train(
+ tokens=tokens,
+ features_lens=features_lens,
+ )
+
+ speech_condition_mask = condition_time_mask_suffix(
+ features_lens=features_lens,
+ mask_percent=(0.5, 1.0),
+ max_len=features.size(1),
+ )
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
+
+ if condition_drop_ratio > 0.0:
+ drop_mask = (
+ torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
+ > condition_drop_ratio
+ )
+ text_condition = text_condition * drop_mask
+
+ xt = features * t + noise * (1 - t)
+ ut = features - noise # (B, T, F)
+
+ vt = self.forward_fm_decoder(
+ t=t,
+ xt=xt,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ )
+
+ loss_mask = speech_condition_mask & (~padding_mask)
+ fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
+
+ return fm_loss
+
+
+class ZipVoiceDialogStereo(ZipVoiceDialog):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ required_params = {
+ "feat_dim",
+ "fm_decoder_downsampling_factor",
+ "fm_decoder_num_layers",
+ "fm_decoder_cnn_module_kernel",
+ "fm_decoder_dim",
+ "fm_decoder_feedforward_dim",
+ "fm_decoder_num_heads",
+ "query_head_dim",
+ "pos_head_dim",
+ "value_head_dim",
+ "pos_dim",
+ "time_embed_dim",
+ }
+
+ missing = [p for p in required_params if p not in kwargs]
+ if missing:
+ raise ValueError(f"Missing required parameters: {', '.join(missing)}")
+
+ self.fm_decoder = TTSZipformerTwoStream(
+ in_dim=(kwargs["feat_dim"] * 5, kwargs["feat_dim"] * 3),
+ out_dim=(kwargs["feat_dim"] * 2, kwargs["feat_dim"]),
+ downsampling_factor=kwargs["fm_decoder_downsampling_factor"],
+ num_encoder_layers=kwargs["fm_decoder_num_layers"],
+ cnn_module_kernel=kwargs["fm_decoder_cnn_module_kernel"],
+ encoder_dim=kwargs["fm_decoder_dim"],
+ feedforward_dim=kwargs["fm_decoder_feedforward_dim"],
+ num_heads=kwargs["fm_decoder_num_heads"],
+ query_head_dim=kwargs["query_head_dim"],
+ pos_head_dim=kwargs["pos_head_dim"],
+ value_head_dim=kwargs["value_head_dim"],
+ pos_dim=kwargs["pos_dim"],
+ use_time_embed=True,
+ time_embed_dim=kwargs["time_embed_dim"],
+ )
+
+ def forward(
+ self,
+ tokens: List[List[int]],
+ features: torch.Tensor,
+ features_lens: torch.Tensor,
+ noise: torch.Tensor,
+ t: torch.Tensor,
+ condition_drop_ratio: float = 0.0,
+ se_weight: float = 1.0,
+ ) -> torch.Tensor:
+ """Forward pass of the model for training.
+ Args:
+ tokens: a list of list of token ids.
+ features: the acoustic features, with the shape (batch, seq_len, feat_dim).
+ features_lens: the length of each acoustic feature sequence, shape (batch,).
+ noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
+ t: the time step, with the shape (batch, 1, 1).
+ condition_drop_ratio: the ratio of dropped text condition.
+ se_weight: the weight of the speaker exclusive loss.
+ Returns:
+ fm_loss: the flow-matching loss.
+ """
+
+ (text_condition, padding_mask,) = self.forward_text_train(
+ tokens=tokens,
+ features_lens=features_lens,
+ )
+
+ speech_condition_mask = condition_time_mask_suffix(
+ features_lens=features_lens,
+ mask_percent=(0.5, 1.0),
+ max_len=features.size(1),
+ )
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
+
+ if condition_drop_ratio > 0.0:
+ drop_mask = (
+ torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
+ > condition_drop_ratio
+ )
+ text_condition = text_condition * drop_mask
+
+ xt = features * t + noise * (1 - t)
+ ut = features - noise # (B, T, F)
+
+ vt = self.forward_fm_decoder(
+ t=t,
+ xt=xt,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ )
+
+ loss_mask = speech_condition_mask & (~padding_mask)
+ fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
+
+ if se_weight > 0:
+ target = xt + vt * (1 - t)
+ fbank_1 = target[:, :, : self.feat_dim]
+ fbank_2 = target[:, :, self.feat_dim :]
+ energy_loss = torch.mean(
+ self.energy_based_loss(fbank_1, fbank_2, features)[loss_mask]
+ )
+ loss = fm_loss + energy_loss * se_weight
+ else:
+ loss = fm_loss
+
+ return loss
+
+ def energy_based_loss(self, fbank1, fbank2, gt_fbank):
+ energy1 = self.energy(fbank1)
+ energy2 = self.energy(fbank2)
+
+ energy_thresholds = self.adaptive_threshold_from_gt(
+ torch.cat(
+ [
+ gt_fbank[:, :, : self.feat_dim],
+ gt_fbank[:, :, self.feat_dim :],
+ ],
+ dim=1,
+ )
+ )
+
+ both_speaking = (
+ (energy1 > energy_thresholds) & (energy2 > energy_thresholds)
+ ).float()
+
+ penalty = (
+ both_speaking
+ * (energy1 - energy_thresholds)
+ * (energy2 - energy_thresholds)
+ )
+ return penalty
+
+ def energy(self, fbank):
+ return torch.mean(fbank, dim=-1)
+
+ def adaptive_threshold_from_gt(self, gt_fbank, percentile=50):
+ frame_energies = self.energy(gt_fbank)
+ thresholds = torch.quantile(frame_energies, q=percentile / 100, dim=1)
+ return thresholds.unsqueeze(1)
diff --git a/zipvoice/models/zipvoice_distill.py b/zipvoice/models/zipvoice_distill.py
new file mode 100644
index 0000000000000000000000000000000000000000..29481bd4dd0d4aa96268ab3672dbc528123f1727
--- /dev/null
+++ b/zipvoice/models/zipvoice_distill.py
@@ -0,0 +1,94 @@
+# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
+# Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+
+from zipvoice.models.modules.solver import DistillEulerSolver
+from zipvoice.models.modules.zipformer import TTSZipformer
+from zipvoice.models.zipvoice import ZipVoice
+
+
+class ZipVoiceDistill(ZipVoice):
+ """ZipVoice-Distill model."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ required_params = {
+ "feat_dim",
+ "fm_decoder_downsampling_factor",
+ "fm_decoder_num_layers",
+ "fm_decoder_cnn_module_kernel",
+ "fm_decoder_dim",
+ "fm_decoder_feedforward_dim",
+ "fm_decoder_num_heads",
+ "query_head_dim",
+ "pos_head_dim",
+ "value_head_dim",
+ "pos_dim",
+ "time_embed_dim",
+ }
+
+ missing = [p for p in required_params if p not in kwargs]
+ if missing:
+ raise ValueError(f"Missing required parameters: {', '.join(missing)}")
+
+ self.fm_decoder = TTSZipformer(
+ in_dim=kwargs["feat_dim"] * 3,
+ out_dim=kwargs["feat_dim"],
+ downsampling_factor=kwargs["fm_decoder_downsampling_factor"],
+ num_encoder_layers=kwargs["fm_decoder_num_layers"],
+ cnn_module_kernel=kwargs["fm_decoder_cnn_module_kernel"],
+ encoder_dim=kwargs["fm_decoder_dim"],
+ feedforward_dim=kwargs["fm_decoder_feedforward_dim"],
+ num_heads=kwargs["fm_decoder_num_heads"],
+ query_head_dim=kwargs["query_head_dim"],
+ pos_head_dim=kwargs["pos_head_dim"],
+ value_head_dim=kwargs["value_head_dim"],
+ pos_dim=kwargs["pos_dim"],
+ use_time_embed=True,
+ time_embed_dim=kwargs["time_embed_dim"],
+ use_guidance_scale_embed=True,
+ )
+ self.solver = DistillEulerSolver(self, func_name="forward_fm_decoder")
+
+ def forward(
+ self,
+ tokens: List[List[int]],
+ features: torch.Tensor,
+ features_lens: torch.Tensor,
+ noise: torch.Tensor,
+ speech_condition_mask: torch.Tensor,
+ t_start: float,
+ t_end: float,
+ num_step: int = 1,
+ guidance_scale: torch.Tensor = None,
+ ) -> torch.Tensor:
+
+ return self.sample_intermediate(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=noise,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t_start,
+ t_end=t_end,
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ )
diff --git a/zipvoice/tokenizer/normalizer.py b/zipvoice/tokenizer/normalizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d29558938e569ad3ae64da1b01e4f0fd92eaa904
--- /dev/null
+++ b/zipvoice/tokenizer/normalizer.py
@@ -0,0 +1,170 @@
+import re
+from abc import ABC, abstractmethod
+
+import cn2an
+import inflect
+
+
+class TextNormalizer(ABC):
+ """Abstract base class for text normalization, defining common interface."""
+
+ @abstractmethod
+ def normalize(self, text: str) -> str:
+ """Normalize text."""
+ raise NotImplementedError
+
+
+class EnglishTextNormalizer(TextNormalizer):
+ """
+ A class to handle preprocessing of English text including normalization. Following:
+ https://github.com/espnet/espnet_tts_frontend/blob/master/tacotron_cleaner/cleaners.py
+ """
+
+ def __init__(self):
+ # List of (regular expression, replacement) pairs for abbreviations:
+ self._abbreviations = [
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
+ for x in [
+ ("mrs", "misess"),
+ ("mr", "mister"),
+ ("dr", "doctor"),
+ ("st", "saint"),
+ ("co", "company"),
+ ("jr", "junior"),
+ ("maj", "major"),
+ ("gen", "general"),
+ ("drs", "doctors"),
+ ("rev", "reverend"),
+ ("lt", "lieutenant"),
+ ("hon", "honorable"),
+ ("sgt", "sergeant"),
+ ("capt", "captain"),
+ ("esq", "esquire"),
+ ("ltd", "limited"),
+ ("col", "colonel"),
+ ("ft", "fort"),
+ ("etc", "et cetera"),
+ ("btw", "by the way"),
+ ]
+ ]
+
+ self._inflect = inflect.engine()
+ self._comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
+ self._decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
+ self._percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
+ self._pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
+ self._dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
+ self._fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
+ self._ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
+ self._number_re = re.compile(r"[0-9]+")
+ self._whitespace_re = re.compile(r"\s+")
+
+ def normalize(self, text: str) -> str:
+ """Custom pipeline for English text,
+ including number and abbreviation expansion."""
+ text = self.expand_abbreviations(text)
+ text = self.normalize_numbers(text)
+
+ return text
+
+ def fraction_to_words(self, numerator, denominator):
+ if numerator == 1 and denominator == 2:
+ return " one half "
+ if numerator == 1 and denominator == 4:
+ return " one quarter "
+ if denominator == 2:
+ return " " + self._inflect.number_to_words(numerator) + " halves "
+ if denominator == 4:
+ return " " + self._inflect.number_to_words(numerator) + " quarters "
+ return (
+ " "
+ + self._inflect.number_to_words(numerator)
+ + " "
+ + self._inflect.ordinal(self._inflect.number_to_words(denominator))
+ + " "
+ )
+
+ def _remove_commas(self, m):
+ return m.group(1).replace(",", "")
+
+ def _expand_dollars(self, m):
+ match = m.group(1)
+ parts = match.split(".")
+ if len(parts) > 2:
+ return " " + match + " dollars " # Unexpected format
+ dollars = int(parts[0]) if parts[0] else 0
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
+ if dollars and cents:
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
+ cent_unit = "cent" if cents == 1 else "cents"
+ return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
+ elif dollars:
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
+ return " %s %s " % (dollars, dollar_unit)
+ elif cents:
+ cent_unit = "cent" if cents == 1 else "cents"
+ return " %s %s " % (cents, cent_unit)
+ else:
+ return " zero dollars "
+
+ def _expand_fraction(self, m):
+ numerator = int(m.group(1))
+ denominator = int(m.group(2))
+ return self.fraction_to_words(numerator, denominator)
+
+ def _expand_decimal_point(self, m):
+ return m.group(1).replace(".", " point ")
+
+ def _expand_percent(self, m):
+ return m.group(1).replace("%", " percent ")
+
+ def _expand_ordinal(self, m):
+ return " " + self._inflect.number_to_words(m.group(0)) + " "
+
+ def _expand_number(self, m):
+ num = int(m.group(0))
+ if num > 1000 and num < 3000:
+ if num == 2000:
+ return " two thousand "
+ elif num > 2000 and num < 2010:
+ return " two thousand " + self._inflect.number_to_words(num % 100) + " "
+ elif num % 100 == 0:
+ return " " + self._inflect.number_to_words(num // 100) + " hundred "
+ else:
+ return (
+ " "
+ + self._inflect.number_to_words(
+ num, andword="", zero="oh", group=2
+ ).replace(", ", " ")
+ + " "
+ )
+ else:
+ return " " + self._inflect.number_to_words(num, andword="") + " "
+
+ def normalize_numbers(self, text):
+ text = re.sub(self._comma_number_re, self._remove_commas, text)
+ text = re.sub(self._pounds_re, r"\1 pounds", text)
+ text = re.sub(self._dollars_re, self._expand_dollars, text)
+ text = re.sub(self._fraction_re, self._expand_fraction, text)
+ text = re.sub(self._decimal_number_re, self._expand_decimal_point, text)
+ text = re.sub(self._percent_number_re, self._expand_percent, text)
+ text = re.sub(self._ordinal_re, self._expand_ordinal, text)
+ text = re.sub(self._number_re, self._expand_number, text)
+ return text
+
+ def expand_abbreviations(self, text):
+ for regex, replacement in self._abbreviations:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+class ChineseTextNormalizer(TextNormalizer):
+ """
+ A class to handle preprocessing of Chinese text including normalization.
+ """
+
+ def normalize(self, text: str) -> str:
+ """Normalize text."""
+ # Convert numbers to Chinese
+ text = cn2an.transform(text, "an2cn")
+ return text
diff --git a/zipvoice/tokenizer/tokenizer.py b/zipvoice/tokenizer/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..95033b46a4c6769e035626b45117fc62d11d7aa4
--- /dev/null
+++ b/zipvoice/tokenizer/tokenizer.py
@@ -0,0 +1,774 @@
+# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao
+# Han Zhu,
+# Wei Kang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+from abc import ABC, abstractmethod
+from functools import reduce
+from typing import Dict, List, Optional
+
+import jieba
+from lhotse import CutSet
+from pypinyin import Style, lazy_pinyin
+from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
+
+from zipvoice.tokenizer.normalizer import ChineseTextNormalizer, EnglishTextNormalizer
+
+try:
+ from piper_phonemize import phonemize_espeak
+except Exception as ex:
+ raise RuntimeError(
+ f"{ex}\nPlease run\n"
+ "pip install piper_phonemize -f \
+ https://k2-fsa.github.io/icefall/piper_phonemize.html"
+ )
+
+jieba.default_logger.setLevel(logging.INFO)
+
+
+class Tokenizer(ABC):
+ """Abstract base class for tokenizers, defining common interface."""
+
+ @abstractmethod
+ def texts_to_token_ids(self, texts: List[str]) -> List[List[int]]:
+ """Convert list of texts to list of token id sequences."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def texts_to_tokens(self, texts: List[str]) -> List[List[str]]:
+ """Convert list of texts to list of token sequences."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]:
+ """Convert list of token sequences to list of token id sequences."""
+ raise NotImplementedError
+
+import logging
+from typing import List, Dict, Optional
+
+def text_to_phoneme_simple(text: str) -> str:
+ text = text.lower()
+ mapping = {
+ "nh": "ɲ",
+ "đ": "ɗ",
+
+ # u, ư
+ "u": "u",
+ "ú": "u3", "ù": "u2", "ủ": "u4", "ũ": "u5", "ụ": "u6",
+ "ư": "'u",
+ "ứ": "'u3", "ừ": "'u2", "ử": "'u4", "ữ": "'u5", "ự": "'u6",
+
+ # e, ê
+ "e": "e",
+ "é": "e3", "è": "e2", "ẻ": "e4", "ẽ": "e5", "ẹ": "e6",
+ "ê": "'e",
+ "ế": "'e3", "ề": "'e2", "ể": "'e4", "ễ": "'e5", "ệ": "'e6",
+
+ # o, ô, ơ
+ "o": "o",
+ "ó": "o3", "ò": "o2", "ỏ": "o4", "õ": "o5", "ọ": "o6",
+ "ô": "'o",
+ "ố": "'o3", "ồ": "'o2", "ổ": "'o4", "ỗ": "'o5", "ộ": "'o6",
+ "ơ": "əː",
+ "ớ": "əː3", "ờ": "əː2", "ở": "əː4", "ỡ": "əː5", "ợ": "əː6",
+
+ # a, ă, â
+ "a": "a",
+ "á": "a3", "à": "a2", "ả": "a4", "ã": "a5", "ạ": "a6",
+ "ă": "'a",
+ "ắ": "'a3", "ằ": "'a2", "ẳ": "'a4", "ẵ": "'a5", "ặ": "'a6",
+ "â": "'ə",
+ "ấ": "'ə3", "ầ": "'ə2", "ẩ": "'ə4", "ẫ": "'ə5", "ậ": "'ə6",
+
+ # i, y
+ "i": "i",
+ "í": "i3", "ì": "i2", "ỉ": "i4", "ĩ": "i5", "ị": "i6",
+ "y": "y",
+ "ý": "y3", "ỳ": "y2", "ỷ": "y4", "ỹ": "y5", "ỵ": "y6",
+ '́': "3",
+ }
+
+ # Ưu tiên key dài hơn (vd: "nh" > "n")
+ keys_sorted = sorted(mapping.keys(), key=lambda x: -len(x))
+
+ result = ""
+ i = 0
+ while i < len(text):
+ matched = False
+ for k in keys_sorted:
+ if text[i:i+len(k)] == k:
+ result += mapping[k]
+ i += len(k)
+ matched = True
+ break
+ if not matched:
+ result += text[i] # giữ nguyên nếu không có trong bảng
+ i += 1
+ return result
+
+
+class SimpleTokenizer2(Tokenizer):
+ """Tokenizer dựa trên phoneme:
+ - chuyển text sang phoneme
+ - sau đó map từng ký tự phoneme sang id
+ """
+
+ def __init__(self, token_file: Optional[str] = None):
+ self.has_tokens = False
+ if token_file is None:
+ logging.debug(
+ "Initialize Tokenizer without tokens file, "
+ "will fail when map to ids."
+ )
+ return
+
+ self.token2id: Dict[str, int] = {}
+ with open(token_file, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ info = line.rstrip().split("\t")
+ token, id = info[0], int(info[1])
+ assert token not in self.token2id, token
+ self.token2id[token] = id
+
+ self.pad_id = self.token2id["_"] # padding
+ self.vocab_size = len(self.token2id)
+ self.has_tokens = True
+
+ def texts_to_token_ids(
+ self,
+ texts: List[str],
+ ) -> List[List[int]]:
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
+
+ def texts_to_tokens(
+ self,
+ texts: List[str],
+ ) -> List[List[str]]:
+ # Chuyển mỗi text thành phoneme trước khi tách ký tự
+ phoneme_texts = [text_to_phoneme_simple(t) for t in texts]
+ # print(phoneme_texts)
+ tokens_list = [list(p) for p in phoneme_texts]
+ return tokens_list
+
+ def tokens_to_token_ids(
+ self,
+ tokens_list: List[List[str]],
+ ) -> List[List[int]]:
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
+
+ token_ids_list = []
+ for tokens in tokens_list:
+ token_ids = []
+ for t in tokens:
+ if t not in self.token2id:
+ print(f"OOV token skipped: {t}")
+ continue
+ token_ids.append(self.token2id[t])
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
+
+class SimpleTokenizer(Tokenizer):
+ """The simplpest tokenizer, treat every character as a token,
+ without text normalization.
+ """
+
+ def __init__(self, token_file: Optional[str] = None):
+ """
+ Args:
+ tokens: the file that contains information that maps tokens to ids,
+ which is a text file with '{token}\t{token_id}' per line.
+ """
+ # Parse token file
+ self.has_tokens = False
+ if token_file is None:
+ logging.debug(
+ "Initialize Tokenizer without tokens file, \
+ will fail when map to ids."
+ )
+ return
+ self.token2id: Dict[str, int] = {}
+ with open(token_file, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ info = line.rstrip().split("\t")
+ token, id = info[0], int(info[1])
+ assert token not in self.token2id, token
+ self.token2id[token] = id
+ self.pad_id = self.token2id["_"] # padding
+ self.vocab_size = len(self.token2id)
+ self.has_tokens = True
+
+ def texts_to_token_ids(
+ self,
+ texts: List[str],
+ ) -> List[List[int]]:
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
+
+ def texts_to_tokens(
+ self,
+ texts: List[str],
+ ) -> List[List[str]]:
+ tokens_list = [list(texts[i]) for i in range(len(texts))]
+ return tokens_list
+
+ def tokens_to_token_ids(
+ self,
+ tokens_list: List[List[str]],
+ ) -> List[List[int]]:
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
+
+ token_ids_list = []
+
+ for tokens in tokens_list:
+ token_ids = []
+ for t in tokens:
+ if t not in self.token2id:
+ logging.debug(f"Skip OOV {t}")
+ continue
+ token_ids.append(self.token2id[t])
+
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
+
+
+class EspeakTokenizer(Tokenizer):
+ """A simple tokenizer with Espeak g2p function."""
+
+ def __init__(self, token_file: Optional[str] = None, lang: str = "en-us"):
+ """
+ Args:
+ tokens: the file that contains information that maps tokens to ids,
+ which is a text file with '{token}\t{token_id}' per line.
+ lang: the language identifier, see
+ https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md
+ """
+ # Parse token file
+ self.has_tokens = False
+ self.lang = lang
+ if token_file is None:
+ logging.debug(
+ "Initialize Tokenizer without tokens file, \
+ will fail when map to ids."
+ )
+ return
+ self.token2id: Dict[str, int] = {}
+ with open(token_file, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ info = line.rstrip().split("\t")
+ token, id = info[0], int(info[1])
+ assert token not in self.token2id, token
+ self.token2id[token] = id
+ self.pad_id = self.token2id["_"] # padding
+ self.vocab_size = len(self.token2id)
+ self.has_tokens = True
+
+ def g2p(self, text: str) -> List[str]:
+ try:
+ tokens = phonemize_espeak(text, self.lang)
+ tokens = reduce(lambda x, y: x + y, tokens)
+ return tokens
+ except Exception as ex:
+ logging.warning(f"Tokenization of {self.lang} texts failed: {ex}")
+ return []
+
+ def texts_to_token_ids(
+ self,
+ texts: List[str],
+ ) -> List[List[int]]:
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
+
+ def texts_to_tokens(
+ self,
+ texts: List[str],
+ ) -> List[List[str]]:
+ tokens_list = [self.g2p(texts[i]) for i in range(len(texts))]
+ return tokens_list
+
+ def tokens_to_token_ids(
+ self,
+ tokens_list: List[List[str]],
+ ) -> List[List[int]]:
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
+
+ token_ids_list = []
+
+ for tokens in tokens_list:
+ token_ids = []
+ for t in tokens:
+ if t not in self.token2id:
+ logging.debug(f"Skip OOV {t}")
+ continue
+ token_ids.append(self.token2id[t])
+
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
+
+
+class EmiliaTokenizer(Tokenizer):
+ def __init__(self, token_file: Optional[str] = None, token_type="phone"):
+ """
+ Args:
+ tokens: the file that contains information that maps tokens to ids,
+ which is a text file with '{token}\t{token_id}' per line.
+ """
+ assert (
+ token_type == "phone"
+ ), f"Only support phone tokenizer for Emilia, but get {token_type}."
+
+ self.english_normalizer = EnglishTextNormalizer()
+ self.chinese_normalizer = ChineseTextNormalizer()
+
+ self.has_tokens = False
+ if token_file is None:
+ logging.debug(
+ "Initialize Tokenizer without tokens file, \
+ will fail when map to ids."
+ )
+ return
+ self.token2id: Dict[str, int] = {}
+ with open(token_file, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ info = line.rstrip().split("\t")
+ token, id = info[0], int(info[1])
+ assert token not in self.token2id, token
+ self.token2id[token] = id
+ self.pad_id = self.token2id["_"] # padding
+
+ self.vocab_size = len(self.token2id)
+ self.has_tokens = True
+
+ def texts_to_token_ids(
+ self,
+ texts: List[str],
+ ) -> List[List[int]]:
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
+
+ def preprocess_text(
+ self,
+ text: str,
+ ) -> str:
+ return self.map_punctuations(text)
+
+ def texts_to_tokens(
+ self,
+ texts: List[str],
+ ) -> List[List[str]]:
+ for i in range(len(texts)):
+ # Text normalization
+ texts[i] = self.preprocess_text(texts[i])
+
+ phoneme_list = []
+ for text in texts:
+ # now only en and ch
+ segments = self.get_segment(text)
+ all_phoneme = []
+ for index in range(len(segments)):
+ seg = segments[index]
+ if seg[1] == "zh":
+ phoneme = self.tokenize_ZH(seg[0])
+ elif seg[1] == "en":
+ phoneme = self.tokenize_EN(seg[0])
+ elif seg[1] == "pinyin":
+ phoneme = self.tokenize_pinyin(seg[0])
+ elif seg[1] == "tag":
+ phoneme = [seg[0]]
+ else:
+ logging.warning(
+ f"No English or Chinese characters found, \
+ skipping segment of unknown language: {seg}"
+ )
+ continue
+ all_phoneme += phoneme
+ phoneme_list.append(all_phoneme)
+ return phoneme_list
+
+ def tokens_to_token_ids(
+ self,
+ tokens_list: List[List[str]],
+ ) -> List[List[int]]:
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
+ token_ids_list = []
+
+ for tokens in tokens_list:
+ token_ids = []
+ for t in tokens:
+ if t not in self.token2id:
+ logging.debug(f"Skip OOV {t}")
+ continue
+ token_ids.append(self.token2id[t])
+
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
+
+ def tokenize_ZH(self, text: str) -> List[str]:
+ try:
+ text = self.chinese_normalizer.normalize(text)
+ segs = list(jieba.cut(text))
+ full = lazy_pinyin(
+ segs,
+ style=Style.TONE3,
+ tone_sandhi=True,
+ neutral_tone_with_five=True,
+ )
+ phones = []
+ for x in full:
+ # valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
+ if not (x[0:-1].isalpha() and x[-1] in ("1", "2", "3", "4", "5")):
+ phones.append(x)
+ continue
+ else:
+ phones.extend(self.seperate_pinyin(x))
+ return phones
+ except Exception as ex:
+ logging.warning(f"Tokenization of Chinese texts failed: {ex}")
+ return []
+
+ def tokenize_EN(self, text: str) -> List[str]:
+ try:
+ text = self.english_normalizer.normalize(text)
+ tokens = phonemize_espeak(text, "en-us")
+ tokens = reduce(lambda x, y: x + y, tokens)
+ return tokens
+ except Exception as ex:
+ logging.warning(f"Tokenization of English texts failed: {ex}")
+ return []
+
+ def tokenize_pinyin(self, text: str) -> List[str]:
+ try:
+ assert text.startswith("<") and text.endswith(">")
+ text = text.lstrip("<").rstrip(">")
+ # valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
+ if not (text[0:-1].isalpha() and text[-1] in ("1", "2", "3", "4", "5")):
+ logging.warning(
+ f"Strings enclosed with <> should be pinyin, \
+ but got: {text}. Skipped it. "
+ )
+ return []
+ else:
+ return self.seperate_pinyin(text)
+ except Exception as ex:
+ logging.warning(f"Tokenize pinyin failed: {ex}")
+ return []
+
+ def seperate_pinyin(self, text: str) -> List[str]:
+ """
+ Separate pinyin into initial and final
+ """
+ pinyins = []
+ initial = to_initials(text, strict=False)
+ # don't want to share tokens with espeak tokens,
+ # so use tone3 style
+ final = to_finals_tone3(
+ text,
+ strict=False,
+ neutral_tone_with_five=True,
+ )
+ if initial != "":
+ # don't want to share tokens with espeak tokens,
+ # so add a '0' after each initial
+ pinyins.append(initial + "0")
+ if final != "":
+ pinyins.append(final)
+ return pinyins
+
+ def map_punctuations(self, text):
+ text = text.replace(",", ",")
+ text = text.replace("。", ".")
+ text = text.replace("!", "!")
+ text = text.replace("?", "?")
+ text = text.replace(";", ";")
+ text = text.replace(":", ":")
+ text = text.replace("、", ",")
+ text = text.replace("‘", "'")
+ text = text.replace("“", '"')
+ text = text.replace("”", '"')
+ text = text.replace("’", "'")
+ text = text.replace("⋯", "…")
+ text = text.replace("···", "…")
+ text = text.replace("・・・", "…")
+ text = text.replace("...", "…")
+ return text
+
+ def get_segment(self, text: str) -> List[str]:
+ """
+ Split a text into segments based on language types
+ (Chinese, English, Pinyin, tags, etc.)
+
+ Args:
+ text (str): Input text to be segmented
+
+ Returns:
+ List[str]: Segmented text parts with their language types
+
+ Example:
+ Input: 我们是小米人,是吗? Yes I think so!霍...啦啦啦
+ Output: [('我们是小米人,是吗? ', 'zh'),
+ ('Yes I think so!', 'en'), ('霍...啦啦啦', 'zh')]
+ """
+ # Stores the final segmented parts and their language types
+ segments = []
+ # Stores the language type of each character in the input text
+ types = []
+ temp_seg = ""
+ temp_lang = ""
+
+ # Each part is a character, or a special string enclosed in <> and []
+ # <> denotes pinyin string, [] denotes other special strings.
+ _part_pattern = re.compile(r"[<[].*?[>\]]|.")
+ text = _part_pattern.findall(text)
+
+ for i, part in enumerate(text):
+ if self.is_chinese(part) or self.is_pinyin(part):
+ types.append("zh")
+ elif self.is_alphabet(part):
+ types.append("en")
+ else:
+ types.append("other")
+
+ assert len(types) == len(text)
+
+ for i in range(len(types)):
+ # find the first char of the seg
+ if i == 0:
+ temp_seg += text[i]
+ temp_lang = types[i]
+ else:
+ if temp_lang == "other":
+ temp_seg += text[i]
+ temp_lang = types[i]
+ else:
+ if types[i] in [temp_lang, "other"]:
+ temp_seg += text[i]
+ else:
+ segments.append((temp_seg, temp_lang))
+ temp_seg = text[i]
+ temp_lang = types[i]
+
+ segments.append((temp_seg, temp_lang))
+
+ # Handle "pinyin" and "tag" types
+ segments = self.split_segments(segments)
+ return segments
+
+ def split_segments(self, segments):
+ """
+ split segments into smaller parts if special strings enclosed by [] or <>
+ are found, where <> denotes pinyin strings, [] denotes other special strings.
+
+ Args:
+ segments (list): A list of tuples where each tuple contains:
+ - temp_seg (str): The text segment to be split.
+ - temp_lang (str): The language code associated with the segment.
+
+ Returns:
+ list: A list of smaller segments.
+ """
+ result = []
+ for temp_seg, temp_lang in segments:
+ parts = re.split(r"([<[].*?[>\]])", temp_seg)
+ for part in parts:
+ if not part:
+ continue
+ if self.is_pinyin(part):
+ result.append((part, "pinyin"))
+ elif self.is_tag(part):
+ result.append((part, "tag"))
+ else:
+ result.append((part, temp_lang))
+ return result
+
+ def is_chinese(self, char: str) -> bool:
+ if char >= "\u4e00" and char <= "\u9fa5":
+ return True
+ else:
+ return False
+
+ def is_alphabet(self, char: str) -> bool:
+ if (char >= "\u0041" and char <= "\u005a") or (
+ char >= "\u0061" and char <= "\u007a"
+ ):
+ return True
+ else:
+ return False
+
+ def is_pinyin(self, part: str) -> bool:
+ if part.startswith("<") and part.endswith(">"):
+ return True
+ else:
+ return False
+
+ def is_tag(self, part: str) -> bool:
+ if part.startswith("[") and part.endswith("]"):
+ return True
+ else:
+ return False
+
+
+class DialogTokenizer(EmiliaTokenizer):
+ def __init__(self, token_file: Optional[str] = None, token_type="phone"):
+ super().__init__(token_file=token_file, token_type=token_type)
+ if token_file:
+ self.spk_a_id = self.token2id["[S1]"]
+ self.spk_b_id = self.token2id["[S2]"]
+
+ def preprocess_text(
+ self,
+ text: str,
+ ) -> str:
+ text = re.sub(r"\s*(\[S[12]\])\s*", r"\1", text)
+ text = self.map_punctuations(text)
+ return text
+
+
+class LibriTTSTokenizer(Tokenizer):
+ def __init__(self, token_file: Optional[str] = None, token_type="char"):
+ """
+ Args:
+ type: the type of tokenizer, e.g., bpe, char, phone.
+ tokens: the file that contains information that maps tokens to ids,
+ which is a text file with '{token}\t{token_id}' per line if type is
+ char or phone, otherwise it is a bpe_model file.
+ """
+ self.type = token_type
+ assert token_type in ["bpe", "char", "phone"]
+ try:
+ import tacotron_cleaner.cleaners
+ except Exception as ex:
+ raise RuntimeError(f"{ex}\nPlease run\n" "pip install espnet_tts_frontend")
+
+ self.normalize = tacotron_cleaner.cleaners.custom_english_cleaners
+
+ self.has_tokens = False
+ if token_file is None:
+ logging.debug(
+ "Initialize Tokenizer without tokens file, \
+ will fail when map to ids."
+ )
+ return
+ if token_type == "bpe":
+ import sentencepiece as spm
+
+ self.sp = spm.SentencePieceProcessor()
+ self.sp.load(token_file)
+ self.pad_id = self.sp.piece_to_id("")
+ self.vocab_size = self.sp.get_piece_size()
+ else:
+ self.token2id: Dict[str, int] = {}
+ with open(token_file, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ info = line.rstrip().split("\t")
+ token, id = info[0], int(info[1])
+ assert token not in self.token2id, token
+ self.token2id[token] = id
+ self.pad_id = self.token2id["_"] # padding
+ self.vocab_size = len(self.token2id)
+ self.has_tokens = True
+
+ def texts_to_token_ids(
+ self,
+ texts: List[str],
+ ) -> List[List[int]]:
+ if self.type == "bpe":
+ for i in range(len(texts)):
+ texts[i] = self.normalize(texts[i])
+ return self.sp.encode(texts)
+ else:
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
+
+ def texts_to_tokens(
+ self,
+ texts: List[str],
+ ) -> List[List[str]]:
+ for i in range(len(texts)):
+ texts[i] = self.normalize(texts[i])
+
+ if self.type == "char":
+ tokens_list = [list(texts[i]) for i in range(len(texts))]
+ elif self.type == "phone":
+ tokens_list = [
+ phonemize_espeak(texts[i].lower(), "en-us") for i in range(len(texts))
+ ]
+ elif self.type == "bpe":
+ tokens_list = self.sp.encode(texts, out_type=str)
+
+ return tokens_list
+
+ def tokens_to_token_ids(
+ self,
+ tokens_list: List[List[str]],
+ ) -> List[List[int]]:
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
+
+ assert self.type != "bpe", "BPE tokenizer does not support this function."
+
+ token_ids_list = []
+
+ for tokens in tokens_list:
+ token_ids = []
+ for t in tokens:
+ if t not in self.token2id:
+ logging.debug(f"Skip OOV {t}")
+ continue
+ token_ids.append(self.token2id[t])
+
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
+
+
+def add_tokens(cut_set: CutSet, tokenizer: str, lang: str):
+ if tokenizer == "emilia":
+ tokenizer = EmiliaTokenizer()
+ elif tokenizer == "espeak":
+ tokenizer = EspeakTokenizer(lang=lang)
+ elif tokenizer == "dialog":
+ tokenizer = DialogTokenizer()
+ elif tokenizer == "libritts":
+ tokenizer = LibriTTSTokenizer()
+ elif tokenizer == "simple":
+ tokenizer = SimpleTokenizer()
+ elif tokenizer == "simple2":
+ tokenizer = SimpleTokenizer2()
+ else:
+ raise ValueError(f"Unsupported tokenizer: {tokenizer}.")
+
+ def _prepare_cut(cut):
+ # Each cut only contains one supervision
+ assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
+ text = cut.supervisions[0].text
+ tokens = tokenizer.texts_to_tokens([text])[0]
+ cut.supervisions[0].tokens = tokens
+ return cut
+
+ cut_set = cut_set.map(_prepare_cut)
+ return cut_set
+
+
+if __name__ == "__main__":
+ text = (
+ "我们是5年小米人,是吗? Yes I think so! "
+ "mr king, 5 years, from 2019 to 2024."
+ "霍...啦啦啦超过90%的人...?!9204"
+ )
+ tokenizer = EmiliaTokenizer()
+ tokens = tokenizer.texts_to_tokens([text])
+ print(f"tokens: {'|'.join(tokens[0])}")
diff --git a/zipvoice/utils/checkpoint.py b/zipvoice/utils/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..35ebe611174d463e208e463527ef990c3d1ddddc
--- /dev/null
+++ b/zipvoice/utils/checkpoint.py
@@ -0,0 +1,572 @@
+# Copyright 2021-2025 Xiaomi Corporation (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import glob
+import logging
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+from lhotse.dataset.sampling.base import CutSampler
+from torch.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+
+from zipvoice.utils.common import AttributeDict
+
+# use duck typing for LRScheduler since we have different possibilities, see
+# our class LRScheduler.
+LRSchedulerType = object
+
+
+def save_checkpoint(
+ filename: Path,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ model_ema: Optional[nn.Module] = None,
+ params: Optional[Dict[str, Any]] = None,
+ optimizer: Optional[Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ scaler: Optional[GradScaler] = None,
+ sampler: Optional[CutSampler] = None,
+ rank: int = 0,
+) -> None:
+ """Save training information to a file.
+
+ Args:
+ filename:
+ The checkpoint filename.
+ model:
+ The model to be saved. We only save its `state_dict()`.
+ model_avg:
+ The stored model averaged from the start of training.
+ model_ema:
+ The EMA version of model.
+ params:
+ User defined parameters, e.g., epoch, loss.
+ optimizer:
+ The optimizer to be saved. We only save its `state_dict()`.
+ scheduler:
+ The scheduler to be saved. We only save its `state_dict()`.
+ scalar:
+ The GradScaler to be saved. We only save its `state_dict()`.
+ sampler:
+ The sampler used in the labeled training dataset. We only
+ save its `state_dict()`.
+ rank:
+ Used in DDP. We save checkpoint only for the node whose
+ rank is 0.
+ Returns:
+ Return None.
+ """
+ if rank != 0:
+ return
+
+ logging.info(f"Saving checkpoint to {filename}")
+
+ if isinstance(model, DDP):
+ model = model.module
+
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict() if optimizer is not None else None,
+ "scheduler": scheduler.state_dict() if scheduler is not None else None,
+ "grad_scaler": scaler.state_dict() if scaler is not None else None,
+ "sampler": sampler.state_dict() if sampler is not None else None,
+ }
+
+ if model_avg is not None:
+ checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
+ if model_ema is not None:
+ checkpoint["model_ema"] = model_ema.to(torch.float32).state_dict()
+
+ if params:
+ for k, v in params.items():
+ assert k not in checkpoint
+ checkpoint[k] = v
+
+ torch.save(checkpoint, filename)
+
+
+def load_checkpoint(
+ filename: Path,
+ model: Optional[nn.Module] = None,
+ model_avg: Optional[nn.Module] = None,
+ model_ema: Optional[nn.Module] = None,
+ strict: bool = False,
+) -> Dict[str, Any]:
+ logging.info(f"Loading checkpoint from {filename}")
+ checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
+
+ if model is not None:
+
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+
+ dst_state_dict = model.state_dict()
+ src_state_dict = checkpoint["model"]
+ for key in dst_state_dict.keys():
+ src_key = "{}.{}".format("module", key)
+ dst_state_dict[key] = src_state_dict.pop(src_key)
+ assert len(src_state_dict) == 0
+ model.load_state_dict(dst_state_dict, strict=strict)
+ else:
+ logging.info("Loading checkpoint")
+ model.load_state_dict(checkpoint["model"], strict=strict)
+
+ checkpoint.pop("model")
+
+ if model_avg is not None and "model_avg" in checkpoint:
+ logging.info("Loading averaged model")
+ model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
+ checkpoint.pop("model_avg")
+
+ if model_ema is not None and "model_ema" in checkpoint:
+ logging.info("Loading ema model")
+ model_ema.load_state_dict(checkpoint["model_ema"], strict=strict)
+ checkpoint.pop("model_ema")
+
+ return checkpoint
+
+
+def load_checkpoint_extend_vocab_size(
+ filename: Path, extend_size: int, model: nn.Module, strict: bool = True
+) -> Dict[str, Any]:
+ logging.info(f"Loading checkpoint from {filename}")
+ checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
+
+ if model is not None:
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+ dst_state_dict = model.state_dict()
+ src_state_dict = checkpoint["model"]
+ for key in dst_state_dict.keys():
+ src_key = "{}.{}".format("module", key)
+ dst_state_dict[key] = src_state_dict.pop(src_key)
+ assert len(src_state_dict) == 0
+ else:
+ logging.info("Loading checkpoint")
+ dst_state_dict = checkpoint["model"]
+ dst_state_dict["spk_embed.weight"] = model.state_dict()["spk_embed.weight"]
+ embed_weight = model.state_dict()["embed.weight"]
+ embed_weight[:-extend_size, :] = dst_state_dict["embed.weight"]
+ dst_state_dict["embed.weight"] = embed_weight
+
+ model.load_state_dict(dst_state_dict, strict=strict)
+
+
+def load_checkpoint_copy_proj_three_channel_alter(
+ filename: Path,
+ in_proj_key: str,
+ out_proj_key: str,
+ dim: int,
+ model: nn.Module,
+) -> Dict[str, Any]:
+ logging.info(f"Loading checkpoint from {filename}")
+ checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
+
+ if model is not None:
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+
+ dst_state_dict = dict()
+ src_state_dict = checkpoint["model"]
+ for key in src_state_dict.keys():
+ dst_state_dict[key.lstrip("module.")] = src_state_dict.pop(key)
+ assert len(src_state_dict) == 0
+ else:
+ logging.info("Loading checkpoint")
+ dst_state_dict = checkpoint["model"]
+ keys = list(dst_state_dict.keys())
+ for key in keys:
+ if in_proj_key in key:
+ if "weight" in key:
+ weight = dst_state_dict.pop(key)
+ dst_state_dict[key.replace("weight", "0.weight")] = torch.cat(
+ [
+ weight[:, :dim] / 2,
+ weight[:, :dim] / 2,
+ weight[:, dim : dim * 2],
+ weight[:, dim * 2 :] / 2,
+ weight[:, dim * 2 :] / 2,
+ ],
+ dim=-1,
+ )
+ dst_state_dict[key.replace("weight", "1.weight")] = weight
+ if "bias" in key:
+ bias = dst_state_dict.pop(key)
+ dst_state_dict[key.replace("bias", "0.bias")] = bias
+ dst_state_dict[key.replace("bias", "1.bias")] = bias
+ if out_proj_key in key:
+ if "weight" in key:
+ weight = dst_state_dict.pop(key)
+ dst_state_dict[key.replace("weight", "0.weight")] = torch.cat(
+ [weight, weight], dim=0
+ )
+ dst_state_dict[key.replace("weight", "1.weight")] = weight
+ elif "bias" in key:
+ bias = dst_state_dict.pop(key)
+ dst_state_dict[key.replace("bias", "0.bias")] = torch.cat(
+ [bias, bias], dim=0
+ )
+ dst_state_dict[key.replace("bias", "1.bias")] = bias
+
+ model.load_state_dict(dst_state_dict, strict=True)
+
+
+def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
+ """Find all available checkpoints in a directory.
+
+ The checkpoint filenames have the form: `checkpoint-xxx.pt`
+ where xxx is a numerical value.
+
+ Assume you have the following checkpoints in the folder `foo`:
+
+ - checkpoint-1.pt
+ - checkpoint-20.pt
+ - checkpoint-300.pt
+ - checkpoint-4000.pt
+
+ Case 1 (Return all checkpoints)::
+
+ find_checkpoints(out_dir='foo')
+
+ Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
+ checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)
+
+ find_checkpoints(out_dir='foo', iteration=20)
+
+ Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
+ checkpoint-20.pt, checkpoint-1.pt)::
+
+ find_checkpoints(out_dir='foo', iteration=-20)
+
+ Args:
+ out_dir:
+ The directory where to search for checkpoints.
+ iteration:
+ If it is 0, return all available checkpoints.
+ If it is positive, return the checkpoints whose iteration number is
+ greater than or equal to `iteration`.
+ If it is negative, return the checkpoints whose iteration number is
+ less than or equal to `-iteration`.
+ Returns:
+ Return a list of checkpoint filenames, sorted in descending
+ order by the numerical value in the filename.
+ """
+ checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
+ pattern = re.compile(r"checkpoint-([0-9]+).pt")
+ iter_checkpoints = []
+ for c in checkpoints:
+ result = pattern.search(c)
+ if not result:
+ logging.warn(f"Invalid checkpoint filename {c}")
+ continue
+
+ iter_checkpoints.append((int(result.group(1)), c))
+
+ # iter_checkpoints is a list of tuples. Each tuple contains
+ # two elements: (iteration_number, checkpoint-iteration_number.pt)
+
+ iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
+ if iteration >= 0:
+ ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
+ else:
+ ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]
+
+ return ans
+
+
+def average_checkpoints_with_averaged_model(
+ filename_start: str,
+ filename_end: str,
+ device: torch.device = torch.device("cpu"),
+) -> Dict[str, torch.Tensor]:
+ """Average model parameters over the range with given
+ start model (excluded) and end model.
+
+ Let start = batch_idx_train of model-start;
+ end = batch_idx_train of model-end;
+ interval = end - start.
+ Then the average model over range from start (excluded) to end is
+ (1) avg = (model_end * end - model_start * start) / interval.
+ It can be written as
+ (2) avg = model_end * weight_end + model_start * weight_start,
+ where weight_end = end / interval,
+ weight_start = -start / interval = 1 - weight_end.
+ Since the terms `weight_end` and `weight_start` would be large
+ if the model has been trained for lots of batches, which would cause
+ overflow when multiplying the model parameters.
+ To avoid this, we rewrite (2) as:
+ (3) avg = (model_end + model_start * (weight_start / weight_end))
+ * weight_end
+
+ The model index could be epoch number or iteration number.
+
+ Args:
+ filename_start:
+ Checkpoint filename of the start model. We assume it
+ is saved by :func:`save_checkpoint`.
+ filename_end:
+ Checkpoint filename of the end model. We assume it
+ is saved by :func:`save_checkpoint`.
+ device:
+ Move checkpoints to this device before averaging.
+ """
+ state_dict_start = torch.load(
+ filename_start, map_location=device, weights_only=False
+ )
+ state_dict_end = torch.load(filename_end, map_location=device, weights_only=False)
+
+ average_period = state_dict_start["average_period"]
+
+ batch_idx_train_start = state_dict_start["batch_idx_train"]
+ batch_idx_train_start = (batch_idx_train_start // average_period) * average_period
+ batch_idx_train_end = state_dict_end["batch_idx_train"]
+ batch_idx_train_end = (batch_idx_train_end // average_period) * average_period
+ interval = batch_idx_train_end - batch_idx_train_start
+ assert interval > 0, interval
+ weight_end = batch_idx_train_end / interval
+ weight_start = 1 - weight_end
+
+ model_end = state_dict_end["model_avg"]
+ model_start = state_dict_start["model_avg"]
+ avg = model_end
+
+ # scale the weight to avoid overflow
+ average_state_dict(
+ state_dict_1=avg,
+ state_dict_2=model_start,
+ weight_1=1.0,
+ weight_2=weight_start / weight_end,
+ scaling_factor=weight_end,
+ )
+
+ return avg
+
+
+def remove_checkpoints(
+ out_dir: Path,
+ topk: int,
+ rank: int = 0,
+):
+ """Remove checkpoints from the given directory.
+
+ We assume that checkpoint filename has the form `checkpoint-xxx.pt`
+ where xxx is a number, representing the number of processed batches
+ when saving that checkpoint. We sort checkpoints by filename and keep
+ only the `topk` checkpoints with the highest `xxx`.
+
+ Args:
+ out_dir:
+ The directory containing checkpoints to be removed.
+ topk:
+ Number of checkpoints to keep.
+ rank:
+ If using DDP for training, it is the rank of the current node.
+ Use 0 if no DDP is used for training.
+ """
+ assert topk >= 1, topk
+ if rank != 0:
+ return
+ checkpoints = find_checkpoints(out_dir)
+
+ if len(checkpoints) == 0:
+ logging.warn(f"No checkpoints found in {out_dir}")
+ return
+
+ if len(checkpoints) <= topk:
+ return
+
+ to_remove = checkpoints[topk:]
+ for c in to_remove:
+ os.remove(c)
+
+
+def resume_checkpoint(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module,
+ model_ema: Optional[nn.Module] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ filename = params.exp_dir / f"epoch-{params.start_epoch - 1}.pt"
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ model_ema=model_ema,
+ strict=True,
+ )
+
+ if params.start_epoch > 1:
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def average_state_dict(
+ state_dict_1: Dict[str, torch.Tensor],
+ state_dict_2: Dict[str, torch.Tensor],
+ weight_1: float,
+ weight_2: float,
+ scaling_factor: float = 1.0,
+) -> Dict[str, torch.Tensor]:
+ """Average two state_dict with given weights:
+ state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2)
+ * scaling_factor
+ It is an in-place operation on state_dict_1 itself.
+ """
+ # Identify shared parameters. Two parameters are said to be shared
+ # if they have the same data_ptr
+ uniqued: Dict[int, str] = dict()
+ for k, v in state_dict_1.items():
+ v_data_ptr = v.data_ptr()
+ if v_data_ptr in uniqued:
+ continue
+ uniqued[v_data_ptr] = k
+
+ uniqued_names = list(uniqued.values())
+ for k in uniqued_names:
+ v = state_dict_1[k]
+ if torch.is_floating_point(v):
+ v *= weight_1
+ v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
+ v *= scaling_factor
+
+
+def update_averaged_model(
+ params: Dict[str, torch.Tensor],
+ model_cur: Union[nn.Module, DDP],
+ model_avg: nn.Module,
+) -> None:
+ """Update the averaged model:
+ model_avg = model_cur * (average_period / batch_idx_train)
+ + model_avg * ((batch_idx_train - average_period) / batch_idx_train)
+
+ Args:
+ params:
+ User defined parameters, e.g., epoch, loss.
+ model_cur:
+ The current model.
+ model_avg:
+ The averaged model to be updated.
+ """
+ weight_cur = params.average_period / params.batch_idx_train
+ weight_avg = 1 - weight_cur
+
+ if isinstance(model_cur, DDP):
+ model_cur = model_cur.module
+
+ cur = model_cur.state_dict()
+ avg = model_avg.state_dict()
+
+ average_state_dict(
+ state_dict_1=avg,
+ state_dict_2=cur,
+ weight_1=weight_avg,
+ weight_2=weight_cur,
+ )
+
+
+def save_checkpoint_with_global_batch_idx(
+ out_dir: Path,
+ global_batch_idx: int,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ params: Optional[Dict[str, Any]] = None,
+ optimizer: Optional[Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ scaler: Optional[GradScaler] = None,
+ sampler: Optional[CutSampler] = None,
+ rank: int = 0,
+):
+ """Save training info after processing given number of batches.
+
+ Args:
+ out_dir:
+ The directory to save the checkpoint.
+ global_batch_idx:
+ The number of batches processed so far from the very start of the
+ training. The saved checkpoint will have the following filename:
+
+ f'out_dir / checkpoint-{global_batch_idx}.pt'
+ model:
+ The neural network model whose `state_dict` will be saved in the
+ checkpoint.
+ model_avg:
+ The stored model averaged from the start of training.
+ params:
+ A dict of training configurations to be saved.
+ optimizer:
+ The optimizer used in the training. Its `state_dict` will be saved.
+ scheduler:
+ The learning rate scheduler used in the training. Its `state_dict` will
+ be saved.
+ scaler:
+ The scaler used for mix precision training. Its `state_dict` will
+ be saved.
+ sampler:
+ The sampler used in the training dataset.
+ rank:
+ The rank ID used in DDP training of the current node. Set it to 0
+ if DDP is not used.
+ """
+ out_dir = Path(out_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+ filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
+ save_checkpoint(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ scaler=scaler,
+ sampler=sampler,
+ rank=rank,
+ )
diff --git a/zipvoice/utils/common.py b/zipvoice/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa80aedc55eaf03d9347269344ec7c318e6aff85
--- /dev/null
+++ b/zipvoice/utils/common.py
@@ -0,0 +1,664 @@
+import argparse
+import collections
+import json
+import logging
+import os
+import socket
+import subprocess
+import sys
+import warnings
+from collections import defaultdict
+from contextlib import contextmanager
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import torch
+from packaging import version
+from torch import distributed as dist
+from torch import nn
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+Pathlike = Union[str, Path]
+
+
+class AttributeDict(dict):
+ def __getattr__(self, key):
+ if key in self:
+ return self[key]
+ raise AttributeError(f"No such attribute '{key}'")
+
+ def __setattr__(self, key, value):
+ self[key] = value
+
+ def __delattr__(self, key):
+ if key in self:
+ del self[key]
+ return
+ raise AttributeError(f"No such attribute '{key}'")
+
+ def __str__(self, indent: int = 2):
+ tmp = {}
+ for k, v in self.items():
+ # PosixPath is ont JSON serializable
+ if isinstance(v, (Path, torch.device, torch.dtype)):
+ v = str(v)
+ tmp[k] = v
+ return json.dumps(tmp, indent=indent, sort_keys=True)
+
+
+class MetricsTracker(collections.defaultdict):
+ def __init__(self):
+ # Passing the type 'int' to the base-class constructor
+ # makes undefined items default to int() which is zero.
+ # This class will play a role as metrics tracker.
+ # It can record many metrics, including but not limited to loss.
+ super(MetricsTracker, self).__init__(int)
+
+ def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
+ ans = MetricsTracker()
+ for k, v in self.items():
+ ans[k] = v
+ for k, v in other.items():
+ if v - v == 0:
+ ans[k] = ans[k] + v
+ return ans
+
+ def __mul__(self, alpha: float) -> "MetricsTracker":
+ ans = MetricsTracker()
+ for k, v in self.items():
+ ans[k] = v * alpha
+ return ans
+
+ def __str__(self) -> str:
+ ans_frames = ""
+ ans_utterances = ""
+ for k, v in self.norm_items():
+ norm_value = "%.4g" % v
+ if "utt_" not in k:
+ ans_frames += str(k) + "=" + str(norm_value) + ", "
+ else:
+ ans_utterances += str(k) + "=" + str(norm_value)
+ if k == "utt_duration":
+ ans_utterances += " frames, "
+ elif k == "utt_pad_proportion":
+ ans_utterances += ", "
+ else:
+ raise ValueError(f"Unexpected key: {k}")
+ frames = "%.2f" % self["frames"]
+ ans_frames += "over " + str(frames) + " frames. "
+ if ans_utterances != "":
+ utterances = "%.2f" % self["utterances"]
+ ans_utterances += "over " + str(utterances) + " utterances."
+
+ return ans_frames + ans_utterances
+
+ def norm_items(self) -> List[Tuple[str, float]]:
+ """
+ Returns a list of pairs, like:
+ [('ctc_loss', 0.1), ('att_loss', 0.07)]
+ """
+ num_frames = self["frames"] if "frames" in self else 1
+ num_utterances = self["utterances"] if "utterances" in self else 1
+ ans = []
+ for k, v in self.items():
+ if k == "frames" or k == "utterances":
+ continue
+ norm_value = (
+ float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
+ )
+ ans.append((k, norm_value))
+ return ans
+
+ def reduce(self, device):
+ """
+ Reduce using torch.distributed, which I believe ensures that
+ all processes get the total.
+ """
+ keys = sorted(self.keys())
+ s = torch.tensor([float(self[k]) for k in keys], device=device)
+ dist.all_reduce(s, op=dist.ReduceOp.SUM)
+ for k, v in zip(keys, s.cpu().tolist()):
+ self[k] = v
+
+ def write_summary(
+ self,
+ tb_writer: SummaryWriter,
+ prefix: str,
+ batch_idx: int,
+ ) -> None:
+ """Add logging information to a TensorBoard writer.
+
+ Args:
+ tb_writer: a TensorBoard writer
+ prefix: a prefix for the name of the loss, e.g. "train/valid_",
+ or "train/current_"
+ batch_idx: The current batch index, used as the x-axis of the plot.
+ """
+ for k, v in self.norm_items():
+ tb_writer.add_scalar(prefix + k, v, batch_idx)
+
+
+@contextmanager
+def torch_autocast(device_type="cuda", **kwargs):
+ """
+ To fix the following warnings:
+ FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated.
+ Please use `torch.amp.autocast('cuda', args...)` instead.
+ with torch.cuda.amp.autocast(enabled=False):
+ """
+ if version.parse(torch.__version__) >= version.parse("2.3.0"):
+ # Use new unified API
+ with torch.amp.autocast(device_type=device_type, **kwargs):
+ yield
+ else:
+ # Suppress deprecation warning and use old CUDA-specific autocast
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=FutureWarning)
+ with torch.cuda.amp.autocast(**kwargs):
+ yield
+
+
+def create_grad_scaler(device="cuda", **kwargs):
+ """
+ Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0.
+ Accepts all kwargs like: enabled, init_scale, growth_factor, etc.
+
+ FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated.
+ Please use `torch.amp.GradScaler('cuda', args...)` instead.
+ """
+ if version.parse(torch.__version__) >= version.parse("2.3.0"):
+ from torch.amp import GradScaler
+
+ return GradScaler(device=device, **kwargs)
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=FutureWarning)
+ return torch.cuda.amp.GradScaler(**kwargs)
+
+
+def setup_dist(
+ rank=None,
+ world_size=None,
+ master_port=None,
+ use_ddp_launch=False,
+ master_addr=None,
+):
+ """
+ rank and world_size are used only if use_ddp_launch is False.
+ """
+ if "MASTER_ADDR" not in os.environ:
+ os.environ["MASTER_ADDR"] = (
+ "localhost" if master_addr is None else str(master_addr)
+ )
+
+ if "MASTER_PORT" not in os.environ:
+ os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
+
+ if use_ddp_launch is False:
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+ torch.cuda.set_device(rank)
+ else:
+ dist.init_process_group("nccl")
+
+
+def cleanup_dist():
+ dist.destroy_process_group()
+
+
+def prepare_input(
+ params: AttributeDict,
+ batch: dict,
+ device: torch.device,
+ return_tokens: bool = True,
+ return_feature: bool = True,
+ return_audio: bool = False,
+):
+ """
+ Parse the features and targets of the current batch.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ device:
+ The device of Tensor.
+ """
+ return_list = []
+
+ if return_tokens:
+ return_list += [batch["tokens"]]
+
+ if return_feature:
+ features = batch["features"].to(device)
+ features_lens = batch["features_lens"].to(device)
+ return_list += [features * params.feat_scale, features_lens]
+
+ if return_audio:
+ return_list += [batch["audio"], batch["audio_lens"]]
+
+ return return_list
+
+
+def prepare_avg_tokens_durations(features_lens, tokens_lens):
+ tokens_durations = []
+ for i in range(len(features_lens)):
+ utt_duration = features_lens[i]
+ avg_token_duration = utt_duration // tokens_lens[i]
+ tokens_durations.append([avg_token_duration] * tokens_lens[i])
+ return tokens_durations
+
+
+def pad_labels(y: List[List[int]], pad_id: int, device: torch.device):
+ """
+ Pad the transcripts to the same length with zeros.
+
+ Args:
+ y: the transcripts, which is a list of a list
+
+ Returns:
+ Return a Tensor of padded transcripts.
+ """
+ y = [token_ids + [pad_id] for token_ids in y]
+ length = max([len(token_ids) for token_ids in y])
+ y = [token_ids + [pad_id] * (length - len(token_ids)) for token_ids in y]
+ return torch.tensor(y, dtype=torch.int64, device=device)
+
+
+def get_tokens_index(durations: List[List[int]], num_frames: int) -> torch.Tensor:
+ """
+ Gets position in the transcript for each frame, i.e. the position
+ in the symbol-sequence to look up.
+
+ Args:
+ durations:
+ Duration of each token in transcripts.
+ num_frames:
+ The maximum frame length of the current batch.
+
+ Returns:
+ Return a Tensor of shape (batch_size, num_frames)
+ """
+ durations = [x + [num_frames - sum(x)] for x in durations]
+ batch_size = len(durations)
+ ans = torch.zeros(batch_size, num_frames, dtype=torch.int64)
+ for b in range(batch_size):
+ this_dur = durations[b]
+ cur_frame = 0
+ for i, d in enumerate(this_dur):
+ ans[b, cur_frame : cur_frame + d] = i
+ cur_frame += d
+ assert cur_frame == num_frames, (cur_frame, num_frames)
+ return ans
+
+
+def to_int_tuple(s: Union[str, int]):
+ if isinstance(s, int):
+ return (s,)
+ return tuple(map(int, s.split(",")))
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the
+ # reference duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def condition_time_mask(
+ features_lens: torch.Tensor,
+ mask_percent: Tuple[float, float],
+ max_len: int = 0,
+) -> torch.Tensor:
+ """
+ Apply Time masking.
+ Args:
+ features_lens:
+ input tensor of shape ``(B)``
+ mask_size:
+ the width size for masking.
+ max_len:
+ the maximum length of the mask.
+ Returns:
+ Return a 2-D bool tensor (B, T), where masked positions
+ are filled with `True` and non-masked positions are
+ filled with `False`.
+ """
+ mask_size = (
+ torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
+ * features_lens
+ ).to(torch.int64)
+ mask_starts = (
+ torch.rand_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
+ ).to(torch.int64)
+ mask_ends = mask_starts + mask_size
+ max_len = max(max_len, features_lens.max())
+ seq_range = torch.arange(0, max_len, device=features_lens.device)
+ mask = (seq_range[None, :] >= mask_starts[:, None]) & (
+ seq_range[None, :] < mask_ends[:, None]
+ )
+ return mask
+
+
+def condition_time_mask_suffix(
+ features_lens: torch.Tensor,
+ mask_percent: Tuple[float, float],
+ max_len: int = 0,
+) -> torch.Tensor:
+ """
+ Apply Time masking, mask from the end time index.
+ Args:
+ features_lens:
+ input tensor of shape ``(B)``
+ mask_size:
+ the width size for masking.
+ max_len:
+ the maximum length of the mask.
+ Returns:
+ Return a 2-D bool tensor (B, T), where masked positions
+ are filled with `True` and non-masked positions are
+ filled with `False`.
+ """
+ mask_size = (
+ torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
+ * features_lens
+ ).to(torch.int64)
+ mask_starts = (
+ torch.ones_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
+ ).to(torch.int64)
+ mask_ends = mask_starts + mask_size
+ max_len = max(max_len, features_lens.max())
+ seq_range = torch.arange(0, max_len, device=features_lens.device)
+ mask = (seq_range[None, :] >= mask_starts[:, None]) & (
+ seq_range[None, :] < mask_ends[:, None]
+ )
+ return mask
+
+
+def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
+ """
+ Args:
+ lengths:
+ A 1-D tensor containing sentence lengths.
+ max_len:
+ The length of masks.
+ Returns:
+ Return a 2-D bool tensor, where masked positions
+ are filled with `True` and non-masked positions are
+ filled with `False`.
+
+ >>> lengths = torch.tensor([1, 3, 2, 5])
+ >>> make_pad_mask(lengths)
+ tensor([[False, True, True, True, True],
+ [False, False, False, True, True],
+ [False, False, True, True, True],
+ [False, False, False, False, False]])
+ """
+ assert lengths.ndim == 1, lengths.ndim
+ max_len = max(max_len, lengths.max())
+ n = lengths.size(0)
+ seq_range = torch.arange(0, max_len, device=lengths.device)
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
+
+ return expaned_lengths >= lengths.unsqueeze(-1)
+
+
+def str2bool(v):
+ """Used in argparse.ArgumentParser.add_argument to indicate
+ that a type is a bool type and user can enter
+
+ - yes, true, t, y, 1, to represent True
+ - no, false, f, n, 0, to represent False
+
+ See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
+ """
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+
+def setup_logger(
+ log_filename: Pathlike,
+ log_level: str = "info",
+ use_console: bool = True,
+) -> None:
+ """Setup log level.
+
+ Args:
+ log_filename:
+ The filename to save the log.
+ log_level:
+ The log level to use, e.g., "debug", "info", "warning", "error",
+ "critical"
+ use_console:
+ True to also print logs to console.
+ """
+ now = datetime.now()
+ date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
+ if dist.is_available() and dist.is_initialized():
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
+ log_filename = f"{log_filename}-{date_time}-{rank}"
+ else:
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ log_filename = f"{log_filename}-{date_time}"
+
+ os.makedirs(os.path.dirname(log_filename), exist_ok=True)
+
+ level = logging.ERROR
+ if log_level == "debug":
+ level = logging.DEBUG
+ elif log_level == "info":
+ level = logging.INFO
+ elif log_level == "warning":
+ level = logging.WARNING
+ elif log_level == "critical":
+ level = logging.CRITICAL
+
+ logging.basicConfig(
+ filename=log_filename,
+ format=formatter,
+ level=level,
+ filemode="w",
+ force=True,
+ )
+ if use_console:
+ console = logging.StreamHandler()
+ console.setLevel(level)
+ console.setFormatter(logging.Formatter(formatter))
+ logging.getLogger("").addHandler(console)
+
+
+def get_git_sha1():
+ try:
+ git_commit = (
+ subprocess.run(
+ ["git", "rev-parse", "--short", "HEAD"],
+ check=True,
+ stdout=subprocess.PIPE,
+ )
+ .stdout.decode()
+ .rstrip("\n")
+ .strip()
+ )
+ dirty_commit = (
+ len(
+ subprocess.run(
+ ["git", "diff", "--shortstat"],
+ check=True,
+ stdout=subprocess.PIPE,
+ )
+ .stdout.decode()
+ .rstrip("\n")
+ .strip()
+ )
+ > 0
+ )
+ git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
+ except: # noqa
+ return None
+
+ return git_commit
+
+
+def get_git_date():
+ try:
+ git_date = (
+ subprocess.run(
+ ["git", "log", "-1", "--format=%ad", "--date=local"],
+ check=True,
+ stdout=subprocess.PIPE,
+ )
+ .stdout.decode()
+ .rstrip("\n")
+ .strip()
+ )
+ except: # noqa
+ return None
+
+ return git_date
+
+
+def get_git_branch_name():
+ try:
+ git_date = (
+ subprocess.run(
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"],
+ check=True,
+ stdout=subprocess.PIPE,
+ )
+ .stdout.decode()
+ .rstrip("\n")
+ .strip()
+ )
+ except: # noqa
+ return None
+
+ return git_date
+
+
+def get_env_info() -> Dict[str, Any]:
+ """Get the environment information."""
+ return {
+ "torch-version": str(torch.__version__),
+ "torch-cuda-available": torch.cuda.is_available(),
+ "torch-cuda-version": torch.version.cuda,
+ "python-version": sys.version[:4],
+ "zipvoice-git-branch": get_git_branch_name(),
+ "zipvoice-git-sha1": get_git_sha1(),
+ "zipvoice-git-date": get_git_date(),
+ "zipvoice-path": str(Path(__file__).resolve().parent.parent),
+ "hostname": socket.gethostname(),
+ "IP address": socket.gethostbyname(socket.gethostname()),
+ }
+
+
+def get_parameter_groups_with_lrs(
+ model: nn.Module,
+ lr: float,
+ include_names: bool = False,
+ freeze_modules: List[str] = [],
+ unfreeze_modules: List[str] = [],
+) -> List[dict]:
+ """
+ This is for use with the ScaledAdam optimizers (more recent versions that accept
+ lists of named-parameters; we can, if needed, create a version without the names).
+
+ It provides a way to specify learning-rate scales inside the module, so that if
+ any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will
+ scale the LR of any parameters inside that module or its submodules. Note: you
+ can set module parameters outside the __init__ function, e.g.:
+ >>> a = nn.Linear(10, 10)
+ >>> a.lr_scale = 0.5
+
+ Returns: a list of dicts, of the following form:
+ if include_names == False:
+ [ { 'params': [ tensor1, tensor2, ... ], 'lr': 0.01 },
+ { 'params': [ tensor3, tensor4, ... ], 'lr': 0.005 },
+ ... ]
+ if include_names == true:
+ [ { 'named_params': [ (name1, tensor1, (name2, tensor2), ... ], 'lr': 0.01 },
+ { 'named_params': [ (name3, tensor3), (name4, tensor4), ... ], 'lr': 0.005 },
+ ... ]
+
+ """
+ # Use freeze_modules or unfreeze_modules to freeze or unfreeze modules
+ assert not (len(freeze_modules) and len(unfreeze_modules))
+
+ # flat_lr_scale just contains the lr_scale explicitly specified
+ # for each prefix of the name, e.g. 'encoder.layers.3', these need
+ # to be multiplied for all prefix of the name of any given parameter.
+ flat_lr_scale = defaultdict(lambda: 1.0)
+ names = []
+ for name, m in model.named_modules():
+ names.append(name)
+ if hasattr(m, "lr_scale"):
+ flat_lr_scale[name] = m.lr_scale
+
+ # lr_to_parames is a dict from learning rate (floating point) to: if
+ # include_names == true, a list of (name, parameter) for that learning rate;
+ # otherwise a list of parameters for that learning rate.
+ lr_to_params = defaultdict(list)
+
+ for name, parameter in model.named_parameters():
+ if not parameter.requires_grad:
+ logging.info(f"Remove {name} from parameter")
+ continue
+ split_name = name.split(".")
+ # caution: as a special case, if the name is '', split_name will be [ '' ].
+ prefix = split_name[0]
+ if len(freeze_modules) > 0:
+ if prefix == "module": # DDP
+ module_name = split_name[1]
+ if module_name in freeze_modules:
+ logging.info(f"Remove {name} from parameters")
+ continue
+ else:
+ if prefix in freeze_modules:
+ logging.info(f"Remove {name} from parameters")
+ continue
+ elif len(unfreeze_modules) > 0:
+ if prefix == "module": # DDP
+ module_name = split_name[1]
+ if module_name not in unfreeze_modules:
+ logging.info(f"Remove {name} from parameters")
+ continue
+ else:
+ if prefix not in unfreeze_modules:
+ logging.info(f"Remove {name} from parameters")
+ continue
+ cur_lr = lr * flat_lr_scale[prefix]
+ if prefix != "":
+ cur_lr *= flat_lr_scale[""]
+ for part in split_name[1:]:
+ prefix = ".".join([prefix, part])
+ cur_lr *= flat_lr_scale[prefix]
+ lr_to_params[cur_lr].append((name, parameter) if include_names else parameter)
+
+ if include_names:
+ return [{"named_params": pairs, "lr": lr} for lr, pairs in lr_to_params.items()]
+ else:
+ return [{"params": params, "lr": lr} for lr, params in lr_to_params.items()]
diff --git a/zipvoice/utils/diagnostics.py b/zipvoice/utils/diagnostics.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdefaa38019db9f1028333756ebc75535c40b6c4
--- /dev/null
+++ b/zipvoice/utils/diagnostics.py
@@ -0,0 +1,723 @@
+# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey
+# Zengwei Yao
+# Mingshuang Luo,
+# Zengrui Jin,)
+#
+# See ../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import random
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, nn
+
+
+class TensorDiagnosticOptions(object):
+ """Options object for tensor diagnostics:
+
+ Args:
+ max_eig_dim:
+ The maximum dimension for which we print out eigenvalues
+ (limited for speed reasons).
+ """
+
+ def __init__(self, max_eig_dim: int = 512):
+ self.max_eig_dim = max_eig_dim
+
+ def dim_is_summarized(self, size: int):
+ return size > 10 and size != 31
+
+
+def get_tensor_stats(
+ x: Tensor,
+ dim: int,
+ stats_type: str,
+) -> Tuple[Tensor, int]:
+ """
+ Returns the specified transformation of the Tensor (either x or x.abs()
+ or (x > 0), summed over all but the index `dim`.
+
+ Args:
+ x:
+ Tensor, tensor to be analyzed
+ dim:
+ Dimension with 0 <= dim < x.ndim
+ stats_type:
+ The stats_type includes several types:
+ "abs" -> take abs() before summing
+ "positive" -> take (x > 0) before summing
+ "rms" -> square before summing, we'll take sqrt later
+ "value" -> just sum x itself
+ "max", "min" -> take the maximum or minimum [over all other dims but dim]
+ instead of summing
+ "rms-sort" -> this is a bit different than the others, it's based on computing
+ the rms over the specified dim and returning percentiles of the result
+ (11 of them).
+ Returns:
+ stats: a Tensor of shape (x.shape[dim],).
+ count: an integer saying how many items were counted in each element
+ of stats.
+ """
+
+ if stats_type == "rms-sort":
+ rms = (x**2).mean(dim=dim).sqrt()
+ rms = rms.flatten()
+ rms = rms.sort()[0]
+ rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)]
+ count = 1.0
+ return rms, count
+
+ count = x.numel() // x.shape[dim]
+
+ if stats_type == "eigs":
+ x = x.transpose(dim, -1)
+ x = x.reshape(-1, x.shape[-1])
+ # shape of returned tensor: (s, s),
+ # where s is size of dimension `dim` of original x.
+ return torch.matmul(x.transpose(0, 1), x), count
+ elif stats_type == "abs":
+ x = x.abs()
+ elif stats_type == "rms":
+ x = x**2
+ elif stats_type == "positive":
+ x = (x > 0).to(dtype=torch.float)
+ else:
+ assert stats_type in ["value", "max", "min"]
+
+ sum_dims = [d for d in range(x.ndim) if d != dim]
+ if len(sum_dims) > 0:
+ if stats_type == "max":
+ for dim in reversed(sum_dims):
+ x = torch.max(x, dim=dim)[0]
+ elif stats_type == "min":
+ for dim in reversed(sum_dims):
+ x = torch.min(x, dim=dim)[0]
+ else:
+ x = torch.sum(x, dim=sum_dims)
+ x = x.flatten().clone()
+ return x, count
+
+
+@dataclass
+class TensorAndCount:
+ tensor: Tensor
+ count: int
+
+
+class TensorDiagnostic(object):
+ """This class is not directly used by the user, it is responsible for
+ collecting diagnostics for a module or parameter tensor of a torch.nn.Module.
+
+ Args:
+ opts:
+ Options object.
+ name:
+ The name associated with this diagnostics object, will probably be
+ {module_name}.X where X is "output" or "grad", or {parameter_name}.
+ Y where Y is param_value or param_grad.
+ """
+
+ def __init__(self, opts: TensorDiagnosticOptions, name: str):
+ self.opts = opts
+ self.name = name
+ self.class_name = None # will assign in accumulate()
+
+ self.stats = None # we'll later assign a list to self.stats.
+ # It's a list of dicts, indexed by dim (i.e. by the
+ # axis of the tensor). The dicts, in turn, are
+ # indexed by `stats-type` which are strings in
+ # ["abs", "max", "min", "positive", "value", "rms"].
+
+ # scalar_stats contains some analysis of the activations and gradients,
+ self.scalar_stats = None
+
+ # the keys into self.stats[dim] are strings, whose values can be
+ # "abs", "max", "min" ,"value", "positive", "rms", "value".
+ # The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
+ # containing a tensor and its associated count (which is the sum of the other
+ # dims that we aggregated over, e.g. the number of frames and/or batch elements
+ # and/or channels.
+ # ... we actually accumulate the Tensors / counts any time we have the same-dim
+ # tensor, only adding a new element to the list if there was a different dim.
+ # if the string in the key is "eigs", if we detect a length mismatch we put None
+ # as the value.
+
+ def accumulate(self, x, class_name: Optional[str] = None):
+ """
+ Accumulate tensors.
+ """
+ if class_name is not None:
+ self.class_name = class_name
+ if isinstance(x, Tuple):
+ x = x[0]
+ if not isinstance(x, Tensor):
+ return
+ if x.numel() == 0: # for empty tensor
+ return
+ x = x.detach().clone()
+ if x.ndim == 0:
+ x = x.unsqueeze(0)
+ ndim = x.ndim
+ if self.stats is None:
+ self.stats = [dict() for _ in range(ndim)]
+
+ for dim in range(ndim):
+ this_dim_stats = self.stats[dim]
+ if ndim > 1:
+ # rms-sort is different from the others, it's based on summing over just
+ # this dim, then sorting and returning the percentiles.
+ stats_types = [
+ "abs",
+ "max",
+ "min",
+ "positive",
+ "value",
+ "rms",
+ "rms-sort",
+ ]
+ if x.shape[dim] <= self.opts.max_eig_dim:
+ stats_types.append("eigs")
+ else:
+ stats_types = ["value", "abs", "max", "min"]
+
+ for stats_type in stats_types:
+ stats, count = get_tensor_stats(x, dim, stats_type)
+ if stats_type not in this_dim_stats:
+ this_dim_stats[stats_type] = [] # list of TensorAndCount
+
+ done = False
+ if this_dim_stats[stats_type] is None:
+ # we can reach here if we detected for stats_type "eigs" that
+ # where was more than one different size for this dim. Then we
+ # disable accumulating this stats type, as it uses too much memory.
+ continue
+ for s in this_dim_stats[stats_type]:
+ if s.tensor.shape == stats.shape:
+ if stats_type == "max":
+ s.tensor = torch.maximum(s.tensor, stats)
+
+ elif stats_type == "min":
+ s.tensor = torch.minimum(s.tensor, stats)
+ else:
+ assert stats_type != "max"
+ s.tensor += stats
+ s.count += count
+ done = True
+ break
+ if not done:
+ if this_dim_stats[stats_type] != [] and stats_type == "eigs":
+ # >1 size encountered on this dim, e.g. it's a batch or time
+ # dimension, don't accumulat "eigs" stats type, it uses too much
+ # memory
+ this_dim_stats[stats_type] = None
+ else:
+ this_dim_stats[stats_type].append(TensorAndCount(stats, count))
+
+ def print_diagnostics(self):
+ """Print diagnostics for each dimension of the tensor."""
+ if self.stats is None:
+ print(f"Warning: the stats of {self.name} is None.")
+ return
+ for dim, this_dim_stats in enumerate(self.stats):
+ if "rms" in this_dim_stats and "value" in this_dim_stats:
+ # produce "stddev" stats, which is centered RMS.
+ rms_stats_list = this_dim_stats["rms"]
+ value_stats_list = this_dim_stats["value"]
+ if len(rms_stats_list) == len(value_stats_list):
+ stddev_stats_list = []
+ for r, v in zip(rms_stats_list, value_stats_list):
+ stddev_stats_list.append(
+ # r.count and v.count should be the same, but we don't check
+ # this.
+ TensorAndCount(
+ r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20),
+ r.count,
+ )
+ )
+ this_dim_stats["stddev"] = stddev_stats_list
+
+ for stats_type, stats_list in this_dim_stats.items():
+ # stats_type could be "rms", "value", "abs", "eigs", "positive", "min"
+ # or "max". "stats_list" could be a list of TensorAndCount (one list per
+ # distinct tensor shape of the stats), or None
+ if stats_list is None:
+ assert stats_type == "eigs"
+ continue
+
+ def get_count(count):
+ return 1 if stats_type in ["max", "min"] else count
+
+ if len(stats_list) == 1:
+ stats = stats_list[0].tensor / get_count(stats_list[0].count)
+ else:
+ # a dimension that has variable size in different nnet
+ # forwards, e.g. a time dimension in an ASR model.
+ stats = torch.cat(
+ [x.tensor / get_count(x.count) for x in stats_list], dim=0
+ )
+
+ if stats_type == "eigs":
+ try:
+ if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
+ eigs, _ = torch.linalg.eigh(stats)
+ else:
+ eigs, _ = torch.symeig(stats)
+ stats = eigs.abs().sqrt()
+ except: # noqa
+ print("Error getting eigenvalues, trying another method.")
+ if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"):
+ eigs, _ = torch.linalg.eig(stats)
+ eigs = eigs.abs()
+ else:
+ eigs, _ = torch.eig(stats)
+ eigs = eigs.norm(dim=1)
+ stats = eigs.sqrt()
+ # sqrt so it reflects data magnitude, like stddev- not variance
+
+ if stats_type in ["rms", "stddev"]:
+ # we stored the square; after aggregation we need to take sqrt.
+ stats = stats.sqrt()
+
+ # if `summarize` we print percentiles of the stats; else,
+ # we print out individual elements.
+ summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
+ stats.numel()
+ )
+ if summarize: # usually `summarize` will be true
+ # print out percentiles.
+ stats = stats.sort()[0]
+ num_percentiles = 10
+ size = stats.numel()
+ percentiles = []
+ for i in range(num_percentiles + 1):
+ index = (i * (size - 1)) // num_percentiles
+ percentiles.append(stats[index].item())
+ percentiles = ["%.2g" % x for x in percentiles]
+ percentiles = " ".join(percentiles)
+ ans = f"percentiles: [{percentiles}]"
+ else:
+ ans = stats.tolist()
+ ans = ["%.2g" % x for x in ans]
+ ans = "[" + " ".join(ans) + "]"
+ if stats_type in ["value", "rms", "stddev", "eigs"]:
+ # This norm is useful because it is strictly less than the largest
+ # sqrt(eigenvalue) of the variance, which we print out, and shows,
+ # speaking in an approximate way, how much of that largest
+ # eigenvalue can be attributed to the mean of the distribution.
+ norm = (stats**2).sum().sqrt().item()
+ ans += f", norm={norm:.2g}"
+ mean = stats.mean().item()
+ rms = (stats**2).mean().sqrt().item()
+ ans += f", mean={mean:.3g}, rms={rms:.3g}"
+
+ # OK, "ans" contains the actual stats, e.g.
+ # ans = "percentiles: \
+ # [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], \
+ # mean=0.5, rms=0.5"
+
+ sizes = [x.tensor.shape[0] for x in stats_list]
+ size_str = (
+ f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
+ )
+ maybe_class_name = (
+ f" type={self.class_name}," if self.class_name is not None else ""
+ )
+ print(
+ f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, "
+ f"{stats_type} {ans}"
+ )
+
+
+class ScalarDiagnostic(object):
+ """This class is not directly used by the user, it is responsible for
+ collecting diagnostics for a single module (subclass of torch.nn.Module) that
+ represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc.
+ """
+
+ def __init__(self, opts: TensorDiagnosticOptions, name: str):
+ self.opts = opts
+ self.name = name
+ self.class_name = None # will assign in accumulate()
+ self.is_forward_pass = True
+
+ self.tick_scale = None
+
+ self.saved_inputs = []
+ self.is_ok = True
+
+ self.counts = None
+ self.sum_grad = None
+ self.sum_gradsq = None
+ self.sum_abs_grad = None
+
+ def accumulate_input(self, x: Tensor, class_name: Optional[str] = None):
+ """
+ Called in forward pass.
+ """
+ if not self.is_forward_pass:
+ # in case we did a forward pass without a backward pass, for some reason.
+ self.saved_inputs = []
+ self.is_forward_pass = True
+
+ if class_name is not None:
+ self.class_name = class_name
+ if not self.is_ok:
+ return
+
+ limit = 10
+ if len(self.saved_inputs) > limit:
+ print(
+ f"ERROR: forward pass called for this module over {limit} times "
+ f"with no backward pass. Will not accumulate scalar stats."
+ )
+ self.is_ok = False
+ return
+ self.saved_inputs.append(x)
+
+ def accumulate_output_grad(self, grad: Tensor):
+ if not self.is_ok:
+ return
+ if self.is_forward_pass:
+ self.is_forward_pass = False
+
+ last_shape = (
+ "n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
+ )
+ if len(self.saved_inputs) == 0 or grad.shape != last_shape:
+ print(
+ f"ERROR: shape mismatch or no forward activation present when backward "
+ f"pass called: grad shape ={tuple(grad.shape)}"
+ f", num-saved-inputs={len(self.saved_inputs)}"
+ f", shape-of-last-saved-input={last_shape}"
+ )
+ self.is_ok = False
+ return
+
+ x = self.saved_inputs.pop()
+ self.process_input_and_grad(x, grad)
+
+ def process_input_and_grad(self, x: Tensor, grad: Tensor):
+ assert x.shape == grad.shape
+ x = x.flatten()
+ grad = grad.flatten()
+
+ num_ticks_per_side = 256
+
+ if self.tick_scale is None:
+ x_abs_sorted = x.abs().sort()[0]
+ # take the 98th percentile as the largest value we count separately.
+ index = int(x.numel() * 0.98)
+ self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side)
+
+ # integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1]
+ self.counts = torch.zeros(
+ 2 * num_ticks_per_side, dtype=torch.long, device=x.device
+ )
+ self.sum_grad = torch.zeros(
+ 2 * num_ticks_per_side, dtype=torch.double, device=x.device
+ )
+ # sum_gradsq is for getting error bars.
+ self.sum_gradsq = torch.zeros(
+ 2 * num_ticks_per_side, dtype=torch.double, device=x.device
+ )
+ self.sum_abs_grad = torch.zeros(
+ 2 * num_ticks_per_side, dtype=torch.double, device=x.device
+ )
+
+ # this will round down.
+ x = (x / self.tick_scale).to(torch.long)
+ x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1)
+ x = x + num_ticks_per_side
+
+ self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
+ self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
+ self.sum_gradsq.index_add_(
+ dim=0, index=x, source=(grad * grad).to(torch.double)
+ )
+ self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double))
+
+ def print_diagnostics(self):
+ """Print diagnostics."""
+ if self.is_ok is False or self.counts is None:
+ print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}")
+ return
+
+ counts = self.counts.to("cpu")
+ sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32)
+ sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32)
+ sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32)
+
+ counts_cumsum = counts.cumsum(dim=0)
+ counts_tot = counts_cumsum[-1]
+
+ # subdivide the distribution up into `num_bins` intervals for analysis, for
+ # greater statistical significance. each bin corresponds to multiple of the
+ # original 'tick' intervals.
+ num_bins = 20
+
+ # integer division
+ counts_per_bin = (counts_tot // num_bins) + 1
+ bin_indexes = counts_cumsum // counts_per_bin
+ bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long)
+
+ bin_counts = torch.zeros(num_bins, dtype=torch.long)
+ bin_counts.index_add_(dim=0, index=bin_indexes, source=counts)
+ bin_grad = torch.zeros(num_bins)
+ bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad)
+ bin_gradsq = torch.zeros(num_bins)
+ bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
+ bin_abs_grad = torch.zeros(num_bins)
+ bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad)
+
+ bin_boundary_counts = (
+ torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
+ )
+ bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
+ # boundaries are the "x" values between the bins, e.g. corresponding to the
+ # locations of percentiles of the distribution.
+ num_ticks_per_side = counts.numel() // 2
+ bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale
+
+ bin_grad = bin_grad / (bin_counts + 1)
+ bin_conf_interval = bin_gradsq.sqrt() / (
+ bin_counts + 1
+ ) # consider this a standard deviation.
+ # bin_grad / bin_abs_grad will give us a sense for how important in a practical
+ # sense, the gradients are.
+ bin_abs_grad = bin_abs_grad / (bin_counts + 1)
+
+ bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20)
+ bin_conf = bin_grad / (bin_conf_interval + 1.0e-20)
+
+ def tensor_to_str(x: Tensor):
+ x = ["%.2g" % f for f in x]
+ x = "[" + " ".join(x) + "]"
+ return x
+
+ maybe_class_name = (
+ f" type={self.class_name}," if self.class_name is not None else ""
+ )
+
+ print(
+ f"module={self.name},{maybe_class_name} "
+ f"bin-boundaries={tensor_to_str(bin_boundaries)}, "
+ f"rel_grad={tensor_to_str(bin_rel_grad)}, "
+ f"grad_conf={tensor_to_str(bin_conf)}"
+ )
+
+
+class ModelDiagnostic(object):
+ """This class stores diagnostics for all tensors in the torch.nn.Module.
+
+ Args:
+ opts:
+ Options object.
+ """
+
+ def __init__(self, opts: Optional[TensorDiagnosticOptions] = None):
+ # In this dictionary, the keys are tensors names and the values
+ # are corresponding TensorDiagnostic objects.
+ if opts is None:
+ self.opts = TensorDiagnosticOptions()
+ else:
+ self.opts = opts
+ self.diagnostics = dict()
+
+ def __getitem__(self, name: str):
+ T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic
+ if name not in self.diagnostics:
+ self.diagnostics[name] = T(self.opts, name)
+ return self.diagnostics[name]
+
+ def print_diagnostics(self):
+ """Print diagnostics for each tensor."""
+ for k in sorted(self.diagnostics.keys()):
+ self.diagnostics[k].print_diagnostics()
+
+
+def get_class_name(module: nn.Module):
+ ans = type(module).__name__
+ # we put the below in try blocks in case anyone is using a different version of
+ # these modules that might have different member names.
+ if ans == "Balancer" or ans == "ActivationBalancer":
+ try:
+ ans += f"[{float(module.min_positive)},{float(module.max_positive)},"
+ f"{float(module.min_abs)},{float(module.max_abs)}]"
+ except:
+ pass
+ elif ans == "AbsValuePenalizer":
+ try:
+ ans += f"[{module.limit}]"
+ except:
+ pass
+ return ans
+
+
+def attach_diagnostics(
+ model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None
+) -> ModelDiagnostic:
+ """Attach a ModelDiagnostic object to the model by
+ 1) registering forward hook and backward hook on each module, to accumulate
+ its output tensors and gradient tensors, respectively;
+ 2) registering backward hook on each module parameter, to accumulate its
+ values and gradients.
+
+ Args:
+ model:
+ the model to be analyzed.
+ opts:
+ Options object.
+
+ Returns:
+ The ModelDiagnostic object attached to the model.
+ """
+
+ ans = ModelDiagnostic(opts)
+ for name, module in model.named_modules():
+ if name == "":
+ name = ""
+
+ # Setting model_diagnostic=ans and n=name below, instead of trying to
+ # capture the variables, ensures that we use the current values.
+ # (this matters for `name`, since the variable gets overwritten).
+ # These closures don't really capture by value, only by
+ # "the final value the variable got in the function" :-(
+ def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
+ if isinstance(_output, tuple) and len(_output) == 1:
+ _output = _output[0]
+
+ if isinstance(_output, Tensor) and _output.dtype in (
+ torch.float32,
+ torch.float16,
+ torch.float64,
+ ):
+ _model_diagnostic[f"{_name}.output"].accumulate(
+ _output, class_name=get_class_name(_module)
+ )
+ elif isinstance(_output, tuple):
+ for i, o in enumerate(_output):
+ if isinstance(o, Tensor) and o.dtype in (
+ torch.float32,
+ torch.float16,
+ torch.float64,
+ ):
+ _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
+ o, class_name=get_class_name(_module)
+ )
+
+ def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
+ if isinstance(_output, tuple) and len(_output) == 1:
+ _output = _output[0]
+ if isinstance(_output, Tensor) and _output.dtype in (
+ torch.float32,
+ torch.float16,
+ torch.float64,
+ ):
+ _model_diagnostic[f"{_name}.grad"].accumulate(
+ _output, class_name=get_class_name(_module)
+ )
+ elif isinstance(_output, tuple):
+ for i, o in enumerate(_output):
+ if isinstance(o, Tensor) and o.dtype in (
+ torch.float32,
+ torch.float16,
+ torch.float64,
+ ):
+ _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
+ o, class_name=get_class_name(_module)
+ )
+
+ module.register_forward_hook(forward_hook)
+ module.register_backward_hook(backward_hook)
+
+ if type(module).__name__ in [
+ "Sigmoid",
+ "Tanh",
+ "ReLU",
+ "TanSwish",
+ "Swish",
+ "DoubleSwish",
+ "Swoosh",
+ ]:
+ # For these specific module types, accumulate some additional diagnostics
+ # that can help us improve the activation function. These require a lot of
+ # memory, to save the forward activations, so limit this to some select
+ # classes. Note: this will not work correctly for all model types.
+ def scalar_forward_hook(
+ _module, _input, _output, _model_diagnostic=ans, _name=name
+ ):
+ if isinstance(_input, tuple):
+ (_input,) = _input
+ assert isinstance(_input, Tensor)
+ _model_diagnostic[f"{_name}.scalar"].accumulate_input(
+ _input, class_name=get_class_name(_module)
+ )
+
+ def scalar_backward_hook(
+ _module, _input, _output, _model_diagnostic=ans, _name=name
+ ):
+ if isinstance(_output, tuple):
+ (_output,) = _output
+ assert isinstance(_output, Tensor)
+ _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
+
+ module.register_forward_hook(scalar_forward_hook)
+ module.register_backward_hook(scalar_backward_hook)
+
+ for name, parameter in model.named_parameters():
+
+ def param_backward_hook(
+ grad, _parameter=parameter, _model_diagnostic=ans, _name=name
+ ):
+ _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
+ _model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
+
+ try:
+ parameter.register_hook(param_backward_hook)
+ except:
+ logging.warning(
+ f"Warning: could not register backward hook for parameter {name}, "
+ f"it might not be differentiable."
+ )
+
+ return ans
+
+
+def _test_tensor_diagnostic():
+ opts = TensorDiagnosticOptions(512)
+
+ diagnostic = TensorDiagnostic(opts, "foo")
+
+ for _ in range(10):
+ diagnostic.accumulate(torch.randn(50, 100) * 10.0)
+
+ diagnostic.print_diagnostics()
+
+ model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80))
+
+ diagnostic = attach_diagnostics(model, opts)
+ for _ in range(10):
+ T = random.randint(200, 300)
+ x = torch.randn(T, 100)
+ y = model(x)
+ y.sum().backward()
+
+ diagnostic.print_diagnostics()
+
+
+if __name__ == "__main__":
+ _test_tensor_diagnostic()
diff --git a/zipvoice/utils/feature.py b/zipvoice/utils/feature.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6cbc34a39ed7e2b214ffea3ed15be06a6e82153
--- /dev/null
+++ b/zipvoice/utils/feature.py
@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Union
+
+import numpy as np
+import torch
+import torchaudio
+from lhotse.features.base import FeatureExtractor, register_extractor
+from lhotse.utils import Seconds, compute_num_frames
+
+
+@dataclass
+class VocosFbankConfig:
+ sampling_rate: int = 24000
+ n_mels: int = 100
+ n_fft: int = 1024
+ hop_length: int = 256
+
+
+@register_extractor
+class VocosFbank(FeatureExtractor):
+
+ name = "VocosFbank"
+ config_type = VocosFbankConfig
+
+ def __init__(self, num_channels: int = 1):
+ config = VocosFbankConfig
+ super().__init__(config=config)
+ assert num_channels in (1, 2)
+ self.num_channels = num_channels
+ self.fbank = torchaudio.transforms.MelSpectrogram(
+ sample_rate=self.config.sampling_rate,
+ n_fft=self.config.n_fft,
+ hop_length=self.config.hop_length,
+ n_mels=self.config.n_mels,
+ center=True,
+ power=1,
+ )
+
+ def _feature_fn(self, sample):
+ mel = self.fbank(sample)
+ logmel = mel.clamp(min=1e-7).log()
+
+ return logmel
+
+ @property
+ def device(self) -> Union[str, torch.device]:
+ return self.config.device
+
+ def feature_dim(self, sampling_rate: int) -> int:
+ return self.config.n_mels
+
+ def extract(
+ self,
+ samples: Union[np.ndarray, torch.Tensor],
+ sampling_rate: int,
+ ) -> Union[np.ndarray, torch.Tensor]:
+ # Check for sampling rate compatibility.
+ expected_sr = self.config.sampling_rate
+ assert sampling_rate == expected_sr, (
+ f"Mismatched sampling rate: extractor expects {expected_sr}, "
+ f"got {sampling_rate}"
+ )
+ is_numpy = False
+ if not isinstance(samples, torch.Tensor):
+ samples = torch.from_numpy(samples)
+ is_numpy = True
+
+ if len(samples.shape) == 1:
+ samples = samples.unsqueeze(0)
+ else:
+ assert samples.ndim == 2, samples.shape
+
+ if self.num_channels == 1:
+ if samples.shape[0] == 2:
+ samples = samples.mean(dim=0, keepdims=True)
+ else:
+ assert samples.shape[0] == 2, samples.shape
+
+ mel = self._feature_fn(samples)
+ # (1, n_mels, time) or (2, n_mels, time)
+ mel = mel.reshape(-1, mel.shape[-1]).t()
+ # (time, n_mels) or (time, 2 * n_mels)
+
+ num_frames = compute_num_frames(
+ samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
+ )
+
+ if mel.shape[0] > num_frames:
+ mel = mel[:num_frames]
+ elif mel.shape[0] < num_frames:
+ mel = mel.unsqueeze(0)
+ mel = torch.nn.functional.pad(
+ mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
+ ).squeeze(0)
+
+ if is_numpy:
+ return mel.cpu().numpy()
+ else:
+ return mel
+
+ @property
+ def frame_shift(self) -> Seconds:
+ return self.config.hop_length / self.config.sampling_rate
diff --git a/zipvoice/utils/hooks.py b/zipvoice/utils/hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a16581464dd447ca3dbc5bc89dbac4c0d4b4f64
--- /dev/null
+++ b/zipvoice/utils/hooks.py
@@ -0,0 +1,111 @@
+# Copyright 2021-2024 Xiaomi Corporation (authors: Zengwei Yao,
+# Daniel Povey,
+# Zengrui Jin,)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import random
+
+import torch
+from torch import Tensor, nn
+
+
+def register_inf_check_hooks(model: nn.Module) -> None:
+ """Registering forward hook on each module, to check
+ whether its output tensors is not finite.
+
+ Args:
+ model:
+ the model to be analyzed.
+ """
+
+ for name, module in model.named_modules():
+ if name == "":
+ name = ""
+
+ # default param _name is a way to capture the current value of the variable
+ # "name".
+ def forward_hook(_module, _input, _output, _name=name):
+ if isinstance(_output, Tensor):
+ try:
+ if not torch.isfinite(_output.to(torch.float32).sum()):
+ logging.warning(f"The sum of {_name}.output is not finite")
+ except RuntimeError: # e.g. CUDA out of memory
+ pass
+ elif isinstance(_output, tuple):
+ for i, o in enumerate(_output):
+ if isinstance(o, tuple):
+ o = o[0]
+ if not isinstance(o, Tensor):
+ continue
+ try:
+ if not torch.isfinite(o.to(torch.float32).sum()):
+ logging.warning(
+ f"The sum of {_name}.output[{i}] is not finite"
+ )
+ except RuntimeError: # e.g. CUDA out of memory
+ pass
+
+ # default param _name is a way to capture the current value of the variable
+ # "name".
+ def backward_hook(_module, _input, _output, _name=name):
+ if isinstance(_output, Tensor):
+ try:
+ if not torch.isfinite(_output.to(torch.float32).sum()):
+ logging.warning(f"The sum of {_name}.grad is not finite")
+ except RuntimeError: # e.g. CUDA out of memory
+ pass
+
+ elif isinstance(_output, tuple):
+ for i, o in enumerate(_output):
+ if isinstance(o, tuple):
+ o = o[0]
+ if not isinstance(o, Tensor):
+ continue
+ if not torch.isfinite(o.to(torch.float32).sum()):
+ logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
+
+ module.register_forward_hook(forward_hook)
+ module.register_backward_hook(backward_hook)
+
+ for name, parameter in model.named_parameters():
+
+ def param_backward_hook(grad, _name=name):
+ if not torch.isfinite(grad.to(torch.float32).sum()):
+ logging.warning(f"The sum of {_name}.param_grad is not finite")
+
+ try:
+ parameter.register_hook(param_backward_hook)
+ except Exception as e:
+ logging.warning(
+ f"Warning: could not register backward hook for parameter {name}"
+ f" with error {e}, it might not be differentiable."
+ )
+
+
+def _test_inf_check_hooks():
+ model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
+
+ register_inf_check_hooks(model)
+ for _ in range(10):
+ T = random.randint(200, 300)
+ x = torch.randn(T, 100) + float("inf") * (T % 2)
+ y = model(x)
+ y.sum().backward()
+
+
+if __name__ == "__main__":
+ _test_inf_check_hooks()
diff --git a/zipvoice/utils/lr_scheduler.py b/zipvoice/utils/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6b6b7bae012a2e2091db006f3ca2cf3c771e6c4
--- /dev/null
+++ b/zipvoice/utils/lr_scheduler.py
@@ -0,0 +1,245 @@
+# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
+#
+# See ../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import List, Optional, Union
+
+import torch
+from torch.optim import Optimizer
+
+
+class LRScheduler(object):
+ """
+ Base-class for learning rate schedulers where the learning-rate depends on both the
+ batch and the epoch.
+ """
+
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
+ # Attach optimizer
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
+ self.optimizer = optimizer
+ self.verbose = verbose
+
+ for group in optimizer.param_groups:
+ group.setdefault("base_lr", group["lr"])
+
+ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
+
+ self.epoch = 0
+ self.batch = 0
+
+ def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ """
+ return {
+ # the user might try to override the base_lr, so don't include this in the
+ # state. previously they were included.
+ # "base_lrs": self.base_lrs,
+ "epoch": self.epoch,
+ "batch": self.batch,
+ }
+
+ def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+
+ Args:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # the things with base_lrs are a work-around for a previous problem
+ # where base_lrs were written with the state dict.
+ base_lrs = self.base_lrs
+ self.__dict__.update(state_dict)
+ self.base_lrs = base_lrs
+
+ def get_last_lr(self) -> List[float]:
+ """Return last computed learning rate by current scheduler.
+ Will be a list of float."""
+ return self._last_lr
+
+ def get_lr(self):
+ # Compute list of learning rates from self.epoch and self.batch and
+ # self.base_lrs; this must be overloaded by the user.
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr)
+ # for base_lr in self.base_lrs ]
+ raise NotImplementedError
+
+ def step_batch(self, batch: Optional[int] = None) -> None:
+ # Step the batch index, or just set it. If `batch` is specified, it
+ # must be the batch index from the start of training, i.e. summed over
+ # all epochs.
+ # You can call this in any order; if you don't provide 'batch', it should
+ # of course be called once per batch.
+ if batch is not None:
+ self.batch = batch
+ else:
+ self.batch = self.batch + 1
+ self._set_lrs()
+
+ def step_epoch(self, epoch: Optional[int] = None):
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg, you
+ # should call this at the start of the epoch; if you don't provide the 'epoch'
+ # arg, you should call it at the end of the epoch.
+ if epoch is not None:
+ self.epoch = epoch
+ else:
+ self.epoch = self.epoch + 1
+ self._set_lrs()
+
+ def _set_lrs(self):
+ values = self.get_lr()
+ assert len(values) == len(self.optimizer.param_groups)
+
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
+ param_group, lr = data
+ param_group["lr"] = lr
+ self.print_lr(self.verbose, i, lr)
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
+
+ def print_lr(self, is_verbose, group, lr):
+ """Display the current learning rate."""
+ if is_verbose:
+ logging.warning(
+ f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
+ f" of group {group} to {lr:.4e}."
+ )
+
+
+class Eden(LRScheduler):
+ """
+ Eden scheduler.
+ The basic formula (before warmup) is:
+ lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
+ and then stays constant at 1.
+
+ If you don't have the concept of epochs, or one epoch takes a very long time,
+ you can replace the notion of 'epoch' with some measure of the amount of data
+ processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
+ some measure representing "quite a lot of data": say, one fifth or one third
+ of an entire training run, but it doesn't matter much. You could also use
+ Eden2 which has only the notion of batches.
+
+ We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
+
+ Args:
+ optimizer: the optimizer to change the learning rates on
+ lr_batches: the number of batches after which we start significantly
+ decreasing the learning rate, suggest 5000.
+ lr_epochs: the number of epochs after which we start significantly
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
+ 20 to 40 epochs, but may need smaller number if dataset is huge
+ and you will do few epochs.
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ lr_batches: Union[int, float],
+ lr_epochs: Union[int, float],
+ warmup_batches: Union[int, float] = 500.0,
+ warmup_start: float = 0.5,
+ verbose: bool = False,
+ ):
+ super(Eden, self).__init__(optimizer, verbose)
+ self.lr_batches = lr_batches
+ self.lr_epochs = lr_epochs
+ self.warmup_batches = warmup_batches
+
+ assert 0.0 <= warmup_start <= 1.0, warmup_start
+ self.warmup_start = warmup_start
+
+ def get_lr(self):
+ factor = (
+ (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
+ ) ** -0.25 * (
+ ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
+ )
+ warmup_factor = (
+ 1.0
+ if self.batch >= self.warmup_batches
+ else self.warmup_start
+ + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
+ # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
+ )
+
+ return [x * factor * warmup_factor for x in self.base_lrs]
+
+
+class FixedLRScheduler(LRScheduler):
+ """
+ Fixed learning rate scheduler.
+
+ Args:
+ optimizer: the optimizer to change the learning rates on
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ verbose: bool = False,
+ ):
+ super(FixedLRScheduler, self).__init__(optimizer, verbose)
+
+ def get_lr(self):
+
+ return [x for x in self.base_lrs]
+
+
+def _test_eden():
+ m = torch.nn.Linear(100, 100)
+ from zipvoice.utils.optim import ScaledAdam
+
+ optim = ScaledAdam(m.parameters(), lr=0.03)
+
+ scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
+
+ for epoch in range(10):
+ scheduler.step_epoch(epoch) # sets epoch to `epoch`
+
+ for step in range(20):
+ x = torch.randn(200, 100).detach()
+ x.requires_grad = True
+ y = m(x)
+ dy = torch.randn(200, 100).detach()
+ f = (y * dy).sum()
+ f.backward()
+
+ optim.step()
+ scheduler.step_batch()
+ optim.zero_grad()
+
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
+ logging.info(f"state dict = {scheduler.state_dict()}")
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ logging.getLogger().setLevel(logging.INFO)
+ import subprocess
+
+ s = subprocess.check_output(
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
+ )
+ logging.info(s)
+
+ _test_eden()
diff --git a/zipvoice/utils/optim.py b/zipvoice/utils/optim.py
new file mode 100644
index 0000000000000000000000000000000000000000..f90466ab8d3e77ae411151e38db1debd56221da6
--- /dev/null
+++ b/zipvoice/utils/optim.py
@@ -0,0 +1,868 @@
+# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
+#
+# See ../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import logging
+from collections import defaultdict
+from typing import Dict, List, Tuple
+
+import torch
+from lhotse.utils import fix_random_seed
+from torch import Tensor
+from torch.optim import Optimizer
+
+
+class BatchedOptimizer(Optimizer):
+ """
+ This class adds to class Optimizer the capability to optimize parameters in batches:
+ it will stack the parameters and their grads for you so the optimizer can work
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
+ as it reduces the number of kernels launched in the optimizer.
+
+ Args:
+ params:
+ """
+
+ def __init__(self, params, defaults):
+ super(BatchedOptimizer, self).__init__(params, defaults)
+
+ @contextlib.contextmanager
+ def batched_params(self, param_group, group_params_names):
+ """
+ This function returns (technically, yields) a list of
+ of tuples (p, state), where
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
+ that share the same shape, and its gradient is also stacked;
+ `state` is the state corresponding to this batch of parameters
+ (it will be physically located in the "state" for one of the real
+ parameters, the last one that has any particular shape and dtype).
+
+ This function is decorated as a context manager so that it can
+ write parameters back to their "real" locations.
+
+ The idea is, instead of doing:
+
+ for p in group["params"]:
+ state = self.state[p]
+ ...
+
+ you can do:
+
+ with self.batched_params(group["params"]) as batches:
+ for p, state, p_names in batches:
+ ...
+
+
+ Args:
+ group: a parameter group, which is a list of parameters; should be
+ one of self.param_groups.
+ group_params_names: name for each parameter in group,
+ which is List[str].
+ """
+ batches = defaultdict(
+ list
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
+ batches_names = defaultdict(
+ list
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
+
+ assert len(param_group) == len(group_params_names)
+ for p, named_p in zip(param_group, group_params_names):
+ key = (str(p.dtype), *p.shape)
+ batches[key].append(p)
+ batches_names[key].append(named_p)
+
+ batches_names_keys = list(batches_names.keys())
+ sorted_idx = sorted(
+ range(len(batches_names)), key=lambda i: batches_names_keys[i]
+ )
+ batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
+
+ stacked_params_dict = dict()
+
+ # turn batches into a list, in deterministic order.
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
+ # one for each batch in `batches`.
+ tuples = []
+
+ for batch, batch_names in zip(batches, batches_names):
+ p = batch[0]
+ # we arbitrarily store the state in the
+ # state corresponding to the 1st parameter in the
+ # group. class Optimizer will take care of saving/loading state.
+ state = self.state[p]
+ p_stacked = torch.stack(batch)
+ grad = torch.stack(
+ [torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
+ )
+ p_stacked.grad = grad
+ stacked_params_dict[key] = p_stacked
+ tuples.append((p_stacked, state, batch_names))
+
+ yield tuples # <-- calling code will do the actual optimization here!
+
+ for (stacked_params, _state, _names), batch in zip(tuples, batches):
+ for i, p in enumerate(batch): # batch is list of Parameter
+ p.copy_(stacked_params[i])
+
+
+def basic_step(group, p, state, grad):
+ # computes basic Adam update using beta2 (dividing by gradient stddev) only. no
+ # momentum yet.
+ lr = group["lr"]
+ if p.numel() == p.shape[0]:
+ lr = lr * group["scalar_lr_scale"]
+ beta2 = group["betas"][1]
+ eps = group["eps"]
+ # p shape: (batch_size,) or (batch_size, 1, [1,..])
+ try:
+ exp_avg_sq = state[
+ "exp_avg_sq"
+ ] # shape: (batch_size,) or (batch_size, 1, [1,..])
+ except KeyError:
+ exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
+ state["exp_avg_sq"] = exp_avg_sq
+
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+ # bias_correction2 is like in Adam.
+ # slower update at the start will help stability anyway.
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
+ if bias_correction2 < 0.99:
+ # note: not in-place.
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
+ denom = exp_avg_sq.sqrt().add_(eps)
+
+ return -lr * grad / denom
+
+
+def scaling_step(group, p, state, grad):
+ delta = basic_step(group, p, state, grad)
+ if p.numel() == p.shape[0]:
+ return delta
+ # there is no scaling for scalar parameters.
+ # (p.shape[0] is the batch of parameters.)
+
+ step = state["step"]
+ size_update_period = group["size_update_period"]
+
+ try:
+ param_rms = state["param_rms"]
+ scale_grads = state["scale_grads"]
+ scale_exp_avg_sq = state["scale_exp_avg_sq"]
+ except KeyError:
+ # we know p.ndim > 1 because we'd have returned above if not, so don't worry
+ # about the speial case of dim=[] that pytorch treats inconsistently.
+ param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
+ param_rms = param_rms.to(torch.float)
+ scale_exp_avg_sq = torch.zeros_like(param_rms)
+ scale_grads = torch.zeros(
+ size_update_period,
+ *param_rms.shape,
+ dtype=torch.float,
+ device=p.device,
+ )
+ state["param_rms"] = param_rms
+ state["scale_grads"] = scale_grads
+ state["scale_exp_avg_sq"] = scale_exp_avg_sq
+
+ # on every step, update the gradient w.r.t. the scale of the parameter, we
+ # store these as a batch and periodically update the size (for speed only, to
+ # avoid too many operations).
+ scale_grads[step % size_update_period] = (p * grad).sum(
+ dim=list(range(1, p.ndim)), keepdim=True
+ )
+
+ # periodically recompute the value of param_rms.
+ if step % size_update_period == size_update_period - 1:
+ param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
+
+ param_min_rms = group["param_min_rms"]
+
+ # scale the step size by param_rms. This is the most important "scaling" part of
+ # ScaledAdam
+ delta *= param_rms.clamp(min=param_min_rms)
+
+ if step % size_update_period == size_update_period - 1 and step > 0:
+ # This block updates the size of parameter by adding a step ("delta") value in
+ # the direction of either shrinking or growing it.
+ beta2 = group["betas"][1]
+ size_lr = group["lr"] * group["scalar_lr_scale"]
+ param_max_rms = group["param_max_rms"]
+ eps = group["eps"]
+ # correct beta2 for the size update period: we will have
+ # faster decay at this level.
+ beta2_corr = beta2**size_update_period
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
+ alpha=1 - beta2_corr,
+ ) # shape is (batch_size, 1, 1, ...)
+
+ # The 1st time we reach here is when size_step == 1.
+ size_step = (step + 1) // size_update_period
+ bias_correction2 = 1 - beta2_corr**size_step
+
+ denom = scale_exp_avg_sq.sqrt() + eps
+
+ scale_step = (
+ -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
+ )
+
+ is_too_small = param_rms < param_min_rms
+
+ # when the param gets too small, just don't shrink it any further.
+ scale_step.masked_fill_(is_too_small, 0.0)
+
+ # The following may help prevent instability: don't allow the scale step to be
+ # too large in either direction.
+ scale_step.clamp_(min=-0.1, max=0.1)
+
+ # and ensure the parameter rms after update never exceeds param_max_rms.
+ # We have to look at the trained model for parameters at or around the
+ # param_max_rms, because sometimes they can indicate a problem with the
+ # topology or settings.
+ scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
+
+ delta.add_(p * scale_step)
+
+ return delta
+
+
+def momentum_step(group, p, state, grad):
+ delta = scaling_step(group, p, state, grad)
+ beta1 = group["betas"][0]
+ try:
+ stored_delta = state["delta"]
+ except KeyError:
+ stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
+ state["delta"] = stored_delta
+ stored_delta.mul_(beta1)
+ stored_delta.add_(delta, alpha=(1 - beta1))
+ # we don't bother doing the "bias correction" part of Adam for beta1 because this is
+ # just an edge effect that affects the first 10 or so batches; and the effect of not
+ # doing it is just to do a slower update for the first few batches, which will help
+ # stability.
+ return stored_delta
+
+
+class ScaledAdam(BatchedOptimizer):
+ """
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
+ proportional to the norm of that parameter; and also learn the scale of the
+ parameter, in log space, subject to upper and lower limits (as if we had factored
+ each parameter as param = underlying_param * log_scale.exp())
+
+
+ Args:
+ params: The parameters or param_groups to optimize (like other Optimizer
+ subclasses) Unlike common optimizers, which accept
+ model.parameters() or groups of parameters(), this optimizer
+ could accept model.named_parameters() or groups of
+ named_parameters(). See comments of function
+ _get_names_of_parameters for its 4 possible cases.
+ lr: The learning rate. We will typically use a learning rate schedule
+ that starts at 0.03 and decreases over time, i.e. much higher
+ than other common optimizers.
+ clipping_scale: (e.g. 2.0)
+ A scale for gradient-clipping: if specified, the normalized gradients
+ over the whole model will be clipped to have 2-norm equal to
+ `clipping_scale` times the median 2-norm over the most recent period
+ of `clipping_update_period` minibatches. By "normalized gradients",
+ we mean after multiplying by the rms parameter value for this tensor
+ [for non-scalars]; this is appropriate because our update is scaled
+ by this quantity.
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving
+ sum-sq grad. Must satisfy 0 < beta <= beta2 < 1.
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
+ scale of each parameter tensor and scalar parameters of the mode..
+ If each parameter were decomposed as p * p_scale.exp(),
+ where (p**2).mean().sqrt() == 1.0, scalar_lr_scale would be a the
+ scaling factor on the learning rate of p_scale.
+ eps: A general-purpose epsilon to prevent division by zero
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
+ learning the scale on the parameters (we'll constrain the rms of
+ each non-scalar parameter tensor to be >= this value)
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
+ learning the scale on the parameters (we'll constrain the rms of
+ each non-scalar parameter tensor to be <= this value)
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
+ model has any parameters with numel() == 1).
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
+ of the parameter tensor. This is provided to save a little time
+ in the update.
+ clipping_update_period: if clipping_scale is specified, this is the period
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=3e-02,
+ clipping_scale=None,
+ betas=(0.9, 0.98),
+ scalar_lr_scale=0.1,
+ eps=1.0e-08,
+ param_min_rms=1.0e-05,
+ param_max_rms=3.0,
+ scalar_max=10.0,
+ size_update_period=4,
+ clipping_update_period=100,
+ ):
+
+ defaults = dict(
+ lr=lr,
+ clipping_scale=clipping_scale,
+ betas=betas,
+ scalar_lr_scale=scalar_lr_scale,
+ eps=eps,
+ param_min_rms=param_min_rms,
+ param_max_rms=param_max_rms,
+ scalar_max=scalar_max,
+ size_update_period=size_update_period,
+ clipping_update_period=clipping_update_period,
+ )
+
+ # If params only contains parameters or group of parameters,
+ # i.e when parameter names are not given,
+ # this flag will be set to False in funciton _get_names_of_parameters.
+ self.show_dominant_parameters = True
+ param_groups, parameters_names = self._get_names_of_parameters(params)
+ super(ScaledAdam, self).__init__(param_groups, defaults)
+ assert len(self.param_groups) == len(parameters_names)
+ self.parameters_names = parameters_names
+
+ def _get_names_of_parameters(
+ self, params_or_named_params
+ ) -> Tuple[List[Dict], List[List[str]]]:
+ """
+ Args:
+ params_or_named_params: according to the way ScaledAdam is initialized
+ in train.py, this argument could be one of following 4 cases,
+ case 1, a generator of parameter, e.g.:
+ optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
+ clipping_scale=3.0)
+
+ case 2, a list of parameter groups with different config, e.g.:
+ model_param_groups = [
+ {'params': model.encoder.parameters(), 'lr': 0.05},
+ {'params': model.decoder.parameters(), 'lr': 0.01},
+ {'params': model.joiner.parameters(), 'lr': 0.03},
+ ]
+ optimizer = ScaledAdam(model_param_groups, lr=params.base_lr,
+ clipping_scale=3.0)
+
+ case 3, a generator of named_parameter, e.g.:
+ optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr,
+ clipping_scale=3.0)
+
+ case 4, a list of named_parameter groups with different config, e.g.:
+ model_named_param_groups = [
+ {'named_params': model.encoder.named_parameters(), 'lr': 0.05},
+ {'named_params': model.decoder.named_parameters(), 'lr': 0.01},
+ {'named_params': model.joiner.named_parameters(), 'lr': 0.03},
+ ]
+ optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr,
+ clipping_scale=3.0)
+
+ For case 1 and case 2, input params is used to initialize the underlying
+ torch.optimizer.
+ For case 3 and case 4, firstly, names and params are extracted from input
+ named_params, then, these extracted params are used to initialize the
+ underlying torch.optimizer, and these extracted names are mainly used by
+ function `_show_gradient_dominating_parameter`
+
+ Returns:
+ Returns a tuple containing 2 elements:
+ - `param_groups` with type List[Dict], each Dict element is a parameter
+ group. An example of `param_groups` could be:
+ [
+ {'params': `one iterable of Parameter`, 'lr': 0.05},
+ {'params': `another iterable of Parameter`, 'lr': 0.08},
+ {'params': `a third iterable of Parameter`, 'lr': 0.1},
+ ]
+ - `param_gruops_names` with type List[List[str]],
+ each `List[str]` is for a group['params'] in param_groups,
+ and each `str` is the name of a parameter.
+ A dummy name "foo" is related to each parameter,
+ if input are params without names, i.e. case 1 or case 2.
+ """
+ # variable naming convention in this function:
+ # p is short for param.
+ # np is short for named_param.
+ # p_or_np is short for param_or_named_param.
+ # cur is short for current.
+ # group is a dict,
+ # e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
+ # groups is a List[group]
+
+ iterable_or_groups = list(params_or_named_params)
+ if len(iterable_or_groups) == 0:
+ raise ValueError("optimizer got an empty parameter list")
+
+ # The first value of returned tuple. A list of dicts containing at
+ # least 'params' as a key.
+ param_groups = []
+
+ # The second value of returned tuple,
+ # a List[List[str]], each sub-List is for a group.
+ param_groups_names = []
+
+ if not isinstance(iterable_or_groups[0], dict):
+ # case 1 or case 3,
+ # the input is an iterable of parameter or named parameter.
+ param_iterable_cur_group = []
+ param_names_cur_group = []
+ for p_or_np in iterable_or_groups:
+ if isinstance(p_or_np, tuple):
+ # case 3
+ name, param = p_or_np
+ else:
+ # case 1
+ assert isinstance(p_or_np, torch.Tensor)
+ param = p_or_np
+ # Assign a dummy name as a placeholder
+ name = "foo"
+ self.show_dominant_parameters = False
+ param_iterable_cur_group.append(param)
+ param_names_cur_group.append(name)
+ param_groups.append({"params": param_iterable_cur_group})
+ param_groups_names.append(param_names_cur_group)
+ else:
+ # case 2 or case 4
+ # the input is groups of parameter or named parameter.
+ for cur_group in iterable_or_groups:
+ if "named_params" in cur_group:
+ name_list = [x[0] for x in cur_group["named_params"]]
+ p_list = [x[1] for x in cur_group["named_params"]]
+ del cur_group["named_params"]
+ cur_group["params"] = p_list
+ else:
+ assert "params" in cur_group
+ name_list = ["foo" for _ in cur_group["params"]]
+ param_groups.append(cur_group)
+ param_groups_names.append(name_list)
+
+ return param_groups, param_groups_names
+
+ def __setstate__(self, state):
+ super(ScaledAdam, self).__setstate__(state)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group, group_params_names in zip(self.param_groups, self.parameters_names):
+
+ with self.batched_params(group["params"], group_params_names) as batches:
+
+ # batches is list of pairs (stacked_param, state). stacked_param is
+ # like a regular parameter, and will have a .grad, but the 1st dim
+ # corresponds to a stacking dim, it is not a real dim.
+
+ if (
+ len(batches[0][1]) == 0
+ ): # if len(first state) == 0: not yet initialized
+ clipping_scale = 1
+ else:
+ clipping_scale = self._get_clipping_scale(group, batches)
+
+ for p, state, _ in batches:
+ # Perform optimization step.
+ # grad is not going to be None, we handled that when creating the
+ # batches.
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "ScaledAdam optimizer does not support sparse gradients"
+ )
+
+ try:
+ cur_step = state["step"]
+ except KeyError:
+ state["step"] = 0
+ cur_step = 0
+
+ grad = (
+ p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)
+ )
+ p += momentum_step(group, p.detach(), state, grad)
+
+ if p.numel() == p.shape[0]: # scalar parameter
+ scalar_max = group["scalar_max"]
+ p.clamp_(min=-scalar_max, max=scalar_max)
+
+ state["step"] = cur_step + 1
+
+ return loss
+
+ def _get_clipping_scale(
+ self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
+ ) -> float:
+ """
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will
+ scale the gradients by this amount before applying the rest of the update.
+
+ Args:
+ group: the parameter group, an item in self.param_groups
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ """
+ assert len(tuples) >= 1
+ clipping_scale = group["clipping_scale"]
+ (first_p, first_state, _) = tuples[0]
+ step = first_state["step"]
+ if clipping_scale is None or step == 0:
+ # no clipping. return early on step == 0 because the other
+ # parameters' state won't have been initialized yet.
+ return 1.0
+ clipping_update_period = group["clipping_update_period"]
+ scalar_lr_scale = group["scalar_lr_scale"]
+
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
+ for p, state, param_names in tuples:
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "ScaledAdam optimizer does not support sparse gradients"
+ )
+ if p.numel() == p.shape[0]: # a batch of scalars
+ tot_sumsq += (grad**2).sum() * (
+ scalar_lr_scale**2
+ ) # sum() to change shape [1] to []
+ else:
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
+
+ tot_norm = tot_sumsq.sqrt()
+ if "model_norms" not in first_state:
+ first_state["model_norms"] = torch.zeros(
+ clipping_update_period, device=p.device
+ )
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
+
+ irregular_estimate_steps = [
+ i for i in [10, 20, 40] if i < clipping_update_period
+ ]
+ if step % clipping_update_period == 0 or step in irregular_estimate_steps:
+ # Print some stats.
+ # We don't reach here if step == 0 because we would have returned
+ # above.
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
+ if step in irregular_estimate_steps:
+ sorted_norms = sorted_norms[-step:]
+ num_norms = sorted_norms.numel()
+ quartiles = []
+ for n in range(0, 5):
+ index = min(num_norms - 1, (num_norms // 4) * n)
+ quartiles.append(sorted_norms[index].item())
+
+ median = quartiles[2]
+ if median - median != 0:
+ raise RuntimeError("Too many grads were not finite")
+ threshold = clipping_scale * median
+ if step in irregular_estimate_steps:
+ # use larger thresholds on first few steps of estimating threshold,
+ # as norm may be changing rapidly.
+ threshold = threshold * 2.0
+ first_state["model_norm_threshold"] = threshold
+ percent_clipped = (
+ first_state["num_clipped"] * 100.0 / num_norms
+ if "num_clipped" in first_state
+ else 0.0
+ )
+ first_state["num_clipped"] = 0
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
+ logging.warning(
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
+ )
+
+ try:
+ model_norm_threshold = first_state["model_norm_threshold"]
+ except KeyError:
+ return 1.0 # threshold has not yet been set.
+
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
+ if ans != ans: # e.g. ans is nan
+ ans = 0.0
+ if ans < 1.0:
+ first_state["num_clipped"] += 1
+ if ans < 0.5:
+ logging.debug(
+ f"Scaling gradients by {ans}, "
+ f"model_norm_threshold={model_norm_threshold}"
+ )
+ if self.show_dominant_parameters:
+ assert p.shape[0] == len(param_names)
+ self._show_gradient_dominating_parameter(
+ tuples, tot_sumsq, group["scalar_lr_scale"]
+ )
+ self._show_param_with_unusual_grad(tuples)
+
+ if ans == 0.0:
+ for p, state, param_names in tuples:
+ p.grad.zero_() # get rid of infinity()
+
+ return ans
+
+ def _show_param_with_unusual_grad(
+ self,
+ tuples: List[Tuple[Tensor, dict, List[str]]],
+ ):
+ """
+ Print information about parameter which has the largest ratio of
+ grad-on-this-batch divided by normal grad size.
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ """
+ # ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor)
+ ratios_names = []
+ for p, state, batch_param_names in tuples:
+ dims = list(range(1, p.ndim))
+
+ def mean(x):
+ # workaround for bad interface of torch's "mean" for when dims is the
+ # empty list.
+ if len(dims) > 0:
+ return x.mean(dim=dims)
+ else:
+ return x
+
+ grad_ratio = (
+ (mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims))
+ .sqrt()
+ .to("cpu")
+ )
+
+ ratios_names += zip(
+ grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0)
+ )
+
+ ratios_names = sorted(ratios_names, reverse=True)
+ ratios_names = ratios_names[:10]
+ ratios_names = [
+ (ratio, name, largest_index(tensor))
+ for (ratio, name, tensor) in ratios_names
+ ]
+
+ logging.debug(
+ f"Parameters with most larger-than-usual grads, with ratios, "
+ f"are: {ratios_names}"
+ )
+
+ def _show_gradient_dominating_parameter(
+ self,
+ tuples: List[Tuple[Tensor, dict, List[str]]],
+ tot_sumsq: Tensor,
+ scalar_lr_scale: float,
+ ):
+ """
+ Show information of parameter which dominates tot_sumsq.
+
+ Args:
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
+ from tuples, we still pass it to save some time.
+ """
+ all_sumsq_orig = {}
+ for p, state, batch_param_names in tuples:
+ # p is a stacked batch parameters.
+ batch_grad = p.grad
+ if p.numel() == p.shape[0]: # a batch of scalars
+ # Dummy values used by following `zip` statement.
+ batch_rms_orig = torch.full(
+ p.shape, scalar_lr_scale, device=batch_grad.device
+ )
+ else:
+ batch_rms_orig = state["param_rms"]
+ batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
+ if batch_grad.ndim > 1:
+ # need to guard it with if-statement because sum() sums over
+ # all dims if dim == ().
+ batch_sumsq_orig = batch_sumsq_orig.sum(
+ dim=list(range(1, batch_grad.ndim))
+ )
+ for name, sumsq_orig, rms, grad in zip(
+ batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
+ ):
+
+ proportion_orig = sumsq_orig / tot_sumsq
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
+
+ sorted_by_proportion = {
+ k: v
+ for k, v in sorted(
+ all_sumsq_orig.items(),
+ key=lambda item: item[1][0],
+ reverse=True,
+ )
+ }
+ dominant_param_name = next(iter(sorted_by_proportion))
+ (
+ dominant_proportion,
+ dominant_sumsq,
+ dominant_rms,
+ dominant_grad,
+ ) = sorted_by_proportion[dominant_param_name]
+ logging.debug(
+ f"Parameter dominating tot_sumsq {dominant_param_name}"
+ f" with proportion {dominant_proportion:.2f},"
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
+ f"={dominant_sumsq:.3e},"
+ f" grad_sumsq={(dominant_grad**2).sum():.3e},"
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
+ )
+
+
+def largest_index(x: Tensor):
+ x = x.contiguous()
+ argmax = x.abs().argmax().item()
+ return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)]
+
+
+def _test_scaled_adam(hidden_dim: int):
+ import timeit
+
+ from zipvoice.models.modules.scaling import ScaledLinear
+ from zipvoice.utils.lr_scheduler import Eden
+
+ E = 100
+ B = 4
+ T = 2
+ logging.info("in test_eve_cain")
+ # device = torch.device('cuda')
+ device = torch.device("cpu")
+ dtype = torch.float32
+
+ fix_random_seed(42)
+ # these input_magnitudes and output_magnitudes are to test that
+ # Abel is working as we expect and is able to adjust scales of
+ # different dims differently.
+ input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
+ output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
+
+ fix_random_seed(42)
+ Linear = ScaledLinear
+
+ m = torch.nn.Sequential(
+ Linear(E, hidden_dim),
+ torch.nn.PReLU(),
+ Linear(hidden_dim, hidden_dim),
+ torch.nn.PReLU(),
+ Linear(hidden_dim, E),
+ ).to(device)
+
+ train_pairs = [
+ (
+ 100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
+ torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
+ )
+ for _ in range(20)
+ ]
+ optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0)
+ scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
+
+ start = timeit.default_timer()
+ avg_loss = 0.0
+ for epoch in range(180):
+ scheduler.step_epoch()
+ # if epoch == 100 and iter in [2,3]:
+ # optim.reset_speedup() # check it doesn't crash.
+
+ # if epoch == 130:
+ # opts = diagnostics.TensorDiagnosticOptions(
+ # 512
+ # ) # allow 4 megabytes per sub-module
+ # diagnostic = diagnostics.attach_diagnostics(m, opts)
+
+ for n, (x, y) in enumerate(train_pairs):
+ y_out = m(x)
+ loss = ((y_out - y) ** 2).mean() * 100.0
+ if epoch == 0 and n == 0:
+ avg_loss = loss.item()
+ else:
+ avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
+ if n == 0 and epoch % 5 == 0:
+ # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
+ # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
+ # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
+ # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
+ # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
+ # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
+ # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
+ # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
+ lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Iter {iter}, epoch {epoch}, batch {n}, "
+ f"avg_loss {avg_loss:.4g}, lr={lr:.4e}"
+ ) # , norms={norm1,norm1b,norm2,norm2b}")
+ # scales={scale1,scale1b,scale2,scale2b}
+ loss.log().backward()
+ optim.step()
+ optim.zero_grad()
+ scheduler.step_batch()
+
+ # diagnostic.print_diagnostics()
+
+ stop = timeit.default_timer()
+ logging.info(f"Iter={iter}, Time taken: {stop - start}")
+
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
+ # logging.info("state dict = ", scheduler.state_dict())
+ # logging.info("optim state_dict = ", optim.state_dict())
+ logging.info(f"input_magnitudes = {input_magnitudes}")
+ logging.info(f"output_magnitudes = {output_magnitudes}")
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ logging.getLogger().setLevel(logging.INFO)
+ import subprocess
+
+ s = subprocess.check_output(
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
+ )
+ logging.info(s)
+ import sys
+
+ if len(sys.argv) > 1:
+ hidden_dim = int(sys.argv[1])
+ else:
+ hidden_dim = 200
+
+ _test_scaled_adam(hidden_dim)
diff --git a/zipvoice/utils/scaling_converter.py b/zipvoice/utils/scaling_converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f9ff213e6f58f0ea1500135a07804f5a9519571
--- /dev/null
+++ b/zipvoice/utils/scaling_converter.py
@@ -0,0 +1,105 @@
+# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file replaces various modules in a model.
+Specifically, ActivationBalancer is replaced with an identity operator;
+Whiten is also replaced with an identity operator;
+BasicNorm is replaced by a module with `exp` removed.
+"""
+
+import copy
+from typing import List
+
+import torch
+import torch.nn as nn
+
+from zipvoice.models.modules.scaling import (
+ Balancer,
+ Dropout3,
+ SwooshL,
+ SwooshLOnnx,
+ SwooshR,
+ SwooshROnnx,
+ Whiten,
+)
+from zipvoice.models.modules.zipformer import CompactRelPositionalEncoding
+
+
+# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
+# get_submodule was added to nn.Module at v1.9.0
+def get_submodule(model, target):
+ if target == "":
+ return model
+ atoms: List[str] = target.split(".")
+ mod: torch.nn.Module = model
+ for item in atoms:
+ if not hasattr(mod, item):
+ raise AttributeError(
+ mod._get_name() + " has no " "attribute `" + item + "`"
+ )
+ mod = getattr(mod, item)
+ if not isinstance(mod, torch.nn.Module):
+ raise AttributeError("`" + item + "` is not " "an nn.Module")
+ return mod
+
+
+def convert_scaled_to_non_scaled(
+ model: nn.Module,
+ inplace: bool = False,
+ is_pnnx: bool = False,
+ is_onnx: bool = False,
+):
+ """
+ Args:
+ model:
+ The model to be converted.
+ inplace:
+ If True, the input model is modified inplace.
+ If False, the input model is copied and we modify the copied version.
+ is_pnnx:
+ True if we are going to export the model for PNNX.
+ is_onnx:
+ True if we are going to export the model for ONNX.
+ Return:
+ Return a model without scaled layers.
+ """
+ if not inplace:
+ model = copy.deepcopy(model)
+
+ d = {}
+ for name, m in model.named_modules():
+ if isinstance(m, (Balancer, Dropout3, Whiten)):
+ d[name] = nn.Identity()
+ elif is_onnx and isinstance(m, SwooshR):
+ d[name] = SwooshROnnx()
+ elif is_onnx and isinstance(m, SwooshL):
+ d[name] = SwooshLOnnx()
+ elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
+ # We want to recreate the positional encoding vector when
+ # the input changes, so we have to use torch.jit.script()
+ # to replace torch.jit.trace()
+ d[name] = torch.jit.script(m)
+
+ for k, v in d.items():
+ if "." in k:
+ parent, child = k.rsplit(".", maxsplit=1)
+ setattr(get_submodule(model, parent), child, v)
+ else:
+ setattr(model, k, v)
+
+ return model