Hak5 commited on
Commit
7496177
·
verified ·
1 Parent(s): 9c4c740

Add bundled AVoice runtime for HF-only inference

Browse files
runtime/LICENSE ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 2, June 1991
3
+
4
+ Copyright (C) 1989, 1991 Free Software Foundation, Inc.,
5
+ 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
6
+ Everyone is permitted to copy and distribute verbatim copies
7
+ of this license document, but changing it is not allowed.
8
+
9
+ Preamble
10
+
11
+ The licenses for most software are designed to take away your
12
+ freedom to share and change it. By contrast, the GNU General Public
13
+ License is intended to guarantee your freedom to share and change free
14
+ software--to make sure the software is free for all its users. This
15
+ General Public License applies to most of the Free Software
16
+ Foundation's software and to any other program whose authors commit to
17
+ using it. (Some other Free Software Foundation software is covered by
18
+ the GNU Lesser General Public License instead.) You can apply it to
19
+ your programs, too.
20
+
21
+ When we speak of free software, we are referring to freedom, not
22
+ price. Our General Public Licenses are designed to make sure that you
23
+ have the freedom to distribute copies of free software (and charge for
24
+ this service if you wish), that you receive source code or can get it
25
+ if you want it, that you can change the software or use pieces of it
26
+ in new free programs; and that you know you can do these things.
27
+
28
+ To protect your rights, we need to make restrictions that forbid
29
+ anyone to deny you these rights or to ask you to surrender the rights.
30
+ These restrictions translate to certain responsibilities for you if you
31
+ distribute copies of the software, or if you modify it.
32
+
33
+ For example, if you distribute copies of such a program, whether
34
+ gratis or for a fee, you must give the recipients all the rights that
35
+ you have. You must make sure that they, too, receive or can get the
36
+ source code. And you must show them these terms so they know their
37
+ rights.
38
+
39
+ We protect your rights with two steps: (1) copyright the software, and
40
+ (2) offer you this license which gives you legal permission to copy,
41
+ distribute and/or modify the software.
42
+
43
+ Also, for each author's protection and ours, we want to make certain
44
+ that everyone understands that there is no warranty for this free
45
+ software. If the software is modified by someone else and passed on, we
46
+ want its recipients to know that what they have is not the original, so
47
+ that any problems introduced by others will not reflect on the original
48
+ authors' reputations.
49
+
50
+ Finally, any free program is threatened constantly by software
51
+ patents. We wish to avoid the danger that redistributors of a free
52
+ program will individually obtain patent licenses, in effect making the
53
+ program proprietary. To prevent this, we have made it clear that any
54
+ patent must be licensed for everyone's free use or not licensed at all.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ GNU GENERAL PUBLIC LICENSE
60
+ TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
61
+
62
+ 0. This License applies to any program or other work which contains
63
+ a notice placed by the copyright holder saying it may be distributed
64
+ under the terms of this General Public License. The "Program", below,
65
+ refers to any such program or work, and a "work based on the Program"
66
+ means either the Program or any derivative work under copyright law:
67
+ that is to say, a work containing the Program or a portion of it,
68
+ either verbatim or with modifications and/or translated into another
69
+ language. (Hereinafter, translation is included without limitation in
70
+ the term "modification".) Each licensee is addressed as "you".
71
+
72
+ Activities other than copying, distribution and modification are not
73
+ covered by this License; they are outside its scope. The act of
74
+ running the Program is not restricted, and the output from the Program
75
+ is covered only if its contents constitute a work based on the
76
+ Program (independent of having been made by running the Program).
77
+ Whether that is true depends on what the Program does.
78
+
79
+ 1. You may copy and distribute verbatim copies of the Program's
80
+ source code as you receive it, in any medium, provided that you
81
+ conspicuously and appropriately publish on each copy an appropriate
82
+ copyright notice and disclaimer of warranty; keep intact all the
83
+ notices that refer to this License and to the absence of any warranty;
84
+ and give any other recipients of the Program a copy of this License
85
+ along with the Program.
86
+
87
+ You may charge a fee for the physical act of transferring a copy, and
88
+ you may at your option offer warranty protection in exchange for a fee.
89
+
90
+ 2. You may modify your copy or copies of the Program or any portion
91
+ of it, thus forming a work based on the Program, and copy and
92
+ distribute such modifications or work under the terms of Section 1
93
+ above, provided that you also meet all of these conditions:
94
+
95
+ a) You must cause the modified files to carry prominent notices
96
+ stating that you changed the files and the date of any change.
97
+
98
+ b) You must cause any work that you distribute or publish, that in
99
+ whole or in part contains or is derived from the Program or any
100
+ part thereof, to be licensed as a whole at no charge to all third
101
+ parties under the terms of this License.
102
+
103
+ c) If the modified program normally reads commands interactively
104
+ when run, you must cause it, when started running for such
105
+ interactive use in the most ordinary way, to print or display an
106
+ announcement including an appropriate copyright notice and a
107
+ notice that there is no warranty (or else, saying that you provide
108
+ a warranty) and that users may redistribute the program under
109
+ these conditions, and telling the user how to view a copy of this
110
+ License. (Exception: if the Program itself is interactive but
111
+ does not normally print such an announcement, your work based on
112
+ the Program is not required to print an announcement.)
113
+
114
+ These requirements apply to the modified work as a whole. If
115
+ identifiable sections of that work are not derived from the Program,
116
+ and can be reasonably considered independent and separate works in
117
+ themselves, then this License, and its terms, do not apply to those
118
+ sections when you distribute them as separate works. But when you
119
+ distribute the same sections as part of a whole which is a work based
120
+ on the Program, the distribution of the whole must be on the terms of
121
+ this License, whose permissions for other licensees extend to the
122
+ entire whole, and thus to each and every part regardless of who wrote it.
123
+
124
+ Thus, it is not the intent of this section to claim rights or contest
125
+ your rights to work written entirely by you; rather, the intent is to
126
+ exercise the right to control the distribution of derivative or
127
+ collective works based on the Program.
128
+
129
+ In addition, mere aggregation of another work not based on the Program
130
+ with the Program (or with a work based on the Program) on a volume of
131
+ a storage or distribution medium does not bring the other work under
132
+ the scope of this License.
133
+
134
+ 3. You may copy and distribute the Program (or a work based on it,
135
+ under Section 2) in object code or executable form under the terms of
136
+ Sections 1 and 2 above provided that you also do one of the following:
137
+
138
+ a) Accompany it with the complete corresponding machine-readable
139
+ source code, which must be distributed under the terms of Sections
140
+ 1 and 2 above on a medium customarily used for software interchange; or,
141
+
142
+ b) Accompany it with a written offer, valid for at least three
143
+ years, to give any third party, for a charge no more than your
144
+ cost of physically performing source distribution, a complete
145
+ machine-readable copy of the corresponding source code, to be
146
+ distributed under the terms of Sections 1 and 2 above on a medium
147
+ customarily used for software interchange; or,
148
+
149
+ c) Accompany it with the information you received as to the offer
150
+ to distribute corresponding source code. (This alternative is
151
+ allowed only for noncommercial distribution and only if you
152
+ received the program in object code or executable form with such
153
+ an offer, in accord with Subsection b above.)
154
+
155
+ The source code for a work means the preferred form of the work for
156
+ making modifications to it. For an executable work, complete source
157
+ code means all the source code for all modules it contains, plus any
158
+ associated interface definition files, plus the scripts used to
159
+ control compilation and installation of the executable. However, as a
160
+ special exception, the source code distributed need not include
161
+ anything that is normally distributed (in either source or binary
162
+ form) with the major components (compiler, kernel, and so on) of the
163
+ operating system on which the executable runs, unless that component
164
+ itself accompanies the executable.
165
+
166
+ If distribution of executable or object code is made by offering
167
+ access to copy from a designated place, then offering equivalent
168
+ access to copy the source code from the same place counts as
169
+ distribution of the source code, even though third parties are not
170
+ compelled to copy the source along with the object code.
171
+
172
+ 4. You may not copy, modify, sublicense, or distribute the Program
173
+ except as expressly provided under this License. Any attempt
174
+ otherwise to copy, modify, sublicense or distribute the Program is
175
+ void, and will automatically terminate your rights under this License.
176
+ However, parties who have received copies, or rights, from you under
177
+ this License will not have their licenses terminated so long as such
178
+ parties remain in full compliance.
179
+
180
+ 5. You are not required to accept this License, since you have not
181
+ signed it. However, nothing else grants you permission to modify or
182
+ distribute the Program or its derivative works. These actions are
183
+ prohibited by law if you do not accept this License. Therefore, by
184
+ modifying or distributing the Program (or any work based on the
185
+ Program), you indicate your acceptance of this License to do so, and
186
+ all its terms and conditions for copying, distributing or modifying
187
+ the Program or works based on it.
188
+
189
+ 6. Each time you redistribute the Program (or any work based on the
190
+ Program), the recipient automatically receives a license from the
191
+ original licensor to copy, distribute or modify the Program subject to
192
+ these terms and conditions. You may not impose any further
193
+ restrictions on the recipients' exercise of the rights granted herein.
194
+ You are not responsible for enforcing compliance by third parties to
195
+ this License.
196
+
197
+ 7. If, as a consequence of a court judgment or allegation of patent
198
+ infringement or for any other reason (not limited to patent issues),
199
+ conditions are imposed on you (whether by court order, agreement or
200
+ otherwise) that contradict the conditions of this License, they do not
201
+ excuse you from the conditions of this License. If you cannot
202
+ distribute so as to satisfy simultaneously your obligations under this
203
+ License and any other pertinent obligations, then as a consequence you
204
+ may not distribute the Program at all. For example, if a patent
205
+ license would not permit royalty-free redistribution of the Program by
206
+ all those who receive copies directly or indirectly through you, then
207
+ the only way you could satisfy both it and this License would be to
208
+ refrain entirely from distribution of the Program.
209
+
210
+ If any portion of this section is held invalid or unenforceable under
211
+ any particular circumstance, the balance of the section is intended to
212
+ apply and the section as a whole is intended to apply in other
213
+ circumstances.
214
+
215
+ It is not the purpose of this section to induce you to infringe any
216
+ patents or other property right claims or to contest validity of any
217
+ such claims; this section has the sole purpose of protecting the
218
+ integrity of the free software distribution system, which is
219
+ implemented by public license practices. Many people have made
220
+ generous contributions to the wide range of software distributed
221
+ through that system in reliance on consistent application of that
222
+ system; it is up to the author/donor to decide if he or she is willing
223
+ to distribute software through any other system and a licensee cannot
224
+ impose that choice.
225
+
226
+ This section is intended to make thoroughly clear what is believed to
227
+ be a consequence of the rest of this License.
228
+
229
+ 8. If the distribution and/or use of the Program is restricted in
230
+ certain countries either by patents or by copyrighted interfaces, the
231
+ original copyright holder who places the Program under this License
232
+ may add an explicit geographical distribution limitation excluding
233
+ those countries, so that distribution is permitted only in or among
234
+ countries not thus excluded. In such case, this License incorporates
235
+ the limitation as if written in the body of this License.
236
+
237
+ 9. The Free Software Foundation may publish revised and/or new versions
238
+ of the General Public License from time to time. Such new versions will
239
+ be similar in spirit to the present version, but may differ in detail to
240
+ address new problems or concerns.
241
+
242
+ Each version is given a distinguishing version number. If the Program
243
+ specifies a version number of this License which applies to it and "any
244
+ later version", you have the option of following the terms and conditions
245
+ either of that version or of any later version published by the Free
246
+ Software Foundation. If the Program does not specify a version number of
247
+ this License, you may choose any version ever published by the Free Software
248
+ Foundation.
249
+
250
+ 10. If you wish to incorporate parts of the Program into other free
251
+ programs whose distribution conditions are different, write to the author
252
+ to ask for permission. For software which is copyrighted by the Free
253
+ Software Foundation, write to the Free Software Foundation; we sometimes
254
+ make exceptions for this. Our decision will be guided by the two goals
255
+ of preserving the free status of all derivatives of our free software and
256
+ of promoting the sharing and reuse of software generally.
257
+
258
+ NO WARRANTY
259
+
260
+ 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
261
+ FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
262
+ OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
263
+ PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
264
+ OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
265
+ MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
266
+ TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
267
+ PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
268
+ REPAIR OR CORRECTION.
269
+
270
+ 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
271
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR
272
+ REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
273
+ INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
274
+ OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED
275
+ TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY
276
+ YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
277
+ PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
278
+ POSSIBILITY OF SUCH DAMAGES.
279
+
280
+ END OF TERMS AND CONDITIONS
281
+
282
+ How to Apply These Terms to Your New Programs
283
+
284
+ If you develop a new program, and you want it to be of the greatest
285
+ possible use to the public, the best way to achieve this is to make it
286
+ free software which everyone can redistribute and change under these terms.
287
+
288
+ To do so, attach the following notices to the program. It is safest
289
+ to attach them to the start of each source file to most effectively
290
+ convey the exclusion of warranty; and each file should have at least
291
+ the "copyright" line and a pointer to where the full notice is found.
292
+
293
+ <one line to give the program's name and a brief idea of what it does.>
294
+ Copyright (C) <year> <name of author>
295
+
296
+ This program is free software; you can redistribute it and/or modify
297
+ it under the terms of the GNU General Public License as published by
298
+ the Free Software Foundation; either version 2 of the License, or
299
+ (at your option) any later version.
300
+
301
+ This program is distributed in the hope that it will be useful,
302
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
303
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
304
+ GNU General Public License for more details.
305
+
306
+ You should have received a copy of the GNU General Public License along
307
+ with this program; if not, write to the Free Software Foundation, Inc.,
308
+ 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
309
+
310
+ Also add information on how to contact you by electronic and paper mail.
311
+
312
+ If the program is interactive, make it output a short notice like this
313
+ when it starts in an interactive mode:
314
+
315
+ Gnomovision version 69, Copyright (C) year name of author
316
+ Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
317
+ This is free software, and you are welcome to redistribute it
318
+ under certain conditions; type `show c' for details.
319
+
320
+ The hypothetical commands `show w' and `show c' should show the appropriate
321
+ parts of the General Public License. Of course, the commands you use may
322
+ be called something other than `show w' and `show c'; they could even be
323
+ mouse-clicks or menu items--whatever suits your program.
324
+
325
+ You should also get your employer (if you work as a programmer) or your
326
+ school, if any, to sign a "copyright disclaimer" for the program, if
327
+ necessary. Here is a sample; alter the names:
328
+
329
+ Yoyodyne, Inc., hereby disclaims all copyright interest in the program
330
+ `Gnomovision' (which makes passes at compilers) written by James Hacker.
331
+
332
+ <signature of Ty Coon>, 1 April 1989
333
+ Ty Coon, President of Vice
334
+
335
+ This General Public License does not permit incorporating your program into
336
+ proprietary programs. If your program is a subroutine library, you may
337
+ consider it more useful to permit linking proprietary applications with the
338
+ library. If this is what you want to do, use the GNU Lesser General
339
+ Public License instead of this License.
runtime/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # AVoice Runtime
2
+
3
+ Bundled runtime for HF-only AVoice-TTS inference.
runtime/THIRD_PARTY_NOTICES.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Third-Party Notices
2
+
3
+ ## OmniVoice
4
+
5
+ This repository vendors the PyTorch implementation from:
6
+
7
+ - Project: `k2-fsa/OmniVoice`
8
+ - Source: https://github.com/k2-fsa/OmniVoice
9
+ - Vendored commit: `7a68a5cffa71b904a862f4870b246966deebadf7`
10
+ - License: Apache License 2.0
11
+
12
+ The vendored code lives in `omnivoice/`. Local Armenian-specific changes are
13
+ kept in this repository so training, inference, tokenization, and model changes
14
+ can be edited without depending on an installed `omnivoice` wheel.
runtime/omnivoice/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from importlib.metadata import PackageNotFoundError, version
3
+
4
+ warnings.filterwarnings("ignore", module="torchaudio")
5
+ warnings.filterwarnings(
6
+ "ignore",
7
+ category=SyntaxWarning,
8
+ message="invalid escape sequence",
9
+ module="pydub.utils",
10
+ )
11
+ warnings.filterwarnings(
12
+ "ignore",
13
+ category=FutureWarning,
14
+ module="torch.distributed.algorithms.ddp_comm_hooks",
15
+ )
16
+
17
+ try:
18
+ __version__ = version("avoice")
19
+ except PackageNotFoundError:
20
+ __version__ = "0.0.0"
21
+
22
+ __all__ = ["OmniVoice", "OmniVoiceConfig", "OmniVoiceGenerationConfig"]
23
+
24
+
25
+ def __getattr__(name):
26
+ if name not in __all__:
27
+ raise AttributeError(f"module 'omnivoice' has no attribute {name!r}")
28
+
29
+ from omnivoice.models.omnivoice import (
30
+ OmniVoice,
31
+ OmniVoiceConfig,
32
+ OmniVoiceGenerationConfig,
33
+ )
34
+
35
+ values = {
36
+ "OmniVoice": OmniVoice,
37
+ "OmniVoiceConfig": OmniVoiceConfig,
38
+ "OmniVoiceGenerationConfig": OmniVoiceGenerationConfig,
39
+ }
40
+ return values[name]
runtime/omnivoice/cli/__init__.py ADDED
File without changes
runtime/omnivoice/cli/infer.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Single-item inference CLI for OmniVoice.
2
+
3
+ Generates audio from a single text input using voice cloning,
4
+ voice design, or auto voice.
5
+
6
+ Usage:
7
+ # Voice cloning
8
+ omnivoice-infer --model Hak5/AVoice \
9
+ --text "Hello, this is a text for text-to-speech." \
10
+ --ref_audio ref.wav --ref_text "Reference transcript." --output out.wav
11
+
12
+ # Voice design
13
+ omnivoice-infer --model Hak5/AVoice \
14
+ --text "Hello, this is a text for text-to-speech." \
15
+ --instruct "male, British accent" --output out.wav
16
+
17
+ # Auto voice
18
+ omnivoice-infer --model Hak5/AVoice \
19
+ --text "Hello, this is a text for text-to-speech." --output out.wav
20
+ """
21
+
22
+ import argparse
23
+ import logging
24
+
25
+ import torch
26
+
27
+ import soundfile as sf
28
+
29
+ from omnivoice.models.omnivoice import OmniVoice
30
+ from omnivoice.utils.common import str2bool
31
+
32
+
33
+ def get_best_device():
34
+ """Auto-detect the best available device: CUDA > MPS > CPU."""
35
+ if torch.cuda.is_available():
36
+ return "cuda"
37
+ if torch.backends.mps.is_available():
38
+ return "mps"
39
+ return "cpu"
40
+
41
+
42
+ def get_parser() -> argparse.ArgumentParser:
43
+ parser = argparse.ArgumentParser(
44
+ description="OmniVoice single-item inference",
45
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
46
+ )
47
+ parser.add_argument(
48
+ "--model",
49
+ type=str,
50
+ default="Hak5/AVoice",
51
+ help="Model checkpoint path or HuggingFace repo id.",
52
+ )
53
+ parser.add_argument(
54
+ "--text",
55
+ type=str,
56
+ required=True,
57
+ help="Text to synthesize.",
58
+ )
59
+ parser.add_argument(
60
+ "--output",
61
+ type=str,
62
+ required=True,
63
+ help="Output WAV file path.",
64
+ )
65
+ # Voice cloning
66
+ parser.add_argument(
67
+ "--ref_audio",
68
+ type=str,
69
+ default=None,
70
+ help="Reference audio file path for voice cloning.",
71
+ )
72
+ parser.add_argument(
73
+ "--ref_text",
74
+ type=str,
75
+ default=None,
76
+ help="Reference text describing the reference audio.",
77
+ )
78
+ # Voice design
79
+ parser.add_argument(
80
+ "--instruct",
81
+ type=str,
82
+ default=None,
83
+ help="Style instruction for voice design mode.",
84
+ )
85
+ parser.add_argument(
86
+ "--language",
87
+ type=str,
88
+ default=None,
89
+ help="Language name (e.g. 'English') or code (e.g. 'en').",
90
+ )
91
+ # Generation parameters
92
+ parser.add_argument("--num_step", type=int, default=32)
93
+ parser.add_argument("--guidance_scale", type=float, default=2.0)
94
+ parser.add_argument("--speed", type=float, default=1.0)
95
+ parser.add_argument(
96
+ "--duration",
97
+ type=float,
98
+ default=None,
99
+ help="Fixed output duration in seconds. If set, overrides the "
100
+ "model's duration estimation. The speed factor is automatically "
101
+ "adjusted to match while preserving language-aware pacing.",
102
+ )
103
+ parser.add_argument("--t_shift", type=float, default=0.1)
104
+ parser.add_argument("--denoise", type=str2bool, default=True)
105
+ parser.add_argument(
106
+ "--postprocess_output",
107
+ type=str2bool,
108
+ default=True,
109
+ )
110
+ parser.add_argument("--layer_penalty_factor", type=float, default=5.0)
111
+ parser.add_argument("--position_temperature", type=float, default=5.0)
112
+ parser.add_argument("--class_temperature", type=float, default=0.0)
113
+ parser.add_argument(
114
+ "--device",
115
+ type=str,
116
+ default=None,
117
+ help="Device to use for inference. Auto-detected if not specified.",
118
+ )
119
+ return parser
120
+
121
+
122
+ def main():
123
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
124
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
125
+
126
+ args = get_parser().parse_args()
127
+
128
+ device = args.device or get_best_device()
129
+ logging.info(f"Loading model from {args.model} on {device} ...")
130
+ dtype = torch.float16 if device == "cuda" else torch.float32
131
+ model = OmniVoice.from_pretrained(args.model, device_map=device, dtype=dtype)
132
+
133
+ logging.info(f"Generating audio for: {args.text[:80]}...")
134
+ audios = model.generate(
135
+ text=args.text,
136
+ language=args.language,
137
+ ref_audio=args.ref_audio,
138
+ ref_text=args.ref_text,
139
+ instruct=args.instruct,
140
+ duration=args.duration,
141
+ num_step=args.num_step,
142
+ guidance_scale=args.guidance_scale,
143
+ speed=args.speed,
144
+ t_shift=args.t_shift,
145
+ denoise=args.denoise,
146
+ postprocess_output=args.postprocess_output,
147
+ layer_penalty_factor=args.layer_penalty_factor,
148
+ position_temperature=args.position_temperature,
149
+ class_temperature=args.class_temperature,
150
+ )
151
+
152
+ sf.write(args.output, audios[0], model.sampling_rate)
153
+ logging.info(f"Saved to {args.output}")
154
+
155
+
156
+ if __name__ == "__main__":
157
+ main()
runtime/omnivoice/models/__init__.py ADDED
File without changes
runtime/omnivoice/models/omnivoice.py ADDED
@@ -0,0 +1,1610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Core OmniVoice model implementation for AVoice inference.
19
+
20
+ ``OmniVoice.from_pretrained()`` loads a local or Hugging Face checkpoint, then
21
+ ``model.generate()`` synthesizes audio from text with optional voice cloning,
22
+ voice design, and Armenian text normalization.
23
+ """
24
+
25
+ import difflib
26
+ import logging
27
+ import math
28
+ import os
29
+ import re
30
+ from dataclasses import dataclass, fields
31
+ from functools import partial
32
+ from typing import Any, List, Optional, Union
33
+
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ import torchaudio
39
+
40
+ try:
41
+ from torch.nn.attention.flex_attention import create_block_mask
42
+
43
+ _flex_attention_available = True
44
+ except ImportError:
45
+ _flex_attention_available = False
46
+ from transformers import (
47
+ AutoFeatureExtractor,
48
+ AutoModel,
49
+ AutoTokenizer,
50
+ HiggsAudioV2TokenizerModel,
51
+ PretrainedConfig,
52
+ PreTrainedModel,
53
+ )
54
+ from transformers.modeling_outputs import ModelOutput
55
+ from transformers.models.auto import CONFIG_MAPPING, AutoConfig
56
+
57
+ from omnivoice.utils.audio import (
58
+ cross_fade_chunks,
59
+ fade_and_pad_audio,
60
+ load_audio,
61
+ remove_silence,
62
+ trim_long_audio,
63
+ )
64
+ from omnivoice.utils.armenian_text import normalize_for_tts
65
+ from omnivoice.utils.duration import RuleDurationEstimator
66
+ from omnivoice.utils.lang_map import LANG_IDS, LANG_NAMES
67
+ from omnivoice.utils.text import add_punctuation, chunk_text_punctuation
68
+ from omnivoice.utils.voice_design import (
69
+ _INSTRUCT_ALL_VALID,
70
+ _INSTRUCT_EN_TO_ZH,
71
+ _INSTRUCT_MUTUALLY_EXCLUSIVE,
72
+ _INSTRUCT_VALID_EN,
73
+ _INSTRUCT_VALID_ZH,
74
+ _INSTRUCT_ZH_TO_EN,
75
+ _ZH_RE,
76
+ )
77
+
78
+ logger = logging.getLogger(__name__)
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # Dataclasses
83
+ # ---------------------------------------------------------------------------
84
+
85
+
86
+ @dataclass
87
+ class VoiceClonePrompt:
88
+ ref_audio_tokens: torch.Tensor # (C, T)
89
+ ref_text: str
90
+ ref_rms: float
91
+
92
+
93
+ @dataclass
94
+ class OmniVoiceGenerationConfig:
95
+ num_step: int = 32
96
+ guidance_scale: float = 2.0
97
+ t_shift: float = 0.1
98
+ layer_penalty_factor: float = 5.0
99
+ position_temperature: float = 5.0
100
+ class_temperature: float = 0.0
101
+ denoise: bool = True
102
+ preprocess_prompt: bool = True
103
+ postprocess_output: bool = True
104
+ audio_chunk_duration: float = 15.0
105
+ audio_chunk_threshold: float = 30.0
106
+
107
+ @classmethod
108
+ def from_dict(cls, kwargs_dict):
109
+ valid_keys = {f.name for f in fields(cls)}
110
+ filtered = {k: v for k, v in kwargs_dict.items() if k in valid_keys}
111
+ return cls(**filtered)
112
+
113
+
114
+ @dataclass
115
+ class GenerationTask:
116
+ batch_size: int
117
+ texts: List[str]
118
+ target_lens: List[int]
119
+ langs: List[Optional[str]]
120
+ instructs: List[Optional[str]]
121
+ ref_texts: List[Optional[str]]
122
+ ref_audio_tokens: List[Optional[torch.Tensor]]
123
+ ref_rms: List[Optional[float]]
124
+ speed: Optional[List[float]] = None
125
+
126
+ def get_indices(self, config: OmniVoiceGenerationConfig, frame_rate: int):
127
+ threshold = int(config.audio_chunk_threshold * frame_rate)
128
+ short_idx = [i for i, l in enumerate(self.target_lens) if l <= threshold]
129
+ long_idx = [i for i, l in enumerate(self.target_lens) if l > threshold]
130
+ return short_idx, long_idx
131
+
132
+ def slice_task(self, indices: List[int]):
133
+ if not indices:
134
+ return None
135
+ return GenerationTask(
136
+ batch_size=len(indices),
137
+ texts=[self.texts[i] for i in indices],
138
+ target_lens=[self.target_lens[i] for i in indices],
139
+ langs=[self.langs[i] for i in indices],
140
+ instructs=[self.instructs[i] for i in indices],
141
+ ref_texts=[self.ref_texts[i] for i in indices],
142
+ ref_audio_tokens=[self.ref_audio_tokens[i] for i in indices],
143
+ ref_rms=[self.ref_rms[i] for i in indices],
144
+ speed=[self.speed[i] for i in indices] if self.speed else None,
145
+ )
146
+
147
+
148
+ @dataclass
149
+ class OmniVoiceModelOutput(ModelOutput):
150
+ loss: Optional[torch.Tensor] = None
151
+ logits: Optional[torch.Tensor] = None
152
+ layer_losses: Optional[torch.Tensor] = None
153
+ layer_token_counts: Optional[torch.Tensor] = None
154
+
155
+
156
+ # ---------------------------------------------------------------------------
157
+ # Config & Model
158
+ # ---------------------------------------------------------------------------
159
+
160
+
161
+ class OmniVoiceConfig(PretrainedConfig):
162
+ model_type = "omnivoice"
163
+ sub_configs = {"llm_config": AutoConfig}
164
+
165
+ def __init__(
166
+ self,
167
+ audio_vocab_size: int = 1025,
168
+ audio_mask_id: int = 1024,
169
+ num_audio_codebook: int = 8,
170
+ audio_codebook_weights: Optional[list[float]] = None,
171
+ llm_config: Optional[Union[dict, PretrainedConfig]] = None,
172
+ **kwargs,
173
+ ):
174
+
175
+ if isinstance(llm_config, dict):
176
+ llm_config = CONFIG_MAPPING[llm_config["model_type"]](**llm_config)
177
+
178
+ self.llm_config = llm_config
179
+
180
+ super().__init__(**kwargs)
181
+ self.audio_vocab_size = audio_vocab_size
182
+ self.audio_mask_id = audio_mask_id
183
+ self.num_audio_codebook = num_audio_codebook
184
+ if audio_codebook_weights is None:
185
+ audio_codebook_weights = [8, 8, 6, 6, 4, 4, 2, 2]
186
+ self.audio_codebook_weights = audio_codebook_weights
187
+
188
+
189
+ def _resolve_model_path(name_or_path: str) -> str:
190
+ if os.path.isdir(name_or_path):
191
+ return name_or_path
192
+ from huggingface_hub import snapshot_download
193
+
194
+ return snapshot_download(name_or_path)
195
+
196
+
197
+ class OmniVoice(PreTrainedModel):
198
+ _supports_flex_attn = True
199
+ _supports_flash_attn_2 = True
200
+ _supports_sdpa = True
201
+ config_class = OmniVoiceConfig
202
+
203
+ def __init__(self, config: OmniVoiceConfig, llm: Optional[PreTrainedModel] = None):
204
+ super().__init__(config)
205
+
206
+ if llm is not None:
207
+ # If an LLM instance is provided, use it directly
208
+ # (skipping config-based init).
209
+ self.llm = llm
210
+ else:
211
+ # Otherwise, initialize the LLM from the config.
212
+ self.llm = AutoModel.from_config(self.config.llm_config)
213
+
214
+ self.audio_embeddings = nn.Embedding(
215
+ config.num_audio_codebook * config.audio_vocab_size,
216
+ self.config.llm_config.hidden_size,
217
+ )
218
+ self.register_buffer(
219
+ "codebook_layer_offsets",
220
+ torch.arange(config.num_audio_codebook) * config.audio_vocab_size,
221
+ )
222
+
223
+ self.audio_heads = nn.Linear(
224
+ self.config.llm_config.hidden_size,
225
+ config.num_audio_codebook * config.audio_vocab_size,
226
+ bias=False,
227
+ )
228
+
229
+ self.normalized_audio_codebook_weights = [
230
+ w / sum(config.audio_codebook_weights)
231
+ for w in config.audio_codebook_weights
232
+ ]
233
+
234
+ self.post_init()
235
+
236
+ # Inference-only attributes (set by from_pretrained when not in train mode)
237
+ self.text_tokenizer = None
238
+ self.audio_tokenizer = None
239
+ self.duration_estimator = None
240
+ self.sampling_rate = None
241
+ self._asr_pipe = None
242
+
243
+ @classmethod
244
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
245
+ train_mode = kwargs.pop("train", False)
246
+ load_asr = kwargs.pop("load_asr", False)
247
+ asr_model_name = kwargs.pop("asr_model_name", "openai/whisper-large-v3-turbo")
248
+
249
+ # Suppress noisy INFO logs from transformers/huggingface_hub during loading
250
+ _prev_disable = logging.root.manager.disable
251
+ logging.disable(logging.INFO)
252
+
253
+ try:
254
+ # Resolve to local path first; download only if not already cached
255
+ resolved_path = _resolve_model_path(pretrained_model_name_or_path)
256
+
257
+ model = super().from_pretrained(resolved_path, *args, **kwargs)
258
+
259
+ if not train_mode:
260
+ model.text_tokenizer = AutoTokenizer.from_pretrained(resolved_path)
261
+
262
+ audio_tokenizer_path = os.path.join(resolved_path, "audio_tokenizer")
263
+
264
+ if not os.path.isdir(audio_tokenizer_path):
265
+ audio_tokenizer_path = _resolve_model_path(
266
+ "eustlb/higgs-audio-v2-tokenizer"
267
+ )
268
+
269
+ # higgs-audio-v2-tokenizer does not support MPS
270
+ # (output channels > 65536)
271
+ tokenizer_device = (
272
+ "cpu" if str(model.device).startswith("mps") else model.device
273
+ )
274
+ model.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
275
+ audio_tokenizer_path, device_map=tokenizer_device
276
+ )
277
+ model.feature_extractor = AutoFeatureExtractor.from_pretrained(
278
+ audio_tokenizer_path
279
+ )
280
+
281
+ model.sampling_rate = model.feature_extractor.sampling_rate
282
+
283
+ model.duration_estimator = RuleDurationEstimator()
284
+
285
+ if load_asr:
286
+ model.load_asr_model(model_name=asr_model_name)
287
+ finally:
288
+ logging.disable(_prev_disable)
289
+
290
+ return model
291
+
292
+ # -------------------------------------------------------------------
293
+ # ASR support (optional, for auto-transcription)
294
+ # -------------------------------------------------------------------
295
+
296
+ def load_asr_model(self, model_name: str = "openai/whisper-large-v3-turbo"):
297
+ """Load a Whisper ASR model for reference audio transcription.
298
+
299
+ Args:
300
+ model_name: HuggingFace model name or local path for the Whisper model.
301
+ """
302
+ from transformers import pipeline as hf_pipeline
303
+
304
+ logger.info("Loading ASR model %s ...", model_name)
305
+ asr_dtype = (
306
+ torch.float16 if str(self.device).startswith("cuda") else torch.float32
307
+ )
308
+
309
+ model_name = _resolve_model_path(model_name)
310
+
311
+ self._asr_pipe = hf_pipeline(
312
+ "automatic-speech-recognition",
313
+ model=model_name,
314
+ dtype=asr_dtype,
315
+ device_map=self.device,
316
+ )
317
+ logger.info("ASR model loaded on %s.", self.device)
318
+
319
+ @torch.inference_mode()
320
+ def transcribe(
321
+ self,
322
+ audio: Union[str, tuple],
323
+ ) -> str:
324
+ """Transcribe audio using the loaded Whisper ASR model.
325
+
326
+ Args:
327
+ audio: File path or ``(waveform, sample_rate)`` tuple.
328
+ Waveform can be a numpy array or torch.Tensor of shape
329
+ ``(1, T)`` or ``(T,)``.
330
+
331
+ Returns:
332
+ Transcribed text.
333
+ """
334
+ if self._asr_pipe is None:
335
+ raise RuntimeError(
336
+ "ASR model is not loaded. Call model.load_asr_model() first."
337
+ )
338
+
339
+ if isinstance(audio, str):
340
+ return self._asr_pipe(audio)["text"].strip()
341
+ else:
342
+ waveform, sr = audio
343
+ if isinstance(waveform, torch.Tensor):
344
+ waveform = waveform.cpu().numpy()
345
+ waveform = np.squeeze(waveform) # (1, T) or (T,) → (T,)
346
+ audio_input = {
347
+ "array": waveform,
348
+ "sampling_rate": sr,
349
+ }
350
+ return self._asr_pipe(audio_input)["text"].strip()
351
+
352
+ def get_input_embeddings(self):
353
+ return self.llm.get_input_embeddings()
354
+
355
+ def set_input_embeddings(self, value):
356
+ self.llm.set_input_embeddings(value)
357
+
358
+ def _prepare_embed_inputs(
359
+ self, input_ids: torch.Tensor, audio_mask: torch.Tensor
360
+ ) -> torch.Tensor:
361
+ """
362
+ Prepares embeddings from input_ids of shape (batch_size, layers, seq_length).
363
+ Embedding shape is (batch_size, seq_length, hidden_size).
364
+ """
365
+ text_embeds = self.get_input_embeddings()(input_ids[:, 0, :])
366
+
367
+ # Apply shift to audio IDs based on codebook layer
368
+ # audio_ids: [Batch, 8, Seq]
369
+ # codebook_layer_offsets: [1, 8, 1]
370
+ # Result: Layer 0 ID Layer 1 ID + Layer 2 ID + 2050...
371
+ shifted_ids = (
372
+ input_ids * audio_mask.unsqueeze(1)
373
+ ) + self.codebook_layer_offsets.view(1, -1, 1)
374
+
375
+ # input: [Batch, 8, Seq] -> output: [Batch, Seq, Hidden]
376
+ audio_embeds = self.audio_embeddings(shifted_ids).sum(dim=1)
377
+
378
+ return torch.where(audio_mask.unsqueeze(-1), audio_embeds, text_embeds)
379
+
380
+ def forward(
381
+ self,
382
+ input_ids: torch.LongTensor,
383
+ audio_mask: torch.Tensor,
384
+ labels: Optional[torch.LongTensor] = None,
385
+ attention_mask: Optional[torch.Tensor] = None,
386
+ document_ids: Optional[torch.Tensor] = None,
387
+ position_ids: Optional[torch.LongTensor] = None,
388
+ ):
389
+
390
+ inputs_embeds = self._prepare_embed_inputs(input_ids, audio_mask)
391
+
392
+ if attention_mask is None and document_ids is not None:
393
+ if not _flex_attention_available:
394
+ raise RuntimeError(
395
+ "flex_attention is not available in the current environment. "
396
+ "If you do not need flex_attention, set "
397
+ '"attn_implementation": "sdpa" in your training config.'
398
+ )
399
+ attention_mask = create_block_mask(
400
+ _get_packed_mask(
401
+ document_ids[0].to(inputs_embeds.device),
402
+ ),
403
+ B=None,
404
+ H=None,
405
+ Q_LEN=input_ids.size(-1),
406
+ KV_LEN=input_ids.size(-1),
407
+ _compile=True,
408
+ device=inputs_embeds.device,
409
+ )
410
+
411
+ llm_outputs = self.llm(
412
+ inputs_embeds=inputs_embeds,
413
+ attention_mask=attention_mask,
414
+ return_dict=True,
415
+ position_ids=position_ids,
416
+ )
417
+ hidden_states = llm_outputs[0]
418
+
419
+ loss = None
420
+ layer_losses = None
421
+ layer_token_counts = None
422
+
423
+ # Shape: [B, S, C * Vocab]
424
+ batch_size, seq_len, _ = hidden_states.shape
425
+ logits_flat = self.audio_heads(hidden_states)
426
+ # Shape: [B, S, C, Vocab] -> [B, C, S, Vocab]
427
+ audio_logits = logits_flat.view(
428
+ batch_size,
429
+ seq_len,
430
+ self.config.num_audio_codebook,
431
+ self.config.audio_vocab_size,
432
+ ).permute(0, 2, 1, 3)
433
+
434
+ if labels is not None:
435
+
436
+ # audio_logits.permute(0, 3, 1, 2):
437
+ # [Batch, Layer, Seq, Vocab] -> [Batch, Vocab, Layer, Seq]
438
+ # per_token_loss shape: [Batch, Layer, Seq],ignore -100
439
+ per_token_loss = torch.nn.functional.cross_entropy(
440
+ audio_logits.permute(0, 3, 1, 2),
441
+ labels,
442
+ reduction="none",
443
+ ignore_index=-100,
444
+ )
445
+ # valid_mask shape: [Batch, Layer, Seq]
446
+ valid_mask = (labels != -100).float()
447
+
448
+ layer_token_counts = valid_mask.sum(dim=(0, 2))
449
+ # layer_means shape: [num_layers]
450
+ layer_losses = (per_token_loss * valid_mask).sum(
451
+ dim=(0, 2)
452
+ ) / layer_token_counts.clamp(min=1.0)
453
+
454
+ weights = torch.tensor(
455
+ self.normalized_audio_codebook_weights, device=audio_logits.device
456
+ )
457
+ loss = (layer_losses * weights).sum()
458
+
459
+ return OmniVoiceModelOutput(
460
+ loss=loss,
461
+ logits=audio_logits,
462
+ layer_losses=layer_losses,
463
+ layer_token_counts=layer_token_counts,
464
+ )
465
+
466
+ def supported_language_ids(self) -> set[str]:
467
+ """Return a list of supported language IDs."""
468
+ return LANG_IDS
469
+
470
+ def supported_language_names(self) -> set[str]:
471
+ """Return a list of supported language names."""
472
+ return LANG_NAMES
473
+
474
+ # -------------------------------------------------------------------
475
+ # Inference API
476
+ # -------------------------------------------------------------------
477
+
478
+ @torch.inference_mode()
479
+ def generate(
480
+ self,
481
+ text: Union[str, list[str]],
482
+ language: Union[str, list[str], None] = None,
483
+ ref_text: Union[str, list[str], None] = None,
484
+ ref_audio: Union[
485
+ str,
486
+ list[str],
487
+ tuple[torch.Tensor, int],
488
+ list[tuple[torch.Tensor, int]],
489
+ None,
490
+ ] = None,
491
+ voice_clone_prompt: Union[
492
+ VoiceClonePrompt, list[VoiceClonePrompt], None
493
+ ] = None,
494
+ instruct: Union[str, list[str], None] = None,
495
+ duration: Union[float, list[Optional[float]], None] = None,
496
+ speed: Union[float, list[Optional[float]], None] = None,
497
+ generation_config: Optional[OmniVoiceGenerationConfig] = None,
498
+ **kwargs,
499
+ ) -> list[np.ndarray]:
500
+ """Generate speech audio given text in various modes.
501
+
502
+ Supports three modes:
503
+
504
+ 1. **Voice clone** — clone the voice style from the reference audio.
505
+ Should provide ``voice_clone_prompt`` (from
506
+ :meth:`create_voice_clone_prompt`) or ``ref_text`` + ``ref_audio``.
507
+ 2. **Voice design** — provide ``instruct`` text describing
508
+ the desired voice style; no reference audio needed.
509
+ 3. **Auto** — provide neither; the model picks a voice itself.
510
+
511
+ Args:
512
+ text: Target text (single string or list for batch).
513
+ language: Language name (e.g. ``"English"``) or code
514
+ (e.g. ``"en"``). ``None`` for language-agnostic mode.
515
+ Performance is slightly better if you specify the language.
516
+ ref_text: Optional reference text for voice cloning mode.
517
+ ref_audio: Optional reference audio for voice cloning mode.
518
+ Can be a file path or a (waveform, sample_rate) tuple.
519
+ voice_clone_prompt: Reusable prompt from :meth:`create_voice_clone_prompt`.
520
+ If provided, it overrides ``ref_text`` and ``ref_audio``.
521
+ instruct: Style instruction for voice design mode.
522
+ duration: Fixed output duration in seconds. If a single float,
523
+ applies to all items; if a list, one value per item.
524
+ ``None`` (default) lets the model estimate duration from text.
525
+ Overrides ``speed`` when both are provided.
526
+ speed: Speaking speed factor. ``> 1.0`` for faster, ``< 1.0`` for
527
+ slower. If a list, one value per item. ``None`` (default) uses
528
+ the model's default estimation.
529
+ generation_config: Explicit config object. If provided, takes
530
+ precedence over ``**kwargs``.
531
+ **kwargs: Generation config or its fields:
532
+ denoise: Whether to prepend the ``<|denoise|>`` token.
533
+ num_step: Number of iterative decoding steps.
534
+ guidance_scale: Classifier-free guidance scale.
535
+ t_shift: Time-step shift (smaller → emphasise low-SNR).
536
+ postprocess_output: Post-process output (remove silence, fade-in/out, pad edges).
537
+ layer_penalty_factor: Penalty encouraging earlier codebook
538
+ layers to unmask first.
539
+ position_temperature: Temperature for position selection.
540
+ class_temperature: Temperature for token sampling (0 = greedy).
541
+ audio_chunk_duration: If > 0, split long text into chunks of
542
+ this duration (seconds) and generate chunk by chunk.
543
+ audio_chunk_threshold: Only apply chunking if estimated audio
544
+ duration exceeds this threshold (seconds).
545
+ Returns:
546
+ ``audios`` a list of 1-D ``np.ndarray`` with shape ``(T,)`` and
547
+ sampling rate consistent with the model's audio tokenizer
548
+ (usually 24 000 Hz). Can be saved directly with
549
+ ``soundfile.write("out.wav", audios[0], model.sampling_rate)``.
550
+ """
551
+
552
+ if self.audio_tokenizer is None or self.text_tokenizer is None:
553
+ raise RuntimeError(
554
+ "Model is not loaded with audio/text tokenizers. Make sure you "
555
+ "loaded the model with OmniVoice.from_pretrained()."
556
+ )
557
+ gen_config = (
558
+ generation_config
559
+ if generation_config is not None
560
+ else OmniVoiceGenerationConfig.from_dict(kwargs)
561
+ )
562
+
563
+ self.eval()
564
+
565
+ full_task = self._preprocess_all(
566
+ text=text,
567
+ language=language,
568
+ ref_text=ref_text,
569
+ ref_audio=ref_audio,
570
+ voice_clone_prompt=voice_clone_prompt,
571
+ instruct=instruct,
572
+ preprocess_prompt=gen_config.preprocess_prompt,
573
+ speed=speed,
574
+ duration=duration,
575
+ )
576
+
577
+ short_idx, long_idx = full_task.get_indices(
578
+ gen_config, self.audio_tokenizer.config.frame_rate
579
+ )
580
+
581
+ results = [None] * full_task.batch_size
582
+
583
+ if short_idx:
584
+ short_task = full_task.slice_task(short_idx)
585
+ short_results = self._generate_iterative(short_task, gen_config)
586
+ for idx, res in zip(short_idx, short_results):
587
+ results[idx] = res
588
+
589
+ if long_idx:
590
+ long_task = full_task.slice_task(long_idx)
591
+ long_results = self._generate_chunked(long_task, gen_config)
592
+ for idx, res in zip(long_idx, long_results):
593
+ results[idx] = res
594
+
595
+ generated_audios = []
596
+ for i in range(full_task.batch_size):
597
+ assert results[i] is not None, f"Result {i} was not generated"
598
+ generated_audios.append(
599
+ self._decode_and_post_process(
600
+ results[i], full_task.ref_rms[i], gen_config # type: ignore[arg-type]
601
+ )
602
+ )
603
+
604
+ return generated_audios
605
+
606
+ def create_voice_clone_prompt(
607
+ self,
608
+ ref_audio: Union[str, tuple[torch.Tensor, int]],
609
+ ref_text: Optional[str] = None,
610
+ preprocess_prompt: bool = True,
611
+ ) -> VoiceClonePrompt:
612
+ """Create a reusable voice clone prompt from reference audio.
613
+
614
+ Args:
615
+ ref_audio: File path (str) or ``(waveform, sample_rate)`` tuple.
616
+ waveform should be a 1-D or 2-D torch.Tensor (channels x samples).
617
+ ref_text: Transcript of the reference audio. If ``None``, the
618
+ ASR model will be used to auto-transcribe (must call
619
+ :meth:`load_asr_model` first).
620
+ preprocess_prompt: If ``True`` (default), apply silence removal and
621
+ trimming to the reference audio, add punctuation in the end
622
+ of reference text (if not already)
623
+
624
+ Returns:
625
+ A :class:`VoiceClonePrompt` that can be passed to :meth:`generate`.
626
+ """
627
+ if self.audio_tokenizer is None:
628
+ raise RuntimeError(
629
+ "Audio tokenizer is not loaded. Make sure you loaded the model "
630
+ "with OmniVoice.from_pretrained()."
631
+ )
632
+
633
+ if isinstance(ref_audio, str):
634
+ ref_wav = load_audio(ref_audio, self.sampling_rate)
635
+ else:
636
+ waveform, sr = ref_audio
637
+ if isinstance(waveform, torch.Tensor):
638
+ waveform = waveform.cpu().numpy()
639
+ if waveform.ndim == 1:
640
+ waveform = waveform[np.newaxis, :]
641
+ if waveform.shape[0] > 1:
642
+ waveform = np.mean(waveform, axis=0, keepdims=True)
643
+ if sr != self.sampling_rate:
644
+ waveform = torchaudio.functional.resample(
645
+ torch.from_numpy(waveform),
646
+ orig_freq=sr,
647
+ new_freq=self.sampling_rate,
648
+ ).numpy()
649
+ ref_wav = waveform
650
+
651
+ ref_rms = float(np.sqrt(np.mean(ref_wav**2)))
652
+ if 0 < ref_rms < 0.1:
653
+ ref_wav = ref_wav * 0.1 / ref_rms
654
+
655
+ if preprocess_prompt:
656
+ # Trim long reference audio (>20s) by splitting at the largest silence gap.
657
+ # Skip trimming when ref_text is user-provided, otherwise the
658
+ # trimmed audio will no longer match the full transcript.
659
+ if ref_text is None:
660
+ ref_wav = trim_long_audio(
661
+ ref_wav, self.sampling_rate, trim_threshold=20.0
662
+ )
663
+ ref_wav = remove_silence(
664
+ ref_wav,
665
+ self.sampling_rate,
666
+ mid_sil=200,
667
+ lead_sil=100,
668
+ trail_sil=200,
669
+ )
670
+ if ref_wav.shape[-1] == 0:
671
+ raise ValueError(
672
+ "Reference audio is empty after silence removal. "
673
+ "Try setting preprocess_prompt=False."
674
+ )
675
+
676
+ ref_duration = ref_wav.shape[-1] / self.sampling_rate
677
+ if ref_duration > 20.0:
678
+ logger.warning(
679
+ "Reference audio is %.1fs long (>20s). This may cause slower "
680
+ "generation, higher memory usage, and degraded voice cloning "
681
+ "quality. We recommend trimming it to 3-10s.",
682
+ ref_duration,
683
+ )
684
+
685
+ # Auto-transcribe if ref_text not provided
686
+ if ref_text is None:
687
+ if self._asr_pipe is None:
688
+ logger.info("ASR model not loaded yet, loading on-the-fly ...")
689
+ self.load_asr_model()
690
+ ref_text = self.transcribe((ref_wav, self.sampling_rate))
691
+ logger.debug("Auto-transcribed ref_text: %s", ref_text)
692
+
693
+ chunk_size = self.audio_tokenizer.config.hop_length
694
+ clip_size = int(ref_wav.shape[-1] % chunk_size)
695
+ ref_wav = ref_wav[:, :-clip_size] if clip_size > 0 else ref_wav
696
+ # numpy → torch at tokenizer boundary
697
+ ref_wav_tensor = torch.from_numpy(ref_wav).to(self.audio_tokenizer.device)
698
+ ref_audio_tokens = self.audio_tokenizer.encode(
699
+ ref_wav_tensor.unsqueeze(0),
700
+ ).audio_codes.squeeze(
701
+ 0
702
+ ) # (C, T)
703
+
704
+ if preprocess_prompt:
705
+ ref_text = add_punctuation(ref_text)
706
+
707
+ return VoiceClonePrompt(
708
+ ref_audio_tokens=ref_audio_tokens,
709
+ ref_text=ref_text,
710
+ ref_rms=ref_rms,
711
+ )
712
+
713
+ def _decode_and_post_process(
714
+ self,
715
+ tokens: Union[torch.Tensor, List[torch.Tensor]],
716
+ rms: Union[float, None],
717
+ gen_config: OmniVoiceGenerationConfig,
718
+ ) -> np.ndarray:
719
+ """
720
+ Args:
721
+ tokens: Audio tokens — either a single tensor of shape
722
+ (num_codebooks, seq_len) or a list of chunk tensors.
723
+ rms: RMS of the reference audio for volume adjustment.
724
+ gen_config: Generation config for post-processing options.
725
+ Returns:
726
+ Decoded and post-processed audio array of shape (T,).
727
+ """
728
+ tokenizer_device = self.audio_tokenizer.device
729
+ if isinstance(tokens, list):
730
+ chunk_audios = [
731
+ self.audio_tokenizer.decode(t.to(tokenizer_device).unsqueeze(0))
732
+ .audio_values[0]
733
+ .cpu()
734
+ .numpy()
735
+ for t in tokens
736
+ ]
737
+ audio_waveform = cross_fade_chunks(chunk_audios, self.sampling_rate)
738
+ else:
739
+ audio_waveform = (
740
+ self.audio_tokenizer.decode(tokens.to(tokenizer_device).unsqueeze(0))
741
+ .audio_values[0]
742
+ .cpu()
743
+ .numpy()
744
+ )
745
+
746
+ audio_waveform = self._post_process_audio(
747
+ audio_waveform,
748
+ postprocess_output=gen_config.postprocess_output,
749
+ ref_rms=rms,
750
+ )
751
+ return audio_waveform.squeeze(0)
752
+
753
+ def _post_process_audio(
754
+ self,
755
+ generated_audio: np.ndarray,
756
+ postprocess_output: bool,
757
+ ref_rms: Union[float, None],
758
+ ) -> np.ndarray:
759
+ """Optionally remove long silences, adjust volume, and add edge padding.
760
+
761
+ Args:
762
+ generated_audio: Numpy array of shape (1, T).
763
+ postprocess_output: If True, remove long silences and apply fade/pad.
764
+ ref_rms: RMS of the reference audio for volume normalisation.
765
+ Returns:
766
+ Processed numpy array of shape (1, T).
767
+ """
768
+ if postprocess_output:
769
+ generated_audio = remove_silence(
770
+ generated_audio,
771
+ self.sampling_rate,
772
+ mid_sil=500,
773
+ lead_sil=100,
774
+ trail_sil=100,
775
+ )
776
+
777
+ if ref_rms is not None and ref_rms < 0.1:
778
+ generated_audio = generated_audio * ref_rms / 0.1
779
+ elif ref_rms is None:
780
+ peak = np.abs(generated_audio).max()
781
+ if peak > 1e-6:
782
+ generated_audio = generated_audio / peak * 0.5
783
+
784
+ generated_audio = fade_and_pad_audio(
785
+ generated_audio,
786
+ sample_rate=self.sampling_rate,
787
+ )
788
+ return generated_audio
789
+
790
+ def _generate_chunked(
791
+ self, task: GenerationTask, gen_config: OmniVoiceGenerationConfig
792
+ ) -> List[List[torch.Tensor]]:
793
+ """Generate long audio by splitting text into chunks and batching.
794
+
795
+ Each item in the returned list corresponds to one input and contains
796
+ a list of audio token tensors — one per text chunk.
797
+
798
+ Args:
799
+ task: A :class:`GenerationTask` with one or more items whose
800
+ estimated audio exceeds ``audio_chunk_threshold``.
801
+ gen_config: Generation config (``audio_chunk_duration`` controls
802
+ chunk size).
803
+ Returns:
804
+ Per-item list of chunk token-tensor lists.
805
+ """
806
+ # Chunk each item's text
807
+ all_chunks = []
808
+ for i in range(task.batch_size):
809
+ avg_tokens_per_char = task.target_lens[i] / len(task.texts[i])
810
+ text_chunk_len = int(
811
+ gen_config.audio_chunk_duration
812
+ * self.audio_tokenizer.config.frame_rate
813
+ / avg_tokens_per_char
814
+ )
815
+ chunks = chunk_text_punctuation(
816
+ text=task.texts[i],
817
+ chunk_len=text_chunk_len,
818
+ min_chunk_len=3,
819
+ )
820
+ logger.debug(f"Item {i} chunked into {len(chunks)} pieces: {chunks}")
821
+ all_chunks.append(chunks)
822
+
823
+ has_ref = [t is not None for t in task.ref_audio_tokens]
824
+ assert all(has_ref) or not any(has_ref), (
825
+ "Chunked inference requires all items to either have or not have "
826
+ "ref_audio. Mixed ref/non-ref is not supported."
827
+ )
828
+
829
+ max_num_chunks = max(len(c) for c in all_chunks)
830
+
831
+ # chunk_results[item_idx] = list of generated token tensors per chunk
832
+ chunk_results = [[] for _ in range(task.batch_size)]
833
+
834
+ def _run_batch(indices, texts, ref_audios, ref_texts):
835
+ speed_list = task.speed
836
+ target_lens = [
837
+ self._estimate_target_tokens(
838
+ texts[j],
839
+ ref_texts[j],
840
+ ref_audios[j].size(-1) if ref_audios[j] is not None else None,
841
+ speed=speed_list[i] if speed_list else 1.0,
842
+ )
843
+ for j, i in enumerate(indices)
844
+ ]
845
+ sub_task = GenerationTask(
846
+ batch_size=len(indices),
847
+ texts=texts,
848
+ target_lens=target_lens,
849
+ langs=[task.langs[i] for i in indices],
850
+ instructs=[task.instructs[i] for i in indices],
851
+ ref_texts=ref_texts,
852
+ ref_audio_tokens=ref_audios,
853
+ ref_rms=[task.ref_rms[i] for i in indices],
854
+ speed=[task.speed[i] for i in indices] if task.speed else None,
855
+ )
856
+ gen_tokens = self._generate_iterative(sub_task, gen_config)
857
+ for j, idx in enumerate(indices):
858
+ chunk_results[idx].append(gen_tokens[j])
859
+
860
+ if all(has_ref):
861
+ # All items have reference audio.
862
+ # We still sequentially generate chunks within each item, but we
863
+ # batch across items for the same chunk index. This allows to keep
864
+ # the VRAM usage manageable while still benefiting from batching.
865
+ for ci in range(max_num_chunks):
866
+ indices = [i for i in range(task.batch_size) if ci < len(all_chunks[i])]
867
+ if not indices:
868
+ continue
869
+ _run_batch(
870
+ indices,
871
+ texts=[all_chunks[i][ci] for i in indices],
872
+ ref_audios=[task.ref_audio_tokens[i] for i in indices],
873
+ ref_texts=[task.ref_texts[i] for i in indices],
874
+ )
875
+ else:
876
+ # No reference audio — generate chunk 0 for all items first,
877
+ # then use chunk 0 output as reference for all subsequent chunks.
878
+ indices_0 = [i for i in range(task.batch_size) if len(all_chunks[i]) > 0]
879
+ _run_batch(
880
+ indices_0,
881
+ texts=[all_chunks[i][0] for i in indices_0],
882
+ ref_audios=[None] * len(indices_0),
883
+ ref_texts=[None] * len(indices_0),
884
+ )
885
+ first_chunk_map = {idx: chunk_results[idx][0] for idx in indices_0}
886
+
887
+ # Batch all remaining chunks, using chunk 0 as fixed reference
888
+ for ci in range(1, max_num_chunks):
889
+ indices = [i for i in range(task.batch_size) if ci < len(all_chunks[i])]
890
+ if not indices:
891
+ continue
892
+ _run_batch(
893
+ indices,
894
+ texts=[all_chunks[i][ci] for i in indices],
895
+ ref_audios=[first_chunk_map[i] for i in indices],
896
+ ref_texts=[all_chunks[i][0] for i in indices],
897
+ )
898
+
899
+ return chunk_results
900
+
901
+ def _preprocess_all(
902
+ self,
903
+ text: Union[str, list[str]],
904
+ language: Union[str, list[str], None] = None,
905
+ ref_text: Union[str, list[str], None] = None,
906
+ ref_audio: Union[
907
+ str,
908
+ list[str],
909
+ tuple[torch.Tensor, int],
910
+ list[tuple[torch.Tensor, int]],
911
+ None,
912
+ ] = None,
913
+ voice_clone_prompt: Union[
914
+ VoiceClonePrompt, list[VoiceClonePrompt], None
915
+ ] = None,
916
+ instruct: Union[str, list[str], None] = None,
917
+ preprocess_prompt: bool = True,
918
+ speed: Union[float, list[Optional[float]], None] = None,
919
+ duration: Union[float, list[Optional[float]], None] = None,
920
+ ) -> GenerationTask:
921
+
922
+ if isinstance(text, str):
923
+ text_list = [text]
924
+ else:
925
+ assert isinstance(
926
+ text, list
927
+ ), "text should be a string or a list of strings"
928
+ text_list = text
929
+ batch_size = len(text_list)
930
+
931
+ language_list = self._ensure_list(language, batch_size)
932
+ language_list = [_resolve_language(lang) for lang in language_list]
933
+ text_list = [
934
+ normalize_for_tts(text_item, language=language_list[i])
935
+ for i, text_item in enumerate(text_list)
936
+ ]
937
+ instruct_list = self._ensure_list(instruct, batch_size)
938
+ for i, s in enumerate(instruct_list):
939
+ if s is None:
940
+ continue
941
+ use_zh = bool(text_list[i] and _ZH_RE.search(text_list[i]))
942
+ instruct_list[i] = _resolve_instruct(s, use_zh=use_zh)
943
+
944
+ if voice_clone_prompt is not None and (
945
+ ref_text is not None or ref_audio is not None
946
+ ):
947
+ logger.warning(
948
+ "Both voice_clone_prompt and ref_text/ref_audio are provided. "
949
+ "ref_text/ref_audio will be ignored."
950
+ )
951
+ if voice_clone_prompt is None and ref_audio is not None:
952
+ # If voice_clone_prompt is not provided, create it from
953
+ # ref_audio (ref_text will be auto-transcribed if not given).
954
+ ref_text_list = self._ensure_list(ref_text, batch_size, auto_repeat=False)
955
+ ref_audio_list = self._ensure_list(ref_audio, batch_size, auto_repeat=False)
956
+
957
+ voice_clone_prompt = []
958
+ for i in range(len(ref_text_list)):
959
+ if ref_text_list[i] is not None:
960
+ lang_idx = i if i < len(language_list) else 0
961
+ ref_text_list[i] = normalize_for_tts(
962
+ ref_text_list[i], language=language_list[lang_idx]
963
+ )
964
+ voice_clone_prompt.append(
965
+ self.create_voice_clone_prompt(
966
+ ref_audio=ref_audio_list[i],
967
+ ref_text=ref_text_list[i],
968
+ preprocess_prompt=preprocess_prompt,
969
+ )
970
+ )
971
+
972
+ voice_clone_prompt_list = self._ensure_list(voice_clone_prompt, batch_size)
973
+ if voice_clone_prompt_list[0] is not None:
974
+ ref_text_list = [vc.ref_text for vc in voice_clone_prompt_list]
975
+ ref_audio_tokens_list = [
976
+ vc.ref_audio_tokens for vc in voice_clone_prompt_list
977
+ ]
978
+ ref_rms_list = [vc.ref_rms for vc in voice_clone_prompt_list]
979
+ else:
980
+ ref_text_list = [None] * batch_size
981
+ ref_audio_tokens_list = [None] * batch_size
982
+ ref_rms_list = [None] * batch_size
983
+
984
+ # Normalize speed/duration to per-item lists (may contain None).
985
+ if speed is not None:
986
+ if isinstance(speed, (int, float)):
987
+ user_speed = [float(speed)] * batch_size
988
+ else:
989
+ user_speed = list(speed)
990
+ else:
991
+ user_speed = None
992
+
993
+ if duration is not None:
994
+ if isinstance(duration, (int, float)):
995
+ durations = [float(duration)] * batch_size
996
+ else:
997
+ durations = list(duration)
998
+ else:
999
+ durations = None
1000
+
1001
+ num_target_tokens_list = []
1002
+ for i in range(batch_size):
1003
+ # duration[i] overrides speed for estimation: use speed=1.0
1004
+ # to get the raw estimate, then override target_lens below.
1005
+ has_dur = durations is not None and durations[i] is not None
1006
+ item_speed = 1.0 if has_dur else (user_speed[i] if user_speed else 1.0)
1007
+ est = self._estimate_target_tokens(
1008
+ text_list[i],
1009
+ ref_text_list[i],
1010
+ ref_audio_tokens_list[i].size(-1)
1011
+ if ref_audio_tokens_list[i] is not None
1012
+ else None,
1013
+ speed=item_speed,
1014
+ )
1015
+ num_target_tokens_list.append(est)
1016
+
1017
+ # Per-item duration overrides: set target_lens to exact frame count
1018
+ # and compute speed ratio so chunked generation scales proportionally.
1019
+ speed_list: Optional[List[float]] = None
1020
+ if durations is not None:
1021
+ frame_rate = self.audio_tokenizer.config.frame_rate
1022
+ speed_list = []
1023
+ for i in range(batch_size):
1024
+ if durations[i] is not None:
1025
+ target_tokens = max(1, int(durations[i] * frame_rate))
1026
+ est = num_target_tokens_list[i]
1027
+ speed_list.append(est / target_tokens if target_tokens > 0 else 1.0)
1028
+ num_target_tokens_list[i] = target_tokens
1029
+ else:
1030
+ s = user_speed[i] if user_speed else None
1031
+ speed_list.append(s if s is not None else 1.0)
1032
+ elif user_speed is not None:
1033
+ speed_list = [s if s is not None else 1.0 for s in user_speed]
1034
+
1035
+ return GenerationTask(
1036
+ batch_size=batch_size,
1037
+ texts=text_list,
1038
+ target_lens=num_target_tokens_list,
1039
+ langs=language_list,
1040
+ instructs=instruct_list,
1041
+ ref_texts=ref_text_list,
1042
+ ref_audio_tokens=ref_audio_tokens_list,
1043
+ ref_rms=ref_rms_list,
1044
+ speed=speed_list,
1045
+ )
1046
+
1047
+ def _estimate_target_tokens(self, text, ref_text, num_ref_audio_tokens, speed=1.0):
1048
+ """Estimate number of target audio tokens."""
1049
+ if num_ref_audio_tokens is None or ref_text is None or len(ref_text) == 0:
1050
+ # Fall back to a simple heuristic
1051
+ ref_text = "Nice to meet you."
1052
+ num_ref_audio_tokens = 25
1053
+
1054
+ est = self.duration_estimator.estimate_duration(
1055
+ text, ref_text, num_ref_audio_tokens
1056
+ )
1057
+ if speed > 0 and speed != 1.0:
1058
+ est = est / speed
1059
+ return max(1, int(est))
1060
+
1061
+ def _ensure_list(
1062
+ self, x: Union[Any, List[Any]], batch_size: int, auto_repeat: bool = True
1063
+ ) -> List[Any]:
1064
+ x_list = x if isinstance(x, list) else [x]
1065
+ if len(x_list) not in (
1066
+ 1,
1067
+ batch_size,
1068
+ ):
1069
+ raise ValueError(
1070
+ f"should be either the number of the text or 1, but got {len(x_list)}"
1071
+ )
1072
+ if auto_repeat and len(x_list) == 1 and batch_size is not None:
1073
+ x_list = x_list * batch_size
1074
+ return x_list
1075
+
1076
+ def _prepare_inference_inputs(
1077
+ self,
1078
+ text: str,
1079
+ num_target_tokens: int,
1080
+ ref_text: Optional[str] = None,
1081
+ ref_audio_tokens: Optional[torch.Tensor] = None,
1082
+ lang: Optional[str] = None,
1083
+ instruct: Optional[str] = None,
1084
+ denoise: bool = True,
1085
+ ):
1086
+ """Prepare input_ids and audio masks for inference.
1087
+ Args:
1088
+ text: Target text to generate.
1089
+ num_target_tokens: Number of audio tokens to generate.
1090
+ ref_text: Optional reference text for voice cloning.
1091
+ ref_audio_tokens: Optional reference audio tokens for voice cloning.
1092
+ with shape (C, T).
1093
+ lang: Optional language ID.
1094
+ instruct: Optional style instruction for voice design.
1095
+ denoise: Whether to include the <|denoise|> token.
1096
+ """
1097
+
1098
+ # Build style tokens: <|denoise|> + <|lang_start|>...<|lang_end|>
1099
+ # + <|instruct_start|>...<|instruct_end|>
1100
+ style_text = ""
1101
+ if denoise and ref_audio_tokens is not None:
1102
+ style_text += "<|denoise|>"
1103
+ lang_str = lang if lang else "None"
1104
+ instruct_str = instruct if instruct else "None"
1105
+ style_text += f"<|lang_start|>{lang_str}<|lang_end|>"
1106
+ style_text += f"<|instruct_start|>{instruct_str}<|instruct_end|>"
1107
+
1108
+ style_tokens = (
1109
+ self.text_tokenizer(style_text, return_tensors="pt")
1110
+ .input_ids.repeat(self.config.num_audio_codebook, 1)
1111
+ .unsqueeze(0)
1112
+ ).to(
1113
+ self.device
1114
+ ) # [1, C, N1]
1115
+
1116
+ # Build text tokens
1117
+ full_text = _combine_text(ref_text=ref_text, text=text)
1118
+ wrapped_text = f"<|text_start|>{full_text}<|text_end|>"
1119
+ text_tokens = (
1120
+ _tokenize_with_nonverbal_tags(wrapped_text, self.text_tokenizer)
1121
+ .repeat(self.config.num_audio_codebook, 1)
1122
+ .unsqueeze(0)
1123
+ ).to(
1124
+ self.device
1125
+ ) # [1, C, N2]
1126
+
1127
+ # Target: all MASK
1128
+ target_audio_tokens = torch.full(
1129
+ (1, self.config.num_audio_codebook, num_target_tokens),
1130
+ self.config.audio_mask_id,
1131
+ dtype=torch.long,
1132
+ device=self.device,
1133
+ )
1134
+
1135
+ # Conditional input
1136
+ parts = [style_tokens, text_tokens]
1137
+ if ref_audio_tokens is not None:
1138
+ parts.append(ref_audio_tokens.unsqueeze(0).to(self.device))
1139
+ parts.append(target_audio_tokens)
1140
+ cond_input_ids = torch.cat(parts, dim=2)
1141
+
1142
+ cond_total_length = cond_input_ids.shape[2]
1143
+ cond_audio_start_idx = cond_total_length - num_target_tokens
1144
+ if ref_audio_tokens is not None:
1145
+ cond_audio_start_idx -= ref_audio_tokens.size(-1)
1146
+
1147
+ cond_audio_mask = torch.zeros(
1148
+ 1, cond_total_length, dtype=torch.bool, device=self.device
1149
+ )
1150
+ cond_audio_mask[0, cond_audio_start_idx:] = True
1151
+
1152
+ return {
1153
+ "input_ids": cond_input_ids,
1154
+ "audio_mask": cond_audio_mask,
1155
+ }
1156
+
1157
+ def _generate_iterative(
1158
+ self, task: GenerationTask, gen_config: OmniVoiceGenerationConfig
1159
+ ) -> List[torch.Tensor]:
1160
+ """N-step iterative unmasked decoding.
1161
+
1162
+ Args:
1163
+ task: A :class:`GenerationTask` containing batch texts, target
1164
+ lengths, languages, instructions, and optional reference data.
1165
+ gen_config: A :class:`OmniVoiceGenerationConfig` controlling
1166
+ decoding steps, guidance, temperatures, etc.
1167
+ Returns:
1168
+ List of generated audio token tensors of shape (C, T) (one per
1169
+ input text).
1170
+ """
1171
+
1172
+ B = task.batch_size
1173
+
1174
+ for i in range(B):
1175
+ logger.debug(
1176
+ "Item %d — text: %s | ref_text: %s | instruct: %s | lang: %s | target_tokens: %d",
1177
+ i,
1178
+ task.texts[i],
1179
+ task.ref_texts[i],
1180
+ task.instructs[i],
1181
+ task.langs[i],
1182
+ task.target_lens[i],
1183
+ )
1184
+
1185
+ inputs_list = [
1186
+ self._prepare_inference_inputs(
1187
+ task.texts[i],
1188
+ task.target_lens[i],
1189
+ task.ref_texts[i],
1190
+ task.ref_audio_tokens[i],
1191
+ task.langs[i],
1192
+ task.instructs[i],
1193
+ gen_config.denoise,
1194
+ )
1195
+ for i in range(B)
1196
+ ]
1197
+
1198
+ c_lens = [inp["input_ids"].size(2) for inp in inputs_list]
1199
+ max_c_len = max(c_lens)
1200
+ pad_id = self.config.audio_mask_id # Or any other tokens
1201
+
1202
+ batch_input_ids = torch.full(
1203
+ (2 * B, self.config.num_audio_codebook, max_c_len),
1204
+ pad_id,
1205
+ dtype=torch.long,
1206
+ device=self.device,
1207
+ )
1208
+ batch_audio_mask = torch.zeros(
1209
+ (2 * B, max_c_len), dtype=torch.bool, device=self.device
1210
+ )
1211
+ batch_attention_mask = torch.zeros(
1212
+ (2 * B, 1, max_c_len, max_c_len), dtype=torch.bool, device=self.device
1213
+ )
1214
+
1215
+ for i, inp in enumerate(inputs_list):
1216
+ c_len, u_len = c_lens[i], task.target_lens[i]
1217
+
1218
+ # Cond (0 ~ B-1)
1219
+ batch_input_ids[i, :, :c_len] = inp["input_ids"]
1220
+ batch_audio_mask[i, :c_len] = inp["audio_mask"]
1221
+ batch_attention_mask[i, :, :c_len, :c_len] = True
1222
+
1223
+ # Uncond (B ~ 2B-1)
1224
+ batch_input_ids[B + i, :, :u_len] = inp["input_ids"][..., -u_len:]
1225
+ batch_audio_mask[B + i, :u_len] = inp["audio_mask"][..., -u_len:]
1226
+ batch_attention_mask[B + i, :, :u_len, :u_len] = True
1227
+ if max_c_len > u_len:
1228
+ pad_diag = torch.arange(u_len, max_c_len, device=self.device)
1229
+ batch_attention_mask[B + i, :, pad_diag, pad_diag] = True
1230
+
1231
+ tokens = torch.full(
1232
+ (B, self.config.num_audio_codebook, max(task.target_lens)),
1233
+ self.config.audio_mask_id,
1234
+ dtype=torch.long,
1235
+ device=self.device,
1236
+ )
1237
+
1238
+ timesteps = _get_time_steps(
1239
+ t_start=0.0,
1240
+ t_end=1.0,
1241
+ num_step=gen_config.num_step,
1242
+ t_shift=gen_config.t_shift,
1243
+ ).tolist()
1244
+ schedules = []
1245
+ for t_len in task.target_lens:
1246
+ total_mask = t_len * self.config.num_audio_codebook
1247
+ rem = total_mask
1248
+ sched = []
1249
+ for step in range(gen_config.num_step):
1250
+ num = (
1251
+ rem
1252
+ if step == gen_config.num_step - 1
1253
+ else min(
1254
+ math.ceil(total_mask * (timesteps[step + 1] - timesteps[step])),
1255
+ rem,
1256
+ )
1257
+ )
1258
+ sched.append(int(num))
1259
+ rem -= int(num)
1260
+ schedules.append(sched)
1261
+
1262
+ layer_ids = torch.arange(
1263
+ self.config.num_audio_codebook, device=self.device
1264
+ ).view(1, -1, 1)
1265
+
1266
+ for step in range(gen_config.num_step):
1267
+ batch_logits = self(
1268
+ input_ids=batch_input_ids,
1269
+ audio_mask=batch_audio_mask,
1270
+ attention_mask=batch_attention_mask,
1271
+ ).logits.to(torch.float32)
1272
+
1273
+ for i in range(B):
1274
+ k = schedules[i][step]
1275
+ if k <= 0:
1276
+ continue
1277
+
1278
+ c_len, t_len = c_lens[i], task.target_lens[i]
1279
+
1280
+ # Extract real target Logits
1281
+ # [1, C, T, V]
1282
+ c_logits = batch_logits[i : i + 1, :, c_len - t_len : c_len, :]
1283
+ u_logits = batch_logits[B + i : B + i + 1, :, :t_len, :]
1284
+
1285
+ pred_tokens, scores = self._predict_tokens_with_scoring(
1286
+ c_logits, u_logits, gen_config
1287
+ )
1288
+
1289
+ scores = scores - (layer_ids * gen_config.layer_penalty_factor)
1290
+
1291
+ if gen_config.position_temperature > 0.0:
1292
+ scores = _gumbel_sample(scores, gen_config.position_temperature)
1293
+
1294
+ sample_tokens = tokens[i : i + 1, :, :t_len]
1295
+ scores.masked_fill_(
1296
+ sample_tokens != self.config.audio_mask_id, -float("inf")
1297
+ )
1298
+
1299
+ _, topk_idx = torch.topk(scores.flatten(), k)
1300
+ flat_tokens = sample_tokens.flatten()
1301
+ flat_tokens[topk_idx] = pred_tokens.flatten()[topk_idx]
1302
+ sample_tokens.copy_(flat_tokens.view_as(sample_tokens))
1303
+
1304
+ # Update individual slices into batched structure
1305
+ tokens[i : i + 1, :, :t_len] = sample_tokens
1306
+ batch_input_ids[i : i + 1, :, c_len - t_len : c_len] = sample_tokens
1307
+ batch_input_ids[B + i : B + i + 1, :, :t_len] = sample_tokens
1308
+
1309
+ return [tokens[i, :, : task.target_lens[i]] for i in range(B)]
1310
+
1311
+ def _predict_tokens_with_scoring(self, c_logits, u_logits, gen_config):
1312
+ if gen_config.guidance_scale != 0:
1313
+ c_log_probs = F.log_softmax(c_logits, dim=-1)
1314
+ u_log_probs = F.log_softmax(u_logits, dim=-1)
1315
+ log_probs = torch.log_softmax(
1316
+ c_log_probs + gen_config.guidance_scale * (c_log_probs - u_log_probs),
1317
+ dim=-1,
1318
+ )
1319
+ else:
1320
+ log_probs = F.log_softmax(c_logits, dim=-1)
1321
+
1322
+ log_probs[..., self.config.audio_mask_id] = -float("inf")
1323
+
1324
+ if gen_config.class_temperature > 0.0:
1325
+ filtered_probs = _filter_top_k(log_probs, ratio=0.1)
1326
+ pred_tokens = _gumbel_sample(
1327
+ filtered_probs, gen_config.class_temperature
1328
+ ).argmax(dim=-1)
1329
+ else:
1330
+ pred_tokens = log_probs.argmax(dim=-1)
1331
+
1332
+ confidence_scores = log_probs.max(dim=-1)[0]
1333
+
1334
+ return pred_tokens, confidence_scores
1335
+
1336
+
1337
+ # ---------------------------------------------------------------------------
1338
+ # Standalone helpers
1339
+ # ---------------------------------------------------------------------------
1340
+
1341
+
1342
+ def _get_packed_mask(document_ids):
1343
+ return partial(_mask_mod_packed, document_ids)
1344
+
1345
+
1346
+ def _mask_mod_packed(document_ids, b, h, q_idx, kv_idx):
1347
+ # 1. Sequence Packing Logic: Tokens must belong to the same document.
1348
+ # Note: The doc_id for padding tokens is -1, which will automatically not match
1349
+ # (if handled correctly) or be ignored.
1350
+ same_doc = document_ids[q_idx] == document_ids[kv_idx]
1351
+ return same_doc
1352
+
1353
+
1354
+ def _resolve_language(language: Optional[str]) -> Union[str, None]:
1355
+ from omnivoice.utils.lang_map import LANG_IDS, LANG_NAME_TO_ID
1356
+
1357
+ if language is None or language.lower() == "none":
1358
+ return None
1359
+ if language in LANG_IDS:
1360
+ return language
1361
+ key = language.lower()
1362
+ if key in LANG_NAME_TO_ID:
1363
+ return LANG_NAME_TO_ID[key]
1364
+ logger.warning(
1365
+ f"Language '{language}' is not recognized. "
1366
+ f"Please use a valid language ID (e.g., 'en', 'zh', 'ja', 'de') "
1367
+ f"or a full language name (e.g., 'English', 'Chinese', 'Japanese'). "
1368
+ f"See supported_language_ids() or supported_language_names() for details. "
1369
+ f"Falling back to None (language-agnostic mode)."
1370
+ )
1371
+ return None
1372
+
1373
+
1374
+ def _resolve_instruct(
1375
+ instruct: Optional[str], use_zh: bool = False
1376
+ ) -> Union[str, None]:
1377
+ """Validate and normalise a voice-design instruct string.
1378
+
1379
+ Supported instruct items (case-insensitive for English):
1380
+
1381
+ English (comma + space separated):
1382
+ gender: male, female
1383
+ age: child, teenager, young adult, middle-aged, elderly
1384
+ pitch: very low pitch, low pitch, moderate pitch,
1385
+ high pitch, very high pitch
1386
+ style: whisper
1387
+ accent: american accent, british accent, australian accent, ...
1388
+
1389
+ Chinese (full-width comma separated):
1390
+ gender: 男, 女
1391
+ age: 儿童, 少年, 青年, 中年, 老年
1392
+ pitch: 极低音调, 低音调, 中音调, 高音调, 极高音调
1393
+ style: 耳语
1394
+ dialect: 河南话, 陕西话, 四川话, 贵州话, 云南话,
1395
+ 桂林话, 济南话, 石家庄话, 甘肃话, 宁夏话,
1396
+ 青岛话, 东北话
1397
+
1398
+ Minor issues (auto-fixed):
1399
+ - Wrong separator (half-width comma in Chinese instruct or
1400
+ full-width comma in English instruct)
1401
+ - Leading / trailing commas
1402
+
1403
+ Major issues (raise ``ValueError``):
1404
+ - Unsupported or misspelled instruct items
1405
+ - Suggestions are offered for close matches
1406
+
1407
+ Args:
1408
+ instruct: Raw instruct string, or ``None``.
1409
+ use_zh: If True, normalise all items to Chinese (used when the
1410
+ synthesis text contains Chinese and no accent is specified).
1411
+
1412
+ Returns:
1413
+ Normalised instruct string, or ``None``.
1414
+
1415
+ Raises:
1416
+ ValueError: if any instruct item is unsupported or misspelled.
1417
+ """
1418
+ if instruct is None:
1419
+ return None
1420
+
1421
+ instruct_str = instruct.strip()
1422
+ if not instruct_str:
1423
+ return None
1424
+
1425
+ # Split on both half-width and full-width commas
1426
+ raw_items = re.split(r"\s*[,,]\s*", instruct_str)
1427
+ raw_items = [x for x in raw_items if x]
1428
+
1429
+ # Validate each item
1430
+ unknown = []
1431
+ normalised = []
1432
+ for raw in raw_items:
1433
+ n = raw.strip().lower()
1434
+ if n in _INSTRUCT_ALL_VALID:
1435
+ normalised.append(n)
1436
+ else:
1437
+ sug = difflib.get_close_matches(n, _INSTRUCT_ALL_VALID, n=1, cutoff=0.6)
1438
+ unknown.append((raw, n, sug[0] if sug else None))
1439
+
1440
+ if unknown:
1441
+ lines = []
1442
+ for raw, n, sug in unknown:
1443
+ if sug:
1444
+ lines.append(f" '{raw}' -> '{n}' (unsupported; did you mean '{sug}'?)")
1445
+ else:
1446
+ lines.append(f" '{raw}' -> '{n}' (unsupported)")
1447
+ err = (
1448
+ f"Unsupported instruct items found in {instruct_str}:\n"
1449
+ + "\n".join(lines)
1450
+ + "\n\nValid English items: "
1451
+ + ", ".join(sorted(_INSTRUCT_VALID_EN))
1452
+ + "\nValid Chinese items: "
1453
+ + ",".join(sorted(_INSTRUCT_VALID_ZH))
1454
+ + "\n\nTip: Use only English or only Chinese instructs. "
1455
+ "English instructs should use comma + space (e.g. "
1456
+ "'male, indian accent'),\nChinese instructs should use full-width "
1457
+ "comma (e.g. '男,河南话')."
1458
+ )
1459
+ raise ValueError(err)
1460
+
1461
+ # --- Language consistency: dialect forces Chinese, accent forces English ---
1462
+ has_dialect = any(n.endswith("话") for n in normalised)
1463
+ has_accent = any(" accent" in n for n in normalised)
1464
+
1465
+ if has_dialect and has_accent:
1466
+ raise ValueError(
1467
+ "Cannot mix Chinese dialect and English accent in a single instruct. "
1468
+ "Dialects are for Chinese speech, accents for English speech."
1469
+ )
1470
+
1471
+ if has_dialect:
1472
+ use_zh = True
1473
+ elif has_accent:
1474
+ use_zh = False
1475
+
1476
+ # --- Unify to single language ---
1477
+ if use_zh:
1478
+ normalised = [_INSTRUCT_EN_TO_ZH.get(n, n) for n in normalised]
1479
+ else:
1480
+ normalised = [_INSTRUCT_ZH_TO_EN.get(n, n) for n in normalised]
1481
+
1482
+ # --- Category conflict check ---
1483
+ conflicts = []
1484
+ for cat in _INSTRUCT_MUTUALLY_EXCLUSIVE:
1485
+ hits = [n for n in normalised if n in cat]
1486
+ if len(hits) > 1:
1487
+ conflicts.append(hits)
1488
+ if conflicts:
1489
+ parts = []
1490
+ for group in conflicts:
1491
+ parts.append(" vs ".join(f"'{x}'" for x in group))
1492
+ raise ValueError(
1493
+ "Conflicting instruct items within the same category: "
1494
+ + "; ".join(parts)
1495
+ + ". Each category (gender, age, pitch, style, accent, dialect) "
1496
+ "allows at most one item."
1497
+ )
1498
+
1499
+ # Determine separator based on language
1500
+ has_zh = any(any("\u4e00" <= c <= "\u9fff" for c in n) for n in normalised)
1501
+ separator = "," if has_zh else ", "
1502
+
1503
+ return separator.join(normalised)
1504
+
1505
+
1506
+ def _filter_top_k(logits: torch.Tensor, ratio: float = 0.1) -> torch.Tensor:
1507
+ k = math.ceil(ratio * logits.shape[-1])
1508
+ val, ind = logits.topk(k, dim=-1)
1509
+ probs = torch.full_like(logits, float("-inf"))
1510
+ probs.scatter_(-1, ind, val)
1511
+ return probs
1512
+
1513
+
1514
+ def _gumbel_sample(logits: torch.Tensor, temperature: float) -> torch.Tensor:
1515
+ scaled_logits = logits / temperature
1516
+ u = torch.rand_like(scaled_logits)
1517
+ gumbel_noise = -torch.log(-torch.log(u + 1e-10) + 1e-10)
1518
+ return scaled_logits + gumbel_noise
1519
+
1520
+
1521
+ def _get_time_steps(
1522
+ t_start: float = 0.0,
1523
+ t_end: float = 1.0,
1524
+ num_step: int = 10,
1525
+ t_shift: float = 1.0,
1526
+ device: torch.device = torch.device("cpu"),
1527
+ ) -> torch.Tensor:
1528
+ timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
1529
+ timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
1530
+ return timesteps
1531
+
1532
+
1533
+ _NONVERBAL_PATTERN = re.compile(
1534
+ r"\[(laughter|sigh|confirmation-en|question-en|question-ah|question-oh|"
1535
+ r"question-ei|question-yi|surprise-ah|surprise-oh|surprise-wa|"
1536
+ r"surprise-yo|dissatisfaction-hnn)\]"
1537
+ )
1538
+
1539
+
1540
+ def _tokenize_with_nonverbal_tags(text: str, tokenizer) -> torch.Tensor:
1541
+ """Tokenize text containing non-verbal tags, handling each tag independently.
1542
+
1543
+ Non-verbal tags are tokenized standalone to guarantee consistent token
1544
+ IDs regardless of surrounding language context (Chinese, English, etc.).
1545
+
1546
+ Args:
1547
+ text: Full text string potentially containing non-verbal tags.
1548
+ tokenizer: HuggingFace text tokenizer instance.
1549
+ Returns:
1550
+ Token IDs tensor of shape (1, seq_len).
1551
+ """
1552
+ parts = []
1553
+ last_end = 0
1554
+ for m in _NONVERBAL_PATTERN.finditer(text):
1555
+ if m.start() > last_end:
1556
+ segment = text[last_end : m.start()]
1557
+ ids = tokenizer(segment, add_special_tokens=False).input_ids
1558
+ if ids:
1559
+ parts.append(ids)
1560
+ tag_ids = tokenizer(m.group(), add_special_tokens=False).input_ids
1561
+ if tag_ids:
1562
+ parts.append(tag_ids)
1563
+ last_end = m.end()
1564
+ if last_end < len(text):
1565
+ segment = text[last_end:]
1566
+ ids = tokenizer(segment, add_special_tokens=False).input_ids
1567
+ if ids:
1568
+ parts.append(ids)
1569
+
1570
+ if not parts:
1571
+ result = tokenizer(text, return_tensors="pt").input_ids
1572
+ else:
1573
+ combined = []
1574
+ for p in parts:
1575
+ combined.extend(p)
1576
+ result = torch.tensor([combined], dtype=torch.long)
1577
+ return result
1578
+
1579
+
1580
+ def _combine_text(text, ref_text: Optional[str] = None) -> str:
1581
+
1582
+ # combine with reference text if not None
1583
+ if ref_text:
1584
+ full_text = ref_text.strip() + " " + text.strip()
1585
+ else:
1586
+ full_text = text.strip()
1587
+
1588
+ # filter out newline / carriage-return characters
1589
+ full_text = re.sub(r"[\r\n]+", "", full_text)
1590
+
1591
+ # replace Chinese parentheses with English ones
1592
+ full_text = full_text.replace("\uff08", "(").replace("\uff09", ")")
1593
+
1594
+ # collapse consecutive spaces / tabs into a single space
1595
+ full_text = re.sub(r"[ \t]+", " ", full_text)
1596
+
1597
+ # remove spaces around chinese characters
1598
+ chinese_range = r"[\u4e00-\u9fff]"
1599
+ pattern = rf"(?<={chinese_range})\s+|\s+(?={chinese_range})"
1600
+ full_text = re.sub(pattern, "", full_text)
1601
+
1602
+ return full_text
1603
+
1604
+
1605
+ # ---------------------------------------------------------------------------
1606
+ # Register with HuggingFace Auto classes
1607
+ # ---------------------------------------------------------------------------
1608
+
1609
+ AutoConfig.register("omnivoice", OmniVoiceConfig)
1610
+ AutoModel.register(OmniVoiceConfig, OmniVoice)
runtime/omnivoice/server/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """HTTP API server helpers for OmniVoice."""
2
+
runtime/omnivoice/server/app.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import base64
5
+ import binascii
6
+ import io
7
+ import logging
8
+ import os
9
+ import threading
10
+ from contextlib import asynccontextmanager
11
+ from dataclasses import dataclass
12
+ from importlib import import_module
13
+ from typing import Any, Literal, Protocol
14
+
15
+ from fastapi import FastAPI, HTTPException, Response
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel, Field
18
+
19
+ from omnivoice import __version__
20
+ from omnivoice.utils.lang_map import LANG_NAME_TO_ID
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ DEFAULT_MODEL_PATH = "/app/model"
25
+ DEFAULT_HOST = "0.0.0.0"
26
+ DEFAULT_PORT = 8000
27
+
28
+
29
+ def _parse_bool(value: str | None, default: bool = False) -> bool:
30
+ if value is None:
31
+ return default
32
+ return value.strip().lower() in {"1", "true", "yes", "on"}
33
+
34
+
35
+ def _parse_origins(value: str | None) -> list[str]:
36
+ if not value or value.strip() == "*":
37
+ return ["*"]
38
+ return [item.strip() for item in value.split(",") if item.strip()]
39
+
40
+
41
+ def _clean_base64_audio(value: str) -> bytes:
42
+ payload = value.strip()
43
+ if "," in payload and payload.split(",", 1)[0].startswith("data:"):
44
+ payload = payload.split(",", 1)[1]
45
+ try:
46
+ return base64.b64decode(payload, validate=True)
47
+ except binascii.Error as exc:
48
+ raise ValueError("reference audio must be valid base64 or data URI") from exc
49
+
50
+
51
+ @dataclass(slots=True)
52
+ class ServerSettings:
53
+ model: str = DEFAULT_MODEL_PATH
54
+ device: str = "auto"
55
+ dtype: str = "auto"
56
+ host: str = DEFAULT_HOST
57
+ port: int = DEFAULT_PORT
58
+ log_level: str = "info"
59
+ cors_origins: list[str] | None = None
60
+ preload_model: bool = False
61
+ load_asr: bool = False
62
+ asr_model_name: str = "openai/whisper-large-v3-turbo"
63
+
64
+ @classmethod
65
+ def from_env(cls) -> "ServerSettings":
66
+ port = int(os.getenv("OMNIVOICE_PORT", str(DEFAULT_PORT)))
67
+ return cls(
68
+ model=os.getenv("OMNIVOICE_MODEL", DEFAULT_MODEL_PATH),
69
+ device=os.getenv("OMNIVOICE_DEVICE", "auto"),
70
+ dtype=os.getenv("OMNIVOICE_DTYPE", "auto"),
71
+ host=os.getenv("OMNIVOICE_HOST", DEFAULT_HOST),
72
+ port=port,
73
+ log_level=os.getenv("OMNIVOICE_LOG_LEVEL", "info"),
74
+ cors_origins=_parse_origins(os.getenv("OMNIVOICE_CORS_ORIGINS", "*")),
75
+ preload_model=_parse_bool(os.getenv("OMNIVOICE_PRELOAD", "0")),
76
+ load_asr=_parse_bool(os.getenv("OMNIVOICE_LOAD_ASR", "0")),
77
+ asr_model_name=os.getenv(
78
+ "OMNIVOICE_ASR_MODEL",
79
+ "openai/whisper-large-v3-turbo",
80
+ ),
81
+ )
82
+
83
+
84
+ class SynthesizeRequest(BaseModel):
85
+ text: str = Field(..., min_length=1, description="Text to synthesize.")
86
+ language: str | None = Field(
87
+ default=None,
88
+ description="Optional language code or language name.",
89
+ )
90
+ instruct: str | None = Field(
91
+ default=None,
92
+ description="Optional voice-design instruction.",
93
+ )
94
+ ref_text: str | None = Field(
95
+ default=None,
96
+ description="Transcript for the reference audio used in voice cloning.",
97
+ )
98
+ ref_audio_base64: str | None = Field(
99
+ default=None,
100
+ description="Optional reference audio encoded as base64 or data URI.",
101
+ )
102
+ num_step: int = Field(default=32, ge=1, le=128)
103
+ guidance_scale: float = Field(default=2.0, ge=0.0, le=20.0)
104
+ speed: float = Field(default=1.0, gt=0.0, le=4.0)
105
+ duration: float | None = Field(default=None, gt=0.0)
106
+ denoise: bool = True
107
+ preprocess_prompt: bool = True
108
+ postprocess_output: bool = True
109
+
110
+
111
+ class RuntimeStatus(BaseModel):
112
+ model_path: str
113
+ model_loaded: bool
114
+ device_preference: str
115
+ device_resolved: str | None = None
116
+ dtype_preference: str
117
+ dtype_resolved: str | None = None
118
+ load_asr: bool
119
+ last_error: str | None = None
120
+
121
+
122
+ class SynthesizeResponse(BaseModel):
123
+ audio_base64: str
124
+ sample_rate: int
125
+ duration_seconds: float
126
+ device: str
127
+ model_path: str
128
+
129
+
130
+ @dataclass(slots=True)
131
+ class SynthesisResult:
132
+ audio_bytes: bytes
133
+ sample_rate: int
134
+ duration_seconds: float
135
+ device: str
136
+ model_path: str
137
+
138
+
139
+ class RuntimeLike(Protocol):
140
+ settings: ServerSettings
141
+
142
+ def get_status(self) -> RuntimeStatus:
143
+ ...
144
+
145
+ def list_languages(self) -> list[dict[str, str]]:
146
+ ...
147
+
148
+ def synthesize(self, request: SynthesizeRequest) -> SynthesisResult:
149
+ ...
150
+
151
+ def maybe_preload(self) -> None:
152
+ ...
153
+
154
+
155
+ class OmniVoiceRuntime:
156
+ def __init__(self, settings: ServerSettings):
157
+ self.settings = settings
158
+ self._model: Any | None = None
159
+ self._torch: Any | None = None
160
+ self._config_cls: Any | None = None
161
+ self._device: str | None = None
162
+ self._dtype_name: str | None = None
163
+ self._last_error: str | None = None
164
+ self._load_lock = threading.Lock()
165
+ self._infer_lock = threading.Lock()
166
+
167
+ def get_status(self) -> RuntimeStatus:
168
+ return RuntimeStatus(
169
+ model_path=self.settings.model,
170
+ model_loaded=self._model is not None,
171
+ device_preference=self.settings.device,
172
+ device_resolved=self._device,
173
+ dtype_preference=self.settings.dtype,
174
+ dtype_resolved=self._dtype_name,
175
+ load_asr=self.settings.load_asr,
176
+ last_error=self._last_error,
177
+ )
178
+
179
+ def maybe_preload(self) -> None:
180
+ if self.settings.preload_model:
181
+ self._ensure_loaded()
182
+
183
+ def list_languages(self) -> list[dict[str, str]]:
184
+ languages = [
185
+ {"id": code, "name": name.title()}
186
+ for name, code in LANG_NAME_TO_ID.items()
187
+ ]
188
+ languages.sort(key=lambda item: (item["name"], item["id"]))
189
+ return languages
190
+
191
+ def synthesize(self, request: SynthesizeRequest) -> SynthesisResult:
192
+ if not request.text.strip():
193
+ raise ValueError("text must not be blank")
194
+ self._ensure_loaded()
195
+ assert self._model is not None
196
+ assert self._torch is not None
197
+ assert self._config_cls is not None
198
+
199
+ prompt = None
200
+ if request.ref_audio_base64:
201
+ prompt = self._build_voice_clone_prompt(
202
+ audio_base64=request.ref_audio_base64,
203
+ ref_text=request.ref_text,
204
+ preprocess_prompt=request.preprocess_prompt,
205
+ )
206
+ elif request.ref_text:
207
+ raise ValueError("ref_text requires ref_audio_base64 as well")
208
+
209
+ generation_config = self._config_cls(
210
+ num_step=request.num_step,
211
+ guidance_scale=request.guidance_scale,
212
+ denoise=request.denoise,
213
+ preprocess_prompt=request.preprocess_prompt,
214
+ postprocess_output=request.postprocess_output,
215
+ )
216
+
217
+ with self._infer_lock:
218
+ audios = self._model.generate(
219
+ text=request.text,
220
+ language=request.language,
221
+ voice_clone_prompt=prompt,
222
+ instruct=request.instruct,
223
+ duration=request.duration,
224
+ speed=request.speed,
225
+ generation_config=generation_config,
226
+ )
227
+
228
+ audio = audios[0]
229
+ wav_bytes = self._encode_wav(audio, self._model.sampling_rate)
230
+ duration_seconds = float(len(audio) / self._model.sampling_rate)
231
+ return SynthesisResult(
232
+ audio_bytes=wav_bytes,
233
+ sample_rate=int(self._model.sampling_rate),
234
+ duration_seconds=duration_seconds,
235
+ device=self._device or self.settings.device,
236
+ model_path=self.settings.model,
237
+ )
238
+
239
+ def _ensure_loaded(self) -> None:
240
+ if self._model is not None:
241
+ return
242
+ with self._load_lock:
243
+ if self._model is not None:
244
+ return
245
+ try:
246
+ torch_module = import_module("torch")
247
+ omnivoice_module = import_module("omnivoice")
248
+ model_cls = getattr(omnivoice_module, "OmniVoice")
249
+ config_cls = getattr(omnivoice_module, "OmniVoiceGenerationConfig")
250
+ device = self._resolve_device(torch_module)
251
+ dtype_name, dtype_value = self._resolve_dtype(torch_module, device)
252
+ logger.info(
253
+ "Loading OmniVoice model from %s on %s (%s)",
254
+ self.settings.model,
255
+ device,
256
+ dtype_name,
257
+ )
258
+ model = model_cls.from_pretrained(
259
+ self.settings.model,
260
+ device_map=device,
261
+ dtype=dtype_value,
262
+ load_asr=self.settings.load_asr,
263
+ asr_model_name=self.settings.asr_model_name,
264
+ )
265
+ except Exception as exc:
266
+ self._last_error = f"{type(exc).__name__}: {exc}"
267
+ raise
268
+
269
+ self._torch = torch_module
270
+ self._config_cls = config_cls
271
+ self._model = model
272
+ self._device = device
273
+ self._dtype_name = dtype_name
274
+ self._last_error = None
275
+
276
+ def _resolve_device(self, torch_module: Any) -> str:
277
+ choice = self.settings.device.strip().lower()
278
+ if choice == "auto":
279
+ if torch_module.cuda.is_available():
280
+ return "cuda"
281
+ mps_backend = getattr(getattr(torch_module, "backends", None), "mps", None)
282
+ if mps_backend is not None and mps_backend.is_available():
283
+ return "mps"
284
+ return "cpu"
285
+ if choice == "cuda":
286
+ if not torch_module.cuda.is_available():
287
+ raise RuntimeError("OMNIVOICE_DEVICE=cuda was requested but CUDA is unavailable")
288
+ return "cuda"
289
+ if choice == "mps":
290
+ mps_backend = getattr(getattr(torch_module, "backends", None), "mps", None)
291
+ if mps_backend is None or not mps_backend.is_available():
292
+ raise RuntimeError("OMNIVOICE_DEVICE=mps was requested but MPS is unavailable")
293
+ return "mps"
294
+ if choice == "cpu":
295
+ return "cpu"
296
+ raise RuntimeError(f"Unsupported device choice: {self.settings.device}")
297
+
298
+ def _resolve_dtype(self, torch_module: Any, device: str) -> tuple[str, Any]:
299
+ aliases = {
300
+ "fp16": "float16",
301
+ "half": "float16",
302
+ "fp32": "float32",
303
+ "float": "float32",
304
+ "bf16": "bfloat16",
305
+ }
306
+ choice = self.settings.dtype.strip().lower()
307
+ if choice == "auto":
308
+ choice = "float16" if device == "cuda" else "float32"
309
+ choice = aliases.get(choice, choice)
310
+ valid = {"float16", "float32", "bfloat16"}
311
+ if choice not in valid:
312
+ raise RuntimeError(f"Unsupported dtype choice: {self.settings.dtype}")
313
+ return choice, getattr(torch_module, choice)
314
+
315
+ def _build_voice_clone_prompt(
316
+ self,
317
+ audio_base64: str,
318
+ ref_text: str | None,
319
+ preprocess_prompt: bool,
320
+ ) -> Any:
321
+ assert self._model is not None
322
+ assert self._torch is not None
323
+ waveform, sample_rate = self._decode_audio(audio_base64)
324
+ return self._model.create_voice_clone_prompt(
325
+ ref_audio=(self._torch.from_numpy(waveform), sample_rate),
326
+ ref_text=ref_text,
327
+ preprocess_prompt=preprocess_prompt,
328
+ )
329
+
330
+ def _decode_audio(self, audio_base64: str) -> tuple[Any, int]:
331
+ import numpy as np
332
+ import soundfile as sf
333
+
334
+ raw_bytes = _clean_base64_audio(audio_base64)
335
+ audio_buffer = io.BytesIO(raw_bytes)
336
+ waveform, sample_rate = sf.read(audio_buffer, dtype="float32", always_2d=False)
337
+ if waveform.ndim == 2:
338
+ waveform = np.transpose(waveform)
339
+ return waveform, int(sample_rate)
340
+
341
+ def _encode_wav(self, audio: Any, sample_rate: int) -> bytes:
342
+ import soundfile as sf
343
+
344
+ buffer = io.BytesIO()
345
+ sf.write(buffer, audio, sample_rate, format="WAV")
346
+ return buffer.getvalue()
347
+
348
+
349
+ def create_app(
350
+ settings: ServerSettings | None = None,
351
+ runtime: RuntimeLike | None = None,
352
+ ) -> FastAPI:
353
+ if settings is None:
354
+ if runtime is not None and hasattr(runtime, "settings"):
355
+ settings = runtime.settings
356
+ else:
357
+ settings = ServerSettings.from_env()
358
+ runtime = runtime or OmniVoiceRuntime(settings)
359
+
360
+ @asynccontextmanager
361
+ async def lifespan(_: FastAPI):
362
+ runtime.maybe_preload()
363
+ yield
364
+
365
+ app = FastAPI(
366
+ title="AVoice OmniVoice API",
367
+ version=__version__,
368
+ summary="Local HTTP API for OmniVoice speech generation.",
369
+ lifespan=lifespan,
370
+ )
371
+
372
+ if settings.cors_origins:
373
+ app.add_middleware(
374
+ CORSMiddleware,
375
+ allow_origins=settings.cors_origins,
376
+ allow_credentials=True,
377
+ allow_methods=["*"],
378
+ allow_headers=["*"],
379
+ )
380
+
381
+ @app.get("/healthz", response_model=RuntimeStatus)
382
+ def healthz() -> RuntimeStatus:
383
+ return runtime.get_status()
384
+
385
+ @app.get("/v1/runtime", response_model=RuntimeStatus)
386
+ def runtime_status() -> RuntimeStatus:
387
+ return runtime.get_status()
388
+
389
+ @app.get("/v1/languages")
390
+ def languages() -> dict[str, list[dict[str, str]]]:
391
+ return {"languages": runtime.list_languages()}
392
+
393
+ @app.post("/v1/audio/speech")
394
+ def speech(request: SynthesizeRequest) -> Response:
395
+ try:
396
+ result = runtime.synthesize(request)
397
+ except ValueError as exc:
398
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
399
+ except RuntimeError as exc:
400
+ raise HTTPException(status_code=503, detail=str(exc)) from exc
401
+ except Exception as exc:
402
+ logger.exception("speech generation failed")
403
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
404
+
405
+ headers = {
406
+ "X-OmniVoice-Device": result.device,
407
+ "X-OmniVoice-Sample-Rate": str(result.sample_rate),
408
+ "X-OmniVoice-Model": result.model_path,
409
+ }
410
+ return Response(content=result.audio_bytes, media_type="audio/wav", headers=headers)
411
+
412
+ @app.post("/v1/audio/speech/json", response_model=SynthesizeResponse)
413
+ def speech_json(request: SynthesizeRequest) -> SynthesizeResponse:
414
+ try:
415
+ result = runtime.synthesize(request)
416
+ except ValueError as exc:
417
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
418
+ except RuntimeError as exc:
419
+ raise HTTPException(status_code=503, detail=str(exc)) from exc
420
+ except Exception as exc:
421
+ logger.exception("speech generation failed")
422
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
423
+
424
+ return SynthesizeResponse(
425
+ audio_base64=base64.b64encode(result.audio_bytes).decode("ascii"),
426
+ sample_rate=result.sample_rate,
427
+ duration_seconds=result.duration_seconds,
428
+ device=result.device,
429
+ model_path=result.model_path,
430
+ )
431
+
432
+ return app
433
+
434
+
435
+ def build_parser() -> argparse.ArgumentParser:
436
+ env = ServerSettings.from_env()
437
+ parser = argparse.ArgumentParser(
438
+ prog="omnivoice-api",
439
+ description="Serve OmniVoice inference as a local HTTP API.",
440
+ )
441
+ parser.add_argument("--model", default=env.model, help="Local model path or HuggingFace repo id.")
442
+ parser.add_argument(
443
+ "--device",
444
+ default=env.device,
445
+ help="Device selection: auto, cpu, cuda, or mps.",
446
+ )
447
+ parser.add_argument(
448
+ "--dtype",
449
+ default=env.dtype,
450
+ help="Precision selection: auto, float16, float32, or bfloat16.",
451
+ )
452
+ parser.add_argument("--host", default=env.host, help="Bind host.")
453
+ parser.add_argument("--port", type=int, default=env.port, help="Bind port.")
454
+ parser.add_argument("--log-level", default=env.log_level, help="Uvicorn log level.")
455
+ parser.add_argument(
456
+ "--preload-model",
457
+ action="store_true",
458
+ default=env.preload_model,
459
+ help="Load the model during server startup instead of on first request.",
460
+ )
461
+ parser.add_argument(
462
+ "--load-asr",
463
+ action="store_true",
464
+ default=env.load_asr,
465
+ help="Load the Whisper ASR helper at startup for reference-audio transcription.",
466
+ )
467
+ parser.add_argument(
468
+ "--asr-model-name",
469
+ default=env.asr_model_name,
470
+ help="Whisper model to use when ref_text is omitted.",
471
+ )
472
+ parser.add_argument(
473
+ "--cors-origins",
474
+ default=",".join(env.cors_origins or ["*"]),
475
+ help="Comma-separated CORS allowlist or *.",
476
+ )
477
+ return parser
478
+
479
+
480
+ def main() -> None:
481
+ parser = build_parser()
482
+ args = parser.parse_args()
483
+ settings = ServerSettings(
484
+ model=args.model,
485
+ device=args.device,
486
+ dtype=args.dtype,
487
+ host=args.host,
488
+ port=args.port,
489
+ log_level=args.log_level,
490
+ cors_origins=_parse_origins(args.cors_origins),
491
+ preload_model=args.preload_model,
492
+ load_asr=args.load_asr,
493
+ asr_model_name=args.asr_model_name,
494
+ )
495
+
496
+ import uvicorn
497
+
498
+ uvicorn.run(
499
+ create_app(settings=settings),
500
+ host=settings.host,
501
+ port=settings.port,
502
+ log_level=settings.log_level,
503
+ )
504
+
505
+
506
+ app = create_app()
runtime/omnivoice/server/prefetch.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import logging
5
+ import os
6
+ import shutil
7
+ from pathlib import Path
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def _resolve_path(name_or_path: str) -> str:
13
+ if os.path.isdir(name_or_path):
14
+ return name_or_path
15
+
16
+ from huggingface_hub import snapshot_download
17
+
18
+ return snapshot_download(name_or_path)
19
+
20
+
21
+ def build_parser() -> argparse.ArgumentParser:
22
+ parser = argparse.ArgumentParser(
23
+ prog="omnivoice-prefetch",
24
+ description="Cache auxiliary OmniVoice assets for offline/container use.",
25
+ )
26
+ parser.add_argument(
27
+ "--model-dir",
28
+ default="/app/model",
29
+ help="Directory containing the OmniVoice model files.",
30
+ )
31
+ parser.add_argument(
32
+ "--audio-tokenizer",
33
+ default="eustlb/higgs-audio-v2-tokenizer",
34
+ help="Audio tokenizer repo id or local path.",
35
+ )
36
+ parser.add_argument(
37
+ "--asr-model",
38
+ default="openai/whisper-large-v3-turbo",
39
+ help="ASR model repo id or local path.",
40
+ )
41
+ parser.add_argument(
42
+ "--copy-audio-tokenizer",
43
+ action="store_true",
44
+ help="Copy the tokenizer into <model-dir>/audio_tokenizer.",
45
+ )
46
+ parser.add_argument(
47
+ "--prefetch-asr",
48
+ action="store_true",
49
+ help="Download the ASR model into the Hugging Face cache.",
50
+ )
51
+ return parser
52
+
53
+
54
+ def main() -> None:
55
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
56
+ args = build_parser().parse_args()
57
+
58
+ tokenizer_path = _resolve_path(args.audio_tokenizer)
59
+ if args.copy_audio_tokenizer:
60
+ model_dir = Path(args.model_dir)
61
+ target_dir = model_dir / "audio_tokenizer"
62
+ target_dir.mkdir(parents=True, exist_ok=True)
63
+ shutil.copytree(tokenizer_path, target_dir, dirs_exist_ok=True)
64
+ logger.info("Copied audio tokenizer to %s", target_dir)
65
+ else:
66
+ logger.info("Cached audio tokenizer at %s", tokenizer_path)
67
+
68
+ if args.prefetch_asr:
69
+ asr_path = _resolve_path(args.asr_model)
70
+ logger.info("Cached ASR model at %s", asr_path)
71
+
runtime/omnivoice/utils/__init__.py ADDED
File without changes
runtime/omnivoice/utils/armenian_text.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Armenian text frontend for TTS.
3
+
4
+ The acoustic model should see pronounceable text, not written shortcuts such as
5
+ ``02.02.2026`` or ``25%``. This module keeps that logic in one place so the
6
+ same frontend can be used for manifest preparation and inference.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import asdict, dataclass
12
+ import re
13
+ from typing import Iterable, List, Optional
14
+
15
+
16
+ ARMENIAN_CHAR_RE = re.compile(r"[\u0530-\u058f]")
17
+ SPACE_RE = re.compile(r"\s+")
18
+
19
+ REPLACEMENTS = {
20
+ "\u00a0": " ",
21
+ "\u200b": "",
22
+ "\u200c": "",
23
+ "\u200d": "",
24
+ "\u2018": "'",
25
+ "\u2019": "'",
26
+ "\u201c": '"',
27
+ "\u201d": '"',
28
+ "\u2013": "-",
29
+ "\u2014": "-",
30
+ "\u2212": "-",
31
+ }
32
+
33
+ LANGUAGE_IDS_ARMENIAN = {
34
+ "hy",
35
+ "hye",
36
+ "hyw",
37
+ "hye-east",
38
+ "hyw-west",
39
+ "armenian",
40
+ "eastern armenian",
41
+ "western armenian",
42
+ }
43
+
44
+ ONES = {
45
+ 0: "զրո",
46
+ 1: "մեկ",
47
+ 2: "երկու",
48
+ 3: "երեք",
49
+ 4: "չորս",
50
+ 5: "հինգ",
51
+ 6: "վեց",
52
+ 7: "յոթ",
53
+ 8: "ութ",
54
+ 9: "ինը",
55
+ }
56
+
57
+ TEENS = {
58
+ 10: "տասը",
59
+ 11: "տասնմեկ",
60
+ 12: "տասներկու",
61
+ 13: "տասներեք",
62
+ 14: "տասնչորս",
63
+ 15: "տասնհինգ",
64
+ 16: "տասնվեց",
65
+ 17: "տասնյոթ",
66
+ 18: "տասնութ",
67
+ 19: "տասնինը",
68
+ }
69
+
70
+ TENS = {
71
+ 20: "քսան",
72
+ 30: "երեսուն",
73
+ 40: "քառասուն",
74
+ 50: "հիսուն",
75
+ 60: "վաթսուն",
76
+ 70: "յոթանասուն",
77
+ 80: "ութսուն",
78
+ 90: "իննսուն",
79
+ }
80
+
81
+ MONTHS_GENITIVE = {
82
+ 1: "հունվարի",
83
+ 2: "փետրվարի",
84
+ 3: "մարտի",
85
+ 4: "ապրիլի",
86
+ 5: "մայիսի",
87
+ 6: "հունիսի",
88
+ 7: "հուլիսի",
89
+ 8: "օգոստոսի",
90
+ 9: "սեպտեմբերի",
91
+ 10: "հոկտեմբերի",
92
+ 11: "նոյեմբերի",
93
+ 12: "դեկտեմբերի",
94
+ }
95
+
96
+ DAY_DATIVE = {
97
+ 1: "մեկին",
98
+ 2: "երկուսին",
99
+ 3: "երեքին",
100
+ 4: "չորսին",
101
+ 5: "հինգին",
102
+ 6: "վեցին",
103
+ 7: "յոթին",
104
+ 8: "ութին",
105
+ 9: "իննին",
106
+ 10: "տասին",
107
+ 11: "տասնմեկին",
108
+ 12: "տասներկուսին",
109
+ 13: "տասներեքին",
110
+ 14: "տասնչորսին",
111
+ 15: "տասնհինգին",
112
+ 16: "տասնվեցին",
113
+ 17: "տասնյոթին",
114
+ 18: "տասնութին",
115
+ 19: "տասնիննին",
116
+ 20: "քսանին",
117
+ 21: "քսանմեկին",
118
+ 22: "քսաներկուսին",
119
+ 23: "քսաներեքին",
120
+ 24: "քսանչորսին",
121
+ 25: "քսանհինգին",
122
+ 26: "քսանվեցին",
123
+ 27: "քսանյոթին",
124
+ 28: "քսանութին",
125
+ 29: "քսանիննին",
126
+ 30: "երեսունին",
127
+ 31: "երեսունմեկին",
128
+ }
129
+
130
+ ORDINALS = {
131
+ 1: "առաջին",
132
+ 2: "երկրորդ",
133
+ 3: "երրորդ",
134
+ 4: "չորրորդ",
135
+ 5: "հինգերորդ",
136
+ 6: "վեցերորդ",
137
+ 7: "յոթերորդ",
138
+ 8: "ութերորդ",
139
+ 9: "իններորդ",
140
+ 10: "տասներորդ",
141
+ 20: "քսաներորդ",
142
+ 30: "երեսուներորդ",
143
+ 40: "քառասուներորդ",
144
+ 50: "հիսուներորդ",
145
+ 60: "վաթսուներորդ",
146
+ 70: "յոթանասուներորդ",
147
+ 80: "ութսուներորդ",
148
+ 90: "իննսուներորդ",
149
+ 100: "հարյուրերորդ",
150
+ 1000: "հազարերորդ",
151
+ }
152
+
153
+ CURRENCY_NAMES = {
154
+ "֏": "դրամ",
155
+ "դրամ": "դրամ",
156
+ "դր": "դրամ",
157
+ "դր.": "դրամ",
158
+ "amd": "դրամ",
159
+ "$": "դոլար",
160
+ "usd": "դոլար",
161
+ "€": "եվրո",
162
+ "eur": "եվրո",
163
+ "£": "ֆունտ",
164
+ "gbp": "ֆունտ",
165
+ "₽": "ռուբլի",
166
+ "rub": "ռուբլի",
167
+ }
168
+
169
+ LETTER_DIGIT = r"A-Za-zԱ-Ֆա-ֆևԵՎՙ՚՛՜՝՞՟\d_"
170
+ LEFT_BOUNDARY = rf"(?<![{LETTER_DIGIT}])"
171
+ RIGHT_BOUNDARY = rf"(?![{LETTER_DIGIT}])"
172
+
173
+ DATE_SUFFIX = r"(?:\s*-?\s*(?:ին|թ\.?|թվականին))?"
174
+ DATE_DMY_RE = re.compile(
175
+ rf"{LEFT_BOUNDARY}([0-3]?\d)[./-]([01]?\d)[./-]((?:18|19|20|21)\d{{2}}){DATE_SUFFIX}{RIGHT_BOUNDARY}"
176
+ )
177
+ DATE_YMD_RE = re.compile(
178
+ rf"{LEFT_BOUNDARY}((?:18|19|20|21)\d{{2}})[./-]([01]?\d)[./-]([0-3]?\d){DATE_SUFFIX}{RIGHT_BOUNDARY}"
179
+ )
180
+ TIME_RE = re.compile(
181
+ rf"{LEFT_BOUNDARY}(?:ժամը\s*)?([0-2]?\d):([0-5]\d)(?::([0-5]\d))?{RIGHT_BOUNDARY}"
182
+ )
183
+ PERCENT_RE = re.compile(
184
+ rf"{LEFT_BOUNDARY}([+-]?(?:\d{{1,3}}(?:[ ,]\d{{3}})+|\d+)(?:[.,]\d+)?)\s*(?:%|տոկոս){RIGHT_BOUNDARY}",
185
+ re.IGNORECASE,
186
+ )
187
+ CURRENCY_PREFIX_RE = re.compile(
188
+ rf"{LEFT_BOUNDARY}([$€£₽֏])\s*([+-]?(?:\d{{1,3}}(?:[ ,]\d{{3}})+|\d+)(?:[.,]\d+)?)"
189
+ )
190
+ CURRENCY_SUFFIX_RE = re.compile(
191
+ rf"{LEFT_BOUNDARY}([+-]?(?:\d{{1,3}}(?:[ ,]\d{{3}})+|\d+)(?:[.,]\d+)?)\s*(֏|դրամ|դր\.?|AMD|USD|EUR|GBP|RUB|\$|€|£|₽){RIGHT_BOUNDARY}",
192
+ re.IGNORECASE,
193
+ )
194
+ ORDINAL_RE = re.compile(rf"{LEFT_BOUNDARY}(\d+)\s*[-\u2010-\u2015]?\s*(?:րդ|ին){RIGHT_BOUNDARY}")
195
+ RANGE_RE = re.compile(rf"{LEFT_BOUNDARY}(\d+)\s*[-\u2010-\u2015]\s*(\d+){RIGHT_BOUNDARY}")
196
+ PLAIN_NUMBER_RE = re.compile(
197
+ rf"{LEFT_BOUNDARY}([+-]?(?:\d{{1,3}}(?:[ ,]\d{{3}})+|\d+)(?:[.,]\d+)?)"
198
+ rf"{RIGHT_BOUNDARY}"
199
+ )
200
+ URL_RE = re.compile(r"\b(?:https?://|www\.)\S+", re.IGNORECASE)
201
+ EMAIL_RE = re.compile(r"\b[\w.+-]+@[\w.-]+\.[A-Za-z]{2,}\b")
202
+ SYMBOL_RE = re.compile(r"[@#&=+*/\\|<>_~^]")
203
+ LATIN_TOKEN_RE = re.compile(rf"{LEFT_BOUNDARY}[A-Za-z]{{2,}}{RIGHT_BOUNDARY}")
204
+
205
+
206
+ @dataclass(frozen=True)
207
+ class TextIssue:
208
+ kind: str
209
+ value: str
210
+ start: int
211
+ end: int
212
+
213
+ def as_dict(self) -> dict:
214
+ return asdict(self)
215
+
216
+
217
+ def looks_armenian_context(text: str, language: Optional[str] = None) -> bool:
218
+ if language and language.strip().lower() in LANGUAGE_IDS_ARMENIAN:
219
+ return True
220
+ return bool(ARMENIAN_CHAR_RE.search(text))
221
+
222
+
223
+ def normalize_unicode_text(text: str) -> str:
224
+ s = str(text)
225
+ for old, new in REPLACEMENTS.items():
226
+ s = s.replace(old, new)
227
+ s = SPACE_RE.sub(" ", s.strip())
228
+ return s
229
+
230
+
231
+ def integer_to_armenian(value: int) -> str:
232
+ if value < 0:
233
+ return "մինուս " + integer_to_armenian(abs(value))
234
+ if value < 10:
235
+ return ONES[value]
236
+ if value < 20:
237
+ return TEENS[value]
238
+ if value < 100:
239
+ tens = value // 10 * 10
240
+ ones = value % 10
241
+ return TENS[tens] + (ONES[ones] if ones else "")
242
+ if value < 1000:
243
+ hundreds = value // 100
244
+ rest = value % 100
245
+ prefix = "հարյուր" if hundreds == 1 else f"{integer_to_armenian(hundreds)} հարյուր"
246
+ return prefix if rest == 0 else f"{prefix} {integer_to_armenian(rest)}"
247
+
248
+ for scale, name in (
249
+ (1_000_000_000, "միլիարդ"),
250
+ (1_000_000, "միլիոն"),
251
+ (1000, "հազար"),
252
+ ):
253
+ if value >= scale:
254
+ head = value // scale
255
+ rest = value % scale
256
+ if scale == 1000 and head == 1:
257
+ prefix = name
258
+ else:
259
+ prefix = f"{integer_to_armenian(head)} {name}"
260
+ return prefix if rest == 0 else f"{prefix} {integer_to_armenian(rest)}"
261
+
262
+ raise ValueError(f"Unsupported integer value: {value}")
263
+
264
+
265
+ def _strip_group_separators(value: str) -> str:
266
+ return re.sub(r"(?<=\d)[ ,](?=\d{3}(?:\D|$))", "", value)
267
+
268
+
269
+ def number_to_armenian(value: str | int) -> str:
270
+ raw = str(value).strip()
271
+ if not raw:
272
+ return raw
273
+
274
+ sign = ""
275
+ if raw[0] in "+-":
276
+ sign = "մինուս " if raw[0] == "-" else ""
277
+ raw = raw[1:]
278
+
279
+ raw = _strip_group_separators(raw)
280
+
281
+ decimal_match = re.fullmatch(r"(\d+)[.,](\d+)", raw)
282
+ if decimal_match:
283
+ whole, frac = decimal_match.groups()
284
+ frac_words = " ".join(ONES[int(ch)] for ch in frac)
285
+ return f"{sign}{integer_to_armenian(int(whole))} ամբողջ {frac_words}".strip()
286
+
287
+ return f"{sign}{integer_to_armenian(int(raw))}".strip()
288
+
289
+
290
+ def ordinal_to_armenian(value: int) -> str:
291
+ if value in ORDINALS:
292
+ return ORDINALS[value]
293
+ if value < 100:
294
+ return integer_to_armenian(value) + "երորդ"
295
+
296
+ words = integer_to_armenian(value).split()
297
+ words[-1] = ordinal_to_armenian(int_to_last_component(value))
298
+ return " ".join(words)
299
+
300
+
301
+ def int_to_last_component(value: int) -> int:
302
+ if value % 100:
303
+ return value % 100
304
+ if value % 1000:
305
+ return value % 1000
306
+ if value % 1_000_000:
307
+ return value % 1_000_000
308
+ return value
309
+
310
+
311
+ def day_to_date_armenian(day: int) -> str:
312
+ if day in DAY_DATIVE:
313
+ return DAY_DATIVE[day]
314
+ if not 1 <= day <= 31:
315
+ raise ValueError(f"Invalid day: {day}")
316
+ raise ValueError(f"Invalid day: {day}")
317
+
318
+
319
+ def expand_numeric_date(day: int, month: int, year: int) -> str:
320
+ if month not in MONTHS_GENITIVE or not 1 <= day <= 31:
321
+ raise ValueError(f"Invalid date: {day}.{month}.{year}")
322
+ # Keep validation lightweight; this frontend is a normalizer, not a calendar.
323
+ return f"{MONTHS_GENITIVE[month]} {day_to_date_armenian(day)}, {integer_to_armenian(year)} թվականին"
324
+
325
+
326
+ def _replace_dmy(match: re.Match[str]) -> str:
327
+ day, month, year = (int(x) for x in match.groups())
328
+ try:
329
+ return expand_numeric_date(day, month, year)
330
+ except ValueError:
331
+ return match.group(0)
332
+
333
+
334
+ def _replace_ymd(match: re.Match[str]) -> str:
335
+ year, month, day = (int(x) for x in match.groups())
336
+ try:
337
+ return expand_numeric_date(day, month, year)
338
+ except ValueError:
339
+ return match.group(0)
340
+
341
+
342
+ def _replace_time(match: re.Match[str]) -> str:
343
+ hour = int(match.group(1))
344
+ minute = int(match.group(2))
345
+ second = int(match.group(3)) if match.group(3) is not None else None
346
+ if hour > 23:
347
+ return match.group(0)
348
+ text = f"ժամը {integer_to_armenian(hour)}"
349
+ if minute:
350
+ text += f" անց {integer_to_armenian(minute)}"
351
+ if second:
352
+ text += f" և {integer_to_armenian(second)} վայրկյան"
353
+ return text
354
+
355
+
356
+ def _replace_percent(match: re.Match[str]) -> str:
357
+ return f"{number_to_armenian(match.group(1))} տոկոս"
358
+
359
+
360
+ def _replace_currency_prefix(match: re.Match[str]) -> str:
361
+ currency = CURRENCY_NAMES[match.group(1).lower()]
362
+ return f"{number_to_armenian(match.group(2))} {currency}"
363
+
364
+
365
+ def _replace_currency_suffix(match: re.Match[str]) -> str:
366
+ currency = CURRENCY_NAMES[match.group(2).lower()]
367
+ return f"{number_to_armenian(match.group(1))} {currency}"
368
+
369
+
370
+ def _replace_ordinal(match: re.Match[str]) -> str:
371
+ return ordinal_to_armenian(int(match.group(1)))
372
+
373
+
374
+ def _replace_range(match: re.Match[str]) -> str:
375
+ return f"{number_to_armenian(match.group(1))}ից {number_to_armenian(match.group(2))}"
376
+
377
+
378
+ def _replace_number(match: re.Match[str]) -> str:
379
+ try:
380
+ return number_to_armenian(match.group(1))
381
+ except ValueError:
382
+ return match.group(0)
383
+
384
+
385
+ def expand_armenian_text(text: str) -> str:
386
+ s = normalize_unicode_text(text)
387
+ s = DATE_DMY_RE.sub(_replace_dmy, s)
388
+ s = DATE_YMD_RE.sub(_replace_ymd, s)
389
+ s = TIME_RE.sub(_replace_time, s)
390
+ s = CURRENCY_PREFIX_RE.sub(_replace_currency_prefix, s)
391
+ s = CURRENCY_SUFFIX_RE.sub(_replace_currency_suffix, s)
392
+ s = PERCENT_RE.sub(_replace_percent, s)
393
+ s = ORDINAL_RE.sub(_replace_ordinal, s)
394
+ s = RANGE_RE.sub(_replace_range, s)
395
+ s = PLAIN_NUMBER_RE.sub(_replace_number, s)
396
+ return cleanup_spacing(s)
397
+
398
+
399
+ def cleanup_spacing(text: str) -> str:
400
+ s = SPACE_RE.sub(" ", text.strip())
401
+ s = re.sub(r"\s+([,.;:!?։՝՞՜])", r"\1", s)
402
+ s = re.sub(r"([,;:!?։])(?=\S)", r"\1 ", s)
403
+ return SPACE_RE.sub(" ", s).strip()
404
+
405
+
406
+ def normalize_for_tts(text: str, language: Optional[str] = "hy") -> str:
407
+ s = normalize_unicode_text(text)
408
+ if looks_armenian_context(s, language=language):
409
+ return expand_armenian_text(s)
410
+ return cleanup_spacing(s)
411
+
412
+
413
+ def _iter_issue_matches(text: str) -> Iterable[TextIssue]:
414
+ patterns = [
415
+ ("url", URL_RE),
416
+ ("email", EMAIL_RE),
417
+ ("date", DATE_DMY_RE),
418
+ ("date", DATE_YMD_RE),
419
+ ("time", TIME_RE),
420
+ ("currency", CURRENCY_PREFIX_RE),
421
+ ("currency", CURRENCY_SUFFIX_RE),
422
+ ("percent", PERCENT_RE),
423
+ ("ordinal", ORDINAL_RE),
424
+ ("range", RANGE_RE),
425
+ ("number", PLAIN_NUMBER_RE),
426
+ ("symbol", SYMBOL_RE),
427
+ ("latin_token", LATIN_TOKEN_RE),
428
+ ]
429
+ occupied: list[tuple[int, int]] = []
430
+ for kind, pattern in patterns:
431
+ for match in pattern.finditer(text):
432
+ start, end = match.span()
433
+ if any(start < old_end and end > old_start for old_start, old_end in occupied):
434
+ continue
435
+ occupied.append((start, end))
436
+ yield TextIssue(kind=kind, value=match.group(0), start=start, end=end)
437
+
438
+
439
+ def find_text_frontend_issues(
440
+ text: str,
441
+ language: Optional[str] = "hy",
442
+ ) -> List[TextIssue]:
443
+ s = normalize_unicode_text(text)
444
+ if not looks_armenian_context(s, language=language):
445
+ return []
446
+ return sorted(_iter_issue_matches(s), key=lambda issue: (issue.start, issue.end))
447
+
448
+
449
+ def issues_as_dicts(issues: Iterable[TextIssue]) -> list[dict]:
450
+ return [issue.as_dict() for issue in issues]
runtime/omnivoice/utils/audio.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Audio I/O and processing utilities.
19
+
20
+ Provides functions for loading, resampling, silence removal,
21
+ chunking, cross-fading, and format conversion.
22
+
23
+ All public functions in this module operate on **numpy float32 arrays**
24
+ with shape ``(C, T)`` (channels-first).
25
+ """
26
+
27
+ import io
28
+ import logging
29
+
30
+ import numpy as np
31
+ import soundfile as sf
32
+ import torch
33
+ import torchaudio
34
+ from pydub import AudioSegment
35
+ from pydub.silence import detect_leading_silence, detect_nonsilent, split_on_silence
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Loading
42
+ # ---------------------------------------------------------------------------
43
+
44
+
45
+ def load_waveform(audio_path: str):
46
+ """Load audio from a file path, returning (data, sample_rate).
47
+
48
+ Tries two backends in order:
49
+ 1. soundfile — covers WAV/FLAC/OGG etc., no ffmpeg needed.
50
+ 2. librosa — covers MP3/M4A etc. via audioread + ffmpeg.
51
+
52
+ Returns:
53
+ (data, sample_rate) where data is a numpy float32 array of
54
+ shape (C, T).
55
+ """
56
+ try:
57
+ data, sr = sf.read(audio_path, dtype="float32", always_2d=True)
58
+ return data.T, sr # (T, C) → (C, T)
59
+ except Exception:
60
+ # soundfile cannot handle MP3/M4A etc., fall back to librosa.
61
+ import librosa
62
+
63
+ data, sr = librosa.load(audio_path, sr=None, mono=False)
64
+ if data.ndim == 1:
65
+ data = data[np.newaxis, :]
66
+ return data, sr
67
+
68
+
69
+ def load_audio(audio_path: str, sampling_rate: int) -> np.ndarray:
70
+ """Load a waveform from file and resample to the target rate.
71
+
72
+ Parameters:
73
+ audio_path: path of the audio.
74
+ sampling_rate: target sampling rate.
75
+
76
+ Returns:
77
+ Numpy float32 array of shape (1, T).
78
+ """
79
+ data, sr = load_waveform(audio_path)
80
+
81
+ if data.shape[0] > 1:
82
+ data = np.mean(data, axis=0, keepdims=True)
83
+ if sr != sampling_rate:
84
+ data = torchaudio.functional.resample(
85
+ torch.from_numpy(data), orig_freq=sr, new_freq=sampling_rate
86
+ ).numpy()
87
+
88
+ return data
89
+
90
+
91
+ def load_audio_bytes(raw: bytes, sampling_rate: int) -> np.ndarray:
92
+ """Load audio from in-memory bytes and resample.
93
+
94
+ Parameters:
95
+ raw: raw audio file bytes (e.g. from WebDataset).
96
+ sampling_rate: target sampling rate.
97
+
98
+ Returns:
99
+ Numpy float32 array of shape (1, T).
100
+ """
101
+ buf = io.BytesIO(raw)
102
+
103
+ try:
104
+ data, sr = sf.read(buf, dtype="float32", always_2d=True)
105
+ data = data.T # (T, C) → (C, T)
106
+ except Exception:
107
+ import librosa
108
+
109
+ buf.seek(0)
110
+ data, sr = librosa.load(buf, sr=None, mono=False)
111
+ if data.ndim == 1:
112
+ data = data[np.newaxis, :]
113
+
114
+ if data.shape[0] > 1:
115
+ data = np.mean(data, axis=0, keepdims=True)
116
+ if sr != sampling_rate:
117
+ data = torchaudio.functional.resample(
118
+ torch.from_numpy(data), orig_freq=sr, new_freq=sampling_rate
119
+ ).numpy()
120
+
121
+ return data
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Audio processing (all numpy in / numpy out)
126
+ # ---------------------------------------------------------------------------
127
+
128
+
129
+ def numpy_to_audiosegment(audio: np.ndarray, sample_rate: int) -> AudioSegment:
130
+ """Convert a numpy float32 array of shape (C, T) to a pydub AudioSegment."""
131
+ audio_int = (audio * 32768.0).clip(-32768, 32767).astype(np.int16)
132
+ if audio_int.shape[0] > 1:
133
+ audio_int = audio_int.T.flatten() # interleave channels
134
+ return AudioSegment(
135
+ data=audio_int.tobytes(),
136
+ sample_width=2,
137
+ frame_rate=sample_rate,
138
+ channels=audio.shape[0],
139
+ )
140
+
141
+
142
+ def audiosegment_to_numpy(aseg: AudioSegment) -> np.ndarray:
143
+ """Convert a pydub AudioSegment to a numpy float32 array of shape (C, T)."""
144
+ data = np.array(aseg.get_array_of_samples()).astype(np.float32) / 32768.0
145
+ if aseg.channels == 1:
146
+ return data[np.newaxis, :]
147
+ return data.reshape(-1, aseg.channels).T
148
+
149
+
150
+ def remove_silence(
151
+ audio: np.ndarray,
152
+ sampling_rate: int,
153
+ mid_sil: int = 300,
154
+ lead_sil: int = 100,
155
+ trail_sil: int = 300,
156
+ ) -> np.ndarray:
157
+ """Remove middle silences longer than *mid_sil* ms and trim edge silences.
158
+
159
+ Parameters:
160
+ audio: numpy array with shape (C, T).
161
+ sampling_rate: sampling rate of the audio.
162
+ mid_sil: middle-silence threshold in ms (0 to skip).
163
+ lead_sil: kept leading silence in ms.
164
+ trail_sil: kept trailing silence in ms.
165
+
166
+ Returns:
167
+ Numpy array with shape (C, T').
168
+ """
169
+ wave = numpy_to_audiosegment(audio, sampling_rate)
170
+
171
+ if mid_sil > 0:
172
+ non_silent_segs = split_on_silence(
173
+ wave,
174
+ min_silence_len=mid_sil,
175
+ silence_thresh=-50,
176
+ keep_silence=mid_sil,
177
+ seek_step=10,
178
+ )
179
+ wave = AudioSegment.silent(duration=0)
180
+ for seg in non_silent_segs:
181
+ wave += seg
182
+
183
+ wave = remove_silence_edges(wave, lead_sil, trail_sil, -50)
184
+
185
+ return audiosegment_to_numpy(wave)
186
+
187
+
188
+ def remove_silence_edges(
189
+ audio: AudioSegment,
190
+ lead_sil: int = 100,
191
+ trail_sil: int = 300,
192
+ silence_threshold: float = -50,
193
+ ) -> AudioSegment:
194
+ """Remove edge silences, keeping *lead_sil* / *trail_sil* ms."""
195
+ start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
196
+ start_idx = max(0, start_idx - lead_sil)
197
+ audio = audio[start_idx:]
198
+
199
+ audio = audio.reverse()
200
+ start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
201
+ start_idx = max(0, start_idx - trail_sil)
202
+ audio = audio[start_idx:]
203
+ audio = audio.reverse()
204
+
205
+ return audio
206
+
207
+
208
+ def fade_and_pad_audio(
209
+ audio: np.ndarray,
210
+ pad_duration: float = 0.1,
211
+ fade_duration: float = 0.1,
212
+ sample_rate: int = 24000,
213
+ ) -> np.ndarray:
214
+ """Apply fade-in/out and pad with silence to prevent clicks.
215
+
216
+ Args:
217
+ audio: numpy array of shape (C, T).
218
+ pad_duration: silence padding duration per side (seconds).
219
+ fade_duration: fade curve duration (seconds).
220
+ sample_rate: audio sampling rate.
221
+
222
+ Returns:
223
+ Processed numpy array of shape (C, T_new).
224
+ """
225
+ if audio.shape[-1] == 0:
226
+ return audio
227
+
228
+ fade_samples = int(fade_duration * sample_rate)
229
+ pad_samples = int(pad_duration * sample_rate)
230
+
231
+ processed = audio.copy()
232
+
233
+ if fade_samples > 0:
234
+ k = min(fade_samples, processed.shape[-1] // 2)
235
+ if k > 0:
236
+ fade_in = np.linspace(0, 1, k, dtype=np.float32)[np.newaxis, :]
237
+ processed[..., :k] *= fade_in
238
+
239
+ fade_out = np.linspace(1, 0, k, dtype=np.float32)[np.newaxis, :]
240
+ processed[..., -k:] *= fade_out
241
+
242
+ if pad_samples > 0:
243
+ silence = np.zeros(
244
+ (processed.shape[0], pad_samples),
245
+ dtype=processed.dtype,
246
+ )
247
+ processed = np.concatenate([silence, processed, silence], axis=-1)
248
+
249
+ return processed
250
+
251
+
252
+ def trim_long_audio(
253
+ audio: np.ndarray,
254
+ sampling_rate: int,
255
+ max_duration: float = 15.0,
256
+ min_duration: float = 3.0,
257
+ trim_threshold: float = 20.0,
258
+ ) -> np.ndarray:
259
+ """Trim audio to <= *max_duration* by splitting at the largest silence gap.
260
+
261
+ Only trims when the audio exceeds *trim_threshold* seconds.
262
+
263
+ Args:
264
+ audio: numpy array of shape (C, T).
265
+ sampling_rate: audio sampling rate.
266
+ max_duration: maximum duration in seconds.
267
+ min_duration: minimum duration in seconds.
268
+ trim_threshold: only trim if audio is longer than this (seconds).
269
+
270
+ Returns:
271
+ Trimmed numpy array.
272
+ """
273
+ duration = audio.shape[-1] / sampling_rate
274
+ if duration <= trim_threshold:
275
+ return audio
276
+
277
+ seg = numpy_to_audiosegment(audio, sampling_rate)
278
+ nonsilent = detect_nonsilent(
279
+ seg, min_silence_len=100, silence_thresh=-40, seek_step=10
280
+ )
281
+ if not nonsilent:
282
+ return audio
283
+
284
+ max_ms = int(max_duration * 1000)
285
+ min_ms = int(min_duration * 1000)
286
+
287
+ best_split = 0
288
+ for start, end in nonsilent:
289
+ if start > best_split and start <= max_ms:
290
+ best_split = start
291
+ if end > max_ms:
292
+ break
293
+
294
+ if best_split < min_ms:
295
+ best_split = min(max_ms, len(seg))
296
+
297
+ trimmed = seg[:best_split]
298
+ return audiosegment_to_numpy(trimmed)
299
+
300
+
301
+ def cross_fade_chunks(
302
+ chunks: list[np.ndarray],
303
+ sample_rate: int,
304
+ silence_duration: float = 0.3,
305
+ ) -> np.ndarray:
306
+ """Concatenate audio chunks with silence gaps and cross-fade at boundaries.
307
+
308
+ Args:
309
+ chunks: list of numpy arrays, each (C, T).
310
+ sample_rate: audio sample rate.
311
+ silence_duration: total silence gap duration in seconds.
312
+
313
+ Returns:
314
+ Merged numpy array (C, T_total).
315
+ """
316
+ if len(chunks) == 1:
317
+ return chunks[0]
318
+
319
+ total_n = int(silence_duration * sample_rate)
320
+ fade_n = total_n // 3
321
+ silence_n = fade_n
322
+ merged = chunks[0].copy()
323
+
324
+ for chunk in chunks[1:]:
325
+ parts = [merged]
326
+
327
+ fout_n = min(fade_n, merged.shape[-1])
328
+ if fout_n > 0:
329
+ w_out = np.linspace(1, 0, fout_n, dtype=np.float32)[np.newaxis, :]
330
+ parts[-1][..., -fout_n:] *= w_out
331
+
332
+ parts.append(np.zeros((chunks[0].shape[0], silence_n), dtype=np.float32))
333
+
334
+ fade_in = chunk.copy()
335
+ fin_n = min(fade_n, fade_in.shape[-1])
336
+ if fin_n > 0:
337
+ w_in = np.linspace(0, 1, fin_n, dtype=np.float32)[np.newaxis, :]
338
+ fade_in[..., :fin_n] *= w_in
339
+
340
+ parts.append(fade_in)
341
+ merged = np.concatenate(parts, axis=-1)
342
+
343
+ return merged
runtime/omnivoice/utils/common.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Shared utility functions."""
19
+
20
+ import argparse
21
+ import random
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+
27
+ def str2bool(v):
28
+ """Used in argparse.ArgumentParser.add_argument to indicate
29
+ that a type is a bool type and user can enter
30
+
31
+ - yes, true, t, y, 1, to represent True
32
+ - no, false, f, n, 0, to represent False
33
+
34
+ See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
35
+ """
36
+ if isinstance(v, bool):
37
+ return v
38
+ if v.lower() in ("yes", "true", "t", "y", "1"):
39
+ return True
40
+ elif v.lower() in ("no", "false", "f", "n", "0"):
41
+ return False
42
+ else:
43
+ raise argparse.ArgumentTypeError("Boolean value expected.")
44
+
45
+
46
+ def fix_random_seed(random_seed: int):
47
+ """
48
+ Set the same random seed for the libraries and modules.
49
+ Includes the ``random`` module, numpy, and torch.
50
+ """
51
+ random.seed(random_seed)
52
+ np.random.seed(random_seed)
53
+ torch.random.manual_seed(random_seed)
54
+ # Ensure deterministic ID creation
55
+ rd = random.Random()
56
+ rd.seed(random_seed)
runtime/omnivoice/utils/duration.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Text duration estimation for TTS generation.
19
+
20
+ Provides ``RuleDurationEstimator``, which estimates audio duration from text
21
+ using character phonetic weights across 600+ languages. Used by
22
+ ``OmniVoice.generate()`` to determine output length when no duration is specified.
23
+ """
24
+
25
+ import bisect
26
+ import unicodedata
27
+ from functools import lru_cache
28
+ from typing import Optional
29
+
30
+
31
+ class RuleDurationEstimator:
32
+ def __init__(self):
33
+ # ==========================================
34
+ # 1. Phonetic Weights Table
35
+ # ==========================================
36
+ # The weight represents the relative speaking time compared to
37
+ # a standard Latin letter.
38
+ # Benchmark: 1.0 = One Latin Character (~40-50ms)
39
+ self.weights = {
40
+ # --- Logographic (1 char = full syllable/word) ---
41
+ "cjk": 3.0, # Chinese, Japanese Kanji, etc.
42
+ # --- Syllabic / Blocks
43
+ "hangul": 2.5, # Korean Hangul
44
+ "kana": 2.2, # Japanese Hiragana/Katakana
45
+ "ethiopic": 3.0, # Amharic/Ge'ez
46
+ "yi": 3.0, # Yi script
47
+ # --- Abugida (Consonant-Vowel complexes) ---
48
+ "indic": 1.8, # Hindi, Bengali, Tamil, etc.
49
+ "thai_lao": 1.5, # Thai, Lao
50
+ "khmer_myanmar": 1.8, # Khmer, Myanmar
51
+ # --- Abjad (Consonant-heavy) ---
52
+ "arabic": 1.5, # Arabic, Persian, Urdu
53
+ "hebrew": 1.5, # Hebrew
54
+ # --- Alphabet (Segmental) ---
55
+ "latin": 1.0, # English, Spanish, French, Vietnamese, etc. (Baseline)
56
+ "cyrillic": 1.0, # Russian, Ukrainian
57
+ "greek": 1.0, # Greek
58
+ "armenian": 1.0, # Armenian
59
+ "georgian": 1.0, # Georgian
60
+ # --- Symbols & Misc ---
61
+ "punctuation": 0.5, # Pause capability
62
+ "space": 0.2, # Word boundary/Breath (0.05 / 0.22)
63
+ "digit": 3.5, # Numbers
64
+ "mark": 0.0, # Diacritics/Accents (Silent modifiers)
65
+ "default": 1.0, # Fallback for unknown scripts
66
+ }
67
+
68
+ # ==========================================
69
+ # 2. Unicode Range Mapping
70
+ # ==========================================
71
+ # Format: (End_Codepoint, Type_Key)
72
+ # Used for fast binary search (bisect).
73
+ self.ranges = [
74
+ (0x02AF, "latin"), # Latin (Basic, Supplement, Ext, IPA)
75
+ (0x03FF, "greek"), # Greek & Coptic
76
+ (0x052F, "cyrillic"), # Cyrillic
77
+ (0x058F, "armenian"), # Armenian
78
+ (0x05FF, "hebrew"), # Hebrew
79
+ (0x077F, "arabic"), # Arabic, Syriac, Arabic Supplement
80
+ (0x089F, "arabic"), # Arabic Extended-B (+ Syriac Supp)
81
+ (0x08FF, "arabic"), # Arabic Extended-A
82
+ (0x097F, "indic"), # Devanagari
83
+ (0x09FF, "indic"), # Bengali
84
+ (0x0A7F, "indic"), # Gurmukhi
85
+ (0x0AFF, "indic"), # Gujarati
86
+ (0x0B7F, "indic"), # Oriya
87
+ (0x0BFF, "indic"), # Tamil
88
+ (0x0C7F, "indic"), # Telugu
89
+ (0x0CFF, "indic"), # Kannada
90
+ (0x0D7F, "indic"), # Malayalam
91
+ (0x0DFF, "indic"), # Sinhala
92
+ (0x0EFF, "thai_lao"), # Thai & Lao
93
+ (0x0FFF, "indic"), # Tibetan (Abugida)
94
+ (0x109F, "khmer_myanmar"), # Myanmar
95
+ (0x10FF, "georgian"), # Georgian
96
+ (0x11FF, "hangul"), # Hangul Jamo
97
+ (0x137F, "ethiopic"), # Ethiopic
98
+ (0x139F, "ethiopic"), # Ethiopic Supplement
99
+ (0x13FF, "default"), # Cherokee
100
+ (0x167F, "default"), # Canadian Aboriginal Syllabics
101
+ (0x169F, "default"), # Ogham
102
+ (0x16FF, "default"), # Runic
103
+ (0x171F, "default"), # Tagalog (Baybayin)
104
+ (0x173F, "default"), # Hanunoo
105
+ (0x175F, "default"), # Buhid
106
+ (0x177F, "default"), # Tagbanwa
107
+ (0x17FF, "khmer_myanmar"), # Khmer
108
+ (0x18AF, "default"), # Mongolian
109
+ (0x18FF, "default"), # Canadian Aboriginal Syllabics Ext
110
+ (0x194F, "indic"), # Limbu
111
+ (0x19DF, "indic"), # Tai Le & New Tai Lue
112
+ (0x19FF, "khmer_myanmar"), # Khmer Symbols
113
+ (0x1A1F, "indic"), # Buginese
114
+ (0x1AAF, "indic"), # Tai Tham
115
+ (0x1B7F, "indic"), # Balinese
116
+ (0x1BBF, "indic"), # Sundanese
117
+ (0x1BFF, "indic"), # Batak
118
+ (0x1C4F, "indic"), # Lepcha
119
+ (0x1C7F, "indic"), # Ol Chiki (Santali)
120
+ (0x1C8F, "cyrillic"), # Cyrillic Extended-C
121
+ (0x1CBF, "georgian"), # Georgian Extended
122
+ (0x1CCF, "indic"), # Sundanese Supplement
123
+ (0x1CFF, "indic"), # Vedic Extensions
124
+ (0x1D7F, "latin"), # Phonetic Extensions
125
+ (0x1DBF, "latin"), # Phonetic Extensions Supplement
126
+ (0x1DFF, "default"), # Combining Diacritical Marks Supplement
127
+ (0x1EFF, "latin"), # Latin Extended Additional (Vietnamese)
128
+ (0x309F, "kana"), # Hiragana
129
+ (0x30FF, "kana"), # Katakana
130
+ (0x312F, "cjk"), # Bopomofo (Pinyin)
131
+ (0x318F, "hangul"), # Hangul Compatibility Jamo
132
+ (0x9FFF, "cjk"), # CJK Unified Ideographs (Main)
133
+ (0xA4CF, "yi"), # Yi Syllables
134
+ (0xA4FF, "default"), # Lisu
135
+ (0xA63F, "default"), # Vai
136
+ (0xA69F, "cyrillic"), # Cyrillic Extended-B
137
+ (0xA6FF, "default"), # Bamum
138
+ (0xA7FF, "latin"), # Latin Extended-D
139
+ (0xA82F, "indic"), # Syloti Nagri
140
+ (0xA87F, "default"), # Phags-pa
141
+ (0xA8DF, "indic"), # Saurashtra
142
+ (0xA8FF, "indic"), # Devanagari Extended
143
+ (0xA92F, "indic"), # Kayah Li
144
+ (0xA95F, "indic"), # Rejang
145
+ (0xA97F, "hangul"), # Hangul Jamo Extended-A
146
+ (0xA9DF, "indic"), # Javanese
147
+ (0xA9FF, "khmer_myanmar"), # Myanmar Extended-B
148
+ (0xAA5F, "indic"), # Cham
149
+ (0xAA7F, "khmer_myanmar"), # Myanmar Extended-A
150
+ (0xAADF, "indic"), # Tai Viet
151
+ (0xAAFF, "indic"), # Meetei Mayek Extensions
152
+ (0xAB2F, "ethiopic"), # Ethiopic Extended-A
153
+ (0xAB6F, "latin"), # Latin Extended-E
154
+ (0xABBF, "default"), # Cherokee Supplement
155
+ (0xABFF, "indic"), # Meetei Mayek
156
+ (0xD7AF, "hangul"), # Hangul Syllables
157
+ (0xFAFF, "cjk"), # CJK Compatibility
158
+ (0xFDFF, "arabic"), # Arabic Presentation Forms-A
159
+ (0xFE6F, "default"), # Variation Selectors
160
+ (0xFEFF, "arabic"), # Arabic Presentation Forms-B
161
+ (0xFFEF, "latin"), # Fullwidth Latin
162
+ ]
163
+ self.breakpoints = [r[0] for r in self.ranges]
164
+
165
+ @lru_cache(maxsize=4096)
166
+ def _get_char_weight(self, char):
167
+ """Determines the weight of a single character."""
168
+ code = ord(char)
169
+ if (65 <= code <= 90) or (97 <= code <= 122):
170
+ return self.weights["latin"]
171
+ if code == 32:
172
+ return self.weights["space"]
173
+
174
+ # Ignore arabic Tatweel
175
+ if code == 0x0640:
176
+ return self.weights["mark"]
177
+
178
+ category = unicodedata.category(char)
179
+
180
+ if category.startswith("M"):
181
+ return self.weights["mark"]
182
+
183
+ if category.startswith("P") or category.startswith("S"):
184
+ return self.weights["punctuation"]
185
+
186
+ if category.startswith("Z"):
187
+ return self.weights["space"]
188
+
189
+ if category.startswith("N"):
190
+ return self.weights["digit"]
191
+
192
+ # 3. Binary search for Unicode Block (此时区间里绝不会再混进标点符号)
193
+ idx = bisect.bisect_left(self.breakpoints, code)
194
+ if idx < len(self.ranges):
195
+ script_type = self.ranges[idx][1]
196
+ return self.weights.get(script_type, self.weights["default"])
197
+
198
+ # 4. Handle upper planes (CJK Ext B/C/D, Historic scripts)
199
+ if code > 0x20000:
200
+ return self.weights["cjk"]
201
+
202
+ return self.weights["default"]
203
+
204
+ def calculate_total_weight(self, text):
205
+ """Sums up the normalized weights for a string."""
206
+ return sum(self._get_char_weight(c) for c in text)
207
+
208
+ def estimate_duration(
209
+ self,
210
+ target_text: str,
211
+ ref_text: str,
212
+ ref_duration: float,
213
+ low_threshold: Optional[float] = 50,
214
+ boost_strength: float = 3,
215
+ ) -> float:
216
+ """
217
+
218
+ Args:
219
+ target_text (str): The text for which we want to estimate the duration.
220
+ ref_text (str): The reference text that was used to measure
221
+ the ref_duration.
222
+ ref_duration (float): The actual duration it took
223
+ to speak the ref_text.
224
+ low_threshold (float): The minimum duration threshold below which the
225
+ estimation will be considered unreliable.
226
+ boost_strength (float): Controls the power-curve boost for short durations.
227
+ Higher values boost small durations more aggressively.
228
+ 1 = no boost (linear), 2 = sqrt-like
229
+
230
+ Returns:
231
+ float: The estimated duration for the target_text based
232
+ on the ref_text and ref_duration.
233
+ """
234
+ if ref_duration <= 0 or not ref_text:
235
+ return 0.0
236
+
237
+ ref_weight = self.calculate_total_weight(ref_text)
238
+ if ref_weight == 0:
239
+ return 0.0
240
+
241
+ speed_factor = ref_weight / ref_duration
242
+ target_weight = self.calculate_total_weight(target_text)
243
+
244
+ estimated_duration = target_weight / speed_factor
245
+ if low_threshold is not None and estimated_duration < low_threshold:
246
+ alpha = 1.0 / boost_strength
247
+ return low_threshold * (estimated_duration / low_threshold) ** alpha
248
+ else:
249
+ return estimated_duration
250
+
251
+
252
+ # ==========================================
253
+ # Example Usage
254
+ # ==========================================
255
+ if __name__ == "__main__":
256
+ estimator = RuleDurationEstimator()
257
+
258
+ ref_txt = "Hello, world."
259
+ ref_dur = 1.5
260
+
261
+ test_cases = [
262
+ ("Hindi (With complex marks)", "नमस्ते दुनिया"),
263
+ ("Arabic (With vowels)", "مَرْحَبًا بِالْعَالَم"),
264
+ ("Vietnamese (Lots of diacritics)", "Chào thế giới"),
265
+ ("Chinese", "你好,世界!"),
266
+ ("Mixed Emoji", "Hello 🌍! This is fun 🎉"),
267
+ ]
268
+
269
+ print("--- Reference ---")
270
+ print(f"Reference Text: '{ref_txt}'")
271
+ print(f"Reference Duration: {ref_dur}s")
272
+ print("-" * 30)
273
+
274
+ for lang, txt in test_cases:
275
+ est_time = estimator.estimate_duration(txt, ref_txt, ref_dur)
276
+ weight = estimator.calculate_total_weight(txt)
277
+
278
+ print(f"[{lang}]")
279
+ print(f"Text: {txt}")
280
+ print(f"Total Weight: {weight:.2f}")
281
+ print(f"Estimated Duration: {est_time:.2f} s")
282
+ print("-" * 30)
runtime/omnivoice/utils/lang_map.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Language name to ISO 639-3 code mapping.
19
+
20
+ Auto-generated from ``docs/lang_id_name_map.tsv``. Provides ``LANG_NAME_TO_ID``
21
+ (for resolving language names to codes) and ``LANG_IDS`` (the set of supported
22
+ ISO 639-3 codes). Used by ``OmniVoice.generate()`` to resolve user-provided
23
+ language names.
24
+ """
25
+
26
+ # Auto-generated from docs/lang_id_name_map.tsv
27
+ # Maps lowercase language name -> language ID code
28
+
29
+ LANG_NAME_TO_ID = {
30
+ "abadi": "kbt",
31
+ "abkhazian": "ab",
32
+ "abron": "abr",
33
+ "abua": "abn",
34
+ "adamawa fulfulde": "fub",
35
+ "adyghe": "ady",
36
+ "afade": "aal",
37
+ "afrikaans": "af",
38
+ "agwagwune": "yay",
39
+ "aja (benin)": "ajg",
40
+ "akebu": "keu",
41
+ "alago": "ala",
42
+ "albanian": "sq",
43
+ "algerian arabic": "arq",
44
+ "algerian saharan arabic": "aao",
45
+ "ambo-pasco quechua": "qva",
46
+ "ambonese malay": "abs",
47
+ "amdo tibetan": "adx",
48
+ "amharic": "am",
49
+ "anaang": "anw",
50
+ "angika": "anp",
51
+ "antankarana malagasy": "xmv",
52
+ "aragonese": "an",
53
+ "arbëreshë albanian": "aae",
54
+ "arequipa-la unión quechua": "qxu",
55
+ "armenian": "hy",
56
+ "ashe": "ahs",
57
+ "ashéninka perené": "prq",
58
+ "askopan": "eiv",
59
+ "assamese": "as",
60
+ "asturian": "ast",
61
+ "atayal": "tay",
62
+ "awak": "awo",
63
+ "ayacucho quechua": "quy",
64
+ "azerbaijani": "az",
65
+ "baatonum": "bba",
66
+ "bacama": "bcy",
67
+ "bade": "bde",
68
+ "bafia": "ksf",
69
+ "bafut": "bfd",
70
+ "bagirmi fulfulde": "fui",
71
+ "bago-kusuntu": "bqg",
72
+ "baharna arabic": "abv",
73
+ "bakoko": "bkh",
74
+ "balanta-ganja": "bjt",
75
+ "balti": "bft",
76
+ "bamenyam": "bce",
77
+ "bamun": "bax",
78
+ "bangwinji": "bsj",
79
+ "banjar": "bjn",
80
+ "bankon": "abb",
81
+ "baoulé": "bci",
82
+ "bara malagasy": "bhr",
83
+ "barok": "bjk",
84
+ "basa (cameroon)": "bas",
85
+ "basa (nigeria)": "bzw",
86
+ "bashkir": "ba",
87
+ "basque": "eu",
88
+ "batak mandailing": "btm",
89
+ "batanga": "bnm",
90
+ "bateri": "btv",
91
+ "bats": "bbl",
92
+ "bayot": "bda",
93
+ "bebele": "beb",
94
+ "belarusian": "be",
95
+ "bengali": "bn",
96
+ "betawi": "bew",
97
+ "bhili": "bhb",
98
+ "bhojpuri": "bho",
99
+ "bilur": "bxf",
100
+ "bima": "bhp",
101
+ "bodo": "brx",
102
+ "boghom": "bux",
103
+ "bokyi": "bky",
104
+ "bomu": "bmq",
105
+ "bondei": "bou",
106
+ "borgu fulfulde": "fue",
107
+ "bosnian": "bs",
108
+ "brahui": "brh",
109
+ "braj": "bra",
110
+ "breton": "br",
111
+ "buduma": "bdm",
112
+ "buginese": "bug",
113
+ "bukharic": "bhh",
114
+ "bulgarian": "bg",
115
+ "bulu (cameroon)": "bum",
116
+ "bundeli": "bns",
117
+ "bunun": "bnn",
118
+ "bura-pabir": "bwr",
119
+ "burak": "bys",
120
+ "burmese": "my",
121
+ "burushaski": "bsk",
122
+ "cacaloxtepec mixtec": "miu",
123
+ "cajatambo north lima quechua": "qvl",
124
+ "cakfem-mushere": "cky",
125
+ "cameroon pidgin": "wes",
126
+ "campidanese sardinian": "sro",
127
+ "cantonese": "yue",
128
+ "catalan": "ca",
129
+ "cebuano": "ceb",
130
+ "cen": "cen",
131
+ "central kurdish": "ckb",
132
+ "central nahuatl": "nhn",
133
+ "central pame": "pbs",
134
+ "central pashto": "pst",
135
+ "central puebla nahuatl": "ncx",
136
+ "central tarahumara": "tar",
137
+ "central yupik": "esu",
138
+ "central-eastern niger fulfulde": "fuq",
139
+ "chadian arabic": "shu",
140
+ "chichewa": "ny",
141
+ "chichicapan zapotec": "zpv",
142
+ "chiga": "cgg",
143
+ "chimalapa zoque": "zoh",
144
+ "chimborazo highland quichua": "qug",
145
+ "chinese": "zh",
146
+ "chiquián ancash quechua": "qxa",
147
+ "chitwania tharu": "the",
148
+ "chokwe": "cjk",
149
+ "chuvash": "cv",
150
+ "cibak": "ckl",
151
+ "coastal konjo": "kjc",
152
+ "copainalá zoque": "zoc",
153
+ "cornish": "kw",
154
+ "corongo ancash quechua": "qwa",
155
+ "croatian": "hr",
156
+ "cross river mbembe": "mfn",
157
+ "cuyamecalco mixtec": "xtu",
158
+ "czech": "cs",
159
+ "dadiya": "dbd",
160
+ "dagbani": "dag",
161
+ "dameli": "dml",
162
+ "danish": "da",
163
+ "dargwa": "dar",
164
+ "dazaga": "dzg",
165
+ "deccan": "dcc",
166
+ "degema": "deg",
167
+ "dera (nigeria)": "kna",
168
+ "dghwede": "dgh",
169
+ "dhatki": "mki",
170
+ "dhivehi": "dv",
171
+ "dhofari arabic": "adf",
172
+ "dijim-bwilim": "cfa",
173
+ "dogri": "dgo",
174
+ "domaaki": "dmk",
175
+ "dotyali": "dty",
176
+ "duala": "dua",
177
+ "dutch": "nl",
178
+ "dũya": "ldb",
179
+ "dyula": "dyu",
180
+ "eastern balochi": "bgp",
181
+ "eastern bolivian guaraní": "gui",
182
+ "eastern egyptian bedawi arabic": "avl",
183
+ "eastern krahn": "kqo",
184
+ "eastern mari": "mhr",
185
+ "eastern yiddish": "ydd",
186
+ "ebrié": "ebr",
187
+ "eggon": "ego",
188
+ "egyptian arabic": "arz",
189
+ "ejagham": "etu",
190
+ "eleme": "elm",
191
+ "eloyi": "afo",
192
+ "embu": "ebu",
193
+ "english": "en",
194
+ "erzya": "myv",
195
+ "esan": "ish",
196
+ "esperanto": "eo",
197
+ "estonian": "et",
198
+ "eton (cameroon)": "eto",
199
+ "ewondo": "ewo",
200
+ "extremaduran": "ext",
201
+ "fang (equatorial guinea)": "fan",
202
+ "fanti": "fat",
203
+ "farefare": "gur",
204
+ "fe'fe'": "fmp",
205
+ "filipino": "fil",
206
+ "filomena mata-coahuitlán totonac": "tlp",
207
+ "finnish": "fi",
208
+ "fipa": "fip",
209
+ "french": "fr",
210
+ "fulah": "ff",
211
+ "galician": "gl",
212
+ "gambian wolof": "wof",
213
+ "ganda": "lg",
214
+ "garhwali": "gbm",
215
+ "gawar-bati": "gwt",
216
+ "gawri": "gwc",
217
+ "gbagyi": "gbr",
218
+ "gbari": "gby",
219
+ "geji": "gyz",
220
+ "gen": "gej",
221
+ "georgian": "ka",
222
+ "german": "de",
223
+ "geser-gorom": "ges",
224
+ "gheg albanian": "aln",
225
+ "ghomálá'": "bbj",
226
+ "gidar": "gid",
227
+ "glavda": "glw",
228
+ "goan konkani": "gom",
229
+ "goaria": "gig",
230
+ "goemai": "ank",
231
+ "gola": "gol",
232
+ "greek": "el",
233
+ "guarani": "gn",
234
+ "guduf-gava": "gdf",
235
+ "guerrero amuzgo": "amu",
236
+ "gujarati": "gu",
237
+ "gujari": "gju",
238
+ "gulf arabic": "afb",
239
+ "gurgula": "ggg",
240
+ "gusii": "guz",
241
+ "gusilay": "gsl",
242
+ "gweno": "gwe",
243
+ "güilá zapotec": "ztu",
244
+ "hadothi": "hoj",
245
+ "hahon": "hah",
246
+ "haitian": "ht",
247
+ "hakha chin": "cnh",
248
+ "hakö": "hao",
249
+ "halia": "hla",
250
+ "hausa": "ha",
251
+ "hawaiian": "haw",
252
+ "hazaragi": "haz",
253
+ "hebrew": "he",
254
+ "hemba": "hem",
255
+ "herero": "hz",
256
+ "highland konjo": "kjk",
257
+ "hijazi arabic": "acw",
258
+ "hindi": "hi",
259
+ "huarijio": "var",
260
+ "huautla mazatec": "mau",
261
+ "huaxcaleca nahuatl": "nhq",
262
+ "huba": "hbb",
263
+ "huitepec mixtec": "mxs",
264
+ "hula": "hul",
265
+ "hungarian": "hu",
266
+ "hunjara-kaina ke": "hkk",
267
+ "hwana": "hwo",
268
+ "ibibio": "ibb",
269
+ "icelandic": "is",
270
+ "idakho-isukha-tiriki": "ida",
271
+ "idoma": "idu",
272
+ "igbo": "ig",
273
+ "igo": "ahl",
274
+ "ikposo": "kpo",
275
+ "ikwere": "ikw",
276
+ "imbabura highland quichua": "qvi",
277
+ "indonesian": "id",
278
+ "indus kohistani": "mvy",
279
+ "interlingua (international auxiliary language association)": "ia",
280
+ "inupiaq": "ik",
281
+ "irish": "ga",
282
+ "iron ossetic": "os",
283
+ "isekiri": "its",
284
+ "isoko": "iso",
285
+ "italian": "it",
286
+ "ito": "itw",
287
+ "itzá": "itz",
288
+ "ixtayutla mixtec": "vmj",
289
+ "izon": "ijc",
290
+ "jambi malay": "jax",
291
+ "japanese": "ja",
292
+ "jaqaru": "jqr",
293
+ "jauja wanca quechua": "qxw",
294
+ "jaunsari": "jns",
295
+ "javanese": "jv",
296
+ "jiba": "juo",
297
+ "jju": "kaj",
298
+ "judeo-moroccan arabic": "aju",
299
+ "juxtlahuaca mixtec": "vmc",
300
+ "kabardian": "kbd",
301
+ "kabras": "lkb",
302
+ "kabuverdianu": "kea",
303
+ "kabyle": "kab",
304
+ "kachi koli": "gjk",
305
+ "kairak": "ckr",
306
+ "kalabari": "ijn",
307
+ "kalasha": "kls",
308
+ "kalenjin": "kln",
309
+ "kalkoti": "xka",
310
+ "kamba": "kam",
311
+ "kamo": "kcq",
312
+ "kanauji": "bjj",
313
+ "kanembu": "kbl",
314
+ "kannada": "kn",
315
+ "karekare": "kai",
316
+ "kashmiri": "ks",
317
+ "kathoriya tharu": "tkt",
318
+ "kati": "bsh",
319
+ "kazakh": "kk",
320
+ "keiyo": "eyo",
321
+ "khams tibetan": "khg",
322
+ "khana": "ogo",
323
+ "khetrani": "xhe",
324
+ "khmer": "km",
325
+ "khowar": "khw",
326
+ "kinga": "zga",
327
+ "kinnauri": "kfk",
328
+ "kinyarwanda": "rw",
329
+ "kirghiz": "ky",
330
+ "kirya-konzəl": "fkk",
331
+ "kochila tharu": "thq",
332
+ "kohistani shina": "plk",
333
+ "kohumono": "bcs",
334
+ "kok borok": "trp",
335
+ "kol (papua new guinea)": "kol",
336
+ "kom (cameroon)": "bkm",
337
+ "koma": "kmy",
338
+ "konkani": "knn",
339
+ "konzo": "koo",
340
+ "korean": "ko",
341
+ "korwa": "kfp",
342
+ "kota (india)": "kfe",
343
+ "koti": "eko",
344
+ "kuanua": "ksd",
345
+ "kuanyama": "kj",
346
+ "kui (india)": "uki",
347
+ "kulung (nigeria)": "bbu",
348
+ "kuot": "kto",
349
+ "kushi": "kuh",
350
+ "kwambi": "kwm",
351
+ "kwasio": "nmg",
352
+ "lala-roba": "lla",
353
+ "lamang": "hia",
354
+ "lao": "lo",
355
+ "larike-wakasihu": "alo",
356
+ "lasi": "lss",
357
+ "latgalian": "ltg",
358
+ "latvian": "lv",
359
+ "levantine arabic": "apc",
360
+ "liana-seti": "ste",
361
+ "liberia kpelle": "xpe",
362
+ "liberian english": "lir",
363
+ "libyan arabic": "ayl",
364
+ "ligurian": "lij",
365
+ "lijili": "mgi",
366
+ "lingala": "ln",
367
+ "lithuanian": "lt",
368
+ "loarki": "lrk",
369
+ "logooli": "rag",
370
+ "logudorese sardinian": "src",
371
+ "loja highland quichua": "qvj",
372
+ "loloda": "loa",
373
+ "longuda": "lnu",
374
+ "loxicha zapotec": "ztp",
375
+ "luba-lulua": "lua",
376
+ "luo": "luo",
377
+ "lushai": "lus",
378
+ "luxembourgish": "lb",
379
+ "maasina fulfulde": "ffm",
380
+ "maba (chad)": "mde",
381
+ "macedo-romanian": "rup",
382
+ "macedonian": "mk",
383
+ "mada (cameroon)": "mxu",
384
+ "mafa": "maf",
385
+ "maithili": "mai",
386
+ "malay": "ms",
387
+ "malayalam": "ml",
388
+ "mali": "gcc",
389
+ "malinaltepec me'phaa": "tcf",
390
+ "maltese": "mt",
391
+ "mandara": "tbf",
392
+ "mandjak": "mfv",
393
+ "manggarai": "mqy",
394
+ "manipuri": "mni",
395
+ "mansoanka": "msw",
396
+ "manx": "gv",
397
+ "maori": "mi",
398
+ "marathi": "mr",
399
+ "marghi central": "mrt",
400
+ "marghi south": "mfm",
401
+ "maria (india)": "mrr",
402
+ "marwari (pakistan)": "mve",
403
+ "masana": "mcn",
404
+ "masikoro malagasy": "msh",
405
+ "matsés": "mcf",
406
+ "mazaltepec zapotec": "zpy",
407
+ "mazatlán mazatec": "vmz",
408
+ "mazatlán mixe": "mzl",
409
+ "mbe": "mfo",
410
+ "mbo (cameroon)": "mbo",
411
+ "mbum": "mdd",
412
+ "medumba": "byv",
413
+ "mekeo": "mek",
414
+ "meru": "mer",
415
+ "mesopotamian arabic": "acm",
416
+ "mewari": "mtr",
417
+ "min nan chinese": "nan",
418
+ "mingrelian": "xmf",
419
+ "mitlatongo mixtec": "vmm",
420
+ "miya": "mkf",
421
+ "mokpwe": "bri",
422
+ "moksha": "mdf",
423
+ "mom jango": "ver",
424
+ "mongolian": "mn",
425
+ "moroccan arabic": "ary",
426
+ "motu": "meu",
427
+ "mpiemo": "mcx",
428
+ "mpumpong": "mgg",
429
+ "mundang": "mua",
430
+ "mungaka": "mhk",
431
+ "musey": "mse",
432
+ "musgu": "mug",
433
+ "musi": "mui",
434
+ "naba": "mne",
435
+ "najdi arabic": "ars",
436
+ "nalik": "nal",
437
+ "nawdm": "nmz",
438
+ "ndonga": "ng",
439
+ "neapolitan": "nap",
440
+ "nepali": "npi",
441
+ "ngamo": "nbh",
442
+ "ngas": "anc",
443
+ "ngiemboon": "nnh",
444
+ "ngizim": "ngi",
445
+ "ngomba": "jgo",
446
+ "ngombale": "nla",
447
+ "nigerian fulfulde": "fuv",
448
+ "nigerian pidgin": "pcm",
449
+ "nimadi": "noe",
450
+ "nobiin": "fia",
451
+ "north mesopotamian arabic": "ayp",
452
+ "north moluccan malay": "max",
453
+ "northern betsimisaraka malagasy": "bmm",
454
+ "northern hindko": "hno",
455
+ "northern kurdish": "kmr",
456
+ "northern pame": "pmq",
457
+ "northern pashto": "pbu",
458
+ "northern uzbek": "uzn",
459
+ "northwest gbaya": "gya",
460
+ "norwegian": "no",
461
+ "norwegian bokmål": "nb",
462
+ "norwegian nynorsk": "nn",
463
+ "notsi": "ncf",
464
+ "nyankpa": "yes",
465
+ "nyungwe": "nyu",
466
+ "nzanyi": "nja",
467
+ "nüpode huitoto": "hux",
468
+ "occitan": "oc",
469
+ "od": "odk",
470
+ "odia": "ory",
471
+ "odual": "odu",
472
+ "omani arabic": "acx",
473
+ "orizaba nahuatl": "nlv",
474
+ "orma": "orc",
475
+ "ormuri": "oru",
476
+ "oromo": "om",
477
+ "pahari-potwari": "phr",
478
+ "paiwan": "pwn",
479
+ "panjabi": "pa",
480
+ "papuan malay": "pmy",
481
+ "parkari koli": "kvx",
482
+ "pedi": "nso",
483
+ "pero": "pip",
484
+ "persian": "fa",
485
+ "petats": "pex",
486
+ "phalura": "phl",
487
+ "piemontese": "pms",
488
+ "piya-kwonci": "piy",
489
+ "plateau malagasy": "plt",
490
+ "polish": "pl",
491
+ "poqomam": "poc",
492
+ "portuguese": "pt",
493
+ "pulaar": "fuc",
494
+ "pular": "fuf",
495
+ "puno quechua": "qxp",
496
+ "pushto": "ps",
497
+ "pökoot": "pko",
498
+ "qaqet": "byx",
499
+ "quiotepec chinantec": "chq",
500
+ "rana tharu": "thr",
501
+ "rangi": "lag",
502
+ "rapoisi": "kyx",
503
+ "ratahan": "rth",
504
+ "rayón zoque": "zor",
505
+ "romanian": "ro",
506
+ "romansh": "rm",
507
+ "rombo": "rof",
508
+ "rotokas": "roo",
509
+ "rukai": "dru",
510
+ "russian": "ru",
511
+ "sacapulteco": "quv",
512
+ "saidi arabic": "aec",
513
+ "sakalava malagasy": "skg",
514
+ "sakizaya": "szy",
515
+ "saleman": "sau",
516
+ "samba daka": "ccg",
517
+ "samba leko": "ndi",
518
+ "san felipe otlaltepec popoloca": "pow",
519
+ "san francisco del mar huave": "hue",
520
+ "san juan atzingo popoloca": "poe",
521
+ "san martín itunyoso triqui": "trq",
522
+ "san miguel el grande mixtec": "mig",
523
+ "sansi": "ssi",
524
+ "sanskrit": "sa",
525
+ "santa ana de tusi pasco quechua": "qxt",
526
+ "santa catarina albarradas zapotec": "ztn",
527
+ "santali": "sat",
528
+ "santiago del estero quichua": "qus",
529
+ "saposa": "sps",
530
+ "saraiki": "skr",
531
+ "sardinian": "sc",
532
+ "saya": "say",
533
+ "sediq": "trv",
534
+ "serbian": "sr",
535
+ "seri": "sei",
536
+ "shina": "scl",
537
+ "shona": "sn",
538
+ "siar-lak": "sjr",
539
+ "sibe": "nco",
540
+ "sicilian": "scn",
541
+ "sihuas ancash quechua": "qws",
542
+ "sikkimese": "sip",
543
+ "sinaugoro": "snc",
544
+ "sindhi": "sd",
545
+ "sindhi bhil": "sbn",
546
+ "sinhala": "si",
547
+ "sinicahua mixtec": "xti",
548
+ "sipacapense": "qum",
549
+ "siwai": "siw",
550
+ "slovak": "sk",
551
+ "slovenian": "sl",
552
+ "solos": "sol",
553
+ "somali": "so",
554
+ "soninke": "snk",
555
+ "south giziga": "giz",
556
+ "south ucayali ashéninka": "cpy",
557
+ "southeastern nochixtlán mixtec": "mxy",
558
+ "southern betsimisaraka malagasy": "bzc",
559
+ "southern pashto": "pbt",
560
+ "southern pastaza quechua": "qup",
561
+ "soyaltepec mazatec": "vmp",
562
+ "spanish": "es",
563
+ "standard arabic": "arb",
564
+ "standard moroccan tamazight": "zgh",
565
+ "sudanese arabic": "apd",
566
+ "sulka": "sua",
567
+ "svan": "sva",
568
+ "swahili": "sw",
569
+ "swedish": "sv",
570
+ "tae'": "rob",
571
+ "tahaggart tamahaq": "thv",
572
+ "taita": "dav",
573
+ "tajik": "tg",
574
+ "tamil": "ta",
575
+ "tandroy-mahafaly malagasy": "tdx",
576
+ "tangale": "tan",
577
+ "tanosy malagasy": "txy",
578
+ "tarok": "yer",
579
+ "tatar": "tt",
580
+ "tedaga": "tuq",
581
+ "telugu": "te",
582
+ "tem": "kdh",
583
+ "teop": "tio",
584
+ "tepeuxila cuicatec": "cux",
585
+ "tepinapa chinantec": "cte",
586
+ "tera": "ttr",
587
+ "terei": "buo",
588
+ "termanu": "twu",
589
+ "tesaka malagasy": "tkg",
590
+ "tetelcingo nahuatl": "nhg",
591
+ "teutila cuicatec": "cut",
592
+ "thai": "th",
593
+ "tibetan": "bo",
594
+ "tidaá mixtec": "mtx",
595
+ "tidore": "tvo",
596
+ "tigak": "tgc",
597
+ "tigre": "tig",
598
+ "tigrinya": "ti",
599
+ "tilquiapan zapotec": "zts",
600
+ "tinputz": "tpz",
601
+ "tlacoapa me'phaa": "tpl",
602
+ "tlacoatzintepec chinantec": "ctl",
603
+ "tlingit": "tli",
604
+ "toki pona": "tok",
605
+ "tomoip": "tqp",
606
+ "tondano": "tdn",
607
+ "tonsea": "txs",
608
+ "tooro": "ttj",
609
+ "torau": "ttu",
610
+ "torwali": "trw",
611
+ "tsimihety malagasy": "xmw",
612
+ "tsotso": "lto",
613
+ "tswana": "tn",
614
+ "tugen": "tuy",
615
+ "tuki": "bag",
616
+ "tula": "tul",
617
+ "tulu": "tcy",
618
+ "tunen": "tvu",
619
+ "tungag": "lcm",
620
+ "tunisian arabic": "aeb",
621
+ "tupuri": "tui",
622
+ "turkana": "tuv",
623
+ "turkish": "tr",
624
+ "turkmen": "tk",
625
+ "tututepec mixtec": "mtu",
626
+ "twi": "tw",
627
+ "ubaghara": "byc",
628
+ "uighur": "ug",
629
+ "ukrainian": "uk",
630
+ "umbundu": "umb",
631
+ "upper sorbian": "hsb",
632
+ "urdu": "ur",
633
+ "ushojo": "ush",
634
+ "uzbek": "uz",
635
+ "vai": "vai",
636
+ "vietnamese": "vi",
637
+ "votic": "vot",
638
+ "võro": "vro",
639
+ "waci gbe": "wci",
640
+ "wadiyara koli": "kxp",
641
+ "waja": "wja",
642
+ "wakhi": "wbl",
643
+ "wanga": "lwg",
644
+ "wapan": "juk",
645
+ "warji": "wji",
646
+ "welsh": "cy",
647
+ "wemale": "weo",
648
+ "western frisian": "fy",
649
+ "western highland purepecha": "pua",
650
+ "western juxtlahuaca mixtec": "jmx",
651
+ "western maninkakan": "mlq",
652
+ "western mari": "mrj",
653
+ "western niger fulfulde": "fuh",
654
+ "western panjabi": "pnb",
655
+ "wolof": "wo",
656
+ "wuzlam": "udl",
657
+ "xanaguía zapotec": "ztg",
658
+ "xhosa": "xh",
659
+ "yace": "ekr",
660
+ "yakut": "sah",
661
+ "yalahatan": "jal",
662
+ "yanahuanca pasco quechua": "qur",
663
+ "yangben": "yav",
664
+ "yaqui": "yaq",
665
+ "yauyos quechua": "qux",
666
+ "yekhee": "ets",
667
+ "yiddish": "yi",
668
+ "yidgha": "ydg",
669
+ "yoruba": "yo",
670
+ "yutanduchi mixtec": "mab",
671
+ "zacatlán-ahuacatlán-tepetzintla nahuatl": "nhi",
672
+ "zarma": "dje",
673
+ "zaza": "zza",
674
+ "zulu": "zu",
675
+ "ömie": "aom",
676
+ }
677
+
678
+ LANG_NAMES = set(LANG_NAME_TO_ID.keys())
679
+ LANG_IDS = set(LANG_NAME_TO_ID.values())
680
+
681
+ # Exceptions where .title() doesn't match the canonical casing from the TSV.
682
+ _TITLE_EXCEPTIONS = {
683
+ "fe'fe'": "Fe'fe'",
684
+ "dũya": "Dũya",
685
+ "santiago del estero quichua": "Santiago del Estero Quichua",
686
+ "santa ana de tusi pasco quechua": "Santa Ana de Tusi Pasco Quechua",
687
+ "malinaltepec me'phaa": "Malinaltepec Me'phaa",
688
+ "tlacoapa me'phaa": "Tlacoapa Me'phaa",
689
+ }
690
+
691
+
692
+ def lang_display_name(name: str) -> str:
693
+ """Return a display-friendly version of a lowercase language name.
694
+
695
+ Uses .title() for most names, with manual exceptions for cases like
696
+ apostrophes and small words (de, del) that should stay lowercase.
697
+ """
698
+ return _TITLE_EXCEPTIONS.get(name, name.title())
runtime/omnivoice/utils/text.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Text processing utilities for TTS inference.
19
+
20
+ Provides:
21
+ - ``chunk_text_punctuation()``: Splits long text into model-friendly chunks at
22
+ sentence boundaries, with abbreviation-aware punctuation splitting.
23
+ - ``add_punctuation()``: Appends missing end punctuation (Chinese or English).
24
+ """
25
+
26
+ from typing import List, Optional
27
+
28
+
29
+ SPLIT_PUNCTUATION = set(".,;:!?。,;:!?")
30
+ CLOSING_MARKS = set("\"'""')]》》>」】")
31
+
32
+ END_PUNCTUATION = {
33
+ ";",
34
+ ":",
35
+ ",",
36
+ ".",
37
+ "!",
38
+ "?",
39
+ "…",
40
+ ")",
41
+ "]",
42
+ "}",
43
+ '"',
44
+ "'",
45
+ """,
46
+ "'",
47
+ ";",
48
+ ":",
49
+ ",",
50
+ "。",
51
+ "!",
52
+ "?",
53
+ "、",
54
+ "……",
55
+ ")",
56
+ "】",
57
+ """,
58
+ "'",
59
+ }
60
+
61
+
62
+ ABBREVIATIONS = {
63
+ "Mr.",
64
+ "Mrs.",
65
+ "Ms.",
66
+ "Dr.",
67
+ "Prof.",
68
+ "Sr.",
69
+ "Jr.",
70
+ "Rev.",
71
+ "Fr.",
72
+ "Hon.",
73
+ "Pres.",
74
+ "Gov.",
75
+ "Capt.",
76
+ "Gen.",
77
+ "Sen.",
78
+ "Rep.",
79
+ "Col.",
80
+ "Maj.",
81
+ "Lt.",
82
+ "Cmdr.",
83
+ "Sgt.",
84
+ "Cpl.",
85
+ "Co.",
86
+ "Corp.",
87
+ "Inc.",
88
+ "Ltd.",
89
+ "Est.",
90
+ "Dept.",
91
+ "St.",
92
+ "Ave.",
93
+ "Blvd.",
94
+ "Rd.",
95
+ "Mt.",
96
+ "Ft.",
97
+ "No.",
98
+ "Jan.",
99
+ "Feb.",
100
+ "Mar.",
101
+ "Apr.",
102
+ "Aug.",
103
+ "Sep.",
104
+ "Sept.",
105
+ "Oct.",
106
+ "Nov.",
107
+ "Dec.",
108
+ "i.e.",
109
+ "e.g.",
110
+ "vs.",
111
+ "Vs.",
112
+ "Etc.",
113
+ "approx.",
114
+ "fig.",
115
+ "def.",
116
+ }
117
+
118
+
119
+ def chunk_text_punctuation(
120
+ text: str,
121
+ chunk_len: int,
122
+ min_chunk_len: Optional[int] = None,
123
+ ) -> List[str]:
124
+ """
125
+ Splits the input tokens list into chunks according to punctuations,
126
+ avoiding splits on common abbreviations (e.g., Mr., No.).
127
+ """
128
+
129
+ # 1. Split the tokens according to punctuations.
130
+ sentences = []
131
+ current_sentence = []
132
+
133
+ tokens_list = list(text)
134
+
135
+ for token in tokens_list:
136
+ # If the first token of current sentence is punctuation,
137
+ # append it to the end of the previous sentence.
138
+ if (
139
+ len(current_sentence) == 0
140
+ and len(sentences) != 0
141
+ and (token in SPLIT_PUNCTUATION or token in CLOSING_MARKS)
142
+ ):
143
+ sentences[-1].append(token)
144
+ # Otherwise, append the current token to the current sentence.
145
+ else:
146
+ current_sentence.append(token)
147
+
148
+ # Split the sentence in positions of punctuations.
149
+ if token in SPLIT_PUNCTUATION:
150
+ is_abbreviation = False
151
+
152
+ if token == ".":
153
+ temp_str = "".join(current_sentence).strip()
154
+ if temp_str:
155
+ last_word = temp_str.split()[-1]
156
+ if last_word in ABBREVIATIONS:
157
+ is_abbreviation = True
158
+
159
+ if not is_abbreviation:
160
+ sentences.append(current_sentence)
161
+ current_sentence = []
162
+ # Assume the last few tokens are also a sentence
163
+ if len(current_sentence) != 0:
164
+ sentences.append(current_sentence)
165
+
166
+ # 2. Merge short sentences.
167
+ merged_chunks = []
168
+ current_chunk = []
169
+ for sentence in sentences:
170
+ if len(current_chunk) + len(sentence) <= chunk_len:
171
+ current_chunk.extend(sentence)
172
+ else:
173
+ if len(current_chunk) > 0:
174
+ merged_chunks.append(current_chunk)
175
+ current_chunk = sentence
176
+
177
+ if len(current_chunk) > 0:
178
+ merged_chunks.append(current_chunk)
179
+
180
+ # 4. Post-process: Check for undersized chunks and merge them
181
+ # with the previous chunk or next chunk (if it's the first chunk).
182
+ if min_chunk_len is not None:
183
+ first_chunk_short_flag = (
184
+ len(merged_chunks) > 0 and len(merged_chunks[0]) < min_chunk_len
185
+ )
186
+ final_chunks = []
187
+ for i, chunk in enumerate(merged_chunks):
188
+ if i == 1 and first_chunk_short_flag:
189
+ final_chunks[-1].extend(chunk)
190
+ else:
191
+ if len(chunk) >= min_chunk_len:
192
+ final_chunks.append(chunk)
193
+ else:
194
+ if len(final_chunks) == 0:
195
+ final_chunks.append(chunk)
196
+ else:
197
+ final_chunks[-1].extend(chunk)
198
+ else:
199
+ final_chunks = merged_chunks
200
+
201
+ chunk_strings = [
202
+ "".join(chunk).strip() for chunk in final_chunks if "".join(chunk).strip()
203
+ ]
204
+ return chunk_strings
205
+
206
+
207
+ def add_punctuation(text: str):
208
+ """Add punctuation if there is not in the end of text"""
209
+ text = text.strip()
210
+
211
+ if not text:
212
+ return text
213
+
214
+ if text[-1] not in END_PUNCTUATION:
215
+ is_chinese = any("\u4e00" <= char <= "\u9fff" for char in text)
216
+
217
+ text += "。" if is_chinese else "."
218
+
219
+ return text
runtime/omnivoice/utils/voice_design.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Voice-design instruct constants for TTS inference.
19
+
20
+ Defines speaker attribute tags (gender, age, pitch, accent, dialect) and
21
+ translation/validation utilities between English and Chinese. Used by
22
+ ``OmniVoice.generate()`` for voice design mode.
23
+ """
24
+
25
+ import re
26
+
27
+ _ZH_RE = re.compile(r'[\u4e00-\u9fff]')
28
+
29
+ # Category = set of {english: chinese, ...} items that are mutually exclusive.
30
+ # Accent (EN-only) and dialect (ZH-only) are stored as flat sets below.
31
+ _INSTRUCT_CATEGORIES = [
32
+ {"male": "男", "female": "女"},
33
+ {"child": "儿童", "teenager": "少年", "young adult": "青年",
34
+ "middle-aged": "中年", "elderly": "老年"},
35
+ {"very low pitch": "极低音调", "low pitch": "低音调",
36
+ "moderate pitch": "中音调", "high pitch": "高音调",
37
+ "very high pitch": "极高音调"},
38
+ {"whisper": "耳语"},
39
+ # Accent (English-only, no Chinese counterpart)
40
+ {"american accent", "british accent", "australian accent",
41
+ "chinese accent", "canadian accent", "indian accent",
42
+ "korean accent", "portuguese accent", "russian accent", "japanese accent",
43
+ "armenian accent", "eastern armenian accent", "western armenian accent",
44
+ "yerevan accent"},
45
+ # Dialect (Chinese-only, no English counterpart)
46
+ {"河南话", "陕西话", "四川话", "贵州话", "云南话", "桂林话",
47
+ "济南话", "石家庄话", "甘肃话", "宁夏话", "青岛话", "东北话"},
48
+ ]
49
+
50
+ _INSTRUCT_EN_TO_ZH = {}
51
+ _INSTRUCT_ZH_TO_EN = {}
52
+ _INSTRUCT_MUTUALLY_EXCLUSIVE = []
53
+ for _cat in _INSTRUCT_CATEGORIES:
54
+ if isinstance(_cat, dict):
55
+ _INSTRUCT_EN_TO_ZH.update(_cat)
56
+ _INSTRUCT_ZH_TO_EN.update({v: k for k, v in _cat.items()})
57
+ _INSTRUCT_MUTUALLY_EXCLUSIVE.append(set(_cat) | set(_cat.values()))
58
+ else:
59
+ _INSTRUCT_MUTUALLY_EXCLUSIVE.append(set(_cat))
60
+
61
+ _INSTRUCT_ALL_VALID = (
62
+ set(_INSTRUCT_EN_TO_ZH) | set(_INSTRUCT_ZH_TO_EN)
63
+ | _INSTRUCT_MUTUALLY_EXCLUSIVE[-2] # accents
64
+ | _INSTRUCT_MUTUALLY_EXCLUSIVE[-1] # dialects
65
+ )
66
+
67
+ _INSTRUCT_VALID_EN = frozenset(i for i in _INSTRUCT_ALL_VALID if not _ZH_RE.search(i))
68
+ _INSTRUCT_VALID_ZH = frozenset(i for i in _INSTRUCT_ALL_VALID if _ZH_RE.search(i))
runtime/pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "avoice-runtime"
7
+ version = "0.1.0"
8
+ description = "Runtime package for the AVoice Armenian TTS model."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = { text = "GPL-2.0" }
12
+ dependencies = ["numpy>=1.26","torch>=2.4","torchaudio>=2.4","transformers>=5.5.0","huggingface_hub>=0.24","soundfile>=0.12","sentencepiece>=0.2","fastapi>=0.115","uvicorn>=0.30","pydub","librosa"]
13
+
14
+ [project.scripts]
15
+ avoice = "omnivoice.cli.infer:main"
16
+ omnivoice-infer = "omnivoice.cli.infer:main"
17
+ omnivoice-api = "omnivoice.server.app:main"
18
+ omnivoice-prefetch = "omnivoice.server.prefetch:main"
19
+
20
+ [tool.setuptools.packages.find]
21
+ include = ["omnivoice*"]
runtime/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.26
2
+ torch>=2.4
3
+ torchaudio>=2.4
4
+ transformers>=5.5.0
5
+ huggingface_hub>=0.24
6
+ soundfile>=0.12
7
+ sentencepiece>=0.2
8
+ fastapi>=0.115
9
+ uvicorn>=0.30
10
+ pydub
11
+ librosa