diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0652054b20b554fa1cb6208a1f8dbc2ceab35233 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2026 OpenMOSS Team, Fudan University, SII and MOSI + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index c841f159ecb1fe6f17fce367c4c2d4e0caf14cc3..67eca1ad32cab79957e1af754a0574948138fee9 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,12 @@ python_version: '3.12' app_file: app.py pinned: false license: apache-2.0 +tags: + - zerogpu + - aoti + - text-to-speech --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +dots.tts Gradio Space for Hugging Face ZeroGPU with optional PyTorch AOTInductor startup compilation. + +Set `DOTS_TTS_MODEL_NAME_OR_PATH` to a local model directory or Hugging Face model repo id. The app defaults to `rednote-hilab/dots.tts`. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e16f0b22d540f13c38c491e225529dd9d38b77ef --- /dev/null +++ b/app.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import os +import sys +from pathlib import Path +from typing import Any, Callable + + +REPO_ROOT = Path(__file__).resolve().parent +SRC_ROOT = REPO_ROOT / "src" + +for import_root in (REPO_ROOT, SRC_ROOT): + import_root_str = str(import_root) + if import_root_str not in sys.path: + sys.path.insert(0, import_root_str) + + +class _SpacesFallback: + @staticmethod + def GPU(*decorator_args, **_decorator_kwargs): + if decorator_args and callable(decorator_args[0]): + return decorator_args[0] + + def decorate(fn: Callable[..., Any]) -> Callable[..., Any]: + return fn + + return decorate + + +try: + import spaces # type: ignore +except Exception: # pragma: no cover - only used outside Hugging Face Spaces. + spaces = _SpacesFallback() # type: ignore + + +def _env_bool(name: str, default: bool) -> bool: + value = os.environ.get(name) + if value is None: + return default + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def _env_int(name: str, default: int) -> int: + value = os.environ.get(name) + if value is None or not value.strip(): + return default + return int(value) + + +def _configure_zero_gpu_environment() -> None: + os.environ.setdefault("DOTS_TTS_COMPILE_BACKEND", "aoti") + os.environ.setdefault("DOTS_TTS_SKIP_INIT_WARMUP", "1") + os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + + +def _preload_runtime(app_service, app_config, compile_backend: str): + runtime, resolved_model_name_or_path = app_service._get_runtime( # noqa: SLF001 + app_config.default_model_name_or_path, + ) + runtime.optimize = bool(app_config.optimize) + runtime.model.set_optimize(bool(app_config.optimize)) + if hasattr(runtime.model, "set_compile_backend"): + runtime.model.set_compile_backend(compile_backend) + return runtime, resolved_model_name_or_path + + +def main() -> None: + _configure_zero_gpu_environment() + + import gradio as gr + from loguru import logger + + from apps.gradio.app import PLAYGROUND_CSS, build_demo, build_playground_theme + from apps.gradio.service import GradioAppService, build_gradio_app_config + from dots_tts.utils.logging import configure_logging + + host = os.environ.get("DOTS_TTS_HOST", "0.0.0.0") + port = _env_int("DOTS_TTS_PORT", 7860) + model_name_or_path = os.environ.get( + "DOTS_TTS_MODEL_NAME_OR_PATH", + "rednote-hilab/dots.tts", + ) + precision = os.environ.get("DOTS_TTS_PRECISION", "bfloat16") + execution_mode = os.environ.get("DOTS_TTS_EXECUTION_MODE", "generate_stream") + max_generate_length = _env_int("DOTS_TTS_MAX_GENERATE_LENGTH", 500) + default_num_steps = _env_int("DOTS_TTS_DEFAULT_NUM_STEPS", 10) + compile_backend = os.environ.get("DOTS_TTS_COMPILE_BACKEND", "aoti").strip().lower() + enable_aoti = _env_bool("DOTS_TTS_ENABLE_AOTI", True) + startup_compile = _env_bool("DOTS_TTS_AOTI_COMPILE_ON_STARTUP", True) + optimize = _env_bool("DOTS_TTS_OPTIMIZE", True) + generation_duration = _env_int("DOTS_TTS_ZERO_GPU_DURATION", 600) + compile_duration = _env_int("DOTS_TTS_ZERO_GPU_COMPILE_DURATION", 1500) + output_dir = Path(os.environ.get("DOTS_TTS_OUTPUT_DIR", "/tmp/dots_tts_outputs")) + log_file = Path(os.environ.get("DOTS_TTS_LOG_FILE", "/tmp/dots_tts_gradio.log")) + + configure_logging(log_file=log_file) + logger.info( + "Space app starting: model={} execution_mode={} precision={} optimize={} " + "compile_backend={} enable_aoti={} startup_compile={} max_generate_length={}", + model_name_or_path, + execution_mode, + precision, + optimize, + compile_backend, + enable_aoti, + startup_compile, + max_generate_length, + ) + + app_config = build_gradio_app_config( + host=host, + port=port, + execution_mode=execution_mode, + precision=precision, + optimize=optimize, + model_name_or_path=model_name_or_path, + output_dir=output_dir, + max_generate_length=max_generate_length, + default_num_steps=default_num_steps, + default_max_generate_length=max_generate_length, + repo_root=REPO_ROOT, + ) + app_service = GradioAppService(app_config) + runtime, resolved_model_name_or_path = _preload_runtime( + app_service, + app_config, + compile_backend if enable_aoti else "torch_compile", + ) + + if enable_aoti and startup_compile and optimize: + + @spaces.GPU(duration=compile_duration) + def compile_aoti_cache(): + child_runtime, _ = _preload_runtime( + app_service, + app_config, + compile_backend, + ) + child_runtime.model.run_warmup( + max_generate_length=app_config.max_generate_length, + precision=app_config.precision, + num_steps=app_config.default_num_steps, + guidance_scale=app_config.default_guidance_scale, + ) + return child_runtime.model.export_compiled_models() + + compiled_models = compile_aoti_cache() + if compiled_models: + runtime.model.import_compiled_models(compiled_models) + logger.info( + "AOTI startup compile completed: compiled_target_count={}", + len(compiled_models or {}), + ) + + app_service.generate = spaces.GPU(duration=generation_duration)(app_service.generate) + + demo = build_demo(gr, app_config, app_service) + logger.info( + "Space app ready: host={} port={} resolved_model={} compiled_target_count={}", + app_config.host, + app_config.port, + resolved_model_name_or_path, + len(runtime.model.export_compiled_models()) + if hasattr(runtime.model, "export_compiled_models") + else 0, + ) + demo.launch( + server_name=app_config.host, + server_port=app_config.port, + theme=build_playground_theme(gr), + css=PLAYGROUND_CSS, + ) + + +if __name__ == "__main__": + main() diff --git a/apps/__init__.py b/apps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6c872ad474f5be0ff27e8b3de592df998368ba --- /dev/null +++ b/apps/__init__.py @@ -0,0 +1 @@ +"""Application entrypoints for dots.tts.""" diff --git a/apps/gradio/__init__.py b/apps/gradio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f921e101e49432bb5030192aa97efd3aee0280d8 --- /dev/null +++ b/apps/gradio/__init__.py @@ -0,0 +1 @@ +"""Gradio application for dots.tts.""" diff --git a/apps/gradio/app.py b/apps/gradio/app.py new file mode 100644 index 0000000000000000000000000000000000000000..59164b8e7c23f10014b0f9f7f2e37b679029a64d --- /dev/null +++ b/apps/gradio/app.py @@ -0,0 +1,663 @@ +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path +from typing import TYPE_CHECKING + +REPO_ROOT = Path(__file__).resolve().parents[2] +SRC_ROOT = REPO_ROOT / "src" + +for import_root in (REPO_ROOT, SRC_ROOT): + import_root_str = str(import_root) + if import_root_str not in sys.path: + sys.path.insert(0, import_root_str) + +from apps.gradio.constants import ( # noqa: E402 + DEFAULT_EXECUTION_MODE, + DEFAULT_GUIDANCE_SCALE, + DEFAULT_HOST, + DEFAULT_INPUT_TEXT, + DEFAULT_LOG_FILE, + DEFAULT_MAX_GENERATE_LENGTH, + DEFAULT_NUM_STEPS, + DEFAULT_ODE_METHOD, + DEFAULT_OUTPUT_DIR, + DEFAULT_OUTPUT_RETENTION, + DEFAULT_PORT, + DEFAULT_PRECISION, + DEFAULT_PROMPT_NAME, + DEFAULT_SEED, + DEFAULT_SPEAKER_SCALE, +) + +if TYPE_CHECKING: + import gradio as gr + +DEBUG_GRADIO_ENABLED = os.environ.get("DEBUG_GRADIO", "0") == "1" + + +PLAYGROUND_CSS = """ +.gradio-container { + width: min(1600px, calc(100vw - 32px)) !important; + max-width: none !important; + margin: 0 auto !important; + padding-left: 0 !important; + padding-right: 0 !important; +} + +.gradio-container, +.gradio-container .gradio-container { + --block-label-background-fill: #CCE5FF; + --block-label-text-color: #6666FF; + --block-label-border-color: #99c7ee; + --block-label-text-weight: 600; + --block-title-background-fill: #CCE5FF; + --block-title-text-color: #6666FF; + --block-title-border-color: #99c7ee; + --block-title-border-width: var(--block-label-border-width); + --block-title-radius: var(--block-label-radius); + --block-title-padding: var(--block-label-padding); + --block-title-text-size: var(--block-label-text-size); + --block-title-text-weight: 600; +} + +.gradio-container label[data-testid="block-label"], +.gradio-container label[data-testid="block-label"] *, +.gradio-container span[data-testid="block-info"], +.gradio-container span[data-testid="block-info"] * { + background: #CCE5FF !important; + border-color: #99c7ee !important; + color: #6666FF !important; + fill: #6666FF !important; + font-family: Verdana, Geneva, "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans CJK SC", sans-serif !important; + font-style: normal !important; + font-size: 0.78rem !important; + line-height: 1.2 !important; + letter-spacing: 0 !important; + text-transform: none !important; +} +.gradio-container label[data-testid="block-label"], +.gradio-container span[data-testid="block-info"], +.gradio-container [data-testid="block-title"], +.gradio-container .block-title { + border: var(--block-label-border-width) solid #99c7ee !important; + border-top: none !important; + border-left: none !important; + border-radius: var(--block-label-radius) !important; + box-shadow: var(--block-label-shadow) !important; + padding: var(--block-label-padding) !important; +} +.gradio-container label[data-testid="block-label"], +.gradio-container label[data-testid="block-label"] *, +.gradio-container span[data-testid="block-info"], +.gradio-container span[data-testid="block-info"] *, +.gradio-container [data-testid="block-title"], +.gradio-container [data-testid="block-title"] *, +.gradio-container .block-title, +.gradio-container .block-title * { + font-weight: 600 !important; +} +.gradio-container .block label > span, +.gradio-container .block label > span *, +.gradio-container .form label > span, +.gradio-container .form label > span *, +.gradio-container label > span:first-child, +.gradio-container label > span:first-child * { + font-weight: 600 !important; +} +.strong-label [data-testid="block-label"], +.strong-label [data-testid="block-label"] *, +.strong-label span[data-testid="block-info"], +.strong-label span[data-testid="block-info"] *, +.strong-label [data-testid="block-title"], +.strong-label [data-testid="block-title"] *, +.strong-label .block-label, +.strong-label .block-label *, +.strong-label .block-title, +.strong-label .block-title *, +.strong-label label > span:first-child, +.strong-label label > span:first-child * { + font-weight: 600 !important; +} +.gradio-container .info-text, +.gradio-container .info-text * { + font-weight: 400 !important; +} +.gradio-container input, +.gradio-container textarea, +.gradio-container select, +.gradio-container [role="textbox"], +.gradio-container [contenteditable="true"] { + font-weight: 400 !important; +} +.gradio-container label[data-testid="block-label"] > span:first-child { + display: none !important; +} + +.generate-button { + background: #6666FF !important; + color: #ffffff !important; + border: 1px solid #5555ee !important; + font-family: Verdana, Geneva, sans-serif !important; +} +.generate-button:hover { + background: #5555ee !important; +} + +#playground-banner { + padding: 0; + border-radius: 0; + margin-bottom: 18px; + background: transparent; + border: 0; +} +#playground-banner h1 { + margin: 0 0 4px 0; + font-size: 1.7rem; + font-weight: 700; + color: #0f172a; + letter-spacing: 0; +} +#playground-banner .subtitle { + margin: 0; + color: #1e293b; + font-size: 0.9rem; +} + +.info-card { + padding: 14px 18px; + border-radius: 8px; + border: 1px solid #99c7ee; + border-left: 4px solid #2563eb; + background: transparent; + font-size: 0.86rem; + line-height: 1.55; + margin-bottom: 16px; + box-sizing: border-box; + color: #0f172a; +} +.info-card .card-title, +.info-card .notice-title { + display: block; + font-weight: 600; + font-size: 0.92rem; + color: #0f172a; +} +.info-card .card-title { + margin-bottom: 4px; +} +.info-card .notice-title { + margin-top: 8px; + margin-bottom: 4px; +} +.info-card ol, +.info-card ul { + margin: 0; + padding-left: 18px; +} +.info-card li { + margin: 2px 0; +} + +.main-workspace { + gap: 18px !important; + align-items: stretch !important; +} + +.prompt-column, +.synthesis-column { + gap: 14px !important; +} + +.control-row, +.settings-slider-row { + gap: 14px !important; +} + +.settings-card { + margin-top: 2px !important; +} + +.generate-button { + margin-top: 2px !important; + width: 100% !important; + box-sizing: border-box !important; + flex: 0 0 auto !important; + min-height: 44px !important; + padding-top: 10px !important; + padding-bottom: 10px !important; + font-size: 1rem !important; + font-weight: 600 !important; +} + +.output-audio { + flex: 0 0 auto !important; + min-height: 190px !important; +} +.output-audio audio { + width: 100% !important; +} + +@media (max-width: 768px) { + .gradio-container { + width: calc(100vw - 20px) !important; + } + +} + +""" + + +def build_playground_theme(gr): + return gr.themes.Soft( + primary_hue="slate", + secondary_hue="slate", + neutral_hue="slate", + radius_size="md", + text_size="md", + spacing_size="md", + font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"], + ) + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="dots.tts Gradio app.") + parser.add_argument("--host", default=DEFAULT_HOST, help="Server host") + parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="Server port") + parser.add_argument( + "--execution-mode", + choices=("generate", "generate_stream"), + default=DEFAULT_EXECUTION_MODE, + help="Runtime execution mode fixed for the app", + ) + parser.add_argument( + "--precision", + default=DEFAULT_PRECISION, + help="Inference precision fixed for the app runtime", + ) + parser.add_argument( + "--optimize", + action="store_true", + help="Enable runtime optimize acceleration", + ) + parser.add_argument( + "--model-name-or-path", + default=None, + help="Default model directory or Hugging Face repo id", + ) + parser.add_argument( + "--output-dir", + default=str(DEFAULT_OUTPUT_DIR), + help="Directory for generated wav outputs", + ) + parser.add_argument( + "--log-file", + default=str(DEFAULT_LOG_FILE), + help="Path to the Gradio log file", + ) + parser.add_argument( + "--output-retention-count", + type=int, + default=DEFAULT_OUTPUT_RETENTION, + help="Maximum number of generated wav files to keep", + ) + parser.add_argument( + "--max-generate-length", + type=int, + default=DEFAULT_MAX_GENERATE_LENGTH, + help="Maximum generation schedule length fixed for the app runtime", + ) + parser.add_argument( + "--default-prompt-name", + default=DEFAULT_PROMPT_NAME, + help="Default built-in voice preset name", + ) + parser.add_argument( + "--default-precision", + default=DEFAULT_PRECISION, + choices=["bfloat16", "float32", "float16"], + help="Default precision selected in the UI", + ) + parser.add_argument( + "--default-num-steps", + type=int, + default=DEFAULT_NUM_STEPS, + help="Default Num Steps selected in the UI", + ) + parser.add_argument( + "--default-guidance-scale", + type=float, + default=DEFAULT_GUIDANCE_SCALE, + help="Default Guidance Scale selected in the UI", + ) + parser.add_argument( + "--default-speaker-scale", + type=float, + default=DEFAULT_SPEAKER_SCALE, + help="Default Speaker Scale selected in the UI", + ) + parser.add_argument( + "--default-max-generate-length", + type=int, + default=DEFAULT_MAX_GENERATE_LENGTH, + help="Default Max Generate Length selected in the UI", + ) + parser.add_argument( + "--skip-warmup", + action="store_true", + help="Start the Gradio server without running an initial synthesis warmup.", + ) + return parser.parse_args(argv) + + +def build_startup_config_panel(gr, app_config) -> None: + with gr.Accordion("启动固定参数", open=False): + gr.Markdown("只读。修改这部分需要重启服务并传入新的启动参数。") + gr.Textbox( + label="Model", + value=app_config.default_model_name_or_path, + interactive=False, + ) + with gr.Row(): + gr.Textbox( + label="Execution Mode", + value=app_config.execution_mode, + interactive=False, + ) + gr.Textbox( + label="Precision", + value=app_config.precision, + interactive=False, + ) + with gr.Row(): + gr.Number( + label="Max Generate Length", + value=app_config.max_generate_length, + precision=0, + interactive=False, + ) + gr.Checkbox( + label="Optimize", + value=app_config.optimize, + interactive=False, + ) + + +def build_demo(gr, app_config, app_service) -> "gr.Blocks": + from apps.gradio.service import ( + GRADIO_SYNTHESIS_MODE_CHOICES, + SynthesisRequest, + build_prompt_choice_items, + resolve_prompt_selection, + ) + + def select_prompt_preset(prompt_name: str): + audio_path, prompt_text = resolve_prompt_selection( + prompt_name, + app_config.prompt_presets, + ) + return audio_path, prompt_text + + def run_synthesis( + text: str, + synthesis_mode: str, + prompt_audio_path: str | None, + prompt_text: str, + ode_method: str, + num_steps: float, + guidance_scale: float, + speaker_scale: float, + normalize_text: bool, + seed: float, + ): + resolved_synthesis_mode = synthesis_mode if DEBUG_GRADIO_ENABLED else "tts" + request = SynthesisRequest( + model_name_or_path=app_config.default_model_name_or_path, + text=text, + prompt_audio_path=prompt_audio_path, + prompt_text=prompt_text, + execution_mode=app_config.execution_mode, + template_name=resolved_synthesis_mode, + ode_method=ode_method, + num_steps=int(num_steps), + guidance_scale=float(guidance_scale), + speaker_scale=float(speaker_scale), + normalize_text=normalize_text, + seed=int(seed), + ) + result = app_service.generate(request) + return result.audio_path, result.metrics + + show_prompt_preset = bool(app_config.prompt_presets) + + with gr.Blocks(title="dots.tts") as demo: + gr.HTML( + "\n" + + """ +
+

dots.tts

+

Fully-continuous Autoregressive TTS · 48 kHz · Voice Cloning

+
+ """, + ) + + gr.HTML( + """ +
+ 使用说明 · Instructions +
    +
  1. 上传参考音频并填写对应转写文本 · Upload prompt audio and fill in its transcript.
  2. +
  3. 在文本框中输入要合成的内容 · Enter the text to synthesize.
  4. +
  5. 点击 Generate 合成声音 · Click Generate to synthesize speech.
  6. +
+
+ """, + ) + + with gr.Row(equal_height=True, elem_classes="main-workspace"): + with gr.Column(scale=1, min_width=480, elem_classes="prompt-column"): + prompt_preset = gr.Dropdown( + label="音色 · Voice Preset", + choices=build_prompt_choice_items(app_config.prompt_presets), + value=app_config.default_prompt_name, + info="内置音色clone样本;选择后自动填入参考音频与转写。", + elem_id="voice-preset-dropdown", + elem_classes="strong-label", + visible=show_prompt_preset, + ) + prompt_audio_path = gr.Audio( + label="参考音频 · Prompt Audio", + sources=["upload"], + type="filepath", + value=app_config.default_prompt_audio_path, + elem_classes="strong-label", + ) + prompt_text = gr.Textbox( + label="参考音频转写 · Prompt Text", + lines=5, + value=app_config.default_prompt_text, + placeholder="Prompt audio 对应的文本转写(continuation cloning 必填)", + elem_classes="strong-label", + ) + + with gr.Column(scale=1, min_width=480, elem_classes="synthesis-column"): + text = gr.Textbox( + label="待合成文本 · Text", + lines=5, + max_lines=8, + value=DEFAULT_INPUT_TEXT, + placeholder="输入待合成的文本", + elem_classes="strong-label", + ) + with gr.Accordion("⚙️ Settings", open=False, elem_classes="settings-card"): + with gr.Row(elem_classes="settings-slider-row"): + num_steps = gr.Slider( + label="Num Steps", + minimum=1, + maximum=32, + step=1, + value=app_config.default_num_steps, + ) + with gr.Row(elem_classes="settings-slider-row"): + guidance_scale = gr.Slider( + label="Guidance Scale", + minimum=1.0, + maximum=3.0, + step=0.1, + value=app_config.default_guidance_scale, + ) + with gr.Row(elem_classes="control-row"): + seed = gr.Number( + label="Seed", + value=DEFAULT_SEED, + precision=0, + scale=1, + min_width=180, + ) + normalize_text = gr.Checkbox( + label="Normalize Text", + value=False, + scale=1, + min_width=180, + ) + generate = gr.Button( + "Generate", + variant="primary", + size="lg", + elem_classes="generate-button", + ) + audio_out = gr.Audio( + label="生成音频 · Output", + type="filepath", + elem_classes="output-audio", + ) + + if DEBUG_GRADIO_ENABLED: + with gr.Accordion("Debug", open=False): + synthesis_mode = gr.Dropdown( + label="SynthesisMode", + choices=list(GRADIO_SYNTHESIS_MODE_CHOICES), + value="tts", + info="选择合成模式;界面显示名会自动映射到 runtime 对应模板。", + ) + ode_method = gr.Textbox( + label="ODE Method", + value=DEFAULT_ODE_METHOD, + lines=1, + ) + speaker_scale = gr.Slider( + label="Speaker Scale", + minimum=0.0, + maximum=3.0, + step=0.1, + value=app_config.default_speaker_scale, + info="说话人 x-vector 强度", + ) + metrics = gr.JSON(label="Metrics", value=app_service.metadata()) + build_startup_config_panel(gr, app_config) + else: + synthesis_mode = gr.State(value="tts") + ode_method = gr.State(value=DEFAULT_ODE_METHOD) + speaker_scale = gr.State(value=app_config.default_speaker_scale) + metrics = gr.State(value={}) + + generate.click( + fn=run_synthesis, + inputs=[ + text, + synthesis_mode, + prompt_audio_path, + prompt_text, + ode_method, + num_steps, + guidance_scale, + speaker_scale, + normalize_text, + seed, + ], + outputs=[audio_out, metrics], + concurrency_limit=1, + ) + prompt_preset.change( + fn=select_prompt_preset, + inputs=[prompt_preset], + outputs=[prompt_audio_path, prompt_text], + concurrency_limit=1, + ) + + return demo.queue(default_concurrency_limit=1, max_size=8) + + +def main() -> None: + args = parse_args() + import gradio as gr + from loguru import logger + + from apps.gradio.service import GradioAppService, build_gradio_app_config + from dots_tts.utils.logging import configure_logging + + configure_logging(log_file=args.log_file) + logger.info( + "Gradio app starting: host={} port={} model_name_or_path={} output_dir={} " + "log_file={} output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={} " + "default_prompt_name={} skip_warmup={}", + args.host, + args.port, + args.model_name_or_path, + args.output_dir, + args.log_file, + args.output_retention_count, + args.max_generate_length, + args.execution_mode, + args.precision, + args.optimize, + args.default_prompt_name, + args.skip_warmup, + ) + app_config = build_gradio_app_config( + host=args.host, + port=args.port, + execution_mode=args.execution_mode, + precision=args.precision, + optimize=args.optimize, + model_name_or_path=args.model_name_or_path, + output_dir=Path(args.output_dir), + output_retention_count=args.output_retention_count, + max_generate_length=args.max_generate_length, + default_prompt_name=args.default_prompt_name, + default_precision=args.default_precision, + default_num_steps=args.default_num_steps, + default_guidance_scale=args.default_guidance_scale, + default_speaker_scale=args.default_speaker_scale, + default_max_generate_length=args.default_max_generate_length, + ) + app_service = GradioAppService(app_config) + if args.skip_warmup: + logger.info("Gradio app warmup skipped by --skip-warmup.") + else: + warmup_metrics = app_service.warmup() + logger.info("Gradio app warmup metrics: {}", warmup_metrics) + demo = build_demo(gr, app_config, app_service) + logger.info( + "Gradio app ready: host={} port={} execution_mode={} precision={} optimize={} default_model_name_or_path={}", + app_config.host, + app_config.port, + app_config.execution_mode, + app_config.precision, + app_config.optimize, + app_config.default_model_name_or_path, + ) + demo.launch( + server_name=app_config.host, + server_port=app_config.port, + theme=build_playground_theme(gr), + css=PLAYGROUND_CSS, + ) + + +if __name__ == "__main__": + main() diff --git a/apps/gradio/constants.py b/apps/gradio/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..f60261f89f7ec298736b7c8cd0395cff1b1993b5 --- /dev/null +++ b/apps/gradio/constants.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +DEFAULT_HOST = "0.0.0.0" +DEFAULT_PORT = 7860 +DEFAULT_OUTPUT_DIR = REPO_ROOT / "apps" / "gradio" / "outputs" +DEFAULT_LOG_FILE = REPO_ROOT / "apps" / "gradio" / "gradio.log" +DEFAULT_PROMPTS_DIR = REPO_ROOT / "apps" / "gradio" / "default_prompts" +DEFAULT_PROMPT_SOURCE_DIR = DEFAULT_PROMPTS_DIR +DEFAULT_PROMPT_MAPPING_FILE = DEFAULT_PROMPTS_DIR / "prompt_text" +DEFAULT_OUTPUT_RETENTION = 20 +DEFAULT_EXECUTION_MODE = "generate_stream" +DEFAULT_PRECISION = "bfloat16" +DEFAULT_ODE_METHOD = "euler" +DEFAULT_NUM_STEPS = 10 +DEFAULT_GUIDANCE_SCALE = 1.2 +DEFAULT_SPEAKER_SCALE = 1.5 +DEFAULT_MAX_GENERATE_LENGTH = 500 +DEFAULT_SEED = 42 +DEFAULT_INPUT_TEXT = "" +DEFAULT_WARMUP_TEXT = "dots.tts is a 2B-parameter fully continuous, end-to-end autoregressive (AR) text-to-speech system. The backbone pairs a semantic encoder, an LLM, and an autoregressive flow-matching acoustic head over a 48 kHz AudioVAE" +DEFAULT_PROMPT_NAME = "male_zh" +DEFAULT_PROMPT_NONE = "__none__" +PROMPT_AUDIO_SUFFIXES = (".wav", ".mp3", ".flac", ".m4a", ".ogg") diff --git a/apps/gradio/default_prompts/prompt_text b/apps/gradio/default_prompts/prompt_text new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/gradio/languages.py b/apps/gradio/languages.py new file mode 100644 index 0000000000000000000000000000000000000000..ffea99232718c44b4642c81485722476413be613 --- /dev/null +++ b/apps/gradio/languages.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +SUPPORTED_LANGUAGE_CODE_BY_NAME = { + "普通话": "ZH", + "粤语": "口音:粤语", + "北京话": "口音:北京官话", + "东北话": "口音:东北话", + "四川话": "口音:四川话", + "闽南话": "口音:闽南话", + "吴语": "口音:吴语", + "英语": "EN", + "西班牙语": "ES", + "印地语": "HI", + "阿拉伯语": "AR", + "孟加拉语": "BN", + "葡萄牙语": "PT", + "俄语": "RU", + "日语": "JA", + "法语": "FR", + "德语": "DE", + "韩语": "KO", + "意大利语": "IT", + "土耳其语": "TR", + "越南语": "VI", + "印尼语": "ID", + "乌尔都语": "UR", + "波斯语": "FA", + "泰米尔语": "TA", + "泰卢固语": "TE", + "菲律宾语": "FIL", + "马来语": "MS", + "旁遮普语": "PA", + "马拉地语": "MR", + "古吉拉特语": "GU", + "马拉雅拉姆语": "ML", + "卡纳达语": "KN", + "波兰语": "PL", + "乌克兰语": "UK", + "荷兰语": "NL", + "泰语": "TH", + "罗马尼亚语": "RO", + "斯瓦希里语": "SW", + "希伯来语": "HE", + "捷克语": "CS", + "希腊语": "EL", + "匈牙利语": "HU", + "瑞典语": "SV", + "丹麦语": "DA", + "芬兰语": "FI", + "书面挪威语": "NB", + "斯洛伐克语": "SK", + "斯洛文尼亚语": "SL", + "塞尔维亚语": "SR", + "波斯尼亚语": "BS", + "克罗地亚语": "HR", + "保加利亚语": "BG", + "马其顿语": "MK", + "立陶宛语": "LT", + "拉脱维亚语": "LV", + "爱沙尼亚语": "ET", + "冰岛语": "IS", + "爱尔兰语": "GA", + "威尔士语": "CY", + "加泰罗尼亚语": "CA", + "加利西亚语": "GL", + "奥克语": "OC", + "阿斯图里亚斯语": "AST", + "尼泊尔语": "NE", + "信德语": "SD", + "奥里亚语": "OR", + "阿萨姆语": "AS", + "普什图语": "PS", + "缅甸语": "MY", + "高棉语": "KM", + "老挝语": "LO", + "哈萨克语": "KK", + "乌兹别克语": "UZ", + "吉尔吉斯语": "KY", + "塔吉克语": "TG", + "阿塞拜疆语": "AZ", + "格鲁吉亚语": "KA", + "亚美尼亚语": "HY", + "白俄罗斯语": "BE", + "卢森堡语": "LB", + "马耳他语": "MT", + "毛利语": "MI", + "南非荷兰语": "AF", + "祖鲁语": "ZU", + "科萨语": "XH", + "约鲁巴语": "YO", + "豪萨语": "HA", + "伊博语": "IG", + "阿姆哈拉语": "AM", + "奥罗莫语": "OM", + "北索托语": "NSO", + "尼扬贾语": "NY", + "修纳语": "SN", + "索马里语": "SO", + "卢干达语": "LG", + "林加拉语": "LN", + "卢奥语": "LUO", + "坎巴语": "KAM", + "翁本杜语": "UMB", + "富拉语": "FF", + "沃洛夫语": "WO", + "中库尔德语": "CKB", + "宿务语": "CEB", + "佛得角克里奥尔语": "KEA", + "蒙古语": "MN", + "爪哇语": "JV", +} + + +def build_language_choice_items() -> list[tuple[str, str]]: + return [("不指定", ""), *[(name, code) for name, code in SUPPORTED_LANGUAGE_CODE_BY_NAME.items()]] diff --git a/apps/gradio/service.py b/apps/gradio/service.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0b43db92b4c1356fc6cec189d6935983df8adb --- /dev/null +++ b/apps/gradio/service.py @@ -0,0 +1,773 @@ +from __future__ import annotations + +import shutil +import sys +import threading +import time +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +REPO_ROOT = Path(__file__).resolve().parents[2] +SRC_ROOT = REPO_ROOT / "src" + +for import_root in (REPO_ROOT, SRC_ROOT): + import_root_str = str(import_root) + if import_root_str not in sys.path: + sys.path.insert(0, import_root_str) + +import soundfile as sf # noqa: E402 +import torch # noqa: E402 +from loguru import logger # noqa: E402 + +from apps.gradio.constants import ( # noqa: E402 + DEFAULT_EXECUTION_MODE, + DEFAULT_GUIDANCE_SCALE, + DEFAULT_HOST, + DEFAULT_MAX_GENERATE_LENGTH, + DEFAULT_NUM_STEPS, + DEFAULT_ODE_METHOD, + DEFAULT_OUTPUT_DIR, + DEFAULT_OUTPUT_RETENTION, + DEFAULT_PORT, + DEFAULT_PRECISION, + DEFAULT_PROMPT_MAPPING_FILE, + DEFAULT_PROMPT_NAME, + DEFAULT_PROMPT_NONE, + DEFAULT_PROMPT_SOURCE_DIR, + DEFAULT_PROMPTS_DIR, + DEFAULT_SEED, + DEFAULT_SPEAKER_SCALE, + DEFAULT_WARMUP_TEXT, + PROMPT_AUDIO_SUFFIXES, +) +from apps.gradio.languages import ( # noqa: E402 + SUPPORTED_LANGUAGE_CODE_BY_NAME, + build_language_choice_items, +) +from dots_tts.runtime import DotsTtsRuntime # noqa: E402 +from dots_tts.utils.util import seed_everything # noqa: E402 + +ExecutionMode = Literal["generate", "generate_stream"] +GRADIO_SYNTHESIS_MODE_CHOICES = ( + ("tts", "tts"), + ("instruct_tts", "instruction_tts"), + ("instruct_tts_general", "text_to_audio"), +) +GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES = tuple( + value for _, value in GRADIO_SYNTHESIS_MODE_CHOICES +) + + +@dataclass(frozen=True) +class PromptPreset: + name: str + audio_path: str + prompt_text: str + + +def _is_prompt_asset(path: Path) -> bool: + return path.is_file() and ( + path.name == "prompt_text" or path.suffix.lower() in PROMPT_AUDIO_SUFFIXES + ) + + +def sync_default_prompt_library( + source_dir: Path = DEFAULT_PROMPT_SOURCE_DIR, + target_dir: Path = DEFAULT_PROMPTS_DIR, +) -> None: + source_dir = Path(source_dir) + if not source_dir.is_dir(): + logger.info( + "Prompt library sync skipped: source_dir={} does not exist.", + source_dir, + ) + return + + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + logger.info( + "Prompt library sync started: source_dir={} target_dir={}", + source_dir, + target_dir, + ) + + source_assets = { + asset.name: asset for asset in sorted(source_dir.iterdir()) if _is_prompt_asset(asset) + } + copied_count = 0 + for asset_name, source_asset in source_assets.items(): + target_asset = target_dir / asset_name + if ( + not target_asset.exists() + or target_asset.stat().st_size != source_asset.stat().st_size + or target_asset.stat().st_mtime_ns != source_asset.stat().st_mtime_ns + ): + shutil.copy2(source_asset, target_asset) + copied_count += 1 + + removed_count = 0 + for target_asset in sorted(target_dir.iterdir()): + if _is_prompt_asset(target_asset) and target_asset.name not in source_assets: + target_asset.unlink(missing_ok=True) + removed_count += 1 + logger.info( + "Prompt library sync completed: copied_assets={} removed_assets={} " + "available_assets={}", + copied_count, + removed_count, + len(source_assets), + ) + + +def _load_prompt_text_map(mapping_file: Path) -> dict[str, str]: + if not mapping_file.is_file(): + return {} + + prompt_text_map: dict[str, str] = {} + with mapping_file.open(encoding="utf-8") as file_obj: + for raw_line in file_obj: + line = raw_line.strip() + if not line or line.startswith("#") or "|" not in line: + continue + name, text = line.split("|", 1) + prompt_text_map[name.strip()] = text.strip() + return prompt_text_map + + +def discover_prompt_presets( + prompts_dir: Path = DEFAULT_PROMPTS_DIR, + mapping_file: Path = DEFAULT_PROMPT_MAPPING_FILE, +) -> tuple[PromptPreset, ...]: + prompts_dir = Path(prompts_dir) + if not prompts_dir.is_dir(): + return () + + prompt_text_map = _load_prompt_text_map(Path(mapping_file)) + prompt_audio_paths = [ + audio_path + for audio_path in sorted(prompts_dir.iterdir(), key=lambda path: (path.stem == "child", path.stem)) + if audio_path.is_file() and audio_path.suffix.lower() in PROMPT_AUDIO_SUFFIXES + ] + return tuple( + PromptPreset( + name=audio_path.stem, + audio_path=str(audio_path.resolve()), + prompt_text=prompt_text_map.get(audio_path.stem, ""), + ) + for audio_path in prompt_audio_paths + ) + + +def build_prompt_choice_items( + prompt_presets: tuple[PromptPreset, ...], +) -> list[tuple[str, str]]: + return [("No Preset", DEFAULT_PROMPT_NONE), *[(preset.name, preset.name) for preset in prompt_presets]] + + +def resolve_default_prompt_selection( + prompt_presets: tuple[PromptPreset, ...], + default_prompt_name: str = DEFAULT_PROMPT_NAME, +) -> tuple[str, str | None, str]: + if not prompt_presets: + return DEFAULT_PROMPT_NONE, None, "" + + preset_by_name = {preset.name: preset for preset in prompt_presets} + selected_name = default_prompt_name if default_prompt_name in preset_by_name else prompt_presets[0].name + selected_preset = preset_by_name[selected_name] + return selected_name, selected_preset.audio_path, selected_preset.prompt_text + + +def resolve_prompt_selection( + prompt_name: str, + prompt_presets: tuple[PromptPreset, ...], +) -> tuple[str | None, str]: + if prompt_name == DEFAULT_PROMPT_NONE: + return None, "" + + for preset in prompt_presets: + if preset.name == prompt_name: + return preset.audio_path, preset.prompt_text + return None, "" + + +def discover_local_model_choices(repo_root: Path = REPO_ROOT) -> list[str]: + model_root = Path(repo_root) / "pretrained_models" + if not model_root.is_dir(): + return [] + return sorted( + path.relative_to(repo_root).as_posix() + for path in model_root.glob("**/model") + if path.is_dir() + ) + + +def resolve_model_name_or_path(model_name_or_path: str, repo_root: Path = REPO_ROOT) -> str: + normalized = model_name_or_path.strip() + if not normalized: + raise ValueError("model_name_or_path 不能为空。") + + direct_path = Path(normalized).expanduser() + if direct_path.exists(): + return str(direct_path.resolve()) + + repo_relative_path = Path(repo_root) / normalized + if repo_relative_path.exists(): + return str(repo_relative_path.resolve()) + + return normalized + + +def default_model_name_or_path(repo_root: Path = REPO_ROOT) -> str: + discovered = discover_local_model_choices(repo_root=repo_root) + if not discovered: + return "" + return discovered[0] + + +@dataclass(frozen=True) +class GradioAppConfig: + host: str + port: int + execution_mode: ExecutionMode + precision: str + optimize: bool + output_dir: Path + prompts_dir: Path + output_retention_count: int + max_generate_length: int + default_model_name_or_path: str + prompt_presets: tuple[PromptPreset, ...] + default_prompt_name: str + default_prompt_audio_path: str | None + default_prompt_text: str + default_precision: str + default_num_steps: int + default_guidance_scale: float + default_speaker_scale: float + default_max_generate_length: int + local_model_choices: tuple[str, ...] + repo_root: Path = REPO_ROOT + + +def build_gradio_app_config( + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + execution_mode: ExecutionMode = DEFAULT_EXECUTION_MODE, + precision: str = DEFAULT_PRECISION, + optimize: bool = False, + output_dir: Path = DEFAULT_OUTPUT_DIR, + output_retention_count: int = DEFAULT_OUTPUT_RETENTION, + max_generate_length: int = DEFAULT_MAX_GENERATE_LENGTH, + model_name_or_path: str | None = None, + default_prompt_name: str = DEFAULT_PROMPT_NAME, + default_precision: str = DEFAULT_PRECISION, + default_num_steps: int = DEFAULT_NUM_STEPS, + default_guidance_scale: float = DEFAULT_GUIDANCE_SCALE, + default_speaker_scale: float = DEFAULT_SPEAKER_SCALE, + default_max_generate_length: int = DEFAULT_MAX_GENERATE_LENGTH, + repo_root: Path = REPO_ROOT, + prompts_dir: Path = DEFAULT_PROMPTS_DIR, + prompt_source_dir: Path = DEFAULT_PROMPT_SOURCE_DIR, +) -> GradioAppConfig: + sync_default_prompt_library( + source_dir=prompt_source_dir, + target_dir=prompts_dir, + ) + discovered_models = discover_local_model_choices(repo_root=repo_root) + prompt_presets = discover_prompt_presets( + prompts_dir=prompts_dir, + mapping_file=prompts_dir / "prompt_text", + ) + resolved_default_prompt_name, default_prompt_audio_path, default_prompt_text = ( + resolve_default_prompt_selection( + prompt_presets, + default_prompt_name=default_prompt_name, + ) + ) + selected_model_name_or_path = ( + model_name_or_path.strip() + if model_name_or_path is not None + else default_model_name_or_path(repo_root=repo_root) + ) + if not selected_model_name_or_path: + raise ValueError("No default model found. Please pass --model-name-or-path.") + if execution_mode not in ("generate", "generate_stream"): + raise ValueError(f"Unsupported execution_mode: {execution_mode}") + resolved_max_generate_length = int(max_generate_length) + if resolved_max_generate_length <= 0: + raise ValueError("max_generate_length must be positive.") + resolved_precision = precision.strip() or DEFAULT_PRECISION + logger.info( + "Gradio app config prepared: host={} port={} output_dir={} " + "output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={} " + "default_model_name_or_path={} prompt_preset_count={} language_count={} local_model_choice_count={}", + host, + port, + output_dir, + output_retention_count, + resolved_max_generate_length, + execution_mode, + resolved_precision, + bool(optimize), + selected_model_name_or_path, + len(prompt_presets), + len(SUPPORTED_LANGUAGE_CODE_BY_NAME), + len(discovered_models), + ) + return GradioAppConfig( + host=host, + port=int(port), + execution_mode=execution_mode, + precision=resolved_precision, + optimize=bool(optimize), + output_dir=Path(output_dir), + prompts_dir=Path(prompts_dir), + output_retention_count=int(output_retention_count), + max_generate_length=resolved_max_generate_length, + default_model_name_or_path=selected_model_name_or_path, + prompt_presets=prompt_presets, + default_prompt_name=resolved_default_prompt_name, + default_prompt_audio_path=default_prompt_audio_path, + default_prompt_text=default_prompt_text, + default_precision=default_precision, + default_num_steps=int(default_num_steps), + default_guidance_scale=float(default_guidance_scale), + default_speaker_scale=float(default_speaker_scale), + default_max_generate_length=int(default_max_generate_length), + local_model_choices=tuple(discovered_models), + repo_root=repo_root, + ) + + +@dataclass(frozen=True) +class SynthesisRequest: + model_name_or_path: str + text: str + prompt_audio_path: str | None = None + prompt_text: str | None = None + execution_mode: ExecutionMode = DEFAULT_EXECUTION_MODE + template_name: str = "tts" + language: str | None = None + ode_method: str = DEFAULT_ODE_METHOD + num_steps: int = DEFAULT_NUM_STEPS + guidance_scale: float = DEFAULT_GUIDANCE_SCALE + speaker_scale: float = DEFAULT_SPEAKER_SCALE + normalize_text: bool = False + seed: int = DEFAULT_SEED + + +@dataclass(frozen=True) +class SynthesisResult: + audio_path: str + metrics: dict[str, Any] + status: str + + +class GradioAppService: + def __init__(self, config: GradioAppConfig): + self.config = config + self.config.output_dir.mkdir(parents=True, exist_ok=True) + self._lock = threading.Lock() + self._runtime: DotsTtsRuntime | None = None + self._runtime_model_name_or_path: str | None = None + logger.info( + "Gradio service initialized: output_dir={} default_model_name_or_path={} " + "output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={}", + self.config.output_dir, + self.config.default_model_name_or_path, + self.config.output_retention_count, + self.config.max_generate_length, + self.config.execution_mode, + self.config.precision, + self.config.optimize, + ) + + def metadata(self) -> dict[str, Any]: + return { + "repo_root": str(self.config.repo_root), + "default_model_name_or_path": self.config.default_model_name_or_path, + "local_model_choices": list(self.config.local_model_choices), + "prompts_dir": str(self.config.prompts_dir), + "prompt_preset_names": [preset.name for preset in self.config.prompt_presets], + "default_prompt_name": self.config.default_prompt_name, + "output_dir": str(self.config.output_dir), + "output_retention_count": self.config.output_retention_count, + "configured_max_generate_length": self.config.max_generate_length, + "configured_execution_mode": self.config.execution_mode, + "configured_precision": self.config.precision, + "optimize": self.config.optimize, + "loaded_model_name_or_path": self._runtime_model_name_or_path, + "loaded_max_generate_length": ( + self.config.max_generate_length if self._runtime is not None else None + ), + "loaded_precision": ( + self.config.precision if self._runtime is not None else None + ), + "model_loaded": self._runtime is not None, + "host": self.config.host, + "port": self.config.port, + "default_precision": self.config.default_precision, + "default_num_steps": self.config.default_num_steps, + "default_guidance_scale": self.config.default_guidance_scale, + "default_speaker_scale": self.config.default_speaker_scale, + "default_max_generate_length": self.config.default_max_generate_length, + "supported_languages": build_language_choice_items()[1:], + "supported_template_names": list(GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES), + } + + def _get_runtime( + self, + model_name_or_path: str, + ) -> tuple[DotsTtsRuntime, str]: + resolved_model_name_or_path = resolve_model_name_or_path( + model_name_or_path, + repo_root=self.config.repo_root, + ) + if ( + self._runtime is None + or self._runtime_model_name_or_path != resolved_model_name_or_path + ): + logger.info( + "Gradio runtime cache miss: requested_model={} resolved_model={} " + "max_generate_length={} execution_mode={} precision={} optimize={}", + model_name_or_path, + resolved_model_name_or_path, + self.config.max_generate_length, + self.config.execution_mode, + self.config.precision, + self.config.optimize, + ) + self._runtime = DotsTtsRuntime.from_pretrained( + resolved_model_name_or_path, + precision=self.config.precision, + optimize=self.config.optimize, + max_generate_length=self.config.max_generate_length, + ) + self._runtime_model_name_or_path = resolved_model_name_or_path + else: + logger.info( + "Gradio runtime cache hit: requested_model={} resolved_model={} " + "max_generate_length={} execution_mode={} precision={} optimize={}", + model_name_or_path, + resolved_model_name_or_path, + self.config.max_generate_length, + self.config.execution_mode, + self.config.precision, + self.config.optimize, + ) + return self._runtime, resolved_model_name_or_path + + def _build_stream_request_id( + self, + runtime: DotsTtsRuntime, + request: SynthesisRequest, + ) -> str: + normalized_text, normalized_language = runtime._process_text( # noqa: SLF001 + request.text, + language=request.language, + normalize=request.normalize_text, + ) + normalized_prompt_text = runtime._process_prompt_text( # noqa: SLF001 + request.prompt_text, + language=normalized_language, + ) + if normalized_language is not None and not normalized_prompt_text: + from dots_tts.utils.text import attach_language_tag # noqa: PLC0415 + + normalized_text = attach_language_tag( + normalized_text, + normalized_language, + ) + request_id_kwargs = { + "text": normalized_text, + "prompt_audio_path": request.prompt_audio_path, + "prompt_text": normalized_prompt_text, + "template_name": request.template_name, + } + if normalized_language is not None: + request_id_kwargs["language"] = normalized_language + return runtime._build_request_id( # noqa: SLF001 + **request_id_kwargs, + ) + + @staticmethod + def _build_runtime_generate_kwargs(request: SynthesisRequest) -> dict[str, Any]: + runtime_kwargs: dict[str, Any] = { + "text": request.text, + "prompt_audio_path": request.prompt_audio_path, + "prompt_text": request.prompt_text, + "template_name": request.template_name, + "ode_method": request.ode_method, + "num_steps": request.num_steps, + "guidance_scale": request.guidance_scale, + "speaker_scale": request.speaker_scale, + "normalize_text": request.normalize_text, + } + if request.language is not None: + runtime_kwargs["language"] = request.language + return runtime_kwargs + + def _run_stream_generation( + self, + runtime: DotsTtsRuntime, + request: SynthesisRequest, + ) -> dict[str, Any]: + start_time = time.time() + chunks = [ + chunk.detach().float().cpu() + for chunk in runtime.generate_stream( + **self._build_runtime_generate_kwargs(request) + ) + ] + if not chunks: + raise ValueError("流式生成未返回任何音频块。") + + audio = torch.cat(chunks, dim=-1) + elapsed_seconds = time.time() - start_time + audio_seconds = audio.shape[-1] / runtime.sample_rate + rtf = elapsed_seconds / audio_seconds if audio_seconds > 0 else float("inf") + return { + "fid": self._build_stream_request_id(runtime, request), + "audio": audio, + "sample_rate": runtime.sample_rate, + "time_used": elapsed_seconds, + "rtf": rtf, + "chunk_count": len(chunks), + } + + def warmup(self, text: str | None = None) -> dict[str, Any]: + warmup_text = (text or "").strip() or DEFAULT_WARMUP_TEXT.strip() + if not warmup_text: + raise ValueError("DEFAULT_WARMUP_TEXT 不能为空。") + + with self._lock: + logger.info( + "Gradio warmup requested: default_model_name_or_path={} execution_mode={} precision={} optimize={} seed={}", + self.config.default_model_name_or_path, + self.config.execution_mode, + self.config.precision, + self.config.optimize, + DEFAULT_SEED, + ) + try: + seed_everything(DEFAULT_SEED) + runtime, resolved_model_name_or_path = self._get_runtime( + self.config.default_model_name_or_path, + ) + warmup_request = SynthesisRequest( + model_name_or_path=self.config.default_model_name_or_path, + text=warmup_text, + execution_mode=self.config.execution_mode, + template_name="tts", + ode_method=DEFAULT_ODE_METHOD, + num_steps=self.config.default_num_steps, + guidance_scale=self.config.default_guidance_scale, + speaker_scale=self.config.default_speaker_scale, + normalize_text=False, + seed=DEFAULT_SEED, + ) + request_id = self._build_stream_request_id(runtime, warmup_request) + if self.config.execution_mode == "generate_stream": + result = self._run_stream_generation(runtime, warmup_request) + else: + start_time = time.time() + result = runtime.generate(**self._build_runtime_generate_kwargs(warmup_request)) + result["time_used"] = time.time() - start_time + result["chunk_count"] = 1 + audio_samples = int(result["audio"].shape[-1]) + except Exception: + logger.exception( + "Gradio warmup failed: default_model_name_or_path={}", + self.config.default_model_name_or_path, + ) + raise + audio_seconds = audio_samples / runtime.sample_rate + metrics = { + "request_id": request_id, + "execution_mode": self.config.execution_mode, + "chunk_count": int(result["chunk_count"]), + "resolved_model_name_or_path": resolved_model_name_or_path, + "sample_rate": runtime.sample_rate, + "elapsed_seconds": round(float(result["time_used"]), 3), + "audio_seconds": round(float(audio_seconds), 3), + "rtf": round(float(result["rtf"]), 4), + "seed": DEFAULT_SEED, + "text": warmup_text, + } + logger.info( + "Gradio warmup ready: request_id={} execution_mode={} resolved_model_name_or_path={}", + metrics["request_id"], + metrics["execution_mode"], + metrics["resolved_model_name_or_path"], + ) + return metrics + + def _normalize_request(self, request: SynthesisRequest) -> SynthesisRequest: + normalized_text = request.text.strip() + if not normalized_text: + raise ValueError("text 不能为空。") + + normalized_prompt_audio_path = request.prompt_audio_path or None + normalized_prompt_text = (request.prompt_text or "").strip() or None + if normalized_prompt_text and not normalized_prompt_audio_path: + raise ValueError("prompt_text requires prompt_audio_path.") + normalized_template_name = request.template_name.strip() or "tts" + if normalized_template_name not in GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES: + raise ValueError( + f"Unsupported template_name={normalized_template_name!r}. " + f"Expected one of {list(GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES)}." + ) + normalized_language = (request.language or "").strip() or None + supported_language_codes = set(SUPPORTED_LANGUAGE_CODE_BY_NAME.values()) + if ( + normalized_language is not None + and normalized_language not in supported_language_codes + ): + raise ValueError( + f"Unsupported language={normalized_language!r}. " + f"Expected one of {sorted(supported_language_codes)}." + ) + + resolved_seed = int(request.seed) + return SynthesisRequest( + model_name_or_path=request.model_name_or_path.strip(), + text=normalized_text, + prompt_audio_path=normalized_prompt_audio_path, + prompt_text=normalized_prompt_text, + execution_mode=request.execution_mode, + template_name=normalized_template_name, + language=normalized_language, + ode_method=request.ode_method.strip() or DEFAULT_ODE_METHOD, + num_steps=int(request.num_steps), + guidance_scale=float(request.guidance_scale), + speaker_scale=float(request.speaker_scale), + normalize_text=bool(request.normalize_text), + seed=resolved_seed, + ) + + def _build_output_path(self) -> Path: + output_name = f"{time.strftime('%Y%m%d-%H%M%S')}-{uuid.uuid4().hex[:8]}.wav" + return self.config.output_dir / output_name + + def _cleanup_outputs(self) -> None: + if self.config.output_retention_count <= 0: + return + + wav_files = sorted( + self.config.output_dir.glob("*.wav"), + key=lambda path: path.stat().st_mtime, + reverse=True, + ) + removed_count = 0 + for stale_file in wav_files[self.config.output_retention_count :]: + stale_file.unlink(missing_ok=True) + removed_count += 1 + if removed_count > 0: + logger.info( + "Gradio output cleanup completed: removed_files={} retention_limit={}", + removed_count, + self.config.output_retention_count, + ) + + @staticmethod + def _waveform_to_numpy(audio: torch.Tensor): + waveform = audio.detach().float().cpu().squeeze() + if waveform.ndim == 0: + raise ValueError("生成音频为空。") + return waveform.numpy() + + def _write_audio(self, audio: torch.Tensor, sample_rate: int) -> str: + output_path = self._build_output_path() + logger.info( + "Writing synthesized audio: output_path={} sample_rate={} samples={}", + output_path, + sample_rate, + audio.shape[-1], + ) + sf.write(output_path, self._waveform_to_numpy(audio), sample_rate) + self._cleanup_outputs() + logger.info("Synthesized audio written: output_path={}", output_path) + return str(output_path) + + def generate(self, request: SynthesisRequest) -> SynthesisResult: + normalized_request = self._normalize_request(request) + + with self._lock: + try: + seed_everything(normalized_request.seed) + runtime, resolved_model_name_or_path = self._get_runtime( + normalized_request.model_name_or_path, + ) + logger.info( + "Gradio request accepted: resolved_model_name_or_path={} execution_mode={} seed={}", + resolved_model_name_or_path, + normalized_request.execution_mode, + normalized_request.seed, + ) + if normalized_request.execution_mode == "generate_stream": + result = self._run_stream_generation(runtime, normalized_request) + else: + result = runtime.generate( + **self._build_runtime_generate_kwargs(normalized_request) + ) + result["chunk_count"] = 1 + audio_path = self._write_audio(result["audio"], result["sample_rate"]) + except Exception: + logger.exception( + "Gradio request failed: model_name_or_path={} execution_mode={} text_len={} has_prompt_audio={} has_prompt_text={} template_name={} language={} " + "precision={} ode_method={} num_steps={} guidance_scale={} speaker_scale={} max_generate_length={} " + "normalize_text={} seed={}", + normalized_request.model_name_or_path, + normalized_request.execution_mode, + len(normalized_request.text), + bool(normalized_request.prompt_audio_path), + bool(normalized_request.prompt_text), + normalized_request.template_name, + normalized_request.language, + self.config.precision, + normalized_request.ode_method, + normalized_request.num_steps, + normalized_request.guidance_scale, + normalized_request.speaker_scale, + self.config.max_generate_length, + normalized_request.normalize_text, + normalized_request.seed, + ) + raise + audio_seconds = result["audio"].shape[-1] / result["sample_rate"] + metrics = { + "request_id": result["fid"], + "execution_mode": normalized_request.execution_mode, + "chunk_count": int(result["chunk_count"]), + "template_name": normalized_request.template_name, + "language": normalized_request.language, + "resolved_model_name_or_path": resolved_model_name_or_path, + "sample_rate": result["sample_rate"], + "elapsed_seconds": round(float(result["time_used"]), 3), + "audio_seconds": round(float(audio_seconds), 3), + "rtf": round(float(result["rtf"]), 4), + "seed": normalized_request.seed, + "output_path": audio_path, + } + logger.info( + "Gradio request output ready: request_id={} execution_mode={} resolved_model_name_or_path={} output_path={}", + metrics["request_id"], + metrics["execution_mode"], + metrics["resolved_model_name_or_path"], + metrics["output_path"], + ) + status = ( + f"完成:{Path(audio_path).name} | " + f"模式 {metrics['execution_mode']} | " + f"耗时 {metrics['elapsed_seconds']}s | " + f"音频 {metrics['audio_seconds']}s | " + f"RTF {metrics['rtf']}" + ) + return SynthesisResult( + audio_path=audio_path, + metrics=metrics, + status=status, + ) diff --git a/configs/dots_tts.yaml b/configs/dots_tts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c25aa9927938c9200d4a9a80881d21d127c61ab2 --- /dev/null +++ b/configs/dots_tts.yaml @@ -0,0 +1,76 @@ +train_data: + train_audio_sample_rate: 48000 + audio_samples_per_llm_token: 7680 + sources: + - name: ljspeech_basic + weight: 1.0 + pipeline: basic + adapter: + class_name: JsonlManifestSourceAdapter + params: + manifest_path: downloaded_data/ljspeech_48khz_manifest_train.jsonl + shuffle: true + - name: ljspeech_interleave + weight: 1.0 + pipeline: interleave + adapter: + class_name: JsonlManifestSourceAdapter + params: + manifest_path: downloaded_data/ljspeech_48khz_manifest_train.jsonl + shuffle: true + # append other sources here if need + num_tokens_per_epoch: 2000000 + num_workers: 20 + pin_memory: true + max_audio_seconds_in_batch: 30.0 + max_text_tokens_in_batch: 2048 + max_samples_per_batch: null + bucketing_pool_size: 100 +val_data: + train_audio_sample_rate: 48000 + audio_samples_per_llm_token: 7680 + sources: + - name: ljspeech_valid_basic + weight: 1.0 + adapter: + class_name: JsonlManifestSourceAdapter + params: + manifest_path: downloaded_data/ljspeech_48khz_manifest_valid.jsonl + shuffle: false + pipeline: basic + - name: ljspeech_valid_interleave + weight: 1.0 + pipeline: interleave + adapter: + class_name: JsonlManifestSourceAdapter + params: + manifest_path: downloaded_data/ljspeech_48khz_manifest_valid.jsonl + shuffle: false + pipeline: interleave + # append other sources here if need + num_workers: 4 + pin_memory: true + max_audio_seconds_in_batch: 30.0 + max_text_tokens_in_batch: 2048 + max_samples_per_batch: null + bucketing_pool_size: 64 +train: + pretrained_model_path: pretrained_models/pretrain_cpt_decay/latest/model/ + output_dir: debug_train/run_003 + seed: 42 + learning_rate: 1.0e-05 + weight_decay: 0.01 + warmup_steps: 50 + max_train_steps: 500 + gradient_accumulation_steps: 2 + grad_clip_norm: 1 + save_interval: 500 + max_checkpoints_to_keep: 40 + log_interval: 10 + eval_interval: 100 + max_eval_batches: null + run_eval_on_start: false +loss: + ce_weight: 1.0 + fm_weight: 1.0 + eos_weight: 1.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..943217629151ffbaabaaeccdf124eb0858cc2c66 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +spaces>=0.40.1 +torch>=2.8.0 +torchaudio>=2.8.0 +transformers>=4.57.0 +huggingface-hub>=0.36.0 +gradio>=6.16.0 +loguru>=0.7.3 +langcodes[data]>=3.5.0 +einops>=0.8.1 +librosa>=0.11.0 +soundfile>=0.13.1 +numpy>=2.2.6 +pydantic>=2.12.5,<3 +PyYAML>=6.0.3 +safetensors>=0.8.0rc0 +torchdiffeq>=0.2.5 +tqdm>=4.67.1 +lingua-language-detector>=2.1.1 +WeTextProcessing>=1.0.4 diff --git a/src/dots_tts/__init__.py b/src/dots_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b25576d5f2ecf89d144162771ebb197943801da --- /dev/null +++ b/src/dots_tts/__init__.py @@ -0,0 +1 @@ +"""dots.tts package.""" diff --git a/src/dots_tts/cli.py b/src/dots_tts/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..4c6e6bc5734ef1cc3eb903a08f50bfba0c66f2e3 --- /dev/null +++ b/src/dots_tts/cli.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + + +def parse_args(argv=None): + parser = argparse.ArgumentParser(description="dots.tts inference CLI.") + template_choices = ("tts", "instruction_tts", "text_to_audio", "tts_interleave") + parser.add_argument( + "--model-name-or-path", + required=True, + help="Local pretrained directory or Hugging Face repo id", + ) + parser.add_argument( + "--revision", default=None, help="Optional Hugging Face revision" + ) + parser.add_argument( + "--cache-dir", default=None, help="Optional Hugging Face cache dir" + ) + parser.add_argument("--text", type=str, required=True, help="Input text") + parser.add_argument("--output", default="output.wav", help="Output wav file path") + parser.add_argument( + "--precision", type=str, default="bfloat16", help="Inference precision" + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for inference.", + ) + parser.add_argument( + "--prompt-audio", type=str, default=None, help="Path to prompt audio" + ) + parser.add_argument( + "--prompt-text", type=str, default=None, help="Transcript of prompt audio" + ) + parser.add_argument( + "--language", + type=str, + default=None, + help="Language tag mode. Default: none. Supported values: none, auto_detect, or a language code/name such as EN/en/english/chinese.", + ) + parser.add_argument( + "--template-name", + choices=template_choices, + default=None, + help="Named template preset for generation.", + ) + parser.add_argument( + "--ode-method", type=str, default="euler", help="ODE solver method" + ) + parser.add_argument( + "--num-steps", type=int, default=10, help="Diffusion sampling steps" + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=1.2, + help="Classifier-free guidance scale", + ) + parser.add_argument( + "--speaker-scale", + type=float, + default=1.5, + help="Scale applied to the reference speaker embedding", + ) + parser.add_argument( + "--max-generate-length", + type=int, + default=500, + help="Maximum total audio patch count (prompt + generated)", + ) + parser.add_argument( + "--normalize-text", + action="store_true", + help="Whether to normalize text before inference", + ) + parser.add_argument( + "--profile-inference", + action="store_true", + help="Collect per-module inference timing statistics", + ) + return parser.parse_args(argv) + + +def main(argv=None): + args = parse_args(argv) + import soundfile as sf + from loguru import logger + + from dots_tts.runtime import DotsTtsRuntime + from dots_tts.utils.logging import configure_logging + from dots_tts.utils.util import seed_everything + + configure_logging() + seed_everything(args.seed) + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + logger.info( + "CLI command started: model={} output={} seed={}", + args.model_name_or_path, + output_path, + args.seed, + ) + + try: + runtime = DotsTtsRuntime.from_pretrained( + args.model_name_or_path, + revision=args.revision, + cache_dir=args.cache_dir, + precision=args.precision, + max_generate_length=args.max_generate_length, + ) + result = runtime.generate( + text=args.text, + prompt_audio_path=args.prompt_audio, + prompt_text=args.prompt_text, + language=args.language, + template_name=args.template_name, + ode_method=args.ode_method, + num_steps=args.num_steps, + guidance_scale=args.guidance_scale, + speaker_scale=args.speaker_scale, + normalize_text=args.normalize_text, + profile_inference=args.profile_inference, + ) + sf.write( + output_path, + result["audio"].float().cpu().squeeze().numpy(), + result["sample_rate"], + ) + except Exception: + logger.exception( + "CLI inference failed: model={} output={}", + args.model_name_or_path, + output_path, + ) + raise + + logger.info( + "CLI output written: request_id={} output={} sample_rate={} samples={}", + result["fid"], + output_path, + result["sample_rate"], + int(result["audio"].shape[-1]), + ) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/dots_tts/config/__init__.py b/src/dots_tts/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56096f22e33373e8aa0dd3cbb6a5e2f99bdbca1b --- /dev/null +++ b/src/dots_tts/config/__init__.py @@ -0,0 +1 @@ +"""Configuration package.""" diff --git a/src/dots_tts/config/app.py b/src/dots_tts/config/app.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf0ce2d3794feedf1bf1d36bcc05551df93907e --- /dev/null +++ b/src/dots_tts/config/app.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from pathlib import Path + +import yaml + +from dots_tts.config.base import StrictConfigBase +from dots_tts.config.data import DataConfig +from dots_tts.config.train import TrainConfig +from dots_tts.models.dots_tts.config import LossConfig + +DEFAULT_CONFIG_PATH = "configs/dots_tts.yaml" + + +class AppConfig(StrictConfigBase): + train_data: DataConfig + val_data: DataConfig | None = None + loss: LossConfig + train: TrainConfig + + @classmethod + def from_yaml(cls, config_path: str = DEFAULT_CONFIG_PATH) -> AppConfig: + with Path(config_path).open(encoding="utf-8") as fin: + raw_config = yaml.safe_load(fin) + return cls.model_validate(raw_config) + + +def load_config(config_path: str = DEFAULT_CONFIG_PATH) -> AppConfig: + return AppConfig.from_yaml(config_path) + + +__all__ = ["AppConfig", "DEFAULT_CONFIG_PATH", "load_config"] diff --git a/src/dots_tts/config/base.py b/src/dots_tts/config/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b81ff06c3a149a616bc7a5be4575fbd3c75d19a5 --- /dev/null +++ b/src/dots_tts/config/base.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class ConfigBase(BaseModel): + model_config = ConfigDict( + extra="allow", + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + def get(self, key: str, default=None): + value = getattr(self, key, default) + if value is default: + return value + + fields_set = self.model_fields_set + if value is None and key not in fields_set: + return default + return value + + def to_dict(self) -> dict[str, Any]: + return self.model_dump(exclude_none=True) + + @classmethod + def _declared_field_names(cls) -> list[str]: + return [name for name in cls.model_fields if name != "model_config"] + + @classmethod + def _serialize_declared_value(cls, value): + if isinstance(value, ConfigBase): + return value.to_declared_dict() + if isinstance(value, list): + return [cls._serialize_declared_value(item) for item in value] + if isinstance(value, tuple): + return [cls._serialize_declared_value(item) for item in value] + if isinstance(value, dict): + return { + key: cls._serialize_declared_value(item) for key, item in value.items() + } + return value + + def to_declared_dict(self) -> dict[str, Any]: + data = {} + for name in self._declared_field_names(): + value = getattr(self, name, None) + if value is None: + continue + data[name] = self._serialize_declared_value(value) + return data + + +class StrictConfigBase(ConfigBase): + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + +__all__ = ["ConfigBase", "StrictConfigBase"] diff --git a/src/dots_tts/config/data.py b/src/dots_tts/config/data.py new file mode 100644 index 0000000000000000000000000000000000000000..6035115b076bef81dc0bab131a43224f0060b141 --- /dev/null +++ b/src/dots_tts/config/data.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import Field, model_validator + +from dots_tts.config.base import StrictConfigBase + +DEFAULT_SOURCE_ADAPTER_CLASS_NAME = "JsonlManifestSourceAdapter" + + +class SourceAdapterConfig(StrictConfigBase): + class_name: Literal["JsonlManifestSourceAdapter"] = ( + DEFAULT_SOURCE_ADAPTER_CLASS_NAME + ) + params: dict[str, Any] = Field(default_factory=dict) + + +class DataSourceConfig(StrictConfigBase): + name: str + weight: float = Field(default=1.0, gt=0.0) + pipeline: Literal["basic", "interleave"] = "basic" + adapter: SourceAdapterConfig = Field(default_factory=SourceAdapterConfig) + + +class DataConfig(StrictConfigBase): + sources: list[DataSourceConfig] + train_audio_sample_rate: int = Field(ge=1) + audio_samples_per_llm_token: int = Field(ge=1) + num_tokens_per_epoch: int | None = Field( + default=None, + ge=1, + description="Global token budget across all ranks for one training epoch.", + ) + num_workers: int = Field(default=0, ge=0) + pin_memory: bool = False + prefetch_factor: int = Field( + default=2, + ge=1, + description="Samples prefetched by each DataLoader worker.", + ) + max_audio_seconds_in_batch: float = Field(gt=0.0) + max_text_tokens_in_batch: int = Field(ge=1) + max_samples_per_batch: int | None = Field(default=None, ge=1) + bucketing_pool_size: int = Field(default=64, ge=1) + + @model_validator(mode="after") + def _validate_unique_source_names(self) -> "DataConfig": + counts: dict[str, int] = {} + for source in self.sources: + counts[source.name] = counts.get(source.name, 0) + 1 + duplicated = [name for name, count in counts.items() if count > 1] + if duplicated: + raise ValueError(f"Source names must be unique: {duplicated}") + return self + + +__all__ = [ + "DEFAULT_SOURCE_ADAPTER_CLASS_NAME", + "DataConfig", + "DataSourceConfig", + "SourceAdapterConfig", +] diff --git a/src/dots_tts/config/train.py b/src/dots_tts/config/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b975ad7bebeebe9e408fad37981bab691407844d --- /dev/null +++ b/src/dots_tts/config/train.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from pydantic import Field + +from dots_tts.config.base import StrictConfigBase + + +class TrainConfig(StrictConfigBase): + pretrained_model_path: str + output_dir: str + seed: int = 42 + learning_rate: float + cfg_droprate: float = 0.0 + xvec_drop_rate: float = 0.5 + weight_decay: float = 0.01 + warmup_steps: int = 0 + max_train_steps: int + gradient_accumulation_steps: int = Field(default=1, ge=1) + grad_clip_norm: float = 1.0 + save_interval: int = Field(default=1000, ge=1) + max_checkpoints_to_keep: int = 10 + log_interval: int = Field(default=10, ge=1) + eval_interval: int | None = Field(default=None, ge=1) + max_eval_batches: int | None = None + run_eval_on_start: bool = False + + +__all__ = ["TrainConfig"] diff --git a/src/dots_tts/data/EXTENSION.md b/src/dots_tts/data/EXTENSION.md new file mode 100644 index 0000000000000000000000000000000000000000..660815f8e2c0379aa577496c5a67bb2b251de7a6 --- /dev/null +++ b/src/dots_tts/data/EXTENSION.md @@ -0,0 +1,124 @@ +# Data Source Extension Guide + +This document answers exactly one question: how to plug a new training data source into the current `dots_tts` data pipeline. + +If you only need to swap in a different JSONL manifest, no code changes are required. To support a new raw data format, you usually only need to add: + +- one **source adapter** +- optionally one **sample pipeline** + +## Data flow + +1. An **adapter** reads from the raw data source and yields raw samples. +2. A **pipeline** turns each raw sample into a training sample (1:1). +3. A **multi-source wrapper** handles mixing across sources and resume state. +4. `StreamingSampleDataset` / `DataLoader` pulls samples. +5. `OnlineBatcher` assembles batches and `PadCollator` performs padding. + +## What an adapter must implement + +Subclass `BaseSourceAdapter`: + +```python +class BaseSourceAdapter(ABC): + @abstractmethod + def initial_state(self) -> dict[str, Any]: + ... + + @abstractmethod + def iter_samples( + self, + context: SourceContext, + *, + state: dict[str, Any] | None = None, + ) -> Iterable[dict[str, Any]]: + ... + + @abstractmethod + def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool: + ... + + # Optional — only required when used under WeightedMultiSourceAdapter, + # which cycles each finite child source independently. The default + # implementation raises if your adapter never gets re-cycled. + def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]: + ... +``` + +Each emitted sample **must** carry these fields: + +- `fid` +- `text` +- `audio` +- `_adapter_state` + +Key constraints: + +- `_adapter_state` must describe **where to resume next**, not the position of the current item. +- The state must be plain Python data — serializable and recoverable after a restart. +- If your source needs to be split across workers, use `context.global_worker_id` and `context.global_worker_count` (or subclass `ShardableSourceAdapter` and use its `is_assigned_index` / `shard_items` helpers). +- If the source will participate in weighted cyclic sampling, you must implement `advance_cycle` and make `is_cycle_start_state` correct — otherwise `WeightedMultiSourceAdapter` cannot detect an empty cycle and will raise. + +After implementing the adapter, register the class in `dots_tts/data/builders.py::_SOURCE_ADAPTER_CLASSES` so that the YAML config can resolve it by `class_name`. + +## What a pipeline must implement + +Pipelines must subclass `BaseSamplePipeline` and perform a strict **1:1** sample transform. + +Minimum implementation: + +```python +class MyPipeline(BaseSamplePipeline): + def process_sample(self, sample: dict) -> dict: + sample["text"] = str(sample["text"]).strip() + return sample +``` + +Do **not**: + +- filter samples out +- expand a single sample into multiple samples +- assemble batches inside the pipeline + +`BaseSamplePipeline.__call__` automatically merges the original raw sample (including `_adapter_state` and any extra fields the adapter attached) with whatever your `process_sample` returns. You do not need to copy these fields manually — just return the fields you produced or want to overwrite. + +To wire a new pipeline into config, also extend `dots_tts/data/builders.py::_build_source_pipeline` so it can be selected by name in YAML. + +## How multi-source wrappers affect you + +There are two wrappers in the current codebase: + +- `SequentialMultiSourceAdapter` — used for validation. Reads sources in the configured order, exhaustively, once. +- `WeightedMultiSourceAdapter` — used for training. Draws sources by weight, cycles each child source independently when exhausted. + +Both wrappers **replace** the `_adapter_state` produced by your child adapter with their own resume state before yielding to the dataset. Even so, the child adapter must still emit its own `_adapter_state` — the wrapper reads it to track where each sub-source has read to. + +## Config + +Each source is configured independently: + +```yaml +train_data: + sources: + - name: train_a + weight: 1.0 + pipeline: basic + adapter: + class_name: JsonlManifestSourceAdapter + params: + manifest_path: train_a.jsonl + - name: train_b + weight: 2.0 + pipeline: interleave + adapter: + class_name: JsonlManifestSourceAdapter + params: + manifest_path: train_b.jsonl +``` + +Constraints: + +- `sources[].name` must be unique within the same `train_data` / `val_data` block (it is used as a dict key for resume state). +- `sources[].pipeline` is a per-source setting, not shared across the dataset. +- All sources must ultimately produce the same training-sample structure, since they feed into the same batcher and collator. +- `class_name` must match a key registered in `_SOURCE_ADAPTER_CLASSES`; `params` is forwarded verbatim as kwargs to the adapter constructor. diff --git a/src/dots_tts/data/__init__.py b/src/dots_tts/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68c760d27de425293851d8186f8ad18d9987b4e2 --- /dev/null +++ b/src/dots_tts/data/__init__.py @@ -0,0 +1 @@ +"""Data package.""" diff --git a/src/dots_tts/data/batchers.py b/src/dots_tts/data/batchers.py new file mode 100644 index 0000000000000000000000000000000000000000..1503064c0e3f8fa095ddd43461efc5ca79b13f9a --- /dev/null +++ b/src/dots_tts/data/batchers.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +import warnings +from collections.abc import Iterable, Iterator +from dataclasses import dataclass + +from dots_tts.utils.profiling import ensure_data_profiler + + +@dataclass(slots=True) +class BatchDecision: + dropped_samples: list[dict] + batch_samples: list[dict] + + +@dataclass(slots=True) +class _PoolSample: + sample: dict + num_audio_tokens: int + num_text_tokens: int + arrival_step: int + + +class OnlineBatcher: + def __init__( + self, + *, + max_audio_tokens_in_batch: int, + max_text_tokens_in_batch: int, + max_batch_size: int | None, + sample_pool_size: int, + profiler=None, + ): + self.max_audio_tokens_in_batch = max(1, int(max_audio_tokens_in_batch)) + self.max_text_tokens_in_batch = max(1, int(max_text_tokens_in_batch)) + self.max_batch_size = max_batch_size + self.sample_pool_size = max(1, int(sample_pool_size)) + self.profiler = ensure_data_profiler(profiler) + + @staticmethod + def _sort_pool(pool: list[_PoolSample]) -> None: + pool.sort( + key=lambda item: ( + item.num_audio_tokens, + item.num_text_tokens, + -item.arrival_step, + ), + reverse=True, + ) + + def _choose_anchor_index( + self, + pool: list[_PoolSample], + *, + decision_step: int, + ) -> int: + oldest_waiting_index = -1 + oldest_waiting_step = decision_step + + for index, item in enumerate(pool): + waited_steps = decision_step - item.arrival_step + if waited_steps < self.sample_pool_size: + continue + if item.arrival_step <= oldest_waiting_step: + oldest_waiting_index = index + oldest_waiting_step = item.arrival_step + + return 0 if oldest_waiting_index < 0 else oldest_waiting_index + + def _build_next_decision( + self, + pool: list[_PoolSample], + *, + decision_step: int, + ) -> BatchDecision: + dropped_samples: list[dict] = [] + batch_samples: list[dict] = [] + selected_indices: list[int] = [] + anchor_index = self._choose_anchor_index(pool, decision_step=decision_step) + anchor = pool[anchor_index] + + exceed_audio_budget = anchor.num_audio_tokens > self.max_audio_tokens_in_batch + exceed_text_budget = anchor.num_text_tokens > self.max_text_tokens_in_batch + exceed_batch_size = self.max_batch_size is not None and self.max_batch_size < 1 + if exceed_audio_budget or exceed_text_budget or exceed_batch_size: + skipped = pool.pop(anchor_index).sample + dropped_samples.append(skipped) + warnings.warn( + "Skipping sample that exceeds batching limits on its own: " + f"fid={skipped.get('fid')!r}, " + f"num_audio_tokens={anchor.num_audio_tokens}, " + f"input_ids_length={anchor.num_text_tokens}, " + f"max_audio_tokens_in_batch={self.max_audio_tokens_in_batch}, " + f"max_text_tokens_in_batch={self.max_text_tokens_in_batch}, " + f"max_batch_size={self.max_batch_size}", + RuntimeWarning, + stacklevel=2, + ) + return BatchDecision( + dropped_samples=dropped_samples, + batch_samples=batch_samples, + ) + + longest_audio_tokens = anchor.num_audio_tokens + longest_text_tokens = anchor.num_text_tokens + batch_samples.append(anchor.sample) + selected_indices.append(anchor_index) + + for index, item in enumerate(pool): + if index == anchor_index: + continue + if ( + self.max_batch_size is not None + and len(batch_samples) >= self.max_batch_size + ): + break + + proposed_batch_size = len(batch_samples) + 1 + proposed_longest_audio_tokens = max( + longest_audio_tokens, + item.num_audio_tokens, + ) + proposed_longest_text_tokens = max( + longest_text_tokens, + item.num_text_tokens, + ) + if ( + proposed_longest_audio_tokens * proposed_batch_size + > self.max_audio_tokens_in_batch + ): + continue + if ( + proposed_longest_text_tokens * proposed_batch_size + > self.max_text_tokens_in_batch + ): + continue + + batch_samples.append(item.sample) + selected_indices.append(index) + longest_audio_tokens = proposed_longest_audio_tokens + longest_text_tokens = proposed_longest_text_tokens + + for index in sorted(set(selected_indices), reverse=True): + pool.pop(index) + + return BatchDecision( + dropped_samples=dropped_samples, + batch_samples=batch_samples, + ) + + def build_decisions(self, sample_iter: Iterable[dict]) -> Iterator[BatchDecision]: + pool: list[_PoolSample] = [] + source_exhausted = False + decision_step = 0 + iterator = iter(sample_iter) + + while not source_exhausted or pool: + while not source_exhausted and len(pool) < self.sample_pool_size: + try: + sample = next(iterator) + except StopIteration: + source_exhausted = True + break + pool.append( + _PoolSample( + sample=sample, + num_audio_tokens=int(sample.get("num_audio_tokens", 0)), + num_text_tokens=int(sample.get("input_ids_length", 0)), + arrival_step=decision_step, + ) + ) + + if not pool: + break + + profiler = self.profiler + with profiler.measure("main.sort_pool", count=len(pool)): + self._sort_pool(pool) + with profiler.measure("main.build_batch_decision"): + decision = self._build_next_decision( + pool, + decision_step=decision_step, + ) + if decision.dropped_samples or decision.batch_samples: + decision_step += 1 + yield decision + continue + raise RuntimeError("OnlineBatcher failed to make progress on a non-empty pool.") diff --git a/src/dots_tts/data/builders.py b/src/dots_tts/data/builders.py new file mode 100644 index 0000000000000000000000000000000000000000..a750951c78e47be5d28dea0baf7141f213f6a6aa --- /dev/null +++ b/src/dots_tts/data/builders.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from torch.utils.data import DataLoader + +from dots_tts.config.data import DataConfig +from dots_tts.data.pipelines.base import BaseSamplePipeline +from dots_tts.data.pipelines.tts_pipeline import BasicTtsPipeline, InterleaveTtsPipeline +from dots_tts.data.source_adapters.jsonl_manifest_adapter import ( + JsonlManifestSourceAdapter, +) +from dots_tts.data.source_adapters.multi_source_adapter import ( + SequentialMultiSourceAdapter, + SourceSpec, + WeightedMultiSourceAdapter, +) +from dots_tts.data.streaming import ( + BatchedDataStream, + StreamingSampleDataset, + identity_collate, +) + +_SOURCE_ADAPTER_CLASSES = { + "JsonlManifestSourceAdapter": JsonlManifestSourceAdapter, +} + + +def _build_source_pipeline( + tokenizer, data_cfg, pipeline_name: str, *, profiler=None +) -> BaseSamplePipeline: + if pipeline_name == "basic": + return BasicTtsPipeline(tokenizer, data_cfg, profiler=profiler) + if pipeline_name == "interleave": + return InterleaveTtsPipeline(tokenizer, data_cfg, profiler=profiler) + raise ValueError(f"Unsupported data pipeline: {pipeline_name!r}") + + +def _build_source_specs(data_cfg, tokenizer, *, profiler=None) -> list[SourceSpec]: + specs = [] + for source_cfg in data_cfg.sources: + adapter_cls = _SOURCE_ADAPTER_CLASSES[source_cfg.adapter.class_name] + adapter = adapter_cls(**source_cfg.adapter.params) + specs.append( + SourceSpec( + name=source_cfg.name, + weight=float(source_cfg.weight), + adapter=adapter, + pipeline=_build_source_pipeline( + tokenizer, data_cfg, source_cfg.pipeline, profiler=profiler + ), + ) + ) + return specs + + +def _resolve_rank_info(accelerator=None) -> tuple[int, int]: + rank = ( + int(getattr(accelerator, "process_index", 0)) if accelerator is not None else 0 + ) + world_size = ( + int(getattr(accelerator, "num_processes", 1)) if accelerator is not None else 1 + ) + return rank, world_size + + +def _local_num_tokens_per_epoch( + global_num_tokens_per_epoch: int, *, rank: int, world_size: int +) -> int: + if world_size <= 0: + raise ValueError(f"world_size must be positive, but got {world_size}.") + if rank < 0 or rank >= world_size: + raise ValueError( + f"rank must be in [0, {world_size}), but got rank={rank}." + ) + + base, remainder = divmod(int(global_num_tokens_per_epoch), int(world_size)) + return base + int(rank < remainder) + + +def _build_dataset( + data_cfg: DataConfig, + *, + tokenizer, + seed: int, + accelerator=None, + sequential: bool, + profiler=None, +): + rank, world_size = _resolve_rank_info(accelerator) + source_cls = SequentialMultiSourceAdapter if sequential else WeightedMultiSourceAdapter + source = source_cls( + sources=_build_source_specs(data_cfg, tokenizer, profiler=profiler) + ) + return StreamingSampleDataset( + source=source, + rank=rank, + world_size=world_size, + seed=int(seed), + ) + + +def build_training_dataset( + data_cfg: DataConfig, + tokenizer, + *, + seed: int, + accelerator=None, + profiler=None, +): + if data_cfg.num_tokens_per_epoch is None: + raise ValueError("Training data requires num_tokens_per_epoch.") + return _build_dataset( + data_cfg, + tokenizer=tokenizer, + seed=seed, + accelerator=accelerator, + sequential=False, + profiler=profiler, + ) + + +def build_validation_dataset( + data_cfg: DataConfig, + tokenizer, + *, + seed: int, + accelerator=None, + profiler=None, +): + return _build_dataset( + data_cfg, + tokenizer=tokenizer, + seed=seed, + accelerator=accelerator, + sequential=True, + profiler=profiler, + ) + + +def _build_sample_loader(dataset, data_cfg: DataConfig) -> DataLoader: + loader_kwargs = { + "dataset": dataset, + "batch_size": None, + "collate_fn": identity_collate, + "num_workers": data_cfg.num_workers, + "pin_memory": data_cfg.pin_memory, + "persistent_workers": data_cfg.num_workers > 0, + } + if data_cfg.num_workers > 0: + loader_kwargs["prefetch_factor"] = int(data_cfg.prefetch_factor) + sample_loader = DataLoader(**loader_kwargs) + return sample_loader + + +def build_training_dataloader( + dataset, data_cfg: DataConfig, tokenizer, *, profiler=None +): + local_num_tokens_per_epoch = _local_num_tokens_per_epoch( + int(data_cfg.num_tokens_per_epoch), + rank=int(dataset.rank), + world_size=int(dataset.world_size), + ) + sample_loader = _build_sample_loader(dataset, data_cfg) + batched_stream = BatchedDataStream( + sample_dataset=dataset, + data_cfg=data_cfg, + tokenizer=tokenizer, + num_tokens_per_epoch=local_num_tokens_per_epoch, + profiler=profiler, + ) + batched_stream.attach_loader(sample_loader) + return batched_stream + + +def build_validation_dataloader( + dataset, data_cfg: DataConfig, tokenizer, *, profiler=None +): + sample_loader = _build_sample_loader(dataset, data_cfg) + batched_stream = BatchedDataStream( + sample_dataset=dataset, + data_cfg=data_cfg, + tokenizer=tokenizer, + num_tokens_per_epoch=None, + profiler=profiler, + ) + batched_stream.attach_loader(sample_loader) + return batched_stream + + +__all__ = [ + "build_training_dataloader", + "build_training_dataset", + "build_validation_dataloader", + "build_validation_dataset", +] diff --git a/src/dots_tts/data/collator.py b/src/dots_tts/data/collator.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ec98fa95b9d41e5b4faf4f282ccafeb0221ddd --- /dev/null +++ b/src/dots_tts/data/collator.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import Any + +import torch +from torch.nn.utils.rnn import pad_sequence + + +class PadCollator: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + self.pad_token_id = tokenizer.pad_token_id + if self.pad_token_id is None: + self.pad_token_id = tokenizer.eos_token_id or 0 + + def __call__(self, samples: list[dict[str, Any]]) -> dict[str, Any]: + if not samples: + raise ValueError("PadCollator received an empty sample list.") + + order = sorted( + range(len(samples)), + key=lambda idx: samples[idx]["sample_length"], + reverse=True, + ) + ordered = [samples[idx] for idx in order] + + input_ids = [ + torch.tensor(sample["input_ids"], dtype=torch.long) for sample in ordered + ] + labels = [ + torch.tensor(sample["labels"], dtype=torch.long) for sample in ordered + ] + loss_masks = [ + torch.tensor(sample["loss_mask"], dtype=torch.float32) for sample in ordered + ] + waveforms = [sample["sample"].squeeze(0) for sample in ordered] + fbank = [sample["fbank"] for sample in ordered] + + return { + "fids": [sample["fid"] for sample in ordered], + "source_names": [sample.get("source_name") for sample in ordered], + "input_ids": pad_sequence( + input_ids, + batch_first=True, + padding_value=self.pad_token_id, + ), + "input_ids_lengths": torch.tensor( + [len(sample["input_ids"]) for sample in ordered], + dtype=torch.long, + ), + "labels": pad_sequence( + labels, + batch_first=True, + padding_value=self.pad_token_id, + ), + "loss_mask": pad_sequence( + loss_masks, + batch_first=True, + padding_value=0.0, + ), + "sample": pad_sequence( + waveforms, + batch_first=True, + padding_value=0.0, + ).unsqueeze(1), + "sample_lengths": torch.tensor( + [sample["sample_length"] for sample in ordered], + dtype=torch.long, + ), + "num_text_tokens": torch.tensor( + [sample["num_text_tokens"] for sample in ordered], + dtype=torch.long, + ), + "num_audio_tokens": torch.tensor( + [sample["num_audio_tokens"] for sample in ordered], + dtype=torch.long, + ), + "fbank": pad_sequence( + fbank, + batch_first=True, + padding_value=0.0, + ), + "fbank_lengths": torch.tensor( + [sample["fbank_length"] for sample in ordered], + dtype=torch.long, + ), + } diff --git a/src/dots_tts/data/pipelines/__init__.py b/src/dots_tts/data/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2888992530b44dc91d9d3f3b81c927f82ea256f1 --- /dev/null +++ b/src/dots_tts/data/pipelines/__init__.py @@ -0,0 +1 @@ +"""Data pipelines package.""" diff --git a/src/dots_tts/data/pipelines/base.py b/src/dots_tts/data/pipelines/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c33c080bfa986a0078f80c0f67373e486081b484 --- /dev/null +++ b/src/dots_tts/data/pipelines/base.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator + + +class BaseSamplePipeline(ABC): + """1:1 sample pipeline that preserves adapter resume metadata.""" + + @staticmethod + def _validate_input_sample(sample: dict) -> None: + if "_adapter_state" not in sample: + raise RuntimeError( + "Source sample is missing required '_adapter_state' for resume." + ) + + @abstractmethod + def process_sample(self, sample: dict) -> dict: + """Transform one raw sample into one processed sample.""" + + def __call__(self, samples: Iterable[dict]) -> Iterator[dict]: + for raw_sample in samples: + self._validate_input_sample(raw_sample) + processed = self.process_sample(dict(raw_sample)) + if not isinstance(processed, dict): + raise RuntimeError( + f"{self.__class__.__name__}.process_sample() must return a dict." + ) + item = dict(raw_sample) + item.update(processed) + self._validate_input_sample(item) + yield item diff --git a/src/dots_tts/data/pipelines/preprocessing.py b/src/dots_tts/data/pipelines/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..b598b9d2f34500ae9326a6228e6f293b7497f517 --- /dev/null +++ b/src/dots_tts/data/pipelines/preprocessing.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import torch +import torch.nn.functional as F + +DEFAULT_EDGE_SILENCE_MS = 250.0 +DEFAULT_EDGE_SILENCE_TOP_DB = 30.0 + + +def align_length(num_samples: int, multiple_of: int | None) -> int: + if multiple_of is None or multiple_of <= 0: + return int(num_samples) + if num_samples % multiple_of == 0: + return int(num_samples) + return int(((num_samples + multiple_of - 1) // multiple_of) * multiple_of) + + +def pad_waveform_align_only( + waveform: torch.Tensor, + *, + multiple_of: int | None, +) -> torch.Tensor: + if multiple_of is None or multiple_of <= 0: + return waveform + + target_length = align_length(waveform.size(-1), multiple_of) + delta = target_length - waveform.size(-1) + if delta <= 0: + return waveform + + return F.pad(waveform, (0, delta), "constant", 0.0) + + +def normalize_edge_silence_duration( + waveform: torch.Tensor, + *, + sample_rate: int, + target_silence_duration_ms: float = DEFAULT_EDGE_SILENCE_MS, + top_db: float = DEFAULT_EDGE_SILENCE_TOP_DB, +) -> torch.Tensor: + mono_waveform = waveform[0] + target_samples = int(round(float(sample_rate) * float(target_silence_duration_ms) / 1000.0)) + amplitude = mono_waveform.abs() + peak = float(amplitude.max().item()) + if peak <= 0.0: + waveform = waveform[..., :target_samples] + current_length = int(waveform.size(-1)) + if current_length < target_samples: + waveform = F.pad(waveform, (0, target_samples - current_length), "constant", 0.0) + return waveform + + threshold = peak * (10.0 ** (-float(top_db) / 20.0)) + non_silent = torch.nonzero(amplitude > threshold, as_tuple=False).flatten() + first_non_silent = int(non_silent[0].item()) + last_non_silent = int(non_silent[-1].item()) + + leading_silence_samples = first_non_silent + trailing_silence_samples = int(mono_waveform.numel()) - last_non_silent - 1 + + leading_delta = target_samples - leading_silence_samples + if leading_delta > 0: + waveform = F.pad(waveform, (leading_delta, 0), "constant", 0.0) + else: + trim_from_start = min(-leading_delta, int(waveform.size(-1))) + waveform = waveform[..., trim_from_start:] + + trailing_delta = target_samples - trailing_silence_samples + if trailing_delta > 0: + return F.pad(waveform, (0, trailing_delta), "constant", 0.0) + + trim_from_end = min(-trailing_delta, int(waveform.size(-1))) + if trim_from_end <= 0: + return waveform + return waveform[..., :-trim_from_end] + + +def compute_num_audio_tokens( + num_samples: int, *, audio_samples_per_llm_token: int +) -> int: + if num_samples % audio_samples_per_llm_token != 0: + raise ValueError( + f"Waveform length {num_samples} is not aligned to token hop {audio_samples_per_llm_token}." + ) + return num_samples // audio_samples_per_llm_token diff --git a/src/dots_tts/data/pipelines/tokenizing.py b/src/dots_tts/data/pipelines/tokenizing.py new file mode 100644 index 0000000000000000000000000000000000000000..0133fec16c1a6b19b7f97110058ee4bdcf3b06c8 --- /dev/null +++ b/src/dots_tts/data/pipelines/tokenizing.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +from dataclasses import dataclass +import re +from typing import Any + +from loguru import logger + +from dots_tts.utils.tokenizer import ( + AUDIO_GEN_END_TOKEN, + AUDIO_GEN_SPAN_TOKEN, + AUDIO_GEN_START_TOKEN, + TEXT_COND_END_TOKEN, + require_token_id, +) + +TEMPLATE_PATTERN = re.compile(r"\{text\}|\{audio\}|\{interleave\}|[^\{]+") + + +@dataclass(frozen=True) +class ParsedTemplate: + parts: tuple[str, ...] + has_audio_placeholder: bool + has_interleave_placeholder: bool + + +@dataclass(frozen=True) +class TokenizedTemplatePart: + kind: str + token_ids: tuple[int, ...] = () + raw_text: str | None = None + + +def parse_template(template: str) -> ParsedTemplate: + parts = tuple(re.findall(TEMPLATE_PATTERN, template)) + has_audio_placeholder = "{audio}" in parts + interleave_count = parts.count("{interleave}") + if has_audio_placeholder and interleave_count: + raise ValueError("Template cannot mix audio and interleave placeholders.") + if interleave_count > 1: + raise ValueError( + "Interleave generation template must contain exactly one interleave placeholder." + ) + return ParsedTemplate( + parts=parts, + has_audio_placeholder=has_audio_placeholder, + has_interleave_placeholder=interleave_count == 1, + ) + + +def _prepare_template_tokens( + *, text: str, tokenizer, template: str +) -> tuple[ParsedTemplate, list[int]]: + return parse_template(template), tokenizer.encode(text, add_special_tokens=False) + + +def _iter_tokenized_template_parts( + *, + parsed_template: ParsedTemplate, + tokenizer, + text_tokens: list[int], +): + for part in parsed_template.parts: + if part == "{text}": + yield TokenizedTemplatePart(kind="text", token_ids=tuple(text_tokens)) + continue + if part == "{audio}": + yield TokenizedTemplatePart(kind="audio") + continue + if part == "{interleave}": + yield TokenizedTemplatePart(kind="interleave") + continue + yield TokenizedTemplatePart( + kind="literal", + token_ids=tuple(tokenizer.encode(part, add_special_tokens=False)), + raw_text=part, + ) + + +def _extend_tokens_with_loss( + *, full_ids: list[int], loss_mask: list[float], token_ids: tuple[int, ...], loss: float +) -> None: + full_ids.extend(token_ids) + loss_mask.extend([loss] * len(token_ids)) + + +def build_tokenized_example( + *, text: str, tokenizer, template: str, num_audio_tokens: int +) -> dict[str, Any]: + if tokenizer.eos_token_id is None: + raise ValueError("Tokenizer eos_token_id is required for generation targets.") + + parsed_template, text_tokens = _prepare_template_tokens( + text=text, + tokenizer=tokenizer, + template=template, + ) + + full_ids: list[int] = [] + loss_mask: list[float] = [] + audio_tokens: list[int] | None = None + if parsed_template.has_audio_placeholder: + audio_gen_start_id = require_token_id(tokenizer, AUDIO_GEN_START_TOKEN) + audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN) + audio_gen_end_id = require_token_id(tokenizer, AUDIO_GEN_END_TOKEN) + audio_tokens = ( + [audio_gen_start_id] + + [audio_gen_span_id] * num_audio_tokens + + [audio_gen_end_id] + ) + elif parsed_template.has_interleave_placeholder: + audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN) + audio_gen_end_id = require_token_id(tokenizer, AUDIO_GEN_END_TOKEN) + text_cond_end_id = require_token_id(tokenizer, TEXT_COND_END_TOKEN) + + for part in _iter_tokenized_template_parts( + parsed_template=parsed_template, + tokenizer=tokenizer, + text_tokens=text_tokens, + ): + if part.kind == "text": + _extend_tokens_with_loss( + full_ids=full_ids, + loss_mask=loss_mask, + token_ids=part.token_ids, + loss=0.0, + ) + continue + + if part.kind == "audio": + if audio_tokens is None: + raise RuntimeError("Audio placeholder tokens were not initialized.") + full_ids.extend(audio_tokens) + loss_mask.extend([0.0]) + loss_mask.extend([1.0] * max(0, len(audio_tokens) - 2)) + loss_mask.append(0.0) + continue + + if part.kind == "interleave": + _append_interleave_generation_tokens( + full_ids=full_ids, + loss_mask=loss_mask, + text_tokens=text_tokens, + num_audio_tokens=num_audio_tokens, + audio_span_id=audio_gen_span_id, + audio_end_id=audio_gen_end_id, + text_cond_end_id=text_cond_end_id, + ) + continue + + _extend_tokens_with_loss( + full_ids=full_ids, + loss_mask=loss_mask, + token_ids=part.token_ids, + loss=0.0, + ) + + full_ids.append(tokenizer.eos_token_id) + loss_mask.append(0.0) + + return { + "input_ids": full_ids[:-1], + "labels": full_ids[1:], + "loss_mask": loss_mask[1:], + "text_token_count": len(text_tokens), + } + + +def build_generation_schedule( + *, + text: str, + tokenizer, + template: str, + max_audio_tokens: int, +) -> dict[str, Any]: + if max_audio_tokens <= 0: + raise ValueError("max_audio_tokens must be positive for generation.") + + parsed_template, text_tokens = _prepare_template_tokens( + text=text, + tokenizer=tokenizer, + template=template, + ) + schedule_ids: list[int] = [] + audio_gen_start_id = require_token_id(tokenizer, AUDIO_GEN_START_TOKEN) + audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN) + + if parsed_template.has_audio_placeholder: + for part in _iter_tokenized_template_parts( + parsed_template=parsed_template, + tokenizer=tokenizer, + text_tokens=text_tokens, + ): + if part.kind == "audio": + schedule_ids.append(audio_gen_start_id) + schedule_ids.extend([audio_gen_span_id] * max_audio_tokens) + continue + schedule_ids.extend(part.token_ids) + visible_schedule_ids = [ + token_id for token_id in schedule_ids if token_id != audio_gen_span_id + ] + decoded_schedule = ( + tokenizer.decode( + visible_schedule_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + if hasattr(tokenizer, "decode") + else repr(visible_schedule_ids) + ) + logger.info( + "Built generation schedule: interleave={} max_audio_tokens={} sequence={!r}", + False, + int(max_audio_tokens), + decoded_schedule, + ) + return { + "schedule_ids": schedule_ids, + "interleave": False, + } + + if not parsed_template.has_interleave_placeholder: + raise ValueError( + "Generation template must contain either {audio} or {interleave}." + ) + text_cond_end_id = require_token_id(tokenizer, TEXT_COND_END_TOKEN) + if max_audio_tokens < len(text_tokens): + raise ValueError( + "Interleave generation requires at least one audio span per text token: " + f"text_token_count={len(text_tokens)} " + f"max_audio_patch_count={max_audio_tokens}." + ) + + interleave_started = False + for part in _iter_tokenized_template_parts( + parsed_template=parsed_template, + tokenizer=tokenizer, + text_tokens=text_tokens, + ): + if part.kind == "interleave": + _append_interleave_schedule_tokens( + schedule_ids=schedule_ids, + text_tokens=text_tokens, + max_audio_tokens=max_audio_tokens, + audio_span_id=audio_gen_span_id, + text_cond_end_id=text_cond_end_id, + ) + interleave_started = True + continue + if part.kind == "text": + raise ValueError( + "Generation schedule does not support {text} inside an interleave template." + ) + if part.kind == "audio": + raise ValueError( + "Generation schedule does not support {audio} inside an interleave template." + ) + if interleave_started: + if (part.raw_text or "").strip(): + raise ValueError( + "Generation schedule does not support non-empty suffix text after the interleave placeholder." + ) + continue + schedule_ids.extend(part.token_ids) + + visible_schedule_ids = [ + token_id for token_id in schedule_ids if token_id != audio_gen_span_id + ] + decoded_schedule = ( + tokenizer.decode( + visible_schedule_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + if hasattr(tokenizer, "decode") + else repr(visible_schedule_ids) + ) + logger.info( + "Built generation schedule: interleave={} max_audio_tokens={} sequence={!r}", + True, + int(max_audio_tokens), + decoded_schedule, + ) + return { + "schedule_ids": schedule_ids, + "interleave": True, + } + + +def _append_interleave_generation_tokens( + *, + full_ids: list[int], + loss_mask: list[float], + text_tokens: list[int], + num_audio_tokens: int, + audio_span_id: int, + audio_end_id: int, + text_cond_end_id: int, +) -> None: + audio_tokens = [audio_span_id] * num_audio_tokens + [audio_end_id] + text_index = 0 + audio_index = 0 + text_cond_end_added = False + + while text_index < len(text_tokens) or audio_index < len(audio_tokens): + if text_index < len(text_tokens): + full_ids.append(text_tokens[text_index]) + loss_mask.append(0.0) + text_index += 1 + elif not text_cond_end_added: + full_ids.append(text_cond_end_id) + loss_mask.append(0.0) + text_cond_end_added = True + + if audio_index < len(audio_tokens): + full_ids.append(audio_tokens[audio_index]) + loss_mask.append(1.0 if audio_index < num_audio_tokens else 0.0) + audio_index += 1 + + if not text_cond_end_added: + full_ids.append(text_cond_end_id) + loss_mask.append(0.0) + + +def _append_interleave_schedule_tokens( + *, + schedule_ids: list[int], + text_tokens: list[int], + max_audio_tokens: int, + audio_span_id: int, + text_cond_end_id: int, +) -> None: + for token_id in text_tokens: + schedule_ids.append(token_id) + schedule_ids.append(audio_span_id) + schedule_ids.append(text_cond_end_id) + remaining_audio_tokens = max_audio_tokens - len(text_tokens) + if remaining_audio_tokens > 0: + schedule_ids.extend([audio_span_id] * remaining_audio_tokens) diff --git a/src/dots_tts/data/pipelines/tts_pipeline.py b/src/dots_tts/data/pipelines/tts_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e9406d610c2d02a0301f57d8533d26d90c2d895f --- /dev/null +++ b/src/dots_tts/data/pipelines/tts_pipeline.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import soundfile as sf +import torch + +from dots_tts.utils.profiling import ensure_data_profiler +from dots_tts.data.pipelines.base import BaseSamplePipeline +from dots_tts.data.pipelines.preprocessing import ( + compute_num_audio_tokens, + normalize_edge_silence_duration, + pad_waveform_align_only, +) +from dots_tts.data.pipelines.tokenizing import build_tokenized_example +from dots_tts.modules.speaker.fbank import extract_speaker_fbank +from dots_tts.utils.audio import high_quality_resample + +TTS_TEXT_PREFIX = "[文本]" +TTS_AUDIO_PREFIX = "[文本对应语音]" +TTS_INSTRUCTION_TEXT_PREFIX = "[带指令文本]" +TTA_TEXT_PREFIX = "[声音描述]" +TTA_AUDIO_PREFIX = "[描述对应声音]" +TTS_INTERLEAVE_PREFIX = "[流式语音合成]" +DEFAULT_TRAIN_TEMPLATE = f"{TTS_TEXT_PREFIX}{{text}}{TTS_AUDIO_PREFIX}{{audio}}" +DEFAULT_INSTRUCTION_TTS_TEMPLATE = ( + f"{TTS_INSTRUCTION_TEXT_PREFIX}{{text}}{TTS_AUDIO_PREFIX}{{audio}}" +) +DEFAULT_TEXT_TO_AUDIO_TEMPLATE = f"{TTA_TEXT_PREFIX}{{text}}{TTA_AUDIO_PREFIX}{{audio}}" +DEFAULT_INTERLEAVE_TRAIN_TEMPLATE = f"{TTS_INTERLEAVE_PREFIX}{{interleave}}" + + +class BasicTtsPipeline(BaseSamplePipeline): + """Fixed internal training pipeline for adapter-emitted samples.""" + + template = DEFAULT_TRAIN_TEMPLATE + + def __init__(self, tokenizer, data_cfg, *, profiler=None): + self.tokenizer = tokenizer + self.train_audio_sample_rate = int(data_cfg.train_audio_sample_rate) + self.audio_samples_per_llm_token = int(data_cfg.audio_samples_per_llm_token) + self.profiler = ensure_data_profiler(profiler) + + @staticmethod + def _load_waveform(audio_path: str) -> tuple[torch.Tensor, int]: + if not isinstance(audio_path, str): + raise TypeError( + f"Training audio must be a filesystem path, got {type(audio_path)}." + ) + audio_data, sample_rate = sf.read( + audio_path, + dtype="float32", + always_2d=True, + ) + waveform = torch.from_numpy(audio_data.T) + if waveform.size(0) > 1: + waveform = waveform.mean(dim=0, keepdim=True) + return waveform.contiguous(), int(sample_rate) + + @staticmethod + def _validate_source_sample(sample: dict) -> None: + missing = [field for field in ("fid", "text", "audio") if field not in sample] + if missing: + raise ValueError( + "Source adapter must emit fid/text/audio. " + f"Missing fields: {missing}. Sample keys: {sorted(sample.keys())}" + ) + + def process_sample(self, raw_sample: dict) -> dict: + sample = dict(raw_sample) + self._validate_source_sample(sample) + sample["fid"] = str(sample["fid"]) + + with self.profiler.measure("worker.process_sample_total"): + return self._process_sample_impl(sample) + + def _process_sample_impl(self, sample: dict) -> dict: + profiler = self.profiler + with profiler.measure("worker.load_audio"): + waveform, sample_rate = self._load_waveform(sample["audio"]) + with profiler.measure("worker.resample_audio"): + waveform = high_quality_resample( + waveform, + orig_sr=sample_rate, + target_sr=self.train_audio_sample_rate, + ) + with profiler.measure("worker.normalize_edge_silence"): + waveform = normalize_edge_silence_duration( + waveform, + sample_rate=self.train_audio_sample_rate, + ) + sample["sample"] = waveform + sample["sample_rate"] = self.train_audio_sample_rate + sample["unpadded_sample_length"] = int(waveform.size(-1)) + + with profiler.measure("worker.pad_audio"): + waveform = pad_waveform_align_only( + waveform, + multiple_of=self.audio_samples_per_llm_token, + ) + sample["sample"] = waveform + sample["sample_length"] = int(waveform.size(-1)) + + num_audio_tokens = compute_num_audio_tokens( + sample["sample_length"], + audio_samples_per_llm_token=self.audio_samples_per_llm_token, + ) + with profiler.measure("worker.tokenize"): + tokenized = build_tokenized_example( + text=sample["text"], + tokenizer=self.tokenizer, + template=self.template, + num_audio_tokens=num_audio_tokens, + ) + sample["input_ids"] = tokenized["input_ids"] + sample["labels"] = tokenized["labels"] + sample["loss_mask"] = tokenized["loss_mask"] + sample["input_ids_length"] = len(tokenized["input_ids"]) + sample["num_text_tokens"] = tokenized["text_token_count"] + sample["num_audio_tokens"] = num_audio_tokens + sample["num_total_tokens"] = sample["input_ids_length"] + + with profiler.measure("worker.extract_fbank"): + fbank = extract_speaker_fbank( + sample["sample"], + sample_rate=sample["sample_rate"], + ) + sample["fbank"] = fbank + sample["fbank_length"] = int(fbank.size(0)) + return sample + + +class InterleaveTtsPipeline(BasicTtsPipeline): + template = DEFAULT_INTERLEAVE_TRAIN_TEMPLATE diff --git a/src/dots_tts/data/source_adapters/__init__.py b/src/dots_tts/data/source_adapters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20e336cead881a402227d769a4664a70e48900c2 --- /dev/null +++ b/src/dots_tts/data/source_adapters/__init__.py @@ -0,0 +1 @@ +"""Source adapter package.""" diff --git a/src/dots_tts/data/source_adapters/base_adapter.py b/src/dots_tts/data/source_adapters/base_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc12aa413eab257a871250595f76f9eb595471b --- /dev/null +++ b/src/dots_tts/data/source_adapters/base_adapter.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from collections.abc import Iterable, Sequence +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, TypeVar + + +@dataclass(frozen=True) +class SourceContext: + """Execution context for a single adapter iterator.""" + + epoch: int + rank: int + world_size: int + worker_id: int + num_workers: int + seed: int + + @property + def global_worker_count(self) -> int: + return max(1, self.world_size * self.num_workers) + + @property + def global_worker_id(self) -> int: + return self.rank * self.num_workers + self.worker_id + + +class BaseSourceAdapter(ABC): + """State-aware streaming source interface used by the training pipeline.""" + + @abstractmethod + def initial_state(self) -> dict[str, Any]: + """Return the default iterator state for a new worker/epoch.""" + + @abstractmethod + def iter_samples( + self, + context: SourceContext, + *, + state: dict[str, Any] | None = None, + ) -> Iterable[dict[str, Any]]: + """Yield raw samples and attach the next adapter state to each item.""" + + @abstractmethod + def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool: + """Return whether ``state`` points at the beginning of a source cycle.""" + + def normalize_state(self, state: dict[str, Any] | None) -> dict[str, Any]: + merged = self.initial_state() + if state: + merged.update(deepcopy(state)) + return merged + + def clone_state(self, state: dict[str, Any] | None) -> dict[str, Any]: + return deepcopy(self.normalize_state(state)) + + def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]: + raise RuntimeError( + f"{self.__class__.__name__} does not support repeated cycling." + ) + + +_T = TypeVar("_T") + + +class ShardableSourceAdapter(BaseSourceAdapter): + """Helper mixin for deterministic rank/worker sharding.""" + + @staticmethod + def is_assigned_index(index: int, context: SourceContext) -> bool: + return index % context.global_worker_count == context.global_worker_id + + @staticmethod + def shard_items( + items: Sequence[_T], + context: SourceContext, + *, + shuffle: bool = False, + seed_offset: int = 0, + ) -> list[_T]: + assigned = list(items) + if shuffle: + random.Random(context.seed + context.epoch + seed_offset).shuffle(assigned) + return [ + item + for index, item in enumerate(assigned) + if ShardableSourceAdapter.is_assigned_index(index, context) + ] diff --git a/src/dots_tts/data/source_adapters/jsonl_manifest_adapter.py b/src/dots_tts/data/source_adapters/jsonl_manifest_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b0759bbbbadde00173abc178fdf2cae1d9c2ac --- /dev/null +++ b/src/dots_tts/data/source_adapters/jsonl_manifest_adapter.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import json +import random +from collections.abc import Iterable, Iterator +from pathlib import Path +from typing import Any + +from dots_tts.data.source_adapters.base_adapter import ( + BaseSourceAdapter, + ShardableSourceAdapter, + SourceContext, +) + + +class JsonlManifestSourceAdapter(ShardableSourceAdapter, BaseSourceAdapter): + """Finite adapter for line-delimited JSON manifests.""" + + def __init__( + self, + *, + manifest_path: str, + fid_key: str = "fid", + text_key: str = "text", + audio_key: str = "audio", + shuffle: bool = False, + encoding: str = "utf-8", + ): + self.manifest_path = Path(manifest_path) + self.fid_key = fid_key + self.text_key = text_key + self.audio_key = audio_key + self.shuffle = shuffle + self.encoding = encoding + self._records: list[dict[str, Any]] | None = None + + def initial_state(self) -> dict[str, Any]: + return {"cycle": 0, "cursor": 0} + + def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool: + normalized = self.normalize_state(state) + return int(normalized["cursor"]) == 0 + + def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]: + normalized = self.normalize_state(state) + return {"cycle": int(normalized["cycle"]) + 1, "cursor": 0} + + def _iter_records(self) -> Iterator[dict[str, Any]]: + if not self.manifest_path.is_file(): + raise FileNotFoundError(f"Manifest file not found: {self.manifest_path!s}") + with self.manifest_path.open("r", encoding=self.encoding) as fin: + for line_no, raw_line in enumerate(fin, start=1): + line = raw_line.strip() + if not line: + continue + try: + yield json.loads(line) + except json.JSONDecodeError as exc: + raise ValueError( + f"Invalid JSON at {self.manifest_path}:{line_no}" + ) from exc + + def _base_records(self) -> list[dict[str, Any]]: + if self._records is None: + self._records = list(self._iter_records()) + return self._records + + def _build_sample(self, record: dict[str, Any]) -> dict[str, Any]: + missing = [ + key + for key in (self.fid_key, self.text_key, self.audio_key) + if key not in record + ] + if missing: + raise KeyError( + f"Manifest record is missing required keys {missing}: {record}" + ) + + sample = { + "fid": str(record[self.fid_key]), + "text": record[self.text_key], + "audio": record[self.audio_key], + } + for key, value in record.items(): + if key in {self.fid_key, self.text_key, self.audio_key}: + continue + sample[key] = value + return sample + + def _indices_for_cycle( + self, + context: SourceContext, + *, + cycle: int, + ) -> list[int]: + indices = list(range(len(self._base_records()))) + if self.shuffle: + random.Random(context.seed + context.epoch + 1009 * int(cycle)).shuffle( + indices + ) + indices = [ + record_index + for shuffled_index, record_index in enumerate(indices) + if self.is_assigned_index(shuffled_index, context) + ] + else: + indices = [ + record_index + for record_index in indices + if self.is_assigned_index(record_index, context) + ] + return indices + + def iter_samples( + self, + context: SourceContext, + *, + state: dict[str, Any] | None = None, + ) -> Iterable[dict[str, Any]]: + live_state = self.normalize_state(state) + cycle = int(live_state["cycle"]) + cursor = int(live_state["cursor"]) + records = self._base_records() + indices = self._indices_for_cycle(context, cycle=cycle) + + for position in range(cursor, len(indices)): + sample = self._build_sample(records[indices[position]]) + sample["_adapter_state"] = { + "cycle": cycle, + "cursor": position + 1, + } + yield sample diff --git a/src/dots_tts/data/source_adapters/multi_source_adapter.py b/src/dots_tts/data/source_adapters/multi_source_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..c43a2bc88cb0325d90dc44c62b1f3ba3137ded25 --- /dev/null +++ b/src/dots_tts/data/source_adapters/multi_source_adapter.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from collections.abc import Iterable +from copy import deepcopy +from dataclasses import dataclass + +from dots_tts.data.pipelines.base import BaseSamplePipeline +from dots_tts.data.source_adapters.base_adapter import ( + BaseSourceAdapter, + SourceContext, +) + + +@dataclass(frozen=True) +class SourceSpec: + name: str + weight: float + adapter: BaseSourceAdapter + pipeline: BaseSamplePipeline + + +_UINT64_MASK = 0xFFFFFFFFFFFFFFFF + + +def _mix_uint64(value: int) -> int: + value = (value ^ (value >> 30)) * 0xBF58476D1CE4E5B9 + value &= _UINT64_MASK + value = (value ^ (value >> 27)) * 0x94D049BB133111EB + value &= _UINT64_MASK + return (value ^ (value >> 31)) & _UINT64_MASK + + +def _stable_seed(*parts: int) -> int: + value = 0x9E3779B97F4A7C15 + for part in parts: + value = (value + int(part) + 0x9E3779B97F4A7C15) & _UINT64_MASK + value = _mix_uint64(value) + return value + + +class SequentialMultiSourceAdapter(BaseSourceAdapter): + """Finite adapter that concatenates sources in the configured order.""" + + def __init__(self, *, sources: list[SourceSpec]): + if not sources: + raise ValueError( + "SequentialMultiSourceAdapter requires at least one source." + ) + self.sources = list(sources) + + def initial_state(self) -> dict: + return { + "source_index": 0, + "sources": { + source.name: source.adapter.initial_state() for source in self.sources + }, + } + + def is_cycle_start_state(self, state: dict | None) -> bool: + normalized = self.normalize_state(state) + if int(normalized["source_index"]) != 0: + return False + return all( + source.adapter.is_cycle_start_state(normalized["sources"][source.name]) + for source in self.sources + ) + + def normalize_state(self, state: dict | None) -> dict: + normalized = super().normalize_state(state) + source_states = normalized.get("sources") or {} + normalized["sources"] = { + source.name: source.adapter.clone_state(source_states.get(source.name)) + for source in self.sources + } + normalized["source_index"] = int(normalized.get("source_index", 0)) + return normalized + + def clone_state(self, state: dict | None) -> dict: + return deepcopy(self.normalize_state(state)) + + def iter_samples( + self, + context: SourceContext, + *, + state: dict | None = None, + ) -> Iterable[dict]: + live_state = self.normalize_state(state) + start_index = int(live_state["source_index"]) + for index in range(start_index, len(self.sources)): + source = self.sources[index] + child_state = live_state["sources"][source.name] + raw_iter = source.adapter.iter_samples(context, state=child_state) + for sample in source.pipeline(raw_iter): + item = dict(sample) + next_child_state = item.pop("_adapter_state", None) + if next_child_state is None: + raise RuntimeError( + f"{source.adapter.__class__.__name__} must attach '_adapter_state' to samples." + ) + live_state["source_index"] = index + live_state["sources"][source.name] = source.adapter.clone_state( + next_child_state + ) + item["source_name"] = source.name + item["_adapter_state"] = self.clone_state(live_state) + yield item + live_state["source_index"] = index + 1 + + +class WeightedMultiSourceAdapter(BaseSourceAdapter): + """Infinite weighted sampler that cycles each child source independently.""" + + def __init__(self, *, sources: list[SourceSpec]): + if not sources: + raise ValueError("WeightedMultiSourceAdapter requires at least one source.") + invalid = [source.name for source in sources if float(source.weight) <= 0.0] + if invalid: + raise ValueError(f"Source weights must be positive: {invalid}") + self.sources = list(sources) + self._cumulative_weights = [] + total = 0.0 + for source in self.sources: + total += float(source.weight) + self._cumulative_weights.append(total) + self._total_weight = total + + def initial_state(self) -> dict: + return { + "draw_count": 0, + "sources": { + source.name: source.adapter.initial_state() for source in self.sources + }, + } + + def is_cycle_start_state(self, state: dict | None) -> bool: + normalized = self.normalize_state(state) + if int(normalized["draw_count"]) != 0: + return False + return all( + source.adapter.is_cycle_start_state(normalized["sources"][source.name]) + for source in self.sources + ) + + def normalize_state(self, state: dict | None) -> dict: + normalized = super().normalize_state(state) + source_states = normalized.get("sources") or {} + normalized["sources"] = { + source.name: source.adapter.clone_state(source_states.get(source.name)) + for source in self.sources + } + normalized["draw_count"] = int(normalized.get("draw_count", 0)) + return normalized + + def clone_state(self, state: dict | None) -> dict: + return deepcopy(self.normalize_state(state)) + + def _source_draw_value(self, context: SourceContext, draw_count: int) -> float: + raw = _stable_seed( + context.seed, + context.epoch, + context.rank, + context.worker_id, + draw_count, + ) + return (raw / float(1 << 64)) * self._total_weight + + def _pick_source(self, context: SourceContext, draw_count: int) -> SourceSpec: + draw_value = self._source_draw_value(context, draw_count) + for source, upper in zip(self.sources, self._cumulative_weights, strict=True): + if draw_value < upper: + return source + return self.sources[-1] + + def iter_samples( + self, + context: SourceContext, + *, + state: dict | None = None, + ) -> Iterable[dict]: + live_state = self.normalize_state(state) + iterators: dict[str, object] = {} + + while True: + draw_count = int(live_state["draw_count"]) + source = self._pick_source(context, draw_count) + + while True: + child_state = live_state["sources"][source.name] + child_iter = iterators.get(source.name) + if child_iter is None: + raw_iter = source.adapter.iter_samples(context, state=child_state) + child_iter = iter(source.pipeline(raw_iter)) + iterators[source.name] = child_iter + + try: + sample = dict(next(child_iter)) + except StopIteration: + if source.adapter.is_cycle_start_state(child_state): + raise RuntimeError( + "Weighted source yielded no samples for this worker. " + f"source={source.name!r}, worker={context.global_worker_id}, " + f"epoch={context.epoch}" + ) + iterators.pop(source.name, None) + live_state["sources"][source.name] = source.adapter.advance_cycle( + child_state + ) + continue + + next_child_state = sample.pop("_adapter_state", None) + if next_child_state is None: + raise RuntimeError( + f"{source.adapter.__class__.__name__} must attach '_adapter_state' to samples." + ) + live_state["sources"][source.name] = source.adapter.clone_state( + next_child_state + ) + live_state["draw_count"] = draw_count + 1 + sample["source_name"] = source.name + sample["_adapter_state"] = self.clone_state(live_state) + yield sample + break diff --git a/src/dots_tts/data/streaming.py b/src/dots_tts/data/streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..4356b9faea69635bb23203bafaaaf00fdd0092c8 --- /dev/null +++ b/src/dots_tts/data/streaming.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +import math +import multiprocessing as mp +from collections.abc import Iterable +from copy import deepcopy + +from torch.utils.data import DataLoader, IterableDataset, get_worker_info + +from dots_tts.data.batchers import OnlineBatcher +from dots_tts.utils.profiling import ensure_data_profiler +from dots_tts.data.source_adapters.base_adapter import BaseSourceAdapter, SourceContext + +_TRACKING_KEY = "__tracking_state__" +_RESUME_TOPOLOGY_KEY = "resume_topology" + + +def identity_collate(sample): + return sample + + +class StreamingSampleDataset(IterableDataset): + def __init__( + self, + *, + source: BaseSourceAdapter, + rank: int, + world_size: int, + seed: int, + ): + self.source = source + self.rank = int(rank) + self.world_size = int(world_size) + self.seed = int(seed) + self._epoch = mp.Value("q", 0) + self._pending_resume_state: dict | None = None + + def load_state_dict(self, state: dict | None) -> None: + self._pending_resume_state = deepcopy(state) if state else None + + def set_epoch(self, epoch: int) -> None: + with self._epoch.get_lock(): + self._epoch.value = int(epoch) + + def _current_epoch(self) -> int: + with self._epoch.get_lock(): + return int(self._epoch.value) + + def _take_resume_state(self, epoch: int) -> dict | None: + if ( + self._pending_resume_state is None + or int(self._pending_resume_state.get("epoch", -1)) != int(epoch) + ): + return None + state = deepcopy(self._pending_resume_state) + self._pending_resume_state = None + return state + + @staticmethod + def _validate_resume_topology( + resume_state: dict, + *, + context: SourceContext, + loader_num_workers: int, + ) -> None: + resume_topology = resume_state.get(_RESUME_TOPOLOGY_KEY) + if not isinstance(resume_topology, dict): + raise RuntimeError( + "Resume state is missing required worker topology metadata." + ) + expected_world_size = int(resume_topology["world_size"]) + expected_num_workers = int(resume_topology["loader_num_workers"]) + expected_global_worker_count = int(resume_topology["global_worker_count"]) + current_num_workers = int(loader_num_workers) + current_global_worker_count = int(context.global_worker_count) + if ( + expected_world_size != int(context.world_size) + or expected_num_workers != current_num_workers + or expected_global_worker_count != current_global_worker_count + ): + raise RuntimeError( + "Resume requires the same data worker topology as the saved state. " + f"saved(world_size={expected_world_size}, " + f"num_workers_per_rank={expected_num_workers}, " + f"global_worker_count={expected_global_worker_count}), " + f"current(world_size={context.world_size}, " + f"num_workers_per_rank={current_num_workers}, " + f"global_worker_count={current_global_worker_count})." + ) + + def __iter__(self) -> Iterable[dict]: + worker_info = get_worker_info() + if worker_info is None: + worker_id = 0 + loader_num_workers = 0 + effective_num_workers = 1 + else: + worker_id = worker_info.id + loader_num_workers = worker_info.num_workers + effective_num_workers = worker_info.num_workers + + epoch = self._current_epoch() + context = SourceContext( + epoch=epoch, + rank=self.rank, + world_size=self.world_size, + worker_id=worker_id, + num_workers=effective_num_workers, + seed=self.seed, + ) + resume_state = self._take_resume_state(epoch) + if resume_state is not None: + self._validate_resume_topology( + resume_state, + context=context, + loader_num_workers=loader_num_workers, + ) + worker_state = ( + None + if resume_state is None + else (resume_state.get("workers") or {}).get(str(context.global_worker_id)) + ) + sample_iter = self.source.iter_samples( + context, + state=None if worker_state is None else worker_state.get("adapter_state"), + ) + for sample in sample_iter: + sample["data_worker_id"] = context.worker_id + sample["data_global_worker_id"] = context.global_worker_id + yield sample + + +class _DataStateTracker: + def __init__(self, *, num_tokens_per_epoch: int | None): + self.num_tokens_per_epoch = ( + None if num_tokens_per_epoch is None else int(num_tokens_per_epoch) + ) + self._pending_state: dict | None = None + self._reset_for_epoch(epoch=0) + + def _reset_for_epoch(self, *, epoch: int) -> None: + self.epoch = int(epoch) + self.samples_emitted = 0 + self.num_text_tokens = 0 + self.num_audio_tokens = 0 + self.num_total_tokens = 0 + self.workers: dict[str, dict] = {} + self._next_sample_order_by_worker: dict[str, int] = {} + + def load_state_dict(self, state: dict | None) -> None: + self._pending_state = deepcopy(state) if state else None + + def set_epoch(self, epoch: int) -> None: + if self._pending_state is not None and int( + self._pending_state.get("epoch", -1) + ) == int(epoch): + state = deepcopy(self._pending_state) + self._pending_state = None + self.epoch = int(state.get("epoch", epoch)) + self.samples_emitted = int(state.get("samples_emitted", 0)) + self.num_text_tokens = int(state.get("num_text_tokens", 0)) + self.num_audio_tokens = int(state.get("num_audio_tokens", 0)) + self.num_total_tokens = int(state.get("num_total_tokens", 0)) + self.workers = deepcopy(state.get("workers") or {}) + self._next_sample_order_by_worker = { + worker_key: int((worker_state or {}).get("sample_order", -1)) + 1 + for worker_key, worker_state in self.workers.items() + } + return + self._reset_for_epoch(epoch=int(epoch)) + + def should_stop(self) -> bool: + return ( + self.num_tokens_per_epoch is not None + and self.num_total_tokens >= self.num_tokens_per_epoch + ) + + def stage_sample(self, sample: dict) -> dict: + item = dict(sample) + worker_key = str(item.pop("data_global_worker_id")) + item.pop("data_worker_id", None) + adapter_state = item.pop("_adapter_state", None) + sample_order = int(self._next_sample_order_by_worker.get(worker_key, 0)) + self._next_sample_order_by_worker[worker_key] = sample_order + 1 + item[_TRACKING_KEY] = { + "worker_key": worker_key, + "adapter_state": deepcopy(adapter_state), + "sample_order": sample_order, + "num_text_tokens": int(item["num_text_tokens"]), + "num_audio_tokens": int(item["num_audio_tokens"]), + "num_total_tokens": int( + item.get("num_total_tokens", item["input_ids_length"]) + ), + } + return item + + def _pop_tracking(self, sample: dict) -> tuple[dict, dict]: + item = dict(sample) + tracking = item.pop(_TRACKING_KEY, None) + if not isinstance(tracking, dict): + raise RuntimeError("Tracked sample is missing internal resume metadata.") + return item, tracking + + def _advance_worker(self, tracking: dict) -> None: + adapter_state = tracking.get("adapter_state") + if adapter_state is None: + return + worker_key = str(tracking["worker_key"]) + sample_order = int(tracking.get("sample_order", -1)) + current_state = self.workers.get(worker_key) + current_order = int((current_state or {}).get("sample_order", -1)) + if current_order >= sample_order: + return + self.workers[worker_key] = { + "adapter_state": deepcopy(adapter_state), + "sample_order": sample_order, + } + + def mark_samples_dropped(self, samples: list[dict]) -> None: + for sample in samples: + _, tracking = self._pop_tracking(sample) + self._advance_worker(tracking) + + def commit_batch(self, samples: list[dict]) -> list[dict]: + committed: list[dict] = [] + for sample in samples: + item, tracking = self._pop_tracking(sample) + self._advance_worker(tracking) + self.samples_emitted += 1 + self.num_text_tokens += int(tracking["num_text_tokens"]) + self.num_audio_tokens += int(tracking["num_audio_tokens"]) + self.num_total_tokens += int(tracking["num_total_tokens"]) + committed.append(item) + return committed + + def state_dict(self) -> dict: + return { + "epoch": int(self.epoch), + "samples_emitted": int(self.samples_emitted), + "num_text_tokens": int(self.num_text_tokens), + "num_audio_tokens": int(self.num_audio_tokens), + "num_total_tokens": int(self.num_total_tokens), + "workers": deepcopy(self.workers), + "num_tokens_per_epoch": self.num_tokens_per_epoch, + } + + +class BatchedDataStream: + def __init__( + self, + *, + sample_dataset: StreamingSampleDataset, + data_cfg, + tokenizer, + num_tokens_per_epoch: int | None, + profiler=None, + ): + from dots_tts.data.collator import PadCollator + + self.sample_dataset = sample_dataset + self.profiler = ensure_data_profiler(profiler) + llm_token_rate = ( + float(data_cfg.train_audio_sample_rate) + / float(data_cfg.audio_samples_per_llm_token) + ) + self.batcher = OnlineBatcher( + max_audio_tokens_in_batch=max( + 1, + math.ceil(float(data_cfg.max_audio_seconds_in_batch) * llm_token_rate), + ), + max_text_tokens_in_batch=data_cfg.max_text_tokens_in_batch, + max_batch_size=data_cfg.max_samples_per_batch, + sample_pool_size=data_cfg.bucketing_pool_size, + profiler=self.profiler, + ) + self.sample_loader = None + self.collator = PadCollator(tokenizer) + self.data_state = _DataStateTracker( + num_tokens_per_epoch=num_tokens_per_epoch + ) + self._decision_iterator = None + self._sample_iterator = None + self._pending_batch = None + self._pending_samples = None + + def attach_loader(self, loader: DataLoader) -> None: + self.sample_loader = loader + + def close(self) -> None: + self._reset_iteration_state() + self.sample_loader = None + + def load_state_dict(self, state: dict | None) -> None: + self.data_state.load_state_dict(state) + self.sample_dataset.load_state_dict(state) + self._reset_iteration_state() + + def state_dict(self) -> dict: + if self.sample_loader is None: + raise RuntimeError("BatchedDataStream has no attached sample loader.") + if self._pending_batch is not None or self._pending_samples is not None: + raise RuntimeError( + "Cannot serialize BatchedDataStream while a batch is pending commit." + ) + loader_num_workers = int(getattr(self.sample_loader, "num_workers", 0)) + effective_num_workers = max(1, loader_num_workers) + state = self.data_state.state_dict() + state[_RESUME_TOPOLOGY_KEY] = { + "world_size": int(self.sample_dataset.world_size), + "loader_num_workers": loader_num_workers, + "global_worker_count": int(self.sample_dataset.world_size) + * effective_num_workers, + } + return state + + def set_epoch(self, epoch: int) -> None: + self.sample_dataset.set_epoch(epoch) + self.data_state.set_epoch(epoch) + self._reset_iteration_state() + + def _reset_iteration_state(self) -> None: + close_iterator = getattr(self._decision_iterator, "close", None) + if callable(close_iterator): + close_iterator() + self._decision_iterator = None + self._sample_iterator = None + self._pending_batch = None + self._pending_samples = None + + def _iter_staged_samples(self): + if self.sample_loader is None: + raise RuntimeError("BatchedDataStream has no attached sample loader.") + self._sample_iterator = iter(self.sample_loader) + profiler = self.profiler + try: + while True: + if self.data_state.should_stop(): + return + try: + with profiler.measure("main.loader_wait_next_sample"): + sample = next(self._sample_iterator) + except StopIteration: + return + if sample is None: + continue + with profiler.measure("main.stage_sample"): + staged = self.data_state.stage_sample(sample) + yield staged + finally: + self._sample_iterator = None + + def _decision_stream(self): + if self._decision_iterator is None: + self._decision_iterator = iter( + self.batcher.build_decisions(self._iter_staged_samples()) + ) + return self._decision_iterator + + def peek_batch(self) -> tuple[dict | None, bool]: + if self._pending_batch is not None: + return self._pending_batch, True + + for decision in self._decision_stream(): + if decision.dropped_samples: + self.data_state.mark_samples_dropped(decision.dropped_samples) + if not decision.batch_samples: + continue + self._pending_samples = decision.batch_samples + with self.profiler.measure( + "main.collate_batch", + count=len(decision.batch_samples), + ): + self._pending_batch = self.collator(decision.batch_samples) + return self._pending_batch, True + return None, False + + def commit_batch(self) -> dict: + if self._pending_batch is None or self._pending_samples is None: + raise RuntimeError("BatchedDataStream has no pending batch to commit.") + pending_batch = self._pending_batch + self.data_state.commit_batch(self._pending_samples) + self._pending_batch = None + self._pending_samples = None + return pending_batch + + def discard_batch(self) -> None: + if self._pending_batch is None or self._pending_samples is None: + raise RuntimeError("BatchedDataStream has no pending batch to discard.") + self._pending_batch = None + self._pending_samples = None + + def __iter__(self): + while True: + batch, has_batch = self.peek_batch() + if not has_batch: + return + self.commit_batch() + yield batch + if self.data_state.should_stop(): + return diff --git a/src/dots_tts/models/__init__.py b/src/dots_tts/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..482ee568273d12024e3b0bff255e4150a15a50d0 --- /dev/null +++ b/src/dots_tts/models/__init__.py @@ -0,0 +1 @@ +"""Model families.""" diff --git a/src/dots_tts/models/dots_tts/__init__.py b/src/dots_tts/models/dots_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9bc4ee612df1718db3054ac4383cf1f6c5a4b4 --- /dev/null +++ b/src/dots_tts/models/dots_tts/__init__.py @@ -0,0 +1 @@ +"""dots_tts model package.""" diff --git a/src/dots_tts/models/dots_tts/config.py b/src/dots_tts/models/dots_tts/config.py new file mode 100644 index 0000000000000000000000000000000000000000..56a1a61991693c091725ce9e6f5761f453f92f76 --- /dev/null +++ b/src/dots_tts/models/dots_tts/config.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from dots_tts.config.base import ConfigBase, StrictConfigBase +from dots_tts.modules.vocoder.config import AudioVAEConfig + + +class _EncoderConfig(ConfigBase): + num_layers: int = 6 + num_heads: int = 16 + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + modulation: bool = False + qkv_bias: bool = False + qk_norm: bool = False + attn_dropout: float = 0.0 + dropout: float = 0.0 + norm_layer: str = "LayerNorm" + alibi_bias: bool = False + rotary_bias: bool = False + rotary_theta: float | None = 10000 + input_dim: int = 1024 + causal: bool = True + + +class _DiTConfig(ConfigBase): + num_layers: int = 18 + num_heads: int = 16 + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + modulation: bool = True + qkv_bias: bool = False + qk_norm: bool = False + attn_dropout: float = 0.0 + dropout: float = 0.0 + norm_layer: str = "LayerNorm" + alibi_bias: bool = False + rotary_bias: bool = True + rotary_theta: float | None = 10000 + + +class LossConfig(StrictConfigBase): + ce_weight: float = 1.0 + fm_weight: float = 1.0 + eos_weight: float = 1.0 + + +class MeanFlowConfig(ConfigBase): + enabled: bool = False + use_duration_embedding: bool = True + + +class ModelConfig(ConfigBase): + model_type: str = "dots_tts" + latent_dim: int + patch_size: int + cfg_droprate: float = 0.2 + PatchEncoder: _EncoderConfig + DiT: _DiTConfig + vocoder: AudioVAEConfig + fm_sigma: float = 0.0 + xvec_drop_rate: float = 0.2 + campplus_embedding_size: int | None = 512 + xvec_max_audio_seconds: float = 10.0 + meanflow: MeanFlowConfig | None = None + + +__all__ = [ + "LossConfig", + "MeanFlowConfig", + "ModelConfig", +] diff --git a/src/dots_tts/models/dots_tts/core.py b/src/dots_tts/models/dots_tts/core.py new file mode 100644 index 0000000000000000000000000000000000000000..c0666c82b1ce4c8f9a7cc3099816b927dcbeaacc --- /dev/null +++ b/src/dots_tts/models/dots_tts/core.py @@ -0,0 +1,910 @@ +import copy +from dataclasses import dataclass +from typing import Any, Callable + +import torch +import torch.nn as nn +from einops import rearrange +from loguru import logger +from torch.nn.utils.rnn import pad_sequence +from torchdiffeq import odeint +from transformers import Qwen2Config, Qwen2ForCausalLM + +from dots_tts.models.dots_tts.config import ModelConfig +from dots_tts.modules.backbone.dit import DiT +from dots_tts.modules.backbone.semantic_encoder import VAESemanticEncoder +from dots_tts.utils.tokenizer import ( + AUDIO_COMP_SPAN_TOKEN, + AUDIO_GEN_SPAN_TOKEN, + TEXT_COND_END_TOKEN, + require_token_id, +) +from dots_tts.utils.util import get_mask_from_lengths, mask_data + + +@dataclass(frozen=True) +class DotsTtsForwardOutput: + llm_logits: torch.Tensor + pred: torch.Tensor + target: torch.Tensor + eos_out: torch.Tensor + + +class DotsTtsCore(nn.Module): + # region Module construction + def __init__( + self, + config: ModelConfig, + llm_config: Qwen2Config, + tokenizer=None, + *, + latent_stats_path, + ): + super().__init__() + self.config = config + self.fm_hidden_size = config.DiT.hidden_size + self.hidden_patch_size = 1 + self.cfg_droprate = config.get("cfg_droprate", 0.2) + self.latent_patch_size = config.patch_size + self.latent_dim = config.latent_dim + self.xvec_dim = config.campplus_embedding_size + self.xvec_drop_rate = config.get("xvec_drop_rate", 0.2) + + # Setup tokenizer + self.tokenizer = tokenizer + if self.tokenizer is None: + raise RuntimeError("Tokenizer must be provided before building the model.") + if llm_config is None: + raise RuntimeError("LLM config must be provided before building the model.") + self.pad_token_id = getattr(self.tokenizer, "pad_token_id", None) + self.audio_gen_span_id = require_token_id(self.tokenizer, AUDIO_GEN_SPAN_TOKEN) + self.audio_comp_span_id = require_token_id( + self.tokenizer, AUDIO_COMP_SPAN_TOKEN + ) + self.text_cond_end_id = require_token_id(self.tokenizer, TEXT_COND_END_TOKEN) + + # Setup LLM with language modeling head so we can obtain logits directly + llm_config = copy.deepcopy(llm_config) + llm_config.vocab_size = len(self.tokenizer) + self.llm = Qwen2ForCausalLM._from_config( + llm_config, + dtype=torch.float32, + ) + self.llm_hidden_size = self.llm.config.hidden_size + + self.patch_encoder = VAESemanticEncoder( + in_dim=self.latent_dim, + out_dim=self.llm_hidden_size, + config=config, + ) + + # Setup Flow matching related modules + self.hidden_proj = nn.Linear(self.llm_hidden_size, self.fm_hidden_size) + self.latent_proj = nn.Linear(self.latent_dim, self.fm_hidden_size) + self.coordinate_proj = nn.Linear(self.latent_dim, self.fm_hidden_size) + self.xvec_proj = nn.Sequential( + nn.Linear(self.xvec_dim, self.fm_hidden_size), + nn.LayerNorm(self.fm_hidden_size), + ) + self.meanflow_config = config.meanflow if config.meanflow is not None else None + self.mode = ( + "meanflow" + if self.meanflow_config is not None and self.meanflow_config.enabled + else "flow_matching" + ) + dit_mode = ( + "meanflow" + if self.mode == "meanflow" + and self.meanflow_config.use_duration_embedding + else "flow_matching" + ) + self.velocity_field_predictor = DiT( + in_dim=self.fm_hidden_size, + out_dim=self.latent_dim, + transformer_config=config.DiT, + mode=dit_mode, + ) + + # Setup eos predictor + self.eos_proj = nn.Sequential( + nn.Linear(self.llm_hidden_size, self.llm_hidden_size), + nn.SiLU(), + nn.Linear(self.llm_hidden_size, 2), + ) + + # Helpers + self.fm_helper = FlowMatchingHelper(sigma=config.get("fm_sigma", 0.0)) + self.causal_helper = CausalHelper() + self.io_helper = IOHelper(latent_stats_path=latent_stats_path) + self.audio_span_token_ids: list[int] = [ + self.audio_gen_span_id, + self.audio_comp_span_id, + ] + # endregion Module construction + + # region Training forward path + def forward(self, data: dict[str, Any]) -> DotsTtsForwardOutput: + input_ids: torch.Tensor = data["input_ids"] + input_ids_lengths: torch.Tensor = data["input_ids_lengths"] + input_span_mask: torch.Tensor = data["input_span_mask"] + output_span_mask: torch.Tensor = data["output_span_mask"] + batch_size = input_ids.size(0) + device = input_ids.device + + latents: torch.Tensor | None = data.get("latents") + latents_sampled: torch.Tensor | None = data.get("latents_sampled") + latent_lengths: torch.Tensor | None = data.get("latent_lengths") + has_latents = latents is not None or latents_sampled is not None + + patch_embeddings: torch.Tensor | None + valid_patch_counts: torch.Tensor | None + if has_latents: + if latents_sampled is None: + latents_sampled = self.io_helper.sample_from_latent(latents) + patch_embeddings = self.patch_encoder( + latents_sampled, x_lens=latent_lengths + ) + valid_patch_counts = latent_lengths // self.latent_patch_size + latents_sampled = self.io_helper.normalize(latents_sampled) + else: + latents_sampled = None + patch_embeddings = None + valid_patch_counts = torch.zeros( + batch_size, dtype=torch.long, device=device + ) + + input_span_counts = input_span_mask.sum(dim=1) + if input_span_counts.sum() > 0 and patch_embeddings is None: + raise RuntimeError( + "Found audio span tokens but no latents provided to compute patch embeddings." + ) + + # Token embeddings with audio span replacement + inputs_embeds = self.llm.get_input_embeddings()(input_ids) + if patch_embeddings is not None: + inputs_embeds = inputs_embeds.clone() + patch_embeddings = patch_embeddings.to(inputs_embeds.dtype) + for b in range(batch_size): + span_num = input_span_counts[b].item() + if span_num == 0: + continue + expected = valid_patch_counts[b].item() + if expected != span_num: + raise RuntimeError( + f"Mismatch between span tokens ({span_num}) and latent patches ({expected}) for sample {b}." + ) + indices = input_span_mask[b].nonzero(as_tuple=False).squeeze(-1) + inputs_embeds[b, indices, :] = patch_embeddings[b, :span_num, :] + + # LLM forward pass to obtain logits & hidden states + _llm_attn_mask, llm_seq_mask, _ = self.causal_helper.create_causal_mask_and_pos( + seq_lens=input_ids_lengths, max_len=input_ids.size(1) + ) + llm_outputs = self.llm( + inputs_embeds=inputs_embeds, + attention_mask=llm_seq_mask.long(), + use_cache=False, + output_hidden_states=True, + return_dict=True, + ) + llm_logits = llm_outputs.logits # [B, L, V] + llm_hidden = llm_outputs.hidden_states[-1] # [B, L, H] + + # eos prediction, before cfg masking + eos = self.eos_proj(llm_hidden.detach()) + + # Flow matching forward + total_patches = int(output_span_mask.sum().item()) + if total_patches > 0 and latents_sampled is None: + raise RuntimeError("Flow matching requested but latents are missing.") + if total_patches > 0: + xvec_cond = self.xvec_proj(data["xvector"]) + vocal_mask = data.get("vocal_mask") + if vocal_mask is None: + vocal_mask = torch.ones((batch_size,), device=device, dtype=torch.bool) + xvec_drop_mask = ( + torch.empty((batch_size,), device=device, dtype=torch.float32).uniform_( + 0, 1 + ) + < self.xvec_drop_rate + ) + xvec_drop_mask = xvec_drop_mask & vocal_mask + xvec_cond = mask_data(xvec_cond, xvec_drop_mask) + + hiddens_for_fm = torch.where( + output_span_mask.unsqueeze(-1), llm_hidden, inputs_embeds + ) + + # Prepare DiT inputs + ( + fm_seq, + target, + fm_attn_mask, + fm_seq_mask, + fm_pos_ids, + times, + fm_prefix_lengths, + fm_gen_lengths, + fm_gen_patch_size, + ) = self.io_helper.prepare_inputs_for_dit( + hiddens=hiddens_for_fm, + hidden_lens=input_ids_lengths, + latents=latents_sampled, + latent_lens=latent_lengths, + hidden_proj=self.hidden_proj, + latent_proj=self.latent_proj, + noisy_proj=self.coordinate_proj, + span_mask=output_span_mask, + hidden_patch_size=self.hidden_patch_size, + latent_patch_size=self.latent_patch_size, + fm_helper=self.fm_helper, + cfg_droprate=self.cfg_droprate, + ) + + # Predict velocity field + vt = self.velocity_field_predictor( + x=fm_seq, + timesteps=times, + pos_ids=fm_pos_ids, + mask=fm_seq_mask, + attn_mask=fm_attn_mask, + return_hidden_stats=False, + g_cond=xvec_cond, + ) + + # Get predictions and targets + pred = self.io_helper.get_dit_outputs( + pred_v=vt, + fm_prefix_lengths=fm_prefix_lengths, + fm_gen_lengths=fm_gen_lengths, + fm_gen_patch_size=fm_gen_patch_size, + latent_patch_size=self.latent_patch_size, + ) + else: + # Dummy forward for velocity_field_predictor to keep gradients connected in DDP + dummy_length = self.latent_patch_size + dummy_seq_h = llm_hidden.new_zeros((1, dummy_length, self.llm_hidden_size)) + dummy_seq_h = self.hidden_proj(dummy_seq_h) * 0.0 # dummy op for ddp + dummy_seq_l = llm_hidden.new_zeros((1, dummy_length, self.latent_dim)) + dummy_seq_l = self.latent_proj(dummy_seq_l) * 0.0 # dummy op for ddp + dummy_seq_c = llm_hidden.new_zeros((1, dummy_length, self.latent_dim)) + dummy_seq_c = self.coordinate_proj(dummy_seq_c) * 0.0 # dummy op for ddp + dummy_seq = dummy_seq_h + dummy_seq_l + dummy_seq_c + dummy_times = torch.zeros((1,), device=device, dtype=torch.float32) + dummy_attn_mask = torch.ones( + (1, dummy_length, dummy_length), device=device, dtype=torch.bool + ) + dummy_out = self.velocity_field_predictor( + x=dummy_seq, + timesteps=dummy_times, + attn_mask=dummy_attn_mask, + ) + pred = dummy_out[:, -self.latent_patch_size :, :] + target = pred.detach() + + return DotsTtsForwardOutput( + llm_logits=llm_logits, + pred=pred, + target=target, + eos_out=eos, + ) + # endregion Training forward path + + # region Autoregressive and flow-matching inference steps + @torch.no_grad() + def fm_solver_step( + self, + t: torch.Tensor, + z: torch.Tensor, + *, + input_sequence: torch.Tensor, + cfg_sequence: torch.Tensor, + attn_mask: torch.Tensor, + pos_ids: torch.Tensor | None, + hidden_size: int, + patch_size: int, + g_cond: torch.Tensor | None, + guidance_scale: torch.Tensor | float, + ) -> torch.Tensor: + batch_size = input_sequence.size(0) + if input_sequence.shape != cfg_sequence.shape: + raise ValueError( + "FM input_sequence and cfg_sequence must share the same shape." + ) + if input_sequence.size(1) < patch_size: + raise ValueError( + "FM input sequence must reserve at least one latent patch slot." + ) + latent_start = input_sequence.size(1) - patch_size + z = self.coordinate_proj(z) + z_c = input_sequence.clone() + z_c[:, latent_start:] = z + z_branches = [z_c] + g_cond_t = ( + None if g_cond is None else g_cond.to(device=z_c.device, dtype=z_c.dtype) + ) + g_cond_branches = None if g_cond_t is None else [g_cond_t] + + z_cfg = cfg_sequence.clone() + z_cfg[:, latent_start:] = z + z_branches.append(z_cfg) + if g_cond_branches is not None: + g_cond_branches.append(torch.zeros_like(g_cond_t)) + + z_z = torch.cat(z_branches, dim=0) + t_t = t.reshape(1).repeat(len(z_branches)) + if g_cond_branches is not None: + g_cond_t = torch.cat(g_cond_branches, dim=0) + vt = self.velocity_field_predictor( + x=z_z, + timesteps=t_t, + attn_mask=attn_mask, + pos_ids=pos_ids, + g_cond=g_cond_t, + hidden_size=patch_size * 2 + hidden_size, + patch_size=patch_size + 1, + ) + vt = vt[:, latent_start:] + vt_c = vt[:batch_size] + vt_u = vt[batch_size:] + if not torch.is_tensor(guidance_scale): + guidance_scale = vt_c.new_tensor(float(guidance_scale)) + else: + guidance_scale = guidance_scale.to(device=vt_c.device, dtype=vt_c.dtype) + return vt_c + guidance_scale * (vt_c - vt_u) + + @torch.no_grad() + def step_llm( + self, + inputs_embeds: torch.Tensor | None = None, + input_ids: torch.Tensor | None = None, + past_key_values: Any | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any | None]: + provided = int(inputs_embeds is not None) + int(input_ids is not None) + if provided != 1: + raise ValueError( + "Exactly one of inputs_embeds or input_ids must be provided to step_llm()." + ) + + if inputs_embeds is not None: + pass + else: + inputs_embeds = self.llm.get_input_embeddings()(input_ids) + + outputs = self.llm( + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=True, + output_hidden_states=True, + return_dict=True, + ) + + hidden = outputs.hidden_states[-1] + logits = outputs.logits + past_key_values = outputs.past_key_values + + return inputs_embeds, hidden, logits, past_key_values + + @torch.no_grad() + def _meanflow_step_fm( + self, + *, + input_sequence: torch.Tensor, + attn_mask: torch.Tensor, + pos_ids: torch.Tensor | None, + patch_size: int, + g_cond: torch.Tensor | None = None, + nfe: int = 2, + solver_step: Callable[..., torch.Tensor] | None = None, + ) -> torch.Tensor: + if nfe <= 0: + raise ValueError(f"MeanFlow nfe must be positive, got {nfe}.") + batch_size = input_sequence.size(0) + device = input_sequence.device + dtype = input_sequence.dtype + solver_step = self.meanflow_solver_step if solver_step is None else solver_step + z = ( + torch.randn( + (batch_size, patch_size, self.latent_dim), + device=device, + dtype=dtype, + ) + ) + times = torch.linspace(0.0, 1.0, nfe + 1, device=device, dtype=dtype) + + for step in range(nfe): + t = times[step].expand(batch_size) + dt = (times[step + 1] - times[step]).expand(batch_size) + z = solver_step( + z, + t=t, + dt=dt, + input_sequence=input_sequence, + attn_mask=attn_mask, + pos_ids=pos_ids, + patch_size=patch_size, + g_cond=g_cond, + ).clone() + return z + + @torch.no_grad() + def meanflow_solver_step( + self, + z: torch.Tensor, + *, + t: torch.Tensor, + dt: torch.Tensor, + input_sequence: torch.Tensor, + attn_mask: torch.Tensor, + pos_ids: torch.Tensor | None, + patch_size: int, + g_cond: torch.Tensor | None, + ) -> torch.Tensor: + if input_sequence.size(1) < patch_size: + raise ValueError( + "MeanFlow input sequence must reserve at least one latent patch slot." + ) + latent_start = input_sequence.size(1) - patch_size + z_proj = self.coordinate_proj(z) + z_c = input_sequence.clone() + z_c[:, latent_start:] = z_proj + vt = self.velocity_field_predictor( + x=z_c, + timesteps=t, + duration=dt, + attn_mask=attn_mask, + pos_ids=pos_ids, + g_cond=g_cond, + ) + velocity = vt[:, latent_start:] + return z + velocity * dt.view(-1, 1, 1) + + @torch.no_grad() + def _flow_matching_step_fm( + self, + *, + input_sequence: torch.Tensor, + cfg_sequence: torch.Tensor, + attn_mask: torch.Tensor, + pos_ids: torch.Tensor | None, + hidden_size: int, + patch_size: int, + g_cond: torch.Tensor | None = None, + ode_method: str = "euler", + num_steps: int = 10, + guidance_scale: float = 3.0, + solver_step: Callable[..., torch.Tensor] | None = None, + ) -> torch.Tensor: + batch_size = input_sequence.size(0) + num_evals = 0 + solver_step = self.fm_solver_step if solver_step is None else solver_step + guidance_scale_tensor = input_sequence.new_tensor(float(guidance_scale)) + + # Prepare ODE solver + def solver(t, z): + nonlocal num_evals + num_evals += 1 + return solver_step( + t, + z, + input_sequence=input_sequence, + cfg_sequence=cfg_sequence, + attn_mask=attn_mask, + pos_ids=pos_ids, + hidden_size=hidden_size, + patch_size=patch_size, + g_cond=g_cond, + guidance_scale=guidance_scale_tensor, + ) + + # Prepare noise as initial coordinate + noise = torch.randn( + (batch_size, patch_size, self.latent_dim), + dtype=input_sequence.dtype, + device=input_sequence.device, + ) + # Solve + times = torch.tensor( + [0.0, 1.0], dtype=input_sequence.dtype, device=input_sequence.device + ) + if ode_method in ["euler", "midpoint", "rk4"]: # fixed step size methods + options = {"step_size": 1.0 / num_steps} + else: + logger.warning( + "Using adaptive step size ODE solver for FM, NFE is not guaranteed: " + "ode_method={}", + ode_method, + ) + options = {} + trajectory = odeint( + func=solver, + y0=noise, + t=times, + atol=1e-5, + rtol=1e-5, + method=ode_method, + options=options, + ) + # print(f"Expected NFE: {num_steps}, Actual NFE: {num_evals}") + return trajectory[-1] + + @torch.no_grad() + def step_fm( + self, + input_sequence: torch.Tensor, + cfg_sequence: torch.Tensor, + attn_mask: torch.Tensor, + pos_ids: torch.Tensor | None, + hidden_size: int, + patch_size: int, + g_cond: torch.Tensor | None = None, + ode_method: str = "euler", + num_steps: int = 10, + guidance_scale: float = 3.0, + solver_step: Callable[..., torch.Tensor] | None = None, + ) -> torch.Tensor: + if self.mode == "meanflow": + return self._meanflow_step_fm( + input_sequence=input_sequence, + attn_mask=attn_mask, + pos_ids=pos_ids, + patch_size=patch_size, + g_cond=g_cond, + nfe=num_steps, + solver_step=solver_step, + ) + + return self._flow_matching_step_fm( + input_sequence=input_sequence, + cfg_sequence=cfg_sequence, + attn_mask=attn_mask, + pos_ids=pos_ids, + hidden_size=hidden_size, + patch_size=patch_size, + g_cond=g_cond, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + solver_step=solver_step, + ) + # endregion Autoregressive and flow-matching inference steps + + +class FlowMatchingHelper: + """ + Base helper for computing x_t and u_t, given target x_1 and noise x_0 + ref: Flow matching for generative modeling, Lipman + """ + + def __init__(self, sigma=1e-5): + self.sigma = sigma + + def compute_mu_t(self, x1, t): + return t * x1 + + def compute_sigma_t(self, t): + return 1 - (1 - self.sigma) * t + + def sample_x_t(self, x0, x1, t): + mu_t = self.compute_mu_t(x1, t) + sigma_t = self.compute_sigma_t(t) + return mu_t + sigma_t * x0 + + def compute_u_t(self, x0, x1): + return x1 - (1 - self.sigma) * x0 + + def compute_xt_ut(self, x1, t=None, x0=None): + if x0 is None: + x0 = torch.randn_like(x1, device=x1.device) + if t is None: + t = torch.rand(x1.size(0), dtype=x1.dtype, device=x1.device) + times = t + t = t.reshape(-1, *([1] * (x1.dim() - 1))) + xt = self.sample_x_t(x0, x1, t) + ut = self.compute_u_t(x0, x1) + return xt, ut, times + + +class CausalHelper: + def create_causal_mask_and_pos(self, seq_lens, max_len): + seq_mask = get_mask_from_lengths(seq_lens, max_len=max_len).unsqueeze(1) + causal_mask = ( + torch.ones((max_len, max_len), device=seq_lens.device).triu(1).bool() + ) + causal_mask = ~causal_mask.unsqueeze(0) + attn_mask = seq_mask & causal_mask + return attn_mask, seq_mask.squeeze(1), None + + def create_causal_chunk_mask_and_pos( + self, + batch_size, + C_lens, + Z_lens, + span_mask, + patch_size=8, + ): + device = C_lens.device + total_lens = C_lens + Z_lens + attn_mask = torch.zeros( + (batch_size, total_lens.max(), total_lens.max()), + device=device, + dtype=torch.bool, + ) + pos_ids = [] + # | C2C | | + # | Z2C | Z2Z | + for i in range(batch_size): + C_len = C_lens[i] + Z_len = Z_lens[i] + + # C2C parts are standard causal attention + attn_mask[i, :C_len, :C_len] = ( + torch.ones((C_len, C_len), device=device, dtype=torch.bool) + .triu(1) + .logical_not() + ) + # Position ids in C parts are 0, 1, 2, ..., n + c_pos = torch.arange(C_len, device=device, dtype=torch.float32) + + # Z2Z parts are block diag attention + assert Z_len % patch_size == 0, "Z_len must be multiple of patch_size" + attn_mask[i, C_len : C_len + Z_len, C_len : C_len + Z_len] = ( + torch.block_diag( + *[ + torch.ones( + (patch_size, patch_size), device=device, dtype=torch.bool + ) + ] + * (Z_len // patch_size) + ) + ) + + # Z2C parts is full attention before current patch latents + # build according to span_mask + j_indices = torch.arange(Z_len, device=device) + patch_indices = j_indices // patch_size + patch_in_c_indices = torch.where(span_mask[i])[0][patch_indices] + attn_mask[ + i, + C_len + j_indices.unsqueeze(1), + torch.arange(C_len, device=device).unsqueeze(0), + ] = torch.arange(C_len, device=device).unsqueeze( + 0 + ) < patch_in_c_indices.unsqueeze(1) + # Position ids in Z parts start from current patch latents index in C parts + z_pos = (patch_in_c_indices + j_indices % patch_size).to(torch.float32) + pos_ids.append(torch.cat([c_pos, z_pos])) + seq_mask = get_mask_from_lengths(total_lens, max_len=total_lens.max().item()) + pos_ids = pad_sequence(pos_ids, batch_first=True, padding_value=0.0).to( + C_lens.device + ) + return attn_mask, seq_mask, pos_ids + + +class IOHelper: + def __init__(self, latent_stats_path=None): + if latent_stats_path is not None: + latent_stats = torch.load(latent_stats_path, weights_only=False) + self.global_mean = torch.as_tensor(latent_stats["mean"]) + self.global_var = torch.as_tensor(latent_stats["var"]) + else: + self.global_mean = None + self.global_var = None + + def normalize(self, x): + if self.global_mean is not None and self.global_var is not None: + x = (x - self.global_mean.to(x.device)) / torch.sqrt( + self.global_var.to(x.device) + ) + return x + + def denormalize(self, x): + if self.global_mean is not None and self.global_var is not None: + x = x * torch.sqrt(self.global_var.to(x.device)) + self.global_mean.to( + x.device + ) + return x + + @staticmethod + def sample_from_latent(latent): + mean, log_std = latent.chunk(2, 1) + z = mean + torch.randn_like(mean) * torch.exp(log_std) + return z.transpose(1, 2) + + @staticmethod + def prepare_inputs_for_dit( + hiddens, + hidden_lens, + latents, + latent_lens, + hidden_proj, + latent_proj, + noisy_proj, + span_mask, + hidden_patch_size, + latent_patch_size, + fm_helper, + cfg_droprate=-1, + ): + assert hidden_patch_size == 1, "Hidden patch size > 1 is not supported." + + B, _, _, device = *hiddens.shape, hiddens.device + + # Gather span hidden states for flow matching using span_mask + span_hidden_list = [] + for b in range(B): + indices = span_mask[b].nonzero(as_tuple=False).squeeze(-1) + span_hidden_list.append(hiddens[b, indices, :]) + hiddens = pad_sequence(span_hidden_list, batch_first=True, padding_value=0.0) + hidden_lens = torch.tensor( + [t.size(0) for t in span_hidden_list], device=device, dtype=torch.long + ) + + # Update span_mask to be all True for the new lengths + max_len = hiddens.size(1) + span_mask = torch.arange(max_len, device=device).expand( + B, max_len + ) < hidden_lens.unsqueeze(1) + + # Prepare history latents + history_latents = latent_proj(latents) + fm_dim = history_latents.shape[-1] + assert (latent_patch_size * history_latents.size(1) % latents.size(1)) == 0 + latent_history_patch_size = ( + latent_patch_size * history_latents.size(1) // latents.size(1) + ) + + # Prepare llm hidden with cfg masking + cfg_mask = ( + torch.empty((B,), dtype=torch.float, device=latents.device).uniform_(0, 1) + < cfg_droprate + ) + hiddens = hidden_proj(mask_data(hiddens, cfg_mask)) + + # Prepare noise latents + xt, ut, times = fm_helper.compute_xt_ut(latents) + projected_noise = noisy_proj(xt) + + # Initialize empty fm_seq + hist_chunk_size = hidden_patch_size + latent_history_patch_size + valid_patch_counts = latent_lens // latent_patch_size + fm_prefix_lengths = hidden_lens + valid_patch_counts * ( + hist_chunk_size - hidden_patch_size + ) + fm_gen_lengths = latent_lens + valid_patch_counts * hidden_patch_size + fm_gen_patch_size = hidden_patch_size + latent_patch_size + fm_seq_lengths = fm_prefix_lengths + fm_gen_lengths + fm_seq = torch.zeros( + (B, fm_seq_lengths.max().item(), fm_dim), + dtype=history_latents.dtype, + device=device, + ) + fm_target = [] + patch_context_lengths = [] + history_latent_span_mask = torch.zeros( + (B, fm_seq_lengths.max().item()), dtype=torch.bool, device=device + ) # to mark start positions of each history latents + + # Fill fm_seq + for b in range(B): + # Step 1: Interleave hiddens at span positions with patched_latents + interleaved = [] + span_mask_b = span_mask[b, : hidden_lens[b]] + interleaved.append( + hiddens[b, : hidden_lens[b]][span_mask_b].reshape( + valid_patch_counts[b], hidden_patch_size, fm_dim + ) + ) + interleaved.append( + history_latents[ + b, : valid_patch_counts[b] * latent_history_patch_size, : + ].reshape(valid_patch_counts[b], latent_history_patch_size, fm_dim) + ) + interleaved = torch.cat(interleaved, dim=1) + interleaved = rearrange( + interleaved, "n h d -> (n h) d" + ) # [num_spans*hist_chunk_size, D] + + # Step 2: Build mapping from input positions to fm positions + position_increment = torch.where( + span_mask_b, hist_chunk_size, 1 + ) # span->hist_chunk_size, non-span->1 + fm_seq_positions = ( + torch.cumsum(position_increment, dim=0) - position_increment + ) + + # Step 3: Scatter non-span hiddens + non_span_mask = ~span_mask_b + non_span_indices = fm_seq_positions[non_span_mask] # [num_non_spans] + fm_seq[b, non_span_indices, :] = hiddens[b, : hidden_lens[b]][ + non_span_mask, : + ] + + # Step 4: Scatter interleaved span tokens + span_indices = fm_seq_positions[span_mask_b] # [num_spans] + span_indices_expanded = torch.stack( + [span_indices + i for i in range(hist_chunk_size)], dim=1 + ) # [num_spans, hist_chunk_size] + span_indices_flat = span_indices_expanded.reshape( + -1 + ) # [num_spans*hist_chunk_size] + fm_seq[b, span_indices_flat, :] = interleaved + history_latent_span_mask[b, span_indices] = True + patch_context_lengths.append(span_indices.clone()) + + # Step 5: Fill with noise latents at the end + noise_part = [] + span_mask_b = span_mask[b, : hidden_lens[b]] + noise_part.append( + hiddens[b, : hidden_lens[b]][span_mask_b].reshape( + valid_patch_counts[b], hidden_patch_size, fm_dim + ) + ) + noise_part.append( + projected_noise[b, : latent_lens[b], :].reshape( + valid_patch_counts[b], latent_patch_size, fm_dim + ) + ) + noise_part = torch.cat(noise_part, dim=1) + noise_part = rearrange(noise_part, "n h d -> (n h) d") + noise_start = fm_seq_positions[-1] + position_increment[-1] + noise_end = noise_start + fm_gen_lengths[b] + fm_seq[b, noise_start:noise_end, :] = noise_part + + # Step 6: prepare fm_target + ut_b = ut[b, : latent_lens[b], :] + fm_target.append(rearrange(ut_b, "(n p) d -> n p d", p=latent_patch_size)) + + # Construct fm_attn_mask and fm_pos_ids + fm_attn_mask, fm_seq_mask, fm_pos_ids = ( + CausalHelper().create_causal_chunk_mask_and_pos( + batch_size=B, + C_lens=fm_prefix_lengths, + Z_lens=fm_gen_lengths, + span_mask=history_latent_span_mask, + patch_size=fm_gen_patch_size, + ) + ) + fm_prefix_lengths = fm_prefix_lengths.unsqueeze(1) + fm_gen_lengths = fm_gen_lengths.unsqueeze(1) + fm_target = torch.cat(fm_target, dim=0) + results = [ + fm_seq, + fm_target, + fm_attn_mask, + fm_seq_mask, + fm_pos_ids, + times, + fm_prefix_lengths, + fm_gen_lengths, + fm_gen_patch_size, + ] + return tuple(results) + + @staticmethod + def get_dit_outputs( + pred_v, + fm_prefix_lengths, + fm_gen_lengths, + fm_gen_patch_size, + latent_patch_size, + ): + B, P = fm_prefix_lengths.shape + fm_pred = [] + for b in range(B): + p_offset = 0 + for p in range(P): + latents_b = pred_v[ + b, + p_offset + fm_prefix_lengths[b][p] : p_offset + + fm_prefix_lengths[b][p] + + fm_gen_lengths[b][p], + ] + latents_b = rearrange( + latents_b, "(n p) d -> n p d", p=fm_gen_patch_size + ) + # extract only the latent parts + latents_b = latents_b[:, -latent_patch_size:, :] + fm_pred.append(latents_b) + p_offset += fm_prefix_lengths[b][p] + fm_gen_lengths[b][p] + return torch.cat(fm_pred, dim=0) diff --git a/src/dots_tts/models/dots_tts/model.py b/src/dots_tts/models/dots_tts/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d7f3e1c6989be57dc521057c2ec47a9b4b5a0c50 --- /dev/null +++ b/src/dots_tts/models/dots_tts/model.py @@ -0,0 +1,1958 @@ +from __future__ import annotations + +import json +import math +import os +import shutil +from dataclasses import dataclass +from functools import partial +from pathlib import Path +from typing import Any, Callable, Iterator + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from loguru import logger +from safetensors.torch import load_file, save_file +from transformers import AutoTokenizer, Qwen2Config + +from dots_tts.models.dots_tts.config import ModelConfig +from dots_tts.models.dots_tts.core import DotsTtsCore, DotsTtsForwardOutput +from dots_tts.modules.speaker.encoder import SpeakerXVectorFeatures +from dots_tts.modules.vocoder.bigvgan import AudioVAE +from dots_tts.training.losses import LossMasks, LossTerm, LossTerms +from dots_tts.utils.profiling import measure_inference +from dots_tts.utils.tokenizer import AUDIO_GEN_START_TOKEN, require_token_id +from dots_tts.utils.util import get_dtype + + +_AOTI_BACKENDS = {"aoti", "aot", "aotinductor", "aot_inductor"} + + +class _AotiMethodModule(nn.Module): + def __init__(self, owner: nn.Module, method_name: str): + super().__init__() + self.owner = owner + self.method_name = method_name + + def forward(self, *args, **kwargs): + raw_method = getattr(type(self.owner), self.method_name, None) + if raw_method is None: + return getattr(self.owner, self.method_name)(*args, **kwargs) + raw_callable = getattr(raw_method, "__wrapped__", raw_method) + return raw_callable(self.owner, *args, **kwargs) + + +class _LazyAotiCompiledMethod: + def __init__( + self, + *, + key: str, + owner: nn.Module, + method_name: str, + signature: tuple[Any, ...] | None, + ): + self.key = key + self.owner = owner + self.method_name = method_name + self.signature = signature + self.compiled: Callable[..., Any] | None = None + self.fallback: Callable[..., Any] | None = None + + def __call__(self, *args, **kwargs): + if self.compiled is not None: + return self.compiled(*args, **kwargs) + if self.fallback is not None: + return self.fallback(*args, **kwargs) + + try: + import spaces # noqa: PLC0415 + + if not hasattr(spaces, "aoti_compile"): + raise RuntimeError("spaces.aoti_compile is not available.") + exported = torch.export.export( + _AotiMethodModule(self.owner, self.method_name).eval(), + args=args, + kwargs=kwargs, + ) + self.compiled = spaces.aoti_compile(exported) + logger.info( + "AOTI compiled inference target: key={} method={} signature={}", + self.key, + self.method_name, + self.signature, + ) + return self.compiled(*args, **kwargs) + except Exception: + if os.environ.get("DOTS_TTS_AOTI_ALLOW_EAGER_FALLBACK", "0") != "1": + raise + logger.exception( + "AOTI compile failed; falling back to eager method: key={} method={} signature={}", + self.key, + self.method_name, + self.signature, + ) + self.fallback = getattr(self.owner, self.method_name) + return self.fallback(*args, **kwargs) + + +@dataclass +class _GenerateState: + llm_cache: Any | None = None + llm_hiddens: torch.Tensor | None = None + patch_encoder_state: Any | None = None + fm_seq_len: int = 0 + fm_capacity: int = 0 + fm_sequence: torch.Tensor | None = None + fm_cfg_sequence: torch.Tensor | None = None + fm_null_g_cond: torch.Tensor | None = None + end_flag: bool = False + + +@dataclass(frozen=True) +class _PromptConditioning: + prompt_patches: torch.Tensor | None = None + prompt_latents: torch.Tensor | None = None + g_cond: torch.Tensor | None = None + + +@dataclass(frozen=True) +class _GenerateLengthBucket: + size: int + + def run_warmup( + self, + model: "DotsTtsModel", + *, + precision: str, + ode_method: str, + num_steps: int, + guidance_scale: float, + ) -> None: + model._warmup_fm_bucket( + max_audio_patch_count=self.size, + precision=precision, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + ) + model._warmup_patch_encoder_bucket( + max_audio_patch_count=self.size, + precision=precision, + ) + device = next(model.core.parameters()).device + generation_schedule = torch.full( + (1, self.size + 1), + fill_value=model.core.audio_gen_span_id, + dtype=torch.long, + device=device, + ) + generation_schedule[0, 0] = model.audio_gen_start_id + warmup_inputs = {"generation_schedule": generation_schedule} + + for _ in model.generate_audio_stream( + warmup_inputs, + precision=precision, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + ): + return + raise RuntimeError( + f"Warmup produced no audio chunk for generate bucket {self.size}." + ) + + +class DotsTtsModel(nn.Module): + """Full train/infer model assembly around the dots.tts core network.""" + + _GENERATE_LENGTH_BUCKETS = ( + _GenerateLengthBucket(32), + _GenerateLengthBucket(64), + _GenerateLengthBucket(128), + _GenerateLengthBucket(256), + _GenerateLengthBucket(512), + _GenerateLengthBucket(1024), + ) + _COMPILE_TARGETS = frozenset( + { + "FM", + "patch_encoder", + "vocoder", + } + ) + _optimize_enabled = True + CONFIG_FILENAME = "config.json" + HF_MODEL_TYPE = "dots_tts" + HF_ARCHITECTURES = ["DotsTTSForConditionalGeneration"] + LATENT_STATS_FILENAME = "latent_stats.pt" + LLM_CONFIG_FILENAME = "llm_config.json" + MODEL_FILENAME = "model.safetensors" + VOCODER_FILENAME = "vocoder.safetensors" + SPEAKER_ENCODER_FILENAME = "speaker_encoder.safetensors" + _ARTIFACT_ALIASES = (("llm.lm_head.weight", "llm.model.embed_tokens.weight"),) + REQUIRED_ARTIFACT_FILES = ( + CONFIG_FILENAME, + LATENT_STATS_FILENAME, + LLM_CONFIG_FILENAME, + MODEL_FILENAME, + VOCODER_FILENAME, + SPEAKER_ENCODER_FILENAME, + ) + + # region Module assembly and checkpoint IO + def __init__( + self, + config: ModelConfig, + tokenizer, + latent_stats_path: str | Path, + llm_config: Qwen2Config, + ): + super().__init__() + self.config = config + self.tokenizer = tokenizer + self.latent_stats_path = Path(latent_stats_path) + self.audio_gen_start_id = require_token_id( + self.tokenizer, AUDIO_GEN_START_TOKEN + ) + + self.core = DotsTtsCore( + config, + llm_config=llm_config, + tokenizer=tokenizer, + latent_stats_path=self.latent_stats_path, + ) + self.vocoder = AudioVAE(config.vocoder).eval() + self.vocoder.remove_weight_norm() + self.hop_size = self.vocoder.hop_size + self.xvector_extractor = SpeakerXVectorFeatures( + sample_rate=self.vocoder.sample_rate, + campplus_embedding_size=config.campplus_embedding_size, + max_audio_seconds=config.xvec_max_audio_seconds, + ).eval() + + for param in self.vocoder.parameters(): + param.requires_grad = False + for param in self.xvector_extractor.parameters(): + param.requires_grad = False + self._optimize_enabled = True + self._compiled_models: dict[ + tuple[str, tuple[Any, ...] | None], Callable[..., Any] + ] = {} + self._compile_backend = os.environ.get( + "DOTS_TTS_COMPILE_BACKEND", + "torch_compile", + ).strip().lower() + self._static_generate_workspaces: dict[tuple[Any, ...], dict[str, Any]] = {} + self._fm_decode_workspaces: dict[tuple[Any, ...], dict[str, torch.Tensor]] = {} + + def set_optimize(self, optimize: bool) -> None: + self._optimize_enabled = bool(optimize) + if not self._optimize_enabled: + self._compiled_models.clear() + + def set_compile_backend(self, backend: str) -> None: + normalized_backend = (backend or "torch_compile").strip().lower() + if normalized_backend != self._compile_backend: + self._compiled_models.clear() + self._compile_backend = normalized_backend + + def export_compiled_models( + self, + ) -> dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]]: + exported: dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]] = {} + for cache_key, compiled in self._compiled_models.items(): + if isinstance(compiled, _LazyAotiCompiledMethod): + if compiled.compiled is not None: + exported[cache_key] = compiled.compiled + continue + exported[cache_key] = compiled + return exported + + def import_compiled_models( + self, + compiled_models: dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]], + ) -> None: + self._compiled_models.update(compiled_models) + + def set_cfg_droprate( + self, + cfg_droprate: float | None = None, + xvec_drop_rate: float | None = None, + ) -> None: + if cfg_droprate is not None: + self.config.cfg_droprate = cfg_droprate + self.core.config.cfg_droprate = cfg_droprate + self.core.cfg_droprate = cfg_droprate + + if xvec_drop_rate is not None: + self.config.xvec_drop_rate = xvec_drop_rate + self.core.config.xvec_drop_rate = xvec_drop_rate + self.core.xvec_drop_rate = xvec_drop_rate + + @classmethod + def _resolve_generate_length_bucket( + cls, + max_generate_length: int, + ) -> _GenerateLengthBucket: + requested = int(max_generate_length) + if requested <= 0: + raise ValueError("max_generate_length must be positive.") + for bucket in cls._GENERATE_LENGTH_BUCKETS: + if requested <= bucket.size: + return bucket + raise ValueError( + "max_generate_length exceeds the largest supported compile bucket: " + f"max_generate_length={requested} " + f"max_supported={cls._GENERATE_LENGTH_BUCKETS[-1].size}." + ) + + @torch.no_grad() + def run_warmup( + self, + *, + max_generate_length: int, + precision: str = "bfloat16", + ode_method: str = "euler", + num_steps: int = 10, + guidance_scale: float = 1.2, + ) -> None: + ceiling_bucket = self._resolve_generate_length_bucket(max_generate_length) + warmup_buckets = tuple( + bucket + for bucket in self._GENERATE_LENGTH_BUCKETS + if bucket.size <= ceiling_bucket.size + ) + bucket_sizes = [bucket.size for bucket in warmup_buckets] + logger.info( + "Inference warmup started: requested_max_generate_length={} bucket_sizes={}", + int(max_generate_length), + bucket_sizes, + ) + for bucket in warmup_buckets: + bucket.run_warmup( + self, + precision=precision, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + ) + logger.info( + "Inference warmup completed: requested_max_generate_length={} bucket_sizes={}", + int(max_generate_length), + bucket_sizes, + ) + + def _resolve_state_audio_patch_count(self, max_audio_patch_count: int) -> int: + requested = int(max_audio_patch_count) + if requested <= 0: + raise ValueError("max_audio_patch_count must be positive.") + if not self._optimize_enabled: + return requested + return self._resolve_generate_length_bucket(requested).size + + def _warmup_fm_bucket( + self, + *, + max_audio_patch_count: int, + precision: str, + ode_method: str, + num_steps: int, + guidance_scale: float, + ) -> None: + dtype = get_dtype(precision) + device = next(self.core.parameters()).device + use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} + with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp): + state = self._allocate_generate_state( + max_audio_patch_count=max_audio_patch_count, + device=device, + dtype=dtype, + ) + state.fm_seq_len = state.fm_capacity + self._decode_next_audio( + state, + device=device, + g_cond=None, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + ) + + def _warmup_patch_encoder_bucket( + self, + *, + max_audio_patch_count: int, + precision: str, + ) -> None: + dtype = get_dtype(precision) + device = next(self.core.parameters()).device + state_dtype = dtype if device.type == "cuda" else torch.float32 + use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} + with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp): + state_audio_patch_count = self._resolve_state_audio_patch_count( + max_audio_patch_count + ) + patch_encoder_state = self.core.patch_encoder.init_decode_state( + max_audio_patch_count=state_audio_patch_count, + batch_size=1, + device=device, + dtype=state_dtype, + ) + audio_patch = torch.zeros( + ( + 1, + self.core.patch_encoder.patch_size, + self.core.latent_dim, + ), + dtype=state_dtype, + device=device, + ) + audio_patch = self.core.io_helper.denormalize(audio_patch) + patch_encoder_decode = self._get_compiled_method( + "patch_encoder.decode_patch", + self.core.patch_encoder, + "decode_patch", + signature=self._patch_encoder_compile_signature(patch_encoder_state), + ) + positions = torch.arange( + self.core.patch_encoder.out_ds_rate, + device=device, + dtype=torch.long, + ) + with measure_inference("patch_encoder"): + patch_encoder_decode( + audio_patch, + patch_encoder_state.conv_tail, + patch_encoder_state.layer_caches, + positions, + ) + + def _compile_callable( + self, + key: str, + model: Callable[..., Any], + *, + signature: tuple[Any, ...] | None = None, + ) -> Callable[..., Any]: + compile_target = key.split(".", maxsplit=1)[0] + cache_key = (key, signature) + compiled = self._compiled_models.get(cache_key) + if compiled is None: + mode = ( + "default" + if key == "patch_encoder.decode_patch" + else "reduce-overhead" + ) + compiled = torch.compile( + model, + mode=mode, + fullgraph=True, + dynamic=False, + ) + self._compiled_models[cache_key] = compiled + logger.info( + "Compiled inference target: key={} target={} signature={}", + key, + compile_target, + signature, + ) + return compiled + + def _get_compiled_model( + self, + key: str, + model: Callable[..., Any], + *, + signature: tuple[Any, ...] | None = None, + ) -> Callable[..., Any]: + compile_target = key.split(".", maxsplit=1)[0] + if not self._optimize_enabled or compile_target not in self._COMPILE_TARGETS: + return model + return self._compile_callable( + key, + model, + signature=signature, + ) + + def _get_compiled_method( + self, + key: str, + owner: Any, + method_name: str, + *, + signature: tuple[Any, ...] | None = None, + ) -> Callable[..., Any]: + bound_method = getattr(owner, method_name) + compile_target = key.split(".", maxsplit=1)[0] + if not self._optimize_enabled or compile_target not in self._COMPILE_TARGETS: + return bound_method + + cache_key = (key, signature) + if self._compile_backend in _AOTI_BACKENDS: + compiled = self._compiled_models.get(cache_key) + if compiled is None: + compiled = _LazyAotiCompiledMethod( + key=key, + owner=owner, + method_name=method_name, + signature=signature, + ) + self._compiled_models[cache_key] = compiled + return compiled + + raw_method = getattr(type(owner), method_name) + raw_callable = getattr(raw_method, "__wrapped__", raw_method) + compiled = self._compile_callable( + key, + raw_callable, + signature=signature, + ) + return partial(compiled, owner) + + def _allocate_generate_state( + self, + *, + max_audio_patch_count: int, + device: torch.device, + dtype: torch.dtype, + ) -> _GenerateState: + state_dtype = dtype if device.type == "cuda" else torch.float32 + state_audio_patch_count = self._resolve_state_audio_patch_count( + max_audio_patch_count + ) + fm_capacity = state_audio_patch_count * ( + self.core.hidden_patch_size + self.core.latent_patch_size + ) + workspace_key = ( + state_audio_patch_count, + str(device), + state_dtype, + ) + workspace = self._static_generate_workspaces.get(workspace_key) + if workspace is None: + workspace = { + "fm_sequence": torch.zeros( + (1, fm_capacity, self.core.fm_hidden_size), + dtype=state_dtype, + device=device, + ), + "fm_cfg_sequence": torch.zeros( + (1, fm_capacity, self.core.fm_hidden_size), + dtype=state_dtype, + device=device, + ), + "fm_null_g_cond": torch.zeros( + (1, self.core.fm_hidden_size), + dtype=state_dtype, + device=device, + ), + } + self._static_generate_workspaces[workspace_key] = workspace + else: + workspace["fm_sequence"].zero_() + workspace["fm_cfg_sequence"].zero_() + + patch_encoder_state = None + if not self._optimize_enabled: + patch_encoder_state = self.core.patch_encoder.init_decode_state( + max_audio_patch_count=state_audio_patch_count, + batch_size=1, + device=device, + dtype=state_dtype, + ) + + return _GenerateState( + patch_encoder_state=patch_encoder_state, + fm_seq_len=0, + fm_capacity=fm_capacity, + fm_sequence=workspace["fm_sequence"], + fm_cfg_sequence=workspace["fm_cfg_sequence"], + fm_null_g_cond=workspace["fm_null_g_cond"], + ) + + @staticmethod + def _tensor_storage_signature(tensor: torch.Tensor) -> tuple: + return ( + tensor.untyped_storage().data_ptr(), + tensor.storage_offset(), + tuple(tensor.size()), + tuple(tensor.stride()), + tensor.dtype, + ) + + @classmethod + def _build_artifact_state_dict(cls, module) -> dict[str, torch.Tensor]: + state_dict = module.state_dict() + skip_keys = set() + + for redundant_key, canonical_key in cls._ARTIFACT_ALIASES: + redundant_tensor = state_dict.get(redundant_key) + canonical_tensor = state_dict.get(canonical_key) + if ( + redundant_tensor is not None + and canonical_tensor is not None + and cls._tensor_storage_signature(redundant_tensor) + == cls._tensor_storage_signature(canonical_tensor) + ): + skip_keys.add(redundant_key) + + cleaned_state_dict = {} + seen_storage = set() + for key, value in state_dict.items(): + if key in skip_keys: + continue + + storage_signature = cls._tensor_storage_signature(value) + if storage_signature in seen_storage: + continue + + seen_storage.add(storage_signature) + cleaned_state_dict[key] = value.detach().cpu().contiguous() + + return cleaned_state_dict + + @classmethod + def _restore_artifact_state_dict(cls, state_dict: dict, module) -> dict: + restored_state_dict = dict(state_dict) + for redundant_key, canonical_key in cls._ARTIFACT_ALIASES: + if ( + canonical_key in restored_state_dict + and redundant_key not in restored_state_dict + and redundant_key in module.state_dict() + ): + restored_state_dict[redundant_key] = restored_state_dict[canonical_key] + return restored_state_dict + + @classmethod + def _save_artifact_module(cls, module, path: Path) -> None: + save_file(cls._build_artifact_state_dict(module), path) + + @classmethod + def _load_artifact_module(cls, module, path: Path): + state_dict = load_file(path, device="cpu") + restored_state_dict = cls._restore_artifact_state_dict(state_dict, module) + mismatch = module.load_state_dict(restored_state_dict, strict=False) + if mismatch.missing_keys or mismatch.unexpected_keys: + raise RuntimeError(f"Failed to load {path}: {mismatch}") + return module + + @classmethod + def _validate_pretrained_directory( + cls, pretrained_model_name_or_path: str | Path + ) -> Path: + pretrained_path = Path(pretrained_model_name_or_path).expanduser().resolve() + missing_files = [ + name + for name in cls.REQUIRED_ARTIFACT_FILES + if not (pretrained_path / name).is_file() + ] + if missing_files: + raise FileNotFoundError( + f"Pretrained path {pretrained_path} is missing required files: {missing_files}" + ) + return pretrained_path + + @classmethod + def _load_pretrained_config(cls, pretrained_path: Path) -> ModelConfig: + return ModelConfig.model_validate( + json.loads( + (pretrained_path / cls.CONFIG_FILENAME).read_text(encoding="utf-8") + ) + ) + + @staticmethod + def _save_llm_config(llm_config: Qwen2Config, path: Path) -> None: + path.write_text( + json.dumps(llm_config.to_dict(), ensure_ascii=True, indent=2), + encoding="utf-8", + ) + + @staticmethod + def _load_llm_config(path: Path) -> Qwen2Config: + return Qwen2Config.from_dict(json.loads(path.read_text(encoding="utf-8"))) + + def _tie_llm_weights(self) -> None: + if hasattr(self.core.llm, "tie_weights"): + self.core.llm.tie_weights() + + def save_pretrained(self, save_directory: str | Path) -> Path: + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + config_payload = self.config.to_declared_dict() + config_payload["model_type"] = self.HF_MODEL_TYPE + config_payload["architectures"] = list(self.HF_ARCHITECTURES) + (save_directory / self.CONFIG_FILENAME).write_text( + json.dumps(config_payload, ensure_ascii=True, indent=2), + encoding="utf-8", + ) + self._save_llm_config( + self.core.llm.config, + save_directory / self.LLM_CONFIG_FILENAME, + ) + self.tokenizer.save_pretrained(save_directory) + shutil.copy2( + self.latent_stats_path, + save_directory / self.LATENT_STATS_FILENAME, + ) + self._save_artifact_module(self.core, save_directory / self.MODEL_FILENAME) + self._save_artifact_module(self.vocoder, save_directory / self.VOCODER_FILENAME) + self._save_artifact_module( + self.xvector_extractor, + save_directory / self.SPEAKER_ENCODER_FILENAME, + ) + return save_directory + + def _load_pretrained_artifacts(self, pretrained_path: Path) -> None: + self.latent_stats_path = pretrained_path / self.LATENT_STATS_FILENAME + self.core.io_helper = type(self.core.io_helper)( + latent_stats_path=self.latent_stats_path + ) + self._load_artifact_module(self.core, pretrained_path / self.MODEL_FILENAME) + self._tie_llm_weights() + self._load_artifact_module( + self.vocoder, pretrained_path / self.VOCODER_FILENAME + ) + self._load_artifact_module( + self.xvector_extractor, + pretrained_path / self.SPEAKER_ENCODER_FILENAME, + ) + self.core.eval() + self.vocoder.eval() + self.xvector_extractor.eval() + + def load_pretrained_weights( + self, pretrained_model_name_or_path: str | Path + ) -> None: + pretrained_path = self._validate_pretrained_directory( + pretrained_model_name_or_path + ) + saved_config = self._load_pretrained_config(pretrained_path) + if saved_config.to_declared_dict() != self.config.to_declared_dict(): + raise ValueError( + f"Pretrained config at {pretrained_path} does not match the current model." + ) + saved_llm_config = self._load_llm_config( + pretrained_path / self.LLM_CONFIG_FILENAME + ) + if saved_llm_config.to_dict() != self.core.llm.config.to_dict(): + raise ValueError( + f"Pretrained LLM config at {pretrained_path} does not match the current model." + ) + self._load_pretrained_artifacts(pretrained_path) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str | Path): + logger.info( + "DotsTtsModel load started: pretrained_path={}", + pretrained_model_name_or_path, + ) + pretrained_model_name_or_path = cls._validate_pretrained_directory( + pretrained_model_name_or_path + ) + config = cls._load_pretrained_config(pretrained_model_name_or_path) + llm_config = cls._load_llm_config( + pretrained_model_name_or_path / cls.LLM_CONFIG_FILENAME + ) + logger.info( + "DotsTtsModel config loaded: pretrained_path={} sample_rate={} patch_size={}", + pretrained_model_name_or_path, + config.vocoder.sample_rate, + config.patch_size, + ) + tokenizer = AutoTokenizer.from_pretrained( + str(pretrained_model_name_or_path), + local_files_only=True, + ) + model = cls( + config, + tokenizer=tokenizer, + latent_stats_path=pretrained_model_name_or_path / cls.LATENT_STATS_FILENAME, + llm_config=llm_config, + ) + model._load_pretrained_artifacts(pretrained_model_name_or_path) + logger.info( + "DotsTtsModel load completed: pretrained_path={}", + pretrained_model_name_or_path, + ) + return model.eval() + + # endregion Module assembly and checkpoint IO + + # region Training batch preparation + @torch.no_grad() + def prepare_training_inputs(self, data: dict[str, Any]) -> dict[str, Any]: + self.vocoder.eval() + self.xvector_extractor.eval() + processed = dict(data) + sample: torch.Tensor | None = data.get("sample") + sample_lengths: torch.Tensor | None = data.get("sample_lengths") + + if sample is not None: + latents = self.vocoder.extract_latents(sample) + processed["latents"] = latents + if sample_lengths is not None: + processed["latent_lengths"] = sample_lengths // self.hop_size + else: + processed["latent_lengths"] = torch.full( + (latents.size(0),), + latents.size(-1), + dtype=torch.long, + device=latents.device, + ) + processed["latents_sampled"] = self.core.io_helper.sample_from_latent( + latents + ) + fbank = data.get("fbank") + fbank_lengths = data.get("fbank_lengths") + processed["xvector"] = self.xvector_extractor( + sample, + audio_lengths=sample_lengths, + fbank=fbank, + fbank_lengths=fbank_lengths, + ) + else: + processed["latents"] = None + processed["latent_lengths"] = None + + return processed + + def _build_audio_span_mask(self, token_ids: torch.Tensor) -> torch.Tensor: + span_mask = torch.zeros_like(token_ids, dtype=torch.bool) + for token_id in self.core.audio_span_token_ids: + span_mask = span_mask | (token_ids == token_id) + return span_mask + + def _prepare_loss_metadata(self, data: dict[str, Any]) -> dict[str, Any]: + input_ids: torch.Tensor = data["input_ids"] + labels: torch.Tensor = data["labels"] + loss_mask: torch.Tensor = data["loss_mask"] + input_span_mask = self._build_audio_span_mask(input_ids) + output_span_mask = self._build_audio_span_mask(labels) + output_span_mask_float = output_span_mask.to(loss_mask.dtype) + llm_loss_mask = loss_mask * (1.0 - output_span_mask_float) + fm_loss_mask = loss_mask * output_span_mask_float + patch_counts = output_span_mask.sum(dim=1) + max_patch_count = max(1, int(patch_counts.max().item())) + fm_patch_mask = loss_mask.new_zeros((loss_mask.size(0), max_patch_count)) + for batch_idx in range(loss_mask.size(0)): + count = int(patch_counts[batch_idx].item()) + if count <= 0: + continue + fm_patch_mask[batch_idx, :count] = fm_loss_mask[batch_idx].masked_select( + output_span_mask[batch_idx] + ) + + return { + "input_span_mask": input_span_mask, + "output_span_mask": output_span_mask, + "loss_masks": { + "ce_loss": llm_loss_mask, + "fm_loss": fm_patch_mask, + "eos_loss": self._build_eos_loss_mask(fm_loss_mask), + }, + } + + @staticmethod + def _build_eos_loss_mask(eos_loss_mask: torch.Tensor) -> torch.Tensor: + batch_size, seq_len = eos_loss_mask.shape + mask = eos_loss_mask.to(dtype=torch.bool) + target = torch.zeros((batch_size, seq_len), dtype=torch.bool, device=mask.device) + mask_counts = mask.sum(dim=1, keepdim=True) + cumulative = mask.long().cumsum(dim=1) + target[mask & (cumulative == mask_counts)] = True + + mask_counts_flat = mask_counts.squeeze(1) + neg_counts = (mask_counts_flat - 1).clamp_min(0).to(eos_loss_mask.dtype) + pos_weight = torch.where( + neg_counts > 0, + torch.full_like(neg_counts, 0.5), + torch.ones_like(neg_counts), + ).unsqueeze(1) + neg_weight = torch.where( + neg_counts > 0, + 0.5 / neg_counts, + torch.zeros_like(neg_counts), + ).unsqueeze(1) + + positive_mask = target & mask + negative_mask = mask & ~positive_mask + return torch.where( + positive_mask, + pos_weight, + negative_mask.to(eos_loss_mask.dtype) * neg_weight, + ) + # endregion Training batch preparation + + # region Training loss assembly and forward + @staticmethod + def _compute_ce_loss_term( + llm_logits: torch.Tensor, + llm_labels: torch.Tensor, + llm_loss_mask: torch.Tensor, + ) -> LossTerm: + vocab_size = llm_logits.size(-1) + ce_loss = F.cross_entropy( + llm_logits.view(-1, vocab_size), + llm_labels.view(-1), + reduction="none", + ).view_as(llm_labels) + return LossTerm(loss=ce_loss, mask=llm_loss_mask.to(ce_loss.dtype)) + + @staticmethod + def _compute_fm_loss_term( + pred: torch.Tensor, + target: torch.Tensor, + fm_patch_mask: torch.Tensor, + ) -> LossTerm: + batch_size, max_patch_count = fm_patch_mask.shape + fm_loss = (pred - target).pow(2).mean(dim=2).mean(dim=1) + loss = fm_loss.new_zeros((batch_size, max_patch_count)) + patch_counts = fm_patch_mask.gt(0).sum(dim=1).tolist() + expected_count = int(sum(patch_counts)) + if expected_count > 0 and int(fm_loss.numel()) != expected_count: + raise RuntimeError( + "Flow-matching loss count mismatch: " + f"expected {expected_count}, got {int(fm_loss.numel())}." + ) + + offset = 0 + for batch_idx, patch_count in enumerate(patch_counts): + if patch_count <= 0: + continue + next_offset = offset + int(patch_count) + loss[batch_idx, :patch_count] = fm_loss[offset:next_offset] + offset = next_offset + return LossTerm(loss=loss, mask=fm_patch_mask.to(loss.dtype)) + + @staticmethod + def _compute_eos_loss_term( + eos_out: torch.Tensor, + eos_loss_mask: torch.Tensor, + ) -> LossTerm: + batch_size, seq_len, _ = eos_out.shape + weights = eos_loss_mask.to(device=eos_out.device) + mask = weights.gt(0) + target = torch.zeros( + (batch_size, seq_len), + dtype=torch.long, + device=eos_out.device, + ) + mask_counts = mask.sum(dim=1, keepdim=True) + cumulative = mask.long().cumsum(dim=1) + target[mask & (cumulative == mask_counts)] = 1 + + logits = rearrange(eos_out, "b n c -> b c n") + ce_per_token = F.cross_entropy(logits, target, reduction="none") + return LossTerm(loss=ce_per_token, mask=weights.to(ce_per_token.dtype)) + + @staticmethod + def _compute_eos_loss_stats( + eos_out: torch.Tensor, + eos_loss_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + weights = DotsTtsModel._build_eos_loss_mask(eos_loss_mask) + term = DotsTtsModel._compute_eos_loss_term(eos_out, weights) + mask = term.mask.to(device=term.loss.device, dtype=term.loss.dtype) + eos_loss_sum = (term.loss * mask).sum(dim=1) + eos_sample_count = eos_loss_mask.to(device=term.loss.device).gt(0).any( + dim=1 + ).to(term.loss.dtype) + return eos_loss_sum, eos_sample_count + + def _compute_loss_terms( + self, + outputs: DotsTtsForwardOutput, + *, + labels: torch.Tensor, + loss_masks: LossMasks, + ) -> LossTerms: + return { + "ce_loss": self._compute_ce_loss_term( + outputs.llm_logits, + labels, + loss_masks["ce_loss"], + ), + "fm_loss": self._compute_fm_loss_term( + outputs.pred, + outputs.target, + loss_masks["fm_loss"], + ), + "eos_loss": self._compute_eos_loss_term( + outputs.eos_out, + loss_masks["eos_loss"], + ), + } + + def prepare_training_batch(self, data: dict[str, Any]) -> dict[str, Any]: + prepared = dict(data) + prepared.update(self._prepare_loss_metadata(prepared)) + return prepared + + def forward(self, data: dict[str, Any]) -> LossTerms: + loss_masks: LossMasks = data["loss_masks"] + processed = self.prepare_training_inputs(data) + processed["input_span_mask"] = data["input_span_mask"] + processed["output_span_mask"] = data["output_span_mask"] + return self._compute_loss_terms( + self.core(processed), + labels=processed["labels"], + loss_masks=loss_masks, + ) + # endregion Training loss assembly and forward + + # region Prompt conditioning and decode state helpers + @torch.no_grad() + def _prepare_prompt_conditioning( + self, + prompt_audio: torch.Tensor | None, + *, + use_prompt_prefill: bool, + speaker_scale: float = 1.5, + ) -> _PromptConditioning: + if prompt_audio is None: + logger.info("Prompt conditioning skipped: no prompt audio provided.") + return _PromptConditioning() + + self.vocoder.eval() + self.xvector_extractor.eval() + device = next(self.core.parameters()).device + if prompt_audio.ndim == 1: + prompt_audio = prompt_audio.unsqueeze(0) + prompt_audio = prompt_audio.to(device=device) + + target_len = math.ceil( + prompt_audio.size(1) / (self.config.patch_size * self.hop_size) + ) * (self.config.patch_size * self.hop_size) + pad_len = target_len - prompt_audio.size(1) + if pad_len > 0: + prompt_audio = F.pad(prompt_audio, (0, pad_len)) + + speaker_encoder = self._get_compiled_model( + "speaker_encoder", + self.xvector_extractor, + ) + with measure_inference("speaker_encoder"): + speaker_embedding = ( + speaker_encoder(prompt_audio[None, :]) * float(speaker_scale) + ) + g_cond = self.core.xvec_proj(speaker_embedding) + if not use_prompt_prefill: + logger.info( + "Reference-audio-only conditioning prepared: prompt_samples={} speaker_scale={} device={}", + prompt_audio.shape[-1], + speaker_scale, + device, + ) + return _PromptConditioning(g_cond=g_cond) + + latent_encoder = self._get_compiled_model( + "latent_encoder", + self.vocoder.extract_latents, + ) + with measure_inference("latent_encoder"): + prompt_latents = latent_encoder(prompt_audio[None, :]) + prompt_latents_sampled = self.core.io_helper.sample_from_latent(prompt_latents) + prompt_latents_sampled = prompt_latents_sampled[:, : -self.config.patch_size] + prompt_patches = rearrange( + self.core.io_helper.normalize(prompt_latents_sampled), + "b (s p) d -> b s p d", + p=self.config.patch_size, + ) + logger.info( + "Prompt conditioning prepared: prompt_samples={} prompt_patch_count={} " + "speaker_scale={} device={}", + prompt_audio.shape[-1], + prompt_patches.size(1), + speaker_scale, + device, + ) + return _PromptConditioning( + prompt_patches=prompt_patches, + prompt_latents=prompt_latents_sampled, + g_cond=g_cond, + ) + + @staticmethod + def _patch_encoder_compile_signature( + patch_encoder_state: Any, + ) -> tuple[int, torch.dtype]: + key_cache, _ = patch_encoder_state.layer_caches[0] + return int(key_cache.size(2)), key_cache.dtype + + def _resolve_patch_encoder_audio_bucket(self, required_seq_len: int) -> int: + requested = int(required_seq_len) + if requested <= 0: + raise ValueError("required_seq_len must be positive.") + requested_patch_count = math.ceil( + requested / self.core.patch_encoder.out_ds_rate + ) + if not self._optimize_enabled: + return requested_patch_count + return self._resolve_generate_length_bucket(requested_patch_count).size + + def _copy_patch_encoder_state(self, source: Any, target: Any) -> None: + seq_len = source.seq_len + target_capacity = int(target.layer_caches[0][0].size(2)) + if seq_len > target_capacity: + raise ValueError( + "Patch encoder state copy exceeds target capacity: " + f"seq_len={seq_len} capacity={target_capacity}." + ) + + target.conv_tail.copy_(source.conv_tail) + target.seq_len = seq_len + for (source_key, source_value), (target_key, target_value) in zip( + source.layer_caches, + target.layer_caches, + strict=True, + ): + if seq_len > 0: + target_key[:, :, :seq_len, :].copy_(source_key[:, :, :seq_len, :]) + target_value[:, :, :seq_len, :].copy_(source_value[:, :, :seq_len, :]) + + def _ensure_patch_encoder_state_capacity( + self, + state: _GenerateState, + *, + required_seq_len: int, + device: torch.device, + dtype: torch.dtype, + ) -> None: + current_state = state.patch_encoder_state + if current_state is not None: + current_capacity = int(current_state.layer_caches[0][0].size(2)) + if current_capacity >= required_seq_len: + return + + target_audio_patch_count = self._resolve_patch_encoder_audio_bucket( + required_seq_len + ) + next_state = self.core.patch_encoder.init_decode_state( + max_audio_patch_count=target_audio_patch_count, + batch_size=1, + device=device, + dtype=dtype, + ) + if current_state is not None: + self._copy_patch_encoder_state(current_state, next_state) + state.patch_encoder_state = next_state + + def _prefill_prompt_latents( + self, + prompt_latents: torch.Tensor | None, + *, + state: _GenerateState, + ) -> torch.Tensor | None: + if prompt_latents is None: + return None + if prompt_latents.size(1) == 0: + return prompt_latents.new_zeros( + (prompt_latents.size(0), 0, self.core.llm_hidden_size) + ) + self._ensure_patch_encoder_state_capacity( + state, + required_seq_len=( + (prompt_latents.size(1) // self.core.patch_encoder.patch_size) + * self.core.patch_encoder.out_ds_rate + ), + device=prompt_latents.device, + dtype=( + state.fm_sequence.dtype + if state.fm_sequence is not None + else prompt_latents.dtype + ), + ) + with measure_inference("patch_encoder"): + prompt_patch_embeddings, state.patch_encoder_state = ( + self.core.patch_encoder.prefill( + prompt_latents, + state.patch_encoder_state, + ) + ) + return prompt_patch_embeddings + + def _get_fm_decode_workspace( + self, + *, + total_len: int, + device: torch.device, + dtype: torch.dtype, + ) -> dict[str, torch.Tensor]: + workspace_key = (total_len, str(device), dtype) + workspace = self._fm_decode_workspaces.get(workspace_key) + if workspace is None: + workspace = { + "input_sequence": torch.zeros( + (1, total_len, self.core.fm_hidden_size), + dtype=dtype, + device=device, + ), + "cfg_sequence": torch.zeros( + (1, total_len, self.core.fm_hidden_size), + dtype=dtype, + device=device, + ), + "attn_mask": torch.zeros( + (1, total_len, total_len), + dtype=torch.bool, + device=device, + ), + "pos_ids": torch.zeros( + (1, total_len), + dtype=torch.float32, + device=device, + ), + } + self._fm_decode_workspaces[workspace_key] = workspace + else: + workspace["input_sequence"].zero_() + workspace["cfg_sequence"].zero_() + return workspace + + def _resolve_fm_history_bucket_capacity(self, fm_seq_len: int) -> int: + requested = int(fm_seq_len) + if requested <= 0: + raise ValueError("fm_seq_len must be positive.") + if not self._optimize_enabled: + return requested + history_stride = self.core.hidden_patch_size + self.core.latent_patch_size + requested_patch_count = math.ceil(requested / history_stride) + return self._resolve_generate_length_bucket( + requested_patch_count + ).size * history_stride + + def _build_fm_attn_mask( + self, + *, + state: _GenerateState, + attn_mask: torch.Tensor, + ) -> torch.Tensor: + if state.fm_seq_len <= 0: + raise RuntimeError("FM sequence length must be positive before decode.") + hidden_patch_size = self.core.hidden_patch_size + latent_start = attn_mask.size(-1) - self.core.latent_patch_size + attn_mask.zero_() + block_start = state.fm_seq_len - hidden_patch_size + if block_start > 0: + causal_mask = torch.ones( + (block_start, block_start), + device=attn_mask.device, + dtype=torch.bool, + ).triu(1).logical_not() + attn_mask[:, :block_start, :block_start] = causal_mask + + attn_mask[:, block_start : state.fm_seq_len, : state.fm_seq_len] = True + attn_mask[:, block_start : state.fm_seq_len, latent_start:] = True + attn_mask[:, latent_start:, : state.fm_seq_len] = True + attn_mask[:, latent_start:, latent_start:] = True + if latent_start > state.fm_seq_len: + padding_indices = torch.arange( + state.fm_seq_len, + latent_start, + device=attn_mask.device, + ) + attn_mask[:, padding_indices, padding_indices] = True + return attn_mask + + def _build_fm_pos_ids( + self, + *, + state: _GenerateState, + pos_ids: torch.Tensor, + ) -> torch.Tensor: + if state.fm_seq_len <= 0: + raise RuntimeError("FM sequence length must be positive before decode.") + pos_ids.zero_() + latent_start = pos_ids.size(-1) - self.core.latent_patch_size + pos_ids[:, : state.fm_seq_len] = torch.arange( + state.fm_seq_len, + device=pos_ids.device, + dtype=pos_ids.dtype, + ) + pos_ids[:, latent_start:] = torch.arange( + state.fm_seq_len, + state.fm_seq_len + self.core.latent_patch_size, + device=pos_ids.device, + dtype=pos_ids.dtype, + ) + return pos_ids + + def _prepare_fm_decode_inputs( + self, + state: _GenerateState, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + sequence = state.fm_sequence + cfg_sequence = state.fm_cfg_sequence + if sequence is None or cfg_sequence is None: + raise RuntimeError("FM static buffers are not initialized.") + history_bucket_capacity = self._resolve_fm_history_bucket_capacity( + state.fm_seq_len + ) + total_len = history_bucket_capacity + self.core.latent_patch_size + workspace = self._get_fm_decode_workspace( + total_len=total_len, + device=sequence.device, + dtype=sequence.dtype, + ) + workspace["input_sequence"][:, : state.fm_seq_len].copy_( + sequence[:, : state.fm_seq_len] + ) + workspace["cfg_sequence"][:, : state.fm_seq_len].copy_( + cfg_sequence[:, : state.fm_seq_len] + ) + return ( + workspace["input_sequence"], + workspace["cfg_sequence"], + workspace["attn_mask"], + workspace["pos_ids"], + history_bucket_capacity, + ) + + def _append_to_fm_buffer( + self, + buffer: torch.Tensor | None, + state: _GenerateState, + chunk: torch.Tensor, + ) -> tuple[int, int]: + if buffer is None: + raise RuntimeError("FM static buffer is not initialized.") + start = state.fm_seq_len + end = start + chunk.size(1) + if end > state.fm_capacity: + raise RuntimeError( + "FM StaticBuffer capacity exceeded: " + f"next_length={end} capacity={state.fm_capacity}." + ) + buffer[:, start:end].copy_(chunk.to(buffer.dtype)) + return start, end + + def _append_hidden_chunk( + self, state: _GenerateState, hidden_chunk: torch.Tensor + ) -> None: + last_hidden = hidden_chunk[:, -self.core.hidden_patch_size :, :] + projected = self.core.hidden_proj(last_hidden) + null_projected = self.core.hidden_proj(torch.zeros_like(last_hidden)) + _start, end = self._append_to_fm_buffer( + state.fm_sequence, + state, + projected, + ) + cfg_buffer = state.fm_cfg_sequence + if cfg_buffer is None: + raise RuntimeError("FM cfg static buffer is not initialized.") + cfg_buffer[:, state.fm_seq_len : end].copy_(null_projected.to(cfg_buffer.dtype)) + state.fm_seq_len = end + + def _append_history_chunk( + self, state: _GenerateState, latent_chunk: torch.Tensor + ) -> None: + history_latent = self.core.latent_proj(latent_chunk) + _start, end = self._append_to_fm_buffer( + state.fm_sequence, + state, + history_latent, + ) + cfg_buffer = state.fm_cfg_sequence + if cfg_buffer is None: + raise RuntimeError("FM cfg static buffer is not initialized.") + cfg_buffer[:, state.fm_seq_len : end].copy_(history_latent.to(cfg_buffer.dtype)) + state.fm_seq_len = end + + def _consume_text_schedule( + self, + generation_schedule: torch.Tensor, + *, + position: int, + next_audio_position: int, + state: _GenerateState, + ) -> int: + with measure_inference("LLM"): + text_chunk = generation_schedule[:, position:next_audio_position] + _, state.llm_hiddens, _, state.llm_cache = self.core.step_llm( + input_ids=text_chunk, + past_key_values=state.llm_cache, + ) + self._append_hidden_chunk(state, state.llm_hiddens) + return next_audio_position + + def _locate_prefill_boundary( + self, + *, + span_positions: torch.Tensor, + prompt_patch_count: int, + ) -> tuple[int, torch.Tensor]: + if span_positions.numel() > prompt_patch_count: + return int(span_positions[prompt_patch_count].item()), span_positions[ + :prompt_patch_count + ] + raise RuntimeError( + "Prefill boundary discovery failed despite prior schedule validation." + ) + + @staticmethod + def _find_audio_span_positions( + generation_schedule: torch.Tensor, + *, + audio_placeholder_ids: set[int], + ) -> torch.Tensor: + schedule = generation_schedule[0] + placeholder_ids = torch.tensor( + sorted(audio_placeholder_ids), + device=schedule.device, + dtype=schedule.dtype, + ) + return torch.nonzero( + torch.isin(schedule, placeholder_ids), + as_tuple=False, + ).squeeze(-1) + + @staticmethod + def _next_token_is_audio_span( + generation_schedule: torch.Tensor, + *, + position: int, + audio_placeholder_ids: set[int], + ) -> bool: + next_position = position + 1 + if next_position >= generation_schedule.size(1): + return False + return int(generation_schedule[0, next_position].item()) in audio_placeholder_ids + + def _build_prefill_inputs_embeds( + self, + generation_schedule: torch.Tensor, + *, + prompt_patch_embeddings: torch.Tensor | None, + prompt_span_positions: torch.Tensor, + ) -> torch.Tensor: + inputs_embeds = self.core.llm.get_input_embeddings()( + generation_schedule + ).clone() + if prompt_span_positions.numel() > 0: + if prompt_patch_embeddings is None: + raise RuntimeError( + "Prompt patch embeddings are required when prefill includes prompt audio spans." + ) + patch_embeddings = prompt_patch_embeddings[ + :, : prompt_span_positions.numel() + ].to(inputs_embeds.dtype) + if patch_embeddings.size(1) != prompt_span_positions.numel(): + raise RuntimeError( + f"Prompt patch embeddings ({patch_embeddings.size(1)}) do not match prompt span count ({prompt_span_positions.numel()})." + ) + inputs_embeds[:, prompt_span_positions, :] = patch_embeddings + return inputs_embeds + + def _prefill( + self, + generation_schedule: torch.Tensor, + *, + state: _GenerateState, + span_positions: torch.Tensor, + prompt_patches: torch.Tensor | None, + prompt_patch_embeddings: torch.Tensor | None, + audio_placeholder_ids: set[int], + ) -> int: + prompt_patch_count = ( + 0 if prompt_patches is None else int(prompt_patches.size(1)) + ) + prefill_end, prompt_span_positions = self._locate_prefill_boundary( + span_positions=span_positions, + prompt_patch_count=prompt_patch_count, + ) + if prefill_end == 0: + return 0 + inputs_embeds = self._build_prefill_inputs_embeds( + generation_schedule[:, :prefill_end], + prompt_patch_embeddings=prompt_patch_embeddings, + prompt_span_positions=prompt_span_positions, + ) + with measure_inference("LLM"): + _, llm_hiddens, _, state.llm_cache = self.core.step_llm( + inputs_embeds=inputs_embeds, + past_key_values=state.llm_cache, + ) + state.llm_hiddens = llm_hiddens[:, -1:, :] + + cursor = 0 + for prompt_index, span_position in enumerate(prompt_span_positions.tolist()): + if span_position > cursor: + self._append_hidden_chunk( + state, llm_hiddens[:, span_position - 1 : span_position, :] + ) + self._append_history_chunk(state, prompt_patches[:, prompt_index]) + if self._next_token_is_audio_span( + generation_schedule, + position=span_position, + audio_placeholder_ids=audio_placeholder_ids, + ): + self._append_hidden_chunk( + state, llm_hiddens[:, span_position : span_position + 1, :] + ) + cursor = span_position + 1 + if prefill_end > cursor: + self._append_hidden_chunk( + state, llm_hiddens[:, prefill_end - 1 : prefill_end, :] + ) + return prefill_end + + def _decode_next_audio( + self, + state: _GenerateState, + *, + device: torch.device, + g_cond: torch.Tensor | None, + ode_method: str, + num_steps: int, + guidance_scale: float, + ) -> torch.Tensor: + if state.fm_seq_len <= 0: + raise RuntimeError( + "Cannot decode audio before any conditioning state has been prefetched." + ) + if state.fm_sequence is None or state.fm_cfg_sequence is None: + raise RuntimeError("FM static buffers are not initialized.") + if state.fm_null_g_cond is None: + raise RuntimeError("FM null conditioning buffer is not initialized.") + fm_sequence, fm_cfg_sequence, fm_attn_mask, fm_pos_ids, history_bucket_capacity = ( + self._prepare_fm_decode_inputs(state) + ) + compile_signature = ( + (history_bucket_capacity, state.fm_sequence.dtype) + if self._optimize_enabled + else (state.fm_seq_len, state.fm_sequence.dtype) + ) + if g_cond is None: + g_cond = state.fm_null_g_cond + else: + g_cond = g_cond.to( + device=state.fm_null_g_cond.device, + dtype=state.fm_null_g_cond.dtype, + ) + with measure_inference("FM"): + attn_mask = self._build_fm_attn_mask( + state=state, + attn_mask=fm_attn_mask, + ) + pos_ids = self._build_fm_pos_ids( + state=state, + pos_ids=fm_pos_ids, + ) + if self.core.mode == "meanflow": + fm_solver_step = self._get_compiled_method( + "FM.meanflow.solver_step", + self.core, + "meanflow_solver_step", + signature=compile_signature, + ) + return self.core._meanflow_step_fm( + input_sequence=fm_sequence, + attn_mask=attn_mask, + pos_ids=pos_ids, + patch_size=self.core.latent_patch_size, + g_cond=g_cond, + nfe=num_steps, + solver_step=fm_solver_step, + ) + + fm_solver_step = self._get_compiled_method( + "FM.flow_matching.solver_step", + self.core, + "fm_solver_step", + signature=compile_signature, + ) + return self.core._flow_matching_step_fm( + input_sequence=fm_sequence, + cfg_sequence=fm_cfg_sequence, + attn_mask=attn_mask, + pos_ids=pos_ids, + hidden_size=self.core.hidden_patch_size, + patch_size=self.core.latent_patch_size, + g_cond=g_cond, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + solver_step=fm_solver_step, + ) + + def _consume_audio_patch( + self, + state: _GenerateState, + *, + audio_patch: torch.Tensor, + ) -> None: + audio_patch_for_llm = self.core.io_helper.denormalize(audio_patch) + self._append_history_chunk(state, audio_patch) + current_seq_len = ( + 0 + if state.patch_encoder_state is None + else state.patch_encoder_state.seq_len + ) + self._ensure_patch_encoder_state_capacity( + state, + required_seq_len=current_seq_len + self.core.patch_encoder.out_ds_rate, + device=audio_patch_for_llm.device, + dtype=( + state.fm_sequence.dtype + if state.fm_sequence is not None + else audio_patch_for_llm.dtype + ), + ) + patch_encoder_decode = self._get_compiled_method( + "patch_encoder.decode_patch", + self.core.patch_encoder, + "decode_patch", + signature=self._patch_encoder_compile_signature(state.patch_encoder_state), + ) + patch_positions = ( + torch.arange( + self.core.patch_encoder.out_ds_rate, + device=audio_patch_for_llm.device, + dtype=torch.long, + ) + + state.patch_encoder_state.seq_len + ) + with measure_inference("patch_encoder"): + llm_embedding, conv_tail = patch_encoder_decode( + audio_patch_for_llm, + state.patch_encoder_state.conv_tail, + state.patch_encoder_state.layer_caches, + patch_positions, + ) + state.patch_encoder_state.conv_tail.copy_(conv_tail) + state.patch_encoder_state.seq_len += self.core.patch_encoder.out_ds_rate + with measure_inference("LLM"): + _, state.llm_hiddens, _, state.llm_cache = self.core.step_llm( + inputs_embeds=llm_embedding, + past_key_values=state.llm_cache, + ) + + def _decode( + self, + generation_schedule: torch.Tensor, + *, + position: int, + state: _GenerateState, + audio_placeholder_ids: set[int], + span_positions: torch.Tensor, + device: torch.device, + g_cond: torch.Tensor | None, + ode_method: str, + num_steps: int, + guidance_scale: float, + eos_threshold: float, + ) -> Iterator[torch.Tensor]: + span_cursor = torch.searchsorted( + span_positions, + torch.tensor( + position, + device=span_positions.device, + dtype=span_positions.dtype, + ), + ).item() + while position < generation_schedule.size(1): + token_id = int(generation_schedule[0, position].item()) + if token_id in audio_placeholder_ids: + stop_after_current_audio = self._should_stop_after_current_audio( + state, + eos_threshold=eos_threshold, + ) + audio_patch = self._decode_next_audio( + state, + device=device, + g_cond=g_cond, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + ) + self._consume_audio_patch( + state, + audio_patch=audio_patch, + ) + if self._next_token_is_audio_span( + generation_schedule, + position=position, + audio_placeholder_ids=audio_placeholder_ids, + ): + self._append_hidden_chunk(state, state.llm_hiddens) + position += 1 + span_cursor += 1 + yield audio_patch + if stop_after_current_audio: + state.end_flag = True + return + continue + next_audio_position = ( + int(span_positions[span_cursor].item()) + if span_cursor < span_positions.numel() + else generation_schedule.size(1) + ) + position = self._consume_text_schedule( + generation_schedule, + position=position, + next_audio_position=next_audio_position, + state=state, + ) + + def _should_stop_after_current_audio( + self, state: _GenerateState, *, eos_threshold: float + ) -> bool: + if state.llm_hiddens is None: + return False + eos = ( + self.core.eos_proj(state.llm_hiddens).softmax(dim=-1)[:, -1, 1] + > eos_threshold + ) + return state.end_flag or bool(eos.item()) + + # endregion Prompt conditioning and decode state helpers + + # region Public generation APIs + @torch.no_grad() + def _generate_latents_stream( + self, + data: dict[str, Any], + *, + precision: str, + ode_method: str, + num_steps: int, + guidance_scale: float, + speaker_scale: float = 1.5, + eos_threshold: float = 0.8, + ) -> Iterator[torch.Tensor]: + dtype = get_dtype(precision) + device = next(self.core.parameters()).device + use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} + with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp): + generation_schedule: torch.Tensor = data["generation_schedule"] + if generation_schedule.size(0) != 1: + raise ValueError( + "DotsTtsModel.generate expects batch size 1 for generation_schedule." + ) + + use_prompt_prefill = data.get("prompt_audio") is not None and bool( + data.get("prompt_text") + ) + prompt_conditioning = self._prepare_prompt_conditioning( + data.get("prompt_audio"), + use_prompt_prefill=use_prompt_prefill, + speaker_scale=speaker_scale, + ) + has_prompt_prefill = prompt_conditioning.prompt_patches is not None + prompt_patch_count = ( + 0 + if not has_prompt_prefill + else int(prompt_conditioning.prompt_patches.size(1)) + ) + audio_placeholder_ids = set(self.core.audio_span_token_ids) + span_positions = self._find_audio_span_positions( + generation_schedule, + audio_placeholder_ids=audio_placeholder_ids, + ) + span_count = int(span_positions.numel()) + minimum_required_spans = prompt_patch_count + 1 + if span_count < minimum_required_spans: + raise ValueError( + f"generation_schedule provides {span_count} audio spans, but prompt prefill requires " + f"{prompt_patch_count} spans and generation requires at least one additional decode span." + ) + logger.info( + "Latent generation prepared: schedule_audio_spans={} prompt_patch_count={} " + "minimum_required_spans={}", + span_count, + prompt_patch_count, + minimum_required_spans, + ) + + state = self._allocate_generate_state( + max_audio_patch_count=span_count, + device=device, + dtype=dtype, + ) + prompt_patch_embeddings = self._prefill_prompt_latents( + prompt_conditioning.prompt_latents, + state=state, + ) + position = self._prefill( + generation_schedule, + state=state, + span_positions=span_positions, + prompt_patches=prompt_conditioning.prompt_patches, + prompt_patch_embeddings=prompt_patch_embeddings, + audio_placeholder_ids=audio_placeholder_ids, + ) + + payload_patch_count = 0 + should_drop_regenerated_prompt_patch = has_prompt_prefill + for audio_patch in self._decode( + generation_schedule, + position=position, + state=state, + audio_placeholder_ids=audio_placeholder_ids, + span_positions=span_positions, + device=device, + g_cond=prompt_conditioning.g_cond, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + eos_threshold=eos_threshold, + ): + if should_drop_regenerated_prompt_patch: + should_drop_regenerated_prompt_patch = False + continue + payload_patch_count += 1 + if payload_patch_count == 1 or payload_patch_count % 10 == 0: + logger.info( + "Latent generation progress: payload_audio_patches={}", + payload_patch_count, + ) + yield self.core.io_helper.denormalize(audio_patch) + + if payload_patch_count == 0: + if has_prompt_prefill: + raise RuntimeError( + "Generation produced no payload latents after discarding the regenerated prompt-tail patch. " + "This usually means EOS triggered immediately after prompt continuation " + "or the generation schedule did not provide an effective decode span." + ) + raise RuntimeError( + "Generation produced no decodable latents. " + "This usually means EOS triggered before the first decode patch " + "or the generation schedule did not provide an effective decode span." + ) + logger.info( + "Latent generation completed: payload_audio_patches={}", + payload_patch_count, + ) + + @torch.no_grad() + def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + with measure_inference("latent_decoder"): + return self.vocoder.inference_from_latents( + latents.transpose(1, 2).float(), + do_sample=False, + ) + + @torch.no_grad() + def _init_vocoder_stream_state(self) -> Any: + return self.vocoder.init_stream_state( + batch_size=1, + chunk_size=self.core.latent_patch_size, + ) + + @torch.no_grad() + def _stream_vocoder_patch( + self, + latent_patch: torch.Tensor, + *, + stream_state: Any, + ) -> torch.Tensor: + latents = latent_patch.transpose(1, 2) + if not self._optimize_enabled: + with measure_inference("vocoder"): + return self.vocoder.stream_step(latents, stream_state) + + valid_frames = min( + stream_state.decoder.total_frames, + stream_state.decoder.window.size(-1), + ) + valid_frames_tensor = stream_state.decoder.window.new_tensor( + valid_frames, + dtype=torch.int64, + ) + vocoder_step = self._get_compiled_method( + "vocoder.step", + self.vocoder, + "compiled_stream_step", + ) + with measure_inference("vocoder"): + audio_window, hidden_h, hidden_c, new_window = vocoder_step( + latents, + stream_state.lstm_hidden[0], + stream_state.lstm_hidden[1], + stream_state.decoder.window, + valid_frames_tensor, + ) + stream_state.lstm_hidden = (hidden_h.clone(), hidden_c.clone()) + stream_state.decoder.window = new_window.clone() + stream_state.decoder.total_frames += int(latents.size(-1)) + audio_chunk = self.vocoder._slice_stream_audio_window( + audio_window, + stream_state, + final=False, + ) + return audio_chunk.clone() + + @torch.no_grad() + def _flush_vocoder_stream(self, stream_state: Any) -> torch.Tensor: + with measure_inference("vocoder"): + return self.vocoder.stream_flush(stream_state) + + @torch.no_grad() + def generate_audio_stream( + self, + data: dict[str, Any], + *, + precision: str, + ode_method: str, + num_steps: int, + guidance_scale: float, + speaker_scale: float = 1.5, + eos_threshold: float = 0.8, + ) -> Iterator[torch.Tensor]: + stream_state = self._init_vocoder_stream_state() + for latent_patch in self._generate_latents_stream( + data, + precision=precision, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + speaker_scale=speaker_scale, + eos_threshold=eos_threshold, + ): + audio_chunk = self._stream_vocoder_patch( + latent_patch, + stream_state=stream_state, + ) + if audio_chunk.size(-1) > 0: + yield audio_chunk + + final_chunk = self._flush_vocoder_stream(stream_state) + if final_chunk.size(-1) > 0: + yield final_chunk + + @torch.no_grad() + def generate_audio( + self, + data: dict[str, Any], + *, + precision: str, + ode_method: str, + num_steps: int, + guidance_scale: float, + speaker_scale: float = 1.5, + ) -> torch.Tensor: + latent_patches = list( + self._generate_latents_stream( + data, + precision=precision, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + speaker_scale=speaker_scale, + ) + ) + logger.info( + "Vocoder decode started: latent_patch_count={}", + len(latent_patches), + ) + audio = self._decode_latents(torch.cat(latent_patches, dim=1)) + logger.info( + "Vocoder decode completed: waveform_samples={}", + audio.shape[-1], + ) + return audio + # endregion Public generation APIs diff --git a/src/dots_tts/modules/__init__.py b/src/dots_tts/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/dots_tts/modules/backbone/__init__.py b/src/dots_tts/modules/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c175d372db793678414083ad6efd623645551d7 --- /dev/null +++ b/src/dots_tts/modules/backbone/__init__.py @@ -0,0 +1 @@ +"""Backbone modules.""" diff --git a/src/dots_tts/modules/backbone/dit.py b/src/dots_tts/modules/backbone/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..c9673d48bdaf2df5f2fab6a942451205b9a279b0 --- /dev/null +++ b/src/dots_tts/modules/backbone/dit.py @@ -0,0 +1,205 @@ +import math + +import torch +import torch.nn as nn + +from dots_tts.modules.backbone.layers import Mlp, MultiHeadAttention + + +def modulate(x, shift, scale, **_kwargs): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + return self.mlp(t_freq) + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, output_size): + super().__init__() + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), + ) + self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-5) + self.linear = nn.Linear(hidden_size, output_size, bias=True) + + def forward(self, x, c, **_kwargs): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm(x), shift, scale) + return self.linear(x) + + +class DiTBlock(nn.Module): + def __init__( + self, + attention: nn.Module, + ffn: nn.Module, + hidden_size: int = 1024, + modulation: bool = False, + eps: float = 1e-5, + **_kwargs, + ): + super().__init__() + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=not modulation, eps=eps + ) + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=not modulation, eps=eps + ) + self.attn = attention + self.ffn = ffn + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True), + ) + + def forward(self, x, condition=None, mask=None, **kwargs): + if condition is None: + assert not self.modulation, ( + "Without global condition, must set modulation to False" + ) + else: + assert self.modulation, "With global condition, must set modulation to True" + shift_attn, scale_attn, gate_attn, shift_ffn, scale_ffn, gate_ffn = ( + self.adaLN_modulation(condition).chunk(6, dim=1) + ) + + if condition is not None: + pack_indices = kwargs.get("pack_indices") + if pack_indices is not None: + gate_attn = gate_attn[pack_indices] + gate_ffn = gate_ffn[pack_indices] + else: + gate_attn = gate_attn.unsqueeze(1) + gate_ffn = gate_ffn.unsqueeze(1) + + if condition is not None: + x = x + gate_attn * self.attn( + modulate(self.norm1(x), shift_attn, scale_attn, **kwargs), + mask=mask, + **kwargs, + ) + else: + x = x + self.attn(self.norm1(x), mask=mask, **kwargs) + + if condition is not None: + x = x + gate_ffn * self.ffn( + modulate(self.norm2(x), shift_ffn, scale_ffn, **kwargs) + ) + else: + x = x + self.ffn(self.norm2(x), mask=mask) + return x + + +class DiT(nn.Module): + def __init__( + self, + in_dim, + out_dim, + transformer_config, + *, + mode: str = "flow_matching", + ): + super().__init__() + if mode not in {"flow_matching", "meanflow"}: + raise ValueError( + f"DiT mode must be 'flow_matching' or 'meanflow', got {mode!r}." + ) + + transformer_kwargs = transformer_config.to_dict() + model_dim = transformer_config.hidden_size + self.mode = mode + self.num_layers = transformer_config.num_layers + + self.input_layer = nn.Linear(in_dim, model_dim) + self.time_embedder = TimestepEmbedder(model_dim) + if mode == "meanflow": + self.duration_embedder = TimestepEmbedder(model_dim) + + self.blocks = nn.ModuleList() + for i in range(self.num_layers): + attn_block = MultiHeadAttention(**transformer_kwargs, name=f"layer_{i}") + ffn_block = Mlp( + act_layer=lambda: nn.GELU(approximate="tanh"), **transformer_kwargs + ) + self.blocks.append( + DiTBlock(attention=attn_block, ffn=ffn_block, **transformer_kwargs) + ) + + self.output_layer = FinalLayer(model_dim, out_dim) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + nn.init.normal_(self.time_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.time_embedder.mlp[2].weight, std=0.02) + + for block in self.blocks: + if hasattr(block, "adaLN_modulation"): + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + nn.init.constant_(self.output_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.output_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.output_layer.linear.weight, 0) + nn.init.constant_(self.output_layer.linear.bias, 0) + + def forward( + self, + x, + timesteps, + duration: torch.Tensor | None = None, + mask=None, + attn_mask=None, + g_cond: torch.Tensor | None = None, + **kwargs, + ): + t = self.time_embedder(timesteps) + c = t + duration_embedder = getattr(self, "duration_embedder", None) + if duration_embedder is not None and duration is not None: + c = c + duration_embedder(duration) + if g_cond is not None: + c = c + g_cond + + x = self.input_layer(x) + for block in self.blocks: + x = block(x, c, mask=attn_mask, **kwargs) + return self.output_layer(x, c, **kwargs) diff --git a/src/dots_tts/modules/backbone/layers.py b/src/dots_tts/modules/backbone/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae1dcedce28456e9916b805e274605075a67b10 --- /dev/null +++ b/src/dots_tts/modules/backbone/layers.py @@ -0,0 +1,333 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class Dropout(nn.Module): + def __init__( + self, p: float = 0.5, inplace: bool = False, force_drop: bool = False, **_kwargs + ): + super().__init__() + if p < 0.0 or p > 1.0: + raise ValueError( + f"dropout probability has to be between 0 and 1, but got {p}" + ) + self.p = p + self.inplace = inplace + self.force_drop = force_drop + + def forward(self, x, **_kwargs): + return F.dropout( + x, + p=self.p, + training=True if self.force_drop else self.training, + inplace=self.inplace, + ) + + +class Conv1d(nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + padding_mode: str = "zeros", + bias: bool = True, + padding=None, + causal: bool = False, + **_kwargs, + ): + self.causal = causal + if padding is None: + if causal: + padding = 0 + self.left_padding = dilation * (kernel_size - 1) + else: + padding = int((kernel_size * dilation - dilation) / 2) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + bias=bias, + ) + + self.in_channels = in_channels + + def forward(self, x): + if self.causal: + x = F.pad(x.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2) + return super().forward(x) + + +class ConvTranspose1d(nn.ConvTranspose1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + output_padding: int = 0, + groups: int = 1, + bias: bool = True, + dilation: int = 1, + padding=None, + padding_mode: str = "zeros", + causal: bool = False, + **_kwargs, + ): + if padding is None: + padding = 0 if causal else (kernel_size - stride) // 2 + if causal: + assert padding == 0, "padding is not allowed in causal ConvTranspose1d." + assert kernel_size == 2 * stride, ( + "kernel_size must be equal to 2*stride in Causal ConvTranspose1d." + ) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode, + ) + + self.causal = causal + self.stride = stride + + def forward(self, x): + x = super().forward(x) + if self.causal: + x = x[:, :, : -self.stride] + return x + + +class Mlp(nn.Module): + def __init__( + self, + hidden_size, + ffn_hidden_size=4096, + act_layer=nn.GELU, + dropout=0.0, + **_kwargs, + ): + super().__init__() + self.fc1 = nn.Linear(hidden_size, ffn_hidden_size) + self.act = act_layer() + self.fc2 = nn.Linear(ffn_hidden_size, hidden_size) + self.drop = Dropout(dropout) + + def forward(self, x, _mask=None): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return self.drop(x) + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +@torch.autocast(enabled=False, device_type="cuda") +def apply_rotary_pos_emb(pos, t): + if pos.dim() == 3: + pos = pos.unsqueeze(1) + return t * pos.cos() + rotate_half(t) * pos.sin() + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, theta=50000): + super().__init__() + self.register_buffer( + "inv_freq", + 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)), + persistent=False, + ) + self._theta = float(theta) + + def _apply(self, fn): + inv_freq = self.inv_freq + super()._apply(fn) + self.inv_freq = inv_freq.to(device=self.inv_freq.device, dtype=torch.float32) + return self + + @torch.autocast(enabled=False, device_type="cuda") + def forward(self, t): + inv_freq = self.inv_freq + if inv_freq.device != t.device: + raise RuntimeError( + "RotaryEmbedding buffer device mismatch: " + f"inv_freq={inv_freq.device} input={t.device}." + ) + t = t.to(dtype=inv_freq.dtype) + if t.dim() == 1: + freqs = torch.einsum("i , j -> i j", t, inv_freq) + else: + freqs = torch.einsum("bi, j -> bij", t, inv_freq) + return torch.cat((freqs, freqs), dim=-1) + + +class MultiHeadAttention(nn.Module): + """Multi-head attention""" + + def __init__( + self, + hidden_size: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + dropout: float = 0.0, + norm_layer: str = "LayerNorm", + rotary_bias: bool = False, + rotary_theta: float | None = 50000, + **_kwargs, + ): + super().__init__() + assert hidden_size % num_heads == 0, ( + "hidden_size should be divisible by num_heads" + ) + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = self.head_dim**-0.5 + self.rotary_bias = rotary_bias + + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + norm_layer = getattr(nn, norm_layer) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + + self.attn_drop = Dropout(attn_drop) + self.o_proj = nn.Linear(hidden_size, hidden_size) + self.o_dropout = Dropout(dropout) + + if self.rotary_bias: + self.rotary = RotaryEmbedding(self.head_dim, theta=rotary_theta) + + def forward(self, q, k=None, v=None, mask=None, pos_ids=None, **_kwargs): + k = k or q + v = v or q + B, L, _ = q.shape + _, S, _ = v.shape + if mask is not None: + if mask.ndim == 2: # [B, L] + assert L == S + mask = rearrange(mask, "b j -> b 1 1 j") + mask = mask.expand(-1, self.num_heads, L, -1) + elif mask.ndim == 3: # [B, L, S] + assert mask.size(1) == L and mask.size(2) == S + mask = mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) + + q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v) + q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) + k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads) + v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads) + q, k = self.q_norm(q), self.k_norm(k) + + # Apply rotary + if self.rotary_bias: + if L == S: + if pos_ids is None: + rotary_emb = self.rotary(torch.arange(L, device=q.device)) + else: + rotary_emb = self.rotary(pos_ids) + q, k = (apply_rotary_pos_emb(rotary_emb, tensor) for tensor in (q, k)) + else: + q_rotary_emb = self.rotary(torch.arange(L, device=q.device)) + k_rotary_emb = self.rotary(torch.arange(S, device=k.device)) + q = apply_rotary_pos_emb(q_rotary_emb, q) + k = apply_rotary_pos_emb(k_rotary_emb, k) + + attn_bias = torch.zeros(B, self.num_heads, L, S, dtype=q.dtype, device=q.device) + + if mask is not None: + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_bias, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.o_dropout(self.o_proj(out)) + + def decode_step(self, x, *, cache, positions: torch.Tensor): + if x.size(1) <= 0: + raise ValueError("MultiHeadAttention.decode_step expects a non-empty input.") + if positions.ndim != 1 or positions.size(0) != x.size(1): + raise ValueError( + "MultiHeadAttention.decode_step positions must match the decode block length." + ) + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) + k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads) + v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads) + q, k = self.q_norm(q), self.k_norm(k) + block_len = q.size(2) + + if self.rotary_bias: + rotary_emb = self.rotary(positions) + q = apply_rotary_pos_emb(rotary_emb, q) + k = apply_rotary_pos_emb(rotary_emb, k) + + cached_k, cached_v = cache + cached_k.index_copy_(2, positions, k) + cached_v.index_copy_(2, positions, v) + + cache_capacity = cached_k.size(2) + key_positions = torch.arange( + cache_capacity, + device=x.device, + dtype=torch.long, + ).unsqueeze(0) + query_positions = positions.unsqueeze(1) + causal_mask = key_positions <= query_positions + valid_mask = key_positions <= positions[-1] + attn_bias = torch.zeros( + q.size(0), + self.num_heads, + block_len, + cache_capacity, + dtype=q.dtype, + device=q.device, + ) + attn_bias.masked_fill_( + (causal_mask & valid_mask).unsqueeze(0).unsqueeze(0).logical_not(), + float("-inf"), + ) + + out = F.scaled_dot_product_attention( + q, + cached_k, + cached_v, + attn_mask=attn_bias, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + out = rearrange(out, "b h n d -> b n (h d)") + return self.o_dropout(self.o_proj(out)), cache diff --git a/src/dots_tts/modules/backbone/semantic_encoder.py b/src/dots_tts/modules/backbone/semantic_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..df9e7f226f13edf5391c8742d8aca3d3cd89e673 --- /dev/null +++ b/src/dots_tts/modules/backbone/semantic_encoder.py @@ -0,0 +1,356 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from dots_tts.modules.backbone.layers import Conv1d, Mlp, MultiHeadAttention + + +@dataclass +class SemanticEncoderDecodeState: + conv_tail: torch.Tensor + layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...] + seq_len: int + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + hidden_size, + num_heads=16, + ffn_hidden_size=4096, + attn_dropout=0.0, + ffn_dropout=0.0, + norm_layer="LayerNorm", + **kwargs, + ): + super().__init__() + self.attn = MultiHeadAttention( + hidden_size, + num_heads, + attn_drop=attn_dropout, + norm_layer=norm_layer, + **kwargs, + ) + norm_cls = getattr(nn, norm_layer) + self.attn_norm = norm_cls(hidden_size) + self.ffn = Mlp( + hidden_size, ffn_hidden_size, dropout=ffn_dropout, act_layer=nn.SiLU + ) + self.ffn_norm = norm_cls(hidden_size) + self.hidden_size = hidden_size + + def _build_causal_mask(self, T: int, device): + return torch.tril(torch.ones(T, T, dtype=torch.bool, device=device)) + + def _build_padding_mask(self, x_lens, max_len: int, device): + B = x_lens.size(0) + positions = torch.arange(max_len, device=device).unsqueeze(0).expand(B, -1) + return positions < x_lens.unsqueeze(1) + + def _fuse_attn_mask(self, causal_mask, padding_mask): + if causal_mask is None and padding_mask is None: + return None + if causal_mask is None: + row = padding_mask.unsqueeze(2) + col = padding_mask.unsqueeze(1) + return row & col + if padding_mask is None: + return causal_mask.unsqueeze(0) + + _B, _T = padding_mask.shape + causal = causal_mask.unsqueeze(0) + row = padding_mask.unsqueeze(2) + col = padding_mask.unsqueeze(1) + pad_2d = row & col + return causal & pad_2d + + def forward( + self, + x, + x_lens=None, + causal=True, + ): + _B, T, C = x.shape + assert self.hidden_size == C + device = x.device + + causal_mask = self._build_causal_mask(T, device) if causal else None + if x_lens is not None: + padding_mask = self._build_padding_mask(x_lens, T, device) + else: + padding_mask = None + fused_mask = self._fuse_attn_mask(causal_mask, padding_mask) + + h = self.attn_norm(x) + h = self.attn( + q=h, + mask=fused_mask, + ) + x = x + h + + h = self.ffn_norm(x) + h = self.ffn(h) + return x + h + + def decode_step( + self, + x, + *, + cache: tuple[torch.Tensor, torch.Tensor], + positions: torch.Tensor, + ): + if x.size(1) <= 0: + raise ValueError( + "TransformerEncoderLayer.decode_step expects a non-empty input." + ) + + h = self.attn_norm(x) + h, cache = self.attn.decode_step(h, cache=cache, positions=positions) + x = x + h + + h = self.ffn_norm(x) + h = self.ffn(h) + return x + h, cache + + +class SuperviseEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.get("hidden_size", 1024) + self.layers = nn.ModuleList( + [ + TransformerEncoderLayer( + hidden_size=self.hidden_size, + num_heads=config.get("num_heads", 16), + ffn_hidden_size=config.get("ffn_hidden_size", 4096), + norm_layer=config.get("norm_layer", "LayerNorm"), + ) + for _ in range(config.get("num_layers", 6)) + ] + ) + self.causal = config.get("causal", False) + + def forward(self, x, x_lens=None): + batch_size, seq_len, _ = x.shape + if x_lens is None: + x_lens = torch.full( + (batch_size,), seq_len, device=x.device, dtype=torch.long + ) + for layer in self.layers: + x = layer(x, x_lens=x_lens, causal=self.causal) + return x + + def init_decode_state( + self, + *, + batch_size: int, + max_seq_len: int, + device: torch.device, + dtype: torch.dtype, + ): + layer_caches = [] + for layer in self.layers: + cache_shape = ( + batch_size, + layer.attn.num_heads, + max_seq_len, + layer.attn.head_dim, + ) + layer_caches.append( + ( + torch.zeros(cache_shape, dtype=dtype, device=device), + torch.zeros(cache_shape, dtype=dtype, device=device), + ) + ) + return tuple(layer_caches) + + def reset_decode_state( + self, + layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...], + ) -> None: + if len(layer_caches) != len(self.layers): + raise ValueError("Layer cache count does not match encoder depth.") + for key_cache, value_cache in layer_caches: + key_cache.zero_() + value_cache.zero_() + + def decode_step(self, x, *, layer_caches, positions: torch.Tensor): + if len(layer_caches) != len(self.layers): + raise ValueError("Layer cache count does not match encoder depth.") + + for layer, cache in zip(self.layers, layer_caches, strict=True): + x, _ = layer.decode_step(x, cache=cache, positions=positions) + return x + + +class VAESemanticEncoder(nn.Module): + def __init__(self, in_dim, out_dim, config): + super().__init__() + in_ds_rate = 2 + self.patch_size = int(config.patch_size) + self.in_ds_rate = in_ds_rate + self.ds_proj = Conv1d( + in_dim, in_dim, kernel_size=in_ds_rate, stride=in_ds_rate, causal=True + ) + self.in_proj = nn.Linear(in_dim, config.PatchEncoder.hidden_size) + self.encoder = SuperviseEncoder(config.PatchEncoder) + self.out_ds_rate = self.patch_size // in_ds_rate + self.out_proj = nn.Linear( + config.PatchEncoder.hidden_size * self.out_ds_rate, out_dim + ) + + def forward(self, x, x_lens=None): + x = self._downsample(x) + x = self.in_proj(x) + z = self.encoder(x, x_lens=x_lens) + return self._project_embeddings(z) + + def init_decode_state( + self, + *, + max_audio_patch_count: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> SemanticEncoderDecodeState: + return SemanticEncoderDecodeState( + conv_tail=torch.zeros( + (batch_size, self.ds_proj.in_channels, self.ds_proj.left_padding), + dtype=dtype, + device=device, + ), + layer_caches=self.encoder.init_decode_state( + batch_size=batch_size, + max_seq_len=max_audio_patch_count * self.out_ds_rate, + device=device, + dtype=dtype, + ), + seq_len=0, + ) + + def reset_decode_state(self, state: SemanticEncoderDecodeState) -> None: + state.conv_tail.zero_() + self.encoder.reset_decode_state(state.layer_caches) + state.seq_len = 0 + + def prefill( + self, + x, + state: SemanticEncoderDecodeState, + ) -> tuple[torch.Tensor, SemanticEncoderDecodeState]: + if x.ndim != 3: + raise ValueError( + f"VAESemanticEncoder.prefill expects rank-3 input, got {tuple(x.shape)}." + ) + if x.size(1) % self.patch_size != 0: + raise ValueError( + f"Prompt latent length {x.size(1)} must be divisible by patch_size={self.patch_size}." + ) + + if x.size(1) == 0: + return ( + x.new_zeros((x.size(0), 0, self.out_proj.out_features)), + state, + ) + if state.conv_tail.size(0) != x.size(0): + raise ValueError( + "VAESemanticEncoder.prefill batch size does not match decode state." + ) + + step_inputs = self.in_proj(self._downsample(x)) + expected_token_count = (x.size(1) // self.patch_size) * self.out_ds_rate + if step_inputs.size(1) != expected_token_count: + raise RuntimeError( + "Patch encoder prefill produced an unexpected token count: " + f"expected={expected_token_count} actual={step_inputs.size(1)}." + ) + + current_seq_len = state.seq_len + next_seq_len = current_seq_len + step_inputs.size(1) + cache_capacity = state.layer_caches[0][0].size(2) + if next_seq_len > cache_capacity: + raise ValueError( + "Patch encoder prefill exceeds decode-state capacity: " + f"required={next_seq_len} capacity={cache_capacity}." + ) + + positions = ( + torch.arange(step_inputs.size(1), device=x.device, dtype=torch.long) + + current_seq_len + ) + encoded = self.encoder.decode_step( + step_inputs, + layer_caches=state.layer_caches, + positions=positions, + ) + embedding = self._project_embeddings(encoded) + raw = x.transpose(1, 2) + state.conv_tail.copy_(raw[..., -self.ds_proj.left_padding :]) + state.seq_len = next_seq_len + return embedding, state + + def decode_patch( + self, + latent_patch, + conv_tail: torch.Tensor, + layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...], + positions: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if latent_patch.ndim != 3: + raise ValueError( + f"VAESemanticEncoder.decode_patch expects rank-3 input, got {tuple(latent_patch.shape)}." + ) + if latent_patch.size(1) != self.patch_size: + raise ValueError( + f"decode_patch expects patch length {self.patch_size}, got {latent_patch.size(1)}." + ) + if positions.ndim != 1 or positions.size(0) != self.out_ds_rate: + raise ValueError( + "decode_patch positions must be a rank-1 tensor matching out_ds_rate." + ) + + step_inputs, conv_tail = self._downsample_step( + latent_patch, + conv_tail=conv_tail, + ) + if step_inputs.size(1) != self.out_ds_rate: + raise RuntimeError( + f"Downsample step produced {step_inputs.size(1)} tokens, expected {self.out_ds_rate}." + ) + + encoded = self.encoder.decode_step( + step_inputs, + layer_caches=layer_caches, + positions=positions, + ) + embedding = self._project_embeddings(encoded) + return embedding, conv_tail + + def _downsample(self, x): + return self.ds_proj(x.transpose(1, 2)).transpose(1, 2) + + def _project_embeddings(self, z): + if self.out_ds_rate > 1: + z = rearrange(z, "b (s d) h -> b s (d h)", d=self.out_ds_rate) + return self.out_proj(z) + + def _downsample_step(self, latent_patch, *, conv_tail): + raw = latent_patch.transpose(1, 2) + conv_input = torch.cat([conv_tail, raw], dim=-1) + + projected = F.conv1d( + conv_input, + self.ds_proj.weight, + self.ds_proj.bias, + stride=self.ds_proj.stride[0], + padding=0, + dilation=self.ds_proj.dilation[0], + groups=self.ds_proj.groups, + ).transpose(1, 2) + new_conv_tail = raw[..., -self.ds_proj.left_padding :] + return self.in_proj(projected), new_conv_tail diff --git a/src/dots_tts/modules/speaker/__init__.py b/src/dots_tts/modules/speaker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ad248684adb71ec62e118df7d5f6ff367e9db5 --- /dev/null +++ b/src/dots_tts/modules/speaker/__init__.py @@ -0,0 +1 @@ +"""Speaker modules.""" diff --git a/src/dots_tts/modules/speaker/campplus.py b/src/dots_tts/modules/speaker/campplus.py new file mode 100644 index 0000000000000000000000000000000000000000..fc32030c09cf1c9115132483cd075024503c0a7d --- /dev/null +++ b/src/dots_tts/modules/speaker/campplus.py @@ -0,0 +1,200 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from torch import nn + +from dots_tts.modules.speaker.campplus_layers import ( + BasicResBlock, + CAMDenseTDNNBlock, + DenseLayer, + StatsPool, + TDNNLayer, + TransitLayer, + get_nonlinear, +) +from dots_tts.modules.speaker.fbank import _SPEAKER_FBANK_N_MELS + + +class FCM(nn.Module): + def __init__( + self, + block=BasicResBlock, + num_blocks=(2, 2), + m_channels=32, + feat_dim=_SPEAKER_FBANK_N_MELS, + ): + super().__init__() + self.in_planes = m_channels + self.conv1 = nn.Conv2d( + 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(m_channels) + + self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) + self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2) + + self.conv2 = nn.Conv2d( + m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(m_channels) + self.out_channels = m_channels * (feat_dim // 8) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.unsqueeze(1) + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = F.relu(self.bn2(self.conv2(out))) + + shape = out.shape + return out.reshape(shape[0], shape[1] * shape[2], shape[3]) + + +class CAMPPlus(nn.Module): + _TDNN_KERNEL_SIZE = 5 + _TDNN_STRIDE = 2 + _TDNN_PADDING = 2 + + def __init__( + self, + feat_dim=_SPEAKER_FBANK_N_MELS, + embedding_size=512, + growth_rate=32, + bn_size=4, + init_channels=128, + config_str="batchnorm-relu", + memory_efficient=True, + ): + super().__init__() + + self.head = FCM(feat_dim=feat_dim) + channels = self.head.out_channels + + self.xvector = nn.Sequential( + OrderedDict( + [ + ( + "tdnn", + TDNNLayer( + channels, + init_channels, + self._TDNN_KERNEL_SIZE, + stride=self._TDNN_STRIDE, + dilation=1, + padding=-1, + config_str=config_str, + ), + ), + ] + ) + ) + channels = init_channels + for i, (num_layers, kernel_size, dilation) in enumerate( + zip((12, 24, 16), (3, 3, 3), (1, 2, 2), strict=True) + ): + block = CAMDenseTDNNBlock( + num_layers=num_layers, + in_channels=channels, + out_channels=growth_rate, + bn_channels=bn_size * growth_rate, + kernel_size=kernel_size, + dilation=dilation, + config_str=config_str, + memory_efficient=memory_efficient, + ) + self.xvector.add_module(f"block{i + 1}", block) + channels = channels + num_layers * growth_rate + self.xvector.add_module( + f"transit{i + 1}", + TransitLayer( + channels, channels // 2, bias=False, config_str=config_str + ), + ) + channels //= 2 + + self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels)) + + self.xvector.add_module("stats", StatsPool()) + self.xvector.add_module( + "dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_") + ) + + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias) + + @staticmethod + def _conv_output_lengths(lengths, kernel_size, stride=1, padding=0, dilation=1): + return ( + torch.div( + lengths + 2 * padding - dilation * (kernel_size - 1) - 1, + stride, + rounding_mode="floor", + ) + + 1 + ) + + @staticmethod + def _make_length_mask(lengths, max_len, device): + lengths = lengths.to(device=device, dtype=torch.long).clamp(min=0, max=max_len) + return torch.arange(max_len, device=device).unsqueeze(0) < lengths.unsqueeze(1) + + def _masked_stats_pooling(self, x, lengths, unbiased=True, eps=1e-2): + lengths = lengths.to(device=x.device, dtype=torch.long).clamp( + min=1, max=x.size(-1) + ) + mask = self._make_length_mask(lengths, x.size(-1), x.device).unsqueeze(1) + mask = mask.to(dtype=x.dtype) + + denom = lengths.to(dtype=x.dtype).view(-1, 1).clamp_min(1.0) + mean = (x * mask).sum(dim=-1) / denom + + centered = (x - mean.unsqueeze(-1)) * mask + var_denom = ( + (lengths - 1).clamp_min(1).to(dtype=x.dtype).view(-1, 1) + if unbiased + else denom + ) + var = centered.pow(2).sum(dim=-1) / var_denom + std = torch.sqrt(var.clamp_min(eps)) + return torch.cat([mean, std], dim=1) + + def forward(self, x, lengths=None): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = self.head(x) + if lengths is not None: + lengths = lengths.to(device=x.device, dtype=torch.long).clamp(min=1) + + for name, module in self.xvector.named_children(): + if name == "stats": + x = ( + self._masked_stats_pooling(x, lengths) + if lengths is not None + else module(x) + ) + continue + + x = module(x) + if name == "tdnn" and lengths is not None: + lengths = self._conv_output_lengths( + lengths, + kernel_size=self._TDNN_KERNEL_SIZE, + stride=self._TDNN_STRIDE, + padding=self._TDNN_PADDING, + ) + + return x diff --git a/src/dots_tts/modules/speaker/campplus_layers.py b/src/dots_tts/modules/speaker/campplus_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..9189009bc8e8e68a6cf61a5e11a4977cdaf15e9f --- /dev/null +++ b/src/dots_tts/modules/speaker/campplus_layers.py @@ -0,0 +1,258 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from torch import nn + + +def get_nonlinear(config_str, channels): + nonlinear = nn.Sequential() + for name in config_str.split("-"): + if name == "relu": + nonlinear.add_module("relu", nn.ReLU(inplace=True)) + elif name == "prelu": + nonlinear.add_module("prelu", nn.PReLU(channels)) + elif name == "batchnorm": + nonlinear.add_module("batchnorm", nn.BatchNorm1d(channels)) + elif name == "batchnorm_": + nonlinear.add_module("batchnorm", nn.BatchNorm1d(channels, affine=False)) + else: + raise ValueError(f"Unexpected module ({name}).") + return nonlinear + + +def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, _eps=1e-2): + mean = x.mean(dim=dim) + std = x.std(dim=dim, unbiased=unbiased) + stats = torch.cat([mean, std], dim=-1) + if keepdim: + stats = stats.unsqueeze(dim=dim) + return stats + + +class StatsPool(nn.Module): + def forward(self, x): + return statistics_pooling(x) + + +class TDNNLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias=False, + config_str="batchnorm-relu", + ): + super().__init__() + if padding < 0: + assert kernel_size % 2 == 1, ( + f"Expect equal paddings, but got even kernel size ({kernel_size})" + ) + padding = (kernel_size - 1) // 2 * dilation + self.linear = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + self.nonlinear = get_nonlinear(config_str, out_channels) + + def forward(self, x): + x = self.linear(x) + return self.nonlinear(x) + + +class CAMLayer(nn.Module): + def __init__( + self, + bn_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + bias, + reduction=2, + ): + super().__init__() + self.linear_local = nn.Conv1d( + bn_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1) + self.relu = nn.ReLU(inplace=True) + self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + y = self.linear_local(x) + context = x.mean(-1, keepdim=True) + self.seg_pooling(x) + context = self.relu(self.linear1(context)) + m = self.sigmoid(self.linear2(context)) + return y * m + + def seg_pooling(self, x, seg_len=100, stype="avg"): + if stype == "avg": + seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) + elif stype == "max": + seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) + else: + raise ValueError("Wrong segment pooling type.") + shape = seg.shape + seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1) + return seg[..., : x.shape[-1]] + + +class CAMDenseTDNNLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + bn_channels, + kernel_size, + stride=1, + dilation=1, + bias=False, + config_str="batchnorm-relu", + memory_efficient=False, + ): + super().__init__() + assert kernel_size % 2 == 1, ( + f"Expect equal paddings, but got even kernel size ({kernel_size})" + ) + padding = (kernel_size - 1) // 2 * dilation + self.memory_efficient = memory_efficient + self.nonlinear1 = get_nonlinear(config_str, in_channels) + self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False) + self.nonlinear2 = get_nonlinear(config_str, bn_channels) + self.cam_layer = CAMLayer( + bn_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + def bn_function(self, x): + return self.linear1(self.nonlinear1(x)) + + def forward(self, x): + if self.training and self.memory_efficient: + x = cp.checkpoint(self.bn_function, x) + else: + x = self.bn_function(x) + return self.cam_layer(self.nonlinear2(x)) + + +class CAMDenseTDNNBlock(nn.ModuleList): + def __init__( + self, + num_layers, + in_channels, + out_channels, + bn_channels, + kernel_size, + stride=1, + dilation=1, + bias=False, + config_str="batchnorm-relu", + memory_efficient=False, + ): + super().__init__() + for i in range(num_layers): + layer = CAMDenseTDNNLayer( + in_channels=in_channels + i * out_channels, + out_channels=out_channels, + bn_channels=bn_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + bias=bias, + config_str=config_str, + memory_efficient=memory_efficient, + ) + self.add_module(f"tdnnd{i + 1}", layer) + + def forward(self, x): + for layer in self: + x = torch.cat([x, layer(x)], dim=1) + return x + + +class TransitLayer(nn.Module): + def __init__( + self, in_channels, out_channels, bias=True, config_str="batchnorm-relu" + ): + super().__init__() + self.nonlinear = get_nonlinear(config_str, in_channels) + self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) + + def forward(self, x): + x = self.nonlinear(x) + return self.linear(x) + + +class DenseLayer(nn.Module): + def __init__( + self, in_channels, out_channels, bias=False, config_str="batchnorm-relu" + ): + super().__init__() + self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) + self.nonlinear = get_nonlinear(config_str, out_channels) + + def forward(self, x): + if len(x.shape) == 2: + x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1) + else: + x = self.linear(x) + return self.nonlinear(x) + + +class BasicResBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=(stride, 1), + bias=False, + ), + nn.BatchNorm2d(self.expansion * planes), + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + return F.relu(out) diff --git a/src/dots_tts/modules/speaker/encoder.py b/src/dots_tts/modules/speaker/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f2e0e55bd5900705255ce782770279500cc21d --- /dev/null +++ b/src/dots_tts/modules/speaker/encoder.py @@ -0,0 +1,226 @@ +import math +import random + +import torch +import torch.nn as nn +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from dots_tts.modules.speaker.campplus import CAMPPlus +from dots_tts.modules.speaker.fbank import ( + _SPEAKER_FBANK_N_MELS, + _SPEAKER_FBANK_SAMPLE_RATE, + extract_speaker_fbank, +) + + +class SpeakerXVectorFeatures(nn.Module): + """ + Speaker embedding extractor based on 3D-Speaker CAM++. + """ + + def __init__( + self, + sample_rate=_SPEAKER_FBANK_SAMPLE_RATE, + campplus_embedding_size=512, + max_audio_seconds=10.0, + ): + super().__init__() + + self.sample_rate = sample_rate + self.max_audio_seconds = float(max_audio_seconds) + self.model = CAMPPlus( + feat_dim=_SPEAKER_FBANK_N_MELS, + embedding_size=campplus_embedding_size, + ) + self.resample = None + if self.sample_rate != _SPEAKER_FBANK_SAMPLE_RATE: + self.resample = torchaudio.transforms.Resample( + orig_freq=sample_rate, + new_freq=_SPEAKER_FBANK_SAMPLE_RATE, + ) + + for param in self.model.parameters(): + param.requires_grad = False + + @staticmethod + def _normalize_lengths(lengths, batch_size, max_length, device, *, min_length): + if lengths is None: + return torch.full( + (batch_size,), + max_length, + device=device, + dtype=torch.long, + ) + return lengths.to(device=device, dtype=torch.long).clamp( + min=min_length, + max=max_length, + ) + + def _crop_audio(self, audio, audio_lengths=None): + original_lengths = self._normalize_lengths( + audio_lengths, + audio.size(0), + audio.size(-1), + audio.device, + min_length=0, + ) + if self.max_audio_seconds <= 0: + return audio, original_lengths, original_lengths, torch.zeros_like( + original_lengths + ) + + max_input_length = round(self.sample_rate * self.max_audio_seconds) + cropped_audio = [] + cropped_lengths = [] + starts = [] + + for index, total_length_tensor in enumerate(original_lengths): + total_length = int(total_length_tensor.item()) + cropped_length = min(total_length, max_input_length) + start = ( + random.randint(0, total_length - cropped_length) + if total_length > cropped_length + else 0 + ) + cropped_audio.append(audio[index, start : start + cropped_length]) + cropped_lengths.append(cropped_length) + starts.append(start) + + return pad_sequence( + cropped_audio, + batch_first=True, + padding_value=0.0, + ), original_lengths, torch.tensor( + cropped_lengths, + device=audio.device, + dtype=torch.long, + ), torch.tensor(starts, device=audio.device, dtype=torch.long) + + def _crop_fbank( + self, + fbank, + fbank_lengths, + original_audio_lengths, + cropped_audio_lengths, + starts, + ): + original_fbank_lengths = self._normalize_lengths( + fbank_lengths, + fbank.size(0), + fbank.size(1), + fbank.device, + min_length=1, + ) + cropped_fbank = [] + cropped_fbank_lengths = [] + + for index, total_feat_length_tensor in enumerate(original_fbank_lengths): + total_audio_length = int(original_audio_lengths[index].item()) + total_feat_length = int(total_feat_length_tensor.item()) + start_audio = int(starts[index].item()) + end_audio = start_audio + int(cropped_audio_lengths[index].item()) + + if total_audio_length > 0: + start_feat = math.floor( + start_audio * total_feat_length / total_audio_length + ) + end_feat = math.ceil(end_audio * total_feat_length / total_audio_length) + else: + start_feat = 0 + end_feat = 1 + + start_feat = min(start_feat, total_feat_length - 1) + end_feat = min(max(end_feat, start_feat + 1), total_feat_length) + cropped_fbank.append(fbank[index, start_feat:end_feat]) + cropped_fbank_lengths.append(end_feat - start_feat) + + return pad_sequence( + cropped_fbank, + batch_first=True, + padding_value=0.0, + ), torch.tensor( + cropped_fbank_lengths, + device=fbank.device, + dtype=torch.long, + ) + + def _extract_fbank_batch(self, audio, audio_lengths): + if self.resample is not None: + audio = self.resample(audio) + audio_lengths = torch.ceil( + audio_lengths.float() + * (_SPEAKER_FBANK_SAMPLE_RATE / self.sample_rate) + ).long() + + audio_cpu = audio.detach().cpu() + features = [] + + for index, valid_length_tensor in enumerate(audio_lengths): + valid_length = int(valid_length_tensor.item()) + waveform = audio_cpu[index, :valid_length] + if waveform.numel() == 0: + waveform = audio_cpu.new_zeros(1) + features.append( + extract_speaker_fbank( + waveform, + sample_rate=_SPEAKER_FBANK_SAMPLE_RATE, + ) + ) + + fbank_lengths = torch.tensor( + [feature.size(0) for feature in features], + device=audio.device, + dtype=torch.long, + ) + fbank = pad_sequence( + features, + batch_first=True, + padding_value=0.0, + ).to(device=audio.device, dtype=audio.dtype) + return fbank, fbank_lengths + + @torch.no_grad() + @torch.autocast(enabled=False, device_type="cuda") + def forward( + self, audio, audio_lengths=None, fbank=None, fbank_lengths=None, **_kwargs + ): + self.model.eval() + audio = audio.float() + if audio.dim() == 3: + if audio.size(1) != 1: + raise ValueError( + f"Speaker encoder expects mono audio, got shape {tuple(audio.shape)}." + ) + audio = audio[:, 0] + elif audio.dim() != 2: + raise ValueError( + f"Speaker encoder expects a 2D or 3D audio tensor, got shape {tuple(audio.shape)}." + ) + + audio, original_audio_lengths, cropped_audio_lengths, starts = self._crop_audio( + audio, + audio_lengths=audio_lengths, + ) + + if fbank is None: + fbank, fbank_lengths = self._extract_fbank_batch( + audio, + cropped_audio_lengths, + ) + else: + if not isinstance(fbank, torch.Tensor): + raise TypeError("Speaker encoder expects `fbank` to be a torch.Tensor.") + if fbank.dim() != 3 or fbank.size(0) != audio.size(0): + raise ValueError( + f"Speaker encoder expects `fbank` with shape (B, T, F) and matching batch size, got {tuple(fbank.shape)}." + ) + fbank, fbank_lengths = self._crop_fbank( + fbank.to(device=audio.device, dtype=torch.float32), + fbank_lengths, + original_audio_lengths, + cropped_audio_lengths, + starts, + ) + + return self.model(fbank, lengths=fbank_lengths) diff --git a/src/dots_tts/modules/speaker/fbank.py b/src/dots_tts/modules/speaker/fbank.py new file mode 100644 index 0000000000000000000000000000000000000000..2d416e9d6b34c6d471c73d712709652fe0a9a038 --- /dev/null +++ b/src/dots_tts/modules/speaker/fbank.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import torch + +from dots_tts.utils.audio import extract_fbank, high_quality_resample + +_SPEAKER_FBANK_SAMPLE_RATE = 16000 +_SPEAKER_FBANK_N_MELS = 80 +_SPEAKER_FBANK_MEAN_NORM = True +_SPEAKER_FBANK_DITHER = 0.0 + + +def extract_speaker_fbank( + waveform: torch.Tensor, + *, + sample_rate: int, +) -> torch.Tensor: + feature_input = waveform + if sample_rate != _SPEAKER_FBANK_SAMPLE_RATE: + feature_input = high_quality_resample( + waveform, + orig_sr=sample_rate, + target_sr=_SPEAKER_FBANK_SAMPLE_RATE, + ) + return extract_fbank( + feature_input, + sample_rate=_SPEAKER_FBANK_SAMPLE_RATE, + n_mels=_SPEAKER_FBANK_N_MELS, + dither=_SPEAKER_FBANK_DITHER, + mean_norm=_SPEAKER_FBANK_MEAN_NORM, + ) diff --git a/src/dots_tts/modules/vocoder/__init__.py b/src/dots_tts/modules/vocoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f16ef8bd262da6f4d8d4715162bc600d8b140ef7 --- /dev/null +++ b/src/dots_tts/modules/vocoder/__init__.py @@ -0,0 +1 @@ +"""Vocoder modules.""" diff --git a/src/dots_tts/modules/vocoder/alias_free_act.py b/src/dots_tts/modules/vocoder/alias_free_act.py new file mode 100644 index 0000000000000000000000000000000000000000..948d2c69eaaddb9cacc7cf075ca00020f368b8d2 --- /dev/null +++ b/src/dots_tts/modules/vocoder/alias_free_act.py @@ -0,0 +1,163 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +from torch import pow, sin +from torch.nn import Parameter + +from .alias_free_resample import DownSample1d, UpSample1d + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + causal=True, + fixed_filter=False, + ): + super().__init__() + # causal=False + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d( + up_ratio, + up_kernel_size, + activation.in_features, + causal=causal, + fixed_filter=fixed_filter, + ) + self.downsample = DownSample1d( + down_ratio, + down_kernel_size, + activation.in_features, + causal=causal, + fixed_filter=fixed_filter, + ) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + return self.downsample(x) + + +class Snake(nn.Module): + """ + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False + ): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + Snake := x + 1/a * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + return x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False + ): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta := x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + return x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) diff --git a/src/dots_tts/modules/vocoder/alias_free_filter.py b/src/dots_tts/modules/vocoder/alias_free_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..4913366ea7a73de56a4136887a51b9584e618e33 --- /dev/null +++ b/src/dots_tts/modules/vocoder/alias_free_filter.py @@ -0,0 +1,114 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +if "sinc" in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d( + cutoff, half_width, kernel_size +): # return filter [1,1,kernel_size] + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + channels: int = 1, + causal: bool = True, + fixed_filter: bool = False, + ): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + if causal: + self.pad_left = kernel_size - 1 + self.pad_right = 0 + else: + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + self.fixed_filter = fixed_filter + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + if fixed_filter: + self.register_buffer("filter", filter) + else: + self.filter = nn.Parameter(filter.expand(channels, -1, -1).clone()) + + # input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + if self.fixed_filter: + out = F.conv1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + else: + out = F.conv1d(x, self.filter, stride=self.stride, groups=C) + return out diff --git a/src/dots_tts/modules/vocoder/alias_free_resample.py b/src/dots_tts/modules/vocoder/alias_free_resample.py new file mode 100644 index 0000000000000000000000000000000000000000..e10cf7e76c53f7bb3f19d0f337bcdc1931505fc2 --- /dev/null +++ b/src/dots_tts/modules/vocoder/alias_free_resample.py @@ -0,0 +1,81 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F + +from .alias_free_filter import LowPassFilter1d, kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__( + self, ratio=2, kernel_size=None, channels=None, causal=True, fixed_filter=False + ): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.stride = ratio + self.channels = channels + self.causal = causal + self.fixed_filter = fixed_filter + if causal: + self.pad = 0 + else: + self.pad = self.kernel_size // ratio - 1 + self.pad_left = ( + self.pad * self.stride + (self.kernel_size - self.stride) // 2 + ) + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + if self.fixed_filter: + self.register_buffer("filter", filter) + else: + self.filter = nn.Parameter(filter.expand(channels, -1, -1).clone()) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + x = F.pad(x, (self.pad, self.pad), mode="replicate") + if self.fixed_filter: + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + else: + x = self.ratio * F.conv_transpose1d( + x, self.filter, stride=self.stride, groups=C + ) + if self.causal: + x = x[..., : -(self.kernel_size - self.stride)] + else: + x = x[..., self.pad_left : -self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__( + self, ratio=2, kernel_size=None, channels=None, causal=True, fixed_filter=False + ): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + channels=channels, + causal=causal, + fixed_filter=fixed_filter, + ) + + def forward(self, x): + return self.lowpass(x) diff --git a/src/dots_tts/modules/vocoder/bigvgan.py b/src/dots_tts/modules/vocoder/bigvgan.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e21c4c4a06c2b13a941828c7c16f49ce2924b4 --- /dev/null +++ b/src/dots_tts/modules/vocoder/bigvgan.py @@ -0,0 +1,1070 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. +from __future__ import annotations + +import itertools +from dataclasses import dataclass +from fractions import Fraction + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import remove_weight_norm, weight_norm + +from dots_tts.modules.backbone.layers import Conv1d, ConvTranspose1d +from dots_tts.modules.vocoder.alias_free_act import Activation1d, Snake, SnakeBeta +from dots_tts.modules.vocoder.config import AudioVAEConfig + + +@dataclass(slots=True) +class BigVGANStreamState: + lstm_hidden: tuple[torch.Tensor, torch.Tensor] + decoder: "DecoderStreamState" + + +@dataclass(slots=True) +class DecoderStreamState: + window: torch.Tensor + chunk_size: int + total_frames: int = 0 + emitted_frames: int = 0 + +def _empty_chunk( + ref: torch.Tensor, + *, + channels: int | None = None, + dtype: torch.dtype | None = None, +) -> torch.Tensor: + return ref.new_zeros( + (ref.size(0), channels or ref.size(1), 0), + dtype=dtype or ref.dtype, + ) + + +def _module_state_device_dtype(module: nn.Module) -> tuple[torch.device, torch.dtype]: + for name in ("weight", "bias", "filter"): + tensor = getattr(module, name, None) + if tensor is not None: + return tensor.device, tensor.dtype + for tensor in itertools.chain(module.parameters(), module.buffers()): + return tensor.device, tensor.dtype + raise RuntimeError(f"Unable to infer state dtype/device for {type(module).__name__}.") + + +def _stream_state_zeros( + batch_size: int, + channels: int, + length: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + return torch.zeros( + (batch_size, channels, max(0, int(length))), + device=device, + dtype=dtype, + ) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +class Conv1d_S(nn.Module): + "Conv1d for spectral normalisation and orthogonal initialisation" + + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + dilation=1, + groups=1, + causal=False, + ): + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.groups = groups + self.causal = causal + pad = 0 if causal else dilation * (kernel_size - 1) // 2 + self.causal_pad = dilation * (kernel_size - 1) if causal else 0 + + self.layer = weight_norm( + nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=pad, + dilation=dilation, + groups=groups, + ) + ) + + def forward(self, inputs): + if self.causal and self.causal_pad > 0: + inputs = F.pad(inputs, (self.causal_pad, 0)) + return self.layer(inputs) + + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + + def __init__( + self, + dimension: int, + num_layers: int = 2, + skip: bool = True, + bidirectional: bool = False, + ): + super().__init__() + self.skip = skip + self.bidirectional = bidirectional + self.lstm = nn.LSTM( + input_size=dimension, + hidden_size=dimension, + num_layers=num_layers, + bidirectional=bidirectional, + batch_first=True, + ) + self._stream_num_layers = num_layers + self._stream_weight_ih = tuple( + getattr(self.lstm, f"weight_ih_l{layer_idx}") + for layer_idx in range(num_layers) + ) + self._stream_weight_hh = tuple( + getattr(self.lstm, f"weight_hh_l{layer_idx}") + for layer_idx in range(num_layers) + ) + self._stream_bias_ih = tuple( + getattr(self.lstm, f"bias_ih_l{layer_idx}") + for layer_idx in range(num_layers) + ) + self._stream_bias_hh = tuple( + getattr(self.lstm, f"bias_hh_l{layer_idx}") + for layer_idx in range(num_layers) + ) + if self.bidirectional: + self.proj_out = nn.Linear(dimension * 2, dimension) + + def forward(self, x): + y, _ = self.lstm(x) + if self.bidirectional: + y = self.proj_out(y) + if self.skip: + y = y + x + return y + + def stream_step( + self, + x: torch.Tensor, + hidden: tuple[torch.Tensor, torch.Tensor] | None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if self.bidirectional: + raise RuntimeError("Streaming only supports unidirectional SLSTM.") + + residual = x + if hidden is None: + hidden = self.init_stream_state(x.size(0)) + + hidden_h, hidden_c = hidden + next_hidden_h = [] + next_hidden_c = [] + for layer_idx in range(self._stream_num_layers): + layer_input = x + hx = hidden_h[layer_idx] + cx = hidden_c[layer_idx] + outputs = [] + weight_ih = self._stream_weight_ih[layer_idx] + weight_hh = self._stream_weight_hh[layer_idx] + bias_ih = self._stream_bias_ih[layer_idx] + bias_hh = self._stream_bias_hh[layer_idx] + + for frame_idx in range(x.size(1)): + gates = F.linear(layer_input[:, frame_idx, :], weight_ih, bias_ih) + gates = gates + F.linear(hx, weight_hh, bias_hh) + input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, dim=-1) + input_gate = torch.sigmoid(input_gate) + forget_gate = torch.sigmoid(forget_gate) + cell_gate = torch.tanh(cell_gate) + output_gate = torch.sigmoid(output_gate) + cx = forget_gate * cx + input_gate * cell_gate + hx = output_gate * torch.tanh(cx) + outputs.append(hx) + + x = torch.stack(outputs, dim=1) + next_hidden_h.append(hx) + next_hidden_c.append(cx) + + y = x + if self.skip: + y = y + residual + return y, (torch.stack(next_hidden_h, dim=0), torch.stack(next_hidden_c, dim=0)) + + def init_stream_state( + self, + batch_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_directions = 2 if self.bidirectional else 1 + state_shape = ( + self.lstm.num_layers * num_directions, + batch_size, + self.lstm.hidden_size, + ) + weight = self.lstm.weight_hh_l0 + return ( + weight.new_zeros(state_shape), + weight.new_zeros(state_shape), + ) + + +class ResStack(nn.Module): + def __init__(self, channel, kernel_size=3, base=3, nums=4, causal=False): + super().__init__() + + self.layers = nn.ModuleList([]) + for i in range(nums): + dil = base**i + pad1 = dil * (kernel_size - 1) if causal else dil + pad2 = (kernel_size - 1) if causal else 1 + block = [ + nn.LeakyReLU(), + ] + if causal and pad1 > 0: + block.append(nn.ConstantPad1d((pad1, 0), 0.0)) + block.append( + nn.utils.weight_norm( + nn.Conv1d( + channel, + channel, + kernel_size=kernel_size, + dilation=dil, + padding=0 if causal else pad1, + ) + ) + ) + block.append(nn.LeakyReLU()) + if causal and pad2 > 0: + block.append(nn.ConstantPad1d((pad2, 0), 0.0)) + block.append( + nn.utils.weight_norm( + nn.Conv1d( + channel, + channel, + kernel_size=kernel_size, + dilation=1, + padding=0 if causal else pad2, + ) + ) + ) + self.layers.append(nn.Sequential(*block)) + + def forward(self, x): + for layer in self.layers: + x = x + layer(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=1, + out_channels=100, + base_channels=12, + proj_kernel_size=3, + stack_kernel_size=3, + stack_dilation_base=2, + stacks=6, + channels=(12, 24, 48, 96, 192, 384, 768), + down_sample_factors=(2, 2, 2, 2, 4, 4), + causal=False, + lookahead=0, + ): + super().__init__() + + act_slope = 0.2 + layers = [] + # pre proj_layer + layers += [ + Conv1d_S( + in_channels, + base_channels, + kernel_size=proj_kernel_size, + stride=1, + causal=causal, + ), + nn.LeakyReLU(act_slope, True), + ] + + # channels: [512, 256, 128, 64], upsample_factors: [5, 2, 2] + for (in_c, out_c), down_f in zip( + itertools.pairwise(channels), down_sample_factors, strict=True + ): + layers += [ + Conv1d_S( + in_c, out_c, kernel_size=down_f * 2, stride=down_f, causal=causal + ), + ResStack( + out_c, stack_kernel_size, stack_dilation_base, stacks, causal=causal + ), + nn.LeakyReLU(act_slope, True), + ] + + # post layers + if lookahead > 0: + layers += [ + Conv1d_S( + channels[-1], + out_channels, + kernel_size=lookahead * 2 + 1, + stride=1, + causal=False, + ), + ] + else: + layers += [ + Conv1d_S( + channels[-1], + out_channels, + kernel_size=proj_kernel_size, + stride=1, + causal=causal, + ), + ] + self.generator = nn.Sequential(*layers) + + def forward(self, conditions, _z_inputs=None): + return self.generator(conditions) + + +class AMPBlock1(torch.nn.Module): + def __init__( + self, + h, + channels, + kernel_size=3, + dilation=(1, 3, 5), + activation=None, + causal=True, + ): + super().__init__() + self.h = h + + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + causal=causal, + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + causal=causal, + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + causal=causal, + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, channels, kernel_size, 1, dilation=1, causal=causal + ) + ), + weight_norm( + Conv1d( + channels, channels, kernel_size, 1, dilation=1, causal=causal + ) + ), + weight_norm( + Conv1d( + channels, channels, kernel_size, 1, dilation=1, causal=causal + ) + ), + ] + ) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len( + self.convs2 + ) # total number of conv layers + + if ( + activation == "snake" + ): # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=Snake(channels, alpha_logscale=h.snake_logscale), + causal=causal, + fixed_filter=True, + ) + for _ in range(self.num_layers) + ] + ) + elif ( + activation == "snakebeta" + ): # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale), + causal=causal, + fixed_filter=True, + ) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2, strict=True): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for layer in self.convs1: + remove_weight_norm(layer) + for layer in self.convs2: + remove_weight_norm(layer) + + +class AMPBlock2(torch.nn.Module): + def __init__( + self, h, channels, kernel_size=3, dilation=(1, 3), activation=None, causal=True + ): + super().__init__() + self.h = h + + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + causal=causal, + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + causal=causal, + ) + ), + ] + ) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # total number of conv layers + + if ( + activation == "snake" + ): # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=Snake(channels, alpha_logscale=h.snake_logscale), + causal=causal, + fixed_filter=True, + ) + for _ in range(self.num_layers) + ] + ) + elif ( + activation == "snakebeta" + ): # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale), + causal=causal, + fixed_filter=True, + ) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + for c, a in zip(self.convs, self.activations, strict=True): + xt = a(x) + xt = c(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for layer in self.convs: + remove_weight_norm(layer) + + +class Decoder(nn.Module): + def __init__(self, h): + super().__init__() + self.h = h + causal = h.causal + self._stream_window_sizes: dict[int, int] = {} + + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + + num_decoder_lookahead = h.get("num_decoder_lookahead", 2) + # pre conv + self.conv_pre = weight_norm( + Conv1d( + h.latent_dim, + h.upsample_initial_channel, + kernel_size=2 * num_decoder_lookahead + 1, + stride=1, + causal=False, + ) + ) + + # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + resblock = AMPBlock1 if h.resblock == "1" else AMPBlock2 + + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip(h.upsample_rates, h.upsample_kernel_sizes, strict=True) + ): + self.ups.append( + nn.ModuleList( + [ + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + causal=causal, + ) + ) + ] + ) + ) + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for k, d in zip( + h.resblock_kernel_sizes, h.resblock_dilation_sizes, strict=True + ): + self.resblocks.append( + resblock(h, ch, k, d, activation=h.activation, causal=causal) + ) + + # post conv + if ( + h.activation == "snake" + ): # periodic nonlinearity with snake function and anti-aliasing + activation_post = Snake(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d( + activation=activation_post, causal=causal, fixed_filter=False + ) + elif ( + h.activation == "snakebeta" + ): # periodic nonlinearity with snakebeta function and anti-aliasing + activation_post = SnakeBeta(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d( + activation=activation_post, causal=causal, fixed_filter=False + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + self.conv_post = weight_norm( + Conv1d(ch, 1, 7, 1, causal=causal, bias=h.get("use_bias_at_final", True)) + ) + + # weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, z): + # pre conv + x = self.conv_pre(z) + + for i in range(self.num_upsamples): + # upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x) + x = self.conv_post(x) + if self.h.get("use_tanh_at_final", True): + x = torch.tanh(x) + else: + x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1] + + return x + + @property + def stream_lookahead(self) -> int: + return (self.conv_pre.kernel_size[0] - 1) // 2 + + @staticmethod + def _conv1d_left_context(layer) -> int: + dilation = layer.dilation[0] if isinstance(layer.dilation, tuple) else layer.dilation + kernel_size = ( + layer.kernel_size[0] + if isinstance(layer.kernel_size, tuple) + else layer.kernel_size + ) + if getattr(layer, "causal", False): + return dilation * (kernel_size - 1) + return layer.padding[0] if isinstance(layer.padding, tuple) else layer.padding + + @staticmethod + def _convtranspose1d_left_context(layer) -> int: + stride = layer.stride[0] if isinstance(layer.stride, tuple) else layer.stride + kernel_size = ( + layer.kernel_size[0] + if isinstance(layer.kernel_size, tuple) + else layer.kernel_size + ) + if not getattr(layer, "causal", False): + raise NotImplementedError("Streaming only supports causal ConvTranspose1d.") + if kernel_size != 2 * stride: + raise ValueError( + "Streaming ConvTranspose1d expects kernel_size == 2 * stride, got " + f"kernel_size={kernel_size} stride={stride}." + ) + return 1 + + @classmethod + def _activation_left_context(cls, activation: Activation1d) -> int: + upsample = activation.upsample + downsample = activation.downsample.lowpass + if not upsample.causal or not downsample.padding or downsample.pad_right != 0: + raise NotImplementedError("Streaming only supports causal alias-free activations.") + ratio = int(upsample.ratio) + if ratio != int(downsample.stride): + raise ValueError( + "Alias-free activation expects matched up/down ratios, got " + f"up_ratio={ratio} down_ratio={downsample.stride}." + ) + total_left = (upsample.kernel_size - 1) + (downsample.kernel_size - 1) + return (total_left + ratio - 1) // ratio + + @classmethod + def _ampblock_left_context(cls, block) -> int: + if isinstance(block, AMPBlock1): + left_context = 0 + acts1 = block.activations[::2] + acts2 = block.activations[1::2] + for conv1, conv2, act1, act2 in zip( + block.convs1, + block.convs2, + acts1, + acts2, + strict=True, + ): + left_context += ( + cls._activation_left_context(act1) + + cls._conv1d_left_context(conv1) + + cls._activation_left_context(act2) + + cls._conv1d_left_context(conv2) + ) + return left_context + if isinstance(block, AMPBlock2): + left_context = 0 + for conv, activation in zip(block.convs, block.activations, strict=True): + left_context += ( + cls._activation_left_context(activation) + + cls._conv1d_left_context(conv) + ) + return left_context + raise TypeError(f"Unsupported resblock type: {type(block).__name__}.") + + def _stream_left_context(self) -> int: + left_context = Fraction(self._conv1d_left_context(self.conv_pre), 1) + current_scale = Fraction(1, 1) + for stage_idx, upsample_layers in enumerate(self.ups): + for upsample in upsample_layers: + left_context += current_scale * self._convtranspose1d_left_context( + upsample + ) + stride = ( + upsample.stride[0] + if isinstance(upsample.stride, tuple) + else upsample.stride + ) + current_scale /= int(stride) + + stage_start = stage_idx * self.num_kernels + stage_end = stage_start + self.num_kernels + stage_context = max( + self._ampblock_left_context(block) + for block in self.resblocks[stage_start:stage_end] + ) + left_context += current_scale * stage_context + + left_context += current_scale * self._activation_left_context(self.activation_post) + left_context += current_scale * self._conv1d_left_context(self.conv_post) + return int(left_context.__ceil__()) + + def stream_window_size(self, chunk_size: int) -> int: + if chunk_size < 1: + raise ValueError(f"chunk_size must be >= 1, got {chunk_size}.") + cached = self._stream_window_sizes.get(chunk_size) + if cached is not None: + return cached + + window_size = chunk_size + self.stream_lookahead + self._stream_left_context() + self._stream_window_sizes[chunk_size] = window_size + return window_size + + def remove_weight_norm(self): + for upsample_layers in self.ups: + for upsample_layer in upsample_layers: + remove_weight_norm(upsample_layer) + for resblock in self.resblocks: + resblock.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class AudioVAE(nn.Module): + def __init__(self, h: AudioVAEConfig): + super().__init__() + self.config = h + + self.h = h + self.hop_size = int(np.prod(h.downsample_rates)) + self.sample_rate = h.sample_rate + self.decoder_lookahead = int(h.get("num_decoder_lookahead", 2)) + + self.audio_encoder = Encoder( + out_channels=h.latent_dim, + down_sample_factors=h.downsample_rates, + channels=h.downsample_channels, + causal=h.causal_encoder, + lookahead=h.get("num_encoder_lookahead", 2), + ) + + intermediate_size = h.latent_dim * 4 + self.enc_mi_layer = nn.Sequential( + nn.Linear(h.latent_dim, intermediate_size), + SLSTM(intermediate_size, num_layers=h.mi_num_layers), + nn.Linear(intermediate_size, h.latent_dim), + ) + self.dec_mi_layer = nn.Sequential( + nn.Linear(h.latent_dim, intermediate_size), + SLSTM(intermediate_size, num_layers=h.mi_num_layers), + nn.Linear(intermediate_size, h.latent_dim), + ) + self.pre_proj = Conv1d( + in_channels=h.latent_dim, + out_channels=h.latent_dim * 2, + kernel_size=1, + stride=1, + ) + self.post_proj = Conv1d( + in_channels=h.latent_dim, out_channels=h.latent_dim, kernel_size=1, stride=1 + ) + + self.decoder = Decoder(h) + + def inference(self, data): + latents = self.extract_latents(data["sample"]) + return {"sample": self.inference_from_latents(latents)} + + @torch.autocast(enabled=False, device_type="cuda") + def extract_latents(self, x, do_sample=False): + x = x.float() + x = self.audio_encoder(x) + x = x.permute(0, 2, 1) + x = self.enc_mi_layer(x) + x = x.permute(0, 2, 1) + x = self.pre_proj(x) + if do_sample: + m_q, logs_q = torch.split(x, self.h.latent_dim, dim=1) + x = m_q + torch.randn_like(m_q) * torch.exp(logs_q) + return x + + def inference_from_latents(self, x, do_sample=True, noise_scale=1.0): + if do_sample: + assert x.size(1) == self.h.latent_dim * 2, ( + f"Input must be like [B, D, H], got {x.shape}" + ) + m_q, logs_q = torch.split(x, self.h.latent_dim, dim=1) + x = m_q + torch.randn_like(m_q) * torch.exp(logs_q) * noise_scale + else: + assert x.size(1) == self.h.latent_dim, ( + f"Input must be like [B, D, H], got {x.shape}" + ) + x = self.post_proj(x) + x = x.permute(0, 2, 1) + x = self.dec_mi_layer(x) + x = x.permute(0, 2, 1) + return self.decoder(x) + + def _validate_stream_latents(self, latents: torch.Tensor) -> None: + if latents.ndim != 3: + raise ValueError( + "Streaming latents must have shape [batch, latent_dim, frames], " + f"got {tuple(latents.shape)}." + ) + if latents.size(1) != self.h.latent_dim: + raise ValueError( + f"Streaming latent_dim must be {self.h.latent_dim}, got {latents.size(1)}." + ) + + def init_stream_state( + self, + batch_size: int = 1, + chunk_size: int = 8, + ) -> BigVGANStreamState: + if not self.h.causal: + raise RuntimeError("Strict streaming requires a causal vocoder.") + if batch_size < 1: + raise ValueError(f"batch_size must be >= 1, got {batch_size}.") + window_size = self.decoder.stream_window_size(chunk_size) + device, dtype = _module_state_device_dtype(self.decoder.conv_pre) + return BigVGANStreamState( + lstm_hidden=self.dec_mi_layer[1].init_stream_state(batch_size), + decoder=DecoderStreamState( + window=_stream_state_zeros( + batch_size, + self.h.latent_dim, + window_size, + device=device, + dtype=dtype, + ), + chunk_size=int(chunk_size), + ), + ) + + def _decode_stream_latents( + self, + latents: torch.Tensor, + lstm_hidden: tuple[torch.Tensor, torch.Tensor], + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + self._validate_stream_latents(latents) + latents = latents.float() + x = self.post_proj(latents) + x = x.permute(0, 2, 1) + x = self.dec_mi_layer[0](x) + x, lstm_hidden = self.dec_mi_layer[1].stream_step( + x, lstm_hidden + ) + x = self.dec_mi_layer[2](x) + decoder_dtype = self.decoder.conv_pre.weight.dtype + return x.permute(0, 2, 1).to(dtype=decoder_dtype), lstm_hidden + + def _prepare_stream_decoder_input( + self, + latents: torch.Tensor, + state: BigVGANStreamState, + ) -> torch.Tensor: + decoder_input, state.lstm_hidden = self._decode_stream_latents( + latents, + state.lstm_hidden, + ) + return decoder_input + + def _append_stream_decoder_input_tensor( + self, + decoder_input: torch.Tensor, + window: torch.Tensor, + valid_frames: torch.Tensor, + ) -> torch.Tensor: + if window.dtype != decoder_input.dtype: + window = window.to(dtype=decoder_input.dtype) + chunk_size = int(decoder_input.size(-1)) + if chunk_size >= window.size(-1): + raise ValueError( + f"decoder window size {window.size(-1)} must be larger than chunk_size {chunk_size}." + ) + positions = torch.arange( + window.size(-1), + device=window.device, + dtype=valid_frames.dtype, + ) + clipped_valid = valid_frames.clamp(min=0, max=window.size(-1)) + combined = torch.cat( + [window, decoder_input.new_zeros(window.size(0), window.size(1), chunk_size)], + dim=-1, + ) + insert_index = clipped_valid + torch.arange( + chunk_size, + device=window.device, + dtype=valid_frames.dtype, + ) + combined.scatter_( + -1, + insert_index.view(1, 1, -1).expand_as(decoder_input), + decoder_input, + ) + new_valid = (clipped_valid + chunk_size).clamp(max=window.size(-1)) + start = (clipped_valid + chunk_size - window.size(-1)).clamp(min=0) + gather_index = (start + positions).clamp(max=combined.size(-1) - 1) + gathered = combined.gather( + -1, + gather_index.view(1, 1, -1).expand_as(window), + ) + mask = (positions < new_valid).to(dtype=window.dtype).view(1, 1, -1) + return gathered * mask + + def _append_stream_decoder_input( + self, + decoder_input: torch.Tensor, + state: BigVGANStreamState, + ) -> torch.Tensor: + decoder_state = state.decoder + chunk_size = int(decoder_input.size(-1)) + if chunk_size != decoder_state.chunk_size: + raise ValueError( + f"Streaming chunk_size must stay fixed at {decoder_state.chunk_size}, got {chunk_size}." + ) + window = decoder_state.window + valid_frames = min(decoder_state.total_frames, window.size(-1)) + valid_frames_tensor = window.new_tensor(valid_frames, dtype=torch.int64) + new_window = self._append_stream_decoder_input_tensor( + decoder_input, + window, + valid_frames_tensor, + ) + decoder_state.window = new_window + decoder_state.total_frames += chunk_size + return new_window + + def compiled_stream_step( + self, + latents: torch.Tensor, + hidden_h: torch.Tensor, + hidden_c: torch.Tensor, + window: torch.Tensor, + valid_frames: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + decoder_input, (hidden_h, hidden_c) = self._decode_stream_latents( + latents, + (hidden_h, hidden_c), + ) + new_window = self._append_stream_decoder_input_tensor( + decoder_input, + window, + valid_frames, + ) + audio_window = self.decode_stream_window(new_window) + return audio_window, hidden_h, hidden_c, new_window + + def decode_stream_window(self, window: torch.Tensor) -> torch.Tensor: + decoder_dtype = self.decoder.conv_pre.weight.dtype + if window.dtype != decoder_dtype: + window = window.to(dtype=decoder_dtype) + return self.decoder(window) + + def _slice_stream_audio_window( + self, + audio_window: torch.Tensor, + state: BigVGANStreamState, + *, + final: bool, + ) -> torch.Tensor: + decoder_state = state.decoder + stable_end = ( + decoder_state.total_frames + if final + else max(0, decoder_state.total_frames - self.decoder.stream_lookahead) + ) + if stable_end <= decoder_state.emitted_frames: + return _empty_chunk(audio_window, channels=1) + + window_size = decoder_state.window.size(-1) + valid_frames = min(decoder_state.total_frames, window_size) + window_start = decoder_state.total_frames - valid_frames + if decoder_state.emitted_frames < window_start: + raise RuntimeError( + "Decoder stream window is too short for fixed-graph decoding." + ) + + local_start = decoder_state.emitted_frames - window_start + local_end = stable_end - window_start + sample_start = local_start * self.hop_size + sample_end = local_end * self.hop_size + decoder_state.emitted_frames = stable_end + return audio_window[..., sample_start:sample_end] + + def stream_step( + self, + latents: torch.Tensor, + state: BigVGANStreamState, + ) -> torch.Tensor: + decoder_input = self._prepare_stream_decoder_input(latents, state) + window = self._append_stream_decoder_input(decoder_input, state) + audio_window = self.decode_stream_window(window) + return self._slice_stream_audio_window(audio_window, state, final=False) + + def stream_flush(self, state: BigVGANStreamState) -> torch.Tensor: + audio_window = self.decode_stream_window(state.decoder.window) + return self._slice_stream_audio_window(audio_window, state, final=True) + + def remove_weight_norm(self): + self.decoder.remove_weight_norm() diff --git a/src/dots_tts/modules/vocoder/config.py b/src/dots_tts/modules/vocoder/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f3bd68d9afdcf03acc5624b31dced42319f2e4e2 --- /dev/null +++ b/src/dots_tts/modules/vocoder/config.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from pydantic import Field + +from dots_tts.config.base import ConfigBase + + +class AudioVAEConfig(ConfigBase): + sample_rate: int = 24000 + upsample_rates: list[int] = Field(default_factory=list) + upsample_kernel_sizes: list[int] = Field(default_factory=list) + upsample_initial_channel: int = 1536 + resblock: str = "1" + resblock_kernel_sizes: list[int] = Field(default_factory=list) + resblock_dilation_sizes: list[list[int]] = Field(default_factory=list) + downsample_rates: list[int] = Field(default_factory=list) + downsample_channels: list[int] = Field(default_factory=list) + activation: str = "snakebeta" + snake_logscale: bool = True + latent_dim: int = 128 + causal: bool = False + mi_num_layers: int = 4 + causal_encoder: bool = False + use_bias_at_final: bool = True + use_tanh_at_final: bool = True + + +__all__ = ["AudioVAEConfig"] diff --git a/src/dots_tts/runtime.py b/src/dots_tts/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4445b7684c4ddb59971ff4e9421db7fe52511f --- /dev/null +++ b/src/dots_tts/runtime.py @@ -0,0 +1,566 @@ +from __future__ import annotations + +import hashlib +import json +import time +from pathlib import Path +from typing import Any, Iterator, TypedDict + +import librosa +import torch +from huggingface_hub import snapshot_download +from loguru import logger + +from dots_tts.data.pipelines.tokenizing import build_generation_schedule +from dots_tts.data.pipelines.tts_pipeline import ( + DEFAULT_INSTRUCTION_TTS_TEMPLATE, + DEFAULT_INTERLEAVE_TRAIN_TEMPLATE, + DEFAULT_TEXT_TO_AUDIO_TEMPLATE, + DEFAULT_TRAIN_TEMPLATE, +) +from dots_tts.models.dots_tts.model import DotsTtsModel +from dots_tts.utils.audio import high_quality_resample +from dots_tts.utils.profiling import ( + InferenceProfiler, + activate_inference_profiler, + inference_profiling, + log_inference_profile, +) +from dots_tts.utils.text import ( + attach_language_tag, + detect, + normalize_language_code, + normalize_text, +) +from dots_tts.utils.util import get_dtype + +RUNTIME_TEMPLATE_BY_NAME = { + "tts": DEFAULT_TRAIN_TEMPLATE, + "instruction_tts": DEFAULT_INSTRUCTION_TTS_TEMPLATE, + "text_to_audio": DEFAULT_TEXT_TO_AUDIO_TEMPLATE, + "tts_interleave": DEFAULT_INTERLEAVE_TRAIN_TEMPLATE, +} + + +class RuntimeInputs(TypedDict, total=False): + fid: str + language: str + text: str + prompt_text: str + template_name: str + generation_schedule: torch.Tensor + prompt_audio: torch.Tensor + + +class DotsTtsRuntime: + # region Lifecycle and pretrained loading + def __init__( + self, + model: DotsTtsModel, + pretrained_path: Path, + *, + precision: str = "bfloat16", + optimize: bool = False, + max_generate_length: int = 500, + ): + self.model = model + self.pretrained_path = pretrained_path + self.precision = precision + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + torch.set_num_threads(1) + if self.device.type == "cuda" and self.precision.lower() in { + "fp32", + "torch.float32", + "float32", + }: + torch.set_float32_matmul_precision("high") + target_dtype = get_dtype(self.precision) + self.model.core.to(dtype=target_dtype) + self.model = self.model.to(self.device).eval() + self.optimize = bool(optimize) + self.max_generate_length = int(max_generate_length) + self.model.set_optimize(self.optimize) + self.sample_rate = int(self.model.config.vocoder.sample_rate) + skip_init_warmup = os.environ.get("DOTS_TTS_SKIP_INIT_WARMUP", "0") == "1" + if self.optimize and hasattr(self.model, "run_warmup") and not skip_init_warmup: + self.model.run_warmup( + max_generate_length=self.max_generate_length, + precision=self.precision, + ) + logger.info( + "Runtime initialized: pretrained_path={} device={} sample_rate={} " + "precision={} " + "optimize={} max_audio_patch_count={}", + self.pretrained_path, + self.device, + self.sample_rate, + self.precision, + self.optimize, + self.max_generate_length, + ) + + @classmethod + def from_pretrained( + cls, + model_name_or_path: str, + *, + revision: str | None = None, + cache_dir: str | None = None, + precision: str = "bfloat16", + optimize: bool = False, + max_generate_length: int = 500, + ) -> DotsTtsRuntime: + logger.info( + "Runtime load started: model={} revision={} cache_dir={} precision={}", + model_name_or_path, + revision, + cache_dir, + precision, + ) + pretrained_path = cls._resolve_pretrained_path( + model_name_or_path, + revision=revision, + cache_dir=cache_dir, + ) + loaded_model = DotsTtsModel.from_pretrained(pretrained_path) + logger.info("Runtime load completed: pretrained_path={}", pretrained_path) + return cls( + model=loaded_model, + pretrained_path=pretrained_path, + precision=precision, + optimize=optimize, + max_generate_length=max_generate_length, + ) + + @classmethod + def _resolve_pretrained_path( + cls, + model_name_or_path: str, + revision: str | None = None, + cache_dir: str | None = None, + ) -> Path: + logger.info( + "Resolving pretrained path: model={} revision={} cache_dir={}", + model_name_or_path, + revision, + cache_dir, + ) + resolved_path = Path(model_name_or_path).expanduser().resolve() + if resolved_path.exists(): + logger.info("Using local pretrained directory: path={}", resolved_path) + return resolved_path + + logger.info( + "Downloading pretrained snapshot: repo_id={} revision={} cache_dir={}", + model_name_or_path, + revision, + cache_dir, + ) + snapshot_dir = snapshot_download( + repo_id=model_name_or_path, + revision=revision, + cache_dir=cache_dir, + ) + resolved_path = Path(snapshot_dir).expanduser().resolve() + logger.info("Pretrained snapshot ready: path={}", resolved_path) + return resolved_path + # endregion Lifecycle and pretrained loading + + # region Request normalization and metadata + @staticmethod + def _build_request_id( + *, + text: str, + prompt_audio_path: str | None, + prompt_text: str | None, + template_name: str, + language: str | None = None, + ) -> str: + payload = { + "text": text, + "prompt_audio_path": prompt_audio_path, + "prompt_text": prompt_text, + "template_name": template_name, + } + if language is not None: + payload["language"] = language + digest = hashlib.sha1( + json.dumps(payload, ensure_ascii=False, sort_keys=True).encode("utf-8") + ).hexdigest() + return digest[:16] + + def _load_prompt_audio( + self, + prompt_audio_path: str, + ) -> torch.Tensor: + logger.info("Loading prompt audio: path={}", prompt_audio_path) + prompt_audio, sample_rate = librosa.load(prompt_audio_path, sr=None, mono=True) + prompt_audio = librosa.effects.trim(prompt_audio, top_db=30)[0] + prompt_audio = torch.from_numpy(prompt_audio).unsqueeze(0) + prompt_audio = high_quality_resample( + prompt_audio, + orig_sr=sample_rate, + target_sr=self.sample_rate, + ) + if prompt_audio.ndim == 1: + prompt_audio = prompt_audio.unsqueeze(0) + logger.info( + "Prompt audio loaded: path={} original_sample_rate={} resampled_sample_rate={} " + "samples={}", + prompt_audio_path, + sample_rate, + self.sample_rate, + prompt_audio.shape[-1], + ) + return prompt_audio + + def _resolve_language( + self, + language: str | None, + *, + text: str, + ) -> str | None: + if language is None: + return None + + stripped = language.strip() + if not stripped or stripped.lower() == "none": + return None + if stripped.lower() == "auto_detect": + return normalize_language_code(detect(text)) + + normalized_language = normalize_language_code(stripped) + if normalized_language is None: + raise ValueError( + f"Unsupported language={language!r}. " + "Expected 'none', 'auto_detect', or a valid language code/name." + ) + return normalized_language + + def _process_prompt_text( + self, + prompt_text: str | None, + *, + language: str | None = None, + ) -> str: + if prompt_text is None: + return "" + prompt_text = prompt_text.strip() + if not prompt_text: + return "" + + prompt_language = language + if prompt_language is None: + prompt_language = normalize_language_code(detect(prompt_text)) + + if prompt_language not in {"ZH", "YUE", "JA", "口音:粤语"}: + prompt_text += " " + if language is not None: + prompt_text = attach_language_tag(prompt_text, language) + return prompt_text + + def _process_text( + self, + text: str, + *, + language: str | None = None, + normalize: bool = False, + ) -> tuple[str, str | None]: + stripped = text.strip() + if normalize: + stripped = normalize_text(stripped) + resolved_language = self._resolve_language(language, text=stripped) + return stripped, resolved_language + + def _estimate_prompt_audio_patch_count( + self, + *, + prompt_audio: torch.Tensor | None, + prompt_text: str, + ) -> int: + if prompt_audio is None or not prompt_text: + return 0 + samples_per_patch = int(self.model.config.patch_size * self.model.hop_size) + prompt_samples = int(prompt_audio.shape[-1]) + return (prompt_samples + samples_per_patch - 1) // samples_per_patch + # endregion Request normalization and metadata + + # region Generation schedule assembly + def _normalize_template_name(self, template_name: str | None) -> str: + if template_name is None: + return "tts" + if template_name not in RUNTIME_TEMPLATE_BY_NAME: + raise ValueError( + f"Unknown template_name={template_name!r}. " + f"Expected one of {sorted(RUNTIME_TEMPLATE_BY_NAME)}." + ) + return template_name + + def _prepare_inputs( + self, + *, + text: str, + prompt_audio_path: str | None, + prompt_text: str | None, + template_name: str | None, + language: str | None = None, + normalize_text: bool = False, + ) -> RuntimeInputs: + normalized_template_name = self._normalize_template_name(template_name) + template = RUNTIME_TEMPLATE_BY_NAME[normalized_template_name] + if prompt_text and not prompt_audio_path: + raise ValueError("prompt_text requires prompt_audio_path.") + + normalized_text, normalized_language = self._process_text( + text, + language=language, + normalize=normalize_text, + ) + normalized_prompt_text = self._process_prompt_text( + prompt_text, + language=normalized_language, + ) + if normalized_language is not None and not normalized_prompt_text: + normalized_text = attach_language_tag(normalized_text, normalized_language) + inputs: RuntimeInputs = { + "fid": self._build_request_id( + text=normalized_text, + prompt_audio_path=prompt_audio_path, + prompt_text=normalized_prompt_text, + template_name=normalized_template_name, + language=normalized_language, + ), + "language": normalized_language or "", + "text": normalized_text, + "prompt_text": normalized_prompt_text, + "template_name": normalized_template_name, + } + + if prompt_audio_path: + inputs["prompt_audio"] = self._load_prompt_audio(prompt_audio_path) + prompt_audio_patch_count = self._estimate_prompt_audio_patch_count( + prompt_audio=inputs.get("prompt_audio"), + prompt_text=normalized_prompt_text, + ) + if ( + prompt_audio_patch_count > 0 + and self.max_generate_length <= prompt_audio_patch_count + ): + raise ValueError( + "max_generate_length must exceed prompt audio patch count when prompt_text is provided: " + f"max_generate_length={self.max_generate_length} " + f"prompt_audio_patch_count={prompt_audio_patch_count}." + ) + + schedule_spec = build_generation_schedule( + text=f"{normalized_prompt_text}{normalized_text}", + tokenizer=self.model.tokenizer, + template=template, + max_audio_tokens=self.max_generate_length, + ) + schedule = torch.tensor( + schedule_spec["schedule_ids"], + dtype=torch.long, + device=self.device, + ) + inputs["generation_schedule"] = schedule.unsqueeze(0) + logger.info( + "Inputs prepared: request_id={} template_name={} " + "language={} text_len={} prompt_text_len={} schedule_length={} " + "prompt_audio_patch_count={} max_audio_patch_count={} has_prompt_audio={}", + inputs["fid"], + normalized_template_name, + normalized_language, + len(normalized_text), + len(normalized_prompt_text), + schedule.numel(), + prompt_audio_patch_count, + self.max_generate_length, + bool(prompt_audio_path), + ) + return inputs + # endregion Generation schedule assembly + + # region Public generation APIs + def generate_stream( + self, + *, + text: str, + prompt_audio_path: str | None = None, + prompt_text: str | None = None, + template_name: str | None = None, + language: str | None = None, + speaker_scale: float = 1.5, + ode_method: str = "euler", + num_steps: int = 10, + guidance_scale: float = 1.2, + normalize_text: bool = False, + profile_inference: bool = False, + ) -> Iterator[torch.Tensor]: + inputs = self._prepare_inputs( + text=text, + prompt_audio_path=prompt_audio_path, + prompt_text=prompt_text, + template_name=template_name, + language=language, + normalize_text=normalize_text, + ) + logger.info( + "Streaming generation started: request_id={} text_len={} has_prompt_audio={} " + "has_prompt_text={} template_name={} language={} precision={} ode_method={} num_steps={} " + "guidance_scale={} speaker_scale={} max_audio_patch_count={} normalize_text={}", + inputs["fid"], + len(inputs["text"]), + bool(prompt_audio_path), + bool(inputs["prompt_text"]), + inputs["template_name"], + inputs["language"] or None, + self.precision, + ode_method, + num_steps, + guidance_scale, + speaker_scale, + self.max_generate_length, + normalize_text, + ) + start_time = time.time() + emitted_samples = 0 + chunk_count = 0 + profiler: InferenceProfiler | None = None + try: + profiler = ( + InferenceProfiler(self.device) if profile_inference else None + ) + stream = self.model.generate_audio_stream( + inputs, + precision=self.precision, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + speaker_scale=speaker_scale, + ) + while True: + try: + with activate_inference_profiler(profiler): + chunk = next(stream) + except StopIteration: + break + emitted_samples += int(chunk.shape[-1]) + chunk_count += 1 + yield chunk + except Exception: + logger.exception( + "Streaming generation failed: request_id={}", + inputs["fid"], + ) + raise + time_used = time.time() - start_time + duration_seconds = emitted_samples / self.sample_rate + rtf = time_used / duration_seconds if duration_seconds > 0 else float("inf") + if profile_inference and profiler is not None: + log_inference_profile( + request_id=inputs["fid"], + profiling=profiler.summary(duration_seconds=duration_seconds), + duration_seconds=duration_seconds, + ) + logger.info( + "Streaming generation finished: request_id={} chunk_count={} elapsed_seconds={:.3f} " + "audio_seconds={:.3f} rtf={:.4f} sample_rate={}", + inputs["fid"], + chunk_count, + time_used, + duration_seconds, + rtf, + self.sample_rate, + ) + + def generate( + self, + *, + text: str, + prompt_audio_path: str | None = None, + prompt_text: str | None = None, + template_name: str | None = None, + language: str | None = None, + speaker_scale: float = 1.5, + ode_method: str = "euler", + num_steps: int = 10, + guidance_scale: float = 1.2, + normalize_text: bool = False, + profile_inference: bool = False, + ) -> dict[str, Any]: + inputs = self._prepare_inputs( + text=text, + prompt_audio_path=prompt_audio_path, + prompt_text=prompt_text, + template_name=template_name, + language=language, + normalize_text=normalize_text, + ) + logger.info( + "Generation started: request_id={} text_len={} has_prompt_audio={} " + "has_prompt_text={} template_name={} language={} precision={} ode_method={} num_steps={} " + "guidance_scale={} speaker_scale={} max_audio_patch_count={} normalize_text={}", + inputs["fid"], + len(inputs["text"]), + bool(prompt_audio_path), + bool(inputs["prompt_text"]), + inputs["template_name"], + inputs["language"] or None, + self.precision, + ode_method, + num_steps, + guidance_scale, + speaker_scale, + self.max_generate_length, + normalize_text, + ) + start_time = time.time() + profiling = None + try: + with inference_profiling( + enabled=profile_inference, + device=self.device, + ) as profiler: + audio = self.model.generate_audio( + inputs, + precision=self.precision, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + speaker_scale=speaker_scale, + ) + except Exception: + logger.exception("Generation failed: request_id={}", inputs["fid"]) + raise + time_used = time.time() - start_time + duration_seconds = audio.shape[-1] / self.sample_rate + rtf = time_used / duration_seconds if duration_seconds > 0 else float("inf") + if profiler is not None: + profiling = profiler.summary(duration_seconds=duration_seconds) + log_inference_profile( + request_id=inputs["fid"], + profiling=profiling, + duration_seconds=duration_seconds, + ) + logger.info( + "Generation completed: request_id={} elapsed_seconds={:.3f} audio_seconds={:.3f} " + "rtf={:.4f} sample_rate={}", + inputs["fid"], + time_used, + duration_seconds, + rtf, + self.sample_rate, + ) + return { + "fid": inputs["fid"], + "audio": audio, + "sample_rate": self.sample_rate, + "time_used": time_used, + "rtf": rtf, + "profiling": profiling, + } + # endregion Public generation APIs diff --git a/src/dots_tts/runtime_double_streaming.py b/src/dots_tts/runtime_double_streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..208f5dd246f65151da1b18e756f58bd692bcf64c --- /dev/null +++ b/src/dots_tts/runtime_double_streaming.py @@ -0,0 +1,355 @@ +from __future__ import annotations + +from pathlib import Path + +import torch +from loguru import logger + +from dots_tts.data.pipelines.tts_pipeline import TTS_INTERLEAVE_PREFIX +from dots_tts.runtime import DotsTtsRuntime +from dots_tts.utils.util import get_dtype + + +class DoubleStreamingSession: + """Incremental interleave session for text-token to audio-chunk generation.""" + + def __init__( + self, + runtime: DotsTtsRuntime, + *, + prompt_audio_path: str | None = None, + prompt_text: str | None = None, + ode_method: str = "euler", + num_steps: int = 10, + guidance_scale: float = 1.2, + speaker_scale: float = 1.5, + eos_threshold: float = 0.8, + initial_silence_audio_tokens: int = 1, + ) -> None: + normalized_prompt_text = runtime._process_prompt_text(prompt_text) + if normalized_prompt_text: + raise ValueError("Double streaming does not support prompt_text.") + + self.runtime = runtime + self.model = runtime.model + self.device = runtime.device + self.ode_method = ode_method + self.num_steps = int(num_steps) + self.guidance_scale = float(guidance_scale) + self.speaker_scale = float(speaker_scale) + self.eos_threshold = float(eos_threshold) + self.max_generate_length = runtime.max_generate_length + self._initial_silence_audio_tokens = max( + 0, + min(10, int(initial_silence_audio_tokens or 0)), + ) + + self._dtype = get_dtype(runtime.precision) + self._use_amp = self.device.type == "cuda" and self._dtype in { + torch.float16, + torch.bfloat16, + } + self._prefix_token_ids = tuple( + self.model.tokenizer.encode( + TTS_INTERLEAVE_PREFIX, + add_special_tokens=False, + ) + ) + self._state = self.model._allocate_generate_state( + max_audio_patch_count=self.max_generate_length, + device=self.device, + dtype=self._dtype, + ) + self._vocoder_state = self.model.vocoder.init_stream_state( + batch_size=1, + chunk_size=self.model.core.latent_patch_size, + ) + self._g_cond = None + self._started = False + self._text_finished = False + self._closed = False + self._decoded_patch_count = 0 + + if prompt_audio_path is not None: + cache = getattr(self.runtime, "_double_streaming_prompt_g_cond_cache", None) + if cache is None: + cache = {} + setattr(self.runtime, "_double_streaming_prompt_g_cond_cache", cache) + prompt_cache_key = ( + str(Path(prompt_audio_path).expanduser().resolve()), + str(self.device), + str(self._dtype), + self.speaker_scale, + ) + cached_g_cond = cache.get(prompt_cache_key) + if cached_g_cond is None: + prompt_audio = self.runtime._load_prompt_audio(prompt_audio_path) + with torch.no_grad(): + with torch.autocast( + device_type=self.device.type, + dtype=self._dtype, + enabled=self._use_amp, + ): + prompt_conditioning = self.model._prepare_prompt_conditioning( + prompt_audio, + use_prompt_prefill=False, + speaker_scale=self.speaker_scale, + ) + cached_g_cond = prompt_conditioning.g_cond.detach() + cache[prompt_cache_key] = cached_g_cond + logger.info( + "Double streaming prompt conditioning cached: path={} device={} " + "dtype={} speaker_scale={}", + prompt_cache_key[0], + self.device, + self._dtype, + self.speaker_scale, + ) + else: + logger.info( + "Double streaming prompt conditioning cache hit: path={} device={} " + "dtype={} speaker_scale={}", + prompt_cache_key[0], + self.device, + self._dtype, + self.speaker_scale, + ) + self._g_cond = cached_g_cond + + logger.info( + "Double streaming session started: prefix_token_count={} precision={} " + "ode_method={} num_steps={} guidance_scale={} speaker_scale={} max_audio_patch_count={} " + "initial_silence_audio_tokens={} has_ref_audio_only={}", + len(self._prefix_token_ids), + runtime.precision, + self.ode_method, + self.num_steps, + self.guidance_scale, + self.speaker_scale, + self.max_generate_length, + self._initial_silence_audio_tokens, + self._g_cond is not None, + ) + + @property + def is_finished(self) -> bool: + return self._closed + + def push_text_token(self, text_token: int) -> torch.Tensor | None: + self._ensure_active() + if self._text_finished: + raise RuntimeError("Cannot push text tokens after finish_text().") + if self._state.end_flag: + raise RuntimeError( + "Double streaming generation has already reached EOS. " + "Call finish_text() to flush the remaining audio tail." + ) + + token_id = int(text_token) + if not self._started: + chunk_token_ids = [*self._prefix_token_ids, token_id] + self._started = True + else: + chunk_token_ids = [token_id] + + self._consume_text_chunk(chunk_token_ids) + return self._decode_audio_chunk() + + def finish_text(self): + self._ensure_active() + + if not self._state.end_flag: + if not self._text_finished: + text_end_chunk = [self.model.core.text_cond_end_id] + if not self._started: + text_end_chunk = [*self._prefix_token_ids, *text_end_chunk] + self._started = True + self._consume_text_chunk(text_end_chunk) + self._text_finished = True + + while not self._state.end_flag: + audio_chunk = self._decode_audio_chunk(continue_audio_span=True) + if audio_chunk is not None: + yield audio_chunk + else: + self._text_finished = True + + final_chunk = self.model.vocoder.stream_flush(self._vocoder_state) + self._closed = True + logger.info( + "Double streaming session finished: decoded_patch_count={}", + self._decoded_patch_count, + ) + if final_chunk.size(-1) > 0: + yield final_chunk + + def _ensure_active(self) -> None: + if self._closed: + raise RuntimeError("Double streaming session is already closed.") + + def _consume_text_chunk(self, token_ids: list[int]) -> None: + schedule = torch.tensor( + [token_ids], + dtype=torch.long, + device=self.device, + ) + with torch.no_grad(): + with torch.autocast( + device_type=self.device.type, + dtype=self._dtype, + enabled=self._use_amp, + ): + self.model._consume_text_schedule( + schedule, + position=0, + next_audio_position=schedule.size(1), + state=self._state, + ) + + def _get_initial_silence_audio_patch( + self, + patch_index: int, + audio_patch: torch.Tensor, + ) -> torch.Tensor: + cache = getattr(self.runtime, "_double_streaming_silence_audio_patch_cache", None) + if cache is None: + cache = {} + setattr(self.runtime, "_double_streaming_silence_audio_patch_cache", cache) + + cache_count = 10 + patch_size = int(self.model.core.latent_patch_size) + key = ( + str(self.device), + str(self._dtype), + patch_size, + int(audio_patch.size(-1)), + cache_count, + ) + cached_patches = cache.get(key) + if cached_patches is None: + hop_size = int(getattr(self.model.vocoder, "hop_size", 1)) + zero_samples = cache_count * patch_size * hop_size + zero_audio = torch.zeros( + (1, 1, zero_samples), + device=self.device, + dtype=torch.float32, + ) + silence_latents = self.model.vocoder.extract_latents(zero_audio) + silence_latents, _ = torch.split( + silence_latents, + int(audio_patch.size(-1)), + dim=1, + ) + silence_latents = silence_latents.transpose(1, 2) + target_frames = cache_count * patch_size + if silence_latents.size(1) < target_frames: + silence_latents = torch.cat( + [ + silence_latents, + silence_latents.new_zeros( + ( + silence_latents.size(0), + target_frames - silence_latents.size(1), + silence_latents.size(2), + ) + ), + ], + dim=1, + ) + silence_latents = silence_latents[:, :target_frames, :] + cached_patches = self.model.core.io_helper.normalize(silence_latents) + cached_patches = cached_patches.to(device=self.device, dtype=audio_patch.dtype) + cached_patches = cached_patches.reshape( + 1, + cache_count, + patch_size, + int(audio_patch.size(-1)), + ).detach() + cache[key] = cached_patches + logger.info( + "Double streaming initial silence cache built: patches={} patch_size={} " + "hop_size={} device={} dtype={}", + cache_count, + patch_size, + hop_size, + self.device, + audio_patch.dtype, + ) + return cached_patches[:, int(patch_index)].clone() + + def _consume_audio_patch(self, audio_patch: torch.Tensor) -> None: + self.model._consume_audio_patch(self._state, audio_patch=audio_patch) + + def _decode_audio_chunk(self, *, continue_audio_span: bool = False) -> torch.Tensor | None: + if self._decoded_patch_count >= self.max_generate_length: + raise RuntimeError( + "Double streaming exceeded max_generate_length before reaching EOS." + ) + + with torch.no_grad(): + with torch.autocast( + device_type=self.device.type, + dtype=self._dtype, + enabled=self._use_amp, + ): + stop_after_current_audio = self.model._should_stop_after_current_audio( + self._state, + eos_threshold=self.eos_threshold, + ) + audio_patch = self.model._decode_next_audio( + self._state, + device=self.device, + g_cond=self._g_cond, + ode_method=self.ode_method, + num_steps=self.num_steps, + guidance_scale=self.guidance_scale, + ) + if self._decoded_patch_count < self._initial_silence_audio_tokens: + audio_patch = self._get_initial_silence_audio_patch( + self._decoded_patch_count, + audio_patch, + ) + self._consume_audio_patch(audio_patch) + if continue_audio_span: + self.model._append_hidden_chunk(self._state, self._state.llm_hiddens) + self._decoded_patch_count += 1 + latent_patch = self.model.core.io_helper.denormalize(audio_patch) + audio_chunk = self.model.vocoder.stream_step( + latent_patch.transpose(1, 2), + self._vocoder_state, + ) + if stop_after_current_audio: + self._state.end_flag = True + + if audio_chunk.size(-1) == 0: + return None + return audio_chunk + + +class DotsTtsRuntimeDoubleStreaming(DotsTtsRuntime): + def start_double_streaming( + self, + *, + prompt_audio_path: str | None = None, + prompt_text: str | None = None, + ode_method: str = "euler", + num_steps: int = 10, + guidance_scale: float = 1.2, + speaker_scale: float = 1.5, + eos_threshold: float = 0.8, + initial_silence_audio_tokens: int = 1, + ) -> DoubleStreamingSession: + return DoubleStreamingSession( + self, + prompt_audio_path=prompt_audio_path, + prompt_text=prompt_text, + ode_method=ode_method, + num_steps=num_steps, + guidance_scale=guidance_scale, + speaker_scale=speaker_scale, + eos_threshold=eos_threshold, + initial_silence_audio_tokens=initial_silence_audio_tokens, + ) + + +__all__ = ["DotsTtsRuntimeDoubleStreaming", "DoubleStreamingSession"] diff --git a/src/dots_tts/training/__init__.py b/src/dots_tts/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d0dfdc6a265d0c56e3bff22522478b338c5ba3 --- /dev/null +++ b/src/dots_tts/training/__init__.py @@ -0,0 +1 @@ +"""Training package.""" diff --git a/src/dots_tts/training/checkpoint.py b/src/dots_tts/training/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..2c448e8848183aa56831d28947a0efe346aa5c6d --- /dev/null +++ b/src/dots_tts/training/checkpoint.py @@ -0,0 +1,315 @@ +"""Checkpoint helpers for distributed dots_tts training. + +This module persists not only model/optimizer/scheduler state, but also +rank-local RNG state and data-loader progress so resumed training can continue +from the same point with minimal drift. +""" + +from __future__ import annotations + +import json +import random +import shutil +from dataclasses import fields +from pathlib import Path + +import numpy as np +import torch +import torch.distributed as dist + + +def _checkpoint_dir(log_dir: str, step: int) -> Path: + """Return the canonical directory name for a training step checkpoint.""" + return Path(log_dir) / f"checkpoint-{step:08d}" + + +def _checkpoint_entries(log_dir: str) -> list[tuple[int, Path]]: + """List valid ``checkpoint-*`` directories sorted by step number.""" + entries = [] + for path in Path(log_dir).glob("checkpoint-*"): + if not path.is_dir(): + continue + suffix = path.name.removeprefix("checkpoint-") + if suffix.isdigit(): + entries.append((int(suffix), path)) + return sorted(entries) + + +def resolve_latest_train_checkpoint(log_dir: str) -> Path: + """Resolve the checkpoint directory that should be used for resume. + + Preference order: + 1. ``/latest`` symlink, if present. + 2. The numerically largest ``checkpoint-*`` directory. + """ + latest_path = Path(log_dir) / "latest" + if latest_path.exists() or latest_path.is_symlink(): + return latest_path.resolve(strict=True) + + entries = _checkpoint_entries(log_dir) + if not entries: + raise FileNotFoundError( + f"No checkpoint found under {log_dir!s}; expected latest or checkpoint-*." + ) + return entries[-1][1].resolve(strict=True) + + +def _rng_state() -> dict: + """Capture Python/NumPy/PyTorch RNG state for deterministic resume.""" + numpy_state = np.random.get_state() + state = { + "torch": torch.get_rng_state(), + "python": random.getstate(), + "numpy": { + "bit_generator": str(numpy_state[0]), + "keys": numpy_state[1].tolist(), + "pos": int(numpy_state[2]), + "has_gauss": int(numpy_state[3]), + "cached_gaussian": float(numpy_state[4]), + }, + } + if torch.cuda.is_available(): + state["cuda"] = torch.cuda.get_rng_state_all() + return state + + +def _restore_rng_state(state: dict) -> None: + """Restore RNG state previously produced by :func:`_rng_state`.""" + torch.set_rng_state(state["torch"]) + random.setstate(state["python"]) + numpy_state = state["numpy"] + np.random.set_state( + ( + numpy_state["bit_generator"], + np.asarray(numpy_state["keys"], dtype=np.uint32), + int(numpy_state["pos"]), + int(numpy_state["has_gauss"]), + float(numpy_state["cached_gaussian"]), + ) + ) + if torch.cuda.is_available(): + torch.cuda.set_rng_state_all(state["cuda"]) + + +def _replace_latest_symlink(log_dir: str, save_dir: Path) -> None: + """Atomically refresh the ``latest`` symlink to point at ``save_dir``.""" + log_path = Path(log_dir) + link_path = log_path / "latest" + tmp_link_path = log_path / "latest.tmp" + + if tmp_link_path.exists() or tmp_link_path.is_symlink(): + tmp_link_path.unlink() + tmp_link_path.symlink_to(save_dir.name) + + if link_path.exists() or link_path.is_symlink(): + if link_path.is_dir() and not link_path.is_symlink(): + shutil.rmtree(link_path) + else: + link_path.unlink() + tmp_link_path.rename(link_path) + + +def _cleanup_old_checkpoints(log_dir: str, keep_max: int) -> None: + """Delete older checkpoints while keeping the newest ``keep_max`` ones.""" + if keep_max <= 0: + return + for _, path in _checkpoint_entries(log_dir)[:-keep_max]: + shutil.rmtree(path, ignore_errors=True) + + +def _pack_rank_payload(accelerator, payload: dict, *, payload_name: str) -> dict | None: + """Collect rank-local payloads onto the main process for checkpointing. + + Some training state is intentionally local to each rank, for example RNG + state or data-loader shard progress. We therefore gather a per-rank payload + and store it in the checkpoint as ``{world_size, per_rank}``. + """ + local_payload = { + "rank": int(accelerator.process_index), + "payload": payload, + } + if dist.is_available() and dist.is_initialized(): + gathered: list[dict | None] = [None] * int(accelerator.num_processes) + dist.all_gather_object(gathered, local_payload) + else: + gathered = [local_payload] + + if not accelerator.is_main_process: + return None + + per_rank = {} + for item in gathered: + if not isinstance(item, dict): + raise RuntimeError( + f"Failed to gather rank-scoped {payload_name} for checkpointing." + ) + per_rank[str(int(item["rank"]))] = item["payload"] + return { + "world_size": len(gathered), + "per_rank": per_rank, + } + + +def _extract_rank_payload( + accelerator, payload: dict | None, *, payload_name: str +) -> dict: + """Recover the payload for the current rank from a packed checkpoint blob.""" + if payload is None: + return {} + + expected_world_size = int(accelerator.num_processes) + if int(payload["world_size"]) != expected_world_size: + raise RuntimeError( + f"Checkpoint {payload_name} payload does not match the current world." + ) + + local_rank = str(int(accelerator.process_index)) + if local_rank not in payload["per_rank"]: + raise RuntimeError(f"Checkpoint {payload_name} is missing rank {local_rank}.") + return payload["per_rank"][local_rank] + + +def save_train_checkpoint( + accelerator, + model, + optimizer, + progress, + log_dir: str, + keep_max: int, + data_state: dict, + scheduler_state: dict, +) -> None: + """Save a full resumable training checkpoint. + + Stored artifacts include: + - model weights in ``save_pretrained`` format + - optimizer / scheduler / scaler state + - training progress counters + - rank-local RNG state + - rank-local data pipeline state + """ + accelerator.wait_for_everyone() + packed_data_state = _pack_rank_payload( + accelerator, + data_state, + payload_name="data_state", + ) + packed_rng_state = _pack_rank_payload( + accelerator, + _rng_state(), + payload_name="rng_state", + ) + + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + save_dir = _checkpoint_dir(log_dir, progress.global_step) + tmp_dir = save_dir.with_name(f"{save_dir.name}.tmp") + model_dir = tmp_dir / "model" + scaler = getattr(accelerator, "scaler", None) + + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + model_dir.mkdir(parents=True, exist_ok=True) + + try: + # Write into a temporary directory first so interrupted saves never + # leave behind a half-written checkpoint that looks valid. + unwrapped_model.save_pretrained(model_dir) + + torch.save(optimizer.state_dict(), tmp_dir / "optimizer.pt") + torch.save(scheduler_state, tmp_dir / "scheduler.pt") + torch.save( + {} if scaler is None else scaler.state_dict(), + tmp_dir / "scaler.pt", + ) + torch.save(packed_rng_state, tmp_dir / "rng_state.pt") + torch.save(packed_data_state, tmp_dir / "data_state.pt") + (tmp_dir / "trainer_state.json").write_text( + json.dumps( + { + field.name: int(getattr(progress, field.name)) + for field in fields(progress) + }, + ensure_ascii=True, + indent=2, + ), + encoding="utf-8", + ) + + if save_dir.exists(): + shutil.rmtree(save_dir) + tmp_dir.rename(save_dir) + _replace_latest_symlink(log_dir, save_dir) + _cleanup_old_checkpoints(log_dir, keep_max) + accelerator.print(f"Checkpoint saved: {save_dir}") + except Exception: + if tmp_dir.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + raise + + accelerator.wait_for_everyone() + + +def load_train_checkpoint( + accelerator, + model, + optimizer, + progress, + checkpoint_dir: str | Path, + scheduler, +) -> dict: + """Restore a checkpoint previously written by :func:`save_train_checkpoint`. + + Returns auxiliary state that the caller usually needs to resume the input + pipeline and scheduler bookkeeping. + """ + checkpoint_dir = Path(checkpoint_dir) + model_dir = checkpoint_dir / "model" + if not model_dir.is_dir(): + raise FileNotFoundError(f"Checkpoint model directory not found: {model_dir!s}") + + accelerator.wait_for_everyone() + + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.load_pretrained_weights(model_dir) + + optimizer.load_state_dict( + torch.load(checkpoint_dir / "optimizer.pt", map_location="cpu") + ) + + scheduler_payload = torch.load(checkpoint_dir / "scheduler.pt", map_location="cpu") + scheduler.load_state_dict(scheduler_payload["state_dict"]) + + scaler = getattr(accelerator, "scaler", None) + scaler_state = torch.load(checkpoint_dir / "scaler.pt", map_location="cpu") + if scaler is not None and scaler_state: + scaler.load_state_dict(scaler_state) + + rng_state_payload = torch.load(checkpoint_dir / "rng_state.pt", map_location="cpu") + _restore_rng_state( + _extract_rank_payload( + accelerator, + rng_state_payload, + payload_name="rng_state", + ) + ) + data_state_payload = torch.load( + checkpoint_dir / "data_state.pt", map_location="cpu" + ) + + trainer_state = json.loads( + (checkpoint_dir / "trainer_state.json").read_text(encoding="utf-8") + ) + for field in fields(progress): + setattr(progress, field.name, int(trainer_state[field.name])) + + accelerator.wait_for_everyone() + return { + "checkpoint_dir": checkpoint_dir, + "data_state": _extract_rank_payload( + accelerator, + data_state_payload, + payload_name="data_state", + ), + "scheduler_state": scheduler_payload, + } diff --git a/src/dots_tts/training/losses.py b/src/dots_tts/training/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..32b387373570cee83b91e0b913cfd13d478b4b45 --- /dev/null +++ b/src/dots_tts/training/losses.py @@ -0,0 +1,365 @@ +"""Loss aggregation helpers shared by training and validation. + +The model returns masked per-token / per-patch loss tensors. This module turns +them into numerators/denominators for logging, combines configured loss weights, +and provides distributed reduction helpers. +""" + +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, TypeAlias + +import torch +import torch.distributed as dist + +from dots_tts.utils.util import scalar_as_float + + +@dataclass(frozen=True) +class LossTerm: + """A loss tensor paired with a same-shape mask. + + ``loss`` stores unreduced per-element values. + ``mask`` stores the weighting/validity for the same positions. + """ + + loss: torch.Tensor + mask: torch.Tensor + + def __post_init__(self) -> None: + if self.loss.shape != self.mask.shape: + raise ValueError( + "LossTerm expects loss and mask to have the same shape, " + f"but got {tuple(self.loss.shape)} and {tuple(self.mask.shape)}." + ) + + +LossTerms: TypeAlias = dict[str, LossTerm] +LossMasks: TypeAlias = dict[str, torch.Tensor] + + +def _safe_average(numerator: Any, denominator: Any) -> Any: + """Average safely when a mask may produce a zero denominator.""" + if isinstance(numerator, torch.Tensor): + denom = denominator + if not isinstance(denom, torch.Tensor): + denom = numerator.new_tensor(float(denominator)) + if float(denom.detach().item()) <= 0.0: + return numerator * 0.0 + return numerator / denom.clamp_min(1.0).to(numerator.dtype) + + denom = float(denominator) + if denom <= 0.0: + return 0.0 + return float(numerator) / denom + + +def _as_weight_map(loss_config) -> dict[str, float]: + """Extract ``*_weight`` fields from config into ``*_loss`` weights.""" + weights = {} + for name, value in loss_config.model_dump().items(): + if name.endswith("_weight"): + weights[f"{name[:-7]}_loss"] = float(value) + return weights + + +def accumulate_named_scalars_( + target: dict[str, float], + source: dict[str, float], +) -> dict[str, float]: + """In-place add ``source`` scalar values into ``target`` by key.""" + for name, value in source.items(): + target[name] += float(value) + return target + + +def to_host_named_scalars(values: dict[str, Any]) -> dict[str, float]: + """Convert scalar tensors into plain Python floats for logging/serialization.""" + return {name: scalar_as_float(value) for name, value in values.items()} + + +def collapse_loss_masks( + loss_masks: LossMasks, +) -> dict[str, Any]: + """Reduce each loss mask to its total effective weight.""" + return {name: mask.sum() for name, mask in loss_masks.items()} + + +def collapse_loss_terms( + loss_terms: LossTerms, + *, + indices: list[int] | None = None, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Convert masked per-sample loss terms into ``sum(loss*mask)`` statistics. + + Returns ``(numerators, normalizers)`` so the caller can aggregate them across + batches or ranks before taking the final average. + """ + index = None + if indices is not None: + first = next(iter(loss_terms.values())) + index = torch.tensor(indices, device=first.loss.device, dtype=torch.long) + + numerators = {} + normalizers = {} + for name, term in loss_terms.items(): + loss = term.loss + mask = term.mask + if index is not None: + loss = loss.index_select(0, index) + mask = mask.index_select(0, index) + mask = mask.to(loss.dtype) + numerators[name] = (loss * mask).sum() + normalizers[name] = mask.sum() + return numerators, normalizers + + +def collapse_loss_terms_by_source( + loss_terms: LossTerms, + *, + source_names: list[str | None], +) -> tuple[dict[str, dict[str, float]], dict[str, dict[str, float]]]: + """Group collapsed loss statistics by dataset/source name within a batch.""" + first = next(iter(loss_terms.values())) + batch_size = int(first.loss.size(0)) + if len(source_names) != batch_size: + raise RuntimeError( + "source_names must align with the batch size for source loss statistics. " + f"Expected {batch_size}, got {len(source_names)}." + ) + + source_indices: dict[str, list[int]] = defaultdict(list) + for index, source_name in enumerate(source_names): + if source_name is None: + raise RuntimeError("source_names must not contain None.") + source_indices[str(source_name)].append(index) + + numerators_by_source = {} + normalizers_by_source = {} + for source_name, indices in source_indices.items(): + numerators, normalizers = collapse_loss_terms(loss_terms, indices=indices) + numerators_by_source[source_name] = to_host_named_scalars(numerators) + normalizers_by_source[source_name] = to_host_named_scalars(normalizers) + return numerators_by_source, normalizers_by_source + + +def reduce_loss_statistics( + numerators: dict[str, Any], + normalizers: dict[str, Any], + *, + loss_config, +) -> dict[str, Any]: + """Turn aggregated numerators/normalizers into averaged metrics. + + The returned mapping includes each individual loss plus a weighted ``loss`` + field assembled from ``loss_config``. + """ + weights = _as_weight_map(loss_config) + reduced = {} + total_loss: Any = 0.0 + for name, numerator in sorted(numerators.items()): + value = _safe_average(numerator, normalizers[name]) + reduced[name] = value + total_loss = total_loss + value * weights.get(name, 1.0) + reduced["loss"] = total_loss + return reduced + + +def reduce_loss_statistics_by_source( + numerators_by_source: dict[str, dict[str, float]], + normalizers_by_source: dict[str, dict[str, float]], + *, + loss_config, +) -> dict[str, dict[str, Any]]: + """Apply :func:`reduce_loss_statistics` independently for each source.""" + return { + source_name: reduce_loss_statistics( + numerators, + normalizers_by_source[source_name], + loss_config=loss_config, + ) + for source_name, numerators in sorted(numerators_by_source.items()) + } + + +def compute_gradient_loss( + loss_terms: LossTerms, + *, + global_normalizers: dict[str, float], + loss_config, + ddp_world_size: int, + gradient_accumulation_steps: int, +) -> torch.Tensor: + """Build the scalar loss used for ``backward()``. + + ``global_normalizers`` is expected to already include cross-rank totals. The + final scaling by world size and accumulation steps compensates for the mean + reduction that DDP/Accelerate applies during gradient synchronization. + """ + numerators, _ = collapse_loss_terms(loss_terms) + weights = _as_weight_map(loss_config) + + total_loss: Any = 0.0 + for name, numerator in sorted(numerators.items()): + total_loss = total_loss + _safe_average( + numerator, + global_normalizers[name], + ) * weights.get(name, 1.0) + return total_loss * float(ddp_world_size) * float(gradient_accumulation_steps) + + +def sum_named_scalars_across_ranks( + values: dict[str, float], + *, + device: torch.device, +) -> dict[str, float]: + """All-reduce a dict of scalar values and return summed host floats.""" + names = _gather_string_union_across_ranks(values, device=device) + if not names: + return {} + + packed = torch.tensor( + [float(values.get(name, 0.0)) for name in names], + device=device, + dtype=torch.float64, + ) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(packed, op=dist.ReduceOp.SUM) + return { + name: float(value) + for name, value in zip(names, packed.tolist(), strict=True) + } + + +def sum_grouped_named_scalars_across_ranks( + values: dict[str, dict[str, float]], + *, + device: torch.device, +) -> dict[str, dict[str, float]]: + """All-reduce nested ``group -> metric -> value`` scalar mappings.""" + group_names = _gather_string_union_across_ranks(values, device=device) + if not group_names: + return {} + + metric_names = _gather_string_union_across_ranks( + ( + metric_name + for group_values in values.values() + for metric_name in group_values + ), + device=device, + ) + if not metric_names: + return {group_name: {} for group_name in group_names} + + packed = torch.tensor( + [ + float(values.get(group_name, {}).get(metric_name, 0.0)) + for group_name in group_names + for metric_name in metric_names + ], + device=device, + dtype=torch.float64, + ) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(packed, op=dist.ReduceOp.SUM) + packed = packed.view(len(group_names), len(metric_names)) + return { + group_name: { + metric_name: float(packed[group_index, metric_index].item()) + for metric_index, metric_name in enumerate(metric_names) + } + for group_index, group_name in enumerate(group_names) + } + + +def accumulate_grouped_named_scalars_( + target: dict[str, dict[str, float]], + source: dict[str, dict[str, float]], +) -> dict[str, dict[str, float]]: + """In-place add nested ``group -> metric -> value`` scalar mappings.""" + for group_name, values in source.items(): + group_target = target.get(group_name) + if group_target is None: + group_target = {name: 0.0 for name in values} + target[group_name] = group_target + for name, value in values.items(): + group_target[name] += float(value) + return target + + +def _gather_string_union_across_ranks( + values: Iterable[str], + *, + device: torch.device, +) -> list[str]: + strings = sorted({str(value) for value in values}) + if not (dist.is_available() and dist.is_initialized()): + return strings + + payload = _encode_string_list(strings) + world_size = dist.get_world_size() + local_size = torch.tensor([len(payload)], device=device, dtype=torch.int64) + size_tensors = [torch.zeros_like(local_size) for _ in range(world_size)] + dist.all_gather(size_tensors, local_size) + + max_size = max(int(size.item()) for size in size_tensors) + if max_size <= 0: + return [] + + local_bytes = torch.zeros(max_size, device=device, dtype=torch.uint8) + if payload: + local_bytes[: len(payload)] = torch.tensor( + list(payload), + device=device, + dtype=torch.uint8, + ) + + gathered_bytes = [ + torch.empty(max_size, device=device, dtype=torch.uint8) + for _ in range(world_size) + ] + dist.all_gather(gathered_bytes, local_bytes) + + union = set(strings) + for size_tensor, byte_tensor in zip(size_tensors, gathered_bytes, strict=True): + size = int(size_tensor.item()) + if size <= 0: + continue + union.update( + _decode_string_list(bytes(byte_tensor[:size].cpu().tolist())) + ) + return sorted(union) + + +def _encode_string_list(values: list[str]) -> bytes: + if any("\0" in value for value in values): + raise ValueError("Distributed scalar keys must not contain NUL characters.") + return "\0".join(values).encode("utf-8") + + +def _decode_string_list(payload: bytes) -> list[str]: + if not payload: + return [] + return payload.decode("utf-8").split("\0") + + +__all__ = [ + "accumulate_grouped_named_scalars_", + "LossMasks", + "LossTerm", + "LossTerms", + "accumulate_named_scalars_", + "collapse_loss_masks", + "collapse_loss_terms", + "collapse_loss_terms_by_source", + "compute_gradient_loss", + "reduce_loss_statistics", + "reduce_loss_statistics_by_source", + "sum_grouped_named_scalars_across_ranks", + "sum_named_scalars_across_ranks", + "to_host_named_scalars", +] diff --git a/src/dots_tts/training/utils.py b/src/dots_tts/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..74bbad852efd4cafd2abb4ecb6c8e525eb845a62 --- /dev/null +++ b/src/dots_tts/training/utils.py @@ -0,0 +1,650 @@ +"""Shared helpers for the dots_tts training entrypoints.""" + +from __future__ import annotations + +import math +import os +import sys +import traceback +from collections import Counter +from dataclasses import dataclass, fields, is_dataclass +from typing import Any + +import torch +import torch.distributed as dist + +from dots_tts.training import losses as loss_ops + +# --------------------------------------------------------------------------- +# Training State +# --------------------------------------------------------------------------- + + +@dataclass(slots=True) +class TrainProgress: + """Minimal progress counters that must survive checkpoint save/load.""" + + global_step: int = 0 + epoch: int = 0 + total_tokens: int = 0 + audio_tokens: int = 0 + text_tokens: int = 0 + + +@dataclass(slots=True) +class TrainStepReport: + log_values: dict[str, float] + console_line: str + + +# --------------------------------------------------------------------------- +# Distributed Helpers +# --------------------------------------------------------------------------- + + +def any_rank_true(flag: bool, *, device: torch.device) -> bool: + """Return ``True`` if any distributed rank reports ``flag=True``.""" + packed = torch.tensor(int(flag), device=device, dtype=torch.int32) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(packed, op=dist.ReduceOp.MAX) + return bool(packed.item()) + + +def sum_integer_counters_across_ranks( + values: list[int], + *, + device: torch.device, +) -> list[int]: + """All-reduce integer counters and return their cross-rank sums.""" + packed = torch.tensor(values, device=device, dtype=torch.int64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(packed, op=dist.ReduceOp.SUM) + return [int(value) for value in packed.tolist()] + + +def move_to_device(value, device): + """Recursively move nested tensors/dataclasses onto ``device``.""" + if isinstance(value, torch.Tensor): + return value.to(device, non_blocking=True) + if isinstance(value, dict): + return {key: move_to_device(item, device) for key, item in value.items()} + if isinstance(value, list): + return [move_to_device(item, device) for item in value] + if isinstance(value, tuple): + return tuple(move_to_device(item, device) for item in value) + if is_dataclass(value) and not isinstance(value, type): + return type(value)( + **{ + field.name: move_to_device(getattr(value, field.name), device) + for field in fields(value) + } + ) + return value + + +# --------------------------------------------------------------------------- +# Failure Handling +# --------------------------------------------------------------------------- + + +def abort_on_out_of_memory( + exc: BaseException, + *, + stage: str, + batch: dict[str, object] | None, + progress: TrainProgress, + device: torch.device, + process_index: int, + num_processes: int, +) -> None: + if not _is_out_of_memory_error(exc): + return + + message = ( + "Fatal out-of-memory during training. " + f"stage={stage}, " + f"epoch={progress.epoch}, " + f"global_step={progress.global_step}, " + f"rank={process_index}/{num_processes}. " + f"{_build_batch_memory_summary(batch)}. " + f"{_build_cuda_memory_summary(device)}." + ) + print(message, file=sys.stderr, flush=True) + traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr) + sys.stderr.flush() + + if num_processes > 1: + os._exit(1) + + +def _is_out_of_memory_error(exc: BaseException) -> bool: + oom_error_type = getattr(torch, "OutOfMemoryError", None) + if oom_error_type is not None and isinstance(exc, oom_error_type): + return True + if not isinstance(exc, RuntimeError): + return False + return "out of memory" in str(exc).lower() + + +def _build_batch_memory_summary(batch: dict[str, object] | None) -> str: + if not isinstance(batch, dict): + return "batch=unavailable" + + fields = [] + input_ids = batch.get("input_ids") + if isinstance(input_ids, torch.Tensor): + fields.append(f"input_ids_shape={tuple(input_ids.shape)}") + sample = batch.get("sample") + if isinstance(sample, torch.Tensor): + fields.append(f"sample_shape={tuple(sample.shape)}") + input_ids_lengths = batch.get("input_ids_lengths") + if isinstance(input_ids_lengths, torch.Tensor) and input_ids_lengths.numel() > 0: + fields.append( + f"max_input_ids_length={int(input_ids_lengths.max().detach().item())}" + ) + num_audio_tokens = batch.get("num_audio_tokens") + if isinstance(num_audio_tokens, torch.Tensor) and num_audio_tokens.numel() > 0: + fields.append(f"max_audio_tokens={int(num_audio_tokens.max().detach().item())}") + num_text_tokens = batch.get("num_text_tokens") + if isinstance(num_text_tokens, torch.Tensor) and num_text_tokens.numel() > 0: + fields.append(f"max_text_tokens={int(num_text_tokens.max().detach().item())}") + return ", ".join(fields) if fields else "batch=unavailable" + + +def _build_cuda_memory_summary(device: torch.device) -> str: + if device.type != "cuda" or not torch.cuda.is_available(): + return "device_memory=unavailable" + allocated = torch.cuda.memory_allocated(device) / (1024**3) + reserved = torch.cuda.memory_reserved(device) / (1024**3) + max_allocated = torch.cuda.max_memory_allocated(device) / (1024**3) + max_reserved = torch.cuda.max_memory_reserved(device) / (1024**3) + return ( + f"device={device}, " + f"allocated_gb={allocated:.2f}, " + f"reserved_gb={reserved:.2f}, " + f"max_allocated_gb={max_allocated:.2f}, " + f"max_reserved_gb={max_reserved:.2f}" + ) + + +# --------------------------------------------------------------------------- +# Debug Helpers +# --------------------------------------------------------------------------- + + +def build_data_debug_lines( + batch: dict[str, object], + *, + batch_index: int, + tokenizer: Any, + sample_rate: int, +) -> list[str]: + input_ids = batch["input_ids"] + input_ids_lengths = batch["input_ids_lengths"] + sample = batch["sample"] + sample_lengths = batch["sample_lengths"] + num_audio_tokens = batch["num_audio_tokens"] + num_text_tokens = batch["num_text_tokens"] + + if not isinstance(input_ids, torch.Tensor) or not isinstance( + input_ids_lengths, torch.Tensor + ): + raise TypeError("Debug batch requires tensor input_ids and input_ids_lengths.") + if not isinstance(sample, torch.Tensor) or not isinstance( + sample_lengths, torch.Tensor + ): + raise TypeError("Debug batch requires tensor sample and sample_lengths.") + if not isinstance(num_audio_tokens, torch.Tensor) or not isinstance( + num_text_tokens, torch.Tensor + ): + raise TypeError( + "Debug batch requires tensor num_audio_tokens and num_text_tokens." + ) + + source_names = batch.get("source_names") + debug_lines = [ + ( + "[debug:data] " + f"batch_index={batch_index} " + f"batch_size={int(input_ids.size(0))} " + f"input_ids_shape={tuple(input_ids.shape)} " + f"sample_shape={tuple(sample.shape)} " + f"sample_rate={sample_rate} " + f"sources={dict(Counter(source_names or []))}" + ), + ( + "[debug:data] " + f"input_tokens(min/mean/max)={_format_tensor_triplet(input_ids_lengths)} " + f"text_tokens(min/mean/max)={_format_tensor_triplet(num_text_tokens)} " + f"audio_tokens(min/mean/max)={_format_tensor_triplet(num_audio_tokens)} " + f"audio_seconds(min/mean/max)={_format_audio_seconds_triplet(sample_lengths, sample_rate)}" + ), + ] + + fbank = batch.get("fbank") + fbank_lengths = batch.get("fbank_lengths") + if isinstance(fbank, torch.Tensor): + debug_lines.append( + "[debug:data] " + f"fbank_shape={tuple(fbank.shape)} " + f"fbank_frames(min/mean/max)={_format_tensor_triplet(fbank_lengths)}" + ) + + loss_masks = batch.get("loss_masks") + if isinstance(loss_masks, dict): + debug_lines.append( + "[debug:data] " + "loss_masks=" + + ", ".join( + f"{name}:{_format_mask_density(mask)}" + for name, mask in sorted(loss_masks.items()) + ) + ) + + fids = batch.get("fids") or [] + sample_count = min(int(input_ids.size(0)), 3) + for sample_idx in range(sample_count): + input_length = int(input_ids_lengths[sample_idx].item()) + audio_length = int(sample_lengths[sample_idx].item()) + fbank_shape = "unavailable" + if isinstance(fbank, torch.Tensor) and isinstance(fbank_lengths, torch.Tensor): + fbank_shape = ( + f"({int(fbank_lengths[sample_idx].item())}, {int(fbank.size(-1))})" + ) + debug_lines.append( + "[debug:data] " + f"sample_index={sample_idx} " + f"fid={str(fids[sample_idx]) if sample_idx < len(fids) else f'sample_{sample_idx:02d}'} " + f"source_name={source_names[sample_idx] if source_names else None} " + f"input_ids_shape=({input_length},) " + f"sample_shape=(1, {audio_length}) " + f"fbank_shape={fbank_shape} " + f"num_text_tokens={int(num_text_tokens[sample_idx].item())} " + f"num_audio_tokens={int(num_audio_tokens[sample_idx].item())} " + f"audio_seconds={audio_length / float(sample_rate):.2f} " + "text=" + f"{tokenizer.decode(input_ids[sample_idx, :input_length].detach().cpu().tolist(), skip_special_tokens=False, clean_up_tokenization_spaces=False)!r}" + ) + return debug_lines + + +def should_print_gradient_debug( + *, + debug_enabled: bool, + is_main_process: bool, + next_global_step: int, + log_interval: int, + early_step_limit: int, +) -> bool: + return bool( + debug_enabled + and is_main_process + and ( + next_global_step <= early_step_limit + or next_global_step % log_interval == 0 + ) + ) + + +def build_gradient_debug_lines( + model: torch.nn.Module, + *, + global_step: int, + grad_norm: float, + grad_clip_norm: float, +) -> list[str]: + top_param_candidates: list[tuple[str, float, float, float]] = [] + nonfinite_grad_params: list[str] = [] + nonfinite_param_count = 0 + params_with_grad = 0 + params_without_grad = 0 + abs_sum = 0.0 + abs_count = 0 + max_abs_grad = 0.0 + + for name, parameter in model.named_parameters(): + if not parameter.requires_grad: + continue + grad = parameter.grad + if grad is None: + params_without_grad += 1 + continue + + grad_tensor = grad.detach().float() + params_with_grad += 1 + if not bool(torch.isfinite(grad_tensor).all().item()): + nonfinite_param_count += 1 + if len(nonfinite_grad_params) < 8: + nonfinite_grad_params.append(name) + grad_abs = grad_tensor.abs() + param_norm = float(torch.linalg.vector_norm(grad_tensor).item()) + param_max_abs = float(grad_abs.max().item()) + param_mean_abs = float(grad_abs.mean().item()) + max_abs_grad = max(max_abs_grad, param_max_abs) + abs_sum += float(grad_abs.sum().item()) + abs_count += int(grad_abs.numel()) + top_param_candidates.append((name, param_norm, param_max_abs, param_mean_abs)) + + mean_abs_grad = math.nan if abs_count == 0 else abs_sum / float(abs_count) + top_param_norms = sorted( + top_param_candidates, + key=lambda item: item[1], + reverse=True, + )[:6] + + debug_lines = [ + ( + "[debug:grad] " + f"step={global_step} " + f"pre_clip_grad_norm={format_scalar(grad_norm)} " + f"clip_ratio={format_scalar(_safe_grad_clip_ratio(grad_norm, grad_clip_norm))} " + f"params_with_grad={params_with_grad} " + f"params_without_grad={params_without_grad} " + f"nonfinite_param_count={nonfinite_param_count} " + f"max_abs_grad={format_scalar(max_abs_grad)} " + f"mean_abs_grad={format_scalar(mean_abs_grad)}" + ) + ] + if top_param_norms: + debug_lines.append( + "[debug:grad] top_params=" + + ", ".join( + ( + f"{name}:{param_norm:.4f}" + f"(max={param_max_abs:.4e},mean={param_mean_abs:.4e})" + ) + for name, param_norm, param_max_abs, param_mean_abs in top_param_norms + ) + ) + if nonfinite_grad_params: + debug_lines.append( + "[debug:grad] nonfinite_params=" + ", ".join(nonfinite_grad_params) + ) + return debug_lines + + +def _format_tensor_triplet(values: object) -> str: + if not isinstance(values, torch.Tensor) or values.numel() == 0: + return "n/a" + flattened = values.detach().cpu().to(torch.float32) + return ( + f"{int(flattened.min().item())}/" + f"{flattened.mean().item():.2f}/" + f"{int(flattened.max().item())}" + ) + + +def _format_audio_seconds_triplet(values: object, sample_rate: int) -> str: + if not isinstance(values, torch.Tensor) or values.numel() == 0: + return "n/a" + seconds = values.detach().cpu().to(torch.float32) / float(sample_rate) + return ( + f"{seconds.min().item():.2f}/" + f"{seconds.mean().item():.2f}/" + f"{seconds.max().item():.2f}" + ) + + +def _format_mask_density(mask: object) -> str: + if not isinstance(mask, torch.Tensor) or mask.numel() == 0: + return "n/a" + return f"{int(mask.detach().gt(0).sum().item())}/{int(mask.numel())}" + + +def _safe_grad_clip_ratio(grad_norm: float, grad_clip_norm: float) -> float: + if not math.isfinite(grad_norm): + return math.nan + return grad_norm / float(grad_clip_norm) + + +# --------------------------------------------------------------------------- +# Step Reporting +# --------------------------------------------------------------------------- + + +def should_log_training_step(global_step: int, log_interval: int) -> bool: + return global_step % log_interval == 0 + + +def reduce_source_metrics( + source_loss_totals: dict[str, dict[str, float]], + source_loss_denominators: dict[str, dict[str, float]], + *, + device: torch.device, + loss_config: Any, +) -> dict[str, dict[str, float]]: + reduced_source_totals = loss_ops.sum_grouped_named_scalars_across_ranks( + source_loss_totals, + device=device, + ) + reduced_source_denominators = loss_ops.sum_grouped_named_scalars_across_ranks( + source_loss_denominators, + device=device, + ) + return loss_ops.reduce_loss_statistics_by_source( + reduced_source_totals, + reduced_source_denominators, + loss_config=loss_config, + ) + + +def build_train_step_report( + metrics: dict[str, Any], + *, + learning_rate: float, + grad_norm: float, + current_time: float, + last_log_step: int, + last_log_time: float, + progress: TrainProgress, + max_train_steps: int, + reduced_by_source: dict[str, dict[str, float]], +) -> TrainStepReport: + logged_steps = progress.global_step - last_log_step + elapsed = current_time - last_log_time + steps_per_second = ( + math.nan + if logged_steps <= 0 or elapsed <= 0.0 + else float(logged_steps) / elapsed + ) + eta_seconds = ( + math.nan + if not math.isfinite(steps_per_second) or steps_per_second <= 0.0 + else float(max_train_steps - progress.global_step) / steps_per_second + ) + return TrainStepReport( + log_values=build_train_log_dict( + metrics, + learning_rate=learning_rate, + grad_norm=grad_norm, + steps_per_second=steps_per_second, + eta_seconds=eta_seconds, + progress=progress, + reduced_by_source=reduced_by_source, + ), + console_line=format_train_line( + metrics, + learning_rate=learning_rate, + grad_norm=grad_norm, + steps_per_second=steps_per_second, + eta_seconds=eta_seconds, + progress=progress, + max_train_steps=max_train_steps, + reduced_by_source=reduced_by_source, + ), + ) + + +# --------------------------------------------------------------------------- +# Formatting Helpers +# --------------------------------------------------------------------------- + + +def flatten_config(values, parent_key="", sep="/"): + """Flatten a nested config dict into ``path/to/key -> value`` pairs.""" + items = [] + for key, value in values.items(): + new_key = f"{parent_key}{sep}{key}" if parent_key else key + if isinstance(value, dict): + items.extend(flatten_config(value, new_key, sep).items()) + elif isinstance(value, (list, tuple)): + items.append((new_key, str(value))) + elif value is None: + items.append((new_key, "None")) + else: + items.append((new_key, value)) + return dict(items) + + +def format_scalar(value: float) -> str: + """Format a scalar for concise console logging.""" + if not math.isfinite(value): + return "nan" + if float(value).is_integer(): + return str(int(value)) + return f"{value:.4f}" + + +def _format_eta(eta_seconds: float) -> str: + """Render ETA seconds as ``HH:MM:SS`` or ``n/a``.""" + if not math.isfinite(eta_seconds) or eta_seconds < 0.0: + return "n/a" + total_seconds = int(round(eta_seconds)) + hours, remainder = divmod(total_seconds, 3600) + minutes, seconds = divmod(remainder, 60) + return f"{hours:02d}:{minutes:02d}:{seconds:02d}" + + +def build_train_log_dict( + metrics: dict[str, Any], + *, + learning_rate: float, + grad_norm: float, + steps_per_second: float, + eta_seconds: float, + progress: TrainProgress, + reduced_by_source: dict[str, dict[str, Any]], +) -> dict[str, float]: + """Build the flat metric dict sent to experiment trackers.""" + log_dict = { + "train/epoch": float(progress.epoch), + "train/learning_rate": learning_rate, + "train/grad_norm": grad_norm, + "train/steps_per_second": steps_per_second, + "train/eta_seconds": eta_seconds, + "train/consumed_tokens": float(progress.total_tokens), + "train/consumed_audio_tokens": float(progress.audio_tokens), + "train/consumed_text_tokens": float(progress.text_tokens), + } + for name, value in metrics.items(): + log_dict[f"train/{name}"] = float(value) + for source_name, source_metrics in reduced_by_source.items(): + log_dict.update( + { + f"train/{source_name}/{name}": float(value) + for name, value in source_metrics.items() + } + ) + return log_dict + + +def format_train_line( + metrics: dict[str, Any], + *, + learning_rate: float, + grad_norm: float, + steps_per_second: float, + eta_seconds: float, + progress: TrainProgress, + max_train_steps: int, + reduced_by_source: dict[str, dict[str, Any]], +) -> str: + """Build a single human-readable console line for one training step.""" + parts = [ + f"iteration {progress.global_step}/{max_train_steps}", + f"epoch: {progress.epoch}", + f"consumed_tokens: {progress.total_tokens}", + f"consumed_audio_tokens: {progress.audio_tokens}", + f"consumed_text_tokens: {progress.text_tokens}", + f"learning_rate: {learning_rate:.2e}", + f"steps_per_second: {format_scalar(steps_per_second)}", + f"job_eta: {_format_eta(eta_seconds)}", + f"grad_norm: {format_scalar(grad_norm)}", + ] + for name in sorted(name for name in metrics if name != "loss"): + parts.append(f"{name}: {format_scalar(float(metrics[name]))}") + if "loss" in metrics: + parts.append(f"loss: {format_scalar(float(metrics['loss']))}") + for source_name, source_metrics in reduced_by_source.items(): + for name in sorted(name for name in source_metrics if name != "loss"): + parts.append( + f"{source_name}_{name}: {format_scalar(float(source_metrics[name]))}" + ) + if "loss" in source_metrics: + parts.append( + f"{source_name}_loss: {format_scalar(float(source_metrics['loss']))}" + ) + return " | ".join(parts) + + +def build_validation_log_dict( + metrics: dict[str, Any], + *, + reduced_by_source: dict[str, dict[str, Any]], +) -> dict[str, float]: + """Build the flat validation metric dict sent to experiment trackers.""" + log_dict = {f"val/{name}": float(value) for name, value in metrics.items()} + for source_name, source_metrics in reduced_by_source.items(): + log_dict.update( + { + f"val/{source_name}/{name}": float(value) + for name, value in source_metrics.items() + } + ) + return log_dict + + +def format_validation_line( + metrics: dict[str, Any], + *, + global_step: int, + reduced_by_source: dict[str, dict[str, Any]], +) -> str: + """Build the console summary line printed after a validation pass.""" + parts = [f"validation at iteration {global_step}"] + for name in sorted(name for name in metrics if name != "loss"): + parts.append(f"{name}: {format_scalar(float(metrics[name]))}") + if "loss" in metrics: + parts.append(f"loss: {format_scalar(float(metrics['loss']))}") + for source_name, source_metrics in reduced_by_source.items(): + for name in sorted(name for name in source_metrics if name != "loss"): + parts.append( + f"{source_name}_{name}: {format_scalar(float(source_metrics[name]))}" + ) + if "loss" in source_metrics: + parts.append( + f"{source_name}_loss: {format_scalar(float(source_metrics['loss']))}" + ) + return " | ".join(parts) + + +__all__ = [ + "TrainProgress", + "TrainStepReport", + "abort_on_out_of_memory", + "any_rank_true", + "build_data_debug_lines", + "build_gradient_debug_lines", + "build_train_step_report", + "build_train_log_dict", + "build_validation_log_dict", + "flatten_config", + "format_scalar", + "format_train_line", + "format_validation_line", + "move_to_device", + "reduce_source_metrics", + "should_log_training_step", + "should_print_gradient_debug", + "sum_integer_counters_across_ranks", +] diff --git a/src/dots_tts/utils/__init__.py b/src/dots_tts/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/dots_tts/utils/audio.py b/src/dots_tts/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..3e2202b650d864c8b492b3693043ce254efcff8f --- /dev/null +++ b/src/dots_tts/utils/audio.py @@ -0,0 +1,45 @@ +"""Audio helpers used by the retained train/infer pipeline.""" + +from __future__ import annotations + +import torch +import torchaudio.compliance.kaldi as Kaldi +import torchaudio.functional as AF + + +def high_quality_resample(x, orig_sr, target_sr): + return AF.resample( + x, + orig_freq=orig_sr, + new_freq=target_sr, + lowpass_filter_width=64, + rolloff=0.95, + resampling_method="sinc_interp_kaiser", + ) + + +def extract_fbank( + waveform: torch.Tensor, + *, + sample_rate: int, + n_mels: int, + dither: float = 0.0, + mean_norm: bool = False, +) -> torch.Tensor: + if waveform.ndim == 1: + feature_input = waveform.unsqueeze(0) + elif waveform.ndim == 2: + feature_input = waveform if waveform.size(0) == 1 else waveform[0:1, :] + else: + raise ValueError( + f"FBank expects a 1D or 2D waveform, got shape {tuple(waveform.shape)}." + ) + features = Kaldi.fbank( + feature_input, + num_mel_bins=n_mels, + sample_frequency=sample_rate, + dither=dither, + ) + if mean_norm: + features = features - features.mean(dim=0, keepdim=True) + return features diff --git a/src/dots_tts/utils/logging.py b/src/dots_tts/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..f1e7bcdf983d2cab32ba8d7825da9daff76ab954 --- /dev/null +++ b/src/dots_tts/utils/logging.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import os +import sys +from pathlib import Path + +from loguru import logger + +DEFAULT_LOG_LEVEL = "INFO" +DEFAULT_LOG_FORMAT = ( + "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level:<8} | " + "{name}:{function}:{line} | {message}" +) + + +def configure_logging( + *, + level: str | None = None, + log_file: str | os.PathLike[str] | None = None, +) -> None: + resolved_level = (level or os.environ.get("DOTS_TTS_LOG_LEVEL") or DEFAULT_LOG_LEVEL).upper() + logger.remove() + logger.add( + sys.stderr, + level=resolved_level, + format=DEFAULT_LOG_FORMAT, + backtrace=True, + diagnose=False, + enqueue=False, + ) + if log_file: + log_path = Path(log_file).expanduser() + log_path.parent.mkdir(parents=True, exist_ok=True) + logger.add( + log_path, + level=resolved_level, + format=DEFAULT_LOG_FORMAT, + backtrace=True, + diagnose=False, + enqueue=False, + encoding="utf-8", + ) diff --git a/src/dots_tts/utils/profiling.py b/src/dots_tts/utils/profiling.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5a54c0862542a9b11fcbfeae3ce79700d916de --- /dev/null +++ b/src/dots_tts/utils/profiling.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +import os +import time +from contextlib import contextmanager +from contextvars import ContextVar, Token +from dataclasses import dataclass +from multiprocessing import Queue +from typing import Iterator + +import torch +from loguru import logger + +INFERENCE_STAGE_NAMES = ( + "FM", + "latent_encoder", + "patch_encoder", + "LLM", + "latent_decoder", + "speaker_encoder", + "vocoder", +) + +_INFERENCE_STAGE_NAME_MAP = { + name.lower(): name for name in INFERENCE_STAGE_NAMES +} +_CURRENT_INFERENCE_PROFILER: ContextVar[InferenceProfiler | None] = ContextVar( + "current_inference_profiler", + default=None, +) + + +def normalize_inference_stage_name(name: str) -> str: + canonical = _INFERENCE_STAGE_NAME_MAP.get(name.strip().lower()) + if canonical is None: + raise ValueError( + f"Unsupported inference stage '{name}'. " + f"Expected one of: {', '.join(INFERENCE_STAGE_NAMES)}." + ) + return canonical + + +@dataclass(slots=True) +class InferenceStageStat: + seconds: float = 0.0 + count: int = 0 + + +@dataclass(frozen=True, slots=True) +class ProfileEvent: + stage: str + seconds: float + count: int + pid: int + + +class DataProfiler: + def __init__(self, queue: Queue | None = None): + self._queue = queue + self._pid = os.getpid() + + @property + def enabled(self) -> bool: + return self._queue is not None + + @contextmanager + def measure(self, stage: str, *, count: int = 1) -> Iterator[None]: + if self._queue is None: + yield + return + start = time.perf_counter() + try: + yield + finally: + self._queue.put( + ProfileEvent( + stage=stage, + seconds=time.perf_counter() - start, + count=int(count), + pid=self._pid, + ) + ) + + def child(self) -> DataProfiler: + return DataProfiler(self._queue) + + +def ensure_data_profiler(profiler: DataProfiler | None) -> DataProfiler: + return DataProfiler() if profiler is None else profiler + + +class InferenceProfiler: + def __init__(self, device: torch.device): + self._device = device + self._stats = { + stage: InferenceStageStat() for stage in INFERENCE_STAGE_NAMES + } + + def _sync(self) -> None: + if self._device.type == "cuda": + torch.cuda.synchronize(self._device) + + @contextmanager + def measure(self, stage: str, *, count: int = 1) -> Iterator[None]: + stage = normalize_inference_stage_name(stage) + self._sync() + start = time.perf_counter() + try: + yield + finally: + self._sync() + stat = self._stats[stage] + stat.seconds += time.perf_counter() - start + stat.count += int(count) + + def summary( + self, + *, + duration_seconds: float | None = None, + ) -> dict[str, dict[str, float | int]]: + summary: dict[str, dict[str, float | int]] = {} + for stage in INFERENCE_STAGE_NAMES: + stat = self._stats[stage] + payload: dict[str, float | int] = { + "seconds": stat.seconds, + "count": stat.count, + } + if duration_seconds is not None: + payload["rtf"] = ( + stat.seconds / duration_seconds + if duration_seconds > 0 + else float("inf") + ) + summary[stage] = payload + return summary + + +@contextmanager +def inference_profiling( + *, + enabled: bool, + device: torch.device, +) -> Iterator[InferenceProfiler | None]: + profiler = InferenceProfiler(device) if enabled else None + with activate_inference_profiler(profiler): + yield profiler + + +@contextmanager +def activate_inference_profiler( + profiler: InferenceProfiler | None, +) -> Iterator[InferenceProfiler | None]: + if profiler is None: + yield None + return + token: Token[InferenceProfiler | None] = _CURRENT_INFERENCE_PROFILER.set(profiler) + try: + yield profiler + finally: + _CURRENT_INFERENCE_PROFILER.reset(token) + + +@contextmanager +def measure_inference(stage: str, *, count: int = 1) -> Iterator[None]: + profiler = _CURRENT_INFERENCE_PROFILER.get() + if profiler is None: + yield + return + with profiler.measure(stage, count=count): + yield + + +def log_inference_profile( + *, + request_id: str, + profiling: dict[str, dict[str, float | int]], + duration_seconds: float, +) -> None: + active_stages = [ + stage + for stage in INFERENCE_STAGE_NAMES + if int(profiling[stage]["count"]) > 0 + ] + if not active_stages: + logger.info( + "Inference profiling summary: request_id={} no_profiled_stages duration_seconds={:.3f}", + request_id, + duration_seconds, + ) + return + for stage in active_stages: + stats = profiling[stage] + logger.info( + "Inference profiling: request_id={} stage={} seconds={:.4f} count={} rtf={:.4f}", + request_id, + stage, + float(stats["seconds"]), + int(stats["count"]), + float(stats["rtf"]), + ) + + +__all__ = [ + "DataProfiler", + "ProfileEvent", + "INFERENCE_STAGE_NAMES", + "activate_inference_profiler", + "ensure_data_profiler", + "InferenceProfiler", + "InferenceStageStat", + "inference_profiling", + "log_inference_profile", + "measure_inference", + "normalize_inference_stage_name", +] diff --git a/src/dots_tts/utils/text.py b/src/dots_tts/utils/text.py new file mode 100644 index 0000000000000000000000000000000000000000..ccda50f61412718f42193b3f8852189e96dff04e --- /dev/null +++ b/src/dots_tts/utils/text.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import re +from functools import lru_cache +from typing import Literal + +from langcodes import Language as LangcodesLanguage +from lingua import Language, LanguageDetectorBuilder +from tn.chinese.normalizer import Normalizer as ZhNormalizer +from tn.english.normalizer import Normalizer as EnNormalizer + +TextLanguage = Literal["zh", "en", "unknown"] + +_WHITESPACE_PATTERN = re.compile(r"\s+") + + +@lru_cache(maxsize=1) +def get_chinese_text_normalizer() -> ZhNormalizer: + return ZhNormalizer() + + +@lru_cache(maxsize=1) +def get_english_text_normalizer() -> EnNormalizer: + return EnNormalizer() + + +@lru_cache(maxsize=1) +def get_language_detector(): + supported_languages = tuple( + sorted(Language.all(), key=lambda language: language.name) + ) + return LanguageDetectorBuilder.from_languages(*supported_languages).build() + + +def _lingua_language_to_code(language: Language | None) -> str | None: + if language is None: + return None + iso_code_639_1 = getattr(language.iso_code_639_1, "name", None) + if iso_code_639_1: + return iso_code_639_1.lower() + iso_code_639_3 = getattr(language.iso_code_639_3, "name", None) + if iso_code_639_3: + return iso_code_639_3.lower() + return language.name.lower() + + +def detect(text: str) -> str | None: + stripped = text.strip() + if not stripped: + return None + language = get_language_detector().detect_language_of(stripped) + return _lingua_language_to_code(language) + + +def normalize_language_code(language: str | None) -> str | None: + if language is None: + return None + + stripped = language.strip() + if not stripped or stripped.lower() in {"none", "unknown"}: + return None + if stripped.startswith("口音:"): + return stripped + + for resolver in (LangcodesLanguage.get, LangcodesLanguage.find): + try: + normalized_language = resolver(stripped).prefer_macrolanguage() + except Exception: + continue + + language_code = (normalized_language.language or "").strip().upper() + if language_code and language_code != "UND": + return language_code + return None + + +def attach_language_tag(text: str, language: str | None) -> str: + if not text: + return text + + language_code = normalize_language_code(language) + if language_code is None: + return text + + if language_code == "YUE": + language_code = "口音:粤语" + + language_tag = f"[{language_code}]" + if text.startswith(language_tag): + return text + return f"{language_tag}{text}" + + +def detect_text_language(text: str) -> TextLanguage: + language_code = detect(text) + if language_code == "zh": + return "zh" + if language_code == "en": + return "en" + return "unknown" + + +def _normalize_with(normalizer, text: str) -> str: + normalized = normalizer.normalize(text) + return _WHITESPACE_PATTERN.sub(" ", normalized).strip() + + +def normalize_chinese_text(text: str) -> str: + stripped = text.strip() + if not stripped: + return "" + return _normalize_with(get_chinese_text_normalizer(), stripped) + + +def normalize_english_text(text: str) -> str: + stripped = text.strip() + if not stripped: + return "" + return _normalize_with(get_english_text_normalizer(), stripped) + + +def normalize_text(text: str) -> str: + stripped = text.strip() + if not stripped: + return "" + + language = detect_text_language(stripped) + if language == "zh": + return _normalize_with(get_chinese_text_normalizer(), stripped) + if language == "en": + return _normalize_with(get_english_text_normalizer(), stripped) + return stripped diff --git a/src/dots_tts/utils/tokenizer.py b/src/dots_tts/utils/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..afa6b2a64e8c5e1c3a790870fa824f32cd6f73fd --- /dev/null +++ b/src/dots_tts/utils/tokenizer.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +AUDIO_COMP_START_TOKEN = "<|audio_comp_start|>" +AUDIO_COMP_SPAN_TOKEN = "<|audio_comp_span|>" +AUDIO_COMP_END_TOKEN = "<|audio_comp_end|>" +AUDIO_GEN_START_TOKEN = "<|audio_gen_start|>" +AUDIO_GEN_SPAN_TOKEN = "<|audio_gen_span|>" +AUDIO_GEN_END_TOKEN = "<|audio_gen_end|>" +TEXT_COND_END_TOKEN = "<|text_cond_end|>" + + +def require_token_id(tokenizer, token: str) -> int: + token_id = tokenizer.convert_tokens_to_ids(token) + if token_id is None or token_id < 0: + raise ValueError(f"Artifact tokenizer is missing required special token: {token}") + return int(token_id) + + +__all__ = [ + "AUDIO_COMP_END_TOKEN", + "AUDIO_COMP_SPAN_TOKEN", + "AUDIO_COMP_START_TOKEN", + "AUDIO_GEN_END_TOKEN", + "AUDIO_GEN_SPAN_TOKEN", + "AUDIO_GEN_START_TOKEN", + "TEXT_COND_END_TOKEN", + "require_token_id", +] diff --git a/src/dots_tts/utils/util.py b/src/dots_tts/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..3b68727fbab5187d56861da2abea20b93727442b --- /dev/null +++ b/src/dots_tts/utils/util.py @@ -0,0 +1,48 @@ +import random +from typing import Any + +import numpy as np +import torch + + +def get_dtype(x): + if x.lower() in ("bf16", "torch.bfloat16", "bfloat16"): + return torch.bfloat16 + if x.lower() in ("fp16", "torch.float16", "float16"): + return torch.float16 + if x.lower() in ("fp32", "torch.float32", "float32"): + return torch.float32 + raise ValueError("Unsupported dtype value.") + + +def seed_everything(seed: int = 42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def mask_data(x, mask, masking_value=0.0): + while mask.dim() < x.dim(): + mask = mask.unsqueeze(-1) + if isinstance(masking_value, torch.Tensor): + return torch.where(mask, masking_value.expand_as(x), x) + return torch.where( + mask, torch.full(x.shape, masking_value, dtype=x.dtype, device=x.device), x + ) + + +def get_mask_from_lengths(lengths, max_len=None): + if max_len is None: + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, out=torch.LongTensor(max_len).to(lengths.device)) + return (ids < lengths.unsqueeze(1)).bool() + + +def scalar_as_float(value: Any) -> float: + if isinstance(value, torch.Tensor): + return float(value.detach().float().item()) + return float(value)