korallll's picture
Add contact: email + Telegram channel
81757c3 verified
|
Raw
History Blame Contribute Delete
3.59 kB
---
language:
- multilingual
library_name: onnxruntime
pipeline_tag: audio-classification
tags:
- audio
- speech
- tts
- quality-classification
- wav2vec2
- onnx
license: apache-2.0
---
# TTS Suitability Classifier
ONNX audio classifier that estimates whether a speech segment is suitable for
TTS training.
The model is a binary classifier based on the 300M wav2vec2 encoder from
[facebook/omniASR](https://github.com/facebookresearch/omnilingual-asr).
The ONNX file is self-contained and does not require `fairseq2`, PyTorch, or the
original omnilingual-asr repository for inference.
## Labels
| Class | Label | Meaning |
|---:|---|---|
| 0 | `not_tts` | Audio is not suitable for TTS training |
| 1 | `tts` | Audio is suitable for TTS training |
`p_tts` is the softmax probability of class 1. The default decision threshold
is `0.5`. For dataset filtering, choose the threshold on a manually labeled
validation set.
## Installation
```bash
pip install -r requirements.txt
```
For CUDA inference, replace `onnxruntime` with a compatible
`onnxruntime-gpu` build.
## Command-line inference
```bash
python inference.py sample.mp3
python inference.py /path/to/audio-directory --provider cpu
python inference.py sample.wav --provider cuda --cuda-device-id 0
```
Each result is printed as one JSON object:
```json
{
"label": "tts",
"predicted_class": 1,
"p_not_tts": 0.02,
"p_tts": 0.98,
"logits": [-2.2, 1.5]
}
```
## Python API
```python
from inference import TTSSuitabilityClassifier
classifier = TTSSuitabilityClassifier(provider="auto")
result = classifier.predict("sample.mp3")
print(result["label"])
print(result["p_tts"])
```
## Input preprocessing
The included inference code applies the same preprocessing as the training and
export recipe:
1. Decode WAV, FLAC, MP3, OGG, or M4A.
2. Mix channels to mono.
3. Resample to 16 kHz.
4. Apply waveform layer normalization.
5. Split long audio into 10-second chunks.
6. Average chunk logits and apply softmax.
The ONNX input is a float32 tensor named `waveforms` with shape
`[batch_size, num_frames]`. The output is `logits` with shape
`[batch_size, 2]`. Both input axes are dynamic; ONNX opset 17 is used.
## Files
- `model.onnx`: self-contained FP32 ONNX model.
- `inference.py`: standalone ONNX Runtime inference.
- `requirements.txt`: CPU inference dependencies.
## Upload to Hugging Face
Create an empty model repository, then run from this directory:
```bash
hf upload-large-folder <username>/<repo-name> . --repo-type model
```
`model.onnx` is configured for Git LFS in `.gitattributes`.
## Training and export
The released model corresponds to training checkpoint step 94,000. It was
exported using the repository recipes:
- `workflows/recipes/wav2vec2/binary_classification/export_onnx.py`
- `workflows/recipes/wav2vec2/binary_classification/run_onnx.py`
Architecture: `wav2vec2_asr 300m`
Sample rate: `16000 Hz`
Training maximum audio length: `160000` samples
Classes: `not_tts`, `tts`
## Limitations
- The score measures similarity to the training definition of TTS-suitable
audio; it is not a general-purpose MOS score.
- Music, noise, clipping, overlapping speakers, and unusual recording
conditions may affect predictions.
- Probabilities are not guaranteed to be calibrated.
- Validate the threshold on data from the intended domain before filtering a
large dataset.
## License
Apache 2.0. The base architecture and code originate from the
omnilingual-asr project.
## Contact
- Email: kborodin.research@gmail.com
- Telegram: [@korallll_ai](https://t.me/korallll_ai)