ambassadia commited on
Commit
69dee76
·
verified ·
1 Parent(s): 7cc65c1

v0.1.0 — initial release (MLX-native Supertonic 3, ~x100 RTF on M4)

Browse files
Files changed (47) hide show
  1. .gitattributes +6 -0
  2. LICENSE +209 -0
  3. LICENSE-CODE +202 -0
  4. NOTICE +39 -0
  5. README.md +239 -0
  6. bench_results.csv +7 -0
  7. conversion_report.json +226 -0
  8. examples/quickstart.py +23 -0
  9. pyproject.toml +43 -0
  10. samples/de_M2.wav +3 -0
  11. samples/en_F1_short.wav +3 -0
  12. samples/en_M1_long.wav +3 -0
  13. samples/es_M3.wav +3 -0
  14. samples/fr_F2.wav +3 -0
  15. samples/ja_F3.wav +3 -0
  16. src/supertonic_3_mlx/__init__.py +51 -0
  17. src/supertonic_3_mlx/__pycache__/__init__.cpython-312.pyc +0 -0
  18. src/supertonic_3_mlx/__pycache__/_config.cpython-312.pyc +0 -0
  19. src/supertonic_3_mlx/__pycache__/_nn_wrappers.cpython-312.pyc +0 -0
  20. src/supertonic_3_mlx/__pycache__/duration_predictor.cpython-312.pyc +0 -0
  21. src/supertonic_3_mlx/__pycache__/pipeline.cpython-312.pyc +0 -0
  22. src/supertonic_3_mlx/__pycache__/text_encoder.cpython-312.pyc +0 -0
  23. src/supertonic_3_mlx/__pycache__/vector_estimator.cpython-312.pyc +0 -0
  24. src/supertonic_3_mlx/__pycache__/vocoder.cpython-312.pyc +0 -0
  25. src/supertonic_3_mlx/_config.py +58 -0
  26. src/supertonic_3_mlx/_nn_wrappers.py +50 -0
  27. src/supertonic_3_mlx/duration_predictor.py +347 -0
  28. src/supertonic_3_mlx/pipeline.py +545 -0
  29. src/supertonic_3_mlx/text_encoder.py +382 -0
  30. src/supertonic_3_mlx/vector_estimator.py +765 -0
  31. src/supertonic_3_mlx/vocoder.py +304 -0
  32. src/supertonic_3_mlx/weights.py +152 -0
  33. unicode_indexer.json +0 -0
  34. voice_styles/F1.json +0 -0
  35. voice_styles/F2.json +0 -0
  36. voice_styles/F3.json +0 -0
  37. voice_styles/F4.json +0 -0
  38. voice_styles/F5.json +0 -0
  39. voice_styles/M1.json +0 -0
  40. voice_styles/M2.json +0 -0
  41. voice_styles/M3.json +0 -0
  42. voice_styles/M4.json +0 -0
  43. voice_styles/M5.json +0 -0
  44. weights/duration_predictor.safetensors +3 -0
  45. weights/text_encoder.safetensors +3 -0
  46. weights/vector_estimator.safetensors +3 -0
  47. weights/vocoder.safetensors +3 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ samples/de_M2.wav filter=lfs diff=lfs merge=lfs -text
37
+ samples/en_F1_short.wav filter=lfs diff=lfs merge=lfs -text
38
+ samples/en_M1_long.wav filter=lfs diff=lfs merge=lfs -text
39
+ samples/es_M3.wav filter=lfs diff=lfs merge=lfs -text
40
+ samples/fr_F2.wav filter=lfs diff=lfs merge=lfs -text
41
+ samples/ja_F3.wav filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BigScience Open RAIL-M License
2
+ dated August 18, 2022
3
+
4
+ Section I: PREAMBLE
5
+
6
+ This Open RAIL-M License was created by BigScience, a collaborative open innovation project aimed at
7
+ the responsible development and use of large multilingual datasets and Large Language Models
8
+ (“LLMs”). While a similar license was originally designed for the BLOOM model, we decided to adapt it
9
+ and create this license in order to propose a general open and responsible license applicable to other
10
+ machine learning based AI models (e.g. multimodal generative models).
11
+ In short, this license strives for both the open and responsible downstream use of the accompanying
12
+ model. When it comes to the open character, we took inspiration from open source permissive licenses
13
+ regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based
14
+ restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be
15
+ able to enforce the license in case potential misuses of the Model may occur. Even though downstream
16
+ derivative versions of the model could be released under different licensing terms, the latter will always
17
+ have to include - at minimum - the same use-based restrictions as the ones in the original license (this
18
+ license).
19
+ The development and use of artificial intelligence (“AI”), does not come without concerns. The world has
20
+ witnessed how AI techniques may, in some instances, become risky for the public in general. These risks
21
+ come in many forms, from racial discrimination to the misuse of sensitive information.
22
+ BigScience believes in the intersection between open and responsible AI development; thus, this License
23
+ aims to strike a balance between both in order to enable responsible open-science in the field of AI.
24
+ This License governs the use of the model (and its derivatives) and is informed by the model card
25
+ associated with the model.
26
+
27
+ NOW THEREFORE, You and Licensor agree as follows:
28
+
29
+ 1. Definitions
30
+ (a) "License" means the terms and conditions for use, reproduction, and Distribution as defined in
31
+ this document.
32
+ (b) “Data” means a collection of information and/or content extracted from the dataset used with the
33
+ Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under
34
+ this License.
35
+ (c)“Output” means the results of operating a Model as embodied in informational content resulting
36
+ therefrom.
37
+ (d)“Model” means any accompanying machine-learning based assemblies (including checkpoints),
38
+ consisting of learnt weights, parameters (including optimizer states), corresponding to the model
39
+ architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or
40
+ in part on the Data, using the Complementary Material.
41
+ (e) “Derivatives of the Model” means all modifications to the Model, works based on the Model, or any
42
+ other model which is created or initialized by transfer of patterns of the weights, parameters,
43
+ activations or output of the Model, to the other model, in order to cause the other model to perform
44
+ similarly to the Model, including - but not limited to - distillation methods entailing the use of
45
+ intermediate data representations or methods based on the generation of synthetic data by the Model
46
+ for training the other model.
47
+ (f)“Complementary Material” means the accompanying source code and scripts used to define,
48
+ run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if
49
+ any. This includes any accompanying documentation, tutorials, examples, etc, if any.
50
+ (g) “Distribution” means any transmission, reproduction, publication or other sharing of the Model or
51
+ Derivatives of the Model to a third party, including providing the Model as a hosted service made
52
+ available by electronic or other remote means - e.g. API-based or web access.
53
+ (h) “Licensor” means the copyright owner or entity authorized by the copyright owner that is
54
+ granting the License, including the persons or entities that may have rights in the Model and/or
55
+ distributing the Model.
56
+ (i) "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this
57
+ License and/or making use of the Model for whichever purpose and in any field of use, including
58
+ usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
59
+ (j) “Third Parties” means individuals or legal entities that are not under common control with
60
+ Licensor or You.
61
+ (k) "Contribution" means any work of authorship, including the original version of the Model and
62
+ any modifications or additions to that Model or Derivatives of the Model thereof, that is
63
+ intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an
64
+ individual or Legal Entity authorized to submit on behalf of the copyright owner. For the
65
+ purposes of this definition,
66
+ “submitted” means any form of electronic, verbal, or written
67
+ communication sent to the Licensor or its representatives, including but not limited to
68
+ communication on electronic mailing lists, source code control systems, and issue tracking
69
+ systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and
70
+ improving the Model, but excluding communication that is conspicuously marked or otherwise
71
+ designated in writing by the copyright owner as "Not a Contribution."
72
+ (l) "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a
73
+ Contribution has been received by Licensor and subsequently incorporated within the Model.
74
+
75
+
76
+ Section II: INTELLECTUAL PROPERTY RIGHTS
77
+
78
+ Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary
79
+ Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
80
+
81
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor
82
+ hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the
83
+ Complementary Material, the Model, and Derivatives of the Model.
84
+
85
+ 3. Grant of Patent License. Subject to the terms and conditions of this License and where and as
86
+ applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
87
+ royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer
88
+ to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such
89
+ license applies only to those patent claims licensable by such Contributor that are necessarily infringed by
90
+ their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such
91
+ Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim
92
+ or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution
93
+ incorporated within the Model and/or Complementary Material constitutes direct or contributory patent
94
+ infringement, then any patent licenses granted to You under this License for the Model and/or Work shall
95
+ terminate as of the date such litigation is asserted or filed.
96
+ Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
97
+
98
+ 4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g.
99
+ software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof
100
+ in any medium, with or without modifications, provided that You meet the following conditions:
101
+
102
+ a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision
103
+ by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the
104
+ Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to,
105
+ that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply
106
+ to the use of Complementary Material.
107
+
108
+ b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this
109
+ License;
110
+
111
+ c. You must cause any modified files to carry prominent notices stating that You changed the files;
112
+
113
+ d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices
114
+ that do not pertain to any part of the Model, Derivatives of the Model.
115
+ You may add Your own copyright statement to Your modifications and may provide additional or
116
+ different license terms and conditions - respecting paragraph 4.a.
117
+ - for use, reproduction, or Distribution
118
+ of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use,
119
+ reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
120
+
121
+ 5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions.
122
+ Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You
123
+ may use the Model subject to this License, including only for lawful purposes and in accordance with the
124
+ License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or
125
+ reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model
126
+ to comply with the terms of this paragraph (paragraph 5).
127
+
128
+ 6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You
129
+ generate using the Model. You are accountable for the Output you generate and its subsequent uses. No
130
+ use of the output can contravene any provision as stated in the License.
131
+
132
+ Section IV: OTHER PROVISIONS
133
+
134
+ 7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the
135
+ right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model
136
+ through electronic means, or modify the Output of the Model based on updates. You shall undertake
137
+ reasonable efforts to use the latest version of the Model.
138
+
139
+ 8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks,
140
+ trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the
141
+ parties; and any rights not expressly granted herein are reserved by the Licensors.
142
+
143
+ 9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides
144
+ the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS
145
+ IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
146
+ including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT,
147
+ MERCHANTABILITY , or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for
148
+ determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the
149
+ Complementary Material and assume any risks associated with Your exercise of permissions under this
150
+ License.
151
+
152
+ 10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence),
153
+ contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or
154
+ agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect,
155
+ special, incidental, or consequential damages of any character arising as a result of this License or out of
156
+ the use or inability to use the Model and the Complementary Material (including but not limited to
157
+ damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other
158
+ commercial damages or losses), even if such Contributor has been advised of the possibility of such
159
+ damages.
160
+
161
+ 11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the
162
+ Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance
163
+ of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License.
164
+ However, in accepting such obligations, You may act only on Your own behalf and on Your sole
165
+ responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and
166
+ hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor
167
+ by reason of your accepting any such warranty or additional liability.
168
+
169
+ 12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining
170
+ provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
171
+
172
+ END OF TERMS AND CONDITIONS
173
+
174
+ Attachment A
175
+
176
+ Use Restrictions
177
+
178
+ You agree not to use the Model or Derivatives of the Model:
179
+ (a) In any way that violates any applicable national, federal, state, local or international law
180
+ or regulation;
181
+ (b) For the purpose of exploiting, harming or attempting to exploit or harm minors in any
182
+ way;
183
+ (c) To generate or disseminate verifiably false information and/or content with the purpose of
184
+ harming others;
185
+ (d) To generate or disseminate personal identifiable information that can be used to harm an
186
+ individual;
187
+ (e) To generate or disseminate information and/or content (e.g. images, code, posts, articles),
188
+ and place the information and/or content in any context (e.g. bot generating tweets)
189
+ without expressly and intelligibly disclaiming that the information and/or content is
190
+ machine generated;
191
+ (f) To defame, disparage or otherwise harass others;
192
+ (g) To impersonate or attempt to impersonate (e.g. deepfakes) others without their consent;
193
+ (h) For fully automated decision making that adversely impacts an individual’s legal rights or
194
+ otherwise creates or modifies a binding, enforceable obligation;
195
+ (i) For any use intended to or which has the effect of discriminating against or harming
196
+ individuals or groups based on online or offline social behavior or known or predicted
197
+ personal or personality characteristics;
198
+ (j) To exploit any of the vulnerabilities of a specific group of persons based on their age,
199
+ social, physical or mental characteristics, in order to materially distort the behavior of a
200
+ person pertaining to that group in a manner that causes or is likely to cause that person or
201
+ another person physical or psychological harm;
202
+ (k) For any use intended to or which has the effect of discriminating against individuals or
203
+ groups based on legally protected characteristics or categories;
204
+ (l) To provide medical advice and medical results interpretation;
205
+ (m) To generate or disseminate information for the purpose to be used for administration of
206
+ justice, law enforcement, immigration or asylum processes, such as predicting an
207
+ individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal
208
+ relationships between assertions made in documents, indiscriminate and
209
+ arbitrarily-targeted use).
LICENSE-CODE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
NOTICE ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ supertonic-3-mlx
2
+ ================
3
+
4
+ This release is a derivative of the upstream Supertone Supertonic 3
5
+ text-to-speech model and consists of two artefact classes governed by
6
+ two different licenses:
7
+
8
+ 1. The model weights (under ./weights/*.safetensors) are released under
9
+ the BigScience Open RAIL-M License. The full text is in ./LICENSE and
10
+ was copied verbatim from
11
+ https://huggingface.co/Supertone/supertonic-3/blob/main/LICENSE
12
+ The Attachment A use restrictions (Section 5 + Attachment A clauses
13
+ (a)–(m)) apply to all downstream use of the model and of any output
14
+ generated by the model.
15
+
16
+ 2. The MLX port code (under ./src/supertonic_3_mlx/) is released under
17
+ the Apache License, Version 2.0. The full text is in ./LICENSE-CODE.
18
+
19
+ Attribution and modifications statement (BigScience Open RAIL-M Section 4.c):
20
+
21
+ Copyright (c) 2026 Supertone Inc. — original model weights and reference
22
+ Python/ONNX implementation. Distributed at
23
+ https://huggingface.co/Supertone/supertonic-3
24
+ Copyright (c) 2026 Olivier Dupont — MLX-native port code, weight format
25
+ conversion (ONNX → safetensors via the 3-stage extractor in
26
+ ``src/supertonic_3_mlx/pipeline.py:_convert_onnx``), and pipeline
27
+ optimisations (``mx.compile`` of the CFG Euler loop, cross-attention
28
+ K/V cache shared across the 5 Euler steps). Distributed at
29
+ https://huggingface.co/ambassadia/supertonic-3-mlx
30
+
31
+ The MLX port does not modify the model's learned parameters in any
32
+ semantic sense — the only weight-level transformation is a tensor-shape
33
+ re-layout to match the MLX memory model (e.g. depthwise Conv1d
34
+ ``(C, 1, K)`` → ``(C, K, 1)``). Bit-identical audio output to the
35
+ upstream ONNX Runtime reference is preserved up to FP32 accumulation
36
+ noise (cosine ≥ 0.98 on the full pipeline, cosine = 1.00 on the vocoder).
37
+
38
+ No use of the Supertone trademarks, logos, or trade dress is asserted or
39
+ permitted by this release (BigScience Open RAIL-M Section 8).
README.md ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: openrail
3
+ license_link: LICENSE
4
+ language:
5
+ - en
6
+ - fr
7
+ - de
8
+ - es
9
+ - it
10
+ - pt
11
+ - ja
12
+ - ko
13
+ - zh
14
+ - ru
15
+ - pl
16
+ - nl
17
+ - tr
18
+ - ar
19
+ - hi
20
+ - vi
21
+ - th
22
+ - id
23
+ - cs
24
+ - ro
25
+ - hu
26
+ - el
27
+ - da
28
+ - sv
29
+ - fi
30
+ - no
31
+ - he
32
+ - uk
33
+ - bg
34
+ - hr
35
+ - sk
36
+ pipeline_tag: text-to-speech
37
+ tags:
38
+ - mlx
39
+ - apple-silicon
40
+ - tts
41
+ - text-to-speech
42
+ - speech-synthesis
43
+ - supertonic
44
+ - multilingual
45
+ - flow-matching
46
+ library_name: supertonic-3-mlx
47
+ base_model: Supertone/supertonic-3
48
+ inference: false
49
+ ---
50
+
51
+ # Supertonic 3 — MLX-native
52
+
53
+ **31-language text-to-speech, ~x100 realtime on Apple Silicon.**
54
+ Native MLX port of [Supertone/supertonic-3](https://huggingface.co/Supertone/supertonic-3),
55
+ runs the full flow-matching + classifier-free-guidance pipeline (DurationPredictor →
56
+ TextEncoder → 24-block VectorEstimator (5 Euler steps) → 10-block Vocos vocoder)
57
+ without ONNX, CoreML or any C++ runtime — only MLX + NumPy.
58
+
59
+ ## Install
60
+
61
+ ```bash
62
+ pip install supertonic-3-mlx
63
+ ```
64
+
65
+ The package depends only on `mlx` and `numpy`. The optional `[hub]` extra adds
66
+ `huggingface_hub` for one-line weight downloads.
67
+
68
+ ## Quickstart
69
+
70
+ ```python
71
+ from supertonic_3_mlx import Pipeline
72
+
73
+ pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
74
+ wav = pipe.generate("Hello world from Apple Silicon.", voice="F1", lang="en")
75
+
76
+ # wav is a 1-D numpy.float32 array at 44.1 kHz
77
+ import soundfile as sf
78
+ sf.write("hello.wav", wav, pipe.sample_rate)
79
+ ```
80
+
81
+ The first call downloads the ~400 MB weight bundle into your Hugging Face cache.
82
+ Subsequent calls re-use the cached weights and cold-start in ~11 ms on M4.
83
+
84
+ ## Audio samples
85
+
86
+ Six languages, mix of male / female voices, mix of short and long utterances —
87
+ all generated by the MLX pipeline at the wall times reported below.
88
+
89
+ <audio controls src="samples/en_F1_short.wav"></audio> &nbsp; **EN · F1 · 2.79 s** —
90
+ "Hello world from Apple Silicon. Supertonic 3 runs at one hundred times real time."
91
+
92
+ <audio controls src="samples/en_M1_long.wav"></audio> &nbsp; **EN · M1 · 3.90 s** —
93
+ "A gentle breeze moved through the open window while the children, still half-asleep, listened to the distant sound of the harbour bells."
94
+
95
+ <audio controls src="samples/fr_F2.wav"></audio> &nbsp; **FR · F2 · 3.41 s** —
96
+ "Bonjour, ceci est un test de synthèse vocale en français. Le modèle gère trente-et-une langues sur une puce M4."
97
+
98
+ <audio controls src="samples/de_M2.wav"></audio> &nbsp; **DE · M2 · 3.69 s** —
99
+ "Guten Morgen. Dieses Modell läuft komplett auf Apple Silicon, ohne ONNX und ohne CoreML, in reinem MLX."
100
+
101
+ <audio controls src="samples/ja_F3.wav"></audio> &nbsp; **JA · F3 · 1.46 s** —
102
+ "こんにちは。これはアップルシリコン上でMLXを使ったテストです。"
103
+
104
+ <audio controls src="samples/es_M3.wav"></audio> &nbsp; **ES · M3 · 2.86 s** —
105
+ "Hola, esto es una prueba de síntesis de voz en español ejecutada en tiempo real sobre Apple Silicon."
106
+
107
+ ## Benchmarks (Apple M4, FP32, median of 3)
108
+
109
+ | Sample | Duration | MLX wall | RTF | ONNX SDK | Speedup |
110
+ |-----------------|---------:|----------:|----------:|---------:|--------:|
111
+ | EN · F1 · short | 2.79 s | 36.6 ms | **x76** | 1005 ms | **28 ×**|
112
+ | EN · M1 · long | 3.90 s | 38.4 ms | **x102** | 1356 ms | **35 ×**|
113
+ | FR · F2 | 3.41 s | 37.9 ms | **x90** | 1196 ms | **32 ×**|
114
+ | DE · M2 | 3.69 s | 38.1 ms | **x97** | 1314 ms | **35 ×**|
115
+ | JA · F3 | 1.46 s | 32.1 ms | **x46** | 848 ms | **26 ×**|
116
+ | ES · M3 | 2.86 s | 37.0 ms | **x77** | 1002 ms | **27 ×**|
117
+
118
+ Raw numbers are in [`bench_results.csv`](bench_results.csv) (regenerable via
119
+ the source repo's
120
+ [`tools/supertonic3_samples_and_bench.py`](https://gitea.tavportal.com/olivier/MLX_CONVERTOR/src/branch/feat/platform-abc/tools/supertonic3_samples_and_bench.py)).
121
+
122
+ Reference comparison: the CoreML build of the same model on the same hardware
123
+ runs at ~x27 realtime. The MLX port is **~2-4× faster** end-to-end while
124
+ remaining bit-identical to the ONNX Runtime reference on the vocoder
125
+ (cosine 1.00) and at cosine ≥ 0.98 on the full estimator output.
126
+
127
+ ## Voices
128
+
129
+ 10 preset voices — five female (`F1`–`F5`) and five male (`M1`–`M5`). The
130
+ `voice_styles/` directory contains both `style_ttl` (50×256 latent style for
131
+ the audio path) and `style_dp` (8×16 style for the duration head) for each
132
+ voice. Pass the voice name as the `voice=` kwarg to `Pipeline.generate`.
133
+
134
+ ## Languages
135
+
136
+ 31 languages supported. Pass the ISO 639-1 code as the `lang=` kwarg:
137
+ `en` `fr` `de` `es` `it` `pt` `ja` `ko` `zh` `ru` `pl` `nl` `tr` `ar` `hi`
138
+ `vi` `th` `id` `cs` `ro` `hu` `el` `da` `sv` `fi` `no` `he` `uk` `bg` `hr` `sk`.
139
+
140
+ ## Architecture (short)
141
+
142
+ Four sub-models, all in `weights/*.safetensors`:
143
+
144
+ | Sub-model | Role | Params | Size |
145
+ |----------------------|-------------------------------------|--------|---------|
146
+ | `vector_estimator` | 24-block CFG flow-matching velocity | ~64 M | 256 MB |
147
+ | `text_encoder` | Character → 256-D text embedding | ~9 M | 36 MB |
148
+ | `duration_predictor` | Text → seconds | ~1 M | 3.5 MB |
149
+ | `vocoder` | Latent (B,144,T) → 44.1 kHz wav | ~25 M | 101 MB |
150
+
151
+ The pipeline runs **exactly 5 Euler steps** with classifier-free guidance
152
+ (`4×cond − 3×uncond`). This schedule is trained-in: reducing the step count
153
+ or disabling CFG produces an essentially uncorrelated waveform (verified
154
+ empirically — see the `bench_n_steps.py` script in the source repo).
155
+
156
+ ## Loading from a local snapshot
157
+
158
+ Three layouts are auto-detected by `Pipeline.from_pretrained`:
159
+
160
+ 1. **Hugging Face repo id** (e.g. `"ambassadia/supertonic-3-mlx"`) — auto-download
161
+ 2. **Local path containing `weights/`** (this layout) — fastest cold-load
162
+ 3. **Local path containing `onnx/`** (upstream snapshot) — converts at load time
163
+
164
+ ## License
165
+
166
+ This release combines two artefact classes under two distinct licenses:
167
+
168
+ - **Model weights** (`weights/*.safetensors`) — **BigScience Open RAIL-M**.
169
+ See [`LICENSE`](LICENSE) for the full text. The Attachment A use
170
+ restrictions are reproduced below and apply to all downstream use of the
171
+ model and of generated audio.
172
+ - **Port code** (`src/supertonic_3_mlx/`) — **Apache License 2.0**. See
173
+ [`LICENSE-CODE`](LICENSE-CODE).
174
+
175
+ See [`NOTICE`](NOTICE) for the modifications statement and the upstream
176
+ attribution.
177
+
178
+ ### OpenRAIL-M Attachment A — use restrictions
179
+
180
+ You agree not to use the model or derivatives:
181
+
182
+ (a) In any way that violates any applicable national, federal, state, local or
183
+ international law or regulation.
184
+
185
+ (b) For the purpose of exploiting, harming or attempting to exploit or harm
186
+ minors in any way.
187
+
188
+ (c) To generate or disseminate verifiably false information and/or content
189
+ with the purpose of harming others.
190
+
191
+ (d) To generate or disseminate personal identifiable information that can be
192
+ used to harm an individual.
193
+
194
+ (e) To generate or disseminate information and/or content (e.g. images, code,
195
+ posts, articles), and place the information and/or content in any context
196
+ (e.g. bot generating tweets) **without expressly and intelligibly disclaiming
197
+ that the information and/or content is machine generated**.
198
+
199
+ (f) To defame, disparage or otherwise harass others.
200
+
201
+ (g) To impersonate or attempt to impersonate (e.g. **deepfakes**) others
202
+ without their consent.
203
+
204
+ (h) For fully automated decision making that adversely impacts an individual's
205
+ legal rights or otherwise creates or modifies a binding, enforceable obligation.
206
+
207
+ (i) For any use intended to or which has the effect of discriminating against
208
+ or harming individuals or groups based on online or offline social behavior or
209
+ known or predicted personal or personality characteristics.
210
+
211
+ (j) To exploit any of the vulnerabilities of a specific group of persons based
212
+ on their age, social, physical or mental characteristics, in order to materially
213
+ distort the behavior of a person pertaining to that group in a manner that
214
+ causes or is likely to cause that person or another person physical or
215
+ psychological harm.
216
+
217
+ (k) For any use intended to or which has the effect of discriminating against
218
+ individuals or groups based on legally protected characteristics or categories.
219
+
220
+ (l) **To provide medical advice and medical results interpretation.**
221
+
222
+ (m) To generate or disseminate information for the purpose to be used for
223
+ administration of justice, law enforcement, immigration or asylum processes,
224
+ such as predicting an individual will commit fraud/crime commitment.
225
+
226
+ ## Citation
227
+
228
+ ```bibtex
229
+ @misc{supertonic3-mlx,
230
+ title = {Supertonic 3 MLX: native Apple Silicon port of Supertone's multilingual TTS},
231
+ author = {Dupont, Olivier},
232
+ year = {2026},
233
+ url = {https://huggingface.co/ambassadia/supertonic-3-mlx},
234
+ note = {Derivative of Supertone/supertonic-3 (https://huggingface.co/Supertone/supertonic-3)}
235
+ }
236
+ ```
237
+
238
+ Please also cite the upstream Supertone Supertonic 3 model when using this
239
+ port.
bench_results.csv ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ filename,language,voice,text,duration_s,mlx_ms_median,rtf_mlx,onnx_ms_median,rtf_onnx,speedup_mlx_over_onnx
2
+ samples/en_F1_short.wav,en,F1,Hello world from Apple Silicon. Supertonic 3 runs at one hundred times real time.,2.786,36.6,76.2,1004.7,2.8,27.5
3
+ samples/en_M1_long.wav,en,M1,"A gentle breeze moved through the open window while the children, still half-asleep, listened to the distant sound of the harbour bells.",3.901,38.4,101.7,1356.0,2.9,35.3
4
+ samples/fr_F2.wav,fr,F2,"Bonjour, ceci est un test de synthèse vocale en français. Le modèle gère trente-et-une langues sur une puce M4.",3.413,37.9,90.1,1195.6,2.9,31.6
5
+ samples/de_M2.wav,de,M2,"Guten Morgen. Dieses Modell läuft komplett auf Apple Silicon, ohne ONNX und ohne CoreML, in reinem MLX.",3.692,38.1,96.9,1313.9,2.8,34.5
6
+ samples/ja_F3.wav,ja,F3,こんにちは。これはアップルシリコン上でMLXを使ったテストです。,1.463,32.1,45.6,848.4,1.7,26.4
7
+ samples/es_M3.wav,es,M3,"Hola, esto es una prueba de síntesis de voz en español ejecutada en tiempo real sobre Apple Silicon.",2.856,37.0,77.2,1002.1,2.9,27.1
conversion_report.json ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": [
3
+ {
4
+ "model": "VectorEstimator",
5
+ "onnx": "/tmp/supertonic3/model/onnx/vector_estimator.onnx",
6
+ "safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/vector_estimator.safetensors",
7
+ "bytes": 256053073,
8
+ "sha256": "2359240f2dcaee03b4800102aa0bea00223d2867ab752ef01af2b1cfaf92f3a6",
9
+ "weights_kept": 351,
10
+ "weights_dropped": 120,
11
+ "dropped_detail": {
12
+ "tts.ae.vector_field.proj_in.net.weight": "not-in-model",
13
+ "tts.ae.vector_field.main_blocks.0.convnext.0.pwconv1.weight": "not-in-model",
14
+ "tts.ae.vector_field.main_blocks.0.convnext.0.pwconv1.bias": "not-in-model",
15
+ "tts.ae.vector_field.main_blocks.0.convnext.0.pwconv2.weight": "not-in-model",
16
+ "tts.ae.vector_field.main_blocks.0.convnext.0.pwconv2.bias": "not-in-model",
17
+ "tts.ae.vector_field.main_blocks.0.convnext.1.pwconv1.weight": "not-in-model",
18
+ "tts.ae.vector_field.main_blocks.0.convnext.1.pwconv1.bias": "not-in-model",
19
+ "tts.ae.vector_field.main_blocks.0.convnext.1.pwconv2.weight": "not-in-model",
20
+ "tts.ae.vector_field.main_blocks.0.convnext.1.pwconv2.bias": "not-in-model",
21
+ "tts.ae.vector_field.main_blocks.0.convnext.2.pwconv1.weight": "not-in-model",
22
+ "tts.ae.vector_field.main_blocks.0.convnext.2.pwconv1.bias": "not-in-model",
23
+ "tts.ae.vector_field.main_blocks.0.convnext.2.pwconv2.weight": "not-in-model",
24
+ "tts.ae.vector_field.main_blocks.0.convnext.2.pwconv2.bias": "not-in-model",
25
+ "tts.ae.vector_field.main_blocks.0.convnext.3.pwconv1.weight": "not-in-model",
26
+ "tts.ae.vector_field.main_blocks.0.convnext.3.pwconv1.bias": "not-in-model",
27
+ "tts.ae.vector_field.main_blocks.0.convnext.3.pwconv2.weight": "not-in-model",
28
+ "tts.ae.vector_field.main_blocks.0.convnext.3.pwconv2.bias": "not-in-model",
29
+ "tts.ae.vector_field.main_blocks.2.convnext.0.pwconv1.weight": "not-in-model",
30
+ "tts.ae.vector_field.main_blocks.2.convnext.0.pwconv1.bias": "not-in-model",
31
+ "tts.ae.vector_field.main_blocks.2.convnext.0.pwconv2.weight": "not-in-model",
32
+ "tts.ae.vector_field.main_blocks.2.convnext.0.pwconv2.bias": "not-in-model",
33
+ "tts.ae.vector_field.main_blocks.4.convnext.0.pwconv1.weight": "not-in-model",
34
+ "tts.ae.vector_field.main_blocks.4.convnext.0.pwconv1.bias": "not-in-model",
35
+ "tts.ae.vector_field.main_blocks.4.convnext.0.pwconv2.weight": "not-in-model",
36
+ "tts.ae.vector_field.main_blocks.4.convnext.0.pwconv2.bias": "not-in-model",
37
+ "tts.ae.vector_field.main_blocks.6.convnext.0.pwconv1.weight": "not-in-model",
38
+ "tts.ae.vector_field.main_blocks.6.convnext.0.pwconv1.bias": "not-in-model",
39
+ "tts.ae.vector_field.main_blocks.6.convnext.0.pwconv2.weight": "not-in-model",
40
+ "tts.ae.vector_field.main_blocks.6.convnext.0.pwconv2.bias": "not-in-model",
41
+ "tts.ae.vector_field.main_blocks.6.convnext.1.pwconv1.weight": "not-in-model",
42
+ "tts.ae.vector_field.main_blocks.6.convnext.1.pwconv1.bias": "not-in-model",
43
+ "tts.ae.vector_field.main_blocks.6.convnext.1.pwconv2.weight": "not-in-model",
44
+ "tts.ae.vector_field.main_blocks.6.convnext.1.pwconv2.bias": "not-in-model",
45
+ "tts.ae.vector_field.main_blocks.6.convnext.2.pwconv1.weight": "not-in-model",
46
+ "tts.ae.vector_field.main_blocks.6.convnext.2.pwconv1.bias": "not-in-model",
47
+ "tts.ae.vector_field.main_blocks.6.convnext.2.pwconv2.weight": "not-in-model",
48
+ "tts.ae.vector_field.main_blocks.6.convnext.2.pwconv2.bias": "not-in-model",
49
+ "tts.ae.vector_field.main_blocks.6.convnext.3.pwconv1.weight": "not-in-model",
50
+ "tts.ae.vector_field.main_blocks.6.convnext.3.pwconv1.bias": "not-in-model",
51
+ "tts.ae.vector_field.main_blocks.6.convnext.3.pwconv2.weight": "not-in-model",
52
+ "tts.ae.vector_field.main_blocks.6.convnext.3.pwconv2.bias": "not-in-model",
53
+ "tts.ae.vector_field.main_blocks.8.convnext.0.pwconv1.weight": "not-in-model",
54
+ "tts.ae.vector_field.main_blocks.8.convnext.0.pwconv1.bias": "not-in-model",
55
+ "tts.ae.vector_field.main_blocks.8.convnext.0.pwconv2.weight": "not-in-model",
56
+ "tts.ae.vector_field.main_blocks.8.convnext.0.pwconv2.bias": "not-in-model",
57
+ "tts.ae.vector_field.main_blocks.10.convnext.0.pwconv1.weight": "not-in-model",
58
+ "tts.ae.vector_field.main_blocks.10.convnext.0.pwconv1.bias": "not-in-model",
59
+ "tts.ae.vector_field.main_blocks.10.convnext.0.pwconv2.weight": "not-in-model",
60
+ "tts.ae.vector_field.main_blocks.10.convnext.0.pwconv2.bias": "not-in-model",
61
+ "tts.ae.vector_field.main_blocks.12.convnext.0.pwconv1.weight": "not-in-model",
62
+ "tts.ae.vector_field.main_blocks.12.convnext.0.pwconv1.bias": "not-in-model",
63
+ "tts.ae.vector_field.main_blocks.12.convnext.0.pwconv2.weight": "not-in-model",
64
+ "tts.ae.vector_field.main_blocks.12.convnext.0.pwconv2.bias": "not-in-model",
65
+ "tts.ae.vector_field.main_blocks.12.convnext.1.pwconv1.weight": "not-in-model",
66
+ "tts.ae.vector_field.main_blocks.12.convnext.1.pwconv1.bias": "not-in-model",
67
+ "tts.ae.vector_field.main_blocks.12.convnext.1.pwconv2.weight": "not-in-model",
68
+ "tts.ae.vector_field.main_blocks.12.convnext.1.pwconv2.bias": "not-in-model",
69
+ "tts.ae.vector_field.main_blocks.12.convnext.2.pwconv1.weight": "not-in-model",
70
+ "tts.ae.vector_field.main_blocks.12.convnext.2.pwconv1.bias": "not-in-model",
71
+ "tts.ae.vector_field.main_blocks.12.convnext.2.pwconv2.weight": "not-in-model",
72
+ "tts.ae.vector_field.main_blocks.12.convnext.2.pwconv2.bias": "not-in-model",
73
+ "tts.ae.vector_field.main_blocks.12.convnext.3.pwconv1.weight": "not-in-model",
74
+ "tts.ae.vector_field.main_blocks.12.convnext.3.pwconv1.bias": "not-in-model",
75
+ "tts.ae.vector_field.main_blocks.12.convnext.3.pwconv2.weight": "not-in-model",
76
+ "tts.ae.vector_field.main_blocks.12.convnext.3.pwconv2.bias": "not-in-model",
77
+ "tts.ae.vector_field.main_blocks.14.convnext.0.pwconv1.weight": "not-in-model",
78
+ "tts.ae.vector_field.main_blocks.14.convnext.0.pwconv1.bias": "not-in-model",
79
+ "tts.ae.vector_field.main_blocks.14.convnext.0.pwconv2.weight": "not-in-model",
80
+ "tts.ae.vector_field.main_blocks.14.convnext.0.pwconv2.bias": "not-in-model",
81
+ "tts.ae.vector_field.main_blocks.16.convnext.0.pwconv1.weight": "not-in-model",
82
+ "tts.ae.vector_field.main_blocks.16.convnext.0.pwconv1.bias": "not-in-model",
83
+ "tts.ae.vector_field.main_blocks.16.convnext.0.pwconv2.weight": "not-in-model",
84
+ "tts.ae.vector_field.main_blocks.16.convnext.0.pwconv2.bias": "not-in-model",
85
+ "tts.ae.vector_field.main_blocks.18.convnext.0.pwconv1.weight": "not-in-model",
86
+ "tts.ae.vector_field.main_blocks.18.convnext.0.pwconv1.bias": "not-in-model",
87
+ "tts.ae.vector_field.main_blocks.18.convnext.0.pwconv2.weight": "not-in-model",
88
+ "tts.ae.vector_field.main_blocks.18.convnext.0.pwconv2.bias": "not-in-model",
89
+ "tts.ae.vector_field.main_blocks.18.convnext.1.pwconv1.weight": "not-in-model",
90
+ "tts.ae.vector_field.main_blocks.18.convnext.1.pwconv1.bias": "not-in-model",
91
+ "tts.ae.vector_field.main_blocks.18.convnext.1.pwconv2.weight": "not-in-model",
92
+ "tts.ae.vector_field.main_blocks.18.convnext.1.pwconv2.bias": "not-in-model",
93
+ "tts.ae.vector_field.main_blocks.18.convnext.2.pwconv1.weight": "not-in-model",
94
+ "tts.ae.vector_field.main_blocks.18.convnext.2.pwconv1.bias": "not-in-model",
95
+ "tts.ae.vector_field.main_blocks.18.convnext.2.pwconv2.weight": "not-in-model",
96
+ "tts.ae.vector_field.main_blocks.18.convnext.2.pwconv2.bias": "not-in-model",
97
+ "tts.ae.vector_field.main_blocks.18.convnext.3.pwconv1.weight": "not-in-model",
98
+ "tts.ae.vector_field.main_blocks.18.convnext.3.pwconv1.bias": "not-in-model",
99
+ "tts.ae.vector_field.main_blocks.18.convnext.3.pwconv2.weight": "not-in-model",
100
+ "tts.ae.vector_field.main_blocks.18.convnext.3.pwconv2.bias": "not-in-model",
101
+ "tts.ae.vector_field.main_blocks.20.convnext.0.pwconv1.weight": "not-in-model",
102
+ "tts.ae.vector_field.main_blocks.20.convnext.0.pwconv1.bias": "not-in-model",
103
+ "tts.ae.vector_field.main_blocks.20.convnext.0.pwconv2.weight": "not-in-model",
104
+ "tts.ae.vector_field.main_blocks.20.convnext.0.pwconv2.bias": "not-in-model",
105
+ "tts.ae.vector_field.main_blocks.22.convnext.0.pwconv1.weight": "not-in-model",
106
+ "tts.ae.vector_field.main_blocks.22.convnext.0.pwconv1.bias": "not-in-model",
107
+ "tts.ae.vector_field.main_blocks.22.convnext.0.pwconv2.weight": "not-in-model",
108
+ "tts.ae.vector_field.main_blocks.22.convnext.0.pwconv2.bias": "not-in-model",
109
+ "tts.ae.vector_field.last_convnext.convnext.0.pwconv1.weight": "not-in-model",
110
+ "tts.ae.vector_field.last_convnext.convnext.0.pwconv1.bias": "not-in-model",
111
+ "tts.ae.vector_field.last_convnext.convnext.0.pwconv2.weight": "not-in-model",
112
+ "tts.ae.vector_field.last_convnext.convnext.0.pwconv2.bias": "not-in-model",
113
+ "tts.ae.vector_field.last_convnext.convnext.1.pwconv1.weight": "not-in-model",
114
+ "tts.ae.vector_field.last_convnext.convnext.1.pwconv1.bias": "not-in-model",
115
+ "tts.ae.vector_field.last_convnext.convnext.1.pwconv2.weight": "not-in-model",
116
+ "tts.ae.vector_field.last_convnext.convnext.1.pwconv2.bias": "not-in-model",
117
+ "tts.ae.vector_field.last_convnext.convnext.2.pwconv1.weight": "not-in-model",
118
+ "tts.ae.vector_field.last_convnext.convnext.2.pwconv1.bias": "not-in-model",
119
+ "tts.ae.vector_field.last_convnext.convnext.2.pwconv2.weight": "not-in-model",
120
+ "tts.ae.vector_field.last_convnext.convnext.2.pwconv2.bias": "not-in-model",
121
+ "tts.ae.vector_field.last_convnext.convnext.3.pwconv1.weight": "not-in-model",
122
+ "tts.ae.vector_field.last_convnext.convnext.3.pwconv1.bias": "not-in-model",
123
+ "tts.ae.vector_field.last_convnext.convnext.3.pwconv2.weight": "not-in-model",
124
+ "tts.ae.vector_field.last_convnext.convnext.3.pwconv2.bias": "not-in-model",
125
+ "tts.ae.vector_field.proj_out.net.weight": "not-in-model",
126
+ "<missing>.vector_field.main_blocks.9.attn.theta": "expected-but-not-extracted",
127
+ "<missing>.vector_field.main_blocks.9.attn.increments": "expected-but-not-extracted",
128
+ "<missing>.vector_field.main_blocks.15.attn.theta": "expected-but-not-extracted",
129
+ "<missing>.vector_field.main_blocks.15.attn.increments": "expected-but-not-extracted",
130
+ "<missing>.vector_field.main_blocks.21.attn.theta": "expected-but-not-extracted",
131
+ "<missing>.vector_field.main_blocks.21.attn.increments": "expected-but-not-extracted"
132
+ },
133
+ "elapsed_s": 0.289
134
+ },
135
+ {
136
+ "model": "TextEncoder",
137
+ "onnx": "/tmp/supertonic3/model/onnx/text_encoder.onnx",
138
+ "safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/text_encoder.safetensors",
139
+ "bytes": 36022466,
140
+ "sha256": "9df20bb79496718b36d2c0fc37636d3f78d6ef751b2899ff6dfeb975ae737ada",
141
+ "weights_kept": 146,
142
+ "weights_dropped": 0,
143
+ "dropped_detail": {},
144
+ "elapsed_s": 0.035
145
+ },
146
+ {
147
+ "model": "DurationPredictor",
148
+ "onnx": "/tmp/supertonic3/model/onnx/duration_predictor.onnx",
149
+ "safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/duration_predictor.safetensors",
150
+ "bytes": 3470807,
151
+ "sha256": "cd473acb6e0ac27426084488ccb3b3cc184e70d05db90897e2b892846db5dcb3",
152
+ "weights_kept": 98,
153
+ "weights_dropped": 0,
154
+ "dropped_detail": {},
155
+ "elapsed_s": 0.007
156
+ },
157
+ {
158
+ "model": "Vocoder",
159
+ "onnx": "/tmp/supertonic3/model/onnx/vocoder.onnx",
160
+ "safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/vocoder.safetensors",
161
+ "bytes": 101364763,
162
+ "sha256": "b2ec31ab7c554f6e15b9a6780554b5d3502345de7848b310966bfb4e1ea4e526",
163
+ "weights_kept": 103,
164
+ "weights_dropped": 0,
165
+ "dropped_detail": {},
166
+ "elapsed_s": 0.079
167
+ }
168
+ ],
169
+ "ancillary": [
170
+ {
171
+ "name": "unicode_indexer.json",
172
+ "bytes": 277676,
173
+ "sha256": "9bf7346e43883a81f8645c81224f786d43c5b57f3641f6e7671a7d6c493cb24f"
174
+ },
175
+ {
176
+ "name": "voice_styles/F1.json",
177
+ "bytes": 292046,
178
+ "sha256": "bbdec6ee00231c2c742ad05483df5334cab3b52fda3ba38e6a07059c4563dbc2"
179
+ },
180
+ {
181
+ "name": "voice_styles/F2.json",
182
+ "bytes": 292423,
183
+ "sha256": "7c722c6a72707b1a77f035d67f0d1351ba187738e06f7683e8c72b1df3477fc6"
184
+ },
185
+ {
186
+ "name": "voice_styles/F3.json",
187
+ "bytes": 290794,
188
+ "sha256": "12f6ef2573baa2defa1128069cb59f203e3ab67c92af77b42df8a0e3a2f7c6ab"
189
+ },
190
+ {
191
+ "name": "voice_styles/F4.json",
192
+ "bytes": 291808,
193
+ "sha256": "c2fa764c1225a76dfc3e2c73e8aa4f70d9ee48793860eb34c295fff01c2e032b"
194
+ },
195
+ {
196
+ "name": "voice_styles/F5.json",
197
+ "bytes": 291479,
198
+ "sha256": "45966e73316415626cf41a7d1c6f3b4c70dbc1ba2bee5c1978ef0ce33244fc8d"
199
+ },
200
+ {
201
+ "name": "voice_styles/M1.json",
202
+ "bytes": 291748,
203
+ "sha256": "e35604687f5d23694b8e91593a93eec0e4eca6c0b02bb8ed69139ab2ea6b0a5b"
204
+ },
205
+ {
206
+ "name": "voice_styles/M2.json",
207
+ "bytes": 292055,
208
+ "sha256": "b76cbf62bac707c710cf0ae5aba5e31eea1a6339a9734bfae33ab98499534a50"
209
+ },
210
+ {
211
+ "name": "voice_styles/M3.json",
212
+ "bytes": 290198,
213
+ "sha256": "ea1ac35ccb91b0d7ecad533a2fbd0eec10c91513d8951e3b25fbba99954e159b"
214
+ },
215
+ {
216
+ "name": "voice_styles/M4.json",
217
+ "bytes": 291522,
218
+ "sha256": "ca8eefad4fcd989c9379032ff3e50738adc547eeb5e221b82593a6d7b3bac303"
219
+ },
220
+ {
221
+ "name": "voice_styles/M5.json",
222
+ "bytes": 291469,
223
+ "sha256": "dd22b92740314321f8ae11c5e87f8dd60d060f15dd3a632b5adf77f471f77af2"
224
+ }
225
+ ]
226
+ }
examples/quickstart.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal Supertonic 3 MLX usage — 5 lines, no fluff.
2
+
3
+ Run from anywhere AFTER ``pip install supertonic-3-mlx`` (or from inside
4
+ this directory after ``pip install ./``):
5
+
6
+ python examples/quickstart.py
7
+ """
8
+ from supertonic_3_mlx import Pipeline
9
+ import soundfile as sf
10
+
11
+ # When the package has been pip-installed, this auto-downloads from the Hub
12
+ # (~ 400 MB) into the standard Hugging Face cache. After the first run, the
13
+ # weights are reused from cache and cold start is ~ 11 ms on M4.
14
+ pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
15
+
16
+ wav = pipe.generate(
17
+ "Hello world from Apple Silicon. Supertonic 3 runs at one hundred times realtime.",
18
+ voice="F1", # one of F1..F5, M1..M5
19
+ lang="en", # ISO 639-1
20
+ )
21
+
22
+ sf.write("hello.wav", wav, pipe.sample_rate)
23
+ print(f"wrote hello.wav — {len(wav) / pipe.sample_rate:.2f}s of audio")
pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "supertonic-3-mlx"
3
+ version = "0.1.0"
4
+ description = "MLX-native port of Supertone's Supertonic 3 multilingual TTS (31 languages, ~x100 realtime on Apple Silicon)"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ authors = [{ name = "Olivier Dupont", email = "olivier.dupont@taviramonaco.com" }]
8
+ license = { text = "Apache-2.0 AND OpenRAIL-M" }
9
+ keywords = ["mlx", "tts", "speech-synthesis", "apple-silicon", "supertonic", "multilingual"]
10
+ classifiers = [
11
+ "Development Status :: 4 - Beta",
12
+ "Environment :: MacOS X",
13
+ "Intended Audience :: Developers",
14
+ "Intended Audience :: Science/Research",
15
+ "License :: OSI Approved :: Apache Software License",
16
+ "Operating System :: MacOS",
17
+ "Programming Language :: Python :: 3 :: Only",
18
+ "Programming Language :: Python :: 3.10",
19
+ "Programming Language :: Python :: 3.11",
20
+ "Programming Language :: Python :: 3.12",
21
+ "Topic :: Multimedia :: Sound/Audio :: Speech",
22
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
23
+ ]
24
+ dependencies = [
25
+ "mlx>=0.21.0",
26
+ "numpy>=1.24.0",
27
+ ]
28
+
29
+ [project.optional-dependencies]
30
+ hub = ["huggingface_hub>=0.26.0"]
31
+ dev = ["pytest>=8.3.0", "ruff>=0.7.0"]
32
+
33
+ [project.urls]
34
+ Homepage = "https://huggingface.co/ambassadia/supertonic-3-mlx"
35
+ Upstream = "https://huggingface.co/Supertone/supertonic-3"
36
+ Source = "https://gitea.tavportal.com/olivier/MLX_CONVERTOR"
37
+
38
+ [build-system]
39
+ requires = ["hatchling"]
40
+ build-backend = "hatchling.build"
41
+
42
+ [tool.hatch.build.targets.wheel]
43
+ packages = ["src/supertonic_3_mlx"]
samples/de_M2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b363a0c1ca3f596e05ac001f65377d129613db67164dbedfbdf9b4c11d56e365
3
+ size 325676
samples/en_F1_short.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c78ec23d58e6befa6bbcd078631a11eb6c1b647dbb07750767bd29ed17205f6
3
+ size 245804
samples/en_M1_long.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2561326a9c4a28e7bb3cc5bdf6a36f23c97e454714c0bc5e63b8f8a981beac96
3
+ size 344108
samples/es_M3.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c19918dc6927d18825f1a8277b172f3482a5953526f19a3f0bbce3f911885822
3
+ size 251948
samples/fr_F2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f37101850f032c2b59c20fc45b0d7fee794005c2bebb43737737427d09069d94
3
+ size 301100
samples/ja_F3.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c447d1d9899548fe646de2238e13ccf9b3c7ded9d16d39bd115a8e9d66d5ff1
3
+ size 129068
src/supertonic_3_mlx/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Supertonic 3 — MLX-native TTS for Apple Silicon.
2
+
3
+ 31-language text-to-speech, 5 Euler steps with classifier-free guidance, in
4
+ pure MLX. On M4 the full pipeline runs at ~x100 realtime.
5
+
6
+ Quickstart
7
+ ----------
8
+
9
+ from supertonic_3_mlx import Pipeline
10
+ pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
11
+ wav = pipe.generate("Hello world from Apple Silicon.", voice="F1", lang="en")
12
+ # wav is a 1-D ``numpy.float32`` array at 44.1 kHz.
13
+
14
+ The model weights are released under the BigScience OpenRAIL-M license
15
+ (see LICENSE in the Hugging Face repository). This MLX port code is
16
+ Apache-2.0. Together they form a dual-license package; Attachment A use
17
+ restrictions of OpenRAIL-M govern downstream use of the generated audio.
18
+
19
+ Public API:
20
+ Pipeline — end-to-end TTS, ``from_pretrained`` + ``generate``
21
+ VectorEstimator — the 24-block CFG flow-matching net (sub-model 1/4)
22
+ TextEncoder — character → text embedding (sub-model 2/4)
23
+ DurationPredictor — text → duration in seconds (sub-model 3/4)
24
+ Vocoder — latent → 44.1 kHz waveform (sub-model 4/4)
25
+ """
26
+ from supertonic_3_mlx._config import (
27
+ DIM, LATENT_CH, CONVNEXT_HIDDEN, CONVNEXT_K,
28
+ NUM_MAIN_BLOCKS, NUM_CYCLES, BLOCKS_PER_CYCLE, BLOCK_CYCLE, STACK4_DILATIONS,
29
+ TEXT_HEADS, TEXT_HEAD_DIM, TEXT_DIM, ROTARY_BASE, ROTARY_SCALE,
30
+ STYLE_HEADS, STYLE_HEAD_DIM, STYLE_LEN, STYLE_DIM,
31
+ TIME_EMB_DIM, TIME_MLP_HIDDEN,
32
+ EPS_LN, CHUNK_COMPRESS, LATENT_DIM, SAMPLE_RATE,
33
+ SUPERTONIC3_HF_REPO,
34
+ )
35
+ from supertonic_3_mlx.duration_predictor import DurationPredictor
36
+ from supertonic_3_mlx.text_encoder import TextEncoder
37
+ from supertonic_3_mlx.vector_estimator import VectorEstimator
38
+ from supertonic_3_mlx.vocoder import Vocoder
39
+ from supertonic_3_mlx.pipeline import SupertonicMLXPipeline as Pipeline
40
+
41
+ __all__ = [
42
+ "Pipeline",
43
+ "DurationPredictor", "TextEncoder", "VectorEstimator", "Vocoder",
44
+ "DIM", "LATENT_CH", "CONVNEXT_HIDDEN", "CONVNEXT_K",
45
+ "NUM_MAIN_BLOCKS", "NUM_CYCLES", "BLOCKS_PER_CYCLE", "BLOCK_CYCLE", "STACK4_DILATIONS",
46
+ "TEXT_HEADS", "TEXT_HEAD_DIM", "TEXT_DIM", "ROTARY_BASE", "ROTARY_SCALE",
47
+ "STYLE_HEADS", "STYLE_HEAD_DIM", "STYLE_LEN", "STYLE_DIM",
48
+ "TIME_EMB_DIM", "TIME_MLP_HIDDEN",
49
+ "EPS_LN", "CHUNK_COMPRESS", "LATENT_DIM", "SAMPLE_RATE",
50
+ "SUPERTONIC3_HF_REPO",
51
+ ]
src/supertonic_3_mlx/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (2.56 kB). View file
 
src/supertonic_3_mlx/__pycache__/_config.cpython-312.pyc ADDED
Binary file (2.21 kB). View file
 
src/supertonic_3_mlx/__pycache__/_nn_wrappers.cpython-312.pyc ADDED
Binary file (3.24 kB). View file
 
src/supertonic_3_mlx/__pycache__/duration_predictor.cpython-312.pyc ADDED
Binary file (22.4 kB). View file
 
src/supertonic_3_mlx/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (28 kB). View file
 
src/supertonic_3_mlx/__pycache__/text_encoder.cpython-312.pyc ADDED
Binary file (22.9 kB). View file
 
src/supertonic_3_mlx/__pycache__/vector_estimator.cpython-312.pyc ADDED
Binary file (38.4 kB). View file
 
src/supertonic_3_mlx/__pycache__/vocoder.cpython-312.pyc ADDED
Binary file (18.2 kB). View file
 
src/supertonic_3_mlx/_config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Locked hyperparameters for Supertonic 3 MLX port.
2
+
3
+ Derived from the official ``Supertone/supertonic-3/onnx/tts.json``.
4
+ Changing these = re-running parity tests.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ # Vector estimator (the flow-matching denoiser)
9
+ DIM: int = 512 # backbone width
10
+ LATENT_CH: int = 144 # 24 * chunk_compress_factor (6)
11
+ CONVNEXT_HIDDEN: int = 2048 # main_blocks ConvNeXt intermediate dim (2× vs s2)
12
+ CONVNEXT_K: int = 5
13
+ LAST_CONVNEXT_NUM: int = 4 # last_convnext is a 4-layer stack (dilations [1,1,1,1])
14
+
15
+ # 24 main_blocks = 4 cycles × 6 sub-blocks (cycle: stack4, time, cn1, text_attn, cn1, style_attn)
16
+ NUM_CYCLES: int = 4
17
+ BLOCKS_PER_CYCLE: int = 6
18
+ NUM_MAIN_BLOCKS: int = NUM_CYCLES * BLOCKS_PER_CYCLE
19
+ BLOCK_CYCLE = ("stack4", "time", "cn1", "text_attn", "cn1", "style_attn")
20
+
21
+ # ConvNeXt stack 4 (in stack4 blocks) — dilation schedule
22
+ STACK4_DILATIONS = (1, 2, 4, 8)
23
+
24
+ # Text cross-attention (RoPE) — block type "text_attn"
25
+ TEXT_DIM: int = 256
26
+ TEXT_HEADS: int = 8 # 2× vs s2 (4)
27
+ TEXT_HEAD_DIM: int = DIM // TEXT_HEADS # 512/8 = 64
28
+ ROTARY_BASE: int = 10_000
29
+ ROTARY_SCALE: int = 10
30
+
31
+ # Style cross-attention — block type "style_attn"
32
+ STYLE_DIM: int = 256
33
+ STYLE_LEN: int = 50 # 50 style tokens (n_style)
34
+ STYLE_HEADS: int = 2
35
+ STYLE_HEAD_DIM: int = 128
36
+
37
+ # Time encoding (sinusoidal + MLP)
38
+ TIME_EMB_DIM: int = 64
39
+ TIME_MLP_HIDDEN: int = 256
40
+
41
+ # LayerNorm epsilon
42
+ EPS_LN: float = 1e-6
43
+
44
+ # Chunk compress factor (used by AE)
45
+ CHUNK_COMPRESS: int = 6
46
+ LATENT_DIM: int = 24 # ldim before chunk compression
47
+
48
+ # Sample rate
49
+ SAMPLE_RATE: int = 44_100
50
+
51
+ # HF references (will be pinned to SHA after first download)
52
+ SUPERTONIC3_HF_REPO: str = "Supertone/supertonic-3"
53
+ ONNX_VECTOR_ESTIMATOR: str = "onnx/vector_estimator.onnx"
54
+ ONNX_TEXT_ENCODER: str = "onnx/text_encoder.onnx"
55
+ ONNX_DURATION_PREDICTOR: str = "onnx/duration_predictor.onnx"
56
+ ONNX_VOCODER: str = "onnx/vocoder.onnx"
57
+ ONNX_TTS_JSON: str = "onnx/tts.json"
58
+ ONNX_UNICODE_INDEXER: str = "onnx/unicode_indexer.json"
src/supertonic_3_mlx/_nn_wrappers.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Small wrapper modules to match Supertonic 3 ONNX submodule nesting.
2
+
3
+ The s3 checkpoint nests primitives one level deeper than typical MLX modules:
4
+ - ``norm.norm.weight`` — LayerNorm wrapped in a Norm container
5
+ - ``linear.linear.weight`` — Linear wrapped in a Linear container
6
+ - ``W_query.linear.weight`` — attention projection wrapped
7
+
8
+ Mirroring this nesting lets us load the safetensors with ``model.load_weights(...)``
9
+ without any key remapping at load time.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import mlx.core as mx
14
+ import mlx.nn as nn
15
+
16
+
17
+ class WrappedNorm(nn.Module):
18
+ """Container with a single nested LayerNorm — produces key ``X.norm.weight``."""
19
+
20
+ def __init__(self, dim: int, eps: float = 1e-6) -> None:
21
+ super().__init__()
22
+ self.norm = nn.LayerNorm(dim, eps=eps)
23
+
24
+ def __call__(self, x: mx.array) -> mx.array:
25
+ return self.norm(x)
26
+
27
+
28
+ class WrappedLinear(nn.Module):
29
+ """Container with a single nested Linear — produces keys ``X.linear.weight/bias``."""
30
+
31
+ def __init__(self, in_dim: int, out_dim: int, bias: bool = True) -> None:
32
+ super().__init__()
33
+ self.linear = nn.Linear(in_dim, out_dim, bias=bias)
34
+
35
+ def __call__(self, x: mx.array) -> mx.array:
36
+ return self.linear(x)
37
+
38
+
39
+ class ProjConv1x1(nn.Module):
40
+ """Conv1d k=1 expressed as ``self.net = Linear`` (matches ``proj_in.net.weight``)."""
41
+
42
+ def __init__(self, in_dim: int, out_dim: int, bias: bool = True) -> None:
43
+ super().__init__()
44
+ self.net = nn.Linear(in_dim, out_dim, bias=bias)
45
+
46
+ def __call__(self, x: mx.array) -> mx.array:
47
+ return self.net(x)
48
+
49
+
50
+ __all__ = ["WrappedNorm", "WrappedLinear", "ProjConv1x1"]
src/supertonic_3_mlx/duration_predictor.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Supertonic 3 duration predictor — predicts total audio duration in seconds.
2
+
3
+ Pipeline (channels-last NTC throughout):
4
+
5
+ text_ids [B, T] int64 character IDs
6
+ → char_embed (Embedding 8322→64) [B, T, 64]
7
+ → prepend sentence_token (1, 64, 1) [B, T+1, 64]
8
+ → 6× ConvNeXt (dim=64, hidden=256, k=5, all dilations=1)
9
+ → 2× RelPosSelfAttn (heads=2, head_dim=32, window=4) + norm + FFN + norm
10
+ → proj_out (Conv1d k=1: 64→64) applied to slot 0 (sentence token)
11
+ → concat with style_dp flattened (B, 8×16=128) [B, 192]
12
+ → Linear(192 → 128) → PReLU → Linear(128 → 1) → exp → duration [B]
13
+
14
+ Inputs:
15
+ text_ids: (B, T) int — character indices
16
+ style_dp: (B, 8, 16) — style summary tokens
17
+ text_mask: (B, 1, T) — 1.0 valid, 0.0 padded
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import mlx.core as mx
22
+ import mlx.nn as nn
23
+
24
+ from supertonic_3_mlx._config import EPS_LN
25
+ from supertonic_3_mlx._nn_wrappers import WrappedNorm
26
+ from supertonic_3_mlx.vector_estimator import _pad_sym_edge, _gelu_exact
27
+
28
+
29
+ DP_VOCAB = 8322
30
+ DP_DIM = 64
31
+ DP_CONVNEXT_HIDDEN = 256
32
+ DP_CONVNEXT_K = 5
33
+ DP_CONVNEXT_NUM_LAYERS = 6
34
+ DP_ATTN_NUM_LAYERS = 2
35
+ DP_ATTN_HEADS = 2
36
+ DP_ATTN_HEAD_DIM = DP_DIM // DP_ATTN_HEADS # 32
37
+ DP_FFN_HIDDEN = 256
38
+ DP_REL_POS_WINDOW = 4
39
+ DP_N_STYLE = 8
40
+ DP_STYLE_DIM = 16
41
+ DP_MLP_IN = DP_DIM + DP_N_STYLE * DP_STYLE_DIM # 64 + 128 = 192
42
+ DP_MLP_HIDDEN = 128
43
+
44
+
45
+ class _DPConvNeXtBlock(nn.Module):
46
+ """ConvNeXt block (dim=64, hidden=256, dilation=1)."""
47
+
48
+ def __init__(self) -> None:
49
+ super().__init__()
50
+ self.dwconv = nn.Conv1d(
51
+ DP_DIM, DP_DIM, kernel_size=DP_CONVNEXT_K, padding=0,
52
+ dilation=1, groups=DP_DIM, bias=True,
53
+ )
54
+ self.norm = WrappedNorm(DP_DIM, eps=EPS_LN)
55
+ self.pwconv1 = nn.Linear(DP_DIM, DP_CONVNEXT_HIDDEN, bias=True)
56
+ self.pwconv2 = nn.Linear(DP_CONVNEXT_HIDDEN, DP_DIM, bias=True)
57
+ self.gamma = mx.zeros((DP_DIM,))
58
+ self.pad = (DP_CONVNEXT_K - 1) // 2
59
+
60
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
61
+ residual = x
62
+ y = _pad_sym_edge(x, self.pad)
63
+ y = self.dwconv(y)
64
+ y = self.norm(y)
65
+ y = self.pwconv1(y)
66
+ y = _gelu_exact(y)
67
+ y = self.pwconv2(y)
68
+ y = y * self.gamma
69
+ out = residual + y
70
+ if mask is not None:
71
+ out = out * mask
72
+ return out
73
+
74
+
75
+ class _DPConvNeXtStack(nn.Module):
76
+ """``convnext.[0..5]`` — 6 ConvNeXt blocks."""
77
+
78
+ def __init__(self) -> None:
79
+ super().__init__()
80
+ self.convnext = [_DPConvNeXtBlock() for _ in range(DP_CONVNEXT_NUM_LAYERS)]
81
+
82
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
83
+ for b in self.convnext:
84
+ x = b(x, mask)
85
+ return x
86
+
87
+
88
+ class _DPConvLayer(nn.Module):
89
+ """Conv1d k=1 with weight (out, 1, in) — matches ONNX storage."""
90
+
91
+ def __init__(self, in_dim: int, out_dim: int) -> None:
92
+ super().__init__()
93
+ self.weight = mx.zeros((out_dim, 1, in_dim))
94
+ self.bias = mx.zeros((out_dim,))
95
+
96
+ def __call__(self, x: mx.array) -> mx.array:
97
+ return mx.conv1d(x, self.weight, stride=1, padding=0) + self.bias
98
+
99
+
100
+ def _dp_rel_to_abs(x: mx.array) -> mx.array:
101
+ """(B, h, L, 2L-1) → (B, h, L, L) via VITS shifted-skew reshape."""
102
+ B, h, L, _ = x.shape
103
+ x = mx.concatenate([x, mx.zeros((B, h, L, 1), dtype=x.dtype)], axis=-1)
104
+ x_flat = x.reshape(B, h, L * 2 * L)
105
+ x_flat = mx.concatenate([x_flat, mx.zeros((B, h, L - 1), dtype=x.dtype)], axis=-1)
106
+ x_final = x_flat.reshape(B, h, L + 1, 2 * L - 1)
107
+ return x_final[:, :, :L, L - 1:]
108
+
109
+
110
+ def _dp_abs_to_rel(x: mx.array) -> mx.array:
111
+ """(B, h, L, L) → (B, h, L, 2L-1)."""
112
+ B, h, L, _ = x.shape
113
+ x = mx.concatenate([x, mx.zeros((B, h, L, L - 1), dtype=x.dtype)], axis=-1)
114
+ x_flat = x.reshape(B, h, L * (2 * L - 1))
115
+ x_flat = mx.concatenate([mx.zeros((B, h, L), dtype=x.dtype), x_flat], axis=-1)
116
+ x_final = x_flat.reshape(B, h, L, 2 * L)
117
+ return x_final[:, :, :, 1:]
118
+
119
+
120
+ def _dp_slice_rel(rel: mx.array, length: int, window: int) -> mx.array:
121
+ """(1, 2W+1, d) → (1, 2L-1, d) by zero-padding/slicing."""
122
+ pad_l = max(length - (window + 1), 0)
123
+ if pad_l > 0:
124
+ zero = mx.zeros((1, pad_l, rel.shape[-1]), dtype=rel.dtype)
125
+ padded = mx.concatenate([zero, rel, zero], axis=1)
126
+ else:
127
+ padded = rel
128
+ start = max(window + 1 - length, 0)
129
+ return padded[:, start: start + 2 * length - 1]
130
+
131
+
132
+ class _DPRelPosSelfAttn(nn.Module):
133
+ """VITS-style rel-pos self-attention (2 heads × 32 head_dim, window=4).
134
+
135
+ Includes both rel-pos contributions (q × rel_k → logits, abs_to_rel(attn) × rel_v → out).
136
+ """
137
+
138
+ def __init__(self) -> None:
139
+ super().__init__()
140
+ self.conv_q = _DPConvLayer(DP_DIM, DP_DIM)
141
+ self.conv_k = _DPConvLayer(DP_DIM, DP_DIM)
142
+ self.conv_v = _DPConvLayer(DP_DIM, DP_DIM)
143
+ self.conv_o = _DPConvLayer(DP_DIM, DP_DIM)
144
+ self.emb_rel_k = mx.zeros((1, 2 * DP_REL_POS_WINDOW + 1, DP_ATTN_HEAD_DIM))
145
+ self.emb_rel_v = mx.zeros((1, 2 * DP_REL_POS_WINDOW + 1, DP_ATTN_HEAD_DIM))
146
+
147
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
148
+ B, T, _ = x.shape
149
+ H, D = DP_ATTN_HEADS, DP_ATTN_HEAD_DIM
150
+ q = self.conv_q(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
151
+ k = self.conv_k(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
152
+ v = self.conv_v(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
153
+ scale = D ** -0.5
154
+
155
+ logits = (q @ k.transpose(0, 1, 3, 2)) * scale
156
+
157
+ rel_k = _dp_slice_rel(self.emb_rel_k, T, DP_REL_POS_WINDOW)
158
+ rel_logits = q @ rel_k.transpose(0, 2, 1)[:, None, :, :]
159
+ rel_logits = _dp_rel_to_abs(rel_logits * scale)
160
+ logits = logits + rel_logits
161
+
162
+ if mask is not None:
163
+ key_mask = mask[:, :, 0][:, None, None, :]
164
+ neg_inf = mx.array(-1e4, dtype=logits.dtype)
165
+ logits = mx.where(key_mask.astype(mx.bool_), logits, neg_inf)
166
+
167
+ attn = mx.softmax(logits, axis=-1)
168
+ out = attn @ v
169
+
170
+ rel_v = _dp_slice_rel(self.emb_rel_v, T, DP_REL_POS_WINDOW)
171
+ rel_weights = _dp_abs_to_rel(attn)
172
+ out = out + rel_weights @ rel_v[:, None, :, :]
173
+
174
+ out = out.transpose(0, 2, 1, 3).reshape(B, T, H * D)
175
+ return self.conv_o(out)
176
+
177
+
178
+ class _DPFFN(nn.Module):
179
+ """FFN with two Conv1d k=1 — 64 → 256 → 64, ReLU + mask."""
180
+
181
+ def __init__(self) -> None:
182
+ super().__init__()
183
+ self.conv_1 = _DPConvLayer(DP_DIM, DP_FFN_HIDDEN)
184
+ self.conv_2 = _DPConvLayer(DP_FFN_HIDDEN, DP_DIM)
185
+
186
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
187
+ if mask is not None:
188
+ x = x * mask
189
+ y = self.conv_1(x)
190
+ y = mx.maximum(y, mx.array(0.0, dtype=y.dtype))
191
+ if mask is not None:
192
+ y = y * mask
193
+ y = self.conv_2(y)
194
+ if mask is not None:
195
+ y = y * mask
196
+ return y
197
+
198
+
199
+ class _DPAttnEncoder(nn.Module):
200
+ """2× (attn + norm) + (ffn + norm)."""
201
+
202
+ def __init__(self) -> None:
203
+ super().__init__()
204
+ self.attn_layers = [_DPRelPosSelfAttn() for _ in range(DP_ATTN_NUM_LAYERS)]
205
+ self.norm_layers_1 = [WrappedNorm(DP_DIM, eps=EPS_LN) for _ in range(DP_ATTN_NUM_LAYERS)]
206
+ self.ffn_layers = [_DPFFN() for _ in range(DP_ATTN_NUM_LAYERS)]
207
+ self.norm_layers_2 = [WrappedNorm(DP_DIM, eps=EPS_LN) for _ in range(DP_ATTN_NUM_LAYERS)]
208
+
209
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
210
+ for i in range(DP_ATTN_NUM_LAYERS):
211
+ x = self.norm_layers_1[i](x + self.attn_layers[i](x, mask=mask))
212
+ x = self.norm_layers_2[i](x + self.ffn_layers[i](x, mask))
213
+ return x
214
+
215
+
216
+ class _DPSentenceEncoder(nn.Module):
217
+ """Text → 64-d sentence vector via prepended ``sentence_token`` slot."""
218
+
219
+ def __init__(self) -> None:
220
+ super().__init__()
221
+ class _TextEmb(nn.Module):
222
+ def __init__(_):
223
+ super().__init__()
224
+ _.char_embedder = nn.Embedding(DP_VOCAB, DP_DIM)
225
+ def __call__(_, ids):
226
+ return _.char_embedder(ids)
227
+ self.text_embedder = _TextEmb()
228
+ self.convnext = _DPConvNeXtStack()
229
+ self.attn_encoder = _DPAttnEncoder()
230
+ # proj_out keeps the .net.weight (out, 1, in) Conv1d-k1 layout
231
+ self.proj_out = _DPProjOut()
232
+ # sentence_token (1, DIM, 1) — prepended as the first time slot
233
+ self.sentence_token = mx.zeros((1, DP_DIM, 1))
234
+
235
+ def __call__(self, text_ids: mx.array, text_mask: mx.array) -> mx.array:
236
+ x = self.text_embedder(text_ids) # (B, T, 64)
237
+ # Prepend sentence_token: shape (1, 64, 1) → (B, 1, 64)
238
+ B = x.shape[0]
239
+ sentence = self.sentence_token.transpose(0, 2, 1)
240
+ sentence = mx.broadcast_to(sentence, (B, 1, DP_DIM))
241
+ x = mx.concatenate([sentence, x], axis=1) # (B, T+1, 64)
242
+
243
+ # Extend mask with a leading 1 (sentence token always valid)
244
+ if text_mask is not None:
245
+ extra = mx.ones((B, 1, 1), dtype=text_mask.dtype)
246
+ mask_ntc = mx.concatenate([extra, text_mask.transpose(0, 2, 1)], axis=1)
247
+ else:
248
+ mask_ntc = None
249
+
250
+ x = self.convnext(x, mask_ntc)
251
+ x = self.attn_encoder(x, mask_ntc)
252
+
253
+ # Take slot 0 (sentence token output) → (B, 1, 64)
254
+ sentence_out = x[:, :1, :] # (B, 1, 64)
255
+ # proj_out (Conv1d k=1) — applied along time, output (B, 1, 64)
256
+ sentence_out = self.proj_out(sentence_out)
257
+ return sentence_out.reshape(B, DP_DIM) # (B, 64)
258
+
259
+
260
+ class _DPProjOut(nn.Module):
261
+ """Conv1d k=1 64→64. No bias in ONNX (confirmed via graph inspection)."""
262
+
263
+ def __init__(self) -> None:
264
+ super().__init__()
265
+ class _Net(nn.Module):
266
+ def __init__(_):
267
+ super().__init__()
268
+ _.weight = mx.zeros((DP_DIM, 1, DP_DIM))
269
+ def __call__(_, x):
270
+ return mx.conv1d(x, _.weight, stride=1, padding=0)
271
+ self.net = _Net()
272
+
273
+ def __call__(self, x: mx.array) -> mx.array:
274
+ return self.net(x)
275
+
276
+
277
+ class _DPPredictor(nn.Module):
278
+ """Linear(192 → 128) + PReLU + Linear(128 → 1).
279
+
280
+ PReLU is stored under ``activation.weight (1,)`` — a single learnable
281
+ negative-slope coefficient.
282
+ """
283
+
284
+ def __init__(self) -> None:
285
+ super().__init__()
286
+ self.layers = [
287
+ nn.Linear(DP_MLP_IN, DP_MLP_HIDDEN, bias=True),
288
+ nn.Linear(DP_MLP_HIDDEN, 1, bias=True),
289
+ ]
290
+ # PReLU: activation.weight shape (1,) — single scalar slope
291
+ class _Activation(nn.Module):
292
+ def __init__(_):
293
+ super().__init__()
294
+ _.weight = mx.zeros((1,))
295
+ def __call__(_, x):
296
+ # PReLU(x) = max(0, x) + slope * min(0, x)
297
+ neg = mx.minimum(x, mx.array(0.0, dtype=x.dtype))
298
+ pos = mx.maximum(x, mx.array(0.0, dtype=x.dtype))
299
+ return pos + _.weight * neg
300
+ self.activation = _Activation()
301
+
302
+ def __call__(self, x: mx.array) -> mx.array:
303
+ h = self.layers[0](x) # (B, 128)
304
+ h = self.activation(h)
305
+ h = self.layers[1](h) # (B, 1)
306
+ return h
307
+
308
+
309
+ class _DPRoot(nn.Module):
310
+ """``tts.dp.X`` namespace container."""
311
+
312
+ def __init__(self) -> None:
313
+ super().__init__()
314
+ self.sentence_encoder = _DPSentenceEncoder()
315
+ self.predictor = _DPPredictor()
316
+
317
+
318
+ class _DPContainer(nn.Module):
319
+ def __init__(self) -> None:
320
+ super().__init__()
321
+ self.dp = _DPRoot()
322
+
323
+
324
+ class DurationPredictor(nn.Module):
325
+ """Predicts total audio duration (seconds) for an utterance.
326
+
327
+ Submodule namespace matches ONNX keys ``tts.dp.X.Y`` exactly.
328
+ """
329
+
330
+ def __init__(self) -> None:
331
+ super().__init__()
332
+ self.tts = _DPContainer()
333
+
334
+ def __call__(
335
+ self,
336
+ text_ids: mx.array, # (B, T) int
337
+ style_dp: mx.array, # (B, 8, 16)
338
+ text_mask: mx.array, # (B, 1, T)
339
+ ) -> mx.array:
340
+ sentence = self.tts.dp.sentence_encoder(text_ids, text_mask) # (B, 64)
341
+ style_flat = style_dp.reshape(style_dp.shape[0], -1) # (B, 128)
342
+ joined = mx.concatenate([sentence, style_flat], axis=-1) # (B, 192)
343
+ log_dur = self.tts.dp.predictor(joined).reshape(-1) # (B,)
344
+ return mx.exp(log_dur) # duration in seconds
345
+
346
+
347
+ __all__ = ["DurationPredictor"]
src/supertonic_3_mlx/pipeline.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Supertonic 3 end-to-end MLX pipeline.
2
+
3
+ Stitches the four MLX sub-models (DurationPredictor → TextEncoder →
4
+ VectorEstimator → Vocoder) into a single ``generate(text, voice, lang)`` call
5
+ that returns a 44.1 kHz mono numpy waveform.
6
+
7
+ Flow:
8
+
9
+ text ──tokenize(unicode_indexer)──▶ text_ids (B, T_text)
10
+
11
+ voice_style (.json) ──▶ style_ttl (B, 50, 256), style_dp (B, 8, 16)
12
+
13
+ duration_predictor(text_ids, style_dp, text_mask) ──▶ duration_s (B,)
14
+
15
+ text_encoder(text_ids, style_ttl, text_mask) ──▶ text_emb (B, 256, T_text)
16
+
17
+ noise ~ N(0, I) of shape (B, 144, T_lat)
18
+ where T_lat = ceil(duration_s × 44100 / (512 × 6))
19
+
20
+ vector_estimator 5-step Euler with CFG (4×cond − 3×uncond):
21
+ for step in [0..4]:
22
+ x ← VE(x, text_emb, style_ttl, masks, current_step=step+1, total_step=5)
23
+
24
+ vocoder(audio_latent) ──▶ wav (B, T_lat × 6 × 512)
25
+
26
+ Public API:
27
+
28
+ pipe = SupertonicMLXPipeline.from_pretrained("/tmp/supertonic3/model")
29
+ wav = pipe.generate("Hello world", voice="F1", lang="en")
30
+ import soundfile as sf
31
+ sf.write("out.wav", wav, pipe.sample_rate)
32
+ """
33
+ from __future__ import annotations
34
+
35
+ import json
36
+ import math
37
+ from pathlib import Path
38
+ from typing import Optional
39
+
40
+ import mlx.core as mx
41
+ import numpy as np
42
+
43
+ from supertonic_3_mlx._config import SAMPLE_RATE
44
+ from supertonic_3_mlx.duration_predictor import DurationPredictor
45
+ from supertonic_3_mlx.text_encoder import TextEncoder
46
+ from supertonic_3_mlx.vector_estimator import VectorEstimator
47
+ from supertonic_3_mlx.vocoder import Vocoder
48
+
49
+
50
+ # Latent rate: at 44.1 kHz with hop=512 and chunk_compress=6, one latent step
51
+ # covers 512 × 6 = 3072 samples = 69.7 ms.
52
+ SAMPLES_PER_LATENT_STEP = 512 * 6 # 3072
53
+
54
+
55
+ # ── Shared ONNX → MLX weight extraction ─────────────────────────────
56
+
57
+
58
+ def _convert_onnx(onnx_path: str | Path) -> dict:
59
+ """Return a dict of ``{clean_key: mx.array}`` for a Supertonic ONNX file.
60
+
61
+ Combines the three extraction stages discovered during the per-component
62
+ ports (T.3.1, T.3.2, T.3.3):
63
+
64
+ 1. Named ``tts.*`` initialisers with shape transforms (dwconv, gamma,
65
+ pwconv, head.layer2).
66
+ 2. Anonymous MatMul weights recovered via the MatMul output path.
67
+ 3. Anonymous Conv weights and PReLU slopes recovered the same way.
68
+ """
69
+ import onnx
70
+ import onnx.numpy_helper as nh
71
+
72
+ m = onnx.load(str(onnx_path))
73
+
74
+ def _matmul_clean(out_name: str) -> str:
75
+ p = out_name.lstrip("/")
76
+ if p.endswith("/MatMul_output_0"):
77
+ p = p[: -len("/MatMul_output_0")]
78
+ # Drop the leading model-name path (e.g. /text_encoder/, /duration_predictor/, /vector_estimator/)
79
+ for prefix in ("text_encoder/", "duration_predictor/", "vector_estimator/", "vocoder/"):
80
+ if p.startswith(prefix):
81
+ p = p[len(prefix):]
82
+ break
83
+ return p.replace("/", ".") + ".weight"
84
+
85
+ def _conv_clean(out_name: str) -> str:
86
+ p = out_name.lstrip("/")
87
+ if p.endswith("/Conv_output_0"):
88
+ p = p[: -len("/Conv_output_0")]
89
+ for prefix in ("vocoder/", "vector_estimator/", "text_encoder/", "duration_predictor/"):
90
+ if p.startswith(prefix):
91
+ p = p[len(prefix):]
92
+ break
93
+ return "tts.ae." + p.replace("/", ".")
94
+
95
+ def _prelu_clean(out_name: str) -> str:
96
+ p = out_name.lstrip("/")
97
+ if p.endswith("/PRelu_output_0"):
98
+ p = p[: -len("/PRelu_output_0")]
99
+ for prefix in ("vocoder/", "vector_estimator/"):
100
+ if p.startswith(prefix):
101
+ p = p[len(prefix):]
102
+ break
103
+ return "tts.ae." + p.replace("/", ".") + ".weight"
104
+
105
+ # Detect which model this file is — affects how we wrap named init keys
106
+ name_prefixes = {init.name.split(".")[0] for init in m.graph.initializer if "." in init.name}
107
+ is_text_encoder = "tts" in name_prefixes and any(
108
+ i.name.startswith("tts.ttl.text_encoder") for i in m.graph.initializer
109
+ )
110
+
111
+ weights: dict[str, mx.array] = {}
112
+
113
+ # Stage 1: named initialisers
114
+ for init in m.graph.initializer:
115
+ n = init.name
116
+ # Determine if this is a structured (named) weight or an anonymous graph const
117
+ if not (n.startswith("tts.") or "vector_estimator.tts.ttl." in n or "uncond_masker." in n):
118
+ continue
119
+
120
+ # Strip the vector_estimator-specific prefix so all 4 models share a name space.
121
+ if n.startswith("vector_estimator.tts.ttl."):
122
+ clean = n[len("vector_estimator.tts.ttl."):]
123
+ else:
124
+ clean = n
125
+
126
+ arr = nh.to_array(init)
127
+
128
+ # Shape transforms
129
+ if (clean.endswith(".dwconv.weight") and arr.ndim == 3
130
+ and arr.shape[1] == 1 and arr.shape[2] != 1):
131
+ arr = np.transpose(arr, (0, 2, 1))
132
+ if (clean.endswith(".dwconv.net.weight") and arr.ndim == 3
133
+ and arr.shape[1] == 1):
134
+ arr = np.transpose(arr, (0, 2, 1))
135
+ if (clean.endswith(".gamma") and arr.ndim == 3
136
+ and arr.shape[0] == 1 and arr.shape[2] == 1):
137
+ arr = arr.reshape(arr.shape[1])
138
+ if ((clean.endswith(".pwconv1.weight") or clean.endswith(".pwconv2.weight"))
139
+ and arr.ndim == 3 and arr.shape[-1] == 1):
140
+ arr = arr.squeeze(-1)
141
+ if clean.endswith(".net.weight") and arr.ndim == 3 and arr.shape[-1] == 1:
142
+ # Conv1d k=1 wrapped via .net (e.g. proj_in/proj_out)
143
+ arr = arr.squeeze(-1)
144
+ # vocoder head.layer2 (out, in, 1) → MLX Conv1d (out, K=1, in)
145
+ if clean == "tts.ae.decoder.head.layer2.weight" and arr.ndim == 3:
146
+ arr = np.transpose(arr, (0, 2, 1))
147
+ # vocoder head.layer1.net.weight (out, in, K) → MLX Conv1d (out, K, in)
148
+ if clean == "tts.ae.decoder.head.layer1.net.weight" and arr.ndim == 3:
149
+ arr = np.transpose(arr, (0, 2, 1))
150
+
151
+ weights[clean] = mx.array(arr)
152
+
153
+ # Stage 2: MatMul weight recovery
154
+ inits_map = {init.name: init for init in m.graph.initializer}
155
+ for node in m.graph.node:
156
+ if node.op_type != "MatMul" or len(node.input) < 2:
157
+ continue
158
+ winp = node.input[1]
159
+ if winp not in inits_map or winp.startswith("tts.") or "vector_estimator.tts" in winp:
160
+ continue
161
+ arr = nh.to_array(inits_map[winp])
162
+ if arr.ndim == 2:
163
+ arr = arr.T # ONNX (in, out) → MLX Linear (out, in)
164
+ clean = _matmul_clean(node.output[0])
165
+ # Build the leading namespace from the file context (already in tts.*)
166
+ if not clean.startswith(("tts.", "vector_field.", "uncond_masker.")):
167
+ clean = "tts.ttl." + clean if is_text_encoder else clean
168
+ weights[clean] = mx.array(arr)
169
+
170
+ # Stage 3: anonymous Conv + PReLU (vocoder embed / head)
171
+ for node in m.graph.node:
172
+ if node.op_type == "Conv":
173
+ for i, inp in enumerate(node.input[1:], 1):
174
+ if inp not in inits_map or inp.startswith("tts."):
175
+ continue
176
+ arr = nh.to_array(inits_map[inp])
177
+ base = _conv_clean(node.output[0])
178
+ if "dwconv" in base:
179
+ continue
180
+ if i == 1 and arr.ndim == 3:
181
+ arr = np.transpose(arr, (0, 2, 1)) # ONNX (out, in, K) → MLX (out, K, in)
182
+ key = base + (".weight" if i == 1 else ".bias")
183
+ weights[key] = mx.array(arr)
184
+ elif node.op_type == "PRelu":
185
+ for inp in node.input[1:]:
186
+ if inp in inits_map and not inp.startswith("tts."):
187
+ weights[_prelu_clean(node.output[0])] = mx.array(nh.to_array(inits_map[inp]))
188
+
189
+ return weights
190
+
191
+
192
+ def _load_into(model, weights: dict) -> int:
193
+ """Match converted weights to model params (shape-tolerant via reshape).
194
+
195
+ Returns the number of successfully matched tensors.
196
+ """
197
+ from mlx.utils import tree_flatten
198
+ expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
199
+ matched = {}
200
+ for k, exp_shape in expected.items():
201
+ if k not in weights:
202
+ continue
203
+ v = weights[k]
204
+ if tuple(v.shape) != exp_shape:
205
+ if v.size == np.prod(exp_shape):
206
+ v = v.reshape(exp_shape)
207
+ else:
208
+ continue
209
+ matched[k] = v
210
+ model.load_weights(list(matched.items()), strict=False)
211
+ return len(matched)
212
+
213
+
214
+ # ── Tokenization ────────────────────────────────────────────────────
215
+
216
+
217
+ def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
218
+ """Encode a text string into character IDs.
219
+
220
+ The unicode_indexer is a flat list of size 65536; ``indexer[ord(c)]`` gives
221
+ the token ID for character ``c`` (-1 = unknown). For Phase T.4 we wrap the
222
+ text with no special language tokens — the ONNX SDK uses language tags but
223
+ our pipeline currently runs unconditioned on language for the first WAV
224
+ emission (parity validation happens after).
225
+ """
226
+ ids = []
227
+ for c in text:
228
+ cp = ord(c)
229
+ if 0 <= cp < len(indexer):
230
+ tok = indexer[cp]
231
+ if tok >= 0:
232
+ ids.append(tok)
233
+ if not ids:
234
+ # fallback to a single space token to avoid empty input
235
+ ids = [indexer[ord(" ")]] if indexer[ord(" ")] >= 0 else [0]
236
+ return np.asarray(ids, dtype=np.int32)
237
+
238
+
239
+ # ── Pipeline ────────────────────────────────────────────────────────
240
+
241
+
242
+ class SupertonicMLXPipeline:
243
+ """End-to-end Supertonic 3 TTS pipeline in pure MLX.
244
+
245
+ Loads four sub-models (duration_predictor, text_encoder, vector_estimator,
246
+ vocoder), the unicode tokenizer, and exposes ``generate(text, voice, lang)``.
247
+ """
248
+
249
+ sample_rate: int = SAMPLE_RATE
250
+ # Locked by the model architecture: Supertonic 3 is a flow-matching + CFG
251
+ # model trained for exactly 5 Euler steps with t ∈ {0.2, 0.4, 0.6, 0.8, 1.0}
252
+ # and the combination 4×cond − 3×uncond. Any other step count or skipping
253
+ # CFG produces an essentially uncorrelated waveform (verified by
254
+ # ``sub-projects/supertonic3-mlx/bench_n_steps.py``: cosine drops to
255
+ # ≤ 0.5 for n∈{3,4,6} and ≈ 0.05 for cfg=False). Reducing inference
256
+ # latency further would require distilling a shorter-schedule model.
257
+ n_euler_steps: int = 5
258
+
259
+ def __init__(
260
+ self,
261
+ duration_predictor: DurationPredictor,
262
+ text_encoder: TextEncoder,
263
+ vector_estimator: VectorEstimator,
264
+ vocoder: Vocoder,
265
+ unicode_indexer: list[int],
266
+ voice_dir: Path,
267
+ ) -> None:
268
+ self.duration_predictor = duration_predictor
269
+ self.text_encoder = text_encoder
270
+ self.vector_estimator = vector_estimator
271
+ self.vocoder = vocoder
272
+ self.unicode_indexer = unicode_indexer
273
+ self.voice_dir = voice_dir
274
+
275
+ # T.5 — compile the hot loops. ``mx.compile`` caches a kernel graph keyed
276
+ # by input shapes; the 5× CFG Euler loop and the single vocoder pass
277
+ # both gain from fused kernel dispatch (~50–100 layer ops collapse into
278
+ # one dispatch per cached graph).
279
+
280
+ # T.5.3 — also pre-project text and style K/V outside the step. They
281
+ # are invariant across the 5 Euler steps, so the 4 text_attn + 4
282
+ # style_attn blocks no longer re-run their W_key / W_value / RoPE_K
283
+ # matmuls on every step (saves 40 matmuls per generate).
284
+ cond_scale = self.vector_estimator.CFG_COND_SCALE
285
+ uncond_scale = self.vector_estimator.CFG_UNCOND_SCALE
286
+
287
+ def _cached_step(
288
+ noisy, lat_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
289
+ ):
290
+ noisy_2 = mx.concatenate([noisy, noisy], axis=0)
291
+ text_kv = [(kv_flat[2 * i], kv_flat[2 * i + 1]) for i in range(4)]
292
+ style_kv = [(kv_flat[8 + 2 * i], kv_flat[8 + 2 * i + 1]) for i in range(4)]
293
+ v_2 = self.vector_estimator.velocity_cached(
294
+ noisy_2, lat_mask_2, text_mask_2, t_norm_2, text_kv, style_kv,
295
+ )
296
+ B = noisy.shape[0]
297
+ cond_v = v_2[:B]
298
+ uncond_v = v_2[B:]
299
+ combined = cond_scale * cond_v - uncond_scale * uncond_v
300
+ return noisy + combined / total_step.reshape(-1, 1, 1).astype(combined.dtype)
301
+
302
+ def _voc_step(latent):
303
+ return self.vocoder(latent)
304
+
305
+ self._cached_step_compiled = mx.compile(_cached_step)
306
+ self._voc_compiled = mx.compile(_voc_step)
307
+
308
+ # Pick the runtime dtype from any leaf weight of the vector estimator —
309
+ # ``from_pretrained(dtype=...)`` may have cast the model to ``bf16``,
310
+ # in which case all inputs to the compiled hot loops must be cast to
311
+ # match (mixed-dtype Conv/MatMul is not legal in MLX).
312
+ from mlx.utils import tree_flatten
313
+ leaves = [v for _, v in tree_flatten(vector_estimator.parameters())
314
+ if isinstance(v, mx.array)]
315
+ self.dtype = leaves[0].dtype if leaves else mx.float32
316
+
317
+ @classmethod
318
+ def from_pretrained(
319
+ cls,
320
+ model_id_or_path: str | Path,
321
+ dtype: mx.Dtype | None = None,
322
+ cache_dir: str | Path | None = None,
323
+ revision: str | None = None,
324
+ ) -> "SupertonicMLXPipeline":
325
+ """Construct the pipeline from a model snapshot.
326
+
327
+ Three sources are accepted, auto-detected:
328
+
329
+ 1. **Hugging Face Hub repo id** (e.g. ``"ambassadia/supertonic-3-mlx"``):
330
+ weights are downloaded via :func:`huggingface_hub.snapshot_download`
331
+ into ``cache_dir`` (defaults to the standard HF cache) and loaded
332
+ directly from the bundled ``weights/*.safetensors`` files.
333
+ 2. **Local path with a** ``weights/`` **subdir**: the MLX-native
334
+ layout (4 safetensors + ``unicode_indexer.json`` + ``voice_styles/``).
335
+ Fast path — no ONNX conversion at runtime.
336
+ 3. **Local path with an** ``onnx/`` **subdir**: the upstream
337
+ ``Supertone/supertonic-3`` snapshot layout. Weights are converted
338
+ from ONNX on the fly (~ 1 s per sub-model on M4). Useful for
339
+ development or when starting from the original upstream release.
340
+
341
+ Optional kwargs:
342
+ dtype — if non-None and not float32, cast all weights to the
343
+ given dtype after load (only ``mx.bfloat16`` is
344
+ currently meaningful; see README "BF16 note").
345
+ cache_dir — passed to ``huggingface_hub.snapshot_download``.
346
+ revision — branch / tag / commit sha on the Hub.
347
+ """
348
+ # 1. Resolve the local snapshot directory
349
+ if isinstance(model_id_or_path, str) and "/" in model_id_or_path \
350
+ and not Path(model_id_or_path).exists():
351
+ try:
352
+ from huggingface_hub import snapshot_download
353
+ except ImportError as e:
354
+ raise ImportError(
355
+ "Loading from the Hugging Face Hub requires "
356
+ "``huggingface_hub`` — install with ``pip install "
357
+ "supertonic-3-mlx[hub]`` or ``pip install huggingface_hub``."
358
+ ) from e
359
+ local_dir = Path(snapshot_download(
360
+ repo_id=model_id_or_path,
361
+ cache_dir=cache_dir,
362
+ revision=revision,
363
+ allow_patterns=[
364
+ "weights/*.safetensors",
365
+ "unicode_indexer.json",
366
+ "voice_styles/*.json",
367
+ ],
368
+ ))
369
+ else:
370
+ local_dir = Path(model_id_or_path)
371
+
372
+ # 2. Detect layout
373
+ weights_dir = local_dir / "weights"
374
+ onnx_dir = local_dir / "onnx"
375
+ if weights_dir.exists():
376
+ return cls._from_safetensors(local_dir, dtype=dtype)
377
+ if onnx_dir.exists():
378
+ return cls._from_onnx(local_dir, dtype=dtype)
379
+ raise FileNotFoundError(
380
+ f"{local_dir} contains neither ``weights/`` (safetensors layout) "
381
+ f"nor ``onnx/`` (upstream layout); cannot load."
382
+ )
383
+
384
+ @classmethod
385
+ def _from_safetensors(
386
+ cls, local_dir: Path, dtype: mx.Dtype | None = None,
387
+ ) -> "SupertonicMLXPipeline":
388
+ from mlx.utils import tree_flatten
389
+ weights_dir = local_dir / "weights"
390
+ voice_dir = local_dir / "voice_styles"
391
+ unicode_indexer = json.loads((local_dir / "unicode_indexer.json").read_text())
392
+
393
+ def _build(cls_, name):
394
+ model = cls_()
395
+ w = mx.load(str(weights_dir / f"{name}.safetensors"))
396
+ # Reshape any mismatched leaves (defensive; the converter already
397
+ # produced shape-correct tensors but a future re-export may not).
398
+ expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
399
+ for k in list(w.keys()):
400
+ if k in expected and tuple(w[k].shape) != expected[k]:
401
+ if w[k].size == int(np.prod(expected[k])):
402
+ w[k] = w[k].reshape(expected[k])
403
+ model.load_weights(list(w.items()), strict=False)
404
+ return model
405
+
406
+ ve = _build(VectorEstimator, "vector_estimator")
407
+ te = _build(TextEncoder, "text_encoder")
408
+ dp = _build(DurationPredictor, "duration_predictor")
409
+ voc = _build(Vocoder, "vocoder")
410
+
411
+ if dtype is not None and dtype != mx.float32:
412
+ cls._cast_all(dp, te, ve, voc, dtype=dtype)
413
+
414
+ return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
415
+
416
+ @classmethod
417
+ def _from_onnx(
418
+ cls, local_dir: Path, dtype: mx.Dtype | None = None,
419
+ ) -> "SupertonicMLXPipeline":
420
+ onnx_dir = local_dir / "onnx"
421
+ voice_dir = local_dir / "voice_styles"
422
+ unicode_indexer = json.loads((onnx_dir / "unicode_indexer.json").read_text())
423
+
424
+ ve = VectorEstimator()
425
+ _load_into(ve, _convert_onnx(onnx_dir / "vector_estimator.onnx"))
426
+ te = TextEncoder()
427
+ _load_into(te, _convert_onnx(onnx_dir / "text_encoder.onnx"))
428
+ dp = DurationPredictor()
429
+ _load_into(dp, _convert_onnx(onnx_dir / "duration_predictor.onnx"))
430
+ voc = Vocoder()
431
+ _load_into(voc, _convert_onnx(onnx_dir / "vocoder.onnx"))
432
+
433
+ if dtype is not None and dtype != mx.float32:
434
+ cls._cast_all(dp, te, ve, voc, dtype=dtype)
435
+
436
+ return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
437
+
438
+ @staticmethod
439
+ def _cast_all(*models, dtype: mx.Dtype) -> None:
440
+ """Cast all fp32 leaves of each model to ``dtype`` (in-place)."""
441
+ from mlx.utils import tree_map
442
+
443
+ def _cast(p):
444
+ if not isinstance(p, mx.array) or p.dtype != mx.float32:
445
+ return p
446
+ return p.astype(dtype)
447
+
448
+ for m_ in models:
449
+ m_.update(tree_map(_cast, m_.parameters()))
450
+
451
+ def _load_voice(self, voice: str) -> tuple[mx.array, mx.array]:
452
+ """Load ``voice_styles/<voice>.json`` and return (style_ttl, style_dp)."""
453
+ path = self.voice_dir / f"{voice}.json"
454
+ data = json.loads(path.read_text())
455
+ style_ttl = np.asarray(data["style_ttl"]["data"], dtype=np.float32) # (1, 50, 256)
456
+ style_dp = np.asarray(data["style_dp"]["data"], dtype=np.float32) # (1, 8, 16)
457
+ return mx.array(style_ttl), mx.array(style_dp)
458
+
459
+ def generate(
460
+ self,
461
+ text: str,
462
+ voice: str = "F1",
463
+ lang: str = "en",
464
+ seed: int = 42,
465
+ n_steps: Optional[int] = None,
466
+ ) -> np.ndarray:
467
+ """Synthesise a single utterance. Returns a 1D float32 numpy waveform."""
468
+ n_steps = n_steps if n_steps is not None else self.n_euler_steps
469
+
470
+ # Tokenize
471
+ text_ids_np = _encode_text(text, self.unicode_indexer, lang)
472
+ text_ids = mx.array(text_ids_np[None, :]) # (1, T_text)
473
+ T_text = text_ids.shape[1]
474
+ text_mask = mx.ones((1, 1, T_text), dtype=self.dtype)
475
+
476
+ # Style
477
+ style_ttl, style_dp = self._load_voice(voice)
478
+ if self.dtype != mx.float32:
479
+ style_ttl = style_ttl.astype(self.dtype)
480
+ style_dp = style_dp.astype(self.dtype)
481
+
482
+ # Duration → latent length
483
+ duration_s = self.duration_predictor(text_ids, style_dp, text_mask)
484
+ mx.eval(duration_s)
485
+ duration_val = max(float(duration_s[0].item()), 0.5) # clamp to ≥ 0.5 s
486
+ T_lat = max(int(math.ceil(duration_val * self.sample_rate / SAMPLES_PER_LATENT_STEP)), 1)
487
+
488
+ # Text embedding
489
+ text_emb = self.text_encoder(text_ids, style_ttl, text_mask) # (1, 256, T_text)
490
+
491
+ # Initial noise — fixed seed for reproducibility
492
+ key = mx.random.key(seed)
493
+ noise = mx.random.normal((1, 144, T_lat), key=key).astype(self.dtype)
494
+ latent_mask = mx.ones((1, 1, T_lat), dtype=self.dtype)
495
+
496
+ # T.5.3 — build the (2B) CFG conditioning tensors once and pre-project
497
+ # K/V for every text_attn / style_attn block. ``kv_flat`` is the 16
498
+ # ``(K, V)`` arrays flattened into a list for the compiled step.
499
+ B = noise.shape[0]
500
+ ve = self.vector_estimator
501
+ text_uncond = mx.broadcast_to(
502
+ ve.uncond_masker.text_special_token, (B, text_emb.shape[1], text_emb.shape[2])
503
+ ).astype(self.dtype)
504
+ style_k_uncond = mx.broadcast_to(
505
+ ve.uncond_masker.style_key_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
506
+ ).astype(self.dtype)
507
+ style_v_uncond = mx.broadcast_to(
508
+ ve.uncond_masker.style_value_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
509
+ ).astype(self.dtype)
510
+ text_emb_2 = mx.concatenate([text_emb, text_uncond], axis=0)
511
+ style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
512
+ style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
513
+ text_mask_2 = mx.concatenate([text_mask, text_mask], axis=0)
514
+ latent_mask_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
515
+
516
+ text_kv, style_kv = ve.precompute_cross_kv(
517
+ text_emb_2, style_k_2, style_v_2, text_mask_2,
518
+ )
519
+ kv_flat = []
520
+ for k, v in text_kv:
521
+ kv_flat.extend([k, v])
522
+ for k, v in style_kv:
523
+ kv_flat.extend([k, v])
524
+
525
+ # Euler with CFG — 5 steps by default
526
+ x = noise
527
+ total_step = mx.array([float(n_steps)], dtype=self.dtype)
528
+ for step in range(n_steps):
529
+ current_step = mx.array([float(step + 1)], dtype=self.dtype)
530
+ t_norm = current_step / total_step
531
+ t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
532
+ x = self._cached_step_compiled(
533
+ x, latent_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
534
+ )
535
+ mx.eval(x)
536
+
537
+ # Decode latent → waveform
538
+ wav = self._voc_compiled(x)
539
+ mx.eval(wav)
540
+ if wav.dtype != mx.float32:
541
+ wav = wav.astype(mx.float32)
542
+ return np.array(wav)[0] # (T_lat × 6 × 512,)
543
+
544
+
545
+ __all__ = ["SupertonicMLXPipeline"]
src/supertonic_3_mlx/text_encoder.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Supertonic 3 text encoder MLX port.
2
+
3
+ Pipeline (operating in channels-last NTC after the initial conv):
4
+
5
+ text_ids [B, T_text] int64 character IDs
6
+ → char_embedder (Embedding 8322→256) [B, T_text, 256]
7
+ → 6× ConvNeXt(dim=256, hidden=1024, k=5, dilations [1,1,2,2,4,4])
8
+ → 4× attn_encoder block:
9
+ RelPosSelfAttn (conv_q/k/v/o, 4 heads × 64) + norm_layers_1
10
+ FFN (conv_1: 256→1024, conv_2: 1024→256) + norm_layers_2
11
+ → speech_prompted_text_encoder:
12
+ cross-attn1: text (Q) × style_ttl (K, V) → text features
13
+ cross-attn2: text (Q) × style_ttl (K, V) → text features
14
+ norm
15
+ → output text_emb [B, 256, T_text] (channels-first to match vector_estimator)
16
+
17
+ Inputs:
18
+ text_ids: (B, T_text) int — character indices
19
+ style_ttl: (B, 50, 256) float — style token bank
20
+ text_mask: (B, 1, T_text) float — 1.0 where valid, 0.0 where padded
21
+
22
+ Submodule naming matches the ONNX initializer keys exactly so that
23
+ ``model.load_weights(...)`` succeeds with no remapping.
24
+ """
25
+ from __future__ import annotations
26
+
27
+ import mlx.core as mx
28
+ import mlx.nn as nn
29
+
30
+ from supertonic_3_mlx._config import EPS_LN
31
+ from supertonic_3_mlx._nn_wrappers import WrappedNorm, WrappedLinear
32
+ from supertonic_3_mlx.vector_estimator import (
33
+ ConvNeXtBlock, _pad_sym_edge, _gelu_exact,
34
+ )
35
+
36
+
37
+ # Vocab + dims (frozen by checkpoint)
38
+ VOCAB_SIZE = 8322
39
+ TE_DIM = 256
40
+ TE_CONVNEXT_HIDDEN = 1024
41
+ TE_CONVNEXT_K = 5
42
+ TE_CONVNEXT_NUM_LAYERS = 6
43
+ TE_CONVNEXT_DILATIONS = (1, 1, 2, 2, 4, 4)
44
+
45
+ TE_ATTN_NUM_LAYERS = 4
46
+ TE_ATTN_HEADS = 4
47
+ TE_ATTN_HEAD_DIM = TE_DIM // TE_ATTN_HEADS # 64
48
+ TE_FFN_HIDDEN = 1024
49
+
50
+
51
+ class TextConvNeXtBlock(nn.Module):
52
+ """ConvNeXt for the text encoder (dim=256, hidden=1024).
53
+
54
+ Shares the same architecture as ``vector_estimator.ConvNeXtBlock`` but is
55
+ redefined here with text-encoder-specific defaults to keep the modules
56
+ self-contained.
57
+ """
58
+
59
+ def __init__(self, dilation: int = 1) -> None:
60
+ super().__init__()
61
+ self.dim = TE_DIM
62
+ self.dilation = dilation
63
+ self.pad = dilation * (TE_CONVNEXT_K - 1) // 2
64
+ self.dwconv = nn.Conv1d(
65
+ TE_DIM, TE_DIM, kernel_size=TE_CONVNEXT_K, padding=0,
66
+ dilation=dilation, groups=TE_DIM, bias=True,
67
+ )
68
+ self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
69
+ self.pwconv1 = nn.Linear(TE_DIM, TE_CONVNEXT_HIDDEN, bias=True)
70
+ self.pwconv2 = nn.Linear(TE_CONVNEXT_HIDDEN, TE_DIM, bias=True)
71
+ self.gamma = mx.zeros((TE_DIM,))
72
+
73
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
74
+ # x: (B, T_text, 256)
75
+ residual = x
76
+ y = _pad_sym_edge(x, self.pad)
77
+ y = self.dwconv(y)
78
+ y = self.norm(y)
79
+ y = self.pwconv1(y)
80
+ y = _gelu_exact(y)
81
+ y = self.pwconv2(y)
82
+ y = y * self.gamma
83
+ out = residual + y
84
+ if mask is not None:
85
+ out = out * mask
86
+ return out
87
+
88
+
89
+ class TextConvNeXtStack(nn.Module):
90
+ """6 stacked ConvNeXt blocks. Loaded as ``convnext.convnext.[0..5].X``."""
91
+
92
+ def __init__(self) -> None:
93
+ super().__init__()
94
+ self.convnext = [TextConvNeXtBlock(d) for d in TE_CONVNEXT_DILATIONS]
95
+
96
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
97
+ for b in self.convnext:
98
+ x = b(x, mask)
99
+ return x
100
+
101
+
102
+ class _ConvLayer(nn.Module):
103
+ """Conv1d k=1 expressed via the ONNX-style ``X.weight (out, in, 1) + X.bias``.
104
+
105
+ The attn_encoder uses Conv1d k=1 instead of nn.Linear for its Q/K/V/O.
106
+ This wrapper keeps the weight shape (out, in, 1) intact and runs as a
107
+ Conv1d (the equivalent of a Linear when k=1).
108
+ """
109
+
110
+ def __init__(self, in_dim: int, out_dim: int) -> None:
111
+ super().__init__()
112
+ self.weight = mx.zeros((out_dim, 1, in_dim)) # (C_out, K=1, C_in)
113
+ self.bias = mx.zeros((out_dim,))
114
+
115
+ def __call__(self, x: mx.array) -> mx.array:
116
+ # x: (B, T, in_dim) — channels-last
117
+ # equivalent to nn.Conv1d(in_dim, out_dim, k=1) in NTC layout
118
+ return mx.conv1d(x, self.weight, stride=1, padding=0) + self.bias
119
+
120
+
121
+ REL_POS_WINDOW = 4 # rel_pos table size = 2*4 + 1 = 9
122
+
123
+
124
+ def _rel_to_abs(x: mx.array) -> mx.array:
125
+ """[B, h, L, 2L-1] → [B, h, L, L] via the VITS shifted-skew reshape."""
126
+ B, h, L, _ = x.shape
127
+ x = mx.concatenate([x, mx.zeros((B, h, L, 1), dtype=x.dtype)], axis=-1)
128
+ x_flat = x.reshape(B, h, L * 2 * L)
129
+ x_flat = mx.concatenate([x_flat, mx.zeros((B, h, L - 1), dtype=x.dtype)], axis=-1)
130
+ x_final = x_flat.reshape(B, h, L + 1, 2 * L - 1)
131
+ return x_final[:, :, :L, L - 1:]
132
+
133
+
134
+ def _abs_to_rel(x: mx.array) -> mx.array:
135
+ """[B, h, L, L] → [B, h, L, 2L-1] (inverse of _rel_to_abs)."""
136
+ B, h, L, _ = x.shape
137
+ x = mx.concatenate([x, mx.zeros((B, h, L, L - 1), dtype=x.dtype)], axis=-1)
138
+ x_flat = x.reshape(B, h, L * (2 * L - 1))
139
+ x_flat = mx.concatenate([mx.zeros((B, h, L), dtype=x.dtype), x_flat], axis=-1)
140
+ x_final = x_flat.reshape(B, h, L, 2 * L)
141
+ return x_final[:, :, :, 1:]
142
+
143
+
144
+ def _slice_rel_emb(rel: mx.array, length: int, window: int) -> mx.array:
145
+ """``rel`` (1, 2W+1, d) → (1, 2L-1, d) by zero-padding/slicing."""
146
+ pad_l = max(length - (window + 1), 0)
147
+ if pad_l > 0:
148
+ zero = mx.zeros((1, pad_l, rel.shape[-1]), dtype=rel.dtype)
149
+ padded = mx.concatenate([zero, rel, zero], axis=1)
150
+ else:
151
+ padded = rel
152
+ start = max(window + 1 - length, 0)
153
+ return padded[:, start: start + 2 * length - 1]
154
+
155
+
156
+ class RelPosSelfAttention(nn.Module):
157
+ """VITS-style relative-position self-attention with window=4.
158
+
159
+ Adds two contributions to vanilla MHA:
160
+ - ``rel_logits = q @ rel_k.T`` then ``_rel_to_abs`` and added to attention logits
161
+ - ``rel_attn = _abs_to_rel(softmax(logits))`` then ``@ rel_v`` and added to output
162
+
163
+ Loaded keys (per layer):
164
+ ``conv_q/k/v/o.weight`` (256, 256, 1) and ``.bias`` (256)
165
+ ``emb_rel_k`` (1, 9, 64), ``emb_rel_v`` (1, 9, 64)
166
+ """
167
+
168
+ def __init__(self) -> None:
169
+ super().__init__()
170
+ self.conv_q = _ConvLayer(TE_DIM, TE_DIM)
171
+ self.conv_k = _ConvLayer(TE_DIM, TE_DIM)
172
+ self.conv_v = _ConvLayer(TE_DIM, TE_DIM)
173
+ self.conv_o = _ConvLayer(TE_DIM, TE_DIM)
174
+ self.window = REL_POS_WINDOW
175
+ self.emb_rel_k = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
176
+ self.emb_rel_v = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
177
+
178
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
179
+ B, T, _ = x.shape
180
+ H, D = TE_ATTN_HEADS, TE_ATTN_HEAD_DIM
181
+ q = self.conv_q(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
182
+ k = self.conv_k(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
183
+ v = self.conv_v(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
184
+ scale = D ** -0.5
185
+
186
+ # Standard attention logits
187
+ logits = (q @ k.transpose(0, 1, 3, 2)) * scale # (B, H, T, T)
188
+
189
+ # VITS relative-position contribution to logits
190
+ rel_k = _slice_rel_emb(self.emb_rel_k, T, self.window) # (1, 2T-1, D)
191
+ rel_logits = q @ rel_k.transpose(0, 2, 1)[:, None, :, :] # (B, H, T, 2T-1)
192
+ rel_logits = _rel_to_abs(rel_logits * scale) # (B, H, T, T)
193
+ logits = logits + rel_logits
194
+
195
+ if mask is not None:
196
+ key_mask = mask[:, :, 0][:, None, None, :]
197
+ neg_inf = mx.array(-1e4, dtype=logits.dtype)
198
+ logits = mx.where(key_mask.astype(mx.bool_), logits, neg_inf)
199
+
200
+ attn = mx.softmax(logits, axis=-1) # (B, H, T, T)
201
+ out = attn @ v # (B, H, T, D)
202
+
203
+ # VITS rel-pos value contribution
204
+ rel_v = _slice_rel_emb(self.emb_rel_v, T, self.window) # (1, 2T-1, D)
205
+ rel_weights = _abs_to_rel(attn) # (B, H, T, 2T-1)
206
+ out = out + rel_weights @ rel_v[:, None, :, :] # (B, H, T, D)
207
+
208
+ out = out.transpose(0, 2, 1, 3).reshape(B, T, H * D)
209
+ return self.conv_o(out)
210
+
211
+
212
+ class FFN(nn.Module):
213
+ """FFN with Conv1d k=1 wrappers: conv_1 (256→1024) + ReLU + conv_2 (1024→256).
214
+
215
+ Activation is ReLU (confirmed by ONNX graph node ``Relu`` in ``ffn_layers.N``),
216
+ not GELU. The mask is applied before each Conv to match the ONNX semantics.
217
+ """
218
+
219
+ def __init__(self) -> None:
220
+ super().__init__()
221
+ self.conv_1 = _ConvLayer(TE_DIM, TE_FFN_HIDDEN)
222
+ self.conv_2 = _ConvLayer(TE_FFN_HIDDEN, TE_DIM)
223
+
224
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
225
+ if mask is not None:
226
+ x = x * mask
227
+ y = self.conv_1(x)
228
+ y = mx.maximum(y, mx.array(0.0, dtype=y.dtype))
229
+ if mask is not None:
230
+ y = y * mask
231
+ y = self.conv_2(y)
232
+ if mask is not None:
233
+ y = y * mask
234
+ return y
235
+
236
+
237
+ class AttnEncoder(nn.Module):
238
+ """Stack of (RelPosSelfAttn + norm1) + (FFN + norm2) × 4."""
239
+
240
+ def __init__(self) -> None:
241
+ super().__init__()
242
+ self.attn_layers = [RelPosSelfAttention() for _ in range(TE_ATTN_NUM_LAYERS)]
243
+ self.norm_layers_1 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
244
+ self.ffn_layers = [FFN() for _ in range(TE_ATTN_NUM_LAYERS)]
245
+ self.norm_layers_2 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
246
+
247
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
248
+ for i in range(TE_ATTN_NUM_LAYERS):
249
+ y = self.attn_layers[i](x, mask=mask)
250
+ x = self.norm_layers_1[i](x + y)
251
+ y = self.ffn_layers[i](x, mask)
252
+ x = self.norm_layers_2[i](x + y)
253
+ return x
254
+
255
+
256
+ class _TextEmbedder(nn.Module):
257
+ """char_embedder: VOCAB → TE_DIM. Loaded as ``char_embedder.weight (8322, 256)``."""
258
+
259
+ def __init__(self) -> None:
260
+ super().__init__()
261
+ self.char_embedder = nn.Embedding(VOCAB_SIZE, TE_DIM)
262
+
263
+ def __call__(self, text_ids: mx.array) -> mx.array:
264
+ return self.char_embedder(text_ids)
265
+
266
+
267
+ class _InnerTextEncoder(nn.Module):
268
+ """Pure text encoder before speech prompting. Loaded as ``text_encoder.X.Y``."""
269
+
270
+ def __init__(self) -> None:
271
+ super().__init__()
272
+ self.text_embedder = _TextEmbedder()
273
+ self.convnext = TextConvNeXtStack()
274
+ self.attn_encoder = AttnEncoder()
275
+
276
+ def __call__(self, text_ids: mx.array, mask: mx.array) -> mx.array:
277
+ x = self.text_embedder(text_ids) # (B, T, 256)
278
+ if mask is not None:
279
+ x = x * mask
280
+ x = self.convnext(x, mask)
281
+ x = self.attn_encoder(x, mask)
282
+ return x
283
+
284
+
285
+ class _StyleEncoder(nn.Module):
286
+ """Holds ``style_token_layer.style_key`` (1, 50, 256)."""
287
+
288
+ def __init__(self) -> None:
289
+ super().__init__()
290
+ # Use a child module so the parameter path matches ``style_token_layer.style_key``
291
+ class _StyleTokenLayer(nn.Module):
292
+ def __init__(_):
293
+ super().__init__()
294
+ _.style_key = mx.zeros((1, 50, 256))
295
+ self.style_token_layer = _StyleTokenLayer()
296
+
297
+
298
+ class _SpeechPromptedAttn(nn.Module):
299
+ """Cross-attention from text (Q) to style_ttl (K, V). Single head, 256-d."""
300
+
301
+ def __init__(self) -> None:
302
+ super().__init__()
303
+ self.W_query = WrappedLinear(TE_DIM, TE_DIM, bias=True)
304
+ self.W_key = WrappedLinear(TE_DIM, TE_DIM, bias=True)
305
+ self.W_value = WrappedLinear(TE_DIM, TE_DIM, bias=True)
306
+ self.out_fc = WrappedLinear(TE_DIM, TE_DIM, bias=True)
307
+
308
+ def __call__(self, x: mx.array, style: mx.array) -> mx.array:
309
+ # x: (B, T_text, 256); style: (B, 50, 256)
310
+ # Single-head cross attention.
311
+ B, T, D = x.shape
312
+ q = self.W_query(x)
313
+ k = self.W_key(style)
314
+ v = self.W_value(style)
315
+ scale = D ** -0.5
316
+ logits = (q @ k.transpose(0, 2, 1)) * scale
317
+ attn = mx.softmax(logits, axis=-1)
318
+ out = attn @ v
319
+ return self.out_fc(out)
320
+
321
+
322
+ class _SpeechPromptedTextEncoder(nn.Module):
323
+ """Two cross-attention layers modulating text features with style_ttl."""
324
+
325
+ def __init__(self) -> None:
326
+ super().__init__()
327
+ self.attention1 = _SpeechPromptedAttn()
328
+ self.attention2 = _SpeechPromptedAttn()
329
+ self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
330
+
331
+ def __call__(self, x: mx.array, style: mx.array) -> mx.array:
332
+ x = x + self.attention1(x, style)
333
+ x = x + self.attention2(x, style)
334
+ return self.norm(x)
335
+
336
+
337
+ class _RootTextEncoder(nn.Module):
338
+ """Top-level container matching ONNX ``tts.ttl.*`` namespace."""
339
+
340
+ def __init__(self) -> None:
341
+ super().__init__()
342
+ self.text_encoder = _InnerTextEncoder()
343
+ self.style_encoder = _StyleEncoder()
344
+ self.speech_prompted_text_encoder = _SpeechPromptedTextEncoder()
345
+
346
+
347
+ class _TtsContainer(nn.Module):
348
+ """Outer container so weight keys ``tts.ttl.X.Y`` resolve."""
349
+
350
+ def __init__(self) -> None:
351
+ super().__init__()
352
+ self.ttl = _RootTextEncoder()
353
+
354
+
355
+ class TextEncoder(nn.Module):
356
+ """Top-level text encoder: ``text_ids + style_ttl + text_mask → text_emb (B, 256, T)``.
357
+
358
+ Submodule naming matches the ONNX initializer keys after a single
359
+ ``tts.ttl.`` prefix wrap (so weight keys look like
360
+ ``tts.ttl.text_encoder.convnext.convnext.0.dwconv.weight``).
361
+ """
362
+
363
+ def __init__(self) -> None:
364
+ super().__init__()
365
+ self.tts = _TtsContainer()
366
+
367
+ def __call__(
368
+ self,
369
+ text_ids: mx.array, # (B, T_text) int
370
+ style_ttl: mx.array, # (B, 50, 256)
371
+ text_mask: mx.array, # (B, 1, T_text)
372
+ ) -> mx.array:
373
+ mask_ntc = text_mask.transpose(0, 2, 1) # (B, T_text, 1)
374
+ x = self.tts.ttl.text_encoder(text_ids, mask_ntc)
375
+ x = self.tts.ttl.speech_prompted_text_encoder(x, style_ttl)
376
+ if mask_ntc is not None:
377
+ x = x * mask_ntc
378
+ # Return channels-first (B, 256, T_text) to match the vector_estimator input.
379
+ return x.transpose(0, 2, 1)
380
+
381
+
382
+ __all__ = ["TextEncoder", "VOCAB_SIZE", "TE_DIM"]
src/supertonic_3_mlx/vector_estimator.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Supertonic 3 vector estimator (64 M params) — flow-matching denoiser, MLX port.
2
+
3
+ Pipeline (operating in channels-last NTC layout):
4
+
5
+ noisy_latent [B, 144, T_lat] (channels first from ONNX I/O)
6
+ → transpose [B, T_lat, 144]
7
+ → proj_in (Linear 144→512) [B, T_lat, 512]
8
+ → 24 main_blocks (4 cycles × 6 sub-types):
9
+ cycle = [stack4, time_film, cn1, text_attn, cn1, style_attn]
10
+ → last_convnext (4 ConvNeXt) [B, T_lat, 512]
11
+ → proj_out (Linear 512→144) [B, T_lat, 144]
12
+ → transpose [B, 144, T_lat]
13
+ → Euler step: denoised = noisy + velocity * (1 / total_step)
14
+ → output [B, 144, T_lat]
15
+
16
+ Submodule naming matches the s3 ONNX initializer keys exactly, so loading
17
+ the safetensors produced by ``weights.convert_onnx_to_mlx`` requires no
18
+ remapping.
19
+
20
+ The forward path is faithful to ONNX semantics in fp32; ``mx.compile``,
21
+ quantisation, and kernel fusion are layered on later in T.3.
22
+ """
23
+ from __future__ import annotations
24
+
25
+ import math
26
+
27
+ import mlx.core as mx
28
+ import mlx.nn as nn
29
+
30
+ from supertonic_3_mlx._config import (
31
+ DIM, LATENT_CH, CONVNEXT_HIDDEN, CONVNEXT_K, STACK4_DILATIONS,
32
+ NUM_MAIN_BLOCKS, BLOCKS_PER_CYCLE, BLOCK_CYCLE,
33
+ TEXT_DIM, TEXT_HEADS, TEXT_HEAD_DIM, ROTARY_BASE, ROTARY_SCALE,
34
+ STYLE_DIM, STYLE_LEN, STYLE_HEADS, STYLE_HEAD_DIM,
35
+ TIME_EMB_DIM, TIME_MLP_HIDDEN,
36
+ EPS_LN,
37
+ )
38
+ from supertonic_3_mlx._nn_wrappers import (
39
+ WrappedNorm, WrappedLinear, ProjConv1x1,
40
+ )
41
+
42
+
43
+ def _pad_sym_edge(x: mx.array, pad: int) -> mx.array:
44
+ """Symmetric replicate-edge pad on the time axis (axis=1 for [B, T, C])."""
45
+ if pad == 0:
46
+ return x
47
+ left = mx.broadcast_to(x[:, :1, :], (x.shape[0], pad, x.shape[2]))
48
+ right = mx.broadcast_to(x[:, -1:, :], (x.shape[0], pad, x.shape[2]))
49
+ return mx.concatenate([left, x, right], axis=1)
50
+
51
+
52
+ def _gelu_exact(x: mx.array) -> mx.array:
53
+ """Exact (non-tanh) GELU: x * 0.5 * (1 + erf(x / sqrt(2)))."""
54
+ return x * 0.5 * (1.0 + mx.erf(x * (2 ** -0.5)))
55
+
56
+
57
+ def _mish(x: mx.array) -> mx.array:
58
+ """Mish: x * tanh(softplus(x)) = x * tanh(log(1 + exp(x)))."""
59
+ return x * mx.tanh(mx.logaddexp(x, mx.array(0.0, dtype=x.dtype)))
60
+
61
+
62
+ # ──────────────────────────────────────────────────────────────────
63
+ # ConvNeXt building blocks
64
+ # ──────────────────────────────────────────────────────────────────
65
+
66
+
67
+ class ConvNeXtBlock(nn.Module):
68
+ """Single ConvNeXt block matching s3 keys: ``dwconv``, ``norm.norm``, ``pwconv1/2``, ``gamma``."""
69
+
70
+ def __init__(
71
+ self,
72
+ dim: int = DIM,
73
+ hidden: int = CONVNEXT_HIDDEN,
74
+ kernel: int = CONVNEXT_K,
75
+ dilation: int = 1,
76
+ ) -> None:
77
+ super().__init__()
78
+ self.dim = dim
79
+ self.dilation = dilation
80
+ self.pad = dilation * (kernel - 1) // 2
81
+ self.dwconv = nn.Conv1d(
82
+ dim, dim, kernel_size=kernel, padding=0, dilation=dilation,
83
+ groups=dim, bias=True,
84
+ )
85
+ self.norm = WrappedNorm(dim, eps=EPS_LN)
86
+ self.pwconv1 = nn.Linear(dim, hidden, bias=True)
87
+ self.pwconv2 = nn.Linear(hidden, dim, bias=True)
88
+ # Stored as shape (1, dim, 1) in the ONNX checkpoint — see weights.py for
89
+ # the load-time reshape that flattens it to (dim,) for broadcasting in NTC.
90
+ self.gamma = mx.zeros((dim,))
91
+
92
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
93
+ # x: (B, T, C)
94
+ residual = x
95
+ y = _pad_sym_edge(x, self.pad)
96
+ y = self.dwconv(y) # (B, T, C)
97
+ y = self.norm(y) # LayerNorm last-dim
98
+ y = self.pwconv1(y) # (B, T, hidden)
99
+ y = _gelu_exact(y)
100
+ y = self.pwconv2(y) # (B, T, C)
101
+ y = y * self.gamma # broadcast over (B, T, .)
102
+ out = residual + y
103
+ if mask is not None:
104
+ out = out * mask
105
+ return out
106
+
107
+
108
+ class ConvNeXtStack(nn.Module):
109
+ """List of ConvNeXt blocks. Loaded as ``convnext.[0..N-1].X``."""
110
+
111
+ def __init__(self, dilations: tuple, dim: int = DIM, hidden: int = CONVNEXT_HIDDEN) -> None:
112
+ super().__init__()
113
+ self.convnext = [ConvNeXtBlock(dim, hidden, CONVNEXT_K, d) for d in dilations]
114
+
115
+ def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
116
+ for b in self.convnext:
117
+ x = b(x, mask)
118
+ return x
119
+
120
+
121
+ # ──────────────────────────────────────────────────────────────────
122
+ # 6 block types per cycle
123
+ # ───────────────────────────��──────────────────────────────────────
124
+
125
+
126
+ class Stack4Block(nn.Module):
127
+ """Cycle position 0 — 4 ConvNeXt with dilations [1, 2, 4, 8].
128
+
129
+ Loaded keys: ``convnext.[0..3].{dwconv,norm.norm,pwconv1,pwconv2,gamma}``.
130
+ """
131
+
132
+ def __init__(self) -> None:
133
+ super().__init__()
134
+ self.convnext = [ConvNeXtBlock(DIM, CONVNEXT_HIDDEN, CONVNEXT_K, d) for d in STACK4_DILATIONS]
135
+
136
+ def __call__(self, x: mx.array, mask: mx.array | None, **_) -> mx.array:
137
+ for b in self.convnext:
138
+ x = b(x, mask)
139
+ return x
140
+
141
+
142
+ class TimeFiLMBlock(nn.Module):
143
+ """Cycle position 1 — additive time conditioning: ``x + linear(t_emb)``.
144
+
145
+ Loaded keys: ``linear.linear.{weight,bias}``.
146
+ """
147
+
148
+ def __init__(self) -> None:
149
+ super().__init__()
150
+ self.linear = WrappedLinear(TIME_EMB_DIM, DIM, bias=True)
151
+
152
+ def __call__(self, x: mx.array, mask: mx.array | None, t_emb: mx.array, **_) -> mx.array:
153
+ # t_emb: (B, TIME_EMB_DIM) → broadcast across T
154
+ bias = self.linear(t_emb)[:, None, :] # (B, 1, DIM)
155
+ y = x + bias
156
+ if mask is not None:
157
+ y = y * mask
158
+ return y
159
+
160
+
161
+ class ConvNeXt1Block(nn.Module):
162
+ """Cycle positions 2 and 4 — a single ConvNeXt block.
163
+
164
+ Loaded keys: ``convnext.0.{dwconv,norm.norm,pwconv1,pwconv2,gamma}``.
165
+ """
166
+
167
+ def __init__(self) -> None:
168
+ super().__init__()
169
+ self.convnext = [ConvNeXtBlock(DIM, CONVNEXT_HIDDEN, CONVNEXT_K, 1)]
170
+
171
+ def __call__(self, x: mx.array, mask: mx.array | None, **_) -> mx.array:
172
+ return self.convnext[0](x, mask)
173
+
174
+
175
+ def _build_rope_freqs(head_dim: int, base: int, scale: int, max_len: int = 1024) -> mx.array:
176
+ """Pre-compute RoPE cos/sin table — (max_len, head_dim/2, 2)."""
177
+ half = head_dim // 2
178
+ inv_freq = 1.0 / (base ** (mx.arange(half, dtype=mx.float32) / half))
179
+ pos = mx.arange(max_len, dtype=mx.float32) * scale
180
+ angles = pos[:, None] * inv_freq[None, :] # (max_len, half)
181
+ return mx.stack([mx.cos(angles), mx.sin(angles)], axis=-1) # (max_len, half, 2)
182
+
183
+
184
+ def _apply_rope(x: mx.array, freqs: mx.array) -> mx.array:
185
+ """Apply RoPE rotation. ``x`` shape (B, H, T, head_dim); ``freqs`` (T, half, 2)."""
186
+ half = x.shape[-1] // 2
187
+ x_even, x_odd = x[..., :half], x[..., half:]
188
+ cos = freqs[..., 0] # (T, half)
189
+ sin = freqs[..., 1]
190
+ rot_even = x_even * cos[None, None, :, :] - x_odd * sin[None, None, :, :]
191
+ rot_odd = x_even * sin[None, None, :, :] + x_odd * cos[None, None, :, :]
192
+ return mx.concatenate([rot_even, rot_odd], axis=-1)
193
+
194
+
195
+ class TextCrossAttnBlock(nn.Module):
196
+ """Cycle position 3 — text cross-attention with RoPE on Q and K.
197
+
198
+ Loaded keys:
199
+ ``attn.W_query.linear.{weight,bias}``
200
+ ``attn.W_key.linear.{weight,bias}``
201
+ ``attn.W_value.linear.{weight,bias}``
202
+ ``attn.out_fc.linear.{weight,bias}``
203
+ ``attn.theta`` — frozen RoPE inv-freq table (1, 1, half)
204
+ ``attn.increments`` — frozen position table (1, 1000, 1) — 0..999
205
+ ``norm.norm.{weight,bias}``
206
+ """
207
+
208
+ def __init__(self) -> None:
209
+ super().__init__()
210
+ self.attn = _AttnInner(DIM, TEXT_DIM, TEXT_HEADS, TEXT_HEAD_DIM)
211
+ self.norm = WrappedNorm(DIM, eps=EPS_LN)
212
+
213
+ def __call__(
214
+ self,
215
+ x: mx.array,
216
+ mask: mx.array | None,
217
+ *,
218
+ text_emb: mx.array | None = None,
219
+ text_mask: mx.array | None = None,
220
+ latent_seq_len: mx.array | None = None,
221
+ text_seq_len: mx.array | None = None,
222
+ kv_cache: tuple[mx.array, mx.array] | None = None,
223
+ **_,
224
+ ) -> mx.array:
225
+ # x: (B, T_lat, DIM); text_emb: (B, T_text, TEXT_DIM) — unused when kv_cache supplied.
226
+ residual = x * mask if mask is not None else x
227
+ h = self.attn(
228
+ residual, text_emb, text_mask=text_mask,
229
+ latent_seq_len=latent_seq_len, text_seq_len=text_seq_len,
230
+ kv_cache=kv_cache,
231
+ )
232
+ if mask is not None:
233
+ h = h * mask
234
+ out = self.norm(residual + h)
235
+ if mask is not None:
236
+ out = out * mask
237
+ return out
238
+
239
+
240
+ class _AttnInner(nn.Module):
241
+ """Multi-head cross-attention with RoPE applied to query and key.
242
+
243
+ Holds parameters under ``W_query``, ``W_key``, ``W_value``, ``out_fc`` —
244
+ each is a :class:`WrappedLinear` so its weight is keyed
245
+ ``…W_query.linear.weight`` to match the ONNX checkpoint.
246
+
247
+ ``theta`` and ``increments`` come from the ONNX graph as frozen tensors
248
+ (precomputed RoPE table). We rebuild the equivalent table from the
249
+ Supertonic-3 config so the module is self-contained.
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ in_dim: int,
255
+ ctx_dim: int,
256
+ num_heads: int,
257
+ head_dim: int,
258
+ ) -> None:
259
+ super().__init__()
260
+ self.num_heads = num_heads
261
+ self.head_dim = head_dim
262
+ # ONNX divides attention logits by 16.0 (= sqrt(TEXT_DIM)), not sqrt(head_dim).
263
+ self.scale = ctx_dim ** -0.5
264
+
265
+ kv_dim = num_heads * head_dim # = DIM = 512
266
+ self.W_query = WrappedLinear(in_dim, kv_dim, bias=True)
267
+ self.W_key = WrappedLinear(ctx_dim, kv_dim, bias=True)
268
+ self.W_value = WrappedLinear(ctx_dim, kv_dim, bias=True)
269
+ self.out_fc = WrappedLinear(kv_dim, in_dim, bias=True)
270
+
271
+ # Frozen RoPE tables — overwritten by checkpoint at load time.
272
+ # ONNX layout:
273
+ # ``increments`` (1, 1000, 1) holds positions 0..999 (no scale baked in)
274
+ # ``theta`` (1, 1, half) holds rotary_scale × base^(-i/half)
275
+ # Angle formula: ``angle = (pos / actual_seq_len) × theta``.
276
+ # The division by the actual seq length is critical — it normalises
277
+ # absolute positions into [0, 1] so audio and text are RoPE-aligned
278
+ # regardless of their respective lengths.
279
+ max_len = 1000
280
+ half = head_dim // 2
281
+ idx = mx.arange(half, dtype=mx.float32)
282
+ self.theta = (ROTARY_SCALE * mx.exp(-math.log(ROTARY_BASE) * idx / half))[None, None, :]
283
+ positions = mx.arange(max_len, dtype=mx.int64)
284
+ self.increments = positions[None, :, None] # (1, max_len, 1)
285
+
286
+ def _rope(self, x: mx.array, seq_len: mx.array | int | None = None) -> mx.array:
287
+ """Apply RoPE rotation. ``seq_len`` is the effective (unmasked) length.
288
+
289
+ Args:
290
+ x: (B, H, T, head_dim)
291
+ seq_len: scalar or (B,) — actual sequence length for position normalisation.
292
+ If None, defaults to T (no normalisation).
293
+ """
294
+ T = x.shape[-2]
295
+ positions = self.increments[:, :T, :] # (1, T, 1)
296
+ if seq_len is None:
297
+ seq_len = float(T)
298
+ if isinstance(seq_len, (int, float)):
299
+ divisor = float(seq_len)
300
+ else:
301
+ divisor = seq_len.astype(mx.float32).reshape(-1, 1, 1)
302
+ norm_pos = positions / divisor # broadcasts to (B, T, 1) if divisor is (B,1,1)
303
+ angles = norm_pos * self.theta # (B, T, half) or (1, T, half)
304
+ cos = mx.cos(angles)
305
+ sin = mx.sin(angles)
306
+ half = self.head_dim // 2
307
+ # Broadcast (?, T, half) → (?, 1, T, half) for head dim
308
+ cos_b = cos[..., None, :, :] if cos.ndim == 3 else cos[None, None, :, :]
309
+ sin_b = sin[..., None, :, :] if sin.ndim == 3 else sin[None, None, :, :]
310
+ # Make sure broadcasts properly
311
+ if cos_b.shape[0] == 1 and x.shape[0] > 1:
312
+ cos_b = mx.broadcast_to(cos_b, (x.shape[0], 1, T, half))
313
+ sin_b = mx.broadcast_to(sin_b, (x.shape[0], 1, T, half))
314
+ # Reshape if needed
315
+ cos_b = cos_b.reshape(-1, 1, T, half)
316
+ sin_b = sin_b.reshape(-1, 1, T, half)
317
+ x_first, x_second = x[..., :half], x[..., half:]
318
+ rot_first = x_first * cos_b - x_second * sin_b
319
+ rot_second = x_first * sin_b + x_second * cos_b
320
+ return mx.concatenate([rot_first, rot_second], axis=-1)
321
+
322
+ def project_kv(
323
+ self,
324
+ text_emb: mx.array,
325
+ text_seq_len: mx.array | None = None,
326
+ ) -> tuple[mx.array, mx.array]:
327
+ """Project text_emb → (K_rope, V) once. Both are constant across the
328
+ Euler steps in a TTS inference call (T.5.3 cache target)."""
329
+ B, T_text, _ = text_emb.shape
330
+ H, D = self.num_heads, self.head_dim
331
+ k = self.W_key(text_emb).reshape(B, T_text, H, D).transpose(0, 2, 1, 3)
332
+ v = self.W_value(text_emb).reshape(B, T_text, H, D).transpose(0, 2, 1, 3)
333
+ k = self._rope(k, seq_len=text_seq_len if text_seq_len is not None else T_text)
334
+ return k, v
335
+
336
+ def __call__(
337
+ self,
338
+ x: mx.array,
339
+ text_emb: mx.array | None = None,
340
+ text_mask: mx.array | None = None,
341
+ latent_seq_len: mx.array | None = None,
342
+ text_seq_len: mx.array | None = None,
343
+ kv_cache: tuple[mx.array, mx.array] | None = None,
344
+ ) -> mx.array:
345
+ B, T_lat, _ = x.shape
346
+ H, D = self.num_heads, self.head_dim
347
+
348
+ q = self.W_query(x).reshape(B, T_lat, H, D).transpose(0, 2, 1, 3) # (B, H, T_lat, D)
349
+ if kv_cache is not None:
350
+ k, v = kv_cache
351
+ else:
352
+ k, v = self.project_kv(text_emb, text_seq_len=text_seq_len)
353
+
354
+ # RoPE normalises positions by the effective (unmasked) sequence length.
355
+ q = self._rope(q, seq_len=latent_seq_len if latent_seq_len is not None else T_lat)
356
+
357
+ # Attention
358
+ logits = (q @ k.transpose(0, 1, 3, 2)) * self.scale # (B, H, T_lat, T_text)
359
+ if text_mask is not None:
360
+ neg_inf = mx.array(-1e4, dtype=logits.dtype)
361
+ logits = mx.where(text_mask[:, :, None, :].astype(mx.bool_), logits, neg_inf)
362
+ attn = mx.softmax(logits, axis=-1)
363
+ out = attn @ v # (B, H, T_lat, D)
364
+ out = out.transpose(0, 2, 1, 3).reshape(B, T_lat, H * D)
365
+ return self.out_fc(out)
366
+
367
+
368
+ class StyleCrossAttnBlock(nn.Module):
369
+ """Cycle position 5 — style cross-attention to 50 learned style tokens.
370
+
371
+ Loaded keys:
372
+ ``attention.W_query.linear.{weight,bias}``
373
+ ``attention.W_key.linear.{weight,bias}``
374
+ ``attention.W_value.linear.{weight,bias}``
375
+ ``attention.out_fc.linear.{weight,bias}``
376
+ ``norm.norm.{weight,bias}``
377
+ """
378
+
379
+ def __init__(self) -> None:
380
+ super().__init__()
381
+ self.attention = _StyleAttnInner(DIM, STYLE_DIM, STYLE_HEADS, STYLE_HEAD_DIM)
382
+ self.norm = WrappedNorm(DIM, eps=EPS_LN)
383
+
384
+ def __call__(
385
+ self,
386
+ x: mx.array,
387
+ mask: mx.array | None,
388
+ *,
389
+ style_k: mx.array | None = None,
390
+ style_v: mx.array | None = None,
391
+ kv_cache: tuple[mx.array, mx.array] | None = None,
392
+ **_,
393
+ ) -> mx.array:
394
+ # style_v defaults to style_k (same tensor for cond path); CFG path supplies
395
+ # different style_v to model the uncond branch.
396
+ if style_v is None and style_k is not None:
397
+ style_v = style_k
398
+ residual = x * mask if mask is not None else x
399
+ h = self.attention(residual, style_k, style_v, kv_cache=kv_cache)
400
+ if mask is not None:
401
+ h = h * mask
402
+ out = self.norm(residual + h)
403
+ if mask is not None:
404
+ out = out * mask
405
+ return out
406
+
407
+
408
+ class _StyleAttnInner(nn.Module):
409
+ def __init__(self, in_dim: int, ctx_dim: int, num_heads: int, head_dim: int) -> None:
410
+ super().__init__()
411
+ self.num_heads = num_heads
412
+ self.head_dim = head_dim
413
+ # ONNX divides attention logits by 16.0 (= sqrt(STYLE_DIM)), not sqrt(head_dim).
414
+ self.scale = ctx_dim ** -0.5
415
+ kv_dim = num_heads * head_dim # 2 * 128 = 256
416
+ # Q is on DIM (audio), K/V on ctx_dim (style 256)
417
+ self.W_query = WrappedLinear(in_dim, kv_dim, bias=True)
418
+ self.W_key = WrappedLinear(ctx_dim, kv_dim, bias=True)
419
+ self.W_value = WrappedLinear(ctx_dim, kv_dim, bias=True)
420
+ self.out_fc = WrappedLinear(kv_dim, in_dim, bias=True)
421
+
422
+ def project_kv(
423
+ self, style_k: mx.array, style_v: mx.array
424
+ ) -> tuple[mx.array, mx.array]:
425
+ """Project (style_k, style_v) → (K, V) once. T.5.3 cache target."""
426
+ B, T_style = style_k.shape[0], style_k.shape[1]
427
+ H, D = self.num_heads, self.head_dim
428
+ # Note: ONNX graph applies tanh to the K projection (``attention/tanh/Tanh``
429
+ # node) — the style key bank is bounded into [-1, 1] before softmax dot
430
+ # product, which acts as a soft attention temperature regulariser.
431
+ k = mx.tanh(self.W_key(style_k)).reshape(B, T_style, H, D).transpose(0, 2, 1, 3)
432
+ v = self.W_value(style_v).reshape(B, style_v.shape[1], H, D).transpose(0, 2, 1, 3)
433
+ return k, v
434
+
435
+ def __call__(
436
+ self,
437
+ x: mx.array,
438
+ style_k: mx.array | None = None,
439
+ style_v: mx.array | None = None,
440
+ kv_cache: tuple[mx.array, mx.array] | None = None,
441
+ ) -> mx.array:
442
+ # style_k and style_v can be the same tensor (cond) or distinct (uncond
443
+ # branch in CFG, where K comes from style_key_special_token and V from
444
+ # style_value_special_token).
445
+ B, T_lat, _ = x.shape
446
+ H, D = self.num_heads, self.head_dim
447
+ q = self.W_query(x).reshape(B, T_lat, H, D).transpose(0, 2, 1, 3)
448
+ if kv_cache is not None:
449
+ k, v = kv_cache
450
+ else:
451
+ k, v = self.project_kv(style_k, style_v)
452
+ logits = (q @ k.transpose(0, 1, 3, 2)) * self.scale
453
+ attn = mx.softmax(logits, axis=-1)
454
+ out = attn @ v
455
+ out = out.transpose(0, 2, 1, 3).reshape(B, T_lat, H * D)
456
+ return self.out_fc(out)
457
+
458
+
459
+ # ──────────────────────────────────────────────────────────────────
460
+ # Time encoder
461
+ # ──────────────────────────────────────────────────────────────────
462
+
463
+
464
+ class _MlpItem(nn.Module):
465
+ """A single MLP layer wrapped to produce keys ``mlp.N.linear.{weight,bias}``."""
466
+
467
+ def __init__(self, in_dim: int, out_dim: int) -> None:
468
+ super().__init__()
469
+ self.linear = nn.Linear(in_dim, out_dim, bias=True)
470
+
471
+ def __call__(self, x: mx.array) -> mx.array:
472
+ return self.linear(x)
473
+
474
+
475
+ class TimeEncoder(nn.Module):
476
+ """Sinusoidal time embedding + 2-layer MLP. Keys: ``mlp.0.linear``, ``mlp.2.linear``."""
477
+
478
+ def __init__(self) -> None:
479
+ super().__init__()
480
+ # ONNX: mlp.0.linear (64→256), mlp.2.linear (256→64). Index 1 is activation.
481
+ self.mlp = [
482
+ _MlpItem(TIME_EMB_DIM, TIME_MLP_HIDDEN), # mlp.0
483
+ nn.Identity(), # mlp.1 (activation; no weights)
484
+ _MlpItem(TIME_MLP_HIDDEN, TIME_EMB_DIM), # mlp.2
485
+ ]
486
+
487
+ def __call__(self, t: mx.array) -> mx.array:
488
+ # t: (B,) — produce sinusoidal embedding then run through MLP.
489
+ # Activation is Mish (not SiLU) to match the ONNX graph
490
+ # (Softplus → Tanh → Mul pattern == x * tanh(softplus(x))).
491
+ emb = self._sinusoidal(t, TIME_EMB_DIM)
492
+ h = self.mlp[0](emb)
493
+ h = _mish(h)
494
+ h = self.mlp[2](h)
495
+ return h
496
+
497
+ @staticmethod
498
+ def _sinusoidal(t: mx.array, dim: int) -> mx.array:
499
+ """Time embedding matching ``Supertonic-3`` ONNX exactly.
500
+
501
+ ONNX path: pos = t * 1000; freqs[i] = 10000^(-i/(half-1));
502
+ concat[sin(pos*freqs), cos(pos*freqs)].
503
+ """
504
+ half = dim // 2
505
+ denom = max(half - 1, 1)
506
+ freqs = mx.exp(-math.log(10_000) * mx.arange(half, dtype=mx.float32) / denom)
507
+ pos = t.astype(mx.float32)[:, None] * 1000.0
508
+ angles = pos * freqs[None, :]
509
+ return mx.concatenate([mx.sin(angles), mx.cos(angles)], axis=-1).astype(mx.float32)
510
+
511
+
512
+ # ──────────────────────────────────────────────────────────────────
513
+ # Top-level VectorEstimator
514
+ # ──────────────────────────────────────────────────────────────────
515
+
516
+
517
+ def _build_main_block(idx: int) -> nn.Module:
518
+ """Instantiate the appropriate block class for cycle position ``idx % 6``."""
519
+ pos = idx % BLOCKS_PER_CYCLE
520
+ name = BLOCK_CYCLE[pos]
521
+ if name == "stack4":
522
+ return Stack4Block()
523
+ if name == "time":
524
+ return TimeFiLMBlock()
525
+ if name == "cn1":
526
+ return ConvNeXt1Block()
527
+ if name == "text_attn":
528
+ return TextCrossAttnBlock()
529
+ if name == "style_attn":
530
+ return StyleCrossAttnBlock()
531
+ raise RuntimeError(f"unknown block type for index {idx}: {name}")
532
+
533
+
534
+ class _VectorField(nn.Module):
535
+ """Inner module mirroring ONNX ``vector_estimator.tts.ttl.vector_field.*``."""
536
+
537
+ def __init__(self) -> None:
538
+ super().__init__()
539
+ self.proj_in = ProjConv1x1(LATENT_CH, DIM, bias=False)
540
+ self.main_blocks = [_build_main_block(i) for i in range(NUM_MAIN_BLOCKS)]
541
+ self.last_convnext = ConvNeXtStack(dilations=(1, 1, 1, 1), dim=DIM, hidden=CONVNEXT_HIDDEN)
542
+ self.proj_out = ProjConv1x1(DIM, LATENT_CH, bias=False)
543
+ self.time_encoder = TimeEncoder()
544
+
545
+
546
+ class _UncondMasker(nn.Module):
547
+ """Holds the three unconditional-token tensors used by CFG.
548
+
549
+ Keys:
550
+ ``text_special_token`` (1, 256, 1)
551
+ ``style_key_special_token`` (1, 50, 256)
552
+ ``style_value_special_token`` (1, 50, 256)
553
+ """
554
+
555
+ def __init__(self) -> None:
556
+ super().__init__()
557
+ # Initialised to zero; checkpoint provides real values.
558
+ self.text_special_token = mx.zeros((1, TEXT_DIM, 1))
559
+ self.style_key_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM))
560
+ self.style_value_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM))
561
+
562
+
563
+ class VectorEstimator(nn.Module):
564
+ """Top-level module — matches ONNX root names ``vector_field.*`` and ``uncond_masker.*``.
565
+
566
+ Two inference paths:
567
+ - :meth:`velocity`: single forward pass; predicts the velocity from one set
568
+ of conditioning inputs. ``style_k``/``style_v`` may be the same tensor
569
+ (cond path) or different (uncond path of CFG).
570
+ - :meth:`__call__`: full ONNX-parity forward — applies CFG batch doubling
571
+ (cond + uncond) internally and combines via
572
+ ``final = noisy + (4*cond - 3*uncond) / total_step``.
573
+ """
574
+
575
+ # CFG guidance constants — baked into the ONNX graph as ``/Constant_3`` (=4.0)
576
+ # and ``/Constant_4`` (=3.0). Equivalent to guidance_scale = 4 with the
577
+ # standard formula ``v = uncond + g*(cond - uncond) = 4*cond - 3*uncond``.
578
+ CFG_COND_SCALE: float = 4.0
579
+ CFG_UNCOND_SCALE: float = 3.0
580
+
581
+ def __init__(self) -> None:
582
+ super().__init__()
583
+ self.vector_field = _VectorField()
584
+ self.uncond_masker = _UncondMasker()
585
+
586
+ # ── inference API ─────────────────────────────────────────────
587
+ def velocity(
588
+ self,
589
+ noisy_latent: mx.array, # (B, 144, T_lat)
590
+ text_emb: mx.array, # (B, 256, T_text)
591
+ style_k: mx.array, # (B, 50, 256) — K side of style attention
592
+ style_v: mx.array, # (B, 50, 256) — V side of style attention
593
+ latent_mask: mx.array, # (B, 1, T_lat)
594
+ text_mask: mx.array, # (B, 1, T_text)
595
+ t_norm: mx.array, # (B,) timestep in [0, 1]
596
+ ) -> mx.array:
597
+ """Predict velocity (B, 144, T_lat) without applying CFG or Euler step."""
598
+ x = noisy_latent.transpose(0, 2, 1) # (B, T_lat, 144)
599
+ text = text_emb.transpose(0, 2, 1) # (B, T_text, 256)
600
+ lat_mask_ntc = latent_mask.transpose(0, 2, 1) # (B, T_lat, 1)
601
+
602
+ x = self.vector_field.proj_in(x) # (B, T_lat, 512)
603
+ t_emb = self.vector_field.time_encoder(t_norm) # (B, TIME_EMB_DIM)
604
+
605
+ # Effective (unmasked) sequence lengths for RoPE normalisation —
606
+ # ONNX uses ``ReduceSum(mask)`` for this so that audio and text are
607
+ # rope-aligned regardless of padding.
608
+ latent_seq_len = mx.sum(latent_mask, axis=(1, 2)) # (B,)
609
+ text_seq_len = mx.sum(text_mask, axis=(1, 2)) # (B,)
610
+
611
+ for blk in self.vector_field.main_blocks:
612
+ x = blk(
613
+ x,
614
+ lat_mask_ntc,
615
+ t_emb=t_emb,
616
+ text_emb=text,
617
+ text_mask=text_mask,
618
+ style_k=style_k,
619
+ style_v=style_v,
620
+ latent_seq_len=latent_seq_len,
621
+ text_seq_len=text_seq_len,
622
+ )
623
+
624
+ x = self.vector_field.last_convnext(x, lat_mask_ntc)
625
+ v_ntc = self.vector_field.proj_out(x) # (B, T_lat, 144)
626
+ return v_ntc.transpose(0, 2, 1) # (B, 144, T_lat)
627
+
628
+ # ── T.5.3 — pre-projected K/V path ────────────────────────────
629
+ def precompute_cross_kv(
630
+ self,
631
+ text_emb: mx.array, # (B, 256, T_text) channels-first
632
+ style_k: mx.array, # (B, 50, 256)
633
+ style_v: mx.array, # (B, 50, 256)
634
+ text_mask: mx.array, # (B, 1, T_text)
635
+ ) -> tuple[list[tuple[mx.array, mx.array]], list[tuple[mx.array, mx.array]]]:
636
+ """Project K/V for every text_attn and style_attn block exactly once.
637
+
638
+ Returns ``(text_kv_list, style_kv_list)`` — both ordered to align with
639
+ the corresponding blocks encountered when iterating ``main_blocks``.
640
+ These tensors are invariant across the 5 Euler steps of one TTS
641
+ call; pre-projecting them once and feeding the result into
642
+ :meth:`velocity_cached` cuts ~ 4 × 2 × 5 = 40 redundant matmuls.
643
+ """
644
+ text_seq_len = mx.sum(text_mask, axis=(1, 2))
645
+ text_ntc = text_emb.transpose(0, 2, 1) # (B, T_text, 256)
646
+
647
+ text_kv: list[tuple[mx.array, mx.array]] = []
648
+ style_kv: list[tuple[mx.array, mx.array]] = []
649
+ for blk in self.vector_field.main_blocks:
650
+ if isinstance(blk, TextCrossAttnBlock):
651
+ text_kv.append(blk.attn.project_kv(text_ntc, text_seq_len=text_seq_len))
652
+ elif isinstance(blk, StyleCrossAttnBlock):
653
+ style_kv.append(blk.attention.project_kv(style_k, style_v))
654
+ return text_kv, style_kv
655
+
656
+ def velocity_cached(
657
+ self,
658
+ noisy_latent: mx.array,
659
+ latent_mask: mx.array,
660
+ text_mask: mx.array,
661
+ t_norm: mx.array,
662
+ text_kv: list[tuple[mx.array, mx.array]],
663
+ style_kv: list[tuple[mx.array, mx.array]],
664
+ ) -> mx.array:
665
+ """Same as :meth:`velocity` but reads K/V from pre-projected caches.
666
+
667
+ ``text_kv`` and ``style_kv`` must come from :meth:`precompute_cross_kv`
668
+ applied to the same (batched) conditioning tensors that will be
669
+ active for this call.
670
+ """
671
+ x = noisy_latent.transpose(0, 2, 1)
672
+ lat_mask_ntc = latent_mask.transpose(0, 2, 1)
673
+
674
+ x = self.vector_field.proj_in(x)
675
+ t_emb = self.vector_field.time_encoder(t_norm)
676
+ latent_seq_len = mx.sum(latent_mask, axis=(1, 2))
677
+
678
+ ti = 0
679
+ si = 0
680
+ for blk in self.vector_field.main_blocks:
681
+ if isinstance(blk, TextCrossAttnBlock):
682
+ x = blk(
683
+ x, lat_mask_ntc,
684
+ text_mask=text_mask,
685
+ latent_seq_len=latent_seq_len,
686
+ kv_cache=text_kv[ti],
687
+ )
688
+ ti += 1
689
+ elif isinstance(blk, StyleCrossAttnBlock):
690
+ x = blk(x, lat_mask_ntc, kv_cache=style_kv[si])
691
+ si += 1
692
+ else:
693
+ x = blk(x, lat_mask_ntc, t_emb=t_emb)
694
+
695
+ x = self.vector_field.last_convnext(x, lat_mask_ntc)
696
+ v_ntc = self.vector_field.proj_out(x)
697
+ return v_ntc.transpose(0, 2, 1)
698
+
699
+ def __call__(
700
+ self,
701
+ noisy_latent: mx.array, # (B, 144, T_lat) channels-first per ONNX I/O
702
+ text_emb: mx.array, # (B, 256, T_text) channels-first
703
+ style_ttl: mx.array, # (B, 50, 256) — used as both K and V for cond
704
+ latent_mask: mx.array, # (B, 1, T_lat)
705
+ text_mask: mx.array, # (B, 1, T_text)
706
+ current_step: mx.array, # (B,)
707
+ total_step: mx.array, # (B,)
708
+ cfg: bool = True,
709
+ ) -> mx.array:
710
+ """Run one Euler step with CFG (matches ONNX semantics).
711
+
712
+ With ``cfg=True`` (default) the model runs both conditional and
713
+ unconditional paths in a single batched forward and combines via
714
+ ``final = noisy + (4*cond_v - 3*uncond_v) / total_step``.
715
+
716
+ With ``cfg=False`` only the conditional path runs — half the work, but
717
+ produces a different (lower-quality) output. Useful for speed bench /
718
+ sanity tests.
719
+ """
720
+ B = noisy_latent.shape[0]
721
+ t_norm = current_step.astype(mx.float32) / total_step.astype(mx.float32)
722
+
723
+ if not cfg:
724
+ v = self.velocity(
725
+ noisy_latent, text_emb, style_ttl, style_ttl,
726
+ latent_mask, text_mask, t_norm,
727
+ )
728
+ return noisy_latent + v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype)
729
+
730
+ # CFG branch — build (2B, ...) inputs by concatenating cond + uncond.
731
+ # uncond text_emb = text_special_token broadcast to (B, 256, T_text).
732
+ # uncond style_k = style_key_special_token broadcast, similarly style_v.
733
+ text_uncond = mx.broadcast_to(
734
+ self.uncond_masker.text_special_token, (B, TEXT_DIM, text_emb.shape[2])
735
+ )
736
+ style_k_uncond = mx.broadcast_to(
737
+ self.uncond_masker.style_key_special_token, (B, STYLE_LEN, STYLE_DIM)
738
+ )
739
+ style_v_uncond = mx.broadcast_to(
740
+ self.uncond_masker.style_value_special_token, (B, STYLE_LEN, STYLE_DIM)
741
+ )
742
+
743
+ noisy_2 = mx.concatenate([noisy_latent, noisy_latent], axis=0)
744
+ text_2 = mx.concatenate([text_emb, text_uncond], axis=0)
745
+ style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
746
+ style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
747
+ lm_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
748
+ tm_2 = mx.concatenate([text_mask, text_mask], axis=0)
749
+ t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
750
+
751
+ v_2 = self.velocity(
752
+ noisy_2, text_2, style_k_2, style_v_2, lm_2, tm_2, t_norm_2,
753
+ ) # (2B, 144, T_lat)
754
+ cond_v = v_2[:B]
755
+ uncond_v = v_2[B:2 * B]
756
+ combined_v = self.CFG_COND_SCALE * cond_v - self.CFG_UNCOND_SCALE * uncond_v
757
+ return noisy_latent + combined_v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype)
758
+
759
+
760
+ __all__ = [
761
+ "ConvNeXtBlock", "ConvNeXtStack",
762
+ "Stack4Block", "TimeFiLMBlock", "ConvNeXt1Block",
763
+ "TextCrossAttnBlock", "StyleCrossAttnBlock",
764
+ "TimeEncoder", "VectorEstimator",
765
+ ]
src/supertonic_3_mlx/vocoder.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Supertonic 3 vocoder — latent → 44.1 kHz waveform, MLX port.
2
+
3
+ Pipeline (operating in channels-last NTC layout, then converted to channels-first
4
+ for output reshape):
5
+
6
+ latent [B, 144, T_lat] (output of vector_estimator)
7
+ → /= normalizer.scale (scalar)
8
+ → reshape [B, 24, T_lat*6] # de-compress
9
+ → (* latent_std + latent_mean) # de-normalise
10
+ → transpose to NTC [B, T_lat*6, 24]
11
+ → embed Conv1d(24→512, k=7, sym-edge pad) [B, T_lat*6, 512]
12
+ → 10× ConvNeXt(dim=512, hidden=2048, k=7,
13
+ dilations [1,2,4,1,2,4,1,1,1,1])
14
+ → final_norm: BatchNorm1d (eval-time: running stats only)
15
+ → head.layer1: Conv1d(512→2048, k=3, sym-edge pad)
16
+ → PReLU (with per-channel learnable slope)
17
+ → head.layer2: Conv1d(2048→512, k=1, no bias)
18
+ → transpose to (B, 512, T_lat*6) → flatten → wav (B, T_lat*6*512)
19
+
20
+ The 512 samples/step × 6 chunk × 44.1 kHz → T_lat steps of about 0.0697 s each.
21
+ """
22
+ from __future__ import annotations
23
+
24
+ import mlx.core as mx
25
+ import mlx.nn as nn
26
+
27
+ from supertonic_3_mlx._config import EPS_LN
28
+ from supertonic_3_mlx._nn_wrappers import WrappedNorm
29
+ from supertonic_3_mlx.vector_estimator import _gelu_exact
30
+
31
+
32
+ def _pad_left_edge(x: mx.array, pad: int) -> mx.array:
33
+ """Causal replicate-edge pad on the time axis (axis=1 for [B, T, C]).
34
+
35
+ Pads ``pad`` time-steps on the LEFT only by replicating the first frame.
36
+ Matches the ONNX vocoder pads spec ``[0, 0, pad, 0, 0, 0]``.
37
+ """
38
+ if pad == 0:
39
+ return x
40
+ left = mx.broadcast_to(x[:, :1, :], (x.shape[0], pad, x.shape[2]))
41
+ return mx.concatenate([left, x], axis=1)
42
+
43
+
44
+ VOC_DIM = 512
45
+ VOC_HIDDEN = 2048
46
+ VOC_K = 7
47
+ VOC_HEAD_K = 3
48
+ VOC_LDIM = 24 # de-compressed channels (24 × 6 = 144 input)
49
+ VOC_CHUNK_COMPRESS = 6
50
+ VOC_NUM_CONVNEXT_LAYERS = 10
51
+ VOC_DILATIONS = (1, 2, 4, 1, 2, 4, 1, 1, 1, 1)
52
+ EPS_BN = 1e-5
53
+
54
+
55
+ class _Conv1dNet(nn.Module):
56
+ """Conv1d wrapped under ``.net`` to match ONNX storage ``.net.weight/bias``."""
57
+
58
+ def __init__(self, in_dim: int, out_dim: int, kernel: int, dilation: int = 1,
59
+ groups: int = 1, bias: bool = True) -> None:
60
+ super().__init__()
61
+ class _Net(nn.Module):
62
+ def __init__(_):
63
+ super().__init__()
64
+ # MLX Conv1d weight: (out, K, in/groups)
65
+ _.weight = mx.zeros((out_dim, kernel, in_dim // groups))
66
+ if bias:
67
+ _.bias = mx.zeros((out_dim,))
68
+ else:
69
+ _.bias = None
70
+ def __call__(_, x, dilation=1):
71
+ y = mx.conv1d(x, _.weight, stride=1, padding=0, dilation=dilation,
72
+ groups=groups)
73
+ if _.bias is not None:
74
+ y = y + _.bias
75
+ return y
76
+ self.net = _Net()
77
+ self.dilation = dilation
78
+ self.groups = groups
79
+ self.kernel = kernel
80
+
81
+ def __call__(self, x: mx.array) -> mx.array:
82
+ return self.net(x, dilation=self.dilation)
83
+
84
+
85
+ class _VocConvNeXtBlock(nn.Module):
86
+ """ConvNeXt block matching keys ``convnext.N.{dwconv.net,norm.norm,pwconv1,pwconv2,gamma}``."""
87
+
88
+ def __init__(self, dilation: int) -> None:
89
+ super().__init__()
90
+ self.dilation = dilation
91
+ self.pad = dilation * (VOC_K - 1)
92
+ self.dwconv = _Conv1dNet(VOC_DIM, VOC_DIM, kernel=VOC_K, dilation=dilation,
93
+ groups=VOC_DIM, bias=True)
94
+ self.norm = WrappedNorm(VOC_DIM, eps=EPS_LN)
95
+ # pwconv1 / pwconv2 stored as Conv1d k=1 → loaded after squeeze to Linear.
96
+ self.pwconv1 = nn.Linear(VOC_DIM, VOC_HIDDEN, bias=True)
97
+ self.pwconv2 = nn.Linear(VOC_HIDDEN, VOC_DIM, bias=True)
98
+ self.gamma = mx.zeros((VOC_DIM,))
99
+
100
+ def __call__(self, x: mx.array) -> mx.array:
101
+ residual = x
102
+ y = _pad_left_edge(x, self.pad)
103
+ y = self.dwconv(y)
104
+ y = self.norm(y)
105
+ y = self.pwconv1(y)
106
+ y = _gelu_exact(y)
107
+ y = self.pwconv2(y)
108
+ y = y * self.gamma
109
+ return residual + y
110
+
111
+
112
+ class _BatchNorm1dEval(nn.Module):
113
+ """Eval-mode BatchNorm1d: applies stored running_mean/running_var only.
114
+
115
+ Loaded keys: ``norm.{weight,bias,running_mean,running_var}``.
116
+ """
117
+
118
+ def __init__(self) -> None:
119
+ super().__init__()
120
+ class _Norm(nn.Module):
121
+ def __init__(_):
122
+ super().__init__()
123
+ _.weight = mx.ones((VOC_DIM,))
124
+ _.bias = mx.zeros((VOC_DIM,))
125
+ _.running_mean = mx.zeros((VOC_DIM,))
126
+ _.running_var = mx.ones((VOC_DIM,))
127
+ def __call__(_, x):
128
+ # x: (B, T, C). BN1d normalises across batch+time per channel.
129
+ # Eval mode: use stored running stats.
130
+ norm = (x - _.running_mean) * mx.rsqrt(_.running_var + EPS_BN)
131
+ return norm * _.weight + _.bias
132
+ self.norm = _Norm()
133
+
134
+ def __call__(self, x: mx.array) -> mx.array:
135
+ return self.norm(x)
136
+
137
+
138
+ class _VocHeadActivation(nn.Module):
139
+ """PReLU with per-channel learnable slope (weight shape (C,))."""
140
+
141
+ def __init__(self) -> None:
142
+ super().__init__()
143
+ # ONNX anonymous PReLU stores slope of shape (1,) sometimes or (C,).
144
+ # We default to (1,) and reshape on load if needed.
145
+ self.weight = mx.zeros((1,))
146
+
147
+ def __call__(self, x: mx.array) -> mx.array:
148
+ # PReLU: max(0, x) + slope × min(0, x).
149
+ # slope broadcasts over (B, T, C) or (B, C, T) depending on layout.
150
+ zero = mx.array(0.0, dtype=x.dtype)
151
+ return mx.maximum(x, zero) + self.weight * mx.minimum(x, zero)
152
+
153
+
154
+ class _VocHead(nn.Module):
155
+ """``head.layer1`` (Conv1d 512→2048 k=3) + ``head.act`` (PReLU) + ``head.layer2`` (Conv1d k=1, no bias)."""
156
+
157
+ def __init__(self) -> None:
158
+ super().__init__()
159
+ self.layer1 = _Conv1dNet(VOC_DIM, VOC_HIDDEN, kernel=VOC_HEAD_K, bias=True)
160
+ self.act = _VocHeadActivation()
161
+ # layer2 has no .net wrapper in ONNX (different from layer1)
162
+ # ONNX: head.layer2.weight (512, 2048, 1) — Conv1d k=1, no bias.
163
+ # We represent it directly without .net wrap.
164
+ self.layer2 = _VocLayer2()
165
+
166
+ def __call__(self, x: mx.array) -> mx.array:
167
+ # x: (B, T, 512)
168
+ pad = VOC_HEAD_K - 1
169
+ y = _pad_left_edge(x, pad)
170
+ y = self.layer1(y) # (B, T, 2048)
171
+ y = self.act(y)
172
+ y = self.layer2(y) # (B, T, 512)
173
+ return y
174
+
175
+
176
+ class _VocLayer2(nn.Module):
177
+ """Conv1d k=1 (2048 → 512), no bias. Keys: ``layer2.weight (512, 2048, 1)``."""
178
+
179
+ def __init__(self) -> None:
180
+ super().__init__()
181
+ # MLX Conv1d weight shape: (out, K, in/groups) = (512, 1, 2048)
182
+ # ONNX storage: (out, in, 1) = (512, 2048, 1). Same size; reshape on load.
183
+ self.weight = mx.zeros((VOC_DIM, 1, VOC_HIDDEN))
184
+
185
+ def __call__(self, x: mx.array) -> mx.array:
186
+ return mx.conv1d(x, self.weight, stride=1, padding=0)
187
+
188
+
189
+ class _VocEmbed(nn.Module):
190
+ """Initial Conv1d(24→512, k=7) with sym-edge pad.
191
+
192
+ The weight + bias are anonymous in the ONNX graph (``onnx::Conv_1441`` and
193
+ ``onnx::Conv_1442``); the conversion recovers them via the Conv node path
194
+ ``/decoder/embed/net/Conv`` → structured name ``tts.ae.decoder.embed.net.{weight,bias}``.
195
+ """
196
+
197
+ def __init__(self) -> None:
198
+ super().__init__()
199
+ class _Net(nn.Module):
200
+ def __init__(_):
201
+ super().__init__()
202
+ _.weight = mx.zeros((VOC_DIM, VOC_K, VOC_LDIM))
203
+ _.bias = mx.zeros((VOC_DIM,))
204
+ def __call__(_, x):
205
+ return mx.conv1d(x, _.weight, stride=1, padding=0) + _.bias
206
+ self.net = _Net()
207
+
208
+ def __call__(self, x: mx.array) -> mx.array:
209
+ pad = VOC_K - 1
210
+ y = _pad_left_edge(x, pad)
211
+ return self.net(y)
212
+
213
+
214
+ class _VocDecoder(nn.Module):
215
+ """``tts.ae.decoder.X`` namespace."""
216
+
217
+ def __init__(self) -> None:
218
+ super().__init__()
219
+ self.embed = _VocEmbed()
220
+ self.convnext = [_VocConvNeXtBlock(d) for d in VOC_DILATIONS]
221
+ self.final_norm = _BatchNorm1dEval()
222
+ self.head = _VocHead()
223
+
224
+
225
+ class _AEContainer(nn.Module):
226
+ """``tts.ae.X`` — holds latent_mean, latent_std, decoder."""
227
+
228
+ def __init__(self) -> None:
229
+ super().__init__()
230
+ self.latent_mean = mx.zeros((1, VOC_LDIM, 1))
231
+ self.latent_std = mx.ones((1, VOC_LDIM, 1))
232
+ self.decoder = _VocDecoder()
233
+
234
+
235
+ class _TtlContainer(nn.Module):
236
+ """``tts.ttl.normalizer.scale`` (scalar) — divides the latent before de-norm."""
237
+
238
+ def __init__(self) -> None:
239
+ super().__init__()
240
+ class _Normalizer(nn.Module):
241
+ def __init__(_):
242
+ super().__init__()
243
+ _.scale = mx.array(1.0)
244
+ self.normalizer = _Normalizer()
245
+
246
+
247
+ class _TtsContainer(nn.Module):
248
+ def __init__(self) -> None:
249
+ super().__init__()
250
+ self.ttl = _TtlContainer()
251
+ self.ae = _AEContainer()
252
+
253
+
254
+ class Vocoder(nn.Module):
255
+ """Latent → waveform decoder (44.1 kHz mono).
256
+
257
+ Submodule namespace matches ONNX keys ``tts.X.Y`` exactly.
258
+ """
259
+
260
+ def __init__(self) -> None:
261
+ super().__init__()
262
+ self.tts = _TtsContainer()
263
+
264
+ def __call__(self, latent: mx.array) -> mx.array:
265
+ # latent: (B, 144, T_lat)
266
+ B = latent.shape[0]
267
+ T_lat = latent.shape[2]
268
+
269
+ # /= scale (scalar)
270
+ x = latent / self.tts.ttl.normalizer.scale
271
+
272
+ # reshape (B, 144, T_lat) → (B, 24, T_lat*6)
273
+ x = x.reshape(B, VOC_LDIM, VOC_CHUNK_COMPRESS, T_lat) # (B, 24, 6, T_lat)
274
+ x = x.transpose(0, 1, 3, 2) # (B, 24, T_lat, 6)
275
+ x = x.reshape(B, VOC_LDIM, T_lat * VOC_CHUNK_COMPRESS) # (B, 24, T_lat*6)
276
+
277
+ # De-normalise: (* std + mean)
278
+ x = x * self.tts.ae.latent_std + self.tts.ae.latent_mean
279
+
280
+ # Transpose to NTC for Conv1d layers
281
+ x = x.transpose(0, 2, 1) # (B, T_lat*6, 24)
282
+
283
+ # embed
284
+ x = self.tts.ae.decoder.embed(x) # (B, T_lat*6, 512)
285
+
286
+ # 10× ConvNeXt
287
+ for blk in self.tts.ae.decoder.convnext:
288
+ x = blk(x)
289
+
290
+ # final_norm (BatchNorm1d eval)
291
+ x = self.tts.ae.decoder.final_norm(x)
292
+
293
+ # head
294
+ x = self.tts.ae.decoder.head(x) # (B, T_lat*6, 512)
295
+
296
+ # Flatten time × channels row-major → waveform (matches ONNX:
297
+ # head.layer2 Conv (B, 512, T_lat*6) → Transpose to (B, T_lat*6, 512) →
298
+ # Reshape to (B, T_lat*6*512). Since the head already runs in NTC, we
299
+ # are already in the post-Transpose layout and only the Reshape remains).
300
+ wav = x.reshape(B, -1) # (B, T_lat*6*512)
301
+ return wav
302
+
303
+
304
+ __all__ = ["Vocoder", "VOC_DIM", "VOC_HIDDEN", "VOC_LDIM", "VOC_CHUNK_COMPRESS"]
src/supertonic_3_mlx/weights.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ONNX → MLX safetensors conversion for Supertonic 3.
2
+
3
+ Two-stage extraction:
4
+ 1. **Named initializers** (e.g. ``vector_estimator.tts.ttl.vector_field.main_blocks.0.convnext.0.dwconv.weight``)
5
+ — straight name strip + optional shape transformation.
6
+ 2. **Anonymous MatMul weights** (e.g. ``onnx::MatMul_3391``) — looked up via the
7
+ MatMul node graph: each MatMul output path is the human-readable name of the
8
+ weight (e.g. ``…/W_query/linear/MatMul_output_0``); we trace the second
9
+ operand initializer and rebind it to the structured name + transpose to
10
+ the MLX Linear layout ``(out, in)``.
11
+
12
+ Shape transformations:
13
+ - depthwise dwconv: ONNX ``(C, 1, K)`` → MLX ``(C, K, 1)``
14
+ - pwconv1/2 k=1: ONNX ``(out, in, 1)`` → MLX ``(out, in)``
15
+ - proj_in/out k=1: ONNX ``(out, in, 1)`` → MLX ``(out, in)``
16
+ - MatMul Linear: ONNX ``(in, out)`` → MLX ``(out, in)``
17
+ - gamma: ONNX ``(1, dim, 1)`` → MLX ``(dim,)``
18
+ """
19
+ from __future__ import annotations
20
+
21
+ from pathlib import Path
22
+ from typing import Dict, Tuple
23
+
24
+ import mlx.core as mx
25
+ import numpy as np
26
+
27
+
28
+ _ONNX_PREFIX = "vector_estimator.tts.ttl."
29
+
30
+ _DWCONV_SUFFIX = ".dwconv.weight"
31
+ _PWCONV_SUFFIXES = (".pwconv1.weight", ".pwconv2.weight")
32
+ _GAMMA_SUFFIX = ".gamma"
33
+
34
+
35
+ def _strip_prefix(name: str) -> str:
36
+ if name.startswith(_ONNX_PREFIX):
37
+ return name[len(_ONNX_PREFIX):]
38
+ return name
39
+
40
+
41
+ def _is_named_weight(name: str) -> bool:
42
+ """True if this is a structured weight (vs anonymous graph constant)."""
43
+ if name.startswith(_ONNX_PREFIX):
44
+ return True
45
+ if name.startswith("uncond_masker."):
46
+ return True
47
+ return False
48
+
49
+
50
+ def _convert_named(clean_name: str, arr: np.ndarray) -> np.ndarray:
51
+ """Apply shape transforms to a named initializer based on its key."""
52
+ # Depthwise Conv1d weight: (C, 1, K) → (C, K, 1)
53
+ if clean_name.endswith(_DWCONV_SUFFIX) and arr.ndim == 3 and arr.shape[1] == 1 and arr.shape[2] != 1:
54
+ arr = np.transpose(arr, (0, 2, 1))
55
+
56
+ # Pointwise k=1 / proj net weight: (out, in, 1) → (out, in)
57
+ if (any(clean_name.endswith(s) for s in _PWCONV_SUFFIXES) or clean_name.endswith(".net.weight")) \
58
+ and arr.ndim == 3 and arr.shape[-1] == 1:
59
+ arr = arr.squeeze(-1)
60
+
61
+ # gamma: (1, C, 1) → (C,)
62
+ if clean_name.endswith(_GAMMA_SUFFIX) and arr.ndim == 3 and arr.shape[0] == 1 and arr.shape[2] == 1:
63
+ arr = arr.reshape(arr.shape[1])
64
+
65
+ return arr
66
+
67
+
68
+ def _matmul_output_to_clean_name(matmul_output: str) -> str:
69
+ """Map a MatMul node output path to the structured ``.weight`` key.
70
+
71
+ Example::
72
+
73
+ /vector_estimator/vector_field/main_blocks.3/attn/W_query/linear/MatMul_output_0
74
+ → vector_field.main_blocks.3.attn.W_query.linear.weight
75
+ """
76
+ # Strip prefix slash and the trailing /MatMul_output_0
77
+ path = matmul_output.lstrip("/")
78
+ if path.endswith("/MatMul_output_0"):
79
+ path = path[: -len("/MatMul_output_0")]
80
+ # Drop leading "vector_estimator/" if present
81
+ if path.startswith("vector_estimator/"):
82
+ path = path[len("vector_estimator/"):]
83
+ return path.replace("/", ".") + ".weight"
84
+
85
+
86
+ def convert_onnx_to_mlx(onnx_path: str | Path) -> Dict[str, mx.array]:
87
+ """Load an ONNX model and return all weights as ``{clean_name: mx.array}``.
88
+
89
+ Combines named initializers and MatMul-only weights into a single dict ready
90
+ for ``model.load_weights(...)``.
91
+ """
92
+ import onnx
93
+ from onnx import numpy_helper
94
+
95
+ model = onnx.load(str(onnx_path))
96
+
97
+ # Build initializer name → numpy array map (in-memory once)
98
+ inits: Dict[str, np.ndarray] = {
99
+ init.name: numpy_helper.to_array(init) for init in model.graph.initializer
100
+ }
101
+
102
+ out: Dict[str, mx.array] = {}
103
+
104
+ # Stage 1: named initializers
105
+ for name, arr in inits.items():
106
+ if not _is_named_weight(name):
107
+ continue
108
+ clean = _strip_prefix(name)
109
+ arr = _convert_named(clean, arr)
110
+ out[clean] = mx.array(arr)
111
+
112
+ # Stage 2: anonymous MatMul weights, recovered via the graph
113
+ for node in model.graph.node:
114
+ if node.op_type != "MatMul":
115
+ continue
116
+ if len(node.input) < 2:
117
+ continue
118
+ # The weight is conventionally the second operand
119
+ weight_name = node.input[1]
120
+ if weight_name not in inits:
121
+ continue
122
+ # Skip if it's already named structurally (shouldn't happen here)
123
+ if _is_named_weight(weight_name):
124
+ continue
125
+ # Look up the structured name from the MatMul output path
126
+ if len(node.output) < 1:
127
+ continue
128
+ clean = _matmul_output_to_clean_name(node.output[0])
129
+ # ONNX MatMul stores W as (in, out); MLX Linear expects (out, in)
130
+ arr = inits[weight_name]
131
+ if arr.ndim == 2:
132
+ arr = arr.T
133
+ out[clean] = mx.array(arr)
134
+
135
+ if not out:
136
+ raise RuntimeError(f"no weights extracted from {onnx_path}")
137
+ return out
138
+
139
+
140
+ def save_safetensors(
141
+ onnx_path: str | Path,
142
+ output_path: str | Path,
143
+ ) -> Dict[str, Tuple[int, ...]]:
144
+ """Convert an ONNX file to MLX safetensors. Returns a {name: shape} map."""
145
+ weights = convert_onnx_to_mlx(onnx_path)
146
+ output_path = Path(output_path)
147
+ output_path.parent.mkdir(parents=True, exist_ok=True)
148
+ mx.save_safetensors(str(output_path), weights)
149
+ return {k: tuple(v.shape) for k, v in weights.items()}
150
+
151
+
152
+ __all__ = ["convert_onnx_to_mlx", "save_safetensors"]
unicode_indexer.json ADDED
The diff for this file is too large to render. See raw diff
 
voice_styles/F1.json ADDED
The diff for this file is too large to render. See raw diff
 
voice_styles/F2.json ADDED
The diff for this file is too large to render. See raw diff
 
voice_styles/F3.json ADDED
The diff for this file is too large to render. See raw diff
 
voice_styles/F4.json ADDED
The diff for this file is too large to render. See raw diff
 
voice_styles/F5.json ADDED
The diff for this file is too large to render. See raw diff
 
voice_styles/M1.json ADDED
The diff for this file is too large to render. See raw diff
 
voice_styles/M2.json ADDED
The diff for this file is too large to render. See raw diff
 
voice_styles/M3.json ADDED
The diff for this file is too large to render. See raw diff
 
voice_styles/M4.json ADDED
The diff for this file is too large to render. See raw diff
 
voice_styles/M5.json ADDED
The diff for this file is too large to render. See raw diff
 
weights/duration_predictor.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd473acb6e0ac27426084488ccb3b3cc184e70d05db90897e2b892846db5dcb3
3
+ size 3470807
weights/text_encoder.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9df20bb79496718b36d2c0fc37636d3f78d6ef751b2899ff6dfeb975ae737ada
3
+ size 36022466
weights/vector_estimator.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2359240f2dcaee03b4800102aa0bea00223d2867ab752ef01af2b1cfaf92f3a6
3
+ size 256053073
weights/vocoder.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2ec31ab7c554f6e15b9a6780554b5d3502345de7848b310966bfb4e1ea4e526
3
+ size 101364763