Spaces:
Running on Zero
Running on Zero
Upload 48 files
Browse files- app.py +128 -0
- omnivoice/__init__.py +28 -0
- omnivoice/cli/__init__.py +0 -0
- omnivoice/cli/demo.py +533 -0
- omnivoice/cli/infer.py +157 -0
- omnivoice/cli/infer_batch.py +523 -0
- omnivoice/cli/train.py +74 -0
- omnivoice/data/__init__.py +0 -0
- omnivoice/data/batching.py +166 -0
- omnivoice/data/collator.py +92 -0
- omnivoice/data/dataset.py +551 -0
- omnivoice/data/processor.py +258 -0
- omnivoice/eval/__init__.py +4 -0
- omnivoice/eval/models/ecapa_tdnn_wavlm.py +374 -0
- omnivoice/eval/models/utmos.py +370 -0
- omnivoice/eval/mos/utmos.py +299 -0
- omnivoice/eval/speaker_similarity/sim.py +321 -0
- omnivoice/eval/utils.py +80 -0
- omnivoice/eval/wer/common.py +88 -0
- omnivoice/eval/wer/fleurs.py +517 -0
- omnivoice/eval/wer/hubert.py +318 -0
- omnivoice/eval/wer/minimax.py +596 -0
- omnivoice/eval/wer/norm_config_module.py +291 -0
- omnivoice/eval/wer/punctuations.lst +188 -0
- omnivoice/eval/wer/seedtts.py +413 -0
- omnivoice/eval/wer/sensevoice.py +344 -0
- omnivoice/eval/wer/text_norm_omni.py +113 -0
- omnivoice/models/__init__.py +0 -0
- omnivoice/models/omnivoice.py +1502 -0
- omnivoice/scripts/__init__.py +0 -0
- omnivoice/scripts/denoise_audio.py +1048 -0
- omnivoice/scripts/extract_audio_tokens.py +625 -0
- omnivoice/scripts/extract_audio_tokens_add_noise.py +825 -0
- omnivoice/scripts/jsonl_to_webdataset.py +439 -0
- omnivoice/training/__init__.py +0 -0
- omnivoice/training/builder.py +180 -0
- omnivoice/training/checkpoint.py +180 -0
- omnivoice/training/config.py +98 -0
- omnivoice/training/trainer.py +342 -0
- omnivoice/utils/__init__.py +0 -0
- omnivoice/utils/audio.py +355 -0
- omnivoice/utils/common.py +56 -0
- omnivoice/utils/data_utils.py +63 -0
- omnivoice/utils/duration.py +282 -0
- omnivoice/utils/lang_map.py +698 -0
- omnivoice/utils/text.py +219 -0
- omnivoice/utils/voice_design.py +66 -0
- requirements.txt +10 -0
app.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
HuggingFace Space entry point for OmniVoice demo.
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import tempfile
|
| 10 |
+
from typing import Any, Dict
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torchaudio
|
| 14 |
+
|
| 15 |
+
from omnivoice import OmniVoice, OmniVoiceGenerationConfig
|
| 16 |
+
from omnivoice.cli.demo import build_demo
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
logging.basicConfig(level=logging.INFO)
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Hardware detection
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
+
logger.info(f"Using device: {DEVICE}")
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Model loading
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
CHECKPOINT = os.environ.get("OMNIVOICE_MODEL", "k2-fsa/OmniVoice")
|
| 31 |
+
|
| 32 |
+
logger.info(f"Loading model from {CHECKPOINT} on {DEVICE} ...")
|
| 33 |
+
model = OmniVoice.from_pretrained(
|
| 34 |
+
CHECKPOINT,
|
| 35 |
+
device_map=DEVICE,
|
| 36 |
+
dtype=torch.float16,
|
| 37 |
+
load_asr=True,
|
| 38 |
+
)
|
| 39 |
+
logger.info("Model loaded on %s.", DEVICE)
|
| 40 |
+
sampling_rate = model.sampling_rate
|
| 41 |
+
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# Generation logic (outside build_demo so we can wrap with spaces.GPU)
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _gen_core(
|
| 48 |
+
text,
|
| 49 |
+
language,
|
| 50 |
+
ref_audio,
|
| 51 |
+
instruct,
|
| 52 |
+
num_step,
|
| 53 |
+
guidance_scale,
|
| 54 |
+
denoise,
|
| 55 |
+
speed,
|
| 56 |
+
duration,
|
| 57 |
+
preprocess_prompt,
|
| 58 |
+
postprocess_output,
|
| 59 |
+
mode,
|
| 60 |
+
ref_text=None,
|
| 61 |
+
):
|
| 62 |
+
if not text or not text.strip():
|
| 63 |
+
return None, "Please enter the text to synthesize."
|
| 64 |
+
|
| 65 |
+
gen_config = OmniVoiceGenerationConfig(
|
| 66 |
+
num_step=int(num_step or 32),
|
| 67 |
+
guidance_scale=float(guidance_scale) if guidance_scale is not None else 2.0,
|
| 68 |
+
denoise=bool(denoise) if denoise is not None else True,
|
| 69 |
+
preprocess_prompt=bool(preprocess_prompt),
|
| 70 |
+
postprocess_output=bool(postprocess_output),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
lang = language if (language and language != "Auto") else None
|
| 74 |
+
|
| 75 |
+
kw: Dict[str, Any] = dict(
|
| 76 |
+
text=text.strip(), language=lang, generation_config=gen_config
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
if speed is not None and float(speed) != 1.0:
|
| 80 |
+
kw["speed"] = float(speed)
|
| 81 |
+
if duration is not None and float(duration) > 0:
|
| 82 |
+
kw["duration"] = float(duration)
|
| 83 |
+
|
| 84 |
+
if mode == "clone":
|
| 85 |
+
if not ref_audio:
|
| 86 |
+
return None, "Please upload a reference audio."
|
| 87 |
+
kw["voice_clone_prompt"] = model.create_voice_clone_prompt(
|
| 88 |
+
ref_audio=ref_audio,
|
| 89 |
+
ref_text=ref_text,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if mode == "design":
|
| 93 |
+
if instruct and instruct.strip():
|
| 94 |
+
kw["instruct"] = instruct.strip()
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
out_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
|
| 98 |
+
audio = model.generate(**kw)
|
| 99 |
+
torchaudio.save(out_path, audio[0], sampling_rate)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
return None, f"Error: {type(e).__name__}: {e}"
|
| 102 |
+
|
| 103 |
+
return out_path, "Done."
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
# ZeroGPU wrapper
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
generate_fn = None
|
| 110 |
+
try:
|
| 111 |
+
import spaces
|
| 112 |
+
|
| 113 |
+
@spaces.GPU()
|
| 114 |
+
def _gen_gpu(*args, **kwargs):
|
| 115 |
+
return _gen_core(*args, **kwargs)
|
| 116 |
+
|
| 117 |
+
generate_fn = _gen_gpu
|
| 118 |
+
logger.info("Using spaces.GPU() wrapper.")
|
| 119 |
+
except ImportError:
|
| 120 |
+
logger.info("spaces module not found, running without GPU wrapper.")
|
| 121 |
+
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
# Build and launch demo — reuses the full UI from omnivoice.cli.demo
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
demo = build_demo(model, CHECKPOINT, generate_fn=generate_fn)
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
demo.queue().launch()
|
omnivoice/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from importlib.metadata import PackageNotFoundError, version
|
| 3 |
+
|
| 4 |
+
warnings.filterwarnings("ignore", module="torchaudio")
|
| 5 |
+
warnings.filterwarnings(
|
| 6 |
+
"ignore",
|
| 7 |
+
category=SyntaxWarning,
|
| 8 |
+
message="invalid escape sequence",
|
| 9 |
+
module="pydub.utils",
|
| 10 |
+
)
|
| 11 |
+
warnings.filterwarnings(
|
| 12 |
+
"ignore",
|
| 13 |
+
category=FutureWarning,
|
| 14 |
+
module="torch.distributed.algorithms.ddp_comm_hooks",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
__version__ = version("omnivoice")
|
| 19 |
+
except PackageNotFoundError:
|
| 20 |
+
__version__ = "0.0.0"
|
| 21 |
+
|
| 22 |
+
from omnivoice.models.omnivoice import (
|
| 23 |
+
OmniVoice,
|
| 24 |
+
OmniVoiceConfig,
|
| 25 |
+
OmniVoiceGenerationConfig,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
__all__ = ["OmniVoice", "OmniVoiceConfig", "OmniVoiceGenerationConfig"]
|
omnivoice/cli/__init__.py
ADDED
|
File without changes
|
omnivoice/cli/demo.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
"""
|
| 18 |
+
Gradio demo for OmniVoice.
|
| 19 |
+
|
| 20 |
+
Supports voice cloning and voice design.
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
omnivoice-demo --model /path/to/checkpoint --port 8000
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import logging
|
| 28 |
+
from typing import Any, Dict
|
| 29 |
+
|
| 30 |
+
import gradio as gr
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
from omnivoice import OmniVoice, OmniVoiceGenerationConfig
|
| 35 |
+
from omnivoice.utils.lang_map import LANG_NAMES, lang_display_name
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_best_device():
|
| 39 |
+
"""Auto-detect the best available device: CUDA > MPS > CPU."""
|
| 40 |
+
if torch.cuda.is_available():
|
| 41 |
+
return "cuda"
|
| 42 |
+
if torch.backends.mps.is_available():
|
| 43 |
+
return "mps"
|
| 44 |
+
return "cpu"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Language list — all 600+ supported languages
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
_ALL_LANGUAGES = ["Auto"] + sorted(lang_display_name(n) for n in LANG_NAMES)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# Voice Design instruction templates
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
# Each option is displayed as "English / 中文".
|
| 57 |
+
# The model expects English for accents and Chinese for dialects.
|
| 58 |
+
_CATEGORIES = {
|
| 59 |
+
"Gender / 性别": ["Male / 男", "Female / 女"],
|
| 60 |
+
"Age / 年龄": [
|
| 61 |
+
"Child / 儿童",
|
| 62 |
+
"Teenager / 少年",
|
| 63 |
+
"Young Adult / 青年",
|
| 64 |
+
"Middle-aged / 中年",
|
| 65 |
+
"Elderly / 老年",
|
| 66 |
+
],
|
| 67 |
+
"Pitch / 音调": [
|
| 68 |
+
"Very Low Pitch / 极低音调",
|
| 69 |
+
"Low Pitch / 低音调",
|
| 70 |
+
"Moderate Pitch / 中音调",
|
| 71 |
+
"High Pitch / 高音调",
|
| 72 |
+
"Very High Pitch / 极高音调",
|
| 73 |
+
],
|
| 74 |
+
"Style / 风格": ["Whisper / 耳语"],
|
| 75 |
+
"English Accent / 英文口音": [
|
| 76 |
+
"American Accent / 美式口音",
|
| 77 |
+
"Australian Accent / 澳大利亚口音",
|
| 78 |
+
"British Accent / 英国口音",
|
| 79 |
+
"Chinese Accent / 中国口音",
|
| 80 |
+
"Canadian Accent / 加拿大口音",
|
| 81 |
+
"Indian Accent / 印度口音",
|
| 82 |
+
"Korean Accent / 韩国口音",
|
| 83 |
+
"Portuguese Accent / 葡萄牙口音",
|
| 84 |
+
"Russian Accent / 俄罗斯口音",
|
| 85 |
+
"Japanese Accent / 日本口音",
|
| 86 |
+
],
|
| 87 |
+
"Chinese Dialect / 中文方言": [
|
| 88 |
+
"Henan Dialect / 河南话",
|
| 89 |
+
"Shaanxi Dialect / 陕西话",
|
| 90 |
+
"Sichuan Dialect / 四川话",
|
| 91 |
+
"Guizhou Dialect / 贵州话",
|
| 92 |
+
"Yunnan Dialect / 云南话",
|
| 93 |
+
"Guilin Dialect / 桂林话",
|
| 94 |
+
"Jinan Dialect / 济南话",
|
| 95 |
+
"Shijiazhuang Dialect / 石家庄话",
|
| 96 |
+
"Gansu Dialect / 甘肃话",
|
| 97 |
+
"Ningxia Dialect / 宁夏话",
|
| 98 |
+
"Qingdao Dialect / 青岛话",
|
| 99 |
+
"Northeast Dialect / 东北话",
|
| 100 |
+
],
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
_ATTR_INFO = {
|
| 104 |
+
"English Accent / 英文口音": "Only effective for English speech.",
|
| 105 |
+
"Chinese Dialect / 中文方言": "Only effective for Chinese speech.",
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
# Argument parser
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 114 |
+
parser = argparse.ArgumentParser(
|
| 115 |
+
prog="omnivoice-demo",
|
| 116 |
+
description="Launch a Gradio demo for OmniVoice.",
|
| 117 |
+
formatter_class=argparse.RawTextHelpFormatter,
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--model",
|
| 121 |
+
default="k2-fsa/OmniVoice",
|
| 122 |
+
help="Model checkpoint path or HuggingFace repo id.",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--device", default=None, help="Device to use. Auto-detected if not specified."
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument("--ip", default="0.0.0.0", help="Server IP (default: 0.0.0.0).")
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--port", type=int, default=7860, help="Server port (default: 7860)."
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--root-path",
|
| 133 |
+
default=None,
|
| 134 |
+
help="Root path for reverse proxy.",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--share", action="store_true", default=False, help="Create public link."
|
| 138 |
+
)
|
| 139 |
+
return parser
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ---------------------------------------------------------------------------
|
| 143 |
+
# Build demo
|
| 144 |
+
# ---------------------------------------------------------------------------
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def build_demo(
|
| 148 |
+
model: OmniVoice,
|
| 149 |
+
checkpoint: str,
|
| 150 |
+
generate_fn=None,
|
| 151 |
+
) -> gr.Blocks:
|
| 152 |
+
|
| 153 |
+
sampling_rate = model.sampling_rate
|
| 154 |
+
|
| 155 |
+
# -- shared generation core --
|
| 156 |
+
def _gen_core(
|
| 157 |
+
text,
|
| 158 |
+
language,
|
| 159 |
+
ref_audio,
|
| 160 |
+
instruct,
|
| 161 |
+
num_step,
|
| 162 |
+
guidance_scale,
|
| 163 |
+
denoise,
|
| 164 |
+
speed,
|
| 165 |
+
duration,
|
| 166 |
+
preprocess_prompt,
|
| 167 |
+
postprocess_output,
|
| 168 |
+
mode,
|
| 169 |
+
ref_text=None,
|
| 170 |
+
):
|
| 171 |
+
if not text or not text.strip():
|
| 172 |
+
return None, "Please enter the text to synthesize."
|
| 173 |
+
|
| 174 |
+
gen_config = OmniVoiceGenerationConfig(
|
| 175 |
+
num_step=int(num_step or 32),
|
| 176 |
+
guidance_scale=float(guidance_scale) if guidance_scale is not None else 2.0,
|
| 177 |
+
denoise=bool(denoise) if denoise is not None else True,
|
| 178 |
+
preprocess_prompt=bool(preprocess_prompt),
|
| 179 |
+
postprocess_output=bool(postprocess_output),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
lang = language if (language and language != "Auto") else None
|
| 183 |
+
|
| 184 |
+
kw: Dict[str, Any] = dict(
|
| 185 |
+
text=text.strip(), language=lang, generation_config=gen_config
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
if speed is not None and float(speed) != 1.0:
|
| 189 |
+
kw["speed"] = float(speed)
|
| 190 |
+
if duration is not None and float(duration) > 0:
|
| 191 |
+
kw["duration"] = float(duration)
|
| 192 |
+
|
| 193 |
+
if mode == "clone":
|
| 194 |
+
if not ref_audio:
|
| 195 |
+
return None, "Please upload a reference audio."
|
| 196 |
+
kw["voice_clone_prompt"] = model.create_voice_clone_prompt(
|
| 197 |
+
ref_audio=ref_audio,
|
| 198 |
+
ref_text=ref_text,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if mode == "design":
|
| 202 |
+
if instruct and instruct.strip():
|
| 203 |
+
kw["instruct"] = instruct.strip()
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
audio = model.generate(**kw)
|
| 207 |
+
except Exception as e:
|
| 208 |
+
return None, f"Error: {type(e).__name__}: {e}"
|
| 209 |
+
|
| 210 |
+
waveform = audio[0].squeeze(0).numpy() # (T,)
|
| 211 |
+
waveform = (waveform * 32767).astype(np.int16)
|
| 212 |
+
return (sampling_rate, waveform), "Done."
|
| 213 |
+
|
| 214 |
+
# Allow external wrappers (e.g. spaces.GPU for ZeroGPU Spaces)
|
| 215 |
+
_gen = generate_fn if generate_fn is not None else _gen_core
|
| 216 |
+
|
| 217 |
+
# =====================================================================
|
| 218 |
+
# UI
|
| 219 |
+
# =====================================================================
|
| 220 |
+
theme = gr.themes.Soft(
|
| 221 |
+
font=["Inter", "Arial", "sans-serif"],
|
| 222 |
+
)
|
| 223 |
+
css = """
|
| 224 |
+
.gradio-container {max-width: 100% !important; font-size: 16px !important;}
|
| 225 |
+
.gradio-container h1 {font-size: 1.5em !important;}
|
| 226 |
+
.gradio-container .prose {font-size: 1.1em !important;}
|
| 227 |
+
.compact-audio audio {height: 60px !important;}
|
| 228 |
+
.compact-audio .waveform {min-height: 80px !important;}
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
# Reusable: language dropdown component
|
| 232 |
+
def _lang_dropdown(label="Language (optional) / 语种 (可选)", value="Auto"):
|
| 233 |
+
return gr.Dropdown(
|
| 234 |
+
label=label,
|
| 235 |
+
choices=_ALL_LANGUAGES,
|
| 236 |
+
value=value,
|
| 237 |
+
allow_custom_value=False,
|
| 238 |
+
interactive=True,
|
| 239 |
+
info="Keep as Auto to auto-detect the language.",
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Reusable: optional generation settings accordion
|
| 243 |
+
def _gen_settings():
|
| 244 |
+
with gr.Accordion("Generation Settings (optional)", open=False):
|
| 245 |
+
sp = gr.Slider(
|
| 246 |
+
0.7,
|
| 247 |
+
1.3,
|
| 248 |
+
value=1.0,
|
| 249 |
+
step=0.05,
|
| 250 |
+
label="Speed",
|
| 251 |
+
info="1.0 = normal. >1 faster, <1 slower. Ignored if Duration is set.",
|
| 252 |
+
)
|
| 253 |
+
du = gr.Number(
|
| 254 |
+
value=None,
|
| 255 |
+
label="Duration (seconds)",
|
| 256 |
+
info=(
|
| 257 |
+
"Leave empty to use speed."
|
| 258 |
+
" Set a fixed duration to override speed."
|
| 259 |
+
),
|
| 260 |
+
)
|
| 261 |
+
ns = gr.Slider(
|
| 262 |
+
4,
|
| 263 |
+
64,
|
| 264 |
+
value=32,
|
| 265 |
+
step=1,
|
| 266 |
+
label="Inference Steps",
|
| 267 |
+
info="Default: 32. Lower = faster, higher = better quality.",
|
| 268 |
+
)
|
| 269 |
+
dn = gr.Checkbox(
|
| 270 |
+
label="Denoise",
|
| 271 |
+
value=True,
|
| 272 |
+
info="Default: enabled. Uncheck to disable denoising.",
|
| 273 |
+
)
|
| 274 |
+
gs = gr.Slider(
|
| 275 |
+
0.0,
|
| 276 |
+
4.0,
|
| 277 |
+
value=2.0,
|
| 278 |
+
step=0.1,
|
| 279 |
+
label="Guidance Scale (CFG)",
|
| 280 |
+
info="Default: 2.0.",
|
| 281 |
+
)
|
| 282 |
+
pp = gr.Checkbox(
|
| 283 |
+
label="Preprocess Prompt",
|
| 284 |
+
value=True,
|
| 285 |
+
info="apply silence removal and trimming to the reference "
|
| 286 |
+
"audio, add punctuation in the end of reference text (if not already)",
|
| 287 |
+
)
|
| 288 |
+
po = gr.Checkbox(
|
| 289 |
+
label="Postprocess Output",
|
| 290 |
+
value=True,
|
| 291 |
+
info="Remove long silences from generated audio.",
|
| 292 |
+
)
|
| 293 |
+
return ns, gs, dn, sp, du, pp, po
|
| 294 |
+
|
| 295 |
+
with gr.Blocks(theme=theme, css=css, title="OmniVoice Demo") as demo:
|
| 296 |
+
gr.Markdown(
|
| 297 |
+
"""
|
| 298 |
+
# OmniVoice Demo
|
| 299 |
+
|
| 300 |
+
State-of-the-art text-to-speech model for **600+ languages**, supporting:
|
| 301 |
+
|
| 302 |
+
- **Voice Clone** — Clone any voice from a reference audio
|
| 303 |
+
- **Voice Design** — Create custom voices with speaker attributes
|
| 304 |
+
|
| 305 |
+
Built with [OmniVoice](https://github.com/k2-fsa/OmniVoice)
|
| 306 |
+
by Xiaomi Next-gen Kaldi team.
|
| 307 |
+
"""
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
with gr.Tabs():
|
| 311 |
+
# ==============================================================
|
| 312 |
+
# Voice Clone
|
| 313 |
+
# ==============================================================
|
| 314 |
+
with gr.TabItem("Voice Clone"):
|
| 315 |
+
with gr.Row():
|
| 316 |
+
with gr.Column(scale=1):
|
| 317 |
+
vc_text = gr.Textbox(
|
| 318 |
+
label="Text to Synthesize / 待合成文本",
|
| 319 |
+
lines=4,
|
| 320 |
+
placeholder="Enter the text you want to synthesize...",
|
| 321 |
+
)
|
| 322 |
+
vc_ref_audio = gr.Audio(
|
| 323 |
+
label="Reference Audio / 参考音频",
|
| 324 |
+
type="filepath",
|
| 325 |
+
elem_classes="compact-audio",
|
| 326 |
+
)
|
| 327 |
+
gr.Markdown(
|
| 328 |
+
"<span style='font-size:0.85em;color:#888;'>"
|
| 329 |
+
"Recommended: 3–10 seconds audio. "
|
| 330 |
+
"</span>"
|
| 331 |
+
)
|
| 332 |
+
vc_ref_text = gr.Textbox(
|
| 333 |
+
label=("Reference Text (optional)" " / 参考音频文本(可选)"),
|
| 334 |
+
lines=2,
|
| 335 |
+
placeholder="Transcript of the reference audio. Leave empty"
|
| 336 |
+
" to auto-transcribe via ASR models.",
|
| 337 |
+
)
|
| 338 |
+
vc_lang = _lang_dropdown("Language (optional) / 语种 (可选)")
|
| 339 |
+
(
|
| 340 |
+
vc_ns,
|
| 341 |
+
vc_gs,
|
| 342 |
+
vc_dn,
|
| 343 |
+
vc_sp,
|
| 344 |
+
vc_du,
|
| 345 |
+
vc_pp,
|
| 346 |
+
vc_po,
|
| 347 |
+
) = _gen_settings()
|
| 348 |
+
vc_btn = gr.Button("Generate / 生成", variant="primary")
|
| 349 |
+
with gr.Column(scale=1):
|
| 350 |
+
vc_audio = gr.Audio(
|
| 351 |
+
label="Output Audio / 合成结果",
|
| 352 |
+
type="numpy",
|
| 353 |
+
)
|
| 354 |
+
vc_status = gr.Textbox(label="Status / 状态", lines=2)
|
| 355 |
+
|
| 356 |
+
def _clone_fn(
|
| 357 |
+
text, lang, ref_aud, ref_text, ns, gs, dn, sp, du, pp, po
|
| 358 |
+
):
|
| 359 |
+
return _gen(
|
| 360 |
+
text,
|
| 361 |
+
lang,
|
| 362 |
+
ref_aud,
|
| 363 |
+
None,
|
| 364 |
+
ns,
|
| 365 |
+
gs,
|
| 366 |
+
dn,
|
| 367 |
+
sp,
|
| 368 |
+
du,
|
| 369 |
+
pp,
|
| 370 |
+
po,
|
| 371 |
+
mode="clone",
|
| 372 |
+
ref_text=ref_text or None,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
vc_btn.click(
|
| 376 |
+
_clone_fn,
|
| 377 |
+
inputs=[
|
| 378 |
+
vc_text,
|
| 379 |
+
vc_lang,
|
| 380 |
+
vc_ref_audio,
|
| 381 |
+
vc_ref_text,
|
| 382 |
+
vc_ns,
|
| 383 |
+
vc_gs,
|
| 384 |
+
vc_dn,
|
| 385 |
+
vc_sp,
|
| 386 |
+
vc_du,
|
| 387 |
+
vc_pp,
|
| 388 |
+
vc_po,
|
| 389 |
+
],
|
| 390 |
+
outputs=[vc_audio, vc_status],
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# ==============================================================
|
| 394 |
+
# Voice Design
|
| 395 |
+
# ==============================================================
|
| 396 |
+
with gr.TabItem("Voice Design"):
|
| 397 |
+
with gr.Row():
|
| 398 |
+
with gr.Column(scale=1):
|
| 399 |
+
vd_text = gr.Textbox(
|
| 400 |
+
label="Text to Synthesize / 待合成文本",
|
| 401 |
+
lines=4,
|
| 402 |
+
placeholder="Enter the text you want to synthesize...",
|
| 403 |
+
)
|
| 404 |
+
vd_lang = _lang_dropdown()
|
| 405 |
+
|
| 406 |
+
_AUTO = "Auto"
|
| 407 |
+
vd_groups = []
|
| 408 |
+
for _cat, _choices in _CATEGORIES.items():
|
| 409 |
+
vd_groups.append(
|
| 410 |
+
gr.Dropdown(
|
| 411 |
+
label=_cat,
|
| 412 |
+
choices=[_AUTO] + _choices,
|
| 413 |
+
value=_AUTO,
|
| 414 |
+
info=_ATTR_INFO.get(_cat),
|
| 415 |
+
)
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
(
|
| 419 |
+
vd_ns,
|
| 420 |
+
vd_gs,
|
| 421 |
+
vd_dn,
|
| 422 |
+
vd_sp,
|
| 423 |
+
vd_du,
|
| 424 |
+
vd_pp,
|
| 425 |
+
vd_po,
|
| 426 |
+
) = _gen_settings()
|
| 427 |
+
vd_btn = gr.Button("Generate / 生成", variant="primary")
|
| 428 |
+
with gr.Column(scale=1):
|
| 429 |
+
vd_audio = gr.Audio(
|
| 430 |
+
label="Output Audio / 合成结果",
|
| 431 |
+
type="numpy",
|
| 432 |
+
)
|
| 433 |
+
vd_status = gr.Textbox(label="Status / 状态", lines=2)
|
| 434 |
+
|
| 435 |
+
def _build_instruct(groups):
|
| 436 |
+
"""Extract instruct text from UI dropdowns.
|
| 437 |
+
|
| 438 |
+
Language unification and validation is handled by
|
| 439 |
+
_resolve_instruct inside _preprocess_all.
|
| 440 |
+
"""
|
| 441 |
+
selected = [g for g in groups if g and g != "Auto"]
|
| 442 |
+
if not selected:
|
| 443 |
+
return None
|
| 444 |
+
parts = []
|
| 445 |
+
for v in selected:
|
| 446 |
+
if " / " in v:
|
| 447 |
+
en, zh = v.split(" / ", 1)
|
| 448 |
+
# Dialects have no English equivalent
|
| 449 |
+
if "Dialect" in v.split(" / ")[0]:
|
| 450 |
+
parts.append(zh.strip())
|
| 451 |
+
else:
|
| 452 |
+
parts.append(en.strip())
|
| 453 |
+
else:
|
| 454 |
+
parts.append(v)
|
| 455 |
+
return ", ".join(parts)
|
| 456 |
+
|
| 457 |
+
def _design_fn(text, lang, ns, gs, dn, sp, du, pp, po, *groups):
|
| 458 |
+
return _gen(
|
| 459 |
+
text,
|
| 460 |
+
lang,
|
| 461 |
+
None,
|
| 462 |
+
_build_instruct(groups),
|
| 463 |
+
ns,
|
| 464 |
+
gs,
|
| 465 |
+
dn,
|
| 466 |
+
sp,
|
| 467 |
+
du,
|
| 468 |
+
pp,
|
| 469 |
+
po,
|
| 470 |
+
mode="design",
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
vd_btn.click(
|
| 474 |
+
_design_fn,
|
| 475 |
+
inputs=[
|
| 476 |
+
vd_text,
|
| 477 |
+
vd_lang,
|
| 478 |
+
vd_ns,
|
| 479 |
+
vd_gs,
|
| 480 |
+
vd_dn,
|
| 481 |
+
vd_sp,
|
| 482 |
+
vd_du,
|
| 483 |
+
vd_pp,
|
| 484 |
+
vd_po,
|
| 485 |
+
]
|
| 486 |
+
+ vd_groups,
|
| 487 |
+
outputs=[vd_audio, vd_status],
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
return demo
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
# ---------------------------------------------------------------------------
|
| 494 |
+
# Main
|
| 495 |
+
# ---------------------------------------------------------------------------
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def main(argv=None) -> int:
|
| 499 |
+
logging.basicConfig(
|
| 500 |
+
level=logging.INFO,
|
| 501 |
+
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
|
| 502 |
+
)
|
| 503 |
+
parser = build_parser()
|
| 504 |
+
args = parser.parse_args(argv)
|
| 505 |
+
|
| 506 |
+
device = args.device or get_best_device()
|
| 507 |
+
|
| 508 |
+
checkpoint = args.model
|
| 509 |
+
if not checkpoint:
|
| 510 |
+
parser.print_help()
|
| 511 |
+
return 0
|
| 512 |
+
logging.info(f"Loading model from {checkpoint}, device={device} ...")
|
| 513 |
+
model = OmniVoice.from_pretrained(
|
| 514 |
+
checkpoint,
|
| 515 |
+
device_map=device,
|
| 516 |
+
dtype=torch.float16,
|
| 517 |
+
load_asr=True,
|
| 518 |
+
)
|
| 519 |
+
print("Model loaded.")
|
| 520 |
+
|
| 521 |
+
demo = build_demo(model, checkpoint)
|
| 522 |
+
|
| 523 |
+
demo.queue().launch(
|
| 524 |
+
server_name=args.ip,
|
| 525 |
+
server_port=args.port,
|
| 526 |
+
share=args.share,
|
| 527 |
+
root_path=args.root_path,
|
| 528 |
+
)
|
| 529 |
+
return 0
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
if __name__ == "__main__":
|
| 533 |
+
raise SystemExit(main())
|
omnivoice/cli/infer.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Single-item inference CLI for OmniVoice.
|
| 2 |
+
|
| 3 |
+
Generates audio from a single text input using voice cloning,
|
| 4 |
+
voice design, or auto voice.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
# Voice cloning
|
| 8 |
+
omnivoice-infer --model k2-fsa/OmniVoice \
|
| 9 |
+
--text "Hello, this is a text for text-to-speech." \
|
| 10 |
+
--ref_audio ref.wav --ref_text "Reference transcript." --output out.wav
|
| 11 |
+
|
| 12 |
+
# Voice design
|
| 13 |
+
omnivoice-infer --model k2-fsa/OmniVoice \
|
| 14 |
+
--text "Hello, this is a text for text-to-speech." \
|
| 15 |
+
--instruct "male, British accent" --output out.wav
|
| 16 |
+
|
| 17 |
+
# Auto voice
|
| 18 |
+
omnivoice-infer --model k2-fsa/OmniVoice \
|
| 19 |
+
--text "Hello, this is a text for text-to-speech." --output out.wav
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import logging
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torchaudio
|
| 27 |
+
|
| 28 |
+
from omnivoice.models.omnivoice import OmniVoice
|
| 29 |
+
from omnivoice.utils.common import str2bool
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_best_device():
|
| 33 |
+
"""Auto-detect the best available device: CUDA > MPS > CPU."""
|
| 34 |
+
if torch.cuda.is_available():
|
| 35 |
+
return "cuda"
|
| 36 |
+
if torch.backends.mps.is_available():
|
| 37 |
+
return "mps"
|
| 38 |
+
return "cpu"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_parser() -> argparse.ArgumentParser:
|
| 42 |
+
parser = argparse.ArgumentParser(
|
| 43 |
+
description="OmniVoice single-item inference",
|
| 44 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--model",
|
| 48 |
+
type=str,
|
| 49 |
+
default="k2-fsa/OmniVoice",
|
| 50 |
+
help="Model checkpoint path or HuggingFace repo id.",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--text",
|
| 54 |
+
type=str,
|
| 55 |
+
required=True,
|
| 56 |
+
help="Text to synthesize.",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--output",
|
| 60 |
+
type=str,
|
| 61 |
+
required=True,
|
| 62 |
+
help="Output WAV file path.",
|
| 63 |
+
)
|
| 64 |
+
# Voice cloning
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--ref_audio",
|
| 67 |
+
type=str,
|
| 68 |
+
default=None,
|
| 69 |
+
help="Reference audio file path for voice cloning.",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--ref_text",
|
| 73 |
+
type=str,
|
| 74 |
+
default=None,
|
| 75 |
+
help="Reference text describing the reference audio.",
|
| 76 |
+
)
|
| 77 |
+
# Voice design
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--instruct",
|
| 80 |
+
type=str,
|
| 81 |
+
default=None,
|
| 82 |
+
help="Style instruction for voice design mode.",
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--language",
|
| 86 |
+
type=str,
|
| 87 |
+
default=None,
|
| 88 |
+
help="Language name (e.g. 'English') or code (e.g. 'en').",
|
| 89 |
+
)
|
| 90 |
+
# Generation parameters
|
| 91 |
+
parser.add_argument("--num_step", type=int, default=32)
|
| 92 |
+
parser.add_argument("--guidance_scale", type=float, default=2.0)
|
| 93 |
+
parser.add_argument("--speed", type=float, default=1.0)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--duration",
|
| 96 |
+
type=float,
|
| 97 |
+
default=None,
|
| 98 |
+
help="Fixed output duration in seconds. If set, overrides the "
|
| 99 |
+
"model's duration estimation. The speed factor is automatically "
|
| 100 |
+
"adjusted to match while preserving language-aware pacing.",
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument("--t_shift", type=float, default=0.1)
|
| 103 |
+
parser.add_argument("--denoise", type=str2bool, default=True)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--postprocess_output",
|
| 106 |
+
type=str2bool,
|
| 107 |
+
default=True,
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument("--layer_penalty_factor", type=float, default=5.0)
|
| 110 |
+
parser.add_argument("--position_temperature", type=float, default=5.0)
|
| 111 |
+
parser.add_argument("--class_temperature", type=float, default=0.0)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--device",
|
| 114 |
+
type=str,
|
| 115 |
+
default=None,
|
| 116 |
+
help="Device to use for inference. Auto-detected if not specified.",
|
| 117 |
+
)
|
| 118 |
+
return parser
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main():
|
| 122 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 123 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 124 |
+
|
| 125 |
+
args = get_parser().parse_args()
|
| 126 |
+
|
| 127 |
+
device = args.device or get_best_device()
|
| 128 |
+
logging.info(f"Loading model from {args.model} on {device} ...")
|
| 129 |
+
model = OmniVoice.from_pretrained(
|
| 130 |
+
args.model, device_map=device, dtype=torch.float16
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
logging.info(f"Generating audio for: {args.text[:80]}...")
|
| 134 |
+
audios = model.generate(
|
| 135 |
+
text=args.text,
|
| 136 |
+
language=args.language,
|
| 137 |
+
ref_audio=args.ref_audio,
|
| 138 |
+
ref_text=args.ref_text,
|
| 139 |
+
instruct=args.instruct,
|
| 140 |
+
duration=args.duration,
|
| 141 |
+
num_step=args.num_step,
|
| 142 |
+
guidance_scale=args.guidance_scale,
|
| 143 |
+
speed=args.speed,
|
| 144 |
+
t_shift=args.t_shift,
|
| 145 |
+
denoise=args.denoise,
|
| 146 |
+
postprocess_output=args.postprocess_output,
|
| 147 |
+
layer_penalty_factor=args.layer_penalty_factor,
|
| 148 |
+
position_temperature=args.position_temperature,
|
| 149 |
+
class_temperature=args.class_temperature,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
torchaudio.save(args.output, audios[0], model.sampling_rate)
|
| 153 |
+
logging.info(f"Saved to {args.output}")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
main()
|
omnivoice/cli/infer_batch.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Batch inference CLI for OmniVoice.
|
| 19 |
+
|
| 20 |
+
Distributes TTS generation across multiple GPUs for large-scale tasks.
|
| 21 |
+
Reads a JSONL test list, generates audio in parallel, and saves results.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
omnivoice-infer-batch --model k2-fsa/OmniVoice \
|
| 25 |
+
--test_list test.jsonl --res_dir results/
|
| 26 |
+
|
| 27 |
+
Test list format (JSONL, one JSON object per line):
|
| 28 |
+
Required fields: "id", "text"
|
| 29 |
+
Voice cloning: "ref_audio", "ref_text"
|
| 30 |
+
Voice design: "instruct"
|
| 31 |
+
Optional: "language_id", "language_name", "duration", "speed"
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import argparse
|
| 35 |
+
import logging
|
| 36 |
+
import multiprocessing as mp
|
| 37 |
+
import os
|
| 38 |
+
import signal
|
| 39 |
+
import time
|
| 40 |
+
import traceback
|
| 41 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 42 |
+
from typing import List, Optional, Tuple
|
| 43 |
+
|
| 44 |
+
import torch
|
| 45 |
+
import torchaudio
|
| 46 |
+
from tqdm import tqdm
|
| 47 |
+
|
| 48 |
+
from omnivoice.models.omnivoice import OmniVoice
|
| 49 |
+
from omnivoice.utils.audio import load_audio
|
| 50 |
+
from omnivoice.utils.common import str2bool
|
| 51 |
+
from omnivoice.utils.data_utils import read_test_list
|
| 52 |
+
from omnivoice.utils.duration import RuleDurationEstimator
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_best_device():
|
| 56 |
+
"""Auto-detect the best available device: CUDA > MPS > CPU."""
|
| 57 |
+
if torch.cuda.is_available():
|
| 58 |
+
return "cuda", torch.cuda.device_count()
|
| 59 |
+
if torch.backends.mps.is_available():
|
| 60 |
+
return "mps", 1
|
| 61 |
+
return "cpu", 1
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
worker_model = None
|
| 65 |
+
SAMPLING_RATE = 24000
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_parser():
|
| 69 |
+
parser = argparse.ArgumentParser(description="Infer OmniVoice Model")
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--model",
|
| 72 |
+
type=str,
|
| 73 |
+
default="k2-fsa/OmniVoice",
|
| 74 |
+
help="Path to the model checkpoint (local dir or HF repo id). "
|
| 75 |
+
"Audio tokenizer is expected at <checkpoint>/audio_tokenizer/.",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--test_list",
|
| 79 |
+
type=str,
|
| 80 |
+
required=True,
|
| 81 |
+
help="Path to the JSONL file containing test samples. "
|
| 82 |
+
'Each line is a JSON object: {"id": "name", "text": "...", '
|
| 83 |
+
'"ref_audio": "/path.wav", "ref_text": "...", '
|
| 84 |
+
'"language_id": "en", "language_name": "English", '
|
| 85 |
+
'"duration": 10.0, "speed": 1.2}. '
|
| 86 |
+
"language_id, language_name, duration, and speed are optional.",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--res_dir",
|
| 90 |
+
type=str,
|
| 91 |
+
required=True,
|
| 92 |
+
help="Directory to save the generated audio files.",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--num_step",
|
| 96 |
+
type=int,
|
| 97 |
+
default=32,
|
| 98 |
+
help="Number of steps for iterative decoding.",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--guidance_scale",
|
| 102 |
+
type=float,
|
| 103 |
+
default=2.0,
|
| 104 |
+
help="Scale for Classifier-Free Guidance.",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--t_shift",
|
| 108 |
+
type=float,
|
| 109 |
+
default=0.1,
|
| 110 |
+
help="Shift t to smaller ones if t_shift < 1.0",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--nj_per_gpu",
|
| 114 |
+
type=int,
|
| 115 |
+
default=1,
|
| 116 |
+
help="Number of worker processes to spawn per GPU.",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--audio_chunk_duration",
|
| 120 |
+
type=float,
|
| 121 |
+
default=15.0,
|
| 122 |
+
help="Maximum duration of audio chunk (in seconds) for splitting. "
|
| 123 |
+
'"Not split" if <= 0.',
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--audio_chunk_threshold",
|
| 127 |
+
type=float,
|
| 128 |
+
default=30.0,
|
| 129 |
+
help=(
|
| 130 |
+
"The duration threshold (in seconds) to decide"
|
| 131 |
+
" whether to split audio into chunks."
|
| 132 |
+
),
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--batch_duration",
|
| 136 |
+
type=float,
|
| 137 |
+
default=1000.0,
|
| 138 |
+
help="Maximum total duration (reference + generated) per batch (seconds). "
|
| 139 |
+
"Only effective for parallel_chunk / no chunk mode.",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--batch_size",
|
| 143 |
+
type=int,
|
| 144 |
+
default=0,
|
| 145 |
+
help="Fixed batch size (number of samples per batch). "
|
| 146 |
+
"If > 0, use fixed-size batching instead of duration-based batching.",
|
| 147 |
+
)
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--warmup",
|
| 150 |
+
type=int,
|
| 151 |
+
default=0,
|
| 152 |
+
help="Number of dummy inference runs per worker before real inference "
|
| 153 |
+
"starts, to warm up CUDA kernels and caches.",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--preprocess_prompt",
|
| 157 |
+
type=str2bool,
|
| 158 |
+
default=True,
|
| 159 |
+
help="Whether to preprocess reference audio (silence removal, trimming). "
|
| 160 |
+
"Set to False to keep raw audio.",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--postprocess_output",
|
| 164 |
+
type=str2bool,
|
| 165 |
+
default=True,
|
| 166 |
+
help="Whether to post-process generated audio (remove silence).",
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--layer_penalty_factor",
|
| 170 |
+
type=float,
|
| 171 |
+
default=5.0,
|
| 172 |
+
help="The penalty factor for layer-wise sampling.",
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"--position_temperature",
|
| 176 |
+
type=float,
|
| 177 |
+
default=5.0,
|
| 178 |
+
help="The temperature for position selection.",
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--class_temperature",
|
| 182 |
+
type=float,
|
| 183 |
+
default=0.0,
|
| 184 |
+
help="The temperature for class token sampling.",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--denoise",
|
| 188 |
+
type=str2bool,
|
| 189 |
+
default=True,
|
| 190 |
+
help="Whether to add <|denoise|> token in the reference.",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--lang_id",
|
| 194 |
+
type=str,
|
| 195 |
+
default=None,
|
| 196 |
+
help="Language id to use when test_list JSONL entries do not contain "
|
| 197 |
+
"language_id/language_name fields. If provided, both language_id and "
|
| 198 |
+
"language_name will be set to this value.",
|
| 199 |
+
)
|
| 200 |
+
return parser
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def process_init(rank_queue, model_checkpoint, warmup=0):
|
| 204 |
+
"""Initializer for each worker process.
|
| 205 |
+
|
| 206 |
+
Loads model (with tokenizers and duration estimator) onto a specific GPU
|
| 207 |
+
via ``OmniVoice.from_pretrained()``.
|
| 208 |
+
"""
|
| 209 |
+
global worker_model
|
| 210 |
+
|
| 211 |
+
torch.set_num_threads(2)
|
| 212 |
+
torch.set_num_interop_threads(2)
|
| 213 |
+
|
| 214 |
+
formatter = (
|
| 215 |
+
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] "
|
| 216 |
+
"[Worker %(process)d] %(message)s"
|
| 217 |
+
)
|
| 218 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 219 |
+
|
| 220 |
+
rank = rank_queue.get()
|
| 221 |
+
device_type, device_id = rank
|
| 222 |
+
if device_type == "cpu":
|
| 223 |
+
worker_device = "cpu"
|
| 224 |
+
elif device_type == "mps":
|
| 225 |
+
worker_device = "mps"
|
| 226 |
+
else:
|
| 227 |
+
worker_device = f"cuda:{device_id}"
|
| 228 |
+
|
| 229 |
+
logging.info(f"Initializing worker on device: {worker_device}")
|
| 230 |
+
|
| 231 |
+
worker_model = OmniVoice.from_pretrained(
|
| 232 |
+
model_checkpoint,
|
| 233 |
+
device_map=worker_device,
|
| 234 |
+
dtype=torch.float16,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
if warmup > 0:
|
| 238 |
+
logging.info(f"Running {warmup} warmup iterations on {worker_device}")
|
| 239 |
+
dummy_ref_audio = (
|
| 240 |
+
torch.randn(1, SAMPLING_RATE),
|
| 241 |
+
SAMPLING_RATE,
|
| 242 |
+
) # 1s silence
|
| 243 |
+
for i in range(warmup):
|
| 244 |
+
worker_model.generate(
|
| 245 |
+
text=["hello"],
|
| 246 |
+
language=["en"],
|
| 247 |
+
ref_audio=[dummy_ref_audio],
|
| 248 |
+
ref_text=["hello"],
|
| 249 |
+
)
|
| 250 |
+
logging.info(f"Warmup complete on {worker_device}")
|
| 251 |
+
|
| 252 |
+
logging.info(f"Worker on {worker_device} initialized successfully.")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def estimate_sample_total_duration(
|
| 256 |
+
duration_estimator: RuleDurationEstimator,
|
| 257 |
+
text: str,
|
| 258 |
+
ref_text: str,
|
| 259 |
+
ref_audio_path: str,
|
| 260 |
+
gen_duration: Optional[float] = None,
|
| 261 |
+
) -> float:
|
| 262 |
+
ref_wav = load_audio(ref_audio_path, SAMPLING_RATE)
|
| 263 |
+
ref_duration = ref_wav.shape[-1] / SAMPLING_RATE
|
| 264 |
+
|
| 265 |
+
if gen_duration is None:
|
| 266 |
+
gen_duration = duration_estimator.estimate_duration(
|
| 267 |
+
text, ref_text, ref_duration, low_threshold=2.0
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
total_duration = ref_duration + gen_duration
|
| 271 |
+
return total_duration
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def cluster_samples_by_duration(
|
| 275 |
+
samples: List[Tuple],
|
| 276 |
+
duration_estimator: RuleDurationEstimator,
|
| 277 |
+
batch_duration: float,
|
| 278 |
+
) -> List[List[Tuple]]:
|
| 279 |
+
sample_with_duration = []
|
| 280 |
+
for sample in samples:
|
| 281 |
+
save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd = sample
|
| 282 |
+
total_duration = estimate_sample_total_duration(
|
| 283 |
+
duration_estimator,
|
| 284 |
+
text,
|
| 285 |
+
ref_text,
|
| 286 |
+
ref_audio_path,
|
| 287 |
+
gen_duration=dur,
|
| 288 |
+
)
|
| 289 |
+
sample_with_duration.append((sample, total_duration))
|
| 290 |
+
|
| 291 |
+
sample_with_duration.sort(key=lambda x: x[1], reverse=True)
|
| 292 |
+
batches = []
|
| 293 |
+
current_batch = []
|
| 294 |
+
current_total_duration = 0.0
|
| 295 |
+
|
| 296 |
+
for sample, duration in sample_with_duration:
|
| 297 |
+
if duration > batch_duration:
|
| 298 |
+
batches.append([sample])
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
if current_total_duration + duration <= batch_duration:
|
| 302 |
+
current_batch.append(sample)
|
| 303 |
+
current_total_duration += duration
|
| 304 |
+
else:
|
| 305 |
+
batches.append(current_batch)
|
| 306 |
+
current_batch = [sample]
|
| 307 |
+
current_total_duration = duration
|
| 308 |
+
|
| 309 |
+
if current_batch:
|
| 310 |
+
batches.append(current_batch)
|
| 311 |
+
|
| 312 |
+
logging.info(f"Clustered {len(samples)} samples into {len(batches)} batches")
|
| 313 |
+
return batches
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def cluster_samples_by_batch_size(
|
| 317 |
+
samples: List[Tuple],
|
| 318 |
+
duration_estimator: RuleDurationEstimator,
|
| 319 |
+
batch_size: int,
|
| 320 |
+
) -> List[List[Tuple]]:
|
| 321 |
+
"""Split samples into fixed-size batches, sorted by duration to minimize padding."""
|
| 322 |
+
sample_with_duration = []
|
| 323 |
+
for sample in samples:
|
| 324 |
+
save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd = sample
|
| 325 |
+
total_duration = estimate_sample_total_duration(
|
| 326 |
+
duration_estimator,
|
| 327 |
+
text,
|
| 328 |
+
ref_text,
|
| 329 |
+
ref_audio_path,
|
| 330 |
+
gen_duration=dur,
|
| 331 |
+
)
|
| 332 |
+
sample_with_duration.append((sample, total_duration))
|
| 333 |
+
|
| 334 |
+
sample_with_duration.sort(key=lambda x: x[1], reverse=True)
|
| 335 |
+
sorted_samples = [s for s, _ in sample_with_duration]
|
| 336 |
+
|
| 337 |
+
batches = [
|
| 338 |
+
sorted_samples[i : i + batch_size]
|
| 339 |
+
for i in range(0, len(sorted_samples), batch_size)
|
| 340 |
+
]
|
| 341 |
+
logging.info(
|
| 342 |
+
f"Split {len(samples)} samples into {len(batches)} batches "
|
| 343 |
+
f"(fixed batch_size={batch_size}, sorted by duration)"
|
| 344 |
+
)
|
| 345 |
+
return batches
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def run_inference_batch(
|
| 349 |
+
batch_samples: List[Tuple],
|
| 350 |
+
res_dir: str,
|
| 351 |
+
**gen_kwargs,
|
| 352 |
+
) -> List[Tuple]:
|
| 353 |
+
global worker_model
|
| 354 |
+
|
| 355 |
+
save_names = []
|
| 356 |
+
ref_texts = []
|
| 357 |
+
ref_audio_paths = []
|
| 358 |
+
texts = []
|
| 359 |
+
langs = []
|
| 360 |
+
durations = []
|
| 361 |
+
speeds = []
|
| 362 |
+
|
| 363 |
+
for sample in batch_samples:
|
| 364 |
+
save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd = sample
|
| 365 |
+
save_names.append(save_name)
|
| 366 |
+
ref_texts.append(ref_text)
|
| 367 |
+
ref_audio_paths.append(ref_audio_path)
|
| 368 |
+
texts.append(text)
|
| 369 |
+
langs.append(lang_id)
|
| 370 |
+
durations.append(dur)
|
| 371 |
+
speeds.append(spd)
|
| 372 |
+
|
| 373 |
+
start_time = time.time()
|
| 374 |
+
audios = worker_model.generate(
|
| 375 |
+
text=texts,
|
| 376 |
+
language=langs,
|
| 377 |
+
ref_audio=ref_audio_paths,
|
| 378 |
+
ref_text=ref_texts,
|
| 379 |
+
duration=durations if any(d is not None for d in durations) else None,
|
| 380 |
+
speed=speeds if any(s is not None for s in speeds) else None,
|
| 381 |
+
**gen_kwargs,
|
| 382 |
+
)
|
| 383 |
+
batch_synth_time = time.time() - start_time
|
| 384 |
+
|
| 385 |
+
results = []
|
| 386 |
+
for save_name, audio in zip(save_names, audios):
|
| 387 |
+
save_path = os.path.join(res_dir, save_name + ".wav")
|
| 388 |
+
torchaudio.save(save_path, audio, worker_model.sampling_rate)
|
| 389 |
+
audio_duration = audio.shape[-1] / worker_model.sampling_rate
|
| 390 |
+
results.append(
|
| 391 |
+
(
|
| 392 |
+
save_name,
|
| 393 |
+
batch_synth_time / len(batch_samples),
|
| 394 |
+
audio_duration,
|
| 395 |
+
"success",
|
| 396 |
+
)
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
return results
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def main():
|
| 403 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 404 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 405 |
+
mp.set_start_method("spawn", force=True)
|
| 406 |
+
|
| 407 |
+
args = get_parser().parse_args()
|
| 408 |
+
os.makedirs(args.res_dir, exist_ok=True)
|
| 409 |
+
|
| 410 |
+
device_type, num_devices = get_best_device()
|
| 411 |
+
if device_type == "cpu":
|
| 412 |
+
logging.warning(
|
| 413 |
+
"No GPU found. Falling back to CPU inference. This might be slow."
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
num_processes = num_devices * args.nj_per_gpu
|
| 417 |
+
logging.info(
|
| 418 |
+
f"Using {device_type} ({num_devices} device(s))."
|
| 419 |
+
f" Spawning {num_processes} worker processes."
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
manager = mp.Manager()
|
| 423 |
+
rank_queue = manager.Queue()
|
| 424 |
+
for rank in list(range(num_devices)) * args.nj_per_gpu:
|
| 425 |
+
rank_queue.put((device_type, rank))
|
| 426 |
+
|
| 427 |
+
samples_raw = read_test_list(args.test_list)
|
| 428 |
+
samples = []
|
| 429 |
+
for s in samples_raw:
|
| 430 |
+
if args.lang_id is not None:
|
| 431 |
+
lang_id = args.lang_id
|
| 432 |
+
lang_name = args.lang_id
|
| 433 |
+
else:
|
| 434 |
+
lang_id = s.get("language_id")
|
| 435 |
+
lang_name = s.get("language_name")
|
| 436 |
+
samples.append(
|
| 437 |
+
(
|
| 438 |
+
s["id"],
|
| 439 |
+
s["ref_text"],
|
| 440 |
+
s["ref_audio"],
|
| 441 |
+
s["text"],
|
| 442 |
+
lang_id,
|
| 443 |
+
lang_name,
|
| 444 |
+
s.get("duration"),
|
| 445 |
+
s.get("speed"),
|
| 446 |
+
)
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
total_synthesis_time = []
|
| 450 |
+
total_audio_duration = []
|
| 451 |
+
|
| 452 |
+
try:
|
| 453 |
+
with ProcessPoolExecutor(
|
| 454 |
+
max_workers=num_processes,
|
| 455 |
+
initializer=process_init,
|
| 456 |
+
initargs=(rank_queue, args.model, args.warmup),
|
| 457 |
+
) as executor:
|
| 458 |
+
futures = []
|
| 459 |
+
|
| 460 |
+
# parallel_chunk / no chunk
|
| 461 |
+
logging.info("Running batch inference")
|
| 462 |
+
|
| 463 |
+
duration_estimator = RuleDurationEstimator()
|
| 464 |
+
if args.batch_size > 0:
|
| 465 |
+
batches = cluster_samples_by_batch_size(
|
| 466 |
+
samples, duration_estimator, args.batch_size
|
| 467 |
+
)
|
| 468 |
+
else:
|
| 469 |
+
batches = cluster_samples_by_duration(
|
| 470 |
+
samples, duration_estimator, args.batch_duration
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
args_dict = vars(args)
|
| 474 |
+
|
| 475 |
+
for batch in batches:
|
| 476 |
+
futures.append(
|
| 477 |
+
executor.submit(
|
| 478 |
+
run_inference_batch, batch_samples=batch, **args_dict
|
| 479 |
+
)
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
for future in tqdm(
|
| 483 |
+
as_completed(futures), total=len(futures), desc="Processing samples"
|
| 484 |
+
):
|
| 485 |
+
try:
|
| 486 |
+
result = future.result()
|
| 487 |
+
for s_name, synth_time, audio_dur, status in result:
|
| 488 |
+
total_synthesis_time.append(synth_time)
|
| 489 |
+
total_audio_duration.append(audio_dur)
|
| 490 |
+
rtf = synth_time / audio_dur if audio_dur > 0 else float("inf")
|
| 491 |
+
logging.debug(
|
| 492 |
+
f"Processed {s_name}: Audio Duration={audio_dur:.2f}s, "
|
| 493 |
+
f"Synthesis Time={synth_time:.2f}s, RTF={rtf:.4f}"
|
| 494 |
+
)
|
| 495 |
+
except Exception as e:
|
| 496 |
+
logging.error(f"Failed to process sample: {e}")
|
| 497 |
+
detailed_error = traceback.format_exc()
|
| 498 |
+
logging.error(f"Detailed error: {detailed_error}")
|
| 499 |
+
|
| 500 |
+
except (Exception, KeyboardInterrupt) as e:
|
| 501 |
+
logging.critical(
|
| 502 |
+
f"An unrecoverable error occurred: {e}. Terminating all processes."
|
| 503 |
+
)
|
| 504 |
+
detailed_error_info = traceback.format_exc()
|
| 505 |
+
logging.error(f"--- DETAILED TRACEBACK ---\n{detailed_error_info}")
|
| 506 |
+
os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)
|
| 507 |
+
|
| 508 |
+
total_synthesis_time = sum(total_synthesis_time)
|
| 509 |
+
total_audio_duration = sum(total_audio_duration)
|
| 510 |
+
logging.info("--- Summary ---")
|
| 511 |
+
logging.info(f"Total audio duration: {total_audio_duration:.2f}s")
|
| 512 |
+
logging.info(f"Total synthesis time: {total_synthesis_time:.2f}s")
|
| 513 |
+
if total_audio_duration > 0:
|
| 514 |
+
average_rtf = total_synthesis_time / total_audio_duration
|
| 515 |
+
logging.info(f"Average RTF: {average_rtf:.4f}")
|
| 516 |
+
else:
|
| 517 |
+
logging.warning("No speech was generated. RTF cannot be computed.")
|
| 518 |
+
|
| 519 |
+
logging.info("Done!")
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
if __name__ == "__main__":
|
| 523 |
+
main()
|
omnivoice/cli/train.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Training CLI for OmniVoice.
|
| 19 |
+
|
| 20 |
+
Launches distributed training via HuggingFace Accelerate.
|
| 21 |
+
Supports pre-training on Emilia data and finetuning on custom data.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
accelerate launch --gpu_ids 0,1,2,3 --num_processes 4 \\
|
| 25 |
+
-m omnivoice.cli.train \\
|
| 26 |
+
--train_config train_config.json \\
|
| 27 |
+
--data_config data_config.json \\
|
| 28 |
+
--output_dir output/
|
| 29 |
+
|
| 30 |
+
See examples/run_emilia.sh and examples/run_finetune.sh for full pipelines.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import argparse
|
| 34 |
+
|
| 35 |
+
from omnivoice.training.builder import build_dataloaders, build_model_and_tokenizer
|
| 36 |
+
from omnivoice.training.config import TrainingConfig
|
| 37 |
+
from omnivoice.training.trainer import OmniTrainer
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
parser = argparse.ArgumentParser(description="OmniVoice Training Entry Point")
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--train_config", type=str, required=True, help="Path to config JSON"
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--output_dir", type=str, required=True, help="Where to save checkpoints"
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--data_config", type=str, required=True, help="Path to data config JSON"
|
| 50 |
+
)
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
# 1. Load Configuration
|
| 54 |
+
config = TrainingConfig.from_json(args.train_config)
|
| 55 |
+
config.output_dir = args.output_dir
|
| 56 |
+
config.data_config = args.data_config
|
| 57 |
+
|
| 58 |
+
# 2. Build Components
|
| 59 |
+
model, tokenizer = build_model_and_tokenizer(config)
|
| 60 |
+
train_loader, eval_loader = build_dataloaders(config, tokenizer)
|
| 61 |
+
|
| 62 |
+
# 3. Initialize Trainer and Start
|
| 63 |
+
trainer = OmniTrainer(
|
| 64 |
+
model=model,
|
| 65 |
+
config=config,
|
| 66 |
+
train_dataloader=train_loader,
|
| 67 |
+
eval_dataloader=eval_loader,
|
| 68 |
+
tokenizer=tokenizer,
|
| 69 |
+
)
|
| 70 |
+
trainer.train()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
main()
|
omnivoice/data/__init__.py
ADDED
|
File without changes
|
omnivoice/data/batching.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Batching strategies for streaming/iterable datasets.
|
| 19 |
+
|
| 20 |
+
Provides length-based grouping and packing for efficient training with
|
| 21 |
+
variable-length audio.
|
| 22 |
+
|
| 23 |
+
Key classes:
|
| 24 |
+
- ``PackingIterableDataset``: Packs multiple samples into fixed-length sequences
|
| 25 |
+
for training. Used by ``omnivoice.training.builder``.
|
| 26 |
+
- ``StreamLengthGroupDataset``: Groups samples by length into buckets. Used by
|
| 27 |
+
data processing scripts (e.g. ``omnivoice/scripts/``).
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import bisect
|
| 31 |
+
import logging
|
| 32 |
+
from typing import Any, Dict, Iterator, List, Optional
|
| 33 |
+
|
| 34 |
+
import numpy as np
|
| 35 |
+
|
| 36 |
+
from omnivoice.data.dataset import IterableDataReader, WrappedIterableDataset
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class StreamLengthGroupDataset(WrappedIterableDataset):
|
| 40 |
+
"""A streaming dataset that groups samples by their lengths into buckets.
|
| 41 |
+
Only support audio data for now."""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dataset: IterableDataReader,
|
| 46 |
+
batch_duration: float,
|
| 47 |
+
min_length: float = 0.5,
|
| 48 |
+
max_length: float = 30.0,
|
| 49 |
+
num_buckets: int = 20,
|
| 50 |
+
audio_key: str = "audio",
|
| 51 |
+
drop_last: bool = False,
|
| 52 |
+
max_sample: Optional[int] = None,
|
| 53 |
+
):
|
| 54 |
+
self.dataset = dataset
|
| 55 |
+
self.batch_duration = batch_duration
|
| 56 |
+
self.min_length = min_length
|
| 57 |
+
self.max_length = max_length
|
| 58 |
+
self.num_buckets = num_buckets
|
| 59 |
+
self.audio_key = audio_key
|
| 60 |
+
self.drop_last = drop_last
|
| 61 |
+
self.max_sample = max_sample if max_sample is not None else float("inf")
|
| 62 |
+
|
| 63 |
+
self.boundaries = np.linspace(min_length, max_length, num_buckets + 1)[1:]
|
| 64 |
+
|
| 65 |
+
def set_epoch(self, epoch: int):
|
| 66 |
+
"""
|
| 67 |
+
Set the epoch for shuffling.
|
| 68 |
+
"""
|
| 69 |
+
self.dataset.set_epoch(epoch)
|
| 70 |
+
|
| 71 |
+
def _get_bucket_id(self, length: float) -> int:
|
| 72 |
+
|
| 73 |
+
return bisect.bisect_left(self.boundaries, length)
|
| 74 |
+
|
| 75 |
+
def __iter__(self) -> Iterator[List[Dict[str, Any]]]:
|
| 76 |
+
buckets = [[] for _ in range(self.num_buckets)]
|
| 77 |
+
bucket_max_len = [0.0] * self.num_buckets
|
| 78 |
+
|
| 79 |
+
for sample in self.dataset:
|
| 80 |
+
audio = sample[self.audio_key]
|
| 81 |
+
duration = audio.size(-1) / self.dataset.sample_rate
|
| 82 |
+
|
| 83 |
+
if duration < self.min_length or duration > self.max_length:
|
| 84 |
+
# logging.warning(f"Skipping sample with duration {duration:.2f}s")
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
b_id = self._get_bucket_id(duration)
|
| 88 |
+
buckets[b_id].append(sample)
|
| 89 |
+
|
| 90 |
+
if duration > bucket_max_len[b_id]:
|
| 91 |
+
bucket_max_len[b_id] = duration
|
| 92 |
+
|
| 93 |
+
if (
|
| 94 |
+
bucket_max_len[b_id] * (len(buckets[b_id]) + 1) >= self.batch_duration
|
| 95 |
+
or len(buckets[b_id]) >= self.max_sample
|
| 96 |
+
):
|
| 97 |
+
yield buckets[b_id]
|
| 98 |
+
buckets[b_id] = []
|
| 99 |
+
bucket_max_len[b_id] = 0.0
|
| 100 |
+
|
| 101 |
+
if not self.drop_last:
|
| 102 |
+
for b_idx, bucket in enumerate(buckets):
|
| 103 |
+
if bucket:
|
| 104 |
+
yield bucket
|
| 105 |
+
buckets[b_idx] = []
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class PackingIterableDataset(WrappedIterableDataset):
|
| 109 |
+
"""
|
| 110 |
+
An IterableDataset that dynamically processes samples using a processor
|
| 111 |
+
and packs them into batches based on the real token count.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
dataset (Iterable): The raw dataset to process.
|
| 115 |
+
processor (Callable): A processor to process each sample.
|
| 116 |
+
batch_tokens (int): Maximum number of tokens per batch.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
dataset: IterableDataReader,
|
| 122 |
+
processor: Any,
|
| 123 |
+
batch_tokens: int,
|
| 124 |
+
):
|
| 125 |
+
self.dataset = dataset
|
| 126 |
+
self.processor = processor
|
| 127 |
+
self.batch_tokens = batch_tokens
|
| 128 |
+
self.skip_batches = 0
|
| 129 |
+
|
| 130 |
+
def set_epoch(self, epoch: int):
|
| 131 |
+
"""
|
| 132 |
+
Set the epoch for shuffling.
|
| 133 |
+
"""
|
| 134 |
+
self.dataset.set_epoch(epoch)
|
| 135 |
+
|
| 136 |
+
def __iter__(self) -> Iterator[List[Dict[str, Any]]]:
|
| 137 |
+
current_batch = []
|
| 138 |
+
current_token_count = 0
|
| 139 |
+
|
| 140 |
+
for raw_sample in self.dataset:
|
| 141 |
+
# Process the sample using the processor
|
| 142 |
+
try:
|
| 143 |
+
processed_sample = self.processor(raw_sample)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logging.warning(f"Error processing sample {raw_sample}: {e}")
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
sample_length = processed_sample["length"]
|
| 149 |
+
|
| 150 |
+
if sample_length > self.batch_tokens:
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
# Check if adding this sample exceeds the batch token limit
|
| 154 |
+
if current_token_count + sample_length > self.batch_tokens:
|
| 155 |
+
# Yield the current batch and start a new one
|
| 156 |
+
yield current_batch
|
| 157 |
+
current_batch = []
|
| 158 |
+
current_token_count = 0
|
| 159 |
+
|
| 160 |
+
# Add the processed sample to the current batch
|
| 161 |
+
current_batch.append(processed_sample)
|
| 162 |
+
current_token_count += sample_length
|
| 163 |
+
|
| 164 |
+
# Yield the last batch if it's not empty
|
| 165 |
+
if current_batch:
|
| 166 |
+
yield current_batch
|
omnivoice/data/collator.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Data collator with packing for efficient training.
|
| 19 |
+
|
| 20 |
+
Packs multiple samples into a single sequence of fixed length (``batch_tokens``)
|
| 21 |
+
to maximize GPU utilization, instead of padding each sample individually.
|
| 22 |
+
Used by ``omnivoice.training.builder`` to create the collate function.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from typing import Any, Dict, List
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PackingDataCollator:
|
| 31 |
+
def __init__(self, processor, batch_tokens: int):
|
| 32 |
+
self.batch_tokens = batch_tokens
|
| 33 |
+
self.processor = processor
|
| 34 |
+
|
| 35 |
+
def __call__(self, processed_samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 36 |
+
|
| 37 |
+
target_length = self.batch_tokens
|
| 38 |
+
|
| 39 |
+
input_ids = torch.cat(
|
| 40 |
+
[s["input_ids"] for s in processed_samples], dim=1
|
| 41 |
+
) # [C, Total_Len], C is the number of codebook layers of the audio tokenizer
|
| 42 |
+
labels = torch.cat(
|
| 43 |
+
[s["labels"] for s in processed_samples], dim=1
|
| 44 |
+
) # [C, Total_Len]
|
| 45 |
+
audio_mask = torch.cat(
|
| 46 |
+
[s["audio_mask"] for s in processed_samples], dim=0
|
| 47 |
+
) # [Total_Len]
|
| 48 |
+
|
| 49 |
+
position_ids = torch.cat(
|
| 50 |
+
[torch.arange(s["length"], dtype=torch.long) for s in processed_samples],
|
| 51 |
+
dim=0,
|
| 52 |
+
) # [Total_Len]
|
| 53 |
+
|
| 54 |
+
pad_length = target_length - input_ids.shape[1]
|
| 55 |
+
|
| 56 |
+
input_ids = torch.nn.functional.pad(
|
| 57 |
+
input_ids,
|
| 58 |
+
pad=(0, pad_length),
|
| 59 |
+
value=self.processor.text_tokenizer.pad_token_id,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
labels = torch.nn.functional.pad(labels, pad=(0, pad_length), value=-100)
|
| 63 |
+
|
| 64 |
+
audio_mask = torch.nn.functional.pad(
|
| 65 |
+
audio_mask, pad=(0, pad_length), value=False
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
position_ids = torch.nn.functional.pad(
|
| 69 |
+
position_ids, pad=(0, pad_length), value=0
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return_list = {
|
| 73 |
+
"input_ids": input_ids.unsqueeze(0), # [1, C, L]
|
| 74 |
+
"labels": labels.unsqueeze(0), # [1, C, L]
|
| 75 |
+
"audio_mask": audio_mask.unsqueeze(0), # [1, L]
|
| 76 |
+
"position_ids": position_ids.unsqueeze(0), # [1, L]
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
document_ids_list = []
|
| 80 |
+
|
| 81 |
+
for i, s in enumerate(processed_samples):
|
| 82 |
+
seq_len = s["length"]
|
| 83 |
+
document_ids_list.append(torch.full((seq_len,), i, dtype=torch.int32))
|
| 84 |
+
|
| 85 |
+
document_ids = torch.cat(document_ids_list, dim=0)
|
| 86 |
+
|
| 87 |
+
document_ids = torch.nn.functional.pad(
|
| 88 |
+
document_ids, pad=(0, pad_length), value=-1
|
| 89 |
+
)
|
| 90 |
+
return_list["document_ids"] = document_ids.unsqueeze(0) # [1, L]
|
| 91 |
+
|
| 92 |
+
return return_list
|
omnivoice/data/dataset.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Dataset and data-loading utilities for training and evaluation.
|
| 19 |
+
|
| 20 |
+
Provides WebDataset-based iterable datasets, manifest parsing, and audio/token
|
| 21 |
+
loading. Used by ``omnivoice.training.builder.build_dataloaders()`` to construct
|
| 22 |
+
train and eval data loaders.
|
| 23 |
+
|
| 24 |
+
Key functions:
|
| 25 |
+
- ``prepare_data_manifests_from_json()``: Parses a data config JSON into train/dev
|
| 26 |
+
manifests.
|
| 27 |
+
|
| 28 |
+
Key classes:
|
| 29 |
+
- ``WebDatasetReader``: Reads audio/text pairs from WebDataset tar shards as an
|
| 30 |
+
iterable dataset.
|
| 31 |
+
- ``MuxWebDatasetReader``: Multiplexes multiple WebDataset readers for
|
| 32 |
+
multilingual data.
|
| 33 |
+
- ``JsonlDatasetReader``: Reads audio/text pairs from a JSONL manifest file.
|
| 34 |
+
Used by data processing scripts (e.g. ``omnivoice/scripts/``).
|
| 35 |
+
- ``SampleDecoder``: Decodes individual samples (audio or tokens + labels).
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import io
|
| 39 |
+
import json
|
| 40 |
+
import logging
|
| 41 |
+
import os
|
| 42 |
+
import random
|
| 43 |
+
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
| 44 |
+
|
| 45 |
+
import torch
|
| 46 |
+
import torch.distributed as dist
|
| 47 |
+
import torchaudio
|
| 48 |
+
import webdataset as wds
|
| 49 |
+
from torch.utils.data import IterableDataset
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_audio_webdataset(data, sample_rate: int = 24000, device="cpu"):
|
| 53 |
+
"""
|
| 54 |
+
Load audio from bytes data and resample to the target sample rate if needed.
|
| 55 |
+
Return a tensor of shape (1, num_samples)
|
| 56 |
+
"""
|
| 57 |
+
audio, sr = torchaudio.load(io.BytesIO(data))
|
| 58 |
+
audio = audio.to(device)
|
| 59 |
+
if audio.size(dim=0) > 1:
|
| 60 |
+
audio = torch.mean(audio, dim=0)
|
| 61 |
+
if sr != sample_rate:
|
| 62 |
+
audio = torchaudio.functional.resample(audio, sr, sample_rate)
|
| 63 |
+
return audio
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def prepare_data_manifests_from_json(
|
| 67 |
+
data_config: str,
|
| 68 |
+
) -> Tuple[List[Tuple[str, str, int, float]], List[Tuple[str, str, int, float]]]:
|
| 69 |
+
"""
|
| 70 |
+
Prepare data manifests from a json file.
|
| 71 |
+
A typical multilingual json file is in the following format:
|
| 72 |
+
{
|
| 73 |
+
"train":
|
| 74 |
+
[
|
| 75 |
+
{
|
| 76 |
+
"language_id": "en",
|
| 77 |
+
"manifest_path": [
|
| 78 |
+
"/Emilia/EN/data.lst"
|
| 79 |
+
],
|
| 80 |
+
"repeat": 1
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"language_id": "zh",
|
| 84 |
+
"manifest_path": [
|
| 85 |
+
"/Emilia/ZH/data.lst"
|
| 86 |
+
],
|
| 87 |
+
"repeat": 1
|
| 88 |
+
}
|
| 89 |
+
],
|
| 90 |
+
"dev":
|
| 91 |
+
[
|
| 92 |
+
{
|
| 93 |
+
"language_id": "en",
|
| 94 |
+
"manifest_path": [
|
| 95 |
+
"/Emilia/EN-dev/data.lst"
|
| 96 |
+
],
|
| 97 |
+
"repeat": 1
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"language_id": "zh",
|
| 101 |
+
"manifest_path": [
|
| 102 |
+
"/Emilia/ZH-dev/data.lst"
|
| 103 |
+
],
|
| 104 |
+
"repeat": 1
|
| 105 |
+
}
|
| 106 |
+
]
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
"language_id" is not used, just for better organization of multilingual data.
|
| 110 |
+
"repeat" is an optional field, default to 1, which indicates how many times
|
| 111 |
+
the manifest should be repeated.
|
| 112 |
+
|
| 113 |
+
The simplist format is like:
|
| 114 |
+
{
|
| 115 |
+
"train":
|
| 116 |
+
[
|
| 117 |
+
{
|
| 118 |
+
"manifest_path": [
|
| 119 |
+
"/Emilia/EN/data.lst",
|
| 120 |
+
"/Emilia/ZH/data.lst"
|
| 121 |
+
],
|
| 122 |
+
}
|
| 123 |
+
],
|
| 124 |
+
"dev":
|
| 125 |
+
[
|
| 126 |
+
{
|
| 127 |
+
"manifest_path": [
|
| 128 |
+
"/Emilia/EN-dev/data.lst",
|
| 129 |
+
"/Emilia/ZH-dev/data.lst"
|
| 130 |
+
],
|
| 131 |
+
}
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
data.lst format (items separated by space):
|
| 135 |
+
/path/to/data.tar /path/to/label.jsonl num_items num_seconds
|
| 136 |
+
"""
|
| 137 |
+
train_manifests = []
|
| 138 |
+
dev_manifests = []
|
| 139 |
+
with open(data_config, "r", encoding="utf-8") as f:
|
| 140 |
+
data = json.load(f)
|
| 141 |
+
for item in data["train"]:
|
| 142 |
+
manifest_paths = item["manifest_path"]
|
| 143 |
+
repeat = item.get("repeat", 1)
|
| 144 |
+
for manifest_path in manifest_paths:
|
| 145 |
+
# assert manifest_path is a file
|
| 146 |
+
assert os.path.isfile(manifest_path), f"{manifest_path} is not a file."
|
| 147 |
+
train_manifests.extend(
|
| 148 |
+
webdataset_manifest_reader(manifest_path) * repeat
|
| 149 |
+
)
|
| 150 |
+
if "dev" in data:
|
| 151 |
+
for item in data["dev"]:
|
| 152 |
+
manifest_paths = item["manifest_path"]
|
| 153 |
+
repeat = item.get("repeat", 1)
|
| 154 |
+
for manifest_path in manifest_paths:
|
| 155 |
+
dev_manifests.extend(
|
| 156 |
+
webdataset_manifest_reader(manifest_path) * repeat
|
| 157 |
+
)
|
| 158 |
+
return train_manifests, dev_manifests
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def webdataset_manifest_reader(
|
| 162 |
+
manifest_path: str,
|
| 163 |
+
) -> List[Tuple[str, str]]:
|
| 164 |
+
"""
|
| 165 |
+
Read a manifest file containing webdataset tar paths and label jsonl paths.
|
| 166 |
+
Each line in the manifest file is in the format of:
|
| 167 |
+
/path/to/data.tar /path/to/label.jsonl num_items num_seconds
|
| 168 |
+
"""
|
| 169 |
+
manifests = []
|
| 170 |
+
with open(manifest_path, "r", encoding="utf-8") as f:
|
| 171 |
+
for line in f:
|
| 172 |
+
line = line.strip()
|
| 173 |
+
if not line:
|
| 174 |
+
continue
|
| 175 |
+
parts = line.split()
|
| 176 |
+
if len(parts) != 4:
|
| 177 |
+
raise ValueError(
|
| 178 |
+
f"Invalid manifest line: {line}. "
|
| 179 |
+
f"Each line must contain "
|
| 180 |
+
"tar_path, label_jsonl_path, num_items, num_seconds."
|
| 181 |
+
)
|
| 182 |
+
tar_path, label_jsonl_path, num_items, num_seconds = (
|
| 183 |
+
parts[0],
|
| 184 |
+
parts[1],
|
| 185 |
+
int(parts[2]),
|
| 186 |
+
float(parts[3]),
|
| 187 |
+
)
|
| 188 |
+
manifests.append((tar_path, label_jsonl_path, num_items, num_seconds))
|
| 189 |
+
return manifests
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class SampleDecoder:
|
| 193 |
+
"""
|
| 194 |
+
Decode a sample from webdataset, including loading audio/tokens and fetching label.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
tar_to_label: Dict,
|
| 200 |
+
sample_rate: int = 24000,
|
| 201 |
+
audio_format: Optional[Tuple[str]] = None,
|
| 202 |
+
normalize_audio: bool = True,
|
| 203 |
+
):
|
| 204 |
+
"""
|
| 205 |
+
Args:
|
| 206 |
+
tar_to_label:
|
| 207 |
+
A dict mapping from audio tar file to label tar file.
|
| 208 |
+
sample_rate:
|
| 209 |
+
Target sample rate for audio. Required if audio is loaded.
|
| 210 |
+
audio_format:
|
| 211 |
+
Tuple of audio file extensions to look for in the sample.
|
| 212 |
+
"""
|
| 213 |
+
self.tar_to_label = tar_to_label
|
| 214 |
+
self.sample_rate = sample_rate
|
| 215 |
+
self.label_dataset = None
|
| 216 |
+
if audio_format is None:
|
| 217 |
+
self.audio_format = ("flac", "wav", "mp3")
|
| 218 |
+
else:
|
| 219 |
+
self.audio_format = audio_format
|
| 220 |
+
self.normalize_audio = normalize_audio
|
| 221 |
+
|
| 222 |
+
def __call__(self, sample):
|
| 223 |
+
return_dict = {}
|
| 224 |
+
src = sample["__url__"]
|
| 225 |
+
key = sample["__key__"]
|
| 226 |
+
if (
|
| 227 |
+
self.label_dataset is None
|
| 228 |
+
or self.label_dataset.path != self.tar_to_label[src]
|
| 229 |
+
):
|
| 230 |
+
self.label_dataset = LabelDataset(self.tar_to_label[src])
|
| 231 |
+
|
| 232 |
+
audio = torch.empty(0)
|
| 233 |
+
if "npy" in sample:
|
| 234 |
+
audio_tokens = torch.from_numpy(sample["npy"])
|
| 235 |
+
return_dict["audio_tokens"] = audio_tokens
|
| 236 |
+
else:
|
| 237 |
+
for ext in self.audio_format:
|
| 238 |
+
if ext in sample:
|
| 239 |
+
# load audio (1, num_samples)
|
| 240 |
+
audio = load_audio_webdataset(
|
| 241 |
+
sample[ext], sample_rate=self.sample_rate
|
| 242 |
+
)
|
| 243 |
+
if self.normalize_audio:
|
| 244 |
+
audio = (audio / (audio.abs().max() + 1e-7)) * 0.9
|
| 245 |
+
break
|
| 246 |
+
return_dict["audio"] = audio
|
| 247 |
+
return_dict["audio_duration"] = audio.size(-1) / self.sample_rate
|
| 248 |
+
|
| 249 |
+
label = self.label_dataset[key]
|
| 250 |
+
|
| 251 |
+
return_dict["label"] = label
|
| 252 |
+
return return_dict
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class LabelDataset:
|
| 256 |
+
def __init__(self, jsonl_path: str):
|
| 257 |
+
"""
|
| 258 |
+
Load labels from a jsonl file.
|
| 259 |
+
Args:
|
| 260 |
+
jsonl_path:
|
| 261 |
+
Path to the jsonl file containing labels.
|
| 262 |
+
Each line in the manifest file is in the format of:
|
| 263 |
+
{"idx": "idx", "text": "transcription text"}
|
| 264 |
+
"""
|
| 265 |
+
self._labels = {}
|
| 266 |
+
self.path = jsonl_path
|
| 267 |
+
if not os.path.exists(jsonl_path):
|
| 268 |
+
raise FileNotFoundError(f"Label jsonl file {jsonl_path} does not exist.")
|
| 269 |
+
with open(jsonl_path, "r", encoding="utf-8") as f:
|
| 270 |
+
for line in f:
|
| 271 |
+
line = line.strip()
|
| 272 |
+
if not line:
|
| 273 |
+
continue
|
| 274 |
+
item = json.loads(line)
|
| 275 |
+
if "id" in item:
|
| 276 |
+
self._labels[item["id"]] = item
|
| 277 |
+
|
| 278 |
+
def __getitem__(self, key):
|
| 279 |
+
return self._labels[key]
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class IterableDataReader:
|
| 283 |
+
"Interfaces for classes reading data."
|
| 284 |
+
|
| 285 |
+
sample_rate: int
|
| 286 |
+
|
| 287 |
+
def set_epoch(self, epoch: int):
|
| 288 |
+
raise NotImplementedError
|
| 289 |
+
|
| 290 |
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
| 291 |
+
raise NotImplementedError
|
| 292 |
+
|
| 293 |
+
def __len__(self) -> int:
|
| 294 |
+
raise NotImplementedError
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class WrappedIterableDataset(IterableDataset):
|
| 298 |
+
"IterableDataset interfaces in this project."
|
| 299 |
+
|
| 300 |
+
def set_epoch(self, epoch: int):
|
| 301 |
+
raise NotImplementedError
|
| 302 |
+
|
| 303 |
+
def __iter__(self) -> Iterator[List[Dict[str, Any]]]:
|
| 304 |
+
raise NotImplementedError
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class WebDatasetReader(IterableDataReader):
|
| 308 |
+
def __init__(
|
| 309 |
+
self,
|
| 310 |
+
manifests: List[Tuple[str, str, int, float]],
|
| 311 |
+
evaluation: bool = False,
|
| 312 |
+
shuffle_buffer_size: int = 20000,
|
| 313 |
+
sample_rate: int = 24000,
|
| 314 |
+
):
|
| 315 |
+
self.shuffle_buffer_size = shuffle_buffer_size
|
| 316 |
+
self.evaluation = evaluation
|
| 317 |
+
self.epoch = 0
|
| 318 |
+
|
| 319 |
+
self.orig_urls = []
|
| 320 |
+
self.tar_to_label = {}
|
| 321 |
+
self.num_items = 0
|
| 322 |
+
self.num_seconds = 0.0
|
| 323 |
+
for tar_path, label_jsonl_path, num_items, num_seconds in manifests:
|
| 324 |
+
self.orig_urls.append(tar_path)
|
| 325 |
+
self.tar_to_label[tar_path] = label_jsonl_path
|
| 326 |
+
self.num_items += num_items
|
| 327 |
+
self.num_seconds += num_seconds
|
| 328 |
+
self.urls = self.orig_urls.copy()
|
| 329 |
+
self.sample_decoder = SampleDecoder(
|
| 330 |
+
tar_to_label=self.tar_to_label,
|
| 331 |
+
sample_rate=sample_rate,
|
| 332 |
+
)
|
| 333 |
+
self.sample_rate = sample_rate
|
| 334 |
+
|
| 335 |
+
def set_epoch(self, epoch: int):
|
| 336 |
+
"""
|
| 337 |
+
Set the epoch for shuffling.
|
| 338 |
+
"""
|
| 339 |
+
self.epoch = epoch
|
| 340 |
+
self.urls = self.orig_urls.copy()
|
| 341 |
+
if not self.evaluation:
|
| 342 |
+
random.Random(epoch).shuffle(self.urls)
|
| 343 |
+
|
| 344 |
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
| 345 |
+
|
| 346 |
+
dataset = wds.WebDataset(
|
| 347 |
+
self.urls,
|
| 348 |
+
shardshuffle=False,
|
| 349 |
+
workersplitter=wds.split_by_worker,
|
| 350 |
+
nodesplitter=wds.split_by_node,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
pipeline = dataset.decode().map(self.sample_decoder)
|
| 354 |
+
if not self.evaluation:
|
| 355 |
+
pipeline = pipeline.shuffle(self.shuffle_buffer_size, seed=self.epoch)
|
| 356 |
+
return iter(pipeline)
|
| 357 |
+
|
| 358 |
+
def __len__(self) -> int:
|
| 359 |
+
return self.num_items
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class JsonlDatasetReader(IterableDataReader):
|
| 363 |
+
"""Read raw JSONL and load audio files, matching WebDatasetReader output format.
|
| 364 |
+
|
| 365 |
+
Each JSONL line should be a JSON object with at least:
|
| 366 |
+
{"id": "...", "audio_path": "/path/to/audio.wav", ...}
|
| 367 |
+
|
| 368 |
+
Yields dicts of the form: {"audio": Tensor(1, T), "label": dict}
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
def __init__(
|
| 372 |
+
self,
|
| 373 |
+
jsonl_path: str,
|
| 374 |
+
sample_rate: int = 24_000,
|
| 375 |
+
shuffle: bool = True,
|
| 376 |
+
shuffle_seed: int = 42,
|
| 377 |
+
normalize_audio: bool = True,
|
| 378 |
+
):
|
| 379 |
+
self.jsonl_path = jsonl_path
|
| 380 |
+
self.sample_rate = sample_rate
|
| 381 |
+
self.shuffle = shuffle
|
| 382 |
+
self.shuffle_seed = shuffle_seed
|
| 383 |
+
self.normalize_audio = normalize_audio
|
| 384 |
+
|
| 385 |
+
def set_epoch(self, epoch: int):
|
| 386 |
+
self.shuffle_seed = epoch
|
| 387 |
+
|
| 388 |
+
def _read_lines(self) -> list[dict]:
|
| 389 |
+
entries = []
|
| 390 |
+
with open(self.jsonl_path, "r", encoding="utf-8") as f:
|
| 391 |
+
for line in f:
|
| 392 |
+
line = line.strip()
|
| 393 |
+
if line:
|
| 394 |
+
entries.append(json.loads(line))
|
| 395 |
+
if self.shuffle:
|
| 396 |
+
random.seed(self.shuffle_seed)
|
| 397 |
+
random.shuffle(entries)
|
| 398 |
+
logging.info(
|
| 399 |
+
f"Shuffled {len(entries)} JSONL entries (seed={self.shuffle_seed})"
|
| 400 |
+
)
|
| 401 |
+
return entries
|
| 402 |
+
|
| 403 |
+
def _stream_lines(self):
|
| 404 |
+
with open(self.jsonl_path, "r", encoding="utf-8") as f:
|
| 405 |
+
for line in f:
|
| 406 |
+
line = line.strip()
|
| 407 |
+
if line:
|
| 408 |
+
yield json.loads(line)
|
| 409 |
+
|
| 410 |
+
def __iter__(self):
|
| 411 |
+
source = self._read_lines() if self.shuffle else self._stream_lines()
|
| 412 |
+
|
| 413 |
+
# Split data across distributed ranks (multi-GPU / DDP)
|
| 414 |
+
if dist.is_initialized():
|
| 415 |
+
rank = dist.get_rank()
|
| 416 |
+
world_size = dist.get_world_size()
|
| 417 |
+
source = [item for i, item in enumerate(source) if i % world_size == rank]
|
| 418 |
+
|
| 419 |
+
# Split data across DataLoader workers to avoid duplication
|
| 420 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 421 |
+
if worker_info is not None:
|
| 422 |
+
source = (
|
| 423 |
+
item
|
| 424 |
+
for i, item in enumerate(source)
|
| 425 |
+
if i % worker_info.num_workers == worker_info.id
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
for meta in source:
|
| 429 |
+
audio_path = meta.get("audio_path")
|
| 430 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 431 |
+
logging.warning(
|
| 432 |
+
f"Skipping {meta.get('id', '?')}: audio_path missing or not found"
|
| 433 |
+
)
|
| 434 |
+
continue
|
| 435 |
+
try:
|
| 436 |
+
waveform, sr = torchaudio.load(audio_path)
|
| 437 |
+
if waveform.shape[0] > 1:
|
| 438 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
| 439 |
+
if sr != self.sample_rate:
|
| 440 |
+
waveform = torchaudio.functional.resample(
|
| 441 |
+
waveform, sr, self.sample_rate
|
| 442 |
+
)
|
| 443 |
+
if self.normalize_audio:
|
| 444 |
+
waveform = (waveform / (waveform.abs().max() + 1e-7)) * 0.9
|
| 445 |
+
meta["audio_duration"] = waveform.shape[1] / self.sample_rate
|
| 446 |
+
yield {"audio": waveform, "label": meta}
|
| 447 |
+
except Exception as e:
|
| 448 |
+
logging.warning(f"Skipping {meta.get('id', '?')}: {e}")
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class MuxWebDatasetReader(IterableDataReader):
|
| 452 |
+
def __init__(
|
| 453 |
+
self,
|
| 454 |
+
readers: List[WebDatasetReader],
|
| 455 |
+
weights: Optional[List[float]] = None,
|
| 456 |
+
stop_early: bool = False,
|
| 457 |
+
seed: int = 0,
|
| 458 |
+
):
|
| 459 |
+
self.readers = readers
|
| 460 |
+
self.stop_early = stop_early
|
| 461 |
+
self.mux_iterator = LazyIteratorMultiplexer(
|
| 462 |
+
*readers,
|
| 463 |
+
stop_early=stop_early,
|
| 464 |
+
weights=weights,
|
| 465 |
+
seed=seed,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
def set_epoch(self, epoch: int):
|
| 469 |
+
"""
|
| 470 |
+
Set the epoch for shuffling.
|
| 471 |
+
"""
|
| 472 |
+
for reader in self.readers:
|
| 473 |
+
reader.set_epoch(epoch)
|
| 474 |
+
|
| 475 |
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
| 476 |
+
return iter(self.mux_iterator)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class LazyIteratorMultiplexer:
|
| 480 |
+
"""
|
| 481 |
+
A wrapper over multiple iterators that enables to combine
|
| 482 |
+
lazy manifests in Lhotse. During iteration, unlike
|
| 483 |
+
:class:`.LazyIteratorChain`,
|
| 484 |
+
:class:`.LazyIteratorMultiplexer` at each step randomly
|
| 485 |
+
selects the iterable used to yield an item.
|
| 486 |
+
|
| 487 |
+
Since the iterables might be of different length, we provide
|
| 488 |
+
a ``weights`` parameter to let the user decide which iterables
|
| 489 |
+
should be sampled more frequently than others.
|
| 490 |
+
When an iterable is exhausted, we will keep sampling from the other iterables, until
|
| 491 |
+
we exhaust them all, unless ``stop_early`` is set to ``True``.
|
| 492 |
+
"""
|
| 493 |
+
|
| 494 |
+
def __init__(
|
| 495 |
+
self,
|
| 496 |
+
*iterators: IterableDataReader,
|
| 497 |
+
stop_early: bool = False,
|
| 498 |
+
weights: Optional[List[float]] = None,
|
| 499 |
+
seed: int = 0,
|
| 500 |
+
) -> None:
|
| 501 |
+
self.iterators = list(iterators)
|
| 502 |
+
self.stop_early = stop_early
|
| 503 |
+
self.seed = seed
|
| 504 |
+
|
| 505 |
+
assert (
|
| 506 |
+
len(self.iterators) > 1
|
| 507 |
+
), "There have to be at least two iterables to multiplex."
|
| 508 |
+
|
| 509 |
+
if weights is None:
|
| 510 |
+
if all(hasattr(it, "__len__") for it in self.iterators):
|
| 511 |
+
lengths = [len(it) for it in self.iterators]
|
| 512 |
+
total_length = sum(lengths)
|
| 513 |
+
self.weights = [length / total_length for length in lengths]
|
| 514 |
+
else:
|
| 515 |
+
self.weights = [1] * len(self.iterators)
|
| 516 |
+
else:
|
| 517 |
+
self.weights = weights
|
| 518 |
+
|
| 519 |
+
assert len(self.iterators) == len(self.weights)
|
| 520 |
+
|
| 521 |
+
def __iter__(self):
|
| 522 |
+
|
| 523 |
+
rng = random.Random(self.seed)
|
| 524 |
+
iters = [iter(it) for it in self.iterators]
|
| 525 |
+
exhausted = [False for _ in range(len(iters))]
|
| 526 |
+
|
| 527 |
+
def should_continue():
|
| 528 |
+
if self.stop_early:
|
| 529 |
+
return not any(exhausted)
|
| 530 |
+
else:
|
| 531 |
+
return not all(exhausted)
|
| 532 |
+
|
| 533 |
+
while should_continue():
|
| 534 |
+
active_indexes, active_weights = zip(
|
| 535 |
+
*[
|
| 536 |
+
(i, w)
|
| 537 |
+
for i, (is_exhausted, w) in enumerate(zip(exhausted, self.weights))
|
| 538 |
+
if not is_exhausted
|
| 539 |
+
]
|
| 540 |
+
)
|
| 541 |
+
idx = rng.choices(active_indexes, weights=active_weights, k=1)[0]
|
| 542 |
+
selected = iters[idx]
|
| 543 |
+
try:
|
| 544 |
+
item = next(selected)
|
| 545 |
+
yield item
|
| 546 |
+
except StopIteration:
|
| 547 |
+
exhausted[idx] = True
|
| 548 |
+
continue
|
| 549 |
+
|
| 550 |
+
def __len__(self) -> int:
|
| 551 |
+
return sum(len(iterator) for iterator in self.iterators)
|
omnivoice/data/processor.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Training sample processor for OmniVoice.
|
| 19 |
+
|
| 20 |
+
Converts raw audio/text samples into model-ready tensors: applies prompt/mask
|
| 21 |
+
tokenization, randomly drops conditioning, and injects language/instruct tokens.
|
| 22 |
+
Used by ``omnivoice.training.builder`` to build the data pipeline.
|
| 23 |
+
|
| 24 |
+
Contains two processor classes:
|
| 25 |
+
- ``OmniVoiceSampleProcessor``: Full processor used for training.
|
| 26 |
+
- ``OmniVoiceSimpleSampleProcessor``: Simplified processor (not used for training).
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import random
|
| 30 |
+
from typing import Any, Dict
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class OmniVoiceSampleProcessor:
|
| 36 |
+
"""
|
| 37 |
+
Handles the logic of processing a raw sample into tensors
|
| 38 |
+
(masking, tokenization, etc.).
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
text_tokenizer: Any,
|
| 44 |
+
num_channels: int,
|
| 45 |
+
audio_mask_id: int,
|
| 46 |
+
prompt_ratio_range: tuple,
|
| 47 |
+
mask_ratio_range: tuple,
|
| 48 |
+
drop_cond_ratio: float,
|
| 49 |
+
language_ratio: float,
|
| 50 |
+
use_pinyin_ratio: float,
|
| 51 |
+
instruct_ratio: float,
|
| 52 |
+
only_instruct_ratio: float,
|
| 53 |
+
):
|
| 54 |
+
self.text_tokenizer = text_tokenizer
|
| 55 |
+
self.num_channels = num_channels
|
| 56 |
+
self.audio_mask_id = audio_mask_id
|
| 57 |
+
self.prompt_ratio_range = prompt_ratio_range
|
| 58 |
+
self.mask_ratio_range = mask_ratio_range
|
| 59 |
+
self.drop_cond_ratio = drop_cond_ratio
|
| 60 |
+
|
| 61 |
+
self.language_ratio = language_ratio
|
| 62 |
+
self.use_pinyin_ratio = use_pinyin_ratio
|
| 63 |
+
self.instruct_ratio = instruct_ratio
|
| 64 |
+
self.only_instruct_ratio = only_instruct_ratio
|
| 65 |
+
|
| 66 |
+
def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
| 67 |
+
|
| 68 |
+
# clean_start_token_idx is only used for prompt denoising training,
|
| 69 |
+
# where the prompt region is augmented with noises and the model
|
| 70 |
+
# needs to learn to recover the clean prompt.
|
| 71 |
+
# clean_start_token_idx indicates the start index of the clean generated token.
|
| 72 |
+
if "clean_start_token_idx" in sample["label"]:
|
| 73 |
+
drop_cond = False
|
| 74 |
+
else:
|
| 75 |
+
drop_cond = random.uniform(0, 1) < self.drop_cond_ratio
|
| 76 |
+
|
| 77 |
+
if drop_cond:
|
| 78 |
+
prompt_ratio = 0.0
|
| 79 |
+
drop_text = True
|
| 80 |
+
use_language = False
|
| 81 |
+
use_instruct = False
|
| 82 |
+
else:
|
| 83 |
+
prompt_ratio = random.uniform(*self.prompt_ratio_range)
|
| 84 |
+
drop_text = False
|
| 85 |
+
use_language = random.uniform(0, 1) < self.language_ratio
|
| 86 |
+
use_instruct = random.uniform(0, 1) < self.instruct_ratio
|
| 87 |
+
if use_instruct and random.uniform(0, 1) < self.only_instruct_ratio:
|
| 88 |
+
prompt_ratio = 0.0
|
| 89 |
+
|
| 90 |
+
mask_ratio = random.uniform(*self.mask_ratio_range)
|
| 91 |
+
|
| 92 |
+
# --- Style ---
|
| 93 |
+
style = ""
|
| 94 |
+
if use_language:
|
| 95 |
+
language = sample["label"].get("language_id", "None")
|
| 96 |
+
else:
|
| 97 |
+
language = "None"
|
| 98 |
+
if use_instruct:
|
| 99 |
+
instruct = sample["label"].get("instruct", "None")
|
| 100 |
+
else:
|
| 101 |
+
instruct = "None"
|
| 102 |
+
|
| 103 |
+
if "clean_start_token_idx" in sample["label"]:
|
| 104 |
+
style += "<|denoise|>"
|
| 105 |
+
|
| 106 |
+
style += f"<|lang_start|>{language}<|lang_end|>"
|
| 107 |
+
style += f"<|instruct_start|>{instruct}<|instruct_end|>"
|
| 108 |
+
|
| 109 |
+
style_inputs = self.text_tokenizer(style, return_tensors="pt").input_ids.repeat(
|
| 110 |
+
self.num_channels, 1
|
| 111 |
+
)
|
| 112 |
+
style_labels = torch.full(
|
| 113 |
+
style_inputs.shape, -100
|
| 114 |
+
) # Style prompt does not compute loss
|
| 115 |
+
|
| 116 |
+
# --- Text ---
|
| 117 |
+
if (
|
| 118 |
+
"text_pinyin" in sample["label"]
|
| 119 |
+
and random.uniform(0, 1) < self.use_pinyin_ratio
|
| 120 |
+
):
|
| 121 |
+
text = sample["label"]["text_pinyin"]
|
| 122 |
+
else:
|
| 123 |
+
text = sample["label"]["text"]
|
| 124 |
+
text_inputs = self.text_tokenizer(
|
| 125 |
+
f"<|text_start|>{text}<|text_end|>", return_tensors="pt"
|
| 126 |
+
).input_ids.repeat(self.num_channels, 1)
|
| 127 |
+
text_labels = torch.full(text_inputs.shape, -100) # Text does not compute loss
|
| 128 |
+
|
| 129 |
+
# --- Audio ---
|
| 130 |
+
audio_tokens = sample["audio_tokens"].long()
|
| 131 |
+
|
| 132 |
+
# Masking Logic
|
| 133 |
+
if "clean_start_token_idx" in sample["label"]:
|
| 134 |
+
prompt_length = sample["label"]["clean_start_token_idx"]
|
| 135 |
+
else:
|
| 136 |
+
prompt_length = int(audio_tokens.shape[1] * prompt_ratio)
|
| 137 |
+
|
| 138 |
+
audio_inputs = audio_tokens.clone()
|
| 139 |
+
audio_labels = audio_tokens.clone()
|
| 140 |
+
|
| 141 |
+
# Apply masking
|
| 142 |
+
maskable_region = audio_tokens[:, prompt_length:]
|
| 143 |
+
token_mask = torch.rand(maskable_region.shape) < mask_ratio
|
| 144 |
+
audio_inputs[:, prompt_length:][token_mask] = self.audio_mask_id
|
| 145 |
+
audio_labels[:, prompt_length:][
|
| 146 |
+
~token_mask
|
| 147 |
+
] = -100 # Only compute loss on masked tokens
|
| 148 |
+
if not drop_cond:
|
| 149 |
+
audio_labels[:, :prompt_length] = -100 # No loss on prompt region
|
| 150 |
+
|
| 151 |
+
# --- Concatenation ---
|
| 152 |
+
if drop_text:
|
| 153 |
+
input_ids = audio_inputs
|
| 154 |
+
labels = audio_labels
|
| 155 |
+
total_length = input_ids.shape[1]
|
| 156 |
+
audio_mask = torch.ones(total_length, dtype=torch.bool)
|
| 157 |
+
else:
|
| 158 |
+
input_ids = torch.cat([style_inputs, text_inputs, audio_inputs], dim=1)
|
| 159 |
+
labels = torch.cat([style_labels, text_labels, audio_labels], dim=1)
|
| 160 |
+
total_length = input_ids.shape[1]
|
| 161 |
+
audio_start_idx = style_inputs.shape[1] + text_inputs.shape[1]
|
| 162 |
+
audio_mask = torch.zeros(total_length, dtype=torch.bool)
|
| 163 |
+
audio_mask[audio_start_idx:] = True
|
| 164 |
+
|
| 165 |
+
return_dict = {
|
| 166 |
+
"input_ids": input_ids, # [C, L]
|
| 167 |
+
"labels": labels, # [C, L]
|
| 168 |
+
"audio_mask": audio_mask, # [L]
|
| 169 |
+
"length": total_length,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
return return_dict
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class OmniVoiceSimpleSampleProcessor:
|
| 176 |
+
"""
|
| 177 |
+
Handles the logic of processing a raw sample into tensors
|
| 178 |
+
(masking, tokenization, etc.).
|
| 179 |
+
This is a simpler version that does not include language, instructions,
|
| 180 |
+
or denoising prompts.
|
| 181 |
+
We do not use it for training as OmniVoiceSampleProcessor can cover this case.
|
| 182 |
+
We keep it as a reference implementation for users to understand the basic logics.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
text_tokenizer: Any,
|
| 188 |
+
num_channels: int,
|
| 189 |
+
audio_mask_id: int,
|
| 190 |
+
prompt_ratio_range: tuple,
|
| 191 |
+
mask_ratio_range: tuple,
|
| 192 |
+
drop_cond_ratio: float,
|
| 193 |
+
):
|
| 194 |
+
self.text_tokenizer = text_tokenizer
|
| 195 |
+
self.num_channels = num_channels
|
| 196 |
+
self.audio_mask_id = audio_mask_id
|
| 197 |
+
self.prompt_ratio_range = prompt_ratio_range
|
| 198 |
+
self.mask_ratio_range = mask_ratio_range
|
| 199 |
+
self.drop_cond_ratio = drop_cond_ratio
|
| 200 |
+
|
| 201 |
+
def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
| 202 |
+
drop_cond = random.uniform(0, 1) < self.drop_cond_ratio
|
| 203 |
+
mask_ratio = random.uniform(*self.mask_ratio_range)
|
| 204 |
+
|
| 205 |
+
if drop_cond:
|
| 206 |
+
prompt_ratio = 0.0
|
| 207 |
+
else:
|
| 208 |
+
prompt_ratio = random.uniform(*self.prompt_ratio_range)
|
| 209 |
+
|
| 210 |
+
# --- Text ---
|
| 211 |
+
text = sample["label"]["text"]
|
| 212 |
+
text_inputs = self.text_tokenizer(
|
| 213 |
+
f"<|text_start|>{text}<|text_end|>", return_tensors="pt"
|
| 214 |
+
).input_ids.repeat(self.num_channels, 1)
|
| 215 |
+
text_labels = torch.full(text_inputs.shape, -100) # Text does not compute loss
|
| 216 |
+
|
| 217 |
+
# --- Audio ---
|
| 218 |
+
audio_tokens = sample["audio_tokens"].long()
|
| 219 |
+
|
| 220 |
+
# Masking Logic
|
| 221 |
+
prompt_length = int(audio_tokens.shape[1] * prompt_ratio)
|
| 222 |
+
audio_inputs = audio_tokens.clone()
|
| 223 |
+
audio_labels = audio_tokens.clone()
|
| 224 |
+
|
| 225 |
+
# Apply masking
|
| 226 |
+
maskable_region = audio_tokens[:, prompt_length:]
|
| 227 |
+
token_mask = torch.rand(maskable_region.shape) < mask_ratio
|
| 228 |
+
audio_inputs[:, prompt_length:][token_mask] = self.audio_mask_id
|
| 229 |
+
audio_labels[:, prompt_length:][
|
| 230 |
+
~token_mask
|
| 231 |
+
] = -100 # Only compute loss on masked tokens
|
| 232 |
+
|
| 233 |
+
if not drop_cond:
|
| 234 |
+
# No loss on prompt region
|
| 235 |
+
audio_labels[:, :prompt_length] = -100
|
| 236 |
+
|
| 237 |
+
# --- Concatenation ---
|
| 238 |
+
if drop_cond:
|
| 239 |
+
input_ids = audio_inputs
|
| 240 |
+
labels = audio_labels
|
| 241 |
+
total_length = input_ids.shape[1]
|
| 242 |
+
audio_mask = torch.ones(total_length, dtype=torch.bool)
|
| 243 |
+
else:
|
| 244 |
+
input_ids = torch.cat([text_inputs, audio_inputs], dim=1)
|
| 245 |
+
labels = torch.cat([text_labels, audio_labels], dim=1)
|
| 246 |
+
total_length = input_ids.shape[1]
|
| 247 |
+
audio_start_idx = text_inputs.shape[1]
|
| 248 |
+
audio_mask = torch.zeros(total_length, dtype=torch.bool)
|
| 249 |
+
audio_mask[audio_start_idx:] = True
|
| 250 |
+
|
| 251 |
+
return_dict = {
|
| 252 |
+
"input_ids": input_ids, # [C, L]
|
| 253 |
+
"labels": labels, # [C, L]
|
| 254 |
+
"audio_mask": audio_mask, # [L]
|
| 255 |
+
"length": total_length,
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
return return_dict
|
omnivoice/eval/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
# Suppress specific warnings from zhconv that are not relevant to WER calculation
|
| 4 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
omnivoice/eval/models/ecapa_tdnn_wavlm.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ECAPA_TDNN_WAVLM(nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
feat_dim=80,
|
| 29 |
+
channels=512,
|
| 30 |
+
emb_dim=192,
|
| 31 |
+
global_context_att=False,
|
| 32 |
+
sr=16000,
|
| 33 |
+
ssl_model_path=None,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.sr = sr
|
| 37 |
+
|
| 38 |
+
if ssl_model_path is None:
|
| 39 |
+
self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
|
| 40 |
+
else:
|
| 41 |
+
self.feature_extract = torch.hub.load(
|
| 42 |
+
os.path.dirname(ssl_model_path),
|
| 43 |
+
"wavlm_local",
|
| 44 |
+
source="local",
|
| 45 |
+
ckpt=os.path.join(ssl_model_path, "wavlm_large.pt"),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
| 49 |
+
self.feature_extract.model.encoder.layers[23].self_attn,
|
| 50 |
+
"fp32_attention",
|
| 51 |
+
):
|
| 52 |
+
self.feature_extract.model.encoder.layers[
|
| 53 |
+
23
|
| 54 |
+
].self_attn.fp32_attention = False
|
| 55 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
| 56 |
+
self.feature_extract.model.encoder.layers[11].self_attn,
|
| 57 |
+
"fp32_attention",
|
| 58 |
+
):
|
| 59 |
+
self.feature_extract.model.encoder.layers[
|
| 60 |
+
11
|
| 61 |
+
].self_attn.fp32_attention = False
|
| 62 |
+
|
| 63 |
+
self.feat_num = self.get_feat_num()
|
| 64 |
+
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
| 65 |
+
|
| 66 |
+
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
| 67 |
+
# self.channels = [channels] * 4 + [channels * 3]
|
| 68 |
+
self.channels = [channels] * 4 + [1536]
|
| 69 |
+
|
| 70 |
+
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
| 71 |
+
self.layer2 = SE_Res2Block(
|
| 72 |
+
self.channels[0],
|
| 73 |
+
self.channels[1],
|
| 74 |
+
kernel_size=3,
|
| 75 |
+
stride=1,
|
| 76 |
+
padding=2,
|
| 77 |
+
dilation=2,
|
| 78 |
+
scale=8,
|
| 79 |
+
se_bottleneck_dim=128,
|
| 80 |
+
)
|
| 81 |
+
self.layer3 = SE_Res2Block(
|
| 82 |
+
self.channels[1],
|
| 83 |
+
self.channels[2],
|
| 84 |
+
kernel_size=3,
|
| 85 |
+
stride=1,
|
| 86 |
+
padding=3,
|
| 87 |
+
dilation=3,
|
| 88 |
+
scale=8,
|
| 89 |
+
se_bottleneck_dim=128,
|
| 90 |
+
)
|
| 91 |
+
self.layer4 = SE_Res2Block(
|
| 92 |
+
self.channels[2],
|
| 93 |
+
self.channels[3],
|
| 94 |
+
kernel_size=3,
|
| 95 |
+
stride=1,
|
| 96 |
+
padding=4,
|
| 97 |
+
dilation=4,
|
| 98 |
+
scale=8,
|
| 99 |
+
se_bottleneck_dim=128,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
| 103 |
+
cat_channels = channels * 3
|
| 104 |
+
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
| 105 |
+
self.pooling = AttentiveStatsPool(
|
| 106 |
+
self.channels[-1],
|
| 107 |
+
attention_channels=128,
|
| 108 |
+
global_context_att=global_context_att,
|
| 109 |
+
)
|
| 110 |
+
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
| 111 |
+
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
| 112 |
+
|
| 113 |
+
def get_feat_num(self):
|
| 114 |
+
self.feature_extract.eval()
|
| 115 |
+
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
features = self.feature_extract(wav)
|
| 118 |
+
select_feature = features["hidden_states"]
|
| 119 |
+
if isinstance(select_feature, (list, tuple)):
|
| 120 |
+
return len(select_feature)
|
| 121 |
+
else:
|
| 122 |
+
return 1
|
| 123 |
+
|
| 124 |
+
def get_feat(self, x):
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
x = self.feature_extract([sample for sample in x])
|
| 127 |
+
|
| 128 |
+
x = x["hidden_states"]
|
| 129 |
+
if isinstance(x, (list, tuple)):
|
| 130 |
+
x = torch.stack(x, dim=0)
|
| 131 |
+
else:
|
| 132 |
+
x = x.unsqueeze(0)
|
| 133 |
+
norm_weights = (
|
| 134 |
+
F.softmax(self.feature_weight, dim=-1)
|
| 135 |
+
.unsqueeze(-1)
|
| 136 |
+
.unsqueeze(-1)
|
| 137 |
+
.unsqueeze(-1)
|
| 138 |
+
)
|
| 139 |
+
x = (norm_weights * x).sum(dim=0)
|
| 140 |
+
x = torch.transpose(x, 1, 2) + 1e-6
|
| 141 |
+
|
| 142 |
+
x = self.instance_norm(x)
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
x = self.get_feat(x)
|
| 147 |
+
|
| 148 |
+
out1 = self.layer1(x)
|
| 149 |
+
out2 = self.layer2(out1)
|
| 150 |
+
out3 = self.layer3(out2)
|
| 151 |
+
out4 = self.layer4(out3)
|
| 152 |
+
|
| 153 |
+
out = torch.cat([out2, out3, out4], dim=1)
|
| 154 |
+
out = F.relu(self.conv(out))
|
| 155 |
+
out = self.bn(self.pooling(out))
|
| 156 |
+
out = self.linear(out)
|
| 157 |
+
|
| 158 |
+
return out
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
| 162 |
+
|
| 163 |
+
""" Res2Conv1d + BatchNorm1d + ReLU
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class Res2Conv1dReluBn(nn.Module):
|
| 168 |
+
"""
|
| 169 |
+
in_channels == out_channels == channels
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
channels,
|
| 175 |
+
kernel_size=1,
|
| 176 |
+
stride=1,
|
| 177 |
+
padding=0,
|
| 178 |
+
dilation=1,
|
| 179 |
+
bias=True,
|
| 180 |
+
scale=4,
|
| 181 |
+
):
|
| 182 |
+
super().__init__()
|
| 183 |
+
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
| 184 |
+
self.scale = scale
|
| 185 |
+
self.width = channels // scale
|
| 186 |
+
self.nums = scale if scale == 1 else scale - 1
|
| 187 |
+
|
| 188 |
+
self.convs = []
|
| 189 |
+
self.bns = []
|
| 190 |
+
for i in range(self.nums):
|
| 191 |
+
self.convs.append(
|
| 192 |
+
nn.Conv1d(
|
| 193 |
+
self.width,
|
| 194 |
+
self.width,
|
| 195 |
+
kernel_size,
|
| 196 |
+
stride,
|
| 197 |
+
padding,
|
| 198 |
+
dilation,
|
| 199 |
+
bias=bias,
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
self.bns.append(nn.BatchNorm1d(self.width))
|
| 203 |
+
self.convs = nn.ModuleList(self.convs)
|
| 204 |
+
self.bns = nn.ModuleList(self.bns)
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
out = []
|
| 208 |
+
spx = torch.split(x, self.width, 1)
|
| 209 |
+
for i in range(self.nums):
|
| 210 |
+
if i == 0:
|
| 211 |
+
sp = spx[i]
|
| 212 |
+
else:
|
| 213 |
+
sp = sp + spx[i]
|
| 214 |
+
# Order: conv -> relu -> bn
|
| 215 |
+
sp = self.convs[i](sp)
|
| 216 |
+
sp = self.bns[i](F.relu(sp))
|
| 217 |
+
out.append(sp)
|
| 218 |
+
if self.scale != 1:
|
| 219 |
+
out.append(spx[self.nums])
|
| 220 |
+
out = torch.cat(out, dim=1)
|
| 221 |
+
|
| 222 |
+
return out
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
""" Conv1d + BatchNorm1d + ReLU
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class Conv1dReluBn(nn.Module):
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
in_channels,
|
| 233 |
+
out_channels,
|
| 234 |
+
kernel_size=1,
|
| 235 |
+
stride=1,
|
| 236 |
+
padding=0,
|
| 237 |
+
dilation=1,
|
| 238 |
+
bias=True,
|
| 239 |
+
):
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.conv = nn.Conv1d(
|
| 242 |
+
in_channels,
|
| 243 |
+
out_channels,
|
| 244 |
+
kernel_size,
|
| 245 |
+
stride,
|
| 246 |
+
padding,
|
| 247 |
+
dilation,
|
| 248 |
+
bias=bias,
|
| 249 |
+
)
|
| 250 |
+
self.bn = nn.BatchNorm1d(out_channels)
|
| 251 |
+
|
| 252 |
+
def forward(self, x):
|
| 253 |
+
return self.bn(F.relu(self.conv(x)))
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
""" The SE connection of 1D case.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class SE_Connect(nn.Module):
|
| 261 |
+
def __init__(self, channels, se_bottleneck_dim=128):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
| 264 |
+
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
out = x.mean(dim=2)
|
| 268 |
+
out = F.relu(self.linear1(out))
|
| 269 |
+
out = torch.sigmoid(self.linear2(out))
|
| 270 |
+
out = x * out.unsqueeze(2)
|
| 271 |
+
|
| 272 |
+
return out
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
| 280 |
+
# return nn.Sequential(
|
| 281 |
+
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
|
| 282 |
+
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
|
| 283 |
+
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
|
| 284 |
+
# SE_Connect(channels)
|
| 285 |
+
# )
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class SE_Res2Block(nn.Module):
|
| 289 |
+
def __init__(
|
| 290 |
+
self,
|
| 291 |
+
in_channels,
|
| 292 |
+
out_channels,
|
| 293 |
+
kernel_size,
|
| 294 |
+
stride,
|
| 295 |
+
padding,
|
| 296 |
+
dilation,
|
| 297 |
+
scale,
|
| 298 |
+
se_bottleneck_dim,
|
| 299 |
+
):
|
| 300 |
+
super().__init__()
|
| 301 |
+
self.Conv1dReluBn1 = Conv1dReluBn(
|
| 302 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 303 |
+
)
|
| 304 |
+
self.Res2Conv1dReluBn = Res2Conv1dReluBn(
|
| 305 |
+
out_channels, kernel_size, stride, padding, dilation, scale=scale
|
| 306 |
+
)
|
| 307 |
+
self.Conv1dReluBn2 = Conv1dReluBn(
|
| 308 |
+
out_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 309 |
+
)
|
| 310 |
+
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
| 311 |
+
|
| 312 |
+
self.shortcut = None
|
| 313 |
+
if in_channels != out_channels:
|
| 314 |
+
self.shortcut = nn.Conv1d(
|
| 315 |
+
in_channels=in_channels,
|
| 316 |
+
out_channels=out_channels,
|
| 317 |
+
kernel_size=1,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def forward(self, x):
|
| 321 |
+
residual = x
|
| 322 |
+
if self.shortcut:
|
| 323 |
+
residual = self.shortcut(x)
|
| 324 |
+
|
| 325 |
+
x = self.Conv1dReluBn1(x)
|
| 326 |
+
x = self.Res2Conv1dReluBn(x)
|
| 327 |
+
x = self.Conv1dReluBn2(x)
|
| 328 |
+
x = self.SE_Connect(x)
|
| 329 |
+
|
| 330 |
+
return x + residual
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
""" Attentive weighted mean and standard deviation pooling.
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class AttentiveStatsPool(nn.Module):
|
| 338 |
+
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
| 339 |
+
super().__init__()
|
| 340 |
+
self.global_context_att = global_context_att
|
| 341 |
+
|
| 342 |
+
# Use Conv1d with stride == 1 rather than Linear,
|
| 343 |
+
# then we don't need to transpose inputs.
|
| 344 |
+
if global_context_att:
|
| 345 |
+
self.linear1 = nn.Conv1d(
|
| 346 |
+
in_dim * 3, attention_channels, kernel_size=1
|
| 347 |
+
) # equals W and b in the paper
|
| 348 |
+
else:
|
| 349 |
+
self.linear1 = nn.Conv1d(
|
| 350 |
+
in_dim, attention_channels, kernel_size=1
|
| 351 |
+
) # equals W and b in the paper
|
| 352 |
+
self.linear2 = nn.Conv1d(
|
| 353 |
+
attention_channels, in_dim, kernel_size=1
|
| 354 |
+
) # equals V and k in the paper
|
| 355 |
+
|
| 356 |
+
def forward(self, x):
|
| 357 |
+
|
| 358 |
+
if self.global_context_att:
|
| 359 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
| 360 |
+
context_std = torch.sqrt(
|
| 361 |
+
torch.var(x, dim=-1, keepdim=True) + 1e-10
|
| 362 |
+
).expand_as(x)
|
| 363 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
| 364 |
+
else:
|
| 365 |
+
x_in = x
|
| 366 |
+
|
| 367 |
+
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
| 368 |
+
alpha = torch.tanh(self.linear1(x_in))
|
| 369 |
+
# alpha = F.relu(self.linear1(x_in))
|
| 370 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 371 |
+
mean = torch.sum(alpha * x, dim=2)
|
| 372 |
+
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
| 373 |
+
std = torch.sqrt(residuals.clamp(min=1e-9))
|
| 374 |
+
return torch.cat([mean, std], dim=1)
|
omnivoice/eval/models/utmos.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
UTMOS strong model.
|
| 20 |
+
Implementation from https://github.com/tarepan/SpeechMOS
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import math
|
| 25 |
+
from typing import List, Optional, Tuple
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
from torch import Tensor, nn
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class UTMOS22Strong(nn.Module):
|
| 33 |
+
"""Saeki_2022 paper's `UTMOS strong learner` inference model
|
| 34 |
+
(w/o Phoneme encoder)."""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
"""Init."""
|
| 38 |
+
|
| 39 |
+
super().__init__() # pyright: ignore [reportUnknownMemberType]
|
| 40 |
+
|
| 41 |
+
feat_ssl, feat_domain_emb, feat_judge_emb, feat_rnn_h, feat_proj_h = (
|
| 42 |
+
768,
|
| 43 |
+
128,
|
| 44 |
+
128,
|
| 45 |
+
512,
|
| 46 |
+
2048,
|
| 47 |
+
)
|
| 48 |
+
feat_cat = feat_ssl + feat_domain_emb + feat_judge_emb
|
| 49 |
+
|
| 50 |
+
# SSL/DataDomainEmb/JudgeIdEmb/BLSTM/Projection
|
| 51 |
+
self.wav2vec2 = Wav2Vec2Model()
|
| 52 |
+
self.domain_emb = nn.Parameter(
|
| 53 |
+
data=torch.empty(1, feat_domain_emb), requires_grad=False
|
| 54 |
+
)
|
| 55 |
+
self.judge_emb = nn.Parameter(
|
| 56 |
+
data=torch.empty(1, feat_judge_emb), requires_grad=False
|
| 57 |
+
)
|
| 58 |
+
self.blstm = nn.LSTM(
|
| 59 |
+
input_size=feat_cat,
|
| 60 |
+
hidden_size=feat_rnn_h,
|
| 61 |
+
batch_first=True,
|
| 62 |
+
bidirectional=True,
|
| 63 |
+
)
|
| 64 |
+
self.projection = nn.Sequential(
|
| 65 |
+
nn.Linear(feat_rnn_h * 2, feat_proj_h), nn.ReLU(), nn.Linear(feat_proj_h, 1)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def forward(self, wave: Tensor, sr: int) -> Tensor: # pylint: disable=invalid-name
|
| 69 |
+
"""wave-to-score :: (B, T) -> (B,)"""
|
| 70 |
+
|
| 71 |
+
# Feature extraction :: (B, T) -> (B, Frame, Feat)
|
| 72 |
+
unit_series = self.wav2vec2(wave)
|
| 73 |
+
bsz, frm, _ = unit_series.size()
|
| 74 |
+
|
| 75 |
+
# DataDomain/JudgeId Embedding's Batch/Time expansion ::
|
| 76 |
+
# (B=1, Feat) -> (B=bsz, Frame=frm, Feat)
|
| 77 |
+
domain_series = self.domain_emb.unsqueeze(1).expand(bsz, frm, -1)
|
| 78 |
+
judge_series = self.judge_emb.unsqueeze(1).expand(bsz, frm, -1)
|
| 79 |
+
|
| 80 |
+
# Feature concatenation :: (B, Frame, Feat=f1) + (B, Frame, Feat=f2) +
|
| 81 |
+
# (B, Frame, Feat=f3) -> (B, Frame, Feat=f1+f2+f3)
|
| 82 |
+
cat_series = torch.cat([unit_series, domain_series, judge_series], dim=2)
|
| 83 |
+
|
| 84 |
+
# Frame-scale score estimation :: (B, Frame, Feat) -> (B, Frame, Feat)
|
| 85 |
+
# -> (B, Frame, Feat=1) - BLSTM/Projection
|
| 86 |
+
feat_series = self.blstm(cat_series)[0]
|
| 87 |
+
score_series = self.projection(feat_series)
|
| 88 |
+
|
| 89 |
+
# Utterance-scale score :: (B, Frame, Feat=1) -> (B, Feat=1)
|
| 90 |
+
# -> (B,) - Time averaging
|
| 91 |
+
utter_score = score_series.mean(dim=1).squeeze(1) * 2 + 3
|
| 92 |
+
|
| 93 |
+
return utter_score
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class Wav2Vec2Model(nn.Module):
|
| 97 |
+
"""Wav2Vev2."""
|
| 98 |
+
|
| 99 |
+
def __init__(self):
|
| 100 |
+
super().__init__() # pyright: ignore [reportUnknownMemberType]
|
| 101 |
+
|
| 102 |
+
feat_h1, feat_h2 = 512, 768
|
| 103 |
+
feature_enc_layers = (
|
| 104 |
+
[(feat_h1, 10, 5)] + [(feat_h1, 3, 2)] * 4 + [(feat_h1, 2, 2)] * 2
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
| 108 |
+
conv_layers=feature_enc_layers
|
| 109 |
+
) # pyright: ignore [reportGeneralTypeIssues]
|
| 110 |
+
self.layer_norm = nn.LayerNorm(feat_h1)
|
| 111 |
+
self.post_extract_proj = nn.Linear(feat_h1, feat_h2)
|
| 112 |
+
self.dropout_input = nn.Dropout(0.1)
|
| 113 |
+
self.encoder = TransformerEncoder(feat_h2)
|
| 114 |
+
|
| 115 |
+
# Remnants
|
| 116 |
+
self.mask_emb = nn.Parameter(torch.FloatTensor(feat_h2))
|
| 117 |
+
|
| 118 |
+
def forward(self, source: Tensor):
|
| 119 |
+
"""FeatureEncoder + ContextTransformer"""
|
| 120 |
+
|
| 121 |
+
# Feature encoding
|
| 122 |
+
features = self.feature_extractor(source)
|
| 123 |
+
features = features.transpose(1, 2)
|
| 124 |
+
features = self.layer_norm(features)
|
| 125 |
+
features = self.post_extract_proj(features)
|
| 126 |
+
|
| 127 |
+
# Context transformer
|
| 128 |
+
x = self.encoder(features)
|
| 129 |
+
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class ConvFeatureExtractionModel(nn.Module):
|
| 134 |
+
"""Feature Encoder."""
|
| 135 |
+
|
| 136 |
+
def __init__(self, conv_layers: List[Tuple[int, int, int]]):
|
| 137 |
+
super().__init__() # pyright: ignore [reportUnknownMemberType]
|
| 138 |
+
|
| 139 |
+
def block(
|
| 140 |
+
n_in: int, n_out: int, k: int, stride: int, is_group_norm: bool = False
|
| 141 |
+
):
|
| 142 |
+
if is_group_norm:
|
| 143 |
+
return nn.Sequential(
|
| 144 |
+
nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
|
| 145 |
+
nn.Dropout(p=0.0),
|
| 146 |
+
nn.GroupNorm(dim, dim, affine=True),
|
| 147 |
+
nn.GELU(),
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
return nn.Sequential(
|
| 151 |
+
nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
|
| 152 |
+
nn.Dropout(p=0.0),
|
| 153 |
+
nn.GELU(),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
in_d = 1
|
| 157 |
+
self.conv_layers = nn.ModuleList()
|
| 158 |
+
for i, params in enumerate(conv_layers):
|
| 159 |
+
(dim, k, stride) = params
|
| 160 |
+
self.conv_layers.append(block(in_d, dim, k, stride, is_group_norm=i == 0))
|
| 161 |
+
in_d = dim
|
| 162 |
+
|
| 163 |
+
def forward(self, series: Tensor) -> Tensor:
|
| 164 |
+
""":: (B, T) -> (B, Feat, Frame)"""
|
| 165 |
+
|
| 166 |
+
series = series.unsqueeze(1)
|
| 167 |
+
for conv in self.conv_layers:
|
| 168 |
+
series = conv(series)
|
| 169 |
+
|
| 170 |
+
return series
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class TransformerEncoder(nn.Module):
|
| 174 |
+
"""Transformer."""
|
| 175 |
+
|
| 176 |
+
def build_encoder_layer(self, feat: int):
|
| 177 |
+
"""Layer builder."""
|
| 178 |
+
return TransformerSentenceEncoderLayer(
|
| 179 |
+
embedding_dim=feat,
|
| 180 |
+
ffn_embedding_dim=3072,
|
| 181 |
+
num_attention_heads=12,
|
| 182 |
+
activation_fn="gelu",
|
| 183 |
+
dropout=0.1,
|
| 184 |
+
attention_dropout=0.1,
|
| 185 |
+
activation_dropout=0.0,
|
| 186 |
+
layer_norm_first=False,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def __init__(self, feat: int):
|
| 190 |
+
super().__init__() # pyright: ignore [reportUnknownMemberType]
|
| 191 |
+
|
| 192 |
+
self.required_seq_len_multiple = 2
|
| 193 |
+
|
| 194 |
+
self.pos_conv = nn.Sequential(
|
| 195 |
+
*[
|
| 196 |
+
nn.utils.weight_norm(
|
| 197 |
+
nn.Conv1d(feat, feat, kernel_size=128, padding=128 // 2, groups=16),
|
| 198 |
+
name="weight",
|
| 199 |
+
dim=2,
|
| 200 |
+
),
|
| 201 |
+
SamePad(128),
|
| 202 |
+
nn.GELU(),
|
| 203 |
+
]
|
| 204 |
+
)
|
| 205 |
+
self.layer_norm = nn.LayerNorm(feat)
|
| 206 |
+
self.layers = nn.ModuleList([self.build_encoder_layer(feat) for _ in range(12)])
|
| 207 |
+
|
| 208 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 209 |
+
|
| 210 |
+
x_conv = self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
|
| 211 |
+
x = x + x_conv
|
| 212 |
+
|
| 213 |
+
x = self.layer_norm(x)
|
| 214 |
+
|
| 215 |
+
# pad to the sequence length dimension
|
| 216 |
+
x, pad_length = pad_to_multiple(
|
| 217 |
+
x, self.required_seq_len_multiple, dim=-2, value=0
|
| 218 |
+
)
|
| 219 |
+
if pad_length > 0:
|
| 220 |
+
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
| 221 |
+
padding_mask[:, -pad_length:] = True
|
| 222 |
+
else:
|
| 223 |
+
padding_mask, _ = pad_to_multiple(
|
| 224 |
+
None, self.required_seq_len_multiple, dim=-1, value=True
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# :: (B, T, Feat) -> (T, B, Feat)
|
| 228 |
+
x = x.transpose(0, 1)
|
| 229 |
+
for layer in self.layers:
|
| 230 |
+
x = layer(x, padding_mask)
|
| 231 |
+
# :: (T, B, Feat) -> (B, T, Feat)
|
| 232 |
+
x = x.transpose(0, 1)
|
| 233 |
+
|
| 234 |
+
# undo paddding
|
| 235 |
+
if pad_length > 0:
|
| 236 |
+
x = x[:, :-pad_length]
|
| 237 |
+
|
| 238 |
+
return x
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class SamePad(nn.Module):
|
| 242 |
+
"""Tail inverse padding."""
|
| 243 |
+
|
| 244 |
+
def __init__(self, kernel_size: int):
|
| 245 |
+
super().__init__() # pyright: ignore [reportUnknownMemberType]
|
| 246 |
+
assert kernel_size % 2 == 0, "`SamePad` now support only even kernel."
|
| 247 |
+
|
| 248 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 249 |
+
return x[:, :, :-1]
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def pad_to_multiple(
|
| 253 |
+
x: Optional[Tensor], multiple: int, dim: int = -1, value: float = 0
|
| 254 |
+
) -> Tuple[Optional[Tensor], int]:
|
| 255 |
+
"""Tail padding."""
|
| 256 |
+
if x is None:
|
| 257 |
+
return None, 0
|
| 258 |
+
tsz = x.size(dim)
|
| 259 |
+
m = tsz / multiple
|
| 260 |
+
remainder = math.ceil(m) * multiple - tsz
|
| 261 |
+
if m.is_integer():
|
| 262 |
+
return x, 0
|
| 263 |
+
pad_offset = (0,) * (-1 - dim) * 2
|
| 264 |
+
|
| 265 |
+
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
| 269 |
+
"""Transformer Encoder Layer used in BERT/XLM style pre-trained models."""
|
| 270 |
+
|
| 271 |
+
def __init__(
|
| 272 |
+
self,
|
| 273 |
+
embedding_dim: int,
|
| 274 |
+
ffn_embedding_dim: int,
|
| 275 |
+
num_attention_heads: int,
|
| 276 |
+
activation_fn: str,
|
| 277 |
+
dropout: float,
|
| 278 |
+
attention_dropout: float,
|
| 279 |
+
activation_dropout: float,
|
| 280 |
+
layer_norm_first: bool,
|
| 281 |
+
) -> None:
|
| 282 |
+
super().__init__() # pyright: ignore [reportUnknownMemberType]
|
| 283 |
+
|
| 284 |
+
assert layer_norm_first is False, "`layer_norm_first` is fixed to `False`"
|
| 285 |
+
assert activation_fn == "gelu", "`activation_fn` is fixed to `gelu`"
|
| 286 |
+
|
| 287 |
+
feat = embedding_dim
|
| 288 |
+
|
| 289 |
+
self.self_attn = MultiheadAttention(
|
| 290 |
+
feat, num_attention_heads, attention_dropout
|
| 291 |
+
)
|
| 292 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 293 |
+
self.dropout2 = nn.Dropout(activation_dropout)
|
| 294 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 295 |
+
self.fc1 = nn.Linear(feat, ffn_embedding_dim)
|
| 296 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, feat)
|
| 297 |
+
self.self_attn_layer_norm = nn.LayerNorm(feat)
|
| 298 |
+
self.final_layer_norm = nn.LayerNorm(feat)
|
| 299 |
+
|
| 300 |
+
def forward(self, x: Tensor, self_attn_padding_mask: Optional[Tensor]):
|
| 301 |
+
# Res[Attn-Do]-LN
|
| 302 |
+
residual = x
|
| 303 |
+
x = self.self_attn(x, x, x, self_attn_padding_mask)
|
| 304 |
+
x = self.dropout1(x)
|
| 305 |
+
x = residual + x
|
| 306 |
+
x = self.self_attn_layer_norm(x)
|
| 307 |
+
|
| 308 |
+
# Res[SegFC-GELU-Do-SegFC-Do]-LN
|
| 309 |
+
residual = x
|
| 310 |
+
x = F.gelu(self.fc1(x)) # pyright: ignore [reportUnknownMemberType]
|
| 311 |
+
x = self.dropout2(x)
|
| 312 |
+
x = self.fc2(x)
|
| 313 |
+
x = self.dropout3(x)
|
| 314 |
+
x = residual + x
|
| 315 |
+
x = self.final_layer_norm(x)
|
| 316 |
+
|
| 317 |
+
return x
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class MultiheadAttention(nn.Module):
|
| 321 |
+
"""Multi-headed attention."""
|
| 322 |
+
|
| 323 |
+
def __init__(self, embed_dim: int, num_heads: int, dropout: float):
|
| 324 |
+
super().__init__() # pyright: ignore [reportUnknownMemberType]
|
| 325 |
+
|
| 326 |
+
self.embed_dim, self.num_heads, self.p_dropout = embed_dim, num_heads, dropout
|
| 327 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 328 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 329 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 330 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
query: Tensor,
|
| 335 |
+
key: Tensor,
|
| 336 |
+
value: Tensor,
|
| 337 |
+
key_padding_mask: Optional[Tensor],
|
| 338 |
+
) -> Tensor:
|
| 339 |
+
"""
|
| 340 |
+
Args:
|
| 341 |
+
query :: (T, B, Feat)
|
| 342 |
+
key_padding_mask :: (B, src_len) - mask to exclude keys that are pads
|
| 343 |
+
, where padding elements are indicated by 1s.
|
| 344 |
+
"""
|
| 345 |
+
return F.multi_head_attention_forward(
|
| 346 |
+
query=query,
|
| 347 |
+
key=key,
|
| 348 |
+
value=value,
|
| 349 |
+
embed_dim_to_check=self.embed_dim,
|
| 350 |
+
num_heads=self.num_heads,
|
| 351 |
+
in_proj_weight=torch.empty([0]),
|
| 352 |
+
in_proj_bias=torch.cat(
|
| 353 |
+
(self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)
|
| 354 |
+
),
|
| 355 |
+
bias_k=None,
|
| 356 |
+
bias_v=None,
|
| 357 |
+
add_zero_attn=False,
|
| 358 |
+
dropout_p=self.p_dropout,
|
| 359 |
+
out_proj_weight=self.out_proj.weight,
|
| 360 |
+
out_proj_bias=self.out_proj.bias,
|
| 361 |
+
training=False,
|
| 362 |
+
key_padding_mask=key_padding_mask.bool()
|
| 363 |
+
if key_padding_mask is not None
|
| 364 |
+
else None,
|
| 365 |
+
need_weights=False,
|
| 366 |
+
use_separate_proj_weight=True,
|
| 367 |
+
q_proj_weight=self.q_proj.weight,
|
| 368 |
+
k_proj_weight=self.k_proj.weight,
|
| 369 |
+
v_proj_weight=self.v_proj.weight,
|
| 370 |
+
)[0]
|
omnivoice/eval/mos/utmos.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
|
| 20 |
+
"""
|
| 21 |
+
import argparse
|
| 22 |
+
import logging
|
| 23 |
+
import multiprocessing as mp
|
| 24 |
+
import os
|
| 25 |
+
import sys
|
| 26 |
+
import traceback
|
| 27 |
+
import warnings
|
| 28 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import torch
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
|
| 34 |
+
from omnivoice.eval.models.utmos import UTMOS22Strong
|
| 35 |
+
from omnivoice.eval.utils import load_waveform
|
| 36 |
+
from omnivoice.utils.data_utils import read_test_list
|
| 37 |
+
|
| 38 |
+
warnings.filterwarnings("ignore")
|
| 39 |
+
|
| 40 |
+
# Global variables for workers
|
| 41 |
+
worker_model = None
|
| 42 |
+
worker_device = None
|
| 43 |
+
worker_sr = 16000
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_parser() -> argparse.ArgumentParser:
|
| 47 |
+
parser = argparse.ArgumentParser(
|
| 48 |
+
description="Calculate UTMOS score using UTMOS22Strong model."
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--wav-path",
|
| 52 |
+
type=str,
|
| 53 |
+
required=True,
|
| 54 |
+
help="Path to the directory containing evaluated speech files.",
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--test-list",
|
| 58 |
+
type=str,
|
| 59 |
+
required=True,
|
| 60 |
+
help="Path to the JSONL test list. Each line is a JSON object "
|
| 61 |
+
"with fields: id, text, ref_audio, ref_text, language_id, language_name.",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--model-dir",
|
| 65 |
+
type=str,
|
| 66 |
+
required=True,
|
| 67 |
+
help="Local path of our evaluation model repository."
|
| 68 |
+
"Download from https://huggingface.co/k2-fsa/TTS_eval_models."
|
| 69 |
+
"Will use 'tts_eval_models/mos/utmos22_strong_step7459_v1.pt'"
|
| 70 |
+
" in this script",
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--extension",
|
| 74 |
+
type=str,
|
| 75 |
+
default="wav",
|
| 76 |
+
help="Extension of the speech files. Default: wav",
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--decode-path",
|
| 80 |
+
type=str,
|
| 81 |
+
default=None,
|
| 82 |
+
help="Path to the output file where UTMOS information will be saved. "
|
| 83 |
+
"If not provided, results are only printed to console.",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--nj-per-gpu",
|
| 87 |
+
type=int,
|
| 88 |
+
default=1,
|
| 89 |
+
help="Number of worker processes to spawn per GPU.",
|
| 90 |
+
)
|
| 91 |
+
return parser
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_device(rank: int = 0) -> torch.device:
|
| 95 |
+
assert torch.cuda.is_available(), "CUDA is required but not available."
|
| 96 |
+
device = torch.device(f"cuda:{rank}")
|
| 97 |
+
torch.cuda.set_device(rank)
|
| 98 |
+
return device
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def worker_init(
|
| 102 |
+
rank_queue,
|
| 103 |
+
model_path,
|
| 104 |
+
):
|
| 105 |
+
"""Initialize worker process with model and device."""
|
| 106 |
+
global worker_model, worker_device, worker_sr
|
| 107 |
+
|
| 108 |
+
# Limit CPU threads per worker
|
| 109 |
+
torch.set_num_threads(2)
|
| 110 |
+
|
| 111 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] [Worker %(process)d] %(message)s"
|
| 112 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 113 |
+
|
| 114 |
+
rank = rank_queue.get() if rank_queue else -1
|
| 115 |
+
|
| 116 |
+
worker_device = get_device(rank)
|
| 117 |
+
worker_sr = 16000
|
| 118 |
+
|
| 119 |
+
logging.debug(f"Initializing UTMOS worker on {worker_device}")
|
| 120 |
+
|
| 121 |
+
# Initialize Model
|
| 122 |
+
worker_model = UTMOS22Strong()
|
| 123 |
+
try:
|
| 124 |
+
# Load weights to CPU first, then move to device
|
| 125 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 126 |
+
worker_model.load_state_dict(state_dict)
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logging.error(f"Failed to load model from {model_path}: {e}")
|
| 129 |
+
raise
|
| 130 |
+
|
| 131 |
+
worker_model.to(worker_device)
|
| 132 |
+
worker_model.eval()
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@torch.no_grad()
|
| 136 |
+
def run_utmos_worker(file_idx, wav_path, language_name):
|
| 137 |
+
"""Worker function to process a single audio file."""
|
| 138 |
+
try:
|
| 139 |
+
if not os.path.exists(wav_path):
|
| 140 |
+
return file_idx, wav_path, language_name, f"File not found: {wav_path}", "error"
|
| 141 |
+
|
| 142 |
+
# Load and preprocess waveform
|
| 143 |
+
speech = load_waveform(wav_path, worker_sr, device=worker_device)
|
| 144 |
+
|
| 145 |
+
# Compute score
|
| 146 |
+
# UTMOS expects input shape (Batch, Time)
|
| 147 |
+
score = worker_model(speech.unsqueeze(0), worker_sr)
|
| 148 |
+
|
| 149 |
+
return file_idx, wav_path, language_name, score.item(), "success"
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
error_detail = (
|
| 153 |
+
f"Error processing {wav_path}: {str(e)}\n"
|
| 154 |
+
f"Traceback:\n{traceback.format_exc()}"
|
| 155 |
+
)
|
| 156 |
+
return file_idx, wav_path, language_name, error_detail, "error"
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def main():
|
| 160 |
+
parser = get_parser()
|
| 161 |
+
args = parser.parse_args()
|
| 162 |
+
|
| 163 |
+
# Main process thread setting
|
| 164 |
+
torch.set_num_threads(2)
|
| 165 |
+
|
| 166 |
+
mp.set_start_method("spawn", force=True)
|
| 167 |
+
|
| 168 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 169 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 170 |
+
|
| 171 |
+
# Validate inputs
|
| 172 |
+
if not os.path.isdir(args.wav_path):
|
| 173 |
+
logging.error(f"Invalid directory: {args.wav_path}")
|
| 174 |
+
sys.exit(1)
|
| 175 |
+
|
| 176 |
+
model_path = os.path.join(args.model_dir, "mos/utmos22_strong_step7459_v1.pt")
|
| 177 |
+
if not os.path.exists(model_path):
|
| 178 |
+
logging.error(f"Model file not found at {model_path}")
|
| 179 |
+
sys.exit(1)
|
| 180 |
+
|
| 181 |
+
# Scan directory for files
|
| 182 |
+
logging.info(f"Calculating UTMOS for {args.wav_path}")
|
| 183 |
+
|
| 184 |
+
wav_files = []
|
| 185 |
+
try:
|
| 186 |
+
samples = read_test_list(args.test_list)
|
| 187 |
+
for s in samples:
|
| 188 |
+
language_name = s.get("language_name") or "unknown"
|
| 189 |
+
eval_wav_path = os.path.join(args.wav_path, f"{s['id']}.{args.extension}")
|
| 190 |
+
wav_files.append((eval_wav_path, language_name))
|
| 191 |
+
except Exception as e:
|
| 192 |
+
raise ValueError(f"Error reading test list {args.test_list}: {e}")
|
| 193 |
+
|
| 194 |
+
# Setup Parallel Processing
|
| 195 |
+
num_gpus = torch.cuda.device_count()
|
| 196 |
+
assert num_gpus > 0, "No GPU found. GPU is required."
|
| 197 |
+
total_procs = num_gpus * args.nj_per_gpu
|
| 198 |
+
|
| 199 |
+
logging.info(
|
| 200 |
+
f"Starting evaluation with {total_procs} processes on {num_gpus} GPUs."
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
manager = mp.Manager()
|
| 204 |
+
rank_queue = manager.Queue()
|
| 205 |
+
|
| 206 |
+
for rank in list(range(num_gpus)) * args.nj_per_gpu:
|
| 207 |
+
rank_queue.put(rank)
|
| 208 |
+
|
| 209 |
+
scores = []
|
| 210 |
+
|
| 211 |
+
fout = None
|
| 212 |
+
if args.decode_path:
|
| 213 |
+
os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
|
| 214 |
+
fout = open(args.decode_path, "w", encoding="utf8")
|
| 215 |
+
logging.info(f"Saving detailed UTMOS results to: {args.decode_path}")
|
| 216 |
+
fout.write("Name\tUTMOS\n")
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
with ProcessPoolExecutor(
|
| 220 |
+
max_workers=total_procs,
|
| 221 |
+
initializer=worker_init,
|
| 222 |
+
initargs=(
|
| 223 |
+
rank_queue,
|
| 224 |
+
model_path,
|
| 225 |
+
),
|
| 226 |
+
) as executor:
|
| 227 |
+
futures = []
|
| 228 |
+
for i, (wav_path, language_name) in enumerate(wav_files):
|
| 229 |
+
futures.append(
|
| 230 |
+
executor.submit(run_utmos_worker, i, wav_path, language_name)
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
pbar = tqdm(
|
| 234 |
+
as_completed(futures), total=len(wav_files), desc="Evaluating UTMOS"
|
| 235 |
+
)
|
| 236 |
+
lang_stats = {}
|
| 237 |
+
for future in pbar:
|
| 238 |
+
idx, path, language_name, result, status = future.result()
|
| 239 |
+
if status == "success":
|
| 240 |
+
if language_name not in lang_stats:
|
| 241 |
+
lang_stats[language_name] = []
|
| 242 |
+
lang_stats[language_name].append(result)
|
| 243 |
+
scores.append(result)
|
| 244 |
+
if fout:
|
| 245 |
+
if language_name == "unknown":
|
| 246 |
+
fout.write(f"{os.path.basename(path)}\t{result:.2f}\n")
|
| 247 |
+
else:
|
| 248 |
+
fout.write(
|
| 249 |
+
f"{language_name}\t{os.path.basename(path)}\t{result:.2f}\n"
|
| 250 |
+
)
|
| 251 |
+
else:
|
| 252 |
+
pbar.write(f"!!! FAILED [File {idx}]: {path} | {result}")
|
| 253 |
+
|
| 254 |
+
except (Exception, KeyboardInterrupt) as e:
|
| 255 |
+
logging.critical(
|
| 256 |
+
f"An unrecoverable error occurred: {e}. Terminating all processes."
|
| 257 |
+
)
|
| 258 |
+
detailed_error_info = traceback.format_exc()
|
| 259 |
+
logging.error(f"--- DETAILED TRACEBACK ---\n{detailed_error_info}")
|
| 260 |
+
sys.exit(1)
|
| 261 |
+
|
| 262 |
+
print("-" * 50)
|
| 263 |
+
|
| 264 |
+
if len(lang_stats) > 1:
|
| 265 |
+
lang_scores = []
|
| 266 |
+
for lang in sorted(lang_stats.keys()):
|
| 267 |
+
l_scores = lang_stats[lang]
|
| 268 |
+
l_avg = np.mean(l_scores)
|
| 269 |
+
lang_scores.append(l_scores)
|
| 270 |
+
l_count = len(l_scores)
|
| 271 |
+
logging.info(f"[{lang}] UTMOS score: {l_avg:.3f} ({l_count} samples)")
|
| 272 |
+
if fout:
|
| 273 |
+
fout.write(f"[{lang}] UTMOS: {l_avg:.3f} ({l_count} samples)\n")
|
| 274 |
+
logging.info(
|
| 275 |
+
f"Macro-average UTMOS over {len(lang_stats)} languages: "
|
| 276 |
+
f"{np.mean([np.mean(ls) for ls in lang_scores]):.3f}"
|
| 277 |
+
)
|
| 278 |
+
if fout:
|
| 279 |
+
fout.write(
|
| 280 |
+
f"\nMacro-average UTMOS over {len(lang_stats)} languages: "
|
| 281 |
+
f"{np.mean([np.mean(ls) for ls in lang_scores]):.3f}\n"
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
if scores:
|
| 285 |
+
avg_score = np.mean(scores)
|
| 286 |
+
logging.info(f"Processed {len(scores)}/{len(wav_files)} files.")
|
| 287 |
+
logging.info(f"UTMOS score: {avg_score:.2f}")
|
| 288 |
+
if fout:
|
| 289 |
+
fout.write(f"\nAverage UTMOS: {avg_score:.2f}\n")
|
| 290 |
+
else:
|
| 291 |
+
logging.error("No valid scores computed.")
|
| 292 |
+
print("-" * 50)
|
| 293 |
+
|
| 294 |
+
if fout:
|
| 295 |
+
fout.close()
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
if __name__ == "__main__":
|
| 299 |
+
main()
|
omnivoice/eval/speaker_similarity/sim.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Computes speaker similarity (SIM-o) using a WavLM-based
|
| 20 |
+
ECAPA-TDNN speaker verification model.
|
| 21 |
+
"""
|
| 22 |
+
import argparse
|
| 23 |
+
import logging
|
| 24 |
+
import multiprocessing as mp
|
| 25 |
+
import os
|
| 26 |
+
import sys
|
| 27 |
+
import traceback
|
| 28 |
+
import warnings
|
| 29 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
from tqdm import tqdm
|
| 34 |
+
|
| 35 |
+
from omnivoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
|
| 36 |
+
from omnivoice.eval.utils import load_waveform
|
| 37 |
+
from omnivoice.utils.data_utils import read_test_list
|
| 38 |
+
|
| 39 |
+
warnings.filterwarnings("ignore")
|
| 40 |
+
|
| 41 |
+
# Global variables for workers
|
| 42 |
+
worker_model = None
|
| 43 |
+
worker_device = None
|
| 44 |
+
worker_sr = 16000
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_parser() -> argparse.ArgumentParser:
|
| 48 |
+
parser = argparse.ArgumentParser(
|
| 49 |
+
description="Calculate speaker similarity (SIM-o) score."
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--wav-path",
|
| 53 |
+
type=str,
|
| 54 |
+
required=True,
|
| 55 |
+
help="Path to the directory containing evaluated speech files.",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--test-list",
|
| 59 |
+
type=str,
|
| 60 |
+
required=True,
|
| 61 |
+
help="Path to the JSONL test list. Each line is a JSON object "
|
| 62 |
+
"with fields: id, text, ref_audio, ref_text, language_id, language_name.",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--model-dir",
|
| 66 |
+
type=str,
|
| 67 |
+
required=True,
|
| 68 |
+
help="Local path of our evaluation model repository."
|
| 69 |
+
"Download from https://huggingface.co/k2-fsa/TTS_eval_models."
|
| 70 |
+
"Will use 'tts_eval_models/speaker_similarity/wavlm_large_finetune.pth'"
|
| 71 |
+
"and 'tts_eval_models/speaker_similarity/wavlm_large/' in this script",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--extension",
|
| 75 |
+
type=str,
|
| 76 |
+
default="wav",
|
| 77 |
+
help="Extension of the speech files.",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--decode-path",
|
| 81 |
+
type=str,
|
| 82 |
+
default=None,
|
| 83 |
+
help="Path to the output file where SIM-o information will be saved. "
|
| 84 |
+
"If not provided, results are only printed to console.",
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--nj-per-gpu",
|
| 88 |
+
type=int,
|
| 89 |
+
default=1,
|
| 90 |
+
help="Number of worker processes to spawn per GPU.",
|
| 91 |
+
)
|
| 92 |
+
return parser
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_device(rank: int = 0) -> torch.device:
|
| 96 |
+
assert torch.cuda.is_available(), "CUDA is required but not available."
|
| 97 |
+
device = torch.device(f"cuda:{rank}")
|
| 98 |
+
torch.cuda.set_device(rank)
|
| 99 |
+
return device
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def worker_init(
|
| 103 |
+
rank_queue,
|
| 104 |
+
sv_model_path,
|
| 105 |
+
ssl_model_path,
|
| 106 |
+
):
|
| 107 |
+
"""Initialize worker process with model and device."""
|
| 108 |
+
global worker_model, worker_device, worker_sr
|
| 109 |
+
|
| 110 |
+
torch.set_num_threads(2)
|
| 111 |
+
|
| 112 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] [Worker %(process)d] %(message)s"
|
| 113 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 114 |
+
|
| 115 |
+
rank = rank_queue.get() if rank_queue else -1
|
| 116 |
+
|
| 117 |
+
worker_device = get_device(rank)
|
| 118 |
+
worker_sr = 16000
|
| 119 |
+
|
| 120 |
+
logging.debug(f"Initializing SIM-o worker on {worker_device}")
|
| 121 |
+
# Temporarily suppress INFO logs to hide verbose WavLM config
|
| 122 |
+
logging.disable(logging.INFO)
|
| 123 |
+
|
| 124 |
+
# Initialize Model
|
| 125 |
+
try:
|
| 126 |
+
worker_model = ECAPA_TDNN_WAVLM(
|
| 127 |
+
feat_dim=1024,
|
| 128 |
+
channels=512,
|
| 129 |
+
emb_dim=256,
|
| 130 |
+
sr=worker_sr,
|
| 131 |
+
ssl_model_path=ssl_model_path,
|
| 132 |
+
)
|
| 133 |
+
state_dict = torch.load(
|
| 134 |
+
sv_model_path, map_location=lambda storage, loc: storage
|
| 135 |
+
)
|
| 136 |
+
worker_model.load_state_dict(state_dict["model"], strict=False)
|
| 137 |
+
worker_model.to(worker_device)
|
| 138 |
+
worker_model.eval()
|
| 139 |
+
finally:
|
| 140 |
+
# Restore normal logging
|
| 141 |
+
logging.disable(logging.NOTSET)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@torch.no_grad()
|
| 145 |
+
def get_embedding(wav_path: str) -> torch.Tensor:
|
| 146 |
+
"""Extract embedding for a single file."""
|
| 147 |
+
speech = load_waveform(wav_path, worker_sr, device=worker_device, max_seconds=120)
|
| 148 |
+
return worker_model([speech])
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def run_similarity_worker(line_idx, sample, wav_dir, extension):
|
| 152 |
+
"""Worker function to process a single pair."""
|
| 153 |
+
try:
|
| 154 |
+
wav_name = sample["id"]
|
| 155 |
+
ref_wav_path = sample["ref_audio"]
|
| 156 |
+
language_name = sample.get("language_name") or "unknown"
|
| 157 |
+
eval_wav_path = os.path.join(wav_dir, f"{wav_name}.{extension}")
|
| 158 |
+
|
| 159 |
+
if not os.path.exists(ref_wav_path):
|
| 160 |
+
return line_idx, f"Reference not found: {ref_wav_path}", None, "error"
|
| 161 |
+
if not os.path.exists(eval_wav_path):
|
| 162 |
+
return line_idx, f"Eval wav not found: {eval_wav_path}", None, "error"
|
| 163 |
+
|
| 164 |
+
# Compute embeddings pair-wise
|
| 165 |
+
ref_emb = get_embedding(ref_wav_path)
|
| 166 |
+
eval_emb = get_embedding(eval_wav_path)
|
| 167 |
+
|
| 168 |
+
# Cosine Similarity
|
| 169 |
+
similarity = torch.nn.functional.cosine_similarity(ref_emb, eval_emb, dim=-1)
|
| 170 |
+
|
| 171 |
+
return (
|
| 172 |
+
line_idx,
|
| 173 |
+
(ref_wav_path, eval_wav_path, language_name),
|
| 174 |
+
similarity.item(),
|
| 175 |
+
"success",
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
error_detail = f"Error: {str(e)}\nTraceback:\n{traceback.format_exc()}"
|
| 180 |
+
return line_idx, str(sample), error_detail, "error"
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def main():
|
| 184 |
+
parser = get_parser()
|
| 185 |
+
args = parser.parse_args()
|
| 186 |
+
|
| 187 |
+
# Main process thread setting
|
| 188 |
+
torch.set_num_threads(2)
|
| 189 |
+
|
| 190 |
+
mp.set_start_method("spawn", force=True)
|
| 191 |
+
|
| 192 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 193 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 194 |
+
|
| 195 |
+
# Prepare paths
|
| 196 |
+
sv_model_path = os.path.join(
|
| 197 |
+
args.model_dir, "speaker_similarity/wavlm_large_finetune.pth"
|
| 198 |
+
)
|
| 199 |
+
ssl_model_path = os.path.join(args.model_dir, "speaker_similarity/wavlm_large/")
|
| 200 |
+
|
| 201 |
+
if not os.path.exists(sv_model_path) or not os.path.exists(ssl_model_path):
|
| 202 |
+
logging.error("Model files not found. Please check --model-dir.")
|
| 203 |
+
sys.exit(1)
|
| 204 |
+
|
| 205 |
+
logging.info(f"Calculating SIM-o for {args.wav_path}")
|
| 206 |
+
# Read list
|
| 207 |
+
samples = read_test_list(args.test_list)
|
| 208 |
+
|
| 209 |
+
# Setup Parallel Processing
|
| 210 |
+
num_gpus = torch.cuda.device_count()
|
| 211 |
+
assert num_gpus > 0, "No GPU found. GPU is required."
|
| 212 |
+
total_procs = num_gpus * args.nj_per_gpu
|
| 213 |
+
|
| 214 |
+
logging.info(
|
| 215 |
+
f"Starting evaluation with {total_procs} processes " f"on {num_gpus} GPUs."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
manager = mp.Manager()
|
| 219 |
+
rank_queue = manager.Queue()
|
| 220 |
+
|
| 221 |
+
for rank in list(range(num_gpus)) * args.nj_per_gpu:
|
| 222 |
+
rank_queue.put(rank)
|
| 223 |
+
|
| 224 |
+
scores = []
|
| 225 |
+
|
| 226 |
+
fout = None
|
| 227 |
+
if args.decode_path:
|
| 228 |
+
os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
|
| 229 |
+
fout = open(args.decode_path, "w", encoding="utf8")
|
| 230 |
+
logging.info(f"Saving detailed SIM-o results to: {args.decode_path}")
|
| 231 |
+
fout.write("Prompt-path\tEval-path\tSIM-o\n")
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
with ProcessPoolExecutor(
|
| 235 |
+
max_workers=total_procs,
|
| 236 |
+
initializer=worker_init,
|
| 237 |
+
initargs=(
|
| 238 |
+
rank_queue,
|
| 239 |
+
sv_model_path,
|
| 240 |
+
ssl_model_path,
|
| 241 |
+
),
|
| 242 |
+
) as executor:
|
| 243 |
+
futures = []
|
| 244 |
+
for i, sample in enumerate(samples):
|
| 245 |
+
futures.append(
|
| 246 |
+
executor.submit(
|
| 247 |
+
run_similarity_worker, i, sample, args.wav_path, args.extension
|
| 248 |
+
)
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
pbar = tqdm(
|
| 252 |
+
as_completed(futures), total=len(samples), desc="Evaluating SIM-o"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
lang_stats = {}
|
| 256 |
+
|
| 257 |
+
for future in pbar:
|
| 258 |
+
idx, context, result, status = future.result()
|
| 259 |
+
if status == "success":
|
| 260 |
+
prompt_path, eval_path, lang = context
|
| 261 |
+
scores.append(result)
|
| 262 |
+
|
| 263 |
+
# Accumulate per-language
|
| 264 |
+
if lang not in lang_stats:
|
| 265 |
+
lang_stats[lang] = []
|
| 266 |
+
lang_stats[lang].append(result)
|
| 267 |
+
|
| 268 |
+
if fout:
|
| 269 |
+
if lang == "unknown":
|
| 270 |
+
fout.write(f"{prompt_path}\t{eval_path}\t{result:.2f}\n")
|
| 271 |
+
else:
|
| 272 |
+
fout.write(
|
| 273 |
+
f"{lang}\t{context[0]}\t{context[1]}\t{result:.2f}\n"
|
| 274 |
+
)
|
| 275 |
+
else:
|
| 276 |
+
pbar.write(f"!!! FAILED [Line {idx}]: {context} | Error: {result}")
|
| 277 |
+
|
| 278 |
+
except (Exception, KeyboardInterrupt) as e:
|
| 279 |
+
logging.critical(
|
| 280 |
+
f"An unrecoverable error occurred: {e}. " f"Terminating all processes."
|
| 281 |
+
)
|
| 282 |
+
detailed_error_info = traceback.format_exc()
|
| 283 |
+
logging.error(f"--- DETAILED TRACEBACK ---\n{detailed_error_info}")
|
| 284 |
+
sys.exit(1)
|
| 285 |
+
|
| 286 |
+
print("-" * 50)
|
| 287 |
+
if len(lang_stats) > 1:
|
| 288 |
+
lang_scores = []
|
| 289 |
+
for lang in sorted(lang_stats.keys()):
|
| 290 |
+
l_scores = lang_stats[lang]
|
| 291 |
+
l_avg = np.mean(l_scores)
|
| 292 |
+
lang_scores.append(l_scores)
|
| 293 |
+
l_count = len(l_scores)
|
| 294 |
+
logging.info(f"[{lang}] SIM-o score: {l_avg:.3f} ({l_count} pairs)")
|
| 295 |
+
if fout:
|
| 296 |
+
fout.write(f"[{lang}] SIM-o: {l_avg:.3f} ({l_count} pairs)\n")
|
| 297 |
+
logging.info(
|
| 298 |
+
f"Macro-average SIM-o over {len(lang_stats)} languages: "
|
| 299 |
+
f"{np.mean([np.mean(ls) for ls in lang_scores]):.3f}"
|
| 300 |
+
)
|
| 301 |
+
if fout:
|
| 302 |
+
fout.write(
|
| 303 |
+
f"\nMacro-average SIM-o over {len(lang_stats)} languages: "
|
| 304 |
+
f"{np.mean([np.mean(ls) for ls in lang_scores]):.3f}\n"
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if scores:
|
| 308 |
+
avg_score = np.mean(scores)
|
| 309 |
+
logging.info(f"Processed {len(scores)}/{len(samples)} pairs.")
|
| 310 |
+
logging.info(f"SIM-o score: {avg_score:.3f}")
|
| 311 |
+
if fout:
|
| 312 |
+
fout.write(f"\nAverage SIM-o: {avg_score:.3f}\n")
|
| 313 |
+
else:
|
| 314 |
+
logging.error("No valid scores computed.")
|
| 315 |
+
if fout:
|
| 316 |
+
fout.close()
|
| 317 |
+
print("-" * 50)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
main()
|
omnivoice/eval/utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import logging
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import librosa
|
| 22 |
+
import soundfile as sf
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_waveform(
|
| 27 |
+
fname: str,
|
| 28 |
+
sample_rate: int,
|
| 29 |
+
dtype: str = "float32",
|
| 30 |
+
device: torch.device = torch.device("cpu"),
|
| 31 |
+
return_numpy: bool = False,
|
| 32 |
+
max_seconds: Optional[float] = None,
|
| 33 |
+
) -> torch.Tensor:
|
| 34 |
+
"""
|
| 35 |
+
Load an audio file, preprocess it, and convert to a PyTorch tensor.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
fname (str): Path to the audio file.
|
| 39 |
+
sample_rate (int): Target sample rate for resampling.
|
| 40 |
+
dtype (str, optional): Data type to load audio as (default: "float32").
|
| 41 |
+
device (torch.device, optional): Device to place the resulting tensor
|
| 42 |
+
on (default: CPU).
|
| 43 |
+
return_numpy (bool): If True, returns a NumPy array instead of a
|
| 44 |
+
PyTorch tensor.
|
| 45 |
+
max_seconds (float): Maximum length (seconds) of the audio tensor.
|
| 46 |
+
If the audio is longer than this, it will be truncated.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
torch.Tensor: Processed audio waveform as a PyTorch tensor,
|
| 50 |
+
with shape (num_samples,).
|
| 51 |
+
|
| 52 |
+
Notes:
|
| 53 |
+
- If the audio is stereo, it will be converted to mono by averaging channels.
|
| 54 |
+
- If the audio's sample rate differs from the target, it will be resampled.
|
| 55 |
+
"""
|
| 56 |
+
# Load audio file with specified data type
|
| 57 |
+
wav_data, sr = sf.read(fname, dtype=dtype)
|
| 58 |
+
|
| 59 |
+
# Convert stereo to mono if necessary
|
| 60 |
+
if len(wav_data.shape) == 2:
|
| 61 |
+
wav_data = wav_data.mean(1)
|
| 62 |
+
|
| 63 |
+
# Resample to target sample rate if needed
|
| 64 |
+
if sr != sample_rate:
|
| 65 |
+
wav_data = librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate)
|
| 66 |
+
|
| 67 |
+
if max_seconds is not None:
|
| 68 |
+
# Trim to max length
|
| 69 |
+
max_length = int(sample_rate * max_seconds)
|
| 70 |
+
if len(wav_data) > max_length:
|
| 71 |
+
wav_data = wav_data[:max_length]
|
| 72 |
+
logging.warning(
|
| 73 |
+
f"Wav file {fname} is longer than {max_seconds}s, "
|
| 74 |
+
f"truncated to {max_seconds}s to avoid OOM."
|
| 75 |
+
)
|
| 76 |
+
if return_numpy:
|
| 77 |
+
return wav_data
|
| 78 |
+
else:
|
| 79 |
+
wav_data = torch.from_numpy(wav_data)
|
| 80 |
+
return wav_data.to(device)
|
omnivoice/eval/wer/common.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Shared utilities for WER evaluation scripts.
|
| 20 |
+
"""
|
| 21 |
+
import logging
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
from jiwer import compute_measures
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def process_one(hypothesis: str, truth: str, post_process, lang: str = None) -> dict:
|
| 28 |
+
"""
|
| 29 |
+
Computes WER and related metrics for a single hypothesis-truth pair.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
hypothesis (str): The transcribed text from the ASR model.
|
| 33 |
+
truth (str): The ground truth transcript.
|
| 34 |
+
post_process (callable): Text normalization function defined by each script.
|
| 35 |
+
Signature: post_process(text, lang) or post_process(text).
|
| 36 |
+
lang (str): The language code for post_process. Pass None if post_process
|
| 37 |
+
does not accept a lang argument.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
dict: A dict containing:
|
| 41 |
+
- truth (str): Post-processed ground truth text.
|
| 42 |
+
- hypothesis (str): Post-processed hypothesis text.
|
| 43 |
+
- wer (float): Word Error Rate.
|
| 44 |
+
- substitutions (int): Number of substitutions.
|
| 45 |
+
- deletions (int): Number of deletions.
|
| 46 |
+
- insertions (int): Number of insertions.
|
| 47 |
+
- word_num (int): Number of words in the post-processed ground truth.
|
| 48 |
+
"""
|
| 49 |
+
if lang is not None:
|
| 50 |
+
truth_processed = post_process(truth, lang)
|
| 51 |
+
hypothesis_processed = post_process(hypothesis, lang)
|
| 52 |
+
else:
|
| 53 |
+
truth_processed = post_process(truth)
|
| 54 |
+
hypothesis_processed = post_process(hypothesis)
|
| 55 |
+
measures = compute_measures(truth_processed, hypothesis_processed)
|
| 56 |
+
word_num = len(truth_processed.split(" "))
|
| 57 |
+
return {
|
| 58 |
+
"truth": truth_processed,
|
| 59 |
+
"hypo": hypothesis_processed,
|
| 60 |
+
"wer": measures["wer"],
|
| 61 |
+
"substitutions": measures["substitutions"],
|
| 62 |
+
"deletions": measures["deletions"],
|
| 63 |
+
"insertions": measures["insertions"],
|
| 64 |
+
"word_num": word_num,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def log_metrics(fout, prefix, i_list, d_list, s_list, w_total, ndigits=2):
|
| 69 |
+
"""Log weighted WER metrics for a subset of results."""
|
| 70 |
+
metrics_wer = round(
|
| 71 |
+
(np.sum(s_list) + np.sum(d_list) + np.sum(i_list)) / w_total * 100, ndigits
|
| 72 |
+
)
|
| 73 |
+
metrics_inse = np.sum(i_list)
|
| 74 |
+
metrics_dele = np.sum(d_list)
|
| 75 |
+
metrics_subs = np.sum(s_list)
|
| 76 |
+
|
| 77 |
+
logging.info(f"{prefix} WER: {metrics_wer}%")
|
| 78 |
+
logging.info(
|
| 79 |
+
f"{prefix} Errors: {metrics_inse} ins, {metrics_dele} del, "
|
| 80 |
+
f"{metrics_subs} sub / {w_total} words"
|
| 81 |
+
)
|
| 82 |
+
if fout:
|
| 83 |
+
fout.write(f"{prefix} WER: {metrics_wer}%\n")
|
| 84 |
+
fout.write(
|
| 85 |
+
f"{prefix} Errors: {metrics_inse} ins, {metrics_dele} del, "
|
| 86 |
+
f"{metrics_subs} sub / {w_total} words\n"
|
| 87 |
+
)
|
| 88 |
+
return metrics_wer
|
omnivoice/eval/wer/fleurs.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Computes word error rate (WER) for FLEURS multilingual evaluation.
|
| 19 |
+
|
| 20 |
+
Uses omnilingual-asr for ASR transcription across 100+ languages.
|
| 21 |
+
Requires a separate environment with ``omnilingual_asr`` installed.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
python3 omnivoice/eval/wer/fleurs.py \\
|
| 25 |
+
--wav-path results/fleurs \\
|
| 26 |
+
--test-list test.jsonl \\
|
| 27 |
+
--decode-path results/fleurs.wer.log \\
|
| 28 |
+
--model-card omniASR_LLM_Unlimited_7B_v2 \\
|
| 29 |
+
--chunk-size 100 --batch-size 50
|
| 30 |
+
"""
|
| 31 |
+
import argparse
|
| 32 |
+
import logging
|
| 33 |
+
import multiprocessing as mp
|
| 34 |
+
import os
|
| 35 |
+
import re
|
| 36 |
+
import sys
|
| 37 |
+
import traceback
|
| 38 |
+
import types
|
| 39 |
+
from collections import defaultdict
|
| 40 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 41 |
+
from pathlib import Path
|
| 42 |
+
from typing import List, Union
|
| 43 |
+
|
| 44 |
+
import numpy as np
|
| 45 |
+
import torch
|
| 46 |
+
from tqdm import tqdm
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline
|
| 50 |
+
from omnilingual_asr.models.wav2vec2_llama.lang_ids import supported_langs
|
| 51 |
+
except ImportError:
|
| 52 |
+
logging.error("Please install omnilingual_asr first.")
|
| 53 |
+
exit(1)
|
| 54 |
+
|
| 55 |
+
# omnilingual-asr may pull a transformers version that lacks
|
| 56 |
+
# HiggsAudioV2TokenizerModel. Pre-register stubs to bypass
|
| 57 |
+
# omnivoice/__init__.py heavy imports.
|
| 58 |
+
if "omnivoice" not in sys.modules:
|
| 59 |
+
_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 60 |
+
for _name in (
|
| 61 |
+
"omnivoice",
|
| 62 |
+
"omnivoice.eval",
|
| 63 |
+
"omnivoice.eval.wer",
|
| 64 |
+
"omnivoice.utils",
|
| 65 |
+
):
|
| 66 |
+
if _name not in sys.modules:
|
| 67 |
+
_m = types.ModuleType(_name)
|
| 68 |
+
_m.__path__ = [os.path.join(_root, *_name.split(".")[1:])]
|
| 69 |
+
_m.__package__ = _name
|
| 70 |
+
sys.modules[_name] = _m
|
| 71 |
+
|
| 72 |
+
from omnivoice.eval.wer.common import log_metrics, process_one
|
| 73 |
+
from omnivoice.eval.wer.text_norm_omni import text_normalize
|
| 74 |
+
from omnivoice.utils.data_utils import read_test_list
|
| 75 |
+
|
| 76 |
+
# --- Global variables for worker processes ---
|
| 77 |
+
worker_pipe = None
|
| 78 |
+
worker_device = None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# fix mismatched language codes between OmniVoice and Omnilingual-ASR model
|
| 82 |
+
rename = {
|
| 83 |
+
"et": "ekk",
|
| 84 |
+
"ms": "zsm",
|
| 85 |
+
"sw": "swh",
|
| 86 |
+
"npi": "nep",
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def read_language_mapping_from_tsv(
|
| 91 |
+
mapping_path: Path,
|
| 92 |
+
) -> dict[str, Union[str, List[str]]]:
|
| 93 |
+
with open(mapping_path, "r", encoding="utf-8") as f:
|
| 94 |
+
_ = f.readline() # Skip header
|
| 95 |
+
language_mapping = {}
|
| 96 |
+
for line in f:
|
| 97 |
+
parts = line.strip().split("\t")
|
| 98 |
+
mixed_id, language_name, iso_639_3_id, duration = parts
|
| 99 |
+
language_mapping[iso_639_3_id] = mixed_id
|
| 100 |
+
return language_mapping
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
iso_639_3_id_to_mixed_id = read_language_mapping_from_tsv(
|
| 104 |
+
Path(f"{os.path.dirname(__file__)}/../../../docs/lang_id_name_map.tsv")
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
mixed_id_to_omnilingual_asr_lang = {}
|
| 108 |
+
|
| 109 |
+
for lang in supported_langs:
|
| 110 |
+
if lang in ("cmn_Hant",):
|
| 111 |
+
continue
|
| 112 |
+
iso_639_3_lang_code = lang.split("_")[0]
|
| 113 |
+
if iso_639_3_lang_code in iso_639_3_id_to_mixed_id:
|
| 114 |
+
mixed_id = iso_639_3_id_to_mixed_id[iso_639_3_lang_code]
|
| 115 |
+
mixed_id_to_omnilingual_asr_lang[mixed_id] = lang
|
| 116 |
+
else:
|
| 117 |
+
mixed_id_to_omnilingual_asr_lang[iso_639_3_lang_code] = lang
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def clean_cjk_spaces(text):
|
| 121 |
+
"""
|
| 122 |
+
Removes spaces adjacent to Chinese and Japanese characters while preserving
|
| 123 |
+
meaningful spaces in English or other languages (like Korean).
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
# Define CJK (Chinese, Japanese) Unicode ranges
|
| 127 |
+
# \u4e00-\u9fff: CJK Unified Ideographs (Chinese)
|
| 128 |
+
# \u3040-\u309f: Hiragana (Japanese)
|
| 129 |
+
# \u30a0-\u30ff: Katakana (Japanese)
|
| 130 |
+
# \u3000-\u303f: CJK Symbols and Punctuation
|
| 131 |
+
cjk_range = r"\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u3000-\u303f"
|
| 132 |
+
|
| 133 |
+
# 1. Remove spaces between two CJK characters
|
| 134 |
+
# Example: "我 爱 你" -> "我爱你"
|
| 135 |
+
text = re.sub(f"([{cjk_range}])\\s+([{cjk_range}])", r"\1\2", text)
|
| 136 |
+
|
| 137 |
+
# 2. Remove spaces between a CJK character and a non-CJK character (English/Numbers)
|
| 138 |
+
# Example: "我 爱 you" -> "我爱you"
|
| 139 |
+
text = re.sub(f"([{cjk_range}])\\s+", r"\1", text)
|
| 140 |
+
text = re.sub(f"\\s+([{cjk_range}])", r"\1", text)
|
| 141 |
+
|
| 142 |
+
# 3. Collapse multiple spaces into one for the remaining parts (e.g., English words)
|
| 143 |
+
text = re.sub(r"\s+", " ", text)
|
| 144 |
+
|
| 145 |
+
return text.strip()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def get_parser():
|
| 149 |
+
parser = argparse.ArgumentParser(
|
| 150 |
+
description="Computes WER with Whisper.",
|
| 151 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--wav-path",
|
| 156 |
+
type=str,
|
| 157 |
+
required=True,
|
| 158 |
+
help="Path to the directory containing speech files.",
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--extension",
|
| 163 |
+
type=str,
|
| 164 |
+
default="wav",
|
| 165 |
+
help="Extension of the speech files. Default: wav",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--decode-path",
|
| 170 |
+
type=str,
|
| 171 |
+
default=None,
|
| 172 |
+
help="Path to the output file where WER information will be saved. "
|
| 173 |
+
"If not provided, results are only printed to console.",
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--model-card",
|
| 177 |
+
type=str,
|
| 178 |
+
default="omniASR_LLM_7B",
|
| 179 |
+
help="Model card name for OmniASR (e.g., omniASR_LLM_7B) or local path.",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--test-list",
|
| 183 |
+
type=str,
|
| 184 |
+
default="test.jsonl",
|
| 185 |
+
help="path of the JSONL test list. Each line is a JSON object "
|
| 186 |
+
"with fields: id, text, ref_audio, ref_text, language_id, language_name.",
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--lang",
|
| 190 |
+
type=str,
|
| 191 |
+
default=None,
|
| 192 |
+
help="""Language code to evaluate (e.g., 'en' for English, 'zh' for Chinese).
|
| 193 |
+
If not provided, the script will evaluate all languages found in the test list.
|
| 194 |
+
If specified, only samples of the given language will be evaluated.
|
| 195 |
+
""",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--batch-size",
|
| 199 |
+
type=int,
|
| 200 |
+
default=8,
|
| 201 |
+
help="Batch size for decoding with the Hugging Face pipeline.",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--nj-per-gpu", type=int, default=1, help="Number of workers per GPU."
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--chunk-size",
|
| 208 |
+
type=int,
|
| 209 |
+
default=300,
|
| 210 |
+
help="Number of samples per task chunk sent to workers.",
|
| 211 |
+
)
|
| 212 |
+
return parser
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def load_omni_model(model_card, device):
|
| 216 |
+
logging.info(f"Loading OmniASR model ({model_card}) on {device}...")
|
| 217 |
+
try:
|
| 218 |
+
pipeline = ASRInferencePipeline(model_card=model_card, device=str(device))
|
| 219 |
+
return pipeline
|
| 220 |
+
except Exception as e:
|
| 221 |
+
logging.error(f"Failed to load OmniASR pipeline: {e}")
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def process_init(rank_queue, model_card):
|
| 226 |
+
"""
|
| 227 |
+
Initializer for each worker process.
|
| 228 |
+
"""
|
| 229 |
+
global worker_pipe, worker_device
|
| 230 |
+
|
| 231 |
+
# Configure threads constraint
|
| 232 |
+
torch.set_num_threads(2)
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
rank = rank_queue.get(timeout=10)
|
| 236 |
+
except Exception:
|
| 237 |
+
raise RuntimeError("Failed to get GPU rank from queue.")
|
| 238 |
+
|
| 239 |
+
assert torch.cuda.is_available(), "CUDA is required but not available."
|
| 240 |
+
worker_device = torch.device(f"cuda:{rank}")
|
| 241 |
+
torch.cuda.set_device(rank)
|
| 242 |
+
|
| 243 |
+
logging.info(f"Initializing worker on device: {worker_device}")
|
| 244 |
+
|
| 245 |
+
try:
|
| 246 |
+
# Using the model_card argument
|
| 247 |
+
worker_pipe = load_omni_model(model_card, worker_device)
|
| 248 |
+
if worker_pipe is None:
|
| 249 |
+
raise RuntimeError("Model loading failed.")
|
| 250 |
+
except Exception as e:
|
| 251 |
+
logging.critical(f"Failed to load model on {worker_device}: {e}")
|
| 252 |
+
raise e
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def post_process(text: str, lang: str) -> str:
|
| 256 |
+
"""
|
| 257 |
+
Cleans and normalizes text for WER calculation.
|
| 258 |
+
Args:
|
| 259 |
+
text (str): The input text to be processed.
|
| 260 |
+
lang (str): The language of the input text.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
str: The cleaned and normalized text.
|
| 264 |
+
"""
|
| 265 |
+
lang_id = lang[:3] # Extract ISO 639-3 code (e.g., 'eng' from 'eng_Latn')
|
| 266 |
+
text = text_normalize(
|
| 267 |
+
text,
|
| 268 |
+
iso_code=lang_id,
|
| 269 |
+
lower_case=True,
|
| 270 |
+
remove_numbers=False,
|
| 271 |
+
remove_brackets=False,
|
| 272 |
+
)
|
| 273 |
+
text = clean_cjk_spaces(text)
|
| 274 |
+
text = text.replace(" ", "|")
|
| 275 |
+
text = " ".join([x for x in text])
|
| 276 |
+
return text
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def run_eval_worker(data_chunk, language, batch_size):
|
| 280 |
+
"""
|
| 281 |
+
Worker function to process a chunk of data.
|
| 282 |
+
Uses the global worker_pipe initialized by process_init.
|
| 283 |
+
"""
|
| 284 |
+
global worker_pipe
|
| 285 |
+
if worker_pipe is None:
|
| 286 |
+
logging.error("Worker pipeline is not initialized!")
|
| 287 |
+
return []
|
| 288 |
+
|
| 289 |
+
metrics_buffer = []
|
| 290 |
+
try:
|
| 291 |
+
# Prepare batch lists for OmniASR
|
| 292 |
+
audio_paths = [item["wav_path"] for item in data_chunk]
|
| 293 |
+
|
| 294 |
+
# OmniASR expects explicit language codes for each file if not auto-detected.
|
| 295 |
+
# Using the language passed to the worker function, or item specific language
|
| 296 |
+
# Assuming item['lang_id'] is compatible (e.g., 'en', 'zh', 'arb_Arab')
|
| 297 |
+
# If the model needs full tokens like 'en_Latn', conversion might be needed here depending on input data.
|
| 298 |
+
lang_list = [item.get("lang_id", language) for item in data_chunk]
|
| 299 |
+
|
| 300 |
+
# Use the pipeline to infer batch
|
| 301 |
+
# OmniASR pipeline.transcribe returns a list of strings
|
| 302 |
+
transcriptions = worker_pipe.transcribe(
|
| 303 |
+
audio_paths, lang=lang_list, batch_size=batch_size
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
for i, hypo_text in enumerate(transcriptions):
|
| 307 |
+
ref_item = data_chunk[i]
|
| 308 |
+
truth = ref_item["truth_text"]
|
| 309 |
+
wav_path = ref_item["wav_path"]
|
| 310 |
+
lang_id = ref_item.get("lang_id")
|
| 311 |
+
lang_name = ref_item.get("lang_name")
|
| 312 |
+
|
| 313 |
+
m = process_one(hypo_text, truth, post_process, lang_id)
|
| 314 |
+
m["wav_path"] = wav_path
|
| 315 |
+
m["lang_name"] = lang_name
|
| 316 |
+
metrics_buffer.append(m)
|
| 317 |
+
|
| 318 |
+
except Exception:
|
| 319 |
+
logging.error(
|
| 320 |
+
f"Worker failed on chunk (Lang: {language}):\n{traceback.format_exc()}"
|
| 321 |
+
)
|
| 322 |
+
return []
|
| 323 |
+
|
| 324 |
+
return metrics_buffer
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def main():
|
| 328 |
+
parser = get_parser()
|
| 329 |
+
args = parser.parse_args()
|
| 330 |
+
|
| 331 |
+
logging.basicConfig(
|
| 332 |
+
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
|
| 333 |
+
level=logging.INFO,
|
| 334 |
+
force=True,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# 1. Prepare Data
|
| 338 |
+
logging.info("Reading test list...")
|
| 339 |
+
data_by_lang = defaultdict(list)
|
| 340 |
+
total_files = 0
|
| 341 |
+
wav_root = Path(args.wav_path)
|
| 342 |
+
|
| 343 |
+
samples = read_test_list(args.test_list)
|
| 344 |
+
for s in samples:
|
| 345 |
+
wav_path = str(wav_root / f"{s['id']}.{args.extension}")
|
| 346 |
+
if not os.path.exists(wav_path):
|
| 347 |
+
logging.warning(f"File missing: {wav_path}")
|
| 348 |
+
continue
|
| 349 |
+
|
| 350 |
+
lang_id = s.get("language_id") or "unknown"
|
| 351 |
+
if lang_id in rename:
|
| 352 |
+
lang_id = mixed_id_to_omnilingual_asr_lang[rename[lang_id]]
|
| 353 |
+
else:
|
| 354 |
+
lang_id = mixed_id_to_omnilingual_asr_lang[lang_id]
|
| 355 |
+
item = {
|
| 356 |
+
"wav_path": wav_path,
|
| 357 |
+
"truth_text": s["text"],
|
| 358 |
+
"lang_id": lang_id,
|
| 359 |
+
"lang_name": s.get("language_name") or "unknown",
|
| 360 |
+
}
|
| 361 |
+
if args.lang and s.get("language_id") != args.lang:
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
data_by_lang[s.get("language_name") or "unknown"].append(item)
|
| 365 |
+
|
| 366 |
+
total_files += 1
|
| 367 |
+
|
| 368 |
+
logging.info(f"Total files: {total_files} in {len(data_by_lang)} languages.")
|
| 369 |
+
|
| 370 |
+
# 2. Worker config
|
| 371 |
+
num_gpus = torch.cuda.device_count()
|
| 372 |
+
assert num_gpus > 0, "No GPU found. GPU is required."
|
| 373 |
+
total_workers = num_gpus * args.nj_per_gpu
|
| 374 |
+
|
| 375 |
+
mp.set_start_method("spawn", force=True)
|
| 376 |
+
manager = mp.Manager()
|
| 377 |
+
rank_queue = manager.Queue()
|
| 378 |
+
|
| 379 |
+
for _ in range(args.nj_per_gpu):
|
| 380 |
+
for rank in range(num_gpus):
|
| 381 |
+
rank_queue.put(rank)
|
| 382 |
+
|
| 383 |
+
# 3. Scheduling: Split languages into chunks
|
| 384 |
+
# This prevents one huge language from blocking a worker for too long,
|
| 385 |
+
# allows better load balancing across the pool.
|
| 386 |
+
tasks = []
|
| 387 |
+
chunk_size = args.chunk_size
|
| 388 |
+
|
| 389 |
+
for lang_name, items in data_by_lang.items():
|
| 390 |
+
# Slicing the list into chunks
|
| 391 |
+
for i in range(0, len(items), chunk_size):
|
| 392 |
+
chunk = items[i : i + chunk_size]
|
| 393 |
+
tasks.append({"chunk": chunk, "lang": lang_name})
|
| 394 |
+
|
| 395 |
+
logging.info(
|
| 396 |
+
f"Split data into {len(tasks)} chunks (size ~{chunk_size}). Spawning {total_workers} workers."
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
# 4. Execution
|
| 400 |
+
results = []
|
| 401 |
+
|
| 402 |
+
with ProcessPoolExecutor(
|
| 403 |
+
max_workers=total_workers,
|
| 404 |
+
initializer=process_init,
|
| 405 |
+
initargs=(rank_queue, args.model_card),
|
| 406 |
+
) as executor:
|
| 407 |
+
|
| 408 |
+
futures = []
|
| 409 |
+
for task in tasks:
|
| 410 |
+
futures.append(
|
| 411 |
+
executor.submit(
|
| 412 |
+
run_eval_worker, task["chunk"], task["lang"], args.batch_size
|
| 413 |
+
)
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Unified progress bar
|
| 417 |
+
with tqdm(total=total_files, desc="Eval Progress", dynamic_ncols=True) as pbar:
|
| 418 |
+
for future in as_completed(futures):
|
| 419 |
+
try:
|
| 420 |
+
chunk_metrics = future.result()
|
| 421 |
+
results.extend(chunk_metrics)
|
| 422 |
+
pbar.update(len(chunk_metrics))
|
| 423 |
+
except Exception as e:
|
| 424 |
+
logging.error(f"Task failed: {e}")
|
| 425 |
+
|
| 426 |
+
# 5. Metrics Aggregation
|
| 427 |
+
wers, inses, deles, subses = [], [], [], []
|
| 428 |
+
word_nums = 0
|
| 429 |
+
|
| 430 |
+
# Store metrics per language
|
| 431 |
+
lang_stats = {}
|
| 432 |
+
|
| 433 |
+
fout = None
|
| 434 |
+
if args.decode_path:
|
| 435 |
+
os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
|
| 436 |
+
logging.info(f"Saving detailed WER results to: {args.decode_path}")
|
| 437 |
+
fout = open(args.decode_path, "w", encoding="utf-8")
|
| 438 |
+
|
| 439 |
+
for res in results:
|
| 440 |
+
wers.append(float(res["wer"]))
|
| 441 |
+
inses.append(float(res["insertions"]))
|
| 442 |
+
deles.append(float(res["deletions"]))
|
| 443 |
+
subses.append(float(res["substitutions"]))
|
| 444 |
+
word_nums += res["word_num"]
|
| 445 |
+
|
| 446 |
+
if fout:
|
| 447 |
+
fout.write(
|
| 448 |
+
f"{res['wav_path']}\t{res['wer']}\t{res['truth']}\t"
|
| 449 |
+
f"{res['hypo']}\t{res['insertions']}\t{res['deletions']}\t"
|
| 450 |
+
f"{res['substitutions']}\n"
|
| 451 |
+
)
|
| 452 |
+
lang_name = res["lang_name"]
|
| 453 |
+
|
| 454 |
+
# Per language stats
|
| 455 |
+
if lang_name not in lang_stats:
|
| 456 |
+
lang_stats[lang_name] = {
|
| 457 |
+
"inses": [],
|
| 458 |
+
"deles": [],
|
| 459 |
+
"subses": [],
|
| 460 |
+
"word_nums": 0,
|
| 461 |
+
}
|
| 462 |
+
lang_stats[lang_name]["inses"].append(float(res["insertions"]))
|
| 463 |
+
lang_stats[lang_name]["deles"].append(float(res["deletions"]))
|
| 464 |
+
lang_stats[lang_name]["subses"].append(float(res["substitutions"]))
|
| 465 |
+
lang_stats[lang_name]["word_nums"] += res["word_num"]
|
| 466 |
+
|
| 467 |
+
print("-" * 50)
|
| 468 |
+
# Log per-language stats
|
| 469 |
+
per_lang_wers = []
|
| 470 |
+
for lang in sorted(lang_stats.keys()):
|
| 471 |
+
stats = lang_stats[lang]
|
| 472 |
+
if stats["word_nums"] > 0:
|
| 473 |
+
lang_wer = log_metrics(
|
| 474 |
+
fout,
|
| 475 |
+
f"[{lang}]",
|
| 476 |
+
stats["inses"],
|
| 477 |
+
stats["deles"],
|
| 478 |
+
stats["subses"],
|
| 479 |
+
stats["word_nums"],
|
| 480 |
+
)
|
| 481 |
+
per_lang_wers.append(lang_wer)
|
| 482 |
+
print("-" * 50)
|
| 483 |
+
|
| 484 |
+
# Log Macro-average WER
|
| 485 |
+
if len(per_lang_wers) > 1:
|
| 486 |
+
macro_wer = np.mean(per_lang_wers)
|
| 487 |
+
logging.info(
|
| 488 |
+
f"Macro-average WER over {len(per_lang_wers)} languages: {macro_wer:.2f}%"
|
| 489 |
+
)
|
| 490 |
+
if fout:
|
| 491 |
+
fout.write(
|
| 492 |
+
f"Macro-average WER over {len(per_lang_wers)} languages: {macro_wer:.2f}%\n"
|
| 493 |
+
)
|
| 494 |
+
count_le_5 = sum(1 for w in per_lang_wers if w <= 5.0)
|
| 495 |
+
count_le_10 = sum(1 for w in per_lang_wers if w <= 10.0)
|
| 496 |
+
count_le_20 = sum(1 for w in per_lang_wers if w <= 20.0)
|
| 497 |
+
|
| 498 |
+
stats_msg = (
|
| 499 |
+
f"Languages with WER/CER <= 5%: {count_le_5}/{len(per_lang_wers)}\n"
|
| 500 |
+
f"Languages with WER/CER <= 10%: {count_le_10}/{len(per_lang_wers)}\n"
|
| 501 |
+
f"Languages with WER/CER <= 20%: {count_le_20}/{len(per_lang_wers)}"
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
logging.info("\n" + stats_msg)
|
| 505 |
+
if fout:
|
| 506 |
+
fout.write(stats_msg + "\n")
|
| 507 |
+
|
| 508 |
+
# Log overall stats
|
| 509 |
+
if word_nums > 0:
|
| 510 |
+
log_metrics(fout, "Overall", inses, deles, subses, word_nums)
|
| 511 |
+
|
| 512 |
+
if fout:
|
| 513 |
+
fout.close()
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
if __name__ == "__main__":
|
| 517 |
+
main()
|
omnivoice/eval/wer/hubert.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Computes word error rate (WER) with Hubert models for LibriSpeech test sets.
|
| 20 |
+
"""
|
| 21 |
+
import argparse
|
| 22 |
+
import logging
|
| 23 |
+
import multiprocessing as mp
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import traceback
|
| 27 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import torch
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
|
| 34 |
+
from omnivoice.eval.utils import load_waveform
|
| 35 |
+
from omnivoice.eval.wer.common import process_one
|
| 36 |
+
from omnivoice.utils.data_utils import read_test_list
|
| 37 |
+
|
| 38 |
+
# --- Global variables for worker processes ---
|
| 39 |
+
worker_pipe = None
|
| 40 |
+
worker_device = None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_parser():
|
| 44 |
+
parser = argparse.ArgumentParser(
|
| 45 |
+
description="Computes WER with Hubert-based ASR model.",
|
| 46 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--wav-path",
|
| 50 |
+
type=str,
|
| 51 |
+
required=True,
|
| 52 |
+
help="Path to the directory containing speech files.",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--extension",
|
| 56 |
+
type=str,
|
| 57 |
+
default="wav",
|
| 58 |
+
help="Extension of the speech files. Default: wav",
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--decode-path",
|
| 62 |
+
type=str,
|
| 63 |
+
default=None,
|
| 64 |
+
help="Path to the output file where WER information will be saved. "
|
| 65 |
+
"If not provided, results are only printed to console.",
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--model-dir",
|
| 69 |
+
type=str,
|
| 70 |
+
required=True,
|
| 71 |
+
help="Local path of our evaluation model repository."
|
| 72 |
+
"Download from https://huggingface.co/k2-fsa/TTS_eval_models."
|
| 73 |
+
"Will use 'tts_eval_models/wer/hubert-large-ls960-ft/'"
|
| 74 |
+
" in this script",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--test-list",
|
| 78 |
+
type=str,
|
| 79 |
+
default="transcript.jsonl",
|
| 80 |
+
help="path of the JSONL test list. Each line is a JSON object "
|
| 81 |
+
"with fields: id, text, ref_audio, ref_text, language_id, language_name.",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--batch-size",
|
| 85 |
+
type=int,
|
| 86 |
+
default=16,
|
| 87 |
+
help="Batch size for decoding with the Hugging Face pipeline.",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--nj-per-gpu", type=int, default=1, help="Number of workers per GPU."
|
| 91 |
+
)
|
| 92 |
+
return parser
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def process_init(rank_queue, model_dir):
|
| 96 |
+
global worker_pipe, worker_device
|
| 97 |
+
|
| 98 |
+
torch.set_num_threads(2)
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
rank = rank_queue.get(timeout=10)
|
| 102 |
+
except Exception:
|
| 103 |
+
raise RuntimeError("Failed to get GPU rank from queue.")
|
| 104 |
+
|
| 105 |
+
assert torch.cuda.is_available(), "CUDA is required but not available."
|
| 106 |
+
worker_device = torch.device(f"cuda:{rank}")
|
| 107 |
+
torch.cuda.set_device(rank)
|
| 108 |
+
|
| 109 |
+
logging.info(f"Initializing worker on device: {worker_device}")
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
worker_pipe = load_hubert_model(model_dir, worker_device)
|
| 113 |
+
if worker_pipe is None:
|
| 114 |
+
raise RuntimeError("Model loading failed.")
|
| 115 |
+
except Exception as e:
|
| 116 |
+
logging.critical(f"Failed to load model on {worker_device}: {e}")
|
| 117 |
+
raise e
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def load_hubert_model(model_dir, device):
|
| 121 |
+
model_path = os.path.join(model_dir, "wer/hubert-large-ls960-ft/")
|
| 122 |
+
if not os.path.exists(model_path):
|
| 123 |
+
logging.error(
|
| 124 |
+
f"Hubert model not found at {model_path}. "
|
| 125 |
+
"Please download from https://huggingface.co/k2-fsa/TTS_eval_models"
|
| 126 |
+
)
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
logging.debug(f"Loading Hubert-based ASR model on {device}...")
|
| 130 |
+
import transformers
|
| 131 |
+
|
| 132 |
+
# Suppress transformers logging
|
| 133 |
+
transformers.logging.set_verbosity_error()
|
| 134 |
+
|
| 135 |
+
pipe = transformers.pipeline(
|
| 136 |
+
"automatic-speech-recognition",
|
| 137 |
+
model=model_path,
|
| 138 |
+
device=device,
|
| 139 |
+
tokenizer=model_path,
|
| 140 |
+
)
|
| 141 |
+
return pipe
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def post_process(text: str) -> str:
|
| 145 |
+
"""
|
| 146 |
+
Cleans and normalizes text for WER calculation.
|
| 147 |
+
Args:
|
| 148 |
+
text (str): The input text to be processed.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
str: The cleaned and normalized text.
|
| 152 |
+
"""
|
| 153 |
+
text = text.replace("‘", "'").replace("’", "'")
|
| 154 |
+
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
|
| 155 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 156 |
+
return text
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def run_eval_worker(data_chunk, batch_size):
|
| 160 |
+
global worker_pipe
|
| 161 |
+
if worker_pipe is None:
|
| 162 |
+
logging.error("Worker pipeline is not initialized!")
|
| 163 |
+
return []
|
| 164 |
+
|
| 165 |
+
metrics_buffer = []
|
| 166 |
+
try:
|
| 167 |
+
dataset = [
|
| 168 |
+
{
|
| 169 |
+
"array": load_waveform(
|
| 170 |
+
item["wav_path"], sample_rate=16000, return_numpy=True
|
| 171 |
+
),
|
| 172 |
+
"sampling_rate": 16000,
|
| 173 |
+
}
|
| 174 |
+
for item in data_chunk
|
| 175 |
+
]
|
| 176 |
+
generate_kwargs = {"language": "english", "task": "transcribe"}
|
| 177 |
+
|
| 178 |
+
iterator = worker_pipe(
|
| 179 |
+
dataset, generate_kwargs=generate_kwargs, batch_size=batch_size
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
for i, out in enumerate(iterator):
|
| 183 |
+
hypothesis = out["text"].strip()
|
| 184 |
+
ref_item = data_chunk[i]
|
| 185 |
+
truth = ref_item["truth_text"]
|
| 186 |
+
wav_path = ref_item["wav_path"]
|
| 187 |
+
|
| 188 |
+
m = process_one(hypothesis, truth, post_process)
|
| 189 |
+
m["wav_path"] = wav_path
|
| 190 |
+
metrics_buffer.append(m)
|
| 191 |
+
|
| 192 |
+
except Exception:
|
| 193 |
+
logging.error(f"Worker failed on chunk:\n{traceback.format_exc()}")
|
| 194 |
+
return []
|
| 195 |
+
|
| 196 |
+
return metrics_buffer
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def main():
|
| 200 |
+
parser = get_parser()
|
| 201 |
+
args = parser.parse_args()
|
| 202 |
+
|
| 203 |
+
logging.basicConfig(
|
| 204 |
+
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
|
| 205 |
+
level=logging.INFO,
|
| 206 |
+
force=True,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
logging.info(f"Calculating WER for {args.wav_path}")
|
| 210 |
+
|
| 211 |
+
data_list = []
|
| 212 |
+
samples = read_test_list(args.test_list)
|
| 213 |
+
for s in samples:
|
| 214 |
+
wav_full_path = str(Path(args.wav_path) / (s["id"] + "." + args.extension))
|
| 215 |
+
if not os.path.exists(wav_full_path):
|
| 216 |
+
logging.warning(f"File missing: {wav_full_path}")
|
| 217 |
+
continue
|
| 218 |
+
data_list.append(
|
| 219 |
+
{
|
| 220 |
+
"wav_path": wav_full_path,
|
| 221 |
+
"truth_text": s["text"],
|
| 222 |
+
}
|
| 223 |
+
)
|
| 224 |
+
total_files = len(data_list)
|
| 225 |
+
|
| 226 |
+
num_gpus = torch.cuda.device_count()
|
| 227 |
+
assert num_gpus > 0, "No GPU found. GPU is required."
|
| 228 |
+
total_workers = num_gpus * args.nj_per_gpu
|
| 229 |
+
|
| 230 |
+
mp.set_start_method("spawn", force=True)
|
| 231 |
+
manager = mp.Manager()
|
| 232 |
+
rank_queue = manager.Queue()
|
| 233 |
+
|
| 234 |
+
for _ in range(args.nj_per_gpu):
|
| 235 |
+
for rank in range(num_gpus):
|
| 236 |
+
rank_queue.put(rank)
|
| 237 |
+
|
| 238 |
+
chunk_size = max(1, args.batch_size)
|
| 239 |
+
tasks = [data_list[i : i + chunk_size] for i in range(0, total_files, chunk_size)]
|
| 240 |
+
|
| 241 |
+
logging.info(
|
| 242 |
+
f"Split data into {len(tasks)} chunks (size ~{chunk_size}). "
|
| 243 |
+
f"Spawning {total_workers} workers."
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
results = []
|
| 247 |
+
|
| 248 |
+
with ProcessPoolExecutor(
|
| 249 |
+
max_workers=total_workers,
|
| 250 |
+
initializer=process_init,
|
| 251 |
+
initargs=(rank_queue, args.model_dir),
|
| 252 |
+
) as executor:
|
| 253 |
+
|
| 254 |
+
futures = []
|
| 255 |
+
for chunk in tasks:
|
| 256 |
+
futures.append(executor.submit(run_eval_worker, chunk, args.batch_size))
|
| 257 |
+
|
| 258 |
+
with tqdm(total=total_files, desc="Eval Progress", dynamic_ncols=True) as pbar:
|
| 259 |
+
for future in as_completed(futures):
|
| 260 |
+
chunk_metrics = future.result()
|
| 261 |
+
results.extend(chunk_metrics)
|
| 262 |
+
pbar.update(len(chunk_metrics))
|
| 263 |
+
|
| 264 |
+
wers, inses, deles, subses = [], [], [], []
|
| 265 |
+
word_nums = 0
|
| 266 |
+
|
| 267 |
+
fout = None
|
| 268 |
+
if args.decode_path:
|
| 269 |
+
os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
|
| 270 |
+
fout = open(args.decode_path, "w", encoding="utf8")
|
| 271 |
+
logging.info(f"Saving detailed WER results to: {args.decode_path}")
|
| 272 |
+
fout.write(
|
| 273 |
+
"Name\tWER\tTruth\tHypothesis\tInsertions\tDeletions\tSubstitutions\n"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
for res in results:
|
| 277 |
+
wers.append(float(res["wer"]))
|
| 278 |
+
inses.append(float(res["insertions"]))
|
| 279 |
+
deles.append(float(res["deletions"]))
|
| 280 |
+
subses.append(float(res["substitutions"]))
|
| 281 |
+
word_nums += res["word_num"]
|
| 282 |
+
|
| 283 |
+
if fout:
|
| 284 |
+
fout.write(
|
| 285 |
+
f"{res['wav_path']}\t{res['wer']}\t{res['truth']}\t"
|
| 286 |
+
f"{res['hypo']}\t{res['insertions']}\t{res['deletions']}\t"
|
| 287 |
+
f"{res['substitutions']}\n"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
wer_weighted = (
|
| 291 |
+
round(
|
| 292 |
+
(np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 2
|
| 293 |
+
)
|
| 294 |
+
if word_nums > 0
|
| 295 |
+
else float("nan")
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
inse_sum = np.sum(inses)
|
| 299 |
+
dele_sum = np.sum(deles)
|
| 300 |
+
subs_sum = np.sum(subses)
|
| 301 |
+
|
| 302 |
+
print("-" * 50)
|
| 303 |
+
logging.info(f"Processed {len(results)}/{total_files} files.")
|
| 304 |
+
wer_info = f"WER: {wer_weighted}%"
|
| 305 |
+
detailed_info = (
|
| 306 |
+
f"Errors: {inse_sum} ins, {dele_sum} del, {subs_sum} sub / {word_nums} words"
|
| 307 |
+
)
|
| 308 |
+
logging.info(wer_info)
|
| 309 |
+
logging.info(detailed_info)
|
| 310 |
+
print("-" * 50)
|
| 311 |
+
|
| 312 |
+
if fout:
|
| 313 |
+
fout.write(wer_info + "\n" + detailed_info + "\n")
|
| 314 |
+
fout.close()
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
main()
|
omnivoice/eval/wer/minimax.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Computes word error rate (WER) with Whisper-large-v3 for English and
|
| 20 |
+
Paraformer for Chinese. Intended to evaluate WERs on Seed-TTS test sets.
|
| 21 |
+
"""
|
| 22 |
+
import argparse
|
| 23 |
+
import logging
|
| 24 |
+
import multiprocessing as mp
|
| 25 |
+
import os
|
| 26 |
+
import traceback
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from typing import List, Union
|
| 31 |
+
|
| 32 |
+
import numpy as np
|
| 33 |
+
import torch
|
| 34 |
+
import zhconv
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
|
| 37 |
+
from omnivoice.eval.utils import load_waveform
|
| 38 |
+
from omnivoice.eval.wer.common import log_metrics, process_one
|
| 39 |
+
from omnivoice.eval.wer.text_norm_omni import text_normalize
|
| 40 |
+
from omnivoice.utils.data_utils import read_test_list
|
| 41 |
+
|
| 42 |
+
# --- Global variables for worker processes ---
|
| 43 |
+
worker_pipe = None
|
| 44 |
+
worker_paraformer = None
|
| 45 |
+
worker_device = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def read_language_mapping_from_tsv(
|
| 49 |
+
mapping_path: Path,
|
| 50 |
+
) -> dict[str, Union[str, List[str]]]:
|
| 51 |
+
with open(mapping_path, "r", encoding="utf-8") as f:
|
| 52 |
+
_ = f.readline() # Skip header
|
| 53 |
+
language_mapping = {}
|
| 54 |
+
for line in f:
|
| 55 |
+
parts = line.strip().split("\t")
|
| 56 |
+
mixed_id, language_name, iso_639_3_id, duration = parts
|
| 57 |
+
language_mapping[mixed_id] = iso_639_3_id
|
| 58 |
+
return language_mapping
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
mixed_id_to_iso_639_3_id = read_language_mapping_from_tsv(
|
| 62 |
+
Path(f"{os.path.dirname(__file__)}/../../../docs/lang_id_name_map.tsv")
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_parser():
|
| 67 |
+
parser = argparse.ArgumentParser(
|
| 68 |
+
description="Computes WER with Whisper.",
|
| 69 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--wav-path",
|
| 74 |
+
type=str,
|
| 75 |
+
required=True,
|
| 76 |
+
help="Path to the directory containing speech files.",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--extension",
|
| 81 |
+
type=str,
|
| 82 |
+
default="wav",
|
| 83 |
+
help="Extension of the speech files. Default: wav",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--decode-path",
|
| 88 |
+
type=str,
|
| 89 |
+
default=None,
|
| 90 |
+
help="Path to the output file where WER information will be saved. "
|
| 91 |
+
"If not provided, results are only printed to console.",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--model-dir",
|
| 95 |
+
type=str,
|
| 96 |
+
required=True,
|
| 97 |
+
help="Local path of evaluation models repository. "
|
| 98 |
+
"Download from https://huggingface.co/k2-fsa/TTS_eval_models. ",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--test-list",
|
| 102 |
+
type=str,
|
| 103 |
+
default="test.jsonl",
|
| 104 |
+
help="path of the JSONL test list. Each line is a JSON object "
|
| 105 |
+
"with fields: id, text, ref_audio, ref_text, language_id, language_name.",
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--lang",
|
| 109 |
+
type=str,
|
| 110 |
+
default=None,
|
| 111 |
+
help="""Language code to evaluate (e.g., 'en' for English, 'zh' for Chinese).
|
| 112 |
+
If not provided, the script will evaluate all languages found in the test list.
|
| 113 |
+
If specified, only samples of the given language will be evaluated.
|
| 114 |
+
""",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--batch-size",
|
| 118 |
+
type=int,
|
| 119 |
+
default=16,
|
| 120 |
+
help="Batch size for decoding with the Hugging Face pipeline.",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--nj-per-gpu", type=int, default=1, help="Number of workers per GPU."
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--chunk-size",
|
| 127 |
+
type=int,
|
| 128 |
+
default=10,
|
| 129 |
+
help="Number of samples per task chunk sent to workers.",
|
| 130 |
+
)
|
| 131 |
+
return parser
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def load_whisper_model(model_dir, device):
|
| 135 |
+
model_path = os.path.join(model_dir, "wer/whisper-large-v3/")
|
| 136 |
+
if not os.path.exists(model_path):
|
| 137 |
+
logging.error(f"Whisper model not found at {model_path}.")
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
import transformers
|
| 141 |
+
|
| 142 |
+
# Suppress transformers logging
|
| 143 |
+
transformers.logging.set_verbosity_error()
|
| 144 |
+
|
| 145 |
+
logging.info(f"Loading Whisper model on {device}...")
|
| 146 |
+
pipe = transformers.pipeline(
|
| 147 |
+
"automatic-speech-recognition",
|
| 148 |
+
model=model_path,
|
| 149 |
+
chunk_length_s=30,
|
| 150 |
+
dtype=torch.float16 if "cuda" in str(device) else torch.float32,
|
| 151 |
+
device=device,
|
| 152 |
+
)
|
| 153 |
+
return pipe
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def load_paraformer_model(model_dir, device):
|
| 157 |
+
model_path = os.path.join(model_dir, "wer/paraformer-zh/")
|
| 158 |
+
if not os.path.exists(model_path):
|
| 159 |
+
logging.error(f"Paraformer model not found at {model_path}.")
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
logging.info(f"Loading Paraformer model on {device}...")
|
| 163 |
+
|
| 164 |
+
previous_level = logging.root.manager.disable
|
| 165 |
+
logging.disable(logging.CRITICAL)
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
from funasr import AutoModel
|
| 169 |
+
|
| 170 |
+
model = AutoModel(
|
| 171 |
+
model=model_path,
|
| 172 |
+
device=str(device),
|
| 173 |
+
disable_update=True,
|
| 174 |
+
disable_pbar=True,
|
| 175 |
+
verbose=False,
|
| 176 |
+
)
|
| 177 |
+
finally:
|
| 178 |
+
logging.disable(previous_level)
|
| 179 |
+
|
| 180 |
+
return model
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _worker_setup(rank_queue):
|
| 184 |
+
"""Common worker setup: get rank, configure device and threads."""
|
| 185 |
+
global worker_device
|
| 186 |
+
|
| 187 |
+
torch.set_num_threads(2)
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
rank = rank_queue.get(timeout=10)
|
| 191 |
+
except Exception:
|
| 192 |
+
raise RuntimeError("Failed to get GPU rank from queue.")
|
| 193 |
+
|
| 194 |
+
assert torch.cuda.is_available(), "CUDA is required but not available."
|
| 195 |
+
worker_device = torch.device(f"cuda:{rank}")
|
| 196 |
+
torch.cuda.set_device(rank)
|
| 197 |
+
|
| 198 |
+
logging.info(f"Initializing worker on device: {worker_device}")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def process_init(rank_queue, model_dir):
|
| 202 |
+
"""Initializer for Whisper worker processes."""
|
| 203 |
+
global worker_pipe
|
| 204 |
+
|
| 205 |
+
_worker_setup(rank_queue)
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
worker_pipe = load_whisper_model(model_dir, worker_device)
|
| 209 |
+
if worker_pipe is None:
|
| 210 |
+
raise RuntimeError("Whisper model loading failed.")
|
| 211 |
+
except Exception as e:
|
| 212 |
+
logging.critical(f"Failed to load Whisper model on {worker_device}: {e}")
|
| 213 |
+
raise e
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def process_init_paraformer(rank_queue, model_dir):
|
| 217 |
+
"""Initializer for Paraformer worker processes (Chinese evaluation)."""
|
| 218 |
+
global worker_paraformer
|
| 219 |
+
|
| 220 |
+
_worker_setup(rank_queue)
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
worker_paraformer = load_paraformer_model(model_dir, worker_device)
|
| 224 |
+
if worker_paraformer is None:
|
| 225 |
+
raise RuntimeError("Paraformer model loading failed.")
|
| 226 |
+
except Exception as e:
|
| 227 |
+
logging.critical(f"Failed to load Paraformer model on {worker_device}: {e}")
|
| 228 |
+
raise e
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def post_process(text: str, lang: str) -> str:
|
| 232 |
+
"""
|
| 233 |
+
Cleans and normalizes text for WER calculation.
|
| 234 |
+
Args:
|
| 235 |
+
text (str): The input text to be processed.
|
| 236 |
+
lang (str): The language of the input text.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
str: The cleaned and normalized text.
|
| 240 |
+
"""
|
| 241 |
+
if lang != "unknown":
|
| 242 |
+
|
| 243 |
+
iso_639_3_code = mixed_id_to_iso_639_3_id[lang]
|
| 244 |
+
text = text_normalize(
|
| 245 |
+
text,
|
| 246 |
+
iso_code=iso_639_3_code,
|
| 247 |
+
lower_case=True,
|
| 248 |
+
remove_numbers=False,
|
| 249 |
+
remove_brackets=False,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
if lang in ["zh", "yue"]:
|
| 253 |
+
text = zhconv.convert(text, "zh-cn")
|
| 254 |
+
|
| 255 |
+
# Processing spaces for languages using CER (consistent with the practice
|
| 256 |
+
# in paper Minimax-Speech), specifically: zh, yue, ja, ko, th, arb, vi, hi, el.
|
| 257 |
+
if lang in ("zh", "yue", "ja"):
|
| 258 |
+
# For languages where spaces are not semantically meaningful, remove spaces.
|
| 259 |
+
text = text.replace(" ", "")
|
| 260 |
+
text = " ".join([x for x in text])
|
| 261 |
+
elif lang in ("ko", "th", "arb", "vi", "hi", "el"):
|
| 262 |
+
# For languages where spaces are semantically meaningful, replace spaces with |.
|
| 263 |
+
text = text.replace(" ", "|")
|
| 264 |
+
text = " ".join([x for x in text])
|
| 265 |
+
text = text.lower()
|
| 266 |
+
return text.strip()
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class SpeechEvalDataset(torch.utils.data.Dataset):
|
| 270 |
+
def __init__(self, data_list):
|
| 271 |
+
self.data_list = data_list
|
| 272 |
+
|
| 273 |
+
def __len__(self):
|
| 274 |
+
return len(self.data_list)
|
| 275 |
+
|
| 276 |
+
def __getitem__(self, index):
|
| 277 |
+
item = self.data_list[index]
|
| 278 |
+
waveform = load_waveform(item["wav_path"], sample_rate=16000, return_numpy=True)
|
| 279 |
+
return {
|
| 280 |
+
"array": waveform,
|
| 281 |
+
"sampling_rate": 16000,
|
| 282 |
+
"truth_text": item["truth_text"],
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def run_eval_worker(data_chunk, language, batch_size):
|
| 287 |
+
"""
|
| 288 |
+
Worker function to process a chunk of data.
|
| 289 |
+
Uses the global worker_pipe initialized by process_init.
|
| 290 |
+
"""
|
| 291 |
+
global worker_pipe
|
| 292 |
+
if worker_pipe is None:
|
| 293 |
+
logging.error("Worker pipeline is not initialized!")
|
| 294 |
+
return []
|
| 295 |
+
|
| 296 |
+
metrics_buffer = []
|
| 297 |
+
try:
|
| 298 |
+
dataset = SpeechEvalDataset(data_chunk)
|
| 299 |
+
if language != "unknown":
|
| 300 |
+
generate_kwargs = {"language": language, "task": "transcribe"}
|
| 301 |
+
else:
|
| 302 |
+
generate_kwargs = {"task": "transcribe"}
|
| 303 |
+
|
| 304 |
+
# Use the pipeline to infer batch
|
| 305 |
+
# Note: We iterate through the iterator returned by pipe
|
| 306 |
+
iterator = worker_pipe(
|
| 307 |
+
dataset, generate_kwargs=generate_kwargs, batch_size=batch_size
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
for i, out in enumerate(iterator):
|
| 311 |
+
hypothesis = out["text"].strip()
|
| 312 |
+
|
| 313 |
+
ref_item = data_chunk[i]
|
| 314 |
+
truth = ref_item["truth_text"]
|
| 315 |
+
wav_path = ref_item["wav_path"]
|
| 316 |
+
lang_id = ref_item.get("lang_id")
|
| 317 |
+
lang_name = ref_item.get("lang_name")
|
| 318 |
+
|
| 319 |
+
m = process_one(hypothesis, truth, post_process, lang_id)
|
| 320 |
+
m["wav_path"] = wav_path
|
| 321 |
+
m["lang_name"] = lang_name
|
| 322 |
+
metrics_buffer.append(m)
|
| 323 |
+
|
| 324 |
+
except Exception:
|
| 325 |
+
logging.error(
|
| 326 |
+
f"Worker failed on chunk (Lang: {language}):\n{traceback.format_exc()}"
|
| 327 |
+
)
|
| 328 |
+
return []
|
| 329 |
+
|
| 330 |
+
return metrics_buffer
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def run_eval_worker_paraformer(data_chunk, batch_size):
|
| 334 |
+
"""
|
| 335 |
+
Worker function for Chinese evaluation using Paraformer.
|
| 336 |
+
Uses the global worker_paraformer initialized by process_init_paraformer.
|
| 337 |
+
"""
|
| 338 |
+
global worker_paraformer
|
| 339 |
+
if worker_paraformer is None:
|
| 340 |
+
logging.error("Paraformer worker pipeline is not initialized!")
|
| 341 |
+
return []
|
| 342 |
+
|
| 343 |
+
metrics_buffer = []
|
| 344 |
+
try:
|
| 345 |
+
wav_paths = [item["wav_path"] for item in data_chunk]
|
| 346 |
+
|
| 347 |
+
for i in range(0, len(wav_paths), batch_size):
|
| 348 |
+
batch_paths = wav_paths[i : i + batch_size]
|
| 349 |
+
res_batch = worker_paraformer.generate(
|
| 350 |
+
input=batch_paths, batch_size=batch_size, disable_pbar=True
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
for j, res in enumerate(res_batch):
|
| 354 |
+
hypothesis = res["text"]
|
| 355 |
+
ref_item = data_chunk[i + j]
|
| 356 |
+
truth = ref_item["truth_text"]
|
| 357 |
+
wav_path = ref_item["wav_path"]
|
| 358 |
+
lang_name = ref_item.get("lang_name")
|
| 359 |
+
|
| 360 |
+
m = process_one(hypothesis, truth, post_process, "zh")
|
| 361 |
+
m["wav_path"] = wav_path
|
| 362 |
+
m["lang_name"] = lang_name
|
| 363 |
+
metrics_buffer.append(m)
|
| 364 |
+
|
| 365 |
+
except Exception:
|
| 366 |
+
logging.error(f"Paraformer worker failed on chunk:\n{traceback.format_exc()}")
|
| 367 |
+
return []
|
| 368 |
+
|
| 369 |
+
return metrics_buffer
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def main():
|
| 373 |
+
parser = get_parser()
|
| 374 |
+
args = parser.parse_args()
|
| 375 |
+
|
| 376 |
+
logging.basicConfig(
|
| 377 |
+
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
|
| 378 |
+
level=logging.INFO,
|
| 379 |
+
force=True,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# 1. Prepare Data
|
| 383 |
+
logging.info("Reading test list...")
|
| 384 |
+
data_by_lang = defaultdict(list)
|
| 385 |
+
total_files = 0
|
| 386 |
+
wav_root = Path(args.wav_path)
|
| 387 |
+
|
| 388 |
+
samples = read_test_list(args.test_list)
|
| 389 |
+
for s in samples:
|
| 390 |
+
wav_path = str(wav_root / f"{s['id']}.{args.extension}")
|
| 391 |
+
if not os.path.exists(wav_path):
|
| 392 |
+
logging.warning(f"File missing: {wav_path}")
|
| 393 |
+
continue
|
| 394 |
+
|
| 395 |
+
lang_id = s.get("language_id") or "unknown"
|
| 396 |
+
lang_name = s.get("language_name") or "unknown"
|
| 397 |
+
|
| 398 |
+
item = {
|
| 399 |
+
"wav_path": wav_path,
|
| 400 |
+
"truth_text": s["text"],
|
| 401 |
+
"lang_id": lang_id,
|
| 402 |
+
"lang_name": lang_name,
|
| 403 |
+
}
|
| 404 |
+
if args.lang and s.get("language_id") != args.lang:
|
| 405 |
+
continue
|
| 406 |
+
|
| 407 |
+
data_by_lang[lang_name].append(item)
|
| 408 |
+
total_files += 1
|
| 409 |
+
|
| 410 |
+
logging.info(f"Total files: {total_files} in {len(data_by_lang)} languages.")
|
| 411 |
+
|
| 412 |
+
# 2. Worker config
|
| 413 |
+
num_gpus = torch.cuda.device_count()
|
| 414 |
+
assert num_gpus > 0, "No GPU found. GPU is required."
|
| 415 |
+
total_workers = num_gpus * args.nj_per_gpu
|
| 416 |
+
|
| 417 |
+
mp.set_start_method("spawn", force=True)
|
| 418 |
+
manager = mp.Manager()
|
| 419 |
+
|
| 420 |
+
# 3. Scheduling: Split data into Chinese (Paraformer) and non-Chinese (Whisper)
|
| 421 |
+
zh_items = []
|
| 422 |
+
non_zh_items = []
|
| 423 |
+
for lang_name, items in data_by_lang.items():
|
| 424 |
+
lang_id = items[0].get("lang_id", "") if items else ""
|
| 425 |
+
if lang_name == "Chinese" or (lang_id and lang_id.startswith("zh")):
|
| 426 |
+
zh_items.extend(items)
|
| 427 |
+
else:
|
| 428 |
+
non_zh_items.extend(items)
|
| 429 |
+
|
| 430 |
+
chunk_size = args.chunk_size
|
| 431 |
+
|
| 432 |
+
whisper_tasks = []
|
| 433 |
+
for i in range(0, len(non_zh_items), chunk_size):
|
| 434 |
+
chunk = non_zh_items[i : i + chunk_size]
|
| 435 |
+
lang_name = chunk[0].get("lang_name", "unknown")
|
| 436 |
+
whisper_tasks.append({"chunk": chunk, "lang": lang_name})
|
| 437 |
+
|
| 438 |
+
paraformer_tasks = []
|
| 439 |
+
for i in range(0, len(zh_items), chunk_size):
|
| 440 |
+
paraformer_tasks.append(zh_items[i : i + chunk_size])
|
| 441 |
+
|
| 442 |
+
logging.info(
|
| 443 |
+
f"Whisper tasks: {len(whisper_tasks)} chunks ({len(non_zh_items)} files). "
|
| 444 |
+
f"Paraformer tasks: {len(paraformer_tasks)} chunks ({len(zh_items)} files). "
|
| 445 |
+
f"Spawning {total_workers} workers per pool."
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# 4. Execution — run Whisper and Paraformer pools sequentially
|
| 449 |
+
results = []
|
| 450 |
+
|
| 451 |
+
# 4a. Whisper pool for non-Chinese languages
|
| 452 |
+
if whisper_tasks:
|
| 453 |
+
whisper_rank_queue = manager.Queue()
|
| 454 |
+
for _ in range(args.nj_per_gpu):
|
| 455 |
+
for rank in range(num_gpus):
|
| 456 |
+
whisper_rank_queue.put(rank)
|
| 457 |
+
|
| 458 |
+
with ProcessPoolExecutor(
|
| 459 |
+
max_workers=total_workers,
|
| 460 |
+
initializer=process_init,
|
| 461 |
+
initargs=(whisper_rank_queue, args.model_dir),
|
| 462 |
+
) as executor:
|
| 463 |
+
|
| 464 |
+
futures = []
|
| 465 |
+
for task in whisper_tasks:
|
| 466 |
+
futures.append(
|
| 467 |
+
executor.submit(
|
| 468 |
+
run_eval_worker, task["chunk"], task["lang"], args.batch_size
|
| 469 |
+
)
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
with tqdm(
|
| 473 |
+
total=len(non_zh_items),
|
| 474 |
+
desc="Whisper Eval",
|
| 475 |
+
dynamic_ncols=True,
|
| 476 |
+
) as pbar:
|
| 477 |
+
for future in as_completed(futures):
|
| 478 |
+
try:
|
| 479 |
+
chunk_metrics = future.result()
|
| 480 |
+
results.extend(chunk_metrics)
|
| 481 |
+
pbar.update(len(chunk_metrics))
|
| 482 |
+
except Exception as e:
|
| 483 |
+
logging.error(f"Whisper task failed: {e}")
|
| 484 |
+
|
| 485 |
+
# 4b. Paraformer pool for Chinese
|
| 486 |
+
if paraformer_tasks:
|
| 487 |
+
para_rank_queue = manager.Queue()
|
| 488 |
+
for _ in range(args.nj_per_gpu):
|
| 489 |
+
for rank in range(num_gpus):
|
| 490 |
+
para_rank_queue.put(rank)
|
| 491 |
+
|
| 492 |
+
with ProcessPoolExecutor(
|
| 493 |
+
max_workers=total_workers,
|
| 494 |
+
initializer=process_init_paraformer,
|
| 495 |
+
initargs=(para_rank_queue, args.model_dir),
|
| 496 |
+
) as executor:
|
| 497 |
+
|
| 498 |
+
futures = []
|
| 499 |
+
for chunk in paraformer_tasks:
|
| 500 |
+
futures.append(
|
| 501 |
+
executor.submit(run_eval_worker_paraformer, chunk, args.batch_size)
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
with tqdm(
|
| 505 |
+
total=len(zh_items),
|
| 506 |
+
desc="Paraformer Eval",
|
| 507 |
+
dynamic_ncols=True,
|
| 508 |
+
) as pbar:
|
| 509 |
+
for future in as_completed(futures):
|
| 510 |
+
try:
|
| 511 |
+
chunk_metrics = future.result()
|
| 512 |
+
results.extend(chunk_metrics)
|
| 513 |
+
pbar.update(len(chunk_metrics))
|
| 514 |
+
except Exception as e:
|
| 515 |
+
logging.error(f"Paraformer task failed: {e}")
|
| 516 |
+
|
| 517 |
+
# 5. Metrics Aggregation
|
| 518 |
+
wers, inses, deles, subses = [], [], [], []
|
| 519 |
+
word_nums = 0
|
| 520 |
+
|
| 521 |
+
# Store metrics per language
|
| 522 |
+
lang_stats = {}
|
| 523 |
+
|
| 524 |
+
fout = None
|
| 525 |
+
if args.decode_path:
|
| 526 |
+
os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
|
| 527 |
+
logging.info(f"Saving detailed WER results to: {args.decode_path}")
|
| 528 |
+
fout = open(args.decode_path, "w", encoding="utf-8")
|
| 529 |
+
|
| 530 |
+
for res in results:
|
| 531 |
+
wers.append(float(res["wer"]))
|
| 532 |
+
inses.append(float(res["insertions"]))
|
| 533 |
+
deles.append(float(res["deletions"]))
|
| 534 |
+
subses.append(float(res["substitutions"]))
|
| 535 |
+
word_nums += res["word_num"]
|
| 536 |
+
|
| 537 |
+
if fout:
|
| 538 |
+
fout.write(
|
| 539 |
+
f"{res['wav_path']}\t{res['wer']}\t{res['truth']}\t"
|
| 540 |
+
f"{res['hypo']}\t{res['insertions']}\t{res['deletions']}\t"
|
| 541 |
+
f"{res['substitutions']}\n"
|
| 542 |
+
)
|
| 543 |
+
lang_name = res["lang_name"]
|
| 544 |
+
|
| 545 |
+
# Per language stats
|
| 546 |
+
if lang_name not in lang_stats:
|
| 547 |
+
lang_stats[lang_name] = {
|
| 548 |
+
"inses": [],
|
| 549 |
+
"deles": [],
|
| 550 |
+
"subses": [],
|
| 551 |
+
"word_nums": 0,
|
| 552 |
+
}
|
| 553 |
+
lang_stats[lang_name]["inses"].append(float(res["insertions"]))
|
| 554 |
+
lang_stats[lang_name]["deles"].append(float(res["deletions"]))
|
| 555 |
+
lang_stats[lang_name]["subses"].append(float(res["substitutions"]))
|
| 556 |
+
lang_stats[lang_name]["word_nums"] += res["word_num"]
|
| 557 |
+
|
| 558 |
+
print("-" * 50)
|
| 559 |
+
# Log per-language stats
|
| 560 |
+
per_lang_wers = []
|
| 561 |
+
for lang in sorted(lang_stats.keys()):
|
| 562 |
+
stats = lang_stats[lang]
|
| 563 |
+
if stats["word_nums"] > 0:
|
| 564 |
+
lang_wer = log_metrics(
|
| 565 |
+
fout,
|
| 566 |
+
f"[{lang}]",
|
| 567 |
+
stats["inses"],
|
| 568 |
+
stats["deles"],
|
| 569 |
+
stats["subses"],
|
| 570 |
+
stats["word_nums"],
|
| 571 |
+
ndigits=3,
|
| 572 |
+
)
|
| 573 |
+
per_lang_wers.append(lang_wer)
|
| 574 |
+
print("-" * 50)
|
| 575 |
+
|
| 576 |
+
# Log Macro-average WER
|
| 577 |
+
if len(per_lang_wers) > 1:
|
| 578 |
+
macro_wer = np.mean(per_lang_wers)
|
| 579 |
+
logging.info(
|
| 580 |
+
f"Macro-average WER over {len(per_lang_wers)} languages: {macro_wer:.2f}%"
|
| 581 |
+
)
|
| 582 |
+
if fout:
|
| 583 |
+
fout.write(
|
| 584 |
+
f"Macro-average WER over {len(per_lang_wers)} languages: {macro_wer:.2f}%\n"
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
# Log overall stats
|
| 588 |
+
if word_nums > 0:
|
| 589 |
+
log_metrics(fout, "Overall", inses, deles, subses, word_nums)
|
| 590 |
+
|
| 591 |
+
if fout:
|
| 592 |
+
fout.close()
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
if __name__ == "__main__":
|
| 596 |
+
main()
|
omnivoice/eval/wer/norm_config_module.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the BSD-style license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
This module defines the normalization configuration for WER evaluation.
|
| 10 |
+
Copied from https://github.com/facebookresearch/omnilingual-asr/blob/81f51e224ce9e74b02cc2a3eaf21b2d91d743455/workflows/dataprep/norm_config_module.py
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
# type: ignore
|
| 14 |
+
import os
|
| 15 |
+
import re
|
| 16 |
+
|
| 17 |
+
colon = ":"
|
| 18 |
+
comma = ","
|
| 19 |
+
exclamation_mark = "!"
|
| 20 |
+
period = re.escape(".")
|
| 21 |
+
question_mark = re.escape("?")
|
| 22 |
+
semicolon = ";"
|
| 23 |
+
|
| 24 |
+
left_curly_bracket = "{"
|
| 25 |
+
right_curly_bracket = "}"
|
| 26 |
+
quotation_mark = '"'
|
| 27 |
+
|
| 28 |
+
basic_punc = (
|
| 29 |
+
period
|
| 30 |
+
+ question_mark
|
| 31 |
+
+ comma
|
| 32 |
+
+ colon
|
| 33 |
+
+ exclamation_mark
|
| 34 |
+
+ left_curly_bracket
|
| 35 |
+
+ right_curly_bracket
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# General punc unicode block (0x2000-0x206F)
|
| 39 |
+
zero_width_space = r"\u200B"
|
| 40 |
+
zero_width_nonjoiner = r"\u200C"
|
| 41 |
+
left_to_right_mark = r"\u200E"
|
| 42 |
+
right_to_left_mark = r"\u200F"
|
| 43 |
+
left_to_right_embedding = r"\u202A"
|
| 44 |
+
pop_directional_formatting = r"\u202C"
|
| 45 |
+
|
| 46 |
+
# Here are some commonly ill-typed versions of apostrophe
|
| 47 |
+
right_single_quotation_mark = r"\u2019"
|
| 48 |
+
left_single_quotation_mark = r"\u2018"
|
| 49 |
+
|
| 50 |
+
# Language specific definitions
|
| 51 |
+
# Spanish
|
| 52 |
+
inverted_exclamation_mark = r"\u00A1"
|
| 53 |
+
inverted_question_mark = r"\u00BF"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Hindi
|
| 57 |
+
hindi_danda = "\u0964"
|
| 58 |
+
|
| 59 |
+
# Egyptian Arabic
|
| 60 |
+
# arabic_percent = r"\u066A"
|
| 61 |
+
arabic_comma = r"\u060C"
|
| 62 |
+
arabic_question_mark = r"\u061F"
|
| 63 |
+
arabic_semicolon = r"\u061B"
|
| 64 |
+
arabic_diacritics = r"\u064B-\u0652"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
arabic_subscript_alef_and_inverted_damma = r"\u0656-\u0657"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Chinese
|
| 71 |
+
full_stop = r"\u3002"
|
| 72 |
+
full_comma = r"\uFF0C"
|
| 73 |
+
full_exclamation_mark = r"\uFF01"
|
| 74 |
+
full_question_mark = r"\uFF1F"
|
| 75 |
+
full_semicolon = r"\uFF1B"
|
| 76 |
+
full_colon = r"\uFF1A"
|
| 77 |
+
full_parentheses = r"\uFF08\uFF09"
|
| 78 |
+
quotation_mark_horizontal = r"\u300C-\u300F"
|
| 79 |
+
quotation_mark_vertical = r"\uFF41-\uFF44"
|
| 80 |
+
title_marks = r"\u3008-\u300B"
|
| 81 |
+
wavy_low_line = r"\uFE4F"
|
| 82 |
+
ellipsis = r"\u22EF"
|
| 83 |
+
enumeration_comma = r"\u3001"
|
| 84 |
+
hyphenation_point = r"\u2027"
|
| 85 |
+
forward_slash = r"\uFF0F"
|
| 86 |
+
wavy_dash = r"\uFF5E"
|
| 87 |
+
box_drawings_light_horizontal = r"\u2500"
|
| 88 |
+
fullwidth_low_line = r"\uFF3F"
|
| 89 |
+
chinese_punc = (
|
| 90 |
+
full_stop
|
| 91 |
+
+ full_comma
|
| 92 |
+
+ full_exclamation_mark
|
| 93 |
+
+ full_question_mark
|
| 94 |
+
+ full_semicolon
|
| 95 |
+
+ full_colon
|
| 96 |
+
+ full_parentheses
|
| 97 |
+
+ quotation_mark_horizontal
|
| 98 |
+
+ quotation_mark_vertical
|
| 99 |
+
+ title_marks
|
| 100 |
+
+ wavy_low_line
|
| 101 |
+
+ ellipsis
|
| 102 |
+
+ enumeration_comma
|
| 103 |
+
+ hyphenation_point
|
| 104 |
+
+ forward_slash
|
| 105 |
+
+ wavy_dash
|
| 106 |
+
+ box_drawings_light_horizontal
|
| 107 |
+
+ fullwidth_low_line
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Armenian
|
| 111 |
+
armenian_apostrophe = r"\u055A"
|
| 112 |
+
emphasis_mark = r"\u055B"
|
| 113 |
+
exclamation_mark = r"\u055C"
|
| 114 |
+
armenian_comma = r"\u055D"
|
| 115 |
+
armenian_question_mark = r"\u055E"
|
| 116 |
+
abbreviation_mark = r"\u055F"
|
| 117 |
+
armenian_full_stop = r"\u0589"
|
| 118 |
+
armenian_punc = (
|
| 119 |
+
armenian_apostrophe
|
| 120 |
+
+ emphasis_mark
|
| 121 |
+
+ exclamation_mark
|
| 122 |
+
+ armenian_comma
|
| 123 |
+
+ armenian_question_mark
|
| 124 |
+
+ abbreviation_mark
|
| 125 |
+
+ armenian_full_stop
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
lesser_than_symbol = r"<"
|
| 129 |
+
greater_than_symbol = r">"
|
| 130 |
+
|
| 131 |
+
lesser_than_sign = r"\u003c"
|
| 132 |
+
greater_than_sign = r"\u003e"
|
| 133 |
+
|
| 134 |
+
nbsp_written_form = r" "
|
| 135 |
+
|
| 136 |
+
# Quotation marks
|
| 137 |
+
left_double_quotes = r"\u201c"
|
| 138 |
+
right_double_quotes = r"\u201d"
|
| 139 |
+
left_double_angle = r"\u00ab"
|
| 140 |
+
right_double_angle = r"\u00bb"
|
| 141 |
+
left_single_angle = r"\u2039"
|
| 142 |
+
right_single_angle = r"\u203a"
|
| 143 |
+
low_double_quotes = r"\u201e"
|
| 144 |
+
low_single_quotes = r"\u201a"
|
| 145 |
+
high_double_quotes = r"\u201f"
|
| 146 |
+
high_single_quotes = r"\u201b"
|
| 147 |
+
|
| 148 |
+
all_punct_quotes = (
|
| 149 |
+
left_double_quotes
|
| 150 |
+
+ right_double_quotes
|
| 151 |
+
+ left_double_angle
|
| 152 |
+
+ right_double_angle
|
| 153 |
+
+ left_single_angle
|
| 154 |
+
+ right_single_angle
|
| 155 |
+
+ low_double_quotes
|
| 156 |
+
+ low_single_quotes
|
| 157 |
+
+ high_double_quotes
|
| 158 |
+
+ high_single_quotes
|
| 159 |
+
+ right_single_quotation_mark
|
| 160 |
+
+ left_single_quotation_mark
|
| 161 |
+
)
|
| 162 |
+
mapping_quotes = (
|
| 163 |
+
"["
|
| 164 |
+
+ high_single_quotes
|
| 165 |
+
+ right_single_quotation_mark
|
| 166 |
+
+ left_single_quotation_mark
|
| 167 |
+
+ "]"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# Digits
|
| 172 |
+
|
| 173 |
+
english_digits = r"\u0030-\u0039"
|
| 174 |
+
bengali_digits = r"\u09e6-\u09ef"
|
| 175 |
+
khmer_digits = r"\u17e0-\u17e9"
|
| 176 |
+
devanagari_digits = r"\u0966-\u096f"
|
| 177 |
+
oriya_digits = r"\u0b66-\u0b6f"
|
| 178 |
+
extended_arabic_indic_digits = r"\u06f0-\u06f9"
|
| 179 |
+
kayah_li_digits = r"\ua900-\ua909"
|
| 180 |
+
fullwidth_digits = r"\uff10-\uff19"
|
| 181 |
+
malayam_digits = r"\u0d66-\u0d6f"
|
| 182 |
+
myanmar_digits = r"\u1040-\u1049"
|
| 183 |
+
roman_numeral = r"\u2170-\u2179"
|
| 184 |
+
nominal_digit_shapes = r"\u206f"
|
| 185 |
+
|
| 186 |
+
# Load punctuations
|
| 187 |
+
with open(f"{os.path.dirname(__file__)}/punctuations.lst", "r") as punc_f:
|
| 188 |
+
punc_list = [
|
| 189 |
+
line
|
| 190 |
+
for line in punc_f.readlines()
|
| 191 |
+
if line.strip() and not line.strip().startswith("#")
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
punct_pattern = r""
|
| 195 |
+
for punc in punc_list:
|
| 196 |
+
# the first character in the tab separated line is the punc to be removed
|
| 197 |
+
punct_pattern += re.escape(punc.split("\t")[0])
|
| 198 |
+
|
| 199 |
+
shared_digits = (
|
| 200 |
+
english_digits
|
| 201 |
+
+ bengali_digits
|
| 202 |
+
+ khmer_digits
|
| 203 |
+
+ devanagari_digits
|
| 204 |
+
+ oriya_digits
|
| 205 |
+
+ extended_arabic_indic_digits
|
| 206 |
+
+ kayah_li_digits
|
| 207 |
+
+ fullwidth_digits
|
| 208 |
+
+ malayam_digits
|
| 209 |
+
+ myanmar_digits
|
| 210 |
+
+ roman_numeral
|
| 211 |
+
+ nominal_digit_shapes
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
shared_punc_list = (
|
| 215 |
+
basic_punc
|
| 216 |
+
+ all_punct_quotes
|
| 217 |
+
+ greater_than_sign
|
| 218 |
+
+ lesser_than_sign
|
| 219 |
+
+ inverted_question_mark
|
| 220 |
+
+ full_stop
|
| 221 |
+
+ semicolon
|
| 222 |
+
+ armenian_punc
|
| 223 |
+
+ inverted_exclamation_mark
|
| 224 |
+
+ arabic_comma
|
| 225 |
+
+ enumeration_comma
|
| 226 |
+
+ hindi_danda
|
| 227 |
+
+ quotation_mark
|
| 228 |
+
+ arabic_semicolon
|
| 229 |
+
+ arabic_question_mark
|
| 230 |
+
+ chinese_punc
|
| 231 |
+
+ punct_pattern
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
shared_mappping = {
|
| 235 |
+
lesser_than_symbol: "",
|
| 236 |
+
greater_than_symbol: "",
|
| 237 |
+
nbsp_written_form: "",
|
| 238 |
+
r"(\S+)" + mapping_quotes + r"(\S+)": r"\1'\2",
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
shared_deletion_list = (
|
| 242 |
+
left_to_right_mark
|
| 243 |
+
+ zero_width_nonjoiner
|
| 244 |
+
+ arabic_subscript_alef_and_inverted_damma
|
| 245 |
+
+ zero_width_space
|
| 246 |
+
+ arabic_diacritics
|
| 247 |
+
+ pop_directional_formatting
|
| 248 |
+
+ right_to_left_mark
|
| 249 |
+
+ left_to_right_embedding
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
norm_config = {
|
| 253 |
+
"*": {
|
| 254 |
+
"lower_case": True,
|
| 255 |
+
"punc_set": shared_punc_list,
|
| 256 |
+
"del_set": shared_deletion_list,
|
| 257 |
+
"mapping": shared_mappping,
|
| 258 |
+
"digit_set": shared_digits,
|
| 259 |
+
"unicode_norm": "NFKC",
|
| 260 |
+
"rm_diacritics": False,
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
# =============== Mongolian ===============#
|
| 265 |
+
|
| 266 |
+
norm_config["mon"] = norm_config["*"].copy()
|
| 267 |
+
# add soft hyphen to punc list to match with fleurs
|
| 268 |
+
norm_config["mon"]["del_set"] += r"\u00AD"
|
| 269 |
+
|
| 270 |
+
norm_config["khk"] = norm_config["mon"].copy()
|
| 271 |
+
|
| 272 |
+
# =============== Hebrew ===============#
|
| 273 |
+
|
| 274 |
+
norm_config["heb"] = norm_config["*"].copy()
|
| 275 |
+
# add "HEBREW POINT" symbols to match with fleurs
|
| 276 |
+
norm_config["heb"]["del_set"] += r"\u05B0-\u05BF\u05C0-\u05CF"
|
| 277 |
+
|
| 278 |
+
# =============== Thai ===============#
|
| 279 |
+
|
| 280 |
+
norm_config["tha"] = norm_config["*"].copy()
|
| 281 |
+
# add "Zero width joiner" symbols to match with fleurs
|
| 282 |
+
norm_config["tha"]["punc_set"] += r"\u200D"
|
| 283 |
+
|
| 284 |
+
# =============== Arabic ===============#
|
| 285 |
+
norm_config["ara"] = norm_config["*"].copy()
|
| 286 |
+
norm_config["ara"]["mapping"]["ٱ"] = "ا"
|
| 287 |
+
norm_config["arb"] = norm_config["ara"].copy()
|
| 288 |
+
|
| 289 |
+
# =============== Javanese ===============#
|
| 290 |
+
norm_config["jav"] = norm_config["*"].copy()
|
| 291 |
+
norm_config["jav"]["rm_diacritics"] = True
|
omnivoice/eval/wer/punctuations.lst
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
7355 INVALID UNICODE 0x81
|
| 2 |
+
5265 INVALID UNICODE 0x90
|
| 3 |
+
75 INVALID UNICODE 0x8
|
| 4 |
+
31 INVALID UNICODE 0x8d
|
| 5 |
+
3 INVALID UNICODE 0x94
|
| 6 |
+
2 INVALID UNICODE 0x8f
|
| 7 |
+
2 INVALID UNICODE 0x1a
|
| 8 |
+
1 INVALID UNICODE 0x9d
|
| 9 |
+
1 INVALID UNICODE 0x93
|
| 10 |
+
1 INVALID UNICODE 0x92
|
| 11 |
+
8647 INVALID UNICODE 0xe295
|
| 12 |
+
6650 INVALID UNICODE 0xf21d
|
| 13 |
+
6234 INVALID UNICODE 0xf62d
|
| 14 |
+
4815 INVALID UNICODE 0xf173
|
| 15 |
+
4789 INVALID UNICODE 0xe514
|
| 16 |
+
4409 INVALID UNICODE 0xe293
|
| 17 |
+
3881 INVALID UNICODE 0xf523
|
| 18 |
+
3788 INVALID UNICODE 0xe233
|
| 19 |
+
2448 INVALID UNICODE 0xf50f
|
| 20 |
+
2177 INVALID UNICODE 0xe232
|
| 21 |
+
1955 INVALID UNICODE 0xea7b
|
| 22 |
+
1926 INVALID UNICODE 0xf172
|
| 23 |
+
973 INVALID UNICODE 0xe290
|
| 24 |
+
972 INVALID UNICODE 0xf519
|
| 25 |
+
661 INVALID UNICODE 0xe292
|
| 26 |
+
591 INVALID UNICODE 0xe328
|
| 27 |
+
509 INVALID UNICODE 0xe2fa
|
| 28 |
+
458 INVALID UNICODE 0xe234
|
| 29 |
+
446 INVALID UNICODE 0xe043
|
| 30 |
+
419 INVALID UNICODE 0xe040
|
| 31 |
+
399 INVALID UNICODE 0xe2fb
|
| 32 |
+
387 INVALID UNICODE 0xe32b
|
| 33 |
+
381 INVALID UNICODE 0xe236
|
| 34 |
+
374 INVALID UNICODE 0xf511
|
| 35 |
+
314 INVALID UNICODE 0xe517
|
| 36 |
+
296 INVALID UNICODE 0xe2fe
|
| 37 |
+
293 INVALID UNICODE 0xe492
|
| 38 |
+
291 INVALID UNICODE 0xf52d
|
| 39 |
+
289 INVALID UNICODE 0xe2fc
|
| 40 |
+
195 INVALID UNICODE 0xf521
|
| 41 |
+
190 INVALID UNICODE 0xe516
|
| 42 |
+
182 INVALID UNICODE 0xe041
|
| 43 |
+
178 INVALID UNICODE 0xf529
|
| 44 |
+
113 INVALID UNICODE 0xe2f9
|
| 45 |
+
87 INVALID UNICODE 0xe2d9
|
| 46 |
+
78 INVALID UNICODE 0xe32a
|
| 47 |
+
76 INVALID UNICODE 0xe291
|
| 48 |
+
74 INVALID UNICODE 0xe296
|
| 49 |
+
66 INVALID UNICODE 0xe518
|
| 50 |
+
52 INVALID UNICODE 0xe32c
|
| 51 |
+
46 INVALID UNICODE 0xe2db
|
| 52 |
+
41 INVALID UNICODE 0xe231
|
| 53 |
+
34 INVALID UNICODE 0xf522
|
| 54 |
+
33 INVALID UNICODE 0xf518
|
| 55 |
+
32 INVALID UNICODE 0xf513
|
| 56 |
+
27 INVALID UNICODE 0xe32d
|
| 57 |
+
25 INVALID UNICODE 0xe32e
|
| 58 |
+
23 INVALID UNICODE 0xe06b
|
| 59 |
+
15 INVALID UNICODE 0xea01
|
| 60 |
+
12 INVALID UNICODE 0xe294
|
| 61 |
+
11 INVALID UNICODE 0xe203
|
| 62 |
+
8 INVALID UNICODE 0xf218
|
| 63 |
+
7 INVALID UNICODE 0xe070
|
| 64 |
+
7 INVALID UNICODE 0xe013
|
| 65 |
+
5 INVALID UNICODE 0xe2de
|
| 66 |
+
4 INVALID UNICODE 0xe493
|
| 67 |
+
3 INVALID UNICODE 0xf7e8
|
| 68 |
+
3 INVALID UNICODE 0xf7d0
|
| 69 |
+
3 INVALID UNICODE 0xe313
|
| 70 |
+
2 INVALID UNICODE 0xe329
|
| 71 |
+
2 INVALID UNICODE 0xe06d
|
| 72 |
+
2 INVALID UNICODE 0xe003
|
| 73 |
+
1 INVALID UNICODE 0xf50e
|
| 74 |
+
1 INVALID UNICODE 0xf171
|
| 75 |
+
1 INVALID UNICODE 0xe01d
|
| 76 |
+
71 NOMINAL DIGIT SHAPES 0x206f
|
| 77 |
+
3 WORD JOINER 0x2060
|
| 78 |
+
― 126545 HORIZONTAL BAR 0x2015
|
| 79 |
+
־ 1028 HEBREW PUNCTUATION MAQAF 0x5be
|
| 80 |
+
) 98429 RIGHT PARENTHESIS 0x29
|
| 81 |
+
] 27108 RIGHT SQUARE BRACKET 0x5d
|
| 82 |
+
⌋ 1567 RIGHT FLOOR 0x230b
|
| 83 |
+
〕 97 RIGHT TORTOISE SHELL BRACKET 0x3015
|
| 84 |
+
】 36 RIGHT BLACK LENTICULAR BRACKET 0x3011
|
| 85 |
+
﴾ 14 ORNATE LEFT PARENTHESIS 0xfd3e
|
| 86 |
+
& 170517 AMPERSAND 0x26
|
| 87 |
+
། 106330 TIBETAN MARK SHAD 0xf0d
|
| 88 |
+
። 90203 ETHIOPIC FULL STOP 0x1362
|
| 89 |
+
፥ 60484 ETHIOPIC COLON 0x1365
|
| 90 |
+
༌ 60464 TIBETAN MARK DELIMITER TSHEG BSTAR 0xf0c
|
| 91 |
+
။ 51567 MYANMAR SIGN SECTION 0x104b
|
| 92 |
+
/ 46929 SOLIDUS 0x2f
|
| 93 |
+
၊ 38042 MYANMAR SIGN LITTLE SECTION 0x104a
|
| 94 |
+
· 37985 MIDDLE DOT 0xb7
|
| 95 |
+
‸ 36310 CARET 0x2038
|
| 96 |
+
* 34793 ASTERISK 0x2a
|
| 97 |
+
۔ 32432 ARABIC FULL STOP 0x6d4
|
| 98 |
+
፤ 31906 ETHIOPIC SEMICOLON 0x1364
|
| 99 |
+
၏ 21519 MYANMAR SYMBOL GENITIVE 0x104f
|
| 100 |
+
។ 20834 KHMER SIGN KHAN 0x17d4
|
| 101 |
+
꓾ 15773 LISU PUNCTUATION COMMA 0xa4fe
|
| 102 |
+
᙮ 13473 CANADIAN SYLLABICS FULL STOP 0x166e
|
| 103 |
+
꤯ 12892 KAYAH LI SIGN SHYA 0xa92f
|
| 104 |
+
⵰ 11478 TIFINAGH SEPARATOR MARK 0x2d70
|
| 105 |
+
꓿ 11118 LISU PUNCTUATION FULL STOP 0xa4ff
|
| 106 |
+
॥ 10763 DEVANAGARI DOUBLE DANDA 0x965
|
| 107 |
+
؞ 10403 ARABIC TRIPLE DOT PUNCTUATION MARK 0x61e
|
| 108 |
+
၍ 8936 MYANMAR SYMBOL COMPLETED 0x104d
|
| 109 |
+
· 8431 GREEK ANO TELEIA 0x387
|
| 110 |
+
† 7477 DAGGER 0x2020
|
| 111 |
+
၌ 6632 MYANMAR SYMBOL LOCATIVE 0x104c
|
| 112 |
+
፣ 5719 ETHIOPIC COMMA 0x1363
|
| 113 |
+
៖ 5528 KHMER SIGN CAMNUC PII KUUH 0x17d6
|
| 114 |
+
꤮ 4791 KAYAH LI SIGN CWI 0xa92e
|
| 115 |
+
※ 3439 REFERENCE MARK 0x203b
|
| 116 |
+
፦ 2727 ETHIOPIC PREFACE COLON 0x1366
|
| 117 |
+
• 1749 BULLET 0x2022
|
| 118 |
+
¶ 1507 PILCROW SIGN 0xb6
|
| 119 |
+
၎ 1386 MYANMAR SYMBOL AFOREMENTIONED 0x104e
|
| 120 |
+
﹖ 1224 SMALL QUESTION MARK 0xfe56
|
| 121 |
+
; 975 GREEK QUESTION MARK 0x37e
|
| 122 |
+
… 827 HORIZONTAL ELLIPSIS 0x2026
|
| 123 |
+
% 617 PERCENT SIGN 0x25
|
| 124 |
+
・ 468 KATAKANA MIDDLE DOT 0x30fb
|
| 125 |
+
༎ 306 TIBETAN MARK NYIS SHAD 0xf0e
|
| 126 |
+
‡ 140 DOUBLE DAGGER 0x2021
|
| 127 |
+
# 137 NUMBER SIGN 0x23
|
| 128 |
+
@ 125 COMMERCIAL AT 0x40
|
| 129 |
+
፡ 121 ETHIOPIC WORDSPACE 0x1361
|
| 130 |
+
៚ 55 KHMER SIGN KOOMUUT 0x17da
|
| 131 |
+
៕ 49 KHMER SIGN BARIYOOSAN 0x17d5
|
| 132 |
+
﹐ 10 SMALL COMMA 0xfe50
|
| 133 |
+
༅ 6 TIBETAN MARK CLOSING YIG MGO SGAB MA 0xf05
|
| 134 |
+
༄ 6 TIBETAN MARK INITIAL YIG MGO MDUN MA 0xf04
|
| 135 |
+
. 2 FULLWIDTH FULL STOP 0xff0e
|
| 136 |
+
﹗ 2 SMALL EXCLAMATION MARK 0xfe57
|
| 137 |
+
﹕ 2 SMALL COLON 0xfe55
|
| 138 |
+
‰ 2 PER MILLE SIGN 0x2030
|
| 139 |
+
・ 1 HALFWIDTH KATAKANA MIDDLE DOT 0xff65
|
| 140 |
+
( 98504 LEFT PARENTHESIS 0x28
|
| 141 |
+
[ 27245 LEFT SQUARE BRACKET 0x5b
|
| 142 |
+
⌊ 1567 LEFT FLOOR 0x230a
|
| 143 |
+
〔 95 LEFT TORTOISE SHELL BRACKET 0x3014
|
| 144 |
+
【 36 LEFT BLACK LENTICULAR BRACKET 0x3010
|
| 145 |
+
﴿ 14 ORNATE RIGHT PARENTHESIS 0xfd3f
|
| 146 |
+
_ 4851 LOW LINE 0x5f
|
| 147 |
+
$ 72 DOLLAR SIGN 0x24
|
| 148 |
+
€ 14 EURO SIGN 0x20ac
|
| 149 |
+
£ 2 POUND SIGN 0xa3
|
| 150 |
+
~ 27462 TILDE 0x7e
|
| 151 |
+
= 11450 EQUALS SIGN 0x3d
|
| 152 |
+
| 8430 VERTICAL LINE 0x7c
|
| 153 |
+
− 3971 MINUS SIGN 0x2212
|
| 154 |
+
≫ 1904 MUCH GREATER-THAN 0x226b
|
| 155 |
+
≪ 1903 MUCH LESS-THAN 0x226a
|
| 156 |
+
+ 1450 PLUS SIGN 0x2b
|
| 157 |
+
< 345 FULLWIDTH LESS-THAN SIGN 0xff1c
|
| 158 |
+
> 344 FULLWIDTH GREATER-THAN SIGN 0xff1e
|
| 159 |
+
¬ 5 NOT SIGN 0xac
|
| 160 |
+
× 4 MULTIPLICATION SIGN 0xd7
|
| 161 |
+
→ 2 RIGHTWARDS ARROW 0x2192
|
| 162 |
+
᙭ 537 CANADIAN SYLLABICS CHI SIGN 0x166d
|
| 163 |
+
° 499 DEGREE SIGN 0xb0
|
| 164 |
+
႟ 421 MYANMAR SYMBOL SHAN EXCLAMATION 0x109f
|
| 165 |
+
� 192 REPLACEMENT CHARACTER 0xfffd
|
| 166 |
+
⌟ 54 BOTTOM RIGHT CORNER 0x231f
|
| 167 |
+
⌞ 54 BOTTOM LEFT CORNER 0x231e
|
| 168 |
+
© 2 COPYRIGHT SIGN 0xa9
|
| 169 |
+
40 NARROW NO-BREAK SPACE 0x202f
|
| 170 |
+
1 SIX-PER-EM SPACE 0x2006
|
| 171 |
+
˜ 40261 SMALL TILDE 0x2dc
|
| 172 |
+
^ 6469 CIRCUMFLEX ACCENT 0x5e
|
| 173 |
+
¯ 20 MACRON 0xaf
|
| 174 |
+
ˇ 191442 CARON 0x2c7
|
| 175 |
+
ⁿ 38144 SUPERSCRIPT LATIN SMALL LETTER N 0x207f
|
| 176 |
+
ـ 9440 ARABIC TATWEEL 0x640
|
| 177 |
+
ๆ 6766 THAI CHARACTER MAIYAMOK 0xe46
|
| 178 |
+
ៗ 3310 KHMER SIGN LEK TOO 0x17d7
|
| 179 |
+
々 678 IDEOGRAPHIC ITERATION MARK 0x3005
|
| 180 |
+
ໆ 430 LAO KO LA 0xec6
|
| 181 |
+
ー 319 KATAKANA-HIRAGANA PROLONGED SOUND MARK 0x30fc
|
| 182 |
+
ⁱ 137 SUPERSCRIPT LATIN SMALL LETTER I 0x2071
|
| 183 |
+
৷ 11056 BENGALI CURRENCY NUMERATOR FOUR 0x9f7
|
| 184 |
+
⅓ 26 VULGAR FRACTION ONE THIRD 0x2153
|
| 185 |
+
½ 26 VULGAR FRACTION ONE HALF 0xbd
|
| 186 |
+
¼ 4 VULGAR FRACTION ONE QUARTER 0xbc
|
| 187 |
+
⅟ 1 FRACTION NUMERATOR ONE 0x215f
|
| 188 |
+
⁄ 57 FRACTION SLASH 0x2044
|
omnivoice/eval/wer/seedtts.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Computes word error rate (WER) with Whisper-large-v3 for English and
|
| 20 |
+
Paraformer for Chinese. Intended to evaluate WERs on Seed-TTS test sets.
|
| 21 |
+
"""
|
| 22 |
+
import argparse
|
| 23 |
+
import logging
|
| 24 |
+
import multiprocessing as mp
|
| 25 |
+
import os
|
| 26 |
+
import string
|
| 27 |
+
import traceback
|
| 28 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
import zhconv
|
| 34 |
+
from tqdm import tqdm
|
| 35 |
+
from zhon.hanzi import punctuation
|
| 36 |
+
|
| 37 |
+
from omnivoice.eval.utils import load_waveform
|
| 38 |
+
from omnivoice.eval.wer.common import process_one
|
| 39 |
+
from omnivoice.utils.data_utils import read_test_list
|
| 40 |
+
|
| 41 |
+
# --- Global variables for worker processes ---
|
| 42 |
+
worker_pipe = None
|
| 43 |
+
worker_device = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_parser():
|
| 47 |
+
parser = argparse.ArgumentParser(
|
| 48 |
+
description="Computes WER with Whisper/Paraformer.",
|
| 49 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--wav-path",
|
| 53 |
+
type=str,
|
| 54 |
+
required=True,
|
| 55 |
+
help="Path to the directory containing speech files.",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--extension",
|
| 59 |
+
type=str,
|
| 60 |
+
default="wav",
|
| 61 |
+
help="Extension of the speech files. Default: wav",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--decode-path",
|
| 65 |
+
type=str,
|
| 66 |
+
default=None,
|
| 67 |
+
help="Path to the output file where WER information will be saved. "
|
| 68 |
+
"If not provided, results are only printed to console.",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--model-dir",
|
| 72 |
+
type=str,
|
| 73 |
+
required=True,
|
| 74 |
+
help="Local path of evaluation models repository. "
|
| 75 |
+
"Download from https://huggingface.co/k2-fsa/TTS_eval_models. "
|
| 76 |
+
"This script expects 'tts_eval_models/wer/whisper-large-v3/' for English "
|
| 77 |
+
"and 'tts_eval_models/wer/paraformer-zh/' for Chinese within this directory.",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--test-list",
|
| 81 |
+
type=str,
|
| 82 |
+
default="test.jsonl",
|
| 83 |
+
help="path of the JSONL test list. Each line is a JSON object "
|
| 84 |
+
"with fields: id, text, ref_audio, ref_text, language_id, language_name.",
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--lang",
|
| 88 |
+
type=str,
|
| 89 |
+
choices=["zh", "en"],
|
| 90 |
+
required=True,
|
| 91 |
+
help="Language of the audio and transcripts for "
|
| 92 |
+
"decoding ('zh' for Chinese or 'en' for English).",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--batch-size",
|
| 96 |
+
type=int,
|
| 97 |
+
default=16,
|
| 98 |
+
help="Batch size for decoding with the Hugging Face pipeline.",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--nj-per-gpu", type=int, default=1, help="Number of workers per GPU."
|
| 102 |
+
)
|
| 103 |
+
return parser
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def load_whisper_model(model_dir, device):
|
| 107 |
+
model_path = os.path.join(model_dir, "wer/whisper-large-v3/")
|
| 108 |
+
if not os.path.exists(model_path):
|
| 109 |
+
logging.error(f"Whisper model not found at {model_path}.")
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
logging.debug(f"Loading Whisper model on {device}...")
|
| 113 |
+
|
| 114 |
+
import transformers
|
| 115 |
+
|
| 116 |
+
# Suppress transformers logging
|
| 117 |
+
transformers.logging.set_verbosity_error()
|
| 118 |
+
|
| 119 |
+
pipe = transformers.pipeline(
|
| 120 |
+
"automatic-speech-recognition",
|
| 121 |
+
model=model_path,
|
| 122 |
+
dtype=torch.float16 if "cuda" in str(device) else torch.float32,
|
| 123 |
+
device=device,
|
| 124 |
+
)
|
| 125 |
+
return pipe
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_paraformer_model(model_dir, device):
|
| 129 |
+
model_path = os.path.join(model_dir, "wer/paraformer-zh/")
|
| 130 |
+
if not os.path.exists(model_path):
|
| 131 |
+
logging.error(f"Paraformer model not found at {model_path}.")
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
logging.debug(f"Loading Paraformer model on {device}...")
|
| 135 |
+
|
| 136 |
+
previous_level = logging.root.manager.disable
|
| 137 |
+
logging.disable(logging.CRITICAL)
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
from funasr import AutoModel
|
| 141 |
+
|
| 142 |
+
# FunASR AutoModel accepts "cuda:0" string or torch.device
|
| 143 |
+
model = AutoModel(
|
| 144 |
+
model=model_path,
|
| 145 |
+
device=str(device),
|
| 146 |
+
disable_update=True,
|
| 147 |
+
disable_pbar=True,
|
| 148 |
+
verbose=False,
|
| 149 |
+
)
|
| 150 |
+
finally:
|
| 151 |
+
logging.disable(previous_level)
|
| 152 |
+
|
| 153 |
+
return model
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def post_process(text: str, lang: str) -> str:
|
| 157 |
+
"""
|
| 158 |
+
Cleans and normalizes text for WER calculation.
|
| 159 |
+
Args:
|
| 160 |
+
text (str): The input text to be processed.
|
| 161 |
+
lang (str): The language of the input text.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
str: The cleaned and normalized text.
|
| 165 |
+
"""
|
| 166 |
+
punctuation_all = punctuation + string.punctuation
|
| 167 |
+
for x in punctuation_all:
|
| 168 |
+
if x == "'":
|
| 169 |
+
continue
|
| 170 |
+
text = text.replace(x, "")
|
| 171 |
+
|
| 172 |
+
text = text.replace(" ", " ")
|
| 173 |
+
|
| 174 |
+
if lang == "zh":
|
| 175 |
+
text = " ".join([x for x in text])
|
| 176 |
+
elif lang == "en":
|
| 177 |
+
text = text.lower()
|
| 178 |
+
else:
|
| 179 |
+
raise NotImplementedError
|
| 180 |
+
return text
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def process_init(rank_queue, model_dir, lang):
|
| 184 |
+
"""
|
| 185 |
+
Initializer for each worker process.
|
| 186 |
+
Loads model onto a specific GPU, once per process.
|
| 187 |
+
"""
|
| 188 |
+
global worker_pipe, worker_device
|
| 189 |
+
|
| 190 |
+
torch.set_num_threads(2)
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
rank = rank_queue.get(timeout=10)
|
| 194 |
+
except Exception:
|
| 195 |
+
raise RuntimeError("Failed to get GPU rank from queue.")
|
| 196 |
+
|
| 197 |
+
assert torch.cuda.is_available(), "CUDA is required but not available."
|
| 198 |
+
worker_device = torch.device(f"cuda:{rank}")
|
| 199 |
+
torch.cuda.set_device(rank)
|
| 200 |
+
|
| 201 |
+
logging.info(f"Initializing worker on device: {worker_device}")
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
if lang == "en":
|
| 205 |
+
worker_pipe = load_whisper_model(model_dir, worker_device)
|
| 206 |
+
elif lang == "zh":
|
| 207 |
+
worker_pipe = load_paraformer_model(model_dir, worker_device)
|
| 208 |
+
if worker_pipe is None:
|
| 209 |
+
raise RuntimeError("Model loading failed.")
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logging.critical(f"Failed to load model on {worker_device}: {e}")
|
| 212 |
+
raise e
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def run_eval_worker(data_chunk, lang, batch_size):
|
| 216 |
+
"""
|
| 217 |
+
Worker function to process a chunk of data.
|
| 218 |
+
Uses the global worker_pipe initialized by process_init.
|
| 219 |
+
"""
|
| 220 |
+
global worker_pipe
|
| 221 |
+
if worker_pipe is None:
|
| 222 |
+
logging.error("Worker pipeline is not initialized!")
|
| 223 |
+
return []
|
| 224 |
+
|
| 225 |
+
metrics_buffer = []
|
| 226 |
+
try:
|
| 227 |
+
if lang == "en":
|
| 228 |
+
# Load waveforms as arrays, truncating to 30s
|
| 229 |
+
dataset = [
|
| 230 |
+
{
|
| 231 |
+
"array": load_waveform(
|
| 232 |
+
item["wav_path"], sample_rate=16000, return_numpy=True
|
| 233 |
+
)[: 16000 * 30],
|
| 234 |
+
"sampling_rate": 16000,
|
| 235 |
+
}
|
| 236 |
+
for item in data_chunk
|
| 237 |
+
]
|
| 238 |
+
generate_kwargs = {"language": "english", "task": "transcribe"}
|
| 239 |
+
|
| 240 |
+
iterator = worker_pipe(
|
| 241 |
+
dataset, generate_kwargs=generate_kwargs, batch_size=batch_size
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
for i, out in enumerate(iterator):
|
| 245 |
+
hypothesis = out["text"].strip()
|
| 246 |
+
ref_item = data_chunk[i]
|
| 247 |
+
truth = ref_item["truth_text"]
|
| 248 |
+
wav_path = ref_item["wav_path"]
|
| 249 |
+
|
| 250 |
+
m = process_one(hypothesis, truth, post_process, lang)
|
| 251 |
+
m["wav_path"] = wav_path
|
| 252 |
+
metrics_buffer.append(m)
|
| 253 |
+
|
| 254 |
+
elif lang == "zh":
|
| 255 |
+
wav_paths = [item["wav_path"] for item in data_chunk]
|
| 256 |
+
|
| 257 |
+
for i in range(0, len(wav_paths), batch_size):
|
| 258 |
+
batch_paths = wav_paths[i : i + batch_size]
|
| 259 |
+
res_batch = worker_pipe.generate(
|
| 260 |
+
input=batch_paths, batch_size=batch_size, disable_pbar=True
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
for j, res in enumerate(res_batch):
|
| 264 |
+
hypothesis = zhconv.convert(res["text"], "zh-cn")
|
| 265 |
+
ref_item = data_chunk[i + j]
|
| 266 |
+
truth = ref_item["truth_text"]
|
| 267 |
+
wav_path = ref_item["wav_path"]
|
| 268 |
+
|
| 269 |
+
m = process_one(hypothesis, truth, post_process, lang)
|
| 270 |
+
m["wav_path"] = wav_path
|
| 271 |
+
metrics_buffer.append(m)
|
| 272 |
+
|
| 273 |
+
except Exception:
|
| 274 |
+
logging.error(
|
| 275 |
+
f"Worker failed on chunk (Lang: {lang}):\n{traceback.format_exc()}"
|
| 276 |
+
)
|
| 277 |
+
return []
|
| 278 |
+
|
| 279 |
+
return metrics_buffer
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def main():
|
| 283 |
+
parser = get_parser()
|
| 284 |
+
args = parser.parse_args()
|
| 285 |
+
|
| 286 |
+
logging.basicConfig(
|
| 287 |
+
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
|
| 288 |
+
level=logging.INFO,
|
| 289 |
+
force=True,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
logging.info(f"Calculating WER for {args.wav_path}")
|
| 293 |
+
|
| 294 |
+
# 1. Prepare Data
|
| 295 |
+
logging.info("Reading test list...")
|
| 296 |
+
data_list = []
|
| 297 |
+
samples = read_test_list(args.test_list)
|
| 298 |
+
for s in samples:
|
| 299 |
+
wav_path = str(Path(args.wav_path) / f"{s['id']}.{args.extension}")
|
| 300 |
+
if not os.path.exists(wav_path):
|
| 301 |
+
logging.warning(f"File missing: {wav_path}")
|
| 302 |
+
continue
|
| 303 |
+
data_list.append({"wav_path": wav_path, "truth_text": s["text"]})
|
| 304 |
+
total_files = len(data_list)
|
| 305 |
+
logging.info(f"Total files: {total_files}.")
|
| 306 |
+
|
| 307 |
+
# 2. Worker config
|
| 308 |
+
num_gpus = torch.cuda.device_count()
|
| 309 |
+
assert num_gpus > 0, "No GPU found. GPU is required."
|
| 310 |
+
total_workers = num_gpus * args.nj_per_gpu
|
| 311 |
+
|
| 312 |
+
mp.set_start_method("spawn", force=True)
|
| 313 |
+
manager = mp.Manager()
|
| 314 |
+
rank_queue = manager.Queue()
|
| 315 |
+
|
| 316 |
+
for _ in range(args.nj_per_gpu):
|
| 317 |
+
for rank in range(num_gpus):
|
| 318 |
+
rank_queue.put(rank)
|
| 319 |
+
|
| 320 |
+
# 3. Scheduling: Split data into chunks for better load balancing
|
| 321 |
+
chunk_size = max(1, args.batch_size)
|
| 322 |
+
tasks = []
|
| 323 |
+
for i in range(0, total_files, chunk_size):
|
| 324 |
+
tasks.append(data_list[i : i + chunk_size])
|
| 325 |
+
|
| 326 |
+
logging.info(
|
| 327 |
+
f"Split data into {len(tasks)} chunks (size ~{chunk_size}). "
|
| 328 |
+
f"Spawning {total_workers} workers."
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# 4. Execution
|
| 332 |
+
results = []
|
| 333 |
+
|
| 334 |
+
with ProcessPoolExecutor(
|
| 335 |
+
max_workers=total_workers,
|
| 336 |
+
initializer=process_init,
|
| 337 |
+
initargs=(rank_queue, args.model_dir, args.lang),
|
| 338 |
+
) as executor:
|
| 339 |
+
|
| 340 |
+
futures = []
|
| 341 |
+
for chunk in tasks:
|
| 342 |
+
futures.append(
|
| 343 |
+
executor.submit(run_eval_worker, chunk, args.lang, args.batch_size)
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# Unified progress bar
|
| 347 |
+
with tqdm(total=total_files, desc="Eval Progress", dynamic_ncols=True) as pbar:
|
| 348 |
+
for future in as_completed(futures):
|
| 349 |
+
try:
|
| 350 |
+
chunk_metrics = future.result()
|
| 351 |
+
results.extend(chunk_metrics)
|
| 352 |
+
pbar.update(len(chunk_metrics))
|
| 353 |
+
except Exception as e:
|
| 354 |
+
logging.error(f"Task failed: {e}")
|
| 355 |
+
|
| 356 |
+
wers, inses, deles, subses = [], [], [], []
|
| 357 |
+
word_nums = 0
|
| 358 |
+
|
| 359 |
+
fout = None
|
| 360 |
+
if args.decode_path:
|
| 361 |
+
os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
|
| 362 |
+
fout = open(args.decode_path, "w", encoding="utf8")
|
| 363 |
+
logging.info(f"Saving detailed WER results to: {args.decode_path}")
|
| 364 |
+
fout.write(
|
| 365 |
+
"Name\tWER\tTruth\tHypothesis\tInsertions\tDeletions\tSubstitutions\n"
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
for res in results:
|
| 369 |
+
wers.append(float(res["wer"]))
|
| 370 |
+
inses.append(float(res["insertions"]))
|
| 371 |
+
deles.append(float(res["deletions"]))
|
| 372 |
+
subses.append(float(res["substitutions"]))
|
| 373 |
+
word_nums += res["word_num"]
|
| 374 |
+
|
| 375 |
+
if fout:
|
| 376 |
+
fout.write(
|
| 377 |
+
f"{res['wav_path']}\t{res['wer']}\t{res['truth']}\t"
|
| 378 |
+
f"{res['hypo']}\t{res['insertions']}\t{res['deletions']}\t"
|
| 379 |
+
f"{res['substitutions']}\n"
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
wer_avg = round(np.mean(wers) * 100, 2) if wers else float("nan")
|
| 383 |
+
wer_weighted = (
|
| 384 |
+
round(
|
| 385 |
+
(np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 2
|
| 386 |
+
)
|
| 387 |
+
if word_nums > 0
|
| 388 |
+
else float("nan")
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
inse_sum = np.sum(inses)
|
| 392 |
+
dele_sum = np.sum(deles)
|
| 393 |
+
subs_sum = np.sum(subses)
|
| 394 |
+
|
| 395 |
+
print("-" * 50)
|
| 396 |
+
logging.info(f"Processed {len(results)}/{total_files} files.")
|
| 397 |
+
seedtts_wer_info = f"Seed-TTS WER (Avg of WERs): {wer_avg}%"
|
| 398 |
+
wer_info = f"WER (Weighted): {wer_weighted}%"
|
| 399 |
+
detailed_info = (
|
| 400 |
+
f"Errors: {inse_sum} ins, {dele_sum} del, {subs_sum} sub / {word_nums} words"
|
| 401 |
+
)
|
| 402 |
+
logging.info(seedtts_wer_info)
|
| 403 |
+
logging.info(wer_info)
|
| 404 |
+
logging.info(detailed_info)
|
| 405 |
+
print("-" * 50)
|
| 406 |
+
|
| 407 |
+
if fout:
|
| 408 |
+
fout.write(seedtts_wer_info + "\n" + wer_info + "\n" + detailed_info + "\n")
|
| 409 |
+
fout.close()
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
if __name__ == "__main__":
|
| 413 |
+
main()
|
omnivoice/eval/wer/sensevoice.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Computes Character Error Rate (CER) for Cantonese (yue) using SenseVoiceSmall.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import logging
|
| 24 |
+
import multiprocessing as mp
|
| 25 |
+
import os
|
| 26 |
+
import re
|
| 27 |
+
import traceback
|
| 28 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import cn2an
|
| 32 |
+
import torch
|
| 33 |
+
import zhconv
|
| 34 |
+
from tqdm import tqdm
|
| 35 |
+
|
| 36 |
+
from omnivoice.eval.wer.common import log_metrics, process_one
|
| 37 |
+
from omnivoice.eval.wer.text_norm_omni import text_normalize
|
| 38 |
+
from omnivoice.utils.data_utils import read_test_list
|
| 39 |
+
|
| 40 |
+
# --- Global variables for worker processes ---
|
| 41 |
+
worker_sensevoice = None
|
| 42 |
+
worker_device = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_parser():
|
| 46 |
+
parser = argparse.ArgumentParser(
|
| 47 |
+
description="Computes CER for Cantonese using SenseVoiceSmall.",
|
| 48 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--wav-path",
|
| 53 |
+
type=str,
|
| 54 |
+
required=True,
|
| 55 |
+
help="Path to the directory containing speech files.",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--extension",
|
| 60 |
+
type=str,
|
| 61 |
+
default="wav",
|
| 62 |
+
help="Extension of the speech files. Default: wav",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--decode-path",
|
| 67 |
+
type=str,
|
| 68 |
+
default=None,
|
| 69 |
+
help="Path to the output file where CER information will be saved. ",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--model-dir",
|
| 73 |
+
type=str,
|
| 74 |
+
required=True,
|
| 75 |
+
help="Local path of evaluation models repository. ",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--test-list",
|
| 79 |
+
type=str,
|
| 80 |
+
default="test.jsonl",
|
| 81 |
+
help="path of the JSONL test list.",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--batch-size",
|
| 85 |
+
type=int,
|
| 86 |
+
default=16,
|
| 87 |
+
help="Batch size for decoding.",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--nj-per-gpu", type=int, default=1, help="Number of workers per GPU."
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--chunk-size",
|
| 94 |
+
type=int,
|
| 95 |
+
default=10,
|
| 96 |
+
help="Number of samples per task chunk sent to workers.",
|
| 97 |
+
)
|
| 98 |
+
return parser
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def load_sensevoice_model(model_dir, device):
|
| 102 |
+
model_path = os.path.join(model_dir, "wer/SenseVoiceSmall")
|
| 103 |
+
if not os.path.exists(model_path):
|
| 104 |
+
# Fallback if specific sensevoice spelling isn't found
|
| 105 |
+
logging.warning(
|
| 106 |
+
f"SenseVoiceSmall not found at {model_path}. "
|
| 107 |
+
f"Please ensure it is present in eval models."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
logging.info(f"Loading SenseVoice model on {device}...")
|
| 111 |
+
|
| 112 |
+
previous_level = logging.root.manager.disable
|
| 113 |
+
logging.disable(logging.CRITICAL)
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
from funasr import AutoModel
|
| 117 |
+
|
| 118 |
+
model = AutoModel(
|
| 119 |
+
model="iic/SenseVoiceSmall",
|
| 120 |
+
device=str(device),
|
| 121 |
+
disable_update=True,
|
| 122 |
+
disable_pbar=True,
|
| 123 |
+
verbose=False,
|
| 124 |
+
)
|
| 125 |
+
finally:
|
| 126 |
+
logging.disable(previous_level)
|
| 127 |
+
|
| 128 |
+
return model
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _worker_setup(rank_queue):
|
| 132 |
+
global worker_device
|
| 133 |
+
|
| 134 |
+
torch.set_num_threads(2)
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
rank = rank_queue.get(timeout=10)
|
| 138 |
+
except Exception:
|
| 139 |
+
raise RuntimeError("Failed to get GPU rank from queue.")
|
| 140 |
+
|
| 141 |
+
assert torch.cuda.is_available(), "CUDA is required but not available."
|
| 142 |
+
worker_device = torch.device(f"cuda:{rank}")
|
| 143 |
+
torch.cuda.set_device(rank)
|
| 144 |
+
|
| 145 |
+
logging.info(f"Initializing worker on device: {worker_device}")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def process_init_sensevoice(rank_queue, model_dir):
|
| 149 |
+
global worker_sensevoice
|
| 150 |
+
|
| 151 |
+
_worker_setup(rank_queue)
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
worker_sensevoice = load_sensevoice_model(model_dir, worker_device)
|
| 155 |
+
if worker_sensevoice is None:
|
| 156 |
+
raise RuntimeError("SenseVoice model loading failed.")
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logging.critical(f"Failed to load SenseVoice model on {worker_device}: {e}")
|
| 159 |
+
raise e
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def post_process(text: str, lang: str) -> str:
|
| 163 |
+
"""
|
| 164 |
+
Cleans and normalizes text for calculation.
|
| 165 |
+
"""
|
| 166 |
+
assert lang == "yue", "this script is designed for Cantonese (yue) evaluation only."
|
| 167 |
+
text = text_normalize(
|
| 168 |
+
text,
|
| 169 |
+
iso_code="yue",
|
| 170 |
+
lower_case=True,
|
| 171 |
+
remove_numbers=False,
|
| 172 |
+
remove_brackets=False,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
text = zhconv.convert(text, "zh-cn")
|
| 176 |
+
|
| 177 |
+
text = cn2an.transform(text, "an2cn")
|
| 178 |
+
|
| 179 |
+
text = text.replace(" ", "")
|
| 180 |
+
text = " ".join([x for x in text])
|
| 181 |
+
text = text.lower()
|
| 182 |
+
return text.strip()
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def run_eval_worker_sensevoice(data_chunk, batch_size):
|
| 186 |
+
global worker_sensevoice
|
| 187 |
+
if worker_sensevoice is None:
|
| 188 |
+
logging.error("SenseVoice worker pipeline is not initialized!")
|
| 189 |
+
return []
|
| 190 |
+
|
| 191 |
+
metrics_buffer = []
|
| 192 |
+
try:
|
| 193 |
+
wav_paths = [item["wav_path"] for item in data_chunk]
|
| 194 |
+
|
| 195 |
+
for i in range(0, len(wav_paths), batch_size):
|
| 196 |
+
batch_paths = wav_paths[i : i + batch_size]
|
| 197 |
+
|
| 198 |
+
# SenseVoice generate call, target lang mapped to yue
|
| 199 |
+
res_batch = worker_sensevoice.generate(
|
| 200 |
+
input=batch_paths,
|
| 201 |
+
batch_size=batch_size,
|
| 202 |
+
language="yue",
|
| 203 |
+
use_itn=False,
|
| 204 |
+
disable_pbar=True,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
for j, res in enumerate(res_batch):
|
| 208 |
+
hypothesis = res["text"]
|
| 209 |
+
# SenseVoice may format output with language tags,
|
| 210 |
+
# cleaning basic tags if any
|
| 211 |
+
hypothesis = re.sub(r"<\|[^|]*\|>", "", hypothesis).strip()
|
| 212 |
+
|
| 213 |
+
ref_item = data_chunk[i + j]
|
| 214 |
+
truth = ref_item["truth_text"]
|
| 215 |
+
wav_path = ref_item["wav_path"]
|
| 216 |
+
lang_name = ref_item.get("lang_name")
|
| 217 |
+
|
| 218 |
+
m = process_one(hypothesis, truth, post_process, "yue")
|
| 219 |
+
m["wav_path"] = wav_path
|
| 220 |
+
m["lang_name"] = lang_name
|
| 221 |
+
metrics_buffer.append(m)
|
| 222 |
+
|
| 223 |
+
except Exception:
|
| 224 |
+
logging.error(f"SenseVoice worker failed on chunk:\n{traceback.format_exc()}")
|
| 225 |
+
return []
|
| 226 |
+
|
| 227 |
+
return metrics_buffer
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def main():
|
| 231 |
+
parser = get_parser()
|
| 232 |
+
args = parser.parse_args()
|
| 233 |
+
|
| 234 |
+
logging.basicConfig(
|
| 235 |
+
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
|
| 236 |
+
level=logging.INFO,
|
| 237 |
+
force=True,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
logging.info("Reading test list and filtering for Cantonese (yue)...")
|
| 241 |
+
yue_items = []
|
| 242 |
+
wav_root = Path(args.wav_path)
|
| 243 |
+
|
| 244 |
+
samples = read_test_list(args.test_list)
|
| 245 |
+
for s in samples:
|
| 246 |
+
lang_id = s.get("language_id", "")
|
| 247 |
+
if lang_id != "yue":
|
| 248 |
+
continue
|
| 249 |
+
|
| 250 |
+
wav_path = str(wav_root / f"{s['id']}.{args.extension}")
|
| 251 |
+
if not os.path.exists(wav_path):
|
| 252 |
+
logging.warning(f"File missing: {wav_path}")
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
yue_items.append(
|
| 256 |
+
{
|
| 257 |
+
"wav_path": wav_path,
|
| 258 |
+
"truth_text": s["text"],
|
| 259 |
+
"lang_id": "yue",
|
| 260 |
+
"lang_name": s.get("language_name", "Cantonese"),
|
| 261 |
+
}
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
logging.info(f"Total Cantonese files found: {len(yue_items)}.")
|
| 265 |
+
if len(yue_items) == 0:
|
| 266 |
+
logging.warning("No files to evaluate. Exiting.")
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
num_gpus = torch.cuda.device_count()
|
| 270 |
+
assert num_gpus > 0, "No GPU found. GPU is required."
|
| 271 |
+
total_workers = num_gpus * args.nj_per_gpu
|
| 272 |
+
|
| 273 |
+
mp.set_start_method("spawn", force=True)
|
| 274 |
+
manager = mp.Manager()
|
| 275 |
+
|
| 276 |
+
chunk_size = args.chunk_size
|
| 277 |
+
tasks = []
|
| 278 |
+
for i in range(0, len(yue_items), chunk_size):
|
| 279 |
+
tasks.append(yue_items[i : i + chunk_size])
|
| 280 |
+
|
| 281 |
+
results = []
|
| 282 |
+
rank_queue = manager.Queue()
|
| 283 |
+
for _ in range(args.nj_per_gpu):
|
| 284 |
+
for rank in range(num_gpus):
|
| 285 |
+
rank_queue.put(rank)
|
| 286 |
+
|
| 287 |
+
with ProcessPoolExecutor(
|
| 288 |
+
max_workers=total_workers,
|
| 289 |
+
initializer=process_init_sensevoice,
|
| 290 |
+
initargs=(rank_queue, args.model_dir),
|
| 291 |
+
) as executor:
|
| 292 |
+
|
| 293 |
+
futures = []
|
| 294 |
+
for chunk in tasks:
|
| 295 |
+
futures.append(
|
| 296 |
+
executor.submit(run_eval_worker_sensevoice, chunk, args.batch_size)
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
with tqdm(
|
| 300 |
+
total=len(yue_items),
|
| 301 |
+
desc="SenseVoice Eval (Cantonese)",
|
| 302 |
+
dynamic_ncols=True,
|
| 303 |
+
) as pbar:
|
| 304 |
+
for future in as_completed(futures):
|
| 305 |
+
try:
|
| 306 |
+
chunk_metrics = future.result()
|
| 307 |
+
results.extend(chunk_metrics)
|
| 308 |
+
pbar.update(len(chunk_metrics))
|
| 309 |
+
except Exception as e:
|
| 310 |
+
logging.error(f"Task failed: {e}")
|
| 311 |
+
|
| 312 |
+
# Metrics Aggregation
|
| 313 |
+
inses, deles, subses = [], [], []
|
| 314 |
+
word_nums = 0
|
| 315 |
+
|
| 316 |
+
fout = None
|
| 317 |
+
if args.decode_path:
|
| 318 |
+
os.makedirs(os.path.dirname(args.decode_path), exist_ok=True)
|
| 319 |
+
logging.info(f"Saving detailed CER results to: {args.decode_path}")
|
| 320 |
+
fout = open(args.decode_path, "w", encoding="utf-8")
|
| 321 |
+
|
| 322 |
+
for res in results:
|
| 323 |
+
inses.append(float(res["insertions"]))
|
| 324 |
+
deles.append(float(res["deletions"]))
|
| 325 |
+
subses.append(float(res["substitutions"]))
|
| 326 |
+
word_nums += res["word_num"]
|
| 327 |
+
|
| 328 |
+
if fout:
|
| 329 |
+
fout.write(
|
| 330 |
+
f"{res['wav_path']}\t{res['wer']}\t{res['truth']}\t"
|
| 331 |
+
f"{res['hypo']}\t{res['insertions']}\t{res['deletions']}\t"
|
| 332 |
+
f"{res['substitutions']}\n"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
print("-" * 50)
|
| 336 |
+
if word_nums > 0:
|
| 337 |
+
log_metrics(fout, "[yue] Cantonese", inses, deles, subses, word_nums)
|
| 338 |
+
|
| 339 |
+
if fout:
|
| 340 |
+
fout.close()
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
main()
|
omnivoice/eval/wer/text_norm_omni.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the BSD-style license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
This module contains the text normalization function for WER evaluation.
|
| 10 |
+
Copied from https://github.com/facebookresearch/omnilingual-asr/blob/81f51e224ce9e74b02cc2a3eaf21b2d91d743455/workflows/dataprep/text_tools.py
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import re
|
| 14 |
+
import unicodedata
|
| 15 |
+
|
| 16 |
+
from unidecode import unidecode
|
| 17 |
+
|
| 18 |
+
import omnivoice.eval.wer.norm_config_module as norm_config_module
|
| 19 |
+
|
| 20 |
+
norm_config = norm_config_module.norm_config # type: ignore
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def text_normalize(
|
| 24 |
+
text, iso_code, lower_case=True, remove_numbers=True, remove_brackets=False
|
| 25 |
+
):
|
| 26 |
+
"""Given a text, normalize it by changing to lower case, removing punctuations, removing words that only contain digits and removing extra spaces
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
text : The string to be normalized
|
| 30 |
+
iso_code :
|
| 31 |
+
remove_numbers : Boolean flag to specify if words containing only digits should be removed
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
normalized_text : the string after all normalization
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
config = norm_config.get(iso_code, norm_config["*"])
|
| 39 |
+
|
| 40 |
+
for field in [
|
| 41 |
+
"lower_case",
|
| 42 |
+
"punc_set",
|
| 43 |
+
"del_set",
|
| 44 |
+
"mapping",
|
| 45 |
+
"digit_set",
|
| 46 |
+
"unicode_norm",
|
| 47 |
+
]:
|
| 48 |
+
if field not in config:
|
| 49 |
+
config[field] = norm_config["*"][field]
|
| 50 |
+
|
| 51 |
+
text = unicodedata.normalize(config["unicode_norm"], text)
|
| 52 |
+
|
| 53 |
+
# Convert to lower case
|
| 54 |
+
|
| 55 |
+
if config["lower_case"] and lower_case:
|
| 56 |
+
text = text.lower()
|
| 57 |
+
|
| 58 |
+
# brackets
|
| 59 |
+
|
| 60 |
+
# always text inside brackets with numbers in them. Usually corresponds to "(Sam 23:17)"
|
| 61 |
+
text = re.sub(r"\([^\)]*\d[^\)]*\)", " ", text)
|
| 62 |
+
if remove_brackets:
|
| 63 |
+
text = re.sub(r"\([^\)]*\)", " ", text)
|
| 64 |
+
|
| 65 |
+
# Apply mappings
|
| 66 |
+
|
| 67 |
+
for old, new in config["mapping"].items():
|
| 68 |
+
text = re.sub(old, new, text)
|
| 69 |
+
|
| 70 |
+
# Replace punctutations with space
|
| 71 |
+
|
| 72 |
+
punct_pattern = r"[" + config["punc_set"]
|
| 73 |
+
|
| 74 |
+
punct_pattern += "]"
|
| 75 |
+
|
| 76 |
+
normalized_text = re.sub(punct_pattern, " ", text)
|
| 77 |
+
|
| 78 |
+
# remove characters in delete list
|
| 79 |
+
|
| 80 |
+
delete_patten = r"[" + config["del_set"] + "]"
|
| 81 |
+
|
| 82 |
+
normalized_text = re.sub(delete_patten, "", normalized_text)
|
| 83 |
+
|
| 84 |
+
# Remove words containing only digits
|
| 85 |
+
# We check for 3 cases a)text starts with a number b) a number is present somewhere in the middle of the text c) the text ends with a number
|
| 86 |
+
# For each case we use lookaround regex pattern to see if the digit pattern in preceded and followed by whitespaces, only then we replace the numbers with space
|
| 87 |
+
# The lookaround enables overlapping pattern matches to be replaced
|
| 88 |
+
|
| 89 |
+
if remove_numbers:
|
| 90 |
+
|
| 91 |
+
digits_pattern = "[" + config["digit_set"]
|
| 92 |
+
|
| 93 |
+
digits_pattern += "]+"
|
| 94 |
+
|
| 95 |
+
complete_digit_pattern = (
|
| 96 |
+
r"^"
|
| 97 |
+
+ digits_pattern
|
| 98 |
+
+ r"(?=\s)|(?<=\s)"
|
| 99 |
+
+ digits_pattern
|
| 100 |
+
+ r"(?=\s)|(?<=\s)"
|
| 101 |
+
+ digits_pattern
|
| 102 |
+
+ "$"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
normalized_text = re.sub(complete_digit_pattern, " ", normalized_text)
|
| 106 |
+
|
| 107 |
+
if config["rm_diacritics"]:
|
| 108 |
+
normalized_text = unidecode(normalized_text)
|
| 109 |
+
|
| 110 |
+
# Remove extra spaces
|
| 111 |
+
normalized_text = re.sub(r"\s+", " ", normalized_text).strip()
|
| 112 |
+
|
| 113 |
+
return normalized_text
|
omnivoice/models/__init__.py
ADDED
|
File without changes
|
omnivoice/models/omnivoice.py
ADDED
|
@@ -0,0 +1,1502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Core OmniVoice model implementation.
|
| 19 |
+
|
| 20 |
+
Defines the ``OmniVoice`` model class, generation config, and inference pipeline.
|
| 21 |
+
This is the main entry point for both inference and training:
|
| 22 |
+
|
| 23 |
+
- **Inference**: ``OmniVoice.from_pretrained()`` loads the model, then
|
| 24 |
+
``model.generate()`` supports voice cloning, voice design, and auto voice.
|
| 25 |
+
- **Training**: ``model.forward()`` computes the training loss; the model is
|
| 26 |
+
built and used by ``omnivoice.training.builder`` and ``omnivoice.training.trainer``.
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import difflib
|
| 31 |
+
import logging
|
| 32 |
+
import math
|
| 33 |
+
import os
|
| 34 |
+
import re
|
| 35 |
+
from dataclasses import dataclass, fields
|
| 36 |
+
from functools import partial
|
| 37 |
+
from typing import Any, List, Optional, Union
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
import torch.nn as nn
|
| 41 |
+
import torch.nn.functional as F
|
| 42 |
+
import torchaudio
|
| 43 |
+
from torch.nn.attention.flex_attention import create_block_mask
|
| 44 |
+
from transformers import (
|
| 45 |
+
AutoFeatureExtractor,
|
| 46 |
+
AutoModel,
|
| 47 |
+
AutoTokenizer,
|
| 48 |
+
HiggsAudioV2TokenizerModel,
|
| 49 |
+
PretrainedConfig,
|
| 50 |
+
PreTrainedModel,
|
| 51 |
+
)
|
| 52 |
+
from transformers.modeling_outputs import ModelOutput
|
| 53 |
+
from transformers.models.auto import CONFIG_MAPPING, AutoConfig
|
| 54 |
+
|
| 55 |
+
from omnivoice.utils.audio import (
|
| 56 |
+
cross_fade_chunks,
|
| 57 |
+
fade_and_pad_audio,
|
| 58 |
+
load_audio,
|
| 59 |
+
remove_silence,
|
| 60 |
+
trim_long_audio,
|
| 61 |
+
)
|
| 62 |
+
from omnivoice.utils.duration import RuleDurationEstimator
|
| 63 |
+
from omnivoice.utils.lang_map import LANG_IDS, LANG_NAMES
|
| 64 |
+
from omnivoice.utils.text import add_punctuation, chunk_text_punctuation
|
| 65 |
+
from omnivoice.utils.voice_design import (
|
| 66 |
+
_INSTRUCT_ALL_VALID,
|
| 67 |
+
_INSTRUCT_EN_TO_ZH,
|
| 68 |
+
_INSTRUCT_MUTUALLY_EXCLUSIVE,
|
| 69 |
+
_INSTRUCT_VALID_EN,
|
| 70 |
+
_INSTRUCT_VALID_ZH,
|
| 71 |
+
_INSTRUCT_ZH_TO_EN,
|
| 72 |
+
_ZH_RE,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
logger = logging.getLogger(__name__)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
# Dataclasses
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
class VoiceClonePrompt:
|
| 85 |
+
ref_audio_tokens: torch.Tensor # (C, T)
|
| 86 |
+
ref_text: str
|
| 87 |
+
ref_rms: float
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class OmniVoiceGenerationConfig:
|
| 92 |
+
num_step: int = 32
|
| 93 |
+
guidance_scale: float = 2.0
|
| 94 |
+
t_shift: float = 0.1
|
| 95 |
+
layer_penalty_factor: float = 5.0
|
| 96 |
+
position_temperature: float = 5.0
|
| 97 |
+
class_temperature: float = 0.0
|
| 98 |
+
denoise: bool = True
|
| 99 |
+
preprocess_prompt: bool = True
|
| 100 |
+
postprocess_output: bool = True
|
| 101 |
+
audio_chunk_duration: float = 15.0
|
| 102 |
+
audio_chunk_threshold: float = 30.0
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def from_dict(cls, kwargs_dict):
|
| 106 |
+
valid_keys = {f.name for f in fields(cls)}
|
| 107 |
+
filtered = {k: v for k, v in kwargs_dict.items() if k in valid_keys}
|
| 108 |
+
return cls(**filtered)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@dataclass
|
| 112 |
+
class GenerationTask:
|
| 113 |
+
batch_size: int
|
| 114 |
+
texts: List[str]
|
| 115 |
+
target_lens: List[int]
|
| 116 |
+
langs: List[Optional[str]]
|
| 117 |
+
instructs: List[Optional[str]]
|
| 118 |
+
ref_texts: List[Optional[str]]
|
| 119 |
+
ref_audio_tokens: List[Optional[torch.Tensor]]
|
| 120 |
+
ref_rms: List[Optional[float]]
|
| 121 |
+
speed: Optional[List[float]] = None
|
| 122 |
+
|
| 123 |
+
def get_indices(self, config: OmniVoiceGenerationConfig, frame_rate: int):
|
| 124 |
+
threshold = int(config.audio_chunk_threshold * frame_rate)
|
| 125 |
+
short_idx = [i for i, l in enumerate(self.target_lens) if l <= threshold]
|
| 126 |
+
long_idx = [i for i, l in enumerate(self.target_lens) if l > threshold]
|
| 127 |
+
return short_idx, long_idx
|
| 128 |
+
|
| 129 |
+
def slice_task(self, indices: List[int]):
|
| 130 |
+
if not indices:
|
| 131 |
+
return None
|
| 132 |
+
return GenerationTask(
|
| 133 |
+
batch_size=len(indices),
|
| 134 |
+
texts=[self.texts[i] for i in indices],
|
| 135 |
+
target_lens=[self.target_lens[i] for i in indices],
|
| 136 |
+
langs=[self.langs[i] for i in indices],
|
| 137 |
+
instructs=[self.instructs[i] for i in indices],
|
| 138 |
+
ref_texts=[self.ref_texts[i] for i in indices],
|
| 139 |
+
ref_audio_tokens=[self.ref_audio_tokens[i] for i in indices],
|
| 140 |
+
ref_rms=[self.ref_rms[i] for i in indices],
|
| 141 |
+
speed=[self.speed[i] for i in indices] if self.speed else None,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@dataclass
|
| 146 |
+
class OmniVoiceModelOutput(ModelOutput):
|
| 147 |
+
loss: Optional[torch.Tensor] = None
|
| 148 |
+
logits: Optional[torch.Tensor] = None
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
# Config & Model
|
| 153 |
+
# ---------------------------------------------------------------------------
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class OmniVoiceConfig(PretrainedConfig):
|
| 157 |
+
model_type = "omnivoice"
|
| 158 |
+
sub_configs = {"llm_config": AutoConfig}
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
audio_vocab_size: int = 1025,
|
| 163 |
+
audio_mask_id: int = 1024,
|
| 164 |
+
num_audio_codebook: int = 8,
|
| 165 |
+
audio_codebook_weights: Optional[list[float]] = None,
|
| 166 |
+
llm_config: Optional[Union[dict, PretrainedConfig]] = None,
|
| 167 |
+
**kwargs,
|
| 168 |
+
):
|
| 169 |
+
|
| 170 |
+
if isinstance(llm_config, dict):
|
| 171 |
+
llm_config = CONFIG_MAPPING[llm_config["model_type"]](**llm_config)
|
| 172 |
+
|
| 173 |
+
self.llm_config = llm_config
|
| 174 |
+
|
| 175 |
+
super().__init__(**kwargs)
|
| 176 |
+
self.audio_vocab_size = audio_vocab_size
|
| 177 |
+
self.audio_mask_id = audio_mask_id
|
| 178 |
+
self.num_audio_codebook = num_audio_codebook
|
| 179 |
+
if audio_codebook_weights is None:
|
| 180 |
+
audio_codebook_weights = [8, 8, 6, 6, 4, 4, 2, 2]
|
| 181 |
+
self.audio_codebook_weights = audio_codebook_weights
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class OmniVoice(PreTrainedModel):
|
| 185 |
+
_supports_flex_attn = True
|
| 186 |
+
_supports_flash_attn_2 = True
|
| 187 |
+
config_class = OmniVoiceConfig
|
| 188 |
+
|
| 189 |
+
def __init__(self, config: OmniVoiceConfig, llm: Optional[PreTrainedModel] = None):
|
| 190 |
+
super().__init__(config)
|
| 191 |
+
|
| 192 |
+
if llm is not None:
|
| 193 |
+
# If an LLM instance is provided, use it directly
|
| 194 |
+
# (skipping config-based init).
|
| 195 |
+
self.llm = llm
|
| 196 |
+
else:
|
| 197 |
+
# Otherwise, initialize the LLM from the config.
|
| 198 |
+
self.llm = AutoModel.from_config(self.config.llm_config)
|
| 199 |
+
|
| 200 |
+
self.audio_embeddings = nn.Embedding(
|
| 201 |
+
config.num_audio_codebook * config.audio_vocab_size,
|
| 202 |
+
self.config.llm_config.hidden_size,
|
| 203 |
+
)
|
| 204 |
+
self.register_buffer(
|
| 205 |
+
"codebook_layer_offsets",
|
| 206 |
+
torch.arange(config.num_audio_codebook) * config.audio_vocab_size,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
self.audio_heads = nn.Linear(
|
| 210 |
+
self.config.llm_config.hidden_size,
|
| 211 |
+
config.num_audio_codebook * config.audio_vocab_size,
|
| 212 |
+
bias=False,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
self.normalized_audio_codebook_weights = [
|
| 216 |
+
w / sum(config.audio_codebook_weights)
|
| 217 |
+
for w in config.audio_codebook_weights
|
| 218 |
+
]
|
| 219 |
+
|
| 220 |
+
self.post_init()
|
| 221 |
+
|
| 222 |
+
# Inference-only attributes (set by from_pretrained when not in train mode)
|
| 223 |
+
self.text_tokenizer = None
|
| 224 |
+
self.audio_tokenizer = None
|
| 225 |
+
self.duration_estimator = None
|
| 226 |
+
self.sampling_rate = None
|
| 227 |
+
self._asr_pipe = None
|
| 228 |
+
|
| 229 |
+
@classmethod
|
| 230 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
| 231 |
+
train_mode = kwargs.pop("train", False)
|
| 232 |
+
load_asr = kwargs.pop("load_asr", False)
|
| 233 |
+
asr_model_name = kwargs.pop("asr_model_name", "openai/whisper-large-v3-turbo")
|
| 234 |
+
|
| 235 |
+
# Suppress noisy INFO logs from transformers/huggingface_hub during loading
|
| 236 |
+
_prev_disable = logging.root.manager.disable
|
| 237 |
+
logging.disable(logging.INFO)
|
| 238 |
+
|
| 239 |
+
try:
|
| 240 |
+
model = super().from_pretrained(
|
| 241 |
+
pretrained_model_name_or_path, *args, **kwargs
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
if not train_mode:
|
| 245 |
+
# Resolve local path for audio tokenizer subdirectory
|
| 246 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 247 |
+
resolved_path = pretrained_model_name_or_path
|
| 248 |
+
else:
|
| 249 |
+
from huggingface_hub import snapshot_download
|
| 250 |
+
|
| 251 |
+
resolved_path = snapshot_download(pretrained_model_name_or_path)
|
| 252 |
+
|
| 253 |
+
model.text_tokenizer = AutoTokenizer.from_pretrained(
|
| 254 |
+
pretrained_model_name_or_path
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
audio_tokenizer_path = os.path.join(resolved_path, "audio_tokenizer")
|
| 258 |
+
|
| 259 |
+
if not os.path.isdir(audio_tokenizer_path):
|
| 260 |
+
# Fallback to the HuggingFace Hub path of transformers'
|
| 261 |
+
# HiggsAudioV2Tokenizer if the local subdirectory doesn't exist.
|
| 262 |
+
audio_tokenizer_path = "eustlb/higgs-audio-v2-tokenizer"
|
| 263 |
+
|
| 264 |
+
# higgs-audio-v2-tokenizer does not support MPS (output channels > 65536)
|
| 265 |
+
tokenizer_device = (
|
| 266 |
+
"cpu" if str(model.device).startswith("mps") else model.device
|
| 267 |
+
)
|
| 268 |
+
model.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
|
| 269 |
+
audio_tokenizer_path, device_map=tokenizer_device
|
| 270 |
+
)
|
| 271 |
+
model.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
| 272 |
+
audio_tokenizer_path
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
model.sampling_rate = model.feature_extractor.sampling_rate
|
| 276 |
+
|
| 277 |
+
model.duration_estimator = RuleDurationEstimator()
|
| 278 |
+
|
| 279 |
+
if load_asr:
|
| 280 |
+
model.load_asr_model(model_name=asr_model_name)
|
| 281 |
+
finally:
|
| 282 |
+
logging.disable(_prev_disable)
|
| 283 |
+
|
| 284 |
+
return model
|
| 285 |
+
|
| 286 |
+
# -------------------------------------------------------------------
|
| 287 |
+
# ASR support (optional, for auto-transcription)
|
| 288 |
+
# -------------------------------------------------------------------
|
| 289 |
+
|
| 290 |
+
def load_asr_model(self, model_name: str = "openai/whisper-large-v3-turbo"):
|
| 291 |
+
"""Load a Whisper ASR model for reference audio transcription.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
model_name: HuggingFace model name for the Whisper model.
|
| 295 |
+
"""
|
| 296 |
+
from transformers import pipeline as hf_pipeline
|
| 297 |
+
|
| 298 |
+
logger.info("Loading ASR model %s ...", model_name)
|
| 299 |
+
asr_dtype = (
|
| 300 |
+
torch.float16 if str(self.device).startswith("cuda") else torch.float32
|
| 301 |
+
)
|
| 302 |
+
self._asr_pipe = hf_pipeline(
|
| 303 |
+
"automatic-speech-recognition",
|
| 304 |
+
model=model_name,
|
| 305 |
+
dtype=asr_dtype,
|
| 306 |
+
device_map=self.device,
|
| 307 |
+
)
|
| 308 |
+
logger.info("ASR model loaded on %s.", self.device)
|
| 309 |
+
|
| 310 |
+
@torch.inference_mode()
|
| 311 |
+
def transcribe(
|
| 312 |
+
self,
|
| 313 |
+
audio: Union[str, tuple[torch.Tensor, int]],
|
| 314 |
+
) -> str:
|
| 315 |
+
"""Transcribe audio using the loaded Whisper ASR model.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
audio: File path or (waveform, sample_rate) tuple.
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
Transcribed text.
|
| 322 |
+
"""
|
| 323 |
+
if self._asr_pipe is None:
|
| 324 |
+
raise RuntimeError(
|
| 325 |
+
"ASR model is not loaded. Call model.load_asr_model() first."
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if isinstance(audio, str):
|
| 329 |
+
return self._asr_pipe(audio)["text"].strip()
|
| 330 |
+
else:
|
| 331 |
+
waveform, sr = audio
|
| 332 |
+
if waveform.dim() == 1:
|
| 333 |
+
waveform = waveform.unsqueeze(0)
|
| 334 |
+
if waveform.size(0) > 1:
|
| 335 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| 336 |
+
audio_input = {
|
| 337 |
+
"array": waveform.squeeze(0).cpu().numpy(),
|
| 338 |
+
"sampling_rate": sr,
|
| 339 |
+
}
|
| 340 |
+
return self._asr_pipe(audio_input)["text"].strip()
|
| 341 |
+
|
| 342 |
+
def get_input_embeddings(self):
|
| 343 |
+
return self.llm.get_input_embeddings()
|
| 344 |
+
|
| 345 |
+
def set_input_embeddings(self, value):
|
| 346 |
+
self.llm.set_input_embeddings(value)
|
| 347 |
+
|
| 348 |
+
def _prepare_embed_inputs(
|
| 349 |
+
self, input_ids: torch.Tensor, audio_mask: torch.Tensor
|
| 350 |
+
) -> torch.Tensor:
|
| 351 |
+
"""
|
| 352 |
+
Prepares embeddings from input_ids of shape (batch_size, layers, seq_length).
|
| 353 |
+
Embedding shape is (batch_size, seq_length, hidden_size).
|
| 354 |
+
"""
|
| 355 |
+
text_embeds = self.get_input_embeddings()(input_ids[:, 0, :])
|
| 356 |
+
|
| 357 |
+
# Apply shift to audio IDs based on codebook layer
|
| 358 |
+
# audio_ids: [Batch, 8, Seq]
|
| 359 |
+
# codebook_layer_offsets: [1, 8, 1]
|
| 360 |
+
# Result: Layer 0 ID Layer 1 ID + Layer 2 ID + 2050...
|
| 361 |
+
shifted_ids = (
|
| 362 |
+
input_ids * audio_mask.unsqueeze(1)
|
| 363 |
+
) + self.codebook_layer_offsets.view(1, -1, 1)
|
| 364 |
+
|
| 365 |
+
# input: [Batch, 8, Seq] -> output: [Batch, Seq, Hidden]
|
| 366 |
+
audio_embeds = self.audio_embeddings(shifted_ids).sum(dim=1)
|
| 367 |
+
|
| 368 |
+
return torch.where(audio_mask.unsqueeze(-1), audio_embeds, text_embeds)
|
| 369 |
+
|
| 370 |
+
def forward(
|
| 371 |
+
self,
|
| 372 |
+
input_ids: torch.LongTensor,
|
| 373 |
+
audio_mask: torch.Tensor,
|
| 374 |
+
labels: Optional[torch.LongTensor] = None,
|
| 375 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 376 |
+
document_ids: Optional[torch.Tensor] = None,
|
| 377 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 378 |
+
):
|
| 379 |
+
|
| 380 |
+
inputs_embeds = self._prepare_embed_inputs(input_ids, audio_mask)
|
| 381 |
+
|
| 382 |
+
if attention_mask is None and document_ids is not None:
|
| 383 |
+
attention_mask = create_block_mask(
|
| 384 |
+
_get_packed_mask(
|
| 385 |
+
document_ids[0].to(inputs_embeds.device),
|
| 386 |
+
),
|
| 387 |
+
B=None,
|
| 388 |
+
H=None,
|
| 389 |
+
Q_LEN=input_ids.size(-1),
|
| 390 |
+
KV_LEN=input_ids.size(-1),
|
| 391 |
+
_compile=True,
|
| 392 |
+
device=inputs_embeds.device,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
llm_outputs = self.llm(
|
| 396 |
+
inputs_embeds=inputs_embeds,
|
| 397 |
+
attention_mask=attention_mask,
|
| 398 |
+
return_dict=True,
|
| 399 |
+
position_ids=position_ids,
|
| 400 |
+
)
|
| 401 |
+
hidden_states = llm_outputs[0]
|
| 402 |
+
|
| 403 |
+
loss = None
|
| 404 |
+
|
| 405 |
+
# Shape: [B, S, C * Vocab]
|
| 406 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 407 |
+
logits_flat = self.audio_heads(hidden_states)
|
| 408 |
+
# Shape: [B, S, C, Vocab] -> [B, C, S, Vocab]
|
| 409 |
+
audio_logits = logits_flat.view(
|
| 410 |
+
batch_size,
|
| 411 |
+
seq_len,
|
| 412 |
+
self.config.num_audio_codebook,
|
| 413 |
+
self.config.audio_vocab_size,
|
| 414 |
+
).permute(0, 2, 1, 3)
|
| 415 |
+
|
| 416 |
+
if labels is not None:
|
| 417 |
+
|
| 418 |
+
# audio_logits.permute(0, 3, 1, 2):
|
| 419 |
+
# [Batch, Layer, Seq, Vocab] -> [Batch, Vocab, Layer, Seq]
|
| 420 |
+
# per_token_loss shape: [Batch, Layer, Seq],ignore -100
|
| 421 |
+
per_token_loss = torch.nn.functional.cross_entropy(
|
| 422 |
+
audio_logits.permute(0, 3, 1, 2),
|
| 423 |
+
labels,
|
| 424 |
+
reduction="none",
|
| 425 |
+
ignore_index=-100,
|
| 426 |
+
)
|
| 427 |
+
# valid_mask shape: [Batch, Layer, Seq]
|
| 428 |
+
valid_mask = (labels != -100).float()
|
| 429 |
+
|
| 430 |
+
# layer_means shape: [num_layers]
|
| 431 |
+
layer_means = (per_token_loss * valid_mask).sum(
|
| 432 |
+
dim=(0, 2)
|
| 433 |
+
) / valid_mask.sum(dim=(0, 2)).clamp(min=1.0)
|
| 434 |
+
|
| 435 |
+
weights = torch.tensor(
|
| 436 |
+
self.normalized_audio_codebook_weights, device=audio_logits.device
|
| 437 |
+
)
|
| 438 |
+
loss = (layer_means * weights).sum()
|
| 439 |
+
|
| 440 |
+
return OmniVoiceModelOutput(
|
| 441 |
+
loss=loss,
|
| 442 |
+
logits=audio_logits,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
def supported_language_ids(self) -> set[str]:
|
| 446 |
+
"""Return a list of supported language IDs."""
|
| 447 |
+
return LANG_IDS
|
| 448 |
+
|
| 449 |
+
def supported_language_names(self) -> set[str]:
|
| 450 |
+
"""Return a list of supported language names."""
|
| 451 |
+
return LANG_NAMES
|
| 452 |
+
|
| 453 |
+
# -------------------------------------------------------------------
|
| 454 |
+
# Inference API
|
| 455 |
+
# -------------------------------------------------------------------
|
| 456 |
+
|
| 457 |
+
@torch.inference_mode()
|
| 458 |
+
def generate(
|
| 459 |
+
self,
|
| 460 |
+
text: Union[str, list[str]],
|
| 461 |
+
language: Union[str, list[str], None] = None,
|
| 462 |
+
ref_text: Union[str, list[str], None] = None,
|
| 463 |
+
ref_audio: Union[
|
| 464 |
+
str,
|
| 465 |
+
list[str],
|
| 466 |
+
tuple[torch.Tensor, int],
|
| 467 |
+
list[tuple[torch.Tensor, int]],
|
| 468 |
+
None,
|
| 469 |
+
] = None,
|
| 470 |
+
voice_clone_prompt: Union[
|
| 471 |
+
VoiceClonePrompt, list[VoiceClonePrompt], None
|
| 472 |
+
] = None,
|
| 473 |
+
instruct: Union[str, list[str], None] = None,
|
| 474 |
+
duration: Union[float, list[Optional[float]], None] = None,
|
| 475 |
+
speed: Union[float, list[Optional[float]], None] = None,
|
| 476 |
+
generation_config: Optional[OmniVoiceGenerationConfig] = None,
|
| 477 |
+
**kwargs,
|
| 478 |
+
) -> list[torch.Tensor]:
|
| 479 |
+
"""Generate speech audio given text in various modes.
|
| 480 |
+
|
| 481 |
+
Supports three modes:
|
| 482 |
+
|
| 483 |
+
1. **Voice clone** — clone the voice style from the reference audio.
|
| 484 |
+
Should provide ``voice_clone_prompt`` (from
|
| 485 |
+
:meth:`create_voice_clone_prompt`) or ``ref_text`` + ``ref_audio``.
|
| 486 |
+
2. **Voice design** — provide ``instruct`` text describing
|
| 487 |
+
the desired voice style; no reference audio needed.
|
| 488 |
+
3. **Auto** — provide neither; the model picks a voice itself.
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
text: Target text (single string or list for batch).
|
| 492 |
+
language: Language name (e.g. ``"English"``) or code
|
| 493 |
+
(e.g. ``"en"``). ``None`` for language-agnostic mode.
|
| 494 |
+
Performance is slightly better if you specify the language.
|
| 495 |
+
ref_text: Optional reference text for voice cloning mode.
|
| 496 |
+
ref_audio: Optional reference audio for voice cloning mode.
|
| 497 |
+
Can be a file path or a (waveform, sample_rate) tuple.
|
| 498 |
+
voice_clone_prompt: Reusable prompt from :meth:`create_voice_clone_prompt`.
|
| 499 |
+
If provided, it overrides ``ref_text`` and ``ref_audio``.
|
| 500 |
+
instruct: Style instruction for voice design mode.
|
| 501 |
+
duration: Fixed output duration in seconds. If a single float,
|
| 502 |
+
applies to all items; if a list, one value per item.
|
| 503 |
+
``None`` (default) lets the model estimate duration from text.
|
| 504 |
+
Overrides ``speed`` when both are provided.
|
| 505 |
+
speed: Speaking speed factor. ``> 1.0`` for faster, ``< 1.0`` for
|
| 506 |
+
slower. If a list, one value per item. ``None`` (default) uses
|
| 507 |
+
the model's default estimation.
|
| 508 |
+
generation_config: Explicit config object. If provided, takes
|
| 509 |
+
precedence over ``**kwargs``.
|
| 510 |
+
**kwargs: Generation config or its fields:
|
| 511 |
+
denoise: Whether to prepend the ``<|denoise|>`` token.
|
| 512 |
+
num_step: Number of iterative decoding steps.
|
| 513 |
+
guidance_scale: Classifier-free guidance scale.
|
| 514 |
+
t_shift: Time-step shift (smaller → emphasise low-SNR).
|
| 515 |
+
postprocess_output: Post-process output (remove silence, fade-in/out, pad edges).
|
| 516 |
+
layer_penalty_factor: Penalty encouraging earlier codebook
|
| 517 |
+
layers to unmask first.
|
| 518 |
+
position_temperature: Temperature for position selection.
|
| 519 |
+
class_temperature: Temperature for token sampling (0 = greedy).
|
| 520 |
+
audio_chunk_duration: If > 0, split long text into chunks of
|
| 521 |
+
this duration (seconds) and generate chunk by chunk.
|
| 522 |
+
audio_chunk_threshold: Only apply chunking if estimated audio
|
| 523 |
+
duration exceeds this threshold (seconds).
|
| 524 |
+
Returns:
|
| 525 |
+
``audios`` a list of 2-D ``torch.Tensor``, with the shape (1, T) and sampling rate
|
| 526 |
+
consistent with the model's audio tokenizer (usually 24000 Hz).
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
if self.audio_tokenizer is None or self.text_tokenizer is None:
|
| 530 |
+
raise RuntimeError(
|
| 531 |
+
"Model is not loaded with audio/text tokenizers. Make sure you "
|
| 532 |
+
"loaded the model with OmniVoice.from_pretrained()."
|
| 533 |
+
)
|
| 534 |
+
gen_config = (
|
| 535 |
+
generation_config
|
| 536 |
+
if generation_config is not None
|
| 537 |
+
else OmniVoiceGenerationConfig.from_dict(kwargs)
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
self.eval()
|
| 541 |
+
|
| 542 |
+
full_task = self._preprocess_all(
|
| 543 |
+
text=text,
|
| 544 |
+
language=language,
|
| 545 |
+
ref_text=ref_text,
|
| 546 |
+
ref_audio=ref_audio,
|
| 547 |
+
voice_clone_prompt=voice_clone_prompt,
|
| 548 |
+
instruct=instruct,
|
| 549 |
+
preprocess_prompt=gen_config.preprocess_prompt,
|
| 550 |
+
speed=speed,
|
| 551 |
+
duration=duration,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
short_idx, long_idx = full_task.get_indices(
|
| 555 |
+
gen_config, self.audio_tokenizer.config.frame_rate
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
results = [None] * full_task.batch_size
|
| 559 |
+
|
| 560 |
+
if short_idx:
|
| 561 |
+
short_task = full_task.slice_task(short_idx)
|
| 562 |
+
short_results = self._generate_iterative(short_task, gen_config)
|
| 563 |
+
for idx, res in zip(short_idx, short_results):
|
| 564 |
+
results[idx] = res
|
| 565 |
+
|
| 566 |
+
if long_idx:
|
| 567 |
+
long_task = full_task.slice_task(long_idx)
|
| 568 |
+
long_results = self._generate_chunked(long_task, gen_config)
|
| 569 |
+
for idx, res in zip(long_idx, long_results):
|
| 570 |
+
results[idx] = res
|
| 571 |
+
|
| 572 |
+
generated_audios = []
|
| 573 |
+
for i in range(full_task.batch_size):
|
| 574 |
+
assert results[i] is not None, f"Result {i} was not generated"
|
| 575 |
+
generated_audios.append(
|
| 576 |
+
self._decode_and_post_process(
|
| 577 |
+
results[i], full_task.ref_rms[i], gen_config # type: ignore[arg-type]
|
| 578 |
+
)
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
return generated_audios
|
| 582 |
+
|
| 583 |
+
def create_voice_clone_prompt(
|
| 584 |
+
self,
|
| 585 |
+
ref_audio: Union[str, tuple[torch.Tensor, int]],
|
| 586 |
+
ref_text: Optional[str] = None,
|
| 587 |
+
preprocess_prompt: bool = True,
|
| 588 |
+
) -> VoiceClonePrompt:
|
| 589 |
+
"""Create a reusable voice clone prompt from reference audio.
|
| 590 |
+
|
| 591 |
+
Args:
|
| 592 |
+
ref_audio: File path (str) or ``(waveform, sample_rate)`` tuple.
|
| 593 |
+
waveform should be a 1-D or 2-D torch.Tensor (channels x samples).
|
| 594 |
+
ref_text: Transcript of the reference audio. If ``None``, the
|
| 595 |
+
ASR model will be used to auto-transcribe (must call
|
| 596 |
+
:meth:`load_asr_model` first).
|
| 597 |
+
preprocess_prompt: If ``True`` (default), apply silence removal and
|
| 598 |
+
trimming to the reference audio, add punctuation in the end
|
| 599 |
+
of reference text (if not already)
|
| 600 |
+
|
| 601 |
+
Returns:
|
| 602 |
+
A :class:`VoiceClonePrompt` that can be passed to :meth:`generate`.
|
| 603 |
+
"""
|
| 604 |
+
if self.audio_tokenizer is None:
|
| 605 |
+
raise RuntimeError(
|
| 606 |
+
"Audio tokenizer is not loaded. Make sure you loaded the model "
|
| 607 |
+
"with OmniVoice.from_pretrained()."
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
if isinstance(ref_audio, str):
|
| 611 |
+
ref_wav = load_audio(ref_audio, self.sampling_rate)
|
| 612 |
+
else:
|
| 613 |
+
waveform, sr = ref_audio
|
| 614 |
+
if waveform.dim() == 1:
|
| 615 |
+
waveform = waveform.unsqueeze(0)
|
| 616 |
+
if waveform.size(0) > 1:
|
| 617 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| 618 |
+
if sr != self.sampling_rate:
|
| 619 |
+
waveform = torchaudio.functional.resample(
|
| 620 |
+
waveform, sr, self.sampling_rate
|
| 621 |
+
)
|
| 622 |
+
ref_wav = waveform
|
| 623 |
+
|
| 624 |
+
ref_rms = torch.sqrt(torch.mean(torch.square(ref_wav))).item()
|
| 625 |
+
if 0 < ref_rms < 0.1:
|
| 626 |
+
ref_wav = ref_wav * 0.1 / ref_rms
|
| 627 |
+
|
| 628 |
+
if preprocess_prompt:
|
| 629 |
+
# Trim long reference audio (>20s) by splitting at the largest silence gap.
|
| 630 |
+
# Skip trimming when ref_text is user-provided, otherwise the
|
| 631 |
+
# trimmed audio will no longer match the full transcript.
|
| 632 |
+
if ref_text is None:
|
| 633 |
+
ref_wav = trim_long_audio(ref_wav, self.sampling_rate)
|
| 634 |
+
elif ref_wav.size(-1) / self.sampling_rate > 20.0:
|
| 635 |
+
logger.warning(
|
| 636 |
+
"Reference audio is %.1fs long (>20s) and ref_text was "
|
| 637 |
+
"provided, so automatic trimming is skipped. A long reference "
|
| 638 |
+
"may cause slower generation and degraded quality.",
|
| 639 |
+
ref_wav.size(-1) / self.sampling_rate,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
ref_wav = remove_silence(
|
| 643 |
+
ref_wav,
|
| 644 |
+
self.sampling_rate,
|
| 645 |
+
mid_sil=200,
|
| 646 |
+
lead_sil=100,
|
| 647 |
+
trail_sil=200,
|
| 648 |
+
)
|
| 649 |
+
if ref_wav.size(-1) == 0:
|
| 650 |
+
raise ValueError(
|
| 651 |
+
"Reference audio is empty after silence removal. "
|
| 652 |
+
"Try setting preprocess_prompt=False."
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# Auto-transcribe if ref_text not provided
|
| 656 |
+
if ref_text is None:
|
| 657 |
+
if self._asr_pipe is None:
|
| 658 |
+
logger.info("ASR model not loaded yet, loading on-the-fly ...")
|
| 659 |
+
self.load_asr_model()
|
| 660 |
+
ref_text = self.transcribe((ref_wav, self.sampling_rate))
|
| 661 |
+
logger.debug("Auto-transcribed ref_text: %s", ref_text)
|
| 662 |
+
|
| 663 |
+
chunk_size = self.audio_tokenizer.config.hop_length
|
| 664 |
+
clip_size = int(ref_wav.size(-1) % chunk_size)
|
| 665 |
+
ref_wav = ref_wav[:, :-clip_size] if clip_size > 0 else ref_wav
|
| 666 |
+
ref_audio_tokens = self.audio_tokenizer.encode(
|
| 667 |
+
ref_wav.unsqueeze(0).to(self.audio_tokenizer.device),
|
| 668 |
+
).audio_codes.squeeze(
|
| 669 |
+
0
|
| 670 |
+
) # (C, T)
|
| 671 |
+
|
| 672 |
+
if preprocess_prompt:
|
| 673 |
+
ref_text = add_punctuation(ref_text)
|
| 674 |
+
|
| 675 |
+
return VoiceClonePrompt(
|
| 676 |
+
ref_audio_tokens=ref_audio_tokens,
|
| 677 |
+
ref_text=ref_text,
|
| 678 |
+
ref_rms=ref_rms,
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
def _decode_and_post_process(
|
| 682 |
+
self,
|
| 683 |
+
tokens: Union[torch.Tensor, List[torch.Tensor]],
|
| 684 |
+
rms: Union[float, None],
|
| 685 |
+
gen_config: OmniVoiceGenerationConfig,
|
| 686 |
+
) -> torch.Tensor:
|
| 687 |
+
"""
|
| 688 |
+
Args:
|
| 689 |
+
tokens: Audio tokens — either a single tensor of shape
|
| 690 |
+
(num_codebooks, seq_len) or a list of chunk tensors.
|
| 691 |
+
rms: RMS of the reference audio for volume adjustment.
|
| 692 |
+
gen_config: Generation config for post-processing options.
|
| 693 |
+
Returns:
|
| 694 |
+
Decoded and post-processed audio tensor of shape (1, T).
|
| 695 |
+
"""
|
| 696 |
+
tokenizer_device = self.audio_tokenizer.device
|
| 697 |
+
if isinstance(tokens, list):
|
| 698 |
+
chunk_audios = [
|
| 699 |
+
self.audio_tokenizer.decode(t.to(tokenizer_device).unsqueeze(0))
|
| 700 |
+
.audio_values[0]
|
| 701 |
+
.cpu()
|
| 702 |
+
for t in tokens
|
| 703 |
+
]
|
| 704 |
+
audio_waveform = cross_fade_chunks(chunk_audios, self.sampling_rate)
|
| 705 |
+
else:
|
| 706 |
+
audio_waveform = (
|
| 707 |
+
self.audio_tokenizer.decode(tokens.to(tokenizer_device).unsqueeze(0))
|
| 708 |
+
.audio_values[0]
|
| 709 |
+
.cpu()
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
return self._post_process_audio(
|
| 713 |
+
audio_waveform,
|
| 714 |
+
postprocess_output=gen_config.postprocess_output,
|
| 715 |
+
ref_rms=rms,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
def _post_process_audio(
|
| 719 |
+
self,
|
| 720 |
+
generated_audio: torch.Tensor,
|
| 721 |
+
postprocess_output: bool,
|
| 722 |
+
ref_rms: Union[float, None],
|
| 723 |
+
) -> torch.Tensor:
|
| 724 |
+
"""Optionally remove long silences, adjust volume, and add edge padding.
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
generated_audio: Audio tensor of shape (1, T).
|
| 728 |
+
postprocess_output: If True, remove long silences and apply fade/pad.
|
| 729 |
+
ref_rms: RMS of the reference audio for volume normalisation.
|
| 730 |
+
Returns:
|
| 731 |
+
Processed audio tensor of shape (1, T).
|
| 732 |
+
"""
|
| 733 |
+
if postprocess_output:
|
| 734 |
+
generated_audio = remove_silence(
|
| 735 |
+
generated_audio,
|
| 736 |
+
self.sampling_rate,
|
| 737 |
+
mid_sil=500,
|
| 738 |
+
lead_sil=100,
|
| 739 |
+
trail_sil=100,
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
if ref_rms is not None and ref_rms < 0.1:
|
| 743 |
+
generated_audio = generated_audio * ref_rms / 0.1
|
| 744 |
+
elif ref_rms is None:
|
| 745 |
+
# No reference audio (voice design): peak-normalize to 0.5
|
| 746 |
+
# to avoid clipping while keeping a comfortable volume level.
|
| 747 |
+
peak = generated_audio.abs().max()
|
| 748 |
+
if peak > 1e-6:
|
| 749 |
+
generated_audio = generated_audio / peak * 0.5
|
| 750 |
+
|
| 751 |
+
generated_audio = fade_and_pad_audio(
|
| 752 |
+
generated_audio,
|
| 753 |
+
sample_rate=self.sampling_rate,
|
| 754 |
+
)
|
| 755 |
+
return generated_audio
|
| 756 |
+
|
| 757 |
+
def _generate_chunked(
|
| 758 |
+
self, task: GenerationTask, gen_config: OmniVoiceGenerationConfig
|
| 759 |
+
) -> List[List[torch.Tensor]]:
|
| 760 |
+
"""Generate long audio by splitting text into chunks and batching.
|
| 761 |
+
|
| 762 |
+
Each item in the returned list corresponds to one input and contains
|
| 763 |
+
a list of audio token tensors — one per text chunk.
|
| 764 |
+
|
| 765 |
+
Args:
|
| 766 |
+
task: A :class:`GenerationTask` with one or more items whose
|
| 767 |
+
estimated audio exceeds ``audio_chunk_threshold``.
|
| 768 |
+
gen_config: Generation config (``audio_chunk_duration`` controls
|
| 769 |
+
chunk size).
|
| 770 |
+
Returns:
|
| 771 |
+
Per-item list of chunk token-tensor lists.
|
| 772 |
+
"""
|
| 773 |
+
# Chunk each item's text
|
| 774 |
+
all_chunks = []
|
| 775 |
+
for i in range(task.batch_size):
|
| 776 |
+
avg_tokens_per_char = task.target_lens[i] / len(task.texts[i])
|
| 777 |
+
text_chunk_len = int(
|
| 778 |
+
gen_config.audio_chunk_duration
|
| 779 |
+
* self.audio_tokenizer.config.frame_rate
|
| 780 |
+
/ avg_tokens_per_char
|
| 781 |
+
)
|
| 782 |
+
chunks = chunk_text_punctuation(
|
| 783 |
+
text=task.texts[i],
|
| 784 |
+
chunk_len=text_chunk_len,
|
| 785 |
+
min_chunk_len=3,
|
| 786 |
+
)
|
| 787 |
+
logger.debug(f"Item {i} chunked into {len(chunks)} pieces: {chunks}")
|
| 788 |
+
all_chunks.append(chunks)
|
| 789 |
+
|
| 790 |
+
has_ref = [t is not None for t in task.ref_audio_tokens]
|
| 791 |
+
assert all(has_ref) or not any(has_ref), (
|
| 792 |
+
"Chunked inference requires all items to either have or not have "
|
| 793 |
+
"ref_audio. Mixed ref/non-ref is not supported."
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
max_num_chunks = max(len(c) for c in all_chunks)
|
| 797 |
+
|
| 798 |
+
# chunk_results[item_idx] = list of generated token tensors per chunk
|
| 799 |
+
chunk_results = [[] for _ in range(task.batch_size)]
|
| 800 |
+
|
| 801 |
+
def _run_batch(indices, texts, ref_audios, ref_texts):
|
| 802 |
+
speed_list = task.speed
|
| 803 |
+
target_lens = [
|
| 804 |
+
self._estimate_target_tokens(
|
| 805 |
+
texts[j],
|
| 806 |
+
ref_texts[j],
|
| 807 |
+
ref_audios[j].size(-1) if ref_audios[j] is not None else None,
|
| 808 |
+
speed=speed_list[i] if speed_list else 1.0,
|
| 809 |
+
)
|
| 810 |
+
for j, i in enumerate(indices)
|
| 811 |
+
]
|
| 812 |
+
sub_task = GenerationTask(
|
| 813 |
+
batch_size=len(indices),
|
| 814 |
+
texts=texts,
|
| 815 |
+
target_lens=target_lens,
|
| 816 |
+
langs=[task.langs[i] for i in indices],
|
| 817 |
+
instructs=[task.instructs[i] for i in indices],
|
| 818 |
+
ref_texts=ref_texts,
|
| 819 |
+
ref_audio_tokens=ref_audios,
|
| 820 |
+
ref_rms=[task.ref_rms[i] for i in indices],
|
| 821 |
+
speed=[task.speed[i] for i in indices] if task.speed else None,
|
| 822 |
+
)
|
| 823 |
+
gen_tokens = self._generate_iterative(sub_task, gen_config)
|
| 824 |
+
for j, idx in enumerate(indices):
|
| 825 |
+
chunk_results[idx].append(gen_tokens[j])
|
| 826 |
+
|
| 827 |
+
if all(has_ref):
|
| 828 |
+
# All items have reference audio.
|
| 829 |
+
# We still sequentially generate chunks within each item, but we
|
| 830 |
+
# batch across items for the same chunk index. This allows to keep
|
| 831 |
+
# the VRAM usage manageable while still benefiting from batching.
|
| 832 |
+
for ci in range(max_num_chunks):
|
| 833 |
+
indices = [i for i in range(task.batch_size) if ci < len(all_chunks[i])]
|
| 834 |
+
if not indices:
|
| 835 |
+
continue
|
| 836 |
+
_run_batch(
|
| 837 |
+
indices,
|
| 838 |
+
texts=[all_chunks[i][ci] for i in indices],
|
| 839 |
+
ref_audios=[task.ref_audio_tokens[i] for i in indices],
|
| 840 |
+
ref_texts=[task.ref_texts[i] for i in indices],
|
| 841 |
+
)
|
| 842 |
+
else:
|
| 843 |
+
# No reference audio — generate chunk 0 for all items first,
|
| 844 |
+
# then use chunk 0 output as reference for all subsequent chunks.
|
| 845 |
+
indices_0 = [i for i in range(task.batch_size) if len(all_chunks[i]) > 0]
|
| 846 |
+
_run_batch(
|
| 847 |
+
indices_0,
|
| 848 |
+
texts=[all_chunks[i][0] for i in indices_0],
|
| 849 |
+
ref_audios=[None] * len(indices_0),
|
| 850 |
+
ref_texts=[None] * len(indices_0),
|
| 851 |
+
)
|
| 852 |
+
first_chunk_map = {idx: chunk_results[idx][0] for idx in indices_0}
|
| 853 |
+
|
| 854 |
+
# Batch all remaining chunks, using chunk 0 as fixed reference
|
| 855 |
+
for ci in range(1, max_num_chunks):
|
| 856 |
+
indices = [i for i in range(task.batch_size) if ci < len(all_chunks[i])]
|
| 857 |
+
if not indices:
|
| 858 |
+
continue
|
| 859 |
+
_run_batch(
|
| 860 |
+
indices,
|
| 861 |
+
texts=[all_chunks[i][ci] for i in indices],
|
| 862 |
+
ref_audios=[first_chunk_map[i] for i in indices],
|
| 863 |
+
ref_texts=[all_chunks[i][0] for i in indices],
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
return chunk_results
|
| 867 |
+
|
| 868 |
+
def _preprocess_all(
|
| 869 |
+
self,
|
| 870 |
+
text: Union[str, list[str]],
|
| 871 |
+
language: Union[str, list[str], None] = None,
|
| 872 |
+
ref_text: Union[str, list[str], None] = None,
|
| 873 |
+
ref_audio: Union[
|
| 874 |
+
str,
|
| 875 |
+
list[str],
|
| 876 |
+
tuple[torch.Tensor, int],
|
| 877 |
+
list[tuple[torch.Tensor, int]],
|
| 878 |
+
None,
|
| 879 |
+
] = None,
|
| 880 |
+
voice_clone_prompt: Union[
|
| 881 |
+
VoiceClonePrompt, list[VoiceClonePrompt], None
|
| 882 |
+
] = None,
|
| 883 |
+
instruct: Union[str, list[str], None] = None,
|
| 884 |
+
preprocess_prompt: bool = True,
|
| 885 |
+
speed: Union[float, list[Optional[float]], None] = None,
|
| 886 |
+
duration: Union[float, list[Optional[float]], None] = None,
|
| 887 |
+
) -> GenerationTask:
|
| 888 |
+
|
| 889 |
+
if isinstance(text, str):
|
| 890 |
+
text_list = [text]
|
| 891 |
+
else:
|
| 892 |
+
assert isinstance(
|
| 893 |
+
text, list
|
| 894 |
+
), "text should be a string or a list of strings"
|
| 895 |
+
text_list = text
|
| 896 |
+
batch_size = len(text_list)
|
| 897 |
+
|
| 898 |
+
language_list = self._ensure_list(language, batch_size)
|
| 899 |
+
language_list = [_resolve_language(lang) for lang in language_list]
|
| 900 |
+
instruct_list = self._ensure_list(instruct, batch_size)
|
| 901 |
+
for i, s in enumerate(instruct_list):
|
| 902 |
+
if s is None:
|
| 903 |
+
continue
|
| 904 |
+
use_zh = bool(text_list[i] and _ZH_RE.search(text_list[i]))
|
| 905 |
+
instruct_list[i] = _resolve_instruct(s, use_zh=use_zh)
|
| 906 |
+
|
| 907 |
+
if voice_clone_prompt is not None and (
|
| 908 |
+
ref_text is not None or ref_audio is not None
|
| 909 |
+
):
|
| 910 |
+
logger.warning(
|
| 911 |
+
"Both voice_clone_prompt and ref_text/ref_audio are provided. "
|
| 912 |
+
"ref_text/ref_audio will be ignored."
|
| 913 |
+
)
|
| 914 |
+
if voice_clone_prompt is None and ref_audio is not None:
|
| 915 |
+
# If voice_clone_prompt is not provided, create it from
|
| 916 |
+
# ref_audio (ref_text will be auto-transcribed if not given).
|
| 917 |
+
ref_text_list = self._ensure_list(ref_text, batch_size, auto_repeat=False)
|
| 918 |
+
ref_audio_list = self._ensure_list(ref_audio, batch_size, auto_repeat=False)
|
| 919 |
+
|
| 920 |
+
voice_clone_prompt = []
|
| 921 |
+
for i in range(len(ref_text_list)):
|
| 922 |
+
voice_clone_prompt.append(
|
| 923 |
+
self.create_voice_clone_prompt(
|
| 924 |
+
ref_audio=ref_audio_list[i],
|
| 925 |
+
ref_text=ref_text_list[i],
|
| 926 |
+
preprocess_prompt=preprocess_prompt,
|
| 927 |
+
)
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
voice_clone_prompt_list = self._ensure_list(voice_clone_prompt, batch_size)
|
| 931 |
+
if voice_clone_prompt_list[0] is not None:
|
| 932 |
+
ref_text_list = [vc.ref_text for vc in voice_clone_prompt_list]
|
| 933 |
+
ref_audio_tokens_list = [
|
| 934 |
+
vc.ref_audio_tokens for vc in voice_clone_prompt_list
|
| 935 |
+
]
|
| 936 |
+
ref_rms_list = [vc.ref_rms for vc in voice_clone_prompt_list]
|
| 937 |
+
else:
|
| 938 |
+
ref_text_list = [None] * batch_size
|
| 939 |
+
ref_audio_tokens_list = [None] * batch_size
|
| 940 |
+
ref_rms_list = [None] * batch_size
|
| 941 |
+
|
| 942 |
+
# Normalize speed/duration to per-item lists (may contain None).
|
| 943 |
+
if speed is not None:
|
| 944 |
+
if isinstance(speed, (int, float)):
|
| 945 |
+
user_speed = [float(speed)] * batch_size
|
| 946 |
+
else:
|
| 947 |
+
user_speed = list(speed)
|
| 948 |
+
else:
|
| 949 |
+
user_speed = None
|
| 950 |
+
|
| 951 |
+
if duration is not None:
|
| 952 |
+
if isinstance(duration, (int, float)):
|
| 953 |
+
durations = [float(duration)] * batch_size
|
| 954 |
+
else:
|
| 955 |
+
durations = list(duration)
|
| 956 |
+
else:
|
| 957 |
+
durations = None
|
| 958 |
+
|
| 959 |
+
num_target_tokens_list = []
|
| 960 |
+
for i in range(batch_size):
|
| 961 |
+
# duration[i] overrides speed for estimation: use speed=1.0
|
| 962 |
+
# to get the raw estimate, then override target_lens below.
|
| 963 |
+
has_dur = durations is not None and durations[i] is not None
|
| 964 |
+
item_speed = 1.0 if has_dur else (user_speed[i] if user_speed else 1.0)
|
| 965 |
+
est = self._estimate_target_tokens(
|
| 966 |
+
text_list[i],
|
| 967 |
+
ref_text_list[i],
|
| 968 |
+
ref_audio_tokens_list[i].size(-1)
|
| 969 |
+
if ref_audio_tokens_list[i] is not None
|
| 970 |
+
else None,
|
| 971 |
+
speed=item_speed,
|
| 972 |
+
)
|
| 973 |
+
num_target_tokens_list.append(est)
|
| 974 |
+
|
| 975 |
+
# Per-item duration overrides: set target_lens to exact frame count
|
| 976 |
+
# and compute speed ratio so chunked generation scales proportionally.
|
| 977 |
+
speed_list: Optional[List[float]] = None
|
| 978 |
+
if durations is not None:
|
| 979 |
+
frame_rate = self.audio_tokenizer.config.frame_rate
|
| 980 |
+
speed_list = []
|
| 981 |
+
for i in range(batch_size):
|
| 982 |
+
if durations[i] is not None:
|
| 983 |
+
target_tokens = max(1, int(durations[i] * frame_rate))
|
| 984 |
+
est = num_target_tokens_list[i]
|
| 985 |
+
speed_list.append(est / target_tokens if target_tokens > 0 else 1.0)
|
| 986 |
+
num_target_tokens_list[i] = target_tokens
|
| 987 |
+
else:
|
| 988 |
+
s = user_speed[i] if user_speed else None
|
| 989 |
+
speed_list.append(s if s is not None else 1.0)
|
| 990 |
+
elif user_speed is not None:
|
| 991 |
+
speed_list = [s if s is not None else 1.0 for s in user_speed]
|
| 992 |
+
|
| 993 |
+
return GenerationTask(
|
| 994 |
+
batch_size=batch_size,
|
| 995 |
+
texts=text_list,
|
| 996 |
+
target_lens=num_target_tokens_list,
|
| 997 |
+
langs=language_list,
|
| 998 |
+
instructs=instruct_list,
|
| 999 |
+
ref_texts=ref_text_list,
|
| 1000 |
+
ref_audio_tokens=ref_audio_tokens_list,
|
| 1001 |
+
ref_rms=ref_rms_list,
|
| 1002 |
+
speed=speed_list,
|
| 1003 |
+
)
|
| 1004 |
+
|
| 1005 |
+
def _estimate_target_tokens(self, text, ref_text, num_ref_audio_tokens, speed=1.0):
|
| 1006 |
+
"""Estimate number of target audio tokens."""
|
| 1007 |
+
if num_ref_audio_tokens is None or ref_text is None or len(ref_text) == 0:
|
| 1008 |
+
# Fall back to a simple heuristic
|
| 1009 |
+
ref_text = "Nice to meet you."
|
| 1010 |
+
num_ref_audio_tokens = 25
|
| 1011 |
+
|
| 1012 |
+
est = self.duration_estimator.estimate_duration(
|
| 1013 |
+
text, ref_text, num_ref_audio_tokens
|
| 1014 |
+
)
|
| 1015 |
+
if speed > 0 and speed != 1.0:
|
| 1016 |
+
est = est / speed
|
| 1017 |
+
return max(1, int(est))
|
| 1018 |
+
|
| 1019 |
+
def _ensure_list(
|
| 1020 |
+
self, x: Union[Any, List[Any]], batch_size: int, auto_repeat: bool = True
|
| 1021 |
+
) -> List[Any]:
|
| 1022 |
+
x_list = x if isinstance(x, list) else [x]
|
| 1023 |
+
if len(x_list) not in (
|
| 1024 |
+
1,
|
| 1025 |
+
batch_size,
|
| 1026 |
+
):
|
| 1027 |
+
raise ValueError(
|
| 1028 |
+
f"should be either the number of the text or 1, but got {len(x_list)}"
|
| 1029 |
+
)
|
| 1030 |
+
if auto_repeat and len(x_list) == 1 and batch_size is not None:
|
| 1031 |
+
x_list = x_list * batch_size
|
| 1032 |
+
return x_list
|
| 1033 |
+
|
| 1034 |
+
def _prepare_inference_inputs(
|
| 1035 |
+
self,
|
| 1036 |
+
text: str,
|
| 1037 |
+
num_target_tokens: int,
|
| 1038 |
+
ref_text: Optional[str] = None,
|
| 1039 |
+
ref_audio_tokens: Optional[torch.Tensor] = None,
|
| 1040 |
+
lang: Optional[str] = None,
|
| 1041 |
+
instruct: Optional[str] = None,
|
| 1042 |
+
denoise: bool = True,
|
| 1043 |
+
):
|
| 1044 |
+
"""Prepare input_ids and audio masks for inference.
|
| 1045 |
+
Args:
|
| 1046 |
+
text: Target text to generate.
|
| 1047 |
+
num_target_tokens: Number of audio tokens to generate.
|
| 1048 |
+
ref_text: Optional reference text for voice cloning.
|
| 1049 |
+
ref_audio_tokens: Optional reference audio tokens for voice cloning.
|
| 1050 |
+
with shape (C, T).
|
| 1051 |
+
lang: Optional language ID.
|
| 1052 |
+
instruct: Optional style instruction for voice design.
|
| 1053 |
+
denoise: Whether to include the <|denoise|> token.
|
| 1054 |
+
"""
|
| 1055 |
+
|
| 1056 |
+
# Build style tokens: <|denoise|> + <|lang_start|>...<|lang_end|>
|
| 1057 |
+
# + <|instruct_start|>...<|instruct_end|>
|
| 1058 |
+
style_text = ""
|
| 1059 |
+
if denoise:
|
| 1060 |
+
style_text += "<|denoise|>"
|
| 1061 |
+
lang_str = lang if lang else "None"
|
| 1062 |
+
instruct_str = instruct if instruct else "None"
|
| 1063 |
+
style_text += f"<|lang_start|>{lang_str}<|lang_end|>"
|
| 1064 |
+
style_text += f"<|instruct_start|>{instruct_str}<|instruct_end|>"
|
| 1065 |
+
|
| 1066 |
+
style_tokens = (
|
| 1067 |
+
self.text_tokenizer(style_text, return_tensors="pt")
|
| 1068 |
+
.input_ids.repeat(self.config.num_audio_codebook, 1)
|
| 1069 |
+
.unsqueeze(0)
|
| 1070 |
+
).to(
|
| 1071 |
+
self.device
|
| 1072 |
+
) # [1, C, N1]
|
| 1073 |
+
|
| 1074 |
+
# Build text tokens
|
| 1075 |
+
full_text = _combine_text(ref_text=ref_text, text=text)
|
| 1076 |
+
text_tokens = (
|
| 1077 |
+
self.text_tokenizer(
|
| 1078 |
+
f"<|text_start|>{full_text}<|text_end|>",
|
| 1079 |
+
return_tensors="pt",
|
| 1080 |
+
)
|
| 1081 |
+
.input_ids.repeat(self.config.num_audio_codebook, 1)
|
| 1082 |
+
.unsqueeze(0)
|
| 1083 |
+
).to(
|
| 1084 |
+
self.device
|
| 1085 |
+
) # [1, C, N2]
|
| 1086 |
+
|
| 1087 |
+
# Target: all MASK
|
| 1088 |
+
target_audio_tokens = torch.full(
|
| 1089 |
+
(1, self.config.num_audio_codebook, num_target_tokens),
|
| 1090 |
+
self.config.audio_mask_id,
|
| 1091 |
+
dtype=torch.long,
|
| 1092 |
+
device=self.device,
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
# Conditional input
|
| 1096 |
+
parts = [style_tokens, text_tokens]
|
| 1097 |
+
if ref_audio_tokens is not None:
|
| 1098 |
+
parts.append(ref_audio_tokens.unsqueeze(0).to(self.device))
|
| 1099 |
+
parts.append(target_audio_tokens)
|
| 1100 |
+
cond_input_ids = torch.cat(parts, dim=2)
|
| 1101 |
+
|
| 1102 |
+
cond_total_length = cond_input_ids.shape[2]
|
| 1103 |
+
cond_audio_start_idx = cond_total_length - num_target_tokens
|
| 1104 |
+
if ref_audio_tokens is not None:
|
| 1105 |
+
cond_audio_start_idx -= ref_audio_tokens.size(-1)
|
| 1106 |
+
|
| 1107 |
+
cond_audio_mask = torch.zeros(
|
| 1108 |
+
1, cond_total_length, dtype=torch.bool, device=self.device
|
| 1109 |
+
)
|
| 1110 |
+
cond_audio_mask[0, cond_audio_start_idx:] = True
|
| 1111 |
+
|
| 1112 |
+
return {
|
| 1113 |
+
"input_ids": cond_input_ids,
|
| 1114 |
+
"audio_mask": cond_audio_mask,
|
| 1115 |
+
}
|
| 1116 |
+
|
| 1117 |
+
def _generate_iterative(
|
| 1118 |
+
self, task: GenerationTask, gen_config: OmniVoiceGenerationConfig
|
| 1119 |
+
) -> List[torch.Tensor]:
|
| 1120 |
+
"""N-step iterative unmasked decoding.
|
| 1121 |
+
|
| 1122 |
+
Args:
|
| 1123 |
+
task: A :class:`GenerationTask` containing batch texts, target
|
| 1124 |
+
lengths, languages, instructions, and optional reference data.
|
| 1125 |
+
gen_config: A :class:`OmniVoiceGenerationConfig` controlling
|
| 1126 |
+
decoding steps, guidance, temperatures, etc.
|
| 1127 |
+
Returns:
|
| 1128 |
+
List of generated audio token tensors of shape (C, T) (one per
|
| 1129 |
+
input text).
|
| 1130 |
+
"""
|
| 1131 |
+
|
| 1132 |
+
B = task.batch_size
|
| 1133 |
+
|
| 1134 |
+
inputs_list = [
|
| 1135 |
+
self._prepare_inference_inputs(
|
| 1136 |
+
task.texts[i],
|
| 1137 |
+
task.target_lens[i],
|
| 1138 |
+
task.ref_texts[i],
|
| 1139 |
+
task.ref_audio_tokens[i],
|
| 1140 |
+
task.langs[i],
|
| 1141 |
+
task.instructs[i],
|
| 1142 |
+
gen_config.denoise,
|
| 1143 |
+
)
|
| 1144 |
+
for i in range(B)
|
| 1145 |
+
]
|
| 1146 |
+
|
| 1147 |
+
c_lens = [inp["input_ids"].size(2) for inp in inputs_list]
|
| 1148 |
+
max_c_len = max(c_lens)
|
| 1149 |
+
pad_id = self.config.audio_mask_id # Or any other tokens
|
| 1150 |
+
|
| 1151 |
+
batch_input_ids = torch.full(
|
| 1152 |
+
(2 * B, self.config.num_audio_codebook, max_c_len),
|
| 1153 |
+
pad_id,
|
| 1154 |
+
dtype=torch.long,
|
| 1155 |
+
device=self.device,
|
| 1156 |
+
)
|
| 1157 |
+
batch_audio_mask = torch.zeros(
|
| 1158 |
+
(2 * B, max_c_len), dtype=torch.bool, device=self.device
|
| 1159 |
+
)
|
| 1160 |
+
batch_attention_mask = torch.zeros(
|
| 1161 |
+
(2 * B, 1, max_c_len, max_c_len), dtype=torch.bool, device=self.device
|
| 1162 |
+
)
|
| 1163 |
+
|
| 1164 |
+
for i, inp in enumerate(inputs_list):
|
| 1165 |
+
c_len, u_len = c_lens[i], task.target_lens[i]
|
| 1166 |
+
|
| 1167 |
+
# Cond (0 ~ B-1)
|
| 1168 |
+
batch_input_ids[i, :, :c_len] = inp["input_ids"]
|
| 1169 |
+
batch_audio_mask[i, :c_len] = inp["audio_mask"]
|
| 1170 |
+
batch_attention_mask[i, :, :c_len, :c_len] = True
|
| 1171 |
+
|
| 1172 |
+
# Uncond (B ~ 2B-1)
|
| 1173 |
+
batch_input_ids[B + i, :, :u_len] = inp["input_ids"][..., -u_len:]
|
| 1174 |
+
batch_audio_mask[B + i, :u_len] = inp["audio_mask"][..., -u_len:]
|
| 1175 |
+
batch_attention_mask[B + i, :, :u_len, :u_len] = True
|
| 1176 |
+
|
| 1177 |
+
tokens = torch.full(
|
| 1178 |
+
(B, self.config.num_audio_codebook, max(task.target_lens)),
|
| 1179 |
+
self.config.audio_mask_id,
|
| 1180 |
+
dtype=torch.long,
|
| 1181 |
+
device=self.device,
|
| 1182 |
+
)
|
| 1183 |
+
|
| 1184 |
+
timesteps = _get_time_steps(
|
| 1185 |
+
t_start=0.0,
|
| 1186 |
+
t_end=1.0,
|
| 1187 |
+
num_step=gen_config.num_step + 1,
|
| 1188 |
+
t_shift=gen_config.t_shift,
|
| 1189 |
+
).tolist()
|
| 1190 |
+
schedules = []
|
| 1191 |
+
for t_len in task.target_lens:
|
| 1192 |
+
total_mask = t_len * self.config.num_audio_codebook
|
| 1193 |
+
rem = total_mask
|
| 1194 |
+
sched = []
|
| 1195 |
+
for step in range(gen_config.num_step):
|
| 1196 |
+
num = (
|
| 1197 |
+
rem
|
| 1198 |
+
if step == gen_config.num_step - 1
|
| 1199 |
+
else min(
|
| 1200 |
+
math.ceil(total_mask * (timesteps[step + 1] - timesteps[step])),
|
| 1201 |
+
rem,
|
| 1202 |
+
)
|
| 1203 |
+
)
|
| 1204 |
+
sched.append(int(num))
|
| 1205 |
+
rem -= int(num)
|
| 1206 |
+
schedules.append(sched)
|
| 1207 |
+
|
| 1208 |
+
layer_ids = torch.arange(
|
| 1209 |
+
self.config.num_audio_codebook, device=self.device
|
| 1210 |
+
).view(1, -1, 1)
|
| 1211 |
+
|
| 1212 |
+
for step in range(gen_config.num_step):
|
| 1213 |
+
batch_logits = self(
|
| 1214 |
+
input_ids=batch_input_ids,
|
| 1215 |
+
audio_mask=batch_audio_mask,
|
| 1216 |
+
attention_mask=batch_attention_mask,
|
| 1217 |
+
).logits.to(torch.float32)
|
| 1218 |
+
|
| 1219 |
+
for i in range(B):
|
| 1220 |
+
k = schedules[i][step]
|
| 1221 |
+
if k <= 0:
|
| 1222 |
+
continue
|
| 1223 |
+
|
| 1224 |
+
c_len, t_len = c_lens[i], task.target_lens[i]
|
| 1225 |
+
|
| 1226 |
+
# Extract real target Logits
|
| 1227 |
+
# [1, C, T, V]
|
| 1228 |
+
c_logits = batch_logits[i : i + 1, :, c_len - t_len : c_len, :]
|
| 1229 |
+
u_logits = batch_logits[B + i : B + i + 1, :, :t_len, :]
|
| 1230 |
+
|
| 1231 |
+
pred_tokens, scores = self._predict_tokens_with_scoring(
|
| 1232 |
+
c_logits, u_logits, gen_config
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
scores = scores - (layer_ids * gen_config.layer_penalty_factor)
|
| 1236 |
+
|
| 1237 |
+
if gen_config.position_temperature > 0.0:
|
| 1238 |
+
scores = _gumbel_sample(scores, gen_config.position_temperature)
|
| 1239 |
+
|
| 1240 |
+
sample_tokens = tokens[i : i + 1, :, :t_len]
|
| 1241 |
+
scores.masked_fill_(
|
| 1242 |
+
sample_tokens != self.config.audio_mask_id, -float("inf")
|
| 1243 |
+
)
|
| 1244 |
+
|
| 1245 |
+
_, topk_idx = torch.topk(scores.flatten(), k)
|
| 1246 |
+
flat_tokens = sample_tokens.flatten()
|
| 1247 |
+
flat_tokens[topk_idx] = pred_tokens.flatten()[topk_idx]
|
| 1248 |
+
sample_tokens.copy_(flat_tokens.view_as(sample_tokens))
|
| 1249 |
+
|
| 1250 |
+
# Update individual slices into batched structure
|
| 1251 |
+
tokens[i : i + 1, :, :t_len] = sample_tokens
|
| 1252 |
+
batch_input_ids[i : i + 1, :, c_len - t_len : c_len] = sample_tokens
|
| 1253 |
+
batch_input_ids[B + i : B + i + 1, :, :t_len] = sample_tokens
|
| 1254 |
+
|
| 1255 |
+
return [tokens[i, :, : task.target_lens[i]] for i in range(B)]
|
| 1256 |
+
|
| 1257 |
+
def _predict_tokens_with_scoring(self, c_logits, u_logits, gen_config):
|
| 1258 |
+
if gen_config.guidance_scale != 0:
|
| 1259 |
+
c_log_probs = F.log_softmax(c_logits, dim=-1)
|
| 1260 |
+
u_log_probs = F.log_softmax(u_logits, dim=-1)
|
| 1261 |
+
log_probs = torch.log_softmax(
|
| 1262 |
+
c_log_probs + gen_config.guidance_scale * (c_log_probs - u_log_probs),
|
| 1263 |
+
dim=-1,
|
| 1264 |
+
)
|
| 1265 |
+
else:
|
| 1266 |
+
log_probs = F.log_softmax(c_logits, dim=-1)
|
| 1267 |
+
|
| 1268 |
+
log_probs[..., self.config.audio_mask_id] = -float("inf")
|
| 1269 |
+
|
| 1270 |
+
if gen_config.class_temperature > 0.0:
|
| 1271 |
+
filtered_probs = _filter_top_k(log_probs, ratio=0.1)
|
| 1272 |
+
pred_tokens = _gumbel_sample(
|
| 1273 |
+
filtered_probs, gen_config.class_temperature
|
| 1274 |
+
).argmax(dim=-1)
|
| 1275 |
+
else:
|
| 1276 |
+
pred_tokens = log_probs.argmax(dim=-1)
|
| 1277 |
+
|
| 1278 |
+
confidence_scores = log_probs.max(dim=-1)[0]
|
| 1279 |
+
|
| 1280 |
+
return pred_tokens, confidence_scores
|
| 1281 |
+
|
| 1282 |
+
|
| 1283 |
+
# ---------------------------------------------------------------------------
|
| 1284 |
+
# Standalone helpers
|
| 1285 |
+
# ---------------------------------------------------------------------------
|
| 1286 |
+
|
| 1287 |
+
|
| 1288 |
+
def _get_packed_mask(document_ids):
|
| 1289 |
+
return partial(_mask_mod_packed, document_ids)
|
| 1290 |
+
|
| 1291 |
+
|
| 1292 |
+
def _mask_mod_packed(document_ids, b, h, q_idx, kv_idx):
|
| 1293 |
+
# 1. Sequence Packing Logic: Tokens must belong to the same document.
|
| 1294 |
+
# Note: The doc_id for padding tokens is -1, which will automatically not match
|
| 1295 |
+
# (if handled correctly) or be ignored.
|
| 1296 |
+
same_doc = document_ids[q_idx] == document_ids[kv_idx]
|
| 1297 |
+
return same_doc
|
| 1298 |
+
|
| 1299 |
+
|
| 1300 |
+
def _resolve_language(language: Optional[str]) -> Union[str, None]:
|
| 1301 |
+
from omnivoice.utils.lang_map import LANG_IDS, LANG_NAME_TO_ID
|
| 1302 |
+
|
| 1303 |
+
if language is None or language.lower() == "none":
|
| 1304 |
+
return None
|
| 1305 |
+
if language in LANG_IDS:
|
| 1306 |
+
return language
|
| 1307 |
+
key = language.lower()
|
| 1308 |
+
if key in LANG_NAME_TO_ID:
|
| 1309 |
+
return LANG_NAME_TO_ID[key]
|
| 1310 |
+
logger.warning(
|
| 1311 |
+
f"Language '{language}' is not recognized. "
|
| 1312 |
+
f"Please use a valid language ID (e.g., 'en', 'zh', 'ja', 'de') "
|
| 1313 |
+
f"or a full language name (e.g., 'English', 'Chinese', 'Japanese'). "
|
| 1314 |
+
f"See supported_language_ids() or supported_language_names() for details. "
|
| 1315 |
+
f"Falling back to None (language-agnostic mode)."
|
| 1316 |
+
)
|
| 1317 |
+
return None
|
| 1318 |
+
|
| 1319 |
+
|
| 1320 |
+
def _resolve_instruct(
|
| 1321 |
+
instruct: Optional[str], use_zh: bool = False
|
| 1322 |
+
) -> Union[str, None]:
|
| 1323 |
+
"""Validate and normalise a voice-design instruct string.
|
| 1324 |
+
|
| 1325 |
+
Supported instruct items (case-insensitive for English):
|
| 1326 |
+
|
| 1327 |
+
English (comma + space separated):
|
| 1328 |
+
gender: male, female
|
| 1329 |
+
age: child, teenager, young adult, middle-aged, elderly
|
| 1330 |
+
pitch: very low pitch, low pitch, moderate pitch,
|
| 1331 |
+
high pitch, very high pitch
|
| 1332 |
+
style: whisper
|
| 1333 |
+
accent: american accent, british accent, australian accent, ...
|
| 1334 |
+
|
| 1335 |
+
Chinese (full-width comma separated):
|
| 1336 |
+
gender: 男, 女
|
| 1337 |
+
age: 儿童, 少年, 青年, 中年, 老年
|
| 1338 |
+
pitch: 极低音调, 低音调, 中音调, 高音调, 极高音调
|
| 1339 |
+
style: 耳语
|
| 1340 |
+
dialect: 河南话, 陕西话, 四川话, 贵州话, 云南话,
|
| 1341 |
+
桂林话, 济南话, 石家庄话, 甘肃话, 宁夏话,
|
| 1342 |
+
青岛话, 东北话
|
| 1343 |
+
|
| 1344 |
+
Minor issues (auto-fixed):
|
| 1345 |
+
- Wrong separator (half-width comma in Chinese instruct or
|
| 1346 |
+
full-width comma in English instruct)
|
| 1347 |
+
- Leading / trailing commas
|
| 1348 |
+
|
| 1349 |
+
Major issues (raise ``ValueError``):
|
| 1350 |
+
- Unsupported or misspelled instruct items
|
| 1351 |
+
- Suggestions are offered for close matches
|
| 1352 |
+
|
| 1353 |
+
Args:
|
| 1354 |
+
instruct: Raw instruct string, or ``None``.
|
| 1355 |
+
use_zh: If True, normalise all items to Chinese (used when the
|
| 1356 |
+
synthesis text contains Chinese and no accent is specified).
|
| 1357 |
+
|
| 1358 |
+
Returns:
|
| 1359 |
+
Normalised instruct string, or ``None``.
|
| 1360 |
+
|
| 1361 |
+
Raises:
|
| 1362 |
+
ValueError: if any instruct item is unsupported or misspelled.
|
| 1363 |
+
"""
|
| 1364 |
+
if instruct is None:
|
| 1365 |
+
return None
|
| 1366 |
+
|
| 1367 |
+
instruct_str = instruct.strip()
|
| 1368 |
+
if not instruct_str:
|
| 1369 |
+
return None
|
| 1370 |
+
|
| 1371 |
+
# Split on both half-width and full-width commas
|
| 1372 |
+
raw_items = re.split(r"\s*[,,]\s*", instruct_str)
|
| 1373 |
+
raw_items = [x for x in raw_items if x]
|
| 1374 |
+
|
| 1375 |
+
# Validate each item
|
| 1376 |
+
unknown = []
|
| 1377 |
+
normalised = []
|
| 1378 |
+
for raw in raw_items:
|
| 1379 |
+
n = raw.strip().lower()
|
| 1380 |
+
if n in _INSTRUCT_ALL_VALID:
|
| 1381 |
+
normalised.append(n)
|
| 1382 |
+
else:
|
| 1383 |
+
sug = difflib.get_close_matches(n, _INSTRUCT_ALL_VALID, n=1, cutoff=0.6)
|
| 1384 |
+
unknown.append((raw, n, sug[0] if sug else None))
|
| 1385 |
+
|
| 1386 |
+
if unknown:
|
| 1387 |
+
lines = []
|
| 1388 |
+
for raw, n, sug in unknown:
|
| 1389 |
+
if sug:
|
| 1390 |
+
lines.append(f" '{raw}' -> '{n}' (unsupported; did you mean '{sug}'?)")
|
| 1391 |
+
else:
|
| 1392 |
+
lines.append(f" '{raw}' -> '{n}' (unsupported)")
|
| 1393 |
+
err = (
|
| 1394 |
+
f"Unsupported instruct items found in {instruct_str}:\n"
|
| 1395 |
+
+ "\n".join(lines)
|
| 1396 |
+
+ "\n\nValid English items: "
|
| 1397 |
+
+ ", ".join(sorted(_INSTRUCT_VALID_EN))
|
| 1398 |
+
+ "\nValid Chinese items: "
|
| 1399 |
+
+ ",".join(sorted(_INSTRUCT_VALID_ZH))
|
| 1400 |
+
+ "\n\nTip: Use only English or only Chinese instructs. "
|
| 1401 |
+
"English instructs should use comma + space (e.g. "
|
| 1402 |
+
"'male, indian accent'),\nChinese instructs should use full-width "
|
| 1403 |
+
"comma (e.g. '男,河南话')."
|
| 1404 |
+
)
|
| 1405 |
+
raise ValueError(err)
|
| 1406 |
+
|
| 1407 |
+
# --- Language consistency: dialect forces Chinese, accent forces English ---
|
| 1408 |
+
has_dialect = any(n.endswith("话") for n in normalised)
|
| 1409 |
+
has_accent = any(" accent" in n for n in normalised)
|
| 1410 |
+
|
| 1411 |
+
if has_dialect and has_accent:
|
| 1412 |
+
raise ValueError(
|
| 1413 |
+
"Cannot mix Chinese dialect and English accent in a single instruct. "
|
| 1414 |
+
"Dialects are for Chinese speech, accents for English speech."
|
| 1415 |
+
)
|
| 1416 |
+
|
| 1417 |
+
if has_dialect:
|
| 1418 |
+
use_zh = True
|
| 1419 |
+
elif has_accent:
|
| 1420 |
+
use_zh = False
|
| 1421 |
+
|
| 1422 |
+
# --- Unify to single language ---
|
| 1423 |
+
if use_zh:
|
| 1424 |
+
normalised = [_INSTRUCT_EN_TO_ZH.get(n, n) for n in normalised]
|
| 1425 |
+
else:
|
| 1426 |
+
normalised = [_INSTRUCT_ZH_TO_EN.get(n, n) for n in normalised]
|
| 1427 |
+
|
| 1428 |
+
# --- Category conflict check ---
|
| 1429 |
+
conflicts = []
|
| 1430 |
+
for cat in _INSTRUCT_MUTUALLY_EXCLUSIVE:
|
| 1431 |
+
hits = [n for n in normalised if n in cat]
|
| 1432 |
+
if len(hits) > 1:
|
| 1433 |
+
conflicts.append(hits)
|
| 1434 |
+
if conflicts:
|
| 1435 |
+
parts = []
|
| 1436 |
+
for group in conflicts:
|
| 1437 |
+
parts.append(" vs ".join(f"'{x}'" for x in group))
|
| 1438 |
+
raise ValueError(
|
| 1439 |
+
"Conflicting instruct items within the same category: "
|
| 1440 |
+
+ "; ".join(parts)
|
| 1441 |
+
+ ". Each category (gender, age, pitch, style, accent, dialect) "
|
| 1442 |
+
"allows at most one item."
|
| 1443 |
+
)
|
| 1444 |
+
|
| 1445 |
+
# Determine separator based on language
|
| 1446 |
+
has_zh = any(any("\u4e00" <= c <= "\u9fff" for c in n) for n in normalised)
|
| 1447 |
+
separator = "," if has_zh else ", "
|
| 1448 |
+
|
| 1449 |
+
return separator.join(normalised)
|
| 1450 |
+
|
| 1451 |
+
|
| 1452 |
+
def _filter_top_k(logits: torch.Tensor, ratio: float = 0.1) -> torch.Tensor:
|
| 1453 |
+
k = math.ceil(ratio * logits.shape[-1])
|
| 1454 |
+
val, ind = logits.topk(k, dim=-1)
|
| 1455 |
+
probs = torch.full_like(logits, float("-inf"))
|
| 1456 |
+
probs.scatter_(-1, ind, val)
|
| 1457 |
+
return probs
|
| 1458 |
+
|
| 1459 |
+
|
| 1460 |
+
def _gumbel_sample(logits: torch.Tensor, temperature: float) -> torch.Tensor:
|
| 1461 |
+
scaled_logits = logits / temperature
|
| 1462 |
+
u = torch.rand_like(scaled_logits)
|
| 1463 |
+
gumbel_noise = -torch.log(-torch.log(u + 1e-10) + 1e-10)
|
| 1464 |
+
return scaled_logits + gumbel_noise
|
| 1465 |
+
|
| 1466 |
+
|
| 1467 |
+
def _get_time_steps(
|
| 1468 |
+
t_start: float = 0.0,
|
| 1469 |
+
t_end: float = 1.0,
|
| 1470 |
+
num_step: int = 10,
|
| 1471 |
+
t_shift: float = 1.0,
|
| 1472 |
+
device: torch.device = torch.device("cpu"),
|
| 1473 |
+
) -> torch.Tensor:
|
| 1474 |
+
timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
|
| 1475 |
+
timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
|
| 1476 |
+
return timesteps
|
| 1477 |
+
|
| 1478 |
+
|
| 1479 |
+
def _combine_text(text, ref_text: Optional[str] = None) -> str:
|
| 1480 |
+
|
| 1481 |
+
# combine with reference text if not None
|
| 1482 |
+
if ref_text:
|
| 1483 |
+
full_text = ref_text.strip() + " " + text.strip()
|
| 1484 |
+
else:
|
| 1485 |
+
full_text = text.strip()
|
| 1486 |
+
|
| 1487 |
+
# replace \n with .
|
| 1488 |
+
full_text = re.sub(r"[ \t]*\r?\n[\s]*", ".", full_text)
|
| 1489 |
+
|
| 1490 |
+
# remove spaces around chinese characters
|
| 1491 |
+
chinese_range = r"[\u4e00-\u9fff]"
|
| 1492 |
+
pattern = rf"(?<={chinese_range})\s+|\s+(?={chinese_range})"
|
| 1493 |
+
full_text = re.sub(pattern, "", full_text)
|
| 1494 |
+
return full_text
|
| 1495 |
+
|
| 1496 |
+
|
| 1497 |
+
# ---------------------------------------------------------------------------
|
| 1498 |
+
# Register with HuggingFace Auto classes
|
| 1499 |
+
# ---------------------------------------------------------------------------
|
| 1500 |
+
|
| 1501 |
+
AutoConfig.register("omnivoice", OmniVoiceConfig)
|
| 1502 |
+
AutoModel.register(OmniVoiceConfig, OmniVoice)
|
omnivoice/scripts/__init__.py
ADDED
|
File without changes
|
omnivoice/scripts/denoise_audio.py
ADDED
|
@@ -0,0 +1,1048 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Denoise audio with Sidon and pack results into WebDataset shards.
|
| 19 |
+
|
| 20 |
+
Supports two input modes:
|
| 21 |
+
|
| 22 |
+
1. WebDataset manifest (data.lst):
|
| 23 |
+
python denoise_audio.py \
|
| 24 |
+
--input_manifest data.lst \
|
| 25 |
+
--tar_output_pattern output/audios/shard-%06d.tar \
|
| 26 |
+
--jsonl_output_pattern output/txts/shard-%06d.jsonl \
|
| 27 |
+
--feature_extractor_path sidon-v0.1/feature_extractor_cuda.pt \
|
| 28 |
+
--decoder_path sidon-v0.1/decoder_cuda.pt
|
| 29 |
+
|
| 30 |
+
2. Raw JSONL (each line: {"id": "...", "audio_path": "...", ...}):
|
| 31 |
+
python denoise_audio.py \
|
| 32 |
+
--input_jsonl data.jsonl \
|
| 33 |
+
--tar_output_pattern output/audios/shard-%06d.tar \
|
| 34 |
+
--jsonl_output_pattern output/txts/shard-%06d.jsonl \
|
| 35 |
+
--feature_extractor_path sidon-v0.1/feature_extractor_cuda.pt \
|
| 36 |
+
--decoder_path sidon-v0.1/decoder_cuda.pt
|
| 37 |
+
|
| 38 |
+
Output structure:
|
| 39 |
+
output_dir/
|
| 40 |
+
├── audios/ # WebDataset tar shards (.flac audio + .json metadata)
|
| 41 |
+
│ ├── shard_000000.tar
|
| 42 |
+
│ └── ...
|
| 43 |
+
├── txts/ # Per-shard JSONL metadata
|
| 44 |
+
│ ├── shard_000000.jsonl
|
| 45 |
+
│ └── ...
|
| 46 |
+
├── data.lst # Manifest: <tar_path> <jsonl_path> <sample_count> <total_duration>
|
| 47 |
+
└── errors.jsonl # Failed samples with error details
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
from __future__ import annotations
|
| 51 |
+
|
| 52 |
+
import argparse
|
| 53 |
+
import io
|
| 54 |
+
import json
|
| 55 |
+
import logging
|
| 56 |
+
import os
|
| 57 |
+
import pickle
|
| 58 |
+
import struct
|
| 59 |
+
import subprocess
|
| 60 |
+
import sys
|
| 61 |
+
import threading
|
| 62 |
+
from concurrent.futures import FIRST_COMPLETED, Future, wait
|
| 63 |
+
from dataclasses import dataclass
|
| 64 |
+
from pathlib import Path
|
| 65 |
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
| 66 |
+
|
| 67 |
+
import numpy as np
|
| 68 |
+
import torch
|
| 69 |
+
import torchaudio
|
| 70 |
+
import webdataset as wds
|
| 71 |
+
from torch.utils.data import DataLoader
|
| 72 |
+
from tqdm.auto import tqdm
|
| 73 |
+
|
| 74 |
+
from omnivoice.data.batching import StreamLengthGroupDataset
|
| 75 |
+
from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
|
| 76 |
+
from omnivoice.utils.common import str2bool
|
| 77 |
+
|
| 78 |
+
SIDON_INPUT_SAMPLE_RATE = 16_000
|
| 79 |
+
SIDON_OUTPUT_SAMPLE_RATE = 48_000
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 83 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 84 |
+
|
| 85 |
+
# ── Input (mutually exclusive) ──
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--input_manifest",
|
| 88 |
+
default=None,
|
| 89 |
+
help="WebDataset manifest (data.lst). Each line: "
|
| 90 |
+
"<tar_path> <jsonl_path> <num_items> <duration>",
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--input_jsonl",
|
| 94 |
+
default=None,
|
| 95 |
+
help="Raw JSONL file. Each line: " '{"id": "...", "audio_path": "...", ...}',
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# ── Output ──
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--tar_output_pattern",
|
| 101 |
+
default=None,
|
| 102 |
+
help="Tar shard pattern, e.g. output/audios/shard_%%06d.tar",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--jsonl_output_pattern",
|
| 106 |
+
default=None,
|
| 107 |
+
help="JSONL shard pattern, e.g. output/txts/shard_%%06d.jsonl",
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--samples_per_shard",
|
| 111 |
+
type=int,
|
| 112 |
+
default=1_000,
|
| 113 |
+
help="Maximum records per output shard",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# ── Model ──
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--feature_extractor_path",
|
| 119 |
+
default=None,
|
| 120 |
+
help="Path to feature_extractor_cuda.pt",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--decoder_path",
|
| 124 |
+
default=None,
|
| 125 |
+
help="Path to decoder_cuda.pt",
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--target_sample_rate",
|
| 129 |
+
type=int,
|
| 130 |
+
default=24_000,
|
| 131 |
+
help="Sample rate of the denoised output audio",
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# ── Filtering ──
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--min_length",
|
| 137 |
+
type=float,
|
| 138 |
+
default=0.0,
|
| 139 |
+
help="Minimum audio duration in seconds",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--max_length",
|
| 143 |
+
type=float,
|
| 144 |
+
default=80.0,
|
| 145 |
+
help="Maximum audio duration in seconds",
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# ── Batching ──
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--batch_duration",
|
| 151 |
+
type=float,
|
| 152 |
+
default=200.0,
|
| 153 |
+
help="Target batch duration in seconds for dynamic batching",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--max_sample",
|
| 157 |
+
type=int,
|
| 158 |
+
default=32,
|
| 159 |
+
help="Maximum samples per batch for dynamic batching",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# ── Distributed ──
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--num_machines",
|
| 165 |
+
type=int,
|
| 166 |
+
default=1,
|
| 167 |
+
help="Total number of machines for distributed runs",
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--machine_index",
|
| 171 |
+
type=int,
|
| 172 |
+
default=0,
|
| 173 |
+
help="Zero-based machine index when distributing across multiple "
|
| 174 |
+
"machines (e.g. 0, 1, ... num_machines-1)",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# ── Parallelism ──
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--nj_per_gpu",
|
| 180 |
+
type=int,
|
| 181 |
+
default=1,
|
| 182 |
+
help="Worker processes per GPU (default 1)",
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--loader_workers",
|
| 186 |
+
type=int,
|
| 187 |
+
default=16,
|
| 188 |
+
help="PyTorch DataLoader worker threads",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# ── Data order (JSONL mode) ──
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--shuffle",
|
| 194 |
+
type=str2bool,
|
| 195 |
+
default=True,
|
| 196 |
+
help="Shuffle JSONL entries",
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--shuffle_seed",
|
| 200 |
+
type=int,
|
| 201 |
+
default=42,
|
| 202 |
+
help="Seed for JSONL shuffle",
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# ── Error handling ──
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--skip_errors",
|
| 208 |
+
action="store_true",
|
| 209 |
+
help="Skip items that fail to denoise instead of aborting",
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--_subprocess_worker",
|
| 213 |
+
action="store_true",
|
| 214 |
+
help=argparse.SUPPRESS,
|
| 215 |
+
)
|
| 216 |
+
return parser
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ---------------------------------------------------------------------------
|
| 220 |
+
# Utilities
|
| 221 |
+
# ---------------------------------------------------------------------------
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def count_lines(path: str) -> int:
|
| 225 |
+
"""Count newlines efficiently by reading binary chunks."""
|
| 226 |
+
count = 0
|
| 227 |
+
with open(path, "rb") as f:
|
| 228 |
+
for chunk in iter(lambda: f.read(1 << 20), b""):
|
| 229 |
+
count += chunk.count(b"\n")
|
| 230 |
+
return count
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
PaddingStrategy = Union[bool, str]
|
| 234 |
+
ReturnType = Union[torch.Tensor, np.ndarray]
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def extract_seamless_m4t_features(
|
| 238 |
+
raw_speech: Union[torch.Tensor, List[float], List[torch.Tensor], List[List[float]]],
|
| 239 |
+
sampling_rate: int = 16000,
|
| 240 |
+
num_mel_bins: int = 80,
|
| 241 |
+
frame_length: int = 25,
|
| 242 |
+
frame_shift: int = 10,
|
| 243 |
+
preemphasis_coefficient: float = 0.97,
|
| 244 |
+
dither: float = 0.0,
|
| 245 |
+
window_type: str = "povey",
|
| 246 |
+
do_normalize_per_mel_bins: bool = True,
|
| 247 |
+
stride: int = 2,
|
| 248 |
+
padding: PaddingStrategy = "longest",
|
| 249 |
+
max_length: Optional[int] = None,
|
| 250 |
+
pad_to_multiple_of: Optional[int] = 2,
|
| 251 |
+
return_tensors: Optional[str] = "pt",
|
| 252 |
+
return_attention_mask: bool = True,
|
| 253 |
+
padding_value: float = 0.0,
|
| 254 |
+
device: torch.device = torch.device("cpu"),
|
| 255 |
+
) -> Dict[str, ReturnType]:
|
| 256 |
+
"""Extract SeamlessM4T features using Torch-only operators."""
|
| 257 |
+
if not isinstance(raw_speech, list):
|
| 258 |
+
raw_speech = [raw_speech]
|
| 259 |
+
|
| 260 |
+
processed_speech = [
|
| 261 |
+
torch.as_tensor(sample, dtype=torch.float32, device=device)
|
| 262 |
+
for sample in raw_speech
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
features: List[torch.Tensor] = []
|
| 266 |
+
for waveform in processed_speech:
|
| 267 |
+
if waveform.ndim > 1:
|
| 268 |
+
waveform = waveform[0]
|
| 269 |
+
waveform_tensor = waveform.unsqueeze(0)
|
| 270 |
+
feature = torchaudio.compliance.kaldi.fbank(
|
| 271 |
+
waveform=waveform_tensor,
|
| 272 |
+
sample_frequency=sampling_rate,
|
| 273 |
+
num_mel_bins=num_mel_bins,
|
| 274 |
+
frame_length=frame_length,
|
| 275 |
+
frame_shift=frame_shift,
|
| 276 |
+
dither=dither,
|
| 277 |
+
preemphasis_coefficient=preemphasis_coefficient,
|
| 278 |
+
remove_dc_offset=True,
|
| 279 |
+
window_type=window_type,
|
| 280 |
+
use_energy=False,
|
| 281 |
+
energy_floor=1.192092955078125e-07,
|
| 282 |
+
)
|
| 283 |
+
features.append(feature.squeeze(0))
|
| 284 |
+
|
| 285 |
+
if do_normalize_per_mel_bins:
|
| 286 |
+
normalised: List[torch.Tensor] = []
|
| 287 |
+
for feature in features:
|
| 288 |
+
mean = feature.mean(0, keepdim=True)
|
| 289 |
+
var = feature.var(0, keepdim=True)
|
| 290 |
+
normalised.append((feature - mean) / torch.sqrt(var + 1e-5))
|
| 291 |
+
features = normalised
|
| 292 |
+
|
| 293 |
+
def _pad_batch(
|
| 294 |
+
features: List[torch.Tensor],
|
| 295 |
+
padding_strategy: PaddingStrategy = "longest",
|
| 296 |
+
max_length: Optional[int] = None,
|
| 297 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 298 |
+
padding_value: float = 0.0,
|
| 299 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 300 |
+
if padding_strategy == "longest":
|
| 301 |
+
target_length = max(f.shape[0] for f in features)
|
| 302 |
+
elif max_length is not None:
|
| 303 |
+
target_length = max_length
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError(
|
| 306 |
+
"max_length must be provided when padding_strategy is not 'longest'"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if pad_to_multiple_of is not None:
|
| 310 |
+
target_length = (
|
| 311 |
+
(target_length + pad_to_multiple_of - 1)
|
| 312 |
+
// pad_to_multiple_of
|
| 313 |
+
* pad_to_multiple_of
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
batch_size = len(features)
|
| 317 |
+
feature_dim = features[0].shape[1]
|
| 318 |
+
device = features[0].device
|
| 319 |
+
|
| 320 |
+
padded_features = torch.full(
|
| 321 |
+
(batch_size, target_length, feature_dim),
|
| 322 |
+
padding_value,
|
| 323 |
+
dtype=torch.float32,
|
| 324 |
+
device=device,
|
| 325 |
+
)
|
| 326 |
+
attention_mask = torch.zeros(
|
| 327 |
+
(batch_size, target_length),
|
| 328 |
+
dtype=torch.int64,
|
| 329 |
+
device=device,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
for index, feature_tensor in enumerate(features):
|
| 333 |
+
seq_len = feature_tensor.shape[0]
|
| 334 |
+
padded_features[index, :seq_len] = feature_tensor
|
| 335 |
+
attention_mask[index, :seq_len] = 1
|
| 336 |
+
|
| 337 |
+
return padded_features, attention_mask
|
| 338 |
+
|
| 339 |
+
input_features, attention_mask = _pad_batch(
|
| 340 |
+
features,
|
| 341 |
+
padding_strategy=padding,
|
| 342 |
+
max_length=max_length,
|
| 343 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 344 |
+
padding_value=padding_value,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
batch_size, num_frames, num_channels = input_features.shape
|
| 348 |
+
new_num_frames = (num_frames // stride) * stride
|
| 349 |
+
input_features = input_features[:, :new_num_frames, :]
|
| 350 |
+
if return_attention_mask:
|
| 351 |
+
attention_mask = attention_mask[:, :new_num_frames]
|
| 352 |
+
|
| 353 |
+
input_features = input_features.reshape(
|
| 354 |
+
batch_size, new_num_frames // stride, num_channels * stride
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
output: Dict[str, ReturnType] = {"input_features": input_features}
|
| 358 |
+
if return_attention_mask:
|
| 359 |
+
output["attention_mask"] = attention_mask[:, 1::stride]
|
| 360 |
+
|
| 361 |
+
if return_tensors == "np":
|
| 362 |
+
for key, value in output.items():
|
| 363 |
+
output[key] = value.cpu().numpy() # type: ignore[assignment]
|
| 364 |
+
|
| 365 |
+
return output
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def serialise_flac(key: str, waveform: torch.Tensor, sample_rate: int) -> dict:
|
| 369 |
+
buffer = io.BytesIO()
|
| 370 |
+
audio = waveform.to(dtype=torch.float32).cpu()
|
| 371 |
+
if audio.ndim == 1:
|
| 372 |
+
audio = audio.unsqueeze(0)
|
| 373 |
+
torchaudio.save(buffer, audio, sample_rate, format="flac", bits_per_sample=16)
|
| 374 |
+
return {"__key__": key, "flac": buffer.getvalue()}
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def _normalise_value(value: Any) -> Any:
|
| 378 |
+
"""Convert tensors and NumPy scalars to serialisable Python objects."""
|
| 379 |
+
if isinstance(value, torch.Tensor):
|
| 380 |
+
if value.ndim == 0:
|
| 381 |
+
return value.item()
|
| 382 |
+
return value.cpu().tolist()
|
| 383 |
+
if isinstance(value, np.generic):
|
| 384 |
+
return value.item()
|
| 385 |
+
if isinstance(value, np.ndarray):
|
| 386 |
+
return value.tolist()
|
| 387 |
+
return value
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def _encode_metadata(metadata: dict[str, Any]) -> bytes:
|
| 391 |
+
cleaned: dict[str, Any] = {}
|
| 392 |
+
for key, value in metadata.items():
|
| 393 |
+
if value is None:
|
| 394 |
+
continue
|
| 395 |
+
cleaned[key] = _normalise_value(value)
|
| 396 |
+
return json.dumps(cleaned, ensure_ascii=False).encode("utf-8")
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
# ---------------------------------------------------------------------------
|
| 400 |
+
# Denoising model
|
| 401 |
+
# ---------------------------------------------------------------------------
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class SpeechDenoisingProcessor:
|
| 405 |
+
"""Run the TorchScripted feature extractor and decoder."""
|
| 406 |
+
|
| 407 |
+
def __init__(
|
| 408 |
+
self,
|
| 409 |
+
feature_extractor_path: str,
|
| 410 |
+
decoder_path: str,
|
| 411 |
+
device: str,
|
| 412 |
+
) -> None:
|
| 413 |
+
self.device = torch.device(device)
|
| 414 |
+
self.feature_extractor = torch.jit.load(
|
| 415 |
+
feature_extractor_path, map_location=self.device
|
| 416 |
+
)
|
| 417 |
+
self.decoder = torch.jit.load(decoder_path, map_location=self.device)
|
| 418 |
+
self.feature_extractor.eval()
|
| 419 |
+
self.decoder.eval()
|
| 420 |
+
|
| 421 |
+
@torch.inference_mode()
|
| 422 |
+
def process(self, waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
| 423 |
+
return self.process_batch([waveform], [sample_rate])[0]
|
| 424 |
+
|
| 425 |
+
@torch.inference_mode()
|
| 426 |
+
def process_batch(
|
| 427 |
+
self,
|
| 428 |
+
waveforms: Sequence[torch.Tensor] | torch.Tensor,
|
| 429 |
+
sample_rates: Optional[Sequence[int]] = None,
|
| 430 |
+
expected_lengths: Optional[Sequence[int]] = None,
|
| 431 |
+
) -> List[torch.Tensor]:
|
| 432 |
+
if expected_lengths is None:
|
| 433 |
+
expected_lengths: list[int] = []
|
| 434 |
+
for waveform, sample_rate in zip(waveforms, sample_rates):
|
| 435 |
+
duration_seconds = waveform.shape[-1] / float(sample_rate)
|
| 436 |
+
expected_lengths.append(
|
| 437 |
+
int(round(duration_seconds * SIDON_OUTPUT_SAMPLE_RATE))
|
| 438 |
+
)
|
| 439 |
+
waveforms = torch.nn.functional.pad(waveforms, (0, 24000))
|
| 440 |
+
|
| 441 |
+
features = extract_seamless_m4t_features(
|
| 442 |
+
[x for x in waveforms],
|
| 443 |
+
return_tensors="pt",
|
| 444 |
+
padding_value=1.0,
|
| 445 |
+
device=self.device,
|
| 446 |
+
)
|
| 447 |
+
feature_tensor = self.feature_extractor(
|
| 448 |
+
features["input_features"].to(self.device)
|
| 449 |
+
)["last_hidden_state"]
|
| 450 |
+
restored_waveforms = self.decoder(feature_tensor.transpose(1, 2)).cpu()
|
| 451 |
+
|
| 452 |
+
results: List[torch.Tensor] = []
|
| 453 |
+
for sample_idx, sample in enumerate(restored_waveforms):
|
| 454 |
+
restored_waveform = sample.view(-1)
|
| 455 |
+
target_length = expected_lengths[sample_idx]
|
| 456 |
+
current_length = restored_waveform.shape[-1]
|
| 457 |
+
if target_length > 0 and current_length != target_length:
|
| 458 |
+
diff = target_length - current_length
|
| 459 |
+
if diff > 0:
|
| 460 |
+
restored_waveform = torch.nn.functional.pad(
|
| 461 |
+
restored_waveform, (0, diff)
|
| 462 |
+
)
|
| 463 |
+
elif diff < 0:
|
| 464 |
+
restored_waveform = restored_waveform[:target_length]
|
| 465 |
+
results.append(restored_waveform.contiguous())
|
| 466 |
+
|
| 467 |
+
return results
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
# ---------------------------------------------------------------------------
|
| 471 |
+
# Batch collation
|
| 472 |
+
# ---------------------------------------------------------------------------
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class CollateFunction:
|
| 476 |
+
"""Collate a list of samples into a padded batch."""
|
| 477 |
+
|
| 478 |
+
def __init__(
|
| 479 |
+
self,
|
| 480 |
+
sample_rate: int,
|
| 481 |
+
skip_errors: bool,
|
| 482 |
+
) -> None:
|
| 483 |
+
self.sample_rate = sample_rate
|
| 484 |
+
self.skip_errors = skip_errors
|
| 485 |
+
|
| 486 |
+
def __call__(self, samples: Sequence[dict[str, Any]]) -> CollatedBatch:
|
| 487 |
+
keys: list[str] = []
|
| 488 |
+
waveforms: list[torch.Tensor] = []
|
| 489 |
+
durations: list[float] = []
|
| 490 |
+
metadata: list[dict[str, Any]] = []
|
| 491 |
+
|
| 492 |
+
for sample in samples:
|
| 493 |
+
keys.append(sample["label"]["id"])
|
| 494 |
+
waveforms.append(sample["audio"].squeeze(0))
|
| 495 |
+
durations.append(sample["audio"].size(-1) / self.sample_rate)
|
| 496 |
+
metadata.append(sample["label"])
|
| 497 |
+
waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
|
| 498 |
+
|
| 499 |
+
return CollatedBatch(
|
| 500 |
+
keys=keys, waveforms=waveforms, durations=durations, metadata=metadata
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
@dataclass
|
| 505 |
+
class CollatedBatch:
|
| 506 |
+
"""Batch payload returned by the DataLoader collate function."""
|
| 507 |
+
|
| 508 |
+
keys: list[str]
|
| 509 |
+
waveforms: list[torch.Tensor]
|
| 510 |
+
durations: list[float]
|
| 511 |
+
metadata: list[dict[str, Any]]
|
| 512 |
+
|
| 513 |
+
@property
|
| 514 |
+
def size(self) -> int:
|
| 515 |
+
return len(self.keys)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
# ---------------------------------------------------------------------------
|
| 519 |
+
# Subprocess-based GPU worker pool
|
| 520 |
+
# ---------------------------------------------------------------------------
|
| 521 |
+
#
|
| 522 |
+
# Problem: PyTorch ≥2.8 caches CUDA device state at import time. Neither
|
| 523 |
+
# forkserver nor spawn lets us change CUDA_VISIBLE_DEVICES *before* the CUDA
|
| 524 |
+
# runtime captures the device list. The only reliable approach is to launch
|
| 525 |
+
# each worker as a **subprocess** with CUDA_VISIBLE_DEVICES set in the
|
| 526 |
+
# subprocess environment, guaranteeing it takes effect before `import torch`.
|
| 527 |
+
#
|
| 528 |
+
# Protocol (parent ↔ child, length-prefixed pickle over stdin/stdout):
|
| 529 |
+
# Parent → child: 4-byte LE uint32 length + pickle(CollatedBatch)
|
| 530 |
+
# Child → parent: 4-byte LE uint32 length + pickle(result dict)
|
| 531 |
+
# Shutdown signal: 4 zero bytes (length == 0)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def _subprocess_recv():
|
| 535 |
+
"""Read a length-prefixed pickled object from stdin. Returns None on shutdown."""
|
| 536 |
+
raw = sys.stdin.buffer.read(4)
|
| 537 |
+
if len(raw) < 4:
|
| 538 |
+
return None
|
| 539 |
+
(length,) = struct.unpack("<I", raw)
|
| 540 |
+
if length == 0:
|
| 541 |
+
return None
|
| 542 |
+
data = sys.stdin.buffer.read(length)
|
| 543 |
+
return pickle.loads(data)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def _subprocess_send(obj):
|
| 547 |
+
"""Send a pickled object with a 4-byte length prefix to stdout."""
|
| 548 |
+
data = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
| 549 |
+
sys.stdout.buffer.write(struct.pack("<I", len(data)))
|
| 550 |
+
sys.stdout.buffer.write(data)
|
| 551 |
+
sys.stdout.buffer.flush()
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def subprocess_worker_main():
|
| 555 |
+
"""Entry point for a GPU worker subprocess.
|
| 556 |
+
|
| 557 |
+
Expected environment: CUDA_VISIBLE_DEVICES already set by the parent.
|
| 558 |
+
Receives initargs via stdin, then processes batches in a loop.
|
| 559 |
+
"""
|
| 560 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] [Worker PID %(process)d] %(message)s"
|
| 561 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 562 |
+
|
| 563 |
+
initargs = _subprocess_recv()
|
| 564 |
+
feature_extractor_path, decoder_path = initargs
|
| 565 |
+
|
| 566 |
+
device = "cpu"
|
| 567 |
+
if torch.cuda.is_available():
|
| 568 |
+
torch.cuda.set_device(0)
|
| 569 |
+
device = "cuda:0"
|
| 570 |
+
else:
|
| 571 |
+
logging.warning("CUDA not available in worker subprocess.")
|
| 572 |
+
|
| 573 |
+
logging.info(
|
| 574 |
+
f"Worker PID={os.getpid()}, "
|
| 575 |
+
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, device={device}"
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
processor = SpeechDenoisingProcessor(
|
| 579 |
+
feature_extractor_path=feature_extractor_path,
|
| 580 |
+
decoder_path=decoder_path,
|
| 581 |
+
device=device,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# Process batches until shutdown signal
|
| 585 |
+
while True:
|
| 586 |
+
msg = _subprocess_recv()
|
| 587 |
+
if msg is None:
|
| 588 |
+
break
|
| 589 |
+
req_id = msg["_req_id"]
|
| 590 |
+
batch = msg["_batch"]
|
| 591 |
+
try:
|
| 592 |
+
cleaned_waveforms = processor.process_batch(
|
| 593 |
+
batch.waveforms,
|
| 594 |
+
expected_lengths=[
|
| 595 |
+
round(d * SIDON_OUTPUT_SAMPLE_RATE) for d in batch.durations
|
| 596 |
+
],
|
| 597 |
+
)
|
| 598 |
+
cleaned_cpu = [w.cpu() for w in cleaned_waveforms]
|
| 599 |
+
result = {
|
| 600 |
+
"_req_id": req_id,
|
| 601 |
+
"status": "success",
|
| 602 |
+
"keys": batch.keys,
|
| 603 |
+
"results": cleaned_cpu,
|
| 604 |
+
"metadata": batch.metadata,
|
| 605 |
+
"size": batch.size,
|
| 606 |
+
}
|
| 607 |
+
except Exception as e:
|
| 608 |
+
result = {
|
| 609 |
+
"_req_id": req_id,
|
| 610 |
+
"status": "error",
|
| 611 |
+
"keys": batch.keys,
|
| 612 |
+
"error": str(e),
|
| 613 |
+
"size": batch.size,
|
| 614 |
+
}
|
| 615 |
+
_subprocess_send(result)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
class _GPUWorker:
|
| 619 |
+
"""Handle to a single GPU worker subprocess."""
|
| 620 |
+
|
| 621 |
+
def __init__(self, physical_gpu_id, feature_extractor_path, decoder_path):
|
| 622 |
+
env = os.environ.copy()
|
| 623 |
+
if physical_gpu_id is not None:
|
| 624 |
+
env["CUDA_VISIBLE_DEVICES"] = str(physical_gpu_id)
|
| 625 |
+
self.proc = subprocess.Popen(
|
| 626 |
+
[
|
| 627 |
+
sys.executable,
|
| 628 |
+
"-m",
|
| 629 |
+
"omnivoice.scripts.denoise_audio",
|
| 630 |
+
"--_subprocess_worker",
|
| 631 |
+
],
|
| 632 |
+
stdin=subprocess.PIPE,
|
| 633 |
+
stdout=subprocess.PIPE,
|
| 634 |
+
env=env,
|
| 635 |
+
)
|
| 636 |
+
# Send init args
|
| 637 |
+
init_data = pickle.dumps(
|
| 638 |
+
(feature_extractor_path, decoder_path), protocol=pickle.HIGHEST_PROTOCOL
|
| 639 |
+
)
|
| 640 |
+
self.proc.stdin.write(struct.pack("<I", len(init_data)))
|
| 641 |
+
self.proc.stdin.write(init_data)
|
| 642 |
+
self.proc.stdin.flush()
|
| 643 |
+
self._lock = threading.Lock()
|
| 644 |
+
|
| 645 |
+
def submit(self, batch_with_id):
|
| 646 |
+
"""Send a batch dict (containing _req_id + _batch) for processing."""
|
| 647 |
+
with self._lock:
|
| 648 |
+
data = pickle.dumps(batch_with_id, protocol=pickle.HIGHEST_PROTOCOL)
|
| 649 |
+
self.proc.stdin.write(struct.pack("<I", len(data)))
|
| 650 |
+
self.proc.stdin.write(data)
|
| 651 |
+
self.proc.stdin.flush()
|
| 652 |
+
|
| 653 |
+
def read_result(self):
|
| 654 |
+
"""Blocking read for one result."""
|
| 655 |
+
raw = self.proc.stdout.read(4)
|
| 656 |
+
if len(raw) < 4:
|
| 657 |
+
return None
|
| 658 |
+
(length,) = struct.unpack("<I", raw)
|
| 659 |
+
if length == 0:
|
| 660 |
+
return None
|
| 661 |
+
data = self.proc.stdout.read(length)
|
| 662 |
+
return pickle.loads(data)
|
| 663 |
+
|
| 664 |
+
def shutdown(self):
|
| 665 |
+
"""Send shutdown signal and wait for process."""
|
| 666 |
+
try:
|
| 667 |
+
with self._lock:
|
| 668 |
+
self.proc.stdin.write(struct.pack("<I", 0))
|
| 669 |
+
self.proc.stdin.flush()
|
| 670 |
+
except Exception:
|
| 671 |
+
pass
|
| 672 |
+
self.proc.wait(timeout=30)
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
class GPUWorkerPool:
|
| 676 |
+
"""Pool of GPU worker subprocesses with round-robin task submission."""
|
| 677 |
+
|
| 678 |
+
def __init__(self, pool_specs, feature_extractor_path, decoder_path):
|
| 679 |
+
"""
|
| 680 |
+
Args:
|
| 681 |
+
pool_specs: list of (physical_gpu_id, num_workers) tuples.
|
| 682 |
+
feature_extractor_path: path to JIT feature extractor.
|
| 683 |
+
decoder_path: path to JIT decoder.
|
| 684 |
+
"""
|
| 685 |
+
self.workers: list[_GPUWorker] = []
|
| 686 |
+
for physical_gpu_id, num_workers in pool_specs:
|
| 687 |
+
for _ in range(num_workers):
|
| 688 |
+
self.workers.append(
|
| 689 |
+
_GPUWorker(physical_gpu_id, feature_extractor_path, decoder_path)
|
| 690 |
+
)
|
| 691 |
+
self._rr = 0
|
| 692 |
+
self._futures: dict[int, Future] = {}
|
| 693 |
+
self._futures_lock = threading.Lock()
|
| 694 |
+
self._next_id = 0
|
| 695 |
+
# Start reader threads for each worker
|
| 696 |
+
self._reader_threads = []
|
| 697 |
+
for worker in self.workers:
|
| 698 |
+
t = threading.Thread(target=self._reader_loop, args=(worker,), daemon=True)
|
| 699 |
+
t.start()
|
| 700 |
+
self._reader_threads.append(t)
|
| 701 |
+
|
| 702 |
+
def _reader_loop(self, worker):
|
| 703 |
+
while True:
|
| 704 |
+
result = worker.read_result()
|
| 705 |
+
if result is None:
|
| 706 |
+
break
|
| 707 |
+
req_id = result.pop("_req_id", None)
|
| 708 |
+
with self._futures_lock:
|
| 709 |
+
fut = self._futures.pop(req_id, None)
|
| 710 |
+
if fut is not None:
|
| 711 |
+
fut.set_result(result)
|
| 712 |
+
|
| 713 |
+
def submit(self, batch) -> Future:
|
| 714 |
+
worker = self.workers[self._rr % len(self.workers)]
|
| 715 |
+
self._rr += 1
|
| 716 |
+
with self._futures_lock:
|
| 717 |
+
req_id = self._next_id
|
| 718 |
+
self._next_id += 1
|
| 719 |
+
fut = Future()
|
| 720 |
+
self._futures[req_id] = fut
|
| 721 |
+
batch_dict = {
|
| 722 |
+
"_req_id": req_id,
|
| 723 |
+
"_batch": batch,
|
| 724 |
+
}
|
| 725 |
+
worker.submit(batch_dict)
|
| 726 |
+
return fut
|
| 727 |
+
|
| 728 |
+
def shutdown(self):
|
| 729 |
+
for worker in self.workers:
|
| 730 |
+
worker.shutdown()
|
| 731 |
+
for t in self._reader_threads:
|
| 732 |
+
t.join(timeout=5)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
# ---------------------------------------------------------------------------
|
| 736 |
+
# Main
|
| 737 |
+
# ---------------------------------------------------------------------------
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
def main() -> None:
|
| 741 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 742 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 743 |
+
parser = build_parser()
|
| 744 |
+
args = parser.parse_args()
|
| 745 |
+
|
| 746 |
+
# ── Subprocess worker mode ──
|
| 747 |
+
if args._subprocess_worker:
|
| 748 |
+
subprocess_worker_main()
|
| 749 |
+
return
|
| 750 |
+
|
| 751 |
+
# Validate input arguments
|
| 752 |
+
assert args.tar_output_pattern is not None, "--tar_output_pattern is required."
|
| 753 |
+
assert args.jsonl_output_pattern is not None, "--jsonl_output_pattern is required."
|
| 754 |
+
assert bool(args.input_manifest) != bool(
|
| 755 |
+
args.input_jsonl
|
| 756 |
+
), "Exactly one of --input_manifest or --input_jsonl must be provided."
|
| 757 |
+
|
| 758 |
+
if args.num_machines > 1:
|
| 759 |
+
assert (
|
| 760 |
+
0 <= args.machine_index < args.num_machines
|
| 761 |
+
), f"machine_index {args.machine_index} must be in [0, {args.num_machines})"
|
| 762 |
+
|
| 763 |
+
# ── Build base dataset and count total samples ──
|
| 764 |
+
if args.input_jsonl:
|
| 765 |
+
logging.info(f"Input mode: raw JSONL ({args.input_jsonl})")
|
| 766 |
+
total_samples = count_lines(args.input_jsonl)
|
| 767 |
+
base_dataset = JsonlDatasetReader(
|
| 768 |
+
args.input_jsonl,
|
| 769 |
+
sample_rate=SIDON_INPUT_SAMPLE_RATE,
|
| 770 |
+
shuffle=args.shuffle,
|
| 771 |
+
shuffle_seed=args.shuffle_seed,
|
| 772 |
+
)
|
| 773 |
+
loader_workers = args.loader_workers
|
| 774 |
+
else:
|
| 775 |
+
logging.info(f"Input mode: WebDataset manifest ({args.input_manifest})")
|
| 776 |
+
manifest_num_lines = count_lines(args.input_manifest)
|
| 777 |
+
loader_workers = min(args.loader_workers, manifest_num_lines)
|
| 778 |
+
total_samples = 0
|
| 779 |
+
manifests = []
|
| 780 |
+
with open(args.input_manifest, "r", encoding="utf-8") as f:
|
| 781 |
+
for line_id, line in tqdm(
|
| 782 |
+
enumerate(f),
|
| 783 |
+
total=manifest_num_lines,
|
| 784 |
+
desc="Calculating dataset length",
|
| 785 |
+
):
|
| 786 |
+
items = line.strip().split(" ")
|
| 787 |
+
tar_path, jsonl_path, num_items, duration = (
|
| 788 |
+
items[0],
|
| 789 |
+
items[1],
|
| 790 |
+
int(items[2]),
|
| 791 |
+
float(items[3]),
|
| 792 |
+
)
|
| 793 |
+
assert os.path.exists(tar_path), f"File {tar_path} does not exist."
|
| 794 |
+
assert os.path.exists(jsonl_path), f"File {jsonl_path} does not exist."
|
| 795 |
+
assert jsonl_path.endswith(
|
| 796 |
+
".jsonl"
|
| 797 |
+
), f"File {jsonl_path} is not a .jsonl file."
|
| 798 |
+
if (
|
| 799 |
+
args.num_machines > 1
|
| 800 |
+
and line_id % args.num_machines != args.machine_index
|
| 801 |
+
):
|
| 802 |
+
continue
|
| 803 |
+
total_samples += num_items
|
| 804 |
+
manifests.append((tar_path, jsonl_path, num_items, duration))
|
| 805 |
+
logging.info(
|
| 806 |
+
f"Total shards: {manifest_num_lines}, "
|
| 807 |
+
f"Shards for current index: {len(manifests)}"
|
| 808 |
+
)
|
| 809 |
+
base_dataset = WebDatasetReader(
|
| 810 |
+
manifests=manifests,
|
| 811 |
+
sample_rate=SIDON_INPUT_SAMPLE_RATE,
|
| 812 |
+
evaluation=True,
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
# ── Dynamic batching + DataLoader ──
|
| 816 |
+
batched_dataset = StreamLengthGroupDataset(
|
| 817 |
+
dataset=base_dataset,
|
| 818 |
+
batch_duration=args.batch_duration,
|
| 819 |
+
max_sample=args.max_sample,
|
| 820 |
+
min_length=args.min_length,
|
| 821 |
+
max_length=args.max_length,
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
collate_fn = CollateFunction(
|
| 825 |
+
skip_errors=args.skip_errors,
|
| 826 |
+
sample_rate=SIDON_INPUT_SAMPLE_RATE,
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
dataloader = DataLoader(
|
| 830 |
+
dataset=batched_dataset,
|
| 831 |
+
batch_size=None,
|
| 832 |
+
collate_fn=collate_fn,
|
| 833 |
+
num_workers=loader_workers,
|
| 834 |
+
prefetch_factor=10 if loader_workers > 0 else None,
|
| 835 |
+
pin_memory=True,
|
| 836 |
+
persistent_workers=loader_workers > 0,
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
# ── Multi-GPU process pool ──
|
| 840 |
+
num_devices = torch.cuda.device_count()
|
| 841 |
+
if num_devices == 0:
|
| 842 |
+
logging.warning("No GPUs detected - using CPU for processing")
|
| 843 |
+
num_processes = args.nj_per_gpu
|
| 844 |
+
else:
|
| 845 |
+
num_processes = num_devices * args.nj_per_gpu
|
| 846 |
+
logging.info(
|
| 847 |
+
f"GPU count: {num_devices}, Processes per GPU: {args.nj_per_gpu}, "
|
| 848 |
+
f"Total processes: {num_processes}"
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Build a list of (physical_gpu_id, num_workers) for each pool.
|
| 852 |
+
# When num_devices == 0 we use a single CPU pool.
|
| 853 |
+
if num_devices == 0:
|
| 854 |
+
pool_specs = [(None, num_processes)]
|
| 855 |
+
else:
|
| 856 |
+
pool_specs = [(gpu_id, args.nj_per_gpu) for gpu_id in range(num_devices)]
|
| 857 |
+
|
| 858 |
+
# ── Output paths ──
|
| 859 |
+
tar_output_pattern = str(Path(args.tar_output_pattern).expanduser())
|
| 860 |
+
jsonl_output_pattern = str(Path(args.jsonl_output_pattern).expanduser())
|
| 861 |
+
Path(tar_output_pattern).parent.mkdir(parents=True, exist_ok=True)
|
| 862 |
+
Path(jsonl_output_pattern).parent.mkdir(parents=True, exist_ok=True)
|
| 863 |
+
|
| 864 |
+
output_dir = Path(tar_output_pattern).parent.parent
|
| 865 |
+
error_log_path = str(output_dir / "errors.jsonl")
|
| 866 |
+
manifest_path = str(output_dir / "data.lst")
|
| 867 |
+
|
| 868 |
+
error_logger = logging.getLogger("error_log")
|
| 869 |
+
error_logger.setLevel(logging.ERROR)
|
| 870 |
+
error_logger.handlers.clear()
|
| 871 |
+
error_fh = logging.FileHandler(error_log_path, mode="w", encoding="utf-8")
|
| 872 |
+
error_fh.setFormatter(logging.Formatter("%(message)s"))
|
| 873 |
+
error_logger.addHandler(error_fh)
|
| 874 |
+
|
| 875 |
+
# ── Progress and shard tracking ──
|
| 876 |
+
processed_count = 0
|
| 877 |
+
error_count = 0
|
| 878 |
+
write_error_count = 0
|
| 879 |
+
failed_ids = []
|
| 880 |
+
shard_idx = 0
|
| 881 |
+
shard_sample_count = 0
|
| 882 |
+
shard_duration = 0.0
|
| 883 |
+
samples_per_shard = args.samples_per_shard
|
| 884 |
+
shard_manifest = {}
|
| 885 |
+
target_sample_rate = args.target_sample_rate
|
| 886 |
+
|
| 887 |
+
tar_writer = None
|
| 888 |
+
jsonl_file = None
|
| 889 |
+
|
| 890 |
+
def open_new_shard():
|
| 891 |
+
nonlocal tar_writer, jsonl_file, shard_idx, shard_sample_count, shard_duration
|
| 892 |
+
if tar_writer is not None:
|
| 893 |
+
tar_writer.close()
|
| 894 |
+
if jsonl_file is not None:
|
| 895 |
+
jsonl_file.close()
|
| 896 |
+
if shard_idx > 0 and shard_sample_count > 0:
|
| 897 |
+
prev_idx = shard_idx - 1
|
| 898 |
+
shard_manifest[prev_idx] = (
|
| 899 |
+
os.path.abspath(tar_output_pattern % prev_idx),
|
| 900 |
+
os.path.abspath(jsonl_output_pattern % prev_idx),
|
| 901 |
+
shard_sample_count,
|
| 902 |
+
shard_duration,
|
| 903 |
+
)
|
| 904 |
+
tar_fname = tar_output_pattern % shard_idx
|
| 905 |
+
jsonl_fname = jsonl_output_pattern % shard_idx
|
| 906 |
+
tar_writer = wds.TarWriter(tar_fname)
|
| 907 |
+
jsonl_file = open(jsonl_fname, "w", encoding="utf-8")
|
| 908 |
+
shard_idx += 1
|
| 909 |
+
shard_sample_count = 0
|
| 910 |
+
shard_duration = 0.0
|
| 911 |
+
|
| 912 |
+
def write_sample(key, waveform, metadata):
|
| 913 |
+
nonlocal shard_sample_count, write_error_count, shard_duration
|
| 914 |
+
assert tar_writer is not None and jsonl_file is not None
|
| 915 |
+
try:
|
| 916 |
+
if target_sample_rate != SIDON_OUTPUT_SAMPLE_RATE:
|
| 917 |
+
waveform = torchaudio.functional.resample(
|
| 918 |
+
waveform,
|
| 919 |
+
orig_freq=SIDON_OUTPUT_SAMPLE_RATE,
|
| 920 |
+
new_freq=target_sample_rate,
|
| 921 |
+
)
|
| 922 |
+
waveform = (waveform / (waveform.abs().max() + 1e-7)) * 0.6
|
| 923 |
+
|
| 924 |
+
record = serialise_flac(key, waveform, target_sample_rate)
|
| 925 |
+
jsonl_record = _encode_metadata(metadata)
|
| 926 |
+
tar_writer.write(record)
|
| 927 |
+
jsonl_file.write(jsonl_record.decode("utf-8") + "\n")
|
| 928 |
+
shard_sample_count += 1
|
| 929 |
+
shard_duration += metadata.get("audio_duration", 0.0)
|
| 930 |
+
except Exception as exc:
|
| 931 |
+
write_error_count += 1
|
| 932 |
+
failed_ids.append(key)
|
| 933 |
+
error_logger.error(
|
| 934 |
+
json.dumps({"id": key, "reason": str(exc)}, ensure_ascii=False)
|
| 935 |
+
)
|
| 936 |
+
logging.error(f"Write failed for sample {key}: {exc}")
|
| 937 |
+
|
| 938 |
+
def handle_result(result):
|
| 939 |
+
nonlocal processed_count, error_count
|
| 940 |
+
if result["status"] == "success":
|
| 941 |
+
for key, cleaned, metadata in zip(
|
| 942 |
+
result["keys"], result["results"], result["metadata"]
|
| 943 |
+
):
|
| 944 |
+
if tar_writer is None or shard_sample_count >= samples_per_shard:
|
| 945 |
+
open_new_shard()
|
| 946 |
+
write_sample(key, cleaned, metadata)
|
| 947 |
+
processed_count += 1
|
| 948 |
+
else:
|
| 949 |
+
error_count += result["size"]
|
| 950 |
+
failed_ids.extend(result["keys"])
|
| 951 |
+
for key in result["keys"]:
|
| 952 |
+
error_logger.error(
|
| 953 |
+
json.dumps(
|
| 954 |
+
{"id": key, "reason": result["error"]},
|
| 955 |
+
ensure_ascii=False,
|
| 956 |
+
)
|
| 957 |
+
)
|
| 958 |
+
if not args.skip_errors:
|
| 959 |
+
raise RuntimeError(
|
| 960 |
+
f"Batch starting with {result['keys'][0]} failed - terminating"
|
| 961 |
+
)
|
| 962 |
+
logging.warning(
|
| 963 |
+
f"Skipping failed batch starting with {result['keys'][0]}: "
|
| 964 |
+
f"{result['error']}"
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
# ── Main processing loop ──
|
| 968 |
+
main_progress = tqdm(total=total_samples, desc="Denoising Audio")
|
| 969 |
+
|
| 970 |
+
# Launch subprocess-based GPU workers. CUDA_VISIBLE_DEVICES is set in the
|
| 971 |
+
# subprocess Popen environment so it takes effect before import torch.
|
| 972 |
+
pool = GPUWorkerPool(pool_specs, args.feature_extractor_path, args.decoder_path)
|
| 973 |
+
logging.info(f"Submitting tasks... ({num_processes} subprocess workers)")
|
| 974 |
+
try:
|
| 975 |
+
futures = set()
|
| 976 |
+
max_pending = num_processes * 2
|
| 977 |
+
|
| 978 |
+
def drain_completed():
|
| 979 |
+
nonlocal futures
|
| 980 |
+
done, _ = wait(futures, return_when=FIRST_COMPLETED)
|
| 981 |
+
for f in done:
|
| 982 |
+
futures.discard(f)
|
| 983 |
+
result = f.result()
|
| 984 |
+
main_progress.update(result["size"])
|
| 985 |
+
handle_result(result)
|
| 986 |
+
main_progress.set_postfix(
|
| 987 |
+
OK=processed_count,
|
| 988 |
+
Err=error_count,
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
for batch in dataloader:
|
| 992 |
+
if batch.size == 0:
|
| 993 |
+
continue
|
| 994 |
+
if len(futures) >= max_pending:
|
| 995 |
+
drain_completed()
|
| 996 |
+
futures.add(pool.submit(batch))
|
| 997 |
+
|
| 998 |
+
logging.info("Processing remaining pending batches...")
|
| 999 |
+
while futures:
|
| 1000 |
+
drain_completed()
|
| 1001 |
+
|
| 1002 |
+
except Exception:
|
| 1003 |
+
logging.error("Critical error during processing", exc_info=True)
|
| 1004 |
+
raise
|
| 1005 |
+
finally:
|
| 1006 |
+
pool.shutdown()
|
| 1007 |
+
main_progress.close()
|
| 1008 |
+
if tar_writer is not None:
|
| 1009 |
+
tar_writer.close()
|
| 1010 |
+
if jsonl_file is not None:
|
| 1011 |
+
jsonl_file.close()
|
| 1012 |
+
if shard_idx > 0 and shard_sample_count > 0:
|
| 1013 |
+
last_idx = shard_idx - 1
|
| 1014 |
+
shard_manifest[last_idx] = (
|
| 1015 |
+
os.path.abspath(tar_output_pattern % last_idx),
|
| 1016 |
+
os.path.abspath(jsonl_output_pattern % last_idx),
|
| 1017 |
+
shard_sample_count,
|
| 1018 |
+
shard_duration,
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
# ── Write manifest (data.lst) ──
|
| 1022 |
+
with open(manifest_path, "w", encoding="utf-8") as mf:
|
| 1023 |
+
for idx in sorted(shard_manifest.keys()):
|
| 1024 |
+
tar_path, jsonl_path, count, duration = shard_manifest[idx]
|
| 1025 |
+
mf.write(f"{tar_path} {jsonl_path} {count} {duration:.3f}\n")
|
| 1026 |
+
|
| 1027 |
+
# ── Summary ──
|
| 1028 |
+
total_failed = error_count + write_error_count
|
| 1029 |
+
filtered_and_skipped = total_samples - processed_count - total_failed
|
| 1030 |
+
logging.info(
|
| 1031 |
+
f"Processing Complete - Successful: {processed_count}, Failed: {total_failed}, "
|
| 1032 |
+
f"Filtered/Skipped: {filtered_and_skipped}, Shards written: {shard_idx}"
|
| 1033 |
+
)
|
| 1034 |
+
logging.info(f"Manifest written to: {manifest_path} ({len(shard_manifest)} shards)")
|
| 1035 |
+
if total_failed > 0:
|
| 1036 |
+
logging.info(f"Error details: {error_log_path}")
|
| 1037 |
+
if failed_ids and args.skip_errors:
|
| 1038 |
+
logging.warning(
|
| 1039 |
+
f"Failed sample IDs (count: {len(failed_ids)}): {failed_ids[:100]}..."
|
| 1040 |
+
)
|
| 1041 |
+
if write_error_count > 0 and not args.skip_errors:
|
| 1042 |
+
raise RuntimeError(
|
| 1043 |
+
f"{write_error_count} samples failed to write - check logs for details"
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
if __name__ == "__main__":
|
| 1048 |
+
main()
|
omnivoice/scripts/extract_audio_tokens.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Extract audio tokens from audio data and pack them into WebDataset shards.
|
| 20 |
+
|
| 21 |
+
Supports two input modes:
|
| 22 |
+
|
| 23 |
+
1. WebDataset manifest (data.lst):
|
| 24 |
+
python extract_audio_tokens.py \
|
| 25 |
+
--input_manifest data.lst \
|
| 26 |
+
--tar_output_pattern output/audios/shard-%06d.tar \
|
| 27 |
+
--jsonl_output_pattern output/txts/shard-%06d.jsonl
|
| 28 |
+
|
| 29 |
+
2. Raw JSONL (each line: {"id": "...", "audio_path": "...", "text": "...", ...}):
|
| 30 |
+
python extract_audio_tokens.py \
|
| 31 |
+
--input_jsonl data.jsonl \
|
| 32 |
+
--tar_output_pattern output/audios/shard-%06d.tar \
|
| 33 |
+
--jsonl_output_pattern output/txts/shard-%06d.jsonl
|
| 34 |
+
|
| 35 |
+
Output structure:
|
| 36 |
+
output_dir/
|
| 37 |
+
├── audios/ # WebDataset tar shards (.npy audio tokens + .json metadata)
|
| 38 |
+
│ ├── shard_000000.tar
|
| 39 |
+
│ └── ...
|
| 40 |
+
├── txts/ # Per-shard JSONL metadata
|
| 41 |
+
│ ├── shard_000000.jsonl
|
| 42 |
+
│ └── ...
|
| 43 |
+
├── data.lst # Manifest: <tar_path> <jsonl_path> <sample_count> <total_duration>
|
| 44 |
+
└── errors.jsonl # Failed samples with error details
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
import argparse
|
| 48 |
+
import io
|
| 49 |
+
import json
|
| 50 |
+
import logging
|
| 51 |
+
import multiprocessing as mp
|
| 52 |
+
import os
|
| 53 |
+
import warnings
|
| 54 |
+
from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait
|
| 55 |
+
from pathlib import Path
|
| 56 |
+
from typing import Any
|
| 57 |
+
|
| 58 |
+
import numpy as np
|
| 59 |
+
import torch
|
| 60 |
+
import webdataset as wds
|
| 61 |
+
from torch.utils.data import DataLoader, IterableDataset
|
| 62 |
+
from tqdm.auto import tqdm
|
| 63 |
+
from transformers import AutoFeatureExtractor, HiggsAudioV2TokenizerModel
|
| 64 |
+
|
| 65 |
+
from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
|
| 66 |
+
from omnivoice.utils.common import str2bool
|
| 67 |
+
|
| 68 |
+
warnings.filterwarnings(
|
| 69 |
+
"ignore", category=FutureWarning, module="torch.nn.utils.weight_norm"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
HIGGS_INPUT_SAMPLE_RATE = 24_000
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Global variables: Store tokenizer and device for each worker process
|
| 76 |
+
worker_tokenizer = None
|
| 77 |
+
worker_feature_extractor = None
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 81 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--input_manifest",
|
| 84 |
+
default=None,
|
| 85 |
+
help="Path to input dataset manifest (data.lst).",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--input_jsonl",
|
| 89 |
+
default=None,
|
| 90 |
+
help="Path to raw JSONL file (alternative to --input_manifest).",
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--tar_output_pattern",
|
| 94 |
+
required=True,
|
| 95 |
+
help="Tar shard pattern passed to WebDataset",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--jsonl_output_pattern",
|
| 99 |
+
required=True,
|
| 100 |
+
help="Jsonl shard pattern passed to WebDataset",
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--samples_per_shard",
|
| 104 |
+
type=int,
|
| 105 |
+
default=1000,
|
| 106 |
+
help="Maximum records per shard",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--min_num_shards",
|
| 110 |
+
type=int,
|
| 111 |
+
default=32,
|
| 112 |
+
help="Minimum number of output shards (use to ensure "
|
| 113 |
+
"shard count >= num_gpu * num_workers)",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--tokenizer_path",
|
| 117 |
+
type=str,
|
| 118 |
+
default="eustlb/higgs-audio-v2-tokenizer",
|
| 119 |
+
help="Path to audio tokenizer.",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--skip_errors", action="store_true", help="Skip items that fail to process"
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--min_length",
|
| 126 |
+
type=float,
|
| 127 |
+
default=0.0,
|
| 128 |
+
help="Minimum audio duration in seconds (e.g. 2.0)",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--max_length",
|
| 132 |
+
type=float,
|
| 133 |
+
default=float("inf"),
|
| 134 |
+
help="Maximum audio duration in seconds (e.g. 15.0)",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--num_machines",
|
| 138 |
+
type=int,
|
| 139 |
+
default=1,
|
| 140 |
+
help="Total number of machines for distributed runs",
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--machine_index",
|
| 144 |
+
type=int,
|
| 145 |
+
default=0,
|
| 146 |
+
help="Zero-based machine index when distributing across multiple "
|
| 147 |
+
"machines (e.g. 0, 1, ... num_machines-1)",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--nj_per_gpu",
|
| 151 |
+
type=int,
|
| 152 |
+
default=3,
|
| 153 |
+
help="Number of worker processes to spawn per GPU.",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--loader_workers",
|
| 157 |
+
type=int,
|
| 158 |
+
default=24,
|
| 159 |
+
help="Number of DataLoader workers for streaming IterableDataset.",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--shuffle",
|
| 163 |
+
type=str2bool,
|
| 164 |
+
default=True,
|
| 165 |
+
help="Shuffle data by default.",
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--shuffle-seed",
|
| 169 |
+
type=int,
|
| 170 |
+
default=42,
|
| 171 |
+
help="Random seed for shuffle (default: 42).",
|
| 172 |
+
)
|
| 173 |
+
return parser
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def count_lines(path):
|
| 177 |
+
with open(path, "rb") as f:
|
| 178 |
+
return sum(buf.count(b"\n") for buf in iter(lambda: f.read(1 << 20), b""))
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def serialise_numpy(key: str, tokens: np.ndarray) -> dict:
|
| 182 |
+
buffer = io.BytesIO()
|
| 183 |
+
np.save(buffer, tokens)
|
| 184 |
+
return {"__key__": key, "npy": buffer.getvalue()}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def process_init(rank_queue, tokenizer_path):
|
| 188 |
+
"""
|
| 189 |
+
Initialization function for each worker process.
|
| 190 |
+
Assigns a specific GPU to the process and loads the tokenizer.
|
| 191 |
+
"""
|
| 192 |
+
global worker_tokenizer, worker_feature_extractor
|
| 193 |
+
|
| 194 |
+
# Configure worker process logging
|
| 195 |
+
formatter = (
|
| 196 |
+
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d]"
|
| 197 |
+
" [Worker %(process)d] %(message)s"
|
| 198 |
+
)
|
| 199 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 200 |
+
|
| 201 |
+
# Get assigned GPU rank
|
| 202 |
+
rank = rank_queue.get()
|
| 203 |
+
# Determine device
|
| 204 |
+
if rank != -1 and torch.cuda.is_available():
|
| 205 |
+
worker_device = torch.device(f"cuda:{rank}")
|
| 206 |
+
else:
|
| 207 |
+
worker_device = torch.device("cpu")
|
| 208 |
+
|
| 209 |
+
logging.debug(f"Worker process initialized with device: {worker_device}")
|
| 210 |
+
# Load tokenizer onto the specified device
|
| 211 |
+
worker_feature_extractor = AutoFeatureExtractor.from_pretrained(tokenizer_path)
|
| 212 |
+
worker_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
|
| 213 |
+
tokenizer_path, device_map=worker_device
|
| 214 |
+
)
|
| 215 |
+
logging.debug(f"Tokenizer loaded successfully on device {worker_device}")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def process_single_sample(sample: dict[str, Any]) -> dict[str, Any]:
|
| 219 |
+
"""
|
| 220 |
+
Single-sample processing function executed in worker processes.
|
| 221 |
+
Skips invalid samples during streaming processing.
|
| 222 |
+
"""
|
| 223 |
+
try:
|
| 224 |
+
audio_tensor = sample.get("audio", None) # shape (1, T)
|
| 225 |
+
if audio_tensor is None:
|
| 226 |
+
raise ValueError("Sample missing 'audio' field")
|
| 227 |
+
|
| 228 |
+
with torch.inference_mode():
|
| 229 |
+
key = sample["label"]["id"]
|
| 230 |
+
inputs = worker_feature_extractor(
|
| 231 |
+
raw_audio=audio_tensor.squeeze(0).numpy(),
|
| 232 |
+
sampling_rate=HIGGS_INPUT_SAMPLE_RATE,
|
| 233 |
+
return_tensors="pt",
|
| 234 |
+
).to(worker_tokenizer.device)
|
| 235 |
+
audio_tokens = worker_tokenizer.encode(
|
| 236 |
+
inputs["input_values"],
|
| 237 |
+
).audio_codes.squeeze(0)
|
| 238 |
+
|
| 239 |
+
assert len(audio_tokens.shape) == 2
|
| 240 |
+
assert audio_tokens.size(0) == 8
|
| 241 |
+
|
| 242 |
+
num_tokens = audio_tokens.size(1)
|
| 243 |
+
metadata = sample["label"]
|
| 244 |
+
metadata["num_tokens"] = num_tokens
|
| 245 |
+
|
| 246 |
+
# Convert to numpy format for subsequent serialization (int16 to save space)
|
| 247 |
+
audio_tokens_np = audio_tokens.to(torch.int16).cpu().numpy()
|
| 248 |
+
|
| 249 |
+
return {
|
| 250 |
+
"status": "success",
|
| 251 |
+
"key": key,
|
| 252 |
+
"audio_tokens": audio_tokens_np,
|
| 253 |
+
"metadata": metadata,
|
| 254 |
+
"error_msg": None,
|
| 255 |
+
}
|
| 256 |
+
except Exception as e:
|
| 257 |
+
sample_id = sample.get("label", {}).get("id", "unknown")
|
| 258 |
+
logging.error(f"Failed to process sample {sample_id}: {e}")
|
| 259 |
+
return {
|
| 260 |
+
"status": "error",
|
| 261 |
+
"key": sample_id,
|
| 262 |
+
"audio_tokens": None,
|
| 263 |
+
"metadata": None,
|
| 264 |
+
"error_msg": str(e),
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _normalise_value(value: Any) -> Any:
|
| 269 |
+
"""Convert tensors and NumPy scalars to serialisable Python objects."""
|
| 270 |
+
if isinstance(value, torch.Tensor):
|
| 271 |
+
if value.ndim == 0:
|
| 272 |
+
return value.item()
|
| 273 |
+
return value.cpu().tolist()
|
| 274 |
+
if isinstance(value, np.generic):
|
| 275 |
+
return value.item()
|
| 276 |
+
if isinstance(value, np.ndarray):
|
| 277 |
+
return value.tolist()
|
| 278 |
+
return value
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def _encode_metadata(metadata: dict[str, Any]) -> bytes:
|
| 282 |
+
cleaned: dict[str, Any] = {}
|
| 283 |
+
for key, value in metadata.items():
|
| 284 |
+
if value is None:
|
| 285 |
+
continue
|
| 286 |
+
cleaned[key] = _normalise_value(value)
|
| 287 |
+
return json.dumps(cleaned, ensure_ascii=False).encode("utf-8")
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class StreamingLengthFilteredDataset(IterableDataset):
|
| 291 |
+
def __init__(
|
| 292 |
+
self,
|
| 293 |
+
base_iterable,
|
| 294 |
+
min_len: float,
|
| 295 |
+
max_len: float,
|
| 296 |
+
sr: int,
|
| 297 |
+
):
|
| 298 |
+
self.base_iterable = base_iterable
|
| 299 |
+
self.min_len = min_len
|
| 300 |
+
self.max_len = max_len
|
| 301 |
+
self.sr = sr
|
| 302 |
+
self.filtered_count = 0
|
| 303 |
+
|
| 304 |
+
def __iter__(self):
|
| 305 |
+
"""Stream samples one by one and filter on the fly."""
|
| 306 |
+
for sample in self.base_iterable:
|
| 307 |
+
try:
|
| 308 |
+
duration = sample["audio"].size(-1) / self.sr
|
| 309 |
+
if self.min_len <= duration <= self.max_len:
|
| 310 |
+
yield sample
|
| 311 |
+
else:
|
| 312 |
+
self.filtered_count += 1
|
| 313 |
+
logging.warning(
|
| 314 |
+
f"Filtered sample (duration out of range): "
|
| 315 |
+
f"{sample['label']['id']} ({duration:.2f}s)"
|
| 316 |
+
)
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logging.warning(f"Skipped invalid sample during streaming: {e}")
|
| 319 |
+
continue
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def main() -> None:
|
| 323 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 324 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 325 |
+
parser = build_parser()
|
| 326 |
+
args = parser.parse_args()
|
| 327 |
+
mp.set_start_method("spawn", force=True)
|
| 328 |
+
|
| 329 |
+
# Validate input arguments
|
| 330 |
+
assert bool(args.input_manifest) != bool(
|
| 331 |
+
args.input_jsonl
|
| 332 |
+
), "Exactly one of --input_manifest or --input_jsonl must be provided."
|
| 333 |
+
|
| 334 |
+
if args.num_machines > 1:
|
| 335 |
+
assert (
|
| 336 |
+
0 <= args.machine_index < args.num_machines
|
| 337 |
+
), f"machine_index {args.machine_index} must be in [0, {args.num_machines})"
|
| 338 |
+
|
| 339 |
+
# Build base dataset and count total samples based on input mode
|
| 340 |
+
if args.input_jsonl:
|
| 341 |
+
logging.info(f"Input mode: raw JSONL ({args.input_jsonl})")
|
| 342 |
+
total_samples = count_lines(args.input_jsonl)
|
| 343 |
+
base_dataset = JsonlDatasetReader(
|
| 344 |
+
args.input_jsonl,
|
| 345 |
+
sample_rate=HIGGS_INPUT_SAMPLE_RATE,
|
| 346 |
+
shuffle=args.shuffle,
|
| 347 |
+
shuffle_seed=args.shuffle_seed,
|
| 348 |
+
)
|
| 349 |
+
loader_workers = args.loader_workers
|
| 350 |
+
else:
|
| 351 |
+
logging.info(f"Input mode: WebDataset manifest ({args.input_manifest})")
|
| 352 |
+
manifest_num_lines = count_lines(args.input_manifest)
|
| 353 |
+
loader_workers = min(args.loader_workers, manifest_num_lines)
|
| 354 |
+
total_samples = 0
|
| 355 |
+
manifests = []
|
| 356 |
+
with open(args.input_manifest, "r", encoding="utf-8") as f:
|
| 357 |
+
for line_id, line in tqdm(
|
| 358 |
+
enumerate(f),
|
| 359 |
+
total=manifest_num_lines,
|
| 360 |
+
desc="Calculating dataset length",
|
| 361 |
+
):
|
| 362 |
+
items = line.strip().split(" ")
|
| 363 |
+
tar_path, jsonl_path, num_items, duration = (
|
| 364 |
+
items[0],
|
| 365 |
+
items[1],
|
| 366 |
+
int(items[2]),
|
| 367 |
+
float(items[3]),
|
| 368 |
+
)
|
| 369 |
+
assert os.path.exists(tar_path), f"File {tar_path} does not exist."
|
| 370 |
+
assert os.path.exists(jsonl_path), f"File {jsonl_path} does not exist."
|
| 371 |
+
assert jsonl_path.endswith(
|
| 372 |
+
".jsonl"
|
| 373 |
+
), f"File {jsonl_path} is not a .jsonl file."
|
| 374 |
+
if (
|
| 375 |
+
args.num_machines > 1
|
| 376 |
+
and line_id % args.num_machines != args.machine_index
|
| 377 |
+
):
|
| 378 |
+
continue
|
| 379 |
+
total_samples += num_items
|
| 380 |
+
manifests.append((tar_path, jsonl_path, num_items, duration))
|
| 381 |
+
logging.info(
|
| 382 |
+
f"Total shards: {manifest_num_lines}, "
|
| 383 |
+
f"Shards for current index: {len(manifests)}"
|
| 384 |
+
)
|
| 385 |
+
base_dataset = WebDatasetReader(
|
| 386 |
+
manifests=manifests,
|
| 387 |
+
sample_rate=HIGGS_INPUT_SAMPLE_RATE,
|
| 388 |
+
evaluation=True,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# Adjust samples_per_shard if min_num_shards would be violated
|
| 392 |
+
samples_per_shard = args.samples_per_shard
|
| 393 |
+
if total_samples > 0:
|
| 394 |
+
estimated_shards = max(
|
| 395 |
+
1, (total_samples + samples_per_shard - 1) // samples_per_shard
|
| 396 |
+
)
|
| 397 |
+
if estimated_shards < args.min_num_shards:
|
| 398 |
+
samples_per_shard = max(1, total_samples // args.min_num_shards)
|
| 399 |
+
logging.info(
|
| 400 |
+
f"Adjusted samples_per_shard from {args.samples_per_shard} to "
|
| 401 |
+
f"{samples_per_shard} to meet min_num_shards={args.min_num_shards} "
|
| 402 |
+
f"(total_samples={total_samples})"
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Apply length filter and create DataLoader
|
| 406 |
+
filtered_dataset = StreamingLengthFilteredDataset(
|
| 407 |
+
base_iterable=base_dataset,
|
| 408 |
+
min_len=args.min_length,
|
| 409 |
+
max_len=args.max_length,
|
| 410 |
+
sr=HIGGS_INPUT_SAMPLE_RATE,
|
| 411 |
+
)
|
| 412 |
+
dataloader = DataLoader(
|
| 413 |
+
dataset=filtered_dataset,
|
| 414 |
+
batch_size=None,
|
| 415 |
+
num_workers=loader_workers,
|
| 416 |
+
persistent_workers=loader_workers > 0,
|
| 417 |
+
pin_memory=False,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Configure multi-GPU multi-process setup
|
| 421 |
+
num_devices = torch.cuda.device_count()
|
| 422 |
+
if num_devices == 0:
|
| 423 |
+
logging.warning("No GPUs detected - using CPU for processing")
|
| 424 |
+
num_processes = args.nj_per_gpu
|
| 425 |
+
else:
|
| 426 |
+
num_processes = num_devices * args.nj_per_gpu
|
| 427 |
+
logging.info(
|
| 428 |
+
f"GPU count: {num_devices}, Processes per GPU: {args.nj_per_gpu}, "
|
| 429 |
+
f"Total processes: {num_processes}"
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# Shared GPU rank queue for process assignment
|
| 433 |
+
manager = mp.Manager()
|
| 434 |
+
rank_queue = manager.Queue()
|
| 435 |
+
for rank in list(range(num_devices)) * args.nj_per_gpu:
|
| 436 |
+
rank_queue.put(rank)
|
| 437 |
+
if num_devices == 0:
|
| 438 |
+
for _ in range(num_processes):
|
| 439 |
+
rank_queue.put(-1)
|
| 440 |
+
|
| 441 |
+
# Prepare output paths
|
| 442 |
+
tar_output_pattern = str(Path(args.tar_output_pattern).expanduser())
|
| 443 |
+
jsonl_output_pattern = str(Path(args.jsonl_output_pattern).expanduser())
|
| 444 |
+
Path(tar_output_pattern).parent.mkdir(parents=True, exist_ok=True)
|
| 445 |
+
Path(jsonl_output_pattern).parent.mkdir(parents=True, exist_ok=True)
|
| 446 |
+
|
| 447 |
+
# Determine output directory from tar_output_pattern
|
| 448 |
+
output_dir = Path(tar_output_pattern).parent.parent
|
| 449 |
+
error_log_path = str(output_dir / "errors.jsonl")
|
| 450 |
+
manifest_path = str(output_dir / "data.lst")
|
| 451 |
+
|
| 452 |
+
# Setup error logger (writes to errors.jsonl)
|
| 453 |
+
error_logger = logging.getLogger("error_log")
|
| 454 |
+
error_logger.setLevel(logging.ERROR)
|
| 455 |
+
error_logger.handlers.clear()
|
| 456 |
+
error_fh = logging.FileHandler(error_log_path, mode="w", encoding="utf-8")
|
| 457 |
+
error_fh.setFormatter(logging.Formatter("%(message)s"))
|
| 458 |
+
error_logger.addHandler(error_fh)
|
| 459 |
+
|
| 460 |
+
# Progress and error tracking
|
| 461 |
+
processed_count = 0
|
| 462 |
+
error_count = 0
|
| 463 |
+
write_error_count = 0
|
| 464 |
+
failed_ids = []
|
| 465 |
+
shard_idx = 0
|
| 466 |
+
shard_sample_count = 0
|
| 467 |
+
shard_duration = 0.0
|
| 468 |
+
shard_manifest = {} # shard_idx -> (tar_path, jsonl_path, count, duration)
|
| 469 |
+
|
| 470 |
+
tar_writer = None
|
| 471 |
+
jsonl_file = None
|
| 472 |
+
|
| 473 |
+
def open_new_shard():
|
| 474 |
+
nonlocal tar_writer, jsonl_file, shard_idx, shard_sample_count, shard_duration
|
| 475 |
+
if tar_writer is not None:
|
| 476 |
+
tar_writer.close()
|
| 477 |
+
if jsonl_file is not None:
|
| 478 |
+
jsonl_file.close()
|
| 479 |
+
# Record manifest for the previous shard
|
| 480 |
+
if shard_idx > 0 and shard_sample_count > 0:
|
| 481 |
+
prev_idx = shard_idx - 1
|
| 482 |
+
shard_manifest[prev_idx] = (
|
| 483 |
+
os.path.abspath(tar_output_pattern % prev_idx),
|
| 484 |
+
os.path.abspath(jsonl_output_pattern % prev_idx),
|
| 485 |
+
shard_sample_count,
|
| 486 |
+
shard_duration,
|
| 487 |
+
)
|
| 488 |
+
tar_fname = tar_output_pattern % shard_idx
|
| 489 |
+
jsonl_fname = jsonl_output_pattern % shard_idx
|
| 490 |
+
tar_writer = wds.TarWriter(tar_fname)
|
| 491 |
+
jsonl_file = open(jsonl_fname, "w", encoding="utf-8")
|
| 492 |
+
shard_idx += 1
|
| 493 |
+
shard_sample_count = 0
|
| 494 |
+
shard_duration = 0.0
|
| 495 |
+
|
| 496 |
+
def write_sample(key, audio_tokens_np, metadata):
|
| 497 |
+
nonlocal shard_sample_count, write_error_count, shard_duration
|
| 498 |
+
assert tar_writer is not None and jsonl_file is not None
|
| 499 |
+
try:
|
| 500 |
+
token_record = serialise_numpy(key, audio_tokens_np)
|
| 501 |
+
json_record = _encode_metadata(metadata)
|
| 502 |
+
tar_writer.write(token_record)
|
| 503 |
+
jsonl_file.write(json_record.decode("utf-8") + "\n")
|
| 504 |
+
shard_sample_count += 1
|
| 505 |
+
shard_duration += metadata.get("audio_duration", 0.0)
|
| 506 |
+
except Exception as exc:
|
| 507 |
+
write_error_count += 1
|
| 508 |
+
failed_ids.append(key)
|
| 509 |
+
error_logger.error(
|
| 510 |
+
json.dumps({"id": key, "reason": str(exc)}, ensure_ascii=False)
|
| 511 |
+
)
|
| 512 |
+
logging.error(f"Write failed for sample {key}: {exc}")
|
| 513 |
+
|
| 514 |
+
def handle_result(result):
|
| 515 |
+
nonlocal processed_count, error_count
|
| 516 |
+
if result["status"] == "success":
|
| 517 |
+
# Rotate shard if needed
|
| 518 |
+
if tar_writer is None or shard_sample_count >= samples_per_shard:
|
| 519 |
+
open_new_shard()
|
| 520 |
+
write_sample(result["key"], result["audio_tokens"], result["metadata"])
|
| 521 |
+
processed_count += 1
|
| 522 |
+
else:
|
| 523 |
+
error_count += 1
|
| 524 |
+
failed_ids.append(result["key"])
|
| 525 |
+
error_logger.error(
|
| 526 |
+
json.dumps(
|
| 527 |
+
{"id": result["key"], "reason": result["error_msg"]},
|
| 528 |
+
ensure_ascii=False,
|
| 529 |
+
)
|
| 530 |
+
)
|
| 531 |
+
if not args.skip_errors:
|
| 532 |
+
raise RuntimeError(
|
| 533 |
+
f"Sample {result['key']} processing failed due "
|
| 534 |
+
f"to {result['error_msg']} - terminating"
|
| 535 |
+
)
|
| 536 |
+
logging.warning(
|
| 537 |
+
f"Skipping failed sample {result['key']}: {result['error_msg']}"
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
main_progress = tqdm(total=total_samples, desc="Extracting Audio Tokens")
|
| 541 |
+
|
| 542 |
+
try:
|
| 543 |
+
with ProcessPoolExecutor(
|
| 544 |
+
max_workers=num_processes,
|
| 545 |
+
initializer=process_init,
|
| 546 |
+
initargs=(rank_queue, args.tokenizer_path),
|
| 547 |
+
) as executor:
|
| 548 |
+
logging.info(f"Submitting tasks... ({num_processes} workers)")
|
| 549 |
+
futures = set()
|
| 550 |
+
max_pending = num_processes * 10
|
| 551 |
+
|
| 552 |
+
def drain_completed():
|
| 553 |
+
"""Wait for at least one future to complete, process all done."""
|
| 554 |
+
nonlocal futures
|
| 555 |
+
done, _ = wait(futures, return_when=FIRST_COMPLETED)
|
| 556 |
+
for f in done:
|
| 557 |
+
futures.discard(f)
|
| 558 |
+
result = f.result()
|
| 559 |
+
main_progress.update(1)
|
| 560 |
+
handle_result(result)
|
| 561 |
+
main_progress.set_postfix(
|
| 562 |
+
Samples=processed_count,
|
| 563 |
+
Errors=error_count,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
# Stream samples from DataLoader
|
| 567 |
+
for sample in dataloader:
|
| 568 |
+
if len(futures) >= max_pending:
|
| 569 |
+
drain_completed()
|
| 570 |
+
|
| 571 |
+
future = executor.submit(process_single_sample, sample)
|
| 572 |
+
futures.add(future)
|
| 573 |
+
|
| 574 |
+
# Process remaining futures
|
| 575 |
+
logging.info("Processing remaining pending samples...")
|
| 576 |
+
while futures:
|
| 577 |
+
drain_completed()
|
| 578 |
+
|
| 579 |
+
except Exception:
|
| 580 |
+
logging.error("Critical error during processing", exc_info=True)
|
| 581 |
+
raise
|
| 582 |
+
finally:
|
| 583 |
+
main_progress.close()
|
| 584 |
+
if tar_writer is not None:
|
| 585 |
+
tar_writer.close()
|
| 586 |
+
if jsonl_file is not None:
|
| 587 |
+
jsonl_file.close()
|
| 588 |
+
# Record the last shard in the manifest
|
| 589 |
+
if shard_idx > 0 and shard_sample_count > 0:
|
| 590 |
+
last_idx = shard_idx - 1
|
| 591 |
+
shard_manifest[last_idx] = (
|
| 592 |
+
os.path.abspath(tar_output_pattern % last_idx),
|
| 593 |
+
os.path.abspath(jsonl_output_pattern % last_idx),
|
| 594 |
+
shard_sample_count,
|
| 595 |
+
shard_duration,
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
# Write manifest file (data.lst)
|
| 599 |
+
with open(manifest_path, "w", encoding="utf-8") as mf:
|
| 600 |
+
for idx in sorted(shard_manifest.keys()):
|
| 601 |
+
tar_path, jsonl_path, count, duration = shard_manifest[idx]
|
| 602 |
+
mf.write(f"{tar_path} {jsonl_path} {count} {duration:.3f}\n")
|
| 603 |
+
|
| 604 |
+
# Output final statistics
|
| 605 |
+
total_failed = error_count + write_error_count
|
| 606 |
+
filtered_and_skipped = total_samples - processed_count - total_failed
|
| 607 |
+
logging.info(
|
| 608 |
+
f"Processing Complete - Successful: {processed_count}, Failed: {total_failed}, "
|
| 609 |
+
f"Filtered/Skipped: {filtered_and_skipped}, Shards written: {shard_idx}"
|
| 610 |
+
)
|
| 611 |
+
logging.info(f"Manifest written to: {manifest_path} ({len(shard_manifest)} shards)")
|
| 612 |
+
if total_failed > 0:
|
| 613 |
+
logging.info(f"Error details: {error_log_path}")
|
| 614 |
+
if failed_ids and args.skip_errors:
|
| 615 |
+
logging.warning(
|
| 616 |
+
f"Failed sample IDs (count: {len(failed_ids)}): {failed_ids[:100]}..."
|
| 617 |
+
)
|
| 618 |
+
if write_error_count > 0 and not args.skip_errors:
|
| 619 |
+
raise RuntimeError(
|
| 620 |
+
f"{write_error_count} samples failed to write - check logs for details"
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
if __name__ == "__main__":
|
| 625 |
+
main()
|
omnivoice/scripts/extract_audio_tokens_add_noise.py
ADDED
|
@@ -0,0 +1,825 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Extract audio tokens from audio data and pack them into WebDataset shards.
|
| 20 |
+
|
| 21 |
+
Extends ``extract_audio_tokens.py`` with optional noise and reverberation
|
| 22 |
+
augmentation on the prompt (reference) portion of the audio. Requires a
|
| 23 |
+
noise manifest and/or RIR manifest.
|
| 24 |
+
|
| 25 |
+
Supports two input modes:
|
| 26 |
+
|
| 27 |
+
1. WebDataset manifest (data.lst):
|
| 28 |
+
python extract_audio_tokens_add_noise.py \\
|
| 29 |
+
--input_manifest data.lst \\
|
| 30 |
+
--noise_manifest noise.lst \\
|
| 31 |
+
--tar_output_pattern output/audios/shard-%06d.tar \\
|
| 32 |
+
--jsonl_output_pattern output/txts/shard-%06d.jsonl
|
| 33 |
+
|
| 34 |
+
2. Raw JSONL (each line: {"id": "...", "audio_path": "...", "text": "...", ...}):
|
| 35 |
+
python extract_audio_tokens_add_noise.py \\
|
| 36 |
+
--input_jsonl data.jsonl \\
|
| 37 |
+
--noise_manifest noise.lst \\
|
| 38 |
+
--tar_output_pattern output/audios/shard-%06d.tar \\
|
| 39 |
+
--jsonl_output_pattern output/txts/shard-%06d.jsonl
|
| 40 |
+
|
| 41 |
+
Output structure:
|
| 42 |
+
output_dir/
|
| 43 |
+
├── audios/ # WebDataset tar shards (.npy audio tokens + .json metadata)
|
| 44 |
+
│ ├── shard_000000.tar
|
| 45 |
+
│ └── ...
|
| 46 |
+
├── txts/ # Per-shard JSONL metadata
|
| 47 |
+
│ ├── shard_000000.jsonl
|
| 48 |
+
│ └── ...
|
| 49 |
+
├── data.lst # Manifest: <tar_path> <jsonl_path> <sample_count> <total_duration>
|
| 50 |
+
└── errors.jsonl # Failed samples with error details
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
import argparse
|
| 54 |
+
import io
|
| 55 |
+
import json
|
| 56 |
+
import logging
|
| 57 |
+
import math
|
| 58 |
+
import multiprocessing as mp
|
| 59 |
+
import os
|
| 60 |
+
import random
|
| 61 |
+
import warnings
|
| 62 |
+
from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait
|
| 63 |
+
from pathlib import Path
|
| 64 |
+
from typing import Any
|
| 65 |
+
|
| 66 |
+
import numpy as np
|
| 67 |
+
import torch
|
| 68 |
+
import torch.nn.functional as F
|
| 69 |
+
import torchaudio
|
| 70 |
+
import webdataset as wds
|
| 71 |
+
from torch.utils.data import DataLoader, IterableDataset
|
| 72 |
+
from tqdm.auto import tqdm
|
| 73 |
+
from transformers import AutoFeatureExtractor, HiggsAudioV2TokenizerModel
|
| 74 |
+
|
| 75 |
+
from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
|
| 76 |
+
from omnivoice.utils.common import str2bool
|
| 77 |
+
|
| 78 |
+
warnings.filterwarnings(
|
| 79 |
+
"ignore", category=FutureWarning, module="torch.nn.utils.weight_norm"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
HIGGS_INPUT_SAMPLE_RATE = 24_000
|
| 83 |
+
|
| 84 |
+
# Global variables: Store tokenizer and device for each worker process
|
| 85 |
+
worker_tokenizer = None
|
| 86 |
+
worker_feature_extractor = None
|
| 87 |
+
worker_noise_sampler = None
|
| 88 |
+
worker_rir_sampler = None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 92 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--input_manifest",
|
| 95 |
+
default=None,
|
| 96 |
+
help="Path to input dataset manifest (data.lst).",
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--input_jsonl",
|
| 100 |
+
default=None,
|
| 101 |
+
help="Path to raw JSONL file (alternative to --input_manifest).",
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--tar_output_pattern",
|
| 105 |
+
required=True,
|
| 106 |
+
help="Tar shard pattern passed to WebDataset",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--jsonl_output_pattern",
|
| 110 |
+
required=True,
|
| 111 |
+
help="Jsonl shard pattern passed to WebDataset",
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--samples_per_shard",
|
| 115 |
+
type=int,
|
| 116 |
+
default=1000,
|
| 117 |
+
help="Maximum records per shard",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--min_num_shards",
|
| 121 |
+
type=int,
|
| 122 |
+
default=32,
|
| 123 |
+
help="Minimum number of output shards (use to ensure "
|
| 124 |
+
"shard count >= num_gpu * num_workers)",
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--tokenizer_path",
|
| 128 |
+
type=str,
|
| 129 |
+
default="eustlb/higgs-audio-v2-tokenizer",
|
| 130 |
+
help="Path to audio tokenizer.",
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--skip_errors", action="store_true", help="Skip items that fail to process"
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--min_length",
|
| 137 |
+
type=float,
|
| 138 |
+
default=0.0,
|
| 139 |
+
help="Minimum audio duration in seconds (e.g. 2.0)",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--max_length",
|
| 143 |
+
type=float,
|
| 144 |
+
default=float("inf"),
|
| 145 |
+
help="Maximum audio duration in seconds (e.g. 15.0)",
|
| 146 |
+
)
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--num_machines",
|
| 149 |
+
type=int,
|
| 150 |
+
default=1,
|
| 151 |
+
help="Total number of machines for distributed runs",
|
| 152 |
+
)
|
| 153 |
+
parser.add_argument(
|
| 154 |
+
"--machine_index",
|
| 155 |
+
type=int,
|
| 156 |
+
default=0,
|
| 157 |
+
help="Zero-based machine index when distributing across multiple "
|
| 158 |
+
"machines (e.g. 0, 1, ... num_machines-1)",
|
| 159 |
+
)
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--nj_per_gpu",
|
| 162 |
+
type=int,
|
| 163 |
+
default=3,
|
| 164 |
+
help="Number of worker processes to spawn per GPU.",
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--loader_workers",
|
| 168 |
+
type=int,
|
| 169 |
+
default=24,
|
| 170 |
+
help="Number of DataLoader workers for streaming IterableDataset.",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--shuffle",
|
| 174 |
+
type=str2bool,
|
| 175 |
+
default=True,
|
| 176 |
+
help="Shuffle data by default.",
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--shuffle-seed",
|
| 180 |
+
type=int,
|
| 181 |
+
default=42,
|
| 182 |
+
help="Random seed for shuffle (default: 42).",
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--noise_manifest",
|
| 186 |
+
default=None,
|
| 187 |
+
help="Path to noise manifest (list of tar files). Enables prompt noise augmentation.",
|
| 188 |
+
)
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--rir_manifest",
|
| 191 |
+
default=None,
|
| 192 |
+
help="Path to RIR manifest (list of tar files). Enables prompt reverb augmentation.",
|
| 193 |
+
)
|
| 194 |
+
return parser
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def count_lines(path):
|
| 198 |
+
with open(path, "rb") as f:
|
| 199 |
+
return sum(buf.count(b"\n") for buf in iter(lambda: f.read(1 << 20), b""))
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def serialise_numpy(key: str, tokens: np.ndarray) -> dict:
|
| 203 |
+
buffer = io.BytesIO()
|
| 204 |
+
np.save(buffer, tokens)
|
| 205 |
+
return {"__key__": key, "npy": buffer.getvalue()}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _load_aug_audio(data, sample_rate=24000):
|
| 209 |
+
"""Simple audio loader for augmentation files."""
|
| 210 |
+
with io.BytesIO(data) as b:
|
| 211 |
+
wav, sr = torchaudio.load(b)
|
| 212 |
+
if wav.shape[0] > 1:
|
| 213 |
+
wav = wav.mean(dim=0, keepdim=True)
|
| 214 |
+
if sr != sample_rate:
|
| 215 |
+
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
| 216 |
+
return wav
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class SimpleWorkerSampler:
|
| 220 |
+
"""A lightweight infinite sampler for noise/RIR within a worker process."""
|
| 221 |
+
|
| 222 |
+
def __init__(self, tar_paths, sample_rate=24000):
|
| 223 |
+
self.dataset = (
|
| 224 |
+
wds.WebDataset(
|
| 225 |
+
tar_paths, shardshuffle=True, nodesplitter=None, workersplitter=None
|
| 226 |
+
)
|
| 227 |
+
.decode()
|
| 228 |
+
.map(lambda s: self._decode(s, sample_rate))
|
| 229 |
+
.select(lambda x: x is not None)
|
| 230 |
+
.shuffle(100)
|
| 231 |
+
.repeat()
|
| 232 |
+
)
|
| 233 |
+
self.iterator = iter(self.dataset)
|
| 234 |
+
|
| 235 |
+
def _decode(self, sample, sample_rate):
|
| 236 |
+
for ext in ["wav", "flac", "mp3"]:
|
| 237 |
+
if ext in sample:
|
| 238 |
+
return _load_aug_audio(sample[ext], sample_rate)
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
def sample_segment(self, target_len, allow_repeat=True):
|
| 242 |
+
"""Get a random segment of noise matching the target length."""
|
| 243 |
+
try:
|
| 244 |
+
audio = next(self.iterator)
|
| 245 |
+
except StopIteration:
|
| 246 |
+
self.iterator = iter(self.dataset)
|
| 247 |
+
audio = next(self.iterator)
|
| 248 |
+
|
| 249 |
+
cur_len = audio.size(-1)
|
| 250 |
+
if cur_len < target_len and allow_repeat:
|
| 251 |
+
if cur_len > 0:
|
| 252 |
+
num_repeats = math.ceil(target_len / cur_len)
|
| 253 |
+
audio = audio.repeat(1, num_repeats)
|
| 254 |
+
else:
|
| 255 |
+
audio = F.pad(audio, (0, target_len), mode="constant")
|
| 256 |
+
cur_len = audio.size(-1)
|
| 257 |
+
|
| 258 |
+
if cur_len > target_len:
|
| 259 |
+
start = random.randint(0, cur_len - target_len)
|
| 260 |
+
audio = audio[..., start : start + target_len]
|
| 261 |
+
|
| 262 |
+
return audio
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _convolve1d(signal: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
|
| 266 |
+
m = signal.size(-1)
|
| 267 |
+
n = kernel.size(-1)
|
| 268 |
+
padded_size = m + n - 1
|
| 269 |
+
f_signal = torch.fft.rfft(signal, n=padded_size)
|
| 270 |
+
f_kernel = torch.fft.rfft(kernel, n=padded_size)
|
| 271 |
+
f_result = f_signal * f_kernel
|
| 272 |
+
result = torch.fft.irfft(f_result, n=padded_size)
|
| 273 |
+
return result[:padded_size]
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def _apply_rir(audio, rir, mix_ratio=0.5):
|
| 277 |
+
rir_scaling_factor = 0.5**15
|
| 278 |
+
N_in = audio.shape[-1]
|
| 279 |
+
rir_d = rir[0, :] * rir_scaling_factor
|
| 280 |
+
aug_d = _convolve1d(audio[0], rir_d)
|
| 281 |
+
shift_index = torch.argmax(torch.abs(rir_d))
|
| 282 |
+
end_index = shift_index + N_in
|
| 283 |
+
if end_index > aug_d.shape[0]:
|
| 284 |
+
augmented = F.pad(aug_d[shift_index:], (0, end_index - aug_d.shape[0]))
|
| 285 |
+
else:
|
| 286 |
+
augmented = aug_d[shift_index:end_index]
|
| 287 |
+
power_before = torch.sum(audio[0] ** 2)
|
| 288 |
+
power_after = torch.sum(augmented**2)
|
| 289 |
+
if power_after > 0:
|
| 290 |
+
augmented *= torch.sqrt(power_before / power_after)
|
| 291 |
+
mixed = (1 - mix_ratio) * audio[0] + mix_ratio * augmented
|
| 292 |
+
return mixed.unsqueeze(0)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def process_init(rank_queue, tokenizer_path, noise_manifest=None, rir_manifest=None):
|
| 296 |
+
"""
|
| 297 |
+
Initialization function for each worker process.
|
| 298 |
+
Assigns a specific GPU to the process and loads the tokenizer.
|
| 299 |
+
"""
|
| 300 |
+
global worker_tokenizer, worker_feature_extractor, worker_noise_sampler, worker_rir_sampler
|
| 301 |
+
|
| 302 |
+
# Configure worker process logging
|
| 303 |
+
formatter = (
|
| 304 |
+
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d]"
|
| 305 |
+
" [Worker %(process)d] %(message)s"
|
| 306 |
+
)
|
| 307 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 308 |
+
|
| 309 |
+
# Get assigned GPU rank
|
| 310 |
+
rank = rank_queue.get()
|
| 311 |
+
# Determine device
|
| 312 |
+
if rank != -1 and torch.cuda.is_available():
|
| 313 |
+
worker_device = torch.device(f"cuda:{rank}")
|
| 314 |
+
else:
|
| 315 |
+
worker_device = torch.device("cpu")
|
| 316 |
+
|
| 317 |
+
logging.debug(f"Worker process initialized with device: {worker_device}")
|
| 318 |
+
# Load tokenizer onto the specified device
|
| 319 |
+
worker_feature_extractor = AutoFeatureExtractor.from_pretrained(tokenizer_path)
|
| 320 |
+
worker_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
|
| 321 |
+
tokenizer_path, device_map=worker_device
|
| 322 |
+
)
|
| 323 |
+
logging.debug(f"Tokenizer loaded successfully on device {worker_device}")
|
| 324 |
+
|
| 325 |
+
# Initialize augmentation samplers (optional)
|
| 326 |
+
if noise_manifest:
|
| 327 |
+
try:
|
| 328 |
+
with open(noise_manifest, "r") as f:
|
| 329 |
+
tars = [l.strip().split()[0] for l in f if l.strip()]
|
| 330 |
+
worker_noise_sampler = SimpleWorkerSampler(
|
| 331 |
+
tars, sample_rate=HIGGS_INPUT_SAMPLE_RATE
|
| 332 |
+
)
|
| 333 |
+
logging.debug("Noise sampler initialized.")
|
| 334 |
+
except Exception as e:
|
| 335 |
+
logging.warning(f"Failed to load noise manifest: {e}")
|
| 336 |
+
|
| 337 |
+
if rir_manifest:
|
| 338 |
+
try:
|
| 339 |
+
with open(rir_manifest, "r") as f:
|
| 340 |
+
tars = [l.strip().split()[0] for l in f if l.strip()]
|
| 341 |
+
worker_rir_sampler = SimpleWorkerSampler(
|
| 342 |
+
tars, sample_rate=HIGGS_INPUT_SAMPLE_RATE
|
| 343 |
+
)
|
| 344 |
+
logging.debug("RIR sampler initialized.")
|
| 345 |
+
except Exception as e:
|
| 346 |
+
logging.warning(f"Failed to load RIR manifest: {e}")
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def _augment_prompt(audio_tensor: torch.Tensor) -> tuple[torch.Tensor, int]:
|
| 350 |
+
"""Apply noise/reverb augmentation to the front portion of audio.
|
| 351 |
+
|
| 352 |
+
Returns the augmented audio and the sample index where clean audio starts.
|
| 353 |
+
"""
|
| 354 |
+
# Pre-normalization
|
| 355 |
+
max_val = audio_tensor.abs().max() + 1e-7
|
| 356 |
+
audio_tensor = (audio_tensor / max_val) * 0.6
|
| 357 |
+
|
| 358 |
+
total_len = audio_tensor.size(-1)
|
| 359 |
+
ratio = random.uniform(0.1, 0.3)
|
| 360 |
+
split_idx = int(total_len * ratio)
|
| 361 |
+
front_part = audio_tensor[:, :split_idx].clone()
|
| 362 |
+
|
| 363 |
+
# Apply noise
|
| 364 |
+
if worker_noise_sampler is not None:
|
| 365 |
+
noise = worker_noise_sampler.sample_segment(split_idx)
|
| 366 |
+
snr_db = random.uniform(5, 15)
|
| 367 |
+
sig_rms = front_part.norm(p=2) / (split_idx**0.5)
|
| 368 |
+
noise_rms = noise.norm(p=2) / (split_idx**0.5)
|
| 369 |
+
if noise_rms > 1e-9:
|
| 370 |
+
snr = 10 ** (snr_db / 20)
|
| 371 |
+
scale = sig_rms / (snr * noise_rms + 1e-8)
|
| 372 |
+
front_part = front_part + noise * scale
|
| 373 |
+
|
| 374 |
+
# Apply RIR (30% probability)
|
| 375 |
+
if worker_rir_sampler is not None and random.random() < 0.3:
|
| 376 |
+
rir = worker_rir_sampler.sample_segment(split_idx, allow_repeat=False)
|
| 377 |
+
reverb_amt = random.uniform(0.3, 1.0)
|
| 378 |
+
try:
|
| 379 |
+
front_part = _apply_rir(front_part, rir, reverb_amt)
|
| 380 |
+
except Exception as e:
|
| 381 |
+
logging.warning(f"RIR failed: {e}")
|
| 382 |
+
|
| 383 |
+
# Merge back
|
| 384 |
+
if front_part.device != audio_tensor.device:
|
| 385 |
+
front_part = front_part.to(audio_tensor.device)
|
| 386 |
+
audio_tensor[:, :split_idx] = front_part
|
| 387 |
+
|
| 388 |
+
# Post-normalization
|
| 389 |
+
max_val = audio_tensor.abs().max() + 1e-7
|
| 390 |
+
audio_tensor = (audio_tensor / max_val) * 0.9
|
| 391 |
+
|
| 392 |
+
return audio_tensor, split_idx
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def process_single_sample(sample: dict[str, Any]) -> dict[str, Any]:
|
| 396 |
+
"""
|
| 397 |
+
Single-sample processing function executed in worker processes.
|
| 398 |
+
Skips invalid samples during streaming processing.
|
| 399 |
+
"""
|
| 400 |
+
try:
|
| 401 |
+
audio_tensor = sample.get("audio", None) # shape (1, T)
|
| 402 |
+
if audio_tensor is None:
|
| 403 |
+
raise ValueError("Sample missing 'audio' field")
|
| 404 |
+
|
| 405 |
+
# Apply prompt augmentation if noise/rir samplers are available
|
| 406 |
+
enable_aug = worker_noise_sampler is not None or worker_rir_sampler is not None
|
| 407 |
+
clean_sample_idx = 0
|
| 408 |
+
if enable_aug:
|
| 409 |
+
audio_tensor, clean_sample_idx = _augment_prompt(audio_tensor)
|
| 410 |
+
|
| 411 |
+
with torch.inference_mode():
|
| 412 |
+
key = sample["label"]["id"]
|
| 413 |
+
|
| 414 |
+
inputs = worker_feature_extractor(
|
| 415 |
+
raw_audio=audio_tensor.squeeze(0).numpy(),
|
| 416 |
+
sampling_rate=HIGGS_INPUT_SAMPLE_RATE,
|
| 417 |
+
return_tensors="pt",
|
| 418 |
+
).to(worker_tokenizer.device)
|
| 419 |
+
audio_tokens = worker_tokenizer.encode(
|
| 420 |
+
inputs["input_values"],
|
| 421 |
+
).audio_codes.squeeze(0)
|
| 422 |
+
|
| 423 |
+
assert len(audio_tokens.shape) == 2
|
| 424 |
+
assert audio_tokens.size(0) == 8
|
| 425 |
+
|
| 426 |
+
num_tokens = audio_tokens.size(1)
|
| 427 |
+
metadata = sample["label"]
|
| 428 |
+
metadata["num_tokens"] = num_tokens
|
| 429 |
+
|
| 430 |
+
if enable_aug:
|
| 431 |
+
clean_token_idx = math.ceil(
|
| 432 |
+
clean_sample_idx / worker_tokenizer.config.hop_length
|
| 433 |
+
)
|
| 434 |
+
metadata["clean_start_token_idx"] = clean_token_idx
|
| 435 |
+
|
| 436 |
+
# Convert to numpy format for subsequent serialization (int16 to save space)
|
| 437 |
+
audio_tokens_np = audio_tokens.to(torch.int16).cpu().numpy()
|
| 438 |
+
|
| 439 |
+
return {
|
| 440 |
+
"status": "success",
|
| 441 |
+
"key": key,
|
| 442 |
+
"audio_tokens": audio_tokens_np,
|
| 443 |
+
"metadata": metadata,
|
| 444 |
+
"error_msg": None,
|
| 445 |
+
}
|
| 446 |
+
except Exception as e:
|
| 447 |
+
sample_id = sample.get("label", {}).get("id", "unknown")
|
| 448 |
+
logging.error(f"Failed to process sample {sample_id}: {e}")
|
| 449 |
+
return {
|
| 450 |
+
"status": "error",
|
| 451 |
+
"key": sample_id,
|
| 452 |
+
"audio_tokens": None,
|
| 453 |
+
"metadata": None,
|
| 454 |
+
"error_msg": str(e),
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def _normalise_value(value: Any) -> Any:
|
| 459 |
+
"""Convert tensors and NumPy scalars to serialisable Python objects."""
|
| 460 |
+
if isinstance(value, torch.Tensor):
|
| 461 |
+
if value.ndim == 0:
|
| 462 |
+
return value.item()
|
| 463 |
+
return value.cpu().tolist()
|
| 464 |
+
if isinstance(value, np.generic):
|
| 465 |
+
return value.item()
|
| 466 |
+
if isinstance(value, np.ndarray):
|
| 467 |
+
return value.tolist()
|
| 468 |
+
return value
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def _encode_metadata(metadata: dict[str, Any]) -> bytes:
|
| 472 |
+
cleaned: dict[str, Any] = {}
|
| 473 |
+
for key, value in metadata.items():
|
| 474 |
+
if value is None:
|
| 475 |
+
continue
|
| 476 |
+
cleaned[key] = _normalise_value(value)
|
| 477 |
+
return json.dumps(cleaned, ensure_ascii=False).encode("utf-8")
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class StreamingLengthFilteredDataset(IterableDataset):
|
| 481 |
+
def __init__(
|
| 482 |
+
self,
|
| 483 |
+
base_iterable,
|
| 484 |
+
min_len: float,
|
| 485 |
+
max_len: float,
|
| 486 |
+
sr: int,
|
| 487 |
+
):
|
| 488 |
+
self.base_iterable = base_iterable
|
| 489 |
+
self.min_len = min_len
|
| 490 |
+
self.max_len = max_len
|
| 491 |
+
self.sr = sr
|
| 492 |
+
self.filtered_count = 0
|
| 493 |
+
|
| 494 |
+
def __iter__(self):
|
| 495 |
+
"""Stream samples one by one and filter on the fly."""
|
| 496 |
+
for sample in self.base_iterable:
|
| 497 |
+
try:
|
| 498 |
+
duration = sample["audio"].size(-1) / self.sr
|
| 499 |
+
if self.min_len <= duration <= self.max_len:
|
| 500 |
+
yield sample
|
| 501 |
+
else:
|
| 502 |
+
self.filtered_count += 1
|
| 503 |
+
logging.warning(
|
| 504 |
+
f"Filtered sample (duration out of range): "
|
| 505 |
+
f"{sample['label']['id']} ({duration:.2f}s)"
|
| 506 |
+
)
|
| 507 |
+
except Exception as e:
|
| 508 |
+
logging.warning(f"Skipped invalid sample during streaming: {e}")
|
| 509 |
+
continue
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def main() -> None:
|
| 513 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 514 |
+
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
|
| 515 |
+
parser = build_parser()
|
| 516 |
+
args = parser.parse_args()
|
| 517 |
+
mp.set_start_method("spawn", force=True)
|
| 518 |
+
|
| 519 |
+
# Validate input arguments
|
| 520 |
+
assert bool(args.input_manifest) != bool(
|
| 521 |
+
args.input_jsonl
|
| 522 |
+
), "Exactly one of --input_manifest or --input_jsonl must be provided."
|
| 523 |
+
|
| 524 |
+
if args.num_machines > 1:
|
| 525 |
+
assert (
|
| 526 |
+
0 <= args.machine_index < args.num_machines
|
| 527 |
+
), f"machine_index {args.machine_index} must be in [0, {args.num_machines})"
|
| 528 |
+
|
| 529 |
+
# Build base dataset and count total samples based on input mode
|
| 530 |
+
if args.input_jsonl:
|
| 531 |
+
logging.info(f"Input mode: raw JSONL ({args.input_jsonl})")
|
| 532 |
+
total_samples = count_lines(args.input_jsonl)
|
| 533 |
+
base_dataset = JsonlDatasetReader(
|
| 534 |
+
args.input_jsonl,
|
| 535 |
+
sample_rate=HIGGS_INPUT_SAMPLE_RATE,
|
| 536 |
+
shuffle=args.shuffle,
|
| 537 |
+
shuffle_seed=args.shuffle_seed,
|
| 538 |
+
)
|
| 539 |
+
loader_workers = args.loader_workers
|
| 540 |
+
else:
|
| 541 |
+
logging.info(f"Input mode: WebDataset manifest ({args.input_manifest})")
|
| 542 |
+
manifest_num_lines = count_lines(args.input_manifest)
|
| 543 |
+
loader_workers = min(args.loader_workers, manifest_num_lines)
|
| 544 |
+
total_samples = 0
|
| 545 |
+
manifests = []
|
| 546 |
+
with open(args.input_manifest, "r", encoding="utf-8") as f:
|
| 547 |
+
for line_id, line in tqdm(
|
| 548 |
+
enumerate(f),
|
| 549 |
+
total=manifest_num_lines,
|
| 550 |
+
desc="Calculating dataset length",
|
| 551 |
+
):
|
| 552 |
+
items = line.strip().split(" ")
|
| 553 |
+
tar_path, jsonl_path, num_items, duration = (
|
| 554 |
+
items[0],
|
| 555 |
+
items[1],
|
| 556 |
+
int(items[2]),
|
| 557 |
+
float(items[3]),
|
| 558 |
+
)
|
| 559 |
+
assert os.path.exists(tar_path), f"File {tar_path} does not exist."
|
| 560 |
+
assert os.path.exists(jsonl_path), f"File {jsonl_path} does not exist."
|
| 561 |
+
assert jsonl_path.endswith(
|
| 562 |
+
".jsonl"
|
| 563 |
+
), f"File {jsonl_path} is not a .jsonl file."
|
| 564 |
+
if (
|
| 565 |
+
args.num_machines > 1
|
| 566 |
+
and line_id % args.num_machines != args.machine_index
|
| 567 |
+
):
|
| 568 |
+
continue
|
| 569 |
+
total_samples += num_items
|
| 570 |
+
manifests.append((tar_path, jsonl_path, num_items, duration))
|
| 571 |
+
logging.info(
|
| 572 |
+
f"Total shards: {manifest_num_lines}, "
|
| 573 |
+
f"Shards for current index: {len(manifests)}"
|
| 574 |
+
)
|
| 575 |
+
base_dataset = WebDatasetReader(
|
| 576 |
+
manifests=manifests,
|
| 577 |
+
sample_rate=HIGGS_INPUT_SAMPLE_RATE,
|
| 578 |
+
evaluation=True,
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
# Apply length filter and create DataLoader
|
| 582 |
+
filtered_dataset = StreamingLengthFilteredDataset(
|
| 583 |
+
base_iterable=base_dataset,
|
| 584 |
+
min_len=args.min_length,
|
| 585 |
+
max_len=args.max_length,
|
| 586 |
+
sr=HIGGS_INPUT_SAMPLE_RATE,
|
| 587 |
+
)
|
| 588 |
+
dataloader = DataLoader(
|
| 589 |
+
dataset=filtered_dataset,
|
| 590 |
+
batch_size=None,
|
| 591 |
+
num_workers=loader_workers,
|
| 592 |
+
persistent_workers=loader_workers > 0,
|
| 593 |
+
pin_memory=False,
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
# Adjust samples_per_shard if min_num_shards would be violated
|
| 597 |
+
samples_per_shard = args.samples_per_shard
|
| 598 |
+
if total_samples > 0:
|
| 599 |
+
estimated_shards = max(
|
| 600 |
+
1, (total_samples + samples_per_shard - 1) // samples_per_shard
|
| 601 |
+
)
|
| 602 |
+
if estimated_shards < args.min_num_shards:
|
| 603 |
+
samples_per_shard = max(1, total_samples // args.min_num_shards)
|
| 604 |
+
logging.info(
|
| 605 |
+
f"Adjusted samples_per_shard from {args.samples_per_shard} to "
|
| 606 |
+
f"{samples_per_shard} to meet min_num_shards={args.min_num_shards} "
|
| 607 |
+
f"(total_samples={total_samples})"
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# Configure multi-GPU multi-process setup
|
| 611 |
+
num_devices = torch.cuda.device_count()
|
| 612 |
+
if num_devices == 0:
|
| 613 |
+
logging.warning("No GPUs detected - using CPU for processing")
|
| 614 |
+
num_processes = args.nj_per_gpu
|
| 615 |
+
else:
|
| 616 |
+
num_processes = num_devices * args.nj_per_gpu
|
| 617 |
+
logging.info(
|
| 618 |
+
f"GPU count: {num_devices}, Processes per GPU: {args.nj_per_gpu}, "
|
| 619 |
+
f"Total processes: {num_processes}"
|
| 620 |
+
)
|
| 621 |
+
if args.noise_manifest or args.rir_manifest:
|
| 622 |
+
logging.info(
|
| 623 |
+
f"Prompt augmentation enabled - "
|
| 624 |
+
f"noise: {args.noise_manifest or 'off'}, rir: {args.rir_manifest or 'off'}"
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# Shared GPU rank queue for process assignment
|
| 628 |
+
manager = mp.Manager()
|
| 629 |
+
rank_queue = manager.Queue()
|
| 630 |
+
for rank in list(range(num_devices)) * args.nj_per_gpu:
|
| 631 |
+
rank_queue.put(rank)
|
| 632 |
+
if num_devices == 0:
|
| 633 |
+
for _ in range(num_processes):
|
| 634 |
+
rank_queue.put(-1)
|
| 635 |
+
|
| 636 |
+
# Prepare output paths
|
| 637 |
+
tar_output_pattern = str(Path(args.tar_output_pattern).expanduser())
|
| 638 |
+
jsonl_output_pattern = str(Path(args.jsonl_output_pattern).expanduser())
|
| 639 |
+
Path(tar_output_pattern).parent.mkdir(parents=True, exist_ok=True)
|
| 640 |
+
Path(jsonl_output_pattern).parent.mkdir(parents=True, exist_ok=True)
|
| 641 |
+
|
| 642 |
+
# Determine output directory from tar_output_pattern
|
| 643 |
+
output_dir = Path(tar_output_pattern).parent.parent
|
| 644 |
+
error_log_path = str(output_dir / "errors.jsonl")
|
| 645 |
+
manifest_path = str(output_dir / "data.lst")
|
| 646 |
+
|
| 647 |
+
# Setup error logger (writes to errors.jsonl)
|
| 648 |
+
error_logger = logging.getLogger("error_log")
|
| 649 |
+
error_logger.setLevel(logging.ERROR)
|
| 650 |
+
error_logger.handlers.clear()
|
| 651 |
+
error_fh = logging.FileHandler(error_log_path, mode="w", encoding="utf-8")
|
| 652 |
+
error_fh.setFormatter(logging.Formatter("%(message)s"))
|
| 653 |
+
error_logger.addHandler(error_fh)
|
| 654 |
+
|
| 655 |
+
# Progress and error tracking
|
| 656 |
+
processed_count = 0
|
| 657 |
+
error_count = 0
|
| 658 |
+
write_error_count = 0
|
| 659 |
+
failed_ids = []
|
| 660 |
+
shard_idx = 0
|
| 661 |
+
shard_sample_count = 0
|
| 662 |
+
shard_duration = 0.0
|
| 663 |
+
shard_manifest = {} # shard_idx -> (tar_path, jsonl_path, count, duration)
|
| 664 |
+
|
| 665 |
+
tar_writer = None
|
| 666 |
+
jsonl_file = None
|
| 667 |
+
|
| 668 |
+
def open_new_shard():
|
| 669 |
+
nonlocal tar_writer, jsonl_file, shard_idx, shard_sample_count, shard_duration
|
| 670 |
+
if tar_writer is not None:
|
| 671 |
+
tar_writer.close()
|
| 672 |
+
if jsonl_file is not None:
|
| 673 |
+
jsonl_file.close()
|
| 674 |
+
# Record manifest for the previous shard
|
| 675 |
+
if shard_idx > 0 and shard_sample_count > 0:
|
| 676 |
+
prev_idx = shard_idx - 1
|
| 677 |
+
shard_manifest[prev_idx] = (
|
| 678 |
+
os.path.abspath(tar_output_pattern % prev_idx),
|
| 679 |
+
os.path.abspath(jsonl_output_pattern % prev_idx),
|
| 680 |
+
shard_sample_count,
|
| 681 |
+
shard_duration,
|
| 682 |
+
)
|
| 683 |
+
tar_fname = tar_output_pattern % shard_idx
|
| 684 |
+
jsonl_fname = jsonl_output_pattern % shard_idx
|
| 685 |
+
tar_writer = wds.TarWriter(tar_fname)
|
| 686 |
+
jsonl_file = open(jsonl_fname, "w", encoding="utf-8")
|
| 687 |
+
shard_idx += 1
|
| 688 |
+
shard_sample_count = 0
|
| 689 |
+
shard_duration = 0.0
|
| 690 |
+
|
| 691 |
+
def write_sample(key, audio_tokens_np, metadata):
|
| 692 |
+
nonlocal shard_sample_count, write_error_count, shard_duration
|
| 693 |
+
assert tar_writer is not None and jsonl_file is not None
|
| 694 |
+
try:
|
| 695 |
+
token_record = serialise_numpy(key, audio_tokens_np)
|
| 696 |
+
json_record = _encode_metadata(metadata)
|
| 697 |
+
tar_writer.write(token_record)
|
| 698 |
+
jsonl_file.write(json_record.decode("utf-8") + "\n")
|
| 699 |
+
shard_sample_count += 1
|
| 700 |
+
shard_duration += metadata.get("audio_duration", 0.0)
|
| 701 |
+
except Exception as exc:
|
| 702 |
+
write_error_count += 1
|
| 703 |
+
failed_ids.append(key)
|
| 704 |
+
error_logger.error(
|
| 705 |
+
json.dumps({"id": key, "reason": str(exc)}, ensure_ascii=False)
|
| 706 |
+
)
|
| 707 |
+
logging.error(f"Write failed for sample {key}: {exc}")
|
| 708 |
+
|
| 709 |
+
def handle_result(result):
|
| 710 |
+
nonlocal processed_count, error_count
|
| 711 |
+
if result["status"] == "success":
|
| 712 |
+
# Rotate shard if needed
|
| 713 |
+
if tar_writer is None or shard_sample_count >= samples_per_shard:
|
| 714 |
+
open_new_shard()
|
| 715 |
+
write_sample(result["key"], result["audio_tokens"], result["metadata"])
|
| 716 |
+
processed_count += 1
|
| 717 |
+
else:
|
| 718 |
+
error_count += 1
|
| 719 |
+
failed_ids.append(result["key"])
|
| 720 |
+
error_logger.error(
|
| 721 |
+
json.dumps(
|
| 722 |
+
{"id": result["key"], "reason": result["error_msg"]},
|
| 723 |
+
ensure_ascii=False,
|
| 724 |
+
)
|
| 725 |
+
)
|
| 726 |
+
if not args.skip_errors:
|
| 727 |
+
raise RuntimeError(
|
| 728 |
+
f"Sample {result['key']} processing failed due "
|
| 729 |
+
f"to {result['error_msg']} - terminating"
|
| 730 |
+
)
|
| 731 |
+
logging.warning(
|
| 732 |
+
f"Skipping failed sample {result['key']}: {result['error_msg']}"
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
main_progress = tqdm(total=total_samples, desc="Extracting Audio Tokens")
|
| 736 |
+
|
| 737 |
+
try:
|
| 738 |
+
with ProcessPoolExecutor(
|
| 739 |
+
max_workers=num_processes,
|
| 740 |
+
initializer=process_init,
|
| 741 |
+
initargs=(
|
| 742 |
+
rank_queue,
|
| 743 |
+
args.tokenizer_path,
|
| 744 |
+
args.noise_manifest,
|
| 745 |
+
args.rir_manifest,
|
| 746 |
+
),
|
| 747 |
+
) as executor:
|
| 748 |
+
logging.info(f"Submitting tasks... ({num_processes} workers)")
|
| 749 |
+
futures = set()
|
| 750 |
+
max_pending = num_processes * 10
|
| 751 |
+
|
| 752 |
+
def drain_completed():
|
| 753 |
+
"""Wait for at least one future to complete, process all done."""
|
| 754 |
+
nonlocal futures
|
| 755 |
+
done, _ = wait(futures, return_when=FIRST_COMPLETED)
|
| 756 |
+
for f in done:
|
| 757 |
+
futures.discard(f)
|
| 758 |
+
result = f.result()
|
| 759 |
+
main_progress.update(1)
|
| 760 |
+
handle_result(result)
|
| 761 |
+
main_progress.set_postfix(
|
| 762 |
+
Samples=processed_count,
|
| 763 |
+
Errors=error_count,
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
# Stream samples from DataLoader
|
| 767 |
+
for sample in dataloader:
|
| 768 |
+
if len(futures) >= max_pending:
|
| 769 |
+
drain_completed()
|
| 770 |
+
|
| 771 |
+
future = executor.submit(process_single_sample, sample)
|
| 772 |
+
futures.add(future)
|
| 773 |
+
|
| 774 |
+
# Process remaining futures
|
| 775 |
+
logging.info("Processing remaining pending samples...")
|
| 776 |
+
while futures:
|
| 777 |
+
drain_completed()
|
| 778 |
+
|
| 779 |
+
except Exception:
|
| 780 |
+
logging.error("Critical error during processing", exc_info=True)
|
| 781 |
+
raise
|
| 782 |
+
finally:
|
| 783 |
+
main_progress.close()
|
| 784 |
+
if tar_writer is not None:
|
| 785 |
+
tar_writer.close()
|
| 786 |
+
if jsonl_file is not None:
|
| 787 |
+
jsonl_file.close()
|
| 788 |
+
# Record the last shard in the manifest
|
| 789 |
+
if shard_idx > 0 and shard_sample_count > 0:
|
| 790 |
+
last_idx = shard_idx - 1
|
| 791 |
+
shard_manifest[last_idx] = (
|
| 792 |
+
os.path.abspath(tar_output_pattern % last_idx),
|
| 793 |
+
os.path.abspath(jsonl_output_pattern % last_idx),
|
| 794 |
+
shard_sample_count,
|
| 795 |
+
shard_duration,
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
# Write manifest file (data.lst)
|
| 799 |
+
with open(manifest_path, "w", encoding="utf-8") as mf:
|
| 800 |
+
for idx in sorted(shard_manifest.keys()):
|
| 801 |
+
tar_path, jsonl_path, count, duration = shard_manifest[idx]
|
| 802 |
+
mf.write(f"{tar_path} {jsonl_path} {count} {duration:.3f}\n")
|
| 803 |
+
|
| 804 |
+
# Output final statistics
|
| 805 |
+
total_failed = error_count + write_error_count
|
| 806 |
+
filtered_and_skipped = total_samples - processed_count - total_failed
|
| 807 |
+
logging.info(
|
| 808 |
+
f"Processing Complete - Successful: {processed_count}, Failed: {total_failed}, "
|
| 809 |
+
f"Filtered/Skipped: {filtered_and_skipped}, Shards written: {shard_idx}"
|
| 810 |
+
)
|
| 811 |
+
logging.info(f"Manifest written to: {manifest_path} ({len(shard_manifest)} shards)")
|
| 812 |
+
if total_failed > 0:
|
| 813 |
+
logging.info(f"Error details: {error_log_path}")
|
| 814 |
+
if failed_ids and args.skip_errors:
|
| 815 |
+
logging.warning(
|
| 816 |
+
f"Failed sample IDs (count: {len(failed_ids)}): {failed_ids[:100]}..."
|
| 817 |
+
)
|
| 818 |
+
if write_error_count > 0 and not args.skip_errors:
|
| 819 |
+
raise RuntimeError(
|
| 820 |
+
f"{write_error_count} samples failed to write - check logs for details"
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
if __name__ == "__main__":
|
| 825 |
+
main()
|
omnivoice/scripts/jsonl_to_webdataset.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Pack a JSONL audio dataset into a customed WebDataset shards
|
| 20 |
+
(paired .tar and .jsonl files).
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
python jsonl_to_webdataset.py \
|
| 24 |
+
--input data.jsonl \
|
| 25 |
+
--output output_dir/ \
|
| 26 |
+
--workers 16 \
|
| 27 |
+
--threads 4 \
|
| 28 |
+
--shard-size 1000 \
|
| 29 |
+
--sr 24000
|
| 30 |
+
|
| 31 |
+
Input JSONL format (one JSON object per line):
|
| 32 |
+
{"id": "utt_001", "audio_path": "/data/wavs/001.wav", "text": "hello world", ...}
|
| 33 |
+
|
| 34 |
+
Required fields: "id", "audio_path", "text"
|
| 35 |
+
All other fields are preserved in the output metadata.
|
| 36 |
+
|
| 37 |
+
Output structure:
|
| 38 |
+
output_dir/
|
| 39 |
+
├── audios/ # WebDataset tar shards
|
| 40 |
+
│ ├── shard_000000.tar
|
| 41 |
+
│ ├── shard_000001.tar
|
| 42 |
+
│ └── ...
|
| 43 |
+
├── txts/ # Per-shard JSONL metadata (with audio_duration added)
|
| 44 |
+
│ ├── shard_000000.jsonl
|
| 45 |
+
│ ├── shard_000001.jsonl
|
| 46 |
+
│ └── ...
|
| 47 |
+
├── data.lst # Manifest: <tar_path> <jsonl_path> <sample_count> <total_duration>
|
| 48 |
+
└── errors.jsonl # Failed samples with error details
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
import argparse
|
| 52 |
+
import io
|
| 53 |
+
import json
|
| 54 |
+
import logging
|
| 55 |
+
import multiprocessing as mp
|
| 56 |
+
import os
|
| 57 |
+
import random
|
| 58 |
+
from concurrent.futures import (
|
| 59 |
+
FIRST_COMPLETED,
|
| 60 |
+
ProcessPoolExecutor,
|
| 61 |
+
ThreadPoolExecutor,
|
| 62 |
+
as_completed,
|
| 63 |
+
wait,
|
| 64 |
+
)
|
| 65 |
+
from itertools import islice
|
| 66 |
+
from pathlib import Path
|
| 67 |
+
|
| 68 |
+
import torchaudio
|
| 69 |
+
import webdataset as wds
|
| 70 |
+
from tqdm import tqdm
|
| 71 |
+
|
| 72 |
+
from omnivoice.utils.common import str2bool
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 76 |
+
parser = argparse.ArgumentParser(
|
| 77 |
+
description="Pack JSONL audio dataset into WebDataset shards."
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--input", type=str, default="data.jsonl", help="Path to input JSONL file"
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--output",
|
| 84 |
+
type=str,
|
| 85 |
+
default="emilia",
|
| 86 |
+
help="Path to output directory",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--workers",
|
| 90 |
+
type=int,
|
| 91 |
+
default=16,
|
| 92 |
+
help="Number of worker processes (default: 16)",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--threads",
|
| 96 |
+
type=int,
|
| 97 |
+
default=4,
|
| 98 |
+
help="Number of threads per worker process.",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--shard-size",
|
| 102 |
+
type=int,
|
| 103 |
+
default=1000,
|
| 104 |
+
help="Number of samples per shard (default: 1000)",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--sr", type=int, default=24000, help="Target sample rate (default: 24000)"
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--shuffle",
|
| 111 |
+
type=str2bool,
|
| 112 |
+
default=True,
|
| 113 |
+
help="Shuffle data by default.",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--shuffle-seed",
|
| 117 |
+
type=int,
|
| 118 |
+
default=42,
|
| 119 |
+
help="Random seed for shuffle (default: 42)",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--min-duration",
|
| 123 |
+
type=float,
|
| 124 |
+
default=None,
|
| 125 |
+
help="Filter out samples shorter than this (seconds).",
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--max-duration",
|
| 129 |
+
type=float,
|
| 130 |
+
default=None,
|
| 131 |
+
help="Filter out samples >= this duration (seconds).",
|
| 132 |
+
)
|
| 133 |
+
return parser
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def read_jsonl(file_path):
|
| 137 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 138 |
+
for line in f:
|
| 139 |
+
line = line.strip()
|
| 140 |
+
if line:
|
| 141 |
+
yield json.loads(line)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def chunked_reader(iterator, chunk_size):
|
| 145 |
+
it = iter(iterator)
|
| 146 |
+
while chunk := list(islice(it, chunk_size)):
|
| 147 |
+
yield chunk
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def process_audio_item(meta, target_sr):
|
| 151 |
+
key = meta.get("id")
|
| 152 |
+
audio_path = meta.get("audio_path")
|
| 153 |
+
|
| 154 |
+
if not key or not audio_path:
|
| 155 |
+
return {
|
| 156 |
+
"error": {
|
| 157 |
+
"id": key,
|
| 158 |
+
"audio_path": audio_path,
|
| 159 |
+
"reason": "missing id or audio_path",
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
if not os.path.exists(audio_path):
|
| 165 |
+
raise FileNotFoundError(f"{audio_path} not found")
|
| 166 |
+
|
| 167 |
+
waveform, sr = torchaudio.load(audio_path)
|
| 168 |
+
audio_duration = waveform.shape[1] / sr
|
| 169 |
+
meta["audio_duration"] = audio_duration
|
| 170 |
+
|
| 171 |
+
if target_sr and sr != target_sr:
|
| 172 |
+
waveform = torchaudio.functional.resample(waveform, sr, target_sr)
|
| 173 |
+
sr = target_sr
|
| 174 |
+
|
| 175 |
+
audio_buffer = io.BytesIO()
|
| 176 |
+
torchaudio.save(audio_buffer, waveform, sr, format="flac", bits_per_sample=16)
|
| 177 |
+
audio_bytes = audio_buffer.getvalue()
|
| 178 |
+
|
| 179 |
+
sample = {
|
| 180 |
+
"__key__": key,
|
| 181 |
+
"flac": audio_bytes,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
return {"ok": (sample, meta)}
|
| 185 |
+
|
| 186 |
+
except Exception as e:
|
| 187 |
+
return {"error": {"id": key, "audio_path": audio_path, "reason": str(e)}}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def process_single_shard(
|
| 191 |
+
shard_idx,
|
| 192 |
+
records,
|
| 193 |
+
output_tar_pattern,
|
| 194 |
+
output_jsonl_pattern,
|
| 195 |
+
target_sr,
|
| 196 |
+
num_threads=4,
|
| 197 |
+
min_duration=None,
|
| 198 |
+
max_duration=None,
|
| 199 |
+
):
|
| 200 |
+
tar_fname = output_tar_pattern % shard_idx
|
| 201 |
+
jsonl_fname = output_jsonl_pattern % shard_idx
|
| 202 |
+
|
| 203 |
+
processed_count = 0
|
| 204 |
+
filtered_count = 0
|
| 205 |
+
error_count = 0
|
| 206 |
+
total_duration = 0.0
|
| 207 |
+
errors = []
|
| 208 |
+
|
| 209 |
+
with wds.TarWriter(tar_fname) as sink, open(
|
| 210 |
+
jsonl_fname, "w", encoding="utf-8"
|
| 211 |
+
) as jsonl_f:
|
| 212 |
+
|
| 213 |
+
with ThreadPoolExecutor(max_workers=num_threads) as thread_pool:
|
| 214 |
+
futures = []
|
| 215 |
+
|
| 216 |
+
for meta in records:
|
| 217 |
+
f = thread_pool.submit(process_audio_item, meta, target_sr)
|
| 218 |
+
futures.append(f)
|
| 219 |
+
|
| 220 |
+
for f in as_completed(futures):
|
| 221 |
+
result = f.result()
|
| 222 |
+
|
| 223 |
+
if "error" in result:
|
| 224 |
+
error_count += 1
|
| 225 |
+
errors.append(result["error"])
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
sample, meta = result["ok"]
|
| 229 |
+
dur = meta.get("audio_duration", 0.0)
|
| 230 |
+
|
| 231 |
+
# Duration filtering (based on actual audio_duration computed above)
|
| 232 |
+
if min_duration is not None and dur < min_duration:
|
| 233 |
+
filtered_count += 1
|
| 234 |
+
continue
|
| 235 |
+
if max_duration is not None and dur >= max_duration:
|
| 236 |
+
filtered_count += 1
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
sink.write(sample)
|
| 240 |
+
|
| 241 |
+
jsonl_f.write(json.dumps(meta, ensure_ascii=False) + "\n")
|
| 242 |
+
|
| 243 |
+
total_duration += dur
|
| 244 |
+
processed_count += 1
|
| 245 |
+
|
| 246 |
+
# Clean up empty shard files
|
| 247 |
+
if processed_count == 0:
|
| 248 |
+
for p in (tar_fname, jsonl_fname):
|
| 249 |
+
if os.path.exists(p):
|
| 250 |
+
os.remove(p)
|
| 251 |
+
|
| 252 |
+
return (
|
| 253 |
+
shard_idx,
|
| 254 |
+
processed_count,
|
| 255 |
+
error_count,
|
| 256 |
+
filtered_count,
|
| 257 |
+
total_duration,
|
| 258 |
+
errors,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def count_lines(path):
|
| 263 |
+
with open(path, "rb") as f:
|
| 264 |
+
return sum(buf.count(b"\n") for buf in iter(lambda: f.read(1 << 20), b""))
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def pack_dataset(
|
| 268 |
+
input_jsonl,
|
| 269 |
+
output_dir,
|
| 270 |
+
samples_per_shard=5000,
|
| 271 |
+
num_workers=16,
|
| 272 |
+
target_sr=24000,
|
| 273 |
+
threads_per_worker=4,
|
| 274 |
+
shuffle=False,
|
| 275 |
+
shuffle_seed=None,
|
| 276 |
+
min_duration=None,
|
| 277 |
+
max_duration=None,
|
| 278 |
+
):
|
| 279 |
+
input_path = Path(input_jsonl)
|
| 280 |
+
output_dir = Path(output_dir)
|
| 281 |
+
output_tar_dir = output_dir / "audios"
|
| 282 |
+
output_tar_dir.mkdir(parents=True, exist_ok=True)
|
| 283 |
+
output_jsonl_dir = output_dir / "txts"
|
| 284 |
+
output_jsonl_dir.mkdir(parents=True, exist_ok=True)
|
| 285 |
+
|
| 286 |
+
output_tar_pattern = str(output_tar_dir / "shard-%06d.tar")
|
| 287 |
+
output_jsonl_pattern = str(output_jsonl_dir / "shard-%06d.jsonl")
|
| 288 |
+
|
| 289 |
+
error_log_path = str(output_dir / "errors.jsonl")
|
| 290 |
+
|
| 291 |
+
# Setup error logger
|
| 292 |
+
error_logger = logging.getLogger("error_log")
|
| 293 |
+
error_logger.setLevel(logging.ERROR)
|
| 294 |
+
error_logger.handlers.clear()
|
| 295 |
+
fh = logging.FileHandler(error_log_path, mode="w", encoding="utf-8")
|
| 296 |
+
fh.setFormatter(logging.Formatter("%(message)s"))
|
| 297 |
+
error_logger.addHandler(fh)
|
| 298 |
+
|
| 299 |
+
shard_manifest = {}
|
| 300 |
+
|
| 301 |
+
print(f"Reading input: {input_path}")
|
| 302 |
+
print(f"Output dir: {output_dir}")
|
| 303 |
+
print(f"Strategy: {num_workers} Processes x {threads_per_worker} Threads")
|
| 304 |
+
|
| 305 |
+
if shuffle:
|
| 306 |
+
print("Load input dataset...")
|
| 307 |
+
entries = list(read_jsonl(input_path))
|
| 308 |
+
random.seed(shuffle_seed)
|
| 309 |
+
random.shuffle(entries)
|
| 310 |
+
print(f"Shuffled {len(entries)} entries (seed={shuffle_seed})")
|
| 311 |
+
total_lines = len(entries)
|
| 312 |
+
chunk_gen = chunked_reader(iter(entries), samples_per_shard)
|
| 313 |
+
else:
|
| 314 |
+
print("Calculating total lines...")
|
| 315 |
+
total_lines = count_lines(input_path)
|
| 316 |
+
chunk_gen = chunked_reader(read_jsonl(input_path), samples_per_shard)
|
| 317 |
+
|
| 318 |
+
if min_duration is not None or max_duration is not None:
|
| 319 |
+
print(
|
| 320 |
+
f"Duration filter: [{min_duration or 0:.2f}s"
|
| 321 |
+
f", {max_duration or float('inf'):.1f}s) (applied after audio decoding)"
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
total_shards_est = (total_lines + samples_per_shard - 1) // samples_per_shard
|
| 325 |
+
print(f"Total samples: {total_lines}, Estimated shards: {total_shards_est}")
|
| 326 |
+
|
| 327 |
+
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
| 328 |
+
|
| 329 |
+
futures = set()
|
| 330 |
+
|
| 331 |
+
shard_idx = 0
|
| 332 |
+
total_processed = 0
|
| 333 |
+
total_errors = 0
|
| 334 |
+
total_filtered = 0
|
| 335 |
+
|
| 336 |
+
pbar = tqdm(
|
| 337 |
+
total=total_shards_est,
|
| 338 |
+
desc="Shards Processed",
|
| 339 |
+
unit="shard",
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
def submit_next_chunks(limit):
|
| 343 |
+
"""Pull up to `limit` chunks from generator, submit them."""
|
| 344 |
+
nonlocal shard_idx
|
| 345 |
+
submitted = 0
|
| 346 |
+
for chunk in chunk_gen:
|
| 347 |
+
f = executor.submit(
|
| 348 |
+
process_single_shard,
|
| 349 |
+
shard_idx,
|
| 350 |
+
chunk,
|
| 351 |
+
output_tar_pattern,
|
| 352 |
+
output_jsonl_pattern,
|
| 353 |
+
target_sr,
|
| 354 |
+
threads_per_worker,
|
| 355 |
+
min_duration,
|
| 356 |
+
max_duration,
|
| 357 |
+
)
|
| 358 |
+
futures.add(f)
|
| 359 |
+
shard_idx += 1
|
| 360 |
+
submitted += 1
|
| 361 |
+
if submitted >= limit:
|
| 362 |
+
break
|
| 363 |
+
|
| 364 |
+
submit_next_chunks(num_workers * 2)
|
| 365 |
+
|
| 366 |
+
while futures:
|
| 367 |
+
done, _ = wait(futures, return_when=FIRST_COMPLETED)
|
| 368 |
+
|
| 369 |
+
for f in done:
|
| 370 |
+
futures.remove(f)
|
| 371 |
+
|
| 372 |
+
try:
|
| 373 |
+
s_idx, p_count, e_count, f_count, s_duration, errors = f.result()
|
| 374 |
+
total_processed += p_count
|
| 375 |
+
total_errors += e_count
|
| 376 |
+
total_filtered += f_count
|
| 377 |
+
|
| 378 |
+
# Write error log
|
| 379 |
+
for err in errors:
|
| 380 |
+
err["shard_idx"] = s_idx
|
| 381 |
+
error_logger.error(json.dumps(err, ensure_ascii=False))
|
| 382 |
+
|
| 383 |
+
if p_count > 0:
|
| 384 |
+
tar_abs = os.path.abspath(output_tar_pattern % s_idx)
|
| 385 |
+
jsonl_abs = os.path.abspath(output_jsonl_pattern % s_idx)
|
| 386 |
+
shard_manifest[s_idx] = (
|
| 387 |
+
tar_abs,
|
| 388 |
+
jsonl_abs,
|
| 389 |
+
p_count,
|
| 390 |
+
s_duration,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
pbar.set_postfix(
|
| 394 |
+
{
|
| 395 |
+
"Samples": total_processed,
|
| 396 |
+
"Filtered": total_filtered,
|
| 397 |
+
"Errors": total_errors,
|
| 398 |
+
}
|
| 399 |
+
)
|
| 400 |
+
pbar.update(1)
|
| 401 |
+
except Exception as e:
|
| 402 |
+
print(f"Shard task failed: {e}")
|
| 403 |
+
|
| 404 |
+
submit_next_chunks(1)
|
| 405 |
+
|
| 406 |
+
pbar.close()
|
| 407 |
+
|
| 408 |
+
# Write final manifest file (data.lst)
|
| 409 |
+
manifest_path = str(output_dir / "data.lst")
|
| 410 |
+
with open(manifest_path, "w", encoding="utf-8") as mf:
|
| 411 |
+
for idx in sorted(shard_manifest.keys()):
|
| 412 |
+
tar_path, jsonl_path, count, duration = shard_manifest[idx]
|
| 413 |
+
mf.write(f"{tar_path} {jsonl_path} {count} {duration:.3f}\n")
|
| 414 |
+
|
| 415 |
+
print(f"\nDone! Output saved to {output_dir}")
|
| 416 |
+
print(f"Successfully packed: {total_processed}")
|
| 417 |
+
print(f"Filtered by duration: {total_filtered}")
|
| 418 |
+
print(f"Failed: {total_errors}")
|
| 419 |
+
print(f"Manifest written to: {manifest_path} ({len(shard_manifest)} shards)")
|
| 420 |
+
if total_errors > 0:
|
| 421 |
+
print(f"Error details: {error_log_path}")
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
if __name__ == "__main__":
|
| 425 |
+
mp.set_start_method("spawn", force=True)
|
| 426 |
+
|
| 427 |
+
args = build_parser().parse_args()
|
| 428 |
+
pack_dataset(
|
| 429 |
+
input_jsonl=args.input,
|
| 430 |
+
output_dir=args.output,
|
| 431 |
+
samples_per_shard=args.shard_size,
|
| 432 |
+
num_workers=args.workers,
|
| 433 |
+
target_sr=args.sr,
|
| 434 |
+
threads_per_worker=args.threads,
|
| 435 |
+
shuffle=args.shuffle,
|
| 436 |
+
shuffle_seed=args.shuffle_seed,
|
| 437 |
+
min_duration=args.min_duration,
|
| 438 |
+
max_duration=args.max_duration,
|
| 439 |
+
)
|
omnivoice/training/__init__.py
ADDED
|
File without changes
|
omnivoice/training/builder.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Builders for constructing training components.
|
| 19 |
+
|
| 20 |
+
Provides factory functions to assemble the model, tokenizer, and data loaders
|
| 21 |
+
from a ``TrainingConfig``. Called by ``omnivoice.cli.train`` to set up training.
|
| 22 |
+
|
| 23 |
+
Key functions:
|
| 24 |
+
- ``build_model_and_tokenizer()``: Loads the model and text tokenizer.
|
| 25 |
+
- ``build_dataloaders()``: Builds packed train/eval data loaders
|
| 26 |
+
from a data config JSON.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import logging
|
| 30 |
+
from functools import partial
|
| 31 |
+
from typing import Tuple
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
from torch.utils.data import DataLoader
|
| 35 |
+
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
| 36 |
+
from transformers import logging as hf_logging
|
| 37 |
+
from transformers.trainer_utils import seed_worker
|
| 38 |
+
|
| 39 |
+
from omnivoice.data.batching import PackingIterableDataset
|
| 40 |
+
from omnivoice.data.collator import PackingDataCollator
|
| 41 |
+
from omnivoice.data.dataset import WebDatasetReader, prepare_data_manifests_from_json
|
| 42 |
+
from omnivoice.data.processor import OmniVoiceSampleProcessor
|
| 43 |
+
from omnivoice.models.omnivoice import OmniVoice, OmniVoiceConfig
|
| 44 |
+
from omnivoice.training.config import TrainingConfig
|
| 45 |
+
|
| 46 |
+
logger = logging.getLogger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def build_model_and_tokenizer(
|
| 50 |
+
config: TrainingConfig,
|
| 51 |
+
) -> Tuple[OmniVoice, AutoTokenizer]:
|
| 52 |
+
"""Load Tokenizer and Model, handle resizing and special tokens."""
|
| 53 |
+
logger.info("Initializing Model & Tokenizer...")
|
| 54 |
+
|
| 55 |
+
# 1. Tokenizer
|
| 56 |
+
tokenizer_path = (
|
| 57 |
+
config.init_from_checkpoint
|
| 58 |
+
if config.init_from_checkpoint
|
| 59 |
+
else config.llm_name_or_path
|
| 60 |
+
)
|
| 61 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| 62 |
+
if tokenizer.pad_token is None:
|
| 63 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 64 |
+
|
| 65 |
+
new_tokens = [
|
| 66 |
+
"<|denoise|>",
|
| 67 |
+
"<|lang_start|>",
|
| 68 |
+
"<|lang_end|>",
|
| 69 |
+
"<|instruct_start|>",
|
| 70 |
+
"<|instruct_end|>",
|
| 71 |
+
"<|text_start|>",
|
| 72 |
+
"<|text_end|>",
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
tokens_to_add = [t for t in new_tokens if t not in tokenizer.get_vocab()]
|
| 76 |
+
if tokens_to_add:
|
| 77 |
+
tokenizer.add_special_tokens({"additional_special_tokens": tokens_to_add})
|
| 78 |
+
|
| 79 |
+
if config.init_from_checkpoint:
|
| 80 |
+
logger.info(f"Loading weights from {config.init_from_checkpoint}")
|
| 81 |
+
model = OmniVoice.from_pretrained(
|
| 82 |
+
config.init_from_checkpoint,
|
| 83 |
+
attn_implementation="flex_attention",
|
| 84 |
+
dtype=torch.float32,
|
| 85 |
+
train=True,
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
llm_config = AutoConfig.from_pretrained(config.llm_name_or_path)
|
| 89 |
+
|
| 90 |
+
ov_config = OmniVoiceConfig(
|
| 91 |
+
audio_vocab_size=config.audio_vocab_size,
|
| 92 |
+
audio_mask_id=config.audio_mask_id,
|
| 93 |
+
num_audio_codebook=config.num_audio_codebook,
|
| 94 |
+
audio_codebook_weights=config.audio_codebook_weights,
|
| 95 |
+
llm_config=llm_config,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
original_level = hf_logging.get_verbosity()
|
| 99 |
+
hf_logging.set_verbosity_error() # suppress expected lm_head.weight warnings
|
| 100 |
+
|
| 101 |
+
llm = AutoModel.from_pretrained(
|
| 102 |
+
config.llm_name_or_path,
|
| 103 |
+
attn_implementation="flex_attention",
|
| 104 |
+
dtype=torch.float32,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
hf_logging.set_verbosity(original_level)
|
| 108 |
+
model = OmniVoice(config=ov_config, llm=llm)
|
| 109 |
+
|
| 110 |
+
# 3. Resize Embeddings
|
| 111 |
+
if len(tokenizer) != model.config.llm_config.vocab_size:
|
| 112 |
+
model.llm.resize_token_embeddings(len(tokenizer))
|
| 113 |
+
model.config.llm_config.vocab_size = len(tokenizer)
|
| 114 |
+
|
| 115 |
+
# 4. Config IDs
|
| 116 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 117 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 118 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 119 |
+
|
| 120 |
+
return model, tokenizer
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def build_dataloaders(
|
| 124 |
+
config: TrainingConfig, tokenizer: AutoTokenizer
|
| 125 |
+
) -> Tuple[DataLoader, DataLoader]:
|
| 126 |
+
"""Setup Data Pipeline: Manifests -> WDS -> Packing -> Loaders."""
|
| 127 |
+
logger.info("Initializing Data Readers...")
|
| 128 |
+
|
| 129 |
+
processor = OmniVoiceSampleProcessor(
|
| 130 |
+
text_tokenizer=tokenizer,
|
| 131 |
+
num_channels=config.num_audio_codebook,
|
| 132 |
+
audio_mask_id=config.audio_mask_id,
|
| 133 |
+
prompt_ratio_range=config.prompt_ratio_range,
|
| 134 |
+
mask_ratio_range=config.mask_ratio_range,
|
| 135 |
+
drop_cond_ratio=config.drop_cond_ratio,
|
| 136 |
+
language_ratio=config.language_ratio,
|
| 137 |
+
use_pinyin_ratio=config.use_pinyin_ratio,
|
| 138 |
+
instruct_ratio=config.instruct_ratio,
|
| 139 |
+
only_instruct_ratio=config.only_instruct_ratio,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
train_manifests, dev_manifests = prepare_data_manifests_from_json(
|
| 143 |
+
config.data_config
|
| 144 |
+
)
|
| 145 |
+
raw_train_ds = WebDatasetReader(manifests=train_manifests, evaluation=False)
|
| 146 |
+
|
| 147 |
+
train_dataset = PackingIterableDataset(raw_train_ds, processor, config.batch_tokens)
|
| 148 |
+
|
| 149 |
+
collate_fn = PackingDataCollator(processor, config.batch_tokens)
|
| 150 |
+
|
| 151 |
+
init_fn = partial(
|
| 152 |
+
seed_worker,
|
| 153 |
+
num_workers=config.num_workers,
|
| 154 |
+
rank=torch.distributed.get_rank() if torch.distributed.is_initialized() else 0,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
train_loader = DataLoader(
|
| 158 |
+
train_dataset,
|
| 159 |
+
batch_size=None, # Each item is a batch packed to the target batch_tokens
|
| 160 |
+
num_workers=config.num_workers,
|
| 161 |
+
collate_fn=collate_fn,
|
| 162 |
+
worker_init_fn=init_fn,
|
| 163 |
+
pin_memory=True,
|
| 164 |
+
prefetch_factor=4,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
eval_loader = None
|
| 168 |
+
if dev_manifests:
|
| 169 |
+
raw_dev_ds = WebDatasetReader(manifests=dev_manifests, evaluation=True)
|
| 170 |
+
dev_dataset = PackingIterableDataset(raw_dev_ds, processor, config.batch_tokens)
|
| 171 |
+
eval_loader = DataLoader(
|
| 172 |
+
dev_dataset,
|
| 173 |
+
batch_size=None, # Each item is a batch packed to the target batch_tokens
|
| 174 |
+
num_workers=1,
|
| 175 |
+
collate_fn=collate_fn,
|
| 176 |
+
pin_memory=True,
|
| 177 |
+
prefetch_factor=2,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return train_loader, eval_loader
|
omnivoice/training/checkpoint.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Checkpoint saving, resuming, and training logging.
|
| 19 |
+
|
| 20 |
+
Provides utilities for saving/loading training checkpoints and logging metrics
|
| 21 |
+
to console and trackers (TensorBoard/WandB). Used by ``OmniTrainer``.
|
| 22 |
+
|
| 23 |
+
Key components:
|
| 24 |
+
- ``TrainLogger``: Logs training metrics to console and Accelerate trackers.
|
| 25 |
+
- ``save_checkpoint()``: Saves model, optimizer, and scheduler state.
|
| 26 |
+
- ``load_checkpoint()``: Restores training state from a checkpoint directory.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import logging
|
| 30 |
+
import os
|
| 31 |
+
import shutil
|
| 32 |
+
import time
|
| 33 |
+
from typing import Any, Dict, Optional
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
from accelerate import Accelerator
|
| 37 |
+
from tqdm.auto import tqdm
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class TrainLogger:
|
| 43 |
+
"""
|
| 44 |
+
Handles logging to console and trackers (TensorBoard/WandB)
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, accelerator: Accelerator, total_steps: int, logging_steps: int):
|
| 48 |
+
self.accelerator = accelerator
|
| 49 |
+
self.total_steps = total_steps
|
| 50 |
+
self.logging_steps = logging_steps
|
| 51 |
+
self.start_time = None
|
| 52 |
+
self.progress_bar = None
|
| 53 |
+
|
| 54 |
+
def start(self, start_step: int = 0):
|
| 55 |
+
self.start_time = time.time()
|
| 56 |
+
|
| 57 |
+
if self.accelerator.is_main_process:
|
| 58 |
+
self.progress_bar = tqdm(
|
| 59 |
+
total=self.total_steps,
|
| 60 |
+
initial=start_step,
|
| 61 |
+
desc="Training",
|
| 62 |
+
dynamic_ncols=True,
|
| 63 |
+
disable=not self.accelerator.is_local_main_process,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def update(
|
| 67 |
+
self, step: int, loss: Optional[float] = None, lr: Optional[float] = None
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Called every step to update the progress bar UI.
|
| 71 |
+
"""
|
| 72 |
+
if self.progress_bar:
|
| 73 |
+
self.progress_bar.update(1)
|
| 74 |
+
|
| 75 |
+
# Update real-time metrics on the progress bar itself
|
| 76 |
+
postfix = {}
|
| 77 |
+
if loss is not None:
|
| 78 |
+
postfix["loss"] = f"{loss:.4f}"
|
| 79 |
+
if lr is not None:
|
| 80 |
+
postfix["lr"] = f"{lr:.2e}"
|
| 81 |
+
|
| 82 |
+
if postfix:
|
| 83 |
+
self.progress_bar.set_postfix(postfix)
|
| 84 |
+
|
| 85 |
+
def log_metrics(self, step: int, metrics: Dict[str, Any]):
|
| 86 |
+
"""
|
| 87 |
+
Called periodically to log to TensorBoard/WandB and console.
|
| 88 |
+
"""
|
| 89 |
+
# Log to trackers (TensorBoard, etc.)
|
| 90 |
+
self.accelerator.log(metrics, step=step)
|
| 91 |
+
|
| 92 |
+
if self.accelerator.is_main_process:
|
| 93 |
+
# Format for console log (separate from tqdm)
|
| 94 |
+
# Remove keys that are redundant or too verbose for one line
|
| 95 |
+
formatted_metrics = []
|
| 96 |
+
for k, v in metrics.items():
|
| 97 |
+
if isinstance(v, float):
|
| 98 |
+
val_str = f"{v:.4f}"
|
| 99 |
+
if val_str == "0.0000" and v != 0:
|
| 100 |
+
formatted_metrics.append(f"{k}: {v:.2e}")
|
| 101 |
+
else:
|
| 102 |
+
formatted_metrics.append(f"{k}: {val_str}")
|
| 103 |
+
else:
|
| 104 |
+
formatted_metrics.append(f"{k}: {v}")
|
| 105 |
+
|
| 106 |
+
# Use external logger to write to file, tqdm.write to avoid breaking bar
|
| 107 |
+
msg = f"Step {step} | " + " | ".join(formatted_metrics)
|
| 108 |
+
if self.progress_bar:
|
| 109 |
+
self.progress_bar.write(msg)
|
| 110 |
+
else:
|
| 111 |
+
logger.info(msg)
|
| 112 |
+
|
| 113 |
+
def close(self):
|
| 114 |
+
if self.progress_bar:
|
| 115 |
+
self.progress_bar.close()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def save_checkpoint(
|
| 119 |
+
accelerator: Accelerator,
|
| 120 |
+
model: torch.nn.Module,
|
| 121 |
+
tokenizer: Any,
|
| 122 |
+
output_dir: str,
|
| 123 |
+
step: int,
|
| 124 |
+
keep_last_n: int = 3,
|
| 125 |
+
):
|
| 126 |
+
"""
|
| 127 |
+
Saves model, tokenizer, and accelerator states (optimizer/scheduler).
|
| 128 |
+
Manages rotation of checkpoints.
|
| 129 |
+
"""
|
| 130 |
+
checkpoint_dir = os.path.join(output_dir, f"checkpoint-{step}")
|
| 131 |
+
|
| 132 |
+
# 1. Save Accelerator State (Optimizer, Scheduler, RNG, Scaler)
|
| 133 |
+
accelerator.save_state(checkpoint_dir)
|
| 134 |
+
|
| 135 |
+
# 2. Save Model in HF format (config.json + pytorch_model.bin/safetensors)
|
| 136 |
+
unwrap_model = accelerator.unwrap_model(model)
|
| 137 |
+
unwrap_model.save_pretrained(
|
| 138 |
+
checkpoint_dir,
|
| 139 |
+
is_main_process=accelerator.is_main_process,
|
| 140 |
+
save_function=accelerator.save,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# 3. Save Tokenizer
|
| 144 |
+
if accelerator.is_main_process:
|
| 145 |
+
tokenizer.save_pretrained(checkpoint_dir)
|
| 146 |
+
|
| 147 |
+
logger.info(f"Saved checkpoint to {checkpoint_dir}")
|
| 148 |
+
|
| 149 |
+
# 4. Rotate checkpoints (Keep last N)
|
| 150 |
+
if accelerator.is_main_process and keep_last_n > 0:
|
| 151 |
+
checkpoints = [
|
| 152 |
+
d
|
| 153 |
+
for d in os.listdir(output_dir)
|
| 154 |
+
if d.startswith("checkpoint-")
|
| 155 |
+
and os.path.isdir(os.path.join(output_dir, d))
|
| 156 |
+
]
|
| 157 |
+
# Sort by step number
|
| 158 |
+
checkpoints.sort(key=lambda x: int(x.split("-")[-1]))
|
| 159 |
+
|
| 160 |
+
if len(checkpoints) > keep_last_n:
|
| 161 |
+
to_remove = checkpoints[:-keep_last_n]
|
| 162 |
+
for d in to_remove:
|
| 163 |
+
shutil.rmtree(os.path.join(output_dir, d))
|
| 164 |
+
logger.info(f"Removed old checkpoint {d}")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def load_checkpoint(accelerator: Accelerator, checkpoint_path: str):
|
| 168 |
+
"""
|
| 169 |
+
Resumes training state.
|
| 170 |
+
"""
|
| 171 |
+
logger.info(f"Resuming from {checkpoint_path}")
|
| 172 |
+
accelerator.load_state(checkpoint_path)
|
| 173 |
+
|
| 174 |
+
# Try to infer step
|
| 175 |
+
try:
|
| 176 |
+
clean_path = os.path.normpath(checkpoint_path)
|
| 177 |
+
step = int(os.path.basename(clean_path).split("-")[-1])
|
| 178 |
+
return step
|
| 179 |
+
except ValueError:
|
| 180 |
+
return 0
|
omnivoice/training/config.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Training configuration dataclass.
|
| 19 |
+
|
| 20 |
+
Defines ``TrainingConfig``, a dataclass that holds all hyperparameters and paths
|
| 21 |
+
for training. Loaded from a JSON config file via ``TrainingConfig.from_json()``
|
| 22 |
+
in ``omnivoice.cli.train``.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import json
|
| 26 |
+
from dataclasses import asdict, dataclass, field
|
| 27 |
+
from typing import List, Optional, Tuple
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class TrainingConfig:
|
| 32 |
+
# Key Paths
|
| 33 |
+
output_dir: Optional[str] = None
|
| 34 |
+
data_config: Optional[str] = None
|
| 35 |
+
|
| 36 |
+
# Model Specific
|
| 37 |
+
llm_name_or_path: str = "Qwen/Qwen3-0.6B"
|
| 38 |
+
audio_vocab_size: int = 1025 # valid vocab size + 1 (mask token)
|
| 39 |
+
audio_mask_id: int = 1024 # 1024 is the 1025-th token
|
| 40 |
+
num_audio_codebook: int = 8
|
| 41 |
+
|
| 42 |
+
# Model Training Specific
|
| 43 |
+
audio_codebook_weights: List[float | int] = field(
|
| 44 |
+
default_factory=lambda: [8, 8, 6, 6, 4, 4, 2, 2]
|
| 45 |
+
)
|
| 46 |
+
drop_cond_ratio: float = 0.1
|
| 47 |
+
prompt_ratio_range: Tuple[float, float] = field(default_factory=lambda: (0.0, 0.3))
|
| 48 |
+
mask_ratio_range: Tuple[float, float] = field(default_factory=lambda: (0.0, 1.0))
|
| 49 |
+
language_ratio: float = 0.8
|
| 50 |
+
use_pinyin_ratio: float = 0.3
|
| 51 |
+
instruct_ratio: float = 1.0
|
| 52 |
+
only_instruct_ratio: float = 0.5
|
| 53 |
+
|
| 54 |
+
# Init settings
|
| 55 |
+
resume_from_checkpoint: Optional[str] = None
|
| 56 |
+
init_from_checkpoint: Optional[str] = None
|
| 57 |
+
|
| 58 |
+
# Training Hyperparams
|
| 59 |
+
learning_rate: float = 1e-4
|
| 60 |
+
weight_decay: float = 0.01
|
| 61 |
+
max_grad_norm: float = 1.0
|
| 62 |
+
steps: int = 300000
|
| 63 |
+
seed: int = 42
|
| 64 |
+
lr_scheduler_type: str = "cosine"
|
| 65 |
+
warmup_type: str = "ratio"
|
| 66 |
+
warmup_ratio: float = 0.03
|
| 67 |
+
warmup_steps: int = 2000
|
| 68 |
+
|
| 69 |
+
# Data
|
| 70 |
+
batch_tokens: int = 8192
|
| 71 |
+
gradient_accumulation_steps: int = 1
|
| 72 |
+
num_workers: int = 8
|
| 73 |
+
|
| 74 |
+
# System
|
| 75 |
+
mixed_precision: str = "bf16"
|
| 76 |
+
allow_tf32: bool = True
|
| 77 |
+
use_deepspeed: bool = False
|
| 78 |
+
deepspeed_config: Optional[str] = None
|
| 79 |
+
|
| 80 |
+
# Logging
|
| 81 |
+
logging_steps: int = 100
|
| 82 |
+
eval_steps: int = 1000
|
| 83 |
+
save_steps: int = 10000
|
| 84 |
+
keep_last_n_checkpoints: int = -1
|
| 85 |
+
|
| 86 |
+
@classmethod
|
| 87 |
+
def from_json(cls, json_path: str):
|
| 88 |
+
with open(json_path, "r") as f:
|
| 89 |
+
cfg_dict = json.load(f)
|
| 90 |
+
valid_keys = cls.__annotations__.keys()
|
| 91 |
+
filtered_dict = {k: v for k, v in cfg_dict.items() if k in valid_keys}
|
| 92 |
+
instance = cls(**filtered_dict)
|
| 93 |
+
return instance
|
| 94 |
+
|
| 95 |
+
def save_to_json(self, json_path: str):
|
| 96 |
+
data = asdict(self)
|
| 97 |
+
with open(json_path, "w") as f:
|
| 98 |
+
json.dump(data, f, indent=4)
|
omnivoice/training/trainer.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Training loop for OmniVoice.
|
| 19 |
+
|
| 20 |
+
Wraps the HuggingFace Accelerate training loop with checkpoint saving/resuming,
|
| 21 |
+
evaluation, gradient accumulation, and learning rate scheduling.
|
| 22 |
+
Launched via ``omnivoice.cli.train``.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import logging
|
| 26 |
+
import math
|
| 27 |
+
import os
|
| 28 |
+
import sys
|
| 29 |
+
import time
|
| 30 |
+
from datetime import timedelta
|
| 31 |
+
from typing import Any, Optional
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
from accelerate import Accelerator, DistributedDataParallelKwargs
|
| 35 |
+
from accelerate.utils import DeepSpeedPlugin, InitProcessGroupKwargs, set_seed
|
| 36 |
+
from torch.utils.data import DataLoader
|
| 37 |
+
from transformers import (
|
| 38 |
+
get_cosine_schedule_with_warmup,
|
| 39 |
+
get_constant_schedule_with_warmup,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
from omnivoice.training.checkpoint import TrainLogger, load_checkpoint
|
| 43 |
+
from omnivoice.training.checkpoint import save_checkpoint as engine_save_checkpoint
|
| 44 |
+
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class OmniTrainer:
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
model: torch.nn.Module,
|
| 52 |
+
config: Any, # TrainingConfig
|
| 53 |
+
train_dataloader: DataLoader,
|
| 54 |
+
eval_dataloader: Optional[DataLoader] = None,
|
| 55 |
+
tokenizer: Optional[Any] = None,
|
| 56 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
| 57 |
+
lr_scheduler: Optional[Any] = None,
|
| 58 |
+
):
|
| 59 |
+
self.config = config
|
| 60 |
+
self.model = model
|
| 61 |
+
self.tokenizer = tokenizer
|
| 62 |
+
self.train_dataloader = train_dataloader
|
| 63 |
+
self.eval_dataloader = eval_dataloader
|
| 64 |
+
|
| 65 |
+
# 1. Initialize Accelerator
|
| 66 |
+
self.accelerator = self._init_accelerator()
|
| 67 |
+
|
| 68 |
+
# 2. Setup Optimizer & Scheduler if not provided
|
| 69 |
+
if optimizer is None:
|
| 70 |
+
self.optimizer, self.lr_scheduler = self.create_optimizer_and_scheduler()
|
| 71 |
+
else:
|
| 72 |
+
self.optimizer = optimizer
|
| 73 |
+
self.lr_scheduler = lr_scheduler
|
| 74 |
+
|
| 75 |
+
# 3. DeepSpeed Hack (Batch Size fix)
|
| 76 |
+
if self.accelerator.distributed_type == "DEEPSPEED":
|
| 77 |
+
self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
| 78 |
+
"train_micro_batch_size_per_gpu"
|
| 79 |
+
] = 1
|
| 80 |
+
|
| 81 |
+
# 4. Prepare with Accelerator
|
| 82 |
+
(self.model, self.optimizer, self.lr_scheduler,) = self.accelerator.prepare(
|
| 83 |
+
self.model,
|
| 84 |
+
self.optimizer,
|
| 85 |
+
self.lr_scheduler,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.global_step = 0
|
| 89 |
+
self.epoch = 0
|
| 90 |
+
|
| 91 |
+
def _init_accelerator(self) -> Accelerator:
|
| 92 |
+
"""Initialize Accelerator, DeepSpeed, and Logging."""
|
| 93 |
+
# TF32 setup
|
| 94 |
+
if getattr(self.config, "allow_tf32", False):
|
| 95 |
+
torch.set_float32_matmul_precision("high")
|
| 96 |
+
|
| 97 |
+
# Init handlers
|
| 98 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
| 99 |
+
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(minutes=60))
|
| 100 |
+
|
| 101 |
+
# DeepSpeed setup
|
| 102 |
+
deepspeed_plugin = None
|
| 103 |
+
if self.config.use_deepspeed and self.config.deepspeed_config:
|
| 104 |
+
if not os.path.exists(self.config.deepspeed_config):
|
| 105 |
+
raise FileNotFoundError(
|
| 106 |
+
f"DeepSpeed config not found: {self.config.deepspeed_config}"
|
| 107 |
+
)
|
| 108 |
+
deepspeed_plugin = DeepSpeedPlugin(
|
| 109 |
+
hf_ds_config=self.config.deepspeed_config,
|
| 110 |
+
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
|
| 111 |
+
gradient_clipping=self.config.max_grad_norm,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
accelerator = Accelerator(
|
| 115 |
+
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
|
| 116 |
+
mixed_precision=self.config.mixed_precision,
|
| 117 |
+
log_with="tensorboard",
|
| 118 |
+
project_dir=self.config.output_dir,
|
| 119 |
+
step_scheduler_with_optimizer=False,
|
| 120 |
+
kwargs_handlers=[ddp_kwargs, init_kwargs],
|
| 121 |
+
deepspeed_plugin=deepspeed_plugin,
|
| 122 |
+
split_batches=False,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Logging setup
|
| 126 |
+
if accelerator.is_main_process:
|
| 127 |
+
os.makedirs(self.config.output_dir, exist_ok=True)
|
| 128 |
+
# Try to save config if it has the method
|
| 129 |
+
if hasattr(self.config, "save_to_json"):
|
| 130 |
+
self.config.save_to_json(
|
| 131 |
+
os.path.join(self.config.output_dir, "initial_config.json")
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
logging.basicConfig(
|
| 135 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 136 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 137 |
+
level=logging.INFO,
|
| 138 |
+
handlers=[
|
| 139 |
+
logging.StreamHandler(sys.stdout),
|
| 140 |
+
logging.FileHandler(
|
| 141 |
+
os.path.join(self.config.output_dir, "train.log")
|
| 142 |
+
),
|
| 143 |
+
],
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
logging.basicConfig(level=logging.ERROR)
|
| 147 |
+
|
| 148 |
+
logger.info(f"Loaded Config: {self.config}")
|
| 149 |
+
set_seed(self.config.seed)
|
| 150 |
+
accelerator.init_trackers("tensorboard")
|
| 151 |
+
return accelerator
|
| 152 |
+
|
| 153 |
+
def create_optimizer_and_scheduler(self):
|
| 154 |
+
"""Default AdamW + configurable LR Scheduler."""
|
| 155 |
+
optimizer = torch.optim.AdamW(
|
| 156 |
+
self.model.parameters(),
|
| 157 |
+
lr=self.config.learning_rate,
|
| 158 |
+
weight_decay=self.config.weight_decay,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if self.config.warmup_type == "ratio":
|
| 162 |
+
final_warmup_steps = math.ceil(self.config.steps * self.config.warmup_ratio)
|
| 163 |
+
else:
|
| 164 |
+
final_warmup_steps = self.config.warmup_steps
|
| 165 |
+
|
| 166 |
+
if self.config.lr_scheduler_type == "constant":
|
| 167 |
+
lr_scheduler = get_constant_schedule_with_warmup(
|
| 168 |
+
optimizer=optimizer,
|
| 169 |
+
num_warmup_steps=final_warmup_steps,
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
lr_scheduler = get_cosine_schedule_with_warmup(
|
| 173 |
+
optimizer=optimizer,
|
| 174 |
+
num_warmup_steps=final_warmup_steps,
|
| 175 |
+
num_training_steps=self.config.steps,
|
| 176 |
+
)
|
| 177 |
+
return optimizer, lr_scheduler
|
| 178 |
+
|
| 179 |
+
def save_checkpoint(self, step):
|
| 180 |
+
"""Wrapper for engine save_checkpoint."""
|
| 181 |
+
engine_save_checkpoint(
|
| 182 |
+
self.accelerator,
|
| 183 |
+
self.model,
|
| 184 |
+
self.tokenizer,
|
| 185 |
+
self.config.output_dir,
|
| 186 |
+
step,
|
| 187 |
+
self.config.keep_last_n_checkpoints,
|
| 188 |
+
)
|
| 189 |
+
# Save config copy for convenience
|
| 190 |
+
if self.accelerator.is_main_process and hasattr(self.config, "save_to_json"):
|
| 191 |
+
checkpoint_dir = os.path.join(self.config.output_dir, f"checkpoint-{step}")
|
| 192 |
+
self.config.save_to_json(os.path.join(checkpoint_dir, "train_config.json"))
|
| 193 |
+
|
| 194 |
+
def load_checkpoint(self, checkpoint_path):
|
| 195 |
+
"""Wrapper for loading."""
|
| 196 |
+
step = load_checkpoint(self.accelerator, checkpoint_path)
|
| 197 |
+
self.global_step = step
|
| 198 |
+
logger.info(f"Resumed from step {self.global_step}")
|
| 199 |
+
return step
|
| 200 |
+
|
| 201 |
+
def evaluate(self):
|
| 202 |
+
"""Evaluation loop."""
|
| 203 |
+
if self.eval_dataloader is None:
|
| 204 |
+
return {}
|
| 205 |
+
|
| 206 |
+
self.model.eval()
|
| 207 |
+
logger.info(f"Running evaluation at step {self.global_step}...")
|
| 208 |
+
|
| 209 |
+
local_loss_sum = torch.tensor(0.0, device=self.accelerator.device)
|
| 210 |
+
eval_count = 0
|
| 211 |
+
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
for eval_batch in self.eval_dataloader:
|
| 214 |
+
outputs = self.model(**eval_batch)
|
| 215 |
+
local_loss_sum += outputs.loss.detach()
|
| 216 |
+
eval_count += 1
|
| 217 |
+
|
| 218 |
+
if eval_count > 0:
|
| 219 |
+
local_mean = local_loss_sum / eval_count
|
| 220 |
+
else:
|
| 221 |
+
local_mean = torch.tensor(0.0, device=self.accelerator.device)
|
| 222 |
+
|
| 223 |
+
all_means = self.accelerator.gather(local_mean)
|
| 224 |
+
final_eval_loss = all_means.mean().item()
|
| 225 |
+
|
| 226 |
+
eval_metrics = {"eval/loss": final_eval_loss}
|
| 227 |
+
self.accelerator.log(eval_metrics, step=self.global_step)
|
| 228 |
+
logger.info(f"Eval Loss: {final_eval_loss:.4f}")
|
| 229 |
+
|
| 230 |
+
self.accelerator.wait_for_everyone()
|
| 231 |
+
self.model.train()
|
| 232 |
+
return eval_metrics
|
| 233 |
+
|
| 234 |
+
def train(self):
|
| 235 |
+
"""Main training loop."""
|
| 236 |
+
logger.info("Starting Training Loop...")
|
| 237 |
+
|
| 238 |
+
# Resume if configured
|
| 239 |
+
if self.config.resume_from_checkpoint:
|
| 240 |
+
self.load_checkpoint(self.config.resume_from_checkpoint)
|
| 241 |
+
|
| 242 |
+
# Handle IterableDataset Epochs
|
| 243 |
+
if hasattr(self.train_dataloader.dataset, "set_epoch"):
|
| 244 |
+
self.train_dataloader.dataset.set_epoch(self.epoch)
|
| 245 |
+
|
| 246 |
+
# Logger
|
| 247 |
+
train_logger = TrainLogger(
|
| 248 |
+
self.accelerator, self.config.steps, self.config.logging_steps
|
| 249 |
+
)
|
| 250 |
+
train_logger.start(self.global_step)
|
| 251 |
+
|
| 252 |
+
self.model.train()
|
| 253 |
+
train_iterator = iter(self.train_dataloader)
|
| 254 |
+
|
| 255 |
+
logging_start_time = time.time()
|
| 256 |
+
logging_start_step = self.global_step
|
| 257 |
+
tr_loss = torch.tensor(0.0).to(self.accelerator.device)
|
| 258 |
+
logging_loss_scalar = 0.0
|
| 259 |
+
|
| 260 |
+
while self.global_step < self.config.steps:
|
| 261 |
+
try:
|
| 262 |
+
batch = next(train_iterator)
|
| 263 |
+
except StopIteration:
|
| 264 |
+
self.epoch += 1
|
| 265 |
+
logger.info(f"Epoch {self.epoch} starting. Resetting dataloader...")
|
| 266 |
+
if hasattr(self.train_dataloader.dataset, "set_epoch"):
|
| 267 |
+
self.train_dataloader.dataset.set_epoch(self.epoch)
|
| 268 |
+
|
| 269 |
+
train_iterator = iter(self.train_dataloader)
|
| 270 |
+
batch = next(train_iterator)
|
| 271 |
+
|
| 272 |
+
with self.accelerator.accumulate(self.model):
|
| 273 |
+
outputs = self.model(**batch)
|
| 274 |
+
loss = outputs.loss
|
| 275 |
+
tr_loss += loss.detach()
|
| 276 |
+
self.accelerator.backward(loss)
|
| 277 |
+
|
| 278 |
+
if self.accelerator.sync_gradients:
|
| 279 |
+
# Clipping
|
| 280 |
+
grad_norm = 0.0
|
| 281 |
+
if self.config.max_grad_norm > 0:
|
| 282 |
+
grad_norm = self.accelerator.clip_grad_norm_(
|
| 283 |
+
self.model.parameters(), self.config.max_grad_norm
|
| 284 |
+
)
|
| 285 |
+
grad_norm = (
|
| 286 |
+
grad_norm.item() if grad_norm is not None else 0.0
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
self.optimizer.step()
|
| 290 |
+
self.lr_scheduler.step()
|
| 291 |
+
self.optimizer.zero_grad()
|
| 292 |
+
self.global_step += 1
|
| 293 |
+
|
| 294 |
+
# Logging
|
| 295 |
+
current_lr = self.lr_scheduler.get_last_lr()[0]
|
| 296 |
+
train_logger.update(
|
| 297 |
+
step=self.global_step, loss=loss.item(), lr=current_lr
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if self.global_step % self.config.logging_steps == 0:
|
| 301 |
+
elapsed = time.time() - logging_start_time
|
| 302 |
+
steps_per_sec = (
|
| 303 |
+
(self.global_step - logging_start_step) / elapsed
|
| 304 |
+
if elapsed > 0
|
| 305 |
+
else 0
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
tr_loss_scalar = self.accelerator.gather(tr_loss).mean().item()
|
| 309 |
+
current_interval_loss = tr_loss_scalar - logging_loss_scalar
|
| 310 |
+
avg_loss = current_interval_loss / (
|
| 311 |
+
self.config.logging_steps
|
| 312 |
+
* self.config.gradient_accumulation_steps
|
| 313 |
+
)
|
| 314 |
+
logging_loss_scalar = tr_loss_scalar
|
| 315 |
+
|
| 316 |
+
logs = {
|
| 317 |
+
"train/loss": avg_loss,
|
| 318 |
+
"train/learning_rate": current_lr,
|
| 319 |
+
"train/grad_norm": grad_norm,
|
| 320 |
+
"train/epoch": self.epoch,
|
| 321 |
+
"train/steps_per_sec": steps_per_sec,
|
| 322 |
+
}
|
| 323 |
+
train_logger.log_metrics(step=self.global_step, metrics=logs)
|
| 324 |
+
|
| 325 |
+
logging_start_time = time.time()
|
| 326 |
+
logging_start_step = self.global_step
|
| 327 |
+
|
| 328 |
+
# Evaluate
|
| 329 |
+
if (
|
| 330 |
+
self.eval_dataloader is not None
|
| 331 |
+
and self.global_step % self.config.eval_steps == 0
|
| 332 |
+
):
|
| 333 |
+
self.evaluate()
|
| 334 |
+
|
| 335 |
+
# Save
|
| 336 |
+
if self.global_step % self.config.save_steps == 0:
|
| 337 |
+
self.save_checkpoint(self.global_step)
|
| 338 |
+
|
| 339 |
+
# Final Save
|
| 340 |
+
self.save_checkpoint(self.global_step)
|
| 341 |
+
train_logger.close()
|
| 342 |
+
self.accelerator.end_training()
|
omnivoice/utils/__init__.py
ADDED
|
File without changes
|
omnivoice/utils/audio.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Audio I/O and processing utilities.
|
| 19 |
+
|
| 20 |
+
Provides functions for loading, resampling, silence removal, chunking,
|
| 21 |
+
cross-fading, and format conversion. Used by ``OmniVoice.generate()`` during
|
| 22 |
+
inference post-processing.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torchaudio
|
| 28 |
+
from pydub import AudioSegment
|
| 29 |
+
from pydub.silence import detect_leading_silence, detect_nonsilent, split_on_silence
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_audio(audio_path: str, sampling_rate: int):
|
| 33 |
+
"""
|
| 34 |
+
Load the waveform with torchaudio and resampling if needed.
|
| 35 |
+
|
| 36 |
+
Parameters:
|
| 37 |
+
audio_path: path of the audio.
|
| 38 |
+
sampling_rate: target sampling rate.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Loaded prompt waveform with target sampling rate,
|
| 42 |
+
PyTorch tensor of shape (1, T)
|
| 43 |
+
"""
|
| 44 |
+
try:
|
| 45 |
+
waveform, prompt_sampling_rate = torchaudio.load(audio_path)
|
| 46 |
+
except (RuntimeError, OSError):
|
| 47 |
+
# Fallback via pydub+ffmpeg for formats torchaudio can't handle
|
| 48 |
+
aseg = AudioSegment.from_file(audio_path)
|
| 49 |
+
audio_data = np.array(aseg.get_array_of_samples()).astype(np.float32) / 32768.0
|
| 50 |
+
if aseg.channels == 1:
|
| 51 |
+
waveform = torch.from_numpy(audio_data).unsqueeze(0)
|
| 52 |
+
else:
|
| 53 |
+
waveform = torch.from_numpy(audio_data.reshape(-1, aseg.channels).T)
|
| 54 |
+
prompt_sampling_rate = aseg.frame_rate
|
| 55 |
+
|
| 56 |
+
if prompt_sampling_rate != sampling_rate:
|
| 57 |
+
waveform = torchaudio.functional.resample(
|
| 58 |
+
waveform,
|
| 59 |
+
orig_freq=prompt_sampling_rate,
|
| 60 |
+
new_freq=sampling_rate,
|
| 61 |
+
)
|
| 62 |
+
if waveform.shape[0] > 1:
|
| 63 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| 64 |
+
|
| 65 |
+
return waveform
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def remove_silence(
|
| 69 |
+
audio: torch.Tensor,
|
| 70 |
+
sampling_rate: int,
|
| 71 |
+
mid_sil: int = 300,
|
| 72 |
+
lead_sil: int = 100,
|
| 73 |
+
trail_sil: int = 300,
|
| 74 |
+
):
|
| 75 |
+
"""
|
| 76 |
+
Remove middle silences longer than mid_sil ms, and edge silences longer than edge_sil ms
|
| 77 |
+
|
| 78 |
+
Parameters:
|
| 79 |
+
audio: PyTorch tensor with shape (C, T).
|
| 80 |
+
sampling_rate: sampling rate of the audio.
|
| 81 |
+
mid_sil: the duration of silences in the middle of audio to be removed in ms.
|
| 82 |
+
if mid_sil <= 0, no middle silence will be removed.
|
| 83 |
+
edge_sil: the duration of silences in the edge of audio to be removed in ms.
|
| 84 |
+
trail_sil: the duration of added trailing silence in ms.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
PyTorch tensor with shape (C, T), where C is number of channels
|
| 88 |
+
and T is number of audio samples
|
| 89 |
+
"""
|
| 90 |
+
# Load audio file
|
| 91 |
+
wave = tensor_to_audiosegment(audio, sampling_rate)
|
| 92 |
+
|
| 93 |
+
if mid_sil > 0:
|
| 94 |
+
# Split audio using silences longer than mid_sil
|
| 95 |
+
non_silent_segs = split_on_silence(
|
| 96 |
+
wave,
|
| 97 |
+
min_silence_len=mid_sil,
|
| 98 |
+
silence_thresh=-50,
|
| 99 |
+
keep_silence=mid_sil,
|
| 100 |
+
seek_step=10,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Concatenate all non-silent segments
|
| 104 |
+
wave = AudioSegment.silent(duration=0)
|
| 105 |
+
for seg in non_silent_segs:
|
| 106 |
+
wave += seg
|
| 107 |
+
|
| 108 |
+
# Remove silence longer than 0.1 seconds in the begining and ending of wave
|
| 109 |
+
wave = remove_silence_edges(wave, lead_sil, trail_sil, -50)
|
| 110 |
+
|
| 111 |
+
# Convert to PyTorch tensor
|
| 112 |
+
return audiosegment_to_tensor(wave)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def remove_silence_edges(
|
| 116 |
+
audio: AudioSegment,
|
| 117 |
+
lead_sil: int = 100,
|
| 118 |
+
trail_sil: int = 300,
|
| 119 |
+
silence_threshold: float = -50,
|
| 120 |
+
):
|
| 121 |
+
"""
|
| 122 |
+
Remove edge silences longer than `keep_silence` ms.
|
| 123 |
+
|
| 124 |
+
Parameters:
|
| 125 |
+
audio: an AudioSegment object.
|
| 126 |
+
keep_silence: kept silence in the edge.
|
| 127 |
+
only_edge: If true, only remove edge silences.
|
| 128 |
+
silence_threshold: the threshold of silence.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
An AudioSegment object
|
| 132 |
+
"""
|
| 133 |
+
# Remove heading silence
|
| 134 |
+
start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
|
| 135 |
+
start_idx = max(0, start_idx - lead_sil)
|
| 136 |
+
audio = audio[start_idx:]
|
| 137 |
+
|
| 138 |
+
# Remove trailing silence
|
| 139 |
+
audio = audio.reverse()
|
| 140 |
+
start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
|
| 141 |
+
start_idx = max(0, start_idx - trail_sil)
|
| 142 |
+
audio = audio[start_idx:]
|
| 143 |
+
audio = audio.reverse()
|
| 144 |
+
|
| 145 |
+
return audio
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def audiosegment_to_tensor(aseg):
|
| 149 |
+
"""
|
| 150 |
+
Convert a pydub.AudioSegment to PyTorch audio tensor
|
| 151 |
+
"""
|
| 152 |
+
audio_data = np.array(aseg.get_array_of_samples())
|
| 153 |
+
|
| 154 |
+
# Convert to float32 and normalize to [-1, 1] range
|
| 155 |
+
audio_data = audio_data.astype(np.float32) / 32768.0
|
| 156 |
+
|
| 157 |
+
# Handle channels
|
| 158 |
+
if aseg.channels == 1:
|
| 159 |
+
# Mono channel: add channel dimension (T) -> (1, T)
|
| 160 |
+
tensor_data = torch.from_numpy(audio_data).unsqueeze(0)
|
| 161 |
+
else:
|
| 162 |
+
# Multi-channel: reshape to (C, T)
|
| 163 |
+
tensor_data = torch.from_numpy(audio_data.reshape(-1, aseg.channels).T)
|
| 164 |
+
|
| 165 |
+
return tensor_data
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def tensor_to_audiosegment(tensor, sample_rate):
|
| 169 |
+
"""
|
| 170 |
+
Convert a PyTorch audio tensor to pydub.AudioSegment
|
| 171 |
+
|
| 172 |
+
Parameters:
|
| 173 |
+
tensor: Tensor with shape (C, T), where C is the number of channels
|
| 174 |
+
and T is the time steps
|
| 175 |
+
sample_rate: Audio sample rate
|
| 176 |
+
"""
|
| 177 |
+
# Convert tensor to numpy array
|
| 178 |
+
assert isinstance(tensor, torch.Tensor)
|
| 179 |
+
audio_np = tensor.cpu().numpy()
|
| 180 |
+
|
| 181 |
+
# Convert to int16 type (common format for pydub)
|
| 182 |
+
# Assumes tensor values are in [-1, 1] range as floating point
|
| 183 |
+
audio_np = (audio_np * 32768.0).clip(-32768, 32767).astype(np.int16)
|
| 184 |
+
|
| 185 |
+
# Convert to byte stream
|
| 186 |
+
# For multi-channel audio, pydub requires interleaved format
|
| 187 |
+
# (e.g., left-right-left-right)
|
| 188 |
+
if audio_np.shape[0] > 1:
|
| 189 |
+
# Convert to interleaved format
|
| 190 |
+
audio_np = audio_np.transpose(1, 0).flatten()
|
| 191 |
+
audio_bytes = audio_np.tobytes()
|
| 192 |
+
|
| 193 |
+
# Create AudioSegment
|
| 194 |
+
audio_segment = AudioSegment(
|
| 195 |
+
data=audio_bytes,
|
| 196 |
+
sample_width=2,
|
| 197 |
+
frame_rate=sample_rate,
|
| 198 |
+
channels=tensor.shape[0],
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
return audio_segment
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def fade_and_pad_audio(
|
| 205 |
+
audio: torch.Tensor,
|
| 206 |
+
pad_duration: float = 0.1,
|
| 207 |
+
fade_duration: float = 0.1,
|
| 208 |
+
sample_rate: int = 24000,
|
| 209 |
+
) -> torch.Tensor:
|
| 210 |
+
"""
|
| 211 |
+
Applies a smooth fade-in and fade-out to the audio, and then pads both sides
|
| 212 |
+
with pure silence to prevent abrupt starts and ends (clicks/pops).
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
audio: PyTorch tensor of shape (C, T) containing audio data.
|
| 216 |
+
pad_duration: Duration of pure silence to add to each end (in seconds).
|
| 217 |
+
fade_duration: Duration of the fade-in/out curve (in seconds).
|
| 218 |
+
sample_rate: Audio sampling rate.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Processed sequence tensor with shape (C, T_new)
|
| 222 |
+
"""
|
| 223 |
+
if audio.shape[-1] == 0:
|
| 224 |
+
return audio
|
| 225 |
+
|
| 226 |
+
fade_samples = int(fade_duration * sample_rate)
|
| 227 |
+
pad_samples = int(pad_duration * sample_rate)
|
| 228 |
+
|
| 229 |
+
processed = audio.clone()
|
| 230 |
+
|
| 231 |
+
if fade_samples > 0:
|
| 232 |
+
k = min(fade_samples, processed.shape[-1] // 2)
|
| 233 |
+
|
| 234 |
+
if k > 0:
|
| 235 |
+
fade_in = torch.linspace(
|
| 236 |
+
0, 1, k, device=processed.device, dtype=processed.dtype
|
| 237 |
+
)[None, :]
|
| 238 |
+
processed[..., :k] = processed[..., :k] * fade_in
|
| 239 |
+
|
| 240 |
+
fade_out = torch.linspace(
|
| 241 |
+
1, 0, k, device=processed.device, dtype=processed.dtype
|
| 242 |
+
)[None, :]
|
| 243 |
+
processed[..., -k:] = processed[..., -k:] * fade_out
|
| 244 |
+
|
| 245 |
+
if pad_samples > 0:
|
| 246 |
+
silence = torch.zeros(
|
| 247 |
+
(processed.shape[0], pad_samples),
|
| 248 |
+
dtype=processed.dtype,
|
| 249 |
+
device=processed.device,
|
| 250 |
+
)
|
| 251 |
+
processed = torch.cat([silence, processed, silence], dim=-1)
|
| 252 |
+
|
| 253 |
+
return processed
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def trim_long_audio(
|
| 257 |
+
audio: torch.Tensor,
|
| 258 |
+
sampling_rate: int,
|
| 259 |
+
max_duration: float = 15.0,
|
| 260 |
+
min_duration: float = 3.0,
|
| 261 |
+
trim_threshold: float = 20.0,
|
| 262 |
+
) -> torch.Tensor:
|
| 263 |
+
"""Trim audio to <= max_duration by splitting at the largest silence gap.
|
| 264 |
+
|
| 265 |
+
Only trims when the audio exceeds *trim_threshold* seconds.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
audio: Audio tensor of shape (C, T).
|
| 269 |
+
sampling_rate: Audio sampling rate.
|
| 270 |
+
max_duration: Maximum duration in seconds.
|
| 271 |
+
min_duration: Minimum duration in seconds.
|
| 272 |
+
trim_threshold: Only trim if audio is longer than this (seconds).
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Trimmed audio tensor.
|
| 276 |
+
"""
|
| 277 |
+
duration = audio.size(-1) / sampling_rate
|
| 278 |
+
if duration <= trim_threshold:
|
| 279 |
+
return audio
|
| 280 |
+
|
| 281 |
+
seg = tensor_to_audiosegment(audio, sampling_rate)
|
| 282 |
+
nonsilent = detect_nonsilent(
|
| 283 |
+
seg, min_silence_len=100, silence_thresh=-40, seek_step=10
|
| 284 |
+
)
|
| 285 |
+
if not nonsilent:
|
| 286 |
+
return audio
|
| 287 |
+
|
| 288 |
+
max_ms = int(max_duration * 1000)
|
| 289 |
+
min_ms = int(min_duration * 1000)
|
| 290 |
+
|
| 291 |
+
# Walk through speech regions; at each gap pick the latest split <= max_duration
|
| 292 |
+
best_split = 0
|
| 293 |
+
for start, end in nonsilent:
|
| 294 |
+
if start > best_split and start <= max_ms:
|
| 295 |
+
best_split = start
|
| 296 |
+
if end > max_ms:
|
| 297 |
+
break
|
| 298 |
+
|
| 299 |
+
if best_split < min_ms:
|
| 300 |
+
best_split = min(max_ms, len(seg))
|
| 301 |
+
|
| 302 |
+
trimmed = seg[:best_split]
|
| 303 |
+
return audiosegment_to_tensor(trimmed)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def cross_fade_chunks(
|
| 307 |
+
chunks: list[torch.Tensor],
|
| 308 |
+
sample_rate: int,
|
| 309 |
+
silence_duration: float = 0.3,
|
| 310 |
+
) -> torch.Tensor:
|
| 311 |
+
"""Concatenate audio chunks with a short silence gap and fade at boundaries.
|
| 312 |
+
|
| 313 |
+
Each boundary is structured as: fade-out tail → silence buffer → fade-in head.
|
| 314 |
+
This avoids click artifacts from direct concatenation or overlapping mismatch.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
chunks: List of audio tensors, each (C, T).
|
| 318 |
+
sample_rate: Audio sample rate.
|
| 319 |
+
silence_duration: Total silence gap duration in seconds.
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
Merged audio tensor (C, T_total).
|
| 323 |
+
"""
|
| 324 |
+
if len(chunks) == 1:
|
| 325 |
+
return chunks[0]
|
| 326 |
+
|
| 327 |
+
total_n = int(silence_duration * sample_rate)
|
| 328 |
+
fade_n = total_n // 3
|
| 329 |
+
silence_n = fade_n # middle silent gap
|
| 330 |
+
merged = chunks[0].clone()
|
| 331 |
+
|
| 332 |
+
for chunk in chunks[1:]:
|
| 333 |
+
dev, dt = merged.device, merged.dtype
|
| 334 |
+
parts = [merged]
|
| 335 |
+
|
| 336 |
+
# Fade out tail of current merged audio
|
| 337 |
+
fout_n = min(fade_n, merged.size(-1))
|
| 338 |
+
if fout_n > 0:
|
| 339 |
+
w_out = torch.linspace(1, 0, fout_n, device=dev, dtype=dt)[None, :]
|
| 340 |
+
parts[-1][..., -fout_n:] = parts[-1][..., -fout_n:] * w_out
|
| 341 |
+
|
| 342 |
+
# Silent buffer between chunks
|
| 343 |
+
parts.append(torch.zeros(chunks[0].shape[0], silence_n, device=dev, dtype=dt))
|
| 344 |
+
|
| 345 |
+
# Fade in head of next chunk
|
| 346 |
+
fade_in = chunk.clone()
|
| 347 |
+
fin_n = min(fade_n, fade_in.size(-1))
|
| 348 |
+
if fin_n > 0:
|
| 349 |
+
w_in = torch.linspace(0, 1, fin_n, device=dev, dtype=dt)[None, :]
|
| 350 |
+
fade_in[..., :fin_n] = fade_in[..., :fin_n] * w_in
|
| 351 |
+
|
| 352 |
+
parts.append(fade_in)
|
| 353 |
+
merged = torch.cat(parts, dim=-1)
|
| 354 |
+
|
| 355 |
+
return merged
|
omnivoice/utils/common.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Shared utility functions."""
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import random
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def str2bool(v):
|
| 28 |
+
"""Used in argparse.ArgumentParser.add_argument to indicate
|
| 29 |
+
that a type is a bool type and user can enter
|
| 30 |
+
|
| 31 |
+
- yes, true, t, y, 1, to represent True
|
| 32 |
+
- no, false, f, n, 0, to represent False
|
| 33 |
+
|
| 34 |
+
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
| 35 |
+
"""
|
| 36 |
+
if isinstance(v, bool):
|
| 37 |
+
return v
|
| 38 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 39 |
+
return True
|
| 40 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 41 |
+
return False
|
| 42 |
+
else:
|
| 43 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def fix_random_seed(random_seed: int):
|
| 47 |
+
"""
|
| 48 |
+
Set the same random seed for the libraries and modules.
|
| 49 |
+
Includes the ``random`` module, numpy, and torch.
|
| 50 |
+
"""
|
| 51 |
+
random.seed(random_seed)
|
| 52 |
+
np.random.seed(random_seed)
|
| 53 |
+
torch.random.manual_seed(random_seed)
|
| 54 |
+
# Ensure deterministic ID creation
|
| 55 |
+
rd = random.Random()
|
| 56 |
+
rd.seed(random_seed)
|
omnivoice/utils/data_utils.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Data utilities for batch inference and evaluation.
|
| 19 |
+
|
| 20 |
+
Provides ``read_test_list()`` to parse JSONL test list files used by
|
| 21 |
+
``omnivoice.cli.infer_batch`` and evaluation scripts.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import json
|
| 25 |
+
import logging
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def read_test_list(path):
|
| 30 |
+
"""Read a JSONL test list file.
|
| 31 |
+
|
| 32 |
+
Each line should be a JSON object with fields:
|
| 33 |
+
id, text, ref_audio, ref_text, language_id, language_name, duration, speed
|
| 34 |
+
|
| 35 |
+
language_id, language_name, duration, and speed are optional (default to None).
|
| 36 |
+
|
| 37 |
+
Returns a list of dicts.
|
| 38 |
+
"""
|
| 39 |
+
path = Path(path)
|
| 40 |
+
samples = []
|
| 41 |
+
with path.open("r", encoding="utf-8") as f:
|
| 42 |
+
for line_no, line in enumerate(f, 1):
|
| 43 |
+
line = line.strip()
|
| 44 |
+
if not line:
|
| 45 |
+
continue
|
| 46 |
+
try:
|
| 47 |
+
obj = json.loads(line)
|
| 48 |
+
except json.JSONDecodeError:
|
| 49 |
+
logging.warning(f"Skipping malformed JSON at line {line_no}: {line}")
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
sample = {
|
| 53 |
+
"id": obj.get("id"),
|
| 54 |
+
"text": obj.get("text"),
|
| 55 |
+
"ref_audio": obj.get("ref_audio"),
|
| 56 |
+
"ref_text": obj.get("ref_text"),
|
| 57 |
+
"language_id": obj.get("language_id"),
|
| 58 |
+
"language_name": obj.get("language_name"),
|
| 59 |
+
"duration": obj.get("duration"),
|
| 60 |
+
"speed": obj.get("speed"),
|
| 61 |
+
}
|
| 62 |
+
samples.append(sample)
|
| 63 |
+
return samples
|
omnivoice/utils/duration.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Text duration estimation for TTS generation.
|
| 19 |
+
|
| 20 |
+
Provides ``RuleDurationEstimator``, which estimates audio duration from text
|
| 21 |
+
using character phonetic weights across 600+ languages. Used by
|
| 22 |
+
``OmniVoice.generate()`` to determine output length when no duration is specified.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import bisect
|
| 26 |
+
import unicodedata
|
| 27 |
+
from functools import lru_cache
|
| 28 |
+
from typing import Optional
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RuleDurationEstimator:
|
| 32 |
+
def __init__(self):
|
| 33 |
+
# ==========================================
|
| 34 |
+
# 1. Phonetic Weights Table
|
| 35 |
+
# ==========================================
|
| 36 |
+
# The weight represents the relative speaking time compared to
|
| 37 |
+
# a standard Latin letter.
|
| 38 |
+
# Benchmark: 1.0 = One Latin Character (~40-50ms)
|
| 39 |
+
self.weights = {
|
| 40 |
+
# --- Logographic (1 char = full syllable/word) ---
|
| 41 |
+
"cjk": 3.0, # Chinese, Japanese Kanji, etc.
|
| 42 |
+
# --- Syllabic / Blocks
|
| 43 |
+
"hangul": 2.5, # Korean Hangul
|
| 44 |
+
"kana": 2.2, # Japanese Hiragana/Katakana
|
| 45 |
+
"ethiopic": 3.0, # Amharic/Ge'ez
|
| 46 |
+
"yi": 3.0, # Yi script
|
| 47 |
+
# --- Abugida (Consonant-Vowel complexes) ---
|
| 48 |
+
"indic": 1.8, # Hindi, Bengali, Tamil, etc.
|
| 49 |
+
"thai_lao": 1.5, # Thai, Lao
|
| 50 |
+
"khmer_myanmar": 1.8, # Khmer, Myanmar
|
| 51 |
+
# --- Abjad (Consonant-heavy) ---
|
| 52 |
+
"arabic": 1.5, # Arabic, Persian, Urdu
|
| 53 |
+
"hebrew": 1.5, # Hebrew
|
| 54 |
+
# --- Alphabet (Segmental) ---
|
| 55 |
+
"latin": 1.0, # English, Spanish, French, Vietnamese, etc. (Baseline)
|
| 56 |
+
"cyrillic": 1.0, # Russian, Ukrainian
|
| 57 |
+
"greek": 1.0, # Greek
|
| 58 |
+
"armenian": 1.0, # Armenian
|
| 59 |
+
"georgian": 1.0, # Georgian
|
| 60 |
+
# --- Symbols & Misc ---
|
| 61 |
+
"punctuation": 0.5, # Pause capability
|
| 62 |
+
"space": 0.2, # Word boundary/Breath (0.05 / 0.22)
|
| 63 |
+
"digit": 3.5, # Numbers
|
| 64 |
+
"mark": 0.0, # Diacritics/Accents (Silent modifiers)
|
| 65 |
+
"default": 1.0, # Fallback for unknown scripts
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# ==========================================
|
| 69 |
+
# 2. Unicode Range Mapping
|
| 70 |
+
# ==========================================
|
| 71 |
+
# Format: (End_Codepoint, Type_Key)
|
| 72 |
+
# Used for fast binary search (bisect).
|
| 73 |
+
self.ranges = [
|
| 74 |
+
(0x02AF, "latin"), # Latin (Basic, Supplement, Ext, IPA)
|
| 75 |
+
(0x03FF, "greek"), # Greek & Coptic
|
| 76 |
+
(0x052F, "cyrillic"), # Cyrillic
|
| 77 |
+
(0x058F, "armenian"), # Armenian
|
| 78 |
+
(0x05FF, "hebrew"), # Hebrew
|
| 79 |
+
(0x077F, "arabic"), # Arabic, Syriac, Arabic Supplement
|
| 80 |
+
(0x089F, "arabic"), # Arabic Extended-B (+ Syriac Supp)
|
| 81 |
+
(0x08FF, "arabic"), # Arabic Extended-A
|
| 82 |
+
(0x097F, "indic"), # Devanagari
|
| 83 |
+
(0x09FF, "indic"), # Bengali
|
| 84 |
+
(0x0A7F, "indic"), # Gurmukhi
|
| 85 |
+
(0x0AFF, "indic"), # Gujarati
|
| 86 |
+
(0x0B7F, "indic"), # Oriya
|
| 87 |
+
(0x0BFF, "indic"), # Tamil
|
| 88 |
+
(0x0C7F, "indic"), # Telugu
|
| 89 |
+
(0x0CFF, "indic"), # Kannada
|
| 90 |
+
(0x0D7F, "indic"), # Malayalam
|
| 91 |
+
(0x0DFF, "indic"), # Sinhala
|
| 92 |
+
(0x0EFF, "thai_lao"), # Thai & Lao
|
| 93 |
+
(0x0FFF, "indic"), # Tibetan (Abugida)
|
| 94 |
+
(0x109F, "khmer_myanmar"), # Myanmar
|
| 95 |
+
(0x10FF, "georgian"), # Georgian
|
| 96 |
+
(0x11FF, "hangul"), # Hangul Jamo
|
| 97 |
+
(0x137F, "ethiopic"), # Ethiopic
|
| 98 |
+
(0x139F, "ethiopic"), # Ethiopic Supplement
|
| 99 |
+
(0x13FF, "default"), # Cherokee
|
| 100 |
+
(0x167F, "default"), # Canadian Aboriginal Syllabics
|
| 101 |
+
(0x169F, "default"), # Ogham
|
| 102 |
+
(0x16FF, "default"), # Runic
|
| 103 |
+
(0x171F, "default"), # Tagalog (Baybayin)
|
| 104 |
+
(0x173F, "default"), # Hanunoo
|
| 105 |
+
(0x175F, "default"), # Buhid
|
| 106 |
+
(0x177F, "default"), # Tagbanwa
|
| 107 |
+
(0x17FF, "khmer_myanmar"), # Khmer
|
| 108 |
+
(0x18AF, "default"), # Mongolian
|
| 109 |
+
(0x18FF, "default"), # Canadian Aboriginal Syllabics Ext
|
| 110 |
+
(0x194F, "indic"), # Limbu
|
| 111 |
+
(0x19DF, "indic"), # Tai Le & New Tai Lue
|
| 112 |
+
(0x19FF, "khmer_myanmar"), # Khmer Symbols
|
| 113 |
+
(0x1A1F, "indic"), # Buginese
|
| 114 |
+
(0x1AAF, "indic"), # Tai Tham
|
| 115 |
+
(0x1B7F, "indic"), # Balinese
|
| 116 |
+
(0x1BBF, "indic"), # Sundanese
|
| 117 |
+
(0x1BFF, "indic"), # Batak
|
| 118 |
+
(0x1C4F, "indic"), # Lepcha
|
| 119 |
+
(0x1C7F, "indic"), # Ol Chiki (Santali)
|
| 120 |
+
(0x1C8F, "cyrillic"), # Cyrillic Extended-C
|
| 121 |
+
(0x1CBF, "georgian"), # Georgian Extended
|
| 122 |
+
(0x1CCF, "indic"), # Sundanese Supplement
|
| 123 |
+
(0x1CFF, "indic"), # Vedic Extensions
|
| 124 |
+
(0x1D7F, "latin"), # Phonetic Extensions
|
| 125 |
+
(0x1DBF, "latin"), # Phonetic Extensions Supplement
|
| 126 |
+
(0x1DFF, "default"), # Combining Diacritical Marks Supplement
|
| 127 |
+
(0x1EFF, "latin"), # Latin Extended Additional (Vietnamese)
|
| 128 |
+
(0x309F, "kana"), # Hiragana
|
| 129 |
+
(0x30FF, "kana"), # Katakana
|
| 130 |
+
(0x312F, "cjk"), # Bopomofo (Pinyin)
|
| 131 |
+
(0x318F, "hangul"), # Hangul Compatibility Jamo
|
| 132 |
+
(0x9FFF, "cjk"), # CJK Unified Ideographs (Main)
|
| 133 |
+
(0xA4CF, "yi"), # Yi Syllables
|
| 134 |
+
(0xA4FF, "default"), # Lisu
|
| 135 |
+
(0xA63F, "default"), # Vai
|
| 136 |
+
(0xA69F, "cyrillic"), # Cyrillic Extended-B
|
| 137 |
+
(0xA6FF, "default"), # Bamum
|
| 138 |
+
(0xA7FF, "latin"), # Latin Extended-D
|
| 139 |
+
(0xA82F, "indic"), # Syloti Nagri
|
| 140 |
+
(0xA87F, "default"), # Phags-pa
|
| 141 |
+
(0xA8DF, "indic"), # Saurashtra
|
| 142 |
+
(0xA8FF, "indic"), # Devanagari Extended
|
| 143 |
+
(0xA92F, "indic"), # Kayah Li
|
| 144 |
+
(0xA95F, "indic"), # Rejang
|
| 145 |
+
(0xA97F, "hangul"), # Hangul Jamo Extended-A
|
| 146 |
+
(0xA9DF, "indic"), # Javanese
|
| 147 |
+
(0xA9FF, "khmer_myanmar"), # Myanmar Extended-B
|
| 148 |
+
(0xAA5F, "indic"), # Cham
|
| 149 |
+
(0xAA7F, "khmer_myanmar"), # Myanmar Extended-A
|
| 150 |
+
(0xAADF, "indic"), # Tai Viet
|
| 151 |
+
(0xAAFF, "indic"), # Meetei Mayek Extensions
|
| 152 |
+
(0xAB2F, "ethiopic"), # Ethiopic Extended-A
|
| 153 |
+
(0xAB6F, "latin"), # Latin Extended-E
|
| 154 |
+
(0xABBF, "default"), # Cherokee Supplement
|
| 155 |
+
(0xABFF, "indic"), # Meetei Mayek
|
| 156 |
+
(0xD7AF, "hangul"), # Hangul Syllables
|
| 157 |
+
(0xFAFF, "cjk"), # CJK Compatibility
|
| 158 |
+
(0xFDFF, "arabic"), # Arabic Presentation Forms-A
|
| 159 |
+
(0xFE6F, "default"), # Variation Selectors
|
| 160 |
+
(0xFEFF, "arabic"), # Arabic Presentation Forms-B
|
| 161 |
+
(0xFFEF, "latin"), # Fullwidth Latin
|
| 162 |
+
]
|
| 163 |
+
self.breakpoints = [r[0] for r in self.ranges]
|
| 164 |
+
|
| 165 |
+
@lru_cache(maxsize=4096)
|
| 166 |
+
def _get_char_weight(self, char):
|
| 167 |
+
"""Determines the weight of a single character."""
|
| 168 |
+
code = ord(char)
|
| 169 |
+
if (65 <= code <= 90) or (97 <= code <= 122):
|
| 170 |
+
return self.weights["latin"]
|
| 171 |
+
if code == 32:
|
| 172 |
+
return self.weights["space"]
|
| 173 |
+
|
| 174 |
+
# Ignore arabic Tatweel
|
| 175 |
+
if code == 0x0640:
|
| 176 |
+
return self.weights["mark"]
|
| 177 |
+
|
| 178 |
+
category = unicodedata.category(char)
|
| 179 |
+
|
| 180 |
+
if category.startswith("M"):
|
| 181 |
+
return self.weights["mark"]
|
| 182 |
+
|
| 183 |
+
if category.startswith("P") or category.startswith("S"):
|
| 184 |
+
return self.weights["punctuation"]
|
| 185 |
+
|
| 186 |
+
if category.startswith("Z"):
|
| 187 |
+
return self.weights["space"]
|
| 188 |
+
|
| 189 |
+
if category.startswith("N"):
|
| 190 |
+
return self.weights["digit"]
|
| 191 |
+
|
| 192 |
+
# 3. Binary search for Unicode Block (此时区间里绝不会再混进标点符号)
|
| 193 |
+
idx = bisect.bisect_left(self.breakpoints, code)
|
| 194 |
+
if idx < len(self.ranges):
|
| 195 |
+
script_type = self.ranges[idx][1]
|
| 196 |
+
return self.weights.get(script_type, self.weights["default"])
|
| 197 |
+
|
| 198 |
+
# 4. Handle upper planes (CJK Ext B/C/D, Historic scripts)
|
| 199 |
+
if code > 0x20000:
|
| 200 |
+
return self.weights["cjk"]
|
| 201 |
+
|
| 202 |
+
return self.weights["default"]
|
| 203 |
+
|
| 204 |
+
def calculate_total_weight(self, text):
|
| 205 |
+
"""Sums up the normalized weights for a string."""
|
| 206 |
+
return sum(self._get_char_weight(c) for c in text)
|
| 207 |
+
|
| 208 |
+
def estimate_duration(
|
| 209 |
+
self,
|
| 210 |
+
target_text: str,
|
| 211 |
+
ref_text: str,
|
| 212 |
+
ref_duration: float,
|
| 213 |
+
low_threshold: Optional[float] = 50,
|
| 214 |
+
boost_strength: float = 3,
|
| 215 |
+
) -> float:
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
target_text (str): The text for which we want to estimate the duration.
|
| 220 |
+
ref_text (str): The reference text that was used to measure
|
| 221 |
+
the ref_duration.
|
| 222 |
+
ref_duration (float): The actual duration it took
|
| 223 |
+
to speak the ref_text.
|
| 224 |
+
low_threshold (float): The minimum duration threshold below which the
|
| 225 |
+
estimation will be considered unreliable.
|
| 226 |
+
boost_strength (float): Controls the power-curve boost for short durations.
|
| 227 |
+
Higher values boost small durations more aggressively.
|
| 228 |
+
1 = no boost (linear), 2 = sqrt-like
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
float: The estimated duration for the target_text based
|
| 232 |
+
on the ref_text and ref_duration.
|
| 233 |
+
"""
|
| 234 |
+
if ref_duration <= 0 or not ref_text:
|
| 235 |
+
return 0.0
|
| 236 |
+
|
| 237 |
+
ref_weight = self.calculate_total_weight(ref_text)
|
| 238 |
+
if ref_weight == 0:
|
| 239 |
+
return 0.0
|
| 240 |
+
|
| 241 |
+
speed_factor = ref_weight / ref_duration
|
| 242 |
+
target_weight = self.calculate_total_weight(target_text)
|
| 243 |
+
|
| 244 |
+
estimated_duration = target_weight / speed_factor
|
| 245 |
+
if low_threshold is not None and estimated_duration < low_threshold:
|
| 246 |
+
alpha = 1.0 / boost_strength
|
| 247 |
+
return low_threshold * (estimated_duration / low_threshold) ** alpha
|
| 248 |
+
else:
|
| 249 |
+
return estimated_duration
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# ==========================================
|
| 253 |
+
# Example Usage
|
| 254 |
+
# ==========================================
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
estimator = RuleDurationEstimator()
|
| 257 |
+
|
| 258 |
+
ref_txt = "Hello, world."
|
| 259 |
+
ref_dur = 1.5
|
| 260 |
+
|
| 261 |
+
test_cases = [
|
| 262 |
+
("Hindi (With complex marks)", "नमस्ते दुनिया"),
|
| 263 |
+
("Arabic (With vowels)", "مَرْحَبًا بِالْعَالَم"),
|
| 264 |
+
("Vietnamese (Lots of diacritics)", "Chào thế giới"),
|
| 265 |
+
("Chinese", "你好,世界!"),
|
| 266 |
+
("Mixed Emoji", "Hello 🌍! This is fun 🎉"),
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
print("--- Reference ---")
|
| 270 |
+
print(f"Reference Text: '{ref_txt}'")
|
| 271 |
+
print(f"Reference Duration: {ref_dur}s")
|
| 272 |
+
print("-" * 30)
|
| 273 |
+
|
| 274 |
+
for lang, txt in test_cases:
|
| 275 |
+
est_time = estimator.estimate_duration(txt, ref_txt, ref_dur)
|
| 276 |
+
weight = estimator.calculate_total_weight(txt)
|
| 277 |
+
|
| 278 |
+
print(f"[{lang}]")
|
| 279 |
+
print(f"Text: {txt}")
|
| 280 |
+
print(f"Total Weight: {weight:.2f}")
|
| 281 |
+
print(f"Estimated Duration: {est_time:.2f} s")
|
| 282 |
+
print("-" * 30)
|
omnivoice/utils/lang_map.py
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Language name to ISO 639-3 code mapping.
|
| 19 |
+
|
| 20 |
+
Auto-generated from ``docs/lang_id_name_map.tsv``. Provides ``LANG_NAME_TO_ID``
|
| 21 |
+
(for resolving language names to codes) and ``LANG_IDS`` (the set of supported
|
| 22 |
+
ISO 639-3 codes). Used by ``OmniVoice.generate()`` to resolve user-provided
|
| 23 |
+
language names.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# Auto-generated from docs/lang_id_name_map.tsv
|
| 27 |
+
# Maps lowercase language name -> language ID code
|
| 28 |
+
|
| 29 |
+
LANG_NAME_TO_ID = {
|
| 30 |
+
"abadi": "kbt",
|
| 31 |
+
"abkhazian": "ab",
|
| 32 |
+
"abron": "abr",
|
| 33 |
+
"abua": "abn",
|
| 34 |
+
"adamawa fulfulde": "fub",
|
| 35 |
+
"adyghe": "ady",
|
| 36 |
+
"afade": "aal",
|
| 37 |
+
"afrikaans": "af",
|
| 38 |
+
"agwagwune": "yay",
|
| 39 |
+
"aja (benin)": "ajg",
|
| 40 |
+
"akebu": "keu",
|
| 41 |
+
"alago": "ala",
|
| 42 |
+
"albanian": "sq",
|
| 43 |
+
"algerian arabic": "arq",
|
| 44 |
+
"algerian saharan arabic": "aao",
|
| 45 |
+
"ambo-pasco quechua": "qva",
|
| 46 |
+
"ambonese malay": "abs",
|
| 47 |
+
"amdo tibetan": "adx",
|
| 48 |
+
"amharic": "am",
|
| 49 |
+
"anaang": "anw",
|
| 50 |
+
"angika": "anp",
|
| 51 |
+
"antankarana malagasy": "xmv",
|
| 52 |
+
"aragonese": "an",
|
| 53 |
+
"arbëreshë albanian": "aae",
|
| 54 |
+
"arequipa-la unión quechua": "qxu",
|
| 55 |
+
"armenian": "hy",
|
| 56 |
+
"ashe": "ahs",
|
| 57 |
+
"ashéninka perené": "prq",
|
| 58 |
+
"askopan": "eiv",
|
| 59 |
+
"assamese": "as",
|
| 60 |
+
"asturian": "ast",
|
| 61 |
+
"atayal": "tay",
|
| 62 |
+
"awak": "awo",
|
| 63 |
+
"ayacucho quechua": "quy",
|
| 64 |
+
"azerbaijani": "az",
|
| 65 |
+
"baatonum": "bba",
|
| 66 |
+
"bacama": "bcy",
|
| 67 |
+
"bade": "bde",
|
| 68 |
+
"bafia": "ksf",
|
| 69 |
+
"bafut": "bfd",
|
| 70 |
+
"bagirmi fulfulde": "fui",
|
| 71 |
+
"bago-kusuntu": "bqg",
|
| 72 |
+
"baharna arabic": "abv",
|
| 73 |
+
"bakoko": "bkh",
|
| 74 |
+
"balanta-ganja": "bjt",
|
| 75 |
+
"balti": "bft",
|
| 76 |
+
"bamenyam": "bce",
|
| 77 |
+
"bamun": "bax",
|
| 78 |
+
"bangwinji": "bsj",
|
| 79 |
+
"banjar": "bjn",
|
| 80 |
+
"bankon": "abb",
|
| 81 |
+
"baoulé": "bci",
|
| 82 |
+
"bara malagasy": "bhr",
|
| 83 |
+
"barok": "bjk",
|
| 84 |
+
"basa (cameroon)": "bas",
|
| 85 |
+
"basa (nigeria)": "bzw",
|
| 86 |
+
"bashkir": "ba",
|
| 87 |
+
"basque": "eu",
|
| 88 |
+
"batak mandailing": "btm",
|
| 89 |
+
"batanga": "bnm",
|
| 90 |
+
"bateri": "btv",
|
| 91 |
+
"bats": "bbl",
|
| 92 |
+
"bayot": "bda",
|
| 93 |
+
"bebele": "beb",
|
| 94 |
+
"belarusian": "be",
|
| 95 |
+
"bengali": "bn",
|
| 96 |
+
"betawi": "bew",
|
| 97 |
+
"bhili": "bhb",
|
| 98 |
+
"bhojpuri": "bho",
|
| 99 |
+
"bilur": "bxf",
|
| 100 |
+
"bima": "bhp",
|
| 101 |
+
"bodo": "brx",
|
| 102 |
+
"boghom": "bux",
|
| 103 |
+
"bokyi": "bky",
|
| 104 |
+
"bomu": "bmq",
|
| 105 |
+
"bondei": "bou",
|
| 106 |
+
"borgu fulfulde": "fue",
|
| 107 |
+
"bosnian": "bs",
|
| 108 |
+
"brahui": "brh",
|
| 109 |
+
"braj": "bra",
|
| 110 |
+
"breton": "br",
|
| 111 |
+
"buduma": "bdm",
|
| 112 |
+
"buginese": "bug",
|
| 113 |
+
"bukharic": "bhh",
|
| 114 |
+
"bulgarian": "bg",
|
| 115 |
+
"bulu (cameroon)": "bum",
|
| 116 |
+
"bundeli": "bns",
|
| 117 |
+
"bunun": "bnn",
|
| 118 |
+
"bura-pabir": "bwr",
|
| 119 |
+
"burak": "bys",
|
| 120 |
+
"burmese": "my",
|
| 121 |
+
"burushaski": "bsk",
|
| 122 |
+
"cacaloxtepec mixtec": "miu",
|
| 123 |
+
"cajatambo north lima quechua": "qvl",
|
| 124 |
+
"cakfem-mushere": "cky",
|
| 125 |
+
"cameroon pidgin": "wes",
|
| 126 |
+
"campidanese sardinian": "sro",
|
| 127 |
+
"cantonese": "yue",
|
| 128 |
+
"catalan": "ca",
|
| 129 |
+
"cebuano": "ceb",
|
| 130 |
+
"cen": "cen",
|
| 131 |
+
"central kurdish": "ckb",
|
| 132 |
+
"central nahuatl": "nhn",
|
| 133 |
+
"central pame": "pbs",
|
| 134 |
+
"central pashto": "pst",
|
| 135 |
+
"central puebla nahuatl": "ncx",
|
| 136 |
+
"central tarahumara": "tar",
|
| 137 |
+
"central yupik": "esu",
|
| 138 |
+
"central-eastern niger fulfulde": "fuq",
|
| 139 |
+
"chadian arabic": "shu",
|
| 140 |
+
"chichewa": "ny",
|
| 141 |
+
"chichicapan zapotec": "zpv",
|
| 142 |
+
"chiga": "cgg",
|
| 143 |
+
"chimalapa zoque": "zoh",
|
| 144 |
+
"chimborazo highland quichua": "qug",
|
| 145 |
+
"chinese": "zh",
|
| 146 |
+
"chiquián ancash quechua": "qxa",
|
| 147 |
+
"chitwania tharu": "the",
|
| 148 |
+
"chokwe": "cjk",
|
| 149 |
+
"chuvash": "cv",
|
| 150 |
+
"cibak": "ckl",
|
| 151 |
+
"coastal konjo": "kjc",
|
| 152 |
+
"copainalá zoque": "zoc",
|
| 153 |
+
"cornish": "kw",
|
| 154 |
+
"corongo ancash quechua": "qwa",
|
| 155 |
+
"croatian": "hr",
|
| 156 |
+
"cross river mbembe": "mfn",
|
| 157 |
+
"cuyamecalco mixtec": "xtu",
|
| 158 |
+
"czech": "cs",
|
| 159 |
+
"dadiya": "dbd",
|
| 160 |
+
"dagbani": "dag",
|
| 161 |
+
"dameli": "dml",
|
| 162 |
+
"danish": "da",
|
| 163 |
+
"dargwa": "dar",
|
| 164 |
+
"dazaga": "dzg",
|
| 165 |
+
"deccan": "dcc",
|
| 166 |
+
"degema": "deg",
|
| 167 |
+
"dera (nigeria)": "kna",
|
| 168 |
+
"dghwede": "dgh",
|
| 169 |
+
"dhatki": "mki",
|
| 170 |
+
"dhivehi": "dv",
|
| 171 |
+
"dhofari arabic": "adf",
|
| 172 |
+
"dijim-bwilim": "cfa",
|
| 173 |
+
"dogri": "dgo",
|
| 174 |
+
"domaaki": "dmk",
|
| 175 |
+
"dotyali": "dty",
|
| 176 |
+
"duala": "dua",
|
| 177 |
+
"dutch": "nl",
|
| 178 |
+
"dũya": "ldb",
|
| 179 |
+
"dyula": "dyu",
|
| 180 |
+
"eastern balochi": "bgp",
|
| 181 |
+
"eastern bolivian guaraní": "gui",
|
| 182 |
+
"eastern egyptian bedawi arabic": "avl",
|
| 183 |
+
"eastern krahn": "kqo",
|
| 184 |
+
"eastern mari": "mhr",
|
| 185 |
+
"eastern yiddish": "ydd",
|
| 186 |
+
"ebrié": "ebr",
|
| 187 |
+
"eggon": "ego",
|
| 188 |
+
"egyptian arabic": "arz",
|
| 189 |
+
"ejagham": "etu",
|
| 190 |
+
"eleme": "elm",
|
| 191 |
+
"eloyi": "afo",
|
| 192 |
+
"embu": "ebu",
|
| 193 |
+
"english": "en",
|
| 194 |
+
"erzya": "myv",
|
| 195 |
+
"esan": "ish",
|
| 196 |
+
"esperanto": "eo",
|
| 197 |
+
"estonian": "et",
|
| 198 |
+
"eton (cameroon)": "eto",
|
| 199 |
+
"ewondo": "ewo",
|
| 200 |
+
"extremaduran": "ext",
|
| 201 |
+
"fang (equatorial guinea)": "fan",
|
| 202 |
+
"fanti": "fat",
|
| 203 |
+
"farefare": "gur",
|
| 204 |
+
"fe'fe'": "fmp",
|
| 205 |
+
"filipino": "fil",
|
| 206 |
+
"filomena mata-coahuitlán totonac": "tlp",
|
| 207 |
+
"finnish": "fi",
|
| 208 |
+
"fipa": "fip",
|
| 209 |
+
"french": "fr",
|
| 210 |
+
"fulah": "ff",
|
| 211 |
+
"galician": "gl",
|
| 212 |
+
"gambian wolof": "wof",
|
| 213 |
+
"ganda": "lg",
|
| 214 |
+
"garhwali": "gbm",
|
| 215 |
+
"gawar-bati": "gwt",
|
| 216 |
+
"gawri": "gwc",
|
| 217 |
+
"gbagyi": "gbr",
|
| 218 |
+
"gbari": "gby",
|
| 219 |
+
"geji": "gyz",
|
| 220 |
+
"gen": "gej",
|
| 221 |
+
"georgian": "ka",
|
| 222 |
+
"german": "de",
|
| 223 |
+
"geser-gorom": "ges",
|
| 224 |
+
"gheg albanian": "aln",
|
| 225 |
+
"ghomálá'": "bbj",
|
| 226 |
+
"gidar": "gid",
|
| 227 |
+
"glavda": "glw",
|
| 228 |
+
"goan konkani": "gom",
|
| 229 |
+
"goaria": "gig",
|
| 230 |
+
"goemai": "ank",
|
| 231 |
+
"gola": "gol",
|
| 232 |
+
"greek": "el",
|
| 233 |
+
"guarani": "gn",
|
| 234 |
+
"guduf-gava": "gdf",
|
| 235 |
+
"guerrero amuzgo": "amu",
|
| 236 |
+
"gujarati": "gu",
|
| 237 |
+
"gujari": "gju",
|
| 238 |
+
"gulf arabic": "afb",
|
| 239 |
+
"gurgula": "ggg",
|
| 240 |
+
"gusii": "guz",
|
| 241 |
+
"gusilay": "gsl",
|
| 242 |
+
"gweno": "gwe",
|
| 243 |
+
"güilá zapotec": "ztu",
|
| 244 |
+
"hadothi": "hoj",
|
| 245 |
+
"hahon": "hah",
|
| 246 |
+
"haitian": "ht",
|
| 247 |
+
"hakha chin": "cnh",
|
| 248 |
+
"hakö": "hao",
|
| 249 |
+
"halia": "hla",
|
| 250 |
+
"hausa": "ha",
|
| 251 |
+
"hawaiian": "haw",
|
| 252 |
+
"hazaragi": "haz",
|
| 253 |
+
"hebrew": "he",
|
| 254 |
+
"hemba": "hem",
|
| 255 |
+
"herero": "hz",
|
| 256 |
+
"highland konjo": "kjk",
|
| 257 |
+
"hijazi arabic": "acw",
|
| 258 |
+
"hindi": "hi",
|
| 259 |
+
"huarijio": "var",
|
| 260 |
+
"huautla mazatec": "mau",
|
| 261 |
+
"huaxcaleca nahuatl": "nhq",
|
| 262 |
+
"huba": "hbb",
|
| 263 |
+
"huitepec mixtec": "mxs",
|
| 264 |
+
"hula": "hul",
|
| 265 |
+
"hungarian": "hu",
|
| 266 |
+
"hunjara-kaina ke": "hkk",
|
| 267 |
+
"hwana": "hwo",
|
| 268 |
+
"ibibio": "ibb",
|
| 269 |
+
"icelandic": "is",
|
| 270 |
+
"idakho-isukha-tiriki": "ida",
|
| 271 |
+
"idoma": "idu",
|
| 272 |
+
"igbo": "ig",
|
| 273 |
+
"igo": "ahl",
|
| 274 |
+
"ikposo": "kpo",
|
| 275 |
+
"ikwere": "ikw",
|
| 276 |
+
"imbabura highland quichua": "qvi",
|
| 277 |
+
"indonesian": "id",
|
| 278 |
+
"indus kohistani": "mvy",
|
| 279 |
+
"interlingua (international auxiliary language association)": "ia",
|
| 280 |
+
"inupiaq": "ik",
|
| 281 |
+
"irish": "ga",
|
| 282 |
+
"iron ossetic": "os",
|
| 283 |
+
"isekiri": "its",
|
| 284 |
+
"isoko": "iso",
|
| 285 |
+
"italian": "it",
|
| 286 |
+
"ito": "itw",
|
| 287 |
+
"itzá": "itz",
|
| 288 |
+
"ixtayutla mixtec": "vmj",
|
| 289 |
+
"izon": "ijc",
|
| 290 |
+
"jambi malay": "jax",
|
| 291 |
+
"japanese": "ja",
|
| 292 |
+
"jaqaru": "jqr",
|
| 293 |
+
"jauja wanca quechua": "qxw",
|
| 294 |
+
"jaunsari": "jns",
|
| 295 |
+
"javanese": "jv",
|
| 296 |
+
"jiba": "juo",
|
| 297 |
+
"jju": "kaj",
|
| 298 |
+
"judeo-moroccan arabic": "aju",
|
| 299 |
+
"juxtlahuaca mixtec": "vmc",
|
| 300 |
+
"kabardian": "kbd",
|
| 301 |
+
"kabras": "lkb",
|
| 302 |
+
"kabuverdianu": "kea",
|
| 303 |
+
"kabyle": "kab",
|
| 304 |
+
"kachi koli": "gjk",
|
| 305 |
+
"kairak": "ckr",
|
| 306 |
+
"kalabari": "ijn",
|
| 307 |
+
"kalasha": "kls",
|
| 308 |
+
"kalenjin": "kln",
|
| 309 |
+
"kalkoti": "xka",
|
| 310 |
+
"kamba": "kam",
|
| 311 |
+
"kamo": "kcq",
|
| 312 |
+
"kanauji": "bjj",
|
| 313 |
+
"kanembu": "kbl",
|
| 314 |
+
"kannada": "kn",
|
| 315 |
+
"karekare": "kai",
|
| 316 |
+
"kashmiri": "ks",
|
| 317 |
+
"kathoriya tharu": "tkt",
|
| 318 |
+
"kati": "bsh",
|
| 319 |
+
"kazakh": "kk",
|
| 320 |
+
"keiyo": "eyo",
|
| 321 |
+
"khams tibetan": "khg",
|
| 322 |
+
"khana": "ogo",
|
| 323 |
+
"khetrani": "xhe",
|
| 324 |
+
"khmer": "km",
|
| 325 |
+
"khowar": "khw",
|
| 326 |
+
"kinga": "zga",
|
| 327 |
+
"kinnauri": "kfk",
|
| 328 |
+
"kinyarwanda": "rw",
|
| 329 |
+
"kirghiz": "ky",
|
| 330 |
+
"kirya-konzəl": "fkk",
|
| 331 |
+
"kochila tharu": "thq",
|
| 332 |
+
"kohistani shina": "plk",
|
| 333 |
+
"kohumono": "bcs",
|
| 334 |
+
"kok borok": "trp",
|
| 335 |
+
"kol (papua new guinea)": "kol",
|
| 336 |
+
"kom (cameroon)": "bkm",
|
| 337 |
+
"koma": "kmy",
|
| 338 |
+
"konkani": "knn",
|
| 339 |
+
"konzo": "koo",
|
| 340 |
+
"korean": "ko",
|
| 341 |
+
"korwa": "kfp",
|
| 342 |
+
"kota (india)": "kfe",
|
| 343 |
+
"koti": "eko",
|
| 344 |
+
"kuanua": "ksd",
|
| 345 |
+
"kuanyama": "kj",
|
| 346 |
+
"kui (india)": "uki",
|
| 347 |
+
"kulung (nigeria)": "bbu",
|
| 348 |
+
"kuot": "kto",
|
| 349 |
+
"kushi": "kuh",
|
| 350 |
+
"kwambi": "kwm",
|
| 351 |
+
"kwasio": "nmg",
|
| 352 |
+
"lala-roba": "lla",
|
| 353 |
+
"lamang": "hia",
|
| 354 |
+
"lao": "lo",
|
| 355 |
+
"larike-wakasihu": "alo",
|
| 356 |
+
"lasi": "lss",
|
| 357 |
+
"latgalian": "ltg",
|
| 358 |
+
"latvian": "lv",
|
| 359 |
+
"levantine arabic": "apc",
|
| 360 |
+
"liana-seti": "ste",
|
| 361 |
+
"liberia kpelle": "xpe",
|
| 362 |
+
"liberian english": "lir",
|
| 363 |
+
"libyan arabic": "ayl",
|
| 364 |
+
"ligurian": "lij",
|
| 365 |
+
"lijili": "mgi",
|
| 366 |
+
"lingala": "ln",
|
| 367 |
+
"lithuanian": "lt",
|
| 368 |
+
"loarki": "lrk",
|
| 369 |
+
"logooli": "rag",
|
| 370 |
+
"logudorese sardinian": "src",
|
| 371 |
+
"loja highland quichua": "qvj",
|
| 372 |
+
"loloda": "loa",
|
| 373 |
+
"longuda": "lnu",
|
| 374 |
+
"loxicha zapotec": "ztp",
|
| 375 |
+
"luba-lulua": "lua",
|
| 376 |
+
"luo": "luo",
|
| 377 |
+
"lushai": "lus",
|
| 378 |
+
"luxembourgish": "lb",
|
| 379 |
+
"maasina fulfulde": "ffm",
|
| 380 |
+
"maba (chad)": "mde",
|
| 381 |
+
"macedo-romanian": "rup",
|
| 382 |
+
"macedonian": "mk",
|
| 383 |
+
"mada (cameroon)": "mxu",
|
| 384 |
+
"mafa": "maf",
|
| 385 |
+
"maithili": "mai",
|
| 386 |
+
"malay": "ms",
|
| 387 |
+
"malayalam": "ml",
|
| 388 |
+
"mali": "gcc",
|
| 389 |
+
"malinaltepec me'phaa": "tcf",
|
| 390 |
+
"maltese": "mt",
|
| 391 |
+
"mandara": "tbf",
|
| 392 |
+
"mandjak": "mfv",
|
| 393 |
+
"manggarai": "mqy",
|
| 394 |
+
"manipuri": "mni",
|
| 395 |
+
"mansoanka": "msw",
|
| 396 |
+
"manx": "gv",
|
| 397 |
+
"maori": "mi",
|
| 398 |
+
"marathi": "mr",
|
| 399 |
+
"marghi central": "mrt",
|
| 400 |
+
"marghi south": "mfm",
|
| 401 |
+
"maria (india)": "mrr",
|
| 402 |
+
"marwari (pakistan)": "mve",
|
| 403 |
+
"masana": "mcn",
|
| 404 |
+
"masikoro malagasy": "msh",
|
| 405 |
+
"matsés": "mcf",
|
| 406 |
+
"mazaltepec zapotec": "zpy",
|
| 407 |
+
"mazatlán mazatec": "vmz",
|
| 408 |
+
"mazatlán mixe": "mzl",
|
| 409 |
+
"mbe": "mfo",
|
| 410 |
+
"mbo (cameroon)": "mbo",
|
| 411 |
+
"mbum": "mdd",
|
| 412 |
+
"medumba": "byv",
|
| 413 |
+
"mekeo": "mek",
|
| 414 |
+
"meru": "mer",
|
| 415 |
+
"mesopotamian arabic": "acm",
|
| 416 |
+
"mewari": "mtr",
|
| 417 |
+
"min nan chinese": "nan",
|
| 418 |
+
"mingrelian": "xmf",
|
| 419 |
+
"mitlatongo mixtec": "vmm",
|
| 420 |
+
"miya": "mkf",
|
| 421 |
+
"mokpwe": "bri",
|
| 422 |
+
"moksha": "mdf",
|
| 423 |
+
"mom jango": "ver",
|
| 424 |
+
"mongolian": "mn",
|
| 425 |
+
"moroccan arabic": "ary",
|
| 426 |
+
"motu": "meu",
|
| 427 |
+
"mpiemo": "mcx",
|
| 428 |
+
"mpumpong": "mgg",
|
| 429 |
+
"mundang": "mua",
|
| 430 |
+
"mungaka": "mhk",
|
| 431 |
+
"musey": "mse",
|
| 432 |
+
"musgu": "mug",
|
| 433 |
+
"musi": "mui",
|
| 434 |
+
"naba": "mne",
|
| 435 |
+
"najdi arabic": "ars",
|
| 436 |
+
"nalik": "nal",
|
| 437 |
+
"nawdm": "nmz",
|
| 438 |
+
"ndonga": "ng",
|
| 439 |
+
"neapolitan": "nap",
|
| 440 |
+
"nepali": "npi",
|
| 441 |
+
"ngamo": "nbh",
|
| 442 |
+
"ngas": "anc",
|
| 443 |
+
"ngiemboon": "nnh",
|
| 444 |
+
"ngizim": "ngi",
|
| 445 |
+
"ngomba": "jgo",
|
| 446 |
+
"ngombale": "nla",
|
| 447 |
+
"nigerian fulfulde": "fuv",
|
| 448 |
+
"nigerian pidgin": "pcm",
|
| 449 |
+
"nimadi": "noe",
|
| 450 |
+
"nobiin": "fia",
|
| 451 |
+
"north mesopotamian arabic": "ayp",
|
| 452 |
+
"north moluccan malay": "max",
|
| 453 |
+
"northern betsimisaraka malagasy": "bmm",
|
| 454 |
+
"northern hindko": "hno",
|
| 455 |
+
"northern kurdish": "kmr",
|
| 456 |
+
"northern pame": "pmq",
|
| 457 |
+
"northern pashto": "pbu",
|
| 458 |
+
"northern uzbek": "uzn",
|
| 459 |
+
"northwest gbaya": "gya",
|
| 460 |
+
"norwegian": "no",
|
| 461 |
+
"norwegian bokmål": "nb",
|
| 462 |
+
"norwegian nynorsk": "nn",
|
| 463 |
+
"notsi": "ncf",
|
| 464 |
+
"nyankpa": "yes",
|
| 465 |
+
"nyungwe": "nyu",
|
| 466 |
+
"nzanyi": "nja",
|
| 467 |
+
"nüpode huitoto": "hux",
|
| 468 |
+
"occitan": "oc",
|
| 469 |
+
"od": "odk",
|
| 470 |
+
"odia": "ory",
|
| 471 |
+
"odual": "odu",
|
| 472 |
+
"omani arabic": "acx",
|
| 473 |
+
"orizaba nahuatl": "nlv",
|
| 474 |
+
"orma": "orc",
|
| 475 |
+
"ormuri": "oru",
|
| 476 |
+
"oromo": "om",
|
| 477 |
+
"pahari-potwari": "phr",
|
| 478 |
+
"paiwan": "pwn",
|
| 479 |
+
"panjabi": "pa",
|
| 480 |
+
"papuan malay": "pmy",
|
| 481 |
+
"parkari koli": "kvx",
|
| 482 |
+
"pedi": "nso",
|
| 483 |
+
"pero": "pip",
|
| 484 |
+
"persian": "fa",
|
| 485 |
+
"petats": "pex",
|
| 486 |
+
"phalura": "phl",
|
| 487 |
+
"piemontese": "pms",
|
| 488 |
+
"piya-kwonci": "piy",
|
| 489 |
+
"plateau malagasy": "plt",
|
| 490 |
+
"polish": "pl",
|
| 491 |
+
"poqomam": "poc",
|
| 492 |
+
"portuguese": "pt",
|
| 493 |
+
"pulaar": "fuc",
|
| 494 |
+
"pular": "fuf",
|
| 495 |
+
"puno quechua": "qxp",
|
| 496 |
+
"pushto": "ps",
|
| 497 |
+
"pökoot": "pko",
|
| 498 |
+
"qaqet": "byx",
|
| 499 |
+
"quiotepec chinantec": "chq",
|
| 500 |
+
"rana tharu": "thr",
|
| 501 |
+
"rangi": "lag",
|
| 502 |
+
"rapoisi": "kyx",
|
| 503 |
+
"ratahan": "rth",
|
| 504 |
+
"rayón zoque": "zor",
|
| 505 |
+
"romanian": "ro",
|
| 506 |
+
"romansh": "rm",
|
| 507 |
+
"rombo": "rof",
|
| 508 |
+
"rotokas": "roo",
|
| 509 |
+
"rukai": "dru",
|
| 510 |
+
"russian": "ru",
|
| 511 |
+
"sacapulteco": "quv",
|
| 512 |
+
"saidi arabic": "aec",
|
| 513 |
+
"sakalava malagasy": "skg",
|
| 514 |
+
"sakizaya": "szy",
|
| 515 |
+
"saleman": "sau",
|
| 516 |
+
"samba daka": "ccg",
|
| 517 |
+
"samba leko": "ndi",
|
| 518 |
+
"san felipe otlaltepec popoloca": "pow",
|
| 519 |
+
"san francisco del mar huave": "hue",
|
| 520 |
+
"san juan atzingo popoloca": "poe",
|
| 521 |
+
"san martín itunyoso triqui": "trq",
|
| 522 |
+
"san miguel el grande mixtec": "mig",
|
| 523 |
+
"sansi": "ssi",
|
| 524 |
+
"sanskrit": "sa",
|
| 525 |
+
"santa ana de tusi pasco quechua": "qxt",
|
| 526 |
+
"santa catarina albarradas zapotec": "ztn",
|
| 527 |
+
"santali": "sat",
|
| 528 |
+
"santiago del estero quichua": "qus",
|
| 529 |
+
"saposa": "sps",
|
| 530 |
+
"saraiki": "skr",
|
| 531 |
+
"sardinian": "sc",
|
| 532 |
+
"saya": "say",
|
| 533 |
+
"sediq": "trv",
|
| 534 |
+
"serbian": "sr",
|
| 535 |
+
"seri": "sei",
|
| 536 |
+
"shina": "scl",
|
| 537 |
+
"shona": "sn",
|
| 538 |
+
"siar-lak": "sjr",
|
| 539 |
+
"sibe": "nco",
|
| 540 |
+
"sicilian": "scn",
|
| 541 |
+
"sihuas ancash quechua": "qws",
|
| 542 |
+
"sikkimese": "sip",
|
| 543 |
+
"sinaugoro": "snc",
|
| 544 |
+
"sindhi": "sd",
|
| 545 |
+
"sindhi bhil": "sbn",
|
| 546 |
+
"sinhala": "si",
|
| 547 |
+
"sinicahua mixtec": "xti",
|
| 548 |
+
"sipacapense": "qum",
|
| 549 |
+
"siwai": "siw",
|
| 550 |
+
"slovak": "sk",
|
| 551 |
+
"slovenian": "sl",
|
| 552 |
+
"solos": "sol",
|
| 553 |
+
"somali": "so",
|
| 554 |
+
"soninke": "snk",
|
| 555 |
+
"south giziga": "giz",
|
| 556 |
+
"south ucayali ashéninka": "cpy",
|
| 557 |
+
"southeastern nochixtlán mixtec": "mxy",
|
| 558 |
+
"southern betsimisaraka malagasy": "bzc",
|
| 559 |
+
"southern pashto": "pbt",
|
| 560 |
+
"southern pastaza quechua": "qup",
|
| 561 |
+
"soyaltepec mazatec": "vmp",
|
| 562 |
+
"spanish": "es",
|
| 563 |
+
"standard arabic": "arb",
|
| 564 |
+
"standard moroccan tamazight": "zgh",
|
| 565 |
+
"sudanese arabic": "apd",
|
| 566 |
+
"sulka": "sua",
|
| 567 |
+
"svan": "sva",
|
| 568 |
+
"swahili": "sw",
|
| 569 |
+
"swedish": "sv",
|
| 570 |
+
"tae'": "rob",
|
| 571 |
+
"tahaggart tamahaq": "thv",
|
| 572 |
+
"taita": "dav",
|
| 573 |
+
"tajik": "tg",
|
| 574 |
+
"tamil": "ta",
|
| 575 |
+
"tandroy-mahafaly malagasy": "tdx",
|
| 576 |
+
"tangale": "tan",
|
| 577 |
+
"tanosy malagasy": "txy",
|
| 578 |
+
"tarok": "yer",
|
| 579 |
+
"tatar": "tt",
|
| 580 |
+
"tedaga": "tuq",
|
| 581 |
+
"telugu": "te",
|
| 582 |
+
"tem": "kdh",
|
| 583 |
+
"teop": "tio",
|
| 584 |
+
"tepeuxila cuicatec": "cux",
|
| 585 |
+
"tepinapa chinantec": "cte",
|
| 586 |
+
"tera": "ttr",
|
| 587 |
+
"terei": "buo",
|
| 588 |
+
"termanu": "twu",
|
| 589 |
+
"tesaka malagasy": "tkg",
|
| 590 |
+
"tetelcingo nahuatl": "nhg",
|
| 591 |
+
"teutila cuicatec": "cut",
|
| 592 |
+
"thai": "th",
|
| 593 |
+
"tibetan": "bo",
|
| 594 |
+
"tidaá mixtec": "mtx",
|
| 595 |
+
"tidore": "tvo",
|
| 596 |
+
"tigak": "tgc",
|
| 597 |
+
"tigre": "tig",
|
| 598 |
+
"tigrinya": "ti",
|
| 599 |
+
"tilquiapan zapotec": "zts",
|
| 600 |
+
"tinputz": "tpz",
|
| 601 |
+
"tlacoapa me'phaa": "tpl",
|
| 602 |
+
"tlacoatzintepec chinantec": "ctl",
|
| 603 |
+
"tlingit": "tli",
|
| 604 |
+
"toki pona": "tok",
|
| 605 |
+
"tomoip": "tqp",
|
| 606 |
+
"tondano": "tdn",
|
| 607 |
+
"tonsea": "txs",
|
| 608 |
+
"tooro": "ttj",
|
| 609 |
+
"torau": "ttu",
|
| 610 |
+
"torwali": "trw",
|
| 611 |
+
"tsimihety malagasy": "xmw",
|
| 612 |
+
"tsotso": "lto",
|
| 613 |
+
"tswana": "tn",
|
| 614 |
+
"tugen": "tuy",
|
| 615 |
+
"tuki": "bag",
|
| 616 |
+
"tula": "tul",
|
| 617 |
+
"tulu": "tcy",
|
| 618 |
+
"tunen": "tvu",
|
| 619 |
+
"tungag": "lcm",
|
| 620 |
+
"tunisian arabic": "aeb",
|
| 621 |
+
"tupuri": "tui",
|
| 622 |
+
"turkana": "tuv",
|
| 623 |
+
"turkish": "tr",
|
| 624 |
+
"turkmen": "tk",
|
| 625 |
+
"tututepec mixtec": "mtu",
|
| 626 |
+
"twi": "tw",
|
| 627 |
+
"ubaghara": "byc",
|
| 628 |
+
"uighur": "ug",
|
| 629 |
+
"ukrainian": "uk",
|
| 630 |
+
"umbundu": "umb",
|
| 631 |
+
"upper sorbian": "hsb",
|
| 632 |
+
"urdu": "ur",
|
| 633 |
+
"ushojo": "ush",
|
| 634 |
+
"uzbek": "uz",
|
| 635 |
+
"vai": "vai",
|
| 636 |
+
"vietnamese": "vi",
|
| 637 |
+
"votic": "vot",
|
| 638 |
+
"võro": "vro",
|
| 639 |
+
"waci gbe": "wci",
|
| 640 |
+
"wadiyara koli": "kxp",
|
| 641 |
+
"waja": "wja",
|
| 642 |
+
"wakhi": "wbl",
|
| 643 |
+
"wanga": "lwg",
|
| 644 |
+
"wapan": "juk",
|
| 645 |
+
"warji": "wji",
|
| 646 |
+
"welsh": "cy",
|
| 647 |
+
"wemale": "weo",
|
| 648 |
+
"western frisian": "fy",
|
| 649 |
+
"western highland purepecha": "pua",
|
| 650 |
+
"western juxtlahuaca mixtec": "jmx",
|
| 651 |
+
"western maninkakan": "mlq",
|
| 652 |
+
"western mari": "mrj",
|
| 653 |
+
"western niger fulfulde": "fuh",
|
| 654 |
+
"western panjabi": "pnb",
|
| 655 |
+
"wolof": "wo",
|
| 656 |
+
"wuzlam": "udl",
|
| 657 |
+
"xanaguía zapotec": "ztg",
|
| 658 |
+
"xhosa": "xh",
|
| 659 |
+
"yace": "ekr",
|
| 660 |
+
"yakut": "sah",
|
| 661 |
+
"yalahatan": "jal",
|
| 662 |
+
"yanahuanca pasco quechua": "qur",
|
| 663 |
+
"yangben": "yav",
|
| 664 |
+
"yaqui": "yaq",
|
| 665 |
+
"yauyos quechua": "qux",
|
| 666 |
+
"yekhee": "ets",
|
| 667 |
+
"yiddish": "yi",
|
| 668 |
+
"yidgha": "ydg",
|
| 669 |
+
"yoruba": "yo",
|
| 670 |
+
"yutanduchi mixtec": "mab",
|
| 671 |
+
"zacatlán-ahuacatlán-tepetzintla nahuatl": "nhi",
|
| 672 |
+
"zarma": "dje",
|
| 673 |
+
"zaza": "zza",
|
| 674 |
+
"zulu": "zu",
|
| 675 |
+
"ömie": "aom",
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
LANG_NAMES = set(LANG_NAME_TO_ID.keys())
|
| 679 |
+
LANG_IDS = set(LANG_NAME_TO_ID.values())
|
| 680 |
+
|
| 681 |
+
# Exceptions where .title() doesn't match the canonical casing from the TSV.
|
| 682 |
+
_TITLE_EXCEPTIONS = {
|
| 683 |
+
"fe'fe'": "Fe'fe'",
|
| 684 |
+
"dũya": "Dũya",
|
| 685 |
+
"santiago del estero quichua": "Santiago del Estero Quichua",
|
| 686 |
+
"santa ana de tusi pasco quechua": "Santa Ana de Tusi Pasco Quechua",
|
| 687 |
+
"malinaltepec me'phaa": "Malinaltepec Me'phaa",
|
| 688 |
+
"tlacoapa me'phaa": "Tlacoapa Me'phaa",
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def lang_display_name(name: str) -> str:
|
| 693 |
+
"""Return a display-friendly version of a lowercase language name.
|
| 694 |
+
|
| 695 |
+
Uses .title() for most names, with manual exceptions for cases like
|
| 696 |
+
apostrophes and small words (de, del) that should stay lowercase.
|
| 697 |
+
"""
|
| 698 |
+
return _TITLE_EXCEPTIONS.get(name, name.title())
|
omnivoice/utils/text.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Text processing utilities for TTS inference.
|
| 19 |
+
|
| 20 |
+
Provides:
|
| 21 |
+
- ``chunk_text_punctuation()``: Splits long text into model-friendly chunks at
|
| 22 |
+
sentence boundaries, with abbreviation-aware punctuation splitting.
|
| 23 |
+
- ``add_punctuation()``: Appends missing end punctuation (Chinese or English).
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from typing import List, Optional
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
SPLIT_PUNCTUATION = set(".,;:!?。,;:!?")
|
| 30 |
+
CLOSING_MARKS = set("\"'""')]》》>」】")
|
| 31 |
+
|
| 32 |
+
END_PUNCTUATION = {
|
| 33 |
+
";",
|
| 34 |
+
":",
|
| 35 |
+
",",
|
| 36 |
+
".",
|
| 37 |
+
"!",
|
| 38 |
+
"?",
|
| 39 |
+
"…",
|
| 40 |
+
")",
|
| 41 |
+
"]",
|
| 42 |
+
"}",
|
| 43 |
+
'"',
|
| 44 |
+
"'",
|
| 45 |
+
""",
|
| 46 |
+
"'",
|
| 47 |
+
";",
|
| 48 |
+
":",
|
| 49 |
+
",",
|
| 50 |
+
"。",
|
| 51 |
+
"!",
|
| 52 |
+
"?",
|
| 53 |
+
"、",
|
| 54 |
+
"……",
|
| 55 |
+
")",
|
| 56 |
+
"】",
|
| 57 |
+
""",
|
| 58 |
+
"'",
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
ABBREVIATIONS = {
|
| 63 |
+
"Mr.",
|
| 64 |
+
"Mrs.",
|
| 65 |
+
"Ms.",
|
| 66 |
+
"Dr.",
|
| 67 |
+
"Prof.",
|
| 68 |
+
"Sr.",
|
| 69 |
+
"Jr.",
|
| 70 |
+
"Rev.",
|
| 71 |
+
"Fr.",
|
| 72 |
+
"Hon.",
|
| 73 |
+
"Pres.",
|
| 74 |
+
"Gov.",
|
| 75 |
+
"Capt.",
|
| 76 |
+
"Gen.",
|
| 77 |
+
"Sen.",
|
| 78 |
+
"Rep.",
|
| 79 |
+
"Col.",
|
| 80 |
+
"Maj.",
|
| 81 |
+
"Lt.",
|
| 82 |
+
"Cmdr.",
|
| 83 |
+
"Sgt.",
|
| 84 |
+
"Cpl.",
|
| 85 |
+
"Co.",
|
| 86 |
+
"Corp.",
|
| 87 |
+
"Inc.",
|
| 88 |
+
"Ltd.",
|
| 89 |
+
"Est.",
|
| 90 |
+
"Dept.",
|
| 91 |
+
"St.",
|
| 92 |
+
"Ave.",
|
| 93 |
+
"Blvd.",
|
| 94 |
+
"Rd.",
|
| 95 |
+
"Mt.",
|
| 96 |
+
"Ft.",
|
| 97 |
+
"No.",
|
| 98 |
+
"Jan.",
|
| 99 |
+
"Feb.",
|
| 100 |
+
"Mar.",
|
| 101 |
+
"Apr.",
|
| 102 |
+
"Aug.",
|
| 103 |
+
"Sep.",
|
| 104 |
+
"Sept.",
|
| 105 |
+
"Oct.",
|
| 106 |
+
"Nov.",
|
| 107 |
+
"Dec.",
|
| 108 |
+
"i.e.",
|
| 109 |
+
"e.g.",
|
| 110 |
+
"vs.",
|
| 111 |
+
"Vs.",
|
| 112 |
+
"Etc.",
|
| 113 |
+
"approx.",
|
| 114 |
+
"fig.",
|
| 115 |
+
"def.",
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def chunk_text_punctuation(
|
| 120 |
+
text: str,
|
| 121 |
+
chunk_len: int,
|
| 122 |
+
min_chunk_len: Optional[int] = None,
|
| 123 |
+
) -> List[str]:
|
| 124 |
+
"""
|
| 125 |
+
Splits the input tokens list into chunks according to punctuations,
|
| 126 |
+
avoiding splits on common abbreviations (e.g., Mr., No.).
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
# 1. Split the tokens according to punctuations.
|
| 130 |
+
sentences = []
|
| 131 |
+
current_sentence = []
|
| 132 |
+
|
| 133 |
+
tokens_list = list(text)
|
| 134 |
+
|
| 135 |
+
for token in tokens_list:
|
| 136 |
+
# If the first token of current sentence is punctuation,
|
| 137 |
+
# append it to the end of the previous sentence.
|
| 138 |
+
if (
|
| 139 |
+
len(current_sentence) == 0
|
| 140 |
+
and len(sentences) != 0
|
| 141 |
+
and (token in SPLIT_PUNCTUATION or token in CLOSING_MARKS)
|
| 142 |
+
):
|
| 143 |
+
sentences[-1].append(token)
|
| 144 |
+
# Otherwise, append the current token to the current sentence.
|
| 145 |
+
else:
|
| 146 |
+
current_sentence.append(token)
|
| 147 |
+
|
| 148 |
+
# Split the sentence in positions of punctuations.
|
| 149 |
+
if token in SPLIT_PUNCTUATION:
|
| 150 |
+
is_abbreviation = False
|
| 151 |
+
|
| 152 |
+
if token == ".":
|
| 153 |
+
temp_str = "".join(current_sentence).strip()
|
| 154 |
+
if temp_str:
|
| 155 |
+
last_word = temp_str.split()[-1]
|
| 156 |
+
if last_word in ABBREVIATIONS:
|
| 157 |
+
is_abbreviation = True
|
| 158 |
+
|
| 159 |
+
if not is_abbreviation:
|
| 160 |
+
sentences.append(current_sentence)
|
| 161 |
+
current_sentence = []
|
| 162 |
+
# Assume the last few tokens are also a sentence
|
| 163 |
+
if len(current_sentence) != 0:
|
| 164 |
+
sentences.append(current_sentence)
|
| 165 |
+
|
| 166 |
+
# 2. Merge short sentences.
|
| 167 |
+
merged_chunks = []
|
| 168 |
+
current_chunk = []
|
| 169 |
+
for sentence in sentences:
|
| 170 |
+
if len(current_chunk) + len(sentence) <= chunk_len:
|
| 171 |
+
current_chunk.extend(sentence)
|
| 172 |
+
else:
|
| 173 |
+
if len(current_chunk) > 0:
|
| 174 |
+
merged_chunks.append(current_chunk)
|
| 175 |
+
current_chunk = sentence
|
| 176 |
+
|
| 177 |
+
if len(current_chunk) > 0:
|
| 178 |
+
merged_chunks.append(current_chunk)
|
| 179 |
+
|
| 180 |
+
# 4. Post-process: Check for undersized chunks and merge them
|
| 181 |
+
# with the previous chunk or next chunk (if it's the first chunk).
|
| 182 |
+
if min_chunk_len is not None:
|
| 183 |
+
first_chunk_short_flag = (
|
| 184 |
+
len(merged_chunks) > 0 and len(merged_chunks[0]) < min_chunk_len
|
| 185 |
+
)
|
| 186 |
+
final_chunks = []
|
| 187 |
+
for i, chunk in enumerate(merged_chunks):
|
| 188 |
+
if i == 1 and first_chunk_short_flag:
|
| 189 |
+
final_chunks[-1].extend(chunk)
|
| 190 |
+
else:
|
| 191 |
+
if len(chunk) >= min_chunk_len:
|
| 192 |
+
final_chunks.append(chunk)
|
| 193 |
+
else:
|
| 194 |
+
if len(final_chunks) == 0:
|
| 195 |
+
final_chunks.append(chunk)
|
| 196 |
+
else:
|
| 197 |
+
final_chunks[-1].extend(chunk)
|
| 198 |
+
else:
|
| 199 |
+
final_chunks = merged_chunks
|
| 200 |
+
|
| 201 |
+
chunk_strings = [
|
| 202 |
+
"".join(chunk).strip() for chunk in final_chunks if "".join(chunk).strip()
|
| 203 |
+
]
|
| 204 |
+
return chunk_strings
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def add_punctuation(text: str):
|
| 208 |
+
"""Add punctuation if there is not in the end of text"""
|
| 209 |
+
text = text.strip()
|
| 210 |
+
|
| 211 |
+
if not text:
|
| 212 |
+
return text
|
| 213 |
+
|
| 214 |
+
if text[-1] not in END_PUNCTUATION:
|
| 215 |
+
is_chinese = any("\u4e00" <= char <= "\u9fff" for char in text)
|
| 216 |
+
|
| 217 |
+
text += "。" if is_chinese else "."
|
| 218 |
+
|
| 219 |
+
return text
|
omnivoice/utils/voice_design.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""Voice-design instruct constants for TTS inference.
|
| 19 |
+
|
| 20 |
+
Defines speaker attribute tags (gender, age, pitch, accent, dialect) and
|
| 21 |
+
translation/validation utilities between English and Chinese. Used by
|
| 22 |
+
``OmniVoice.generate()`` for voice design mode.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import re
|
| 26 |
+
|
| 27 |
+
_ZH_RE = re.compile(r'[\u4e00-\u9fff]')
|
| 28 |
+
|
| 29 |
+
# Category = set of {english: chinese, ...} items that are mutually exclusive.
|
| 30 |
+
# Accent (EN-only) and dialect (ZH-only) are stored as flat sets below.
|
| 31 |
+
_INSTRUCT_CATEGORIES = [
|
| 32 |
+
{"male": "男", "female": "女"},
|
| 33 |
+
{"child": "儿童", "teenager": "少年", "young adult": "青年",
|
| 34 |
+
"middle-aged": "中年", "elderly": "老年"},
|
| 35 |
+
{"very low pitch": "极低音调", "low pitch": "低音调",
|
| 36 |
+
"moderate pitch": "中音调", "high pitch": "高音调",
|
| 37 |
+
"very high pitch": "极高音调"},
|
| 38 |
+
{"whisper": "耳语"},
|
| 39 |
+
# Accent (English-only, no Chinese counterpart)
|
| 40 |
+
{"american accent", "british accent", "australian accent",
|
| 41 |
+
"chinese accent", "canadian accent", "indian accent",
|
| 42 |
+
"korean accent", "portuguese accent", "russian accent", "japanese accent"},
|
| 43 |
+
# Dialect (Chinese-only, no English counterpart)
|
| 44 |
+
{"河南话", "陕西话", "四川话", "贵州话", "云南话", "桂林话",
|
| 45 |
+
"济南话", "石家庄话", "甘肃话", "宁夏话", "青岛话", "东北话"},
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
_INSTRUCT_EN_TO_ZH = {}
|
| 49 |
+
_INSTRUCT_ZH_TO_EN = {}
|
| 50 |
+
_INSTRUCT_MUTUALLY_EXCLUSIVE = []
|
| 51 |
+
for _cat in _INSTRUCT_CATEGORIES:
|
| 52 |
+
if isinstance(_cat, dict):
|
| 53 |
+
_INSTRUCT_EN_TO_ZH.update(_cat)
|
| 54 |
+
_INSTRUCT_ZH_TO_EN.update({v: k for k, v in _cat.items()})
|
| 55 |
+
_INSTRUCT_MUTUALLY_EXCLUSIVE.append(set(_cat) | set(_cat.values()))
|
| 56 |
+
else:
|
| 57 |
+
_INSTRUCT_MUTUALLY_EXCLUSIVE.append(set(_cat))
|
| 58 |
+
|
| 59 |
+
_INSTRUCT_ALL_VALID = (
|
| 60 |
+
set(_INSTRUCT_EN_TO_ZH) | set(_INSTRUCT_ZH_TO_EN)
|
| 61 |
+
| _INSTRUCT_MUTUALLY_EXCLUSIVE[-2] # accents
|
| 62 |
+
| _INSTRUCT_MUTUALLY_EXCLUSIVE[-1] # dialects
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
_INSTRUCT_VALID_EN = frozenset(i for i in _INSTRUCT_ALL_VALID if not _ZH_RE.search(i))
|
| 66 |
+
_INSTRUCT_VALID_ZH = frozenset(i for i in _INSTRUCT_ALL_VALID if _ZH_RE.search(i))
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu128
|
| 2 |
+
torch==2.8.0
|
| 3 |
+
torchaudio==2.8.0
|
| 4 |
+
transformers==5.3
|
| 5 |
+
accelerate
|
| 6 |
+
pydub
|
| 7 |
+
soundfile
|
| 8 |
+
numpy
|
| 9 |
+
gradio
|
| 10 |
+
hf_transfer
|