YuMS commited on
Commit
804ee23
·
1 Parent(s): 02617d2

add inference code with AOTI support for hf space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. README.md +7 -1
  3. app.py +176 -0
  4. apps/__init__.py +1 -0
  5. apps/gradio/__init__.py +1 -0
  6. apps/gradio/app.py +663 -0
  7. apps/gradio/constants.py +26 -0
  8. apps/gradio/default_prompts/prompt_text +0 -0
  9. apps/gradio/languages.py +115 -0
  10. apps/gradio/service.py +773 -0
  11. configs/dots_tts.yaml +76 -0
  12. requirements.txt +19 -0
  13. src/dots_tts/__init__.py +1 -0
  14. src/dots_tts/cli.py +152 -0
  15. src/dots_tts/config/__init__.py +1 -0
  16. src/dots_tts/config/app.py +32 -0
  17. src/dots_tts/config/base.py +64 -0
  18. src/dots_tts/config/data.py +63 -0
  19. src/dots_tts/config/train.py +28 -0
  20. src/dots_tts/data/EXTENSION.md +124 -0
  21. src/dots_tts/data/__init__.py +1 -0
  22. src/dots_tts/data/batchers.py +188 -0
  23. src/dots_tts/data/builders.py +194 -0
  24. src/dots_tts/data/collator.py +87 -0
  25. src/dots_tts/data/pipelines/__init__.py +1 -0
  26. src/dots_tts/data/pipelines/base.py +32 -0
  27. src/dots_tts/data/pipelines/preprocessing.py +84 -0
  28. src/dots_tts/data/pipelines/tokenizing.py +339 -0
  29. src/dots_tts/data/pipelines/tts_pipeline.py +132 -0
  30. src/dots_tts/data/source_adapters/__init__.py +1 -0
  31. src/dots_tts/data/source_adapters/base_adapter.py +91 -0
  32. src/dots_tts/data/source_adapters/jsonl_manifest_adapter.py +132 -0
  33. src/dots_tts/data/source_adapters/multi_source_adapter.py +222 -0
  34. src/dots_tts/data/streaming.py +400 -0
  35. src/dots_tts/models/__init__.py +1 -0
  36. src/dots_tts/models/dots_tts/__init__.py +1 -0
  37. src/dots_tts/models/dots_tts/config.py +71 -0
  38. src/dots_tts/models/dots_tts/core.py +910 -0
  39. src/dots_tts/models/dots_tts/model.py +1958 -0
  40. src/dots_tts/modules/__init__.py +0 -0
  41. src/dots_tts/modules/backbone/__init__.py +1 -0
  42. src/dots_tts/modules/backbone/dit.py +205 -0
  43. src/dots_tts/modules/backbone/layers.py +333 -0
  44. src/dots_tts/modules/backbone/semantic_encoder.py +356 -0
  45. src/dots_tts/modules/speaker/__init__.py +1 -0
  46. src/dots_tts/modules/speaker/campplus.py +200 -0
  47. src/dots_tts/modules/speaker/campplus_layers.py +258 -0
  48. src/dots_tts/modules/speaker/encoder.py +226 -0
  49. src/dots_tts/modules/speaker/fbank.py +31 -0
  50. src/dots_tts/modules/vocoder/__init__.py +1 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2026 OpenMOSS Team, Fudan University, SII and MOSI
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -9,6 +9,12 @@ python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
+ tags:
13
+ - zerogpu
14
+ - aoti
15
+ - text-to-speech
16
  ---
17
 
18
+ dots.tts Gradio Space for Hugging Face ZeroGPU with optional PyTorch AOTInductor startup compilation.
19
+
20
+ Set `DOTS_TTS_MODEL_NAME_OR_PATH` to a local model directory or Hugging Face model repo id. The app defaults to `rednote-hilab/dots.tts`.
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import Any, Callable
7
+
8
+
9
+ REPO_ROOT = Path(__file__).resolve().parent
10
+ SRC_ROOT = REPO_ROOT / "src"
11
+
12
+ for import_root in (REPO_ROOT, SRC_ROOT):
13
+ import_root_str = str(import_root)
14
+ if import_root_str not in sys.path:
15
+ sys.path.insert(0, import_root_str)
16
+
17
+
18
+ class _SpacesFallback:
19
+ @staticmethod
20
+ def GPU(*decorator_args, **_decorator_kwargs):
21
+ if decorator_args and callable(decorator_args[0]):
22
+ return decorator_args[0]
23
+
24
+ def decorate(fn: Callable[..., Any]) -> Callable[..., Any]:
25
+ return fn
26
+
27
+ return decorate
28
+
29
+
30
+ try:
31
+ import spaces # type: ignore
32
+ except Exception: # pragma: no cover - only used outside Hugging Face Spaces.
33
+ spaces = _SpacesFallback() # type: ignore
34
+
35
+
36
+ def _env_bool(name: str, default: bool) -> bool:
37
+ value = os.environ.get(name)
38
+ if value is None:
39
+ return default
40
+ return value.strip().lower() in {"1", "true", "yes", "on"}
41
+
42
+
43
+ def _env_int(name: str, default: int) -> int:
44
+ value = os.environ.get(name)
45
+ if value is None or not value.strip():
46
+ return default
47
+ return int(value)
48
+
49
+
50
+ def _configure_zero_gpu_environment() -> None:
51
+ os.environ.setdefault("DOTS_TTS_COMPILE_BACKEND", "aoti")
52
+ os.environ.setdefault("DOTS_TTS_SKIP_INIT_WARMUP", "1")
53
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
54
+
55
+
56
+ def _preload_runtime(app_service, app_config, compile_backend: str):
57
+ runtime, resolved_model_name_or_path = app_service._get_runtime( # noqa: SLF001
58
+ app_config.default_model_name_or_path,
59
+ )
60
+ runtime.optimize = bool(app_config.optimize)
61
+ runtime.model.set_optimize(bool(app_config.optimize))
62
+ if hasattr(runtime.model, "set_compile_backend"):
63
+ runtime.model.set_compile_backend(compile_backend)
64
+ return runtime, resolved_model_name_or_path
65
+
66
+
67
+ def main() -> None:
68
+ _configure_zero_gpu_environment()
69
+
70
+ import gradio as gr
71
+ from loguru import logger
72
+
73
+ from apps.gradio.app import PLAYGROUND_CSS, build_demo, build_playground_theme
74
+ from apps.gradio.service import GradioAppService, build_gradio_app_config
75
+ from dots_tts.utils.logging import configure_logging
76
+
77
+ host = os.environ.get("DOTS_TTS_HOST", "0.0.0.0")
78
+ port = _env_int("DOTS_TTS_PORT", 7860)
79
+ model_name_or_path = os.environ.get(
80
+ "DOTS_TTS_MODEL_NAME_OR_PATH",
81
+ "rednote-hilab/dots.tts",
82
+ )
83
+ precision = os.environ.get("DOTS_TTS_PRECISION", "bfloat16")
84
+ execution_mode = os.environ.get("DOTS_TTS_EXECUTION_MODE", "generate_stream")
85
+ max_generate_length = _env_int("DOTS_TTS_MAX_GENERATE_LENGTH", 500)
86
+ default_num_steps = _env_int("DOTS_TTS_DEFAULT_NUM_STEPS", 10)
87
+ compile_backend = os.environ.get("DOTS_TTS_COMPILE_BACKEND", "aoti").strip().lower()
88
+ enable_aoti = _env_bool("DOTS_TTS_ENABLE_AOTI", True)
89
+ startup_compile = _env_bool("DOTS_TTS_AOTI_COMPILE_ON_STARTUP", True)
90
+ optimize = _env_bool("DOTS_TTS_OPTIMIZE", True)
91
+ generation_duration = _env_int("DOTS_TTS_ZERO_GPU_DURATION", 600)
92
+ compile_duration = _env_int("DOTS_TTS_ZERO_GPU_COMPILE_DURATION", 1500)
93
+ output_dir = Path(os.environ.get("DOTS_TTS_OUTPUT_DIR", "/tmp/dots_tts_outputs"))
94
+ log_file = Path(os.environ.get("DOTS_TTS_LOG_FILE", "/tmp/dots_tts_gradio.log"))
95
+
96
+ configure_logging(log_file=log_file)
97
+ logger.info(
98
+ "Space app starting: model={} execution_mode={} precision={} optimize={} "
99
+ "compile_backend={} enable_aoti={} startup_compile={} max_generate_length={}",
100
+ model_name_or_path,
101
+ execution_mode,
102
+ precision,
103
+ optimize,
104
+ compile_backend,
105
+ enable_aoti,
106
+ startup_compile,
107
+ max_generate_length,
108
+ )
109
+
110
+ app_config = build_gradio_app_config(
111
+ host=host,
112
+ port=port,
113
+ execution_mode=execution_mode,
114
+ precision=precision,
115
+ optimize=optimize,
116
+ model_name_or_path=model_name_or_path,
117
+ output_dir=output_dir,
118
+ max_generate_length=max_generate_length,
119
+ default_num_steps=default_num_steps,
120
+ default_max_generate_length=max_generate_length,
121
+ repo_root=REPO_ROOT,
122
+ )
123
+ app_service = GradioAppService(app_config)
124
+ runtime, resolved_model_name_or_path = _preload_runtime(
125
+ app_service,
126
+ app_config,
127
+ compile_backend if enable_aoti else "torch_compile",
128
+ )
129
+
130
+ if enable_aoti and startup_compile and optimize:
131
+
132
+ @spaces.GPU(duration=compile_duration)
133
+ def compile_aoti_cache():
134
+ child_runtime, _ = _preload_runtime(
135
+ app_service,
136
+ app_config,
137
+ compile_backend,
138
+ )
139
+ child_runtime.model.run_warmup(
140
+ max_generate_length=app_config.max_generate_length,
141
+ precision=app_config.precision,
142
+ num_steps=app_config.default_num_steps,
143
+ guidance_scale=app_config.default_guidance_scale,
144
+ )
145
+ return child_runtime.model.export_compiled_models()
146
+
147
+ compiled_models = compile_aoti_cache()
148
+ if compiled_models:
149
+ runtime.model.import_compiled_models(compiled_models)
150
+ logger.info(
151
+ "AOTI startup compile completed: compiled_target_count={}",
152
+ len(compiled_models or {}),
153
+ )
154
+
155
+ app_service.generate = spaces.GPU(duration=generation_duration)(app_service.generate)
156
+
157
+ demo = build_demo(gr, app_config, app_service)
158
+ logger.info(
159
+ "Space app ready: host={} port={} resolved_model={} compiled_target_count={}",
160
+ app_config.host,
161
+ app_config.port,
162
+ resolved_model_name_or_path,
163
+ len(runtime.model.export_compiled_models())
164
+ if hasattr(runtime.model, "export_compiled_models")
165
+ else 0,
166
+ )
167
+ demo.launch(
168
+ server_name=app_config.host,
169
+ server_port=app_config.port,
170
+ theme=build_playground_theme(gr),
171
+ css=PLAYGROUND_CSS,
172
+ )
173
+
174
+
175
+ if __name__ == "__main__":
176
+ main()
apps/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Application entrypoints for dots.tts."""
apps/gradio/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Gradio application for dots.tts."""
apps/gradio/app.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ REPO_ROOT = Path(__file__).resolve().parents[2]
10
+ SRC_ROOT = REPO_ROOT / "src"
11
+
12
+ for import_root in (REPO_ROOT, SRC_ROOT):
13
+ import_root_str = str(import_root)
14
+ if import_root_str not in sys.path:
15
+ sys.path.insert(0, import_root_str)
16
+
17
+ from apps.gradio.constants import ( # noqa: E402
18
+ DEFAULT_EXECUTION_MODE,
19
+ DEFAULT_GUIDANCE_SCALE,
20
+ DEFAULT_HOST,
21
+ DEFAULT_INPUT_TEXT,
22
+ DEFAULT_LOG_FILE,
23
+ DEFAULT_MAX_GENERATE_LENGTH,
24
+ DEFAULT_NUM_STEPS,
25
+ DEFAULT_ODE_METHOD,
26
+ DEFAULT_OUTPUT_DIR,
27
+ DEFAULT_OUTPUT_RETENTION,
28
+ DEFAULT_PORT,
29
+ DEFAULT_PRECISION,
30
+ DEFAULT_PROMPT_NAME,
31
+ DEFAULT_SEED,
32
+ DEFAULT_SPEAKER_SCALE,
33
+ )
34
+
35
+ if TYPE_CHECKING:
36
+ import gradio as gr
37
+
38
+ DEBUG_GRADIO_ENABLED = os.environ.get("DEBUG_GRADIO", "0") == "1"
39
+
40
+
41
+ PLAYGROUND_CSS = """
42
+ .gradio-container {
43
+ width: min(1600px, calc(100vw - 32px)) !important;
44
+ max-width: none !important;
45
+ margin: 0 auto !important;
46
+ padding-left: 0 !important;
47
+ padding-right: 0 !important;
48
+ }
49
+
50
+ .gradio-container,
51
+ .gradio-container .gradio-container {
52
+ --block-label-background-fill: #CCE5FF;
53
+ --block-label-text-color: #6666FF;
54
+ --block-label-border-color: #99c7ee;
55
+ --block-label-text-weight: 600;
56
+ --block-title-background-fill: #CCE5FF;
57
+ --block-title-text-color: #6666FF;
58
+ --block-title-border-color: #99c7ee;
59
+ --block-title-border-width: var(--block-label-border-width);
60
+ --block-title-radius: var(--block-label-radius);
61
+ --block-title-padding: var(--block-label-padding);
62
+ --block-title-text-size: var(--block-label-text-size);
63
+ --block-title-text-weight: 600;
64
+ }
65
+
66
+ .gradio-container label[data-testid="block-label"],
67
+ .gradio-container label[data-testid="block-label"] *,
68
+ .gradio-container span[data-testid="block-info"],
69
+ .gradio-container span[data-testid="block-info"] * {
70
+ background: #CCE5FF !important;
71
+ border-color: #99c7ee !important;
72
+ color: #6666FF !important;
73
+ fill: #6666FF !important;
74
+ font-family: Verdana, Geneva, "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans CJK SC", sans-serif !important;
75
+ font-style: normal !important;
76
+ font-size: 0.78rem !important;
77
+ line-height: 1.2 !important;
78
+ letter-spacing: 0 !important;
79
+ text-transform: none !important;
80
+ }
81
+ .gradio-container label[data-testid="block-label"],
82
+ .gradio-container span[data-testid="block-info"],
83
+ .gradio-container [data-testid="block-title"],
84
+ .gradio-container .block-title {
85
+ border: var(--block-label-border-width) solid #99c7ee !important;
86
+ border-top: none !important;
87
+ border-left: none !important;
88
+ border-radius: var(--block-label-radius) !important;
89
+ box-shadow: var(--block-label-shadow) !important;
90
+ padding: var(--block-label-padding) !important;
91
+ }
92
+ .gradio-container label[data-testid="block-label"],
93
+ .gradio-container label[data-testid="block-label"] *,
94
+ .gradio-container span[data-testid="block-info"],
95
+ .gradio-container span[data-testid="block-info"] *,
96
+ .gradio-container [data-testid="block-title"],
97
+ .gradio-container [data-testid="block-title"] *,
98
+ .gradio-container .block-title,
99
+ .gradio-container .block-title * {
100
+ font-weight: 600 !important;
101
+ }
102
+ .gradio-container .block label > span,
103
+ .gradio-container .block label > span *,
104
+ .gradio-container .form label > span,
105
+ .gradio-container .form label > span *,
106
+ .gradio-container label > span:first-child,
107
+ .gradio-container label > span:first-child * {
108
+ font-weight: 600 !important;
109
+ }
110
+ .strong-label [data-testid="block-label"],
111
+ .strong-label [data-testid="block-label"] *,
112
+ .strong-label span[data-testid="block-info"],
113
+ .strong-label span[data-testid="block-info"] *,
114
+ .strong-label [data-testid="block-title"],
115
+ .strong-label [data-testid="block-title"] *,
116
+ .strong-label .block-label,
117
+ .strong-label .block-label *,
118
+ .strong-label .block-title,
119
+ .strong-label .block-title *,
120
+ .strong-label label > span:first-child,
121
+ .strong-label label > span:first-child * {
122
+ font-weight: 600 !important;
123
+ }
124
+ .gradio-container .info-text,
125
+ .gradio-container .info-text * {
126
+ font-weight: 400 !important;
127
+ }
128
+ .gradio-container input,
129
+ .gradio-container textarea,
130
+ .gradio-container select,
131
+ .gradio-container [role="textbox"],
132
+ .gradio-container [contenteditable="true"] {
133
+ font-weight: 400 !important;
134
+ }
135
+ .gradio-container label[data-testid="block-label"] > span:first-child {
136
+ display: none !important;
137
+ }
138
+
139
+ .generate-button {
140
+ background: #6666FF !important;
141
+ color: #ffffff !important;
142
+ border: 1px solid #5555ee !important;
143
+ font-family: Verdana, Geneva, sans-serif !important;
144
+ }
145
+ .generate-button:hover {
146
+ background: #5555ee !important;
147
+ }
148
+
149
+ #playground-banner {
150
+ padding: 0;
151
+ border-radius: 0;
152
+ margin-bottom: 18px;
153
+ background: transparent;
154
+ border: 0;
155
+ }
156
+ #playground-banner h1 {
157
+ margin: 0 0 4px 0;
158
+ font-size: 1.7rem;
159
+ font-weight: 700;
160
+ color: #0f172a;
161
+ letter-spacing: 0;
162
+ }
163
+ #playground-banner .subtitle {
164
+ margin: 0;
165
+ color: #1e293b;
166
+ font-size: 0.9rem;
167
+ }
168
+
169
+ .info-card {
170
+ padding: 14px 18px;
171
+ border-radius: 8px;
172
+ border: 1px solid #99c7ee;
173
+ border-left: 4px solid #2563eb;
174
+ background: transparent;
175
+ font-size: 0.86rem;
176
+ line-height: 1.55;
177
+ margin-bottom: 16px;
178
+ box-sizing: border-box;
179
+ color: #0f172a;
180
+ }
181
+ .info-card .card-title,
182
+ .info-card .notice-title {
183
+ display: block;
184
+ font-weight: 600;
185
+ font-size: 0.92rem;
186
+ color: #0f172a;
187
+ }
188
+ .info-card .card-title {
189
+ margin-bottom: 4px;
190
+ }
191
+ .info-card .notice-title {
192
+ margin-top: 8px;
193
+ margin-bottom: 4px;
194
+ }
195
+ .info-card ol,
196
+ .info-card ul {
197
+ margin: 0;
198
+ padding-left: 18px;
199
+ }
200
+ .info-card li {
201
+ margin: 2px 0;
202
+ }
203
+
204
+ .main-workspace {
205
+ gap: 18px !important;
206
+ align-items: stretch !important;
207
+ }
208
+
209
+ .prompt-column,
210
+ .synthesis-column {
211
+ gap: 14px !important;
212
+ }
213
+
214
+ .control-row,
215
+ .settings-slider-row {
216
+ gap: 14px !important;
217
+ }
218
+
219
+ .settings-card {
220
+ margin-top: 2px !important;
221
+ }
222
+
223
+ .generate-button {
224
+ margin-top: 2px !important;
225
+ width: 100% !important;
226
+ box-sizing: border-box !important;
227
+ flex: 0 0 auto !important;
228
+ min-height: 44px !important;
229
+ padding-top: 10px !important;
230
+ padding-bottom: 10px !important;
231
+ font-size: 1rem !important;
232
+ font-weight: 600 !important;
233
+ }
234
+
235
+ .output-audio {
236
+ flex: 0 0 auto !important;
237
+ min-height: 190px !important;
238
+ }
239
+ .output-audio audio {
240
+ width: 100% !important;
241
+ }
242
+
243
+ @media (max-width: 768px) {
244
+ .gradio-container {
245
+ width: calc(100vw - 20px) !important;
246
+ }
247
+
248
+ }
249
+
250
+ """
251
+
252
+
253
+ def build_playground_theme(gr):
254
+ return gr.themes.Soft(
255
+ primary_hue="slate",
256
+ secondary_hue="slate",
257
+ neutral_hue="slate",
258
+ radius_size="md",
259
+ text_size="md",
260
+ spacing_size="md",
261
+ font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"],
262
+ )
263
+
264
+
265
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
266
+ parser = argparse.ArgumentParser(description="dots.tts Gradio app.")
267
+ parser.add_argument("--host", default=DEFAULT_HOST, help="Server host")
268
+ parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="Server port")
269
+ parser.add_argument(
270
+ "--execution-mode",
271
+ choices=("generate", "generate_stream"),
272
+ default=DEFAULT_EXECUTION_MODE,
273
+ help="Runtime execution mode fixed for the app",
274
+ )
275
+ parser.add_argument(
276
+ "--precision",
277
+ default=DEFAULT_PRECISION,
278
+ help="Inference precision fixed for the app runtime",
279
+ )
280
+ parser.add_argument(
281
+ "--optimize",
282
+ action="store_true",
283
+ help="Enable runtime optimize acceleration",
284
+ )
285
+ parser.add_argument(
286
+ "--model-name-or-path",
287
+ default=None,
288
+ help="Default model directory or Hugging Face repo id",
289
+ )
290
+ parser.add_argument(
291
+ "--output-dir",
292
+ default=str(DEFAULT_OUTPUT_DIR),
293
+ help="Directory for generated wav outputs",
294
+ )
295
+ parser.add_argument(
296
+ "--log-file",
297
+ default=str(DEFAULT_LOG_FILE),
298
+ help="Path to the Gradio log file",
299
+ )
300
+ parser.add_argument(
301
+ "--output-retention-count",
302
+ type=int,
303
+ default=DEFAULT_OUTPUT_RETENTION,
304
+ help="Maximum number of generated wav files to keep",
305
+ )
306
+ parser.add_argument(
307
+ "--max-generate-length",
308
+ type=int,
309
+ default=DEFAULT_MAX_GENERATE_LENGTH,
310
+ help="Maximum generation schedule length fixed for the app runtime",
311
+ )
312
+ parser.add_argument(
313
+ "--default-prompt-name",
314
+ default=DEFAULT_PROMPT_NAME,
315
+ help="Default built-in voice preset name",
316
+ )
317
+ parser.add_argument(
318
+ "--default-precision",
319
+ default=DEFAULT_PRECISION,
320
+ choices=["bfloat16", "float32", "float16"],
321
+ help="Default precision selected in the UI",
322
+ )
323
+ parser.add_argument(
324
+ "--default-num-steps",
325
+ type=int,
326
+ default=DEFAULT_NUM_STEPS,
327
+ help="Default Num Steps selected in the UI",
328
+ )
329
+ parser.add_argument(
330
+ "--default-guidance-scale",
331
+ type=float,
332
+ default=DEFAULT_GUIDANCE_SCALE,
333
+ help="Default Guidance Scale selected in the UI",
334
+ )
335
+ parser.add_argument(
336
+ "--default-speaker-scale",
337
+ type=float,
338
+ default=DEFAULT_SPEAKER_SCALE,
339
+ help="Default Speaker Scale selected in the UI",
340
+ )
341
+ parser.add_argument(
342
+ "--default-max-generate-length",
343
+ type=int,
344
+ default=DEFAULT_MAX_GENERATE_LENGTH,
345
+ help="Default Max Generate Length selected in the UI",
346
+ )
347
+ parser.add_argument(
348
+ "--skip-warmup",
349
+ action="store_true",
350
+ help="Start the Gradio server without running an initial synthesis warmup.",
351
+ )
352
+ return parser.parse_args(argv)
353
+
354
+
355
+ def build_startup_config_panel(gr, app_config) -> None:
356
+ with gr.Accordion("启动固定参数", open=False):
357
+ gr.Markdown("只读。修改这部分需要重启服务并传入新的启动参数。")
358
+ gr.Textbox(
359
+ label="Model",
360
+ value=app_config.default_model_name_or_path,
361
+ interactive=False,
362
+ )
363
+ with gr.Row():
364
+ gr.Textbox(
365
+ label="Execution Mode",
366
+ value=app_config.execution_mode,
367
+ interactive=False,
368
+ )
369
+ gr.Textbox(
370
+ label="Precision",
371
+ value=app_config.precision,
372
+ interactive=False,
373
+ )
374
+ with gr.Row():
375
+ gr.Number(
376
+ label="Max Generate Length",
377
+ value=app_config.max_generate_length,
378
+ precision=0,
379
+ interactive=False,
380
+ )
381
+ gr.Checkbox(
382
+ label="Optimize",
383
+ value=app_config.optimize,
384
+ interactive=False,
385
+ )
386
+
387
+
388
+ def build_demo(gr, app_config, app_service) -> "gr.Blocks":
389
+ from apps.gradio.service import (
390
+ GRADIO_SYNTHESIS_MODE_CHOICES,
391
+ SynthesisRequest,
392
+ build_prompt_choice_items,
393
+ resolve_prompt_selection,
394
+ )
395
+
396
+ def select_prompt_preset(prompt_name: str):
397
+ audio_path, prompt_text = resolve_prompt_selection(
398
+ prompt_name,
399
+ app_config.prompt_presets,
400
+ )
401
+ return audio_path, prompt_text
402
+
403
+ def run_synthesis(
404
+ text: str,
405
+ synthesis_mode: str,
406
+ prompt_audio_path: str | None,
407
+ prompt_text: str,
408
+ ode_method: str,
409
+ num_steps: float,
410
+ guidance_scale: float,
411
+ speaker_scale: float,
412
+ normalize_text: bool,
413
+ seed: float,
414
+ ):
415
+ resolved_synthesis_mode = synthesis_mode if DEBUG_GRADIO_ENABLED else "tts"
416
+ request = SynthesisRequest(
417
+ model_name_or_path=app_config.default_model_name_or_path,
418
+ text=text,
419
+ prompt_audio_path=prompt_audio_path,
420
+ prompt_text=prompt_text,
421
+ execution_mode=app_config.execution_mode,
422
+ template_name=resolved_synthesis_mode,
423
+ ode_method=ode_method,
424
+ num_steps=int(num_steps),
425
+ guidance_scale=float(guidance_scale),
426
+ speaker_scale=float(speaker_scale),
427
+ normalize_text=normalize_text,
428
+ seed=int(seed),
429
+ )
430
+ result = app_service.generate(request)
431
+ return result.audio_path, result.metrics
432
+
433
+ show_prompt_preset = bool(app_config.prompt_presets)
434
+
435
+ with gr.Blocks(title="dots.tts") as demo:
436
+ gr.HTML(
437
+ "<style>\n"
438
+ + PLAYGROUND_CSS
439
+ + "\n</style>\n"
440
+ + """
441
+ <div id="playground-banner">
442
+ <h1>dots.tts</h1>
443
+ <p class="subtitle">Fully-continuous Autoregressive TTS · 48 kHz · Voice Cloning</p>
444
+ </div>
445
+ """,
446
+ )
447
+
448
+ gr.HTML(
449
+ """
450
+ <div class="info-card">
451
+ <span class="card-title">使用说明 · Instructions</span>
452
+ <ol>
453
+ <li>上传参考音频并填写对应转写文本 · Upload prompt audio and fill in its transcript.</li>
454
+ <li>在文本框中输入要合成的内容 · Enter the text to synthesize.</li>
455
+ <li>点击 <b>Generate</b> 合成声音 · Click <b>Generate</b> to synthesize speech.</li>
456
+ </ol>
457
+ </div>
458
+ """,
459
+ )
460
+
461
+ with gr.Row(equal_height=True, elem_classes="main-workspace"):
462
+ with gr.Column(scale=1, min_width=480, elem_classes="prompt-column"):
463
+ prompt_preset = gr.Dropdown(
464
+ label="音色 · Voice Preset",
465
+ choices=build_prompt_choice_items(app_config.prompt_presets),
466
+ value=app_config.default_prompt_name,
467
+ info="内置音色clone样本;选择后自动填入参考音频与转写。",
468
+ elem_id="voice-preset-dropdown",
469
+ elem_classes="strong-label",
470
+ visible=show_prompt_preset,
471
+ )
472
+ prompt_audio_path = gr.Audio(
473
+ label="参考音频 · Prompt Audio",
474
+ sources=["upload"],
475
+ type="filepath",
476
+ value=app_config.default_prompt_audio_path,
477
+ elem_classes="strong-label",
478
+ )
479
+ prompt_text = gr.Textbox(
480
+ label="参考音频转写 · Prompt Text",
481
+ lines=5,
482
+ value=app_config.default_prompt_text,
483
+ placeholder="Prompt audio 对应的文本转写(continuation cloning 必填)",
484
+ elem_classes="strong-label",
485
+ )
486
+
487
+ with gr.Column(scale=1, min_width=480, elem_classes="synthesis-column"):
488
+ text = gr.Textbox(
489
+ label="待合成文本 · Text",
490
+ lines=5,
491
+ max_lines=8,
492
+ value=DEFAULT_INPUT_TEXT,
493
+ placeholder="输入待合成的文本",
494
+ elem_classes="strong-label",
495
+ )
496
+ with gr.Accordion("⚙️ Settings", open=False, elem_classes="settings-card"):
497
+ with gr.Row(elem_classes="settings-slider-row"):
498
+ num_steps = gr.Slider(
499
+ label="Num Steps",
500
+ minimum=1,
501
+ maximum=32,
502
+ step=1,
503
+ value=app_config.default_num_steps,
504
+ )
505
+ with gr.Row(elem_classes="settings-slider-row"):
506
+ guidance_scale = gr.Slider(
507
+ label="Guidance Scale",
508
+ minimum=1.0,
509
+ maximum=3.0,
510
+ step=0.1,
511
+ value=app_config.default_guidance_scale,
512
+ )
513
+ with gr.Row(elem_classes="control-row"):
514
+ seed = gr.Number(
515
+ label="Seed",
516
+ value=DEFAULT_SEED,
517
+ precision=0,
518
+ scale=1,
519
+ min_width=180,
520
+ )
521
+ normalize_text = gr.Checkbox(
522
+ label="Normalize Text",
523
+ value=False,
524
+ scale=1,
525
+ min_width=180,
526
+ )
527
+ generate = gr.Button(
528
+ "Generate",
529
+ variant="primary",
530
+ size="lg",
531
+ elem_classes="generate-button",
532
+ )
533
+ audio_out = gr.Audio(
534
+ label="生成音频 · Output",
535
+ type="filepath",
536
+ elem_classes="output-audio",
537
+ )
538
+
539
+ if DEBUG_GRADIO_ENABLED:
540
+ with gr.Accordion("Debug", open=False):
541
+ synthesis_mode = gr.Dropdown(
542
+ label="SynthesisMode",
543
+ choices=list(GRADIO_SYNTHESIS_MODE_CHOICES),
544
+ value="tts",
545
+ info="选择合成模式;界面显示名会自动映射到 runtime 对应模板。",
546
+ )
547
+ ode_method = gr.Textbox(
548
+ label="ODE Method",
549
+ value=DEFAULT_ODE_METHOD,
550
+ lines=1,
551
+ )
552
+ speaker_scale = gr.Slider(
553
+ label="Speaker Scale",
554
+ minimum=0.0,
555
+ maximum=3.0,
556
+ step=0.1,
557
+ value=app_config.default_speaker_scale,
558
+ info="说话人 x-vector 强度",
559
+ )
560
+ metrics = gr.JSON(label="Metrics", value=app_service.metadata())
561
+ build_startup_config_panel(gr, app_config)
562
+ else:
563
+ synthesis_mode = gr.State(value="tts")
564
+ ode_method = gr.State(value=DEFAULT_ODE_METHOD)
565
+ speaker_scale = gr.State(value=app_config.default_speaker_scale)
566
+ metrics = gr.State(value={})
567
+
568
+ generate.click(
569
+ fn=run_synthesis,
570
+ inputs=[
571
+ text,
572
+ synthesis_mode,
573
+ prompt_audio_path,
574
+ prompt_text,
575
+ ode_method,
576
+ num_steps,
577
+ guidance_scale,
578
+ speaker_scale,
579
+ normalize_text,
580
+ seed,
581
+ ],
582
+ outputs=[audio_out, metrics],
583
+ concurrency_limit=1,
584
+ )
585
+ prompt_preset.change(
586
+ fn=select_prompt_preset,
587
+ inputs=[prompt_preset],
588
+ outputs=[prompt_audio_path, prompt_text],
589
+ concurrency_limit=1,
590
+ )
591
+
592
+ return demo.queue(default_concurrency_limit=1, max_size=8)
593
+
594
+
595
+ def main() -> None:
596
+ args = parse_args()
597
+ import gradio as gr
598
+ from loguru import logger
599
+
600
+ from apps.gradio.service import GradioAppService, build_gradio_app_config
601
+ from dots_tts.utils.logging import configure_logging
602
+
603
+ configure_logging(log_file=args.log_file)
604
+ logger.info(
605
+ "Gradio app starting: host={} port={} model_name_or_path={} output_dir={} "
606
+ "log_file={} output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={} "
607
+ "default_prompt_name={} skip_warmup={}",
608
+ args.host,
609
+ args.port,
610
+ args.model_name_or_path,
611
+ args.output_dir,
612
+ args.log_file,
613
+ args.output_retention_count,
614
+ args.max_generate_length,
615
+ args.execution_mode,
616
+ args.precision,
617
+ args.optimize,
618
+ args.default_prompt_name,
619
+ args.skip_warmup,
620
+ )
621
+ app_config = build_gradio_app_config(
622
+ host=args.host,
623
+ port=args.port,
624
+ execution_mode=args.execution_mode,
625
+ precision=args.precision,
626
+ optimize=args.optimize,
627
+ model_name_or_path=args.model_name_or_path,
628
+ output_dir=Path(args.output_dir),
629
+ output_retention_count=args.output_retention_count,
630
+ max_generate_length=args.max_generate_length,
631
+ default_prompt_name=args.default_prompt_name,
632
+ default_precision=args.default_precision,
633
+ default_num_steps=args.default_num_steps,
634
+ default_guidance_scale=args.default_guidance_scale,
635
+ default_speaker_scale=args.default_speaker_scale,
636
+ default_max_generate_length=args.default_max_generate_length,
637
+ )
638
+ app_service = GradioAppService(app_config)
639
+ if args.skip_warmup:
640
+ logger.info("Gradio app warmup skipped by --skip-warmup.")
641
+ else:
642
+ warmup_metrics = app_service.warmup()
643
+ logger.info("Gradio app warmup metrics: {}", warmup_metrics)
644
+ demo = build_demo(gr, app_config, app_service)
645
+ logger.info(
646
+ "Gradio app ready: host={} port={} execution_mode={} precision={} optimize={} default_model_name_or_path={}",
647
+ app_config.host,
648
+ app_config.port,
649
+ app_config.execution_mode,
650
+ app_config.precision,
651
+ app_config.optimize,
652
+ app_config.default_model_name_or_path,
653
+ )
654
+ demo.launch(
655
+ server_name=app_config.host,
656
+ server_port=app_config.port,
657
+ theme=build_playground_theme(gr),
658
+ css=PLAYGROUND_CSS,
659
+ )
660
+
661
+
662
+ if __name__ == "__main__":
663
+ main()
apps/gradio/constants.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ REPO_ROOT = Path(__file__).resolve().parents[2]
6
+ DEFAULT_HOST = "0.0.0.0"
7
+ DEFAULT_PORT = 7860
8
+ DEFAULT_OUTPUT_DIR = REPO_ROOT / "apps" / "gradio" / "outputs"
9
+ DEFAULT_LOG_FILE = REPO_ROOT / "apps" / "gradio" / "gradio.log"
10
+ DEFAULT_PROMPTS_DIR = REPO_ROOT / "apps" / "gradio" / "default_prompts"
11
+ DEFAULT_PROMPT_SOURCE_DIR = DEFAULT_PROMPTS_DIR
12
+ DEFAULT_PROMPT_MAPPING_FILE = DEFAULT_PROMPTS_DIR / "prompt_text"
13
+ DEFAULT_OUTPUT_RETENTION = 20
14
+ DEFAULT_EXECUTION_MODE = "generate_stream"
15
+ DEFAULT_PRECISION = "bfloat16"
16
+ DEFAULT_ODE_METHOD = "euler"
17
+ DEFAULT_NUM_STEPS = 10
18
+ DEFAULT_GUIDANCE_SCALE = 1.2
19
+ DEFAULT_SPEAKER_SCALE = 1.5
20
+ DEFAULT_MAX_GENERATE_LENGTH = 500
21
+ DEFAULT_SEED = 42
22
+ DEFAULT_INPUT_TEXT = ""
23
+ DEFAULT_WARMUP_TEXT = "dots.tts is a 2B-parameter fully continuous, end-to-end autoregressive (AR) text-to-speech system. The backbone pairs a semantic encoder, an LLM, and an autoregressive flow-matching acoustic head over a 48 kHz AudioVAE"
24
+ DEFAULT_PROMPT_NAME = "male_zh"
25
+ DEFAULT_PROMPT_NONE = "__none__"
26
+ PROMPT_AUDIO_SUFFIXES = (".wav", ".mp3", ".flac", ".m4a", ".ogg")
apps/gradio/default_prompts/prompt_text ADDED
File without changes
apps/gradio/languages.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ SUPPORTED_LANGUAGE_CODE_BY_NAME = {
4
+ "普通话": "ZH",
5
+ "粤语": "口音:粤语",
6
+ "北京话": "口音:北京官话",
7
+ "东北话": "口音:东北话",
8
+ "四川话": "口音:四川话",
9
+ "闽南话": "口音:闽南话",
10
+ "吴语": "口音:吴语",
11
+ "英语": "EN",
12
+ "西班牙语": "ES",
13
+ "印地语": "HI",
14
+ "阿拉伯语": "AR",
15
+ "孟加拉语": "BN",
16
+ "葡萄牙语": "PT",
17
+ "俄语": "RU",
18
+ "日语": "JA",
19
+ "法语": "FR",
20
+ "德语": "DE",
21
+ "韩语": "KO",
22
+ "意大利语": "IT",
23
+ "土耳其语": "TR",
24
+ "越南语": "VI",
25
+ "印尼语": "ID",
26
+ "乌尔都语": "UR",
27
+ "波斯语": "FA",
28
+ "泰米尔语": "TA",
29
+ "泰卢固语": "TE",
30
+ "菲律宾语": "FIL",
31
+ "马来语": "MS",
32
+ "旁遮普语": "PA",
33
+ "马拉地语": "MR",
34
+ "古吉拉特语": "GU",
35
+ "马拉雅拉姆语": "ML",
36
+ "卡纳达语": "KN",
37
+ "波兰语": "PL",
38
+ "乌克兰语": "UK",
39
+ "荷兰语": "NL",
40
+ "泰语": "TH",
41
+ "罗马尼亚语": "RO",
42
+ "斯瓦希里语": "SW",
43
+ "希伯来语": "HE",
44
+ "捷克语": "CS",
45
+ "希腊语": "EL",
46
+ "匈牙利语": "HU",
47
+ "瑞典语": "SV",
48
+ "丹麦语": "DA",
49
+ "芬兰语": "FI",
50
+ "书面挪威语": "NB",
51
+ "斯洛伐克语": "SK",
52
+ "斯洛文尼亚语": "SL",
53
+ "塞尔维亚语": "SR",
54
+ "波斯尼亚语": "BS",
55
+ "克罗地亚语": "HR",
56
+ "保加利亚语": "BG",
57
+ "马其顿语": "MK",
58
+ "立陶宛语": "LT",
59
+ "拉脱维亚语": "LV",
60
+ "爱沙尼亚语": "ET",
61
+ "冰岛语": "IS",
62
+ "爱尔兰语": "GA",
63
+ "威尔士语": "CY",
64
+ "加泰罗尼亚语": "CA",
65
+ "加利西亚语": "GL",
66
+ "奥克语": "OC",
67
+ "阿斯图里亚斯语": "AST",
68
+ "尼泊尔语": "NE",
69
+ "信德语": "SD",
70
+ "奥里亚语": "OR",
71
+ "阿萨姆语": "AS",
72
+ "普什图语": "PS",
73
+ "缅甸语": "MY",
74
+ "高棉语": "KM",
75
+ "老挝语": "LO",
76
+ "哈萨克语": "KK",
77
+ "乌兹别克语": "UZ",
78
+ "吉尔吉斯语": "KY",
79
+ "塔吉克语": "TG",
80
+ "阿塞拜疆语": "AZ",
81
+ "格鲁吉亚语": "KA",
82
+ "亚美尼亚语": "HY",
83
+ "白俄罗斯语": "BE",
84
+ "卢森堡语": "LB",
85
+ "马耳他语": "MT",
86
+ "毛利语": "MI",
87
+ "南非荷兰语": "AF",
88
+ "祖鲁语": "ZU",
89
+ "科萨语": "XH",
90
+ "约鲁巴语": "YO",
91
+ "豪萨语": "HA",
92
+ "伊博语": "IG",
93
+ "阿姆哈拉语": "AM",
94
+ "奥罗莫语": "OM",
95
+ "北索托语": "NSO",
96
+ "尼扬贾语": "NY",
97
+ "修纳语": "SN",
98
+ "索马里语": "SO",
99
+ "卢干达语": "LG",
100
+ "林加拉语": "LN",
101
+ "卢奥语": "LUO",
102
+ "坎巴语": "KAM",
103
+ "翁本杜语": "UMB",
104
+ "富拉语": "FF",
105
+ "沃洛夫语": "WO",
106
+ "中库尔德语": "CKB",
107
+ "宿务语": "CEB",
108
+ "佛得角克里奥尔语": "KEA",
109
+ "蒙古语": "MN",
110
+ "爪哇语": "JV",
111
+ }
112
+
113
+
114
+ def build_language_choice_items() -> list[tuple[str, str]]:
115
+ return [("不指定", ""), *[(name, code) for name, code in SUPPORTED_LANGUAGE_CODE_BY_NAME.items()]]
apps/gradio/service.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import shutil
4
+ import sys
5
+ import threading
6
+ import time
7
+ import uuid
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Literal
11
+
12
+ REPO_ROOT = Path(__file__).resolve().parents[2]
13
+ SRC_ROOT = REPO_ROOT / "src"
14
+
15
+ for import_root in (REPO_ROOT, SRC_ROOT):
16
+ import_root_str = str(import_root)
17
+ if import_root_str not in sys.path:
18
+ sys.path.insert(0, import_root_str)
19
+
20
+ import soundfile as sf # noqa: E402
21
+ import torch # noqa: E402
22
+ from loguru import logger # noqa: E402
23
+
24
+ from apps.gradio.constants import ( # noqa: E402
25
+ DEFAULT_EXECUTION_MODE,
26
+ DEFAULT_GUIDANCE_SCALE,
27
+ DEFAULT_HOST,
28
+ DEFAULT_MAX_GENERATE_LENGTH,
29
+ DEFAULT_NUM_STEPS,
30
+ DEFAULT_ODE_METHOD,
31
+ DEFAULT_OUTPUT_DIR,
32
+ DEFAULT_OUTPUT_RETENTION,
33
+ DEFAULT_PORT,
34
+ DEFAULT_PRECISION,
35
+ DEFAULT_PROMPT_MAPPING_FILE,
36
+ DEFAULT_PROMPT_NAME,
37
+ DEFAULT_PROMPT_NONE,
38
+ DEFAULT_PROMPT_SOURCE_DIR,
39
+ DEFAULT_PROMPTS_DIR,
40
+ DEFAULT_SEED,
41
+ DEFAULT_SPEAKER_SCALE,
42
+ DEFAULT_WARMUP_TEXT,
43
+ PROMPT_AUDIO_SUFFIXES,
44
+ )
45
+ from apps.gradio.languages import ( # noqa: E402
46
+ SUPPORTED_LANGUAGE_CODE_BY_NAME,
47
+ build_language_choice_items,
48
+ )
49
+ from dots_tts.runtime import DotsTtsRuntime # noqa: E402
50
+ from dots_tts.utils.util import seed_everything # noqa: E402
51
+
52
+ ExecutionMode = Literal["generate", "generate_stream"]
53
+ GRADIO_SYNTHESIS_MODE_CHOICES = (
54
+ ("tts", "tts"),
55
+ ("instruct_tts", "instruction_tts"),
56
+ ("instruct_tts_general", "text_to_audio"),
57
+ )
58
+ GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES = tuple(
59
+ value for _, value in GRADIO_SYNTHESIS_MODE_CHOICES
60
+ )
61
+
62
+
63
+ @dataclass(frozen=True)
64
+ class PromptPreset:
65
+ name: str
66
+ audio_path: str
67
+ prompt_text: str
68
+
69
+
70
+ def _is_prompt_asset(path: Path) -> bool:
71
+ return path.is_file() and (
72
+ path.name == "prompt_text" or path.suffix.lower() in PROMPT_AUDIO_SUFFIXES
73
+ )
74
+
75
+
76
+ def sync_default_prompt_library(
77
+ source_dir: Path = DEFAULT_PROMPT_SOURCE_DIR,
78
+ target_dir: Path = DEFAULT_PROMPTS_DIR,
79
+ ) -> None:
80
+ source_dir = Path(source_dir)
81
+ if not source_dir.is_dir():
82
+ logger.info(
83
+ "Prompt library sync skipped: source_dir={} does not exist.",
84
+ source_dir,
85
+ )
86
+ return
87
+
88
+ target_dir = Path(target_dir)
89
+ target_dir.mkdir(parents=True, exist_ok=True)
90
+ logger.info(
91
+ "Prompt library sync started: source_dir={} target_dir={}",
92
+ source_dir,
93
+ target_dir,
94
+ )
95
+
96
+ source_assets = {
97
+ asset.name: asset for asset in sorted(source_dir.iterdir()) if _is_prompt_asset(asset)
98
+ }
99
+ copied_count = 0
100
+ for asset_name, source_asset in source_assets.items():
101
+ target_asset = target_dir / asset_name
102
+ if (
103
+ not target_asset.exists()
104
+ or target_asset.stat().st_size != source_asset.stat().st_size
105
+ or target_asset.stat().st_mtime_ns != source_asset.stat().st_mtime_ns
106
+ ):
107
+ shutil.copy2(source_asset, target_asset)
108
+ copied_count += 1
109
+
110
+ removed_count = 0
111
+ for target_asset in sorted(target_dir.iterdir()):
112
+ if _is_prompt_asset(target_asset) and target_asset.name not in source_assets:
113
+ target_asset.unlink(missing_ok=True)
114
+ removed_count += 1
115
+ logger.info(
116
+ "Prompt library sync completed: copied_assets={} removed_assets={} "
117
+ "available_assets={}",
118
+ copied_count,
119
+ removed_count,
120
+ len(source_assets),
121
+ )
122
+
123
+
124
+ def _load_prompt_text_map(mapping_file: Path) -> dict[str, str]:
125
+ if not mapping_file.is_file():
126
+ return {}
127
+
128
+ prompt_text_map: dict[str, str] = {}
129
+ with mapping_file.open(encoding="utf-8") as file_obj:
130
+ for raw_line in file_obj:
131
+ line = raw_line.strip()
132
+ if not line or line.startswith("#") or "|" not in line:
133
+ continue
134
+ name, text = line.split("|", 1)
135
+ prompt_text_map[name.strip()] = text.strip()
136
+ return prompt_text_map
137
+
138
+
139
+ def discover_prompt_presets(
140
+ prompts_dir: Path = DEFAULT_PROMPTS_DIR,
141
+ mapping_file: Path = DEFAULT_PROMPT_MAPPING_FILE,
142
+ ) -> tuple[PromptPreset, ...]:
143
+ prompts_dir = Path(prompts_dir)
144
+ if not prompts_dir.is_dir():
145
+ return ()
146
+
147
+ prompt_text_map = _load_prompt_text_map(Path(mapping_file))
148
+ prompt_audio_paths = [
149
+ audio_path
150
+ for audio_path in sorted(prompts_dir.iterdir(), key=lambda path: (path.stem == "child", path.stem))
151
+ if audio_path.is_file() and audio_path.suffix.lower() in PROMPT_AUDIO_SUFFIXES
152
+ ]
153
+ return tuple(
154
+ PromptPreset(
155
+ name=audio_path.stem,
156
+ audio_path=str(audio_path.resolve()),
157
+ prompt_text=prompt_text_map.get(audio_path.stem, ""),
158
+ )
159
+ for audio_path in prompt_audio_paths
160
+ )
161
+
162
+
163
+ def build_prompt_choice_items(
164
+ prompt_presets: tuple[PromptPreset, ...],
165
+ ) -> list[tuple[str, str]]:
166
+ return [("No Preset", DEFAULT_PROMPT_NONE), *[(preset.name, preset.name) for preset in prompt_presets]]
167
+
168
+
169
+ def resolve_default_prompt_selection(
170
+ prompt_presets: tuple[PromptPreset, ...],
171
+ default_prompt_name: str = DEFAULT_PROMPT_NAME,
172
+ ) -> tuple[str, str | None, str]:
173
+ if not prompt_presets:
174
+ return DEFAULT_PROMPT_NONE, None, ""
175
+
176
+ preset_by_name = {preset.name: preset for preset in prompt_presets}
177
+ selected_name = default_prompt_name if default_prompt_name in preset_by_name else prompt_presets[0].name
178
+ selected_preset = preset_by_name[selected_name]
179
+ return selected_name, selected_preset.audio_path, selected_preset.prompt_text
180
+
181
+
182
+ def resolve_prompt_selection(
183
+ prompt_name: str,
184
+ prompt_presets: tuple[PromptPreset, ...],
185
+ ) -> tuple[str | None, str]:
186
+ if prompt_name == DEFAULT_PROMPT_NONE:
187
+ return None, ""
188
+
189
+ for preset in prompt_presets:
190
+ if preset.name == prompt_name:
191
+ return preset.audio_path, preset.prompt_text
192
+ return None, ""
193
+
194
+
195
+ def discover_local_model_choices(repo_root: Path = REPO_ROOT) -> list[str]:
196
+ model_root = Path(repo_root) / "pretrained_models"
197
+ if not model_root.is_dir():
198
+ return []
199
+ return sorted(
200
+ path.relative_to(repo_root).as_posix()
201
+ for path in model_root.glob("**/model")
202
+ if path.is_dir()
203
+ )
204
+
205
+
206
+ def resolve_model_name_or_path(model_name_or_path: str, repo_root: Path = REPO_ROOT) -> str:
207
+ normalized = model_name_or_path.strip()
208
+ if not normalized:
209
+ raise ValueError("model_name_or_path 不能为空。")
210
+
211
+ direct_path = Path(normalized).expanduser()
212
+ if direct_path.exists():
213
+ return str(direct_path.resolve())
214
+
215
+ repo_relative_path = Path(repo_root) / normalized
216
+ if repo_relative_path.exists():
217
+ return str(repo_relative_path.resolve())
218
+
219
+ return normalized
220
+
221
+
222
+ def default_model_name_or_path(repo_root: Path = REPO_ROOT) -> str:
223
+ discovered = discover_local_model_choices(repo_root=repo_root)
224
+ if not discovered:
225
+ return ""
226
+ return discovered[0]
227
+
228
+
229
+ @dataclass(frozen=True)
230
+ class GradioAppConfig:
231
+ host: str
232
+ port: int
233
+ execution_mode: ExecutionMode
234
+ precision: str
235
+ optimize: bool
236
+ output_dir: Path
237
+ prompts_dir: Path
238
+ output_retention_count: int
239
+ max_generate_length: int
240
+ default_model_name_or_path: str
241
+ prompt_presets: tuple[PromptPreset, ...]
242
+ default_prompt_name: str
243
+ default_prompt_audio_path: str | None
244
+ default_prompt_text: str
245
+ default_precision: str
246
+ default_num_steps: int
247
+ default_guidance_scale: float
248
+ default_speaker_scale: float
249
+ default_max_generate_length: int
250
+ local_model_choices: tuple[str, ...]
251
+ repo_root: Path = REPO_ROOT
252
+
253
+
254
+ def build_gradio_app_config(
255
+ *,
256
+ host: str = DEFAULT_HOST,
257
+ port: int = DEFAULT_PORT,
258
+ execution_mode: ExecutionMode = DEFAULT_EXECUTION_MODE,
259
+ precision: str = DEFAULT_PRECISION,
260
+ optimize: bool = False,
261
+ output_dir: Path = DEFAULT_OUTPUT_DIR,
262
+ output_retention_count: int = DEFAULT_OUTPUT_RETENTION,
263
+ max_generate_length: int = DEFAULT_MAX_GENERATE_LENGTH,
264
+ model_name_or_path: str | None = None,
265
+ default_prompt_name: str = DEFAULT_PROMPT_NAME,
266
+ default_precision: str = DEFAULT_PRECISION,
267
+ default_num_steps: int = DEFAULT_NUM_STEPS,
268
+ default_guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
269
+ default_speaker_scale: float = DEFAULT_SPEAKER_SCALE,
270
+ default_max_generate_length: int = DEFAULT_MAX_GENERATE_LENGTH,
271
+ repo_root: Path = REPO_ROOT,
272
+ prompts_dir: Path = DEFAULT_PROMPTS_DIR,
273
+ prompt_source_dir: Path = DEFAULT_PROMPT_SOURCE_DIR,
274
+ ) -> GradioAppConfig:
275
+ sync_default_prompt_library(
276
+ source_dir=prompt_source_dir,
277
+ target_dir=prompts_dir,
278
+ )
279
+ discovered_models = discover_local_model_choices(repo_root=repo_root)
280
+ prompt_presets = discover_prompt_presets(
281
+ prompts_dir=prompts_dir,
282
+ mapping_file=prompts_dir / "prompt_text",
283
+ )
284
+ resolved_default_prompt_name, default_prompt_audio_path, default_prompt_text = (
285
+ resolve_default_prompt_selection(
286
+ prompt_presets,
287
+ default_prompt_name=default_prompt_name,
288
+ )
289
+ )
290
+ selected_model_name_or_path = (
291
+ model_name_or_path.strip()
292
+ if model_name_or_path is not None
293
+ else default_model_name_or_path(repo_root=repo_root)
294
+ )
295
+ if not selected_model_name_or_path:
296
+ raise ValueError("No default model found. Please pass --model-name-or-path.")
297
+ if execution_mode not in ("generate", "generate_stream"):
298
+ raise ValueError(f"Unsupported execution_mode: {execution_mode}")
299
+ resolved_max_generate_length = int(max_generate_length)
300
+ if resolved_max_generate_length <= 0:
301
+ raise ValueError("max_generate_length must be positive.")
302
+ resolved_precision = precision.strip() or DEFAULT_PRECISION
303
+ logger.info(
304
+ "Gradio app config prepared: host={} port={} output_dir={} "
305
+ "output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={} "
306
+ "default_model_name_or_path={} prompt_preset_count={} language_count={} local_model_choice_count={}",
307
+ host,
308
+ port,
309
+ output_dir,
310
+ output_retention_count,
311
+ resolved_max_generate_length,
312
+ execution_mode,
313
+ resolved_precision,
314
+ bool(optimize),
315
+ selected_model_name_or_path,
316
+ len(prompt_presets),
317
+ len(SUPPORTED_LANGUAGE_CODE_BY_NAME),
318
+ len(discovered_models),
319
+ )
320
+ return GradioAppConfig(
321
+ host=host,
322
+ port=int(port),
323
+ execution_mode=execution_mode,
324
+ precision=resolved_precision,
325
+ optimize=bool(optimize),
326
+ output_dir=Path(output_dir),
327
+ prompts_dir=Path(prompts_dir),
328
+ output_retention_count=int(output_retention_count),
329
+ max_generate_length=resolved_max_generate_length,
330
+ default_model_name_or_path=selected_model_name_or_path,
331
+ prompt_presets=prompt_presets,
332
+ default_prompt_name=resolved_default_prompt_name,
333
+ default_prompt_audio_path=default_prompt_audio_path,
334
+ default_prompt_text=default_prompt_text,
335
+ default_precision=default_precision,
336
+ default_num_steps=int(default_num_steps),
337
+ default_guidance_scale=float(default_guidance_scale),
338
+ default_speaker_scale=float(default_speaker_scale),
339
+ default_max_generate_length=int(default_max_generate_length),
340
+ local_model_choices=tuple(discovered_models),
341
+ repo_root=repo_root,
342
+ )
343
+
344
+
345
+ @dataclass(frozen=True)
346
+ class SynthesisRequest:
347
+ model_name_or_path: str
348
+ text: str
349
+ prompt_audio_path: str | None = None
350
+ prompt_text: str | None = None
351
+ execution_mode: ExecutionMode = DEFAULT_EXECUTION_MODE
352
+ template_name: str = "tts"
353
+ language: str | None = None
354
+ ode_method: str = DEFAULT_ODE_METHOD
355
+ num_steps: int = DEFAULT_NUM_STEPS
356
+ guidance_scale: float = DEFAULT_GUIDANCE_SCALE
357
+ speaker_scale: float = DEFAULT_SPEAKER_SCALE
358
+ normalize_text: bool = False
359
+ seed: int = DEFAULT_SEED
360
+
361
+
362
+ @dataclass(frozen=True)
363
+ class SynthesisResult:
364
+ audio_path: str
365
+ metrics: dict[str, Any]
366
+ status: str
367
+
368
+
369
+ class GradioAppService:
370
+ def __init__(self, config: GradioAppConfig):
371
+ self.config = config
372
+ self.config.output_dir.mkdir(parents=True, exist_ok=True)
373
+ self._lock = threading.Lock()
374
+ self._runtime: DotsTtsRuntime | None = None
375
+ self._runtime_model_name_or_path: str | None = None
376
+ logger.info(
377
+ "Gradio service initialized: output_dir={} default_model_name_or_path={} "
378
+ "output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={}",
379
+ self.config.output_dir,
380
+ self.config.default_model_name_or_path,
381
+ self.config.output_retention_count,
382
+ self.config.max_generate_length,
383
+ self.config.execution_mode,
384
+ self.config.precision,
385
+ self.config.optimize,
386
+ )
387
+
388
+ def metadata(self) -> dict[str, Any]:
389
+ return {
390
+ "repo_root": str(self.config.repo_root),
391
+ "default_model_name_or_path": self.config.default_model_name_or_path,
392
+ "local_model_choices": list(self.config.local_model_choices),
393
+ "prompts_dir": str(self.config.prompts_dir),
394
+ "prompt_preset_names": [preset.name for preset in self.config.prompt_presets],
395
+ "default_prompt_name": self.config.default_prompt_name,
396
+ "output_dir": str(self.config.output_dir),
397
+ "output_retention_count": self.config.output_retention_count,
398
+ "configured_max_generate_length": self.config.max_generate_length,
399
+ "configured_execution_mode": self.config.execution_mode,
400
+ "configured_precision": self.config.precision,
401
+ "optimize": self.config.optimize,
402
+ "loaded_model_name_or_path": self._runtime_model_name_or_path,
403
+ "loaded_max_generate_length": (
404
+ self.config.max_generate_length if self._runtime is not None else None
405
+ ),
406
+ "loaded_precision": (
407
+ self.config.precision if self._runtime is not None else None
408
+ ),
409
+ "model_loaded": self._runtime is not None,
410
+ "host": self.config.host,
411
+ "port": self.config.port,
412
+ "default_precision": self.config.default_precision,
413
+ "default_num_steps": self.config.default_num_steps,
414
+ "default_guidance_scale": self.config.default_guidance_scale,
415
+ "default_speaker_scale": self.config.default_speaker_scale,
416
+ "default_max_generate_length": self.config.default_max_generate_length,
417
+ "supported_languages": build_language_choice_items()[1:],
418
+ "supported_template_names": list(GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES),
419
+ }
420
+
421
+ def _get_runtime(
422
+ self,
423
+ model_name_or_path: str,
424
+ ) -> tuple[DotsTtsRuntime, str]:
425
+ resolved_model_name_or_path = resolve_model_name_or_path(
426
+ model_name_or_path,
427
+ repo_root=self.config.repo_root,
428
+ )
429
+ if (
430
+ self._runtime is None
431
+ or self._runtime_model_name_or_path != resolved_model_name_or_path
432
+ ):
433
+ logger.info(
434
+ "Gradio runtime cache miss: requested_model={} resolved_model={} "
435
+ "max_generate_length={} execution_mode={} precision={} optimize={}",
436
+ model_name_or_path,
437
+ resolved_model_name_or_path,
438
+ self.config.max_generate_length,
439
+ self.config.execution_mode,
440
+ self.config.precision,
441
+ self.config.optimize,
442
+ )
443
+ self._runtime = DotsTtsRuntime.from_pretrained(
444
+ resolved_model_name_or_path,
445
+ precision=self.config.precision,
446
+ optimize=self.config.optimize,
447
+ max_generate_length=self.config.max_generate_length,
448
+ )
449
+ self._runtime_model_name_or_path = resolved_model_name_or_path
450
+ else:
451
+ logger.info(
452
+ "Gradio runtime cache hit: requested_model={} resolved_model={} "
453
+ "max_generate_length={} execution_mode={} precision={} optimize={}",
454
+ model_name_or_path,
455
+ resolved_model_name_or_path,
456
+ self.config.max_generate_length,
457
+ self.config.execution_mode,
458
+ self.config.precision,
459
+ self.config.optimize,
460
+ )
461
+ return self._runtime, resolved_model_name_or_path
462
+
463
+ def _build_stream_request_id(
464
+ self,
465
+ runtime: DotsTtsRuntime,
466
+ request: SynthesisRequest,
467
+ ) -> str:
468
+ normalized_text, normalized_language = runtime._process_text( # noqa: SLF001
469
+ request.text,
470
+ language=request.language,
471
+ normalize=request.normalize_text,
472
+ )
473
+ normalized_prompt_text = runtime._process_prompt_text( # noqa: SLF001
474
+ request.prompt_text,
475
+ language=normalized_language,
476
+ )
477
+ if normalized_language is not None and not normalized_prompt_text:
478
+ from dots_tts.utils.text import attach_language_tag # noqa: PLC0415
479
+
480
+ normalized_text = attach_language_tag(
481
+ normalized_text,
482
+ normalized_language,
483
+ )
484
+ request_id_kwargs = {
485
+ "text": normalized_text,
486
+ "prompt_audio_path": request.prompt_audio_path,
487
+ "prompt_text": normalized_prompt_text,
488
+ "template_name": request.template_name,
489
+ }
490
+ if normalized_language is not None:
491
+ request_id_kwargs["language"] = normalized_language
492
+ return runtime._build_request_id( # noqa: SLF001
493
+ **request_id_kwargs,
494
+ )
495
+
496
+ @staticmethod
497
+ def _build_runtime_generate_kwargs(request: SynthesisRequest) -> dict[str, Any]:
498
+ runtime_kwargs: dict[str, Any] = {
499
+ "text": request.text,
500
+ "prompt_audio_path": request.prompt_audio_path,
501
+ "prompt_text": request.prompt_text,
502
+ "template_name": request.template_name,
503
+ "ode_method": request.ode_method,
504
+ "num_steps": request.num_steps,
505
+ "guidance_scale": request.guidance_scale,
506
+ "speaker_scale": request.speaker_scale,
507
+ "normalize_text": request.normalize_text,
508
+ }
509
+ if request.language is not None:
510
+ runtime_kwargs["language"] = request.language
511
+ return runtime_kwargs
512
+
513
+ def _run_stream_generation(
514
+ self,
515
+ runtime: DotsTtsRuntime,
516
+ request: SynthesisRequest,
517
+ ) -> dict[str, Any]:
518
+ start_time = time.time()
519
+ chunks = [
520
+ chunk.detach().float().cpu()
521
+ for chunk in runtime.generate_stream(
522
+ **self._build_runtime_generate_kwargs(request)
523
+ )
524
+ ]
525
+ if not chunks:
526
+ raise ValueError("流式生成未返回任何音频块。")
527
+
528
+ audio = torch.cat(chunks, dim=-1)
529
+ elapsed_seconds = time.time() - start_time
530
+ audio_seconds = audio.shape[-1] / runtime.sample_rate
531
+ rtf = elapsed_seconds / audio_seconds if audio_seconds > 0 else float("inf")
532
+ return {
533
+ "fid": self._build_stream_request_id(runtime, request),
534
+ "audio": audio,
535
+ "sample_rate": runtime.sample_rate,
536
+ "time_used": elapsed_seconds,
537
+ "rtf": rtf,
538
+ "chunk_count": len(chunks),
539
+ }
540
+
541
+ def warmup(self, text: str | None = None) -> dict[str, Any]:
542
+ warmup_text = (text or "").strip() or DEFAULT_WARMUP_TEXT.strip()
543
+ if not warmup_text:
544
+ raise ValueError("DEFAULT_WARMUP_TEXT 不能为空。")
545
+
546
+ with self._lock:
547
+ logger.info(
548
+ "Gradio warmup requested: default_model_name_or_path={} execution_mode={} precision={} optimize={} seed={}",
549
+ self.config.default_model_name_or_path,
550
+ self.config.execution_mode,
551
+ self.config.precision,
552
+ self.config.optimize,
553
+ DEFAULT_SEED,
554
+ )
555
+ try:
556
+ seed_everything(DEFAULT_SEED)
557
+ runtime, resolved_model_name_or_path = self._get_runtime(
558
+ self.config.default_model_name_or_path,
559
+ )
560
+ warmup_request = SynthesisRequest(
561
+ model_name_or_path=self.config.default_model_name_or_path,
562
+ text=warmup_text,
563
+ execution_mode=self.config.execution_mode,
564
+ template_name="tts",
565
+ ode_method=DEFAULT_ODE_METHOD,
566
+ num_steps=self.config.default_num_steps,
567
+ guidance_scale=self.config.default_guidance_scale,
568
+ speaker_scale=self.config.default_speaker_scale,
569
+ normalize_text=False,
570
+ seed=DEFAULT_SEED,
571
+ )
572
+ request_id = self._build_stream_request_id(runtime, warmup_request)
573
+ if self.config.execution_mode == "generate_stream":
574
+ result = self._run_stream_generation(runtime, warmup_request)
575
+ else:
576
+ start_time = time.time()
577
+ result = runtime.generate(**self._build_runtime_generate_kwargs(warmup_request))
578
+ result["time_used"] = time.time() - start_time
579
+ result["chunk_count"] = 1
580
+ audio_samples = int(result["audio"].shape[-1])
581
+ except Exception:
582
+ logger.exception(
583
+ "Gradio warmup failed: default_model_name_or_path={}",
584
+ self.config.default_model_name_or_path,
585
+ )
586
+ raise
587
+ audio_seconds = audio_samples / runtime.sample_rate
588
+ metrics = {
589
+ "request_id": request_id,
590
+ "execution_mode": self.config.execution_mode,
591
+ "chunk_count": int(result["chunk_count"]),
592
+ "resolved_model_name_or_path": resolved_model_name_or_path,
593
+ "sample_rate": runtime.sample_rate,
594
+ "elapsed_seconds": round(float(result["time_used"]), 3),
595
+ "audio_seconds": round(float(audio_seconds), 3),
596
+ "rtf": round(float(result["rtf"]), 4),
597
+ "seed": DEFAULT_SEED,
598
+ "text": warmup_text,
599
+ }
600
+ logger.info(
601
+ "Gradio warmup ready: request_id={} execution_mode={} resolved_model_name_or_path={}",
602
+ metrics["request_id"],
603
+ metrics["execution_mode"],
604
+ metrics["resolved_model_name_or_path"],
605
+ )
606
+ return metrics
607
+
608
+ def _normalize_request(self, request: SynthesisRequest) -> SynthesisRequest:
609
+ normalized_text = request.text.strip()
610
+ if not normalized_text:
611
+ raise ValueError("text 不能为空。")
612
+
613
+ normalized_prompt_audio_path = request.prompt_audio_path or None
614
+ normalized_prompt_text = (request.prompt_text or "").strip() or None
615
+ if normalized_prompt_text and not normalized_prompt_audio_path:
616
+ raise ValueError("prompt_text requires prompt_audio_path.")
617
+ normalized_template_name = request.template_name.strip() or "tts"
618
+ if normalized_template_name not in GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES:
619
+ raise ValueError(
620
+ f"Unsupported template_name={normalized_template_name!r}. "
621
+ f"Expected one of {list(GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES)}."
622
+ )
623
+ normalized_language = (request.language or "").strip() or None
624
+ supported_language_codes = set(SUPPORTED_LANGUAGE_CODE_BY_NAME.values())
625
+ if (
626
+ normalized_language is not None
627
+ and normalized_language not in supported_language_codes
628
+ ):
629
+ raise ValueError(
630
+ f"Unsupported language={normalized_language!r}. "
631
+ f"Expected one of {sorted(supported_language_codes)}."
632
+ )
633
+
634
+ resolved_seed = int(request.seed)
635
+ return SynthesisRequest(
636
+ model_name_or_path=request.model_name_or_path.strip(),
637
+ text=normalized_text,
638
+ prompt_audio_path=normalized_prompt_audio_path,
639
+ prompt_text=normalized_prompt_text,
640
+ execution_mode=request.execution_mode,
641
+ template_name=normalized_template_name,
642
+ language=normalized_language,
643
+ ode_method=request.ode_method.strip() or DEFAULT_ODE_METHOD,
644
+ num_steps=int(request.num_steps),
645
+ guidance_scale=float(request.guidance_scale),
646
+ speaker_scale=float(request.speaker_scale),
647
+ normalize_text=bool(request.normalize_text),
648
+ seed=resolved_seed,
649
+ )
650
+
651
+ def _build_output_path(self) -> Path:
652
+ output_name = f"{time.strftime('%Y%m%d-%H%M%S')}-{uuid.uuid4().hex[:8]}.wav"
653
+ return self.config.output_dir / output_name
654
+
655
+ def _cleanup_outputs(self) -> None:
656
+ if self.config.output_retention_count <= 0:
657
+ return
658
+
659
+ wav_files = sorted(
660
+ self.config.output_dir.glob("*.wav"),
661
+ key=lambda path: path.stat().st_mtime,
662
+ reverse=True,
663
+ )
664
+ removed_count = 0
665
+ for stale_file in wav_files[self.config.output_retention_count :]:
666
+ stale_file.unlink(missing_ok=True)
667
+ removed_count += 1
668
+ if removed_count > 0:
669
+ logger.info(
670
+ "Gradio output cleanup completed: removed_files={} retention_limit={}",
671
+ removed_count,
672
+ self.config.output_retention_count,
673
+ )
674
+
675
+ @staticmethod
676
+ def _waveform_to_numpy(audio: torch.Tensor):
677
+ waveform = audio.detach().float().cpu().squeeze()
678
+ if waveform.ndim == 0:
679
+ raise ValueError("生成音频为空。")
680
+ return waveform.numpy()
681
+
682
+ def _write_audio(self, audio: torch.Tensor, sample_rate: int) -> str:
683
+ output_path = self._build_output_path()
684
+ logger.info(
685
+ "Writing synthesized audio: output_path={} sample_rate={} samples={}",
686
+ output_path,
687
+ sample_rate,
688
+ audio.shape[-1],
689
+ )
690
+ sf.write(output_path, self._waveform_to_numpy(audio), sample_rate)
691
+ self._cleanup_outputs()
692
+ logger.info("Synthesized audio written: output_path={}", output_path)
693
+ return str(output_path)
694
+
695
+ def generate(self, request: SynthesisRequest) -> SynthesisResult:
696
+ normalized_request = self._normalize_request(request)
697
+
698
+ with self._lock:
699
+ try:
700
+ seed_everything(normalized_request.seed)
701
+ runtime, resolved_model_name_or_path = self._get_runtime(
702
+ normalized_request.model_name_or_path,
703
+ )
704
+ logger.info(
705
+ "Gradio request accepted: resolved_model_name_or_path={} execution_mode={} seed={}",
706
+ resolved_model_name_or_path,
707
+ normalized_request.execution_mode,
708
+ normalized_request.seed,
709
+ )
710
+ if normalized_request.execution_mode == "generate_stream":
711
+ result = self._run_stream_generation(runtime, normalized_request)
712
+ else:
713
+ result = runtime.generate(
714
+ **self._build_runtime_generate_kwargs(normalized_request)
715
+ )
716
+ result["chunk_count"] = 1
717
+ audio_path = self._write_audio(result["audio"], result["sample_rate"])
718
+ except Exception:
719
+ logger.exception(
720
+ "Gradio request failed: model_name_or_path={} execution_mode={} text_len={} has_prompt_audio={} has_prompt_text={} template_name={} language={} "
721
+ "precision={} ode_method={} num_steps={} guidance_scale={} speaker_scale={} max_generate_length={} "
722
+ "normalize_text={} seed={}",
723
+ normalized_request.model_name_or_path,
724
+ normalized_request.execution_mode,
725
+ len(normalized_request.text),
726
+ bool(normalized_request.prompt_audio_path),
727
+ bool(normalized_request.prompt_text),
728
+ normalized_request.template_name,
729
+ normalized_request.language,
730
+ self.config.precision,
731
+ normalized_request.ode_method,
732
+ normalized_request.num_steps,
733
+ normalized_request.guidance_scale,
734
+ normalized_request.speaker_scale,
735
+ self.config.max_generate_length,
736
+ normalized_request.normalize_text,
737
+ normalized_request.seed,
738
+ )
739
+ raise
740
+ audio_seconds = result["audio"].shape[-1] / result["sample_rate"]
741
+ metrics = {
742
+ "request_id": result["fid"],
743
+ "execution_mode": normalized_request.execution_mode,
744
+ "chunk_count": int(result["chunk_count"]),
745
+ "template_name": normalized_request.template_name,
746
+ "language": normalized_request.language,
747
+ "resolved_model_name_or_path": resolved_model_name_or_path,
748
+ "sample_rate": result["sample_rate"],
749
+ "elapsed_seconds": round(float(result["time_used"]), 3),
750
+ "audio_seconds": round(float(audio_seconds), 3),
751
+ "rtf": round(float(result["rtf"]), 4),
752
+ "seed": normalized_request.seed,
753
+ "output_path": audio_path,
754
+ }
755
+ logger.info(
756
+ "Gradio request output ready: request_id={} execution_mode={} resolved_model_name_or_path={} output_path={}",
757
+ metrics["request_id"],
758
+ metrics["execution_mode"],
759
+ metrics["resolved_model_name_or_path"],
760
+ metrics["output_path"],
761
+ )
762
+ status = (
763
+ f"完成:{Path(audio_path).name} | "
764
+ f"模式 {metrics['execution_mode']} | "
765
+ f"耗时 {metrics['elapsed_seconds']}s | "
766
+ f"音频 {metrics['audio_seconds']}s | "
767
+ f"RTF {metrics['rtf']}"
768
+ )
769
+ return SynthesisResult(
770
+ audio_path=audio_path,
771
+ metrics=metrics,
772
+ status=status,
773
+ )
configs/dots_tts.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_data:
2
+ train_audio_sample_rate: 48000
3
+ audio_samples_per_llm_token: 7680
4
+ sources:
5
+ - name: ljspeech_basic
6
+ weight: 1.0
7
+ pipeline: basic
8
+ adapter:
9
+ class_name: JsonlManifestSourceAdapter
10
+ params:
11
+ manifest_path: downloaded_data/ljspeech_48khz_manifest_train.jsonl
12
+ shuffle: true
13
+ - name: ljspeech_interleave
14
+ weight: 1.0
15
+ pipeline: interleave
16
+ adapter:
17
+ class_name: JsonlManifestSourceAdapter
18
+ params:
19
+ manifest_path: downloaded_data/ljspeech_48khz_manifest_train.jsonl
20
+ shuffle: true
21
+ # append other sources here if need
22
+ num_tokens_per_epoch: 2000000
23
+ num_workers: 20
24
+ pin_memory: true
25
+ max_audio_seconds_in_batch: 30.0
26
+ max_text_tokens_in_batch: 2048
27
+ max_samples_per_batch: null
28
+ bucketing_pool_size: 100
29
+ val_data:
30
+ train_audio_sample_rate: 48000
31
+ audio_samples_per_llm_token: 7680
32
+ sources:
33
+ - name: ljspeech_valid_basic
34
+ weight: 1.0
35
+ adapter:
36
+ class_name: JsonlManifestSourceAdapter
37
+ params:
38
+ manifest_path: downloaded_data/ljspeech_48khz_manifest_valid.jsonl
39
+ shuffle: false
40
+ pipeline: basic
41
+ - name: ljspeech_valid_interleave
42
+ weight: 1.0
43
+ pipeline: interleave
44
+ adapter:
45
+ class_name: JsonlManifestSourceAdapter
46
+ params:
47
+ manifest_path: downloaded_data/ljspeech_48khz_manifest_valid.jsonl
48
+ shuffle: false
49
+ pipeline: interleave
50
+ # append other sources here if need
51
+ num_workers: 4
52
+ pin_memory: true
53
+ max_audio_seconds_in_batch: 30.0
54
+ max_text_tokens_in_batch: 2048
55
+ max_samples_per_batch: null
56
+ bucketing_pool_size: 64
57
+ train:
58
+ pretrained_model_path: pretrained_models/pretrain_cpt_decay/latest/model/
59
+ output_dir: debug_train/run_003
60
+ seed: 42
61
+ learning_rate: 1.0e-05
62
+ weight_decay: 0.01
63
+ warmup_steps: 50
64
+ max_train_steps: 500
65
+ gradient_accumulation_steps: 2
66
+ grad_clip_norm: 1
67
+ save_interval: 500
68
+ max_checkpoints_to_keep: 40
69
+ log_interval: 10
70
+ eval_interval: 100
71
+ max_eval_batches: null
72
+ run_eval_on_start: false
73
+ loss:
74
+ ce_weight: 1.0
75
+ fm_weight: 1.0
76
+ eos_weight: 1.0
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spaces>=0.40.1
2
+ torch>=2.8.0
3
+ torchaudio>=2.8.0
4
+ transformers>=4.57.0
5
+ huggingface-hub>=0.36.0
6
+ gradio>=6.16.0
7
+ loguru>=0.7.3
8
+ langcodes[data]>=3.5.0
9
+ einops>=0.8.1
10
+ librosa>=0.11.0
11
+ soundfile>=0.13.1
12
+ numpy>=2.2.6
13
+ pydantic>=2.12.5,<3
14
+ PyYAML>=6.0.3
15
+ safetensors>=0.8.0rc0
16
+ torchdiffeq>=0.2.5
17
+ tqdm>=4.67.1
18
+ lingua-language-detector>=2.1.1
19
+ WeTextProcessing>=1.0.4
src/dots_tts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """dots.tts package."""
src/dots_tts/cli.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+
7
+ def parse_args(argv=None):
8
+ parser = argparse.ArgumentParser(description="dots.tts inference CLI.")
9
+ template_choices = ("tts", "instruction_tts", "text_to_audio", "tts_interleave")
10
+ parser.add_argument(
11
+ "--model-name-or-path",
12
+ required=True,
13
+ help="Local pretrained directory or Hugging Face repo id",
14
+ )
15
+ parser.add_argument(
16
+ "--revision", default=None, help="Optional Hugging Face revision"
17
+ )
18
+ parser.add_argument(
19
+ "--cache-dir", default=None, help="Optional Hugging Face cache dir"
20
+ )
21
+ parser.add_argument("--text", type=str, required=True, help="Input text")
22
+ parser.add_argument("--output", default="output.wav", help="Output wav file path")
23
+ parser.add_argument(
24
+ "--precision", type=str, default="bfloat16", help="Inference precision"
25
+ )
26
+ parser.add_argument(
27
+ "--seed",
28
+ type=int,
29
+ default=42,
30
+ help="Random seed for inference.",
31
+ )
32
+ parser.add_argument(
33
+ "--prompt-audio", type=str, default=None, help="Path to prompt audio"
34
+ )
35
+ parser.add_argument(
36
+ "--prompt-text", type=str, default=None, help="Transcript of prompt audio"
37
+ )
38
+ parser.add_argument(
39
+ "--language",
40
+ type=str,
41
+ default=None,
42
+ help="Language tag mode. Default: none. Supported values: none, auto_detect, or a language code/name such as EN/en/english/chinese.",
43
+ )
44
+ parser.add_argument(
45
+ "--template-name",
46
+ choices=template_choices,
47
+ default=None,
48
+ help="Named template preset for generation.",
49
+ )
50
+ parser.add_argument(
51
+ "--ode-method", type=str, default="euler", help="ODE solver method"
52
+ )
53
+ parser.add_argument(
54
+ "--num-steps", type=int, default=10, help="Diffusion sampling steps"
55
+ )
56
+ parser.add_argument(
57
+ "--guidance-scale",
58
+ type=float,
59
+ default=1.2,
60
+ help="Classifier-free guidance scale",
61
+ )
62
+ parser.add_argument(
63
+ "--speaker-scale",
64
+ type=float,
65
+ default=1.5,
66
+ help="Scale applied to the reference speaker embedding",
67
+ )
68
+ parser.add_argument(
69
+ "--max-generate-length",
70
+ type=int,
71
+ default=500,
72
+ help="Maximum total audio patch count (prompt + generated)",
73
+ )
74
+ parser.add_argument(
75
+ "--normalize-text",
76
+ action="store_true",
77
+ help="Whether to normalize text before inference",
78
+ )
79
+ parser.add_argument(
80
+ "--profile-inference",
81
+ action="store_true",
82
+ help="Collect per-module inference timing statistics",
83
+ )
84
+ return parser.parse_args(argv)
85
+
86
+
87
+ def main(argv=None):
88
+ args = parse_args(argv)
89
+ import soundfile as sf
90
+ from loguru import logger
91
+
92
+ from dots_tts.runtime import DotsTtsRuntime
93
+ from dots_tts.utils.logging import configure_logging
94
+ from dots_tts.utils.util import seed_everything
95
+
96
+ configure_logging()
97
+ seed_everything(args.seed)
98
+ output_path = Path(args.output)
99
+ output_path.parent.mkdir(parents=True, exist_ok=True)
100
+
101
+ logger.info(
102
+ "CLI command started: model={} output={} seed={}",
103
+ args.model_name_or_path,
104
+ output_path,
105
+ args.seed,
106
+ )
107
+
108
+ try:
109
+ runtime = DotsTtsRuntime.from_pretrained(
110
+ args.model_name_or_path,
111
+ revision=args.revision,
112
+ cache_dir=args.cache_dir,
113
+ precision=args.precision,
114
+ max_generate_length=args.max_generate_length,
115
+ )
116
+ result = runtime.generate(
117
+ text=args.text,
118
+ prompt_audio_path=args.prompt_audio,
119
+ prompt_text=args.prompt_text,
120
+ language=args.language,
121
+ template_name=args.template_name,
122
+ ode_method=args.ode_method,
123
+ num_steps=args.num_steps,
124
+ guidance_scale=args.guidance_scale,
125
+ speaker_scale=args.speaker_scale,
126
+ normalize_text=args.normalize_text,
127
+ profile_inference=args.profile_inference,
128
+ )
129
+ sf.write(
130
+ output_path,
131
+ result["audio"].float().cpu().squeeze().numpy(),
132
+ result["sample_rate"],
133
+ )
134
+ except Exception:
135
+ logger.exception(
136
+ "CLI inference failed: model={} output={}",
137
+ args.model_name_or_path,
138
+ output_path,
139
+ )
140
+ raise
141
+
142
+ logger.info(
143
+ "CLI output written: request_id={} output={} sample_rate={} samples={}",
144
+ result["fid"],
145
+ output_path,
146
+ result["sample_rate"],
147
+ int(result["audio"].shape[-1]),
148
+ )
149
+
150
+
151
+ if __name__ == "__main__":
152
+ raise SystemExit(main())
src/dots_tts/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Configuration package."""
src/dots_tts/config/app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import yaml
6
+
7
+ from dots_tts.config.base import StrictConfigBase
8
+ from dots_tts.config.data import DataConfig
9
+ from dots_tts.config.train import TrainConfig
10
+ from dots_tts.models.dots_tts.config import LossConfig
11
+
12
+ DEFAULT_CONFIG_PATH = "configs/dots_tts.yaml"
13
+
14
+
15
+ class AppConfig(StrictConfigBase):
16
+ train_data: DataConfig
17
+ val_data: DataConfig | None = None
18
+ loss: LossConfig
19
+ train: TrainConfig
20
+
21
+ @classmethod
22
+ def from_yaml(cls, config_path: str = DEFAULT_CONFIG_PATH) -> AppConfig:
23
+ with Path(config_path).open(encoding="utf-8") as fin:
24
+ raw_config = yaml.safe_load(fin)
25
+ return cls.model_validate(raw_config)
26
+
27
+
28
+ def load_config(config_path: str = DEFAULT_CONFIG_PATH) -> AppConfig:
29
+ return AppConfig.from_yaml(config_path)
30
+
31
+
32
+ __all__ = ["AppConfig", "DEFAULT_CONFIG_PATH", "load_config"]
src/dots_tts/config/base.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+
8
+ class ConfigBase(BaseModel):
9
+ model_config = ConfigDict(
10
+ extra="allow",
11
+ validate_assignment=True,
12
+ arbitrary_types_allowed=True,
13
+ )
14
+
15
+ def get(self, key: str, default=None):
16
+ value = getattr(self, key, default)
17
+ if value is default:
18
+ return value
19
+
20
+ fields_set = self.model_fields_set
21
+ if value is None and key not in fields_set:
22
+ return default
23
+ return value
24
+
25
+ def to_dict(self) -> dict[str, Any]:
26
+ return self.model_dump(exclude_none=True)
27
+
28
+ @classmethod
29
+ def _declared_field_names(cls) -> list[str]:
30
+ return [name for name in cls.model_fields if name != "model_config"]
31
+
32
+ @classmethod
33
+ def _serialize_declared_value(cls, value):
34
+ if isinstance(value, ConfigBase):
35
+ return value.to_declared_dict()
36
+ if isinstance(value, list):
37
+ return [cls._serialize_declared_value(item) for item in value]
38
+ if isinstance(value, tuple):
39
+ return [cls._serialize_declared_value(item) for item in value]
40
+ if isinstance(value, dict):
41
+ return {
42
+ key: cls._serialize_declared_value(item) for key, item in value.items()
43
+ }
44
+ return value
45
+
46
+ def to_declared_dict(self) -> dict[str, Any]:
47
+ data = {}
48
+ for name in self._declared_field_names():
49
+ value = getattr(self, name, None)
50
+ if value is None:
51
+ continue
52
+ data[name] = self._serialize_declared_value(value)
53
+ return data
54
+
55
+
56
+ class StrictConfigBase(ConfigBase):
57
+ model_config = ConfigDict(
58
+ extra="forbid",
59
+ validate_assignment=True,
60
+ arbitrary_types_allowed=True,
61
+ )
62
+
63
+
64
+ __all__ = ["ConfigBase", "StrictConfigBase"]
src/dots_tts/config/data.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Literal
4
+
5
+ from pydantic import Field, model_validator
6
+
7
+ from dots_tts.config.base import StrictConfigBase
8
+
9
+ DEFAULT_SOURCE_ADAPTER_CLASS_NAME = "JsonlManifestSourceAdapter"
10
+
11
+
12
+ class SourceAdapterConfig(StrictConfigBase):
13
+ class_name: Literal["JsonlManifestSourceAdapter"] = (
14
+ DEFAULT_SOURCE_ADAPTER_CLASS_NAME
15
+ )
16
+ params: dict[str, Any] = Field(default_factory=dict)
17
+
18
+
19
+ class DataSourceConfig(StrictConfigBase):
20
+ name: str
21
+ weight: float = Field(default=1.0, gt=0.0)
22
+ pipeline: Literal["basic", "interleave"] = "basic"
23
+ adapter: SourceAdapterConfig = Field(default_factory=SourceAdapterConfig)
24
+
25
+
26
+ class DataConfig(StrictConfigBase):
27
+ sources: list[DataSourceConfig]
28
+ train_audio_sample_rate: int = Field(ge=1)
29
+ audio_samples_per_llm_token: int = Field(ge=1)
30
+ num_tokens_per_epoch: int | None = Field(
31
+ default=None,
32
+ ge=1,
33
+ description="Global token budget across all ranks for one training epoch.",
34
+ )
35
+ num_workers: int = Field(default=0, ge=0)
36
+ pin_memory: bool = False
37
+ prefetch_factor: int = Field(
38
+ default=2,
39
+ ge=1,
40
+ description="Samples prefetched by each DataLoader worker.",
41
+ )
42
+ max_audio_seconds_in_batch: float = Field(gt=0.0)
43
+ max_text_tokens_in_batch: int = Field(ge=1)
44
+ max_samples_per_batch: int | None = Field(default=None, ge=1)
45
+ bucketing_pool_size: int = Field(default=64, ge=1)
46
+
47
+ @model_validator(mode="after")
48
+ def _validate_unique_source_names(self) -> "DataConfig":
49
+ counts: dict[str, int] = {}
50
+ for source in self.sources:
51
+ counts[source.name] = counts.get(source.name, 0) + 1
52
+ duplicated = [name for name, count in counts.items() if count > 1]
53
+ if duplicated:
54
+ raise ValueError(f"Source names must be unique: {duplicated}")
55
+ return self
56
+
57
+
58
+ __all__ = [
59
+ "DEFAULT_SOURCE_ADAPTER_CLASS_NAME",
60
+ "DataConfig",
61
+ "DataSourceConfig",
62
+ "SourceAdapterConfig",
63
+ ]
src/dots_tts/config/train.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pydantic import Field
4
+
5
+ from dots_tts.config.base import StrictConfigBase
6
+
7
+
8
+ class TrainConfig(StrictConfigBase):
9
+ pretrained_model_path: str
10
+ output_dir: str
11
+ seed: int = 42
12
+ learning_rate: float
13
+ cfg_droprate: float = 0.0
14
+ xvec_drop_rate: float = 0.5
15
+ weight_decay: float = 0.01
16
+ warmup_steps: int = 0
17
+ max_train_steps: int
18
+ gradient_accumulation_steps: int = Field(default=1, ge=1)
19
+ grad_clip_norm: float = 1.0
20
+ save_interval: int = Field(default=1000, ge=1)
21
+ max_checkpoints_to_keep: int = 10
22
+ log_interval: int = Field(default=10, ge=1)
23
+ eval_interval: int | None = Field(default=None, ge=1)
24
+ max_eval_batches: int | None = None
25
+ run_eval_on_start: bool = False
26
+
27
+
28
+ __all__ = ["TrainConfig"]
src/dots_tts/data/EXTENSION.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Source Extension Guide
2
+
3
+ This document answers exactly one question: how to plug a new training data source into the current `dots_tts` data pipeline.
4
+
5
+ If you only need to swap in a different JSONL manifest, no code changes are required. To support a new raw data format, you usually only need to add:
6
+
7
+ - one **source adapter**
8
+ - optionally one **sample pipeline**
9
+
10
+ ## Data flow
11
+
12
+ 1. An **adapter** reads from the raw data source and yields raw samples.
13
+ 2. A **pipeline** turns each raw sample into a training sample (1:1).
14
+ 3. A **multi-source wrapper** handles mixing across sources and resume state.
15
+ 4. `StreamingSampleDataset` / `DataLoader` pulls samples.
16
+ 5. `OnlineBatcher` assembles batches and `PadCollator` performs padding.
17
+
18
+ ## What an adapter must implement
19
+
20
+ Subclass `BaseSourceAdapter`:
21
+
22
+ ```python
23
+ class BaseSourceAdapter(ABC):
24
+ @abstractmethod
25
+ def initial_state(self) -> dict[str, Any]:
26
+ ...
27
+
28
+ @abstractmethod
29
+ def iter_samples(
30
+ self,
31
+ context: SourceContext,
32
+ *,
33
+ state: dict[str, Any] | None = None,
34
+ ) -> Iterable[dict[str, Any]]:
35
+ ...
36
+
37
+ @abstractmethod
38
+ def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool:
39
+ ...
40
+
41
+ # Optional — only required when used under WeightedMultiSourceAdapter,
42
+ # which cycles each finite child source independently. The default
43
+ # implementation raises if your adapter never gets re-cycled.
44
+ def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]:
45
+ ...
46
+ ```
47
+
48
+ Each emitted sample **must** carry these fields:
49
+
50
+ - `fid`
51
+ - `text`
52
+ - `audio`
53
+ - `_adapter_state`
54
+
55
+ Key constraints:
56
+
57
+ - `_adapter_state` must describe **where to resume next**, not the position of the current item.
58
+ - The state must be plain Python data — serializable and recoverable after a restart.
59
+ - If your source needs to be split across workers, use `context.global_worker_id` and `context.global_worker_count` (or subclass `ShardableSourceAdapter` and use its `is_assigned_index` / `shard_items` helpers).
60
+ - If the source will participate in weighted cyclic sampling, you must implement `advance_cycle` and make `is_cycle_start_state` correct — otherwise `WeightedMultiSourceAdapter` cannot detect an empty cycle and will raise.
61
+
62
+ After implementing the adapter, register the class in `dots_tts/data/builders.py::_SOURCE_ADAPTER_CLASSES` so that the YAML config can resolve it by `class_name`.
63
+
64
+ ## What a pipeline must implement
65
+
66
+ Pipelines must subclass `BaseSamplePipeline` and perform a strict **1:1** sample transform.
67
+
68
+ Minimum implementation:
69
+
70
+ ```python
71
+ class MyPipeline(BaseSamplePipeline):
72
+ def process_sample(self, sample: dict) -> dict:
73
+ sample["text"] = str(sample["text"]).strip()
74
+ return sample
75
+ ```
76
+
77
+ Do **not**:
78
+
79
+ - filter samples out
80
+ - expand a single sample into multiple samples
81
+ - assemble batches inside the pipeline
82
+
83
+ `BaseSamplePipeline.__call__` automatically merges the original raw sample (including `_adapter_state` and any extra fields the adapter attached) with whatever your `process_sample` returns. You do not need to copy these fields manually — just return the fields you produced or want to overwrite.
84
+
85
+ To wire a new pipeline into config, also extend `dots_tts/data/builders.py::_build_source_pipeline` so it can be selected by name in YAML.
86
+
87
+ ## How multi-source wrappers affect you
88
+
89
+ There are two wrappers in the current codebase:
90
+
91
+ - `SequentialMultiSourceAdapter` — used for validation. Reads sources in the configured order, exhaustively, once.
92
+ - `WeightedMultiSourceAdapter` — used for training. Draws sources by weight, cycles each child source independently when exhausted.
93
+
94
+ Both wrappers **replace** the `_adapter_state` produced by your child adapter with their own resume state before yielding to the dataset. Even so, the child adapter must still emit its own `_adapter_state` — the wrapper reads it to track where each sub-source has read to.
95
+
96
+ ## Config
97
+
98
+ Each source is configured independently:
99
+
100
+ ```yaml
101
+ train_data:
102
+ sources:
103
+ - name: train_a
104
+ weight: 1.0
105
+ pipeline: basic
106
+ adapter:
107
+ class_name: JsonlManifestSourceAdapter
108
+ params:
109
+ manifest_path: train_a.jsonl
110
+ - name: train_b
111
+ weight: 2.0
112
+ pipeline: interleave
113
+ adapter:
114
+ class_name: JsonlManifestSourceAdapter
115
+ params:
116
+ manifest_path: train_b.jsonl
117
+ ```
118
+
119
+ Constraints:
120
+
121
+ - `sources[].name` must be unique within the same `train_data` / `val_data` block (it is used as a dict key for resume state).
122
+ - `sources[].pipeline` is a per-source setting, not shared across the dataset.
123
+ - All sources must ultimately produce the same training-sample structure, since they feed into the same batcher and collator.
124
+ - `class_name` must match a key registered in `_SOURCE_ADAPTER_CLASSES`; `params` is forwarded verbatim as kwargs to the adapter constructor.
src/dots_tts/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Data package."""
src/dots_tts/data/batchers.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from collections.abc import Iterable, Iterator
5
+ from dataclasses import dataclass
6
+
7
+ from dots_tts.utils.profiling import ensure_data_profiler
8
+
9
+
10
+ @dataclass(slots=True)
11
+ class BatchDecision:
12
+ dropped_samples: list[dict]
13
+ batch_samples: list[dict]
14
+
15
+
16
+ @dataclass(slots=True)
17
+ class _PoolSample:
18
+ sample: dict
19
+ num_audio_tokens: int
20
+ num_text_tokens: int
21
+ arrival_step: int
22
+
23
+
24
+ class OnlineBatcher:
25
+ def __init__(
26
+ self,
27
+ *,
28
+ max_audio_tokens_in_batch: int,
29
+ max_text_tokens_in_batch: int,
30
+ max_batch_size: int | None,
31
+ sample_pool_size: int,
32
+ profiler=None,
33
+ ):
34
+ self.max_audio_tokens_in_batch = max(1, int(max_audio_tokens_in_batch))
35
+ self.max_text_tokens_in_batch = max(1, int(max_text_tokens_in_batch))
36
+ self.max_batch_size = max_batch_size
37
+ self.sample_pool_size = max(1, int(sample_pool_size))
38
+ self.profiler = ensure_data_profiler(profiler)
39
+
40
+ @staticmethod
41
+ def _sort_pool(pool: list[_PoolSample]) -> None:
42
+ pool.sort(
43
+ key=lambda item: (
44
+ item.num_audio_tokens,
45
+ item.num_text_tokens,
46
+ -item.arrival_step,
47
+ ),
48
+ reverse=True,
49
+ )
50
+
51
+ def _choose_anchor_index(
52
+ self,
53
+ pool: list[_PoolSample],
54
+ *,
55
+ decision_step: int,
56
+ ) -> int:
57
+ oldest_waiting_index = -1
58
+ oldest_waiting_step = decision_step
59
+
60
+ for index, item in enumerate(pool):
61
+ waited_steps = decision_step - item.arrival_step
62
+ if waited_steps < self.sample_pool_size:
63
+ continue
64
+ if item.arrival_step <= oldest_waiting_step:
65
+ oldest_waiting_index = index
66
+ oldest_waiting_step = item.arrival_step
67
+
68
+ return 0 if oldest_waiting_index < 0 else oldest_waiting_index
69
+
70
+ def _build_next_decision(
71
+ self,
72
+ pool: list[_PoolSample],
73
+ *,
74
+ decision_step: int,
75
+ ) -> BatchDecision:
76
+ dropped_samples: list[dict] = []
77
+ batch_samples: list[dict] = []
78
+ selected_indices: list[int] = []
79
+ anchor_index = self._choose_anchor_index(pool, decision_step=decision_step)
80
+ anchor = pool[anchor_index]
81
+
82
+ exceed_audio_budget = anchor.num_audio_tokens > self.max_audio_tokens_in_batch
83
+ exceed_text_budget = anchor.num_text_tokens > self.max_text_tokens_in_batch
84
+ exceed_batch_size = self.max_batch_size is not None and self.max_batch_size < 1
85
+ if exceed_audio_budget or exceed_text_budget or exceed_batch_size:
86
+ skipped = pool.pop(anchor_index).sample
87
+ dropped_samples.append(skipped)
88
+ warnings.warn(
89
+ "Skipping sample that exceeds batching limits on its own: "
90
+ f"fid={skipped.get('fid')!r}, "
91
+ f"num_audio_tokens={anchor.num_audio_tokens}, "
92
+ f"input_ids_length={anchor.num_text_tokens}, "
93
+ f"max_audio_tokens_in_batch={self.max_audio_tokens_in_batch}, "
94
+ f"max_text_tokens_in_batch={self.max_text_tokens_in_batch}, "
95
+ f"max_batch_size={self.max_batch_size}",
96
+ RuntimeWarning,
97
+ stacklevel=2,
98
+ )
99
+ return BatchDecision(
100
+ dropped_samples=dropped_samples,
101
+ batch_samples=batch_samples,
102
+ )
103
+
104
+ longest_audio_tokens = anchor.num_audio_tokens
105
+ longest_text_tokens = anchor.num_text_tokens
106
+ batch_samples.append(anchor.sample)
107
+ selected_indices.append(anchor_index)
108
+
109
+ for index, item in enumerate(pool):
110
+ if index == anchor_index:
111
+ continue
112
+ if (
113
+ self.max_batch_size is not None
114
+ and len(batch_samples) >= self.max_batch_size
115
+ ):
116
+ break
117
+
118
+ proposed_batch_size = len(batch_samples) + 1
119
+ proposed_longest_audio_tokens = max(
120
+ longest_audio_tokens,
121
+ item.num_audio_tokens,
122
+ )
123
+ proposed_longest_text_tokens = max(
124
+ longest_text_tokens,
125
+ item.num_text_tokens,
126
+ )
127
+ if (
128
+ proposed_longest_audio_tokens * proposed_batch_size
129
+ > self.max_audio_tokens_in_batch
130
+ ):
131
+ continue
132
+ if (
133
+ proposed_longest_text_tokens * proposed_batch_size
134
+ > self.max_text_tokens_in_batch
135
+ ):
136
+ continue
137
+
138
+ batch_samples.append(item.sample)
139
+ selected_indices.append(index)
140
+ longest_audio_tokens = proposed_longest_audio_tokens
141
+ longest_text_tokens = proposed_longest_text_tokens
142
+
143
+ for index in sorted(set(selected_indices), reverse=True):
144
+ pool.pop(index)
145
+
146
+ return BatchDecision(
147
+ dropped_samples=dropped_samples,
148
+ batch_samples=batch_samples,
149
+ )
150
+
151
+ def build_decisions(self, sample_iter: Iterable[dict]) -> Iterator[BatchDecision]:
152
+ pool: list[_PoolSample] = []
153
+ source_exhausted = False
154
+ decision_step = 0
155
+ iterator = iter(sample_iter)
156
+
157
+ while not source_exhausted or pool:
158
+ while not source_exhausted and len(pool) < self.sample_pool_size:
159
+ try:
160
+ sample = next(iterator)
161
+ except StopIteration:
162
+ source_exhausted = True
163
+ break
164
+ pool.append(
165
+ _PoolSample(
166
+ sample=sample,
167
+ num_audio_tokens=int(sample.get("num_audio_tokens", 0)),
168
+ num_text_tokens=int(sample.get("input_ids_length", 0)),
169
+ arrival_step=decision_step,
170
+ )
171
+ )
172
+
173
+ if not pool:
174
+ break
175
+
176
+ profiler = self.profiler
177
+ with profiler.measure("main.sort_pool", count=len(pool)):
178
+ self._sort_pool(pool)
179
+ with profiler.measure("main.build_batch_decision"):
180
+ decision = self._build_next_decision(
181
+ pool,
182
+ decision_step=decision_step,
183
+ )
184
+ if decision.dropped_samples or decision.batch_samples:
185
+ decision_step += 1
186
+ yield decision
187
+ continue
188
+ raise RuntimeError("OnlineBatcher failed to make progress on a non-empty pool.")
src/dots_tts/data/builders.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from torch.utils.data import DataLoader
4
+
5
+ from dots_tts.config.data import DataConfig
6
+ from dots_tts.data.pipelines.base import BaseSamplePipeline
7
+ from dots_tts.data.pipelines.tts_pipeline import BasicTtsPipeline, InterleaveTtsPipeline
8
+ from dots_tts.data.source_adapters.jsonl_manifest_adapter import (
9
+ JsonlManifestSourceAdapter,
10
+ )
11
+ from dots_tts.data.source_adapters.multi_source_adapter import (
12
+ SequentialMultiSourceAdapter,
13
+ SourceSpec,
14
+ WeightedMultiSourceAdapter,
15
+ )
16
+ from dots_tts.data.streaming import (
17
+ BatchedDataStream,
18
+ StreamingSampleDataset,
19
+ identity_collate,
20
+ )
21
+
22
+ _SOURCE_ADAPTER_CLASSES = {
23
+ "JsonlManifestSourceAdapter": JsonlManifestSourceAdapter,
24
+ }
25
+
26
+
27
+ def _build_source_pipeline(
28
+ tokenizer, data_cfg, pipeline_name: str, *, profiler=None
29
+ ) -> BaseSamplePipeline:
30
+ if pipeline_name == "basic":
31
+ return BasicTtsPipeline(tokenizer, data_cfg, profiler=profiler)
32
+ if pipeline_name == "interleave":
33
+ return InterleaveTtsPipeline(tokenizer, data_cfg, profiler=profiler)
34
+ raise ValueError(f"Unsupported data pipeline: {pipeline_name!r}")
35
+
36
+
37
+ def _build_source_specs(data_cfg, tokenizer, *, profiler=None) -> list[SourceSpec]:
38
+ specs = []
39
+ for source_cfg in data_cfg.sources:
40
+ adapter_cls = _SOURCE_ADAPTER_CLASSES[source_cfg.adapter.class_name]
41
+ adapter = adapter_cls(**source_cfg.adapter.params)
42
+ specs.append(
43
+ SourceSpec(
44
+ name=source_cfg.name,
45
+ weight=float(source_cfg.weight),
46
+ adapter=adapter,
47
+ pipeline=_build_source_pipeline(
48
+ tokenizer, data_cfg, source_cfg.pipeline, profiler=profiler
49
+ ),
50
+ )
51
+ )
52
+ return specs
53
+
54
+
55
+ def _resolve_rank_info(accelerator=None) -> tuple[int, int]:
56
+ rank = (
57
+ int(getattr(accelerator, "process_index", 0)) if accelerator is not None else 0
58
+ )
59
+ world_size = (
60
+ int(getattr(accelerator, "num_processes", 1)) if accelerator is not None else 1
61
+ )
62
+ return rank, world_size
63
+
64
+
65
+ def _local_num_tokens_per_epoch(
66
+ global_num_tokens_per_epoch: int, *, rank: int, world_size: int
67
+ ) -> int:
68
+ if world_size <= 0:
69
+ raise ValueError(f"world_size must be positive, but got {world_size}.")
70
+ if rank < 0 or rank >= world_size:
71
+ raise ValueError(
72
+ f"rank must be in [0, {world_size}), but got rank={rank}."
73
+ )
74
+
75
+ base, remainder = divmod(int(global_num_tokens_per_epoch), int(world_size))
76
+ return base + int(rank < remainder)
77
+
78
+
79
+ def _build_dataset(
80
+ data_cfg: DataConfig,
81
+ *,
82
+ tokenizer,
83
+ seed: int,
84
+ accelerator=None,
85
+ sequential: bool,
86
+ profiler=None,
87
+ ):
88
+ rank, world_size = _resolve_rank_info(accelerator)
89
+ source_cls = SequentialMultiSourceAdapter if sequential else WeightedMultiSourceAdapter
90
+ source = source_cls(
91
+ sources=_build_source_specs(data_cfg, tokenizer, profiler=profiler)
92
+ )
93
+ return StreamingSampleDataset(
94
+ source=source,
95
+ rank=rank,
96
+ world_size=world_size,
97
+ seed=int(seed),
98
+ )
99
+
100
+
101
+ def build_training_dataset(
102
+ data_cfg: DataConfig,
103
+ tokenizer,
104
+ *,
105
+ seed: int,
106
+ accelerator=None,
107
+ profiler=None,
108
+ ):
109
+ if data_cfg.num_tokens_per_epoch is None:
110
+ raise ValueError("Training data requires num_tokens_per_epoch.")
111
+ return _build_dataset(
112
+ data_cfg,
113
+ tokenizer=tokenizer,
114
+ seed=seed,
115
+ accelerator=accelerator,
116
+ sequential=False,
117
+ profiler=profiler,
118
+ )
119
+
120
+
121
+ def build_validation_dataset(
122
+ data_cfg: DataConfig,
123
+ tokenizer,
124
+ *,
125
+ seed: int,
126
+ accelerator=None,
127
+ profiler=None,
128
+ ):
129
+ return _build_dataset(
130
+ data_cfg,
131
+ tokenizer=tokenizer,
132
+ seed=seed,
133
+ accelerator=accelerator,
134
+ sequential=True,
135
+ profiler=profiler,
136
+ )
137
+
138
+
139
+ def _build_sample_loader(dataset, data_cfg: DataConfig) -> DataLoader:
140
+ loader_kwargs = {
141
+ "dataset": dataset,
142
+ "batch_size": None,
143
+ "collate_fn": identity_collate,
144
+ "num_workers": data_cfg.num_workers,
145
+ "pin_memory": data_cfg.pin_memory,
146
+ "persistent_workers": data_cfg.num_workers > 0,
147
+ }
148
+ if data_cfg.num_workers > 0:
149
+ loader_kwargs["prefetch_factor"] = int(data_cfg.prefetch_factor)
150
+ sample_loader = DataLoader(**loader_kwargs)
151
+ return sample_loader
152
+
153
+
154
+ def build_training_dataloader(
155
+ dataset, data_cfg: DataConfig, tokenizer, *, profiler=None
156
+ ):
157
+ local_num_tokens_per_epoch = _local_num_tokens_per_epoch(
158
+ int(data_cfg.num_tokens_per_epoch),
159
+ rank=int(dataset.rank),
160
+ world_size=int(dataset.world_size),
161
+ )
162
+ sample_loader = _build_sample_loader(dataset, data_cfg)
163
+ batched_stream = BatchedDataStream(
164
+ sample_dataset=dataset,
165
+ data_cfg=data_cfg,
166
+ tokenizer=tokenizer,
167
+ num_tokens_per_epoch=local_num_tokens_per_epoch,
168
+ profiler=profiler,
169
+ )
170
+ batched_stream.attach_loader(sample_loader)
171
+ return batched_stream
172
+
173
+
174
+ def build_validation_dataloader(
175
+ dataset, data_cfg: DataConfig, tokenizer, *, profiler=None
176
+ ):
177
+ sample_loader = _build_sample_loader(dataset, data_cfg)
178
+ batched_stream = BatchedDataStream(
179
+ sample_dataset=dataset,
180
+ data_cfg=data_cfg,
181
+ tokenizer=tokenizer,
182
+ num_tokens_per_epoch=None,
183
+ profiler=profiler,
184
+ )
185
+ batched_stream.attach_loader(sample_loader)
186
+ return batched_stream
187
+
188
+
189
+ __all__ = [
190
+ "build_training_dataloader",
191
+ "build_training_dataset",
192
+ "build_validation_dataloader",
193
+ "build_validation_dataset",
194
+ ]
src/dots_tts/data/collator.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from torch.nn.utils.rnn import pad_sequence
7
+
8
+
9
+ class PadCollator:
10
+ def __init__(self, tokenizer):
11
+ self.tokenizer = tokenizer
12
+ self.pad_token_id = tokenizer.pad_token_id
13
+ if self.pad_token_id is None:
14
+ self.pad_token_id = tokenizer.eos_token_id or 0
15
+
16
+ def __call__(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
17
+ if not samples:
18
+ raise ValueError("PadCollator received an empty sample list.")
19
+
20
+ order = sorted(
21
+ range(len(samples)),
22
+ key=lambda idx: samples[idx]["sample_length"],
23
+ reverse=True,
24
+ )
25
+ ordered = [samples[idx] for idx in order]
26
+
27
+ input_ids = [
28
+ torch.tensor(sample["input_ids"], dtype=torch.long) for sample in ordered
29
+ ]
30
+ labels = [
31
+ torch.tensor(sample["labels"], dtype=torch.long) for sample in ordered
32
+ ]
33
+ loss_masks = [
34
+ torch.tensor(sample["loss_mask"], dtype=torch.float32) for sample in ordered
35
+ ]
36
+ waveforms = [sample["sample"].squeeze(0) for sample in ordered]
37
+ fbank = [sample["fbank"] for sample in ordered]
38
+
39
+ return {
40
+ "fids": [sample["fid"] for sample in ordered],
41
+ "source_names": [sample.get("source_name") for sample in ordered],
42
+ "input_ids": pad_sequence(
43
+ input_ids,
44
+ batch_first=True,
45
+ padding_value=self.pad_token_id,
46
+ ),
47
+ "input_ids_lengths": torch.tensor(
48
+ [len(sample["input_ids"]) for sample in ordered],
49
+ dtype=torch.long,
50
+ ),
51
+ "labels": pad_sequence(
52
+ labels,
53
+ batch_first=True,
54
+ padding_value=self.pad_token_id,
55
+ ),
56
+ "loss_mask": pad_sequence(
57
+ loss_masks,
58
+ batch_first=True,
59
+ padding_value=0.0,
60
+ ),
61
+ "sample": pad_sequence(
62
+ waveforms,
63
+ batch_first=True,
64
+ padding_value=0.0,
65
+ ).unsqueeze(1),
66
+ "sample_lengths": torch.tensor(
67
+ [sample["sample_length"] for sample in ordered],
68
+ dtype=torch.long,
69
+ ),
70
+ "num_text_tokens": torch.tensor(
71
+ [sample["num_text_tokens"] for sample in ordered],
72
+ dtype=torch.long,
73
+ ),
74
+ "num_audio_tokens": torch.tensor(
75
+ [sample["num_audio_tokens"] for sample in ordered],
76
+ dtype=torch.long,
77
+ ),
78
+ "fbank": pad_sequence(
79
+ fbank,
80
+ batch_first=True,
81
+ padding_value=0.0,
82
+ ),
83
+ "fbank_lengths": torch.tensor(
84
+ [sample["fbank_length"] for sample in ordered],
85
+ dtype=torch.long,
86
+ ),
87
+ }
src/dots_tts/data/pipelines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Data pipelines package."""
src/dots_tts/data/pipelines/base.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Iterable, Iterator
5
+
6
+
7
+ class BaseSamplePipeline(ABC):
8
+ """1:1 sample pipeline that preserves adapter resume metadata."""
9
+
10
+ @staticmethod
11
+ def _validate_input_sample(sample: dict) -> None:
12
+ if "_adapter_state" not in sample:
13
+ raise RuntimeError(
14
+ "Source sample is missing required '_adapter_state' for resume."
15
+ )
16
+
17
+ @abstractmethod
18
+ def process_sample(self, sample: dict) -> dict:
19
+ """Transform one raw sample into one processed sample."""
20
+
21
+ def __call__(self, samples: Iterable[dict]) -> Iterator[dict]:
22
+ for raw_sample in samples:
23
+ self._validate_input_sample(raw_sample)
24
+ processed = self.process_sample(dict(raw_sample))
25
+ if not isinstance(processed, dict):
26
+ raise RuntimeError(
27
+ f"{self.__class__.__name__}.process_sample() must return a dict."
28
+ )
29
+ item = dict(raw_sample)
30
+ item.update(processed)
31
+ self._validate_input_sample(item)
32
+ yield item
src/dots_tts/data/pipelines/preprocessing.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ DEFAULT_EDGE_SILENCE_MS = 250.0
7
+ DEFAULT_EDGE_SILENCE_TOP_DB = 30.0
8
+
9
+
10
+ def align_length(num_samples: int, multiple_of: int | None) -> int:
11
+ if multiple_of is None or multiple_of <= 0:
12
+ return int(num_samples)
13
+ if num_samples % multiple_of == 0:
14
+ return int(num_samples)
15
+ return int(((num_samples + multiple_of - 1) // multiple_of) * multiple_of)
16
+
17
+
18
+ def pad_waveform_align_only(
19
+ waveform: torch.Tensor,
20
+ *,
21
+ multiple_of: int | None,
22
+ ) -> torch.Tensor:
23
+ if multiple_of is None or multiple_of <= 0:
24
+ return waveform
25
+
26
+ target_length = align_length(waveform.size(-1), multiple_of)
27
+ delta = target_length - waveform.size(-1)
28
+ if delta <= 0:
29
+ return waveform
30
+
31
+ return F.pad(waveform, (0, delta), "constant", 0.0)
32
+
33
+
34
+ def normalize_edge_silence_duration(
35
+ waveform: torch.Tensor,
36
+ *,
37
+ sample_rate: int,
38
+ target_silence_duration_ms: float = DEFAULT_EDGE_SILENCE_MS,
39
+ top_db: float = DEFAULT_EDGE_SILENCE_TOP_DB,
40
+ ) -> torch.Tensor:
41
+ mono_waveform = waveform[0]
42
+ target_samples = int(round(float(sample_rate) * float(target_silence_duration_ms) / 1000.0))
43
+ amplitude = mono_waveform.abs()
44
+ peak = float(amplitude.max().item())
45
+ if peak <= 0.0:
46
+ waveform = waveform[..., :target_samples]
47
+ current_length = int(waveform.size(-1))
48
+ if current_length < target_samples:
49
+ waveform = F.pad(waveform, (0, target_samples - current_length), "constant", 0.0)
50
+ return waveform
51
+
52
+ threshold = peak * (10.0 ** (-float(top_db) / 20.0))
53
+ non_silent = torch.nonzero(amplitude > threshold, as_tuple=False).flatten()
54
+ first_non_silent = int(non_silent[0].item())
55
+ last_non_silent = int(non_silent[-1].item())
56
+
57
+ leading_silence_samples = first_non_silent
58
+ trailing_silence_samples = int(mono_waveform.numel()) - last_non_silent - 1
59
+
60
+ leading_delta = target_samples - leading_silence_samples
61
+ if leading_delta > 0:
62
+ waveform = F.pad(waveform, (leading_delta, 0), "constant", 0.0)
63
+ else:
64
+ trim_from_start = min(-leading_delta, int(waveform.size(-1)))
65
+ waveform = waveform[..., trim_from_start:]
66
+
67
+ trailing_delta = target_samples - trailing_silence_samples
68
+ if trailing_delta > 0:
69
+ return F.pad(waveform, (0, trailing_delta), "constant", 0.0)
70
+
71
+ trim_from_end = min(-trailing_delta, int(waveform.size(-1)))
72
+ if trim_from_end <= 0:
73
+ return waveform
74
+ return waveform[..., :-trim_from_end]
75
+
76
+
77
+ def compute_num_audio_tokens(
78
+ num_samples: int, *, audio_samples_per_llm_token: int
79
+ ) -> int:
80
+ if num_samples % audio_samples_per_llm_token != 0:
81
+ raise ValueError(
82
+ f"Waveform length {num_samples} is not aligned to token hop {audio_samples_per_llm_token}."
83
+ )
84
+ return num_samples // audio_samples_per_llm_token
src/dots_tts/data/pipelines/tokenizing.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import re
5
+ from typing import Any
6
+
7
+ from loguru import logger
8
+
9
+ from dots_tts.utils.tokenizer import (
10
+ AUDIO_GEN_END_TOKEN,
11
+ AUDIO_GEN_SPAN_TOKEN,
12
+ AUDIO_GEN_START_TOKEN,
13
+ TEXT_COND_END_TOKEN,
14
+ require_token_id,
15
+ )
16
+
17
+ TEMPLATE_PATTERN = re.compile(r"\{text\}|\{audio\}|\{interleave\}|[^\{]+")
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class ParsedTemplate:
22
+ parts: tuple[str, ...]
23
+ has_audio_placeholder: bool
24
+ has_interleave_placeholder: bool
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class TokenizedTemplatePart:
29
+ kind: str
30
+ token_ids: tuple[int, ...] = ()
31
+ raw_text: str | None = None
32
+
33
+
34
+ def parse_template(template: str) -> ParsedTemplate:
35
+ parts = tuple(re.findall(TEMPLATE_PATTERN, template))
36
+ has_audio_placeholder = "{audio}" in parts
37
+ interleave_count = parts.count("{interleave}")
38
+ if has_audio_placeholder and interleave_count:
39
+ raise ValueError("Template cannot mix audio and interleave placeholders.")
40
+ if interleave_count > 1:
41
+ raise ValueError(
42
+ "Interleave generation template must contain exactly one interleave placeholder."
43
+ )
44
+ return ParsedTemplate(
45
+ parts=parts,
46
+ has_audio_placeholder=has_audio_placeholder,
47
+ has_interleave_placeholder=interleave_count == 1,
48
+ )
49
+
50
+
51
+ def _prepare_template_tokens(
52
+ *, text: str, tokenizer, template: str
53
+ ) -> tuple[ParsedTemplate, list[int]]:
54
+ return parse_template(template), tokenizer.encode(text, add_special_tokens=False)
55
+
56
+
57
+ def _iter_tokenized_template_parts(
58
+ *,
59
+ parsed_template: ParsedTemplate,
60
+ tokenizer,
61
+ text_tokens: list[int],
62
+ ):
63
+ for part in parsed_template.parts:
64
+ if part == "{text}":
65
+ yield TokenizedTemplatePart(kind="text", token_ids=tuple(text_tokens))
66
+ continue
67
+ if part == "{audio}":
68
+ yield TokenizedTemplatePart(kind="audio")
69
+ continue
70
+ if part == "{interleave}":
71
+ yield TokenizedTemplatePart(kind="interleave")
72
+ continue
73
+ yield TokenizedTemplatePart(
74
+ kind="literal",
75
+ token_ids=tuple(tokenizer.encode(part, add_special_tokens=False)),
76
+ raw_text=part,
77
+ )
78
+
79
+
80
+ def _extend_tokens_with_loss(
81
+ *, full_ids: list[int], loss_mask: list[float], token_ids: tuple[int, ...], loss: float
82
+ ) -> None:
83
+ full_ids.extend(token_ids)
84
+ loss_mask.extend([loss] * len(token_ids))
85
+
86
+
87
+ def build_tokenized_example(
88
+ *, text: str, tokenizer, template: str, num_audio_tokens: int
89
+ ) -> dict[str, Any]:
90
+ if tokenizer.eos_token_id is None:
91
+ raise ValueError("Tokenizer eos_token_id is required for generation targets.")
92
+
93
+ parsed_template, text_tokens = _prepare_template_tokens(
94
+ text=text,
95
+ tokenizer=tokenizer,
96
+ template=template,
97
+ )
98
+
99
+ full_ids: list[int] = []
100
+ loss_mask: list[float] = []
101
+ audio_tokens: list[int] | None = None
102
+ if parsed_template.has_audio_placeholder:
103
+ audio_gen_start_id = require_token_id(tokenizer, AUDIO_GEN_START_TOKEN)
104
+ audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN)
105
+ audio_gen_end_id = require_token_id(tokenizer, AUDIO_GEN_END_TOKEN)
106
+ audio_tokens = (
107
+ [audio_gen_start_id]
108
+ + [audio_gen_span_id] * num_audio_tokens
109
+ + [audio_gen_end_id]
110
+ )
111
+ elif parsed_template.has_interleave_placeholder:
112
+ audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN)
113
+ audio_gen_end_id = require_token_id(tokenizer, AUDIO_GEN_END_TOKEN)
114
+ text_cond_end_id = require_token_id(tokenizer, TEXT_COND_END_TOKEN)
115
+
116
+ for part in _iter_tokenized_template_parts(
117
+ parsed_template=parsed_template,
118
+ tokenizer=tokenizer,
119
+ text_tokens=text_tokens,
120
+ ):
121
+ if part.kind == "text":
122
+ _extend_tokens_with_loss(
123
+ full_ids=full_ids,
124
+ loss_mask=loss_mask,
125
+ token_ids=part.token_ids,
126
+ loss=0.0,
127
+ )
128
+ continue
129
+
130
+ if part.kind == "audio":
131
+ if audio_tokens is None:
132
+ raise RuntimeError("Audio placeholder tokens were not initialized.")
133
+ full_ids.extend(audio_tokens)
134
+ loss_mask.extend([0.0])
135
+ loss_mask.extend([1.0] * max(0, len(audio_tokens) - 2))
136
+ loss_mask.append(0.0)
137
+ continue
138
+
139
+ if part.kind == "interleave":
140
+ _append_interleave_generation_tokens(
141
+ full_ids=full_ids,
142
+ loss_mask=loss_mask,
143
+ text_tokens=text_tokens,
144
+ num_audio_tokens=num_audio_tokens,
145
+ audio_span_id=audio_gen_span_id,
146
+ audio_end_id=audio_gen_end_id,
147
+ text_cond_end_id=text_cond_end_id,
148
+ )
149
+ continue
150
+
151
+ _extend_tokens_with_loss(
152
+ full_ids=full_ids,
153
+ loss_mask=loss_mask,
154
+ token_ids=part.token_ids,
155
+ loss=0.0,
156
+ )
157
+
158
+ full_ids.append(tokenizer.eos_token_id)
159
+ loss_mask.append(0.0)
160
+
161
+ return {
162
+ "input_ids": full_ids[:-1],
163
+ "labels": full_ids[1:],
164
+ "loss_mask": loss_mask[1:],
165
+ "text_token_count": len(text_tokens),
166
+ }
167
+
168
+
169
+ def build_generation_schedule(
170
+ *,
171
+ text: str,
172
+ tokenizer,
173
+ template: str,
174
+ max_audio_tokens: int,
175
+ ) -> dict[str, Any]:
176
+ if max_audio_tokens <= 0:
177
+ raise ValueError("max_audio_tokens must be positive for generation.")
178
+
179
+ parsed_template, text_tokens = _prepare_template_tokens(
180
+ text=text,
181
+ tokenizer=tokenizer,
182
+ template=template,
183
+ )
184
+ schedule_ids: list[int] = []
185
+ audio_gen_start_id = require_token_id(tokenizer, AUDIO_GEN_START_TOKEN)
186
+ audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN)
187
+
188
+ if parsed_template.has_audio_placeholder:
189
+ for part in _iter_tokenized_template_parts(
190
+ parsed_template=parsed_template,
191
+ tokenizer=tokenizer,
192
+ text_tokens=text_tokens,
193
+ ):
194
+ if part.kind == "audio":
195
+ schedule_ids.append(audio_gen_start_id)
196
+ schedule_ids.extend([audio_gen_span_id] * max_audio_tokens)
197
+ continue
198
+ schedule_ids.extend(part.token_ids)
199
+ visible_schedule_ids = [
200
+ token_id for token_id in schedule_ids if token_id != audio_gen_span_id
201
+ ]
202
+ decoded_schedule = (
203
+ tokenizer.decode(
204
+ visible_schedule_ids,
205
+ skip_special_tokens=False,
206
+ clean_up_tokenization_spaces=False,
207
+ )
208
+ if hasattr(tokenizer, "decode")
209
+ else repr(visible_schedule_ids)
210
+ )
211
+ logger.info(
212
+ "Built generation schedule: interleave={} max_audio_tokens={} sequence={!r}",
213
+ False,
214
+ int(max_audio_tokens),
215
+ decoded_schedule,
216
+ )
217
+ return {
218
+ "schedule_ids": schedule_ids,
219
+ "interleave": False,
220
+ }
221
+
222
+ if not parsed_template.has_interleave_placeholder:
223
+ raise ValueError(
224
+ "Generation template must contain either {audio} or {interleave}."
225
+ )
226
+ text_cond_end_id = require_token_id(tokenizer, TEXT_COND_END_TOKEN)
227
+ if max_audio_tokens < len(text_tokens):
228
+ raise ValueError(
229
+ "Interleave generation requires at least one audio span per text token: "
230
+ f"text_token_count={len(text_tokens)} "
231
+ f"max_audio_patch_count={max_audio_tokens}."
232
+ )
233
+
234
+ interleave_started = False
235
+ for part in _iter_tokenized_template_parts(
236
+ parsed_template=parsed_template,
237
+ tokenizer=tokenizer,
238
+ text_tokens=text_tokens,
239
+ ):
240
+ if part.kind == "interleave":
241
+ _append_interleave_schedule_tokens(
242
+ schedule_ids=schedule_ids,
243
+ text_tokens=text_tokens,
244
+ max_audio_tokens=max_audio_tokens,
245
+ audio_span_id=audio_gen_span_id,
246
+ text_cond_end_id=text_cond_end_id,
247
+ )
248
+ interleave_started = True
249
+ continue
250
+ if part.kind == "text":
251
+ raise ValueError(
252
+ "Generation schedule does not support {text} inside an interleave template."
253
+ )
254
+ if part.kind == "audio":
255
+ raise ValueError(
256
+ "Generation schedule does not support {audio} inside an interleave template."
257
+ )
258
+ if interleave_started:
259
+ if (part.raw_text or "").strip():
260
+ raise ValueError(
261
+ "Generation schedule does not support non-empty suffix text after the interleave placeholder."
262
+ )
263
+ continue
264
+ schedule_ids.extend(part.token_ids)
265
+
266
+ visible_schedule_ids = [
267
+ token_id for token_id in schedule_ids if token_id != audio_gen_span_id
268
+ ]
269
+ decoded_schedule = (
270
+ tokenizer.decode(
271
+ visible_schedule_ids,
272
+ skip_special_tokens=False,
273
+ clean_up_tokenization_spaces=False,
274
+ )
275
+ if hasattr(tokenizer, "decode")
276
+ else repr(visible_schedule_ids)
277
+ )
278
+ logger.info(
279
+ "Built generation schedule: interleave={} max_audio_tokens={} sequence={!r}",
280
+ True,
281
+ int(max_audio_tokens),
282
+ decoded_schedule,
283
+ )
284
+ return {
285
+ "schedule_ids": schedule_ids,
286
+ "interleave": True,
287
+ }
288
+
289
+
290
+ def _append_interleave_generation_tokens(
291
+ *,
292
+ full_ids: list[int],
293
+ loss_mask: list[float],
294
+ text_tokens: list[int],
295
+ num_audio_tokens: int,
296
+ audio_span_id: int,
297
+ audio_end_id: int,
298
+ text_cond_end_id: int,
299
+ ) -> None:
300
+ audio_tokens = [audio_span_id] * num_audio_tokens + [audio_end_id]
301
+ text_index = 0
302
+ audio_index = 0
303
+ text_cond_end_added = False
304
+
305
+ while text_index < len(text_tokens) or audio_index < len(audio_tokens):
306
+ if text_index < len(text_tokens):
307
+ full_ids.append(text_tokens[text_index])
308
+ loss_mask.append(0.0)
309
+ text_index += 1
310
+ elif not text_cond_end_added:
311
+ full_ids.append(text_cond_end_id)
312
+ loss_mask.append(0.0)
313
+ text_cond_end_added = True
314
+
315
+ if audio_index < len(audio_tokens):
316
+ full_ids.append(audio_tokens[audio_index])
317
+ loss_mask.append(1.0 if audio_index < num_audio_tokens else 0.0)
318
+ audio_index += 1
319
+
320
+ if not text_cond_end_added:
321
+ full_ids.append(text_cond_end_id)
322
+ loss_mask.append(0.0)
323
+
324
+
325
+ def _append_interleave_schedule_tokens(
326
+ *,
327
+ schedule_ids: list[int],
328
+ text_tokens: list[int],
329
+ max_audio_tokens: int,
330
+ audio_span_id: int,
331
+ text_cond_end_id: int,
332
+ ) -> None:
333
+ for token_id in text_tokens:
334
+ schedule_ids.append(token_id)
335
+ schedule_ids.append(audio_span_id)
336
+ schedule_ids.append(text_cond_end_id)
337
+ remaining_audio_tokens = max_audio_tokens - len(text_tokens)
338
+ if remaining_audio_tokens > 0:
339
+ schedule_ids.extend([audio_span_id] * remaining_audio_tokens)
src/dots_tts/data/pipelines/tts_pipeline.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import soundfile as sf
4
+ import torch
5
+
6
+ from dots_tts.utils.profiling import ensure_data_profiler
7
+ from dots_tts.data.pipelines.base import BaseSamplePipeline
8
+ from dots_tts.data.pipelines.preprocessing import (
9
+ compute_num_audio_tokens,
10
+ normalize_edge_silence_duration,
11
+ pad_waveform_align_only,
12
+ )
13
+ from dots_tts.data.pipelines.tokenizing import build_tokenized_example
14
+ from dots_tts.modules.speaker.fbank import extract_speaker_fbank
15
+ from dots_tts.utils.audio import high_quality_resample
16
+
17
+ TTS_TEXT_PREFIX = "[文本]"
18
+ TTS_AUDIO_PREFIX = "[文本对应语音]"
19
+ TTS_INSTRUCTION_TEXT_PREFIX = "[带指令文本]"
20
+ TTA_TEXT_PREFIX = "[声音描述]"
21
+ TTA_AUDIO_PREFIX = "[描述对应声音]"
22
+ TTS_INTERLEAVE_PREFIX = "[流式语音合成]"
23
+ DEFAULT_TRAIN_TEMPLATE = f"{TTS_TEXT_PREFIX}{{text}}{TTS_AUDIO_PREFIX}{{audio}}"
24
+ DEFAULT_INSTRUCTION_TTS_TEMPLATE = (
25
+ f"{TTS_INSTRUCTION_TEXT_PREFIX}{{text}}{TTS_AUDIO_PREFIX}{{audio}}"
26
+ )
27
+ DEFAULT_TEXT_TO_AUDIO_TEMPLATE = f"{TTA_TEXT_PREFIX}{{text}}{TTA_AUDIO_PREFIX}{{audio}}"
28
+ DEFAULT_INTERLEAVE_TRAIN_TEMPLATE = f"{TTS_INTERLEAVE_PREFIX}{{interleave}}"
29
+
30
+
31
+ class BasicTtsPipeline(BaseSamplePipeline):
32
+ """Fixed internal training pipeline for adapter-emitted samples."""
33
+
34
+ template = DEFAULT_TRAIN_TEMPLATE
35
+
36
+ def __init__(self, tokenizer, data_cfg, *, profiler=None):
37
+ self.tokenizer = tokenizer
38
+ self.train_audio_sample_rate = int(data_cfg.train_audio_sample_rate)
39
+ self.audio_samples_per_llm_token = int(data_cfg.audio_samples_per_llm_token)
40
+ self.profiler = ensure_data_profiler(profiler)
41
+
42
+ @staticmethod
43
+ def _load_waveform(audio_path: str) -> tuple[torch.Tensor, int]:
44
+ if not isinstance(audio_path, str):
45
+ raise TypeError(
46
+ f"Training audio must be a filesystem path, got {type(audio_path)}."
47
+ )
48
+ audio_data, sample_rate = sf.read(
49
+ audio_path,
50
+ dtype="float32",
51
+ always_2d=True,
52
+ )
53
+ waveform = torch.from_numpy(audio_data.T)
54
+ if waveform.size(0) > 1:
55
+ waveform = waveform.mean(dim=0, keepdim=True)
56
+ return waveform.contiguous(), int(sample_rate)
57
+
58
+ @staticmethod
59
+ def _validate_source_sample(sample: dict) -> None:
60
+ missing = [field for field in ("fid", "text", "audio") if field not in sample]
61
+ if missing:
62
+ raise ValueError(
63
+ "Source adapter must emit fid/text/audio. "
64
+ f"Missing fields: {missing}. Sample keys: {sorted(sample.keys())}"
65
+ )
66
+
67
+ def process_sample(self, raw_sample: dict) -> dict:
68
+ sample = dict(raw_sample)
69
+ self._validate_source_sample(sample)
70
+ sample["fid"] = str(sample["fid"])
71
+
72
+ with self.profiler.measure("worker.process_sample_total"):
73
+ return self._process_sample_impl(sample)
74
+
75
+ def _process_sample_impl(self, sample: dict) -> dict:
76
+ profiler = self.profiler
77
+ with profiler.measure("worker.load_audio"):
78
+ waveform, sample_rate = self._load_waveform(sample["audio"])
79
+ with profiler.measure("worker.resample_audio"):
80
+ waveform = high_quality_resample(
81
+ waveform,
82
+ orig_sr=sample_rate,
83
+ target_sr=self.train_audio_sample_rate,
84
+ )
85
+ with profiler.measure("worker.normalize_edge_silence"):
86
+ waveform = normalize_edge_silence_duration(
87
+ waveform,
88
+ sample_rate=self.train_audio_sample_rate,
89
+ )
90
+ sample["sample"] = waveform
91
+ sample["sample_rate"] = self.train_audio_sample_rate
92
+ sample["unpadded_sample_length"] = int(waveform.size(-1))
93
+
94
+ with profiler.measure("worker.pad_audio"):
95
+ waveform = pad_waveform_align_only(
96
+ waveform,
97
+ multiple_of=self.audio_samples_per_llm_token,
98
+ )
99
+ sample["sample"] = waveform
100
+ sample["sample_length"] = int(waveform.size(-1))
101
+
102
+ num_audio_tokens = compute_num_audio_tokens(
103
+ sample["sample_length"],
104
+ audio_samples_per_llm_token=self.audio_samples_per_llm_token,
105
+ )
106
+ with profiler.measure("worker.tokenize"):
107
+ tokenized = build_tokenized_example(
108
+ text=sample["text"],
109
+ tokenizer=self.tokenizer,
110
+ template=self.template,
111
+ num_audio_tokens=num_audio_tokens,
112
+ )
113
+ sample["input_ids"] = tokenized["input_ids"]
114
+ sample["labels"] = tokenized["labels"]
115
+ sample["loss_mask"] = tokenized["loss_mask"]
116
+ sample["input_ids_length"] = len(tokenized["input_ids"])
117
+ sample["num_text_tokens"] = tokenized["text_token_count"]
118
+ sample["num_audio_tokens"] = num_audio_tokens
119
+ sample["num_total_tokens"] = sample["input_ids_length"]
120
+
121
+ with profiler.measure("worker.extract_fbank"):
122
+ fbank = extract_speaker_fbank(
123
+ sample["sample"],
124
+ sample_rate=sample["sample_rate"],
125
+ )
126
+ sample["fbank"] = fbank
127
+ sample["fbank_length"] = int(fbank.size(0))
128
+ return sample
129
+
130
+
131
+ class InterleaveTtsPipeline(BasicTtsPipeline):
132
+ template = DEFAULT_INTERLEAVE_TRAIN_TEMPLATE
src/dots_tts/data/source_adapters/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Source adapter package."""
src/dots_tts/data/source_adapters/base_adapter.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Iterable, Sequence
6
+ from copy import deepcopy
7
+ from dataclasses import dataclass
8
+ from typing import Any, TypeVar
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class SourceContext:
13
+ """Execution context for a single adapter iterator."""
14
+
15
+ epoch: int
16
+ rank: int
17
+ world_size: int
18
+ worker_id: int
19
+ num_workers: int
20
+ seed: int
21
+
22
+ @property
23
+ def global_worker_count(self) -> int:
24
+ return max(1, self.world_size * self.num_workers)
25
+
26
+ @property
27
+ def global_worker_id(self) -> int:
28
+ return self.rank * self.num_workers + self.worker_id
29
+
30
+
31
+ class BaseSourceAdapter(ABC):
32
+ """State-aware streaming source interface used by the training pipeline."""
33
+
34
+ @abstractmethod
35
+ def initial_state(self) -> dict[str, Any]:
36
+ """Return the default iterator state for a new worker/epoch."""
37
+
38
+ @abstractmethod
39
+ def iter_samples(
40
+ self,
41
+ context: SourceContext,
42
+ *,
43
+ state: dict[str, Any] | None = None,
44
+ ) -> Iterable[dict[str, Any]]:
45
+ """Yield raw samples and attach the next adapter state to each item."""
46
+
47
+ @abstractmethod
48
+ def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool:
49
+ """Return whether ``state`` points at the beginning of a source cycle."""
50
+
51
+ def normalize_state(self, state: dict[str, Any] | None) -> dict[str, Any]:
52
+ merged = self.initial_state()
53
+ if state:
54
+ merged.update(deepcopy(state))
55
+ return merged
56
+
57
+ def clone_state(self, state: dict[str, Any] | None) -> dict[str, Any]:
58
+ return deepcopy(self.normalize_state(state))
59
+
60
+ def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]:
61
+ raise RuntimeError(
62
+ f"{self.__class__.__name__} does not support repeated cycling."
63
+ )
64
+
65
+
66
+ _T = TypeVar("_T")
67
+
68
+
69
+ class ShardableSourceAdapter(BaseSourceAdapter):
70
+ """Helper mixin for deterministic rank/worker sharding."""
71
+
72
+ @staticmethod
73
+ def is_assigned_index(index: int, context: SourceContext) -> bool:
74
+ return index % context.global_worker_count == context.global_worker_id
75
+
76
+ @staticmethod
77
+ def shard_items(
78
+ items: Sequence[_T],
79
+ context: SourceContext,
80
+ *,
81
+ shuffle: bool = False,
82
+ seed_offset: int = 0,
83
+ ) -> list[_T]:
84
+ assigned = list(items)
85
+ if shuffle:
86
+ random.Random(context.seed + context.epoch + seed_offset).shuffle(assigned)
87
+ return [
88
+ item
89
+ for index, item in enumerate(assigned)
90
+ if ShardableSourceAdapter.is_assigned_index(index, context)
91
+ ]
src/dots_tts/data/source_adapters/jsonl_manifest_adapter.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import random
5
+ from collections.abc import Iterable, Iterator
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from dots_tts.data.source_adapters.base_adapter import (
10
+ BaseSourceAdapter,
11
+ ShardableSourceAdapter,
12
+ SourceContext,
13
+ )
14
+
15
+
16
+ class JsonlManifestSourceAdapter(ShardableSourceAdapter, BaseSourceAdapter):
17
+ """Finite adapter for line-delimited JSON manifests."""
18
+
19
+ def __init__(
20
+ self,
21
+ *,
22
+ manifest_path: str,
23
+ fid_key: str = "fid",
24
+ text_key: str = "text",
25
+ audio_key: str = "audio",
26
+ shuffle: bool = False,
27
+ encoding: str = "utf-8",
28
+ ):
29
+ self.manifest_path = Path(manifest_path)
30
+ self.fid_key = fid_key
31
+ self.text_key = text_key
32
+ self.audio_key = audio_key
33
+ self.shuffle = shuffle
34
+ self.encoding = encoding
35
+ self._records: list[dict[str, Any]] | None = None
36
+
37
+ def initial_state(self) -> dict[str, Any]:
38
+ return {"cycle": 0, "cursor": 0}
39
+
40
+ def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool:
41
+ normalized = self.normalize_state(state)
42
+ return int(normalized["cursor"]) == 0
43
+
44
+ def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]:
45
+ normalized = self.normalize_state(state)
46
+ return {"cycle": int(normalized["cycle"]) + 1, "cursor": 0}
47
+
48
+ def _iter_records(self) -> Iterator[dict[str, Any]]:
49
+ if not self.manifest_path.is_file():
50
+ raise FileNotFoundError(f"Manifest file not found: {self.manifest_path!s}")
51
+ with self.manifest_path.open("r", encoding=self.encoding) as fin:
52
+ for line_no, raw_line in enumerate(fin, start=1):
53
+ line = raw_line.strip()
54
+ if not line:
55
+ continue
56
+ try:
57
+ yield json.loads(line)
58
+ except json.JSONDecodeError as exc:
59
+ raise ValueError(
60
+ f"Invalid JSON at {self.manifest_path}:{line_no}"
61
+ ) from exc
62
+
63
+ def _base_records(self) -> list[dict[str, Any]]:
64
+ if self._records is None:
65
+ self._records = list(self._iter_records())
66
+ return self._records
67
+
68
+ def _build_sample(self, record: dict[str, Any]) -> dict[str, Any]:
69
+ missing = [
70
+ key
71
+ for key in (self.fid_key, self.text_key, self.audio_key)
72
+ if key not in record
73
+ ]
74
+ if missing:
75
+ raise KeyError(
76
+ f"Manifest record is missing required keys {missing}: {record}"
77
+ )
78
+
79
+ sample = {
80
+ "fid": str(record[self.fid_key]),
81
+ "text": record[self.text_key],
82
+ "audio": record[self.audio_key],
83
+ }
84
+ for key, value in record.items():
85
+ if key in {self.fid_key, self.text_key, self.audio_key}:
86
+ continue
87
+ sample[key] = value
88
+ return sample
89
+
90
+ def _indices_for_cycle(
91
+ self,
92
+ context: SourceContext,
93
+ *,
94
+ cycle: int,
95
+ ) -> list[int]:
96
+ indices = list(range(len(self._base_records())))
97
+ if self.shuffle:
98
+ random.Random(context.seed + context.epoch + 1009 * int(cycle)).shuffle(
99
+ indices
100
+ )
101
+ indices = [
102
+ record_index
103
+ for shuffled_index, record_index in enumerate(indices)
104
+ if self.is_assigned_index(shuffled_index, context)
105
+ ]
106
+ else:
107
+ indices = [
108
+ record_index
109
+ for record_index in indices
110
+ if self.is_assigned_index(record_index, context)
111
+ ]
112
+ return indices
113
+
114
+ def iter_samples(
115
+ self,
116
+ context: SourceContext,
117
+ *,
118
+ state: dict[str, Any] | None = None,
119
+ ) -> Iterable[dict[str, Any]]:
120
+ live_state = self.normalize_state(state)
121
+ cycle = int(live_state["cycle"])
122
+ cursor = int(live_state["cursor"])
123
+ records = self._base_records()
124
+ indices = self._indices_for_cycle(context, cycle=cycle)
125
+
126
+ for position in range(cursor, len(indices)):
127
+ sample = self._build_sample(records[indices[position]])
128
+ sample["_adapter_state"] = {
129
+ "cycle": cycle,
130
+ "cursor": position + 1,
131
+ }
132
+ yield sample
src/dots_tts/data/source_adapters/multi_source_adapter.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+ from copy import deepcopy
5
+ from dataclasses import dataclass
6
+
7
+ from dots_tts.data.pipelines.base import BaseSamplePipeline
8
+ from dots_tts.data.source_adapters.base_adapter import (
9
+ BaseSourceAdapter,
10
+ SourceContext,
11
+ )
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class SourceSpec:
16
+ name: str
17
+ weight: float
18
+ adapter: BaseSourceAdapter
19
+ pipeline: BaseSamplePipeline
20
+
21
+
22
+ _UINT64_MASK = 0xFFFFFFFFFFFFFFFF
23
+
24
+
25
+ def _mix_uint64(value: int) -> int:
26
+ value = (value ^ (value >> 30)) * 0xBF58476D1CE4E5B9
27
+ value &= _UINT64_MASK
28
+ value = (value ^ (value >> 27)) * 0x94D049BB133111EB
29
+ value &= _UINT64_MASK
30
+ return (value ^ (value >> 31)) & _UINT64_MASK
31
+
32
+
33
+ def _stable_seed(*parts: int) -> int:
34
+ value = 0x9E3779B97F4A7C15
35
+ for part in parts:
36
+ value = (value + int(part) + 0x9E3779B97F4A7C15) & _UINT64_MASK
37
+ value = _mix_uint64(value)
38
+ return value
39
+
40
+
41
+ class SequentialMultiSourceAdapter(BaseSourceAdapter):
42
+ """Finite adapter that concatenates sources in the configured order."""
43
+
44
+ def __init__(self, *, sources: list[SourceSpec]):
45
+ if not sources:
46
+ raise ValueError(
47
+ "SequentialMultiSourceAdapter requires at least one source."
48
+ )
49
+ self.sources = list(sources)
50
+
51
+ def initial_state(self) -> dict:
52
+ return {
53
+ "source_index": 0,
54
+ "sources": {
55
+ source.name: source.adapter.initial_state() for source in self.sources
56
+ },
57
+ }
58
+
59
+ def is_cycle_start_state(self, state: dict | None) -> bool:
60
+ normalized = self.normalize_state(state)
61
+ if int(normalized["source_index"]) != 0:
62
+ return False
63
+ return all(
64
+ source.adapter.is_cycle_start_state(normalized["sources"][source.name])
65
+ for source in self.sources
66
+ )
67
+
68
+ def normalize_state(self, state: dict | None) -> dict:
69
+ normalized = super().normalize_state(state)
70
+ source_states = normalized.get("sources") or {}
71
+ normalized["sources"] = {
72
+ source.name: source.adapter.clone_state(source_states.get(source.name))
73
+ for source in self.sources
74
+ }
75
+ normalized["source_index"] = int(normalized.get("source_index", 0))
76
+ return normalized
77
+
78
+ def clone_state(self, state: dict | None) -> dict:
79
+ return deepcopy(self.normalize_state(state))
80
+
81
+ def iter_samples(
82
+ self,
83
+ context: SourceContext,
84
+ *,
85
+ state: dict | None = None,
86
+ ) -> Iterable[dict]:
87
+ live_state = self.normalize_state(state)
88
+ start_index = int(live_state["source_index"])
89
+ for index in range(start_index, len(self.sources)):
90
+ source = self.sources[index]
91
+ child_state = live_state["sources"][source.name]
92
+ raw_iter = source.adapter.iter_samples(context, state=child_state)
93
+ for sample in source.pipeline(raw_iter):
94
+ item = dict(sample)
95
+ next_child_state = item.pop("_adapter_state", None)
96
+ if next_child_state is None:
97
+ raise RuntimeError(
98
+ f"{source.adapter.__class__.__name__} must attach '_adapter_state' to samples."
99
+ )
100
+ live_state["source_index"] = index
101
+ live_state["sources"][source.name] = source.adapter.clone_state(
102
+ next_child_state
103
+ )
104
+ item["source_name"] = source.name
105
+ item["_adapter_state"] = self.clone_state(live_state)
106
+ yield item
107
+ live_state["source_index"] = index + 1
108
+
109
+
110
+ class WeightedMultiSourceAdapter(BaseSourceAdapter):
111
+ """Infinite weighted sampler that cycles each child source independently."""
112
+
113
+ def __init__(self, *, sources: list[SourceSpec]):
114
+ if not sources:
115
+ raise ValueError("WeightedMultiSourceAdapter requires at least one source.")
116
+ invalid = [source.name for source in sources if float(source.weight) <= 0.0]
117
+ if invalid:
118
+ raise ValueError(f"Source weights must be positive: {invalid}")
119
+ self.sources = list(sources)
120
+ self._cumulative_weights = []
121
+ total = 0.0
122
+ for source in self.sources:
123
+ total += float(source.weight)
124
+ self._cumulative_weights.append(total)
125
+ self._total_weight = total
126
+
127
+ def initial_state(self) -> dict:
128
+ return {
129
+ "draw_count": 0,
130
+ "sources": {
131
+ source.name: source.adapter.initial_state() for source in self.sources
132
+ },
133
+ }
134
+
135
+ def is_cycle_start_state(self, state: dict | None) -> bool:
136
+ normalized = self.normalize_state(state)
137
+ if int(normalized["draw_count"]) != 0:
138
+ return False
139
+ return all(
140
+ source.adapter.is_cycle_start_state(normalized["sources"][source.name])
141
+ for source in self.sources
142
+ )
143
+
144
+ def normalize_state(self, state: dict | None) -> dict:
145
+ normalized = super().normalize_state(state)
146
+ source_states = normalized.get("sources") or {}
147
+ normalized["sources"] = {
148
+ source.name: source.adapter.clone_state(source_states.get(source.name))
149
+ for source in self.sources
150
+ }
151
+ normalized["draw_count"] = int(normalized.get("draw_count", 0))
152
+ return normalized
153
+
154
+ def clone_state(self, state: dict | None) -> dict:
155
+ return deepcopy(self.normalize_state(state))
156
+
157
+ def _source_draw_value(self, context: SourceContext, draw_count: int) -> float:
158
+ raw = _stable_seed(
159
+ context.seed,
160
+ context.epoch,
161
+ context.rank,
162
+ context.worker_id,
163
+ draw_count,
164
+ )
165
+ return (raw / float(1 << 64)) * self._total_weight
166
+
167
+ def _pick_source(self, context: SourceContext, draw_count: int) -> SourceSpec:
168
+ draw_value = self._source_draw_value(context, draw_count)
169
+ for source, upper in zip(self.sources, self._cumulative_weights, strict=True):
170
+ if draw_value < upper:
171
+ return source
172
+ return self.sources[-1]
173
+
174
+ def iter_samples(
175
+ self,
176
+ context: SourceContext,
177
+ *,
178
+ state: dict | None = None,
179
+ ) -> Iterable[dict]:
180
+ live_state = self.normalize_state(state)
181
+ iterators: dict[str, object] = {}
182
+
183
+ while True:
184
+ draw_count = int(live_state["draw_count"])
185
+ source = self._pick_source(context, draw_count)
186
+
187
+ while True:
188
+ child_state = live_state["sources"][source.name]
189
+ child_iter = iterators.get(source.name)
190
+ if child_iter is None:
191
+ raw_iter = source.adapter.iter_samples(context, state=child_state)
192
+ child_iter = iter(source.pipeline(raw_iter))
193
+ iterators[source.name] = child_iter
194
+
195
+ try:
196
+ sample = dict(next(child_iter))
197
+ except StopIteration:
198
+ if source.adapter.is_cycle_start_state(child_state):
199
+ raise RuntimeError(
200
+ "Weighted source yielded no samples for this worker. "
201
+ f"source={source.name!r}, worker={context.global_worker_id}, "
202
+ f"epoch={context.epoch}"
203
+ )
204
+ iterators.pop(source.name, None)
205
+ live_state["sources"][source.name] = source.adapter.advance_cycle(
206
+ child_state
207
+ )
208
+ continue
209
+
210
+ next_child_state = sample.pop("_adapter_state", None)
211
+ if next_child_state is None:
212
+ raise RuntimeError(
213
+ f"{source.adapter.__class__.__name__} must attach '_adapter_state' to samples."
214
+ )
215
+ live_state["sources"][source.name] = source.adapter.clone_state(
216
+ next_child_state
217
+ )
218
+ live_state["draw_count"] = draw_count + 1
219
+ sample["source_name"] = source.name
220
+ sample["_adapter_state"] = self.clone_state(live_state)
221
+ yield sample
222
+ break
src/dots_tts/data/streaming.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import multiprocessing as mp
5
+ from collections.abc import Iterable
6
+ from copy import deepcopy
7
+
8
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
9
+
10
+ from dots_tts.data.batchers import OnlineBatcher
11
+ from dots_tts.utils.profiling import ensure_data_profiler
12
+ from dots_tts.data.source_adapters.base_adapter import BaseSourceAdapter, SourceContext
13
+
14
+ _TRACKING_KEY = "__tracking_state__"
15
+ _RESUME_TOPOLOGY_KEY = "resume_topology"
16
+
17
+
18
+ def identity_collate(sample):
19
+ return sample
20
+
21
+
22
+ class StreamingSampleDataset(IterableDataset):
23
+ def __init__(
24
+ self,
25
+ *,
26
+ source: BaseSourceAdapter,
27
+ rank: int,
28
+ world_size: int,
29
+ seed: int,
30
+ ):
31
+ self.source = source
32
+ self.rank = int(rank)
33
+ self.world_size = int(world_size)
34
+ self.seed = int(seed)
35
+ self._epoch = mp.Value("q", 0)
36
+ self._pending_resume_state: dict | None = None
37
+
38
+ def load_state_dict(self, state: dict | None) -> None:
39
+ self._pending_resume_state = deepcopy(state) if state else None
40
+
41
+ def set_epoch(self, epoch: int) -> None:
42
+ with self._epoch.get_lock():
43
+ self._epoch.value = int(epoch)
44
+
45
+ def _current_epoch(self) -> int:
46
+ with self._epoch.get_lock():
47
+ return int(self._epoch.value)
48
+
49
+ def _take_resume_state(self, epoch: int) -> dict | None:
50
+ if (
51
+ self._pending_resume_state is None
52
+ or int(self._pending_resume_state.get("epoch", -1)) != int(epoch)
53
+ ):
54
+ return None
55
+ state = deepcopy(self._pending_resume_state)
56
+ self._pending_resume_state = None
57
+ return state
58
+
59
+ @staticmethod
60
+ def _validate_resume_topology(
61
+ resume_state: dict,
62
+ *,
63
+ context: SourceContext,
64
+ loader_num_workers: int,
65
+ ) -> None:
66
+ resume_topology = resume_state.get(_RESUME_TOPOLOGY_KEY)
67
+ if not isinstance(resume_topology, dict):
68
+ raise RuntimeError(
69
+ "Resume state is missing required worker topology metadata."
70
+ )
71
+ expected_world_size = int(resume_topology["world_size"])
72
+ expected_num_workers = int(resume_topology["loader_num_workers"])
73
+ expected_global_worker_count = int(resume_topology["global_worker_count"])
74
+ current_num_workers = int(loader_num_workers)
75
+ current_global_worker_count = int(context.global_worker_count)
76
+ if (
77
+ expected_world_size != int(context.world_size)
78
+ or expected_num_workers != current_num_workers
79
+ or expected_global_worker_count != current_global_worker_count
80
+ ):
81
+ raise RuntimeError(
82
+ "Resume requires the same data worker topology as the saved state. "
83
+ f"saved(world_size={expected_world_size}, "
84
+ f"num_workers_per_rank={expected_num_workers}, "
85
+ f"global_worker_count={expected_global_worker_count}), "
86
+ f"current(world_size={context.world_size}, "
87
+ f"num_workers_per_rank={current_num_workers}, "
88
+ f"global_worker_count={current_global_worker_count})."
89
+ )
90
+
91
+ def __iter__(self) -> Iterable[dict]:
92
+ worker_info = get_worker_info()
93
+ if worker_info is None:
94
+ worker_id = 0
95
+ loader_num_workers = 0
96
+ effective_num_workers = 1
97
+ else:
98
+ worker_id = worker_info.id
99
+ loader_num_workers = worker_info.num_workers
100
+ effective_num_workers = worker_info.num_workers
101
+
102
+ epoch = self._current_epoch()
103
+ context = SourceContext(
104
+ epoch=epoch,
105
+ rank=self.rank,
106
+ world_size=self.world_size,
107
+ worker_id=worker_id,
108
+ num_workers=effective_num_workers,
109
+ seed=self.seed,
110
+ )
111
+ resume_state = self._take_resume_state(epoch)
112
+ if resume_state is not None:
113
+ self._validate_resume_topology(
114
+ resume_state,
115
+ context=context,
116
+ loader_num_workers=loader_num_workers,
117
+ )
118
+ worker_state = (
119
+ None
120
+ if resume_state is None
121
+ else (resume_state.get("workers") or {}).get(str(context.global_worker_id))
122
+ )
123
+ sample_iter = self.source.iter_samples(
124
+ context,
125
+ state=None if worker_state is None else worker_state.get("adapter_state"),
126
+ )
127
+ for sample in sample_iter:
128
+ sample["data_worker_id"] = context.worker_id
129
+ sample["data_global_worker_id"] = context.global_worker_id
130
+ yield sample
131
+
132
+
133
+ class _DataStateTracker:
134
+ def __init__(self, *, num_tokens_per_epoch: int | None):
135
+ self.num_tokens_per_epoch = (
136
+ None if num_tokens_per_epoch is None else int(num_tokens_per_epoch)
137
+ )
138
+ self._pending_state: dict | None = None
139
+ self._reset_for_epoch(epoch=0)
140
+
141
+ def _reset_for_epoch(self, *, epoch: int) -> None:
142
+ self.epoch = int(epoch)
143
+ self.samples_emitted = 0
144
+ self.num_text_tokens = 0
145
+ self.num_audio_tokens = 0
146
+ self.num_total_tokens = 0
147
+ self.workers: dict[str, dict] = {}
148
+ self._next_sample_order_by_worker: dict[str, int] = {}
149
+
150
+ def load_state_dict(self, state: dict | None) -> None:
151
+ self._pending_state = deepcopy(state) if state else None
152
+
153
+ def set_epoch(self, epoch: int) -> None:
154
+ if self._pending_state is not None and int(
155
+ self._pending_state.get("epoch", -1)
156
+ ) == int(epoch):
157
+ state = deepcopy(self._pending_state)
158
+ self._pending_state = None
159
+ self.epoch = int(state.get("epoch", epoch))
160
+ self.samples_emitted = int(state.get("samples_emitted", 0))
161
+ self.num_text_tokens = int(state.get("num_text_tokens", 0))
162
+ self.num_audio_tokens = int(state.get("num_audio_tokens", 0))
163
+ self.num_total_tokens = int(state.get("num_total_tokens", 0))
164
+ self.workers = deepcopy(state.get("workers") or {})
165
+ self._next_sample_order_by_worker = {
166
+ worker_key: int((worker_state or {}).get("sample_order", -1)) + 1
167
+ for worker_key, worker_state in self.workers.items()
168
+ }
169
+ return
170
+ self._reset_for_epoch(epoch=int(epoch))
171
+
172
+ def should_stop(self) -> bool:
173
+ return (
174
+ self.num_tokens_per_epoch is not None
175
+ and self.num_total_tokens >= self.num_tokens_per_epoch
176
+ )
177
+
178
+ def stage_sample(self, sample: dict) -> dict:
179
+ item = dict(sample)
180
+ worker_key = str(item.pop("data_global_worker_id"))
181
+ item.pop("data_worker_id", None)
182
+ adapter_state = item.pop("_adapter_state", None)
183
+ sample_order = int(self._next_sample_order_by_worker.get(worker_key, 0))
184
+ self._next_sample_order_by_worker[worker_key] = sample_order + 1
185
+ item[_TRACKING_KEY] = {
186
+ "worker_key": worker_key,
187
+ "adapter_state": deepcopy(adapter_state),
188
+ "sample_order": sample_order,
189
+ "num_text_tokens": int(item["num_text_tokens"]),
190
+ "num_audio_tokens": int(item["num_audio_tokens"]),
191
+ "num_total_tokens": int(
192
+ item.get("num_total_tokens", item["input_ids_length"])
193
+ ),
194
+ }
195
+ return item
196
+
197
+ def _pop_tracking(self, sample: dict) -> tuple[dict, dict]:
198
+ item = dict(sample)
199
+ tracking = item.pop(_TRACKING_KEY, None)
200
+ if not isinstance(tracking, dict):
201
+ raise RuntimeError("Tracked sample is missing internal resume metadata.")
202
+ return item, tracking
203
+
204
+ def _advance_worker(self, tracking: dict) -> None:
205
+ adapter_state = tracking.get("adapter_state")
206
+ if adapter_state is None:
207
+ return
208
+ worker_key = str(tracking["worker_key"])
209
+ sample_order = int(tracking.get("sample_order", -1))
210
+ current_state = self.workers.get(worker_key)
211
+ current_order = int((current_state or {}).get("sample_order", -1))
212
+ if current_order >= sample_order:
213
+ return
214
+ self.workers[worker_key] = {
215
+ "adapter_state": deepcopy(adapter_state),
216
+ "sample_order": sample_order,
217
+ }
218
+
219
+ def mark_samples_dropped(self, samples: list[dict]) -> None:
220
+ for sample in samples:
221
+ _, tracking = self._pop_tracking(sample)
222
+ self._advance_worker(tracking)
223
+
224
+ def commit_batch(self, samples: list[dict]) -> list[dict]:
225
+ committed: list[dict] = []
226
+ for sample in samples:
227
+ item, tracking = self._pop_tracking(sample)
228
+ self._advance_worker(tracking)
229
+ self.samples_emitted += 1
230
+ self.num_text_tokens += int(tracking["num_text_tokens"])
231
+ self.num_audio_tokens += int(tracking["num_audio_tokens"])
232
+ self.num_total_tokens += int(tracking["num_total_tokens"])
233
+ committed.append(item)
234
+ return committed
235
+
236
+ def state_dict(self) -> dict:
237
+ return {
238
+ "epoch": int(self.epoch),
239
+ "samples_emitted": int(self.samples_emitted),
240
+ "num_text_tokens": int(self.num_text_tokens),
241
+ "num_audio_tokens": int(self.num_audio_tokens),
242
+ "num_total_tokens": int(self.num_total_tokens),
243
+ "workers": deepcopy(self.workers),
244
+ "num_tokens_per_epoch": self.num_tokens_per_epoch,
245
+ }
246
+
247
+
248
+ class BatchedDataStream:
249
+ def __init__(
250
+ self,
251
+ *,
252
+ sample_dataset: StreamingSampleDataset,
253
+ data_cfg,
254
+ tokenizer,
255
+ num_tokens_per_epoch: int | None,
256
+ profiler=None,
257
+ ):
258
+ from dots_tts.data.collator import PadCollator
259
+
260
+ self.sample_dataset = sample_dataset
261
+ self.profiler = ensure_data_profiler(profiler)
262
+ llm_token_rate = (
263
+ float(data_cfg.train_audio_sample_rate)
264
+ / float(data_cfg.audio_samples_per_llm_token)
265
+ )
266
+ self.batcher = OnlineBatcher(
267
+ max_audio_tokens_in_batch=max(
268
+ 1,
269
+ math.ceil(float(data_cfg.max_audio_seconds_in_batch) * llm_token_rate),
270
+ ),
271
+ max_text_tokens_in_batch=data_cfg.max_text_tokens_in_batch,
272
+ max_batch_size=data_cfg.max_samples_per_batch,
273
+ sample_pool_size=data_cfg.bucketing_pool_size,
274
+ profiler=self.profiler,
275
+ )
276
+ self.sample_loader = None
277
+ self.collator = PadCollator(tokenizer)
278
+ self.data_state = _DataStateTracker(
279
+ num_tokens_per_epoch=num_tokens_per_epoch
280
+ )
281
+ self._decision_iterator = None
282
+ self._sample_iterator = None
283
+ self._pending_batch = None
284
+ self._pending_samples = None
285
+
286
+ def attach_loader(self, loader: DataLoader) -> None:
287
+ self.sample_loader = loader
288
+
289
+ def close(self) -> None:
290
+ self._reset_iteration_state()
291
+ self.sample_loader = None
292
+
293
+ def load_state_dict(self, state: dict | None) -> None:
294
+ self.data_state.load_state_dict(state)
295
+ self.sample_dataset.load_state_dict(state)
296
+ self._reset_iteration_state()
297
+
298
+ def state_dict(self) -> dict:
299
+ if self.sample_loader is None:
300
+ raise RuntimeError("BatchedDataStream has no attached sample loader.")
301
+ if self._pending_batch is not None or self._pending_samples is not None:
302
+ raise RuntimeError(
303
+ "Cannot serialize BatchedDataStream while a batch is pending commit."
304
+ )
305
+ loader_num_workers = int(getattr(self.sample_loader, "num_workers", 0))
306
+ effective_num_workers = max(1, loader_num_workers)
307
+ state = self.data_state.state_dict()
308
+ state[_RESUME_TOPOLOGY_KEY] = {
309
+ "world_size": int(self.sample_dataset.world_size),
310
+ "loader_num_workers": loader_num_workers,
311
+ "global_worker_count": int(self.sample_dataset.world_size)
312
+ * effective_num_workers,
313
+ }
314
+ return state
315
+
316
+ def set_epoch(self, epoch: int) -> None:
317
+ self.sample_dataset.set_epoch(epoch)
318
+ self.data_state.set_epoch(epoch)
319
+ self._reset_iteration_state()
320
+
321
+ def _reset_iteration_state(self) -> None:
322
+ close_iterator = getattr(self._decision_iterator, "close", None)
323
+ if callable(close_iterator):
324
+ close_iterator()
325
+ self._decision_iterator = None
326
+ self._sample_iterator = None
327
+ self._pending_batch = None
328
+ self._pending_samples = None
329
+
330
+ def _iter_staged_samples(self):
331
+ if self.sample_loader is None:
332
+ raise RuntimeError("BatchedDataStream has no attached sample loader.")
333
+ self._sample_iterator = iter(self.sample_loader)
334
+ profiler = self.profiler
335
+ try:
336
+ while True:
337
+ if self.data_state.should_stop():
338
+ return
339
+ try:
340
+ with profiler.measure("main.loader_wait_next_sample"):
341
+ sample = next(self._sample_iterator)
342
+ except StopIteration:
343
+ return
344
+ if sample is None:
345
+ continue
346
+ with profiler.measure("main.stage_sample"):
347
+ staged = self.data_state.stage_sample(sample)
348
+ yield staged
349
+ finally:
350
+ self._sample_iterator = None
351
+
352
+ def _decision_stream(self):
353
+ if self._decision_iterator is None:
354
+ self._decision_iterator = iter(
355
+ self.batcher.build_decisions(self._iter_staged_samples())
356
+ )
357
+ return self._decision_iterator
358
+
359
+ def peek_batch(self) -> tuple[dict | None, bool]:
360
+ if self._pending_batch is not None:
361
+ return self._pending_batch, True
362
+
363
+ for decision in self._decision_stream():
364
+ if decision.dropped_samples:
365
+ self.data_state.mark_samples_dropped(decision.dropped_samples)
366
+ if not decision.batch_samples:
367
+ continue
368
+ self._pending_samples = decision.batch_samples
369
+ with self.profiler.measure(
370
+ "main.collate_batch",
371
+ count=len(decision.batch_samples),
372
+ ):
373
+ self._pending_batch = self.collator(decision.batch_samples)
374
+ return self._pending_batch, True
375
+ return None, False
376
+
377
+ def commit_batch(self) -> dict:
378
+ if self._pending_batch is None or self._pending_samples is None:
379
+ raise RuntimeError("BatchedDataStream has no pending batch to commit.")
380
+ pending_batch = self._pending_batch
381
+ self.data_state.commit_batch(self._pending_samples)
382
+ self._pending_batch = None
383
+ self._pending_samples = None
384
+ return pending_batch
385
+
386
+ def discard_batch(self) -> None:
387
+ if self._pending_batch is None or self._pending_samples is None:
388
+ raise RuntimeError("BatchedDataStream has no pending batch to discard.")
389
+ self._pending_batch = None
390
+ self._pending_samples = None
391
+
392
+ def __iter__(self):
393
+ while True:
394
+ batch, has_batch = self.peek_batch()
395
+ if not has_batch:
396
+ return
397
+ self.commit_batch()
398
+ yield batch
399
+ if self.data_state.should_stop():
400
+ return
src/dots_tts/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Model families."""
src/dots_tts/models/dots_tts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """dots_tts model package."""
src/dots_tts/models/dots_tts/config.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dots_tts.config.base import ConfigBase, StrictConfigBase
4
+ from dots_tts.modules.vocoder.config import AudioVAEConfig
5
+
6
+
7
+ class _EncoderConfig(ConfigBase):
8
+ num_layers: int = 6
9
+ num_heads: int = 16
10
+ hidden_size: int = 1024
11
+ ffn_hidden_size: int = 4096
12
+ modulation: bool = False
13
+ qkv_bias: bool = False
14
+ qk_norm: bool = False
15
+ attn_dropout: float = 0.0
16
+ dropout: float = 0.0
17
+ norm_layer: str = "LayerNorm"
18
+ alibi_bias: bool = False
19
+ rotary_bias: bool = False
20
+ rotary_theta: float | None = 10000
21
+ input_dim: int = 1024
22
+ causal: bool = True
23
+
24
+
25
+ class _DiTConfig(ConfigBase):
26
+ num_layers: int = 18
27
+ num_heads: int = 16
28
+ hidden_size: int = 1024
29
+ ffn_hidden_size: int = 4096
30
+ modulation: bool = True
31
+ qkv_bias: bool = False
32
+ qk_norm: bool = False
33
+ attn_dropout: float = 0.0
34
+ dropout: float = 0.0
35
+ norm_layer: str = "LayerNorm"
36
+ alibi_bias: bool = False
37
+ rotary_bias: bool = True
38
+ rotary_theta: float | None = 10000
39
+
40
+
41
+ class LossConfig(StrictConfigBase):
42
+ ce_weight: float = 1.0
43
+ fm_weight: float = 1.0
44
+ eos_weight: float = 1.0
45
+
46
+
47
+ class MeanFlowConfig(ConfigBase):
48
+ enabled: bool = False
49
+ use_duration_embedding: bool = True
50
+
51
+
52
+ class ModelConfig(ConfigBase):
53
+ model_type: str = "dots_tts"
54
+ latent_dim: int
55
+ patch_size: int
56
+ cfg_droprate: float = 0.2
57
+ PatchEncoder: _EncoderConfig
58
+ DiT: _DiTConfig
59
+ vocoder: AudioVAEConfig
60
+ fm_sigma: float = 0.0
61
+ xvec_drop_rate: float = 0.2
62
+ campplus_embedding_size: int | None = 512
63
+ xvec_max_audio_seconds: float = 10.0
64
+ meanflow: MeanFlowConfig | None = None
65
+
66
+
67
+ __all__ = [
68
+ "LossConfig",
69
+ "MeanFlowConfig",
70
+ "ModelConfig",
71
+ ]
src/dots_tts/models/dots_tts/core.py ADDED
@@ -0,0 +1,910 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Any, Callable
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange
8
+ from loguru import logger
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from torchdiffeq import odeint
11
+ from transformers import Qwen2Config, Qwen2ForCausalLM
12
+
13
+ from dots_tts.models.dots_tts.config import ModelConfig
14
+ from dots_tts.modules.backbone.dit import DiT
15
+ from dots_tts.modules.backbone.semantic_encoder import VAESemanticEncoder
16
+ from dots_tts.utils.tokenizer import (
17
+ AUDIO_COMP_SPAN_TOKEN,
18
+ AUDIO_GEN_SPAN_TOKEN,
19
+ TEXT_COND_END_TOKEN,
20
+ require_token_id,
21
+ )
22
+ from dots_tts.utils.util import get_mask_from_lengths, mask_data
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class DotsTtsForwardOutput:
27
+ llm_logits: torch.Tensor
28
+ pred: torch.Tensor
29
+ target: torch.Tensor
30
+ eos_out: torch.Tensor
31
+
32
+
33
+ class DotsTtsCore(nn.Module):
34
+ # region Module construction
35
+ def __init__(
36
+ self,
37
+ config: ModelConfig,
38
+ llm_config: Qwen2Config,
39
+ tokenizer=None,
40
+ *,
41
+ latent_stats_path,
42
+ ):
43
+ super().__init__()
44
+ self.config = config
45
+ self.fm_hidden_size = config.DiT.hidden_size
46
+ self.hidden_patch_size = 1
47
+ self.cfg_droprate = config.get("cfg_droprate", 0.2)
48
+ self.latent_patch_size = config.patch_size
49
+ self.latent_dim = config.latent_dim
50
+ self.xvec_dim = config.campplus_embedding_size
51
+ self.xvec_drop_rate = config.get("xvec_drop_rate", 0.2)
52
+
53
+ # Setup tokenizer
54
+ self.tokenizer = tokenizer
55
+ if self.tokenizer is None:
56
+ raise RuntimeError("Tokenizer must be provided before building the model.")
57
+ if llm_config is None:
58
+ raise RuntimeError("LLM config must be provided before building the model.")
59
+ self.pad_token_id = getattr(self.tokenizer, "pad_token_id", None)
60
+ self.audio_gen_span_id = require_token_id(self.tokenizer, AUDIO_GEN_SPAN_TOKEN)
61
+ self.audio_comp_span_id = require_token_id(
62
+ self.tokenizer, AUDIO_COMP_SPAN_TOKEN
63
+ )
64
+ self.text_cond_end_id = require_token_id(self.tokenizer, TEXT_COND_END_TOKEN)
65
+
66
+ # Setup LLM with language modeling head so we can obtain logits directly
67
+ llm_config = copy.deepcopy(llm_config)
68
+ llm_config.vocab_size = len(self.tokenizer)
69
+ self.llm = Qwen2ForCausalLM._from_config(
70
+ llm_config,
71
+ dtype=torch.float32,
72
+ )
73
+ self.llm_hidden_size = self.llm.config.hidden_size
74
+
75
+ self.patch_encoder = VAESemanticEncoder(
76
+ in_dim=self.latent_dim,
77
+ out_dim=self.llm_hidden_size,
78
+ config=config,
79
+ )
80
+
81
+ # Setup Flow matching related modules
82
+ self.hidden_proj = nn.Linear(self.llm_hidden_size, self.fm_hidden_size)
83
+ self.latent_proj = nn.Linear(self.latent_dim, self.fm_hidden_size)
84
+ self.coordinate_proj = nn.Linear(self.latent_dim, self.fm_hidden_size)
85
+ self.xvec_proj = nn.Sequential(
86
+ nn.Linear(self.xvec_dim, self.fm_hidden_size),
87
+ nn.LayerNorm(self.fm_hidden_size),
88
+ )
89
+ self.meanflow_config = config.meanflow if config.meanflow is not None else None
90
+ self.mode = (
91
+ "meanflow"
92
+ if self.meanflow_config is not None and self.meanflow_config.enabled
93
+ else "flow_matching"
94
+ )
95
+ dit_mode = (
96
+ "meanflow"
97
+ if self.mode == "meanflow"
98
+ and self.meanflow_config.use_duration_embedding
99
+ else "flow_matching"
100
+ )
101
+ self.velocity_field_predictor = DiT(
102
+ in_dim=self.fm_hidden_size,
103
+ out_dim=self.latent_dim,
104
+ transformer_config=config.DiT,
105
+ mode=dit_mode,
106
+ )
107
+
108
+ # Setup eos predictor
109
+ self.eos_proj = nn.Sequential(
110
+ nn.Linear(self.llm_hidden_size, self.llm_hidden_size),
111
+ nn.SiLU(),
112
+ nn.Linear(self.llm_hidden_size, 2),
113
+ )
114
+
115
+ # Helpers
116
+ self.fm_helper = FlowMatchingHelper(sigma=config.get("fm_sigma", 0.0))
117
+ self.causal_helper = CausalHelper()
118
+ self.io_helper = IOHelper(latent_stats_path=latent_stats_path)
119
+ self.audio_span_token_ids: list[int] = [
120
+ self.audio_gen_span_id,
121
+ self.audio_comp_span_id,
122
+ ]
123
+ # endregion Module construction
124
+
125
+ # region Training forward path
126
+ def forward(self, data: dict[str, Any]) -> DotsTtsForwardOutput:
127
+ input_ids: torch.Tensor = data["input_ids"]
128
+ input_ids_lengths: torch.Tensor = data["input_ids_lengths"]
129
+ input_span_mask: torch.Tensor = data["input_span_mask"]
130
+ output_span_mask: torch.Tensor = data["output_span_mask"]
131
+ batch_size = input_ids.size(0)
132
+ device = input_ids.device
133
+
134
+ latents: torch.Tensor | None = data.get("latents")
135
+ latents_sampled: torch.Tensor | None = data.get("latents_sampled")
136
+ latent_lengths: torch.Tensor | None = data.get("latent_lengths")
137
+ has_latents = latents is not None or latents_sampled is not None
138
+
139
+ patch_embeddings: torch.Tensor | None
140
+ valid_patch_counts: torch.Tensor | None
141
+ if has_latents:
142
+ if latents_sampled is None:
143
+ latents_sampled = self.io_helper.sample_from_latent(latents)
144
+ patch_embeddings = self.patch_encoder(
145
+ latents_sampled, x_lens=latent_lengths
146
+ )
147
+ valid_patch_counts = latent_lengths // self.latent_patch_size
148
+ latents_sampled = self.io_helper.normalize(latents_sampled)
149
+ else:
150
+ latents_sampled = None
151
+ patch_embeddings = None
152
+ valid_patch_counts = torch.zeros(
153
+ batch_size, dtype=torch.long, device=device
154
+ )
155
+
156
+ input_span_counts = input_span_mask.sum(dim=1)
157
+ if input_span_counts.sum() > 0 and patch_embeddings is None:
158
+ raise RuntimeError(
159
+ "Found audio span tokens but no latents provided to compute patch embeddings."
160
+ )
161
+
162
+ # Token embeddings with audio span replacement
163
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
164
+ if patch_embeddings is not None:
165
+ inputs_embeds = inputs_embeds.clone()
166
+ patch_embeddings = patch_embeddings.to(inputs_embeds.dtype)
167
+ for b in range(batch_size):
168
+ span_num = input_span_counts[b].item()
169
+ if span_num == 0:
170
+ continue
171
+ expected = valid_patch_counts[b].item()
172
+ if expected != span_num:
173
+ raise RuntimeError(
174
+ f"Mismatch between span tokens ({span_num}) and latent patches ({expected}) for sample {b}."
175
+ )
176
+ indices = input_span_mask[b].nonzero(as_tuple=False).squeeze(-1)
177
+ inputs_embeds[b, indices, :] = patch_embeddings[b, :span_num, :]
178
+
179
+ # LLM forward pass to obtain logits & hidden states
180
+ _llm_attn_mask, llm_seq_mask, _ = self.causal_helper.create_causal_mask_and_pos(
181
+ seq_lens=input_ids_lengths, max_len=input_ids.size(1)
182
+ )
183
+ llm_outputs = self.llm(
184
+ inputs_embeds=inputs_embeds,
185
+ attention_mask=llm_seq_mask.long(),
186
+ use_cache=False,
187
+ output_hidden_states=True,
188
+ return_dict=True,
189
+ )
190
+ llm_logits = llm_outputs.logits # [B, L, V]
191
+ llm_hidden = llm_outputs.hidden_states[-1] # [B, L, H]
192
+
193
+ # eos prediction, before cfg masking
194
+ eos = self.eos_proj(llm_hidden.detach())
195
+
196
+ # Flow matching forward
197
+ total_patches = int(output_span_mask.sum().item())
198
+ if total_patches > 0 and latents_sampled is None:
199
+ raise RuntimeError("Flow matching requested but latents are missing.")
200
+ if total_patches > 0:
201
+ xvec_cond = self.xvec_proj(data["xvector"])
202
+ vocal_mask = data.get("vocal_mask")
203
+ if vocal_mask is None:
204
+ vocal_mask = torch.ones((batch_size,), device=device, dtype=torch.bool)
205
+ xvec_drop_mask = (
206
+ torch.empty((batch_size,), device=device, dtype=torch.float32).uniform_(
207
+ 0, 1
208
+ )
209
+ < self.xvec_drop_rate
210
+ )
211
+ xvec_drop_mask = xvec_drop_mask & vocal_mask
212
+ xvec_cond = mask_data(xvec_cond, xvec_drop_mask)
213
+
214
+ hiddens_for_fm = torch.where(
215
+ output_span_mask.unsqueeze(-1), llm_hidden, inputs_embeds
216
+ )
217
+
218
+ # Prepare DiT inputs
219
+ (
220
+ fm_seq,
221
+ target,
222
+ fm_attn_mask,
223
+ fm_seq_mask,
224
+ fm_pos_ids,
225
+ times,
226
+ fm_prefix_lengths,
227
+ fm_gen_lengths,
228
+ fm_gen_patch_size,
229
+ ) = self.io_helper.prepare_inputs_for_dit(
230
+ hiddens=hiddens_for_fm,
231
+ hidden_lens=input_ids_lengths,
232
+ latents=latents_sampled,
233
+ latent_lens=latent_lengths,
234
+ hidden_proj=self.hidden_proj,
235
+ latent_proj=self.latent_proj,
236
+ noisy_proj=self.coordinate_proj,
237
+ span_mask=output_span_mask,
238
+ hidden_patch_size=self.hidden_patch_size,
239
+ latent_patch_size=self.latent_patch_size,
240
+ fm_helper=self.fm_helper,
241
+ cfg_droprate=self.cfg_droprate,
242
+ )
243
+
244
+ # Predict velocity field
245
+ vt = self.velocity_field_predictor(
246
+ x=fm_seq,
247
+ timesteps=times,
248
+ pos_ids=fm_pos_ids,
249
+ mask=fm_seq_mask,
250
+ attn_mask=fm_attn_mask,
251
+ return_hidden_stats=False,
252
+ g_cond=xvec_cond,
253
+ )
254
+
255
+ # Get predictions and targets
256
+ pred = self.io_helper.get_dit_outputs(
257
+ pred_v=vt,
258
+ fm_prefix_lengths=fm_prefix_lengths,
259
+ fm_gen_lengths=fm_gen_lengths,
260
+ fm_gen_patch_size=fm_gen_patch_size,
261
+ latent_patch_size=self.latent_patch_size,
262
+ )
263
+ else:
264
+ # Dummy forward for velocity_field_predictor to keep gradients connected in DDP
265
+ dummy_length = self.latent_patch_size
266
+ dummy_seq_h = llm_hidden.new_zeros((1, dummy_length, self.llm_hidden_size))
267
+ dummy_seq_h = self.hidden_proj(dummy_seq_h) * 0.0 # dummy op for ddp
268
+ dummy_seq_l = llm_hidden.new_zeros((1, dummy_length, self.latent_dim))
269
+ dummy_seq_l = self.latent_proj(dummy_seq_l) * 0.0 # dummy op for ddp
270
+ dummy_seq_c = llm_hidden.new_zeros((1, dummy_length, self.latent_dim))
271
+ dummy_seq_c = self.coordinate_proj(dummy_seq_c) * 0.0 # dummy op for ddp
272
+ dummy_seq = dummy_seq_h + dummy_seq_l + dummy_seq_c
273
+ dummy_times = torch.zeros((1,), device=device, dtype=torch.float32)
274
+ dummy_attn_mask = torch.ones(
275
+ (1, dummy_length, dummy_length), device=device, dtype=torch.bool
276
+ )
277
+ dummy_out = self.velocity_field_predictor(
278
+ x=dummy_seq,
279
+ timesteps=dummy_times,
280
+ attn_mask=dummy_attn_mask,
281
+ )
282
+ pred = dummy_out[:, -self.latent_patch_size :, :]
283
+ target = pred.detach()
284
+
285
+ return DotsTtsForwardOutput(
286
+ llm_logits=llm_logits,
287
+ pred=pred,
288
+ target=target,
289
+ eos_out=eos,
290
+ )
291
+ # endregion Training forward path
292
+
293
+ # region Autoregressive and flow-matching inference steps
294
+ @torch.no_grad()
295
+ def fm_solver_step(
296
+ self,
297
+ t: torch.Tensor,
298
+ z: torch.Tensor,
299
+ *,
300
+ input_sequence: torch.Tensor,
301
+ cfg_sequence: torch.Tensor,
302
+ attn_mask: torch.Tensor,
303
+ pos_ids: torch.Tensor | None,
304
+ hidden_size: int,
305
+ patch_size: int,
306
+ g_cond: torch.Tensor | None,
307
+ guidance_scale: torch.Tensor | float,
308
+ ) -> torch.Tensor:
309
+ batch_size = input_sequence.size(0)
310
+ if input_sequence.shape != cfg_sequence.shape:
311
+ raise ValueError(
312
+ "FM input_sequence and cfg_sequence must share the same shape."
313
+ )
314
+ if input_sequence.size(1) < patch_size:
315
+ raise ValueError(
316
+ "FM input sequence must reserve at least one latent patch slot."
317
+ )
318
+ latent_start = input_sequence.size(1) - patch_size
319
+ z = self.coordinate_proj(z)
320
+ z_c = input_sequence.clone()
321
+ z_c[:, latent_start:] = z
322
+ z_branches = [z_c]
323
+ g_cond_t = (
324
+ None if g_cond is None else g_cond.to(device=z_c.device, dtype=z_c.dtype)
325
+ )
326
+ g_cond_branches = None if g_cond_t is None else [g_cond_t]
327
+
328
+ z_cfg = cfg_sequence.clone()
329
+ z_cfg[:, latent_start:] = z
330
+ z_branches.append(z_cfg)
331
+ if g_cond_branches is not None:
332
+ g_cond_branches.append(torch.zeros_like(g_cond_t))
333
+
334
+ z_z = torch.cat(z_branches, dim=0)
335
+ t_t = t.reshape(1).repeat(len(z_branches))
336
+ if g_cond_branches is not None:
337
+ g_cond_t = torch.cat(g_cond_branches, dim=0)
338
+ vt = self.velocity_field_predictor(
339
+ x=z_z,
340
+ timesteps=t_t,
341
+ attn_mask=attn_mask,
342
+ pos_ids=pos_ids,
343
+ g_cond=g_cond_t,
344
+ hidden_size=patch_size * 2 + hidden_size,
345
+ patch_size=patch_size + 1,
346
+ )
347
+ vt = vt[:, latent_start:]
348
+ vt_c = vt[:batch_size]
349
+ vt_u = vt[batch_size:]
350
+ if not torch.is_tensor(guidance_scale):
351
+ guidance_scale = vt_c.new_tensor(float(guidance_scale))
352
+ else:
353
+ guidance_scale = guidance_scale.to(device=vt_c.device, dtype=vt_c.dtype)
354
+ return vt_c + guidance_scale * (vt_c - vt_u)
355
+
356
+ @torch.no_grad()
357
+ def step_llm(
358
+ self,
359
+ inputs_embeds: torch.Tensor | None = None,
360
+ input_ids: torch.Tensor | None = None,
361
+ past_key_values: Any | None = None,
362
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any | None]:
363
+ provided = int(inputs_embeds is not None) + int(input_ids is not None)
364
+ if provided != 1:
365
+ raise ValueError(
366
+ "Exactly one of inputs_embeds or input_ids must be provided to step_llm()."
367
+ )
368
+
369
+ if inputs_embeds is not None:
370
+ pass
371
+ else:
372
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
373
+
374
+ outputs = self.llm(
375
+ inputs_embeds=inputs_embeds,
376
+ past_key_values=past_key_values,
377
+ use_cache=True,
378
+ output_hidden_states=True,
379
+ return_dict=True,
380
+ )
381
+
382
+ hidden = outputs.hidden_states[-1]
383
+ logits = outputs.logits
384
+ past_key_values = outputs.past_key_values
385
+
386
+ return inputs_embeds, hidden, logits, past_key_values
387
+
388
+ @torch.no_grad()
389
+ def _meanflow_step_fm(
390
+ self,
391
+ *,
392
+ input_sequence: torch.Tensor,
393
+ attn_mask: torch.Tensor,
394
+ pos_ids: torch.Tensor | None,
395
+ patch_size: int,
396
+ g_cond: torch.Tensor | None = None,
397
+ nfe: int = 2,
398
+ solver_step: Callable[..., torch.Tensor] | None = None,
399
+ ) -> torch.Tensor:
400
+ if nfe <= 0:
401
+ raise ValueError(f"MeanFlow nfe must be positive, got {nfe}.")
402
+ batch_size = input_sequence.size(0)
403
+ device = input_sequence.device
404
+ dtype = input_sequence.dtype
405
+ solver_step = self.meanflow_solver_step if solver_step is None else solver_step
406
+ z = (
407
+ torch.randn(
408
+ (batch_size, patch_size, self.latent_dim),
409
+ device=device,
410
+ dtype=dtype,
411
+ )
412
+ )
413
+ times = torch.linspace(0.0, 1.0, nfe + 1, device=device, dtype=dtype)
414
+
415
+ for step in range(nfe):
416
+ t = times[step].expand(batch_size)
417
+ dt = (times[step + 1] - times[step]).expand(batch_size)
418
+ z = solver_step(
419
+ z,
420
+ t=t,
421
+ dt=dt,
422
+ input_sequence=input_sequence,
423
+ attn_mask=attn_mask,
424
+ pos_ids=pos_ids,
425
+ patch_size=patch_size,
426
+ g_cond=g_cond,
427
+ ).clone()
428
+ return z
429
+
430
+ @torch.no_grad()
431
+ def meanflow_solver_step(
432
+ self,
433
+ z: torch.Tensor,
434
+ *,
435
+ t: torch.Tensor,
436
+ dt: torch.Tensor,
437
+ input_sequence: torch.Tensor,
438
+ attn_mask: torch.Tensor,
439
+ pos_ids: torch.Tensor | None,
440
+ patch_size: int,
441
+ g_cond: torch.Tensor | None,
442
+ ) -> torch.Tensor:
443
+ if input_sequence.size(1) < patch_size:
444
+ raise ValueError(
445
+ "MeanFlow input sequence must reserve at least one latent patch slot."
446
+ )
447
+ latent_start = input_sequence.size(1) - patch_size
448
+ z_proj = self.coordinate_proj(z)
449
+ z_c = input_sequence.clone()
450
+ z_c[:, latent_start:] = z_proj
451
+ vt = self.velocity_field_predictor(
452
+ x=z_c,
453
+ timesteps=t,
454
+ duration=dt,
455
+ attn_mask=attn_mask,
456
+ pos_ids=pos_ids,
457
+ g_cond=g_cond,
458
+ )
459
+ velocity = vt[:, latent_start:]
460
+ return z + velocity * dt.view(-1, 1, 1)
461
+
462
+ @torch.no_grad()
463
+ def _flow_matching_step_fm(
464
+ self,
465
+ *,
466
+ input_sequence: torch.Tensor,
467
+ cfg_sequence: torch.Tensor,
468
+ attn_mask: torch.Tensor,
469
+ pos_ids: torch.Tensor | None,
470
+ hidden_size: int,
471
+ patch_size: int,
472
+ g_cond: torch.Tensor | None = None,
473
+ ode_method: str = "euler",
474
+ num_steps: int = 10,
475
+ guidance_scale: float = 3.0,
476
+ solver_step: Callable[..., torch.Tensor] | None = None,
477
+ ) -> torch.Tensor:
478
+ batch_size = input_sequence.size(0)
479
+ num_evals = 0
480
+ solver_step = self.fm_solver_step if solver_step is None else solver_step
481
+ guidance_scale_tensor = input_sequence.new_tensor(float(guidance_scale))
482
+
483
+ # Prepare ODE solver
484
+ def solver(t, z):
485
+ nonlocal num_evals
486
+ num_evals += 1
487
+ return solver_step(
488
+ t,
489
+ z,
490
+ input_sequence=input_sequence,
491
+ cfg_sequence=cfg_sequence,
492
+ attn_mask=attn_mask,
493
+ pos_ids=pos_ids,
494
+ hidden_size=hidden_size,
495
+ patch_size=patch_size,
496
+ g_cond=g_cond,
497
+ guidance_scale=guidance_scale_tensor,
498
+ )
499
+
500
+ # Prepare noise as initial coordinate
501
+ noise = torch.randn(
502
+ (batch_size, patch_size, self.latent_dim),
503
+ dtype=input_sequence.dtype,
504
+ device=input_sequence.device,
505
+ )
506
+ # Solve
507
+ times = torch.tensor(
508
+ [0.0, 1.0], dtype=input_sequence.dtype, device=input_sequence.device
509
+ )
510
+ if ode_method in ["euler", "midpoint", "rk4"]: # fixed step size methods
511
+ options = {"step_size": 1.0 / num_steps}
512
+ else:
513
+ logger.warning(
514
+ "Using adaptive step size ODE solver for FM, NFE is not guaranteed: "
515
+ "ode_method={}",
516
+ ode_method,
517
+ )
518
+ options = {}
519
+ trajectory = odeint(
520
+ func=solver,
521
+ y0=noise,
522
+ t=times,
523
+ atol=1e-5,
524
+ rtol=1e-5,
525
+ method=ode_method,
526
+ options=options,
527
+ )
528
+ # print(f"Expected NFE: {num_steps}, Actual NFE: {num_evals}")
529
+ return trajectory[-1]
530
+
531
+ @torch.no_grad()
532
+ def step_fm(
533
+ self,
534
+ input_sequence: torch.Tensor,
535
+ cfg_sequence: torch.Tensor,
536
+ attn_mask: torch.Tensor,
537
+ pos_ids: torch.Tensor | None,
538
+ hidden_size: int,
539
+ patch_size: int,
540
+ g_cond: torch.Tensor | None = None,
541
+ ode_method: str = "euler",
542
+ num_steps: int = 10,
543
+ guidance_scale: float = 3.0,
544
+ solver_step: Callable[..., torch.Tensor] | None = None,
545
+ ) -> torch.Tensor:
546
+ if self.mode == "meanflow":
547
+ return self._meanflow_step_fm(
548
+ input_sequence=input_sequence,
549
+ attn_mask=attn_mask,
550
+ pos_ids=pos_ids,
551
+ patch_size=patch_size,
552
+ g_cond=g_cond,
553
+ nfe=num_steps,
554
+ solver_step=solver_step,
555
+ )
556
+
557
+ return self._flow_matching_step_fm(
558
+ input_sequence=input_sequence,
559
+ cfg_sequence=cfg_sequence,
560
+ attn_mask=attn_mask,
561
+ pos_ids=pos_ids,
562
+ hidden_size=hidden_size,
563
+ patch_size=patch_size,
564
+ g_cond=g_cond,
565
+ ode_method=ode_method,
566
+ num_steps=num_steps,
567
+ guidance_scale=guidance_scale,
568
+ solver_step=solver_step,
569
+ )
570
+ # endregion Autoregressive and flow-matching inference steps
571
+
572
+
573
+ class FlowMatchingHelper:
574
+ """
575
+ Base helper for computing x_t and u_t, given target x_1 and noise x_0
576
+ ref: Flow matching for generative modeling, Lipman
577
+ """
578
+
579
+ def __init__(self, sigma=1e-5):
580
+ self.sigma = sigma
581
+
582
+ def compute_mu_t(self, x1, t):
583
+ return t * x1
584
+
585
+ def compute_sigma_t(self, t):
586
+ return 1 - (1 - self.sigma) * t
587
+
588
+ def sample_x_t(self, x0, x1, t):
589
+ mu_t = self.compute_mu_t(x1, t)
590
+ sigma_t = self.compute_sigma_t(t)
591
+ return mu_t + sigma_t * x0
592
+
593
+ def compute_u_t(self, x0, x1):
594
+ return x1 - (1 - self.sigma) * x0
595
+
596
+ def compute_xt_ut(self, x1, t=None, x0=None):
597
+ if x0 is None:
598
+ x0 = torch.randn_like(x1, device=x1.device)
599
+ if t is None:
600
+ t = torch.rand(x1.size(0), dtype=x1.dtype, device=x1.device)
601
+ times = t
602
+ t = t.reshape(-1, *([1] * (x1.dim() - 1)))
603
+ xt = self.sample_x_t(x0, x1, t)
604
+ ut = self.compute_u_t(x0, x1)
605
+ return xt, ut, times
606
+
607
+
608
+ class CausalHelper:
609
+ def create_causal_mask_and_pos(self, seq_lens, max_len):
610
+ seq_mask = get_mask_from_lengths(seq_lens, max_len=max_len).unsqueeze(1)
611
+ causal_mask = (
612
+ torch.ones((max_len, max_len), device=seq_lens.device).triu(1).bool()
613
+ )
614
+ causal_mask = ~causal_mask.unsqueeze(0)
615
+ attn_mask = seq_mask & causal_mask
616
+ return attn_mask, seq_mask.squeeze(1), None
617
+
618
+ def create_causal_chunk_mask_and_pos(
619
+ self,
620
+ batch_size,
621
+ C_lens,
622
+ Z_lens,
623
+ span_mask,
624
+ patch_size=8,
625
+ ):
626
+ device = C_lens.device
627
+ total_lens = C_lens + Z_lens
628
+ attn_mask = torch.zeros(
629
+ (batch_size, total_lens.max(), total_lens.max()),
630
+ device=device,
631
+ dtype=torch.bool,
632
+ )
633
+ pos_ids = []
634
+ # | C2C | |
635
+ # | Z2C | Z2Z |
636
+ for i in range(batch_size):
637
+ C_len = C_lens[i]
638
+ Z_len = Z_lens[i]
639
+
640
+ # C2C parts are standard causal attention
641
+ attn_mask[i, :C_len, :C_len] = (
642
+ torch.ones((C_len, C_len), device=device, dtype=torch.bool)
643
+ .triu(1)
644
+ .logical_not()
645
+ )
646
+ # Position ids in C parts are 0, 1, 2, ..., n
647
+ c_pos = torch.arange(C_len, device=device, dtype=torch.float32)
648
+
649
+ # Z2Z parts are block diag attention
650
+ assert Z_len % patch_size == 0, "Z_len must be multiple of patch_size"
651
+ attn_mask[i, C_len : C_len + Z_len, C_len : C_len + Z_len] = (
652
+ torch.block_diag(
653
+ *[
654
+ torch.ones(
655
+ (patch_size, patch_size), device=device, dtype=torch.bool
656
+ )
657
+ ]
658
+ * (Z_len // patch_size)
659
+ )
660
+ )
661
+
662
+ # Z2C parts is full attention before current patch latents
663
+ # build according to span_mask
664
+ j_indices = torch.arange(Z_len, device=device)
665
+ patch_indices = j_indices // patch_size
666
+ patch_in_c_indices = torch.where(span_mask[i])[0][patch_indices]
667
+ attn_mask[
668
+ i,
669
+ C_len + j_indices.unsqueeze(1),
670
+ torch.arange(C_len, device=device).unsqueeze(0),
671
+ ] = torch.arange(C_len, device=device).unsqueeze(
672
+ 0
673
+ ) < patch_in_c_indices.unsqueeze(1)
674
+ # Position ids in Z parts start from current patch latents index in C parts
675
+ z_pos = (patch_in_c_indices + j_indices % patch_size).to(torch.float32)
676
+ pos_ids.append(torch.cat([c_pos, z_pos]))
677
+ seq_mask = get_mask_from_lengths(total_lens, max_len=total_lens.max().item())
678
+ pos_ids = pad_sequence(pos_ids, batch_first=True, padding_value=0.0).to(
679
+ C_lens.device
680
+ )
681
+ return attn_mask, seq_mask, pos_ids
682
+
683
+
684
+ class IOHelper:
685
+ def __init__(self, latent_stats_path=None):
686
+ if latent_stats_path is not None:
687
+ latent_stats = torch.load(latent_stats_path, weights_only=False)
688
+ self.global_mean = torch.as_tensor(latent_stats["mean"])
689
+ self.global_var = torch.as_tensor(latent_stats["var"])
690
+ else:
691
+ self.global_mean = None
692
+ self.global_var = None
693
+
694
+ def normalize(self, x):
695
+ if self.global_mean is not None and self.global_var is not None:
696
+ x = (x - self.global_mean.to(x.device)) / torch.sqrt(
697
+ self.global_var.to(x.device)
698
+ )
699
+ return x
700
+
701
+ def denormalize(self, x):
702
+ if self.global_mean is not None and self.global_var is not None:
703
+ x = x * torch.sqrt(self.global_var.to(x.device)) + self.global_mean.to(
704
+ x.device
705
+ )
706
+ return x
707
+
708
+ @staticmethod
709
+ def sample_from_latent(latent):
710
+ mean, log_std = latent.chunk(2, 1)
711
+ z = mean + torch.randn_like(mean) * torch.exp(log_std)
712
+ return z.transpose(1, 2)
713
+
714
+ @staticmethod
715
+ def prepare_inputs_for_dit(
716
+ hiddens,
717
+ hidden_lens,
718
+ latents,
719
+ latent_lens,
720
+ hidden_proj,
721
+ latent_proj,
722
+ noisy_proj,
723
+ span_mask,
724
+ hidden_patch_size,
725
+ latent_patch_size,
726
+ fm_helper,
727
+ cfg_droprate=-1,
728
+ ):
729
+ assert hidden_patch_size == 1, "Hidden patch size > 1 is not supported."
730
+
731
+ B, _, _, device = *hiddens.shape, hiddens.device
732
+
733
+ # Gather span hidden states for flow matching using span_mask
734
+ span_hidden_list = []
735
+ for b in range(B):
736
+ indices = span_mask[b].nonzero(as_tuple=False).squeeze(-1)
737
+ span_hidden_list.append(hiddens[b, indices, :])
738
+ hiddens = pad_sequence(span_hidden_list, batch_first=True, padding_value=0.0)
739
+ hidden_lens = torch.tensor(
740
+ [t.size(0) for t in span_hidden_list], device=device, dtype=torch.long
741
+ )
742
+
743
+ # Update span_mask to be all True for the new lengths
744
+ max_len = hiddens.size(1)
745
+ span_mask = torch.arange(max_len, device=device).expand(
746
+ B, max_len
747
+ ) < hidden_lens.unsqueeze(1)
748
+
749
+ # Prepare history latents
750
+ history_latents = latent_proj(latents)
751
+ fm_dim = history_latents.shape[-1]
752
+ assert (latent_patch_size * history_latents.size(1) % latents.size(1)) == 0
753
+ latent_history_patch_size = (
754
+ latent_patch_size * history_latents.size(1) // latents.size(1)
755
+ )
756
+
757
+ # Prepare llm hidden with cfg masking
758
+ cfg_mask = (
759
+ torch.empty((B,), dtype=torch.float, device=latents.device).uniform_(0, 1)
760
+ < cfg_droprate
761
+ )
762
+ hiddens = hidden_proj(mask_data(hiddens, cfg_mask))
763
+
764
+ # Prepare noise latents
765
+ xt, ut, times = fm_helper.compute_xt_ut(latents)
766
+ projected_noise = noisy_proj(xt)
767
+
768
+ # Initialize empty fm_seq
769
+ hist_chunk_size = hidden_patch_size + latent_history_patch_size
770
+ valid_patch_counts = latent_lens // latent_patch_size
771
+ fm_prefix_lengths = hidden_lens + valid_patch_counts * (
772
+ hist_chunk_size - hidden_patch_size
773
+ )
774
+ fm_gen_lengths = latent_lens + valid_patch_counts * hidden_patch_size
775
+ fm_gen_patch_size = hidden_patch_size + latent_patch_size
776
+ fm_seq_lengths = fm_prefix_lengths + fm_gen_lengths
777
+ fm_seq = torch.zeros(
778
+ (B, fm_seq_lengths.max().item(), fm_dim),
779
+ dtype=history_latents.dtype,
780
+ device=device,
781
+ )
782
+ fm_target = []
783
+ patch_context_lengths = []
784
+ history_latent_span_mask = torch.zeros(
785
+ (B, fm_seq_lengths.max().item()), dtype=torch.bool, device=device
786
+ ) # to mark start positions of each history latents
787
+
788
+ # Fill fm_seq
789
+ for b in range(B):
790
+ # Step 1: Interleave hiddens at span positions with patched_latents
791
+ interleaved = []
792
+ span_mask_b = span_mask[b, : hidden_lens[b]]
793
+ interleaved.append(
794
+ hiddens[b, : hidden_lens[b]][span_mask_b].reshape(
795
+ valid_patch_counts[b], hidden_patch_size, fm_dim
796
+ )
797
+ )
798
+ interleaved.append(
799
+ history_latents[
800
+ b, : valid_patch_counts[b] * latent_history_patch_size, :
801
+ ].reshape(valid_patch_counts[b], latent_history_patch_size, fm_dim)
802
+ )
803
+ interleaved = torch.cat(interleaved, dim=1)
804
+ interleaved = rearrange(
805
+ interleaved, "n h d -> (n h) d"
806
+ ) # [num_spans*hist_chunk_size, D]
807
+
808
+ # Step 2: Build mapping from input positions to fm positions
809
+ position_increment = torch.where(
810
+ span_mask_b, hist_chunk_size, 1
811
+ ) # span->hist_chunk_size, non-span->1
812
+ fm_seq_positions = (
813
+ torch.cumsum(position_increment, dim=0) - position_increment
814
+ )
815
+
816
+ # Step 3: Scatter non-span hiddens
817
+ non_span_mask = ~span_mask_b
818
+ non_span_indices = fm_seq_positions[non_span_mask] # [num_non_spans]
819
+ fm_seq[b, non_span_indices, :] = hiddens[b, : hidden_lens[b]][
820
+ non_span_mask, :
821
+ ]
822
+
823
+ # Step 4: Scatter interleaved span tokens
824
+ span_indices = fm_seq_positions[span_mask_b] # [num_spans]
825
+ span_indices_expanded = torch.stack(
826
+ [span_indices + i for i in range(hist_chunk_size)], dim=1
827
+ ) # [num_spans, hist_chunk_size]
828
+ span_indices_flat = span_indices_expanded.reshape(
829
+ -1
830
+ ) # [num_spans*hist_chunk_size]
831
+ fm_seq[b, span_indices_flat, :] = interleaved
832
+ history_latent_span_mask[b, span_indices] = True
833
+ patch_context_lengths.append(span_indices.clone())
834
+
835
+ # Step 5: Fill with noise latents at the end
836
+ noise_part = []
837
+ span_mask_b = span_mask[b, : hidden_lens[b]]
838
+ noise_part.append(
839
+ hiddens[b, : hidden_lens[b]][span_mask_b].reshape(
840
+ valid_patch_counts[b], hidden_patch_size, fm_dim
841
+ )
842
+ )
843
+ noise_part.append(
844
+ projected_noise[b, : latent_lens[b], :].reshape(
845
+ valid_patch_counts[b], latent_patch_size, fm_dim
846
+ )
847
+ )
848
+ noise_part = torch.cat(noise_part, dim=1)
849
+ noise_part = rearrange(noise_part, "n h d -> (n h) d")
850
+ noise_start = fm_seq_positions[-1] + position_increment[-1]
851
+ noise_end = noise_start + fm_gen_lengths[b]
852
+ fm_seq[b, noise_start:noise_end, :] = noise_part
853
+
854
+ # Step 6: prepare fm_target
855
+ ut_b = ut[b, : latent_lens[b], :]
856
+ fm_target.append(rearrange(ut_b, "(n p) d -> n p d", p=latent_patch_size))
857
+
858
+ # Construct fm_attn_mask and fm_pos_ids
859
+ fm_attn_mask, fm_seq_mask, fm_pos_ids = (
860
+ CausalHelper().create_causal_chunk_mask_and_pos(
861
+ batch_size=B,
862
+ C_lens=fm_prefix_lengths,
863
+ Z_lens=fm_gen_lengths,
864
+ span_mask=history_latent_span_mask,
865
+ patch_size=fm_gen_patch_size,
866
+ )
867
+ )
868
+ fm_prefix_lengths = fm_prefix_lengths.unsqueeze(1)
869
+ fm_gen_lengths = fm_gen_lengths.unsqueeze(1)
870
+ fm_target = torch.cat(fm_target, dim=0)
871
+ results = [
872
+ fm_seq,
873
+ fm_target,
874
+ fm_attn_mask,
875
+ fm_seq_mask,
876
+ fm_pos_ids,
877
+ times,
878
+ fm_prefix_lengths,
879
+ fm_gen_lengths,
880
+ fm_gen_patch_size,
881
+ ]
882
+ return tuple(results)
883
+
884
+ @staticmethod
885
+ def get_dit_outputs(
886
+ pred_v,
887
+ fm_prefix_lengths,
888
+ fm_gen_lengths,
889
+ fm_gen_patch_size,
890
+ latent_patch_size,
891
+ ):
892
+ B, P = fm_prefix_lengths.shape
893
+ fm_pred = []
894
+ for b in range(B):
895
+ p_offset = 0
896
+ for p in range(P):
897
+ latents_b = pred_v[
898
+ b,
899
+ p_offset + fm_prefix_lengths[b][p] : p_offset
900
+ + fm_prefix_lengths[b][p]
901
+ + fm_gen_lengths[b][p],
902
+ ]
903
+ latents_b = rearrange(
904
+ latents_b, "(n p) d -> n p d", p=fm_gen_patch_size
905
+ )
906
+ # extract only the latent parts
907
+ latents_b = latents_b[:, -latent_patch_size:, :]
908
+ fm_pred.append(latents_b)
909
+ p_offset += fm_prefix_lengths[b][p] + fm_gen_lengths[b][p]
910
+ return torch.cat(fm_pred, dim=0)
src/dots_tts/models/dots_tts/model.py ADDED
@@ -0,0 +1,1958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import math
5
+ import os
6
+ import shutil
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ from pathlib import Path
10
+ from typing import Any, Callable, Iterator
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from einops import rearrange
16
+ from loguru import logger
17
+ from safetensors.torch import load_file, save_file
18
+ from transformers import AutoTokenizer, Qwen2Config
19
+
20
+ from dots_tts.models.dots_tts.config import ModelConfig
21
+ from dots_tts.models.dots_tts.core import DotsTtsCore, DotsTtsForwardOutput
22
+ from dots_tts.modules.speaker.encoder import SpeakerXVectorFeatures
23
+ from dots_tts.modules.vocoder.bigvgan import AudioVAE
24
+ from dots_tts.training.losses import LossMasks, LossTerm, LossTerms
25
+ from dots_tts.utils.profiling import measure_inference
26
+ from dots_tts.utils.tokenizer import AUDIO_GEN_START_TOKEN, require_token_id
27
+ from dots_tts.utils.util import get_dtype
28
+
29
+
30
+ _AOTI_BACKENDS = {"aoti", "aot", "aotinductor", "aot_inductor"}
31
+
32
+
33
+ class _AotiMethodModule(nn.Module):
34
+ def __init__(self, owner: nn.Module, method_name: str):
35
+ super().__init__()
36
+ self.owner = owner
37
+ self.method_name = method_name
38
+
39
+ def forward(self, *args, **kwargs):
40
+ raw_method = getattr(type(self.owner), self.method_name, None)
41
+ if raw_method is None:
42
+ return getattr(self.owner, self.method_name)(*args, **kwargs)
43
+ raw_callable = getattr(raw_method, "__wrapped__", raw_method)
44
+ return raw_callable(self.owner, *args, **kwargs)
45
+
46
+
47
+ class _LazyAotiCompiledMethod:
48
+ def __init__(
49
+ self,
50
+ *,
51
+ key: str,
52
+ owner: nn.Module,
53
+ method_name: str,
54
+ signature: tuple[Any, ...] | None,
55
+ ):
56
+ self.key = key
57
+ self.owner = owner
58
+ self.method_name = method_name
59
+ self.signature = signature
60
+ self.compiled: Callable[..., Any] | None = None
61
+ self.fallback: Callable[..., Any] | None = None
62
+
63
+ def __call__(self, *args, **kwargs):
64
+ if self.compiled is not None:
65
+ return self.compiled(*args, **kwargs)
66
+ if self.fallback is not None:
67
+ return self.fallback(*args, **kwargs)
68
+
69
+ try:
70
+ import spaces # noqa: PLC0415
71
+
72
+ if not hasattr(spaces, "aoti_compile"):
73
+ raise RuntimeError("spaces.aoti_compile is not available.")
74
+ exported = torch.export.export(
75
+ _AotiMethodModule(self.owner, self.method_name).eval(),
76
+ args=args,
77
+ kwargs=kwargs,
78
+ )
79
+ self.compiled = spaces.aoti_compile(exported)
80
+ logger.info(
81
+ "AOTI compiled inference target: key={} method={} signature={}",
82
+ self.key,
83
+ self.method_name,
84
+ self.signature,
85
+ )
86
+ return self.compiled(*args, **kwargs)
87
+ except Exception:
88
+ if os.environ.get("DOTS_TTS_AOTI_ALLOW_EAGER_FALLBACK", "0") != "1":
89
+ raise
90
+ logger.exception(
91
+ "AOTI compile failed; falling back to eager method: key={} method={} signature={}",
92
+ self.key,
93
+ self.method_name,
94
+ self.signature,
95
+ )
96
+ self.fallback = getattr(self.owner, self.method_name)
97
+ return self.fallback(*args, **kwargs)
98
+
99
+
100
+ @dataclass
101
+ class _GenerateState:
102
+ llm_cache: Any | None = None
103
+ llm_hiddens: torch.Tensor | None = None
104
+ patch_encoder_state: Any | None = None
105
+ fm_seq_len: int = 0
106
+ fm_capacity: int = 0
107
+ fm_sequence: torch.Tensor | None = None
108
+ fm_cfg_sequence: torch.Tensor | None = None
109
+ fm_null_g_cond: torch.Tensor | None = None
110
+ end_flag: bool = False
111
+
112
+
113
+ @dataclass(frozen=True)
114
+ class _PromptConditioning:
115
+ prompt_patches: torch.Tensor | None = None
116
+ prompt_latents: torch.Tensor | None = None
117
+ g_cond: torch.Tensor | None = None
118
+
119
+
120
+ @dataclass(frozen=True)
121
+ class _GenerateLengthBucket:
122
+ size: int
123
+
124
+ def run_warmup(
125
+ self,
126
+ model: "DotsTtsModel",
127
+ *,
128
+ precision: str,
129
+ ode_method: str,
130
+ num_steps: int,
131
+ guidance_scale: float,
132
+ ) -> None:
133
+ model._warmup_fm_bucket(
134
+ max_audio_patch_count=self.size,
135
+ precision=precision,
136
+ ode_method=ode_method,
137
+ num_steps=num_steps,
138
+ guidance_scale=guidance_scale,
139
+ )
140
+ model._warmup_patch_encoder_bucket(
141
+ max_audio_patch_count=self.size,
142
+ precision=precision,
143
+ )
144
+ device = next(model.core.parameters()).device
145
+ generation_schedule = torch.full(
146
+ (1, self.size + 1),
147
+ fill_value=model.core.audio_gen_span_id,
148
+ dtype=torch.long,
149
+ device=device,
150
+ )
151
+ generation_schedule[0, 0] = model.audio_gen_start_id
152
+ warmup_inputs = {"generation_schedule": generation_schedule}
153
+
154
+ for _ in model.generate_audio_stream(
155
+ warmup_inputs,
156
+ precision=precision,
157
+ ode_method=ode_method,
158
+ num_steps=num_steps,
159
+ guidance_scale=guidance_scale,
160
+ ):
161
+ return
162
+ raise RuntimeError(
163
+ f"Warmup produced no audio chunk for generate bucket {self.size}."
164
+ )
165
+
166
+
167
+ class DotsTtsModel(nn.Module):
168
+ """Full train/infer model assembly around the dots.tts core network."""
169
+
170
+ _GENERATE_LENGTH_BUCKETS = (
171
+ _GenerateLengthBucket(32),
172
+ _GenerateLengthBucket(64),
173
+ _GenerateLengthBucket(128),
174
+ _GenerateLengthBucket(256),
175
+ _GenerateLengthBucket(512),
176
+ _GenerateLengthBucket(1024),
177
+ )
178
+ _COMPILE_TARGETS = frozenset(
179
+ {
180
+ "FM",
181
+ "patch_encoder",
182
+ "vocoder",
183
+ }
184
+ )
185
+ _optimize_enabled = True
186
+ CONFIG_FILENAME = "config.json"
187
+ HF_MODEL_TYPE = "dots_tts"
188
+ HF_ARCHITECTURES = ["DotsTTSForConditionalGeneration"]
189
+ LATENT_STATS_FILENAME = "latent_stats.pt"
190
+ LLM_CONFIG_FILENAME = "llm_config.json"
191
+ MODEL_FILENAME = "model.safetensors"
192
+ VOCODER_FILENAME = "vocoder.safetensors"
193
+ SPEAKER_ENCODER_FILENAME = "speaker_encoder.safetensors"
194
+ _ARTIFACT_ALIASES = (("llm.lm_head.weight", "llm.model.embed_tokens.weight"),)
195
+ REQUIRED_ARTIFACT_FILES = (
196
+ CONFIG_FILENAME,
197
+ LATENT_STATS_FILENAME,
198
+ LLM_CONFIG_FILENAME,
199
+ MODEL_FILENAME,
200
+ VOCODER_FILENAME,
201
+ SPEAKER_ENCODER_FILENAME,
202
+ )
203
+
204
+ # region Module assembly and checkpoint IO
205
+ def __init__(
206
+ self,
207
+ config: ModelConfig,
208
+ tokenizer,
209
+ latent_stats_path: str | Path,
210
+ llm_config: Qwen2Config,
211
+ ):
212
+ super().__init__()
213
+ self.config = config
214
+ self.tokenizer = tokenizer
215
+ self.latent_stats_path = Path(latent_stats_path)
216
+ self.audio_gen_start_id = require_token_id(
217
+ self.tokenizer, AUDIO_GEN_START_TOKEN
218
+ )
219
+
220
+ self.core = DotsTtsCore(
221
+ config,
222
+ llm_config=llm_config,
223
+ tokenizer=tokenizer,
224
+ latent_stats_path=self.latent_stats_path,
225
+ )
226
+ self.vocoder = AudioVAE(config.vocoder).eval()
227
+ self.vocoder.remove_weight_norm()
228
+ self.hop_size = self.vocoder.hop_size
229
+ self.xvector_extractor = SpeakerXVectorFeatures(
230
+ sample_rate=self.vocoder.sample_rate,
231
+ campplus_embedding_size=config.campplus_embedding_size,
232
+ max_audio_seconds=config.xvec_max_audio_seconds,
233
+ ).eval()
234
+
235
+ for param in self.vocoder.parameters():
236
+ param.requires_grad = False
237
+ for param in self.xvector_extractor.parameters():
238
+ param.requires_grad = False
239
+ self._optimize_enabled = True
240
+ self._compiled_models: dict[
241
+ tuple[str, tuple[Any, ...] | None], Callable[..., Any]
242
+ ] = {}
243
+ self._compile_backend = os.environ.get(
244
+ "DOTS_TTS_COMPILE_BACKEND",
245
+ "torch_compile",
246
+ ).strip().lower()
247
+ self._static_generate_workspaces: dict[tuple[Any, ...], dict[str, Any]] = {}
248
+ self._fm_decode_workspaces: dict[tuple[Any, ...], dict[str, torch.Tensor]] = {}
249
+
250
+ def set_optimize(self, optimize: bool) -> None:
251
+ self._optimize_enabled = bool(optimize)
252
+ if not self._optimize_enabled:
253
+ self._compiled_models.clear()
254
+
255
+ def set_compile_backend(self, backend: str) -> None:
256
+ normalized_backend = (backend or "torch_compile").strip().lower()
257
+ if normalized_backend != self._compile_backend:
258
+ self._compiled_models.clear()
259
+ self._compile_backend = normalized_backend
260
+
261
+ def export_compiled_models(
262
+ self,
263
+ ) -> dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]]:
264
+ exported: dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]] = {}
265
+ for cache_key, compiled in self._compiled_models.items():
266
+ if isinstance(compiled, _LazyAotiCompiledMethod):
267
+ if compiled.compiled is not None:
268
+ exported[cache_key] = compiled.compiled
269
+ continue
270
+ exported[cache_key] = compiled
271
+ return exported
272
+
273
+ def import_compiled_models(
274
+ self,
275
+ compiled_models: dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]],
276
+ ) -> None:
277
+ self._compiled_models.update(compiled_models)
278
+
279
+ def set_cfg_droprate(
280
+ self,
281
+ cfg_droprate: float | None = None,
282
+ xvec_drop_rate: float | None = None,
283
+ ) -> None:
284
+ if cfg_droprate is not None:
285
+ self.config.cfg_droprate = cfg_droprate
286
+ self.core.config.cfg_droprate = cfg_droprate
287
+ self.core.cfg_droprate = cfg_droprate
288
+
289
+ if xvec_drop_rate is not None:
290
+ self.config.xvec_drop_rate = xvec_drop_rate
291
+ self.core.config.xvec_drop_rate = xvec_drop_rate
292
+ self.core.xvec_drop_rate = xvec_drop_rate
293
+
294
+ @classmethod
295
+ def _resolve_generate_length_bucket(
296
+ cls,
297
+ max_generate_length: int,
298
+ ) -> _GenerateLengthBucket:
299
+ requested = int(max_generate_length)
300
+ if requested <= 0:
301
+ raise ValueError("max_generate_length must be positive.")
302
+ for bucket in cls._GENERATE_LENGTH_BUCKETS:
303
+ if requested <= bucket.size:
304
+ return bucket
305
+ raise ValueError(
306
+ "max_generate_length exceeds the largest supported compile bucket: "
307
+ f"max_generate_length={requested} "
308
+ f"max_supported={cls._GENERATE_LENGTH_BUCKETS[-1].size}."
309
+ )
310
+
311
+ @torch.no_grad()
312
+ def run_warmup(
313
+ self,
314
+ *,
315
+ max_generate_length: int,
316
+ precision: str = "bfloat16",
317
+ ode_method: str = "euler",
318
+ num_steps: int = 10,
319
+ guidance_scale: float = 1.2,
320
+ ) -> None:
321
+ ceiling_bucket = self._resolve_generate_length_bucket(max_generate_length)
322
+ warmup_buckets = tuple(
323
+ bucket
324
+ for bucket in self._GENERATE_LENGTH_BUCKETS
325
+ if bucket.size <= ceiling_bucket.size
326
+ )
327
+ bucket_sizes = [bucket.size for bucket in warmup_buckets]
328
+ logger.info(
329
+ "Inference warmup started: requested_max_generate_length={} bucket_sizes={}",
330
+ int(max_generate_length),
331
+ bucket_sizes,
332
+ )
333
+ for bucket in warmup_buckets:
334
+ bucket.run_warmup(
335
+ self,
336
+ precision=precision,
337
+ ode_method=ode_method,
338
+ num_steps=num_steps,
339
+ guidance_scale=guidance_scale,
340
+ )
341
+ logger.info(
342
+ "Inference warmup completed: requested_max_generate_length={} bucket_sizes={}",
343
+ int(max_generate_length),
344
+ bucket_sizes,
345
+ )
346
+
347
+ def _resolve_state_audio_patch_count(self, max_audio_patch_count: int) -> int:
348
+ requested = int(max_audio_patch_count)
349
+ if requested <= 0:
350
+ raise ValueError("max_audio_patch_count must be positive.")
351
+ if not self._optimize_enabled:
352
+ return requested
353
+ return self._resolve_generate_length_bucket(requested).size
354
+
355
+ def _warmup_fm_bucket(
356
+ self,
357
+ *,
358
+ max_audio_patch_count: int,
359
+ precision: str,
360
+ ode_method: str,
361
+ num_steps: int,
362
+ guidance_scale: float,
363
+ ) -> None:
364
+ dtype = get_dtype(precision)
365
+ device = next(self.core.parameters()).device
366
+ use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
367
+ with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp):
368
+ state = self._allocate_generate_state(
369
+ max_audio_patch_count=max_audio_patch_count,
370
+ device=device,
371
+ dtype=dtype,
372
+ )
373
+ state.fm_seq_len = state.fm_capacity
374
+ self._decode_next_audio(
375
+ state,
376
+ device=device,
377
+ g_cond=None,
378
+ ode_method=ode_method,
379
+ num_steps=num_steps,
380
+ guidance_scale=guidance_scale,
381
+ )
382
+
383
+ def _warmup_patch_encoder_bucket(
384
+ self,
385
+ *,
386
+ max_audio_patch_count: int,
387
+ precision: str,
388
+ ) -> None:
389
+ dtype = get_dtype(precision)
390
+ device = next(self.core.parameters()).device
391
+ state_dtype = dtype if device.type == "cuda" else torch.float32
392
+ use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
393
+ with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp):
394
+ state_audio_patch_count = self._resolve_state_audio_patch_count(
395
+ max_audio_patch_count
396
+ )
397
+ patch_encoder_state = self.core.patch_encoder.init_decode_state(
398
+ max_audio_patch_count=state_audio_patch_count,
399
+ batch_size=1,
400
+ device=device,
401
+ dtype=state_dtype,
402
+ )
403
+ audio_patch = torch.zeros(
404
+ (
405
+ 1,
406
+ self.core.patch_encoder.patch_size,
407
+ self.core.latent_dim,
408
+ ),
409
+ dtype=state_dtype,
410
+ device=device,
411
+ )
412
+ audio_patch = self.core.io_helper.denormalize(audio_patch)
413
+ patch_encoder_decode = self._get_compiled_method(
414
+ "patch_encoder.decode_patch",
415
+ self.core.patch_encoder,
416
+ "decode_patch",
417
+ signature=self._patch_encoder_compile_signature(patch_encoder_state),
418
+ )
419
+ positions = torch.arange(
420
+ self.core.patch_encoder.out_ds_rate,
421
+ device=device,
422
+ dtype=torch.long,
423
+ )
424
+ with measure_inference("patch_encoder"):
425
+ patch_encoder_decode(
426
+ audio_patch,
427
+ patch_encoder_state.conv_tail,
428
+ patch_encoder_state.layer_caches,
429
+ positions,
430
+ )
431
+
432
+ def _compile_callable(
433
+ self,
434
+ key: str,
435
+ model: Callable[..., Any],
436
+ *,
437
+ signature: tuple[Any, ...] | None = None,
438
+ ) -> Callable[..., Any]:
439
+ compile_target = key.split(".", maxsplit=1)[0]
440
+ cache_key = (key, signature)
441
+ compiled = self._compiled_models.get(cache_key)
442
+ if compiled is None:
443
+ mode = (
444
+ "default"
445
+ if key == "patch_encoder.decode_patch"
446
+ else "reduce-overhead"
447
+ )
448
+ compiled = torch.compile(
449
+ model,
450
+ mode=mode,
451
+ fullgraph=True,
452
+ dynamic=False,
453
+ )
454
+ self._compiled_models[cache_key] = compiled
455
+ logger.info(
456
+ "Compiled inference target: key={} target={} signature={}",
457
+ key,
458
+ compile_target,
459
+ signature,
460
+ )
461
+ return compiled
462
+
463
+ def _get_compiled_model(
464
+ self,
465
+ key: str,
466
+ model: Callable[..., Any],
467
+ *,
468
+ signature: tuple[Any, ...] | None = None,
469
+ ) -> Callable[..., Any]:
470
+ compile_target = key.split(".", maxsplit=1)[0]
471
+ if not self._optimize_enabled or compile_target not in self._COMPILE_TARGETS:
472
+ return model
473
+ return self._compile_callable(
474
+ key,
475
+ model,
476
+ signature=signature,
477
+ )
478
+
479
+ def _get_compiled_method(
480
+ self,
481
+ key: str,
482
+ owner: Any,
483
+ method_name: str,
484
+ *,
485
+ signature: tuple[Any, ...] | None = None,
486
+ ) -> Callable[..., Any]:
487
+ bound_method = getattr(owner, method_name)
488
+ compile_target = key.split(".", maxsplit=1)[0]
489
+ if not self._optimize_enabled or compile_target not in self._COMPILE_TARGETS:
490
+ return bound_method
491
+
492
+ cache_key = (key, signature)
493
+ if self._compile_backend in _AOTI_BACKENDS:
494
+ compiled = self._compiled_models.get(cache_key)
495
+ if compiled is None:
496
+ compiled = _LazyAotiCompiledMethod(
497
+ key=key,
498
+ owner=owner,
499
+ method_name=method_name,
500
+ signature=signature,
501
+ )
502
+ self._compiled_models[cache_key] = compiled
503
+ return compiled
504
+
505
+ raw_method = getattr(type(owner), method_name)
506
+ raw_callable = getattr(raw_method, "__wrapped__", raw_method)
507
+ compiled = self._compile_callable(
508
+ key,
509
+ raw_callable,
510
+ signature=signature,
511
+ )
512
+ return partial(compiled, owner)
513
+
514
+ def _allocate_generate_state(
515
+ self,
516
+ *,
517
+ max_audio_patch_count: int,
518
+ device: torch.device,
519
+ dtype: torch.dtype,
520
+ ) -> _GenerateState:
521
+ state_dtype = dtype if device.type == "cuda" else torch.float32
522
+ state_audio_patch_count = self._resolve_state_audio_patch_count(
523
+ max_audio_patch_count
524
+ )
525
+ fm_capacity = state_audio_patch_count * (
526
+ self.core.hidden_patch_size + self.core.latent_patch_size
527
+ )
528
+ workspace_key = (
529
+ state_audio_patch_count,
530
+ str(device),
531
+ state_dtype,
532
+ )
533
+ workspace = self._static_generate_workspaces.get(workspace_key)
534
+ if workspace is None:
535
+ workspace = {
536
+ "fm_sequence": torch.zeros(
537
+ (1, fm_capacity, self.core.fm_hidden_size),
538
+ dtype=state_dtype,
539
+ device=device,
540
+ ),
541
+ "fm_cfg_sequence": torch.zeros(
542
+ (1, fm_capacity, self.core.fm_hidden_size),
543
+ dtype=state_dtype,
544
+ device=device,
545
+ ),
546
+ "fm_null_g_cond": torch.zeros(
547
+ (1, self.core.fm_hidden_size),
548
+ dtype=state_dtype,
549
+ device=device,
550
+ ),
551
+ }
552
+ self._static_generate_workspaces[workspace_key] = workspace
553
+ else:
554
+ workspace["fm_sequence"].zero_()
555
+ workspace["fm_cfg_sequence"].zero_()
556
+
557
+ patch_encoder_state = None
558
+ if not self._optimize_enabled:
559
+ patch_encoder_state = self.core.patch_encoder.init_decode_state(
560
+ max_audio_patch_count=state_audio_patch_count,
561
+ batch_size=1,
562
+ device=device,
563
+ dtype=state_dtype,
564
+ )
565
+
566
+ return _GenerateState(
567
+ patch_encoder_state=patch_encoder_state,
568
+ fm_seq_len=0,
569
+ fm_capacity=fm_capacity,
570
+ fm_sequence=workspace["fm_sequence"],
571
+ fm_cfg_sequence=workspace["fm_cfg_sequence"],
572
+ fm_null_g_cond=workspace["fm_null_g_cond"],
573
+ )
574
+
575
+ @staticmethod
576
+ def _tensor_storage_signature(tensor: torch.Tensor) -> tuple:
577
+ return (
578
+ tensor.untyped_storage().data_ptr(),
579
+ tensor.storage_offset(),
580
+ tuple(tensor.size()),
581
+ tuple(tensor.stride()),
582
+ tensor.dtype,
583
+ )
584
+
585
+ @classmethod
586
+ def _build_artifact_state_dict(cls, module) -> dict[str, torch.Tensor]:
587
+ state_dict = module.state_dict()
588
+ skip_keys = set()
589
+
590
+ for redundant_key, canonical_key in cls._ARTIFACT_ALIASES:
591
+ redundant_tensor = state_dict.get(redundant_key)
592
+ canonical_tensor = state_dict.get(canonical_key)
593
+ if (
594
+ redundant_tensor is not None
595
+ and canonical_tensor is not None
596
+ and cls._tensor_storage_signature(redundant_tensor)
597
+ == cls._tensor_storage_signature(canonical_tensor)
598
+ ):
599
+ skip_keys.add(redundant_key)
600
+
601
+ cleaned_state_dict = {}
602
+ seen_storage = set()
603
+ for key, value in state_dict.items():
604
+ if key in skip_keys:
605
+ continue
606
+
607
+ storage_signature = cls._tensor_storage_signature(value)
608
+ if storage_signature in seen_storage:
609
+ continue
610
+
611
+ seen_storage.add(storage_signature)
612
+ cleaned_state_dict[key] = value.detach().cpu().contiguous()
613
+
614
+ return cleaned_state_dict
615
+
616
+ @classmethod
617
+ def _restore_artifact_state_dict(cls, state_dict: dict, module) -> dict:
618
+ restored_state_dict = dict(state_dict)
619
+ for redundant_key, canonical_key in cls._ARTIFACT_ALIASES:
620
+ if (
621
+ canonical_key in restored_state_dict
622
+ and redundant_key not in restored_state_dict
623
+ and redundant_key in module.state_dict()
624
+ ):
625
+ restored_state_dict[redundant_key] = restored_state_dict[canonical_key]
626
+ return restored_state_dict
627
+
628
+ @classmethod
629
+ def _save_artifact_module(cls, module, path: Path) -> None:
630
+ save_file(cls._build_artifact_state_dict(module), path)
631
+
632
+ @classmethod
633
+ def _load_artifact_module(cls, module, path: Path):
634
+ state_dict = load_file(path, device="cpu")
635
+ restored_state_dict = cls._restore_artifact_state_dict(state_dict, module)
636
+ mismatch = module.load_state_dict(restored_state_dict, strict=False)
637
+ if mismatch.missing_keys or mismatch.unexpected_keys:
638
+ raise RuntimeError(f"Failed to load {path}: {mismatch}")
639
+ return module
640
+
641
+ @classmethod
642
+ def _validate_pretrained_directory(
643
+ cls, pretrained_model_name_or_path: str | Path
644
+ ) -> Path:
645
+ pretrained_path = Path(pretrained_model_name_or_path).expanduser().resolve()
646
+ missing_files = [
647
+ name
648
+ for name in cls.REQUIRED_ARTIFACT_FILES
649
+ if not (pretrained_path / name).is_file()
650
+ ]
651
+ if missing_files:
652
+ raise FileNotFoundError(
653
+ f"Pretrained path {pretrained_path} is missing required files: {missing_files}"
654
+ )
655
+ return pretrained_path
656
+
657
+ @classmethod
658
+ def _load_pretrained_config(cls, pretrained_path: Path) -> ModelConfig:
659
+ return ModelConfig.model_validate(
660
+ json.loads(
661
+ (pretrained_path / cls.CONFIG_FILENAME).read_text(encoding="utf-8")
662
+ )
663
+ )
664
+
665
+ @staticmethod
666
+ def _save_llm_config(llm_config: Qwen2Config, path: Path) -> None:
667
+ path.write_text(
668
+ json.dumps(llm_config.to_dict(), ensure_ascii=True, indent=2),
669
+ encoding="utf-8",
670
+ )
671
+
672
+ @staticmethod
673
+ def _load_llm_config(path: Path) -> Qwen2Config:
674
+ return Qwen2Config.from_dict(json.loads(path.read_text(encoding="utf-8")))
675
+
676
+ def _tie_llm_weights(self) -> None:
677
+ if hasattr(self.core.llm, "tie_weights"):
678
+ self.core.llm.tie_weights()
679
+
680
+ def save_pretrained(self, save_directory: str | Path) -> Path:
681
+ save_directory = Path(save_directory)
682
+ save_directory.mkdir(parents=True, exist_ok=True)
683
+
684
+ config_payload = self.config.to_declared_dict()
685
+ config_payload["model_type"] = self.HF_MODEL_TYPE
686
+ config_payload["architectures"] = list(self.HF_ARCHITECTURES)
687
+ (save_directory / self.CONFIG_FILENAME).write_text(
688
+ json.dumps(config_payload, ensure_ascii=True, indent=2),
689
+ encoding="utf-8",
690
+ )
691
+ self._save_llm_config(
692
+ self.core.llm.config,
693
+ save_directory / self.LLM_CONFIG_FILENAME,
694
+ )
695
+ self.tokenizer.save_pretrained(save_directory)
696
+ shutil.copy2(
697
+ self.latent_stats_path,
698
+ save_directory / self.LATENT_STATS_FILENAME,
699
+ )
700
+ self._save_artifact_module(self.core, save_directory / self.MODEL_FILENAME)
701
+ self._save_artifact_module(self.vocoder, save_directory / self.VOCODER_FILENAME)
702
+ self._save_artifact_module(
703
+ self.xvector_extractor,
704
+ save_directory / self.SPEAKER_ENCODER_FILENAME,
705
+ )
706
+ return save_directory
707
+
708
+ def _load_pretrained_artifacts(self, pretrained_path: Path) -> None:
709
+ self.latent_stats_path = pretrained_path / self.LATENT_STATS_FILENAME
710
+ self.core.io_helper = type(self.core.io_helper)(
711
+ latent_stats_path=self.latent_stats_path
712
+ )
713
+ self._load_artifact_module(self.core, pretrained_path / self.MODEL_FILENAME)
714
+ self._tie_llm_weights()
715
+ self._load_artifact_module(
716
+ self.vocoder, pretrained_path / self.VOCODER_FILENAME
717
+ )
718
+ self._load_artifact_module(
719
+ self.xvector_extractor,
720
+ pretrained_path / self.SPEAKER_ENCODER_FILENAME,
721
+ )
722
+ self.core.eval()
723
+ self.vocoder.eval()
724
+ self.xvector_extractor.eval()
725
+
726
+ def load_pretrained_weights(
727
+ self, pretrained_model_name_or_path: str | Path
728
+ ) -> None:
729
+ pretrained_path = self._validate_pretrained_directory(
730
+ pretrained_model_name_or_path
731
+ )
732
+ saved_config = self._load_pretrained_config(pretrained_path)
733
+ if saved_config.to_declared_dict() != self.config.to_declared_dict():
734
+ raise ValueError(
735
+ f"Pretrained config at {pretrained_path} does not match the current model."
736
+ )
737
+ saved_llm_config = self._load_llm_config(
738
+ pretrained_path / self.LLM_CONFIG_FILENAME
739
+ )
740
+ if saved_llm_config.to_dict() != self.core.llm.config.to_dict():
741
+ raise ValueError(
742
+ f"Pretrained LLM config at {pretrained_path} does not match the current model."
743
+ )
744
+ self._load_pretrained_artifacts(pretrained_path)
745
+
746
+ @classmethod
747
+ def from_pretrained(cls, pretrained_model_name_or_path: str | Path):
748
+ logger.info(
749
+ "DotsTtsModel load started: pretrained_path={}",
750
+ pretrained_model_name_or_path,
751
+ )
752
+ pretrained_model_name_or_path = cls._validate_pretrained_directory(
753
+ pretrained_model_name_or_path
754
+ )
755
+ config = cls._load_pretrained_config(pretrained_model_name_or_path)
756
+ llm_config = cls._load_llm_config(
757
+ pretrained_model_name_or_path / cls.LLM_CONFIG_FILENAME
758
+ )
759
+ logger.info(
760
+ "DotsTtsModel config loaded: pretrained_path={} sample_rate={} patch_size={}",
761
+ pretrained_model_name_or_path,
762
+ config.vocoder.sample_rate,
763
+ config.patch_size,
764
+ )
765
+ tokenizer = AutoTokenizer.from_pretrained(
766
+ str(pretrained_model_name_or_path),
767
+ local_files_only=True,
768
+ )
769
+ model = cls(
770
+ config,
771
+ tokenizer=tokenizer,
772
+ latent_stats_path=pretrained_model_name_or_path / cls.LATENT_STATS_FILENAME,
773
+ llm_config=llm_config,
774
+ )
775
+ model._load_pretrained_artifacts(pretrained_model_name_or_path)
776
+ logger.info(
777
+ "DotsTtsModel load completed: pretrained_path={}",
778
+ pretrained_model_name_or_path,
779
+ )
780
+ return model.eval()
781
+
782
+ # endregion Module assembly and checkpoint IO
783
+
784
+ # region Training batch preparation
785
+ @torch.no_grad()
786
+ def prepare_training_inputs(self, data: dict[str, Any]) -> dict[str, Any]:
787
+ self.vocoder.eval()
788
+ self.xvector_extractor.eval()
789
+ processed = dict(data)
790
+ sample: torch.Tensor | None = data.get("sample")
791
+ sample_lengths: torch.Tensor | None = data.get("sample_lengths")
792
+
793
+ if sample is not None:
794
+ latents = self.vocoder.extract_latents(sample)
795
+ processed["latents"] = latents
796
+ if sample_lengths is not None:
797
+ processed["latent_lengths"] = sample_lengths // self.hop_size
798
+ else:
799
+ processed["latent_lengths"] = torch.full(
800
+ (latents.size(0),),
801
+ latents.size(-1),
802
+ dtype=torch.long,
803
+ device=latents.device,
804
+ )
805
+ processed["latents_sampled"] = self.core.io_helper.sample_from_latent(
806
+ latents
807
+ )
808
+ fbank = data.get("fbank")
809
+ fbank_lengths = data.get("fbank_lengths")
810
+ processed["xvector"] = self.xvector_extractor(
811
+ sample,
812
+ audio_lengths=sample_lengths,
813
+ fbank=fbank,
814
+ fbank_lengths=fbank_lengths,
815
+ )
816
+ else:
817
+ processed["latents"] = None
818
+ processed["latent_lengths"] = None
819
+
820
+ return processed
821
+
822
+ def _build_audio_span_mask(self, token_ids: torch.Tensor) -> torch.Tensor:
823
+ span_mask = torch.zeros_like(token_ids, dtype=torch.bool)
824
+ for token_id in self.core.audio_span_token_ids:
825
+ span_mask = span_mask | (token_ids == token_id)
826
+ return span_mask
827
+
828
+ def _prepare_loss_metadata(self, data: dict[str, Any]) -> dict[str, Any]:
829
+ input_ids: torch.Tensor = data["input_ids"]
830
+ labels: torch.Tensor = data["labels"]
831
+ loss_mask: torch.Tensor = data["loss_mask"]
832
+ input_span_mask = self._build_audio_span_mask(input_ids)
833
+ output_span_mask = self._build_audio_span_mask(labels)
834
+ output_span_mask_float = output_span_mask.to(loss_mask.dtype)
835
+ llm_loss_mask = loss_mask * (1.0 - output_span_mask_float)
836
+ fm_loss_mask = loss_mask * output_span_mask_float
837
+ patch_counts = output_span_mask.sum(dim=1)
838
+ max_patch_count = max(1, int(patch_counts.max().item()))
839
+ fm_patch_mask = loss_mask.new_zeros((loss_mask.size(0), max_patch_count))
840
+ for batch_idx in range(loss_mask.size(0)):
841
+ count = int(patch_counts[batch_idx].item())
842
+ if count <= 0:
843
+ continue
844
+ fm_patch_mask[batch_idx, :count] = fm_loss_mask[batch_idx].masked_select(
845
+ output_span_mask[batch_idx]
846
+ )
847
+
848
+ return {
849
+ "input_span_mask": input_span_mask,
850
+ "output_span_mask": output_span_mask,
851
+ "loss_masks": {
852
+ "ce_loss": llm_loss_mask,
853
+ "fm_loss": fm_patch_mask,
854
+ "eos_loss": self._build_eos_loss_mask(fm_loss_mask),
855
+ },
856
+ }
857
+
858
+ @staticmethod
859
+ def _build_eos_loss_mask(eos_loss_mask: torch.Tensor) -> torch.Tensor:
860
+ batch_size, seq_len = eos_loss_mask.shape
861
+ mask = eos_loss_mask.to(dtype=torch.bool)
862
+ target = torch.zeros((batch_size, seq_len), dtype=torch.bool, device=mask.device)
863
+ mask_counts = mask.sum(dim=1, keepdim=True)
864
+ cumulative = mask.long().cumsum(dim=1)
865
+ target[mask & (cumulative == mask_counts)] = True
866
+
867
+ mask_counts_flat = mask_counts.squeeze(1)
868
+ neg_counts = (mask_counts_flat - 1).clamp_min(0).to(eos_loss_mask.dtype)
869
+ pos_weight = torch.where(
870
+ neg_counts > 0,
871
+ torch.full_like(neg_counts, 0.5),
872
+ torch.ones_like(neg_counts),
873
+ ).unsqueeze(1)
874
+ neg_weight = torch.where(
875
+ neg_counts > 0,
876
+ 0.5 / neg_counts,
877
+ torch.zeros_like(neg_counts),
878
+ ).unsqueeze(1)
879
+
880
+ positive_mask = target & mask
881
+ negative_mask = mask & ~positive_mask
882
+ return torch.where(
883
+ positive_mask,
884
+ pos_weight,
885
+ negative_mask.to(eos_loss_mask.dtype) * neg_weight,
886
+ )
887
+ # endregion Training batch preparation
888
+
889
+ # region Training loss assembly and forward
890
+ @staticmethod
891
+ def _compute_ce_loss_term(
892
+ llm_logits: torch.Tensor,
893
+ llm_labels: torch.Tensor,
894
+ llm_loss_mask: torch.Tensor,
895
+ ) -> LossTerm:
896
+ vocab_size = llm_logits.size(-1)
897
+ ce_loss = F.cross_entropy(
898
+ llm_logits.view(-1, vocab_size),
899
+ llm_labels.view(-1),
900
+ reduction="none",
901
+ ).view_as(llm_labels)
902
+ return LossTerm(loss=ce_loss, mask=llm_loss_mask.to(ce_loss.dtype))
903
+
904
+ @staticmethod
905
+ def _compute_fm_loss_term(
906
+ pred: torch.Tensor,
907
+ target: torch.Tensor,
908
+ fm_patch_mask: torch.Tensor,
909
+ ) -> LossTerm:
910
+ batch_size, max_patch_count = fm_patch_mask.shape
911
+ fm_loss = (pred - target).pow(2).mean(dim=2).mean(dim=1)
912
+ loss = fm_loss.new_zeros((batch_size, max_patch_count))
913
+ patch_counts = fm_patch_mask.gt(0).sum(dim=1).tolist()
914
+ expected_count = int(sum(patch_counts))
915
+ if expected_count > 0 and int(fm_loss.numel()) != expected_count:
916
+ raise RuntimeError(
917
+ "Flow-matching loss count mismatch: "
918
+ f"expected {expected_count}, got {int(fm_loss.numel())}."
919
+ )
920
+
921
+ offset = 0
922
+ for batch_idx, patch_count in enumerate(patch_counts):
923
+ if patch_count <= 0:
924
+ continue
925
+ next_offset = offset + int(patch_count)
926
+ loss[batch_idx, :patch_count] = fm_loss[offset:next_offset]
927
+ offset = next_offset
928
+ return LossTerm(loss=loss, mask=fm_patch_mask.to(loss.dtype))
929
+
930
+ @staticmethod
931
+ def _compute_eos_loss_term(
932
+ eos_out: torch.Tensor,
933
+ eos_loss_mask: torch.Tensor,
934
+ ) -> LossTerm:
935
+ batch_size, seq_len, _ = eos_out.shape
936
+ weights = eos_loss_mask.to(device=eos_out.device)
937
+ mask = weights.gt(0)
938
+ target = torch.zeros(
939
+ (batch_size, seq_len),
940
+ dtype=torch.long,
941
+ device=eos_out.device,
942
+ )
943
+ mask_counts = mask.sum(dim=1, keepdim=True)
944
+ cumulative = mask.long().cumsum(dim=1)
945
+ target[mask & (cumulative == mask_counts)] = 1
946
+
947
+ logits = rearrange(eos_out, "b n c -> b c n")
948
+ ce_per_token = F.cross_entropy(logits, target, reduction="none")
949
+ return LossTerm(loss=ce_per_token, mask=weights.to(ce_per_token.dtype))
950
+
951
+ @staticmethod
952
+ def _compute_eos_loss_stats(
953
+ eos_out: torch.Tensor,
954
+ eos_loss_mask: torch.Tensor,
955
+ ) -> tuple[torch.Tensor, torch.Tensor]:
956
+ weights = DotsTtsModel._build_eos_loss_mask(eos_loss_mask)
957
+ term = DotsTtsModel._compute_eos_loss_term(eos_out, weights)
958
+ mask = term.mask.to(device=term.loss.device, dtype=term.loss.dtype)
959
+ eos_loss_sum = (term.loss * mask).sum(dim=1)
960
+ eos_sample_count = eos_loss_mask.to(device=term.loss.device).gt(0).any(
961
+ dim=1
962
+ ).to(term.loss.dtype)
963
+ return eos_loss_sum, eos_sample_count
964
+
965
+ def _compute_loss_terms(
966
+ self,
967
+ outputs: DotsTtsForwardOutput,
968
+ *,
969
+ labels: torch.Tensor,
970
+ loss_masks: LossMasks,
971
+ ) -> LossTerms:
972
+ return {
973
+ "ce_loss": self._compute_ce_loss_term(
974
+ outputs.llm_logits,
975
+ labels,
976
+ loss_masks["ce_loss"],
977
+ ),
978
+ "fm_loss": self._compute_fm_loss_term(
979
+ outputs.pred,
980
+ outputs.target,
981
+ loss_masks["fm_loss"],
982
+ ),
983
+ "eos_loss": self._compute_eos_loss_term(
984
+ outputs.eos_out,
985
+ loss_masks["eos_loss"],
986
+ ),
987
+ }
988
+
989
+ def prepare_training_batch(self, data: dict[str, Any]) -> dict[str, Any]:
990
+ prepared = dict(data)
991
+ prepared.update(self._prepare_loss_metadata(prepared))
992
+ return prepared
993
+
994
+ def forward(self, data: dict[str, Any]) -> LossTerms:
995
+ loss_masks: LossMasks = data["loss_masks"]
996
+ processed = self.prepare_training_inputs(data)
997
+ processed["input_span_mask"] = data["input_span_mask"]
998
+ processed["output_span_mask"] = data["output_span_mask"]
999
+ return self._compute_loss_terms(
1000
+ self.core(processed),
1001
+ labels=processed["labels"],
1002
+ loss_masks=loss_masks,
1003
+ )
1004
+ # endregion Training loss assembly and forward
1005
+
1006
+ # region Prompt conditioning and decode state helpers
1007
+ @torch.no_grad()
1008
+ def _prepare_prompt_conditioning(
1009
+ self,
1010
+ prompt_audio: torch.Tensor | None,
1011
+ *,
1012
+ use_prompt_prefill: bool,
1013
+ speaker_scale: float = 1.5,
1014
+ ) -> _PromptConditioning:
1015
+ if prompt_audio is None:
1016
+ logger.info("Prompt conditioning skipped: no prompt audio provided.")
1017
+ return _PromptConditioning()
1018
+
1019
+ self.vocoder.eval()
1020
+ self.xvector_extractor.eval()
1021
+ device = next(self.core.parameters()).device
1022
+ if prompt_audio.ndim == 1:
1023
+ prompt_audio = prompt_audio.unsqueeze(0)
1024
+ prompt_audio = prompt_audio.to(device=device)
1025
+
1026
+ target_len = math.ceil(
1027
+ prompt_audio.size(1) / (self.config.patch_size * self.hop_size)
1028
+ ) * (self.config.patch_size * self.hop_size)
1029
+ pad_len = target_len - prompt_audio.size(1)
1030
+ if pad_len > 0:
1031
+ prompt_audio = F.pad(prompt_audio, (0, pad_len))
1032
+
1033
+ speaker_encoder = self._get_compiled_model(
1034
+ "speaker_encoder",
1035
+ self.xvector_extractor,
1036
+ )
1037
+ with measure_inference("speaker_encoder"):
1038
+ speaker_embedding = (
1039
+ speaker_encoder(prompt_audio[None, :]) * float(speaker_scale)
1040
+ )
1041
+ g_cond = self.core.xvec_proj(speaker_embedding)
1042
+ if not use_prompt_prefill:
1043
+ logger.info(
1044
+ "Reference-audio-only conditioning prepared: prompt_samples={} speaker_scale={} device={}",
1045
+ prompt_audio.shape[-1],
1046
+ speaker_scale,
1047
+ device,
1048
+ )
1049
+ return _PromptConditioning(g_cond=g_cond)
1050
+
1051
+ latent_encoder = self._get_compiled_model(
1052
+ "latent_encoder",
1053
+ self.vocoder.extract_latents,
1054
+ )
1055
+ with measure_inference("latent_encoder"):
1056
+ prompt_latents = latent_encoder(prompt_audio[None, :])
1057
+ prompt_latents_sampled = self.core.io_helper.sample_from_latent(prompt_latents)
1058
+ prompt_latents_sampled = prompt_latents_sampled[:, : -self.config.patch_size]
1059
+ prompt_patches = rearrange(
1060
+ self.core.io_helper.normalize(prompt_latents_sampled),
1061
+ "b (s p) d -> b s p d",
1062
+ p=self.config.patch_size,
1063
+ )
1064
+ logger.info(
1065
+ "Prompt conditioning prepared: prompt_samples={} prompt_patch_count={} "
1066
+ "speaker_scale={} device={}",
1067
+ prompt_audio.shape[-1],
1068
+ prompt_patches.size(1),
1069
+ speaker_scale,
1070
+ device,
1071
+ )
1072
+ return _PromptConditioning(
1073
+ prompt_patches=prompt_patches,
1074
+ prompt_latents=prompt_latents_sampled,
1075
+ g_cond=g_cond,
1076
+ )
1077
+
1078
+ @staticmethod
1079
+ def _patch_encoder_compile_signature(
1080
+ patch_encoder_state: Any,
1081
+ ) -> tuple[int, torch.dtype]:
1082
+ key_cache, _ = patch_encoder_state.layer_caches[0]
1083
+ return int(key_cache.size(2)), key_cache.dtype
1084
+
1085
+ def _resolve_patch_encoder_audio_bucket(self, required_seq_len: int) -> int:
1086
+ requested = int(required_seq_len)
1087
+ if requested <= 0:
1088
+ raise ValueError("required_seq_len must be positive.")
1089
+ requested_patch_count = math.ceil(
1090
+ requested / self.core.patch_encoder.out_ds_rate
1091
+ )
1092
+ if not self._optimize_enabled:
1093
+ return requested_patch_count
1094
+ return self._resolve_generate_length_bucket(requested_patch_count).size
1095
+
1096
+ def _copy_patch_encoder_state(self, source: Any, target: Any) -> None:
1097
+ seq_len = source.seq_len
1098
+ target_capacity = int(target.layer_caches[0][0].size(2))
1099
+ if seq_len > target_capacity:
1100
+ raise ValueError(
1101
+ "Patch encoder state copy exceeds target capacity: "
1102
+ f"seq_len={seq_len} capacity={target_capacity}."
1103
+ )
1104
+
1105
+ target.conv_tail.copy_(source.conv_tail)
1106
+ target.seq_len = seq_len
1107
+ for (source_key, source_value), (target_key, target_value) in zip(
1108
+ source.layer_caches,
1109
+ target.layer_caches,
1110
+ strict=True,
1111
+ ):
1112
+ if seq_len > 0:
1113
+ target_key[:, :, :seq_len, :].copy_(source_key[:, :, :seq_len, :])
1114
+ target_value[:, :, :seq_len, :].copy_(source_value[:, :, :seq_len, :])
1115
+
1116
+ def _ensure_patch_encoder_state_capacity(
1117
+ self,
1118
+ state: _GenerateState,
1119
+ *,
1120
+ required_seq_len: int,
1121
+ device: torch.device,
1122
+ dtype: torch.dtype,
1123
+ ) -> None:
1124
+ current_state = state.patch_encoder_state
1125
+ if current_state is not None:
1126
+ current_capacity = int(current_state.layer_caches[0][0].size(2))
1127
+ if current_capacity >= required_seq_len:
1128
+ return
1129
+
1130
+ target_audio_patch_count = self._resolve_patch_encoder_audio_bucket(
1131
+ required_seq_len
1132
+ )
1133
+ next_state = self.core.patch_encoder.init_decode_state(
1134
+ max_audio_patch_count=target_audio_patch_count,
1135
+ batch_size=1,
1136
+ device=device,
1137
+ dtype=dtype,
1138
+ )
1139
+ if current_state is not None:
1140
+ self._copy_patch_encoder_state(current_state, next_state)
1141
+ state.patch_encoder_state = next_state
1142
+
1143
+ def _prefill_prompt_latents(
1144
+ self,
1145
+ prompt_latents: torch.Tensor | None,
1146
+ *,
1147
+ state: _GenerateState,
1148
+ ) -> torch.Tensor | None:
1149
+ if prompt_latents is None:
1150
+ return None
1151
+ if prompt_latents.size(1) == 0:
1152
+ return prompt_latents.new_zeros(
1153
+ (prompt_latents.size(0), 0, self.core.llm_hidden_size)
1154
+ )
1155
+ self._ensure_patch_encoder_state_capacity(
1156
+ state,
1157
+ required_seq_len=(
1158
+ (prompt_latents.size(1) // self.core.patch_encoder.patch_size)
1159
+ * self.core.patch_encoder.out_ds_rate
1160
+ ),
1161
+ device=prompt_latents.device,
1162
+ dtype=(
1163
+ state.fm_sequence.dtype
1164
+ if state.fm_sequence is not None
1165
+ else prompt_latents.dtype
1166
+ ),
1167
+ )
1168
+ with measure_inference("patch_encoder"):
1169
+ prompt_patch_embeddings, state.patch_encoder_state = (
1170
+ self.core.patch_encoder.prefill(
1171
+ prompt_latents,
1172
+ state.patch_encoder_state,
1173
+ )
1174
+ )
1175
+ return prompt_patch_embeddings
1176
+
1177
+ def _get_fm_decode_workspace(
1178
+ self,
1179
+ *,
1180
+ total_len: int,
1181
+ device: torch.device,
1182
+ dtype: torch.dtype,
1183
+ ) -> dict[str, torch.Tensor]:
1184
+ workspace_key = (total_len, str(device), dtype)
1185
+ workspace = self._fm_decode_workspaces.get(workspace_key)
1186
+ if workspace is None:
1187
+ workspace = {
1188
+ "input_sequence": torch.zeros(
1189
+ (1, total_len, self.core.fm_hidden_size),
1190
+ dtype=dtype,
1191
+ device=device,
1192
+ ),
1193
+ "cfg_sequence": torch.zeros(
1194
+ (1, total_len, self.core.fm_hidden_size),
1195
+ dtype=dtype,
1196
+ device=device,
1197
+ ),
1198
+ "attn_mask": torch.zeros(
1199
+ (1, total_len, total_len),
1200
+ dtype=torch.bool,
1201
+ device=device,
1202
+ ),
1203
+ "pos_ids": torch.zeros(
1204
+ (1, total_len),
1205
+ dtype=torch.float32,
1206
+ device=device,
1207
+ ),
1208
+ }
1209
+ self._fm_decode_workspaces[workspace_key] = workspace
1210
+ else:
1211
+ workspace["input_sequence"].zero_()
1212
+ workspace["cfg_sequence"].zero_()
1213
+ return workspace
1214
+
1215
+ def _resolve_fm_history_bucket_capacity(self, fm_seq_len: int) -> int:
1216
+ requested = int(fm_seq_len)
1217
+ if requested <= 0:
1218
+ raise ValueError("fm_seq_len must be positive.")
1219
+ if not self._optimize_enabled:
1220
+ return requested
1221
+ history_stride = self.core.hidden_patch_size + self.core.latent_patch_size
1222
+ requested_patch_count = math.ceil(requested / history_stride)
1223
+ return self._resolve_generate_length_bucket(
1224
+ requested_patch_count
1225
+ ).size * history_stride
1226
+
1227
+ def _build_fm_attn_mask(
1228
+ self,
1229
+ *,
1230
+ state: _GenerateState,
1231
+ attn_mask: torch.Tensor,
1232
+ ) -> torch.Tensor:
1233
+ if state.fm_seq_len <= 0:
1234
+ raise RuntimeError("FM sequence length must be positive before decode.")
1235
+ hidden_patch_size = self.core.hidden_patch_size
1236
+ latent_start = attn_mask.size(-1) - self.core.latent_patch_size
1237
+ attn_mask.zero_()
1238
+ block_start = state.fm_seq_len - hidden_patch_size
1239
+ if block_start > 0:
1240
+ causal_mask = torch.ones(
1241
+ (block_start, block_start),
1242
+ device=attn_mask.device,
1243
+ dtype=torch.bool,
1244
+ ).triu(1).logical_not()
1245
+ attn_mask[:, :block_start, :block_start] = causal_mask
1246
+
1247
+ attn_mask[:, block_start : state.fm_seq_len, : state.fm_seq_len] = True
1248
+ attn_mask[:, block_start : state.fm_seq_len, latent_start:] = True
1249
+ attn_mask[:, latent_start:, : state.fm_seq_len] = True
1250
+ attn_mask[:, latent_start:, latent_start:] = True
1251
+ if latent_start > state.fm_seq_len:
1252
+ padding_indices = torch.arange(
1253
+ state.fm_seq_len,
1254
+ latent_start,
1255
+ device=attn_mask.device,
1256
+ )
1257
+ attn_mask[:, padding_indices, padding_indices] = True
1258
+ return attn_mask
1259
+
1260
+ def _build_fm_pos_ids(
1261
+ self,
1262
+ *,
1263
+ state: _GenerateState,
1264
+ pos_ids: torch.Tensor,
1265
+ ) -> torch.Tensor:
1266
+ if state.fm_seq_len <= 0:
1267
+ raise RuntimeError("FM sequence length must be positive before decode.")
1268
+ pos_ids.zero_()
1269
+ latent_start = pos_ids.size(-1) - self.core.latent_patch_size
1270
+ pos_ids[:, : state.fm_seq_len] = torch.arange(
1271
+ state.fm_seq_len,
1272
+ device=pos_ids.device,
1273
+ dtype=pos_ids.dtype,
1274
+ )
1275
+ pos_ids[:, latent_start:] = torch.arange(
1276
+ state.fm_seq_len,
1277
+ state.fm_seq_len + self.core.latent_patch_size,
1278
+ device=pos_ids.device,
1279
+ dtype=pos_ids.dtype,
1280
+ )
1281
+ return pos_ids
1282
+
1283
+ def _prepare_fm_decode_inputs(
1284
+ self,
1285
+ state: _GenerateState,
1286
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
1287
+ sequence = state.fm_sequence
1288
+ cfg_sequence = state.fm_cfg_sequence
1289
+ if sequence is None or cfg_sequence is None:
1290
+ raise RuntimeError("FM static buffers are not initialized.")
1291
+ history_bucket_capacity = self._resolve_fm_history_bucket_capacity(
1292
+ state.fm_seq_len
1293
+ )
1294
+ total_len = history_bucket_capacity + self.core.latent_patch_size
1295
+ workspace = self._get_fm_decode_workspace(
1296
+ total_len=total_len,
1297
+ device=sequence.device,
1298
+ dtype=sequence.dtype,
1299
+ )
1300
+ workspace["input_sequence"][:, : state.fm_seq_len].copy_(
1301
+ sequence[:, : state.fm_seq_len]
1302
+ )
1303
+ workspace["cfg_sequence"][:, : state.fm_seq_len].copy_(
1304
+ cfg_sequence[:, : state.fm_seq_len]
1305
+ )
1306
+ return (
1307
+ workspace["input_sequence"],
1308
+ workspace["cfg_sequence"],
1309
+ workspace["attn_mask"],
1310
+ workspace["pos_ids"],
1311
+ history_bucket_capacity,
1312
+ )
1313
+
1314
+ def _append_to_fm_buffer(
1315
+ self,
1316
+ buffer: torch.Tensor | None,
1317
+ state: _GenerateState,
1318
+ chunk: torch.Tensor,
1319
+ ) -> tuple[int, int]:
1320
+ if buffer is None:
1321
+ raise RuntimeError("FM static buffer is not initialized.")
1322
+ start = state.fm_seq_len
1323
+ end = start + chunk.size(1)
1324
+ if end > state.fm_capacity:
1325
+ raise RuntimeError(
1326
+ "FM StaticBuffer capacity exceeded: "
1327
+ f"next_length={end} capacity={state.fm_capacity}."
1328
+ )
1329
+ buffer[:, start:end].copy_(chunk.to(buffer.dtype))
1330
+ return start, end
1331
+
1332
+ def _append_hidden_chunk(
1333
+ self, state: _GenerateState, hidden_chunk: torch.Tensor
1334
+ ) -> None:
1335
+ last_hidden = hidden_chunk[:, -self.core.hidden_patch_size :, :]
1336
+ projected = self.core.hidden_proj(last_hidden)
1337
+ null_projected = self.core.hidden_proj(torch.zeros_like(last_hidden))
1338
+ _start, end = self._append_to_fm_buffer(
1339
+ state.fm_sequence,
1340
+ state,
1341
+ projected,
1342
+ )
1343
+ cfg_buffer = state.fm_cfg_sequence
1344
+ if cfg_buffer is None:
1345
+ raise RuntimeError("FM cfg static buffer is not initialized.")
1346
+ cfg_buffer[:, state.fm_seq_len : end].copy_(null_projected.to(cfg_buffer.dtype))
1347
+ state.fm_seq_len = end
1348
+
1349
+ def _append_history_chunk(
1350
+ self, state: _GenerateState, latent_chunk: torch.Tensor
1351
+ ) -> None:
1352
+ history_latent = self.core.latent_proj(latent_chunk)
1353
+ _start, end = self._append_to_fm_buffer(
1354
+ state.fm_sequence,
1355
+ state,
1356
+ history_latent,
1357
+ )
1358
+ cfg_buffer = state.fm_cfg_sequence
1359
+ if cfg_buffer is None:
1360
+ raise RuntimeError("FM cfg static buffer is not initialized.")
1361
+ cfg_buffer[:, state.fm_seq_len : end].copy_(history_latent.to(cfg_buffer.dtype))
1362
+ state.fm_seq_len = end
1363
+
1364
+ def _consume_text_schedule(
1365
+ self,
1366
+ generation_schedule: torch.Tensor,
1367
+ *,
1368
+ position: int,
1369
+ next_audio_position: int,
1370
+ state: _GenerateState,
1371
+ ) -> int:
1372
+ with measure_inference("LLM"):
1373
+ text_chunk = generation_schedule[:, position:next_audio_position]
1374
+ _, state.llm_hiddens, _, state.llm_cache = self.core.step_llm(
1375
+ input_ids=text_chunk,
1376
+ past_key_values=state.llm_cache,
1377
+ )
1378
+ self._append_hidden_chunk(state, state.llm_hiddens)
1379
+ return next_audio_position
1380
+
1381
+ def _locate_prefill_boundary(
1382
+ self,
1383
+ *,
1384
+ span_positions: torch.Tensor,
1385
+ prompt_patch_count: int,
1386
+ ) -> tuple[int, torch.Tensor]:
1387
+ if span_positions.numel() > prompt_patch_count:
1388
+ return int(span_positions[prompt_patch_count].item()), span_positions[
1389
+ :prompt_patch_count
1390
+ ]
1391
+ raise RuntimeError(
1392
+ "Prefill boundary discovery failed despite prior schedule validation."
1393
+ )
1394
+
1395
+ @staticmethod
1396
+ def _find_audio_span_positions(
1397
+ generation_schedule: torch.Tensor,
1398
+ *,
1399
+ audio_placeholder_ids: set[int],
1400
+ ) -> torch.Tensor:
1401
+ schedule = generation_schedule[0]
1402
+ placeholder_ids = torch.tensor(
1403
+ sorted(audio_placeholder_ids),
1404
+ device=schedule.device,
1405
+ dtype=schedule.dtype,
1406
+ )
1407
+ return torch.nonzero(
1408
+ torch.isin(schedule, placeholder_ids),
1409
+ as_tuple=False,
1410
+ ).squeeze(-1)
1411
+
1412
+ @staticmethod
1413
+ def _next_token_is_audio_span(
1414
+ generation_schedule: torch.Tensor,
1415
+ *,
1416
+ position: int,
1417
+ audio_placeholder_ids: set[int],
1418
+ ) -> bool:
1419
+ next_position = position + 1
1420
+ if next_position >= generation_schedule.size(1):
1421
+ return False
1422
+ return int(generation_schedule[0, next_position].item()) in audio_placeholder_ids
1423
+
1424
+ def _build_prefill_inputs_embeds(
1425
+ self,
1426
+ generation_schedule: torch.Tensor,
1427
+ *,
1428
+ prompt_patch_embeddings: torch.Tensor | None,
1429
+ prompt_span_positions: torch.Tensor,
1430
+ ) -> torch.Tensor:
1431
+ inputs_embeds = self.core.llm.get_input_embeddings()(
1432
+ generation_schedule
1433
+ ).clone()
1434
+ if prompt_span_positions.numel() > 0:
1435
+ if prompt_patch_embeddings is None:
1436
+ raise RuntimeError(
1437
+ "Prompt patch embeddings are required when prefill includes prompt audio spans."
1438
+ )
1439
+ patch_embeddings = prompt_patch_embeddings[
1440
+ :, : prompt_span_positions.numel()
1441
+ ].to(inputs_embeds.dtype)
1442
+ if patch_embeddings.size(1) != prompt_span_positions.numel():
1443
+ raise RuntimeError(
1444
+ f"Prompt patch embeddings ({patch_embeddings.size(1)}) do not match prompt span count ({prompt_span_positions.numel()})."
1445
+ )
1446
+ inputs_embeds[:, prompt_span_positions, :] = patch_embeddings
1447
+ return inputs_embeds
1448
+
1449
+ def _prefill(
1450
+ self,
1451
+ generation_schedule: torch.Tensor,
1452
+ *,
1453
+ state: _GenerateState,
1454
+ span_positions: torch.Tensor,
1455
+ prompt_patches: torch.Tensor | None,
1456
+ prompt_patch_embeddings: torch.Tensor | None,
1457
+ audio_placeholder_ids: set[int],
1458
+ ) -> int:
1459
+ prompt_patch_count = (
1460
+ 0 if prompt_patches is None else int(prompt_patches.size(1))
1461
+ )
1462
+ prefill_end, prompt_span_positions = self._locate_prefill_boundary(
1463
+ span_positions=span_positions,
1464
+ prompt_patch_count=prompt_patch_count,
1465
+ )
1466
+ if prefill_end == 0:
1467
+ return 0
1468
+ inputs_embeds = self._build_prefill_inputs_embeds(
1469
+ generation_schedule[:, :prefill_end],
1470
+ prompt_patch_embeddings=prompt_patch_embeddings,
1471
+ prompt_span_positions=prompt_span_positions,
1472
+ )
1473
+ with measure_inference("LLM"):
1474
+ _, llm_hiddens, _, state.llm_cache = self.core.step_llm(
1475
+ inputs_embeds=inputs_embeds,
1476
+ past_key_values=state.llm_cache,
1477
+ )
1478
+ state.llm_hiddens = llm_hiddens[:, -1:, :]
1479
+
1480
+ cursor = 0
1481
+ for prompt_index, span_position in enumerate(prompt_span_positions.tolist()):
1482
+ if span_position > cursor:
1483
+ self._append_hidden_chunk(
1484
+ state, llm_hiddens[:, span_position - 1 : span_position, :]
1485
+ )
1486
+ self._append_history_chunk(state, prompt_patches[:, prompt_index])
1487
+ if self._next_token_is_audio_span(
1488
+ generation_schedule,
1489
+ position=span_position,
1490
+ audio_placeholder_ids=audio_placeholder_ids,
1491
+ ):
1492
+ self._append_hidden_chunk(
1493
+ state, llm_hiddens[:, span_position : span_position + 1, :]
1494
+ )
1495
+ cursor = span_position + 1
1496
+ if prefill_end > cursor:
1497
+ self._append_hidden_chunk(
1498
+ state, llm_hiddens[:, prefill_end - 1 : prefill_end, :]
1499
+ )
1500
+ return prefill_end
1501
+
1502
+ def _decode_next_audio(
1503
+ self,
1504
+ state: _GenerateState,
1505
+ *,
1506
+ device: torch.device,
1507
+ g_cond: torch.Tensor | None,
1508
+ ode_method: str,
1509
+ num_steps: int,
1510
+ guidance_scale: float,
1511
+ ) -> torch.Tensor:
1512
+ if state.fm_seq_len <= 0:
1513
+ raise RuntimeError(
1514
+ "Cannot decode audio before any conditioning state has been prefetched."
1515
+ )
1516
+ if state.fm_sequence is None or state.fm_cfg_sequence is None:
1517
+ raise RuntimeError("FM static buffers are not initialized.")
1518
+ if state.fm_null_g_cond is None:
1519
+ raise RuntimeError("FM null conditioning buffer is not initialized.")
1520
+ fm_sequence, fm_cfg_sequence, fm_attn_mask, fm_pos_ids, history_bucket_capacity = (
1521
+ self._prepare_fm_decode_inputs(state)
1522
+ )
1523
+ compile_signature = (
1524
+ (history_bucket_capacity, state.fm_sequence.dtype)
1525
+ if self._optimize_enabled
1526
+ else (state.fm_seq_len, state.fm_sequence.dtype)
1527
+ )
1528
+ if g_cond is None:
1529
+ g_cond = state.fm_null_g_cond
1530
+ else:
1531
+ g_cond = g_cond.to(
1532
+ device=state.fm_null_g_cond.device,
1533
+ dtype=state.fm_null_g_cond.dtype,
1534
+ )
1535
+ with measure_inference("FM"):
1536
+ attn_mask = self._build_fm_attn_mask(
1537
+ state=state,
1538
+ attn_mask=fm_attn_mask,
1539
+ )
1540
+ pos_ids = self._build_fm_pos_ids(
1541
+ state=state,
1542
+ pos_ids=fm_pos_ids,
1543
+ )
1544
+ if self.core.mode == "meanflow":
1545
+ fm_solver_step = self._get_compiled_method(
1546
+ "FM.meanflow.solver_step",
1547
+ self.core,
1548
+ "meanflow_solver_step",
1549
+ signature=compile_signature,
1550
+ )
1551
+ return self.core._meanflow_step_fm(
1552
+ input_sequence=fm_sequence,
1553
+ attn_mask=attn_mask,
1554
+ pos_ids=pos_ids,
1555
+ patch_size=self.core.latent_patch_size,
1556
+ g_cond=g_cond,
1557
+ nfe=num_steps,
1558
+ solver_step=fm_solver_step,
1559
+ )
1560
+
1561
+ fm_solver_step = self._get_compiled_method(
1562
+ "FM.flow_matching.solver_step",
1563
+ self.core,
1564
+ "fm_solver_step",
1565
+ signature=compile_signature,
1566
+ )
1567
+ return self.core._flow_matching_step_fm(
1568
+ input_sequence=fm_sequence,
1569
+ cfg_sequence=fm_cfg_sequence,
1570
+ attn_mask=attn_mask,
1571
+ pos_ids=pos_ids,
1572
+ hidden_size=self.core.hidden_patch_size,
1573
+ patch_size=self.core.latent_patch_size,
1574
+ g_cond=g_cond,
1575
+ ode_method=ode_method,
1576
+ num_steps=num_steps,
1577
+ guidance_scale=guidance_scale,
1578
+ solver_step=fm_solver_step,
1579
+ )
1580
+
1581
+ def _consume_audio_patch(
1582
+ self,
1583
+ state: _GenerateState,
1584
+ *,
1585
+ audio_patch: torch.Tensor,
1586
+ ) -> None:
1587
+ audio_patch_for_llm = self.core.io_helper.denormalize(audio_patch)
1588
+ self._append_history_chunk(state, audio_patch)
1589
+ current_seq_len = (
1590
+ 0
1591
+ if state.patch_encoder_state is None
1592
+ else state.patch_encoder_state.seq_len
1593
+ )
1594
+ self._ensure_patch_encoder_state_capacity(
1595
+ state,
1596
+ required_seq_len=current_seq_len + self.core.patch_encoder.out_ds_rate,
1597
+ device=audio_patch_for_llm.device,
1598
+ dtype=(
1599
+ state.fm_sequence.dtype
1600
+ if state.fm_sequence is not None
1601
+ else audio_patch_for_llm.dtype
1602
+ ),
1603
+ )
1604
+ patch_encoder_decode = self._get_compiled_method(
1605
+ "patch_encoder.decode_patch",
1606
+ self.core.patch_encoder,
1607
+ "decode_patch",
1608
+ signature=self._patch_encoder_compile_signature(state.patch_encoder_state),
1609
+ )
1610
+ patch_positions = (
1611
+ torch.arange(
1612
+ self.core.patch_encoder.out_ds_rate,
1613
+ device=audio_patch_for_llm.device,
1614
+ dtype=torch.long,
1615
+ )
1616
+ + state.patch_encoder_state.seq_len
1617
+ )
1618
+ with measure_inference("patch_encoder"):
1619
+ llm_embedding, conv_tail = patch_encoder_decode(
1620
+ audio_patch_for_llm,
1621
+ state.patch_encoder_state.conv_tail,
1622
+ state.patch_encoder_state.layer_caches,
1623
+ patch_positions,
1624
+ )
1625
+ state.patch_encoder_state.conv_tail.copy_(conv_tail)
1626
+ state.patch_encoder_state.seq_len += self.core.patch_encoder.out_ds_rate
1627
+ with measure_inference("LLM"):
1628
+ _, state.llm_hiddens, _, state.llm_cache = self.core.step_llm(
1629
+ inputs_embeds=llm_embedding,
1630
+ past_key_values=state.llm_cache,
1631
+ )
1632
+
1633
+ def _decode(
1634
+ self,
1635
+ generation_schedule: torch.Tensor,
1636
+ *,
1637
+ position: int,
1638
+ state: _GenerateState,
1639
+ audio_placeholder_ids: set[int],
1640
+ span_positions: torch.Tensor,
1641
+ device: torch.device,
1642
+ g_cond: torch.Tensor | None,
1643
+ ode_method: str,
1644
+ num_steps: int,
1645
+ guidance_scale: float,
1646
+ eos_threshold: float,
1647
+ ) -> Iterator[torch.Tensor]:
1648
+ span_cursor = torch.searchsorted(
1649
+ span_positions,
1650
+ torch.tensor(
1651
+ position,
1652
+ device=span_positions.device,
1653
+ dtype=span_positions.dtype,
1654
+ ),
1655
+ ).item()
1656
+ while position < generation_schedule.size(1):
1657
+ token_id = int(generation_schedule[0, position].item())
1658
+ if token_id in audio_placeholder_ids:
1659
+ stop_after_current_audio = self._should_stop_after_current_audio(
1660
+ state,
1661
+ eos_threshold=eos_threshold,
1662
+ )
1663
+ audio_patch = self._decode_next_audio(
1664
+ state,
1665
+ device=device,
1666
+ g_cond=g_cond,
1667
+ ode_method=ode_method,
1668
+ num_steps=num_steps,
1669
+ guidance_scale=guidance_scale,
1670
+ )
1671
+ self._consume_audio_patch(
1672
+ state,
1673
+ audio_patch=audio_patch,
1674
+ )
1675
+ if self._next_token_is_audio_span(
1676
+ generation_schedule,
1677
+ position=position,
1678
+ audio_placeholder_ids=audio_placeholder_ids,
1679
+ ):
1680
+ self._append_hidden_chunk(state, state.llm_hiddens)
1681
+ position += 1
1682
+ span_cursor += 1
1683
+ yield audio_patch
1684
+ if stop_after_current_audio:
1685
+ state.end_flag = True
1686
+ return
1687
+ continue
1688
+ next_audio_position = (
1689
+ int(span_positions[span_cursor].item())
1690
+ if span_cursor < span_positions.numel()
1691
+ else generation_schedule.size(1)
1692
+ )
1693
+ position = self._consume_text_schedule(
1694
+ generation_schedule,
1695
+ position=position,
1696
+ next_audio_position=next_audio_position,
1697
+ state=state,
1698
+ )
1699
+
1700
+ def _should_stop_after_current_audio(
1701
+ self, state: _GenerateState, *, eos_threshold: float
1702
+ ) -> bool:
1703
+ if state.llm_hiddens is None:
1704
+ return False
1705
+ eos = (
1706
+ self.core.eos_proj(state.llm_hiddens).softmax(dim=-1)[:, -1, 1]
1707
+ > eos_threshold
1708
+ )
1709
+ return state.end_flag or bool(eos.item())
1710
+
1711
+ # endregion Prompt conditioning and decode state helpers
1712
+
1713
+ # region Public generation APIs
1714
+ @torch.no_grad()
1715
+ def _generate_latents_stream(
1716
+ self,
1717
+ data: dict[str, Any],
1718
+ *,
1719
+ precision: str,
1720
+ ode_method: str,
1721
+ num_steps: int,
1722
+ guidance_scale: float,
1723
+ speaker_scale: float = 1.5,
1724
+ eos_threshold: float = 0.8,
1725
+ ) -> Iterator[torch.Tensor]:
1726
+ dtype = get_dtype(precision)
1727
+ device = next(self.core.parameters()).device
1728
+ use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
1729
+ with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp):
1730
+ generation_schedule: torch.Tensor = data["generation_schedule"]
1731
+ if generation_schedule.size(0) != 1:
1732
+ raise ValueError(
1733
+ "DotsTtsModel.generate expects batch size 1 for generation_schedule."
1734
+ )
1735
+
1736
+ use_prompt_prefill = data.get("prompt_audio") is not None and bool(
1737
+ data.get("prompt_text")
1738
+ )
1739
+ prompt_conditioning = self._prepare_prompt_conditioning(
1740
+ data.get("prompt_audio"),
1741
+ use_prompt_prefill=use_prompt_prefill,
1742
+ speaker_scale=speaker_scale,
1743
+ )
1744
+ has_prompt_prefill = prompt_conditioning.prompt_patches is not None
1745
+ prompt_patch_count = (
1746
+ 0
1747
+ if not has_prompt_prefill
1748
+ else int(prompt_conditioning.prompt_patches.size(1))
1749
+ )
1750
+ audio_placeholder_ids = set(self.core.audio_span_token_ids)
1751
+ span_positions = self._find_audio_span_positions(
1752
+ generation_schedule,
1753
+ audio_placeholder_ids=audio_placeholder_ids,
1754
+ )
1755
+ span_count = int(span_positions.numel())
1756
+ minimum_required_spans = prompt_patch_count + 1
1757
+ if span_count < minimum_required_spans:
1758
+ raise ValueError(
1759
+ f"generation_schedule provides {span_count} audio spans, but prompt prefill requires "
1760
+ f"{prompt_patch_count} spans and generation requires at least one additional decode span."
1761
+ )
1762
+ logger.info(
1763
+ "Latent generation prepared: schedule_audio_spans={} prompt_patch_count={} "
1764
+ "minimum_required_spans={}",
1765
+ span_count,
1766
+ prompt_patch_count,
1767
+ minimum_required_spans,
1768
+ )
1769
+
1770
+ state = self._allocate_generate_state(
1771
+ max_audio_patch_count=span_count,
1772
+ device=device,
1773
+ dtype=dtype,
1774
+ )
1775
+ prompt_patch_embeddings = self._prefill_prompt_latents(
1776
+ prompt_conditioning.prompt_latents,
1777
+ state=state,
1778
+ )
1779
+ position = self._prefill(
1780
+ generation_schedule,
1781
+ state=state,
1782
+ span_positions=span_positions,
1783
+ prompt_patches=prompt_conditioning.prompt_patches,
1784
+ prompt_patch_embeddings=prompt_patch_embeddings,
1785
+ audio_placeholder_ids=audio_placeholder_ids,
1786
+ )
1787
+
1788
+ payload_patch_count = 0
1789
+ should_drop_regenerated_prompt_patch = has_prompt_prefill
1790
+ for audio_patch in self._decode(
1791
+ generation_schedule,
1792
+ position=position,
1793
+ state=state,
1794
+ audio_placeholder_ids=audio_placeholder_ids,
1795
+ span_positions=span_positions,
1796
+ device=device,
1797
+ g_cond=prompt_conditioning.g_cond,
1798
+ ode_method=ode_method,
1799
+ num_steps=num_steps,
1800
+ guidance_scale=guidance_scale,
1801
+ eos_threshold=eos_threshold,
1802
+ ):
1803
+ if should_drop_regenerated_prompt_patch:
1804
+ should_drop_regenerated_prompt_patch = False
1805
+ continue
1806
+ payload_patch_count += 1
1807
+ if payload_patch_count == 1 or payload_patch_count % 10 == 0:
1808
+ logger.info(
1809
+ "Latent generation progress: payload_audio_patches={}",
1810
+ payload_patch_count,
1811
+ )
1812
+ yield self.core.io_helper.denormalize(audio_patch)
1813
+
1814
+ if payload_patch_count == 0:
1815
+ if has_prompt_prefill:
1816
+ raise RuntimeError(
1817
+ "Generation produced no payload latents after discarding the regenerated prompt-tail patch. "
1818
+ "This usually means EOS triggered immediately after prompt continuation "
1819
+ "or the generation schedule did not provide an effective decode span."
1820
+ )
1821
+ raise RuntimeError(
1822
+ "Generation produced no decodable latents. "
1823
+ "This usually means EOS triggered before the first decode patch "
1824
+ "or the generation schedule did not provide an effective decode span."
1825
+ )
1826
+ logger.info(
1827
+ "Latent generation completed: payload_audio_patches={}",
1828
+ payload_patch_count,
1829
+ )
1830
+
1831
+ @torch.no_grad()
1832
+ def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
1833
+ with measure_inference("latent_decoder"):
1834
+ return self.vocoder.inference_from_latents(
1835
+ latents.transpose(1, 2).float(),
1836
+ do_sample=False,
1837
+ )
1838
+
1839
+ @torch.no_grad()
1840
+ def _init_vocoder_stream_state(self) -> Any:
1841
+ return self.vocoder.init_stream_state(
1842
+ batch_size=1,
1843
+ chunk_size=self.core.latent_patch_size,
1844
+ )
1845
+
1846
+ @torch.no_grad()
1847
+ def _stream_vocoder_patch(
1848
+ self,
1849
+ latent_patch: torch.Tensor,
1850
+ *,
1851
+ stream_state: Any,
1852
+ ) -> torch.Tensor:
1853
+ latents = latent_patch.transpose(1, 2)
1854
+ if not self._optimize_enabled:
1855
+ with measure_inference("vocoder"):
1856
+ return self.vocoder.stream_step(latents, stream_state)
1857
+
1858
+ valid_frames = min(
1859
+ stream_state.decoder.total_frames,
1860
+ stream_state.decoder.window.size(-1),
1861
+ )
1862
+ valid_frames_tensor = stream_state.decoder.window.new_tensor(
1863
+ valid_frames,
1864
+ dtype=torch.int64,
1865
+ )
1866
+ vocoder_step = self._get_compiled_method(
1867
+ "vocoder.step",
1868
+ self.vocoder,
1869
+ "compiled_stream_step",
1870
+ )
1871
+ with measure_inference("vocoder"):
1872
+ audio_window, hidden_h, hidden_c, new_window = vocoder_step(
1873
+ latents,
1874
+ stream_state.lstm_hidden[0],
1875
+ stream_state.lstm_hidden[1],
1876
+ stream_state.decoder.window,
1877
+ valid_frames_tensor,
1878
+ )
1879
+ stream_state.lstm_hidden = (hidden_h.clone(), hidden_c.clone())
1880
+ stream_state.decoder.window = new_window.clone()
1881
+ stream_state.decoder.total_frames += int(latents.size(-1))
1882
+ audio_chunk = self.vocoder._slice_stream_audio_window(
1883
+ audio_window,
1884
+ stream_state,
1885
+ final=False,
1886
+ )
1887
+ return audio_chunk.clone()
1888
+
1889
+ @torch.no_grad()
1890
+ def _flush_vocoder_stream(self, stream_state: Any) -> torch.Tensor:
1891
+ with measure_inference("vocoder"):
1892
+ return self.vocoder.stream_flush(stream_state)
1893
+
1894
+ @torch.no_grad()
1895
+ def generate_audio_stream(
1896
+ self,
1897
+ data: dict[str, Any],
1898
+ *,
1899
+ precision: str,
1900
+ ode_method: str,
1901
+ num_steps: int,
1902
+ guidance_scale: float,
1903
+ speaker_scale: float = 1.5,
1904
+ eos_threshold: float = 0.8,
1905
+ ) -> Iterator[torch.Tensor]:
1906
+ stream_state = self._init_vocoder_stream_state()
1907
+ for latent_patch in self._generate_latents_stream(
1908
+ data,
1909
+ precision=precision,
1910
+ ode_method=ode_method,
1911
+ num_steps=num_steps,
1912
+ guidance_scale=guidance_scale,
1913
+ speaker_scale=speaker_scale,
1914
+ eos_threshold=eos_threshold,
1915
+ ):
1916
+ audio_chunk = self._stream_vocoder_patch(
1917
+ latent_patch,
1918
+ stream_state=stream_state,
1919
+ )
1920
+ if audio_chunk.size(-1) > 0:
1921
+ yield audio_chunk
1922
+
1923
+ final_chunk = self._flush_vocoder_stream(stream_state)
1924
+ if final_chunk.size(-1) > 0:
1925
+ yield final_chunk
1926
+
1927
+ @torch.no_grad()
1928
+ def generate_audio(
1929
+ self,
1930
+ data: dict[str, Any],
1931
+ *,
1932
+ precision: str,
1933
+ ode_method: str,
1934
+ num_steps: int,
1935
+ guidance_scale: float,
1936
+ speaker_scale: float = 1.5,
1937
+ ) -> torch.Tensor:
1938
+ latent_patches = list(
1939
+ self._generate_latents_stream(
1940
+ data,
1941
+ precision=precision,
1942
+ ode_method=ode_method,
1943
+ num_steps=num_steps,
1944
+ guidance_scale=guidance_scale,
1945
+ speaker_scale=speaker_scale,
1946
+ )
1947
+ )
1948
+ logger.info(
1949
+ "Vocoder decode started: latent_patch_count={}",
1950
+ len(latent_patches),
1951
+ )
1952
+ audio = self._decode_latents(torch.cat(latent_patches, dim=1))
1953
+ logger.info(
1954
+ "Vocoder decode completed: waveform_samples={}",
1955
+ audio.shape[-1],
1956
+ )
1957
+ return audio
1958
+ # endregion Public generation APIs
src/dots_tts/modules/__init__.py ADDED
File without changes
src/dots_tts/modules/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Backbone modules."""
src/dots_tts/modules/backbone/dit.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from dots_tts.modules.backbone.layers import Mlp, MultiHeadAttention
7
+
8
+
9
+ def modulate(x, shift, scale, **_kwargs):
10
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
11
+
12
+
13
+ class TimestepEmbedder(nn.Module):
14
+ def __init__(self, hidden_size, frequency_embedding_size=256):
15
+ super().__init__()
16
+ self.mlp = nn.Sequential(
17
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
18
+ nn.SiLU(),
19
+ nn.Linear(hidden_size, hidden_size, bias=True),
20
+ )
21
+ self.frequency_embedding_size = frequency_embedding_size
22
+
23
+ @staticmethod
24
+ def timestep_embedding(t, dim, max_period=10000):
25
+ half = dim // 2
26
+ freqs = torch.exp(
27
+ -math.log(max_period)
28
+ * torch.arange(start=0, end=half, dtype=torch.float32)
29
+ / half
30
+ ).to(device=t.device)
31
+ args = t[:, None].float() * freqs[None]
32
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
33
+ if dim % 2:
34
+ embedding = torch.cat(
35
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
36
+ )
37
+ return embedding
38
+
39
+ def forward(self, t):
40
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
41
+ return self.mlp(t_freq)
42
+
43
+
44
+ class FinalLayer(nn.Module):
45
+ def __init__(self, hidden_size, output_size):
46
+ super().__init__()
47
+ self.adaLN_modulation = nn.Sequential(
48
+ nn.SiLU(),
49
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True),
50
+ )
51
+ self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-5)
52
+ self.linear = nn.Linear(hidden_size, output_size, bias=True)
53
+
54
+ def forward(self, x, c, **_kwargs):
55
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
56
+ x = modulate(self.norm(x), shift, scale)
57
+ return self.linear(x)
58
+
59
+
60
+ class DiTBlock(nn.Module):
61
+ def __init__(
62
+ self,
63
+ attention: nn.Module,
64
+ ffn: nn.Module,
65
+ hidden_size: int = 1024,
66
+ modulation: bool = False,
67
+ eps: float = 1e-5,
68
+ **_kwargs,
69
+ ):
70
+ super().__init__()
71
+ self.norm1 = nn.LayerNorm(
72
+ hidden_size, elementwise_affine=not modulation, eps=eps
73
+ )
74
+ self.norm2 = nn.LayerNorm(
75
+ hidden_size, elementwise_affine=not modulation, eps=eps
76
+ )
77
+ self.attn = attention
78
+ self.ffn = ffn
79
+ self.modulation = modulation
80
+ if modulation:
81
+ self.adaLN_modulation = nn.Sequential(
82
+ nn.SiLU(),
83
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True),
84
+ )
85
+
86
+ def forward(self, x, condition=None, mask=None, **kwargs):
87
+ if condition is None:
88
+ assert not self.modulation, (
89
+ "Without global condition, must set modulation to False"
90
+ )
91
+ else:
92
+ assert self.modulation, "With global condition, must set modulation to True"
93
+ shift_attn, scale_attn, gate_attn, shift_ffn, scale_ffn, gate_ffn = (
94
+ self.adaLN_modulation(condition).chunk(6, dim=1)
95
+ )
96
+
97
+ if condition is not None:
98
+ pack_indices = kwargs.get("pack_indices")
99
+ if pack_indices is not None:
100
+ gate_attn = gate_attn[pack_indices]
101
+ gate_ffn = gate_ffn[pack_indices]
102
+ else:
103
+ gate_attn = gate_attn.unsqueeze(1)
104
+ gate_ffn = gate_ffn.unsqueeze(1)
105
+
106
+ if condition is not None:
107
+ x = x + gate_attn * self.attn(
108
+ modulate(self.norm1(x), shift_attn, scale_attn, **kwargs),
109
+ mask=mask,
110
+ **kwargs,
111
+ )
112
+ else:
113
+ x = x + self.attn(self.norm1(x), mask=mask, **kwargs)
114
+
115
+ if condition is not None:
116
+ x = x + gate_ffn * self.ffn(
117
+ modulate(self.norm2(x), shift_ffn, scale_ffn, **kwargs)
118
+ )
119
+ else:
120
+ x = x + self.ffn(self.norm2(x), mask=mask)
121
+ return x
122
+
123
+
124
+ class DiT(nn.Module):
125
+ def __init__(
126
+ self,
127
+ in_dim,
128
+ out_dim,
129
+ transformer_config,
130
+ *,
131
+ mode: str = "flow_matching",
132
+ ):
133
+ super().__init__()
134
+ if mode not in {"flow_matching", "meanflow"}:
135
+ raise ValueError(
136
+ f"DiT mode must be 'flow_matching' or 'meanflow', got {mode!r}."
137
+ )
138
+
139
+ transformer_kwargs = transformer_config.to_dict()
140
+ model_dim = transformer_config.hidden_size
141
+ self.mode = mode
142
+ self.num_layers = transformer_config.num_layers
143
+
144
+ self.input_layer = nn.Linear(in_dim, model_dim)
145
+ self.time_embedder = TimestepEmbedder(model_dim)
146
+ if mode == "meanflow":
147
+ self.duration_embedder = TimestepEmbedder(model_dim)
148
+
149
+ self.blocks = nn.ModuleList()
150
+ for i in range(self.num_layers):
151
+ attn_block = MultiHeadAttention(**transformer_kwargs, name=f"layer_{i}")
152
+ ffn_block = Mlp(
153
+ act_layer=lambda: nn.GELU(approximate="tanh"), **transformer_kwargs
154
+ )
155
+ self.blocks.append(
156
+ DiTBlock(attention=attn_block, ffn=ffn_block, **transformer_kwargs)
157
+ )
158
+
159
+ self.output_layer = FinalLayer(model_dim, out_dim)
160
+ self.initialize_weights()
161
+
162
+ def initialize_weights(self):
163
+ def _basic_init(module):
164
+ if isinstance(module, nn.Linear):
165
+ torch.nn.init.xavier_uniform_(module.weight)
166
+ if module.bias is not None:
167
+ nn.init.constant_(module.bias, 0)
168
+
169
+ self.apply(_basic_init)
170
+
171
+ nn.init.normal_(self.time_embedder.mlp[0].weight, std=0.02)
172
+ nn.init.normal_(self.time_embedder.mlp[2].weight, std=0.02)
173
+
174
+ for block in self.blocks:
175
+ if hasattr(block, "adaLN_modulation"):
176
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
177
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
178
+
179
+ nn.init.constant_(self.output_layer.adaLN_modulation[-1].weight, 0)
180
+ nn.init.constant_(self.output_layer.adaLN_modulation[-1].bias, 0)
181
+ nn.init.constant_(self.output_layer.linear.weight, 0)
182
+ nn.init.constant_(self.output_layer.linear.bias, 0)
183
+
184
+ def forward(
185
+ self,
186
+ x,
187
+ timesteps,
188
+ duration: torch.Tensor | None = None,
189
+ mask=None,
190
+ attn_mask=None,
191
+ g_cond: torch.Tensor | None = None,
192
+ **kwargs,
193
+ ):
194
+ t = self.time_embedder(timesteps)
195
+ c = t
196
+ duration_embedder = getattr(self, "duration_embedder", None)
197
+ if duration_embedder is not None and duration is not None:
198
+ c = c + duration_embedder(duration)
199
+ if g_cond is not None:
200
+ c = c + g_cond
201
+
202
+ x = self.input_layer(x)
203
+ for block in self.blocks:
204
+ x = block(x, c, mask=attn_mask, **kwargs)
205
+ return self.output_layer(x, c, **kwargs)
src/dots_tts/modules/backbone/layers.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+
7
+ class Dropout(nn.Module):
8
+ def __init__(
9
+ self, p: float = 0.5, inplace: bool = False, force_drop: bool = False, **_kwargs
10
+ ):
11
+ super().__init__()
12
+ if p < 0.0 or p > 1.0:
13
+ raise ValueError(
14
+ f"dropout probability has to be between 0 and 1, but got {p}"
15
+ )
16
+ self.p = p
17
+ self.inplace = inplace
18
+ self.force_drop = force_drop
19
+
20
+ def forward(self, x, **_kwargs):
21
+ return F.dropout(
22
+ x,
23
+ p=self.p,
24
+ training=True if self.force_drop else self.training,
25
+ inplace=self.inplace,
26
+ )
27
+
28
+
29
+ class Conv1d(nn.Conv1d):
30
+ def __init__(
31
+ self,
32
+ in_channels: int,
33
+ out_channels: int,
34
+ kernel_size: int = 1,
35
+ stride: int = 1,
36
+ dilation: int = 1,
37
+ groups: int = 1,
38
+ padding_mode: str = "zeros",
39
+ bias: bool = True,
40
+ padding=None,
41
+ causal: bool = False,
42
+ **_kwargs,
43
+ ):
44
+ self.causal = causal
45
+ if padding is None:
46
+ if causal:
47
+ padding = 0
48
+ self.left_padding = dilation * (kernel_size - 1)
49
+ else:
50
+ padding = int((kernel_size * dilation - dilation) / 2)
51
+
52
+ super().__init__(
53
+ in_channels,
54
+ out_channels,
55
+ kernel_size,
56
+ stride=stride,
57
+ padding=padding,
58
+ dilation=dilation,
59
+ groups=groups,
60
+ padding_mode=padding_mode,
61
+ bias=bias,
62
+ )
63
+
64
+ self.in_channels = in_channels
65
+
66
+ def forward(self, x):
67
+ if self.causal:
68
+ x = F.pad(x.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2)
69
+ return super().forward(x)
70
+
71
+
72
+ class ConvTranspose1d(nn.ConvTranspose1d):
73
+ def __init__(
74
+ self,
75
+ in_channels: int,
76
+ out_channels: int,
77
+ kernel_size: int,
78
+ stride: int = 1,
79
+ output_padding: int = 0,
80
+ groups: int = 1,
81
+ bias: bool = True,
82
+ dilation: int = 1,
83
+ padding=None,
84
+ padding_mode: str = "zeros",
85
+ causal: bool = False,
86
+ **_kwargs,
87
+ ):
88
+ if padding is None:
89
+ padding = 0 if causal else (kernel_size - stride) // 2
90
+ if causal:
91
+ assert padding == 0, "padding is not allowed in causal ConvTranspose1d."
92
+ assert kernel_size == 2 * stride, (
93
+ "kernel_size must be equal to 2*stride in Causal ConvTranspose1d."
94
+ )
95
+
96
+ super().__init__(
97
+ in_channels,
98
+ out_channels,
99
+ kernel_size,
100
+ stride=stride,
101
+ padding=padding,
102
+ output_padding=output_padding,
103
+ groups=groups,
104
+ bias=bias,
105
+ dilation=dilation,
106
+ padding_mode=padding_mode,
107
+ )
108
+
109
+ self.causal = causal
110
+ self.stride = stride
111
+
112
+ def forward(self, x):
113
+ x = super().forward(x)
114
+ if self.causal:
115
+ x = x[:, :, : -self.stride]
116
+ return x
117
+
118
+
119
+ class Mlp(nn.Module):
120
+ def __init__(
121
+ self,
122
+ hidden_size,
123
+ ffn_hidden_size=4096,
124
+ act_layer=nn.GELU,
125
+ dropout=0.0,
126
+ **_kwargs,
127
+ ):
128
+ super().__init__()
129
+ self.fc1 = nn.Linear(hidden_size, ffn_hidden_size)
130
+ self.act = act_layer()
131
+ self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
132
+ self.drop = Dropout(dropout)
133
+
134
+ def forward(self, x, _mask=None):
135
+ x = self.fc1(x)
136
+ x = self.act(x)
137
+ x = self.drop(x)
138
+ x = self.fc2(x)
139
+ return self.drop(x)
140
+
141
+
142
+ def rotate_half(x):
143
+ x1, x2 = x.chunk(2, dim=-1)
144
+ return torch.cat((-x2, x1), dim=-1)
145
+
146
+
147
+ @torch.autocast(enabled=False, device_type="cuda")
148
+ def apply_rotary_pos_emb(pos, t):
149
+ if pos.dim() == 3:
150
+ pos = pos.unsqueeze(1)
151
+ return t * pos.cos() + rotate_half(t) * pos.sin()
152
+
153
+
154
+ class RotaryEmbedding(nn.Module):
155
+ def __init__(self, dim, theta=50000):
156
+ super().__init__()
157
+ self.register_buffer(
158
+ "inv_freq",
159
+ 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)),
160
+ persistent=False,
161
+ )
162
+ self._theta = float(theta)
163
+
164
+ def _apply(self, fn):
165
+ inv_freq = self.inv_freq
166
+ super()._apply(fn)
167
+ self.inv_freq = inv_freq.to(device=self.inv_freq.device, dtype=torch.float32)
168
+ return self
169
+
170
+ @torch.autocast(enabled=False, device_type="cuda")
171
+ def forward(self, t):
172
+ inv_freq = self.inv_freq
173
+ if inv_freq.device != t.device:
174
+ raise RuntimeError(
175
+ "RotaryEmbedding buffer device mismatch: "
176
+ f"inv_freq={inv_freq.device} input={t.device}."
177
+ )
178
+ t = t.to(dtype=inv_freq.dtype)
179
+ if t.dim() == 1:
180
+ freqs = torch.einsum("i , j -> i j", t, inv_freq)
181
+ else:
182
+ freqs = torch.einsum("bi, j -> bij", t, inv_freq)
183
+ return torch.cat((freqs, freqs), dim=-1)
184
+
185
+
186
+ class MultiHeadAttention(nn.Module):
187
+ """Multi-head attention"""
188
+
189
+ def __init__(
190
+ self,
191
+ hidden_size: int,
192
+ num_heads: int = 8,
193
+ qkv_bias: bool = False,
194
+ qk_norm: bool = False,
195
+ attn_drop: float = 0.0,
196
+ dropout: float = 0.0,
197
+ norm_layer: str = "LayerNorm",
198
+ rotary_bias: bool = False,
199
+ rotary_theta: float | None = 50000,
200
+ **_kwargs,
201
+ ):
202
+ super().__init__()
203
+ assert hidden_size % num_heads == 0, (
204
+ "hidden_size should be divisible by num_heads"
205
+ )
206
+ self.num_heads = num_heads
207
+ self.head_dim = hidden_size // num_heads
208
+ self.scale = self.head_dim**-0.5
209
+ self.rotary_bias = rotary_bias
210
+
211
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
212
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
213
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
214
+
215
+ norm_layer = getattr(nn, norm_layer)
216
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
217
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
218
+
219
+ self.attn_drop = Dropout(attn_drop)
220
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
221
+ self.o_dropout = Dropout(dropout)
222
+
223
+ if self.rotary_bias:
224
+ self.rotary = RotaryEmbedding(self.head_dim, theta=rotary_theta)
225
+
226
+ def forward(self, q, k=None, v=None, mask=None, pos_ids=None, **_kwargs):
227
+ k = k or q
228
+ v = v or q
229
+ B, L, _ = q.shape
230
+ _, S, _ = v.shape
231
+ if mask is not None:
232
+ if mask.ndim == 2: # [B, L]
233
+ assert L == S
234
+ mask = rearrange(mask, "b j -> b 1 1 j")
235
+ mask = mask.expand(-1, self.num_heads, L, -1)
236
+ elif mask.ndim == 3: # [B, L, S]
237
+ assert mask.size(1) == L and mask.size(2) == S
238
+ mask = mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
239
+
240
+ q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
241
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
242
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
243
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
244
+ q, k = self.q_norm(q), self.k_norm(k)
245
+
246
+ # Apply rotary
247
+ if self.rotary_bias:
248
+ if L == S:
249
+ if pos_ids is None:
250
+ rotary_emb = self.rotary(torch.arange(L, device=q.device))
251
+ else:
252
+ rotary_emb = self.rotary(pos_ids)
253
+ q, k = (apply_rotary_pos_emb(rotary_emb, tensor) for tensor in (q, k))
254
+ else:
255
+ q_rotary_emb = self.rotary(torch.arange(L, device=q.device))
256
+ k_rotary_emb = self.rotary(torch.arange(S, device=k.device))
257
+ q = apply_rotary_pos_emb(q_rotary_emb, q)
258
+ k = apply_rotary_pos_emb(k_rotary_emb, k)
259
+
260
+ attn_bias = torch.zeros(B, self.num_heads, L, S, dtype=q.dtype, device=q.device)
261
+
262
+ if mask is not None:
263
+ attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
264
+
265
+ out = F.scaled_dot_product_attention(
266
+ q,
267
+ k,
268
+ v,
269
+ attn_mask=attn_bias,
270
+ dropout_p=self.attn_drop.p if self.training else 0.0,
271
+ )
272
+
273
+ out = rearrange(out, "b h n d -> b n (h d)")
274
+ return self.o_dropout(self.o_proj(out))
275
+
276
+ def decode_step(self, x, *, cache, positions: torch.Tensor):
277
+ if x.size(1) <= 0:
278
+ raise ValueError("MultiHeadAttention.decode_step expects a non-empty input.")
279
+ if positions.ndim != 1 or positions.size(0) != x.size(1):
280
+ raise ValueError(
281
+ "MultiHeadAttention.decode_step positions must match the decode block length."
282
+ )
283
+
284
+ q = self.q_proj(x)
285
+ k = self.k_proj(x)
286
+ v = self.v_proj(x)
287
+
288
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
289
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
290
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
291
+ q, k = self.q_norm(q), self.k_norm(k)
292
+ block_len = q.size(2)
293
+
294
+ if self.rotary_bias:
295
+ rotary_emb = self.rotary(positions)
296
+ q = apply_rotary_pos_emb(rotary_emb, q)
297
+ k = apply_rotary_pos_emb(rotary_emb, k)
298
+
299
+ cached_k, cached_v = cache
300
+ cached_k.index_copy_(2, positions, k)
301
+ cached_v.index_copy_(2, positions, v)
302
+
303
+ cache_capacity = cached_k.size(2)
304
+ key_positions = torch.arange(
305
+ cache_capacity,
306
+ device=x.device,
307
+ dtype=torch.long,
308
+ ).unsqueeze(0)
309
+ query_positions = positions.unsqueeze(1)
310
+ causal_mask = key_positions <= query_positions
311
+ valid_mask = key_positions <= positions[-1]
312
+ attn_bias = torch.zeros(
313
+ q.size(0),
314
+ self.num_heads,
315
+ block_len,
316
+ cache_capacity,
317
+ dtype=q.dtype,
318
+ device=q.device,
319
+ )
320
+ attn_bias.masked_fill_(
321
+ (causal_mask & valid_mask).unsqueeze(0).unsqueeze(0).logical_not(),
322
+ float("-inf"),
323
+ )
324
+
325
+ out = F.scaled_dot_product_attention(
326
+ q,
327
+ cached_k,
328
+ cached_v,
329
+ attn_mask=attn_bias,
330
+ dropout_p=self.attn_drop.p if self.training else 0.0,
331
+ )
332
+ out = rearrange(out, "b h n d -> b n (h d)")
333
+ return self.o_dropout(self.o_proj(out)), cache
src/dots_tts/modules/backbone/semantic_encoder.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ from dots_tts.modules.backbone.layers import Conv1d, Mlp, MultiHeadAttention
11
+
12
+
13
+ @dataclass
14
+ class SemanticEncoderDecodeState:
15
+ conv_tail: torch.Tensor
16
+ layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...]
17
+ seq_len: int
18
+
19
+
20
+ class TransformerEncoderLayer(nn.Module):
21
+ def __init__(
22
+ self,
23
+ hidden_size,
24
+ num_heads=16,
25
+ ffn_hidden_size=4096,
26
+ attn_dropout=0.0,
27
+ ffn_dropout=0.0,
28
+ norm_layer="LayerNorm",
29
+ **kwargs,
30
+ ):
31
+ super().__init__()
32
+ self.attn = MultiHeadAttention(
33
+ hidden_size,
34
+ num_heads,
35
+ attn_drop=attn_dropout,
36
+ norm_layer=norm_layer,
37
+ **kwargs,
38
+ )
39
+ norm_cls = getattr(nn, norm_layer)
40
+ self.attn_norm = norm_cls(hidden_size)
41
+ self.ffn = Mlp(
42
+ hidden_size, ffn_hidden_size, dropout=ffn_dropout, act_layer=nn.SiLU
43
+ )
44
+ self.ffn_norm = norm_cls(hidden_size)
45
+ self.hidden_size = hidden_size
46
+
47
+ def _build_causal_mask(self, T: int, device):
48
+ return torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
49
+
50
+ def _build_padding_mask(self, x_lens, max_len: int, device):
51
+ B = x_lens.size(0)
52
+ positions = torch.arange(max_len, device=device).unsqueeze(0).expand(B, -1)
53
+ return positions < x_lens.unsqueeze(1)
54
+
55
+ def _fuse_attn_mask(self, causal_mask, padding_mask):
56
+ if causal_mask is None and padding_mask is None:
57
+ return None
58
+ if causal_mask is None:
59
+ row = padding_mask.unsqueeze(2)
60
+ col = padding_mask.unsqueeze(1)
61
+ return row & col
62
+ if padding_mask is None:
63
+ return causal_mask.unsqueeze(0)
64
+
65
+ _B, _T = padding_mask.shape
66
+ causal = causal_mask.unsqueeze(0)
67
+ row = padding_mask.unsqueeze(2)
68
+ col = padding_mask.unsqueeze(1)
69
+ pad_2d = row & col
70
+ return causal & pad_2d
71
+
72
+ def forward(
73
+ self,
74
+ x,
75
+ x_lens=None,
76
+ causal=True,
77
+ ):
78
+ _B, T, C = x.shape
79
+ assert self.hidden_size == C
80
+ device = x.device
81
+
82
+ causal_mask = self._build_causal_mask(T, device) if causal else None
83
+ if x_lens is not None:
84
+ padding_mask = self._build_padding_mask(x_lens, T, device)
85
+ else:
86
+ padding_mask = None
87
+ fused_mask = self._fuse_attn_mask(causal_mask, padding_mask)
88
+
89
+ h = self.attn_norm(x)
90
+ h = self.attn(
91
+ q=h,
92
+ mask=fused_mask,
93
+ )
94
+ x = x + h
95
+
96
+ h = self.ffn_norm(x)
97
+ h = self.ffn(h)
98
+ return x + h
99
+
100
+ def decode_step(
101
+ self,
102
+ x,
103
+ *,
104
+ cache: tuple[torch.Tensor, torch.Tensor],
105
+ positions: torch.Tensor,
106
+ ):
107
+ if x.size(1) <= 0:
108
+ raise ValueError(
109
+ "TransformerEncoderLayer.decode_step expects a non-empty input."
110
+ )
111
+
112
+ h = self.attn_norm(x)
113
+ h, cache = self.attn.decode_step(h, cache=cache, positions=positions)
114
+ x = x + h
115
+
116
+ h = self.ffn_norm(x)
117
+ h = self.ffn(h)
118
+ return x + h, cache
119
+
120
+
121
+ class SuperviseEncoder(nn.Module):
122
+ def __init__(self, config):
123
+ super().__init__()
124
+ self.hidden_size = config.get("hidden_size", 1024)
125
+ self.layers = nn.ModuleList(
126
+ [
127
+ TransformerEncoderLayer(
128
+ hidden_size=self.hidden_size,
129
+ num_heads=config.get("num_heads", 16),
130
+ ffn_hidden_size=config.get("ffn_hidden_size", 4096),
131
+ norm_layer=config.get("norm_layer", "LayerNorm"),
132
+ )
133
+ for _ in range(config.get("num_layers", 6))
134
+ ]
135
+ )
136
+ self.causal = config.get("causal", False)
137
+
138
+ def forward(self, x, x_lens=None):
139
+ batch_size, seq_len, _ = x.shape
140
+ if x_lens is None:
141
+ x_lens = torch.full(
142
+ (batch_size,), seq_len, device=x.device, dtype=torch.long
143
+ )
144
+ for layer in self.layers:
145
+ x = layer(x, x_lens=x_lens, causal=self.causal)
146
+ return x
147
+
148
+ def init_decode_state(
149
+ self,
150
+ *,
151
+ batch_size: int,
152
+ max_seq_len: int,
153
+ device: torch.device,
154
+ dtype: torch.dtype,
155
+ ):
156
+ layer_caches = []
157
+ for layer in self.layers:
158
+ cache_shape = (
159
+ batch_size,
160
+ layer.attn.num_heads,
161
+ max_seq_len,
162
+ layer.attn.head_dim,
163
+ )
164
+ layer_caches.append(
165
+ (
166
+ torch.zeros(cache_shape, dtype=dtype, device=device),
167
+ torch.zeros(cache_shape, dtype=dtype, device=device),
168
+ )
169
+ )
170
+ return tuple(layer_caches)
171
+
172
+ def reset_decode_state(
173
+ self,
174
+ layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...],
175
+ ) -> None:
176
+ if len(layer_caches) != len(self.layers):
177
+ raise ValueError("Layer cache count does not match encoder depth.")
178
+ for key_cache, value_cache in layer_caches:
179
+ key_cache.zero_()
180
+ value_cache.zero_()
181
+
182
+ def decode_step(self, x, *, layer_caches, positions: torch.Tensor):
183
+ if len(layer_caches) != len(self.layers):
184
+ raise ValueError("Layer cache count does not match encoder depth.")
185
+
186
+ for layer, cache in zip(self.layers, layer_caches, strict=True):
187
+ x, _ = layer.decode_step(x, cache=cache, positions=positions)
188
+ return x
189
+
190
+
191
+ class VAESemanticEncoder(nn.Module):
192
+ def __init__(self, in_dim, out_dim, config):
193
+ super().__init__()
194
+ in_ds_rate = 2
195
+ self.patch_size = int(config.patch_size)
196
+ self.in_ds_rate = in_ds_rate
197
+ self.ds_proj = Conv1d(
198
+ in_dim, in_dim, kernel_size=in_ds_rate, stride=in_ds_rate, causal=True
199
+ )
200
+ self.in_proj = nn.Linear(in_dim, config.PatchEncoder.hidden_size)
201
+ self.encoder = SuperviseEncoder(config.PatchEncoder)
202
+ self.out_ds_rate = self.patch_size // in_ds_rate
203
+ self.out_proj = nn.Linear(
204
+ config.PatchEncoder.hidden_size * self.out_ds_rate, out_dim
205
+ )
206
+
207
+ def forward(self, x, x_lens=None):
208
+ x = self._downsample(x)
209
+ x = self.in_proj(x)
210
+ z = self.encoder(x, x_lens=x_lens)
211
+ return self._project_embeddings(z)
212
+
213
+ def init_decode_state(
214
+ self,
215
+ *,
216
+ max_audio_patch_count: int,
217
+ batch_size: int,
218
+ device: torch.device,
219
+ dtype: torch.dtype,
220
+ ) -> SemanticEncoderDecodeState:
221
+ return SemanticEncoderDecodeState(
222
+ conv_tail=torch.zeros(
223
+ (batch_size, self.ds_proj.in_channels, self.ds_proj.left_padding),
224
+ dtype=dtype,
225
+ device=device,
226
+ ),
227
+ layer_caches=self.encoder.init_decode_state(
228
+ batch_size=batch_size,
229
+ max_seq_len=max_audio_patch_count * self.out_ds_rate,
230
+ device=device,
231
+ dtype=dtype,
232
+ ),
233
+ seq_len=0,
234
+ )
235
+
236
+ def reset_decode_state(self, state: SemanticEncoderDecodeState) -> None:
237
+ state.conv_tail.zero_()
238
+ self.encoder.reset_decode_state(state.layer_caches)
239
+ state.seq_len = 0
240
+
241
+ def prefill(
242
+ self,
243
+ x,
244
+ state: SemanticEncoderDecodeState,
245
+ ) -> tuple[torch.Tensor, SemanticEncoderDecodeState]:
246
+ if x.ndim != 3:
247
+ raise ValueError(
248
+ f"VAESemanticEncoder.prefill expects rank-3 input, got {tuple(x.shape)}."
249
+ )
250
+ if x.size(1) % self.patch_size != 0:
251
+ raise ValueError(
252
+ f"Prompt latent length {x.size(1)} must be divisible by patch_size={self.patch_size}."
253
+ )
254
+
255
+ if x.size(1) == 0:
256
+ return (
257
+ x.new_zeros((x.size(0), 0, self.out_proj.out_features)),
258
+ state,
259
+ )
260
+ if state.conv_tail.size(0) != x.size(0):
261
+ raise ValueError(
262
+ "VAESemanticEncoder.prefill batch size does not match decode state."
263
+ )
264
+
265
+ step_inputs = self.in_proj(self._downsample(x))
266
+ expected_token_count = (x.size(1) // self.patch_size) * self.out_ds_rate
267
+ if step_inputs.size(1) != expected_token_count:
268
+ raise RuntimeError(
269
+ "Patch encoder prefill produced an unexpected token count: "
270
+ f"expected={expected_token_count} actual={step_inputs.size(1)}."
271
+ )
272
+
273
+ current_seq_len = state.seq_len
274
+ next_seq_len = current_seq_len + step_inputs.size(1)
275
+ cache_capacity = state.layer_caches[0][0].size(2)
276
+ if next_seq_len > cache_capacity:
277
+ raise ValueError(
278
+ "Patch encoder prefill exceeds decode-state capacity: "
279
+ f"required={next_seq_len} capacity={cache_capacity}."
280
+ )
281
+
282
+ positions = (
283
+ torch.arange(step_inputs.size(1), device=x.device, dtype=torch.long)
284
+ + current_seq_len
285
+ )
286
+ encoded = self.encoder.decode_step(
287
+ step_inputs,
288
+ layer_caches=state.layer_caches,
289
+ positions=positions,
290
+ )
291
+ embedding = self._project_embeddings(encoded)
292
+ raw = x.transpose(1, 2)
293
+ state.conv_tail.copy_(raw[..., -self.ds_proj.left_padding :])
294
+ state.seq_len = next_seq_len
295
+ return embedding, state
296
+
297
+ def decode_patch(
298
+ self,
299
+ latent_patch,
300
+ conv_tail: torch.Tensor,
301
+ layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...],
302
+ positions: torch.Tensor,
303
+ ) -> tuple[torch.Tensor, torch.Tensor]:
304
+ if latent_patch.ndim != 3:
305
+ raise ValueError(
306
+ f"VAESemanticEncoder.decode_patch expects rank-3 input, got {tuple(latent_patch.shape)}."
307
+ )
308
+ if latent_patch.size(1) != self.patch_size:
309
+ raise ValueError(
310
+ f"decode_patch expects patch length {self.patch_size}, got {latent_patch.size(1)}."
311
+ )
312
+ if positions.ndim != 1 or positions.size(0) != self.out_ds_rate:
313
+ raise ValueError(
314
+ "decode_patch positions must be a rank-1 tensor matching out_ds_rate."
315
+ )
316
+
317
+ step_inputs, conv_tail = self._downsample_step(
318
+ latent_patch,
319
+ conv_tail=conv_tail,
320
+ )
321
+ if step_inputs.size(1) != self.out_ds_rate:
322
+ raise RuntimeError(
323
+ f"Downsample step produced {step_inputs.size(1)} tokens, expected {self.out_ds_rate}."
324
+ )
325
+
326
+ encoded = self.encoder.decode_step(
327
+ step_inputs,
328
+ layer_caches=layer_caches,
329
+ positions=positions,
330
+ )
331
+ embedding = self._project_embeddings(encoded)
332
+ return embedding, conv_tail
333
+
334
+ def _downsample(self, x):
335
+ return self.ds_proj(x.transpose(1, 2)).transpose(1, 2)
336
+
337
+ def _project_embeddings(self, z):
338
+ if self.out_ds_rate > 1:
339
+ z = rearrange(z, "b (s d) h -> b s (d h)", d=self.out_ds_rate)
340
+ return self.out_proj(z)
341
+
342
+ def _downsample_step(self, latent_patch, *, conv_tail):
343
+ raw = latent_patch.transpose(1, 2)
344
+ conv_input = torch.cat([conv_tail, raw], dim=-1)
345
+
346
+ projected = F.conv1d(
347
+ conv_input,
348
+ self.ds_proj.weight,
349
+ self.ds_proj.bias,
350
+ stride=self.ds_proj.stride[0],
351
+ padding=0,
352
+ dilation=self.ds_proj.dilation[0],
353
+ groups=self.ds_proj.groups,
354
+ ).transpose(1, 2)
355
+ new_conv_tail = raw[..., -self.ds_proj.left_padding :]
356
+ return self.in_proj(projected), new_conv_tail
src/dots_tts/modules/speaker/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Speaker modules."""
src/dots_tts/modules/speaker/campplus.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from dots_tts.modules.speaker.campplus_layers import (
11
+ BasicResBlock,
12
+ CAMDenseTDNNBlock,
13
+ DenseLayer,
14
+ StatsPool,
15
+ TDNNLayer,
16
+ TransitLayer,
17
+ get_nonlinear,
18
+ )
19
+ from dots_tts.modules.speaker.fbank import _SPEAKER_FBANK_N_MELS
20
+
21
+
22
+ class FCM(nn.Module):
23
+ def __init__(
24
+ self,
25
+ block=BasicResBlock,
26
+ num_blocks=(2, 2),
27
+ m_channels=32,
28
+ feat_dim=_SPEAKER_FBANK_N_MELS,
29
+ ):
30
+ super().__init__()
31
+ self.in_planes = m_channels
32
+ self.conv1 = nn.Conv2d(
33
+ 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False
34
+ )
35
+ self.bn1 = nn.BatchNorm2d(m_channels)
36
+
37
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
38
+ self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
39
+
40
+ self.conv2 = nn.Conv2d(
41
+ m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False
42
+ )
43
+ self.bn2 = nn.BatchNorm2d(m_channels)
44
+ self.out_channels = m_channels * (feat_dim // 8)
45
+
46
+ def _make_layer(self, block, planes, num_blocks, stride):
47
+ strides = [stride] + [1] * (num_blocks - 1)
48
+ layers = []
49
+ for stride in strides:
50
+ layers.append(block(self.in_planes, planes, stride))
51
+ self.in_planes = planes * block.expansion
52
+ return nn.Sequential(*layers)
53
+
54
+ def forward(self, x):
55
+ x = x.unsqueeze(1)
56
+ out = F.relu(self.bn1(self.conv1(x)))
57
+ out = self.layer1(out)
58
+ out = self.layer2(out)
59
+ out = F.relu(self.bn2(self.conv2(out)))
60
+
61
+ shape = out.shape
62
+ return out.reshape(shape[0], shape[1] * shape[2], shape[3])
63
+
64
+
65
+ class CAMPPlus(nn.Module):
66
+ _TDNN_KERNEL_SIZE = 5
67
+ _TDNN_STRIDE = 2
68
+ _TDNN_PADDING = 2
69
+
70
+ def __init__(
71
+ self,
72
+ feat_dim=_SPEAKER_FBANK_N_MELS,
73
+ embedding_size=512,
74
+ growth_rate=32,
75
+ bn_size=4,
76
+ init_channels=128,
77
+ config_str="batchnorm-relu",
78
+ memory_efficient=True,
79
+ ):
80
+ super().__init__()
81
+
82
+ self.head = FCM(feat_dim=feat_dim)
83
+ channels = self.head.out_channels
84
+
85
+ self.xvector = nn.Sequential(
86
+ OrderedDict(
87
+ [
88
+ (
89
+ "tdnn",
90
+ TDNNLayer(
91
+ channels,
92
+ init_channels,
93
+ self._TDNN_KERNEL_SIZE,
94
+ stride=self._TDNN_STRIDE,
95
+ dilation=1,
96
+ padding=-1,
97
+ config_str=config_str,
98
+ ),
99
+ ),
100
+ ]
101
+ )
102
+ )
103
+ channels = init_channels
104
+ for i, (num_layers, kernel_size, dilation) in enumerate(
105
+ zip((12, 24, 16), (3, 3, 3), (1, 2, 2), strict=True)
106
+ ):
107
+ block = CAMDenseTDNNBlock(
108
+ num_layers=num_layers,
109
+ in_channels=channels,
110
+ out_channels=growth_rate,
111
+ bn_channels=bn_size * growth_rate,
112
+ kernel_size=kernel_size,
113
+ dilation=dilation,
114
+ config_str=config_str,
115
+ memory_efficient=memory_efficient,
116
+ )
117
+ self.xvector.add_module(f"block{i + 1}", block)
118
+ channels = channels + num_layers * growth_rate
119
+ self.xvector.add_module(
120
+ f"transit{i + 1}",
121
+ TransitLayer(
122
+ channels, channels // 2, bias=False, config_str=config_str
123
+ ),
124
+ )
125
+ channels //= 2
126
+
127
+ self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels))
128
+
129
+ self.xvector.add_module("stats", StatsPool())
130
+ self.xvector.add_module(
131
+ "dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_")
132
+ )
133
+
134
+ for m in self.modules():
135
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
136
+ nn.init.kaiming_normal_(m.weight.data)
137
+ if m.bias is not None:
138
+ nn.init.zeros_(m.bias)
139
+
140
+ @staticmethod
141
+ def _conv_output_lengths(lengths, kernel_size, stride=1, padding=0, dilation=1):
142
+ return (
143
+ torch.div(
144
+ lengths + 2 * padding - dilation * (kernel_size - 1) - 1,
145
+ stride,
146
+ rounding_mode="floor",
147
+ )
148
+ + 1
149
+ )
150
+
151
+ @staticmethod
152
+ def _make_length_mask(lengths, max_len, device):
153
+ lengths = lengths.to(device=device, dtype=torch.long).clamp(min=0, max=max_len)
154
+ return torch.arange(max_len, device=device).unsqueeze(0) < lengths.unsqueeze(1)
155
+
156
+ def _masked_stats_pooling(self, x, lengths, unbiased=True, eps=1e-2):
157
+ lengths = lengths.to(device=x.device, dtype=torch.long).clamp(
158
+ min=1, max=x.size(-1)
159
+ )
160
+ mask = self._make_length_mask(lengths, x.size(-1), x.device).unsqueeze(1)
161
+ mask = mask.to(dtype=x.dtype)
162
+
163
+ denom = lengths.to(dtype=x.dtype).view(-1, 1).clamp_min(1.0)
164
+ mean = (x * mask).sum(dim=-1) / denom
165
+
166
+ centered = (x - mean.unsqueeze(-1)) * mask
167
+ var_denom = (
168
+ (lengths - 1).clamp_min(1).to(dtype=x.dtype).view(-1, 1)
169
+ if unbiased
170
+ else denom
171
+ )
172
+ var = centered.pow(2).sum(dim=-1) / var_denom
173
+ std = torch.sqrt(var.clamp_min(eps))
174
+ return torch.cat([mean, std], dim=1)
175
+
176
+ def forward(self, x, lengths=None):
177
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
178
+ x = self.head(x)
179
+ if lengths is not None:
180
+ lengths = lengths.to(device=x.device, dtype=torch.long).clamp(min=1)
181
+
182
+ for name, module in self.xvector.named_children():
183
+ if name == "stats":
184
+ x = (
185
+ self._masked_stats_pooling(x, lengths)
186
+ if lengths is not None
187
+ else module(x)
188
+ )
189
+ continue
190
+
191
+ x = module(x)
192
+ if name == "tdnn" and lengths is not None:
193
+ lengths = self._conv_output_lengths(
194
+ lengths,
195
+ kernel_size=self._TDNN_KERNEL_SIZE,
196
+ stride=self._TDNN_STRIDE,
197
+ padding=self._TDNN_PADDING,
198
+ )
199
+
200
+ return x
src/dots_tts/modules/speaker/campplus_layers.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint as cp
7
+ from torch import nn
8
+
9
+
10
+ def get_nonlinear(config_str, channels):
11
+ nonlinear = nn.Sequential()
12
+ for name in config_str.split("-"):
13
+ if name == "relu":
14
+ nonlinear.add_module("relu", nn.ReLU(inplace=True))
15
+ elif name == "prelu":
16
+ nonlinear.add_module("prelu", nn.PReLU(channels))
17
+ elif name == "batchnorm":
18
+ nonlinear.add_module("batchnorm", nn.BatchNorm1d(channels))
19
+ elif name == "batchnorm_":
20
+ nonlinear.add_module("batchnorm", nn.BatchNorm1d(channels, affine=False))
21
+ else:
22
+ raise ValueError(f"Unexpected module ({name}).")
23
+ return nonlinear
24
+
25
+
26
+ def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, _eps=1e-2):
27
+ mean = x.mean(dim=dim)
28
+ std = x.std(dim=dim, unbiased=unbiased)
29
+ stats = torch.cat([mean, std], dim=-1)
30
+ if keepdim:
31
+ stats = stats.unsqueeze(dim=dim)
32
+ return stats
33
+
34
+
35
+ class StatsPool(nn.Module):
36
+ def forward(self, x):
37
+ return statistics_pooling(x)
38
+
39
+
40
+ class TDNNLayer(nn.Module):
41
+ def __init__(
42
+ self,
43
+ in_channels,
44
+ out_channels,
45
+ kernel_size,
46
+ stride=1,
47
+ padding=0,
48
+ dilation=1,
49
+ bias=False,
50
+ config_str="batchnorm-relu",
51
+ ):
52
+ super().__init__()
53
+ if padding < 0:
54
+ assert kernel_size % 2 == 1, (
55
+ f"Expect equal paddings, but got even kernel size ({kernel_size})"
56
+ )
57
+ padding = (kernel_size - 1) // 2 * dilation
58
+ self.linear = nn.Conv1d(
59
+ in_channels,
60
+ out_channels,
61
+ kernel_size,
62
+ stride=stride,
63
+ padding=padding,
64
+ dilation=dilation,
65
+ bias=bias,
66
+ )
67
+ self.nonlinear = get_nonlinear(config_str, out_channels)
68
+
69
+ def forward(self, x):
70
+ x = self.linear(x)
71
+ return self.nonlinear(x)
72
+
73
+
74
+ class CAMLayer(nn.Module):
75
+ def __init__(
76
+ self,
77
+ bn_channels,
78
+ out_channels,
79
+ kernel_size,
80
+ stride,
81
+ padding,
82
+ dilation,
83
+ bias,
84
+ reduction=2,
85
+ ):
86
+ super().__init__()
87
+ self.linear_local = nn.Conv1d(
88
+ bn_channels,
89
+ out_channels,
90
+ kernel_size,
91
+ stride=stride,
92
+ padding=padding,
93
+ dilation=dilation,
94
+ bias=bias,
95
+ )
96
+ self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
97
+ self.relu = nn.ReLU(inplace=True)
98
+ self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
99
+ self.sigmoid = nn.Sigmoid()
100
+
101
+ def forward(self, x):
102
+ y = self.linear_local(x)
103
+ context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
104
+ context = self.relu(self.linear1(context))
105
+ m = self.sigmoid(self.linear2(context))
106
+ return y * m
107
+
108
+ def seg_pooling(self, x, seg_len=100, stype="avg"):
109
+ if stype == "avg":
110
+ seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
111
+ elif stype == "max":
112
+ seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
113
+ else:
114
+ raise ValueError("Wrong segment pooling type.")
115
+ shape = seg.shape
116
+ seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
117
+ return seg[..., : x.shape[-1]]
118
+
119
+
120
+ class CAMDenseTDNNLayer(nn.Module):
121
+ def __init__(
122
+ self,
123
+ in_channels,
124
+ out_channels,
125
+ bn_channels,
126
+ kernel_size,
127
+ stride=1,
128
+ dilation=1,
129
+ bias=False,
130
+ config_str="batchnorm-relu",
131
+ memory_efficient=False,
132
+ ):
133
+ super().__init__()
134
+ assert kernel_size % 2 == 1, (
135
+ f"Expect equal paddings, but got even kernel size ({kernel_size})"
136
+ )
137
+ padding = (kernel_size - 1) // 2 * dilation
138
+ self.memory_efficient = memory_efficient
139
+ self.nonlinear1 = get_nonlinear(config_str, in_channels)
140
+ self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
141
+ self.nonlinear2 = get_nonlinear(config_str, bn_channels)
142
+ self.cam_layer = CAMLayer(
143
+ bn_channels,
144
+ out_channels,
145
+ kernel_size,
146
+ stride=stride,
147
+ padding=padding,
148
+ dilation=dilation,
149
+ bias=bias,
150
+ )
151
+
152
+ def bn_function(self, x):
153
+ return self.linear1(self.nonlinear1(x))
154
+
155
+ def forward(self, x):
156
+ if self.training and self.memory_efficient:
157
+ x = cp.checkpoint(self.bn_function, x)
158
+ else:
159
+ x = self.bn_function(x)
160
+ return self.cam_layer(self.nonlinear2(x))
161
+
162
+
163
+ class CAMDenseTDNNBlock(nn.ModuleList):
164
+ def __init__(
165
+ self,
166
+ num_layers,
167
+ in_channels,
168
+ out_channels,
169
+ bn_channels,
170
+ kernel_size,
171
+ stride=1,
172
+ dilation=1,
173
+ bias=False,
174
+ config_str="batchnorm-relu",
175
+ memory_efficient=False,
176
+ ):
177
+ super().__init__()
178
+ for i in range(num_layers):
179
+ layer = CAMDenseTDNNLayer(
180
+ in_channels=in_channels + i * out_channels,
181
+ out_channels=out_channels,
182
+ bn_channels=bn_channels,
183
+ kernel_size=kernel_size,
184
+ stride=stride,
185
+ dilation=dilation,
186
+ bias=bias,
187
+ config_str=config_str,
188
+ memory_efficient=memory_efficient,
189
+ )
190
+ self.add_module(f"tdnnd{i + 1}", layer)
191
+
192
+ def forward(self, x):
193
+ for layer in self:
194
+ x = torch.cat([x, layer(x)], dim=1)
195
+ return x
196
+
197
+
198
+ class TransitLayer(nn.Module):
199
+ def __init__(
200
+ self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"
201
+ ):
202
+ super().__init__()
203
+ self.nonlinear = get_nonlinear(config_str, in_channels)
204
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
205
+
206
+ def forward(self, x):
207
+ x = self.nonlinear(x)
208
+ return self.linear(x)
209
+
210
+
211
+ class DenseLayer(nn.Module):
212
+ def __init__(
213
+ self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"
214
+ ):
215
+ super().__init__()
216
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
217
+ self.nonlinear = get_nonlinear(config_str, out_channels)
218
+
219
+ def forward(self, x):
220
+ if len(x.shape) == 2:
221
+ x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
222
+ else:
223
+ x = self.linear(x)
224
+ return self.nonlinear(x)
225
+
226
+
227
+ class BasicResBlock(nn.Module):
228
+ expansion = 1
229
+
230
+ def __init__(self, in_planes, planes, stride=1):
231
+ super().__init__()
232
+ self.conv1 = nn.Conv2d(
233
+ in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False
234
+ )
235
+ self.bn1 = nn.BatchNorm2d(planes)
236
+ self.conv2 = nn.Conv2d(
237
+ planes, planes, kernel_size=3, stride=1, padding=1, bias=False
238
+ )
239
+ self.bn2 = nn.BatchNorm2d(planes)
240
+
241
+ self.shortcut = nn.Sequential()
242
+ if stride != 1 or in_planes != self.expansion * planes:
243
+ self.shortcut = nn.Sequential(
244
+ nn.Conv2d(
245
+ in_planes,
246
+ self.expansion * planes,
247
+ kernel_size=1,
248
+ stride=(stride, 1),
249
+ bias=False,
250
+ ),
251
+ nn.BatchNorm2d(self.expansion * planes),
252
+ )
253
+
254
+ def forward(self, x):
255
+ out = F.relu(self.bn1(self.conv1(x)))
256
+ out = self.bn2(self.conv2(out))
257
+ out += self.shortcut(x)
258
+ return F.relu(out)
src/dots_tts/modules/speaker/encoder.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchaudio
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+ from dots_tts.modules.speaker.campplus import CAMPPlus
10
+ from dots_tts.modules.speaker.fbank import (
11
+ _SPEAKER_FBANK_N_MELS,
12
+ _SPEAKER_FBANK_SAMPLE_RATE,
13
+ extract_speaker_fbank,
14
+ )
15
+
16
+
17
+ class SpeakerXVectorFeatures(nn.Module):
18
+ """
19
+ Speaker embedding extractor based on 3D-Speaker CAM++.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ sample_rate=_SPEAKER_FBANK_SAMPLE_RATE,
25
+ campplus_embedding_size=512,
26
+ max_audio_seconds=10.0,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.sample_rate = sample_rate
31
+ self.max_audio_seconds = float(max_audio_seconds)
32
+ self.model = CAMPPlus(
33
+ feat_dim=_SPEAKER_FBANK_N_MELS,
34
+ embedding_size=campplus_embedding_size,
35
+ )
36
+ self.resample = None
37
+ if self.sample_rate != _SPEAKER_FBANK_SAMPLE_RATE:
38
+ self.resample = torchaudio.transforms.Resample(
39
+ orig_freq=sample_rate,
40
+ new_freq=_SPEAKER_FBANK_SAMPLE_RATE,
41
+ )
42
+
43
+ for param in self.model.parameters():
44
+ param.requires_grad = False
45
+
46
+ @staticmethod
47
+ def _normalize_lengths(lengths, batch_size, max_length, device, *, min_length):
48
+ if lengths is None:
49
+ return torch.full(
50
+ (batch_size,),
51
+ max_length,
52
+ device=device,
53
+ dtype=torch.long,
54
+ )
55
+ return lengths.to(device=device, dtype=torch.long).clamp(
56
+ min=min_length,
57
+ max=max_length,
58
+ )
59
+
60
+ def _crop_audio(self, audio, audio_lengths=None):
61
+ original_lengths = self._normalize_lengths(
62
+ audio_lengths,
63
+ audio.size(0),
64
+ audio.size(-1),
65
+ audio.device,
66
+ min_length=0,
67
+ )
68
+ if self.max_audio_seconds <= 0:
69
+ return audio, original_lengths, original_lengths, torch.zeros_like(
70
+ original_lengths
71
+ )
72
+
73
+ max_input_length = round(self.sample_rate * self.max_audio_seconds)
74
+ cropped_audio = []
75
+ cropped_lengths = []
76
+ starts = []
77
+
78
+ for index, total_length_tensor in enumerate(original_lengths):
79
+ total_length = int(total_length_tensor.item())
80
+ cropped_length = min(total_length, max_input_length)
81
+ start = (
82
+ random.randint(0, total_length - cropped_length)
83
+ if total_length > cropped_length
84
+ else 0
85
+ )
86
+ cropped_audio.append(audio[index, start : start + cropped_length])
87
+ cropped_lengths.append(cropped_length)
88
+ starts.append(start)
89
+
90
+ return pad_sequence(
91
+ cropped_audio,
92
+ batch_first=True,
93
+ padding_value=0.0,
94
+ ), original_lengths, torch.tensor(
95
+ cropped_lengths,
96
+ device=audio.device,
97
+ dtype=torch.long,
98
+ ), torch.tensor(starts, device=audio.device, dtype=torch.long)
99
+
100
+ def _crop_fbank(
101
+ self,
102
+ fbank,
103
+ fbank_lengths,
104
+ original_audio_lengths,
105
+ cropped_audio_lengths,
106
+ starts,
107
+ ):
108
+ original_fbank_lengths = self._normalize_lengths(
109
+ fbank_lengths,
110
+ fbank.size(0),
111
+ fbank.size(1),
112
+ fbank.device,
113
+ min_length=1,
114
+ )
115
+ cropped_fbank = []
116
+ cropped_fbank_lengths = []
117
+
118
+ for index, total_feat_length_tensor in enumerate(original_fbank_lengths):
119
+ total_audio_length = int(original_audio_lengths[index].item())
120
+ total_feat_length = int(total_feat_length_tensor.item())
121
+ start_audio = int(starts[index].item())
122
+ end_audio = start_audio + int(cropped_audio_lengths[index].item())
123
+
124
+ if total_audio_length > 0:
125
+ start_feat = math.floor(
126
+ start_audio * total_feat_length / total_audio_length
127
+ )
128
+ end_feat = math.ceil(end_audio * total_feat_length / total_audio_length)
129
+ else:
130
+ start_feat = 0
131
+ end_feat = 1
132
+
133
+ start_feat = min(start_feat, total_feat_length - 1)
134
+ end_feat = min(max(end_feat, start_feat + 1), total_feat_length)
135
+ cropped_fbank.append(fbank[index, start_feat:end_feat])
136
+ cropped_fbank_lengths.append(end_feat - start_feat)
137
+
138
+ return pad_sequence(
139
+ cropped_fbank,
140
+ batch_first=True,
141
+ padding_value=0.0,
142
+ ), torch.tensor(
143
+ cropped_fbank_lengths,
144
+ device=fbank.device,
145
+ dtype=torch.long,
146
+ )
147
+
148
+ def _extract_fbank_batch(self, audio, audio_lengths):
149
+ if self.resample is not None:
150
+ audio = self.resample(audio)
151
+ audio_lengths = torch.ceil(
152
+ audio_lengths.float()
153
+ * (_SPEAKER_FBANK_SAMPLE_RATE / self.sample_rate)
154
+ ).long()
155
+
156
+ audio_cpu = audio.detach().cpu()
157
+ features = []
158
+
159
+ for index, valid_length_tensor in enumerate(audio_lengths):
160
+ valid_length = int(valid_length_tensor.item())
161
+ waveform = audio_cpu[index, :valid_length]
162
+ if waveform.numel() == 0:
163
+ waveform = audio_cpu.new_zeros(1)
164
+ features.append(
165
+ extract_speaker_fbank(
166
+ waveform,
167
+ sample_rate=_SPEAKER_FBANK_SAMPLE_RATE,
168
+ )
169
+ )
170
+
171
+ fbank_lengths = torch.tensor(
172
+ [feature.size(0) for feature in features],
173
+ device=audio.device,
174
+ dtype=torch.long,
175
+ )
176
+ fbank = pad_sequence(
177
+ features,
178
+ batch_first=True,
179
+ padding_value=0.0,
180
+ ).to(device=audio.device, dtype=audio.dtype)
181
+ return fbank, fbank_lengths
182
+
183
+ @torch.no_grad()
184
+ @torch.autocast(enabled=False, device_type="cuda")
185
+ def forward(
186
+ self, audio, audio_lengths=None, fbank=None, fbank_lengths=None, **_kwargs
187
+ ):
188
+ self.model.eval()
189
+ audio = audio.float()
190
+ if audio.dim() == 3:
191
+ if audio.size(1) != 1:
192
+ raise ValueError(
193
+ f"Speaker encoder expects mono audio, got shape {tuple(audio.shape)}."
194
+ )
195
+ audio = audio[:, 0]
196
+ elif audio.dim() != 2:
197
+ raise ValueError(
198
+ f"Speaker encoder expects a 2D or 3D audio tensor, got shape {tuple(audio.shape)}."
199
+ )
200
+
201
+ audio, original_audio_lengths, cropped_audio_lengths, starts = self._crop_audio(
202
+ audio,
203
+ audio_lengths=audio_lengths,
204
+ )
205
+
206
+ if fbank is None:
207
+ fbank, fbank_lengths = self._extract_fbank_batch(
208
+ audio,
209
+ cropped_audio_lengths,
210
+ )
211
+ else:
212
+ if not isinstance(fbank, torch.Tensor):
213
+ raise TypeError("Speaker encoder expects `fbank` to be a torch.Tensor.")
214
+ if fbank.dim() != 3 or fbank.size(0) != audio.size(0):
215
+ raise ValueError(
216
+ f"Speaker encoder expects `fbank` with shape (B, T, F) and matching batch size, got {tuple(fbank.shape)}."
217
+ )
218
+ fbank, fbank_lengths = self._crop_fbank(
219
+ fbank.to(device=audio.device, dtype=torch.float32),
220
+ fbank_lengths,
221
+ original_audio_lengths,
222
+ cropped_audio_lengths,
223
+ starts,
224
+ )
225
+
226
+ return self.model(fbank, lengths=fbank_lengths)
src/dots_tts/modules/speaker/fbank.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+ from dots_tts.utils.audio import extract_fbank, high_quality_resample
6
+
7
+ _SPEAKER_FBANK_SAMPLE_RATE = 16000
8
+ _SPEAKER_FBANK_N_MELS = 80
9
+ _SPEAKER_FBANK_MEAN_NORM = True
10
+ _SPEAKER_FBANK_DITHER = 0.0
11
+
12
+
13
+ def extract_speaker_fbank(
14
+ waveform: torch.Tensor,
15
+ *,
16
+ sample_rate: int,
17
+ ) -> torch.Tensor:
18
+ feature_input = waveform
19
+ if sample_rate != _SPEAKER_FBANK_SAMPLE_RATE:
20
+ feature_input = high_quality_resample(
21
+ waveform,
22
+ orig_sr=sample_rate,
23
+ target_sr=_SPEAKER_FBANK_SAMPLE_RATE,
24
+ )
25
+ return extract_fbank(
26
+ feature_input,
27
+ sample_rate=_SPEAKER_FBANK_SAMPLE_RATE,
28
+ n_mels=_SPEAKER_FBANK_N_MELS,
29
+ dither=_SPEAKER_FBANK_DITHER,
30
+ mean_norm=_SPEAKER_FBANK_MEAN_NORM,
31
+ )
src/dots_tts/modules/vocoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Vocoder modules."""