File size: 3,588 Bytes
2a52c76
6b0fab3
 
 
 
 
 
 
 
 
 
 
2a52c76
 
6b0fab3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81757c3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
---
language:
- multilingual
library_name: onnxruntime
pipeline_tag: audio-classification
tags:
- audio
- speech
- tts
- quality-classification
- wav2vec2
- onnx
license: apache-2.0
---

# TTS Suitability Classifier

ONNX audio classifier that estimates whether a speech segment is suitable for
TTS training.

The model is a binary classifier based on the 300M wav2vec2 encoder from
[facebook/omniASR](https://github.com/facebookresearch/omnilingual-asr).
The ONNX file is self-contained and does not require `fairseq2`, PyTorch, or the
original omnilingual-asr repository for inference.

## Labels

| Class | Label | Meaning |
|---:|---|---|
| 0 | `not_tts` | Audio is not suitable for TTS training |
| 1 | `tts` | Audio is suitable for TTS training |

`p_tts` is the softmax probability of class 1. The default decision threshold
is `0.5`. For dataset filtering, choose the threshold on a manually labeled
validation set.

## Installation

```bash
pip install -r requirements.txt
```

For CUDA inference, replace `onnxruntime` with a compatible
`onnxruntime-gpu` build.

## Command-line inference

```bash
python inference.py sample.mp3
python inference.py /path/to/audio-directory --provider cpu
python inference.py sample.wav --provider cuda --cuda-device-id 0
```

Each result is printed as one JSON object:

```json
{
  "label": "tts",
  "predicted_class": 1,
  "p_not_tts": 0.02,
  "p_tts": 0.98,
  "logits": [-2.2, 1.5]
}
```

## Python API

```python
from inference import TTSSuitabilityClassifier

classifier = TTSSuitabilityClassifier(provider="auto")
result = classifier.predict("sample.mp3")

print(result["label"])
print(result["p_tts"])
```

## Input preprocessing

The included inference code applies the same preprocessing as the training and
export recipe:

1. Decode WAV, FLAC, MP3, OGG, or M4A.
2. Mix channels to mono.
3. Resample to 16 kHz.
4. Apply waveform layer normalization.
5. Split long audio into 10-second chunks.
6. Average chunk logits and apply softmax.

The ONNX input is a float32 tensor named `waveforms` with shape
`[batch_size, num_frames]`. The output is `logits` with shape
`[batch_size, 2]`. Both input axes are dynamic; ONNX opset 17 is used.

## Files

- `model.onnx`: self-contained FP32 ONNX model.
- `inference.py`: standalone ONNX Runtime inference.
- `requirements.txt`: CPU inference dependencies.

## Upload to Hugging Face

Create an empty model repository, then run from this directory:

```bash
hf upload-large-folder <username>/<repo-name> . --repo-type model
```

`model.onnx` is configured for Git LFS in `.gitattributes`.

## Training and export

The released model corresponds to training checkpoint step 94,000. It was
exported using the repository recipes:

- `workflows/recipes/wav2vec2/binary_classification/export_onnx.py`
- `workflows/recipes/wav2vec2/binary_classification/run_onnx.py`

Architecture: `wav2vec2_asr 300m`  
Sample rate: `16000 Hz`  
Training maximum audio length: `160000` samples  
Classes: `not_tts`, `tts`

## Limitations

- The score measures similarity to the training definition of TTS-suitable
  audio; it is not a general-purpose MOS score.
- Music, noise, clipping, overlapping speakers, and unusual recording
  conditions may affect predictions.
- Probabilities are not guaranteed to be calibrated.
- Validate the threshold on data from the intended domain before filtering a
  large dataset.

## License

Apache 2.0. The base architecture and code originate from the
omnilingual-asr project.

## Contact

- Email: kborodin.research@gmail.com
- Telegram: [@korallll_ai](https://t.me/korallll_ai)