| --- |
| language: |
| - multilingual |
| library_name: onnxruntime |
| pipeline_tag: audio-classification |
| tags: |
| - audio |
| - speech |
| - tts |
| - quality-classification |
| - wav2vec2 |
| - onnx |
| license: apache-2.0 |
| --- |
| |
| # TTS Suitability Classifier |
|
|
| ONNX audio classifier that estimates whether a speech segment is suitable for |
| TTS training. |
|
|
| The model is a binary classifier based on the 300M wav2vec2 encoder from |
| [facebook/omniASR](https://github.com/facebookresearch/omnilingual-asr). |
| The ONNX file is self-contained and does not require `fairseq2`, PyTorch, or the |
| original omnilingual-asr repository for inference. |
|
|
| ## Labels |
|
|
| | Class | Label | Meaning | |
| |---:|---|---| |
| | 0 | `not_tts` | Audio is not suitable for TTS training | |
| | 1 | `tts` | Audio is suitable for TTS training | |
|
|
| `p_tts` is the softmax probability of class 1. The default decision threshold |
| is `0.5`. For dataset filtering, choose the threshold on a manually labeled |
| validation set. |
|
|
| ## Installation |
|
|
| ```bash |
| pip install -r requirements.txt |
| ``` |
|
|
| For CUDA inference, replace `onnxruntime` with a compatible |
| `onnxruntime-gpu` build. |
|
|
| ## Command-line inference |
|
|
| ```bash |
| python inference.py sample.mp3 |
| python inference.py /path/to/audio-directory --provider cpu |
| python inference.py sample.wav --provider cuda --cuda-device-id 0 |
| ``` |
|
|
| Each result is printed as one JSON object: |
|
|
| ```json |
| { |
| "label": "tts", |
| "predicted_class": 1, |
| "p_not_tts": 0.02, |
| "p_tts": 0.98, |
| "logits": [-2.2, 1.5] |
| } |
| ``` |
|
|
| ## Python API |
|
|
| ```python |
| from inference import TTSSuitabilityClassifier |
| |
| classifier = TTSSuitabilityClassifier(provider="auto") |
| result = classifier.predict("sample.mp3") |
| |
| print(result["label"]) |
| print(result["p_tts"]) |
| ``` |
|
|
| ## Input preprocessing |
|
|
| The included inference code applies the same preprocessing as the training and |
| export recipe: |
|
|
| 1. Decode WAV, FLAC, MP3, OGG, or M4A. |
| 2. Mix channels to mono. |
| 3. Resample to 16 kHz. |
| 4. Apply waveform layer normalization. |
| 5. Split long audio into 10-second chunks. |
| 6. Average chunk logits and apply softmax. |
|
|
| The ONNX input is a float32 tensor named `waveforms` with shape |
| `[batch_size, num_frames]`. The output is `logits` with shape |
| `[batch_size, 2]`. Both input axes are dynamic; ONNX opset 17 is used. |
|
|
| ## Files |
|
|
| - `model.onnx`: self-contained FP32 ONNX model. |
| - `inference.py`: standalone ONNX Runtime inference. |
| - `requirements.txt`: CPU inference dependencies. |
|
|
| ## Upload to Hugging Face |
|
|
| Create an empty model repository, then run from this directory: |
|
|
| ```bash |
| hf upload-large-folder <username>/<repo-name> . --repo-type model |
| ``` |
|
|
| `model.onnx` is configured for Git LFS in `.gitattributes`. |
|
|
| ## Training and export |
|
|
| The released model corresponds to training checkpoint step 94,000. It was |
| exported using the repository recipes: |
|
|
| - `workflows/recipes/wav2vec2/binary_classification/export_onnx.py` |
| - `workflows/recipes/wav2vec2/binary_classification/run_onnx.py` |
|
|
| Architecture: `wav2vec2_asr 300m` |
| Sample rate: `16000 Hz` |
| Training maximum audio length: `160000` samples |
| Classes: `not_tts`, `tts` |
|
|
| ## Limitations |
|
|
| - The score measures similarity to the training definition of TTS-suitable |
| audio; it is not a general-purpose MOS score. |
| - Music, noise, clipping, overlapping speakers, and unusual recording |
| conditions may affect predictions. |
| - Probabilities are not guaranteed to be calibrated. |
| - Validate the threshold on data from the intended domain before filtering a |
| large dataset. |
|
|
| ## License |
|
|
| Apache 2.0. The base architecture and code originate from the |
| omnilingual-asr project. |
|
|
| ## Contact |
|
|
| - Email: kborodin.research@gmail.com |
| - Telegram: [@korallll_ai](https://t.me/korallll_ai) |
|
|