Spaces:
Running on Zero
Running on Zero
add inference code with AOTI support for hf space
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- LICENSE +201 -0
- README.md +7 -1
- app.py +176 -0
- apps/__init__.py +1 -0
- apps/gradio/__init__.py +1 -0
- apps/gradio/app.py +663 -0
- apps/gradio/constants.py +26 -0
- apps/gradio/default_prompts/prompt_text +0 -0
- apps/gradio/languages.py +115 -0
- apps/gradio/service.py +773 -0
- configs/dots_tts.yaml +76 -0
- requirements.txt +19 -0
- src/dots_tts/__init__.py +1 -0
- src/dots_tts/cli.py +152 -0
- src/dots_tts/config/__init__.py +1 -0
- src/dots_tts/config/app.py +32 -0
- src/dots_tts/config/base.py +64 -0
- src/dots_tts/config/data.py +63 -0
- src/dots_tts/config/train.py +28 -0
- src/dots_tts/data/EXTENSION.md +124 -0
- src/dots_tts/data/__init__.py +1 -0
- src/dots_tts/data/batchers.py +188 -0
- src/dots_tts/data/builders.py +194 -0
- src/dots_tts/data/collator.py +87 -0
- src/dots_tts/data/pipelines/__init__.py +1 -0
- src/dots_tts/data/pipelines/base.py +32 -0
- src/dots_tts/data/pipelines/preprocessing.py +84 -0
- src/dots_tts/data/pipelines/tokenizing.py +339 -0
- src/dots_tts/data/pipelines/tts_pipeline.py +132 -0
- src/dots_tts/data/source_adapters/__init__.py +1 -0
- src/dots_tts/data/source_adapters/base_adapter.py +91 -0
- src/dots_tts/data/source_adapters/jsonl_manifest_adapter.py +132 -0
- src/dots_tts/data/source_adapters/multi_source_adapter.py +222 -0
- src/dots_tts/data/streaming.py +400 -0
- src/dots_tts/models/__init__.py +1 -0
- src/dots_tts/models/dots_tts/__init__.py +1 -0
- src/dots_tts/models/dots_tts/config.py +71 -0
- src/dots_tts/models/dots_tts/core.py +910 -0
- src/dots_tts/models/dots_tts/model.py +1958 -0
- src/dots_tts/modules/__init__.py +0 -0
- src/dots_tts/modules/backbone/__init__.py +1 -0
- src/dots_tts/modules/backbone/dit.py +205 -0
- src/dots_tts/modules/backbone/layers.py +333 -0
- src/dots_tts/modules/backbone/semantic_encoder.py +356 -0
- src/dots_tts/modules/speaker/__init__.py +1 -0
- src/dots_tts/modules/speaker/campplus.py +200 -0
- src/dots_tts/modules/speaker/campplus_layers.py +258 -0
- src/dots_tts/modules/speaker/encoder.py +226 -0
- src/dots_tts/modules/speaker/fbank.py +31 -0
- src/dots_tts/modules/vocoder/__init__.py +1 -0
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2026 OpenMOSS Team, Fudan University, SII and MOSI
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -9,6 +9,12 @@ python_version: '3.12'
|
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
license: apache-2.0
|
| 12 |
+
tags:
|
| 13 |
+
- zerogpu
|
| 14 |
+
- aoti
|
| 15 |
+
- text-to-speech
|
| 16 |
---
|
| 17 |
|
| 18 |
+
dots.tts Gradio Space for Hugging Face ZeroGPU with optional PyTorch AOTInductor startup compilation.
|
| 19 |
+
|
| 20 |
+
Set `DOTS_TTS_MODEL_NAME_OR_PATH` to a local model directory or Hugging Face model repo id. The app defaults to `rednote-hilab/dots.tts`.
|
app.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Callable
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 10 |
+
SRC_ROOT = REPO_ROOT / "src"
|
| 11 |
+
|
| 12 |
+
for import_root in (REPO_ROOT, SRC_ROOT):
|
| 13 |
+
import_root_str = str(import_root)
|
| 14 |
+
if import_root_str not in sys.path:
|
| 15 |
+
sys.path.insert(0, import_root_str)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class _SpacesFallback:
|
| 19 |
+
@staticmethod
|
| 20 |
+
def GPU(*decorator_args, **_decorator_kwargs):
|
| 21 |
+
if decorator_args and callable(decorator_args[0]):
|
| 22 |
+
return decorator_args[0]
|
| 23 |
+
|
| 24 |
+
def decorate(fn: Callable[..., Any]) -> Callable[..., Any]:
|
| 25 |
+
return fn
|
| 26 |
+
|
| 27 |
+
return decorate
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
import spaces # type: ignore
|
| 32 |
+
except Exception: # pragma: no cover - only used outside Hugging Face Spaces.
|
| 33 |
+
spaces = _SpacesFallback() # type: ignore
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _env_bool(name: str, default: bool) -> bool:
|
| 37 |
+
value = os.environ.get(name)
|
| 38 |
+
if value is None:
|
| 39 |
+
return default
|
| 40 |
+
return value.strip().lower() in {"1", "true", "yes", "on"}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _env_int(name: str, default: int) -> int:
|
| 44 |
+
value = os.environ.get(name)
|
| 45 |
+
if value is None or not value.strip():
|
| 46 |
+
return default
|
| 47 |
+
return int(value)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _configure_zero_gpu_environment() -> None:
|
| 51 |
+
os.environ.setdefault("DOTS_TTS_COMPILE_BACKEND", "aoti")
|
| 52 |
+
os.environ.setdefault("DOTS_TTS_SKIP_INIT_WARMUP", "1")
|
| 53 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _preload_runtime(app_service, app_config, compile_backend: str):
|
| 57 |
+
runtime, resolved_model_name_or_path = app_service._get_runtime( # noqa: SLF001
|
| 58 |
+
app_config.default_model_name_or_path,
|
| 59 |
+
)
|
| 60 |
+
runtime.optimize = bool(app_config.optimize)
|
| 61 |
+
runtime.model.set_optimize(bool(app_config.optimize))
|
| 62 |
+
if hasattr(runtime.model, "set_compile_backend"):
|
| 63 |
+
runtime.model.set_compile_backend(compile_backend)
|
| 64 |
+
return runtime, resolved_model_name_or_path
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main() -> None:
|
| 68 |
+
_configure_zero_gpu_environment()
|
| 69 |
+
|
| 70 |
+
import gradio as gr
|
| 71 |
+
from loguru import logger
|
| 72 |
+
|
| 73 |
+
from apps.gradio.app import PLAYGROUND_CSS, build_demo, build_playground_theme
|
| 74 |
+
from apps.gradio.service import GradioAppService, build_gradio_app_config
|
| 75 |
+
from dots_tts.utils.logging import configure_logging
|
| 76 |
+
|
| 77 |
+
host = os.environ.get("DOTS_TTS_HOST", "0.0.0.0")
|
| 78 |
+
port = _env_int("DOTS_TTS_PORT", 7860)
|
| 79 |
+
model_name_or_path = os.environ.get(
|
| 80 |
+
"DOTS_TTS_MODEL_NAME_OR_PATH",
|
| 81 |
+
"rednote-hilab/dots.tts",
|
| 82 |
+
)
|
| 83 |
+
precision = os.environ.get("DOTS_TTS_PRECISION", "bfloat16")
|
| 84 |
+
execution_mode = os.environ.get("DOTS_TTS_EXECUTION_MODE", "generate_stream")
|
| 85 |
+
max_generate_length = _env_int("DOTS_TTS_MAX_GENERATE_LENGTH", 500)
|
| 86 |
+
default_num_steps = _env_int("DOTS_TTS_DEFAULT_NUM_STEPS", 10)
|
| 87 |
+
compile_backend = os.environ.get("DOTS_TTS_COMPILE_BACKEND", "aoti").strip().lower()
|
| 88 |
+
enable_aoti = _env_bool("DOTS_TTS_ENABLE_AOTI", True)
|
| 89 |
+
startup_compile = _env_bool("DOTS_TTS_AOTI_COMPILE_ON_STARTUP", True)
|
| 90 |
+
optimize = _env_bool("DOTS_TTS_OPTIMIZE", True)
|
| 91 |
+
generation_duration = _env_int("DOTS_TTS_ZERO_GPU_DURATION", 600)
|
| 92 |
+
compile_duration = _env_int("DOTS_TTS_ZERO_GPU_COMPILE_DURATION", 1500)
|
| 93 |
+
output_dir = Path(os.environ.get("DOTS_TTS_OUTPUT_DIR", "/tmp/dots_tts_outputs"))
|
| 94 |
+
log_file = Path(os.environ.get("DOTS_TTS_LOG_FILE", "/tmp/dots_tts_gradio.log"))
|
| 95 |
+
|
| 96 |
+
configure_logging(log_file=log_file)
|
| 97 |
+
logger.info(
|
| 98 |
+
"Space app starting: model={} execution_mode={} precision={} optimize={} "
|
| 99 |
+
"compile_backend={} enable_aoti={} startup_compile={} max_generate_length={}",
|
| 100 |
+
model_name_or_path,
|
| 101 |
+
execution_mode,
|
| 102 |
+
precision,
|
| 103 |
+
optimize,
|
| 104 |
+
compile_backend,
|
| 105 |
+
enable_aoti,
|
| 106 |
+
startup_compile,
|
| 107 |
+
max_generate_length,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
app_config = build_gradio_app_config(
|
| 111 |
+
host=host,
|
| 112 |
+
port=port,
|
| 113 |
+
execution_mode=execution_mode,
|
| 114 |
+
precision=precision,
|
| 115 |
+
optimize=optimize,
|
| 116 |
+
model_name_or_path=model_name_or_path,
|
| 117 |
+
output_dir=output_dir,
|
| 118 |
+
max_generate_length=max_generate_length,
|
| 119 |
+
default_num_steps=default_num_steps,
|
| 120 |
+
default_max_generate_length=max_generate_length,
|
| 121 |
+
repo_root=REPO_ROOT,
|
| 122 |
+
)
|
| 123 |
+
app_service = GradioAppService(app_config)
|
| 124 |
+
runtime, resolved_model_name_or_path = _preload_runtime(
|
| 125 |
+
app_service,
|
| 126 |
+
app_config,
|
| 127 |
+
compile_backend if enable_aoti else "torch_compile",
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if enable_aoti and startup_compile and optimize:
|
| 131 |
+
|
| 132 |
+
@spaces.GPU(duration=compile_duration)
|
| 133 |
+
def compile_aoti_cache():
|
| 134 |
+
child_runtime, _ = _preload_runtime(
|
| 135 |
+
app_service,
|
| 136 |
+
app_config,
|
| 137 |
+
compile_backend,
|
| 138 |
+
)
|
| 139 |
+
child_runtime.model.run_warmup(
|
| 140 |
+
max_generate_length=app_config.max_generate_length,
|
| 141 |
+
precision=app_config.precision,
|
| 142 |
+
num_steps=app_config.default_num_steps,
|
| 143 |
+
guidance_scale=app_config.default_guidance_scale,
|
| 144 |
+
)
|
| 145 |
+
return child_runtime.model.export_compiled_models()
|
| 146 |
+
|
| 147 |
+
compiled_models = compile_aoti_cache()
|
| 148 |
+
if compiled_models:
|
| 149 |
+
runtime.model.import_compiled_models(compiled_models)
|
| 150 |
+
logger.info(
|
| 151 |
+
"AOTI startup compile completed: compiled_target_count={}",
|
| 152 |
+
len(compiled_models or {}),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
app_service.generate = spaces.GPU(duration=generation_duration)(app_service.generate)
|
| 156 |
+
|
| 157 |
+
demo = build_demo(gr, app_config, app_service)
|
| 158 |
+
logger.info(
|
| 159 |
+
"Space app ready: host={} port={} resolved_model={} compiled_target_count={}",
|
| 160 |
+
app_config.host,
|
| 161 |
+
app_config.port,
|
| 162 |
+
resolved_model_name_or_path,
|
| 163 |
+
len(runtime.model.export_compiled_models())
|
| 164 |
+
if hasattr(runtime.model, "export_compiled_models")
|
| 165 |
+
else 0,
|
| 166 |
+
)
|
| 167 |
+
demo.launch(
|
| 168 |
+
server_name=app_config.host,
|
| 169 |
+
server_port=app_config.port,
|
| 170 |
+
theme=build_playground_theme(gr),
|
| 171 |
+
css=PLAYGROUND_CSS,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
if __name__ == "__main__":
|
| 176 |
+
main()
|
apps/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Application entrypoints for dots.tts."""
|
apps/gradio/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Gradio application for dots.tts."""
|
apps/gradio/app.py
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 10 |
+
SRC_ROOT = REPO_ROOT / "src"
|
| 11 |
+
|
| 12 |
+
for import_root in (REPO_ROOT, SRC_ROOT):
|
| 13 |
+
import_root_str = str(import_root)
|
| 14 |
+
if import_root_str not in sys.path:
|
| 15 |
+
sys.path.insert(0, import_root_str)
|
| 16 |
+
|
| 17 |
+
from apps.gradio.constants import ( # noqa: E402
|
| 18 |
+
DEFAULT_EXECUTION_MODE,
|
| 19 |
+
DEFAULT_GUIDANCE_SCALE,
|
| 20 |
+
DEFAULT_HOST,
|
| 21 |
+
DEFAULT_INPUT_TEXT,
|
| 22 |
+
DEFAULT_LOG_FILE,
|
| 23 |
+
DEFAULT_MAX_GENERATE_LENGTH,
|
| 24 |
+
DEFAULT_NUM_STEPS,
|
| 25 |
+
DEFAULT_ODE_METHOD,
|
| 26 |
+
DEFAULT_OUTPUT_DIR,
|
| 27 |
+
DEFAULT_OUTPUT_RETENTION,
|
| 28 |
+
DEFAULT_PORT,
|
| 29 |
+
DEFAULT_PRECISION,
|
| 30 |
+
DEFAULT_PROMPT_NAME,
|
| 31 |
+
DEFAULT_SEED,
|
| 32 |
+
DEFAULT_SPEAKER_SCALE,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
if TYPE_CHECKING:
|
| 36 |
+
import gradio as gr
|
| 37 |
+
|
| 38 |
+
DEBUG_GRADIO_ENABLED = os.environ.get("DEBUG_GRADIO", "0") == "1"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
PLAYGROUND_CSS = """
|
| 42 |
+
.gradio-container {
|
| 43 |
+
width: min(1600px, calc(100vw - 32px)) !important;
|
| 44 |
+
max-width: none !important;
|
| 45 |
+
margin: 0 auto !important;
|
| 46 |
+
padding-left: 0 !important;
|
| 47 |
+
padding-right: 0 !important;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
.gradio-container,
|
| 51 |
+
.gradio-container .gradio-container {
|
| 52 |
+
--block-label-background-fill: #CCE5FF;
|
| 53 |
+
--block-label-text-color: #6666FF;
|
| 54 |
+
--block-label-border-color: #99c7ee;
|
| 55 |
+
--block-label-text-weight: 600;
|
| 56 |
+
--block-title-background-fill: #CCE5FF;
|
| 57 |
+
--block-title-text-color: #6666FF;
|
| 58 |
+
--block-title-border-color: #99c7ee;
|
| 59 |
+
--block-title-border-width: var(--block-label-border-width);
|
| 60 |
+
--block-title-radius: var(--block-label-radius);
|
| 61 |
+
--block-title-padding: var(--block-label-padding);
|
| 62 |
+
--block-title-text-size: var(--block-label-text-size);
|
| 63 |
+
--block-title-text-weight: 600;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
.gradio-container label[data-testid="block-label"],
|
| 67 |
+
.gradio-container label[data-testid="block-label"] *,
|
| 68 |
+
.gradio-container span[data-testid="block-info"],
|
| 69 |
+
.gradio-container span[data-testid="block-info"] * {
|
| 70 |
+
background: #CCE5FF !important;
|
| 71 |
+
border-color: #99c7ee !important;
|
| 72 |
+
color: #6666FF !important;
|
| 73 |
+
fill: #6666FF !important;
|
| 74 |
+
font-family: Verdana, Geneva, "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans CJK SC", sans-serif !important;
|
| 75 |
+
font-style: normal !important;
|
| 76 |
+
font-size: 0.78rem !important;
|
| 77 |
+
line-height: 1.2 !important;
|
| 78 |
+
letter-spacing: 0 !important;
|
| 79 |
+
text-transform: none !important;
|
| 80 |
+
}
|
| 81 |
+
.gradio-container label[data-testid="block-label"],
|
| 82 |
+
.gradio-container span[data-testid="block-info"],
|
| 83 |
+
.gradio-container [data-testid="block-title"],
|
| 84 |
+
.gradio-container .block-title {
|
| 85 |
+
border: var(--block-label-border-width) solid #99c7ee !important;
|
| 86 |
+
border-top: none !important;
|
| 87 |
+
border-left: none !important;
|
| 88 |
+
border-radius: var(--block-label-radius) !important;
|
| 89 |
+
box-shadow: var(--block-label-shadow) !important;
|
| 90 |
+
padding: var(--block-label-padding) !important;
|
| 91 |
+
}
|
| 92 |
+
.gradio-container label[data-testid="block-label"],
|
| 93 |
+
.gradio-container label[data-testid="block-label"] *,
|
| 94 |
+
.gradio-container span[data-testid="block-info"],
|
| 95 |
+
.gradio-container span[data-testid="block-info"] *,
|
| 96 |
+
.gradio-container [data-testid="block-title"],
|
| 97 |
+
.gradio-container [data-testid="block-title"] *,
|
| 98 |
+
.gradio-container .block-title,
|
| 99 |
+
.gradio-container .block-title * {
|
| 100 |
+
font-weight: 600 !important;
|
| 101 |
+
}
|
| 102 |
+
.gradio-container .block label > span,
|
| 103 |
+
.gradio-container .block label > span *,
|
| 104 |
+
.gradio-container .form label > span,
|
| 105 |
+
.gradio-container .form label > span *,
|
| 106 |
+
.gradio-container label > span:first-child,
|
| 107 |
+
.gradio-container label > span:first-child * {
|
| 108 |
+
font-weight: 600 !important;
|
| 109 |
+
}
|
| 110 |
+
.strong-label [data-testid="block-label"],
|
| 111 |
+
.strong-label [data-testid="block-label"] *,
|
| 112 |
+
.strong-label span[data-testid="block-info"],
|
| 113 |
+
.strong-label span[data-testid="block-info"] *,
|
| 114 |
+
.strong-label [data-testid="block-title"],
|
| 115 |
+
.strong-label [data-testid="block-title"] *,
|
| 116 |
+
.strong-label .block-label,
|
| 117 |
+
.strong-label .block-label *,
|
| 118 |
+
.strong-label .block-title,
|
| 119 |
+
.strong-label .block-title *,
|
| 120 |
+
.strong-label label > span:first-child,
|
| 121 |
+
.strong-label label > span:first-child * {
|
| 122 |
+
font-weight: 600 !important;
|
| 123 |
+
}
|
| 124 |
+
.gradio-container .info-text,
|
| 125 |
+
.gradio-container .info-text * {
|
| 126 |
+
font-weight: 400 !important;
|
| 127 |
+
}
|
| 128 |
+
.gradio-container input,
|
| 129 |
+
.gradio-container textarea,
|
| 130 |
+
.gradio-container select,
|
| 131 |
+
.gradio-container [role="textbox"],
|
| 132 |
+
.gradio-container [contenteditable="true"] {
|
| 133 |
+
font-weight: 400 !important;
|
| 134 |
+
}
|
| 135 |
+
.gradio-container label[data-testid="block-label"] > span:first-child {
|
| 136 |
+
display: none !important;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
.generate-button {
|
| 140 |
+
background: #6666FF !important;
|
| 141 |
+
color: #ffffff !important;
|
| 142 |
+
border: 1px solid #5555ee !important;
|
| 143 |
+
font-family: Verdana, Geneva, sans-serif !important;
|
| 144 |
+
}
|
| 145 |
+
.generate-button:hover {
|
| 146 |
+
background: #5555ee !important;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
#playground-banner {
|
| 150 |
+
padding: 0;
|
| 151 |
+
border-radius: 0;
|
| 152 |
+
margin-bottom: 18px;
|
| 153 |
+
background: transparent;
|
| 154 |
+
border: 0;
|
| 155 |
+
}
|
| 156 |
+
#playground-banner h1 {
|
| 157 |
+
margin: 0 0 4px 0;
|
| 158 |
+
font-size: 1.7rem;
|
| 159 |
+
font-weight: 700;
|
| 160 |
+
color: #0f172a;
|
| 161 |
+
letter-spacing: 0;
|
| 162 |
+
}
|
| 163 |
+
#playground-banner .subtitle {
|
| 164 |
+
margin: 0;
|
| 165 |
+
color: #1e293b;
|
| 166 |
+
font-size: 0.9rem;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
.info-card {
|
| 170 |
+
padding: 14px 18px;
|
| 171 |
+
border-radius: 8px;
|
| 172 |
+
border: 1px solid #99c7ee;
|
| 173 |
+
border-left: 4px solid #2563eb;
|
| 174 |
+
background: transparent;
|
| 175 |
+
font-size: 0.86rem;
|
| 176 |
+
line-height: 1.55;
|
| 177 |
+
margin-bottom: 16px;
|
| 178 |
+
box-sizing: border-box;
|
| 179 |
+
color: #0f172a;
|
| 180 |
+
}
|
| 181 |
+
.info-card .card-title,
|
| 182 |
+
.info-card .notice-title {
|
| 183 |
+
display: block;
|
| 184 |
+
font-weight: 600;
|
| 185 |
+
font-size: 0.92rem;
|
| 186 |
+
color: #0f172a;
|
| 187 |
+
}
|
| 188 |
+
.info-card .card-title {
|
| 189 |
+
margin-bottom: 4px;
|
| 190 |
+
}
|
| 191 |
+
.info-card .notice-title {
|
| 192 |
+
margin-top: 8px;
|
| 193 |
+
margin-bottom: 4px;
|
| 194 |
+
}
|
| 195 |
+
.info-card ol,
|
| 196 |
+
.info-card ul {
|
| 197 |
+
margin: 0;
|
| 198 |
+
padding-left: 18px;
|
| 199 |
+
}
|
| 200 |
+
.info-card li {
|
| 201 |
+
margin: 2px 0;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
.main-workspace {
|
| 205 |
+
gap: 18px !important;
|
| 206 |
+
align-items: stretch !important;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
.prompt-column,
|
| 210 |
+
.synthesis-column {
|
| 211 |
+
gap: 14px !important;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
.control-row,
|
| 215 |
+
.settings-slider-row {
|
| 216 |
+
gap: 14px !important;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
.settings-card {
|
| 220 |
+
margin-top: 2px !important;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
.generate-button {
|
| 224 |
+
margin-top: 2px !important;
|
| 225 |
+
width: 100% !important;
|
| 226 |
+
box-sizing: border-box !important;
|
| 227 |
+
flex: 0 0 auto !important;
|
| 228 |
+
min-height: 44px !important;
|
| 229 |
+
padding-top: 10px !important;
|
| 230 |
+
padding-bottom: 10px !important;
|
| 231 |
+
font-size: 1rem !important;
|
| 232 |
+
font-weight: 600 !important;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
.output-audio {
|
| 236 |
+
flex: 0 0 auto !important;
|
| 237 |
+
min-height: 190px !important;
|
| 238 |
+
}
|
| 239 |
+
.output-audio audio {
|
| 240 |
+
width: 100% !important;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
@media (max-width: 768px) {
|
| 244 |
+
.gradio-container {
|
| 245 |
+
width: calc(100vw - 20px) !important;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def build_playground_theme(gr):
|
| 254 |
+
return gr.themes.Soft(
|
| 255 |
+
primary_hue="slate",
|
| 256 |
+
secondary_hue="slate",
|
| 257 |
+
neutral_hue="slate",
|
| 258 |
+
radius_size="md",
|
| 259 |
+
text_size="md",
|
| 260 |
+
spacing_size="md",
|
| 261 |
+
font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"],
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 266 |
+
parser = argparse.ArgumentParser(description="dots.tts Gradio app.")
|
| 267 |
+
parser.add_argument("--host", default=DEFAULT_HOST, help="Server host")
|
| 268 |
+
parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="Server port")
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--execution-mode",
|
| 271 |
+
choices=("generate", "generate_stream"),
|
| 272 |
+
default=DEFAULT_EXECUTION_MODE,
|
| 273 |
+
help="Runtime execution mode fixed for the app",
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--precision",
|
| 277 |
+
default=DEFAULT_PRECISION,
|
| 278 |
+
help="Inference precision fixed for the app runtime",
|
| 279 |
+
)
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
"--optimize",
|
| 282 |
+
action="store_true",
|
| 283 |
+
help="Enable runtime optimize acceleration",
|
| 284 |
+
)
|
| 285 |
+
parser.add_argument(
|
| 286 |
+
"--model-name-or-path",
|
| 287 |
+
default=None,
|
| 288 |
+
help="Default model directory or Hugging Face repo id",
|
| 289 |
+
)
|
| 290 |
+
parser.add_argument(
|
| 291 |
+
"--output-dir",
|
| 292 |
+
default=str(DEFAULT_OUTPUT_DIR),
|
| 293 |
+
help="Directory for generated wav outputs",
|
| 294 |
+
)
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--log-file",
|
| 297 |
+
default=str(DEFAULT_LOG_FILE),
|
| 298 |
+
help="Path to the Gradio log file",
|
| 299 |
+
)
|
| 300 |
+
parser.add_argument(
|
| 301 |
+
"--output-retention-count",
|
| 302 |
+
type=int,
|
| 303 |
+
default=DEFAULT_OUTPUT_RETENTION,
|
| 304 |
+
help="Maximum number of generated wav files to keep",
|
| 305 |
+
)
|
| 306 |
+
parser.add_argument(
|
| 307 |
+
"--max-generate-length",
|
| 308 |
+
type=int,
|
| 309 |
+
default=DEFAULT_MAX_GENERATE_LENGTH,
|
| 310 |
+
help="Maximum generation schedule length fixed for the app runtime",
|
| 311 |
+
)
|
| 312 |
+
parser.add_argument(
|
| 313 |
+
"--default-prompt-name",
|
| 314 |
+
default=DEFAULT_PROMPT_NAME,
|
| 315 |
+
help="Default built-in voice preset name",
|
| 316 |
+
)
|
| 317 |
+
parser.add_argument(
|
| 318 |
+
"--default-precision",
|
| 319 |
+
default=DEFAULT_PRECISION,
|
| 320 |
+
choices=["bfloat16", "float32", "float16"],
|
| 321 |
+
help="Default precision selected in the UI",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--default-num-steps",
|
| 325 |
+
type=int,
|
| 326 |
+
default=DEFAULT_NUM_STEPS,
|
| 327 |
+
help="Default Num Steps selected in the UI",
|
| 328 |
+
)
|
| 329 |
+
parser.add_argument(
|
| 330 |
+
"--default-guidance-scale",
|
| 331 |
+
type=float,
|
| 332 |
+
default=DEFAULT_GUIDANCE_SCALE,
|
| 333 |
+
help="Default Guidance Scale selected in the UI",
|
| 334 |
+
)
|
| 335 |
+
parser.add_argument(
|
| 336 |
+
"--default-speaker-scale",
|
| 337 |
+
type=float,
|
| 338 |
+
default=DEFAULT_SPEAKER_SCALE,
|
| 339 |
+
help="Default Speaker Scale selected in the UI",
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--default-max-generate-length",
|
| 343 |
+
type=int,
|
| 344 |
+
default=DEFAULT_MAX_GENERATE_LENGTH,
|
| 345 |
+
help="Default Max Generate Length selected in the UI",
|
| 346 |
+
)
|
| 347 |
+
parser.add_argument(
|
| 348 |
+
"--skip-warmup",
|
| 349 |
+
action="store_true",
|
| 350 |
+
help="Start the Gradio server without running an initial synthesis warmup.",
|
| 351 |
+
)
|
| 352 |
+
return parser.parse_args(argv)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def build_startup_config_panel(gr, app_config) -> None:
|
| 356 |
+
with gr.Accordion("启动固定参数", open=False):
|
| 357 |
+
gr.Markdown("只读。修改这部分需要重启服务并传入新的启动参数。")
|
| 358 |
+
gr.Textbox(
|
| 359 |
+
label="Model",
|
| 360 |
+
value=app_config.default_model_name_or_path,
|
| 361 |
+
interactive=False,
|
| 362 |
+
)
|
| 363 |
+
with gr.Row():
|
| 364 |
+
gr.Textbox(
|
| 365 |
+
label="Execution Mode",
|
| 366 |
+
value=app_config.execution_mode,
|
| 367 |
+
interactive=False,
|
| 368 |
+
)
|
| 369 |
+
gr.Textbox(
|
| 370 |
+
label="Precision",
|
| 371 |
+
value=app_config.precision,
|
| 372 |
+
interactive=False,
|
| 373 |
+
)
|
| 374 |
+
with gr.Row():
|
| 375 |
+
gr.Number(
|
| 376 |
+
label="Max Generate Length",
|
| 377 |
+
value=app_config.max_generate_length,
|
| 378 |
+
precision=0,
|
| 379 |
+
interactive=False,
|
| 380 |
+
)
|
| 381 |
+
gr.Checkbox(
|
| 382 |
+
label="Optimize",
|
| 383 |
+
value=app_config.optimize,
|
| 384 |
+
interactive=False,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def build_demo(gr, app_config, app_service) -> "gr.Blocks":
|
| 389 |
+
from apps.gradio.service import (
|
| 390 |
+
GRADIO_SYNTHESIS_MODE_CHOICES,
|
| 391 |
+
SynthesisRequest,
|
| 392 |
+
build_prompt_choice_items,
|
| 393 |
+
resolve_prompt_selection,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
def select_prompt_preset(prompt_name: str):
|
| 397 |
+
audio_path, prompt_text = resolve_prompt_selection(
|
| 398 |
+
prompt_name,
|
| 399 |
+
app_config.prompt_presets,
|
| 400 |
+
)
|
| 401 |
+
return audio_path, prompt_text
|
| 402 |
+
|
| 403 |
+
def run_synthesis(
|
| 404 |
+
text: str,
|
| 405 |
+
synthesis_mode: str,
|
| 406 |
+
prompt_audio_path: str | None,
|
| 407 |
+
prompt_text: str,
|
| 408 |
+
ode_method: str,
|
| 409 |
+
num_steps: float,
|
| 410 |
+
guidance_scale: float,
|
| 411 |
+
speaker_scale: float,
|
| 412 |
+
normalize_text: bool,
|
| 413 |
+
seed: float,
|
| 414 |
+
):
|
| 415 |
+
resolved_synthesis_mode = synthesis_mode if DEBUG_GRADIO_ENABLED else "tts"
|
| 416 |
+
request = SynthesisRequest(
|
| 417 |
+
model_name_or_path=app_config.default_model_name_or_path,
|
| 418 |
+
text=text,
|
| 419 |
+
prompt_audio_path=prompt_audio_path,
|
| 420 |
+
prompt_text=prompt_text,
|
| 421 |
+
execution_mode=app_config.execution_mode,
|
| 422 |
+
template_name=resolved_synthesis_mode,
|
| 423 |
+
ode_method=ode_method,
|
| 424 |
+
num_steps=int(num_steps),
|
| 425 |
+
guidance_scale=float(guidance_scale),
|
| 426 |
+
speaker_scale=float(speaker_scale),
|
| 427 |
+
normalize_text=normalize_text,
|
| 428 |
+
seed=int(seed),
|
| 429 |
+
)
|
| 430 |
+
result = app_service.generate(request)
|
| 431 |
+
return result.audio_path, result.metrics
|
| 432 |
+
|
| 433 |
+
show_prompt_preset = bool(app_config.prompt_presets)
|
| 434 |
+
|
| 435 |
+
with gr.Blocks(title="dots.tts") as demo:
|
| 436 |
+
gr.HTML(
|
| 437 |
+
"<style>\n"
|
| 438 |
+
+ PLAYGROUND_CSS
|
| 439 |
+
+ "\n</style>\n"
|
| 440 |
+
+ """
|
| 441 |
+
<div id="playground-banner">
|
| 442 |
+
<h1>dots.tts</h1>
|
| 443 |
+
<p class="subtitle">Fully-continuous Autoregressive TTS · 48 kHz · Voice Cloning</p>
|
| 444 |
+
</div>
|
| 445 |
+
""",
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
gr.HTML(
|
| 449 |
+
"""
|
| 450 |
+
<div class="info-card">
|
| 451 |
+
<span class="card-title">使用说明 · Instructions</span>
|
| 452 |
+
<ol>
|
| 453 |
+
<li>上传参考音频并填写对应转写文本 · Upload prompt audio and fill in its transcript.</li>
|
| 454 |
+
<li>在文本框中输入要合成的内容 · Enter the text to synthesize.</li>
|
| 455 |
+
<li>点击 <b>Generate</b> 合成声音 · Click <b>Generate</b> to synthesize speech.</li>
|
| 456 |
+
</ol>
|
| 457 |
+
</div>
|
| 458 |
+
""",
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
with gr.Row(equal_height=True, elem_classes="main-workspace"):
|
| 462 |
+
with gr.Column(scale=1, min_width=480, elem_classes="prompt-column"):
|
| 463 |
+
prompt_preset = gr.Dropdown(
|
| 464 |
+
label="音色 · Voice Preset",
|
| 465 |
+
choices=build_prompt_choice_items(app_config.prompt_presets),
|
| 466 |
+
value=app_config.default_prompt_name,
|
| 467 |
+
info="内置音色clone样本;选择后自动填入参考音频与转写。",
|
| 468 |
+
elem_id="voice-preset-dropdown",
|
| 469 |
+
elem_classes="strong-label",
|
| 470 |
+
visible=show_prompt_preset,
|
| 471 |
+
)
|
| 472 |
+
prompt_audio_path = gr.Audio(
|
| 473 |
+
label="参考音频 · Prompt Audio",
|
| 474 |
+
sources=["upload"],
|
| 475 |
+
type="filepath",
|
| 476 |
+
value=app_config.default_prompt_audio_path,
|
| 477 |
+
elem_classes="strong-label",
|
| 478 |
+
)
|
| 479 |
+
prompt_text = gr.Textbox(
|
| 480 |
+
label="参考音频转写 · Prompt Text",
|
| 481 |
+
lines=5,
|
| 482 |
+
value=app_config.default_prompt_text,
|
| 483 |
+
placeholder="Prompt audio 对应的文本转写(continuation cloning 必填)",
|
| 484 |
+
elem_classes="strong-label",
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
with gr.Column(scale=1, min_width=480, elem_classes="synthesis-column"):
|
| 488 |
+
text = gr.Textbox(
|
| 489 |
+
label="待合成文本 · Text",
|
| 490 |
+
lines=5,
|
| 491 |
+
max_lines=8,
|
| 492 |
+
value=DEFAULT_INPUT_TEXT,
|
| 493 |
+
placeholder="输入待合成的文本",
|
| 494 |
+
elem_classes="strong-label",
|
| 495 |
+
)
|
| 496 |
+
with gr.Accordion("⚙️ Settings", open=False, elem_classes="settings-card"):
|
| 497 |
+
with gr.Row(elem_classes="settings-slider-row"):
|
| 498 |
+
num_steps = gr.Slider(
|
| 499 |
+
label="Num Steps",
|
| 500 |
+
minimum=1,
|
| 501 |
+
maximum=32,
|
| 502 |
+
step=1,
|
| 503 |
+
value=app_config.default_num_steps,
|
| 504 |
+
)
|
| 505 |
+
with gr.Row(elem_classes="settings-slider-row"):
|
| 506 |
+
guidance_scale = gr.Slider(
|
| 507 |
+
label="Guidance Scale",
|
| 508 |
+
minimum=1.0,
|
| 509 |
+
maximum=3.0,
|
| 510 |
+
step=0.1,
|
| 511 |
+
value=app_config.default_guidance_scale,
|
| 512 |
+
)
|
| 513 |
+
with gr.Row(elem_classes="control-row"):
|
| 514 |
+
seed = gr.Number(
|
| 515 |
+
label="Seed",
|
| 516 |
+
value=DEFAULT_SEED,
|
| 517 |
+
precision=0,
|
| 518 |
+
scale=1,
|
| 519 |
+
min_width=180,
|
| 520 |
+
)
|
| 521 |
+
normalize_text = gr.Checkbox(
|
| 522 |
+
label="Normalize Text",
|
| 523 |
+
value=False,
|
| 524 |
+
scale=1,
|
| 525 |
+
min_width=180,
|
| 526 |
+
)
|
| 527 |
+
generate = gr.Button(
|
| 528 |
+
"Generate",
|
| 529 |
+
variant="primary",
|
| 530 |
+
size="lg",
|
| 531 |
+
elem_classes="generate-button",
|
| 532 |
+
)
|
| 533 |
+
audio_out = gr.Audio(
|
| 534 |
+
label="生成音频 · Output",
|
| 535 |
+
type="filepath",
|
| 536 |
+
elem_classes="output-audio",
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
if DEBUG_GRADIO_ENABLED:
|
| 540 |
+
with gr.Accordion("Debug", open=False):
|
| 541 |
+
synthesis_mode = gr.Dropdown(
|
| 542 |
+
label="SynthesisMode",
|
| 543 |
+
choices=list(GRADIO_SYNTHESIS_MODE_CHOICES),
|
| 544 |
+
value="tts",
|
| 545 |
+
info="选择合成模式;界面显示名会自动映射到 runtime 对应模板。",
|
| 546 |
+
)
|
| 547 |
+
ode_method = gr.Textbox(
|
| 548 |
+
label="ODE Method",
|
| 549 |
+
value=DEFAULT_ODE_METHOD,
|
| 550 |
+
lines=1,
|
| 551 |
+
)
|
| 552 |
+
speaker_scale = gr.Slider(
|
| 553 |
+
label="Speaker Scale",
|
| 554 |
+
minimum=0.0,
|
| 555 |
+
maximum=3.0,
|
| 556 |
+
step=0.1,
|
| 557 |
+
value=app_config.default_speaker_scale,
|
| 558 |
+
info="说话人 x-vector 强度",
|
| 559 |
+
)
|
| 560 |
+
metrics = gr.JSON(label="Metrics", value=app_service.metadata())
|
| 561 |
+
build_startup_config_panel(gr, app_config)
|
| 562 |
+
else:
|
| 563 |
+
synthesis_mode = gr.State(value="tts")
|
| 564 |
+
ode_method = gr.State(value=DEFAULT_ODE_METHOD)
|
| 565 |
+
speaker_scale = gr.State(value=app_config.default_speaker_scale)
|
| 566 |
+
metrics = gr.State(value={})
|
| 567 |
+
|
| 568 |
+
generate.click(
|
| 569 |
+
fn=run_synthesis,
|
| 570 |
+
inputs=[
|
| 571 |
+
text,
|
| 572 |
+
synthesis_mode,
|
| 573 |
+
prompt_audio_path,
|
| 574 |
+
prompt_text,
|
| 575 |
+
ode_method,
|
| 576 |
+
num_steps,
|
| 577 |
+
guidance_scale,
|
| 578 |
+
speaker_scale,
|
| 579 |
+
normalize_text,
|
| 580 |
+
seed,
|
| 581 |
+
],
|
| 582 |
+
outputs=[audio_out, metrics],
|
| 583 |
+
concurrency_limit=1,
|
| 584 |
+
)
|
| 585 |
+
prompt_preset.change(
|
| 586 |
+
fn=select_prompt_preset,
|
| 587 |
+
inputs=[prompt_preset],
|
| 588 |
+
outputs=[prompt_audio_path, prompt_text],
|
| 589 |
+
concurrency_limit=1,
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
return demo.queue(default_concurrency_limit=1, max_size=8)
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def main() -> None:
|
| 596 |
+
args = parse_args()
|
| 597 |
+
import gradio as gr
|
| 598 |
+
from loguru import logger
|
| 599 |
+
|
| 600 |
+
from apps.gradio.service import GradioAppService, build_gradio_app_config
|
| 601 |
+
from dots_tts.utils.logging import configure_logging
|
| 602 |
+
|
| 603 |
+
configure_logging(log_file=args.log_file)
|
| 604 |
+
logger.info(
|
| 605 |
+
"Gradio app starting: host={} port={} model_name_or_path={} output_dir={} "
|
| 606 |
+
"log_file={} output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={} "
|
| 607 |
+
"default_prompt_name={} skip_warmup={}",
|
| 608 |
+
args.host,
|
| 609 |
+
args.port,
|
| 610 |
+
args.model_name_or_path,
|
| 611 |
+
args.output_dir,
|
| 612 |
+
args.log_file,
|
| 613 |
+
args.output_retention_count,
|
| 614 |
+
args.max_generate_length,
|
| 615 |
+
args.execution_mode,
|
| 616 |
+
args.precision,
|
| 617 |
+
args.optimize,
|
| 618 |
+
args.default_prompt_name,
|
| 619 |
+
args.skip_warmup,
|
| 620 |
+
)
|
| 621 |
+
app_config = build_gradio_app_config(
|
| 622 |
+
host=args.host,
|
| 623 |
+
port=args.port,
|
| 624 |
+
execution_mode=args.execution_mode,
|
| 625 |
+
precision=args.precision,
|
| 626 |
+
optimize=args.optimize,
|
| 627 |
+
model_name_or_path=args.model_name_or_path,
|
| 628 |
+
output_dir=Path(args.output_dir),
|
| 629 |
+
output_retention_count=args.output_retention_count,
|
| 630 |
+
max_generate_length=args.max_generate_length,
|
| 631 |
+
default_prompt_name=args.default_prompt_name,
|
| 632 |
+
default_precision=args.default_precision,
|
| 633 |
+
default_num_steps=args.default_num_steps,
|
| 634 |
+
default_guidance_scale=args.default_guidance_scale,
|
| 635 |
+
default_speaker_scale=args.default_speaker_scale,
|
| 636 |
+
default_max_generate_length=args.default_max_generate_length,
|
| 637 |
+
)
|
| 638 |
+
app_service = GradioAppService(app_config)
|
| 639 |
+
if args.skip_warmup:
|
| 640 |
+
logger.info("Gradio app warmup skipped by --skip-warmup.")
|
| 641 |
+
else:
|
| 642 |
+
warmup_metrics = app_service.warmup()
|
| 643 |
+
logger.info("Gradio app warmup metrics: {}", warmup_metrics)
|
| 644 |
+
demo = build_demo(gr, app_config, app_service)
|
| 645 |
+
logger.info(
|
| 646 |
+
"Gradio app ready: host={} port={} execution_mode={} precision={} optimize={} default_model_name_or_path={}",
|
| 647 |
+
app_config.host,
|
| 648 |
+
app_config.port,
|
| 649 |
+
app_config.execution_mode,
|
| 650 |
+
app_config.precision,
|
| 651 |
+
app_config.optimize,
|
| 652 |
+
app_config.default_model_name_or_path,
|
| 653 |
+
)
|
| 654 |
+
demo.launch(
|
| 655 |
+
server_name=app_config.host,
|
| 656 |
+
server_port=app_config.port,
|
| 657 |
+
theme=build_playground_theme(gr),
|
| 658 |
+
css=PLAYGROUND_CSS,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
if __name__ == "__main__":
|
| 663 |
+
main()
|
apps/gradio/constants.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 6 |
+
DEFAULT_HOST = "0.0.0.0"
|
| 7 |
+
DEFAULT_PORT = 7860
|
| 8 |
+
DEFAULT_OUTPUT_DIR = REPO_ROOT / "apps" / "gradio" / "outputs"
|
| 9 |
+
DEFAULT_LOG_FILE = REPO_ROOT / "apps" / "gradio" / "gradio.log"
|
| 10 |
+
DEFAULT_PROMPTS_DIR = REPO_ROOT / "apps" / "gradio" / "default_prompts"
|
| 11 |
+
DEFAULT_PROMPT_SOURCE_DIR = DEFAULT_PROMPTS_DIR
|
| 12 |
+
DEFAULT_PROMPT_MAPPING_FILE = DEFAULT_PROMPTS_DIR / "prompt_text"
|
| 13 |
+
DEFAULT_OUTPUT_RETENTION = 20
|
| 14 |
+
DEFAULT_EXECUTION_MODE = "generate_stream"
|
| 15 |
+
DEFAULT_PRECISION = "bfloat16"
|
| 16 |
+
DEFAULT_ODE_METHOD = "euler"
|
| 17 |
+
DEFAULT_NUM_STEPS = 10
|
| 18 |
+
DEFAULT_GUIDANCE_SCALE = 1.2
|
| 19 |
+
DEFAULT_SPEAKER_SCALE = 1.5
|
| 20 |
+
DEFAULT_MAX_GENERATE_LENGTH = 500
|
| 21 |
+
DEFAULT_SEED = 42
|
| 22 |
+
DEFAULT_INPUT_TEXT = ""
|
| 23 |
+
DEFAULT_WARMUP_TEXT = "dots.tts is a 2B-parameter fully continuous, end-to-end autoregressive (AR) text-to-speech system. The backbone pairs a semantic encoder, an LLM, and an autoregressive flow-matching acoustic head over a 48 kHz AudioVAE"
|
| 24 |
+
DEFAULT_PROMPT_NAME = "male_zh"
|
| 25 |
+
DEFAULT_PROMPT_NONE = "__none__"
|
| 26 |
+
PROMPT_AUDIO_SUFFIXES = (".wav", ".mp3", ".flac", ".m4a", ".ogg")
|
apps/gradio/default_prompts/prompt_text
ADDED
|
File without changes
|
apps/gradio/languages.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
SUPPORTED_LANGUAGE_CODE_BY_NAME = {
|
| 4 |
+
"普通话": "ZH",
|
| 5 |
+
"粤语": "口音:粤语",
|
| 6 |
+
"北京话": "口音:北京官话",
|
| 7 |
+
"东北话": "口音:东北话",
|
| 8 |
+
"四川话": "口音:四川话",
|
| 9 |
+
"闽南话": "口音:闽南话",
|
| 10 |
+
"吴语": "口音:吴语",
|
| 11 |
+
"英语": "EN",
|
| 12 |
+
"西班牙语": "ES",
|
| 13 |
+
"印地语": "HI",
|
| 14 |
+
"阿拉伯语": "AR",
|
| 15 |
+
"孟加拉语": "BN",
|
| 16 |
+
"葡萄牙语": "PT",
|
| 17 |
+
"俄语": "RU",
|
| 18 |
+
"日语": "JA",
|
| 19 |
+
"法语": "FR",
|
| 20 |
+
"德语": "DE",
|
| 21 |
+
"韩语": "KO",
|
| 22 |
+
"意大利语": "IT",
|
| 23 |
+
"土耳其语": "TR",
|
| 24 |
+
"越南语": "VI",
|
| 25 |
+
"印尼语": "ID",
|
| 26 |
+
"乌尔都语": "UR",
|
| 27 |
+
"波斯语": "FA",
|
| 28 |
+
"泰米尔语": "TA",
|
| 29 |
+
"泰卢固语": "TE",
|
| 30 |
+
"菲律宾语": "FIL",
|
| 31 |
+
"马来语": "MS",
|
| 32 |
+
"旁遮普语": "PA",
|
| 33 |
+
"马拉地语": "MR",
|
| 34 |
+
"古吉拉特语": "GU",
|
| 35 |
+
"马拉雅拉姆语": "ML",
|
| 36 |
+
"卡纳达语": "KN",
|
| 37 |
+
"波兰语": "PL",
|
| 38 |
+
"乌克兰语": "UK",
|
| 39 |
+
"荷兰语": "NL",
|
| 40 |
+
"泰语": "TH",
|
| 41 |
+
"罗马尼亚语": "RO",
|
| 42 |
+
"斯瓦希里语": "SW",
|
| 43 |
+
"希伯来语": "HE",
|
| 44 |
+
"捷克语": "CS",
|
| 45 |
+
"希腊语": "EL",
|
| 46 |
+
"匈牙利语": "HU",
|
| 47 |
+
"瑞典语": "SV",
|
| 48 |
+
"丹麦语": "DA",
|
| 49 |
+
"芬兰语": "FI",
|
| 50 |
+
"书面挪威语": "NB",
|
| 51 |
+
"斯洛伐克语": "SK",
|
| 52 |
+
"斯洛文尼亚语": "SL",
|
| 53 |
+
"塞尔维亚语": "SR",
|
| 54 |
+
"波斯尼亚语": "BS",
|
| 55 |
+
"克罗地亚语": "HR",
|
| 56 |
+
"保加利亚语": "BG",
|
| 57 |
+
"马其顿语": "MK",
|
| 58 |
+
"立陶宛语": "LT",
|
| 59 |
+
"拉脱维亚语": "LV",
|
| 60 |
+
"爱沙尼亚语": "ET",
|
| 61 |
+
"冰岛语": "IS",
|
| 62 |
+
"爱尔兰语": "GA",
|
| 63 |
+
"威尔士语": "CY",
|
| 64 |
+
"加泰罗尼亚语": "CA",
|
| 65 |
+
"加利西亚语": "GL",
|
| 66 |
+
"奥克语": "OC",
|
| 67 |
+
"阿斯图里亚斯语": "AST",
|
| 68 |
+
"尼泊尔语": "NE",
|
| 69 |
+
"信德语": "SD",
|
| 70 |
+
"奥里亚语": "OR",
|
| 71 |
+
"阿萨姆语": "AS",
|
| 72 |
+
"普什图语": "PS",
|
| 73 |
+
"缅甸语": "MY",
|
| 74 |
+
"高棉语": "KM",
|
| 75 |
+
"老挝语": "LO",
|
| 76 |
+
"哈萨克语": "KK",
|
| 77 |
+
"乌兹别克语": "UZ",
|
| 78 |
+
"吉尔吉斯语": "KY",
|
| 79 |
+
"塔吉克语": "TG",
|
| 80 |
+
"阿塞拜疆语": "AZ",
|
| 81 |
+
"格鲁吉亚语": "KA",
|
| 82 |
+
"亚美尼亚语": "HY",
|
| 83 |
+
"白俄罗斯语": "BE",
|
| 84 |
+
"卢森堡语": "LB",
|
| 85 |
+
"马耳他语": "MT",
|
| 86 |
+
"毛利语": "MI",
|
| 87 |
+
"南非荷兰语": "AF",
|
| 88 |
+
"祖鲁语": "ZU",
|
| 89 |
+
"科萨语": "XH",
|
| 90 |
+
"约鲁巴语": "YO",
|
| 91 |
+
"豪萨语": "HA",
|
| 92 |
+
"伊博语": "IG",
|
| 93 |
+
"阿姆哈拉语": "AM",
|
| 94 |
+
"奥罗莫语": "OM",
|
| 95 |
+
"北索托语": "NSO",
|
| 96 |
+
"尼扬贾语": "NY",
|
| 97 |
+
"修纳语": "SN",
|
| 98 |
+
"索马里语": "SO",
|
| 99 |
+
"卢干达语": "LG",
|
| 100 |
+
"林加拉语": "LN",
|
| 101 |
+
"卢奥语": "LUO",
|
| 102 |
+
"坎巴语": "KAM",
|
| 103 |
+
"翁本杜语": "UMB",
|
| 104 |
+
"富拉语": "FF",
|
| 105 |
+
"沃洛夫语": "WO",
|
| 106 |
+
"中库尔德语": "CKB",
|
| 107 |
+
"宿务语": "CEB",
|
| 108 |
+
"佛得角克里奥尔语": "KEA",
|
| 109 |
+
"蒙古语": "MN",
|
| 110 |
+
"爪哇语": "JV",
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def build_language_choice_items() -> list[tuple[str, str]]:
|
| 115 |
+
return [("不指定", ""), *[(name, code) for name, code in SUPPORTED_LANGUAGE_CODE_BY_NAME.items()]]
|
apps/gradio/service.py
ADDED
|
@@ -0,0 +1,773 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import shutil
|
| 4 |
+
import sys
|
| 5 |
+
import threading
|
| 6 |
+
import time
|
| 7 |
+
import uuid
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Literal
|
| 11 |
+
|
| 12 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 13 |
+
SRC_ROOT = REPO_ROOT / "src"
|
| 14 |
+
|
| 15 |
+
for import_root in (REPO_ROOT, SRC_ROOT):
|
| 16 |
+
import_root_str = str(import_root)
|
| 17 |
+
if import_root_str not in sys.path:
|
| 18 |
+
sys.path.insert(0, import_root_str)
|
| 19 |
+
|
| 20 |
+
import soundfile as sf # noqa: E402
|
| 21 |
+
import torch # noqa: E402
|
| 22 |
+
from loguru import logger # noqa: E402
|
| 23 |
+
|
| 24 |
+
from apps.gradio.constants import ( # noqa: E402
|
| 25 |
+
DEFAULT_EXECUTION_MODE,
|
| 26 |
+
DEFAULT_GUIDANCE_SCALE,
|
| 27 |
+
DEFAULT_HOST,
|
| 28 |
+
DEFAULT_MAX_GENERATE_LENGTH,
|
| 29 |
+
DEFAULT_NUM_STEPS,
|
| 30 |
+
DEFAULT_ODE_METHOD,
|
| 31 |
+
DEFAULT_OUTPUT_DIR,
|
| 32 |
+
DEFAULT_OUTPUT_RETENTION,
|
| 33 |
+
DEFAULT_PORT,
|
| 34 |
+
DEFAULT_PRECISION,
|
| 35 |
+
DEFAULT_PROMPT_MAPPING_FILE,
|
| 36 |
+
DEFAULT_PROMPT_NAME,
|
| 37 |
+
DEFAULT_PROMPT_NONE,
|
| 38 |
+
DEFAULT_PROMPT_SOURCE_DIR,
|
| 39 |
+
DEFAULT_PROMPTS_DIR,
|
| 40 |
+
DEFAULT_SEED,
|
| 41 |
+
DEFAULT_SPEAKER_SCALE,
|
| 42 |
+
DEFAULT_WARMUP_TEXT,
|
| 43 |
+
PROMPT_AUDIO_SUFFIXES,
|
| 44 |
+
)
|
| 45 |
+
from apps.gradio.languages import ( # noqa: E402
|
| 46 |
+
SUPPORTED_LANGUAGE_CODE_BY_NAME,
|
| 47 |
+
build_language_choice_items,
|
| 48 |
+
)
|
| 49 |
+
from dots_tts.runtime import DotsTtsRuntime # noqa: E402
|
| 50 |
+
from dots_tts.utils.util import seed_everything # noqa: E402
|
| 51 |
+
|
| 52 |
+
ExecutionMode = Literal["generate", "generate_stream"]
|
| 53 |
+
GRADIO_SYNTHESIS_MODE_CHOICES = (
|
| 54 |
+
("tts", "tts"),
|
| 55 |
+
("instruct_tts", "instruction_tts"),
|
| 56 |
+
("instruct_tts_general", "text_to_audio"),
|
| 57 |
+
)
|
| 58 |
+
GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES = tuple(
|
| 59 |
+
value for _, value in GRADIO_SYNTHESIS_MODE_CHOICES
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass(frozen=True)
|
| 64 |
+
class PromptPreset:
|
| 65 |
+
name: str
|
| 66 |
+
audio_path: str
|
| 67 |
+
prompt_text: str
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _is_prompt_asset(path: Path) -> bool:
|
| 71 |
+
return path.is_file() and (
|
| 72 |
+
path.name == "prompt_text" or path.suffix.lower() in PROMPT_AUDIO_SUFFIXES
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def sync_default_prompt_library(
|
| 77 |
+
source_dir: Path = DEFAULT_PROMPT_SOURCE_DIR,
|
| 78 |
+
target_dir: Path = DEFAULT_PROMPTS_DIR,
|
| 79 |
+
) -> None:
|
| 80 |
+
source_dir = Path(source_dir)
|
| 81 |
+
if not source_dir.is_dir():
|
| 82 |
+
logger.info(
|
| 83 |
+
"Prompt library sync skipped: source_dir={} does not exist.",
|
| 84 |
+
source_dir,
|
| 85 |
+
)
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
target_dir = Path(target_dir)
|
| 89 |
+
target_dir.mkdir(parents=True, exist_ok=True)
|
| 90 |
+
logger.info(
|
| 91 |
+
"Prompt library sync started: source_dir={} target_dir={}",
|
| 92 |
+
source_dir,
|
| 93 |
+
target_dir,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
source_assets = {
|
| 97 |
+
asset.name: asset for asset in sorted(source_dir.iterdir()) if _is_prompt_asset(asset)
|
| 98 |
+
}
|
| 99 |
+
copied_count = 0
|
| 100 |
+
for asset_name, source_asset in source_assets.items():
|
| 101 |
+
target_asset = target_dir / asset_name
|
| 102 |
+
if (
|
| 103 |
+
not target_asset.exists()
|
| 104 |
+
or target_asset.stat().st_size != source_asset.stat().st_size
|
| 105 |
+
or target_asset.stat().st_mtime_ns != source_asset.stat().st_mtime_ns
|
| 106 |
+
):
|
| 107 |
+
shutil.copy2(source_asset, target_asset)
|
| 108 |
+
copied_count += 1
|
| 109 |
+
|
| 110 |
+
removed_count = 0
|
| 111 |
+
for target_asset in sorted(target_dir.iterdir()):
|
| 112 |
+
if _is_prompt_asset(target_asset) and target_asset.name not in source_assets:
|
| 113 |
+
target_asset.unlink(missing_ok=True)
|
| 114 |
+
removed_count += 1
|
| 115 |
+
logger.info(
|
| 116 |
+
"Prompt library sync completed: copied_assets={} removed_assets={} "
|
| 117 |
+
"available_assets={}",
|
| 118 |
+
copied_count,
|
| 119 |
+
removed_count,
|
| 120 |
+
len(source_assets),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _load_prompt_text_map(mapping_file: Path) -> dict[str, str]:
|
| 125 |
+
if not mapping_file.is_file():
|
| 126 |
+
return {}
|
| 127 |
+
|
| 128 |
+
prompt_text_map: dict[str, str] = {}
|
| 129 |
+
with mapping_file.open(encoding="utf-8") as file_obj:
|
| 130 |
+
for raw_line in file_obj:
|
| 131 |
+
line = raw_line.strip()
|
| 132 |
+
if not line or line.startswith("#") or "|" not in line:
|
| 133 |
+
continue
|
| 134 |
+
name, text = line.split("|", 1)
|
| 135 |
+
prompt_text_map[name.strip()] = text.strip()
|
| 136 |
+
return prompt_text_map
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def discover_prompt_presets(
|
| 140 |
+
prompts_dir: Path = DEFAULT_PROMPTS_DIR,
|
| 141 |
+
mapping_file: Path = DEFAULT_PROMPT_MAPPING_FILE,
|
| 142 |
+
) -> tuple[PromptPreset, ...]:
|
| 143 |
+
prompts_dir = Path(prompts_dir)
|
| 144 |
+
if not prompts_dir.is_dir():
|
| 145 |
+
return ()
|
| 146 |
+
|
| 147 |
+
prompt_text_map = _load_prompt_text_map(Path(mapping_file))
|
| 148 |
+
prompt_audio_paths = [
|
| 149 |
+
audio_path
|
| 150 |
+
for audio_path in sorted(prompts_dir.iterdir(), key=lambda path: (path.stem == "child", path.stem))
|
| 151 |
+
if audio_path.is_file() and audio_path.suffix.lower() in PROMPT_AUDIO_SUFFIXES
|
| 152 |
+
]
|
| 153 |
+
return tuple(
|
| 154 |
+
PromptPreset(
|
| 155 |
+
name=audio_path.stem,
|
| 156 |
+
audio_path=str(audio_path.resolve()),
|
| 157 |
+
prompt_text=prompt_text_map.get(audio_path.stem, ""),
|
| 158 |
+
)
|
| 159 |
+
for audio_path in prompt_audio_paths
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def build_prompt_choice_items(
|
| 164 |
+
prompt_presets: tuple[PromptPreset, ...],
|
| 165 |
+
) -> list[tuple[str, str]]:
|
| 166 |
+
return [("No Preset", DEFAULT_PROMPT_NONE), *[(preset.name, preset.name) for preset in prompt_presets]]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def resolve_default_prompt_selection(
|
| 170 |
+
prompt_presets: tuple[PromptPreset, ...],
|
| 171 |
+
default_prompt_name: str = DEFAULT_PROMPT_NAME,
|
| 172 |
+
) -> tuple[str, str | None, str]:
|
| 173 |
+
if not prompt_presets:
|
| 174 |
+
return DEFAULT_PROMPT_NONE, None, ""
|
| 175 |
+
|
| 176 |
+
preset_by_name = {preset.name: preset for preset in prompt_presets}
|
| 177 |
+
selected_name = default_prompt_name if default_prompt_name in preset_by_name else prompt_presets[0].name
|
| 178 |
+
selected_preset = preset_by_name[selected_name]
|
| 179 |
+
return selected_name, selected_preset.audio_path, selected_preset.prompt_text
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def resolve_prompt_selection(
|
| 183 |
+
prompt_name: str,
|
| 184 |
+
prompt_presets: tuple[PromptPreset, ...],
|
| 185 |
+
) -> tuple[str | None, str]:
|
| 186 |
+
if prompt_name == DEFAULT_PROMPT_NONE:
|
| 187 |
+
return None, ""
|
| 188 |
+
|
| 189 |
+
for preset in prompt_presets:
|
| 190 |
+
if preset.name == prompt_name:
|
| 191 |
+
return preset.audio_path, preset.prompt_text
|
| 192 |
+
return None, ""
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def discover_local_model_choices(repo_root: Path = REPO_ROOT) -> list[str]:
|
| 196 |
+
model_root = Path(repo_root) / "pretrained_models"
|
| 197 |
+
if not model_root.is_dir():
|
| 198 |
+
return []
|
| 199 |
+
return sorted(
|
| 200 |
+
path.relative_to(repo_root).as_posix()
|
| 201 |
+
for path in model_root.glob("**/model")
|
| 202 |
+
if path.is_dir()
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def resolve_model_name_or_path(model_name_or_path: str, repo_root: Path = REPO_ROOT) -> str:
|
| 207 |
+
normalized = model_name_or_path.strip()
|
| 208 |
+
if not normalized:
|
| 209 |
+
raise ValueError("model_name_or_path 不能为空。")
|
| 210 |
+
|
| 211 |
+
direct_path = Path(normalized).expanduser()
|
| 212 |
+
if direct_path.exists():
|
| 213 |
+
return str(direct_path.resolve())
|
| 214 |
+
|
| 215 |
+
repo_relative_path = Path(repo_root) / normalized
|
| 216 |
+
if repo_relative_path.exists():
|
| 217 |
+
return str(repo_relative_path.resolve())
|
| 218 |
+
|
| 219 |
+
return normalized
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def default_model_name_or_path(repo_root: Path = REPO_ROOT) -> str:
|
| 223 |
+
discovered = discover_local_model_choices(repo_root=repo_root)
|
| 224 |
+
if not discovered:
|
| 225 |
+
return ""
|
| 226 |
+
return discovered[0]
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@dataclass(frozen=True)
|
| 230 |
+
class GradioAppConfig:
|
| 231 |
+
host: str
|
| 232 |
+
port: int
|
| 233 |
+
execution_mode: ExecutionMode
|
| 234 |
+
precision: str
|
| 235 |
+
optimize: bool
|
| 236 |
+
output_dir: Path
|
| 237 |
+
prompts_dir: Path
|
| 238 |
+
output_retention_count: int
|
| 239 |
+
max_generate_length: int
|
| 240 |
+
default_model_name_or_path: str
|
| 241 |
+
prompt_presets: tuple[PromptPreset, ...]
|
| 242 |
+
default_prompt_name: str
|
| 243 |
+
default_prompt_audio_path: str | None
|
| 244 |
+
default_prompt_text: str
|
| 245 |
+
default_precision: str
|
| 246 |
+
default_num_steps: int
|
| 247 |
+
default_guidance_scale: float
|
| 248 |
+
default_speaker_scale: float
|
| 249 |
+
default_max_generate_length: int
|
| 250 |
+
local_model_choices: tuple[str, ...]
|
| 251 |
+
repo_root: Path = REPO_ROOT
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def build_gradio_app_config(
|
| 255 |
+
*,
|
| 256 |
+
host: str = DEFAULT_HOST,
|
| 257 |
+
port: int = DEFAULT_PORT,
|
| 258 |
+
execution_mode: ExecutionMode = DEFAULT_EXECUTION_MODE,
|
| 259 |
+
precision: str = DEFAULT_PRECISION,
|
| 260 |
+
optimize: bool = False,
|
| 261 |
+
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
| 262 |
+
output_retention_count: int = DEFAULT_OUTPUT_RETENTION,
|
| 263 |
+
max_generate_length: int = DEFAULT_MAX_GENERATE_LENGTH,
|
| 264 |
+
model_name_or_path: str | None = None,
|
| 265 |
+
default_prompt_name: str = DEFAULT_PROMPT_NAME,
|
| 266 |
+
default_precision: str = DEFAULT_PRECISION,
|
| 267 |
+
default_num_steps: int = DEFAULT_NUM_STEPS,
|
| 268 |
+
default_guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
|
| 269 |
+
default_speaker_scale: float = DEFAULT_SPEAKER_SCALE,
|
| 270 |
+
default_max_generate_length: int = DEFAULT_MAX_GENERATE_LENGTH,
|
| 271 |
+
repo_root: Path = REPO_ROOT,
|
| 272 |
+
prompts_dir: Path = DEFAULT_PROMPTS_DIR,
|
| 273 |
+
prompt_source_dir: Path = DEFAULT_PROMPT_SOURCE_DIR,
|
| 274 |
+
) -> GradioAppConfig:
|
| 275 |
+
sync_default_prompt_library(
|
| 276 |
+
source_dir=prompt_source_dir,
|
| 277 |
+
target_dir=prompts_dir,
|
| 278 |
+
)
|
| 279 |
+
discovered_models = discover_local_model_choices(repo_root=repo_root)
|
| 280 |
+
prompt_presets = discover_prompt_presets(
|
| 281 |
+
prompts_dir=prompts_dir,
|
| 282 |
+
mapping_file=prompts_dir / "prompt_text",
|
| 283 |
+
)
|
| 284 |
+
resolved_default_prompt_name, default_prompt_audio_path, default_prompt_text = (
|
| 285 |
+
resolve_default_prompt_selection(
|
| 286 |
+
prompt_presets,
|
| 287 |
+
default_prompt_name=default_prompt_name,
|
| 288 |
+
)
|
| 289 |
+
)
|
| 290 |
+
selected_model_name_or_path = (
|
| 291 |
+
model_name_or_path.strip()
|
| 292 |
+
if model_name_or_path is not None
|
| 293 |
+
else default_model_name_or_path(repo_root=repo_root)
|
| 294 |
+
)
|
| 295 |
+
if not selected_model_name_or_path:
|
| 296 |
+
raise ValueError("No default model found. Please pass --model-name-or-path.")
|
| 297 |
+
if execution_mode not in ("generate", "generate_stream"):
|
| 298 |
+
raise ValueError(f"Unsupported execution_mode: {execution_mode}")
|
| 299 |
+
resolved_max_generate_length = int(max_generate_length)
|
| 300 |
+
if resolved_max_generate_length <= 0:
|
| 301 |
+
raise ValueError("max_generate_length must be positive.")
|
| 302 |
+
resolved_precision = precision.strip() or DEFAULT_PRECISION
|
| 303 |
+
logger.info(
|
| 304 |
+
"Gradio app config prepared: host={} port={} output_dir={} "
|
| 305 |
+
"output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={} "
|
| 306 |
+
"default_model_name_or_path={} prompt_preset_count={} language_count={} local_model_choice_count={}",
|
| 307 |
+
host,
|
| 308 |
+
port,
|
| 309 |
+
output_dir,
|
| 310 |
+
output_retention_count,
|
| 311 |
+
resolved_max_generate_length,
|
| 312 |
+
execution_mode,
|
| 313 |
+
resolved_precision,
|
| 314 |
+
bool(optimize),
|
| 315 |
+
selected_model_name_or_path,
|
| 316 |
+
len(prompt_presets),
|
| 317 |
+
len(SUPPORTED_LANGUAGE_CODE_BY_NAME),
|
| 318 |
+
len(discovered_models),
|
| 319 |
+
)
|
| 320 |
+
return GradioAppConfig(
|
| 321 |
+
host=host,
|
| 322 |
+
port=int(port),
|
| 323 |
+
execution_mode=execution_mode,
|
| 324 |
+
precision=resolved_precision,
|
| 325 |
+
optimize=bool(optimize),
|
| 326 |
+
output_dir=Path(output_dir),
|
| 327 |
+
prompts_dir=Path(prompts_dir),
|
| 328 |
+
output_retention_count=int(output_retention_count),
|
| 329 |
+
max_generate_length=resolved_max_generate_length,
|
| 330 |
+
default_model_name_or_path=selected_model_name_or_path,
|
| 331 |
+
prompt_presets=prompt_presets,
|
| 332 |
+
default_prompt_name=resolved_default_prompt_name,
|
| 333 |
+
default_prompt_audio_path=default_prompt_audio_path,
|
| 334 |
+
default_prompt_text=default_prompt_text,
|
| 335 |
+
default_precision=default_precision,
|
| 336 |
+
default_num_steps=int(default_num_steps),
|
| 337 |
+
default_guidance_scale=float(default_guidance_scale),
|
| 338 |
+
default_speaker_scale=float(default_speaker_scale),
|
| 339 |
+
default_max_generate_length=int(default_max_generate_length),
|
| 340 |
+
local_model_choices=tuple(discovered_models),
|
| 341 |
+
repo_root=repo_root,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
@dataclass(frozen=True)
|
| 346 |
+
class SynthesisRequest:
|
| 347 |
+
model_name_or_path: str
|
| 348 |
+
text: str
|
| 349 |
+
prompt_audio_path: str | None = None
|
| 350 |
+
prompt_text: str | None = None
|
| 351 |
+
execution_mode: ExecutionMode = DEFAULT_EXECUTION_MODE
|
| 352 |
+
template_name: str = "tts"
|
| 353 |
+
language: str | None = None
|
| 354 |
+
ode_method: str = DEFAULT_ODE_METHOD
|
| 355 |
+
num_steps: int = DEFAULT_NUM_STEPS
|
| 356 |
+
guidance_scale: float = DEFAULT_GUIDANCE_SCALE
|
| 357 |
+
speaker_scale: float = DEFAULT_SPEAKER_SCALE
|
| 358 |
+
normalize_text: bool = False
|
| 359 |
+
seed: int = DEFAULT_SEED
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
@dataclass(frozen=True)
|
| 363 |
+
class SynthesisResult:
|
| 364 |
+
audio_path: str
|
| 365 |
+
metrics: dict[str, Any]
|
| 366 |
+
status: str
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class GradioAppService:
|
| 370 |
+
def __init__(self, config: GradioAppConfig):
|
| 371 |
+
self.config = config
|
| 372 |
+
self.config.output_dir.mkdir(parents=True, exist_ok=True)
|
| 373 |
+
self._lock = threading.Lock()
|
| 374 |
+
self._runtime: DotsTtsRuntime | None = None
|
| 375 |
+
self._runtime_model_name_or_path: str | None = None
|
| 376 |
+
logger.info(
|
| 377 |
+
"Gradio service initialized: output_dir={} default_model_name_or_path={} "
|
| 378 |
+
"output_retention_count={} max_generate_length={} execution_mode={} precision={} optimize={}",
|
| 379 |
+
self.config.output_dir,
|
| 380 |
+
self.config.default_model_name_or_path,
|
| 381 |
+
self.config.output_retention_count,
|
| 382 |
+
self.config.max_generate_length,
|
| 383 |
+
self.config.execution_mode,
|
| 384 |
+
self.config.precision,
|
| 385 |
+
self.config.optimize,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
def metadata(self) -> dict[str, Any]:
|
| 389 |
+
return {
|
| 390 |
+
"repo_root": str(self.config.repo_root),
|
| 391 |
+
"default_model_name_or_path": self.config.default_model_name_or_path,
|
| 392 |
+
"local_model_choices": list(self.config.local_model_choices),
|
| 393 |
+
"prompts_dir": str(self.config.prompts_dir),
|
| 394 |
+
"prompt_preset_names": [preset.name for preset in self.config.prompt_presets],
|
| 395 |
+
"default_prompt_name": self.config.default_prompt_name,
|
| 396 |
+
"output_dir": str(self.config.output_dir),
|
| 397 |
+
"output_retention_count": self.config.output_retention_count,
|
| 398 |
+
"configured_max_generate_length": self.config.max_generate_length,
|
| 399 |
+
"configured_execution_mode": self.config.execution_mode,
|
| 400 |
+
"configured_precision": self.config.precision,
|
| 401 |
+
"optimize": self.config.optimize,
|
| 402 |
+
"loaded_model_name_or_path": self._runtime_model_name_or_path,
|
| 403 |
+
"loaded_max_generate_length": (
|
| 404 |
+
self.config.max_generate_length if self._runtime is not None else None
|
| 405 |
+
),
|
| 406 |
+
"loaded_precision": (
|
| 407 |
+
self.config.precision if self._runtime is not None else None
|
| 408 |
+
),
|
| 409 |
+
"model_loaded": self._runtime is not None,
|
| 410 |
+
"host": self.config.host,
|
| 411 |
+
"port": self.config.port,
|
| 412 |
+
"default_precision": self.config.default_precision,
|
| 413 |
+
"default_num_steps": self.config.default_num_steps,
|
| 414 |
+
"default_guidance_scale": self.config.default_guidance_scale,
|
| 415 |
+
"default_speaker_scale": self.config.default_speaker_scale,
|
| 416 |
+
"default_max_generate_length": self.config.default_max_generate_length,
|
| 417 |
+
"supported_languages": build_language_choice_items()[1:],
|
| 418 |
+
"supported_template_names": list(GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES),
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
def _get_runtime(
|
| 422 |
+
self,
|
| 423 |
+
model_name_or_path: str,
|
| 424 |
+
) -> tuple[DotsTtsRuntime, str]:
|
| 425 |
+
resolved_model_name_or_path = resolve_model_name_or_path(
|
| 426 |
+
model_name_or_path,
|
| 427 |
+
repo_root=self.config.repo_root,
|
| 428 |
+
)
|
| 429 |
+
if (
|
| 430 |
+
self._runtime is None
|
| 431 |
+
or self._runtime_model_name_or_path != resolved_model_name_or_path
|
| 432 |
+
):
|
| 433 |
+
logger.info(
|
| 434 |
+
"Gradio runtime cache miss: requested_model={} resolved_model={} "
|
| 435 |
+
"max_generate_length={} execution_mode={} precision={} optimize={}",
|
| 436 |
+
model_name_or_path,
|
| 437 |
+
resolved_model_name_or_path,
|
| 438 |
+
self.config.max_generate_length,
|
| 439 |
+
self.config.execution_mode,
|
| 440 |
+
self.config.precision,
|
| 441 |
+
self.config.optimize,
|
| 442 |
+
)
|
| 443 |
+
self._runtime = DotsTtsRuntime.from_pretrained(
|
| 444 |
+
resolved_model_name_or_path,
|
| 445 |
+
precision=self.config.precision,
|
| 446 |
+
optimize=self.config.optimize,
|
| 447 |
+
max_generate_length=self.config.max_generate_length,
|
| 448 |
+
)
|
| 449 |
+
self._runtime_model_name_or_path = resolved_model_name_or_path
|
| 450 |
+
else:
|
| 451 |
+
logger.info(
|
| 452 |
+
"Gradio runtime cache hit: requested_model={} resolved_model={} "
|
| 453 |
+
"max_generate_length={} execution_mode={} precision={} optimize={}",
|
| 454 |
+
model_name_or_path,
|
| 455 |
+
resolved_model_name_or_path,
|
| 456 |
+
self.config.max_generate_length,
|
| 457 |
+
self.config.execution_mode,
|
| 458 |
+
self.config.precision,
|
| 459 |
+
self.config.optimize,
|
| 460 |
+
)
|
| 461 |
+
return self._runtime, resolved_model_name_or_path
|
| 462 |
+
|
| 463 |
+
def _build_stream_request_id(
|
| 464 |
+
self,
|
| 465 |
+
runtime: DotsTtsRuntime,
|
| 466 |
+
request: SynthesisRequest,
|
| 467 |
+
) -> str:
|
| 468 |
+
normalized_text, normalized_language = runtime._process_text( # noqa: SLF001
|
| 469 |
+
request.text,
|
| 470 |
+
language=request.language,
|
| 471 |
+
normalize=request.normalize_text,
|
| 472 |
+
)
|
| 473 |
+
normalized_prompt_text = runtime._process_prompt_text( # noqa: SLF001
|
| 474 |
+
request.prompt_text,
|
| 475 |
+
language=normalized_language,
|
| 476 |
+
)
|
| 477 |
+
if normalized_language is not None and not normalized_prompt_text:
|
| 478 |
+
from dots_tts.utils.text import attach_language_tag # noqa: PLC0415
|
| 479 |
+
|
| 480 |
+
normalized_text = attach_language_tag(
|
| 481 |
+
normalized_text,
|
| 482 |
+
normalized_language,
|
| 483 |
+
)
|
| 484 |
+
request_id_kwargs = {
|
| 485 |
+
"text": normalized_text,
|
| 486 |
+
"prompt_audio_path": request.prompt_audio_path,
|
| 487 |
+
"prompt_text": normalized_prompt_text,
|
| 488 |
+
"template_name": request.template_name,
|
| 489 |
+
}
|
| 490 |
+
if normalized_language is not None:
|
| 491 |
+
request_id_kwargs["language"] = normalized_language
|
| 492 |
+
return runtime._build_request_id( # noqa: SLF001
|
| 493 |
+
**request_id_kwargs,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
@staticmethod
|
| 497 |
+
def _build_runtime_generate_kwargs(request: SynthesisRequest) -> dict[str, Any]:
|
| 498 |
+
runtime_kwargs: dict[str, Any] = {
|
| 499 |
+
"text": request.text,
|
| 500 |
+
"prompt_audio_path": request.prompt_audio_path,
|
| 501 |
+
"prompt_text": request.prompt_text,
|
| 502 |
+
"template_name": request.template_name,
|
| 503 |
+
"ode_method": request.ode_method,
|
| 504 |
+
"num_steps": request.num_steps,
|
| 505 |
+
"guidance_scale": request.guidance_scale,
|
| 506 |
+
"speaker_scale": request.speaker_scale,
|
| 507 |
+
"normalize_text": request.normalize_text,
|
| 508 |
+
}
|
| 509 |
+
if request.language is not None:
|
| 510 |
+
runtime_kwargs["language"] = request.language
|
| 511 |
+
return runtime_kwargs
|
| 512 |
+
|
| 513 |
+
def _run_stream_generation(
|
| 514 |
+
self,
|
| 515 |
+
runtime: DotsTtsRuntime,
|
| 516 |
+
request: SynthesisRequest,
|
| 517 |
+
) -> dict[str, Any]:
|
| 518 |
+
start_time = time.time()
|
| 519 |
+
chunks = [
|
| 520 |
+
chunk.detach().float().cpu()
|
| 521 |
+
for chunk in runtime.generate_stream(
|
| 522 |
+
**self._build_runtime_generate_kwargs(request)
|
| 523 |
+
)
|
| 524 |
+
]
|
| 525 |
+
if not chunks:
|
| 526 |
+
raise ValueError("流式生成未返回任何音频块。")
|
| 527 |
+
|
| 528 |
+
audio = torch.cat(chunks, dim=-1)
|
| 529 |
+
elapsed_seconds = time.time() - start_time
|
| 530 |
+
audio_seconds = audio.shape[-1] / runtime.sample_rate
|
| 531 |
+
rtf = elapsed_seconds / audio_seconds if audio_seconds > 0 else float("inf")
|
| 532 |
+
return {
|
| 533 |
+
"fid": self._build_stream_request_id(runtime, request),
|
| 534 |
+
"audio": audio,
|
| 535 |
+
"sample_rate": runtime.sample_rate,
|
| 536 |
+
"time_used": elapsed_seconds,
|
| 537 |
+
"rtf": rtf,
|
| 538 |
+
"chunk_count": len(chunks),
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
def warmup(self, text: str | None = None) -> dict[str, Any]:
|
| 542 |
+
warmup_text = (text or "").strip() or DEFAULT_WARMUP_TEXT.strip()
|
| 543 |
+
if not warmup_text:
|
| 544 |
+
raise ValueError("DEFAULT_WARMUP_TEXT 不能为空。")
|
| 545 |
+
|
| 546 |
+
with self._lock:
|
| 547 |
+
logger.info(
|
| 548 |
+
"Gradio warmup requested: default_model_name_or_path={} execution_mode={} precision={} optimize={} seed={}",
|
| 549 |
+
self.config.default_model_name_or_path,
|
| 550 |
+
self.config.execution_mode,
|
| 551 |
+
self.config.precision,
|
| 552 |
+
self.config.optimize,
|
| 553 |
+
DEFAULT_SEED,
|
| 554 |
+
)
|
| 555 |
+
try:
|
| 556 |
+
seed_everything(DEFAULT_SEED)
|
| 557 |
+
runtime, resolved_model_name_or_path = self._get_runtime(
|
| 558 |
+
self.config.default_model_name_or_path,
|
| 559 |
+
)
|
| 560 |
+
warmup_request = SynthesisRequest(
|
| 561 |
+
model_name_or_path=self.config.default_model_name_or_path,
|
| 562 |
+
text=warmup_text,
|
| 563 |
+
execution_mode=self.config.execution_mode,
|
| 564 |
+
template_name="tts",
|
| 565 |
+
ode_method=DEFAULT_ODE_METHOD,
|
| 566 |
+
num_steps=self.config.default_num_steps,
|
| 567 |
+
guidance_scale=self.config.default_guidance_scale,
|
| 568 |
+
speaker_scale=self.config.default_speaker_scale,
|
| 569 |
+
normalize_text=False,
|
| 570 |
+
seed=DEFAULT_SEED,
|
| 571 |
+
)
|
| 572 |
+
request_id = self._build_stream_request_id(runtime, warmup_request)
|
| 573 |
+
if self.config.execution_mode == "generate_stream":
|
| 574 |
+
result = self._run_stream_generation(runtime, warmup_request)
|
| 575 |
+
else:
|
| 576 |
+
start_time = time.time()
|
| 577 |
+
result = runtime.generate(**self._build_runtime_generate_kwargs(warmup_request))
|
| 578 |
+
result["time_used"] = time.time() - start_time
|
| 579 |
+
result["chunk_count"] = 1
|
| 580 |
+
audio_samples = int(result["audio"].shape[-1])
|
| 581 |
+
except Exception:
|
| 582 |
+
logger.exception(
|
| 583 |
+
"Gradio warmup failed: default_model_name_or_path={}",
|
| 584 |
+
self.config.default_model_name_or_path,
|
| 585 |
+
)
|
| 586 |
+
raise
|
| 587 |
+
audio_seconds = audio_samples / runtime.sample_rate
|
| 588 |
+
metrics = {
|
| 589 |
+
"request_id": request_id,
|
| 590 |
+
"execution_mode": self.config.execution_mode,
|
| 591 |
+
"chunk_count": int(result["chunk_count"]),
|
| 592 |
+
"resolved_model_name_or_path": resolved_model_name_or_path,
|
| 593 |
+
"sample_rate": runtime.sample_rate,
|
| 594 |
+
"elapsed_seconds": round(float(result["time_used"]), 3),
|
| 595 |
+
"audio_seconds": round(float(audio_seconds), 3),
|
| 596 |
+
"rtf": round(float(result["rtf"]), 4),
|
| 597 |
+
"seed": DEFAULT_SEED,
|
| 598 |
+
"text": warmup_text,
|
| 599 |
+
}
|
| 600 |
+
logger.info(
|
| 601 |
+
"Gradio warmup ready: request_id={} execution_mode={} resolved_model_name_or_path={}",
|
| 602 |
+
metrics["request_id"],
|
| 603 |
+
metrics["execution_mode"],
|
| 604 |
+
metrics["resolved_model_name_or_path"],
|
| 605 |
+
)
|
| 606 |
+
return metrics
|
| 607 |
+
|
| 608 |
+
def _normalize_request(self, request: SynthesisRequest) -> SynthesisRequest:
|
| 609 |
+
normalized_text = request.text.strip()
|
| 610 |
+
if not normalized_text:
|
| 611 |
+
raise ValueError("text 不能为空。")
|
| 612 |
+
|
| 613 |
+
normalized_prompt_audio_path = request.prompt_audio_path or None
|
| 614 |
+
normalized_prompt_text = (request.prompt_text or "").strip() or None
|
| 615 |
+
if normalized_prompt_text and not normalized_prompt_audio_path:
|
| 616 |
+
raise ValueError("prompt_text requires prompt_audio_path.")
|
| 617 |
+
normalized_template_name = request.template_name.strip() or "tts"
|
| 618 |
+
if normalized_template_name not in GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES:
|
| 619 |
+
raise ValueError(
|
| 620 |
+
f"Unsupported template_name={normalized_template_name!r}. "
|
| 621 |
+
f"Expected one of {list(GRADIO_SYNTHESIS_MODE_TEMPLATE_NAMES)}."
|
| 622 |
+
)
|
| 623 |
+
normalized_language = (request.language or "").strip() or None
|
| 624 |
+
supported_language_codes = set(SUPPORTED_LANGUAGE_CODE_BY_NAME.values())
|
| 625 |
+
if (
|
| 626 |
+
normalized_language is not None
|
| 627 |
+
and normalized_language not in supported_language_codes
|
| 628 |
+
):
|
| 629 |
+
raise ValueError(
|
| 630 |
+
f"Unsupported language={normalized_language!r}. "
|
| 631 |
+
f"Expected one of {sorted(supported_language_codes)}."
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
resolved_seed = int(request.seed)
|
| 635 |
+
return SynthesisRequest(
|
| 636 |
+
model_name_or_path=request.model_name_or_path.strip(),
|
| 637 |
+
text=normalized_text,
|
| 638 |
+
prompt_audio_path=normalized_prompt_audio_path,
|
| 639 |
+
prompt_text=normalized_prompt_text,
|
| 640 |
+
execution_mode=request.execution_mode,
|
| 641 |
+
template_name=normalized_template_name,
|
| 642 |
+
language=normalized_language,
|
| 643 |
+
ode_method=request.ode_method.strip() or DEFAULT_ODE_METHOD,
|
| 644 |
+
num_steps=int(request.num_steps),
|
| 645 |
+
guidance_scale=float(request.guidance_scale),
|
| 646 |
+
speaker_scale=float(request.speaker_scale),
|
| 647 |
+
normalize_text=bool(request.normalize_text),
|
| 648 |
+
seed=resolved_seed,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
def _build_output_path(self) -> Path:
|
| 652 |
+
output_name = f"{time.strftime('%Y%m%d-%H%M%S')}-{uuid.uuid4().hex[:8]}.wav"
|
| 653 |
+
return self.config.output_dir / output_name
|
| 654 |
+
|
| 655 |
+
def _cleanup_outputs(self) -> None:
|
| 656 |
+
if self.config.output_retention_count <= 0:
|
| 657 |
+
return
|
| 658 |
+
|
| 659 |
+
wav_files = sorted(
|
| 660 |
+
self.config.output_dir.glob("*.wav"),
|
| 661 |
+
key=lambda path: path.stat().st_mtime,
|
| 662 |
+
reverse=True,
|
| 663 |
+
)
|
| 664 |
+
removed_count = 0
|
| 665 |
+
for stale_file in wav_files[self.config.output_retention_count :]:
|
| 666 |
+
stale_file.unlink(missing_ok=True)
|
| 667 |
+
removed_count += 1
|
| 668 |
+
if removed_count > 0:
|
| 669 |
+
logger.info(
|
| 670 |
+
"Gradio output cleanup completed: removed_files={} retention_limit={}",
|
| 671 |
+
removed_count,
|
| 672 |
+
self.config.output_retention_count,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
@staticmethod
|
| 676 |
+
def _waveform_to_numpy(audio: torch.Tensor):
|
| 677 |
+
waveform = audio.detach().float().cpu().squeeze()
|
| 678 |
+
if waveform.ndim == 0:
|
| 679 |
+
raise ValueError("生成音频为空。")
|
| 680 |
+
return waveform.numpy()
|
| 681 |
+
|
| 682 |
+
def _write_audio(self, audio: torch.Tensor, sample_rate: int) -> str:
|
| 683 |
+
output_path = self._build_output_path()
|
| 684 |
+
logger.info(
|
| 685 |
+
"Writing synthesized audio: output_path={} sample_rate={} samples={}",
|
| 686 |
+
output_path,
|
| 687 |
+
sample_rate,
|
| 688 |
+
audio.shape[-1],
|
| 689 |
+
)
|
| 690 |
+
sf.write(output_path, self._waveform_to_numpy(audio), sample_rate)
|
| 691 |
+
self._cleanup_outputs()
|
| 692 |
+
logger.info("Synthesized audio written: output_path={}", output_path)
|
| 693 |
+
return str(output_path)
|
| 694 |
+
|
| 695 |
+
def generate(self, request: SynthesisRequest) -> SynthesisResult:
|
| 696 |
+
normalized_request = self._normalize_request(request)
|
| 697 |
+
|
| 698 |
+
with self._lock:
|
| 699 |
+
try:
|
| 700 |
+
seed_everything(normalized_request.seed)
|
| 701 |
+
runtime, resolved_model_name_or_path = self._get_runtime(
|
| 702 |
+
normalized_request.model_name_or_path,
|
| 703 |
+
)
|
| 704 |
+
logger.info(
|
| 705 |
+
"Gradio request accepted: resolved_model_name_or_path={} execution_mode={} seed={}",
|
| 706 |
+
resolved_model_name_or_path,
|
| 707 |
+
normalized_request.execution_mode,
|
| 708 |
+
normalized_request.seed,
|
| 709 |
+
)
|
| 710 |
+
if normalized_request.execution_mode == "generate_stream":
|
| 711 |
+
result = self._run_stream_generation(runtime, normalized_request)
|
| 712 |
+
else:
|
| 713 |
+
result = runtime.generate(
|
| 714 |
+
**self._build_runtime_generate_kwargs(normalized_request)
|
| 715 |
+
)
|
| 716 |
+
result["chunk_count"] = 1
|
| 717 |
+
audio_path = self._write_audio(result["audio"], result["sample_rate"])
|
| 718 |
+
except Exception:
|
| 719 |
+
logger.exception(
|
| 720 |
+
"Gradio request failed: model_name_or_path={} execution_mode={} text_len={} has_prompt_audio={} has_prompt_text={} template_name={} language={} "
|
| 721 |
+
"precision={} ode_method={} num_steps={} guidance_scale={} speaker_scale={} max_generate_length={} "
|
| 722 |
+
"normalize_text={} seed={}",
|
| 723 |
+
normalized_request.model_name_or_path,
|
| 724 |
+
normalized_request.execution_mode,
|
| 725 |
+
len(normalized_request.text),
|
| 726 |
+
bool(normalized_request.prompt_audio_path),
|
| 727 |
+
bool(normalized_request.prompt_text),
|
| 728 |
+
normalized_request.template_name,
|
| 729 |
+
normalized_request.language,
|
| 730 |
+
self.config.precision,
|
| 731 |
+
normalized_request.ode_method,
|
| 732 |
+
normalized_request.num_steps,
|
| 733 |
+
normalized_request.guidance_scale,
|
| 734 |
+
normalized_request.speaker_scale,
|
| 735 |
+
self.config.max_generate_length,
|
| 736 |
+
normalized_request.normalize_text,
|
| 737 |
+
normalized_request.seed,
|
| 738 |
+
)
|
| 739 |
+
raise
|
| 740 |
+
audio_seconds = result["audio"].shape[-1] / result["sample_rate"]
|
| 741 |
+
metrics = {
|
| 742 |
+
"request_id": result["fid"],
|
| 743 |
+
"execution_mode": normalized_request.execution_mode,
|
| 744 |
+
"chunk_count": int(result["chunk_count"]),
|
| 745 |
+
"template_name": normalized_request.template_name,
|
| 746 |
+
"language": normalized_request.language,
|
| 747 |
+
"resolved_model_name_or_path": resolved_model_name_or_path,
|
| 748 |
+
"sample_rate": result["sample_rate"],
|
| 749 |
+
"elapsed_seconds": round(float(result["time_used"]), 3),
|
| 750 |
+
"audio_seconds": round(float(audio_seconds), 3),
|
| 751 |
+
"rtf": round(float(result["rtf"]), 4),
|
| 752 |
+
"seed": normalized_request.seed,
|
| 753 |
+
"output_path": audio_path,
|
| 754 |
+
}
|
| 755 |
+
logger.info(
|
| 756 |
+
"Gradio request output ready: request_id={} execution_mode={} resolved_model_name_or_path={} output_path={}",
|
| 757 |
+
metrics["request_id"],
|
| 758 |
+
metrics["execution_mode"],
|
| 759 |
+
metrics["resolved_model_name_or_path"],
|
| 760 |
+
metrics["output_path"],
|
| 761 |
+
)
|
| 762 |
+
status = (
|
| 763 |
+
f"完成:{Path(audio_path).name} | "
|
| 764 |
+
f"模式 {metrics['execution_mode']} | "
|
| 765 |
+
f"耗时 {metrics['elapsed_seconds']}s | "
|
| 766 |
+
f"音频 {metrics['audio_seconds']}s | "
|
| 767 |
+
f"RTF {metrics['rtf']}"
|
| 768 |
+
)
|
| 769 |
+
return SynthesisResult(
|
| 770 |
+
audio_path=audio_path,
|
| 771 |
+
metrics=metrics,
|
| 772 |
+
status=status,
|
| 773 |
+
)
|
configs/dots_tts.yaml
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train_data:
|
| 2 |
+
train_audio_sample_rate: 48000
|
| 3 |
+
audio_samples_per_llm_token: 7680
|
| 4 |
+
sources:
|
| 5 |
+
- name: ljspeech_basic
|
| 6 |
+
weight: 1.0
|
| 7 |
+
pipeline: basic
|
| 8 |
+
adapter:
|
| 9 |
+
class_name: JsonlManifestSourceAdapter
|
| 10 |
+
params:
|
| 11 |
+
manifest_path: downloaded_data/ljspeech_48khz_manifest_train.jsonl
|
| 12 |
+
shuffle: true
|
| 13 |
+
- name: ljspeech_interleave
|
| 14 |
+
weight: 1.0
|
| 15 |
+
pipeline: interleave
|
| 16 |
+
adapter:
|
| 17 |
+
class_name: JsonlManifestSourceAdapter
|
| 18 |
+
params:
|
| 19 |
+
manifest_path: downloaded_data/ljspeech_48khz_manifest_train.jsonl
|
| 20 |
+
shuffle: true
|
| 21 |
+
# append other sources here if need
|
| 22 |
+
num_tokens_per_epoch: 2000000
|
| 23 |
+
num_workers: 20
|
| 24 |
+
pin_memory: true
|
| 25 |
+
max_audio_seconds_in_batch: 30.0
|
| 26 |
+
max_text_tokens_in_batch: 2048
|
| 27 |
+
max_samples_per_batch: null
|
| 28 |
+
bucketing_pool_size: 100
|
| 29 |
+
val_data:
|
| 30 |
+
train_audio_sample_rate: 48000
|
| 31 |
+
audio_samples_per_llm_token: 7680
|
| 32 |
+
sources:
|
| 33 |
+
- name: ljspeech_valid_basic
|
| 34 |
+
weight: 1.0
|
| 35 |
+
adapter:
|
| 36 |
+
class_name: JsonlManifestSourceAdapter
|
| 37 |
+
params:
|
| 38 |
+
manifest_path: downloaded_data/ljspeech_48khz_manifest_valid.jsonl
|
| 39 |
+
shuffle: false
|
| 40 |
+
pipeline: basic
|
| 41 |
+
- name: ljspeech_valid_interleave
|
| 42 |
+
weight: 1.0
|
| 43 |
+
pipeline: interleave
|
| 44 |
+
adapter:
|
| 45 |
+
class_name: JsonlManifestSourceAdapter
|
| 46 |
+
params:
|
| 47 |
+
manifest_path: downloaded_data/ljspeech_48khz_manifest_valid.jsonl
|
| 48 |
+
shuffle: false
|
| 49 |
+
pipeline: interleave
|
| 50 |
+
# append other sources here if need
|
| 51 |
+
num_workers: 4
|
| 52 |
+
pin_memory: true
|
| 53 |
+
max_audio_seconds_in_batch: 30.0
|
| 54 |
+
max_text_tokens_in_batch: 2048
|
| 55 |
+
max_samples_per_batch: null
|
| 56 |
+
bucketing_pool_size: 64
|
| 57 |
+
train:
|
| 58 |
+
pretrained_model_path: pretrained_models/pretrain_cpt_decay/latest/model/
|
| 59 |
+
output_dir: debug_train/run_003
|
| 60 |
+
seed: 42
|
| 61 |
+
learning_rate: 1.0e-05
|
| 62 |
+
weight_decay: 0.01
|
| 63 |
+
warmup_steps: 50
|
| 64 |
+
max_train_steps: 500
|
| 65 |
+
gradient_accumulation_steps: 2
|
| 66 |
+
grad_clip_norm: 1
|
| 67 |
+
save_interval: 500
|
| 68 |
+
max_checkpoints_to_keep: 40
|
| 69 |
+
log_interval: 10
|
| 70 |
+
eval_interval: 100
|
| 71 |
+
max_eval_batches: null
|
| 72 |
+
run_eval_on_start: false
|
| 73 |
+
loss:
|
| 74 |
+
ce_weight: 1.0
|
| 75 |
+
fm_weight: 1.0
|
| 76 |
+
eos_weight: 1.0
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spaces>=0.40.1
|
| 2 |
+
torch>=2.8.0
|
| 3 |
+
torchaudio>=2.8.0
|
| 4 |
+
transformers>=4.57.0
|
| 5 |
+
huggingface-hub>=0.36.0
|
| 6 |
+
gradio>=6.16.0
|
| 7 |
+
loguru>=0.7.3
|
| 8 |
+
langcodes[data]>=3.5.0
|
| 9 |
+
einops>=0.8.1
|
| 10 |
+
librosa>=0.11.0
|
| 11 |
+
soundfile>=0.13.1
|
| 12 |
+
numpy>=2.2.6
|
| 13 |
+
pydantic>=2.12.5,<3
|
| 14 |
+
PyYAML>=6.0.3
|
| 15 |
+
safetensors>=0.8.0rc0
|
| 16 |
+
torchdiffeq>=0.2.5
|
| 17 |
+
tqdm>=4.67.1
|
| 18 |
+
lingua-language-detector>=2.1.1
|
| 19 |
+
WeTextProcessing>=1.0.4
|
src/dots_tts/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""dots.tts package."""
|
src/dots_tts/cli.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def parse_args(argv=None):
|
| 8 |
+
parser = argparse.ArgumentParser(description="dots.tts inference CLI.")
|
| 9 |
+
template_choices = ("tts", "instruction_tts", "text_to_audio", "tts_interleave")
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"--model-name-or-path",
|
| 12 |
+
required=True,
|
| 13 |
+
help="Local pretrained directory or Hugging Face repo id",
|
| 14 |
+
)
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"--revision", default=None, help="Optional Hugging Face revision"
|
| 17 |
+
)
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--cache-dir", default=None, help="Optional Hugging Face cache dir"
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument("--text", type=str, required=True, help="Input text")
|
| 22 |
+
parser.add_argument("--output", default="output.wav", help="Output wav file path")
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--precision", type=str, default="bfloat16", help="Inference precision"
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--seed",
|
| 28 |
+
type=int,
|
| 29 |
+
default=42,
|
| 30 |
+
help="Random seed for inference.",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--prompt-audio", type=str, default=None, help="Path to prompt audio"
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--prompt-text", type=str, default=None, help="Transcript of prompt audio"
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--language",
|
| 40 |
+
type=str,
|
| 41 |
+
default=None,
|
| 42 |
+
help="Language tag mode. Default: none. Supported values: none, auto_detect, or a language code/name such as EN/en/english/chinese.",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--template-name",
|
| 46 |
+
choices=template_choices,
|
| 47 |
+
default=None,
|
| 48 |
+
help="Named template preset for generation.",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--ode-method", type=str, default="euler", help="ODE solver method"
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--num-steps", type=int, default=10, help="Diffusion sampling steps"
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--guidance-scale",
|
| 58 |
+
type=float,
|
| 59 |
+
default=1.2,
|
| 60 |
+
help="Classifier-free guidance scale",
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--speaker-scale",
|
| 64 |
+
type=float,
|
| 65 |
+
default=1.5,
|
| 66 |
+
help="Scale applied to the reference speaker embedding",
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--max-generate-length",
|
| 70 |
+
type=int,
|
| 71 |
+
default=500,
|
| 72 |
+
help="Maximum total audio patch count (prompt + generated)",
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--normalize-text",
|
| 76 |
+
action="store_true",
|
| 77 |
+
help="Whether to normalize text before inference",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--profile-inference",
|
| 81 |
+
action="store_true",
|
| 82 |
+
help="Collect per-module inference timing statistics",
|
| 83 |
+
)
|
| 84 |
+
return parser.parse_args(argv)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def main(argv=None):
|
| 88 |
+
args = parse_args(argv)
|
| 89 |
+
import soundfile as sf
|
| 90 |
+
from loguru import logger
|
| 91 |
+
|
| 92 |
+
from dots_tts.runtime import DotsTtsRuntime
|
| 93 |
+
from dots_tts.utils.logging import configure_logging
|
| 94 |
+
from dots_tts.utils.util import seed_everything
|
| 95 |
+
|
| 96 |
+
configure_logging()
|
| 97 |
+
seed_everything(args.seed)
|
| 98 |
+
output_path = Path(args.output)
|
| 99 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 100 |
+
|
| 101 |
+
logger.info(
|
| 102 |
+
"CLI command started: model={} output={} seed={}",
|
| 103 |
+
args.model_name_or_path,
|
| 104 |
+
output_path,
|
| 105 |
+
args.seed,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
runtime = DotsTtsRuntime.from_pretrained(
|
| 110 |
+
args.model_name_or_path,
|
| 111 |
+
revision=args.revision,
|
| 112 |
+
cache_dir=args.cache_dir,
|
| 113 |
+
precision=args.precision,
|
| 114 |
+
max_generate_length=args.max_generate_length,
|
| 115 |
+
)
|
| 116 |
+
result = runtime.generate(
|
| 117 |
+
text=args.text,
|
| 118 |
+
prompt_audio_path=args.prompt_audio,
|
| 119 |
+
prompt_text=args.prompt_text,
|
| 120 |
+
language=args.language,
|
| 121 |
+
template_name=args.template_name,
|
| 122 |
+
ode_method=args.ode_method,
|
| 123 |
+
num_steps=args.num_steps,
|
| 124 |
+
guidance_scale=args.guidance_scale,
|
| 125 |
+
speaker_scale=args.speaker_scale,
|
| 126 |
+
normalize_text=args.normalize_text,
|
| 127 |
+
profile_inference=args.profile_inference,
|
| 128 |
+
)
|
| 129 |
+
sf.write(
|
| 130 |
+
output_path,
|
| 131 |
+
result["audio"].float().cpu().squeeze().numpy(),
|
| 132 |
+
result["sample_rate"],
|
| 133 |
+
)
|
| 134 |
+
except Exception:
|
| 135 |
+
logger.exception(
|
| 136 |
+
"CLI inference failed: model={} output={}",
|
| 137 |
+
args.model_name_or_path,
|
| 138 |
+
output_path,
|
| 139 |
+
)
|
| 140 |
+
raise
|
| 141 |
+
|
| 142 |
+
logger.info(
|
| 143 |
+
"CLI output written: request_id={} output={} sample_rate={} samples={}",
|
| 144 |
+
result["fid"],
|
| 145 |
+
output_path,
|
| 146 |
+
result["sample_rate"],
|
| 147 |
+
int(result["audio"].shape[-1]),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
raise SystemExit(main())
|
src/dots_tts/config/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Configuration package."""
|
src/dots_tts/config/app.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
|
| 7 |
+
from dots_tts.config.base import StrictConfigBase
|
| 8 |
+
from dots_tts.config.data import DataConfig
|
| 9 |
+
from dots_tts.config.train import TrainConfig
|
| 10 |
+
from dots_tts.models.dots_tts.config import LossConfig
|
| 11 |
+
|
| 12 |
+
DEFAULT_CONFIG_PATH = "configs/dots_tts.yaml"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AppConfig(StrictConfigBase):
|
| 16 |
+
train_data: DataConfig
|
| 17 |
+
val_data: DataConfig | None = None
|
| 18 |
+
loss: LossConfig
|
| 19 |
+
train: TrainConfig
|
| 20 |
+
|
| 21 |
+
@classmethod
|
| 22 |
+
def from_yaml(cls, config_path: str = DEFAULT_CONFIG_PATH) -> AppConfig:
|
| 23 |
+
with Path(config_path).open(encoding="utf-8") as fin:
|
| 24 |
+
raw_config = yaml.safe_load(fin)
|
| 25 |
+
return cls.model_validate(raw_config)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_config(config_path: str = DEFAULT_CONFIG_PATH) -> AppConfig:
|
| 29 |
+
return AppConfig.from_yaml(config_path)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
__all__ = ["AppConfig", "DEFAULT_CONFIG_PATH", "load_config"]
|
src/dots_tts/config/base.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, ConfigDict
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConfigBase(BaseModel):
|
| 9 |
+
model_config = ConfigDict(
|
| 10 |
+
extra="allow",
|
| 11 |
+
validate_assignment=True,
|
| 12 |
+
arbitrary_types_allowed=True,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
def get(self, key: str, default=None):
|
| 16 |
+
value = getattr(self, key, default)
|
| 17 |
+
if value is default:
|
| 18 |
+
return value
|
| 19 |
+
|
| 20 |
+
fields_set = self.model_fields_set
|
| 21 |
+
if value is None and key not in fields_set:
|
| 22 |
+
return default
|
| 23 |
+
return value
|
| 24 |
+
|
| 25 |
+
def to_dict(self) -> dict[str, Any]:
|
| 26 |
+
return self.model_dump(exclude_none=True)
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
def _declared_field_names(cls) -> list[str]:
|
| 30 |
+
return [name for name in cls.model_fields if name != "model_config"]
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def _serialize_declared_value(cls, value):
|
| 34 |
+
if isinstance(value, ConfigBase):
|
| 35 |
+
return value.to_declared_dict()
|
| 36 |
+
if isinstance(value, list):
|
| 37 |
+
return [cls._serialize_declared_value(item) for item in value]
|
| 38 |
+
if isinstance(value, tuple):
|
| 39 |
+
return [cls._serialize_declared_value(item) for item in value]
|
| 40 |
+
if isinstance(value, dict):
|
| 41 |
+
return {
|
| 42 |
+
key: cls._serialize_declared_value(item) for key, item in value.items()
|
| 43 |
+
}
|
| 44 |
+
return value
|
| 45 |
+
|
| 46 |
+
def to_declared_dict(self) -> dict[str, Any]:
|
| 47 |
+
data = {}
|
| 48 |
+
for name in self._declared_field_names():
|
| 49 |
+
value = getattr(self, name, None)
|
| 50 |
+
if value is None:
|
| 51 |
+
continue
|
| 52 |
+
data[name] = self._serialize_declared_value(value)
|
| 53 |
+
return data
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class StrictConfigBase(ConfigBase):
|
| 57 |
+
model_config = ConfigDict(
|
| 58 |
+
extra="forbid",
|
| 59 |
+
validate_assignment=True,
|
| 60 |
+
arbitrary_types_allowed=True,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
__all__ = ["ConfigBase", "StrictConfigBase"]
|
src/dots_tts/config/data.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Literal
|
| 4 |
+
|
| 5 |
+
from pydantic import Field, model_validator
|
| 6 |
+
|
| 7 |
+
from dots_tts.config.base import StrictConfigBase
|
| 8 |
+
|
| 9 |
+
DEFAULT_SOURCE_ADAPTER_CLASS_NAME = "JsonlManifestSourceAdapter"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SourceAdapterConfig(StrictConfigBase):
|
| 13 |
+
class_name: Literal["JsonlManifestSourceAdapter"] = (
|
| 14 |
+
DEFAULT_SOURCE_ADAPTER_CLASS_NAME
|
| 15 |
+
)
|
| 16 |
+
params: dict[str, Any] = Field(default_factory=dict)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DataSourceConfig(StrictConfigBase):
|
| 20 |
+
name: str
|
| 21 |
+
weight: float = Field(default=1.0, gt=0.0)
|
| 22 |
+
pipeline: Literal["basic", "interleave"] = "basic"
|
| 23 |
+
adapter: SourceAdapterConfig = Field(default_factory=SourceAdapterConfig)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DataConfig(StrictConfigBase):
|
| 27 |
+
sources: list[DataSourceConfig]
|
| 28 |
+
train_audio_sample_rate: int = Field(ge=1)
|
| 29 |
+
audio_samples_per_llm_token: int = Field(ge=1)
|
| 30 |
+
num_tokens_per_epoch: int | None = Field(
|
| 31 |
+
default=None,
|
| 32 |
+
ge=1,
|
| 33 |
+
description="Global token budget across all ranks for one training epoch.",
|
| 34 |
+
)
|
| 35 |
+
num_workers: int = Field(default=0, ge=0)
|
| 36 |
+
pin_memory: bool = False
|
| 37 |
+
prefetch_factor: int = Field(
|
| 38 |
+
default=2,
|
| 39 |
+
ge=1,
|
| 40 |
+
description="Samples prefetched by each DataLoader worker.",
|
| 41 |
+
)
|
| 42 |
+
max_audio_seconds_in_batch: float = Field(gt=0.0)
|
| 43 |
+
max_text_tokens_in_batch: int = Field(ge=1)
|
| 44 |
+
max_samples_per_batch: int | None = Field(default=None, ge=1)
|
| 45 |
+
bucketing_pool_size: int = Field(default=64, ge=1)
|
| 46 |
+
|
| 47 |
+
@model_validator(mode="after")
|
| 48 |
+
def _validate_unique_source_names(self) -> "DataConfig":
|
| 49 |
+
counts: dict[str, int] = {}
|
| 50 |
+
for source in self.sources:
|
| 51 |
+
counts[source.name] = counts.get(source.name, 0) + 1
|
| 52 |
+
duplicated = [name for name, count in counts.items() if count > 1]
|
| 53 |
+
if duplicated:
|
| 54 |
+
raise ValueError(f"Source names must be unique: {duplicated}")
|
| 55 |
+
return self
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
__all__ = [
|
| 59 |
+
"DEFAULT_SOURCE_ADAPTER_CLASS_NAME",
|
| 60 |
+
"DataConfig",
|
| 61 |
+
"DataSourceConfig",
|
| 62 |
+
"SourceAdapterConfig",
|
| 63 |
+
]
|
src/dots_tts/config/train.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pydantic import Field
|
| 4 |
+
|
| 5 |
+
from dots_tts.config.base import StrictConfigBase
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TrainConfig(StrictConfigBase):
|
| 9 |
+
pretrained_model_path: str
|
| 10 |
+
output_dir: str
|
| 11 |
+
seed: int = 42
|
| 12 |
+
learning_rate: float
|
| 13 |
+
cfg_droprate: float = 0.0
|
| 14 |
+
xvec_drop_rate: float = 0.5
|
| 15 |
+
weight_decay: float = 0.01
|
| 16 |
+
warmup_steps: int = 0
|
| 17 |
+
max_train_steps: int
|
| 18 |
+
gradient_accumulation_steps: int = Field(default=1, ge=1)
|
| 19 |
+
grad_clip_norm: float = 1.0
|
| 20 |
+
save_interval: int = Field(default=1000, ge=1)
|
| 21 |
+
max_checkpoints_to_keep: int = 10
|
| 22 |
+
log_interval: int = Field(default=10, ge=1)
|
| 23 |
+
eval_interval: int | None = Field(default=None, ge=1)
|
| 24 |
+
max_eval_batches: int | None = None
|
| 25 |
+
run_eval_on_start: bool = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
__all__ = ["TrainConfig"]
|
src/dots_tts/data/EXTENSION.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data Source Extension Guide
|
| 2 |
+
|
| 3 |
+
This document answers exactly one question: how to plug a new training data source into the current `dots_tts` data pipeline.
|
| 4 |
+
|
| 5 |
+
If you only need to swap in a different JSONL manifest, no code changes are required. To support a new raw data format, you usually only need to add:
|
| 6 |
+
|
| 7 |
+
- one **source adapter**
|
| 8 |
+
- optionally one **sample pipeline**
|
| 9 |
+
|
| 10 |
+
## Data flow
|
| 11 |
+
|
| 12 |
+
1. An **adapter** reads from the raw data source and yields raw samples.
|
| 13 |
+
2. A **pipeline** turns each raw sample into a training sample (1:1).
|
| 14 |
+
3. A **multi-source wrapper** handles mixing across sources and resume state.
|
| 15 |
+
4. `StreamingSampleDataset` / `DataLoader` pulls samples.
|
| 16 |
+
5. `OnlineBatcher` assembles batches and `PadCollator` performs padding.
|
| 17 |
+
|
| 18 |
+
## What an adapter must implement
|
| 19 |
+
|
| 20 |
+
Subclass `BaseSourceAdapter`:
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
class BaseSourceAdapter(ABC):
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def initial_state(self) -> dict[str, Any]:
|
| 26 |
+
...
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def iter_samples(
|
| 30 |
+
self,
|
| 31 |
+
context: SourceContext,
|
| 32 |
+
*,
|
| 33 |
+
state: dict[str, Any] | None = None,
|
| 34 |
+
) -> Iterable[dict[str, Any]]:
|
| 35 |
+
...
|
| 36 |
+
|
| 37 |
+
@abstractmethod
|
| 38 |
+
def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool:
|
| 39 |
+
...
|
| 40 |
+
|
| 41 |
+
# Optional — only required when used under WeightedMultiSourceAdapter,
|
| 42 |
+
# which cycles each finite child source independently. The default
|
| 43 |
+
# implementation raises if your adapter never gets re-cycled.
|
| 44 |
+
def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]:
|
| 45 |
+
...
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Each emitted sample **must** carry these fields:
|
| 49 |
+
|
| 50 |
+
- `fid`
|
| 51 |
+
- `text`
|
| 52 |
+
- `audio`
|
| 53 |
+
- `_adapter_state`
|
| 54 |
+
|
| 55 |
+
Key constraints:
|
| 56 |
+
|
| 57 |
+
- `_adapter_state` must describe **where to resume next**, not the position of the current item.
|
| 58 |
+
- The state must be plain Python data — serializable and recoverable after a restart.
|
| 59 |
+
- If your source needs to be split across workers, use `context.global_worker_id` and `context.global_worker_count` (or subclass `ShardableSourceAdapter` and use its `is_assigned_index` / `shard_items` helpers).
|
| 60 |
+
- If the source will participate in weighted cyclic sampling, you must implement `advance_cycle` and make `is_cycle_start_state` correct — otherwise `WeightedMultiSourceAdapter` cannot detect an empty cycle and will raise.
|
| 61 |
+
|
| 62 |
+
After implementing the adapter, register the class in `dots_tts/data/builders.py::_SOURCE_ADAPTER_CLASSES` so that the YAML config can resolve it by `class_name`.
|
| 63 |
+
|
| 64 |
+
## What a pipeline must implement
|
| 65 |
+
|
| 66 |
+
Pipelines must subclass `BaseSamplePipeline` and perform a strict **1:1** sample transform.
|
| 67 |
+
|
| 68 |
+
Minimum implementation:
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
class MyPipeline(BaseSamplePipeline):
|
| 72 |
+
def process_sample(self, sample: dict) -> dict:
|
| 73 |
+
sample["text"] = str(sample["text"]).strip()
|
| 74 |
+
return sample
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
Do **not**:
|
| 78 |
+
|
| 79 |
+
- filter samples out
|
| 80 |
+
- expand a single sample into multiple samples
|
| 81 |
+
- assemble batches inside the pipeline
|
| 82 |
+
|
| 83 |
+
`BaseSamplePipeline.__call__` automatically merges the original raw sample (including `_adapter_state` and any extra fields the adapter attached) with whatever your `process_sample` returns. You do not need to copy these fields manually — just return the fields you produced or want to overwrite.
|
| 84 |
+
|
| 85 |
+
To wire a new pipeline into config, also extend `dots_tts/data/builders.py::_build_source_pipeline` so it can be selected by name in YAML.
|
| 86 |
+
|
| 87 |
+
## How multi-source wrappers affect you
|
| 88 |
+
|
| 89 |
+
There are two wrappers in the current codebase:
|
| 90 |
+
|
| 91 |
+
- `SequentialMultiSourceAdapter` — used for validation. Reads sources in the configured order, exhaustively, once.
|
| 92 |
+
- `WeightedMultiSourceAdapter` — used for training. Draws sources by weight, cycles each child source independently when exhausted.
|
| 93 |
+
|
| 94 |
+
Both wrappers **replace** the `_adapter_state` produced by your child adapter with their own resume state before yielding to the dataset. Even so, the child adapter must still emit its own `_adapter_state` — the wrapper reads it to track where each sub-source has read to.
|
| 95 |
+
|
| 96 |
+
## Config
|
| 97 |
+
|
| 98 |
+
Each source is configured independently:
|
| 99 |
+
|
| 100 |
+
```yaml
|
| 101 |
+
train_data:
|
| 102 |
+
sources:
|
| 103 |
+
- name: train_a
|
| 104 |
+
weight: 1.0
|
| 105 |
+
pipeline: basic
|
| 106 |
+
adapter:
|
| 107 |
+
class_name: JsonlManifestSourceAdapter
|
| 108 |
+
params:
|
| 109 |
+
manifest_path: train_a.jsonl
|
| 110 |
+
- name: train_b
|
| 111 |
+
weight: 2.0
|
| 112 |
+
pipeline: interleave
|
| 113 |
+
adapter:
|
| 114 |
+
class_name: JsonlManifestSourceAdapter
|
| 115 |
+
params:
|
| 116 |
+
manifest_path: train_b.jsonl
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
Constraints:
|
| 120 |
+
|
| 121 |
+
- `sources[].name` must be unique within the same `train_data` / `val_data` block (it is used as a dict key for resume state).
|
| 122 |
+
- `sources[].pipeline` is a per-source setting, not shared across the dataset.
|
| 123 |
+
- All sources must ultimately produce the same training-sample structure, since they feed into the same batcher and collator.
|
| 124 |
+
- `class_name` must match a key registered in `_SOURCE_ADAPTER_CLASSES`; `params` is forwarded verbatim as kwargs to the adapter constructor.
|
src/dots_tts/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Data package."""
|
src/dots_tts/data/batchers.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
from collections.abc import Iterable, Iterator
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from dots_tts.utils.profiling import ensure_data_profiler
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(slots=True)
|
| 11 |
+
class BatchDecision:
|
| 12 |
+
dropped_samples: list[dict]
|
| 13 |
+
batch_samples: list[dict]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass(slots=True)
|
| 17 |
+
class _PoolSample:
|
| 18 |
+
sample: dict
|
| 19 |
+
num_audio_tokens: int
|
| 20 |
+
num_text_tokens: int
|
| 21 |
+
arrival_step: int
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class OnlineBatcher:
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
*,
|
| 28 |
+
max_audio_tokens_in_batch: int,
|
| 29 |
+
max_text_tokens_in_batch: int,
|
| 30 |
+
max_batch_size: int | None,
|
| 31 |
+
sample_pool_size: int,
|
| 32 |
+
profiler=None,
|
| 33 |
+
):
|
| 34 |
+
self.max_audio_tokens_in_batch = max(1, int(max_audio_tokens_in_batch))
|
| 35 |
+
self.max_text_tokens_in_batch = max(1, int(max_text_tokens_in_batch))
|
| 36 |
+
self.max_batch_size = max_batch_size
|
| 37 |
+
self.sample_pool_size = max(1, int(sample_pool_size))
|
| 38 |
+
self.profiler = ensure_data_profiler(profiler)
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def _sort_pool(pool: list[_PoolSample]) -> None:
|
| 42 |
+
pool.sort(
|
| 43 |
+
key=lambda item: (
|
| 44 |
+
item.num_audio_tokens,
|
| 45 |
+
item.num_text_tokens,
|
| 46 |
+
-item.arrival_step,
|
| 47 |
+
),
|
| 48 |
+
reverse=True,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def _choose_anchor_index(
|
| 52 |
+
self,
|
| 53 |
+
pool: list[_PoolSample],
|
| 54 |
+
*,
|
| 55 |
+
decision_step: int,
|
| 56 |
+
) -> int:
|
| 57 |
+
oldest_waiting_index = -1
|
| 58 |
+
oldest_waiting_step = decision_step
|
| 59 |
+
|
| 60 |
+
for index, item in enumerate(pool):
|
| 61 |
+
waited_steps = decision_step - item.arrival_step
|
| 62 |
+
if waited_steps < self.sample_pool_size:
|
| 63 |
+
continue
|
| 64 |
+
if item.arrival_step <= oldest_waiting_step:
|
| 65 |
+
oldest_waiting_index = index
|
| 66 |
+
oldest_waiting_step = item.arrival_step
|
| 67 |
+
|
| 68 |
+
return 0 if oldest_waiting_index < 0 else oldest_waiting_index
|
| 69 |
+
|
| 70 |
+
def _build_next_decision(
|
| 71 |
+
self,
|
| 72 |
+
pool: list[_PoolSample],
|
| 73 |
+
*,
|
| 74 |
+
decision_step: int,
|
| 75 |
+
) -> BatchDecision:
|
| 76 |
+
dropped_samples: list[dict] = []
|
| 77 |
+
batch_samples: list[dict] = []
|
| 78 |
+
selected_indices: list[int] = []
|
| 79 |
+
anchor_index = self._choose_anchor_index(pool, decision_step=decision_step)
|
| 80 |
+
anchor = pool[anchor_index]
|
| 81 |
+
|
| 82 |
+
exceed_audio_budget = anchor.num_audio_tokens > self.max_audio_tokens_in_batch
|
| 83 |
+
exceed_text_budget = anchor.num_text_tokens > self.max_text_tokens_in_batch
|
| 84 |
+
exceed_batch_size = self.max_batch_size is not None and self.max_batch_size < 1
|
| 85 |
+
if exceed_audio_budget or exceed_text_budget or exceed_batch_size:
|
| 86 |
+
skipped = pool.pop(anchor_index).sample
|
| 87 |
+
dropped_samples.append(skipped)
|
| 88 |
+
warnings.warn(
|
| 89 |
+
"Skipping sample that exceeds batching limits on its own: "
|
| 90 |
+
f"fid={skipped.get('fid')!r}, "
|
| 91 |
+
f"num_audio_tokens={anchor.num_audio_tokens}, "
|
| 92 |
+
f"input_ids_length={anchor.num_text_tokens}, "
|
| 93 |
+
f"max_audio_tokens_in_batch={self.max_audio_tokens_in_batch}, "
|
| 94 |
+
f"max_text_tokens_in_batch={self.max_text_tokens_in_batch}, "
|
| 95 |
+
f"max_batch_size={self.max_batch_size}",
|
| 96 |
+
RuntimeWarning,
|
| 97 |
+
stacklevel=2,
|
| 98 |
+
)
|
| 99 |
+
return BatchDecision(
|
| 100 |
+
dropped_samples=dropped_samples,
|
| 101 |
+
batch_samples=batch_samples,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
longest_audio_tokens = anchor.num_audio_tokens
|
| 105 |
+
longest_text_tokens = anchor.num_text_tokens
|
| 106 |
+
batch_samples.append(anchor.sample)
|
| 107 |
+
selected_indices.append(anchor_index)
|
| 108 |
+
|
| 109 |
+
for index, item in enumerate(pool):
|
| 110 |
+
if index == anchor_index:
|
| 111 |
+
continue
|
| 112 |
+
if (
|
| 113 |
+
self.max_batch_size is not None
|
| 114 |
+
and len(batch_samples) >= self.max_batch_size
|
| 115 |
+
):
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
proposed_batch_size = len(batch_samples) + 1
|
| 119 |
+
proposed_longest_audio_tokens = max(
|
| 120 |
+
longest_audio_tokens,
|
| 121 |
+
item.num_audio_tokens,
|
| 122 |
+
)
|
| 123 |
+
proposed_longest_text_tokens = max(
|
| 124 |
+
longest_text_tokens,
|
| 125 |
+
item.num_text_tokens,
|
| 126 |
+
)
|
| 127 |
+
if (
|
| 128 |
+
proposed_longest_audio_tokens * proposed_batch_size
|
| 129 |
+
> self.max_audio_tokens_in_batch
|
| 130 |
+
):
|
| 131 |
+
continue
|
| 132 |
+
if (
|
| 133 |
+
proposed_longest_text_tokens * proposed_batch_size
|
| 134 |
+
> self.max_text_tokens_in_batch
|
| 135 |
+
):
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
batch_samples.append(item.sample)
|
| 139 |
+
selected_indices.append(index)
|
| 140 |
+
longest_audio_tokens = proposed_longest_audio_tokens
|
| 141 |
+
longest_text_tokens = proposed_longest_text_tokens
|
| 142 |
+
|
| 143 |
+
for index in sorted(set(selected_indices), reverse=True):
|
| 144 |
+
pool.pop(index)
|
| 145 |
+
|
| 146 |
+
return BatchDecision(
|
| 147 |
+
dropped_samples=dropped_samples,
|
| 148 |
+
batch_samples=batch_samples,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def build_decisions(self, sample_iter: Iterable[dict]) -> Iterator[BatchDecision]:
|
| 152 |
+
pool: list[_PoolSample] = []
|
| 153 |
+
source_exhausted = False
|
| 154 |
+
decision_step = 0
|
| 155 |
+
iterator = iter(sample_iter)
|
| 156 |
+
|
| 157 |
+
while not source_exhausted or pool:
|
| 158 |
+
while not source_exhausted and len(pool) < self.sample_pool_size:
|
| 159 |
+
try:
|
| 160 |
+
sample = next(iterator)
|
| 161 |
+
except StopIteration:
|
| 162 |
+
source_exhausted = True
|
| 163 |
+
break
|
| 164 |
+
pool.append(
|
| 165 |
+
_PoolSample(
|
| 166 |
+
sample=sample,
|
| 167 |
+
num_audio_tokens=int(sample.get("num_audio_tokens", 0)),
|
| 168 |
+
num_text_tokens=int(sample.get("input_ids_length", 0)),
|
| 169 |
+
arrival_step=decision_step,
|
| 170 |
+
)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
if not pool:
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
profiler = self.profiler
|
| 177 |
+
with profiler.measure("main.sort_pool", count=len(pool)):
|
| 178 |
+
self._sort_pool(pool)
|
| 179 |
+
with profiler.measure("main.build_batch_decision"):
|
| 180 |
+
decision = self._build_next_decision(
|
| 181 |
+
pool,
|
| 182 |
+
decision_step=decision_step,
|
| 183 |
+
)
|
| 184 |
+
if decision.dropped_samples or decision.batch_samples:
|
| 185 |
+
decision_step += 1
|
| 186 |
+
yield decision
|
| 187 |
+
continue
|
| 188 |
+
raise RuntimeError("OnlineBatcher failed to make progress on a non-empty pool.")
|
src/dots_tts/data/builders.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
|
| 5 |
+
from dots_tts.config.data import DataConfig
|
| 6 |
+
from dots_tts.data.pipelines.base import BaseSamplePipeline
|
| 7 |
+
from dots_tts.data.pipelines.tts_pipeline import BasicTtsPipeline, InterleaveTtsPipeline
|
| 8 |
+
from dots_tts.data.source_adapters.jsonl_manifest_adapter import (
|
| 9 |
+
JsonlManifestSourceAdapter,
|
| 10 |
+
)
|
| 11 |
+
from dots_tts.data.source_adapters.multi_source_adapter import (
|
| 12 |
+
SequentialMultiSourceAdapter,
|
| 13 |
+
SourceSpec,
|
| 14 |
+
WeightedMultiSourceAdapter,
|
| 15 |
+
)
|
| 16 |
+
from dots_tts.data.streaming import (
|
| 17 |
+
BatchedDataStream,
|
| 18 |
+
StreamingSampleDataset,
|
| 19 |
+
identity_collate,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
_SOURCE_ADAPTER_CLASSES = {
|
| 23 |
+
"JsonlManifestSourceAdapter": JsonlManifestSourceAdapter,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _build_source_pipeline(
|
| 28 |
+
tokenizer, data_cfg, pipeline_name: str, *, profiler=None
|
| 29 |
+
) -> BaseSamplePipeline:
|
| 30 |
+
if pipeline_name == "basic":
|
| 31 |
+
return BasicTtsPipeline(tokenizer, data_cfg, profiler=profiler)
|
| 32 |
+
if pipeline_name == "interleave":
|
| 33 |
+
return InterleaveTtsPipeline(tokenizer, data_cfg, profiler=profiler)
|
| 34 |
+
raise ValueError(f"Unsupported data pipeline: {pipeline_name!r}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _build_source_specs(data_cfg, tokenizer, *, profiler=None) -> list[SourceSpec]:
|
| 38 |
+
specs = []
|
| 39 |
+
for source_cfg in data_cfg.sources:
|
| 40 |
+
adapter_cls = _SOURCE_ADAPTER_CLASSES[source_cfg.adapter.class_name]
|
| 41 |
+
adapter = adapter_cls(**source_cfg.adapter.params)
|
| 42 |
+
specs.append(
|
| 43 |
+
SourceSpec(
|
| 44 |
+
name=source_cfg.name,
|
| 45 |
+
weight=float(source_cfg.weight),
|
| 46 |
+
adapter=adapter,
|
| 47 |
+
pipeline=_build_source_pipeline(
|
| 48 |
+
tokenizer, data_cfg, source_cfg.pipeline, profiler=profiler
|
| 49 |
+
),
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
+
return specs
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _resolve_rank_info(accelerator=None) -> tuple[int, int]:
|
| 56 |
+
rank = (
|
| 57 |
+
int(getattr(accelerator, "process_index", 0)) if accelerator is not None else 0
|
| 58 |
+
)
|
| 59 |
+
world_size = (
|
| 60 |
+
int(getattr(accelerator, "num_processes", 1)) if accelerator is not None else 1
|
| 61 |
+
)
|
| 62 |
+
return rank, world_size
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _local_num_tokens_per_epoch(
|
| 66 |
+
global_num_tokens_per_epoch: int, *, rank: int, world_size: int
|
| 67 |
+
) -> int:
|
| 68 |
+
if world_size <= 0:
|
| 69 |
+
raise ValueError(f"world_size must be positive, but got {world_size}.")
|
| 70 |
+
if rank < 0 or rank >= world_size:
|
| 71 |
+
raise ValueError(
|
| 72 |
+
f"rank must be in [0, {world_size}), but got rank={rank}."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
base, remainder = divmod(int(global_num_tokens_per_epoch), int(world_size))
|
| 76 |
+
return base + int(rank < remainder)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _build_dataset(
|
| 80 |
+
data_cfg: DataConfig,
|
| 81 |
+
*,
|
| 82 |
+
tokenizer,
|
| 83 |
+
seed: int,
|
| 84 |
+
accelerator=None,
|
| 85 |
+
sequential: bool,
|
| 86 |
+
profiler=None,
|
| 87 |
+
):
|
| 88 |
+
rank, world_size = _resolve_rank_info(accelerator)
|
| 89 |
+
source_cls = SequentialMultiSourceAdapter if sequential else WeightedMultiSourceAdapter
|
| 90 |
+
source = source_cls(
|
| 91 |
+
sources=_build_source_specs(data_cfg, tokenizer, profiler=profiler)
|
| 92 |
+
)
|
| 93 |
+
return StreamingSampleDataset(
|
| 94 |
+
source=source,
|
| 95 |
+
rank=rank,
|
| 96 |
+
world_size=world_size,
|
| 97 |
+
seed=int(seed),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def build_training_dataset(
|
| 102 |
+
data_cfg: DataConfig,
|
| 103 |
+
tokenizer,
|
| 104 |
+
*,
|
| 105 |
+
seed: int,
|
| 106 |
+
accelerator=None,
|
| 107 |
+
profiler=None,
|
| 108 |
+
):
|
| 109 |
+
if data_cfg.num_tokens_per_epoch is None:
|
| 110 |
+
raise ValueError("Training data requires num_tokens_per_epoch.")
|
| 111 |
+
return _build_dataset(
|
| 112 |
+
data_cfg,
|
| 113 |
+
tokenizer=tokenizer,
|
| 114 |
+
seed=seed,
|
| 115 |
+
accelerator=accelerator,
|
| 116 |
+
sequential=False,
|
| 117 |
+
profiler=profiler,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def build_validation_dataset(
|
| 122 |
+
data_cfg: DataConfig,
|
| 123 |
+
tokenizer,
|
| 124 |
+
*,
|
| 125 |
+
seed: int,
|
| 126 |
+
accelerator=None,
|
| 127 |
+
profiler=None,
|
| 128 |
+
):
|
| 129 |
+
return _build_dataset(
|
| 130 |
+
data_cfg,
|
| 131 |
+
tokenizer=tokenizer,
|
| 132 |
+
seed=seed,
|
| 133 |
+
accelerator=accelerator,
|
| 134 |
+
sequential=True,
|
| 135 |
+
profiler=profiler,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _build_sample_loader(dataset, data_cfg: DataConfig) -> DataLoader:
|
| 140 |
+
loader_kwargs = {
|
| 141 |
+
"dataset": dataset,
|
| 142 |
+
"batch_size": None,
|
| 143 |
+
"collate_fn": identity_collate,
|
| 144 |
+
"num_workers": data_cfg.num_workers,
|
| 145 |
+
"pin_memory": data_cfg.pin_memory,
|
| 146 |
+
"persistent_workers": data_cfg.num_workers > 0,
|
| 147 |
+
}
|
| 148 |
+
if data_cfg.num_workers > 0:
|
| 149 |
+
loader_kwargs["prefetch_factor"] = int(data_cfg.prefetch_factor)
|
| 150 |
+
sample_loader = DataLoader(**loader_kwargs)
|
| 151 |
+
return sample_loader
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def build_training_dataloader(
|
| 155 |
+
dataset, data_cfg: DataConfig, tokenizer, *, profiler=None
|
| 156 |
+
):
|
| 157 |
+
local_num_tokens_per_epoch = _local_num_tokens_per_epoch(
|
| 158 |
+
int(data_cfg.num_tokens_per_epoch),
|
| 159 |
+
rank=int(dataset.rank),
|
| 160 |
+
world_size=int(dataset.world_size),
|
| 161 |
+
)
|
| 162 |
+
sample_loader = _build_sample_loader(dataset, data_cfg)
|
| 163 |
+
batched_stream = BatchedDataStream(
|
| 164 |
+
sample_dataset=dataset,
|
| 165 |
+
data_cfg=data_cfg,
|
| 166 |
+
tokenizer=tokenizer,
|
| 167 |
+
num_tokens_per_epoch=local_num_tokens_per_epoch,
|
| 168 |
+
profiler=profiler,
|
| 169 |
+
)
|
| 170 |
+
batched_stream.attach_loader(sample_loader)
|
| 171 |
+
return batched_stream
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def build_validation_dataloader(
|
| 175 |
+
dataset, data_cfg: DataConfig, tokenizer, *, profiler=None
|
| 176 |
+
):
|
| 177 |
+
sample_loader = _build_sample_loader(dataset, data_cfg)
|
| 178 |
+
batched_stream = BatchedDataStream(
|
| 179 |
+
sample_dataset=dataset,
|
| 180 |
+
data_cfg=data_cfg,
|
| 181 |
+
tokenizer=tokenizer,
|
| 182 |
+
num_tokens_per_epoch=None,
|
| 183 |
+
profiler=profiler,
|
| 184 |
+
)
|
| 185 |
+
batched_stream.attach_loader(sample_loader)
|
| 186 |
+
return batched_stream
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
__all__ = [
|
| 190 |
+
"build_training_dataloader",
|
| 191 |
+
"build_training_dataset",
|
| 192 |
+
"build_validation_dataloader",
|
| 193 |
+
"build_validation_dataset",
|
| 194 |
+
]
|
src/dots_tts/data/collator.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PadCollator:
|
| 10 |
+
def __init__(self, tokenizer):
|
| 11 |
+
self.tokenizer = tokenizer
|
| 12 |
+
self.pad_token_id = tokenizer.pad_token_id
|
| 13 |
+
if self.pad_token_id is None:
|
| 14 |
+
self.pad_token_id = tokenizer.eos_token_id or 0
|
| 15 |
+
|
| 16 |
+
def __call__(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
|
| 17 |
+
if not samples:
|
| 18 |
+
raise ValueError("PadCollator received an empty sample list.")
|
| 19 |
+
|
| 20 |
+
order = sorted(
|
| 21 |
+
range(len(samples)),
|
| 22 |
+
key=lambda idx: samples[idx]["sample_length"],
|
| 23 |
+
reverse=True,
|
| 24 |
+
)
|
| 25 |
+
ordered = [samples[idx] for idx in order]
|
| 26 |
+
|
| 27 |
+
input_ids = [
|
| 28 |
+
torch.tensor(sample["input_ids"], dtype=torch.long) for sample in ordered
|
| 29 |
+
]
|
| 30 |
+
labels = [
|
| 31 |
+
torch.tensor(sample["labels"], dtype=torch.long) for sample in ordered
|
| 32 |
+
]
|
| 33 |
+
loss_masks = [
|
| 34 |
+
torch.tensor(sample["loss_mask"], dtype=torch.float32) for sample in ordered
|
| 35 |
+
]
|
| 36 |
+
waveforms = [sample["sample"].squeeze(0) for sample in ordered]
|
| 37 |
+
fbank = [sample["fbank"] for sample in ordered]
|
| 38 |
+
|
| 39 |
+
return {
|
| 40 |
+
"fids": [sample["fid"] for sample in ordered],
|
| 41 |
+
"source_names": [sample.get("source_name") for sample in ordered],
|
| 42 |
+
"input_ids": pad_sequence(
|
| 43 |
+
input_ids,
|
| 44 |
+
batch_first=True,
|
| 45 |
+
padding_value=self.pad_token_id,
|
| 46 |
+
),
|
| 47 |
+
"input_ids_lengths": torch.tensor(
|
| 48 |
+
[len(sample["input_ids"]) for sample in ordered],
|
| 49 |
+
dtype=torch.long,
|
| 50 |
+
),
|
| 51 |
+
"labels": pad_sequence(
|
| 52 |
+
labels,
|
| 53 |
+
batch_first=True,
|
| 54 |
+
padding_value=self.pad_token_id,
|
| 55 |
+
),
|
| 56 |
+
"loss_mask": pad_sequence(
|
| 57 |
+
loss_masks,
|
| 58 |
+
batch_first=True,
|
| 59 |
+
padding_value=0.0,
|
| 60 |
+
),
|
| 61 |
+
"sample": pad_sequence(
|
| 62 |
+
waveforms,
|
| 63 |
+
batch_first=True,
|
| 64 |
+
padding_value=0.0,
|
| 65 |
+
).unsqueeze(1),
|
| 66 |
+
"sample_lengths": torch.tensor(
|
| 67 |
+
[sample["sample_length"] for sample in ordered],
|
| 68 |
+
dtype=torch.long,
|
| 69 |
+
),
|
| 70 |
+
"num_text_tokens": torch.tensor(
|
| 71 |
+
[sample["num_text_tokens"] for sample in ordered],
|
| 72 |
+
dtype=torch.long,
|
| 73 |
+
),
|
| 74 |
+
"num_audio_tokens": torch.tensor(
|
| 75 |
+
[sample["num_audio_tokens"] for sample in ordered],
|
| 76 |
+
dtype=torch.long,
|
| 77 |
+
),
|
| 78 |
+
"fbank": pad_sequence(
|
| 79 |
+
fbank,
|
| 80 |
+
batch_first=True,
|
| 81 |
+
padding_value=0.0,
|
| 82 |
+
),
|
| 83 |
+
"fbank_lengths": torch.tensor(
|
| 84 |
+
[sample["fbank_length"] for sample in ordered],
|
| 85 |
+
dtype=torch.long,
|
| 86 |
+
),
|
| 87 |
+
}
|
src/dots_tts/data/pipelines/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Data pipelines package."""
|
src/dots_tts/data/pipelines/base.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from collections.abc import Iterable, Iterator
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseSamplePipeline(ABC):
|
| 8 |
+
"""1:1 sample pipeline that preserves adapter resume metadata."""
|
| 9 |
+
|
| 10 |
+
@staticmethod
|
| 11 |
+
def _validate_input_sample(sample: dict) -> None:
|
| 12 |
+
if "_adapter_state" not in sample:
|
| 13 |
+
raise RuntimeError(
|
| 14 |
+
"Source sample is missing required '_adapter_state' for resume."
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def process_sample(self, sample: dict) -> dict:
|
| 19 |
+
"""Transform one raw sample into one processed sample."""
|
| 20 |
+
|
| 21 |
+
def __call__(self, samples: Iterable[dict]) -> Iterator[dict]:
|
| 22 |
+
for raw_sample in samples:
|
| 23 |
+
self._validate_input_sample(raw_sample)
|
| 24 |
+
processed = self.process_sample(dict(raw_sample))
|
| 25 |
+
if not isinstance(processed, dict):
|
| 26 |
+
raise RuntimeError(
|
| 27 |
+
f"{self.__class__.__name__}.process_sample() must return a dict."
|
| 28 |
+
)
|
| 29 |
+
item = dict(raw_sample)
|
| 30 |
+
item.update(processed)
|
| 31 |
+
self._validate_input_sample(item)
|
| 32 |
+
yield item
|
src/dots_tts/data/pipelines/preprocessing.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
DEFAULT_EDGE_SILENCE_MS = 250.0
|
| 7 |
+
DEFAULT_EDGE_SILENCE_TOP_DB = 30.0
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def align_length(num_samples: int, multiple_of: int | None) -> int:
|
| 11 |
+
if multiple_of is None or multiple_of <= 0:
|
| 12 |
+
return int(num_samples)
|
| 13 |
+
if num_samples % multiple_of == 0:
|
| 14 |
+
return int(num_samples)
|
| 15 |
+
return int(((num_samples + multiple_of - 1) // multiple_of) * multiple_of)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def pad_waveform_align_only(
|
| 19 |
+
waveform: torch.Tensor,
|
| 20 |
+
*,
|
| 21 |
+
multiple_of: int | None,
|
| 22 |
+
) -> torch.Tensor:
|
| 23 |
+
if multiple_of is None or multiple_of <= 0:
|
| 24 |
+
return waveform
|
| 25 |
+
|
| 26 |
+
target_length = align_length(waveform.size(-1), multiple_of)
|
| 27 |
+
delta = target_length - waveform.size(-1)
|
| 28 |
+
if delta <= 0:
|
| 29 |
+
return waveform
|
| 30 |
+
|
| 31 |
+
return F.pad(waveform, (0, delta), "constant", 0.0)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def normalize_edge_silence_duration(
|
| 35 |
+
waveform: torch.Tensor,
|
| 36 |
+
*,
|
| 37 |
+
sample_rate: int,
|
| 38 |
+
target_silence_duration_ms: float = DEFAULT_EDGE_SILENCE_MS,
|
| 39 |
+
top_db: float = DEFAULT_EDGE_SILENCE_TOP_DB,
|
| 40 |
+
) -> torch.Tensor:
|
| 41 |
+
mono_waveform = waveform[0]
|
| 42 |
+
target_samples = int(round(float(sample_rate) * float(target_silence_duration_ms) / 1000.0))
|
| 43 |
+
amplitude = mono_waveform.abs()
|
| 44 |
+
peak = float(amplitude.max().item())
|
| 45 |
+
if peak <= 0.0:
|
| 46 |
+
waveform = waveform[..., :target_samples]
|
| 47 |
+
current_length = int(waveform.size(-1))
|
| 48 |
+
if current_length < target_samples:
|
| 49 |
+
waveform = F.pad(waveform, (0, target_samples - current_length), "constant", 0.0)
|
| 50 |
+
return waveform
|
| 51 |
+
|
| 52 |
+
threshold = peak * (10.0 ** (-float(top_db) / 20.0))
|
| 53 |
+
non_silent = torch.nonzero(amplitude > threshold, as_tuple=False).flatten()
|
| 54 |
+
first_non_silent = int(non_silent[0].item())
|
| 55 |
+
last_non_silent = int(non_silent[-1].item())
|
| 56 |
+
|
| 57 |
+
leading_silence_samples = first_non_silent
|
| 58 |
+
trailing_silence_samples = int(mono_waveform.numel()) - last_non_silent - 1
|
| 59 |
+
|
| 60 |
+
leading_delta = target_samples - leading_silence_samples
|
| 61 |
+
if leading_delta > 0:
|
| 62 |
+
waveform = F.pad(waveform, (leading_delta, 0), "constant", 0.0)
|
| 63 |
+
else:
|
| 64 |
+
trim_from_start = min(-leading_delta, int(waveform.size(-1)))
|
| 65 |
+
waveform = waveform[..., trim_from_start:]
|
| 66 |
+
|
| 67 |
+
trailing_delta = target_samples - trailing_silence_samples
|
| 68 |
+
if trailing_delta > 0:
|
| 69 |
+
return F.pad(waveform, (0, trailing_delta), "constant", 0.0)
|
| 70 |
+
|
| 71 |
+
trim_from_end = min(-trailing_delta, int(waveform.size(-1)))
|
| 72 |
+
if trim_from_end <= 0:
|
| 73 |
+
return waveform
|
| 74 |
+
return waveform[..., :-trim_from_end]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def compute_num_audio_tokens(
|
| 78 |
+
num_samples: int, *, audio_samples_per_llm_token: int
|
| 79 |
+
) -> int:
|
| 80 |
+
if num_samples % audio_samples_per_llm_token != 0:
|
| 81 |
+
raise ValueError(
|
| 82 |
+
f"Waveform length {num_samples} is not aligned to token hop {audio_samples_per_llm_token}."
|
| 83 |
+
)
|
| 84 |
+
return num_samples // audio_samples_per_llm_token
|
src/dots_tts/data/pipelines/tokenizing.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
import re
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
from dots_tts.utils.tokenizer import (
|
| 10 |
+
AUDIO_GEN_END_TOKEN,
|
| 11 |
+
AUDIO_GEN_SPAN_TOKEN,
|
| 12 |
+
AUDIO_GEN_START_TOKEN,
|
| 13 |
+
TEXT_COND_END_TOKEN,
|
| 14 |
+
require_token_id,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
TEMPLATE_PATTERN = re.compile(r"\{text\}|\{audio\}|\{interleave\}|[^\{]+")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass(frozen=True)
|
| 21 |
+
class ParsedTemplate:
|
| 22 |
+
parts: tuple[str, ...]
|
| 23 |
+
has_audio_placeholder: bool
|
| 24 |
+
has_interleave_placeholder: bool
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class TokenizedTemplatePart:
|
| 29 |
+
kind: str
|
| 30 |
+
token_ids: tuple[int, ...] = ()
|
| 31 |
+
raw_text: str | None = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def parse_template(template: str) -> ParsedTemplate:
|
| 35 |
+
parts = tuple(re.findall(TEMPLATE_PATTERN, template))
|
| 36 |
+
has_audio_placeholder = "{audio}" in parts
|
| 37 |
+
interleave_count = parts.count("{interleave}")
|
| 38 |
+
if has_audio_placeholder and interleave_count:
|
| 39 |
+
raise ValueError("Template cannot mix audio and interleave placeholders.")
|
| 40 |
+
if interleave_count > 1:
|
| 41 |
+
raise ValueError(
|
| 42 |
+
"Interleave generation template must contain exactly one interleave placeholder."
|
| 43 |
+
)
|
| 44 |
+
return ParsedTemplate(
|
| 45 |
+
parts=parts,
|
| 46 |
+
has_audio_placeholder=has_audio_placeholder,
|
| 47 |
+
has_interleave_placeholder=interleave_count == 1,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _prepare_template_tokens(
|
| 52 |
+
*, text: str, tokenizer, template: str
|
| 53 |
+
) -> tuple[ParsedTemplate, list[int]]:
|
| 54 |
+
return parse_template(template), tokenizer.encode(text, add_special_tokens=False)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _iter_tokenized_template_parts(
|
| 58 |
+
*,
|
| 59 |
+
parsed_template: ParsedTemplate,
|
| 60 |
+
tokenizer,
|
| 61 |
+
text_tokens: list[int],
|
| 62 |
+
):
|
| 63 |
+
for part in parsed_template.parts:
|
| 64 |
+
if part == "{text}":
|
| 65 |
+
yield TokenizedTemplatePart(kind="text", token_ids=tuple(text_tokens))
|
| 66 |
+
continue
|
| 67 |
+
if part == "{audio}":
|
| 68 |
+
yield TokenizedTemplatePart(kind="audio")
|
| 69 |
+
continue
|
| 70 |
+
if part == "{interleave}":
|
| 71 |
+
yield TokenizedTemplatePart(kind="interleave")
|
| 72 |
+
continue
|
| 73 |
+
yield TokenizedTemplatePart(
|
| 74 |
+
kind="literal",
|
| 75 |
+
token_ids=tuple(tokenizer.encode(part, add_special_tokens=False)),
|
| 76 |
+
raw_text=part,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _extend_tokens_with_loss(
|
| 81 |
+
*, full_ids: list[int], loss_mask: list[float], token_ids: tuple[int, ...], loss: float
|
| 82 |
+
) -> None:
|
| 83 |
+
full_ids.extend(token_ids)
|
| 84 |
+
loss_mask.extend([loss] * len(token_ids))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def build_tokenized_example(
|
| 88 |
+
*, text: str, tokenizer, template: str, num_audio_tokens: int
|
| 89 |
+
) -> dict[str, Any]:
|
| 90 |
+
if tokenizer.eos_token_id is None:
|
| 91 |
+
raise ValueError("Tokenizer eos_token_id is required for generation targets.")
|
| 92 |
+
|
| 93 |
+
parsed_template, text_tokens = _prepare_template_tokens(
|
| 94 |
+
text=text,
|
| 95 |
+
tokenizer=tokenizer,
|
| 96 |
+
template=template,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
full_ids: list[int] = []
|
| 100 |
+
loss_mask: list[float] = []
|
| 101 |
+
audio_tokens: list[int] | None = None
|
| 102 |
+
if parsed_template.has_audio_placeholder:
|
| 103 |
+
audio_gen_start_id = require_token_id(tokenizer, AUDIO_GEN_START_TOKEN)
|
| 104 |
+
audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN)
|
| 105 |
+
audio_gen_end_id = require_token_id(tokenizer, AUDIO_GEN_END_TOKEN)
|
| 106 |
+
audio_tokens = (
|
| 107 |
+
[audio_gen_start_id]
|
| 108 |
+
+ [audio_gen_span_id] * num_audio_tokens
|
| 109 |
+
+ [audio_gen_end_id]
|
| 110 |
+
)
|
| 111 |
+
elif parsed_template.has_interleave_placeholder:
|
| 112 |
+
audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN)
|
| 113 |
+
audio_gen_end_id = require_token_id(tokenizer, AUDIO_GEN_END_TOKEN)
|
| 114 |
+
text_cond_end_id = require_token_id(tokenizer, TEXT_COND_END_TOKEN)
|
| 115 |
+
|
| 116 |
+
for part in _iter_tokenized_template_parts(
|
| 117 |
+
parsed_template=parsed_template,
|
| 118 |
+
tokenizer=tokenizer,
|
| 119 |
+
text_tokens=text_tokens,
|
| 120 |
+
):
|
| 121 |
+
if part.kind == "text":
|
| 122 |
+
_extend_tokens_with_loss(
|
| 123 |
+
full_ids=full_ids,
|
| 124 |
+
loss_mask=loss_mask,
|
| 125 |
+
token_ids=part.token_ids,
|
| 126 |
+
loss=0.0,
|
| 127 |
+
)
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
if part.kind == "audio":
|
| 131 |
+
if audio_tokens is None:
|
| 132 |
+
raise RuntimeError("Audio placeholder tokens were not initialized.")
|
| 133 |
+
full_ids.extend(audio_tokens)
|
| 134 |
+
loss_mask.extend([0.0])
|
| 135 |
+
loss_mask.extend([1.0] * max(0, len(audio_tokens) - 2))
|
| 136 |
+
loss_mask.append(0.0)
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
if part.kind == "interleave":
|
| 140 |
+
_append_interleave_generation_tokens(
|
| 141 |
+
full_ids=full_ids,
|
| 142 |
+
loss_mask=loss_mask,
|
| 143 |
+
text_tokens=text_tokens,
|
| 144 |
+
num_audio_tokens=num_audio_tokens,
|
| 145 |
+
audio_span_id=audio_gen_span_id,
|
| 146 |
+
audio_end_id=audio_gen_end_id,
|
| 147 |
+
text_cond_end_id=text_cond_end_id,
|
| 148 |
+
)
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
_extend_tokens_with_loss(
|
| 152 |
+
full_ids=full_ids,
|
| 153 |
+
loss_mask=loss_mask,
|
| 154 |
+
token_ids=part.token_ids,
|
| 155 |
+
loss=0.0,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
full_ids.append(tokenizer.eos_token_id)
|
| 159 |
+
loss_mask.append(0.0)
|
| 160 |
+
|
| 161 |
+
return {
|
| 162 |
+
"input_ids": full_ids[:-1],
|
| 163 |
+
"labels": full_ids[1:],
|
| 164 |
+
"loss_mask": loss_mask[1:],
|
| 165 |
+
"text_token_count": len(text_tokens),
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def build_generation_schedule(
|
| 170 |
+
*,
|
| 171 |
+
text: str,
|
| 172 |
+
tokenizer,
|
| 173 |
+
template: str,
|
| 174 |
+
max_audio_tokens: int,
|
| 175 |
+
) -> dict[str, Any]:
|
| 176 |
+
if max_audio_tokens <= 0:
|
| 177 |
+
raise ValueError("max_audio_tokens must be positive for generation.")
|
| 178 |
+
|
| 179 |
+
parsed_template, text_tokens = _prepare_template_tokens(
|
| 180 |
+
text=text,
|
| 181 |
+
tokenizer=tokenizer,
|
| 182 |
+
template=template,
|
| 183 |
+
)
|
| 184 |
+
schedule_ids: list[int] = []
|
| 185 |
+
audio_gen_start_id = require_token_id(tokenizer, AUDIO_GEN_START_TOKEN)
|
| 186 |
+
audio_gen_span_id = require_token_id(tokenizer, AUDIO_GEN_SPAN_TOKEN)
|
| 187 |
+
|
| 188 |
+
if parsed_template.has_audio_placeholder:
|
| 189 |
+
for part in _iter_tokenized_template_parts(
|
| 190 |
+
parsed_template=parsed_template,
|
| 191 |
+
tokenizer=tokenizer,
|
| 192 |
+
text_tokens=text_tokens,
|
| 193 |
+
):
|
| 194 |
+
if part.kind == "audio":
|
| 195 |
+
schedule_ids.append(audio_gen_start_id)
|
| 196 |
+
schedule_ids.extend([audio_gen_span_id] * max_audio_tokens)
|
| 197 |
+
continue
|
| 198 |
+
schedule_ids.extend(part.token_ids)
|
| 199 |
+
visible_schedule_ids = [
|
| 200 |
+
token_id for token_id in schedule_ids if token_id != audio_gen_span_id
|
| 201 |
+
]
|
| 202 |
+
decoded_schedule = (
|
| 203 |
+
tokenizer.decode(
|
| 204 |
+
visible_schedule_ids,
|
| 205 |
+
skip_special_tokens=False,
|
| 206 |
+
clean_up_tokenization_spaces=False,
|
| 207 |
+
)
|
| 208 |
+
if hasattr(tokenizer, "decode")
|
| 209 |
+
else repr(visible_schedule_ids)
|
| 210 |
+
)
|
| 211 |
+
logger.info(
|
| 212 |
+
"Built generation schedule: interleave={} max_audio_tokens={} sequence={!r}",
|
| 213 |
+
False,
|
| 214 |
+
int(max_audio_tokens),
|
| 215 |
+
decoded_schedule,
|
| 216 |
+
)
|
| 217 |
+
return {
|
| 218 |
+
"schedule_ids": schedule_ids,
|
| 219 |
+
"interleave": False,
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
if not parsed_template.has_interleave_placeholder:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
"Generation template must contain either {audio} or {interleave}."
|
| 225 |
+
)
|
| 226 |
+
text_cond_end_id = require_token_id(tokenizer, TEXT_COND_END_TOKEN)
|
| 227 |
+
if max_audio_tokens < len(text_tokens):
|
| 228 |
+
raise ValueError(
|
| 229 |
+
"Interleave generation requires at least one audio span per text token: "
|
| 230 |
+
f"text_token_count={len(text_tokens)} "
|
| 231 |
+
f"max_audio_patch_count={max_audio_tokens}."
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
interleave_started = False
|
| 235 |
+
for part in _iter_tokenized_template_parts(
|
| 236 |
+
parsed_template=parsed_template,
|
| 237 |
+
tokenizer=tokenizer,
|
| 238 |
+
text_tokens=text_tokens,
|
| 239 |
+
):
|
| 240 |
+
if part.kind == "interleave":
|
| 241 |
+
_append_interleave_schedule_tokens(
|
| 242 |
+
schedule_ids=schedule_ids,
|
| 243 |
+
text_tokens=text_tokens,
|
| 244 |
+
max_audio_tokens=max_audio_tokens,
|
| 245 |
+
audio_span_id=audio_gen_span_id,
|
| 246 |
+
text_cond_end_id=text_cond_end_id,
|
| 247 |
+
)
|
| 248 |
+
interleave_started = True
|
| 249 |
+
continue
|
| 250 |
+
if part.kind == "text":
|
| 251 |
+
raise ValueError(
|
| 252 |
+
"Generation schedule does not support {text} inside an interleave template."
|
| 253 |
+
)
|
| 254 |
+
if part.kind == "audio":
|
| 255 |
+
raise ValueError(
|
| 256 |
+
"Generation schedule does not support {audio} inside an interleave template."
|
| 257 |
+
)
|
| 258 |
+
if interleave_started:
|
| 259 |
+
if (part.raw_text or "").strip():
|
| 260 |
+
raise ValueError(
|
| 261 |
+
"Generation schedule does not support non-empty suffix text after the interleave placeholder."
|
| 262 |
+
)
|
| 263 |
+
continue
|
| 264 |
+
schedule_ids.extend(part.token_ids)
|
| 265 |
+
|
| 266 |
+
visible_schedule_ids = [
|
| 267 |
+
token_id for token_id in schedule_ids if token_id != audio_gen_span_id
|
| 268 |
+
]
|
| 269 |
+
decoded_schedule = (
|
| 270 |
+
tokenizer.decode(
|
| 271 |
+
visible_schedule_ids,
|
| 272 |
+
skip_special_tokens=False,
|
| 273 |
+
clean_up_tokenization_spaces=False,
|
| 274 |
+
)
|
| 275 |
+
if hasattr(tokenizer, "decode")
|
| 276 |
+
else repr(visible_schedule_ids)
|
| 277 |
+
)
|
| 278 |
+
logger.info(
|
| 279 |
+
"Built generation schedule: interleave={} max_audio_tokens={} sequence={!r}",
|
| 280 |
+
True,
|
| 281 |
+
int(max_audio_tokens),
|
| 282 |
+
decoded_schedule,
|
| 283 |
+
)
|
| 284 |
+
return {
|
| 285 |
+
"schedule_ids": schedule_ids,
|
| 286 |
+
"interleave": True,
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def _append_interleave_generation_tokens(
|
| 291 |
+
*,
|
| 292 |
+
full_ids: list[int],
|
| 293 |
+
loss_mask: list[float],
|
| 294 |
+
text_tokens: list[int],
|
| 295 |
+
num_audio_tokens: int,
|
| 296 |
+
audio_span_id: int,
|
| 297 |
+
audio_end_id: int,
|
| 298 |
+
text_cond_end_id: int,
|
| 299 |
+
) -> None:
|
| 300 |
+
audio_tokens = [audio_span_id] * num_audio_tokens + [audio_end_id]
|
| 301 |
+
text_index = 0
|
| 302 |
+
audio_index = 0
|
| 303 |
+
text_cond_end_added = False
|
| 304 |
+
|
| 305 |
+
while text_index < len(text_tokens) or audio_index < len(audio_tokens):
|
| 306 |
+
if text_index < len(text_tokens):
|
| 307 |
+
full_ids.append(text_tokens[text_index])
|
| 308 |
+
loss_mask.append(0.0)
|
| 309 |
+
text_index += 1
|
| 310 |
+
elif not text_cond_end_added:
|
| 311 |
+
full_ids.append(text_cond_end_id)
|
| 312 |
+
loss_mask.append(0.0)
|
| 313 |
+
text_cond_end_added = True
|
| 314 |
+
|
| 315 |
+
if audio_index < len(audio_tokens):
|
| 316 |
+
full_ids.append(audio_tokens[audio_index])
|
| 317 |
+
loss_mask.append(1.0 if audio_index < num_audio_tokens else 0.0)
|
| 318 |
+
audio_index += 1
|
| 319 |
+
|
| 320 |
+
if not text_cond_end_added:
|
| 321 |
+
full_ids.append(text_cond_end_id)
|
| 322 |
+
loss_mask.append(0.0)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _append_interleave_schedule_tokens(
|
| 326 |
+
*,
|
| 327 |
+
schedule_ids: list[int],
|
| 328 |
+
text_tokens: list[int],
|
| 329 |
+
max_audio_tokens: int,
|
| 330 |
+
audio_span_id: int,
|
| 331 |
+
text_cond_end_id: int,
|
| 332 |
+
) -> None:
|
| 333 |
+
for token_id in text_tokens:
|
| 334 |
+
schedule_ids.append(token_id)
|
| 335 |
+
schedule_ids.append(audio_span_id)
|
| 336 |
+
schedule_ids.append(text_cond_end_id)
|
| 337 |
+
remaining_audio_tokens = max_audio_tokens - len(text_tokens)
|
| 338 |
+
if remaining_audio_tokens > 0:
|
| 339 |
+
schedule_ids.extend([audio_span_id] * remaining_audio_tokens)
|
src/dots_tts/data/pipelines/tts_pipeline.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from dots_tts.utils.profiling import ensure_data_profiler
|
| 7 |
+
from dots_tts.data.pipelines.base import BaseSamplePipeline
|
| 8 |
+
from dots_tts.data.pipelines.preprocessing import (
|
| 9 |
+
compute_num_audio_tokens,
|
| 10 |
+
normalize_edge_silence_duration,
|
| 11 |
+
pad_waveform_align_only,
|
| 12 |
+
)
|
| 13 |
+
from dots_tts.data.pipelines.tokenizing import build_tokenized_example
|
| 14 |
+
from dots_tts.modules.speaker.fbank import extract_speaker_fbank
|
| 15 |
+
from dots_tts.utils.audio import high_quality_resample
|
| 16 |
+
|
| 17 |
+
TTS_TEXT_PREFIX = "[文本]"
|
| 18 |
+
TTS_AUDIO_PREFIX = "[文本对应语音]"
|
| 19 |
+
TTS_INSTRUCTION_TEXT_PREFIX = "[带指令文本]"
|
| 20 |
+
TTA_TEXT_PREFIX = "[声音描述]"
|
| 21 |
+
TTA_AUDIO_PREFIX = "[描述对应声音]"
|
| 22 |
+
TTS_INTERLEAVE_PREFIX = "[流式语音合成]"
|
| 23 |
+
DEFAULT_TRAIN_TEMPLATE = f"{TTS_TEXT_PREFIX}{{text}}{TTS_AUDIO_PREFIX}{{audio}}"
|
| 24 |
+
DEFAULT_INSTRUCTION_TTS_TEMPLATE = (
|
| 25 |
+
f"{TTS_INSTRUCTION_TEXT_PREFIX}{{text}}{TTS_AUDIO_PREFIX}{{audio}}"
|
| 26 |
+
)
|
| 27 |
+
DEFAULT_TEXT_TO_AUDIO_TEMPLATE = f"{TTA_TEXT_PREFIX}{{text}}{TTA_AUDIO_PREFIX}{{audio}}"
|
| 28 |
+
DEFAULT_INTERLEAVE_TRAIN_TEMPLATE = f"{TTS_INTERLEAVE_PREFIX}{{interleave}}"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BasicTtsPipeline(BaseSamplePipeline):
|
| 32 |
+
"""Fixed internal training pipeline for adapter-emitted samples."""
|
| 33 |
+
|
| 34 |
+
template = DEFAULT_TRAIN_TEMPLATE
|
| 35 |
+
|
| 36 |
+
def __init__(self, tokenizer, data_cfg, *, profiler=None):
|
| 37 |
+
self.tokenizer = tokenizer
|
| 38 |
+
self.train_audio_sample_rate = int(data_cfg.train_audio_sample_rate)
|
| 39 |
+
self.audio_samples_per_llm_token = int(data_cfg.audio_samples_per_llm_token)
|
| 40 |
+
self.profiler = ensure_data_profiler(profiler)
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def _load_waveform(audio_path: str) -> tuple[torch.Tensor, int]:
|
| 44 |
+
if not isinstance(audio_path, str):
|
| 45 |
+
raise TypeError(
|
| 46 |
+
f"Training audio must be a filesystem path, got {type(audio_path)}."
|
| 47 |
+
)
|
| 48 |
+
audio_data, sample_rate = sf.read(
|
| 49 |
+
audio_path,
|
| 50 |
+
dtype="float32",
|
| 51 |
+
always_2d=True,
|
| 52 |
+
)
|
| 53 |
+
waveform = torch.from_numpy(audio_data.T)
|
| 54 |
+
if waveform.size(0) > 1:
|
| 55 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
| 56 |
+
return waveform.contiguous(), int(sample_rate)
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def _validate_source_sample(sample: dict) -> None:
|
| 60 |
+
missing = [field for field in ("fid", "text", "audio") if field not in sample]
|
| 61 |
+
if missing:
|
| 62 |
+
raise ValueError(
|
| 63 |
+
"Source adapter must emit fid/text/audio. "
|
| 64 |
+
f"Missing fields: {missing}. Sample keys: {sorted(sample.keys())}"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def process_sample(self, raw_sample: dict) -> dict:
|
| 68 |
+
sample = dict(raw_sample)
|
| 69 |
+
self._validate_source_sample(sample)
|
| 70 |
+
sample["fid"] = str(sample["fid"])
|
| 71 |
+
|
| 72 |
+
with self.profiler.measure("worker.process_sample_total"):
|
| 73 |
+
return self._process_sample_impl(sample)
|
| 74 |
+
|
| 75 |
+
def _process_sample_impl(self, sample: dict) -> dict:
|
| 76 |
+
profiler = self.profiler
|
| 77 |
+
with profiler.measure("worker.load_audio"):
|
| 78 |
+
waveform, sample_rate = self._load_waveform(sample["audio"])
|
| 79 |
+
with profiler.measure("worker.resample_audio"):
|
| 80 |
+
waveform = high_quality_resample(
|
| 81 |
+
waveform,
|
| 82 |
+
orig_sr=sample_rate,
|
| 83 |
+
target_sr=self.train_audio_sample_rate,
|
| 84 |
+
)
|
| 85 |
+
with profiler.measure("worker.normalize_edge_silence"):
|
| 86 |
+
waveform = normalize_edge_silence_duration(
|
| 87 |
+
waveform,
|
| 88 |
+
sample_rate=self.train_audio_sample_rate,
|
| 89 |
+
)
|
| 90 |
+
sample["sample"] = waveform
|
| 91 |
+
sample["sample_rate"] = self.train_audio_sample_rate
|
| 92 |
+
sample["unpadded_sample_length"] = int(waveform.size(-1))
|
| 93 |
+
|
| 94 |
+
with profiler.measure("worker.pad_audio"):
|
| 95 |
+
waveform = pad_waveform_align_only(
|
| 96 |
+
waveform,
|
| 97 |
+
multiple_of=self.audio_samples_per_llm_token,
|
| 98 |
+
)
|
| 99 |
+
sample["sample"] = waveform
|
| 100 |
+
sample["sample_length"] = int(waveform.size(-1))
|
| 101 |
+
|
| 102 |
+
num_audio_tokens = compute_num_audio_tokens(
|
| 103 |
+
sample["sample_length"],
|
| 104 |
+
audio_samples_per_llm_token=self.audio_samples_per_llm_token,
|
| 105 |
+
)
|
| 106 |
+
with profiler.measure("worker.tokenize"):
|
| 107 |
+
tokenized = build_tokenized_example(
|
| 108 |
+
text=sample["text"],
|
| 109 |
+
tokenizer=self.tokenizer,
|
| 110 |
+
template=self.template,
|
| 111 |
+
num_audio_tokens=num_audio_tokens,
|
| 112 |
+
)
|
| 113 |
+
sample["input_ids"] = tokenized["input_ids"]
|
| 114 |
+
sample["labels"] = tokenized["labels"]
|
| 115 |
+
sample["loss_mask"] = tokenized["loss_mask"]
|
| 116 |
+
sample["input_ids_length"] = len(tokenized["input_ids"])
|
| 117 |
+
sample["num_text_tokens"] = tokenized["text_token_count"]
|
| 118 |
+
sample["num_audio_tokens"] = num_audio_tokens
|
| 119 |
+
sample["num_total_tokens"] = sample["input_ids_length"]
|
| 120 |
+
|
| 121 |
+
with profiler.measure("worker.extract_fbank"):
|
| 122 |
+
fbank = extract_speaker_fbank(
|
| 123 |
+
sample["sample"],
|
| 124 |
+
sample_rate=sample["sample_rate"],
|
| 125 |
+
)
|
| 126 |
+
sample["fbank"] = fbank
|
| 127 |
+
sample["fbank_length"] = int(fbank.size(0))
|
| 128 |
+
return sample
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class InterleaveTtsPipeline(BasicTtsPipeline):
|
| 132 |
+
template = DEFAULT_INTERLEAVE_TRAIN_TEMPLATE
|
src/dots_tts/data/source_adapters/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Source adapter package."""
|
src/dots_tts/data/source_adapters/base_adapter.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from collections.abc import Iterable, Sequence
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, TypeVar
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True)
|
| 12 |
+
class SourceContext:
|
| 13 |
+
"""Execution context for a single adapter iterator."""
|
| 14 |
+
|
| 15 |
+
epoch: int
|
| 16 |
+
rank: int
|
| 17 |
+
world_size: int
|
| 18 |
+
worker_id: int
|
| 19 |
+
num_workers: int
|
| 20 |
+
seed: int
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
def global_worker_count(self) -> int:
|
| 24 |
+
return max(1, self.world_size * self.num_workers)
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def global_worker_id(self) -> int:
|
| 28 |
+
return self.rank * self.num_workers + self.worker_id
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BaseSourceAdapter(ABC):
|
| 32 |
+
"""State-aware streaming source interface used by the training pipeline."""
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def initial_state(self) -> dict[str, Any]:
|
| 36 |
+
"""Return the default iterator state for a new worker/epoch."""
|
| 37 |
+
|
| 38 |
+
@abstractmethod
|
| 39 |
+
def iter_samples(
|
| 40 |
+
self,
|
| 41 |
+
context: SourceContext,
|
| 42 |
+
*,
|
| 43 |
+
state: dict[str, Any] | None = None,
|
| 44 |
+
) -> Iterable[dict[str, Any]]:
|
| 45 |
+
"""Yield raw samples and attach the next adapter state to each item."""
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool:
|
| 49 |
+
"""Return whether ``state`` points at the beginning of a source cycle."""
|
| 50 |
+
|
| 51 |
+
def normalize_state(self, state: dict[str, Any] | None) -> dict[str, Any]:
|
| 52 |
+
merged = self.initial_state()
|
| 53 |
+
if state:
|
| 54 |
+
merged.update(deepcopy(state))
|
| 55 |
+
return merged
|
| 56 |
+
|
| 57 |
+
def clone_state(self, state: dict[str, Any] | None) -> dict[str, Any]:
|
| 58 |
+
return deepcopy(self.normalize_state(state))
|
| 59 |
+
|
| 60 |
+
def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]:
|
| 61 |
+
raise RuntimeError(
|
| 62 |
+
f"{self.__class__.__name__} does not support repeated cycling."
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
_T = TypeVar("_T")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ShardableSourceAdapter(BaseSourceAdapter):
|
| 70 |
+
"""Helper mixin for deterministic rank/worker sharding."""
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def is_assigned_index(index: int, context: SourceContext) -> bool:
|
| 74 |
+
return index % context.global_worker_count == context.global_worker_id
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def shard_items(
|
| 78 |
+
items: Sequence[_T],
|
| 79 |
+
context: SourceContext,
|
| 80 |
+
*,
|
| 81 |
+
shuffle: bool = False,
|
| 82 |
+
seed_offset: int = 0,
|
| 83 |
+
) -> list[_T]:
|
| 84 |
+
assigned = list(items)
|
| 85 |
+
if shuffle:
|
| 86 |
+
random.Random(context.seed + context.epoch + seed_offset).shuffle(assigned)
|
| 87 |
+
return [
|
| 88 |
+
item
|
| 89 |
+
for index, item in enumerate(assigned)
|
| 90 |
+
if ShardableSourceAdapter.is_assigned_index(index, context)
|
| 91 |
+
]
|
src/dots_tts/data/source_adapters/jsonl_manifest_adapter.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
from collections.abc import Iterable, Iterator
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from dots_tts.data.source_adapters.base_adapter import (
|
| 10 |
+
BaseSourceAdapter,
|
| 11 |
+
ShardableSourceAdapter,
|
| 12 |
+
SourceContext,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class JsonlManifestSourceAdapter(ShardableSourceAdapter, BaseSourceAdapter):
|
| 17 |
+
"""Finite adapter for line-delimited JSON manifests."""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
*,
|
| 22 |
+
manifest_path: str,
|
| 23 |
+
fid_key: str = "fid",
|
| 24 |
+
text_key: str = "text",
|
| 25 |
+
audio_key: str = "audio",
|
| 26 |
+
shuffle: bool = False,
|
| 27 |
+
encoding: str = "utf-8",
|
| 28 |
+
):
|
| 29 |
+
self.manifest_path = Path(manifest_path)
|
| 30 |
+
self.fid_key = fid_key
|
| 31 |
+
self.text_key = text_key
|
| 32 |
+
self.audio_key = audio_key
|
| 33 |
+
self.shuffle = shuffle
|
| 34 |
+
self.encoding = encoding
|
| 35 |
+
self._records: list[dict[str, Any]] | None = None
|
| 36 |
+
|
| 37 |
+
def initial_state(self) -> dict[str, Any]:
|
| 38 |
+
return {"cycle": 0, "cursor": 0}
|
| 39 |
+
|
| 40 |
+
def is_cycle_start_state(self, state: dict[str, Any] | None) -> bool:
|
| 41 |
+
normalized = self.normalize_state(state)
|
| 42 |
+
return int(normalized["cursor"]) == 0
|
| 43 |
+
|
| 44 |
+
def advance_cycle(self, state: dict[str, Any] | None) -> dict[str, Any]:
|
| 45 |
+
normalized = self.normalize_state(state)
|
| 46 |
+
return {"cycle": int(normalized["cycle"]) + 1, "cursor": 0}
|
| 47 |
+
|
| 48 |
+
def _iter_records(self) -> Iterator[dict[str, Any]]:
|
| 49 |
+
if not self.manifest_path.is_file():
|
| 50 |
+
raise FileNotFoundError(f"Manifest file not found: {self.manifest_path!s}")
|
| 51 |
+
with self.manifest_path.open("r", encoding=self.encoding) as fin:
|
| 52 |
+
for line_no, raw_line in enumerate(fin, start=1):
|
| 53 |
+
line = raw_line.strip()
|
| 54 |
+
if not line:
|
| 55 |
+
continue
|
| 56 |
+
try:
|
| 57 |
+
yield json.loads(line)
|
| 58 |
+
except json.JSONDecodeError as exc:
|
| 59 |
+
raise ValueError(
|
| 60 |
+
f"Invalid JSON at {self.manifest_path}:{line_no}"
|
| 61 |
+
) from exc
|
| 62 |
+
|
| 63 |
+
def _base_records(self) -> list[dict[str, Any]]:
|
| 64 |
+
if self._records is None:
|
| 65 |
+
self._records = list(self._iter_records())
|
| 66 |
+
return self._records
|
| 67 |
+
|
| 68 |
+
def _build_sample(self, record: dict[str, Any]) -> dict[str, Any]:
|
| 69 |
+
missing = [
|
| 70 |
+
key
|
| 71 |
+
for key in (self.fid_key, self.text_key, self.audio_key)
|
| 72 |
+
if key not in record
|
| 73 |
+
]
|
| 74 |
+
if missing:
|
| 75 |
+
raise KeyError(
|
| 76 |
+
f"Manifest record is missing required keys {missing}: {record}"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
sample = {
|
| 80 |
+
"fid": str(record[self.fid_key]),
|
| 81 |
+
"text": record[self.text_key],
|
| 82 |
+
"audio": record[self.audio_key],
|
| 83 |
+
}
|
| 84 |
+
for key, value in record.items():
|
| 85 |
+
if key in {self.fid_key, self.text_key, self.audio_key}:
|
| 86 |
+
continue
|
| 87 |
+
sample[key] = value
|
| 88 |
+
return sample
|
| 89 |
+
|
| 90 |
+
def _indices_for_cycle(
|
| 91 |
+
self,
|
| 92 |
+
context: SourceContext,
|
| 93 |
+
*,
|
| 94 |
+
cycle: int,
|
| 95 |
+
) -> list[int]:
|
| 96 |
+
indices = list(range(len(self._base_records())))
|
| 97 |
+
if self.shuffle:
|
| 98 |
+
random.Random(context.seed + context.epoch + 1009 * int(cycle)).shuffle(
|
| 99 |
+
indices
|
| 100 |
+
)
|
| 101 |
+
indices = [
|
| 102 |
+
record_index
|
| 103 |
+
for shuffled_index, record_index in enumerate(indices)
|
| 104 |
+
if self.is_assigned_index(shuffled_index, context)
|
| 105 |
+
]
|
| 106 |
+
else:
|
| 107 |
+
indices = [
|
| 108 |
+
record_index
|
| 109 |
+
for record_index in indices
|
| 110 |
+
if self.is_assigned_index(record_index, context)
|
| 111 |
+
]
|
| 112 |
+
return indices
|
| 113 |
+
|
| 114 |
+
def iter_samples(
|
| 115 |
+
self,
|
| 116 |
+
context: SourceContext,
|
| 117 |
+
*,
|
| 118 |
+
state: dict[str, Any] | None = None,
|
| 119 |
+
) -> Iterable[dict[str, Any]]:
|
| 120 |
+
live_state = self.normalize_state(state)
|
| 121 |
+
cycle = int(live_state["cycle"])
|
| 122 |
+
cursor = int(live_state["cursor"])
|
| 123 |
+
records = self._base_records()
|
| 124 |
+
indices = self._indices_for_cycle(context, cycle=cycle)
|
| 125 |
+
|
| 126 |
+
for position in range(cursor, len(indices)):
|
| 127 |
+
sample = self._build_sample(records[indices[position]])
|
| 128 |
+
sample["_adapter_state"] = {
|
| 129 |
+
"cycle": cycle,
|
| 130 |
+
"cursor": position + 1,
|
| 131 |
+
}
|
| 132 |
+
yield sample
|
src/dots_tts/data/source_adapters/multi_source_adapter.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from dots_tts.data.pipelines.base import BaseSamplePipeline
|
| 8 |
+
from dots_tts.data.source_adapters.base_adapter import (
|
| 9 |
+
BaseSourceAdapter,
|
| 10 |
+
SourceContext,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class SourceSpec:
|
| 16 |
+
name: str
|
| 17 |
+
weight: float
|
| 18 |
+
adapter: BaseSourceAdapter
|
| 19 |
+
pipeline: BaseSamplePipeline
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
_UINT64_MASK = 0xFFFFFFFFFFFFFFFF
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _mix_uint64(value: int) -> int:
|
| 26 |
+
value = (value ^ (value >> 30)) * 0xBF58476D1CE4E5B9
|
| 27 |
+
value &= _UINT64_MASK
|
| 28 |
+
value = (value ^ (value >> 27)) * 0x94D049BB133111EB
|
| 29 |
+
value &= _UINT64_MASK
|
| 30 |
+
return (value ^ (value >> 31)) & _UINT64_MASK
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _stable_seed(*parts: int) -> int:
|
| 34 |
+
value = 0x9E3779B97F4A7C15
|
| 35 |
+
for part in parts:
|
| 36 |
+
value = (value + int(part) + 0x9E3779B97F4A7C15) & _UINT64_MASK
|
| 37 |
+
value = _mix_uint64(value)
|
| 38 |
+
return value
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class SequentialMultiSourceAdapter(BaseSourceAdapter):
|
| 42 |
+
"""Finite adapter that concatenates sources in the configured order."""
|
| 43 |
+
|
| 44 |
+
def __init__(self, *, sources: list[SourceSpec]):
|
| 45 |
+
if not sources:
|
| 46 |
+
raise ValueError(
|
| 47 |
+
"SequentialMultiSourceAdapter requires at least one source."
|
| 48 |
+
)
|
| 49 |
+
self.sources = list(sources)
|
| 50 |
+
|
| 51 |
+
def initial_state(self) -> dict:
|
| 52 |
+
return {
|
| 53 |
+
"source_index": 0,
|
| 54 |
+
"sources": {
|
| 55 |
+
source.name: source.adapter.initial_state() for source in self.sources
|
| 56 |
+
},
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
def is_cycle_start_state(self, state: dict | None) -> bool:
|
| 60 |
+
normalized = self.normalize_state(state)
|
| 61 |
+
if int(normalized["source_index"]) != 0:
|
| 62 |
+
return False
|
| 63 |
+
return all(
|
| 64 |
+
source.adapter.is_cycle_start_state(normalized["sources"][source.name])
|
| 65 |
+
for source in self.sources
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def normalize_state(self, state: dict | None) -> dict:
|
| 69 |
+
normalized = super().normalize_state(state)
|
| 70 |
+
source_states = normalized.get("sources") or {}
|
| 71 |
+
normalized["sources"] = {
|
| 72 |
+
source.name: source.adapter.clone_state(source_states.get(source.name))
|
| 73 |
+
for source in self.sources
|
| 74 |
+
}
|
| 75 |
+
normalized["source_index"] = int(normalized.get("source_index", 0))
|
| 76 |
+
return normalized
|
| 77 |
+
|
| 78 |
+
def clone_state(self, state: dict | None) -> dict:
|
| 79 |
+
return deepcopy(self.normalize_state(state))
|
| 80 |
+
|
| 81 |
+
def iter_samples(
|
| 82 |
+
self,
|
| 83 |
+
context: SourceContext,
|
| 84 |
+
*,
|
| 85 |
+
state: dict | None = None,
|
| 86 |
+
) -> Iterable[dict]:
|
| 87 |
+
live_state = self.normalize_state(state)
|
| 88 |
+
start_index = int(live_state["source_index"])
|
| 89 |
+
for index in range(start_index, len(self.sources)):
|
| 90 |
+
source = self.sources[index]
|
| 91 |
+
child_state = live_state["sources"][source.name]
|
| 92 |
+
raw_iter = source.adapter.iter_samples(context, state=child_state)
|
| 93 |
+
for sample in source.pipeline(raw_iter):
|
| 94 |
+
item = dict(sample)
|
| 95 |
+
next_child_state = item.pop("_adapter_state", None)
|
| 96 |
+
if next_child_state is None:
|
| 97 |
+
raise RuntimeError(
|
| 98 |
+
f"{source.adapter.__class__.__name__} must attach '_adapter_state' to samples."
|
| 99 |
+
)
|
| 100 |
+
live_state["source_index"] = index
|
| 101 |
+
live_state["sources"][source.name] = source.adapter.clone_state(
|
| 102 |
+
next_child_state
|
| 103 |
+
)
|
| 104 |
+
item["source_name"] = source.name
|
| 105 |
+
item["_adapter_state"] = self.clone_state(live_state)
|
| 106 |
+
yield item
|
| 107 |
+
live_state["source_index"] = index + 1
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class WeightedMultiSourceAdapter(BaseSourceAdapter):
|
| 111 |
+
"""Infinite weighted sampler that cycles each child source independently."""
|
| 112 |
+
|
| 113 |
+
def __init__(self, *, sources: list[SourceSpec]):
|
| 114 |
+
if not sources:
|
| 115 |
+
raise ValueError("WeightedMultiSourceAdapter requires at least one source.")
|
| 116 |
+
invalid = [source.name for source in sources if float(source.weight) <= 0.0]
|
| 117 |
+
if invalid:
|
| 118 |
+
raise ValueError(f"Source weights must be positive: {invalid}")
|
| 119 |
+
self.sources = list(sources)
|
| 120 |
+
self._cumulative_weights = []
|
| 121 |
+
total = 0.0
|
| 122 |
+
for source in self.sources:
|
| 123 |
+
total += float(source.weight)
|
| 124 |
+
self._cumulative_weights.append(total)
|
| 125 |
+
self._total_weight = total
|
| 126 |
+
|
| 127 |
+
def initial_state(self) -> dict:
|
| 128 |
+
return {
|
| 129 |
+
"draw_count": 0,
|
| 130 |
+
"sources": {
|
| 131 |
+
source.name: source.adapter.initial_state() for source in self.sources
|
| 132 |
+
},
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
def is_cycle_start_state(self, state: dict | None) -> bool:
|
| 136 |
+
normalized = self.normalize_state(state)
|
| 137 |
+
if int(normalized["draw_count"]) != 0:
|
| 138 |
+
return False
|
| 139 |
+
return all(
|
| 140 |
+
source.adapter.is_cycle_start_state(normalized["sources"][source.name])
|
| 141 |
+
for source in self.sources
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def normalize_state(self, state: dict | None) -> dict:
|
| 145 |
+
normalized = super().normalize_state(state)
|
| 146 |
+
source_states = normalized.get("sources") or {}
|
| 147 |
+
normalized["sources"] = {
|
| 148 |
+
source.name: source.adapter.clone_state(source_states.get(source.name))
|
| 149 |
+
for source in self.sources
|
| 150 |
+
}
|
| 151 |
+
normalized["draw_count"] = int(normalized.get("draw_count", 0))
|
| 152 |
+
return normalized
|
| 153 |
+
|
| 154 |
+
def clone_state(self, state: dict | None) -> dict:
|
| 155 |
+
return deepcopy(self.normalize_state(state))
|
| 156 |
+
|
| 157 |
+
def _source_draw_value(self, context: SourceContext, draw_count: int) -> float:
|
| 158 |
+
raw = _stable_seed(
|
| 159 |
+
context.seed,
|
| 160 |
+
context.epoch,
|
| 161 |
+
context.rank,
|
| 162 |
+
context.worker_id,
|
| 163 |
+
draw_count,
|
| 164 |
+
)
|
| 165 |
+
return (raw / float(1 << 64)) * self._total_weight
|
| 166 |
+
|
| 167 |
+
def _pick_source(self, context: SourceContext, draw_count: int) -> SourceSpec:
|
| 168 |
+
draw_value = self._source_draw_value(context, draw_count)
|
| 169 |
+
for source, upper in zip(self.sources, self._cumulative_weights, strict=True):
|
| 170 |
+
if draw_value < upper:
|
| 171 |
+
return source
|
| 172 |
+
return self.sources[-1]
|
| 173 |
+
|
| 174 |
+
def iter_samples(
|
| 175 |
+
self,
|
| 176 |
+
context: SourceContext,
|
| 177 |
+
*,
|
| 178 |
+
state: dict | None = None,
|
| 179 |
+
) -> Iterable[dict]:
|
| 180 |
+
live_state = self.normalize_state(state)
|
| 181 |
+
iterators: dict[str, object] = {}
|
| 182 |
+
|
| 183 |
+
while True:
|
| 184 |
+
draw_count = int(live_state["draw_count"])
|
| 185 |
+
source = self._pick_source(context, draw_count)
|
| 186 |
+
|
| 187 |
+
while True:
|
| 188 |
+
child_state = live_state["sources"][source.name]
|
| 189 |
+
child_iter = iterators.get(source.name)
|
| 190 |
+
if child_iter is None:
|
| 191 |
+
raw_iter = source.adapter.iter_samples(context, state=child_state)
|
| 192 |
+
child_iter = iter(source.pipeline(raw_iter))
|
| 193 |
+
iterators[source.name] = child_iter
|
| 194 |
+
|
| 195 |
+
try:
|
| 196 |
+
sample = dict(next(child_iter))
|
| 197 |
+
except StopIteration:
|
| 198 |
+
if source.adapter.is_cycle_start_state(child_state):
|
| 199 |
+
raise RuntimeError(
|
| 200 |
+
"Weighted source yielded no samples for this worker. "
|
| 201 |
+
f"source={source.name!r}, worker={context.global_worker_id}, "
|
| 202 |
+
f"epoch={context.epoch}"
|
| 203 |
+
)
|
| 204 |
+
iterators.pop(source.name, None)
|
| 205 |
+
live_state["sources"][source.name] = source.adapter.advance_cycle(
|
| 206 |
+
child_state
|
| 207 |
+
)
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
next_child_state = sample.pop("_adapter_state", None)
|
| 211 |
+
if next_child_state is None:
|
| 212 |
+
raise RuntimeError(
|
| 213 |
+
f"{source.adapter.__class__.__name__} must attach '_adapter_state' to samples."
|
| 214 |
+
)
|
| 215 |
+
live_state["sources"][source.name] = source.adapter.clone_state(
|
| 216 |
+
next_child_state
|
| 217 |
+
)
|
| 218 |
+
live_state["draw_count"] = draw_count + 1
|
| 219 |
+
sample["source_name"] = source.name
|
| 220 |
+
sample["_adapter_state"] = self.clone_state(live_state)
|
| 221 |
+
yield sample
|
| 222 |
+
break
|
src/dots_tts/data/streaming.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import multiprocessing as mp
|
| 5 |
+
from collections.abc import Iterable
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
|
| 8 |
+
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
| 9 |
+
|
| 10 |
+
from dots_tts.data.batchers import OnlineBatcher
|
| 11 |
+
from dots_tts.utils.profiling import ensure_data_profiler
|
| 12 |
+
from dots_tts.data.source_adapters.base_adapter import BaseSourceAdapter, SourceContext
|
| 13 |
+
|
| 14 |
+
_TRACKING_KEY = "__tracking_state__"
|
| 15 |
+
_RESUME_TOPOLOGY_KEY = "resume_topology"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def identity_collate(sample):
|
| 19 |
+
return sample
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class StreamingSampleDataset(IterableDataset):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
*,
|
| 26 |
+
source: BaseSourceAdapter,
|
| 27 |
+
rank: int,
|
| 28 |
+
world_size: int,
|
| 29 |
+
seed: int,
|
| 30 |
+
):
|
| 31 |
+
self.source = source
|
| 32 |
+
self.rank = int(rank)
|
| 33 |
+
self.world_size = int(world_size)
|
| 34 |
+
self.seed = int(seed)
|
| 35 |
+
self._epoch = mp.Value("q", 0)
|
| 36 |
+
self._pending_resume_state: dict | None = None
|
| 37 |
+
|
| 38 |
+
def load_state_dict(self, state: dict | None) -> None:
|
| 39 |
+
self._pending_resume_state = deepcopy(state) if state else None
|
| 40 |
+
|
| 41 |
+
def set_epoch(self, epoch: int) -> None:
|
| 42 |
+
with self._epoch.get_lock():
|
| 43 |
+
self._epoch.value = int(epoch)
|
| 44 |
+
|
| 45 |
+
def _current_epoch(self) -> int:
|
| 46 |
+
with self._epoch.get_lock():
|
| 47 |
+
return int(self._epoch.value)
|
| 48 |
+
|
| 49 |
+
def _take_resume_state(self, epoch: int) -> dict | None:
|
| 50 |
+
if (
|
| 51 |
+
self._pending_resume_state is None
|
| 52 |
+
or int(self._pending_resume_state.get("epoch", -1)) != int(epoch)
|
| 53 |
+
):
|
| 54 |
+
return None
|
| 55 |
+
state = deepcopy(self._pending_resume_state)
|
| 56 |
+
self._pending_resume_state = None
|
| 57 |
+
return state
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def _validate_resume_topology(
|
| 61 |
+
resume_state: dict,
|
| 62 |
+
*,
|
| 63 |
+
context: SourceContext,
|
| 64 |
+
loader_num_workers: int,
|
| 65 |
+
) -> None:
|
| 66 |
+
resume_topology = resume_state.get(_RESUME_TOPOLOGY_KEY)
|
| 67 |
+
if not isinstance(resume_topology, dict):
|
| 68 |
+
raise RuntimeError(
|
| 69 |
+
"Resume state is missing required worker topology metadata."
|
| 70 |
+
)
|
| 71 |
+
expected_world_size = int(resume_topology["world_size"])
|
| 72 |
+
expected_num_workers = int(resume_topology["loader_num_workers"])
|
| 73 |
+
expected_global_worker_count = int(resume_topology["global_worker_count"])
|
| 74 |
+
current_num_workers = int(loader_num_workers)
|
| 75 |
+
current_global_worker_count = int(context.global_worker_count)
|
| 76 |
+
if (
|
| 77 |
+
expected_world_size != int(context.world_size)
|
| 78 |
+
or expected_num_workers != current_num_workers
|
| 79 |
+
or expected_global_worker_count != current_global_worker_count
|
| 80 |
+
):
|
| 81 |
+
raise RuntimeError(
|
| 82 |
+
"Resume requires the same data worker topology as the saved state. "
|
| 83 |
+
f"saved(world_size={expected_world_size}, "
|
| 84 |
+
f"num_workers_per_rank={expected_num_workers}, "
|
| 85 |
+
f"global_worker_count={expected_global_worker_count}), "
|
| 86 |
+
f"current(world_size={context.world_size}, "
|
| 87 |
+
f"num_workers_per_rank={current_num_workers}, "
|
| 88 |
+
f"global_worker_count={current_global_worker_count})."
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def __iter__(self) -> Iterable[dict]:
|
| 92 |
+
worker_info = get_worker_info()
|
| 93 |
+
if worker_info is None:
|
| 94 |
+
worker_id = 0
|
| 95 |
+
loader_num_workers = 0
|
| 96 |
+
effective_num_workers = 1
|
| 97 |
+
else:
|
| 98 |
+
worker_id = worker_info.id
|
| 99 |
+
loader_num_workers = worker_info.num_workers
|
| 100 |
+
effective_num_workers = worker_info.num_workers
|
| 101 |
+
|
| 102 |
+
epoch = self._current_epoch()
|
| 103 |
+
context = SourceContext(
|
| 104 |
+
epoch=epoch,
|
| 105 |
+
rank=self.rank,
|
| 106 |
+
world_size=self.world_size,
|
| 107 |
+
worker_id=worker_id,
|
| 108 |
+
num_workers=effective_num_workers,
|
| 109 |
+
seed=self.seed,
|
| 110 |
+
)
|
| 111 |
+
resume_state = self._take_resume_state(epoch)
|
| 112 |
+
if resume_state is not None:
|
| 113 |
+
self._validate_resume_topology(
|
| 114 |
+
resume_state,
|
| 115 |
+
context=context,
|
| 116 |
+
loader_num_workers=loader_num_workers,
|
| 117 |
+
)
|
| 118 |
+
worker_state = (
|
| 119 |
+
None
|
| 120 |
+
if resume_state is None
|
| 121 |
+
else (resume_state.get("workers") or {}).get(str(context.global_worker_id))
|
| 122 |
+
)
|
| 123 |
+
sample_iter = self.source.iter_samples(
|
| 124 |
+
context,
|
| 125 |
+
state=None if worker_state is None else worker_state.get("adapter_state"),
|
| 126 |
+
)
|
| 127 |
+
for sample in sample_iter:
|
| 128 |
+
sample["data_worker_id"] = context.worker_id
|
| 129 |
+
sample["data_global_worker_id"] = context.global_worker_id
|
| 130 |
+
yield sample
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class _DataStateTracker:
|
| 134 |
+
def __init__(self, *, num_tokens_per_epoch: int | None):
|
| 135 |
+
self.num_tokens_per_epoch = (
|
| 136 |
+
None if num_tokens_per_epoch is None else int(num_tokens_per_epoch)
|
| 137 |
+
)
|
| 138 |
+
self._pending_state: dict | None = None
|
| 139 |
+
self._reset_for_epoch(epoch=0)
|
| 140 |
+
|
| 141 |
+
def _reset_for_epoch(self, *, epoch: int) -> None:
|
| 142 |
+
self.epoch = int(epoch)
|
| 143 |
+
self.samples_emitted = 0
|
| 144 |
+
self.num_text_tokens = 0
|
| 145 |
+
self.num_audio_tokens = 0
|
| 146 |
+
self.num_total_tokens = 0
|
| 147 |
+
self.workers: dict[str, dict] = {}
|
| 148 |
+
self._next_sample_order_by_worker: dict[str, int] = {}
|
| 149 |
+
|
| 150 |
+
def load_state_dict(self, state: dict | None) -> None:
|
| 151 |
+
self._pending_state = deepcopy(state) if state else None
|
| 152 |
+
|
| 153 |
+
def set_epoch(self, epoch: int) -> None:
|
| 154 |
+
if self._pending_state is not None and int(
|
| 155 |
+
self._pending_state.get("epoch", -1)
|
| 156 |
+
) == int(epoch):
|
| 157 |
+
state = deepcopy(self._pending_state)
|
| 158 |
+
self._pending_state = None
|
| 159 |
+
self.epoch = int(state.get("epoch", epoch))
|
| 160 |
+
self.samples_emitted = int(state.get("samples_emitted", 0))
|
| 161 |
+
self.num_text_tokens = int(state.get("num_text_tokens", 0))
|
| 162 |
+
self.num_audio_tokens = int(state.get("num_audio_tokens", 0))
|
| 163 |
+
self.num_total_tokens = int(state.get("num_total_tokens", 0))
|
| 164 |
+
self.workers = deepcopy(state.get("workers") or {})
|
| 165 |
+
self._next_sample_order_by_worker = {
|
| 166 |
+
worker_key: int((worker_state or {}).get("sample_order", -1)) + 1
|
| 167 |
+
for worker_key, worker_state in self.workers.items()
|
| 168 |
+
}
|
| 169 |
+
return
|
| 170 |
+
self._reset_for_epoch(epoch=int(epoch))
|
| 171 |
+
|
| 172 |
+
def should_stop(self) -> bool:
|
| 173 |
+
return (
|
| 174 |
+
self.num_tokens_per_epoch is not None
|
| 175 |
+
and self.num_total_tokens >= self.num_tokens_per_epoch
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def stage_sample(self, sample: dict) -> dict:
|
| 179 |
+
item = dict(sample)
|
| 180 |
+
worker_key = str(item.pop("data_global_worker_id"))
|
| 181 |
+
item.pop("data_worker_id", None)
|
| 182 |
+
adapter_state = item.pop("_adapter_state", None)
|
| 183 |
+
sample_order = int(self._next_sample_order_by_worker.get(worker_key, 0))
|
| 184 |
+
self._next_sample_order_by_worker[worker_key] = sample_order + 1
|
| 185 |
+
item[_TRACKING_KEY] = {
|
| 186 |
+
"worker_key": worker_key,
|
| 187 |
+
"adapter_state": deepcopy(adapter_state),
|
| 188 |
+
"sample_order": sample_order,
|
| 189 |
+
"num_text_tokens": int(item["num_text_tokens"]),
|
| 190 |
+
"num_audio_tokens": int(item["num_audio_tokens"]),
|
| 191 |
+
"num_total_tokens": int(
|
| 192 |
+
item.get("num_total_tokens", item["input_ids_length"])
|
| 193 |
+
),
|
| 194 |
+
}
|
| 195 |
+
return item
|
| 196 |
+
|
| 197 |
+
def _pop_tracking(self, sample: dict) -> tuple[dict, dict]:
|
| 198 |
+
item = dict(sample)
|
| 199 |
+
tracking = item.pop(_TRACKING_KEY, None)
|
| 200 |
+
if not isinstance(tracking, dict):
|
| 201 |
+
raise RuntimeError("Tracked sample is missing internal resume metadata.")
|
| 202 |
+
return item, tracking
|
| 203 |
+
|
| 204 |
+
def _advance_worker(self, tracking: dict) -> None:
|
| 205 |
+
adapter_state = tracking.get("adapter_state")
|
| 206 |
+
if adapter_state is None:
|
| 207 |
+
return
|
| 208 |
+
worker_key = str(tracking["worker_key"])
|
| 209 |
+
sample_order = int(tracking.get("sample_order", -1))
|
| 210 |
+
current_state = self.workers.get(worker_key)
|
| 211 |
+
current_order = int((current_state or {}).get("sample_order", -1))
|
| 212 |
+
if current_order >= sample_order:
|
| 213 |
+
return
|
| 214 |
+
self.workers[worker_key] = {
|
| 215 |
+
"adapter_state": deepcopy(adapter_state),
|
| 216 |
+
"sample_order": sample_order,
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
def mark_samples_dropped(self, samples: list[dict]) -> None:
|
| 220 |
+
for sample in samples:
|
| 221 |
+
_, tracking = self._pop_tracking(sample)
|
| 222 |
+
self._advance_worker(tracking)
|
| 223 |
+
|
| 224 |
+
def commit_batch(self, samples: list[dict]) -> list[dict]:
|
| 225 |
+
committed: list[dict] = []
|
| 226 |
+
for sample in samples:
|
| 227 |
+
item, tracking = self._pop_tracking(sample)
|
| 228 |
+
self._advance_worker(tracking)
|
| 229 |
+
self.samples_emitted += 1
|
| 230 |
+
self.num_text_tokens += int(tracking["num_text_tokens"])
|
| 231 |
+
self.num_audio_tokens += int(tracking["num_audio_tokens"])
|
| 232 |
+
self.num_total_tokens += int(tracking["num_total_tokens"])
|
| 233 |
+
committed.append(item)
|
| 234 |
+
return committed
|
| 235 |
+
|
| 236 |
+
def state_dict(self) -> dict:
|
| 237 |
+
return {
|
| 238 |
+
"epoch": int(self.epoch),
|
| 239 |
+
"samples_emitted": int(self.samples_emitted),
|
| 240 |
+
"num_text_tokens": int(self.num_text_tokens),
|
| 241 |
+
"num_audio_tokens": int(self.num_audio_tokens),
|
| 242 |
+
"num_total_tokens": int(self.num_total_tokens),
|
| 243 |
+
"workers": deepcopy(self.workers),
|
| 244 |
+
"num_tokens_per_epoch": self.num_tokens_per_epoch,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class BatchedDataStream:
|
| 249 |
+
def __init__(
|
| 250 |
+
self,
|
| 251 |
+
*,
|
| 252 |
+
sample_dataset: StreamingSampleDataset,
|
| 253 |
+
data_cfg,
|
| 254 |
+
tokenizer,
|
| 255 |
+
num_tokens_per_epoch: int | None,
|
| 256 |
+
profiler=None,
|
| 257 |
+
):
|
| 258 |
+
from dots_tts.data.collator import PadCollator
|
| 259 |
+
|
| 260 |
+
self.sample_dataset = sample_dataset
|
| 261 |
+
self.profiler = ensure_data_profiler(profiler)
|
| 262 |
+
llm_token_rate = (
|
| 263 |
+
float(data_cfg.train_audio_sample_rate)
|
| 264 |
+
/ float(data_cfg.audio_samples_per_llm_token)
|
| 265 |
+
)
|
| 266 |
+
self.batcher = OnlineBatcher(
|
| 267 |
+
max_audio_tokens_in_batch=max(
|
| 268 |
+
1,
|
| 269 |
+
math.ceil(float(data_cfg.max_audio_seconds_in_batch) * llm_token_rate),
|
| 270 |
+
),
|
| 271 |
+
max_text_tokens_in_batch=data_cfg.max_text_tokens_in_batch,
|
| 272 |
+
max_batch_size=data_cfg.max_samples_per_batch,
|
| 273 |
+
sample_pool_size=data_cfg.bucketing_pool_size,
|
| 274 |
+
profiler=self.profiler,
|
| 275 |
+
)
|
| 276 |
+
self.sample_loader = None
|
| 277 |
+
self.collator = PadCollator(tokenizer)
|
| 278 |
+
self.data_state = _DataStateTracker(
|
| 279 |
+
num_tokens_per_epoch=num_tokens_per_epoch
|
| 280 |
+
)
|
| 281 |
+
self._decision_iterator = None
|
| 282 |
+
self._sample_iterator = None
|
| 283 |
+
self._pending_batch = None
|
| 284 |
+
self._pending_samples = None
|
| 285 |
+
|
| 286 |
+
def attach_loader(self, loader: DataLoader) -> None:
|
| 287 |
+
self.sample_loader = loader
|
| 288 |
+
|
| 289 |
+
def close(self) -> None:
|
| 290 |
+
self._reset_iteration_state()
|
| 291 |
+
self.sample_loader = None
|
| 292 |
+
|
| 293 |
+
def load_state_dict(self, state: dict | None) -> None:
|
| 294 |
+
self.data_state.load_state_dict(state)
|
| 295 |
+
self.sample_dataset.load_state_dict(state)
|
| 296 |
+
self._reset_iteration_state()
|
| 297 |
+
|
| 298 |
+
def state_dict(self) -> dict:
|
| 299 |
+
if self.sample_loader is None:
|
| 300 |
+
raise RuntimeError("BatchedDataStream has no attached sample loader.")
|
| 301 |
+
if self._pending_batch is not None or self._pending_samples is not None:
|
| 302 |
+
raise RuntimeError(
|
| 303 |
+
"Cannot serialize BatchedDataStream while a batch is pending commit."
|
| 304 |
+
)
|
| 305 |
+
loader_num_workers = int(getattr(self.sample_loader, "num_workers", 0))
|
| 306 |
+
effective_num_workers = max(1, loader_num_workers)
|
| 307 |
+
state = self.data_state.state_dict()
|
| 308 |
+
state[_RESUME_TOPOLOGY_KEY] = {
|
| 309 |
+
"world_size": int(self.sample_dataset.world_size),
|
| 310 |
+
"loader_num_workers": loader_num_workers,
|
| 311 |
+
"global_worker_count": int(self.sample_dataset.world_size)
|
| 312 |
+
* effective_num_workers,
|
| 313 |
+
}
|
| 314 |
+
return state
|
| 315 |
+
|
| 316 |
+
def set_epoch(self, epoch: int) -> None:
|
| 317 |
+
self.sample_dataset.set_epoch(epoch)
|
| 318 |
+
self.data_state.set_epoch(epoch)
|
| 319 |
+
self._reset_iteration_state()
|
| 320 |
+
|
| 321 |
+
def _reset_iteration_state(self) -> None:
|
| 322 |
+
close_iterator = getattr(self._decision_iterator, "close", None)
|
| 323 |
+
if callable(close_iterator):
|
| 324 |
+
close_iterator()
|
| 325 |
+
self._decision_iterator = None
|
| 326 |
+
self._sample_iterator = None
|
| 327 |
+
self._pending_batch = None
|
| 328 |
+
self._pending_samples = None
|
| 329 |
+
|
| 330 |
+
def _iter_staged_samples(self):
|
| 331 |
+
if self.sample_loader is None:
|
| 332 |
+
raise RuntimeError("BatchedDataStream has no attached sample loader.")
|
| 333 |
+
self._sample_iterator = iter(self.sample_loader)
|
| 334 |
+
profiler = self.profiler
|
| 335 |
+
try:
|
| 336 |
+
while True:
|
| 337 |
+
if self.data_state.should_stop():
|
| 338 |
+
return
|
| 339 |
+
try:
|
| 340 |
+
with profiler.measure("main.loader_wait_next_sample"):
|
| 341 |
+
sample = next(self._sample_iterator)
|
| 342 |
+
except StopIteration:
|
| 343 |
+
return
|
| 344 |
+
if sample is None:
|
| 345 |
+
continue
|
| 346 |
+
with profiler.measure("main.stage_sample"):
|
| 347 |
+
staged = self.data_state.stage_sample(sample)
|
| 348 |
+
yield staged
|
| 349 |
+
finally:
|
| 350 |
+
self._sample_iterator = None
|
| 351 |
+
|
| 352 |
+
def _decision_stream(self):
|
| 353 |
+
if self._decision_iterator is None:
|
| 354 |
+
self._decision_iterator = iter(
|
| 355 |
+
self.batcher.build_decisions(self._iter_staged_samples())
|
| 356 |
+
)
|
| 357 |
+
return self._decision_iterator
|
| 358 |
+
|
| 359 |
+
def peek_batch(self) -> tuple[dict | None, bool]:
|
| 360 |
+
if self._pending_batch is not None:
|
| 361 |
+
return self._pending_batch, True
|
| 362 |
+
|
| 363 |
+
for decision in self._decision_stream():
|
| 364 |
+
if decision.dropped_samples:
|
| 365 |
+
self.data_state.mark_samples_dropped(decision.dropped_samples)
|
| 366 |
+
if not decision.batch_samples:
|
| 367 |
+
continue
|
| 368 |
+
self._pending_samples = decision.batch_samples
|
| 369 |
+
with self.profiler.measure(
|
| 370 |
+
"main.collate_batch",
|
| 371 |
+
count=len(decision.batch_samples),
|
| 372 |
+
):
|
| 373 |
+
self._pending_batch = self.collator(decision.batch_samples)
|
| 374 |
+
return self._pending_batch, True
|
| 375 |
+
return None, False
|
| 376 |
+
|
| 377 |
+
def commit_batch(self) -> dict:
|
| 378 |
+
if self._pending_batch is None or self._pending_samples is None:
|
| 379 |
+
raise RuntimeError("BatchedDataStream has no pending batch to commit.")
|
| 380 |
+
pending_batch = self._pending_batch
|
| 381 |
+
self.data_state.commit_batch(self._pending_samples)
|
| 382 |
+
self._pending_batch = None
|
| 383 |
+
self._pending_samples = None
|
| 384 |
+
return pending_batch
|
| 385 |
+
|
| 386 |
+
def discard_batch(self) -> None:
|
| 387 |
+
if self._pending_batch is None or self._pending_samples is None:
|
| 388 |
+
raise RuntimeError("BatchedDataStream has no pending batch to discard.")
|
| 389 |
+
self._pending_batch = None
|
| 390 |
+
self._pending_samples = None
|
| 391 |
+
|
| 392 |
+
def __iter__(self):
|
| 393 |
+
while True:
|
| 394 |
+
batch, has_batch = self.peek_batch()
|
| 395 |
+
if not has_batch:
|
| 396 |
+
return
|
| 397 |
+
self.commit_batch()
|
| 398 |
+
yield batch
|
| 399 |
+
if self.data_state.should_stop():
|
| 400 |
+
return
|
src/dots_tts/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Model families."""
|
src/dots_tts/models/dots_tts/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""dots_tts model package."""
|
src/dots_tts/models/dots_tts/config.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dots_tts.config.base import ConfigBase, StrictConfigBase
|
| 4 |
+
from dots_tts.modules.vocoder.config import AudioVAEConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class _EncoderConfig(ConfigBase):
|
| 8 |
+
num_layers: int = 6
|
| 9 |
+
num_heads: int = 16
|
| 10 |
+
hidden_size: int = 1024
|
| 11 |
+
ffn_hidden_size: int = 4096
|
| 12 |
+
modulation: bool = False
|
| 13 |
+
qkv_bias: bool = False
|
| 14 |
+
qk_norm: bool = False
|
| 15 |
+
attn_dropout: float = 0.0
|
| 16 |
+
dropout: float = 0.0
|
| 17 |
+
norm_layer: str = "LayerNorm"
|
| 18 |
+
alibi_bias: bool = False
|
| 19 |
+
rotary_bias: bool = False
|
| 20 |
+
rotary_theta: float | None = 10000
|
| 21 |
+
input_dim: int = 1024
|
| 22 |
+
causal: bool = True
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class _DiTConfig(ConfigBase):
|
| 26 |
+
num_layers: int = 18
|
| 27 |
+
num_heads: int = 16
|
| 28 |
+
hidden_size: int = 1024
|
| 29 |
+
ffn_hidden_size: int = 4096
|
| 30 |
+
modulation: bool = True
|
| 31 |
+
qkv_bias: bool = False
|
| 32 |
+
qk_norm: bool = False
|
| 33 |
+
attn_dropout: float = 0.0
|
| 34 |
+
dropout: float = 0.0
|
| 35 |
+
norm_layer: str = "LayerNorm"
|
| 36 |
+
alibi_bias: bool = False
|
| 37 |
+
rotary_bias: bool = True
|
| 38 |
+
rotary_theta: float | None = 10000
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class LossConfig(StrictConfigBase):
|
| 42 |
+
ce_weight: float = 1.0
|
| 43 |
+
fm_weight: float = 1.0
|
| 44 |
+
eos_weight: float = 1.0
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class MeanFlowConfig(ConfigBase):
|
| 48 |
+
enabled: bool = False
|
| 49 |
+
use_duration_embedding: bool = True
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ModelConfig(ConfigBase):
|
| 53 |
+
model_type: str = "dots_tts"
|
| 54 |
+
latent_dim: int
|
| 55 |
+
patch_size: int
|
| 56 |
+
cfg_droprate: float = 0.2
|
| 57 |
+
PatchEncoder: _EncoderConfig
|
| 58 |
+
DiT: _DiTConfig
|
| 59 |
+
vocoder: AudioVAEConfig
|
| 60 |
+
fm_sigma: float = 0.0
|
| 61 |
+
xvec_drop_rate: float = 0.2
|
| 62 |
+
campplus_embedding_size: int | None = 512
|
| 63 |
+
xvec_max_audio_seconds: float = 10.0
|
| 64 |
+
meanflow: MeanFlowConfig | None = None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
__all__ = [
|
| 68 |
+
"LossConfig",
|
| 69 |
+
"MeanFlowConfig",
|
| 70 |
+
"ModelConfig",
|
| 71 |
+
]
|
src/dots_tts/models/dots_tts/core.py
ADDED
|
@@ -0,0 +1,910 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, Callable
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from loguru import logger
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
from torchdiffeq import odeint
|
| 11 |
+
from transformers import Qwen2Config, Qwen2ForCausalLM
|
| 12 |
+
|
| 13 |
+
from dots_tts.models.dots_tts.config import ModelConfig
|
| 14 |
+
from dots_tts.modules.backbone.dit import DiT
|
| 15 |
+
from dots_tts.modules.backbone.semantic_encoder import VAESemanticEncoder
|
| 16 |
+
from dots_tts.utils.tokenizer import (
|
| 17 |
+
AUDIO_COMP_SPAN_TOKEN,
|
| 18 |
+
AUDIO_GEN_SPAN_TOKEN,
|
| 19 |
+
TEXT_COND_END_TOKEN,
|
| 20 |
+
require_token_id,
|
| 21 |
+
)
|
| 22 |
+
from dots_tts.utils.util import get_mask_from_lengths, mask_data
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class DotsTtsForwardOutput:
|
| 27 |
+
llm_logits: torch.Tensor
|
| 28 |
+
pred: torch.Tensor
|
| 29 |
+
target: torch.Tensor
|
| 30 |
+
eos_out: torch.Tensor
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DotsTtsCore(nn.Module):
|
| 34 |
+
# region Module construction
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
config: ModelConfig,
|
| 38 |
+
llm_config: Qwen2Config,
|
| 39 |
+
tokenizer=None,
|
| 40 |
+
*,
|
| 41 |
+
latent_stats_path,
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.config = config
|
| 45 |
+
self.fm_hidden_size = config.DiT.hidden_size
|
| 46 |
+
self.hidden_patch_size = 1
|
| 47 |
+
self.cfg_droprate = config.get("cfg_droprate", 0.2)
|
| 48 |
+
self.latent_patch_size = config.patch_size
|
| 49 |
+
self.latent_dim = config.latent_dim
|
| 50 |
+
self.xvec_dim = config.campplus_embedding_size
|
| 51 |
+
self.xvec_drop_rate = config.get("xvec_drop_rate", 0.2)
|
| 52 |
+
|
| 53 |
+
# Setup tokenizer
|
| 54 |
+
self.tokenizer = tokenizer
|
| 55 |
+
if self.tokenizer is None:
|
| 56 |
+
raise RuntimeError("Tokenizer must be provided before building the model.")
|
| 57 |
+
if llm_config is None:
|
| 58 |
+
raise RuntimeError("LLM config must be provided before building the model.")
|
| 59 |
+
self.pad_token_id = getattr(self.tokenizer, "pad_token_id", None)
|
| 60 |
+
self.audio_gen_span_id = require_token_id(self.tokenizer, AUDIO_GEN_SPAN_TOKEN)
|
| 61 |
+
self.audio_comp_span_id = require_token_id(
|
| 62 |
+
self.tokenizer, AUDIO_COMP_SPAN_TOKEN
|
| 63 |
+
)
|
| 64 |
+
self.text_cond_end_id = require_token_id(self.tokenizer, TEXT_COND_END_TOKEN)
|
| 65 |
+
|
| 66 |
+
# Setup LLM with language modeling head so we can obtain logits directly
|
| 67 |
+
llm_config = copy.deepcopy(llm_config)
|
| 68 |
+
llm_config.vocab_size = len(self.tokenizer)
|
| 69 |
+
self.llm = Qwen2ForCausalLM._from_config(
|
| 70 |
+
llm_config,
|
| 71 |
+
dtype=torch.float32,
|
| 72 |
+
)
|
| 73 |
+
self.llm_hidden_size = self.llm.config.hidden_size
|
| 74 |
+
|
| 75 |
+
self.patch_encoder = VAESemanticEncoder(
|
| 76 |
+
in_dim=self.latent_dim,
|
| 77 |
+
out_dim=self.llm_hidden_size,
|
| 78 |
+
config=config,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Setup Flow matching related modules
|
| 82 |
+
self.hidden_proj = nn.Linear(self.llm_hidden_size, self.fm_hidden_size)
|
| 83 |
+
self.latent_proj = nn.Linear(self.latent_dim, self.fm_hidden_size)
|
| 84 |
+
self.coordinate_proj = nn.Linear(self.latent_dim, self.fm_hidden_size)
|
| 85 |
+
self.xvec_proj = nn.Sequential(
|
| 86 |
+
nn.Linear(self.xvec_dim, self.fm_hidden_size),
|
| 87 |
+
nn.LayerNorm(self.fm_hidden_size),
|
| 88 |
+
)
|
| 89 |
+
self.meanflow_config = config.meanflow if config.meanflow is not None else None
|
| 90 |
+
self.mode = (
|
| 91 |
+
"meanflow"
|
| 92 |
+
if self.meanflow_config is not None and self.meanflow_config.enabled
|
| 93 |
+
else "flow_matching"
|
| 94 |
+
)
|
| 95 |
+
dit_mode = (
|
| 96 |
+
"meanflow"
|
| 97 |
+
if self.mode == "meanflow"
|
| 98 |
+
and self.meanflow_config.use_duration_embedding
|
| 99 |
+
else "flow_matching"
|
| 100 |
+
)
|
| 101 |
+
self.velocity_field_predictor = DiT(
|
| 102 |
+
in_dim=self.fm_hidden_size,
|
| 103 |
+
out_dim=self.latent_dim,
|
| 104 |
+
transformer_config=config.DiT,
|
| 105 |
+
mode=dit_mode,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Setup eos predictor
|
| 109 |
+
self.eos_proj = nn.Sequential(
|
| 110 |
+
nn.Linear(self.llm_hidden_size, self.llm_hidden_size),
|
| 111 |
+
nn.SiLU(),
|
| 112 |
+
nn.Linear(self.llm_hidden_size, 2),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Helpers
|
| 116 |
+
self.fm_helper = FlowMatchingHelper(sigma=config.get("fm_sigma", 0.0))
|
| 117 |
+
self.causal_helper = CausalHelper()
|
| 118 |
+
self.io_helper = IOHelper(latent_stats_path=latent_stats_path)
|
| 119 |
+
self.audio_span_token_ids: list[int] = [
|
| 120 |
+
self.audio_gen_span_id,
|
| 121 |
+
self.audio_comp_span_id,
|
| 122 |
+
]
|
| 123 |
+
# endregion Module construction
|
| 124 |
+
|
| 125 |
+
# region Training forward path
|
| 126 |
+
def forward(self, data: dict[str, Any]) -> DotsTtsForwardOutput:
|
| 127 |
+
input_ids: torch.Tensor = data["input_ids"]
|
| 128 |
+
input_ids_lengths: torch.Tensor = data["input_ids_lengths"]
|
| 129 |
+
input_span_mask: torch.Tensor = data["input_span_mask"]
|
| 130 |
+
output_span_mask: torch.Tensor = data["output_span_mask"]
|
| 131 |
+
batch_size = input_ids.size(0)
|
| 132 |
+
device = input_ids.device
|
| 133 |
+
|
| 134 |
+
latents: torch.Tensor | None = data.get("latents")
|
| 135 |
+
latents_sampled: torch.Tensor | None = data.get("latents_sampled")
|
| 136 |
+
latent_lengths: torch.Tensor | None = data.get("latent_lengths")
|
| 137 |
+
has_latents = latents is not None or latents_sampled is not None
|
| 138 |
+
|
| 139 |
+
patch_embeddings: torch.Tensor | None
|
| 140 |
+
valid_patch_counts: torch.Tensor | None
|
| 141 |
+
if has_latents:
|
| 142 |
+
if latents_sampled is None:
|
| 143 |
+
latents_sampled = self.io_helper.sample_from_latent(latents)
|
| 144 |
+
patch_embeddings = self.patch_encoder(
|
| 145 |
+
latents_sampled, x_lens=latent_lengths
|
| 146 |
+
)
|
| 147 |
+
valid_patch_counts = latent_lengths // self.latent_patch_size
|
| 148 |
+
latents_sampled = self.io_helper.normalize(latents_sampled)
|
| 149 |
+
else:
|
| 150 |
+
latents_sampled = None
|
| 151 |
+
patch_embeddings = None
|
| 152 |
+
valid_patch_counts = torch.zeros(
|
| 153 |
+
batch_size, dtype=torch.long, device=device
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
input_span_counts = input_span_mask.sum(dim=1)
|
| 157 |
+
if input_span_counts.sum() > 0 and patch_embeddings is None:
|
| 158 |
+
raise RuntimeError(
|
| 159 |
+
"Found audio span tokens but no latents provided to compute patch embeddings."
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Token embeddings with audio span replacement
|
| 163 |
+
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
| 164 |
+
if patch_embeddings is not None:
|
| 165 |
+
inputs_embeds = inputs_embeds.clone()
|
| 166 |
+
patch_embeddings = patch_embeddings.to(inputs_embeds.dtype)
|
| 167 |
+
for b in range(batch_size):
|
| 168 |
+
span_num = input_span_counts[b].item()
|
| 169 |
+
if span_num == 0:
|
| 170 |
+
continue
|
| 171 |
+
expected = valid_patch_counts[b].item()
|
| 172 |
+
if expected != span_num:
|
| 173 |
+
raise RuntimeError(
|
| 174 |
+
f"Mismatch between span tokens ({span_num}) and latent patches ({expected}) for sample {b}."
|
| 175 |
+
)
|
| 176 |
+
indices = input_span_mask[b].nonzero(as_tuple=False).squeeze(-1)
|
| 177 |
+
inputs_embeds[b, indices, :] = patch_embeddings[b, :span_num, :]
|
| 178 |
+
|
| 179 |
+
# LLM forward pass to obtain logits & hidden states
|
| 180 |
+
_llm_attn_mask, llm_seq_mask, _ = self.causal_helper.create_causal_mask_and_pos(
|
| 181 |
+
seq_lens=input_ids_lengths, max_len=input_ids.size(1)
|
| 182 |
+
)
|
| 183 |
+
llm_outputs = self.llm(
|
| 184 |
+
inputs_embeds=inputs_embeds,
|
| 185 |
+
attention_mask=llm_seq_mask.long(),
|
| 186 |
+
use_cache=False,
|
| 187 |
+
output_hidden_states=True,
|
| 188 |
+
return_dict=True,
|
| 189 |
+
)
|
| 190 |
+
llm_logits = llm_outputs.logits # [B, L, V]
|
| 191 |
+
llm_hidden = llm_outputs.hidden_states[-1] # [B, L, H]
|
| 192 |
+
|
| 193 |
+
# eos prediction, before cfg masking
|
| 194 |
+
eos = self.eos_proj(llm_hidden.detach())
|
| 195 |
+
|
| 196 |
+
# Flow matching forward
|
| 197 |
+
total_patches = int(output_span_mask.sum().item())
|
| 198 |
+
if total_patches > 0 and latents_sampled is None:
|
| 199 |
+
raise RuntimeError("Flow matching requested but latents are missing.")
|
| 200 |
+
if total_patches > 0:
|
| 201 |
+
xvec_cond = self.xvec_proj(data["xvector"])
|
| 202 |
+
vocal_mask = data.get("vocal_mask")
|
| 203 |
+
if vocal_mask is None:
|
| 204 |
+
vocal_mask = torch.ones((batch_size,), device=device, dtype=torch.bool)
|
| 205 |
+
xvec_drop_mask = (
|
| 206 |
+
torch.empty((batch_size,), device=device, dtype=torch.float32).uniform_(
|
| 207 |
+
0, 1
|
| 208 |
+
)
|
| 209 |
+
< self.xvec_drop_rate
|
| 210 |
+
)
|
| 211 |
+
xvec_drop_mask = xvec_drop_mask & vocal_mask
|
| 212 |
+
xvec_cond = mask_data(xvec_cond, xvec_drop_mask)
|
| 213 |
+
|
| 214 |
+
hiddens_for_fm = torch.where(
|
| 215 |
+
output_span_mask.unsqueeze(-1), llm_hidden, inputs_embeds
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Prepare DiT inputs
|
| 219 |
+
(
|
| 220 |
+
fm_seq,
|
| 221 |
+
target,
|
| 222 |
+
fm_attn_mask,
|
| 223 |
+
fm_seq_mask,
|
| 224 |
+
fm_pos_ids,
|
| 225 |
+
times,
|
| 226 |
+
fm_prefix_lengths,
|
| 227 |
+
fm_gen_lengths,
|
| 228 |
+
fm_gen_patch_size,
|
| 229 |
+
) = self.io_helper.prepare_inputs_for_dit(
|
| 230 |
+
hiddens=hiddens_for_fm,
|
| 231 |
+
hidden_lens=input_ids_lengths,
|
| 232 |
+
latents=latents_sampled,
|
| 233 |
+
latent_lens=latent_lengths,
|
| 234 |
+
hidden_proj=self.hidden_proj,
|
| 235 |
+
latent_proj=self.latent_proj,
|
| 236 |
+
noisy_proj=self.coordinate_proj,
|
| 237 |
+
span_mask=output_span_mask,
|
| 238 |
+
hidden_patch_size=self.hidden_patch_size,
|
| 239 |
+
latent_patch_size=self.latent_patch_size,
|
| 240 |
+
fm_helper=self.fm_helper,
|
| 241 |
+
cfg_droprate=self.cfg_droprate,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Predict velocity field
|
| 245 |
+
vt = self.velocity_field_predictor(
|
| 246 |
+
x=fm_seq,
|
| 247 |
+
timesteps=times,
|
| 248 |
+
pos_ids=fm_pos_ids,
|
| 249 |
+
mask=fm_seq_mask,
|
| 250 |
+
attn_mask=fm_attn_mask,
|
| 251 |
+
return_hidden_stats=False,
|
| 252 |
+
g_cond=xvec_cond,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Get predictions and targets
|
| 256 |
+
pred = self.io_helper.get_dit_outputs(
|
| 257 |
+
pred_v=vt,
|
| 258 |
+
fm_prefix_lengths=fm_prefix_lengths,
|
| 259 |
+
fm_gen_lengths=fm_gen_lengths,
|
| 260 |
+
fm_gen_patch_size=fm_gen_patch_size,
|
| 261 |
+
latent_patch_size=self.latent_patch_size,
|
| 262 |
+
)
|
| 263 |
+
else:
|
| 264 |
+
# Dummy forward for velocity_field_predictor to keep gradients connected in DDP
|
| 265 |
+
dummy_length = self.latent_patch_size
|
| 266 |
+
dummy_seq_h = llm_hidden.new_zeros((1, dummy_length, self.llm_hidden_size))
|
| 267 |
+
dummy_seq_h = self.hidden_proj(dummy_seq_h) * 0.0 # dummy op for ddp
|
| 268 |
+
dummy_seq_l = llm_hidden.new_zeros((1, dummy_length, self.latent_dim))
|
| 269 |
+
dummy_seq_l = self.latent_proj(dummy_seq_l) * 0.0 # dummy op for ddp
|
| 270 |
+
dummy_seq_c = llm_hidden.new_zeros((1, dummy_length, self.latent_dim))
|
| 271 |
+
dummy_seq_c = self.coordinate_proj(dummy_seq_c) * 0.0 # dummy op for ddp
|
| 272 |
+
dummy_seq = dummy_seq_h + dummy_seq_l + dummy_seq_c
|
| 273 |
+
dummy_times = torch.zeros((1,), device=device, dtype=torch.float32)
|
| 274 |
+
dummy_attn_mask = torch.ones(
|
| 275 |
+
(1, dummy_length, dummy_length), device=device, dtype=torch.bool
|
| 276 |
+
)
|
| 277 |
+
dummy_out = self.velocity_field_predictor(
|
| 278 |
+
x=dummy_seq,
|
| 279 |
+
timesteps=dummy_times,
|
| 280 |
+
attn_mask=dummy_attn_mask,
|
| 281 |
+
)
|
| 282 |
+
pred = dummy_out[:, -self.latent_patch_size :, :]
|
| 283 |
+
target = pred.detach()
|
| 284 |
+
|
| 285 |
+
return DotsTtsForwardOutput(
|
| 286 |
+
llm_logits=llm_logits,
|
| 287 |
+
pred=pred,
|
| 288 |
+
target=target,
|
| 289 |
+
eos_out=eos,
|
| 290 |
+
)
|
| 291 |
+
# endregion Training forward path
|
| 292 |
+
|
| 293 |
+
# region Autoregressive and flow-matching inference steps
|
| 294 |
+
@torch.no_grad()
|
| 295 |
+
def fm_solver_step(
|
| 296 |
+
self,
|
| 297 |
+
t: torch.Tensor,
|
| 298 |
+
z: torch.Tensor,
|
| 299 |
+
*,
|
| 300 |
+
input_sequence: torch.Tensor,
|
| 301 |
+
cfg_sequence: torch.Tensor,
|
| 302 |
+
attn_mask: torch.Tensor,
|
| 303 |
+
pos_ids: torch.Tensor | None,
|
| 304 |
+
hidden_size: int,
|
| 305 |
+
patch_size: int,
|
| 306 |
+
g_cond: torch.Tensor | None,
|
| 307 |
+
guidance_scale: torch.Tensor | float,
|
| 308 |
+
) -> torch.Tensor:
|
| 309 |
+
batch_size = input_sequence.size(0)
|
| 310 |
+
if input_sequence.shape != cfg_sequence.shape:
|
| 311 |
+
raise ValueError(
|
| 312 |
+
"FM input_sequence and cfg_sequence must share the same shape."
|
| 313 |
+
)
|
| 314 |
+
if input_sequence.size(1) < patch_size:
|
| 315 |
+
raise ValueError(
|
| 316 |
+
"FM input sequence must reserve at least one latent patch slot."
|
| 317 |
+
)
|
| 318 |
+
latent_start = input_sequence.size(1) - patch_size
|
| 319 |
+
z = self.coordinate_proj(z)
|
| 320 |
+
z_c = input_sequence.clone()
|
| 321 |
+
z_c[:, latent_start:] = z
|
| 322 |
+
z_branches = [z_c]
|
| 323 |
+
g_cond_t = (
|
| 324 |
+
None if g_cond is None else g_cond.to(device=z_c.device, dtype=z_c.dtype)
|
| 325 |
+
)
|
| 326 |
+
g_cond_branches = None if g_cond_t is None else [g_cond_t]
|
| 327 |
+
|
| 328 |
+
z_cfg = cfg_sequence.clone()
|
| 329 |
+
z_cfg[:, latent_start:] = z
|
| 330 |
+
z_branches.append(z_cfg)
|
| 331 |
+
if g_cond_branches is not None:
|
| 332 |
+
g_cond_branches.append(torch.zeros_like(g_cond_t))
|
| 333 |
+
|
| 334 |
+
z_z = torch.cat(z_branches, dim=0)
|
| 335 |
+
t_t = t.reshape(1).repeat(len(z_branches))
|
| 336 |
+
if g_cond_branches is not None:
|
| 337 |
+
g_cond_t = torch.cat(g_cond_branches, dim=0)
|
| 338 |
+
vt = self.velocity_field_predictor(
|
| 339 |
+
x=z_z,
|
| 340 |
+
timesteps=t_t,
|
| 341 |
+
attn_mask=attn_mask,
|
| 342 |
+
pos_ids=pos_ids,
|
| 343 |
+
g_cond=g_cond_t,
|
| 344 |
+
hidden_size=patch_size * 2 + hidden_size,
|
| 345 |
+
patch_size=patch_size + 1,
|
| 346 |
+
)
|
| 347 |
+
vt = vt[:, latent_start:]
|
| 348 |
+
vt_c = vt[:batch_size]
|
| 349 |
+
vt_u = vt[batch_size:]
|
| 350 |
+
if not torch.is_tensor(guidance_scale):
|
| 351 |
+
guidance_scale = vt_c.new_tensor(float(guidance_scale))
|
| 352 |
+
else:
|
| 353 |
+
guidance_scale = guidance_scale.to(device=vt_c.device, dtype=vt_c.dtype)
|
| 354 |
+
return vt_c + guidance_scale * (vt_c - vt_u)
|
| 355 |
+
|
| 356 |
+
@torch.no_grad()
|
| 357 |
+
def step_llm(
|
| 358 |
+
self,
|
| 359 |
+
inputs_embeds: torch.Tensor | None = None,
|
| 360 |
+
input_ids: torch.Tensor | None = None,
|
| 361 |
+
past_key_values: Any | None = None,
|
| 362 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any | None]:
|
| 363 |
+
provided = int(inputs_embeds is not None) + int(input_ids is not None)
|
| 364 |
+
if provided != 1:
|
| 365 |
+
raise ValueError(
|
| 366 |
+
"Exactly one of inputs_embeds or input_ids must be provided to step_llm()."
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
if inputs_embeds is not None:
|
| 370 |
+
pass
|
| 371 |
+
else:
|
| 372 |
+
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
| 373 |
+
|
| 374 |
+
outputs = self.llm(
|
| 375 |
+
inputs_embeds=inputs_embeds,
|
| 376 |
+
past_key_values=past_key_values,
|
| 377 |
+
use_cache=True,
|
| 378 |
+
output_hidden_states=True,
|
| 379 |
+
return_dict=True,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
hidden = outputs.hidden_states[-1]
|
| 383 |
+
logits = outputs.logits
|
| 384 |
+
past_key_values = outputs.past_key_values
|
| 385 |
+
|
| 386 |
+
return inputs_embeds, hidden, logits, past_key_values
|
| 387 |
+
|
| 388 |
+
@torch.no_grad()
|
| 389 |
+
def _meanflow_step_fm(
|
| 390 |
+
self,
|
| 391 |
+
*,
|
| 392 |
+
input_sequence: torch.Tensor,
|
| 393 |
+
attn_mask: torch.Tensor,
|
| 394 |
+
pos_ids: torch.Tensor | None,
|
| 395 |
+
patch_size: int,
|
| 396 |
+
g_cond: torch.Tensor | None = None,
|
| 397 |
+
nfe: int = 2,
|
| 398 |
+
solver_step: Callable[..., torch.Tensor] | None = None,
|
| 399 |
+
) -> torch.Tensor:
|
| 400 |
+
if nfe <= 0:
|
| 401 |
+
raise ValueError(f"MeanFlow nfe must be positive, got {nfe}.")
|
| 402 |
+
batch_size = input_sequence.size(0)
|
| 403 |
+
device = input_sequence.device
|
| 404 |
+
dtype = input_sequence.dtype
|
| 405 |
+
solver_step = self.meanflow_solver_step if solver_step is None else solver_step
|
| 406 |
+
z = (
|
| 407 |
+
torch.randn(
|
| 408 |
+
(batch_size, patch_size, self.latent_dim),
|
| 409 |
+
device=device,
|
| 410 |
+
dtype=dtype,
|
| 411 |
+
)
|
| 412 |
+
)
|
| 413 |
+
times = torch.linspace(0.0, 1.0, nfe + 1, device=device, dtype=dtype)
|
| 414 |
+
|
| 415 |
+
for step in range(nfe):
|
| 416 |
+
t = times[step].expand(batch_size)
|
| 417 |
+
dt = (times[step + 1] - times[step]).expand(batch_size)
|
| 418 |
+
z = solver_step(
|
| 419 |
+
z,
|
| 420 |
+
t=t,
|
| 421 |
+
dt=dt,
|
| 422 |
+
input_sequence=input_sequence,
|
| 423 |
+
attn_mask=attn_mask,
|
| 424 |
+
pos_ids=pos_ids,
|
| 425 |
+
patch_size=patch_size,
|
| 426 |
+
g_cond=g_cond,
|
| 427 |
+
).clone()
|
| 428 |
+
return z
|
| 429 |
+
|
| 430 |
+
@torch.no_grad()
|
| 431 |
+
def meanflow_solver_step(
|
| 432 |
+
self,
|
| 433 |
+
z: torch.Tensor,
|
| 434 |
+
*,
|
| 435 |
+
t: torch.Tensor,
|
| 436 |
+
dt: torch.Tensor,
|
| 437 |
+
input_sequence: torch.Tensor,
|
| 438 |
+
attn_mask: torch.Tensor,
|
| 439 |
+
pos_ids: torch.Tensor | None,
|
| 440 |
+
patch_size: int,
|
| 441 |
+
g_cond: torch.Tensor | None,
|
| 442 |
+
) -> torch.Tensor:
|
| 443 |
+
if input_sequence.size(1) < patch_size:
|
| 444 |
+
raise ValueError(
|
| 445 |
+
"MeanFlow input sequence must reserve at least one latent patch slot."
|
| 446 |
+
)
|
| 447 |
+
latent_start = input_sequence.size(1) - patch_size
|
| 448 |
+
z_proj = self.coordinate_proj(z)
|
| 449 |
+
z_c = input_sequence.clone()
|
| 450 |
+
z_c[:, latent_start:] = z_proj
|
| 451 |
+
vt = self.velocity_field_predictor(
|
| 452 |
+
x=z_c,
|
| 453 |
+
timesteps=t,
|
| 454 |
+
duration=dt,
|
| 455 |
+
attn_mask=attn_mask,
|
| 456 |
+
pos_ids=pos_ids,
|
| 457 |
+
g_cond=g_cond,
|
| 458 |
+
)
|
| 459 |
+
velocity = vt[:, latent_start:]
|
| 460 |
+
return z + velocity * dt.view(-1, 1, 1)
|
| 461 |
+
|
| 462 |
+
@torch.no_grad()
|
| 463 |
+
def _flow_matching_step_fm(
|
| 464 |
+
self,
|
| 465 |
+
*,
|
| 466 |
+
input_sequence: torch.Tensor,
|
| 467 |
+
cfg_sequence: torch.Tensor,
|
| 468 |
+
attn_mask: torch.Tensor,
|
| 469 |
+
pos_ids: torch.Tensor | None,
|
| 470 |
+
hidden_size: int,
|
| 471 |
+
patch_size: int,
|
| 472 |
+
g_cond: torch.Tensor | None = None,
|
| 473 |
+
ode_method: str = "euler",
|
| 474 |
+
num_steps: int = 10,
|
| 475 |
+
guidance_scale: float = 3.0,
|
| 476 |
+
solver_step: Callable[..., torch.Tensor] | None = None,
|
| 477 |
+
) -> torch.Tensor:
|
| 478 |
+
batch_size = input_sequence.size(0)
|
| 479 |
+
num_evals = 0
|
| 480 |
+
solver_step = self.fm_solver_step if solver_step is None else solver_step
|
| 481 |
+
guidance_scale_tensor = input_sequence.new_tensor(float(guidance_scale))
|
| 482 |
+
|
| 483 |
+
# Prepare ODE solver
|
| 484 |
+
def solver(t, z):
|
| 485 |
+
nonlocal num_evals
|
| 486 |
+
num_evals += 1
|
| 487 |
+
return solver_step(
|
| 488 |
+
t,
|
| 489 |
+
z,
|
| 490 |
+
input_sequence=input_sequence,
|
| 491 |
+
cfg_sequence=cfg_sequence,
|
| 492 |
+
attn_mask=attn_mask,
|
| 493 |
+
pos_ids=pos_ids,
|
| 494 |
+
hidden_size=hidden_size,
|
| 495 |
+
patch_size=patch_size,
|
| 496 |
+
g_cond=g_cond,
|
| 497 |
+
guidance_scale=guidance_scale_tensor,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# Prepare noise as initial coordinate
|
| 501 |
+
noise = torch.randn(
|
| 502 |
+
(batch_size, patch_size, self.latent_dim),
|
| 503 |
+
dtype=input_sequence.dtype,
|
| 504 |
+
device=input_sequence.device,
|
| 505 |
+
)
|
| 506 |
+
# Solve
|
| 507 |
+
times = torch.tensor(
|
| 508 |
+
[0.0, 1.0], dtype=input_sequence.dtype, device=input_sequence.device
|
| 509 |
+
)
|
| 510 |
+
if ode_method in ["euler", "midpoint", "rk4"]: # fixed step size methods
|
| 511 |
+
options = {"step_size": 1.0 / num_steps}
|
| 512 |
+
else:
|
| 513 |
+
logger.warning(
|
| 514 |
+
"Using adaptive step size ODE solver for FM, NFE is not guaranteed: "
|
| 515 |
+
"ode_method={}",
|
| 516 |
+
ode_method,
|
| 517 |
+
)
|
| 518 |
+
options = {}
|
| 519 |
+
trajectory = odeint(
|
| 520 |
+
func=solver,
|
| 521 |
+
y0=noise,
|
| 522 |
+
t=times,
|
| 523 |
+
atol=1e-5,
|
| 524 |
+
rtol=1e-5,
|
| 525 |
+
method=ode_method,
|
| 526 |
+
options=options,
|
| 527 |
+
)
|
| 528 |
+
# print(f"Expected NFE: {num_steps}, Actual NFE: {num_evals}")
|
| 529 |
+
return trajectory[-1]
|
| 530 |
+
|
| 531 |
+
@torch.no_grad()
|
| 532 |
+
def step_fm(
|
| 533 |
+
self,
|
| 534 |
+
input_sequence: torch.Tensor,
|
| 535 |
+
cfg_sequence: torch.Tensor,
|
| 536 |
+
attn_mask: torch.Tensor,
|
| 537 |
+
pos_ids: torch.Tensor | None,
|
| 538 |
+
hidden_size: int,
|
| 539 |
+
patch_size: int,
|
| 540 |
+
g_cond: torch.Tensor | None = None,
|
| 541 |
+
ode_method: str = "euler",
|
| 542 |
+
num_steps: int = 10,
|
| 543 |
+
guidance_scale: float = 3.0,
|
| 544 |
+
solver_step: Callable[..., torch.Tensor] | None = None,
|
| 545 |
+
) -> torch.Tensor:
|
| 546 |
+
if self.mode == "meanflow":
|
| 547 |
+
return self._meanflow_step_fm(
|
| 548 |
+
input_sequence=input_sequence,
|
| 549 |
+
attn_mask=attn_mask,
|
| 550 |
+
pos_ids=pos_ids,
|
| 551 |
+
patch_size=patch_size,
|
| 552 |
+
g_cond=g_cond,
|
| 553 |
+
nfe=num_steps,
|
| 554 |
+
solver_step=solver_step,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
return self._flow_matching_step_fm(
|
| 558 |
+
input_sequence=input_sequence,
|
| 559 |
+
cfg_sequence=cfg_sequence,
|
| 560 |
+
attn_mask=attn_mask,
|
| 561 |
+
pos_ids=pos_ids,
|
| 562 |
+
hidden_size=hidden_size,
|
| 563 |
+
patch_size=patch_size,
|
| 564 |
+
g_cond=g_cond,
|
| 565 |
+
ode_method=ode_method,
|
| 566 |
+
num_steps=num_steps,
|
| 567 |
+
guidance_scale=guidance_scale,
|
| 568 |
+
solver_step=solver_step,
|
| 569 |
+
)
|
| 570 |
+
# endregion Autoregressive and flow-matching inference steps
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
class FlowMatchingHelper:
|
| 574 |
+
"""
|
| 575 |
+
Base helper for computing x_t and u_t, given target x_1 and noise x_0
|
| 576 |
+
ref: Flow matching for generative modeling, Lipman
|
| 577 |
+
"""
|
| 578 |
+
|
| 579 |
+
def __init__(self, sigma=1e-5):
|
| 580 |
+
self.sigma = sigma
|
| 581 |
+
|
| 582 |
+
def compute_mu_t(self, x1, t):
|
| 583 |
+
return t * x1
|
| 584 |
+
|
| 585 |
+
def compute_sigma_t(self, t):
|
| 586 |
+
return 1 - (1 - self.sigma) * t
|
| 587 |
+
|
| 588 |
+
def sample_x_t(self, x0, x1, t):
|
| 589 |
+
mu_t = self.compute_mu_t(x1, t)
|
| 590 |
+
sigma_t = self.compute_sigma_t(t)
|
| 591 |
+
return mu_t + sigma_t * x0
|
| 592 |
+
|
| 593 |
+
def compute_u_t(self, x0, x1):
|
| 594 |
+
return x1 - (1 - self.sigma) * x0
|
| 595 |
+
|
| 596 |
+
def compute_xt_ut(self, x1, t=None, x0=None):
|
| 597 |
+
if x0 is None:
|
| 598 |
+
x0 = torch.randn_like(x1, device=x1.device)
|
| 599 |
+
if t is None:
|
| 600 |
+
t = torch.rand(x1.size(0), dtype=x1.dtype, device=x1.device)
|
| 601 |
+
times = t
|
| 602 |
+
t = t.reshape(-1, *([1] * (x1.dim() - 1)))
|
| 603 |
+
xt = self.sample_x_t(x0, x1, t)
|
| 604 |
+
ut = self.compute_u_t(x0, x1)
|
| 605 |
+
return xt, ut, times
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
class CausalHelper:
|
| 609 |
+
def create_causal_mask_and_pos(self, seq_lens, max_len):
|
| 610 |
+
seq_mask = get_mask_from_lengths(seq_lens, max_len=max_len).unsqueeze(1)
|
| 611 |
+
causal_mask = (
|
| 612 |
+
torch.ones((max_len, max_len), device=seq_lens.device).triu(1).bool()
|
| 613 |
+
)
|
| 614 |
+
causal_mask = ~causal_mask.unsqueeze(0)
|
| 615 |
+
attn_mask = seq_mask & causal_mask
|
| 616 |
+
return attn_mask, seq_mask.squeeze(1), None
|
| 617 |
+
|
| 618 |
+
def create_causal_chunk_mask_and_pos(
|
| 619 |
+
self,
|
| 620 |
+
batch_size,
|
| 621 |
+
C_lens,
|
| 622 |
+
Z_lens,
|
| 623 |
+
span_mask,
|
| 624 |
+
patch_size=8,
|
| 625 |
+
):
|
| 626 |
+
device = C_lens.device
|
| 627 |
+
total_lens = C_lens + Z_lens
|
| 628 |
+
attn_mask = torch.zeros(
|
| 629 |
+
(batch_size, total_lens.max(), total_lens.max()),
|
| 630 |
+
device=device,
|
| 631 |
+
dtype=torch.bool,
|
| 632 |
+
)
|
| 633 |
+
pos_ids = []
|
| 634 |
+
# | C2C | |
|
| 635 |
+
# | Z2C | Z2Z |
|
| 636 |
+
for i in range(batch_size):
|
| 637 |
+
C_len = C_lens[i]
|
| 638 |
+
Z_len = Z_lens[i]
|
| 639 |
+
|
| 640 |
+
# C2C parts are standard causal attention
|
| 641 |
+
attn_mask[i, :C_len, :C_len] = (
|
| 642 |
+
torch.ones((C_len, C_len), device=device, dtype=torch.bool)
|
| 643 |
+
.triu(1)
|
| 644 |
+
.logical_not()
|
| 645 |
+
)
|
| 646 |
+
# Position ids in C parts are 0, 1, 2, ..., n
|
| 647 |
+
c_pos = torch.arange(C_len, device=device, dtype=torch.float32)
|
| 648 |
+
|
| 649 |
+
# Z2Z parts are block diag attention
|
| 650 |
+
assert Z_len % patch_size == 0, "Z_len must be multiple of patch_size"
|
| 651 |
+
attn_mask[i, C_len : C_len + Z_len, C_len : C_len + Z_len] = (
|
| 652 |
+
torch.block_diag(
|
| 653 |
+
*[
|
| 654 |
+
torch.ones(
|
| 655 |
+
(patch_size, patch_size), device=device, dtype=torch.bool
|
| 656 |
+
)
|
| 657 |
+
]
|
| 658 |
+
* (Z_len // patch_size)
|
| 659 |
+
)
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
# Z2C parts is full attention before current patch latents
|
| 663 |
+
# build according to span_mask
|
| 664 |
+
j_indices = torch.arange(Z_len, device=device)
|
| 665 |
+
patch_indices = j_indices // patch_size
|
| 666 |
+
patch_in_c_indices = torch.where(span_mask[i])[0][patch_indices]
|
| 667 |
+
attn_mask[
|
| 668 |
+
i,
|
| 669 |
+
C_len + j_indices.unsqueeze(1),
|
| 670 |
+
torch.arange(C_len, device=device).unsqueeze(0),
|
| 671 |
+
] = torch.arange(C_len, device=device).unsqueeze(
|
| 672 |
+
0
|
| 673 |
+
) < patch_in_c_indices.unsqueeze(1)
|
| 674 |
+
# Position ids in Z parts start from current patch latents index in C parts
|
| 675 |
+
z_pos = (patch_in_c_indices + j_indices % patch_size).to(torch.float32)
|
| 676 |
+
pos_ids.append(torch.cat([c_pos, z_pos]))
|
| 677 |
+
seq_mask = get_mask_from_lengths(total_lens, max_len=total_lens.max().item())
|
| 678 |
+
pos_ids = pad_sequence(pos_ids, batch_first=True, padding_value=0.0).to(
|
| 679 |
+
C_lens.device
|
| 680 |
+
)
|
| 681 |
+
return attn_mask, seq_mask, pos_ids
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class IOHelper:
|
| 685 |
+
def __init__(self, latent_stats_path=None):
|
| 686 |
+
if latent_stats_path is not None:
|
| 687 |
+
latent_stats = torch.load(latent_stats_path, weights_only=False)
|
| 688 |
+
self.global_mean = torch.as_tensor(latent_stats["mean"])
|
| 689 |
+
self.global_var = torch.as_tensor(latent_stats["var"])
|
| 690 |
+
else:
|
| 691 |
+
self.global_mean = None
|
| 692 |
+
self.global_var = None
|
| 693 |
+
|
| 694 |
+
def normalize(self, x):
|
| 695 |
+
if self.global_mean is not None and self.global_var is not None:
|
| 696 |
+
x = (x - self.global_mean.to(x.device)) / torch.sqrt(
|
| 697 |
+
self.global_var.to(x.device)
|
| 698 |
+
)
|
| 699 |
+
return x
|
| 700 |
+
|
| 701 |
+
def denormalize(self, x):
|
| 702 |
+
if self.global_mean is not None and self.global_var is not None:
|
| 703 |
+
x = x * torch.sqrt(self.global_var.to(x.device)) + self.global_mean.to(
|
| 704 |
+
x.device
|
| 705 |
+
)
|
| 706 |
+
return x
|
| 707 |
+
|
| 708 |
+
@staticmethod
|
| 709 |
+
def sample_from_latent(latent):
|
| 710 |
+
mean, log_std = latent.chunk(2, 1)
|
| 711 |
+
z = mean + torch.randn_like(mean) * torch.exp(log_std)
|
| 712 |
+
return z.transpose(1, 2)
|
| 713 |
+
|
| 714 |
+
@staticmethod
|
| 715 |
+
def prepare_inputs_for_dit(
|
| 716 |
+
hiddens,
|
| 717 |
+
hidden_lens,
|
| 718 |
+
latents,
|
| 719 |
+
latent_lens,
|
| 720 |
+
hidden_proj,
|
| 721 |
+
latent_proj,
|
| 722 |
+
noisy_proj,
|
| 723 |
+
span_mask,
|
| 724 |
+
hidden_patch_size,
|
| 725 |
+
latent_patch_size,
|
| 726 |
+
fm_helper,
|
| 727 |
+
cfg_droprate=-1,
|
| 728 |
+
):
|
| 729 |
+
assert hidden_patch_size == 1, "Hidden patch size > 1 is not supported."
|
| 730 |
+
|
| 731 |
+
B, _, _, device = *hiddens.shape, hiddens.device
|
| 732 |
+
|
| 733 |
+
# Gather span hidden states for flow matching using span_mask
|
| 734 |
+
span_hidden_list = []
|
| 735 |
+
for b in range(B):
|
| 736 |
+
indices = span_mask[b].nonzero(as_tuple=False).squeeze(-1)
|
| 737 |
+
span_hidden_list.append(hiddens[b, indices, :])
|
| 738 |
+
hiddens = pad_sequence(span_hidden_list, batch_first=True, padding_value=0.0)
|
| 739 |
+
hidden_lens = torch.tensor(
|
| 740 |
+
[t.size(0) for t in span_hidden_list], device=device, dtype=torch.long
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# Update span_mask to be all True for the new lengths
|
| 744 |
+
max_len = hiddens.size(1)
|
| 745 |
+
span_mask = torch.arange(max_len, device=device).expand(
|
| 746 |
+
B, max_len
|
| 747 |
+
) < hidden_lens.unsqueeze(1)
|
| 748 |
+
|
| 749 |
+
# Prepare history latents
|
| 750 |
+
history_latents = latent_proj(latents)
|
| 751 |
+
fm_dim = history_latents.shape[-1]
|
| 752 |
+
assert (latent_patch_size * history_latents.size(1) % latents.size(1)) == 0
|
| 753 |
+
latent_history_patch_size = (
|
| 754 |
+
latent_patch_size * history_latents.size(1) // latents.size(1)
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
# Prepare llm hidden with cfg masking
|
| 758 |
+
cfg_mask = (
|
| 759 |
+
torch.empty((B,), dtype=torch.float, device=latents.device).uniform_(0, 1)
|
| 760 |
+
< cfg_droprate
|
| 761 |
+
)
|
| 762 |
+
hiddens = hidden_proj(mask_data(hiddens, cfg_mask))
|
| 763 |
+
|
| 764 |
+
# Prepare noise latents
|
| 765 |
+
xt, ut, times = fm_helper.compute_xt_ut(latents)
|
| 766 |
+
projected_noise = noisy_proj(xt)
|
| 767 |
+
|
| 768 |
+
# Initialize empty fm_seq
|
| 769 |
+
hist_chunk_size = hidden_patch_size + latent_history_patch_size
|
| 770 |
+
valid_patch_counts = latent_lens // latent_patch_size
|
| 771 |
+
fm_prefix_lengths = hidden_lens + valid_patch_counts * (
|
| 772 |
+
hist_chunk_size - hidden_patch_size
|
| 773 |
+
)
|
| 774 |
+
fm_gen_lengths = latent_lens + valid_patch_counts * hidden_patch_size
|
| 775 |
+
fm_gen_patch_size = hidden_patch_size + latent_patch_size
|
| 776 |
+
fm_seq_lengths = fm_prefix_lengths + fm_gen_lengths
|
| 777 |
+
fm_seq = torch.zeros(
|
| 778 |
+
(B, fm_seq_lengths.max().item(), fm_dim),
|
| 779 |
+
dtype=history_latents.dtype,
|
| 780 |
+
device=device,
|
| 781 |
+
)
|
| 782 |
+
fm_target = []
|
| 783 |
+
patch_context_lengths = []
|
| 784 |
+
history_latent_span_mask = torch.zeros(
|
| 785 |
+
(B, fm_seq_lengths.max().item()), dtype=torch.bool, device=device
|
| 786 |
+
) # to mark start positions of each history latents
|
| 787 |
+
|
| 788 |
+
# Fill fm_seq
|
| 789 |
+
for b in range(B):
|
| 790 |
+
# Step 1: Interleave hiddens at span positions with patched_latents
|
| 791 |
+
interleaved = []
|
| 792 |
+
span_mask_b = span_mask[b, : hidden_lens[b]]
|
| 793 |
+
interleaved.append(
|
| 794 |
+
hiddens[b, : hidden_lens[b]][span_mask_b].reshape(
|
| 795 |
+
valid_patch_counts[b], hidden_patch_size, fm_dim
|
| 796 |
+
)
|
| 797 |
+
)
|
| 798 |
+
interleaved.append(
|
| 799 |
+
history_latents[
|
| 800 |
+
b, : valid_patch_counts[b] * latent_history_patch_size, :
|
| 801 |
+
].reshape(valid_patch_counts[b], latent_history_patch_size, fm_dim)
|
| 802 |
+
)
|
| 803 |
+
interleaved = torch.cat(interleaved, dim=1)
|
| 804 |
+
interleaved = rearrange(
|
| 805 |
+
interleaved, "n h d -> (n h) d"
|
| 806 |
+
) # [num_spans*hist_chunk_size, D]
|
| 807 |
+
|
| 808 |
+
# Step 2: Build mapping from input positions to fm positions
|
| 809 |
+
position_increment = torch.where(
|
| 810 |
+
span_mask_b, hist_chunk_size, 1
|
| 811 |
+
) # span->hist_chunk_size, non-span->1
|
| 812 |
+
fm_seq_positions = (
|
| 813 |
+
torch.cumsum(position_increment, dim=0) - position_increment
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
# Step 3: Scatter non-span hiddens
|
| 817 |
+
non_span_mask = ~span_mask_b
|
| 818 |
+
non_span_indices = fm_seq_positions[non_span_mask] # [num_non_spans]
|
| 819 |
+
fm_seq[b, non_span_indices, :] = hiddens[b, : hidden_lens[b]][
|
| 820 |
+
non_span_mask, :
|
| 821 |
+
]
|
| 822 |
+
|
| 823 |
+
# Step 4: Scatter interleaved span tokens
|
| 824 |
+
span_indices = fm_seq_positions[span_mask_b] # [num_spans]
|
| 825 |
+
span_indices_expanded = torch.stack(
|
| 826 |
+
[span_indices + i for i in range(hist_chunk_size)], dim=1
|
| 827 |
+
) # [num_spans, hist_chunk_size]
|
| 828 |
+
span_indices_flat = span_indices_expanded.reshape(
|
| 829 |
+
-1
|
| 830 |
+
) # [num_spans*hist_chunk_size]
|
| 831 |
+
fm_seq[b, span_indices_flat, :] = interleaved
|
| 832 |
+
history_latent_span_mask[b, span_indices] = True
|
| 833 |
+
patch_context_lengths.append(span_indices.clone())
|
| 834 |
+
|
| 835 |
+
# Step 5: Fill with noise latents at the end
|
| 836 |
+
noise_part = []
|
| 837 |
+
span_mask_b = span_mask[b, : hidden_lens[b]]
|
| 838 |
+
noise_part.append(
|
| 839 |
+
hiddens[b, : hidden_lens[b]][span_mask_b].reshape(
|
| 840 |
+
valid_patch_counts[b], hidden_patch_size, fm_dim
|
| 841 |
+
)
|
| 842 |
+
)
|
| 843 |
+
noise_part.append(
|
| 844 |
+
projected_noise[b, : latent_lens[b], :].reshape(
|
| 845 |
+
valid_patch_counts[b], latent_patch_size, fm_dim
|
| 846 |
+
)
|
| 847 |
+
)
|
| 848 |
+
noise_part = torch.cat(noise_part, dim=1)
|
| 849 |
+
noise_part = rearrange(noise_part, "n h d -> (n h) d")
|
| 850 |
+
noise_start = fm_seq_positions[-1] + position_increment[-1]
|
| 851 |
+
noise_end = noise_start + fm_gen_lengths[b]
|
| 852 |
+
fm_seq[b, noise_start:noise_end, :] = noise_part
|
| 853 |
+
|
| 854 |
+
# Step 6: prepare fm_target
|
| 855 |
+
ut_b = ut[b, : latent_lens[b], :]
|
| 856 |
+
fm_target.append(rearrange(ut_b, "(n p) d -> n p d", p=latent_patch_size))
|
| 857 |
+
|
| 858 |
+
# Construct fm_attn_mask and fm_pos_ids
|
| 859 |
+
fm_attn_mask, fm_seq_mask, fm_pos_ids = (
|
| 860 |
+
CausalHelper().create_causal_chunk_mask_and_pos(
|
| 861 |
+
batch_size=B,
|
| 862 |
+
C_lens=fm_prefix_lengths,
|
| 863 |
+
Z_lens=fm_gen_lengths,
|
| 864 |
+
span_mask=history_latent_span_mask,
|
| 865 |
+
patch_size=fm_gen_patch_size,
|
| 866 |
+
)
|
| 867 |
+
)
|
| 868 |
+
fm_prefix_lengths = fm_prefix_lengths.unsqueeze(1)
|
| 869 |
+
fm_gen_lengths = fm_gen_lengths.unsqueeze(1)
|
| 870 |
+
fm_target = torch.cat(fm_target, dim=0)
|
| 871 |
+
results = [
|
| 872 |
+
fm_seq,
|
| 873 |
+
fm_target,
|
| 874 |
+
fm_attn_mask,
|
| 875 |
+
fm_seq_mask,
|
| 876 |
+
fm_pos_ids,
|
| 877 |
+
times,
|
| 878 |
+
fm_prefix_lengths,
|
| 879 |
+
fm_gen_lengths,
|
| 880 |
+
fm_gen_patch_size,
|
| 881 |
+
]
|
| 882 |
+
return tuple(results)
|
| 883 |
+
|
| 884 |
+
@staticmethod
|
| 885 |
+
def get_dit_outputs(
|
| 886 |
+
pred_v,
|
| 887 |
+
fm_prefix_lengths,
|
| 888 |
+
fm_gen_lengths,
|
| 889 |
+
fm_gen_patch_size,
|
| 890 |
+
latent_patch_size,
|
| 891 |
+
):
|
| 892 |
+
B, P = fm_prefix_lengths.shape
|
| 893 |
+
fm_pred = []
|
| 894 |
+
for b in range(B):
|
| 895 |
+
p_offset = 0
|
| 896 |
+
for p in range(P):
|
| 897 |
+
latents_b = pred_v[
|
| 898 |
+
b,
|
| 899 |
+
p_offset + fm_prefix_lengths[b][p] : p_offset
|
| 900 |
+
+ fm_prefix_lengths[b][p]
|
| 901 |
+
+ fm_gen_lengths[b][p],
|
| 902 |
+
]
|
| 903 |
+
latents_b = rearrange(
|
| 904 |
+
latents_b, "(n p) d -> n p d", p=fm_gen_patch_size
|
| 905 |
+
)
|
| 906 |
+
# extract only the latent parts
|
| 907 |
+
latents_b = latents_b[:, -latent_patch_size:, :]
|
| 908 |
+
fm_pred.append(latents_b)
|
| 909 |
+
p_offset += fm_prefix_lengths[b][p] + fm_gen_lengths[b][p]
|
| 910 |
+
return torch.cat(fm_pred, dim=0)
|
src/dots_tts/models/dots_tts/model.py
ADDED
|
@@ -0,0 +1,1958 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from functools import partial
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Callable, Iterator
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
from loguru import logger
|
| 17 |
+
from safetensors.torch import load_file, save_file
|
| 18 |
+
from transformers import AutoTokenizer, Qwen2Config
|
| 19 |
+
|
| 20 |
+
from dots_tts.models.dots_tts.config import ModelConfig
|
| 21 |
+
from dots_tts.models.dots_tts.core import DotsTtsCore, DotsTtsForwardOutput
|
| 22 |
+
from dots_tts.modules.speaker.encoder import SpeakerXVectorFeatures
|
| 23 |
+
from dots_tts.modules.vocoder.bigvgan import AudioVAE
|
| 24 |
+
from dots_tts.training.losses import LossMasks, LossTerm, LossTerms
|
| 25 |
+
from dots_tts.utils.profiling import measure_inference
|
| 26 |
+
from dots_tts.utils.tokenizer import AUDIO_GEN_START_TOKEN, require_token_id
|
| 27 |
+
from dots_tts.utils.util import get_dtype
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
_AOTI_BACKENDS = {"aoti", "aot", "aotinductor", "aot_inductor"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class _AotiMethodModule(nn.Module):
|
| 34 |
+
def __init__(self, owner: nn.Module, method_name: str):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.owner = owner
|
| 37 |
+
self.method_name = method_name
|
| 38 |
+
|
| 39 |
+
def forward(self, *args, **kwargs):
|
| 40 |
+
raw_method = getattr(type(self.owner), self.method_name, None)
|
| 41 |
+
if raw_method is None:
|
| 42 |
+
return getattr(self.owner, self.method_name)(*args, **kwargs)
|
| 43 |
+
raw_callable = getattr(raw_method, "__wrapped__", raw_method)
|
| 44 |
+
return raw_callable(self.owner, *args, **kwargs)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class _LazyAotiCompiledMethod:
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
*,
|
| 51 |
+
key: str,
|
| 52 |
+
owner: nn.Module,
|
| 53 |
+
method_name: str,
|
| 54 |
+
signature: tuple[Any, ...] | None,
|
| 55 |
+
):
|
| 56 |
+
self.key = key
|
| 57 |
+
self.owner = owner
|
| 58 |
+
self.method_name = method_name
|
| 59 |
+
self.signature = signature
|
| 60 |
+
self.compiled: Callable[..., Any] | None = None
|
| 61 |
+
self.fallback: Callable[..., Any] | None = None
|
| 62 |
+
|
| 63 |
+
def __call__(self, *args, **kwargs):
|
| 64 |
+
if self.compiled is not None:
|
| 65 |
+
return self.compiled(*args, **kwargs)
|
| 66 |
+
if self.fallback is not None:
|
| 67 |
+
return self.fallback(*args, **kwargs)
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
import spaces # noqa: PLC0415
|
| 71 |
+
|
| 72 |
+
if not hasattr(spaces, "aoti_compile"):
|
| 73 |
+
raise RuntimeError("spaces.aoti_compile is not available.")
|
| 74 |
+
exported = torch.export.export(
|
| 75 |
+
_AotiMethodModule(self.owner, self.method_name).eval(),
|
| 76 |
+
args=args,
|
| 77 |
+
kwargs=kwargs,
|
| 78 |
+
)
|
| 79 |
+
self.compiled = spaces.aoti_compile(exported)
|
| 80 |
+
logger.info(
|
| 81 |
+
"AOTI compiled inference target: key={} method={} signature={}",
|
| 82 |
+
self.key,
|
| 83 |
+
self.method_name,
|
| 84 |
+
self.signature,
|
| 85 |
+
)
|
| 86 |
+
return self.compiled(*args, **kwargs)
|
| 87 |
+
except Exception:
|
| 88 |
+
if os.environ.get("DOTS_TTS_AOTI_ALLOW_EAGER_FALLBACK", "0") != "1":
|
| 89 |
+
raise
|
| 90 |
+
logger.exception(
|
| 91 |
+
"AOTI compile failed; falling back to eager method: key={} method={} signature={}",
|
| 92 |
+
self.key,
|
| 93 |
+
self.method_name,
|
| 94 |
+
self.signature,
|
| 95 |
+
)
|
| 96 |
+
self.fallback = getattr(self.owner, self.method_name)
|
| 97 |
+
return self.fallback(*args, **kwargs)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@dataclass
|
| 101 |
+
class _GenerateState:
|
| 102 |
+
llm_cache: Any | None = None
|
| 103 |
+
llm_hiddens: torch.Tensor | None = None
|
| 104 |
+
patch_encoder_state: Any | None = None
|
| 105 |
+
fm_seq_len: int = 0
|
| 106 |
+
fm_capacity: int = 0
|
| 107 |
+
fm_sequence: torch.Tensor | None = None
|
| 108 |
+
fm_cfg_sequence: torch.Tensor | None = None
|
| 109 |
+
fm_null_g_cond: torch.Tensor | None = None
|
| 110 |
+
end_flag: bool = False
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@dataclass(frozen=True)
|
| 114 |
+
class _PromptConditioning:
|
| 115 |
+
prompt_patches: torch.Tensor | None = None
|
| 116 |
+
prompt_latents: torch.Tensor | None = None
|
| 117 |
+
g_cond: torch.Tensor | None = None
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@dataclass(frozen=True)
|
| 121 |
+
class _GenerateLengthBucket:
|
| 122 |
+
size: int
|
| 123 |
+
|
| 124 |
+
def run_warmup(
|
| 125 |
+
self,
|
| 126 |
+
model: "DotsTtsModel",
|
| 127 |
+
*,
|
| 128 |
+
precision: str,
|
| 129 |
+
ode_method: str,
|
| 130 |
+
num_steps: int,
|
| 131 |
+
guidance_scale: float,
|
| 132 |
+
) -> None:
|
| 133 |
+
model._warmup_fm_bucket(
|
| 134 |
+
max_audio_patch_count=self.size,
|
| 135 |
+
precision=precision,
|
| 136 |
+
ode_method=ode_method,
|
| 137 |
+
num_steps=num_steps,
|
| 138 |
+
guidance_scale=guidance_scale,
|
| 139 |
+
)
|
| 140 |
+
model._warmup_patch_encoder_bucket(
|
| 141 |
+
max_audio_patch_count=self.size,
|
| 142 |
+
precision=precision,
|
| 143 |
+
)
|
| 144 |
+
device = next(model.core.parameters()).device
|
| 145 |
+
generation_schedule = torch.full(
|
| 146 |
+
(1, self.size + 1),
|
| 147 |
+
fill_value=model.core.audio_gen_span_id,
|
| 148 |
+
dtype=torch.long,
|
| 149 |
+
device=device,
|
| 150 |
+
)
|
| 151 |
+
generation_schedule[0, 0] = model.audio_gen_start_id
|
| 152 |
+
warmup_inputs = {"generation_schedule": generation_schedule}
|
| 153 |
+
|
| 154 |
+
for _ in model.generate_audio_stream(
|
| 155 |
+
warmup_inputs,
|
| 156 |
+
precision=precision,
|
| 157 |
+
ode_method=ode_method,
|
| 158 |
+
num_steps=num_steps,
|
| 159 |
+
guidance_scale=guidance_scale,
|
| 160 |
+
):
|
| 161 |
+
return
|
| 162 |
+
raise RuntimeError(
|
| 163 |
+
f"Warmup produced no audio chunk for generate bucket {self.size}."
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class DotsTtsModel(nn.Module):
|
| 168 |
+
"""Full train/infer model assembly around the dots.tts core network."""
|
| 169 |
+
|
| 170 |
+
_GENERATE_LENGTH_BUCKETS = (
|
| 171 |
+
_GenerateLengthBucket(32),
|
| 172 |
+
_GenerateLengthBucket(64),
|
| 173 |
+
_GenerateLengthBucket(128),
|
| 174 |
+
_GenerateLengthBucket(256),
|
| 175 |
+
_GenerateLengthBucket(512),
|
| 176 |
+
_GenerateLengthBucket(1024),
|
| 177 |
+
)
|
| 178 |
+
_COMPILE_TARGETS = frozenset(
|
| 179 |
+
{
|
| 180 |
+
"FM",
|
| 181 |
+
"patch_encoder",
|
| 182 |
+
"vocoder",
|
| 183 |
+
}
|
| 184 |
+
)
|
| 185 |
+
_optimize_enabled = True
|
| 186 |
+
CONFIG_FILENAME = "config.json"
|
| 187 |
+
HF_MODEL_TYPE = "dots_tts"
|
| 188 |
+
HF_ARCHITECTURES = ["DotsTTSForConditionalGeneration"]
|
| 189 |
+
LATENT_STATS_FILENAME = "latent_stats.pt"
|
| 190 |
+
LLM_CONFIG_FILENAME = "llm_config.json"
|
| 191 |
+
MODEL_FILENAME = "model.safetensors"
|
| 192 |
+
VOCODER_FILENAME = "vocoder.safetensors"
|
| 193 |
+
SPEAKER_ENCODER_FILENAME = "speaker_encoder.safetensors"
|
| 194 |
+
_ARTIFACT_ALIASES = (("llm.lm_head.weight", "llm.model.embed_tokens.weight"),)
|
| 195 |
+
REQUIRED_ARTIFACT_FILES = (
|
| 196 |
+
CONFIG_FILENAME,
|
| 197 |
+
LATENT_STATS_FILENAME,
|
| 198 |
+
LLM_CONFIG_FILENAME,
|
| 199 |
+
MODEL_FILENAME,
|
| 200 |
+
VOCODER_FILENAME,
|
| 201 |
+
SPEAKER_ENCODER_FILENAME,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# region Module assembly and checkpoint IO
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
config: ModelConfig,
|
| 208 |
+
tokenizer,
|
| 209 |
+
latent_stats_path: str | Path,
|
| 210 |
+
llm_config: Qwen2Config,
|
| 211 |
+
):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.config = config
|
| 214 |
+
self.tokenizer = tokenizer
|
| 215 |
+
self.latent_stats_path = Path(latent_stats_path)
|
| 216 |
+
self.audio_gen_start_id = require_token_id(
|
| 217 |
+
self.tokenizer, AUDIO_GEN_START_TOKEN
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
self.core = DotsTtsCore(
|
| 221 |
+
config,
|
| 222 |
+
llm_config=llm_config,
|
| 223 |
+
tokenizer=tokenizer,
|
| 224 |
+
latent_stats_path=self.latent_stats_path,
|
| 225 |
+
)
|
| 226 |
+
self.vocoder = AudioVAE(config.vocoder).eval()
|
| 227 |
+
self.vocoder.remove_weight_norm()
|
| 228 |
+
self.hop_size = self.vocoder.hop_size
|
| 229 |
+
self.xvector_extractor = SpeakerXVectorFeatures(
|
| 230 |
+
sample_rate=self.vocoder.sample_rate,
|
| 231 |
+
campplus_embedding_size=config.campplus_embedding_size,
|
| 232 |
+
max_audio_seconds=config.xvec_max_audio_seconds,
|
| 233 |
+
).eval()
|
| 234 |
+
|
| 235 |
+
for param in self.vocoder.parameters():
|
| 236 |
+
param.requires_grad = False
|
| 237 |
+
for param in self.xvector_extractor.parameters():
|
| 238 |
+
param.requires_grad = False
|
| 239 |
+
self._optimize_enabled = True
|
| 240 |
+
self._compiled_models: dict[
|
| 241 |
+
tuple[str, tuple[Any, ...] | None], Callable[..., Any]
|
| 242 |
+
] = {}
|
| 243 |
+
self._compile_backend = os.environ.get(
|
| 244 |
+
"DOTS_TTS_COMPILE_BACKEND",
|
| 245 |
+
"torch_compile",
|
| 246 |
+
).strip().lower()
|
| 247 |
+
self._static_generate_workspaces: dict[tuple[Any, ...], dict[str, Any]] = {}
|
| 248 |
+
self._fm_decode_workspaces: dict[tuple[Any, ...], dict[str, torch.Tensor]] = {}
|
| 249 |
+
|
| 250 |
+
def set_optimize(self, optimize: bool) -> None:
|
| 251 |
+
self._optimize_enabled = bool(optimize)
|
| 252 |
+
if not self._optimize_enabled:
|
| 253 |
+
self._compiled_models.clear()
|
| 254 |
+
|
| 255 |
+
def set_compile_backend(self, backend: str) -> None:
|
| 256 |
+
normalized_backend = (backend or "torch_compile").strip().lower()
|
| 257 |
+
if normalized_backend != self._compile_backend:
|
| 258 |
+
self._compiled_models.clear()
|
| 259 |
+
self._compile_backend = normalized_backend
|
| 260 |
+
|
| 261 |
+
def export_compiled_models(
|
| 262 |
+
self,
|
| 263 |
+
) -> dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]]:
|
| 264 |
+
exported: dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]] = {}
|
| 265 |
+
for cache_key, compiled in self._compiled_models.items():
|
| 266 |
+
if isinstance(compiled, _LazyAotiCompiledMethod):
|
| 267 |
+
if compiled.compiled is not None:
|
| 268 |
+
exported[cache_key] = compiled.compiled
|
| 269 |
+
continue
|
| 270 |
+
exported[cache_key] = compiled
|
| 271 |
+
return exported
|
| 272 |
+
|
| 273 |
+
def import_compiled_models(
|
| 274 |
+
self,
|
| 275 |
+
compiled_models: dict[tuple[str, tuple[Any, ...] | None], Callable[..., Any]],
|
| 276 |
+
) -> None:
|
| 277 |
+
self._compiled_models.update(compiled_models)
|
| 278 |
+
|
| 279 |
+
def set_cfg_droprate(
|
| 280 |
+
self,
|
| 281 |
+
cfg_droprate: float | None = None,
|
| 282 |
+
xvec_drop_rate: float | None = None,
|
| 283 |
+
) -> None:
|
| 284 |
+
if cfg_droprate is not None:
|
| 285 |
+
self.config.cfg_droprate = cfg_droprate
|
| 286 |
+
self.core.config.cfg_droprate = cfg_droprate
|
| 287 |
+
self.core.cfg_droprate = cfg_droprate
|
| 288 |
+
|
| 289 |
+
if xvec_drop_rate is not None:
|
| 290 |
+
self.config.xvec_drop_rate = xvec_drop_rate
|
| 291 |
+
self.core.config.xvec_drop_rate = xvec_drop_rate
|
| 292 |
+
self.core.xvec_drop_rate = xvec_drop_rate
|
| 293 |
+
|
| 294 |
+
@classmethod
|
| 295 |
+
def _resolve_generate_length_bucket(
|
| 296 |
+
cls,
|
| 297 |
+
max_generate_length: int,
|
| 298 |
+
) -> _GenerateLengthBucket:
|
| 299 |
+
requested = int(max_generate_length)
|
| 300 |
+
if requested <= 0:
|
| 301 |
+
raise ValueError("max_generate_length must be positive.")
|
| 302 |
+
for bucket in cls._GENERATE_LENGTH_BUCKETS:
|
| 303 |
+
if requested <= bucket.size:
|
| 304 |
+
return bucket
|
| 305 |
+
raise ValueError(
|
| 306 |
+
"max_generate_length exceeds the largest supported compile bucket: "
|
| 307 |
+
f"max_generate_length={requested} "
|
| 308 |
+
f"max_supported={cls._GENERATE_LENGTH_BUCKETS[-1].size}."
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
@torch.no_grad()
|
| 312 |
+
def run_warmup(
|
| 313 |
+
self,
|
| 314 |
+
*,
|
| 315 |
+
max_generate_length: int,
|
| 316 |
+
precision: str = "bfloat16",
|
| 317 |
+
ode_method: str = "euler",
|
| 318 |
+
num_steps: int = 10,
|
| 319 |
+
guidance_scale: float = 1.2,
|
| 320 |
+
) -> None:
|
| 321 |
+
ceiling_bucket = self._resolve_generate_length_bucket(max_generate_length)
|
| 322 |
+
warmup_buckets = tuple(
|
| 323 |
+
bucket
|
| 324 |
+
for bucket in self._GENERATE_LENGTH_BUCKETS
|
| 325 |
+
if bucket.size <= ceiling_bucket.size
|
| 326 |
+
)
|
| 327 |
+
bucket_sizes = [bucket.size for bucket in warmup_buckets]
|
| 328 |
+
logger.info(
|
| 329 |
+
"Inference warmup started: requested_max_generate_length={} bucket_sizes={}",
|
| 330 |
+
int(max_generate_length),
|
| 331 |
+
bucket_sizes,
|
| 332 |
+
)
|
| 333 |
+
for bucket in warmup_buckets:
|
| 334 |
+
bucket.run_warmup(
|
| 335 |
+
self,
|
| 336 |
+
precision=precision,
|
| 337 |
+
ode_method=ode_method,
|
| 338 |
+
num_steps=num_steps,
|
| 339 |
+
guidance_scale=guidance_scale,
|
| 340 |
+
)
|
| 341 |
+
logger.info(
|
| 342 |
+
"Inference warmup completed: requested_max_generate_length={} bucket_sizes={}",
|
| 343 |
+
int(max_generate_length),
|
| 344 |
+
bucket_sizes,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def _resolve_state_audio_patch_count(self, max_audio_patch_count: int) -> int:
|
| 348 |
+
requested = int(max_audio_patch_count)
|
| 349 |
+
if requested <= 0:
|
| 350 |
+
raise ValueError("max_audio_patch_count must be positive.")
|
| 351 |
+
if not self._optimize_enabled:
|
| 352 |
+
return requested
|
| 353 |
+
return self._resolve_generate_length_bucket(requested).size
|
| 354 |
+
|
| 355 |
+
def _warmup_fm_bucket(
|
| 356 |
+
self,
|
| 357 |
+
*,
|
| 358 |
+
max_audio_patch_count: int,
|
| 359 |
+
precision: str,
|
| 360 |
+
ode_method: str,
|
| 361 |
+
num_steps: int,
|
| 362 |
+
guidance_scale: float,
|
| 363 |
+
) -> None:
|
| 364 |
+
dtype = get_dtype(precision)
|
| 365 |
+
device = next(self.core.parameters()).device
|
| 366 |
+
use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
|
| 367 |
+
with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp):
|
| 368 |
+
state = self._allocate_generate_state(
|
| 369 |
+
max_audio_patch_count=max_audio_patch_count,
|
| 370 |
+
device=device,
|
| 371 |
+
dtype=dtype,
|
| 372 |
+
)
|
| 373 |
+
state.fm_seq_len = state.fm_capacity
|
| 374 |
+
self._decode_next_audio(
|
| 375 |
+
state,
|
| 376 |
+
device=device,
|
| 377 |
+
g_cond=None,
|
| 378 |
+
ode_method=ode_method,
|
| 379 |
+
num_steps=num_steps,
|
| 380 |
+
guidance_scale=guidance_scale,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
def _warmup_patch_encoder_bucket(
|
| 384 |
+
self,
|
| 385 |
+
*,
|
| 386 |
+
max_audio_patch_count: int,
|
| 387 |
+
precision: str,
|
| 388 |
+
) -> None:
|
| 389 |
+
dtype = get_dtype(precision)
|
| 390 |
+
device = next(self.core.parameters()).device
|
| 391 |
+
state_dtype = dtype if device.type == "cuda" else torch.float32
|
| 392 |
+
use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
|
| 393 |
+
with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp):
|
| 394 |
+
state_audio_patch_count = self._resolve_state_audio_patch_count(
|
| 395 |
+
max_audio_patch_count
|
| 396 |
+
)
|
| 397 |
+
patch_encoder_state = self.core.patch_encoder.init_decode_state(
|
| 398 |
+
max_audio_patch_count=state_audio_patch_count,
|
| 399 |
+
batch_size=1,
|
| 400 |
+
device=device,
|
| 401 |
+
dtype=state_dtype,
|
| 402 |
+
)
|
| 403 |
+
audio_patch = torch.zeros(
|
| 404 |
+
(
|
| 405 |
+
1,
|
| 406 |
+
self.core.patch_encoder.patch_size,
|
| 407 |
+
self.core.latent_dim,
|
| 408 |
+
),
|
| 409 |
+
dtype=state_dtype,
|
| 410 |
+
device=device,
|
| 411 |
+
)
|
| 412 |
+
audio_patch = self.core.io_helper.denormalize(audio_patch)
|
| 413 |
+
patch_encoder_decode = self._get_compiled_method(
|
| 414 |
+
"patch_encoder.decode_patch",
|
| 415 |
+
self.core.patch_encoder,
|
| 416 |
+
"decode_patch",
|
| 417 |
+
signature=self._patch_encoder_compile_signature(patch_encoder_state),
|
| 418 |
+
)
|
| 419 |
+
positions = torch.arange(
|
| 420 |
+
self.core.patch_encoder.out_ds_rate,
|
| 421 |
+
device=device,
|
| 422 |
+
dtype=torch.long,
|
| 423 |
+
)
|
| 424 |
+
with measure_inference("patch_encoder"):
|
| 425 |
+
patch_encoder_decode(
|
| 426 |
+
audio_patch,
|
| 427 |
+
patch_encoder_state.conv_tail,
|
| 428 |
+
patch_encoder_state.layer_caches,
|
| 429 |
+
positions,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
def _compile_callable(
|
| 433 |
+
self,
|
| 434 |
+
key: str,
|
| 435 |
+
model: Callable[..., Any],
|
| 436 |
+
*,
|
| 437 |
+
signature: tuple[Any, ...] | None = None,
|
| 438 |
+
) -> Callable[..., Any]:
|
| 439 |
+
compile_target = key.split(".", maxsplit=1)[0]
|
| 440 |
+
cache_key = (key, signature)
|
| 441 |
+
compiled = self._compiled_models.get(cache_key)
|
| 442 |
+
if compiled is None:
|
| 443 |
+
mode = (
|
| 444 |
+
"default"
|
| 445 |
+
if key == "patch_encoder.decode_patch"
|
| 446 |
+
else "reduce-overhead"
|
| 447 |
+
)
|
| 448 |
+
compiled = torch.compile(
|
| 449 |
+
model,
|
| 450 |
+
mode=mode,
|
| 451 |
+
fullgraph=True,
|
| 452 |
+
dynamic=False,
|
| 453 |
+
)
|
| 454 |
+
self._compiled_models[cache_key] = compiled
|
| 455 |
+
logger.info(
|
| 456 |
+
"Compiled inference target: key={} target={} signature={}",
|
| 457 |
+
key,
|
| 458 |
+
compile_target,
|
| 459 |
+
signature,
|
| 460 |
+
)
|
| 461 |
+
return compiled
|
| 462 |
+
|
| 463 |
+
def _get_compiled_model(
|
| 464 |
+
self,
|
| 465 |
+
key: str,
|
| 466 |
+
model: Callable[..., Any],
|
| 467 |
+
*,
|
| 468 |
+
signature: tuple[Any, ...] | None = None,
|
| 469 |
+
) -> Callable[..., Any]:
|
| 470 |
+
compile_target = key.split(".", maxsplit=1)[0]
|
| 471 |
+
if not self._optimize_enabled or compile_target not in self._COMPILE_TARGETS:
|
| 472 |
+
return model
|
| 473 |
+
return self._compile_callable(
|
| 474 |
+
key,
|
| 475 |
+
model,
|
| 476 |
+
signature=signature,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
def _get_compiled_method(
|
| 480 |
+
self,
|
| 481 |
+
key: str,
|
| 482 |
+
owner: Any,
|
| 483 |
+
method_name: str,
|
| 484 |
+
*,
|
| 485 |
+
signature: tuple[Any, ...] | None = None,
|
| 486 |
+
) -> Callable[..., Any]:
|
| 487 |
+
bound_method = getattr(owner, method_name)
|
| 488 |
+
compile_target = key.split(".", maxsplit=1)[0]
|
| 489 |
+
if not self._optimize_enabled or compile_target not in self._COMPILE_TARGETS:
|
| 490 |
+
return bound_method
|
| 491 |
+
|
| 492 |
+
cache_key = (key, signature)
|
| 493 |
+
if self._compile_backend in _AOTI_BACKENDS:
|
| 494 |
+
compiled = self._compiled_models.get(cache_key)
|
| 495 |
+
if compiled is None:
|
| 496 |
+
compiled = _LazyAotiCompiledMethod(
|
| 497 |
+
key=key,
|
| 498 |
+
owner=owner,
|
| 499 |
+
method_name=method_name,
|
| 500 |
+
signature=signature,
|
| 501 |
+
)
|
| 502 |
+
self._compiled_models[cache_key] = compiled
|
| 503 |
+
return compiled
|
| 504 |
+
|
| 505 |
+
raw_method = getattr(type(owner), method_name)
|
| 506 |
+
raw_callable = getattr(raw_method, "__wrapped__", raw_method)
|
| 507 |
+
compiled = self._compile_callable(
|
| 508 |
+
key,
|
| 509 |
+
raw_callable,
|
| 510 |
+
signature=signature,
|
| 511 |
+
)
|
| 512 |
+
return partial(compiled, owner)
|
| 513 |
+
|
| 514 |
+
def _allocate_generate_state(
|
| 515 |
+
self,
|
| 516 |
+
*,
|
| 517 |
+
max_audio_patch_count: int,
|
| 518 |
+
device: torch.device,
|
| 519 |
+
dtype: torch.dtype,
|
| 520 |
+
) -> _GenerateState:
|
| 521 |
+
state_dtype = dtype if device.type == "cuda" else torch.float32
|
| 522 |
+
state_audio_patch_count = self._resolve_state_audio_patch_count(
|
| 523 |
+
max_audio_patch_count
|
| 524 |
+
)
|
| 525 |
+
fm_capacity = state_audio_patch_count * (
|
| 526 |
+
self.core.hidden_patch_size + self.core.latent_patch_size
|
| 527 |
+
)
|
| 528 |
+
workspace_key = (
|
| 529 |
+
state_audio_patch_count,
|
| 530 |
+
str(device),
|
| 531 |
+
state_dtype,
|
| 532 |
+
)
|
| 533 |
+
workspace = self._static_generate_workspaces.get(workspace_key)
|
| 534 |
+
if workspace is None:
|
| 535 |
+
workspace = {
|
| 536 |
+
"fm_sequence": torch.zeros(
|
| 537 |
+
(1, fm_capacity, self.core.fm_hidden_size),
|
| 538 |
+
dtype=state_dtype,
|
| 539 |
+
device=device,
|
| 540 |
+
),
|
| 541 |
+
"fm_cfg_sequence": torch.zeros(
|
| 542 |
+
(1, fm_capacity, self.core.fm_hidden_size),
|
| 543 |
+
dtype=state_dtype,
|
| 544 |
+
device=device,
|
| 545 |
+
),
|
| 546 |
+
"fm_null_g_cond": torch.zeros(
|
| 547 |
+
(1, self.core.fm_hidden_size),
|
| 548 |
+
dtype=state_dtype,
|
| 549 |
+
device=device,
|
| 550 |
+
),
|
| 551 |
+
}
|
| 552 |
+
self._static_generate_workspaces[workspace_key] = workspace
|
| 553 |
+
else:
|
| 554 |
+
workspace["fm_sequence"].zero_()
|
| 555 |
+
workspace["fm_cfg_sequence"].zero_()
|
| 556 |
+
|
| 557 |
+
patch_encoder_state = None
|
| 558 |
+
if not self._optimize_enabled:
|
| 559 |
+
patch_encoder_state = self.core.patch_encoder.init_decode_state(
|
| 560 |
+
max_audio_patch_count=state_audio_patch_count,
|
| 561 |
+
batch_size=1,
|
| 562 |
+
device=device,
|
| 563 |
+
dtype=state_dtype,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
return _GenerateState(
|
| 567 |
+
patch_encoder_state=patch_encoder_state,
|
| 568 |
+
fm_seq_len=0,
|
| 569 |
+
fm_capacity=fm_capacity,
|
| 570 |
+
fm_sequence=workspace["fm_sequence"],
|
| 571 |
+
fm_cfg_sequence=workspace["fm_cfg_sequence"],
|
| 572 |
+
fm_null_g_cond=workspace["fm_null_g_cond"],
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
@staticmethod
|
| 576 |
+
def _tensor_storage_signature(tensor: torch.Tensor) -> tuple:
|
| 577 |
+
return (
|
| 578 |
+
tensor.untyped_storage().data_ptr(),
|
| 579 |
+
tensor.storage_offset(),
|
| 580 |
+
tuple(tensor.size()),
|
| 581 |
+
tuple(tensor.stride()),
|
| 582 |
+
tensor.dtype,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
@classmethod
|
| 586 |
+
def _build_artifact_state_dict(cls, module) -> dict[str, torch.Tensor]:
|
| 587 |
+
state_dict = module.state_dict()
|
| 588 |
+
skip_keys = set()
|
| 589 |
+
|
| 590 |
+
for redundant_key, canonical_key in cls._ARTIFACT_ALIASES:
|
| 591 |
+
redundant_tensor = state_dict.get(redundant_key)
|
| 592 |
+
canonical_tensor = state_dict.get(canonical_key)
|
| 593 |
+
if (
|
| 594 |
+
redundant_tensor is not None
|
| 595 |
+
and canonical_tensor is not None
|
| 596 |
+
and cls._tensor_storage_signature(redundant_tensor)
|
| 597 |
+
== cls._tensor_storage_signature(canonical_tensor)
|
| 598 |
+
):
|
| 599 |
+
skip_keys.add(redundant_key)
|
| 600 |
+
|
| 601 |
+
cleaned_state_dict = {}
|
| 602 |
+
seen_storage = set()
|
| 603 |
+
for key, value in state_dict.items():
|
| 604 |
+
if key in skip_keys:
|
| 605 |
+
continue
|
| 606 |
+
|
| 607 |
+
storage_signature = cls._tensor_storage_signature(value)
|
| 608 |
+
if storage_signature in seen_storage:
|
| 609 |
+
continue
|
| 610 |
+
|
| 611 |
+
seen_storage.add(storage_signature)
|
| 612 |
+
cleaned_state_dict[key] = value.detach().cpu().contiguous()
|
| 613 |
+
|
| 614 |
+
return cleaned_state_dict
|
| 615 |
+
|
| 616 |
+
@classmethod
|
| 617 |
+
def _restore_artifact_state_dict(cls, state_dict: dict, module) -> dict:
|
| 618 |
+
restored_state_dict = dict(state_dict)
|
| 619 |
+
for redundant_key, canonical_key in cls._ARTIFACT_ALIASES:
|
| 620 |
+
if (
|
| 621 |
+
canonical_key in restored_state_dict
|
| 622 |
+
and redundant_key not in restored_state_dict
|
| 623 |
+
and redundant_key in module.state_dict()
|
| 624 |
+
):
|
| 625 |
+
restored_state_dict[redundant_key] = restored_state_dict[canonical_key]
|
| 626 |
+
return restored_state_dict
|
| 627 |
+
|
| 628 |
+
@classmethod
|
| 629 |
+
def _save_artifact_module(cls, module, path: Path) -> None:
|
| 630 |
+
save_file(cls._build_artifact_state_dict(module), path)
|
| 631 |
+
|
| 632 |
+
@classmethod
|
| 633 |
+
def _load_artifact_module(cls, module, path: Path):
|
| 634 |
+
state_dict = load_file(path, device="cpu")
|
| 635 |
+
restored_state_dict = cls._restore_artifact_state_dict(state_dict, module)
|
| 636 |
+
mismatch = module.load_state_dict(restored_state_dict, strict=False)
|
| 637 |
+
if mismatch.missing_keys or mismatch.unexpected_keys:
|
| 638 |
+
raise RuntimeError(f"Failed to load {path}: {mismatch}")
|
| 639 |
+
return module
|
| 640 |
+
|
| 641 |
+
@classmethod
|
| 642 |
+
def _validate_pretrained_directory(
|
| 643 |
+
cls, pretrained_model_name_or_path: str | Path
|
| 644 |
+
) -> Path:
|
| 645 |
+
pretrained_path = Path(pretrained_model_name_or_path).expanduser().resolve()
|
| 646 |
+
missing_files = [
|
| 647 |
+
name
|
| 648 |
+
for name in cls.REQUIRED_ARTIFACT_FILES
|
| 649 |
+
if not (pretrained_path / name).is_file()
|
| 650 |
+
]
|
| 651 |
+
if missing_files:
|
| 652 |
+
raise FileNotFoundError(
|
| 653 |
+
f"Pretrained path {pretrained_path} is missing required files: {missing_files}"
|
| 654 |
+
)
|
| 655 |
+
return pretrained_path
|
| 656 |
+
|
| 657 |
+
@classmethod
|
| 658 |
+
def _load_pretrained_config(cls, pretrained_path: Path) -> ModelConfig:
|
| 659 |
+
return ModelConfig.model_validate(
|
| 660 |
+
json.loads(
|
| 661 |
+
(pretrained_path / cls.CONFIG_FILENAME).read_text(encoding="utf-8")
|
| 662 |
+
)
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
@staticmethod
|
| 666 |
+
def _save_llm_config(llm_config: Qwen2Config, path: Path) -> None:
|
| 667 |
+
path.write_text(
|
| 668 |
+
json.dumps(llm_config.to_dict(), ensure_ascii=True, indent=2),
|
| 669 |
+
encoding="utf-8",
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
@staticmethod
|
| 673 |
+
def _load_llm_config(path: Path) -> Qwen2Config:
|
| 674 |
+
return Qwen2Config.from_dict(json.loads(path.read_text(encoding="utf-8")))
|
| 675 |
+
|
| 676 |
+
def _tie_llm_weights(self) -> None:
|
| 677 |
+
if hasattr(self.core.llm, "tie_weights"):
|
| 678 |
+
self.core.llm.tie_weights()
|
| 679 |
+
|
| 680 |
+
def save_pretrained(self, save_directory: str | Path) -> Path:
|
| 681 |
+
save_directory = Path(save_directory)
|
| 682 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
| 683 |
+
|
| 684 |
+
config_payload = self.config.to_declared_dict()
|
| 685 |
+
config_payload["model_type"] = self.HF_MODEL_TYPE
|
| 686 |
+
config_payload["architectures"] = list(self.HF_ARCHITECTURES)
|
| 687 |
+
(save_directory / self.CONFIG_FILENAME).write_text(
|
| 688 |
+
json.dumps(config_payload, ensure_ascii=True, indent=2),
|
| 689 |
+
encoding="utf-8",
|
| 690 |
+
)
|
| 691 |
+
self._save_llm_config(
|
| 692 |
+
self.core.llm.config,
|
| 693 |
+
save_directory / self.LLM_CONFIG_FILENAME,
|
| 694 |
+
)
|
| 695 |
+
self.tokenizer.save_pretrained(save_directory)
|
| 696 |
+
shutil.copy2(
|
| 697 |
+
self.latent_stats_path,
|
| 698 |
+
save_directory / self.LATENT_STATS_FILENAME,
|
| 699 |
+
)
|
| 700 |
+
self._save_artifact_module(self.core, save_directory / self.MODEL_FILENAME)
|
| 701 |
+
self._save_artifact_module(self.vocoder, save_directory / self.VOCODER_FILENAME)
|
| 702 |
+
self._save_artifact_module(
|
| 703 |
+
self.xvector_extractor,
|
| 704 |
+
save_directory / self.SPEAKER_ENCODER_FILENAME,
|
| 705 |
+
)
|
| 706 |
+
return save_directory
|
| 707 |
+
|
| 708 |
+
def _load_pretrained_artifacts(self, pretrained_path: Path) -> None:
|
| 709 |
+
self.latent_stats_path = pretrained_path / self.LATENT_STATS_FILENAME
|
| 710 |
+
self.core.io_helper = type(self.core.io_helper)(
|
| 711 |
+
latent_stats_path=self.latent_stats_path
|
| 712 |
+
)
|
| 713 |
+
self._load_artifact_module(self.core, pretrained_path / self.MODEL_FILENAME)
|
| 714 |
+
self._tie_llm_weights()
|
| 715 |
+
self._load_artifact_module(
|
| 716 |
+
self.vocoder, pretrained_path / self.VOCODER_FILENAME
|
| 717 |
+
)
|
| 718 |
+
self._load_artifact_module(
|
| 719 |
+
self.xvector_extractor,
|
| 720 |
+
pretrained_path / self.SPEAKER_ENCODER_FILENAME,
|
| 721 |
+
)
|
| 722 |
+
self.core.eval()
|
| 723 |
+
self.vocoder.eval()
|
| 724 |
+
self.xvector_extractor.eval()
|
| 725 |
+
|
| 726 |
+
def load_pretrained_weights(
|
| 727 |
+
self, pretrained_model_name_or_path: str | Path
|
| 728 |
+
) -> None:
|
| 729 |
+
pretrained_path = self._validate_pretrained_directory(
|
| 730 |
+
pretrained_model_name_or_path
|
| 731 |
+
)
|
| 732 |
+
saved_config = self._load_pretrained_config(pretrained_path)
|
| 733 |
+
if saved_config.to_declared_dict() != self.config.to_declared_dict():
|
| 734 |
+
raise ValueError(
|
| 735 |
+
f"Pretrained config at {pretrained_path} does not match the current model."
|
| 736 |
+
)
|
| 737 |
+
saved_llm_config = self._load_llm_config(
|
| 738 |
+
pretrained_path / self.LLM_CONFIG_FILENAME
|
| 739 |
+
)
|
| 740 |
+
if saved_llm_config.to_dict() != self.core.llm.config.to_dict():
|
| 741 |
+
raise ValueError(
|
| 742 |
+
f"Pretrained LLM config at {pretrained_path} does not match the current model."
|
| 743 |
+
)
|
| 744 |
+
self._load_pretrained_artifacts(pretrained_path)
|
| 745 |
+
|
| 746 |
+
@classmethod
|
| 747 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str | Path):
|
| 748 |
+
logger.info(
|
| 749 |
+
"DotsTtsModel load started: pretrained_path={}",
|
| 750 |
+
pretrained_model_name_or_path,
|
| 751 |
+
)
|
| 752 |
+
pretrained_model_name_or_path = cls._validate_pretrained_directory(
|
| 753 |
+
pretrained_model_name_or_path
|
| 754 |
+
)
|
| 755 |
+
config = cls._load_pretrained_config(pretrained_model_name_or_path)
|
| 756 |
+
llm_config = cls._load_llm_config(
|
| 757 |
+
pretrained_model_name_or_path / cls.LLM_CONFIG_FILENAME
|
| 758 |
+
)
|
| 759 |
+
logger.info(
|
| 760 |
+
"DotsTtsModel config loaded: pretrained_path={} sample_rate={} patch_size={}",
|
| 761 |
+
pretrained_model_name_or_path,
|
| 762 |
+
config.vocoder.sample_rate,
|
| 763 |
+
config.patch_size,
|
| 764 |
+
)
|
| 765 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 766 |
+
str(pretrained_model_name_or_path),
|
| 767 |
+
local_files_only=True,
|
| 768 |
+
)
|
| 769 |
+
model = cls(
|
| 770 |
+
config,
|
| 771 |
+
tokenizer=tokenizer,
|
| 772 |
+
latent_stats_path=pretrained_model_name_or_path / cls.LATENT_STATS_FILENAME,
|
| 773 |
+
llm_config=llm_config,
|
| 774 |
+
)
|
| 775 |
+
model._load_pretrained_artifacts(pretrained_model_name_or_path)
|
| 776 |
+
logger.info(
|
| 777 |
+
"DotsTtsModel load completed: pretrained_path={}",
|
| 778 |
+
pretrained_model_name_or_path,
|
| 779 |
+
)
|
| 780 |
+
return model.eval()
|
| 781 |
+
|
| 782 |
+
# endregion Module assembly and checkpoint IO
|
| 783 |
+
|
| 784 |
+
# region Training batch preparation
|
| 785 |
+
@torch.no_grad()
|
| 786 |
+
def prepare_training_inputs(self, data: dict[str, Any]) -> dict[str, Any]:
|
| 787 |
+
self.vocoder.eval()
|
| 788 |
+
self.xvector_extractor.eval()
|
| 789 |
+
processed = dict(data)
|
| 790 |
+
sample: torch.Tensor | None = data.get("sample")
|
| 791 |
+
sample_lengths: torch.Tensor | None = data.get("sample_lengths")
|
| 792 |
+
|
| 793 |
+
if sample is not None:
|
| 794 |
+
latents = self.vocoder.extract_latents(sample)
|
| 795 |
+
processed["latents"] = latents
|
| 796 |
+
if sample_lengths is not None:
|
| 797 |
+
processed["latent_lengths"] = sample_lengths // self.hop_size
|
| 798 |
+
else:
|
| 799 |
+
processed["latent_lengths"] = torch.full(
|
| 800 |
+
(latents.size(0),),
|
| 801 |
+
latents.size(-1),
|
| 802 |
+
dtype=torch.long,
|
| 803 |
+
device=latents.device,
|
| 804 |
+
)
|
| 805 |
+
processed["latents_sampled"] = self.core.io_helper.sample_from_latent(
|
| 806 |
+
latents
|
| 807 |
+
)
|
| 808 |
+
fbank = data.get("fbank")
|
| 809 |
+
fbank_lengths = data.get("fbank_lengths")
|
| 810 |
+
processed["xvector"] = self.xvector_extractor(
|
| 811 |
+
sample,
|
| 812 |
+
audio_lengths=sample_lengths,
|
| 813 |
+
fbank=fbank,
|
| 814 |
+
fbank_lengths=fbank_lengths,
|
| 815 |
+
)
|
| 816 |
+
else:
|
| 817 |
+
processed["latents"] = None
|
| 818 |
+
processed["latent_lengths"] = None
|
| 819 |
+
|
| 820 |
+
return processed
|
| 821 |
+
|
| 822 |
+
def _build_audio_span_mask(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 823 |
+
span_mask = torch.zeros_like(token_ids, dtype=torch.bool)
|
| 824 |
+
for token_id in self.core.audio_span_token_ids:
|
| 825 |
+
span_mask = span_mask | (token_ids == token_id)
|
| 826 |
+
return span_mask
|
| 827 |
+
|
| 828 |
+
def _prepare_loss_metadata(self, data: dict[str, Any]) -> dict[str, Any]:
|
| 829 |
+
input_ids: torch.Tensor = data["input_ids"]
|
| 830 |
+
labels: torch.Tensor = data["labels"]
|
| 831 |
+
loss_mask: torch.Tensor = data["loss_mask"]
|
| 832 |
+
input_span_mask = self._build_audio_span_mask(input_ids)
|
| 833 |
+
output_span_mask = self._build_audio_span_mask(labels)
|
| 834 |
+
output_span_mask_float = output_span_mask.to(loss_mask.dtype)
|
| 835 |
+
llm_loss_mask = loss_mask * (1.0 - output_span_mask_float)
|
| 836 |
+
fm_loss_mask = loss_mask * output_span_mask_float
|
| 837 |
+
patch_counts = output_span_mask.sum(dim=1)
|
| 838 |
+
max_patch_count = max(1, int(patch_counts.max().item()))
|
| 839 |
+
fm_patch_mask = loss_mask.new_zeros((loss_mask.size(0), max_patch_count))
|
| 840 |
+
for batch_idx in range(loss_mask.size(0)):
|
| 841 |
+
count = int(patch_counts[batch_idx].item())
|
| 842 |
+
if count <= 0:
|
| 843 |
+
continue
|
| 844 |
+
fm_patch_mask[batch_idx, :count] = fm_loss_mask[batch_idx].masked_select(
|
| 845 |
+
output_span_mask[batch_idx]
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
return {
|
| 849 |
+
"input_span_mask": input_span_mask,
|
| 850 |
+
"output_span_mask": output_span_mask,
|
| 851 |
+
"loss_masks": {
|
| 852 |
+
"ce_loss": llm_loss_mask,
|
| 853 |
+
"fm_loss": fm_patch_mask,
|
| 854 |
+
"eos_loss": self._build_eos_loss_mask(fm_loss_mask),
|
| 855 |
+
},
|
| 856 |
+
}
|
| 857 |
+
|
| 858 |
+
@staticmethod
|
| 859 |
+
def _build_eos_loss_mask(eos_loss_mask: torch.Tensor) -> torch.Tensor:
|
| 860 |
+
batch_size, seq_len = eos_loss_mask.shape
|
| 861 |
+
mask = eos_loss_mask.to(dtype=torch.bool)
|
| 862 |
+
target = torch.zeros((batch_size, seq_len), dtype=torch.bool, device=mask.device)
|
| 863 |
+
mask_counts = mask.sum(dim=1, keepdim=True)
|
| 864 |
+
cumulative = mask.long().cumsum(dim=1)
|
| 865 |
+
target[mask & (cumulative == mask_counts)] = True
|
| 866 |
+
|
| 867 |
+
mask_counts_flat = mask_counts.squeeze(1)
|
| 868 |
+
neg_counts = (mask_counts_flat - 1).clamp_min(0).to(eos_loss_mask.dtype)
|
| 869 |
+
pos_weight = torch.where(
|
| 870 |
+
neg_counts > 0,
|
| 871 |
+
torch.full_like(neg_counts, 0.5),
|
| 872 |
+
torch.ones_like(neg_counts),
|
| 873 |
+
).unsqueeze(1)
|
| 874 |
+
neg_weight = torch.where(
|
| 875 |
+
neg_counts > 0,
|
| 876 |
+
0.5 / neg_counts,
|
| 877 |
+
torch.zeros_like(neg_counts),
|
| 878 |
+
).unsqueeze(1)
|
| 879 |
+
|
| 880 |
+
positive_mask = target & mask
|
| 881 |
+
negative_mask = mask & ~positive_mask
|
| 882 |
+
return torch.where(
|
| 883 |
+
positive_mask,
|
| 884 |
+
pos_weight,
|
| 885 |
+
negative_mask.to(eos_loss_mask.dtype) * neg_weight,
|
| 886 |
+
)
|
| 887 |
+
# endregion Training batch preparation
|
| 888 |
+
|
| 889 |
+
# region Training loss assembly and forward
|
| 890 |
+
@staticmethod
|
| 891 |
+
def _compute_ce_loss_term(
|
| 892 |
+
llm_logits: torch.Tensor,
|
| 893 |
+
llm_labels: torch.Tensor,
|
| 894 |
+
llm_loss_mask: torch.Tensor,
|
| 895 |
+
) -> LossTerm:
|
| 896 |
+
vocab_size = llm_logits.size(-1)
|
| 897 |
+
ce_loss = F.cross_entropy(
|
| 898 |
+
llm_logits.view(-1, vocab_size),
|
| 899 |
+
llm_labels.view(-1),
|
| 900 |
+
reduction="none",
|
| 901 |
+
).view_as(llm_labels)
|
| 902 |
+
return LossTerm(loss=ce_loss, mask=llm_loss_mask.to(ce_loss.dtype))
|
| 903 |
+
|
| 904 |
+
@staticmethod
|
| 905 |
+
def _compute_fm_loss_term(
|
| 906 |
+
pred: torch.Tensor,
|
| 907 |
+
target: torch.Tensor,
|
| 908 |
+
fm_patch_mask: torch.Tensor,
|
| 909 |
+
) -> LossTerm:
|
| 910 |
+
batch_size, max_patch_count = fm_patch_mask.shape
|
| 911 |
+
fm_loss = (pred - target).pow(2).mean(dim=2).mean(dim=1)
|
| 912 |
+
loss = fm_loss.new_zeros((batch_size, max_patch_count))
|
| 913 |
+
patch_counts = fm_patch_mask.gt(0).sum(dim=1).tolist()
|
| 914 |
+
expected_count = int(sum(patch_counts))
|
| 915 |
+
if expected_count > 0 and int(fm_loss.numel()) != expected_count:
|
| 916 |
+
raise RuntimeError(
|
| 917 |
+
"Flow-matching loss count mismatch: "
|
| 918 |
+
f"expected {expected_count}, got {int(fm_loss.numel())}."
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
offset = 0
|
| 922 |
+
for batch_idx, patch_count in enumerate(patch_counts):
|
| 923 |
+
if patch_count <= 0:
|
| 924 |
+
continue
|
| 925 |
+
next_offset = offset + int(patch_count)
|
| 926 |
+
loss[batch_idx, :patch_count] = fm_loss[offset:next_offset]
|
| 927 |
+
offset = next_offset
|
| 928 |
+
return LossTerm(loss=loss, mask=fm_patch_mask.to(loss.dtype))
|
| 929 |
+
|
| 930 |
+
@staticmethod
|
| 931 |
+
def _compute_eos_loss_term(
|
| 932 |
+
eos_out: torch.Tensor,
|
| 933 |
+
eos_loss_mask: torch.Tensor,
|
| 934 |
+
) -> LossTerm:
|
| 935 |
+
batch_size, seq_len, _ = eos_out.shape
|
| 936 |
+
weights = eos_loss_mask.to(device=eos_out.device)
|
| 937 |
+
mask = weights.gt(0)
|
| 938 |
+
target = torch.zeros(
|
| 939 |
+
(batch_size, seq_len),
|
| 940 |
+
dtype=torch.long,
|
| 941 |
+
device=eos_out.device,
|
| 942 |
+
)
|
| 943 |
+
mask_counts = mask.sum(dim=1, keepdim=True)
|
| 944 |
+
cumulative = mask.long().cumsum(dim=1)
|
| 945 |
+
target[mask & (cumulative == mask_counts)] = 1
|
| 946 |
+
|
| 947 |
+
logits = rearrange(eos_out, "b n c -> b c n")
|
| 948 |
+
ce_per_token = F.cross_entropy(logits, target, reduction="none")
|
| 949 |
+
return LossTerm(loss=ce_per_token, mask=weights.to(ce_per_token.dtype))
|
| 950 |
+
|
| 951 |
+
@staticmethod
|
| 952 |
+
def _compute_eos_loss_stats(
|
| 953 |
+
eos_out: torch.Tensor,
|
| 954 |
+
eos_loss_mask: torch.Tensor,
|
| 955 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 956 |
+
weights = DotsTtsModel._build_eos_loss_mask(eos_loss_mask)
|
| 957 |
+
term = DotsTtsModel._compute_eos_loss_term(eos_out, weights)
|
| 958 |
+
mask = term.mask.to(device=term.loss.device, dtype=term.loss.dtype)
|
| 959 |
+
eos_loss_sum = (term.loss * mask).sum(dim=1)
|
| 960 |
+
eos_sample_count = eos_loss_mask.to(device=term.loss.device).gt(0).any(
|
| 961 |
+
dim=1
|
| 962 |
+
).to(term.loss.dtype)
|
| 963 |
+
return eos_loss_sum, eos_sample_count
|
| 964 |
+
|
| 965 |
+
def _compute_loss_terms(
|
| 966 |
+
self,
|
| 967 |
+
outputs: DotsTtsForwardOutput,
|
| 968 |
+
*,
|
| 969 |
+
labels: torch.Tensor,
|
| 970 |
+
loss_masks: LossMasks,
|
| 971 |
+
) -> LossTerms:
|
| 972 |
+
return {
|
| 973 |
+
"ce_loss": self._compute_ce_loss_term(
|
| 974 |
+
outputs.llm_logits,
|
| 975 |
+
labels,
|
| 976 |
+
loss_masks["ce_loss"],
|
| 977 |
+
),
|
| 978 |
+
"fm_loss": self._compute_fm_loss_term(
|
| 979 |
+
outputs.pred,
|
| 980 |
+
outputs.target,
|
| 981 |
+
loss_masks["fm_loss"],
|
| 982 |
+
),
|
| 983 |
+
"eos_loss": self._compute_eos_loss_term(
|
| 984 |
+
outputs.eos_out,
|
| 985 |
+
loss_masks["eos_loss"],
|
| 986 |
+
),
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
def prepare_training_batch(self, data: dict[str, Any]) -> dict[str, Any]:
|
| 990 |
+
prepared = dict(data)
|
| 991 |
+
prepared.update(self._prepare_loss_metadata(prepared))
|
| 992 |
+
return prepared
|
| 993 |
+
|
| 994 |
+
def forward(self, data: dict[str, Any]) -> LossTerms:
|
| 995 |
+
loss_masks: LossMasks = data["loss_masks"]
|
| 996 |
+
processed = self.prepare_training_inputs(data)
|
| 997 |
+
processed["input_span_mask"] = data["input_span_mask"]
|
| 998 |
+
processed["output_span_mask"] = data["output_span_mask"]
|
| 999 |
+
return self._compute_loss_terms(
|
| 1000 |
+
self.core(processed),
|
| 1001 |
+
labels=processed["labels"],
|
| 1002 |
+
loss_masks=loss_masks,
|
| 1003 |
+
)
|
| 1004 |
+
# endregion Training loss assembly and forward
|
| 1005 |
+
|
| 1006 |
+
# region Prompt conditioning and decode state helpers
|
| 1007 |
+
@torch.no_grad()
|
| 1008 |
+
def _prepare_prompt_conditioning(
|
| 1009 |
+
self,
|
| 1010 |
+
prompt_audio: torch.Tensor | None,
|
| 1011 |
+
*,
|
| 1012 |
+
use_prompt_prefill: bool,
|
| 1013 |
+
speaker_scale: float = 1.5,
|
| 1014 |
+
) -> _PromptConditioning:
|
| 1015 |
+
if prompt_audio is None:
|
| 1016 |
+
logger.info("Prompt conditioning skipped: no prompt audio provided.")
|
| 1017 |
+
return _PromptConditioning()
|
| 1018 |
+
|
| 1019 |
+
self.vocoder.eval()
|
| 1020 |
+
self.xvector_extractor.eval()
|
| 1021 |
+
device = next(self.core.parameters()).device
|
| 1022 |
+
if prompt_audio.ndim == 1:
|
| 1023 |
+
prompt_audio = prompt_audio.unsqueeze(0)
|
| 1024 |
+
prompt_audio = prompt_audio.to(device=device)
|
| 1025 |
+
|
| 1026 |
+
target_len = math.ceil(
|
| 1027 |
+
prompt_audio.size(1) / (self.config.patch_size * self.hop_size)
|
| 1028 |
+
) * (self.config.patch_size * self.hop_size)
|
| 1029 |
+
pad_len = target_len - prompt_audio.size(1)
|
| 1030 |
+
if pad_len > 0:
|
| 1031 |
+
prompt_audio = F.pad(prompt_audio, (0, pad_len))
|
| 1032 |
+
|
| 1033 |
+
speaker_encoder = self._get_compiled_model(
|
| 1034 |
+
"speaker_encoder",
|
| 1035 |
+
self.xvector_extractor,
|
| 1036 |
+
)
|
| 1037 |
+
with measure_inference("speaker_encoder"):
|
| 1038 |
+
speaker_embedding = (
|
| 1039 |
+
speaker_encoder(prompt_audio[None, :]) * float(speaker_scale)
|
| 1040 |
+
)
|
| 1041 |
+
g_cond = self.core.xvec_proj(speaker_embedding)
|
| 1042 |
+
if not use_prompt_prefill:
|
| 1043 |
+
logger.info(
|
| 1044 |
+
"Reference-audio-only conditioning prepared: prompt_samples={} speaker_scale={} device={}",
|
| 1045 |
+
prompt_audio.shape[-1],
|
| 1046 |
+
speaker_scale,
|
| 1047 |
+
device,
|
| 1048 |
+
)
|
| 1049 |
+
return _PromptConditioning(g_cond=g_cond)
|
| 1050 |
+
|
| 1051 |
+
latent_encoder = self._get_compiled_model(
|
| 1052 |
+
"latent_encoder",
|
| 1053 |
+
self.vocoder.extract_latents,
|
| 1054 |
+
)
|
| 1055 |
+
with measure_inference("latent_encoder"):
|
| 1056 |
+
prompt_latents = latent_encoder(prompt_audio[None, :])
|
| 1057 |
+
prompt_latents_sampled = self.core.io_helper.sample_from_latent(prompt_latents)
|
| 1058 |
+
prompt_latents_sampled = prompt_latents_sampled[:, : -self.config.patch_size]
|
| 1059 |
+
prompt_patches = rearrange(
|
| 1060 |
+
self.core.io_helper.normalize(prompt_latents_sampled),
|
| 1061 |
+
"b (s p) d -> b s p d",
|
| 1062 |
+
p=self.config.patch_size,
|
| 1063 |
+
)
|
| 1064 |
+
logger.info(
|
| 1065 |
+
"Prompt conditioning prepared: prompt_samples={} prompt_patch_count={} "
|
| 1066 |
+
"speaker_scale={} device={}",
|
| 1067 |
+
prompt_audio.shape[-1],
|
| 1068 |
+
prompt_patches.size(1),
|
| 1069 |
+
speaker_scale,
|
| 1070 |
+
device,
|
| 1071 |
+
)
|
| 1072 |
+
return _PromptConditioning(
|
| 1073 |
+
prompt_patches=prompt_patches,
|
| 1074 |
+
prompt_latents=prompt_latents_sampled,
|
| 1075 |
+
g_cond=g_cond,
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
@staticmethod
|
| 1079 |
+
def _patch_encoder_compile_signature(
|
| 1080 |
+
patch_encoder_state: Any,
|
| 1081 |
+
) -> tuple[int, torch.dtype]:
|
| 1082 |
+
key_cache, _ = patch_encoder_state.layer_caches[0]
|
| 1083 |
+
return int(key_cache.size(2)), key_cache.dtype
|
| 1084 |
+
|
| 1085 |
+
def _resolve_patch_encoder_audio_bucket(self, required_seq_len: int) -> int:
|
| 1086 |
+
requested = int(required_seq_len)
|
| 1087 |
+
if requested <= 0:
|
| 1088 |
+
raise ValueError("required_seq_len must be positive.")
|
| 1089 |
+
requested_patch_count = math.ceil(
|
| 1090 |
+
requested / self.core.patch_encoder.out_ds_rate
|
| 1091 |
+
)
|
| 1092 |
+
if not self._optimize_enabled:
|
| 1093 |
+
return requested_patch_count
|
| 1094 |
+
return self._resolve_generate_length_bucket(requested_patch_count).size
|
| 1095 |
+
|
| 1096 |
+
def _copy_patch_encoder_state(self, source: Any, target: Any) -> None:
|
| 1097 |
+
seq_len = source.seq_len
|
| 1098 |
+
target_capacity = int(target.layer_caches[0][0].size(2))
|
| 1099 |
+
if seq_len > target_capacity:
|
| 1100 |
+
raise ValueError(
|
| 1101 |
+
"Patch encoder state copy exceeds target capacity: "
|
| 1102 |
+
f"seq_len={seq_len} capacity={target_capacity}."
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
target.conv_tail.copy_(source.conv_tail)
|
| 1106 |
+
target.seq_len = seq_len
|
| 1107 |
+
for (source_key, source_value), (target_key, target_value) in zip(
|
| 1108 |
+
source.layer_caches,
|
| 1109 |
+
target.layer_caches,
|
| 1110 |
+
strict=True,
|
| 1111 |
+
):
|
| 1112 |
+
if seq_len > 0:
|
| 1113 |
+
target_key[:, :, :seq_len, :].copy_(source_key[:, :, :seq_len, :])
|
| 1114 |
+
target_value[:, :, :seq_len, :].copy_(source_value[:, :, :seq_len, :])
|
| 1115 |
+
|
| 1116 |
+
def _ensure_patch_encoder_state_capacity(
|
| 1117 |
+
self,
|
| 1118 |
+
state: _GenerateState,
|
| 1119 |
+
*,
|
| 1120 |
+
required_seq_len: int,
|
| 1121 |
+
device: torch.device,
|
| 1122 |
+
dtype: torch.dtype,
|
| 1123 |
+
) -> None:
|
| 1124 |
+
current_state = state.patch_encoder_state
|
| 1125 |
+
if current_state is not None:
|
| 1126 |
+
current_capacity = int(current_state.layer_caches[0][0].size(2))
|
| 1127 |
+
if current_capacity >= required_seq_len:
|
| 1128 |
+
return
|
| 1129 |
+
|
| 1130 |
+
target_audio_patch_count = self._resolve_patch_encoder_audio_bucket(
|
| 1131 |
+
required_seq_len
|
| 1132 |
+
)
|
| 1133 |
+
next_state = self.core.patch_encoder.init_decode_state(
|
| 1134 |
+
max_audio_patch_count=target_audio_patch_count,
|
| 1135 |
+
batch_size=1,
|
| 1136 |
+
device=device,
|
| 1137 |
+
dtype=dtype,
|
| 1138 |
+
)
|
| 1139 |
+
if current_state is not None:
|
| 1140 |
+
self._copy_patch_encoder_state(current_state, next_state)
|
| 1141 |
+
state.patch_encoder_state = next_state
|
| 1142 |
+
|
| 1143 |
+
def _prefill_prompt_latents(
|
| 1144 |
+
self,
|
| 1145 |
+
prompt_latents: torch.Tensor | None,
|
| 1146 |
+
*,
|
| 1147 |
+
state: _GenerateState,
|
| 1148 |
+
) -> torch.Tensor | None:
|
| 1149 |
+
if prompt_latents is None:
|
| 1150 |
+
return None
|
| 1151 |
+
if prompt_latents.size(1) == 0:
|
| 1152 |
+
return prompt_latents.new_zeros(
|
| 1153 |
+
(prompt_latents.size(0), 0, self.core.llm_hidden_size)
|
| 1154 |
+
)
|
| 1155 |
+
self._ensure_patch_encoder_state_capacity(
|
| 1156 |
+
state,
|
| 1157 |
+
required_seq_len=(
|
| 1158 |
+
(prompt_latents.size(1) // self.core.patch_encoder.patch_size)
|
| 1159 |
+
* self.core.patch_encoder.out_ds_rate
|
| 1160 |
+
),
|
| 1161 |
+
device=prompt_latents.device,
|
| 1162 |
+
dtype=(
|
| 1163 |
+
state.fm_sequence.dtype
|
| 1164 |
+
if state.fm_sequence is not None
|
| 1165 |
+
else prompt_latents.dtype
|
| 1166 |
+
),
|
| 1167 |
+
)
|
| 1168 |
+
with measure_inference("patch_encoder"):
|
| 1169 |
+
prompt_patch_embeddings, state.patch_encoder_state = (
|
| 1170 |
+
self.core.patch_encoder.prefill(
|
| 1171 |
+
prompt_latents,
|
| 1172 |
+
state.patch_encoder_state,
|
| 1173 |
+
)
|
| 1174 |
+
)
|
| 1175 |
+
return prompt_patch_embeddings
|
| 1176 |
+
|
| 1177 |
+
def _get_fm_decode_workspace(
|
| 1178 |
+
self,
|
| 1179 |
+
*,
|
| 1180 |
+
total_len: int,
|
| 1181 |
+
device: torch.device,
|
| 1182 |
+
dtype: torch.dtype,
|
| 1183 |
+
) -> dict[str, torch.Tensor]:
|
| 1184 |
+
workspace_key = (total_len, str(device), dtype)
|
| 1185 |
+
workspace = self._fm_decode_workspaces.get(workspace_key)
|
| 1186 |
+
if workspace is None:
|
| 1187 |
+
workspace = {
|
| 1188 |
+
"input_sequence": torch.zeros(
|
| 1189 |
+
(1, total_len, self.core.fm_hidden_size),
|
| 1190 |
+
dtype=dtype,
|
| 1191 |
+
device=device,
|
| 1192 |
+
),
|
| 1193 |
+
"cfg_sequence": torch.zeros(
|
| 1194 |
+
(1, total_len, self.core.fm_hidden_size),
|
| 1195 |
+
dtype=dtype,
|
| 1196 |
+
device=device,
|
| 1197 |
+
),
|
| 1198 |
+
"attn_mask": torch.zeros(
|
| 1199 |
+
(1, total_len, total_len),
|
| 1200 |
+
dtype=torch.bool,
|
| 1201 |
+
device=device,
|
| 1202 |
+
),
|
| 1203 |
+
"pos_ids": torch.zeros(
|
| 1204 |
+
(1, total_len),
|
| 1205 |
+
dtype=torch.float32,
|
| 1206 |
+
device=device,
|
| 1207 |
+
),
|
| 1208 |
+
}
|
| 1209 |
+
self._fm_decode_workspaces[workspace_key] = workspace
|
| 1210 |
+
else:
|
| 1211 |
+
workspace["input_sequence"].zero_()
|
| 1212 |
+
workspace["cfg_sequence"].zero_()
|
| 1213 |
+
return workspace
|
| 1214 |
+
|
| 1215 |
+
def _resolve_fm_history_bucket_capacity(self, fm_seq_len: int) -> int:
|
| 1216 |
+
requested = int(fm_seq_len)
|
| 1217 |
+
if requested <= 0:
|
| 1218 |
+
raise ValueError("fm_seq_len must be positive.")
|
| 1219 |
+
if not self._optimize_enabled:
|
| 1220 |
+
return requested
|
| 1221 |
+
history_stride = self.core.hidden_patch_size + self.core.latent_patch_size
|
| 1222 |
+
requested_patch_count = math.ceil(requested / history_stride)
|
| 1223 |
+
return self._resolve_generate_length_bucket(
|
| 1224 |
+
requested_patch_count
|
| 1225 |
+
).size * history_stride
|
| 1226 |
+
|
| 1227 |
+
def _build_fm_attn_mask(
|
| 1228 |
+
self,
|
| 1229 |
+
*,
|
| 1230 |
+
state: _GenerateState,
|
| 1231 |
+
attn_mask: torch.Tensor,
|
| 1232 |
+
) -> torch.Tensor:
|
| 1233 |
+
if state.fm_seq_len <= 0:
|
| 1234 |
+
raise RuntimeError("FM sequence length must be positive before decode.")
|
| 1235 |
+
hidden_patch_size = self.core.hidden_patch_size
|
| 1236 |
+
latent_start = attn_mask.size(-1) - self.core.latent_patch_size
|
| 1237 |
+
attn_mask.zero_()
|
| 1238 |
+
block_start = state.fm_seq_len - hidden_patch_size
|
| 1239 |
+
if block_start > 0:
|
| 1240 |
+
causal_mask = torch.ones(
|
| 1241 |
+
(block_start, block_start),
|
| 1242 |
+
device=attn_mask.device,
|
| 1243 |
+
dtype=torch.bool,
|
| 1244 |
+
).triu(1).logical_not()
|
| 1245 |
+
attn_mask[:, :block_start, :block_start] = causal_mask
|
| 1246 |
+
|
| 1247 |
+
attn_mask[:, block_start : state.fm_seq_len, : state.fm_seq_len] = True
|
| 1248 |
+
attn_mask[:, block_start : state.fm_seq_len, latent_start:] = True
|
| 1249 |
+
attn_mask[:, latent_start:, : state.fm_seq_len] = True
|
| 1250 |
+
attn_mask[:, latent_start:, latent_start:] = True
|
| 1251 |
+
if latent_start > state.fm_seq_len:
|
| 1252 |
+
padding_indices = torch.arange(
|
| 1253 |
+
state.fm_seq_len,
|
| 1254 |
+
latent_start,
|
| 1255 |
+
device=attn_mask.device,
|
| 1256 |
+
)
|
| 1257 |
+
attn_mask[:, padding_indices, padding_indices] = True
|
| 1258 |
+
return attn_mask
|
| 1259 |
+
|
| 1260 |
+
def _build_fm_pos_ids(
|
| 1261 |
+
self,
|
| 1262 |
+
*,
|
| 1263 |
+
state: _GenerateState,
|
| 1264 |
+
pos_ids: torch.Tensor,
|
| 1265 |
+
) -> torch.Tensor:
|
| 1266 |
+
if state.fm_seq_len <= 0:
|
| 1267 |
+
raise RuntimeError("FM sequence length must be positive before decode.")
|
| 1268 |
+
pos_ids.zero_()
|
| 1269 |
+
latent_start = pos_ids.size(-1) - self.core.latent_patch_size
|
| 1270 |
+
pos_ids[:, : state.fm_seq_len] = torch.arange(
|
| 1271 |
+
state.fm_seq_len,
|
| 1272 |
+
device=pos_ids.device,
|
| 1273 |
+
dtype=pos_ids.dtype,
|
| 1274 |
+
)
|
| 1275 |
+
pos_ids[:, latent_start:] = torch.arange(
|
| 1276 |
+
state.fm_seq_len,
|
| 1277 |
+
state.fm_seq_len + self.core.latent_patch_size,
|
| 1278 |
+
device=pos_ids.device,
|
| 1279 |
+
dtype=pos_ids.dtype,
|
| 1280 |
+
)
|
| 1281 |
+
return pos_ids
|
| 1282 |
+
|
| 1283 |
+
def _prepare_fm_decode_inputs(
|
| 1284 |
+
self,
|
| 1285 |
+
state: _GenerateState,
|
| 1286 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 1287 |
+
sequence = state.fm_sequence
|
| 1288 |
+
cfg_sequence = state.fm_cfg_sequence
|
| 1289 |
+
if sequence is None or cfg_sequence is None:
|
| 1290 |
+
raise RuntimeError("FM static buffers are not initialized.")
|
| 1291 |
+
history_bucket_capacity = self._resolve_fm_history_bucket_capacity(
|
| 1292 |
+
state.fm_seq_len
|
| 1293 |
+
)
|
| 1294 |
+
total_len = history_bucket_capacity + self.core.latent_patch_size
|
| 1295 |
+
workspace = self._get_fm_decode_workspace(
|
| 1296 |
+
total_len=total_len,
|
| 1297 |
+
device=sequence.device,
|
| 1298 |
+
dtype=sequence.dtype,
|
| 1299 |
+
)
|
| 1300 |
+
workspace["input_sequence"][:, : state.fm_seq_len].copy_(
|
| 1301 |
+
sequence[:, : state.fm_seq_len]
|
| 1302 |
+
)
|
| 1303 |
+
workspace["cfg_sequence"][:, : state.fm_seq_len].copy_(
|
| 1304 |
+
cfg_sequence[:, : state.fm_seq_len]
|
| 1305 |
+
)
|
| 1306 |
+
return (
|
| 1307 |
+
workspace["input_sequence"],
|
| 1308 |
+
workspace["cfg_sequence"],
|
| 1309 |
+
workspace["attn_mask"],
|
| 1310 |
+
workspace["pos_ids"],
|
| 1311 |
+
history_bucket_capacity,
|
| 1312 |
+
)
|
| 1313 |
+
|
| 1314 |
+
def _append_to_fm_buffer(
|
| 1315 |
+
self,
|
| 1316 |
+
buffer: torch.Tensor | None,
|
| 1317 |
+
state: _GenerateState,
|
| 1318 |
+
chunk: torch.Tensor,
|
| 1319 |
+
) -> tuple[int, int]:
|
| 1320 |
+
if buffer is None:
|
| 1321 |
+
raise RuntimeError("FM static buffer is not initialized.")
|
| 1322 |
+
start = state.fm_seq_len
|
| 1323 |
+
end = start + chunk.size(1)
|
| 1324 |
+
if end > state.fm_capacity:
|
| 1325 |
+
raise RuntimeError(
|
| 1326 |
+
"FM StaticBuffer capacity exceeded: "
|
| 1327 |
+
f"next_length={end} capacity={state.fm_capacity}."
|
| 1328 |
+
)
|
| 1329 |
+
buffer[:, start:end].copy_(chunk.to(buffer.dtype))
|
| 1330 |
+
return start, end
|
| 1331 |
+
|
| 1332 |
+
def _append_hidden_chunk(
|
| 1333 |
+
self, state: _GenerateState, hidden_chunk: torch.Tensor
|
| 1334 |
+
) -> None:
|
| 1335 |
+
last_hidden = hidden_chunk[:, -self.core.hidden_patch_size :, :]
|
| 1336 |
+
projected = self.core.hidden_proj(last_hidden)
|
| 1337 |
+
null_projected = self.core.hidden_proj(torch.zeros_like(last_hidden))
|
| 1338 |
+
_start, end = self._append_to_fm_buffer(
|
| 1339 |
+
state.fm_sequence,
|
| 1340 |
+
state,
|
| 1341 |
+
projected,
|
| 1342 |
+
)
|
| 1343 |
+
cfg_buffer = state.fm_cfg_sequence
|
| 1344 |
+
if cfg_buffer is None:
|
| 1345 |
+
raise RuntimeError("FM cfg static buffer is not initialized.")
|
| 1346 |
+
cfg_buffer[:, state.fm_seq_len : end].copy_(null_projected.to(cfg_buffer.dtype))
|
| 1347 |
+
state.fm_seq_len = end
|
| 1348 |
+
|
| 1349 |
+
def _append_history_chunk(
|
| 1350 |
+
self, state: _GenerateState, latent_chunk: torch.Tensor
|
| 1351 |
+
) -> None:
|
| 1352 |
+
history_latent = self.core.latent_proj(latent_chunk)
|
| 1353 |
+
_start, end = self._append_to_fm_buffer(
|
| 1354 |
+
state.fm_sequence,
|
| 1355 |
+
state,
|
| 1356 |
+
history_latent,
|
| 1357 |
+
)
|
| 1358 |
+
cfg_buffer = state.fm_cfg_sequence
|
| 1359 |
+
if cfg_buffer is None:
|
| 1360 |
+
raise RuntimeError("FM cfg static buffer is not initialized.")
|
| 1361 |
+
cfg_buffer[:, state.fm_seq_len : end].copy_(history_latent.to(cfg_buffer.dtype))
|
| 1362 |
+
state.fm_seq_len = end
|
| 1363 |
+
|
| 1364 |
+
def _consume_text_schedule(
|
| 1365 |
+
self,
|
| 1366 |
+
generation_schedule: torch.Tensor,
|
| 1367 |
+
*,
|
| 1368 |
+
position: int,
|
| 1369 |
+
next_audio_position: int,
|
| 1370 |
+
state: _GenerateState,
|
| 1371 |
+
) -> int:
|
| 1372 |
+
with measure_inference("LLM"):
|
| 1373 |
+
text_chunk = generation_schedule[:, position:next_audio_position]
|
| 1374 |
+
_, state.llm_hiddens, _, state.llm_cache = self.core.step_llm(
|
| 1375 |
+
input_ids=text_chunk,
|
| 1376 |
+
past_key_values=state.llm_cache,
|
| 1377 |
+
)
|
| 1378 |
+
self._append_hidden_chunk(state, state.llm_hiddens)
|
| 1379 |
+
return next_audio_position
|
| 1380 |
+
|
| 1381 |
+
def _locate_prefill_boundary(
|
| 1382 |
+
self,
|
| 1383 |
+
*,
|
| 1384 |
+
span_positions: torch.Tensor,
|
| 1385 |
+
prompt_patch_count: int,
|
| 1386 |
+
) -> tuple[int, torch.Tensor]:
|
| 1387 |
+
if span_positions.numel() > prompt_patch_count:
|
| 1388 |
+
return int(span_positions[prompt_patch_count].item()), span_positions[
|
| 1389 |
+
:prompt_patch_count
|
| 1390 |
+
]
|
| 1391 |
+
raise RuntimeError(
|
| 1392 |
+
"Prefill boundary discovery failed despite prior schedule validation."
|
| 1393 |
+
)
|
| 1394 |
+
|
| 1395 |
+
@staticmethod
|
| 1396 |
+
def _find_audio_span_positions(
|
| 1397 |
+
generation_schedule: torch.Tensor,
|
| 1398 |
+
*,
|
| 1399 |
+
audio_placeholder_ids: set[int],
|
| 1400 |
+
) -> torch.Tensor:
|
| 1401 |
+
schedule = generation_schedule[0]
|
| 1402 |
+
placeholder_ids = torch.tensor(
|
| 1403 |
+
sorted(audio_placeholder_ids),
|
| 1404 |
+
device=schedule.device,
|
| 1405 |
+
dtype=schedule.dtype,
|
| 1406 |
+
)
|
| 1407 |
+
return torch.nonzero(
|
| 1408 |
+
torch.isin(schedule, placeholder_ids),
|
| 1409 |
+
as_tuple=False,
|
| 1410 |
+
).squeeze(-1)
|
| 1411 |
+
|
| 1412 |
+
@staticmethod
|
| 1413 |
+
def _next_token_is_audio_span(
|
| 1414 |
+
generation_schedule: torch.Tensor,
|
| 1415 |
+
*,
|
| 1416 |
+
position: int,
|
| 1417 |
+
audio_placeholder_ids: set[int],
|
| 1418 |
+
) -> bool:
|
| 1419 |
+
next_position = position + 1
|
| 1420 |
+
if next_position >= generation_schedule.size(1):
|
| 1421 |
+
return False
|
| 1422 |
+
return int(generation_schedule[0, next_position].item()) in audio_placeholder_ids
|
| 1423 |
+
|
| 1424 |
+
def _build_prefill_inputs_embeds(
|
| 1425 |
+
self,
|
| 1426 |
+
generation_schedule: torch.Tensor,
|
| 1427 |
+
*,
|
| 1428 |
+
prompt_patch_embeddings: torch.Tensor | None,
|
| 1429 |
+
prompt_span_positions: torch.Tensor,
|
| 1430 |
+
) -> torch.Tensor:
|
| 1431 |
+
inputs_embeds = self.core.llm.get_input_embeddings()(
|
| 1432 |
+
generation_schedule
|
| 1433 |
+
).clone()
|
| 1434 |
+
if prompt_span_positions.numel() > 0:
|
| 1435 |
+
if prompt_patch_embeddings is None:
|
| 1436 |
+
raise RuntimeError(
|
| 1437 |
+
"Prompt patch embeddings are required when prefill includes prompt audio spans."
|
| 1438 |
+
)
|
| 1439 |
+
patch_embeddings = prompt_patch_embeddings[
|
| 1440 |
+
:, : prompt_span_positions.numel()
|
| 1441 |
+
].to(inputs_embeds.dtype)
|
| 1442 |
+
if patch_embeddings.size(1) != prompt_span_positions.numel():
|
| 1443 |
+
raise RuntimeError(
|
| 1444 |
+
f"Prompt patch embeddings ({patch_embeddings.size(1)}) do not match prompt span count ({prompt_span_positions.numel()})."
|
| 1445 |
+
)
|
| 1446 |
+
inputs_embeds[:, prompt_span_positions, :] = patch_embeddings
|
| 1447 |
+
return inputs_embeds
|
| 1448 |
+
|
| 1449 |
+
def _prefill(
|
| 1450 |
+
self,
|
| 1451 |
+
generation_schedule: torch.Tensor,
|
| 1452 |
+
*,
|
| 1453 |
+
state: _GenerateState,
|
| 1454 |
+
span_positions: torch.Tensor,
|
| 1455 |
+
prompt_patches: torch.Tensor | None,
|
| 1456 |
+
prompt_patch_embeddings: torch.Tensor | None,
|
| 1457 |
+
audio_placeholder_ids: set[int],
|
| 1458 |
+
) -> int:
|
| 1459 |
+
prompt_patch_count = (
|
| 1460 |
+
0 if prompt_patches is None else int(prompt_patches.size(1))
|
| 1461 |
+
)
|
| 1462 |
+
prefill_end, prompt_span_positions = self._locate_prefill_boundary(
|
| 1463 |
+
span_positions=span_positions,
|
| 1464 |
+
prompt_patch_count=prompt_patch_count,
|
| 1465 |
+
)
|
| 1466 |
+
if prefill_end == 0:
|
| 1467 |
+
return 0
|
| 1468 |
+
inputs_embeds = self._build_prefill_inputs_embeds(
|
| 1469 |
+
generation_schedule[:, :prefill_end],
|
| 1470 |
+
prompt_patch_embeddings=prompt_patch_embeddings,
|
| 1471 |
+
prompt_span_positions=prompt_span_positions,
|
| 1472 |
+
)
|
| 1473 |
+
with measure_inference("LLM"):
|
| 1474 |
+
_, llm_hiddens, _, state.llm_cache = self.core.step_llm(
|
| 1475 |
+
inputs_embeds=inputs_embeds,
|
| 1476 |
+
past_key_values=state.llm_cache,
|
| 1477 |
+
)
|
| 1478 |
+
state.llm_hiddens = llm_hiddens[:, -1:, :]
|
| 1479 |
+
|
| 1480 |
+
cursor = 0
|
| 1481 |
+
for prompt_index, span_position in enumerate(prompt_span_positions.tolist()):
|
| 1482 |
+
if span_position > cursor:
|
| 1483 |
+
self._append_hidden_chunk(
|
| 1484 |
+
state, llm_hiddens[:, span_position - 1 : span_position, :]
|
| 1485 |
+
)
|
| 1486 |
+
self._append_history_chunk(state, prompt_patches[:, prompt_index])
|
| 1487 |
+
if self._next_token_is_audio_span(
|
| 1488 |
+
generation_schedule,
|
| 1489 |
+
position=span_position,
|
| 1490 |
+
audio_placeholder_ids=audio_placeholder_ids,
|
| 1491 |
+
):
|
| 1492 |
+
self._append_hidden_chunk(
|
| 1493 |
+
state, llm_hiddens[:, span_position : span_position + 1, :]
|
| 1494 |
+
)
|
| 1495 |
+
cursor = span_position + 1
|
| 1496 |
+
if prefill_end > cursor:
|
| 1497 |
+
self._append_hidden_chunk(
|
| 1498 |
+
state, llm_hiddens[:, prefill_end - 1 : prefill_end, :]
|
| 1499 |
+
)
|
| 1500 |
+
return prefill_end
|
| 1501 |
+
|
| 1502 |
+
def _decode_next_audio(
|
| 1503 |
+
self,
|
| 1504 |
+
state: _GenerateState,
|
| 1505 |
+
*,
|
| 1506 |
+
device: torch.device,
|
| 1507 |
+
g_cond: torch.Tensor | None,
|
| 1508 |
+
ode_method: str,
|
| 1509 |
+
num_steps: int,
|
| 1510 |
+
guidance_scale: float,
|
| 1511 |
+
) -> torch.Tensor:
|
| 1512 |
+
if state.fm_seq_len <= 0:
|
| 1513 |
+
raise RuntimeError(
|
| 1514 |
+
"Cannot decode audio before any conditioning state has been prefetched."
|
| 1515 |
+
)
|
| 1516 |
+
if state.fm_sequence is None or state.fm_cfg_sequence is None:
|
| 1517 |
+
raise RuntimeError("FM static buffers are not initialized.")
|
| 1518 |
+
if state.fm_null_g_cond is None:
|
| 1519 |
+
raise RuntimeError("FM null conditioning buffer is not initialized.")
|
| 1520 |
+
fm_sequence, fm_cfg_sequence, fm_attn_mask, fm_pos_ids, history_bucket_capacity = (
|
| 1521 |
+
self._prepare_fm_decode_inputs(state)
|
| 1522 |
+
)
|
| 1523 |
+
compile_signature = (
|
| 1524 |
+
(history_bucket_capacity, state.fm_sequence.dtype)
|
| 1525 |
+
if self._optimize_enabled
|
| 1526 |
+
else (state.fm_seq_len, state.fm_sequence.dtype)
|
| 1527 |
+
)
|
| 1528 |
+
if g_cond is None:
|
| 1529 |
+
g_cond = state.fm_null_g_cond
|
| 1530 |
+
else:
|
| 1531 |
+
g_cond = g_cond.to(
|
| 1532 |
+
device=state.fm_null_g_cond.device,
|
| 1533 |
+
dtype=state.fm_null_g_cond.dtype,
|
| 1534 |
+
)
|
| 1535 |
+
with measure_inference("FM"):
|
| 1536 |
+
attn_mask = self._build_fm_attn_mask(
|
| 1537 |
+
state=state,
|
| 1538 |
+
attn_mask=fm_attn_mask,
|
| 1539 |
+
)
|
| 1540 |
+
pos_ids = self._build_fm_pos_ids(
|
| 1541 |
+
state=state,
|
| 1542 |
+
pos_ids=fm_pos_ids,
|
| 1543 |
+
)
|
| 1544 |
+
if self.core.mode == "meanflow":
|
| 1545 |
+
fm_solver_step = self._get_compiled_method(
|
| 1546 |
+
"FM.meanflow.solver_step",
|
| 1547 |
+
self.core,
|
| 1548 |
+
"meanflow_solver_step",
|
| 1549 |
+
signature=compile_signature,
|
| 1550 |
+
)
|
| 1551 |
+
return self.core._meanflow_step_fm(
|
| 1552 |
+
input_sequence=fm_sequence,
|
| 1553 |
+
attn_mask=attn_mask,
|
| 1554 |
+
pos_ids=pos_ids,
|
| 1555 |
+
patch_size=self.core.latent_patch_size,
|
| 1556 |
+
g_cond=g_cond,
|
| 1557 |
+
nfe=num_steps,
|
| 1558 |
+
solver_step=fm_solver_step,
|
| 1559 |
+
)
|
| 1560 |
+
|
| 1561 |
+
fm_solver_step = self._get_compiled_method(
|
| 1562 |
+
"FM.flow_matching.solver_step",
|
| 1563 |
+
self.core,
|
| 1564 |
+
"fm_solver_step",
|
| 1565 |
+
signature=compile_signature,
|
| 1566 |
+
)
|
| 1567 |
+
return self.core._flow_matching_step_fm(
|
| 1568 |
+
input_sequence=fm_sequence,
|
| 1569 |
+
cfg_sequence=fm_cfg_sequence,
|
| 1570 |
+
attn_mask=attn_mask,
|
| 1571 |
+
pos_ids=pos_ids,
|
| 1572 |
+
hidden_size=self.core.hidden_patch_size,
|
| 1573 |
+
patch_size=self.core.latent_patch_size,
|
| 1574 |
+
g_cond=g_cond,
|
| 1575 |
+
ode_method=ode_method,
|
| 1576 |
+
num_steps=num_steps,
|
| 1577 |
+
guidance_scale=guidance_scale,
|
| 1578 |
+
solver_step=fm_solver_step,
|
| 1579 |
+
)
|
| 1580 |
+
|
| 1581 |
+
def _consume_audio_patch(
|
| 1582 |
+
self,
|
| 1583 |
+
state: _GenerateState,
|
| 1584 |
+
*,
|
| 1585 |
+
audio_patch: torch.Tensor,
|
| 1586 |
+
) -> None:
|
| 1587 |
+
audio_patch_for_llm = self.core.io_helper.denormalize(audio_patch)
|
| 1588 |
+
self._append_history_chunk(state, audio_patch)
|
| 1589 |
+
current_seq_len = (
|
| 1590 |
+
0
|
| 1591 |
+
if state.patch_encoder_state is None
|
| 1592 |
+
else state.patch_encoder_state.seq_len
|
| 1593 |
+
)
|
| 1594 |
+
self._ensure_patch_encoder_state_capacity(
|
| 1595 |
+
state,
|
| 1596 |
+
required_seq_len=current_seq_len + self.core.patch_encoder.out_ds_rate,
|
| 1597 |
+
device=audio_patch_for_llm.device,
|
| 1598 |
+
dtype=(
|
| 1599 |
+
state.fm_sequence.dtype
|
| 1600 |
+
if state.fm_sequence is not None
|
| 1601 |
+
else audio_patch_for_llm.dtype
|
| 1602 |
+
),
|
| 1603 |
+
)
|
| 1604 |
+
patch_encoder_decode = self._get_compiled_method(
|
| 1605 |
+
"patch_encoder.decode_patch",
|
| 1606 |
+
self.core.patch_encoder,
|
| 1607 |
+
"decode_patch",
|
| 1608 |
+
signature=self._patch_encoder_compile_signature(state.patch_encoder_state),
|
| 1609 |
+
)
|
| 1610 |
+
patch_positions = (
|
| 1611 |
+
torch.arange(
|
| 1612 |
+
self.core.patch_encoder.out_ds_rate,
|
| 1613 |
+
device=audio_patch_for_llm.device,
|
| 1614 |
+
dtype=torch.long,
|
| 1615 |
+
)
|
| 1616 |
+
+ state.patch_encoder_state.seq_len
|
| 1617 |
+
)
|
| 1618 |
+
with measure_inference("patch_encoder"):
|
| 1619 |
+
llm_embedding, conv_tail = patch_encoder_decode(
|
| 1620 |
+
audio_patch_for_llm,
|
| 1621 |
+
state.patch_encoder_state.conv_tail,
|
| 1622 |
+
state.patch_encoder_state.layer_caches,
|
| 1623 |
+
patch_positions,
|
| 1624 |
+
)
|
| 1625 |
+
state.patch_encoder_state.conv_tail.copy_(conv_tail)
|
| 1626 |
+
state.patch_encoder_state.seq_len += self.core.patch_encoder.out_ds_rate
|
| 1627 |
+
with measure_inference("LLM"):
|
| 1628 |
+
_, state.llm_hiddens, _, state.llm_cache = self.core.step_llm(
|
| 1629 |
+
inputs_embeds=llm_embedding,
|
| 1630 |
+
past_key_values=state.llm_cache,
|
| 1631 |
+
)
|
| 1632 |
+
|
| 1633 |
+
def _decode(
|
| 1634 |
+
self,
|
| 1635 |
+
generation_schedule: torch.Tensor,
|
| 1636 |
+
*,
|
| 1637 |
+
position: int,
|
| 1638 |
+
state: _GenerateState,
|
| 1639 |
+
audio_placeholder_ids: set[int],
|
| 1640 |
+
span_positions: torch.Tensor,
|
| 1641 |
+
device: torch.device,
|
| 1642 |
+
g_cond: torch.Tensor | None,
|
| 1643 |
+
ode_method: str,
|
| 1644 |
+
num_steps: int,
|
| 1645 |
+
guidance_scale: float,
|
| 1646 |
+
eos_threshold: float,
|
| 1647 |
+
) -> Iterator[torch.Tensor]:
|
| 1648 |
+
span_cursor = torch.searchsorted(
|
| 1649 |
+
span_positions,
|
| 1650 |
+
torch.tensor(
|
| 1651 |
+
position,
|
| 1652 |
+
device=span_positions.device,
|
| 1653 |
+
dtype=span_positions.dtype,
|
| 1654 |
+
),
|
| 1655 |
+
).item()
|
| 1656 |
+
while position < generation_schedule.size(1):
|
| 1657 |
+
token_id = int(generation_schedule[0, position].item())
|
| 1658 |
+
if token_id in audio_placeholder_ids:
|
| 1659 |
+
stop_after_current_audio = self._should_stop_after_current_audio(
|
| 1660 |
+
state,
|
| 1661 |
+
eos_threshold=eos_threshold,
|
| 1662 |
+
)
|
| 1663 |
+
audio_patch = self._decode_next_audio(
|
| 1664 |
+
state,
|
| 1665 |
+
device=device,
|
| 1666 |
+
g_cond=g_cond,
|
| 1667 |
+
ode_method=ode_method,
|
| 1668 |
+
num_steps=num_steps,
|
| 1669 |
+
guidance_scale=guidance_scale,
|
| 1670 |
+
)
|
| 1671 |
+
self._consume_audio_patch(
|
| 1672 |
+
state,
|
| 1673 |
+
audio_patch=audio_patch,
|
| 1674 |
+
)
|
| 1675 |
+
if self._next_token_is_audio_span(
|
| 1676 |
+
generation_schedule,
|
| 1677 |
+
position=position,
|
| 1678 |
+
audio_placeholder_ids=audio_placeholder_ids,
|
| 1679 |
+
):
|
| 1680 |
+
self._append_hidden_chunk(state, state.llm_hiddens)
|
| 1681 |
+
position += 1
|
| 1682 |
+
span_cursor += 1
|
| 1683 |
+
yield audio_patch
|
| 1684 |
+
if stop_after_current_audio:
|
| 1685 |
+
state.end_flag = True
|
| 1686 |
+
return
|
| 1687 |
+
continue
|
| 1688 |
+
next_audio_position = (
|
| 1689 |
+
int(span_positions[span_cursor].item())
|
| 1690 |
+
if span_cursor < span_positions.numel()
|
| 1691 |
+
else generation_schedule.size(1)
|
| 1692 |
+
)
|
| 1693 |
+
position = self._consume_text_schedule(
|
| 1694 |
+
generation_schedule,
|
| 1695 |
+
position=position,
|
| 1696 |
+
next_audio_position=next_audio_position,
|
| 1697 |
+
state=state,
|
| 1698 |
+
)
|
| 1699 |
+
|
| 1700 |
+
def _should_stop_after_current_audio(
|
| 1701 |
+
self, state: _GenerateState, *, eos_threshold: float
|
| 1702 |
+
) -> bool:
|
| 1703 |
+
if state.llm_hiddens is None:
|
| 1704 |
+
return False
|
| 1705 |
+
eos = (
|
| 1706 |
+
self.core.eos_proj(state.llm_hiddens).softmax(dim=-1)[:, -1, 1]
|
| 1707 |
+
> eos_threshold
|
| 1708 |
+
)
|
| 1709 |
+
return state.end_flag or bool(eos.item())
|
| 1710 |
+
|
| 1711 |
+
# endregion Prompt conditioning and decode state helpers
|
| 1712 |
+
|
| 1713 |
+
# region Public generation APIs
|
| 1714 |
+
@torch.no_grad()
|
| 1715 |
+
def _generate_latents_stream(
|
| 1716 |
+
self,
|
| 1717 |
+
data: dict[str, Any],
|
| 1718 |
+
*,
|
| 1719 |
+
precision: str,
|
| 1720 |
+
ode_method: str,
|
| 1721 |
+
num_steps: int,
|
| 1722 |
+
guidance_scale: float,
|
| 1723 |
+
speaker_scale: float = 1.5,
|
| 1724 |
+
eos_threshold: float = 0.8,
|
| 1725 |
+
) -> Iterator[torch.Tensor]:
|
| 1726 |
+
dtype = get_dtype(precision)
|
| 1727 |
+
device = next(self.core.parameters()).device
|
| 1728 |
+
use_amp = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
|
| 1729 |
+
with torch.autocast(device_type=device.type, dtype=dtype, enabled=use_amp):
|
| 1730 |
+
generation_schedule: torch.Tensor = data["generation_schedule"]
|
| 1731 |
+
if generation_schedule.size(0) != 1:
|
| 1732 |
+
raise ValueError(
|
| 1733 |
+
"DotsTtsModel.generate expects batch size 1 for generation_schedule."
|
| 1734 |
+
)
|
| 1735 |
+
|
| 1736 |
+
use_prompt_prefill = data.get("prompt_audio") is not None and bool(
|
| 1737 |
+
data.get("prompt_text")
|
| 1738 |
+
)
|
| 1739 |
+
prompt_conditioning = self._prepare_prompt_conditioning(
|
| 1740 |
+
data.get("prompt_audio"),
|
| 1741 |
+
use_prompt_prefill=use_prompt_prefill,
|
| 1742 |
+
speaker_scale=speaker_scale,
|
| 1743 |
+
)
|
| 1744 |
+
has_prompt_prefill = prompt_conditioning.prompt_patches is not None
|
| 1745 |
+
prompt_patch_count = (
|
| 1746 |
+
0
|
| 1747 |
+
if not has_prompt_prefill
|
| 1748 |
+
else int(prompt_conditioning.prompt_patches.size(1))
|
| 1749 |
+
)
|
| 1750 |
+
audio_placeholder_ids = set(self.core.audio_span_token_ids)
|
| 1751 |
+
span_positions = self._find_audio_span_positions(
|
| 1752 |
+
generation_schedule,
|
| 1753 |
+
audio_placeholder_ids=audio_placeholder_ids,
|
| 1754 |
+
)
|
| 1755 |
+
span_count = int(span_positions.numel())
|
| 1756 |
+
minimum_required_spans = prompt_patch_count + 1
|
| 1757 |
+
if span_count < minimum_required_spans:
|
| 1758 |
+
raise ValueError(
|
| 1759 |
+
f"generation_schedule provides {span_count} audio spans, but prompt prefill requires "
|
| 1760 |
+
f"{prompt_patch_count} spans and generation requires at least one additional decode span."
|
| 1761 |
+
)
|
| 1762 |
+
logger.info(
|
| 1763 |
+
"Latent generation prepared: schedule_audio_spans={} prompt_patch_count={} "
|
| 1764 |
+
"minimum_required_spans={}",
|
| 1765 |
+
span_count,
|
| 1766 |
+
prompt_patch_count,
|
| 1767 |
+
minimum_required_spans,
|
| 1768 |
+
)
|
| 1769 |
+
|
| 1770 |
+
state = self._allocate_generate_state(
|
| 1771 |
+
max_audio_patch_count=span_count,
|
| 1772 |
+
device=device,
|
| 1773 |
+
dtype=dtype,
|
| 1774 |
+
)
|
| 1775 |
+
prompt_patch_embeddings = self._prefill_prompt_latents(
|
| 1776 |
+
prompt_conditioning.prompt_latents,
|
| 1777 |
+
state=state,
|
| 1778 |
+
)
|
| 1779 |
+
position = self._prefill(
|
| 1780 |
+
generation_schedule,
|
| 1781 |
+
state=state,
|
| 1782 |
+
span_positions=span_positions,
|
| 1783 |
+
prompt_patches=prompt_conditioning.prompt_patches,
|
| 1784 |
+
prompt_patch_embeddings=prompt_patch_embeddings,
|
| 1785 |
+
audio_placeholder_ids=audio_placeholder_ids,
|
| 1786 |
+
)
|
| 1787 |
+
|
| 1788 |
+
payload_patch_count = 0
|
| 1789 |
+
should_drop_regenerated_prompt_patch = has_prompt_prefill
|
| 1790 |
+
for audio_patch in self._decode(
|
| 1791 |
+
generation_schedule,
|
| 1792 |
+
position=position,
|
| 1793 |
+
state=state,
|
| 1794 |
+
audio_placeholder_ids=audio_placeholder_ids,
|
| 1795 |
+
span_positions=span_positions,
|
| 1796 |
+
device=device,
|
| 1797 |
+
g_cond=prompt_conditioning.g_cond,
|
| 1798 |
+
ode_method=ode_method,
|
| 1799 |
+
num_steps=num_steps,
|
| 1800 |
+
guidance_scale=guidance_scale,
|
| 1801 |
+
eos_threshold=eos_threshold,
|
| 1802 |
+
):
|
| 1803 |
+
if should_drop_regenerated_prompt_patch:
|
| 1804 |
+
should_drop_regenerated_prompt_patch = False
|
| 1805 |
+
continue
|
| 1806 |
+
payload_patch_count += 1
|
| 1807 |
+
if payload_patch_count == 1 or payload_patch_count % 10 == 0:
|
| 1808 |
+
logger.info(
|
| 1809 |
+
"Latent generation progress: payload_audio_patches={}",
|
| 1810 |
+
payload_patch_count,
|
| 1811 |
+
)
|
| 1812 |
+
yield self.core.io_helper.denormalize(audio_patch)
|
| 1813 |
+
|
| 1814 |
+
if payload_patch_count == 0:
|
| 1815 |
+
if has_prompt_prefill:
|
| 1816 |
+
raise RuntimeError(
|
| 1817 |
+
"Generation produced no payload latents after discarding the regenerated prompt-tail patch. "
|
| 1818 |
+
"This usually means EOS triggered immediately after prompt continuation "
|
| 1819 |
+
"or the generation schedule did not provide an effective decode span."
|
| 1820 |
+
)
|
| 1821 |
+
raise RuntimeError(
|
| 1822 |
+
"Generation produced no decodable latents. "
|
| 1823 |
+
"This usually means EOS triggered before the first decode patch "
|
| 1824 |
+
"or the generation schedule did not provide an effective decode span."
|
| 1825 |
+
)
|
| 1826 |
+
logger.info(
|
| 1827 |
+
"Latent generation completed: payload_audio_patches={}",
|
| 1828 |
+
payload_patch_count,
|
| 1829 |
+
)
|
| 1830 |
+
|
| 1831 |
+
@torch.no_grad()
|
| 1832 |
+
def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 1833 |
+
with measure_inference("latent_decoder"):
|
| 1834 |
+
return self.vocoder.inference_from_latents(
|
| 1835 |
+
latents.transpose(1, 2).float(),
|
| 1836 |
+
do_sample=False,
|
| 1837 |
+
)
|
| 1838 |
+
|
| 1839 |
+
@torch.no_grad()
|
| 1840 |
+
def _init_vocoder_stream_state(self) -> Any:
|
| 1841 |
+
return self.vocoder.init_stream_state(
|
| 1842 |
+
batch_size=1,
|
| 1843 |
+
chunk_size=self.core.latent_patch_size,
|
| 1844 |
+
)
|
| 1845 |
+
|
| 1846 |
+
@torch.no_grad()
|
| 1847 |
+
def _stream_vocoder_patch(
|
| 1848 |
+
self,
|
| 1849 |
+
latent_patch: torch.Tensor,
|
| 1850 |
+
*,
|
| 1851 |
+
stream_state: Any,
|
| 1852 |
+
) -> torch.Tensor:
|
| 1853 |
+
latents = latent_patch.transpose(1, 2)
|
| 1854 |
+
if not self._optimize_enabled:
|
| 1855 |
+
with measure_inference("vocoder"):
|
| 1856 |
+
return self.vocoder.stream_step(latents, stream_state)
|
| 1857 |
+
|
| 1858 |
+
valid_frames = min(
|
| 1859 |
+
stream_state.decoder.total_frames,
|
| 1860 |
+
stream_state.decoder.window.size(-1),
|
| 1861 |
+
)
|
| 1862 |
+
valid_frames_tensor = stream_state.decoder.window.new_tensor(
|
| 1863 |
+
valid_frames,
|
| 1864 |
+
dtype=torch.int64,
|
| 1865 |
+
)
|
| 1866 |
+
vocoder_step = self._get_compiled_method(
|
| 1867 |
+
"vocoder.step",
|
| 1868 |
+
self.vocoder,
|
| 1869 |
+
"compiled_stream_step",
|
| 1870 |
+
)
|
| 1871 |
+
with measure_inference("vocoder"):
|
| 1872 |
+
audio_window, hidden_h, hidden_c, new_window = vocoder_step(
|
| 1873 |
+
latents,
|
| 1874 |
+
stream_state.lstm_hidden[0],
|
| 1875 |
+
stream_state.lstm_hidden[1],
|
| 1876 |
+
stream_state.decoder.window,
|
| 1877 |
+
valid_frames_tensor,
|
| 1878 |
+
)
|
| 1879 |
+
stream_state.lstm_hidden = (hidden_h.clone(), hidden_c.clone())
|
| 1880 |
+
stream_state.decoder.window = new_window.clone()
|
| 1881 |
+
stream_state.decoder.total_frames += int(latents.size(-1))
|
| 1882 |
+
audio_chunk = self.vocoder._slice_stream_audio_window(
|
| 1883 |
+
audio_window,
|
| 1884 |
+
stream_state,
|
| 1885 |
+
final=False,
|
| 1886 |
+
)
|
| 1887 |
+
return audio_chunk.clone()
|
| 1888 |
+
|
| 1889 |
+
@torch.no_grad()
|
| 1890 |
+
def _flush_vocoder_stream(self, stream_state: Any) -> torch.Tensor:
|
| 1891 |
+
with measure_inference("vocoder"):
|
| 1892 |
+
return self.vocoder.stream_flush(stream_state)
|
| 1893 |
+
|
| 1894 |
+
@torch.no_grad()
|
| 1895 |
+
def generate_audio_stream(
|
| 1896 |
+
self,
|
| 1897 |
+
data: dict[str, Any],
|
| 1898 |
+
*,
|
| 1899 |
+
precision: str,
|
| 1900 |
+
ode_method: str,
|
| 1901 |
+
num_steps: int,
|
| 1902 |
+
guidance_scale: float,
|
| 1903 |
+
speaker_scale: float = 1.5,
|
| 1904 |
+
eos_threshold: float = 0.8,
|
| 1905 |
+
) -> Iterator[torch.Tensor]:
|
| 1906 |
+
stream_state = self._init_vocoder_stream_state()
|
| 1907 |
+
for latent_patch in self._generate_latents_stream(
|
| 1908 |
+
data,
|
| 1909 |
+
precision=precision,
|
| 1910 |
+
ode_method=ode_method,
|
| 1911 |
+
num_steps=num_steps,
|
| 1912 |
+
guidance_scale=guidance_scale,
|
| 1913 |
+
speaker_scale=speaker_scale,
|
| 1914 |
+
eos_threshold=eos_threshold,
|
| 1915 |
+
):
|
| 1916 |
+
audio_chunk = self._stream_vocoder_patch(
|
| 1917 |
+
latent_patch,
|
| 1918 |
+
stream_state=stream_state,
|
| 1919 |
+
)
|
| 1920 |
+
if audio_chunk.size(-1) > 0:
|
| 1921 |
+
yield audio_chunk
|
| 1922 |
+
|
| 1923 |
+
final_chunk = self._flush_vocoder_stream(stream_state)
|
| 1924 |
+
if final_chunk.size(-1) > 0:
|
| 1925 |
+
yield final_chunk
|
| 1926 |
+
|
| 1927 |
+
@torch.no_grad()
|
| 1928 |
+
def generate_audio(
|
| 1929 |
+
self,
|
| 1930 |
+
data: dict[str, Any],
|
| 1931 |
+
*,
|
| 1932 |
+
precision: str,
|
| 1933 |
+
ode_method: str,
|
| 1934 |
+
num_steps: int,
|
| 1935 |
+
guidance_scale: float,
|
| 1936 |
+
speaker_scale: float = 1.5,
|
| 1937 |
+
) -> torch.Tensor:
|
| 1938 |
+
latent_patches = list(
|
| 1939 |
+
self._generate_latents_stream(
|
| 1940 |
+
data,
|
| 1941 |
+
precision=precision,
|
| 1942 |
+
ode_method=ode_method,
|
| 1943 |
+
num_steps=num_steps,
|
| 1944 |
+
guidance_scale=guidance_scale,
|
| 1945 |
+
speaker_scale=speaker_scale,
|
| 1946 |
+
)
|
| 1947 |
+
)
|
| 1948 |
+
logger.info(
|
| 1949 |
+
"Vocoder decode started: latent_patch_count={}",
|
| 1950 |
+
len(latent_patches),
|
| 1951 |
+
)
|
| 1952 |
+
audio = self._decode_latents(torch.cat(latent_patches, dim=1))
|
| 1953 |
+
logger.info(
|
| 1954 |
+
"Vocoder decode completed: waveform_samples={}",
|
| 1955 |
+
audio.shape[-1],
|
| 1956 |
+
)
|
| 1957 |
+
return audio
|
| 1958 |
+
# endregion Public generation APIs
|
src/dots_tts/modules/__init__.py
ADDED
|
File without changes
|
src/dots_tts/modules/backbone/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Backbone modules."""
|
src/dots_tts/modules/backbone/dit.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from dots_tts.modules.backbone.layers import Mlp, MultiHeadAttention
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def modulate(x, shift, scale, **_kwargs):
|
| 10 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TimestepEmbedder(nn.Module):
|
| 14 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.mlp = nn.Sequential(
|
| 17 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 18 |
+
nn.SiLU(),
|
| 19 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 20 |
+
)
|
| 21 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 25 |
+
half = dim // 2
|
| 26 |
+
freqs = torch.exp(
|
| 27 |
+
-math.log(max_period)
|
| 28 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 29 |
+
/ half
|
| 30 |
+
).to(device=t.device)
|
| 31 |
+
args = t[:, None].float() * freqs[None]
|
| 32 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 33 |
+
if dim % 2:
|
| 34 |
+
embedding = torch.cat(
|
| 35 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
| 36 |
+
)
|
| 37 |
+
return embedding
|
| 38 |
+
|
| 39 |
+
def forward(self, t):
|
| 40 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 41 |
+
return self.mlp(t_freq)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class FinalLayer(nn.Module):
|
| 45 |
+
def __init__(self, hidden_size, output_size):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.adaLN_modulation = nn.Sequential(
|
| 48 |
+
nn.SiLU(),
|
| 49 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
|
| 50 |
+
)
|
| 51 |
+
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-5)
|
| 52 |
+
self.linear = nn.Linear(hidden_size, output_size, bias=True)
|
| 53 |
+
|
| 54 |
+
def forward(self, x, c, **_kwargs):
|
| 55 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 56 |
+
x = modulate(self.norm(x), shift, scale)
|
| 57 |
+
return self.linear(x)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class DiTBlock(nn.Module):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
attention: nn.Module,
|
| 64 |
+
ffn: nn.Module,
|
| 65 |
+
hidden_size: int = 1024,
|
| 66 |
+
modulation: bool = False,
|
| 67 |
+
eps: float = 1e-5,
|
| 68 |
+
**_kwargs,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.norm1 = nn.LayerNorm(
|
| 72 |
+
hidden_size, elementwise_affine=not modulation, eps=eps
|
| 73 |
+
)
|
| 74 |
+
self.norm2 = nn.LayerNorm(
|
| 75 |
+
hidden_size, elementwise_affine=not modulation, eps=eps
|
| 76 |
+
)
|
| 77 |
+
self.attn = attention
|
| 78 |
+
self.ffn = ffn
|
| 79 |
+
self.modulation = modulation
|
| 80 |
+
if modulation:
|
| 81 |
+
self.adaLN_modulation = nn.Sequential(
|
| 82 |
+
nn.SiLU(),
|
| 83 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def forward(self, x, condition=None, mask=None, **kwargs):
|
| 87 |
+
if condition is None:
|
| 88 |
+
assert not self.modulation, (
|
| 89 |
+
"Without global condition, must set modulation to False"
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
assert self.modulation, "With global condition, must set modulation to True"
|
| 93 |
+
shift_attn, scale_attn, gate_attn, shift_ffn, scale_ffn, gate_ffn = (
|
| 94 |
+
self.adaLN_modulation(condition).chunk(6, dim=1)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
if condition is not None:
|
| 98 |
+
pack_indices = kwargs.get("pack_indices")
|
| 99 |
+
if pack_indices is not None:
|
| 100 |
+
gate_attn = gate_attn[pack_indices]
|
| 101 |
+
gate_ffn = gate_ffn[pack_indices]
|
| 102 |
+
else:
|
| 103 |
+
gate_attn = gate_attn.unsqueeze(1)
|
| 104 |
+
gate_ffn = gate_ffn.unsqueeze(1)
|
| 105 |
+
|
| 106 |
+
if condition is not None:
|
| 107 |
+
x = x + gate_attn * self.attn(
|
| 108 |
+
modulate(self.norm1(x), shift_attn, scale_attn, **kwargs),
|
| 109 |
+
mask=mask,
|
| 110 |
+
**kwargs,
|
| 111 |
+
)
|
| 112 |
+
else:
|
| 113 |
+
x = x + self.attn(self.norm1(x), mask=mask, **kwargs)
|
| 114 |
+
|
| 115 |
+
if condition is not None:
|
| 116 |
+
x = x + gate_ffn * self.ffn(
|
| 117 |
+
modulate(self.norm2(x), shift_ffn, scale_ffn, **kwargs)
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
x = x + self.ffn(self.norm2(x), mask=mask)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class DiT(nn.Module):
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
in_dim,
|
| 128 |
+
out_dim,
|
| 129 |
+
transformer_config,
|
| 130 |
+
*,
|
| 131 |
+
mode: str = "flow_matching",
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
if mode not in {"flow_matching", "meanflow"}:
|
| 135 |
+
raise ValueError(
|
| 136 |
+
f"DiT mode must be 'flow_matching' or 'meanflow', got {mode!r}."
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
transformer_kwargs = transformer_config.to_dict()
|
| 140 |
+
model_dim = transformer_config.hidden_size
|
| 141 |
+
self.mode = mode
|
| 142 |
+
self.num_layers = transformer_config.num_layers
|
| 143 |
+
|
| 144 |
+
self.input_layer = nn.Linear(in_dim, model_dim)
|
| 145 |
+
self.time_embedder = TimestepEmbedder(model_dim)
|
| 146 |
+
if mode == "meanflow":
|
| 147 |
+
self.duration_embedder = TimestepEmbedder(model_dim)
|
| 148 |
+
|
| 149 |
+
self.blocks = nn.ModuleList()
|
| 150 |
+
for i in range(self.num_layers):
|
| 151 |
+
attn_block = MultiHeadAttention(**transformer_kwargs, name=f"layer_{i}")
|
| 152 |
+
ffn_block = Mlp(
|
| 153 |
+
act_layer=lambda: nn.GELU(approximate="tanh"), **transformer_kwargs
|
| 154 |
+
)
|
| 155 |
+
self.blocks.append(
|
| 156 |
+
DiTBlock(attention=attn_block, ffn=ffn_block, **transformer_kwargs)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.output_layer = FinalLayer(model_dim, out_dim)
|
| 160 |
+
self.initialize_weights()
|
| 161 |
+
|
| 162 |
+
def initialize_weights(self):
|
| 163 |
+
def _basic_init(module):
|
| 164 |
+
if isinstance(module, nn.Linear):
|
| 165 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 166 |
+
if module.bias is not None:
|
| 167 |
+
nn.init.constant_(module.bias, 0)
|
| 168 |
+
|
| 169 |
+
self.apply(_basic_init)
|
| 170 |
+
|
| 171 |
+
nn.init.normal_(self.time_embedder.mlp[0].weight, std=0.02)
|
| 172 |
+
nn.init.normal_(self.time_embedder.mlp[2].weight, std=0.02)
|
| 173 |
+
|
| 174 |
+
for block in self.blocks:
|
| 175 |
+
if hasattr(block, "adaLN_modulation"):
|
| 176 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 177 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 178 |
+
|
| 179 |
+
nn.init.constant_(self.output_layer.adaLN_modulation[-1].weight, 0)
|
| 180 |
+
nn.init.constant_(self.output_layer.adaLN_modulation[-1].bias, 0)
|
| 181 |
+
nn.init.constant_(self.output_layer.linear.weight, 0)
|
| 182 |
+
nn.init.constant_(self.output_layer.linear.bias, 0)
|
| 183 |
+
|
| 184 |
+
def forward(
|
| 185 |
+
self,
|
| 186 |
+
x,
|
| 187 |
+
timesteps,
|
| 188 |
+
duration: torch.Tensor | None = None,
|
| 189 |
+
mask=None,
|
| 190 |
+
attn_mask=None,
|
| 191 |
+
g_cond: torch.Tensor | None = None,
|
| 192 |
+
**kwargs,
|
| 193 |
+
):
|
| 194 |
+
t = self.time_embedder(timesteps)
|
| 195 |
+
c = t
|
| 196 |
+
duration_embedder = getattr(self, "duration_embedder", None)
|
| 197 |
+
if duration_embedder is not None and duration is not None:
|
| 198 |
+
c = c + duration_embedder(duration)
|
| 199 |
+
if g_cond is not None:
|
| 200 |
+
c = c + g_cond
|
| 201 |
+
|
| 202 |
+
x = self.input_layer(x)
|
| 203 |
+
for block in self.blocks:
|
| 204 |
+
x = block(x, c, mask=attn_mask, **kwargs)
|
| 205 |
+
return self.output_layer(x, c, **kwargs)
|
src/dots_tts/modules/backbone/layers.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Dropout(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self, p: float = 0.5, inplace: bool = False, force_drop: bool = False, **_kwargs
|
| 10 |
+
):
|
| 11 |
+
super().__init__()
|
| 12 |
+
if p < 0.0 or p > 1.0:
|
| 13 |
+
raise ValueError(
|
| 14 |
+
f"dropout probability has to be between 0 and 1, but got {p}"
|
| 15 |
+
)
|
| 16 |
+
self.p = p
|
| 17 |
+
self.inplace = inplace
|
| 18 |
+
self.force_drop = force_drop
|
| 19 |
+
|
| 20 |
+
def forward(self, x, **_kwargs):
|
| 21 |
+
return F.dropout(
|
| 22 |
+
x,
|
| 23 |
+
p=self.p,
|
| 24 |
+
training=True if self.force_drop else self.training,
|
| 25 |
+
inplace=self.inplace,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Conv1d(nn.Conv1d):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
in_channels: int,
|
| 33 |
+
out_channels: int,
|
| 34 |
+
kernel_size: int = 1,
|
| 35 |
+
stride: int = 1,
|
| 36 |
+
dilation: int = 1,
|
| 37 |
+
groups: int = 1,
|
| 38 |
+
padding_mode: str = "zeros",
|
| 39 |
+
bias: bool = True,
|
| 40 |
+
padding=None,
|
| 41 |
+
causal: bool = False,
|
| 42 |
+
**_kwargs,
|
| 43 |
+
):
|
| 44 |
+
self.causal = causal
|
| 45 |
+
if padding is None:
|
| 46 |
+
if causal:
|
| 47 |
+
padding = 0
|
| 48 |
+
self.left_padding = dilation * (kernel_size - 1)
|
| 49 |
+
else:
|
| 50 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
| 51 |
+
|
| 52 |
+
super().__init__(
|
| 53 |
+
in_channels,
|
| 54 |
+
out_channels,
|
| 55 |
+
kernel_size,
|
| 56 |
+
stride=stride,
|
| 57 |
+
padding=padding,
|
| 58 |
+
dilation=dilation,
|
| 59 |
+
groups=groups,
|
| 60 |
+
padding_mode=padding_mode,
|
| 61 |
+
bias=bias,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.in_channels = in_channels
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
if self.causal:
|
| 68 |
+
x = F.pad(x.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2)
|
| 69 |
+
return super().forward(x)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ConvTranspose1d(nn.ConvTranspose1d):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
in_channels: int,
|
| 76 |
+
out_channels: int,
|
| 77 |
+
kernel_size: int,
|
| 78 |
+
stride: int = 1,
|
| 79 |
+
output_padding: int = 0,
|
| 80 |
+
groups: int = 1,
|
| 81 |
+
bias: bool = True,
|
| 82 |
+
dilation: int = 1,
|
| 83 |
+
padding=None,
|
| 84 |
+
padding_mode: str = "zeros",
|
| 85 |
+
causal: bool = False,
|
| 86 |
+
**_kwargs,
|
| 87 |
+
):
|
| 88 |
+
if padding is None:
|
| 89 |
+
padding = 0 if causal else (kernel_size - stride) // 2
|
| 90 |
+
if causal:
|
| 91 |
+
assert padding == 0, "padding is not allowed in causal ConvTranspose1d."
|
| 92 |
+
assert kernel_size == 2 * stride, (
|
| 93 |
+
"kernel_size must be equal to 2*stride in Causal ConvTranspose1d."
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
super().__init__(
|
| 97 |
+
in_channels,
|
| 98 |
+
out_channels,
|
| 99 |
+
kernel_size,
|
| 100 |
+
stride=stride,
|
| 101 |
+
padding=padding,
|
| 102 |
+
output_padding=output_padding,
|
| 103 |
+
groups=groups,
|
| 104 |
+
bias=bias,
|
| 105 |
+
dilation=dilation,
|
| 106 |
+
padding_mode=padding_mode,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.causal = causal
|
| 110 |
+
self.stride = stride
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
x = super().forward(x)
|
| 114 |
+
if self.causal:
|
| 115 |
+
x = x[:, :, : -self.stride]
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class Mlp(nn.Module):
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
hidden_size,
|
| 123 |
+
ffn_hidden_size=4096,
|
| 124 |
+
act_layer=nn.GELU,
|
| 125 |
+
dropout=0.0,
|
| 126 |
+
**_kwargs,
|
| 127 |
+
):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.fc1 = nn.Linear(hidden_size, ffn_hidden_size)
|
| 130 |
+
self.act = act_layer()
|
| 131 |
+
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
|
| 132 |
+
self.drop = Dropout(dropout)
|
| 133 |
+
|
| 134 |
+
def forward(self, x, _mask=None):
|
| 135 |
+
x = self.fc1(x)
|
| 136 |
+
x = self.act(x)
|
| 137 |
+
x = self.drop(x)
|
| 138 |
+
x = self.fc2(x)
|
| 139 |
+
return self.drop(x)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def rotate_half(x):
|
| 143 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 144 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@torch.autocast(enabled=False, device_type="cuda")
|
| 148 |
+
def apply_rotary_pos_emb(pos, t):
|
| 149 |
+
if pos.dim() == 3:
|
| 150 |
+
pos = pos.unsqueeze(1)
|
| 151 |
+
return t * pos.cos() + rotate_half(t) * pos.sin()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class RotaryEmbedding(nn.Module):
|
| 155 |
+
def __init__(self, dim, theta=50000):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.register_buffer(
|
| 158 |
+
"inv_freq",
|
| 159 |
+
1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)),
|
| 160 |
+
persistent=False,
|
| 161 |
+
)
|
| 162 |
+
self._theta = float(theta)
|
| 163 |
+
|
| 164 |
+
def _apply(self, fn):
|
| 165 |
+
inv_freq = self.inv_freq
|
| 166 |
+
super()._apply(fn)
|
| 167 |
+
self.inv_freq = inv_freq.to(device=self.inv_freq.device, dtype=torch.float32)
|
| 168 |
+
return self
|
| 169 |
+
|
| 170 |
+
@torch.autocast(enabled=False, device_type="cuda")
|
| 171 |
+
def forward(self, t):
|
| 172 |
+
inv_freq = self.inv_freq
|
| 173 |
+
if inv_freq.device != t.device:
|
| 174 |
+
raise RuntimeError(
|
| 175 |
+
"RotaryEmbedding buffer device mismatch: "
|
| 176 |
+
f"inv_freq={inv_freq.device} input={t.device}."
|
| 177 |
+
)
|
| 178 |
+
t = t.to(dtype=inv_freq.dtype)
|
| 179 |
+
if t.dim() == 1:
|
| 180 |
+
freqs = torch.einsum("i , j -> i j", t, inv_freq)
|
| 181 |
+
else:
|
| 182 |
+
freqs = torch.einsum("bi, j -> bij", t, inv_freq)
|
| 183 |
+
return torch.cat((freqs, freqs), dim=-1)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class MultiHeadAttention(nn.Module):
|
| 187 |
+
"""Multi-head attention"""
|
| 188 |
+
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
hidden_size: int,
|
| 192 |
+
num_heads: int = 8,
|
| 193 |
+
qkv_bias: bool = False,
|
| 194 |
+
qk_norm: bool = False,
|
| 195 |
+
attn_drop: float = 0.0,
|
| 196 |
+
dropout: float = 0.0,
|
| 197 |
+
norm_layer: str = "LayerNorm",
|
| 198 |
+
rotary_bias: bool = False,
|
| 199 |
+
rotary_theta: float | None = 50000,
|
| 200 |
+
**_kwargs,
|
| 201 |
+
):
|
| 202 |
+
super().__init__()
|
| 203 |
+
assert hidden_size % num_heads == 0, (
|
| 204 |
+
"hidden_size should be divisible by num_heads"
|
| 205 |
+
)
|
| 206 |
+
self.num_heads = num_heads
|
| 207 |
+
self.head_dim = hidden_size // num_heads
|
| 208 |
+
self.scale = self.head_dim**-0.5
|
| 209 |
+
self.rotary_bias = rotary_bias
|
| 210 |
+
|
| 211 |
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
|
| 212 |
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
|
| 213 |
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
|
| 214 |
+
|
| 215 |
+
norm_layer = getattr(nn, norm_layer)
|
| 216 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 217 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 218 |
+
|
| 219 |
+
self.attn_drop = Dropout(attn_drop)
|
| 220 |
+
self.o_proj = nn.Linear(hidden_size, hidden_size)
|
| 221 |
+
self.o_dropout = Dropout(dropout)
|
| 222 |
+
|
| 223 |
+
if self.rotary_bias:
|
| 224 |
+
self.rotary = RotaryEmbedding(self.head_dim, theta=rotary_theta)
|
| 225 |
+
|
| 226 |
+
def forward(self, q, k=None, v=None, mask=None, pos_ids=None, **_kwargs):
|
| 227 |
+
k = k or q
|
| 228 |
+
v = v or q
|
| 229 |
+
B, L, _ = q.shape
|
| 230 |
+
_, S, _ = v.shape
|
| 231 |
+
if mask is not None:
|
| 232 |
+
if mask.ndim == 2: # [B, L]
|
| 233 |
+
assert L == S
|
| 234 |
+
mask = rearrange(mask, "b j -> b 1 1 j")
|
| 235 |
+
mask = mask.expand(-1, self.num_heads, L, -1)
|
| 236 |
+
elif mask.ndim == 3: # [B, L, S]
|
| 237 |
+
assert mask.size(1) == L and mask.size(2) == S
|
| 238 |
+
mask = mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
| 239 |
+
|
| 240 |
+
q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
|
| 241 |
+
q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
|
| 242 |
+
k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
|
| 243 |
+
v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
|
| 244 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 245 |
+
|
| 246 |
+
# Apply rotary
|
| 247 |
+
if self.rotary_bias:
|
| 248 |
+
if L == S:
|
| 249 |
+
if pos_ids is None:
|
| 250 |
+
rotary_emb = self.rotary(torch.arange(L, device=q.device))
|
| 251 |
+
else:
|
| 252 |
+
rotary_emb = self.rotary(pos_ids)
|
| 253 |
+
q, k = (apply_rotary_pos_emb(rotary_emb, tensor) for tensor in (q, k))
|
| 254 |
+
else:
|
| 255 |
+
q_rotary_emb = self.rotary(torch.arange(L, device=q.device))
|
| 256 |
+
k_rotary_emb = self.rotary(torch.arange(S, device=k.device))
|
| 257 |
+
q = apply_rotary_pos_emb(q_rotary_emb, q)
|
| 258 |
+
k = apply_rotary_pos_emb(k_rotary_emb, k)
|
| 259 |
+
|
| 260 |
+
attn_bias = torch.zeros(B, self.num_heads, L, S, dtype=q.dtype, device=q.device)
|
| 261 |
+
|
| 262 |
+
if mask is not None:
|
| 263 |
+
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
|
| 264 |
+
|
| 265 |
+
out = F.scaled_dot_product_attention(
|
| 266 |
+
q,
|
| 267 |
+
k,
|
| 268 |
+
v,
|
| 269 |
+
attn_mask=attn_bias,
|
| 270 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 274 |
+
return self.o_dropout(self.o_proj(out))
|
| 275 |
+
|
| 276 |
+
def decode_step(self, x, *, cache, positions: torch.Tensor):
|
| 277 |
+
if x.size(1) <= 0:
|
| 278 |
+
raise ValueError("MultiHeadAttention.decode_step expects a non-empty input.")
|
| 279 |
+
if positions.ndim != 1 or positions.size(0) != x.size(1):
|
| 280 |
+
raise ValueError(
|
| 281 |
+
"MultiHeadAttention.decode_step positions must match the decode block length."
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
q = self.q_proj(x)
|
| 285 |
+
k = self.k_proj(x)
|
| 286 |
+
v = self.v_proj(x)
|
| 287 |
+
|
| 288 |
+
q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
|
| 289 |
+
k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
|
| 290 |
+
v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
|
| 291 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 292 |
+
block_len = q.size(2)
|
| 293 |
+
|
| 294 |
+
if self.rotary_bias:
|
| 295 |
+
rotary_emb = self.rotary(positions)
|
| 296 |
+
q = apply_rotary_pos_emb(rotary_emb, q)
|
| 297 |
+
k = apply_rotary_pos_emb(rotary_emb, k)
|
| 298 |
+
|
| 299 |
+
cached_k, cached_v = cache
|
| 300 |
+
cached_k.index_copy_(2, positions, k)
|
| 301 |
+
cached_v.index_copy_(2, positions, v)
|
| 302 |
+
|
| 303 |
+
cache_capacity = cached_k.size(2)
|
| 304 |
+
key_positions = torch.arange(
|
| 305 |
+
cache_capacity,
|
| 306 |
+
device=x.device,
|
| 307 |
+
dtype=torch.long,
|
| 308 |
+
).unsqueeze(0)
|
| 309 |
+
query_positions = positions.unsqueeze(1)
|
| 310 |
+
causal_mask = key_positions <= query_positions
|
| 311 |
+
valid_mask = key_positions <= positions[-1]
|
| 312 |
+
attn_bias = torch.zeros(
|
| 313 |
+
q.size(0),
|
| 314 |
+
self.num_heads,
|
| 315 |
+
block_len,
|
| 316 |
+
cache_capacity,
|
| 317 |
+
dtype=q.dtype,
|
| 318 |
+
device=q.device,
|
| 319 |
+
)
|
| 320 |
+
attn_bias.masked_fill_(
|
| 321 |
+
(causal_mask & valid_mask).unsqueeze(0).unsqueeze(0).logical_not(),
|
| 322 |
+
float("-inf"),
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
out = F.scaled_dot_product_attention(
|
| 326 |
+
q,
|
| 327 |
+
cached_k,
|
| 328 |
+
cached_v,
|
| 329 |
+
attn_mask=attn_bias,
|
| 330 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 331 |
+
)
|
| 332 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 333 |
+
return self.o_dropout(self.o_proj(out)), cache
|
src/dots_tts/modules/backbone/semantic_encoder.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
from dots_tts.modules.backbone.layers import Conv1d, Mlp, MultiHeadAttention
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class SemanticEncoderDecodeState:
|
| 15 |
+
conv_tail: torch.Tensor
|
| 16 |
+
layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...]
|
| 17 |
+
seq_len: int
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TransformerEncoderLayer(nn.Module):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
hidden_size,
|
| 24 |
+
num_heads=16,
|
| 25 |
+
ffn_hidden_size=4096,
|
| 26 |
+
attn_dropout=0.0,
|
| 27 |
+
ffn_dropout=0.0,
|
| 28 |
+
norm_layer="LayerNorm",
|
| 29 |
+
**kwargs,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.attn = MultiHeadAttention(
|
| 33 |
+
hidden_size,
|
| 34 |
+
num_heads,
|
| 35 |
+
attn_drop=attn_dropout,
|
| 36 |
+
norm_layer=norm_layer,
|
| 37 |
+
**kwargs,
|
| 38 |
+
)
|
| 39 |
+
norm_cls = getattr(nn, norm_layer)
|
| 40 |
+
self.attn_norm = norm_cls(hidden_size)
|
| 41 |
+
self.ffn = Mlp(
|
| 42 |
+
hidden_size, ffn_hidden_size, dropout=ffn_dropout, act_layer=nn.SiLU
|
| 43 |
+
)
|
| 44 |
+
self.ffn_norm = norm_cls(hidden_size)
|
| 45 |
+
self.hidden_size = hidden_size
|
| 46 |
+
|
| 47 |
+
def _build_causal_mask(self, T: int, device):
|
| 48 |
+
return torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
|
| 49 |
+
|
| 50 |
+
def _build_padding_mask(self, x_lens, max_len: int, device):
|
| 51 |
+
B = x_lens.size(0)
|
| 52 |
+
positions = torch.arange(max_len, device=device).unsqueeze(0).expand(B, -1)
|
| 53 |
+
return positions < x_lens.unsqueeze(1)
|
| 54 |
+
|
| 55 |
+
def _fuse_attn_mask(self, causal_mask, padding_mask):
|
| 56 |
+
if causal_mask is None and padding_mask is None:
|
| 57 |
+
return None
|
| 58 |
+
if causal_mask is None:
|
| 59 |
+
row = padding_mask.unsqueeze(2)
|
| 60 |
+
col = padding_mask.unsqueeze(1)
|
| 61 |
+
return row & col
|
| 62 |
+
if padding_mask is None:
|
| 63 |
+
return causal_mask.unsqueeze(0)
|
| 64 |
+
|
| 65 |
+
_B, _T = padding_mask.shape
|
| 66 |
+
causal = causal_mask.unsqueeze(0)
|
| 67 |
+
row = padding_mask.unsqueeze(2)
|
| 68 |
+
col = padding_mask.unsqueeze(1)
|
| 69 |
+
pad_2d = row & col
|
| 70 |
+
return causal & pad_2d
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self,
|
| 74 |
+
x,
|
| 75 |
+
x_lens=None,
|
| 76 |
+
causal=True,
|
| 77 |
+
):
|
| 78 |
+
_B, T, C = x.shape
|
| 79 |
+
assert self.hidden_size == C
|
| 80 |
+
device = x.device
|
| 81 |
+
|
| 82 |
+
causal_mask = self._build_causal_mask(T, device) if causal else None
|
| 83 |
+
if x_lens is not None:
|
| 84 |
+
padding_mask = self._build_padding_mask(x_lens, T, device)
|
| 85 |
+
else:
|
| 86 |
+
padding_mask = None
|
| 87 |
+
fused_mask = self._fuse_attn_mask(causal_mask, padding_mask)
|
| 88 |
+
|
| 89 |
+
h = self.attn_norm(x)
|
| 90 |
+
h = self.attn(
|
| 91 |
+
q=h,
|
| 92 |
+
mask=fused_mask,
|
| 93 |
+
)
|
| 94 |
+
x = x + h
|
| 95 |
+
|
| 96 |
+
h = self.ffn_norm(x)
|
| 97 |
+
h = self.ffn(h)
|
| 98 |
+
return x + h
|
| 99 |
+
|
| 100 |
+
def decode_step(
|
| 101 |
+
self,
|
| 102 |
+
x,
|
| 103 |
+
*,
|
| 104 |
+
cache: tuple[torch.Tensor, torch.Tensor],
|
| 105 |
+
positions: torch.Tensor,
|
| 106 |
+
):
|
| 107 |
+
if x.size(1) <= 0:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
"TransformerEncoderLayer.decode_step expects a non-empty input."
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
h = self.attn_norm(x)
|
| 113 |
+
h, cache = self.attn.decode_step(h, cache=cache, positions=positions)
|
| 114 |
+
x = x + h
|
| 115 |
+
|
| 116 |
+
h = self.ffn_norm(x)
|
| 117 |
+
h = self.ffn(h)
|
| 118 |
+
return x + h, cache
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class SuperviseEncoder(nn.Module):
|
| 122 |
+
def __init__(self, config):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.hidden_size = config.get("hidden_size", 1024)
|
| 125 |
+
self.layers = nn.ModuleList(
|
| 126 |
+
[
|
| 127 |
+
TransformerEncoderLayer(
|
| 128 |
+
hidden_size=self.hidden_size,
|
| 129 |
+
num_heads=config.get("num_heads", 16),
|
| 130 |
+
ffn_hidden_size=config.get("ffn_hidden_size", 4096),
|
| 131 |
+
norm_layer=config.get("norm_layer", "LayerNorm"),
|
| 132 |
+
)
|
| 133 |
+
for _ in range(config.get("num_layers", 6))
|
| 134 |
+
]
|
| 135 |
+
)
|
| 136 |
+
self.causal = config.get("causal", False)
|
| 137 |
+
|
| 138 |
+
def forward(self, x, x_lens=None):
|
| 139 |
+
batch_size, seq_len, _ = x.shape
|
| 140 |
+
if x_lens is None:
|
| 141 |
+
x_lens = torch.full(
|
| 142 |
+
(batch_size,), seq_len, device=x.device, dtype=torch.long
|
| 143 |
+
)
|
| 144 |
+
for layer in self.layers:
|
| 145 |
+
x = layer(x, x_lens=x_lens, causal=self.causal)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
def init_decode_state(
|
| 149 |
+
self,
|
| 150 |
+
*,
|
| 151 |
+
batch_size: int,
|
| 152 |
+
max_seq_len: int,
|
| 153 |
+
device: torch.device,
|
| 154 |
+
dtype: torch.dtype,
|
| 155 |
+
):
|
| 156 |
+
layer_caches = []
|
| 157 |
+
for layer in self.layers:
|
| 158 |
+
cache_shape = (
|
| 159 |
+
batch_size,
|
| 160 |
+
layer.attn.num_heads,
|
| 161 |
+
max_seq_len,
|
| 162 |
+
layer.attn.head_dim,
|
| 163 |
+
)
|
| 164 |
+
layer_caches.append(
|
| 165 |
+
(
|
| 166 |
+
torch.zeros(cache_shape, dtype=dtype, device=device),
|
| 167 |
+
torch.zeros(cache_shape, dtype=dtype, device=device),
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
return tuple(layer_caches)
|
| 171 |
+
|
| 172 |
+
def reset_decode_state(
|
| 173 |
+
self,
|
| 174 |
+
layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...],
|
| 175 |
+
) -> None:
|
| 176 |
+
if len(layer_caches) != len(self.layers):
|
| 177 |
+
raise ValueError("Layer cache count does not match encoder depth.")
|
| 178 |
+
for key_cache, value_cache in layer_caches:
|
| 179 |
+
key_cache.zero_()
|
| 180 |
+
value_cache.zero_()
|
| 181 |
+
|
| 182 |
+
def decode_step(self, x, *, layer_caches, positions: torch.Tensor):
|
| 183 |
+
if len(layer_caches) != len(self.layers):
|
| 184 |
+
raise ValueError("Layer cache count does not match encoder depth.")
|
| 185 |
+
|
| 186 |
+
for layer, cache in zip(self.layers, layer_caches, strict=True):
|
| 187 |
+
x, _ = layer.decode_step(x, cache=cache, positions=positions)
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class VAESemanticEncoder(nn.Module):
|
| 192 |
+
def __init__(self, in_dim, out_dim, config):
|
| 193 |
+
super().__init__()
|
| 194 |
+
in_ds_rate = 2
|
| 195 |
+
self.patch_size = int(config.patch_size)
|
| 196 |
+
self.in_ds_rate = in_ds_rate
|
| 197 |
+
self.ds_proj = Conv1d(
|
| 198 |
+
in_dim, in_dim, kernel_size=in_ds_rate, stride=in_ds_rate, causal=True
|
| 199 |
+
)
|
| 200 |
+
self.in_proj = nn.Linear(in_dim, config.PatchEncoder.hidden_size)
|
| 201 |
+
self.encoder = SuperviseEncoder(config.PatchEncoder)
|
| 202 |
+
self.out_ds_rate = self.patch_size // in_ds_rate
|
| 203 |
+
self.out_proj = nn.Linear(
|
| 204 |
+
config.PatchEncoder.hidden_size * self.out_ds_rate, out_dim
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def forward(self, x, x_lens=None):
|
| 208 |
+
x = self._downsample(x)
|
| 209 |
+
x = self.in_proj(x)
|
| 210 |
+
z = self.encoder(x, x_lens=x_lens)
|
| 211 |
+
return self._project_embeddings(z)
|
| 212 |
+
|
| 213 |
+
def init_decode_state(
|
| 214 |
+
self,
|
| 215 |
+
*,
|
| 216 |
+
max_audio_patch_count: int,
|
| 217 |
+
batch_size: int,
|
| 218 |
+
device: torch.device,
|
| 219 |
+
dtype: torch.dtype,
|
| 220 |
+
) -> SemanticEncoderDecodeState:
|
| 221 |
+
return SemanticEncoderDecodeState(
|
| 222 |
+
conv_tail=torch.zeros(
|
| 223 |
+
(batch_size, self.ds_proj.in_channels, self.ds_proj.left_padding),
|
| 224 |
+
dtype=dtype,
|
| 225 |
+
device=device,
|
| 226 |
+
),
|
| 227 |
+
layer_caches=self.encoder.init_decode_state(
|
| 228 |
+
batch_size=batch_size,
|
| 229 |
+
max_seq_len=max_audio_patch_count * self.out_ds_rate,
|
| 230 |
+
device=device,
|
| 231 |
+
dtype=dtype,
|
| 232 |
+
),
|
| 233 |
+
seq_len=0,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def reset_decode_state(self, state: SemanticEncoderDecodeState) -> None:
|
| 237 |
+
state.conv_tail.zero_()
|
| 238 |
+
self.encoder.reset_decode_state(state.layer_caches)
|
| 239 |
+
state.seq_len = 0
|
| 240 |
+
|
| 241 |
+
def prefill(
|
| 242 |
+
self,
|
| 243 |
+
x,
|
| 244 |
+
state: SemanticEncoderDecodeState,
|
| 245 |
+
) -> tuple[torch.Tensor, SemanticEncoderDecodeState]:
|
| 246 |
+
if x.ndim != 3:
|
| 247 |
+
raise ValueError(
|
| 248 |
+
f"VAESemanticEncoder.prefill expects rank-3 input, got {tuple(x.shape)}."
|
| 249 |
+
)
|
| 250 |
+
if x.size(1) % self.patch_size != 0:
|
| 251 |
+
raise ValueError(
|
| 252 |
+
f"Prompt latent length {x.size(1)} must be divisible by patch_size={self.patch_size}."
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if x.size(1) == 0:
|
| 256 |
+
return (
|
| 257 |
+
x.new_zeros((x.size(0), 0, self.out_proj.out_features)),
|
| 258 |
+
state,
|
| 259 |
+
)
|
| 260 |
+
if state.conv_tail.size(0) != x.size(0):
|
| 261 |
+
raise ValueError(
|
| 262 |
+
"VAESemanticEncoder.prefill batch size does not match decode state."
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
step_inputs = self.in_proj(self._downsample(x))
|
| 266 |
+
expected_token_count = (x.size(1) // self.patch_size) * self.out_ds_rate
|
| 267 |
+
if step_inputs.size(1) != expected_token_count:
|
| 268 |
+
raise RuntimeError(
|
| 269 |
+
"Patch encoder prefill produced an unexpected token count: "
|
| 270 |
+
f"expected={expected_token_count} actual={step_inputs.size(1)}."
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
current_seq_len = state.seq_len
|
| 274 |
+
next_seq_len = current_seq_len + step_inputs.size(1)
|
| 275 |
+
cache_capacity = state.layer_caches[0][0].size(2)
|
| 276 |
+
if next_seq_len > cache_capacity:
|
| 277 |
+
raise ValueError(
|
| 278 |
+
"Patch encoder prefill exceeds decode-state capacity: "
|
| 279 |
+
f"required={next_seq_len} capacity={cache_capacity}."
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
positions = (
|
| 283 |
+
torch.arange(step_inputs.size(1), device=x.device, dtype=torch.long)
|
| 284 |
+
+ current_seq_len
|
| 285 |
+
)
|
| 286 |
+
encoded = self.encoder.decode_step(
|
| 287 |
+
step_inputs,
|
| 288 |
+
layer_caches=state.layer_caches,
|
| 289 |
+
positions=positions,
|
| 290 |
+
)
|
| 291 |
+
embedding = self._project_embeddings(encoded)
|
| 292 |
+
raw = x.transpose(1, 2)
|
| 293 |
+
state.conv_tail.copy_(raw[..., -self.ds_proj.left_padding :])
|
| 294 |
+
state.seq_len = next_seq_len
|
| 295 |
+
return embedding, state
|
| 296 |
+
|
| 297 |
+
def decode_patch(
|
| 298 |
+
self,
|
| 299 |
+
latent_patch,
|
| 300 |
+
conv_tail: torch.Tensor,
|
| 301 |
+
layer_caches: tuple[tuple[torch.Tensor, torch.Tensor], ...],
|
| 302 |
+
positions: torch.Tensor,
|
| 303 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 304 |
+
if latent_patch.ndim != 3:
|
| 305 |
+
raise ValueError(
|
| 306 |
+
f"VAESemanticEncoder.decode_patch expects rank-3 input, got {tuple(latent_patch.shape)}."
|
| 307 |
+
)
|
| 308 |
+
if latent_patch.size(1) != self.patch_size:
|
| 309 |
+
raise ValueError(
|
| 310 |
+
f"decode_patch expects patch length {self.patch_size}, got {latent_patch.size(1)}."
|
| 311 |
+
)
|
| 312 |
+
if positions.ndim != 1 or positions.size(0) != self.out_ds_rate:
|
| 313 |
+
raise ValueError(
|
| 314 |
+
"decode_patch positions must be a rank-1 tensor matching out_ds_rate."
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
step_inputs, conv_tail = self._downsample_step(
|
| 318 |
+
latent_patch,
|
| 319 |
+
conv_tail=conv_tail,
|
| 320 |
+
)
|
| 321 |
+
if step_inputs.size(1) != self.out_ds_rate:
|
| 322 |
+
raise RuntimeError(
|
| 323 |
+
f"Downsample step produced {step_inputs.size(1)} tokens, expected {self.out_ds_rate}."
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
encoded = self.encoder.decode_step(
|
| 327 |
+
step_inputs,
|
| 328 |
+
layer_caches=layer_caches,
|
| 329 |
+
positions=positions,
|
| 330 |
+
)
|
| 331 |
+
embedding = self._project_embeddings(encoded)
|
| 332 |
+
return embedding, conv_tail
|
| 333 |
+
|
| 334 |
+
def _downsample(self, x):
|
| 335 |
+
return self.ds_proj(x.transpose(1, 2)).transpose(1, 2)
|
| 336 |
+
|
| 337 |
+
def _project_embeddings(self, z):
|
| 338 |
+
if self.out_ds_rate > 1:
|
| 339 |
+
z = rearrange(z, "b (s d) h -> b s (d h)", d=self.out_ds_rate)
|
| 340 |
+
return self.out_proj(z)
|
| 341 |
+
|
| 342 |
+
def _downsample_step(self, latent_patch, *, conv_tail):
|
| 343 |
+
raw = latent_patch.transpose(1, 2)
|
| 344 |
+
conv_input = torch.cat([conv_tail, raw], dim=-1)
|
| 345 |
+
|
| 346 |
+
projected = F.conv1d(
|
| 347 |
+
conv_input,
|
| 348 |
+
self.ds_proj.weight,
|
| 349 |
+
self.ds_proj.bias,
|
| 350 |
+
stride=self.ds_proj.stride[0],
|
| 351 |
+
padding=0,
|
| 352 |
+
dilation=self.ds_proj.dilation[0],
|
| 353 |
+
groups=self.ds_proj.groups,
|
| 354 |
+
).transpose(1, 2)
|
| 355 |
+
new_conv_tail = raw[..., -self.ds_proj.left_padding :]
|
| 356 |
+
return self.in_proj(projected), new_conv_tail
|
src/dots_tts/modules/speaker/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Speaker modules."""
|
src/dots_tts/modules/speaker/campplus.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from dots_tts.modules.speaker.campplus_layers import (
|
| 11 |
+
BasicResBlock,
|
| 12 |
+
CAMDenseTDNNBlock,
|
| 13 |
+
DenseLayer,
|
| 14 |
+
StatsPool,
|
| 15 |
+
TDNNLayer,
|
| 16 |
+
TransitLayer,
|
| 17 |
+
get_nonlinear,
|
| 18 |
+
)
|
| 19 |
+
from dots_tts.modules.speaker.fbank import _SPEAKER_FBANK_N_MELS
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FCM(nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
block=BasicResBlock,
|
| 26 |
+
num_blocks=(2, 2),
|
| 27 |
+
m_channels=32,
|
| 28 |
+
feat_dim=_SPEAKER_FBANK_N_MELS,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.in_planes = m_channels
|
| 32 |
+
self.conv1 = nn.Conv2d(
|
| 33 |
+
1, m_channels, kernel_size=3, stride=1, padding=1, bias=False
|
| 34 |
+
)
|
| 35 |
+
self.bn1 = nn.BatchNorm2d(m_channels)
|
| 36 |
+
|
| 37 |
+
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
| 38 |
+
self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
|
| 39 |
+
|
| 40 |
+
self.conv2 = nn.Conv2d(
|
| 41 |
+
m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False
|
| 42 |
+
)
|
| 43 |
+
self.bn2 = nn.BatchNorm2d(m_channels)
|
| 44 |
+
self.out_channels = m_channels * (feat_dim // 8)
|
| 45 |
+
|
| 46 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
| 47 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 48 |
+
layers = []
|
| 49 |
+
for stride in strides:
|
| 50 |
+
layers.append(block(self.in_planes, planes, stride))
|
| 51 |
+
self.in_planes = planes * block.expansion
|
| 52 |
+
return nn.Sequential(*layers)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
x = x.unsqueeze(1)
|
| 56 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 57 |
+
out = self.layer1(out)
|
| 58 |
+
out = self.layer2(out)
|
| 59 |
+
out = F.relu(self.bn2(self.conv2(out)))
|
| 60 |
+
|
| 61 |
+
shape = out.shape
|
| 62 |
+
return out.reshape(shape[0], shape[1] * shape[2], shape[3])
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class CAMPPlus(nn.Module):
|
| 66 |
+
_TDNN_KERNEL_SIZE = 5
|
| 67 |
+
_TDNN_STRIDE = 2
|
| 68 |
+
_TDNN_PADDING = 2
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
feat_dim=_SPEAKER_FBANK_N_MELS,
|
| 73 |
+
embedding_size=512,
|
| 74 |
+
growth_rate=32,
|
| 75 |
+
bn_size=4,
|
| 76 |
+
init_channels=128,
|
| 77 |
+
config_str="batchnorm-relu",
|
| 78 |
+
memory_efficient=True,
|
| 79 |
+
):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
self.head = FCM(feat_dim=feat_dim)
|
| 83 |
+
channels = self.head.out_channels
|
| 84 |
+
|
| 85 |
+
self.xvector = nn.Sequential(
|
| 86 |
+
OrderedDict(
|
| 87 |
+
[
|
| 88 |
+
(
|
| 89 |
+
"tdnn",
|
| 90 |
+
TDNNLayer(
|
| 91 |
+
channels,
|
| 92 |
+
init_channels,
|
| 93 |
+
self._TDNN_KERNEL_SIZE,
|
| 94 |
+
stride=self._TDNN_STRIDE,
|
| 95 |
+
dilation=1,
|
| 96 |
+
padding=-1,
|
| 97 |
+
config_str=config_str,
|
| 98 |
+
),
|
| 99 |
+
),
|
| 100 |
+
]
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
+
channels = init_channels
|
| 104 |
+
for i, (num_layers, kernel_size, dilation) in enumerate(
|
| 105 |
+
zip((12, 24, 16), (3, 3, 3), (1, 2, 2), strict=True)
|
| 106 |
+
):
|
| 107 |
+
block = CAMDenseTDNNBlock(
|
| 108 |
+
num_layers=num_layers,
|
| 109 |
+
in_channels=channels,
|
| 110 |
+
out_channels=growth_rate,
|
| 111 |
+
bn_channels=bn_size * growth_rate,
|
| 112 |
+
kernel_size=kernel_size,
|
| 113 |
+
dilation=dilation,
|
| 114 |
+
config_str=config_str,
|
| 115 |
+
memory_efficient=memory_efficient,
|
| 116 |
+
)
|
| 117 |
+
self.xvector.add_module(f"block{i + 1}", block)
|
| 118 |
+
channels = channels + num_layers * growth_rate
|
| 119 |
+
self.xvector.add_module(
|
| 120 |
+
f"transit{i + 1}",
|
| 121 |
+
TransitLayer(
|
| 122 |
+
channels, channels // 2, bias=False, config_str=config_str
|
| 123 |
+
),
|
| 124 |
+
)
|
| 125 |
+
channels //= 2
|
| 126 |
+
|
| 127 |
+
self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels))
|
| 128 |
+
|
| 129 |
+
self.xvector.add_module("stats", StatsPool())
|
| 130 |
+
self.xvector.add_module(
|
| 131 |
+
"dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_")
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
for m in self.modules():
|
| 135 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 136 |
+
nn.init.kaiming_normal_(m.weight.data)
|
| 137 |
+
if m.bias is not None:
|
| 138 |
+
nn.init.zeros_(m.bias)
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def _conv_output_lengths(lengths, kernel_size, stride=1, padding=0, dilation=1):
|
| 142 |
+
return (
|
| 143 |
+
torch.div(
|
| 144 |
+
lengths + 2 * padding - dilation * (kernel_size - 1) - 1,
|
| 145 |
+
stride,
|
| 146 |
+
rounding_mode="floor",
|
| 147 |
+
)
|
| 148 |
+
+ 1
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def _make_length_mask(lengths, max_len, device):
|
| 153 |
+
lengths = lengths.to(device=device, dtype=torch.long).clamp(min=0, max=max_len)
|
| 154 |
+
return torch.arange(max_len, device=device).unsqueeze(0) < lengths.unsqueeze(1)
|
| 155 |
+
|
| 156 |
+
def _masked_stats_pooling(self, x, lengths, unbiased=True, eps=1e-2):
|
| 157 |
+
lengths = lengths.to(device=x.device, dtype=torch.long).clamp(
|
| 158 |
+
min=1, max=x.size(-1)
|
| 159 |
+
)
|
| 160 |
+
mask = self._make_length_mask(lengths, x.size(-1), x.device).unsqueeze(1)
|
| 161 |
+
mask = mask.to(dtype=x.dtype)
|
| 162 |
+
|
| 163 |
+
denom = lengths.to(dtype=x.dtype).view(-1, 1).clamp_min(1.0)
|
| 164 |
+
mean = (x * mask).sum(dim=-1) / denom
|
| 165 |
+
|
| 166 |
+
centered = (x - mean.unsqueeze(-1)) * mask
|
| 167 |
+
var_denom = (
|
| 168 |
+
(lengths - 1).clamp_min(1).to(dtype=x.dtype).view(-1, 1)
|
| 169 |
+
if unbiased
|
| 170 |
+
else denom
|
| 171 |
+
)
|
| 172 |
+
var = centered.pow(2).sum(dim=-1) / var_denom
|
| 173 |
+
std = torch.sqrt(var.clamp_min(eps))
|
| 174 |
+
return torch.cat([mean, std], dim=1)
|
| 175 |
+
|
| 176 |
+
def forward(self, x, lengths=None):
|
| 177 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 178 |
+
x = self.head(x)
|
| 179 |
+
if lengths is not None:
|
| 180 |
+
lengths = lengths.to(device=x.device, dtype=torch.long).clamp(min=1)
|
| 181 |
+
|
| 182 |
+
for name, module in self.xvector.named_children():
|
| 183 |
+
if name == "stats":
|
| 184 |
+
x = (
|
| 185 |
+
self._masked_stats_pooling(x, lengths)
|
| 186 |
+
if lengths is not None
|
| 187 |
+
else module(x)
|
| 188 |
+
)
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
x = module(x)
|
| 192 |
+
if name == "tdnn" and lengths is not None:
|
| 193 |
+
lengths = self._conv_output_lengths(
|
| 194 |
+
lengths,
|
| 195 |
+
kernel_size=self._TDNN_KERNEL_SIZE,
|
| 196 |
+
stride=self._TDNN_STRIDE,
|
| 197 |
+
padding=self._TDNN_PADDING,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
return x
|
src/dots_tts/modules/speaker/campplus_layers.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.utils.checkpoint as cp
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_nonlinear(config_str, channels):
|
| 11 |
+
nonlinear = nn.Sequential()
|
| 12 |
+
for name in config_str.split("-"):
|
| 13 |
+
if name == "relu":
|
| 14 |
+
nonlinear.add_module("relu", nn.ReLU(inplace=True))
|
| 15 |
+
elif name == "prelu":
|
| 16 |
+
nonlinear.add_module("prelu", nn.PReLU(channels))
|
| 17 |
+
elif name == "batchnorm":
|
| 18 |
+
nonlinear.add_module("batchnorm", nn.BatchNorm1d(channels))
|
| 19 |
+
elif name == "batchnorm_":
|
| 20 |
+
nonlinear.add_module("batchnorm", nn.BatchNorm1d(channels, affine=False))
|
| 21 |
+
else:
|
| 22 |
+
raise ValueError(f"Unexpected module ({name}).")
|
| 23 |
+
return nonlinear
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, _eps=1e-2):
|
| 27 |
+
mean = x.mean(dim=dim)
|
| 28 |
+
std = x.std(dim=dim, unbiased=unbiased)
|
| 29 |
+
stats = torch.cat([mean, std], dim=-1)
|
| 30 |
+
if keepdim:
|
| 31 |
+
stats = stats.unsqueeze(dim=dim)
|
| 32 |
+
return stats
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class StatsPool(nn.Module):
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return statistics_pooling(x)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TDNNLayer(nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
in_channels,
|
| 44 |
+
out_channels,
|
| 45 |
+
kernel_size,
|
| 46 |
+
stride=1,
|
| 47 |
+
padding=0,
|
| 48 |
+
dilation=1,
|
| 49 |
+
bias=False,
|
| 50 |
+
config_str="batchnorm-relu",
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
if padding < 0:
|
| 54 |
+
assert kernel_size % 2 == 1, (
|
| 55 |
+
f"Expect equal paddings, but got even kernel size ({kernel_size})"
|
| 56 |
+
)
|
| 57 |
+
padding = (kernel_size - 1) // 2 * dilation
|
| 58 |
+
self.linear = nn.Conv1d(
|
| 59 |
+
in_channels,
|
| 60 |
+
out_channels,
|
| 61 |
+
kernel_size,
|
| 62 |
+
stride=stride,
|
| 63 |
+
padding=padding,
|
| 64 |
+
dilation=dilation,
|
| 65 |
+
bias=bias,
|
| 66 |
+
)
|
| 67 |
+
self.nonlinear = get_nonlinear(config_str, out_channels)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
x = self.linear(x)
|
| 71 |
+
return self.nonlinear(x)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class CAMLayer(nn.Module):
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
bn_channels,
|
| 78 |
+
out_channels,
|
| 79 |
+
kernel_size,
|
| 80 |
+
stride,
|
| 81 |
+
padding,
|
| 82 |
+
dilation,
|
| 83 |
+
bias,
|
| 84 |
+
reduction=2,
|
| 85 |
+
):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.linear_local = nn.Conv1d(
|
| 88 |
+
bn_channels,
|
| 89 |
+
out_channels,
|
| 90 |
+
kernel_size,
|
| 91 |
+
stride=stride,
|
| 92 |
+
padding=padding,
|
| 93 |
+
dilation=dilation,
|
| 94 |
+
bias=bias,
|
| 95 |
+
)
|
| 96 |
+
self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
|
| 97 |
+
self.relu = nn.ReLU(inplace=True)
|
| 98 |
+
self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
|
| 99 |
+
self.sigmoid = nn.Sigmoid()
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
y = self.linear_local(x)
|
| 103 |
+
context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
|
| 104 |
+
context = self.relu(self.linear1(context))
|
| 105 |
+
m = self.sigmoid(self.linear2(context))
|
| 106 |
+
return y * m
|
| 107 |
+
|
| 108 |
+
def seg_pooling(self, x, seg_len=100, stype="avg"):
|
| 109 |
+
if stype == "avg":
|
| 110 |
+
seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
|
| 111 |
+
elif stype == "max":
|
| 112 |
+
seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError("Wrong segment pooling type.")
|
| 115 |
+
shape = seg.shape
|
| 116 |
+
seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
|
| 117 |
+
return seg[..., : x.shape[-1]]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class CAMDenseTDNNLayer(nn.Module):
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
in_channels,
|
| 124 |
+
out_channels,
|
| 125 |
+
bn_channels,
|
| 126 |
+
kernel_size,
|
| 127 |
+
stride=1,
|
| 128 |
+
dilation=1,
|
| 129 |
+
bias=False,
|
| 130 |
+
config_str="batchnorm-relu",
|
| 131 |
+
memory_efficient=False,
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
assert kernel_size % 2 == 1, (
|
| 135 |
+
f"Expect equal paddings, but got even kernel size ({kernel_size})"
|
| 136 |
+
)
|
| 137 |
+
padding = (kernel_size - 1) // 2 * dilation
|
| 138 |
+
self.memory_efficient = memory_efficient
|
| 139 |
+
self.nonlinear1 = get_nonlinear(config_str, in_channels)
|
| 140 |
+
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
|
| 141 |
+
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
|
| 142 |
+
self.cam_layer = CAMLayer(
|
| 143 |
+
bn_channels,
|
| 144 |
+
out_channels,
|
| 145 |
+
kernel_size,
|
| 146 |
+
stride=stride,
|
| 147 |
+
padding=padding,
|
| 148 |
+
dilation=dilation,
|
| 149 |
+
bias=bias,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def bn_function(self, x):
|
| 153 |
+
return self.linear1(self.nonlinear1(x))
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
if self.training and self.memory_efficient:
|
| 157 |
+
x = cp.checkpoint(self.bn_function, x)
|
| 158 |
+
else:
|
| 159 |
+
x = self.bn_function(x)
|
| 160 |
+
return self.cam_layer(self.nonlinear2(x))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class CAMDenseTDNNBlock(nn.ModuleList):
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
num_layers,
|
| 167 |
+
in_channels,
|
| 168 |
+
out_channels,
|
| 169 |
+
bn_channels,
|
| 170 |
+
kernel_size,
|
| 171 |
+
stride=1,
|
| 172 |
+
dilation=1,
|
| 173 |
+
bias=False,
|
| 174 |
+
config_str="batchnorm-relu",
|
| 175 |
+
memory_efficient=False,
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
for i in range(num_layers):
|
| 179 |
+
layer = CAMDenseTDNNLayer(
|
| 180 |
+
in_channels=in_channels + i * out_channels,
|
| 181 |
+
out_channels=out_channels,
|
| 182 |
+
bn_channels=bn_channels,
|
| 183 |
+
kernel_size=kernel_size,
|
| 184 |
+
stride=stride,
|
| 185 |
+
dilation=dilation,
|
| 186 |
+
bias=bias,
|
| 187 |
+
config_str=config_str,
|
| 188 |
+
memory_efficient=memory_efficient,
|
| 189 |
+
)
|
| 190 |
+
self.add_module(f"tdnnd{i + 1}", layer)
|
| 191 |
+
|
| 192 |
+
def forward(self, x):
|
| 193 |
+
for layer in self:
|
| 194 |
+
x = torch.cat([x, layer(x)], dim=1)
|
| 195 |
+
return x
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class TransitLayer(nn.Module):
|
| 199 |
+
def __init__(
|
| 200 |
+
self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"
|
| 201 |
+
):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.nonlinear = get_nonlinear(config_str, in_channels)
|
| 204 |
+
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
x = self.nonlinear(x)
|
| 208 |
+
return self.linear(x)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class DenseLayer(nn.Module):
|
| 212 |
+
def __init__(
|
| 213 |
+
self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"
|
| 214 |
+
):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
| 217 |
+
self.nonlinear = get_nonlinear(config_str, out_channels)
|
| 218 |
+
|
| 219 |
+
def forward(self, x):
|
| 220 |
+
if len(x.shape) == 2:
|
| 221 |
+
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
|
| 222 |
+
else:
|
| 223 |
+
x = self.linear(x)
|
| 224 |
+
return self.nonlinear(x)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class BasicResBlock(nn.Module):
|
| 228 |
+
expansion = 1
|
| 229 |
+
|
| 230 |
+
def __init__(self, in_planes, planes, stride=1):
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.conv1 = nn.Conv2d(
|
| 233 |
+
in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False
|
| 234 |
+
)
|
| 235 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 236 |
+
self.conv2 = nn.Conv2d(
|
| 237 |
+
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
|
| 238 |
+
)
|
| 239 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 240 |
+
|
| 241 |
+
self.shortcut = nn.Sequential()
|
| 242 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 243 |
+
self.shortcut = nn.Sequential(
|
| 244 |
+
nn.Conv2d(
|
| 245 |
+
in_planes,
|
| 246 |
+
self.expansion * planes,
|
| 247 |
+
kernel_size=1,
|
| 248 |
+
stride=(stride, 1),
|
| 249 |
+
bias=False,
|
| 250 |
+
),
|
| 251 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
def forward(self, x):
|
| 255 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 256 |
+
out = self.bn2(self.conv2(out))
|
| 257 |
+
out += self.shortcut(x)
|
| 258 |
+
return F.relu(out)
|
src/dots_tts/modules/speaker/encoder.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchaudio
|
| 7 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 8 |
+
|
| 9 |
+
from dots_tts.modules.speaker.campplus import CAMPPlus
|
| 10 |
+
from dots_tts.modules.speaker.fbank import (
|
| 11 |
+
_SPEAKER_FBANK_N_MELS,
|
| 12 |
+
_SPEAKER_FBANK_SAMPLE_RATE,
|
| 13 |
+
extract_speaker_fbank,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SpeakerXVectorFeatures(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Speaker embedding extractor based on 3D-Speaker CAM++.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
sample_rate=_SPEAKER_FBANK_SAMPLE_RATE,
|
| 25 |
+
campplus_embedding_size=512,
|
| 26 |
+
max_audio_seconds=10.0,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.sample_rate = sample_rate
|
| 31 |
+
self.max_audio_seconds = float(max_audio_seconds)
|
| 32 |
+
self.model = CAMPPlus(
|
| 33 |
+
feat_dim=_SPEAKER_FBANK_N_MELS,
|
| 34 |
+
embedding_size=campplus_embedding_size,
|
| 35 |
+
)
|
| 36 |
+
self.resample = None
|
| 37 |
+
if self.sample_rate != _SPEAKER_FBANK_SAMPLE_RATE:
|
| 38 |
+
self.resample = torchaudio.transforms.Resample(
|
| 39 |
+
orig_freq=sample_rate,
|
| 40 |
+
new_freq=_SPEAKER_FBANK_SAMPLE_RATE,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
for param in self.model.parameters():
|
| 44 |
+
param.requires_grad = False
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def _normalize_lengths(lengths, batch_size, max_length, device, *, min_length):
|
| 48 |
+
if lengths is None:
|
| 49 |
+
return torch.full(
|
| 50 |
+
(batch_size,),
|
| 51 |
+
max_length,
|
| 52 |
+
device=device,
|
| 53 |
+
dtype=torch.long,
|
| 54 |
+
)
|
| 55 |
+
return lengths.to(device=device, dtype=torch.long).clamp(
|
| 56 |
+
min=min_length,
|
| 57 |
+
max=max_length,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def _crop_audio(self, audio, audio_lengths=None):
|
| 61 |
+
original_lengths = self._normalize_lengths(
|
| 62 |
+
audio_lengths,
|
| 63 |
+
audio.size(0),
|
| 64 |
+
audio.size(-1),
|
| 65 |
+
audio.device,
|
| 66 |
+
min_length=0,
|
| 67 |
+
)
|
| 68 |
+
if self.max_audio_seconds <= 0:
|
| 69 |
+
return audio, original_lengths, original_lengths, torch.zeros_like(
|
| 70 |
+
original_lengths
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
max_input_length = round(self.sample_rate * self.max_audio_seconds)
|
| 74 |
+
cropped_audio = []
|
| 75 |
+
cropped_lengths = []
|
| 76 |
+
starts = []
|
| 77 |
+
|
| 78 |
+
for index, total_length_tensor in enumerate(original_lengths):
|
| 79 |
+
total_length = int(total_length_tensor.item())
|
| 80 |
+
cropped_length = min(total_length, max_input_length)
|
| 81 |
+
start = (
|
| 82 |
+
random.randint(0, total_length - cropped_length)
|
| 83 |
+
if total_length > cropped_length
|
| 84 |
+
else 0
|
| 85 |
+
)
|
| 86 |
+
cropped_audio.append(audio[index, start : start + cropped_length])
|
| 87 |
+
cropped_lengths.append(cropped_length)
|
| 88 |
+
starts.append(start)
|
| 89 |
+
|
| 90 |
+
return pad_sequence(
|
| 91 |
+
cropped_audio,
|
| 92 |
+
batch_first=True,
|
| 93 |
+
padding_value=0.0,
|
| 94 |
+
), original_lengths, torch.tensor(
|
| 95 |
+
cropped_lengths,
|
| 96 |
+
device=audio.device,
|
| 97 |
+
dtype=torch.long,
|
| 98 |
+
), torch.tensor(starts, device=audio.device, dtype=torch.long)
|
| 99 |
+
|
| 100 |
+
def _crop_fbank(
|
| 101 |
+
self,
|
| 102 |
+
fbank,
|
| 103 |
+
fbank_lengths,
|
| 104 |
+
original_audio_lengths,
|
| 105 |
+
cropped_audio_lengths,
|
| 106 |
+
starts,
|
| 107 |
+
):
|
| 108 |
+
original_fbank_lengths = self._normalize_lengths(
|
| 109 |
+
fbank_lengths,
|
| 110 |
+
fbank.size(0),
|
| 111 |
+
fbank.size(1),
|
| 112 |
+
fbank.device,
|
| 113 |
+
min_length=1,
|
| 114 |
+
)
|
| 115 |
+
cropped_fbank = []
|
| 116 |
+
cropped_fbank_lengths = []
|
| 117 |
+
|
| 118 |
+
for index, total_feat_length_tensor in enumerate(original_fbank_lengths):
|
| 119 |
+
total_audio_length = int(original_audio_lengths[index].item())
|
| 120 |
+
total_feat_length = int(total_feat_length_tensor.item())
|
| 121 |
+
start_audio = int(starts[index].item())
|
| 122 |
+
end_audio = start_audio + int(cropped_audio_lengths[index].item())
|
| 123 |
+
|
| 124 |
+
if total_audio_length > 0:
|
| 125 |
+
start_feat = math.floor(
|
| 126 |
+
start_audio * total_feat_length / total_audio_length
|
| 127 |
+
)
|
| 128 |
+
end_feat = math.ceil(end_audio * total_feat_length / total_audio_length)
|
| 129 |
+
else:
|
| 130 |
+
start_feat = 0
|
| 131 |
+
end_feat = 1
|
| 132 |
+
|
| 133 |
+
start_feat = min(start_feat, total_feat_length - 1)
|
| 134 |
+
end_feat = min(max(end_feat, start_feat + 1), total_feat_length)
|
| 135 |
+
cropped_fbank.append(fbank[index, start_feat:end_feat])
|
| 136 |
+
cropped_fbank_lengths.append(end_feat - start_feat)
|
| 137 |
+
|
| 138 |
+
return pad_sequence(
|
| 139 |
+
cropped_fbank,
|
| 140 |
+
batch_first=True,
|
| 141 |
+
padding_value=0.0,
|
| 142 |
+
), torch.tensor(
|
| 143 |
+
cropped_fbank_lengths,
|
| 144 |
+
device=fbank.device,
|
| 145 |
+
dtype=torch.long,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def _extract_fbank_batch(self, audio, audio_lengths):
|
| 149 |
+
if self.resample is not None:
|
| 150 |
+
audio = self.resample(audio)
|
| 151 |
+
audio_lengths = torch.ceil(
|
| 152 |
+
audio_lengths.float()
|
| 153 |
+
* (_SPEAKER_FBANK_SAMPLE_RATE / self.sample_rate)
|
| 154 |
+
).long()
|
| 155 |
+
|
| 156 |
+
audio_cpu = audio.detach().cpu()
|
| 157 |
+
features = []
|
| 158 |
+
|
| 159 |
+
for index, valid_length_tensor in enumerate(audio_lengths):
|
| 160 |
+
valid_length = int(valid_length_tensor.item())
|
| 161 |
+
waveform = audio_cpu[index, :valid_length]
|
| 162 |
+
if waveform.numel() == 0:
|
| 163 |
+
waveform = audio_cpu.new_zeros(1)
|
| 164 |
+
features.append(
|
| 165 |
+
extract_speaker_fbank(
|
| 166 |
+
waveform,
|
| 167 |
+
sample_rate=_SPEAKER_FBANK_SAMPLE_RATE,
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
fbank_lengths = torch.tensor(
|
| 172 |
+
[feature.size(0) for feature in features],
|
| 173 |
+
device=audio.device,
|
| 174 |
+
dtype=torch.long,
|
| 175 |
+
)
|
| 176 |
+
fbank = pad_sequence(
|
| 177 |
+
features,
|
| 178 |
+
batch_first=True,
|
| 179 |
+
padding_value=0.0,
|
| 180 |
+
).to(device=audio.device, dtype=audio.dtype)
|
| 181 |
+
return fbank, fbank_lengths
|
| 182 |
+
|
| 183 |
+
@torch.no_grad()
|
| 184 |
+
@torch.autocast(enabled=False, device_type="cuda")
|
| 185 |
+
def forward(
|
| 186 |
+
self, audio, audio_lengths=None, fbank=None, fbank_lengths=None, **_kwargs
|
| 187 |
+
):
|
| 188 |
+
self.model.eval()
|
| 189 |
+
audio = audio.float()
|
| 190 |
+
if audio.dim() == 3:
|
| 191 |
+
if audio.size(1) != 1:
|
| 192 |
+
raise ValueError(
|
| 193 |
+
f"Speaker encoder expects mono audio, got shape {tuple(audio.shape)}."
|
| 194 |
+
)
|
| 195 |
+
audio = audio[:, 0]
|
| 196 |
+
elif audio.dim() != 2:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"Speaker encoder expects a 2D or 3D audio tensor, got shape {tuple(audio.shape)}."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
audio, original_audio_lengths, cropped_audio_lengths, starts = self._crop_audio(
|
| 202 |
+
audio,
|
| 203 |
+
audio_lengths=audio_lengths,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if fbank is None:
|
| 207 |
+
fbank, fbank_lengths = self._extract_fbank_batch(
|
| 208 |
+
audio,
|
| 209 |
+
cropped_audio_lengths,
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
if not isinstance(fbank, torch.Tensor):
|
| 213 |
+
raise TypeError("Speaker encoder expects `fbank` to be a torch.Tensor.")
|
| 214 |
+
if fbank.dim() != 3 or fbank.size(0) != audio.size(0):
|
| 215 |
+
raise ValueError(
|
| 216 |
+
f"Speaker encoder expects `fbank` with shape (B, T, F) and matching batch size, got {tuple(fbank.shape)}."
|
| 217 |
+
)
|
| 218 |
+
fbank, fbank_lengths = self._crop_fbank(
|
| 219 |
+
fbank.to(device=audio.device, dtype=torch.float32),
|
| 220 |
+
fbank_lengths,
|
| 221 |
+
original_audio_lengths,
|
| 222 |
+
cropped_audio_lengths,
|
| 223 |
+
starts,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
return self.model(fbank, lengths=fbank_lengths)
|
src/dots_tts/modules/speaker/fbank.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from dots_tts.utils.audio import extract_fbank, high_quality_resample
|
| 6 |
+
|
| 7 |
+
_SPEAKER_FBANK_SAMPLE_RATE = 16000
|
| 8 |
+
_SPEAKER_FBANK_N_MELS = 80
|
| 9 |
+
_SPEAKER_FBANK_MEAN_NORM = True
|
| 10 |
+
_SPEAKER_FBANK_DITHER = 0.0
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def extract_speaker_fbank(
|
| 14 |
+
waveform: torch.Tensor,
|
| 15 |
+
*,
|
| 16 |
+
sample_rate: int,
|
| 17 |
+
) -> torch.Tensor:
|
| 18 |
+
feature_input = waveform
|
| 19 |
+
if sample_rate != _SPEAKER_FBANK_SAMPLE_RATE:
|
| 20 |
+
feature_input = high_quality_resample(
|
| 21 |
+
waveform,
|
| 22 |
+
orig_sr=sample_rate,
|
| 23 |
+
target_sr=_SPEAKER_FBANK_SAMPLE_RATE,
|
| 24 |
+
)
|
| 25 |
+
return extract_fbank(
|
| 26 |
+
feature_input,
|
| 27 |
+
sample_rate=_SPEAKER_FBANK_SAMPLE_RATE,
|
| 28 |
+
n_mels=_SPEAKER_FBANK_N_MELS,
|
| 29 |
+
dither=_SPEAKER_FBANK_DITHER,
|
| 30 |
+
mean_norm=_SPEAKER_FBANK_MEAN_NORM,
|
| 31 |
+
)
|
src/dots_tts/modules/vocoder/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Vocoder modules."""
|