diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..0652054b20b554fa1cb6208a1f8dbc2ceab35233
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2026 OpenMOSS Team, Fudan University, SII and MOSI
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index c841f159ecb1fe6f17fce367c4c2d4e0caf14cc3..67eca1ad32cab79957e1af754a0574948138fee9 100644
--- a/README.md
+++ b/README.md
@@ -9,6 +9,12 @@ python_version: '3.12'
app_file: app.py
pinned: false
license: apache-2.0
+tags:
+ - zerogpu
+ - aoti
+ - text-to-speech
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+dots.tts Gradio Space for Hugging Face ZeroGPU with optional PyTorch AOTInductor startup compilation.
+
+Set `DOTS_TTS_MODEL_NAME_OR_PATH` to a local model directory or Hugging Face model repo id. The app defaults to `rednote-hilab/dots.tts`.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..e16f0b22d540f13c38c491e225529dd9d38b77ef
--- /dev/null
+++ b/app.py
@@ -0,0 +1,176 @@
+from __future__ import annotations
+
+import os
+import sys
+from pathlib import Path
+from typing import Any, Callable
+
+
+REPO_ROOT = Path(__file__).resolve().parent
+SRC_ROOT = REPO_ROOT / "src"
+
+for import_root in (REPO_ROOT, SRC_ROOT):
+ import_root_str = str(import_root)
+ if import_root_str not in sys.path:
+ sys.path.insert(0, import_root_str)
+
+
+class _SpacesFallback:
+ @staticmethod
+ def GPU(*decorator_args, **_decorator_kwargs):
+ if decorator_args and callable(decorator_args[0]):
+ return decorator_args[0]
+
+ def decorate(fn: Callable[..., Any]) -> Callable[..., Any]:
+ return fn
+
+ return decorate
+
+
+try:
+ import spaces # type: ignore
+except Exception: # pragma: no cover - only used outside Hugging Face Spaces.
+ spaces = _SpacesFallback() # type: ignore
+
+
+def _env_bool(name: str, default: bool) -> bool:
+ value = os.environ.get(name)
+ if value is None:
+ return default
+ return value.strip().lower() in {"1", "true", "yes", "on"}
+
+
+def _env_int(name: str, default: int) -> int:
+ value = os.environ.get(name)
+ if value is None or not value.strip():
+ return default
+ return int(value)
+
+
+def _configure_zero_gpu_environment() -> None:
+ os.environ.setdefault("DOTS_TTS_COMPILE_BACKEND", "aoti")
+ os.environ.setdefault("DOTS_TTS_SKIP_INIT_WARMUP", "1")
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
+
+
+def _preload_runtime(app_service, app_config, compile_backend: str):
+ runtime, resolved_model_name_or_path = app_service._get_runtime( # noqa: SLF001
+ app_config.default_model_name_or_path,
+ )
+ runtime.optimize = bool(app_config.optimize)
+ runtime.model.set_optimize(bool(app_config.optimize))
+ if hasattr(runtime.model, "set_compile_backend"):
+ runtime.model.set_compile_backend(compile_backend)
+ return runtime, resolved_model_name_or_path
+
+
+def main() -> None:
+ _configure_zero_gpu_environment()
+
+ import gradio as gr
+ from loguru import logger
+
+ from apps.gradio.app import PLAYGROUND_CSS, build_demo, build_playground_theme
+ from apps.gradio.service import GradioAppService, build_gradio_app_config
+ from dots_tts.utils.logging import configure_logging
+
+ host = os.environ.get("DOTS_TTS_HOST", "0.0.0.0")
+ port = _env_int("DOTS_TTS_PORT", 7860)
+ model_name_or_path = os.environ.get(
+ "DOTS_TTS_MODEL_NAME_OR_PATH",
+ "rednote-hilab/dots.tts",
+ )
+ precision = os.environ.get("DOTS_TTS_PRECISION", "bfloat16")
+ execution_mode = os.environ.get("DOTS_TTS_EXECUTION_MODE", "generate_stream")
+ max_generate_length = _env_int("DOTS_TTS_MAX_GENERATE_LENGTH", 500)
+ default_num_steps = _env_int("DOTS_TTS_DEFAULT_NUM_STEPS", 10)
+ compile_backend = os.environ.get("DOTS_TTS_COMPILE_BACKEND", "aoti").strip().lower()
+ enable_aoti = _env_bool("DOTS_TTS_ENABLE_AOTI", True)
+ startup_compile = _env_bool("DOTS_TTS_AOTI_COMPILE_ON_STARTUP", True)
+ optimize = _env_bool("DOTS_TTS_OPTIMIZE", True)
+ generation_duration = _env_int("DOTS_TTS_ZERO_GPU_DURATION", 600)
+ compile_duration = _env_int("DOTS_TTS_ZERO_GPU_COMPILE_DURATION", 1500)
+ output_dir = Path(os.environ.get("DOTS_TTS_OUTPUT_DIR", "/tmp/dots_tts_outputs"))
+ log_file = Path(os.environ.get("DOTS_TTS_LOG_FILE", "/tmp/dots_tts_gradio.log"))
+
+ configure_logging(log_file=log_file)
+ logger.info(
+ "Space app starting: model={} execution_mode={} precision={} optimize={} "
+ "compile_backend={} enable_aoti={} startup_compile={} max_generate_length={}",
+ model_name_or_path,
+ execution_mode,
+ precision,
+ optimize,
+ compile_backend,
+ enable_aoti,
+ startup_compile,
+ max_generate_length,
+ )
+
+ app_config = build_gradio_app_config(
+ host=host,
+ port=port,
+ execution_mode=execution_mode,
+ precision=precision,
+ optimize=optimize,
+ model_name_or_path=model_name_or_path,
+ output_dir=output_dir,
+ max_generate_length=max_generate_length,
+ default_num_steps=default_num_steps,
+ default_max_generate_length=max_generate_length,
+ repo_root=REPO_ROOT,
+ )
+ app_service = GradioAppService(app_config)
+ runtime, resolved_model_name_or_path = _preload_runtime(
+ app_service,
+ app_config,
+ compile_backend if enable_aoti else "torch_compile",
+ )
+
+ if enable_aoti and startup_compile and optimize:
+
+ @spaces.GPU(duration=compile_duration)
+ def compile_aoti_cache():
+ child_runtime, _ = _preload_runtime(
+ app_service,
+ app_config,
+ compile_backend,
+ )
+ child_runtime.model.run_warmup(
+ max_generate_length=app_config.max_generate_length,
+ precision=app_config.precision,
+ num_steps=app_config.default_num_steps,
+ guidance_scale=app_config.default_guidance_scale,
+ )
+ return child_runtime.model.export_compiled_models()
+
+ compiled_models = compile_aoti_cache()
+ if compiled_models:
+ runtime.model.import_compiled_models(compiled_models)
+ logger.info(
+ "AOTI startup compile completed: compiled_target_count={}",
+ len(compiled_models or {}),
+ )
+
+ app_service.generate = spaces.GPU(duration=generation_duration)(app_service.generate)
+
+ demo = build_demo(gr, app_config, app_service)
+ logger.info(
+ "Space app ready: host={} port={} resolved_model={} compiled_target_count={}",
+ app_config.host,
+ app_config.port,
+ resolved_model_name_or_path,
+ len(runtime.model.export_compiled_models())
+ if hasattr(runtime.model, "export_compiled_models")
+ else 0,
+ )
+ demo.launch(
+ server_name=app_config.host,
+ server_port=app_config.port,
+ theme=build_playground_theme(gr),
+ css=PLAYGROUND_CSS,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/apps/__init__.py b/apps/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d6c872ad474f5be0ff27e8b3de592df998368ba
--- /dev/null
+++ b/apps/__init__.py
@@ -0,0 +1 @@
+"""Application entrypoints for dots.tts."""
diff --git a/apps/gradio/__init__.py b/apps/gradio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f921e101e49432bb5030192aa97efd3aee0280d8
--- /dev/null
+++ b/apps/gradio/__init__.py
@@ -0,0 +1 @@
+"""Gradio application for dots.tts."""
diff --git a/apps/gradio/app.py b/apps/gradio/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..59164b8e7c23f10014b0f9f7f2e37b679029a64d
--- /dev/null
+++ b/apps/gradio/app.py
@@ -0,0 +1,663 @@
+from __future__ import annotations
+
+import argparse
+import os
+import sys
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
+SRC_ROOT = REPO_ROOT / "src"
+
+for import_root in (REPO_ROOT, SRC_ROOT):
+ import_root_str = str(import_root)
+ if import_root_str not in sys.path:
+ sys.path.insert(0, import_root_str)
+
+from apps.gradio.constants import ( # noqa: E402
+ DEFAULT_EXECUTION_MODE,
+ DEFAULT_GUIDANCE_SCALE,
+ DEFAULT_HOST,
+ DEFAULT_INPUT_TEXT,
+ DEFAULT_LOG_FILE,
+ DEFAULT_MAX_GENERATE_LENGTH,
+ DEFAULT_NUM_STEPS,
+ DEFAULT_ODE_METHOD,
+ DEFAULT_OUTPUT_DIR,
+ DEFAULT_OUTPUT_RETENTION,
+ DEFAULT_PORT,
+ DEFAULT_PRECISION,
+ DEFAULT_PROMPT_NAME,
+ DEFAULT_SEED,
+ DEFAULT_SPEAKER_SCALE,
+)
+
+if TYPE_CHECKING:
+ import gradio as gr
+
+DEBUG_GRADIO_ENABLED = os.environ.get("DEBUG_GRADIO", "0") == "1"
+
+
+PLAYGROUND_CSS = """
+.gradio-container {
+ width: min(1600px, calc(100vw - 32px)) !important;
+ max-width: none !important;
+ margin: 0 auto !important;
+ padding-left: 0 !important;
+ padding-right: 0 !important;
+}
+
+.gradio-container,
+.gradio-container .gradio-container {
+ --block-label-background-fill: #CCE5FF;
+ --block-label-text-color: #6666FF;
+ --block-label-border-color: #99c7ee;
+ --block-label-text-weight: 600;
+ --block-title-background-fill: #CCE5FF;
+ --block-title-text-color: #6666FF;
+ --block-title-border-color: #99c7ee;
+ --block-title-border-width: var(--block-label-border-width);
+ --block-title-radius: var(--block-label-radius);
+ --block-title-padding: var(--block-label-padding);
+ --block-title-text-size: var(--block-label-text-size);
+ --block-title-text-weight: 600;
+}
+
+.gradio-container label[data-testid="block-label"],
+.gradio-container label[data-testid="block-label"] *,
+.gradio-container span[data-testid="block-info"],
+.gradio-container span[data-testid="block-info"] * {
+ background: #CCE5FF !important;
+ border-color: #99c7ee !important;
+ color: #6666FF !important;
+ fill: #6666FF !important;
+ font-family: Verdana, Geneva, "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans CJK SC", sans-serif !important;
+ font-style: normal !important;
+ font-size: 0.78rem !important;
+ line-height: 1.2 !important;
+ letter-spacing: 0 !important;
+ text-transform: none !important;
+}
+.gradio-container label[data-testid="block-label"],
+.gradio-container span[data-testid="block-info"],
+.gradio-container [data-testid="block-title"],
+.gradio-container .block-title {
+ border: var(--block-label-border-width) solid #99c7ee !important;
+ border-top: none !important;
+ border-left: none !important;
+ border-radius: var(--block-label-radius) !important;
+ box-shadow: var(--block-label-shadow) !important;
+ padding: var(--block-label-padding) !important;
+}
+.gradio-container label[data-testid="block-label"],
+.gradio-container label[data-testid="block-label"] *,
+.gradio-container span[data-testid="block-info"],
+.gradio-container span[data-testid="block-info"] *,
+.gradio-container [data-testid="block-title"],
+.gradio-container [data-testid="block-title"] *,
+.gradio-container .block-title,
+.gradio-container .block-title * {
+ font-weight: 600 !important;
+}
+.gradio-container .block label > span,
+.gradio-container .block label > span *,
+.gradio-container .form label > span,
+.gradio-container .form label > span *,
+.gradio-container label > span:first-child,
+.gradio-container label > span:first-child * {
+ font-weight: 600 !important;
+}
+.strong-label [data-testid="block-label"],
+.strong-label [data-testid="block-label"] *,
+.strong-label span[data-testid="block-info"],
+.strong-label span[data-testid="block-info"] *,
+.strong-label [data-testid="block-title"],
+.strong-label [data-testid="block-title"] *,
+.strong-label .block-label,
+.strong-label .block-label *,
+.strong-label .block-title,
+.strong-label .block-title *,
+.strong-label label > span:first-child,
+.strong-label label > span:first-child * {
+ font-weight: 600 !important;
+}
+.gradio-container .info-text,
+.gradio-container .info-text * {
+ font-weight: 400 !important;
+}
+.gradio-container input,
+.gradio-container textarea,
+.gradio-container select,
+.gradio-container [role="textbox"],
+.gradio-container [contenteditable="true"] {
+ font-weight: 400 !important;
+}
+.gradio-container label[data-testid="block-label"] > span:first-child {
+ display: none !important;
+}
+
+.generate-button {
+ background: #6666FF !important;
+ color: #ffffff !important;
+ border: 1px solid #5555ee !important;
+ font-family: Verdana, Geneva, sans-serif !important;
+}
+.generate-button:hover {
+ background: #5555ee !important;
+}
+
+#playground-banner {
+ padding: 0;
+ border-radius: 0;
+ margin-bottom: 18px;
+ background: transparent;
+ border: 0;
+}
+#playground-banner h1 {
+ margin: 0 0 4px 0;
+ font-size: 1.7rem;
+ font-weight: 700;
+ color: #0f172a;
+ letter-spacing: 0;
+}
+#playground-banner .subtitle {
+ margin: 0;
+ color: #1e293b;
+ font-size: 0.9rem;
+}
+
+.info-card {
+ padding: 14px 18px;
+ border-radius: 8px;
+ border: 1px solid #99c7ee;
+ border-left: 4px solid #2563eb;
+ background: transparent;
+ font-size: 0.86rem;
+ line-height: 1.55;
+ margin-bottom: 16px;
+ box-sizing: border-box;
+ color: #0f172a;
+}
+.info-card .card-title,
+.info-card .notice-title {
+ display: block;
+ font-weight: 600;
+ font-size: 0.92rem;
+ color: #0f172a;
+}
+.info-card .card-title {
+ margin-bottom: 4px;
+}
+.info-card .notice-title {
+ margin-top: 8px;
+ margin-bottom: 4px;
+}
+.info-card ol,
+.info-card ul {
+ margin: 0;
+ padding-left: 18px;
+}
+.info-card li {
+ margin: 2px 0;
+}
+
+.main-workspace {
+ gap: 18px !important;
+ align-items: stretch !important;
+}
+
+.prompt-column,
+.synthesis-column {
+ gap: 14px !important;
+}
+
+.control-row,
+.settings-slider-row {
+ gap: 14px !important;
+}
+
+.settings-card {
+ margin-top: 2px !important;
+}
+
+.generate-button {
+ margin-top: 2px !important;
+ width: 100% !important;
+ box-sizing: border-box !important;
+ flex: 0 0 auto !important;
+ min-height: 44px !important;
+ padding-top: 10px !important;
+ padding-bottom: 10px !important;
+ font-size: 1rem !important;
+ font-weight: 600 !important;
+}
+
+.output-audio {
+ flex: 0 0 auto !important;
+ min-height: 190px !important;
+}
+.output-audio audio {
+ width: 100% !important;
+}
+
+@media (max-width: 768px) {
+ .gradio-container {
+ width: calc(100vw - 20px) !important;
+ }
+
+}
+
+"""
+
+
+def build_playground_theme(gr):
+ return gr.themes.Soft(
+ primary_hue="slate",
+ secondary_hue="slate",
+ neutral_hue="slate",
+ radius_size="md",
+ text_size="md",
+ spacing_size="md",
+ font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"],
+ )
+
+
+def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="dots.tts Gradio app.")
+ parser.add_argument("--host", default=DEFAULT_HOST, help="Server host")
+ parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="Server port")
+ parser.add_argument(
+ "--execution-mode",
+ choices=("generate", "generate_stream"),
+ default=DEFAULT_EXECUTION_MODE,
+ help="Runtime execution mode fixed for the app",
+ )
+ parser.add_argument(
+ "--precision",
+ default=DEFAULT_PRECISION,
+ help="Inference precision fixed for the app runtime",
+ )
+ parser.add_argument(
+ "--optimize",
+ action="store_true",
+ help="Enable runtime optimize acceleration",
+ )
+ parser.add_argument(
+ "--model-name-or-path",
+ default=None,
+ help="Default model directory or Hugging Face repo id",
+ )
+ parser.add_argument(
+ "--output-dir",
+ default=str(DEFAULT_OUTPUT_DIR),
+ help="Directory for generated wav outputs",
+ )
+ parser.add_argument(
+ "--log-file",
+ default=str(DEFAULT_LOG_FILE),
+ help="Path to the Gradio log file",
+ )
+ parser.add_argument(
+ "--output-retention-count",
+ type=int,
+ default=DEFAULT_OUTPUT_RETENTION,
+ help="Maximum number of generated wav files to keep",
+ )
+ parser.add_argument(
+ "--max-generate-length",
+ type=int,
+ default=DEFAULT_MAX_GENERATE_LENGTH,
+ help="Maximum generation schedule length fixed for the app runtime",
+ )
+ parser.add_argument(
+ "--default-prompt-name",
+ default=DEFAULT_PROMPT_NAME,
+ help="Default built-in voice preset name",
+ )
+ parser.add_argument(
+ "--default-precision",
+ default=DEFAULT_PRECISION,
+ choices=["bfloat16", "float32", "float16"],
+ help="Default precision selected in the UI",
+ )
+ parser.add_argument(
+ "--default-num-steps",
+ type=int,
+ default=DEFAULT_NUM_STEPS,
+ help="Default Num Steps selected in the UI",
+ )
+ parser.add_argument(
+ "--default-guidance-scale",
+ type=float,
+ default=DEFAULT_GUIDANCE_SCALE,
+ help="Default Guidance Scale selected in the UI",
+ )
+ parser.add_argument(
+ "--default-speaker-scale",
+ type=float,
+ default=DEFAULT_SPEAKER_SCALE,
+ help="Default Speaker Scale selected in the UI",
+ )
+ parser.add_argument(
+ "--default-max-generate-length",
+ type=int,
+ default=DEFAULT_MAX_GENERATE_LENGTH,
+ help="Default Max Generate Length selected in the UI",
+ )
+ parser.add_argument(
+ "--skip-warmup",
+ action="store_true",
+ help="Start the Gradio server without running an initial synthesis warmup.",
+ )
+ return parser.parse_args(argv)
+
+
+def build_startup_config_panel(gr, app_config) -> None:
+ with gr.Accordion("启动固定参数", open=False):
+ gr.Markdown("只读。修改这部分需要重启服务并传入新的启动参数。")
+ gr.Textbox(
+ label="Model",
+ value=app_config.default_model_name_or_path,
+ interactive=False,
+ )
+ with gr.Row():
+ gr.Textbox(
+ label="Execution Mode",
+ value=app_config.execution_mode,
+ interactive=False,
+ )
+ gr.Textbox(
+ label="Precision",
+ value=app_config.precision,
+ interactive=False,
+ )
+ with gr.Row():
+ gr.Number(
+ label="Max Generate Length",
+ value=app_config.max_generate_length,
+ precision=0,
+ interactive=False,
+ )
+ gr.Checkbox(
+ label="Optimize",
+ value=app_config.optimize,
+ interactive=False,
+ )
+
+
+def build_demo(gr, app_config, app_service) -> "gr.Blocks":
+ from apps.gradio.service import (
+ GRADIO_SYNTHESIS_MODE_CHOICES,
+ SynthesisRequest,
+ build_prompt_choice_items,
+ resolve_prompt_selection,
+ )
+
+ def select_prompt_preset(prompt_name: str):
+ audio_path, prompt_text = resolve_prompt_selection(
+ prompt_name,
+ app_config.prompt_presets,
+ )
+ return audio_path, prompt_text
+
+ def run_synthesis(
+ text: str,
+ synthesis_mode: str,
+ prompt_audio_path: str | None,
+ prompt_text: str,
+ ode_method: str,
+ num_steps: float,
+ guidance_scale: float,
+ speaker_scale: float,
+ normalize_text: bool,
+ seed: float,
+ ):
+ resolved_synthesis_mode = synthesis_mode if DEBUG_GRADIO_ENABLED else "tts"
+ request = SynthesisRequest(
+ model_name_or_path=app_config.default_model_name_or_path,
+ text=text,
+ prompt_audio_path=prompt_audio_path,
+ prompt_text=prompt_text,
+ execution_mode=app_config.execution_mode,
+ template_name=resolved_synthesis_mode,
+ ode_method=ode_method,
+ num_steps=int(num_steps),
+ guidance_scale=float(guidance_scale),
+ speaker_scale=float(speaker_scale),
+ normalize_text=normalize_text,
+ seed=int(seed),
+ )
+ result = app_service.generate(request)
+ return result.audio_path, result.metrics
+
+ show_prompt_preset = bool(app_config.prompt_presets)
+
+ with gr.Blocks(title="dots.tts") as demo:
+ gr.HTML(
+ "\n"
+ + """
+
+
dots.tts
+
Fully-continuous Autoregressive TTS · 48 kHz · Voice Cloning
+
+ """,
+ )
+
+ gr.HTML(
+ """
+
+
使用说明 · Instructions
+
+ - 上传参考音频并填写对应转写文本 · Upload prompt audio and fill in its transcript.
+ - 在文本框中输入要合成的内容 · Enter the text to synthesize.
+ - 点击 Generate 合成声音 · Click Generate to synthesize speech.
+
+
+ """,
+ )
+
+ with gr.Row(equal_height=True, elem_classes="main-workspace"):
+ with gr.Column(scale=1, min_width=480, elem_classes="prompt-column"):
+ prompt_preset = gr.Dropdown(
+ label="音色 · Voice Preset",
+ choices=build_prompt_choice_items(app_config.prompt_presets),
+ value=app_config.default_prompt_name,
+ info="内置音色clone样本;选择后自动填入参考音频与转写。",
+ elem_id="voice-preset-dropdown",
+ elem_classes="strong-label",
+ visible=show_prompt_preset,
+ )
+ prompt_audio_path = gr.Audio(
+ label="参考音频 · Prompt Audio",
+ sources=["upload"],
+ type="filepath",
+ value=app_config.default_prompt_audio_path,
+ elem_classes="strong-label",
+ )
+ prompt_text = gr.Textbox(
+ label="参考音频转写 · Prompt Text",
+ lines=5,
+ value=app_config.default_prompt_text,
+ placeholder="Prompt audio 对应的文本转写(continuation cloning 必填)",
+ elem_classes="strong-label",
+ )
+
+ with gr.Column(scale=1, min_width=480, elem_classes="synthesis-column"):
+ text = gr.Textbox(
+ label="待合成文本 · Text",
+ lines=5,
+ max_lines=8,
+ value=DEFAULT_INPUT_TEXT,
+ placeholder="输入待合成的文本",
+ elem_classes="strong-label",
+ )
+ with gr.Accordion("⚙️ Settings", open=False, elem_classes="settings-card"):
+ with gr.Row(elem_classes="settings-slider-row"):
+ num_steps = gr.Slider(
+ label="Num Steps",
+ minimum=1,
+ maximum=32,
+ step=1,
+ value=app_config.default_num_steps,
+ )
+ with gr.Row(elem_classes="settings-slider-row"):
+ guidance_scale = gr.Slider(
+ label="Guidance Scale",
+ minimum=1.0,
+ maximum=3.0,
+ step=0.1,
+ value=app_config.default_guidance_scale,
+ )
+ with gr.Row(elem_classes="control-row"):
+ seed = gr.Number(
+ label="Seed",
+ value=DEFAULT_SEED,
+ precision=0,
+ scale=1,
+ min_width=180,
+ )
+ normalize_text = gr.Checkbox(
+ label="Normalize Text",
+ value=False,
+ scale=1,
+ min_width=180,
+ )
+ generate = gr.Button(
+ "Generate",
+ variant="primary",
+ size="lg",
+ elem_classes="generate-button",
+ )
+ audio_out = gr.Audio(
+ label="生成音频 · Output",
+ type="filepath",
+ elem_classes="output-audio",
+ )
+
+ if DEBUG_GRADIO_ENABLED:
+ with gr.Accordion("Debug", open=False):
+ synthesis_mode = gr.Dropdown(
+ label="SynthesisMode",
+ choices=list(GRADIO_SYNTHESIS_MODE_CHOICES),
+ value="tts",
+ info="选择合成模式;界面显示名会自动映射到 runtime 对应模板。",
+ )
+ ode_method = gr.Textbox(
+ label="ODE Method",
+ value=DEFAULT_ODE_METHOD,
+ lines=1,
+ )
+ speaker_scale = gr.Slider(
+ label="Speaker Scale",
+ minimum=0.0,
+ maximum=3.0,
+ step=0.1,
+ value=app_config.default_speaker_scale,
+ info="说话人 x-vector 强度",
+ )
+ metrics = gr.JSON(label="Metrics", value=app_service.metadata())
+ build_startup_config_panel(gr, app_config)
+ else:
+ synthesis_mode = gr.State(value="tts")
+ ode_method = gr.State(value=DEFAULT_ODE_METHOD)
+ speaker_scale = gr.State(value=app_config.default_speaker_scale)
+ metrics = gr.State(value={})
+
+ generate.click(
+ fn=run_synthesis,
+ inputs=[
+ text,
+ synthesis_mode,
+ prompt_audio_path,
+ prompt_text,
+ ode_method,
+ num_steps,
+ guidance_scale,
+ speaker_scale,
+ normalize_text,
+ seed,
+ ],
+ outputs=[audio_out, metrics],
+ concurrency_limit=1,
+ )
+ prompt_preset.change(
+ fn=select_prompt_preset,
+ inputs=[prompt_preset],
+ outputs=[prompt_audio_path, prompt_text],
+ concurrency_limit=1,
+ )
+
+ return demo.queue(default_concurrency_limit=1, max_size=8)
+
+
+def main() -> None:
+ args = parse_args()
+ import gradio as gr
+ from loguru import logger
+
+ from apps.gradio.service import GradioAppService, build_gradio_app_config
+ from dots_tts.utils.logging import configure_logging
+
+ configure_logging(log_file=args.log_file)
+ logger.info(
+ "Gradio app starting: host={} port={} model_name_or_path={} output_dir={} "
+ "log_file={} output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={} "
+ "default_prompt_name={} skip_warmup={}",
+ args.host,
+ args.port,
+ args.model_name_or_path,
+ args.output_dir,
+ args.log_file,
+ args.output_retention_count,
+ args.max_generate_length,
+ args.execution_mode,
+ args.precision,
+ args.optimize,
+ args.default_prompt_name,
+ args.skip_warmup,
+ )
+ app_config = build_gradio_app_config(
+ host=args.host,
+ port=args.port,
+ execution_mode=args.execution_mode,
+ precision=args.precision,
+ optimize=args.optimize,
+ model_name_or_path=args.model_name_or_path,
+ output_dir=Path(args.output_dir),
+ output_retention_count=args.output_retention_count,
+ max_generate_length=args.max_generate_length,
+ default_prompt_name=args.default_prompt_name,
+ default_precision=args.default_precision,
+ default_num_steps=args.default_num_steps,
+ default_guidance_scale=args.default_guidance_scale,
+ default_speaker_scale=args.default_speaker_scale,
+ default_max_generate_length=args.default_max_generate_length,
+ )
+ app_service = GradioAppService(app_config)
+ if args.skip_warmup:
+ logger.info("Gradio app warmup skipped by --skip-warmup.")
+ else:
+ warmup_metrics = app_service.warmup()
+ logger.info("Gradio app warmup metrics: {}", warmup_metrics)
+ demo = build_demo(gr, app_config, app_service)
+ logger.info(
+ "Gradio app ready: host={} port={} execution_mode={} precision={} optimize={} default_model_name_or_path={}",
+ app_config.host,
+ app_config.port,
+ app_config.execution_mode,
+ app_config.precision,
+ app_config.optimize,
+ app_config.default_model_name_or_path,
+ )
+ demo.launch(
+ server_name=app_config.host,
+ server_port=app_config.port,
+ theme=build_playground_theme(gr),
+ css=PLAYGROUND_CSS,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/apps/gradio/constants.py b/apps/gradio/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..f60261f89f7ec298736b7c8cd0395cff1b1993b5
--- /dev/null
+++ b/apps/gradio/constants.py
@@ -0,0 +1,26 @@
+from __future__ import annotations
+
+from pathlib import Path
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
+DEFAULT_HOST = "0.0.0.0"
+DEFAULT_PORT = 7860
+DEFAULT_OUTPUT_DIR = REPO_ROOT / "apps" / "gradio" / "outputs"
+DEFAULT_LOG_FILE = REPO_ROOT / "apps" / "gradio" / "gradio.log"
+DEFAULT_PROMPTS_DIR = REPO_ROOT / "apps" / "gradio" / "default_prompts"
+DEFAULT_PROMPT_SOURCE_DIR = DEFAULT_PROMPTS_DIR
+DEFAULT_PROMPT_MAPPING_FILE = DEFAULT_PROMPTS_DIR / "prompt_text"
+DEFAULT_OUTPUT_RETENTION = 20
+DEFAULT_EXECUTION_MODE = "generate_stream"
+DEFAULT_PRECISION = "bfloat16"
+DEFAULT_ODE_METHOD = "euler"
+DEFAULT_NUM_STEPS = 10
+DEFAULT_GUIDANCE_SCALE = 1.2
+DEFAULT_SPEAKER_SCALE = 1.5
+DEFAULT_MAX_GENERATE_LENGTH = 500
+DEFAULT_SEED = 42
+DEFAULT_INPUT_TEXT = ""
+DEFAULT_WARMUP_TEXT = "dots.tts is a 2B-parameter fully continuous, end-to-end autoregressive (AR) text-to-speech system. The backbone pairs a semantic encoder, an LLM, and an autoregressive flow-matching acoustic head over a 48 kHz AudioVAE"
+DEFAULT_PROMPT_NAME = "male_zh"
+DEFAULT_PROMPT_NONE = "__none__"
+PROMPT_AUDIO_SUFFIXES = (".wav", ".mp3", ".flac", ".m4a", ".ogg")
diff --git a/apps/gradio/default_prompts/prompt_text b/apps/gradio/default_prompts/prompt_text
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/apps/gradio/languages.py b/apps/gradio/languages.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffea99232718c44b4642c81485722476413be613
--- /dev/null
+++ b/apps/gradio/languages.py
@@ -0,0 +1,115 @@
+from __future__ import annotations
+
+SUPPORTED_LANGUAGE_CODE_BY_NAME = {
+ "普通话": "ZH",
+ "粤语": "口音:粤语",
+ "北京话": "口音:北京官话",
+ "东北话": "口音:东北话",
+ "四川话": "口音:四川话",
+ "闽南话": "口音:闽南话",
+ "吴语": "口音:吴语",
+ "英语": "EN",
+ "西班牙语": "ES",
+ "印地语": "HI",
+ "阿拉伯语": "AR",
+ "孟加拉语": "BN",
+ "葡萄牙语": "PT",
+ "俄语": "RU",
+ "日语": "JA",
+ "法语": "FR",
+ "德语": "DE",
+ "韩语": "KO",
+ "意大利语": "IT",
+ "土耳其语": "TR",
+ "越南语": "VI",
+ "印尼语": "ID",
+ "乌尔都语": "UR",
+ "波斯语": "FA",
+ "泰米尔语": "TA",
+ "泰卢固语": "TE",
+ "菲律宾语": "FIL",
+ "马来语": "MS",
+ "旁遮普语": "PA",
+ "马拉地语": "MR",
+ "古吉拉特语": "GU",
+ "马拉雅拉姆语": "ML",
+ "卡纳达语": "KN",
+ "波兰语": "PL",
+ "乌克兰语": "UK",
+ "荷兰语": "NL",
+ "泰语": "TH",
+ "罗马尼亚语": "RO",
+ "斯瓦希里语": "SW",
+ "希伯来语": "HE",
+ "捷克语": "CS",
+ "希腊语": "EL",
+ "匈牙利语": "HU",
+ "瑞典语": "SV",
+ "丹麦语": "DA",
+ "芬兰语": "FI",
+ "书面挪威语": "NB",
+ "斯洛伐克语": "SK",
+ "斯洛文尼亚语": "SL",
+ "塞尔维亚语": "SR",
+ "波斯尼亚语": "BS",
+ "克罗地亚语": "HR",
+ "保加利亚语": "BG",
+ "马其顿语": "MK",
+ "立陶宛语": "LT",
+ "拉脱维亚语": "LV",
+ "爱沙尼亚语": "ET",
+ "冰岛语": "IS",
+ "爱尔兰语": "GA",
+ "威尔士语": "CY",
+ "加泰罗尼亚语": "CA",
+ "加利西亚语": "GL",
+ "奥克语": "OC",
+ "阿斯图里亚斯语": "AST",
+ "尼泊尔语": "NE",
+ "信德语": "SD",
+ "奥里亚语": "OR",
+ "阿萨姆语": "AS",
+ "普什图语": "PS",
+ "缅甸语": "MY",
+ "高棉语": "KM",
+ "老挝语": "LO",
+ "哈萨克语": "KK",
+ "乌兹别克语": "UZ",
+ "吉尔吉斯语": "KY",
+ "塔吉克语": "TG",
+ "阿塞拜疆语": "AZ",
+ "格鲁吉亚语": "KA",
+ "亚美尼亚语": "HY",
+ "白俄罗斯语": "BE",
+ "卢森堡语": "LB",
+ "马耳他语": "MT",
+ "毛利语": "MI",
+ "南非荷兰语": "AF",
+ "祖鲁语": "ZU",
+ "科萨语": "XH",
+ "约鲁巴语": "YO",
+ "豪萨语": "HA",
+ "伊博语": "IG",
+ "阿姆哈拉语": "AM",
+ "奥罗莫语": "OM",
+ "北索托语": "NSO",
+ "尼扬贾语": "NY",
+ "修纳语": "SN",
+ "索马里语": "SO",
+ "卢干达语": "LG",
+ "林加拉语": "LN",
+ "卢奥语": "LUO",
+ "坎巴语": "KAM",
+ "翁本杜语": "UMB",
+ "富拉语": "FF",
+ "沃洛夫语": "WO",
+ "中库尔德语": "CKB",
+ "宿务语": "CEB",
+ "佛得角克里奥尔语": "KEA",
+ "蒙古语": "MN",
+ "爪哇语": "JV",
+}
+
+
+def build_language_choice_items() -> list[tuple[str, str]]:
+ return [("不指定", ""), *[(name, code) for name, code in SUPPORTED_LANGUAGE_CODE_BY_NAME.items()]]
diff --git a/apps/gradio/service.py b/apps/gradio/service.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a0b43db92b4c1356fc6cec189d6935983df8adb
--- /dev/null
+++ b/apps/gradio/service.py
@@ -0,0 +1,773 @@
+from __future__ import annotations
+
+import shutil
+import sys
+import threading
+import time
+import uuid
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Literal
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
+SRC_ROOT = REPO_ROOT / "src"
+
+for import_root in (REPO_ROOT, SRC_ROOT):
+ import_root_str = str(import_root)
+ if import_root_str not in sys.path:
+ sys.path.insert(0, import_root_str)
+
+import soundfile as sf # noqa: E402
+import torch # noqa: E402
+from loguru import logger # noqa: E402
+
+from apps.gradio.constants import ( # noqa: E402
+ DEFAULT_EXECUTION_MODE,
+ DEFAULT_GUIDANCE_SCALE,
+ DEFAULT_HOST,
+ DEFAULT_MAX_GENERATE_LENGTH,
+ DEFAULT_NUM_STEPS,
+ DEFAULT_ODE_METHOD,
+ DEFAULT_OUTPUT_DIR,
+ DEFAULT_OUTPUT_RETENTION,
+ DEFAULT_PORT,
+ DEFAULT_PRECISION,
+ DEFAULT_PROMPT_MAPPING_FILE,
+ DEFAULT_PROMPT_NAME,
+ DEFAULT_PROMPT_NONE,
+ DEFAULT_PROMPT_SOURCE_DIR,
+ DEFAULT_PROMPTS_DIR,
+ DEFAULT_SEED,
+ DEFAULT_SPEAKER_SCALE,
+ DEFAULT_WARMUP_TEXT,
+ PROMPT_AUDIO_SUFFIXES,
+)
+from apps.gradio.languages import ( # noqa: E402
+ SUPPORTED_LANGUAGE_CODE_BY_NAME,
+ build_language_choice_items,
+)
+from dots_tts.runtime import DotsTtsRuntime # noqa: E402
+from dots_tts.utils.util import seed_everything # noqa: E402
+
+ExecutionMode = Literal["generate", "generate_stream"]
+GRADIO_SYNTHESIS_MODE_CHOICES = (
+ ("tts", "tts"),
+ ("instruct_tts", "instruction_tts"),
+ ("instruct_tts_general", "text_to_audio"),
+)
+GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES = tuple(
+ value for _, value in GRADIO_SYNTHESIS_MODE_CHOICES
+)
+
+
+@dataclass(frozen=True)
+class PromptPreset:
+ name: str
+ audio_path: str
+ prompt_text: str
+
+
+def _is_prompt_asset(path: Path) -> bool:
+ return path.is_file() and (
+ path.name == "prompt_text" or path.suffix.lower() in PROMPT_AUDIO_SUFFIXES
+ )
+
+
+def sync_default_prompt_library(
+ source_dir: Path = DEFAULT_PROMPT_SOURCE_DIR,
+ target_dir: Path = DEFAULT_PROMPTS_DIR,
+) -> None:
+ source_dir = Path(source_dir)
+ if not source_dir.is_dir():
+ logger.info(
+ "Prompt library sync skipped: source_dir={} does not exist.",
+ source_dir,
+ )
+ return
+
+ target_dir = Path(target_dir)
+ target_dir.mkdir(parents=True, exist_ok=True)
+ logger.info(
+ "Prompt library sync started: source_dir={} target_dir={}",
+ source_dir,
+ target_dir,
+ )
+
+ source_assets = {
+ asset.name: asset for asset in sorted(source_dir.iterdir()) if _is_prompt_asset(asset)
+ }
+ copied_count = 0
+ for asset_name, source_asset in source_assets.items():
+ target_asset = target_dir / asset_name
+ if (
+ not target_asset.exists()
+ or target_asset.stat().st_size != source_asset.stat().st_size
+ or target_asset.stat().st_mtime_ns != source_asset.stat().st_mtime_ns
+ ):
+ shutil.copy2(source_asset, target_asset)
+ copied_count += 1
+
+ removed_count = 0
+ for target_asset in sorted(target_dir.iterdir()):
+ if _is_prompt_asset(target_asset) and target_asset.name not in source_assets:
+ target_asset.unlink(missing_ok=True)
+ removed_count += 1
+ logger.info(
+ "Prompt library sync completed: copied_assets={} removed_assets={} "
+ "available_assets={}",
+ copied_count,
+ removed_count,
+ len(source_assets),
+ )
+
+
+def _load_prompt_text_map(mapping_file: Path) -> dict[str, str]:
+ if not mapping_file.is_file():
+ return {}
+
+ prompt_text_map: dict[str, str] = {}
+ with mapping_file.open(encoding="utf-8") as file_obj:
+ for raw_line in file_obj:
+ line = raw_line.strip()
+ if not line or line.startswith("#") or "|" not in line:
+ continue
+ name, text = line.split("|", 1)
+ prompt_text_map[name.strip()] = text.strip()
+ return prompt_text_map
+
+
+def discover_prompt_presets(
+ prompts_dir: Path = DEFAULT_PROMPTS_DIR,
+ mapping_file: Path = DEFAULT_PROMPT_MAPPING_FILE,
+) -> tuple[PromptPreset, ...]:
+ prompts_dir = Path(prompts_dir)
+ if not prompts_dir.is_dir():
+ return ()
+
+ prompt_text_map = _load_prompt_text_map(Path(mapping_file))
+ prompt_audio_paths = [
+ audio_path
+ for audio_path in sorted(prompts_dir.iterdir(), key=lambda path: (path.stem == "child", path.stem))
+ if audio_path.is_file() and audio_path.suffix.lower() in PROMPT_AUDIO_SUFFIXES
+ ]
+ return tuple(
+ PromptPreset(
+ name=audio_path.stem,
+ audio_path=str(audio_path.resolve()),
+ prompt_text=prompt_text_map.get(audio_path.stem, ""),
+ )
+ for audio_path in prompt_audio_paths
+ )
+
+
+def build_prompt_choice_items(
+ prompt_presets: tuple[PromptPreset, ...],
+) -> list[tuple[str, str]]:
+ return [("No Preset", DEFAULT_PROMPT_NONE), *[(preset.name, preset.name) for preset in prompt_presets]]
+
+
+def resolve_default_prompt_selection(
+ prompt_presets: tuple[PromptPreset, ...],
+ default_prompt_name: str = DEFAULT_PROMPT_NAME,
+) -> tuple[str, str | None, str]:
+ if not prompt_presets:
+ return DEFAULT_PROMPT_NONE, None, ""
+
+ preset_by_name = {preset.name: preset for preset in prompt_presets}
+ selected_name = default_prompt_name if default_prompt_name in preset_by_name else prompt_presets[0].name
+ selected_preset = preset_by_name[selected_name]
+ return selected_name, selected_preset.audio_path, selected_preset.prompt_text
+
+
+def resolve_prompt_selection(
+ prompt_name: str,
+ prompt_presets: tuple[PromptPreset, ...],
+) -> tuple[str | None, str]:
+ if prompt_name == DEFAULT_PROMPT_NONE:
+ return None, ""
+
+ for preset in prompt_presets:
+ if preset.name == prompt_name:
+ return preset.audio_path, preset.prompt_text
+ return None, ""
+
+
+def discover_local_model_choices(repo_root: Path = REPO_ROOT) -> list[str]:
+ model_root = Path(repo_root) / "pretrained_models"
+ if not model_root.is_dir():
+ return []
+ return sorted(
+ path.relative_to(repo_root).as_posix()
+ for path in model_root.glob("**/model")
+ if path.is_dir()
+ )
+
+
+def resolve_model_name_or_path(model_name_or_path: str, repo_root: Path = REPO_ROOT) -> str:
+ normalized = model_name_or_path.strip()
+ if not normalized:
+ raise ValueError("model_name_or_path 不能为空。")
+
+ direct_path = Path(normalized).expanduser()
+ if direct_path.exists():
+ return str(direct_path.resolve())
+
+ repo_relative_path = Path(repo_root) / normalized
+ if repo_relative_path.exists():
+ return str(repo_relative_path.resolve())
+
+ return normalized
+
+
+def default_model_name_or_path(repo_root: Path = REPO_ROOT) -> str:
+ discovered = discover_local_model_choices(repo_root=repo_root)
+ if not discovered:
+ return ""
+ return discovered[0]
+
+
+@dataclass(frozen=True)
+class GradioAppConfig:
+ host: str
+ port: int
+ execution_mode: ExecutionMode
+ precision: str
+ optimize: bool
+ output_dir: Path
+ prompts_dir: Path
+ output_retention_count: int
+ max_generate_length: int
+ default_model_name_or_path: str
+ prompt_presets: tuple[PromptPreset, ...]
+ default_prompt_name: str
+ default_prompt_audio_path: str | None
+ default_prompt_text: str
+ default_precision: str
+ default_num_steps: int
+ default_guidance_scale: float
+ default_speaker_scale: float
+ default_max_generate_length: int
+ local_model_choices: tuple[str, ...]
+ repo_root: Path = REPO_ROOT
+
+
+def build_gradio_app_config(
+ *,
+ host: str = DEFAULT_HOST,
+ port: int = DEFAULT_PORT,
+ execution_mode: ExecutionMode = DEFAULT_EXECUTION_MODE,
+ precision: str = DEFAULT_PRECISION,
+ optimize: bool = False,
+ output_dir: Path = DEFAULT_OUTPUT_DIR,
+ output_retention_count: int = DEFAULT_OUTPUT_RETENTION,
+ max_generate_length: int = DEFAULT_MAX_GENERATE_LENGTH,
+ model_name_or_path: str | None = None,
+ default_prompt_name: str = DEFAULT_PROMPT_NAME,
+ default_precision: str = DEFAULT_PRECISION,
+ default_num_steps: int = DEFAULT_NUM_STEPS,
+ default_guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
+ default_speaker_scale: float = DEFAULT_SPEAKER_SCALE,
+ default_max_generate_length: int = DEFAULT_MAX_GENERATE_LENGTH,
+ repo_root: Path = REPO_ROOT,
+ prompts_dir: Path = DEFAULT_PROMPTS_DIR,
+ prompt_source_dir: Path = DEFAULT_PROMPT_SOURCE_DIR,
+) -> GradioAppConfig:
+ sync_default_prompt_library(
+ source_dir=prompt_source_dir,
+ target_dir=prompts_dir,
+ )
+ discovered_models = discover_local_model_choices(repo_root=repo_root)
+ prompt_presets = discover_prompt_presets(
+ prompts_dir=prompts_dir,
+ mapping_file=prompts_dir / "prompt_text",
+ )
+ resolved_default_prompt_name, default_prompt_audio_path, default_prompt_text = (
+ resolve_default_prompt_selection(
+ prompt_presets,
+ default_prompt_name=default_prompt_name,
+ )
+ )
+ selected_model_name_or_path = (
+ model_name_or_path.strip()
+ if model_name_or_path is not None
+ else default_model_name_or_path(repo_root=repo_root)
+ )
+ if not selected_model_name_or_path:
+ raise ValueError("No default model found. Please pass --model-name-or-path.")
+ if execution_mode not in ("generate", "generate_stream"):
+ raise ValueError(f"Unsupported execution_mode: {execution_mode}")
+ resolved_max_generate_length = int(max_generate_length)
+ if resolved_max_generate_length <= 0:
+ raise ValueError("max_generate_length must be positive.")
+ resolved_precision = precision.strip() or DEFAULT_PRECISION
+ logger.info(
+ "Gradio app config prepared: host={} port={} output_dir={} "
+ "output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={} "
+ "default_model_name_or_path={} prompt_preset_count={} language_count={} local_model_choice_count={}",
+ host,
+ port,
+ output_dir,
+ output_retention_count,
+ resolved_max_generate_length,
+ execution_mode,
+ resolved_precision,
+ bool(optimize),
+ selected_model_name_or_path,
+ len(prompt_presets),
+ len(SUPPORTED_LANGUAGE_CODE_BY_NAME),
+ len(discovered_models),
+ )
+ return GradioAppConfig(
+ host=host,
+ port=int(port),
+ execution_mode=execution_mode,
+ precision=resolved_precision,
+ optimize=bool(optimize),
+ output_dir=Path(output_dir),
+ prompts_dir=Path(prompts_dir),
+ output_retention_count=int(output_retention_count),
+ max_generate_length=resolved_max_generate_length,
+ default_model_name_or_path=selected_model_name_or_path,
+ prompt_presets=prompt_presets,
+ default_prompt_name=resolved_default_prompt_name,
+ default_prompt_audio_path=default_prompt_audio_path,
+ default_prompt_text=default_prompt_text,
+ default_precision=default_precision,
+ default_num_steps=int(default_num_steps),
+ default_guidance_scale=float(default_guidance_scale),
+ default_speaker_scale=float(default_speaker_scale),
+ default_max_generate_length=int(default_max_generate_length),
+ local_model_choices=tuple(discovered_models),
+ repo_root=repo_root,
+ )
+
+
+@dataclass(frozen=True)
+class SynthesisRequest:
+ model_name_or_path: str
+ text: str
+ prompt_audio_path: str | None = None
+ prompt_text: str | None = None
+ execution_mode: ExecutionMode = DEFAULT_EXECUTION_MODE
+ template_name: str = "tts"
+ language: str | None = None
+ ode_method: str = DEFAULT_ODE_METHOD
+ num_steps: int = DEFAULT_NUM_STEPS
+ guidance_scale: float = DEFAULT_GUIDANCE_SCALE
+ speaker_scale: float = DEFAULT_SPEAKER_SCALE
+ normalize_text: bool = False
+ seed: int = DEFAULT_SEED
+
+
+@dataclass(frozen=True)
+class SynthesisResult:
+ audio_path: str
+ metrics: dict[str, Any]
+ status: str
+
+
+class GradioAppService:
+ def __init__(self, config: GradioAppConfig):
+ self.config = config
+ self.config.output_dir.mkdir(parents=True, exist_ok=True)
+ self._lock = threading.Lock()
+ self._runtime: DotsTtsRuntime | None = None
+ self._runtime_model_name_or_path: str | None = None
+ logger.info(
+ "Gradio service initialized: output_dir={} default_model_name_or_path={} "
+ "output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={}",
+ self.config.output_dir,
+ self.config.default_model_name_or_path,
+ self.config.output_retention_count,
+ self.config.max_generate_length,
+ self.config.execution_mode,
+ self.config.precision,
+ self.config.optimize,
+ )
+
+ def metadata(self) -> dict[str, Any]:
+ return {
+ "repo_root": str(self.config.repo_root),
+ "default_model_name_or_path": self.config.default_model_name_or_path,
+ "local_model_choices": list(self.config.local_model_choices),
+ "prompts_dir": str(self.config.prompts_dir),
+ "prompt_preset_names": [preset.name for preset in self.config.prompt_presets],
+ "default_prompt_name": self.config.default_prompt_name,
+ "output_dir": str(self.config.output_dir),
+ "output_retention_count": self.config.output_retention_count,
+ "configured_max_generate_length": self.config.max_generate_length,
+ "configured_execution_mode": self.config.execution_mode,
+ "configured_precision": self.config.precision,
+ "optimize": self.config.optimize,
+ "loaded_model_name_or_path": self._runtime_model_name_or_path,
+ "loaded_max_generate_length": (
+ self.config.max_generate_length if self._runtime is not None else None
+ ),
+ "loaded_precision": (
+ self.config.precision if self._runtime is not None else None
+ ),
+ "model_loaded": self._runtime is not None,
+ "host": self.config.host,
+ "port": self.config.port,
+ "default_precision": self.config.default_precision,
+ "default_num_steps": self.config.default_num_steps,
+ "default_guidance_scale": self.config.default_guidance_scale,
+ "default_speaker_scale": self.config.default_speaker_scale,
+ "default_max_generate_length": self.config.default_max_generate_length,
+ "supported_languages": build_language_choice_items()[1:],
+ "supported_template_names": list(GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES),
+ }
+
+ def _get_runtime(
+ self,
+ model_name_or_path: str,
+ ) -> tuple[DotsTtsRuntime, str]:
+ resolved_model_name_or_path = resolve_model_name_or_path(
+ model_name_or_path,
+ repo_root=self.config.repo_root,
+ )
+ if (
+ self._runtime is None
+ or self._runtime_model_name_or_path != resolved_model_name_or_path
+ ):
+ logger.info(
+ "Gradio runtime cache miss: requested_model={} resolved_model={} "
+ "max_generate_length={} execution_mode={} precision={} optimize={}",
+ model_name_or_path,
+ resolved_model_name_or_path,
+ self.config.max_generate_length,
+ self.config.execution_mode,
+ self.config.precision,
+ self.config.optimize,
+ )
+ self._runtime = DotsTtsRuntime.from_pretrained(
+ resolved_model_name_or_path,
+ precision=self.config.precision,
+ optimize=self.config.optimize,
+ max_generate_length=self.config.max_generate_length,
+ )
+ self._runtime_model_name_or_path = resolved_model_name_or_path
+ else:
+ logger.info(
+ "Gradio runtime cache hit: requested_model={} resolved_model={} "
+ "max_generate_length={} execution_mode={} precision={} optimize={}",
+ model_name_or_path,
+ resolved_model_name_or_path,
+ self.config.max_generate_length,
+ self.config.execution_mode,
+ self.config.precision,
+ self.config.optimize,
+ )
+ return self._runtime, resolved_model_name_or_path
+
+ def _build_stream_request_id(
+ self,
+ runtime: DotsTtsRuntime,
+ request: SynthesisRequest,
+ ) -> str:
+ normalized_text, normalized_language = runtime._process_text( # noqa: SLF001
+ request.text,
+ language=request.language,
+ normalize=request.normalize_text,
+ )
+ normalized_prompt_text = runtime._process_prompt_text( # noqa: SLF001
+ request.prompt_text,
+ language=normalized_language,
+ )
+ if normalized_language is not None and not normalized_prompt_text:
+ from dots_tts.utils.text import attach_language_tag # noqa: PLC0415
+
+ normalized_text = attach_language_tag(
+ normalized_text,
+ normalized_language,
+ )
+ request_id_kwargs = {
+ "text": normalized_text,
+ "prompt_audio_path": request.prompt_audio_path,
+ "prompt_text": normalized_prompt_text,
+ "template_name": request.template_name,
+ }
+ if normalized_language is not None:
+ request_id_kwargs["language"] = normalized_language
+ return runtime._build_request_id( # noqa: SLF001
+ **request_id_kwargs,
+ )
+
+ @staticmethod
+ def _build_runtime_generate_kwargs(request: SynthesisRequest) -> dict[str, Any]:
+ runtime_kwargs: dict[str, Any] = {
+ "text": request.text,
+ "prompt_audio_path": request.prompt_audio_path,
+ "prompt_text": request.prompt_text,
+ "template_name": request.template_name,
+ "ode_method": request.ode_method,
+ "num_steps": request.num_steps,
+ "guidance_scale": request.guidance_scale,
+ "speaker_scale": request.speaker_scale,
+ "normalize_text": request.normalize_text,
+ }
+ if request.language is not None:
+ runtime_kwargs["language"] = request.language
+ return runtime_kwargs
+
+ def _run_stream_generation(
+ self,
+ runtime: DotsTtsRuntime,
+ request: SynthesisRequest,
+ ) -> dict[str, Any]:
+ start_time = time.time()
+ chunks = [
+ chunk.detach().float().cpu()
+ for chunk in runtime.generate_stream(
+ **self._build_runtime_generate_kwargs(request)
+ )
+ ]
+ if not chunks:
+ raise ValueError("流式生成未返回任何音频块。")
+
+ audio = torch.cat(chunks, dim=-1)
+ elapsed_seconds = time.time() - start_time
+ audio_seconds = audio.shape[-1] / runtime.sample_rate
+ rtf = elapsed_seconds / audio_seconds if audio_seconds > 0 else float("inf")
+ return {
+ "fid": self._build_stream_request_id(runtime, request),
+ "audio": audio,
+ "sample_rate": runtime.sample_rate,
+ "time_used": elapsed_seconds,
+ "rtf": rtf,
+ "chunk_count": len(chunks),
+ }
+
+ def warmup(self, text: str | None = None) -> dict[str, Any]:
+ warmup_text = (text or "").strip() or DEFAULT_WARMUP_TEXT.strip()
+ if not warmup_text:
+ raise ValueError("DEFAULT_WARMUP_TEXT 不能为空。")
+
+ with self._lock:
+ logger.info(
+ "Gradio warmup requested: default_model_name_or_path={} execution_mode={} precision={} optimize={} seed={}",
+ self.config.default_model_name_or_path,
+ self.config.execution_mode,
+ self.config.precision,
+ self.config.optimize,
+ DEFAULT_SEED,
+ )
+ try:
+ seed_everything(DEFAULT_SEED)
+ runtime, resolved_model_name_or_path = self._get_runtime(
+ self.config.default_model_name_or_path,
+ )
+ warmup_request = SynthesisRequest(
+ model_name_or_path=self.config.default_model_name_or_path,
+ text=warmup_text,
+ execution_mode=self.config.execution_mode,
+ template_name="tts",
+ ode_method=DEFAULT_ODE_METHOD,
+ num_steps=self.config.default_num_steps,
+ guidance_scale=self.config.default_guidance_scale,
+ speaker_scale=self.config.default_speaker_scale,
+ normalize_text=False,
+ seed=DEFAULT_SEED,
+ )
+ request_id = self._build_stream_request_id(runtime, warmup_request)
+ if self.config.execution_mode == "generate_stream":
+ result = self._run_stream_generation(runtime, warmup_request)
+ else:
+ start_time = time.time()
+ result = runtime.generate(**self._build_runtime_generate_kwargs(warmup_request))
+ result["time_used"] = time.time() - start_time
+ result["chunk_count"] = 1
+ audio_samples = int(result["audio"].shape[-1])
+ except Exception:
+ logger.exception(
+ "Gradio warmup failed: default_model_name_or_path={}",
+ self.config.default_model_name_or_path,
+ )
+ raise
+ audio_seconds = audio_samples / runtime.sample_rate
+ metrics = {
+ "request_id": request_id,
+ "execution_mode": self.config.execution_mode,
+ "chunk_count": int(result["chunk_count"]),
+ "resolved_model_name_or_path": resolved_model_name_or_path,
+ "sample_rate": runtime.sample_rate,
+ "elapsed_seconds": round(float(result["time_used"]), 3),
+ "audio_seconds": round(float(audio_seconds), 3),
+ "rtf": round(float(result["rtf"]), 4),
+ "seed": DEFAULT_SEED,
+ "text": warmup_text,
+ }
+ logger.info(
+ "Gradio warmup ready: request_id={} execution_mode={} resolved_model_name_or_path={}",
+ metrics["request_id"],
+ metrics["execution_mode"],
+ metrics["resolved_model_name_or_path"],
+ )
+ return metrics
+
+ def _normalize_request(self, request: SynthesisRequest) -> SynthesisRequest:
+ normalized_text = request.text.strip()
+ if not normalized_text:
+ raise ValueError("text 不能为空。")
+
+ normalized_prompt_audio_path = request.prompt_audio_path or None
+ normalized_prompt_text = (request.prompt_text or "").strip() or None
+ if normalized_prompt_text and not normalized_prompt_audio_path:
+ raise ValueError("prompt_text requires prompt_audio_path.")
+ normalized_template_name = request.template_name.strip() or "tts"
+ if normalized_template_name not in GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES:
+ raise ValueError(
+ f"Unsupported template_name={normalized_template_name!r}. "
+ f"Expected one of {list(GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES)}."
+ )
+ normalized_language = (request.language or "").strip() or None
+ supported_language_codes = set(SUPPORTED_LANGUAGE_CODE_BY_NAME.values())
+ if (
+ normalized_language is not None
+ and normalized_language not in supported_language_codes
+ ):
+ raise ValueError(
+ f"Unsupported language={normalized_language!r}. "
+ f"Expected one of {sorted(supported_language_codes)}."
+ )
+
+ resolved_seed = int(request.seed)
+ return SynthesisRequest(
+ model_name_or_path=request.model_name_or_path.strip(),
+ text=normalized_text,
+ prompt_audio_path=normalized_prompt_audio_path,
+ prompt_text=normalized_prompt_text,
+ execution_mode=request.execution_mode,
+ template_name=normalized_template_name,
+ language=normalized_language,
+ ode_method=request.ode_method.strip() or DEFAULT_ODE_METHOD,
+ num_steps=int(request.num_steps),
+ guidance_scale=float(request.guidance_scale),
+ speaker_scale=float(request.speaker_scale),
+ normalize_text=bool(request.normalize_text),
+ seed=resolved_seed,
+ )
+
+ def _build_output_path(self) -> Path:
+ output_name = f"{time.strftime('%Y%m%d-%H%M%S')}-{uuid.uuid4().hex[:8]}.wav"
+ return self.config.output_dir / output_name
+
+ def _cleanup_outputs(self) -> None:
+ if self.config.output_retention_count <= 0:
+ return
+
+ wav_files = sorted(
+ self.config.output_dir.glob("*.wav"),
+ key=lambda path: path.stat().st_mtime,
+ reverse=True,
+ )
+ removed_count = 0
+ for stale_file in wav_files[self.config.output_retention_count :]:
+ stale_file.unlink(missing_ok=True)
+ removed_count += 1
+ if removed_count > 0:
+ logger.info(
+ "Gradio output cleanup completed: removed_files={} retention_limit={}",
+ removed_count,
+ self.config.output_retention_count,
+ )
+
+ @staticmethod
+ def _waveform_to_numpy(audio: torch.Tensor):
+ waveform = audio.detach().float().cpu().squeeze()
+ if waveform.ndim == 0:
+ raise ValueError("生成音频为空。")
+ return waveform.numpy()
+
+ def _write_audio(self, audio: torch.Tensor, sample_rate: int) -> str:
+ output_path = self._build_output_path()
+ logger.info(
+ "Writing synthesized audio: output_path={} sample_rate={} samples={}",
+ output_path,
+ sample_rate,
+ audio.shape[-1],
+ )
+ sf.write(output_path, self._waveform_to_numpy(audio), sample_rate)
+ self._cleanup_outputs()
+ logger.info("Synthesized audio written: output_path={}", output_path)
+ return str(output_path)
+
+ def generate(self, request: SynthesisRequest) -> SynthesisResult:
+ normalized_request = self._normalize_request(request)
+
+ with self._lock:
+ try:
+ seed_everything(normalized_request.seed)
+ runtime, resolved_model_name_or_path = self._get_runtime(
+ normalized_request.model_name_or_path,
+ )
+ logger.info(
+ "Gradio request accepted: resolved_model_name_or_path={} execution_mode={} seed={}",
+ resolved_model_name_or_path,
+ normalized_request.execution_mode,
+ normalized_request.seed,
+ )
+ if normalized_request.execution_mode == "generate_stream":
+ result = self._run_stream_generation(runtime, normalized_request)
+ else:
+ result = runtime.generate(
+ **self._build_runtime_generate_kwargs(normalized_request)
+ )
+ result["chunk_count"] = 1
+ audio_path = self._write_audio(result["audio"], result["sample_rate"])
+ except Exception:
+ logger.exception(
+ "Gradio request failed: model_name_or_path={} execution_mode={} text_len={} has_prompt_audio={} has_prompt_text={} template_name={} language={} "
+ "precision={} ode_method={} num_steps={} guidance_scale={} speaker_scale={} max_generate_length={} "
+ "normalize_text={} seed={}",
+ normalized_request.model_name_or_path,
+ normalized_request.execution_mode,
+ len(normalized_request.text),
+ bool(normalized_request.prompt_audio_path),
+ bool(normalized_request.prompt_text),
+ normalized_request.template_name,
+ normalized_request.language,
+ self.config.precision,
+ normalized_request.ode_method,
+ normalized_request.num_steps,
+ normalized_request.guidance_scale,
+ normalized_request.speaker_scale,
+ self.config.max_generate_length,
+ normalized_request.normalize_text,
+ normalized_request.seed,
+ )
+ raise
+ audio_seconds = result["audio"].shape[-1] / result["sample_rate"]
+ metrics = {
+ "request_id": result["fid"],
+ "execution_mode": normalized_request.execution_mode,
+ "chunk_count": int(result["chunk_count"]),
+ "template_name": normalized_request.template_name,
+ "language": normalized_request.language,
+ "resolved_model_name_or_path": resolved_model_name_or_path,
+ "sample_rate": result["sample_rate"],
+ "elapsed_seconds": round(float(result["time_used"]), 3),
+ "audio_seconds": round(float(audio_seconds), 3),
+ "rtf": round(float(result["rtf"]), 4),
+ "seed": normalized_request.seed,
+ "output_path": audio_path,
+ }
+ logger.info(
+ "Gradio request output ready: request_id={} execution_mode={} resolved_model_name_or_path={} output_path={}",
+ metrics["request_id"],
+ metrics["execution_mode"],
+ metrics["resolved_model_name_or_path"],
+ metrics["output_path"],
+ )
+ status = (
+ f"完成:{Path(audio_path).name} | "
+ f"模式 {metrics['execution_mode']} | "
+ f"耗时 {metrics['elapsed_seconds']}s | "
+ f"音频 {metrics['audio_seconds']}s | "
+ f"RTF {metrics['rtf']}"
+ )
+ return SynthesisResult(
+ audio_path=audio_path,
+ metrics=metrics,
+ status=status,
+ )
diff --git a/configs/dots_tts.yaml b/configs/dots_tts.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c25aa9927938c9200d4a9a80881d21d127c61ab2
--- /dev/null
+++ b/configs/dots_tts.yaml
@@ -0,0 +1,76 @@
+train_data:
+ train_audio_sample_rate: 48000
+ audio_samples_per_llm_token: 7680
+ sources:
+ - name: ljspeech_basic
+ weight: 1.0
+ pipeline: basic
+ adapter:
+ class_name: JsonlManifestSourceAdapter
+ params:
+ manifest_path: downloaded_data/ljspeech_48khz_manifest_train.jsonl
+ shuffle: true
+ - name: ljspeech_interleave
+ weight: 1.0
+ pipeline: interleave
+ adapter:
+ class_name: JsonlManifestSourceAdapter
+ params:
+ manifest_path: downloaded_data/ljspeech_48khz_manifest_train.jsonl
+ shuffle: true
+ # append other sources here if need
+ num_tokens_per_epoch: 2000000
+ num_workers: 20
+ pin_memory: true
+ max_audio_seconds_in_batch: 30.0
+ max_text_tokens_in_batch: 2048
+ max_samples_per_batch: null
+ bucketing_pool_size: 100
+val_data:
+ train_audio_sample_rate: 48000
+ audio_samples_per_llm_token: 7680
+ sources:
+ - name: ljspeech_valid_basic
+ weight: 1.0
+ adapter:
+ class_name: JsonlManifestSourceAdapter
+ params:
+ manifest_path: downloaded_data/ljspeech_48khz_manifest_valid.jsonl
+ shuffle: false
+ pipeline: basic
+ - name: ljspeech_valid_interleave
+ weight: 1.0
+ pipeline: interleave
+ adapter:
+ class_name: JsonlManifestSourceAdapter
+ params:
+ manifest_path: downloaded_data/ljspeech_48khz_manifest_valid.jsonl
+ shuffle: false
+ pipeline: interleave
+ # append other sources here if need
+ num_workers: 4
+ pin_memory: true
+ max_audio_seconds_in_batch: 30.0
+ max_text_tokens_in_batch: 2048
+ max_samples_per_batch: null
+ bucketing_pool_size: 64
+train:
+ pretrained_model_path: pretrained_models/pretrain_cpt_decay/latest/model/
+ output_dir: debug_train/run_003
+ seed: 42
+ learning_rate: 1.0e-05
+ weight_decay: 0.01
+ warmup_steps: 50
+ max_train_steps: 500
+ gradient_accumulation_steps: 2
+ grad_clip_norm: 1
+ save_interval: 500
+ max_checkpoints_to_keep: 40
+ log_interval: 10
+ eval_interval: 100
+ max_eval_batches: null
+ run_eval_on_start: false
+loss:
+ ce_weight: 1.0
+ fm_weight: 1.0
+ eos_weight: 1.0
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..943217629151ffbaabaaeccdf124eb0858cc2c66
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,19 @@
+spaces>=0.40.1
+torch>=2.8.0
+torchaudio>=2.8.0
+transformers>=4.57.0
+huggingface-hub>=0.36.0
+gradio>=6.16.0
+loguru>=0.7.3
+langcodes[data]>=3.5.0
+einops>=0.8.1
+librosa>=0.11.0
+soundfile>=0.13.1
+numpy>=2.2.6
+pydantic>=2.12.5,<3
+PyYAML>=6.0.3
+safetensors>=0.8.0rc0
+torchdiffeq>=0.2.5
+tqdm>=4.67.1
+lingua-language-detector>=2.1.1
+WeTextProcessing>=1.0.4
diff --git a/src/dots_tts/__init__.py b/src/dots_tts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b25576d5f2ecf89d144162771ebb197943801da
--- /dev/null
+++ b/src/dots_tts/__init__.py
@@ -0,0 +1 @@
+"""dots.tts package."""
diff --git a/src/dots_tts/cli.py b/src/dots_tts/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c6e6bc5734ef1cc3eb903a08f50bfba0c66f2e3
--- /dev/null
+++ b/src/dots_tts/cli.py
@@ -0,0 +1,152 @@
+from __future__ import annotations
+
+import argparse
+from pathlib import Path
+
+
+def parse_args(argv=None):
+ parser = argparse.ArgumentParser(description="dots.tts inference CLI.")
+ template_choices = ("tts", "instruction_tts", "text_to_audio", "tts_interleave")
+ parser.add_argument(
+ "--model-name-or-path",
+ required=True,
+ help="Local pretrained directory or Hugging Face repo id",
+ )
+ parser.add_argument(
+ "--revision", default=None, help="Optional Hugging Face revision"
+ )
+ parser.add_argument(
+ "--cache-dir", default=None, help="Optional Hugging Face cache dir"
+ )
+ parser.add_argument("--text", type=str, required=True, help="Input text")
+ parser.add_argument("--output", default="output.wav", help="Output wav file path")
+ parser.add_argument(
+ "--precision", type=str, default="bfloat16", help="Inference precision"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="Random seed for inference.",
+ )
+ parser.add_argument(
+ "--prompt-audio", type=str, default=None, help="Path to prompt audio"
+ )
+ parser.add_argument(
+ "--prompt-text", type=str, default=None, help="Transcript of prompt audio"
+ )
+ parser.add_argument(
+ "--language",
+ type=str,
+ default=None,
+ help="Language tag mode. Default: none. Supported values: none, auto_detect, or a language code/name such as EN/en/english/chinese.",
+ )
+ parser.add_argument(
+ "--template-name",
+ choices=template_choices,
+ default=None,
+ help="Named template preset for generation.",
+ )
+ parser.add_argument(
+ "--ode-method", type=str, default="euler", help="ODE solver method"
+ )
+ parser.add_argument(
+ "--num-steps", type=int, default=10, help="Diffusion sampling steps"
+ )
+ parser.add_argument(
+ "--guidance-scale",
+ type=float,
+ default=1.2,
+ help="Classifier-free guidance scale",
+ )
+ parser.add_argument(
+ "--speaker-scale",
+ type=float,
+ default=1.5,
+ help="Scale applied to the reference speaker embedding",
+ )
+ parser.add_argument(
+ "--max-generate-length",
+ type=int,
+ default=500,
+ help="Maximum total audio patch count (prompt + generated)",
+ )
+ parser.add_argument(
+ "--normalize-text",
+ action="store_true",
+ help="Whether to normalize text before inference",
+ )
+ parser.add_argument(
+ "--profile-inference",
+ action="store_true",
+ help="Collect per-module inference timing statistics",
+ )
+ return parser.parse_args(argv)
+
+
+def main(argv=None):
+ args = parse_args(argv)
+ import soundfile as sf
+ from loguru import logger
+
+ from dots_tts.runtime import DotsTtsRuntime
+ from dots_tts.utils.logging import configure_logging
+ from dots_tts.utils.util import seed_everything
+
+ configure_logging()
+ seed_everything(args.seed)
+ output_path = Path(args.output)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ logger.info(
+ "CLI command started: model={} output={} seed={}",
+ args.model_name_or_path,
+ output_path,
+ args.seed,
+ )
+
+ try:
+ runtime = DotsTtsRuntime.from_pretrained(
+ args.model_name_or_path,
+ revision=args.revision,
+ cache_dir=args.cache_dir,
+ precision=args.precision,
+ max_generate_length=args.max_generate_length,
+ )
+ result = runtime.generate(
+ text=args.text,
+ prompt_audio_path=args.prompt_audio,
+ prompt_text=args.prompt_text,
+ language=args.language,
+ template_name=args.template_name,
+ ode_method=args.ode_method,
+ num_steps=args.num_steps,
+ guidance_scale=args.guidance_scale,
+ speaker_scale=args.speaker_scale,
+ normalize_text=args.normalize_text,
+ profile_inference=args.profile_inference,
+ )
+ sf.write(
+ output_path,
+ result["audio"].float().cpu().squeeze().numpy(),
+ result["sample_rate"],
+ )
+ except Exception:
+ logger.exception(
+ "CLI inference failed: model={} output={}",
+ args.model_name_or_path,
+ output_path,
+ )
+ raise
+
+ logger.info(
+ "CLI output written: request_id={} output={} sample_rate={} samples={}",
+ result["fid"],
+ output_path,
+ result["sample_rate"],
+ int(result["audio"].shape[-1]),
+ )
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/src/dots_tts/config/__init__.py b/src/dots_tts/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56096f22e33373e8aa0dd3cbb6a5e2f99bdbca1b
--- /dev/null
+++ b/src/dots_tts/config/__init__.py
@@ -0,0 +1 @@
+"""Configuration package."""
diff --git a/src/dots_tts/config/app.py b/src/dots_tts/config/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdf0ce2d3794feedf1bf1d36bcc05551df93907e
--- /dev/null
+++ b/src/dots_tts/config/app.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+from pathlib import Path
+
+import yaml
+
+from dots_tts.config.base import StrictConfigBase
+from dots_tts.config.data import DataConfig
+from dots_tts.config.train import TrainConfig
+from dots_tts.models.dots_tts.config import LossConfig
+
+DEFAULT_CONFIG_PATH = "configs/dots_tts.yaml"
+
+
+class AppConfig(StrictConfigBase):
+ train_data: DataConfig
+ val_data: DataConfig | None = None
+ loss: LossConfig
+ train: TrainConfig
+
+ @classmethod
+ def from_yaml(cls, config_path: str = DEFAULT_CONFIG_PATH) -> AppConfig:
+ with Path(config_path).open(encoding="utf-8") as fin:
+ raw_config = yaml.safe_load(fin)
+ return cls.model_validate(raw_config)
+
+
+def load_config(config_path: str = DEFAULT_CONFIG_PATH) -> AppConfig:
+ return AppConfig.from_yaml(config_path)
+
+
+__all__ = ["AppConfig", "DEFAULT_CONFIG_PATH", "load_config"]
diff --git a/src/dots_tts/config/base.py b/src/dots_tts/config/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b81ff06c3a149a616bc7a5be4575fbd3c75d19a5
--- /dev/null
+++ b/src/dots_tts/config/base.py
@@ -0,0 +1,64 @@
+from __future__ import annotations
+
+from typing import Any
+
+from pydantic import BaseModel, ConfigDict
+
+
+class ConfigBase(BaseModel):
+ model_config = ConfigDict(
+ extra="allow",
+ validate_assignment=True,
+ arbitrary_types_allowed=True,
+ )
+
+ def get(self, key: str, default=None):
+ value = getattr(self, key, default)
+ if value is default:
+ return value
+
+ fields_set = self.model_fields_set
+ if value is None and key not in fields_set:
+ return default
+ return value
+
+ def to_dict(self) -> dict[str, Any]:
+ return self.model_dump(exclude_none=True)
+
+ @classmethod
+ def _declared_field_names(cls) -> list[str]:
+ return [name for name in cls.model_fields if name != "model_config"]
+
+ @classmethod
+ def _serialize_declared_value(cls, value):
+ if isinstance(value, ConfigBase):
+ return value.to_declared_dict()
+ if isinstance(value, list):
+ return [cls._serialize_declared_value(item) for item in value]
+ if isinstance(value, tuple):
+ return [cls._serialize_declared_value(item) for item in value]
+ if isinstance(value, dict):
+ return {
+ key: cls._serialize_declared_value(item) for key, item in value.items()
+ }
+ return value
+
+ def to_declared_dict(self) -> dict[str, Any]:
+ data = {}
+ for name in self._declared_field_names():
+ value = getattr(self, name, None)
+ if value is None:
+ continue
+ data[name] = self._serialize_declared_value(value)
+ return data
+
+
+class StrictConfigBase(ConfigBase):
+ model_config = ConfigDict(
+ extra="forbid",
+ validate_assignment=True,
+ arbitrary_types_allowed=True,
+ )
+
+
+__all__ = ["ConfigBase", "StrictConfigBase"]
diff --git a/src/dots_tts/config/data.py b/src/dots_tts/config/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..6035115b076bef81dc0bab131a43224f0060b141
--- /dev/null
+++ b/src/dots_tts/config/data.py
@@ -0,0 +1,63 @@
+from __future__ import annotations
+
+from typing import Any, Literal
+
+from pydantic import Field, model_validator
+
+from dots_tts.config.base import StrictConfigBase
+
+DEFAULT_SOURCE_ADAPTER_CLASS_NAME = "JsonlManifestSourceAdapter"
+
+
+class SourceAdapterConfig(StrictConfigBase):
+ class_name: Literal["JsonlManifestSourceAdapter"] = (
+ DEFAULT_SOURCE_ADAPTER_CLASS_NAME
+ )
+ params: dict[str, Any] = Field(default_factory=dict)
+
+
+class DataSourceConfig(StrictConfigBase):
+ name: str
+ weight: float = Field(default=1.0, gt=0.0)
+ pipeline: Literal["basic", "interleave"] = "basic"
+ adapter: SourceAdapterConfig = Field(default_factory=SourceAdapterConfig)
+
+
+class DataConfig(StrictConfigBase):
+ sources: list[DataSourceConfig]
+ train_audio_sample_rate: int = Field(ge=1)
+ audio_samples_per_llm_token: int = Field(ge=1)
+ num_tokens_per_epoch: int | None = Field(
+ default=None,
+ ge=1,
+ description="Global token budget across all ranks for one training epoch.",
+ )
+ num_workers: int = Field(default=0, ge=0)
+ pin_memory: bool = False
+ prefetch_factor: int = Field(
+ default=2,
+ ge=1,
+ description="Samples prefetched by each DataLoader worker.",
+ )
+ max_audio_seconds_in_batch: float = Field(gt=0.0)
+ max_text_tokens_in_batch: int = Field(ge=1)
+ max_samples_per_batch: int | None = Field(default=None, ge=1)
+ bucketing_pool_size: int = Field(default=64, ge=1)
+
+ @model_validator(mode="after")
+ def _validate_unique_source_names(self) -> "DataConfig":
+ counts: dict[str, int] = {}
+ for source in self.sources:
+ counts[source.name] = counts.get(source.name, 0) + 1
+ duplicated = [name for name, count in counts.items() if count > 1]
+ if duplicated:
+ raise ValueError(f"Source names must be unique: {duplicated}")
+ return self
+
+
+__all__ = [
+ "DEFAULT_SOURCE_ADAPTER_CLASS_NAME",
+ "DataConfig",
+ "DataSourceConfig",
+ "SourceAdapterConfig",
+]
diff --git a/src/dots_tts/config/train.py b/src/dots_tts/config/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..b975ad7bebeebe9e408fad37981bab691407844d
--- /dev/null
+++ b/src/dots_tts/config/train.py
@@ -0,0 +1,28 @@
+from __future__ import annotations
+
+from pydantic import Field
+
+from dots_tts.config.base import StrictConfigBase
+
+
+class TrainConfig(StrictConfigBase):
+ pretrained_model_path: str
+ output_dir: str
+ seed: int = 42
+ learning_rate: float
+ cfg_droprate: float = 0.0
+ xvec_drop_rate: float = 0.5
+ weight_decay: float = 0.01
+ warmup_steps: int = 0
+ max_train_steps: int
+ gradient_accumulation_steps: int = Field(default=1, ge=1)
+ grad_clip_norm: float = 1.0
+ save_interval: int = Field(default=1000, ge=1)
+ max_checkpoints_to_keep: int = 10
+ log_interval: int = Field(default=10, ge=1)
+ eval_interval: int | None = Field(default=None, ge=1)
+ max_eval_batches: int | None = None
+ run_eval_on_start: bool = False
+
+
+__all__ = ["TrainConfig"]
diff --git a/src/dots_tts/data/EXTENSION.md b/src/dots_tts/data/EXTENSION.md
new file mode 100644
index 0000000000000000000000000000000000000000..660815f8e2c0379aa577496c5a67bb2b251de7a6
--- /dev/null
+++ b/src/dots_tts/data/EXTENSION.md
@@ -0,0 +1,124 @@
+# Data Source Extension Guide
+
+This document answers exactly one question: how to plug a new training data source into the current `dots_tts` data pipeline.
+
+If you only need to swap in a different JSONL manifest, no code changes are required. To support a new raw data format, you usually only need to add:
+
+- one **source adapter**
+- optionally one **sample pipeline**
+
+## Data flow
+
+1. An **adapter** reads from the raw data source and yields raw samples.
+2. A **pipeline** turns each raw sample into a training sample (1:1).
+3. A **multi-source wrapper** handles mixing across sources and resume state.
+4. `StreamingSampleDataset` / `DataLoader` pulls samples.
+5. `OnlineBatcher` assembles batches and `PadCollator` performs padding.
+
+## What an adapter must implement
+
+Subclass `BaseSourceAdapter`:
+
+```python
+class BaseSourceAdapter(ABC):
+ @abstractmethod
+ def initial_state(self) -> dict[str, Any]:
+ ...
+
+ @abstractmethod
+ def iter_samples(
+ self,
+ context: SourceContext,
+ *,
+ state: dict[str, Any] | None = None,
+ ) -> Iterable[dict[str, Any]]:
+ ...
+
+ @abstractmethod
+ def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool:
+ ...
+
+ # Optional — only required when used under WeightedMultiSourceAdapter,
+ # which cycles each finite child source independently. The default
+ # implementation raises if your adapter never gets re-cycled.
+ def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]:
+ ...
+```
+
+Each emitted sample **must** carry these fields:
+
+- `fid`
+- `text`
+- `audio`
+- `_adapter_state`
+
+Key constraints:
+
+- `_adapter_state` must describe **where to resume next**, not the position of the current item.
+- The state must be plain Python data — serializable and recoverable after a restart.
+- If your source needs to be split across workers, use `context.global_worker_id` and `context.global_worker_count` (or subclass `ShardableSourceAdapter` and use its `is_assigned_index` / `shard_items` helpers).
+- If the source will participate in weighted cyclic sampling, you must implement `advance_cycle` and make `is_cycle_start_state` correct — otherwise `WeightedMultiSourceAdapter` cannot detect an empty cycle and will raise.
+
+After implementing the adapter, register the class in `dots_tts/data/builders.py::_SOURCE_ADAPTER_CLASSES` so that the YAML config can resolve it by `class_name`.
+
+## What a pipeline must implement
+
+Pipelines must subclass `BaseSamplePipeline` and perform a strict **1:1** sample transform.
+
+Minimum implementation:
+
+```python
+class MyPipeline(BaseSamplePipeline):
+ def process_sample(self, sample: dict) -> dict:
+ sample["text"] = str(sample["text"]).strip()
+ return sample
+```
+
+Do **not**:
+
+- filter samples out
+- expand a single sample into multiple samples
+- assemble batches inside the pipeline
+
+`BaseSamplePipeline.__call__` automatically merges the original raw sample (including `_adapter_state` and any extra fields the adapter attached) with whatever your `process_sample` returns. You do not need to copy these fields manually — just return the fields you produced or want to overwrite.
+
+To wire a new pipeline into config, also extend `dots_tts/data/builders.py::_build_source_pipeline` so it can be selected by name in YAML.
+
+## How multi-source wrappers affect you
+
+There are two wrappers in the current codebase:
+
+- `SequentialMultiSourceAdapter` — used for validation. Reads sources in the configured order, exhaustively, once.
+- `WeightedMultiSourceAdapter` — used for training. Draws sources by weight, cycles each child source independently when exhausted.
+
+Both wrappers **replace** the `_adapter_state` produced by your child adapter with their own resume state before yielding to the dataset. Even so, the child adapter must still emit its own `_adapter_state` — the wrapper reads it to track where each sub-source has read to.
+
+## Config
+
+Each source is configured independently:
+
+```yaml
+train_data:
+ sources:
+ - name: train_a
+ weight: 1.0
+ pipeline: basic
+ adapter:
+ class_name: JsonlManifestSourceAdapter
+ params:
+ manifest_path: train_a.jsonl
+ - name: train_b
+ weight: 2.0
+ pipeline: interleave
+ adapter:
+ class_name: JsonlManifestSourceAdapter
+ params:
+ manifest_path: train_b.jsonl
+```
+
+Constraints:
+
+- `sources[].name` must be unique within the same `train_data` / `val_data` block (it is used as a dict key for resume state).
+- `sources[].pipeline` is a per-source setting, not shared across the dataset.
+- All sources must ultimately produce the same training-sample structure, since they feed into the same batcher and collator.
+- `class_name` must match a key registered in `_SOURCE_ADAPTER_CLASSES`; `params` is forwarded verbatim as kwargs to the adapter constructor.
diff --git a/src/dots_tts/data/__init__.py b/src/dots_tts/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..68c760d27de425293851d8186f8ad18d9987b4e2
--- /dev/null
+++ b/src/dots_tts/data/__init__.py
@@ -0,0 +1 @@
+"""Data package."""
diff --git a/src/dots_tts/data/batchers.py b/src/dots_tts/data/batchers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1503064c0e3f8fa095ddd43461efc5ca79b13f9a
--- /dev/null
+++ b/src/dots_tts/data/batchers.py
@@ -0,0 +1,188 @@
+from __future__ import annotations
+
+import warnings
+from collections.abc import Iterable, Iterator
+from dataclasses import dataclass
+
+from dots_tts.utils.profiling import ensure_data_profiler
+
+
+@dataclass(slots=True)
+class BatchDecision:
+ dropped_samples: list[dict]
+ batch_samples: list[dict]
+
+
+@dataclass(slots=True)
+class _PoolSample:
+ sample: dict
+ num_audio_tokens: int
+ num_text_tokens: int
+ arrival_step: int
+
+
+class OnlineBatcher:
+ def __init__(
+ self,
+ *,
+ max_audio_tokens_in_batch: int,
+ max_text_tokens_in_batch: int,
+ max_batch_size: int | None,
+ sample_pool_size: int,
+ profiler=None,
+ ):
+ self.max_audio_tokens_in_batch = max(1, int(max_audio_tokens_in_batch))
+ self.max_text_tokens_in_batch = max(1, int(max_text_tokens_in_batch))
+ self.max_batch_size = max_batch_size
+ self.sample_pool_size = max(1, int(sample_pool_size))
+ self.profiler = ensure_data_profiler(profiler)
+
+ @staticmethod
+ def _sort_pool(pool: list[_PoolSample]) -> None:
+ pool.sort(
+ key=lambda item: (
+ item.num_audio_tokens,
+ item.num_text_tokens,
+ -item.arrival_step,
+ ),
+ reverse=True,
+ )
+
+ def _choose_anchor_index(
+ self,
+ pool: list[_PoolSample],
+ *,
+ decision_step: int,
+ ) -> int:
+ oldest_waiting_index = -1
+ oldest_waiting_step = decision_step
+
+ for index, item in enumerate(pool):
+ waited_steps = decision_step - item.arrival_step
+ if waited_steps < self.sample_pool_size:
+ continue
+ if item.arrival_step <= oldest_waiting_step:
+ oldest_waiting_index = index
+ oldest_waiting_step = item.arrival_step
+
+ return 0 if oldest_waiting_index < 0 else oldest_waiting_index
+
+ def _build_next_decision(
+ self,
+ pool: list[_PoolSample],
+ *,
+ decision_step: int,
+ ) -> BatchDecision:
+ dropped_samples: list[dict] = []
+ batch_samples: list[dict] = []
+ selected_indices: list[int] = []
+ anchor_index = self._choose_anchor_index(pool, decision_step=decision_step)
+ anchor = pool[anchor_index]
+
+ exceed_audio_budget = anchor.num_audio_tokens > self.max_audio_tokens_in_batch
+ exceed_text_budget = anchor.num_text_tokens > self.max_text_tokens_in_batch
+ exceed_batch_size = self.max_batch_size is not None and self.max_batch_size < 1
+ if exceed_audio_budget or exceed_text_budget or exceed_batch_size:
+ skipped = pool.pop(anchor_index).sample
+ dropped_samples.append(skipped)
+ warnings.warn(
+ "Skipping sample that exceeds batching limits on its own: "
+ f"fid={skipped.get('fid')!r}, "
+ f"num_audio_tokens={anchor.num_audio_tokens}, "
+ f"input_ids_length={anchor.num_text_tokens}, "
+ f"max_audio_tokens_in_batch={self.max_audio_tokens_in_batch}, "
+ f"max_text_tokens_in_batch={self.max_text_tokens_in_batch}, "
+ f"max_batch_size={self.max_batch_size}",
+ RuntimeWarning,
+ stacklevel=2,
+ )
+ return BatchDecision(
+ dropped_samples=dropped_samples,
+ batch_samples=batch_samples,
+ )
+
+ longest_audio_tokens = anchor.num_audio_tokens
+ longest_text_tokens = anchor.num_text_tokens
+ batch_samples.append(anchor.sample)
+ selected_indices.append(anchor_index)
+
+ for index, item in enumerate(pool):
+ if index == anchor_index:
+ continue
+ if (
+ self.max_batch_size is not None
+ and len(batch_samples) >= self.max_batch_size
+ ):
+ break
+
+ proposed_batch_size = len(batch_samples) + 1
+ proposed_longest_audio_tokens = max(
+ longest_audio_tokens,
+ item.num_audio_tokens,
+ )
+ proposed_longest_text_tokens = max(
+ longest_text_tokens,
+ item.num_text_tokens,
+ )
+ if (
+ proposed_longest_audio_tokens * proposed_batch_size
+ > self.max_audio_tokens_in_batch
+ ):
+ continue
+ if (
+ proposed_longest_text_tokens * proposed_batch_size
+ > self.max_text_tokens_in_batch
+ ):
+ continue
+
+ batch_samples.append(item.sample)
+ selected_indices.append(index)
+ longest_audio_tokens = proposed_longest_audio_tokens
+ longest_text_tokens = proposed_longest_text_tokens
+
+ for index in sorted(set(selected_indices), reverse=True):
+ pool.pop(index)
+
+ return BatchDecision(
+ dropped_samples=dropped_samples,
+ batch_samples=batch_samples,
+ )
+
+ def build_decisions(self, sample_iter: Iterable[dict]) -> Iterator[BatchDecision]:
+ pool: list[_PoolSample] = []
+ source_exhausted = False
+ decision_step = 0
+ iterator = iter(sample_iter)
+
+ while not source_exhausted or pool:
+ while not source_exhausted and len(pool) < self.sample_pool_size:
+ try:
+ sample = next(iterator)
+ except StopIteration:
+ source_exhausted = True
+ break
+ pool.append(
+ _PoolSample(
+ sample=sample,
+ num_audio_tokens=int(sample.get("num_audio_tokens", 0)),
+ num_text_tokens=int(sample.get("input_ids_length", 0)),
+ arrival_step=decision_step,
+ )
+ )
+
+ if not pool:
+ break
+
+ profiler = self.profiler
+ with profiler.measure("main.sort_pool", count=len(pool)):
+ self._sort_pool(pool)
+ with profiler.measure("main.build_batch_decision"):
+ decision = self._build_next_decision(
+ pool,
+ decision_step=decision_step,
+ )
+ if decision.dropped_samples or decision.batch_samples:
+ decision_step += 1
+ yield decision
+ continue
+ raise RuntimeError("OnlineBatcher failed to make progress on a non-empty pool.")
diff --git a/src/dots_tts/data/builders.py b/src/dots_tts/data/builders.py
new file mode 100644
index 0000000000000000000000000000000000000000..a750951c78e47be5d28dea0baf7141f213f6a6aa
--- /dev/null
+++ b/src/dots_tts/data/builders.py
@@ -0,0 +1,194 @@
+from __future__ import annotations
+
+from torch.utils.data import DataLoader
+
+from dots_tts.config.data import DataConfig
+from dots_tts.data.pipelines.base import BaseSamplePipeline
+from dots_tts.data.pipelines.tts_pipeline import BasicTtsPipeline, InterleaveTtsPipeline
+from dots_tts.data.source_adapters.jsonl_manifest_adapter import (
+ JsonlManifestSourceAdapter,
+)
+from dots_tts.data.source_adapters.multi_source_adapter import (
+ SequentialMultiSourceAdapter,
+ SourceSpec,
+ WeightedMultiSourceAdapter,
+)
+from dots_tts.data.streaming import (
+ BatchedDataStream,
+ StreamingSampleDataset,
+ identity_collate,
+)
+
+_SOURCE_ADAPTER_CLASSES = {
+ "JsonlManifestSourceAdapter": JsonlManifestSourceAdapter,
+}
+
+
+def _build_source_pipeline(
+ tokenizer, data_cfg, pipeline_name: str, *, profiler=None
+) -> BaseSamplePipeline:
+ if pipeline_name == "basic":
+ return BasicTtsPipeline(tokenizer, data_cfg, profiler=profiler)
+ if pipeline_name == "interleave":
+ return InterleaveTtsPipeline(tokenizer, data_cfg, profiler=profiler)
+ raise ValueError(f"Unsupported data pipeline: {pipeline_name!r}")
+
+
+def _build_source_specs(data_cfg, tokenizer, *, profiler=None) -> list[SourceSpec]:
+ specs = []
+ for source_cfg in data_cfg.sources:
+ adapter_cls = _SOURCE_ADAPTER_CLASSES[source_cfg.adapter.class_name]
+ adapter = adapter_cls(**source_cfg.adapter.params)
+ specs.append(
+ SourceSpec(
+ name=source_cfg.name,
+ weight=float(source_cfg.weight),
+ adapter=adapter,
+ pipeline=_build_source_pipeline(
+ tokenizer, data_cfg, source_cfg.pipeline, profiler=profiler
+ ),
+ )
+ )
+ return specs
+
+
+def _resolve_rank_info(accelerator=None) -> tuple[int, int]:
+ rank = (
+ int(getattr(accelerator, "process_index", 0)) if accelerator is not None else 0
+ )
+ world_size = (
+ int(getattr(accelerator, "num_processes", 1)) if accelerator is not None else 1
+ )
+ return rank, world_size
+
+
+def _local_num_tokens_per_epoch(
+ global_num_tokens_per_epoch: int, *, rank: int, world_size: int
+) -> int:
+ if world_size <= 0:
+ raise ValueError(f"world_size must be positive, but got {world_size}.")
+ if rank < 0 or rank >= world_size:
+ raise ValueError(
+ f"rank must be in [0, {world_size}), but got rank={rank}."
+ )
+
+ base, remainder = divmod(int(global_num_tokens_per_epoch), int(world_size))
+ return base + int(rank < remainder)
+
+
+def _build_dataset(
+ data_cfg: DataConfig,
+ *,
+ tokenizer,
+ seed: int,
+ accelerator=None,
+ sequential: bool,
+ profiler=None,
+):
+ rank, world_size = _resolve_rank_info(accelerator)
+ source_cls = SequentialMultiSourceAdapter if sequential else WeightedMultiSourceAdapter
+ source = source_cls(
+ sources=_build_source_specs(data_cfg, tokenizer, profiler=profiler)
+ )
+ return StreamingSampleDataset(
+ source=source,
+ rank=rank,
+ world_size=world_size,
+ seed=int(seed),
+ )
+
+
+def build_training_dataset(
+ data_cfg: DataConfig,
+ tokenizer,
+ *,
+ seed: int,
+ accelerator=None,
+ profiler=None,
+):
+ if data_cfg.num_tokens_per_epoch is None:
+ raise ValueError("Training data requires num_tokens_per_epoch.")
+ return _build_dataset(
+ data_cfg,
+ tokenizer=tokenizer,
+ seed=seed,
+ accelerator=accelerator,
+ sequential=False,
+ profiler=profiler,
+ )
+
+
+def build_validation_dataset(
+ data_cfg: DataConfig,
+ tokenizer,
+ *,
+ seed: int,
+ accelerator=None,
+ profiler=None,
+):
+ return _build_dataset(
+ data_cfg,
+ tokenizer=tokenizer,
+ seed=seed,
+ accelerator=accelerator,
+ sequential=True,
+ profiler=profiler,
+ )
+
+
+def _build_sample_loader(dataset, data_cfg: DataConfig) -> DataLoader:
+ loader_kwargs = {
+ "dataset": dataset,
+ "batch_size": None,
+ "collate_fn": identity_collate,
+ "num_workers": data_cfg.num_workers,
+ "pin_memory": data_cfg.pin_memory,
+ "persistent_workers": data_cfg.num_workers > 0,
+ }
+ if data_cfg.num_workers > 0:
+ loader_kwargs["prefetch_factor"] = int(data_cfg.prefetch_factor)
+ sample_loader = DataLoader(**loader_kwargs)
+ return sample_loader
+
+
+def build_training_dataloader(
+ dataset, data_cfg: DataConfig, tokenizer, *, profiler=None
+):
+ local_num_tokens_per_epoch = _local_num_tokens_per_epoch(
+ int(data_cfg.num_tokens_per_epoch),
+ rank=int(dataset.rank),
+ world_size=int(dataset.world_size),
+ )
+ sample_loader = _build_sample_loader(dataset, data_cfg)
+ batched_stream = BatchedDataStream(
+ sample_dataset=dataset,
+ data_cfg=data_cfg,
+ tokenizer=tokenizer,
+ num_tokens_per_epoch=local_num_tokens_per_epoch,
+ profiler=profiler,
+ )
+ batched_stream.attach_loader(sample_loader)
+ return batched_stream
+
+
+def build_validation_dataloader(
+ dataset, data_cfg: DataConfig, tokenizer, *, profiler=None
+):
+ sample_loader = _build_sample_loader(dataset, data_cfg)
+ batched_stream = BatchedDataStream(
+ sample_dataset=dataset,
+ data_cfg=data_cfg,
+ tokenizer=tokenizer,
+ num_tokens_per_epoch=None,
+ profiler=profiler,
+ )
+ batched_stream.attach_loader(sample_loader)
+ return batched_stream
+
+
+__all__ = [
+ "build_training_dataloader",
+ "build_training_dataset",
+ "build_validation_dataloader",
+ "build_validation_dataset",
+]
diff --git a/src/dots_tts/data/collator.py b/src/dots_tts/data/collator.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ec98fa95b9d41e5b4faf4f282ccafeb0221ddd
--- /dev/null
+++ b/src/dots_tts/data/collator.py
@@ -0,0 +1,87 @@
+from __future__ import annotations
+
+from typing import Any
+
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+class PadCollator:
+ def __init__(self, tokenizer):
+ self.tokenizer = tokenizer
+ self.pad_token_id = tokenizer.pad_token_id
+ if self.pad_token_id is None:
+ self.pad_token_id = tokenizer.eos_token_id or 0
+
+ def __call__(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
+ if not samples:
+ raise ValueError("PadCollator received an empty sample list.")
+
+ order = sorted(
+ range(len(samples)),
+ key=lambda idx: samples[idx]["sample_length"],
+ reverse=True,
+ )
+ ordered = [samples[idx] for idx in order]
+
+ input_ids = [
+ torch.tensor(sample["input_ids"], dtype=torch.long) for sample in ordered
+ ]
+ labels = [
+ torch.tensor(sample["labels"], dtype=torch.long) for sample in ordered
+ ]
+ loss_masks = [
+ torch.tensor(sample["loss_mask"], dtype=torch.float32) for sample in ordered
+ ]
+ waveforms = [sample["sample"].squeeze(0) for sample in ordered]
+ fbank = [sample["fbank"] for sample in ordered]
+
+ return {
+ "fids": [sample["fid"] for sample in ordered],
+ "source_names": [sample.get("source_name") for sample in ordered],
+ "input_ids": pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.pad_token_id,
+ ),
+ "input_ids_lengths": torch.tensor(
+ [len(sample["input_ids"]) for sample in ordered],
+ dtype=torch.long,
+ ),
+ "labels": pad_sequence(
+ labels,
+ batch_first=True,
+ padding_value=self.pad_token_id,
+ ),
+ "loss_mask": pad_sequence(
+ loss_masks,
+ batch_first=True,
+ padding_value=0.0,
+ ),
+ "sample": pad_sequence(
+ waveforms,
+ batch_first=True,
+ padding_value=0.0,
+ ).unsqueeze(1),
+ "sample_lengths": torch.tensor(
+ [sample["sample_length"] for sample in ordered],
+ dtype=torch.long,
+ ),
+ "num_text_tokens": torch.tensor(
+ [sample["num_text_tokens"] for sample in ordered],
+ dtype=torch.long,
+ ),
+ "num_audio_tokens": torch.tensor(
+ [sample["num_audio_tokens"] for sample in ordered],
+ dtype=torch.long,
+ ),
+ "fbank": pad_sequence(
+ fbank,
+ batch_first=True,
+ padding_value=0.0,
+ ),
+ "fbank_lengths": torch.tensor(
+ [sample["fbank_length"] for sample in ordered],
+ dtype=torch.long,
+ ),
+ }
diff --git a/src/dots_tts/data/pipelines/__init__.py b/src/dots_tts/data/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2888992530b44dc91d9d3f3b81c927f82ea256f1
--- /dev/null
+++ b/src/dots_tts/data/pipelines/__init__.py
@@ -0,0 +1 @@
+"""Data pipelines package."""
diff --git a/src/dots_tts/data/pipelines/base.py b/src/dots_tts/data/pipelines/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c33c080bfa986a0078f80c0f67373e486081b484
--- /dev/null
+++ b/src/dots_tts/data/pipelines/base.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from collections.abc import Iterable, Iterator
+
+
+class BaseSamplePipeline(ABC):
+ """1:1 sample pipeline that preserves adapter resume metadata."""
+
+ @staticmethod
+ def _validate_input_sample(sample: dict) -> None:
+ if "_adapter_state" not in sample:
+ raise RuntimeError(
+ "Source sample is missing required '_adapter_state' for resume."
+ )
+
+ @abstractmethod
+ def process_sample(self, sample: dict) -> dict:
+ """Transform one raw sample into one processed sample."""
+
+ def __call__(self, samples: Iterable[dict]) -> Iterator[dict]:
+ for raw_sample in samples:
+ self._validate_input_sample(raw_sample)
+ processed = self.process_sample(dict(raw_sample))
+ if not isinstance(processed, dict):
+ raise RuntimeError(
+ f"{self.__class__.__name__}.process_sample() must return a dict."
+ )
+ item = dict(raw_sample)
+ item.update(processed)
+ self._validate_input_sample(item)
+ yield item
diff --git a/src/dots_tts/data/pipelines/preprocessing.py b/src/dots_tts/data/pipelines/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..b598b9d2f34500ae9326a6228e6f293b7497f517
--- /dev/null
+++ b/src/dots_tts/data/pipelines/preprocessing.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+import torch
+import torch.nn.functional as F
+
+DEFAULT_EDGE_SILENCE_MS = 250.0
+DEFAULT_EDGE_SILENCE_TOP_DB = 30.0
+
+
+def align_length(num_samples: int, multiple_of: int | None) -> int:
+ if multiple_of is None or multiple_of <= 0:
+ return int(num_samples)
+ if num_samples % multiple_of == 0:
+ return int(num_samples)
+ return int(((num_samples + multiple_of - 1) // multiple_of) * multiple_of)
+
+
+def pad_waveform_align_only(
+ waveform: torch.Tensor,
+ *,
+ multiple_of: int | None,
+) -> torch.Tensor:
+ if multiple_of is None or multiple_of <= 0:
+ return waveform
+
+ target_length = align_length(waveform.size(-1), multiple_of)
+ delta = target_length - waveform.size(-1)
+ if delta <= 0:
+ return waveform
+
+ return F.pad(waveform, (0, delta), "constant", 0.0)
+
+
+def normalize_edge_silence_duration(
+ waveform: torch.Tensor,
+ *,
+ sample_rate: int,
+ target_silence_duration_ms: float = DEFAULT_EDGE_SILENCE_MS,
+ top_db: float = DEFAULT_EDGE_SILENCE_TOP_DB,
+) -> torch.Tensor:
+ mono_waveform = waveform[0]
+ target_samples = int(round(float(sample_rate) * float(target_silence_duration_ms) / 1000.0))
+ amplitude = mono_waveform.abs()
+ peak = float(amplitude.max().item())
+ if peak <= 0.0:
+ waveform = waveform[..., :target_samples]
+ current_length = int(waveform.size(-1))
+ if current_length < target_samples:
+ waveform = F.pad(waveform, (0, target_samples - current_length), "constant", 0.0)
+ return waveform
+
+ threshold = peak * (10.0 ** (-float(top_db) / 20.0))
+ non_silent = torch.nonzero(amplitude > threshold, as_tuple=False).flatten()
+ first_non_silent = int(non_silent[0].item())
+ last_non_silent = int(non_silent[-1].item())
+
+ leading_silence_samples = first_non_silent
+ trailing_silence_samples = int(mono_waveform.numel()) - last_non_silent - 1
+
+ leading_delta = target_samples - leading_silence_samples
+ if leading_delta > 0:
+ waveform = F.pad(waveform, (leading_delta, 0), "constant", 0.0)
+ else:
+ trim_from_start = min(-leading_delta, int(waveform.size(-1)))
+ waveform = waveform[..., trim_from_start:]
+
+ trailing_delta = target_samples - trailing_silence_samples
+ if trailing_delta > 0:
+ return F.pad(waveform, (0, trailing_delta), "constant", 0.0)
+
+ trim_from_end = min(-trailing_delta, int(waveform.size(-1)))
+ if trim_from_end <= 0:
+ return waveform
+ return waveform[..., :-trim_from_end]
+
+
+def compute_num_audio_tokens(
+ num_samples: int, *, audio_samples_per_llm_token: int
+) -> int:
+ if num_samples % audio_samples_per_llm_token != 0:
+ raise ValueError(
+ f"Waveform length {num_samples} is not aligned to token hop {audio_samples_per_llm_token}."
+ )
+ return num_samples // audio_samples_per_llm_token
diff --git a/src/dots_tts/data/pipelines/tokenizing.py b/src/dots_tts/data/pipelines/tokenizing.py
new file mode 100644
index 0000000000000000000000000000000000000000..0133fec16c1a6b19b7f97110058ee4bdcf3b06c8
--- /dev/null
+++ b/src/dots_tts/data/pipelines/tokenizing.py
@@ -0,0 +1,339 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+import re
+from typing import Any
+
+from loguru import logger
+
+from dots_tts.utils.tokenizer import (
+ AUDIO_GEN_END_TOKEN,
+ AUDIO_GEN_SPAN_TOKEN,
+ AUDIO_GEN_START_TOKEN,
+ TEXT_COND_END_TOKEN,
+ require_token_id,
+)
+
+TEMPLATE_PATTERN = re.compile(r"\{text\}|\{audio\}|\{interleave\}|[^\{]+")
+
+
+@dataclass(frozen=True)
+class ParsedTemplate:
+ parts: tuple[str, ...]
+ has_audio_placeholder: bool
+ has_interleave_placeholder: bool
+
+
+@dataclass(frozen=True)
+class TokenizedTemplatePart:
+ kind: str
+ token_ids: tuple[int, ...] = ()
+ raw_text: str | None = None
+
+
+def parse_template(template: str) -> ParsedTemplate:
+ parts = tuple(re.findall(TEMPLATE_PATTERN, template))
+ has_audio_placeholder = "{audio}" in parts
+ interleave_count = parts.count("{interleave}")
+ if has_audio_placeholder and interleave_count:
+ raise ValueError("Template cannot mix audio and interleave placeholders.")
+ if interleave_count > 1:
+ raise ValueError(
+ "Interleave generation template must contain exactly one interleave placeholder."
+ )
+ return ParsedTemplate(
+ parts=parts,
+ has_audio_placeholder=has_audio_placeholder,
+ has_interleave_placeholder=interleave_count == 1,
+ )
+
+
+def _prepare_template_tokens(
+ *, text: str, tokenizer, template: str
+) -> tuple[ParsedTemplate, list[int]]:
+ return parse_template(template), tokenizer.encode(text, add_special_tokens=False)
+
+
+def _iter_tokenized_template_parts(
+ *,
+ parsed_template: ParsedTemplate,
+ tokenizer,
+ text_tokens: list[int],
+):
+ for part in parsed_template.parts:
+ if part == "{text}":
+ yield TokenizedTemplatePart(kind="text", token_ids=tuple(text_tokens))
+ continue
+ if part == "{audio}":
+ yield TokenizedTemplatePart(kind="audio")
+ continue
+ if part == "{interleave}":
+ yield TokenizedTemplatePart(kind="interleave")
+ continue
+ yield TokenizedTemplatePart(
+ kind="literal",
+ token_ids=tuple(tokenizer.encode(part, add_special_tokens=False)),
+ raw_text=part,
+ )
+
+
+def _extend_tokens_with_loss(
+ *, full_ids: list[int], loss_mask: list[float], token_ids: tuple[int, ...], loss: float
+) -> None:
+ full_ids.extend(token_ids)
+ loss_mask.extend([loss] * len(token_ids))
+
+
+def build_tokenized_example(
+ *, text: str, tokenizer, template: str, num_audio_tokens: int
+) -> dict[str, Any]:
+ if tokenizer.eos_token_id is None:
+ raise ValueError("Tokenizer eos_token_id is required for generation targets.")
+
+ parsed_template, text_tokens = _prepare_template_tokens(
+ text=text,
+ tokenizer=tokenizer,
+ template=template,
+ )
+
+ full_ids: list[int] = []
+ loss_mask: list[float] = []
+ audio_tokens: list[int] | None = None
+ if parsed_template.has_audio_placeholder:
+ audio_gen_start_id = require_token_id(tokenizer, AUDIO_GEN_START_TOKEN)
+ audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN)
+ audio_gen_end_id = require_token_id(tokenizer, AUDIO_GEN_END_TOKEN)
+ audio_tokens = (
+ [audio_gen_start_id]
+ + [audio_gen_span_id] * num_audio_tokens
+ + [audio_gen_end_id]
+ )
+ elif parsed_template.has_interleave_placeholder:
+ audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN)
+ audio_gen_end_id = require_token_id(tokenizer, AUDIO_GEN_END_TOKEN)
+ text_cond_end_id = require_token_id(tokenizer, TEXT_COND_END_TOKEN)
+
+ for part in _iter_tokenized_template_parts(
+ parsed_template=parsed_template,
+ tokenizer=tokenizer,
+ text_tokens=text_tokens,
+ ):
+ if part.kind == "text":
+ _extend_tokens_with_loss(
+ full_ids=full_ids,
+ loss_mask=loss_mask,
+ token_ids=part.token_ids,
+ loss=0.0,
+ )
+ continue
+
+ if part.kind == "audio":
+ if audio_tokens is None:
+ raise RuntimeError("Audio placeholder tokens were not initialized.")
+ full_ids.extend(audio_tokens)
+ loss_mask.extend([0.0])
+ loss_mask.extend([1.0] * max(0, len(audio_tokens) - 2))
+ loss_mask.append(0.0)
+ continue
+
+ if part.kind == "interleave":
+ _append_interleave_generation_tokens(
+ full_ids=full_ids,
+ loss_mask=loss_mask,
+ text_tokens=text_tokens,
+ num_audio_tokens=num_audio_tokens,
+ audio_span_id=audio_gen_span_id,
+ audio_end_id=audio_gen_end_id,
+ text_cond_end_id=text_cond_end_id,
+ )
+ continue
+
+ _extend_tokens_with_loss(
+ full_ids=full_ids,
+ loss_mask=loss_mask,
+ token_ids=part.token_ids,
+ loss=0.0,
+ )
+
+ full_ids.append(tokenizer.eos_token_id)
+ loss_mask.append(0.0)
+
+ return {
+ "input_ids": full_ids[:-1],
+ "labels": full_ids[1:],
+ "loss_mask": loss_mask[1:],
+ "text_token_count": len(text_tokens),
+ }
+
+
+def build_generation_schedule(
+ *,
+ text: str,
+ tokenizer,
+ template: str,
+ max_audio_tokens: int,
+) -> dict[str, Any]:
+ if max_audio_tokens <= 0:
+ raise ValueError("max_audio_tokens must be positive for generation.")
+
+ parsed_template, text_tokens = _prepare_template_tokens(
+ text=text,
+ tokenizer=tokenizer,
+ template=template,
+ )
+ schedule_ids: list[int] = []
+ audio_gen_start_id = require_token_id(tokenizer, AUDIO_GEN_START_TOKEN)
+ audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN)
+
+ if parsed_template.has_audio_placeholder:
+ for part in _iter_tokenized_template_parts(
+ parsed_template=parsed_template,
+ tokenizer=tokenizer,
+ text_tokens=text_tokens,
+ ):
+ if part.kind == "audio":
+ schedule_ids.append(audio_gen_start_id)
+ schedule_ids.extend([audio_gen_span_id] * max_audio_tokens)
+ continue
+ schedule_ids.extend(part.token_ids)
+ visible_schedule_ids = [
+ token_id for token_id in schedule_ids if token_id != audio_gen_span_id
+ ]
+ decoded_schedule = (
+ tokenizer.decode(
+ visible_schedule_ids,
+ skip_special_tokens=False,
+ clean_up_tokenization_spaces=False,
+ )
+ if hasattr(tokenizer, "decode")
+ else repr(visible_schedule_ids)
+ )
+ logger.info(
+ "Built generation schedule: interleave={} max_audio_tokens={} sequence={!r}",
+ False,
+ int(max_audio_tokens),
+ decoded_schedule,
+ )
+ return {
+ "schedule_ids": schedule_ids,
+ "interleave": False,
+ }
+
+ if not parsed_template.has_interleave_placeholder:
+ raise ValueError(
+ "Generation template must contain either {audio} or {interleave}."
+ )
+ text_cond_end_id = require_token_id(tokenizer, TEXT_COND_END_TOKEN)
+ if max_audio_tokens < len(text_tokens):
+ raise ValueError(
+ "Interleave generation requires at least one audio span per text token: "
+ f"text_token_count={len(text_tokens)} "
+ f"max_audio_patch_count={max_audio_tokens}."
+ )
+
+ interleave_started = False
+ for part in _iter_tokenized_template_parts(
+ parsed_template=parsed_template,
+ tokenizer=tokenizer,
+ text_tokens=text_tokens,
+ ):
+ if part.kind == "interleave":
+ _append_interleave_schedule_tokens(
+ schedule_ids=schedule_ids,
+ text_tokens=text_tokens,
+ max_audio_tokens=max_audio_tokens,
+ audio_span_id=audio_gen_span_id,
+ text_cond_end_id=text_cond_end_id,
+ )
+ interleave_started = True
+ continue
+ if part.kind == "text":
+ raise ValueError(
+ "Generation schedule does not support {text} inside an interleave template."
+ )
+ if part.kind == "audio":
+ raise ValueError(
+ "Generation schedule does not support {audio} inside an interleave template."
+ )
+ if interleave_started:
+ if (part.raw_text or "").strip():
+ raise ValueError(
+ "Generation schedule does not support non-empty suffix text after the interleave placeholder."
+ )
+ continue
+ schedule_ids.extend(part.token_ids)
+
+ visible_schedule_ids = [
+ token_id for token_id in schedule_ids if token_id != audio_gen_span_id
+ ]
+ decoded_schedule = (
+ tokenizer.decode(
+ visible_schedule_ids,
+ skip_special_tokens=False,
+ clean_up_tokenization_spaces=False,
+ )
+ if hasattr(tokenizer, "decode")
+ else repr(visible_schedule_ids)
+ )
+ logger.info(
+ "Built generation schedule: interleave={} max_audio_tokens={} sequence={!r}",
+ True,
+ int(max_audio_tokens),
+ decoded_schedule,
+ )
+ return {
+ "schedule_ids": schedule_ids,
+ "interleave": True,
+ }
+
+
+def _append_interleave_generation_tokens(
+ *,
+ full_ids: list[int],
+ loss_mask: list[float],
+ text_tokens: list[int],
+ num_audio_tokens: int,
+ audio_span_id: int,
+ audio_end_id: int,
+ text_cond_end_id: int,
+) -> None:
+ audio_tokens = [audio_span_id] * num_audio_tokens + [audio_end_id]
+ text_index = 0
+ audio_index = 0
+ text_cond_end_added = False
+
+ while text_index < len(text_tokens) or audio_index < len(audio_tokens):
+ if text_index < len(text_tokens):
+ full_ids.append(text_tokens[text_index])
+ loss_mask.append(0.0)
+ text_index += 1
+ elif not text_cond_end_added:
+ full_ids.append(text_cond_end_id)
+ loss_mask.append(0.0)
+ text_cond_end_added = True
+
+ if audio_index < len(audio_tokens):
+ full_ids.append(audio_tokens[audio_index])
+ loss_mask.append(1.0 if audio_index < num_audio_tokens else 0.0)
+ audio_index += 1
+
+ if not text_cond_end_added:
+ full_ids.append(text_cond_end_id)
+ loss_mask.append(0.0)
+
+
+def _append_interleave_schedule_tokens(
+ *,
+ schedule_ids: list[int],
+ text_tokens: list[int],
+ max_audio_tokens: int,
+ audio_span_id: int,
+ text_cond_end_id: int,
+) -> None:
+ for token_id in text_tokens:
+ schedule_ids.append(token_id)
+ schedule_ids.append(audio_span_id)
+ schedule_ids.append(text_cond_end_id)
+ remaining_audio_tokens = max_audio_tokens - len(text_tokens)
+ if remaining_audio_tokens > 0:
+ schedule_ids.extend([audio_span_id] * remaining_audio_tokens)
diff --git a/src/dots_tts/data/pipelines/tts_pipeline.py b/src/dots_tts/data/pipelines/tts_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9406d610c2d02a0301f57d8533d26d90c2d895f
--- /dev/null
+++ b/src/dots_tts/data/pipelines/tts_pipeline.py
@@ -0,0 +1,132 @@
+from __future__ import annotations
+
+import soundfile as sf
+import torch
+
+from dots_tts.utils.profiling import ensure_data_profiler
+from dots_tts.data.pipelines.base import BaseSamplePipeline
+from dots_tts.data.pipelines.preprocessing import (
+ compute_num_audio_tokens,
+ normalize_edge_silence_duration,
+ pad_waveform_align_only,
+)
+from dots_tts.data.pipelines.tokenizing import build_tokenized_example
+from dots_tts.modules.speaker.fbank import extract_speaker_fbank
+from dots_tts.utils.audio import high_quality_resample
+
+TTS_TEXT_PREFIX = "[文本]"
+TTS_AUDIO_PREFIX = "[文本对应语音]"
+TTS_INSTRUCTION_TEXT_PREFIX = "[带指令文本]"
+TTA_TEXT_PREFIX = "[声音描述]"
+TTA_AUDIO_PREFIX = "[描述对应声音]"
+TTS_INTERLEAVE_PREFIX = "[流式语音合成]"
+DEFAULT_TRAIN_TEMPLATE = f"{TTS_TEXT_PREFIX}{{text}}{TTS_AUDIO_PREFIX}{{audio}}"
+DEFAULT_INSTRUCTION_TTS_TEMPLATE = (
+ f"{TTS_INSTRUCTION_TEXT_PREFIX}{{text}}{TTS_AUDIO_PREFIX}{{audio}}"
+)
+DEFAULT_TEXT_TO_AUDIO_TEMPLATE = f"{TTA_TEXT_PREFIX}{{text}}{TTA_AUDIO_PREFIX}{{audio}}"
+DEFAULT_INTERLEAVE_TRAIN_TEMPLATE = f"{TTS_INTERLEAVE_PREFIX}{{interleave}}"
+
+
+class BasicTtsPipeline(BaseSamplePipeline):
+ """Fixed internal training pipeline for adapter-emitted samples."""
+
+ template = DEFAULT_TRAIN_TEMPLATE
+
+ def __init__(self, tokenizer, data_cfg, *, profiler=None):
+ self.tokenizer = tokenizer
+ self.train_audio_sample_rate = int(data_cfg.train_audio_sample_rate)
+ self.audio_samples_per_llm_token = int(data_cfg.audio_samples_per_llm_token)
+ self.profiler = ensure_data_profiler(profiler)
+
+ @staticmethod
+ def _load_waveform(audio_path: str) -> tuple[torch.Tensor, int]:
+ if not isinstance(audio_path, str):
+ raise TypeError(
+ f"Training audio must be a filesystem path, got {type(audio_path)}."
+ )
+ audio_data, sample_rate = sf.read(
+ audio_path,
+ dtype="float32",
+ always_2d=True,
+ )
+ waveform = torch.from_numpy(audio_data.T)
+ if waveform.size(0) > 1:
+ waveform = waveform.mean(dim=0, keepdim=True)
+ return waveform.contiguous(), int(sample_rate)
+
+ @staticmethod
+ def _validate_source_sample(sample: dict) -> None:
+ missing = [field for field in ("fid", "text", "audio") if field not in sample]
+ if missing:
+ raise ValueError(
+ "Source adapter must emit fid/text/audio. "
+ f"Missing fields: {missing}. Sample keys: {sorted(sample.keys())}"
+ )
+
+ def process_sample(self, raw_sample: dict) -> dict:
+ sample = dict(raw_sample)
+ self._validate_source_sample(sample)
+ sample["fid"] = str(sample["fid"])
+
+ with self.profiler.measure("worker.process_sample_total"):
+ return self._process_sample_impl(sample)
+
+ def _process_sample_impl(self, sample: dict) -> dict:
+ profiler = self.profiler
+ with profiler.measure("worker.load_audio"):
+ waveform, sample_rate = self._load_waveform(sample["audio"])
+ with profiler.measure("worker.resample_audio"):
+ waveform = high_quality_resample(
+ waveform,
+ orig_sr=sample_rate,
+ target_sr=self.train_audio_sample_rate,
+ )
+ with profiler.measure("worker.normalize_edge_silence"):
+ waveform = normalize_edge_silence_duration(
+ waveform,
+ sample_rate=self.train_audio_sample_rate,
+ )
+ sample["sample"] = waveform
+ sample["sample_rate"] = self.train_audio_sample_rate
+ sample["unpadded_sample_length"] = int(waveform.size(-1))
+
+ with profiler.measure("worker.pad_audio"):
+ waveform = pad_waveform_align_only(
+ waveform,
+ multiple_of=self.audio_samples_per_llm_token,
+ )
+ sample["sample"] = waveform
+ sample["sample_length"] = int(waveform.size(-1))
+
+ num_audio_tokens = compute_num_audio_tokens(
+ sample["sample_length"],
+ audio_samples_per_llm_token=self.audio_samples_per_llm_token,
+ )
+ with profiler.measure("worker.tokenize"):
+ tokenized = build_tokenized_example(
+ text=sample["text"],
+ tokenizer=self.tokenizer,
+ template=self.template,
+ num_audio_tokens=num_audio_tokens,
+ )
+ sample["input_ids"] = tokenized["input_ids"]
+ sample["labels"] = tokenized["labels"]
+ sample["loss_mask"] = tokenized["loss_mask"]
+ sample["input_ids_length"] = len(tokenized["input_ids"])
+ sample["num_text_tokens"] = tokenized["text_token_count"]
+ sample["num_audio_tokens"] = num_audio_tokens
+ sample["num_total_tokens"] = sample["input_ids_length"]
+
+ with profiler.measure("worker.extract_fbank"):
+ fbank = extract_speaker_fbank(
+ sample["sample"],
+ sample_rate=sample["sample_rate"],
+ )
+ sample["fbank"] = fbank
+ sample["fbank_length"] = int(fbank.size(0))
+ return sample
+
+
+class InterleaveTtsPipeline(BasicTtsPipeline):
+ template = DEFAULT_INTERLEAVE_TRAIN_TEMPLATE
diff --git a/src/dots_tts/data/source_adapters/__init__.py b/src/dots_tts/data/source_adapters/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e336cead881a402227d769a4664a70e48900c2
--- /dev/null
+++ b/src/dots_tts/data/source_adapters/__init__.py
@@ -0,0 +1 @@
+"""Source adapter package."""
diff --git a/src/dots_tts/data/source_adapters/base_adapter.py b/src/dots_tts/data/source_adapters/base_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffc12aa413eab257a871250595f76f9eb595471b
--- /dev/null
+++ b/src/dots_tts/data/source_adapters/base_adapter.py
@@ -0,0 +1,91 @@
+from __future__ import annotations
+
+import random
+from abc import ABC, abstractmethod
+from collections.abc import Iterable, Sequence
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import Any, TypeVar
+
+
+@dataclass(frozen=True)
+class SourceContext:
+ """Execution context for a single adapter iterator."""
+
+ epoch: int
+ rank: int
+ world_size: int
+ worker_id: int
+ num_workers: int
+ seed: int
+
+ @property
+ def global_worker_count(self) -> int:
+ return max(1, self.world_size * self.num_workers)
+
+ @property
+ def global_worker_id(self) -> int:
+ return self.rank * self.num_workers + self.worker_id
+
+
+class BaseSourceAdapter(ABC):
+ """State-aware streaming source interface used by the training pipeline."""
+
+ @abstractmethod
+ def initial_state(self) -> dict[str, Any]:
+ """Return the default iterator state for a new worker/epoch."""
+
+ @abstractmethod
+ def iter_samples(
+ self,
+ context: SourceContext,
+ *,
+ state: dict[str, Any] | None = None,
+ ) -> Iterable[dict[str, Any]]:
+ """Yield raw samples and attach the next adapter state to each item."""
+
+ @abstractmethod
+ def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool:
+ """Return whether ``state`` points at the beginning of a source cycle."""
+
+ def normalize_state(self, state: dict[str, Any] | None) -> dict[str, Any]:
+ merged = self.initial_state()
+ if state:
+ merged.update(deepcopy(state))
+ return merged
+
+ def clone_state(self, state: dict[str, Any] | None) -> dict[str, Any]:
+ return deepcopy(self.normalize_state(state))
+
+ def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]:
+ raise RuntimeError(
+ f"{self.__class__.__name__} does not support repeated cycling."
+ )
+
+
+_T = TypeVar("_T")
+
+
+class ShardableSourceAdapter(BaseSourceAdapter):
+ """Helper mixin for deterministic rank/worker sharding."""
+
+ @staticmethod
+ def is_assigned_index(index: int, context: SourceContext) -> bool:
+ return index % context.global_worker_count == context.global_worker_id
+
+ @staticmethod
+ def shard_items(
+ items: Sequence[_T],
+ context: SourceContext,
+ *,
+ shuffle: bool = False,
+ seed_offset: int = 0,
+ ) -> list[_T]:
+ assigned = list(items)
+ if shuffle:
+ random.Random(context.seed + context.epoch + seed_offset).shuffle(assigned)
+ return [
+ item
+ for index, item in enumerate(assigned)
+ if ShardableSourceAdapter.is_assigned_index(index, context)
+ ]
diff --git a/src/dots_tts/data/source_adapters/jsonl_manifest_adapter.py b/src/dots_tts/data/source_adapters/jsonl_manifest_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3b0759bbbbadde00173abc178fdf2cae1d9c2ac
--- /dev/null
+++ b/src/dots_tts/data/source_adapters/jsonl_manifest_adapter.py
@@ -0,0 +1,132 @@
+from __future__ import annotations
+
+import json
+import random
+from collections.abc import Iterable, Iterator
+from pathlib import Path
+from typing import Any
+
+from dots_tts.data.source_adapters.base_adapter import (
+ BaseSourceAdapter,
+ ShardableSourceAdapter,
+ SourceContext,
+)
+
+
+class JsonlManifestSourceAdapter(ShardableSourceAdapter, BaseSourceAdapter):
+ """Finite adapter for line-delimited JSON manifests."""
+
+ def __init__(
+ self,
+ *,
+ manifest_path: str,
+ fid_key: str = "fid",
+ text_key: str = "text",
+ audio_key: str = "audio",
+ shuffle: bool = False,
+ encoding: str = "utf-8",
+ ):
+ self.manifest_path = Path(manifest_path)
+ self.fid_key = fid_key
+ self.text_key = text_key
+ self.audio_key = audio_key
+ self.shuffle = shuffle
+ self.encoding = encoding
+ self._records: list[dict[str, Any]] | None = None
+
+ def initial_state(self) -> dict[str, Any]:
+ return {"cycle": 0, "cursor": 0}
+
+ def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool:
+ normalized = self.normalize_state(state)
+ return int(normalized["cursor"]) == 0
+
+ def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]:
+ normalized = self.normalize_state(state)
+ return {"cycle": int(normalized["cycle"]) + 1, "cursor": 0}
+
+ def _iter_records(self) -> Iterator[dict[str, Any]]:
+ if not self.manifest_path.is_file():
+ raise FileNotFoundError(f"Manifest file not found: {self.manifest_path!s}")
+ with self.manifest_path.open("r", encoding=self.encoding) as fin:
+ for line_no, raw_line in enumerate(fin, start=1):
+ line = raw_line.strip()
+ if not line:
+ continue
+ try:
+ yield json.loads(line)
+ except json.JSONDecodeError as exc:
+ raise ValueError(
+ f"Invalid JSON at {self.manifest_path}:{line_no}"
+ ) from exc
+
+ def _base_records(self) -> list[dict[str, Any]]:
+ if self._records is None:
+ self._records = list(self._iter_records())
+ return self._records
+
+ def _build_sample(self, record: dict[str, Any]) -> dict[str, Any]:
+ missing = [
+ key
+ for key in (self.fid_key, self.text_key, self.audio_key)
+ if key not in record
+ ]
+ if missing:
+ raise KeyError(
+ f"Manifest record is missing required keys {missing}: {record}"
+ )
+
+ sample = {
+ "fid": str(record[self.fid_key]),
+ "text": record[self.text_key],
+ "audio": record[self.audio_key],
+ }
+ for key, value in record.items():
+ if key in {self.fid_key, self.text_key, self.audio_key}:
+ continue
+ sample[key] = value
+ return sample
+
+ def _indices_for_cycle(
+ self,
+ context: SourceContext,
+ *,
+ cycle: int,
+ ) -> list[int]:
+ indices = list(range(len(self._base_records())))
+ if self.shuffle:
+ random.Random(context.seed + context.epoch + 1009 * int(cycle)).shuffle(
+ indices
+ )
+ indices = [
+ record_index
+ for shuffled_index, record_index in enumerate(indices)
+ if self.is_assigned_index(shuffled_index, context)
+ ]
+ else:
+ indices = [
+ record_index
+ for record_index in indices
+ if self.is_assigned_index(record_index, context)
+ ]
+ return indices
+
+ def iter_samples(
+ self,
+ context: SourceContext,
+ *,
+ state: dict[str, Any] | None = None,
+ ) -> Iterable[dict[str, Any]]:
+ live_state = self.normalize_state(state)
+ cycle = int(live_state["cycle"])
+ cursor = int(live_state["cursor"])
+ records = self._base_records()
+ indices = self._indices_for_cycle(context, cycle=cycle)
+
+ for position in range(cursor, len(indices)):
+ sample = self._build_sample(records[indices[position]])
+ sample["_adapter_state"] = {
+ "cycle": cycle,
+ "cursor": position + 1,
+ }
+ yield sample
diff --git a/src/dots_tts/data/source_adapters/multi_source_adapter.py b/src/dots_tts/data/source_adapters/multi_source_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..c43a2bc88cb0325d90dc44c62b1f3ba3137ded25
--- /dev/null
+++ b/src/dots_tts/data/source_adapters/multi_source_adapter.py
@@ -0,0 +1,222 @@
+from __future__ import annotations
+
+from collections.abc import Iterable
+from copy import deepcopy
+from dataclasses import dataclass
+
+from dots_tts.data.pipelines.base import BaseSamplePipeline
+from dots_tts.data.source_adapters.base_adapter import (
+ BaseSourceAdapter,
+ SourceContext,
+)
+
+
+@dataclass(frozen=True)
+class SourceSpec:
+ name: str
+ weight: float
+ adapter: BaseSourceAdapter
+ pipeline: BaseSamplePipeline
+
+
+_UINT64_MASK = 0xFFFFFFFFFFFFFFFF
+
+
+def _mix_uint64(value: int) -> int:
+ value = (value ^ (value >> 30)) * 0xBF58476D1CE4E5B9
+ value &= _UINT64_MASK
+ value = (value ^ (value >> 27)) * 0x94D049BB133111EB
+ value &= _UINT64_MASK
+ return (value ^ (value >> 31)) & _UINT64_MASK
+
+
+def _stable_seed(*parts: int) -> int:
+ value = 0x9E3779B97F4A7C15
+ for part in parts:
+ value = (value + int(part) + 0x9E3779B97F4A7C15) & _UINT64_MASK
+ value = _mix_uint64(value)
+ return value
+
+
+class SequentialMultiSourceAdapter(BaseSourceAdapter):
+ """Finite adapter that concatenates sources in the configured order."""
+
+ def __init__(self, *, sources: list[SourceSpec]):
+ if not sources:
+ raise ValueError(
+ "SequentialMultiSourceAdapter requires at least one source."
+ )
+ self.sources = list(sources)
+
+ def initial_state(self) -> dict:
+ return {
+ "source_index": 0,
+ "sources": {
+ source.name: source.adapter.initial_state() for source in self.sources
+ },
+ }
+
+ def is_cycle_start_state(self, state: dict | None) -> bool:
+ normalized = self.normalize_state(state)
+ if int(normalized["source_index"]) != 0:
+ return False
+ return all(
+ source.adapter.is_cycle_start_state(normalized["sources"][source.name])
+ for source in self.sources
+ )
+
+ def normalize_state(self, state: dict | None) -> dict:
+ normalized = super().normalize_state(state)
+ source_states = normalized.get("sources") or {}
+ normalized["sources"] = {
+ source.name: source.adapter.clone_state(source_states.get(source.name))
+ for source in self.sources
+ }
+ normalized["source_index"] = int(normalized.get("source_index", 0))
+ return normalized
+
+ def clone_state(self, state: dict | None) -> dict:
+ return deepcopy(self.normalize_state(state))
+
+ def iter_samples(
+ self,
+ context: SourceContext,
+ *,
+ state: dict | None = None,
+ ) -> Iterable[dict]:
+ live_state = self.normalize_state(state)
+ start_index = int(live_state["source_index"])
+ for index in range(start_index, len(self.sources)):
+ source = self.sources[index]
+ child_state = live_state["sources"][source.name]
+ raw_iter = source.adapter.iter_samples(context, state=child_state)
+ for sample in source.pipeline(raw_iter):
+ item = dict(sample)
+ next_child_state = item.pop("_adapter_state", None)
+ if next_child_state is None:
+ raise RuntimeError(
+ f"{source.adapter.__class__.__name__} must attach '_adapter_state' to samples."
+ )
+ live_state["source_index"] = index
+ live_state["sources"][source.name] = source.adapter.clone_state(
+ next_child_state
+ )
+ item["source_name"] = source.name
+ item["_adapter_state"] = self.clone_state(live_state)
+ yield item
+ live_state["source_index"] = index + 1
+
+
+class WeightedMultiSourceAdapter(BaseSourceAdapter):
+ """Infinite weighted sampler that cycles each child source independently."""
+
+ def __init__(self, *, sources: list[SourceSpec]):
+ if not sources:
+ raise ValueError("WeightedMultiSourceAdapter requires at least one source.")
+ invalid = [source.name for source in sources if float(source.weight) <= 0.0]
+ if invalid:
+ raise ValueError(f"Source weights must be positive: {invalid}")
+ self.sources = list(sources)
+ self._cumulative_weights = []
+ total = 0.0
+ for source in self.sources:
+ total += float(source.weight)
+ self._cumulative_weights.append(total)
+ self._total_weight = total
+
+ def initial_state(self) -> dict:
+ return {
+ "draw_count": 0,
+ "sources": {
+ source.name: source.adapter.initial_state() for source in self.sources
+ },
+ }
+
+ def is_cycle_start_state(self, state: dict | None) -> bool:
+ normalized = self.normalize_state(state)
+ if int(normalized["draw_count"]) != 0:
+ return False
+ return all(
+ source.adapter.is_cycle_start_state(normalized["sources"][source.name])
+ for source in self.sources
+ )
+
+ def normalize_state(self, state: dict | None) -> dict:
+ normalized = super().normalize_state(state)
+ source_states = normalized.get("sources") or {}
+ normalized["sources"] = {
+ source.name: source.adapter.clone_state(source_states.get(source.name))
+ for source in self.sources
+ }
+ normalized["draw_count"] = int(normalized.get("draw_count", 0))
+ return normalized
+
+ def clone_state(self, state: dict | None) -> dict:
+ return deepcopy(self.normalize_state(state))
+
+ def _source_draw_value(self, context: SourceContext, draw_count: int) -> float:
+ raw = _stable_seed(
+ context.seed,
+ context.epoch,
+ context.rank,
+ context.worker_id,
+ draw_count,
+ )
+ return (raw / float(1 << 64)) * self._total_weight
+
+ def _pick_source(self, context: SourceContext, draw_count: int) -> SourceSpec:
+ draw_value = self._source_draw_value(context, draw_count)
+ for source, upper in zip(self.sources, self._cumulative_weights, strict=True):
+ if draw_value < upper:
+ return source
+ return self.sources[-1]
+
+ def iter_samples(
+ self,
+ context: SourceContext,
+ *,
+ state: dict | None = None,
+ ) -> Iterable[dict]:
+ live_state = self.normalize_state(state)
+ iterators: dict[str, object] = {}
+
+ while True:
+ draw_count = int(live_state["draw_count"])
+ source = self._pick_source(context, draw_count)
+
+ while True:
+ child_state = live_state["sources"][source.name]
+ child_iter = iterators.get(source.name)
+ if child_iter is None:
+ raw_iter = source.adapter.iter_samples(context, state=child_state)
+ child_iter = iter(source.pipeline(raw_iter))
+ iterators[source.name] = child_iter
+
+ try:
+ sample = dict(next(child_iter))
+ except StopIteration:
+ if source.adapter.is_cycle_start_state(child_state):
+ raise RuntimeError(
+ "Weighted source yielded no samples for this worker. "
+ f"source={source.name!r}, worker={context.global_worker_id}, "
+ f"epoch={context.epoch}"
+ )
+ iterators.pop(source.name, None)
+ live_state["sources"][source.name] = source.adapter.advance_cycle(
+ child_state
+ )
+ continue
+
+ next_child_state = sample.pop("_adapter_state", None)
+ if next_child_state is None:
+ raise RuntimeError(
+ f"{source.adapter.__class__.__name__} must attach '_adapter_state' to samples."
+ )
+ live_state["sources"][source.name] = source.adapter.clone_state(
+ next_child_state
+ )
+ live_state["draw_count"] = draw_count + 1
+ sample["source_name"] = source.name
+ sample["_adapter_state"] = self.clone_state(live_state)
+ yield sample
+ break
diff --git a/src/dots_tts/data/streaming.py b/src/dots_tts/data/streaming.py
new file mode 100644
index 0000000000000000000000000000000000000000..4356b9faea69635bb23203bafaaaf00fdd0092c8
--- /dev/null
+++ b/src/dots_tts/data/streaming.py
@@ -0,0 +1,400 @@
+from __future__ import annotations
+
+import math
+import multiprocessing as mp
+from collections.abc import Iterable
+from copy import deepcopy
+
+from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+
+from dots_tts.data.batchers import OnlineBatcher
+from dots_tts.utils.profiling import ensure_data_profiler
+from dots_tts.data.source_adapters.base_adapter import BaseSourceAdapter, SourceContext
+
+_TRACKING_KEY = "__tracking_state__"
+_RESUME_TOPOLOGY_KEY = "resume_topology"
+
+
+def identity_collate(sample):
+ return sample
+
+
+class StreamingSampleDataset(IterableDataset):
+ def __init__(
+ self,
+ *,
+ source: BaseSourceAdapter,
+ rank: int,
+ world_size: int,
+ seed: int,
+ ):
+ self.source = source
+ self.rank = int(rank)
+ self.world_size = int(world_size)
+ self.seed = int(seed)
+ self._epoch = mp.Value("q", 0)
+ self._pending_resume_state: dict | None = None
+
+ def load_state_dict(self, state: dict | None) -> None:
+ self._pending_resume_state = deepcopy(state) if state else None
+
+ def set_epoch(self, epoch: int) -> None:
+ with self._epoch.get_lock():
+ self._epoch.value = int(epoch)
+
+ def _current_epoch(self) -> int:
+ with self._epoch.get_lock():
+ return int(self._epoch.value)
+
+ def _take_resume_state(self, epoch: int) -> dict | None:
+ if (
+ self._pending_resume_state is None
+ or int(self._pending_resume_state.get("epoch", -1)) != int(epoch)
+ ):
+ return None
+ state = deepcopy(self._pending_resume_state)
+ self._pending_resume_state = None
+ return state
+
+ @staticmethod
+ def _validate_resume_topology(
+ resume_state: dict,
+ *,
+ context: SourceContext,
+ loader_num_workers: int,
+ ) -> None:
+ resume_topology = resume_state.get(_RESUME_TOPOLOGY_KEY)
+ if not isinstance(resume_topology, dict):
+ raise RuntimeError(
+ "Resume state is missing required worker topology metadata."
+ )
+ expected_world_size = int(resume_topology["world_size"])
+ expected_num_workers = int(resume_topology["loader_num_workers"])
+ expected_global_worker_count = int(resume_topology["global_worker_count"])
+ current_num_workers = int(loader_num_workers)
+ current_global_worker_count = int(context.global_worker_count)
+ if (
+ expected_world_size != int(context.world_size)
+ or expected_num_workers != current_num_workers
+ or expected_global_worker_count != current_global_worker_count
+ ):
+ raise RuntimeError(
+ "Resume requires the same data worker topology as the saved state. "
+ f"saved(world_size={expected_world_size}, "
+ f"num_workers_per_rank={expected_num_workers}, "
+ f"global_worker_count={expected_global_worker_count}), "
+ f"current(world_size={context.world_size}, "
+ f"num_workers_per_rank={current_num_workers}, "
+ f"global_worker_count={current_global_worker_count})."
+ )
+
+ def __iter__(self) -> Iterable[dict]:
+ worker_info = get_worker_info()
+ if worker_info is None:
+ worker_id = 0
+ loader_num_workers = 0
+ effective_num_workers = 1
+ else:
+ worker_id = worker_info.id
+ loader_num_workers = worker_info.num_workers
+ effective_num_workers = worker_info.num_workers
+
+ epoch = self._current_epoch()
+ context = SourceContext(
+ epoch=epoch,
+ rank=self.rank,
+ world_size=self.world_size,
+ worker_id=worker_id,
+ num_workers=effective_num_workers,
+ seed=self.seed,
+ )
+ resume_state = self._take_resume_state(epoch)
+ if resume_state is not None:
+ self._validate_resume_topology(
+ resume_state,
+ context=context,
+ loader_num_workers=loader_num_workers,
+ )
+ worker_state = (
+ None
+ if resume_state is None
+ else (resume_state.get("workers") or {}).get(str(context.global_worker_id))
+ )
+ sample_iter = self.source.iter_samples(
+ context,
+ state=None if worker_state is None else worker_state.get("adapter_state"),
+ )
+ for sample in sample_iter:
+ sample["data_worker_id"] = context.worker_id
+ sample["data_global_worker_id"] = context.global_worker_id
+ yield sample
+
+
+class _DataStateTracker:
+ def __init__(self, *, num_tokens_per_epoch: int | None):
+ self.num_tokens_per_epoch = (
+ None if num_tokens_per_epoch is None else int(num_tokens_per_epoch)
+ )
+ self._pending_state: dict | None = None
+ self._reset_for_epoch(epoch=0)
+
+ def _reset_for_epoch(self, *, epoch: int) -> None:
+ self.epoch = int(epoch)
+ self.samples_emitted = 0
+ self.num_text_tokens = 0
+ self.num_audio_tokens = 0
+ self.num_total_tokens = 0
+ self.workers: dict[str, dict] = {}
+ self._next_sample_order_by_worker: dict[str, int] = {}
+
+ def load_state_dict(self, state: dict | None) -> None:
+ self._pending_state = deepcopy(state) if state else None
+
+ def set_epoch(self, epoch: int) -> None:
+ if self._pending_state is not None and int(
+ self._pending_state.get("epoch", -1)
+ ) == int(epoch):
+ state = deepcopy(self._pending_state)
+ self._pending_state = None
+ self.epoch = int(state.get("epoch", epoch))
+ self.samples_emitted = int(state.get("samples_emitted", 0))
+ self.num_text_tokens = int(state.get("num_text_tokens", 0))
+ self.num_audio_tokens = int(state.get("num_audio_tokens", 0))
+ self.num_total_tokens = int(state.get("num_total_tokens", 0))
+ self.workers = deepcopy(state.get("workers") or {})
+ self._next_sample_order_by_worker = {
+ worker_key: int((worker_state or {}).get("sample_order", -1)) + 1
+ for worker_key, worker_state in self.workers.items()
+ }
+ return
+ self._reset_for_epoch(epoch=int(epoch))
+
+ def should_stop(self) -> bool:
+ return (
+ self.num_tokens_per_epoch is not None
+ and self.num_total_tokens >= self.num_tokens_per_epoch
+ )
+
+ def stage_sample(self, sample: dict) -> dict:
+ item = dict(sample)
+ worker_key = str(item.pop("data_global_worker_id"))
+ item.pop("data_worker_id", None)
+ adapter_state = item.pop("_adapter_state", None)
+ sample_order = int(self._next_sample_order_by_worker.get(worker_key, 0))
+ self._next_sample_order_by_worker[worker_key] = sample_order + 1
+ item[_TRACKING_KEY] = {
+ "worker_key": worker_key,
+ "adapter_state": deepcopy(adapter_state),
+ "sample_order": sample_order,
+ "num_text_tokens": int(item["num_text_tokens"]),
+ "num_audio_tokens": int(item["num_audio_tokens"]),
+ "num_total_tokens": int(
+ item.get("num_total_tokens", item["input_ids_length"])
+ ),
+ }
+ return item
+
+ def _pop_tracking(self, sample: dict) -> tuple[dict, dict]:
+ item = dict(sample)
+ tracking = item.pop(_TRACKING_KEY, None)
+ if not isinstance(tracking, dict):
+ raise RuntimeError("Tracked sample is missing internal resume metadata.")
+ return item, tracking
+
+ def _advance_worker(self, tracking: dict) -> None:
+ adapter_state = tracking.get("adapter_state")
+ if adapter_state is None:
+ return
+ worker_key = str(tracking["worker_key"])
+ sample_order = int(tracking.get("sample_order", -1))
+ current_state = self.workers.get(worker_key)
+ current_order = int((current_state or {}).get("sample_order", -1))
+ if current_order >= sample_order:
+ return
+ self.workers[worker_key] = {
+ "adapter_state": deepcopy(adapter_state),
+ "sample_order": sample_order,
+ }
+
+ def mark_samples_dropped(self, samples: list[dict]) -> None:
+ for sample in samples:
+ _, tracking = self._pop_tracking(sample)
+ self._advance_worker(tracking)
+
+ def commit_batch(self, samples: list[dict]) -> list[dict]:
+ committed: list[dict] = []
+ for sample in samples:
+ item, tracking = self._pop_tracking(sample)
+ self._advance_worker(tracking)
+ self.samples_emitted += 1
+ self.num_text_tokens += int(tracking["num_text_tokens"])
+ self.num_audio_tokens += int(tracking["num_audio_tokens"])
+ self.num_total_tokens += int(tracking["num_total_tokens"])
+ committed.append(item)
+ return committed
+
+ def state_dict(self) -> dict:
+ return {
+ "epoch": int(self.epoch),
+ "samples_emitted": int(self.samples_emitted),
+ "num_text_tokens": int(self.num_text_tokens),
+ "num_audio_tokens": int(self.num_audio_tokens),
+ "num_total_tokens": int(self.num_total_tokens),
+ "workers": deepcopy(self.workers),
+ "num_tokens_per_epoch": self.num_tokens_per_epoch,
+ }
+
+
+class BatchedDataStream:
+ def __init__(
+ self,
+ *,
+ sample_dataset: StreamingSampleDataset,
+ data_cfg,
+ tokenizer,
+ num_tokens_per_epoch: int | None,
+ profiler=None,
+ ):
+ from dots_tts.data.collator import PadCollator
+
+ self.sample_dataset = sample_dataset
+ self.profiler = ensure_data_profiler(profiler)
+ llm_token_rate = (
+ float(data_cfg.train_audio_sample_rate)
+ / float(data_cfg.audio_samples_per_llm_token)
+ )
+ self.batcher = OnlineBatcher(
+ max_audio_tokens_in_batch=max(
+ 1,
+ math.ceil(float(data_cfg.max_audio_seconds_in_batch) * llm_token_rate),
+ ),
+ max_text_tokens_in_batch=data_cfg.max_text_tokens_in_batch,
+ max_batch_size=data_cfg.max_samples_per_batch,
+ sample_pool_size=data_cfg.bucketing_pool_size,
+ profiler=self.profiler,
+ )
+ self.sample_loader = None
+ self.collator = PadCollator(tokenizer)
+ self.data_state = _DataStateTracker(
+ num_tokens_per_epoch=num_tokens_per_epoch
+ )
+ self._decision_iterator = None
+ self._sample_iterator = None
+ self._pending_batch = None
+ self._pending_samples = None
+
+ def attach_loader(self, loader: DataLoader) -> None:
+ self.sample_loader = loader
+
+ def close(self) -> None:
+ self._reset_iteration_state()
+ self.sample_loader = None
+
+ def load_state_dict(self, state: dict | None) -> None:
+ self.data_state.load_state_dict(state)
+ self.sample_dataset.load_state_dict(state)
+ self._reset_iteration_state()
+
+ def state_dict(self) -> dict:
+ if self.sample_loader is None:
+ raise RuntimeError("BatchedDataStream has no attached sample loader.")
+ if self._pending_batch is not None or self._pending_samples is not None:
+ raise RuntimeError(
+ "Cannot serialize BatchedDataStream while a batch is pending commit."
+ )
+ loader_num_workers = int(getattr(self.sample_loader, "num_workers", 0))
+ effective_num_workers = max(1, loader_num_workers)
+ state = self.data_state.state_dict()
+ state[_RESUME_TOPOLOGY_KEY] = {
+ "world_size": int(self.sample_dataset.world_size),
+ "loader_num_workers": loader_num_workers,
+ "global_worker_count": int(self.sample_dataset.world_size)
+ * effective_num_workers,
+ }
+ return state
+
+ def set_epoch(self, epoch: int) -> None:
+ self.sample_dataset.set_epoch(epoch)
+ self.data_state.set_epoch(epoch)
+ self._reset_iteration_state()
+
+ def _reset_iteration_state(self) -> None:
+ close_iterator = getattr(self._decision_iterator, "close", None)
+ if callable(close_iterator):
+ close_iterator()
+ self._decision_iterator = None
+ self._sample_iterator = None
+ self._pending_batch = None
+ self._pending_samples = None
+
+ def _iter_staged_samples(self):
+ if self.sample_loader is None:
+ raise RuntimeError("BatchedDataStream has no attached sample loader.")
+ self._sample_iterator = iter(self.sample_loader)
+ profiler = self.profiler
+ try:
+ while True:
+ if self.data_state.should_stop():
+ return
+ try:
+ with profiler.measure("main.loader_wait_next_sample"):
+ sample = next(self._sample_iterator)
+ except StopIteration:
+ return
+ if sample is None:
+ continue
+ with profiler.measure("main.stage_sample"):
+ staged = self.data_state.stage_sample(sample)
+ yield staged
+ finally:
+ self._sample_iterator = None
+
+ def _decision_stream(self):
+ if self._decision_iterator is None:
+ self._decision_iterator = iter(
+ self.batcher.build_decisions(self._iter_staged_samples())
+ )
+ return self._decision_iterator
+
+ def peek_batch(self) -> tuple[dict | None, bool]:
+ if self._pending_batch is not None:
+ return self._pending_batch, True
+
+ for decision in self._decision_stream():
+ if decision.dropped_samples:
+ self.data_state.mark_samples_dropped(decision.dropped_samples)
+ if not decision.batch_samples:
+ continue
+ self._pending_samples = decision.batch_samples
+ with self.profiler.measure(
+ "main.collate_batch",
+ count=len(decision.batch_samples),
+ ):
+ self._pending_batch = self.collator(decision.batch_samples)
+ return self._pending_batch, True
+ return None, False
+
+ def commit_batch(self) -> dict:
+ if self._pending_batch is None or self._pending_samples is None:
+ raise RuntimeError("BatchedDataStream has no pending batch to commit.")
+ pending_batch = self._pending_batch
+ self.data_state.commit_batch(self._pending_samples)
+ self._pending_batch = None
+ self._pending_samples = None
+ return pending_batch
+
+ def discard_batch(self) -> None:
+ if self._pending_batch is None or self._pending_samples is None:
+ raise RuntimeError("BatchedDataStream has no pending batch to discard.")
+ self._pending_batch = None
+ self._pending_samples = None
+
+ def __iter__(self):
+ while True:
+ batch, has_batch = self.peek_batch()
+ if not has_batch:
+ return
+ self.commit_batch()
+ yield batch
+ if self.data_state.should_stop():
+ return
diff --git a/src/dots_tts/models/__init__.py b/src/dots_tts/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..482ee568273d12024e3b0bff255e4150a15a50d0
--- /dev/null
+++ b/src/dots_tts/models/__init__.py
@@ -0,0 +1 @@
+"""Model families."""
diff --git a/src/dots_tts/models/dots_tts/__init__.py b/src/dots_tts/models/dots_tts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe9bc4ee612df1718db3054ac4383cf1f6c5a4b4
--- /dev/null
+++ b/src/dots_tts/models/dots_tts/__init__.py
@@ -0,0 +1 @@
+"""dots_tts model package."""
diff --git a/src/dots_tts/models/dots_tts/config.py b/src/dots_tts/models/dots_tts/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..56a1a61991693c091725ce9e6f5761f453f92f76
--- /dev/null
+++ b/src/dots_tts/models/dots_tts/config.py
@@ -0,0 +1,71 @@
+from __future__ import annotations
+
+from dots_tts.config.base import ConfigBase, StrictConfigBase
+from dots_tts.modules.vocoder.config import AudioVAEConfig
+
+
+class _EncoderConfig(ConfigBase):
+ num_layers: int = 6
+ num_heads: int = 16
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ modulation: bool = False
+ qkv_bias: bool = False
+ qk_norm: bool = False
+ attn_dropout: float = 0.0
+ dropout: float = 0.0
+ norm_layer: str = "LayerNorm"
+ alibi_bias: bool = False
+ rotary_bias: bool = False
+ rotary_theta: float | None = 10000
+ input_dim: int = 1024
+ causal: bool = True
+
+
+class _DiTConfig(ConfigBase):
+ num_layers: int = 18
+ num_heads: int = 16
+ hidden_size: int = 1024
+ ffn_hidden_size: int = 4096
+ modulation: bool = True
+ qkv_bias: bool = False
+ qk_norm: bool = False
+ attn_dropout: float = 0.0
+ dropout: float = 0.0
+ norm_layer: str = "LayerNorm"
+ alibi_bias: bool = False
+ rotary_bias: bool = True
+ rotary_theta: float | None = 10000
+
+
+class LossConfig(StrictConfigBase):
+ ce_weight: float = 1.0
+ fm_weight: float = 1.0
+ eos_weight: float = 1.0
+
+
+class MeanFlowConfig(ConfigBase):
+ enabled: bool = False
+ use_duration_embedding: bool = True
+
+
+class ModelConfig(ConfigBase):
+ model_type: str = "dots_tts"
+ latent_dim: int
+ patch_size: int
+ cfg_droprate: float = 0.2
+ PatchEncoder: _EncoderConfig
+ DiT: _DiTConfig
+ vocoder: AudioVAEConfig
+ fm_sigma: float = 0.0
+ xvec_drop_rate: float = 0.2
+ campplus_embedding_size: int | None = 512
+ xvec_max_audio_seconds: float = 10.0
+ meanflow: MeanFlowConfig | None = None
+
+
+__all__ = [
+ "LossConfig",
+ "MeanFlowConfig",
+ "ModelConfig",
+]
diff --git a/src/dots_tts/models/dots_tts/core.py b/src/dots_tts/models/dots_tts/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0666c82b1ce4c8f9a7cc3099816b927dcbeaacc
--- /dev/null
+++ b/src/dots_tts/models/dots_tts/core.py
@@ -0,0 +1,910 @@
+import copy
+from dataclasses import dataclass
+from typing import Any, Callable
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from loguru import logger
+from torch.nn.utils.rnn import pad_sequence
+from torchdiffeq import odeint
+from transformers import Qwen2Config, Qwen2ForCausalLM
+
+from dots_tts.models.dots_tts.config import ModelConfig
+from dots_tts.modules.backbone.dit import DiT
+from dots_tts.modules.backbone.semantic_encoder import VAESemanticEncoder
+from dots_tts.utils.tokenizer import (
+ AUDIO_COMP_SPAN_TOKEN,
+ AUDIO_GEN_SPAN_TOKEN,
+ TEXT_COND_END_TOKEN,
+ require_token_id,
+)
+from dots_tts.utils.util import get_mask_from_lengths, mask_data
+
+
+@dataclass(frozen=True)
+class DotsTtsForwardOutput:
+ llm_logits: torch.Tensor
+ pred: torch.Tensor
+ target: torch.Tensor
+ eos_out: torch.Tensor
+
+
+class DotsTtsCore(nn.Module):
+ # region Module construction
+ def __init__(
+ self,
+ config: ModelConfig,
+ llm_config: Qwen2Config,
+ tokenizer=None,
+ *,
+ latent_stats_path,
+ ):
+ super().__init__()
+ self.config = config
+ self.fm_hidden_size = config.DiT.hidden_size
+ self.hidden_patch_size = 1
+ self.cfg_droprate = config.get("cfg_droprate", 0.2)
+ self.latent_patch_size = config.patch_size
+ self.latent_dim = config.latent_dim
+ self.xvec_dim = config.campplus_embedding_size
+ self.xvec_drop_rate = config.get("xvec_drop_rate", 0.2)
+
+ # Setup tokenizer
+ self.tokenizer = tokenizer
+ if self.tokenizer is None:
+ raise RuntimeError("Tokenizer must be provided before building the model.")
+ if llm_config is None:
+ raise RuntimeError("LLM config must be provided before building the model.")
+ self.pad_token_id = getattr(self.tokenizer, "pad_token_id", None)
+ self.audio_gen_span_id = require_token_id(self.tokenizer, AUDIO_GEN_SPAN_TOKEN)
+ self.audio_comp_span_id = require_token_id(
+ self.tokenizer, AUDIO_COMP_SPAN_TOKEN
+ )
+ self.text_cond_end_id = require_token_id(self.tokenizer, TEXT_COND_END_TOKEN)
+
+ # Setup LLM with language modeling head so we can obtain logits directly
+ llm_config = copy.deepcopy(llm_config)
+ llm_config.vocab_size = len(self.tokenizer)
+ self.llm = Qwen2ForCausalLM._from_config(
+ llm_config,
+ dtype=torch.float32,
+ )
+ self.llm_hidden_size = self.llm.config.hidden_size
+
+ self.patch_encoder = VAESemanticEncoder(
+ in_dim=self.latent_dim,
+ out_dim=self.llm_hidden_size,
+ config=config,
+ )
+
+ # Setup Flow matching related modules
+ self.hidden_proj = nn.Linear(self.llm_hidden_size, self.fm_hidden_size)
+ self.latent_proj = nn.Linear(self.latent_dim, self.fm_hidden_size)
+ self.coordinate_proj = nn.Linear(self.latent_dim, self.fm_hidden_size)
+ self.xvec_proj = nn.Sequential(
+ nn.Linear(self.xvec_dim, self.fm_hidden_size),
+ nn.LayerNorm(self.fm_hidden_size),
+ )
+ self.meanflow_config = config.meanflow if config.meanflow is not None else None
+ self.mode = (
+ "meanflow"
+ if self.meanflow_config is not None and self.meanflow_config.enabled
+ else "flow_matching"
+ )
+ dit_mode = (
+ "meanflow"
+ if self.mode == "meanflow"
+ and self.meanflow_config.use_duration_embedding
+ else "flow_matching"
+ )
+ self.velocity_field_predictor = DiT(
+ in_dim=self.fm_hidden_size,
+ out_dim=self.latent_dim,
+ transformer_config=config.DiT,
+ mode=dit_mode,
+ )
+
+ # Setup eos predictor
+ self.eos_proj = nn.Sequential(
+ nn.Linear(self.llm_hidden_size, self.llm_hidden_size),
+ nn.SiLU(),
+ nn.Linear(self.llm_hidden_size, 2),
+ )
+
+ # Helpers
+ self.fm_helper = FlowMatchingHelper(sigma=config.get("fm_sigma", 0.0))
+ self.causal_helper = CausalHelper()
+ self.io_helper = IOHelper(latent_stats_path=latent_stats_path)
+ self.audio_span_token_ids: list[int] = [
+ self.audio_gen_span_id,
+ self.audio_comp_span_id,
+ ]
+ # endregion Module construction
+
+ # region Training forward path
+ def forward(self, data: dict[str, Any]) -> DotsTtsForwardOutput:
+ input_ids: torch.Tensor = data["input_ids"]
+ input_ids_lengths: torch.Tensor = data["input_ids_lengths"]
+ input_span_mask: torch.Tensor = data["input_span_mask"]
+ output_span_mask: torch.Tensor = data["output_span_mask"]
+ batch_size = input_ids.size(0)
+ device = input_ids.device
+
+ latents: torch.Tensor | None = data.get("latents")
+ latents_sampled: torch.Tensor | None = data.get("latents_sampled")
+ latent_lengths: torch.Tensor | None = data.get("latent_lengths")
+ has_latents = latents is not None or latents_sampled is not None
+
+ patch_embeddings: torch.Tensor | None
+ valid_patch_counts: torch.Tensor | None
+ if has_latents:
+ if latents_sampled is None:
+ latents_sampled = self.io_helper.sample_from_latent(latents)
+ patch_embeddings = self.patch_encoder(
+ latents_sampled, x_lens=latent_lengths
+ )
+ valid_patch_counts = latent_lengths // self.latent_patch_size
+ latents_sampled = self.io_helper.normalize(latents_sampled)
+ else:
+ latents_sampled = None
+ patch_embeddings = None
+ valid_patch_counts = torch.zeros(
+ batch_size, dtype=torch.long, device=device
+ )
+
+ input_span_counts = input_span_mask.sum(dim=1)
+ if input_span_counts.sum() > 0 and patch_embeddings is None:
+ raise RuntimeError(
+ "Found audio span tokens but no latents provided to compute patch embeddings."
+ )
+
+ # Token embeddings with audio span replacement
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
+ if patch_embeddings is not None:
+ inputs_embeds = inputs_embeds.clone()
+ patch_embeddings = patch_embeddings.to(inputs_embeds.dtype)
+ for b in range(batch_size):
+ span_num = input_span_counts[b].item()
+ if span_num == 0:
+ continue
+ expected = valid_patch_counts[b].item()
+ if expected != span_num:
+ raise RuntimeError(
+ f"Mismatch between span tokens ({span_num}) and latent patches ({expected}) for sample {b}."
+ )
+ indices = input_span_mask[b].nonzero(as_tuple=False).squeeze(-1)
+ inputs_embeds[b, indices, :] = patch_embeddings[b, :span_num, :]
+
+ # LLM forward pass to obtain logits & hidden states
+ _llm_attn_mask, llm_seq_mask, _ = self.causal_helper.create_causal_mask_and_pos(
+ seq_lens=input_ids_lengths, max_len=input_ids.size(1)
+ )
+ llm_outputs = self.llm(
+ inputs_embeds=inputs_embeds,
+ attention_mask=llm_seq_mask.long(),
+ use_cache=False,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+ llm_logits = llm_outputs.logits # [B, L, V]
+ llm_hidden = llm_outputs.hidden_states[-1] # [B, L, H]
+
+ # eos prediction, before cfg masking
+ eos = self.eos_proj(llm_hidden.detach())
+
+ # Flow matching forward
+ total_patches = int(output_span_mask.sum().item())
+ if total_patches > 0 and latents_sampled is None:
+ raise RuntimeError("Flow matching requested but latents are missing.")
+ if total_patches > 0:
+ xvec_cond = self.xvec_proj(data["xvector"])
+ vocal_mask = data.get("vocal_mask")
+ if vocal_mask is None:
+ vocal_mask = torch.ones((batch_size,), device=device, dtype=torch.bool)
+ xvec_drop_mask = (
+ torch.empty((batch_size,), device=device, dtype=torch.float32).uniform_(
+ 0, 1
+ )
+ < self.xvec_drop_rate
+ )
+ xvec_drop_mask = xvec_drop_mask & vocal_mask
+ xvec_cond = mask_data(xvec_cond, xvec_drop_mask)
+
+ hiddens_for_fm = torch.where(
+ output_span_mask.unsqueeze(-1), llm_hidden, inputs_embeds
+ )
+
+ # Prepare DiT inputs
+ (
+ fm_seq,
+ target,
+ fm_attn_mask,
+ fm_seq_mask,
+ fm_pos_ids,
+ times,
+ fm_prefix_lengths,
+ fm_gen_lengths,
+ fm_gen_patch_size,
+ ) = self.io_helper.prepare_inputs_for_dit(
+ hiddens=hiddens_for_fm,
+ hidden_lens=input_ids_lengths,
+ latents=latents_sampled,
+ latent_lens=latent_lengths,
+ hidden_proj=self.hidden_proj,
+ latent_proj=self.latent_proj,
+ noisy_proj=self.coordinate_proj,
+ span_mask=output_span_mask,
+ hidden_patch_size=self.hidden_patch_size,
+ latent_patch_size=self.latent_patch_size,
+ fm_helper=self.fm_helper,
+ cfg_droprate=self.cfg_droprate,
+ )
+
+ # Predict velocity field
+ vt = self.velocity_field_predictor(
+ x=fm_seq,
+ timesteps=times,
+ pos_ids=fm_pos_ids,
+ mask=fm_seq_mask,
+ attn_mask=fm_attn_mask,
+ return_hidden_stats=False,
+ g_cond=xvec_cond,
+ )
+
+ # Get predictions and targets
+ pred = self.io_helper.get_dit_outputs(
+ pred_v=vt,
+ fm_prefix_lengths=fm_prefix_lengths,
+ fm_gen_lengths=fm_gen_lengths,
+ fm_gen_patch_size=fm_gen_patch_size,
+ latent_patch_size=self.latent_patch_size,
+ )
+ else:
+ # Dummy forward for velocity_field_predictor to keep gradients connected in DDP
+ dummy_length = self.latent_patch_size
+ dummy_seq_h = llm_hidden.new_zeros((1, dummy_length, self.llm_hidden_size))
+ dummy_seq_h = self.hidden_proj(dummy_seq_h) * 0.0 # dummy op for ddp
+ dummy_seq_l = llm_hidden.new_zeros((1, dummy_length, self.latent_dim))
+ dummy_seq_l = self.latent_proj(dummy_seq_l) * 0.0 # dummy op for ddp
+ dummy_seq_c = llm_hidden.new_zeros((1, dummy_length, self.latent_dim))
+ dummy_seq_c = self.coordinate_proj(dummy_seq_c) * 0.0 # dummy op for ddp
+ dummy_seq = dummy_seq_h + dummy_seq_l + dummy_seq_c
+ dummy_times = torch.zeros((1,), device=device, dtype=torch.float32)
+ dummy_attn_mask = torch.ones(
+ (1, dummy_length, dummy_length), device=device, dtype=torch.bool
+ )
+ dummy_out = self.velocity_field_predictor(
+ x=dummy_seq,
+ timesteps=dummy_times,
+ attn_mask=dummy_attn_mask,
+ )
+ pred = dummy_out[:, -self.latent_patch_size :, :]
+ target = pred.detach()
+
+ return DotsTtsForwardOutput(
+ llm_logits=llm_logits,
+ pred=pred,
+ target=target,
+ eos_out=eos,
+ )
+ # endregion Training forward path
+
+ # region Autoregressive and flow-matching inference steps
+ @torch.no_grad()
+ def fm_solver_step(
+ self,
+ t: torch.Tensor,
+ z: torch.Tensor,
+ *,
+ input_sequence: torch.Tensor,
+ cfg_sequence: torch.Tensor,
+ attn_mask: torch.Tensor,
+ pos_ids: torch.Tensor | None,
+ hidden_size: int,
+ patch_size: int,
+ g_cond: torch.Tensor | None,
+ guidance_scale: torch.Tensor | float,
+ ) -> torch.Tensor:
+ batch_size = input_sequence.size(0)
+ if input_sequence.shape != cfg_sequence.shape:
+ raise ValueError(
+ "FM input_sequence and cfg_sequence must share the same shape."
+ )
+ if input_sequence.size(1) < patch_size:
+ raise ValueError(
+ "FM input sequence must reserve at least one latent patch slot."
+ )
+ latent_start = input_sequence.size(1) - patch_size
+ z = self.coordinate_proj(z)
+ z_c = input_sequence.clone()
+ z_c[:, latent_start:] = z
+ z_branches = [z_c]
+ g_cond_t = (
+ None if g_cond is None else g_cond.to(device=z_c.device, dtype=z_c.dtype)
+ )
+ g_cond_branches = None if g_cond_t is None else [g_cond_t]
+
+ z_cfg = cfg_sequence.clone()
+ z_cfg[:, latent_start:] = z
+ z_branches.append(z_cfg)
+ if g_cond_branches is not None:
+ g_cond_branches.append(torch.zeros_like(g_cond_t))
+
+ z_z = torch.cat(z_branches, dim=0)
+ t_t = t.reshape(1).repeat(len(z_branches))
+ if g_cond_branches is not None:
+ g_cond_t = torch.cat(g_cond_branches, dim=0)
+ vt = self.velocity_field_predictor(
+ x=z_z,
+ timesteps=t_t,
+ attn_mask=attn_mask,
+ pos_ids=pos_ids,
+ g_cond=g_cond_t,
+ hidden_size=patch_size * 2 + hidden_size,
+ patch_size=patch_size + 1,
+ )
+ vt = vt[:, latent_start:]
+ vt_c = vt[:batch_size]
+ vt_u = vt[batch_size:]
+ if not torch.is_tensor(guidance_scale):
+ guidance_scale = vt_c.new_tensor(float(guidance_scale))
+ else:
+ guidance_scale = guidance_scale.to(device=vt_c.device, dtype=vt_c.dtype)
+ return vt_c + guidance_scale * (vt_c - vt_u)
+
+ @torch.no_grad()
+ def step_llm(
+ self,
+ inputs_embeds: torch.Tensor | None = None,
+ input_ids: torch.Tensor | None = None,
+ past_key_values: Any | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any | None]:
+ provided = int(inputs_embeds is not None) + int(input_ids is not None)
+ if provided != 1:
+ raise ValueError(
+ "Exactly one of inputs_embeds or input_ids must be provided to step_llm()."
+ )
+
+ if inputs_embeds is not None:
+ pass
+ else:
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
+
+ outputs = self.llm(
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=True,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+
+ hidden = outputs.hidden_states[-1]
+ logits = outputs.logits
+ past_key_values = outputs.past_key_values
+
+ return inputs_embeds, hidden, logits, past_key_values
+
+ @torch.no_grad()
+ def _meanflow_step_fm(
+ self,
+ *,
+ input_sequence: torch.Tensor,
+ attn_mask: torch.Tensor,
+ pos_ids: torch.Tensor | None,
+ patch_size: int,
+ g_cond: torch.Tensor | None = None,
+ nfe: int = 2,
+ solver_step: Callable[..., torch.Tensor] | None = None,
+ ) -> torch.Tensor:
+ if nfe <= 0:
+ raise ValueError(f"MeanFlow nfe must be positive, got {nfe}.")
+ batch_size = input_sequence.size(0)
+ device = input_sequence.device
+ dtype = input_sequence.dtype
+ solver_step = self.meanflow_solver_step if solver_step is None else solver_step
+ z = (
+ torch.randn(
+ (batch_size, patch_size, self.latent_dim),
+ device=device,
+ dtype=dtype,
+ )
+ )
+ times = torch.linspace(0.0, 1.0, nfe + 1, device=device, dtype=dtype)
+
+ for step in range(nfe):
+ t = times[step].expand(batch_size)
+ dt = (times[step + 1] - times[step]).expand(batch_size)
+ z = solver_step(
+ z,
+ t=t,
+ dt=dt,
+ input_sequence=input_sequence,
+ attn_mask=attn_mask,
+ pos_ids=pos_ids,
+ patch_size=patch_size,
+ g_cond=g_cond,
+ ).clone()
+ return z
+
+ @torch.no_grad()
+ def meanflow_solver_step(
+ self,
+ z: torch.Tensor,
+ *,
+ t: torch.Tensor,
+ dt: torch.Tensor,
+ input_sequence: torch.Tensor,
+ attn_mask: torch.Tensor,
+ pos_ids: torch.Tensor | None,
+ patch_size: int,
+ g_cond: torch.Tensor | None,
+ ) -> torch.Tensor:
+ if input_sequence.size(1) < patch_size:
+ raise ValueError(
+ "MeanFlow input sequence must reserve at least one latent patch slot."
+ )
+ latent_start = input_sequence.size(1) - patch_size
+ z_proj = self.coordinate_proj(z)
+ z_c = input_sequence.clone()
+ z_c[:, latent_start:] = z_proj
+ vt = self.velocity_field_predictor(
+ x=z_c,
+ timesteps=t,
+ duration=dt,
+ attn_mask=attn_mask,
+ pos_ids=pos_ids,
+ g_cond=g_cond,
+ )
+ velocity = vt[:, latent_start:]
+ return z + velocity * dt.view(-1, 1, 1)
+
+ @torch.no_grad()
+ def _flow_matching_step_fm(
+ self,
+ *,
+ input_sequence: torch.Tensor,
+ cfg_sequence: torch.Tensor,
+ attn_mask: torch.Tensor,
+ pos_ids: torch.Tensor | None,
+ hidden_size: int,
+ patch_size: int,
+ g_cond: torch.Tensor | None = None,
+ ode_method: str = "euler",
+ num_steps: int = 10,
+ guidance_scale: float = 3.0,
+ solver_step: Callable[..., torch.Tensor] | None = None,
+ ) -> torch.Tensor:
+ batch_size = input_sequence.size(0)
+ num_evals = 0
+ solver_step = self.fm_solver_step if solver_step is None else solver_step
+ guidance_scale_tensor = input_sequence.new_tensor(float(guidance_scale))
+
+ # Prepare ODE solver
+ def solver(t, z):
+ nonlocal num_evals
+ num_evals += 1
+ return solver_step(
+ t,
+ z,
+ input_sequence=input_sequence,
+ cfg_sequence=cfg_sequence,
+ attn_mask=attn_mask,
+ pos_ids=pos_ids,
+ hidden_size=hidden_size,
+ patch_size=patch_size,
+ g_cond=g_cond,
+ guidance_scale=guidance_scale_tensor,
+ )
+
+ # Prepare noise as initial coordinate
+ noise = torch.randn(
+ (batch_size, patch_size, self.latent_dim),
+ dtype=input_sequence.dtype,
+ device=input_sequence.device,
+ )
+ # Solve
+ times = torch.tensor(
+ [0.0, 1.0], dtype=input_sequence.dtype, device=input_sequence.device
+ )
+ if ode_method in ["euler", "midpoint", "rk4"]: # fixed step size methods
+ options = {"step_size": 1.0 / num_steps}
+ else:
+ logger.warning(
+ "Using adaptive step size ODE solver for FM, NFE is not guaranteed: "
+ "ode_method={}",
+ ode_method,
+ )
+ options = {}
+ trajectory = odeint(
+ func=solver,
+ y0=noise,
+ t=times,
+ atol=1e-5,
+ rtol=1e-5,
+ method=ode_method,
+ options=options,
+ )
+ # print(f"Expected NFE: {num_steps}, Actual NFE: {num_evals}")
+ return trajectory[-1]
+
+ @torch.no_grad()
+ def step_fm(
+ self,
+ input_sequence: torch.Tensor,
+ cfg_sequence: torch.Tensor,
+ attn_mask: torch.Tensor,
+ pos_ids: torch.Tensor | None,
+ hidden_size: int,
+ patch_size: int,
+ g_cond: torch.Tensor | None = None,
+ ode_method: str = "euler",
+ num_steps: int = 10,
+ guidance_scale: float = 3.0,
+ solver_step: Callable[..., torch.Tensor] | None = None,
+ ) -> torch.Tensor:
+ if self.mode == "meanflow":
+ return self._meanflow_step_fm(
+ input_sequence=input_sequence,
+ attn_mask=attn_mask,
+ pos_ids=pos_ids,
+ patch_size=patch_size,
+ g_cond=g_cond,
+ nfe=num_steps,
+ solver_step=solver_step,
+ )
+
+ return self._flow_matching_step_fm(
+ input_sequence=input_sequence,
+ cfg_sequence=cfg_sequence,
+ attn_mask=attn_mask,
+ pos_ids=pos_ids,
+ hidden_size=hidden_size,
+ patch_size=patch_size,
+ g_cond=g_cond,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ solver_step=solver_step,
+ )
+ # endregion Autoregressive and flow-matching inference steps
+
+
+class FlowMatchingHelper:
+ """
+ Base helper for computing x_t and u_t, given target x_1 and noise x_0
+ ref: Flow matching for generative modeling, Lipman
+ """
+
+ def __init__(self, sigma=1e-5):
+ self.sigma = sigma
+
+ def compute_mu_t(self, x1, t):
+ return t * x1
+
+ def compute_sigma_t(self, t):
+ return 1 - (1 - self.sigma) * t
+
+ def sample_x_t(self, x0, x1, t):
+ mu_t = self.compute_mu_t(x1, t)
+ sigma_t = self.compute_sigma_t(t)
+ return mu_t + sigma_t * x0
+
+ def compute_u_t(self, x0, x1):
+ return x1 - (1 - self.sigma) * x0
+
+ def compute_xt_ut(self, x1, t=None, x0=None):
+ if x0 is None:
+ x0 = torch.randn_like(x1, device=x1.device)
+ if t is None:
+ t = torch.rand(x1.size(0), dtype=x1.dtype, device=x1.device)
+ times = t
+ t = t.reshape(-1, *([1] * (x1.dim() - 1)))
+ xt = self.sample_x_t(x0, x1, t)
+ ut = self.compute_u_t(x0, x1)
+ return xt, ut, times
+
+
+class CausalHelper:
+ def create_causal_mask_and_pos(self, seq_lens, max_len):
+ seq_mask = get_mask_from_lengths(seq_lens, max_len=max_len).unsqueeze(1)
+ causal_mask = (
+ torch.ones((max_len, max_len), device=seq_lens.device).triu(1).bool()
+ )
+ causal_mask = ~causal_mask.unsqueeze(0)
+ attn_mask = seq_mask & causal_mask
+ return attn_mask, seq_mask.squeeze(1), None
+
+ def create_causal_chunk_mask_and_pos(
+ self,
+ batch_size,
+ C_lens,
+ Z_lens,
+ span_mask,
+ patch_size=8,
+ ):
+ device = C_lens.device
+ total_lens = C_lens + Z_lens
+ attn_mask = torch.zeros(
+ (batch_size, total_lens.max(), total_lens.max()),
+ device=device,
+ dtype=torch.bool,
+ )
+ pos_ids = []
+ # | C2C | |
+ # | Z2C | Z2Z |
+ for i in range(batch_size):
+ C_len = C_lens[i]
+ Z_len = Z_lens[i]
+
+ # C2C parts are standard causal attention
+ attn_mask[i, :C_len, :C_len] = (
+ torch.ones((C_len, C_len), device=device, dtype=torch.bool)
+ .triu(1)
+ .logical_not()
+ )
+ # Position ids in C parts are 0, 1, 2, ..., n
+ c_pos = torch.arange(C_len, device=device, dtype=torch.float32)
+
+ # Z2Z parts are block diag attention
+ assert Z_len % patch_size == 0, "Z_len must be multiple of patch_size"
+ attn_mask[i, C_len : C_len + Z_len, C_len : C_len + Z_len] = (
+ torch.block_diag(
+ *[
+ torch.ones(
+ (patch_size, patch_size), device=device, dtype=torch.bool
+ )
+ ]
+ * (Z_len // patch_size)
+ )
+ )
+
+ # Z2C parts is full attention before current patch latents
+ # build according to span_mask
+ j_indices = torch.arange(Z_len, device=device)
+ patch_indices = j_indices // patch_size
+ patch_in_c_indices = torch.where(span_mask[i])[0][patch_indices]
+ attn_mask[
+ i,
+ C_len + j_indices.unsqueeze(1),
+ torch.arange(C_len, device=device).unsqueeze(0),
+ ] = torch.arange(C_len, device=device).unsqueeze(
+ 0
+ ) < patch_in_c_indices.unsqueeze(1)
+ # Position ids in Z parts start from current patch latents index in C parts
+ z_pos = (patch_in_c_indices + j_indices % patch_size).to(torch.float32)
+ pos_ids.append(torch.cat([c_pos, z_pos]))
+ seq_mask = get_mask_from_lengths(total_lens, max_len=total_lens.max().item())
+ pos_ids = pad_sequence(pos_ids, batch_first=True, padding_value=0.0).to(
+ C_lens.device
+ )
+ return attn_mask, seq_mask, pos_ids
+
+
+class IOHelper:
+ def __init__(self, latent_stats_path=None):
+ if latent_stats_path is not None:
+ latent_stats = torch.load(latent_stats_path, weights_only=False)
+ self.global_mean = torch.as_tensor(latent_stats["mean"])
+ self.global_var = torch.as_tensor(latent_stats["var"])
+ else:
+ self.global_mean = None
+ self.global_var = None
+
+ def normalize(self, x):
+ if self.global_mean is not None and self.global_var is not None:
+ x = (x - self.global_mean.to(x.device)) / torch.sqrt(
+ self.global_var.to(x.device)
+ )
+ return x
+
+ def denormalize(self, x):
+ if self.global_mean is not None and self.global_var is not None:
+ x = x * torch.sqrt(self.global_var.to(x.device)) + self.global_mean.to(
+ x.device
+ )
+ return x
+
+ @staticmethod
+ def sample_from_latent(latent):
+ mean, log_std = latent.chunk(2, 1)
+ z = mean + torch.randn_like(mean) * torch.exp(log_std)
+ return z.transpose(1, 2)
+
+ @staticmethod
+ def prepare_inputs_for_dit(
+ hiddens,
+ hidden_lens,
+ latents,
+ latent_lens,
+ hidden_proj,
+ latent_proj,
+ noisy_proj,
+ span_mask,
+ hidden_patch_size,
+ latent_patch_size,
+ fm_helper,
+ cfg_droprate=-1,
+ ):
+ assert hidden_patch_size == 1, "Hidden patch size > 1 is not supported."
+
+ B, _, _, device = *hiddens.shape, hiddens.device
+
+ # Gather span hidden states for flow matching using span_mask
+ span_hidden_list = []
+ for b in range(B):
+ indices = span_mask[b].nonzero(as_tuple=False).squeeze(-1)
+ span_hidden_list.append(hiddens[b, indices, :])
+ hiddens = pad_sequence(span_hidden_list, batch_first=True, padding_value=0.0)
+ hidden_lens = torch.tensor(
+ [t.size(0) for t in span_hidden_list], device=device, dtype=torch.long
+ )
+
+ # Update span_mask to be all True for the new lengths
+ max_len = hiddens.size(1)
+ span_mask = torch.arange(max_len, device=device).expand(
+ B, max_len
+ ) < hidden_lens.unsqueeze(1)
+
+ # Prepare history latents
+ history_latents = latent_proj(latents)
+ fm_dim = history_latents.shape[-1]
+ assert (latent_patch_size * history_latents.size(1) % latents.size(1)) == 0
+ latent_history_patch_size = (
+ latent_patch_size * history_latents.size(1) // latents.size(1)
+ )
+
+ # Prepare llm hidden with cfg masking
+ cfg_mask = (
+ torch.empty((B,), dtype=torch.float, device=latents.device).uniform_(0, 1)
+ < cfg_droprate
+ )
+ hiddens = hidden_proj(mask_data(hiddens, cfg_mask))
+
+ # Prepare noise latents
+ xt, ut, times = fm_helper.compute_xt_ut(latents)
+ projected_noise = noisy_proj(xt)
+
+ # Initialize empty fm_seq
+ hist_chunk_size = hidden_patch_size + latent_history_patch_size
+ valid_patch_counts = latent_lens // latent_patch_size
+ fm_prefix_lengths = hidden_lens + valid_patch_counts * (
+ hist_chunk_size - hidden_patch_size
+ )
+ fm_gen_lengths = latent_lens + valid_patch_counts * hidden_patch_size
+ fm_gen_patch_size = hidden_patch_size + latent_patch_size
+ fm_seq_lengths = fm_prefix_lengths + fm_gen_lengths
+ fm_seq = torch.zeros(
+ (B, fm_seq_lengths.max().item(), fm_dim),
+ dtype=history_latents.dtype,
+ device=device,
+ )
+ fm_target = []
+ patch_context_lengths = []
+ history_latent_span_mask = torch.zeros(
+ (B, fm_seq_lengths.max().item()), dtype=torch.bool, device=device
+ ) # to mark start positions of each history latents
+
+ # Fill fm_seq
+ for b in range(B):
+ # Step 1: Interleave hiddens at span positions with patched_latents
+ interleaved = []
+ span_mask_b = span_mask[b, : hidden_lens[b]]
+ interleaved.append(
+ hiddens[b, : hidden_lens[b]][span_mask_b].reshape(
+ valid_patch_counts[b], hidden_patch_size, fm_dim
+ )
+ )
+ interleaved.append(
+ history_latents[
+ b, : valid_patch_counts[b] * latent_history_patch_size, :
+ ].reshape(valid_patch_counts[b], latent_history_patch_size, fm_dim)
+ )
+ interleaved = torch.cat(interleaved, dim=1)
+ interleaved = rearrange(
+ interleaved, "n h d -> (n h) d"
+ ) # [num_spans*hist_chunk_size, D]
+
+ # Step 2: Build mapping from input positions to fm positions
+ position_increment = torch.where(
+ span_mask_b, hist_chunk_size, 1
+ ) # span->hist_chunk_size, non-span->1
+ fm_seq_positions = (
+ torch.cumsum(position_increment, dim=0) - position_increment
+ )
+
+ # Step 3: Scatter non-span hiddens
+ non_span_mask = ~span_mask_b
+ non_span_indices = fm_seq_positions[non_span_mask] # [num_non_spans]
+ fm_seq[b, non_span_indices, :] = hiddens[b, : hidden_lens[b]][
+ non_span_mask, :
+ ]
+
+ # Step 4: Scatter interleaved span tokens
+ span_indices = fm_seq_positions[span_mask_b] # [num_spans]
+ span_indices_expanded = torch.stack(
+ [span_indices + i for i in range(hist_chunk_size)], dim=1
+ ) # [num_spans, hist_chunk_size]
+ span_indices_flat = span_indices_expanded.reshape(
+ -1
+ ) # [num_spans*hist_chunk_size]
+ fm_seq[b, span_indices_flat, :] = interleaved
+ history_latent_span_mask[b, span_indices] = True
+ patch_context_lengths.append(span_indices.clone())
+
+ # Step 5: Fill with noise latents at the end
+ noise_part = []
+ span_mask_b = span_mask[b, : hidden_lens[b]]
+ noise_part.append(
+ hiddens[b, : hidden_lens[b]][span_mask_b].reshape(
+ valid_patch_counts[b], hidden_patch_size, fm_dim
+ )
+ )
+ noise_part.append(
+ projected_noise[b, : latent_lens[b], :].reshape(
+ valid_patch_counts[b], latent_patch_size, fm_dim
+ )
+ )
+ noise_part = torch.cat(noise_part, dim=1)
+ noise_part = rearrange(noise_part, "n h d -> (n h) d")
+ noise_start = fm_seq_positions[-1] + position_increment[-1]
+ noise_end = noise_start + fm_gen_lengths[b]
+ fm_seq[b, noise_start:noise_end, :] = noise_part
+
+ # Step 6: prepare fm_target
+ ut_b = ut[b, : latent_lens[b], :]
+ fm_target.append(rearrange(ut_b, "(n p) d -> n p d", p=latent_patch_size))
+
+ # Construct fm_attn_mask and fm_pos_ids
+ fm_attn_mask, fm_seq_mask, fm_pos_ids = (
+ CausalHelper().create_causal_chunk_mask_and_pos(
+ batch_size=B,
+ C_lens=fm_prefix_lengths,
+ Z_lens=fm_gen_lengths,
+ span_mask=history_latent_span_mask,
+ patch_size=fm_gen_patch_size,
+ )
+ )
+ fm_prefix_lengths = fm_prefix_lengths.unsqueeze(1)
+ fm_gen_lengths = fm_gen_lengths.unsqueeze(1)
+ fm_target = torch.cat(fm_target, dim=0)
+ results = [
+ fm_seq,
+ fm_target,
+ fm_attn_mask,
+ fm_seq_mask,
+ fm_pos_ids,
+ times,
+ fm_prefix_lengths,
+ fm_gen_lengths,
+ fm_gen_patch_size,
+ ]
+ return tuple(results)
+
+ @staticmethod
+ def get_dit_outputs(
+ pred_v,
+ fm_prefix_lengths,
+ fm_gen_lengths,
+ fm_gen_patch_size,
+ latent_patch_size,
+ ):
+ B, P = fm_prefix_lengths.shape
+ fm_pred = []
+ for b in range(B):
+ p_offset = 0
+ for p in range(P):
+ latents_b = pred_v[
+ b,
+ p_offset + fm_prefix_lengths[b][p] : p_offset
+ + fm_prefix_lengths[b][p]
+ + fm_gen_lengths[b][p],
+ ]
+ latents_b = rearrange(
+ latents_b, "(n p) d -> n p d", p=fm_gen_patch_size
+ )
+ # extract only the latent parts
+ latents_b = latents_b[:, -latent_patch_size:, :]
+ fm_pred.append(latents_b)
+ p_offset += fm_prefix_lengths[b][p] + fm_gen_lengths[b][p]
+ return torch.cat(fm_pred, dim=0)
diff --git a/src/dots_tts/models/dots_tts/model.py b/src/dots_tts/models/dots_tts/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7f3e1c6989be57dc521057c2ec47a9b4b5a0c50
--- /dev/null
+++ b/src/dots_tts/models/dots_tts/model.py
@@ -0,0 +1,1958 @@
+from __future__ import annotations
+
+import json
+import math
+import os
+import shutil
+from dataclasses import dataclass
+from functools import partial
+from pathlib import Path
+from typing import Any, Callable, Iterator
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from loguru import logger
+from safetensors.torch import load_file, save_file
+from transformers import AutoTokenizer, Qwen2Config
+
+from dots_tts.models.dots_tts.config import ModelConfig
+from dots_tts.models.dots_tts.core import DotsTtsCore, DotsTtsForwardOutput
+from dots_tts.modules.speaker.encoder import SpeakerXVectorFeatures
+from dots_tts.modules.vocoder.bigvgan import AudioVAE
+from dots_tts.training.losses import LossMasks, LossTerm, LossTerms
+from dots_tts.utils.profiling import measure_inference
+from dots_tts.utils.tokenizer import AUDIO_GEN_START_TOKEN, require_token_id
+from dots_tts.utils.util import get_dtype
+
+
+_AOTI_BACKENDS = {"aoti", "aot", "aotinductor", "aot_inductor"}
+
+
+class _AotiMethodModule(nn.Module):
+ def __init__(self, owner: nn.Module, method_name: str):
+ super().__init__()
+ self.owner = owner
+ self.method_name = method_name
+
+ def forward(self, *args, **kwargs):
+ raw_method = getattr(type(self.owner), self.method_name, None)
+ if raw_method is None:
+ return getattr(self.owner, self.method_name)(*args, **kwargs)
+ raw_callable = getattr(raw_method, "__wrapped__", raw_method)
+ return raw_callable(self.owner, *args, **kwargs)
+
+
+class _LazyAotiCompiledMethod:
+ def __init__(
+ self,
+ *,
+ key: str,
+ owner: nn.Module,
+ method_name: str,
+ signature: tuple[Any, ...] | None,
+ ):
+ self.key = key
+ self.owner = owner
+ self.method_name = method_name
+ self.signature = signature
+ self.compiled: Callable[..., Any] | None = None
+ self.fallback: Callable[..., Any] | None = None
+
+ def __call__(self, *args, **kwargs):
+ if self.compiled is not None:
+ return self.compiled(*args, **kwargs)
+ if self.fallback is not None:
+ return self.fallback(*args, **kwargs)
+
+ try:
+ import spaces # noqa: PLC0415
+
+ if not hasattr(spaces, "aoti_compile"):
+ raise RuntimeError("spaces.aoti_compile is not available.")
+ exported = torch.export.export(
+ _AotiMethodModule(self.owner, self.method_name).eval(),
+ args=args,
+ kwargs=kwargs,
+ )
+ self.compiled = spaces.aoti_compile(exported)
+ logger.info(
+ "AOTI compiled inference target: key={} method={} signature={}",
+ self.key,
+ self.method_name,
+ self.signature,
+ )
+ return self.compiled(*args, **kwargs)
+ except Exception:
+ if os.environ.get("DOTS_TTS_AOTI_ALLOW_EAGER_FALLBACK", "0") != "1":
+ raise
+ logger.exception(
+ "AOTI compile failed; falling back to eager method: key={} method={} signature={}",
+ self.key,
+ self.method_name,
+ self.signature,
+ )
+ self.fallback = getattr(self.owner, self.method_name)
+ return self.fallback(*args, **kwargs)
+
+
+@dataclass
+class _GenerateState:
+ llm_cache: Any | None = None
+ llm_hiddens: torch.Tensor | None = None
+ patch_encoder_state: Any | None = None
+ fm_seq_len: int = 0
+ fm_capacity: int = 0
+ fm_sequence: torch.Tensor | None = None
+ fm_cfg_sequence: torch.Tensor | None = None
+ fm_null_g_cond: torch.Tensor | None = None
+ end_flag: bool = False
+
+
+@dataclass(frozen=True)
+class _PromptConditioning:
+ prompt_patches: torch.Tensor | None = None
+ prompt_latents: torch.Tensor | None = None
+ g_cond: torch.Tensor | None = None
+
+
+@dataclass(frozen=True)
+class _GenerateLengthBucket:
+ size: int
+
+ def run_warmup(
+ self,
+ model: "DotsTtsModel",
+ *,
+ precision: str,
+ ode_method: str,
+ num_steps: int,
+ guidance_scale: float,
+ ) -> None:
+ model._warmup_fm_bucket(
+ max_audio_patch_count=self.size,
+ precision=precision,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ )
+ model._warmup_patch_encoder_bucket(
+ max_audio_patch_count=self.size,
+ precision=precision,
+ )
+ device = next(model.core.parameters()).device
+ generation_schedule = torch.full(
+ (1, self.size + 1),
+ fill_value=model.core.audio_gen_span_id,
+ dtype=torch.long,
+ device=device,
+ )
+ generation_schedule[0, 0] = model.audio_gen_start_id
+ warmup_inputs = {"generation_schedule": generation_schedule}
+
+ for _ in model.generate_audio_stream(
+ warmup_inputs,
+ precision=precision,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ ):
+ return
+ raise RuntimeError(
+ f"Warmup produced no audio chunk for generate bucket {self.size}."
+ )
+
+
+class DotsTtsModel(nn.Module):
+ """Full train/infer model assembly around the dots.tts core network."""
+
+ _GENERATE_LENGTH_BUCKETS = (
+ _GenerateLengthBucket(32),
+ _GenerateLengthBucket(64),
+ _GenerateLengthBucket(128),
+ _GenerateLengthBucket(256),
+ _GenerateLengthBucket(512),
+ _GenerateLengthBucket(1024),
+ )
+ _COMPILE_TARGETS = frozenset(
+ {
+ "FM",
+ "patch_encoder",
+ "vocoder",
+ }
+ )
+ _optimize_enabled = True
+ CONFIG_FILENAME = "config.json"
+ HF_MODEL_TYPE = "dots_tts"
+ HF_ARCHITECTURES = ["DotsTTSForConditionalGeneration"]
+ LATENT_STATS_FILENAME = "latent_stats.pt"
+ LLM_CONFIG_FILENAME = "llm_config.json"
+ MODEL_FILENAME = "model.safetensors"
+ VOCODER_FILENAME = "vocoder.safetensors"
+ SPEAKER_ENCODER_FILENAME = "speaker_encoder.safetensors"
+ _ARTIFACT_ALIASES = (("llm.lm_head.weight", "llm.model.embed_tokens.weight"),)
+ REQUIRED_ARTIFACT_FILES = (
+ CONFIG_FILENAME,
+ LATENT_STATS_FILENAME,
+ LLM_CONFIG_FILENAME,
+ MODEL_FILENAME,
+ VOCODER_FILENAME,
+ SPEAKER_ENCODER_FILENAME,
+ )
+
+ # region Module assembly and checkpoint IO
+ def __init__(
+ self,
+ config: ModelConfig,
+ tokenizer,
+ latent_stats_path: str | Path,
+ llm_config: Qwen2Config,
+ ):
+ super().__init__()
+ self.config = config
+ self.tokenizer = tokenizer
+ self.latent_stats_path = Path(latent_stats_path)
+ self.audio_gen_start_id = require_token_id(
+ self.tokenizer, AUDIO_GEN_START_TOKEN
+ )
+
+ self.core = DotsTtsCore(
+ config,
+ llm_config=llm_config,
+ tokenizer=tokenizer,
+ latent_stats_path=self.latent_stats_path,
+ )
+ self.vocoder = AudioVAE(config.vocoder).eval()
+ self.vocoder.remove_weight_norm()
+ self.hop_size = self.vocoder.hop_size
+ self.xvector_extractor = SpeakerXVectorFeatures(
+ sample_rate=self.vocoder.sample_rate,
+ campplus_embedding_size=config.campplus_embedding_size,
+ max_audio_seconds=config.xvec_max_audio_seconds,
+ ).eval()
+
+ for param in self.vocoder.parameters():
+ param.requires_grad = False
+ for param in self.xvector_extractor.parameters():
+ param.requires_grad = False
+ self._optimize_enabled = True
+ self._compiled_models: dict[
+ tuple[str, tuple[Any, ...] | None], Callable[..., Any]
+ ] = {}
+ self._compile_backend = os.environ.get(
+ "DOTS_TTS_COMPILE_BACKEND",
+ "torch_compile",
+ ).strip().lower()
+ self._static_generate_workspaces: dict[tuple[Any, ...], dict[str, Any]] = {}
+ self._fm_decode_workspaces: dict[tuple[Any, ...], dict[str, torch.Tensor]] = {}
+
+ def set_optimize(self, optimize: bool) -> None:
+ self._optimize_enabled = bool(optimize)
+ if not self._optimize_enabled:
+ self._compiled_models.clear()
+
+ def set_compile_backend(self, backend: str) -> None:
+ normalized_backend = (backend or "torch_compile").strip().lower()
+ if normalized_backend != self._compile_backend:
+ self._compiled_models.clear()
+ self._compile_backend = normalized_backend
+
+ def export_compiled_models(
+ self,
+ ) -> dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]]:
+ exported: dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]] = {}
+ for cache_key, compiled in self._compiled_models.items():
+ if isinstance(compiled, _LazyAotiCompiledMethod):
+ if compiled.compiled is not None:
+ exported[cache_key] = compiled.compiled
+ continue
+ exported[cache_key] = compiled
+ return exported
+
+ def import_compiled_models(
+ self,
+ compiled_models: dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]],
+ ) -> None:
+ self._compiled_models.update(compiled_models)
+
+ def set_cfg_droprate(
+ self,
+ cfg_droprate: float | None = None,
+ xvec_drop_rate: float | None = None,
+ ) -> None:
+ if cfg_droprate is not None:
+ self.config.cfg_droprate = cfg_droprate
+ self.core.config.cfg_droprate = cfg_droprate
+ self.core.cfg_droprate = cfg_droprate
+
+ if xvec_drop_rate is not None:
+ self.config.xvec_drop_rate = xvec_drop_rate
+ self.core.config.xvec_drop_rate = xvec_drop_rate
+ self.core.xvec_drop_rate = xvec_drop_rate
+
+ @classmethod
+ def _resolve_generate_length_bucket(
+ cls,
+ max_generate_length: int,
+ ) -> _GenerateLengthBucket:
+ requested = int(max_generate_length)
+ if requested <= 0:
+ raise ValueError("max_generate_length must be positive.")
+ for bucket in cls._GENERATE_LENGTH_BUCKETS:
+ if requested <= bucket.size:
+ return bucket
+ raise ValueError(
+ "max_generate_length exceeds the largest supported compile bucket: "
+ f"max_generate_length={requested} "
+ f"max_supported={cls._GENERATE_LENGTH_BUCKETS[-1].size}."
+ )
+
+ @torch.no_grad()
+ def run_warmup(
+ self,
+ *,
+ max_generate_length: int,
+ precision: str = "bfloat16",
+ ode_method: str = "euler",
+ num_steps: int = 10,
+ guidance_scale: float = 1.2,
+ ) -> None:
+ ceiling_bucket = self._resolve_generate_length_bucket(max_generate_length)
+ warmup_buckets = tuple(
+ bucket
+ for bucket in self._GENERATE_LENGTH_BUCKETS
+ if bucket.size <= ceiling_bucket.size
+ )
+ bucket_sizes = [bucket.size for bucket in warmup_buckets]
+ logger.info(
+ "Inference warmup started: requested_max_generate_length={} bucket_sizes={}",
+ int(max_generate_length),
+ bucket_sizes,
+ )
+ for bucket in warmup_buckets:
+ bucket.run_warmup(
+ self,
+ precision=precision,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ )
+ logger.info(
+ "Inference warmup completed: requested_max_generate_length={} bucket_sizes={}",
+ int(max_generate_length),
+ bucket_sizes,
+ )
+
+ def _resolve_state_audio_patch_count(self, max_audio_patch_count: int) -> int:
+ requested = int(max_audio_patch_count)
+ if requested <= 0:
+ raise ValueError("max_audio_patch_count must be positive.")
+ if not self._optimize_enabled:
+ return requested
+ return self._resolve_generate_length_bucket(requested).size
+
+ def _warmup_fm_bucket(
+ self,
+ *,
+ max_audio_patch_count: int,
+ precision: str,
+ ode_method: str,
+ num_steps: int,
+ guidance_scale: float,
+ ) -> None:
+ dtype = get_dtype(precision)
+ device = next(self.core.parameters()).device
+ use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
+ with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp):
+ state = self._allocate_generate_state(
+ max_audio_patch_count=max_audio_patch_count,
+ device=device,
+ dtype=dtype,
+ )
+ state.fm_seq_len = state.fm_capacity
+ self._decode_next_audio(
+ state,
+ device=device,
+ g_cond=None,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ )
+
+ def _warmup_patch_encoder_bucket(
+ self,
+ *,
+ max_audio_patch_count: int,
+ precision: str,
+ ) -> None:
+ dtype = get_dtype(precision)
+ device = next(self.core.parameters()).device
+ state_dtype = dtype if device.type == "cuda" else torch.float32
+ use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
+ with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp):
+ state_audio_patch_count = self._resolve_state_audio_patch_count(
+ max_audio_patch_count
+ )
+ patch_encoder_state = self.core.patch_encoder.init_decode_state(
+ max_audio_patch_count=state_audio_patch_count,
+ batch_size=1,
+ device=device,
+ dtype=state_dtype,
+ )
+ audio_patch = torch.zeros(
+ (
+ 1,
+ self.core.patch_encoder.patch_size,
+ self.core.latent_dim,
+ ),
+ dtype=state_dtype,
+ device=device,
+ )
+ audio_patch = self.core.io_helper.denormalize(audio_patch)
+ patch_encoder_decode = self._get_compiled_method(
+ "patch_encoder.decode_patch",
+ self.core.patch_encoder,
+ "decode_patch",
+ signature=self._patch_encoder_compile_signature(patch_encoder_state),
+ )
+ positions = torch.arange(
+ self.core.patch_encoder.out_ds_rate,
+ device=device,
+ dtype=torch.long,
+ )
+ with measure_inference("patch_encoder"):
+ patch_encoder_decode(
+ audio_patch,
+ patch_encoder_state.conv_tail,
+ patch_encoder_state.layer_caches,
+ positions,
+ )
+
+ def _compile_callable(
+ self,
+ key: str,
+ model: Callable[..., Any],
+ *,
+ signature: tuple[Any, ...] | None = None,
+ ) -> Callable[..., Any]:
+ compile_target = key.split(".", maxsplit=1)[0]
+ cache_key = (key, signature)
+ compiled = self._compiled_models.get(cache_key)
+ if compiled is None:
+ mode = (
+ "default"
+ if key == "patch_encoder.decode_patch"
+ else "reduce-overhead"
+ )
+ compiled = torch.compile(
+ model,
+ mode=mode,
+ fullgraph=True,
+ dynamic=False,
+ )
+ self._compiled_models[cache_key] = compiled
+ logger.info(
+ "Compiled inference target: key={} target={} signature={}",
+ key,
+ compile_target,
+ signature,
+ )
+ return compiled
+
+ def _get_compiled_model(
+ self,
+ key: str,
+ model: Callable[..., Any],
+ *,
+ signature: tuple[Any, ...] | None = None,
+ ) -> Callable[..., Any]:
+ compile_target = key.split(".", maxsplit=1)[0]
+ if not self._optimize_enabled or compile_target not in self._COMPILE_TARGETS:
+ return model
+ return self._compile_callable(
+ key,
+ model,
+ signature=signature,
+ )
+
+ def _get_compiled_method(
+ self,
+ key: str,
+ owner: Any,
+ method_name: str,
+ *,
+ signature: tuple[Any, ...] | None = None,
+ ) -> Callable[..., Any]:
+ bound_method = getattr(owner, method_name)
+ compile_target = key.split(".", maxsplit=1)[0]
+ if not self._optimize_enabled or compile_target not in self._COMPILE_TARGETS:
+ return bound_method
+
+ cache_key = (key, signature)
+ if self._compile_backend in _AOTI_BACKENDS:
+ compiled = self._compiled_models.get(cache_key)
+ if compiled is None:
+ compiled = _LazyAotiCompiledMethod(
+ key=key,
+ owner=owner,
+ method_name=method_name,
+ signature=signature,
+ )
+ self._compiled_models[cache_key] = compiled
+ return compiled
+
+ raw_method = getattr(type(owner), method_name)
+ raw_callable = getattr(raw_method, "__wrapped__", raw_method)
+ compiled = self._compile_callable(
+ key,
+ raw_callable,
+ signature=signature,
+ )
+ return partial(compiled, owner)
+
+ def _allocate_generate_state(
+ self,
+ *,
+ max_audio_patch_count: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ) -> _GenerateState:
+ state_dtype = dtype if device.type == "cuda" else torch.float32
+ state_audio_patch_count = self._resolve_state_audio_patch_count(
+ max_audio_patch_count
+ )
+ fm_capacity = state_audio_patch_count * (
+ self.core.hidden_patch_size + self.core.latent_patch_size
+ )
+ workspace_key = (
+ state_audio_patch_count,
+ str(device),
+ state_dtype,
+ )
+ workspace = self._static_generate_workspaces.get(workspace_key)
+ if workspace is None:
+ workspace = {
+ "fm_sequence": torch.zeros(
+ (1, fm_capacity, self.core.fm_hidden_size),
+ dtype=state_dtype,
+ device=device,
+ ),
+ "fm_cfg_sequence": torch.zeros(
+ (1, fm_capacity, self.core.fm_hidden_size),
+ dtype=state_dtype,
+ device=device,
+ ),
+ "fm_null_g_cond": torch.zeros(
+ (1, self.core.fm_hidden_size),
+ dtype=state_dtype,
+ device=device,
+ ),
+ }
+ self._static_generate_workspaces[workspace_key] = workspace
+ else:
+ workspace["fm_sequence"].zero_()
+ workspace["fm_cfg_sequence"].zero_()
+
+ patch_encoder_state = None
+ if not self._optimize_enabled:
+ patch_encoder_state = self.core.patch_encoder.init_decode_state(
+ max_audio_patch_count=state_audio_patch_count,
+ batch_size=1,
+ device=device,
+ dtype=state_dtype,
+ )
+
+ return _GenerateState(
+ patch_encoder_state=patch_encoder_state,
+ fm_seq_len=0,
+ fm_capacity=fm_capacity,
+ fm_sequence=workspace["fm_sequence"],
+ fm_cfg_sequence=workspace["fm_cfg_sequence"],
+ fm_null_g_cond=workspace["fm_null_g_cond"],
+ )
+
+ @staticmethod
+ def _tensor_storage_signature(tensor: torch.Tensor) -> tuple:
+ return (
+ tensor.untyped_storage().data_ptr(),
+ tensor.storage_offset(),
+ tuple(tensor.size()),
+ tuple(tensor.stride()),
+ tensor.dtype,
+ )
+
+ @classmethod
+ def _build_artifact_state_dict(cls, module) -> dict[str, torch.Tensor]:
+ state_dict = module.state_dict()
+ skip_keys = set()
+
+ for redundant_key, canonical_key in cls._ARTIFACT_ALIASES:
+ redundant_tensor = state_dict.get(redundant_key)
+ canonical_tensor = state_dict.get(canonical_key)
+ if (
+ redundant_tensor is not None
+ and canonical_tensor is not None
+ and cls._tensor_storage_signature(redundant_tensor)
+ == cls._tensor_storage_signature(canonical_tensor)
+ ):
+ skip_keys.add(redundant_key)
+
+ cleaned_state_dict = {}
+ seen_storage = set()
+ for key, value in state_dict.items():
+ if key in skip_keys:
+ continue
+
+ storage_signature = cls._tensor_storage_signature(value)
+ if storage_signature in seen_storage:
+ continue
+
+ seen_storage.add(storage_signature)
+ cleaned_state_dict[key] = value.detach().cpu().contiguous()
+
+ return cleaned_state_dict
+
+ @classmethod
+ def _restore_artifact_state_dict(cls, state_dict: dict, module) -> dict:
+ restored_state_dict = dict(state_dict)
+ for redundant_key, canonical_key in cls._ARTIFACT_ALIASES:
+ if (
+ canonical_key in restored_state_dict
+ and redundant_key not in restored_state_dict
+ and redundant_key in module.state_dict()
+ ):
+ restored_state_dict[redundant_key] = restored_state_dict[canonical_key]
+ return restored_state_dict
+
+ @classmethod
+ def _save_artifact_module(cls, module, path: Path) -> None:
+ save_file(cls._build_artifact_state_dict(module), path)
+
+ @classmethod
+ def _load_artifact_module(cls, module, path: Path):
+ state_dict = load_file(path, device="cpu")
+ restored_state_dict = cls._restore_artifact_state_dict(state_dict, module)
+ mismatch = module.load_state_dict(restored_state_dict, strict=False)
+ if mismatch.missing_keys or mismatch.unexpected_keys:
+ raise RuntimeError(f"Failed to load {path}: {mismatch}")
+ return module
+
+ @classmethod
+ def _validate_pretrained_directory(
+ cls, pretrained_model_name_or_path: str | Path
+ ) -> Path:
+ pretrained_path = Path(pretrained_model_name_or_path).expanduser().resolve()
+ missing_files = [
+ name
+ for name in cls.REQUIRED_ARTIFACT_FILES
+ if not (pretrained_path / name).is_file()
+ ]
+ if missing_files:
+ raise FileNotFoundError(
+ f"Pretrained path {pretrained_path} is missing required files: {missing_files}"
+ )
+ return pretrained_path
+
+ @classmethod
+ def _load_pretrained_config(cls, pretrained_path: Path) -> ModelConfig:
+ return ModelConfig.model_validate(
+ json.loads(
+ (pretrained_path / cls.CONFIG_FILENAME).read_text(encoding="utf-8")
+ )
+ )
+
+ @staticmethod
+ def _save_llm_config(llm_config: Qwen2Config, path: Path) -> None:
+ path.write_text(
+ json.dumps(llm_config.to_dict(), ensure_ascii=True, indent=2),
+ encoding="utf-8",
+ )
+
+ @staticmethod
+ def _load_llm_config(path: Path) -> Qwen2Config:
+ return Qwen2Config.from_dict(json.loads(path.read_text(encoding="utf-8")))
+
+ def _tie_llm_weights(self) -> None:
+ if hasattr(self.core.llm, "tie_weights"):
+ self.core.llm.tie_weights()
+
+ def save_pretrained(self, save_directory: str | Path) -> Path:
+ save_directory = Path(save_directory)
+ save_directory.mkdir(parents=True, exist_ok=True)
+
+ config_payload = self.config.to_declared_dict()
+ config_payload["model_type"] = self.HF_MODEL_TYPE
+ config_payload["architectures"] = list(self.HF_ARCHITECTURES)
+ (save_directory / self.CONFIG_FILENAME).write_text(
+ json.dumps(config_payload, ensure_ascii=True, indent=2),
+ encoding="utf-8",
+ )
+ self._save_llm_config(
+ self.core.llm.config,
+ save_directory / self.LLM_CONFIG_FILENAME,
+ )
+ self.tokenizer.save_pretrained(save_directory)
+ shutil.copy2(
+ self.latent_stats_path,
+ save_directory / self.LATENT_STATS_FILENAME,
+ )
+ self._save_artifact_module(self.core, save_directory / self.MODEL_FILENAME)
+ self._save_artifact_module(self.vocoder, save_directory / self.VOCODER_FILENAME)
+ self._save_artifact_module(
+ self.xvector_extractor,
+ save_directory / self.SPEAKER_ENCODER_FILENAME,
+ )
+ return save_directory
+
+ def _load_pretrained_artifacts(self, pretrained_path: Path) -> None:
+ self.latent_stats_path = pretrained_path / self.LATENT_STATS_FILENAME
+ self.core.io_helper = type(self.core.io_helper)(
+ latent_stats_path=self.latent_stats_path
+ )
+ self._load_artifact_module(self.core, pretrained_path / self.MODEL_FILENAME)
+ self._tie_llm_weights()
+ self._load_artifact_module(
+ self.vocoder, pretrained_path / self.VOCODER_FILENAME
+ )
+ self._load_artifact_module(
+ self.xvector_extractor,
+ pretrained_path / self.SPEAKER_ENCODER_FILENAME,
+ )
+ self.core.eval()
+ self.vocoder.eval()
+ self.xvector_extractor.eval()
+
+ def load_pretrained_weights(
+ self, pretrained_model_name_or_path: str | Path
+ ) -> None:
+ pretrained_path = self._validate_pretrained_directory(
+ pretrained_model_name_or_path
+ )
+ saved_config = self._load_pretrained_config(pretrained_path)
+ if saved_config.to_declared_dict() != self.config.to_declared_dict():
+ raise ValueError(
+ f"Pretrained config at {pretrained_path} does not match the current model."
+ )
+ saved_llm_config = self._load_llm_config(
+ pretrained_path / self.LLM_CONFIG_FILENAME
+ )
+ if saved_llm_config.to_dict() != self.core.llm.config.to_dict():
+ raise ValueError(
+ f"Pretrained LLM config at {pretrained_path} does not match the current model."
+ )
+ self._load_pretrained_artifacts(pretrained_path)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str | Path):
+ logger.info(
+ "DotsTtsModel load started: pretrained_path={}",
+ pretrained_model_name_or_path,
+ )
+ pretrained_model_name_or_path = cls._validate_pretrained_directory(
+ pretrained_model_name_or_path
+ )
+ config = cls._load_pretrained_config(pretrained_model_name_or_path)
+ llm_config = cls._load_llm_config(
+ pretrained_model_name_or_path / cls.LLM_CONFIG_FILENAME
+ )
+ logger.info(
+ "DotsTtsModel config loaded: pretrained_path={} sample_rate={} patch_size={}",
+ pretrained_model_name_or_path,
+ config.vocoder.sample_rate,
+ config.patch_size,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ str(pretrained_model_name_or_path),
+ local_files_only=True,
+ )
+ model = cls(
+ config,
+ tokenizer=tokenizer,
+ latent_stats_path=pretrained_model_name_or_path / cls.LATENT_STATS_FILENAME,
+ llm_config=llm_config,
+ )
+ model._load_pretrained_artifacts(pretrained_model_name_or_path)
+ logger.info(
+ "DotsTtsModel load completed: pretrained_path={}",
+ pretrained_model_name_or_path,
+ )
+ return model.eval()
+
+ # endregion Module assembly and checkpoint IO
+
+ # region Training batch preparation
+ @torch.no_grad()
+ def prepare_training_inputs(self, data: dict[str, Any]) -> dict[str, Any]:
+ self.vocoder.eval()
+ self.xvector_extractor.eval()
+ processed = dict(data)
+ sample: torch.Tensor | None = data.get("sample")
+ sample_lengths: torch.Tensor | None = data.get("sample_lengths")
+
+ if sample is not None:
+ latents = self.vocoder.extract_latents(sample)
+ processed["latents"] = latents
+ if sample_lengths is not None:
+ processed["latent_lengths"] = sample_lengths // self.hop_size
+ else:
+ processed["latent_lengths"] = torch.full(
+ (latents.size(0),),
+ latents.size(-1),
+ dtype=torch.long,
+ device=latents.device,
+ )
+ processed["latents_sampled"] = self.core.io_helper.sample_from_latent(
+ latents
+ )
+ fbank = data.get("fbank")
+ fbank_lengths = data.get("fbank_lengths")
+ processed["xvector"] = self.xvector_extractor(
+ sample,
+ audio_lengths=sample_lengths,
+ fbank=fbank,
+ fbank_lengths=fbank_lengths,
+ )
+ else:
+ processed["latents"] = None
+ processed["latent_lengths"] = None
+
+ return processed
+
+ def _build_audio_span_mask(self, token_ids: torch.Tensor) -> torch.Tensor:
+ span_mask = torch.zeros_like(token_ids, dtype=torch.bool)
+ for token_id in self.core.audio_span_token_ids:
+ span_mask = span_mask | (token_ids == token_id)
+ return span_mask
+
+ def _prepare_loss_metadata(self, data: dict[str, Any]) -> dict[str, Any]:
+ input_ids: torch.Tensor = data["input_ids"]
+ labels: torch.Tensor = data["labels"]
+ loss_mask: torch.Tensor = data["loss_mask"]
+ input_span_mask = self._build_audio_span_mask(input_ids)
+ output_span_mask = self._build_audio_span_mask(labels)
+ output_span_mask_float = output_span_mask.to(loss_mask.dtype)
+ llm_loss_mask = loss_mask * (1.0 - output_span_mask_float)
+ fm_loss_mask = loss_mask * output_span_mask_float
+ patch_counts = output_span_mask.sum(dim=1)
+ max_patch_count = max(1, int(patch_counts.max().item()))
+ fm_patch_mask = loss_mask.new_zeros((loss_mask.size(0), max_patch_count))
+ for batch_idx in range(loss_mask.size(0)):
+ count = int(patch_counts[batch_idx].item())
+ if count <= 0:
+ continue
+ fm_patch_mask[batch_idx, :count] = fm_loss_mask[batch_idx].masked_select(
+ output_span_mask[batch_idx]
+ )
+
+ return {
+ "input_span_mask": input_span_mask,
+ "output_span_mask": output_span_mask,
+ "loss_masks": {
+ "ce_loss": llm_loss_mask,
+ "fm_loss": fm_patch_mask,
+ "eos_loss": self._build_eos_loss_mask(fm_loss_mask),
+ },
+ }
+
+ @staticmethod
+ def _build_eos_loss_mask(eos_loss_mask: torch.Tensor) -> torch.Tensor:
+ batch_size, seq_len = eos_loss_mask.shape
+ mask = eos_loss_mask.to(dtype=torch.bool)
+ target = torch.zeros((batch_size, seq_len), dtype=torch.bool, device=mask.device)
+ mask_counts = mask.sum(dim=1, keepdim=True)
+ cumulative = mask.long().cumsum(dim=1)
+ target[mask & (cumulative == mask_counts)] = True
+
+ mask_counts_flat = mask_counts.squeeze(1)
+ neg_counts = (mask_counts_flat - 1).clamp_min(0).to(eos_loss_mask.dtype)
+ pos_weight = torch.where(
+ neg_counts > 0,
+ torch.full_like(neg_counts, 0.5),
+ torch.ones_like(neg_counts),
+ ).unsqueeze(1)
+ neg_weight = torch.where(
+ neg_counts > 0,
+ 0.5 / neg_counts,
+ torch.zeros_like(neg_counts),
+ ).unsqueeze(1)
+
+ positive_mask = target & mask
+ negative_mask = mask & ~positive_mask
+ return torch.where(
+ positive_mask,
+ pos_weight,
+ negative_mask.to(eos_loss_mask.dtype) * neg_weight,
+ )
+ # endregion Training batch preparation
+
+ # region Training loss assembly and forward
+ @staticmethod
+ def _compute_ce_loss_term(
+ llm_logits: torch.Tensor,
+ llm_labels: torch.Tensor,
+ llm_loss_mask: torch.Tensor,
+ ) -> LossTerm:
+ vocab_size = llm_logits.size(-1)
+ ce_loss = F.cross_entropy(
+ llm_logits.view(-1, vocab_size),
+ llm_labels.view(-1),
+ reduction="none",
+ ).view_as(llm_labels)
+ return LossTerm(loss=ce_loss, mask=llm_loss_mask.to(ce_loss.dtype))
+
+ @staticmethod
+ def _compute_fm_loss_term(
+ pred: torch.Tensor,
+ target: torch.Tensor,
+ fm_patch_mask: torch.Tensor,
+ ) -> LossTerm:
+ batch_size, max_patch_count = fm_patch_mask.shape
+ fm_loss = (pred - target).pow(2).mean(dim=2).mean(dim=1)
+ loss = fm_loss.new_zeros((batch_size, max_patch_count))
+ patch_counts = fm_patch_mask.gt(0).sum(dim=1).tolist()
+ expected_count = int(sum(patch_counts))
+ if expected_count > 0 and int(fm_loss.numel()) != expected_count:
+ raise RuntimeError(
+ "Flow-matching loss count mismatch: "
+ f"expected {expected_count}, got {int(fm_loss.numel())}."
+ )
+
+ offset = 0
+ for batch_idx, patch_count in enumerate(patch_counts):
+ if patch_count <= 0:
+ continue
+ next_offset = offset + int(patch_count)
+ loss[batch_idx, :patch_count] = fm_loss[offset:next_offset]
+ offset = next_offset
+ return LossTerm(loss=loss, mask=fm_patch_mask.to(loss.dtype))
+
+ @staticmethod
+ def _compute_eos_loss_term(
+ eos_out: torch.Tensor,
+ eos_loss_mask: torch.Tensor,
+ ) -> LossTerm:
+ batch_size, seq_len, _ = eos_out.shape
+ weights = eos_loss_mask.to(device=eos_out.device)
+ mask = weights.gt(0)
+ target = torch.zeros(
+ (batch_size, seq_len),
+ dtype=torch.long,
+ device=eos_out.device,
+ )
+ mask_counts = mask.sum(dim=1, keepdim=True)
+ cumulative = mask.long().cumsum(dim=1)
+ target[mask & (cumulative == mask_counts)] = 1
+
+ logits = rearrange(eos_out, "b n c -> b c n")
+ ce_per_token = F.cross_entropy(logits, target, reduction="none")
+ return LossTerm(loss=ce_per_token, mask=weights.to(ce_per_token.dtype))
+
+ @staticmethod
+ def _compute_eos_loss_stats(
+ eos_out: torch.Tensor,
+ eos_loss_mask: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ weights = DotsTtsModel._build_eos_loss_mask(eos_loss_mask)
+ term = DotsTtsModel._compute_eos_loss_term(eos_out, weights)
+ mask = term.mask.to(device=term.loss.device, dtype=term.loss.dtype)
+ eos_loss_sum = (term.loss * mask).sum(dim=1)
+ eos_sample_count = eos_loss_mask.to(device=term.loss.device).gt(0).any(
+ dim=1
+ ).to(term.loss.dtype)
+ return eos_loss_sum, eos_sample_count
+
+ def _compute_loss_terms(
+ self,
+ outputs: DotsTtsForwardOutput,
+ *,
+ labels: torch.Tensor,
+ loss_masks: LossMasks,
+ ) -> LossTerms:
+ return {
+ "ce_loss": self._compute_ce_loss_term(
+ outputs.llm_logits,
+ labels,
+ loss_masks["ce_loss"],
+ ),
+ "fm_loss": self._compute_fm_loss_term(
+ outputs.pred,
+ outputs.target,
+ loss_masks["fm_loss"],
+ ),
+ "eos_loss": self._compute_eos_loss_term(
+ outputs.eos_out,
+ loss_masks["eos_loss"],
+ ),
+ }
+
+ def prepare_training_batch(self, data: dict[str, Any]) -> dict[str, Any]:
+ prepared = dict(data)
+ prepared.update(self._prepare_loss_metadata(prepared))
+ return prepared
+
+ def forward(self, data: dict[str, Any]) -> LossTerms:
+ loss_masks: LossMasks = data["loss_masks"]
+ processed = self.prepare_training_inputs(data)
+ processed["input_span_mask"] = data["input_span_mask"]
+ processed["output_span_mask"] = data["output_span_mask"]
+ return self._compute_loss_terms(
+ self.core(processed),
+ labels=processed["labels"],
+ loss_masks=loss_masks,
+ )
+ # endregion Training loss assembly and forward
+
+ # region Prompt conditioning and decode state helpers
+ @torch.no_grad()
+ def _prepare_prompt_conditioning(
+ self,
+ prompt_audio: torch.Tensor | None,
+ *,
+ use_prompt_prefill: bool,
+ speaker_scale: float = 1.5,
+ ) -> _PromptConditioning:
+ if prompt_audio is None:
+ logger.info("Prompt conditioning skipped: no prompt audio provided.")
+ return _PromptConditioning()
+
+ self.vocoder.eval()
+ self.xvector_extractor.eval()
+ device = next(self.core.parameters()).device
+ if prompt_audio.ndim == 1:
+ prompt_audio = prompt_audio.unsqueeze(0)
+ prompt_audio = prompt_audio.to(device=device)
+
+ target_len = math.ceil(
+ prompt_audio.size(1) / (self.config.patch_size * self.hop_size)
+ ) * (self.config.patch_size * self.hop_size)
+ pad_len = target_len - prompt_audio.size(1)
+ if pad_len > 0:
+ prompt_audio = F.pad(prompt_audio, (0, pad_len))
+
+ speaker_encoder = self._get_compiled_model(
+ "speaker_encoder",
+ self.xvector_extractor,
+ )
+ with measure_inference("speaker_encoder"):
+ speaker_embedding = (
+ speaker_encoder(prompt_audio[None, :]) * float(speaker_scale)
+ )
+ g_cond = self.core.xvec_proj(speaker_embedding)
+ if not use_prompt_prefill:
+ logger.info(
+ "Reference-audio-only conditioning prepared: prompt_samples={} speaker_scale={} device={}",
+ prompt_audio.shape[-1],
+ speaker_scale,
+ device,
+ )
+ return _PromptConditioning(g_cond=g_cond)
+
+ latent_encoder = self._get_compiled_model(
+ "latent_encoder",
+ self.vocoder.extract_latents,
+ )
+ with measure_inference("latent_encoder"):
+ prompt_latents = latent_encoder(prompt_audio[None, :])
+ prompt_latents_sampled = self.core.io_helper.sample_from_latent(prompt_latents)
+ prompt_latents_sampled = prompt_latents_sampled[:, : -self.config.patch_size]
+ prompt_patches = rearrange(
+ self.core.io_helper.normalize(prompt_latents_sampled),
+ "b (s p) d -> b s p d",
+ p=self.config.patch_size,
+ )
+ logger.info(
+ "Prompt conditioning prepared: prompt_samples={} prompt_patch_count={} "
+ "speaker_scale={} device={}",
+ prompt_audio.shape[-1],
+ prompt_patches.size(1),
+ speaker_scale,
+ device,
+ )
+ return _PromptConditioning(
+ prompt_patches=prompt_patches,
+ prompt_latents=prompt_latents_sampled,
+ g_cond=g_cond,
+ )
+
+ @staticmethod
+ def _patch_encoder_compile_signature(
+ patch_encoder_state: Any,
+ ) -> tuple[int, torch.dtype]:
+ key_cache, _ = patch_encoder_state.layer_caches[0]
+ return int(key_cache.size(2)), key_cache.dtype
+
+ def _resolve_patch_encoder_audio_bucket(self, required_seq_len: int) -> int:
+ requested = int(required_seq_len)
+ if requested <= 0:
+ raise ValueError("required_seq_len must be positive.")
+ requested_patch_count = math.ceil(
+ requested / self.core.patch_encoder.out_ds_rate
+ )
+ if not self._optimize_enabled:
+ return requested_patch_count
+ return self._resolve_generate_length_bucket(requested_patch_count).size
+
+ def _copy_patch_encoder_state(self, source: Any, target: Any) -> None:
+ seq_len = source.seq_len
+ target_capacity = int(target.layer_caches[0][0].size(2))
+ if seq_len > target_capacity:
+ raise ValueError(
+ "Patch encoder state copy exceeds target capacity: "
+ f"seq_len={seq_len} capacity={target_capacity}."
+ )
+
+ target.conv_tail.copy_(source.conv_tail)
+ target.seq_len = seq_len
+ for (source_key, source_value), (target_key, target_value) in zip(
+ source.layer_caches,
+ target.layer_caches,
+ strict=True,
+ ):
+ if seq_len > 0:
+ target_key[:, :, :seq_len, :].copy_(source_key[:, :, :seq_len, :])
+ target_value[:, :, :seq_len, :].copy_(source_value[:, :, :seq_len, :])
+
+ def _ensure_patch_encoder_state_capacity(
+ self,
+ state: _GenerateState,
+ *,
+ required_seq_len: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ) -> None:
+ current_state = state.patch_encoder_state
+ if current_state is not None:
+ current_capacity = int(current_state.layer_caches[0][0].size(2))
+ if current_capacity >= required_seq_len:
+ return
+
+ target_audio_patch_count = self._resolve_patch_encoder_audio_bucket(
+ required_seq_len
+ )
+ next_state = self.core.patch_encoder.init_decode_state(
+ max_audio_patch_count=target_audio_patch_count,
+ batch_size=1,
+ device=device,
+ dtype=dtype,
+ )
+ if current_state is not None:
+ self._copy_patch_encoder_state(current_state, next_state)
+ state.patch_encoder_state = next_state
+
+ def _prefill_prompt_latents(
+ self,
+ prompt_latents: torch.Tensor | None,
+ *,
+ state: _GenerateState,
+ ) -> torch.Tensor | None:
+ if prompt_latents is None:
+ return None
+ if prompt_latents.size(1) == 0:
+ return prompt_latents.new_zeros(
+ (prompt_latents.size(0), 0, self.core.llm_hidden_size)
+ )
+ self._ensure_patch_encoder_state_capacity(
+ state,
+ required_seq_len=(
+ (prompt_latents.size(1) // self.core.patch_encoder.patch_size)
+ * self.core.patch_encoder.out_ds_rate
+ ),
+ device=prompt_latents.device,
+ dtype=(
+ state.fm_sequence.dtype
+ if state.fm_sequence is not None
+ else prompt_latents.dtype
+ ),
+ )
+ with measure_inference("patch_encoder"):
+ prompt_patch_embeddings, state.patch_encoder_state = (
+ self.core.patch_encoder.prefill(
+ prompt_latents,
+ state.patch_encoder_state,
+ )
+ )
+ return prompt_patch_embeddings
+
+ def _get_fm_decode_workspace(
+ self,
+ *,
+ total_len: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ) -> dict[str, torch.Tensor]:
+ workspace_key = (total_len, str(device), dtype)
+ workspace = self._fm_decode_workspaces.get(workspace_key)
+ if workspace is None:
+ workspace = {
+ "input_sequence": torch.zeros(
+ (1, total_len, self.core.fm_hidden_size),
+ dtype=dtype,
+ device=device,
+ ),
+ "cfg_sequence": torch.zeros(
+ (1, total_len, self.core.fm_hidden_size),
+ dtype=dtype,
+ device=device,
+ ),
+ "attn_mask": torch.zeros(
+ (1, total_len, total_len),
+ dtype=torch.bool,
+ device=device,
+ ),
+ "pos_ids": torch.zeros(
+ (1, total_len),
+ dtype=torch.float32,
+ device=device,
+ ),
+ }
+ self._fm_decode_workspaces[workspace_key] = workspace
+ else:
+ workspace["input_sequence"].zero_()
+ workspace["cfg_sequence"].zero_()
+ return workspace
+
+ def _resolve_fm_history_bucket_capacity(self, fm_seq_len: int) -> int:
+ requested = int(fm_seq_len)
+ if requested <= 0:
+ raise ValueError("fm_seq_len must be positive.")
+ if not self._optimize_enabled:
+ return requested
+ history_stride = self.core.hidden_patch_size + self.core.latent_patch_size
+ requested_patch_count = math.ceil(requested / history_stride)
+ return self._resolve_generate_length_bucket(
+ requested_patch_count
+ ).size * history_stride
+
+ def _build_fm_attn_mask(
+ self,
+ *,
+ state: _GenerateState,
+ attn_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ if state.fm_seq_len <= 0:
+ raise RuntimeError("FM sequence length must be positive before decode.")
+ hidden_patch_size = self.core.hidden_patch_size
+ latent_start = attn_mask.size(-1) - self.core.latent_patch_size
+ attn_mask.zero_()
+ block_start = state.fm_seq_len - hidden_patch_size
+ if block_start > 0:
+ causal_mask = torch.ones(
+ (block_start, block_start),
+ device=attn_mask.device,
+ dtype=torch.bool,
+ ).triu(1).logical_not()
+ attn_mask[:, :block_start, :block_start] = causal_mask
+
+ attn_mask[:, block_start : state.fm_seq_len, : state.fm_seq_len] = True
+ attn_mask[:, block_start : state.fm_seq_len, latent_start:] = True
+ attn_mask[:, latent_start:, : state.fm_seq_len] = True
+ attn_mask[:, latent_start:, latent_start:] = True
+ if latent_start > state.fm_seq_len:
+ padding_indices = torch.arange(
+ state.fm_seq_len,
+ latent_start,
+ device=attn_mask.device,
+ )
+ attn_mask[:, padding_indices, padding_indices] = True
+ return attn_mask
+
+ def _build_fm_pos_ids(
+ self,
+ *,
+ state: _GenerateState,
+ pos_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ if state.fm_seq_len <= 0:
+ raise RuntimeError("FM sequence length must be positive before decode.")
+ pos_ids.zero_()
+ latent_start = pos_ids.size(-1) - self.core.latent_patch_size
+ pos_ids[:, : state.fm_seq_len] = torch.arange(
+ state.fm_seq_len,
+ device=pos_ids.device,
+ dtype=pos_ids.dtype,
+ )
+ pos_ids[:, latent_start:] = torch.arange(
+ state.fm_seq_len,
+ state.fm_seq_len + self.core.latent_patch_size,
+ device=pos_ids.device,
+ dtype=pos_ids.dtype,
+ )
+ return pos_ids
+
+ def _prepare_fm_decode_inputs(
+ self,
+ state: _GenerateState,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
+ sequence = state.fm_sequence
+ cfg_sequence = state.fm_cfg_sequence
+ if sequence is None or cfg_sequence is None:
+ raise RuntimeError("FM static buffers are not initialized.")
+ history_bucket_capacity = self._resolve_fm_history_bucket_capacity(
+ state.fm_seq_len
+ )
+ total_len = history_bucket_capacity + self.core.latent_patch_size
+ workspace = self._get_fm_decode_workspace(
+ total_len=total_len,
+ device=sequence.device,
+ dtype=sequence.dtype,
+ )
+ workspace["input_sequence"][:, : state.fm_seq_len].copy_(
+ sequence[:, : state.fm_seq_len]
+ )
+ workspace["cfg_sequence"][:, : state.fm_seq_len].copy_(
+ cfg_sequence[:, : state.fm_seq_len]
+ )
+ return (
+ workspace["input_sequence"],
+ workspace["cfg_sequence"],
+ workspace["attn_mask"],
+ workspace["pos_ids"],
+ history_bucket_capacity,
+ )
+
+ def _append_to_fm_buffer(
+ self,
+ buffer: torch.Tensor | None,
+ state: _GenerateState,
+ chunk: torch.Tensor,
+ ) -> tuple[int, int]:
+ if buffer is None:
+ raise RuntimeError("FM static buffer is not initialized.")
+ start = state.fm_seq_len
+ end = start + chunk.size(1)
+ if end > state.fm_capacity:
+ raise RuntimeError(
+ "FM StaticBuffer capacity exceeded: "
+ f"next_length={end} capacity={state.fm_capacity}."
+ )
+ buffer[:, start:end].copy_(chunk.to(buffer.dtype))
+ return start, end
+
+ def _append_hidden_chunk(
+ self, state: _GenerateState, hidden_chunk: torch.Tensor
+ ) -> None:
+ last_hidden = hidden_chunk[:, -self.core.hidden_patch_size :, :]
+ projected = self.core.hidden_proj(last_hidden)
+ null_projected = self.core.hidden_proj(torch.zeros_like(last_hidden))
+ _start, end = self._append_to_fm_buffer(
+ state.fm_sequence,
+ state,
+ projected,
+ )
+ cfg_buffer = state.fm_cfg_sequence
+ if cfg_buffer is None:
+ raise RuntimeError("FM cfg static buffer is not initialized.")
+ cfg_buffer[:, state.fm_seq_len : end].copy_(null_projected.to(cfg_buffer.dtype))
+ state.fm_seq_len = end
+
+ def _append_history_chunk(
+ self, state: _GenerateState, latent_chunk: torch.Tensor
+ ) -> None:
+ history_latent = self.core.latent_proj(latent_chunk)
+ _start, end = self._append_to_fm_buffer(
+ state.fm_sequence,
+ state,
+ history_latent,
+ )
+ cfg_buffer = state.fm_cfg_sequence
+ if cfg_buffer is None:
+ raise RuntimeError("FM cfg static buffer is not initialized.")
+ cfg_buffer[:, state.fm_seq_len : end].copy_(history_latent.to(cfg_buffer.dtype))
+ state.fm_seq_len = end
+
+ def _consume_text_schedule(
+ self,
+ generation_schedule: torch.Tensor,
+ *,
+ position: int,
+ next_audio_position: int,
+ state: _GenerateState,
+ ) -> int:
+ with measure_inference("LLM"):
+ text_chunk = generation_schedule[:, position:next_audio_position]
+ _, state.llm_hiddens, _, state.llm_cache = self.core.step_llm(
+ input_ids=text_chunk,
+ past_key_values=state.llm_cache,
+ )
+ self._append_hidden_chunk(state, state.llm_hiddens)
+ return next_audio_position
+
+ def _locate_prefill_boundary(
+ self,
+ *,
+ span_positions: torch.Tensor,
+ prompt_patch_count: int,
+ ) -> tuple[int, torch.Tensor]:
+ if span_positions.numel() > prompt_patch_count:
+ return int(span_positions[prompt_patch_count].item()), span_positions[
+ :prompt_patch_count
+ ]
+ raise RuntimeError(
+ "Prefill boundary discovery failed despite prior schedule validation."
+ )
+
+ @staticmethod
+ def _find_audio_span_positions(
+ generation_schedule: torch.Tensor,
+ *,
+ audio_placeholder_ids: set[int],
+ ) -> torch.Tensor:
+ schedule = generation_schedule[0]
+ placeholder_ids = torch.tensor(
+ sorted(audio_placeholder_ids),
+ device=schedule.device,
+ dtype=schedule.dtype,
+ )
+ return torch.nonzero(
+ torch.isin(schedule, placeholder_ids),
+ as_tuple=False,
+ ).squeeze(-1)
+
+ @staticmethod
+ def _next_token_is_audio_span(
+ generation_schedule: torch.Tensor,
+ *,
+ position: int,
+ audio_placeholder_ids: set[int],
+ ) -> bool:
+ next_position = position + 1
+ if next_position >= generation_schedule.size(1):
+ return False
+ return int(generation_schedule[0, next_position].item()) in audio_placeholder_ids
+
+ def _build_prefill_inputs_embeds(
+ self,
+ generation_schedule: torch.Tensor,
+ *,
+ prompt_patch_embeddings: torch.Tensor | None,
+ prompt_span_positions: torch.Tensor,
+ ) -> torch.Tensor:
+ inputs_embeds = self.core.llm.get_input_embeddings()(
+ generation_schedule
+ ).clone()
+ if prompt_span_positions.numel() > 0:
+ if prompt_patch_embeddings is None:
+ raise RuntimeError(
+ "Prompt patch embeddings are required when prefill includes prompt audio spans."
+ )
+ patch_embeddings = prompt_patch_embeddings[
+ :, : prompt_span_positions.numel()
+ ].to(inputs_embeds.dtype)
+ if patch_embeddings.size(1) != prompt_span_positions.numel():
+ raise RuntimeError(
+ f"Prompt patch embeddings ({patch_embeddings.size(1)}) do not match prompt span count ({prompt_span_positions.numel()})."
+ )
+ inputs_embeds[:, prompt_span_positions, :] = patch_embeddings
+ return inputs_embeds
+
+ def _prefill(
+ self,
+ generation_schedule: torch.Tensor,
+ *,
+ state: _GenerateState,
+ span_positions: torch.Tensor,
+ prompt_patches: torch.Tensor | None,
+ prompt_patch_embeddings: torch.Tensor | None,
+ audio_placeholder_ids: set[int],
+ ) -> int:
+ prompt_patch_count = (
+ 0 if prompt_patches is None else int(prompt_patches.size(1))
+ )
+ prefill_end, prompt_span_positions = self._locate_prefill_boundary(
+ span_positions=span_positions,
+ prompt_patch_count=prompt_patch_count,
+ )
+ if prefill_end == 0:
+ return 0
+ inputs_embeds = self._build_prefill_inputs_embeds(
+ generation_schedule[:, :prefill_end],
+ prompt_patch_embeddings=prompt_patch_embeddings,
+ prompt_span_positions=prompt_span_positions,
+ )
+ with measure_inference("LLM"):
+ _, llm_hiddens, _, state.llm_cache = self.core.step_llm(
+ inputs_embeds=inputs_embeds,
+ past_key_values=state.llm_cache,
+ )
+ state.llm_hiddens = llm_hiddens[:, -1:, :]
+
+ cursor = 0
+ for prompt_index, span_position in enumerate(prompt_span_positions.tolist()):
+ if span_position > cursor:
+ self._append_hidden_chunk(
+ state, llm_hiddens[:, span_position - 1 : span_position, :]
+ )
+ self._append_history_chunk(state, prompt_patches[:, prompt_index])
+ if self._next_token_is_audio_span(
+ generation_schedule,
+ position=span_position,
+ audio_placeholder_ids=audio_placeholder_ids,
+ ):
+ self._append_hidden_chunk(
+ state, llm_hiddens[:, span_position : span_position + 1, :]
+ )
+ cursor = span_position + 1
+ if prefill_end > cursor:
+ self._append_hidden_chunk(
+ state, llm_hiddens[:, prefill_end - 1 : prefill_end, :]
+ )
+ return prefill_end
+
+ def _decode_next_audio(
+ self,
+ state: _GenerateState,
+ *,
+ device: torch.device,
+ g_cond: torch.Tensor | None,
+ ode_method: str,
+ num_steps: int,
+ guidance_scale: float,
+ ) -> torch.Tensor:
+ if state.fm_seq_len <= 0:
+ raise RuntimeError(
+ "Cannot decode audio before any conditioning state has been prefetched."
+ )
+ if state.fm_sequence is None or state.fm_cfg_sequence is None:
+ raise RuntimeError("FM static buffers are not initialized.")
+ if state.fm_null_g_cond is None:
+ raise RuntimeError("FM null conditioning buffer is not initialized.")
+ fm_sequence, fm_cfg_sequence, fm_attn_mask, fm_pos_ids, history_bucket_capacity = (
+ self._prepare_fm_decode_inputs(state)
+ )
+ compile_signature = (
+ (history_bucket_capacity, state.fm_sequence.dtype)
+ if self._optimize_enabled
+ else (state.fm_seq_len, state.fm_sequence.dtype)
+ )
+ if g_cond is None:
+ g_cond = state.fm_null_g_cond
+ else:
+ g_cond = g_cond.to(
+ device=state.fm_null_g_cond.device,
+ dtype=state.fm_null_g_cond.dtype,
+ )
+ with measure_inference("FM"):
+ attn_mask = self._build_fm_attn_mask(
+ state=state,
+ attn_mask=fm_attn_mask,
+ )
+ pos_ids = self._build_fm_pos_ids(
+ state=state,
+ pos_ids=fm_pos_ids,
+ )
+ if self.core.mode == "meanflow":
+ fm_solver_step = self._get_compiled_method(
+ "FM.meanflow.solver_step",
+ self.core,
+ "meanflow_solver_step",
+ signature=compile_signature,
+ )
+ return self.core._meanflow_step_fm(
+ input_sequence=fm_sequence,
+ attn_mask=attn_mask,
+ pos_ids=pos_ids,
+ patch_size=self.core.latent_patch_size,
+ g_cond=g_cond,
+ nfe=num_steps,
+ solver_step=fm_solver_step,
+ )
+
+ fm_solver_step = self._get_compiled_method(
+ "FM.flow_matching.solver_step",
+ self.core,
+ "fm_solver_step",
+ signature=compile_signature,
+ )
+ return self.core._flow_matching_step_fm(
+ input_sequence=fm_sequence,
+ cfg_sequence=fm_cfg_sequence,
+ attn_mask=attn_mask,
+ pos_ids=pos_ids,
+ hidden_size=self.core.hidden_patch_size,
+ patch_size=self.core.latent_patch_size,
+ g_cond=g_cond,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ solver_step=fm_solver_step,
+ )
+
+ def _consume_audio_patch(
+ self,
+ state: _GenerateState,
+ *,
+ audio_patch: torch.Tensor,
+ ) -> None:
+ audio_patch_for_llm = self.core.io_helper.denormalize(audio_patch)
+ self._append_history_chunk(state, audio_patch)
+ current_seq_len = (
+ 0
+ if state.patch_encoder_state is None
+ else state.patch_encoder_state.seq_len
+ )
+ self._ensure_patch_encoder_state_capacity(
+ state,
+ required_seq_len=current_seq_len + self.core.patch_encoder.out_ds_rate,
+ device=audio_patch_for_llm.device,
+ dtype=(
+ state.fm_sequence.dtype
+ if state.fm_sequence is not None
+ else audio_patch_for_llm.dtype
+ ),
+ )
+ patch_encoder_decode = self._get_compiled_method(
+ "patch_encoder.decode_patch",
+ self.core.patch_encoder,
+ "decode_patch",
+ signature=self._patch_encoder_compile_signature(state.patch_encoder_state),
+ )
+ patch_positions = (
+ torch.arange(
+ self.core.patch_encoder.out_ds_rate,
+ device=audio_patch_for_llm.device,
+ dtype=torch.long,
+ )
+ + state.patch_encoder_state.seq_len
+ )
+ with measure_inference("patch_encoder"):
+ llm_embedding, conv_tail = patch_encoder_decode(
+ audio_patch_for_llm,
+ state.patch_encoder_state.conv_tail,
+ state.patch_encoder_state.layer_caches,
+ patch_positions,
+ )
+ state.patch_encoder_state.conv_tail.copy_(conv_tail)
+ state.patch_encoder_state.seq_len += self.core.patch_encoder.out_ds_rate
+ with measure_inference("LLM"):
+ _, state.llm_hiddens, _, state.llm_cache = self.core.step_llm(
+ inputs_embeds=llm_embedding,
+ past_key_values=state.llm_cache,
+ )
+
+ def _decode(
+ self,
+ generation_schedule: torch.Tensor,
+ *,
+ position: int,
+ state: _GenerateState,
+ audio_placeholder_ids: set[int],
+ span_positions: torch.Tensor,
+ device: torch.device,
+ g_cond: torch.Tensor | None,
+ ode_method: str,
+ num_steps: int,
+ guidance_scale: float,
+ eos_threshold: float,
+ ) -> Iterator[torch.Tensor]:
+ span_cursor = torch.searchsorted(
+ span_positions,
+ torch.tensor(
+ position,
+ device=span_positions.device,
+ dtype=span_positions.dtype,
+ ),
+ ).item()
+ while position < generation_schedule.size(1):
+ token_id = int(generation_schedule[0, position].item())
+ if token_id in audio_placeholder_ids:
+ stop_after_current_audio = self._should_stop_after_current_audio(
+ state,
+ eos_threshold=eos_threshold,
+ )
+ audio_patch = self._decode_next_audio(
+ state,
+ device=device,
+ g_cond=g_cond,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ )
+ self._consume_audio_patch(
+ state,
+ audio_patch=audio_patch,
+ )
+ if self._next_token_is_audio_span(
+ generation_schedule,
+ position=position,
+ audio_placeholder_ids=audio_placeholder_ids,
+ ):
+ self._append_hidden_chunk(state, state.llm_hiddens)
+ position += 1
+ span_cursor += 1
+ yield audio_patch
+ if stop_after_current_audio:
+ state.end_flag = True
+ return
+ continue
+ next_audio_position = (
+ int(span_positions[span_cursor].item())
+ if span_cursor < span_positions.numel()
+ else generation_schedule.size(1)
+ )
+ position = self._consume_text_schedule(
+ generation_schedule,
+ position=position,
+ next_audio_position=next_audio_position,
+ state=state,
+ )
+
+ def _should_stop_after_current_audio(
+ self, state: _GenerateState, *, eos_threshold: float
+ ) -> bool:
+ if state.llm_hiddens is None:
+ return False
+ eos = (
+ self.core.eos_proj(state.llm_hiddens).softmax(dim=-1)[:, -1, 1]
+ > eos_threshold
+ )
+ return state.end_flag or bool(eos.item())
+
+ # endregion Prompt conditioning and decode state helpers
+
+ # region Public generation APIs
+ @torch.no_grad()
+ def _generate_latents_stream(
+ self,
+ data: dict[str, Any],
+ *,
+ precision: str,
+ ode_method: str,
+ num_steps: int,
+ guidance_scale: float,
+ speaker_scale: float = 1.5,
+ eos_threshold: float = 0.8,
+ ) -> Iterator[torch.Tensor]:
+ dtype = get_dtype(precision)
+ device = next(self.core.parameters()).device
+ use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
+ with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp):
+ generation_schedule: torch.Tensor = data["generation_schedule"]
+ if generation_schedule.size(0) != 1:
+ raise ValueError(
+ "DotsTtsModel.generate expects batch size 1 for generation_schedule."
+ )
+
+ use_prompt_prefill = data.get("prompt_audio") is not None and bool(
+ data.get("prompt_text")
+ )
+ prompt_conditioning = self._prepare_prompt_conditioning(
+ data.get("prompt_audio"),
+ use_prompt_prefill=use_prompt_prefill,
+ speaker_scale=speaker_scale,
+ )
+ has_prompt_prefill = prompt_conditioning.prompt_patches is not None
+ prompt_patch_count = (
+ 0
+ if not has_prompt_prefill
+ else int(prompt_conditioning.prompt_patches.size(1))
+ )
+ audio_placeholder_ids = set(self.core.audio_span_token_ids)
+ span_positions = self._find_audio_span_positions(
+ generation_schedule,
+ audio_placeholder_ids=audio_placeholder_ids,
+ )
+ span_count = int(span_positions.numel())
+ minimum_required_spans = prompt_patch_count + 1
+ if span_count < minimum_required_spans:
+ raise ValueError(
+ f"generation_schedule provides {span_count} audio spans, but prompt prefill requires "
+ f"{prompt_patch_count} spans and generation requires at least one additional decode span."
+ )
+ logger.info(
+ "Latent generation prepared: schedule_audio_spans={} prompt_patch_count={} "
+ "minimum_required_spans={}",
+ span_count,
+ prompt_patch_count,
+ minimum_required_spans,
+ )
+
+ state = self._allocate_generate_state(
+ max_audio_patch_count=span_count,
+ device=device,
+ dtype=dtype,
+ )
+ prompt_patch_embeddings = self._prefill_prompt_latents(
+ prompt_conditioning.prompt_latents,
+ state=state,
+ )
+ position = self._prefill(
+ generation_schedule,
+ state=state,
+ span_positions=span_positions,
+ prompt_patches=prompt_conditioning.prompt_patches,
+ prompt_patch_embeddings=prompt_patch_embeddings,
+ audio_placeholder_ids=audio_placeholder_ids,
+ )
+
+ payload_patch_count = 0
+ should_drop_regenerated_prompt_patch = has_prompt_prefill
+ for audio_patch in self._decode(
+ generation_schedule,
+ position=position,
+ state=state,
+ audio_placeholder_ids=audio_placeholder_ids,
+ span_positions=span_positions,
+ device=device,
+ g_cond=prompt_conditioning.g_cond,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ eos_threshold=eos_threshold,
+ ):
+ if should_drop_regenerated_prompt_patch:
+ should_drop_regenerated_prompt_patch = False
+ continue
+ payload_patch_count += 1
+ if payload_patch_count == 1 or payload_patch_count % 10 == 0:
+ logger.info(
+ "Latent generation progress: payload_audio_patches={}",
+ payload_patch_count,
+ )
+ yield self.core.io_helper.denormalize(audio_patch)
+
+ if payload_patch_count == 0:
+ if has_prompt_prefill:
+ raise RuntimeError(
+ "Generation produced no payload latents after discarding the regenerated prompt-tail patch. "
+ "This usually means EOS triggered immediately after prompt continuation "
+ "or the generation schedule did not provide an effective decode span."
+ )
+ raise RuntimeError(
+ "Generation produced no decodable latents. "
+ "This usually means EOS triggered before the first decode patch "
+ "or the generation schedule did not provide an effective decode span."
+ )
+ logger.info(
+ "Latent generation completed: payload_audio_patches={}",
+ payload_patch_count,
+ )
+
+ @torch.no_grad()
+ def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ with measure_inference("latent_decoder"):
+ return self.vocoder.inference_from_latents(
+ latents.transpose(1, 2).float(),
+ do_sample=False,
+ )
+
+ @torch.no_grad()
+ def _init_vocoder_stream_state(self) -> Any:
+ return self.vocoder.init_stream_state(
+ batch_size=1,
+ chunk_size=self.core.latent_patch_size,
+ )
+
+ @torch.no_grad()
+ def _stream_vocoder_patch(
+ self,
+ latent_patch: torch.Tensor,
+ *,
+ stream_state: Any,
+ ) -> torch.Tensor:
+ latents = latent_patch.transpose(1, 2)
+ if not self._optimize_enabled:
+ with measure_inference("vocoder"):
+ return self.vocoder.stream_step(latents, stream_state)
+
+ valid_frames = min(
+ stream_state.decoder.total_frames,
+ stream_state.decoder.window.size(-1),
+ )
+ valid_frames_tensor = stream_state.decoder.window.new_tensor(
+ valid_frames,
+ dtype=torch.int64,
+ )
+ vocoder_step = self._get_compiled_method(
+ "vocoder.step",
+ self.vocoder,
+ "compiled_stream_step",
+ )
+ with measure_inference("vocoder"):
+ audio_window, hidden_h, hidden_c, new_window = vocoder_step(
+ latents,
+ stream_state.lstm_hidden[0],
+ stream_state.lstm_hidden[1],
+ stream_state.decoder.window,
+ valid_frames_tensor,
+ )
+ stream_state.lstm_hidden = (hidden_h.clone(), hidden_c.clone())
+ stream_state.decoder.window = new_window.clone()
+ stream_state.decoder.total_frames += int(latents.size(-1))
+ audio_chunk = self.vocoder._slice_stream_audio_window(
+ audio_window,
+ stream_state,
+ final=False,
+ )
+ return audio_chunk.clone()
+
+ @torch.no_grad()
+ def _flush_vocoder_stream(self, stream_state: Any) -> torch.Tensor:
+ with measure_inference("vocoder"):
+ return self.vocoder.stream_flush(stream_state)
+
+ @torch.no_grad()
+ def generate_audio_stream(
+ self,
+ data: dict[str, Any],
+ *,
+ precision: str,
+ ode_method: str,
+ num_steps: int,
+ guidance_scale: float,
+ speaker_scale: float = 1.5,
+ eos_threshold: float = 0.8,
+ ) -> Iterator[torch.Tensor]:
+ stream_state = self._init_vocoder_stream_state()
+ for latent_patch in self._generate_latents_stream(
+ data,
+ precision=precision,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ speaker_scale=speaker_scale,
+ eos_threshold=eos_threshold,
+ ):
+ audio_chunk = self._stream_vocoder_patch(
+ latent_patch,
+ stream_state=stream_state,
+ )
+ if audio_chunk.size(-1) > 0:
+ yield audio_chunk
+
+ final_chunk = self._flush_vocoder_stream(stream_state)
+ if final_chunk.size(-1) > 0:
+ yield final_chunk
+
+ @torch.no_grad()
+ def generate_audio(
+ self,
+ data: dict[str, Any],
+ *,
+ precision: str,
+ ode_method: str,
+ num_steps: int,
+ guidance_scale: float,
+ speaker_scale: float = 1.5,
+ ) -> torch.Tensor:
+ latent_patches = list(
+ self._generate_latents_stream(
+ data,
+ precision=precision,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ speaker_scale=speaker_scale,
+ )
+ )
+ logger.info(
+ "Vocoder decode started: latent_patch_count={}",
+ len(latent_patches),
+ )
+ audio = self._decode_latents(torch.cat(latent_patches, dim=1))
+ logger.info(
+ "Vocoder decode completed: waveform_samples={}",
+ audio.shape[-1],
+ )
+ return audio
+ # endregion Public generation APIs
diff --git a/src/dots_tts/modules/__init__.py b/src/dots_tts/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/dots_tts/modules/backbone/__init__.py b/src/dots_tts/modules/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c175d372db793678414083ad6efd623645551d7
--- /dev/null
+++ b/src/dots_tts/modules/backbone/__init__.py
@@ -0,0 +1 @@
+"""Backbone modules."""
diff --git a/src/dots_tts/modules/backbone/dit.py b/src/dots_tts/modules/backbone/dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9673d48bdaf2df5f2fab6a942451205b9a279b0
--- /dev/null
+++ b/src/dots_tts/modules/backbone/dit.py
@@ -0,0 +1,205 @@
+import math
+
+import torch
+import torch.nn as nn
+
+from dots_tts.modules.backbone.layers import Mlp, MultiHeadAttention
+
+
+def modulate(x, shift, scale, **_kwargs):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+class TimestepEmbedder(nn.Module):
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ return self.mlp(t_freq)
+
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, output_size):
+ super().__init__()
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True),
+ )
+ self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-5)
+ self.linear = nn.Linear(hidden_size, output_size, bias=True)
+
+ def forward(self, x, c, **_kwargs):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm(x), shift, scale)
+ return self.linear(x)
+
+
+class DiTBlock(nn.Module):
+ def __init__(
+ self,
+ attention: nn.Module,
+ ffn: nn.Module,
+ hidden_size: int = 1024,
+ modulation: bool = False,
+ eps: float = 1e-5,
+ **_kwargs,
+ ):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(
+ hidden_size, elementwise_affine=not modulation, eps=eps
+ )
+ self.norm2 = nn.LayerNorm(
+ hidden_size, elementwise_affine=not modulation, eps=eps
+ )
+ self.attn = attention
+ self.ffn = ffn
+ self.modulation = modulation
+ if modulation:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True),
+ )
+
+ def forward(self, x, condition=None, mask=None, **kwargs):
+ if condition is None:
+ assert not self.modulation, (
+ "Without global condition, must set modulation to False"
+ )
+ else:
+ assert self.modulation, "With global condition, must set modulation to True"
+ shift_attn, scale_attn, gate_attn, shift_ffn, scale_ffn, gate_ffn = (
+ self.adaLN_modulation(condition).chunk(6, dim=1)
+ )
+
+ if condition is not None:
+ pack_indices = kwargs.get("pack_indices")
+ if pack_indices is not None:
+ gate_attn = gate_attn[pack_indices]
+ gate_ffn = gate_ffn[pack_indices]
+ else:
+ gate_attn = gate_attn.unsqueeze(1)
+ gate_ffn = gate_ffn.unsqueeze(1)
+
+ if condition is not None:
+ x = x + gate_attn * self.attn(
+ modulate(self.norm1(x), shift_attn, scale_attn, **kwargs),
+ mask=mask,
+ **kwargs,
+ )
+ else:
+ x = x + self.attn(self.norm1(x), mask=mask, **kwargs)
+
+ if condition is not None:
+ x = x + gate_ffn * self.ffn(
+ modulate(self.norm2(x), shift_ffn, scale_ffn, **kwargs)
+ )
+ else:
+ x = x + self.ffn(self.norm2(x), mask=mask)
+ return x
+
+
+class DiT(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ transformer_config,
+ *,
+ mode: str = "flow_matching",
+ ):
+ super().__init__()
+ if mode not in {"flow_matching", "meanflow"}:
+ raise ValueError(
+ f"DiT mode must be 'flow_matching' or 'meanflow', got {mode!r}."
+ )
+
+ transformer_kwargs = transformer_config.to_dict()
+ model_dim = transformer_config.hidden_size
+ self.mode = mode
+ self.num_layers = transformer_config.num_layers
+
+ self.input_layer = nn.Linear(in_dim, model_dim)
+ self.time_embedder = TimestepEmbedder(model_dim)
+ if mode == "meanflow":
+ self.duration_embedder = TimestepEmbedder(model_dim)
+
+ self.blocks = nn.ModuleList()
+ for i in range(self.num_layers):
+ attn_block = MultiHeadAttention(**transformer_kwargs, name=f"layer_{i}")
+ ffn_block = Mlp(
+ act_layer=lambda: nn.GELU(approximate="tanh"), **transformer_kwargs
+ )
+ self.blocks.append(
+ DiTBlock(attention=attn_block, ffn=ffn_block, **transformer_kwargs)
+ )
+
+ self.output_layer = FinalLayer(model_dim, out_dim)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ nn.init.normal_(self.time_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.time_embedder.mlp[2].weight, std=0.02)
+
+ for block in self.blocks:
+ if hasattr(block, "adaLN_modulation"):
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ nn.init.constant_(self.output_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.output_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.output_layer.linear.weight, 0)
+ nn.init.constant_(self.output_layer.linear.bias, 0)
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ duration: torch.Tensor | None = None,
+ mask=None,
+ attn_mask=None,
+ g_cond: torch.Tensor | None = None,
+ **kwargs,
+ ):
+ t = self.time_embedder(timesteps)
+ c = t
+ duration_embedder = getattr(self, "duration_embedder", None)
+ if duration_embedder is not None and duration is not None:
+ c = c + duration_embedder(duration)
+ if g_cond is not None:
+ c = c + g_cond
+
+ x = self.input_layer(x)
+ for block in self.blocks:
+ x = block(x, c, mask=attn_mask, **kwargs)
+ return self.output_layer(x, c, **kwargs)
diff --git a/src/dots_tts/modules/backbone/layers.py b/src/dots_tts/modules/backbone/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ae1dcedce28456e9916b805e274605075a67b10
--- /dev/null
+++ b/src/dots_tts/modules/backbone/layers.py
@@ -0,0 +1,333 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+
+class Dropout(nn.Module):
+ def __init__(
+ self, p: float = 0.5, inplace: bool = False, force_drop: bool = False, **_kwargs
+ ):
+ super().__init__()
+ if p < 0.0 or p > 1.0:
+ raise ValueError(
+ f"dropout probability has to be between 0 and 1, but got {p}"
+ )
+ self.p = p
+ self.inplace = inplace
+ self.force_drop = force_drop
+
+ def forward(self, x, **_kwargs):
+ return F.dropout(
+ x,
+ p=self.p,
+ training=True if self.force_drop else self.training,
+ inplace=self.inplace,
+ )
+
+
+class Conv1d(nn.Conv1d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 1,
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ padding_mode: str = "zeros",
+ bias: bool = True,
+ padding=None,
+ causal: bool = False,
+ **_kwargs,
+ ):
+ self.causal = causal
+ if padding is None:
+ if causal:
+ padding = 0
+ self.left_padding = dilation * (kernel_size - 1)
+ else:
+ padding = int((kernel_size * dilation - dilation) / 2)
+
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ padding_mode=padding_mode,
+ bias=bias,
+ )
+
+ self.in_channels = in_channels
+
+ def forward(self, x):
+ if self.causal:
+ x = F.pad(x.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2)
+ return super().forward(x)
+
+
+class ConvTranspose1d(nn.ConvTranspose1d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ output_padding: int = 0,
+ groups: int = 1,
+ bias: bool = True,
+ dilation: int = 1,
+ padding=None,
+ padding_mode: str = "zeros",
+ causal: bool = False,
+ **_kwargs,
+ ):
+ if padding is None:
+ padding = 0 if causal else (kernel_size - stride) // 2
+ if causal:
+ assert padding == 0, "padding is not allowed in causal ConvTranspose1d."
+ assert kernel_size == 2 * stride, (
+ "kernel_size must be equal to 2*stride in Causal ConvTranspose1d."
+ )
+
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ bias=bias,
+ dilation=dilation,
+ padding_mode=padding_mode,
+ )
+
+ self.causal = causal
+ self.stride = stride
+
+ def forward(self, x):
+ x = super().forward(x)
+ if self.causal:
+ x = x[:, :, : -self.stride]
+ return x
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ ffn_hidden_size=4096,
+ act_layer=nn.GELU,
+ dropout=0.0,
+ **_kwargs,
+ ):
+ super().__init__()
+ self.fc1 = nn.Linear(hidden_size, ffn_hidden_size)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
+ self.drop = Dropout(dropout)
+
+ def forward(self, x, _mask=None):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ return self.drop(x)
+
+
+def rotate_half(x):
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+@torch.autocast(enabled=False, device_type="cuda")
+def apply_rotary_pos_emb(pos, t):
+ if pos.dim() == 3:
+ pos = pos.unsqueeze(1)
+ return t * pos.cos() + rotate_half(t) * pos.sin()
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim, theta=50000):
+ super().__init__()
+ self.register_buffer(
+ "inv_freq",
+ 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)),
+ persistent=False,
+ )
+ self._theta = float(theta)
+
+ def _apply(self, fn):
+ inv_freq = self.inv_freq
+ super()._apply(fn)
+ self.inv_freq = inv_freq.to(device=self.inv_freq.device, dtype=torch.float32)
+ return self
+
+ @torch.autocast(enabled=False, device_type="cuda")
+ def forward(self, t):
+ inv_freq = self.inv_freq
+ if inv_freq.device != t.device:
+ raise RuntimeError(
+ "RotaryEmbedding buffer device mismatch: "
+ f"inv_freq={inv_freq.device} input={t.device}."
+ )
+ t = t.to(dtype=inv_freq.dtype)
+ if t.dim() == 1:
+ freqs = torch.einsum("i , j -> i j", t, inv_freq)
+ else:
+ freqs = torch.einsum("bi, j -> bij", t, inv_freq)
+ return torch.cat((freqs, freqs), dim=-1)
+
+
+class MultiHeadAttention(nn.Module):
+ """Multi-head attention"""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ dropout: float = 0.0,
+ norm_layer: str = "LayerNorm",
+ rotary_bias: bool = False,
+ rotary_theta: float | None = 50000,
+ **_kwargs,
+ ):
+ super().__init__()
+ assert hidden_size % num_heads == 0, (
+ "hidden_size should be divisible by num_heads"
+ )
+ self.num_heads = num_heads
+ self.head_dim = hidden_size // num_heads
+ self.scale = self.head_dim**-0.5
+ self.rotary_bias = rotary_bias
+
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
+
+ norm_layer = getattr(nn, norm_layer)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+
+ self.attn_drop = Dropout(attn_drop)
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
+ self.o_dropout = Dropout(dropout)
+
+ if self.rotary_bias:
+ self.rotary = RotaryEmbedding(self.head_dim, theta=rotary_theta)
+
+ def forward(self, q, k=None, v=None, mask=None, pos_ids=None, **_kwargs):
+ k = k or q
+ v = v or q
+ B, L, _ = q.shape
+ _, S, _ = v.shape
+ if mask is not None:
+ if mask.ndim == 2: # [B, L]
+ assert L == S
+ mask = rearrange(mask, "b j -> b 1 1 j")
+ mask = mask.expand(-1, self.num_heads, L, -1)
+ elif mask.ndim == 3: # [B, L, S]
+ assert mask.size(1) == L and mask.size(2) == S
+ mask = mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
+
+ q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ # Apply rotary
+ if self.rotary_bias:
+ if L == S:
+ if pos_ids is None:
+ rotary_emb = self.rotary(torch.arange(L, device=q.device))
+ else:
+ rotary_emb = self.rotary(pos_ids)
+ q, k = (apply_rotary_pos_emb(rotary_emb, tensor) for tensor in (q, k))
+ else:
+ q_rotary_emb = self.rotary(torch.arange(L, device=q.device))
+ k_rotary_emb = self.rotary(torch.arange(S, device=k.device))
+ q = apply_rotary_pos_emb(q_rotary_emb, q)
+ k = apply_rotary_pos_emb(k_rotary_emb, k)
+
+ attn_bias = torch.zeros(B, self.num_heads, L, S, dtype=q.dtype, device=q.device)
+
+ if mask is not None:
+ attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
+
+ out = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=attn_bias,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ )
+
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.o_dropout(self.o_proj(out))
+
+ def decode_step(self, x, *, cache, positions: torch.Tensor):
+ if x.size(1) <= 0:
+ raise ValueError("MultiHeadAttention.decode_step expects a non-empty input.")
+ if positions.ndim != 1 or positions.size(0) != x.size(1):
+ raise ValueError(
+ "MultiHeadAttention.decode_step positions must match the decode block length."
+ )
+
+ q = self.q_proj(x)
+ k = self.k_proj(x)
+ v = self.v_proj(x)
+
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
+ q, k = self.q_norm(q), self.k_norm(k)
+ block_len = q.size(2)
+
+ if self.rotary_bias:
+ rotary_emb = self.rotary(positions)
+ q = apply_rotary_pos_emb(rotary_emb, q)
+ k = apply_rotary_pos_emb(rotary_emb, k)
+
+ cached_k, cached_v = cache
+ cached_k.index_copy_(2, positions, k)
+ cached_v.index_copy_(2, positions, v)
+
+ cache_capacity = cached_k.size(2)
+ key_positions = torch.arange(
+ cache_capacity,
+ device=x.device,
+ dtype=torch.long,
+ ).unsqueeze(0)
+ query_positions = positions.unsqueeze(1)
+ causal_mask = key_positions <= query_positions
+ valid_mask = key_positions <= positions[-1]
+ attn_bias = torch.zeros(
+ q.size(0),
+ self.num_heads,
+ block_len,
+ cache_capacity,
+ dtype=q.dtype,
+ device=q.device,
+ )
+ attn_bias.masked_fill_(
+ (causal_mask & valid_mask).unsqueeze(0).unsqueeze(0).logical_not(),
+ float("-inf"),
+ )
+
+ out = F.scaled_dot_product_attention(
+ q,
+ cached_k,
+ cached_v,
+ attn_mask=attn_bias,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ )
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.o_dropout(self.o_proj(out)), cache
diff --git a/src/dots_tts/modules/backbone/semantic_encoder.py b/src/dots_tts/modules/backbone/semantic_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..df9e7f226f13edf5391c8742d8aca3d3cd89e673
--- /dev/null
+++ b/src/dots_tts/modules/backbone/semantic_encoder.py
@@ -0,0 +1,356 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from dots_tts.modules.backbone.layers import Conv1d, Mlp, MultiHeadAttention
+
+
+@dataclass
+class SemanticEncoderDecodeState:
+ conv_tail: torch.Tensor
+ layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...]
+ seq_len: int
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads=16,
+ ffn_hidden_size=4096,
+ attn_dropout=0.0,
+ ffn_dropout=0.0,
+ norm_layer="LayerNorm",
+ **kwargs,
+ ):
+ super().__init__()
+ self.attn = MultiHeadAttention(
+ hidden_size,
+ num_heads,
+ attn_drop=attn_dropout,
+ norm_layer=norm_layer,
+ **kwargs,
+ )
+ norm_cls = getattr(nn, norm_layer)
+ self.attn_norm = norm_cls(hidden_size)
+ self.ffn = Mlp(
+ hidden_size, ffn_hidden_size, dropout=ffn_dropout, act_layer=nn.SiLU
+ )
+ self.ffn_norm = norm_cls(hidden_size)
+ self.hidden_size = hidden_size
+
+ def _build_causal_mask(self, T: int, device):
+ return torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
+
+ def _build_padding_mask(self, x_lens, max_len: int, device):
+ B = x_lens.size(0)
+ positions = torch.arange(max_len, device=device).unsqueeze(0).expand(B, -1)
+ return positions < x_lens.unsqueeze(1)
+
+ def _fuse_attn_mask(self, causal_mask, padding_mask):
+ if causal_mask is None and padding_mask is None:
+ return None
+ if causal_mask is None:
+ row = padding_mask.unsqueeze(2)
+ col = padding_mask.unsqueeze(1)
+ return row & col
+ if padding_mask is None:
+ return causal_mask.unsqueeze(0)
+
+ _B, _T = padding_mask.shape
+ causal = causal_mask.unsqueeze(0)
+ row = padding_mask.unsqueeze(2)
+ col = padding_mask.unsqueeze(1)
+ pad_2d = row & col
+ return causal & pad_2d
+
+ def forward(
+ self,
+ x,
+ x_lens=None,
+ causal=True,
+ ):
+ _B, T, C = x.shape
+ assert self.hidden_size == C
+ device = x.device
+
+ causal_mask = self._build_causal_mask(T, device) if causal else None
+ if x_lens is not None:
+ padding_mask = self._build_padding_mask(x_lens, T, device)
+ else:
+ padding_mask = None
+ fused_mask = self._fuse_attn_mask(causal_mask, padding_mask)
+
+ h = self.attn_norm(x)
+ h = self.attn(
+ q=h,
+ mask=fused_mask,
+ )
+ x = x + h
+
+ h = self.ffn_norm(x)
+ h = self.ffn(h)
+ return x + h
+
+ def decode_step(
+ self,
+ x,
+ *,
+ cache: tuple[torch.Tensor, torch.Tensor],
+ positions: torch.Tensor,
+ ):
+ if x.size(1) <= 0:
+ raise ValueError(
+ "TransformerEncoderLayer.decode_step expects a non-empty input."
+ )
+
+ h = self.attn_norm(x)
+ h, cache = self.attn.decode_step(h, cache=cache, positions=positions)
+ x = x + h
+
+ h = self.ffn_norm(x)
+ h = self.ffn(h)
+ return x + h, cache
+
+
+class SuperviseEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.get("hidden_size", 1024)
+ self.layers = nn.ModuleList(
+ [
+ TransformerEncoderLayer(
+ hidden_size=self.hidden_size,
+ num_heads=config.get("num_heads", 16),
+ ffn_hidden_size=config.get("ffn_hidden_size", 4096),
+ norm_layer=config.get("norm_layer", "LayerNorm"),
+ )
+ for _ in range(config.get("num_layers", 6))
+ ]
+ )
+ self.causal = config.get("causal", False)
+
+ def forward(self, x, x_lens=None):
+ batch_size, seq_len, _ = x.shape
+ if x_lens is None:
+ x_lens = torch.full(
+ (batch_size,), seq_len, device=x.device, dtype=torch.long
+ )
+ for layer in self.layers:
+ x = layer(x, x_lens=x_lens, causal=self.causal)
+ return x
+
+ def init_decode_state(
+ self,
+ *,
+ batch_size: int,
+ max_seq_len: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ):
+ layer_caches = []
+ for layer in self.layers:
+ cache_shape = (
+ batch_size,
+ layer.attn.num_heads,
+ max_seq_len,
+ layer.attn.head_dim,
+ )
+ layer_caches.append(
+ (
+ torch.zeros(cache_shape, dtype=dtype, device=device),
+ torch.zeros(cache_shape, dtype=dtype, device=device),
+ )
+ )
+ return tuple(layer_caches)
+
+ def reset_decode_state(
+ self,
+ layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...],
+ ) -> None:
+ if len(layer_caches) != len(self.layers):
+ raise ValueError("Layer cache count does not match encoder depth.")
+ for key_cache, value_cache in layer_caches:
+ key_cache.zero_()
+ value_cache.zero_()
+
+ def decode_step(self, x, *, layer_caches, positions: torch.Tensor):
+ if len(layer_caches) != len(self.layers):
+ raise ValueError("Layer cache count does not match encoder depth.")
+
+ for layer, cache in zip(self.layers, layer_caches, strict=True):
+ x, _ = layer.decode_step(x, cache=cache, positions=positions)
+ return x
+
+
+class VAESemanticEncoder(nn.Module):
+ def __init__(self, in_dim, out_dim, config):
+ super().__init__()
+ in_ds_rate = 2
+ self.patch_size = int(config.patch_size)
+ self.in_ds_rate = in_ds_rate
+ self.ds_proj = Conv1d(
+ in_dim, in_dim, kernel_size=in_ds_rate, stride=in_ds_rate, causal=True
+ )
+ self.in_proj = nn.Linear(in_dim, config.PatchEncoder.hidden_size)
+ self.encoder = SuperviseEncoder(config.PatchEncoder)
+ self.out_ds_rate = self.patch_size // in_ds_rate
+ self.out_proj = nn.Linear(
+ config.PatchEncoder.hidden_size * self.out_ds_rate, out_dim
+ )
+
+ def forward(self, x, x_lens=None):
+ x = self._downsample(x)
+ x = self.in_proj(x)
+ z = self.encoder(x, x_lens=x_lens)
+ return self._project_embeddings(z)
+
+ def init_decode_state(
+ self,
+ *,
+ max_audio_patch_count: int,
+ batch_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ) -> SemanticEncoderDecodeState:
+ return SemanticEncoderDecodeState(
+ conv_tail=torch.zeros(
+ (batch_size, self.ds_proj.in_channels, self.ds_proj.left_padding),
+ dtype=dtype,
+ device=device,
+ ),
+ layer_caches=self.encoder.init_decode_state(
+ batch_size=batch_size,
+ max_seq_len=max_audio_patch_count * self.out_ds_rate,
+ device=device,
+ dtype=dtype,
+ ),
+ seq_len=0,
+ )
+
+ def reset_decode_state(self, state: SemanticEncoderDecodeState) -> None:
+ state.conv_tail.zero_()
+ self.encoder.reset_decode_state(state.layer_caches)
+ state.seq_len = 0
+
+ def prefill(
+ self,
+ x,
+ state: SemanticEncoderDecodeState,
+ ) -> tuple[torch.Tensor, SemanticEncoderDecodeState]:
+ if x.ndim != 3:
+ raise ValueError(
+ f"VAESemanticEncoder.prefill expects rank-3 input, got {tuple(x.shape)}."
+ )
+ if x.size(1) % self.patch_size != 0:
+ raise ValueError(
+ f"Prompt latent length {x.size(1)} must be divisible by patch_size={self.patch_size}."
+ )
+
+ if x.size(1) == 0:
+ return (
+ x.new_zeros((x.size(0), 0, self.out_proj.out_features)),
+ state,
+ )
+ if state.conv_tail.size(0) != x.size(0):
+ raise ValueError(
+ "VAESemanticEncoder.prefill batch size does not match decode state."
+ )
+
+ step_inputs = self.in_proj(self._downsample(x))
+ expected_token_count = (x.size(1) // self.patch_size) * self.out_ds_rate
+ if step_inputs.size(1) != expected_token_count:
+ raise RuntimeError(
+ "Patch encoder prefill produced an unexpected token count: "
+ f"expected={expected_token_count} actual={step_inputs.size(1)}."
+ )
+
+ current_seq_len = state.seq_len
+ next_seq_len = current_seq_len + step_inputs.size(1)
+ cache_capacity = state.layer_caches[0][0].size(2)
+ if next_seq_len > cache_capacity:
+ raise ValueError(
+ "Patch encoder prefill exceeds decode-state capacity: "
+ f"required={next_seq_len} capacity={cache_capacity}."
+ )
+
+ positions = (
+ torch.arange(step_inputs.size(1), device=x.device, dtype=torch.long)
+ + current_seq_len
+ )
+ encoded = self.encoder.decode_step(
+ step_inputs,
+ layer_caches=state.layer_caches,
+ positions=positions,
+ )
+ embedding = self._project_embeddings(encoded)
+ raw = x.transpose(1, 2)
+ state.conv_tail.copy_(raw[..., -self.ds_proj.left_padding :])
+ state.seq_len = next_seq_len
+ return embedding, state
+
+ def decode_patch(
+ self,
+ latent_patch,
+ conv_tail: torch.Tensor,
+ layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...],
+ positions: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if latent_patch.ndim != 3:
+ raise ValueError(
+ f"VAESemanticEncoder.decode_patch expects rank-3 input, got {tuple(latent_patch.shape)}."
+ )
+ if latent_patch.size(1) != self.patch_size:
+ raise ValueError(
+ f"decode_patch expects patch length {self.patch_size}, got {latent_patch.size(1)}."
+ )
+ if positions.ndim != 1 or positions.size(0) != self.out_ds_rate:
+ raise ValueError(
+ "decode_patch positions must be a rank-1 tensor matching out_ds_rate."
+ )
+
+ step_inputs, conv_tail = self._downsample_step(
+ latent_patch,
+ conv_tail=conv_tail,
+ )
+ if step_inputs.size(1) != self.out_ds_rate:
+ raise RuntimeError(
+ f"Downsample step produced {step_inputs.size(1)} tokens, expected {self.out_ds_rate}."
+ )
+
+ encoded = self.encoder.decode_step(
+ step_inputs,
+ layer_caches=layer_caches,
+ positions=positions,
+ )
+ embedding = self._project_embeddings(encoded)
+ return embedding, conv_tail
+
+ def _downsample(self, x):
+ return self.ds_proj(x.transpose(1, 2)).transpose(1, 2)
+
+ def _project_embeddings(self, z):
+ if self.out_ds_rate > 1:
+ z = rearrange(z, "b (s d) h -> b s (d h)", d=self.out_ds_rate)
+ return self.out_proj(z)
+
+ def _downsample_step(self, latent_patch, *, conv_tail):
+ raw = latent_patch.transpose(1, 2)
+ conv_input = torch.cat([conv_tail, raw], dim=-1)
+
+ projected = F.conv1d(
+ conv_input,
+ self.ds_proj.weight,
+ self.ds_proj.bias,
+ stride=self.ds_proj.stride[0],
+ padding=0,
+ dilation=self.ds_proj.dilation[0],
+ groups=self.ds_proj.groups,
+ ).transpose(1, 2)
+ new_conv_tail = raw[..., -self.ds_proj.left_padding :]
+ return self.in_proj(projected), new_conv_tail
diff --git a/src/dots_tts/modules/speaker/__init__.py b/src/dots_tts/modules/speaker/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5ad248684adb71ec62e118df7d5f6ff367e9db5
--- /dev/null
+++ b/src/dots_tts/modules/speaker/__init__.py
@@ -0,0 +1 @@
+"""Speaker modules."""
diff --git a/src/dots_tts/modules/speaker/campplus.py b/src/dots_tts/modules/speaker/campplus.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc32030c09cf1c9115132483cd075024503c0a7d
--- /dev/null
+++ b/src/dots_tts/modules/speaker/campplus.py
@@ -0,0 +1,200 @@
+# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from dots_tts.modules.speaker.campplus_layers import (
+ BasicResBlock,
+ CAMDenseTDNNBlock,
+ DenseLayer,
+ StatsPool,
+ TDNNLayer,
+ TransitLayer,
+ get_nonlinear,
+)
+from dots_tts.modules.speaker.fbank import _SPEAKER_FBANK_N_MELS
+
+
+class FCM(nn.Module):
+ def __init__(
+ self,
+ block=BasicResBlock,
+ num_blocks=(2, 2),
+ m_channels=32,
+ feat_dim=_SPEAKER_FBANK_N_MELS,
+ ):
+ super().__init__()
+ self.in_planes = m_channels
+ self.conv1 = nn.Conv2d(
+ 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False
+ )
+ self.bn1 = nn.BatchNorm2d(m_channels)
+
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
+ self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
+
+ self.conv2 = nn.Conv2d(
+ m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False
+ )
+ self.bn2 = nn.BatchNorm2d(m_channels)
+ self.out_channels = m_channels * (feat_dim // 8)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = x.unsqueeze(1)
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = F.relu(self.bn2(self.conv2(out)))
+
+ shape = out.shape
+ return out.reshape(shape[0], shape[1] * shape[2], shape[3])
+
+
+class CAMPPlus(nn.Module):
+ _TDNN_KERNEL_SIZE = 5
+ _TDNN_STRIDE = 2
+ _TDNN_PADDING = 2
+
+ def __init__(
+ self,
+ feat_dim=_SPEAKER_FBANK_N_MELS,
+ embedding_size=512,
+ growth_rate=32,
+ bn_size=4,
+ init_channels=128,
+ config_str="batchnorm-relu",
+ memory_efficient=True,
+ ):
+ super().__init__()
+
+ self.head = FCM(feat_dim=feat_dim)
+ channels = self.head.out_channels
+
+ self.xvector = nn.Sequential(
+ OrderedDict(
+ [
+ (
+ "tdnn",
+ TDNNLayer(
+ channels,
+ init_channels,
+ self._TDNN_KERNEL_SIZE,
+ stride=self._TDNN_STRIDE,
+ dilation=1,
+ padding=-1,
+ config_str=config_str,
+ ),
+ ),
+ ]
+ )
+ )
+ channels = init_channels
+ for i, (num_layers, kernel_size, dilation) in enumerate(
+ zip((12, 24, 16), (3, 3, 3), (1, 2, 2), strict=True)
+ ):
+ block = CAMDenseTDNNBlock(
+ num_layers=num_layers,
+ in_channels=channels,
+ out_channels=growth_rate,
+ bn_channels=bn_size * growth_rate,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ config_str=config_str,
+ memory_efficient=memory_efficient,
+ )
+ self.xvector.add_module(f"block{i + 1}", block)
+ channels = channels + num_layers * growth_rate
+ self.xvector.add_module(
+ f"transit{i + 1}",
+ TransitLayer(
+ channels, channels // 2, bias=False, config_str=config_str
+ ),
+ )
+ channels //= 2
+
+ self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels))
+
+ self.xvector.add_module("stats", StatsPool())
+ self.xvector.add_module(
+ "dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_")
+ )
+
+ for m in self.modules():
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.kaiming_normal_(m.weight.data)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ @staticmethod
+ def _conv_output_lengths(lengths, kernel_size, stride=1, padding=0, dilation=1):
+ return (
+ torch.div(
+ lengths + 2 * padding - dilation * (kernel_size - 1) - 1,
+ stride,
+ rounding_mode="floor",
+ )
+ + 1
+ )
+
+ @staticmethod
+ def _make_length_mask(lengths, max_len, device):
+ lengths = lengths.to(device=device, dtype=torch.long).clamp(min=0, max=max_len)
+ return torch.arange(max_len, device=device).unsqueeze(0) < lengths.unsqueeze(1)
+
+ def _masked_stats_pooling(self, x, lengths, unbiased=True, eps=1e-2):
+ lengths = lengths.to(device=x.device, dtype=torch.long).clamp(
+ min=1, max=x.size(-1)
+ )
+ mask = self._make_length_mask(lengths, x.size(-1), x.device).unsqueeze(1)
+ mask = mask.to(dtype=x.dtype)
+
+ denom = lengths.to(dtype=x.dtype).view(-1, 1).clamp_min(1.0)
+ mean = (x * mask).sum(dim=-1) / denom
+
+ centered = (x - mean.unsqueeze(-1)) * mask
+ var_denom = (
+ (lengths - 1).clamp_min(1).to(dtype=x.dtype).view(-1, 1)
+ if unbiased
+ else denom
+ )
+ var = centered.pow(2).sum(dim=-1) / var_denom
+ std = torch.sqrt(var.clamp_min(eps))
+ return torch.cat([mean, std], dim=1)
+
+ def forward(self, x, lengths=None):
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
+ x = self.head(x)
+ if lengths is not None:
+ lengths = lengths.to(device=x.device, dtype=torch.long).clamp(min=1)
+
+ for name, module in self.xvector.named_children():
+ if name == "stats":
+ x = (
+ self._masked_stats_pooling(x, lengths)
+ if lengths is not None
+ else module(x)
+ )
+ continue
+
+ x = module(x)
+ if name == "tdnn" and lengths is not None:
+ lengths = self._conv_output_lengths(
+ lengths,
+ kernel_size=self._TDNN_KERNEL_SIZE,
+ stride=self._TDNN_STRIDE,
+ padding=self._TDNN_PADDING,
+ )
+
+ return x
diff --git a/src/dots_tts/modules/speaker/campplus_layers.py b/src/dots_tts/modules/speaker/campplus_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..9189009bc8e8e68a6cf61a5e11a4977cdaf15e9f
--- /dev/null
+++ b/src/dots_tts/modules/speaker/campplus_layers.py
@@ -0,0 +1,258 @@
+# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from torch import nn
+
+
+def get_nonlinear(config_str, channels):
+ nonlinear = nn.Sequential()
+ for name in config_str.split("-"):
+ if name == "relu":
+ nonlinear.add_module("relu", nn.ReLU(inplace=True))
+ elif name == "prelu":
+ nonlinear.add_module("prelu", nn.PReLU(channels))
+ elif name == "batchnorm":
+ nonlinear.add_module("batchnorm", nn.BatchNorm1d(channels))
+ elif name == "batchnorm_":
+ nonlinear.add_module("batchnorm", nn.BatchNorm1d(channels, affine=False))
+ else:
+ raise ValueError(f"Unexpected module ({name}).")
+ return nonlinear
+
+
+def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, _eps=1e-2):
+ mean = x.mean(dim=dim)
+ std = x.std(dim=dim, unbiased=unbiased)
+ stats = torch.cat([mean, std], dim=-1)
+ if keepdim:
+ stats = stats.unsqueeze(dim=dim)
+ return stats
+
+
+class StatsPool(nn.Module):
+ def forward(self, x):
+ return statistics_pooling(x)
+
+
+class TDNNLayer(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=False,
+ config_str="batchnorm-relu",
+ ):
+ super().__init__()
+ if padding < 0:
+ assert kernel_size % 2 == 1, (
+ f"Expect equal paddings, but got even kernel size ({kernel_size})"
+ )
+ padding = (kernel_size - 1) // 2 * dilation
+ self.linear = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+ self.nonlinear = get_nonlinear(config_str, out_channels)
+
+ def forward(self, x):
+ x = self.linear(x)
+ return self.nonlinear(x)
+
+
+class CAMLayer(nn.Module):
+ def __init__(
+ self,
+ bn_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ bias,
+ reduction=2,
+ ):
+ super().__init__()
+ self.linear_local = nn.Conv1d(
+ bn_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+ self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
+ self.relu = nn.ReLU(inplace=True)
+ self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ y = self.linear_local(x)
+ context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
+ context = self.relu(self.linear1(context))
+ m = self.sigmoid(self.linear2(context))
+ return y * m
+
+ def seg_pooling(self, x, seg_len=100, stype="avg"):
+ if stype == "avg":
+ seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+ elif stype == "max":
+ seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+ else:
+ raise ValueError("Wrong segment pooling type.")
+ shape = seg.shape
+ seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
+ return seg[..., : x.shape[-1]]
+
+
+class CAMDenseTDNNLayer(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ bn_channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ bias=False,
+ config_str="batchnorm-relu",
+ memory_efficient=False,
+ ):
+ super().__init__()
+ assert kernel_size % 2 == 1, (
+ f"Expect equal paddings, but got even kernel size ({kernel_size})"
+ )
+ padding = (kernel_size - 1) // 2 * dilation
+ self.memory_efficient = memory_efficient
+ self.nonlinear1 = get_nonlinear(config_str, in_channels)
+ self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
+ self.nonlinear2 = get_nonlinear(config_str, bn_channels)
+ self.cam_layer = CAMLayer(
+ bn_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+
+ def bn_function(self, x):
+ return self.linear1(self.nonlinear1(x))
+
+ def forward(self, x):
+ if self.training and self.memory_efficient:
+ x = cp.checkpoint(self.bn_function, x)
+ else:
+ x = self.bn_function(x)
+ return self.cam_layer(self.nonlinear2(x))
+
+
+class CAMDenseTDNNBlock(nn.ModuleList):
+ def __init__(
+ self,
+ num_layers,
+ in_channels,
+ out_channels,
+ bn_channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ bias=False,
+ config_str="batchnorm-relu",
+ memory_efficient=False,
+ ):
+ super().__init__()
+ for i in range(num_layers):
+ layer = CAMDenseTDNNLayer(
+ in_channels=in_channels + i * out_channels,
+ out_channels=out_channels,
+ bn_channels=bn_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ bias=bias,
+ config_str=config_str,
+ memory_efficient=memory_efficient,
+ )
+ self.add_module(f"tdnnd{i + 1}", layer)
+
+ def forward(self, x):
+ for layer in self:
+ x = torch.cat([x, layer(x)], dim=1)
+ return x
+
+
+class TransitLayer(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"
+ ):
+ super().__init__()
+ self.nonlinear = get_nonlinear(config_str, in_channels)
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+
+ def forward(self, x):
+ x = self.nonlinear(x)
+ return self.linear(x)
+
+
+class DenseLayer(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"
+ ):
+ super().__init__()
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+ self.nonlinear = get_nonlinear(config_str, out_channels)
+
+ def forward(self, x):
+ if len(x.shape) == 2:
+ x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
+ else:
+ x = self.linear(x)
+ return self.nonlinear(x)
+
+
+class BasicResBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super().__init__()
+ self.conv1 = nn.Conv2d(
+ in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False
+ )
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(
+ planes, planes, kernel_size=3, stride=1, padding=1, bias=False
+ )
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(
+ in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=(stride, 1),
+ bias=False,
+ ),
+ nn.BatchNorm2d(self.expansion * planes),
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ return F.relu(out)
diff --git a/src/dots_tts/modules/speaker/encoder.py b/src/dots_tts/modules/speaker/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1f2e0e55bd5900705255ce782770279500cc21d
--- /dev/null
+++ b/src/dots_tts/modules/speaker/encoder.py
@@ -0,0 +1,226 @@
+import math
+import random
+
+import torch
+import torch.nn as nn
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+
+from dots_tts.modules.speaker.campplus import CAMPPlus
+from dots_tts.modules.speaker.fbank import (
+ _SPEAKER_FBANK_N_MELS,
+ _SPEAKER_FBANK_SAMPLE_RATE,
+ extract_speaker_fbank,
+)
+
+
+class SpeakerXVectorFeatures(nn.Module):
+ """
+ Speaker embedding extractor based on 3D-Speaker CAM++.
+ """
+
+ def __init__(
+ self,
+ sample_rate=_SPEAKER_FBANK_SAMPLE_RATE,
+ campplus_embedding_size=512,
+ max_audio_seconds=10.0,
+ ):
+ super().__init__()
+
+ self.sample_rate = sample_rate
+ self.max_audio_seconds = float(max_audio_seconds)
+ self.model = CAMPPlus(
+ feat_dim=_SPEAKER_FBANK_N_MELS,
+ embedding_size=campplus_embedding_size,
+ )
+ self.resample = None
+ if self.sample_rate != _SPEAKER_FBANK_SAMPLE_RATE:
+ self.resample = torchaudio.transforms.Resample(
+ orig_freq=sample_rate,
+ new_freq=_SPEAKER_FBANK_SAMPLE_RATE,
+ )
+
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def _normalize_lengths(lengths, batch_size, max_length, device, *, min_length):
+ if lengths is None:
+ return torch.full(
+ (batch_size,),
+ max_length,
+ device=device,
+ dtype=torch.long,
+ )
+ return lengths.to(device=device, dtype=torch.long).clamp(
+ min=min_length,
+ max=max_length,
+ )
+
+ def _crop_audio(self, audio, audio_lengths=None):
+ original_lengths = self._normalize_lengths(
+ audio_lengths,
+ audio.size(0),
+ audio.size(-1),
+ audio.device,
+ min_length=0,
+ )
+ if self.max_audio_seconds <= 0:
+ return audio, original_lengths, original_lengths, torch.zeros_like(
+ original_lengths
+ )
+
+ max_input_length = round(self.sample_rate * self.max_audio_seconds)
+ cropped_audio = []
+ cropped_lengths = []
+ starts = []
+
+ for index, total_length_tensor in enumerate(original_lengths):
+ total_length = int(total_length_tensor.item())
+ cropped_length = min(total_length, max_input_length)
+ start = (
+ random.randint(0, total_length - cropped_length)
+ if total_length > cropped_length
+ else 0
+ )
+ cropped_audio.append(audio[index, start : start + cropped_length])
+ cropped_lengths.append(cropped_length)
+ starts.append(start)
+
+ return pad_sequence(
+ cropped_audio,
+ batch_first=True,
+ padding_value=0.0,
+ ), original_lengths, torch.tensor(
+ cropped_lengths,
+ device=audio.device,
+ dtype=torch.long,
+ ), torch.tensor(starts, device=audio.device, dtype=torch.long)
+
+ def _crop_fbank(
+ self,
+ fbank,
+ fbank_lengths,
+ original_audio_lengths,
+ cropped_audio_lengths,
+ starts,
+ ):
+ original_fbank_lengths = self._normalize_lengths(
+ fbank_lengths,
+ fbank.size(0),
+ fbank.size(1),
+ fbank.device,
+ min_length=1,
+ )
+ cropped_fbank = []
+ cropped_fbank_lengths = []
+
+ for index, total_feat_length_tensor in enumerate(original_fbank_lengths):
+ total_audio_length = int(original_audio_lengths[index].item())
+ total_feat_length = int(total_feat_length_tensor.item())
+ start_audio = int(starts[index].item())
+ end_audio = start_audio + int(cropped_audio_lengths[index].item())
+
+ if total_audio_length > 0:
+ start_feat = math.floor(
+ start_audio * total_feat_length / total_audio_length
+ )
+ end_feat = math.ceil(end_audio * total_feat_length / total_audio_length)
+ else:
+ start_feat = 0
+ end_feat = 1
+
+ start_feat = min(start_feat, total_feat_length - 1)
+ end_feat = min(max(end_feat, start_feat + 1), total_feat_length)
+ cropped_fbank.append(fbank[index, start_feat:end_feat])
+ cropped_fbank_lengths.append(end_feat - start_feat)
+
+ return pad_sequence(
+ cropped_fbank,
+ batch_first=True,
+ padding_value=0.0,
+ ), torch.tensor(
+ cropped_fbank_lengths,
+ device=fbank.device,
+ dtype=torch.long,
+ )
+
+ def _extract_fbank_batch(self, audio, audio_lengths):
+ if self.resample is not None:
+ audio = self.resample(audio)
+ audio_lengths = torch.ceil(
+ audio_lengths.float()
+ * (_SPEAKER_FBANK_SAMPLE_RATE / self.sample_rate)
+ ).long()
+
+ audio_cpu = audio.detach().cpu()
+ features = []
+
+ for index, valid_length_tensor in enumerate(audio_lengths):
+ valid_length = int(valid_length_tensor.item())
+ waveform = audio_cpu[index, :valid_length]
+ if waveform.numel() == 0:
+ waveform = audio_cpu.new_zeros(1)
+ features.append(
+ extract_speaker_fbank(
+ waveform,
+ sample_rate=_SPEAKER_FBANK_SAMPLE_RATE,
+ )
+ )
+
+ fbank_lengths = torch.tensor(
+ [feature.size(0) for feature in features],
+ device=audio.device,
+ dtype=torch.long,
+ )
+ fbank = pad_sequence(
+ features,
+ batch_first=True,
+ padding_value=0.0,
+ ).to(device=audio.device, dtype=audio.dtype)
+ return fbank, fbank_lengths
+
+ @torch.no_grad()
+ @torch.autocast(enabled=False, device_type="cuda")
+ def forward(
+ self, audio, audio_lengths=None, fbank=None, fbank_lengths=None, **_kwargs
+ ):
+ self.model.eval()
+ audio = audio.float()
+ if audio.dim() == 3:
+ if audio.size(1) != 1:
+ raise ValueError(
+ f"Speaker encoder expects mono audio, got shape {tuple(audio.shape)}."
+ )
+ audio = audio[:, 0]
+ elif audio.dim() != 2:
+ raise ValueError(
+ f"Speaker encoder expects a 2D or 3D audio tensor, got shape {tuple(audio.shape)}."
+ )
+
+ audio, original_audio_lengths, cropped_audio_lengths, starts = self._crop_audio(
+ audio,
+ audio_lengths=audio_lengths,
+ )
+
+ if fbank is None:
+ fbank, fbank_lengths = self._extract_fbank_batch(
+ audio,
+ cropped_audio_lengths,
+ )
+ else:
+ if not isinstance(fbank, torch.Tensor):
+ raise TypeError("Speaker encoder expects `fbank` to be a torch.Tensor.")
+ if fbank.dim() != 3 or fbank.size(0) != audio.size(0):
+ raise ValueError(
+ f"Speaker encoder expects `fbank` with shape (B, T, F) and matching batch size, got {tuple(fbank.shape)}."
+ )
+ fbank, fbank_lengths = self._crop_fbank(
+ fbank.to(device=audio.device, dtype=torch.float32),
+ fbank_lengths,
+ original_audio_lengths,
+ cropped_audio_lengths,
+ starts,
+ )
+
+ return self.model(fbank, lengths=fbank_lengths)
diff --git a/src/dots_tts/modules/speaker/fbank.py b/src/dots_tts/modules/speaker/fbank.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d416e9d6b34c6d471c73d712709652fe0a9a038
--- /dev/null
+++ b/src/dots_tts/modules/speaker/fbank.py
@@ -0,0 +1,31 @@
+from __future__ import annotations
+
+import torch
+
+from dots_tts.utils.audio import extract_fbank, high_quality_resample
+
+_SPEAKER_FBANK_SAMPLE_RATE = 16000
+_SPEAKER_FBANK_N_MELS = 80
+_SPEAKER_FBANK_MEAN_NORM = True
+_SPEAKER_FBANK_DITHER = 0.0
+
+
+def extract_speaker_fbank(
+ waveform: torch.Tensor,
+ *,
+ sample_rate: int,
+) -> torch.Tensor:
+ feature_input = waveform
+ if sample_rate != _SPEAKER_FBANK_SAMPLE_RATE:
+ feature_input = high_quality_resample(
+ waveform,
+ orig_sr=sample_rate,
+ target_sr=_SPEAKER_FBANK_SAMPLE_RATE,
+ )
+ return extract_fbank(
+ feature_input,
+ sample_rate=_SPEAKER_FBANK_SAMPLE_RATE,
+ n_mels=_SPEAKER_FBANK_N_MELS,
+ dither=_SPEAKER_FBANK_DITHER,
+ mean_norm=_SPEAKER_FBANK_MEAN_NORM,
+ )
diff --git a/src/dots_tts/modules/vocoder/__init__.py b/src/dots_tts/modules/vocoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f16ef8bd262da6f4d8d4715162bc600d8b140ef7
--- /dev/null
+++ b/src/dots_tts/modules/vocoder/__init__.py
@@ -0,0 +1 @@
+"""Vocoder modules."""
diff --git a/src/dots_tts/modules/vocoder/alias_free_act.py b/src/dots_tts/modules/vocoder/alias_free_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..948d2c69eaaddb9cacc7cf075ca00020f368b8d2
--- /dev/null
+++ b/src/dots_tts/modules/vocoder/alias_free_act.py
@@ -0,0 +1,163 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+from torch import pow, sin
+from torch.nn import Parameter
+
+from .alias_free_resample import DownSample1d, UpSample1d
+
+
+class Activation1d(nn.Module):
+ def __init__(
+ self,
+ activation,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12,
+ causal=True,
+ fixed_filter=False,
+ ):
+ super().__init__()
+ # causal=False
+ self.up_ratio = up_ratio
+ self.down_ratio = down_ratio
+ self.act = activation
+ self.upsample = UpSample1d(
+ up_ratio,
+ up_kernel_size,
+ activation.in_features,
+ causal=causal,
+ fixed_filter=fixed_filter,
+ )
+ self.downsample = DownSample1d(
+ down_ratio,
+ down_kernel_size,
+ activation.in_features,
+ causal=causal,
+ fixed_filter=fixed_filter,
+ )
+
+ # x: [B,C,T]
+ def forward(self, x):
+ x = self.upsample(x)
+ x = self.act(x)
+ return self.downsample(x)
+
+
+class Snake(nn.Module):
+ """
+ Implementation of a sine-based periodic activation function
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter
+ References:
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snake(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ """
+
+ def __init__(
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
+ ):
+ """
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha: trainable parameter
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ alpha will be trained along with the rest of your model.
+ """
+ super().__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ """
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ Snake := x + 1/a * sin^2 (xa)
+ """
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ return x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+
+class SnakeBeta(nn.Module):
+ """
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snakebeta(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ """
+
+ def __init__(
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
+ ):
+ """
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ beta is initialized to 1 by default, higher values = higher-magnitude.
+ alpha will be trained along with the rest of your model.
+ """
+ super().__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+ self.beta = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ """
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta := x + 1/b * sin^2 (xa)
+ """
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ return x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
diff --git a/src/dots_tts/modules/vocoder/alias_free_filter.py b/src/dots_tts/modules/vocoder/alias_free_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..4913366ea7a73de56a4136887a51b9584e618e33
--- /dev/null
+++ b/src/dots_tts/modules/vocoder/alias_free_filter.py
@@ -0,0 +1,114 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+if "sinc" in dir(torch):
+ sinc = torch.sinc
+else:
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
+ # https://adefossez.github.io/julius/julius/core.html
+ # LICENSE is in incl_licenses directory.
+ def sinc(x: torch.Tensor):
+ """
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
+ """
+ return torch.where(
+ x == 0,
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
+ torch.sin(math.pi * x) / math.pi / x,
+ )
+
+
+# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
+# https://adefossez.github.io/julius/julius/lowpass.html
+# LICENSE is in incl_licenses directory.
+def kaiser_sinc_filter1d(
+ cutoff, half_width, kernel_size
+): # return filter [1,1,kernel_size]
+ even = kernel_size % 2 == 0
+ half_size = kernel_size // 2
+
+ # For kaiser window
+ delta_f = 4 * half_width
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+ if A > 50.0:
+ beta = 0.1102 * (A - 8.7)
+ elif A >= 21.0:
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
+ else:
+ beta = 0.0
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
+
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
+ if even:
+ time = torch.arange(-half_size, half_size) + 0.5
+ else:
+ time = torch.arange(kernel_size) - half_size
+ if cutoff == 0:
+ filter_ = torch.zeros_like(time)
+ else:
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
+ # of the constant component in the input signal.
+ filter_ /= filter_.sum()
+ filter = filter_.view(1, 1, kernel_size)
+
+ return filter
+
+
+class LowPassFilter1d(nn.Module):
+ def __init__(
+ self,
+ cutoff=0.5,
+ half_width=0.6,
+ stride: int = 1,
+ padding: bool = True,
+ padding_mode: str = "replicate",
+ kernel_size: int = 12,
+ channels: int = 1,
+ causal: bool = True,
+ fixed_filter: bool = False,
+ ):
+ # kernel_size should be even number for stylegan3 setup,
+ # in this implementation, odd number is also possible.
+ super().__init__()
+ if cutoff < -0.0:
+ raise ValueError("Minimum cutoff must be larger than zero.")
+ if cutoff > 0.5:
+ raise ValueError("A cutoff above 0.5 does not make sense.")
+ self.kernel_size = kernel_size
+ if causal:
+ self.pad_left = kernel_size - 1
+ self.pad_right = 0
+ else:
+ self.even = kernel_size % 2 == 0
+ self.pad_left = kernel_size // 2 - int(self.even)
+ self.pad_right = kernel_size // 2
+ self.stride = stride
+ self.padding = padding
+ self.padding_mode = padding_mode
+ self.fixed_filter = fixed_filter
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+ if fixed_filter:
+ self.register_buffer("filter", filter)
+ else:
+ self.filter = nn.Parameter(filter.expand(channels, -1, -1).clone())
+
+ # input [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+ if self.padding:
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
+ if self.fixed_filter:
+ out = F.conv1d(
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
+ )
+ else:
+ out = F.conv1d(x, self.filter, stride=self.stride, groups=C)
+ return out
diff --git a/src/dots_tts/modules/vocoder/alias_free_resample.py b/src/dots_tts/modules/vocoder/alias_free_resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..e10cf7e76c53f7bb3f19d0f337bcdc1931505fc2
--- /dev/null
+++ b/src/dots_tts/modules/vocoder/alias_free_resample.py
@@ -0,0 +1,81 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import torch.nn as nn
+from torch.nn import functional as F
+
+from .alias_free_filter import LowPassFilter1d, kaiser_sinc_filter1d
+
+
+class UpSample1d(nn.Module):
+ def __init__(
+ self, ratio=2, kernel_size=None, channels=None, causal=True, fixed_filter=False
+ ):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = (
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ )
+ self.stride = ratio
+ self.channels = channels
+ self.causal = causal
+ self.fixed_filter = fixed_filter
+ if causal:
+ self.pad = 0
+ else:
+ self.pad = self.kernel_size // ratio - 1
+ self.pad_left = (
+ self.pad * self.stride + (self.kernel_size - self.stride) // 2
+ )
+ self.pad_right = (
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+ )
+ filter = kaiser_sinc_filter1d(
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
+ )
+ if self.fixed_filter:
+ self.register_buffer("filter", filter)
+ else:
+ self.filter = nn.Parameter(filter.expand(channels, -1, -1).clone())
+
+ # x: [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
+ if self.fixed_filter:
+ x = self.ratio * F.conv_transpose1d(
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
+ )
+ else:
+ x = self.ratio * F.conv_transpose1d(
+ x, self.filter, stride=self.stride, groups=C
+ )
+ if self.causal:
+ x = x[..., : -(self.kernel_size - self.stride)]
+ else:
+ x = x[..., self.pad_left : -self.pad_right]
+
+ return x
+
+
+class DownSample1d(nn.Module):
+ def __init__(
+ self, ratio=2, kernel_size=None, channels=None, causal=True, fixed_filter=False
+ ):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = (
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ )
+ self.lowpass = LowPassFilter1d(
+ cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ stride=ratio,
+ kernel_size=self.kernel_size,
+ channels=channels,
+ causal=causal,
+ fixed_filter=fixed_filter,
+ )
+
+ def forward(self, x):
+ return self.lowpass(x)
diff --git a/src/dots_tts/modules/vocoder/bigvgan.py b/src/dots_tts/modules/vocoder/bigvgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0e21c4c4a06c2b13a941828c7c16f49ce2924b4
--- /dev/null
+++ b/src/dots_tts/modules/vocoder/bigvgan.py
@@ -0,0 +1,1070 @@
+# Copyright (c) 2022 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+from __future__ import annotations
+
+import itertools
+from dataclasses import dataclass
+from fractions import Fraction
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils import remove_weight_norm, weight_norm
+
+from dots_tts.modules.backbone.layers import Conv1d, ConvTranspose1d
+from dots_tts.modules.vocoder.alias_free_act import Activation1d, Snake, SnakeBeta
+from dots_tts.modules.vocoder.config import AudioVAEConfig
+
+
+@dataclass(slots=True)
+class BigVGANStreamState:
+ lstm_hidden: tuple[torch.Tensor, torch.Tensor]
+ decoder: "DecoderStreamState"
+
+
+@dataclass(slots=True)
+class DecoderStreamState:
+ window: torch.Tensor
+ chunk_size: int
+ total_frames: int = 0
+ emitted_frames: int = 0
+
+def _empty_chunk(
+ ref: torch.Tensor,
+ *,
+ channels: int | None = None,
+ dtype: torch.dtype | None = None,
+) -> torch.Tensor:
+ return ref.new_zeros(
+ (ref.size(0), channels or ref.size(1), 0),
+ dtype=dtype or ref.dtype,
+ )
+
+
+def _module_state_device_dtype(module: nn.Module) -> tuple[torch.device, torch.dtype]:
+ for name in ("weight", "bias", "filter"):
+ tensor = getattr(module, name, None)
+ if tensor is not None:
+ return tensor.device, tensor.dtype
+ for tensor in itertools.chain(module.parameters(), module.buffers()):
+ return tensor.device, tensor.dtype
+ raise RuntimeError(f"Unable to infer state dtype/device for {type(module).__name__}.")
+
+
+def _stream_state_zeros(
+ batch_size: int,
+ channels: int,
+ length: int,
+ *,
+ device: torch.device,
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ return torch.zeros(
+ (batch_size, channels, max(0, int(length))),
+ device=device,
+ dtype=dtype,
+ )
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+class Conv1d_S(nn.Module):
+ "Conv1d for spectral normalisation and orthogonal initialisation"
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ dilation=1,
+ groups=1,
+ causal=False,
+ ):
+
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.dilation = dilation
+ self.groups = groups
+ self.causal = causal
+ pad = 0 if causal else dilation * (kernel_size - 1) // 2
+ self.causal_pad = dilation * (kernel_size - 1) if causal else 0
+
+ self.layer = weight_norm(
+ nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=pad,
+ dilation=dilation,
+ groups=groups,
+ )
+ )
+
+ def forward(self, inputs):
+ if self.causal and self.causal_pad > 0:
+ inputs = F.pad(inputs, (self.causal_pad, 0))
+ return self.layer(inputs)
+
+
+class SLSTM(nn.Module):
+ """
+ LSTM without worrying about the hidden state, nor the layout of the data.
+ Expects input as convolutional layout.
+ """
+
+ def __init__(
+ self,
+ dimension: int,
+ num_layers: int = 2,
+ skip: bool = True,
+ bidirectional: bool = False,
+ ):
+ super().__init__()
+ self.skip = skip
+ self.bidirectional = bidirectional
+ self.lstm = nn.LSTM(
+ input_size=dimension,
+ hidden_size=dimension,
+ num_layers=num_layers,
+ bidirectional=bidirectional,
+ batch_first=True,
+ )
+ self._stream_num_layers = num_layers
+ self._stream_weight_ih = tuple(
+ getattr(self.lstm, f"weight_ih_l{layer_idx}")
+ for layer_idx in range(num_layers)
+ )
+ self._stream_weight_hh = tuple(
+ getattr(self.lstm, f"weight_hh_l{layer_idx}")
+ for layer_idx in range(num_layers)
+ )
+ self._stream_bias_ih = tuple(
+ getattr(self.lstm, f"bias_ih_l{layer_idx}")
+ for layer_idx in range(num_layers)
+ )
+ self._stream_bias_hh = tuple(
+ getattr(self.lstm, f"bias_hh_l{layer_idx}")
+ for layer_idx in range(num_layers)
+ )
+ if self.bidirectional:
+ self.proj_out = nn.Linear(dimension * 2, dimension)
+
+ def forward(self, x):
+ y, _ = self.lstm(x)
+ if self.bidirectional:
+ y = self.proj_out(y)
+ if self.skip:
+ y = y + x
+ return y
+
+ def stream_step(
+ self,
+ x: torch.Tensor,
+ hidden: tuple[torch.Tensor, torch.Tensor] | None,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ if self.bidirectional:
+ raise RuntimeError("Streaming only supports unidirectional SLSTM.")
+
+ residual = x
+ if hidden is None:
+ hidden = self.init_stream_state(x.size(0))
+
+ hidden_h, hidden_c = hidden
+ next_hidden_h = []
+ next_hidden_c = []
+ for layer_idx in range(self._stream_num_layers):
+ layer_input = x
+ hx = hidden_h[layer_idx]
+ cx = hidden_c[layer_idx]
+ outputs = []
+ weight_ih = self._stream_weight_ih[layer_idx]
+ weight_hh = self._stream_weight_hh[layer_idx]
+ bias_ih = self._stream_bias_ih[layer_idx]
+ bias_hh = self._stream_bias_hh[layer_idx]
+
+ for frame_idx in range(x.size(1)):
+ gates = F.linear(layer_input[:, frame_idx, :], weight_ih, bias_ih)
+ gates = gates + F.linear(hx, weight_hh, bias_hh)
+ input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, dim=-1)
+ input_gate = torch.sigmoid(input_gate)
+ forget_gate = torch.sigmoid(forget_gate)
+ cell_gate = torch.tanh(cell_gate)
+ output_gate = torch.sigmoid(output_gate)
+ cx = forget_gate * cx + input_gate * cell_gate
+ hx = output_gate * torch.tanh(cx)
+ outputs.append(hx)
+
+ x = torch.stack(outputs, dim=1)
+ next_hidden_h.append(hx)
+ next_hidden_c.append(cx)
+
+ y = x
+ if self.skip:
+ y = y + residual
+ return y, (torch.stack(next_hidden_h, dim=0), torch.stack(next_hidden_c, dim=0))
+
+ def init_stream_state(
+ self,
+ batch_size: int,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ num_directions = 2 if self.bidirectional else 1
+ state_shape = (
+ self.lstm.num_layers * num_directions,
+ batch_size,
+ self.lstm.hidden_size,
+ )
+ weight = self.lstm.weight_hh_l0
+ return (
+ weight.new_zeros(state_shape),
+ weight.new_zeros(state_shape),
+ )
+
+
+class ResStack(nn.Module):
+ def __init__(self, channel, kernel_size=3, base=3, nums=4, causal=False):
+ super().__init__()
+
+ self.layers = nn.ModuleList([])
+ for i in range(nums):
+ dil = base**i
+ pad1 = dil * (kernel_size - 1) if causal else dil
+ pad2 = (kernel_size - 1) if causal else 1
+ block = [
+ nn.LeakyReLU(),
+ ]
+ if causal and pad1 > 0:
+ block.append(nn.ConstantPad1d((pad1, 0), 0.0))
+ block.append(
+ nn.utils.weight_norm(
+ nn.Conv1d(
+ channel,
+ channel,
+ kernel_size=kernel_size,
+ dilation=dil,
+ padding=0 if causal else pad1,
+ )
+ )
+ )
+ block.append(nn.LeakyReLU())
+ if causal and pad2 > 0:
+ block.append(nn.ConstantPad1d((pad2, 0), 0.0))
+ block.append(
+ nn.utils.weight_norm(
+ nn.Conv1d(
+ channel,
+ channel,
+ kernel_size=kernel_size,
+ dilation=1,
+ padding=0 if causal else pad2,
+ )
+ )
+ )
+ self.layers.append(nn.Sequential(*block))
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = x + layer(x)
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=1,
+ out_channels=100,
+ base_channels=12,
+ proj_kernel_size=3,
+ stack_kernel_size=3,
+ stack_dilation_base=2,
+ stacks=6,
+ channels=(12, 24, 48, 96, 192, 384, 768),
+ down_sample_factors=(2, 2, 2, 2, 4, 4),
+ causal=False,
+ lookahead=0,
+ ):
+ super().__init__()
+
+ act_slope = 0.2
+ layers = []
+ # pre proj_layer
+ layers += [
+ Conv1d_S(
+ in_channels,
+ base_channels,
+ kernel_size=proj_kernel_size,
+ stride=1,
+ causal=causal,
+ ),
+ nn.LeakyReLU(act_slope, True),
+ ]
+
+ # channels: [512, 256, 128, 64], upsample_factors: [5, 2, 2]
+ for (in_c, out_c), down_f in zip(
+ itertools.pairwise(channels), down_sample_factors, strict=True
+ ):
+ layers += [
+ Conv1d_S(
+ in_c, out_c, kernel_size=down_f * 2, stride=down_f, causal=causal
+ ),
+ ResStack(
+ out_c, stack_kernel_size, stack_dilation_base, stacks, causal=causal
+ ),
+ nn.LeakyReLU(act_slope, True),
+ ]
+
+ # post layers
+ if lookahead > 0:
+ layers += [
+ Conv1d_S(
+ channels[-1],
+ out_channels,
+ kernel_size=lookahead * 2 + 1,
+ stride=1,
+ causal=False,
+ ),
+ ]
+ else:
+ layers += [
+ Conv1d_S(
+ channels[-1],
+ out_channels,
+ kernel_size=proj_kernel_size,
+ stride=1,
+ causal=causal,
+ ),
+ ]
+ self.generator = nn.Sequential(*layers)
+
+ def forward(self, conditions, _z_inputs=None):
+ return self.generator(conditions)
+
+
+class AMPBlock1(torch.nn.Module):
+ def __init__(
+ self,
+ h,
+ channels,
+ kernel_size=3,
+ dilation=(1, 3, 5),
+ activation=None,
+ causal=True,
+ ):
+ super().__init__()
+ self.h = h
+
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ causal=causal,
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ causal=causal,
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ causal=causal,
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels, channels, kernel_size, 1, dilation=1, causal=causal
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels, channels, kernel_size, 1, dilation=1, causal=causal
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels, channels, kernel_size, 1, dilation=1, causal=causal
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ self.num_layers = len(self.convs1) + len(
+ self.convs2
+ ) # total number of conv layers
+
+ if (
+ activation == "snake"
+ ): # periodic nonlinearity with snake function and anti-aliasing
+ self.activations = nn.ModuleList(
+ [
+ Activation1d(
+ activation=Snake(channels, alpha_logscale=h.snake_logscale),
+ causal=causal,
+ fixed_filter=True,
+ )
+ for _ in range(self.num_layers)
+ ]
+ )
+ elif (
+ activation == "snakebeta"
+ ): # periodic nonlinearity with snakebeta function and anti-aliasing
+ self.activations = nn.ModuleList(
+ [
+ Activation1d(
+ activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale),
+ causal=causal,
+ fixed_filter=True,
+ )
+ for _ in range(self.num_layers)
+ ]
+ )
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'."
+ )
+
+ def forward(self, x):
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2, strict=True):
+ xt = a1(x)
+ xt = c1(xt)
+ xt = a2(xt)
+ xt = c2(xt)
+ x = xt + x
+
+ return x
+
+ def remove_weight_norm(self):
+ for layer in self.convs1:
+ remove_weight_norm(layer)
+ for layer in self.convs2:
+ remove_weight_norm(layer)
+
+
+class AMPBlock2(torch.nn.Module):
+ def __init__(
+ self, h, channels, kernel_size=3, dilation=(1, 3), activation=None, causal=True
+ ):
+ super().__init__()
+ self.h = h
+
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ causal=causal,
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ causal=causal,
+ )
+ ),
+ ]
+ )
+ self.convs.apply(init_weights)
+
+ self.num_layers = len(self.convs) # total number of conv layers
+
+ if (
+ activation == "snake"
+ ): # periodic nonlinearity with snake function and anti-aliasing
+ self.activations = nn.ModuleList(
+ [
+ Activation1d(
+ activation=Snake(channels, alpha_logscale=h.snake_logscale),
+ causal=causal,
+ fixed_filter=True,
+ )
+ for _ in range(self.num_layers)
+ ]
+ )
+ elif (
+ activation == "snakebeta"
+ ): # periodic nonlinearity with snakebeta function and anti-aliasing
+ self.activations = nn.ModuleList(
+ [
+ Activation1d(
+ activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale),
+ causal=causal,
+ fixed_filter=True,
+ )
+ for _ in range(self.num_layers)
+ ]
+ )
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'."
+ )
+
+ def forward(self, x):
+ for c, a in zip(self.convs, self.activations, strict=True):
+ xt = a(x)
+ xt = c(xt)
+ x = xt + x
+
+ return x
+
+ def remove_weight_norm(self):
+ for layer in self.convs:
+ remove_weight_norm(layer)
+
+
+class Decoder(nn.Module):
+ def __init__(self, h):
+ super().__init__()
+ self.h = h
+ causal = h.causal
+ self._stream_window_sizes: dict[int, int] = {}
+
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+
+ num_decoder_lookahead = h.get("num_decoder_lookahead", 2)
+ # pre conv
+ self.conv_pre = weight_norm(
+ Conv1d(
+ h.latent_dim,
+ h.upsample_initial_channel,
+ kernel_size=2 * num_decoder_lookahead + 1,
+ stride=1,
+ causal=False,
+ )
+ )
+
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
+ resblock = AMPBlock1 if h.resblock == "1" else AMPBlock2
+
+ # transposed conv-based upsamplers. does not apply anti-aliasing
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(
+ zip(h.upsample_rates, h.upsample_kernel_sizes, strict=True)
+ ):
+ self.ups.append(
+ nn.ModuleList(
+ [
+ weight_norm(
+ ConvTranspose1d(
+ h.upsample_initial_channel // (2**i),
+ h.upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ causal=causal,
+ )
+ )
+ ]
+ )
+ )
+
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
+ for k, d in zip(
+ h.resblock_kernel_sizes, h.resblock_dilation_sizes, strict=True
+ ):
+ self.resblocks.append(
+ resblock(h, ch, k, d, activation=h.activation, causal=causal)
+ )
+
+ # post conv
+ if (
+ h.activation == "snake"
+ ): # periodic nonlinearity with snake function and anti-aliasing
+ activation_post = Snake(ch, alpha_logscale=h.snake_logscale)
+ self.activation_post = Activation1d(
+ activation=activation_post, causal=causal, fixed_filter=False
+ )
+ elif (
+ h.activation == "snakebeta"
+ ): # periodic nonlinearity with snakebeta function and anti-aliasing
+ activation_post = SnakeBeta(ch, alpha_logscale=h.snake_logscale)
+ self.activation_post = Activation1d(
+ activation=activation_post, causal=causal, fixed_filter=False
+ )
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'."
+ )
+
+ self.conv_post = weight_norm(
+ Conv1d(ch, 1, 7, 1, causal=causal, bias=h.get("use_bias_at_final", True))
+ )
+
+ # weight initialization
+ for i in range(len(self.ups)):
+ self.ups[i].apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, z):
+ # pre conv
+ x = self.conv_pre(z)
+
+ for i in range(self.num_upsamples):
+ # upsampling
+ for i_up in range(len(self.ups[i])):
+ x = self.ups[i][i_up](x)
+ # AMP blocks
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ # post conv
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ if self.h.get("use_tanh_at_final", True):
+ x = torch.tanh(x)
+ else:
+ x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
+
+ return x
+
+ @property
+ def stream_lookahead(self) -> int:
+ return (self.conv_pre.kernel_size[0] - 1) // 2
+
+ @staticmethod
+ def _conv1d_left_context(layer) -> int:
+ dilation = layer.dilation[0] if isinstance(layer.dilation, tuple) else layer.dilation
+ kernel_size = (
+ layer.kernel_size[0]
+ if isinstance(layer.kernel_size, tuple)
+ else layer.kernel_size
+ )
+ if getattr(layer, "causal", False):
+ return dilation * (kernel_size - 1)
+ return layer.padding[0] if isinstance(layer.padding, tuple) else layer.padding
+
+ @staticmethod
+ def _convtranspose1d_left_context(layer) -> int:
+ stride = layer.stride[0] if isinstance(layer.stride, tuple) else layer.stride
+ kernel_size = (
+ layer.kernel_size[0]
+ if isinstance(layer.kernel_size, tuple)
+ else layer.kernel_size
+ )
+ if not getattr(layer, "causal", False):
+ raise NotImplementedError("Streaming only supports causal ConvTranspose1d.")
+ if kernel_size != 2 * stride:
+ raise ValueError(
+ "Streaming ConvTranspose1d expects kernel_size == 2 * stride, got "
+ f"kernel_size={kernel_size} stride={stride}."
+ )
+ return 1
+
+ @classmethod
+ def _activation_left_context(cls, activation: Activation1d) -> int:
+ upsample = activation.upsample
+ downsample = activation.downsample.lowpass
+ if not upsample.causal or not downsample.padding or downsample.pad_right != 0:
+ raise NotImplementedError("Streaming only supports causal alias-free activations.")
+ ratio = int(upsample.ratio)
+ if ratio != int(downsample.stride):
+ raise ValueError(
+ "Alias-free activation expects matched up/down ratios, got "
+ f"up_ratio={ratio} down_ratio={downsample.stride}."
+ )
+ total_left = (upsample.kernel_size - 1) + (downsample.kernel_size - 1)
+ return (total_left + ratio - 1) // ratio
+
+ @classmethod
+ def _ampblock_left_context(cls, block) -> int:
+ if isinstance(block, AMPBlock1):
+ left_context = 0
+ acts1 = block.activations[::2]
+ acts2 = block.activations[1::2]
+ for conv1, conv2, act1, act2 in zip(
+ block.convs1,
+ block.convs2,
+ acts1,
+ acts2,
+ strict=True,
+ ):
+ left_context += (
+ cls._activation_left_context(act1)
+ + cls._conv1d_left_context(conv1)
+ + cls._activation_left_context(act2)
+ + cls._conv1d_left_context(conv2)
+ )
+ return left_context
+ if isinstance(block, AMPBlock2):
+ left_context = 0
+ for conv, activation in zip(block.convs, block.activations, strict=True):
+ left_context += (
+ cls._activation_left_context(activation)
+ + cls._conv1d_left_context(conv)
+ )
+ return left_context
+ raise TypeError(f"Unsupported resblock type: {type(block).__name__}.")
+
+ def _stream_left_context(self) -> int:
+ left_context = Fraction(self._conv1d_left_context(self.conv_pre), 1)
+ current_scale = Fraction(1, 1)
+ for stage_idx, upsample_layers in enumerate(self.ups):
+ for upsample in upsample_layers:
+ left_context += current_scale * self._convtranspose1d_left_context(
+ upsample
+ )
+ stride = (
+ upsample.stride[0]
+ if isinstance(upsample.stride, tuple)
+ else upsample.stride
+ )
+ current_scale /= int(stride)
+
+ stage_start = stage_idx * self.num_kernels
+ stage_end = stage_start + self.num_kernels
+ stage_context = max(
+ self._ampblock_left_context(block)
+ for block in self.resblocks[stage_start:stage_end]
+ )
+ left_context += current_scale * stage_context
+
+ left_context += current_scale * self._activation_left_context(self.activation_post)
+ left_context += current_scale * self._conv1d_left_context(self.conv_post)
+ return int(left_context.__ceil__())
+
+ def stream_window_size(self, chunk_size: int) -> int:
+ if chunk_size < 1:
+ raise ValueError(f"chunk_size must be >= 1, got {chunk_size}.")
+ cached = self._stream_window_sizes.get(chunk_size)
+ if cached is not None:
+ return cached
+
+ window_size = chunk_size + self.stream_lookahead + self._stream_left_context()
+ self._stream_window_sizes[chunk_size] = window_size
+ return window_size
+
+ def remove_weight_norm(self):
+ for upsample_layers in self.ups:
+ for upsample_layer in upsample_layers:
+ remove_weight_norm(upsample_layer)
+ for resblock in self.resblocks:
+ resblock.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+
+class AudioVAE(nn.Module):
+ def __init__(self, h: AudioVAEConfig):
+ super().__init__()
+ self.config = h
+
+ self.h = h
+ self.hop_size = int(np.prod(h.downsample_rates))
+ self.sample_rate = h.sample_rate
+ self.decoder_lookahead = int(h.get("num_decoder_lookahead", 2))
+
+ self.audio_encoder = Encoder(
+ out_channels=h.latent_dim,
+ down_sample_factors=h.downsample_rates,
+ channels=h.downsample_channels,
+ causal=h.causal_encoder,
+ lookahead=h.get("num_encoder_lookahead", 2),
+ )
+
+ intermediate_size = h.latent_dim * 4
+ self.enc_mi_layer = nn.Sequential(
+ nn.Linear(h.latent_dim, intermediate_size),
+ SLSTM(intermediate_size, num_layers=h.mi_num_layers),
+ nn.Linear(intermediate_size, h.latent_dim),
+ )
+ self.dec_mi_layer = nn.Sequential(
+ nn.Linear(h.latent_dim, intermediate_size),
+ SLSTM(intermediate_size, num_layers=h.mi_num_layers),
+ nn.Linear(intermediate_size, h.latent_dim),
+ )
+ self.pre_proj = Conv1d(
+ in_channels=h.latent_dim,
+ out_channels=h.latent_dim * 2,
+ kernel_size=1,
+ stride=1,
+ )
+ self.post_proj = Conv1d(
+ in_channels=h.latent_dim, out_channels=h.latent_dim, kernel_size=1, stride=1
+ )
+
+ self.decoder = Decoder(h)
+
+ def inference(self, data):
+ latents = self.extract_latents(data["sample"])
+ return {"sample": self.inference_from_latents(latents)}
+
+ @torch.autocast(enabled=False, device_type="cuda")
+ def extract_latents(self, x, do_sample=False):
+ x = x.float()
+ x = self.audio_encoder(x)
+ x = x.permute(0, 2, 1)
+ x = self.enc_mi_layer(x)
+ x = x.permute(0, 2, 1)
+ x = self.pre_proj(x)
+ if do_sample:
+ m_q, logs_q = torch.split(x, self.h.latent_dim, dim=1)
+ x = m_q + torch.randn_like(m_q) * torch.exp(logs_q)
+ return x
+
+ def inference_from_latents(self, x, do_sample=True, noise_scale=1.0):
+ if do_sample:
+ assert x.size(1) == self.h.latent_dim * 2, (
+ f"Input must be like [B, D, H], got {x.shape}"
+ )
+ m_q, logs_q = torch.split(x, self.h.latent_dim, dim=1)
+ x = m_q + torch.randn_like(m_q) * torch.exp(logs_q) * noise_scale
+ else:
+ assert x.size(1) == self.h.latent_dim, (
+ f"Input must be like [B, D, H], got {x.shape}"
+ )
+ x = self.post_proj(x)
+ x = x.permute(0, 2, 1)
+ x = self.dec_mi_layer(x)
+ x = x.permute(0, 2, 1)
+ return self.decoder(x)
+
+ def _validate_stream_latents(self, latents: torch.Tensor) -> None:
+ if latents.ndim != 3:
+ raise ValueError(
+ "Streaming latents must have shape [batch, latent_dim, frames], "
+ f"got {tuple(latents.shape)}."
+ )
+ if latents.size(1) != self.h.latent_dim:
+ raise ValueError(
+ f"Streaming latent_dim must be {self.h.latent_dim}, got {latents.size(1)}."
+ )
+
+ def init_stream_state(
+ self,
+ batch_size: int = 1,
+ chunk_size: int = 8,
+ ) -> BigVGANStreamState:
+ if not self.h.causal:
+ raise RuntimeError("Strict streaming requires a causal vocoder.")
+ if batch_size < 1:
+ raise ValueError(f"batch_size must be >= 1, got {batch_size}.")
+ window_size = self.decoder.stream_window_size(chunk_size)
+ device, dtype = _module_state_device_dtype(self.decoder.conv_pre)
+ return BigVGANStreamState(
+ lstm_hidden=self.dec_mi_layer[1].init_stream_state(batch_size),
+ decoder=DecoderStreamState(
+ window=_stream_state_zeros(
+ batch_size,
+ self.h.latent_dim,
+ window_size,
+ device=device,
+ dtype=dtype,
+ ),
+ chunk_size=int(chunk_size),
+ ),
+ )
+
+ def _decode_stream_latents(
+ self,
+ latents: torch.Tensor,
+ lstm_hidden: tuple[torch.Tensor, torch.Tensor],
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ self._validate_stream_latents(latents)
+ latents = latents.float()
+ x = self.post_proj(latents)
+ x = x.permute(0, 2, 1)
+ x = self.dec_mi_layer[0](x)
+ x, lstm_hidden = self.dec_mi_layer[1].stream_step(
+ x, lstm_hidden
+ )
+ x = self.dec_mi_layer[2](x)
+ decoder_dtype = self.decoder.conv_pre.weight.dtype
+ return x.permute(0, 2, 1).to(dtype=decoder_dtype), lstm_hidden
+
+ def _prepare_stream_decoder_input(
+ self,
+ latents: torch.Tensor,
+ state: BigVGANStreamState,
+ ) -> torch.Tensor:
+ decoder_input, state.lstm_hidden = self._decode_stream_latents(
+ latents,
+ state.lstm_hidden,
+ )
+ return decoder_input
+
+ def _append_stream_decoder_input_tensor(
+ self,
+ decoder_input: torch.Tensor,
+ window: torch.Tensor,
+ valid_frames: torch.Tensor,
+ ) -> torch.Tensor:
+ if window.dtype != decoder_input.dtype:
+ window = window.to(dtype=decoder_input.dtype)
+ chunk_size = int(decoder_input.size(-1))
+ if chunk_size >= window.size(-1):
+ raise ValueError(
+ f"decoder window size {window.size(-1)} must be larger than chunk_size {chunk_size}."
+ )
+ positions = torch.arange(
+ window.size(-1),
+ device=window.device,
+ dtype=valid_frames.dtype,
+ )
+ clipped_valid = valid_frames.clamp(min=0, max=window.size(-1))
+ combined = torch.cat(
+ [window, decoder_input.new_zeros(window.size(0), window.size(1), chunk_size)],
+ dim=-1,
+ )
+ insert_index = clipped_valid + torch.arange(
+ chunk_size,
+ device=window.device,
+ dtype=valid_frames.dtype,
+ )
+ combined.scatter_(
+ -1,
+ insert_index.view(1, 1, -1).expand_as(decoder_input),
+ decoder_input,
+ )
+ new_valid = (clipped_valid + chunk_size).clamp(max=window.size(-1))
+ start = (clipped_valid + chunk_size - window.size(-1)).clamp(min=0)
+ gather_index = (start + positions).clamp(max=combined.size(-1) - 1)
+ gathered = combined.gather(
+ -1,
+ gather_index.view(1, 1, -1).expand_as(window),
+ )
+ mask = (positions < new_valid).to(dtype=window.dtype).view(1, 1, -1)
+ return gathered * mask
+
+ def _append_stream_decoder_input(
+ self,
+ decoder_input: torch.Tensor,
+ state: BigVGANStreamState,
+ ) -> torch.Tensor:
+ decoder_state = state.decoder
+ chunk_size = int(decoder_input.size(-1))
+ if chunk_size != decoder_state.chunk_size:
+ raise ValueError(
+ f"Streaming chunk_size must stay fixed at {decoder_state.chunk_size}, got {chunk_size}."
+ )
+ window = decoder_state.window
+ valid_frames = min(decoder_state.total_frames, window.size(-1))
+ valid_frames_tensor = window.new_tensor(valid_frames, dtype=torch.int64)
+ new_window = self._append_stream_decoder_input_tensor(
+ decoder_input,
+ window,
+ valid_frames_tensor,
+ )
+ decoder_state.window = new_window
+ decoder_state.total_frames += chunk_size
+ return new_window
+
+ def compiled_stream_step(
+ self,
+ latents: torch.Tensor,
+ hidden_h: torch.Tensor,
+ hidden_c: torch.Tensor,
+ window: torch.Tensor,
+ valid_frames: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ decoder_input, (hidden_h, hidden_c) = self._decode_stream_latents(
+ latents,
+ (hidden_h, hidden_c),
+ )
+ new_window = self._append_stream_decoder_input_tensor(
+ decoder_input,
+ window,
+ valid_frames,
+ )
+ audio_window = self.decode_stream_window(new_window)
+ return audio_window, hidden_h, hidden_c, new_window
+
+ def decode_stream_window(self, window: torch.Tensor) -> torch.Tensor:
+ decoder_dtype = self.decoder.conv_pre.weight.dtype
+ if window.dtype != decoder_dtype:
+ window = window.to(dtype=decoder_dtype)
+ return self.decoder(window)
+
+ def _slice_stream_audio_window(
+ self,
+ audio_window: torch.Tensor,
+ state: BigVGANStreamState,
+ *,
+ final: bool,
+ ) -> torch.Tensor:
+ decoder_state = state.decoder
+ stable_end = (
+ decoder_state.total_frames
+ if final
+ else max(0, decoder_state.total_frames - self.decoder.stream_lookahead)
+ )
+ if stable_end <= decoder_state.emitted_frames:
+ return _empty_chunk(audio_window, channels=1)
+
+ window_size = decoder_state.window.size(-1)
+ valid_frames = min(decoder_state.total_frames, window_size)
+ window_start = decoder_state.total_frames - valid_frames
+ if decoder_state.emitted_frames < window_start:
+ raise RuntimeError(
+ "Decoder stream window is too short for fixed-graph decoding."
+ )
+
+ local_start = decoder_state.emitted_frames - window_start
+ local_end = stable_end - window_start
+ sample_start = local_start * self.hop_size
+ sample_end = local_end * self.hop_size
+ decoder_state.emitted_frames = stable_end
+ return audio_window[..., sample_start:sample_end]
+
+ def stream_step(
+ self,
+ latents: torch.Tensor,
+ state: BigVGANStreamState,
+ ) -> torch.Tensor:
+ decoder_input = self._prepare_stream_decoder_input(latents, state)
+ window = self._append_stream_decoder_input(decoder_input, state)
+ audio_window = self.decode_stream_window(window)
+ return self._slice_stream_audio_window(audio_window, state, final=False)
+
+ def stream_flush(self, state: BigVGANStreamState) -> torch.Tensor:
+ audio_window = self.decode_stream_window(state.decoder.window)
+ return self._slice_stream_audio_window(audio_window, state, final=True)
+
+ def remove_weight_norm(self):
+ self.decoder.remove_weight_norm()
diff --git a/src/dots_tts/modules/vocoder/config.py b/src/dots_tts/modules/vocoder/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3bd68d9afdcf03acc5624b31dced42319f2e4e2
--- /dev/null
+++ b/src/dots_tts/modules/vocoder/config.py
@@ -0,0 +1,28 @@
+from __future__ import annotations
+
+from pydantic import Field
+
+from dots_tts.config.base import ConfigBase
+
+
+class AudioVAEConfig(ConfigBase):
+ sample_rate: int = 24000
+ upsample_rates: list[int] = Field(default_factory=list)
+ upsample_kernel_sizes: list[int] = Field(default_factory=list)
+ upsample_initial_channel: int = 1536
+ resblock: str = "1"
+ resblock_kernel_sizes: list[int] = Field(default_factory=list)
+ resblock_dilation_sizes: list[list[int]] = Field(default_factory=list)
+ downsample_rates: list[int] = Field(default_factory=list)
+ downsample_channels: list[int] = Field(default_factory=list)
+ activation: str = "snakebeta"
+ snake_logscale: bool = True
+ latent_dim: int = 128
+ causal: bool = False
+ mi_num_layers: int = 4
+ causal_encoder: bool = False
+ use_bias_at_final: bool = True
+ use_tanh_at_final: bool = True
+
+
+__all__ = ["AudioVAEConfig"]
diff --git a/src/dots_tts/runtime.py b/src/dots_tts/runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a4445b7684c4ddb59971ff4e9421db7fe52511f
--- /dev/null
+++ b/src/dots_tts/runtime.py
@@ -0,0 +1,566 @@
+from __future__ import annotations
+
+import hashlib
+import json
+import time
+from pathlib import Path
+from typing import Any, Iterator, TypedDict
+
+import librosa
+import torch
+from huggingface_hub import snapshot_download
+from loguru import logger
+
+from dots_tts.data.pipelines.tokenizing import build_generation_schedule
+from dots_tts.data.pipelines.tts_pipeline import (
+ DEFAULT_INSTRUCTION_TTS_TEMPLATE,
+ DEFAULT_INTERLEAVE_TRAIN_TEMPLATE,
+ DEFAULT_TEXT_TO_AUDIO_TEMPLATE,
+ DEFAULT_TRAIN_TEMPLATE,
+)
+from dots_tts.models.dots_tts.model import DotsTtsModel
+from dots_tts.utils.audio import high_quality_resample
+from dots_tts.utils.profiling import (
+ InferenceProfiler,
+ activate_inference_profiler,
+ inference_profiling,
+ log_inference_profile,
+)
+from dots_tts.utils.text import (
+ attach_language_tag,
+ detect,
+ normalize_language_code,
+ normalize_text,
+)
+from dots_tts.utils.util import get_dtype
+
+RUNTIME_TEMPLATE_BY_NAME = {
+ "tts": DEFAULT_TRAIN_TEMPLATE,
+ "instruction_tts": DEFAULT_INSTRUCTION_TTS_TEMPLATE,
+ "text_to_audio": DEFAULT_TEXT_TO_AUDIO_TEMPLATE,
+ "tts_interleave": DEFAULT_INTERLEAVE_TRAIN_TEMPLATE,
+}
+
+
+class RuntimeInputs(TypedDict, total=False):
+ fid: str
+ language: str
+ text: str
+ prompt_text: str
+ template_name: str
+ generation_schedule: torch.Tensor
+ prompt_audio: torch.Tensor
+
+
+class DotsTtsRuntime:
+ # region Lifecycle and pretrained loading
+ def __init__(
+ self,
+ model: DotsTtsModel,
+ pretrained_path: Path,
+ *,
+ precision: str = "bfloat16",
+ optimize: bool = False,
+ max_generate_length: int = 500,
+ ):
+ self.model = model
+ self.pretrained_path = pretrained_path
+ self.precision = precision
+ if torch.cuda.is_available():
+ self.device = torch.device("cuda")
+ else:
+ self.device = torch.device("cpu")
+ torch.set_num_threads(1)
+ if self.device.type == "cuda" and self.precision.lower() in {
+ "fp32",
+ "torch.float32",
+ "float32",
+ }:
+ torch.set_float32_matmul_precision("high")
+ target_dtype = get_dtype(self.precision)
+ self.model.core.to(dtype=target_dtype)
+ self.model = self.model.to(self.device).eval()
+ self.optimize = bool(optimize)
+ self.max_generate_length = int(max_generate_length)
+ self.model.set_optimize(self.optimize)
+ self.sample_rate = int(self.model.config.vocoder.sample_rate)
+ skip_init_warmup = os.environ.get("DOTS_TTS_SKIP_INIT_WARMUP", "0") == "1"
+ if self.optimize and hasattr(self.model, "run_warmup") and not skip_init_warmup:
+ self.model.run_warmup(
+ max_generate_length=self.max_generate_length,
+ precision=self.precision,
+ )
+ logger.info(
+ "Runtime initialized: pretrained_path={} device={} sample_rate={} "
+ "precision={} "
+ "optimize={} max_audio_patch_count={}",
+ self.pretrained_path,
+ self.device,
+ self.sample_rate,
+ self.precision,
+ self.optimize,
+ self.max_generate_length,
+ )
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_name_or_path: str,
+ *,
+ revision: str | None = None,
+ cache_dir: str | None = None,
+ precision: str = "bfloat16",
+ optimize: bool = False,
+ max_generate_length: int = 500,
+ ) -> DotsTtsRuntime:
+ logger.info(
+ "Runtime load started: model={} revision={} cache_dir={} precision={}",
+ model_name_or_path,
+ revision,
+ cache_dir,
+ precision,
+ )
+ pretrained_path = cls._resolve_pretrained_path(
+ model_name_or_path,
+ revision=revision,
+ cache_dir=cache_dir,
+ )
+ loaded_model = DotsTtsModel.from_pretrained(pretrained_path)
+ logger.info("Runtime load completed: pretrained_path={}", pretrained_path)
+ return cls(
+ model=loaded_model,
+ pretrained_path=pretrained_path,
+ precision=precision,
+ optimize=optimize,
+ max_generate_length=max_generate_length,
+ )
+
+ @classmethod
+ def _resolve_pretrained_path(
+ cls,
+ model_name_or_path: str,
+ revision: str | None = None,
+ cache_dir: str | None = None,
+ ) -> Path:
+ logger.info(
+ "Resolving pretrained path: model={} revision={} cache_dir={}",
+ model_name_or_path,
+ revision,
+ cache_dir,
+ )
+ resolved_path = Path(model_name_or_path).expanduser().resolve()
+ if resolved_path.exists():
+ logger.info("Using local pretrained directory: path={}", resolved_path)
+ return resolved_path
+
+ logger.info(
+ "Downloading pretrained snapshot: repo_id={} revision={} cache_dir={}",
+ model_name_or_path,
+ revision,
+ cache_dir,
+ )
+ snapshot_dir = snapshot_download(
+ repo_id=model_name_or_path,
+ revision=revision,
+ cache_dir=cache_dir,
+ )
+ resolved_path = Path(snapshot_dir).expanduser().resolve()
+ logger.info("Pretrained snapshot ready: path={}", resolved_path)
+ return resolved_path
+ # endregion Lifecycle and pretrained loading
+
+ # region Request normalization and metadata
+ @staticmethod
+ def _build_request_id(
+ *,
+ text: str,
+ prompt_audio_path: str | None,
+ prompt_text: str | None,
+ template_name: str,
+ language: str | None = None,
+ ) -> str:
+ payload = {
+ "text": text,
+ "prompt_audio_path": prompt_audio_path,
+ "prompt_text": prompt_text,
+ "template_name": template_name,
+ }
+ if language is not None:
+ payload["language"] = language
+ digest = hashlib.sha1(
+ json.dumps(payload, ensure_ascii=False, sort_keys=True).encode("utf-8")
+ ).hexdigest()
+ return digest[:16]
+
+ def _load_prompt_audio(
+ self,
+ prompt_audio_path: str,
+ ) -> torch.Tensor:
+ logger.info("Loading prompt audio: path={}", prompt_audio_path)
+ prompt_audio, sample_rate = librosa.load(prompt_audio_path, sr=None, mono=True)
+ prompt_audio = librosa.effects.trim(prompt_audio, top_db=30)[0]
+ prompt_audio = torch.from_numpy(prompt_audio).unsqueeze(0)
+ prompt_audio = high_quality_resample(
+ prompt_audio,
+ orig_sr=sample_rate,
+ target_sr=self.sample_rate,
+ )
+ if prompt_audio.ndim == 1:
+ prompt_audio = prompt_audio.unsqueeze(0)
+ logger.info(
+ "Prompt audio loaded: path={} original_sample_rate={} resampled_sample_rate={} "
+ "samples={}",
+ prompt_audio_path,
+ sample_rate,
+ self.sample_rate,
+ prompt_audio.shape[-1],
+ )
+ return prompt_audio
+
+ def _resolve_language(
+ self,
+ language: str | None,
+ *,
+ text: str,
+ ) -> str | None:
+ if language is None:
+ return None
+
+ stripped = language.strip()
+ if not stripped or stripped.lower() == "none":
+ return None
+ if stripped.lower() == "auto_detect":
+ return normalize_language_code(detect(text))
+
+ normalized_language = normalize_language_code(stripped)
+ if normalized_language is None:
+ raise ValueError(
+ f"Unsupported language={language!r}. "
+ "Expected 'none', 'auto_detect', or a valid language code/name."
+ )
+ return normalized_language
+
+ def _process_prompt_text(
+ self,
+ prompt_text: str | None,
+ *,
+ language: str | None = None,
+ ) -> str:
+ if prompt_text is None:
+ return ""
+ prompt_text = prompt_text.strip()
+ if not prompt_text:
+ return ""
+
+ prompt_language = language
+ if prompt_language is None:
+ prompt_language = normalize_language_code(detect(prompt_text))
+
+ if prompt_language not in {"ZH", "YUE", "JA", "口音:粤语"}:
+ prompt_text += " "
+ if language is not None:
+ prompt_text = attach_language_tag(prompt_text, language)
+ return prompt_text
+
+ def _process_text(
+ self,
+ text: str,
+ *,
+ language: str | None = None,
+ normalize: bool = False,
+ ) -> tuple[str, str | None]:
+ stripped = text.strip()
+ if normalize:
+ stripped = normalize_text(stripped)
+ resolved_language = self._resolve_language(language, text=stripped)
+ return stripped, resolved_language
+
+ def _estimate_prompt_audio_patch_count(
+ self,
+ *,
+ prompt_audio: torch.Tensor | None,
+ prompt_text: str,
+ ) -> int:
+ if prompt_audio is None or not prompt_text:
+ return 0
+ samples_per_patch = int(self.model.config.patch_size * self.model.hop_size)
+ prompt_samples = int(prompt_audio.shape[-1])
+ return (prompt_samples + samples_per_patch - 1) // samples_per_patch
+ # endregion Request normalization and metadata
+
+ # region Generation schedule assembly
+ def _normalize_template_name(self, template_name: str | None) -> str:
+ if template_name is None:
+ return "tts"
+ if template_name not in RUNTIME_TEMPLATE_BY_NAME:
+ raise ValueError(
+ f"Unknown template_name={template_name!r}. "
+ f"Expected one of {sorted(RUNTIME_TEMPLATE_BY_NAME)}."
+ )
+ return template_name
+
+ def _prepare_inputs(
+ self,
+ *,
+ text: str,
+ prompt_audio_path: str | None,
+ prompt_text: str | None,
+ template_name: str | None,
+ language: str | None = None,
+ normalize_text: bool = False,
+ ) -> RuntimeInputs:
+ normalized_template_name = self._normalize_template_name(template_name)
+ template = RUNTIME_TEMPLATE_BY_NAME[normalized_template_name]
+ if prompt_text and not prompt_audio_path:
+ raise ValueError("prompt_text requires prompt_audio_path.")
+
+ normalized_text, normalized_language = self._process_text(
+ text,
+ language=language,
+ normalize=normalize_text,
+ )
+ normalized_prompt_text = self._process_prompt_text(
+ prompt_text,
+ language=normalized_language,
+ )
+ if normalized_language is not None and not normalized_prompt_text:
+ normalized_text = attach_language_tag(normalized_text, normalized_language)
+ inputs: RuntimeInputs = {
+ "fid": self._build_request_id(
+ text=normalized_text,
+ prompt_audio_path=prompt_audio_path,
+ prompt_text=normalized_prompt_text,
+ template_name=normalized_template_name,
+ language=normalized_language,
+ ),
+ "language": normalized_language or "",
+ "text": normalized_text,
+ "prompt_text": normalized_prompt_text,
+ "template_name": normalized_template_name,
+ }
+
+ if prompt_audio_path:
+ inputs["prompt_audio"] = self._load_prompt_audio(prompt_audio_path)
+ prompt_audio_patch_count = self._estimate_prompt_audio_patch_count(
+ prompt_audio=inputs.get("prompt_audio"),
+ prompt_text=normalized_prompt_text,
+ )
+ if (
+ prompt_audio_patch_count > 0
+ and self.max_generate_length <= prompt_audio_patch_count
+ ):
+ raise ValueError(
+ "max_generate_length must exceed prompt audio patch count when prompt_text is provided: "
+ f"max_generate_length={self.max_generate_length} "
+ f"prompt_audio_patch_count={prompt_audio_patch_count}."
+ )
+
+ schedule_spec = build_generation_schedule(
+ text=f"{normalized_prompt_text}{normalized_text}",
+ tokenizer=self.model.tokenizer,
+ template=template,
+ max_audio_tokens=self.max_generate_length,
+ )
+ schedule = torch.tensor(
+ schedule_spec["schedule_ids"],
+ dtype=torch.long,
+ device=self.device,
+ )
+ inputs["generation_schedule"] = schedule.unsqueeze(0)
+ logger.info(
+ "Inputs prepared: request_id={} template_name={} "
+ "language={} text_len={} prompt_text_len={} schedule_length={} "
+ "prompt_audio_patch_count={} max_audio_patch_count={} has_prompt_audio={}",
+ inputs["fid"],
+ normalized_template_name,
+ normalized_language,
+ len(normalized_text),
+ len(normalized_prompt_text),
+ schedule.numel(),
+ prompt_audio_patch_count,
+ self.max_generate_length,
+ bool(prompt_audio_path),
+ )
+ return inputs
+ # endregion Generation schedule assembly
+
+ # region Public generation APIs
+ def generate_stream(
+ self,
+ *,
+ text: str,
+ prompt_audio_path: str | None = None,
+ prompt_text: str | None = None,
+ template_name: str | None = None,
+ language: str | None = None,
+ speaker_scale: float = 1.5,
+ ode_method: str = "euler",
+ num_steps: int = 10,
+ guidance_scale: float = 1.2,
+ normalize_text: bool = False,
+ profile_inference: bool = False,
+ ) -> Iterator[torch.Tensor]:
+ inputs = self._prepare_inputs(
+ text=text,
+ prompt_audio_path=prompt_audio_path,
+ prompt_text=prompt_text,
+ template_name=template_name,
+ language=language,
+ normalize_text=normalize_text,
+ )
+ logger.info(
+ "Streaming generation started: request_id={} text_len={} has_prompt_audio={} "
+ "has_prompt_text={} template_name={} language={} precision={} ode_method={} num_steps={} "
+ "guidance_scale={} speaker_scale={} max_audio_patch_count={} normalize_text={}",
+ inputs["fid"],
+ len(inputs["text"]),
+ bool(prompt_audio_path),
+ bool(inputs["prompt_text"]),
+ inputs["template_name"],
+ inputs["language"] or None,
+ self.precision,
+ ode_method,
+ num_steps,
+ guidance_scale,
+ speaker_scale,
+ self.max_generate_length,
+ normalize_text,
+ )
+ start_time = time.time()
+ emitted_samples = 0
+ chunk_count = 0
+ profiler: InferenceProfiler | None = None
+ try:
+ profiler = (
+ InferenceProfiler(self.device) if profile_inference else None
+ )
+ stream = self.model.generate_audio_stream(
+ inputs,
+ precision=self.precision,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ speaker_scale=speaker_scale,
+ )
+ while True:
+ try:
+ with activate_inference_profiler(profiler):
+ chunk = next(stream)
+ except StopIteration:
+ break
+ emitted_samples += int(chunk.shape[-1])
+ chunk_count += 1
+ yield chunk
+ except Exception:
+ logger.exception(
+ "Streaming generation failed: request_id={}",
+ inputs["fid"],
+ )
+ raise
+ time_used = time.time() - start_time
+ duration_seconds = emitted_samples / self.sample_rate
+ rtf = time_used / duration_seconds if duration_seconds > 0 else float("inf")
+ if profile_inference and profiler is not None:
+ log_inference_profile(
+ request_id=inputs["fid"],
+ profiling=profiler.summary(duration_seconds=duration_seconds),
+ duration_seconds=duration_seconds,
+ )
+ logger.info(
+ "Streaming generation finished: request_id={} chunk_count={} elapsed_seconds={:.3f} "
+ "audio_seconds={:.3f} rtf={:.4f} sample_rate={}",
+ inputs["fid"],
+ chunk_count,
+ time_used,
+ duration_seconds,
+ rtf,
+ self.sample_rate,
+ )
+
+ def generate(
+ self,
+ *,
+ text: str,
+ prompt_audio_path: str | None = None,
+ prompt_text: str | None = None,
+ template_name: str | None = None,
+ language: str | None = None,
+ speaker_scale: float = 1.5,
+ ode_method: str = "euler",
+ num_steps: int = 10,
+ guidance_scale: float = 1.2,
+ normalize_text: bool = False,
+ profile_inference: bool = False,
+ ) -> dict[str, Any]:
+ inputs = self._prepare_inputs(
+ text=text,
+ prompt_audio_path=prompt_audio_path,
+ prompt_text=prompt_text,
+ template_name=template_name,
+ language=language,
+ normalize_text=normalize_text,
+ )
+ logger.info(
+ "Generation started: request_id={} text_len={} has_prompt_audio={} "
+ "has_prompt_text={} template_name={} language={} precision={} ode_method={} num_steps={} "
+ "guidance_scale={} speaker_scale={} max_audio_patch_count={} normalize_text={}",
+ inputs["fid"],
+ len(inputs["text"]),
+ bool(prompt_audio_path),
+ bool(inputs["prompt_text"]),
+ inputs["template_name"],
+ inputs["language"] or None,
+ self.precision,
+ ode_method,
+ num_steps,
+ guidance_scale,
+ speaker_scale,
+ self.max_generate_length,
+ normalize_text,
+ )
+ start_time = time.time()
+ profiling = None
+ try:
+ with inference_profiling(
+ enabled=profile_inference,
+ device=self.device,
+ ) as profiler:
+ audio = self.model.generate_audio(
+ inputs,
+ precision=self.precision,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ speaker_scale=speaker_scale,
+ )
+ except Exception:
+ logger.exception("Generation failed: request_id={}", inputs["fid"])
+ raise
+ time_used = time.time() - start_time
+ duration_seconds = audio.shape[-1] / self.sample_rate
+ rtf = time_used / duration_seconds if duration_seconds > 0 else float("inf")
+ if profiler is not None:
+ profiling = profiler.summary(duration_seconds=duration_seconds)
+ log_inference_profile(
+ request_id=inputs["fid"],
+ profiling=profiling,
+ duration_seconds=duration_seconds,
+ )
+ logger.info(
+ "Generation completed: request_id={} elapsed_seconds={:.3f} audio_seconds={:.3f} "
+ "rtf={:.4f} sample_rate={}",
+ inputs["fid"],
+ time_used,
+ duration_seconds,
+ rtf,
+ self.sample_rate,
+ )
+ return {
+ "fid": inputs["fid"],
+ "audio": audio,
+ "sample_rate": self.sample_rate,
+ "time_used": time_used,
+ "rtf": rtf,
+ "profiling": profiling,
+ }
+ # endregion Public generation APIs
diff --git a/src/dots_tts/runtime_double_streaming.py b/src/dots_tts/runtime_double_streaming.py
new file mode 100644
index 0000000000000000000000000000000000000000..208f5dd246f65151da1b18e756f58bd692bcf64c
--- /dev/null
+++ b/src/dots_tts/runtime_double_streaming.py
@@ -0,0 +1,355 @@
+from __future__ import annotations
+
+from pathlib import Path
+
+import torch
+from loguru import logger
+
+from dots_tts.data.pipelines.tts_pipeline import TTS_INTERLEAVE_PREFIX
+from dots_tts.runtime import DotsTtsRuntime
+from dots_tts.utils.util import get_dtype
+
+
+class DoubleStreamingSession:
+ """Incremental interleave session for text-token to audio-chunk generation."""
+
+ def __init__(
+ self,
+ runtime: DotsTtsRuntime,
+ *,
+ prompt_audio_path: str | None = None,
+ prompt_text: str | None = None,
+ ode_method: str = "euler",
+ num_steps: int = 10,
+ guidance_scale: float = 1.2,
+ speaker_scale: float = 1.5,
+ eos_threshold: float = 0.8,
+ initial_silence_audio_tokens: int = 1,
+ ) -> None:
+ normalized_prompt_text = runtime._process_prompt_text(prompt_text)
+ if normalized_prompt_text:
+ raise ValueError("Double streaming does not support prompt_text.")
+
+ self.runtime = runtime
+ self.model = runtime.model
+ self.device = runtime.device
+ self.ode_method = ode_method
+ self.num_steps = int(num_steps)
+ self.guidance_scale = float(guidance_scale)
+ self.speaker_scale = float(speaker_scale)
+ self.eos_threshold = float(eos_threshold)
+ self.max_generate_length = runtime.max_generate_length
+ self._initial_silence_audio_tokens = max(
+ 0,
+ min(10, int(initial_silence_audio_tokens or 0)),
+ )
+
+ self._dtype = get_dtype(runtime.precision)
+ self._use_amp = self.device.type == "cuda" and self._dtype in {
+ torch.float16,
+ torch.bfloat16,
+ }
+ self._prefix_token_ids = tuple(
+ self.model.tokenizer.encode(
+ TTS_INTERLEAVE_PREFIX,
+ add_special_tokens=False,
+ )
+ )
+ self._state = self.model._allocate_generate_state(
+ max_audio_patch_count=self.max_generate_length,
+ device=self.device,
+ dtype=self._dtype,
+ )
+ self._vocoder_state = self.model.vocoder.init_stream_state(
+ batch_size=1,
+ chunk_size=self.model.core.latent_patch_size,
+ )
+ self._g_cond = None
+ self._started = False
+ self._text_finished = False
+ self._closed = False
+ self._decoded_patch_count = 0
+
+ if prompt_audio_path is not None:
+ cache = getattr(self.runtime, "_double_streaming_prompt_g_cond_cache", None)
+ if cache is None:
+ cache = {}
+ setattr(self.runtime, "_double_streaming_prompt_g_cond_cache", cache)
+ prompt_cache_key = (
+ str(Path(prompt_audio_path).expanduser().resolve()),
+ str(self.device),
+ str(self._dtype),
+ self.speaker_scale,
+ )
+ cached_g_cond = cache.get(prompt_cache_key)
+ if cached_g_cond is None:
+ prompt_audio = self.runtime._load_prompt_audio(prompt_audio_path)
+ with torch.no_grad():
+ with torch.autocast(
+ device_type=self.device.type,
+ dtype=self._dtype,
+ enabled=self._use_amp,
+ ):
+ prompt_conditioning = self.model._prepare_prompt_conditioning(
+ prompt_audio,
+ use_prompt_prefill=False,
+ speaker_scale=self.speaker_scale,
+ )
+ cached_g_cond = prompt_conditioning.g_cond.detach()
+ cache[prompt_cache_key] = cached_g_cond
+ logger.info(
+ "Double streaming prompt conditioning cached: path={} device={} "
+ "dtype={} speaker_scale={}",
+ prompt_cache_key[0],
+ self.device,
+ self._dtype,
+ self.speaker_scale,
+ )
+ else:
+ logger.info(
+ "Double streaming prompt conditioning cache hit: path={} device={} "
+ "dtype={} speaker_scale={}",
+ prompt_cache_key[0],
+ self.device,
+ self._dtype,
+ self.speaker_scale,
+ )
+ self._g_cond = cached_g_cond
+
+ logger.info(
+ "Double streaming session started: prefix_token_count={} precision={} "
+ "ode_method={} num_steps={} guidance_scale={} speaker_scale={} max_audio_patch_count={} "
+ "initial_silence_audio_tokens={} has_ref_audio_only={}",
+ len(self._prefix_token_ids),
+ runtime.precision,
+ self.ode_method,
+ self.num_steps,
+ self.guidance_scale,
+ self.speaker_scale,
+ self.max_generate_length,
+ self._initial_silence_audio_tokens,
+ self._g_cond is not None,
+ )
+
+ @property
+ def is_finished(self) -> bool:
+ return self._closed
+
+ def push_text_token(self, text_token: int) -> torch.Tensor | None:
+ self._ensure_active()
+ if self._text_finished:
+ raise RuntimeError("Cannot push text tokens after finish_text().")
+ if self._state.end_flag:
+ raise RuntimeError(
+ "Double streaming generation has already reached EOS. "
+ "Call finish_text() to flush the remaining audio tail."
+ )
+
+ token_id = int(text_token)
+ if not self._started:
+ chunk_token_ids = [*self._prefix_token_ids, token_id]
+ self._started = True
+ else:
+ chunk_token_ids = [token_id]
+
+ self._consume_text_chunk(chunk_token_ids)
+ return self._decode_audio_chunk()
+
+ def finish_text(self):
+ self._ensure_active()
+
+ if not self._state.end_flag:
+ if not self._text_finished:
+ text_end_chunk = [self.model.core.text_cond_end_id]
+ if not self._started:
+ text_end_chunk = [*self._prefix_token_ids, *text_end_chunk]
+ self._started = True
+ self._consume_text_chunk(text_end_chunk)
+ self._text_finished = True
+
+ while not self._state.end_flag:
+ audio_chunk = self._decode_audio_chunk(continue_audio_span=True)
+ if audio_chunk is not None:
+ yield audio_chunk
+ else:
+ self._text_finished = True
+
+ final_chunk = self.model.vocoder.stream_flush(self._vocoder_state)
+ self._closed = True
+ logger.info(
+ "Double streaming session finished: decoded_patch_count={}",
+ self._decoded_patch_count,
+ )
+ if final_chunk.size(-1) > 0:
+ yield final_chunk
+
+ def _ensure_active(self) -> None:
+ if self._closed:
+ raise RuntimeError("Double streaming session is already closed.")
+
+ def _consume_text_chunk(self, token_ids: list[int]) -> None:
+ schedule = torch.tensor(
+ [token_ids],
+ dtype=torch.long,
+ device=self.device,
+ )
+ with torch.no_grad():
+ with torch.autocast(
+ device_type=self.device.type,
+ dtype=self._dtype,
+ enabled=self._use_amp,
+ ):
+ self.model._consume_text_schedule(
+ schedule,
+ position=0,
+ next_audio_position=schedule.size(1),
+ state=self._state,
+ )
+
+ def _get_initial_silence_audio_patch(
+ self,
+ patch_index: int,
+ audio_patch: torch.Tensor,
+ ) -> torch.Tensor:
+ cache = getattr(self.runtime, "_double_streaming_silence_audio_patch_cache", None)
+ if cache is None:
+ cache = {}
+ setattr(self.runtime, "_double_streaming_silence_audio_patch_cache", cache)
+
+ cache_count = 10
+ patch_size = int(self.model.core.latent_patch_size)
+ key = (
+ str(self.device),
+ str(self._dtype),
+ patch_size,
+ int(audio_patch.size(-1)),
+ cache_count,
+ )
+ cached_patches = cache.get(key)
+ if cached_patches is None:
+ hop_size = int(getattr(self.model.vocoder, "hop_size", 1))
+ zero_samples = cache_count * patch_size * hop_size
+ zero_audio = torch.zeros(
+ (1, 1, zero_samples),
+ device=self.device,
+ dtype=torch.float32,
+ )
+ silence_latents = self.model.vocoder.extract_latents(zero_audio)
+ silence_latents, _ = torch.split(
+ silence_latents,
+ int(audio_patch.size(-1)),
+ dim=1,
+ )
+ silence_latents = silence_latents.transpose(1, 2)
+ target_frames = cache_count * patch_size
+ if silence_latents.size(1) < target_frames:
+ silence_latents = torch.cat(
+ [
+ silence_latents,
+ silence_latents.new_zeros(
+ (
+ silence_latents.size(0),
+ target_frames - silence_latents.size(1),
+ silence_latents.size(2),
+ )
+ ),
+ ],
+ dim=1,
+ )
+ silence_latents = silence_latents[:, :target_frames, :]
+ cached_patches = self.model.core.io_helper.normalize(silence_latents)
+ cached_patches = cached_patches.to(device=self.device, dtype=audio_patch.dtype)
+ cached_patches = cached_patches.reshape(
+ 1,
+ cache_count,
+ patch_size,
+ int(audio_patch.size(-1)),
+ ).detach()
+ cache[key] = cached_patches
+ logger.info(
+ "Double streaming initial silence cache built: patches={} patch_size={} "
+ "hop_size={} device={} dtype={}",
+ cache_count,
+ patch_size,
+ hop_size,
+ self.device,
+ audio_patch.dtype,
+ )
+ return cached_patches[:, int(patch_index)].clone()
+
+ def _consume_audio_patch(self, audio_patch: torch.Tensor) -> None:
+ self.model._consume_audio_patch(self._state, audio_patch=audio_patch)
+
+ def _decode_audio_chunk(self, *, continue_audio_span: bool = False) -> torch.Tensor | None:
+ if self._decoded_patch_count >= self.max_generate_length:
+ raise RuntimeError(
+ "Double streaming exceeded max_generate_length before reaching EOS."
+ )
+
+ with torch.no_grad():
+ with torch.autocast(
+ device_type=self.device.type,
+ dtype=self._dtype,
+ enabled=self._use_amp,
+ ):
+ stop_after_current_audio = self.model._should_stop_after_current_audio(
+ self._state,
+ eos_threshold=self.eos_threshold,
+ )
+ audio_patch = self.model._decode_next_audio(
+ self._state,
+ device=self.device,
+ g_cond=self._g_cond,
+ ode_method=self.ode_method,
+ num_steps=self.num_steps,
+ guidance_scale=self.guidance_scale,
+ )
+ if self._decoded_patch_count < self._initial_silence_audio_tokens:
+ audio_patch = self._get_initial_silence_audio_patch(
+ self._decoded_patch_count,
+ audio_patch,
+ )
+ self._consume_audio_patch(audio_patch)
+ if continue_audio_span:
+ self.model._append_hidden_chunk(self._state, self._state.llm_hiddens)
+ self._decoded_patch_count += 1
+ latent_patch = self.model.core.io_helper.denormalize(audio_patch)
+ audio_chunk = self.model.vocoder.stream_step(
+ latent_patch.transpose(1, 2),
+ self._vocoder_state,
+ )
+ if stop_after_current_audio:
+ self._state.end_flag = True
+
+ if audio_chunk.size(-1) == 0:
+ return None
+ return audio_chunk
+
+
+class DotsTtsRuntimeDoubleStreaming(DotsTtsRuntime):
+ def start_double_streaming(
+ self,
+ *,
+ prompt_audio_path: str | None = None,
+ prompt_text: str | None = None,
+ ode_method: str = "euler",
+ num_steps: int = 10,
+ guidance_scale: float = 1.2,
+ speaker_scale: float = 1.5,
+ eos_threshold: float = 0.8,
+ initial_silence_audio_tokens: int = 1,
+ ) -> DoubleStreamingSession:
+ return DoubleStreamingSession(
+ self,
+ prompt_audio_path=prompt_audio_path,
+ prompt_text=prompt_text,
+ ode_method=ode_method,
+ num_steps=num_steps,
+ guidance_scale=guidance_scale,
+ speaker_scale=speaker_scale,
+ eos_threshold=eos_threshold,
+ initial_silence_audio_tokens=initial_silence_audio_tokens,
+ )
+
+
+__all__ = ["DotsTtsRuntimeDoubleStreaming", "DoubleStreamingSession"]
diff --git a/src/dots_tts/training/__init__.py b/src/dots_tts/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4d0dfdc6a265d0c56e3bff22522478b338c5ba3
--- /dev/null
+++ b/src/dots_tts/training/__init__.py
@@ -0,0 +1 @@
+"""Training package."""
diff --git a/src/dots_tts/training/checkpoint.py b/src/dots_tts/training/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c448e8848183aa56831d28947a0efe346aa5c6d
--- /dev/null
+++ b/src/dots_tts/training/checkpoint.py
@@ -0,0 +1,315 @@
+"""Checkpoint helpers for distributed dots_tts training.
+
+This module persists not only model/optimizer/scheduler state, but also
+rank-local RNG state and data-loader progress so resumed training can continue
+from the same point with minimal drift.
+"""
+
+from __future__ import annotations
+
+import json
+import random
+import shutil
+from dataclasses import fields
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+
+def _checkpoint_dir(log_dir: str, step: int) -> Path:
+ """Return the canonical directory name for a training step checkpoint."""
+ return Path(log_dir) / f"checkpoint-{step:08d}"
+
+
+def _checkpoint_entries(log_dir: str) -> list[tuple[int, Path]]:
+ """List valid ``checkpoint-*`` directories sorted by step number."""
+ entries = []
+ for path in Path(log_dir).glob("checkpoint-*"):
+ if not path.is_dir():
+ continue
+ suffix = path.name.removeprefix("checkpoint-")
+ if suffix.isdigit():
+ entries.append((int(suffix), path))
+ return sorted(entries)
+
+
+def resolve_latest_train_checkpoint(log_dir: str) -> Path:
+ """Resolve the checkpoint directory that should be used for resume.
+
+ Preference order:
+ 1. ``/latest`` symlink, if present.
+ 2. The numerically largest ``checkpoint-*`` directory.
+ """
+ latest_path = Path(log_dir) / "latest"
+ if latest_path.exists() or latest_path.is_symlink():
+ return latest_path.resolve(strict=True)
+
+ entries = _checkpoint_entries(log_dir)
+ if not entries:
+ raise FileNotFoundError(
+ f"No checkpoint found under {log_dir!s}; expected latest or checkpoint-*."
+ )
+ return entries[-1][1].resolve(strict=True)
+
+
+def _rng_state() -> dict:
+ """Capture Python/NumPy/PyTorch RNG state for deterministic resume."""
+ numpy_state = np.random.get_state()
+ state = {
+ "torch": torch.get_rng_state(),
+ "python": random.getstate(),
+ "numpy": {
+ "bit_generator": str(numpy_state[0]),
+ "keys": numpy_state[1].tolist(),
+ "pos": int(numpy_state[2]),
+ "has_gauss": int(numpy_state[3]),
+ "cached_gaussian": float(numpy_state[4]),
+ },
+ }
+ if torch.cuda.is_available():
+ state["cuda"] = torch.cuda.get_rng_state_all()
+ return state
+
+
+def _restore_rng_state(state: dict) -> None:
+ """Restore RNG state previously produced by :func:`_rng_state`."""
+ torch.set_rng_state(state["torch"])
+ random.setstate(state["python"])
+ numpy_state = state["numpy"]
+ np.random.set_state(
+ (
+ numpy_state["bit_generator"],
+ np.asarray(numpy_state["keys"], dtype=np.uint32),
+ int(numpy_state["pos"]),
+ int(numpy_state["has_gauss"]),
+ float(numpy_state["cached_gaussian"]),
+ )
+ )
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state_all(state["cuda"])
+
+
+def _replace_latest_symlink(log_dir: str, save_dir: Path) -> None:
+ """Atomically refresh the ``latest`` symlink to point at ``save_dir``."""
+ log_path = Path(log_dir)
+ link_path = log_path / "latest"
+ tmp_link_path = log_path / "latest.tmp"
+
+ if tmp_link_path.exists() or tmp_link_path.is_symlink():
+ tmp_link_path.unlink()
+ tmp_link_path.symlink_to(save_dir.name)
+
+ if link_path.exists() or link_path.is_symlink():
+ if link_path.is_dir() and not link_path.is_symlink():
+ shutil.rmtree(link_path)
+ else:
+ link_path.unlink()
+ tmp_link_path.rename(link_path)
+
+
+def _cleanup_old_checkpoints(log_dir: str, keep_max: int) -> None:
+ """Delete older checkpoints while keeping the newest ``keep_max`` ones."""
+ if keep_max <= 0:
+ return
+ for _, path in _checkpoint_entries(log_dir)[:-keep_max]:
+ shutil.rmtree(path, ignore_errors=True)
+
+
+def _pack_rank_payload(accelerator, payload: dict, *, payload_name: str) -> dict | None:
+ """Collect rank-local payloads onto the main process for checkpointing.
+
+ Some training state is intentionally local to each rank, for example RNG
+ state or data-loader shard progress. We therefore gather a per-rank payload
+ and store it in the checkpoint as ``{world_size, per_rank}``.
+ """
+ local_payload = {
+ "rank": int(accelerator.process_index),
+ "payload": payload,
+ }
+ if dist.is_available() and dist.is_initialized():
+ gathered: list[dict | None] = [None] * int(accelerator.num_processes)
+ dist.all_gather_object(gathered, local_payload)
+ else:
+ gathered = [local_payload]
+
+ if not accelerator.is_main_process:
+ return None
+
+ per_rank = {}
+ for item in gathered:
+ if not isinstance(item, dict):
+ raise RuntimeError(
+ f"Failed to gather rank-scoped {payload_name} for checkpointing."
+ )
+ per_rank[str(int(item["rank"]))] = item["payload"]
+ return {
+ "world_size": len(gathered),
+ "per_rank": per_rank,
+ }
+
+
+def _extract_rank_payload(
+ accelerator, payload: dict | None, *, payload_name: str
+) -> dict:
+ """Recover the payload for the current rank from a packed checkpoint blob."""
+ if payload is None:
+ return {}
+
+ expected_world_size = int(accelerator.num_processes)
+ if int(payload["world_size"]) != expected_world_size:
+ raise RuntimeError(
+ f"Checkpoint {payload_name} payload does not match the current world."
+ )
+
+ local_rank = str(int(accelerator.process_index))
+ if local_rank not in payload["per_rank"]:
+ raise RuntimeError(f"Checkpoint {payload_name} is missing rank {local_rank}.")
+ return payload["per_rank"][local_rank]
+
+
+def save_train_checkpoint(
+ accelerator,
+ model,
+ optimizer,
+ progress,
+ log_dir: str,
+ keep_max: int,
+ data_state: dict,
+ scheduler_state: dict,
+) -> None:
+ """Save a full resumable training checkpoint.
+
+ Stored artifacts include:
+ - model weights in ``save_pretrained`` format
+ - optimizer / scheduler / scaler state
+ - training progress counters
+ - rank-local RNG state
+ - rank-local data pipeline state
+ """
+ accelerator.wait_for_everyone()
+ packed_data_state = _pack_rank_payload(
+ accelerator,
+ data_state,
+ payload_name="data_state",
+ )
+ packed_rng_state = _pack_rank_payload(
+ accelerator,
+ _rng_state(),
+ payload_name="rng_state",
+ )
+
+ if accelerator.is_main_process:
+ unwrapped_model = accelerator.unwrap_model(model)
+ save_dir = _checkpoint_dir(log_dir, progress.global_step)
+ tmp_dir = save_dir.with_name(f"{save_dir.name}.tmp")
+ model_dir = tmp_dir / "model"
+ scaler = getattr(accelerator, "scaler", None)
+
+ if tmp_dir.exists():
+ shutil.rmtree(tmp_dir)
+ model_dir.mkdir(parents=True, exist_ok=True)
+
+ try:
+ # Write into a temporary directory first so interrupted saves never
+ # leave behind a half-written checkpoint that looks valid.
+ unwrapped_model.save_pretrained(model_dir)
+
+ torch.save(optimizer.state_dict(), tmp_dir / "optimizer.pt")
+ torch.save(scheduler_state, tmp_dir / "scheduler.pt")
+ torch.save(
+ {} if scaler is None else scaler.state_dict(),
+ tmp_dir / "scaler.pt",
+ )
+ torch.save(packed_rng_state, tmp_dir / "rng_state.pt")
+ torch.save(packed_data_state, tmp_dir / "data_state.pt")
+ (tmp_dir / "trainer_state.json").write_text(
+ json.dumps(
+ {
+ field.name: int(getattr(progress, field.name))
+ for field in fields(progress)
+ },
+ ensure_ascii=True,
+ indent=2,
+ ),
+ encoding="utf-8",
+ )
+
+ if save_dir.exists():
+ shutil.rmtree(save_dir)
+ tmp_dir.rename(save_dir)
+ _replace_latest_symlink(log_dir, save_dir)
+ _cleanup_old_checkpoints(log_dir, keep_max)
+ accelerator.print(f"Checkpoint saved: {save_dir}")
+ except Exception:
+ if tmp_dir.exists():
+ shutil.rmtree(tmp_dir, ignore_errors=True)
+ raise
+
+ accelerator.wait_for_everyone()
+
+
+def load_train_checkpoint(
+ accelerator,
+ model,
+ optimizer,
+ progress,
+ checkpoint_dir: str | Path,
+ scheduler,
+) -> dict:
+ """Restore a checkpoint previously written by :func:`save_train_checkpoint`.
+
+ Returns auxiliary state that the caller usually needs to resume the input
+ pipeline and scheduler bookkeeping.
+ """
+ checkpoint_dir = Path(checkpoint_dir)
+ model_dir = checkpoint_dir / "model"
+ if not model_dir.is_dir():
+ raise FileNotFoundError(f"Checkpoint model directory not found: {model_dir!s}")
+
+ accelerator.wait_for_everyone()
+
+ unwrapped_model = accelerator.unwrap_model(model)
+ unwrapped_model.load_pretrained_weights(model_dir)
+
+ optimizer.load_state_dict(
+ torch.load(checkpoint_dir / "optimizer.pt", map_location="cpu")
+ )
+
+ scheduler_payload = torch.load(checkpoint_dir / "scheduler.pt", map_location="cpu")
+ scheduler.load_state_dict(scheduler_payload["state_dict"])
+
+ scaler = getattr(accelerator, "scaler", None)
+ scaler_state = torch.load(checkpoint_dir / "scaler.pt", map_location="cpu")
+ if scaler is not None and scaler_state:
+ scaler.load_state_dict(scaler_state)
+
+ rng_state_payload = torch.load(checkpoint_dir / "rng_state.pt", map_location="cpu")
+ _restore_rng_state(
+ _extract_rank_payload(
+ accelerator,
+ rng_state_payload,
+ payload_name="rng_state",
+ )
+ )
+ data_state_payload = torch.load(
+ checkpoint_dir / "data_state.pt", map_location="cpu"
+ )
+
+ trainer_state = json.loads(
+ (checkpoint_dir / "trainer_state.json").read_text(encoding="utf-8")
+ )
+ for field in fields(progress):
+ setattr(progress, field.name, int(trainer_state[field.name]))
+
+ accelerator.wait_for_everyone()
+ return {
+ "checkpoint_dir": checkpoint_dir,
+ "data_state": _extract_rank_payload(
+ accelerator,
+ data_state_payload,
+ payload_name="data_state",
+ ),
+ "scheduler_state": scheduler_payload,
+ }
diff --git a/src/dots_tts/training/losses.py b/src/dots_tts/training/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..32b387373570cee83b91e0b913cfd13d478b4b45
--- /dev/null
+++ b/src/dots_tts/training/losses.py
@@ -0,0 +1,365 @@
+"""Loss aggregation helpers shared by training and validation.
+
+The model returns masked per-token / per-patch loss tensors. This module turns
+them into numerators/denominators for logging, combines configured loss weights,
+and provides distributed reduction helpers.
+"""
+
+from __future__ import annotations
+
+from collections import defaultdict
+from collections.abc import Iterable
+from dataclasses import dataclass
+from typing import Any, TypeAlias
+
+import torch
+import torch.distributed as dist
+
+from dots_tts.utils.util import scalar_as_float
+
+
+@dataclass(frozen=True)
+class LossTerm:
+ """A loss tensor paired with a same-shape mask.
+
+ ``loss`` stores unreduced per-element values.
+ ``mask`` stores the weighting/validity for the same positions.
+ """
+
+ loss: torch.Tensor
+ mask: torch.Tensor
+
+ def __post_init__(self) -> None:
+ if self.loss.shape != self.mask.shape:
+ raise ValueError(
+ "LossTerm expects loss and mask to have the same shape, "
+ f"but got {tuple(self.loss.shape)} and {tuple(self.mask.shape)}."
+ )
+
+
+LossTerms: TypeAlias = dict[str, LossTerm]
+LossMasks: TypeAlias = dict[str, torch.Tensor]
+
+
+def _safe_average(numerator: Any, denominator: Any) -> Any:
+ """Average safely when a mask may produce a zero denominator."""
+ if isinstance(numerator, torch.Tensor):
+ denom = denominator
+ if not isinstance(denom, torch.Tensor):
+ denom = numerator.new_tensor(float(denominator))
+ if float(denom.detach().item()) <= 0.0:
+ return numerator * 0.0
+ return numerator / denom.clamp_min(1.0).to(numerator.dtype)
+
+ denom = float(denominator)
+ if denom <= 0.0:
+ return 0.0
+ return float(numerator) / denom
+
+
+def _as_weight_map(loss_config) -> dict[str, float]:
+ """Extract ``*_weight`` fields from config into ``*_loss`` weights."""
+ weights = {}
+ for name, value in loss_config.model_dump().items():
+ if name.endswith("_weight"):
+ weights[f"{name[:-7]}_loss"] = float(value)
+ return weights
+
+
+def accumulate_named_scalars_(
+ target: dict[str, float],
+ source: dict[str, float],
+) -> dict[str, float]:
+ """In-place add ``source`` scalar values into ``target`` by key."""
+ for name, value in source.items():
+ target[name] += float(value)
+ return target
+
+
+def to_host_named_scalars(values: dict[str, Any]) -> dict[str, float]:
+ """Convert scalar tensors into plain Python floats for logging/serialization."""
+ return {name: scalar_as_float(value) for name, value in values.items()}
+
+
+def collapse_loss_masks(
+ loss_masks: LossMasks,
+) -> dict[str, Any]:
+ """Reduce each loss mask to its total effective weight."""
+ return {name: mask.sum() for name, mask in loss_masks.items()}
+
+
+def collapse_loss_terms(
+ loss_terms: LossTerms,
+ *,
+ indices: list[int] | None = None,
+) -> tuple[dict[str, Any], dict[str, Any]]:
+ """Convert masked per-sample loss terms into ``sum(loss*mask)`` statistics.
+
+ Returns ``(numerators, normalizers)`` so the caller can aggregate them across
+ batches or ranks before taking the final average.
+ """
+ index = None
+ if indices is not None:
+ first = next(iter(loss_terms.values()))
+ index = torch.tensor(indices, device=first.loss.device, dtype=torch.long)
+
+ numerators = {}
+ normalizers = {}
+ for name, term in loss_terms.items():
+ loss = term.loss
+ mask = term.mask
+ if index is not None:
+ loss = loss.index_select(0, index)
+ mask = mask.index_select(0, index)
+ mask = mask.to(loss.dtype)
+ numerators[name] = (loss * mask).sum()
+ normalizers[name] = mask.sum()
+ return numerators, normalizers
+
+
+def collapse_loss_terms_by_source(
+ loss_terms: LossTerms,
+ *,
+ source_names: list[str | None],
+) -> tuple[dict[str, dict[str, float]], dict[str, dict[str, float]]]:
+ """Group collapsed loss statistics by dataset/source name within a batch."""
+ first = next(iter(loss_terms.values()))
+ batch_size = int(first.loss.size(0))
+ if len(source_names) != batch_size:
+ raise RuntimeError(
+ "source_names must align with the batch size for source loss statistics. "
+ f"Expected {batch_size}, got {len(source_names)}."
+ )
+
+ source_indices: dict[str, list[int]] = defaultdict(list)
+ for index, source_name in enumerate(source_names):
+ if source_name is None:
+ raise RuntimeError("source_names must not contain None.")
+ source_indices[str(source_name)].append(index)
+
+ numerators_by_source = {}
+ normalizers_by_source = {}
+ for source_name, indices in source_indices.items():
+ numerators, normalizers = collapse_loss_terms(loss_terms, indices=indices)
+ numerators_by_source[source_name] = to_host_named_scalars(numerators)
+ normalizers_by_source[source_name] = to_host_named_scalars(normalizers)
+ return numerators_by_source, normalizers_by_source
+
+
+def reduce_loss_statistics(
+ numerators: dict[str, Any],
+ normalizers: dict[str, Any],
+ *,
+ loss_config,
+) -> dict[str, Any]:
+ """Turn aggregated numerators/normalizers into averaged metrics.
+
+ The returned mapping includes each individual loss plus a weighted ``loss``
+ field assembled from ``loss_config``.
+ """
+ weights = _as_weight_map(loss_config)
+ reduced = {}
+ total_loss: Any = 0.0
+ for name, numerator in sorted(numerators.items()):
+ value = _safe_average(numerator, normalizers[name])
+ reduced[name] = value
+ total_loss = total_loss + value * weights.get(name, 1.0)
+ reduced["loss"] = total_loss
+ return reduced
+
+
+def reduce_loss_statistics_by_source(
+ numerators_by_source: dict[str, dict[str, float]],
+ normalizers_by_source: dict[str, dict[str, float]],
+ *,
+ loss_config,
+) -> dict[str, dict[str, Any]]:
+ """Apply :func:`reduce_loss_statistics` independently for each source."""
+ return {
+ source_name: reduce_loss_statistics(
+ numerators,
+ normalizers_by_source[source_name],
+ loss_config=loss_config,
+ )
+ for source_name, numerators in sorted(numerators_by_source.items())
+ }
+
+
+def compute_gradient_loss(
+ loss_terms: LossTerms,
+ *,
+ global_normalizers: dict[str, float],
+ loss_config,
+ ddp_world_size: int,
+ gradient_accumulation_steps: int,
+) -> torch.Tensor:
+ """Build the scalar loss used for ``backward()``.
+
+ ``global_normalizers`` is expected to already include cross-rank totals. The
+ final scaling by world size and accumulation steps compensates for the mean
+ reduction that DDP/Accelerate applies during gradient synchronization.
+ """
+ numerators, _ = collapse_loss_terms(loss_terms)
+ weights = _as_weight_map(loss_config)
+
+ total_loss: Any = 0.0
+ for name, numerator in sorted(numerators.items()):
+ total_loss = total_loss + _safe_average(
+ numerator,
+ global_normalizers[name],
+ ) * weights.get(name, 1.0)
+ return total_loss * float(ddp_world_size) * float(gradient_accumulation_steps)
+
+
+def sum_named_scalars_across_ranks(
+ values: dict[str, float],
+ *,
+ device: torch.device,
+) -> dict[str, float]:
+ """All-reduce a dict of scalar values and return summed host floats."""
+ names = _gather_string_union_across_ranks(values, device=device)
+ if not names:
+ return {}
+
+ packed = torch.tensor(
+ [float(values.get(name, 0.0)) for name in names],
+ device=device,
+ dtype=torch.float64,
+ )
+ if dist.is_available() and dist.is_initialized():
+ dist.all_reduce(packed, op=dist.ReduceOp.SUM)
+ return {
+ name: float(value)
+ for name, value in zip(names, packed.tolist(), strict=True)
+ }
+
+
+def sum_grouped_named_scalars_across_ranks(
+ values: dict[str, dict[str, float]],
+ *,
+ device: torch.device,
+) -> dict[str, dict[str, float]]:
+ """All-reduce nested ``group -> metric -> value`` scalar mappings."""
+ group_names = _gather_string_union_across_ranks(values, device=device)
+ if not group_names:
+ return {}
+
+ metric_names = _gather_string_union_across_ranks(
+ (
+ metric_name
+ for group_values in values.values()
+ for metric_name in group_values
+ ),
+ device=device,
+ )
+ if not metric_names:
+ return {group_name: {} for group_name in group_names}
+
+ packed = torch.tensor(
+ [
+ float(values.get(group_name, {}).get(metric_name, 0.0))
+ for group_name in group_names
+ for metric_name in metric_names
+ ],
+ device=device,
+ dtype=torch.float64,
+ )
+ if dist.is_available() and dist.is_initialized():
+ dist.all_reduce(packed, op=dist.ReduceOp.SUM)
+ packed = packed.view(len(group_names), len(metric_names))
+ return {
+ group_name: {
+ metric_name: float(packed[group_index, metric_index].item())
+ for metric_index, metric_name in enumerate(metric_names)
+ }
+ for group_index, group_name in enumerate(group_names)
+ }
+
+
+def accumulate_grouped_named_scalars_(
+ target: dict[str, dict[str, float]],
+ source: dict[str, dict[str, float]],
+) -> dict[str, dict[str, float]]:
+ """In-place add nested ``group -> metric -> value`` scalar mappings."""
+ for group_name, values in source.items():
+ group_target = target.get(group_name)
+ if group_target is None:
+ group_target = {name: 0.0 for name in values}
+ target[group_name] = group_target
+ for name, value in values.items():
+ group_target[name] += float(value)
+ return target
+
+
+def _gather_string_union_across_ranks(
+ values: Iterable[str],
+ *,
+ device: torch.device,
+) -> list[str]:
+ strings = sorted({str(value) for value in values})
+ if not (dist.is_available() and dist.is_initialized()):
+ return strings
+
+ payload = _encode_string_list(strings)
+ world_size = dist.get_world_size()
+ local_size = torch.tensor([len(payload)], device=device, dtype=torch.int64)
+ size_tensors = [torch.zeros_like(local_size) for _ in range(world_size)]
+ dist.all_gather(size_tensors, local_size)
+
+ max_size = max(int(size.item()) for size in size_tensors)
+ if max_size <= 0:
+ return []
+
+ local_bytes = torch.zeros(max_size, device=device, dtype=torch.uint8)
+ if payload:
+ local_bytes[: len(payload)] = torch.tensor(
+ list(payload),
+ device=device,
+ dtype=torch.uint8,
+ )
+
+ gathered_bytes = [
+ torch.empty(max_size, device=device, dtype=torch.uint8)
+ for _ in range(world_size)
+ ]
+ dist.all_gather(gathered_bytes, local_bytes)
+
+ union = set(strings)
+ for size_tensor, byte_tensor in zip(size_tensors, gathered_bytes, strict=True):
+ size = int(size_tensor.item())
+ if size <= 0:
+ continue
+ union.update(
+ _decode_string_list(bytes(byte_tensor[:size].cpu().tolist()))
+ )
+ return sorted(union)
+
+
+def _encode_string_list(values: list[str]) -> bytes:
+ if any("\0" in value for value in values):
+ raise ValueError("Distributed scalar keys must not contain NUL characters.")
+ return "\0".join(values).encode("utf-8")
+
+
+def _decode_string_list(payload: bytes) -> list[str]:
+ if not payload:
+ return []
+ return payload.decode("utf-8").split("\0")
+
+
+__all__ = [
+ "accumulate_grouped_named_scalars_",
+ "LossMasks",
+ "LossTerm",
+ "LossTerms",
+ "accumulate_named_scalars_",
+ "collapse_loss_masks",
+ "collapse_loss_terms",
+ "collapse_loss_terms_by_source",
+ "compute_gradient_loss",
+ "reduce_loss_statistics",
+ "reduce_loss_statistics_by_source",
+ "sum_grouped_named_scalars_across_ranks",
+ "sum_named_scalars_across_ranks",
+ "to_host_named_scalars",
+]
diff --git a/src/dots_tts/training/utils.py b/src/dots_tts/training/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..74bbad852efd4cafd2abb4ecb6c8e525eb845a62
--- /dev/null
+++ b/src/dots_tts/training/utils.py
@@ -0,0 +1,650 @@
+"""Shared helpers for the dots_tts training entrypoints."""
+
+from __future__ import annotations
+
+import math
+import os
+import sys
+import traceback
+from collections import Counter
+from dataclasses import dataclass, fields, is_dataclass
+from typing import Any
+
+import torch
+import torch.distributed as dist
+
+from dots_tts.training import losses as loss_ops
+
+# ---------------------------------------------------------------------------
+# Training State
+# ---------------------------------------------------------------------------
+
+
+@dataclass(slots=True)
+class TrainProgress:
+ """Minimal progress counters that must survive checkpoint save/load."""
+
+ global_step: int = 0
+ epoch: int = 0
+ total_tokens: int = 0
+ audio_tokens: int = 0
+ text_tokens: int = 0
+
+
+@dataclass(slots=True)
+class TrainStepReport:
+ log_values: dict[str, float]
+ console_line: str
+
+
+# ---------------------------------------------------------------------------
+# Distributed Helpers
+# ---------------------------------------------------------------------------
+
+
+def any_rank_true(flag: bool, *, device: torch.device) -> bool:
+ """Return ``True`` if any distributed rank reports ``flag=True``."""
+ packed = torch.tensor(int(flag), device=device, dtype=torch.int32)
+ if dist.is_available() and dist.is_initialized():
+ dist.all_reduce(packed, op=dist.ReduceOp.MAX)
+ return bool(packed.item())
+
+
+def sum_integer_counters_across_ranks(
+ values: list[int],
+ *,
+ device: torch.device,
+) -> list[int]:
+ """All-reduce integer counters and return their cross-rank sums."""
+ packed = torch.tensor(values, device=device, dtype=torch.int64)
+ if dist.is_available() and dist.is_initialized():
+ dist.all_reduce(packed, op=dist.ReduceOp.SUM)
+ return [int(value) for value in packed.tolist()]
+
+
+def move_to_device(value, device):
+ """Recursively move nested tensors/dataclasses onto ``device``."""
+ if isinstance(value, torch.Tensor):
+ return value.to(device, non_blocking=True)
+ if isinstance(value, dict):
+ return {key: move_to_device(item, device) for key, item in value.items()}
+ if isinstance(value, list):
+ return [move_to_device(item, device) for item in value]
+ if isinstance(value, tuple):
+ return tuple(move_to_device(item, device) for item in value)
+ if is_dataclass(value) and not isinstance(value, type):
+ return type(value)(
+ **{
+ field.name: move_to_device(getattr(value, field.name), device)
+ for field in fields(value)
+ }
+ )
+ return value
+
+
+# ---------------------------------------------------------------------------
+# Failure Handling
+# ---------------------------------------------------------------------------
+
+
+def abort_on_out_of_memory(
+ exc: BaseException,
+ *,
+ stage: str,
+ batch: dict[str, object] | None,
+ progress: TrainProgress,
+ device: torch.device,
+ process_index: int,
+ num_processes: int,
+) -> None:
+ if not _is_out_of_memory_error(exc):
+ return
+
+ message = (
+ "Fatal out-of-memory during training. "
+ f"stage={stage}, "
+ f"epoch={progress.epoch}, "
+ f"global_step={progress.global_step}, "
+ f"rank={process_index}/{num_processes}. "
+ f"{_build_batch_memory_summary(batch)}. "
+ f"{_build_cuda_memory_summary(device)}."
+ )
+ print(message, file=sys.stderr, flush=True)
+ traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr)
+ sys.stderr.flush()
+
+ if num_processes > 1:
+ os._exit(1)
+
+
+def _is_out_of_memory_error(exc: BaseException) -> bool:
+ oom_error_type = getattr(torch, "OutOfMemoryError", None)
+ if oom_error_type is not None and isinstance(exc, oom_error_type):
+ return True
+ if not isinstance(exc, RuntimeError):
+ return False
+ return "out of memory" in str(exc).lower()
+
+
+def _build_batch_memory_summary(batch: dict[str, object] | None) -> str:
+ if not isinstance(batch, dict):
+ return "batch=unavailable"
+
+ fields = []
+ input_ids = batch.get("input_ids")
+ if isinstance(input_ids, torch.Tensor):
+ fields.append(f"input_ids_shape={tuple(input_ids.shape)}")
+ sample = batch.get("sample")
+ if isinstance(sample, torch.Tensor):
+ fields.append(f"sample_shape={tuple(sample.shape)}")
+ input_ids_lengths = batch.get("input_ids_lengths")
+ if isinstance(input_ids_lengths, torch.Tensor) and input_ids_lengths.numel() > 0:
+ fields.append(
+ f"max_input_ids_length={int(input_ids_lengths.max().detach().item())}"
+ )
+ num_audio_tokens = batch.get("num_audio_tokens")
+ if isinstance(num_audio_tokens, torch.Tensor) and num_audio_tokens.numel() > 0:
+ fields.append(f"max_audio_tokens={int(num_audio_tokens.max().detach().item())}")
+ num_text_tokens = batch.get("num_text_tokens")
+ if isinstance(num_text_tokens, torch.Tensor) and num_text_tokens.numel() > 0:
+ fields.append(f"max_text_tokens={int(num_text_tokens.max().detach().item())}")
+ return ", ".join(fields) if fields else "batch=unavailable"
+
+
+def _build_cuda_memory_summary(device: torch.device) -> str:
+ if device.type != "cuda" or not torch.cuda.is_available():
+ return "device_memory=unavailable"
+ allocated = torch.cuda.memory_allocated(device) / (1024**3)
+ reserved = torch.cuda.memory_reserved(device) / (1024**3)
+ max_allocated = torch.cuda.max_memory_allocated(device) / (1024**3)
+ max_reserved = torch.cuda.max_memory_reserved(device) / (1024**3)
+ return (
+ f"device={device}, "
+ f"allocated_gb={allocated:.2f}, "
+ f"reserved_gb={reserved:.2f}, "
+ f"max_allocated_gb={max_allocated:.2f}, "
+ f"max_reserved_gb={max_reserved:.2f}"
+ )
+
+
+# ---------------------------------------------------------------------------
+# Debug Helpers
+# ---------------------------------------------------------------------------
+
+
+def build_data_debug_lines(
+ batch: dict[str, object],
+ *,
+ batch_index: int,
+ tokenizer: Any,
+ sample_rate: int,
+) -> list[str]:
+ input_ids = batch["input_ids"]
+ input_ids_lengths = batch["input_ids_lengths"]
+ sample = batch["sample"]
+ sample_lengths = batch["sample_lengths"]
+ num_audio_tokens = batch["num_audio_tokens"]
+ num_text_tokens = batch["num_text_tokens"]
+
+ if not isinstance(input_ids, torch.Tensor) or not isinstance(
+ input_ids_lengths, torch.Tensor
+ ):
+ raise TypeError("Debug batch requires tensor input_ids and input_ids_lengths.")
+ if not isinstance(sample, torch.Tensor) or not isinstance(
+ sample_lengths, torch.Tensor
+ ):
+ raise TypeError("Debug batch requires tensor sample and sample_lengths.")
+ if not isinstance(num_audio_tokens, torch.Tensor) or not isinstance(
+ num_text_tokens, torch.Tensor
+ ):
+ raise TypeError(
+ "Debug batch requires tensor num_audio_tokens and num_text_tokens."
+ )
+
+ source_names = batch.get("source_names")
+ debug_lines = [
+ (
+ "[debug:data] "
+ f"batch_index={batch_index} "
+ f"batch_size={int(input_ids.size(0))} "
+ f"input_ids_shape={tuple(input_ids.shape)} "
+ f"sample_shape={tuple(sample.shape)} "
+ f"sample_rate={sample_rate} "
+ f"sources={dict(Counter(source_names or []))}"
+ ),
+ (
+ "[debug:data] "
+ f"input_tokens(min/mean/max)={_format_tensor_triplet(input_ids_lengths)} "
+ f"text_tokens(min/mean/max)={_format_tensor_triplet(num_text_tokens)} "
+ f"audio_tokens(min/mean/max)={_format_tensor_triplet(num_audio_tokens)} "
+ f"audio_seconds(min/mean/max)={_format_audio_seconds_triplet(sample_lengths, sample_rate)}"
+ ),
+ ]
+
+ fbank = batch.get("fbank")
+ fbank_lengths = batch.get("fbank_lengths")
+ if isinstance(fbank, torch.Tensor):
+ debug_lines.append(
+ "[debug:data] "
+ f"fbank_shape={tuple(fbank.shape)} "
+ f"fbank_frames(min/mean/max)={_format_tensor_triplet(fbank_lengths)}"
+ )
+
+ loss_masks = batch.get("loss_masks")
+ if isinstance(loss_masks, dict):
+ debug_lines.append(
+ "[debug:data] "
+ "loss_masks="
+ + ", ".join(
+ f"{name}:{_format_mask_density(mask)}"
+ for name, mask in sorted(loss_masks.items())
+ )
+ )
+
+ fids = batch.get("fids") or []
+ sample_count = min(int(input_ids.size(0)), 3)
+ for sample_idx in range(sample_count):
+ input_length = int(input_ids_lengths[sample_idx].item())
+ audio_length = int(sample_lengths[sample_idx].item())
+ fbank_shape = "unavailable"
+ if isinstance(fbank, torch.Tensor) and isinstance(fbank_lengths, torch.Tensor):
+ fbank_shape = (
+ f"({int(fbank_lengths[sample_idx].item())}, {int(fbank.size(-1))})"
+ )
+ debug_lines.append(
+ "[debug:data] "
+ f"sample_index={sample_idx} "
+ f"fid={str(fids[sample_idx]) if sample_idx < len(fids) else f'sample_{sample_idx:02d}'} "
+ f"source_name={source_names[sample_idx] if source_names else None} "
+ f"input_ids_shape=({input_length},) "
+ f"sample_shape=(1, {audio_length}) "
+ f"fbank_shape={fbank_shape} "
+ f"num_text_tokens={int(num_text_tokens[sample_idx].item())} "
+ f"num_audio_tokens={int(num_audio_tokens[sample_idx].item())} "
+ f"audio_seconds={audio_length / float(sample_rate):.2f} "
+ "text="
+ f"{tokenizer.decode(input_ids[sample_idx, :input_length].detach().cpu().tolist(), skip_special_tokens=False, clean_up_tokenization_spaces=False)!r}"
+ )
+ return debug_lines
+
+
+def should_print_gradient_debug(
+ *,
+ debug_enabled: bool,
+ is_main_process: bool,
+ next_global_step: int,
+ log_interval: int,
+ early_step_limit: int,
+) -> bool:
+ return bool(
+ debug_enabled
+ and is_main_process
+ and (
+ next_global_step <= early_step_limit
+ or next_global_step % log_interval == 0
+ )
+ )
+
+
+def build_gradient_debug_lines(
+ model: torch.nn.Module,
+ *,
+ global_step: int,
+ grad_norm: float,
+ grad_clip_norm: float,
+) -> list[str]:
+ top_param_candidates: list[tuple[str, float, float, float]] = []
+ nonfinite_grad_params: list[str] = []
+ nonfinite_param_count = 0
+ params_with_grad = 0
+ params_without_grad = 0
+ abs_sum = 0.0
+ abs_count = 0
+ max_abs_grad = 0.0
+
+ for name, parameter in model.named_parameters():
+ if not parameter.requires_grad:
+ continue
+ grad = parameter.grad
+ if grad is None:
+ params_without_grad += 1
+ continue
+
+ grad_tensor = grad.detach().float()
+ params_with_grad += 1
+ if not bool(torch.isfinite(grad_tensor).all().item()):
+ nonfinite_param_count += 1
+ if len(nonfinite_grad_params) < 8:
+ nonfinite_grad_params.append(name)
+ grad_abs = grad_tensor.abs()
+ param_norm = float(torch.linalg.vector_norm(grad_tensor).item())
+ param_max_abs = float(grad_abs.max().item())
+ param_mean_abs = float(grad_abs.mean().item())
+ max_abs_grad = max(max_abs_grad, param_max_abs)
+ abs_sum += float(grad_abs.sum().item())
+ abs_count += int(grad_abs.numel())
+ top_param_candidates.append((name, param_norm, param_max_abs, param_mean_abs))
+
+ mean_abs_grad = math.nan if abs_count == 0 else abs_sum / float(abs_count)
+ top_param_norms = sorted(
+ top_param_candidates,
+ key=lambda item: item[1],
+ reverse=True,
+ )[:6]
+
+ debug_lines = [
+ (
+ "[debug:grad] "
+ f"step={global_step} "
+ f"pre_clip_grad_norm={format_scalar(grad_norm)} "
+ f"clip_ratio={format_scalar(_safe_grad_clip_ratio(grad_norm, grad_clip_norm))} "
+ f"params_with_grad={params_with_grad} "
+ f"params_without_grad={params_without_grad} "
+ f"nonfinite_param_count={nonfinite_param_count} "
+ f"max_abs_grad={format_scalar(max_abs_grad)} "
+ f"mean_abs_grad={format_scalar(mean_abs_grad)}"
+ )
+ ]
+ if top_param_norms:
+ debug_lines.append(
+ "[debug:grad] top_params="
+ + ", ".join(
+ (
+ f"{name}:{param_norm:.4f}"
+ f"(max={param_max_abs:.4e},mean={param_mean_abs:.4e})"
+ )
+ for name, param_norm, param_max_abs, param_mean_abs in top_param_norms
+ )
+ )
+ if nonfinite_grad_params:
+ debug_lines.append(
+ "[debug:grad] nonfinite_params=" + ", ".join(nonfinite_grad_params)
+ )
+ return debug_lines
+
+
+def _format_tensor_triplet(values: object) -> str:
+ if not isinstance(values, torch.Tensor) or values.numel() == 0:
+ return "n/a"
+ flattened = values.detach().cpu().to(torch.float32)
+ return (
+ f"{int(flattened.min().item())}/"
+ f"{flattened.mean().item():.2f}/"
+ f"{int(flattened.max().item())}"
+ )
+
+
+def _format_audio_seconds_triplet(values: object, sample_rate: int) -> str:
+ if not isinstance(values, torch.Tensor) or values.numel() == 0:
+ return "n/a"
+ seconds = values.detach().cpu().to(torch.float32) / float(sample_rate)
+ return (
+ f"{seconds.min().item():.2f}/"
+ f"{seconds.mean().item():.2f}/"
+ f"{seconds.max().item():.2f}"
+ )
+
+
+def _format_mask_density(mask: object) -> str:
+ if not isinstance(mask, torch.Tensor) or mask.numel() == 0:
+ return "n/a"
+ return f"{int(mask.detach().gt(0).sum().item())}/{int(mask.numel())}"
+
+
+def _safe_grad_clip_ratio(grad_norm: float, grad_clip_norm: float) -> float:
+ if not math.isfinite(grad_norm):
+ return math.nan
+ return grad_norm / float(grad_clip_norm)
+
+
+# ---------------------------------------------------------------------------
+# Step Reporting
+# ---------------------------------------------------------------------------
+
+
+def should_log_training_step(global_step: int, log_interval: int) -> bool:
+ return global_step % log_interval == 0
+
+
+def reduce_source_metrics(
+ source_loss_totals: dict[str, dict[str, float]],
+ source_loss_denominators: dict[str, dict[str, float]],
+ *,
+ device: torch.device,
+ loss_config: Any,
+) -> dict[str, dict[str, float]]:
+ reduced_source_totals = loss_ops.sum_grouped_named_scalars_across_ranks(
+ source_loss_totals,
+ device=device,
+ )
+ reduced_source_denominators = loss_ops.sum_grouped_named_scalars_across_ranks(
+ source_loss_denominators,
+ device=device,
+ )
+ return loss_ops.reduce_loss_statistics_by_source(
+ reduced_source_totals,
+ reduced_source_denominators,
+ loss_config=loss_config,
+ )
+
+
+def build_train_step_report(
+ metrics: dict[str, Any],
+ *,
+ learning_rate: float,
+ grad_norm: float,
+ current_time: float,
+ last_log_step: int,
+ last_log_time: float,
+ progress: TrainProgress,
+ max_train_steps: int,
+ reduced_by_source: dict[str, dict[str, float]],
+) -> TrainStepReport:
+ logged_steps = progress.global_step - last_log_step
+ elapsed = current_time - last_log_time
+ steps_per_second = (
+ math.nan
+ if logged_steps <= 0 or elapsed <= 0.0
+ else float(logged_steps) / elapsed
+ )
+ eta_seconds = (
+ math.nan
+ if not math.isfinite(steps_per_second) or steps_per_second <= 0.0
+ else float(max_train_steps - progress.global_step) / steps_per_second
+ )
+ return TrainStepReport(
+ log_values=build_train_log_dict(
+ metrics,
+ learning_rate=learning_rate,
+ grad_norm=grad_norm,
+ steps_per_second=steps_per_second,
+ eta_seconds=eta_seconds,
+ progress=progress,
+ reduced_by_source=reduced_by_source,
+ ),
+ console_line=format_train_line(
+ metrics,
+ learning_rate=learning_rate,
+ grad_norm=grad_norm,
+ steps_per_second=steps_per_second,
+ eta_seconds=eta_seconds,
+ progress=progress,
+ max_train_steps=max_train_steps,
+ reduced_by_source=reduced_by_source,
+ ),
+ )
+
+
+# ---------------------------------------------------------------------------
+# Formatting Helpers
+# ---------------------------------------------------------------------------
+
+
+def flatten_config(values, parent_key="", sep="/"):
+ """Flatten a nested config dict into ``path/to/key -> value`` pairs."""
+ items = []
+ for key, value in values.items():
+ new_key = f"{parent_key}{sep}{key}" if parent_key else key
+ if isinstance(value, dict):
+ items.extend(flatten_config(value, new_key, sep).items())
+ elif isinstance(value, (list, tuple)):
+ items.append((new_key, str(value)))
+ elif value is None:
+ items.append((new_key, "None"))
+ else:
+ items.append((new_key, value))
+ return dict(items)
+
+
+def format_scalar(value: float) -> str:
+ """Format a scalar for concise console logging."""
+ if not math.isfinite(value):
+ return "nan"
+ if float(value).is_integer():
+ return str(int(value))
+ return f"{value:.4f}"
+
+
+def _format_eta(eta_seconds: float) -> str:
+ """Render ETA seconds as ``HH:MM:SS`` or ``n/a``."""
+ if not math.isfinite(eta_seconds) or eta_seconds < 0.0:
+ return "n/a"
+ total_seconds = int(round(eta_seconds))
+ hours, remainder = divmod(total_seconds, 3600)
+ minutes, seconds = divmod(remainder, 60)
+ return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
+
+
+def build_train_log_dict(
+ metrics: dict[str, Any],
+ *,
+ learning_rate: float,
+ grad_norm: float,
+ steps_per_second: float,
+ eta_seconds: float,
+ progress: TrainProgress,
+ reduced_by_source: dict[str, dict[str, Any]],
+) -> dict[str, float]:
+ """Build the flat metric dict sent to experiment trackers."""
+ log_dict = {
+ "train/epoch": float(progress.epoch),
+ "train/learning_rate": learning_rate,
+ "train/grad_norm": grad_norm,
+ "train/steps_per_second": steps_per_second,
+ "train/eta_seconds": eta_seconds,
+ "train/consumed_tokens": float(progress.total_tokens),
+ "train/consumed_audio_tokens": float(progress.audio_tokens),
+ "train/consumed_text_tokens": float(progress.text_tokens),
+ }
+ for name, value in metrics.items():
+ log_dict[f"train/{name}"] = float(value)
+ for source_name, source_metrics in reduced_by_source.items():
+ log_dict.update(
+ {
+ f"train/{source_name}/{name}": float(value)
+ for name, value in source_metrics.items()
+ }
+ )
+ return log_dict
+
+
+def format_train_line(
+ metrics: dict[str, Any],
+ *,
+ learning_rate: float,
+ grad_norm: float,
+ steps_per_second: float,
+ eta_seconds: float,
+ progress: TrainProgress,
+ max_train_steps: int,
+ reduced_by_source: dict[str, dict[str, Any]],
+) -> str:
+ """Build a single human-readable console line for one training step."""
+ parts = [
+ f"iteration {progress.global_step}/{max_train_steps}",
+ f"epoch: {progress.epoch}",
+ f"consumed_tokens: {progress.total_tokens}",
+ f"consumed_audio_tokens: {progress.audio_tokens}",
+ f"consumed_text_tokens: {progress.text_tokens}",
+ f"learning_rate: {learning_rate:.2e}",
+ f"steps_per_second: {format_scalar(steps_per_second)}",
+ f"job_eta: {_format_eta(eta_seconds)}",
+ f"grad_norm: {format_scalar(grad_norm)}",
+ ]
+ for name in sorted(name for name in metrics if name != "loss"):
+ parts.append(f"{name}: {format_scalar(float(metrics[name]))}")
+ if "loss" in metrics:
+ parts.append(f"loss: {format_scalar(float(metrics['loss']))}")
+ for source_name, source_metrics in reduced_by_source.items():
+ for name in sorted(name for name in source_metrics if name != "loss"):
+ parts.append(
+ f"{source_name}_{name}: {format_scalar(float(source_metrics[name]))}"
+ )
+ if "loss" in source_metrics:
+ parts.append(
+ f"{source_name}_loss: {format_scalar(float(source_metrics['loss']))}"
+ )
+ return " | ".join(parts)
+
+
+def build_validation_log_dict(
+ metrics: dict[str, Any],
+ *,
+ reduced_by_source: dict[str, dict[str, Any]],
+) -> dict[str, float]:
+ """Build the flat validation metric dict sent to experiment trackers."""
+ log_dict = {f"val/{name}": float(value) for name, value in metrics.items()}
+ for source_name, source_metrics in reduced_by_source.items():
+ log_dict.update(
+ {
+ f"val/{source_name}/{name}": float(value)
+ for name, value in source_metrics.items()
+ }
+ )
+ return log_dict
+
+
+def format_validation_line(
+ metrics: dict[str, Any],
+ *,
+ global_step: int,
+ reduced_by_source: dict[str, dict[str, Any]],
+) -> str:
+ """Build the console summary line printed after a validation pass."""
+ parts = [f"validation at iteration {global_step}"]
+ for name in sorted(name for name in metrics if name != "loss"):
+ parts.append(f"{name}: {format_scalar(float(metrics[name]))}")
+ if "loss" in metrics:
+ parts.append(f"loss: {format_scalar(float(metrics['loss']))}")
+ for source_name, source_metrics in reduced_by_source.items():
+ for name in sorted(name for name in source_metrics if name != "loss"):
+ parts.append(
+ f"{source_name}_{name}: {format_scalar(float(source_metrics[name]))}"
+ )
+ if "loss" in source_metrics:
+ parts.append(
+ f"{source_name}_loss: {format_scalar(float(source_metrics['loss']))}"
+ )
+ return " | ".join(parts)
+
+
+__all__ = [
+ "TrainProgress",
+ "TrainStepReport",
+ "abort_on_out_of_memory",
+ "any_rank_true",
+ "build_data_debug_lines",
+ "build_gradient_debug_lines",
+ "build_train_step_report",
+ "build_train_log_dict",
+ "build_validation_log_dict",
+ "flatten_config",
+ "format_scalar",
+ "format_train_line",
+ "format_validation_line",
+ "move_to_device",
+ "reduce_source_metrics",
+ "should_log_training_step",
+ "should_print_gradient_debug",
+ "sum_integer_counters_across_ranks",
+]
diff --git a/src/dots_tts/utils/__init__.py b/src/dots_tts/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/dots_tts/utils/audio.py b/src/dots_tts/utils/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e2202b650d864c8b492b3693043ce254efcff8f
--- /dev/null
+++ b/src/dots_tts/utils/audio.py
@@ -0,0 +1,45 @@
+"""Audio helpers used by the retained train/infer pipeline."""
+
+from __future__ import annotations
+
+import torch
+import torchaudio.compliance.kaldi as Kaldi
+import torchaudio.functional as AF
+
+
+def high_quality_resample(x, orig_sr, target_sr):
+ return AF.resample(
+ x,
+ orig_freq=orig_sr,
+ new_freq=target_sr,
+ lowpass_filter_width=64,
+ rolloff=0.95,
+ resampling_method="sinc_interp_kaiser",
+ )
+
+
+def extract_fbank(
+ waveform: torch.Tensor,
+ *,
+ sample_rate: int,
+ n_mels: int,
+ dither: float = 0.0,
+ mean_norm: bool = False,
+) -> torch.Tensor:
+ if waveform.ndim == 1:
+ feature_input = waveform.unsqueeze(0)
+ elif waveform.ndim == 2:
+ feature_input = waveform if waveform.size(0) == 1 else waveform[0:1, :]
+ else:
+ raise ValueError(
+ f"FBank expects a 1D or 2D waveform, got shape {tuple(waveform.shape)}."
+ )
+ features = Kaldi.fbank(
+ feature_input,
+ num_mel_bins=n_mels,
+ sample_frequency=sample_rate,
+ dither=dither,
+ )
+ if mean_norm:
+ features = features - features.mean(dim=0, keepdim=True)
+ return features
diff --git a/src/dots_tts/utils/logging.py b/src/dots_tts/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1e7bcdf983d2cab32ba8d7825da9daff76ab954
--- /dev/null
+++ b/src/dots_tts/utils/logging.py
@@ -0,0 +1,42 @@
+from __future__ import annotations
+
+import os
+import sys
+from pathlib import Path
+
+from loguru import logger
+
+DEFAULT_LOG_LEVEL = "INFO"
+DEFAULT_LOG_FORMAT = (
+ "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level:<8} | "
+ "{name}:{function}:{line} | {message}"
+)
+
+
+def configure_logging(
+ *,
+ level: str | None = None,
+ log_file: str | os.PathLike[str] | None = None,
+) -> None:
+ resolved_level = (level or os.environ.get("DOTS_TTS_LOG_LEVEL") or DEFAULT_LOG_LEVEL).upper()
+ logger.remove()
+ logger.add(
+ sys.stderr,
+ level=resolved_level,
+ format=DEFAULT_LOG_FORMAT,
+ backtrace=True,
+ diagnose=False,
+ enqueue=False,
+ )
+ if log_file:
+ log_path = Path(log_file).expanduser()
+ log_path.parent.mkdir(parents=True, exist_ok=True)
+ logger.add(
+ log_path,
+ level=resolved_level,
+ format=DEFAULT_LOG_FORMAT,
+ backtrace=True,
+ diagnose=False,
+ enqueue=False,
+ encoding="utf-8",
+ )
diff --git a/src/dots_tts/utils/profiling.py b/src/dots_tts/utils/profiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf5a54c0862542a9b11fcbfeae3ce79700d916de
--- /dev/null
+++ b/src/dots_tts/utils/profiling.py
@@ -0,0 +1,215 @@
+from __future__ import annotations
+
+import os
+import time
+from contextlib import contextmanager
+from contextvars import ContextVar, Token
+from dataclasses import dataclass
+from multiprocessing import Queue
+from typing import Iterator
+
+import torch
+from loguru import logger
+
+INFERENCE_STAGE_NAMES = (
+ "FM",
+ "latent_encoder",
+ "patch_encoder",
+ "LLM",
+ "latent_decoder",
+ "speaker_encoder",
+ "vocoder",
+)
+
+_INFERENCE_STAGE_NAME_MAP = {
+ name.lower(): name for name in INFERENCE_STAGE_NAMES
+}
+_CURRENT_INFERENCE_PROFILER: ContextVar[InferenceProfiler | None] = ContextVar(
+ "current_inference_profiler",
+ default=None,
+)
+
+
+def normalize_inference_stage_name(name: str) -> str:
+ canonical = _INFERENCE_STAGE_NAME_MAP.get(name.strip().lower())
+ if canonical is None:
+ raise ValueError(
+ f"Unsupported inference stage '{name}'. "
+ f"Expected one of: {', '.join(INFERENCE_STAGE_NAMES)}."
+ )
+ return canonical
+
+
+@dataclass(slots=True)
+class InferenceStageStat:
+ seconds: float = 0.0
+ count: int = 0
+
+
+@dataclass(frozen=True, slots=True)
+class ProfileEvent:
+ stage: str
+ seconds: float
+ count: int
+ pid: int
+
+
+class DataProfiler:
+ def __init__(self, queue: Queue | None = None):
+ self._queue = queue
+ self._pid = os.getpid()
+
+ @property
+ def enabled(self) -> bool:
+ return self._queue is not None
+
+ @contextmanager
+ def measure(self, stage: str, *, count: int = 1) -> Iterator[None]:
+ if self._queue is None:
+ yield
+ return
+ start = time.perf_counter()
+ try:
+ yield
+ finally:
+ self._queue.put(
+ ProfileEvent(
+ stage=stage,
+ seconds=time.perf_counter() - start,
+ count=int(count),
+ pid=self._pid,
+ )
+ )
+
+ def child(self) -> DataProfiler:
+ return DataProfiler(self._queue)
+
+
+def ensure_data_profiler(profiler: DataProfiler | None) -> DataProfiler:
+ return DataProfiler() if profiler is None else profiler
+
+
+class InferenceProfiler:
+ def __init__(self, device: torch.device):
+ self._device = device
+ self._stats = {
+ stage: InferenceStageStat() for stage in INFERENCE_STAGE_NAMES
+ }
+
+ def _sync(self) -> None:
+ if self._device.type == "cuda":
+ torch.cuda.synchronize(self._device)
+
+ @contextmanager
+ def measure(self, stage: str, *, count: int = 1) -> Iterator[None]:
+ stage = normalize_inference_stage_name(stage)
+ self._sync()
+ start = time.perf_counter()
+ try:
+ yield
+ finally:
+ self._sync()
+ stat = self._stats[stage]
+ stat.seconds += time.perf_counter() - start
+ stat.count += int(count)
+
+ def summary(
+ self,
+ *,
+ duration_seconds: float | None = None,
+ ) -> dict[str, dict[str, float | int]]:
+ summary: dict[str, dict[str, float | int]] = {}
+ for stage in INFERENCE_STAGE_NAMES:
+ stat = self._stats[stage]
+ payload: dict[str, float | int] = {
+ "seconds": stat.seconds,
+ "count": stat.count,
+ }
+ if duration_seconds is not None:
+ payload["rtf"] = (
+ stat.seconds / duration_seconds
+ if duration_seconds > 0
+ else float("inf")
+ )
+ summary[stage] = payload
+ return summary
+
+
+@contextmanager
+def inference_profiling(
+ *,
+ enabled: bool,
+ device: torch.device,
+) -> Iterator[InferenceProfiler | None]:
+ profiler = InferenceProfiler(device) if enabled else None
+ with activate_inference_profiler(profiler):
+ yield profiler
+
+
+@contextmanager
+def activate_inference_profiler(
+ profiler: InferenceProfiler | None,
+) -> Iterator[InferenceProfiler | None]:
+ if profiler is None:
+ yield None
+ return
+ token: Token[InferenceProfiler | None] = _CURRENT_INFERENCE_PROFILER.set(profiler)
+ try:
+ yield profiler
+ finally:
+ _CURRENT_INFERENCE_PROFILER.reset(token)
+
+
+@contextmanager
+def measure_inference(stage: str, *, count: int = 1) -> Iterator[None]:
+ profiler = _CURRENT_INFERENCE_PROFILER.get()
+ if profiler is None:
+ yield
+ return
+ with profiler.measure(stage, count=count):
+ yield
+
+
+def log_inference_profile(
+ *,
+ request_id: str,
+ profiling: dict[str, dict[str, float | int]],
+ duration_seconds: float,
+) -> None:
+ active_stages = [
+ stage
+ for stage in INFERENCE_STAGE_NAMES
+ if int(profiling[stage]["count"]) > 0
+ ]
+ if not active_stages:
+ logger.info(
+ "Inference profiling summary: request_id={} no_profiled_stages duration_seconds={:.3f}",
+ request_id,
+ duration_seconds,
+ )
+ return
+ for stage in active_stages:
+ stats = profiling[stage]
+ logger.info(
+ "Inference profiling: request_id={} stage={} seconds={:.4f} count={} rtf={:.4f}",
+ request_id,
+ stage,
+ float(stats["seconds"]),
+ int(stats["count"]),
+ float(stats["rtf"]),
+ )
+
+
+__all__ = [
+ "DataProfiler",
+ "ProfileEvent",
+ "INFERENCE_STAGE_NAMES",
+ "activate_inference_profiler",
+ "ensure_data_profiler",
+ "InferenceProfiler",
+ "InferenceStageStat",
+ "inference_profiling",
+ "log_inference_profile",
+ "measure_inference",
+ "normalize_inference_stage_name",
+]
diff --git a/src/dots_tts/utils/text.py b/src/dots_tts/utils/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccda50f61412718f42193b3f8852189e96dff04e
--- /dev/null
+++ b/src/dots_tts/utils/text.py
@@ -0,0 +1,132 @@
+from __future__ import annotations
+
+import re
+from functools import lru_cache
+from typing import Literal
+
+from langcodes import Language as LangcodesLanguage
+from lingua import Language, LanguageDetectorBuilder
+from tn.chinese.normalizer import Normalizer as ZhNormalizer
+from tn.english.normalizer import Normalizer as EnNormalizer
+
+TextLanguage = Literal["zh", "en", "unknown"]
+
+_WHITESPACE_PATTERN = re.compile(r"\s+")
+
+
+@lru_cache(maxsize=1)
+def get_chinese_text_normalizer() -> ZhNormalizer:
+ return ZhNormalizer()
+
+
+@lru_cache(maxsize=1)
+def get_english_text_normalizer() -> EnNormalizer:
+ return EnNormalizer()
+
+
+@lru_cache(maxsize=1)
+def get_language_detector():
+ supported_languages = tuple(
+ sorted(Language.all(), key=lambda language: language.name)
+ )
+ return LanguageDetectorBuilder.from_languages(*supported_languages).build()
+
+
+def _lingua_language_to_code(language: Language | None) -> str | None:
+ if language is None:
+ return None
+ iso_code_639_1 = getattr(language.iso_code_639_1, "name", None)
+ if iso_code_639_1:
+ return iso_code_639_1.lower()
+ iso_code_639_3 = getattr(language.iso_code_639_3, "name", None)
+ if iso_code_639_3:
+ return iso_code_639_3.lower()
+ return language.name.lower()
+
+
+def detect(text: str) -> str | None:
+ stripped = text.strip()
+ if not stripped:
+ return None
+ language = get_language_detector().detect_language_of(stripped)
+ return _lingua_language_to_code(language)
+
+
+def normalize_language_code(language: str | None) -> str | None:
+ if language is None:
+ return None
+
+ stripped = language.strip()
+ if not stripped or stripped.lower() in {"none", "unknown"}:
+ return None
+ if stripped.startswith("口音:"):
+ return stripped
+
+ for resolver in (LangcodesLanguage.get, LangcodesLanguage.find):
+ try:
+ normalized_language = resolver(stripped).prefer_macrolanguage()
+ except Exception:
+ continue
+
+ language_code = (normalized_language.language or "").strip().upper()
+ if language_code and language_code != "UND":
+ return language_code
+ return None
+
+
+def attach_language_tag(text: str, language: str | None) -> str:
+ if not text:
+ return text
+
+ language_code = normalize_language_code(language)
+ if language_code is None:
+ return text
+
+ if language_code == "YUE":
+ language_code = "口音:粤语"
+
+ language_tag = f"[{language_code}]"
+ if text.startswith(language_tag):
+ return text
+ return f"{language_tag}{text}"
+
+
+def detect_text_language(text: str) -> TextLanguage:
+ language_code = detect(text)
+ if language_code == "zh":
+ return "zh"
+ if language_code == "en":
+ return "en"
+ return "unknown"
+
+
+def _normalize_with(normalizer, text: str) -> str:
+ normalized = normalizer.normalize(text)
+ return _WHITESPACE_PATTERN.sub(" ", normalized).strip()
+
+
+def normalize_chinese_text(text: str) -> str:
+ stripped = text.strip()
+ if not stripped:
+ return ""
+ return _normalize_with(get_chinese_text_normalizer(), stripped)
+
+
+def normalize_english_text(text: str) -> str:
+ stripped = text.strip()
+ if not stripped:
+ return ""
+ return _normalize_with(get_english_text_normalizer(), stripped)
+
+
+def normalize_text(text: str) -> str:
+ stripped = text.strip()
+ if not stripped:
+ return ""
+
+ language = detect_text_language(stripped)
+ if language == "zh":
+ return _normalize_with(get_chinese_text_normalizer(), stripped)
+ if language == "en":
+ return _normalize_with(get_english_text_normalizer(), stripped)
+ return stripped
diff --git a/src/dots_tts/utils/tokenizer.py b/src/dots_tts/utils/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..afa6b2a64e8c5e1c3a790870fa824f32cd6f73fd
--- /dev/null
+++ b/src/dots_tts/utils/tokenizer.py
@@ -0,0 +1,28 @@
+from __future__ import annotations
+
+AUDIO_COMP_START_TOKEN = "<|audio_comp_start|>"
+AUDIO_COMP_SPAN_TOKEN = "<|audio_comp_span|>"
+AUDIO_COMP_END_TOKEN = "<|audio_comp_end|>"
+AUDIO_GEN_START_TOKEN = "<|audio_gen_start|>"
+AUDIO_GEN_SPAN_TOKEN = "<|audio_gen_span|>"
+AUDIO_GEN_END_TOKEN = "<|audio_gen_end|>"
+TEXT_COND_END_TOKEN = "<|text_cond_end|>"
+
+
+def require_token_id(tokenizer, token: str) -> int:
+ token_id = tokenizer.convert_tokens_to_ids(token)
+ if token_id is None or token_id < 0:
+ raise ValueError(f"Artifact tokenizer is missing required special token: {token}")
+ return int(token_id)
+
+
+__all__ = [
+ "AUDIO_COMP_END_TOKEN",
+ "AUDIO_COMP_SPAN_TOKEN",
+ "AUDIO_COMP_START_TOKEN",
+ "AUDIO_GEN_END_TOKEN",
+ "AUDIO_GEN_SPAN_TOKEN",
+ "AUDIO_GEN_START_TOKEN",
+ "TEXT_COND_END_TOKEN",
+ "require_token_id",
+]
diff --git a/src/dots_tts/utils/util.py b/src/dots_tts/utils/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b68727fbab5187d56861da2abea20b93727442b
--- /dev/null
+++ b/src/dots_tts/utils/util.py
@@ -0,0 +1,48 @@
+import random
+from typing import Any
+
+import numpy as np
+import torch
+
+
+def get_dtype(x):
+ if x.lower() in ("bf16", "torch.bfloat16", "bfloat16"):
+ return torch.bfloat16
+ if x.lower() in ("fp16", "torch.float16", "float16"):
+ return torch.float16
+ if x.lower() in ("fp32", "torch.float32", "float32"):
+ return torch.float32
+ raise ValueError("Unsupported dtype value.")
+
+
+def seed_everything(seed: int = 42):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def mask_data(x, mask, masking_value=0.0):
+ while mask.dim() < x.dim():
+ mask = mask.unsqueeze(-1)
+ if isinstance(masking_value, torch.Tensor):
+ return torch.where(mask, masking_value.expand_as(x), x)
+ return torch.where(
+ mask, torch.full(x.shape, masking_value, dtype=x.dtype, device=x.device), x
+ )
+
+
+def get_mask_from_lengths(lengths, max_len=None):
+ if max_len is None:
+ max_len = torch.max(lengths).item()
+ ids = torch.arange(0, max_len, out=torch.LongTensor(max_len).to(lengths.device))
+ return (ids < lengths.unsqueeze(1)).bool()
+
+
+def scalar_as_float(value: Any) -> float:
+ if isinstance(value, torch.Tensor):
+ return float(value.detach().float().item())
+ return float(value)