zhu-han commited on
Commit
aa79b9c
·
verified ·
1 Parent(s): 72536e4

Upload 48 files

Browse files
Files changed (48) hide show
  1. app.py +128 -0
  2. omnivoice/__init__.py +28 -0
  3. omnivoice/cli/__init__.py +0 -0
  4. omnivoice/cli/demo.py +533 -0
  5. omnivoice/cli/infer.py +157 -0
  6. omnivoice/cli/infer_batch.py +523 -0
  7. omnivoice/cli/train.py +74 -0
  8. omnivoice/data/__init__.py +0 -0
  9. omnivoice/data/batching.py +166 -0
  10. omnivoice/data/collator.py +92 -0
  11. omnivoice/data/dataset.py +551 -0
  12. omnivoice/data/processor.py +258 -0
  13. omnivoice/eval/__init__.py +4 -0
  14. omnivoice/eval/models/ecapa_tdnn_wavlm.py +374 -0
  15. omnivoice/eval/models/utmos.py +370 -0
  16. omnivoice/eval/mos/utmos.py +299 -0
  17. omnivoice/eval/speaker_similarity/sim.py +321 -0
  18. omnivoice/eval/utils.py +80 -0
  19. omnivoice/eval/wer/common.py +88 -0
  20. omnivoice/eval/wer/fleurs.py +517 -0
  21. omnivoice/eval/wer/hubert.py +318 -0
  22. omnivoice/eval/wer/minimax.py +596 -0
  23. omnivoice/eval/wer/norm_config_module.py +291 -0
  24. omnivoice/eval/wer/punctuations.lst +188 -0
  25. omnivoice/eval/wer/seedtts.py +413 -0
  26. omnivoice/eval/wer/sensevoice.py +344 -0
  27. omnivoice/eval/wer/text_norm_omni.py +113 -0
  28. omnivoice/models/__init__.py +0 -0
  29. omnivoice/models/omnivoice.py +1502 -0
  30. omnivoice/scripts/__init__.py +0 -0
  31. omnivoice/scripts/denoise_audio.py +1048 -0
  32. omnivoice/scripts/extract_audio_tokens.py +625 -0
  33. omnivoice/scripts/extract_audio_tokens_add_noise.py +825 -0
  34. omnivoice/scripts/jsonl_to_webdataset.py +439 -0
  35. omnivoice/training/__init__.py +0 -0
  36. omnivoice/training/builder.py +180 -0
  37. omnivoice/training/checkpoint.py +180 -0
  38. omnivoice/training/config.py +98 -0
  39. omnivoice/training/trainer.py +342 -0
  40. omnivoice/utils/__init__.py +0 -0
  41. omnivoice/utils/audio.py +355 -0
  42. omnivoice/utils/common.py +56 -0
  43. omnivoice/utils/data_utils.py +63 -0
  44. omnivoice/utils/duration.py +282 -0
  45. omnivoice/utils/lang_map.py +698 -0
  46. omnivoice/utils/text.py +219 -0
  47. omnivoice/utils/voice_design.py +66 -0
  48. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ HuggingFace Space entry point for OmniVoice demo.
4
+
5
+ """
6
+
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ from typing import Any, Dict
11
+
12
+ import torch
13
+ import torchaudio
14
+
15
+ from omnivoice import OmniVoice, OmniVoiceGenerationConfig
16
+ from omnivoice.cli.demo import build_demo
17
+
18
+ logger = logging.getLogger(__name__)
19
+ logging.basicConfig(level=logging.INFO)
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Hardware detection
23
+ # ---------------------------------------------------------------------------
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+ logger.info(f"Using device: {DEVICE}")
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Model loading
29
+ # ---------------------------------------------------------------------------
30
+ CHECKPOINT = os.environ.get("OMNIVOICE_MODEL", "k2-fsa/OmniVoice")
31
+
32
+ logger.info(f"Loading model from {CHECKPOINT} on {DEVICE} ...")
33
+ model = OmniVoice.from_pretrained(
34
+ CHECKPOINT,
35
+ device_map=DEVICE,
36
+ dtype=torch.float16,
37
+ load_asr=True,
38
+ )
39
+ logger.info("Model loaded on %s.", DEVICE)
40
+ sampling_rate = model.sampling_rate
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Generation logic (outside build_demo so we can wrap with spaces.GPU)
44
+ # ---------------------------------------------------------------------------
45
+
46
+
47
+ def _gen_core(
48
+ text,
49
+ language,
50
+ ref_audio,
51
+ instruct,
52
+ num_step,
53
+ guidance_scale,
54
+ denoise,
55
+ speed,
56
+ duration,
57
+ preprocess_prompt,
58
+ postprocess_output,
59
+ mode,
60
+ ref_text=None,
61
+ ):
62
+ if not text or not text.strip():
63
+ return None, "Please enter the text to synthesize."
64
+
65
+ gen_config = OmniVoiceGenerationConfig(
66
+ num_step=int(num_step or 32),
67
+ guidance_scale=float(guidance_scale) if guidance_scale is not None else 2.0,
68
+ denoise=bool(denoise) if denoise is not None else True,
69
+ preprocess_prompt=bool(preprocess_prompt),
70
+ postprocess_output=bool(postprocess_output),
71
+ )
72
+
73
+ lang = language if (language and language != "Auto") else None
74
+
75
+ kw: Dict[str, Any] = dict(
76
+ text=text.strip(), language=lang, generation_config=gen_config
77
+ )
78
+
79
+ if speed is not None and float(speed) != 1.0:
80
+ kw["speed"] = float(speed)
81
+ if duration is not None and float(duration) > 0:
82
+ kw["duration"] = float(duration)
83
+
84
+ if mode == "clone":
85
+ if not ref_audio:
86
+ return None, "Please upload a reference audio."
87
+ kw["voice_clone_prompt"] = model.create_voice_clone_prompt(
88
+ ref_audio=ref_audio,
89
+ ref_text=ref_text,
90
+ )
91
+
92
+ if mode == "design":
93
+ if instruct and instruct.strip():
94
+ kw["instruct"] = instruct.strip()
95
+
96
+ try:
97
+ out_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
98
+ audio = model.generate(**kw)
99
+ torchaudio.save(out_path, audio[0], sampling_rate)
100
+ except Exception as e:
101
+ return None, f"Error: {type(e).__name__}: {e}"
102
+
103
+ return out_path, "Done."
104
+
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # ZeroGPU wrapper
108
+ # ---------------------------------------------------------------------------
109
+ generate_fn = None
110
+ try:
111
+ import spaces
112
+
113
+ @spaces.GPU()
114
+ def _gen_gpu(*args, **kwargs):
115
+ return _gen_core(*args, **kwargs)
116
+
117
+ generate_fn = _gen_gpu
118
+ logger.info("Using spaces.GPU() wrapper.")
119
+ except ImportError:
120
+ logger.info("spaces module not found, running without GPU wrapper.")
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # Build and launch demo — reuses the full UI from omnivoice.cli.demo
124
+ # ---------------------------------------------------------------------------
125
+ demo = build_demo(model, CHECKPOINT, generate_fn=generate_fn)
126
+
127
+ if __name__ == "__main__":
128
+ demo.queue().launch()
omnivoice/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from importlib.metadata import PackageNotFoundError, version
3
+
4
+ warnings.filterwarnings("ignore", module="torchaudio")
5
+ warnings.filterwarnings(
6
+ "ignore",
7
+ category=SyntaxWarning,
8
+ message="invalid escape sequence",
9
+ module="pydub.utils",
10
+ )
11
+ warnings.filterwarnings(
12
+ "ignore",
13
+ category=FutureWarning,
14
+ module="torch.distributed.algorithms.ddp_comm_hooks",
15
+ )
16
+
17
+ try:
18
+ __version__ = version("omnivoice")
19
+ except PackageNotFoundError:
20
+ __version__ = "0.0.0"
21
+
22
+ from omnivoice.models.omnivoice import (
23
+ OmniVoice,
24
+ OmniVoiceConfig,
25
+ OmniVoiceGenerationConfig,
26
+ )
27
+
28
+ __all__ = ["OmniVoice", "OmniVoiceConfig", "OmniVoiceGenerationConfig"]
omnivoice/cli/__init__.py ADDED
File without changes
omnivoice/cli/demo.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """
18
+ Gradio demo for OmniVoice.
19
+
20
+ Supports voice cloning and voice design.
21
+
22
+ Usage:
23
+ omnivoice-demo --model /path/to/checkpoint --port 8000
24
+ """
25
+
26
+ import argparse
27
+ import logging
28
+ from typing import Any, Dict
29
+
30
+ import gradio as gr
31
+ import numpy as np
32
+ import torch
33
+
34
+ from omnivoice import OmniVoice, OmniVoiceGenerationConfig
35
+ from omnivoice.utils.lang_map import LANG_NAMES, lang_display_name
36
+
37
+
38
+ def get_best_device():
39
+ """Auto-detect the best available device: CUDA > MPS > CPU."""
40
+ if torch.cuda.is_available():
41
+ return "cuda"
42
+ if torch.backends.mps.is_available():
43
+ return "mps"
44
+ return "cpu"
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Language list — all 600+ supported languages
49
+ # ---------------------------------------------------------------------------
50
+ _ALL_LANGUAGES = ["Auto"] + sorted(lang_display_name(n) for n in LANG_NAMES)
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Voice Design instruction templates
55
+ # ---------------------------------------------------------------------------
56
+ # Each option is displayed as "English / 中文".
57
+ # The model expects English for accents and Chinese for dialects.
58
+ _CATEGORIES = {
59
+ "Gender / 性别": ["Male / 男", "Female / 女"],
60
+ "Age / 年龄": [
61
+ "Child / 儿童",
62
+ "Teenager / 少年",
63
+ "Young Adult / 青年",
64
+ "Middle-aged / 中年",
65
+ "Elderly / 老年",
66
+ ],
67
+ "Pitch / 音调": [
68
+ "Very Low Pitch / 极低音调",
69
+ "Low Pitch / 低音调",
70
+ "Moderate Pitch / 中音调",
71
+ "High Pitch / 高音调",
72
+ "Very High Pitch / 极高音调",
73
+ ],
74
+ "Style / 风格": ["Whisper / 耳语"],
75
+ "English Accent / 英文口音": [
76
+ "American Accent / 美式口音",
77
+ "Australian Accent / 澳大利亚口音",
78
+ "British Accent / 英国口音",
79
+ "Chinese Accent / 中国口音",
80
+ "Canadian Accent / 加拿大口音",
81
+ "Indian Accent / 印度口音",
82
+ "Korean Accent / 韩国口音",
83
+ "Portuguese Accent / 葡萄牙口音",
84
+ "Russian Accent / 俄罗斯口音",
85
+ "Japanese Accent / 日本口音",
86
+ ],
87
+ "Chinese Dialect / 中文方言": [
88
+ "Henan Dialect / 河南话",
89
+ "Shaanxi Dialect / 陕西话",
90
+ "Sichuan Dialect / 四川话",
91
+ "Guizhou Dialect / 贵州话",
92
+ "Yunnan Dialect / 云南话",
93
+ "Guilin Dialect / 桂林话",
94
+ "Jinan Dialect / 济南话",
95
+ "Shijiazhuang Dialect / 石家庄话",
96
+ "Gansu Dialect / 甘肃话",
97
+ "Ningxia Dialect / 宁夏话",
98
+ "Qingdao Dialect / 青岛话",
99
+ "Northeast Dialect / 东北话",
100
+ ],
101
+ }
102
+
103
+ _ATTR_INFO = {
104
+ "English Accent / 英文口音": "Only effective for English speech.",
105
+ "Chinese Dialect / 中文方言": "Only effective for Chinese speech.",
106
+ }
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # Argument parser
110
+ # ---------------------------------------------------------------------------
111
+
112
+
113
+ def build_parser() -> argparse.ArgumentParser:
114
+ parser = argparse.ArgumentParser(
115
+ prog="omnivoice-demo",
116
+ description="Launch a Gradio demo for OmniVoice.",
117
+ formatter_class=argparse.RawTextHelpFormatter,
118
+ )
119
+ parser.add_argument(
120
+ "--model",
121
+ default="k2-fsa/OmniVoice",
122
+ help="Model checkpoint path or HuggingFace repo id.",
123
+ )
124
+ parser.add_argument(
125
+ "--device", default=None, help="Device to use. Auto-detected if not specified."
126
+ )
127
+ parser.add_argument("--ip", default="0.0.0.0", help="Server IP (default: 0.0.0.0).")
128
+ parser.add_argument(
129
+ "--port", type=int, default=7860, help="Server port (default: 7860)."
130
+ )
131
+ parser.add_argument(
132
+ "--root-path",
133
+ default=None,
134
+ help="Root path for reverse proxy.",
135
+ )
136
+ parser.add_argument(
137
+ "--share", action="store_true", default=False, help="Create public link."
138
+ )
139
+ return parser
140
+
141
+
142
+ # ---------------------------------------------------------------------------
143
+ # Build demo
144
+ # ---------------------------------------------------------------------------
145
+
146
+
147
+ def build_demo(
148
+ model: OmniVoice,
149
+ checkpoint: str,
150
+ generate_fn=None,
151
+ ) -> gr.Blocks:
152
+
153
+ sampling_rate = model.sampling_rate
154
+
155
+ # -- shared generation core --
156
+ def _gen_core(
157
+ text,
158
+ language,
159
+ ref_audio,
160
+ instruct,
161
+ num_step,
162
+ guidance_scale,
163
+ denoise,
164
+ speed,
165
+ duration,
166
+ preprocess_prompt,
167
+ postprocess_output,
168
+ mode,
169
+ ref_text=None,
170
+ ):
171
+ if not text or not text.strip():
172
+ return None, "Please enter the text to synthesize."
173
+
174
+ gen_config = OmniVoiceGenerationConfig(
175
+ num_step=int(num_step or 32),
176
+ guidance_scale=float(guidance_scale) if guidance_scale is not None else 2.0,
177
+ denoise=bool(denoise) if denoise is not None else True,
178
+ preprocess_prompt=bool(preprocess_prompt),
179
+ postprocess_output=bool(postprocess_output),
180
+ )
181
+
182
+ lang = language if (language and language != "Auto") else None
183
+
184
+ kw: Dict[str, Any] = dict(
185
+ text=text.strip(), language=lang, generation_config=gen_config
186
+ )
187
+
188
+ if speed is not None and float(speed) != 1.0:
189
+ kw["speed"] = float(speed)
190
+ if duration is not None and float(duration) > 0:
191
+ kw["duration"] = float(duration)
192
+
193
+ if mode == "clone":
194
+ if not ref_audio:
195
+ return None, "Please upload a reference audio."
196
+ kw["voice_clone_prompt"] = model.create_voice_clone_prompt(
197
+ ref_audio=ref_audio,
198
+ ref_text=ref_text,
199
+ )
200
+
201
+ if mode == "design":
202
+ if instruct and instruct.strip():
203
+ kw["instruct"] = instruct.strip()
204
+
205
+ try:
206
+ audio = model.generate(**kw)
207
+ except Exception as e:
208
+ return None, f"Error: {type(e).__name__}: {e}"
209
+
210
+ waveform = audio[0].squeeze(0).numpy() # (T,)
211
+ waveform = (waveform * 32767).astype(np.int16)
212
+ return (sampling_rate, waveform), "Done."
213
+
214
+ # Allow external wrappers (e.g. spaces.GPU for ZeroGPU Spaces)
215
+ _gen = generate_fn if generate_fn is not None else _gen_core
216
+
217
+ # =====================================================================
218
+ # UI
219
+ # =====================================================================
220
+ theme = gr.themes.Soft(
221
+ font=["Inter", "Arial", "sans-serif"],
222
+ )
223
+ css = """
224
+ .gradio-container {max-width: 100% !important; font-size: 16px !important;}
225
+ .gradio-container h1 {font-size: 1.5em !important;}
226
+ .gradio-container .prose {font-size: 1.1em !important;}
227
+ .compact-audio audio {height: 60px !important;}
228
+ .compact-audio .waveform {min-height: 80px !important;}
229
+ """
230
+
231
+ # Reusable: language dropdown component
232
+ def _lang_dropdown(label="Language (optional) / 语种 (可选)", value="Auto"):
233
+ return gr.Dropdown(
234
+ label=label,
235
+ choices=_ALL_LANGUAGES,
236
+ value=value,
237
+ allow_custom_value=False,
238
+ interactive=True,
239
+ info="Keep as Auto to auto-detect the language.",
240
+ )
241
+
242
+ # Reusable: optional generation settings accordion
243
+ def _gen_settings():
244
+ with gr.Accordion("Generation Settings (optional)", open=False):
245
+ sp = gr.Slider(
246
+ 0.7,
247
+ 1.3,
248
+ value=1.0,
249
+ step=0.05,
250
+ label="Speed",
251
+ info="1.0 = normal. >1 faster, <1 slower. Ignored if Duration is set.",
252
+ )
253
+ du = gr.Number(
254
+ value=None,
255
+ label="Duration (seconds)",
256
+ info=(
257
+ "Leave empty to use speed."
258
+ " Set a fixed duration to override speed."
259
+ ),
260
+ )
261
+ ns = gr.Slider(
262
+ 4,
263
+ 64,
264
+ value=32,
265
+ step=1,
266
+ label="Inference Steps",
267
+ info="Default: 32. Lower = faster, higher = better quality.",
268
+ )
269
+ dn = gr.Checkbox(
270
+ label="Denoise",
271
+ value=True,
272
+ info="Default: enabled. Uncheck to disable denoising.",
273
+ )
274
+ gs = gr.Slider(
275
+ 0.0,
276
+ 4.0,
277
+ value=2.0,
278
+ step=0.1,
279
+ label="Guidance Scale (CFG)",
280
+ info="Default: 2.0.",
281
+ )
282
+ pp = gr.Checkbox(
283
+ label="Preprocess Prompt",
284
+ value=True,
285
+ info="apply silence removal and trimming to the reference "
286
+ "audio, add punctuation in the end of reference text (if not already)",
287
+ )
288
+ po = gr.Checkbox(
289
+ label="Postprocess Output",
290
+ value=True,
291
+ info="Remove long silences from generated audio.",
292
+ )
293
+ return ns, gs, dn, sp, du, pp, po
294
+
295
+ with gr.Blocks(theme=theme, css=css, title="OmniVoice Demo") as demo:
296
+ gr.Markdown(
297
+ """
298
+ # OmniVoice Demo
299
+
300
+ State-of-the-art text-to-speech model for **600+ languages**, supporting:
301
+
302
+ - **Voice Clone** — Clone any voice from a reference audio
303
+ - **Voice Design** — Create custom voices with speaker attributes
304
+
305
+ Built with [OmniVoice](https://github.com/k2-fsa/OmniVoice)
306
+ by Xiaomi Next-gen Kaldi team.
307
+ """
308
+ )
309
+
310
+ with gr.Tabs():
311
+ # ==============================================================
312
+ # Voice Clone
313
+ # ==============================================================
314
+ with gr.TabItem("Voice Clone"):
315
+ with gr.Row():
316
+ with gr.Column(scale=1):
317
+ vc_text = gr.Textbox(
318
+ label="Text to Synthesize / 待合成文本",
319
+ lines=4,
320
+ placeholder="Enter the text you want to synthesize...",
321
+ )
322
+ vc_ref_audio = gr.Audio(
323
+ label="Reference Audio / 参考音频",
324
+ type="filepath",
325
+ elem_classes="compact-audio",
326
+ )
327
+ gr.Markdown(
328
+ "<span style='font-size:0.85em;color:#888;'>"
329
+ "Recommended: 3–10 seconds audio. "
330
+ "</span>"
331
+ )
332
+ vc_ref_text = gr.Textbox(
333
+ label=("Reference Text (optional)" " / 参考音频文本(可选)"),
334
+ lines=2,
335
+ placeholder="Transcript of the reference audio. Leave empty"
336
+ " to auto-transcribe via ASR models.",
337
+ )
338
+ vc_lang = _lang_dropdown("Language (optional) / 语种 (可选)")
339
+ (
340
+ vc_ns,
341
+ vc_gs,
342
+ vc_dn,
343
+ vc_sp,
344
+ vc_du,
345
+ vc_pp,
346
+ vc_po,
347
+ ) = _gen_settings()
348
+ vc_btn = gr.Button("Generate / 生成", variant="primary")
349
+ with gr.Column(scale=1):
350
+ vc_audio = gr.Audio(
351
+ label="Output Audio / 合成结果",
352
+ type="numpy",
353
+ )
354
+ vc_status = gr.Textbox(label="Status / 状态", lines=2)
355
+
356
+ def _clone_fn(
357
+ text, lang, ref_aud, ref_text, ns, gs, dn, sp, du, pp, po
358
+ ):
359
+ return _gen(
360
+ text,
361
+ lang,
362
+ ref_aud,
363
+ None,
364
+ ns,
365
+ gs,
366
+ dn,
367
+ sp,
368
+ du,
369
+ pp,
370
+ po,
371
+ mode="clone",
372
+ ref_text=ref_text or None,
373
+ )
374
+
375
+ vc_btn.click(
376
+ _clone_fn,
377
+ inputs=[
378
+ vc_text,
379
+ vc_lang,
380
+ vc_ref_audio,
381
+ vc_ref_text,
382
+ vc_ns,
383
+ vc_gs,
384
+ vc_dn,
385
+ vc_sp,
386
+ vc_du,
387
+ vc_pp,
388
+ vc_po,
389
+ ],
390
+ outputs=[vc_audio, vc_status],
391
+ )
392
+
393
+ # ==============================================================
394
+ # Voice Design
395
+ # ==============================================================
396
+ with gr.TabItem("Voice Design"):
397
+ with gr.Row():
398
+ with gr.Column(scale=1):
399
+ vd_text = gr.Textbox(
400
+ label="Text to Synthesize / 待合成文本",
401
+ lines=4,
402
+ placeholder="Enter the text you want to synthesize...",
403
+ )
404
+ vd_lang = _lang_dropdown()
405
+
406
+ _AUTO = "Auto"
407
+ vd_groups = []
408
+ for _cat, _choices in _CATEGORIES.items():
409
+ vd_groups.append(
410
+ gr.Dropdown(
411
+ label=_cat,
412
+ choices=[_AUTO] + _choices,
413
+ value=_AUTO,
414
+ info=_ATTR_INFO.get(_cat),
415
+ )
416
+ )
417
+
418
+ (
419
+ vd_ns,
420
+ vd_gs,
421
+ vd_dn,
422
+ vd_sp,
423
+ vd_du,
424
+ vd_pp,
425
+ vd_po,
426
+ ) = _gen_settings()
427
+ vd_btn = gr.Button("Generate / 生成", variant="primary")
428
+ with gr.Column(scale=1):
429
+ vd_audio = gr.Audio(
430
+ label="Output Audio / 合成结果",
431
+ type="numpy",
432
+ )
433
+ vd_status = gr.Textbox(label="Status / 状态", lines=2)
434
+
435
+ def _build_instruct(groups):
436
+ """Extract instruct text from UI dropdowns.
437
+
438
+ Language unification and validation is handled by
439
+ _resolve_instruct inside _preprocess_all.
440
+ """
441
+ selected = [g for g in groups if g and g != "Auto"]
442
+ if not selected:
443
+ return None
444
+ parts = []
445
+ for v in selected:
446
+ if " / " in v:
447
+ en, zh = v.split(" / ", 1)
448
+ # Dialects have no English equivalent
449
+ if "Dialect" in v.split(" / ")[0]:
450
+ parts.append(zh.strip())
451
+ else:
452
+ parts.append(en.strip())
453
+ else:
454
+ parts.append(v)
455
+ return ", ".join(parts)
456
+
457
+ def _design_fn(text, lang, ns, gs, dn, sp, du, pp, po, *groups):
458
+ return _gen(
459
+ text,
460
+ lang,
461
+ None,
462
+ _build_instruct(groups),
463
+ ns,
464
+ gs,
465
+ dn,
466
+ sp,
467
+ du,
468
+ pp,
469
+ po,
470
+ mode="design",
471
+ )
472
+
473
+ vd_btn.click(
474
+ _design_fn,
475
+ inputs=[
476
+ vd_text,
477
+ vd_lang,
478
+ vd_ns,
479
+ vd_gs,
480
+ vd_dn,
481
+ vd_sp,
482
+ vd_du,
483
+ vd_pp,
484
+ vd_po,
485
+ ]
486
+ + vd_groups,
487
+ outputs=[vd_audio, vd_status],
488
+ )
489
+
490
+ return demo
491
+
492
+
493
+ # ---------------------------------------------------------------------------
494
+ # Main
495
+ # ---------------------------------------------------------------------------
496
+
497
+
498
+ def main(argv=None) -> int:
499
+ logging.basicConfig(
500
+ level=logging.INFO,
501
+ format="%(asctime)s %(name)s %(levelname)s: %(message)s",
502
+ )
503
+ parser = build_parser()
504
+ args = parser.parse_args(argv)
505
+
506
+ device = args.device or get_best_device()
507
+
508
+ checkpoint = args.model
509
+ if not checkpoint:
510
+ parser.print_help()
511
+ return 0
512
+ logging.info(f"Loading model from {checkpoint}, device={device} ...")
513
+ model = OmniVoice.from_pretrained(
514
+ checkpoint,
515
+ device_map=device,
516
+ dtype=torch.float16,
517
+ load_asr=True,
518
+ )
519
+ print("Model loaded.")
520
+
521
+ demo = build_demo(model, checkpoint)
522
+
523
+ demo.queue().launch(
524
+ server_name=args.ip,
525
+ server_port=args.port,
526
+ share=args.share,
527
+ root_path=args.root_path,
528
+ )
529
+ return 0
530
+
531
+
532
+ if __name__ == "__main__":
533
+ raise SystemExit(main())
omnivoice/cli/infer.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Single-item inference CLI for OmniVoice.
2
+
3
+ Generates audio from a single text input using voice cloning,
4
+ voice design, or auto voice.
5
+
6
+ Usage:
7
+ # Voice cloning
8
+ omnivoice-infer --model k2-fsa/OmniVoice \
9
+ --text "Hello, this is a text for text-to-speech." \
10
+ --ref_audio ref.wav --ref_text "Reference transcript." --output out.wav
11
+
12
+ # Voice design
13
+ omnivoice-infer --model k2-fsa/OmniVoice \
14
+ --text "Hello, this is a text for text-to-speech." \
15
+ --instruct "male, British accent" --output out.wav
16
+
17
+ # Auto voice
18
+ omnivoice-infer --model k2-fsa/OmniVoice \
19
+ --text "Hello, this is a text for text-to-speech." --output out.wav
20
+ """
21
+
22
+ import argparse
23
+ import logging
24
+
25
+ import torch
26
+ import torchaudio
27
+
28
+ from omnivoice.models.omnivoice import OmniVoice
29
+ from omnivoice.utils.common import str2bool
30
+
31
+
32
+ def get_best_device():
33
+ """Auto-detect the best available device: CUDA > MPS > CPU."""
34
+ if torch.cuda.is_available():
35
+ return "cuda"
36
+ if torch.backends.mps.is_available():
37
+ return "mps"
38
+ return "cpu"
39
+
40
+
41
+ def get_parser() -> argparse.ArgumentParser:
42
+ parser = argparse.ArgumentParser(
43
+ description="OmniVoice single-item inference",
44
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
45
+ )
46
+ parser.add_argument(
47
+ "--model",
48
+ type=str,
49
+ default="k2-fsa/OmniVoice",
50
+ help="Model checkpoint path or HuggingFace repo id.",
51
+ )
52
+ parser.add_argument(
53
+ "--text",
54
+ type=str,
55
+ required=True,
56
+ help="Text to synthesize.",
57
+ )
58
+ parser.add_argument(
59
+ "--output",
60
+ type=str,
61
+ required=True,
62
+ help="Output WAV file path.",
63
+ )
64
+ # Voice cloning
65
+ parser.add_argument(
66
+ "--ref_audio",
67
+ type=str,
68
+ default=None,
69
+ help="Reference audio file path for voice cloning.",
70
+ )
71
+ parser.add_argument(
72
+ "--ref_text",
73
+ type=str,
74
+ default=None,
75
+ help="Reference text describing the reference audio.",
76
+ )
77
+ # Voice design
78
+ parser.add_argument(
79
+ "--instruct",
80
+ type=str,
81
+ default=None,
82
+ help="Style instruction for voice design mode.",
83
+ )
84
+ parser.add_argument(
85
+ "--language",
86
+ type=str,
87
+ default=None,
88
+ help="Language name (e.g. 'English') or code (e.g. 'en').",
89
+ )
90
+ # Generation parameters
91
+ parser.add_argument("--num_step", type=int, default=32)
92
+ parser.add_argument("--guidance_scale", type=float, default=2.0)
93
+ parser.add_argument("--speed", type=float, default=1.0)
94
+ parser.add_argument(
95
+ "--duration",
96
+ type=float,
97
+ default=None,
98
+ help="Fixed output duration in seconds. If set, overrides the "
99
+ "model's duration estimation. The speed factor is automatically "
100
+ "adjusted to match while preserving language-aware pacing.",
101
+ )
102
+ parser.add_argument("--t_shift", type=float, default=0.1)
103
+ parser.add_argument("--denoise", type=str2bool, default=True)
104
+ parser.add_argument(
105
+ "--postprocess_output",
106
+ type=str2bool,
107
+ default=True,
108
+ )
109
+ parser.add_argument("--layer_penalty_factor", type=float, default=5.0)
110
+ parser.add_argument("--position_temperature", type=float, default=5.0)
111
+ parser.add_argument("--class_temperature", type=float, default=0.0)
112
+ parser.add_argument(
113
+ "--device",
114
+ type=str,
115
+ default=None,
116
+ help="Device to use for inference. Auto-detected if not specified.",
117
+ )
118
+ return parser
119
+
120
+
121
+ def main():
122
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
123
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
124
+
125
+ args = get_parser().parse_args()
126
+
127
+ device = args.device or get_best_device()
128
+ logging.info(f"Loading model from {args.model} on {device} ...")
129
+ model = OmniVoice.from_pretrained(
130
+ args.model, device_map=device, dtype=torch.float16
131
+ )
132
+
133
+ logging.info(f"Generating audio for: {args.text[:80]}...")
134
+ audios = model.generate(
135
+ text=args.text,
136
+ language=args.language,
137
+ ref_audio=args.ref_audio,
138
+ ref_text=args.ref_text,
139
+ instruct=args.instruct,
140
+ duration=args.duration,
141
+ num_step=args.num_step,
142
+ guidance_scale=args.guidance_scale,
143
+ speed=args.speed,
144
+ t_shift=args.t_shift,
145
+ denoise=args.denoise,
146
+ postprocess_output=args.postprocess_output,
147
+ layer_penalty_factor=args.layer_penalty_factor,
148
+ position_temperature=args.position_temperature,
149
+ class_temperature=args.class_temperature,
150
+ )
151
+
152
+ torchaudio.save(args.output, audios[0], model.sampling_rate)
153
+ logging.info(f"Saved to {args.output}")
154
+
155
+
156
+ if __name__ == "__main__":
157
+ main()
omnivoice/cli/infer_batch.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Batch inference CLI for OmniVoice.
19
+
20
+ Distributes TTS generation across multiple GPUs for large-scale tasks.
21
+ Reads a JSONL test list, generates audio in parallel, and saves results.
22
+
23
+ Usage:
24
+ omnivoice-infer-batch --model k2-fsa/OmniVoice \
25
+ --test_list test.jsonl --res_dir results/
26
+
27
+ Test list format (JSONL, one JSON object per line):
28
+ Required fields: "id", "text"
29
+ Voice cloning: "ref_audio", "ref_text"
30
+ Voice design: "instruct"
31
+ Optional: "language_id", "language_name", "duration", "speed"
32
+ """
33
+
34
+ import argparse
35
+ import logging
36
+ import multiprocessing as mp
37
+ import os
38
+ import signal
39
+ import time
40
+ import traceback
41
+ from concurrent.futures import ProcessPoolExecutor, as_completed
42
+ from typing import List, Optional, Tuple
43
+
44
+ import torch
45
+ import torchaudio
46
+ from tqdm import tqdm
47
+
48
+ from omnivoice.models.omnivoice import OmniVoice
49
+ from omnivoice.utils.audio import load_audio
50
+ from omnivoice.utils.common import str2bool
51
+ from omnivoice.utils.data_utils import read_test_list
52
+ from omnivoice.utils.duration import RuleDurationEstimator
53
+
54
+
55
+ def get_best_device():
56
+ """Auto-detect the best available device: CUDA > MPS > CPU."""
57
+ if torch.cuda.is_available():
58
+ return "cuda", torch.cuda.device_count()
59
+ if torch.backends.mps.is_available():
60
+ return "mps", 1
61
+ return "cpu", 1
62
+
63
+
64
+ worker_model = None
65
+ SAMPLING_RATE = 24000
66
+
67
+
68
+ def get_parser():
69
+ parser = argparse.ArgumentParser(description="Infer OmniVoice Model")
70
+ parser.add_argument(
71
+ "--model",
72
+ type=str,
73
+ default="k2-fsa/OmniVoice",
74
+ help="Path to the model checkpoint (local dir or HF repo id). "
75
+ "Audio tokenizer is expected at <checkpoint>/audio_tokenizer/.",
76
+ )
77
+ parser.add_argument(
78
+ "--test_list",
79
+ type=str,
80
+ required=True,
81
+ help="Path to the JSONL file containing test samples. "
82
+ 'Each line is a JSON object: {"id": "name", "text": "...", '
83
+ '"ref_audio": "/path.wav", "ref_text": "...", '
84
+ '"language_id": "en", "language_name": "English", '
85
+ '"duration": 10.0, "speed": 1.2}. '
86
+ "language_id, language_name, duration, and speed are optional.",
87
+ )
88
+ parser.add_argument(
89
+ "--res_dir",
90
+ type=str,
91
+ required=True,
92
+ help="Directory to save the generated audio files.",
93
+ )
94
+ parser.add_argument(
95
+ "--num_step",
96
+ type=int,
97
+ default=32,
98
+ help="Number of steps for iterative decoding.",
99
+ )
100
+ parser.add_argument(
101
+ "--guidance_scale",
102
+ type=float,
103
+ default=2.0,
104
+ help="Scale for Classifier-Free Guidance.",
105
+ )
106
+ parser.add_argument(
107
+ "--t_shift",
108
+ type=float,
109
+ default=0.1,
110
+ help="Shift t to smaller ones if t_shift < 1.0",
111
+ )
112
+ parser.add_argument(
113
+ "--nj_per_gpu",
114
+ type=int,
115
+ default=1,
116
+ help="Number of worker processes to spawn per GPU.",
117
+ )
118
+ parser.add_argument(
119
+ "--audio_chunk_duration",
120
+ type=float,
121
+ default=15.0,
122
+ help="Maximum duration of audio chunk (in seconds) for splitting. "
123
+ '"Not split" if <= 0.',
124
+ )
125
+ parser.add_argument(
126
+ "--audio_chunk_threshold",
127
+ type=float,
128
+ default=30.0,
129
+ help=(
130
+ "The duration threshold (in seconds) to decide"
131
+ " whether to split audio into chunks."
132
+ ),
133
+ )
134
+ parser.add_argument(
135
+ "--batch_duration",
136
+ type=float,
137
+ default=1000.0,
138
+ help="Maximum total duration (reference + generated) per batch (seconds). "
139
+ "Only effective for parallel_chunk / no chunk mode.",
140
+ )
141
+ parser.add_argument(
142
+ "--batch_size",
143
+ type=int,
144
+ default=0,
145
+ help="Fixed batch size (number of samples per batch). "
146
+ "If > 0, use fixed-size batching instead of duration-based batching.",
147
+ )
148
+ parser.add_argument(
149
+ "--warmup",
150
+ type=int,
151
+ default=0,
152
+ help="Number of dummy inference runs per worker before real inference "
153
+ "starts, to warm up CUDA kernels and caches.",
154
+ )
155
+ parser.add_argument(
156
+ "--preprocess_prompt",
157
+ type=str2bool,
158
+ default=True,
159
+ help="Whether to preprocess reference audio (silence removal, trimming). "
160
+ "Set to False to keep raw audio.",
161
+ )
162
+ parser.add_argument(
163
+ "--postprocess_output",
164
+ type=str2bool,
165
+ default=True,
166
+ help="Whether to post-process generated audio (remove silence).",
167
+ )
168
+ parser.add_argument(
169
+ "--layer_penalty_factor",
170
+ type=float,
171
+ default=5.0,
172
+ help="The penalty factor for layer-wise sampling.",
173
+ )
174
+ parser.add_argument(
175
+ "--position_temperature",
176
+ type=float,
177
+ default=5.0,
178
+ help="The temperature for position selection.",
179
+ )
180
+ parser.add_argument(
181
+ "--class_temperature",
182
+ type=float,
183
+ default=0.0,
184
+ help="The temperature for class token sampling.",
185
+ )
186
+ parser.add_argument(
187
+ "--denoise",
188
+ type=str2bool,
189
+ default=True,
190
+ help="Whether to add <|denoise|> token in the reference.",
191
+ )
192
+ parser.add_argument(
193
+ "--lang_id",
194
+ type=str,
195
+ default=None,
196
+ help="Language id to use when test_list JSONL entries do not contain "
197
+ "language_id/language_name fields. If provided, both language_id and "
198
+ "language_name will be set to this value.",
199
+ )
200
+ return parser
201
+
202
+
203
+ def process_init(rank_queue, model_checkpoint, warmup=0):
204
+ """Initializer for each worker process.
205
+
206
+ Loads model (with tokenizers and duration estimator) onto a specific GPU
207
+ via ``OmniVoice.from_pretrained()``.
208
+ """
209
+ global worker_model
210
+
211
+ torch.set_num_threads(2)
212
+ torch.set_num_interop_threads(2)
213
+
214
+ formatter = (
215
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] "
216
+ "[Worker %(process)d] %(message)s"
217
+ )
218
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
219
+
220
+ rank = rank_queue.get()
221
+ device_type, device_id = rank
222
+ if device_type == "cpu":
223
+ worker_device = "cpu"
224
+ elif device_type == "mps":
225
+ worker_device = "mps"
226
+ else:
227
+ worker_device = f"cuda:{device_id}"
228
+
229
+ logging.info(f"Initializing worker on device: {worker_device}")
230
+
231
+ worker_model = OmniVoice.from_pretrained(
232
+ model_checkpoint,
233
+ device_map=worker_device,
234
+ dtype=torch.float16,
235
+ )
236
+
237
+ if warmup > 0:
238
+ logging.info(f"Running {warmup} warmup iterations on {worker_device}")
239
+ dummy_ref_audio = (
240
+ torch.randn(1, SAMPLING_RATE),
241
+ SAMPLING_RATE,
242
+ ) # 1s silence
243
+ for i in range(warmup):
244
+ worker_model.generate(
245
+ text=["hello"],
246
+ language=["en"],
247
+ ref_audio=[dummy_ref_audio],
248
+ ref_text=["hello"],
249
+ )
250
+ logging.info(f"Warmup complete on {worker_device}")
251
+
252
+ logging.info(f"Worker on {worker_device} initialized successfully.")
253
+
254
+
255
+ def estimate_sample_total_duration(
256
+ duration_estimator: RuleDurationEstimator,
257
+ text: str,
258
+ ref_text: str,
259
+ ref_audio_path: str,
260
+ gen_duration: Optional[float] = None,
261
+ ) -> float:
262
+ ref_wav = load_audio(ref_audio_path, SAMPLING_RATE)
263
+ ref_duration = ref_wav.shape[-1] / SAMPLING_RATE
264
+
265
+ if gen_duration is None:
266
+ gen_duration = duration_estimator.estimate_duration(
267
+ text, ref_text, ref_duration, low_threshold=2.0
268
+ )
269
+
270
+ total_duration = ref_duration + gen_duration
271
+ return total_duration
272
+
273
+
274
+ def cluster_samples_by_duration(
275
+ samples: List[Tuple],
276
+ duration_estimator: RuleDurationEstimator,
277
+ batch_duration: float,
278
+ ) -> List[List[Tuple]]:
279
+ sample_with_duration = []
280
+ for sample in samples:
281
+ save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd = sample
282
+ total_duration = estimate_sample_total_duration(
283
+ duration_estimator,
284
+ text,
285
+ ref_text,
286
+ ref_audio_path,
287
+ gen_duration=dur,
288
+ )
289
+ sample_with_duration.append((sample, total_duration))
290
+
291
+ sample_with_duration.sort(key=lambda x: x[1], reverse=True)
292
+ batches = []
293
+ current_batch = []
294
+ current_total_duration = 0.0
295
+
296
+ for sample, duration in sample_with_duration:
297
+ if duration > batch_duration:
298
+ batches.append([sample])
299
+ continue
300
+
301
+ if current_total_duration + duration <= batch_duration:
302
+ current_batch.append(sample)
303
+ current_total_duration += duration
304
+ else:
305
+ batches.append(current_batch)
306
+ current_batch = [sample]
307
+ current_total_duration = duration
308
+
309
+ if current_batch:
310
+ batches.append(current_batch)
311
+
312
+ logging.info(f"Clustered {len(samples)} samples into {len(batches)} batches")
313
+ return batches
314
+
315
+
316
+ def cluster_samples_by_batch_size(
317
+ samples: List[Tuple],
318
+ duration_estimator: RuleDurationEstimator,
319
+ batch_size: int,
320
+ ) -> List[List[Tuple]]:
321
+ """Split samples into fixed-size batches, sorted by duration to minimize padding."""
322
+ sample_with_duration = []
323
+ for sample in samples:
324
+ save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd = sample
325
+ total_duration = estimate_sample_total_duration(
326
+ duration_estimator,
327
+ text,
328
+ ref_text,
329
+ ref_audio_path,
330
+ gen_duration=dur,
331
+ )
332
+ sample_with_duration.append((sample, total_duration))
333
+
334
+ sample_with_duration.sort(key=lambda x: x[1], reverse=True)
335
+ sorted_samples = [s for s, _ in sample_with_duration]
336
+
337
+ batches = [
338
+ sorted_samples[i : i + batch_size]
339
+ for i in range(0, len(sorted_samples), batch_size)
340
+ ]
341
+ logging.info(
342
+ f"Split {len(samples)} samples into {len(batches)} batches "
343
+ f"(fixed batch_size={batch_size}, sorted by duration)"
344
+ )
345
+ return batches
346
+
347
+
348
+ def run_inference_batch(
349
+ batch_samples: List[Tuple],
350
+ res_dir: str,
351
+ **gen_kwargs,
352
+ ) -> List[Tuple]:
353
+ global worker_model
354
+
355
+ save_names = []
356
+ ref_texts = []
357
+ ref_audio_paths = []
358
+ texts = []
359
+ langs = []
360
+ durations = []
361
+ speeds = []
362
+
363
+ for sample in batch_samples:
364
+ save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd = sample
365
+ save_names.append(save_name)
366
+ ref_texts.append(ref_text)
367
+ ref_audio_paths.append(ref_audio_path)
368
+ texts.append(text)
369
+ langs.append(lang_id)
370
+ durations.append(dur)
371
+ speeds.append(spd)
372
+
373
+ start_time = time.time()
374
+ audios = worker_model.generate(
375
+ text=texts,
376
+ language=langs,
377
+ ref_audio=ref_audio_paths,
378
+ ref_text=ref_texts,
379
+ duration=durations if any(d is not None for d in durations) else None,
380
+ speed=speeds if any(s is not None for s in speeds) else None,
381
+ **gen_kwargs,
382
+ )
383
+ batch_synth_time = time.time() - start_time
384
+
385
+ results = []
386
+ for save_name, audio in zip(save_names, audios):
387
+ save_path = os.path.join(res_dir, save_name + ".wav")
388
+ torchaudio.save(save_path, audio, worker_model.sampling_rate)
389
+ audio_duration = audio.shape[-1] / worker_model.sampling_rate
390
+ results.append(
391
+ (
392
+ save_name,
393
+ batch_synth_time / len(batch_samples),
394
+ audio_duration,
395
+ "success",
396
+ )
397
+ )
398
+
399
+ return results
400
+
401
+
402
+ def main():
403
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
404
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
405
+ mp.set_start_method("spawn", force=True)
406
+
407
+ args = get_parser().parse_args()
408
+ os.makedirs(args.res_dir, exist_ok=True)
409
+
410
+ device_type, num_devices = get_best_device()
411
+ if device_type == "cpu":
412
+ logging.warning(
413
+ "No GPU found. Falling back to CPU inference. This might be slow."
414
+ )
415
+
416
+ num_processes = num_devices * args.nj_per_gpu
417
+ logging.info(
418
+ f"Using {device_type} ({num_devices} device(s))."
419
+ f" Spawning {num_processes} worker processes."
420
+ )
421
+
422
+ manager = mp.Manager()
423
+ rank_queue = manager.Queue()
424
+ for rank in list(range(num_devices)) * args.nj_per_gpu:
425
+ rank_queue.put((device_type, rank))
426
+
427
+ samples_raw = read_test_list(args.test_list)
428
+ samples = []
429
+ for s in samples_raw:
430
+ if args.lang_id is not None:
431
+ lang_id = args.lang_id
432
+ lang_name = args.lang_id
433
+ else:
434
+ lang_id = s.get("language_id")
435
+ lang_name = s.get("language_name")
436
+ samples.append(
437
+ (
438
+ s["id"],
439
+ s["ref_text"],
440
+ s["ref_audio"],
441
+ s["text"],
442
+ lang_id,
443
+ lang_name,
444
+ s.get("duration"),
445
+ s.get("speed"),
446
+ )
447
+ )
448
+
449
+ total_synthesis_time = []
450
+ total_audio_duration = []
451
+
452
+ try:
453
+ with ProcessPoolExecutor(
454
+ max_workers=num_processes,
455
+ initializer=process_init,
456
+ initargs=(rank_queue, args.model, args.warmup),
457
+ ) as executor:
458
+ futures = []
459
+
460
+ # parallel_chunk / no chunk
461
+ logging.info("Running batch inference")
462
+
463
+ duration_estimator = RuleDurationEstimator()
464
+ if args.batch_size > 0:
465
+ batches = cluster_samples_by_batch_size(
466
+ samples, duration_estimator, args.batch_size
467
+ )
468
+ else:
469
+ batches = cluster_samples_by_duration(
470
+ samples, duration_estimator, args.batch_duration
471
+ )
472
+
473
+ args_dict = vars(args)
474
+
475
+ for batch in batches:
476
+ futures.append(
477
+ executor.submit(
478
+ run_inference_batch, batch_samples=batch, **args_dict
479
+ )
480
+ )
481
+
482
+ for future in tqdm(
483
+ as_completed(futures), total=len(futures), desc="Processing samples"
484
+ ):
485
+ try:
486
+ result = future.result()
487
+ for s_name, synth_time, audio_dur, status in result:
488
+ total_synthesis_time.append(synth_time)
489
+ total_audio_duration.append(audio_dur)
490
+ rtf = synth_time / audio_dur if audio_dur > 0 else float("inf")
491
+ logging.debug(
492
+ f"Processed {s_name}: Audio Duration={audio_dur:.2f}s, "
493
+ f"Synthesis Time={synth_time:.2f}s, RTF={rtf:.4f}"
494
+ )
495
+ except Exception as e:
496
+ logging.error(f"Failed to process sample: {e}")
497
+ detailed_error = traceback.format_exc()
498
+ logging.error(f"Detailed error: {detailed_error}")
499
+
500
+ except (Exception, KeyboardInterrupt) as e:
501
+ logging.critical(
502
+ f"An unrecoverable error occurred: {e}. Terminating all processes."
503
+ )
504
+ detailed_error_info = traceback.format_exc()
505
+ logging.error(f"--- DETAILED TRACEBACK ---\n{detailed_error_info}")
506
+ os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)
507
+
508
+ total_synthesis_time = sum(total_synthesis_time)
509
+ total_audio_duration = sum(total_audio_duration)
510
+ logging.info("--- Summary ---")
511
+ logging.info(f"Total audio duration: {total_audio_duration:.2f}s")
512
+ logging.info(f"Total synthesis time: {total_synthesis_time:.2f}s")
513
+ if total_audio_duration > 0:
514
+ average_rtf = total_synthesis_time / total_audio_duration
515
+ logging.info(f"Average RTF: {average_rtf:.4f}")
516
+ else:
517
+ logging.warning("No speech was generated. RTF cannot be computed.")
518
+
519
+ logging.info("Done!")
520
+
521
+
522
+ if __name__ == "__main__":
523
+ main()
omnivoice/cli/train.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Training CLI for OmniVoice.
19
+
20
+ Launches distributed training via HuggingFace Accelerate.
21
+ Supports pre-training on Emilia data and finetuning on custom data.
22
+
23
+ Usage:
24
+ accelerate launch --gpu_ids 0,1,2,3 --num_processes 4 \\
25
+ -m omnivoice.cli.train \\
26
+ --train_config train_config.json \\
27
+ --data_config data_config.json \\
28
+ --output_dir output/
29
+
30
+ See examples/run_emilia.sh and examples/run_finetune.sh for full pipelines.
31
+ """
32
+
33
+ import argparse
34
+
35
+ from omnivoice.training.builder import build_dataloaders, build_model_and_tokenizer
36
+ from omnivoice.training.config import TrainingConfig
37
+ from omnivoice.training.trainer import OmniTrainer
38
+
39
+
40
+ def main():
41
+ parser = argparse.ArgumentParser(description="OmniVoice Training Entry Point")
42
+ parser.add_argument(
43
+ "--train_config", type=str, required=True, help="Path to config JSON"
44
+ )
45
+ parser.add_argument(
46
+ "--output_dir", type=str, required=True, help="Where to save checkpoints"
47
+ )
48
+ parser.add_argument(
49
+ "--data_config", type=str, required=True, help="Path to data config JSON"
50
+ )
51
+ args = parser.parse_args()
52
+
53
+ # 1. Load Configuration
54
+ config = TrainingConfig.from_json(args.train_config)
55
+ config.output_dir = args.output_dir
56
+ config.data_config = args.data_config
57
+
58
+ # 2. Build Components
59
+ model, tokenizer = build_model_and_tokenizer(config)
60
+ train_loader, eval_loader = build_dataloaders(config, tokenizer)
61
+
62
+ # 3. Initialize Trainer and Start
63
+ trainer = OmniTrainer(
64
+ model=model,
65
+ config=config,
66
+ train_dataloader=train_loader,
67
+ eval_dataloader=eval_loader,
68
+ tokenizer=tokenizer,
69
+ )
70
+ trainer.train()
71
+
72
+
73
+ if __name__ == "__main__":
74
+ main()
omnivoice/data/__init__.py ADDED
File without changes
omnivoice/data/batching.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Batching strategies for streaming/iterable datasets.
19
+
20
+ Provides length-based grouping and packing for efficient training with
21
+ variable-length audio.
22
+
23
+ Key classes:
24
+ - ``PackingIterableDataset``: Packs multiple samples into fixed-length sequences
25
+ for training. Used by ``omnivoice.training.builder``.
26
+ - ``StreamLengthGroupDataset``: Groups samples by length into buckets. Used by
27
+ data processing scripts (e.g. ``omnivoice/scripts/``).
28
+ """
29
+
30
+ import bisect
31
+ import logging
32
+ from typing import Any, Dict, Iterator, List, Optional
33
+
34
+ import numpy as np
35
+
36
+ from omnivoice.data.dataset import IterableDataReader, WrappedIterableDataset
37
+
38
+
39
+ class StreamLengthGroupDataset(WrappedIterableDataset):
40
+ """A streaming dataset that groups samples by their lengths into buckets.
41
+ Only support audio data for now."""
42
+
43
+ def __init__(
44
+ self,
45
+ dataset: IterableDataReader,
46
+ batch_duration: float,
47
+ min_length: float = 0.5,
48
+ max_length: float = 30.0,
49
+ num_buckets: int = 20,
50
+ audio_key: str = "audio",
51
+ drop_last: bool = False,
52
+ max_sample: Optional[int] = None,
53
+ ):
54
+ self.dataset = dataset
55
+ self.batch_duration = batch_duration
56
+ self.min_length = min_length
57
+ self.max_length = max_length
58
+ self.num_buckets = num_buckets
59
+ self.audio_key = audio_key
60
+ self.drop_last = drop_last
61
+ self.max_sample = max_sample if max_sample is not None else float("inf")
62
+
63
+ self.boundaries = np.linspace(min_length, max_length, num_buckets + 1)[1:]
64
+
65
+ def set_epoch(self, epoch: int):
66
+ """
67
+ Set the epoch for shuffling.
68
+ """
69
+ self.dataset.set_epoch(epoch)
70
+
71
+ def _get_bucket_id(self, length: float) -> int:
72
+
73
+ return bisect.bisect_left(self.boundaries, length)
74
+
75
+ def __iter__(self) -> Iterator[List[Dict[str, Any]]]:
76
+ buckets = [[] for _ in range(self.num_buckets)]
77
+ bucket_max_len = [0.0] * self.num_buckets
78
+
79
+ for sample in self.dataset:
80
+ audio = sample[self.audio_key]
81
+ duration = audio.size(-1) / self.dataset.sample_rate
82
+
83
+ if duration < self.min_length or duration > self.max_length:
84
+ # logging.warning(f"Skipping sample with duration {duration:.2f}s")
85
+ continue
86
+
87
+ b_id = self._get_bucket_id(duration)
88
+ buckets[b_id].append(sample)
89
+
90
+ if duration > bucket_max_len[b_id]:
91
+ bucket_max_len[b_id] = duration
92
+
93
+ if (
94
+ bucket_max_len[b_id] * (len(buckets[b_id]) + 1) >= self.batch_duration
95
+ or len(buckets[b_id]) >= self.max_sample
96
+ ):
97
+ yield buckets[b_id]
98
+ buckets[b_id] = []
99
+ bucket_max_len[b_id] = 0.0
100
+
101
+ if not self.drop_last:
102
+ for b_idx, bucket in enumerate(buckets):
103
+ if bucket:
104
+ yield bucket
105
+ buckets[b_idx] = []
106
+
107
+
108
+ class PackingIterableDataset(WrappedIterableDataset):
109
+ """
110
+ An IterableDataset that dynamically processes samples using a processor
111
+ and packs them into batches based on the real token count.
112
+
113
+ Args:
114
+ dataset (Iterable): The raw dataset to process.
115
+ processor (Callable): A processor to process each sample.
116
+ batch_tokens (int): Maximum number of tokens per batch.
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ dataset: IterableDataReader,
122
+ processor: Any,
123
+ batch_tokens: int,
124
+ ):
125
+ self.dataset = dataset
126
+ self.processor = processor
127
+ self.batch_tokens = batch_tokens
128
+ self.skip_batches = 0
129
+
130
+ def set_epoch(self, epoch: int):
131
+ """
132
+ Set the epoch for shuffling.
133
+ """
134
+ self.dataset.set_epoch(epoch)
135
+
136
+ def __iter__(self) -> Iterator[List[Dict[str, Any]]]:
137
+ current_batch = []
138
+ current_token_count = 0
139
+
140
+ for raw_sample in self.dataset:
141
+ # Process the sample using the processor
142
+ try:
143
+ processed_sample = self.processor(raw_sample)
144
+ except Exception as e:
145
+ logging.warning(f"Error processing sample {raw_sample}: {e}")
146
+ continue
147
+
148
+ sample_length = processed_sample["length"]
149
+
150
+ if sample_length > self.batch_tokens:
151
+ continue
152
+
153
+ # Check if adding this sample exceeds the batch token limit
154
+ if current_token_count + sample_length > self.batch_tokens:
155
+ # Yield the current batch and start a new one
156
+ yield current_batch
157
+ current_batch = []
158
+ current_token_count = 0
159
+
160
+ # Add the processed sample to the current batch
161
+ current_batch.append(processed_sample)
162
+ current_token_count += sample_length
163
+
164
+ # Yield the last batch if it's not empty
165
+ if current_batch:
166
+ yield current_batch
omnivoice/data/collator.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Data collator with packing for efficient training.
19
+
20
+ Packs multiple samples into a single sequence of fixed length (``batch_tokens``)
21
+ to maximize GPU utilization, instead of padding each sample individually.
22
+ Used by ``omnivoice.training.builder`` to create the collate function.
23
+ """
24
+
25
+ from typing import Any, Dict, List
26
+
27
+ import torch
28
+
29
+
30
+ class PackingDataCollator:
31
+ def __init__(self, processor, batch_tokens: int):
32
+ self.batch_tokens = batch_tokens
33
+ self.processor = processor
34
+
35
+ def __call__(self, processed_samples: List[Dict[str, Any]]) -> Dict[str, Any]:
36
+
37
+ target_length = self.batch_tokens
38
+
39
+ input_ids = torch.cat(
40
+ [s["input_ids"] for s in processed_samples], dim=1
41
+ ) # [C, Total_Len], C is the number of codebook layers of the audio tokenizer
42
+ labels = torch.cat(
43
+ [s["labels"] for s in processed_samples], dim=1
44
+ ) # [C, Total_Len]
45
+ audio_mask = torch.cat(
46
+ [s["audio_mask"] for s in processed_samples], dim=0
47
+ ) # [Total_Len]
48
+
49
+ position_ids = torch.cat(
50
+ [torch.arange(s["length"], dtype=torch.long) for s in processed_samples],
51
+ dim=0,
52
+ ) # [Total_Len]
53
+
54
+ pad_length = target_length - input_ids.shape[1]
55
+
56
+ input_ids = torch.nn.functional.pad(
57
+ input_ids,
58
+ pad=(0, pad_length),
59
+ value=self.processor.text_tokenizer.pad_token_id,
60
+ )
61
+
62
+ labels = torch.nn.functional.pad(labels, pad=(0, pad_length), value=-100)
63
+
64
+ audio_mask = torch.nn.functional.pad(
65
+ audio_mask, pad=(0, pad_length), value=False
66
+ )
67
+
68
+ position_ids = torch.nn.functional.pad(
69
+ position_ids, pad=(0, pad_length), value=0
70
+ )
71
+
72
+ return_list = {
73
+ "input_ids": input_ids.unsqueeze(0), # [1, C, L]
74
+ "labels": labels.unsqueeze(0), # [1, C, L]
75
+ "audio_mask": audio_mask.unsqueeze(0), # [1, L]
76
+ "position_ids": position_ids.unsqueeze(0), # [1, L]
77
+ }
78
+
79
+ document_ids_list = []
80
+
81
+ for i, s in enumerate(processed_samples):
82
+ seq_len = s["length"]
83
+ document_ids_list.append(torch.full((seq_len,), i, dtype=torch.int32))
84
+
85
+ document_ids = torch.cat(document_ids_list, dim=0)
86
+
87
+ document_ids = torch.nn.functional.pad(
88
+ document_ids, pad=(0, pad_length), value=-1
89
+ )
90
+ return_list["document_ids"] = document_ids.unsqueeze(0) # [1, L]
91
+
92
+ return return_list
omnivoice/data/dataset.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Dataset and data-loading utilities for training and evaluation.
19
+
20
+ Provides WebDataset-based iterable datasets, manifest parsing, and audio/token
21
+ loading. Used by ``omnivoice.training.builder.build_dataloaders()`` to construct
22
+ train and eval data loaders.
23
+
24
+ Key functions:
25
+ - ``prepare_data_manifests_from_json()``: Parses a data config JSON into train/dev
26
+ manifests.
27
+
28
+ Key classes:
29
+ - ``WebDatasetReader``: Reads audio/text pairs from WebDataset tar shards as an
30
+ iterable dataset.
31
+ - ``MuxWebDatasetReader``: Multiplexes multiple WebDataset readers for
32
+ multilingual data.
33
+ - ``JsonlDatasetReader``: Reads audio/text pairs from a JSONL manifest file.
34
+ Used by data processing scripts (e.g. ``omnivoice/scripts/``).
35
+ - ``SampleDecoder``: Decodes individual samples (audio or tokens + labels).
36
+ """
37
+
38
+ import io
39
+ import json
40
+ import logging
41
+ import os
42
+ import random
43
+ from typing import Any, Dict, Iterator, List, Optional, Tuple
44
+
45
+ import torch
46
+ import torch.distributed as dist
47
+ import torchaudio
48
+ import webdataset as wds
49
+ from torch.utils.data import IterableDataset
50
+
51
+
52
+ def load_audio_webdataset(data, sample_rate: int = 24000, device="cpu"):
53
+ """
54
+ Load audio from bytes data and resample to the target sample rate if needed.
55
+ Return a tensor of shape (1, num_samples)
56
+ """
57
+ audio, sr = torchaudio.load(io.BytesIO(data))
58
+ audio = audio.to(device)
59
+ if audio.size(dim=0) > 1:
60
+ audio = torch.mean(audio, dim=0)
61
+ if sr != sample_rate:
62
+ audio = torchaudio.functional.resample(audio, sr, sample_rate)
63
+ return audio
64
+
65
+
66
+ def prepare_data_manifests_from_json(
67
+ data_config: str,
68
+ ) -> Tuple[List[Tuple[str, str, int, float]], List[Tuple[str, str, int, float]]]:
69
+ """
70
+ Prepare data manifests from a json file.
71
+ A typical multilingual json file is in the following format:
72
+ {
73
+ "train":
74
+ [
75
+ {
76
+ "language_id": "en",
77
+ "manifest_path": [
78
+ "/Emilia/EN/data.lst"
79
+ ],
80
+ "repeat": 1
81
+ },
82
+ {
83
+ "language_id": "zh",
84
+ "manifest_path": [
85
+ "/Emilia/ZH/data.lst"
86
+ ],
87
+ "repeat": 1
88
+ }
89
+ ],
90
+ "dev":
91
+ [
92
+ {
93
+ "language_id": "en",
94
+ "manifest_path": [
95
+ "/Emilia/EN-dev/data.lst"
96
+ ],
97
+ "repeat": 1
98
+ },
99
+ {
100
+ "language_id": "zh",
101
+ "manifest_path": [
102
+ "/Emilia/ZH-dev/data.lst"
103
+ ],
104
+ "repeat": 1
105
+ }
106
+ ]
107
+ }
108
+
109
+ "language_id" is not used, just for better organization of multilingual data.
110
+ "repeat" is an optional field, default to 1, which indicates how many times
111
+ the manifest should be repeated.
112
+
113
+ The simplist format is like:
114
+ {
115
+ "train":
116
+ [
117
+ {
118
+ "manifest_path": [
119
+ "/Emilia/EN/data.lst",
120
+ "/Emilia/ZH/data.lst"
121
+ ],
122
+ }
123
+ ],
124
+ "dev":
125
+ [
126
+ {
127
+ "manifest_path": [
128
+ "/Emilia/EN-dev/data.lst",
129
+ "/Emilia/ZH-dev/data.lst"
130
+ ],
131
+ }
132
+ ]
133
+
134
+ data.lst format (items separated by space):
135
+ /path/to/data.tar /path/to/label.jsonl num_items num_seconds
136
+ """
137
+ train_manifests = []
138
+ dev_manifests = []
139
+ with open(data_config, "r", encoding="utf-8") as f:
140
+ data = json.load(f)
141
+ for item in data["train"]:
142
+ manifest_paths = item["manifest_path"]
143
+ repeat = item.get("repeat", 1)
144
+ for manifest_path in manifest_paths:
145
+ # assert manifest_path is a file
146
+ assert os.path.isfile(manifest_path), f"{manifest_path} is not a file."
147
+ train_manifests.extend(
148
+ webdataset_manifest_reader(manifest_path) * repeat
149
+ )
150
+ if "dev" in data:
151
+ for item in data["dev"]:
152
+ manifest_paths = item["manifest_path"]
153
+ repeat = item.get("repeat", 1)
154
+ for manifest_path in manifest_paths:
155
+ dev_manifests.extend(
156
+ webdataset_manifest_reader(manifest_path) * repeat
157
+ )
158
+ return train_manifests, dev_manifests
159
+
160
+
161
+ def webdataset_manifest_reader(
162
+ manifest_path: str,
163
+ ) -> List[Tuple[str, str]]:
164
+ """
165
+ Read a manifest file containing webdataset tar paths and label jsonl paths.
166
+ Each line in the manifest file is in the format of:
167
+ /path/to/data.tar /path/to/label.jsonl num_items num_seconds
168
+ """
169
+ manifests = []
170
+ with open(manifest_path, "r", encoding="utf-8") as f:
171
+ for line in f:
172
+ line = line.strip()
173
+ if not line:
174
+ continue
175
+ parts = line.split()
176
+ if len(parts) != 4:
177
+ raise ValueError(
178
+ f"Invalid manifest line: {line}. "
179
+ f"Each line must contain "
180
+ "tar_path, label_jsonl_path, num_items, num_seconds."
181
+ )
182
+ tar_path, label_jsonl_path, num_items, num_seconds = (
183
+ parts[0],
184
+ parts[1],
185
+ int(parts[2]),
186
+ float(parts[3]),
187
+ )
188
+ manifests.append((tar_path, label_jsonl_path, num_items, num_seconds))
189
+ return manifests
190
+
191
+
192
+ class SampleDecoder:
193
+ """
194
+ Decode a sample from webdataset, including loading audio/tokens and fetching label.
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ tar_to_label: Dict,
200
+ sample_rate: int = 24000,
201
+ audio_format: Optional[Tuple[str]] = None,
202
+ normalize_audio: bool = True,
203
+ ):
204
+ """
205
+ Args:
206
+ tar_to_label:
207
+ A dict mapping from audio tar file to label tar file.
208
+ sample_rate:
209
+ Target sample rate for audio. Required if audio is loaded.
210
+ audio_format:
211
+ Tuple of audio file extensions to look for in the sample.
212
+ """
213
+ self.tar_to_label = tar_to_label
214
+ self.sample_rate = sample_rate
215
+ self.label_dataset = None
216
+ if audio_format is None:
217
+ self.audio_format = ("flac", "wav", "mp3")
218
+ else:
219
+ self.audio_format = audio_format
220
+ self.normalize_audio = normalize_audio
221
+
222
+ def __call__(self, sample):
223
+ return_dict = {}
224
+ src = sample["__url__"]
225
+ key = sample["__key__"]
226
+ if (
227
+ self.label_dataset is None
228
+ or self.label_dataset.path != self.tar_to_label[src]
229
+ ):
230
+ self.label_dataset = LabelDataset(self.tar_to_label[src])
231
+
232
+ audio = torch.empty(0)
233
+ if "npy" in sample:
234
+ audio_tokens = torch.from_numpy(sample["npy"])
235
+ return_dict["audio_tokens"] = audio_tokens
236
+ else:
237
+ for ext in self.audio_format:
238
+ if ext in sample:
239
+ # load audio (1, num_samples)
240
+ audio = load_audio_webdataset(
241
+ sample[ext], sample_rate=self.sample_rate
242
+ )
243
+ if self.normalize_audio:
244
+ audio = (audio / (audio.abs().max() + 1e-7)) * 0.9
245
+ break
246
+ return_dict["audio"] = audio
247
+ return_dict["audio_duration"] = audio.size(-1) / self.sample_rate
248
+
249
+ label = self.label_dataset[key]
250
+
251
+ return_dict["label"] = label
252
+ return return_dict
253
+
254
+
255
+ class LabelDataset:
256
+ def __init__(self, jsonl_path: str):
257
+ """
258
+ Load labels from a jsonl file.
259
+ Args:
260
+ jsonl_path:
261
+ Path to the jsonl file containing labels.
262
+ Each line in the manifest file is in the format of:
263
+ {"idx": "idx", "text": "transcription text"}
264
+ """
265
+ self._labels = {}
266
+ self.path = jsonl_path
267
+ if not os.path.exists(jsonl_path):
268
+ raise FileNotFoundError(f"Label jsonl file {jsonl_path} does not exist.")
269
+ with open(jsonl_path, "r", encoding="utf-8") as f:
270
+ for line in f:
271
+ line = line.strip()
272
+ if not line:
273
+ continue
274
+ item = json.loads(line)
275
+ if "id" in item:
276
+ self._labels[item["id"]] = item
277
+
278
+ def __getitem__(self, key):
279
+ return self._labels[key]
280
+
281
+
282
+ class IterableDataReader:
283
+ "Interfaces for classes reading data."
284
+
285
+ sample_rate: int
286
+
287
+ def set_epoch(self, epoch: int):
288
+ raise NotImplementedError
289
+
290
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
291
+ raise NotImplementedError
292
+
293
+ def __len__(self) -> int:
294
+ raise NotImplementedError
295
+
296
+
297
+ class WrappedIterableDataset(IterableDataset):
298
+ "IterableDataset interfaces in this project."
299
+
300
+ def set_epoch(self, epoch: int):
301
+ raise NotImplementedError
302
+
303
+ def __iter__(self) -> Iterator[List[Dict[str, Any]]]:
304
+ raise NotImplementedError
305
+
306
+
307
+ class WebDatasetReader(IterableDataReader):
308
+ def __init__(
309
+ self,
310
+ manifests: List[Tuple[str, str, int, float]],
311
+ evaluation: bool = False,
312
+ shuffle_buffer_size: int = 20000,
313
+ sample_rate: int = 24000,
314
+ ):
315
+ self.shuffle_buffer_size = shuffle_buffer_size
316
+ self.evaluation = evaluation
317
+ self.epoch = 0
318
+
319
+ self.orig_urls = []
320
+ self.tar_to_label = {}
321
+ self.num_items = 0
322
+ self.num_seconds = 0.0
323
+ for tar_path, label_jsonl_path, num_items, num_seconds in manifests:
324
+ self.orig_urls.append(tar_path)
325
+ self.tar_to_label[tar_path] = label_jsonl_path
326
+ self.num_items += num_items
327
+ self.num_seconds += num_seconds
328
+ self.urls = self.orig_urls.copy()
329
+ self.sample_decoder = SampleDecoder(
330
+ tar_to_label=self.tar_to_label,
331
+ sample_rate=sample_rate,
332
+ )
333
+ self.sample_rate = sample_rate
334
+
335
+ def set_epoch(self, epoch: int):
336
+ """
337
+ Set the epoch for shuffling.
338
+ """
339
+ self.epoch = epoch
340
+ self.urls = self.orig_urls.copy()
341
+ if not self.evaluation:
342
+ random.Random(epoch).shuffle(self.urls)
343
+
344
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
345
+
346
+ dataset = wds.WebDataset(
347
+ self.urls,
348
+ shardshuffle=False,
349
+ workersplitter=wds.split_by_worker,
350
+ nodesplitter=wds.split_by_node,
351
+ )
352
+
353
+ pipeline = dataset.decode().map(self.sample_decoder)
354
+ if not self.evaluation:
355
+ pipeline = pipeline.shuffle(self.shuffle_buffer_size, seed=self.epoch)
356
+ return iter(pipeline)
357
+
358
+ def __len__(self) -> int:
359
+ return self.num_items
360
+
361
+
362
+ class JsonlDatasetReader(IterableDataReader):
363
+ """Read raw JSONL and load audio files, matching WebDatasetReader output format.
364
+
365
+ Each JSONL line should be a JSON object with at least:
366
+ {"id": "...", "audio_path": "/path/to/audio.wav", ...}
367
+
368
+ Yields dicts of the form: {"audio": Tensor(1, T), "label": dict}
369
+ """
370
+
371
+ def __init__(
372
+ self,
373
+ jsonl_path: str,
374
+ sample_rate: int = 24_000,
375
+ shuffle: bool = True,
376
+ shuffle_seed: int = 42,
377
+ normalize_audio: bool = True,
378
+ ):
379
+ self.jsonl_path = jsonl_path
380
+ self.sample_rate = sample_rate
381
+ self.shuffle = shuffle
382
+ self.shuffle_seed = shuffle_seed
383
+ self.normalize_audio = normalize_audio
384
+
385
+ def set_epoch(self, epoch: int):
386
+ self.shuffle_seed = epoch
387
+
388
+ def _read_lines(self) -> list[dict]:
389
+ entries = []
390
+ with open(self.jsonl_path, "r", encoding="utf-8") as f:
391
+ for line in f:
392
+ line = line.strip()
393
+ if line:
394
+ entries.append(json.loads(line))
395
+ if self.shuffle:
396
+ random.seed(self.shuffle_seed)
397
+ random.shuffle(entries)
398
+ logging.info(
399
+ f"Shuffled {len(entries)} JSONL entries (seed={self.shuffle_seed})"
400
+ )
401
+ return entries
402
+
403
+ def _stream_lines(self):
404
+ with open(self.jsonl_path, "r", encoding="utf-8") as f:
405
+ for line in f:
406
+ line = line.strip()
407
+ if line:
408
+ yield json.loads(line)
409
+
410
+ def __iter__(self):
411
+ source = self._read_lines() if self.shuffle else self._stream_lines()
412
+
413
+ # Split data across distributed ranks (multi-GPU / DDP)
414
+ if dist.is_initialized():
415
+ rank = dist.get_rank()
416
+ world_size = dist.get_world_size()
417
+ source = [item for i, item in enumerate(source) if i % world_size == rank]
418
+
419
+ # Split data across DataLoader workers to avoid duplication
420
+ worker_info = torch.utils.data.get_worker_info()
421
+ if worker_info is not None:
422
+ source = (
423
+ item
424
+ for i, item in enumerate(source)
425
+ if i % worker_info.num_workers == worker_info.id
426
+ )
427
+
428
+ for meta in source:
429
+ audio_path = meta.get("audio_path")
430
+ if not audio_path or not os.path.exists(audio_path):
431
+ logging.warning(
432
+ f"Skipping {meta.get('id', '?')}: audio_path missing or not found"
433
+ )
434
+ continue
435
+ try:
436
+ waveform, sr = torchaudio.load(audio_path)
437
+ if waveform.shape[0] > 1:
438
+ waveform = waveform.mean(dim=0, keepdim=True)
439
+ if sr != self.sample_rate:
440
+ waveform = torchaudio.functional.resample(
441
+ waveform, sr, self.sample_rate
442
+ )
443
+ if self.normalize_audio:
444
+ waveform = (waveform / (waveform.abs().max() + 1e-7)) * 0.9
445
+ meta["audio_duration"] = waveform.shape[1] / self.sample_rate
446
+ yield {"audio": waveform, "label": meta}
447
+ except Exception as e:
448
+ logging.warning(f"Skipping {meta.get('id', '?')}: {e}")
449
+
450
+
451
+ class MuxWebDatasetReader(IterableDataReader):
452
+ def __init__(
453
+ self,
454
+ readers: List[WebDatasetReader],
455
+ weights: Optional[List[float]] = None,
456
+ stop_early: bool = False,
457
+ seed: int = 0,
458
+ ):
459
+ self.readers = readers
460
+ self.stop_early = stop_early
461
+ self.mux_iterator = LazyIteratorMultiplexer(
462
+ *readers,
463
+ stop_early=stop_early,
464
+ weights=weights,
465
+ seed=seed,
466
+ )
467
+
468
+ def set_epoch(self, epoch: int):
469
+ """
470
+ Set the epoch for shuffling.
471
+ """
472
+ for reader in self.readers:
473
+ reader.set_epoch(epoch)
474
+
475
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
476
+ return iter(self.mux_iterator)
477
+
478
+
479
+ class LazyIteratorMultiplexer:
480
+ """
481
+ A wrapper over multiple iterators that enables to combine
482
+ lazy manifests in Lhotse. During iteration, unlike
483
+ :class:`.LazyIteratorChain`,
484
+ :class:`.LazyIteratorMultiplexer` at each step randomly
485
+ selects the iterable used to yield an item.
486
+
487
+ Since the iterables might be of different length, we provide
488
+ a ``weights`` parameter to let the user decide which iterables
489
+ should be sampled more frequently than others.
490
+ When an iterable is exhausted, we will keep sampling from the other iterables, until
491
+ we exhaust them all, unless ``stop_early`` is set to ``True``.
492
+ """
493
+
494
+ def __init__(
495
+ self,
496
+ *iterators: IterableDataReader,
497
+ stop_early: bool = False,
498
+ weights: Optional[List[float]] = None,
499
+ seed: int = 0,
500
+ ) -> None:
501
+ self.iterators = list(iterators)
502
+ self.stop_early = stop_early
503
+ self.seed = seed
504
+
505
+ assert (
506
+ len(self.iterators) > 1
507
+ ), "There have to be at least two iterables to multiplex."
508
+
509
+ if weights is None:
510
+ if all(hasattr(it, "__len__") for it in self.iterators):
511
+ lengths = [len(it) for it in self.iterators]
512
+ total_length = sum(lengths)
513
+ self.weights = [length / total_length for length in lengths]
514
+ else:
515
+ self.weights = [1] * len(self.iterators)
516
+ else:
517
+ self.weights = weights
518
+
519
+ assert len(self.iterators) == len(self.weights)
520
+
521
+ def __iter__(self):
522
+
523
+ rng = random.Random(self.seed)
524
+ iters = [iter(it) for it in self.iterators]
525
+ exhausted = [False for _ in range(len(iters))]
526
+
527
+ def should_continue():
528
+ if self.stop_early:
529
+ return not any(exhausted)
530
+ else:
531
+ return not all(exhausted)
532
+
533
+ while should_continue():
534
+ active_indexes, active_weights = zip(
535
+ *[
536
+ (i, w)
537
+ for i, (is_exhausted, w) in enumerate(zip(exhausted, self.weights))
538
+ if not is_exhausted
539
+ ]
540
+ )
541
+ idx = rng.choices(active_indexes, weights=active_weights, k=1)[0]
542
+ selected = iters[idx]
543
+ try:
544
+ item = next(selected)
545
+ yield item
546
+ except StopIteration:
547
+ exhausted[idx] = True
548
+ continue
549
+
550
+ def __len__(self) -> int:
551
+ return sum(len(iterator) for iterator in self.iterators)
omnivoice/data/processor.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Training sample processor for OmniVoice.
19
+
20
+ Converts raw audio/text samples into model-ready tensors: applies prompt/mask
21
+ tokenization, randomly drops conditioning, and injects language/instruct tokens.
22
+ Used by ``omnivoice.training.builder`` to build the data pipeline.
23
+
24
+ Contains two processor classes:
25
+ - ``OmniVoiceSampleProcessor``: Full processor used for training.
26
+ - ``OmniVoiceSimpleSampleProcessor``: Simplified processor (not used for training).
27
+ """
28
+
29
+ import random
30
+ from typing import Any, Dict
31
+
32
+ import torch
33
+
34
+
35
+ class OmniVoiceSampleProcessor:
36
+ """
37
+ Handles the logic of processing a raw sample into tensors
38
+ (masking, tokenization, etc.).
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ text_tokenizer: Any,
44
+ num_channels: int,
45
+ audio_mask_id: int,
46
+ prompt_ratio_range: tuple,
47
+ mask_ratio_range: tuple,
48
+ drop_cond_ratio: float,
49
+ language_ratio: float,
50
+ use_pinyin_ratio: float,
51
+ instruct_ratio: float,
52
+ only_instruct_ratio: float,
53
+ ):
54
+ self.text_tokenizer = text_tokenizer
55
+ self.num_channels = num_channels
56
+ self.audio_mask_id = audio_mask_id
57
+ self.prompt_ratio_range = prompt_ratio_range
58
+ self.mask_ratio_range = mask_ratio_range
59
+ self.drop_cond_ratio = drop_cond_ratio
60
+
61
+ self.language_ratio = language_ratio
62
+ self.use_pinyin_ratio = use_pinyin_ratio
63
+ self.instruct_ratio = instruct_ratio
64
+ self.only_instruct_ratio = only_instruct_ratio
65
+
66
+ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
67
+
68
+ # clean_start_token_idx is only used for prompt denoising training,
69
+ # where the prompt region is augmented with noises and the model
70
+ # needs to learn to recover the clean prompt.
71
+ # clean_start_token_idx indicates the start index of the clean generated token.
72
+ if "clean_start_token_idx" in sample["label"]:
73
+ drop_cond = False
74
+ else:
75
+ drop_cond = random.uniform(0, 1) < self.drop_cond_ratio
76
+
77
+ if drop_cond:
78
+ prompt_ratio = 0.0
79
+ drop_text = True
80
+ use_language = False
81
+ use_instruct = False
82
+ else:
83
+ prompt_ratio = random.uniform(*self.prompt_ratio_range)
84
+ drop_text = False
85
+ use_language = random.uniform(0, 1) < self.language_ratio
86
+ use_instruct = random.uniform(0, 1) < self.instruct_ratio
87
+ if use_instruct and random.uniform(0, 1) < self.only_instruct_ratio:
88
+ prompt_ratio = 0.0
89
+
90
+ mask_ratio = random.uniform(*self.mask_ratio_range)
91
+
92
+ # --- Style ---
93
+ style = ""
94
+ if use_language:
95
+ language = sample["label"].get("language_id", "None")
96
+ else:
97
+ language = "None"
98
+ if use_instruct:
99
+ instruct = sample["label"].get("instruct", "None")
100
+ else:
101
+ instruct = "None"
102
+
103
+ if "clean_start_token_idx" in sample["label"]:
104
+ style += "<|denoise|>"
105
+
106
+ style += f"<|lang_start|>{language}<|lang_end|>"
107
+ style += f"<|instruct_start|>{instruct}<|instruct_end|>"
108
+
109
+ style_inputs = self.text_tokenizer(style, return_tensors="pt").input_ids.repeat(
110
+ self.num_channels, 1
111
+ )
112
+ style_labels = torch.full(
113
+ style_inputs.shape, -100
114
+ ) # Style prompt does not compute loss
115
+
116
+ # --- Text ---
117
+ if (
118
+ "text_pinyin" in sample["label"]
119
+ and random.uniform(0, 1) < self.use_pinyin_ratio
120
+ ):
121
+ text = sample["label"]["text_pinyin"]
122
+ else:
123
+ text = sample["label"]["text"]
124
+ text_inputs = self.text_tokenizer(
125
+ f"<|text_start|>{text}<|text_end|>", return_tensors="pt"
126
+ ).input_ids.repeat(self.num_channels, 1)
127
+ text_labels = torch.full(text_inputs.shape, -100) # Text does not compute loss
128
+
129
+ # --- Audio ---
130
+ audio_tokens = sample["audio_tokens"].long()
131
+
132
+ # Masking Logic
133
+ if "clean_start_token_idx" in sample["label"]:
134
+ prompt_length = sample["label"]["clean_start_token_idx"]
135
+ else:
136
+ prompt_length = int(audio_tokens.shape[1] * prompt_ratio)
137
+
138
+ audio_inputs = audio_tokens.clone()
139
+ audio_labels = audio_tokens.clone()
140
+
141
+ # Apply masking
142
+ maskable_region = audio_tokens[:, prompt_length:]
143
+ token_mask = torch.rand(maskable_region.shape) < mask_ratio
144
+ audio_inputs[:, prompt_length:][token_mask] = self.audio_mask_id
145
+ audio_labels[:, prompt_length:][
146
+ ~token_mask
147
+ ] = -100 # Only compute loss on masked tokens
148
+ if not drop_cond:
149
+ audio_labels[:, :prompt_length] = -100 # No loss on prompt region
150
+
151
+ # --- Concatenation ---
152
+ if drop_text:
153
+ input_ids = audio_inputs
154
+ labels = audio_labels
155
+ total_length = input_ids.shape[1]
156
+ audio_mask = torch.ones(total_length, dtype=torch.bool)
157
+ else:
158
+ input_ids = torch.cat([style_inputs, text_inputs, audio_inputs], dim=1)
159
+ labels = torch.cat([style_labels, text_labels, audio_labels], dim=1)
160
+ total_length = input_ids.shape[1]
161
+ audio_start_idx = style_inputs.shape[1] + text_inputs.shape[1]
162
+ audio_mask = torch.zeros(total_length, dtype=torch.bool)
163
+ audio_mask[audio_start_idx:] = True
164
+
165
+ return_dict = {
166
+ "input_ids": input_ids, # [C, L]
167
+ "labels": labels, # [C, L]
168
+ "audio_mask": audio_mask, # [L]
169
+ "length": total_length,
170
+ }
171
+
172
+ return return_dict
173
+
174
+
175
+ class OmniVoiceSimpleSampleProcessor:
176
+ """
177
+ Handles the logic of processing a raw sample into tensors
178
+ (masking, tokenization, etc.).
179
+ This is a simpler version that does not include language, instructions,
180
+ or denoising prompts.
181
+ We do not use it for training as OmniVoiceSampleProcessor can cover this case.
182
+ We keep it as a reference implementation for users to understand the basic logics.
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ text_tokenizer: Any,
188
+ num_channels: int,
189
+ audio_mask_id: int,
190
+ prompt_ratio_range: tuple,
191
+ mask_ratio_range: tuple,
192
+ drop_cond_ratio: float,
193
+ ):
194
+ self.text_tokenizer = text_tokenizer
195
+ self.num_channels = num_channels
196
+ self.audio_mask_id = audio_mask_id
197
+ self.prompt_ratio_range = prompt_ratio_range
198
+ self.mask_ratio_range = mask_ratio_range
199
+ self.drop_cond_ratio = drop_cond_ratio
200
+
201
+ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
202
+ drop_cond = random.uniform(0, 1) < self.drop_cond_ratio
203
+ mask_ratio = random.uniform(*self.mask_ratio_range)
204
+
205
+ if drop_cond:
206
+ prompt_ratio = 0.0
207
+ else:
208
+ prompt_ratio = random.uniform(*self.prompt_ratio_range)
209
+
210
+ # --- Text ---
211
+ text = sample["label"]["text"]
212
+ text_inputs = self.text_tokenizer(
213
+ f"<|text_start|>{text}<|text_end|>", return_tensors="pt"
214
+ ).input_ids.repeat(self.num_channels, 1)
215
+ text_labels = torch.full(text_inputs.shape, -100) # Text does not compute loss
216
+
217
+ # --- Audio ---
218
+ audio_tokens = sample["audio_tokens"].long()
219
+
220
+ # Masking Logic
221
+ prompt_length = int(audio_tokens.shape[1] * prompt_ratio)
222
+ audio_inputs = audio_tokens.clone()
223
+ audio_labels = audio_tokens.clone()
224
+
225
+ # Apply masking
226
+ maskable_region = audio_tokens[:, prompt_length:]
227
+ token_mask = torch.rand(maskable_region.shape) < mask_ratio
228
+ audio_inputs[:, prompt_length:][token_mask] = self.audio_mask_id
229
+ audio_labels[:, prompt_length:][
230
+ ~token_mask
231
+ ] = -100 # Only compute loss on masked tokens
232
+
233
+ if not drop_cond:
234
+ # No loss on prompt region
235
+ audio_labels[:, :prompt_length] = -100
236
+
237
+ # --- Concatenation ---
238
+ if drop_cond:
239
+ input_ids = audio_inputs
240
+ labels = audio_labels
241
+ total_length = input_ids.shape[1]
242
+ audio_mask = torch.ones(total_length, dtype=torch.bool)
243
+ else:
244
+ input_ids = torch.cat([text_inputs, audio_inputs], dim=1)
245
+ labels = torch.cat([text_labels, audio_labels], dim=1)
246
+ total_length = input_ids.shape[1]
247
+ audio_start_idx = text_inputs.shape[1]
248
+ audio_mask = torch.zeros(total_length, dtype=torch.bool)
249
+ audio_mask[audio_start_idx:] = True
250
+
251
+ return_dict = {
252
+ "input_ids": input_ids, # [C, L]
253
+ "labels": labels, # [C, L]
254
+ "audio_mask": audio_mask, # [L]
255
+ "length": total_length,
256
+ }
257
+
258
+ return return_dict
omnivoice/eval/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import warnings
2
+
3
+ # Suppress specific warnings from zhconv that are not relevant to WER calculation
4
+ warnings.filterwarnings("ignore", category=UserWarning)
omnivoice/eval/models/ecapa_tdnn_wavlm.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import os
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+
25
+ class ECAPA_TDNN_WAVLM(nn.Module):
26
+ def __init__(
27
+ self,
28
+ feat_dim=80,
29
+ channels=512,
30
+ emb_dim=192,
31
+ global_context_att=False,
32
+ sr=16000,
33
+ ssl_model_path=None,
34
+ ):
35
+ super().__init__()
36
+ self.sr = sr
37
+
38
+ if ssl_model_path is None:
39
+ self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
40
+ else:
41
+ self.feature_extract = torch.hub.load(
42
+ os.path.dirname(ssl_model_path),
43
+ "wavlm_local",
44
+ source="local",
45
+ ckpt=os.path.join(ssl_model_path, "wavlm_large.pt"),
46
+ )
47
+
48
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
49
+ self.feature_extract.model.encoder.layers[23].self_attn,
50
+ "fp32_attention",
51
+ ):
52
+ self.feature_extract.model.encoder.layers[
53
+ 23
54
+ ].self_attn.fp32_attention = False
55
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
56
+ self.feature_extract.model.encoder.layers[11].self_attn,
57
+ "fp32_attention",
58
+ ):
59
+ self.feature_extract.model.encoder.layers[
60
+ 11
61
+ ].self_attn.fp32_attention = False
62
+
63
+ self.feat_num = self.get_feat_num()
64
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
65
+
66
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
67
+ # self.channels = [channels] * 4 + [channels * 3]
68
+ self.channels = [channels] * 4 + [1536]
69
+
70
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
71
+ self.layer2 = SE_Res2Block(
72
+ self.channels[0],
73
+ self.channels[1],
74
+ kernel_size=3,
75
+ stride=1,
76
+ padding=2,
77
+ dilation=2,
78
+ scale=8,
79
+ se_bottleneck_dim=128,
80
+ )
81
+ self.layer3 = SE_Res2Block(
82
+ self.channels[1],
83
+ self.channels[2],
84
+ kernel_size=3,
85
+ stride=1,
86
+ padding=3,
87
+ dilation=3,
88
+ scale=8,
89
+ se_bottleneck_dim=128,
90
+ )
91
+ self.layer4 = SE_Res2Block(
92
+ self.channels[2],
93
+ self.channels[3],
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=4,
97
+ dilation=4,
98
+ scale=8,
99
+ se_bottleneck_dim=128,
100
+ )
101
+
102
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
103
+ cat_channels = channels * 3
104
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
105
+ self.pooling = AttentiveStatsPool(
106
+ self.channels[-1],
107
+ attention_channels=128,
108
+ global_context_att=global_context_att,
109
+ )
110
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
111
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
112
+
113
+ def get_feat_num(self):
114
+ self.feature_extract.eval()
115
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
116
+ with torch.no_grad():
117
+ features = self.feature_extract(wav)
118
+ select_feature = features["hidden_states"]
119
+ if isinstance(select_feature, (list, tuple)):
120
+ return len(select_feature)
121
+ else:
122
+ return 1
123
+
124
+ def get_feat(self, x):
125
+ with torch.no_grad():
126
+ x = self.feature_extract([sample for sample in x])
127
+
128
+ x = x["hidden_states"]
129
+ if isinstance(x, (list, tuple)):
130
+ x = torch.stack(x, dim=0)
131
+ else:
132
+ x = x.unsqueeze(0)
133
+ norm_weights = (
134
+ F.softmax(self.feature_weight, dim=-1)
135
+ .unsqueeze(-1)
136
+ .unsqueeze(-1)
137
+ .unsqueeze(-1)
138
+ )
139
+ x = (norm_weights * x).sum(dim=0)
140
+ x = torch.transpose(x, 1, 2) + 1e-6
141
+
142
+ x = self.instance_norm(x)
143
+ return x
144
+
145
+ def forward(self, x):
146
+ x = self.get_feat(x)
147
+
148
+ out1 = self.layer1(x)
149
+ out2 = self.layer2(out1)
150
+ out3 = self.layer3(out2)
151
+ out4 = self.layer4(out3)
152
+
153
+ out = torch.cat([out2, out3, out4], dim=1)
154
+ out = F.relu(self.conv(out))
155
+ out = self.bn(self.pooling(out))
156
+ out = self.linear(out)
157
+
158
+ return out
159
+
160
+
161
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
162
+
163
+ """ Res2Conv1d + BatchNorm1d + ReLU
164
+ """
165
+
166
+
167
+ class Res2Conv1dReluBn(nn.Module):
168
+ """
169
+ in_channels == out_channels == channels
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ channels,
175
+ kernel_size=1,
176
+ stride=1,
177
+ padding=0,
178
+ dilation=1,
179
+ bias=True,
180
+ scale=4,
181
+ ):
182
+ super().__init__()
183
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
184
+ self.scale = scale
185
+ self.width = channels // scale
186
+ self.nums = scale if scale == 1 else scale - 1
187
+
188
+ self.convs = []
189
+ self.bns = []
190
+ for i in range(self.nums):
191
+ self.convs.append(
192
+ nn.Conv1d(
193
+ self.width,
194
+ self.width,
195
+ kernel_size,
196
+ stride,
197
+ padding,
198
+ dilation,
199
+ bias=bias,
200
+ )
201
+ )
202
+ self.bns.append(nn.BatchNorm1d(self.width))
203
+ self.convs = nn.ModuleList(self.convs)
204
+ self.bns = nn.ModuleList(self.bns)
205
+
206
+ def forward(self, x):
207
+ out = []
208
+ spx = torch.split(x, self.width, 1)
209
+ for i in range(self.nums):
210
+ if i == 0:
211
+ sp = spx[i]
212
+ else:
213
+ sp = sp + spx[i]
214
+ # Order: conv -> relu -> bn
215
+ sp = self.convs[i](sp)
216
+ sp = self.bns[i](F.relu(sp))
217
+ out.append(sp)
218
+ if self.scale != 1:
219
+ out.append(spx[self.nums])
220
+ out = torch.cat(out, dim=1)
221
+
222
+ return out
223
+
224
+
225
+ """ Conv1d + BatchNorm1d + ReLU
226
+ """
227
+
228
+
229
+ class Conv1dReluBn(nn.Module):
230
+ def __init__(
231
+ self,
232
+ in_channels,
233
+ out_channels,
234
+ kernel_size=1,
235
+ stride=1,
236
+ padding=0,
237
+ dilation=1,
238
+ bias=True,
239
+ ):
240
+ super().__init__()
241
+ self.conv = nn.Conv1d(
242
+ in_channels,
243
+ out_channels,
244
+ kernel_size,
245
+ stride,
246
+ padding,
247
+ dilation,
248
+ bias=bias,
249
+ )
250
+ self.bn = nn.BatchNorm1d(out_channels)
251
+
252
+ def forward(self, x):
253
+ return self.bn(F.relu(self.conv(x)))
254
+
255
+
256
+ """ The SE connection of 1D case.
257
+ """
258
+
259
+
260
+ class SE_Connect(nn.Module):
261
+ def __init__(self, channels, se_bottleneck_dim=128):
262
+ super().__init__()
263
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
264
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
265
+
266
+ def forward(self, x):
267
+ out = x.mean(dim=2)
268
+ out = F.relu(self.linear1(out))
269
+ out = torch.sigmoid(self.linear2(out))
270
+ out = x * out.unsqueeze(2)
271
+
272
+ return out
273
+
274
+
275
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
276
+ """
277
+
278
+
279
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
280
+ # return nn.Sequential(
281
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
282
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
283
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
284
+ # SE_Connect(channels)
285
+ # )
286
+
287
+
288
+ class SE_Res2Block(nn.Module):
289
+ def __init__(
290
+ self,
291
+ in_channels,
292
+ out_channels,
293
+ kernel_size,
294
+ stride,
295
+ padding,
296
+ dilation,
297
+ scale,
298
+ se_bottleneck_dim,
299
+ ):
300
+ super().__init__()
301
+ self.Conv1dReluBn1 = Conv1dReluBn(
302
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
303
+ )
304
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(
305
+ out_channels, kernel_size, stride, padding, dilation, scale=scale
306
+ )
307
+ self.Conv1dReluBn2 = Conv1dReluBn(
308
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
309
+ )
310
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
311
+
312
+ self.shortcut = None
313
+ if in_channels != out_channels:
314
+ self.shortcut = nn.Conv1d(
315
+ in_channels=in_channels,
316
+ out_channels=out_channels,
317
+ kernel_size=1,
318
+ )
319
+
320
+ def forward(self, x):
321
+ residual = x
322
+ if self.shortcut:
323
+ residual = self.shortcut(x)
324
+
325
+ x = self.Conv1dReluBn1(x)
326
+ x = self.Res2Conv1dReluBn(x)
327
+ x = self.Conv1dReluBn2(x)
328
+ x = self.SE_Connect(x)
329
+
330
+ return x + residual
331
+
332
+
333
+ """ Attentive weighted mean and standard deviation pooling.
334
+ """
335
+
336
+
337
+ class AttentiveStatsPool(nn.Module):
338
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
339
+ super().__init__()
340
+ self.global_context_att = global_context_att
341
+
342
+ # Use Conv1d with stride == 1 rather than Linear,
343
+ # then we don't need to transpose inputs.
344
+ if global_context_att:
345
+ self.linear1 = nn.Conv1d(
346
+ in_dim * 3, attention_channels, kernel_size=1
347
+ ) # equals W and b in the paper
348
+ else:
349
+ self.linear1 = nn.Conv1d(
350
+ in_dim, attention_channels, kernel_size=1
351
+ ) # equals W and b in the paper
352
+ self.linear2 = nn.Conv1d(
353
+ attention_channels, in_dim, kernel_size=1
354
+ ) # equals V and k in the paper
355
+
356
+ def forward(self, x):
357
+
358
+ if self.global_context_att:
359
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
360
+ context_std = torch.sqrt(
361
+ torch.var(x, dim=-1, keepdim=True) + 1e-10
362
+ ).expand_as(x)
363
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
364
+ else:
365
+ x_in = x
366
+
367
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
368
+ alpha = torch.tanh(self.linear1(x_in))
369
+ # alpha = F.relu(self.linear1(x_in))
370
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
371
+ mean = torch.sum(alpha * x, dim=2)
372
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
373
+ std = torch.sqrt(residuals.clamp(min=1e-9))
374
+ return torch.cat([mean, std], dim=1)
omnivoice/eval/models/utmos.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ UTMOS strong model.
20
+ Implementation from https://github.com/tarepan/SpeechMOS
21
+
22
+ """
23
+
24
+ import math
25
+ from typing import List, Optional, Tuple
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from torch import Tensor, nn
30
+
31
+
32
+ class UTMOS22Strong(nn.Module):
33
+ """Saeki_2022 paper's `UTMOS strong learner` inference model
34
+ (w/o Phoneme encoder)."""
35
+
36
+ def __init__(self):
37
+ """Init."""
38
+
39
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
40
+
41
+ feat_ssl, feat_domain_emb, feat_judge_emb, feat_rnn_h, feat_proj_h = (
42
+ 768,
43
+ 128,
44
+ 128,
45
+ 512,
46
+ 2048,
47
+ )
48
+ feat_cat = feat_ssl + feat_domain_emb + feat_judge_emb
49
+
50
+ # SSL/DataDomainEmb/JudgeIdEmb/BLSTM/Projection
51
+ self.wav2vec2 = Wav2Vec2Model()
52
+ self.domain_emb = nn.Parameter(
53
+ data=torch.empty(1, feat_domain_emb), requires_grad=False
54
+ )
55
+ self.judge_emb = nn.Parameter(
56
+ data=torch.empty(1, feat_judge_emb), requires_grad=False
57
+ )
58
+ self.blstm = nn.LSTM(
59
+ input_size=feat_cat,
60
+ hidden_size=feat_rnn_h,
61
+ batch_first=True,
62
+ bidirectional=True,
63
+ )
64
+ self.projection = nn.Sequential(
65
+ nn.Linear(feat_rnn_h * 2, feat_proj_h), nn.ReLU(), nn.Linear(feat_proj_h, 1)
66
+ )
67
+
68
+ def forward(self, wave: Tensor, sr: int) -> Tensor: # pylint: disable=invalid-name
69
+ """wave-to-score :: (B, T) -> (B,)"""
70
+
71
+ # Feature extraction :: (B, T) -> (B, Frame, Feat)
72
+ unit_series = self.wav2vec2(wave)
73
+ bsz, frm, _ = unit_series.size()
74
+
75
+ # DataDomain/JudgeId Embedding's Batch/Time expansion ::
76
+ # (B=1, Feat) -> (B=bsz, Frame=frm, Feat)
77
+ domain_series = self.domain_emb.unsqueeze(1).expand(bsz, frm, -1)
78
+ judge_series = self.judge_emb.unsqueeze(1).expand(bsz, frm, -1)
79
+
80
+ # Feature concatenation :: (B, Frame, Feat=f1) + (B, Frame, Feat=f2) +
81
+ # (B, Frame, Feat=f3) -> (B, Frame, Feat=f1+f2+f3)
82
+ cat_series = torch.cat([unit_series, domain_series, judge_series], dim=2)
83
+
84
+ # Frame-scale score estimation :: (B, Frame, Feat) -> (B, Frame, Feat)
85
+ # -> (B, Frame, Feat=1) - BLSTM/Projection
86
+ feat_series = self.blstm(cat_series)[0]
87
+ score_series = self.projection(feat_series)
88
+
89
+ # Utterance-scale score :: (B, Frame, Feat=1) -> (B, Feat=1)
90
+ # -> (B,) - Time averaging
91
+ utter_score = score_series.mean(dim=1).squeeze(1) * 2 + 3
92
+
93
+ return utter_score
94
+
95
+
96
+ class Wav2Vec2Model(nn.Module):
97
+ """Wav2Vev2."""
98
+
99
+ def __init__(self):
100
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
101
+
102
+ feat_h1, feat_h2 = 512, 768
103
+ feature_enc_layers = (
104
+ [(feat_h1, 10, 5)] + [(feat_h1, 3, 2)] * 4 + [(feat_h1, 2, 2)] * 2
105
+ )
106
+
107
+ self.feature_extractor = ConvFeatureExtractionModel(
108
+ conv_layers=feature_enc_layers
109
+ ) # pyright: ignore [reportGeneralTypeIssues]
110
+ self.layer_norm = nn.LayerNorm(feat_h1)
111
+ self.post_extract_proj = nn.Linear(feat_h1, feat_h2)
112
+ self.dropout_input = nn.Dropout(0.1)
113
+ self.encoder = TransformerEncoder(feat_h2)
114
+
115
+ # Remnants
116
+ self.mask_emb = nn.Parameter(torch.FloatTensor(feat_h2))
117
+
118
+ def forward(self, source: Tensor):
119
+ """FeatureEncoder + ContextTransformer"""
120
+
121
+ # Feature encoding
122
+ features = self.feature_extractor(source)
123
+ features = features.transpose(1, 2)
124
+ features = self.layer_norm(features)
125
+ features = self.post_extract_proj(features)
126
+
127
+ # Context transformer
128
+ x = self.encoder(features)
129
+
130
+ return x
131
+
132
+
133
+ class ConvFeatureExtractionModel(nn.Module):
134
+ """Feature Encoder."""
135
+
136
+ def __init__(self, conv_layers: List[Tuple[int, int, int]]):
137
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
138
+
139
+ def block(
140
+ n_in: int, n_out: int, k: int, stride: int, is_group_norm: bool = False
141
+ ):
142
+ if is_group_norm:
143
+ return nn.Sequential(
144
+ nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
145
+ nn.Dropout(p=0.0),
146
+ nn.GroupNorm(dim, dim, affine=True),
147
+ nn.GELU(),
148
+ )
149
+ else:
150
+ return nn.Sequential(
151
+ nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
152
+ nn.Dropout(p=0.0),
153
+ nn.GELU(),
154
+ )
155
+
156
+ in_d = 1
157
+ self.conv_layers = nn.ModuleList()
158
+ for i, params in enumerate(conv_layers):
159
+ (dim, k, stride) = params
160
+ self.conv_layers.append(block(in_d, dim, k, stride, is_group_norm=i == 0))
161
+ in_d = dim
162
+
163
+ def forward(self, series: Tensor) -> Tensor:
164
+ """:: (B, T) -> (B, Feat, Frame)"""
165
+
166
+ series = series.unsqueeze(1)
167
+ for conv in self.conv_layers:
168
+ series = conv(series)
169
+
170
+ return series
171
+
172
+
173
+ class TransformerEncoder(nn.Module):
174
+ """Transformer."""
175
+
176
+ def build_encoder_layer(self, feat: int):
177
+ """Layer builder."""
178
+ return TransformerSentenceEncoderLayer(
179
+ embedding_dim=feat,
180
+ ffn_embedding_dim=3072,
181
+ num_attention_heads=12,
182
+ activation_fn="gelu",
183
+ dropout=0.1,
184
+ attention_dropout=0.1,
185
+ activation_dropout=0.0,
186
+ layer_norm_first=False,
187
+ )
188
+
189
+ def __init__(self, feat: int):
190
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
191
+
192
+ self.required_seq_len_multiple = 2
193
+
194
+ self.pos_conv = nn.Sequential(
195
+ *[
196
+ nn.utils.weight_norm(
197
+ nn.Conv1d(feat, feat, kernel_size=128, padding=128 // 2, groups=16),
198
+ name="weight",
199
+ dim=2,
200
+ ),
201
+ SamePad(128),
202
+ nn.GELU(),
203
+ ]
204
+ )
205
+ self.layer_norm = nn.LayerNorm(feat)
206
+ self.layers = nn.ModuleList([self.build_encoder_layer(feat) for _ in range(12)])
207
+
208
+ def forward(self, x: Tensor) -> Tensor:
209
+
210
+ x_conv = self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
211
+ x = x + x_conv
212
+
213
+ x = self.layer_norm(x)
214
+
215
+ # pad to the sequence length dimension
216
+ x, pad_length = pad_to_multiple(
217
+ x, self.required_seq_len_multiple, dim=-2, value=0
218
+ )
219
+ if pad_length > 0:
220
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
221
+ padding_mask[:, -pad_length:] = True
222
+ else:
223
+ padding_mask, _ = pad_to_multiple(
224
+ None, self.required_seq_len_multiple, dim=-1, value=True
225
+ )
226
+
227
+ # :: (B, T, Feat) -> (T, B, Feat)
228
+ x = x.transpose(0, 1)
229
+ for layer in self.layers:
230
+ x = layer(x, padding_mask)
231
+ # :: (T, B, Feat) -> (B, T, Feat)
232
+ x = x.transpose(0, 1)
233
+
234
+ # undo paddding
235
+ if pad_length > 0:
236
+ x = x[:, :-pad_length]
237
+
238
+ return x
239
+
240
+
241
+ class SamePad(nn.Module):
242
+ """Tail inverse padding."""
243
+
244
+ def __init__(self, kernel_size: int):
245
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
246
+ assert kernel_size % 2 == 0, "`SamePad` now support only even kernel."
247
+
248
+ def forward(self, x: Tensor) -> Tensor:
249
+ return x[:, :, :-1]
250
+
251
+
252
+ def pad_to_multiple(
253
+ x: Optional[Tensor], multiple: int, dim: int = -1, value: float = 0
254
+ ) -> Tuple[Optional[Tensor], int]:
255
+ """Tail padding."""
256
+ if x is None:
257
+ return None, 0
258
+ tsz = x.size(dim)
259
+ m = tsz / multiple
260
+ remainder = math.ceil(m) * multiple - tsz
261
+ if m.is_integer():
262
+ return x, 0
263
+ pad_offset = (0,) * (-1 - dim) * 2
264
+
265
+ return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
266
+
267
+
268
+ class TransformerSentenceEncoderLayer(nn.Module):
269
+ """Transformer Encoder Layer used in BERT/XLM style pre-trained models."""
270
+
271
+ def __init__(
272
+ self,
273
+ embedding_dim: int,
274
+ ffn_embedding_dim: int,
275
+ num_attention_heads: int,
276
+ activation_fn: str,
277
+ dropout: float,
278
+ attention_dropout: float,
279
+ activation_dropout: float,
280
+ layer_norm_first: bool,
281
+ ) -> None:
282
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
283
+
284
+ assert layer_norm_first is False, "`layer_norm_first` is fixed to `False`"
285
+ assert activation_fn == "gelu", "`activation_fn` is fixed to `gelu`"
286
+
287
+ feat = embedding_dim
288
+
289
+ self.self_attn = MultiheadAttention(
290
+ feat, num_attention_heads, attention_dropout
291
+ )
292
+ self.dropout1 = nn.Dropout(dropout)
293
+ self.dropout2 = nn.Dropout(activation_dropout)
294
+ self.dropout3 = nn.Dropout(dropout)
295
+ self.fc1 = nn.Linear(feat, ffn_embedding_dim)
296
+ self.fc2 = nn.Linear(ffn_embedding_dim, feat)
297
+ self.self_attn_layer_norm = nn.LayerNorm(feat)
298
+ self.final_layer_norm = nn.LayerNorm(feat)
299
+
300
+ def forward(self, x: Tensor, self_attn_padding_mask: Optional[Tensor]):
301
+ # Res[Attn-Do]-LN
302
+ residual = x
303
+ x = self.self_attn(x, x, x, self_attn_padding_mask)
304
+ x = self.dropout1(x)
305
+ x = residual + x
306
+ x = self.self_attn_layer_norm(x)
307
+
308
+ # Res[SegFC-GELU-Do-SegFC-Do]-LN
309
+ residual = x
310
+ x = F.gelu(self.fc1(x)) # pyright: ignore [reportUnknownMemberType]
311
+ x = self.dropout2(x)
312
+ x = self.fc2(x)
313
+ x = self.dropout3(x)
314
+ x = residual + x
315
+ x = self.final_layer_norm(x)
316
+
317
+ return x
318
+
319
+
320
+ class MultiheadAttention(nn.Module):
321
+ """Multi-headed attention."""
322
+
323
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float):
324
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
325
+
326
+ self.embed_dim, self.num_heads, self.p_dropout = embed_dim, num_heads, dropout
327
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
328
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
329
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
330
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
331
+
332
+ def forward(
333
+ self,
334
+ query: Tensor,
335
+ key: Tensor,
336
+ value: Tensor,
337
+ key_padding_mask: Optional[Tensor],
338
+ ) -> Tensor:
339
+ """
340
+ Args:
341
+ query :: (T, B, Feat)
342
+ key_padding_mask :: (B, src_len) - mask to exclude keys that are pads
343
+ , where padding elements are indicated by 1s.
344
+ """
345
+ return F.multi_head_attention_forward(
346
+ query=query,
347
+ key=key,
348
+ value=value,
349
+ embed_dim_to_check=self.embed_dim,
350
+ num_heads=self.num_heads,
351
+ in_proj_weight=torch.empty([0]),
352
+ in_proj_bias=torch.cat(
353
+ (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)
354
+ ),
355
+ bias_k=None,
356
+ bias_v=None,
357
+ add_zero_attn=False,
358
+ dropout_p=self.p_dropout,
359
+ out_proj_weight=self.out_proj.weight,
360
+ out_proj_bias=self.out_proj.bias,
361
+ training=False,
362
+ key_padding_mask=key_padding_mask.bool()
363
+ if key_padding_mask is not None
364
+ else None,
365
+ need_weights=False,
366
+ use_separate_proj_weight=True,
367
+ q_proj_weight=self.q_proj.weight,
368
+ k_proj_weight=self.k_proj.weight,
369
+ v_proj_weight=self.v_proj.weight,
370
+ )[0]
omnivoice/eval/mos/utmos.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
20
+ """
21
+ import argparse
22
+ import logging
23
+ import multiprocessing as mp
24
+ import os
25
+ import sys
26
+ import traceback
27
+ import warnings
28
+ from concurrent.futures import ProcessPoolExecutor, as_completed
29
+
30
+ import numpy as np
31
+ import torch
32
+ from tqdm import tqdm
33
+
34
+ from omnivoice.eval.models.utmos import UTMOS22Strong
35
+ from omnivoice.eval.utils import load_waveform
36
+ from omnivoice.utils.data_utils import read_test_list
37
+
38
+ warnings.filterwarnings("ignore")
39
+
40
+ # Global variables for workers
41
+ worker_model = None
42
+ worker_device = None
43
+ worker_sr = 16000
44
+
45
+
46
+ def get_parser() -> argparse.ArgumentParser:
47
+ parser = argparse.ArgumentParser(
48
+ description="Calculate UTMOS score using UTMOS22Strong model."
49
+ )
50
+ parser.add_argument(
51
+ "--wav-path",
52
+ type=str,
53
+ required=True,
54
+ help="Path to the directory containing evaluated speech files.",
55
+ )
56
+ parser.add_argument(
57
+ "--test-list",
58
+ type=str,
59
+ required=True,
60
+ help="Path to the JSONL test list. Each line is a JSON object "
61
+ "with fields: id, text, ref_audio, ref_text, language_id, language_name.",
62
+ )
63
+ parser.add_argument(
64
+ "--model-dir",
65
+ type=str,
66
+ required=True,
67
+ help="Local path of our evaluation model repository."
68
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models."
69
+ "Will use 'tts_eval_models/mos/utmos22_strong_step7459_v1.pt'"
70
+ " in this script",
71
+ )
72
+ parser.add_argument(
73
+ "--extension",
74
+ type=str,
75
+ default="wav",
76
+ help="Extension of the speech files. Default: wav",
77
+ )
78
+ parser.add_argument(
79
+ "--decode-path",
80
+ type=str,
81
+ default=None,
82
+ help="Path to the output file where UTMOS information will be saved. "
83
+ "If not provided, results are only printed to console.",
84
+ )
85
+ parser.add_argument(
86
+ "--nj-per-gpu",
87
+ type=int,
88
+ default=1,
89
+ help="Number of worker processes to spawn per GPU.",
90
+ )
91
+ return parser
92
+
93
+
94
+ def get_device(rank: int = 0) -> torch.device:
95
+ assert torch.cuda.is_available(), "CUDA is required but not available."
96
+ device = torch.device(f"cuda:{rank}")
97
+ torch.cuda.set_device(rank)
98
+ return device
99
+
100
+
101
+ def worker_init(
102
+ rank_queue,
103
+ model_path,
104
+ ):
105
+ """Initialize worker process with model and device."""
106
+ global worker_model, worker_device, worker_sr
107
+
108
+ # Limit CPU threads per worker
109
+ torch.set_num_threads(2)
110
+
111
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] [Worker %(process)d] %(message)s"
112
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
113
+
114
+ rank = rank_queue.get() if rank_queue else -1
115
+
116
+ worker_device = get_device(rank)
117
+ worker_sr = 16000
118
+
119
+ logging.debug(f"Initializing UTMOS worker on {worker_device}")
120
+
121
+ # Initialize Model
122
+ worker_model = UTMOS22Strong()
123
+ try:
124
+ # Load weights to CPU first, then move to device
125
+ state_dict = torch.load(model_path, map_location="cpu")
126
+ worker_model.load_state_dict(state_dict)
127
+ except Exception as e:
128
+ logging.error(f"Failed to load model from {model_path}: {e}")
129
+ raise
130
+
131
+ worker_model.to(worker_device)
132
+ worker_model.eval()
133
+
134
+
135
+ @torch.no_grad()
136
+ def run_utmos_worker(file_idx, wav_path, language_name):
137
+ """Worker function to process a single audio file."""
138
+ try:
139
+ if not os.path.exists(wav_path):
140
+ return file_idx, wav_path, language_name, f"File not found: {wav_path}", "error"
141
+
142
+ # Load and preprocess waveform
143
+ speech = load_waveform(wav_path, worker_sr, device=worker_device)
144
+
145
+ # Compute score
146
+ # UTMOS expects input shape (Batch, Time)
147
+ score = worker_model(speech.unsqueeze(0), worker_sr)
148
+
149
+ return file_idx, wav_path, language_name, score.item(), "success"
150
+
151
+ except Exception as e:
152
+ error_detail = (
153
+ f"Error processing {wav_path}: {str(e)}\n"
154
+ f"Traceback:\n{traceback.format_exc()}"
155
+ )
156
+ return file_idx, wav_path, language_name, error_detail, "error"
157
+
158
+
159
+ def main():
160
+ parser = get_parser()
161
+ args = parser.parse_args()
162
+
163
+ # Main process thread setting
164
+ torch.set_num_threads(2)
165
+
166
+ mp.set_start_method("spawn", force=True)
167
+
168
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
169
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
170
+
171
+ # Validate inputs
172
+ if not os.path.isdir(args.wav_path):
173
+ logging.error(f"Invalid directory: {args.wav_path}")
174
+ sys.exit(1)
175
+
176
+ model_path = os.path.join(args.model_dir, "mos/utmos22_strong_step7459_v1.pt")
177
+ if not os.path.exists(model_path):
178
+ logging.error(f"Model file not found at {model_path}")
179
+ sys.exit(1)
180
+
181
+ # Scan directory for files
182
+ logging.info(f"Calculating UTMOS for {args.wav_path}")
183
+
184
+ wav_files = []
185
+ try:
186
+ samples = read_test_list(args.test_list)
187
+ for s in samples:
188
+ language_name = s.get("language_name") or "unknown"
189
+ eval_wav_path = os.path.join(args.wav_path, f"{s['id']}.{args.extension}")
190
+ wav_files.append((eval_wav_path, language_name))
191
+ except Exception as e:
192
+ raise ValueError(f"Error reading test list {args.test_list}: {e}")
193
+
194
+ # Setup Parallel Processing
195
+ num_gpus = torch.cuda.device_count()
196
+ assert num_gpus > 0, "No GPU found. GPU is required."
197
+ total_procs = num_gpus * args.nj_per_gpu
198
+
199
+ logging.info(
200
+ f"Starting evaluation with {total_procs} processes on {num_gpus} GPUs."
201
+ )
202
+
203
+ manager = mp.Manager()
204
+ rank_queue = manager.Queue()
205
+
206
+ for rank in list(range(num_gpus)) * args.nj_per_gpu:
207
+ rank_queue.put(rank)
208
+
209
+ scores = []
210
+
211
+ fout = None
212
+ if args.decode_path:
213
+ os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
214
+ fout = open(args.decode_path, "w", encoding="utf8")
215
+ logging.info(f"Saving detailed UTMOS results to: {args.decode_path}")
216
+ fout.write("Name\tUTMOS\n")
217
+
218
+ try:
219
+ with ProcessPoolExecutor(
220
+ max_workers=total_procs,
221
+ initializer=worker_init,
222
+ initargs=(
223
+ rank_queue,
224
+ model_path,
225
+ ),
226
+ ) as executor:
227
+ futures = []
228
+ for i, (wav_path, language_name) in enumerate(wav_files):
229
+ futures.append(
230
+ executor.submit(run_utmos_worker, i, wav_path, language_name)
231
+ )
232
+
233
+ pbar = tqdm(
234
+ as_completed(futures), total=len(wav_files), desc="Evaluating UTMOS"
235
+ )
236
+ lang_stats = {}
237
+ for future in pbar:
238
+ idx, path, language_name, result, status = future.result()
239
+ if status == "success":
240
+ if language_name not in lang_stats:
241
+ lang_stats[language_name] = []
242
+ lang_stats[language_name].append(result)
243
+ scores.append(result)
244
+ if fout:
245
+ if language_name == "unknown":
246
+ fout.write(f"{os.path.basename(path)}\t{result:.2f}\n")
247
+ else:
248
+ fout.write(
249
+ f"{language_name}\t{os.path.basename(path)}\t{result:.2f}\n"
250
+ )
251
+ else:
252
+ pbar.write(f"!!! FAILED [File {idx}]: {path} | {result}")
253
+
254
+ except (Exception, KeyboardInterrupt) as e:
255
+ logging.critical(
256
+ f"An unrecoverable error occurred: {e}. Terminating all processes."
257
+ )
258
+ detailed_error_info = traceback.format_exc()
259
+ logging.error(f"--- DETAILED TRACEBACK ---\n{detailed_error_info}")
260
+ sys.exit(1)
261
+
262
+ print("-" * 50)
263
+
264
+ if len(lang_stats) > 1:
265
+ lang_scores = []
266
+ for lang in sorted(lang_stats.keys()):
267
+ l_scores = lang_stats[lang]
268
+ l_avg = np.mean(l_scores)
269
+ lang_scores.append(l_scores)
270
+ l_count = len(l_scores)
271
+ logging.info(f"[{lang}] UTMOS score: {l_avg:.3f} ({l_count} samples)")
272
+ if fout:
273
+ fout.write(f"[{lang}] UTMOS: {l_avg:.3f} ({l_count} samples)\n")
274
+ logging.info(
275
+ f"Macro-average UTMOS over {len(lang_stats)} languages: "
276
+ f"{np.mean([np.mean(ls) for ls in lang_scores]):.3f}"
277
+ )
278
+ if fout:
279
+ fout.write(
280
+ f"\nMacro-average UTMOS over {len(lang_stats)} languages: "
281
+ f"{np.mean([np.mean(ls) for ls in lang_scores]):.3f}\n"
282
+ )
283
+
284
+ if scores:
285
+ avg_score = np.mean(scores)
286
+ logging.info(f"Processed {len(scores)}/{len(wav_files)} files.")
287
+ logging.info(f"UTMOS score: {avg_score:.2f}")
288
+ if fout:
289
+ fout.write(f"\nAverage UTMOS: {avg_score:.2f}\n")
290
+ else:
291
+ logging.error("No valid scores computed.")
292
+ print("-" * 50)
293
+
294
+ if fout:
295
+ fout.close()
296
+
297
+
298
+ if __name__ == "__main__":
299
+ main()
omnivoice/eval/speaker_similarity/sim.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Computes speaker similarity (SIM-o) using a WavLM-based
20
+ ECAPA-TDNN speaker verification model.
21
+ """
22
+ import argparse
23
+ import logging
24
+ import multiprocessing as mp
25
+ import os
26
+ import sys
27
+ import traceback
28
+ import warnings
29
+ from concurrent.futures import ProcessPoolExecutor, as_completed
30
+
31
+ import numpy as np
32
+ import torch
33
+ from tqdm import tqdm
34
+
35
+ from omnivoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
36
+ from omnivoice.eval.utils import load_waveform
37
+ from omnivoice.utils.data_utils import read_test_list
38
+
39
+ warnings.filterwarnings("ignore")
40
+
41
+ # Global variables for workers
42
+ worker_model = None
43
+ worker_device = None
44
+ worker_sr = 16000
45
+
46
+
47
+ def get_parser() -> argparse.ArgumentParser:
48
+ parser = argparse.ArgumentParser(
49
+ description="Calculate speaker similarity (SIM-o) score."
50
+ )
51
+ parser.add_argument(
52
+ "--wav-path",
53
+ type=str,
54
+ required=True,
55
+ help="Path to the directory containing evaluated speech files.",
56
+ )
57
+ parser.add_argument(
58
+ "--test-list",
59
+ type=str,
60
+ required=True,
61
+ help="Path to the JSONL test list. Each line is a JSON object "
62
+ "with fields: id, text, ref_audio, ref_text, language_id, language_name.",
63
+ )
64
+ parser.add_argument(
65
+ "--model-dir",
66
+ type=str,
67
+ required=True,
68
+ help="Local path of our evaluation model repository."
69
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models."
70
+ "Will use 'tts_eval_models/speaker_similarity/wavlm_large_finetune.pth'"
71
+ "and 'tts_eval_models/speaker_similarity/wavlm_large/' in this script",
72
+ )
73
+ parser.add_argument(
74
+ "--extension",
75
+ type=str,
76
+ default="wav",
77
+ help="Extension of the speech files.",
78
+ )
79
+ parser.add_argument(
80
+ "--decode-path",
81
+ type=str,
82
+ default=None,
83
+ help="Path to the output file where SIM-o information will be saved. "
84
+ "If not provided, results are only printed to console.",
85
+ )
86
+ parser.add_argument(
87
+ "--nj-per-gpu",
88
+ type=int,
89
+ default=1,
90
+ help="Number of worker processes to spawn per GPU.",
91
+ )
92
+ return parser
93
+
94
+
95
+ def get_device(rank: int = 0) -> torch.device:
96
+ assert torch.cuda.is_available(), "CUDA is required but not available."
97
+ device = torch.device(f"cuda:{rank}")
98
+ torch.cuda.set_device(rank)
99
+ return device
100
+
101
+
102
+ def worker_init(
103
+ rank_queue,
104
+ sv_model_path,
105
+ ssl_model_path,
106
+ ):
107
+ """Initialize worker process with model and device."""
108
+ global worker_model, worker_device, worker_sr
109
+
110
+ torch.set_num_threads(2)
111
+
112
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] [Worker %(process)d] %(message)s"
113
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
114
+
115
+ rank = rank_queue.get() if rank_queue else -1
116
+
117
+ worker_device = get_device(rank)
118
+ worker_sr = 16000
119
+
120
+ logging.debug(f"Initializing SIM-o worker on {worker_device}")
121
+ # Temporarily suppress INFO logs to hide verbose WavLM config
122
+ logging.disable(logging.INFO)
123
+
124
+ # Initialize Model
125
+ try:
126
+ worker_model = ECAPA_TDNN_WAVLM(
127
+ feat_dim=1024,
128
+ channels=512,
129
+ emb_dim=256,
130
+ sr=worker_sr,
131
+ ssl_model_path=ssl_model_path,
132
+ )
133
+ state_dict = torch.load(
134
+ sv_model_path, map_location=lambda storage, loc: storage
135
+ )
136
+ worker_model.load_state_dict(state_dict["model"], strict=False)
137
+ worker_model.to(worker_device)
138
+ worker_model.eval()
139
+ finally:
140
+ # Restore normal logging
141
+ logging.disable(logging.NOTSET)
142
+
143
+
144
+ @torch.no_grad()
145
+ def get_embedding(wav_path: str) -> torch.Tensor:
146
+ """Extract embedding for a single file."""
147
+ speech = load_waveform(wav_path, worker_sr, device=worker_device, max_seconds=120)
148
+ return worker_model([speech])
149
+
150
+
151
+ def run_similarity_worker(line_idx, sample, wav_dir, extension):
152
+ """Worker function to process a single pair."""
153
+ try:
154
+ wav_name = sample["id"]
155
+ ref_wav_path = sample["ref_audio"]
156
+ language_name = sample.get("language_name") or "unknown"
157
+ eval_wav_path = os.path.join(wav_dir, f"{wav_name}.{extension}")
158
+
159
+ if not os.path.exists(ref_wav_path):
160
+ return line_idx, f"Reference not found: {ref_wav_path}", None, "error"
161
+ if not os.path.exists(eval_wav_path):
162
+ return line_idx, f"Eval wav not found: {eval_wav_path}", None, "error"
163
+
164
+ # Compute embeddings pair-wise
165
+ ref_emb = get_embedding(ref_wav_path)
166
+ eval_emb = get_embedding(eval_wav_path)
167
+
168
+ # Cosine Similarity
169
+ similarity = torch.nn.functional.cosine_similarity(ref_emb, eval_emb, dim=-1)
170
+
171
+ return (
172
+ line_idx,
173
+ (ref_wav_path, eval_wav_path, language_name),
174
+ similarity.item(),
175
+ "success",
176
+ )
177
+
178
+ except Exception as e:
179
+ error_detail = f"Error: {str(e)}\nTraceback:\n{traceback.format_exc()}"
180
+ return line_idx, str(sample), error_detail, "error"
181
+
182
+
183
+ def main():
184
+ parser = get_parser()
185
+ args = parser.parse_args()
186
+
187
+ # Main process thread setting
188
+ torch.set_num_threads(2)
189
+
190
+ mp.set_start_method("spawn", force=True)
191
+
192
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
193
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
194
+
195
+ # Prepare paths
196
+ sv_model_path = os.path.join(
197
+ args.model_dir, "speaker_similarity/wavlm_large_finetune.pth"
198
+ )
199
+ ssl_model_path = os.path.join(args.model_dir, "speaker_similarity/wavlm_large/")
200
+
201
+ if not os.path.exists(sv_model_path) or not os.path.exists(ssl_model_path):
202
+ logging.error("Model files not found. Please check --model-dir.")
203
+ sys.exit(1)
204
+
205
+ logging.info(f"Calculating SIM-o for {args.wav_path}")
206
+ # Read list
207
+ samples = read_test_list(args.test_list)
208
+
209
+ # Setup Parallel Processing
210
+ num_gpus = torch.cuda.device_count()
211
+ assert num_gpus > 0, "No GPU found. GPU is required."
212
+ total_procs = num_gpus * args.nj_per_gpu
213
+
214
+ logging.info(
215
+ f"Starting evaluation with {total_procs} processes " f"on {num_gpus} GPUs."
216
+ )
217
+
218
+ manager = mp.Manager()
219
+ rank_queue = manager.Queue()
220
+
221
+ for rank in list(range(num_gpus)) * args.nj_per_gpu:
222
+ rank_queue.put(rank)
223
+
224
+ scores = []
225
+
226
+ fout = None
227
+ if args.decode_path:
228
+ os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
229
+ fout = open(args.decode_path, "w", encoding="utf8")
230
+ logging.info(f"Saving detailed SIM-o results to: {args.decode_path}")
231
+ fout.write("Prompt-path\tEval-path\tSIM-o\n")
232
+
233
+ try:
234
+ with ProcessPoolExecutor(
235
+ max_workers=total_procs,
236
+ initializer=worker_init,
237
+ initargs=(
238
+ rank_queue,
239
+ sv_model_path,
240
+ ssl_model_path,
241
+ ),
242
+ ) as executor:
243
+ futures = []
244
+ for i, sample in enumerate(samples):
245
+ futures.append(
246
+ executor.submit(
247
+ run_similarity_worker, i, sample, args.wav_path, args.extension
248
+ )
249
+ )
250
+
251
+ pbar = tqdm(
252
+ as_completed(futures), total=len(samples), desc="Evaluating SIM-o"
253
+ )
254
+
255
+ lang_stats = {}
256
+
257
+ for future in pbar:
258
+ idx, context, result, status = future.result()
259
+ if status == "success":
260
+ prompt_path, eval_path, lang = context
261
+ scores.append(result)
262
+
263
+ # Accumulate per-language
264
+ if lang not in lang_stats:
265
+ lang_stats[lang] = []
266
+ lang_stats[lang].append(result)
267
+
268
+ if fout:
269
+ if lang == "unknown":
270
+ fout.write(f"{prompt_path}\t{eval_path}\t{result:.2f}\n")
271
+ else:
272
+ fout.write(
273
+ f"{lang}\t{context[0]}\t{context[1]}\t{result:.2f}\n"
274
+ )
275
+ else:
276
+ pbar.write(f"!!! FAILED [Line {idx}]: {context} | Error: {result}")
277
+
278
+ except (Exception, KeyboardInterrupt) as e:
279
+ logging.critical(
280
+ f"An unrecoverable error occurred: {e}. " f"Terminating all processes."
281
+ )
282
+ detailed_error_info = traceback.format_exc()
283
+ logging.error(f"--- DETAILED TRACEBACK ---\n{detailed_error_info}")
284
+ sys.exit(1)
285
+
286
+ print("-" * 50)
287
+ if len(lang_stats) > 1:
288
+ lang_scores = []
289
+ for lang in sorted(lang_stats.keys()):
290
+ l_scores = lang_stats[lang]
291
+ l_avg = np.mean(l_scores)
292
+ lang_scores.append(l_scores)
293
+ l_count = len(l_scores)
294
+ logging.info(f"[{lang}] SIM-o score: {l_avg:.3f} ({l_count} pairs)")
295
+ if fout:
296
+ fout.write(f"[{lang}] SIM-o: {l_avg:.3f} ({l_count} pairs)\n")
297
+ logging.info(
298
+ f"Macro-average SIM-o over {len(lang_stats)} languages: "
299
+ f"{np.mean([np.mean(ls) for ls in lang_scores]):.3f}"
300
+ )
301
+ if fout:
302
+ fout.write(
303
+ f"\nMacro-average SIM-o over {len(lang_stats)} languages: "
304
+ f"{np.mean([np.mean(ls) for ls in lang_scores]):.3f}\n"
305
+ )
306
+
307
+ if scores:
308
+ avg_score = np.mean(scores)
309
+ logging.info(f"Processed {len(scores)}/{len(samples)} pairs.")
310
+ logging.info(f"SIM-o score: {avg_score:.3f}")
311
+ if fout:
312
+ fout.write(f"\nAverage SIM-o: {avg_score:.3f}\n")
313
+ else:
314
+ logging.error("No valid scores computed.")
315
+ if fout:
316
+ fout.close()
317
+ print("-" * 50)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ main()
omnivoice/eval/utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import logging
19
+ from typing import Optional
20
+
21
+ import librosa
22
+ import soundfile as sf
23
+ import torch
24
+
25
+
26
+ def load_waveform(
27
+ fname: str,
28
+ sample_rate: int,
29
+ dtype: str = "float32",
30
+ device: torch.device = torch.device("cpu"),
31
+ return_numpy: bool = False,
32
+ max_seconds: Optional[float] = None,
33
+ ) -> torch.Tensor:
34
+ """
35
+ Load an audio file, preprocess it, and convert to a PyTorch tensor.
36
+
37
+ Args:
38
+ fname (str): Path to the audio file.
39
+ sample_rate (int): Target sample rate for resampling.
40
+ dtype (str, optional): Data type to load audio as (default: "float32").
41
+ device (torch.device, optional): Device to place the resulting tensor
42
+ on (default: CPU).
43
+ return_numpy (bool): If True, returns a NumPy array instead of a
44
+ PyTorch tensor.
45
+ max_seconds (float): Maximum length (seconds) of the audio tensor.
46
+ If the audio is longer than this, it will be truncated.
47
+
48
+ Returns:
49
+ torch.Tensor: Processed audio waveform as a PyTorch tensor,
50
+ with shape (num_samples,).
51
+
52
+ Notes:
53
+ - If the audio is stereo, it will be converted to mono by averaging channels.
54
+ - If the audio's sample rate differs from the target, it will be resampled.
55
+ """
56
+ # Load audio file with specified data type
57
+ wav_data, sr = sf.read(fname, dtype=dtype)
58
+
59
+ # Convert stereo to mono if necessary
60
+ if len(wav_data.shape) == 2:
61
+ wav_data = wav_data.mean(1)
62
+
63
+ # Resample to target sample rate if needed
64
+ if sr != sample_rate:
65
+ wav_data = librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate)
66
+
67
+ if max_seconds is not None:
68
+ # Trim to max length
69
+ max_length = int(sample_rate * max_seconds)
70
+ if len(wav_data) > max_length:
71
+ wav_data = wav_data[:max_length]
72
+ logging.warning(
73
+ f"Wav file {fname} is longer than {max_seconds}s, "
74
+ f"truncated to {max_seconds}s to avoid OOM."
75
+ )
76
+ if return_numpy:
77
+ return wav_data
78
+ else:
79
+ wav_data = torch.from_numpy(wav_data)
80
+ return wav_data.to(device)
omnivoice/eval/wer/common.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Shared utilities for WER evaluation scripts.
20
+ """
21
+ import logging
22
+
23
+ import numpy as np
24
+ from jiwer import compute_measures
25
+
26
+
27
+ def process_one(hypothesis: str, truth: str, post_process, lang: str = None) -> dict:
28
+ """
29
+ Computes WER and related metrics for a single hypothesis-truth pair.
30
+
31
+ Args:
32
+ hypothesis (str): The transcribed text from the ASR model.
33
+ truth (str): The ground truth transcript.
34
+ post_process (callable): Text normalization function defined by each script.
35
+ Signature: post_process(text, lang) or post_process(text).
36
+ lang (str): The language code for post_process. Pass None if post_process
37
+ does not accept a lang argument.
38
+
39
+ Returns:
40
+ dict: A dict containing:
41
+ - truth (str): Post-processed ground truth text.
42
+ - hypothesis (str): Post-processed hypothesis text.
43
+ - wer (float): Word Error Rate.
44
+ - substitutions (int): Number of substitutions.
45
+ - deletions (int): Number of deletions.
46
+ - insertions (int): Number of insertions.
47
+ - word_num (int): Number of words in the post-processed ground truth.
48
+ """
49
+ if lang is not None:
50
+ truth_processed = post_process(truth, lang)
51
+ hypothesis_processed = post_process(hypothesis, lang)
52
+ else:
53
+ truth_processed = post_process(truth)
54
+ hypothesis_processed = post_process(hypothesis)
55
+ measures = compute_measures(truth_processed, hypothesis_processed)
56
+ word_num = len(truth_processed.split(" "))
57
+ return {
58
+ "truth": truth_processed,
59
+ "hypo": hypothesis_processed,
60
+ "wer": measures["wer"],
61
+ "substitutions": measures["substitutions"],
62
+ "deletions": measures["deletions"],
63
+ "insertions": measures["insertions"],
64
+ "word_num": word_num,
65
+ }
66
+
67
+
68
+ def log_metrics(fout, prefix, i_list, d_list, s_list, w_total, ndigits=2):
69
+ """Log weighted WER metrics for a subset of results."""
70
+ metrics_wer = round(
71
+ (np.sum(s_list) + np.sum(d_list) + np.sum(i_list)) / w_total * 100, ndigits
72
+ )
73
+ metrics_inse = np.sum(i_list)
74
+ metrics_dele = np.sum(d_list)
75
+ metrics_subs = np.sum(s_list)
76
+
77
+ logging.info(f"{prefix} WER: {metrics_wer}%")
78
+ logging.info(
79
+ f"{prefix} Errors: {metrics_inse} ins, {metrics_dele} del, "
80
+ f"{metrics_subs} sub / {w_total} words"
81
+ )
82
+ if fout:
83
+ fout.write(f"{prefix} WER: {metrics_wer}%\n")
84
+ fout.write(
85
+ f"{prefix} Errors: {metrics_inse} ins, {metrics_dele} del, "
86
+ f"{metrics_subs} sub / {w_total} words\n"
87
+ )
88
+ return metrics_wer
omnivoice/eval/wer/fleurs.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Computes word error rate (WER) for FLEURS multilingual evaluation.
19
+
20
+ Uses omnilingual-asr for ASR transcription across 100+ languages.
21
+ Requires a separate environment with ``omnilingual_asr`` installed.
22
+
23
+ Usage:
24
+ python3 omnivoice/eval/wer/fleurs.py \\
25
+ --wav-path results/fleurs \\
26
+ --test-list test.jsonl \\
27
+ --decode-path results/fleurs.wer.log \\
28
+ --model-card omniASR_LLM_Unlimited_7B_v2 \\
29
+ --chunk-size 100 --batch-size 50
30
+ """
31
+ import argparse
32
+ import logging
33
+ import multiprocessing as mp
34
+ import os
35
+ import re
36
+ import sys
37
+ import traceback
38
+ import types
39
+ from collections import defaultdict
40
+ from concurrent.futures import ProcessPoolExecutor, as_completed
41
+ from pathlib import Path
42
+ from typing import List, Union
43
+
44
+ import numpy as np
45
+ import torch
46
+ from tqdm import tqdm
47
+
48
+ try:
49
+ from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline
50
+ from omnilingual_asr.models.wav2vec2_llama.lang_ids import supported_langs
51
+ except ImportError:
52
+ logging.error("Please install omnilingual_asr first.")
53
+ exit(1)
54
+
55
+ # omnilingual-asr may pull a transformers version that lacks
56
+ # HiggsAudioV2TokenizerModel. Pre-register stubs to bypass
57
+ # omnivoice/__init__.py heavy imports.
58
+ if "omnivoice" not in sys.modules:
59
+ _root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
60
+ for _name in (
61
+ "omnivoice",
62
+ "omnivoice.eval",
63
+ "omnivoice.eval.wer",
64
+ "omnivoice.utils",
65
+ ):
66
+ if _name not in sys.modules:
67
+ _m = types.ModuleType(_name)
68
+ _m.__path__ = [os.path.join(_root, *_name.split(".")[1:])]
69
+ _m.__package__ = _name
70
+ sys.modules[_name] = _m
71
+
72
+ from omnivoice.eval.wer.common import log_metrics, process_one
73
+ from omnivoice.eval.wer.text_norm_omni import text_normalize
74
+ from omnivoice.utils.data_utils import read_test_list
75
+
76
+ # --- Global variables for worker processes ---
77
+ worker_pipe = None
78
+ worker_device = None
79
+
80
+
81
+ # fix mismatched language codes between OmniVoice and Omnilingual-ASR model
82
+ rename = {
83
+ "et": "ekk",
84
+ "ms": "zsm",
85
+ "sw": "swh",
86
+ "npi": "nep",
87
+ }
88
+
89
+
90
+ def read_language_mapping_from_tsv(
91
+ mapping_path: Path,
92
+ ) -> dict[str, Union[str, List[str]]]:
93
+ with open(mapping_path, "r", encoding="utf-8") as f:
94
+ _ = f.readline() # Skip header
95
+ language_mapping = {}
96
+ for line in f:
97
+ parts = line.strip().split("\t")
98
+ mixed_id, language_name, iso_639_3_id, duration = parts
99
+ language_mapping[iso_639_3_id] = mixed_id
100
+ return language_mapping
101
+
102
+
103
+ iso_639_3_id_to_mixed_id = read_language_mapping_from_tsv(
104
+ Path(f"{os.path.dirname(__file__)}/../../../docs/lang_id_name_map.tsv")
105
+ )
106
+
107
+ mixed_id_to_omnilingual_asr_lang = {}
108
+
109
+ for lang in supported_langs:
110
+ if lang in ("cmn_Hant",):
111
+ continue
112
+ iso_639_3_lang_code = lang.split("_")[0]
113
+ if iso_639_3_lang_code in iso_639_3_id_to_mixed_id:
114
+ mixed_id = iso_639_3_id_to_mixed_id[iso_639_3_lang_code]
115
+ mixed_id_to_omnilingual_asr_lang[mixed_id] = lang
116
+ else:
117
+ mixed_id_to_omnilingual_asr_lang[iso_639_3_lang_code] = lang
118
+
119
+
120
+ def clean_cjk_spaces(text):
121
+ """
122
+ Removes spaces adjacent to Chinese and Japanese characters while preserving
123
+ meaningful spaces in English or other languages (like Korean).
124
+ """
125
+
126
+ # Define CJK (Chinese, Japanese) Unicode ranges
127
+ # \u4e00-\u9fff: CJK Unified Ideographs (Chinese)
128
+ # \u3040-\u309f: Hiragana (Japanese)
129
+ # \u30a0-\u30ff: Katakana (Japanese)
130
+ # \u3000-\u303f: CJK Symbols and Punctuation
131
+ cjk_range = r"\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u3000-\u303f"
132
+
133
+ # 1. Remove spaces between two CJK characters
134
+ # Example: "我 爱 你" -> "我爱你"
135
+ text = re.sub(f"([{cjk_range}])\\s+([{cjk_range}])", r"\1\2", text)
136
+
137
+ # 2. Remove spaces between a CJK character and a non-CJK character (English/Numbers)
138
+ # Example: "我 爱 you" -> "我爱you"
139
+ text = re.sub(f"([{cjk_range}])\\s+", r"\1", text)
140
+ text = re.sub(f"\\s+([{cjk_range}])", r"\1", text)
141
+
142
+ # 3. Collapse multiple spaces into one for the remaining parts (e.g., English words)
143
+ text = re.sub(r"\s+", " ", text)
144
+
145
+ return text.strip()
146
+
147
+
148
+ def get_parser():
149
+ parser = argparse.ArgumentParser(
150
+ description="Computes WER with Whisper.",
151
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--wav-path",
156
+ type=str,
157
+ required=True,
158
+ help="Path to the directory containing speech files.",
159
+ )
160
+
161
+ parser.add_argument(
162
+ "--extension",
163
+ type=str,
164
+ default="wav",
165
+ help="Extension of the speech files. Default: wav",
166
+ )
167
+
168
+ parser.add_argument(
169
+ "--decode-path",
170
+ type=str,
171
+ default=None,
172
+ help="Path to the output file where WER information will be saved. "
173
+ "If not provided, results are only printed to console.",
174
+ )
175
+ parser.add_argument(
176
+ "--model-card",
177
+ type=str,
178
+ default="omniASR_LLM_7B",
179
+ help="Model card name for OmniASR (e.g., omniASR_LLM_7B) or local path.",
180
+ )
181
+ parser.add_argument(
182
+ "--test-list",
183
+ type=str,
184
+ default="test.jsonl",
185
+ help="path of the JSONL test list. Each line is a JSON object "
186
+ "with fields: id, text, ref_audio, ref_text, language_id, language_name.",
187
+ )
188
+ parser.add_argument(
189
+ "--lang",
190
+ type=str,
191
+ default=None,
192
+ help="""Language code to evaluate (e.g., 'en' for English, 'zh' for Chinese).
193
+ If not provided, the script will evaluate all languages found in the test list.
194
+ If specified, only samples of the given language will be evaluated.
195
+ """,
196
+ )
197
+ parser.add_argument(
198
+ "--batch-size",
199
+ type=int,
200
+ default=8,
201
+ help="Batch size for decoding with the Hugging Face pipeline.",
202
+ )
203
+ parser.add_argument(
204
+ "--nj-per-gpu", type=int, default=1, help="Number of workers per GPU."
205
+ )
206
+ parser.add_argument(
207
+ "--chunk-size",
208
+ type=int,
209
+ default=300,
210
+ help="Number of samples per task chunk sent to workers.",
211
+ )
212
+ return parser
213
+
214
+
215
+ def load_omni_model(model_card, device):
216
+ logging.info(f"Loading OmniASR model ({model_card}) on {device}...")
217
+ try:
218
+ pipeline = ASRInferencePipeline(model_card=model_card, device=str(device))
219
+ return pipeline
220
+ except Exception as e:
221
+ logging.error(f"Failed to load OmniASR pipeline: {e}")
222
+ return None
223
+
224
+
225
+ def process_init(rank_queue, model_card):
226
+ """
227
+ Initializer for each worker process.
228
+ """
229
+ global worker_pipe, worker_device
230
+
231
+ # Configure threads constraint
232
+ torch.set_num_threads(2)
233
+
234
+ try:
235
+ rank = rank_queue.get(timeout=10)
236
+ except Exception:
237
+ raise RuntimeError("Failed to get GPU rank from queue.")
238
+
239
+ assert torch.cuda.is_available(), "CUDA is required but not available."
240
+ worker_device = torch.device(f"cuda:{rank}")
241
+ torch.cuda.set_device(rank)
242
+
243
+ logging.info(f"Initializing worker on device: {worker_device}")
244
+
245
+ try:
246
+ # Using the model_card argument
247
+ worker_pipe = load_omni_model(model_card, worker_device)
248
+ if worker_pipe is None:
249
+ raise RuntimeError("Model loading failed.")
250
+ except Exception as e:
251
+ logging.critical(f"Failed to load model on {worker_device}: {e}")
252
+ raise e
253
+
254
+
255
+ def post_process(text: str, lang: str) -> str:
256
+ """
257
+ Cleans and normalizes text for WER calculation.
258
+ Args:
259
+ text (str): The input text to be processed.
260
+ lang (str): The language of the input text.
261
+
262
+ Returns:
263
+ str: The cleaned and normalized text.
264
+ """
265
+ lang_id = lang[:3] # Extract ISO 639-3 code (e.g., 'eng' from 'eng_Latn')
266
+ text = text_normalize(
267
+ text,
268
+ iso_code=lang_id,
269
+ lower_case=True,
270
+ remove_numbers=False,
271
+ remove_brackets=False,
272
+ )
273
+ text = clean_cjk_spaces(text)
274
+ text = text.replace(" ", "|")
275
+ text = " ".join([x for x in text])
276
+ return text
277
+
278
+
279
+ def run_eval_worker(data_chunk, language, batch_size):
280
+ """
281
+ Worker function to process a chunk of data.
282
+ Uses the global worker_pipe initialized by process_init.
283
+ """
284
+ global worker_pipe
285
+ if worker_pipe is None:
286
+ logging.error("Worker pipeline is not initialized!")
287
+ return []
288
+
289
+ metrics_buffer = []
290
+ try:
291
+ # Prepare batch lists for OmniASR
292
+ audio_paths = [item["wav_path"] for item in data_chunk]
293
+
294
+ # OmniASR expects explicit language codes for each file if not auto-detected.
295
+ # Using the language passed to the worker function, or item specific language
296
+ # Assuming item['lang_id'] is compatible (e.g., 'en', 'zh', 'arb_Arab')
297
+ # If the model needs full tokens like 'en_Latn', conversion might be needed here depending on input data.
298
+ lang_list = [item.get("lang_id", language) for item in data_chunk]
299
+
300
+ # Use the pipeline to infer batch
301
+ # OmniASR pipeline.transcribe returns a list of strings
302
+ transcriptions = worker_pipe.transcribe(
303
+ audio_paths, lang=lang_list, batch_size=batch_size
304
+ )
305
+
306
+ for i, hypo_text in enumerate(transcriptions):
307
+ ref_item = data_chunk[i]
308
+ truth = ref_item["truth_text"]
309
+ wav_path = ref_item["wav_path"]
310
+ lang_id = ref_item.get("lang_id")
311
+ lang_name = ref_item.get("lang_name")
312
+
313
+ m = process_one(hypo_text, truth, post_process, lang_id)
314
+ m["wav_path"] = wav_path
315
+ m["lang_name"] = lang_name
316
+ metrics_buffer.append(m)
317
+
318
+ except Exception:
319
+ logging.error(
320
+ f"Worker failed on chunk (Lang: {language}):\n{traceback.format_exc()}"
321
+ )
322
+ return []
323
+
324
+ return metrics_buffer
325
+
326
+
327
+ def main():
328
+ parser = get_parser()
329
+ args = parser.parse_args()
330
+
331
+ logging.basicConfig(
332
+ format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
333
+ level=logging.INFO,
334
+ force=True,
335
+ )
336
+
337
+ # 1. Prepare Data
338
+ logging.info("Reading test list...")
339
+ data_by_lang = defaultdict(list)
340
+ total_files = 0
341
+ wav_root = Path(args.wav_path)
342
+
343
+ samples = read_test_list(args.test_list)
344
+ for s in samples:
345
+ wav_path = str(wav_root / f"{s['id']}.{args.extension}")
346
+ if not os.path.exists(wav_path):
347
+ logging.warning(f"File missing: {wav_path}")
348
+ continue
349
+
350
+ lang_id = s.get("language_id") or "unknown"
351
+ if lang_id in rename:
352
+ lang_id = mixed_id_to_omnilingual_asr_lang[rename[lang_id]]
353
+ else:
354
+ lang_id = mixed_id_to_omnilingual_asr_lang[lang_id]
355
+ item = {
356
+ "wav_path": wav_path,
357
+ "truth_text": s["text"],
358
+ "lang_id": lang_id,
359
+ "lang_name": s.get("language_name") or "unknown",
360
+ }
361
+ if args.lang and s.get("language_id") != args.lang:
362
+ continue
363
+
364
+ data_by_lang[s.get("language_name") or "unknown"].append(item)
365
+
366
+ total_files += 1
367
+
368
+ logging.info(f"Total files: {total_files} in {len(data_by_lang)} languages.")
369
+
370
+ # 2. Worker config
371
+ num_gpus = torch.cuda.device_count()
372
+ assert num_gpus > 0, "No GPU found. GPU is required."
373
+ total_workers = num_gpus * args.nj_per_gpu
374
+
375
+ mp.set_start_method("spawn", force=True)
376
+ manager = mp.Manager()
377
+ rank_queue = manager.Queue()
378
+
379
+ for _ in range(args.nj_per_gpu):
380
+ for rank in range(num_gpus):
381
+ rank_queue.put(rank)
382
+
383
+ # 3. Scheduling: Split languages into chunks
384
+ # This prevents one huge language from blocking a worker for too long,
385
+ # allows better load balancing across the pool.
386
+ tasks = []
387
+ chunk_size = args.chunk_size
388
+
389
+ for lang_name, items in data_by_lang.items():
390
+ # Slicing the list into chunks
391
+ for i in range(0, len(items), chunk_size):
392
+ chunk = items[i : i + chunk_size]
393
+ tasks.append({"chunk": chunk, "lang": lang_name})
394
+
395
+ logging.info(
396
+ f"Split data into {len(tasks)} chunks (size ~{chunk_size}). Spawning {total_workers} workers."
397
+ )
398
+
399
+ # 4. Execution
400
+ results = []
401
+
402
+ with ProcessPoolExecutor(
403
+ max_workers=total_workers,
404
+ initializer=process_init,
405
+ initargs=(rank_queue, args.model_card),
406
+ ) as executor:
407
+
408
+ futures = []
409
+ for task in tasks:
410
+ futures.append(
411
+ executor.submit(
412
+ run_eval_worker, task["chunk"], task["lang"], args.batch_size
413
+ )
414
+ )
415
+
416
+ # Unified progress bar
417
+ with tqdm(total=total_files, desc="Eval Progress", dynamic_ncols=True) as pbar:
418
+ for future in as_completed(futures):
419
+ try:
420
+ chunk_metrics = future.result()
421
+ results.extend(chunk_metrics)
422
+ pbar.update(len(chunk_metrics))
423
+ except Exception as e:
424
+ logging.error(f"Task failed: {e}")
425
+
426
+ # 5. Metrics Aggregation
427
+ wers, inses, deles, subses = [], [], [], []
428
+ word_nums = 0
429
+
430
+ # Store metrics per language
431
+ lang_stats = {}
432
+
433
+ fout = None
434
+ if args.decode_path:
435
+ os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
436
+ logging.info(f"Saving detailed WER results to: {args.decode_path}")
437
+ fout = open(args.decode_path, "w", encoding="utf-8")
438
+
439
+ for res in results:
440
+ wers.append(float(res["wer"]))
441
+ inses.append(float(res["insertions"]))
442
+ deles.append(float(res["deletions"]))
443
+ subses.append(float(res["substitutions"]))
444
+ word_nums += res["word_num"]
445
+
446
+ if fout:
447
+ fout.write(
448
+ f"{res['wav_path']}\t{res['wer']}\t{res['truth']}\t"
449
+ f"{res['hypo']}\t{res['insertions']}\t{res['deletions']}\t"
450
+ f"{res['substitutions']}\n"
451
+ )
452
+ lang_name = res["lang_name"]
453
+
454
+ # Per language stats
455
+ if lang_name not in lang_stats:
456
+ lang_stats[lang_name] = {
457
+ "inses": [],
458
+ "deles": [],
459
+ "subses": [],
460
+ "word_nums": 0,
461
+ }
462
+ lang_stats[lang_name]["inses"].append(float(res["insertions"]))
463
+ lang_stats[lang_name]["deles"].append(float(res["deletions"]))
464
+ lang_stats[lang_name]["subses"].append(float(res["substitutions"]))
465
+ lang_stats[lang_name]["word_nums"] += res["word_num"]
466
+
467
+ print("-" * 50)
468
+ # Log per-language stats
469
+ per_lang_wers = []
470
+ for lang in sorted(lang_stats.keys()):
471
+ stats = lang_stats[lang]
472
+ if stats["word_nums"] > 0:
473
+ lang_wer = log_metrics(
474
+ fout,
475
+ f"[{lang}]",
476
+ stats["inses"],
477
+ stats["deles"],
478
+ stats["subses"],
479
+ stats["word_nums"],
480
+ )
481
+ per_lang_wers.append(lang_wer)
482
+ print("-" * 50)
483
+
484
+ # Log Macro-average WER
485
+ if len(per_lang_wers) > 1:
486
+ macro_wer = np.mean(per_lang_wers)
487
+ logging.info(
488
+ f"Macro-average WER over {len(per_lang_wers)} languages: {macro_wer:.2f}%"
489
+ )
490
+ if fout:
491
+ fout.write(
492
+ f"Macro-average WER over {len(per_lang_wers)} languages: {macro_wer:.2f}%\n"
493
+ )
494
+ count_le_5 = sum(1 for w in per_lang_wers if w <= 5.0)
495
+ count_le_10 = sum(1 for w in per_lang_wers if w <= 10.0)
496
+ count_le_20 = sum(1 for w in per_lang_wers if w <= 20.0)
497
+
498
+ stats_msg = (
499
+ f"Languages with WER/CER <= 5%: {count_le_5}/{len(per_lang_wers)}\n"
500
+ f"Languages with WER/CER <= 10%: {count_le_10}/{len(per_lang_wers)}\n"
501
+ f"Languages with WER/CER <= 20%: {count_le_20}/{len(per_lang_wers)}"
502
+ )
503
+
504
+ logging.info("\n" + stats_msg)
505
+ if fout:
506
+ fout.write(stats_msg + "\n")
507
+
508
+ # Log overall stats
509
+ if word_nums > 0:
510
+ log_metrics(fout, "Overall", inses, deles, subses, word_nums)
511
+
512
+ if fout:
513
+ fout.close()
514
+
515
+
516
+ if __name__ == "__main__":
517
+ main()
omnivoice/eval/wer/hubert.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Computes word error rate (WER) with Hubert models for LibriSpeech test sets.
20
+ """
21
+ import argparse
22
+ import logging
23
+ import multiprocessing as mp
24
+ import os
25
+ import re
26
+ import traceback
27
+ from concurrent.futures import ProcessPoolExecutor, as_completed
28
+ from pathlib import Path
29
+
30
+ import numpy as np
31
+ import torch
32
+ from tqdm import tqdm
33
+
34
+ from omnivoice.eval.utils import load_waveform
35
+ from omnivoice.eval.wer.common import process_one
36
+ from omnivoice.utils.data_utils import read_test_list
37
+
38
+ # --- Global variables for worker processes ---
39
+ worker_pipe = None
40
+ worker_device = None
41
+
42
+
43
+ def get_parser():
44
+ parser = argparse.ArgumentParser(
45
+ description="Computes WER with Hubert-based ASR model.",
46
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
47
+ )
48
+ parser.add_argument(
49
+ "--wav-path",
50
+ type=str,
51
+ required=True,
52
+ help="Path to the directory containing speech files.",
53
+ )
54
+ parser.add_argument(
55
+ "--extension",
56
+ type=str,
57
+ default="wav",
58
+ help="Extension of the speech files. Default: wav",
59
+ )
60
+ parser.add_argument(
61
+ "--decode-path",
62
+ type=str,
63
+ default=None,
64
+ help="Path to the output file where WER information will be saved. "
65
+ "If not provided, results are only printed to console.",
66
+ )
67
+ parser.add_argument(
68
+ "--model-dir",
69
+ type=str,
70
+ required=True,
71
+ help="Local path of our evaluation model repository."
72
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models."
73
+ "Will use 'tts_eval_models/wer/hubert-large-ls960-ft/'"
74
+ " in this script",
75
+ )
76
+ parser.add_argument(
77
+ "--test-list",
78
+ type=str,
79
+ default="transcript.jsonl",
80
+ help="path of the JSONL test list. Each line is a JSON object "
81
+ "with fields: id, text, ref_audio, ref_text, language_id, language_name.",
82
+ )
83
+ parser.add_argument(
84
+ "--batch-size",
85
+ type=int,
86
+ default=16,
87
+ help="Batch size for decoding with the Hugging Face pipeline.",
88
+ )
89
+ parser.add_argument(
90
+ "--nj-per-gpu", type=int, default=1, help="Number of workers per GPU."
91
+ )
92
+ return parser
93
+
94
+
95
+ def process_init(rank_queue, model_dir):
96
+ global worker_pipe, worker_device
97
+
98
+ torch.set_num_threads(2)
99
+
100
+ try:
101
+ rank = rank_queue.get(timeout=10)
102
+ except Exception:
103
+ raise RuntimeError("Failed to get GPU rank from queue.")
104
+
105
+ assert torch.cuda.is_available(), "CUDA is required but not available."
106
+ worker_device = torch.device(f"cuda:{rank}")
107
+ torch.cuda.set_device(rank)
108
+
109
+ logging.info(f"Initializing worker on device: {worker_device}")
110
+
111
+ try:
112
+ worker_pipe = load_hubert_model(model_dir, worker_device)
113
+ if worker_pipe is None:
114
+ raise RuntimeError("Model loading failed.")
115
+ except Exception as e:
116
+ logging.critical(f"Failed to load model on {worker_device}: {e}")
117
+ raise e
118
+
119
+
120
+ def load_hubert_model(model_dir, device):
121
+ model_path = os.path.join(model_dir, "wer/hubert-large-ls960-ft/")
122
+ if not os.path.exists(model_path):
123
+ logging.error(
124
+ f"Hubert model not found at {model_path}. "
125
+ "Please download from https://huggingface.co/k2-fsa/TTS_eval_models"
126
+ )
127
+ return None
128
+
129
+ logging.debug(f"Loading Hubert-based ASR model on {device}...")
130
+ import transformers
131
+
132
+ # Suppress transformers logging
133
+ transformers.logging.set_verbosity_error()
134
+
135
+ pipe = transformers.pipeline(
136
+ "automatic-speech-recognition",
137
+ model=model_path,
138
+ device=device,
139
+ tokenizer=model_path,
140
+ )
141
+ return pipe
142
+
143
+
144
+ def post_process(text: str) -> str:
145
+ """
146
+ Cleans and normalizes text for WER calculation.
147
+ Args:
148
+ text (str): The input text to be processed.
149
+
150
+ Returns:
151
+ str: The cleaned and normalized text.
152
+ """
153
+ text = text.replace("‘", "'").replace("’", "'")
154
+ text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
155
+ text = re.sub(r"\s+", " ", text).strip()
156
+ return text
157
+
158
+
159
+ def run_eval_worker(data_chunk, batch_size):
160
+ global worker_pipe
161
+ if worker_pipe is None:
162
+ logging.error("Worker pipeline is not initialized!")
163
+ return []
164
+
165
+ metrics_buffer = []
166
+ try:
167
+ dataset = [
168
+ {
169
+ "array": load_waveform(
170
+ item["wav_path"], sample_rate=16000, return_numpy=True
171
+ ),
172
+ "sampling_rate": 16000,
173
+ }
174
+ for item in data_chunk
175
+ ]
176
+ generate_kwargs = {"language": "english", "task": "transcribe"}
177
+
178
+ iterator = worker_pipe(
179
+ dataset, generate_kwargs=generate_kwargs, batch_size=batch_size
180
+ )
181
+
182
+ for i, out in enumerate(iterator):
183
+ hypothesis = out["text"].strip()
184
+ ref_item = data_chunk[i]
185
+ truth = ref_item["truth_text"]
186
+ wav_path = ref_item["wav_path"]
187
+
188
+ m = process_one(hypothesis, truth, post_process)
189
+ m["wav_path"] = wav_path
190
+ metrics_buffer.append(m)
191
+
192
+ except Exception:
193
+ logging.error(f"Worker failed on chunk:\n{traceback.format_exc()}")
194
+ return []
195
+
196
+ return metrics_buffer
197
+
198
+
199
+ def main():
200
+ parser = get_parser()
201
+ args = parser.parse_args()
202
+
203
+ logging.basicConfig(
204
+ format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
205
+ level=logging.INFO,
206
+ force=True,
207
+ )
208
+
209
+ logging.info(f"Calculating WER for {args.wav_path}")
210
+
211
+ data_list = []
212
+ samples = read_test_list(args.test_list)
213
+ for s in samples:
214
+ wav_full_path = str(Path(args.wav_path) / (s["id"] + "." + args.extension))
215
+ if not os.path.exists(wav_full_path):
216
+ logging.warning(f"File missing: {wav_full_path}")
217
+ continue
218
+ data_list.append(
219
+ {
220
+ "wav_path": wav_full_path,
221
+ "truth_text": s["text"],
222
+ }
223
+ )
224
+ total_files = len(data_list)
225
+
226
+ num_gpus = torch.cuda.device_count()
227
+ assert num_gpus > 0, "No GPU found. GPU is required."
228
+ total_workers = num_gpus * args.nj_per_gpu
229
+
230
+ mp.set_start_method("spawn", force=True)
231
+ manager = mp.Manager()
232
+ rank_queue = manager.Queue()
233
+
234
+ for _ in range(args.nj_per_gpu):
235
+ for rank in range(num_gpus):
236
+ rank_queue.put(rank)
237
+
238
+ chunk_size = max(1, args.batch_size)
239
+ tasks = [data_list[i : i + chunk_size] for i in range(0, total_files, chunk_size)]
240
+
241
+ logging.info(
242
+ f"Split data into {len(tasks)} chunks (size ~{chunk_size}). "
243
+ f"Spawning {total_workers} workers."
244
+ )
245
+
246
+ results = []
247
+
248
+ with ProcessPoolExecutor(
249
+ max_workers=total_workers,
250
+ initializer=process_init,
251
+ initargs=(rank_queue, args.model_dir),
252
+ ) as executor:
253
+
254
+ futures = []
255
+ for chunk in tasks:
256
+ futures.append(executor.submit(run_eval_worker, chunk, args.batch_size))
257
+
258
+ with tqdm(total=total_files, desc="Eval Progress", dynamic_ncols=True) as pbar:
259
+ for future in as_completed(futures):
260
+ chunk_metrics = future.result()
261
+ results.extend(chunk_metrics)
262
+ pbar.update(len(chunk_metrics))
263
+
264
+ wers, inses, deles, subses = [], [], [], []
265
+ word_nums = 0
266
+
267
+ fout = None
268
+ if args.decode_path:
269
+ os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
270
+ fout = open(args.decode_path, "w", encoding="utf8")
271
+ logging.info(f"Saving detailed WER results to: {args.decode_path}")
272
+ fout.write(
273
+ "Name\tWER\tTruth\tHypothesis\tInsertions\tDeletions\tSubstitutions\n"
274
+ )
275
+
276
+ for res in results:
277
+ wers.append(float(res["wer"]))
278
+ inses.append(float(res["insertions"]))
279
+ deles.append(float(res["deletions"]))
280
+ subses.append(float(res["substitutions"]))
281
+ word_nums += res["word_num"]
282
+
283
+ if fout:
284
+ fout.write(
285
+ f"{res['wav_path']}\t{res['wer']}\t{res['truth']}\t"
286
+ f"{res['hypo']}\t{res['insertions']}\t{res['deletions']}\t"
287
+ f"{res['substitutions']}\n"
288
+ )
289
+
290
+ wer_weighted = (
291
+ round(
292
+ (np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 2
293
+ )
294
+ if word_nums > 0
295
+ else float("nan")
296
+ )
297
+
298
+ inse_sum = np.sum(inses)
299
+ dele_sum = np.sum(deles)
300
+ subs_sum = np.sum(subses)
301
+
302
+ print("-" * 50)
303
+ logging.info(f"Processed {len(results)}/{total_files} files.")
304
+ wer_info = f"WER: {wer_weighted}%"
305
+ detailed_info = (
306
+ f"Errors: {inse_sum} ins, {dele_sum} del, {subs_sum} sub / {word_nums} words"
307
+ )
308
+ logging.info(wer_info)
309
+ logging.info(detailed_info)
310
+ print("-" * 50)
311
+
312
+ if fout:
313
+ fout.write(wer_info + "\n" + detailed_info + "\n")
314
+ fout.close()
315
+
316
+
317
+ if __name__ == "__main__":
318
+ main()
omnivoice/eval/wer/minimax.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Computes word error rate (WER) with Whisper-large-v3 for English and
20
+ Paraformer for Chinese. Intended to evaluate WERs on Seed-TTS test sets.
21
+ """
22
+ import argparse
23
+ import logging
24
+ import multiprocessing as mp
25
+ import os
26
+ import traceback
27
+ from collections import defaultdict
28
+ from concurrent.futures import ProcessPoolExecutor, as_completed
29
+ from pathlib import Path
30
+ from typing import List, Union
31
+
32
+ import numpy as np
33
+ import torch
34
+ import zhconv
35
+ from tqdm import tqdm
36
+
37
+ from omnivoice.eval.utils import load_waveform
38
+ from omnivoice.eval.wer.common import log_metrics, process_one
39
+ from omnivoice.eval.wer.text_norm_omni import text_normalize
40
+ from omnivoice.utils.data_utils import read_test_list
41
+
42
+ # --- Global variables for worker processes ---
43
+ worker_pipe = None
44
+ worker_paraformer = None
45
+ worker_device = None
46
+
47
+
48
+ def read_language_mapping_from_tsv(
49
+ mapping_path: Path,
50
+ ) -> dict[str, Union[str, List[str]]]:
51
+ with open(mapping_path, "r", encoding="utf-8") as f:
52
+ _ = f.readline() # Skip header
53
+ language_mapping = {}
54
+ for line in f:
55
+ parts = line.strip().split("\t")
56
+ mixed_id, language_name, iso_639_3_id, duration = parts
57
+ language_mapping[mixed_id] = iso_639_3_id
58
+ return language_mapping
59
+
60
+
61
+ mixed_id_to_iso_639_3_id = read_language_mapping_from_tsv(
62
+ Path(f"{os.path.dirname(__file__)}/../../../docs/lang_id_name_map.tsv")
63
+ )
64
+
65
+
66
+ def get_parser():
67
+ parser = argparse.ArgumentParser(
68
+ description="Computes WER with Whisper.",
69
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
70
+ )
71
+
72
+ parser.add_argument(
73
+ "--wav-path",
74
+ type=str,
75
+ required=True,
76
+ help="Path to the directory containing speech files.",
77
+ )
78
+
79
+ parser.add_argument(
80
+ "--extension",
81
+ type=str,
82
+ default="wav",
83
+ help="Extension of the speech files. Default: wav",
84
+ )
85
+
86
+ parser.add_argument(
87
+ "--decode-path",
88
+ type=str,
89
+ default=None,
90
+ help="Path to the output file where WER information will be saved. "
91
+ "If not provided, results are only printed to console.",
92
+ )
93
+ parser.add_argument(
94
+ "--model-dir",
95
+ type=str,
96
+ required=True,
97
+ help="Local path of evaluation models repository. "
98
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models. ",
99
+ )
100
+ parser.add_argument(
101
+ "--test-list",
102
+ type=str,
103
+ default="test.jsonl",
104
+ help="path of the JSONL test list. Each line is a JSON object "
105
+ "with fields: id, text, ref_audio, ref_text, language_id, language_name.",
106
+ )
107
+ parser.add_argument(
108
+ "--lang",
109
+ type=str,
110
+ default=None,
111
+ help="""Language code to evaluate (e.g., 'en' for English, 'zh' for Chinese).
112
+ If not provided, the script will evaluate all languages found in the test list.
113
+ If specified, only samples of the given language will be evaluated.
114
+ """,
115
+ )
116
+ parser.add_argument(
117
+ "--batch-size",
118
+ type=int,
119
+ default=16,
120
+ help="Batch size for decoding with the Hugging Face pipeline.",
121
+ )
122
+ parser.add_argument(
123
+ "--nj-per-gpu", type=int, default=1, help="Number of workers per GPU."
124
+ )
125
+ parser.add_argument(
126
+ "--chunk-size",
127
+ type=int,
128
+ default=10,
129
+ help="Number of samples per task chunk sent to workers.",
130
+ )
131
+ return parser
132
+
133
+
134
+ def load_whisper_model(model_dir, device):
135
+ model_path = os.path.join(model_dir, "wer/whisper-large-v3/")
136
+ if not os.path.exists(model_path):
137
+ logging.error(f"Whisper model not found at {model_path}.")
138
+ return None
139
+
140
+ import transformers
141
+
142
+ # Suppress transformers logging
143
+ transformers.logging.set_verbosity_error()
144
+
145
+ logging.info(f"Loading Whisper model on {device}...")
146
+ pipe = transformers.pipeline(
147
+ "automatic-speech-recognition",
148
+ model=model_path,
149
+ chunk_length_s=30,
150
+ dtype=torch.float16 if "cuda" in str(device) else torch.float32,
151
+ device=device,
152
+ )
153
+ return pipe
154
+
155
+
156
+ def load_paraformer_model(model_dir, device):
157
+ model_path = os.path.join(model_dir, "wer/paraformer-zh/")
158
+ if not os.path.exists(model_path):
159
+ logging.error(f"Paraformer model not found at {model_path}.")
160
+ return None
161
+
162
+ logging.info(f"Loading Paraformer model on {device}...")
163
+
164
+ previous_level = logging.root.manager.disable
165
+ logging.disable(logging.CRITICAL)
166
+
167
+ try:
168
+ from funasr import AutoModel
169
+
170
+ model = AutoModel(
171
+ model=model_path,
172
+ device=str(device),
173
+ disable_update=True,
174
+ disable_pbar=True,
175
+ verbose=False,
176
+ )
177
+ finally:
178
+ logging.disable(previous_level)
179
+
180
+ return model
181
+
182
+
183
+ def _worker_setup(rank_queue):
184
+ """Common worker setup: get rank, configure device and threads."""
185
+ global worker_device
186
+
187
+ torch.set_num_threads(2)
188
+
189
+ try:
190
+ rank = rank_queue.get(timeout=10)
191
+ except Exception:
192
+ raise RuntimeError("Failed to get GPU rank from queue.")
193
+
194
+ assert torch.cuda.is_available(), "CUDA is required but not available."
195
+ worker_device = torch.device(f"cuda:{rank}")
196
+ torch.cuda.set_device(rank)
197
+
198
+ logging.info(f"Initializing worker on device: {worker_device}")
199
+
200
+
201
+ def process_init(rank_queue, model_dir):
202
+ """Initializer for Whisper worker processes."""
203
+ global worker_pipe
204
+
205
+ _worker_setup(rank_queue)
206
+
207
+ try:
208
+ worker_pipe = load_whisper_model(model_dir, worker_device)
209
+ if worker_pipe is None:
210
+ raise RuntimeError("Whisper model loading failed.")
211
+ except Exception as e:
212
+ logging.critical(f"Failed to load Whisper model on {worker_device}: {e}")
213
+ raise e
214
+
215
+
216
+ def process_init_paraformer(rank_queue, model_dir):
217
+ """Initializer for Paraformer worker processes (Chinese evaluation)."""
218
+ global worker_paraformer
219
+
220
+ _worker_setup(rank_queue)
221
+
222
+ try:
223
+ worker_paraformer = load_paraformer_model(model_dir, worker_device)
224
+ if worker_paraformer is None:
225
+ raise RuntimeError("Paraformer model loading failed.")
226
+ except Exception as e:
227
+ logging.critical(f"Failed to load Paraformer model on {worker_device}: {e}")
228
+ raise e
229
+
230
+
231
+ def post_process(text: str, lang: str) -> str:
232
+ """
233
+ Cleans and normalizes text for WER calculation.
234
+ Args:
235
+ text (str): The input text to be processed.
236
+ lang (str): The language of the input text.
237
+
238
+ Returns:
239
+ str: The cleaned and normalized text.
240
+ """
241
+ if lang != "unknown":
242
+
243
+ iso_639_3_code = mixed_id_to_iso_639_3_id[lang]
244
+ text = text_normalize(
245
+ text,
246
+ iso_code=iso_639_3_code,
247
+ lower_case=True,
248
+ remove_numbers=False,
249
+ remove_brackets=False,
250
+ )
251
+
252
+ if lang in ["zh", "yue"]:
253
+ text = zhconv.convert(text, "zh-cn")
254
+
255
+ # Processing spaces for languages using CER (consistent with the practice
256
+ # in paper Minimax-Speech), specifically: zh, yue, ja, ko, th, arb, vi, hi, el.
257
+ if lang in ("zh", "yue", "ja"):
258
+ # For languages where spaces are not semantically meaningful, remove spaces.
259
+ text = text.replace(" ", "")
260
+ text = " ".join([x for x in text])
261
+ elif lang in ("ko", "th", "arb", "vi", "hi", "el"):
262
+ # For languages where spaces are semantically meaningful, replace spaces with |.
263
+ text = text.replace(" ", "|")
264
+ text = " ".join([x for x in text])
265
+ text = text.lower()
266
+ return text.strip()
267
+
268
+
269
+ class SpeechEvalDataset(torch.utils.data.Dataset):
270
+ def __init__(self, data_list):
271
+ self.data_list = data_list
272
+
273
+ def __len__(self):
274
+ return len(self.data_list)
275
+
276
+ def __getitem__(self, index):
277
+ item = self.data_list[index]
278
+ waveform = load_waveform(item["wav_path"], sample_rate=16000, return_numpy=True)
279
+ return {
280
+ "array": waveform,
281
+ "sampling_rate": 16000,
282
+ "truth_text": item["truth_text"],
283
+ }
284
+
285
+
286
+ def run_eval_worker(data_chunk, language, batch_size):
287
+ """
288
+ Worker function to process a chunk of data.
289
+ Uses the global worker_pipe initialized by process_init.
290
+ """
291
+ global worker_pipe
292
+ if worker_pipe is None:
293
+ logging.error("Worker pipeline is not initialized!")
294
+ return []
295
+
296
+ metrics_buffer = []
297
+ try:
298
+ dataset = SpeechEvalDataset(data_chunk)
299
+ if language != "unknown":
300
+ generate_kwargs = {"language": language, "task": "transcribe"}
301
+ else:
302
+ generate_kwargs = {"task": "transcribe"}
303
+
304
+ # Use the pipeline to infer batch
305
+ # Note: We iterate through the iterator returned by pipe
306
+ iterator = worker_pipe(
307
+ dataset, generate_kwargs=generate_kwargs, batch_size=batch_size
308
+ )
309
+
310
+ for i, out in enumerate(iterator):
311
+ hypothesis = out["text"].strip()
312
+
313
+ ref_item = data_chunk[i]
314
+ truth = ref_item["truth_text"]
315
+ wav_path = ref_item["wav_path"]
316
+ lang_id = ref_item.get("lang_id")
317
+ lang_name = ref_item.get("lang_name")
318
+
319
+ m = process_one(hypothesis, truth, post_process, lang_id)
320
+ m["wav_path"] = wav_path
321
+ m["lang_name"] = lang_name
322
+ metrics_buffer.append(m)
323
+
324
+ except Exception:
325
+ logging.error(
326
+ f"Worker failed on chunk (Lang: {language}):\n{traceback.format_exc()}"
327
+ )
328
+ return []
329
+
330
+ return metrics_buffer
331
+
332
+
333
+ def run_eval_worker_paraformer(data_chunk, batch_size):
334
+ """
335
+ Worker function for Chinese evaluation using Paraformer.
336
+ Uses the global worker_paraformer initialized by process_init_paraformer.
337
+ """
338
+ global worker_paraformer
339
+ if worker_paraformer is None:
340
+ logging.error("Paraformer worker pipeline is not initialized!")
341
+ return []
342
+
343
+ metrics_buffer = []
344
+ try:
345
+ wav_paths = [item["wav_path"] for item in data_chunk]
346
+
347
+ for i in range(0, len(wav_paths), batch_size):
348
+ batch_paths = wav_paths[i : i + batch_size]
349
+ res_batch = worker_paraformer.generate(
350
+ input=batch_paths, batch_size=batch_size, disable_pbar=True
351
+ )
352
+
353
+ for j, res in enumerate(res_batch):
354
+ hypothesis = res["text"]
355
+ ref_item = data_chunk[i + j]
356
+ truth = ref_item["truth_text"]
357
+ wav_path = ref_item["wav_path"]
358
+ lang_name = ref_item.get("lang_name")
359
+
360
+ m = process_one(hypothesis, truth, post_process, "zh")
361
+ m["wav_path"] = wav_path
362
+ m["lang_name"] = lang_name
363
+ metrics_buffer.append(m)
364
+
365
+ except Exception:
366
+ logging.error(f"Paraformer worker failed on chunk:\n{traceback.format_exc()}")
367
+ return []
368
+
369
+ return metrics_buffer
370
+
371
+
372
+ def main():
373
+ parser = get_parser()
374
+ args = parser.parse_args()
375
+
376
+ logging.basicConfig(
377
+ format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
378
+ level=logging.INFO,
379
+ force=True,
380
+ )
381
+
382
+ # 1. Prepare Data
383
+ logging.info("Reading test list...")
384
+ data_by_lang = defaultdict(list)
385
+ total_files = 0
386
+ wav_root = Path(args.wav_path)
387
+
388
+ samples = read_test_list(args.test_list)
389
+ for s in samples:
390
+ wav_path = str(wav_root / f"{s['id']}.{args.extension}")
391
+ if not os.path.exists(wav_path):
392
+ logging.warning(f"File missing: {wav_path}")
393
+ continue
394
+
395
+ lang_id = s.get("language_id") or "unknown"
396
+ lang_name = s.get("language_name") or "unknown"
397
+
398
+ item = {
399
+ "wav_path": wav_path,
400
+ "truth_text": s["text"],
401
+ "lang_id": lang_id,
402
+ "lang_name": lang_name,
403
+ }
404
+ if args.lang and s.get("language_id") != args.lang:
405
+ continue
406
+
407
+ data_by_lang[lang_name].append(item)
408
+ total_files += 1
409
+
410
+ logging.info(f"Total files: {total_files} in {len(data_by_lang)} languages.")
411
+
412
+ # 2. Worker config
413
+ num_gpus = torch.cuda.device_count()
414
+ assert num_gpus > 0, "No GPU found. GPU is required."
415
+ total_workers = num_gpus * args.nj_per_gpu
416
+
417
+ mp.set_start_method("spawn", force=True)
418
+ manager = mp.Manager()
419
+
420
+ # 3. Scheduling: Split data into Chinese (Paraformer) and non-Chinese (Whisper)
421
+ zh_items = []
422
+ non_zh_items = []
423
+ for lang_name, items in data_by_lang.items():
424
+ lang_id = items[0].get("lang_id", "") if items else ""
425
+ if lang_name == "Chinese" or (lang_id and lang_id.startswith("zh")):
426
+ zh_items.extend(items)
427
+ else:
428
+ non_zh_items.extend(items)
429
+
430
+ chunk_size = args.chunk_size
431
+
432
+ whisper_tasks = []
433
+ for i in range(0, len(non_zh_items), chunk_size):
434
+ chunk = non_zh_items[i : i + chunk_size]
435
+ lang_name = chunk[0].get("lang_name", "unknown")
436
+ whisper_tasks.append({"chunk": chunk, "lang": lang_name})
437
+
438
+ paraformer_tasks = []
439
+ for i in range(0, len(zh_items), chunk_size):
440
+ paraformer_tasks.append(zh_items[i : i + chunk_size])
441
+
442
+ logging.info(
443
+ f"Whisper tasks: {len(whisper_tasks)} chunks ({len(non_zh_items)} files). "
444
+ f"Paraformer tasks: {len(paraformer_tasks)} chunks ({len(zh_items)} files). "
445
+ f"Spawning {total_workers} workers per pool."
446
+ )
447
+
448
+ # 4. Execution — run Whisper and Paraformer pools sequentially
449
+ results = []
450
+
451
+ # 4a. Whisper pool for non-Chinese languages
452
+ if whisper_tasks:
453
+ whisper_rank_queue = manager.Queue()
454
+ for _ in range(args.nj_per_gpu):
455
+ for rank in range(num_gpus):
456
+ whisper_rank_queue.put(rank)
457
+
458
+ with ProcessPoolExecutor(
459
+ max_workers=total_workers,
460
+ initializer=process_init,
461
+ initargs=(whisper_rank_queue, args.model_dir),
462
+ ) as executor:
463
+
464
+ futures = []
465
+ for task in whisper_tasks:
466
+ futures.append(
467
+ executor.submit(
468
+ run_eval_worker, task["chunk"], task["lang"], args.batch_size
469
+ )
470
+ )
471
+
472
+ with tqdm(
473
+ total=len(non_zh_items),
474
+ desc="Whisper Eval",
475
+ dynamic_ncols=True,
476
+ ) as pbar:
477
+ for future in as_completed(futures):
478
+ try:
479
+ chunk_metrics = future.result()
480
+ results.extend(chunk_metrics)
481
+ pbar.update(len(chunk_metrics))
482
+ except Exception as e:
483
+ logging.error(f"Whisper task failed: {e}")
484
+
485
+ # 4b. Paraformer pool for Chinese
486
+ if paraformer_tasks:
487
+ para_rank_queue = manager.Queue()
488
+ for _ in range(args.nj_per_gpu):
489
+ for rank in range(num_gpus):
490
+ para_rank_queue.put(rank)
491
+
492
+ with ProcessPoolExecutor(
493
+ max_workers=total_workers,
494
+ initializer=process_init_paraformer,
495
+ initargs=(para_rank_queue, args.model_dir),
496
+ ) as executor:
497
+
498
+ futures = []
499
+ for chunk in paraformer_tasks:
500
+ futures.append(
501
+ executor.submit(run_eval_worker_paraformer, chunk, args.batch_size)
502
+ )
503
+
504
+ with tqdm(
505
+ total=len(zh_items),
506
+ desc="Paraformer Eval",
507
+ dynamic_ncols=True,
508
+ ) as pbar:
509
+ for future in as_completed(futures):
510
+ try:
511
+ chunk_metrics = future.result()
512
+ results.extend(chunk_metrics)
513
+ pbar.update(len(chunk_metrics))
514
+ except Exception as e:
515
+ logging.error(f"Paraformer task failed: {e}")
516
+
517
+ # 5. Metrics Aggregation
518
+ wers, inses, deles, subses = [], [], [], []
519
+ word_nums = 0
520
+
521
+ # Store metrics per language
522
+ lang_stats = {}
523
+
524
+ fout = None
525
+ if args.decode_path:
526
+ os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
527
+ logging.info(f"Saving detailed WER results to: {args.decode_path}")
528
+ fout = open(args.decode_path, "w", encoding="utf-8")
529
+
530
+ for res in results:
531
+ wers.append(float(res["wer"]))
532
+ inses.append(float(res["insertions"]))
533
+ deles.append(float(res["deletions"]))
534
+ subses.append(float(res["substitutions"]))
535
+ word_nums += res["word_num"]
536
+
537
+ if fout:
538
+ fout.write(
539
+ f"{res['wav_path']}\t{res['wer']}\t{res['truth']}\t"
540
+ f"{res['hypo']}\t{res['insertions']}\t{res['deletions']}\t"
541
+ f"{res['substitutions']}\n"
542
+ )
543
+ lang_name = res["lang_name"]
544
+
545
+ # Per language stats
546
+ if lang_name not in lang_stats:
547
+ lang_stats[lang_name] = {
548
+ "inses": [],
549
+ "deles": [],
550
+ "subses": [],
551
+ "word_nums": 0,
552
+ }
553
+ lang_stats[lang_name]["inses"].append(float(res["insertions"]))
554
+ lang_stats[lang_name]["deles"].append(float(res["deletions"]))
555
+ lang_stats[lang_name]["subses"].append(float(res["substitutions"]))
556
+ lang_stats[lang_name]["word_nums"] += res["word_num"]
557
+
558
+ print("-" * 50)
559
+ # Log per-language stats
560
+ per_lang_wers = []
561
+ for lang in sorted(lang_stats.keys()):
562
+ stats = lang_stats[lang]
563
+ if stats["word_nums"] > 0:
564
+ lang_wer = log_metrics(
565
+ fout,
566
+ f"[{lang}]",
567
+ stats["inses"],
568
+ stats["deles"],
569
+ stats["subses"],
570
+ stats["word_nums"],
571
+ ndigits=3,
572
+ )
573
+ per_lang_wers.append(lang_wer)
574
+ print("-" * 50)
575
+
576
+ # Log Macro-average WER
577
+ if len(per_lang_wers) > 1:
578
+ macro_wer = np.mean(per_lang_wers)
579
+ logging.info(
580
+ f"Macro-average WER over {len(per_lang_wers)} languages: {macro_wer:.2f}%"
581
+ )
582
+ if fout:
583
+ fout.write(
584
+ f"Macro-average WER over {len(per_lang_wers)} languages: {macro_wer:.2f}%\n"
585
+ )
586
+
587
+ # Log overall stats
588
+ if word_nums > 0:
589
+ log_metrics(fout, "Overall", inses, deles, subses, word_nums)
590
+
591
+ if fout:
592
+ fout.close()
593
+
594
+
595
+ if __name__ == "__main__":
596
+ main()
omnivoice/eval/wer/norm_config_module.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ """
9
+ This module defines the normalization configuration for WER evaluation.
10
+ Copied from https://github.com/facebookresearch/omnilingual-asr/blob/81f51e224ce9e74b02cc2a3eaf21b2d91d743455/workflows/dataprep/norm_config_module.py
11
+ """
12
+
13
+ # type: ignore
14
+ import os
15
+ import re
16
+
17
+ colon = ":"
18
+ comma = ","
19
+ exclamation_mark = "!"
20
+ period = re.escape(".")
21
+ question_mark = re.escape("?")
22
+ semicolon = ";"
23
+
24
+ left_curly_bracket = "{"
25
+ right_curly_bracket = "}"
26
+ quotation_mark = '"'
27
+
28
+ basic_punc = (
29
+ period
30
+ + question_mark
31
+ + comma
32
+ + colon
33
+ + exclamation_mark
34
+ + left_curly_bracket
35
+ + right_curly_bracket
36
+ )
37
+
38
+ # General punc unicode block (0x2000-0x206F)
39
+ zero_width_space = r"\u200B"
40
+ zero_width_nonjoiner = r"\u200C"
41
+ left_to_right_mark = r"\u200E"
42
+ right_to_left_mark = r"\u200F"
43
+ left_to_right_embedding = r"\u202A"
44
+ pop_directional_formatting = r"\u202C"
45
+
46
+ # Here are some commonly ill-typed versions of apostrophe
47
+ right_single_quotation_mark = r"\u2019"
48
+ left_single_quotation_mark = r"\u2018"
49
+
50
+ # Language specific definitions
51
+ # Spanish
52
+ inverted_exclamation_mark = r"\u00A1"
53
+ inverted_question_mark = r"\u00BF"
54
+
55
+
56
+ # Hindi
57
+ hindi_danda = "\u0964"
58
+
59
+ # Egyptian Arabic
60
+ # arabic_percent = r"\u066A"
61
+ arabic_comma = r"\u060C"
62
+ arabic_question_mark = r"\u061F"
63
+ arabic_semicolon = r"\u061B"
64
+ arabic_diacritics = r"\u064B-\u0652"
65
+
66
+
67
+ arabic_subscript_alef_and_inverted_damma = r"\u0656-\u0657"
68
+
69
+
70
+ # Chinese
71
+ full_stop = r"\u3002"
72
+ full_comma = r"\uFF0C"
73
+ full_exclamation_mark = r"\uFF01"
74
+ full_question_mark = r"\uFF1F"
75
+ full_semicolon = r"\uFF1B"
76
+ full_colon = r"\uFF1A"
77
+ full_parentheses = r"\uFF08\uFF09"
78
+ quotation_mark_horizontal = r"\u300C-\u300F"
79
+ quotation_mark_vertical = r"\uFF41-\uFF44"
80
+ title_marks = r"\u3008-\u300B"
81
+ wavy_low_line = r"\uFE4F"
82
+ ellipsis = r"\u22EF"
83
+ enumeration_comma = r"\u3001"
84
+ hyphenation_point = r"\u2027"
85
+ forward_slash = r"\uFF0F"
86
+ wavy_dash = r"\uFF5E"
87
+ box_drawings_light_horizontal = r"\u2500"
88
+ fullwidth_low_line = r"\uFF3F"
89
+ chinese_punc = (
90
+ full_stop
91
+ + full_comma
92
+ + full_exclamation_mark
93
+ + full_question_mark
94
+ + full_semicolon
95
+ + full_colon
96
+ + full_parentheses
97
+ + quotation_mark_horizontal
98
+ + quotation_mark_vertical
99
+ + title_marks
100
+ + wavy_low_line
101
+ + ellipsis
102
+ + enumeration_comma
103
+ + hyphenation_point
104
+ + forward_slash
105
+ + wavy_dash
106
+ + box_drawings_light_horizontal
107
+ + fullwidth_low_line
108
+ )
109
+
110
+ # Armenian
111
+ armenian_apostrophe = r"\u055A"
112
+ emphasis_mark = r"\u055B"
113
+ exclamation_mark = r"\u055C"
114
+ armenian_comma = r"\u055D"
115
+ armenian_question_mark = r"\u055E"
116
+ abbreviation_mark = r"\u055F"
117
+ armenian_full_stop = r"\u0589"
118
+ armenian_punc = (
119
+ armenian_apostrophe
120
+ + emphasis_mark
121
+ + exclamation_mark
122
+ + armenian_comma
123
+ + armenian_question_mark
124
+ + abbreviation_mark
125
+ + armenian_full_stop
126
+ )
127
+
128
+ lesser_than_symbol = r"&lt;"
129
+ greater_than_symbol = r"&gt;"
130
+
131
+ lesser_than_sign = r"\u003c"
132
+ greater_than_sign = r"\u003e"
133
+
134
+ nbsp_written_form = r"&nbsp"
135
+
136
+ # Quotation marks
137
+ left_double_quotes = r"\u201c"
138
+ right_double_quotes = r"\u201d"
139
+ left_double_angle = r"\u00ab"
140
+ right_double_angle = r"\u00bb"
141
+ left_single_angle = r"\u2039"
142
+ right_single_angle = r"\u203a"
143
+ low_double_quotes = r"\u201e"
144
+ low_single_quotes = r"\u201a"
145
+ high_double_quotes = r"\u201f"
146
+ high_single_quotes = r"\u201b"
147
+
148
+ all_punct_quotes = (
149
+ left_double_quotes
150
+ + right_double_quotes
151
+ + left_double_angle
152
+ + right_double_angle
153
+ + left_single_angle
154
+ + right_single_angle
155
+ + low_double_quotes
156
+ + low_single_quotes
157
+ + high_double_quotes
158
+ + high_single_quotes
159
+ + right_single_quotation_mark
160
+ + left_single_quotation_mark
161
+ )
162
+ mapping_quotes = (
163
+ "["
164
+ + high_single_quotes
165
+ + right_single_quotation_mark
166
+ + left_single_quotation_mark
167
+ + "]"
168
+ )
169
+
170
+
171
+ # Digits
172
+
173
+ english_digits = r"\u0030-\u0039"
174
+ bengali_digits = r"\u09e6-\u09ef"
175
+ khmer_digits = r"\u17e0-\u17e9"
176
+ devanagari_digits = r"\u0966-\u096f"
177
+ oriya_digits = r"\u0b66-\u0b6f"
178
+ extended_arabic_indic_digits = r"\u06f0-\u06f9"
179
+ kayah_li_digits = r"\ua900-\ua909"
180
+ fullwidth_digits = r"\uff10-\uff19"
181
+ malayam_digits = r"\u0d66-\u0d6f"
182
+ myanmar_digits = r"\u1040-\u1049"
183
+ roman_numeral = r"\u2170-\u2179"
184
+ nominal_digit_shapes = r"\u206f"
185
+
186
+ # Load punctuations
187
+ with open(f"{os.path.dirname(__file__)}/punctuations.lst", "r") as punc_f:
188
+ punc_list = [
189
+ line
190
+ for line in punc_f.readlines()
191
+ if line.strip() and not line.strip().startswith("#")
192
+ ]
193
+
194
+ punct_pattern = r""
195
+ for punc in punc_list:
196
+ # the first character in the tab separated line is the punc to be removed
197
+ punct_pattern += re.escape(punc.split("\t")[0])
198
+
199
+ shared_digits = (
200
+ english_digits
201
+ + bengali_digits
202
+ + khmer_digits
203
+ + devanagari_digits
204
+ + oriya_digits
205
+ + extended_arabic_indic_digits
206
+ + kayah_li_digits
207
+ + fullwidth_digits
208
+ + malayam_digits
209
+ + myanmar_digits
210
+ + roman_numeral
211
+ + nominal_digit_shapes
212
+ )
213
+
214
+ shared_punc_list = (
215
+ basic_punc
216
+ + all_punct_quotes
217
+ + greater_than_sign
218
+ + lesser_than_sign
219
+ + inverted_question_mark
220
+ + full_stop
221
+ + semicolon
222
+ + armenian_punc
223
+ + inverted_exclamation_mark
224
+ + arabic_comma
225
+ + enumeration_comma
226
+ + hindi_danda
227
+ + quotation_mark
228
+ + arabic_semicolon
229
+ + arabic_question_mark
230
+ + chinese_punc
231
+ + punct_pattern
232
+ )
233
+
234
+ shared_mappping = {
235
+ lesser_than_symbol: "",
236
+ greater_than_symbol: "",
237
+ nbsp_written_form: "",
238
+ r"(\S+)" + mapping_quotes + r"(\S+)": r"\1'\2",
239
+ }
240
+
241
+ shared_deletion_list = (
242
+ left_to_right_mark
243
+ + zero_width_nonjoiner
244
+ + arabic_subscript_alef_and_inverted_damma
245
+ + zero_width_space
246
+ + arabic_diacritics
247
+ + pop_directional_formatting
248
+ + right_to_left_mark
249
+ + left_to_right_embedding
250
+ )
251
+
252
+ norm_config = {
253
+ "*": {
254
+ "lower_case": True,
255
+ "punc_set": shared_punc_list,
256
+ "del_set": shared_deletion_list,
257
+ "mapping": shared_mappping,
258
+ "digit_set": shared_digits,
259
+ "unicode_norm": "NFKC",
260
+ "rm_diacritics": False,
261
+ }
262
+ }
263
+
264
+ # =============== Mongolian ===============#
265
+
266
+ norm_config["mon"] = norm_config["*"].copy()
267
+ # add soft hyphen to punc list to match with fleurs
268
+ norm_config["mon"]["del_set"] += r"\u00AD"
269
+
270
+ norm_config["khk"] = norm_config["mon"].copy()
271
+
272
+ # =============== Hebrew ===============#
273
+
274
+ norm_config["heb"] = norm_config["*"].copy()
275
+ # add "HEBREW POINT" symbols to match with fleurs
276
+ norm_config["heb"]["del_set"] += r"\u05B0-\u05BF\u05C0-\u05CF"
277
+
278
+ # =============== Thai ===============#
279
+
280
+ norm_config["tha"] = norm_config["*"].copy()
281
+ # add "Zero width joiner" symbols to match with fleurs
282
+ norm_config["tha"]["punc_set"] += r"\u200D"
283
+
284
+ # =============== Arabic ===============#
285
+ norm_config["ara"] = norm_config["*"].copy()
286
+ norm_config["ara"]["mapping"]["ٱ"] = "ا"
287
+ norm_config["arb"] = norm_config["ara"].copy()
288
+
289
+ # =============== Javanese ===============#
290
+ norm_config["jav"] = norm_config["*"].copy()
291
+ norm_config["jav"]["rm_diacritics"] = True
omnivoice/eval/wer/punctuations.lst ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+  7355 INVALID UNICODE 0x81
2
+  5265 INVALID UNICODE 0x90
3
+  75 INVALID UNICODE 0x8
4
+  31 INVALID UNICODE 0x8d
5
+ ” 3 INVALID UNICODE 0x94
6
+  2 INVALID UNICODE 0x8f
7
+  2 INVALID UNICODE 0x1a
8
+  1 INVALID UNICODE 0x9d
9
+ “ 1 INVALID UNICODE 0x93
10
+ ’ 1 INVALID UNICODE 0x92
11
+  8647 INVALID UNICODE 0xe295
12
+  6650 INVALID UNICODE 0xf21d
13
+  6234 INVALID UNICODE 0xf62d
14
+  4815 INVALID UNICODE 0xf173
15
+  4789 INVALID UNICODE 0xe514
16
+  4409 INVALID UNICODE 0xe293
17
+  3881 INVALID UNICODE 0xf523
18
+  3788 INVALID UNICODE 0xe233
19
+  2448 INVALID UNICODE 0xf50f
20
+  2177 INVALID UNICODE 0xe232
21
+  1955 INVALID UNICODE 0xea7b
22
+  1926 INVALID UNICODE 0xf172
23
+  973 INVALID UNICODE 0xe290
24
+  972 INVALID UNICODE 0xf519
25
+  661 INVALID UNICODE 0xe292
26
+  591 INVALID UNICODE 0xe328
27
+  509 INVALID UNICODE 0xe2fa
28
+  458 INVALID UNICODE 0xe234
29
+  446 INVALID UNICODE 0xe043
30
+  419 INVALID UNICODE 0xe040
31
+  399 INVALID UNICODE 0xe2fb
32
+  387 INVALID UNICODE 0xe32b
33
+  381 INVALID UNICODE 0xe236
34
+  374 INVALID UNICODE 0xf511
35
+  314 INVALID UNICODE 0xe517
36
+  296 INVALID UNICODE 0xe2fe
37
+  293 INVALID UNICODE 0xe492
38
+  291 INVALID UNICODE 0xf52d
39
+  289 INVALID UNICODE 0xe2fc
40
+  195 INVALID UNICODE 0xf521
41
+  190 INVALID UNICODE 0xe516
42
+  182 INVALID UNICODE 0xe041
43
+  178 INVALID UNICODE 0xf529
44
+  113 INVALID UNICODE 0xe2f9
45
+  87 INVALID UNICODE 0xe2d9
46
+  78 INVALID UNICODE 0xe32a
47
+  76 INVALID UNICODE 0xe291
48
+  74 INVALID UNICODE 0xe296
49
+  66 INVALID UNICODE 0xe518
50
+  52 INVALID UNICODE 0xe32c
51
+  46 INVALID UNICODE 0xe2db
52
+  41 INVALID UNICODE 0xe231
53
+  34 INVALID UNICODE 0xf522
54
+  33 INVALID UNICODE 0xf518
55
+  32 INVALID UNICODE 0xf513
56
+  27 INVALID UNICODE 0xe32d
57
+  25 INVALID UNICODE 0xe32e
58
+  23 INVALID UNICODE 0xe06b
59
+  15 INVALID UNICODE 0xea01
60
+  12 INVALID UNICODE 0xe294
61
+  11 INVALID UNICODE 0xe203
62
+  8 INVALID UNICODE 0xf218
63
+  7 INVALID UNICODE 0xe070
64
+  7 INVALID UNICODE 0xe013
65
+  5 INVALID UNICODE 0xe2de
66
+  4 INVALID UNICODE 0xe493
67
+  3 INVALID UNICODE 0xf7e8
68
+  3 INVALID UNICODE 0xf7d0
69
+  3 INVALID UNICODE 0xe313
70
+  2 INVALID UNICODE 0xe329
71
+  2 INVALID UNICODE 0xe06d
72
+  2 INVALID UNICODE 0xe003
73
+  1 INVALID UNICODE 0xf50e
74
+  1 INVALID UNICODE 0xf171
75
+  1 INVALID UNICODE 0xe01d
76
+  71 NOMINAL DIGIT SHAPES 0x206f
77
+ ⁠ 3 WORD JOINER 0x2060
78
+ ― 126545 HORIZONTAL BAR 0x2015
79
+ ־ 1028 HEBREW PUNCTUATION MAQAF 0x5be
80
+ ) 98429 RIGHT PARENTHESIS 0x29
81
+ ] 27108 RIGHT SQUARE BRACKET 0x5d
82
+ ⌋ 1567 RIGHT FLOOR 0x230b
83
+ 〕 97 RIGHT TORTOISE SHELL BRACKET 0x3015
84
+ 】 36 RIGHT BLACK LENTICULAR BRACKET 0x3011
85
+ ﴾ 14 ORNATE LEFT PARENTHESIS 0xfd3e
86
+ & 170517 AMPERSAND 0x26
87
+ ། 106330 TIBETAN MARK SHAD 0xf0d
88
+ ። 90203 ETHIOPIC FULL STOP 0x1362
89
+ ፥ 60484 ETHIOPIC COLON 0x1365
90
+ ༌ 60464 TIBETAN MARK DELIMITER TSHEG BSTAR 0xf0c
91
+ ။ 51567 MYANMAR SIGN SECTION 0x104b
92
+ / 46929 SOLIDUS 0x2f
93
+ ၊ 38042 MYANMAR SIGN LITTLE SECTION 0x104a
94
+ · 37985 MIDDLE DOT 0xb7
95
+ ‸ 36310 CARET 0x2038
96
+ * 34793 ASTERISK 0x2a
97
+ ۔ 32432 ARABIC FULL STOP 0x6d4
98
+ ፤ 31906 ETHIOPIC SEMICOLON 0x1364
99
+ ၏ 21519 MYANMAR SYMBOL GENITIVE 0x104f
100
+ ។ 20834 KHMER SIGN KHAN 0x17d4
101
+ ꓾ 15773 LISU PUNCTUATION COMMA 0xa4fe
102
+ ᙮ 13473 CANADIAN SYLLABICS FULL STOP 0x166e
103
+ ꤯ 12892 KAYAH LI SIGN SHYA 0xa92f
104
+ ⵰ 11478 TIFINAGH SEPARATOR MARK 0x2d70
105
+ ꓿ 11118 LISU PUNCTUATION FULL STOP 0xa4ff
106
+ ॥ 10763 DEVANAGARI DOUBLE DANDA 0x965
107
+ ؞ 10403 ARABIC TRIPLE DOT PUNCTUATION MARK 0x61e
108
+ ၍ 8936 MYANMAR SYMBOL COMPLETED 0x104d
109
+ · 8431 GREEK ANO TELEIA 0x387
110
+ † 7477 DAGGER 0x2020
111
+ ၌ 6632 MYANMAR SYMBOL LOCATIVE 0x104c
112
+ ፣ 5719 ETHIOPIC COMMA 0x1363
113
+ ៖ 5528 KHMER SIGN CAMNUC PII KUUH 0x17d6
114
+ ꤮ 4791 KAYAH LI SIGN CWI 0xa92e
115
+ ※ 3439 REFERENCE MARK 0x203b
116
+ ፦ 2727 ETHIOPIC PREFACE COLON 0x1366
117
+ • 1749 BULLET 0x2022
118
+ ¶ 1507 PILCROW SIGN 0xb6
119
+ ၎ 1386 MYANMAR SYMBOL AFOREMENTIONED 0x104e
120
+ ﹖ 1224 SMALL QUESTION MARK 0xfe56
121
+ ; 975 GREEK QUESTION MARK 0x37e
122
+ … 827 HORIZONTAL ELLIPSIS 0x2026
123
+ % 617 PERCENT SIGN 0x25
124
+ ・ 468 KATAKANA MIDDLE DOT 0x30fb
125
+ ༎ 306 TIBETAN MARK NYIS SHAD 0xf0e
126
+ ‡ 140 DOUBLE DAGGER 0x2021
127
+ # 137 NUMBER SIGN 0x23
128
+ @ 125 COMMERCIAL AT 0x40
129
+ ፡ 121 ETHIOPIC WORDSPACE 0x1361
130
+ ៚ 55 KHMER SIGN KOOMUUT 0x17da
131
+ ៕ 49 KHMER SIGN BARIYOOSAN 0x17d5
132
+ ﹐ 10 SMALL COMMA 0xfe50
133
+ ༅ 6 TIBETAN MARK CLOSING YIG MGO SGAB MA 0xf05
134
+ ༄ 6 TIBETAN MARK INITIAL YIG MGO MDUN MA 0xf04
135
+ . 2 FULLWIDTH FULL STOP 0xff0e
136
+ ﹗ 2 SMALL EXCLAMATION MARK 0xfe57
137
+ ﹕ 2 SMALL COLON 0xfe55
138
+ ‰ 2 PER MILLE SIGN 0x2030
139
+ ・ 1 HALFWIDTH KATAKANA MIDDLE DOT 0xff65
140
+ ( 98504 LEFT PARENTHESIS 0x28
141
+ [ 27245 LEFT SQUARE BRACKET 0x5b
142
+ ⌊ 1567 LEFT FLOOR 0x230a
143
+ 〔 95 LEFT TORTOISE SHELL BRACKET 0x3014
144
+ 【 36 LEFT BLACK LENTICULAR BRACKET 0x3010
145
+ ﴿ 14 ORNATE RIGHT PARENTHESIS 0xfd3f
146
+ _ 4851 LOW LINE 0x5f
147
+ $ 72 DOLLAR SIGN 0x24
148
+ € 14 EURO SIGN 0x20ac
149
+ £ 2 POUND SIGN 0xa3
150
+ ~ 27462 TILDE 0x7e
151
+ = 11450 EQUALS SIGN 0x3d
152
+ | 8430 VERTICAL LINE 0x7c
153
+ − 3971 MINUS SIGN 0x2212
154
+ ≫ 1904 MUCH GREATER-THAN 0x226b
155
+ ≪ 1903 MUCH LESS-THAN 0x226a
156
+ + 1450 PLUS SIGN 0x2b
157
+ < 345 FULLWIDTH LESS-THAN SIGN 0xff1c
158
+ > 344 FULLWIDTH GREATER-THAN SIGN 0xff1e
159
+ ¬ 5 NOT SIGN 0xac
160
+ × 4 MULTIPLICATION SIGN 0xd7
161
+ → 2 RIGHTWARDS ARROW 0x2192
162
+ ᙭ 537 CANADIAN SYLLABICS CHI SIGN 0x166d
163
+ ° 499 DEGREE SIGN 0xb0
164
+ ႟ 421 MYANMAR SYMBOL SHAN EXCLAMATION 0x109f
165
+ � 192 REPLACEMENT CHARACTER 0xfffd
166
+ ⌟ 54 BOTTOM RIGHT CORNER 0x231f
167
+ ⌞ 54 BOTTOM LEFT CORNER 0x231e
168
+ © 2 COPYRIGHT SIGN 0xa9
169
+   40 NARROW NO-BREAK SPACE 0x202f
170
+   1 SIX-PER-EM SPACE 0x2006
171
+ ˜ 40261 SMALL TILDE 0x2dc
172
+ ^ 6469 CIRCUMFLEX ACCENT 0x5e
173
+ ¯ 20 MACRON 0xaf
174
+ ˇ 191442 CARON 0x2c7
175
+ ⁿ 38144 SUPERSCRIPT LATIN SMALL LETTER N 0x207f
176
+ ـ 9440 ARABIC TATWEEL 0x640
177
+ ๆ 6766 THAI CHARACTER MAIYAMOK 0xe46
178
+ ៗ 3310 KHMER SIGN LEK TOO 0x17d7
179
+ 々 678 IDEOGRAPHIC ITERATION MARK 0x3005
180
+ ໆ 430 LAO KO LA 0xec6
181
+ ー 319 KATAKANA-HIRAGANA PROLONGED SOUND MARK 0x30fc
182
+ ⁱ 137 SUPERSCRIPT LATIN SMALL LETTER I 0x2071
183
+ ৷ 11056 BENGALI CURRENCY NUMERATOR FOUR 0x9f7
184
+ ⅓ 26 VULGAR FRACTION ONE THIRD 0x2153
185
+ ½ 26 VULGAR FRACTION ONE HALF 0xbd
186
+ ¼ 4 VULGAR FRACTION ONE QUARTER 0xbc
187
+ ⅟ 1 FRACTION NUMERATOR ONE 0x215f
188
+ ⁄ 57 FRACTION SLASH 0x2044
omnivoice/eval/wer/seedtts.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Computes word error rate (WER) with Whisper-large-v3 for English and
20
+ Paraformer for Chinese. Intended to evaluate WERs on Seed-TTS test sets.
21
+ """
22
+ import argparse
23
+ import logging
24
+ import multiprocessing as mp
25
+ import os
26
+ import string
27
+ import traceback
28
+ from concurrent.futures import ProcessPoolExecutor, as_completed
29
+ from pathlib import Path
30
+
31
+ import numpy as np
32
+ import torch
33
+ import zhconv
34
+ from tqdm import tqdm
35
+ from zhon.hanzi import punctuation
36
+
37
+ from omnivoice.eval.utils import load_waveform
38
+ from omnivoice.eval.wer.common import process_one
39
+ from omnivoice.utils.data_utils import read_test_list
40
+
41
+ # --- Global variables for worker processes ---
42
+ worker_pipe = None
43
+ worker_device = None
44
+
45
+
46
+ def get_parser():
47
+ parser = argparse.ArgumentParser(
48
+ description="Computes WER with Whisper/Paraformer.",
49
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
50
+ )
51
+ parser.add_argument(
52
+ "--wav-path",
53
+ type=str,
54
+ required=True,
55
+ help="Path to the directory containing speech files.",
56
+ )
57
+ parser.add_argument(
58
+ "--extension",
59
+ type=str,
60
+ default="wav",
61
+ help="Extension of the speech files. Default: wav",
62
+ )
63
+ parser.add_argument(
64
+ "--decode-path",
65
+ type=str,
66
+ default=None,
67
+ help="Path to the output file where WER information will be saved. "
68
+ "If not provided, results are only printed to console.",
69
+ )
70
+ parser.add_argument(
71
+ "--model-dir",
72
+ type=str,
73
+ required=True,
74
+ help="Local path of evaluation models repository. "
75
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models. "
76
+ "This script expects 'tts_eval_models/wer/whisper-large-v3/' for English "
77
+ "and 'tts_eval_models/wer/paraformer-zh/' for Chinese within this directory.",
78
+ )
79
+ parser.add_argument(
80
+ "--test-list",
81
+ type=str,
82
+ default="test.jsonl",
83
+ help="path of the JSONL test list. Each line is a JSON object "
84
+ "with fields: id, text, ref_audio, ref_text, language_id, language_name.",
85
+ )
86
+ parser.add_argument(
87
+ "--lang",
88
+ type=str,
89
+ choices=["zh", "en"],
90
+ required=True,
91
+ help="Language of the audio and transcripts for "
92
+ "decoding ('zh' for Chinese or 'en' for English).",
93
+ )
94
+ parser.add_argument(
95
+ "--batch-size",
96
+ type=int,
97
+ default=16,
98
+ help="Batch size for decoding with the Hugging Face pipeline.",
99
+ )
100
+ parser.add_argument(
101
+ "--nj-per-gpu", type=int, default=1, help="Number of workers per GPU."
102
+ )
103
+ return parser
104
+
105
+
106
+ def load_whisper_model(model_dir, device):
107
+ model_path = os.path.join(model_dir, "wer/whisper-large-v3/")
108
+ if not os.path.exists(model_path):
109
+ logging.error(f"Whisper model not found at {model_path}.")
110
+ return None
111
+
112
+ logging.debug(f"Loading Whisper model on {device}...")
113
+
114
+ import transformers
115
+
116
+ # Suppress transformers logging
117
+ transformers.logging.set_verbosity_error()
118
+
119
+ pipe = transformers.pipeline(
120
+ "automatic-speech-recognition",
121
+ model=model_path,
122
+ dtype=torch.float16 if "cuda" in str(device) else torch.float32,
123
+ device=device,
124
+ )
125
+ return pipe
126
+
127
+
128
+ def load_paraformer_model(model_dir, device):
129
+ model_path = os.path.join(model_dir, "wer/paraformer-zh/")
130
+ if not os.path.exists(model_path):
131
+ logging.error(f"Paraformer model not found at {model_path}.")
132
+ return None
133
+
134
+ logging.debug(f"Loading Paraformer model on {device}...")
135
+
136
+ previous_level = logging.root.manager.disable
137
+ logging.disable(logging.CRITICAL)
138
+
139
+ try:
140
+ from funasr import AutoModel
141
+
142
+ # FunASR AutoModel accepts "cuda:0" string or torch.device
143
+ model = AutoModel(
144
+ model=model_path,
145
+ device=str(device),
146
+ disable_update=True,
147
+ disable_pbar=True,
148
+ verbose=False,
149
+ )
150
+ finally:
151
+ logging.disable(previous_level)
152
+
153
+ return model
154
+
155
+
156
+ def post_process(text: str, lang: str) -> str:
157
+ """
158
+ Cleans and normalizes text for WER calculation.
159
+ Args:
160
+ text (str): The input text to be processed.
161
+ lang (str): The language of the input text.
162
+
163
+ Returns:
164
+ str: The cleaned and normalized text.
165
+ """
166
+ punctuation_all = punctuation + string.punctuation
167
+ for x in punctuation_all:
168
+ if x == "'":
169
+ continue
170
+ text = text.replace(x, "")
171
+
172
+ text = text.replace(" ", " ")
173
+
174
+ if lang == "zh":
175
+ text = " ".join([x for x in text])
176
+ elif lang == "en":
177
+ text = text.lower()
178
+ else:
179
+ raise NotImplementedError
180
+ return text
181
+
182
+
183
+ def process_init(rank_queue, model_dir, lang):
184
+ """
185
+ Initializer for each worker process.
186
+ Loads model onto a specific GPU, once per process.
187
+ """
188
+ global worker_pipe, worker_device
189
+
190
+ torch.set_num_threads(2)
191
+
192
+ try:
193
+ rank = rank_queue.get(timeout=10)
194
+ except Exception:
195
+ raise RuntimeError("Failed to get GPU rank from queue.")
196
+
197
+ assert torch.cuda.is_available(), "CUDA is required but not available."
198
+ worker_device = torch.device(f"cuda:{rank}")
199
+ torch.cuda.set_device(rank)
200
+
201
+ logging.info(f"Initializing worker on device: {worker_device}")
202
+
203
+ try:
204
+ if lang == "en":
205
+ worker_pipe = load_whisper_model(model_dir, worker_device)
206
+ elif lang == "zh":
207
+ worker_pipe = load_paraformer_model(model_dir, worker_device)
208
+ if worker_pipe is None:
209
+ raise RuntimeError("Model loading failed.")
210
+ except Exception as e:
211
+ logging.critical(f"Failed to load model on {worker_device}: {e}")
212
+ raise e
213
+
214
+
215
+ def run_eval_worker(data_chunk, lang, batch_size):
216
+ """
217
+ Worker function to process a chunk of data.
218
+ Uses the global worker_pipe initialized by process_init.
219
+ """
220
+ global worker_pipe
221
+ if worker_pipe is None:
222
+ logging.error("Worker pipeline is not initialized!")
223
+ return []
224
+
225
+ metrics_buffer = []
226
+ try:
227
+ if lang == "en":
228
+ # Load waveforms as arrays, truncating to 30s
229
+ dataset = [
230
+ {
231
+ "array": load_waveform(
232
+ item["wav_path"], sample_rate=16000, return_numpy=True
233
+ )[: 16000 * 30],
234
+ "sampling_rate": 16000,
235
+ }
236
+ for item in data_chunk
237
+ ]
238
+ generate_kwargs = {"language": "english", "task": "transcribe"}
239
+
240
+ iterator = worker_pipe(
241
+ dataset, generate_kwargs=generate_kwargs, batch_size=batch_size
242
+ )
243
+
244
+ for i, out in enumerate(iterator):
245
+ hypothesis = out["text"].strip()
246
+ ref_item = data_chunk[i]
247
+ truth = ref_item["truth_text"]
248
+ wav_path = ref_item["wav_path"]
249
+
250
+ m = process_one(hypothesis, truth, post_process, lang)
251
+ m["wav_path"] = wav_path
252
+ metrics_buffer.append(m)
253
+
254
+ elif lang == "zh":
255
+ wav_paths = [item["wav_path"] for item in data_chunk]
256
+
257
+ for i in range(0, len(wav_paths), batch_size):
258
+ batch_paths = wav_paths[i : i + batch_size]
259
+ res_batch = worker_pipe.generate(
260
+ input=batch_paths, batch_size=batch_size, disable_pbar=True
261
+ )
262
+
263
+ for j, res in enumerate(res_batch):
264
+ hypothesis = zhconv.convert(res["text"], "zh-cn")
265
+ ref_item = data_chunk[i + j]
266
+ truth = ref_item["truth_text"]
267
+ wav_path = ref_item["wav_path"]
268
+
269
+ m = process_one(hypothesis, truth, post_process, lang)
270
+ m["wav_path"] = wav_path
271
+ metrics_buffer.append(m)
272
+
273
+ except Exception:
274
+ logging.error(
275
+ f"Worker failed on chunk (Lang: {lang}):\n{traceback.format_exc()}"
276
+ )
277
+ return []
278
+
279
+ return metrics_buffer
280
+
281
+
282
+ def main():
283
+ parser = get_parser()
284
+ args = parser.parse_args()
285
+
286
+ logging.basicConfig(
287
+ format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
288
+ level=logging.INFO,
289
+ force=True,
290
+ )
291
+
292
+ logging.info(f"Calculating WER for {args.wav_path}")
293
+
294
+ # 1. Prepare Data
295
+ logging.info("Reading test list...")
296
+ data_list = []
297
+ samples = read_test_list(args.test_list)
298
+ for s in samples:
299
+ wav_path = str(Path(args.wav_path) / f"{s['id']}.{args.extension}")
300
+ if not os.path.exists(wav_path):
301
+ logging.warning(f"File missing: {wav_path}")
302
+ continue
303
+ data_list.append({"wav_path": wav_path, "truth_text": s["text"]})
304
+ total_files = len(data_list)
305
+ logging.info(f"Total files: {total_files}.")
306
+
307
+ # 2. Worker config
308
+ num_gpus = torch.cuda.device_count()
309
+ assert num_gpus > 0, "No GPU found. GPU is required."
310
+ total_workers = num_gpus * args.nj_per_gpu
311
+
312
+ mp.set_start_method("spawn", force=True)
313
+ manager = mp.Manager()
314
+ rank_queue = manager.Queue()
315
+
316
+ for _ in range(args.nj_per_gpu):
317
+ for rank in range(num_gpus):
318
+ rank_queue.put(rank)
319
+
320
+ # 3. Scheduling: Split data into chunks for better load balancing
321
+ chunk_size = max(1, args.batch_size)
322
+ tasks = []
323
+ for i in range(0, total_files, chunk_size):
324
+ tasks.append(data_list[i : i + chunk_size])
325
+
326
+ logging.info(
327
+ f"Split data into {len(tasks)} chunks (size ~{chunk_size}). "
328
+ f"Spawning {total_workers} workers."
329
+ )
330
+
331
+ # 4. Execution
332
+ results = []
333
+
334
+ with ProcessPoolExecutor(
335
+ max_workers=total_workers,
336
+ initializer=process_init,
337
+ initargs=(rank_queue, args.model_dir, args.lang),
338
+ ) as executor:
339
+
340
+ futures = []
341
+ for chunk in tasks:
342
+ futures.append(
343
+ executor.submit(run_eval_worker, chunk, args.lang, args.batch_size)
344
+ )
345
+
346
+ # Unified progress bar
347
+ with tqdm(total=total_files, desc="Eval Progress", dynamic_ncols=True) as pbar:
348
+ for future in as_completed(futures):
349
+ try:
350
+ chunk_metrics = future.result()
351
+ results.extend(chunk_metrics)
352
+ pbar.update(len(chunk_metrics))
353
+ except Exception as e:
354
+ logging.error(f"Task failed: {e}")
355
+
356
+ wers, inses, deles, subses = [], [], [], []
357
+ word_nums = 0
358
+
359
+ fout = None
360
+ if args.decode_path:
361
+ os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
362
+ fout = open(args.decode_path, "w", encoding="utf8")
363
+ logging.info(f"Saving detailed WER results to: {args.decode_path}")
364
+ fout.write(
365
+ "Name\tWER\tTruth\tHypothesis\tInsertions\tDeletions\tSubstitutions\n"
366
+ )
367
+
368
+ for res in results:
369
+ wers.append(float(res["wer"]))
370
+ inses.append(float(res["insertions"]))
371
+ deles.append(float(res["deletions"]))
372
+ subses.append(float(res["substitutions"]))
373
+ word_nums += res["word_num"]
374
+
375
+ if fout:
376
+ fout.write(
377
+ f"{res['wav_path']}\t{res['wer']}\t{res['truth']}\t"
378
+ f"{res['hypo']}\t{res['insertions']}\t{res['deletions']}\t"
379
+ f"{res['substitutions']}\n"
380
+ )
381
+
382
+ wer_avg = round(np.mean(wers) * 100, 2) if wers else float("nan")
383
+ wer_weighted = (
384
+ round(
385
+ (np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 2
386
+ )
387
+ if word_nums > 0
388
+ else float("nan")
389
+ )
390
+
391
+ inse_sum = np.sum(inses)
392
+ dele_sum = np.sum(deles)
393
+ subs_sum = np.sum(subses)
394
+
395
+ print("-" * 50)
396
+ logging.info(f"Processed {len(results)}/{total_files} files.")
397
+ seedtts_wer_info = f"Seed-TTS WER (Avg of WERs): {wer_avg}%"
398
+ wer_info = f"WER (Weighted): {wer_weighted}%"
399
+ detailed_info = (
400
+ f"Errors: {inse_sum} ins, {dele_sum} del, {subs_sum} sub / {word_nums} words"
401
+ )
402
+ logging.info(seedtts_wer_info)
403
+ logging.info(wer_info)
404
+ logging.info(detailed_info)
405
+ print("-" * 50)
406
+
407
+ if fout:
408
+ fout.write(seedtts_wer_info + "\n" + wer_info + "\n" + detailed_info + "\n")
409
+ fout.close()
410
+
411
+
412
+ if __name__ == "__main__":
413
+ main()
omnivoice/eval/wer/sensevoice.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Computes Character Error Rate (CER) for Cantonese (yue) using SenseVoiceSmall.
20
+ """
21
+
22
+ import argparse
23
+ import logging
24
+ import multiprocessing as mp
25
+ import os
26
+ import re
27
+ import traceback
28
+ from concurrent.futures import ProcessPoolExecutor, as_completed
29
+ from pathlib import Path
30
+
31
+ import cn2an
32
+ import torch
33
+ import zhconv
34
+ from tqdm import tqdm
35
+
36
+ from omnivoice.eval.wer.common import log_metrics, process_one
37
+ from omnivoice.eval.wer.text_norm_omni import text_normalize
38
+ from omnivoice.utils.data_utils import read_test_list
39
+
40
+ # --- Global variables for worker processes ---
41
+ worker_sensevoice = None
42
+ worker_device = None
43
+
44
+
45
+ def get_parser():
46
+ parser = argparse.ArgumentParser(
47
+ description="Computes CER for Cantonese using SenseVoiceSmall.",
48
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
49
+ )
50
+
51
+ parser.add_argument(
52
+ "--wav-path",
53
+ type=str,
54
+ required=True,
55
+ help="Path to the directory containing speech files.",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--extension",
60
+ type=str,
61
+ default="wav",
62
+ help="Extension of the speech files. Default: wav",
63
+ )
64
+
65
+ parser.add_argument(
66
+ "--decode-path",
67
+ type=str,
68
+ default=None,
69
+ help="Path to the output file where CER information will be saved. ",
70
+ )
71
+ parser.add_argument(
72
+ "--model-dir",
73
+ type=str,
74
+ required=True,
75
+ help="Local path of evaluation models repository. ",
76
+ )
77
+ parser.add_argument(
78
+ "--test-list",
79
+ type=str,
80
+ default="test.jsonl",
81
+ help="path of the JSONL test list.",
82
+ )
83
+ parser.add_argument(
84
+ "--batch-size",
85
+ type=int,
86
+ default=16,
87
+ help="Batch size for decoding.",
88
+ )
89
+ parser.add_argument(
90
+ "--nj-per-gpu", type=int, default=1, help="Number of workers per GPU."
91
+ )
92
+ parser.add_argument(
93
+ "--chunk-size",
94
+ type=int,
95
+ default=10,
96
+ help="Number of samples per task chunk sent to workers.",
97
+ )
98
+ return parser
99
+
100
+
101
+ def load_sensevoice_model(model_dir, device):
102
+ model_path = os.path.join(model_dir, "wer/SenseVoiceSmall")
103
+ if not os.path.exists(model_path):
104
+ # Fallback if specific sensevoice spelling isn't found
105
+ logging.warning(
106
+ f"SenseVoiceSmall not found at {model_path}. "
107
+ f"Please ensure it is present in eval models."
108
+ )
109
+
110
+ logging.info(f"Loading SenseVoice model on {device}...")
111
+
112
+ previous_level = logging.root.manager.disable
113
+ logging.disable(logging.CRITICAL)
114
+
115
+ try:
116
+ from funasr import AutoModel
117
+
118
+ model = AutoModel(
119
+ model="iic/SenseVoiceSmall",
120
+ device=str(device),
121
+ disable_update=True,
122
+ disable_pbar=True,
123
+ verbose=False,
124
+ )
125
+ finally:
126
+ logging.disable(previous_level)
127
+
128
+ return model
129
+
130
+
131
+ def _worker_setup(rank_queue):
132
+ global worker_device
133
+
134
+ torch.set_num_threads(2)
135
+
136
+ try:
137
+ rank = rank_queue.get(timeout=10)
138
+ except Exception:
139
+ raise RuntimeError("Failed to get GPU rank from queue.")
140
+
141
+ assert torch.cuda.is_available(), "CUDA is required but not available."
142
+ worker_device = torch.device(f"cuda:{rank}")
143
+ torch.cuda.set_device(rank)
144
+
145
+ logging.info(f"Initializing worker on device: {worker_device}")
146
+
147
+
148
+ def process_init_sensevoice(rank_queue, model_dir):
149
+ global worker_sensevoice
150
+
151
+ _worker_setup(rank_queue)
152
+
153
+ try:
154
+ worker_sensevoice = load_sensevoice_model(model_dir, worker_device)
155
+ if worker_sensevoice is None:
156
+ raise RuntimeError("SenseVoice model loading failed.")
157
+ except Exception as e:
158
+ logging.critical(f"Failed to load SenseVoice model on {worker_device}: {e}")
159
+ raise e
160
+
161
+
162
+ def post_process(text: str, lang: str) -> str:
163
+ """
164
+ Cleans and normalizes text for calculation.
165
+ """
166
+ assert lang == "yue", "this script is designed for Cantonese (yue) evaluation only."
167
+ text = text_normalize(
168
+ text,
169
+ iso_code="yue",
170
+ lower_case=True,
171
+ remove_numbers=False,
172
+ remove_brackets=False,
173
+ )
174
+
175
+ text = zhconv.convert(text, "zh-cn")
176
+
177
+ text = cn2an.transform(text, "an2cn")
178
+
179
+ text = text.replace(" ", "")
180
+ text = " ".join([x for x in text])
181
+ text = text.lower()
182
+ return text.strip()
183
+
184
+
185
+ def run_eval_worker_sensevoice(data_chunk, batch_size):
186
+ global worker_sensevoice
187
+ if worker_sensevoice is None:
188
+ logging.error("SenseVoice worker pipeline is not initialized!")
189
+ return []
190
+
191
+ metrics_buffer = []
192
+ try:
193
+ wav_paths = [item["wav_path"] for item in data_chunk]
194
+
195
+ for i in range(0, len(wav_paths), batch_size):
196
+ batch_paths = wav_paths[i : i + batch_size]
197
+
198
+ # SenseVoice generate call, target lang mapped to yue
199
+ res_batch = worker_sensevoice.generate(
200
+ input=batch_paths,
201
+ batch_size=batch_size,
202
+ language="yue",
203
+ use_itn=False,
204
+ disable_pbar=True,
205
+ )
206
+
207
+ for j, res in enumerate(res_batch):
208
+ hypothesis = res["text"]
209
+ # SenseVoice may format output with language tags,
210
+ # cleaning basic tags if any
211
+ hypothesis = re.sub(r"<\|[^|]*\|>", "", hypothesis).strip()
212
+
213
+ ref_item = data_chunk[i + j]
214
+ truth = ref_item["truth_text"]
215
+ wav_path = ref_item["wav_path"]
216
+ lang_name = ref_item.get("lang_name")
217
+
218
+ m = process_one(hypothesis, truth, post_process, "yue")
219
+ m["wav_path"] = wav_path
220
+ m["lang_name"] = lang_name
221
+ metrics_buffer.append(m)
222
+
223
+ except Exception:
224
+ logging.error(f"SenseVoice worker failed on chunk:\n{traceback.format_exc()}")
225
+ return []
226
+
227
+ return metrics_buffer
228
+
229
+
230
+ def main():
231
+ parser = get_parser()
232
+ args = parser.parse_args()
233
+
234
+ logging.basicConfig(
235
+ format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
236
+ level=logging.INFO,
237
+ force=True,
238
+ )
239
+
240
+ logging.info("Reading test list and filtering for Cantonese (yue)...")
241
+ yue_items = []
242
+ wav_root = Path(args.wav_path)
243
+
244
+ samples = read_test_list(args.test_list)
245
+ for s in samples:
246
+ lang_id = s.get("language_id", "")
247
+ if lang_id != "yue":
248
+ continue
249
+
250
+ wav_path = str(wav_root / f"{s['id']}.{args.extension}")
251
+ if not os.path.exists(wav_path):
252
+ logging.warning(f"File missing: {wav_path}")
253
+ continue
254
+
255
+ yue_items.append(
256
+ {
257
+ "wav_path": wav_path,
258
+ "truth_text": s["text"],
259
+ "lang_id": "yue",
260
+ "lang_name": s.get("language_name", "Cantonese"),
261
+ }
262
+ )
263
+
264
+ logging.info(f"Total Cantonese files found: {len(yue_items)}.")
265
+ if len(yue_items) == 0:
266
+ logging.warning("No files to evaluate. Exiting.")
267
+ return
268
+
269
+ num_gpus = torch.cuda.device_count()
270
+ assert num_gpus > 0, "No GPU found. GPU is required."
271
+ total_workers = num_gpus * args.nj_per_gpu
272
+
273
+ mp.set_start_method("spawn", force=True)
274
+ manager = mp.Manager()
275
+
276
+ chunk_size = args.chunk_size
277
+ tasks = []
278
+ for i in range(0, len(yue_items), chunk_size):
279
+ tasks.append(yue_items[i : i + chunk_size])
280
+
281
+ results = []
282
+ rank_queue = manager.Queue()
283
+ for _ in range(args.nj_per_gpu):
284
+ for rank in range(num_gpus):
285
+ rank_queue.put(rank)
286
+
287
+ with ProcessPoolExecutor(
288
+ max_workers=total_workers,
289
+ initializer=process_init_sensevoice,
290
+ initargs=(rank_queue, args.model_dir),
291
+ ) as executor:
292
+
293
+ futures = []
294
+ for chunk in tasks:
295
+ futures.append(
296
+ executor.submit(run_eval_worker_sensevoice, chunk, args.batch_size)
297
+ )
298
+
299
+ with tqdm(
300
+ total=len(yue_items),
301
+ desc="SenseVoice Eval (Cantonese)",
302
+ dynamic_ncols=True,
303
+ ) as pbar:
304
+ for future in as_completed(futures):
305
+ try:
306
+ chunk_metrics = future.result()
307
+ results.extend(chunk_metrics)
308
+ pbar.update(len(chunk_metrics))
309
+ except Exception as e:
310
+ logging.error(f"Task failed: {e}")
311
+
312
+ # Metrics Aggregation
313
+ inses, deles, subses = [], [], []
314
+ word_nums = 0
315
+
316
+ fout = None
317
+ if args.decode_path:
318
+ os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
319
+ logging.info(f"Saving detailed CER results to: {args.decode_path}")
320
+ fout = open(args.decode_path, "w", encoding="utf-8")
321
+
322
+ for res in results:
323
+ inses.append(float(res["insertions"]))
324
+ deles.append(float(res["deletions"]))
325
+ subses.append(float(res["substitutions"]))
326
+ word_nums += res["word_num"]
327
+
328
+ if fout:
329
+ fout.write(
330
+ f"{res['wav_path']}\t{res['wer']}\t{res['truth']}\t"
331
+ f"{res['hypo']}\t{res['insertions']}\t{res['deletions']}\t"
332
+ f"{res['substitutions']}\n"
333
+ )
334
+
335
+ print("-" * 50)
336
+ if word_nums > 0:
337
+ log_metrics(fout, "[yue] Cantonese", inses, deles, subses, word_nums)
338
+
339
+ if fout:
340
+ fout.close()
341
+
342
+
343
+ if __name__ == "__main__":
344
+ main()
omnivoice/eval/wer/text_norm_omni.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ """
9
+ This module contains the text normalization function for WER evaluation.
10
+ Copied from https://github.com/facebookresearch/omnilingual-asr/blob/81f51e224ce9e74b02cc2a3eaf21b2d91d743455/workflows/dataprep/text_tools.py
11
+ """
12
+
13
+ import re
14
+ import unicodedata
15
+
16
+ from unidecode import unidecode
17
+
18
+ import omnivoice.eval.wer.norm_config_module as norm_config_module
19
+
20
+ norm_config = norm_config_module.norm_config # type: ignore
21
+
22
+
23
+ def text_normalize(
24
+ text, iso_code, lower_case=True, remove_numbers=True, remove_brackets=False
25
+ ):
26
+ """Given a text, normalize it by changing to lower case, removing punctuations, removing words that only contain digits and removing extra spaces
27
+
28
+ Args:
29
+ text : The string to be normalized
30
+ iso_code :
31
+ remove_numbers : Boolean flag to specify if words containing only digits should be removed
32
+
33
+ Returns:
34
+ normalized_text : the string after all normalization
35
+
36
+ """
37
+
38
+ config = norm_config.get(iso_code, norm_config["*"])
39
+
40
+ for field in [
41
+ "lower_case",
42
+ "punc_set",
43
+ "del_set",
44
+ "mapping",
45
+ "digit_set",
46
+ "unicode_norm",
47
+ ]:
48
+ if field not in config:
49
+ config[field] = norm_config["*"][field]
50
+
51
+ text = unicodedata.normalize(config["unicode_norm"], text)
52
+
53
+ # Convert to lower case
54
+
55
+ if config["lower_case"] and lower_case:
56
+ text = text.lower()
57
+
58
+ # brackets
59
+
60
+ # always text inside brackets with numbers in them. Usually corresponds to "(Sam 23:17)"
61
+ text = re.sub(r"\([^\)]*\d[^\)]*\)", " ", text)
62
+ if remove_brackets:
63
+ text = re.sub(r"\([^\)]*\)", " ", text)
64
+
65
+ # Apply mappings
66
+
67
+ for old, new in config["mapping"].items():
68
+ text = re.sub(old, new, text)
69
+
70
+ # Replace punctutations with space
71
+
72
+ punct_pattern = r"[" + config["punc_set"]
73
+
74
+ punct_pattern += "]"
75
+
76
+ normalized_text = re.sub(punct_pattern, " ", text)
77
+
78
+ # remove characters in delete list
79
+
80
+ delete_patten = r"[" + config["del_set"] + "]"
81
+
82
+ normalized_text = re.sub(delete_patten, "", normalized_text)
83
+
84
+ # Remove words containing only digits
85
+ # We check for 3 cases a)text starts with a number b) a number is present somewhere in the middle of the text c) the text ends with a number
86
+ # For each case we use lookaround regex pattern to see if the digit pattern in preceded and followed by whitespaces, only then we replace the numbers with space
87
+ # The lookaround enables overlapping pattern matches to be replaced
88
+
89
+ if remove_numbers:
90
+
91
+ digits_pattern = "[" + config["digit_set"]
92
+
93
+ digits_pattern += "]+"
94
+
95
+ complete_digit_pattern = (
96
+ r"^"
97
+ + digits_pattern
98
+ + r"(?=\s)|(?<=\s)"
99
+ + digits_pattern
100
+ + r"(?=\s)|(?<=\s)"
101
+ + digits_pattern
102
+ + "$"
103
+ )
104
+
105
+ normalized_text = re.sub(complete_digit_pattern, " ", normalized_text)
106
+
107
+ if config["rm_diacritics"]:
108
+ normalized_text = unidecode(normalized_text)
109
+
110
+ # Remove extra spaces
111
+ normalized_text = re.sub(r"\s+", " ", normalized_text).strip()
112
+
113
+ return normalized_text
omnivoice/models/__init__.py ADDED
File without changes
omnivoice/models/omnivoice.py ADDED
@@ -0,0 +1,1502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Core OmniVoice model implementation.
19
+
20
+ Defines the ``OmniVoice`` model class, generation config, and inference pipeline.
21
+ This is the main entry point for both inference and training:
22
+
23
+ - **Inference**: ``OmniVoice.from_pretrained()`` loads the model, then
24
+ ``model.generate()`` supports voice cloning, voice design, and auto voice.
25
+ - **Training**: ``model.forward()`` computes the training loss; the model is
26
+ built and used by ``omnivoice.training.builder`` and ``omnivoice.training.trainer``.
27
+
28
+ """
29
+
30
+ import difflib
31
+ import logging
32
+ import math
33
+ import os
34
+ import re
35
+ from dataclasses import dataclass, fields
36
+ from functools import partial
37
+ from typing import Any, List, Optional, Union
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+ import torchaudio
43
+ from torch.nn.attention.flex_attention import create_block_mask
44
+ from transformers import (
45
+ AutoFeatureExtractor,
46
+ AutoModel,
47
+ AutoTokenizer,
48
+ HiggsAudioV2TokenizerModel,
49
+ PretrainedConfig,
50
+ PreTrainedModel,
51
+ )
52
+ from transformers.modeling_outputs import ModelOutput
53
+ from transformers.models.auto import CONFIG_MAPPING, AutoConfig
54
+
55
+ from omnivoice.utils.audio import (
56
+ cross_fade_chunks,
57
+ fade_and_pad_audio,
58
+ load_audio,
59
+ remove_silence,
60
+ trim_long_audio,
61
+ )
62
+ from omnivoice.utils.duration import RuleDurationEstimator
63
+ from omnivoice.utils.lang_map import LANG_IDS, LANG_NAMES
64
+ from omnivoice.utils.text import add_punctuation, chunk_text_punctuation
65
+ from omnivoice.utils.voice_design import (
66
+ _INSTRUCT_ALL_VALID,
67
+ _INSTRUCT_EN_TO_ZH,
68
+ _INSTRUCT_MUTUALLY_EXCLUSIVE,
69
+ _INSTRUCT_VALID_EN,
70
+ _INSTRUCT_VALID_ZH,
71
+ _INSTRUCT_ZH_TO_EN,
72
+ _ZH_RE,
73
+ )
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+
78
+ # ---------------------------------------------------------------------------
79
+ # Dataclasses
80
+ # ---------------------------------------------------------------------------
81
+
82
+
83
+ @dataclass
84
+ class VoiceClonePrompt:
85
+ ref_audio_tokens: torch.Tensor # (C, T)
86
+ ref_text: str
87
+ ref_rms: float
88
+
89
+
90
+ @dataclass
91
+ class OmniVoiceGenerationConfig:
92
+ num_step: int = 32
93
+ guidance_scale: float = 2.0
94
+ t_shift: float = 0.1
95
+ layer_penalty_factor: float = 5.0
96
+ position_temperature: float = 5.0
97
+ class_temperature: float = 0.0
98
+ denoise: bool = True
99
+ preprocess_prompt: bool = True
100
+ postprocess_output: bool = True
101
+ audio_chunk_duration: float = 15.0
102
+ audio_chunk_threshold: float = 30.0
103
+
104
+ @classmethod
105
+ def from_dict(cls, kwargs_dict):
106
+ valid_keys = {f.name for f in fields(cls)}
107
+ filtered = {k: v for k, v in kwargs_dict.items() if k in valid_keys}
108
+ return cls(**filtered)
109
+
110
+
111
+ @dataclass
112
+ class GenerationTask:
113
+ batch_size: int
114
+ texts: List[str]
115
+ target_lens: List[int]
116
+ langs: List[Optional[str]]
117
+ instructs: List[Optional[str]]
118
+ ref_texts: List[Optional[str]]
119
+ ref_audio_tokens: List[Optional[torch.Tensor]]
120
+ ref_rms: List[Optional[float]]
121
+ speed: Optional[List[float]] = None
122
+
123
+ def get_indices(self, config: OmniVoiceGenerationConfig, frame_rate: int):
124
+ threshold = int(config.audio_chunk_threshold * frame_rate)
125
+ short_idx = [i for i, l in enumerate(self.target_lens) if l <= threshold]
126
+ long_idx = [i for i, l in enumerate(self.target_lens) if l > threshold]
127
+ return short_idx, long_idx
128
+
129
+ def slice_task(self, indices: List[int]):
130
+ if not indices:
131
+ return None
132
+ return GenerationTask(
133
+ batch_size=len(indices),
134
+ texts=[self.texts[i] for i in indices],
135
+ target_lens=[self.target_lens[i] for i in indices],
136
+ langs=[self.langs[i] for i in indices],
137
+ instructs=[self.instructs[i] for i in indices],
138
+ ref_texts=[self.ref_texts[i] for i in indices],
139
+ ref_audio_tokens=[self.ref_audio_tokens[i] for i in indices],
140
+ ref_rms=[self.ref_rms[i] for i in indices],
141
+ speed=[self.speed[i] for i in indices] if self.speed else None,
142
+ )
143
+
144
+
145
+ @dataclass
146
+ class OmniVoiceModelOutput(ModelOutput):
147
+ loss: Optional[torch.Tensor] = None
148
+ logits: Optional[torch.Tensor] = None
149
+
150
+
151
+ # ---------------------------------------------------------------------------
152
+ # Config & Model
153
+ # ---------------------------------------------------------------------------
154
+
155
+
156
+ class OmniVoiceConfig(PretrainedConfig):
157
+ model_type = "omnivoice"
158
+ sub_configs = {"llm_config": AutoConfig}
159
+
160
+ def __init__(
161
+ self,
162
+ audio_vocab_size: int = 1025,
163
+ audio_mask_id: int = 1024,
164
+ num_audio_codebook: int = 8,
165
+ audio_codebook_weights: Optional[list[float]] = None,
166
+ llm_config: Optional[Union[dict, PretrainedConfig]] = None,
167
+ **kwargs,
168
+ ):
169
+
170
+ if isinstance(llm_config, dict):
171
+ llm_config = CONFIG_MAPPING[llm_config["model_type"]](**llm_config)
172
+
173
+ self.llm_config = llm_config
174
+
175
+ super().__init__(**kwargs)
176
+ self.audio_vocab_size = audio_vocab_size
177
+ self.audio_mask_id = audio_mask_id
178
+ self.num_audio_codebook = num_audio_codebook
179
+ if audio_codebook_weights is None:
180
+ audio_codebook_weights = [8, 8, 6, 6, 4, 4, 2, 2]
181
+ self.audio_codebook_weights = audio_codebook_weights
182
+
183
+
184
+ class OmniVoice(PreTrainedModel):
185
+ _supports_flex_attn = True
186
+ _supports_flash_attn_2 = True
187
+ config_class = OmniVoiceConfig
188
+
189
+ def __init__(self, config: OmniVoiceConfig, llm: Optional[PreTrainedModel] = None):
190
+ super().__init__(config)
191
+
192
+ if llm is not None:
193
+ # If an LLM instance is provided, use it directly
194
+ # (skipping config-based init).
195
+ self.llm = llm
196
+ else:
197
+ # Otherwise, initialize the LLM from the config.
198
+ self.llm = AutoModel.from_config(self.config.llm_config)
199
+
200
+ self.audio_embeddings = nn.Embedding(
201
+ config.num_audio_codebook * config.audio_vocab_size,
202
+ self.config.llm_config.hidden_size,
203
+ )
204
+ self.register_buffer(
205
+ "codebook_layer_offsets",
206
+ torch.arange(config.num_audio_codebook) * config.audio_vocab_size,
207
+ )
208
+
209
+ self.audio_heads = nn.Linear(
210
+ self.config.llm_config.hidden_size,
211
+ config.num_audio_codebook * config.audio_vocab_size,
212
+ bias=False,
213
+ )
214
+
215
+ self.normalized_audio_codebook_weights = [
216
+ w / sum(config.audio_codebook_weights)
217
+ for w in config.audio_codebook_weights
218
+ ]
219
+
220
+ self.post_init()
221
+
222
+ # Inference-only attributes (set by from_pretrained when not in train mode)
223
+ self.text_tokenizer = None
224
+ self.audio_tokenizer = None
225
+ self.duration_estimator = None
226
+ self.sampling_rate = None
227
+ self._asr_pipe = None
228
+
229
+ @classmethod
230
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
231
+ train_mode = kwargs.pop("train", False)
232
+ load_asr = kwargs.pop("load_asr", False)
233
+ asr_model_name = kwargs.pop("asr_model_name", "openai/whisper-large-v3-turbo")
234
+
235
+ # Suppress noisy INFO logs from transformers/huggingface_hub during loading
236
+ _prev_disable = logging.root.manager.disable
237
+ logging.disable(logging.INFO)
238
+
239
+ try:
240
+ model = super().from_pretrained(
241
+ pretrained_model_name_or_path, *args, **kwargs
242
+ )
243
+
244
+ if not train_mode:
245
+ # Resolve local path for audio tokenizer subdirectory
246
+ if os.path.isdir(pretrained_model_name_or_path):
247
+ resolved_path = pretrained_model_name_or_path
248
+ else:
249
+ from huggingface_hub import snapshot_download
250
+
251
+ resolved_path = snapshot_download(pretrained_model_name_or_path)
252
+
253
+ model.text_tokenizer = AutoTokenizer.from_pretrained(
254
+ pretrained_model_name_or_path
255
+ )
256
+
257
+ audio_tokenizer_path = os.path.join(resolved_path, "audio_tokenizer")
258
+
259
+ if not os.path.isdir(audio_tokenizer_path):
260
+ # Fallback to the HuggingFace Hub path of transformers'
261
+ # HiggsAudioV2Tokenizer if the local subdirectory doesn't exist.
262
+ audio_tokenizer_path = "eustlb/higgs-audio-v2-tokenizer"
263
+
264
+ # higgs-audio-v2-tokenizer does not support MPS (output channels > 65536)
265
+ tokenizer_device = (
266
+ "cpu" if str(model.device).startswith("mps") else model.device
267
+ )
268
+ model.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
269
+ audio_tokenizer_path, device_map=tokenizer_device
270
+ )
271
+ model.feature_extractor = AutoFeatureExtractor.from_pretrained(
272
+ audio_tokenizer_path
273
+ )
274
+
275
+ model.sampling_rate = model.feature_extractor.sampling_rate
276
+
277
+ model.duration_estimator = RuleDurationEstimator()
278
+
279
+ if load_asr:
280
+ model.load_asr_model(model_name=asr_model_name)
281
+ finally:
282
+ logging.disable(_prev_disable)
283
+
284
+ return model
285
+
286
+ # -------------------------------------------------------------------
287
+ # ASR support (optional, for auto-transcription)
288
+ # -------------------------------------------------------------------
289
+
290
+ def load_asr_model(self, model_name: str = "openai/whisper-large-v3-turbo"):
291
+ """Load a Whisper ASR model for reference audio transcription.
292
+
293
+ Args:
294
+ model_name: HuggingFace model name for the Whisper model.
295
+ """
296
+ from transformers import pipeline as hf_pipeline
297
+
298
+ logger.info("Loading ASR model %s ...", model_name)
299
+ asr_dtype = (
300
+ torch.float16 if str(self.device).startswith("cuda") else torch.float32
301
+ )
302
+ self._asr_pipe = hf_pipeline(
303
+ "automatic-speech-recognition",
304
+ model=model_name,
305
+ dtype=asr_dtype,
306
+ device_map=self.device,
307
+ )
308
+ logger.info("ASR model loaded on %s.", self.device)
309
+
310
+ @torch.inference_mode()
311
+ def transcribe(
312
+ self,
313
+ audio: Union[str, tuple[torch.Tensor, int]],
314
+ ) -> str:
315
+ """Transcribe audio using the loaded Whisper ASR model.
316
+
317
+ Args:
318
+ audio: File path or (waveform, sample_rate) tuple.
319
+
320
+ Returns:
321
+ Transcribed text.
322
+ """
323
+ if self._asr_pipe is None:
324
+ raise RuntimeError(
325
+ "ASR model is not loaded. Call model.load_asr_model() first."
326
+ )
327
+
328
+ if isinstance(audio, str):
329
+ return self._asr_pipe(audio)["text"].strip()
330
+ else:
331
+ waveform, sr = audio
332
+ if waveform.dim() == 1:
333
+ waveform = waveform.unsqueeze(0)
334
+ if waveform.size(0) > 1:
335
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
336
+ audio_input = {
337
+ "array": waveform.squeeze(0).cpu().numpy(),
338
+ "sampling_rate": sr,
339
+ }
340
+ return self._asr_pipe(audio_input)["text"].strip()
341
+
342
+ def get_input_embeddings(self):
343
+ return self.llm.get_input_embeddings()
344
+
345
+ def set_input_embeddings(self, value):
346
+ self.llm.set_input_embeddings(value)
347
+
348
+ def _prepare_embed_inputs(
349
+ self, input_ids: torch.Tensor, audio_mask: torch.Tensor
350
+ ) -> torch.Tensor:
351
+ """
352
+ Prepares embeddings from input_ids of shape (batch_size, layers, seq_length).
353
+ Embedding shape is (batch_size, seq_length, hidden_size).
354
+ """
355
+ text_embeds = self.get_input_embeddings()(input_ids[:, 0, :])
356
+
357
+ # Apply shift to audio IDs based on codebook layer
358
+ # audio_ids: [Batch, 8, Seq]
359
+ # codebook_layer_offsets: [1, 8, 1]
360
+ # Result: Layer 0 ID Layer 1 ID + Layer 2 ID + 2050...
361
+ shifted_ids = (
362
+ input_ids * audio_mask.unsqueeze(1)
363
+ ) + self.codebook_layer_offsets.view(1, -1, 1)
364
+
365
+ # input: [Batch, 8, Seq] -> output: [Batch, Seq, Hidden]
366
+ audio_embeds = self.audio_embeddings(shifted_ids).sum(dim=1)
367
+
368
+ return torch.where(audio_mask.unsqueeze(-1), audio_embeds, text_embeds)
369
+
370
+ def forward(
371
+ self,
372
+ input_ids: torch.LongTensor,
373
+ audio_mask: torch.Tensor,
374
+ labels: Optional[torch.LongTensor] = None,
375
+ attention_mask: Optional[torch.Tensor] = None,
376
+ document_ids: Optional[torch.Tensor] = None,
377
+ position_ids: Optional[torch.LongTensor] = None,
378
+ ):
379
+
380
+ inputs_embeds = self._prepare_embed_inputs(input_ids, audio_mask)
381
+
382
+ if attention_mask is None and document_ids is not None:
383
+ attention_mask = create_block_mask(
384
+ _get_packed_mask(
385
+ document_ids[0].to(inputs_embeds.device),
386
+ ),
387
+ B=None,
388
+ H=None,
389
+ Q_LEN=input_ids.size(-1),
390
+ KV_LEN=input_ids.size(-1),
391
+ _compile=True,
392
+ device=inputs_embeds.device,
393
+ )
394
+
395
+ llm_outputs = self.llm(
396
+ inputs_embeds=inputs_embeds,
397
+ attention_mask=attention_mask,
398
+ return_dict=True,
399
+ position_ids=position_ids,
400
+ )
401
+ hidden_states = llm_outputs[0]
402
+
403
+ loss = None
404
+
405
+ # Shape: [B, S, C * Vocab]
406
+ batch_size, seq_len, _ = hidden_states.shape
407
+ logits_flat = self.audio_heads(hidden_states)
408
+ # Shape: [B, S, C, Vocab] -> [B, C, S, Vocab]
409
+ audio_logits = logits_flat.view(
410
+ batch_size,
411
+ seq_len,
412
+ self.config.num_audio_codebook,
413
+ self.config.audio_vocab_size,
414
+ ).permute(0, 2, 1, 3)
415
+
416
+ if labels is not None:
417
+
418
+ # audio_logits.permute(0, 3, 1, 2):
419
+ # [Batch, Layer, Seq, Vocab] -> [Batch, Vocab, Layer, Seq]
420
+ # per_token_loss shape: [Batch, Layer, Seq],ignore -100
421
+ per_token_loss = torch.nn.functional.cross_entropy(
422
+ audio_logits.permute(0, 3, 1, 2),
423
+ labels,
424
+ reduction="none",
425
+ ignore_index=-100,
426
+ )
427
+ # valid_mask shape: [Batch, Layer, Seq]
428
+ valid_mask = (labels != -100).float()
429
+
430
+ # layer_means shape: [num_layers]
431
+ layer_means = (per_token_loss * valid_mask).sum(
432
+ dim=(0, 2)
433
+ ) / valid_mask.sum(dim=(0, 2)).clamp(min=1.0)
434
+
435
+ weights = torch.tensor(
436
+ self.normalized_audio_codebook_weights, device=audio_logits.device
437
+ )
438
+ loss = (layer_means * weights).sum()
439
+
440
+ return OmniVoiceModelOutput(
441
+ loss=loss,
442
+ logits=audio_logits,
443
+ )
444
+
445
+ def supported_language_ids(self) -> set[str]:
446
+ """Return a list of supported language IDs."""
447
+ return LANG_IDS
448
+
449
+ def supported_language_names(self) -> set[str]:
450
+ """Return a list of supported language names."""
451
+ return LANG_NAMES
452
+
453
+ # -------------------------------------------------------------------
454
+ # Inference API
455
+ # -------------------------------------------------------------------
456
+
457
+ @torch.inference_mode()
458
+ def generate(
459
+ self,
460
+ text: Union[str, list[str]],
461
+ language: Union[str, list[str], None] = None,
462
+ ref_text: Union[str, list[str], None] = None,
463
+ ref_audio: Union[
464
+ str,
465
+ list[str],
466
+ tuple[torch.Tensor, int],
467
+ list[tuple[torch.Tensor, int]],
468
+ None,
469
+ ] = None,
470
+ voice_clone_prompt: Union[
471
+ VoiceClonePrompt, list[VoiceClonePrompt], None
472
+ ] = None,
473
+ instruct: Union[str, list[str], None] = None,
474
+ duration: Union[float, list[Optional[float]], None] = None,
475
+ speed: Union[float, list[Optional[float]], None] = None,
476
+ generation_config: Optional[OmniVoiceGenerationConfig] = None,
477
+ **kwargs,
478
+ ) -> list[torch.Tensor]:
479
+ """Generate speech audio given text in various modes.
480
+
481
+ Supports three modes:
482
+
483
+ 1. **Voice clone** — clone the voice style from the reference audio.
484
+ Should provide ``voice_clone_prompt`` (from
485
+ :meth:`create_voice_clone_prompt`) or ``ref_text`` + ``ref_audio``.
486
+ 2. **Voice design** — provide ``instruct`` text describing
487
+ the desired voice style; no reference audio needed.
488
+ 3. **Auto** — provide neither; the model picks a voice itself.
489
+
490
+ Args:
491
+ text: Target text (single string or list for batch).
492
+ language: Language name (e.g. ``"English"``) or code
493
+ (e.g. ``"en"``). ``None`` for language-agnostic mode.
494
+ Performance is slightly better if you specify the language.
495
+ ref_text: Optional reference text for voice cloning mode.
496
+ ref_audio: Optional reference audio for voice cloning mode.
497
+ Can be a file path or a (waveform, sample_rate) tuple.
498
+ voice_clone_prompt: Reusable prompt from :meth:`create_voice_clone_prompt`.
499
+ If provided, it overrides ``ref_text`` and ``ref_audio``.
500
+ instruct: Style instruction for voice design mode.
501
+ duration: Fixed output duration in seconds. If a single float,
502
+ applies to all items; if a list, one value per item.
503
+ ``None`` (default) lets the model estimate duration from text.
504
+ Overrides ``speed`` when both are provided.
505
+ speed: Speaking speed factor. ``> 1.0`` for faster, ``< 1.0`` for
506
+ slower. If a list, one value per item. ``None`` (default) uses
507
+ the model's default estimation.
508
+ generation_config: Explicit config object. If provided, takes
509
+ precedence over ``**kwargs``.
510
+ **kwargs: Generation config or its fields:
511
+ denoise: Whether to prepend the ``<|denoise|>`` token.
512
+ num_step: Number of iterative decoding steps.
513
+ guidance_scale: Classifier-free guidance scale.
514
+ t_shift: Time-step shift (smaller → emphasise low-SNR).
515
+ postprocess_output: Post-process output (remove silence, fade-in/out, pad edges).
516
+ layer_penalty_factor: Penalty encouraging earlier codebook
517
+ layers to unmask first.
518
+ position_temperature: Temperature for position selection.
519
+ class_temperature: Temperature for token sampling (0 = greedy).
520
+ audio_chunk_duration: If > 0, split long text into chunks of
521
+ this duration (seconds) and generate chunk by chunk.
522
+ audio_chunk_threshold: Only apply chunking if estimated audio
523
+ duration exceeds this threshold (seconds).
524
+ Returns:
525
+ ``audios`` a list of 2-D ``torch.Tensor``, with the shape (1, T) and sampling rate
526
+ consistent with the model's audio tokenizer (usually 24000 Hz).
527
+ """
528
+
529
+ if self.audio_tokenizer is None or self.text_tokenizer is None:
530
+ raise RuntimeError(
531
+ "Model is not loaded with audio/text tokenizers. Make sure you "
532
+ "loaded the model with OmniVoice.from_pretrained()."
533
+ )
534
+ gen_config = (
535
+ generation_config
536
+ if generation_config is not None
537
+ else OmniVoiceGenerationConfig.from_dict(kwargs)
538
+ )
539
+
540
+ self.eval()
541
+
542
+ full_task = self._preprocess_all(
543
+ text=text,
544
+ language=language,
545
+ ref_text=ref_text,
546
+ ref_audio=ref_audio,
547
+ voice_clone_prompt=voice_clone_prompt,
548
+ instruct=instruct,
549
+ preprocess_prompt=gen_config.preprocess_prompt,
550
+ speed=speed,
551
+ duration=duration,
552
+ )
553
+
554
+ short_idx, long_idx = full_task.get_indices(
555
+ gen_config, self.audio_tokenizer.config.frame_rate
556
+ )
557
+
558
+ results = [None] * full_task.batch_size
559
+
560
+ if short_idx:
561
+ short_task = full_task.slice_task(short_idx)
562
+ short_results = self._generate_iterative(short_task, gen_config)
563
+ for idx, res in zip(short_idx, short_results):
564
+ results[idx] = res
565
+
566
+ if long_idx:
567
+ long_task = full_task.slice_task(long_idx)
568
+ long_results = self._generate_chunked(long_task, gen_config)
569
+ for idx, res in zip(long_idx, long_results):
570
+ results[idx] = res
571
+
572
+ generated_audios = []
573
+ for i in range(full_task.batch_size):
574
+ assert results[i] is not None, f"Result {i} was not generated"
575
+ generated_audios.append(
576
+ self._decode_and_post_process(
577
+ results[i], full_task.ref_rms[i], gen_config # type: ignore[arg-type]
578
+ )
579
+ )
580
+
581
+ return generated_audios
582
+
583
+ def create_voice_clone_prompt(
584
+ self,
585
+ ref_audio: Union[str, tuple[torch.Tensor, int]],
586
+ ref_text: Optional[str] = None,
587
+ preprocess_prompt: bool = True,
588
+ ) -> VoiceClonePrompt:
589
+ """Create a reusable voice clone prompt from reference audio.
590
+
591
+ Args:
592
+ ref_audio: File path (str) or ``(waveform, sample_rate)`` tuple.
593
+ waveform should be a 1-D or 2-D torch.Tensor (channels x samples).
594
+ ref_text: Transcript of the reference audio. If ``None``, the
595
+ ASR model will be used to auto-transcribe (must call
596
+ :meth:`load_asr_model` first).
597
+ preprocess_prompt: If ``True`` (default), apply silence removal and
598
+ trimming to the reference audio, add punctuation in the end
599
+ of reference text (if not already)
600
+
601
+ Returns:
602
+ A :class:`VoiceClonePrompt` that can be passed to :meth:`generate`.
603
+ """
604
+ if self.audio_tokenizer is None:
605
+ raise RuntimeError(
606
+ "Audio tokenizer is not loaded. Make sure you loaded the model "
607
+ "with OmniVoice.from_pretrained()."
608
+ )
609
+
610
+ if isinstance(ref_audio, str):
611
+ ref_wav = load_audio(ref_audio, self.sampling_rate)
612
+ else:
613
+ waveform, sr = ref_audio
614
+ if waveform.dim() == 1:
615
+ waveform = waveform.unsqueeze(0)
616
+ if waveform.size(0) > 1:
617
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
618
+ if sr != self.sampling_rate:
619
+ waveform = torchaudio.functional.resample(
620
+ waveform, sr, self.sampling_rate
621
+ )
622
+ ref_wav = waveform
623
+
624
+ ref_rms = torch.sqrt(torch.mean(torch.square(ref_wav))).item()
625
+ if 0 < ref_rms < 0.1:
626
+ ref_wav = ref_wav * 0.1 / ref_rms
627
+
628
+ if preprocess_prompt:
629
+ # Trim long reference audio (>20s) by splitting at the largest silence gap.
630
+ # Skip trimming when ref_text is user-provided, otherwise the
631
+ # trimmed audio will no longer match the full transcript.
632
+ if ref_text is None:
633
+ ref_wav = trim_long_audio(ref_wav, self.sampling_rate)
634
+ elif ref_wav.size(-1) / self.sampling_rate > 20.0:
635
+ logger.warning(
636
+ "Reference audio is %.1fs long (>20s) and ref_text was "
637
+ "provided, so automatic trimming is skipped. A long reference "
638
+ "may cause slower generation and degraded quality.",
639
+ ref_wav.size(-1) / self.sampling_rate,
640
+ )
641
+
642
+ ref_wav = remove_silence(
643
+ ref_wav,
644
+ self.sampling_rate,
645
+ mid_sil=200,
646
+ lead_sil=100,
647
+ trail_sil=200,
648
+ )
649
+ if ref_wav.size(-1) == 0:
650
+ raise ValueError(
651
+ "Reference audio is empty after silence removal. "
652
+ "Try setting preprocess_prompt=False."
653
+ )
654
+
655
+ # Auto-transcribe if ref_text not provided
656
+ if ref_text is None:
657
+ if self._asr_pipe is None:
658
+ logger.info("ASR model not loaded yet, loading on-the-fly ...")
659
+ self.load_asr_model()
660
+ ref_text = self.transcribe((ref_wav, self.sampling_rate))
661
+ logger.debug("Auto-transcribed ref_text: %s", ref_text)
662
+
663
+ chunk_size = self.audio_tokenizer.config.hop_length
664
+ clip_size = int(ref_wav.size(-1) % chunk_size)
665
+ ref_wav = ref_wav[:, :-clip_size] if clip_size > 0 else ref_wav
666
+ ref_audio_tokens = self.audio_tokenizer.encode(
667
+ ref_wav.unsqueeze(0).to(self.audio_tokenizer.device),
668
+ ).audio_codes.squeeze(
669
+ 0
670
+ ) # (C, T)
671
+
672
+ if preprocess_prompt:
673
+ ref_text = add_punctuation(ref_text)
674
+
675
+ return VoiceClonePrompt(
676
+ ref_audio_tokens=ref_audio_tokens,
677
+ ref_text=ref_text,
678
+ ref_rms=ref_rms,
679
+ )
680
+
681
+ def _decode_and_post_process(
682
+ self,
683
+ tokens: Union[torch.Tensor, List[torch.Tensor]],
684
+ rms: Union[float, None],
685
+ gen_config: OmniVoiceGenerationConfig,
686
+ ) -> torch.Tensor:
687
+ """
688
+ Args:
689
+ tokens: Audio tokens — either a single tensor of shape
690
+ (num_codebooks, seq_len) or a list of chunk tensors.
691
+ rms: RMS of the reference audio for volume adjustment.
692
+ gen_config: Generation config for post-processing options.
693
+ Returns:
694
+ Decoded and post-processed audio tensor of shape (1, T).
695
+ """
696
+ tokenizer_device = self.audio_tokenizer.device
697
+ if isinstance(tokens, list):
698
+ chunk_audios = [
699
+ self.audio_tokenizer.decode(t.to(tokenizer_device).unsqueeze(0))
700
+ .audio_values[0]
701
+ .cpu()
702
+ for t in tokens
703
+ ]
704
+ audio_waveform = cross_fade_chunks(chunk_audios, self.sampling_rate)
705
+ else:
706
+ audio_waveform = (
707
+ self.audio_tokenizer.decode(tokens.to(tokenizer_device).unsqueeze(0))
708
+ .audio_values[0]
709
+ .cpu()
710
+ )
711
+
712
+ return self._post_process_audio(
713
+ audio_waveform,
714
+ postprocess_output=gen_config.postprocess_output,
715
+ ref_rms=rms,
716
+ )
717
+
718
+ def _post_process_audio(
719
+ self,
720
+ generated_audio: torch.Tensor,
721
+ postprocess_output: bool,
722
+ ref_rms: Union[float, None],
723
+ ) -> torch.Tensor:
724
+ """Optionally remove long silences, adjust volume, and add edge padding.
725
+
726
+ Args:
727
+ generated_audio: Audio tensor of shape (1, T).
728
+ postprocess_output: If True, remove long silences and apply fade/pad.
729
+ ref_rms: RMS of the reference audio for volume normalisation.
730
+ Returns:
731
+ Processed audio tensor of shape (1, T).
732
+ """
733
+ if postprocess_output:
734
+ generated_audio = remove_silence(
735
+ generated_audio,
736
+ self.sampling_rate,
737
+ mid_sil=500,
738
+ lead_sil=100,
739
+ trail_sil=100,
740
+ )
741
+
742
+ if ref_rms is not None and ref_rms < 0.1:
743
+ generated_audio = generated_audio * ref_rms / 0.1
744
+ elif ref_rms is None:
745
+ # No reference audio (voice design): peak-normalize to 0.5
746
+ # to avoid clipping while keeping a comfortable volume level.
747
+ peak = generated_audio.abs().max()
748
+ if peak > 1e-6:
749
+ generated_audio = generated_audio / peak * 0.5
750
+
751
+ generated_audio = fade_and_pad_audio(
752
+ generated_audio,
753
+ sample_rate=self.sampling_rate,
754
+ )
755
+ return generated_audio
756
+
757
+ def _generate_chunked(
758
+ self, task: GenerationTask, gen_config: OmniVoiceGenerationConfig
759
+ ) -> List[List[torch.Tensor]]:
760
+ """Generate long audio by splitting text into chunks and batching.
761
+
762
+ Each item in the returned list corresponds to one input and contains
763
+ a list of audio token tensors — one per text chunk.
764
+
765
+ Args:
766
+ task: A :class:`GenerationTask` with one or more items whose
767
+ estimated audio exceeds ``audio_chunk_threshold``.
768
+ gen_config: Generation config (``audio_chunk_duration`` controls
769
+ chunk size).
770
+ Returns:
771
+ Per-item list of chunk token-tensor lists.
772
+ """
773
+ # Chunk each item's text
774
+ all_chunks = []
775
+ for i in range(task.batch_size):
776
+ avg_tokens_per_char = task.target_lens[i] / len(task.texts[i])
777
+ text_chunk_len = int(
778
+ gen_config.audio_chunk_duration
779
+ * self.audio_tokenizer.config.frame_rate
780
+ / avg_tokens_per_char
781
+ )
782
+ chunks = chunk_text_punctuation(
783
+ text=task.texts[i],
784
+ chunk_len=text_chunk_len,
785
+ min_chunk_len=3,
786
+ )
787
+ logger.debug(f"Item {i} chunked into {len(chunks)} pieces: {chunks}")
788
+ all_chunks.append(chunks)
789
+
790
+ has_ref = [t is not None for t in task.ref_audio_tokens]
791
+ assert all(has_ref) or not any(has_ref), (
792
+ "Chunked inference requires all items to either have or not have "
793
+ "ref_audio. Mixed ref/non-ref is not supported."
794
+ )
795
+
796
+ max_num_chunks = max(len(c) for c in all_chunks)
797
+
798
+ # chunk_results[item_idx] = list of generated token tensors per chunk
799
+ chunk_results = [[] for _ in range(task.batch_size)]
800
+
801
+ def _run_batch(indices, texts, ref_audios, ref_texts):
802
+ speed_list = task.speed
803
+ target_lens = [
804
+ self._estimate_target_tokens(
805
+ texts[j],
806
+ ref_texts[j],
807
+ ref_audios[j].size(-1) if ref_audios[j] is not None else None,
808
+ speed=speed_list[i] if speed_list else 1.0,
809
+ )
810
+ for j, i in enumerate(indices)
811
+ ]
812
+ sub_task = GenerationTask(
813
+ batch_size=len(indices),
814
+ texts=texts,
815
+ target_lens=target_lens,
816
+ langs=[task.langs[i] for i in indices],
817
+ instructs=[task.instructs[i] for i in indices],
818
+ ref_texts=ref_texts,
819
+ ref_audio_tokens=ref_audios,
820
+ ref_rms=[task.ref_rms[i] for i in indices],
821
+ speed=[task.speed[i] for i in indices] if task.speed else None,
822
+ )
823
+ gen_tokens = self._generate_iterative(sub_task, gen_config)
824
+ for j, idx in enumerate(indices):
825
+ chunk_results[idx].append(gen_tokens[j])
826
+
827
+ if all(has_ref):
828
+ # All items have reference audio.
829
+ # We still sequentially generate chunks within each item, but we
830
+ # batch across items for the same chunk index. This allows to keep
831
+ # the VRAM usage manageable while still benefiting from batching.
832
+ for ci in range(max_num_chunks):
833
+ indices = [i for i in range(task.batch_size) if ci < len(all_chunks[i])]
834
+ if not indices:
835
+ continue
836
+ _run_batch(
837
+ indices,
838
+ texts=[all_chunks[i][ci] for i in indices],
839
+ ref_audios=[task.ref_audio_tokens[i] for i in indices],
840
+ ref_texts=[task.ref_texts[i] for i in indices],
841
+ )
842
+ else:
843
+ # No reference audio — generate chunk 0 for all items first,
844
+ # then use chunk 0 output as reference for all subsequent chunks.
845
+ indices_0 = [i for i in range(task.batch_size) if len(all_chunks[i]) > 0]
846
+ _run_batch(
847
+ indices_0,
848
+ texts=[all_chunks[i][0] for i in indices_0],
849
+ ref_audios=[None] * len(indices_0),
850
+ ref_texts=[None] * len(indices_0),
851
+ )
852
+ first_chunk_map = {idx: chunk_results[idx][0] for idx in indices_0}
853
+
854
+ # Batch all remaining chunks, using chunk 0 as fixed reference
855
+ for ci in range(1, max_num_chunks):
856
+ indices = [i for i in range(task.batch_size) if ci < len(all_chunks[i])]
857
+ if not indices:
858
+ continue
859
+ _run_batch(
860
+ indices,
861
+ texts=[all_chunks[i][ci] for i in indices],
862
+ ref_audios=[first_chunk_map[i] for i in indices],
863
+ ref_texts=[all_chunks[i][0] for i in indices],
864
+ )
865
+
866
+ return chunk_results
867
+
868
+ def _preprocess_all(
869
+ self,
870
+ text: Union[str, list[str]],
871
+ language: Union[str, list[str], None] = None,
872
+ ref_text: Union[str, list[str], None] = None,
873
+ ref_audio: Union[
874
+ str,
875
+ list[str],
876
+ tuple[torch.Tensor, int],
877
+ list[tuple[torch.Tensor, int]],
878
+ None,
879
+ ] = None,
880
+ voice_clone_prompt: Union[
881
+ VoiceClonePrompt, list[VoiceClonePrompt], None
882
+ ] = None,
883
+ instruct: Union[str, list[str], None] = None,
884
+ preprocess_prompt: bool = True,
885
+ speed: Union[float, list[Optional[float]], None] = None,
886
+ duration: Union[float, list[Optional[float]], None] = None,
887
+ ) -> GenerationTask:
888
+
889
+ if isinstance(text, str):
890
+ text_list = [text]
891
+ else:
892
+ assert isinstance(
893
+ text, list
894
+ ), "text should be a string or a list of strings"
895
+ text_list = text
896
+ batch_size = len(text_list)
897
+
898
+ language_list = self._ensure_list(language, batch_size)
899
+ language_list = [_resolve_language(lang) for lang in language_list]
900
+ instruct_list = self._ensure_list(instruct, batch_size)
901
+ for i, s in enumerate(instruct_list):
902
+ if s is None:
903
+ continue
904
+ use_zh = bool(text_list[i] and _ZH_RE.search(text_list[i]))
905
+ instruct_list[i] = _resolve_instruct(s, use_zh=use_zh)
906
+
907
+ if voice_clone_prompt is not None and (
908
+ ref_text is not None or ref_audio is not None
909
+ ):
910
+ logger.warning(
911
+ "Both voice_clone_prompt and ref_text/ref_audio are provided. "
912
+ "ref_text/ref_audio will be ignored."
913
+ )
914
+ if voice_clone_prompt is None and ref_audio is not None:
915
+ # If voice_clone_prompt is not provided, create it from
916
+ # ref_audio (ref_text will be auto-transcribed if not given).
917
+ ref_text_list = self._ensure_list(ref_text, batch_size, auto_repeat=False)
918
+ ref_audio_list = self._ensure_list(ref_audio, batch_size, auto_repeat=False)
919
+
920
+ voice_clone_prompt = []
921
+ for i in range(len(ref_text_list)):
922
+ voice_clone_prompt.append(
923
+ self.create_voice_clone_prompt(
924
+ ref_audio=ref_audio_list[i],
925
+ ref_text=ref_text_list[i],
926
+ preprocess_prompt=preprocess_prompt,
927
+ )
928
+ )
929
+
930
+ voice_clone_prompt_list = self._ensure_list(voice_clone_prompt, batch_size)
931
+ if voice_clone_prompt_list[0] is not None:
932
+ ref_text_list = [vc.ref_text for vc in voice_clone_prompt_list]
933
+ ref_audio_tokens_list = [
934
+ vc.ref_audio_tokens for vc in voice_clone_prompt_list
935
+ ]
936
+ ref_rms_list = [vc.ref_rms for vc in voice_clone_prompt_list]
937
+ else:
938
+ ref_text_list = [None] * batch_size
939
+ ref_audio_tokens_list = [None] * batch_size
940
+ ref_rms_list = [None] * batch_size
941
+
942
+ # Normalize speed/duration to per-item lists (may contain None).
943
+ if speed is not None:
944
+ if isinstance(speed, (int, float)):
945
+ user_speed = [float(speed)] * batch_size
946
+ else:
947
+ user_speed = list(speed)
948
+ else:
949
+ user_speed = None
950
+
951
+ if duration is not None:
952
+ if isinstance(duration, (int, float)):
953
+ durations = [float(duration)] * batch_size
954
+ else:
955
+ durations = list(duration)
956
+ else:
957
+ durations = None
958
+
959
+ num_target_tokens_list = []
960
+ for i in range(batch_size):
961
+ # duration[i] overrides speed for estimation: use speed=1.0
962
+ # to get the raw estimate, then override target_lens below.
963
+ has_dur = durations is not None and durations[i] is not None
964
+ item_speed = 1.0 if has_dur else (user_speed[i] if user_speed else 1.0)
965
+ est = self._estimate_target_tokens(
966
+ text_list[i],
967
+ ref_text_list[i],
968
+ ref_audio_tokens_list[i].size(-1)
969
+ if ref_audio_tokens_list[i] is not None
970
+ else None,
971
+ speed=item_speed,
972
+ )
973
+ num_target_tokens_list.append(est)
974
+
975
+ # Per-item duration overrides: set target_lens to exact frame count
976
+ # and compute speed ratio so chunked generation scales proportionally.
977
+ speed_list: Optional[List[float]] = None
978
+ if durations is not None:
979
+ frame_rate = self.audio_tokenizer.config.frame_rate
980
+ speed_list = []
981
+ for i in range(batch_size):
982
+ if durations[i] is not None:
983
+ target_tokens = max(1, int(durations[i] * frame_rate))
984
+ est = num_target_tokens_list[i]
985
+ speed_list.append(est / target_tokens if target_tokens > 0 else 1.0)
986
+ num_target_tokens_list[i] = target_tokens
987
+ else:
988
+ s = user_speed[i] if user_speed else None
989
+ speed_list.append(s if s is not None else 1.0)
990
+ elif user_speed is not None:
991
+ speed_list = [s if s is not None else 1.0 for s in user_speed]
992
+
993
+ return GenerationTask(
994
+ batch_size=batch_size,
995
+ texts=text_list,
996
+ target_lens=num_target_tokens_list,
997
+ langs=language_list,
998
+ instructs=instruct_list,
999
+ ref_texts=ref_text_list,
1000
+ ref_audio_tokens=ref_audio_tokens_list,
1001
+ ref_rms=ref_rms_list,
1002
+ speed=speed_list,
1003
+ )
1004
+
1005
+ def _estimate_target_tokens(self, text, ref_text, num_ref_audio_tokens, speed=1.0):
1006
+ """Estimate number of target audio tokens."""
1007
+ if num_ref_audio_tokens is None or ref_text is None or len(ref_text) == 0:
1008
+ # Fall back to a simple heuristic
1009
+ ref_text = "Nice to meet you."
1010
+ num_ref_audio_tokens = 25
1011
+
1012
+ est = self.duration_estimator.estimate_duration(
1013
+ text, ref_text, num_ref_audio_tokens
1014
+ )
1015
+ if speed > 0 and speed != 1.0:
1016
+ est = est / speed
1017
+ return max(1, int(est))
1018
+
1019
+ def _ensure_list(
1020
+ self, x: Union[Any, List[Any]], batch_size: int, auto_repeat: bool = True
1021
+ ) -> List[Any]:
1022
+ x_list = x if isinstance(x, list) else [x]
1023
+ if len(x_list) not in (
1024
+ 1,
1025
+ batch_size,
1026
+ ):
1027
+ raise ValueError(
1028
+ f"should be either the number of the text or 1, but got {len(x_list)}"
1029
+ )
1030
+ if auto_repeat and len(x_list) == 1 and batch_size is not None:
1031
+ x_list = x_list * batch_size
1032
+ return x_list
1033
+
1034
+ def _prepare_inference_inputs(
1035
+ self,
1036
+ text: str,
1037
+ num_target_tokens: int,
1038
+ ref_text: Optional[str] = None,
1039
+ ref_audio_tokens: Optional[torch.Tensor] = None,
1040
+ lang: Optional[str] = None,
1041
+ instruct: Optional[str] = None,
1042
+ denoise: bool = True,
1043
+ ):
1044
+ """Prepare input_ids and audio masks for inference.
1045
+ Args:
1046
+ text: Target text to generate.
1047
+ num_target_tokens: Number of audio tokens to generate.
1048
+ ref_text: Optional reference text for voice cloning.
1049
+ ref_audio_tokens: Optional reference audio tokens for voice cloning.
1050
+ with shape (C, T).
1051
+ lang: Optional language ID.
1052
+ instruct: Optional style instruction for voice design.
1053
+ denoise: Whether to include the <|denoise|> token.
1054
+ """
1055
+
1056
+ # Build style tokens: <|denoise|> + <|lang_start|>...<|lang_end|>
1057
+ # + <|instruct_start|>...<|instruct_end|>
1058
+ style_text = ""
1059
+ if denoise:
1060
+ style_text += "<|denoise|>"
1061
+ lang_str = lang if lang else "None"
1062
+ instruct_str = instruct if instruct else "None"
1063
+ style_text += f"<|lang_start|>{lang_str}<|lang_end|>"
1064
+ style_text += f"<|instruct_start|>{instruct_str}<|instruct_end|>"
1065
+
1066
+ style_tokens = (
1067
+ self.text_tokenizer(style_text, return_tensors="pt")
1068
+ .input_ids.repeat(self.config.num_audio_codebook, 1)
1069
+ .unsqueeze(0)
1070
+ ).to(
1071
+ self.device
1072
+ ) # [1, C, N1]
1073
+
1074
+ # Build text tokens
1075
+ full_text = _combine_text(ref_text=ref_text, text=text)
1076
+ text_tokens = (
1077
+ self.text_tokenizer(
1078
+ f"<|text_start|>{full_text}<|text_end|>",
1079
+ return_tensors="pt",
1080
+ )
1081
+ .input_ids.repeat(self.config.num_audio_codebook, 1)
1082
+ .unsqueeze(0)
1083
+ ).to(
1084
+ self.device
1085
+ ) # [1, C, N2]
1086
+
1087
+ # Target: all MASK
1088
+ target_audio_tokens = torch.full(
1089
+ (1, self.config.num_audio_codebook, num_target_tokens),
1090
+ self.config.audio_mask_id,
1091
+ dtype=torch.long,
1092
+ device=self.device,
1093
+ )
1094
+
1095
+ # Conditional input
1096
+ parts = [style_tokens, text_tokens]
1097
+ if ref_audio_tokens is not None:
1098
+ parts.append(ref_audio_tokens.unsqueeze(0).to(self.device))
1099
+ parts.append(target_audio_tokens)
1100
+ cond_input_ids = torch.cat(parts, dim=2)
1101
+
1102
+ cond_total_length = cond_input_ids.shape[2]
1103
+ cond_audio_start_idx = cond_total_length - num_target_tokens
1104
+ if ref_audio_tokens is not None:
1105
+ cond_audio_start_idx -= ref_audio_tokens.size(-1)
1106
+
1107
+ cond_audio_mask = torch.zeros(
1108
+ 1, cond_total_length, dtype=torch.bool, device=self.device
1109
+ )
1110
+ cond_audio_mask[0, cond_audio_start_idx:] = True
1111
+
1112
+ return {
1113
+ "input_ids": cond_input_ids,
1114
+ "audio_mask": cond_audio_mask,
1115
+ }
1116
+
1117
+ def _generate_iterative(
1118
+ self, task: GenerationTask, gen_config: OmniVoiceGenerationConfig
1119
+ ) -> List[torch.Tensor]:
1120
+ """N-step iterative unmasked decoding.
1121
+
1122
+ Args:
1123
+ task: A :class:`GenerationTask` containing batch texts, target
1124
+ lengths, languages, instructions, and optional reference data.
1125
+ gen_config: A :class:`OmniVoiceGenerationConfig` controlling
1126
+ decoding steps, guidance, temperatures, etc.
1127
+ Returns:
1128
+ List of generated audio token tensors of shape (C, T) (one per
1129
+ input text).
1130
+ """
1131
+
1132
+ B = task.batch_size
1133
+
1134
+ inputs_list = [
1135
+ self._prepare_inference_inputs(
1136
+ task.texts[i],
1137
+ task.target_lens[i],
1138
+ task.ref_texts[i],
1139
+ task.ref_audio_tokens[i],
1140
+ task.langs[i],
1141
+ task.instructs[i],
1142
+ gen_config.denoise,
1143
+ )
1144
+ for i in range(B)
1145
+ ]
1146
+
1147
+ c_lens = [inp["input_ids"].size(2) for inp in inputs_list]
1148
+ max_c_len = max(c_lens)
1149
+ pad_id = self.config.audio_mask_id # Or any other tokens
1150
+
1151
+ batch_input_ids = torch.full(
1152
+ (2 * B, self.config.num_audio_codebook, max_c_len),
1153
+ pad_id,
1154
+ dtype=torch.long,
1155
+ device=self.device,
1156
+ )
1157
+ batch_audio_mask = torch.zeros(
1158
+ (2 * B, max_c_len), dtype=torch.bool, device=self.device
1159
+ )
1160
+ batch_attention_mask = torch.zeros(
1161
+ (2 * B, 1, max_c_len, max_c_len), dtype=torch.bool, device=self.device
1162
+ )
1163
+
1164
+ for i, inp in enumerate(inputs_list):
1165
+ c_len, u_len = c_lens[i], task.target_lens[i]
1166
+
1167
+ # Cond (0 ~ B-1)
1168
+ batch_input_ids[i, :, :c_len] = inp["input_ids"]
1169
+ batch_audio_mask[i, :c_len] = inp["audio_mask"]
1170
+ batch_attention_mask[i, :, :c_len, :c_len] = True
1171
+
1172
+ # Uncond (B ~ 2B-1)
1173
+ batch_input_ids[B + i, :, :u_len] = inp["input_ids"][..., -u_len:]
1174
+ batch_audio_mask[B + i, :u_len] = inp["audio_mask"][..., -u_len:]
1175
+ batch_attention_mask[B + i, :, :u_len, :u_len] = True
1176
+
1177
+ tokens = torch.full(
1178
+ (B, self.config.num_audio_codebook, max(task.target_lens)),
1179
+ self.config.audio_mask_id,
1180
+ dtype=torch.long,
1181
+ device=self.device,
1182
+ )
1183
+
1184
+ timesteps = _get_time_steps(
1185
+ t_start=0.0,
1186
+ t_end=1.0,
1187
+ num_step=gen_config.num_step + 1,
1188
+ t_shift=gen_config.t_shift,
1189
+ ).tolist()
1190
+ schedules = []
1191
+ for t_len in task.target_lens:
1192
+ total_mask = t_len * self.config.num_audio_codebook
1193
+ rem = total_mask
1194
+ sched = []
1195
+ for step in range(gen_config.num_step):
1196
+ num = (
1197
+ rem
1198
+ if step == gen_config.num_step - 1
1199
+ else min(
1200
+ math.ceil(total_mask * (timesteps[step + 1] - timesteps[step])),
1201
+ rem,
1202
+ )
1203
+ )
1204
+ sched.append(int(num))
1205
+ rem -= int(num)
1206
+ schedules.append(sched)
1207
+
1208
+ layer_ids = torch.arange(
1209
+ self.config.num_audio_codebook, device=self.device
1210
+ ).view(1, -1, 1)
1211
+
1212
+ for step in range(gen_config.num_step):
1213
+ batch_logits = self(
1214
+ input_ids=batch_input_ids,
1215
+ audio_mask=batch_audio_mask,
1216
+ attention_mask=batch_attention_mask,
1217
+ ).logits.to(torch.float32)
1218
+
1219
+ for i in range(B):
1220
+ k = schedules[i][step]
1221
+ if k <= 0:
1222
+ continue
1223
+
1224
+ c_len, t_len = c_lens[i], task.target_lens[i]
1225
+
1226
+ # Extract real target Logits
1227
+ # [1, C, T, V]
1228
+ c_logits = batch_logits[i : i + 1, :, c_len - t_len : c_len, :]
1229
+ u_logits = batch_logits[B + i : B + i + 1, :, :t_len, :]
1230
+
1231
+ pred_tokens, scores = self._predict_tokens_with_scoring(
1232
+ c_logits, u_logits, gen_config
1233
+ )
1234
+
1235
+ scores = scores - (layer_ids * gen_config.layer_penalty_factor)
1236
+
1237
+ if gen_config.position_temperature > 0.0:
1238
+ scores = _gumbel_sample(scores, gen_config.position_temperature)
1239
+
1240
+ sample_tokens = tokens[i : i + 1, :, :t_len]
1241
+ scores.masked_fill_(
1242
+ sample_tokens != self.config.audio_mask_id, -float("inf")
1243
+ )
1244
+
1245
+ _, topk_idx = torch.topk(scores.flatten(), k)
1246
+ flat_tokens = sample_tokens.flatten()
1247
+ flat_tokens[topk_idx] = pred_tokens.flatten()[topk_idx]
1248
+ sample_tokens.copy_(flat_tokens.view_as(sample_tokens))
1249
+
1250
+ # Update individual slices into batched structure
1251
+ tokens[i : i + 1, :, :t_len] = sample_tokens
1252
+ batch_input_ids[i : i + 1, :, c_len - t_len : c_len] = sample_tokens
1253
+ batch_input_ids[B + i : B + i + 1, :, :t_len] = sample_tokens
1254
+
1255
+ return [tokens[i, :, : task.target_lens[i]] for i in range(B)]
1256
+
1257
+ def _predict_tokens_with_scoring(self, c_logits, u_logits, gen_config):
1258
+ if gen_config.guidance_scale != 0:
1259
+ c_log_probs = F.log_softmax(c_logits, dim=-1)
1260
+ u_log_probs = F.log_softmax(u_logits, dim=-1)
1261
+ log_probs = torch.log_softmax(
1262
+ c_log_probs + gen_config.guidance_scale * (c_log_probs - u_log_probs),
1263
+ dim=-1,
1264
+ )
1265
+ else:
1266
+ log_probs = F.log_softmax(c_logits, dim=-1)
1267
+
1268
+ log_probs[..., self.config.audio_mask_id] = -float("inf")
1269
+
1270
+ if gen_config.class_temperature > 0.0:
1271
+ filtered_probs = _filter_top_k(log_probs, ratio=0.1)
1272
+ pred_tokens = _gumbel_sample(
1273
+ filtered_probs, gen_config.class_temperature
1274
+ ).argmax(dim=-1)
1275
+ else:
1276
+ pred_tokens = log_probs.argmax(dim=-1)
1277
+
1278
+ confidence_scores = log_probs.max(dim=-1)[0]
1279
+
1280
+ return pred_tokens, confidence_scores
1281
+
1282
+
1283
+ # ---------------------------------------------------------------------------
1284
+ # Standalone helpers
1285
+ # ---------------------------------------------------------------------------
1286
+
1287
+
1288
+ def _get_packed_mask(document_ids):
1289
+ return partial(_mask_mod_packed, document_ids)
1290
+
1291
+
1292
+ def _mask_mod_packed(document_ids, b, h, q_idx, kv_idx):
1293
+ # 1. Sequence Packing Logic: Tokens must belong to the same document.
1294
+ # Note: The doc_id for padding tokens is -1, which will automatically not match
1295
+ # (if handled correctly) or be ignored.
1296
+ same_doc = document_ids[q_idx] == document_ids[kv_idx]
1297
+ return same_doc
1298
+
1299
+
1300
+ def _resolve_language(language: Optional[str]) -> Union[str, None]:
1301
+ from omnivoice.utils.lang_map import LANG_IDS, LANG_NAME_TO_ID
1302
+
1303
+ if language is None or language.lower() == "none":
1304
+ return None
1305
+ if language in LANG_IDS:
1306
+ return language
1307
+ key = language.lower()
1308
+ if key in LANG_NAME_TO_ID:
1309
+ return LANG_NAME_TO_ID[key]
1310
+ logger.warning(
1311
+ f"Language '{language}' is not recognized. "
1312
+ f"Please use a valid language ID (e.g., 'en', 'zh', 'ja', 'de') "
1313
+ f"or a full language name (e.g., 'English', 'Chinese', 'Japanese'). "
1314
+ f"See supported_language_ids() or supported_language_names() for details. "
1315
+ f"Falling back to None (language-agnostic mode)."
1316
+ )
1317
+ return None
1318
+
1319
+
1320
+ def _resolve_instruct(
1321
+ instruct: Optional[str], use_zh: bool = False
1322
+ ) -> Union[str, None]:
1323
+ """Validate and normalise a voice-design instruct string.
1324
+
1325
+ Supported instruct items (case-insensitive for English):
1326
+
1327
+ English (comma + space separated):
1328
+ gender: male, female
1329
+ age: child, teenager, young adult, middle-aged, elderly
1330
+ pitch: very low pitch, low pitch, moderate pitch,
1331
+ high pitch, very high pitch
1332
+ style: whisper
1333
+ accent: american accent, british accent, australian accent, ...
1334
+
1335
+ Chinese (full-width comma separated):
1336
+ gender: 男, 女
1337
+ age: 儿童, 少年, 青年, 中年, 老年
1338
+ pitch: 极低音调, 低音调, 中音调, 高音调, 极高音调
1339
+ style: 耳语
1340
+ dialect: 河南话, 陕西话, 四川话, 贵州话, 云南话,
1341
+ 桂林话, 济南话, 石家庄话, 甘肃话, 宁夏话,
1342
+ 青岛话, 东北话
1343
+
1344
+ Minor issues (auto-fixed):
1345
+ - Wrong separator (half-width comma in Chinese instruct or
1346
+ full-width comma in English instruct)
1347
+ - Leading / trailing commas
1348
+
1349
+ Major issues (raise ``ValueError``):
1350
+ - Unsupported or misspelled instruct items
1351
+ - Suggestions are offered for close matches
1352
+
1353
+ Args:
1354
+ instruct: Raw instruct string, or ``None``.
1355
+ use_zh: If True, normalise all items to Chinese (used when the
1356
+ synthesis text contains Chinese and no accent is specified).
1357
+
1358
+ Returns:
1359
+ Normalised instruct string, or ``None``.
1360
+
1361
+ Raises:
1362
+ ValueError: if any instruct item is unsupported or misspelled.
1363
+ """
1364
+ if instruct is None:
1365
+ return None
1366
+
1367
+ instruct_str = instruct.strip()
1368
+ if not instruct_str:
1369
+ return None
1370
+
1371
+ # Split on both half-width and full-width commas
1372
+ raw_items = re.split(r"\s*[,,]\s*", instruct_str)
1373
+ raw_items = [x for x in raw_items if x]
1374
+
1375
+ # Validate each item
1376
+ unknown = []
1377
+ normalised = []
1378
+ for raw in raw_items:
1379
+ n = raw.strip().lower()
1380
+ if n in _INSTRUCT_ALL_VALID:
1381
+ normalised.append(n)
1382
+ else:
1383
+ sug = difflib.get_close_matches(n, _INSTRUCT_ALL_VALID, n=1, cutoff=0.6)
1384
+ unknown.append((raw, n, sug[0] if sug else None))
1385
+
1386
+ if unknown:
1387
+ lines = []
1388
+ for raw, n, sug in unknown:
1389
+ if sug:
1390
+ lines.append(f" '{raw}' -> '{n}' (unsupported; did you mean '{sug}'?)")
1391
+ else:
1392
+ lines.append(f" '{raw}' -> '{n}' (unsupported)")
1393
+ err = (
1394
+ f"Unsupported instruct items found in {instruct_str}:\n"
1395
+ + "\n".join(lines)
1396
+ + "\n\nValid English items: "
1397
+ + ", ".join(sorted(_INSTRUCT_VALID_EN))
1398
+ + "\nValid Chinese items: "
1399
+ + ",".join(sorted(_INSTRUCT_VALID_ZH))
1400
+ + "\n\nTip: Use only English or only Chinese instructs. "
1401
+ "English instructs should use comma + space (e.g. "
1402
+ "'male, indian accent'),\nChinese instructs should use full-width "
1403
+ "comma (e.g. '男,河南话')."
1404
+ )
1405
+ raise ValueError(err)
1406
+
1407
+ # --- Language consistency: dialect forces Chinese, accent forces English ---
1408
+ has_dialect = any(n.endswith("话") for n in normalised)
1409
+ has_accent = any(" accent" in n for n in normalised)
1410
+
1411
+ if has_dialect and has_accent:
1412
+ raise ValueError(
1413
+ "Cannot mix Chinese dialect and English accent in a single instruct. "
1414
+ "Dialects are for Chinese speech, accents for English speech."
1415
+ )
1416
+
1417
+ if has_dialect:
1418
+ use_zh = True
1419
+ elif has_accent:
1420
+ use_zh = False
1421
+
1422
+ # --- Unify to single language ---
1423
+ if use_zh:
1424
+ normalised = [_INSTRUCT_EN_TO_ZH.get(n, n) for n in normalised]
1425
+ else:
1426
+ normalised = [_INSTRUCT_ZH_TO_EN.get(n, n) for n in normalised]
1427
+
1428
+ # --- Category conflict check ---
1429
+ conflicts = []
1430
+ for cat in _INSTRUCT_MUTUALLY_EXCLUSIVE:
1431
+ hits = [n for n in normalised if n in cat]
1432
+ if len(hits) > 1:
1433
+ conflicts.append(hits)
1434
+ if conflicts:
1435
+ parts = []
1436
+ for group in conflicts:
1437
+ parts.append(" vs ".join(f"'{x}'" for x in group))
1438
+ raise ValueError(
1439
+ "Conflicting instruct items within the same category: "
1440
+ + "; ".join(parts)
1441
+ + ". Each category (gender, age, pitch, style, accent, dialect) "
1442
+ "allows at most one item."
1443
+ )
1444
+
1445
+ # Determine separator based on language
1446
+ has_zh = any(any("\u4e00" <= c <= "\u9fff" for c in n) for n in normalised)
1447
+ separator = "," if has_zh else ", "
1448
+
1449
+ return separator.join(normalised)
1450
+
1451
+
1452
+ def _filter_top_k(logits: torch.Tensor, ratio: float = 0.1) -> torch.Tensor:
1453
+ k = math.ceil(ratio * logits.shape[-1])
1454
+ val, ind = logits.topk(k, dim=-1)
1455
+ probs = torch.full_like(logits, float("-inf"))
1456
+ probs.scatter_(-1, ind, val)
1457
+ return probs
1458
+
1459
+
1460
+ def _gumbel_sample(logits: torch.Tensor, temperature: float) -> torch.Tensor:
1461
+ scaled_logits = logits / temperature
1462
+ u = torch.rand_like(scaled_logits)
1463
+ gumbel_noise = -torch.log(-torch.log(u + 1e-10) + 1e-10)
1464
+ return scaled_logits + gumbel_noise
1465
+
1466
+
1467
+ def _get_time_steps(
1468
+ t_start: float = 0.0,
1469
+ t_end: float = 1.0,
1470
+ num_step: int = 10,
1471
+ t_shift: float = 1.0,
1472
+ device: torch.device = torch.device("cpu"),
1473
+ ) -> torch.Tensor:
1474
+ timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
1475
+ timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
1476
+ return timesteps
1477
+
1478
+
1479
+ def _combine_text(text, ref_text: Optional[str] = None) -> str:
1480
+
1481
+ # combine with reference text if not None
1482
+ if ref_text:
1483
+ full_text = ref_text.strip() + " " + text.strip()
1484
+ else:
1485
+ full_text = text.strip()
1486
+
1487
+ # replace \n with .
1488
+ full_text = re.sub(r"[ \t]*\r?\n[\s]*", ".", full_text)
1489
+
1490
+ # remove spaces around chinese characters
1491
+ chinese_range = r"[\u4e00-\u9fff]"
1492
+ pattern = rf"(?<={chinese_range})\s+|\s+(?={chinese_range})"
1493
+ full_text = re.sub(pattern, "", full_text)
1494
+ return full_text
1495
+
1496
+
1497
+ # ---------------------------------------------------------------------------
1498
+ # Register with HuggingFace Auto classes
1499
+ # ---------------------------------------------------------------------------
1500
+
1501
+ AutoConfig.register("omnivoice", OmniVoiceConfig)
1502
+ AutoModel.register(OmniVoiceConfig, OmniVoice)
omnivoice/scripts/__init__.py ADDED
File without changes
omnivoice/scripts/denoise_audio.py ADDED
@@ -0,0 +1,1048 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Denoise audio with Sidon and pack results into WebDataset shards.
19
+
20
+ Supports two input modes:
21
+
22
+ 1. WebDataset manifest (data.lst):
23
+ python denoise_audio.py \
24
+ --input_manifest data.lst \
25
+ --tar_output_pattern output/audios/shard-%06d.tar \
26
+ --jsonl_output_pattern output/txts/shard-%06d.jsonl \
27
+ --feature_extractor_path sidon-v0.1/feature_extractor_cuda.pt \
28
+ --decoder_path sidon-v0.1/decoder_cuda.pt
29
+
30
+ 2. Raw JSONL (each line: {"id": "...", "audio_path": "...", ...}):
31
+ python denoise_audio.py \
32
+ --input_jsonl data.jsonl \
33
+ --tar_output_pattern output/audios/shard-%06d.tar \
34
+ --jsonl_output_pattern output/txts/shard-%06d.jsonl \
35
+ --feature_extractor_path sidon-v0.1/feature_extractor_cuda.pt \
36
+ --decoder_path sidon-v0.1/decoder_cuda.pt
37
+
38
+ Output structure:
39
+ output_dir/
40
+ ├── audios/ # WebDataset tar shards (.flac audio + .json metadata)
41
+ │ ├── shard_000000.tar
42
+ │ └── ...
43
+ ├── txts/ # Per-shard JSONL metadata
44
+ │ ├── shard_000000.jsonl
45
+ │ └── ...
46
+ ├── data.lst # Manifest: <tar_path> <jsonl_path> <sample_count> <total_duration>
47
+ └── errors.jsonl # Failed samples with error details
48
+ """
49
+
50
+ from __future__ import annotations
51
+
52
+ import argparse
53
+ import io
54
+ import json
55
+ import logging
56
+ import os
57
+ import pickle
58
+ import struct
59
+ import subprocess
60
+ import sys
61
+ import threading
62
+ from concurrent.futures import FIRST_COMPLETED, Future, wait
63
+ from dataclasses import dataclass
64
+ from pathlib import Path
65
+ from typing import Any, Dict, List, Optional, Sequence, Union
66
+
67
+ import numpy as np
68
+ import torch
69
+ import torchaudio
70
+ import webdataset as wds
71
+ from torch.utils.data import DataLoader
72
+ from tqdm.auto import tqdm
73
+
74
+ from omnivoice.data.batching import StreamLengthGroupDataset
75
+ from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
76
+ from omnivoice.utils.common import str2bool
77
+
78
+ SIDON_INPUT_SAMPLE_RATE = 16_000
79
+ SIDON_OUTPUT_SAMPLE_RATE = 48_000
80
+
81
+
82
+ def build_parser() -> argparse.ArgumentParser:
83
+ parser = argparse.ArgumentParser(description=__doc__)
84
+
85
+ # ── Input (mutually exclusive) ──
86
+ parser.add_argument(
87
+ "--input_manifest",
88
+ default=None,
89
+ help="WebDataset manifest (data.lst). Each line: "
90
+ "<tar_path> <jsonl_path> <num_items> <duration>",
91
+ )
92
+ parser.add_argument(
93
+ "--input_jsonl",
94
+ default=None,
95
+ help="Raw JSONL file. Each line: " '{"id": "...", "audio_path": "...", ...}',
96
+ )
97
+
98
+ # ── Output ──
99
+ parser.add_argument(
100
+ "--tar_output_pattern",
101
+ default=None,
102
+ help="Tar shard pattern, e.g. output/audios/shard_%%06d.tar",
103
+ )
104
+ parser.add_argument(
105
+ "--jsonl_output_pattern",
106
+ default=None,
107
+ help="JSONL shard pattern, e.g. output/txts/shard_%%06d.jsonl",
108
+ )
109
+ parser.add_argument(
110
+ "--samples_per_shard",
111
+ type=int,
112
+ default=1_000,
113
+ help="Maximum records per output shard",
114
+ )
115
+
116
+ # ── Model ──
117
+ parser.add_argument(
118
+ "--feature_extractor_path",
119
+ default=None,
120
+ help="Path to feature_extractor_cuda.pt",
121
+ )
122
+ parser.add_argument(
123
+ "--decoder_path",
124
+ default=None,
125
+ help="Path to decoder_cuda.pt",
126
+ )
127
+ parser.add_argument(
128
+ "--target_sample_rate",
129
+ type=int,
130
+ default=24_000,
131
+ help="Sample rate of the denoised output audio",
132
+ )
133
+
134
+ # ── Filtering ──
135
+ parser.add_argument(
136
+ "--min_length",
137
+ type=float,
138
+ default=0.0,
139
+ help="Minimum audio duration in seconds",
140
+ )
141
+ parser.add_argument(
142
+ "--max_length",
143
+ type=float,
144
+ default=80.0,
145
+ help="Maximum audio duration in seconds",
146
+ )
147
+
148
+ # ── Batching ──
149
+ parser.add_argument(
150
+ "--batch_duration",
151
+ type=float,
152
+ default=200.0,
153
+ help="Target batch duration in seconds for dynamic batching",
154
+ )
155
+ parser.add_argument(
156
+ "--max_sample",
157
+ type=int,
158
+ default=32,
159
+ help="Maximum samples per batch for dynamic batching",
160
+ )
161
+
162
+ # ── Distributed ──
163
+ parser.add_argument(
164
+ "--num_machines",
165
+ type=int,
166
+ default=1,
167
+ help="Total number of machines for distributed runs",
168
+ )
169
+ parser.add_argument(
170
+ "--machine_index",
171
+ type=int,
172
+ default=0,
173
+ help="Zero-based machine index when distributing across multiple "
174
+ "machines (e.g. 0, 1, ... num_machines-1)",
175
+ )
176
+
177
+ # ── Parallelism ──
178
+ parser.add_argument(
179
+ "--nj_per_gpu",
180
+ type=int,
181
+ default=1,
182
+ help="Worker processes per GPU (default 1)",
183
+ )
184
+ parser.add_argument(
185
+ "--loader_workers",
186
+ type=int,
187
+ default=16,
188
+ help="PyTorch DataLoader worker threads",
189
+ )
190
+
191
+ # ── Data order (JSONL mode) ──
192
+ parser.add_argument(
193
+ "--shuffle",
194
+ type=str2bool,
195
+ default=True,
196
+ help="Shuffle JSONL entries",
197
+ )
198
+ parser.add_argument(
199
+ "--shuffle_seed",
200
+ type=int,
201
+ default=42,
202
+ help="Seed for JSONL shuffle",
203
+ )
204
+
205
+ # ── Error handling ──
206
+ parser.add_argument(
207
+ "--skip_errors",
208
+ action="store_true",
209
+ help="Skip items that fail to denoise instead of aborting",
210
+ )
211
+ parser.add_argument(
212
+ "--_subprocess_worker",
213
+ action="store_true",
214
+ help=argparse.SUPPRESS,
215
+ )
216
+ return parser
217
+
218
+
219
+ # ---------------------------------------------------------------------------
220
+ # Utilities
221
+ # ---------------------------------------------------------------------------
222
+
223
+
224
+ def count_lines(path: str) -> int:
225
+ """Count newlines efficiently by reading binary chunks."""
226
+ count = 0
227
+ with open(path, "rb") as f:
228
+ for chunk in iter(lambda: f.read(1 << 20), b""):
229
+ count += chunk.count(b"\n")
230
+ return count
231
+
232
+
233
+ PaddingStrategy = Union[bool, str]
234
+ ReturnType = Union[torch.Tensor, np.ndarray]
235
+
236
+
237
+ def extract_seamless_m4t_features(
238
+ raw_speech: Union[torch.Tensor, List[float], List[torch.Tensor], List[List[float]]],
239
+ sampling_rate: int = 16000,
240
+ num_mel_bins: int = 80,
241
+ frame_length: int = 25,
242
+ frame_shift: int = 10,
243
+ preemphasis_coefficient: float = 0.97,
244
+ dither: float = 0.0,
245
+ window_type: str = "povey",
246
+ do_normalize_per_mel_bins: bool = True,
247
+ stride: int = 2,
248
+ padding: PaddingStrategy = "longest",
249
+ max_length: Optional[int] = None,
250
+ pad_to_multiple_of: Optional[int] = 2,
251
+ return_tensors: Optional[str] = "pt",
252
+ return_attention_mask: bool = True,
253
+ padding_value: float = 0.0,
254
+ device: torch.device = torch.device("cpu"),
255
+ ) -> Dict[str, ReturnType]:
256
+ """Extract SeamlessM4T features using Torch-only operators."""
257
+ if not isinstance(raw_speech, list):
258
+ raw_speech = [raw_speech]
259
+
260
+ processed_speech = [
261
+ torch.as_tensor(sample, dtype=torch.float32, device=device)
262
+ for sample in raw_speech
263
+ ]
264
+
265
+ features: List[torch.Tensor] = []
266
+ for waveform in processed_speech:
267
+ if waveform.ndim > 1:
268
+ waveform = waveform[0]
269
+ waveform_tensor = waveform.unsqueeze(0)
270
+ feature = torchaudio.compliance.kaldi.fbank(
271
+ waveform=waveform_tensor,
272
+ sample_frequency=sampling_rate,
273
+ num_mel_bins=num_mel_bins,
274
+ frame_length=frame_length,
275
+ frame_shift=frame_shift,
276
+ dither=dither,
277
+ preemphasis_coefficient=preemphasis_coefficient,
278
+ remove_dc_offset=True,
279
+ window_type=window_type,
280
+ use_energy=False,
281
+ energy_floor=1.192092955078125e-07,
282
+ )
283
+ features.append(feature.squeeze(0))
284
+
285
+ if do_normalize_per_mel_bins:
286
+ normalised: List[torch.Tensor] = []
287
+ for feature in features:
288
+ mean = feature.mean(0, keepdim=True)
289
+ var = feature.var(0, keepdim=True)
290
+ normalised.append((feature - mean) / torch.sqrt(var + 1e-5))
291
+ features = normalised
292
+
293
+ def _pad_batch(
294
+ features: List[torch.Tensor],
295
+ padding_strategy: PaddingStrategy = "longest",
296
+ max_length: Optional[int] = None,
297
+ pad_to_multiple_of: Optional[int] = None,
298
+ padding_value: float = 0.0,
299
+ ) -> tuple[torch.Tensor, torch.Tensor]:
300
+ if padding_strategy == "longest":
301
+ target_length = max(f.shape[0] for f in features)
302
+ elif max_length is not None:
303
+ target_length = max_length
304
+ else:
305
+ raise ValueError(
306
+ "max_length must be provided when padding_strategy is not 'longest'"
307
+ )
308
+
309
+ if pad_to_multiple_of is not None:
310
+ target_length = (
311
+ (target_length + pad_to_multiple_of - 1)
312
+ // pad_to_multiple_of
313
+ * pad_to_multiple_of
314
+ )
315
+
316
+ batch_size = len(features)
317
+ feature_dim = features[0].shape[1]
318
+ device = features[0].device
319
+
320
+ padded_features = torch.full(
321
+ (batch_size, target_length, feature_dim),
322
+ padding_value,
323
+ dtype=torch.float32,
324
+ device=device,
325
+ )
326
+ attention_mask = torch.zeros(
327
+ (batch_size, target_length),
328
+ dtype=torch.int64,
329
+ device=device,
330
+ )
331
+
332
+ for index, feature_tensor in enumerate(features):
333
+ seq_len = feature_tensor.shape[0]
334
+ padded_features[index, :seq_len] = feature_tensor
335
+ attention_mask[index, :seq_len] = 1
336
+
337
+ return padded_features, attention_mask
338
+
339
+ input_features, attention_mask = _pad_batch(
340
+ features,
341
+ padding_strategy=padding,
342
+ max_length=max_length,
343
+ pad_to_multiple_of=pad_to_multiple_of,
344
+ padding_value=padding_value,
345
+ )
346
+
347
+ batch_size, num_frames, num_channels = input_features.shape
348
+ new_num_frames = (num_frames // stride) * stride
349
+ input_features = input_features[:, :new_num_frames, :]
350
+ if return_attention_mask:
351
+ attention_mask = attention_mask[:, :new_num_frames]
352
+
353
+ input_features = input_features.reshape(
354
+ batch_size, new_num_frames // stride, num_channels * stride
355
+ )
356
+
357
+ output: Dict[str, ReturnType] = {"input_features": input_features}
358
+ if return_attention_mask:
359
+ output["attention_mask"] = attention_mask[:, 1::stride]
360
+
361
+ if return_tensors == "np":
362
+ for key, value in output.items():
363
+ output[key] = value.cpu().numpy() # type: ignore[assignment]
364
+
365
+ return output
366
+
367
+
368
+ def serialise_flac(key: str, waveform: torch.Tensor, sample_rate: int) -> dict:
369
+ buffer = io.BytesIO()
370
+ audio = waveform.to(dtype=torch.float32).cpu()
371
+ if audio.ndim == 1:
372
+ audio = audio.unsqueeze(0)
373
+ torchaudio.save(buffer, audio, sample_rate, format="flac", bits_per_sample=16)
374
+ return {"__key__": key, "flac": buffer.getvalue()}
375
+
376
+
377
+ def _normalise_value(value: Any) -> Any:
378
+ """Convert tensors and NumPy scalars to serialisable Python objects."""
379
+ if isinstance(value, torch.Tensor):
380
+ if value.ndim == 0:
381
+ return value.item()
382
+ return value.cpu().tolist()
383
+ if isinstance(value, np.generic):
384
+ return value.item()
385
+ if isinstance(value, np.ndarray):
386
+ return value.tolist()
387
+ return value
388
+
389
+
390
+ def _encode_metadata(metadata: dict[str, Any]) -> bytes:
391
+ cleaned: dict[str, Any] = {}
392
+ for key, value in metadata.items():
393
+ if value is None:
394
+ continue
395
+ cleaned[key] = _normalise_value(value)
396
+ return json.dumps(cleaned, ensure_ascii=False).encode("utf-8")
397
+
398
+
399
+ # ---------------------------------------------------------------------------
400
+ # Denoising model
401
+ # ---------------------------------------------------------------------------
402
+
403
+
404
+ class SpeechDenoisingProcessor:
405
+ """Run the TorchScripted feature extractor and decoder."""
406
+
407
+ def __init__(
408
+ self,
409
+ feature_extractor_path: str,
410
+ decoder_path: str,
411
+ device: str,
412
+ ) -> None:
413
+ self.device = torch.device(device)
414
+ self.feature_extractor = torch.jit.load(
415
+ feature_extractor_path, map_location=self.device
416
+ )
417
+ self.decoder = torch.jit.load(decoder_path, map_location=self.device)
418
+ self.feature_extractor.eval()
419
+ self.decoder.eval()
420
+
421
+ @torch.inference_mode()
422
+ def process(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
423
+ return self.process_batch([waveform], [sample_rate])[0]
424
+
425
+ @torch.inference_mode()
426
+ def process_batch(
427
+ self,
428
+ waveforms: Sequence[torch.Tensor] | torch.Tensor,
429
+ sample_rates: Optional[Sequence[int]] = None,
430
+ expected_lengths: Optional[Sequence[int]] = None,
431
+ ) -> List[torch.Tensor]:
432
+ if expected_lengths is None:
433
+ expected_lengths: list[int] = []
434
+ for waveform, sample_rate in zip(waveforms, sample_rates):
435
+ duration_seconds = waveform.shape[-1] / float(sample_rate)
436
+ expected_lengths.append(
437
+ int(round(duration_seconds * SIDON_OUTPUT_SAMPLE_RATE))
438
+ )
439
+ waveforms = torch.nn.functional.pad(waveforms, (0, 24000))
440
+
441
+ features = extract_seamless_m4t_features(
442
+ [x for x in waveforms],
443
+ return_tensors="pt",
444
+ padding_value=1.0,
445
+ device=self.device,
446
+ )
447
+ feature_tensor = self.feature_extractor(
448
+ features["input_features"].to(self.device)
449
+ )["last_hidden_state"]
450
+ restored_waveforms = self.decoder(feature_tensor.transpose(1, 2)).cpu()
451
+
452
+ results: List[torch.Tensor] = []
453
+ for sample_idx, sample in enumerate(restored_waveforms):
454
+ restored_waveform = sample.view(-1)
455
+ target_length = expected_lengths[sample_idx]
456
+ current_length = restored_waveform.shape[-1]
457
+ if target_length > 0 and current_length != target_length:
458
+ diff = target_length - current_length
459
+ if diff > 0:
460
+ restored_waveform = torch.nn.functional.pad(
461
+ restored_waveform, (0, diff)
462
+ )
463
+ elif diff < 0:
464
+ restored_waveform = restored_waveform[:target_length]
465
+ results.append(restored_waveform.contiguous())
466
+
467
+ return results
468
+
469
+
470
+ # ---------------------------------------------------------------------------
471
+ # Batch collation
472
+ # ---------------------------------------------------------------------------
473
+
474
+
475
+ class CollateFunction:
476
+ """Collate a list of samples into a padded batch."""
477
+
478
+ def __init__(
479
+ self,
480
+ sample_rate: int,
481
+ skip_errors: bool,
482
+ ) -> None:
483
+ self.sample_rate = sample_rate
484
+ self.skip_errors = skip_errors
485
+
486
+ def __call__(self, samples: Sequence[dict[str, Any]]) -> CollatedBatch:
487
+ keys: list[str] = []
488
+ waveforms: list[torch.Tensor] = []
489
+ durations: list[float] = []
490
+ metadata: list[dict[str, Any]] = []
491
+
492
+ for sample in samples:
493
+ keys.append(sample["label"]["id"])
494
+ waveforms.append(sample["audio"].squeeze(0))
495
+ durations.append(sample["audio"].size(-1) / self.sample_rate)
496
+ metadata.append(sample["label"])
497
+ waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
498
+
499
+ return CollatedBatch(
500
+ keys=keys, waveforms=waveforms, durations=durations, metadata=metadata
501
+ )
502
+
503
+
504
+ @dataclass
505
+ class CollatedBatch:
506
+ """Batch payload returned by the DataLoader collate function."""
507
+
508
+ keys: list[str]
509
+ waveforms: list[torch.Tensor]
510
+ durations: list[float]
511
+ metadata: list[dict[str, Any]]
512
+
513
+ @property
514
+ def size(self) -> int:
515
+ return len(self.keys)
516
+
517
+
518
+ # ---------------------------------------------------------------------------
519
+ # Subprocess-based GPU worker pool
520
+ # ---------------------------------------------------------------------------
521
+ #
522
+ # Problem: PyTorch ≥2.8 caches CUDA device state at import time. Neither
523
+ # forkserver nor spawn lets us change CUDA_VISIBLE_DEVICES *before* the CUDA
524
+ # runtime captures the device list. The only reliable approach is to launch
525
+ # each worker as a **subprocess** with CUDA_VISIBLE_DEVICES set in the
526
+ # subprocess environment, guaranteeing it takes effect before `import torch`.
527
+ #
528
+ # Protocol (parent ↔ child, length-prefixed pickle over stdin/stdout):
529
+ # Parent → child: 4-byte LE uint32 length + pickle(CollatedBatch)
530
+ # Child → parent: 4-byte LE uint32 length + pickle(result dict)
531
+ # Shutdown signal: 4 zero bytes (length == 0)
532
+
533
+
534
+ def _subprocess_recv():
535
+ """Read a length-prefixed pickled object from stdin. Returns None on shutdown."""
536
+ raw = sys.stdin.buffer.read(4)
537
+ if len(raw) < 4:
538
+ return None
539
+ (length,) = struct.unpack("<I", raw)
540
+ if length == 0:
541
+ return None
542
+ data = sys.stdin.buffer.read(length)
543
+ return pickle.loads(data)
544
+
545
+
546
+ def _subprocess_send(obj):
547
+ """Send a pickled object with a 4-byte length prefix to stdout."""
548
+ data = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
549
+ sys.stdout.buffer.write(struct.pack("<I", len(data)))
550
+ sys.stdout.buffer.write(data)
551
+ sys.stdout.buffer.flush()
552
+
553
+
554
+ def subprocess_worker_main():
555
+ """Entry point for a GPU worker subprocess.
556
+
557
+ Expected environment: CUDA_VISIBLE_DEVICES already set by the parent.
558
+ Receives initargs via stdin, then processes batches in a loop.
559
+ """
560
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] [Worker PID %(process)d] %(message)s"
561
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
562
+
563
+ initargs = _subprocess_recv()
564
+ feature_extractor_path, decoder_path = initargs
565
+
566
+ device = "cpu"
567
+ if torch.cuda.is_available():
568
+ torch.cuda.set_device(0)
569
+ device = "cuda:0"
570
+ else:
571
+ logging.warning("CUDA not available in worker subprocess.")
572
+
573
+ logging.info(
574
+ f"Worker PID={os.getpid()}, "
575
+ f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, device={device}"
576
+ )
577
+
578
+ processor = SpeechDenoisingProcessor(
579
+ feature_extractor_path=feature_extractor_path,
580
+ decoder_path=decoder_path,
581
+ device=device,
582
+ )
583
+
584
+ # Process batches until shutdown signal
585
+ while True:
586
+ msg = _subprocess_recv()
587
+ if msg is None:
588
+ break
589
+ req_id = msg["_req_id"]
590
+ batch = msg["_batch"]
591
+ try:
592
+ cleaned_waveforms = processor.process_batch(
593
+ batch.waveforms,
594
+ expected_lengths=[
595
+ round(d * SIDON_OUTPUT_SAMPLE_RATE) for d in batch.durations
596
+ ],
597
+ )
598
+ cleaned_cpu = [w.cpu() for w in cleaned_waveforms]
599
+ result = {
600
+ "_req_id": req_id,
601
+ "status": "success",
602
+ "keys": batch.keys,
603
+ "results": cleaned_cpu,
604
+ "metadata": batch.metadata,
605
+ "size": batch.size,
606
+ }
607
+ except Exception as e:
608
+ result = {
609
+ "_req_id": req_id,
610
+ "status": "error",
611
+ "keys": batch.keys,
612
+ "error": str(e),
613
+ "size": batch.size,
614
+ }
615
+ _subprocess_send(result)
616
+
617
+
618
+ class _GPUWorker:
619
+ """Handle to a single GPU worker subprocess."""
620
+
621
+ def __init__(self, physical_gpu_id, feature_extractor_path, decoder_path):
622
+ env = os.environ.copy()
623
+ if physical_gpu_id is not None:
624
+ env["CUDA_VISIBLE_DEVICES"] = str(physical_gpu_id)
625
+ self.proc = subprocess.Popen(
626
+ [
627
+ sys.executable,
628
+ "-m",
629
+ "omnivoice.scripts.denoise_audio",
630
+ "--_subprocess_worker",
631
+ ],
632
+ stdin=subprocess.PIPE,
633
+ stdout=subprocess.PIPE,
634
+ env=env,
635
+ )
636
+ # Send init args
637
+ init_data = pickle.dumps(
638
+ (feature_extractor_path, decoder_path), protocol=pickle.HIGHEST_PROTOCOL
639
+ )
640
+ self.proc.stdin.write(struct.pack("<I", len(init_data)))
641
+ self.proc.stdin.write(init_data)
642
+ self.proc.stdin.flush()
643
+ self._lock = threading.Lock()
644
+
645
+ def submit(self, batch_with_id):
646
+ """Send a batch dict (containing _req_id + _batch) for processing."""
647
+ with self._lock:
648
+ data = pickle.dumps(batch_with_id, protocol=pickle.HIGHEST_PROTOCOL)
649
+ self.proc.stdin.write(struct.pack("<I", len(data)))
650
+ self.proc.stdin.write(data)
651
+ self.proc.stdin.flush()
652
+
653
+ def read_result(self):
654
+ """Blocking read for one result."""
655
+ raw = self.proc.stdout.read(4)
656
+ if len(raw) < 4:
657
+ return None
658
+ (length,) = struct.unpack("<I", raw)
659
+ if length == 0:
660
+ return None
661
+ data = self.proc.stdout.read(length)
662
+ return pickle.loads(data)
663
+
664
+ def shutdown(self):
665
+ """Send shutdown signal and wait for process."""
666
+ try:
667
+ with self._lock:
668
+ self.proc.stdin.write(struct.pack("<I", 0))
669
+ self.proc.stdin.flush()
670
+ except Exception:
671
+ pass
672
+ self.proc.wait(timeout=30)
673
+
674
+
675
+ class GPUWorkerPool:
676
+ """Pool of GPU worker subprocesses with round-robin task submission."""
677
+
678
+ def __init__(self, pool_specs, feature_extractor_path, decoder_path):
679
+ """
680
+ Args:
681
+ pool_specs: list of (physical_gpu_id, num_workers) tuples.
682
+ feature_extractor_path: path to JIT feature extractor.
683
+ decoder_path: path to JIT decoder.
684
+ """
685
+ self.workers: list[_GPUWorker] = []
686
+ for physical_gpu_id, num_workers in pool_specs:
687
+ for _ in range(num_workers):
688
+ self.workers.append(
689
+ _GPUWorker(physical_gpu_id, feature_extractor_path, decoder_path)
690
+ )
691
+ self._rr = 0
692
+ self._futures: dict[int, Future] = {}
693
+ self._futures_lock = threading.Lock()
694
+ self._next_id = 0
695
+ # Start reader threads for each worker
696
+ self._reader_threads = []
697
+ for worker in self.workers:
698
+ t = threading.Thread(target=self._reader_loop, args=(worker,), daemon=True)
699
+ t.start()
700
+ self._reader_threads.append(t)
701
+
702
+ def _reader_loop(self, worker):
703
+ while True:
704
+ result = worker.read_result()
705
+ if result is None:
706
+ break
707
+ req_id = result.pop("_req_id", None)
708
+ with self._futures_lock:
709
+ fut = self._futures.pop(req_id, None)
710
+ if fut is not None:
711
+ fut.set_result(result)
712
+
713
+ def submit(self, batch) -> Future:
714
+ worker = self.workers[self._rr % len(self.workers)]
715
+ self._rr += 1
716
+ with self._futures_lock:
717
+ req_id = self._next_id
718
+ self._next_id += 1
719
+ fut = Future()
720
+ self._futures[req_id] = fut
721
+ batch_dict = {
722
+ "_req_id": req_id,
723
+ "_batch": batch,
724
+ }
725
+ worker.submit(batch_dict)
726
+ return fut
727
+
728
+ def shutdown(self):
729
+ for worker in self.workers:
730
+ worker.shutdown()
731
+ for t in self._reader_threads:
732
+ t.join(timeout=5)
733
+
734
+
735
+ # ---------------------------------------------------------------------------
736
+ # Main
737
+ # ---------------------------------------------------------------------------
738
+
739
+
740
+ def main() -> None:
741
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
742
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
743
+ parser = build_parser()
744
+ args = parser.parse_args()
745
+
746
+ # ── Subprocess worker mode ──
747
+ if args._subprocess_worker:
748
+ subprocess_worker_main()
749
+ return
750
+
751
+ # Validate input arguments
752
+ assert args.tar_output_pattern is not None, "--tar_output_pattern is required."
753
+ assert args.jsonl_output_pattern is not None, "--jsonl_output_pattern is required."
754
+ assert bool(args.input_manifest) != bool(
755
+ args.input_jsonl
756
+ ), "Exactly one of --input_manifest or --input_jsonl must be provided."
757
+
758
+ if args.num_machines > 1:
759
+ assert (
760
+ 0 <= args.machine_index < args.num_machines
761
+ ), f"machine_index {args.machine_index} must be in [0, {args.num_machines})"
762
+
763
+ # ── Build base dataset and count total samples ──
764
+ if args.input_jsonl:
765
+ logging.info(f"Input mode: raw JSONL ({args.input_jsonl})")
766
+ total_samples = count_lines(args.input_jsonl)
767
+ base_dataset = JsonlDatasetReader(
768
+ args.input_jsonl,
769
+ sample_rate=SIDON_INPUT_SAMPLE_RATE,
770
+ shuffle=args.shuffle,
771
+ shuffle_seed=args.shuffle_seed,
772
+ )
773
+ loader_workers = args.loader_workers
774
+ else:
775
+ logging.info(f"Input mode: WebDataset manifest ({args.input_manifest})")
776
+ manifest_num_lines = count_lines(args.input_manifest)
777
+ loader_workers = min(args.loader_workers, manifest_num_lines)
778
+ total_samples = 0
779
+ manifests = []
780
+ with open(args.input_manifest, "r", encoding="utf-8") as f:
781
+ for line_id, line in tqdm(
782
+ enumerate(f),
783
+ total=manifest_num_lines,
784
+ desc="Calculating dataset length",
785
+ ):
786
+ items = line.strip().split(" ")
787
+ tar_path, jsonl_path, num_items, duration = (
788
+ items[0],
789
+ items[1],
790
+ int(items[2]),
791
+ float(items[3]),
792
+ )
793
+ assert os.path.exists(tar_path), f"File {tar_path} does not exist."
794
+ assert os.path.exists(jsonl_path), f"File {jsonl_path} does not exist."
795
+ assert jsonl_path.endswith(
796
+ ".jsonl"
797
+ ), f"File {jsonl_path} is not a .jsonl file."
798
+ if (
799
+ args.num_machines > 1
800
+ and line_id % args.num_machines != args.machine_index
801
+ ):
802
+ continue
803
+ total_samples += num_items
804
+ manifests.append((tar_path, jsonl_path, num_items, duration))
805
+ logging.info(
806
+ f"Total shards: {manifest_num_lines}, "
807
+ f"Shards for current index: {len(manifests)}"
808
+ )
809
+ base_dataset = WebDatasetReader(
810
+ manifests=manifests,
811
+ sample_rate=SIDON_INPUT_SAMPLE_RATE,
812
+ evaluation=True,
813
+ )
814
+
815
+ # ── Dynamic batching + DataLoader ──
816
+ batched_dataset = StreamLengthGroupDataset(
817
+ dataset=base_dataset,
818
+ batch_duration=args.batch_duration,
819
+ max_sample=args.max_sample,
820
+ min_length=args.min_length,
821
+ max_length=args.max_length,
822
+ )
823
+
824
+ collate_fn = CollateFunction(
825
+ skip_errors=args.skip_errors,
826
+ sample_rate=SIDON_INPUT_SAMPLE_RATE,
827
+ )
828
+
829
+ dataloader = DataLoader(
830
+ dataset=batched_dataset,
831
+ batch_size=None,
832
+ collate_fn=collate_fn,
833
+ num_workers=loader_workers,
834
+ prefetch_factor=10 if loader_workers > 0 else None,
835
+ pin_memory=True,
836
+ persistent_workers=loader_workers > 0,
837
+ )
838
+
839
+ # ── Multi-GPU process pool ──
840
+ num_devices = torch.cuda.device_count()
841
+ if num_devices == 0:
842
+ logging.warning("No GPUs detected - using CPU for processing")
843
+ num_processes = args.nj_per_gpu
844
+ else:
845
+ num_processes = num_devices * args.nj_per_gpu
846
+ logging.info(
847
+ f"GPU count: {num_devices}, Processes per GPU: {args.nj_per_gpu}, "
848
+ f"Total processes: {num_processes}"
849
+ )
850
+
851
+ # Build a list of (physical_gpu_id, num_workers) for each pool.
852
+ # When num_devices == 0 we use a single CPU pool.
853
+ if num_devices == 0:
854
+ pool_specs = [(None, num_processes)]
855
+ else:
856
+ pool_specs = [(gpu_id, args.nj_per_gpu) for gpu_id in range(num_devices)]
857
+
858
+ # ── Output paths ──
859
+ tar_output_pattern = str(Path(args.tar_output_pattern).expanduser())
860
+ jsonl_output_pattern = str(Path(args.jsonl_output_pattern).expanduser())
861
+ Path(tar_output_pattern).parent.mkdir(parents=True, exist_ok=True)
862
+ Path(jsonl_output_pattern).parent.mkdir(parents=True, exist_ok=True)
863
+
864
+ output_dir = Path(tar_output_pattern).parent.parent
865
+ error_log_path = str(output_dir / "errors.jsonl")
866
+ manifest_path = str(output_dir / "data.lst")
867
+
868
+ error_logger = logging.getLogger("error_log")
869
+ error_logger.setLevel(logging.ERROR)
870
+ error_logger.handlers.clear()
871
+ error_fh = logging.FileHandler(error_log_path, mode="w", encoding="utf-8")
872
+ error_fh.setFormatter(logging.Formatter("%(message)s"))
873
+ error_logger.addHandler(error_fh)
874
+
875
+ # ── Progress and shard tracking ──
876
+ processed_count = 0
877
+ error_count = 0
878
+ write_error_count = 0
879
+ failed_ids = []
880
+ shard_idx = 0
881
+ shard_sample_count = 0
882
+ shard_duration = 0.0
883
+ samples_per_shard = args.samples_per_shard
884
+ shard_manifest = {}
885
+ target_sample_rate = args.target_sample_rate
886
+
887
+ tar_writer = None
888
+ jsonl_file = None
889
+
890
+ def open_new_shard():
891
+ nonlocal tar_writer, jsonl_file, shard_idx, shard_sample_count, shard_duration
892
+ if tar_writer is not None:
893
+ tar_writer.close()
894
+ if jsonl_file is not None:
895
+ jsonl_file.close()
896
+ if shard_idx > 0 and shard_sample_count > 0:
897
+ prev_idx = shard_idx - 1
898
+ shard_manifest[prev_idx] = (
899
+ os.path.abspath(tar_output_pattern % prev_idx),
900
+ os.path.abspath(jsonl_output_pattern % prev_idx),
901
+ shard_sample_count,
902
+ shard_duration,
903
+ )
904
+ tar_fname = tar_output_pattern % shard_idx
905
+ jsonl_fname = jsonl_output_pattern % shard_idx
906
+ tar_writer = wds.TarWriter(tar_fname)
907
+ jsonl_file = open(jsonl_fname, "w", encoding="utf-8")
908
+ shard_idx += 1
909
+ shard_sample_count = 0
910
+ shard_duration = 0.0
911
+
912
+ def write_sample(key, waveform, metadata):
913
+ nonlocal shard_sample_count, write_error_count, shard_duration
914
+ assert tar_writer is not None and jsonl_file is not None
915
+ try:
916
+ if target_sample_rate != SIDON_OUTPUT_SAMPLE_RATE:
917
+ waveform = torchaudio.functional.resample(
918
+ waveform,
919
+ orig_freq=SIDON_OUTPUT_SAMPLE_RATE,
920
+ new_freq=target_sample_rate,
921
+ )
922
+ waveform = (waveform / (waveform.abs().max() + 1e-7)) * 0.6
923
+
924
+ record = serialise_flac(key, waveform, target_sample_rate)
925
+ jsonl_record = _encode_metadata(metadata)
926
+ tar_writer.write(record)
927
+ jsonl_file.write(jsonl_record.decode("utf-8") + "\n")
928
+ shard_sample_count += 1
929
+ shard_duration += metadata.get("audio_duration", 0.0)
930
+ except Exception as exc:
931
+ write_error_count += 1
932
+ failed_ids.append(key)
933
+ error_logger.error(
934
+ json.dumps({"id": key, "reason": str(exc)}, ensure_ascii=False)
935
+ )
936
+ logging.error(f"Write failed for sample {key}: {exc}")
937
+
938
+ def handle_result(result):
939
+ nonlocal processed_count, error_count
940
+ if result["status"] == "success":
941
+ for key, cleaned, metadata in zip(
942
+ result["keys"], result["results"], result["metadata"]
943
+ ):
944
+ if tar_writer is None or shard_sample_count >= samples_per_shard:
945
+ open_new_shard()
946
+ write_sample(key, cleaned, metadata)
947
+ processed_count += 1
948
+ else:
949
+ error_count += result["size"]
950
+ failed_ids.extend(result["keys"])
951
+ for key in result["keys"]:
952
+ error_logger.error(
953
+ json.dumps(
954
+ {"id": key, "reason": result["error"]},
955
+ ensure_ascii=False,
956
+ )
957
+ )
958
+ if not args.skip_errors:
959
+ raise RuntimeError(
960
+ f"Batch starting with {result['keys'][0]} failed - terminating"
961
+ )
962
+ logging.warning(
963
+ f"Skipping failed batch starting with {result['keys'][0]}: "
964
+ f"{result['error']}"
965
+ )
966
+
967
+ # ── Main processing loop ──
968
+ main_progress = tqdm(total=total_samples, desc="Denoising Audio")
969
+
970
+ # Launch subprocess-based GPU workers. CUDA_VISIBLE_DEVICES is set in the
971
+ # subprocess Popen environment so it takes effect before import torch.
972
+ pool = GPUWorkerPool(pool_specs, args.feature_extractor_path, args.decoder_path)
973
+ logging.info(f"Submitting tasks... ({num_processes} subprocess workers)")
974
+ try:
975
+ futures = set()
976
+ max_pending = num_processes * 2
977
+
978
+ def drain_completed():
979
+ nonlocal futures
980
+ done, _ = wait(futures, return_when=FIRST_COMPLETED)
981
+ for f in done:
982
+ futures.discard(f)
983
+ result = f.result()
984
+ main_progress.update(result["size"])
985
+ handle_result(result)
986
+ main_progress.set_postfix(
987
+ OK=processed_count,
988
+ Err=error_count,
989
+ )
990
+
991
+ for batch in dataloader:
992
+ if batch.size == 0:
993
+ continue
994
+ if len(futures) >= max_pending:
995
+ drain_completed()
996
+ futures.add(pool.submit(batch))
997
+
998
+ logging.info("Processing remaining pending batches...")
999
+ while futures:
1000
+ drain_completed()
1001
+
1002
+ except Exception:
1003
+ logging.error("Critical error during processing", exc_info=True)
1004
+ raise
1005
+ finally:
1006
+ pool.shutdown()
1007
+ main_progress.close()
1008
+ if tar_writer is not None:
1009
+ tar_writer.close()
1010
+ if jsonl_file is not None:
1011
+ jsonl_file.close()
1012
+ if shard_idx > 0 and shard_sample_count > 0:
1013
+ last_idx = shard_idx - 1
1014
+ shard_manifest[last_idx] = (
1015
+ os.path.abspath(tar_output_pattern % last_idx),
1016
+ os.path.abspath(jsonl_output_pattern % last_idx),
1017
+ shard_sample_count,
1018
+ shard_duration,
1019
+ )
1020
+
1021
+ # ── Write manifest (data.lst) ──
1022
+ with open(manifest_path, "w", encoding="utf-8") as mf:
1023
+ for idx in sorted(shard_manifest.keys()):
1024
+ tar_path, jsonl_path, count, duration = shard_manifest[idx]
1025
+ mf.write(f"{tar_path} {jsonl_path} {count} {duration:.3f}\n")
1026
+
1027
+ # ── Summary ──
1028
+ total_failed = error_count + write_error_count
1029
+ filtered_and_skipped = total_samples - processed_count - total_failed
1030
+ logging.info(
1031
+ f"Processing Complete - Successful: {processed_count}, Failed: {total_failed}, "
1032
+ f"Filtered/Skipped: {filtered_and_skipped}, Shards written: {shard_idx}"
1033
+ )
1034
+ logging.info(f"Manifest written to: {manifest_path} ({len(shard_manifest)} shards)")
1035
+ if total_failed > 0:
1036
+ logging.info(f"Error details: {error_log_path}")
1037
+ if failed_ids and args.skip_errors:
1038
+ logging.warning(
1039
+ f"Failed sample IDs (count: {len(failed_ids)}): {failed_ids[:100]}..."
1040
+ )
1041
+ if write_error_count > 0 and not args.skip_errors:
1042
+ raise RuntimeError(
1043
+ f"{write_error_count} samples failed to write - check logs for details"
1044
+ )
1045
+
1046
+
1047
+ if __name__ == "__main__":
1048
+ main()
omnivoice/scripts/extract_audio_tokens.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Extract audio tokens from audio data and pack them into WebDataset shards.
20
+
21
+ Supports two input modes:
22
+
23
+ 1. WebDataset manifest (data.lst):
24
+ python extract_audio_tokens.py \
25
+ --input_manifest data.lst \
26
+ --tar_output_pattern output/audios/shard-%06d.tar \
27
+ --jsonl_output_pattern output/txts/shard-%06d.jsonl
28
+
29
+ 2. Raw JSONL (each line: {"id": "...", "audio_path": "...", "text": "...", ...}):
30
+ python extract_audio_tokens.py \
31
+ --input_jsonl data.jsonl \
32
+ --tar_output_pattern output/audios/shard-%06d.tar \
33
+ --jsonl_output_pattern output/txts/shard-%06d.jsonl
34
+
35
+ Output structure:
36
+ output_dir/
37
+ ├── audios/ # WebDataset tar shards (.npy audio tokens + .json metadata)
38
+ │ ├── shard_000000.tar
39
+ │ └── ...
40
+ ├── txts/ # Per-shard JSONL metadata
41
+ │ ├── shard_000000.jsonl
42
+ │ └── ...
43
+ ├── data.lst # Manifest: <tar_path> <jsonl_path> <sample_count> <total_duration>
44
+ └── errors.jsonl # Failed samples with error details
45
+ """
46
+
47
+ import argparse
48
+ import io
49
+ import json
50
+ import logging
51
+ import multiprocessing as mp
52
+ import os
53
+ import warnings
54
+ from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait
55
+ from pathlib import Path
56
+ from typing import Any
57
+
58
+ import numpy as np
59
+ import torch
60
+ import webdataset as wds
61
+ from torch.utils.data import DataLoader, IterableDataset
62
+ from tqdm.auto import tqdm
63
+ from transformers import AutoFeatureExtractor, HiggsAudioV2TokenizerModel
64
+
65
+ from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
66
+ from omnivoice.utils.common import str2bool
67
+
68
+ warnings.filterwarnings(
69
+ "ignore", category=FutureWarning, module="torch.nn.utils.weight_norm"
70
+ )
71
+
72
+ HIGGS_INPUT_SAMPLE_RATE = 24_000
73
+
74
+
75
+ # Global variables: Store tokenizer and device for each worker process
76
+ worker_tokenizer = None
77
+ worker_feature_extractor = None
78
+
79
+
80
+ def build_parser() -> argparse.ArgumentParser:
81
+ parser = argparse.ArgumentParser(description=__doc__)
82
+ parser.add_argument(
83
+ "--input_manifest",
84
+ default=None,
85
+ help="Path to input dataset manifest (data.lst).",
86
+ )
87
+ parser.add_argument(
88
+ "--input_jsonl",
89
+ default=None,
90
+ help="Path to raw JSONL file (alternative to --input_manifest).",
91
+ )
92
+ parser.add_argument(
93
+ "--tar_output_pattern",
94
+ required=True,
95
+ help="Tar shard pattern passed to WebDataset",
96
+ )
97
+ parser.add_argument(
98
+ "--jsonl_output_pattern",
99
+ required=True,
100
+ help="Jsonl shard pattern passed to WebDataset",
101
+ )
102
+ parser.add_argument(
103
+ "--samples_per_shard",
104
+ type=int,
105
+ default=1000,
106
+ help="Maximum records per shard",
107
+ )
108
+ parser.add_argument(
109
+ "--min_num_shards",
110
+ type=int,
111
+ default=32,
112
+ help="Minimum number of output shards (use to ensure "
113
+ "shard count >= num_gpu * num_workers)",
114
+ )
115
+ parser.add_argument(
116
+ "--tokenizer_path",
117
+ type=str,
118
+ default="eustlb/higgs-audio-v2-tokenizer",
119
+ help="Path to audio tokenizer.",
120
+ )
121
+ parser.add_argument(
122
+ "--skip_errors", action="store_true", help="Skip items that fail to process"
123
+ )
124
+ parser.add_argument(
125
+ "--min_length",
126
+ type=float,
127
+ default=0.0,
128
+ help="Minimum audio duration in seconds (e.g. 2.0)",
129
+ )
130
+ parser.add_argument(
131
+ "--max_length",
132
+ type=float,
133
+ default=float("inf"),
134
+ help="Maximum audio duration in seconds (e.g. 15.0)",
135
+ )
136
+ parser.add_argument(
137
+ "--num_machines",
138
+ type=int,
139
+ default=1,
140
+ help="Total number of machines for distributed runs",
141
+ )
142
+ parser.add_argument(
143
+ "--machine_index",
144
+ type=int,
145
+ default=0,
146
+ help="Zero-based machine index when distributing across multiple "
147
+ "machines (e.g. 0, 1, ... num_machines-1)",
148
+ )
149
+ parser.add_argument(
150
+ "--nj_per_gpu",
151
+ type=int,
152
+ default=3,
153
+ help="Number of worker processes to spawn per GPU.",
154
+ )
155
+ parser.add_argument(
156
+ "--loader_workers",
157
+ type=int,
158
+ default=24,
159
+ help="Number of DataLoader workers for streaming IterableDataset.",
160
+ )
161
+ parser.add_argument(
162
+ "--shuffle",
163
+ type=str2bool,
164
+ default=True,
165
+ help="Shuffle data by default.",
166
+ )
167
+ parser.add_argument(
168
+ "--shuffle-seed",
169
+ type=int,
170
+ default=42,
171
+ help="Random seed for shuffle (default: 42).",
172
+ )
173
+ return parser
174
+
175
+
176
+ def count_lines(path):
177
+ with open(path, "rb") as f:
178
+ return sum(buf.count(b"\n") for buf in iter(lambda: f.read(1 << 20), b""))
179
+
180
+
181
+ def serialise_numpy(key: str, tokens: np.ndarray) -> dict:
182
+ buffer = io.BytesIO()
183
+ np.save(buffer, tokens)
184
+ return {"__key__": key, "npy": buffer.getvalue()}
185
+
186
+
187
+ def process_init(rank_queue, tokenizer_path):
188
+ """
189
+ Initialization function for each worker process.
190
+ Assigns a specific GPU to the process and loads the tokenizer.
191
+ """
192
+ global worker_tokenizer, worker_feature_extractor
193
+
194
+ # Configure worker process logging
195
+ formatter = (
196
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d]"
197
+ " [Worker %(process)d] %(message)s"
198
+ )
199
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
200
+
201
+ # Get assigned GPU rank
202
+ rank = rank_queue.get()
203
+ # Determine device
204
+ if rank != -1 and torch.cuda.is_available():
205
+ worker_device = torch.device(f"cuda:{rank}")
206
+ else:
207
+ worker_device = torch.device("cpu")
208
+
209
+ logging.debug(f"Worker process initialized with device: {worker_device}")
210
+ # Load tokenizer onto the specified device
211
+ worker_feature_extractor = AutoFeatureExtractor.from_pretrained(tokenizer_path)
212
+ worker_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
213
+ tokenizer_path, device_map=worker_device
214
+ )
215
+ logging.debug(f"Tokenizer loaded successfully on device {worker_device}")
216
+
217
+
218
+ def process_single_sample(sample: dict[str, Any]) -> dict[str, Any]:
219
+ """
220
+ Single-sample processing function executed in worker processes.
221
+ Skips invalid samples during streaming processing.
222
+ """
223
+ try:
224
+ audio_tensor = sample.get("audio", None) # shape (1, T)
225
+ if audio_tensor is None:
226
+ raise ValueError("Sample missing 'audio' field")
227
+
228
+ with torch.inference_mode():
229
+ key = sample["label"]["id"]
230
+ inputs = worker_feature_extractor(
231
+ raw_audio=audio_tensor.squeeze(0).numpy(),
232
+ sampling_rate=HIGGS_INPUT_SAMPLE_RATE,
233
+ return_tensors="pt",
234
+ ).to(worker_tokenizer.device)
235
+ audio_tokens = worker_tokenizer.encode(
236
+ inputs["input_values"],
237
+ ).audio_codes.squeeze(0)
238
+
239
+ assert len(audio_tokens.shape) == 2
240
+ assert audio_tokens.size(0) == 8
241
+
242
+ num_tokens = audio_tokens.size(1)
243
+ metadata = sample["label"]
244
+ metadata["num_tokens"] = num_tokens
245
+
246
+ # Convert to numpy format for subsequent serialization (int16 to save space)
247
+ audio_tokens_np = audio_tokens.to(torch.int16).cpu().numpy()
248
+
249
+ return {
250
+ "status": "success",
251
+ "key": key,
252
+ "audio_tokens": audio_tokens_np,
253
+ "metadata": metadata,
254
+ "error_msg": None,
255
+ }
256
+ except Exception as e:
257
+ sample_id = sample.get("label", {}).get("id", "unknown")
258
+ logging.error(f"Failed to process sample {sample_id}: {e}")
259
+ return {
260
+ "status": "error",
261
+ "key": sample_id,
262
+ "audio_tokens": None,
263
+ "metadata": None,
264
+ "error_msg": str(e),
265
+ }
266
+
267
+
268
+ def _normalise_value(value: Any) -> Any:
269
+ """Convert tensors and NumPy scalars to serialisable Python objects."""
270
+ if isinstance(value, torch.Tensor):
271
+ if value.ndim == 0:
272
+ return value.item()
273
+ return value.cpu().tolist()
274
+ if isinstance(value, np.generic):
275
+ return value.item()
276
+ if isinstance(value, np.ndarray):
277
+ return value.tolist()
278
+ return value
279
+
280
+
281
+ def _encode_metadata(metadata: dict[str, Any]) -> bytes:
282
+ cleaned: dict[str, Any] = {}
283
+ for key, value in metadata.items():
284
+ if value is None:
285
+ continue
286
+ cleaned[key] = _normalise_value(value)
287
+ return json.dumps(cleaned, ensure_ascii=False).encode("utf-8")
288
+
289
+
290
+ class StreamingLengthFilteredDataset(IterableDataset):
291
+ def __init__(
292
+ self,
293
+ base_iterable,
294
+ min_len: float,
295
+ max_len: float,
296
+ sr: int,
297
+ ):
298
+ self.base_iterable = base_iterable
299
+ self.min_len = min_len
300
+ self.max_len = max_len
301
+ self.sr = sr
302
+ self.filtered_count = 0
303
+
304
+ def __iter__(self):
305
+ """Stream samples one by one and filter on the fly."""
306
+ for sample in self.base_iterable:
307
+ try:
308
+ duration = sample["audio"].size(-1) / self.sr
309
+ if self.min_len <= duration <= self.max_len:
310
+ yield sample
311
+ else:
312
+ self.filtered_count += 1
313
+ logging.warning(
314
+ f"Filtered sample (duration out of range): "
315
+ f"{sample['label']['id']} ({duration:.2f}s)"
316
+ )
317
+ except Exception as e:
318
+ logging.warning(f"Skipped invalid sample during streaming: {e}")
319
+ continue
320
+
321
+
322
+ def main() -> None:
323
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
324
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
325
+ parser = build_parser()
326
+ args = parser.parse_args()
327
+ mp.set_start_method("spawn", force=True)
328
+
329
+ # Validate input arguments
330
+ assert bool(args.input_manifest) != bool(
331
+ args.input_jsonl
332
+ ), "Exactly one of --input_manifest or --input_jsonl must be provided."
333
+
334
+ if args.num_machines > 1:
335
+ assert (
336
+ 0 <= args.machine_index < args.num_machines
337
+ ), f"machine_index {args.machine_index} must be in [0, {args.num_machines})"
338
+
339
+ # Build base dataset and count total samples based on input mode
340
+ if args.input_jsonl:
341
+ logging.info(f"Input mode: raw JSONL ({args.input_jsonl})")
342
+ total_samples = count_lines(args.input_jsonl)
343
+ base_dataset = JsonlDatasetReader(
344
+ args.input_jsonl,
345
+ sample_rate=HIGGS_INPUT_SAMPLE_RATE,
346
+ shuffle=args.shuffle,
347
+ shuffle_seed=args.shuffle_seed,
348
+ )
349
+ loader_workers = args.loader_workers
350
+ else:
351
+ logging.info(f"Input mode: WebDataset manifest ({args.input_manifest})")
352
+ manifest_num_lines = count_lines(args.input_manifest)
353
+ loader_workers = min(args.loader_workers, manifest_num_lines)
354
+ total_samples = 0
355
+ manifests = []
356
+ with open(args.input_manifest, "r", encoding="utf-8") as f:
357
+ for line_id, line in tqdm(
358
+ enumerate(f),
359
+ total=manifest_num_lines,
360
+ desc="Calculating dataset length",
361
+ ):
362
+ items = line.strip().split(" ")
363
+ tar_path, jsonl_path, num_items, duration = (
364
+ items[0],
365
+ items[1],
366
+ int(items[2]),
367
+ float(items[3]),
368
+ )
369
+ assert os.path.exists(tar_path), f"File {tar_path} does not exist."
370
+ assert os.path.exists(jsonl_path), f"File {jsonl_path} does not exist."
371
+ assert jsonl_path.endswith(
372
+ ".jsonl"
373
+ ), f"File {jsonl_path} is not a .jsonl file."
374
+ if (
375
+ args.num_machines > 1
376
+ and line_id % args.num_machines != args.machine_index
377
+ ):
378
+ continue
379
+ total_samples += num_items
380
+ manifests.append((tar_path, jsonl_path, num_items, duration))
381
+ logging.info(
382
+ f"Total shards: {manifest_num_lines}, "
383
+ f"Shards for current index: {len(manifests)}"
384
+ )
385
+ base_dataset = WebDatasetReader(
386
+ manifests=manifests,
387
+ sample_rate=HIGGS_INPUT_SAMPLE_RATE,
388
+ evaluation=True,
389
+ )
390
+
391
+ # Adjust samples_per_shard if min_num_shards would be violated
392
+ samples_per_shard = args.samples_per_shard
393
+ if total_samples > 0:
394
+ estimated_shards = max(
395
+ 1, (total_samples + samples_per_shard - 1) // samples_per_shard
396
+ )
397
+ if estimated_shards < args.min_num_shards:
398
+ samples_per_shard = max(1, total_samples // args.min_num_shards)
399
+ logging.info(
400
+ f"Adjusted samples_per_shard from {args.samples_per_shard} to "
401
+ f"{samples_per_shard} to meet min_num_shards={args.min_num_shards} "
402
+ f"(total_samples={total_samples})"
403
+ )
404
+
405
+ # Apply length filter and create DataLoader
406
+ filtered_dataset = StreamingLengthFilteredDataset(
407
+ base_iterable=base_dataset,
408
+ min_len=args.min_length,
409
+ max_len=args.max_length,
410
+ sr=HIGGS_INPUT_SAMPLE_RATE,
411
+ )
412
+ dataloader = DataLoader(
413
+ dataset=filtered_dataset,
414
+ batch_size=None,
415
+ num_workers=loader_workers,
416
+ persistent_workers=loader_workers > 0,
417
+ pin_memory=False,
418
+ )
419
+
420
+ # Configure multi-GPU multi-process setup
421
+ num_devices = torch.cuda.device_count()
422
+ if num_devices == 0:
423
+ logging.warning("No GPUs detected - using CPU for processing")
424
+ num_processes = args.nj_per_gpu
425
+ else:
426
+ num_processes = num_devices * args.nj_per_gpu
427
+ logging.info(
428
+ f"GPU count: {num_devices}, Processes per GPU: {args.nj_per_gpu}, "
429
+ f"Total processes: {num_processes}"
430
+ )
431
+
432
+ # Shared GPU rank queue for process assignment
433
+ manager = mp.Manager()
434
+ rank_queue = manager.Queue()
435
+ for rank in list(range(num_devices)) * args.nj_per_gpu:
436
+ rank_queue.put(rank)
437
+ if num_devices == 0:
438
+ for _ in range(num_processes):
439
+ rank_queue.put(-1)
440
+
441
+ # Prepare output paths
442
+ tar_output_pattern = str(Path(args.tar_output_pattern).expanduser())
443
+ jsonl_output_pattern = str(Path(args.jsonl_output_pattern).expanduser())
444
+ Path(tar_output_pattern).parent.mkdir(parents=True, exist_ok=True)
445
+ Path(jsonl_output_pattern).parent.mkdir(parents=True, exist_ok=True)
446
+
447
+ # Determine output directory from tar_output_pattern
448
+ output_dir = Path(tar_output_pattern).parent.parent
449
+ error_log_path = str(output_dir / "errors.jsonl")
450
+ manifest_path = str(output_dir / "data.lst")
451
+
452
+ # Setup error logger (writes to errors.jsonl)
453
+ error_logger = logging.getLogger("error_log")
454
+ error_logger.setLevel(logging.ERROR)
455
+ error_logger.handlers.clear()
456
+ error_fh = logging.FileHandler(error_log_path, mode="w", encoding="utf-8")
457
+ error_fh.setFormatter(logging.Formatter("%(message)s"))
458
+ error_logger.addHandler(error_fh)
459
+
460
+ # Progress and error tracking
461
+ processed_count = 0
462
+ error_count = 0
463
+ write_error_count = 0
464
+ failed_ids = []
465
+ shard_idx = 0
466
+ shard_sample_count = 0
467
+ shard_duration = 0.0
468
+ shard_manifest = {} # shard_idx -> (tar_path, jsonl_path, count, duration)
469
+
470
+ tar_writer = None
471
+ jsonl_file = None
472
+
473
+ def open_new_shard():
474
+ nonlocal tar_writer, jsonl_file, shard_idx, shard_sample_count, shard_duration
475
+ if tar_writer is not None:
476
+ tar_writer.close()
477
+ if jsonl_file is not None:
478
+ jsonl_file.close()
479
+ # Record manifest for the previous shard
480
+ if shard_idx > 0 and shard_sample_count > 0:
481
+ prev_idx = shard_idx - 1
482
+ shard_manifest[prev_idx] = (
483
+ os.path.abspath(tar_output_pattern % prev_idx),
484
+ os.path.abspath(jsonl_output_pattern % prev_idx),
485
+ shard_sample_count,
486
+ shard_duration,
487
+ )
488
+ tar_fname = tar_output_pattern % shard_idx
489
+ jsonl_fname = jsonl_output_pattern % shard_idx
490
+ tar_writer = wds.TarWriter(tar_fname)
491
+ jsonl_file = open(jsonl_fname, "w", encoding="utf-8")
492
+ shard_idx += 1
493
+ shard_sample_count = 0
494
+ shard_duration = 0.0
495
+
496
+ def write_sample(key, audio_tokens_np, metadata):
497
+ nonlocal shard_sample_count, write_error_count, shard_duration
498
+ assert tar_writer is not None and jsonl_file is not None
499
+ try:
500
+ token_record = serialise_numpy(key, audio_tokens_np)
501
+ json_record = _encode_metadata(metadata)
502
+ tar_writer.write(token_record)
503
+ jsonl_file.write(json_record.decode("utf-8") + "\n")
504
+ shard_sample_count += 1
505
+ shard_duration += metadata.get("audio_duration", 0.0)
506
+ except Exception as exc:
507
+ write_error_count += 1
508
+ failed_ids.append(key)
509
+ error_logger.error(
510
+ json.dumps({"id": key, "reason": str(exc)}, ensure_ascii=False)
511
+ )
512
+ logging.error(f"Write failed for sample {key}: {exc}")
513
+
514
+ def handle_result(result):
515
+ nonlocal processed_count, error_count
516
+ if result["status"] == "success":
517
+ # Rotate shard if needed
518
+ if tar_writer is None or shard_sample_count >= samples_per_shard:
519
+ open_new_shard()
520
+ write_sample(result["key"], result["audio_tokens"], result["metadata"])
521
+ processed_count += 1
522
+ else:
523
+ error_count += 1
524
+ failed_ids.append(result["key"])
525
+ error_logger.error(
526
+ json.dumps(
527
+ {"id": result["key"], "reason": result["error_msg"]},
528
+ ensure_ascii=False,
529
+ )
530
+ )
531
+ if not args.skip_errors:
532
+ raise RuntimeError(
533
+ f"Sample {result['key']} processing failed due "
534
+ f"to {result['error_msg']} - terminating"
535
+ )
536
+ logging.warning(
537
+ f"Skipping failed sample {result['key']}: {result['error_msg']}"
538
+ )
539
+
540
+ main_progress = tqdm(total=total_samples, desc="Extracting Audio Tokens")
541
+
542
+ try:
543
+ with ProcessPoolExecutor(
544
+ max_workers=num_processes,
545
+ initializer=process_init,
546
+ initargs=(rank_queue, args.tokenizer_path),
547
+ ) as executor:
548
+ logging.info(f"Submitting tasks... ({num_processes} workers)")
549
+ futures = set()
550
+ max_pending = num_processes * 10
551
+
552
+ def drain_completed():
553
+ """Wait for at least one future to complete, process all done."""
554
+ nonlocal futures
555
+ done, _ = wait(futures, return_when=FIRST_COMPLETED)
556
+ for f in done:
557
+ futures.discard(f)
558
+ result = f.result()
559
+ main_progress.update(1)
560
+ handle_result(result)
561
+ main_progress.set_postfix(
562
+ Samples=processed_count,
563
+ Errors=error_count,
564
+ )
565
+
566
+ # Stream samples from DataLoader
567
+ for sample in dataloader:
568
+ if len(futures) >= max_pending:
569
+ drain_completed()
570
+
571
+ future = executor.submit(process_single_sample, sample)
572
+ futures.add(future)
573
+
574
+ # Process remaining futures
575
+ logging.info("Processing remaining pending samples...")
576
+ while futures:
577
+ drain_completed()
578
+
579
+ except Exception:
580
+ logging.error("Critical error during processing", exc_info=True)
581
+ raise
582
+ finally:
583
+ main_progress.close()
584
+ if tar_writer is not None:
585
+ tar_writer.close()
586
+ if jsonl_file is not None:
587
+ jsonl_file.close()
588
+ # Record the last shard in the manifest
589
+ if shard_idx > 0 and shard_sample_count > 0:
590
+ last_idx = shard_idx - 1
591
+ shard_manifest[last_idx] = (
592
+ os.path.abspath(tar_output_pattern % last_idx),
593
+ os.path.abspath(jsonl_output_pattern % last_idx),
594
+ shard_sample_count,
595
+ shard_duration,
596
+ )
597
+
598
+ # Write manifest file (data.lst)
599
+ with open(manifest_path, "w", encoding="utf-8") as mf:
600
+ for idx in sorted(shard_manifest.keys()):
601
+ tar_path, jsonl_path, count, duration = shard_manifest[idx]
602
+ mf.write(f"{tar_path} {jsonl_path} {count} {duration:.3f}\n")
603
+
604
+ # Output final statistics
605
+ total_failed = error_count + write_error_count
606
+ filtered_and_skipped = total_samples - processed_count - total_failed
607
+ logging.info(
608
+ f"Processing Complete - Successful: {processed_count}, Failed: {total_failed}, "
609
+ f"Filtered/Skipped: {filtered_and_skipped}, Shards written: {shard_idx}"
610
+ )
611
+ logging.info(f"Manifest written to: {manifest_path} ({len(shard_manifest)} shards)")
612
+ if total_failed > 0:
613
+ logging.info(f"Error details: {error_log_path}")
614
+ if failed_ids and args.skip_errors:
615
+ logging.warning(
616
+ f"Failed sample IDs (count: {len(failed_ids)}): {failed_ids[:100]}..."
617
+ )
618
+ if write_error_count > 0 and not args.skip_errors:
619
+ raise RuntimeError(
620
+ f"{write_error_count} samples failed to write - check logs for details"
621
+ )
622
+
623
+
624
+ if __name__ == "__main__":
625
+ main()
omnivoice/scripts/extract_audio_tokens_add_noise.py ADDED
@@ -0,0 +1,825 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Extract audio tokens from audio data and pack them into WebDataset shards.
20
+
21
+ Extends ``extract_audio_tokens.py`` with optional noise and reverberation
22
+ augmentation on the prompt (reference) portion of the audio. Requires a
23
+ noise manifest and/or RIR manifest.
24
+
25
+ Supports two input modes:
26
+
27
+ 1. WebDataset manifest (data.lst):
28
+ python extract_audio_tokens_add_noise.py \\
29
+ --input_manifest data.lst \\
30
+ --noise_manifest noise.lst \\
31
+ --tar_output_pattern output/audios/shard-%06d.tar \\
32
+ --jsonl_output_pattern output/txts/shard-%06d.jsonl
33
+
34
+ 2. Raw JSONL (each line: {"id": "...", "audio_path": "...", "text": "...", ...}):
35
+ python extract_audio_tokens_add_noise.py \\
36
+ --input_jsonl data.jsonl \\
37
+ --noise_manifest noise.lst \\
38
+ --tar_output_pattern output/audios/shard-%06d.tar \\
39
+ --jsonl_output_pattern output/txts/shard-%06d.jsonl
40
+
41
+ Output structure:
42
+ output_dir/
43
+ ├── audios/ # WebDataset tar shards (.npy audio tokens + .json metadata)
44
+ │ ├── shard_000000.tar
45
+ │ └── ...
46
+ ├── txts/ # Per-shard JSONL metadata
47
+ │ ├── shard_000000.jsonl
48
+ │ └── ...
49
+ ├── data.lst # Manifest: <tar_path> <jsonl_path> <sample_count> <total_duration>
50
+ └── errors.jsonl # Failed samples with error details
51
+ """
52
+
53
+ import argparse
54
+ import io
55
+ import json
56
+ import logging
57
+ import math
58
+ import multiprocessing as mp
59
+ import os
60
+ import random
61
+ import warnings
62
+ from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait
63
+ from pathlib import Path
64
+ from typing import Any
65
+
66
+ import numpy as np
67
+ import torch
68
+ import torch.nn.functional as F
69
+ import torchaudio
70
+ import webdataset as wds
71
+ from torch.utils.data import DataLoader, IterableDataset
72
+ from tqdm.auto import tqdm
73
+ from transformers import AutoFeatureExtractor, HiggsAudioV2TokenizerModel
74
+
75
+ from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
76
+ from omnivoice.utils.common import str2bool
77
+
78
+ warnings.filterwarnings(
79
+ "ignore", category=FutureWarning, module="torch.nn.utils.weight_norm"
80
+ )
81
+
82
+ HIGGS_INPUT_SAMPLE_RATE = 24_000
83
+
84
+ # Global variables: Store tokenizer and device for each worker process
85
+ worker_tokenizer = None
86
+ worker_feature_extractor = None
87
+ worker_noise_sampler = None
88
+ worker_rir_sampler = None
89
+
90
+
91
+ def build_parser() -> argparse.ArgumentParser:
92
+ parser = argparse.ArgumentParser(description=__doc__)
93
+ parser.add_argument(
94
+ "--input_manifest",
95
+ default=None,
96
+ help="Path to input dataset manifest (data.lst).",
97
+ )
98
+ parser.add_argument(
99
+ "--input_jsonl",
100
+ default=None,
101
+ help="Path to raw JSONL file (alternative to --input_manifest).",
102
+ )
103
+ parser.add_argument(
104
+ "--tar_output_pattern",
105
+ required=True,
106
+ help="Tar shard pattern passed to WebDataset",
107
+ )
108
+ parser.add_argument(
109
+ "--jsonl_output_pattern",
110
+ required=True,
111
+ help="Jsonl shard pattern passed to WebDataset",
112
+ )
113
+ parser.add_argument(
114
+ "--samples_per_shard",
115
+ type=int,
116
+ default=1000,
117
+ help="Maximum records per shard",
118
+ )
119
+ parser.add_argument(
120
+ "--min_num_shards",
121
+ type=int,
122
+ default=32,
123
+ help="Minimum number of output shards (use to ensure "
124
+ "shard count >= num_gpu * num_workers)",
125
+ )
126
+ parser.add_argument(
127
+ "--tokenizer_path",
128
+ type=str,
129
+ default="eustlb/higgs-audio-v2-tokenizer",
130
+ help="Path to audio tokenizer.",
131
+ )
132
+ parser.add_argument(
133
+ "--skip_errors", action="store_true", help="Skip items that fail to process"
134
+ )
135
+ parser.add_argument(
136
+ "--min_length",
137
+ type=float,
138
+ default=0.0,
139
+ help="Minimum audio duration in seconds (e.g. 2.0)",
140
+ )
141
+ parser.add_argument(
142
+ "--max_length",
143
+ type=float,
144
+ default=float("inf"),
145
+ help="Maximum audio duration in seconds (e.g. 15.0)",
146
+ )
147
+ parser.add_argument(
148
+ "--num_machines",
149
+ type=int,
150
+ default=1,
151
+ help="Total number of machines for distributed runs",
152
+ )
153
+ parser.add_argument(
154
+ "--machine_index",
155
+ type=int,
156
+ default=0,
157
+ help="Zero-based machine index when distributing across multiple "
158
+ "machines (e.g. 0, 1, ... num_machines-1)",
159
+ )
160
+ parser.add_argument(
161
+ "--nj_per_gpu",
162
+ type=int,
163
+ default=3,
164
+ help="Number of worker processes to spawn per GPU.",
165
+ )
166
+ parser.add_argument(
167
+ "--loader_workers",
168
+ type=int,
169
+ default=24,
170
+ help="Number of DataLoader workers for streaming IterableDataset.",
171
+ )
172
+ parser.add_argument(
173
+ "--shuffle",
174
+ type=str2bool,
175
+ default=True,
176
+ help="Shuffle data by default.",
177
+ )
178
+ parser.add_argument(
179
+ "--shuffle-seed",
180
+ type=int,
181
+ default=42,
182
+ help="Random seed for shuffle (default: 42).",
183
+ )
184
+ parser.add_argument(
185
+ "--noise_manifest",
186
+ default=None,
187
+ help="Path to noise manifest (list of tar files). Enables prompt noise augmentation.",
188
+ )
189
+ parser.add_argument(
190
+ "--rir_manifest",
191
+ default=None,
192
+ help="Path to RIR manifest (list of tar files). Enables prompt reverb augmentation.",
193
+ )
194
+ return parser
195
+
196
+
197
+ def count_lines(path):
198
+ with open(path, "rb") as f:
199
+ return sum(buf.count(b"\n") for buf in iter(lambda: f.read(1 << 20), b""))
200
+
201
+
202
+ def serialise_numpy(key: str, tokens: np.ndarray) -> dict:
203
+ buffer = io.BytesIO()
204
+ np.save(buffer, tokens)
205
+ return {"__key__": key, "npy": buffer.getvalue()}
206
+
207
+
208
+ def _load_aug_audio(data, sample_rate=24000):
209
+ """Simple audio loader for augmentation files."""
210
+ with io.BytesIO(data) as b:
211
+ wav, sr = torchaudio.load(b)
212
+ if wav.shape[0] > 1:
213
+ wav = wav.mean(dim=0, keepdim=True)
214
+ if sr != sample_rate:
215
+ wav = torchaudio.functional.resample(wav, sr, sample_rate)
216
+ return wav
217
+
218
+
219
+ class SimpleWorkerSampler:
220
+ """A lightweight infinite sampler for noise/RIR within a worker process."""
221
+
222
+ def __init__(self, tar_paths, sample_rate=24000):
223
+ self.dataset = (
224
+ wds.WebDataset(
225
+ tar_paths, shardshuffle=True, nodesplitter=None, workersplitter=None
226
+ )
227
+ .decode()
228
+ .map(lambda s: self._decode(s, sample_rate))
229
+ .select(lambda x: x is not None)
230
+ .shuffle(100)
231
+ .repeat()
232
+ )
233
+ self.iterator = iter(self.dataset)
234
+
235
+ def _decode(self, sample, sample_rate):
236
+ for ext in ["wav", "flac", "mp3"]:
237
+ if ext in sample:
238
+ return _load_aug_audio(sample[ext], sample_rate)
239
+ return None
240
+
241
+ def sample_segment(self, target_len, allow_repeat=True):
242
+ """Get a random segment of noise matching the target length."""
243
+ try:
244
+ audio = next(self.iterator)
245
+ except StopIteration:
246
+ self.iterator = iter(self.dataset)
247
+ audio = next(self.iterator)
248
+
249
+ cur_len = audio.size(-1)
250
+ if cur_len < target_len and allow_repeat:
251
+ if cur_len > 0:
252
+ num_repeats = math.ceil(target_len / cur_len)
253
+ audio = audio.repeat(1, num_repeats)
254
+ else:
255
+ audio = F.pad(audio, (0, target_len), mode="constant")
256
+ cur_len = audio.size(-1)
257
+
258
+ if cur_len > target_len:
259
+ start = random.randint(0, cur_len - target_len)
260
+ audio = audio[..., start : start + target_len]
261
+
262
+ return audio
263
+
264
+
265
+ def _convolve1d(signal: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
266
+ m = signal.size(-1)
267
+ n = kernel.size(-1)
268
+ padded_size = m + n - 1
269
+ f_signal = torch.fft.rfft(signal, n=padded_size)
270
+ f_kernel = torch.fft.rfft(kernel, n=padded_size)
271
+ f_result = f_signal * f_kernel
272
+ result = torch.fft.irfft(f_result, n=padded_size)
273
+ return result[:padded_size]
274
+
275
+
276
+ def _apply_rir(audio, rir, mix_ratio=0.5):
277
+ rir_scaling_factor = 0.5**15
278
+ N_in = audio.shape[-1]
279
+ rir_d = rir[0, :] * rir_scaling_factor
280
+ aug_d = _convolve1d(audio[0], rir_d)
281
+ shift_index = torch.argmax(torch.abs(rir_d))
282
+ end_index = shift_index + N_in
283
+ if end_index > aug_d.shape[0]:
284
+ augmented = F.pad(aug_d[shift_index:], (0, end_index - aug_d.shape[0]))
285
+ else:
286
+ augmented = aug_d[shift_index:end_index]
287
+ power_before = torch.sum(audio[0] ** 2)
288
+ power_after = torch.sum(augmented**2)
289
+ if power_after > 0:
290
+ augmented *= torch.sqrt(power_before / power_after)
291
+ mixed = (1 - mix_ratio) * audio[0] + mix_ratio * augmented
292
+ return mixed.unsqueeze(0)
293
+
294
+
295
+ def process_init(rank_queue, tokenizer_path, noise_manifest=None, rir_manifest=None):
296
+ """
297
+ Initialization function for each worker process.
298
+ Assigns a specific GPU to the process and loads the tokenizer.
299
+ """
300
+ global worker_tokenizer, worker_feature_extractor, worker_noise_sampler, worker_rir_sampler
301
+
302
+ # Configure worker process logging
303
+ formatter = (
304
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d]"
305
+ " [Worker %(process)d] %(message)s"
306
+ )
307
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
308
+
309
+ # Get assigned GPU rank
310
+ rank = rank_queue.get()
311
+ # Determine device
312
+ if rank != -1 and torch.cuda.is_available():
313
+ worker_device = torch.device(f"cuda:{rank}")
314
+ else:
315
+ worker_device = torch.device("cpu")
316
+
317
+ logging.debug(f"Worker process initialized with device: {worker_device}")
318
+ # Load tokenizer onto the specified device
319
+ worker_feature_extractor = AutoFeatureExtractor.from_pretrained(tokenizer_path)
320
+ worker_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
321
+ tokenizer_path, device_map=worker_device
322
+ )
323
+ logging.debug(f"Tokenizer loaded successfully on device {worker_device}")
324
+
325
+ # Initialize augmentation samplers (optional)
326
+ if noise_manifest:
327
+ try:
328
+ with open(noise_manifest, "r") as f:
329
+ tars = [l.strip().split()[0] for l in f if l.strip()]
330
+ worker_noise_sampler = SimpleWorkerSampler(
331
+ tars, sample_rate=HIGGS_INPUT_SAMPLE_RATE
332
+ )
333
+ logging.debug("Noise sampler initialized.")
334
+ except Exception as e:
335
+ logging.warning(f"Failed to load noise manifest: {e}")
336
+
337
+ if rir_manifest:
338
+ try:
339
+ with open(rir_manifest, "r") as f:
340
+ tars = [l.strip().split()[0] for l in f if l.strip()]
341
+ worker_rir_sampler = SimpleWorkerSampler(
342
+ tars, sample_rate=HIGGS_INPUT_SAMPLE_RATE
343
+ )
344
+ logging.debug("RIR sampler initialized.")
345
+ except Exception as e:
346
+ logging.warning(f"Failed to load RIR manifest: {e}")
347
+
348
+
349
+ def _augment_prompt(audio_tensor: torch.Tensor) -> tuple[torch.Tensor, int]:
350
+ """Apply noise/reverb augmentation to the front portion of audio.
351
+
352
+ Returns the augmented audio and the sample index where clean audio starts.
353
+ """
354
+ # Pre-normalization
355
+ max_val = audio_tensor.abs().max() + 1e-7
356
+ audio_tensor = (audio_tensor / max_val) * 0.6
357
+
358
+ total_len = audio_tensor.size(-1)
359
+ ratio = random.uniform(0.1, 0.3)
360
+ split_idx = int(total_len * ratio)
361
+ front_part = audio_tensor[:, :split_idx].clone()
362
+
363
+ # Apply noise
364
+ if worker_noise_sampler is not None:
365
+ noise = worker_noise_sampler.sample_segment(split_idx)
366
+ snr_db = random.uniform(5, 15)
367
+ sig_rms = front_part.norm(p=2) / (split_idx**0.5)
368
+ noise_rms = noise.norm(p=2) / (split_idx**0.5)
369
+ if noise_rms > 1e-9:
370
+ snr = 10 ** (snr_db / 20)
371
+ scale = sig_rms / (snr * noise_rms + 1e-8)
372
+ front_part = front_part + noise * scale
373
+
374
+ # Apply RIR (30% probability)
375
+ if worker_rir_sampler is not None and random.random() < 0.3:
376
+ rir = worker_rir_sampler.sample_segment(split_idx, allow_repeat=False)
377
+ reverb_amt = random.uniform(0.3, 1.0)
378
+ try:
379
+ front_part = _apply_rir(front_part, rir, reverb_amt)
380
+ except Exception as e:
381
+ logging.warning(f"RIR failed: {e}")
382
+
383
+ # Merge back
384
+ if front_part.device != audio_tensor.device:
385
+ front_part = front_part.to(audio_tensor.device)
386
+ audio_tensor[:, :split_idx] = front_part
387
+
388
+ # Post-normalization
389
+ max_val = audio_tensor.abs().max() + 1e-7
390
+ audio_tensor = (audio_tensor / max_val) * 0.9
391
+
392
+ return audio_tensor, split_idx
393
+
394
+
395
+ def process_single_sample(sample: dict[str, Any]) -> dict[str, Any]:
396
+ """
397
+ Single-sample processing function executed in worker processes.
398
+ Skips invalid samples during streaming processing.
399
+ """
400
+ try:
401
+ audio_tensor = sample.get("audio", None) # shape (1, T)
402
+ if audio_tensor is None:
403
+ raise ValueError("Sample missing 'audio' field")
404
+
405
+ # Apply prompt augmentation if noise/rir samplers are available
406
+ enable_aug = worker_noise_sampler is not None or worker_rir_sampler is not None
407
+ clean_sample_idx = 0
408
+ if enable_aug:
409
+ audio_tensor, clean_sample_idx = _augment_prompt(audio_tensor)
410
+
411
+ with torch.inference_mode():
412
+ key = sample["label"]["id"]
413
+
414
+ inputs = worker_feature_extractor(
415
+ raw_audio=audio_tensor.squeeze(0).numpy(),
416
+ sampling_rate=HIGGS_INPUT_SAMPLE_RATE,
417
+ return_tensors="pt",
418
+ ).to(worker_tokenizer.device)
419
+ audio_tokens = worker_tokenizer.encode(
420
+ inputs["input_values"],
421
+ ).audio_codes.squeeze(0)
422
+
423
+ assert len(audio_tokens.shape) == 2
424
+ assert audio_tokens.size(0) == 8
425
+
426
+ num_tokens = audio_tokens.size(1)
427
+ metadata = sample["label"]
428
+ metadata["num_tokens"] = num_tokens
429
+
430
+ if enable_aug:
431
+ clean_token_idx = math.ceil(
432
+ clean_sample_idx / worker_tokenizer.config.hop_length
433
+ )
434
+ metadata["clean_start_token_idx"] = clean_token_idx
435
+
436
+ # Convert to numpy format for subsequent serialization (int16 to save space)
437
+ audio_tokens_np = audio_tokens.to(torch.int16).cpu().numpy()
438
+
439
+ return {
440
+ "status": "success",
441
+ "key": key,
442
+ "audio_tokens": audio_tokens_np,
443
+ "metadata": metadata,
444
+ "error_msg": None,
445
+ }
446
+ except Exception as e:
447
+ sample_id = sample.get("label", {}).get("id", "unknown")
448
+ logging.error(f"Failed to process sample {sample_id}: {e}")
449
+ return {
450
+ "status": "error",
451
+ "key": sample_id,
452
+ "audio_tokens": None,
453
+ "metadata": None,
454
+ "error_msg": str(e),
455
+ }
456
+
457
+
458
+ def _normalise_value(value: Any) -> Any:
459
+ """Convert tensors and NumPy scalars to serialisable Python objects."""
460
+ if isinstance(value, torch.Tensor):
461
+ if value.ndim == 0:
462
+ return value.item()
463
+ return value.cpu().tolist()
464
+ if isinstance(value, np.generic):
465
+ return value.item()
466
+ if isinstance(value, np.ndarray):
467
+ return value.tolist()
468
+ return value
469
+
470
+
471
+ def _encode_metadata(metadata: dict[str, Any]) -> bytes:
472
+ cleaned: dict[str, Any] = {}
473
+ for key, value in metadata.items():
474
+ if value is None:
475
+ continue
476
+ cleaned[key] = _normalise_value(value)
477
+ return json.dumps(cleaned, ensure_ascii=False).encode("utf-8")
478
+
479
+
480
+ class StreamingLengthFilteredDataset(IterableDataset):
481
+ def __init__(
482
+ self,
483
+ base_iterable,
484
+ min_len: float,
485
+ max_len: float,
486
+ sr: int,
487
+ ):
488
+ self.base_iterable = base_iterable
489
+ self.min_len = min_len
490
+ self.max_len = max_len
491
+ self.sr = sr
492
+ self.filtered_count = 0
493
+
494
+ def __iter__(self):
495
+ """Stream samples one by one and filter on the fly."""
496
+ for sample in self.base_iterable:
497
+ try:
498
+ duration = sample["audio"].size(-1) / self.sr
499
+ if self.min_len <= duration <= self.max_len:
500
+ yield sample
501
+ else:
502
+ self.filtered_count += 1
503
+ logging.warning(
504
+ f"Filtered sample (duration out of range): "
505
+ f"{sample['label']['id']} ({duration:.2f}s)"
506
+ )
507
+ except Exception as e:
508
+ logging.warning(f"Skipped invalid sample during streaming: {e}")
509
+ continue
510
+
511
+
512
+ def main() -> None:
513
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
514
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
515
+ parser = build_parser()
516
+ args = parser.parse_args()
517
+ mp.set_start_method("spawn", force=True)
518
+
519
+ # Validate input arguments
520
+ assert bool(args.input_manifest) != bool(
521
+ args.input_jsonl
522
+ ), "Exactly one of --input_manifest or --input_jsonl must be provided."
523
+
524
+ if args.num_machines > 1:
525
+ assert (
526
+ 0 <= args.machine_index < args.num_machines
527
+ ), f"machine_index {args.machine_index} must be in [0, {args.num_machines})"
528
+
529
+ # Build base dataset and count total samples based on input mode
530
+ if args.input_jsonl:
531
+ logging.info(f"Input mode: raw JSONL ({args.input_jsonl})")
532
+ total_samples = count_lines(args.input_jsonl)
533
+ base_dataset = JsonlDatasetReader(
534
+ args.input_jsonl,
535
+ sample_rate=HIGGS_INPUT_SAMPLE_RATE,
536
+ shuffle=args.shuffle,
537
+ shuffle_seed=args.shuffle_seed,
538
+ )
539
+ loader_workers = args.loader_workers
540
+ else:
541
+ logging.info(f"Input mode: WebDataset manifest ({args.input_manifest})")
542
+ manifest_num_lines = count_lines(args.input_manifest)
543
+ loader_workers = min(args.loader_workers, manifest_num_lines)
544
+ total_samples = 0
545
+ manifests = []
546
+ with open(args.input_manifest, "r", encoding="utf-8") as f:
547
+ for line_id, line in tqdm(
548
+ enumerate(f),
549
+ total=manifest_num_lines,
550
+ desc="Calculating dataset length",
551
+ ):
552
+ items = line.strip().split(" ")
553
+ tar_path, jsonl_path, num_items, duration = (
554
+ items[0],
555
+ items[1],
556
+ int(items[2]),
557
+ float(items[3]),
558
+ )
559
+ assert os.path.exists(tar_path), f"File {tar_path} does not exist."
560
+ assert os.path.exists(jsonl_path), f"File {jsonl_path} does not exist."
561
+ assert jsonl_path.endswith(
562
+ ".jsonl"
563
+ ), f"File {jsonl_path} is not a .jsonl file."
564
+ if (
565
+ args.num_machines > 1
566
+ and line_id % args.num_machines != args.machine_index
567
+ ):
568
+ continue
569
+ total_samples += num_items
570
+ manifests.append((tar_path, jsonl_path, num_items, duration))
571
+ logging.info(
572
+ f"Total shards: {manifest_num_lines}, "
573
+ f"Shards for current index: {len(manifests)}"
574
+ )
575
+ base_dataset = WebDatasetReader(
576
+ manifests=manifests,
577
+ sample_rate=HIGGS_INPUT_SAMPLE_RATE,
578
+ evaluation=True,
579
+ )
580
+
581
+ # Apply length filter and create DataLoader
582
+ filtered_dataset = StreamingLengthFilteredDataset(
583
+ base_iterable=base_dataset,
584
+ min_len=args.min_length,
585
+ max_len=args.max_length,
586
+ sr=HIGGS_INPUT_SAMPLE_RATE,
587
+ )
588
+ dataloader = DataLoader(
589
+ dataset=filtered_dataset,
590
+ batch_size=None,
591
+ num_workers=loader_workers,
592
+ persistent_workers=loader_workers > 0,
593
+ pin_memory=False,
594
+ )
595
+
596
+ # Adjust samples_per_shard if min_num_shards would be violated
597
+ samples_per_shard = args.samples_per_shard
598
+ if total_samples > 0:
599
+ estimated_shards = max(
600
+ 1, (total_samples + samples_per_shard - 1) // samples_per_shard
601
+ )
602
+ if estimated_shards < args.min_num_shards:
603
+ samples_per_shard = max(1, total_samples // args.min_num_shards)
604
+ logging.info(
605
+ f"Adjusted samples_per_shard from {args.samples_per_shard} to "
606
+ f"{samples_per_shard} to meet min_num_shards={args.min_num_shards} "
607
+ f"(total_samples={total_samples})"
608
+ )
609
+
610
+ # Configure multi-GPU multi-process setup
611
+ num_devices = torch.cuda.device_count()
612
+ if num_devices == 0:
613
+ logging.warning("No GPUs detected - using CPU for processing")
614
+ num_processes = args.nj_per_gpu
615
+ else:
616
+ num_processes = num_devices * args.nj_per_gpu
617
+ logging.info(
618
+ f"GPU count: {num_devices}, Processes per GPU: {args.nj_per_gpu}, "
619
+ f"Total processes: {num_processes}"
620
+ )
621
+ if args.noise_manifest or args.rir_manifest:
622
+ logging.info(
623
+ f"Prompt augmentation enabled - "
624
+ f"noise: {args.noise_manifest or 'off'}, rir: {args.rir_manifest or 'off'}"
625
+ )
626
+
627
+ # Shared GPU rank queue for process assignment
628
+ manager = mp.Manager()
629
+ rank_queue = manager.Queue()
630
+ for rank in list(range(num_devices)) * args.nj_per_gpu:
631
+ rank_queue.put(rank)
632
+ if num_devices == 0:
633
+ for _ in range(num_processes):
634
+ rank_queue.put(-1)
635
+
636
+ # Prepare output paths
637
+ tar_output_pattern = str(Path(args.tar_output_pattern).expanduser())
638
+ jsonl_output_pattern = str(Path(args.jsonl_output_pattern).expanduser())
639
+ Path(tar_output_pattern).parent.mkdir(parents=True, exist_ok=True)
640
+ Path(jsonl_output_pattern).parent.mkdir(parents=True, exist_ok=True)
641
+
642
+ # Determine output directory from tar_output_pattern
643
+ output_dir = Path(tar_output_pattern).parent.parent
644
+ error_log_path = str(output_dir / "errors.jsonl")
645
+ manifest_path = str(output_dir / "data.lst")
646
+
647
+ # Setup error logger (writes to errors.jsonl)
648
+ error_logger = logging.getLogger("error_log")
649
+ error_logger.setLevel(logging.ERROR)
650
+ error_logger.handlers.clear()
651
+ error_fh = logging.FileHandler(error_log_path, mode="w", encoding="utf-8")
652
+ error_fh.setFormatter(logging.Formatter("%(message)s"))
653
+ error_logger.addHandler(error_fh)
654
+
655
+ # Progress and error tracking
656
+ processed_count = 0
657
+ error_count = 0
658
+ write_error_count = 0
659
+ failed_ids = []
660
+ shard_idx = 0
661
+ shard_sample_count = 0
662
+ shard_duration = 0.0
663
+ shard_manifest = {} # shard_idx -> (tar_path, jsonl_path, count, duration)
664
+
665
+ tar_writer = None
666
+ jsonl_file = None
667
+
668
+ def open_new_shard():
669
+ nonlocal tar_writer, jsonl_file, shard_idx, shard_sample_count, shard_duration
670
+ if tar_writer is not None:
671
+ tar_writer.close()
672
+ if jsonl_file is not None:
673
+ jsonl_file.close()
674
+ # Record manifest for the previous shard
675
+ if shard_idx > 0 and shard_sample_count > 0:
676
+ prev_idx = shard_idx - 1
677
+ shard_manifest[prev_idx] = (
678
+ os.path.abspath(tar_output_pattern % prev_idx),
679
+ os.path.abspath(jsonl_output_pattern % prev_idx),
680
+ shard_sample_count,
681
+ shard_duration,
682
+ )
683
+ tar_fname = tar_output_pattern % shard_idx
684
+ jsonl_fname = jsonl_output_pattern % shard_idx
685
+ tar_writer = wds.TarWriter(tar_fname)
686
+ jsonl_file = open(jsonl_fname, "w", encoding="utf-8")
687
+ shard_idx += 1
688
+ shard_sample_count = 0
689
+ shard_duration = 0.0
690
+
691
+ def write_sample(key, audio_tokens_np, metadata):
692
+ nonlocal shard_sample_count, write_error_count, shard_duration
693
+ assert tar_writer is not None and jsonl_file is not None
694
+ try:
695
+ token_record = serialise_numpy(key, audio_tokens_np)
696
+ json_record = _encode_metadata(metadata)
697
+ tar_writer.write(token_record)
698
+ jsonl_file.write(json_record.decode("utf-8") + "\n")
699
+ shard_sample_count += 1
700
+ shard_duration += metadata.get("audio_duration", 0.0)
701
+ except Exception as exc:
702
+ write_error_count += 1
703
+ failed_ids.append(key)
704
+ error_logger.error(
705
+ json.dumps({"id": key, "reason": str(exc)}, ensure_ascii=False)
706
+ )
707
+ logging.error(f"Write failed for sample {key}: {exc}")
708
+
709
+ def handle_result(result):
710
+ nonlocal processed_count, error_count
711
+ if result["status"] == "success":
712
+ # Rotate shard if needed
713
+ if tar_writer is None or shard_sample_count >= samples_per_shard:
714
+ open_new_shard()
715
+ write_sample(result["key"], result["audio_tokens"], result["metadata"])
716
+ processed_count += 1
717
+ else:
718
+ error_count += 1
719
+ failed_ids.append(result["key"])
720
+ error_logger.error(
721
+ json.dumps(
722
+ {"id": result["key"], "reason": result["error_msg"]},
723
+ ensure_ascii=False,
724
+ )
725
+ )
726
+ if not args.skip_errors:
727
+ raise RuntimeError(
728
+ f"Sample {result['key']} processing failed due "
729
+ f"to {result['error_msg']} - terminating"
730
+ )
731
+ logging.warning(
732
+ f"Skipping failed sample {result['key']}: {result['error_msg']}"
733
+ )
734
+
735
+ main_progress = tqdm(total=total_samples, desc="Extracting Audio Tokens")
736
+
737
+ try:
738
+ with ProcessPoolExecutor(
739
+ max_workers=num_processes,
740
+ initializer=process_init,
741
+ initargs=(
742
+ rank_queue,
743
+ args.tokenizer_path,
744
+ args.noise_manifest,
745
+ args.rir_manifest,
746
+ ),
747
+ ) as executor:
748
+ logging.info(f"Submitting tasks... ({num_processes} workers)")
749
+ futures = set()
750
+ max_pending = num_processes * 10
751
+
752
+ def drain_completed():
753
+ """Wait for at least one future to complete, process all done."""
754
+ nonlocal futures
755
+ done, _ = wait(futures, return_when=FIRST_COMPLETED)
756
+ for f in done:
757
+ futures.discard(f)
758
+ result = f.result()
759
+ main_progress.update(1)
760
+ handle_result(result)
761
+ main_progress.set_postfix(
762
+ Samples=processed_count,
763
+ Errors=error_count,
764
+ )
765
+
766
+ # Stream samples from DataLoader
767
+ for sample in dataloader:
768
+ if len(futures) >= max_pending:
769
+ drain_completed()
770
+
771
+ future = executor.submit(process_single_sample, sample)
772
+ futures.add(future)
773
+
774
+ # Process remaining futures
775
+ logging.info("Processing remaining pending samples...")
776
+ while futures:
777
+ drain_completed()
778
+
779
+ except Exception:
780
+ logging.error("Critical error during processing", exc_info=True)
781
+ raise
782
+ finally:
783
+ main_progress.close()
784
+ if tar_writer is not None:
785
+ tar_writer.close()
786
+ if jsonl_file is not None:
787
+ jsonl_file.close()
788
+ # Record the last shard in the manifest
789
+ if shard_idx > 0 and shard_sample_count > 0:
790
+ last_idx = shard_idx - 1
791
+ shard_manifest[last_idx] = (
792
+ os.path.abspath(tar_output_pattern % last_idx),
793
+ os.path.abspath(jsonl_output_pattern % last_idx),
794
+ shard_sample_count,
795
+ shard_duration,
796
+ )
797
+
798
+ # Write manifest file (data.lst)
799
+ with open(manifest_path, "w", encoding="utf-8") as mf:
800
+ for idx in sorted(shard_manifest.keys()):
801
+ tar_path, jsonl_path, count, duration = shard_manifest[idx]
802
+ mf.write(f"{tar_path} {jsonl_path} {count} {duration:.3f}\n")
803
+
804
+ # Output final statistics
805
+ total_failed = error_count + write_error_count
806
+ filtered_and_skipped = total_samples - processed_count - total_failed
807
+ logging.info(
808
+ f"Processing Complete - Successful: {processed_count}, Failed: {total_failed}, "
809
+ f"Filtered/Skipped: {filtered_and_skipped}, Shards written: {shard_idx}"
810
+ )
811
+ logging.info(f"Manifest written to: {manifest_path} ({len(shard_manifest)} shards)")
812
+ if total_failed > 0:
813
+ logging.info(f"Error details: {error_log_path}")
814
+ if failed_ids and args.skip_errors:
815
+ logging.warning(
816
+ f"Failed sample IDs (count: {len(failed_ids)}): {failed_ids[:100]}..."
817
+ )
818
+ if write_error_count > 0 and not args.skip_errors:
819
+ raise RuntimeError(
820
+ f"{write_error_count} samples failed to write - check logs for details"
821
+ )
822
+
823
+
824
+ if __name__ == "__main__":
825
+ main()
omnivoice/scripts/jsonl_to_webdataset.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Pack a JSONL audio dataset into a customed WebDataset shards
20
+ (paired .tar and .jsonl files).
21
+
22
+ Usage:
23
+ python jsonl_to_webdataset.py \
24
+ --input data.jsonl \
25
+ --output output_dir/ \
26
+ --workers 16 \
27
+ --threads 4 \
28
+ --shard-size 1000 \
29
+ --sr 24000
30
+
31
+ Input JSONL format (one JSON object per line):
32
+ {"id": "utt_001", "audio_path": "/data/wavs/001.wav", "text": "hello world", ...}
33
+
34
+ Required fields: "id", "audio_path", "text"
35
+ All other fields are preserved in the output metadata.
36
+
37
+ Output structure:
38
+ output_dir/
39
+ ├── audios/ # WebDataset tar shards
40
+ │ ├── shard_000000.tar
41
+ │ ├── shard_000001.tar
42
+ │ └── ...
43
+ ├── txts/ # Per-shard JSONL metadata (with audio_duration added)
44
+ │ ├── shard_000000.jsonl
45
+ │ ├── shard_000001.jsonl
46
+ │ └── ...
47
+ ├── data.lst # Manifest: <tar_path> <jsonl_path> <sample_count> <total_duration>
48
+ └── errors.jsonl # Failed samples with error details
49
+ """
50
+
51
+ import argparse
52
+ import io
53
+ import json
54
+ import logging
55
+ import multiprocessing as mp
56
+ import os
57
+ import random
58
+ from concurrent.futures import (
59
+ FIRST_COMPLETED,
60
+ ProcessPoolExecutor,
61
+ ThreadPoolExecutor,
62
+ as_completed,
63
+ wait,
64
+ )
65
+ from itertools import islice
66
+ from pathlib import Path
67
+
68
+ import torchaudio
69
+ import webdataset as wds
70
+ from tqdm import tqdm
71
+
72
+ from omnivoice.utils.common import str2bool
73
+
74
+
75
+ def build_parser() -> argparse.ArgumentParser:
76
+ parser = argparse.ArgumentParser(
77
+ description="Pack JSONL audio dataset into WebDataset shards."
78
+ )
79
+ parser.add_argument(
80
+ "--input", type=str, default="data.jsonl", help="Path to input JSONL file"
81
+ )
82
+ parser.add_argument(
83
+ "--output",
84
+ type=str,
85
+ default="emilia",
86
+ help="Path to output directory",
87
+ )
88
+ parser.add_argument(
89
+ "--workers",
90
+ type=int,
91
+ default=16,
92
+ help="Number of worker processes (default: 16)",
93
+ )
94
+ parser.add_argument(
95
+ "--threads",
96
+ type=int,
97
+ default=4,
98
+ help="Number of threads per worker process.",
99
+ )
100
+ parser.add_argument(
101
+ "--shard-size",
102
+ type=int,
103
+ default=1000,
104
+ help="Number of samples per shard (default: 1000)",
105
+ )
106
+ parser.add_argument(
107
+ "--sr", type=int, default=24000, help="Target sample rate (default: 24000)"
108
+ )
109
+ parser.add_argument(
110
+ "--shuffle",
111
+ type=str2bool,
112
+ default=True,
113
+ help="Shuffle data by default.",
114
+ )
115
+ parser.add_argument(
116
+ "--shuffle-seed",
117
+ type=int,
118
+ default=42,
119
+ help="Random seed for shuffle (default: 42)",
120
+ )
121
+ parser.add_argument(
122
+ "--min-duration",
123
+ type=float,
124
+ default=None,
125
+ help="Filter out samples shorter than this (seconds).",
126
+ )
127
+ parser.add_argument(
128
+ "--max-duration",
129
+ type=float,
130
+ default=None,
131
+ help="Filter out samples >= this duration (seconds).",
132
+ )
133
+ return parser
134
+
135
+
136
+ def read_jsonl(file_path):
137
+ with open(file_path, "r", encoding="utf-8") as f:
138
+ for line in f:
139
+ line = line.strip()
140
+ if line:
141
+ yield json.loads(line)
142
+
143
+
144
+ def chunked_reader(iterator, chunk_size):
145
+ it = iter(iterator)
146
+ while chunk := list(islice(it, chunk_size)):
147
+ yield chunk
148
+
149
+
150
+ def process_audio_item(meta, target_sr):
151
+ key = meta.get("id")
152
+ audio_path = meta.get("audio_path")
153
+
154
+ if not key or not audio_path:
155
+ return {
156
+ "error": {
157
+ "id": key,
158
+ "audio_path": audio_path,
159
+ "reason": "missing id or audio_path",
160
+ }
161
+ }
162
+
163
+ try:
164
+ if not os.path.exists(audio_path):
165
+ raise FileNotFoundError(f"{audio_path} not found")
166
+
167
+ waveform, sr = torchaudio.load(audio_path)
168
+ audio_duration = waveform.shape[1] / sr
169
+ meta["audio_duration"] = audio_duration
170
+
171
+ if target_sr and sr != target_sr:
172
+ waveform = torchaudio.functional.resample(waveform, sr, target_sr)
173
+ sr = target_sr
174
+
175
+ audio_buffer = io.BytesIO()
176
+ torchaudio.save(audio_buffer, waveform, sr, format="flac", bits_per_sample=16)
177
+ audio_bytes = audio_buffer.getvalue()
178
+
179
+ sample = {
180
+ "__key__": key,
181
+ "flac": audio_bytes,
182
+ }
183
+
184
+ return {"ok": (sample, meta)}
185
+
186
+ except Exception as e:
187
+ return {"error": {"id": key, "audio_path": audio_path, "reason": str(e)}}
188
+
189
+
190
+ def process_single_shard(
191
+ shard_idx,
192
+ records,
193
+ output_tar_pattern,
194
+ output_jsonl_pattern,
195
+ target_sr,
196
+ num_threads=4,
197
+ min_duration=None,
198
+ max_duration=None,
199
+ ):
200
+ tar_fname = output_tar_pattern % shard_idx
201
+ jsonl_fname = output_jsonl_pattern % shard_idx
202
+
203
+ processed_count = 0
204
+ filtered_count = 0
205
+ error_count = 0
206
+ total_duration = 0.0
207
+ errors = []
208
+
209
+ with wds.TarWriter(tar_fname) as sink, open(
210
+ jsonl_fname, "w", encoding="utf-8"
211
+ ) as jsonl_f:
212
+
213
+ with ThreadPoolExecutor(max_workers=num_threads) as thread_pool:
214
+ futures = []
215
+
216
+ for meta in records:
217
+ f = thread_pool.submit(process_audio_item, meta, target_sr)
218
+ futures.append(f)
219
+
220
+ for f in as_completed(futures):
221
+ result = f.result()
222
+
223
+ if "error" in result:
224
+ error_count += 1
225
+ errors.append(result["error"])
226
+ continue
227
+
228
+ sample, meta = result["ok"]
229
+ dur = meta.get("audio_duration", 0.0)
230
+
231
+ # Duration filtering (based on actual audio_duration computed above)
232
+ if min_duration is not None and dur < min_duration:
233
+ filtered_count += 1
234
+ continue
235
+ if max_duration is not None and dur >= max_duration:
236
+ filtered_count += 1
237
+ continue
238
+
239
+ sink.write(sample)
240
+
241
+ jsonl_f.write(json.dumps(meta, ensure_ascii=False) + "\n")
242
+
243
+ total_duration += dur
244
+ processed_count += 1
245
+
246
+ # Clean up empty shard files
247
+ if processed_count == 0:
248
+ for p in (tar_fname, jsonl_fname):
249
+ if os.path.exists(p):
250
+ os.remove(p)
251
+
252
+ return (
253
+ shard_idx,
254
+ processed_count,
255
+ error_count,
256
+ filtered_count,
257
+ total_duration,
258
+ errors,
259
+ )
260
+
261
+
262
+ def count_lines(path):
263
+ with open(path, "rb") as f:
264
+ return sum(buf.count(b"\n") for buf in iter(lambda: f.read(1 << 20), b""))
265
+
266
+
267
+ def pack_dataset(
268
+ input_jsonl,
269
+ output_dir,
270
+ samples_per_shard=5000,
271
+ num_workers=16,
272
+ target_sr=24000,
273
+ threads_per_worker=4,
274
+ shuffle=False,
275
+ shuffle_seed=None,
276
+ min_duration=None,
277
+ max_duration=None,
278
+ ):
279
+ input_path = Path(input_jsonl)
280
+ output_dir = Path(output_dir)
281
+ output_tar_dir = output_dir / "audios"
282
+ output_tar_dir.mkdir(parents=True, exist_ok=True)
283
+ output_jsonl_dir = output_dir / "txts"
284
+ output_jsonl_dir.mkdir(parents=True, exist_ok=True)
285
+
286
+ output_tar_pattern = str(output_tar_dir / "shard-%06d.tar")
287
+ output_jsonl_pattern = str(output_jsonl_dir / "shard-%06d.jsonl")
288
+
289
+ error_log_path = str(output_dir / "errors.jsonl")
290
+
291
+ # Setup error logger
292
+ error_logger = logging.getLogger("error_log")
293
+ error_logger.setLevel(logging.ERROR)
294
+ error_logger.handlers.clear()
295
+ fh = logging.FileHandler(error_log_path, mode="w", encoding="utf-8")
296
+ fh.setFormatter(logging.Formatter("%(message)s"))
297
+ error_logger.addHandler(fh)
298
+
299
+ shard_manifest = {}
300
+
301
+ print(f"Reading input: {input_path}")
302
+ print(f"Output dir: {output_dir}")
303
+ print(f"Strategy: {num_workers} Processes x {threads_per_worker} Threads")
304
+
305
+ if shuffle:
306
+ print("Load input dataset...")
307
+ entries = list(read_jsonl(input_path))
308
+ random.seed(shuffle_seed)
309
+ random.shuffle(entries)
310
+ print(f"Shuffled {len(entries)} entries (seed={shuffle_seed})")
311
+ total_lines = len(entries)
312
+ chunk_gen = chunked_reader(iter(entries), samples_per_shard)
313
+ else:
314
+ print("Calculating total lines...")
315
+ total_lines = count_lines(input_path)
316
+ chunk_gen = chunked_reader(read_jsonl(input_path), samples_per_shard)
317
+
318
+ if min_duration is not None or max_duration is not None:
319
+ print(
320
+ f"Duration filter: [{min_duration or 0:.2f}s"
321
+ f", {max_duration or float('inf'):.1f}s) (applied after audio decoding)"
322
+ )
323
+
324
+ total_shards_est = (total_lines + samples_per_shard - 1) // samples_per_shard
325
+ print(f"Total samples: {total_lines}, Estimated shards: {total_shards_est}")
326
+
327
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
328
+
329
+ futures = set()
330
+
331
+ shard_idx = 0
332
+ total_processed = 0
333
+ total_errors = 0
334
+ total_filtered = 0
335
+
336
+ pbar = tqdm(
337
+ total=total_shards_est,
338
+ desc="Shards Processed",
339
+ unit="shard",
340
+ )
341
+
342
+ def submit_next_chunks(limit):
343
+ """Pull up to `limit` chunks from generator, submit them."""
344
+ nonlocal shard_idx
345
+ submitted = 0
346
+ for chunk in chunk_gen:
347
+ f = executor.submit(
348
+ process_single_shard,
349
+ shard_idx,
350
+ chunk,
351
+ output_tar_pattern,
352
+ output_jsonl_pattern,
353
+ target_sr,
354
+ threads_per_worker,
355
+ min_duration,
356
+ max_duration,
357
+ )
358
+ futures.add(f)
359
+ shard_idx += 1
360
+ submitted += 1
361
+ if submitted >= limit:
362
+ break
363
+
364
+ submit_next_chunks(num_workers * 2)
365
+
366
+ while futures:
367
+ done, _ = wait(futures, return_when=FIRST_COMPLETED)
368
+
369
+ for f in done:
370
+ futures.remove(f)
371
+
372
+ try:
373
+ s_idx, p_count, e_count, f_count, s_duration, errors = f.result()
374
+ total_processed += p_count
375
+ total_errors += e_count
376
+ total_filtered += f_count
377
+
378
+ # Write error log
379
+ for err in errors:
380
+ err["shard_idx"] = s_idx
381
+ error_logger.error(json.dumps(err, ensure_ascii=False))
382
+
383
+ if p_count > 0:
384
+ tar_abs = os.path.abspath(output_tar_pattern % s_idx)
385
+ jsonl_abs = os.path.abspath(output_jsonl_pattern % s_idx)
386
+ shard_manifest[s_idx] = (
387
+ tar_abs,
388
+ jsonl_abs,
389
+ p_count,
390
+ s_duration,
391
+ )
392
+
393
+ pbar.set_postfix(
394
+ {
395
+ "Samples": total_processed,
396
+ "Filtered": total_filtered,
397
+ "Errors": total_errors,
398
+ }
399
+ )
400
+ pbar.update(1)
401
+ except Exception as e:
402
+ print(f"Shard task failed: {e}")
403
+
404
+ submit_next_chunks(1)
405
+
406
+ pbar.close()
407
+
408
+ # Write final manifest file (data.lst)
409
+ manifest_path = str(output_dir / "data.lst")
410
+ with open(manifest_path, "w", encoding="utf-8") as mf:
411
+ for idx in sorted(shard_manifest.keys()):
412
+ tar_path, jsonl_path, count, duration = shard_manifest[idx]
413
+ mf.write(f"{tar_path} {jsonl_path} {count} {duration:.3f}\n")
414
+
415
+ print(f"\nDone! Output saved to {output_dir}")
416
+ print(f"Successfully packed: {total_processed}")
417
+ print(f"Filtered by duration: {total_filtered}")
418
+ print(f"Failed: {total_errors}")
419
+ print(f"Manifest written to: {manifest_path} ({len(shard_manifest)} shards)")
420
+ if total_errors > 0:
421
+ print(f"Error details: {error_log_path}")
422
+
423
+
424
+ if __name__ == "__main__":
425
+ mp.set_start_method("spawn", force=True)
426
+
427
+ args = build_parser().parse_args()
428
+ pack_dataset(
429
+ input_jsonl=args.input,
430
+ output_dir=args.output,
431
+ samples_per_shard=args.shard_size,
432
+ num_workers=args.workers,
433
+ target_sr=args.sr,
434
+ threads_per_worker=args.threads,
435
+ shuffle=args.shuffle,
436
+ shuffle_seed=args.shuffle_seed,
437
+ min_duration=args.min_duration,
438
+ max_duration=args.max_duration,
439
+ )
omnivoice/training/__init__.py ADDED
File without changes
omnivoice/training/builder.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Builders for constructing training components.
19
+
20
+ Provides factory functions to assemble the model, tokenizer, and data loaders
21
+ from a ``TrainingConfig``. Called by ``omnivoice.cli.train`` to set up training.
22
+
23
+ Key functions:
24
+ - ``build_model_and_tokenizer()``: Loads the model and text tokenizer.
25
+ - ``build_dataloaders()``: Builds packed train/eval data loaders
26
+ from a data config JSON.
27
+ """
28
+
29
+ import logging
30
+ from functools import partial
31
+ from typing import Tuple
32
+
33
+ import torch
34
+ from torch.utils.data import DataLoader
35
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
36
+ from transformers import logging as hf_logging
37
+ from transformers.trainer_utils import seed_worker
38
+
39
+ from omnivoice.data.batching import PackingIterableDataset
40
+ from omnivoice.data.collator import PackingDataCollator
41
+ from omnivoice.data.dataset import WebDatasetReader, prepare_data_manifests_from_json
42
+ from omnivoice.data.processor import OmniVoiceSampleProcessor
43
+ from omnivoice.models.omnivoice import OmniVoice, OmniVoiceConfig
44
+ from omnivoice.training.config import TrainingConfig
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ def build_model_and_tokenizer(
50
+ config: TrainingConfig,
51
+ ) -> Tuple[OmniVoice, AutoTokenizer]:
52
+ """Load Tokenizer and Model, handle resizing and special tokens."""
53
+ logger.info("Initializing Model & Tokenizer...")
54
+
55
+ # 1. Tokenizer
56
+ tokenizer_path = (
57
+ config.init_from_checkpoint
58
+ if config.init_from_checkpoint
59
+ else config.llm_name_or_path
60
+ )
61
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
62
+ if tokenizer.pad_token is None:
63
+ tokenizer.pad_token = tokenizer.eos_token
64
+
65
+ new_tokens = [
66
+ "<|denoise|>",
67
+ "<|lang_start|>",
68
+ "<|lang_end|>",
69
+ "<|instruct_start|>",
70
+ "<|instruct_end|>",
71
+ "<|text_start|>",
72
+ "<|text_end|>",
73
+ ]
74
+
75
+ tokens_to_add = [t for t in new_tokens if t not in tokenizer.get_vocab()]
76
+ if tokens_to_add:
77
+ tokenizer.add_special_tokens({"additional_special_tokens": tokens_to_add})
78
+
79
+ if config.init_from_checkpoint:
80
+ logger.info(f"Loading weights from {config.init_from_checkpoint}")
81
+ model = OmniVoice.from_pretrained(
82
+ config.init_from_checkpoint,
83
+ attn_implementation="flex_attention",
84
+ dtype=torch.float32,
85
+ train=True,
86
+ )
87
+ else:
88
+ llm_config = AutoConfig.from_pretrained(config.llm_name_or_path)
89
+
90
+ ov_config = OmniVoiceConfig(
91
+ audio_vocab_size=config.audio_vocab_size,
92
+ audio_mask_id=config.audio_mask_id,
93
+ num_audio_codebook=config.num_audio_codebook,
94
+ audio_codebook_weights=config.audio_codebook_weights,
95
+ llm_config=llm_config,
96
+ )
97
+
98
+ original_level = hf_logging.get_verbosity()
99
+ hf_logging.set_verbosity_error() # suppress expected lm_head.weight warnings
100
+
101
+ llm = AutoModel.from_pretrained(
102
+ config.llm_name_or_path,
103
+ attn_implementation="flex_attention",
104
+ dtype=torch.float32,
105
+ )
106
+
107
+ hf_logging.set_verbosity(original_level)
108
+ model = OmniVoice(config=ov_config, llm=llm)
109
+
110
+ # 3. Resize Embeddings
111
+ if len(tokenizer) != model.config.llm_config.vocab_size:
112
+ model.llm.resize_token_embeddings(len(tokenizer))
113
+ model.config.llm_config.vocab_size = len(tokenizer)
114
+
115
+ # 4. Config IDs
116
+ model.config.pad_token_id = tokenizer.pad_token_id
117
+ model.config.bos_token_id = tokenizer.bos_token_id
118
+ model.config.eos_token_id = tokenizer.eos_token_id
119
+
120
+ return model, tokenizer
121
+
122
+
123
+ def build_dataloaders(
124
+ config: TrainingConfig, tokenizer: AutoTokenizer
125
+ ) -> Tuple[DataLoader, DataLoader]:
126
+ """Setup Data Pipeline: Manifests -> WDS -> Packing -> Loaders."""
127
+ logger.info("Initializing Data Readers...")
128
+
129
+ processor = OmniVoiceSampleProcessor(
130
+ text_tokenizer=tokenizer,
131
+ num_channels=config.num_audio_codebook,
132
+ audio_mask_id=config.audio_mask_id,
133
+ prompt_ratio_range=config.prompt_ratio_range,
134
+ mask_ratio_range=config.mask_ratio_range,
135
+ drop_cond_ratio=config.drop_cond_ratio,
136
+ language_ratio=config.language_ratio,
137
+ use_pinyin_ratio=config.use_pinyin_ratio,
138
+ instruct_ratio=config.instruct_ratio,
139
+ only_instruct_ratio=config.only_instruct_ratio,
140
+ )
141
+
142
+ train_manifests, dev_manifests = prepare_data_manifests_from_json(
143
+ config.data_config
144
+ )
145
+ raw_train_ds = WebDatasetReader(manifests=train_manifests, evaluation=False)
146
+
147
+ train_dataset = PackingIterableDataset(raw_train_ds, processor, config.batch_tokens)
148
+
149
+ collate_fn = PackingDataCollator(processor, config.batch_tokens)
150
+
151
+ init_fn = partial(
152
+ seed_worker,
153
+ num_workers=config.num_workers,
154
+ rank=torch.distributed.get_rank() if torch.distributed.is_initialized() else 0,
155
+ )
156
+
157
+ train_loader = DataLoader(
158
+ train_dataset,
159
+ batch_size=None, # Each item is a batch packed to the target batch_tokens
160
+ num_workers=config.num_workers,
161
+ collate_fn=collate_fn,
162
+ worker_init_fn=init_fn,
163
+ pin_memory=True,
164
+ prefetch_factor=4,
165
+ )
166
+
167
+ eval_loader = None
168
+ if dev_manifests:
169
+ raw_dev_ds = WebDatasetReader(manifests=dev_manifests, evaluation=True)
170
+ dev_dataset = PackingIterableDataset(raw_dev_ds, processor, config.batch_tokens)
171
+ eval_loader = DataLoader(
172
+ dev_dataset,
173
+ batch_size=None, # Each item is a batch packed to the target batch_tokens
174
+ num_workers=1,
175
+ collate_fn=collate_fn,
176
+ pin_memory=True,
177
+ prefetch_factor=2,
178
+ )
179
+
180
+ return train_loader, eval_loader
omnivoice/training/checkpoint.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Checkpoint saving, resuming, and training logging.
19
+
20
+ Provides utilities for saving/loading training checkpoints and logging metrics
21
+ to console and trackers (TensorBoard/WandB). Used by ``OmniTrainer``.
22
+
23
+ Key components:
24
+ - ``TrainLogger``: Logs training metrics to console and Accelerate trackers.
25
+ - ``save_checkpoint()``: Saves model, optimizer, and scheduler state.
26
+ - ``load_checkpoint()``: Restores training state from a checkpoint directory.
27
+ """
28
+
29
+ import logging
30
+ import os
31
+ import shutil
32
+ import time
33
+ from typing import Any, Dict, Optional
34
+
35
+ import torch
36
+ from accelerate import Accelerator
37
+ from tqdm.auto import tqdm
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class TrainLogger:
43
+ """
44
+ Handles logging to console and trackers (TensorBoard/WandB)
45
+ """
46
+
47
+ def __init__(self, accelerator: Accelerator, total_steps: int, logging_steps: int):
48
+ self.accelerator = accelerator
49
+ self.total_steps = total_steps
50
+ self.logging_steps = logging_steps
51
+ self.start_time = None
52
+ self.progress_bar = None
53
+
54
+ def start(self, start_step: int = 0):
55
+ self.start_time = time.time()
56
+
57
+ if self.accelerator.is_main_process:
58
+ self.progress_bar = tqdm(
59
+ total=self.total_steps,
60
+ initial=start_step,
61
+ desc="Training",
62
+ dynamic_ncols=True,
63
+ disable=not self.accelerator.is_local_main_process,
64
+ )
65
+
66
+ def update(
67
+ self, step: int, loss: Optional[float] = None, lr: Optional[float] = None
68
+ ):
69
+ """
70
+ Called every step to update the progress bar UI.
71
+ """
72
+ if self.progress_bar:
73
+ self.progress_bar.update(1)
74
+
75
+ # Update real-time metrics on the progress bar itself
76
+ postfix = {}
77
+ if loss is not None:
78
+ postfix["loss"] = f"{loss:.4f}"
79
+ if lr is not None:
80
+ postfix["lr"] = f"{lr:.2e}"
81
+
82
+ if postfix:
83
+ self.progress_bar.set_postfix(postfix)
84
+
85
+ def log_metrics(self, step: int, metrics: Dict[str, Any]):
86
+ """
87
+ Called periodically to log to TensorBoard/WandB and console.
88
+ """
89
+ # Log to trackers (TensorBoard, etc.)
90
+ self.accelerator.log(metrics, step=step)
91
+
92
+ if self.accelerator.is_main_process:
93
+ # Format for console log (separate from tqdm)
94
+ # Remove keys that are redundant or too verbose for one line
95
+ formatted_metrics = []
96
+ for k, v in metrics.items():
97
+ if isinstance(v, float):
98
+ val_str = f"{v:.4f}"
99
+ if val_str == "0.0000" and v != 0:
100
+ formatted_metrics.append(f"{k}: {v:.2e}")
101
+ else:
102
+ formatted_metrics.append(f"{k}: {val_str}")
103
+ else:
104
+ formatted_metrics.append(f"{k}: {v}")
105
+
106
+ # Use external logger to write to file, tqdm.write to avoid breaking bar
107
+ msg = f"Step {step} | " + " | ".join(formatted_metrics)
108
+ if self.progress_bar:
109
+ self.progress_bar.write(msg)
110
+ else:
111
+ logger.info(msg)
112
+
113
+ def close(self):
114
+ if self.progress_bar:
115
+ self.progress_bar.close()
116
+
117
+
118
+ def save_checkpoint(
119
+ accelerator: Accelerator,
120
+ model: torch.nn.Module,
121
+ tokenizer: Any,
122
+ output_dir: str,
123
+ step: int,
124
+ keep_last_n: int = 3,
125
+ ):
126
+ """
127
+ Saves model, tokenizer, and accelerator states (optimizer/scheduler).
128
+ Manages rotation of checkpoints.
129
+ """
130
+ checkpoint_dir = os.path.join(output_dir, f"checkpoint-{step}")
131
+
132
+ # 1. Save Accelerator State (Optimizer, Scheduler, RNG, Scaler)
133
+ accelerator.save_state(checkpoint_dir)
134
+
135
+ # 2. Save Model in HF format (config.json + pytorch_model.bin/safetensors)
136
+ unwrap_model = accelerator.unwrap_model(model)
137
+ unwrap_model.save_pretrained(
138
+ checkpoint_dir,
139
+ is_main_process=accelerator.is_main_process,
140
+ save_function=accelerator.save,
141
+ )
142
+
143
+ # 3. Save Tokenizer
144
+ if accelerator.is_main_process:
145
+ tokenizer.save_pretrained(checkpoint_dir)
146
+
147
+ logger.info(f"Saved checkpoint to {checkpoint_dir}")
148
+
149
+ # 4. Rotate checkpoints (Keep last N)
150
+ if accelerator.is_main_process and keep_last_n > 0:
151
+ checkpoints = [
152
+ d
153
+ for d in os.listdir(output_dir)
154
+ if d.startswith("checkpoint-")
155
+ and os.path.isdir(os.path.join(output_dir, d))
156
+ ]
157
+ # Sort by step number
158
+ checkpoints.sort(key=lambda x: int(x.split("-")[-1]))
159
+
160
+ if len(checkpoints) > keep_last_n:
161
+ to_remove = checkpoints[:-keep_last_n]
162
+ for d in to_remove:
163
+ shutil.rmtree(os.path.join(output_dir, d))
164
+ logger.info(f"Removed old checkpoint {d}")
165
+
166
+
167
+ def load_checkpoint(accelerator: Accelerator, checkpoint_path: str):
168
+ """
169
+ Resumes training state.
170
+ """
171
+ logger.info(f"Resuming from {checkpoint_path}")
172
+ accelerator.load_state(checkpoint_path)
173
+
174
+ # Try to infer step
175
+ try:
176
+ clean_path = os.path.normpath(checkpoint_path)
177
+ step = int(os.path.basename(clean_path).split("-")[-1])
178
+ return step
179
+ except ValueError:
180
+ return 0
omnivoice/training/config.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Training configuration dataclass.
19
+
20
+ Defines ``TrainingConfig``, a dataclass that holds all hyperparameters and paths
21
+ for training. Loaded from a JSON config file via ``TrainingConfig.from_json()``
22
+ in ``omnivoice.cli.train``.
23
+ """
24
+
25
+ import json
26
+ from dataclasses import asdict, dataclass, field
27
+ from typing import List, Optional, Tuple
28
+
29
+
30
+ @dataclass
31
+ class TrainingConfig:
32
+ # Key Paths
33
+ output_dir: Optional[str] = None
34
+ data_config: Optional[str] = None
35
+
36
+ # Model Specific
37
+ llm_name_or_path: str = "Qwen/Qwen3-0.6B"
38
+ audio_vocab_size: int = 1025 # valid vocab size + 1 (mask token)
39
+ audio_mask_id: int = 1024 # 1024 is the 1025-th token
40
+ num_audio_codebook: int = 8
41
+
42
+ # Model Training Specific
43
+ audio_codebook_weights: List[float | int] = field(
44
+ default_factory=lambda: [8, 8, 6, 6, 4, 4, 2, 2]
45
+ )
46
+ drop_cond_ratio: float = 0.1
47
+ prompt_ratio_range: Tuple[float, float] = field(default_factory=lambda: (0.0, 0.3))
48
+ mask_ratio_range: Tuple[float, float] = field(default_factory=lambda: (0.0, 1.0))
49
+ language_ratio: float = 0.8
50
+ use_pinyin_ratio: float = 0.3
51
+ instruct_ratio: float = 1.0
52
+ only_instruct_ratio: float = 0.5
53
+
54
+ # Init settings
55
+ resume_from_checkpoint: Optional[str] = None
56
+ init_from_checkpoint: Optional[str] = None
57
+
58
+ # Training Hyperparams
59
+ learning_rate: float = 1e-4
60
+ weight_decay: float = 0.01
61
+ max_grad_norm: float = 1.0
62
+ steps: int = 300000
63
+ seed: int = 42
64
+ lr_scheduler_type: str = "cosine"
65
+ warmup_type: str = "ratio"
66
+ warmup_ratio: float = 0.03
67
+ warmup_steps: int = 2000
68
+
69
+ # Data
70
+ batch_tokens: int = 8192
71
+ gradient_accumulation_steps: int = 1
72
+ num_workers: int = 8
73
+
74
+ # System
75
+ mixed_precision: str = "bf16"
76
+ allow_tf32: bool = True
77
+ use_deepspeed: bool = False
78
+ deepspeed_config: Optional[str] = None
79
+
80
+ # Logging
81
+ logging_steps: int = 100
82
+ eval_steps: int = 1000
83
+ save_steps: int = 10000
84
+ keep_last_n_checkpoints: int = -1
85
+
86
+ @classmethod
87
+ def from_json(cls, json_path: str):
88
+ with open(json_path, "r") as f:
89
+ cfg_dict = json.load(f)
90
+ valid_keys = cls.__annotations__.keys()
91
+ filtered_dict = {k: v for k, v in cfg_dict.items() if k in valid_keys}
92
+ instance = cls(**filtered_dict)
93
+ return instance
94
+
95
+ def save_to_json(self, json_path: str):
96
+ data = asdict(self)
97
+ with open(json_path, "w") as f:
98
+ json.dump(data, f, indent=4)
omnivoice/training/trainer.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Training loop for OmniVoice.
19
+
20
+ Wraps the HuggingFace Accelerate training loop with checkpoint saving/resuming,
21
+ evaluation, gradient accumulation, and learning rate scheduling.
22
+ Launched via ``omnivoice.cli.train``.
23
+ """
24
+
25
+ import logging
26
+ import math
27
+ import os
28
+ import sys
29
+ import time
30
+ from datetime import timedelta
31
+ from typing import Any, Optional
32
+
33
+ import torch
34
+ from accelerate import Accelerator, DistributedDataParallelKwargs
35
+ from accelerate.utils import DeepSpeedPlugin, InitProcessGroupKwargs, set_seed
36
+ from torch.utils.data import DataLoader
37
+ from transformers import (
38
+ get_cosine_schedule_with_warmup,
39
+ get_constant_schedule_with_warmup,
40
+ )
41
+
42
+ from omnivoice.training.checkpoint import TrainLogger, load_checkpoint
43
+ from omnivoice.training.checkpoint import save_checkpoint as engine_save_checkpoint
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ class OmniTrainer:
49
+ def __init__(
50
+ self,
51
+ model: torch.nn.Module,
52
+ config: Any, # TrainingConfig
53
+ train_dataloader: DataLoader,
54
+ eval_dataloader: Optional[DataLoader] = None,
55
+ tokenizer: Optional[Any] = None,
56
+ optimizer: Optional[torch.optim.Optimizer] = None,
57
+ lr_scheduler: Optional[Any] = None,
58
+ ):
59
+ self.config = config
60
+ self.model = model
61
+ self.tokenizer = tokenizer
62
+ self.train_dataloader = train_dataloader
63
+ self.eval_dataloader = eval_dataloader
64
+
65
+ # 1. Initialize Accelerator
66
+ self.accelerator = self._init_accelerator()
67
+
68
+ # 2. Setup Optimizer & Scheduler if not provided
69
+ if optimizer is None:
70
+ self.optimizer, self.lr_scheduler = self.create_optimizer_and_scheduler()
71
+ else:
72
+ self.optimizer = optimizer
73
+ self.lr_scheduler = lr_scheduler
74
+
75
+ # 3. DeepSpeed Hack (Batch Size fix)
76
+ if self.accelerator.distributed_type == "DEEPSPEED":
77
+ self.accelerator.state.deepspeed_plugin.deepspeed_config[
78
+ "train_micro_batch_size_per_gpu"
79
+ ] = 1
80
+
81
+ # 4. Prepare with Accelerator
82
+ (self.model, self.optimizer, self.lr_scheduler,) = self.accelerator.prepare(
83
+ self.model,
84
+ self.optimizer,
85
+ self.lr_scheduler,
86
+ )
87
+
88
+ self.global_step = 0
89
+ self.epoch = 0
90
+
91
+ def _init_accelerator(self) -> Accelerator:
92
+ """Initialize Accelerator, DeepSpeed, and Logging."""
93
+ # TF32 setup
94
+ if getattr(self.config, "allow_tf32", False):
95
+ torch.set_float32_matmul_precision("high")
96
+
97
+ # Init handlers
98
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
99
+ init_kwargs = InitProcessGroupKwargs(timeout=timedelta(minutes=60))
100
+
101
+ # DeepSpeed setup
102
+ deepspeed_plugin = None
103
+ if self.config.use_deepspeed and self.config.deepspeed_config:
104
+ if not os.path.exists(self.config.deepspeed_config):
105
+ raise FileNotFoundError(
106
+ f"DeepSpeed config not found: {self.config.deepspeed_config}"
107
+ )
108
+ deepspeed_plugin = DeepSpeedPlugin(
109
+ hf_ds_config=self.config.deepspeed_config,
110
+ gradient_accumulation_steps=self.config.gradient_accumulation_steps,
111
+ gradient_clipping=self.config.max_grad_norm,
112
+ )
113
+
114
+ accelerator = Accelerator(
115
+ gradient_accumulation_steps=self.config.gradient_accumulation_steps,
116
+ mixed_precision=self.config.mixed_precision,
117
+ log_with="tensorboard",
118
+ project_dir=self.config.output_dir,
119
+ step_scheduler_with_optimizer=False,
120
+ kwargs_handlers=[ddp_kwargs, init_kwargs],
121
+ deepspeed_plugin=deepspeed_plugin,
122
+ split_batches=False,
123
+ )
124
+
125
+ # Logging setup
126
+ if accelerator.is_main_process:
127
+ os.makedirs(self.config.output_dir, exist_ok=True)
128
+ # Try to save config if it has the method
129
+ if hasattr(self.config, "save_to_json"):
130
+ self.config.save_to_json(
131
+ os.path.join(self.config.output_dir, "initial_config.json")
132
+ )
133
+
134
+ logging.basicConfig(
135
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
136
+ datefmt="%m/%d/%Y %H:%M:%S",
137
+ level=logging.INFO,
138
+ handlers=[
139
+ logging.StreamHandler(sys.stdout),
140
+ logging.FileHandler(
141
+ os.path.join(self.config.output_dir, "train.log")
142
+ ),
143
+ ],
144
+ )
145
+ else:
146
+ logging.basicConfig(level=logging.ERROR)
147
+
148
+ logger.info(f"Loaded Config: {self.config}")
149
+ set_seed(self.config.seed)
150
+ accelerator.init_trackers("tensorboard")
151
+ return accelerator
152
+
153
+ def create_optimizer_and_scheduler(self):
154
+ """Default AdamW + configurable LR Scheduler."""
155
+ optimizer = torch.optim.AdamW(
156
+ self.model.parameters(),
157
+ lr=self.config.learning_rate,
158
+ weight_decay=self.config.weight_decay,
159
+ )
160
+
161
+ if self.config.warmup_type == "ratio":
162
+ final_warmup_steps = math.ceil(self.config.steps * self.config.warmup_ratio)
163
+ else:
164
+ final_warmup_steps = self.config.warmup_steps
165
+
166
+ if self.config.lr_scheduler_type == "constant":
167
+ lr_scheduler = get_constant_schedule_with_warmup(
168
+ optimizer=optimizer,
169
+ num_warmup_steps=final_warmup_steps,
170
+ )
171
+ else:
172
+ lr_scheduler = get_cosine_schedule_with_warmup(
173
+ optimizer=optimizer,
174
+ num_warmup_steps=final_warmup_steps,
175
+ num_training_steps=self.config.steps,
176
+ )
177
+ return optimizer, lr_scheduler
178
+
179
+ def save_checkpoint(self, step):
180
+ """Wrapper for engine save_checkpoint."""
181
+ engine_save_checkpoint(
182
+ self.accelerator,
183
+ self.model,
184
+ self.tokenizer,
185
+ self.config.output_dir,
186
+ step,
187
+ self.config.keep_last_n_checkpoints,
188
+ )
189
+ # Save config copy for convenience
190
+ if self.accelerator.is_main_process and hasattr(self.config, "save_to_json"):
191
+ checkpoint_dir = os.path.join(self.config.output_dir, f"checkpoint-{step}")
192
+ self.config.save_to_json(os.path.join(checkpoint_dir, "train_config.json"))
193
+
194
+ def load_checkpoint(self, checkpoint_path):
195
+ """Wrapper for loading."""
196
+ step = load_checkpoint(self.accelerator, checkpoint_path)
197
+ self.global_step = step
198
+ logger.info(f"Resumed from step {self.global_step}")
199
+ return step
200
+
201
+ def evaluate(self):
202
+ """Evaluation loop."""
203
+ if self.eval_dataloader is None:
204
+ return {}
205
+
206
+ self.model.eval()
207
+ logger.info(f"Running evaluation at step {self.global_step}...")
208
+
209
+ local_loss_sum = torch.tensor(0.0, device=self.accelerator.device)
210
+ eval_count = 0
211
+
212
+ with torch.no_grad():
213
+ for eval_batch in self.eval_dataloader:
214
+ outputs = self.model(**eval_batch)
215
+ local_loss_sum += outputs.loss.detach()
216
+ eval_count += 1
217
+
218
+ if eval_count > 0:
219
+ local_mean = local_loss_sum / eval_count
220
+ else:
221
+ local_mean = torch.tensor(0.0, device=self.accelerator.device)
222
+
223
+ all_means = self.accelerator.gather(local_mean)
224
+ final_eval_loss = all_means.mean().item()
225
+
226
+ eval_metrics = {"eval/loss": final_eval_loss}
227
+ self.accelerator.log(eval_metrics, step=self.global_step)
228
+ logger.info(f"Eval Loss: {final_eval_loss:.4f}")
229
+
230
+ self.accelerator.wait_for_everyone()
231
+ self.model.train()
232
+ return eval_metrics
233
+
234
+ def train(self):
235
+ """Main training loop."""
236
+ logger.info("Starting Training Loop...")
237
+
238
+ # Resume if configured
239
+ if self.config.resume_from_checkpoint:
240
+ self.load_checkpoint(self.config.resume_from_checkpoint)
241
+
242
+ # Handle IterableDataset Epochs
243
+ if hasattr(self.train_dataloader.dataset, "set_epoch"):
244
+ self.train_dataloader.dataset.set_epoch(self.epoch)
245
+
246
+ # Logger
247
+ train_logger = TrainLogger(
248
+ self.accelerator, self.config.steps, self.config.logging_steps
249
+ )
250
+ train_logger.start(self.global_step)
251
+
252
+ self.model.train()
253
+ train_iterator = iter(self.train_dataloader)
254
+
255
+ logging_start_time = time.time()
256
+ logging_start_step = self.global_step
257
+ tr_loss = torch.tensor(0.0).to(self.accelerator.device)
258
+ logging_loss_scalar = 0.0
259
+
260
+ while self.global_step < self.config.steps:
261
+ try:
262
+ batch = next(train_iterator)
263
+ except StopIteration:
264
+ self.epoch += 1
265
+ logger.info(f"Epoch {self.epoch} starting. Resetting dataloader...")
266
+ if hasattr(self.train_dataloader.dataset, "set_epoch"):
267
+ self.train_dataloader.dataset.set_epoch(self.epoch)
268
+
269
+ train_iterator = iter(self.train_dataloader)
270
+ batch = next(train_iterator)
271
+
272
+ with self.accelerator.accumulate(self.model):
273
+ outputs = self.model(**batch)
274
+ loss = outputs.loss
275
+ tr_loss += loss.detach()
276
+ self.accelerator.backward(loss)
277
+
278
+ if self.accelerator.sync_gradients:
279
+ # Clipping
280
+ grad_norm = 0.0
281
+ if self.config.max_grad_norm > 0:
282
+ grad_norm = self.accelerator.clip_grad_norm_(
283
+ self.model.parameters(), self.config.max_grad_norm
284
+ )
285
+ grad_norm = (
286
+ grad_norm.item() if grad_norm is not None else 0.0
287
+ )
288
+
289
+ self.optimizer.step()
290
+ self.lr_scheduler.step()
291
+ self.optimizer.zero_grad()
292
+ self.global_step += 1
293
+
294
+ # Logging
295
+ current_lr = self.lr_scheduler.get_last_lr()[0]
296
+ train_logger.update(
297
+ step=self.global_step, loss=loss.item(), lr=current_lr
298
+ )
299
+
300
+ if self.global_step % self.config.logging_steps == 0:
301
+ elapsed = time.time() - logging_start_time
302
+ steps_per_sec = (
303
+ (self.global_step - logging_start_step) / elapsed
304
+ if elapsed > 0
305
+ else 0
306
+ )
307
+
308
+ tr_loss_scalar = self.accelerator.gather(tr_loss).mean().item()
309
+ current_interval_loss = tr_loss_scalar - logging_loss_scalar
310
+ avg_loss = current_interval_loss / (
311
+ self.config.logging_steps
312
+ * self.config.gradient_accumulation_steps
313
+ )
314
+ logging_loss_scalar = tr_loss_scalar
315
+
316
+ logs = {
317
+ "train/loss": avg_loss,
318
+ "train/learning_rate": current_lr,
319
+ "train/grad_norm": grad_norm,
320
+ "train/epoch": self.epoch,
321
+ "train/steps_per_sec": steps_per_sec,
322
+ }
323
+ train_logger.log_metrics(step=self.global_step, metrics=logs)
324
+
325
+ logging_start_time = time.time()
326
+ logging_start_step = self.global_step
327
+
328
+ # Evaluate
329
+ if (
330
+ self.eval_dataloader is not None
331
+ and self.global_step % self.config.eval_steps == 0
332
+ ):
333
+ self.evaluate()
334
+
335
+ # Save
336
+ if self.global_step % self.config.save_steps == 0:
337
+ self.save_checkpoint(self.global_step)
338
+
339
+ # Final Save
340
+ self.save_checkpoint(self.global_step)
341
+ train_logger.close()
342
+ self.accelerator.end_training()
omnivoice/utils/__init__.py ADDED
File without changes
omnivoice/utils/audio.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Audio I/O and processing utilities.
19
+
20
+ Provides functions for loading, resampling, silence removal, chunking,
21
+ cross-fading, and format conversion. Used by ``OmniVoice.generate()`` during
22
+ inference post-processing.
23
+ """
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torchaudio
28
+ from pydub import AudioSegment
29
+ from pydub.silence import detect_leading_silence, detect_nonsilent, split_on_silence
30
+
31
+
32
+ def load_audio(audio_path: str, sampling_rate: int):
33
+ """
34
+ Load the waveform with torchaudio and resampling if needed.
35
+
36
+ Parameters:
37
+ audio_path: path of the audio.
38
+ sampling_rate: target sampling rate.
39
+
40
+ Returns:
41
+ Loaded prompt waveform with target sampling rate,
42
+ PyTorch tensor of shape (1, T)
43
+ """
44
+ try:
45
+ waveform, prompt_sampling_rate = torchaudio.load(audio_path)
46
+ except (RuntimeError, OSError):
47
+ # Fallback via pydub+ffmpeg for formats torchaudio can't handle
48
+ aseg = AudioSegment.from_file(audio_path)
49
+ audio_data = np.array(aseg.get_array_of_samples()).astype(np.float32) / 32768.0
50
+ if aseg.channels == 1:
51
+ waveform = torch.from_numpy(audio_data).unsqueeze(0)
52
+ else:
53
+ waveform = torch.from_numpy(audio_data.reshape(-1, aseg.channels).T)
54
+ prompt_sampling_rate = aseg.frame_rate
55
+
56
+ if prompt_sampling_rate != sampling_rate:
57
+ waveform = torchaudio.functional.resample(
58
+ waveform,
59
+ orig_freq=prompt_sampling_rate,
60
+ new_freq=sampling_rate,
61
+ )
62
+ if waveform.shape[0] > 1:
63
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
64
+
65
+ return waveform
66
+
67
+
68
+ def remove_silence(
69
+ audio: torch.Tensor,
70
+ sampling_rate: int,
71
+ mid_sil: int = 300,
72
+ lead_sil: int = 100,
73
+ trail_sil: int = 300,
74
+ ):
75
+ """
76
+ Remove middle silences longer than mid_sil ms, and edge silences longer than edge_sil ms
77
+
78
+ Parameters:
79
+ audio: PyTorch tensor with shape (C, T).
80
+ sampling_rate: sampling rate of the audio.
81
+ mid_sil: the duration of silences in the middle of audio to be removed in ms.
82
+ if mid_sil <= 0, no middle silence will be removed.
83
+ edge_sil: the duration of silences in the edge of audio to be removed in ms.
84
+ trail_sil: the duration of added trailing silence in ms.
85
+
86
+ Returns:
87
+ PyTorch tensor with shape (C, T), where C is number of channels
88
+ and T is number of audio samples
89
+ """
90
+ # Load audio file
91
+ wave = tensor_to_audiosegment(audio, sampling_rate)
92
+
93
+ if mid_sil > 0:
94
+ # Split audio using silences longer than mid_sil
95
+ non_silent_segs = split_on_silence(
96
+ wave,
97
+ min_silence_len=mid_sil,
98
+ silence_thresh=-50,
99
+ keep_silence=mid_sil,
100
+ seek_step=10,
101
+ )
102
+
103
+ # Concatenate all non-silent segments
104
+ wave = AudioSegment.silent(duration=0)
105
+ for seg in non_silent_segs:
106
+ wave += seg
107
+
108
+ # Remove silence longer than 0.1 seconds in the begining and ending of wave
109
+ wave = remove_silence_edges(wave, lead_sil, trail_sil, -50)
110
+
111
+ # Convert to PyTorch tensor
112
+ return audiosegment_to_tensor(wave)
113
+
114
+
115
+ def remove_silence_edges(
116
+ audio: AudioSegment,
117
+ lead_sil: int = 100,
118
+ trail_sil: int = 300,
119
+ silence_threshold: float = -50,
120
+ ):
121
+ """
122
+ Remove edge silences longer than `keep_silence` ms.
123
+
124
+ Parameters:
125
+ audio: an AudioSegment object.
126
+ keep_silence: kept silence in the edge.
127
+ only_edge: If true, only remove edge silences.
128
+ silence_threshold: the threshold of silence.
129
+
130
+ Returns:
131
+ An AudioSegment object
132
+ """
133
+ # Remove heading silence
134
+ start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
135
+ start_idx = max(0, start_idx - lead_sil)
136
+ audio = audio[start_idx:]
137
+
138
+ # Remove trailing silence
139
+ audio = audio.reverse()
140
+ start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
141
+ start_idx = max(0, start_idx - trail_sil)
142
+ audio = audio[start_idx:]
143
+ audio = audio.reverse()
144
+
145
+ return audio
146
+
147
+
148
+ def audiosegment_to_tensor(aseg):
149
+ """
150
+ Convert a pydub.AudioSegment to PyTorch audio tensor
151
+ """
152
+ audio_data = np.array(aseg.get_array_of_samples())
153
+
154
+ # Convert to float32 and normalize to [-1, 1] range
155
+ audio_data = audio_data.astype(np.float32) / 32768.0
156
+
157
+ # Handle channels
158
+ if aseg.channels == 1:
159
+ # Mono channel: add channel dimension (T) -> (1, T)
160
+ tensor_data = torch.from_numpy(audio_data).unsqueeze(0)
161
+ else:
162
+ # Multi-channel: reshape to (C, T)
163
+ tensor_data = torch.from_numpy(audio_data.reshape(-1, aseg.channels).T)
164
+
165
+ return tensor_data
166
+
167
+
168
+ def tensor_to_audiosegment(tensor, sample_rate):
169
+ """
170
+ Convert a PyTorch audio tensor to pydub.AudioSegment
171
+
172
+ Parameters:
173
+ tensor: Tensor with shape (C, T), where C is the number of channels
174
+ and T is the time steps
175
+ sample_rate: Audio sample rate
176
+ """
177
+ # Convert tensor to numpy array
178
+ assert isinstance(tensor, torch.Tensor)
179
+ audio_np = tensor.cpu().numpy()
180
+
181
+ # Convert to int16 type (common format for pydub)
182
+ # Assumes tensor values are in [-1, 1] range as floating point
183
+ audio_np = (audio_np * 32768.0).clip(-32768, 32767).astype(np.int16)
184
+
185
+ # Convert to byte stream
186
+ # For multi-channel audio, pydub requires interleaved format
187
+ # (e.g., left-right-left-right)
188
+ if audio_np.shape[0] > 1:
189
+ # Convert to interleaved format
190
+ audio_np = audio_np.transpose(1, 0).flatten()
191
+ audio_bytes = audio_np.tobytes()
192
+
193
+ # Create AudioSegment
194
+ audio_segment = AudioSegment(
195
+ data=audio_bytes,
196
+ sample_width=2,
197
+ frame_rate=sample_rate,
198
+ channels=tensor.shape[0],
199
+ )
200
+
201
+ return audio_segment
202
+
203
+
204
+ def fade_and_pad_audio(
205
+ audio: torch.Tensor,
206
+ pad_duration: float = 0.1,
207
+ fade_duration: float = 0.1,
208
+ sample_rate: int = 24000,
209
+ ) -> torch.Tensor:
210
+ """
211
+ Applies a smooth fade-in and fade-out to the audio, and then pads both sides
212
+ with pure silence to prevent abrupt starts and ends (clicks/pops).
213
+
214
+ Args:
215
+ audio: PyTorch tensor of shape (C, T) containing audio data.
216
+ pad_duration: Duration of pure silence to add to each end (in seconds).
217
+ fade_duration: Duration of the fade-in/out curve (in seconds).
218
+ sample_rate: Audio sampling rate.
219
+
220
+ Returns:
221
+ Processed sequence tensor with shape (C, T_new)
222
+ """
223
+ if audio.shape[-1] == 0:
224
+ return audio
225
+
226
+ fade_samples = int(fade_duration * sample_rate)
227
+ pad_samples = int(pad_duration * sample_rate)
228
+
229
+ processed = audio.clone()
230
+
231
+ if fade_samples > 0:
232
+ k = min(fade_samples, processed.shape[-1] // 2)
233
+
234
+ if k > 0:
235
+ fade_in = torch.linspace(
236
+ 0, 1, k, device=processed.device, dtype=processed.dtype
237
+ )[None, :]
238
+ processed[..., :k] = processed[..., :k] * fade_in
239
+
240
+ fade_out = torch.linspace(
241
+ 1, 0, k, device=processed.device, dtype=processed.dtype
242
+ )[None, :]
243
+ processed[..., -k:] = processed[..., -k:] * fade_out
244
+
245
+ if pad_samples > 0:
246
+ silence = torch.zeros(
247
+ (processed.shape[0], pad_samples),
248
+ dtype=processed.dtype,
249
+ device=processed.device,
250
+ )
251
+ processed = torch.cat([silence, processed, silence], dim=-1)
252
+
253
+ return processed
254
+
255
+
256
+ def trim_long_audio(
257
+ audio: torch.Tensor,
258
+ sampling_rate: int,
259
+ max_duration: float = 15.0,
260
+ min_duration: float = 3.0,
261
+ trim_threshold: float = 20.0,
262
+ ) -> torch.Tensor:
263
+ """Trim audio to <= max_duration by splitting at the largest silence gap.
264
+
265
+ Only trims when the audio exceeds *trim_threshold* seconds.
266
+
267
+ Args:
268
+ audio: Audio tensor of shape (C, T).
269
+ sampling_rate: Audio sampling rate.
270
+ max_duration: Maximum duration in seconds.
271
+ min_duration: Minimum duration in seconds.
272
+ trim_threshold: Only trim if audio is longer than this (seconds).
273
+
274
+ Returns:
275
+ Trimmed audio tensor.
276
+ """
277
+ duration = audio.size(-1) / sampling_rate
278
+ if duration <= trim_threshold:
279
+ return audio
280
+
281
+ seg = tensor_to_audiosegment(audio, sampling_rate)
282
+ nonsilent = detect_nonsilent(
283
+ seg, min_silence_len=100, silence_thresh=-40, seek_step=10
284
+ )
285
+ if not nonsilent:
286
+ return audio
287
+
288
+ max_ms = int(max_duration * 1000)
289
+ min_ms = int(min_duration * 1000)
290
+
291
+ # Walk through speech regions; at each gap pick the latest split <= max_duration
292
+ best_split = 0
293
+ for start, end in nonsilent:
294
+ if start > best_split and start <= max_ms:
295
+ best_split = start
296
+ if end > max_ms:
297
+ break
298
+
299
+ if best_split < min_ms:
300
+ best_split = min(max_ms, len(seg))
301
+
302
+ trimmed = seg[:best_split]
303
+ return audiosegment_to_tensor(trimmed)
304
+
305
+
306
+ def cross_fade_chunks(
307
+ chunks: list[torch.Tensor],
308
+ sample_rate: int,
309
+ silence_duration: float = 0.3,
310
+ ) -> torch.Tensor:
311
+ """Concatenate audio chunks with a short silence gap and fade at boundaries.
312
+
313
+ Each boundary is structured as: fade-out tail → silence buffer → fade-in head.
314
+ This avoids click artifacts from direct concatenation or overlapping mismatch.
315
+
316
+ Args:
317
+ chunks: List of audio tensors, each (C, T).
318
+ sample_rate: Audio sample rate.
319
+ silence_duration: Total silence gap duration in seconds.
320
+
321
+ Returns:
322
+ Merged audio tensor (C, T_total).
323
+ """
324
+ if len(chunks) == 1:
325
+ return chunks[0]
326
+
327
+ total_n = int(silence_duration * sample_rate)
328
+ fade_n = total_n // 3
329
+ silence_n = fade_n # middle silent gap
330
+ merged = chunks[0].clone()
331
+
332
+ for chunk in chunks[1:]:
333
+ dev, dt = merged.device, merged.dtype
334
+ parts = [merged]
335
+
336
+ # Fade out tail of current merged audio
337
+ fout_n = min(fade_n, merged.size(-1))
338
+ if fout_n > 0:
339
+ w_out = torch.linspace(1, 0, fout_n, device=dev, dtype=dt)[None, :]
340
+ parts[-1][..., -fout_n:] = parts[-1][..., -fout_n:] * w_out
341
+
342
+ # Silent buffer between chunks
343
+ parts.append(torch.zeros(chunks[0].shape[0], silence_n, device=dev, dtype=dt))
344
+
345
+ # Fade in head of next chunk
346
+ fade_in = chunk.clone()
347
+ fin_n = min(fade_n, fade_in.size(-1))
348
+ if fin_n > 0:
349
+ w_in = torch.linspace(0, 1, fin_n, device=dev, dtype=dt)[None, :]
350
+ fade_in[..., :fin_n] = fade_in[..., :fin_n] * w_in
351
+
352
+ parts.append(fade_in)
353
+ merged = torch.cat(parts, dim=-1)
354
+
355
+ return merged
omnivoice/utils/common.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Shared utility functions."""
19
+
20
+ import argparse
21
+ import random
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+
27
+ def str2bool(v):
28
+ """Used in argparse.ArgumentParser.add_argument to indicate
29
+ that a type is a bool type and user can enter
30
+
31
+ - yes, true, t, y, 1, to represent True
32
+ - no, false, f, n, 0, to represent False
33
+
34
+ See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
35
+ """
36
+ if isinstance(v, bool):
37
+ return v
38
+ if v.lower() in ("yes", "true", "t", "y", "1"):
39
+ return True
40
+ elif v.lower() in ("no", "false", "f", "n", "0"):
41
+ return False
42
+ else:
43
+ raise argparse.ArgumentTypeError("Boolean value expected.")
44
+
45
+
46
+ def fix_random_seed(random_seed: int):
47
+ """
48
+ Set the same random seed for the libraries and modules.
49
+ Includes the ``random`` module, numpy, and torch.
50
+ """
51
+ random.seed(random_seed)
52
+ np.random.seed(random_seed)
53
+ torch.random.manual_seed(random_seed)
54
+ # Ensure deterministic ID creation
55
+ rd = random.Random()
56
+ rd.seed(random_seed)
omnivoice/utils/data_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Data utilities for batch inference and evaluation.
19
+
20
+ Provides ``read_test_list()`` to parse JSONL test list files used by
21
+ ``omnivoice.cli.infer_batch`` and evaluation scripts.
22
+ """
23
+
24
+ import json
25
+ import logging
26
+ from pathlib import Path
27
+
28
+
29
+ def read_test_list(path):
30
+ """Read a JSONL test list file.
31
+
32
+ Each line should be a JSON object with fields:
33
+ id, text, ref_audio, ref_text, language_id, language_name, duration, speed
34
+
35
+ language_id, language_name, duration, and speed are optional (default to None).
36
+
37
+ Returns a list of dicts.
38
+ """
39
+ path = Path(path)
40
+ samples = []
41
+ with path.open("r", encoding="utf-8") as f:
42
+ for line_no, line in enumerate(f, 1):
43
+ line = line.strip()
44
+ if not line:
45
+ continue
46
+ try:
47
+ obj = json.loads(line)
48
+ except json.JSONDecodeError:
49
+ logging.warning(f"Skipping malformed JSON at line {line_no}: {line}")
50
+ continue
51
+
52
+ sample = {
53
+ "id": obj.get("id"),
54
+ "text": obj.get("text"),
55
+ "ref_audio": obj.get("ref_audio"),
56
+ "ref_text": obj.get("ref_text"),
57
+ "language_id": obj.get("language_id"),
58
+ "language_name": obj.get("language_name"),
59
+ "duration": obj.get("duration"),
60
+ "speed": obj.get("speed"),
61
+ }
62
+ samples.append(sample)
63
+ return samples
omnivoice/utils/duration.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Text duration estimation for TTS generation.
19
+
20
+ Provides ``RuleDurationEstimator``, which estimates audio duration from text
21
+ using character phonetic weights across 600+ languages. Used by
22
+ ``OmniVoice.generate()`` to determine output length when no duration is specified.
23
+ """
24
+
25
+ import bisect
26
+ import unicodedata
27
+ from functools import lru_cache
28
+ from typing import Optional
29
+
30
+
31
+ class RuleDurationEstimator:
32
+ def __init__(self):
33
+ # ==========================================
34
+ # 1. Phonetic Weights Table
35
+ # ==========================================
36
+ # The weight represents the relative speaking time compared to
37
+ # a standard Latin letter.
38
+ # Benchmark: 1.0 = One Latin Character (~40-50ms)
39
+ self.weights = {
40
+ # --- Logographic (1 char = full syllable/word) ---
41
+ "cjk": 3.0, # Chinese, Japanese Kanji, etc.
42
+ # --- Syllabic / Blocks
43
+ "hangul": 2.5, # Korean Hangul
44
+ "kana": 2.2, # Japanese Hiragana/Katakana
45
+ "ethiopic": 3.0, # Amharic/Ge'ez
46
+ "yi": 3.0, # Yi script
47
+ # --- Abugida (Consonant-Vowel complexes) ---
48
+ "indic": 1.8, # Hindi, Bengali, Tamil, etc.
49
+ "thai_lao": 1.5, # Thai, Lao
50
+ "khmer_myanmar": 1.8, # Khmer, Myanmar
51
+ # --- Abjad (Consonant-heavy) ---
52
+ "arabic": 1.5, # Arabic, Persian, Urdu
53
+ "hebrew": 1.5, # Hebrew
54
+ # --- Alphabet (Segmental) ---
55
+ "latin": 1.0, # English, Spanish, French, Vietnamese, etc. (Baseline)
56
+ "cyrillic": 1.0, # Russian, Ukrainian
57
+ "greek": 1.0, # Greek
58
+ "armenian": 1.0, # Armenian
59
+ "georgian": 1.0, # Georgian
60
+ # --- Symbols & Misc ---
61
+ "punctuation": 0.5, # Pause capability
62
+ "space": 0.2, # Word boundary/Breath (0.05 / 0.22)
63
+ "digit": 3.5, # Numbers
64
+ "mark": 0.0, # Diacritics/Accents (Silent modifiers)
65
+ "default": 1.0, # Fallback for unknown scripts
66
+ }
67
+
68
+ # ==========================================
69
+ # 2. Unicode Range Mapping
70
+ # ==========================================
71
+ # Format: (End_Codepoint, Type_Key)
72
+ # Used for fast binary search (bisect).
73
+ self.ranges = [
74
+ (0x02AF, "latin"), # Latin (Basic, Supplement, Ext, IPA)
75
+ (0x03FF, "greek"), # Greek & Coptic
76
+ (0x052F, "cyrillic"), # Cyrillic
77
+ (0x058F, "armenian"), # Armenian
78
+ (0x05FF, "hebrew"), # Hebrew
79
+ (0x077F, "arabic"), # Arabic, Syriac, Arabic Supplement
80
+ (0x089F, "arabic"), # Arabic Extended-B (+ Syriac Supp)
81
+ (0x08FF, "arabic"), # Arabic Extended-A
82
+ (0x097F, "indic"), # Devanagari
83
+ (0x09FF, "indic"), # Bengali
84
+ (0x0A7F, "indic"), # Gurmukhi
85
+ (0x0AFF, "indic"), # Gujarati
86
+ (0x0B7F, "indic"), # Oriya
87
+ (0x0BFF, "indic"), # Tamil
88
+ (0x0C7F, "indic"), # Telugu
89
+ (0x0CFF, "indic"), # Kannada
90
+ (0x0D7F, "indic"), # Malayalam
91
+ (0x0DFF, "indic"), # Sinhala
92
+ (0x0EFF, "thai_lao"), # Thai & Lao
93
+ (0x0FFF, "indic"), # Tibetan (Abugida)
94
+ (0x109F, "khmer_myanmar"), # Myanmar
95
+ (0x10FF, "georgian"), # Georgian
96
+ (0x11FF, "hangul"), # Hangul Jamo
97
+ (0x137F, "ethiopic"), # Ethiopic
98
+ (0x139F, "ethiopic"), # Ethiopic Supplement
99
+ (0x13FF, "default"), # Cherokee
100
+ (0x167F, "default"), # Canadian Aboriginal Syllabics
101
+ (0x169F, "default"), # Ogham
102
+ (0x16FF, "default"), # Runic
103
+ (0x171F, "default"), # Tagalog (Baybayin)
104
+ (0x173F, "default"), # Hanunoo
105
+ (0x175F, "default"), # Buhid
106
+ (0x177F, "default"), # Tagbanwa
107
+ (0x17FF, "khmer_myanmar"), # Khmer
108
+ (0x18AF, "default"), # Mongolian
109
+ (0x18FF, "default"), # Canadian Aboriginal Syllabics Ext
110
+ (0x194F, "indic"), # Limbu
111
+ (0x19DF, "indic"), # Tai Le & New Tai Lue
112
+ (0x19FF, "khmer_myanmar"), # Khmer Symbols
113
+ (0x1A1F, "indic"), # Buginese
114
+ (0x1AAF, "indic"), # Tai Tham
115
+ (0x1B7F, "indic"), # Balinese
116
+ (0x1BBF, "indic"), # Sundanese
117
+ (0x1BFF, "indic"), # Batak
118
+ (0x1C4F, "indic"), # Lepcha
119
+ (0x1C7F, "indic"), # Ol Chiki (Santali)
120
+ (0x1C8F, "cyrillic"), # Cyrillic Extended-C
121
+ (0x1CBF, "georgian"), # Georgian Extended
122
+ (0x1CCF, "indic"), # Sundanese Supplement
123
+ (0x1CFF, "indic"), # Vedic Extensions
124
+ (0x1D7F, "latin"), # Phonetic Extensions
125
+ (0x1DBF, "latin"), # Phonetic Extensions Supplement
126
+ (0x1DFF, "default"), # Combining Diacritical Marks Supplement
127
+ (0x1EFF, "latin"), # Latin Extended Additional (Vietnamese)
128
+ (0x309F, "kana"), # Hiragana
129
+ (0x30FF, "kana"), # Katakana
130
+ (0x312F, "cjk"), # Bopomofo (Pinyin)
131
+ (0x318F, "hangul"), # Hangul Compatibility Jamo
132
+ (0x9FFF, "cjk"), # CJK Unified Ideographs (Main)
133
+ (0xA4CF, "yi"), # Yi Syllables
134
+ (0xA4FF, "default"), # Lisu
135
+ (0xA63F, "default"), # Vai
136
+ (0xA69F, "cyrillic"), # Cyrillic Extended-B
137
+ (0xA6FF, "default"), # Bamum
138
+ (0xA7FF, "latin"), # Latin Extended-D
139
+ (0xA82F, "indic"), # Syloti Nagri
140
+ (0xA87F, "default"), # Phags-pa
141
+ (0xA8DF, "indic"), # Saurashtra
142
+ (0xA8FF, "indic"), # Devanagari Extended
143
+ (0xA92F, "indic"), # Kayah Li
144
+ (0xA95F, "indic"), # Rejang
145
+ (0xA97F, "hangul"), # Hangul Jamo Extended-A
146
+ (0xA9DF, "indic"), # Javanese
147
+ (0xA9FF, "khmer_myanmar"), # Myanmar Extended-B
148
+ (0xAA5F, "indic"), # Cham
149
+ (0xAA7F, "khmer_myanmar"), # Myanmar Extended-A
150
+ (0xAADF, "indic"), # Tai Viet
151
+ (0xAAFF, "indic"), # Meetei Mayek Extensions
152
+ (0xAB2F, "ethiopic"), # Ethiopic Extended-A
153
+ (0xAB6F, "latin"), # Latin Extended-E
154
+ (0xABBF, "default"), # Cherokee Supplement
155
+ (0xABFF, "indic"), # Meetei Mayek
156
+ (0xD7AF, "hangul"), # Hangul Syllables
157
+ (0xFAFF, "cjk"), # CJK Compatibility
158
+ (0xFDFF, "arabic"), # Arabic Presentation Forms-A
159
+ (0xFE6F, "default"), # Variation Selectors
160
+ (0xFEFF, "arabic"), # Arabic Presentation Forms-B
161
+ (0xFFEF, "latin"), # Fullwidth Latin
162
+ ]
163
+ self.breakpoints = [r[0] for r in self.ranges]
164
+
165
+ @lru_cache(maxsize=4096)
166
+ def _get_char_weight(self, char):
167
+ """Determines the weight of a single character."""
168
+ code = ord(char)
169
+ if (65 <= code <= 90) or (97 <= code <= 122):
170
+ return self.weights["latin"]
171
+ if code == 32:
172
+ return self.weights["space"]
173
+
174
+ # Ignore arabic Tatweel
175
+ if code == 0x0640:
176
+ return self.weights["mark"]
177
+
178
+ category = unicodedata.category(char)
179
+
180
+ if category.startswith("M"):
181
+ return self.weights["mark"]
182
+
183
+ if category.startswith("P") or category.startswith("S"):
184
+ return self.weights["punctuation"]
185
+
186
+ if category.startswith("Z"):
187
+ return self.weights["space"]
188
+
189
+ if category.startswith("N"):
190
+ return self.weights["digit"]
191
+
192
+ # 3. Binary search for Unicode Block (此时区间里绝不会再混进标点符号)
193
+ idx = bisect.bisect_left(self.breakpoints, code)
194
+ if idx < len(self.ranges):
195
+ script_type = self.ranges[idx][1]
196
+ return self.weights.get(script_type, self.weights["default"])
197
+
198
+ # 4. Handle upper planes (CJK Ext B/C/D, Historic scripts)
199
+ if code > 0x20000:
200
+ return self.weights["cjk"]
201
+
202
+ return self.weights["default"]
203
+
204
+ def calculate_total_weight(self, text):
205
+ """Sums up the normalized weights for a string."""
206
+ return sum(self._get_char_weight(c) for c in text)
207
+
208
+ def estimate_duration(
209
+ self,
210
+ target_text: str,
211
+ ref_text: str,
212
+ ref_duration: float,
213
+ low_threshold: Optional[float] = 50,
214
+ boost_strength: float = 3,
215
+ ) -> float:
216
+ """
217
+
218
+ Args:
219
+ target_text (str): The text for which we want to estimate the duration.
220
+ ref_text (str): The reference text that was used to measure
221
+ the ref_duration.
222
+ ref_duration (float): The actual duration it took
223
+ to speak the ref_text.
224
+ low_threshold (float): The minimum duration threshold below which the
225
+ estimation will be considered unreliable.
226
+ boost_strength (float): Controls the power-curve boost for short durations.
227
+ Higher values boost small durations more aggressively.
228
+ 1 = no boost (linear), 2 = sqrt-like
229
+
230
+ Returns:
231
+ float: The estimated duration for the target_text based
232
+ on the ref_text and ref_duration.
233
+ """
234
+ if ref_duration <= 0 or not ref_text:
235
+ return 0.0
236
+
237
+ ref_weight = self.calculate_total_weight(ref_text)
238
+ if ref_weight == 0:
239
+ return 0.0
240
+
241
+ speed_factor = ref_weight / ref_duration
242
+ target_weight = self.calculate_total_weight(target_text)
243
+
244
+ estimated_duration = target_weight / speed_factor
245
+ if low_threshold is not None and estimated_duration < low_threshold:
246
+ alpha = 1.0 / boost_strength
247
+ return low_threshold * (estimated_duration / low_threshold) ** alpha
248
+ else:
249
+ return estimated_duration
250
+
251
+
252
+ # ==========================================
253
+ # Example Usage
254
+ # ==========================================
255
+ if __name__ == "__main__":
256
+ estimator = RuleDurationEstimator()
257
+
258
+ ref_txt = "Hello, world."
259
+ ref_dur = 1.5
260
+
261
+ test_cases = [
262
+ ("Hindi (With complex marks)", "नमस्ते दुनिया"),
263
+ ("Arabic (With vowels)", "مَرْحَبًا بِالْعَالَم"),
264
+ ("Vietnamese (Lots of diacritics)", "Chào thế giới"),
265
+ ("Chinese", "你好,世界!"),
266
+ ("Mixed Emoji", "Hello 🌍! This is fun 🎉"),
267
+ ]
268
+
269
+ print("--- Reference ---")
270
+ print(f"Reference Text: '{ref_txt}'")
271
+ print(f"Reference Duration: {ref_dur}s")
272
+ print("-" * 30)
273
+
274
+ for lang, txt in test_cases:
275
+ est_time = estimator.estimate_duration(txt, ref_txt, ref_dur)
276
+ weight = estimator.calculate_total_weight(txt)
277
+
278
+ print(f"[{lang}]")
279
+ print(f"Text: {txt}")
280
+ print(f"Total Weight: {weight:.2f}")
281
+ print(f"Estimated Duration: {est_time:.2f} s")
282
+ print("-" * 30)
omnivoice/utils/lang_map.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Language name to ISO 639-3 code mapping.
19
+
20
+ Auto-generated from ``docs/lang_id_name_map.tsv``. Provides ``LANG_NAME_TO_ID``
21
+ (for resolving language names to codes) and ``LANG_IDS`` (the set of supported
22
+ ISO 639-3 codes). Used by ``OmniVoice.generate()`` to resolve user-provided
23
+ language names.
24
+ """
25
+
26
+ # Auto-generated from docs/lang_id_name_map.tsv
27
+ # Maps lowercase language name -> language ID code
28
+
29
+ LANG_NAME_TO_ID = {
30
+ "abadi": "kbt",
31
+ "abkhazian": "ab",
32
+ "abron": "abr",
33
+ "abua": "abn",
34
+ "adamawa fulfulde": "fub",
35
+ "adyghe": "ady",
36
+ "afade": "aal",
37
+ "afrikaans": "af",
38
+ "agwagwune": "yay",
39
+ "aja (benin)": "ajg",
40
+ "akebu": "keu",
41
+ "alago": "ala",
42
+ "albanian": "sq",
43
+ "algerian arabic": "arq",
44
+ "algerian saharan arabic": "aao",
45
+ "ambo-pasco quechua": "qva",
46
+ "ambonese malay": "abs",
47
+ "amdo tibetan": "adx",
48
+ "amharic": "am",
49
+ "anaang": "anw",
50
+ "angika": "anp",
51
+ "antankarana malagasy": "xmv",
52
+ "aragonese": "an",
53
+ "arbëreshë albanian": "aae",
54
+ "arequipa-la unión quechua": "qxu",
55
+ "armenian": "hy",
56
+ "ashe": "ahs",
57
+ "ashéninka perené": "prq",
58
+ "askopan": "eiv",
59
+ "assamese": "as",
60
+ "asturian": "ast",
61
+ "atayal": "tay",
62
+ "awak": "awo",
63
+ "ayacucho quechua": "quy",
64
+ "azerbaijani": "az",
65
+ "baatonum": "bba",
66
+ "bacama": "bcy",
67
+ "bade": "bde",
68
+ "bafia": "ksf",
69
+ "bafut": "bfd",
70
+ "bagirmi fulfulde": "fui",
71
+ "bago-kusuntu": "bqg",
72
+ "baharna arabic": "abv",
73
+ "bakoko": "bkh",
74
+ "balanta-ganja": "bjt",
75
+ "balti": "bft",
76
+ "bamenyam": "bce",
77
+ "bamun": "bax",
78
+ "bangwinji": "bsj",
79
+ "banjar": "bjn",
80
+ "bankon": "abb",
81
+ "baoulé": "bci",
82
+ "bara malagasy": "bhr",
83
+ "barok": "bjk",
84
+ "basa (cameroon)": "bas",
85
+ "basa (nigeria)": "bzw",
86
+ "bashkir": "ba",
87
+ "basque": "eu",
88
+ "batak mandailing": "btm",
89
+ "batanga": "bnm",
90
+ "bateri": "btv",
91
+ "bats": "bbl",
92
+ "bayot": "bda",
93
+ "bebele": "beb",
94
+ "belarusian": "be",
95
+ "bengali": "bn",
96
+ "betawi": "bew",
97
+ "bhili": "bhb",
98
+ "bhojpuri": "bho",
99
+ "bilur": "bxf",
100
+ "bima": "bhp",
101
+ "bodo": "brx",
102
+ "boghom": "bux",
103
+ "bokyi": "bky",
104
+ "bomu": "bmq",
105
+ "bondei": "bou",
106
+ "borgu fulfulde": "fue",
107
+ "bosnian": "bs",
108
+ "brahui": "brh",
109
+ "braj": "bra",
110
+ "breton": "br",
111
+ "buduma": "bdm",
112
+ "buginese": "bug",
113
+ "bukharic": "bhh",
114
+ "bulgarian": "bg",
115
+ "bulu (cameroon)": "bum",
116
+ "bundeli": "bns",
117
+ "bunun": "bnn",
118
+ "bura-pabir": "bwr",
119
+ "burak": "bys",
120
+ "burmese": "my",
121
+ "burushaski": "bsk",
122
+ "cacaloxtepec mixtec": "miu",
123
+ "cajatambo north lima quechua": "qvl",
124
+ "cakfem-mushere": "cky",
125
+ "cameroon pidgin": "wes",
126
+ "campidanese sardinian": "sro",
127
+ "cantonese": "yue",
128
+ "catalan": "ca",
129
+ "cebuano": "ceb",
130
+ "cen": "cen",
131
+ "central kurdish": "ckb",
132
+ "central nahuatl": "nhn",
133
+ "central pame": "pbs",
134
+ "central pashto": "pst",
135
+ "central puebla nahuatl": "ncx",
136
+ "central tarahumara": "tar",
137
+ "central yupik": "esu",
138
+ "central-eastern niger fulfulde": "fuq",
139
+ "chadian arabic": "shu",
140
+ "chichewa": "ny",
141
+ "chichicapan zapotec": "zpv",
142
+ "chiga": "cgg",
143
+ "chimalapa zoque": "zoh",
144
+ "chimborazo highland quichua": "qug",
145
+ "chinese": "zh",
146
+ "chiquián ancash quechua": "qxa",
147
+ "chitwania tharu": "the",
148
+ "chokwe": "cjk",
149
+ "chuvash": "cv",
150
+ "cibak": "ckl",
151
+ "coastal konjo": "kjc",
152
+ "copainalá zoque": "zoc",
153
+ "cornish": "kw",
154
+ "corongo ancash quechua": "qwa",
155
+ "croatian": "hr",
156
+ "cross river mbembe": "mfn",
157
+ "cuyamecalco mixtec": "xtu",
158
+ "czech": "cs",
159
+ "dadiya": "dbd",
160
+ "dagbani": "dag",
161
+ "dameli": "dml",
162
+ "danish": "da",
163
+ "dargwa": "dar",
164
+ "dazaga": "dzg",
165
+ "deccan": "dcc",
166
+ "degema": "deg",
167
+ "dera (nigeria)": "kna",
168
+ "dghwede": "dgh",
169
+ "dhatki": "mki",
170
+ "dhivehi": "dv",
171
+ "dhofari arabic": "adf",
172
+ "dijim-bwilim": "cfa",
173
+ "dogri": "dgo",
174
+ "domaaki": "dmk",
175
+ "dotyali": "dty",
176
+ "duala": "dua",
177
+ "dutch": "nl",
178
+ "dũya": "ldb",
179
+ "dyula": "dyu",
180
+ "eastern balochi": "bgp",
181
+ "eastern bolivian guaraní": "gui",
182
+ "eastern egyptian bedawi arabic": "avl",
183
+ "eastern krahn": "kqo",
184
+ "eastern mari": "mhr",
185
+ "eastern yiddish": "ydd",
186
+ "ebrié": "ebr",
187
+ "eggon": "ego",
188
+ "egyptian arabic": "arz",
189
+ "ejagham": "etu",
190
+ "eleme": "elm",
191
+ "eloyi": "afo",
192
+ "embu": "ebu",
193
+ "english": "en",
194
+ "erzya": "myv",
195
+ "esan": "ish",
196
+ "esperanto": "eo",
197
+ "estonian": "et",
198
+ "eton (cameroon)": "eto",
199
+ "ewondo": "ewo",
200
+ "extremaduran": "ext",
201
+ "fang (equatorial guinea)": "fan",
202
+ "fanti": "fat",
203
+ "farefare": "gur",
204
+ "fe'fe'": "fmp",
205
+ "filipino": "fil",
206
+ "filomena mata-coahuitlán totonac": "tlp",
207
+ "finnish": "fi",
208
+ "fipa": "fip",
209
+ "french": "fr",
210
+ "fulah": "ff",
211
+ "galician": "gl",
212
+ "gambian wolof": "wof",
213
+ "ganda": "lg",
214
+ "garhwali": "gbm",
215
+ "gawar-bati": "gwt",
216
+ "gawri": "gwc",
217
+ "gbagyi": "gbr",
218
+ "gbari": "gby",
219
+ "geji": "gyz",
220
+ "gen": "gej",
221
+ "georgian": "ka",
222
+ "german": "de",
223
+ "geser-gorom": "ges",
224
+ "gheg albanian": "aln",
225
+ "ghomálá'": "bbj",
226
+ "gidar": "gid",
227
+ "glavda": "glw",
228
+ "goan konkani": "gom",
229
+ "goaria": "gig",
230
+ "goemai": "ank",
231
+ "gola": "gol",
232
+ "greek": "el",
233
+ "guarani": "gn",
234
+ "guduf-gava": "gdf",
235
+ "guerrero amuzgo": "amu",
236
+ "gujarati": "gu",
237
+ "gujari": "gju",
238
+ "gulf arabic": "afb",
239
+ "gurgula": "ggg",
240
+ "gusii": "guz",
241
+ "gusilay": "gsl",
242
+ "gweno": "gwe",
243
+ "güilá zapotec": "ztu",
244
+ "hadothi": "hoj",
245
+ "hahon": "hah",
246
+ "haitian": "ht",
247
+ "hakha chin": "cnh",
248
+ "hakö": "hao",
249
+ "halia": "hla",
250
+ "hausa": "ha",
251
+ "hawaiian": "haw",
252
+ "hazaragi": "haz",
253
+ "hebrew": "he",
254
+ "hemba": "hem",
255
+ "herero": "hz",
256
+ "highland konjo": "kjk",
257
+ "hijazi arabic": "acw",
258
+ "hindi": "hi",
259
+ "huarijio": "var",
260
+ "huautla mazatec": "mau",
261
+ "huaxcaleca nahuatl": "nhq",
262
+ "huba": "hbb",
263
+ "huitepec mixtec": "mxs",
264
+ "hula": "hul",
265
+ "hungarian": "hu",
266
+ "hunjara-kaina ke": "hkk",
267
+ "hwana": "hwo",
268
+ "ibibio": "ibb",
269
+ "icelandic": "is",
270
+ "idakho-isukha-tiriki": "ida",
271
+ "idoma": "idu",
272
+ "igbo": "ig",
273
+ "igo": "ahl",
274
+ "ikposo": "kpo",
275
+ "ikwere": "ikw",
276
+ "imbabura highland quichua": "qvi",
277
+ "indonesian": "id",
278
+ "indus kohistani": "mvy",
279
+ "interlingua (international auxiliary language association)": "ia",
280
+ "inupiaq": "ik",
281
+ "irish": "ga",
282
+ "iron ossetic": "os",
283
+ "isekiri": "its",
284
+ "isoko": "iso",
285
+ "italian": "it",
286
+ "ito": "itw",
287
+ "itzá": "itz",
288
+ "ixtayutla mixtec": "vmj",
289
+ "izon": "ijc",
290
+ "jambi malay": "jax",
291
+ "japanese": "ja",
292
+ "jaqaru": "jqr",
293
+ "jauja wanca quechua": "qxw",
294
+ "jaunsari": "jns",
295
+ "javanese": "jv",
296
+ "jiba": "juo",
297
+ "jju": "kaj",
298
+ "judeo-moroccan arabic": "aju",
299
+ "juxtlahuaca mixtec": "vmc",
300
+ "kabardian": "kbd",
301
+ "kabras": "lkb",
302
+ "kabuverdianu": "kea",
303
+ "kabyle": "kab",
304
+ "kachi koli": "gjk",
305
+ "kairak": "ckr",
306
+ "kalabari": "ijn",
307
+ "kalasha": "kls",
308
+ "kalenjin": "kln",
309
+ "kalkoti": "xka",
310
+ "kamba": "kam",
311
+ "kamo": "kcq",
312
+ "kanauji": "bjj",
313
+ "kanembu": "kbl",
314
+ "kannada": "kn",
315
+ "karekare": "kai",
316
+ "kashmiri": "ks",
317
+ "kathoriya tharu": "tkt",
318
+ "kati": "bsh",
319
+ "kazakh": "kk",
320
+ "keiyo": "eyo",
321
+ "khams tibetan": "khg",
322
+ "khana": "ogo",
323
+ "khetrani": "xhe",
324
+ "khmer": "km",
325
+ "khowar": "khw",
326
+ "kinga": "zga",
327
+ "kinnauri": "kfk",
328
+ "kinyarwanda": "rw",
329
+ "kirghiz": "ky",
330
+ "kirya-konzəl": "fkk",
331
+ "kochila tharu": "thq",
332
+ "kohistani shina": "plk",
333
+ "kohumono": "bcs",
334
+ "kok borok": "trp",
335
+ "kol (papua new guinea)": "kol",
336
+ "kom (cameroon)": "bkm",
337
+ "koma": "kmy",
338
+ "konkani": "knn",
339
+ "konzo": "koo",
340
+ "korean": "ko",
341
+ "korwa": "kfp",
342
+ "kota (india)": "kfe",
343
+ "koti": "eko",
344
+ "kuanua": "ksd",
345
+ "kuanyama": "kj",
346
+ "kui (india)": "uki",
347
+ "kulung (nigeria)": "bbu",
348
+ "kuot": "kto",
349
+ "kushi": "kuh",
350
+ "kwambi": "kwm",
351
+ "kwasio": "nmg",
352
+ "lala-roba": "lla",
353
+ "lamang": "hia",
354
+ "lao": "lo",
355
+ "larike-wakasihu": "alo",
356
+ "lasi": "lss",
357
+ "latgalian": "ltg",
358
+ "latvian": "lv",
359
+ "levantine arabic": "apc",
360
+ "liana-seti": "ste",
361
+ "liberia kpelle": "xpe",
362
+ "liberian english": "lir",
363
+ "libyan arabic": "ayl",
364
+ "ligurian": "lij",
365
+ "lijili": "mgi",
366
+ "lingala": "ln",
367
+ "lithuanian": "lt",
368
+ "loarki": "lrk",
369
+ "logooli": "rag",
370
+ "logudorese sardinian": "src",
371
+ "loja highland quichua": "qvj",
372
+ "loloda": "loa",
373
+ "longuda": "lnu",
374
+ "loxicha zapotec": "ztp",
375
+ "luba-lulua": "lua",
376
+ "luo": "luo",
377
+ "lushai": "lus",
378
+ "luxembourgish": "lb",
379
+ "maasina fulfulde": "ffm",
380
+ "maba (chad)": "mde",
381
+ "macedo-romanian": "rup",
382
+ "macedonian": "mk",
383
+ "mada (cameroon)": "mxu",
384
+ "mafa": "maf",
385
+ "maithili": "mai",
386
+ "malay": "ms",
387
+ "malayalam": "ml",
388
+ "mali": "gcc",
389
+ "malinaltepec me'phaa": "tcf",
390
+ "maltese": "mt",
391
+ "mandara": "tbf",
392
+ "mandjak": "mfv",
393
+ "manggarai": "mqy",
394
+ "manipuri": "mni",
395
+ "mansoanka": "msw",
396
+ "manx": "gv",
397
+ "maori": "mi",
398
+ "marathi": "mr",
399
+ "marghi central": "mrt",
400
+ "marghi south": "mfm",
401
+ "maria (india)": "mrr",
402
+ "marwari (pakistan)": "mve",
403
+ "masana": "mcn",
404
+ "masikoro malagasy": "msh",
405
+ "matsés": "mcf",
406
+ "mazaltepec zapotec": "zpy",
407
+ "mazatlán mazatec": "vmz",
408
+ "mazatlán mixe": "mzl",
409
+ "mbe": "mfo",
410
+ "mbo (cameroon)": "mbo",
411
+ "mbum": "mdd",
412
+ "medumba": "byv",
413
+ "mekeo": "mek",
414
+ "meru": "mer",
415
+ "mesopotamian arabic": "acm",
416
+ "mewari": "mtr",
417
+ "min nan chinese": "nan",
418
+ "mingrelian": "xmf",
419
+ "mitlatongo mixtec": "vmm",
420
+ "miya": "mkf",
421
+ "mokpwe": "bri",
422
+ "moksha": "mdf",
423
+ "mom jango": "ver",
424
+ "mongolian": "mn",
425
+ "moroccan arabic": "ary",
426
+ "motu": "meu",
427
+ "mpiemo": "mcx",
428
+ "mpumpong": "mgg",
429
+ "mundang": "mua",
430
+ "mungaka": "mhk",
431
+ "musey": "mse",
432
+ "musgu": "mug",
433
+ "musi": "mui",
434
+ "naba": "mne",
435
+ "najdi arabic": "ars",
436
+ "nalik": "nal",
437
+ "nawdm": "nmz",
438
+ "ndonga": "ng",
439
+ "neapolitan": "nap",
440
+ "nepali": "npi",
441
+ "ngamo": "nbh",
442
+ "ngas": "anc",
443
+ "ngiemboon": "nnh",
444
+ "ngizim": "ngi",
445
+ "ngomba": "jgo",
446
+ "ngombale": "nla",
447
+ "nigerian fulfulde": "fuv",
448
+ "nigerian pidgin": "pcm",
449
+ "nimadi": "noe",
450
+ "nobiin": "fia",
451
+ "north mesopotamian arabic": "ayp",
452
+ "north moluccan malay": "max",
453
+ "northern betsimisaraka malagasy": "bmm",
454
+ "northern hindko": "hno",
455
+ "northern kurdish": "kmr",
456
+ "northern pame": "pmq",
457
+ "northern pashto": "pbu",
458
+ "northern uzbek": "uzn",
459
+ "northwest gbaya": "gya",
460
+ "norwegian": "no",
461
+ "norwegian bokmål": "nb",
462
+ "norwegian nynorsk": "nn",
463
+ "notsi": "ncf",
464
+ "nyankpa": "yes",
465
+ "nyungwe": "nyu",
466
+ "nzanyi": "nja",
467
+ "nüpode huitoto": "hux",
468
+ "occitan": "oc",
469
+ "od": "odk",
470
+ "odia": "ory",
471
+ "odual": "odu",
472
+ "omani arabic": "acx",
473
+ "orizaba nahuatl": "nlv",
474
+ "orma": "orc",
475
+ "ormuri": "oru",
476
+ "oromo": "om",
477
+ "pahari-potwari": "phr",
478
+ "paiwan": "pwn",
479
+ "panjabi": "pa",
480
+ "papuan malay": "pmy",
481
+ "parkari koli": "kvx",
482
+ "pedi": "nso",
483
+ "pero": "pip",
484
+ "persian": "fa",
485
+ "petats": "pex",
486
+ "phalura": "phl",
487
+ "piemontese": "pms",
488
+ "piya-kwonci": "piy",
489
+ "plateau malagasy": "plt",
490
+ "polish": "pl",
491
+ "poqomam": "poc",
492
+ "portuguese": "pt",
493
+ "pulaar": "fuc",
494
+ "pular": "fuf",
495
+ "puno quechua": "qxp",
496
+ "pushto": "ps",
497
+ "pökoot": "pko",
498
+ "qaqet": "byx",
499
+ "quiotepec chinantec": "chq",
500
+ "rana tharu": "thr",
501
+ "rangi": "lag",
502
+ "rapoisi": "kyx",
503
+ "ratahan": "rth",
504
+ "rayón zoque": "zor",
505
+ "romanian": "ro",
506
+ "romansh": "rm",
507
+ "rombo": "rof",
508
+ "rotokas": "roo",
509
+ "rukai": "dru",
510
+ "russian": "ru",
511
+ "sacapulteco": "quv",
512
+ "saidi arabic": "aec",
513
+ "sakalava malagasy": "skg",
514
+ "sakizaya": "szy",
515
+ "saleman": "sau",
516
+ "samba daka": "ccg",
517
+ "samba leko": "ndi",
518
+ "san felipe otlaltepec popoloca": "pow",
519
+ "san francisco del mar huave": "hue",
520
+ "san juan atzingo popoloca": "poe",
521
+ "san martín itunyoso triqui": "trq",
522
+ "san miguel el grande mixtec": "mig",
523
+ "sansi": "ssi",
524
+ "sanskrit": "sa",
525
+ "santa ana de tusi pasco quechua": "qxt",
526
+ "santa catarina albarradas zapotec": "ztn",
527
+ "santali": "sat",
528
+ "santiago del estero quichua": "qus",
529
+ "saposa": "sps",
530
+ "saraiki": "skr",
531
+ "sardinian": "sc",
532
+ "saya": "say",
533
+ "sediq": "trv",
534
+ "serbian": "sr",
535
+ "seri": "sei",
536
+ "shina": "scl",
537
+ "shona": "sn",
538
+ "siar-lak": "sjr",
539
+ "sibe": "nco",
540
+ "sicilian": "scn",
541
+ "sihuas ancash quechua": "qws",
542
+ "sikkimese": "sip",
543
+ "sinaugoro": "snc",
544
+ "sindhi": "sd",
545
+ "sindhi bhil": "sbn",
546
+ "sinhala": "si",
547
+ "sinicahua mixtec": "xti",
548
+ "sipacapense": "qum",
549
+ "siwai": "siw",
550
+ "slovak": "sk",
551
+ "slovenian": "sl",
552
+ "solos": "sol",
553
+ "somali": "so",
554
+ "soninke": "snk",
555
+ "south giziga": "giz",
556
+ "south ucayali ashéninka": "cpy",
557
+ "southeastern nochixtlán mixtec": "mxy",
558
+ "southern betsimisaraka malagasy": "bzc",
559
+ "southern pashto": "pbt",
560
+ "southern pastaza quechua": "qup",
561
+ "soyaltepec mazatec": "vmp",
562
+ "spanish": "es",
563
+ "standard arabic": "arb",
564
+ "standard moroccan tamazight": "zgh",
565
+ "sudanese arabic": "apd",
566
+ "sulka": "sua",
567
+ "svan": "sva",
568
+ "swahili": "sw",
569
+ "swedish": "sv",
570
+ "tae'": "rob",
571
+ "tahaggart tamahaq": "thv",
572
+ "taita": "dav",
573
+ "tajik": "tg",
574
+ "tamil": "ta",
575
+ "tandroy-mahafaly malagasy": "tdx",
576
+ "tangale": "tan",
577
+ "tanosy malagasy": "txy",
578
+ "tarok": "yer",
579
+ "tatar": "tt",
580
+ "tedaga": "tuq",
581
+ "telugu": "te",
582
+ "tem": "kdh",
583
+ "teop": "tio",
584
+ "tepeuxila cuicatec": "cux",
585
+ "tepinapa chinantec": "cte",
586
+ "tera": "ttr",
587
+ "terei": "buo",
588
+ "termanu": "twu",
589
+ "tesaka malagasy": "tkg",
590
+ "tetelcingo nahuatl": "nhg",
591
+ "teutila cuicatec": "cut",
592
+ "thai": "th",
593
+ "tibetan": "bo",
594
+ "tidaá mixtec": "mtx",
595
+ "tidore": "tvo",
596
+ "tigak": "tgc",
597
+ "tigre": "tig",
598
+ "tigrinya": "ti",
599
+ "tilquiapan zapotec": "zts",
600
+ "tinputz": "tpz",
601
+ "tlacoapa me'phaa": "tpl",
602
+ "tlacoatzintepec chinantec": "ctl",
603
+ "tlingit": "tli",
604
+ "toki pona": "tok",
605
+ "tomoip": "tqp",
606
+ "tondano": "tdn",
607
+ "tonsea": "txs",
608
+ "tooro": "ttj",
609
+ "torau": "ttu",
610
+ "torwali": "trw",
611
+ "tsimihety malagasy": "xmw",
612
+ "tsotso": "lto",
613
+ "tswana": "tn",
614
+ "tugen": "tuy",
615
+ "tuki": "bag",
616
+ "tula": "tul",
617
+ "tulu": "tcy",
618
+ "tunen": "tvu",
619
+ "tungag": "lcm",
620
+ "tunisian arabic": "aeb",
621
+ "tupuri": "tui",
622
+ "turkana": "tuv",
623
+ "turkish": "tr",
624
+ "turkmen": "tk",
625
+ "tututepec mixtec": "mtu",
626
+ "twi": "tw",
627
+ "ubaghara": "byc",
628
+ "uighur": "ug",
629
+ "ukrainian": "uk",
630
+ "umbundu": "umb",
631
+ "upper sorbian": "hsb",
632
+ "urdu": "ur",
633
+ "ushojo": "ush",
634
+ "uzbek": "uz",
635
+ "vai": "vai",
636
+ "vietnamese": "vi",
637
+ "votic": "vot",
638
+ "võro": "vro",
639
+ "waci gbe": "wci",
640
+ "wadiyara koli": "kxp",
641
+ "waja": "wja",
642
+ "wakhi": "wbl",
643
+ "wanga": "lwg",
644
+ "wapan": "juk",
645
+ "warji": "wji",
646
+ "welsh": "cy",
647
+ "wemale": "weo",
648
+ "western frisian": "fy",
649
+ "western highland purepecha": "pua",
650
+ "western juxtlahuaca mixtec": "jmx",
651
+ "western maninkakan": "mlq",
652
+ "western mari": "mrj",
653
+ "western niger fulfulde": "fuh",
654
+ "western panjabi": "pnb",
655
+ "wolof": "wo",
656
+ "wuzlam": "udl",
657
+ "xanaguía zapotec": "ztg",
658
+ "xhosa": "xh",
659
+ "yace": "ekr",
660
+ "yakut": "sah",
661
+ "yalahatan": "jal",
662
+ "yanahuanca pasco quechua": "qur",
663
+ "yangben": "yav",
664
+ "yaqui": "yaq",
665
+ "yauyos quechua": "qux",
666
+ "yekhee": "ets",
667
+ "yiddish": "yi",
668
+ "yidgha": "ydg",
669
+ "yoruba": "yo",
670
+ "yutanduchi mixtec": "mab",
671
+ "zacatlán-ahuacatlán-tepetzintla nahuatl": "nhi",
672
+ "zarma": "dje",
673
+ "zaza": "zza",
674
+ "zulu": "zu",
675
+ "ömie": "aom",
676
+ }
677
+
678
+ LANG_NAMES = set(LANG_NAME_TO_ID.keys())
679
+ LANG_IDS = set(LANG_NAME_TO_ID.values())
680
+
681
+ # Exceptions where .title() doesn't match the canonical casing from the TSV.
682
+ _TITLE_EXCEPTIONS = {
683
+ "fe'fe'": "Fe'fe'",
684
+ "dũya": "Dũya",
685
+ "santiago del estero quichua": "Santiago del Estero Quichua",
686
+ "santa ana de tusi pasco quechua": "Santa Ana de Tusi Pasco Quechua",
687
+ "malinaltepec me'phaa": "Malinaltepec Me'phaa",
688
+ "tlacoapa me'phaa": "Tlacoapa Me'phaa",
689
+ }
690
+
691
+
692
+ def lang_display_name(name: str) -> str:
693
+ """Return a display-friendly version of a lowercase language name.
694
+
695
+ Uses .title() for most names, with manual exceptions for cases like
696
+ apostrophes and small words (de, del) that should stay lowercase.
697
+ """
698
+ return _TITLE_EXCEPTIONS.get(name, name.title())
omnivoice/utils/text.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Text processing utilities for TTS inference.
19
+
20
+ Provides:
21
+ - ``chunk_text_punctuation()``: Splits long text into model-friendly chunks at
22
+ sentence boundaries, with abbreviation-aware punctuation splitting.
23
+ - ``add_punctuation()``: Appends missing end punctuation (Chinese or English).
24
+ """
25
+
26
+ from typing import List, Optional
27
+
28
+
29
+ SPLIT_PUNCTUATION = set(".,;:!?。,;:!?")
30
+ CLOSING_MARKS = set("\"'""')]》》>」】")
31
+
32
+ END_PUNCTUATION = {
33
+ ";",
34
+ ":",
35
+ ",",
36
+ ".",
37
+ "!",
38
+ "?",
39
+ "…",
40
+ ")",
41
+ "]",
42
+ "}",
43
+ '"',
44
+ "'",
45
+ """,
46
+ "'",
47
+ ";",
48
+ ":",
49
+ ",",
50
+ "。",
51
+ "!",
52
+ "?",
53
+ "、",
54
+ "……",
55
+ ")",
56
+ "】",
57
+ """,
58
+ "'",
59
+ }
60
+
61
+
62
+ ABBREVIATIONS = {
63
+ "Mr.",
64
+ "Mrs.",
65
+ "Ms.",
66
+ "Dr.",
67
+ "Prof.",
68
+ "Sr.",
69
+ "Jr.",
70
+ "Rev.",
71
+ "Fr.",
72
+ "Hon.",
73
+ "Pres.",
74
+ "Gov.",
75
+ "Capt.",
76
+ "Gen.",
77
+ "Sen.",
78
+ "Rep.",
79
+ "Col.",
80
+ "Maj.",
81
+ "Lt.",
82
+ "Cmdr.",
83
+ "Sgt.",
84
+ "Cpl.",
85
+ "Co.",
86
+ "Corp.",
87
+ "Inc.",
88
+ "Ltd.",
89
+ "Est.",
90
+ "Dept.",
91
+ "St.",
92
+ "Ave.",
93
+ "Blvd.",
94
+ "Rd.",
95
+ "Mt.",
96
+ "Ft.",
97
+ "No.",
98
+ "Jan.",
99
+ "Feb.",
100
+ "Mar.",
101
+ "Apr.",
102
+ "Aug.",
103
+ "Sep.",
104
+ "Sept.",
105
+ "Oct.",
106
+ "Nov.",
107
+ "Dec.",
108
+ "i.e.",
109
+ "e.g.",
110
+ "vs.",
111
+ "Vs.",
112
+ "Etc.",
113
+ "approx.",
114
+ "fig.",
115
+ "def.",
116
+ }
117
+
118
+
119
+ def chunk_text_punctuation(
120
+ text: str,
121
+ chunk_len: int,
122
+ min_chunk_len: Optional[int] = None,
123
+ ) -> List[str]:
124
+ """
125
+ Splits the input tokens list into chunks according to punctuations,
126
+ avoiding splits on common abbreviations (e.g., Mr., No.).
127
+ """
128
+
129
+ # 1. Split the tokens according to punctuations.
130
+ sentences = []
131
+ current_sentence = []
132
+
133
+ tokens_list = list(text)
134
+
135
+ for token in tokens_list:
136
+ # If the first token of current sentence is punctuation,
137
+ # append it to the end of the previous sentence.
138
+ if (
139
+ len(current_sentence) == 0
140
+ and len(sentences) != 0
141
+ and (token in SPLIT_PUNCTUATION or token in CLOSING_MARKS)
142
+ ):
143
+ sentences[-1].append(token)
144
+ # Otherwise, append the current token to the current sentence.
145
+ else:
146
+ current_sentence.append(token)
147
+
148
+ # Split the sentence in positions of punctuations.
149
+ if token in SPLIT_PUNCTUATION:
150
+ is_abbreviation = False
151
+
152
+ if token == ".":
153
+ temp_str = "".join(current_sentence).strip()
154
+ if temp_str:
155
+ last_word = temp_str.split()[-1]
156
+ if last_word in ABBREVIATIONS:
157
+ is_abbreviation = True
158
+
159
+ if not is_abbreviation:
160
+ sentences.append(current_sentence)
161
+ current_sentence = []
162
+ # Assume the last few tokens are also a sentence
163
+ if len(current_sentence) != 0:
164
+ sentences.append(current_sentence)
165
+
166
+ # 2. Merge short sentences.
167
+ merged_chunks = []
168
+ current_chunk = []
169
+ for sentence in sentences:
170
+ if len(current_chunk) + len(sentence) <= chunk_len:
171
+ current_chunk.extend(sentence)
172
+ else:
173
+ if len(current_chunk) > 0:
174
+ merged_chunks.append(current_chunk)
175
+ current_chunk = sentence
176
+
177
+ if len(current_chunk) > 0:
178
+ merged_chunks.append(current_chunk)
179
+
180
+ # 4. Post-process: Check for undersized chunks and merge them
181
+ # with the previous chunk or next chunk (if it's the first chunk).
182
+ if min_chunk_len is not None:
183
+ first_chunk_short_flag = (
184
+ len(merged_chunks) > 0 and len(merged_chunks[0]) < min_chunk_len
185
+ )
186
+ final_chunks = []
187
+ for i, chunk in enumerate(merged_chunks):
188
+ if i == 1 and first_chunk_short_flag:
189
+ final_chunks[-1].extend(chunk)
190
+ else:
191
+ if len(chunk) >= min_chunk_len:
192
+ final_chunks.append(chunk)
193
+ else:
194
+ if len(final_chunks) == 0:
195
+ final_chunks.append(chunk)
196
+ else:
197
+ final_chunks[-1].extend(chunk)
198
+ else:
199
+ final_chunks = merged_chunks
200
+
201
+ chunk_strings = [
202
+ "".join(chunk).strip() for chunk in final_chunks if "".join(chunk).strip()
203
+ ]
204
+ return chunk_strings
205
+
206
+
207
+ def add_punctuation(text: str):
208
+ """Add punctuation if there is not in the end of text"""
209
+ text = text.strip()
210
+
211
+ if not text:
212
+ return text
213
+
214
+ if text[-1] not in END_PUNCTUATION:
215
+ is_chinese = any("\u4e00" <= char <= "\u9fff" for char in text)
216
+
217
+ text += "。" if is_chinese else "."
218
+
219
+ return text
omnivoice/utils/voice_design.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Voice-design instruct constants for TTS inference.
19
+
20
+ Defines speaker attribute tags (gender, age, pitch, accent, dialect) and
21
+ translation/validation utilities between English and Chinese. Used by
22
+ ``OmniVoice.generate()`` for voice design mode.
23
+ """
24
+
25
+ import re
26
+
27
+ _ZH_RE = re.compile(r'[\u4e00-\u9fff]')
28
+
29
+ # Category = set of {english: chinese, ...} items that are mutually exclusive.
30
+ # Accent (EN-only) and dialect (ZH-only) are stored as flat sets below.
31
+ _INSTRUCT_CATEGORIES = [
32
+ {"male": "男", "female": "女"},
33
+ {"child": "儿童", "teenager": "少年", "young adult": "青年",
34
+ "middle-aged": "中年", "elderly": "老年"},
35
+ {"very low pitch": "极低音调", "low pitch": "低音调",
36
+ "moderate pitch": "中音调", "high pitch": "高音调",
37
+ "very high pitch": "极高音调"},
38
+ {"whisper": "耳语"},
39
+ # Accent (English-only, no Chinese counterpart)
40
+ {"american accent", "british accent", "australian accent",
41
+ "chinese accent", "canadian accent", "indian accent",
42
+ "korean accent", "portuguese accent", "russian accent", "japanese accent"},
43
+ # Dialect (Chinese-only, no English counterpart)
44
+ {"河南话", "陕西话", "四川话", "贵州话", "云南话", "桂林话",
45
+ "济南话", "石家庄话", "甘肃话", "宁夏话", "青岛话", "东北话"},
46
+ ]
47
+
48
+ _INSTRUCT_EN_TO_ZH = {}
49
+ _INSTRUCT_ZH_TO_EN = {}
50
+ _INSTRUCT_MUTUALLY_EXCLUSIVE = []
51
+ for _cat in _INSTRUCT_CATEGORIES:
52
+ if isinstance(_cat, dict):
53
+ _INSTRUCT_EN_TO_ZH.update(_cat)
54
+ _INSTRUCT_ZH_TO_EN.update({v: k for k, v in _cat.items()})
55
+ _INSTRUCT_MUTUALLY_EXCLUSIVE.append(set(_cat) | set(_cat.values()))
56
+ else:
57
+ _INSTRUCT_MUTUALLY_EXCLUSIVE.append(set(_cat))
58
+
59
+ _INSTRUCT_ALL_VALID = (
60
+ set(_INSTRUCT_EN_TO_ZH) | set(_INSTRUCT_ZH_TO_EN)
61
+ | _INSTRUCT_MUTUALLY_EXCLUSIVE[-2] # accents
62
+ | _INSTRUCT_MUTUALLY_EXCLUSIVE[-1] # dialects
63
+ )
64
+
65
+ _INSTRUCT_VALID_EN = frozenset(i for i in _INSTRUCT_ALL_VALID if not _ZH_RE.search(i))
66
+ _INSTRUCT_VALID_ZH = frozenset(i for i in _INSTRUCT_ALL_VALID if _ZH_RE.search(i))
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu128
2
+ torch==2.8.0
3
+ torchaudio==2.8.0
4
+ transformers==5.3
5
+ accelerate
6
+ pydub
7
+ soundfile
8
+ numpy
9
+ gradio
10
+ hf_transfer