File size: 3,588 Bytes
2a52c76 6b0fab3 2a52c76 6b0fab3 81757c3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | ---
language:
- multilingual
library_name: onnxruntime
pipeline_tag: audio-classification
tags:
- audio
- speech
- tts
- quality-classification
- wav2vec2
- onnx
license: apache-2.0
---
# TTS Suitability Classifier
ONNX audio classifier that estimates whether a speech segment is suitable for
TTS training.
The model is a binary classifier based on the 300M wav2vec2 encoder from
[facebook/omniASR](https://github.com/facebookresearch/omnilingual-asr).
The ONNX file is self-contained and does not require `fairseq2`, PyTorch, or the
original omnilingual-asr repository for inference.
## Labels
| Class | Label | Meaning |
|---:|---|---|
| 0 | `not_tts` | Audio is not suitable for TTS training |
| 1 | `tts` | Audio is suitable for TTS training |
`p_tts` is the softmax probability of class 1. The default decision threshold
is `0.5`. For dataset filtering, choose the threshold on a manually labeled
validation set.
## Installation
```bash
pip install -r requirements.txt
```
For CUDA inference, replace `onnxruntime` with a compatible
`onnxruntime-gpu` build.
## Command-line inference
```bash
python inference.py sample.mp3
python inference.py /path/to/audio-directory --provider cpu
python inference.py sample.wav --provider cuda --cuda-device-id 0
```
Each result is printed as one JSON object:
```json
{
"label": "tts",
"predicted_class": 1,
"p_not_tts": 0.02,
"p_tts": 0.98,
"logits": [-2.2, 1.5]
}
```
## Python API
```python
from inference import TTSSuitabilityClassifier
classifier = TTSSuitabilityClassifier(provider="auto")
result = classifier.predict("sample.mp3")
print(result["label"])
print(result["p_tts"])
```
## Input preprocessing
The included inference code applies the same preprocessing as the training and
export recipe:
1. Decode WAV, FLAC, MP3, OGG, or M4A.
2. Mix channels to mono.
3. Resample to 16 kHz.
4. Apply waveform layer normalization.
5. Split long audio into 10-second chunks.
6. Average chunk logits and apply softmax.
The ONNX input is a float32 tensor named `waveforms` with shape
`[batch_size, num_frames]`. The output is `logits` with shape
`[batch_size, 2]`. Both input axes are dynamic; ONNX opset 17 is used.
## Files
- `model.onnx`: self-contained FP32 ONNX model.
- `inference.py`: standalone ONNX Runtime inference.
- `requirements.txt`: CPU inference dependencies.
## Upload to Hugging Face
Create an empty model repository, then run from this directory:
```bash
hf upload-large-folder <username>/<repo-name> . --repo-type model
```
`model.onnx` is configured for Git LFS in `.gitattributes`.
## Training and export
The released model corresponds to training checkpoint step 94,000. It was
exported using the repository recipes:
- `workflows/recipes/wav2vec2/binary_classification/export_onnx.py`
- `workflows/recipes/wav2vec2/binary_classification/run_onnx.py`
Architecture: `wav2vec2_asr 300m`
Sample rate: `16000 Hz`
Training maximum audio length: `160000` samples
Classes: `not_tts`, `tts`
## Limitations
- The score measures similarity to the training definition of TTS-suitable
audio; it is not a general-purpose MOS score.
- Music, noise, clipping, overlapping speakers, and unusual recording
conditions may affect predictions.
- Probabilities are not guaranteed to be calibrated.
- Validate the threshold on data from the intended domain before filtering a
large dataset.
## License
Apache 2.0. The base architecture and code originate from the
omnilingual-asr project.
## Contact
- Email: kborodin.research@gmail.com
- Telegram: [@korallll_ai](https://t.me/korallll_ai)
|