Text-to-Speech
MLX
Supertonic
supertonic-3-mlx
supertonic-3
apple-silicon
tts
speech-synthesis
multilingual
flow-matching
Instructions to use ambassadia/supertonic-3-mlx with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use ambassadia/supertonic-3-mlx with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir supertonic-3-mlx ambassadia/supertonic-3-mlx
- Supertonic
How to use ambassadia/supertonic-3-mlx with Supertonic:
from supertonic import TTS tts = TTS(auto_download=True) style = tts.get_voice_style(voice_name="M1") text = "The train delay was announced at 4:45 PM on Wed, Apr 3, 2024 due to track maintenance." wav, duration = tts.synthesize(text, voice_style=style) tts.save_audio(wav, "output.wav")
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
v0.1.0 — initial release (MLX-native Supertonic 3, ~x100 RTF on M4)
Browse files- .gitattributes +6 -0
- LICENSE +209 -0
- LICENSE-CODE +202 -0
- NOTICE +39 -0
- README.md +239 -0
- bench_results.csv +7 -0
- conversion_report.json +226 -0
- examples/quickstart.py +23 -0
- pyproject.toml +43 -0
- samples/de_M2.wav +3 -0
- samples/en_F1_short.wav +3 -0
- samples/en_M1_long.wav +3 -0
- samples/es_M3.wav +3 -0
- samples/fr_F2.wav +3 -0
- samples/ja_F3.wav +3 -0
- src/supertonic_3_mlx/__init__.py +51 -0
- src/supertonic_3_mlx/__pycache__/__init__.cpython-312.pyc +0 -0
- src/supertonic_3_mlx/__pycache__/_config.cpython-312.pyc +0 -0
- src/supertonic_3_mlx/__pycache__/_nn_wrappers.cpython-312.pyc +0 -0
- src/supertonic_3_mlx/__pycache__/duration_predictor.cpython-312.pyc +0 -0
- src/supertonic_3_mlx/__pycache__/pipeline.cpython-312.pyc +0 -0
- src/supertonic_3_mlx/__pycache__/text_encoder.cpython-312.pyc +0 -0
- src/supertonic_3_mlx/__pycache__/vector_estimator.cpython-312.pyc +0 -0
- src/supertonic_3_mlx/__pycache__/vocoder.cpython-312.pyc +0 -0
- src/supertonic_3_mlx/_config.py +58 -0
- src/supertonic_3_mlx/_nn_wrappers.py +50 -0
- src/supertonic_3_mlx/duration_predictor.py +347 -0
- src/supertonic_3_mlx/pipeline.py +545 -0
- src/supertonic_3_mlx/text_encoder.py +382 -0
- src/supertonic_3_mlx/vector_estimator.py +765 -0
- src/supertonic_3_mlx/vocoder.py +304 -0
- src/supertonic_3_mlx/weights.py +152 -0
- unicode_indexer.json +0 -0
- voice_styles/F1.json +0 -0
- voice_styles/F2.json +0 -0
- voice_styles/F3.json +0 -0
- voice_styles/F4.json +0 -0
- voice_styles/F5.json +0 -0
- voice_styles/M1.json +0 -0
- voice_styles/M2.json +0 -0
- voice_styles/M3.json +0 -0
- voice_styles/M4.json +0 -0
- voice_styles/M5.json +0 -0
- weights/duration_predictor.safetensors +3 -0
- weights/text_encoder.safetensors +3 -0
- weights/vector_estimator.safetensors +3 -0
- weights/vocoder.safetensors +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
samples/de_M2.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
samples/en_F1_short.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
samples/en_M1_long.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
samples/es_M3.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
samples/fr_F2.wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
samples/ja_F3.wav filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BigScience Open RAIL-M License
|
| 2 |
+
dated August 18, 2022
|
| 3 |
+
|
| 4 |
+
Section I: PREAMBLE
|
| 5 |
+
|
| 6 |
+
This Open RAIL-M License was created by BigScience, a collaborative open innovation project aimed at
|
| 7 |
+
the responsible development and use of large multilingual datasets and Large Language Models
|
| 8 |
+
(“LLMs”). While a similar license was originally designed for the BLOOM model, we decided to adapt it
|
| 9 |
+
and create this license in order to propose a general open and responsible license applicable to other
|
| 10 |
+
machine learning based AI models (e.g. multimodal generative models).
|
| 11 |
+
In short, this license strives for both the open and responsible downstream use of the accompanying
|
| 12 |
+
model. When it comes to the open character, we took inspiration from open source permissive licenses
|
| 13 |
+
regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based
|
| 14 |
+
restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be
|
| 15 |
+
able to enforce the license in case potential misuses of the Model may occur. Even though downstream
|
| 16 |
+
derivative versions of the model could be released under different licensing terms, the latter will always
|
| 17 |
+
have to include - at minimum - the same use-based restrictions as the ones in the original license (this
|
| 18 |
+
license).
|
| 19 |
+
The development and use of artificial intelligence (“AI”), does not come without concerns. The world has
|
| 20 |
+
witnessed how AI techniques may, in some instances, become risky for the public in general. These risks
|
| 21 |
+
come in many forms, from racial discrimination to the misuse of sensitive information.
|
| 22 |
+
BigScience believes in the intersection between open and responsible AI development; thus, this License
|
| 23 |
+
aims to strike a balance between both in order to enable responsible open-science in the field of AI.
|
| 24 |
+
This License governs the use of the model (and its derivatives) and is informed by the model card
|
| 25 |
+
associated with the model.
|
| 26 |
+
|
| 27 |
+
NOW THEREFORE, You and Licensor agree as follows:
|
| 28 |
+
|
| 29 |
+
1. Definitions
|
| 30 |
+
(a) "License" means the terms and conditions for use, reproduction, and Distribution as defined in
|
| 31 |
+
this document.
|
| 32 |
+
(b) “Data” means a collection of information and/or content extracted from the dataset used with the
|
| 33 |
+
Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under
|
| 34 |
+
this License.
|
| 35 |
+
(c)“Output” means the results of operating a Model as embodied in informational content resulting
|
| 36 |
+
therefrom.
|
| 37 |
+
(d)“Model” means any accompanying machine-learning based assemblies (including checkpoints),
|
| 38 |
+
consisting of learnt weights, parameters (including optimizer states), corresponding to the model
|
| 39 |
+
architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or
|
| 40 |
+
in part on the Data, using the Complementary Material.
|
| 41 |
+
(e) “Derivatives of the Model” means all modifications to the Model, works based on the Model, or any
|
| 42 |
+
other model which is created or initialized by transfer of patterns of the weights, parameters,
|
| 43 |
+
activations or output of the Model, to the other model, in order to cause the other model to perform
|
| 44 |
+
similarly to the Model, including - but not limited to - distillation methods entailing the use of
|
| 45 |
+
intermediate data representations or methods based on the generation of synthetic data by the Model
|
| 46 |
+
for training the other model.
|
| 47 |
+
(f)“Complementary Material” means the accompanying source code and scripts used to define,
|
| 48 |
+
run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if
|
| 49 |
+
any. This includes any accompanying documentation, tutorials, examples, etc, if any.
|
| 50 |
+
(g) “Distribution” means any transmission, reproduction, publication or other sharing of the Model or
|
| 51 |
+
Derivatives of the Model to a third party, including providing the Model as a hosted service made
|
| 52 |
+
available by electronic or other remote means - e.g. API-based or web access.
|
| 53 |
+
(h) “Licensor” means the copyright owner or entity authorized by the copyright owner that is
|
| 54 |
+
granting the License, including the persons or entities that may have rights in the Model and/or
|
| 55 |
+
distributing the Model.
|
| 56 |
+
(i) "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this
|
| 57 |
+
License and/or making use of the Model for whichever purpose and in any field of use, including
|
| 58 |
+
usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
|
| 59 |
+
(j) “Third Parties” means individuals or legal entities that are not under common control with
|
| 60 |
+
Licensor or You.
|
| 61 |
+
(k) "Contribution" means any work of authorship, including the original version of the Model and
|
| 62 |
+
any modifications or additions to that Model or Derivatives of the Model thereof, that is
|
| 63 |
+
intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an
|
| 64 |
+
individual or Legal Entity authorized to submit on behalf of the copyright owner. For the
|
| 65 |
+
purposes of this definition,
|
| 66 |
+
“submitted” means any form of electronic, verbal, or written
|
| 67 |
+
communication sent to the Licensor or its representatives, including but not limited to
|
| 68 |
+
communication on electronic mailing lists, source code control systems, and issue tracking
|
| 69 |
+
systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and
|
| 70 |
+
improving the Model, but excluding communication that is conspicuously marked or otherwise
|
| 71 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 72 |
+
(l) "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a
|
| 73 |
+
Contribution has been received by Licensor and subsequently incorporated within the Model.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
Section II: INTELLECTUAL PROPERTY RIGHTS
|
| 77 |
+
|
| 78 |
+
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary
|
| 79 |
+
Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
|
| 80 |
+
|
| 81 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor
|
| 82 |
+
hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the
|
| 83 |
+
Complementary Material, the Model, and Derivatives of the Model.
|
| 84 |
+
|
| 85 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as
|
| 86 |
+
applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
|
| 87 |
+
royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer
|
| 88 |
+
to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such
|
| 89 |
+
license applies only to those patent claims licensable by such Contributor that are necessarily infringed by
|
| 90 |
+
their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such
|
| 91 |
+
Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim
|
| 92 |
+
or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution
|
| 93 |
+
incorporated within the Model and/or Complementary Material constitutes direct or contributory patent
|
| 94 |
+
infringement, then any patent licenses granted to You under this License for the Model and/or Work shall
|
| 95 |
+
terminate as of the date such litigation is asserted or filed.
|
| 96 |
+
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
| 97 |
+
|
| 98 |
+
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g.
|
| 99 |
+
software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof
|
| 100 |
+
in any medium, with or without modifications, provided that You meet the following conditions:
|
| 101 |
+
|
| 102 |
+
a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision
|
| 103 |
+
by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the
|
| 104 |
+
Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to,
|
| 105 |
+
that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply
|
| 106 |
+
to the use of Complementary Material.
|
| 107 |
+
|
| 108 |
+
b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this
|
| 109 |
+
License;
|
| 110 |
+
|
| 111 |
+
c. You must cause any modified files to carry prominent notices stating that You changed the files;
|
| 112 |
+
|
| 113 |
+
d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices
|
| 114 |
+
that do not pertain to any part of the Model, Derivatives of the Model.
|
| 115 |
+
You may add Your own copyright statement to Your modifications and may provide additional or
|
| 116 |
+
different license terms and conditions - respecting paragraph 4.a.
|
| 117 |
+
- for use, reproduction, or Distribution
|
| 118 |
+
of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use,
|
| 119 |
+
reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
|
| 120 |
+
|
| 121 |
+
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions.
|
| 122 |
+
Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You
|
| 123 |
+
may use the Model subject to this License, including only for lawful purposes and in accordance with the
|
| 124 |
+
License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or
|
| 125 |
+
reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model
|
| 126 |
+
to comply with the terms of this paragraph (paragraph 5).
|
| 127 |
+
|
| 128 |
+
6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You
|
| 129 |
+
generate using the Model. You are accountable for the Output you generate and its subsequent uses. No
|
| 130 |
+
use of the output can contravene any provision as stated in the License.
|
| 131 |
+
|
| 132 |
+
Section IV: OTHER PROVISIONS
|
| 133 |
+
|
| 134 |
+
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the
|
| 135 |
+
right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model
|
| 136 |
+
through electronic means, or modify the Output of the Model based on updates. You shall undertake
|
| 137 |
+
reasonable efforts to use the latest version of the Model.
|
| 138 |
+
|
| 139 |
+
8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks,
|
| 140 |
+
trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the
|
| 141 |
+
parties; and any rights not expressly granted herein are reserved by the Licensors.
|
| 142 |
+
|
| 143 |
+
9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides
|
| 144 |
+
the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS
|
| 145 |
+
IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
|
| 146 |
+
including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT,
|
| 147 |
+
MERCHANTABILITY , or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for
|
| 148 |
+
determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the
|
| 149 |
+
Complementary Material and assume any risks associated with Your exercise of permissions under this
|
| 150 |
+
License.
|
| 151 |
+
|
| 152 |
+
10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence),
|
| 153 |
+
contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or
|
| 154 |
+
agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect,
|
| 155 |
+
special, incidental, or consequential damages of any character arising as a result of this License or out of
|
| 156 |
+
the use or inability to use the Model and the Complementary Material (including but not limited to
|
| 157 |
+
damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other
|
| 158 |
+
commercial damages or losses), even if such Contributor has been advised of the possibility of such
|
| 159 |
+
damages.
|
| 160 |
+
|
| 161 |
+
11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the
|
| 162 |
+
Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance
|
| 163 |
+
of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License.
|
| 164 |
+
However, in accepting such obligations, You may act only on Your own behalf and on Your sole
|
| 165 |
+
responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and
|
| 166 |
+
hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor
|
| 167 |
+
by reason of your accepting any such warranty or additional liability.
|
| 168 |
+
|
| 169 |
+
12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining
|
| 170 |
+
provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
| 171 |
+
|
| 172 |
+
END OF TERMS AND CONDITIONS
|
| 173 |
+
|
| 174 |
+
Attachment A
|
| 175 |
+
|
| 176 |
+
Use Restrictions
|
| 177 |
+
|
| 178 |
+
You agree not to use the Model or Derivatives of the Model:
|
| 179 |
+
(a) In any way that violates any applicable national, federal, state, local or international law
|
| 180 |
+
or regulation;
|
| 181 |
+
(b) For the purpose of exploiting, harming or attempting to exploit or harm minors in any
|
| 182 |
+
way;
|
| 183 |
+
(c) To generate or disseminate verifiably false information and/or content with the purpose of
|
| 184 |
+
harming others;
|
| 185 |
+
(d) To generate or disseminate personal identifiable information that can be used to harm an
|
| 186 |
+
individual;
|
| 187 |
+
(e) To generate or disseminate information and/or content (e.g. images, code, posts, articles),
|
| 188 |
+
and place the information and/or content in any context (e.g. bot generating tweets)
|
| 189 |
+
without expressly and intelligibly disclaiming that the information and/or content is
|
| 190 |
+
machine generated;
|
| 191 |
+
(f) To defame, disparage or otherwise harass others;
|
| 192 |
+
(g) To impersonate or attempt to impersonate (e.g. deepfakes) others without their consent;
|
| 193 |
+
(h) For fully automated decision making that adversely impacts an individual’s legal rights or
|
| 194 |
+
otherwise creates or modifies a binding, enforceable obligation;
|
| 195 |
+
(i) For any use intended to or which has the effect of discriminating against or harming
|
| 196 |
+
individuals or groups based on online or offline social behavior or known or predicted
|
| 197 |
+
personal or personality characteristics;
|
| 198 |
+
(j) To exploit any of the vulnerabilities of a specific group of persons based on their age,
|
| 199 |
+
social, physical or mental characteristics, in order to materially distort the behavior of a
|
| 200 |
+
person pertaining to that group in a manner that causes or is likely to cause that person or
|
| 201 |
+
another person physical or psychological harm;
|
| 202 |
+
(k) For any use intended to or which has the effect of discriminating against individuals or
|
| 203 |
+
groups based on legally protected characteristics or categories;
|
| 204 |
+
(l) To provide medical advice and medical results interpretation;
|
| 205 |
+
(m) To generate or disseminate information for the purpose to be used for administration of
|
| 206 |
+
justice, law enforcement, immigration or asylum processes, such as predicting an
|
| 207 |
+
individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal
|
| 208 |
+
relationships between assertions made in documents, indiscriminate and
|
| 209 |
+
arbitrarily-targeted use).
|
LICENSE-CODE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright [yyyy] [name of copyright owner]
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
NOTICE
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
supertonic-3-mlx
|
| 2 |
+
================
|
| 3 |
+
|
| 4 |
+
This release is a derivative of the upstream Supertone Supertonic 3
|
| 5 |
+
text-to-speech model and consists of two artefact classes governed by
|
| 6 |
+
two different licenses:
|
| 7 |
+
|
| 8 |
+
1. The model weights (under ./weights/*.safetensors) are released under
|
| 9 |
+
the BigScience Open RAIL-M License. The full text is in ./LICENSE and
|
| 10 |
+
was copied verbatim from
|
| 11 |
+
https://huggingface.co/Supertone/supertonic-3/blob/main/LICENSE
|
| 12 |
+
The Attachment A use restrictions (Section 5 + Attachment A clauses
|
| 13 |
+
(a)–(m)) apply to all downstream use of the model and of any output
|
| 14 |
+
generated by the model.
|
| 15 |
+
|
| 16 |
+
2. The MLX port code (under ./src/supertonic_3_mlx/) is released under
|
| 17 |
+
the Apache License, Version 2.0. The full text is in ./LICENSE-CODE.
|
| 18 |
+
|
| 19 |
+
Attribution and modifications statement (BigScience Open RAIL-M Section 4.c):
|
| 20 |
+
|
| 21 |
+
Copyright (c) 2026 Supertone Inc. — original model weights and reference
|
| 22 |
+
Python/ONNX implementation. Distributed at
|
| 23 |
+
https://huggingface.co/Supertone/supertonic-3
|
| 24 |
+
Copyright (c) 2026 Olivier Dupont — MLX-native port code, weight format
|
| 25 |
+
conversion (ONNX → safetensors via the 3-stage extractor in
|
| 26 |
+
``src/supertonic_3_mlx/pipeline.py:_convert_onnx``), and pipeline
|
| 27 |
+
optimisations (``mx.compile`` of the CFG Euler loop, cross-attention
|
| 28 |
+
K/V cache shared across the 5 Euler steps). Distributed at
|
| 29 |
+
https://huggingface.co/ambassadia/supertonic-3-mlx
|
| 30 |
+
|
| 31 |
+
The MLX port does not modify the model's learned parameters in any
|
| 32 |
+
semantic sense — the only weight-level transformation is a tensor-shape
|
| 33 |
+
re-layout to match the MLX memory model (e.g. depthwise Conv1d
|
| 34 |
+
``(C, 1, K)`` → ``(C, K, 1)``). Bit-identical audio output to the
|
| 35 |
+
upstream ONNX Runtime reference is preserved up to FP32 accumulation
|
| 36 |
+
noise (cosine ≥ 0.98 on the full pipeline, cosine = 1.00 on the vocoder).
|
| 37 |
+
|
| 38 |
+
No use of the Supertone trademarks, logos, or trade dress is asserted or
|
| 39 |
+
permitted by this release (BigScience Open RAIL-M Section 8).
|
README.md
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: openrail
|
| 3 |
+
license_link: LICENSE
|
| 4 |
+
language:
|
| 5 |
+
- en
|
| 6 |
+
- fr
|
| 7 |
+
- de
|
| 8 |
+
- es
|
| 9 |
+
- it
|
| 10 |
+
- pt
|
| 11 |
+
- ja
|
| 12 |
+
- ko
|
| 13 |
+
- zh
|
| 14 |
+
- ru
|
| 15 |
+
- pl
|
| 16 |
+
- nl
|
| 17 |
+
- tr
|
| 18 |
+
- ar
|
| 19 |
+
- hi
|
| 20 |
+
- vi
|
| 21 |
+
- th
|
| 22 |
+
- id
|
| 23 |
+
- cs
|
| 24 |
+
- ro
|
| 25 |
+
- hu
|
| 26 |
+
- el
|
| 27 |
+
- da
|
| 28 |
+
- sv
|
| 29 |
+
- fi
|
| 30 |
+
- no
|
| 31 |
+
- he
|
| 32 |
+
- uk
|
| 33 |
+
- bg
|
| 34 |
+
- hr
|
| 35 |
+
- sk
|
| 36 |
+
pipeline_tag: text-to-speech
|
| 37 |
+
tags:
|
| 38 |
+
- mlx
|
| 39 |
+
- apple-silicon
|
| 40 |
+
- tts
|
| 41 |
+
- text-to-speech
|
| 42 |
+
- speech-synthesis
|
| 43 |
+
- supertonic
|
| 44 |
+
- multilingual
|
| 45 |
+
- flow-matching
|
| 46 |
+
library_name: supertonic-3-mlx
|
| 47 |
+
base_model: Supertone/supertonic-3
|
| 48 |
+
inference: false
|
| 49 |
+
---
|
| 50 |
+
|
| 51 |
+
# Supertonic 3 — MLX-native
|
| 52 |
+
|
| 53 |
+
**31-language text-to-speech, ~x100 realtime on Apple Silicon.**
|
| 54 |
+
Native MLX port of [Supertone/supertonic-3](https://huggingface.co/Supertone/supertonic-3),
|
| 55 |
+
runs the full flow-matching + classifier-free-guidance pipeline (DurationPredictor →
|
| 56 |
+
TextEncoder → 24-block VectorEstimator (5 Euler steps) → 10-block Vocos vocoder)
|
| 57 |
+
without ONNX, CoreML or any C++ runtime — only MLX + NumPy.
|
| 58 |
+
|
| 59 |
+
## Install
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
pip install supertonic-3-mlx
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
The package depends only on `mlx` and `numpy`. The optional `[hub]` extra adds
|
| 66 |
+
`huggingface_hub` for one-line weight downloads.
|
| 67 |
+
|
| 68 |
+
## Quickstart
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
from supertonic_3_mlx import Pipeline
|
| 72 |
+
|
| 73 |
+
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
|
| 74 |
+
wav = pipe.generate("Hello world from Apple Silicon.", voice="F1", lang="en")
|
| 75 |
+
|
| 76 |
+
# wav is a 1-D numpy.float32 array at 44.1 kHz
|
| 77 |
+
import soundfile as sf
|
| 78 |
+
sf.write("hello.wav", wav, pipe.sample_rate)
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
The first call downloads the ~400 MB weight bundle into your Hugging Face cache.
|
| 82 |
+
Subsequent calls re-use the cached weights and cold-start in ~11 ms on M4.
|
| 83 |
+
|
| 84 |
+
## Audio samples
|
| 85 |
+
|
| 86 |
+
Six languages, mix of male / female voices, mix of short and long utterances —
|
| 87 |
+
all generated by the MLX pipeline at the wall times reported below.
|
| 88 |
+
|
| 89 |
+
<audio controls src="samples/en_F1_short.wav"></audio> **EN · F1 · 2.79 s** —
|
| 90 |
+
"Hello world from Apple Silicon. Supertonic 3 runs at one hundred times real time."
|
| 91 |
+
|
| 92 |
+
<audio controls src="samples/en_M1_long.wav"></audio> **EN · M1 · 3.90 s** —
|
| 93 |
+
"A gentle breeze moved through the open window while the children, still half-asleep, listened to the distant sound of the harbour bells."
|
| 94 |
+
|
| 95 |
+
<audio controls src="samples/fr_F2.wav"></audio> **FR · F2 · 3.41 s** —
|
| 96 |
+
"Bonjour, ceci est un test de synthèse vocale en français. Le modèle gère trente-et-une langues sur une puce M4."
|
| 97 |
+
|
| 98 |
+
<audio controls src="samples/de_M2.wav"></audio> **DE · M2 · 3.69 s** —
|
| 99 |
+
"Guten Morgen. Dieses Modell läuft komplett auf Apple Silicon, ohne ONNX und ohne CoreML, in reinem MLX."
|
| 100 |
+
|
| 101 |
+
<audio controls src="samples/ja_F3.wav"></audio> **JA · F3 · 1.46 s** —
|
| 102 |
+
"こんにちは。これはアップルシリコン上でMLXを使ったテストです。"
|
| 103 |
+
|
| 104 |
+
<audio controls src="samples/es_M3.wav"></audio> **ES · M3 · 2.86 s** —
|
| 105 |
+
"Hola, esto es una prueba de síntesis de voz en español ejecutada en tiempo real sobre Apple Silicon."
|
| 106 |
+
|
| 107 |
+
## Benchmarks (Apple M4, FP32, median of 3)
|
| 108 |
+
|
| 109 |
+
| Sample | Duration | MLX wall | RTF | ONNX SDK | Speedup |
|
| 110 |
+
|-----------------|---------:|----------:|----------:|---------:|--------:|
|
| 111 |
+
| EN · F1 · short | 2.79 s | 36.6 ms | **x76** | 1005 ms | **28 ×**|
|
| 112 |
+
| EN · M1 · long | 3.90 s | 38.4 ms | **x102** | 1356 ms | **35 ×**|
|
| 113 |
+
| FR · F2 | 3.41 s | 37.9 ms | **x90** | 1196 ms | **32 ×**|
|
| 114 |
+
| DE · M2 | 3.69 s | 38.1 ms | **x97** | 1314 ms | **35 ×**|
|
| 115 |
+
| JA · F3 | 1.46 s | 32.1 ms | **x46** | 848 ms | **26 ×**|
|
| 116 |
+
| ES · M3 | 2.86 s | 37.0 ms | **x77** | 1002 ms | **27 ×**|
|
| 117 |
+
|
| 118 |
+
Raw numbers are in [`bench_results.csv`](bench_results.csv) (regenerable via
|
| 119 |
+
the source repo's
|
| 120 |
+
[`tools/supertonic3_samples_and_bench.py`](https://gitea.tavportal.com/olivier/MLX_CONVERTOR/src/branch/feat/platform-abc/tools/supertonic3_samples_and_bench.py)).
|
| 121 |
+
|
| 122 |
+
Reference comparison: the CoreML build of the same model on the same hardware
|
| 123 |
+
runs at ~x27 realtime. The MLX port is **~2-4× faster** end-to-end while
|
| 124 |
+
remaining bit-identical to the ONNX Runtime reference on the vocoder
|
| 125 |
+
(cosine 1.00) and at cosine ≥ 0.98 on the full estimator output.
|
| 126 |
+
|
| 127 |
+
## Voices
|
| 128 |
+
|
| 129 |
+
10 preset voices — five female (`F1`–`F5`) and five male (`M1`–`M5`). The
|
| 130 |
+
`voice_styles/` directory contains both `style_ttl` (50×256 latent style for
|
| 131 |
+
the audio path) and `style_dp` (8×16 style for the duration head) for each
|
| 132 |
+
voice. Pass the voice name as the `voice=` kwarg to `Pipeline.generate`.
|
| 133 |
+
|
| 134 |
+
## Languages
|
| 135 |
+
|
| 136 |
+
31 languages supported. Pass the ISO 639-1 code as the `lang=` kwarg:
|
| 137 |
+
`en` `fr` `de` `es` `it` `pt` `ja` `ko` `zh` `ru` `pl` `nl` `tr` `ar` `hi`
|
| 138 |
+
`vi` `th` `id` `cs` `ro` `hu` `el` `da` `sv` `fi` `no` `he` `uk` `bg` `hr` `sk`.
|
| 139 |
+
|
| 140 |
+
## Architecture (short)
|
| 141 |
+
|
| 142 |
+
Four sub-models, all in `weights/*.safetensors`:
|
| 143 |
+
|
| 144 |
+
| Sub-model | Role | Params | Size |
|
| 145 |
+
|----------------------|-------------------------------------|--------|---------|
|
| 146 |
+
| `vector_estimator` | 24-block CFG flow-matching velocity | ~64 M | 256 MB |
|
| 147 |
+
| `text_encoder` | Character → 256-D text embedding | ~9 M | 36 MB |
|
| 148 |
+
| `duration_predictor` | Text → seconds | ~1 M | 3.5 MB |
|
| 149 |
+
| `vocoder` | Latent (B,144,T) → 44.1 kHz wav | ~25 M | 101 MB |
|
| 150 |
+
|
| 151 |
+
The pipeline runs **exactly 5 Euler steps** with classifier-free guidance
|
| 152 |
+
(`4×cond − 3×uncond`). This schedule is trained-in: reducing the step count
|
| 153 |
+
or disabling CFG produces an essentially uncorrelated waveform (verified
|
| 154 |
+
empirically — see the `bench_n_steps.py` script in the source repo).
|
| 155 |
+
|
| 156 |
+
## Loading from a local snapshot
|
| 157 |
+
|
| 158 |
+
Three layouts are auto-detected by `Pipeline.from_pretrained`:
|
| 159 |
+
|
| 160 |
+
1. **Hugging Face repo id** (e.g. `"ambassadia/supertonic-3-mlx"`) — auto-download
|
| 161 |
+
2. **Local path containing `weights/`** (this layout) — fastest cold-load
|
| 162 |
+
3. **Local path containing `onnx/`** (upstream snapshot) — converts at load time
|
| 163 |
+
|
| 164 |
+
## License
|
| 165 |
+
|
| 166 |
+
This release combines two artefact classes under two distinct licenses:
|
| 167 |
+
|
| 168 |
+
- **Model weights** (`weights/*.safetensors`) — **BigScience Open RAIL-M**.
|
| 169 |
+
See [`LICENSE`](LICENSE) for the full text. The Attachment A use
|
| 170 |
+
restrictions are reproduced below and apply to all downstream use of the
|
| 171 |
+
model and of generated audio.
|
| 172 |
+
- **Port code** (`src/supertonic_3_mlx/`) — **Apache License 2.0**. See
|
| 173 |
+
[`LICENSE-CODE`](LICENSE-CODE).
|
| 174 |
+
|
| 175 |
+
See [`NOTICE`](NOTICE) for the modifications statement and the upstream
|
| 176 |
+
attribution.
|
| 177 |
+
|
| 178 |
+
### OpenRAIL-M Attachment A — use restrictions
|
| 179 |
+
|
| 180 |
+
You agree not to use the model or derivatives:
|
| 181 |
+
|
| 182 |
+
(a) In any way that violates any applicable national, federal, state, local or
|
| 183 |
+
international law or regulation.
|
| 184 |
+
|
| 185 |
+
(b) For the purpose of exploiting, harming or attempting to exploit or harm
|
| 186 |
+
minors in any way.
|
| 187 |
+
|
| 188 |
+
(c) To generate or disseminate verifiably false information and/or content
|
| 189 |
+
with the purpose of harming others.
|
| 190 |
+
|
| 191 |
+
(d) To generate or disseminate personal identifiable information that can be
|
| 192 |
+
used to harm an individual.
|
| 193 |
+
|
| 194 |
+
(e) To generate or disseminate information and/or content (e.g. images, code,
|
| 195 |
+
posts, articles), and place the information and/or content in any context
|
| 196 |
+
(e.g. bot generating tweets) **without expressly and intelligibly disclaiming
|
| 197 |
+
that the information and/or content is machine generated**.
|
| 198 |
+
|
| 199 |
+
(f) To defame, disparage or otherwise harass others.
|
| 200 |
+
|
| 201 |
+
(g) To impersonate or attempt to impersonate (e.g. **deepfakes**) others
|
| 202 |
+
without their consent.
|
| 203 |
+
|
| 204 |
+
(h) For fully automated decision making that adversely impacts an individual's
|
| 205 |
+
legal rights or otherwise creates or modifies a binding, enforceable obligation.
|
| 206 |
+
|
| 207 |
+
(i) For any use intended to or which has the effect of discriminating against
|
| 208 |
+
or harming individuals or groups based on online or offline social behavior or
|
| 209 |
+
known or predicted personal or personality characteristics.
|
| 210 |
+
|
| 211 |
+
(j) To exploit any of the vulnerabilities of a specific group of persons based
|
| 212 |
+
on their age, social, physical or mental characteristics, in order to materially
|
| 213 |
+
distort the behavior of a person pertaining to that group in a manner that
|
| 214 |
+
causes or is likely to cause that person or another person physical or
|
| 215 |
+
psychological harm.
|
| 216 |
+
|
| 217 |
+
(k) For any use intended to or which has the effect of discriminating against
|
| 218 |
+
individuals or groups based on legally protected characteristics or categories.
|
| 219 |
+
|
| 220 |
+
(l) **To provide medical advice and medical results interpretation.**
|
| 221 |
+
|
| 222 |
+
(m) To generate or disseminate information for the purpose to be used for
|
| 223 |
+
administration of justice, law enforcement, immigration or asylum processes,
|
| 224 |
+
such as predicting an individual will commit fraud/crime commitment.
|
| 225 |
+
|
| 226 |
+
## Citation
|
| 227 |
+
|
| 228 |
+
```bibtex
|
| 229 |
+
@misc{supertonic3-mlx,
|
| 230 |
+
title = {Supertonic 3 MLX: native Apple Silicon port of Supertone's multilingual TTS},
|
| 231 |
+
author = {Dupont, Olivier},
|
| 232 |
+
year = {2026},
|
| 233 |
+
url = {https://huggingface.co/ambassadia/supertonic-3-mlx},
|
| 234 |
+
note = {Derivative of Supertone/supertonic-3 (https://huggingface.co/Supertone/supertonic-3)}
|
| 235 |
+
}
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
Please also cite the upstream Supertone Supertonic 3 model when using this
|
| 239 |
+
port.
|
bench_results.csv
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
filename,language,voice,text,duration_s,mlx_ms_median,rtf_mlx,onnx_ms_median,rtf_onnx,speedup_mlx_over_onnx
|
| 2 |
+
samples/en_F1_short.wav,en,F1,Hello world from Apple Silicon. Supertonic 3 runs at one hundred times real time.,2.786,36.6,76.2,1004.7,2.8,27.5
|
| 3 |
+
samples/en_M1_long.wav,en,M1,"A gentle breeze moved through the open window while the children, still half-asleep, listened to the distant sound of the harbour bells.",3.901,38.4,101.7,1356.0,2.9,35.3
|
| 4 |
+
samples/fr_F2.wav,fr,F2,"Bonjour, ceci est un test de synthèse vocale en français. Le modèle gère trente-et-une langues sur une puce M4.",3.413,37.9,90.1,1195.6,2.9,31.6
|
| 5 |
+
samples/de_M2.wav,de,M2,"Guten Morgen. Dieses Modell läuft komplett auf Apple Silicon, ohne ONNX und ohne CoreML, in reinem MLX.",3.692,38.1,96.9,1313.9,2.8,34.5
|
| 6 |
+
samples/ja_F3.wav,ja,F3,こんにちは。これはアップルシリコン上でMLXを使ったテストです。,1.463,32.1,45.6,848.4,1.7,26.4
|
| 7 |
+
samples/es_M3.wav,es,M3,"Hola, esto es una prueba de síntesis de voz en español ejecutada en tiempo real sobre Apple Silicon.",2.856,37.0,77.2,1002.1,2.9,27.1
|
conversion_report.json
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"models": [
|
| 3 |
+
{
|
| 4 |
+
"model": "VectorEstimator",
|
| 5 |
+
"onnx": "/tmp/supertonic3/model/onnx/vector_estimator.onnx",
|
| 6 |
+
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/vector_estimator.safetensors",
|
| 7 |
+
"bytes": 256053073,
|
| 8 |
+
"sha256": "2359240f2dcaee03b4800102aa0bea00223d2867ab752ef01af2b1cfaf92f3a6",
|
| 9 |
+
"weights_kept": 351,
|
| 10 |
+
"weights_dropped": 120,
|
| 11 |
+
"dropped_detail": {
|
| 12 |
+
"tts.ae.vector_field.proj_in.net.weight": "not-in-model",
|
| 13 |
+
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv1.weight": "not-in-model",
|
| 14 |
+
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv1.bias": "not-in-model",
|
| 15 |
+
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv2.weight": "not-in-model",
|
| 16 |
+
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv2.bias": "not-in-model",
|
| 17 |
+
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv1.weight": "not-in-model",
|
| 18 |
+
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv1.bias": "not-in-model",
|
| 19 |
+
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv2.weight": "not-in-model",
|
| 20 |
+
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv2.bias": "not-in-model",
|
| 21 |
+
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv1.weight": "not-in-model",
|
| 22 |
+
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv1.bias": "not-in-model",
|
| 23 |
+
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv2.weight": "not-in-model",
|
| 24 |
+
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv2.bias": "not-in-model",
|
| 25 |
+
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv1.weight": "not-in-model",
|
| 26 |
+
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv1.bias": "not-in-model",
|
| 27 |
+
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv2.weight": "not-in-model",
|
| 28 |
+
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv2.bias": "not-in-model",
|
| 29 |
+
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv1.weight": "not-in-model",
|
| 30 |
+
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv1.bias": "not-in-model",
|
| 31 |
+
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv2.weight": "not-in-model",
|
| 32 |
+
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv2.bias": "not-in-model",
|
| 33 |
+
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv1.weight": "not-in-model",
|
| 34 |
+
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv1.bias": "not-in-model",
|
| 35 |
+
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv2.weight": "not-in-model",
|
| 36 |
+
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv2.bias": "not-in-model",
|
| 37 |
+
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv1.weight": "not-in-model",
|
| 38 |
+
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv1.bias": "not-in-model",
|
| 39 |
+
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv2.weight": "not-in-model",
|
| 40 |
+
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv2.bias": "not-in-model",
|
| 41 |
+
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv1.weight": "not-in-model",
|
| 42 |
+
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv1.bias": "not-in-model",
|
| 43 |
+
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv2.weight": "not-in-model",
|
| 44 |
+
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv2.bias": "not-in-model",
|
| 45 |
+
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv1.weight": "not-in-model",
|
| 46 |
+
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv1.bias": "not-in-model",
|
| 47 |
+
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv2.weight": "not-in-model",
|
| 48 |
+
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv2.bias": "not-in-model",
|
| 49 |
+
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv1.weight": "not-in-model",
|
| 50 |
+
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv1.bias": "not-in-model",
|
| 51 |
+
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv2.weight": "not-in-model",
|
| 52 |
+
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv2.bias": "not-in-model",
|
| 53 |
+
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv1.weight": "not-in-model",
|
| 54 |
+
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv1.bias": "not-in-model",
|
| 55 |
+
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv2.weight": "not-in-model",
|
| 56 |
+
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv2.bias": "not-in-model",
|
| 57 |
+
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv1.weight": "not-in-model",
|
| 58 |
+
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv1.bias": "not-in-model",
|
| 59 |
+
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv2.weight": "not-in-model",
|
| 60 |
+
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv2.bias": "not-in-model",
|
| 61 |
+
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv1.weight": "not-in-model",
|
| 62 |
+
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv1.bias": "not-in-model",
|
| 63 |
+
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv2.weight": "not-in-model",
|
| 64 |
+
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv2.bias": "not-in-model",
|
| 65 |
+
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv1.weight": "not-in-model",
|
| 66 |
+
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv1.bias": "not-in-model",
|
| 67 |
+
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv2.weight": "not-in-model",
|
| 68 |
+
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv2.bias": "not-in-model",
|
| 69 |
+
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv1.weight": "not-in-model",
|
| 70 |
+
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv1.bias": "not-in-model",
|
| 71 |
+
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv2.weight": "not-in-model",
|
| 72 |
+
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv2.bias": "not-in-model",
|
| 73 |
+
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv1.weight": "not-in-model",
|
| 74 |
+
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv1.bias": "not-in-model",
|
| 75 |
+
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv2.weight": "not-in-model",
|
| 76 |
+
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv2.bias": "not-in-model",
|
| 77 |
+
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv1.weight": "not-in-model",
|
| 78 |
+
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv1.bias": "not-in-model",
|
| 79 |
+
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv2.weight": "not-in-model",
|
| 80 |
+
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv2.bias": "not-in-model",
|
| 81 |
+
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv1.weight": "not-in-model",
|
| 82 |
+
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv1.bias": "not-in-model",
|
| 83 |
+
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv2.weight": "not-in-model",
|
| 84 |
+
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv2.bias": "not-in-model",
|
| 85 |
+
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv1.weight": "not-in-model",
|
| 86 |
+
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv1.bias": "not-in-model",
|
| 87 |
+
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv2.weight": "not-in-model",
|
| 88 |
+
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv2.bias": "not-in-model",
|
| 89 |
+
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv1.weight": "not-in-model",
|
| 90 |
+
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv1.bias": "not-in-model",
|
| 91 |
+
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv2.weight": "not-in-model",
|
| 92 |
+
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv2.bias": "not-in-model",
|
| 93 |
+
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv1.weight": "not-in-model",
|
| 94 |
+
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv1.bias": "not-in-model",
|
| 95 |
+
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv2.weight": "not-in-model",
|
| 96 |
+
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv2.bias": "not-in-model",
|
| 97 |
+
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv1.weight": "not-in-model",
|
| 98 |
+
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv1.bias": "not-in-model",
|
| 99 |
+
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv2.weight": "not-in-model",
|
| 100 |
+
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv2.bias": "not-in-model",
|
| 101 |
+
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv1.weight": "not-in-model",
|
| 102 |
+
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv1.bias": "not-in-model",
|
| 103 |
+
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv2.weight": "not-in-model",
|
| 104 |
+
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv2.bias": "not-in-model",
|
| 105 |
+
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv1.weight": "not-in-model",
|
| 106 |
+
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv1.bias": "not-in-model",
|
| 107 |
+
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv2.weight": "not-in-model",
|
| 108 |
+
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv2.bias": "not-in-model",
|
| 109 |
+
"tts.ae.vector_field.last_convnext.convnext.0.pwconv1.weight": "not-in-model",
|
| 110 |
+
"tts.ae.vector_field.last_convnext.convnext.0.pwconv1.bias": "not-in-model",
|
| 111 |
+
"tts.ae.vector_field.last_convnext.convnext.0.pwconv2.weight": "not-in-model",
|
| 112 |
+
"tts.ae.vector_field.last_convnext.convnext.0.pwconv2.bias": "not-in-model",
|
| 113 |
+
"tts.ae.vector_field.last_convnext.convnext.1.pwconv1.weight": "not-in-model",
|
| 114 |
+
"tts.ae.vector_field.last_convnext.convnext.1.pwconv1.bias": "not-in-model",
|
| 115 |
+
"tts.ae.vector_field.last_convnext.convnext.1.pwconv2.weight": "not-in-model",
|
| 116 |
+
"tts.ae.vector_field.last_convnext.convnext.1.pwconv2.bias": "not-in-model",
|
| 117 |
+
"tts.ae.vector_field.last_convnext.convnext.2.pwconv1.weight": "not-in-model",
|
| 118 |
+
"tts.ae.vector_field.last_convnext.convnext.2.pwconv1.bias": "not-in-model",
|
| 119 |
+
"tts.ae.vector_field.last_convnext.convnext.2.pwconv2.weight": "not-in-model",
|
| 120 |
+
"tts.ae.vector_field.last_convnext.convnext.2.pwconv2.bias": "not-in-model",
|
| 121 |
+
"tts.ae.vector_field.last_convnext.convnext.3.pwconv1.weight": "not-in-model",
|
| 122 |
+
"tts.ae.vector_field.last_convnext.convnext.3.pwconv1.bias": "not-in-model",
|
| 123 |
+
"tts.ae.vector_field.last_convnext.convnext.3.pwconv2.weight": "not-in-model",
|
| 124 |
+
"tts.ae.vector_field.last_convnext.convnext.3.pwconv2.bias": "not-in-model",
|
| 125 |
+
"tts.ae.vector_field.proj_out.net.weight": "not-in-model",
|
| 126 |
+
"<missing>.vector_field.main_blocks.9.attn.theta": "expected-but-not-extracted",
|
| 127 |
+
"<missing>.vector_field.main_blocks.9.attn.increments": "expected-but-not-extracted",
|
| 128 |
+
"<missing>.vector_field.main_blocks.15.attn.theta": "expected-but-not-extracted",
|
| 129 |
+
"<missing>.vector_field.main_blocks.15.attn.increments": "expected-but-not-extracted",
|
| 130 |
+
"<missing>.vector_field.main_blocks.21.attn.theta": "expected-but-not-extracted",
|
| 131 |
+
"<missing>.vector_field.main_blocks.21.attn.increments": "expected-but-not-extracted"
|
| 132 |
+
},
|
| 133 |
+
"elapsed_s": 0.289
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"model": "TextEncoder",
|
| 137 |
+
"onnx": "/tmp/supertonic3/model/onnx/text_encoder.onnx",
|
| 138 |
+
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/text_encoder.safetensors",
|
| 139 |
+
"bytes": 36022466,
|
| 140 |
+
"sha256": "9df20bb79496718b36d2c0fc37636d3f78d6ef751b2899ff6dfeb975ae737ada",
|
| 141 |
+
"weights_kept": 146,
|
| 142 |
+
"weights_dropped": 0,
|
| 143 |
+
"dropped_detail": {},
|
| 144 |
+
"elapsed_s": 0.035
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"model": "DurationPredictor",
|
| 148 |
+
"onnx": "/tmp/supertonic3/model/onnx/duration_predictor.onnx",
|
| 149 |
+
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/duration_predictor.safetensors",
|
| 150 |
+
"bytes": 3470807,
|
| 151 |
+
"sha256": "cd473acb6e0ac27426084488ccb3b3cc184e70d05db90897e2b892846db5dcb3",
|
| 152 |
+
"weights_kept": 98,
|
| 153 |
+
"weights_dropped": 0,
|
| 154 |
+
"dropped_detail": {},
|
| 155 |
+
"elapsed_s": 0.007
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"model": "Vocoder",
|
| 159 |
+
"onnx": "/tmp/supertonic3/model/onnx/vocoder.onnx",
|
| 160 |
+
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/vocoder.safetensors",
|
| 161 |
+
"bytes": 101364763,
|
| 162 |
+
"sha256": "b2ec31ab7c554f6e15b9a6780554b5d3502345de7848b310966bfb4e1ea4e526",
|
| 163 |
+
"weights_kept": 103,
|
| 164 |
+
"weights_dropped": 0,
|
| 165 |
+
"dropped_detail": {},
|
| 166 |
+
"elapsed_s": 0.079
|
| 167 |
+
}
|
| 168 |
+
],
|
| 169 |
+
"ancillary": [
|
| 170 |
+
{
|
| 171 |
+
"name": "unicode_indexer.json",
|
| 172 |
+
"bytes": 277676,
|
| 173 |
+
"sha256": "9bf7346e43883a81f8645c81224f786d43c5b57f3641f6e7671a7d6c493cb24f"
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"name": "voice_styles/F1.json",
|
| 177 |
+
"bytes": 292046,
|
| 178 |
+
"sha256": "bbdec6ee00231c2c742ad05483df5334cab3b52fda3ba38e6a07059c4563dbc2"
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"name": "voice_styles/F2.json",
|
| 182 |
+
"bytes": 292423,
|
| 183 |
+
"sha256": "7c722c6a72707b1a77f035d67f0d1351ba187738e06f7683e8c72b1df3477fc6"
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"name": "voice_styles/F3.json",
|
| 187 |
+
"bytes": 290794,
|
| 188 |
+
"sha256": "12f6ef2573baa2defa1128069cb59f203e3ab67c92af77b42df8a0e3a2f7c6ab"
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"name": "voice_styles/F4.json",
|
| 192 |
+
"bytes": 291808,
|
| 193 |
+
"sha256": "c2fa764c1225a76dfc3e2c73e8aa4f70d9ee48793860eb34c295fff01c2e032b"
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"name": "voice_styles/F5.json",
|
| 197 |
+
"bytes": 291479,
|
| 198 |
+
"sha256": "45966e73316415626cf41a7d1c6f3b4c70dbc1ba2bee5c1978ef0ce33244fc8d"
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"name": "voice_styles/M1.json",
|
| 202 |
+
"bytes": 291748,
|
| 203 |
+
"sha256": "e35604687f5d23694b8e91593a93eec0e4eca6c0b02bb8ed69139ab2ea6b0a5b"
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"name": "voice_styles/M2.json",
|
| 207 |
+
"bytes": 292055,
|
| 208 |
+
"sha256": "b76cbf62bac707c710cf0ae5aba5e31eea1a6339a9734bfae33ab98499534a50"
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"name": "voice_styles/M3.json",
|
| 212 |
+
"bytes": 290198,
|
| 213 |
+
"sha256": "ea1ac35ccb91b0d7ecad533a2fbd0eec10c91513d8951e3b25fbba99954e159b"
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"name": "voice_styles/M4.json",
|
| 217 |
+
"bytes": 291522,
|
| 218 |
+
"sha256": "ca8eefad4fcd989c9379032ff3e50738adc547eeb5e221b82593a6d7b3bac303"
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"name": "voice_styles/M5.json",
|
| 222 |
+
"bytes": 291469,
|
| 223 |
+
"sha256": "dd22b92740314321f8ae11c5e87f8dd60d060f15dd3a632b5adf77f471f77af2"
|
| 224 |
+
}
|
| 225 |
+
]
|
| 226 |
+
}
|
examples/quickstart.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal Supertonic 3 MLX usage — 5 lines, no fluff.
|
| 2 |
+
|
| 3 |
+
Run from anywhere AFTER ``pip install supertonic-3-mlx`` (or from inside
|
| 4 |
+
this directory after ``pip install ./``):
|
| 5 |
+
|
| 6 |
+
python examples/quickstart.py
|
| 7 |
+
"""
|
| 8 |
+
from supertonic_3_mlx import Pipeline
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
|
| 11 |
+
# When the package has been pip-installed, this auto-downloads from the Hub
|
| 12 |
+
# (~ 400 MB) into the standard Hugging Face cache. After the first run, the
|
| 13 |
+
# weights are reused from cache and cold start is ~ 11 ms on M4.
|
| 14 |
+
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
|
| 15 |
+
|
| 16 |
+
wav = pipe.generate(
|
| 17 |
+
"Hello world from Apple Silicon. Supertonic 3 runs at one hundred times realtime.",
|
| 18 |
+
voice="F1", # one of F1..F5, M1..M5
|
| 19 |
+
lang="en", # ISO 639-1
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
sf.write("hello.wav", wav, pipe.sample_rate)
|
| 23 |
+
print(f"wrote hello.wav — {len(wav) / pipe.sample_rate:.2f}s of audio")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "supertonic-3-mlx"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "MLX-native port of Supertone's Supertonic 3 multilingual TTS (31 languages, ~x100 realtime on Apple Silicon)"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
authors = [{ name = "Olivier Dupont", email = "olivier.dupont@taviramonaco.com" }]
|
| 8 |
+
license = { text = "Apache-2.0 AND OpenRAIL-M" }
|
| 9 |
+
keywords = ["mlx", "tts", "speech-synthesis", "apple-silicon", "supertonic", "multilingual"]
|
| 10 |
+
classifiers = [
|
| 11 |
+
"Development Status :: 4 - Beta",
|
| 12 |
+
"Environment :: MacOS X",
|
| 13 |
+
"Intended Audience :: Developers",
|
| 14 |
+
"Intended Audience :: Science/Research",
|
| 15 |
+
"License :: OSI Approved :: Apache Software License",
|
| 16 |
+
"Operating System :: MacOS",
|
| 17 |
+
"Programming Language :: Python :: 3 :: Only",
|
| 18 |
+
"Programming Language :: Python :: 3.10",
|
| 19 |
+
"Programming Language :: Python :: 3.11",
|
| 20 |
+
"Programming Language :: Python :: 3.12",
|
| 21 |
+
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
| 22 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 23 |
+
]
|
| 24 |
+
dependencies = [
|
| 25 |
+
"mlx>=0.21.0",
|
| 26 |
+
"numpy>=1.24.0",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
[project.optional-dependencies]
|
| 30 |
+
hub = ["huggingface_hub>=0.26.0"]
|
| 31 |
+
dev = ["pytest>=8.3.0", "ruff>=0.7.0"]
|
| 32 |
+
|
| 33 |
+
[project.urls]
|
| 34 |
+
Homepage = "https://huggingface.co/ambassadia/supertonic-3-mlx"
|
| 35 |
+
Upstream = "https://huggingface.co/Supertone/supertonic-3"
|
| 36 |
+
Source = "https://gitea.tavportal.com/olivier/MLX_CONVERTOR"
|
| 37 |
+
|
| 38 |
+
[build-system]
|
| 39 |
+
requires = ["hatchling"]
|
| 40 |
+
build-backend = "hatchling.build"
|
| 41 |
+
|
| 42 |
+
[tool.hatch.build.targets.wheel]
|
| 43 |
+
packages = ["src/supertonic_3_mlx"]
|
samples/de_M2.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b363a0c1ca3f596e05ac001f65377d129613db67164dbedfbdf9b4c11d56e365
|
| 3 |
+
size 325676
|
samples/en_F1_short.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c78ec23d58e6befa6bbcd078631a11eb6c1b647dbb07750767bd29ed17205f6
|
| 3 |
+
size 245804
|
samples/en_M1_long.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2561326a9c4a28e7bb3cc5bdf6a36f23c97e454714c0bc5e63b8f8a981beac96
|
| 3 |
+
size 344108
|
samples/es_M3.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c19918dc6927d18825f1a8277b172f3482a5953526f19a3f0bbce3f911885822
|
| 3 |
+
size 251948
|
samples/fr_F2.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f37101850f032c2b59c20fc45b0d7fee794005c2bebb43737737427d09069d94
|
| 3 |
+
size 301100
|
samples/ja_F3.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c447d1d9899548fe646de2238e13ccf9b3c7ded9d16d39bd115a8e9d66d5ff1
|
| 3 |
+
size 129068
|
src/supertonic_3_mlx/__init__.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Supertonic 3 — MLX-native TTS for Apple Silicon.
|
| 2 |
+
|
| 3 |
+
31-language text-to-speech, 5 Euler steps with classifier-free guidance, in
|
| 4 |
+
pure MLX. On M4 the full pipeline runs at ~x100 realtime.
|
| 5 |
+
|
| 6 |
+
Quickstart
|
| 7 |
+
----------
|
| 8 |
+
|
| 9 |
+
from supertonic_3_mlx import Pipeline
|
| 10 |
+
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
|
| 11 |
+
wav = pipe.generate("Hello world from Apple Silicon.", voice="F1", lang="en")
|
| 12 |
+
# wav is a 1-D ``numpy.float32`` array at 44.1 kHz.
|
| 13 |
+
|
| 14 |
+
The model weights are released under the BigScience OpenRAIL-M license
|
| 15 |
+
(see LICENSE in the Hugging Face repository). This MLX port code is
|
| 16 |
+
Apache-2.0. Together they form a dual-license package; Attachment A use
|
| 17 |
+
restrictions of OpenRAIL-M govern downstream use of the generated audio.
|
| 18 |
+
|
| 19 |
+
Public API:
|
| 20 |
+
Pipeline — end-to-end TTS, ``from_pretrained`` + ``generate``
|
| 21 |
+
VectorEstimator — the 24-block CFG flow-matching net (sub-model 1/4)
|
| 22 |
+
TextEncoder — character → text embedding (sub-model 2/4)
|
| 23 |
+
DurationPredictor — text → duration in seconds (sub-model 3/4)
|
| 24 |
+
Vocoder — latent → 44.1 kHz waveform (sub-model 4/4)
|
| 25 |
+
"""
|
| 26 |
+
from supertonic_3_mlx._config import (
|
| 27 |
+
DIM, LATENT_CH, CONVNEXT_HIDDEN, CONVNEXT_K,
|
| 28 |
+
NUM_MAIN_BLOCKS, NUM_CYCLES, BLOCKS_PER_CYCLE, BLOCK_CYCLE, STACK4_DILATIONS,
|
| 29 |
+
TEXT_HEADS, TEXT_HEAD_DIM, TEXT_DIM, ROTARY_BASE, ROTARY_SCALE,
|
| 30 |
+
STYLE_HEADS, STYLE_HEAD_DIM, STYLE_LEN, STYLE_DIM,
|
| 31 |
+
TIME_EMB_DIM, TIME_MLP_HIDDEN,
|
| 32 |
+
EPS_LN, CHUNK_COMPRESS, LATENT_DIM, SAMPLE_RATE,
|
| 33 |
+
SUPERTONIC3_HF_REPO,
|
| 34 |
+
)
|
| 35 |
+
from supertonic_3_mlx.duration_predictor import DurationPredictor
|
| 36 |
+
from supertonic_3_mlx.text_encoder import TextEncoder
|
| 37 |
+
from supertonic_3_mlx.vector_estimator import VectorEstimator
|
| 38 |
+
from supertonic_3_mlx.vocoder import Vocoder
|
| 39 |
+
from supertonic_3_mlx.pipeline import SupertonicMLXPipeline as Pipeline
|
| 40 |
+
|
| 41 |
+
__all__ = [
|
| 42 |
+
"Pipeline",
|
| 43 |
+
"DurationPredictor", "TextEncoder", "VectorEstimator", "Vocoder",
|
| 44 |
+
"DIM", "LATENT_CH", "CONVNEXT_HIDDEN", "CONVNEXT_K",
|
| 45 |
+
"NUM_MAIN_BLOCKS", "NUM_CYCLES", "BLOCKS_PER_CYCLE", "BLOCK_CYCLE", "STACK4_DILATIONS",
|
| 46 |
+
"TEXT_HEADS", "TEXT_HEAD_DIM", "TEXT_DIM", "ROTARY_BASE", "ROTARY_SCALE",
|
| 47 |
+
"STYLE_HEADS", "STYLE_HEAD_DIM", "STYLE_LEN", "STYLE_DIM",
|
| 48 |
+
"TIME_EMB_DIM", "TIME_MLP_HIDDEN",
|
| 49 |
+
"EPS_LN", "CHUNK_COMPRESS", "LATENT_DIM", "SAMPLE_RATE",
|
| 50 |
+
"SUPERTONIC3_HF_REPO",
|
| 51 |
+
]
|
src/supertonic_3_mlx/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (2.56 kB). View file
|
|
|
src/supertonic_3_mlx/__pycache__/_config.cpython-312.pyc
ADDED
|
Binary file (2.21 kB). View file
|
|
|
src/supertonic_3_mlx/__pycache__/_nn_wrappers.cpython-312.pyc
ADDED
|
Binary file (3.24 kB). View file
|
|
|
src/supertonic_3_mlx/__pycache__/duration_predictor.cpython-312.pyc
ADDED
|
Binary file (22.4 kB). View file
|
|
|
src/supertonic_3_mlx/__pycache__/pipeline.cpython-312.pyc
ADDED
|
Binary file (28 kB). View file
|
|
|
src/supertonic_3_mlx/__pycache__/text_encoder.cpython-312.pyc
ADDED
|
Binary file (22.9 kB). View file
|
|
|
src/supertonic_3_mlx/__pycache__/vector_estimator.cpython-312.pyc
ADDED
|
Binary file (38.4 kB). View file
|
|
|
src/supertonic_3_mlx/__pycache__/vocoder.cpython-312.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
src/supertonic_3_mlx/_config.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Locked hyperparameters for Supertonic 3 MLX port.
|
| 2 |
+
|
| 3 |
+
Derived from the official ``Supertone/supertonic-3/onnx/tts.json``.
|
| 4 |
+
Changing these = re-running parity tests.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
# Vector estimator (the flow-matching denoiser)
|
| 9 |
+
DIM: int = 512 # backbone width
|
| 10 |
+
LATENT_CH: int = 144 # 24 * chunk_compress_factor (6)
|
| 11 |
+
CONVNEXT_HIDDEN: int = 2048 # main_blocks ConvNeXt intermediate dim (2× vs s2)
|
| 12 |
+
CONVNEXT_K: int = 5
|
| 13 |
+
LAST_CONVNEXT_NUM: int = 4 # last_convnext is a 4-layer stack (dilations [1,1,1,1])
|
| 14 |
+
|
| 15 |
+
# 24 main_blocks = 4 cycles × 6 sub-blocks (cycle: stack4, time, cn1, text_attn, cn1, style_attn)
|
| 16 |
+
NUM_CYCLES: int = 4
|
| 17 |
+
BLOCKS_PER_CYCLE: int = 6
|
| 18 |
+
NUM_MAIN_BLOCKS: int = NUM_CYCLES * BLOCKS_PER_CYCLE
|
| 19 |
+
BLOCK_CYCLE = ("stack4", "time", "cn1", "text_attn", "cn1", "style_attn")
|
| 20 |
+
|
| 21 |
+
# ConvNeXt stack 4 (in stack4 blocks) — dilation schedule
|
| 22 |
+
STACK4_DILATIONS = (1, 2, 4, 8)
|
| 23 |
+
|
| 24 |
+
# Text cross-attention (RoPE) — block type "text_attn"
|
| 25 |
+
TEXT_DIM: int = 256
|
| 26 |
+
TEXT_HEADS: int = 8 # 2× vs s2 (4)
|
| 27 |
+
TEXT_HEAD_DIM: int = DIM // TEXT_HEADS # 512/8 = 64
|
| 28 |
+
ROTARY_BASE: int = 10_000
|
| 29 |
+
ROTARY_SCALE: int = 10
|
| 30 |
+
|
| 31 |
+
# Style cross-attention — block type "style_attn"
|
| 32 |
+
STYLE_DIM: int = 256
|
| 33 |
+
STYLE_LEN: int = 50 # 50 style tokens (n_style)
|
| 34 |
+
STYLE_HEADS: int = 2
|
| 35 |
+
STYLE_HEAD_DIM: int = 128
|
| 36 |
+
|
| 37 |
+
# Time encoding (sinusoidal + MLP)
|
| 38 |
+
TIME_EMB_DIM: int = 64
|
| 39 |
+
TIME_MLP_HIDDEN: int = 256
|
| 40 |
+
|
| 41 |
+
# LayerNorm epsilon
|
| 42 |
+
EPS_LN: float = 1e-6
|
| 43 |
+
|
| 44 |
+
# Chunk compress factor (used by AE)
|
| 45 |
+
CHUNK_COMPRESS: int = 6
|
| 46 |
+
LATENT_DIM: int = 24 # ldim before chunk compression
|
| 47 |
+
|
| 48 |
+
# Sample rate
|
| 49 |
+
SAMPLE_RATE: int = 44_100
|
| 50 |
+
|
| 51 |
+
# HF references (will be pinned to SHA after first download)
|
| 52 |
+
SUPERTONIC3_HF_REPO: str = "Supertone/supertonic-3"
|
| 53 |
+
ONNX_VECTOR_ESTIMATOR: str = "onnx/vector_estimator.onnx"
|
| 54 |
+
ONNX_TEXT_ENCODER: str = "onnx/text_encoder.onnx"
|
| 55 |
+
ONNX_DURATION_PREDICTOR: str = "onnx/duration_predictor.onnx"
|
| 56 |
+
ONNX_VOCODER: str = "onnx/vocoder.onnx"
|
| 57 |
+
ONNX_TTS_JSON: str = "onnx/tts.json"
|
| 58 |
+
ONNX_UNICODE_INDEXER: str = "onnx/unicode_indexer.json"
|
src/supertonic_3_mlx/_nn_wrappers.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Small wrapper modules to match Supertonic 3 ONNX submodule nesting.
|
| 2 |
+
|
| 3 |
+
The s3 checkpoint nests primitives one level deeper than typical MLX modules:
|
| 4 |
+
- ``norm.norm.weight`` — LayerNorm wrapped in a Norm container
|
| 5 |
+
- ``linear.linear.weight`` — Linear wrapped in a Linear container
|
| 6 |
+
- ``W_query.linear.weight`` — attention projection wrapped
|
| 7 |
+
|
| 8 |
+
Mirroring this nesting lets us load the safetensors with ``model.load_weights(...)``
|
| 9 |
+
without any key remapping at load time.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import mlx.core as mx
|
| 14 |
+
import mlx.nn as nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class WrappedNorm(nn.Module):
|
| 18 |
+
"""Container with a single nested LayerNorm — produces key ``X.norm.weight``."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.norm = nn.LayerNorm(dim, eps=eps)
|
| 23 |
+
|
| 24 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 25 |
+
return self.norm(x)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class WrappedLinear(nn.Module):
|
| 29 |
+
"""Container with a single nested Linear — produces keys ``X.linear.weight/bias``."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, in_dim: int, out_dim: int, bias: bool = True) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.linear = nn.Linear(in_dim, out_dim, bias=bias)
|
| 34 |
+
|
| 35 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 36 |
+
return self.linear(x)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ProjConv1x1(nn.Module):
|
| 40 |
+
"""Conv1d k=1 expressed as ``self.net = Linear`` (matches ``proj_in.net.weight``)."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, in_dim: int, out_dim: int, bias: bool = True) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.net = nn.Linear(in_dim, out_dim, bias=bias)
|
| 45 |
+
|
| 46 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 47 |
+
return self.net(x)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
__all__ = ["WrappedNorm", "WrappedLinear", "ProjConv1x1"]
|
src/supertonic_3_mlx/duration_predictor.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Supertonic 3 duration predictor — predicts total audio duration in seconds.
|
| 2 |
+
|
| 3 |
+
Pipeline (channels-last NTC throughout):
|
| 4 |
+
|
| 5 |
+
text_ids [B, T] int64 character IDs
|
| 6 |
+
→ char_embed (Embedding 8322→64) [B, T, 64]
|
| 7 |
+
→ prepend sentence_token (1, 64, 1) [B, T+1, 64]
|
| 8 |
+
→ 6× ConvNeXt (dim=64, hidden=256, k=5, all dilations=1)
|
| 9 |
+
→ 2× RelPosSelfAttn (heads=2, head_dim=32, window=4) + norm + FFN + norm
|
| 10 |
+
→ proj_out (Conv1d k=1: 64→64) applied to slot 0 (sentence token)
|
| 11 |
+
→ concat with style_dp flattened (B, 8×16=128) [B, 192]
|
| 12 |
+
→ Linear(192 → 128) → PReLU → Linear(128 → 1) → exp → duration [B]
|
| 13 |
+
|
| 14 |
+
Inputs:
|
| 15 |
+
text_ids: (B, T) int — character indices
|
| 16 |
+
style_dp: (B, 8, 16) — style summary tokens
|
| 17 |
+
text_mask: (B, 1, T) — 1.0 valid, 0.0 padded
|
| 18 |
+
"""
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import mlx.core as mx
|
| 22 |
+
import mlx.nn as nn
|
| 23 |
+
|
| 24 |
+
from supertonic_3_mlx._config import EPS_LN
|
| 25 |
+
from supertonic_3_mlx._nn_wrappers import WrappedNorm
|
| 26 |
+
from supertonic_3_mlx.vector_estimator import _pad_sym_edge, _gelu_exact
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
DP_VOCAB = 8322
|
| 30 |
+
DP_DIM = 64
|
| 31 |
+
DP_CONVNEXT_HIDDEN = 256
|
| 32 |
+
DP_CONVNEXT_K = 5
|
| 33 |
+
DP_CONVNEXT_NUM_LAYERS = 6
|
| 34 |
+
DP_ATTN_NUM_LAYERS = 2
|
| 35 |
+
DP_ATTN_HEADS = 2
|
| 36 |
+
DP_ATTN_HEAD_DIM = DP_DIM // DP_ATTN_HEADS # 32
|
| 37 |
+
DP_FFN_HIDDEN = 256
|
| 38 |
+
DP_REL_POS_WINDOW = 4
|
| 39 |
+
DP_N_STYLE = 8
|
| 40 |
+
DP_STYLE_DIM = 16
|
| 41 |
+
DP_MLP_IN = DP_DIM + DP_N_STYLE * DP_STYLE_DIM # 64 + 128 = 192
|
| 42 |
+
DP_MLP_HIDDEN = 128
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class _DPConvNeXtBlock(nn.Module):
|
| 46 |
+
"""ConvNeXt block (dim=64, hidden=256, dilation=1)."""
|
| 47 |
+
|
| 48 |
+
def __init__(self) -> None:
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.dwconv = nn.Conv1d(
|
| 51 |
+
DP_DIM, DP_DIM, kernel_size=DP_CONVNEXT_K, padding=0,
|
| 52 |
+
dilation=1, groups=DP_DIM, bias=True,
|
| 53 |
+
)
|
| 54 |
+
self.norm = WrappedNorm(DP_DIM, eps=EPS_LN)
|
| 55 |
+
self.pwconv1 = nn.Linear(DP_DIM, DP_CONVNEXT_HIDDEN, bias=True)
|
| 56 |
+
self.pwconv2 = nn.Linear(DP_CONVNEXT_HIDDEN, DP_DIM, bias=True)
|
| 57 |
+
self.gamma = mx.zeros((DP_DIM,))
|
| 58 |
+
self.pad = (DP_CONVNEXT_K - 1) // 2
|
| 59 |
+
|
| 60 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 61 |
+
residual = x
|
| 62 |
+
y = _pad_sym_edge(x, self.pad)
|
| 63 |
+
y = self.dwconv(y)
|
| 64 |
+
y = self.norm(y)
|
| 65 |
+
y = self.pwconv1(y)
|
| 66 |
+
y = _gelu_exact(y)
|
| 67 |
+
y = self.pwconv2(y)
|
| 68 |
+
y = y * self.gamma
|
| 69 |
+
out = residual + y
|
| 70 |
+
if mask is not None:
|
| 71 |
+
out = out * mask
|
| 72 |
+
return out
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class _DPConvNeXtStack(nn.Module):
|
| 76 |
+
"""``convnext.[0..5]`` — 6 ConvNeXt blocks."""
|
| 77 |
+
|
| 78 |
+
def __init__(self) -> None:
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.convnext = [_DPConvNeXtBlock() for _ in range(DP_CONVNEXT_NUM_LAYERS)]
|
| 81 |
+
|
| 82 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 83 |
+
for b in self.convnext:
|
| 84 |
+
x = b(x, mask)
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class _DPConvLayer(nn.Module):
|
| 89 |
+
"""Conv1d k=1 with weight (out, 1, in) — matches ONNX storage."""
|
| 90 |
+
|
| 91 |
+
def __init__(self, in_dim: int, out_dim: int) -> None:
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.weight = mx.zeros((out_dim, 1, in_dim))
|
| 94 |
+
self.bias = mx.zeros((out_dim,))
|
| 95 |
+
|
| 96 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 97 |
+
return mx.conv1d(x, self.weight, stride=1, padding=0) + self.bias
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _dp_rel_to_abs(x: mx.array) -> mx.array:
|
| 101 |
+
"""(B, h, L, 2L-1) → (B, h, L, L) via VITS shifted-skew reshape."""
|
| 102 |
+
B, h, L, _ = x.shape
|
| 103 |
+
x = mx.concatenate([x, mx.zeros((B, h, L, 1), dtype=x.dtype)], axis=-1)
|
| 104 |
+
x_flat = x.reshape(B, h, L * 2 * L)
|
| 105 |
+
x_flat = mx.concatenate([x_flat, mx.zeros((B, h, L - 1), dtype=x.dtype)], axis=-1)
|
| 106 |
+
x_final = x_flat.reshape(B, h, L + 1, 2 * L - 1)
|
| 107 |
+
return x_final[:, :, :L, L - 1:]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _dp_abs_to_rel(x: mx.array) -> mx.array:
|
| 111 |
+
"""(B, h, L, L) → (B, h, L, 2L-1)."""
|
| 112 |
+
B, h, L, _ = x.shape
|
| 113 |
+
x = mx.concatenate([x, mx.zeros((B, h, L, L - 1), dtype=x.dtype)], axis=-1)
|
| 114 |
+
x_flat = x.reshape(B, h, L * (2 * L - 1))
|
| 115 |
+
x_flat = mx.concatenate([mx.zeros((B, h, L), dtype=x.dtype), x_flat], axis=-1)
|
| 116 |
+
x_final = x_flat.reshape(B, h, L, 2 * L)
|
| 117 |
+
return x_final[:, :, :, 1:]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _dp_slice_rel(rel: mx.array, length: int, window: int) -> mx.array:
|
| 121 |
+
"""(1, 2W+1, d) → (1, 2L-1, d) by zero-padding/slicing."""
|
| 122 |
+
pad_l = max(length - (window + 1), 0)
|
| 123 |
+
if pad_l > 0:
|
| 124 |
+
zero = mx.zeros((1, pad_l, rel.shape[-1]), dtype=rel.dtype)
|
| 125 |
+
padded = mx.concatenate([zero, rel, zero], axis=1)
|
| 126 |
+
else:
|
| 127 |
+
padded = rel
|
| 128 |
+
start = max(window + 1 - length, 0)
|
| 129 |
+
return padded[:, start: start + 2 * length - 1]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class _DPRelPosSelfAttn(nn.Module):
|
| 133 |
+
"""VITS-style rel-pos self-attention (2 heads × 32 head_dim, window=4).
|
| 134 |
+
|
| 135 |
+
Includes both rel-pos contributions (q × rel_k → logits, abs_to_rel(attn) × rel_v → out).
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def __init__(self) -> None:
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.conv_q = _DPConvLayer(DP_DIM, DP_DIM)
|
| 141 |
+
self.conv_k = _DPConvLayer(DP_DIM, DP_DIM)
|
| 142 |
+
self.conv_v = _DPConvLayer(DP_DIM, DP_DIM)
|
| 143 |
+
self.conv_o = _DPConvLayer(DP_DIM, DP_DIM)
|
| 144 |
+
self.emb_rel_k = mx.zeros((1, 2 * DP_REL_POS_WINDOW + 1, DP_ATTN_HEAD_DIM))
|
| 145 |
+
self.emb_rel_v = mx.zeros((1, 2 * DP_REL_POS_WINDOW + 1, DP_ATTN_HEAD_DIM))
|
| 146 |
+
|
| 147 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 148 |
+
B, T, _ = x.shape
|
| 149 |
+
H, D = DP_ATTN_HEADS, DP_ATTN_HEAD_DIM
|
| 150 |
+
q = self.conv_q(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
| 151 |
+
k = self.conv_k(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
| 152 |
+
v = self.conv_v(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
| 153 |
+
scale = D ** -0.5
|
| 154 |
+
|
| 155 |
+
logits = (q @ k.transpose(0, 1, 3, 2)) * scale
|
| 156 |
+
|
| 157 |
+
rel_k = _dp_slice_rel(self.emb_rel_k, T, DP_REL_POS_WINDOW)
|
| 158 |
+
rel_logits = q @ rel_k.transpose(0, 2, 1)[:, None, :, :]
|
| 159 |
+
rel_logits = _dp_rel_to_abs(rel_logits * scale)
|
| 160 |
+
logits = logits + rel_logits
|
| 161 |
+
|
| 162 |
+
if mask is not None:
|
| 163 |
+
key_mask = mask[:, :, 0][:, None, None, :]
|
| 164 |
+
neg_inf = mx.array(-1e4, dtype=logits.dtype)
|
| 165 |
+
logits = mx.where(key_mask.astype(mx.bool_), logits, neg_inf)
|
| 166 |
+
|
| 167 |
+
attn = mx.softmax(logits, axis=-1)
|
| 168 |
+
out = attn @ v
|
| 169 |
+
|
| 170 |
+
rel_v = _dp_slice_rel(self.emb_rel_v, T, DP_REL_POS_WINDOW)
|
| 171 |
+
rel_weights = _dp_abs_to_rel(attn)
|
| 172 |
+
out = out + rel_weights @ rel_v[:, None, :, :]
|
| 173 |
+
|
| 174 |
+
out = out.transpose(0, 2, 1, 3).reshape(B, T, H * D)
|
| 175 |
+
return self.conv_o(out)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class _DPFFN(nn.Module):
|
| 179 |
+
"""FFN with two Conv1d k=1 — 64 → 256 → 64, ReLU + mask."""
|
| 180 |
+
|
| 181 |
+
def __init__(self) -> None:
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.conv_1 = _DPConvLayer(DP_DIM, DP_FFN_HIDDEN)
|
| 184 |
+
self.conv_2 = _DPConvLayer(DP_FFN_HIDDEN, DP_DIM)
|
| 185 |
+
|
| 186 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 187 |
+
if mask is not None:
|
| 188 |
+
x = x * mask
|
| 189 |
+
y = self.conv_1(x)
|
| 190 |
+
y = mx.maximum(y, mx.array(0.0, dtype=y.dtype))
|
| 191 |
+
if mask is not None:
|
| 192 |
+
y = y * mask
|
| 193 |
+
y = self.conv_2(y)
|
| 194 |
+
if mask is not None:
|
| 195 |
+
y = y * mask
|
| 196 |
+
return y
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class _DPAttnEncoder(nn.Module):
|
| 200 |
+
"""2× (attn + norm) + (ffn + norm)."""
|
| 201 |
+
|
| 202 |
+
def __init__(self) -> None:
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.attn_layers = [_DPRelPosSelfAttn() for _ in range(DP_ATTN_NUM_LAYERS)]
|
| 205 |
+
self.norm_layers_1 = [WrappedNorm(DP_DIM, eps=EPS_LN) for _ in range(DP_ATTN_NUM_LAYERS)]
|
| 206 |
+
self.ffn_layers = [_DPFFN() for _ in range(DP_ATTN_NUM_LAYERS)]
|
| 207 |
+
self.norm_layers_2 = [WrappedNorm(DP_DIM, eps=EPS_LN) for _ in range(DP_ATTN_NUM_LAYERS)]
|
| 208 |
+
|
| 209 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 210 |
+
for i in range(DP_ATTN_NUM_LAYERS):
|
| 211 |
+
x = self.norm_layers_1[i](x + self.attn_layers[i](x, mask=mask))
|
| 212 |
+
x = self.norm_layers_2[i](x + self.ffn_layers[i](x, mask))
|
| 213 |
+
return x
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class _DPSentenceEncoder(nn.Module):
|
| 217 |
+
"""Text → 64-d sentence vector via prepended ``sentence_token`` slot."""
|
| 218 |
+
|
| 219 |
+
def __init__(self) -> None:
|
| 220 |
+
super().__init__()
|
| 221 |
+
class _TextEmb(nn.Module):
|
| 222 |
+
def __init__(_):
|
| 223 |
+
super().__init__()
|
| 224 |
+
_.char_embedder = nn.Embedding(DP_VOCAB, DP_DIM)
|
| 225 |
+
def __call__(_, ids):
|
| 226 |
+
return _.char_embedder(ids)
|
| 227 |
+
self.text_embedder = _TextEmb()
|
| 228 |
+
self.convnext = _DPConvNeXtStack()
|
| 229 |
+
self.attn_encoder = _DPAttnEncoder()
|
| 230 |
+
# proj_out keeps the .net.weight (out, 1, in) Conv1d-k1 layout
|
| 231 |
+
self.proj_out = _DPProjOut()
|
| 232 |
+
# sentence_token (1, DIM, 1) — prepended as the first time slot
|
| 233 |
+
self.sentence_token = mx.zeros((1, DP_DIM, 1))
|
| 234 |
+
|
| 235 |
+
def __call__(self, text_ids: mx.array, text_mask: mx.array) -> mx.array:
|
| 236 |
+
x = self.text_embedder(text_ids) # (B, T, 64)
|
| 237 |
+
# Prepend sentence_token: shape (1, 64, 1) → (B, 1, 64)
|
| 238 |
+
B = x.shape[0]
|
| 239 |
+
sentence = self.sentence_token.transpose(0, 2, 1)
|
| 240 |
+
sentence = mx.broadcast_to(sentence, (B, 1, DP_DIM))
|
| 241 |
+
x = mx.concatenate([sentence, x], axis=1) # (B, T+1, 64)
|
| 242 |
+
|
| 243 |
+
# Extend mask with a leading 1 (sentence token always valid)
|
| 244 |
+
if text_mask is not None:
|
| 245 |
+
extra = mx.ones((B, 1, 1), dtype=text_mask.dtype)
|
| 246 |
+
mask_ntc = mx.concatenate([extra, text_mask.transpose(0, 2, 1)], axis=1)
|
| 247 |
+
else:
|
| 248 |
+
mask_ntc = None
|
| 249 |
+
|
| 250 |
+
x = self.convnext(x, mask_ntc)
|
| 251 |
+
x = self.attn_encoder(x, mask_ntc)
|
| 252 |
+
|
| 253 |
+
# Take slot 0 (sentence token output) → (B, 1, 64)
|
| 254 |
+
sentence_out = x[:, :1, :] # (B, 1, 64)
|
| 255 |
+
# proj_out (Conv1d k=1) — applied along time, output (B, 1, 64)
|
| 256 |
+
sentence_out = self.proj_out(sentence_out)
|
| 257 |
+
return sentence_out.reshape(B, DP_DIM) # (B, 64)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class _DPProjOut(nn.Module):
|
| 261 |
+
"""Conv1d k=1 64→64. No bias in ONNX (confirmed via graph inspection)."""
|
| 262 |
+
|
| 263 |
+
def __init__(self) -> None:
|
| 264 |
+
super().__init__()
|
| 265 |
+
class _Net(nn.Module):
|
| 266 |
+
def __init__(_):
|
| 267 |
+
super().__init__()
|
| 268 |
+
_.weight = mx.zeros((DP_DIM, 1, DP_DIM))
|
| 269 |
+
def __call__(_, x):
|
| 270 |
+
return mx.conv1d(x, _.weight, stride=1, padding=0)
|
| 271 |
+
self.net = _Net()
|
| 272 |
+
|
| 273 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 274 |
+
return self.net(x)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class _DPPredictor(nn.Module):
|
| 278 |
+
"""Linear(192 → 128) + PReLU + Linear(128 → 1).
|
| 279 |
+
|
| 280 |
+
PReLU is stored under ``activation.weight (1,)`` — a single learnable
|
| 281 |
+
negative-slope coefficient.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
def __init__(self) -> None:
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.layers = [
|
| 287 |
+
nn.Linear(DP_MLP_IN, DP_MLP_HIDDEN, bias=True),
|
| 288 |
+
nn.Linear(DP_MLP_HIDDEN, 1, bias=True),
|
| 289 |
+
]
|
| 290 |
+
# PReLU: activation.weight shape (1,) — single scalar slope
|
| 291 |
+
class _Activation(nn.Module):
|
| 292 |
+
def __init__(_):
|
| 293 |
+
super().__init__()
|
| 294 |
+
_.weight = mx.zeros((1,))
|
| 295 |
+
def __call__(_, x):
|
| 296 |
+
# PReLU(x) = max(0, x) + slope * min(0, x)
|
| 297 |
+
neg = mx.minimum(x, mx.array(0.0, dtype=x.dtype))
|
| 298 |
+
pos = mx.maximum(x, mx.array(0.0, dtype=x.dtype))
|
| 299 |
+
return pos + _.weight * neg
|
| 300 |
+
self.activation = _Activation()
|
| 301 |
+
|
| 302 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 303 |
+
h = self.layers[0](x) # (B, 128)
|
| 304 |
+
h = self.activation(h)
|
| 305 |
+
h = self.layers[1](h) # (B, 1)
|
| 306 |
+
return h
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class _DPRoot(nn.Module):
|
| 310 |
+
"""``tts.dp.X`` namespace container."""
|
| 311 |
+
|
| 312 |
+
def __init__(self) -> None:
|
| 313 |
+
super().__init__()
|
| 314 |
+
self.sentence_encoder = _DPSentenceEncoder()
|
| 315 |
+
self.predictor = _DPPredictor()
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class _DPContainer(nn.Module):
|
| 319 |
+
def __init__(self) -> None:
|
| 320 |
+
super().__init__()
|
| 321 |
+
self.dp = _DPRoot()
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class DurationPredictor(nn.Module):
|
| 325 |
+
"""Predicts total audio duration (seconds) for an utterance.
|
| 326 |
+
|
| 327 |
+
Submodule namespace matches ONNX keys ``tts.dp.X.Y`` exactly.
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
def __init__(self) -> None:
|
| 331 |
+
super().__init__()
|
| 332 |
+
self.tts = _DPContainer()
|
| 333 |
+
|
| 334 |
+
def __call__(
|
| 335 |
+
self,
|
| 336 |
+
text_ids: mx.array, # (B, T) int
|
| 337 |
+
style_dp: mx.array, # (B, 8, 16)
|
| 338 |
+
text_mask: mx.array, # (B, 1, T)
|
| 339 |
+
) -> mx.array:
|
| 340 |
+
sentence = self.tts.dp.sentence_encoder(text_ids, text_mask) # (B, 64)
|
| 341 |
+
style_flat = style_dp.reshape(style_dp.shape[0], -1) # (B, 128)
|
| 342 |
+
joined = mx.concatenate([sentence, style_flat], axis=-1) # (B, 192)
|
| 343 |
+
log_dur = self.tts.dp.predictor(joined).reshape(-1) # (B,)
|
| 344 |
+
return mx.exp(log_dur) # duration in seconds
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
__all__ = ["DurationPredictor"]
|
src/supertonic_3_mlx/pipeline.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Supertonic 3 end-to-end MLX pipeline.
|
| 2 |
+
|
| 3 |
+
Stitches the four MLX sub-models (DurationPredictor → TextEncoder →
|
| 4 |
+
VectorEstimator → Vocoder) into a single ``generate(text, voice, lang)`` call
|
| 5 |
+
that returns a 44.1 kHz mono numpy waveform.
|
| 6 |
+
|
| 7 |
+
Flow:
|
| 8 |
+
|
| 9 |
+
text ──tokenize(unicode_indexer)──▶ text_ids (B, T_text)
|
| 10 |
+
│
|
| 11 |
+
voice_style (.json) ──▶ style_ttl (B, 50, 256), style_dp (B, 8, 16)
|
| 12 |
+
│
|
| 13 |
+
duration_predictor(text_ids, style_dp, text_mask) ──▶ duration_s (B,)
|
| 14 |
+
│
|
| 15 |
+
text_encoder(text_ids, style_ttl, text_mask) ──▶ text_emb (B, 256, T_text)
|
| 16 |
+
│
|
| 17 |
+
noise ~ N(0, I) of shape (B, 144, T_lat)
|
| 18 |
+
where T_lat = ceil(duration_s × 44100 / (512 × 6))
|
| 19 |
+
│
|
| 20 |
+
vector_estimator 5-step Euler with CFG (4×cond − 3×uncond):
|
| 21 |
+
for step in [0..4]:
|
| 22 |
+
x ← VE(x, text_emb, style_ttl, masks, current_step=step+1, total_step=5)
|
| 23 |
+
│
|
| 24 |
+
vocoder(audio_latent) ──▶ wav (B, T_lat × 6 × 512)
|
| 25 |
+
|
| 26 |
+
Public API:
|
| 27 |
+
|
| 28 |
+
pipe = SupertonicMLXPipeline.from_pretrained("/tmp/supertonic3/model")
|
| 29 |
+
wav = pipe.generate("Hello world", voice="F1", lang="en")
|
| 30 |
+
import soundfile as sf
|
| 31 |
+
sf.write("out.wav", wav, pipe.sample_rate)
|
| 32 |
+
"""
|
| 33 |
+
from __future__ import annotations
|
| 34 |
+
|
| 35 |
+
import json
|
| 36 |
+
import math
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
from typing import Optional
|
| 39 |
+
|
| 40 |
+
import mlx.core as mx
|
| 41 |
+
import numpy as np
|
| 42 |
+
|
| 43 |
+
from supertonic_3_mlx._config import SAMPLE_RATE
|
| 44 |
+
from supertonic_3_mlx.duration_predictor import DurationPredictor
|
| 45 |
+
from supertonic_3_mlx.text_encoder import TextEncoder
|
| 46 |
+
from supertonic_3_mlx.vector_estimator import VectorEstimator
|
| 47 |
+
from supertonic_3_mlx.vocoder import Vocoder
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Latent rate: at 44.1 kHz with hop=512 and chunk_compress=6, one latent step
|
| 51 |
+
# covers 512 × 6 = 3072 samples = 69.7 ms.
|
| 52 |
+
SAMPLES_PER_LATENT_STEP = 512 * 6 # 3072
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ── Shared ONNX → MLX weight extraction ─────────────────────────────
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _convert_onnx(onnx_path: str | Path) -> dict:
|
| 59 |
+
"""Return a dict of ``{clean_key: mx.array}`` for a Supertonic ONNX file.
|
| 60 |
+
|
| 61 |
+
Combines the three extraction stages discovered during the per-component
|
| 62 |
+
ports (T.3.1, T.3.2, T.3.3):
|
| 63 |
+
|
| 64 |
+
1. Named ``tts.*`` initialisers with shape transforms (dwconv, gamma,
|
| 65 |
+
pwconv, head.layer2).
|
| 66 |
+
2. Anonymous MatMul weights recovered via the MatMul output path.
|
| 67 |
+
3. Anonymous Conv weights and PReLU slopes recovered the same way.
|
| 68 |
+
"""
|
| 69 |
+
import onnx
|
| 70 |
+
import onnx.numpy_helper as nh
|
| 71 |
+
|
| 72 |
+
m = onnx.load(str(onnx_path))
|
| 73 |
+
|
| 74 |
+
def _matmul_clean(out_name: str) -> str:
|
| 75 |
+
p = out_name.lstrip("/")
|
| 76 |
+
if p.endswith("/MatMul_output_0"):
|
| 77 |
+
p = p[: -len("/MatMul_output_0")]
|
| 78 |
+
# Drop the leading model-name path (e.g. /text_encoder/, /duration_predictor/, /vector_estimator/)
|
| 79 |
+
for prefix in ("text_encoder/", "duration_predictor/", "vector_estimator/", "vocoder/"):
|
| 80 |
+
if p.startswith(prefix):
|
| 81 |
+
p = p[len(prefix):]
|
| 82 |
+
break
|
| 83 |
+
return p.replace("/", ".") + ".weight"
|
| 84 |
+
|
| 85 |
+
def _conv_clean(out_name: str) -> str:
|
| 86 |
+
p = out_name.lstrip("/")
|
| 87 |
+
if p.endswith("/Conv_output_0"):
|
| 88 |
+
p = p[: -len("/Conv_output_0")]
|
| 89 |
+
for prefix in ("vocoder/", "vector_estimator/", "text_encoder/", "duration_predictor/"):
|
| 90 |
+
if p.startswith(prefix):
|
| 91 |
+
p = p[len(prefix):]
|
| 92 |
+
break
|
| 93 |
+
return "tts.ae." + p.replace("/", ".")
|
| 94 |
+
|
| 95 |
+
def _prelu_clean(out_name: str) -> str:
|
| 96 |
+
p = out_name.lstrip("/")
|
| 97 |
+
if p.endswith("/PRelu_output_0"):
|
| 98 |
+
p = p[: -len("/PRelu_output_0")]
|
| 99 |
+
for prefix in ("vocoder/", "vector_estimator/"):
|
| 100 |
+
if p.startswith(prefix):
|
| 101 |
+
p = p[len(prefix):]
|
| 102 |
+
break
|
| 103 |
+
return "tts.ae." + p.replace("/", ".") + ".weight"
|
| 104 |
+
|
| 105 |
+
# Detect which model this file is — affects how we wrap named init keys
|
| 106 |
+
name_prefixes = {init.name.split(".")[0] for init in m.graph.initializer if "." in init.name}
|
| 107 |
+
is_text_encoder = "tts" in name_prefixes and any(
|
| 108 |
+
i.name.startswith("tts.ttl.text_encoder") for i in m.graph.initializer
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
weights: dict[str, mx.array] = {}
|
| 112 |
+
|
| 113 |
+
# Stage 1: named initialisers
|
| 114 |
+
for init in m.graph.initializer:
|
| 115 |
+
n = init.name
|
| 116 |
+
# Determine if this is a structured (named) weight or an anonymous graph const
|
| 117 |
+
if not (n.startswith("tts.") or "vector_estimator.tts.ttl." in n or "uncond_masker." in n):
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
# Strip the vector_estimator-specific prefix so all 4 models share a name space.
|
| 121 |
+
if n.startswith("vector_estimator.tts.ttl."):
|
| 122 |
+
clean = n[len("vector_estimator.tts.ttl."):]
|
| 123 |
+
else:
|
| 124 |
+
clean = n
|
| 125 |
+
|
| 126 |
+
arr = nh.to_array(init)
|
| 127 |
+
|
| 128 |
+
# Shape transforms
|
| 129 |
+
if (clean.endswith(".dwconv.weight") and arr.ndim == 3
|
| 130 |
+
and arr.shape[1] == 1 and arr.shape[2] != 1):
|
| 131 |
+
arr = np.transpose(arr, (0, 2, 1))
|
| 132 |
+
if (clean.endswith(".dwconv.net.weight") and arr.ndim == 3
|
| 133 |
+
and arr.shape[1] == 1):
|
| 134 |
+
arr = np.transpose(arr, (0, 2, 1))
|
| 135 |
+
if (clean.endswith(".gamma") and arr.ndim == 3
|
| 136 |
+
and arr.shape[0] == 1 and arr.shape[2] == 1):
|
| 137 |
+
arr = arr.reshape(arr.shape[1])
|
| 138 |
+
if ((clean.endswith(".pwconv1.weight") or clean.endswith(".pwconv2.weight"))
|
| 139 |
+
and arr.ndim == 3 and arr.shape[-1] == 1):
|
| 140 |
+
arr = arr.squeeze(-1)
|
| 141 |
+
if clean.endswith(".net.weight") and arr.ndim == 3 and arr.shape[-1] == 1:
|
| 142 |
+
# Conv1d k=1 wrapped via .net (e.g. proj_in/proj_out)
|
| 143 |
+
arr = arr.squeeze(-1)
|
| 144 |
+
# vocoder head.layer2 (out, in, 1) → MLX Conv1d (out, K=1, in)
|
| 145 |
+
if clean == "tts.ae.decoder.head.layer2.weight" and arr.ndim == 3:
|
| 146 |
+
arr = np.transpose(arr, (0, 2, 1))
|
| 147 |
+
# vocoder head.layer1.net.weight (out, in, K) → MLX Conv1d (out, K, in)
|
| 148 |
+
if clean == "tts.ae.decoder.head.layer1.net.weight" and arr.ndim == 3:
|
| 149 |
+
arr = np.transpose(arr, (0, 2, 1))
|
| 150 |
+
|
| 151 |
+
weights[clean] = mx.array(arr)
|
| 152 |
+
|
| 153 |
+
# Stage 2: MatMul weight recovery
|
| 154 |
+
inits_map = {init.name: init for init in m.graph.initializer}
|
| 155 |
+
for node in m.graph.node:
|
| 156 |
+
if node.op_type != "MatMul" or len(node.input) < 2:
|
| 157 |
+
continue
|
| 158 |
+
winp = node.input[1]
|
| 159 |
+
if winp not in inits_map or winp.startswith("tts.") or "vector_estimator.tts" in winp:
|
| 160 |
+
continue
|
| 161 |
+
arr = nh.to_array(inits_map[winp])
|
| 162 |
+
if arr.ndim == 2:
|
| 163 |
+
arr = arr.T # ONNX (in, out) → MLX Linear (out, in)
|
| 164 |
+
clean = _matmul_clean(node.output[0])
|
| 165 |
+
# Build the leading namespace from the file context (already in tts.*)
|
| 166 |
+
if not clean.startswith(("tts.", "vector_field.", "uncond_masker.")):
|
| 167 |
+
clean = "tts.ttl." + clean if is_text_encoder else clean
|
| 168 |
+
weights[clean] = mx.array(arr)
|
| 169 |
+
|
| 170 |
+
# Stage 3: anonymous Conv + PReLU (vocoder embed / head)
|
| 171 |
+
for node in m.graph.node:
|
| 172 |
+
if node.op_type == "Conv":
|
| 173 |
+
for i, inp in enumerate(node.input[1:], 1):
|
| 174 |
+
if inp not in inits_map or inp.startswith("tts."):
|
| 175 |
+
continue
|
| 176 |
+
arr = nh.to_array(inits_map[inp])
|
| 177 |
+
base = _conv_clean(node.output[0])
|
| 178 |
+
if "dwconv" in base:
|
| 179 |
+
continue
|
| 180 |
+
if i == 1 and arr.ndim == 3:
|
| 181 |
+
arr = np.transpose(arr, (0, 2, 1)) # ONNX (out, in, K) → MLX (out, K, in)
|
| 182 |
+
key = base + (".weight" if i == 1 else ".bias")
|
| 183 |
+
weights[key] = mx.array(arr)
|
| 184 |
+
elif node.op_type == "PRelu":
|
| 185 |
+
for inp in node.input[1:]:
|
| 186 |
+
if inp in inits_map and not inp.startswith("tts."):
|
| 187 |
+
weights[_prelu_clean(node.output[0])] = mx.array(nh.to_array(inits_map[inp]))
|
| 188 |
+
|
| 189 |
+
return weights
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _load_into(model, weights: dict) -> int:
|
| 193 |
+
"""Match converted weights to model params (shape-tolerant via reshape).
|
| 194 |
+
|
| 195 |
+
Returns the number of successfully matched tensors.
|
| 196 |
+
"""
|
| 197 |
+
from mlx.utils import tree_flatten
|
| 198 |
+
expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
|
| 199 |
+
matched = {}
|
| 200 |
+
for k, exp_shape in expected.items():
|
| 201 |
+
if k not in weights:
|
| 202 |
+
continue
|
| 203 |
+
v = weights[k]
|
| 204 |
+
if tuple(v.shape) != exp_shape:
|
| 205 |
+
if v.size == np.prod(exp_shape):
|
| 206 |
+
v = v.reshape(exp_shape)
|
| 207 |
+
else:
|
| 208 |
+
continue
|
| 209 |
+
matched[k] = v
|
| 210 |
+
model.load_weights(list(matched.items()), strict=False)
|
| 211 |
+
return len(matched)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# ── Tokenization ────────────────────────────────────────────────────
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
|
| 218 |
+
"""Encode a text string into character IDs.
|
| 219 |
+
|
| 220 |
+
The unicode_indexer is a flat list of size 65536; ``indexer[ord(c)]`` gives
|
| 221 |
+
the token ID for character ``c`` (-1 = unknown). For Phase T.4 we wrap the
|
| 222 |
+
text with no special language tokens — the ONNX SDK uses language tags but
|
| 223 |
+
our pipeline currently runs unconditioned on language for the first WAV
|
| 224 |
+
emission (parity validation happens after).
|
| 225 |
+
"""
|
| 226 |
+
ids = []
|
| 227 |
+
for c in text:
|
| 228 |
+
cp = ord(c)
|
| 229 |
+
if 0 <= cp < len(indexer):
|
| 230 |
+
tok = indexer[cp]
|
| 231 |
+
if tok >= 0:
|
| 232 |
+
ids.append(tok)
|
| 233 |
+
if not ids:
|
| 234 |
+
# fallback to a single space token to avoid empty input
|
| 235 |
+
ids = [indexer[ord(" ")]] if indexer[ord(" ")] >= 0 else [0]
|
| 236 |
+
return np.asarray(ids, dtype=np.int32)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# ── Pipeline ────────────────────────────────────────────────────────
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class SupertonicMLXPipeline:
|
| 243 |
+
"""End-to-end Supertonic 3 TTS pipeline in pure MLX.
|
| 244 |
+
|
| 245 |
+
Loads four sub-models (duration_predictor, text_encoder, vector_estimator,
|
| 246 |
+
vocoder), the unicode tokenizer, and exposes ``generate(text, voice, lang)``.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
sample_rate: int = SAMPLE_RATE
|
| 250 |
+
# Locked by the model architecture: Supertonic 3 is a flow-matching + CFG
|
| 251 |
+
# model trained for exactly 5 Euler steps with t ∈ {0.2, 0.4, 0.6, 0.8, 1.0}
|
| 252 |
+
# and the combination 4×cond − 3×uncond. Any other step count or skipping
|
| 253 |
+
# CFG produces an essentially uncorrelated waveform (verified by
|
| 254 |
+
# ``sub-projects/supertonic3-mlx/bench_n_steps.py``: cosine drops to
|
| 255 |
+
# ≤ 0.5 for n∈{3,4,6} and ≈ 0.05 for cfg=False). Reducing inference
|
| 256 |
+
# latency further would require distilling a shorter-schedule model.
|
| 257 |
+
n_euler_steps: int = 5
|
| 258 |
+
|
| 259 |
+
def __init__(
|
| 260 |
+
self,
|
| 261 |
+
duration_predictor: DurationPredictor,
|
| 262 |
+
text_encoder: TextEncoder,
|
| 263 |
+
vector_estimator: VectorEstimator,
|
| 264 |
+
vocoder: Vocoder,
|
| 265 |
+
unicode_indexer: list[int],
|
| 266 |
+
voice_dir: Path,
|
| 267 |
+
) -> None:
|
| 268 |
+
self.duration_predictor = duration_predictor
|
| 269 |
+
self.text_encoder = text_encoder
|
| 270 |
+
self.vector_estimator = vector_estimator
|
| 271 |
+
self.vocoder = vocoder
|
| 272 |
+
self.unicode_indexer = unicode_indexer
|
| 273 |
+
self.voice_dir = voice_dir
|
| 274 |
+
|
| 275 |
+
# T.5 — compile the hot loops. ``mx.compile`` caches a kernel graph keyed
|
| 276 |
+
# by input shapes; the 5× CFG Euler loop and the single vocoder pass
|
| 277 |
+
# both gain from fused kernel dispatch (~50–100 layer ops collapse into
|
| 278 |
+
# one dispatch per cached graph).
|
| 279 |
+
|
| 280 |
+
# T.5.3 — also pre-project text and style K/V outside the step. They
|
| 281 |
+
# are invariant across the 5 Euler steps, so the 4 text_attn + 4
|
| 282 |
+
# style_attn blocks no longer re-run their W_key / W_value / RoPE_K
|
| 283 |
+
# matmuls on every step (saves 40 matmuls per generate).
|
| 284 |
+
cond_scale = self.vector_estimator.CFG_COND_SCALE
|
| 285 |
+
uncond_scale = self.vector_estimator.CFG_UNCOND_SCALE
|
| 286 |
+
|
| 287 |
+
def _cached_step(
|
| 288 |
+
noisy, lat_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
|
| 289 |
+
):
|
| 290 |
+
noisy_2 = mx.concatenate([noisy, noisy], axis=0)
|
| 291 |
+
text_kv = [(kv_flat[2 * i], kv_flat[2 * i + 1]) for i in range(4)]
|
| 292 |
+
style_kv = [(kv_flat[8 + 2 * i], kv_flat[8 + 2 * i + 1]) for i in range(4)]
|
| 293 |
+
v_2 = self.vector_estimator.velocity_cached(
|
| 294 |
+
noisy_2, lat_mask_2, text_mask_2, t_norm_2, text_kv, style_kv,
|
| 295 |
+
)
|
| 296 |
+
B = noisy.shape[0]
|
| 297 |
+
cond_v = v_2[:B]
|
| 298 |
+
uncond_v = v_2[B:]
|
| 299 |
+
combined = cond_scale * cond_v - uncond_scale * uncond_v
|
| 300 |
+
return noisy + combined / total_step.reshape(-1, 1, 1).astype(combined.dtype)
|
| 301 |
+
|
| 302 |
+
def _voc_step(latent):
|
| 303 |
+
return self.vocoder(latent)
|
| 304 |
+
|
| 305 |
+
self._cached_step_compiled = mx.compile(_cached_step)
|
| 306 |
+
self._voc_compiled = mx.compile(_voc_step)
|
| 307 |
+
|
| 308 |
+
# Pick the runtime dtype from any leaf weight of the vector estimator —
|
| 309 |
+
# ``from_pretrained(dtype=...)`` may have cast the model to ``bf16``,
|
| 310 |
+
# in which case all inputs to the compiled hot loops must be cast to
|
| 311 |
+
# match (mixed-dtype Conv/MatMul is not legal in MLX).
|
| 312 |
+
from mlx.utils import tree_flatten
|
| 313 |
+
leaves = [v for _, v in tree_flatten(vector_estimator.parameters())
|
| 314 |
+
if isinstance(v, mx.array)]
|
| 315 |
+
self.dtype = leaves[0].dtype if leaves else mx.float32
|
| 316 |
+
|
| 317 |
+
@classmethod
|
| 318 |
+
def from_pretrained(
|
| 319 |
+
cls,
|
| 320 |
+
model_id_or_path: str | Path,
|
| 321 |
+
dtype: mx.Dtype | None = None,
|
| 322 |
+
cache_dir: str | Path | None = None,
|
| 323 |
+
revision: str | None = None,
|
| 324 |
+
) -> "SupertonicMLXPipeline":
|
| 325 |
+
"""Construct the pipeline from a model snapshot.
|
| 326 |
+
|
| 327 |
+
Three sources are accepted, auto-detected:
|
| 328 |
+
|
| 329 |
+
1. **Hugging Face Hub repo id** (e.g. ``"ambassadia/supertonic-3-mlx"``):
|
| 330 |
+
weights are downloaded via :func:`huggingface_hub.snapshot_download`
|
| 331 |
+
into ``cache_dir`` (defaults to the standard HF cache) and loaded
|
| 332 |
+
directly from the bundled ``weights/*.safetensors`` files.
|
| 333 |
+
2. **Local path with a** ``weights/`` **subdir**: the MLX-native
|
| 334 |
+
layout (4 safetensors + ``unicode_indexer.json`` + ``voice_styles/``).
|
| 335 |
+
Fast path — no ONNX conversion at runtime.
|
| 336 |
+
3. **Local path with an** ``onnx/`` **subdir**: the upstream
|
| 337 |
+
``Supertone/supertonic-3`` snapshot layout. Weights are converted
|
| 338 |
+
from ONNX on the fly (~ 1 s per sub-model on M4). Useful for
|
| 339 |
+
development or when starting from the original upstream release.
|
| 340 |
+
|
| 341 |
+
Optional kwargs:
|
| 342 |
+
dtype — if non-None and not float32, cast all weights to the
|
| 343 |
+
given dtype after load (only ``mx.bfloat16`` is
|
| 344 |
+
currently meaningful; see README "BF16 note").
|
| 345 |
+
cache_dir — passed to ``huggingface_hub.snapshot_download``.
|
| 346 |
+
revision — branch / tag / commit sha on the Hub.
|
| 347 |
+
"""
|
| 348 |
+
# 1. Resolve the local snapshot directory
|
| 349 |
+
if isinstance(model_id_or_path, str) and "/" in model_id_or_path \
|
| 350 |
+
and not Path(model_id_or_path).exists():
|
| 351 |
+
try:
|
| 352 |
+
from huggingface_hub import snapshot_download
|
| 353 |
+
except ImportError as e:
|
| 354 |
+
raise ImportError(
|
| 355 |
+
"Loading from the Hugging Face Hub requires "
|
| 356 |
+
"``huggingface_hub`` — install with ``pip install "
|
| 357 |
+
"supertonic-3-mlx[hub]`` or ``pip install huggingface_hub``."
|
| 358 |
+
) from e
|
| 359 |
+
local_dir = Path(snapshot_download(
|
| 360 |
+
repo_id=model_id_or_path,
|
| 361 |
+
cache_dir=cache_dir,
|
| 362 |
+
revision=revision,
|
| 363 |
+
allow_patterns=[
|
| 364 |
+
"weights/*.safetensors",
|
| 365 |
+
"unicode_indexer.json",
|
| 366 |
+
"voice_styles/*.json",
|
| 367 |
+
],
|
| 368 |
+
))
|
| 369 |
+
else:
|
| 370 |
+
local_dir = Path(model_id_or_path)
|
| 371 |
+
|
| 372 |
+
# 2. Detect layout
|
| 373 |
+
weights_dir = local_dir / "weights"
|
| 374 |
+
onnx_dir = local_dir / "onnx"
|
| 375 |
+
if weights_dir.exists():
|
| 376 |
+
return cls._from_safetensors(local_dir, dtype=dtype)
|
| 377 |
+
if onnx_dir.exists():
|
| 378 |
+
return cls._from_onnx(local_dir, dtype=dtype)
|
| 379 |
+
raise FileNotFoundError(
|
| 380 |
+
f"{local_dir} contains neither ``weights/`` (safetensors layout) "
|
| 381 |
+
f"nor ``onnx/`` (upstream layout); cannot load."
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
@classmethod
|
| 385 |
+
def _from_safetensors(
|
| 386 |
+
cls, local_dir: Path, dtype: mx.Dtype | None = None,
|
| 387 |
+
) -> "SupertonicMLXPipeline":
|
| 388 |
+
from mlx.utils import tree_flatten
|
| 389 |
+
weights_dir = local_dir / "weights"
|
| 390 |
+
voice_dir = local_dir / "voice_styles"
|
| 391 |
+
unicode_indexer = json.loads((local_dir / "unicode_indexer.json").read_text())
|
| 392 |
+
|
| 393 |
+
def _build(cls_, name):
|
| 394 |
+
model = cls_()
|
| 395 |
+
w = mx.load(str(weights_dir / f"{name}.safetensors"))
|
| 396 |
+
# Reshape any mismatched leaves (defensive; the converter already
|
| 397 |
+
# produced shape-correct tensors but a future re-export may not).
|
| 398 |
+
expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
|
| 399 |
+
for k in list(w.keys()):
|
| 400 |
+
if k in expected and tuple(w[k].shape) != expected[k]:
|
| 401 |
+
if w[k].size == int(np.prod(expected[k])):
|
| 402 |
+
w[k] = w[k].reshape(expected[k])
|
| 403 |
+
model.load_weights(list(w.items()), strict=False)
|
| 404 |
+
return model
|
| 405 |
+
|
| 406 |
+
ve = _build(VectorEstimator, "vector_estimator")
|
| 407 |
+
te = _build(TextEncoder, "text_encoder")
|
| 408 |
+
dp = _build(DurationPredictor, "duration_predictor")
|
| 409 |
+
voc = _build(Vocoder, "vocoder")
|
| 410 |
+
|
| 411 |
+
if dtype is not None and dtype != mx.float32:
|
| 412 |
+
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
| 413 |
+
|
| 414 |
+
return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
|
| 415 |
+
|
| 416 |
+
@classmethod
|
| 417 |
+
def _from_onnx(
|
| 418 |
+
cls, local_dir: Path, dtype: mx.Dtype | None = None,
|
| 419 |
+
) -> "SupertonicMLXPipeline":
|
| 420 |
+
onnx_dir = local_dir / "onnx"
|
| 421 |
+
voice_dir = local_dir / "voice_styles"
|
| 422 |
+
unicode_indexer = json.loads((onnx_dir / "unicode_indexer.json").read_text())
|
| 423 |
+
|
| 424 |
+
ve = VectorEstimator()
|
| 425 |
+
_load_into(ve, _convert_onnx(onnx_dir / "vector_estimator.onnx"))
|
| 426 |
+
te = TextEncoder()
|
| 427 |
+
_load_into(te, _convert_onnx(onnx_dir / "text_encoder.onnx"))
|
| 428 |
+
dp = DurationPredictor()
|
| 429 |
+
_load_into(dp, _convert_onnx(onnx_dir / "duration_predictor.onnx"))
|
| 430 |
+
voc = Vocoder()
|
| 431 |
+
_load_into(voc, _convert_onnx(onnx_dir / "vocoder.onnx"))
|
| 432 |
+
|
| 433 |
+
if dtype is not None and dtype != mx.float32:
|
| 434 |
+
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
| 435 |
+
|
| 436 |
+
return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
|
| 437 |
+
|
| 438 |
+
@staticmethod
|
| 439 |
+
def _cast_all(*models, dtype: mx.Dtype) -> None:
|
| 440 |
+
"""Cast all fp32 leaves of each model to ``dtype`` (in-place)."""
|
| 441 |
+
from mlx.utils import tree_map
|
| 442 |
+
|
| 443 |
+
def _cast(p):
|
| 444 |
+
if not isinstance(p, mx.array) or p.dtype != mx.float32:
|
| 445 |
+
return p
|
| 446 |
+
return p.astype(dtype)
|
| 447 |
+
|
| 448 |
+
for m_ in models:
|
| 449 |
+
m_.update(tree_map(_cast, m_.parameters()))
|
| 450 |
+
|
| 451 |
+
def _load_voice(self, voice: str) -> tuple[mx.array, mx.array]:
|
| 452 |
+
"""Load ``voice_styles/<voice>.json`` and return (style_ttl, style_dp)."""
|
| 453 |
+
path = self.voice_dir / f"{voice}.json"
|
| 454 |
+
data = json.loads(path.read_text())
|
| 455 |
+
style_ttl = np.asarray(data["style_ttl"]["data"], dtype=np.float32) # (1, 50, 256)
|
| 456 |
+
style_dp = np.asarray(data["style_dp"]["data"], dtype=np.float32) # (1, 8, 16)
|
| 457 |
+
return mx.array(style_ttl), mx.array(style_dp)
|
| 458 |
+
|
| 459 |
+
def generate(
|
| 460 |
+
self,
|
| 461 |
+
text: str,
|
| 462 |
+
voice: str = "F1",
|
| 463 |
+
lang: str = "en",
|
| 464 |
+
seed: int = 42,
|
| 465 |
+
n_steps: Optional[int] = None,
|
| 466 |
+
) -> np.ndarray:
|
| 467 |
+
"""Synthesise a single utterance. Returns a 1D float32 numpy waveform."""
|
| 468 |
+
n_steps = n_steps if n_steps is not None else self.n_euler_steps
|
| 469 |
+
|
| 470 |
+
# Tokenize
|
| 471 |
+
text_ids_np = _encode_text(text, self.unicode_indexer, lang)
|
| 472 |
+
text_ids = mx.array(text_ids_np[None, :]) # (1, T_text)
|
| 473 |
+
T_text = text_ids.shape[1]
|
| 474 |
+
text_mask = mx.ones((1, 1, T_text), dtype=self.dtype)
|
| 475 |
+
|
| 476 |
+
# Style
|
| 477 |
+
style_ttl, style_dp = self._load_voice(voice)
|
| 478 |
+
if self.dtype != mx.float32:
|
| 479 |
+
style_ttl = style_ttl.astype(self.dtype)
|
| 480 |
+
style_dp = style_dp.astype(self.dtype)
|
| 481 |
+
|
| 482 |
+
# Duration → latent length
|
| 483 |
+
duration_s = self.duration_predictor(text_ids, style_dp, text_mask)
|
| 484 |
+
mx.eval(duration_s)
|
| 485 |
+
duration_val = max(float(duration_s[0].item()), 0.5) # clamp to ≥ 0.5 s
|
| 486 |
+
T_lat = max(int(math.ceil(duration_val * self.sample_rate / SAMPLES_PER_LATENT_STEP)), 1)
|
| 487 |
+
|
| 488 |
+
# Text embedding
|
| 489 |
+
text_emb = self.text_encoder(text_ids, style_ttl, text_mask) # (1, 256, T_text)
|
| 490 |
+
|
| 491 |
+
# Initial noise — fixed seed for reproducibility
|
| 492 |
+
key = mx.random.key(seed)
|
| 493 |
+
noise = mx.random.normal((1, 144, T_lat), key=key).astype(self.dtype)
|
| 494 |
+
latent_mask = mx.ones((1, 1, T_lat), dtype=self.dtype)
|
| 495 |
+
|
| 496 |
+
# T.5.3 — build the (2B) CFG conditioning tensors once and pre-project
|
| 497 |
+
# K/V for every text_attn / style_attn block. ``kv_flat`` is the 16
|
| 498 |
+
# ``(K, V)`` arrays flattened into a list for the compiled step.
|
| 499 |
+
B = noise.shape[0]
|
| 500 |
+
ve = self.vector_estimator
|
| 501 |
+
text_uncond = mx.broadcast_to(
|
| 502 |
+
ve.uncond_masker.text_special_token, (B, text_emb.shape[1], text_emb.shape[2])
|
| 503 |
+
).astype(self.dtype)
|
| 504 |
+
style_k_uncond = mx.broadcast_to(
|
| 505 |
+
ve.uncond_masker.style_key_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
|
| 506 |
+
).astype(self.dtype)
|
| 507 |
+
style_v_uncond = mx.broadcast_to(
|
| 508 |
+
ve.uncond_masker.style_value_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
|
| 509 |
+
).astype(self.dtype)
|
| 510 |
+
text_emb_2 = mx.concatenate([text_emb, text_uncond], axis=0)
|
| 511 |
+
style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
|
| 512 |
+
style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
|
| 513 |
+
text_mask_2 = mx.concatenate([text_mask, text_mask], axis=0)
|
| 514 |
+
latent_mask_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
|
| 515 |
+
|
| 516 |
+
text_kv, style_kv = ve.precompute_cross_kv(
|
| 517 |
+
text_emb_2, style_k_2, style_v_2, text_mask_2,
|
| 518 |
+
)
|
| 519 |
+
kv_flat = []
|
| 520 |
+
for k, v in text_kv:
|
| 521 |
+
kv_flat.extend([k, v])
|
| 522 |
+
for k, v in style_kv:
|
| 523 |
+
kv_flat.extend([k, v])
|
| 524 |
+
|
| 525 |
+
# Euler with CFG — 5 steps by default
|
| 526 |
+
x = noise
|
| 527 |
+
total_step = mx.array([float(n_steps)], dtype=self.dtype)
|
| 528 |
+
for step in range(n_steps):
|
| 529 |
+
current_step = mx.array([float(step + 1)], dtype=self.dtype)
|
| 530 |
+
t_norm = current_step / total_step
|
| 531 |
+
t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
|
| 532 |
+
x = self._cached_step_compiled(
|
| 533 |
+
x, latent_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
|
| 534 |
+
)
|
| 535 |
+
mx.eval(x)
|
| 536 |
+
|
| 537 |
+
# Decode latent → waveform
|
| 538 |
+
wav = self._voc_compiled(x)
|
| 539 |
+
mx.eval(wav)
|
| 540 |
+
if wav.dtype != mx.float32:
|
| 541 |
+
wav = wav.astype(mx.float32)
|
| 542 |
+
return np.array(wav)[0] # (T_lat × 6 × 512,)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
__all__ = ["SupertonicMLXPipeline"]
|
src/supertonic_3_mlx/text_encoder.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Supertonic 3 text encoder MLX port.
|
| 2 |
+
|
| 3 |
+
Pipeline (operating in channels-last NTC after the initial conv):
|
| 4 |
+
|
| 5 |
+
text_ids [B, T_text] int64 character IDs
|
| 6 |
+
→ char_embedder (Embedding 8322→256) [B, T_text, 256]
|
| 7 |
+
→ 6× ConvNeXt(dim=256, hidden=1024, k=5, dilations [1,1,2,2,4,4])
|
| 8 |
+
→ 4× attn_encoder block:
|
| 9 |
+
RelPosSelfAttn (conv_q/k/v/o, 4 heads × 64) + norm_layers_1
|
| 10 |
+
FFN (conv_1: 256→1024, conv_2: 1024→256) + norm_layers_2
|
| 11 |
+
→ speech_prompted_text_encoder:
|
| 12 |
+
cross-attn1: text (Q) × style_ttl (K, V) → text features
|
| 13 |
+
cross-attn2: text (Q) × style_ttl (K, V) → text features
|
| 14 |
+
norm
|
| 15 |
+
→ output text_emb [B, 256, T_text] (channels-first to match vector_estimator)
|
| 16 |
+
|
| 17 |
+
Inputs:
|
| 18 |
+
text_ids: (B, T_text) int — character indices
|
| 19 |
+
style_ttl: (B, 50, 256) float — style token bank
|
| 20 |
+
text_mask: (B, 1, T_text) float — 1.0 where valid, 0.0 where padded
|
| 21 |
+
|
| 22 |
+
Submodule naming matches the ONNX initializer keys exactly so that
|
| 23 |
+
``model.load_weights(...)`` succeeds with no remapping.
|
| 24 |
+
"""
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import mlx.core as mx
|
| 28 |
+
import mlx.nn as nn
|
| 29 |
+
|
| 30 |
+
from supertonic_3_mlx._config import EPS_LN
|
| 31 |
+
from supertonic_3_mlx._nn_wrappers import WrappedNorm, WrappedLinear
|
| 32 |
+
from supertonic_3_mlx.vector_estimator import (
|
| 33 |
+
ConvNeXtBlock, _pad_sym_edge, _gelu_exact,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Vocab + dims (frozen by checkpoint)
|
| 38 |
+
VOCAB_SIZE = 8322
|
| 39 |
+
TE_DIM = 256
|
| 40 |
+
TE_CONVNEXT_HIDDEN = 1024
|
| 41 |
+
TE_CONVNEXT_K = 5
|
| 42 |
+
TE_CONVNEXT_NUM_LAYERS = 6
|
| 43 |
+
TE_CONVNEXT_DILATIONS = (1, 1, 2, 2, 4, 4)
|
| 44 |
+
|
| 45 |
+
TE_ATTN_NUM_LAYERS = 4
|
| 46 |
+
TE_ATTN_HEADS = 4
|
| 47 |
+
TE_ATTN_HEAD_DIM = TE_DIM // TE_ATTN_HEADS # 64
|
| 48 |
+
TE_FFN_HIDDEN = 1024
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class TextConvNeXtBlock(nn.Module):
|
| 52 |
+
"""ConvNeXt for the text encoder (dim=256, hidden=1024).
|
| 53 |
+
|
| 54 |
+
Shares the same architecture as ``vector_estimator.ConvNeXtBlock`` but is
|
| 55 |
+
redefined here with text-encoder-specific defaults to keep the modules
|
| 56 |
+
self-contained.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, dilation: int = 1) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.dim = TE_DIM
|
| 62 |
+
self.dilation = dilation
|
| 63 |
+
self.pad = dilation * (TE_CONVNEXT_K - 1) // 2
|
| 64 |
+
self.dwconv = nn.Conv1d(
|
| 65 |
+
TE_DIM, TE_DIM, kernel_size=TE_CONVNEXT_K, padding=0,
|
| 66 |
+
dilation=dilation, groups=TE_DIM, bias=True,
|
| 67 |
+
)
|
| 68 |
+
self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
|
| 69 |
+
self.pwconv1 = nn.Linear(TE_DIM, TE_CONVNEXT_HIDDEN, bias=True)
|
| 70 |
+
self.pwconv2 = nn.Linear(TE_CONVNEXT_HIDDEN, TE_DIM, bias=True)
|
| 71 |
+
self.gamma = mx.zeros((TE_DIM,))
|
| 72 |
+
|
| 73 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 74 |
+
# x: (B, T_text, 256)
|
| 75 |
+
residual = x
|
| 76 |
+
y = _pad_sym_edge(x, self.pad)
|
| 77 |
+
y = self.dwconv(y)
|
| 78 |
+
y = self.norm(y)
|
| 79 |
+
y = self.pwconv1(y)
|
| 80 |
+
y = _gelu_exact(y)
|
| 81 |
+
y = self.pwconv2(y)
|
| 82 |
+
y = y * self.gamma
|
| 83 |
+
out = residual + y
|
| 84 |
+
if mask is not None:
|
| 85 |
+
out = out * mask
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class TextConvNeXtStack(nn.Module):
|
| 90 |
+
"""6 stacked ConvNeXt blocks. Loaded as ``convnext.convnext.[0..5].X``."""
|
| 91 |
+
|
| 92 |
+
def __init__(self) -> None:
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.convnext = [TextConvNeXtBlock(d) for d in TE_CONVNEXT_DILATIONS]
|
| 95 |
+
|
| 96 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 97 |
+
for b in self.convnext:
|
| 98 |
+
x = b(x, mask)
|
| 99 |
+
return x
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class _ConvLayer(nn.Module):
|
| 103 |
+
"""Conv1d k=1 expressed via the ONNX-style ``X.weight (out, in, 1) + X.bias``.
|
| 104 |
+
|
| 105 |
+
The attn_encoder uses Conv1d k=1 instead of nn.Linear for its Q/K/V/O.
|
| 106 |
+
This wrapper keeps the weight shape (out, in, 1) intact and runs as a
|
| 107 |
+
Conv1d (the equivalent of a Linear when k=1).
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(self, in_dim: int, out_dim: int) -> None:
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.weight = mx.zeros((out_dim, 1, in_dim)) # (C_out, K=1, C_in)
|
| 113 |
+
self.bias = mx.zeros((out_dim,))
|
| 114 |
+
|
| 115 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 116 |
+
# x: (B, T, in_dim) — channels-last
|
| 117 |
+
# equivalent to nn.Conv1d(in_dim, out_dim, k=1) in NTC layout
|
| 118 |
+
return mx.conv1d(x, self.weight, stride=1, padding=0) + self.bias
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
REL_POS_WINDOW = 4 # rel_pos table size = 2*4 + 1 = 9
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _rel_to_abs(x: mx.array) -> mx.array:
|
| 125 |
+
"""[B, h, L, 2L-1] → [B, h, L, L] via the VITS shifted-skew reshape."""
|
| 126 |
+
B, h, L, _ = x.shape
|
| 127 |
+
x = mx.concatenate([x, mx.zeros((B, h, L, 1), dtype=x.dtype)], axis=-1)
|
| 128 |
+
x_flat = x.reshape(B, h, L * 2 * L)
|
| 129 |
+
x_flat = mx.concatenate([x_flat, mx.zeros((B, h, L - 1), dtype=x.dtype)], axis=-1)
|
| 130 |
+
x_final = x_flat.reshape(B, h, L + 1, 2 * L - 1)
|
| 131 |
+
return x_final[:, :, :L, L - 1:]
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _abs_to_rel(x: mx.array) -> mx.array:
|
| 135 |
+
"""[B, h, L, L] → [B, h, L, 2L-1] (inverse of _rel_to_abs)."""
|
| 136 |
+
B, h, L, _ = x.shape
|
| 137 |
+
x = mx.concatenate([x, mx.zeros((B, h, L, L - 1), dtype=x.dtype)], axis=-1)
|
| 138 |
+
x_flat = x.reshape(B, h, L * (2 * L - 1))
|
| 139 |
+
x_flat = mx.concatenate([mx.zeros((B, h, L), dtype=x.dtype), x_flat], axis=-1)
|
| 140 |
+
x_final = x_flat.reshape(B, h, L, 2 * L)
|
| 141 |
+
return x_final[:, :, :, 1:]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _slice_rel_emb(rel: mx.array, length: int, window: int) -> mx.array:
|
| 145 |
+
"""``rel`` (1, 2W+1, d) → (1, 2L-1, d) by zero-padding/slicing."""
|
| 146 |
+
pad_l = max(length - (window + 1), 0)
|
| 147 |
+
if pad_l > 0:
|
| 148 |
+
zero = mx.zeros((1, pad_l, rel.shape[-1]), dtype=rel.dtype)
|
| 149 |
+
padded = mx.concatenate([zero, rel, zero], axis=1)
|
| 150 |
+
else:
|
| 151 |
+
padded = rel
|
| 152 |
+
start = max(window + 1 - length, 0)
|
| 153 |
+
return padded[:, start: start + 2 * length - 1]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class RelPosSelfAttention(nn.Module):
|
| 157 |
+
"""VITS-style relative-position self-attention with window=4.
|
| 158 |
+
|
| 159 |
+
Adds two contributions to vanilla MHA:
|
| 160 |
+
- ``rel_logits = q @ rel_k.T`` then ``_rel_to_abs`` and added to attention logits
|
| 161 |
+
- ``rel_attn = _abs_to_rel(softmax(logits))`` then ``@ rel_v`` and added to output
|
| 162 |
+
|
| 163 |
+
Loaded keys (per layer):
|
| 164 |
+
``conv_q/k/v/o.weight`` (256, 256, 1) and ``.bias`` (256)
|
| 165 |
+
``emb_rel_k`` (1, 9, 64), ``emb_rel_v`` (1, 9, 64)
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(self) -> None:
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.conv_q = _ConvLayer(TE_DIM, TE_DIM)
|
| 171 |
+
self.conv_k = _ConvLayer(TE_DIM, TE_DIM)
|
| 172 |
+
self.conv_v = _ConvLayer(TE_DIM, TE_DIM)
|
| 173 |
+
self.conv_o = _ConvLayer(TE_DIM, TE_DIM)
|
| 174 |
+
self.window = REL_POS_WINDOW
|
| 175 |
+
self.emb_rel_k = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
|
| 176 |
+
self.emb_rel_v = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
|
| 177 |
+
|
| 178 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 179 |
+
B, T, _ = x.shape
|
| 180 |
+
H, D = TE_ATTN_HEADS, TE_ATTN_HEAD_DIM
|
| 181 |
+
q = self.conv_q(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
| 182 |
+
k = self.conv_k(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
| 183 |
+
v = self.conv_v(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
| 184 |
+
scale = D ** -0.5
|
| 185 |
+
|
| 186 |
+
# Standard attention logits
|
| 187 |
+
logits = (q @ k.transpose(0, 1, 3, 2)) * scale # (B, H, T, T)
|
| 188 |
+
|
| 189 |
+
# VITS relative-position contribution to logits
|
| 190 |
+
rel_k = _slice_rel_emb(self.emb_rel_k, T, self.window) # (1, 2T-1, D)
|
| 191 |
+
rel_logits = q @ rel_k.transpose(0, 2, 1)[:, None, :, :] # (B, H, T, 2T-1)
|
| 192 |
+
rel_logits = _rel_to_abs(rel_logits * scale) # (B, H, T, T)
|
| 193 |
+
logits = logits + rel_logits
|
| 194 |
+
|
| 195 |
+
if mask is not None:
|
| 196 |
+
key_mask = mask[:, :, 0][:, None, None, :]
|
| 197 |
+
neg_inf = mx.array(-1e4, dtype=logits.dtype)
|
| 198 |
+
logits = mx.where(key_mask.astype(mx.bool_), logits, neg_inf)
|
| 199 |
+
|
| 200 |
+
attn = mx.softmax(logits, axis=-1) # (B, H, T, T)
|
| 201 |
+
out = attn @ v # (B, H, T, D)
|
| 202 |
+
|
| 203 |
+
# VITS rel-pos value contribution
|
| 204 |
+
rel_v = _slice_rel_emb(self.emb_rel_v, T, self.window) # (1, 2T-1, D)
|
| 205 |
+
rel_weights = _abs_to_rel(attn) # (B, H, T, 2T-1)
|
| 206 |
+
out = out + rel_weights @ rel_v[:, None, :, :] # (B, H, T, D)
|
| 207 |
+
|
| 208 |
+
out = out.transpose(0, 2, 1, 3).reshape(B, T, H * D)
|
| 209 |
+
return self.conv_o(out)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class FFN(nn.Module):
|
| 213 |
+
"""FFN with Conv1d k=1 wrappers: conv_1 (256→1024) + ReLU + conv_2 (1024→256).
|
| 214 |
+
|
| 215 |
+
Activation is ReLU (confirmed by ONNX graph node ``Relu`` in ``ffn_layers.N``),
|
| 216 |
+
not GELU. The mask is applied before each Conv to match the ONNX semantics.
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
def __init__(self) -> None:
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.conv_1 = _ConvLayer(TE_DIM, TE_FFN_HIDDEN)
|
| 222 |
+
self.conv_2 = _ConvLayer(TE_FFN_HIDDEN, TE_DIM)
|
| 223 |
+
|
| 224 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 225 |
+
if mask is not None:
|
| 226 |
+
x = x * mask
|
| 227 |
+
y = self.conv_1(x)
|
| 228 |
+
y = mx.maximum(y, mx.array(0.0, dtype=y.dtype))
|
| 229 |
+
if mask is not None:
|
| 230 |
+
y = y * mask
|
| 231 |
+
y = self.conv_2(y)
|
| 232 |
+
if mask is not None:
|
| 233 |
+
y = y * mask
|
| 234 |
+
return y
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class AttnEncoder(nn.Module):
|
| 238 |
+
"""Stack of (RelPosSelfAttn + norm1) + (FFN + norm2) × 4."""
|
| 239 |
+
|
| 240 |
+
def __init__(self) -> None:
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.attn_layers = [RelPosSelfAttention() for _ in range(TE_ATTN_NUM_LAYERS)]
|
| 243 |
+
self.norm_layers_1 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
|
| 244 |
+
self.ffn_layers = [FFN() for _ in range(TE_ATTN_NUM_LAYERS)]
|
| 245 |
+
self.norm_layers_2 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
|
| 246 |
+
|
| 247 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 248 |
+
for i in range(TE_ATTN_NUM_LAYERS):
|
| 249 |
+
y = self.attn_layers[i](x, mask=mask)
|
| 250 |
+
x = self.norm_layers_1[i](x + y)
|
| 251 |
+
y = self.ffn_layers[i](x, mask)
|
| 252 |
+
x = self.norm_layers_2[i](x + y)
|
| 253 |
+
return x
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class _TextEmbedder(nn.Module):
|
| 257 |
+
"""char_embedder: VOCAB → TE_DIM. Loaded as ``char_embedder.weight (8322, 256)``."""
|
| 258 |
+
|
| 259 |
+
def __init__(self) -> None:
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.char_embedder = nn.Embedding(VOCAB_SIZE, TE_DIM)
|
| 262 |
+
|
| 263 |
+
def __call__(self, text_ids: mx.array) -> mx.array:
|
| 264 |
+
return self.char_embedder(text_ids)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class _InnerTextEncoder(nn.Module):
|
| 268 |
+
"""Pure text encoder before speech prompting. Loaded as ``text_encoder.X.Y``."""
|
| 269 |
+
|
| 270 |
+
def __init__(self) -> None:
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.text_embedder = _TextEmbedder()
|
| 273 |
+
self.convnext = TextConvNeXtStack()
|
| 274 |
+
self.attn_encoder = AttnEncoder()
|
| 275 |
+
|
| 276 |
+
def __call__(self, text_ids: mx.array, mask: mx.array) -> mx.array:
|
| 277 |
+
x = self.text_embedder(text_ids) # (B, T, 256)
|
| 278 |
+
if mask is not None:
|
| 279 |
+
x = x * mask
|
| 280 |
+
x = self.convnext(x, mask)
|
| 281 |
+
x = self.attn_encoder(x, mask)
|
| 282 |
+
return x
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class _StyleEncoder(nn.Module):
|
| 286 |
+
"""Holds ``style_token_layer.style_key`` (1, 50, 256)."""
|
| 287 |
+
|
| 288 |
+
def __init__(self) -> None:
|
| 289 |
+
super().__init__()
|
| 290 |
+
# Use a child module so the parameter path matches ``style_token_layer.style_key``
|
| 291 |
+
class _StyleTokenLayer(nn.Module):
|
| 292 |
+
def __init__(_):
|
| 293 |
+
super().__init__()
|
| 294 |
+
_.style_key = mx.zeros((1, 50, 256))
|
| 295 |
+
self.style_token_layer = _StyleTokenLayer()
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class _SpeechPromptedAttn(nn.Module):
|
| 299 |
+
"""Cross-attention from text (Q) to style_ttl (K, V). Single head, 256-d."""
|
| 300 |
+
|
| 301 |
+
def __init__(self) -> None:
|
| 302 |
+
super().__init__()
|
| 303 |
+
self.W_query = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
| 304 |
+
self.W_key = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
| 305 |
+
self.W_value = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
| 306 |
+
self.out_fc = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
| 307 |
+
|
| 308 |
+
def __call__(self, x: mx.array, style: mx.array) -> mx.array:
|
| 309 |
+
# x: (B, T_text, 256); style: (B, 50, 256)
|
| 310 |
+
# Single-head cross attention.
|
| 311 |
+
B, T, D = x.shape
|
| 312 |
+
q = self.W_query(x)
|
| 313 |
+
k = self.W_key(style)
|
| 314 |
+
v = self.W_value(style)
|
| 315 |
+
scale = D ** -0.5
|
| 316 |
+
logits = (q @ k.transpose(0, 2, 1)) * scale
|
| 317 |
+
attn = mx.softmax(logits, axis=-1)
|
| 318 |
+
out = attn @ v
|
| 319 |
+
return self.out_fc(out)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class _SpeechPromptedTextEncoder(nn.Module):
|
| 323 |
+
"""Two cross-attention layers modulating text features with style_ttl."""
|
| 324 |
+
|
| 325 |
+
def __init__(self) -> None:
|
| 326 |
+
super().__init__()
|
| 327 |
+
self.attention1 = _SpeechPromptedAttn()
|
| 328 |
+
self.attention2 = _SpeechPromptedAttn()
|
| 329 |
+
self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
|
| 330 |
+
|
| 331 |
+
def __call__(self, x: mx.array, style: mx.array) -> mx.array:
|
| 332 |
+
x = x + self.attention1(x, style)
|
| 333 |
+
x = x + self.attention2(x, style)
|
| 334 |
+
return self.norm(x)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class _RootTextEncoder(nn.Module):
|
| 338 |
+
"""Top-level container matching ONNX ``tts.ttl.*`` namespace."""
|
| 339 |
+
|
| 340 |
+
def __init__(self) -> None:
|
| 341 |
+
super().__init__()
|
| 342 |
+
self.text_encoder = _InnerTextEncoder()
|
| 343 |
+
self.style_encoder = _StyleEncoder()
|
| 344 |
+
self.speech_prompted_text_encoder = _SpeechPromptedTextEncoder()
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class _TtsContainer(nn.Module):
|
| 348 |
+
"""Outer container so weight keys ``tts.ttl.X.Y`` resolve."""
|
| 349 |
+
|
| 350 |
+
def __init__(self) -> None:
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.ttl = _RootTextEncoder()
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class TextEncoder(nn.Module):
|
| 356 |
+
"""Top-level text encoder: ``text_ids + style_ttl + text_mask → text_emb (B, 256, T)``.
|
| 357 |
+
|
| 358 |
+
Submodule naming matches the ONNX initializer keys after a single
|
| 359 |
+
``tts.ttl.`` prefix wrap (so weight keys look like
|
| 360 |
+
``tts.ttl.text_encoder.convnext.convnext.0.dwconv.weight``).
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
def __init__(self) -> None:
|
| 364 |
+
super().__init__()
|
| 365 |
+
self.tts = _TtsContainer()
|
| 366 |
+
|
| 367 |
+
def __call__(
|
| 368 |
+
self,
|
| 369 |
+
text_ids: mx.array, # (B, T_text) int
|
| 370 |
+
style_ttl: mx.array, # (B, 50, 256)
|
| 371 |
+
text_mask: mx.array, # (B, 1, T_text)
|
| 372 |
+
) -> mx.array:
|
| 373 |
+
mask_ntc = text_mask.transpose(0, 2, 1) # (B, T_text, 1)
|
| 374 |
+
x = self.tts.ttl.text_encoder(text_ids, mask_ntc)
|
| 375 |
+
x = self.tts.ttl.speech_prompted_text_encoder(x, style_ttl)
|
| 376 |
+
if mask_ntc is not None:
|
| 377 |
+
x = x * mask_ntc
|
| 378 |
+
# Return channels-first (B, 256, T_text) to match the vector_estimator input.
|
| 379 |
+
return x.transpose(0, 2, 1)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
__all__ = ["TextEncoder", "VOCAB_SIZE", "TE_DIM"]
|
src/supertonic_3_mlx/vector_estimator.py
ADDED
|
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Supertonic 3 vector estimator (64 M params) — flow-matching denoiser, MLX port.
|
| 2 |
+
|
| 3 |
+
Pipeline (operating in channels-last NTC layout):
|
| 4 |
+
|
| 5 |
+
noisy_latent [B, 144, T_lat] (channels first from ONNX I/O)
|
| 6 |
+
→ transpose [B, T_lat, 144]
|
| 7 |
+
→ proj_in (Linear 144→512) [B, T_lat, 512]
|
| 8 |
+
→ 24 main_blocks (4 cycles × 6 sub-types):
|
| 9 |
+
cycle = [stack4, time_film, cn1, text_attn, cn1, style_attn]
|
| 10 |
+
→ last_convnext (4 ConvNeXt) [B, T_lat, 512]
|
| 11 |
+
→ proj_out (Linear 512→144) [B, T_lat, 144]
|
| 12 |
+
→ transpose [B, 144, T_lat]
|
| 13 |
+
→ Euler step: denoised = noisy + velocity * (1 / total_step)
|
| 14 |
+
→ output [B, 144, T_lat]
|
| 15 |
+
|
| 16 |
+
Submodule naming matches the s3 ONNX initializer keys exactly, so loading
|
| 17 |
+
the safetensors produced by ``weights.convert_onnx_to_mlx`` requires no
|
| 18 |
+
remapping.
|
| 19 |
+
|
| 20 |
+
The forward path is faithful to ONNX semantics in fp32; ``mx.compile``,
|
| 21 |
+
quantisation, and kernel fusion are layered on later in T.3.
|
| 22 |
+
"""
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import math
|
| 26 |
+
|
| 27 |
+
import mlx.core as mx
|
| 28 |
+
import mlx.nn as nn
|
| 29 |
+
|
| 30 |
+
from supertonic_3_mlx._config import (
|
| 31 |
+
DIM, LATENT_CH, CONVNEXT_HIDDEN, CONVNEXT_K, STACK4_DILATIONS,
|
| 32 |
+
NUM_MAIN_BLOCKS, BLOCKS_PER_CYCLE, BLOCK_CYCLE,
|
| 33 |
+
TEXT_DIM, TEXT_HEADS, TEXT_HEAD_DIM, ROTARY_BASE, ROTARY_SCALE,
|
| 34 |
+
STYLE_DIM, STYLE_LEN, STYLE_HEADS, STYLE_HEAD_DIM,
|
| 35 |
+
TIME_EMB_DIM, TIME_MLP_HIDDEN,
|
| 36 |
+
EPS_LN,
|
| 37 |
+
)
|
| 38 |
+
from supertonic_3_mlx._nn_wrappers import (
|
| 39 |
+
WrappedNorm, WrappedLinear, ProjConv1x1,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _pad_sym_edge(x: mx.array, pad: int) -> mx.array:
|
| 44 |
+
"""Symmetric replicate-edge pad on the time axis (axis=1 for [B, T, C])."""
|
| 45 |
+
if pad == 0:
|
| 46 |
+
return x
|
| 47 |
+
left = mx.broadcast_to(x[:, :1, :], (x.shape[0], pad, x.shape[2]))
|
| 48 |
+
right = mx.broadcast_to(x[:, -1:, :], (x.shape[0], pad, x.shape[2]))
|
| 49 |
+
return mx.concatenate([left, x, right], axis=1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _gelu_exact(x: mx.array) -> mx.array:
|
| 53 |
+
"""Exact (non-tanh) GELU: x * 0.5 * (1 + erf(x / sqrt(2)))."""
|
| 54 |
+
return x * 0.5 * (1.0 + mx.erf(x * (2 ** -0.5)))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _mish(x: mx.array) -> mx.array:
|
| 58 |
+
"""Mish: x * tanh(softplus(x)) = x * tanh(log(1 + exp(x)))."""
|
| 59 |
+
return x * mx.tanh(mx.logaddexp(x, mx.array(0.0, dtype=x.dtype)))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ──────────────────────────────────────────────────────────────────
|
| 63 |
+
# ConvNeXt building blocks
|
| 64 |
+
# ──────────────────────────────────────────────────────────────────
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ConvNeXtBlock(nn.Module):
|
| 68 |
+
"""Single ConvNeXt block matching s3 keys: ``dwconv``, ``norm.norm``, ``pwconv1/2``, ``gamma``."""
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
dim: int = DIM,
|
| 73 |
+
hidden: int = CONVNEXT_HIDDEN,
|
| 74 |
+
kernel: int = CONVNEXT_K,
|
| 75 |
+
dilation: int = 1,
|
| 76 |
+
) -> None:
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.dim = dim
|
| 79 |
+
self.dilation = dilation
|
| 80 |
+
self.pad = dilation * (kernel - 1) // 2
|
| 81 |
+
self.dwconv = nn.Conv1d(
|
| 82 |
+
dim, dim, kernel_size=kernel, padding=0, dilation=dilation,
|
| 83 |
+
groups=dim, bias=True,
|
| 84 |
+
)
|
| 85 |
+
self.norm = WrappedNorm(dim, eps=EPS_LN)
|
| 86 |
+
self.pwconv1 = nn.Linear(dim, hidden, bias=True)
|
| 87 |
+
self.pwconv2 = nn.Linear(hidden, dim, bias=True)
|
| 88 |
+
# Stored as shape (1, dim, 1) in the ONNX checkpoint — see weights.py for
|
| 89 |
+
# the load-time reshape that flattens it to (dim,) for broadcasting in NTC.
|
| 90 |
+
self.gamma = mx.zeros((dim,))
|
| 91 |
+
|
| 92 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 93 |
+
# x: (B, T, C)
|
| 94 |
+
residual = x
|
| 95 |
+
y = _pad_sym_edge(x, self.pad)
|
| 96 |
+
y = self.dwconv(y) # (B, T, C)
|
| 97 |
+
y = self.norm(y) # LayerNorm last-dim
|
| 98 |
+
y = self.pwconv1(y) # (B, T, hidden)
|
| 99 |
+
y = _gelu_exact(y)
|
| 100 |
+
y = self.pwconv2(y) # (B, T, C)
|
| 101 |
+
y = y * self.gamma # broadcast over (B, T, .)
|
| 102 |
+
out = residual + y
|
| 103 |
+
if mask is not None:
|
| 104 |
+
out = out * mask
|
| 105 |
+
return out
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class ConvNeXtStack(nn.Module):
|
| 109 |
+
"""List of ConvNeXt blocks. Loaded as ``convnext.[0..N-1].X``."""
|
| 110 |
+
|
| 111 |
+
def __init__(self, dilations: tuple, dim: int = DIM, hidden: int = CONVNEXT_HIDDEN) -> None:
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.convnext = [ConvNeXtBlock(dim, hidden, CONVNEXT_K, d) for d in dilations]
|
| 114 |
+
|
| 115 |
+
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
| 116 |
+
for b in self.convnext:
|
| 117 |
+
x = b(x, mask)
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ──────────────────────────────────────────────────────────────────
|
| 122 |
+
# 6 block types per cycle
|
| 123 |
+
# ───────────────────────────��──────────────────────────────────────
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class Stack4Block(nn.Module):
|
| 127 |
+
"""Cycle position 0 — 4 ConvNeXt with dilations [1, 2, 4, 8].
|
| 128 |
+
|
| 129 |
+
Loaded keys: ``convnext.[0..3].{dwconv,norm.norm,pwconv1,pwconv2,gamma}``.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(self) -> None:
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.convnext = [ConvNeXtBlock(DIM, CONVNEXT_HIDDEN, CONVNEXT_K, d) for d in STACK4_DILATIONS]
|
| 135 |
+
|
| 136 |
+
def __call__(self, x: mx.array, mask: mx.array | None, **_) -> mx.array:
|
| 137 |
+
for b in self.convnext:
|
| 138 |
+
x = b(x, mask)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class TimeFiLMBlock(nn.Module):
|
| 143 |
+
"""Cycle position 1 — additive time conditioning: ``x + linear(t_emb)``.
|
| 144 |
+
|
| 145 |
+
Loaded keys: ``linear.linear.{weight,bias}``.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
def __init__(self) -> None:
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.linear = WrappedLinear(TIME_EMB_DIM, DIM, bias=True)
|
| 151 |
+
|
| 152 |
+
def __call__(self, x: mx.array, mask: mx.array | None, t_emb: mx.array, **_) -> mx.array:
|
| 153 |
+
# t_emb: (B, TIME_EMB_DIM) → broadcast across T
|
| 154 |
+
bias = self.linear(t_emb)[:, None, :] # (B, 1, DIM)
|
| 155 |
+
y = x + bias
|
| 156 |
+
if mask is not None:
|
| 157 |
+
y = y * mask
|
| 158 |
+
return y
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class ConvNeXt1Block(nn.Module):
|
| 162 |
+
"""Cycle positions 2 and 4 — a single ConvNeXt block.
|
| 163 |
+
|
| 164 |
+
Loaded keys: ``convnext.0.{dwconv,norm.norm,pwconv1,pwconv2,gamma}``.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
def __init__(self) -> None:
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.convnext = [ConvNeXtBlock(DIM, CONVNEXT_HIDDEN, CONVNEXT_K, 1)]
|
| 170 |
+
|
| 171 |
+
def __call__(self, x: mx.array, mask: mx.array | None, **_) -> mx.array:
|
| 172 |
+
return self.convnext[0](x, mask)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _build_rope_freqs(head_dim: int, base: int, scale: int, max_len: int = 1024) -> mx.array:
|
| 176 |
+
"""Pre-compute RoPE cos/sin table — (max_len, head_dim/2, 2)."""
|
| 177 |
+
half = head_dim // 2
|
| 178 |
+
inv_freq = 1.0 / (base ** (mx.arange(half, dtype=mx.float32) / half))
|
| 179 |
+
pos = mx.arange(max_len, dtype=mx.float32) * scale
|
| 180 |
+
angles = pos[:, None] * inv_freq[None, :] # (max_len, half)
|
| 181 |
+
return mx.stack([mx.cos(angles), mx.sin(angles)], axis=-1) # (max_len, half, 2)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _apply_rope(x: mx.array, freqs: mx.array) -> mx.array:
|
| 185 |
+
"""Apply RoPE rotation. ``x`` shape (B, H, T, head_dim); ``freqs`` (T, half, 2)."""
|
| 186 |
+
half = x.shape[-1] // 2
|
| 187 |
+
x_even, x_odd = x[..., :half], x[..., half:]
|
| 188 |
+
cos = freqs[..., 0] # (T, half)
|
| 189 |
+
sin = freqs[..., 1]
|
| 190 |
+
rot_even = x_even * cos[None, None, :, :] - x_odd * sin[None, None, :, :]
|
| 191 |
+
rot_odd = x_even * sin[None, None, :, :] + x_odd * cos[None, None, :, :]
|
| 192 |
+
return mx.concatenate([rot_even, rot_odd], axis=-1)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class TextCrossAttnBlock(nn.Module):
|
| 196 |
+
"""Cycle position 3 — text cross-attention with RoPE on Q and K.
|
| 197 |
+
|
| 198 |
+
Loaded keys:
|
| 199 |
+
``attn.W_query.linear.{weight,bias}``
|
| 200 |
+
``attn.W_key.linear.{weight,bias}``
|
| 201 |
+
``attn.W_value.linear.{weight,bias}``
|
| 202 |
+
``attn.out_fc.linear.{weight,bias}``
|
| 203 |
+
``attn.theta`` — frozen RoPE inv-freq table (1, 1, half)
|
| 204 |
+
``attn.increments`` — frozen position table (1, 1000, 1) — 0..999
|
| 205 |
+
``norm.norm.{weight,bias}``
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def __init__(self) -> None:
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.attn = _AttnInner(DIM, TEXT_DIM, TEXT_HEADS, TEXT_HEAD_DIM)
|
| 211 |
+
self.norm = WrappedNorm(DIM, eps=EPS_LN)
|
| 212 |
+
|
| 213 |
+
def __call__(
|
| 214 |
+
self,
|
| 215 |
+
x: mx.array,
|
| 216 |
+
mask: mx.array | None,
|
| 217 |
+
*,
|
| 218 |
+
text_emb: mx.array | None = None,
|
| 219 |
+
text_mask: mx.array | None = None,
|
| 220 |
+
latent_seq_len: mx.array | None = None,
|
| 221 |
+
text_seq_len: mx.array | None = None,
|
| 222 |
+
kv_cache: tuple[mx.array, mx.array] | None = None,
|
| 223 |
+
**_,
|
| 224 |
+
) -> mx.array:
|
| 225 |
+
# x: (B, T_lat, DIM); text_emb: (B, T_text, TEXT_DIM) — unused when kv_cache supplied.
|
| 226 |
+
residual = x * mask if mask is not None else x
|
| 227 |
+
h = self.attn(
|
| 228 |
+
residual, text_emb, text_mask=text_mask,
|
| 229 |
+
latent_seq_len=latent_seq_len, text_seq_len=text_seq_len,
|
| 230 |
+
kv_cache=kv_cache,
|
| 231 |
+
)
|
| 232 |
+
if mask is not None:
|
| 233 |
+
h = h * mask
|
| 234 |
+
out = self.norm(residual + h)
|
| 235 |
+
if mask is not None:
|
| 236 |
+
out = out * mask
|
| 237 |
+
return out
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class _AttnInner(nn.Module):
|
| 241 |
+
"""Multi-head cross-attention with RoPE applied to query and key.
|
| 242 |
+
|
| 243 |
+
Holds parameters under ``W_query``, ``W_key``, ``W_value``, ``out_fc`` —
|
| 244 |
+
each is a :class:`WrappedLinear` so its weight is keyed
|
| 245 |
+
``…W_query.linear.weight`` to match the ONNX checkpoint.
|
| 246 |
+
|
| 247 |
+
``theta`` and ``increments`` come from the ONNX graph as frozen tensors
|
| 248 |
+
(precomputed RoPE table). We rebuild the equivalent table from the
|
| 249 |
+
Supertonic-3 config so the module is self-contained.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
in_dim: int,
|
| 255 |
+
ctx_dim: int,
|
| 256 |
+
num_heads: int,
|
| 257 |
+
head_dim: int,
|
| 258 |
+
) -> None:
|
| 259 |
+
super().__init__()
|
| 260 |
+
self.num_heads = num_heads
|
| 261 |
+
self.head_dim = head_dim
|
| 262 |
+
# ONNX divides attention logits by 16.0 (= sqrt(TEXT_DIM)), not sqrt(head_dim).
|
| 263 |
+
self.scale = ctx_dim ** -0.5
|
| 264 |
+
|
| 265 |
+
kv_dim = num_heads * head_dim # = DIM = 512
|
| 266 |
+
self.W_query = WrappedLinear(in_dim, kv_dim, bias=True)
|
| 267 |
+
self.W_key = WrappedLinear(ctx_dim, kv_dim, bias=True)
|
| 268 |
+
self.W_value = WrappedLinear(ctx_dim, kv_dim, bias=True)
|
| 269 |
+
self.out_fc = WrappedLinear(kv_dim, in_dim, bias=True)
|
| 270 |
+
|
| 271 |
+
# Frozen RoPE tables — overwritten by checkpoint at load time.
|
| 272 |
+
# ONNX layout:
|
| 273 |
+
# ``increments`` (1, 1000, 1) holds positions 0..999 (no scale baked in)
|
| 274 |
+
# ``theta`` (1, 1, half) holds rotary_scale × base^(-i/half)
|
| 275 |
+
# Angle formula: ``angle = (pos / actual_seq_len) × theta``.
|
| 276 |
+
# The division by the actual seq length is critical — it normalises
|
| 277 |
+
# absolute positions into [0, 1] so audio and text are RoPE-aligned
|
| 278 |
+
# regardless of their respective lengths.
|
| 279 |
+
max_len = 1000
|
| 280 |
+
half = head_dim // 2
|
| 281 |
+
idx = mx.arange(half, dtype=mx.float32)
|
| 282 |
+
self.theta = (ROTARY_SCALE * mx.exp(-math.log(ROTARY_BASE) * idx / half))[None, None, :]
|
| 283 |
+
positions = mx.arange(max_len, dtype=mx.int64)
|
| 284 |
+
self.increments = positions[None, :, None] # (1, max_len, 1)
|
| 285 |
+
|
| 286 |
+
def _rope(self, x: mx.array, seq_len: mx.array | int | None = None) -> mx.array:
|
| 287 |
+
"""Apply RoPE rotation. ``seq_len`` is the effective (unmasked) length.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
x: (B, H, T, head_dim)
|
| 291 |
+
seq_len: scalar or (B,) — actual sequence length for position normalisation.
|
| 292 |
+
If None, defaults to T (no normalisation).
|
| 293 |
+
"""
|
| 294 |
+
T = x.shape[-2]
|
| 295 |
+
positions = self.increments[:, :T, :] # (1, T, 1)
|
| 296 |
+
if seq_len is None:
|
| 297 |
+
seq_len = float(T)
|
| 298 |
+
if isinstance(seq_len, (int, float)):
|
| 299 |
+
divisor = float(seq_len)
|
| 300 |
+
else:
|
| 301 |
+
divisor = seq_len.astype(mx.float32).reshape(-1, 1, 1)
|
| 302 |
+
norm_pos = positions / divisor # broadcasts to (B, T, 1) if divisor is (B,1,1)
|
| 303 |
+
angles = norm_pos * self.theta # (B, T, half) or (1, T, half)
|
| 304 |
+
cos = mx.cos(angles)
|
| 305 |
+
sin = mx.sin(angles)
|
| 306 |
+
half = self.head_dim // 2
|
| 307 |
+
# Broadcast (?, T, half) → (?, 1, T, half) for head dim
|
| 308 |
+
cos_b = cos[..., None, :, :] if cos.ndim == 3 else cos[None, None, :, :]
|
| 309 |
+
sin_b = sin[..., None, :, :] if sin.ndim == 3 else sin[None, None, :, :]
|
| 310 |
+
# Make sure broadcasts properly
|
| 311 |
+
if cos_b.shape[0] == 1 and x.shape[0] > 1:
|
| 312 |
+
cos_b = mx.broadcast_to(cos_b, (x.shape[0], 1, T, half))
|
| 313 |
+
sin_b = mx.broadcast_to(sin_b, (x.shape[0], 1, T, half))
|
| 314 |
+
# Reshape if needed
|
| 315 |
+
cos_b = cos_b.reshape(-1, 1, T, half)
|
| 316 |
+
sin_b = sin_b.reshape(-1, 1, T, half)
|
| 317 |
+
x_first, x_second = x[..., :half], x[..., half:]
|
| 318 |
+
rot_first = x_first * cos_b - x_second * sin_b
|
| 319 |
+
rot_second = x_first * sin_b + x_second * cos_b
|
| 320 |
+
return mx.concatenate([rot_first, rot_second], axis=-1)
|
| 321 |
+
|
| 322 |
+
def project_kv(
|
| 323 |
+
self,
|
| 324 |
+
text_emb: mx.array,
|
| 325 |
+
text_seq_len: mx.array | None = None,
|
| 326 |
+
) -> tuple[mx.array, mx.array]:
|
| 327 |
+
"""Project text_emb → (K_rope, V) once. Both are constant across the
|
| 328 |
+
Euler steps in a TTS inference call (T.5.3 cache target)."""
|
| 329 |
+
B, T_text, _ = text_emb.shape
|
| 330 |
+
H, D = self.num_heads, self.head_dim
|
| 331 |
+
k = self.W_key(text_emb).reshape(B, T_text, H, D).transpose(0, 2, 1, 3)
|
| 332 |
+
v = self.W_value(text_emb).reshape(B, T_text, H, D).transpose(0, 2, 1, 3)
|
| 333 |
+
k = self._rope(k, seq_len=text_seq_len if text_seq_len is not None else T_text)
|
| 334 |
+
return k, v
|
| 335 |
+
|
| 336 |
+
def __call__(
|
| 337 |
+
self,
|
| 338 |
+
x: mx.array,
|
| 339 |
+
text_emb: mx.array | None = None,
|
| 340 |
+
text_mask: mx.array | None = None,
|
| 341 |
+
latent_seq_len: mx.array | None = None,
|
| 342 |
+
text_seq_len: mx.array | None = None,
|
| 343 |
+
kv_cache: tuple[mx.array, mx.array] | None = None,
|
| 344 |
+
) -> mx.array:
|
| 345 |
+
B, T_lat, _ = x.shape
|
| 346 |
+
H, D = self.num_heads, self.head_dim
|
| 347 |
+
|
| 348 |
+
q = self.W_query(x).reshape(B, T_lat, H, D).transpose(0, 2, 1, 3) # (B, H, T_lat, D)
|
| 349 |
+
if kv_cache is not None:
|
| 350 |
+
k, v = kv_cache
|
| 351 |
+
else:
|
| 352 |
+
k, v = self.project_kv(text_emb, text_seq_len=text_seq_len)
|
| 353 |
+
|
| 354 |
+
# RoPE normalises positions by the effective (unmasked) sequence length.
|
| 355 |
+
q = self._rope(q, seq_len=latent_seq_len if latent_seq_len is not None else T_lat)
|
| 356 |
+
|
| 357 |
+
# Attention
|
| 358 |
+
logits = (q @ k.transpose(0, 1, 3, 2)) * self.scale # (B, H, T_lat, T_text)
|
| 359 |
+
if text_mask is not None:
|
| 360 |
+
neg_inf = mx.array(-1e4, dtype=logits.dtype)
|
| 361 |
+
logits = mx.where(text_mask[:, :, None, :].astype(mx.bool_), logits, neg_inf)
|
| 362 |
+
attn = mx.softmax(logits, axis=-1)
|
| 363 |
+
out = attn @ v # (B, H, T_lat, D)
|
| 364 |
+
out = out.transpose(0, 2, 1, 3).reshape(B, T_lat, H * D)
|
| 365 |
+
return self.out_fc(out)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class StyleCrossAttnBlock(nn.Module):
|
| 369 |
+
"""Cycle position 5 — style cross-attention to 50 learned style tokens.
|
| 370 |
+
|
| 371 |
+
Loaded keys:
|
| 372 |
+
``attention.W_query.linear.{weight,bias}``
|
| 373 |
+
``attention.W_key.linear.{weight,bias}``
|
| 374 |
+
``attention.W_value.linear.{weight,bias}``
|
| 375 |
+
``attention.out_fc.linear.{weight,bias}``
|
| 376 |
+
``norm.norm.{weight,bias}``
|
| 377 |
+
"""
|
| 378 |
+
|
| 379 |
+
def __init__(self) -> None:
|
| 380 |
+
super().__init__()
|
| 381 |
+
self.attention = _StyleAttnInner(DIM, STYLE_DIM, STYLE_HEADS, STYLE_HEAD_DIM)
|
| 382 |
+
self.norm = WrappedNorm(DIM, eps=EPS_LN)
|
| 383 |
+
|
| 384 |
+
def __call__(
|
| 385 |
+
self,
|
| 386 |
+
x: mx.array,
|
| 387 |
+
mask: mx.array | None,
|
| 388 |
+
*,
|
| 389 |
+
style_k: mx.array | None = None,
|
| 390 |
+
style_v: mx.array | None = None,
|
| 391 |
+
kv_cache: tuple[mx.array, mx.array] | None = None,
|
| 392 |
+
**_,
|
| 393 |
+
) -> mx.array:
|
| 394 |
+
# style_v defaults to style_k (same tensor for cond path); CFG path supplies
|
| 395 |
+
# different style_v to model the uncond branch.
|
| 396 |
+
if style_v is None and style_k is not None:
|
| 397 |
+
style_v = style_k
|
| 398 |
+
residual = x * mask if mask is not None else x
|
| 399 |
+
h = self.attention(residual, style_k, style_v, kv_cache=kv_cache)
|
| 400 |
+
if mask is not None:
|
| 401 |
+
h = h * mask
|
| 402 |
+
out = self.norm(residual + h)
|
| 403 |
+
if mask is not None:
|
| 404 |
+
out = out * mask
|
| 405 |
+
return out
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class _StyleAttnInner(nn.Module):
|
| 409 |
+
def __init__(self, in_dim: int, ctx_dim: int, num_heads: int, head_dim: int) -> None:
|
| 410 |
+
super().__init__()
|
| 411 |
+
self.num_heads = num_heads
|
| 412 |
+
self.head_dim = head_dim
|
| 413 |
+
# ONNX divides attention logits by 16.0 (= sqrt(STYLE_DIM)), not sqrt(head_dim).
|
| 414 |
+
self.scale = ctx_dim ** -0.5
|
| 415 |
+
kv_dim = num_heads * head_dim # 2 * 128 = 256
|
| 416 |
+
# Q is on DIM (audio), K/V on ctx_dim (style 256)
|
| 417 |
+
self.W_query = WrappedLinear(in_dim, kv_dim, bias=True)
|
| 418 |
+
self.W_key = WrappedLinear(ctx_dim, kv_dim, bias=True)
|
| 419 |
+
self.W_value = WrappedLinear(ctx_dim, kv_dim, bias=True)
|
| 420 |
+
self.out_fc = WrappedLinear(kv_dim, in_dim, bias=True)
|
| 421 |
+
|
| 422 |
+
def project_kv(
|
| 423 |
+
self, style_k: mx.array, style_v: mx.array
|
| 424 |
+
) -> tuple[mx.array, mx.array]:
|
| 425 |
+
"""Project (style_k, style_v) → (K, V) once. T.5.3 cache target."""
|
| 426 |
+
B, T_style = style_k.shape[0], style_k.shape[1]
|
| 427 |
+
H, D = self.num_heads, self.head_dim
|
| 428 |
+
# Note: ONNX graph applies tanh to the K projection (``attention/tanh/Tanh``
|
| 429 |
+
# node) — the style key bank is bounded into [-1, 1] before softmax dot
|
| 430 |
+
# product, which acts as a soft attention temperature regulariser.
|
| 431 |
+
k = mx.tanh(self.W_key(style_k)).reshape(B, T_style, H, D).transpose(0, 2, 1, 3)
|
| 432 |
+
v = self.W_value(style_v).reshape(B, style_v.shape[1], H, D).transpose(0, 2, 1, 3)
|
| 433 |
+
return k, v
|
| 434 |
+
|
| 435 |
+
def __call__(
|
| 436 |
+
self,
|
| 437 |
+
x: mx.array,
|
| 438 |
+
style_k: mx.array | None = None,
|
| 439 |
+
style_v: mx.array | None = None,
|
| 440 |
+
kv_cache: tuple[mx.array, mx.array] | None = None,
|
| 441 |
+
) -> mx.array:
|
| 442 |
+
# style_k and style_v can be the same tensor (cond) or distinct (uncond
|
| 443 |
+
# branch in CFG, where K comes from style_key_special_token and V from
|
| 444 |
+
# style_value_special_token).
|
| 445 |
+
B, T_lat, _ = x.shape
|
| 446 |
+
H, D = self.num_heads, self.head_dim
|
| 447 |
+
q = self.W_query(x).reshape(B, T_lat, H, D).transpose(0, 2, 1, 3)
|
| 448 |
+
if kv_cache is not None:
|
| 449 |
+
k, v = kv_cache
|
| 450 |
+
else:
|
| 451 |
+
k, v = self.project_kv(style_k, style_v)
|
| 452 |
+
logits = (q @ k.transpose(0, 1, 3, 2)) * self.scale
|
| 453 |
+
attn = mx.softmax(logits, axis=-1)
|
| 454 |
+
out = attn @ v
|
| 455 |
+
out = out.transpose(0, 2, 1, 3).reshape(B, T_lat, H * D)
|
| 456 |
+
return self.out_fc(out)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
# ──────────────────────────────────────────────────────────────────
|
| 460 |
+
# Time encoder
|
| 461 |
+
# ──────────────────────────────────────────────────────────────────
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class _MlpItem(nn.Module):
|
| 465 |
+
"""A single MLP layer wrapped to produce keys ``mlp.N.linear.{weight,bias}``."""
|
| 466 |
+
|
| 467 |
+
def __init__(self, in_dim: int, out_dim: int) -> None:
|
| 468 |
+
super().__init__()
|
| 469 |
+
self.linear = nn.Linear(in_dim, out_dim, bias=True)
|
| 470 |
+
|
| 471 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 472 |
+
return self.linear(x)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class TimeEncoder(nn.Module):
|
| 476 |
+
"""Sinusoidal time embedding + 2-layer MLP. Keys: ``mlp.0.linear``, ``mlp.2.linear``."""
|
| 477 |
+
|
| 478 |
+
def __init__(self) -> None:
|
| 479 |
+
super().__init__()
|
| 480 |
+
# ONNX: mlp.0.linear (64→256), mlp.2.linear (256→64). Index 1 is activation.
|
| 481 |
+
self.mlp = [
|
| 482 |
+
_MlpItem(TIME_EMB_DIM, TIME_MLP_HIDDEN), # mlp.0
|
| 483 |
+
nn.Identity(), # mlp.1 (activation; no weights)
|
| 484 |
+
_MlpItem(TIME_MLP_HIDDEN, TIME_EMB_DIM), # mlp.2
|
| 485 |
+
]
|
| 486 |
+
|
| 487 |
+
def __call__(self, t: mx.array) -> mx.array:
|
| 488 |
+
# t: (B,) — produce sinusoidal embedding then run through MLP.
|
| 489 |
+
# Activation is Mish (not SiLU) to match the ONNX graph
|
| 490 |
+
# (Softplus → Tanh → Mul pattern == x * tanh(softplus(x))).
|
| 491 |
+
emb = self._sinusoidal(t, TIME_EMB_DIM)
|
| 492 |
+
h = self.mlp[0](emb)
|
| 493 |
+
h = _mish(h)
|
| 494 |
+
h = self.mlp[2](h)
|
| 495 |
+
return h
|
| 496 |
+
|
| 497 |
+
@staticmethod
|
| 498 |
+
def _sinusoidal(t: mx.array, dim: int) -> mx.array:
|
| 499 |
+
"""Time embedding matching ``Supertonic-3`` ONNX exactly.
|
| 500 |
+
|
| 501 |
+
ONNX path: pos = t * 1000; freqs[i] = 10000^(-i/(half-1));
|
| 502 |
+
concat[sin(pos*freqs), cos(pos*freqs)].
|
| 503 |
+
"""
|
| 504 |
+
half = dim // 2
|
| 505 |
+
denom = max(half - 1, 1)
|
| 506 |
+
freqs = mx.exp(-math.log(10_000) * mx.arange(half, dtype=mx.float32) / denom)
|
| 507 |
+
pos = t.astype(mx.float32)[:, None] * 1000.0
|
| 508 |
+
angles = pos * freqs[None, :]
|
| 509 |
+
return mx.concatenate([mx.sin(angles), mx.cos(angles)], axis=-1).astype(mx.float32)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
# ──────────────────────────────────────────────────────────────────
|
| 513 |
+
# Top-level VectorEstimator
|
| 514 |
+
# ──────────────────────────────────────────────────────────────────
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def _build_main_block(idx: int) -> nn.Module:
|
| 518 |
+
"""Instantiate the appropriate block class for cycle position ``idx % 6``."""
|
| 519 |
+
pos = idx % BLOCKS_PER_CYCLE
|
| 520 |
+
name = BLOCK_CYCLE[pos]
|
| 521 |
+
if name == "stack4":
|
| 522 |
+
return Stack4Block()
|
| 523 |
+
if name == "time":
|
| 524 |
+
return TimeFiLMBlock()
|
| 525 |
+
if name == "cn1":
|
| 526 |
+
return ConvNeXt1Block()
|
| 527 |
+
if name == "text_attn":
|
| 528 |
+
return TextCrossAttnBlock()
|
| 529 |
+
if name == "style_attn":
|
| 530 |
+
return StyleCrossAttnBlock()
|
| 531 |
+
raise RuntimeError(f"unknown block type for index {idx}: {name}")
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class _VectorField(nn.Module):
|
| 535 |
+
"""Inner module mirroring ONNX ``vector_estimator.tts.ttl.vector_field.*``."""
|
| 536 |
+
|
| 537 |
+
def __init__(self) -> None:
|
| 538 |
+
super().__init__()
|
| 539 |
+
self.proj_in = ProjConv1x1(LATENT_CH, DIM, bias=False)
|
| 540 |
+
self.main_blocks = [_build_main_block(i) for i in range(NUM_MAIN_BLOCKS)]
|
| 541 |
+
self.last_convnext = ConvNeXtStack(dilations=(1, 1, 1, 1), dim=DIM, hidden=CONVNEXT_HIDDEN)
|
| 542 |
+
self.proj_out = ProjConv1x1(DIM, LATENT_CH, bias=False)
|
| 543 |
+
self.time_encoder = TimeEncoder()
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
class _UncondMasker(nn.Module):
|
| 547 |
+
"""Holds the three unconditional-token tensors used by CFG.
|
| 548 |
+
|
| 549 |
+
Keys:
|
| 550 |
+
``text_special_token`` (1, 256, 1)
|
| 551 |
+
``style_key_special_token`` (1, 50, 256)
|
| 552 |
+
``style_value_special_token`` (1, 50, 256)
|
| 553 |
+
"""
|
| 554 |
+
|
| 555 |
+
def __init__(self) -> None:
|
| 556 |
+
super().__init__()
|
| 557 |
+
# Initialised to zero; checkpoint provides real values.
|
| 558 |
+
self.text_special_token = mx.zeros((1, TEXT_DIM, 1))
|
| 559 |
+
self.style_key_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM))
|
| 560 |
+
self.style_value_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM))
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class VectorEstimator(nn.Module):
|
| 564 |
+
"""Top-level module — matches ONNX root names ``vector_field.*`` and ``uncond_masker.*``.
|
| 565 |
+
|
| 566 |
+
Two inference paths:
|
| 567 |
+
- :meth:`velocity`: single forward pass; predicts the velocity from one set
|
| 568 |
+
of conditioning inputs. ``style_k``/``style_v`` may be the same tensor
|
| 569 |
+
(cond path) or different (uncond path of CFG).
|
| 570 |
+
- :meth:`__call__`: full ONNX-parity forward — applies CFG batch doubling
|
| 571 |
+
(cond + uncond) internally and combines via
|
| 572 |
+
``final = noisy + (4*cond - 3*uncond) / total_step``.
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
# CFG guidance constants — baked into the ONNX graph as ``/Constant_3`` (=4.0)
|
| 576 |
+
# and ``/Constant_4`` (=3.0). Equivalent to guidance_scale = 4 with the
|
| 577 |
+
# standard formula ``v = uncond + g*(cond - uncond) = 4*cond - 3*uncond``.
|
| 578 |
+
CFG_COND_SCALE: float = 4.0
|
| 579 |
+
CFG_UNCOND_SCALE: float = 3.0
|
| 580 |
+
|
| 581 |
+
def __init__(self) -> None:
|
| 582 |
+
super().__init__()
|
| 583 |
+
self.vector_field = _VectorField()
|
| 584 |
+
self.uncond_masker = _UncondMasker()
|
| 585 |
+
|
| 586 |
+
# ── inference API ─────────────────────────────────────────────
|
| 587 |
+
def velocity(
|
| 588 |
+
self,
|
| 589 |
+
noisy_latent: mx.array, # (B, 144, T_lat)
|
| 590 |
+
text_emb: mx.array, # (B, 256, T_text)
|
| 591 |
+
style_k: mx.array, # (B, 50, 256) — K side of style attention
|
| 592 |
+
style_v: mx.array, # (B, 50, 256) — V side of style attention
|
| 593 |
+
latent_mask: mx.array, # (B, 1, T_lat)
|
| 594 |
+
text_mask: mx.array, # (B, 1, T_text)
|
| 595 |
+
t_norm: mx.array, # (B,) timestep in [0, 1]
|
| 596 |
+
) -> mx.array:
|
| 597 |
+
"""Predict velocity (B, 144, T_lat) without applying CFG or Euler step."""
|
| 598 |
+
x = noisy_latent.transpose(0, 2, 1) # (B, T_lat, 144)
|
| 599 |
+
text = text_emb.transpose(0, 2, 1) # (B, T_text, 256)
|
| 600 |
+
lat_mask_ntc = latent_mask.transpose(0, 2, 1) # (B, T_lat, 1)
|
| 601 |
+
|
| 602 |
+
x = self.vector_field.proj_in(x) # (B, T_lat, 512)
|
| 603 |
+
t_emb = self.vector_field.time_encoder(t_norm) # (B, TIME_EMB_DIM)
|
| 604 |
+
|
| 605 |
+
# Effective (unmasked) sequence lengths for RoPE normalisation —
|
| 606 |
+
# ONNX uses ``ReduceSum(mask)`` for this so that audio and text are
|
| 607 |
+
# rope-aligned regardless of padding.
|
| 608 |
+
latent_seq_len = mx.sum(latent_mask, axis=(1, 2)) # (B,)
|
| 609 |
+
text_seq_len = mx.sum(text_mask, axis=(1, 2)) # (B,)
|
| 610 |
+
|
| 611 |
+
for blk in self.vector_field.main_blocks:
|
| 612 |
+
x = blk(
|
| 613 |
+
x,
|
| 614 |
+
lat_mask_ntc,
|
| 615 |
+
t_emb=t_emb,
|
| 616 |
+
text_emb=text,
|
| 617 |
+
text_mask=text_mask,
|
| 618 |
+
style_k=style_k,
|
| 619 |
+
style_v=style_v,
|
| 620 |
+
latent_seq_len=latent_seq_len,
|
| 621 |
+
text_seq_len=text_seq_len,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
x = self.vector_field.last_convnext(x, lat_mask_ntc)
|
| 625 |
+
v_ntc = self.vector_field.proj_out(x) # (B, T_lat, 144)
|
| 626 |
+
return v_ntc.transpose(0, 2, 1) # (B, 144, T_lat)
|
| 627 |
+
|
| 628 |
+
# ── T.5.3 — pre-projected K/V path ────────────────────────────
|
| 629 |
+
def precompute_cross_kv(
|
| 630 |
+
self,
|
| 631 |
+
text_emb: mx.array, # (B, 256, T_text) channels-first
|
| 632 |
+
style_k: mx.array, # (B, 50, 256)
|
| 633 |
+
style_v: mx.array, # (B, 50, 256)
|
| 634 |
+
text_mask: mx.array, # (B, 1, T_text)
|
| 635 |
+
) -> tuple[list[tuple[mx.array, mx.array]], list[tuple[mx.array, mx.array]]]:
|
| 636 |
+
"""Project K/V for every text_attn and style_attn block exactly once.
|
| 637 |
+
|
| 638 |
+
Returns ``(text_kv_list, style_kv_list)`` — both ordered to align with
|
| 639 |
+
the corresponding blocks encountered when iterating ``main_blocks``.
|
| 640 |
+
These tensors are invariant across the 5 Euler steps of one TTS
|
| 641 |
+
call; pre-projecting them once and feeding the result into
|
| 642 |
+
:meth:`velocity_cached` cuts ~ 4 × 2 × 5 = 40 redundant matmuls.
|
| 643 |
+
"""
|
| 644 |
+
text_seq_len = mx.sum(text_mask, axis=(1, 2))
|
| 645 |
+
text_ntc = text_emb.transpose(0, 2, 1) # (B, T_text, 256)
|
| 646 |
+
|
| 647 |
+
text_kv: list[tuple[mx.array, mx.array]] = []
|
| 648 |
+
style_kv: list[tuple[mx.array, mx.array]] = []
|
| 649 |
+
for blk in self.vector_field.main_blocks:
|
| 650 |
+
if isinstance(blk, TextCrossAttnBlock):
|
| 651 |
+
text_kv.append(blk.attn.project_kv(text_ntc, text_seq_len=text_seq_len))
|
| 652 |
+
elif isinstance(blk, StyleCrossAttnBlock):
|
| 653 |
+
style_kv.append(blk.attention.project_kv(style_k, style_v))
|
| 654 |
+
return text_kv, style_kv
|
| 655 |
+
|
| 656 |
+
def velocity_cached(
|
| 657 |
+
self,
|
| 658 |
+
noisy_latent: mx.array,
|
| 659 |
+
latent_mask: mx.array,
|
| 660 |
+
text_mask: mx.array,
|
| 661 |
+
t_norm: mx.array,
|
| 662 |
+
text_kv: list[tuple[mx.array, mx.array]],
|
| 663 |
+
style_kv: list[tuple[mx.array, mx.array]],
|
| 664 |
+
) -> mx.array:
|
| 665 |
+
"""Same as :meth:`velocity` but reads K/V from pre-projected caches.
|
| 666 |
+
|
| 667 |
+
``text_kv`` and ``style_kv`` must come from :meth:`precompute_cross_kv`
|
| 668 |
+
applied to the same (batched) conditioning tensors that will be
|
| 669 |
+
active for this call.
|
| 670 |
+
"""
|
| 671 |
+
x = noisy_latent.transpose(0, 2, 1)
|
| 672 |
+
lat_mask_ntc = latent_mask.transpose(0, 2, 1)
|
| 673 |
+
|
| 674 |
+
x = self.vector_field.proj_in(x)
|
| 675 |
+
t_emb = self.vector_field.time_encoder(t_norm)
|
| 676 |
+
latent_seq_len = mx.sum(latent_mask, axis=(1, 2))
|
| 677 |
+
|
| 678 |
+
ti = 0
|
| 679 |
+
si = 0
|
| 680 |
+
for blk in self.vector_field.main_blocks:
|
| 681 |
+
if isinstance(blk, TextCrossAttnBlock):
|
| 682 |
+
x = blk(
|
| 683 |
+
x, lat_mask_ntc,
|
| 684 |
+
text_mask=text_mask,
|
| 685 |
+
latent_seq_len=latent_seq_len,
|
| 686 |
+
kv_cache=text_kv[ti],
|
| 687 |
+
)
|
| 688 |
+
ti += 1
|
| 689 |
+
elif isinstance(blk, StyleCrossAttnBlock):
|
| 690 |
+
x = blk(x, lat_mask_ntc, kv_cache=style_kv[si])
|
| 691 |
+
si += 1
|
| 692 |
+
else:
|
| 693 |
+
x = blk(x, lat_mask_ntc, t_emb=t_emb)
|
| 694 |
+
|
| 695 |
+
x = self.vector_field.last_convnext(x, lat_mask_ntc)
|
| 696 |
+
v_ntc = self.vector_field.proj_out(x)
|
| 697 |
+
return v_ntc.transpose(0, 2, 1)
|
| 698 |
+
|
| 699 |
+
def __call__(
|
| 700 |
+
self,
|
| 701 |
+
noisy_latent: mx.array, # (B, 144, T_lat) channels-first per ONNX I/O
|
| 702 |
+
text_emb: mx.array, # (B, 256, T_text) channels-first
|
| 703 |
+
style_ttl: mx.array, # (B, 50, 256) — used as both K and V for cond
|
| 704 |
+
latent_mask: mx.array, # (B, 1, T_lat)
|
| 705 |
+
text_mask: mx.array, # (B, 1, T_text)
|
| 706 |
+
current_step: mx.array, # (B,)
|
| 707 |
+
total_step: mx.array, # (B,)
|
| 708 |
+
cfg: bool = True,
|
| 709 |
+
) -> mx.array:
|
| 710 |
+
"""Run one Euler step with CFG (matches ONNX semantics).
|
| 711 |
+
|
| 712 |
+
With ``cfg=True`` (default) the model runs both conditional and
|
| 713 |
+
unconditional paths in a single batched forward and combines via
|
| 714 |
+
``final = noisy + (4*cond_v - 3*uncond_v) / total_step``.
|
| 715 |
+
|
| 716 |
+
With ``cfg=False`` only the conditional path runs — half the work, but
|
| 717 |
+
produces a different (lower-quality) output. Useful for speed bench /
|
| 718 |
+
sanity tests.
|
| 719 |
+
"""
|
| 720 |
+
B = noisy_latent.shape[0]
|
| 721 |
+
t_norm = current_step.astype(mx.float32) / total_step.astype(mx.float32)
|
| 722 |
+
|
| 723 |
+
if not cfg:
|
| 724 |
+
v = self.velocity(
|
| 725 |
+
noisy_latent, text_emb, style_ttl, style_ttl,
|
| 726 |
+
latent_mask, text_mask, t_norm,
|
| 727 |
+
)
|
| 728 |
+
return noisy_latent + v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype)
|
| 729 |
+
|
| 730 |
+
# CFG branch — build (2B, ...) inputs by concatenating cond + uncond.
|
| 731 |
+
# uncond text_emb = text_special_token broadcast to (B, 256, T_text).
|
| 732 |
+
# uncond style_k = style_key_special_token broadcast, similarly style_v.
|
| 733 |
+
text_uncond = mx.broadcast_to(
|
| 734 |
+
self.uncond_masker.text_special_token, (B, TEXT_DIM, text_emb.shape[2])
|
| 735 |
+
)
|
| 736 |
+
style_k_uncond = mx.broadcast_to(
|
| 737 |
+
self.uncond_masker.style_key_special_token, (B, STYLE_LEN, STYLE_DIM)
|
| 738 |
+
)
|
| 739 |
+
style_v_uncond = mx.broadcast_to(
|
| 740 |
+
self.uncond_masker.style_value_special_token, (B, STYLE_LEN, STYLE_DIM)
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
noisy_2 = mx.concatenate([noisy_latent, noisy_latent], axis=0)
|
| 744 |
+
text_2 = mx.concatenate([text_emb, text_uncond], axis=0)
|
| 745 |
+
style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
|
| 746 |
+
style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
|
| 747 |
+
lm_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
|
| 748 |
+
tm_2 = mx.concatenate([text_mask, text_mask], axis=0)
|
| 749 |
+
t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
|
| 750 |
+
|
| 751 |
+
v_2 = self.velocity(
|
| 752 |
+
noisy_2, text_2, style_k_2, style_v_2, lm_2, tm_2, t_norm_2,
|
| 753 |
+
) # (2B, 144, T_lat)
|
| 754 |
+
cond_v = v_2[:B]
|
| 755 |
+
uncond_v = v_2[B:2 * B]
|
| 756 |
+
combined_v = self.CFG_COND_SCALE * cond_v - self.CFG_UNCOND_SCALE * uncond_v
|
| 757 |
+
return noisy_latent + combined_v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype)
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
__all__ = [
|
| 761 |
+
"ConvNeXtBlock", "ConvNeXtStack",
|
| 762 |
+
"Stack4Block", "TimeFiLMBlock", "ConvNeXt1Block",
|
| 763 |
+
"TextCrossAttnBlock", "StyleCrossAttnBlock",
|
| 764 |
+
"TimeEncoder", "VectorEstimator",
|
| 765 |
+
]
|
src/supertonic_3_mlx/vocoder.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Supertonic 3 vocoder — latent → 44.1 kHz waveform, MLX port.
|
| 2 |
+
|
| 3 |
+
Pipeline (operating in channels-last NTC layout, then converted to channels-first
|
| 4 |
+
for output reshape):
|
| 5 |
+
|
| 6 |
+
latent [B, 144, T_lat] (output of vector_estimator)
|
| 7 |
+
→ /= normalizer.scale (scalar)
|
| 8 |
+
→ reshape [B, 24, T_lat*6] # de-compress
|
| 9 |
+
→ (* latent_std + latent_mean) # de-normalise
|
| 10 |
+
→ transpose to NTC [B, T_lat*6, 24]
|
| 11 |
+
→ embed Conv1d(24→512, k=7, sym-edge pad) [B, T_lat*6, 512]
|
| 12 |
+
→ 10× ConvNeXt(dim=512, hidden=2048, k=7,
|
| 13 |
+
dilations [1,2,4,1,2,4,1,1,1,1])
|
| 14 |
+
→ final_norm: BatchNorm1d (eval-time: running stats only)
|
| 15 |
+
→ head.layer1: Conv1d(512→2048, k=3, sym-edge pad)
|
| 16 |
+
→ PReLU (with per-channel learnable slope)
|
| 17 |
+
→ head.layer2: Conv1d(2048→512, k=1, no bias)
|
| 18 |
+
→ transpose to (B, 512, T_lat*6) → flatten → wav (B, T_lat*6*512)
|
| 19 |
+
|
| 20 |
+
The 512 samples/step × 6 chunk × 44.1 kHz → T_lat steps of about 0.0697 s each.
|
| 21 |
+
"""
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import mlx.core as mx
|
| 25 |
+
import mlx.nn as nn
|
| 26 |
+
|
| 27 |
+
from supertonic_3_mlx._config import EPS_LN
|
| 28 |
+
from supertonic_3_mlx._nn_wrappers import WrappedNorm
|
| 29 |
+
from supertonic_3_mlx.vector_estimator import _gelu_exact
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _pad_left_edge(x: mx.array, pad: int) -> mx.array:
|
| 33 |
+
"""Causal replicate-edge pad on the time axis (axis=1 for [B, T, C]).
|
| 34 |
+
|
| 35 |
+
Pads ``pad`` time-steps on the LEFT only by replicating the first frame.
|
| 36 |
+
Matches the ONNX vocoder pads spec ``[0, 0, pad, 0, 0, 0]``.
|
| 37 |
+
"""
|
| 38 |
+
if pad == 0:
|
| 39 |
+
return x
|
| 40 |
+
left = mx.broadcast_to(x[:, :1, :], (x.shape[0], pad, x.shape[2]))
|
| 41 |
+
return mx.concatenate([left, x], axis=1)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
VOC_DIM = 512
|
| 45 |
+
VOC_HIDDEN = 2048
|
| 46 |
+
VOC_K = 7
|
| 47 |
+
VOC_HEAD_K = 3
|
| 48 |
+
VOC_LDIM = 24 # de-compressed channels (24 × 6 = 144 input)
|
| 49 |
+
VOC_CHUNK_COMPRESS = 6
|
| 50 |
+
VOC_NUM_CONVNEXT_LAYERS = 10
|
| 51 |
+
VOC_DILATIONS = (1, 2, 4, 1, 2, 4, 1, 1, 1, 1)
|
| 52 |
+
EPS_BN = 1e-5
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class _Conv1dNet(nn.Module):
|
| 56 |
+
"""Conv1d wrapped under ``.net`` to match ONNX storage ``.net.weight/bias``."""
|
| 57 |
+
|
| 58 |
+
def __init__(self, in_dim: int, out_dim: int, kernel: int, dilation: int = 1,
|
| 59 |
+
groups: int = 1, bias: bool = True) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
class _Net(nn.Module):
|
| 62 |
+
def __init__(_):
|
| 63 |
+
super().__init__()
|
| 64 |
+
# MLX Conv1d weight: (out, K, in/groups)
|
| 65 |
+
_.weight = mx.zeros((out_dim, kernel, in_dim // groups))
|
| 66 |
+
if bias:
|
| 67 |
+
_.bias = mx.zeros((out_dim,))
|
| 68 |
+
else:
|
| 69 |
+
_.bias = None
|
| 70 |
+
def __call__(_, x, dilation=1):
|
| 71 |
+
y = mx.conv1d(x, _.weight, stride=1, padding=0, dilation=dilation,
|
| 72 |
+
groups=groups)
|
| 73 |
+
if _.bias is not None:
|
| 74 |
+
y = y + _.bias
|
| 75 |
+
return y
|
| 76 |
+
self.net = _Net()
|
| 77 |
+
self.dilation = dilation
|
| 78 |
+
self.groups = groups
|
| 79 |
+
self.kernel = kernel
|
| 80 |
+
|
| 81 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 82 |
+
return self.net(x, dilation=self.dilation)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class _VocConvNeXtBlock(nn.Module):
|
| 86 |
+
"""ConvNeXt block matching keys ``convnext.N.{dwconv.net,norm.norm,pwconv1,pwconv2,gamma}``."""
|
| 87 |
+
|
| 88 |
+
def __init__(self, dilation: int) -> None:
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.dilation = dilation
|
| 91 |
+
self.pad = dilation * (VOC_K - 1)
|
| 92 |
+
self.dwconv = _Conv1dNet(VOC_DIM, VOC_DIM, kernel=VOC_K, dilation=dilation,
|
| 93 |
+
groups=VOC_DIM, bias=True)
|
| 94 |
+
self.norm = WrappedNorm(VOC_DIM, eps=EPS_LN)
|
| 95 |
+
# pwconv1 / pwconv2 stored as Conv1d k=1 → loaded after squeeze to Linear.
|
| 96 |
+
self.pwconv1 = nn.Linear(VOC_DIM, VOC_HIDDEN, bias=True)
|
| 97 |
+
self.pwconv2 = nn.Linear(VOC_HIDDEN, VOC_DIM, bias=True)
|
| 98 |
+
self.gamma = mx.zeros((VOC_DIM,))
|
| 99 |
+
|
| 100 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 101 |
+
residual = x
|
| 102 |
+
y = _pad_left_edge(x, self.pad)
|
| 103 |
+
y = self.dwconv(y)
|
| 104 |
+
y = self.norm(y)
|
| 105 |
+
y = self.pwconv1(y)
|
| 106 |
+
y = _gelu_exact(y)
|
| 107 |
+
y = self.pwconv2(y)
|
| 108 |
+
y = y * self.gamma
|
| 109 |
+
return residual + y
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class _BatchNorm1dEval(nn.Module):
|
| 113 |
+
"""Eval-mode BatchNorm1d: applies stored running_mean/running_var only.
|
| 114 |
+
|
| 115 |
+
Loaded keys: ``norm.{weight,bias,running_mean,running_var}``.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self) -> None:
|
| 119 |
+
super().__init__()
|
| 120 |
+
class _Norm(nn.Module):
|
| 121 |
+
def __init__(_):
|
| 122 |
+
super().__init__()
|
| 123 |
+
_.weight = mx.ones((VOC_DIM,))
|
| 124 |
+
_.bias = mx.zeros((VOC_DIM,))
|
| 125 |
+
_.running_mean = mx.zeros((VOC_DIM,))
|
| 126 |
+
_.running_var = mx.ones((VOC_DIM,))
|
| 127 |
+
def __call__(_, x):
|
| 128 |
+
# x: (B, T, C). BN1d normalises across batch+time per channel.
|
| 129 |
+
# Eval mode: use stored running stats.
|
| 130 |
+
norm = (x - _.running_mean) * mx.rsqrt(_.running_var + EPS_BN)
|
| 131 |
+
return norm * _.weight + _.bias
|
| 132 |
+
self.norm = _Norm()
|
| 133 |
+
|
| 134 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 135 |
+
return self.norm(x)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class _VocHeadActivation(nn.Module):
|
| 139 |
+
"""PReLU with per-channel learnable slope (weight shape (C,))."""
|
| 140 |
+
|
| 141 |
+
def __init__(self) -> None:
|
| 142 |
+
super().__init__()
|
| 143 |
+
# ONNX anonymous PReLU stores slope of shape (1,) sometimes or (C,).
|
| 144 |
+
# We default to (1,) and reshape on load if needed.
|
| 145 |
+
self.weight = mx.zeros((1,))
|
| 146 |
+
|
| 147 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 148 |
+
# PReLU: max(0, x) + slope × min(0, x).
|
| 149 |
+
# slope broadcasts over (B, T, C) or (B, C, T) depending on layout.
|
| 150 |
+
zero = mx.array(0.0, dtype=x.dtype)
|
| 151 |
+
return mx.maximum(x, zero) + self.weight * mx.minimum(x, zero)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class _VocHead(nn.Module):
|
| 155 |
+
"""``head.layer1`` (Conv1d 512→2048 k=3) + ``head.act`` (PReLU) + ``head.layer2`` (Conv1d k=1, no bias)."""
|
| 156 |
+
|
| 157 |
+
def __init__(self) -> None:
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.layer1 = _Conv1dNet(VOC_DIM, VOC_HIDDEN, kernel=VOC_HEAD_K, bias=True)
|
| 160 |
+
self.act = _VocHeadActivation()
|
| 161 |
+
# layer2 has no .net wrapper in ONNX (different from layer1)
|
| 162 |
+
# ONNX: head.layer2.weight (512, 2048, 1) — Conv1d k=1, no bias.
|
| 163 |
+
# We represent it directly without .net wrap.
|
| 164 |
+
self.layer2 = _VocLayer2()
|
| 165 |
+
|
| 166 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 167 |
+
# x: (B, T, 512)
|
| 168 |
+
pad = VOC_HEAD_K - 1
|
| 169 |
+
y = _pad_left_edge(x, pad)
|
| 170 |
+
y = self.layer1(y) # (B, T, 2048)
|
| 171 |
+
y = self.act(y)
|
| 172 |
+
y = self.layer2(y) # (B, T, 512)
|
| 173 |
+
return y
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class _VocLayer2(nn.Module):
|
| 177 |
+
"""Conv1d k=1 (2048 → 512), no bias. Keys: ``layer2.weight (512, 2048, 1)``."""
|
| 178 |
+
|
| 179 |
+
def __init__(self) -> None:
|
| 180 |
+
super().__init__()
|
| 181 |
+
# MLX Conv1d weight shape: (out, K, in/groups) = (512, 1, 2048)
|
| 182 |
+
# ONNX storage: (out, in, 1) = (512, 2048, 1). Same size; reshape on load.
|
| 183 |
+
self.weight = mx.zeros((VOC_DIM, 1, VOC_HIDDEN))
|
| 184 |
+
|
| 185 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 186 |
+
return mx.conv1d(x, self.weight, stride=1, padding=0)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class _VocEmbed(nn.Module):
|
| 190 |
+
"""Initial Conv1d(24→512, k=7) with sym-edge pad.
|
| 191 |
+
|
| 192 |
+
The weight + bias are anonymous in the ONNX graph (``onnx::Conv_1441`` and
|
| 193 |
+
``onnx::Conv_1442``); the conversion recovers them via the Conv node path
|
| 194 |
+
``/decoder/embed/net/Conv`` → structured name ``tts.ae.decoder.embed.net.{weight,bias}``.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(self) -> None:
|
| 198 |
+
super().__init__()
|
| 199 |
+
class _Net(nn.Module):
|
| 200 |
+
def __init__(_):
|
| 201 |
+
super().__init__()
|
| 202 |
+
_.weight = mx.zeros((VOC_DIM, VOC_K, VOC_LDIM))
|
| 203 |
+
_.bias = mx.zeros((VOC_DIM,))
|
| 204 |
+
def __call__(_, x):
|
| 205 |
+
return mx.conv1d(x, _.weight, stride=1, padding=0) + _.bias
|
| 206 |
+
self.net = _Net()
|
| 207 |
+
|
| 208 |
+
def __call__(self, x: mx.array) -> mx.array:
|
| 209 |
+
pad = VOC_K - 1
|
| 210 |
+
y = _pad_left_edge(x, pad)
|
| 211 |
+
return self.net(y)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class _VocDecoder(nn.Module):
|
| 215 |
+
"""``tts.ae.decoder.X`` namespace."""
|
| 216 |
+
|
| 217 |
+
def __init__(self) -> None:
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.embed = _VocEmbed()
|
| 220 |
+
self.convnext = [_VocConvNeXtBlock(d) for d in VOC_DILATIONS]
|
| 221 |
+
self.final_norm = _BatchNorm1dEval()
|
| 222 |
+
self.head = _VocHead()
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class _AEContainer(nn.Module):
|
| 226 |
+
"""``tts.ae.X`` — holds latent_mean, latent_std, decoder."""
|
| 227 |
+
|
| 228 |
+
def __init__(self) -> None:
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.latent_mean = mx.zeros((1, VOC_LDIM, 1))
|
| 231 |
+
self.latent_std = mx.ones((1, VOC_LDIM, 1))
|
| 232 |
+
self.decoder = _VocDecoder()
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class _TtlContainer(nn.Module):
|
| 236 |
+
"""``tts.ttl.normalizer.scale`` (scalar) — divides the latent before de-norm."""
|
| 237 |
+
|
| 238 |
+
def __init__(self) -> None:
|
| 239 |
+
super().__init__()
|
| 240 |
+
class _Normalizer(nn.Module):
|
| 241 |
+
def __init__(_):
|
| 242 |
+
super().__init__()
|
| 243 |
+
_.scale = mx.array(1.0)
|
| 244 |
+
self.normalizer = _Normalizer()
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class _TtsContainer(nn.Module):
|
| 248 |
+
def __init__(self) -> None:
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.ttl = _TtlContainer()
|
| 251 |
+
self.ae = _AEContainer()
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class Vocoder(nn.Module):
|
| 255 |
+
"""Latent → waveform decoder (44.1 kHz mono).
|
| 256 |
+
|
| 257 |
+
Submodule namespace matches ONNX keys ``tts.X.Y`` exactly.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(self) -> None:
|
| 261 |
+
super().__init__()
|
| 262 |
+
self.tts = _TtsContainer()
|
| 263 |
+
|
| 264 |
+
def __call__(self, latent: mx.array) -> mx.array:
|
| 265 |
+
# latent: (B, 144, T_lat)
|
| 266 |
+
B = latent.shape[0]
|
| 267 |
+
T_lat = latent.shape[2]
|
| 268 |
+
|
| 269 |
+
# /= scale (scalar)
|
| 270 |
+
x = latent / self.tts.ttl.normalizer.scale
|
| 271 |
+
|
| 272 |
+
# reshape (B, 144, T_lat) → (B, 24, T_lat*6)
|
| 273 |
+
x = x.reshape(B, VOC_LDIM, VOC_CHUNK_COMPRESS, T_lat) # (B, 24, 6, T_lat)
|
| 274 |
+
x = x.transpose(0, 1, 3, 2) # (B, 24, T_lat, 6)
|
| 275 |
+
x = x.reshape(B, VOC_LDIM, T_lat * VOC_CHUNK_COMPRESS) # (B, 24, T_lat*6)
|
| 276 |
+
|
| 277 |
+
# De-normalise: (* std + mean)
|
| 278 |
+
x = x * self.tts.ae.latent_std + self.tts.ae.latent_mean
|
| 279 |
+
|
| 280 |
+
# Transpose to NTC for Conv1d layers
|
| 281 |
+
x = x.transpose(0, 2, 1) # (B, T_lat*6, 24)
|
| 282 |
+
|
| 283 |
+
# embed
|
| 284 |
+
x = self.tts.ae.decoder.embed(x) # (B, T_lat*6, 512)
|
| 285 |
+
|
| 286 |
+
# 10× ConvNeXt
|
| 287 |
+
for blk in self.tts.ae.decoder.convnext:
|
| 288 |
+
x = blk(x)
|
| 289 |
+
|
| 290 |
+
# final_norm (BatchNorm1d eval)
|
| 291 |
+
x = self.tts.ae.decoder.final_norm(x)
|
| 292 |
+
|
| 293 |
+
# head
|
| 294 |
+
x = self.tts.ae.decoder.head(x) # (B, T_lat*6, 512)
|
| 295 |
+
|
| 296 |
+
# Flatten time × channels row-major → waveform (matches ONNX:
|
| 297 |
+
# head.layer2 Conv (B, 512, T_lat*6) → Transpose to (B, T_lat*6, 512) →
|
| 298 |
+
# Reshape to (B, T_lat*6*512). Since the head already runs in NTC, we
|
| 299 |
+
# are already in the post-Transpose layout and only the Reshape remains).
|
| 300 |
+
wav = x.reshape(B, -1) # (B, T_lat*6*512)
|
| 301 |
+
return wav
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
__all__ = ["Vocoder", "VOC_DIM", "VOC_HIDDEN", "VOC_LDIM", "VOC_CHUNK_COMPRESS"]
|
src/supertonic_3_mlx/weights.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ONNX → MLX safetensors conversion for Supertonic 3.
|
| 2 |
+
|
| 3 |
+
Two-stage extraction:
|
| 4 |
+
1. **Named initializers** (e.g. ``vector_estimator.tts.ttl.vector_field.main_blocks.0.convnext.0.dwconv.weight``)
|
| 5 |
+
— straight name strip + optional shape transformation.
|
| 6 |
+
2. **Anonymous MatMul weights** (e.g. ``onnx::MatMul_3391``) — looked up via the
|
| 7 |
+
MatMul node graph: each MatMul output path is the human-readable name of the
|
| 8 |
+
weight (e.g. ``…/W_query/linear/MatMul_output_0``); we trace the second
|
| 9 |
+
operand initializer and rebind it to the structured name + transpose to
|
| 10 |
+
the MLX Linear layout ``(out, in)``.
|
| 11 |
+
|
| 12 |
+
Shape transformations:
|
| 13 |
+
- depthwise dwconv: ONNX ``(C, 1, K)`` → MLX ``(C, K, 1)``
|
| 14 |
+
- pwconv1/2 k=1: ONNX ``(out, in, 1)`` → MLX ``(out, in)``
|
| 15 |
+
- proj_in/out k=1: ONNX ``(out, in, 1)`` → MLX ``(out, in)``
|
| 16 |
+
- MatMul Linear: ONNX ``(in, out)`` → MLX ``(out, in)``
|
| 17 |
+
- gamma: ONNX ``(1, dim, 1)`` → MLX ``(dim,)``
|
| 18 |
+
"""
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Dict, Tuple
|
| 23 |
+
|
| 24 |
+
import mlx.core as mx
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_ONNX_PREFIX = "vector_estimator.tts.ttl."
|
| 29 |
+
|
| 30 |
+
_DWCONV_SUFFIX = ".dwconv.weight"
|
| 31 |
+
_PWCONV_SUFFIXES = (".pwconv1.weight", ".pwconv2.weight")
|
| 32 |
+
_GAMMA_SUFFIX = ".gamma"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _strip_prefix(name: str) -> str:
|
| 36 |
+
if name.startswith(_ONNX_PREFIX):
|
| 37 |
+
return name[len(_ONNX_PREFIX):]
|
| 38 |
+
return name
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _is_named_weight(name: str) -> bool:
|
| 42 |
+
"""True if this is a structured weight (vs anonymous graph constant)."""
|
| 43 |
+
if name.startswith(_ONNX_PREFIX):
|
| 44 |
+
return True
|
| 45 |
+
if name.startswith("uncond_masker."):
|
| 46 |
+
return True
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _convert_named(clean_name: str, arr: np.ndarray) -> np.ndarray:
|
| 51 |
+
"""Apply shape transforms to a named initializer based on its key."""
|
| 52 |
+
# Depthwise Conv1d weight: (C, 1, K) → (C, K, 1)
|
| 53 |
+
if clean_name.endswith(_DWCONV_SUFFIX) and arr.ndim == 3 and arr.shape[1] == 1 and arr.shape[2] != 1:
|
| 54 |
+
arr = np.transpose(arr, (0, 2, 1))
|
| 55 |
+
|
| 56 |
+
# Pointwise k=1 / proj net weight: (out, in, 1) → (out, in)
|
| 57 |
+
if (any(clean_name.endswith(s) for s in _PWCONV_SUFFIXES) or clean_name.endswith(".net.weight")) \
|
| 58 |
+
and arr.ndim == 3 and arr.shape[-1] == 1:
|
| 59 |
+
arr = arr.squeeze(-1)
|
| 60 |
+
|
| 61 |
+
# gamma: (1, C, 1) → (C,)
|
| 62 |
+
if clean_name.endswith(_GAMMA_SUFFIX) and arr.ndim == 3 and arr.shape[0] == 1 and arr.shape[2] == 1:
|
| 63 |
+
arr = arr.reshape(arr.shape[1])
|
| 64 |
+
|
| 65 |
+
return arr
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _matmul_output_to_clean_name(matmul_output: str) -> str:
|
| 69 |
+
"""Map a MatMul node output path to the structured ``.weight`` key.
|
| 70 |
+
|
| 71 |
+
Example::
|
| 72 |
+
|
| 73 |
+
/vector_estimator/vector_field/main_blocks.3/attn/W_query/linear/MatMul_output_0
|
| 74 |
+
→ vector_field.main_blocks.3.attn.W_query.linear.weight
|
| 75 |
+
"""
|
| 76 |
+
# Strip prefix slash and the trailing /MatMul_output_0
|
| 77 |
+
path = matmul_output.lstrip("/")
|
| 78 |
+
if path.endswith("/MatMul_output_0"):
|
| 79 |
+
path = path[: -len("/MatMul_output_0")]
|
| 80 |
+
# Drop leading "vector_estimator/" if present
|
| 81 |
+
if path.startswith("vector_estimator/"):
|
| 82 |
+
path = path[len("vector_estimator/"):]
|
| 83 |
+
return path.replace("/", ".") + ".weight"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def convert_onnx_to_mlx(onnx_path: str | Path) -> Dict[str, mx.array]:
|
| 87 |
+
"""Load an ONNX model and return all weights as ``{clean_name: mx.array}``.
|
| 88 |
+
|
| 89 |
+
Combines named initializers and MatMul-only weights into a single dict ready
|
| 90 |
+
for ``model.load_weights(...)``.
|
| 91 |
+
"""
|
| 92 |
+
import onnx
|
| 93 |
+
from onnx import numpy_helper
|
| 94 |
+
|
| 95 |
+
model = onnx.load(str(onnx_path))
|
| 96 |
+
|
| 97 |
+
# Build initializer name → numpy array map (in-memory once)
|
| 98 |
+
inits: Dict[str, np.ndarray] = {
|
| 99 |
+
init.name: numpy_helper.to_array(init) for init in model.graph.initializer
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
out: Dict[str, mx.array] = {}
|
| 103 |
+
|
| 104 |
+
# Stage 1: named initializers
|
| 105 |
+
for name, arr in inits.items():
|
| 106 |
+
if not _is_named_weight(name):
|
| 107 |
+
continue
|
| 108 |
+
clean = _strip_prefix(name)
|
| 109 |
+
arr = _convert_named(clean, arr)
|
| 110 |
+
out[clean] = mx.array(arr)
|
| 111 |
+
|
| 112 |
+
# Stage 2: anonymous MatMul weights, recovered via the graph
|
| 113 |
+
for node in model.graph.node:
|
| 114 |
+
if node.op_type != "MatMul":
|
| 115 |
+
continue
|
| 116 |
+
if len(node.input) < 2:
|
| 117 |
+
continue
|
| 118 |
+
# The weight is conventionally the second operand
|
| 119 |
+
weight_name = node.input[1]
|
| 120 |
+
if weight_name not in inits:
|
| 121 |
+
continue
|
| 122 |
+
# Skip if it's already named structurally (shouldn't happen here)
|
| 123 |
+
if _is_named_weight(weight_name):
|
| 124 |
+
continue
|
| 125 |
+
# Look up the structured name from the MatMul output path
|
| 126 |
+
if len(node.output) < 1:
|
| 127 |
+
continue
|
| 128 |
+
clean = _matmul_output_to_clean_name(node.output[0])
|
| 129 |
+
# ONNX MatMul stores W as (in, out); MLX Linear expects (out, in)
|
| 130 |
+
arr = inits[weight_name]
|
| 131 |
+
if arr.ndim == 2:
|
| 132 |
+
arr = arr.T
|
| 133 |
+
out[clean] = mx.array(arr)
|
| 134 |
+
|
| 135 |
+
if not out:
|
| 136 |
+
raise RuntimeError(f"no weights extracted from {onnx_path}")
|
| 137 |
+
return out
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def save_safetensors(
|
| 141 |
+
onnx_path: str | Path,
|
| 142 |
+
output_path: str | Path,
|
| 143 |
+
) -> Dict[str, Tuple[int, ...]]:
|
| 144 |
+
"""Convert an ONNX file to MLX safetensors. Returns a {name: shape} map."""
|
| 145 |
+
weights = convert_onnx_to_mlx(onnx_path)
|
| 146 |
+
output_path = Path(output_path)
|
| 147 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 148 |
+
mx.save_safetensors(str(output_path), weights)
|
| 149 |
+
return {k: tuple(v.shape) for k, v in weights.items()}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
__all__ = ["convert_onnx_to_mlx", "save_safetensors"]
|
unicode_indexer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
voice_styles/F1.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
voice_styles/F2.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
voice_styles/F3.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
voice_styles/F4.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
voice_styles/F5.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
voice_styles/M1.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
voice_styles/M2.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
voice_styles/M3.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
voice_styles/M4.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
voice_styles/M5.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
weights/duration_predictor.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd473acb6e0ac27426084488ccb3b3cc184e70d05db90897e2b892846db5dcb3
|
| 3 |
+
size 3470807
|
weights/text_encoder.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9df20bb79496718b36d2c0fc37636d3f78d6ef751b2899ff6dfeb975ae737ada
|
| 3 |
+
size 36022466
|
weights/vector_estimator.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2359240f2dcaee03b4800102aa0bea00223d2867ab752ef01af2b1cfaf92f3a6
|
| 3 |
+
size 256053073
|
weights/vocoder.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2ec31ab7c554f6e15b9a6780554b5d3502345de7848b310966bfb4e1ea4e526
|
| 3 |
+
size 101364763
|