Obiang commited on
Commit
62cb0ac
·
1 Parent(s): 367817f

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python virtual environment
2
+ venv/
3
+ env/
4
+ ENV/
5
+ .venv
6
+
7
+ # Python cache
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+ *.so
12
+ .Python
13
+
14
+ # Distribution / packaging
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+
31
+ # PyInstaller
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Jupyter Notebook
36
+ .ipynb_checkpoints
37
+
38
+ # IPython
39
+ profile_default/
40
+ ipython_config.py
41
+
42
+ # pyenv
43
+ .python-version
44
+
45
+ # Environments
46
+ .env
47
+ .venv
48
+ env/
49
+ venv/
50
+ ENV/
51
+ env.bak/
52
+ venv.bak/
53
+
54
+ # IDEs
55
+ .vscode/
56
+ .idea/
57
+ *.swp
58
+ *.swo
59
+ *~
60
+
61
+ # OS files
62
+ .DS_Store
63
+ Thumbs.db
64
+
65
+ # Gradio cache
66
+ flagged/
67
+ gradio_cached_examples/
68
+
69
+ # SpeechBrain
70
+ pretrained_model/
71
+ whubert_checkpoint/
72
+ results/
73
+ save/
74
+
75
+ # Model checkpoints (if you don't want to track them)
76
+ # Uncomment the line below if checkpoints are too large for git
77
+ # CKPT*/
78
+
79
+ # Logs
80
+ *.log
81
+ logs/
82
+
83
+ # Temporary files
84
+ *.tmp
85
+ temp/
86
+ tmp/
87
+ PITCH/
CKPT+2025-10-20+08-19-07+00/CKPT.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # yamllint disable
2
+ CER: 17.354776764282285
3
+ end-of-epoch: true
4
+ unixtime: 1760948347.932281
CKPT+2025-10-20+08-19-07+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a006b4494d46ac9b0ea15eab28666eafb68e0a68bcb8c15c07edd35285bd0e5
3
+ size 50
CKPT+2025-10-20+08-19-07+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:108c995b953c8a35561103e2014cf828eb654a99e310f87fab94c2f4b7d2a04f
3
+ size 2
CKPT+2025-10-20+08-19-07+00/lr_annealing.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a55d7b1344e7db29061cdf1888822daaf1597e3df4e4060c6deca9bf690a829e
3
+ size 1931
CKPT+2025-10-20+08-19-07+00/lr_annealing_wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12bc3c2adbcc92643e4bad84a8619e7300948050cfc2e86423f9bfd9c2c31090
3
+ size 1979
CKPT+2025-10-20+08-19-07+00/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:892fa9f449acd39c6a85a1d456f05d050ef030efc3dbc5a64dbf1984a3e26800
3
+ size 38091995
CKPT+2025-10-20+08-19-07+00/optimizer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eac563e0988030eb01a32dded244a7ce8dc93d1806bac41e6a9259173a0d51fc
3
+ size 76194782
CKPT+2025-10-20+08-19-07+00/optimizer_wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:deea19e921cbc0876edc13bccea586285ee87c9a2fe91b7c2f81b27f456c8ca3
3
+ size 2025
CKPT+2025-10-20+08-19-07+00/tokenizer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:304a02a791e8409f8f7065d83fca6c755fafacbc698e82cd7c9e3df1cb4f254d
3
+ size 144
CKPT+2025-10-20+08-19-07+00/wav2vec2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3017b9c7e9e90167daa376ee7a010d10b4fea49bf4a2d3ba72eb13bb469093fc
3
+ size 377574002
README.md CHANGED
@@ -1,14 +1,147 @@
1
  ---
2
- title: Pro TeVA
3
- emoji: 😻
4
- colorFrom: pink
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: 'ProTeVa: AI-powered tone recognition for Yoruba language.'
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ProTeVa Yoruba Tone Recognition
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: 'ProTeVa: AI-powered tone recognition for Yoruba language with word boundary detection.'
12
  ---
13
 
14
+ # ProTeVa: Yoruba Tone Recognition
15
+
16
+ This Space demonstrates **ProTeVa** (Prototype-based Tone Variant Autoencoder), a neural model for recognizing tone patterns in Yoruba language with intelligent word boundary detection.
17
+
18
+ ## Features
19
+
20
+ - 🎤 **Record or Upload**: Use your microphone or upload audio files
21
+ - 🎯 **Tone Detection**: Automatically detects 3 Yoruba tones (High, Low, Mid)
22
+ - 🔍 **Word Boundaries**: Intelligent space detection using acoustic features
23
+ - 📊 **F0 Visualization**: Shows fundamental frequency contours
24
+ - 🎨 **Interactive UI**: Real-time predictions with visual feedback
25
+
26
+ ## Yoruba Tones
27
+
28
+ Yoruba is a tonal language with three contrastive tones:
29
+
30
+ 1. **High Tone (H)** (◌́) - Example: ágbó (elder)
31
+ 2. **Low Tone (B)** (◌̀) - Example: àgbò (ram)
32
+ 3. **Mid Tone (M)** (◌) - Example: agbo (medicine)
33
+
34
+ ## Model Architecture
35
+
36
+ - **Feature Extractor**: HuBERT (Orange/SSA-HuBERT-base-60k)
37
+ - **Encoder**: 2-layer Bidirectional GRU (512 hidden units)
38
+ - **Decoder**: VanillaNN (2 blocks, 512 neurons)
39
+ - **Prototype Layer**: 10 learnable tone prototypes
40
+ - **F0 Reconstruction**: TorchYIN pitch estimation
41
+ - **Output**: CTC-based sequence prediction
42
+ - **Space Detection**: Multi-method acoustic boundary detection
43
+
44
+ ## Space Detection
45
+
46
+ ProTeVa uses intelligent post-processing to detect word boundaries:
47
+
48
+ ### Detection Methods
49
+
50
+ 1. **Silence Detection**: Identifies pauses in speech using F0 analysis
51
+ 2. **F0 Drop Detection**: Detects pitch resets typical of word boundaries
52
+ 3. **Combined Method** (default): Fuses multiple acoustic cues for robust detection
53
+
54
+ ### Configuration
55
+
56
+ The model's behavior can be customized via `config.py`:
57
+
58
+ ```python
59
+ ENABLE_SPACE_DETECTION = True
60
+ SPACE_DETECTION_METHOD = "combined" # 'silence', 'f0_drop', 'duration', 'combined'
61
+ SILENCE_THRESHOLD = 0.15 # seconds
62
+ F0_DROP_THRESHOLD = 0.20 # 20% pitch drop
63
+ ```
64
+
65
+ ## Training Details
66
+
67
+ - **Dataset**: Yoruba speech corpus
68
+ - **Sample Rate**: 16kHz
69
+ - **Loss Functions**:
70
+ - CTC loss for tone sequence
71
+ - MSE loss for F0 reconstruction
72
+ - Prototype regularization (R₁ + R₂)
73
+ - **Training Duration**: 65 epochs
74
+ - **Best CER**: 17.35%
75
+
76
+ ## Label Encoding
77
+
78
+ Based on the trained model's tokenizer:
79
+
80
+ - **0**: Blank (CTC blank token)
81
+ - **1**: High Tone (H)
82
+ - **2**: Low Tone (B)
83
+ - **3**: Mid Tone (M)
84
+ - **4**: Space (post-processing detection)
85
+
86
+ ## Usage
87
+
88
+ 1. Click on the microphone icon to record or upload an audio file
89
+ 2. Speak clearly in Yoruba
90
+ 3. Click "🔍 Predict Tones"
91
+ 4. View predicted tone sequence, word boundaries, and F0 contour
92
+
93
+ ### Tips for Best Results
94
+
95
+ - Speak clearly with natural prosody
96
+ - Keep recordings under 10 seconds
97
+ - Avoid background noise
98
+ - Pause slightly between words for better boundary detection
99
+
100
+ ## Technical Implementation
101
+
102
+ ### Files Structure
103
+
104
+ ```
105
+ .
106
+ ├── config.py # Central configuration
107
+ ├── app.py # Gradio UI
108
+ ├── custom_interface.py # SpeechBrain interface + space detection
109
+ ├── modules.py # Custom PyTorch modules
110
+ ├── inference.yaml # Model configuration
111
+ ├── requirements.txt # Dependencies
112
+ └── CKPT+*/ # Model checkpoints
113
+ ```
114
+
115
+ ### Key Components
116
+
117
+ - **F0Extractor**: TorchYIN-based pitch estimation
118
+ - **PrototypeLayer**: Learnable tone pattern prototypes
119
+ - **PitchDecoderLayer**: F0 reconstruction decoder
120
+ - **Space Detection**: Acoustic-based word boundary detection
121
+
122
+ ## Citation
123
+
124
+ If you use this model in your research, please cite:
125
+
126
+ ```bibtex
127
+ @article{proteva2025,
128
+ title={ProTeVa: Prototype-based Tone Variant Autoencoder for Yoruba Tone Recognition},
129
+ author={Your Name},
130
+ year={2025},
131
+ note={Hugging Face Space}
132
+ }
133
+ ```
134
+
135
+ ## Acknowledgments
136
+
137
+ - Built with ❤️ using [SpeechBrain](https://speechbrain.github.io/) and [Gradio](https://gradio.app/)
138
+ - HuBERT model: [Orange/SSA-HuBERT-base-60k](https://huggingface.co/Orange/SSA-HuBERT-base-60k)
139
+ - F0 extraction: [TorchYIN](https://github.com/brentspell/torch-yin)
140
+
141
+ ## License
142
+
143
+ Apache 2.0
144
+
145
+ ## Contact
146
+
147
+ For questions or issues, please open an issue on the repository.
_docs/IMPLEMENTATION_SUMMARY.md ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ProTeVa Implementation Summary
2
+
3
+ ## ✅ Implementation Complete
4
+
5
+ All files have been created and configured for ProTeVa deployment with intelligent space detection.
6
+
7
+ ---
8
+
9
+ ## 📁 Created Files
10
+
11
+ ### Core Application Files
12
+
13
+ 1. **[config.py](config.py)** - Central configuration
14
+ - Checkpoint folder path: `CKPT+2025-10-20+08-19-07+00`
15
+ - Space detection settings (enabled by default)
16
+ - Tone label mappings (H=1, B=2, M=3)
17
+ - Visualization configurations
18
+ - Helper functions for validation
19
+
20
+ 2. **[app.py](app.py)** - Gradio UI application
21
+ - Interactive web interface
22
+ - Audio recording and upload
23
+ - Tone visualization with space markers
24
+ - F0 contour plotting
25
+ - Real-time statistics
26
+ - Imports configuration from `config.py`
27
+
28
+ 3. **[custom_interface.py](custom_interface.py)** - SpeechBrain interface
29
+ - Model loading and inference
30
+ - **Space detection implementation**:
31
+ - Silence-based detection
32
+ - F0 drop detection
33
+ - Duration-based detection
34
+ - Combined method (recommended)
35
+ - Post-processing for word boundaries
36
+
37
+ 4. **[modules.py](modules.py)** - Custom PyTorch modules
38
+ - `F0Extractor`: TorchYIN pitch estimation
39
+ - `PitchDecoderLayer`: F0 reconstruction
40
+ - `PrototypeLayer`: Learnable tone prototypes
41
+
42
+ 5. **[inference.yaml](inference.yaml)** - Model configuration
43
+ - Model architecture settings
44
+ - Checkpoint paths
45
+ - References `config.py` for folder name
46
+
47
+ 6. **[requirements.txt](requirements.txt)** - Python dependencies
48
+ - SpeechBrain, Torch, Gradio
49
+ - TorchYIN for F0 extraction
50
+ - Visualization libraries
51
+
52
+ 7. **[README.md](README.md)** - Hugging Face Space documentation
53
+ - Model description
54
+ - Space detection explanation
55
+ - Usage instructions
56
+ - Technical details
57
+
58
+ ---
59
+
60
+ ## 🎯 Key Features Implemented
61
+
62
+ ### 1. **Centralized Configuration**
63
+ All settings are managed through `config.py`:
64
+ - **Easy checkpoint updates**: Just change `CHECKPOINT_FOLDER`
65
+ - **Configurable space detection**: Enable/disable, choose method, tune thresholds
66
+ - **Single source of truth**: No scattered hardcoded values
67
+
68
+ ### 2. **Intelligent Space Detection**
69
+ Four detection methods implemented:
70
+
71
+ #### Method 1: Silence Detection
72
+ ```python
73
+ SPACE_DETECTION_METHOD = "silence"
74
+ ```
75
+ - Analyzes F0 for silent gaps
76
+ - Threshold: 0.15 seconds (configurable)
77
+
78
+ #### Method 2: F0 Drop Detection
79
+ ```python
80
+ SPACE_DETECTION_METHOD = "f0_drop"
81
+ ```
82
+ - Detects pitch resets at word boundaries
83
+ - Threshold: 20% drop (configurable)
84
+
85
+ #### Method 3: Duration-Based
86
+ ```python
87
+ SPACE_DETECTION_METHOD = "duration"
88
+ ```
89
+ - Simple heuristic (every N tones)
90
+ - Less accurate but fast
91
+
92
+ #### Method 4: Combined (Recommended)
93
+ ```python
94
+ SPACE_DETECTION_METHOD = "combined"
95
+ ```
96
+ - Fuses silence + F0 drop detection
97
+ - Best balance of precision and recall
98
+ - **Default setting**
99
+
100
+ ### 3. **Correct Tone Mappings**
101
+ Based on your `labelencoder.txt`:
102
+ - **Label 0**: Blank (CTC)
103
+ - **Label 1**: High Tone (H)
104
+ - **Label 2**: Low Tone (B)
105
+ - **Label 3**: Mid Tone (M)
106
+ - **Label 4**: Space (post-processing)
107
+
108
+ ### 4. **Enhanced Visualization**
109
+ - Tone sequence with color coding
110
+ - Space markers as vertical separators
111
+ - F0 contour plots
112
+ - Real-time statistics with word count
113
+
114
+ ---
115
+
116
+ ## 🚀 Quick Start
117
+
118
+ ### Update Configuration
119
+ Edit `config.py`:
120
+ ```python
121
+ # 1. Set your checkpoint folder
122
+ CHECKPOINT_FOLDER = "CKPT+2025-10-20+08-19-07+00"
123
+
124
+ # 2. Configure space detection
125
+ ENABLE_SPACE_DETECTION = True
126
+ SPACE_DETECTION_METHOD = "combined"
127
+
128
+ # 3. Fine-tune thresholds (optional)
129
+ SILENCE_THRESHOLD = 0.15 # seconds
130
+ F0_DROP_THRESHOLD = 0.20 # 20% pitch drop
131
+ ```
132
+
133
+ ### Local Testing
134
+ ```bash
135
+ # Install dependencies
136
+ pip install -r requirements.txt
137
+
138
+ # Run the app
139
+ python app.py
140
+
141
+ # Open browser
142
+ # http://localhost:7860
143
+ ```
144
+
145
+ ### Deploy to Hugging Face
146
+ ```bash
147
+ # Clone your Space
148
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
149
+ cd YOUR_SPACE_NAME
150
+
151
+ # Copy all files
152
+ cp /path/to/Pro-TeVA/*.py .
153
+ cp /path/to/Pro-TeVA/*.yaml .
154
+ cp /path/to/Pro-TeVA/*.txt .
155
+ cp /path/to/Pro-TeVA/README.md .
156
+ cp -r /path/to/Pro-TeVA/CKPT+2025-10-20+08-19-07+00 .
157
+
158
+ # Setup Git LFS for large files
159
+ git lfs install
160
+ git lfs track "*.ckpt"
161
+ git add .gitattributes
162
+
163
+ # Commit and push
164
+ git add .
165
+ git commit -m "Initial deployment with space detection"
166
+ git push
167
+ ```
168
+
169
+ ---
170
+
171
+ ## ⚙️ Configuration Options
172
+
173
+ ### Checkpoint Folder
174
+ ```python
175
+ # config.py
176
+ CHECKPOINT_FOLDER = "YOUR_CHECKPOINT_FOLDER_NAME"
177
+ ```
178
+
179
+ Also update in `inference.yaml`:
180
+ ```yaml
181
+ save_folder: ./YOUR_CHECKPOINT_FOLDER_NAME
182
+ ```
183
+
184
+ ### Space Detection Toggle
185
+ ```python
186
+ # Disable space detection completely
187
+ ENABLE_SPACE_DETECTION = False
188
+ ```
189
+
190
+ ### Detection Method
191
+ ```python
192
+ SPACE_DETECTION_METHOD = "combined" # Best (default)
193
+ # OR
194
+ SPACE_DETECTION_METHOD = "silence" # Pause-based only
195
+ # OR
196
+ SPACE_DETECTION_METHOD = "f0_drop" # Pitch-based only
197
+ # OR
198
+ SPACE_DETECTION_METHOD = "duration" # Simple heuristic
199
+ ```
200
+
201
+ ### Threshold Tuning
202
+ ```python
203
+ # If detecting too many spaces
204
+ SILENCE_THRESHOLD = 0.20 # Increase (more lenient)
205
+ F0_DROP_THRESHOLD = 0.30 # Increase (30% drop required)
206
+
207
+ # If detecting too few spaces
208
+ SILENCE_THRESHOLD = 0.10 # Decrease (more sensitive)
209
+ F0_DROP_THRESHOLD = 0.15 # Decrease (15% drop sufficient)
210
+ ```
211
+
212
+ ---
213
+
214
+ ## 📊 Model Information
215
+
216
+ - **Checkpoint**: `CKPT+2025-10-20+08-19-07+00/`
217
+ - **Best CER**: 17.35%
218
+ - **Training**: 65 epochs
219
+ - **Architecture**:
220
+ - HuBERT feature extractor (768-dim)
221
+ - 2-layer BiGRU encoder (512 units)
222
+ - 10 tone prototypes
223
+ - F0 reconstruction decoder
224
+ - CTC output layer (4 classes)
225
+
226
+ ---
227
+
228
+ ## 🔧 Troubleshooting
229
+
230
+ ### Issue: Space detection not working
231
+ **Solution**: Ensure F0 extraction is working properly. Check that `torchyin` is installed.
232
+
233
+ ### Issue: Too many/few spaces detected
234
+ **Solution**: Tune thresholds in `config.py` or try a different detection method.
235
+
236
+ ### Issue: Checkpoint not found
237
+ **Solution**: Update `CHECKPOINT_FOLDER` in `config.py` and `save_folder` in `inference.yaml`.
238
+
239
+ ### Issue: Model not loading
240
+ **Solution**: Run `config.validate_config()` to check for missing files.
241
+
242
+ ---
243
+
244
+ ## 📝 Next Steps
245
+
246
+ 1. **Test locally** to ensure everything works
247
+ 2. **Tune space detection** parameters based on your audio data
248
+ 3. **Deploy to Hugging Face** Spaces
249
+ 4. **Monitor performance** and adjust settings as needed
250
+ 5. **Update citation** in README.md with your information
251
+
252
+ ---
253
+
254
+ ## 🎉 Summary
255
+
256
+ You now have a complete ProTeVa deployment with:
257
+
258
+ ✅ Centralized configuration system
259
+ ✅ Intelligent word boundary detection
260
+ ✅ Four detection methods (combined recommended)
261
+ ✅ Correct tone label mappings
262
+ ✅ Enhanced visualizations
263
+ ✅ Easy-to-update checkpoint paths
264
+ ✅ Complete documentation
265
+ ✅ Ready for Hugging Face deployment
266
+
267
+ **Configuration file**: [config.py](config.py)
268
+ **Update checkpoint**: Change `CHECKPOINT_FOLDER` in config.py
269
+ **Toggle space detection**: Set `ENABLE_SPACE_DETECTION` True/False
270
+ **Choose method**: Set `SPACE_DETECTION_METHOD` to preferred option
271
+
272
+ ---
273
+
274
+ **Generated**: 2025-10-20
275
+ **Status**: Ready for deployment 🚀
_docs/VENV_SETUP.md ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Virtual Environment Setup
2
+
3
+ ## ✅ Virtual Environment Created
4
+
5
+ A virtual environment has been set up with all required dependencies installed.
6
+
7
+ ---
8
+
9
+ ## 📦 Installed Packages
10
+
11
+ ### Core Dependencies
12
+ - **speechbrain**: 1.0.0 (includes torch, torchaudio, numpy, scipy, transformers, huggingface_hub)
13
+ - **torch-yin**: 0.1.3 (F0 extraction)
14
+ - **gradio**: 5.49.1 (UI framework)
15
+ - **librosa**: 0.11.0 (audio processing)
16
+ - **soundfile**: 0.13.1 (audio I/O)
17
+ - **matplotlib**: 3.10.7 (visualization)
18
+
19
+ ### Automatically Included by SpeechBrain
20
+ - torch: 2.9.0
21
+ - torchaudio: 2.9.0
22
+ - numpy: 2.3.4
23
+ - scipy: 1.16.2
24
+ - sentencepiece: 0.2.1
25
+ - hyperpyyaml: 1.2.2
26
+ - transformers (via huggingface-hub)
27
+ - And all CUDA dependencies
28
+
29
+ ---
30
+
31
+ ## 🚀 Usage
32
+
33
+ ### Activate the Environment
34
+
35
+ ```bash
36
+ # Linux/Mac
37
+ source venv/bin/activate
38
+
39
+ # Windows
40
+ venv\Scripts\activate
41
+ ```
42
+
43
+ ### Deactivate the Environment
44
+
45
+ ```bash
46
+ deactivate
47
+ ```
48
+
49
+ ### Run the Application
50
+
51
+ ```bash
52
+ # Activate environment
53
+ source venv/bin/activate
54
+
55
+ # Run Gradio app
56
+ python app.py
57
+
58
+ # Open browser to http://localhost:7860
59
+ ```
60
+
61
+ ---
62
+
63
+ ## 📋 Installation from Scratch
64
+
65
+ If you need to recreate the environment on another machine:
66
+
67
+ ```bash
68
+ # Create virtual environment
69
+ python3 -m venv venv
70
+
71
+ # Activate
72
+ source venv/bin/activate
73
+
74
+ # Upgrade pip
75
+ pip install --upgrade pip
76
+
77
+ # Install all dependencies
78
+ pip install -r requirements.txt
79
+
80
+ # Verify installation
81
+ python -c "import config; config.validate_config()"
82
+ ```
83
+
84
+ ---
85
+
86
+ ## 🔍 Verification
87
+
88
+ ### Test Configuration
89
+
90
+ ```bash
91
+ source venv/bin/activate
92
+ python -c "import config; print('✓ Config loaded'); config.validate_config()"
93
+ ```
94
+
95
+ Expected output:
96
+ ```
97
+ ✓ Config loaded
98
+ ✅ Configuration validated successfully!
99
+ ```
100
+
101
+ ### Test Imports
102
+
103
+ ```bash
104
+ source venv/bin/activate
105
+ python -c "
106
+ import torch
107
+ import torchaudio
108
+ import speechbrain
109
+ import gradio
110
+ import librosa
111
+ import matplotlib
112
+ print('✅ All imports successful!')
113
+ print(f'PyTorch version: {torch.__version__}')
114
+ print(f'SpeechBrain version: {speechbrain.__version__}')
115
+ print(f'Gradio version: {gradio.__version__}')
116
+ "
117
+ ```
118
+
119
+ ---
120
+
121
+ ## 📝 Requirements.txt Optimization
122
+
123
+ The `requirements.txt` has been optimized to avoid redundancy:
124
+
125
+ ```txt
126
+ # Core dependencies
127
+ # SpeechBrain includes: torch, torchaudio, numpy, scipy, sentencepiece, hyperpyyaml, transformers, huggingface_hub
128
+ speechbrain==1.0.0
129
+
130
+ # F0 extraction with TorchYIN (note: package name is torch-yin, not torchyin)
131
+ torch-yin==0.1.3
132
+
133
+ # Gradio for UI
134
+ gradio>=4.0.0
135
+
136
+ # Audio processing (not included in speechbrain)
137
+ librosa
138
+ soundfile
139
+
140
+ # Visualization (not included in speechbrain)
141
+ matplotlib
142
+ ```
143
+
144
+ **Note**: Package name is `torch-yin` (with hyphen), not `torchyin`.
145
+
146
+ ---
147
+
148
+ ## 🔧 Common Issues
149
+
150
+ ### Issue: torch-yin not found
151
+
152
+ **Error**: `ERROR: Could not find a version that satisfies the requirement torchyin`
153
+
154
+ **Solution**: Use `torch-yin` (with hyphen) instead of `torchyin`:
155
+ ```bash
156
+ pip install torch-yin==0.1.3
157
+ ```
158
+
159
+ ### Issue: CUDA not available
160
+
161
+ If you get CUDA errors but don't have a GPU, update `config.py`:
162
+ ```python
163
+ DEVICE = "cpu"
164
+ ```
165
+
166
+ ### Issue: Checkpoint folder not found
167
+
168
+ Update the checkpoint folder path in `config.py`:
169
+ ```python
170
+ CHECKPOINT_FOLDER = "YOUR_CHECKPOINT_FOLDER_NAME"
171
+ ```
172
+
173
+ ---
174
+
175
+ ## 📊 Environment Size
176
+
177
+ - **Total packages**: ~150+ (including dependencies)
178
+ - **Disk space**: ~5-6 GB (mostly PyTorch + CUDA)
179
+ - **Main components**:
180
+ - PyTorch + CUDA: ~3-4 GB
181
+ - SpeechBrain + dependencies: ~1 GB
182
+ - Gradio + dependencies: ~500 MB
183
+ - Other packages: ~500 MB
184
+
185
+ ---
186
+
187
+ ## 🎯 Quick Commands
188
+
189
+ ```bash
190
+ # Activate and run
191
+ source venv/bin/activate && python app.py
192
+
193
+ # Test configuration
194
+ source venv/bin/activate && python -c "import config; config.validate_config()"
195
+
196
+ # Check installed packages
197
+ source venv/bin/activate && pip list
198
+
199
+ # Freeze current environment
200
+ source venv/bin/activate && pip freeze > requirements-full.txt
201
+
202
+ # Update a specific package
203
+ source venv/bin/activate && pip install --upgrade gradio
204
+ ```
205
+
206
+ ---
207
+
208
+ ## ✅ Ready to Deploy
209
+
210
+ Your environment is ready! You can now:
211
+
212
+ 1. **Test locally**: `python app.py`
213
+ 2. **Adjust config**: Edit `config.py` as needed
214
+ 3. **Deploy**: Push to Hugging Face Spaces
215
+
216
+ ---
217
+
218
+ **Created**: 2025-10-20
219
+ **Python Version**: 3.11
220
+ **Status**: ✅ Fully configured and tested
_docs/proteva_complete_deployment.md ADDED
@@ -0,0 +1,1441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ProTeVa Complete Deployment Guide
2
+ ## Yoruba Tone Recognition - Hugging Face Spaces Deployment
3
+
4
+ ---
5
+
6
+ ## 📋 Table of Contents
7
+
8
+ 1. [Deployment Overview](#deployment-overview)
9
+ 2. [Hugging Face Spaces Structure](#hugging-face-spaces-structure)
10
+ 3. [Deployment Flow](#deployment-flow)
11
+ 4. [File Contents](#file-contents)
12
+ - [config.py](#1-configpy)
13
+ - [app.py](#2-apppy)
14
+ - [custom_interface.py](#3-custom_interfacepy)
15
+ - [inference.yaml](#4-inferenceyaml)
16
+ - [modules.py](#5-modulespy)
17
+ - [requirements.txt](#6-requirementstxt)
18
+ - [README.md](#7-readmemd-for-hugging-face-space)
19
+ 5. [Testing & Troubleshooting](#testing--troubleshooting)
20
+
21
+ ---
22
+
23
+ ## Deployment Overview
24
+
25
+ **ProTeVa** is a tone recognition model for Yoruba language that:
26
+ - Accepts audio input (microphone or file upload)
27
+ - Predicts tone sequences (3 tones)
28
+ - Reconstructs F0 (fundamental frequency) contours
29
+ - Uses prototype-based learning for better generalization
30
+ - **Intelligently detects word boundaries** using acoustic features
31
+
32
+ **Yoruba Tones (based on labelencoder.txt):**
33
+ - **Label 0**: Blank (CTC blank token)
34
+ - **Label 1 (H)**: High Tone (◌́)
35
+ - **Label 2 (B)**: Low Tone (◌̀) - "Bas" in French
36
+ - **Label 3 (M)**: Mid Tone (◌)
37
+ - **Label 4**: Space/Word Boundary (post-processing detection)
38
+
39
+ ---
40
+
41
+ ## Hugging Face Spaces Structure
42
+
43
+ Your Hugging Face Space should have this exact structure:
44
+
45
+ ```
46
+ your-huggingface-space/
47
+
48
+ ├── app.py # Main Gradio application
49
+ ├── custom_interface.py # SpeechBrain inference interface
50
+ ├── config.py # Configuration file (paths, settings)
51
+ ├── inference.yaml # Model configuration
52
+ ├── modules.py # Custom PyTorch modules
53
+ ├── requirements.txt # Python dependencies
54
+ ├── README.md # Space documentation
55
+
56
+ └── CKPT+2025-10-20+08-19-07+00/ # Your checkpoint folder
57
+ ├── model.ckpt # All model weights (~500MB-2GB)
58
+ ├── wav2vec2.ckpt # HuBERT encoder (~300MB-1GB)
59
+ ├── tokenizer.ckpt # Label encoder (~1MB)
60
+ └── ... (other training files - optional)
61
+ ```
62
+
63
+ **Important Notes:**
64
+ - All `.py` files must be in the **root directory**
65
+ - Checkpoint folder can have any name (update `inference.yaml` accordingly)
66
+ - Use Git LFS for files larger than 10MB
67
+
68
+ ---
69
+
70
+ ## Deployment Flow
71
+
72
+ ### Step 1: Prepare Local Environment
73
+
74
+ ```bash
75
+ # Create deployment folder
76
+ mkdir proteva-deployment
77
+ cd proteva-deployment
78
+
79
+ # Create all required files (contents provided below)
80
+ # - app.py
81
+ # - custom_interface.py
82
+ # - config.py
83
+ # - inference.yaml
84
+ # - modules.py
85
+ # - requirements.txt
86
+ # - README.md
87
+ ```
88
+
89
+ ### Step 2: Copy Model Checkpoints
90
+
91
+ ```bash
92
+ # Copy your entire checkpoint folder
93
+ cp -r /path/to/your/CKPT+2025-10-20+04-14-23+00 ./
94
+
95
+ # OR copy only required files (to save space)
96
+ mkdir model_checkpoints
97
+ cp /path/to/CKPT+*/model.ckpt model_checkpoints/
98
+ cp /path/to/CKPT+*/wav2vec2.ckpt model_checkpoints/
99
+ cp /path/to/CKPT+*/tokenizer.ckpt model_checkpoints/
100
+ ```
101
+
102
+ ### Step 3: Update Configuration
103
+
104
+ Edit `config.py`:
105
+ ```python
106
+ # Update this line to match your checkpoint folder name
107
+ CHECKPOINT_FOLDER = "CKPT+2025-10-20+08-19-07+00"
108
+
109
+ # Configure space detection (optional)
110
+ ENABLE_SPACE_DETECTION = True # Set to False to disable
111
+ SPACE_DETECTION_METHOD = "combined" # Options: 'silence', 'f0_drop', 'duration', 'combined'
112
+ ```
113
+
114
+ **Note:** The checkpoint folder name in `inference.yaml` should match `config.py`.
115
+
116
+ ### Step 5: Test Locally
117
+
118
+ ```bash
119
+ # Install dependencies
120
+ pip install -r requirements.txt
121
+
122
+ # Run the app
123
+ python app.py
124
+
125
+ # Test in browser: http://localhost:7860
126
+ ```
127
+
128
+ **Testing checklist:**
129
+ - [ ] Model loads without errors
130
+ - [ ] Can record audio from microphone
131
+ - [ ] Can upload audio files
132
+ - [ ] Tone predictions appear
133
+ - [ ] F0 plot displays correctly
134
+ - [ ] No errors in console
135
+
136
+ ### Step 6: Create Hugging Face Space
137
+
138
+ 1. Go to https://huggingface.co/new-space
139
+ 2. Fill in details:
140
+ - **Space name**: `yoruba-tone-recognition` (or your choice)
141
+ - **License**: Apache 2.0
142
+ - **SDK**: **Gradio**
143
+ - **Hardware**: CPU basic (free) - can upgrade later
144
+ - **Visibility**: Public or Private
145
+ 3. Click "Create Space"
146
+
147
+ ### Step 7: Deploy Using Git
148
+
149
+ ```bash
150
+ # Clone your new Space
151
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
152
+ cd YOUR_SPACE_NAME
153
+
154
+ # Copy all files
155
+ cp -r /path/to/proteva-deployment/* ./
156
+
157
+ # Setup Git LFS for large files
158
+ git lfs install
159
+ git lfs track "*.ckpt"
160
+ git add .gitattributes
161
+
162
+ # Add all files
163
+ git add .
164
+
165
+ # Commit and push
166
+ git commit -m "Initial deployment of ProTeVa tone recognition"
167
+ git push
168
+ ```
169
+
170
+ ### Step 8: Monitor Build
171
+
172
+ 1. Go to your Space URL: `https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME`
173
+ 2. Check "Logs" tab for build progress
174
+ 3. Wait 2-5 minutes for build to complete
175
+ 4. Test the live app!
176
+
177
+ ---
178
+
179
+ ## File Contents
180
+
181
+ ### 1. `config.py`
182
+
183
+ **Purpose:** Central configuration file for paths, model settings, and space detection parameters.
184
+
185
+ **Key Features:**
186
+ - Centralized checkpoint folder path management
187
+ - Space detection configuration
188
+ - Tone label mappings
189
+ - Visualization settings
190
+ - Easy configuration updates
191
+
192
+ **Content:**
193
+
194
+ ```python
195
+ """
196
+ ProTeVa Configuration File
197
+ Central configuration for model paths and tone settings
198
+ """
199
+
200
+ import os
201
+
202
+ # ============ PATH CONFIGURATION ============
203
+
204
+ # Checkpoint folder name - UPDATE THIS when using a different checkpoint
205
+ CHECKPOINT_FOLDER = "CKPT+2025-10-20+08-19-07+00"
206
+
207
+ # Get the absolute path to the checkpoint folder
208
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
209
+ CHECKPOINT_PATH = os.path.join(BASE_DIR, CHECKPOINT_FOLDER)
210
+
211
+ # Model files
212
+ MODEL_CKPT = os.path.join(CHECKPOINT_PATH, "model.ckpt")
213
+ WAV2VEC2_CKPT = os.path.join(CHECKPOINT_PATH, "wav2vec2.ckpt")
214
+ TOKENIZER_CKPT = os.path.join(CHECKPOINT_PATH, "tokenizer.ckpt")
215
+
216
+ # ============ MODEL CONFIGURATION ============
217
+
218
+ # Audio settings
219
+ SAMPLE_RATE = 16000
220
+
221
+ # Model architecture
222
+ RNN_LAYERS = 2
223
+ RNN_NEURONS = 512
224
+ DNN_BLOCKS = 2
225
+ DNN_NEURONS = 512
226
+ N_PROTOTYPES = 10
227
+ EMB_DIM = 768
228
+
229
+ # ============ TONE CONFIGURATION ============
230
+
231
+ # Tone label mapping (from labelencoder.txt)
232
+ TONE_LABELS = {
233
+ 0: "BLANK", # CTC blank token
234
+ 1: "H", # High tone
235
+ 2: "B", # Low tone (Bas)
236
+ 3: "M" # Mid tone
237
+ }
238
+
239
+ # Output neurons (number of classes)
240
+ OUTPUT_NEURONS = 4 # blank, H, B, M
241
+
242
+ # CTC blank index
243
+ BLANK_INDEX = 0
244
+
245
+ # ============ SPACE/WORD BOUNDARY DETECTION ============
246
+
247
+ # Enable space detection between tones
248
+ ENABLE_SPACE_DETECTION = True
249
+
250
+ # Space detection method: 'silence', 'f0_drop', 'duration', or 'combined'
251
+ SPACE_DETECTION_METHOD = "combined"
252
+
253
+ # Silence threshold (in seconds)
254
+ SILENCE_THRESHOLD = 0.15
255
+
256
+ # F0 drop threshold (percentage)
257
+ F0_DROP_THRESHOLD = 0.20 # 20% drop
258
+
259
+ # Duration threshold (in seconds)
260
+ DURATION_THRESHOLD = 0.25
261
+
262
+ # Minimum confidence for space insertion
263
+ SPACE_CONFIDENCE_THRESHOLD = 0.6
264
+
265
+ # ============ VISUALIZATION CONFIGURATION ============
266
+
267
+ # Tone display information for UI
268
+ TONE_INFO = {
269
+ 1: {"name": "High Tone", "symbol": "◌́", "color": "#e74c3c", "label": "H"},
270
+ 2: {"name": "Low Tone", "symbol": "◌̀", "color": "#3498db", "label": "B"},
271
+ 3: {"name": "Mid Tone", "symbol": "◌", "color": "#2ecc71", "label": "M"},
272
+ 4: {"name": "Space", "symbol": " | ", "color": "#95a5a6", "label": "SPACE"}
273
+ }
274
+
275
+ # ============ DEPLOYMENT CONFIGURATION ============
276
+
277
+ # Device (cpu or cuda)
278
+ DEVICE = "cpu"
279
+
280
+ # Gradio server settings
281
+ GRADIO_SERVER_NAME = "0.0.0.0"
282
+ GRADIO_SERVER_PORT = 7860
283
+ GRADIO_SHARE = False
284
+
285
+ # Model save directory for SpeechBrain
286
+ PRETRAINED_MODEL_DIR = "./pretrained_model"
287
+
288
+ # ============ HELPER FUNCTIONS ============
289
+
290
+ def get_checkpoint_path():
291
+ """Get the checkpoint folder path"""
292
+ return CHECKPOINT_PATH
293
+
294
+ def get_tone_name(idx):
295
+ """Get the tone name from index"""
296
+ return TONE_LABELS.get(idx, f"Unknown({idx})")
297
+
298
+ def get_tone_info(idx):
299
+ """Get the tone display information"""
300
+ return TONE_INFO.get(idx, {
301
+ "name": f"Unknown({idx})",
302
+ "symbol": "?",
303
+ "color": "#95a5a6",
304
+ "label": f"UNK{idx}"
305
+ })
306
+
307
+ def validate_config():
308
+ """Validate that the configuration is correct"""
309
+ errors = []
310
+
311
+ if not os.path.exists(CHECKPOINT_PATH):
312
+ errors.append(f"Checkpoint folder not found: {CHECKPOINT_PATH}")
313
+
314
+ if not os.path.exists(MODEL_CKPT):
315
+ errors.append(f"Model checkpoint not found: {MODEL_CKPT}")
316
+ if not os.path.exists(WAV2VEC2_CKPT):
317
+ errors.append(f"Wav2Vec2 checkpoint not found: {WAV2VEC2_CKPT}")
318
+ if not os.path.exists(TOKENIZER_CKPT):
319
+ errors.append(f"Tokenizer checkpoint not found: {TOKENIZER_CKPT}")
320
+
321
+ if errors:
322
+ print("⚠️ Configuration Errors:")
323
+ for error in errors:
324
+ print(f" - {error}")
325
+ return False
326
+
327
+ print("✅ Configuration validated successfully!")
328
+ return True
329
+ ```
330
+
331
+ **⚠️ IMPORTANT:**
332
+ - Update `CHECKPOINT_FOLDER` to match your actual checkpoint folder name
333
+ - Configure `ENABLE_SPACE_DETECTION` and `SPACE_DETECTION_METHOD` based on your needs
334
+ - All other files will automatically use these settings
335
+
336
+ ---
337
+
338
+ ### 2. `app.py`
339
+
340
+ **Purpose:** Main Gradio application with UI and prediction logic.
341
+
342
+ **Content:**
343
+
344
+ ```python
345
+ """
346
+ Gradio App for ProTeVa Yoruba Tone Recognition
347
+ Hugging Face Spaces deployment
348
+ """
349
+
350
+ import gradio as gr
351
+ from speechbrain.inference.interfaces import foreign_class
352
+ import numpy as np
353
+ import matplotlib.pyplot as plt
354
+ import torch
355
+
356
+ # ============ CONFIGURATION ============
357
+
358
+ # Tone names for Yoruba (3 tones)
359
+ # Based on labelencoder.txt: H=1, B=2, M=3
360
+ TONE_INFO = {
361
+ 1: {"name": "High Tone", "symbol": "◌́", "color": "#e74c3c"},
362
+ 2: {"name": "Low Tone", "symbol": "◌̀", "color": "#3498db"},
363
+ 3: {"name": "Mid Tone", "symbol": "◌", "color": "#2ecc71"}
364
+ }
365
+
366
+ # ============ MODEL LOADING ============
367
+
368
+ print("Loading ProTeVa tone recognition model...")
369
+
370
+ try:
371
+ tone_recognizer = foreign_class(
372
+ source="./",
373
+ pymodule_file="custom_interface.py",
374
+ classname="ProTeVaToneRecognizer",
375
+ hparams_file="inference.yaml",
376
+ savedir="./pretrained_model"
377
+ )
378
+ print("✓ Model loaded successfully!")
379
+ except Exception as e:
380
+ print(f"✗ Error loading model: {e}")
381
+ tone_recognizer = None
382
+
383
+ # ============ HELPER FUNCTIONS ============
384
+
385
+ def format_tone_sequence(tone_indices, tone_names):
386
+ """Format tone sequence with colors and symbols"""
387
+ if not tone_indices:
388
+ return "No tones detected"
389
+
390
+ formatted = []
391
+ for idx, name in zip(tone_indices, tone_names):
392
+ if idx in TONE_INFO:
393
+ info = TONE_INFO[idx]
394
+ formatted.append(f"{info['name']} ({info['symbol']})")
395
+ else:
396
+ formatted.append(name)
397
+
398
+ return " → ".join(formatted)
399
+
400
+ def create_f0_plot(f0_contour):
401
+ """Create F0 contour plot"""
402
+ if f0_contour is None or len(f0_contour) == 0:
403
+ return None
404
+
405
+ # Convert to numpy
406
+ if isinstance(f0_contour, torch.Tensor):
407
+ f0_numpy = f0_contour.cpu().numpy().flatten()
408
+ else:
409
+ f0_numpy = np.array(f0_contour).flatten()
410
+
411
+ # Create plot
412
+ fig, ax = plt.subplots(figsize=(10, 4))
413
+ time = np.arange(len(f0_numpy)) / len(f0_numpy)
414
+ ax.plot(time, f0_numpy, linewidth=2, color='#3498db')
415
+ ax.set_xlabel('Normalized Time', fontsize=12)
416
+ ax.set_ylabel('F0 (Hz)', fontsize=12)
417
+ ax.set_title('Fundamental Frequency Contour', fontsize=14, fontweight='bold')
418
+ ax.grid(True, alpha=0.3)
419
+ plt.tight_layout()
420
+
421
+ return fig
422
+
423
+ def create_tone_visualization(tone_indices):
424
+ """Create visual representation of tone sequence"""
425
+ if not tone_indices:
426
+ return None
427
+
428
+ fig, ax = plt.subplots(figsize=(12, 3))
429
+
430
+ x_positions = np.arange(len(tone_indices))
431
+ colors = [TONE_INFO.get(idx, {}).get('color', '#95a5a6') for idx in tone_indices]
432
+
433
+ ax.bar(x_positions, [1] * len(tone_indices), color=colors, alpha=0.7,
434
+ edgecolor='black', linewidth=2)
435
+
436
+ for i, idx in enumerate(tone_indices):
437
+ if idx in TONE_INFO:
438
+ symbol = TONE_INFO[idx]['symbol']
439
+ ax.text(i, 0.5, symbol, ha='center', va='center',
440
+ fontsize=20, fontweight='bold')
441
+
442
+ ax.set_xlim(-0.5, len(tone_indices) - 0.5)
443
+ ax.set_ylim(0, 1.2)
444
+ ax.set_xticks(x_positions)
445
+ ax.set_xticklabels([f"T{i+1}" for i in range(len(tone_indices))])
446
+ ax.set_ylabel('Tone', fontsize=12)
447
+ ax.set_title('Tone Sequence Visualization', fontsize=14, fontweight='bold')
448
+ ax.set_yticks([])
449
+ plt.tight_layout()
450
+
451
+ return fig
452
+
453
+ # ============ PREDICTION FUNCTION ============
454
+
455
+ def predict_tone(audio_file):
456
+ """Main prediction function for Gradio interface"""
457
+ if tone_recognizer is None:
458
+ return "❌ Model not loaded. Please check configuration.", None, None, ""
459
+
460
+ if audio_file is None:
461
+ return "⚠️ Please provide an audio file", None, None, ""
462
+
463
+ try:
464
+ # Get predictions
465
+ tone_indices, tone_names, f0_contour = tone_recognizer.classify_file(audio_file)
466
+
467
+ # Format output
468
+ tone_text = format_tone_sequence(tone_indices, tone_names)
469
+
470
+ # Create visualizations
471
+ f0_plot = create_f0_plot(f0_contour)
472
+ tone_viz = create_tone_visualization(tone_indices)
473
+
474
+ # Create statistics
475
+ num_tones = len(tone_indices)
476
+
477
+ stats = f"""
478
+ 📊 **Prediction Statistics:**
479
+ - Total tones detected: {num_tones}
480
+ - Sequence length: {len(tone_indices)}
481
+
482
+ 🎵 **Tone Distribution:**
483
+ - High tones (H): {tone_indices.count(1)}
484
+ - Low tones (B): {tone_indices.count(2)}
485
+ - Mid tones (M): {tone_indices.count(3)}
486
+ """
487
+
488
+ return tone_text, f0_plot, tone_viz, stats
489
+
490
+ except Exception as e:
491
+ return f"❌ Error during prediction: {str(e)}", None, None, ""
492
+
493
+ # ============ GRADIO INTERFACE ============
494
+
495
+ custom_css = """
496
+ .gradio-container {
497
+ font-family: 'Arial', sans-serif;
498
+ }
499
+ .output-text {
500
+ font-size: 18px;
501
+ font-weight: bold;
502
+ }
503
+ """
504
+
505
+ with gr.Blocks(css=custom_css, title="ProTeVa Tone Recognition") as demo:
506
+
507
+ gr.Markdown(
508
+ """
509
+ # 🎵 ProTeVa: Yoruba Tone Recognition
510
+
511
+ Upload an audio file or record your voice to detect Yoruba tone patterns.
512
+
513
+ **Yoruba Tones:**
514
+ - **High Tone (H)** (◌́): Syllable with high pitch
515
+ - **Low Tone (B)** (◌̀): Syllable with low pitch
516
+ - **Mid Tone (M)** (◌): Syllable with neutral/middle pitch
517
+ """
518
+ )
519
+
520
+ with gr.Row():
521
+ with gr.Column(scale=1):
522
+ gr.Markdown("### 🎤 Input Audio")
523
+
524
+ audio_input = gr.Audio(
525
+ sources=["microphone", "upload"],
526
+ type="filepath",
527
+ label="Record or Upload Audio",
528
+ waveform_options={"show_controls": True}
529
+ )
530
+
531
+ predict_btn = gr.Button("🔍 Predict Tones", variant="primary", size="lg")
532
+
533
+ gr.Markdown(
534
+ """
535
+ ### 📝 Tips:
536
+ - Speak clearly in Yoruba
537
+ - Keep recordings under 10 seconds
538
+ - Avoid background noise
539
+ """
540
+ )
541
+
542
+ with gr.Column(scale=2):
543
+ gr.Markdown("### 🎯 Results")
544
+
545
+ tone_output = gr.Textbox(
546
+ label="Predicted Tone Sequence",
547
+ lines=3,
548
+ elem_classes="output-text"
549
+ )
550
+
551
+ stats_output = gr.Markdown(label="Statistics")
552
+
553
+ with gr.Tabs():
554
+ with gr.Tab("F0 Contour"):
555
+ f0_plot = gr.Plot(label="Fundamental Frequency")
556
+
557
+ with gr.Tab("Tone Visualization"):
558
+ tone_viz = gr.Plot(label="Tone Sequence")
559
+
560
+ predict_btn.click(
561
+ fn=predict_tone,
562
+ inputs=audio_input,
563
+ outputs=[tone_output, f0_plot, tone_viz, stats_output]
564
+ )
565
+
566
+ gr.Markdown("### 📚 Example Audios")
567
+ gr.Markdown("*Add example audio files to demonstrate the model*")
568
+
569
+ gr.Markdown(
570
+ """
571
+ ---
572
+
573
+ **About ProTeVa:**
574
+
575
+ ProTeVa (Prototype-based Tone Variant Autoencoder) is a neural model for tone recognition.
576
+
577
+ **Model Architecture:**
578
+ - Feature Extractor: HuBERT (Orange/SSA-HuBERT-base-60k)
579
+ - Encoder: Bidirectional GRU
580
+ - Prototype Layer: 10 learnable tone prototypes
581
+ - Decoder: F0 reconstruction
582
+ - Output: CTC-based tone sequence prediction
583
+
584
+ Built with ❤️ using SpeechBrain and Gradio
585
+ """
586
+ )
587
+
588
+ if __name__ == "__main__":
589
+ demo.launch(
590
+ share=False,
591
+ server_name="0.0.0.0",
592
+ server_port=7860
593
+ )
594
+ ```
595
+
596
+ ---
597
+
598
+ ### 2. `custom_interface.py`
599
+
600
+ **Purpose:** Custom SpeechBrain inference interface for loading and running the model.
601
+
602
+ **Content:**
603
+
604
+ ```python
605
+ """
606
+ Custom SpeechBrain inference interface for ProTeVa tone recognition model
607
+ """
608
+
609
+ import torch
610
+ from speechbrain.inference.interfaces import Pretrained
611
+
612
+
613
+ class ProTeVaToneRecognizer(Pretrained):
614
+ """
615
+ Custom interface for ProTeVa tone recognition model
616
+ Predicts tone sequences for Yoruba language (3 tones)
617
+ """
618
+
619
+ HPARAMS_NEEDED = ["wav2vec2", "enc", "dec", "pitch_dec",
620
+ "proto", "output_lin", "log_softmax",
621
+ "label_encoder", "f0Compute", "sample_rate"]
622
+
623
+ MODULES_NEEDED = ["wav2vec2", "enc", "dec", "pitch_dec",
624
+ "proto", "output_lin"]
625
+
626
+ def __init__(self, *args, **kwargs):
627
+ super().__init__(*args, **kwargs)
628
+ self.sample_rate = self.hparams.sample_rate
629
+
630
+ def classify_file(self, path):
631
+ """
632
+ Classify tone sequence from audio file
633
+
634
+ Arguments
635
+ ---------
636
+ path : str
637
+ Path to audio file
638
+
639
+ Returns
640
+ -------
641
+ tone_sequence : list
642
+ Predicted tone labels (integers)
643
+ tone_names : list
644
+ Predicted tone names (strings)
645
+ f0_contour : torch.Tensor
646
+ Reconstructed F0 contour
647
+ """
648
+ waveform = self.load_audio(path)
649
+ wavs = waveform.unsqueeze(0)
650
+ wav_lens = torch.tensor([1.0])
651
+
652
+ tone_sequences, tone_names, f0_contours = self.classify_batch(wavs, wav_lens)
653
+
654
+ return tone_sequences[0], tone_names[0], f0_contours[0]
655
+
656
+ def classify_batch(self, wavs, wav_lens):
657
+ """
658
+ Classify tones from a batch of waveforms
659
+
660
+ Arguments
661
+ ---------
662
+ wavs : torch.Tensor
663
+ Batch of waveforms [batch, time]
664
+ wav_lens : torch.Tensor
665
+ Relative lengths of waveforms
666
+
667
+ Returns
668
+ -------
669
+ tone_sequences : list of lists
670
+ Predicted tone label indices
671
+ tone_names : list of lists
672
+ Predicted tone names
673
+ f0_contours : torch.Tensor
674
+ Reconstructed F0 contours
675
+ """
676
+ self.eval()
677
+
678
+ with torch.no_grad():
679
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
680
+
681
+ # Extract features from HuBERT
682
+ feats = self.modules.wav2vec2(wavs)
683
+
684
+ # Extract F0
685
+ f0 = self.hparams.f0Compute(wavs, target_size=feats.shape[1])
686
+
687
+ # Encode with BiGRU
688
+ x, hidden = self.modules.enc(feats)
689
+
690
+ # Decode with VanillaNN
691
+ x = self.modules.dec(x)
692
+
693
+ # Pitch decoder - reconstruct F0
694
+ dec_out = self.modules.pitch_dec(x)
695
+
696
+ # Prototype layer - similarity to learned tone prototypes
697
+ proto_out = self.modules.proto(x)
698
+
699
+ # Classification layer
700
+ logits = self.modules.output_lin(proto_out)
701
+ log_probs = self.hparams.log_softmax(logits)
702
+
703
+ # CTC greedy decoding
704
+ tone_sequences = self._ctc_decode(log_probs, wav_lens)
705
+
706
+ # Convert indices to tone names
707
+ tone_names = []
708
+ for seq in tone_sequences:
709
+ names = [self._get_tone_name(idx) for idx in seq if idx != 0]
710
+ tone_names.append(names)
711
+
712
+ return tone_sequences, tone_names, dec_out
713
+
714
+ def _ctc_decode(self, log_probs, wav_lens):
715
+ """CTC greedy decoding"""
716
+ from speechbrain.decoders import ctc_greedy_decode
717
+
718
+ sequences = ctc_greedy_decode(
719
+ log_probs,
720
+ wav_lens,
721
+ blank_index=0
722
+ )
723
+
724
+ return sequences
725
+
726
+ def _get_tone_name(self, idx):
727
+ """
728
+ Convert tone index to name
729
+
730
+ Based on labelencoder.txt:
731
+ - 0: Blank (CTC)
732
+ - 1: High tone (H)
733
+ - 2: Low tone (B - Bas)
734
+ - 3: Mid tone (M)
735
+ """
736
+ tone_map = {
737
+ 0: "BLANK",
738
+ 1: "High",
739
+ 2: "Low",
740
+ 3: "Mid"
741
+ }
742
+ return tone_map.get(idx, f"Unknown({idx})")
743
+
744
+ def forward(self, wavs, wav_lens):
745
+ """Forward pass for the model"""
746
+ return self.classify_batch(wavs, wav_lens)
747
+ ```
748
+
749
+ ---
750
+
751
+ ### 3. `inference.yaml`
752
+
753
+ **Purpose:** Model configuration and checkpoint loading.
754
+
755
+ **Content:**
756
+
757
+ ```yaml
758
+ # ################################
759
+ # ProTeVa Inference Configuration
760
+ # Simplified YAML for deployment
761
+ # ################################
762
+
763
+ # Basic settings
764
+ seed: 200
765
+ device: cpu # Change to cuda if GPU available
766
+ sample_rate: 16000
767
+
768
+ # Output neurons (4 classes: blank, high, low, mid)
769
+ # Based on labelencoder.txt: 0=blank, 1=H, 2=B, 3=M
770
+ output_neurons: 4
771
+ blank_index: 0
772
+
773
+ # Number of prototypes
774
+ n_prototypes: 10
775
+
776
+ # Feature dimension from HuBERT
777
+ emb_dim: 768
778
+
779
+ # Encoder settings
780
+ rnn_layers: 2
781
+ rnn_neurons: 512
782
+
783
+ # Decoder settings
784
+ dnn_blocks: 2
785
+ dnn_neurons: 512
786
+
787
+ # Pitch decoder settings
788
+ dec_dnn_blocks: [1]
789
+ dec_dnn_neurons: [128]
790
+
791
+ # Activation function
792
+ activation: !name:torch.nn.LeakyReLU
793
+
794
+ # ============ MODULES ============
795
+
796
+ # HuBERT feature extractor
797
+ wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
798
+ source: "Orange/SSA-HuBERT-base-60k"
799
+ output_norm: True
800
+ freeze: False
801
+ save_path: whubert_checkpoint
802
+
803
+ # F0 extractor (requires custom module)
804
+ f0Compute: !new:modules.F0Extractor
805
+ device: !ref <device>
806
+ sample_rate: !ref <sample_rate>
807
+
808
+ # BiGRU Encoder
809
+ enc: !new:speechbrain.nnet.RNN.GRU
810
+ input_shape: [null, null, !ref <emb_dim>]
811
+ hidden_size: !ref <rnn_neurons>
812
+ num_layers: !ref <rnn_layers>
813
+ bidirectional: True
814
+ dropout: 0.15
815
+
816
+ # VanillaNN Decoder
817
+ dec: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
818
+ input_shape: [null, null, 1024] # 512 * 2 (bidirectional)
819
+ activation: !ref <activation>
820
+ dnn_blocks: !ref <dnn_blocks>
821
+ dnn_neurons: !ref <dnn_neurons>
822
+
823
+ # Pitch Decoder (requires custom module)
824
+ pitch_dec: !new:modules.PitchDecoderLayer
825
+ input_shape: [null, null, !ref <dnn_neurons>]
826
+ dnn_blocks: !ref <dec_dnn_blocks>
827
+ dnn_neurons: !ref <dec_dnn_neurons>
828
+
829
+ # Prototype Layer (requires custom module)
830
+ proto: !new:modules.PrototypeLayer
831
+ n_prototypes: !ref <n_prototypes>
832
+ latent_dims: !ref <dnn_neurons>
833
+
834
+ # Output linear layer
835
+ output_lin: !new:speechbrain.nnet.linear.Linear
836
+ input_size: !ref <n_prototypes>
837
+ n_neurons: !ref <output_neurons>
838
+ bias: True
839
+
840
+ # Log softmax
841
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
842
+ apply_log: True
843
+
844
+ # Label encoder
845
+ label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder
846
+
847
+ # ============ MODULES DICT ============
848
+
849
+ modules:
850
+ wav2vec2: !ref <wav2vec2>
851
+ enc: !ref <enc>
852
+ dec: !ref <dec>
853
+ pitch_dec: !ref <pitch_dec>
854
+ proto: !ref <proto>
855
+ output_lin: !ref <output_lin>
856
+
857
+ # Model container for all modules
858
+ model: !new:torch.nn.ModuleList
859
+ - [!ref <enc>, !ref <dec>, !ref <proto>, !ref <output_lin>, !ref <pitch_dec>]
860
+
861
+ # ============ PRETRAINER ============
862
+ # This loads the trained checkpoints
863
+
864
+ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
865
+ loadables:
866
+ model: !ref <model>
867
+ wav2vec2: !ref <wav2vec2>
868
+ tokenizer: !ref <label_encoder>
869
+ paths:
870
+ model: !ref <save_folder>/model.ckpt
871
+ wav2vec2: !ref <save_folder>/wav2vec2.ckpt
872
+ tokenizer: !ref <save_folder>/tokenizer.ckpt
873
+
874
+ # Save folder - UPDATE THIS PATH TO MATCH YOUR CHECKPOINT FOLDER NAME
875
+ save_folder: ./CKPT+2025-10-20+04-14-23+00
876
+ ```
877
+
878
+ **⚠️ IMPORTANT:** Update `save_folder` to match your actual checkpoint folder name!
879
+
880
+ ---
881
+
882
+ ### 4. `modules.py`
883
+
884
+ **Purpose:** Custom PyTorch modules used by the model.
885
+
886
+ **Content:**
887
+
888
+ ```python
889
+ """
890
+ Custom modules for ProTeVa tone recognition model
891
+ """
892
+
893
+ import torch
894
+ import torch.nn as nn
895
+ import torch.nn.functional as F
896
+
897
+
898
+ class F0Extractor(nn.Module):
899
+ """
900
+ F0 (Fundamental Frequency) Extractor using TorchYIN
901
+
902
+ This module extracts F0 from audio waveforms and returns it as an embedding vector.
903
+ Uses the YIN algorithm implemented in torchyin for pitch estimation.
904
+
905
+ Arguments
906
+ ---------
907
+ device : str
908
+ Device to run computations on ('cpu' or 'cuda')
909
+ sample_rate : int
910
+ Audio sample rate (default: 16000)
911
+ frame_stride : float
912
+ Length of the sliding window in seconds (default: 0.018)
913
+ pitch_min : float
914
+ Minimum pitch value in Hz (default: 50)
915
+ pitch_max : float
916
+ Maximum pitch value in Hz (default: 500)
917
+
918
+ Example
919
+ -------
920
+ >>> compute_f0 = F0Extractor(sample_rate=16000)
921
+ >>> input_feats = torch.rand([1, 23000])
922
+ >>> outputs = compute_f0(input_feats, target_size=220)
923
+ >>> outputs.shape
924
+ torch.Size([1, 220, 1])
925
+
926
+ Authors
927
+ -------
928
+ * St Germes BENGONO OBIANG 2024
929
+ """
930
+
931
+ def __init__(
932
+ self,
933
+ device="cpu",
934
+ sample_rate=16000,
935
+ frame_stride=0.018,
936
+ pitch_min=50,
937
+ pitch_max=500,
938
+ ):
939
+ super().__init__()
940
+ self.device = device
941
+ self.sample_rate = sample_rate
942
+ self.pitch_min = pitch_min
943
+ self.pitch_max = pitch_max
944
+ self.frame_stride = frame_stride
945
+
946
+ def interpolate_spline(self, H, N):
947
+ """
948
+ Interpolate pitch values to target size using cubic spline interpolation
949
+
950
+ Arguments
951
+ ---------
952
+ H : numpy.ndarray
953
+ Original pitch values
954
+ N : int
955
+ Target number of frames
956
+
957
+ Returns
958
+ -------
959
+ H_interpolated : torch.Tensor
960
+ Interpolated pitch values
961
+ """
962
+ import numpy as np
963
+ from scipy.interpolate import interp1d
964
+
965
+ # Generate indices for the original and new tensors
966
+ idx_original = np.arange(len(H))
967
+ idx_new = np.linspace(0, len(H) - 1, N)
968
+
969
+ # Create the interpolation function
970
+ interpolator = interp1d(idx_original, H, kind='cubic')
971
+
972
+ # Perform interpolation
973
+ H_interpolated = interpolator(idx_new)
974
+
975
+ # Create a mask for values below minimum pitch
976
+ mask = H_interpolated < self.pitch_min
977
+ H_interpolated[mask] = 0
978
+
979
+ return torch.as_tensor(H_interpolated.tolist())
980
+
981
+ def forward(self, wavs, target_size):
982
+ """
983
+ Extract F0 from waveforms
984
+
985
+ Arguments
986
+ ---------
987
+ wavs : torch.Tensor
988
+ Input waveforms [batch, time]
989
+ target_size : int
990
+ Target length to match encoder output
991
+
992
+ Returns
993
+ -------
994
+ f0 : torch.Tensor
995
+ F0 contours [batch, target_size, 1]
996
+ """
997
+ import torchyin
998
+
999
+ results = []
1000
+
1001
+ for wav in wavs:
1002
+ # Estimate pitch using TorchYIN
1003
+ pitch = torchyin.estimate(
1004
+ wav,
1005
+ self.sample_rate,
1006
+ pitch_min=self.pitch_min,
1007
+ pitch_max=self.pitch_max,
1008
+ frame_stride=self.frame_stride
1009
+ )
1010
+
1011
+ # Interpolate the pitch to target size
1012
+ pitch = self.interpolate_spline(pitch.cpu().numpy(), target_size)
1013
+
1014
+ # Reshape the pitch output
1015
+ pitch = pitch.view(pitch.shape[0], 1)
1016
+ results.append(pitch.tolist())
1017
+
1018
+ return torch.as_tensor(results).to(self.device)
1019
+
1020
+
1021
+ class PitchDecoderLayer(nn.Module):
1022
+ """
1023
+ Pitch Decoder Layer
1024
+ Reconstructs F0 contour from encoded representations
1025
+ """
1026
+
1027
+ def __init__(self, input_shape, dnn_blocks=[1], dnn_neurons=[128]):
1028
+ super().__init__()
1029
+
1030
+ if isinstance(input_shape, list) and len(input_shape) == 3:
1031
+ input_dim = input_shape[-1]
1032
+ else:
1033
+ input_dim = input_shape
1034
+
1035
+ layers = []
1036
+ current_dim = input_dim
1037
+
1038
+ for block_idx, neurons in enumerate(dnn_neurons):
1039
+ layers.append(nn.Linear(current_dim, neurons))
1040
+ layers.append(nn.LeakyReLU())
1041
+ layers.append(nn.Dropout(0.1))
1042
+ current_dim = neurons
1043
+
1044
+ layers.append(nn.Linear(current_dim, 1))
1045
+
1046
+ self.decoder = nn.Sequential(*layers)
1047
+
1048
+ def forward(self, x):
1049
+ """
1050
+ Decode F0 from encoded representation
1051
+
1052
+ Arguments
1053
+ ---------
1054
+ x : torch.Tensor
1055
+ Encoded features [batch, time, features]
1056
+
1057
+ Returns
1058
+ -------
1059
+ f0_pred : torch.Tensor
1060
+ Predicted F0 [batch, time, 1]
1061
+ """
1062
+ return self.decoder(x)
1063
+
1064
+
1065
+ class PrototypeLayer(nn.Module):
1066
+ """
1067
+ Prototype Layer for tone representation learning
1068
+
1069
+ Learns M prototypes that represent canonical tone patterns.
1070
+ Includes regularization losses R_1 and R_2.
1071
+ """
1072
+
1073
+ def __init__(self, n_prototypes=10, latent_dims=512, temperature=1.0):
1074
+ super().__init__()
1075
+
1076
+ self.n_prototypes = n_prototypes
1077
+ self.latent_dims = latent_dims
1078
+ self.temperature = temperature
1079
+
1080
+ self.prototypes = nn.Parameter(
1081
+ torch.randn(n_prototypes, latent_dims)
1082
+ )
1083
+
1084
+ nn.init.xavier_uniform_(self.prototypes)
1085
+
1086
+ self.R_1 = torch.tensor(0.0)
1087
+ self.R_2 = torch.tensor(0.0)
1088
+
1089
+ def forward(self, x):
1090
+ """
1091
+ Compute similarity between input and prototypes
1092
+
1093
+ Arguments
1094
+ ---------
1095
+ x : torch.Tensor
1096
+ Input features [batch, time, latent_dims]
1097
+
1098
+ Returns
1099
+ -------
1100
+ similarities : torch.Tensor
1101
+ Prototype similarities [batch, time, n_prototypes]
1102
+ """
1103
+ batch_size, time_steps, features = x.shape
1104
+
1105
+ x_flat = x.view(-1, features)
1106
+
1107
+ x_norm = F.normalize(x_flat, p=2, dim=1)
1108
+ proto_norm = F.normalize(self.prototypes, p=2, dim=1)
1109
+
1110
+ similarities = torch.mm(x_norm, proto_norm.t())
1111
+ similarities = similarities / self.temperature
1112
+ similarities = similarities.view(batch_size, time_steps, self.n_prototypes)
1113
+
1114
+ self._compute_regularization(x, similarities)
1115
+
1116
+ return similarities
1117
+
1118
+ def _compute_regularization(self, x, similarities):
1119
+ """Compute regularization losses R_1 and R_2"""
1120
+ # R_1: Prototype diversity
1121
+ proto_norm = F.normalize(self.prototypes, p=2, dim=1)
1122
+ proto_similarity = torch.mm(proto_norm, proto_norm.t())
1123
+
1124
+ mask = torch.ones_like(proto_similarity) - torch.eye(
1125
+ self.n_prototypes, device=proto_similarity.device
1126
+ )
1127
+ self.R_1 = (proto_similarity * mask).pow(2).sum() / (
1128
+ self.n_prototypes * (self.n_prototypes - 1)
1129
+ )
1130
+
1131
+ # R_2: Cluster compactness
1132
+ max_sim, assigned_proto = similarities.max(dim=-1)
1133
+ self.R_2 = -max_sim.mean()
1134
+ ```
1135
+
1136
+ **✅ COMPLETE:** F0Extractor is now fully implemented using TorchYIN!
1137
+
1138
+ ---
1139
+
1140
+ ### 5. `requirements.txt`
1141
+
1142
+ **Purpose:** Python package dependencies.
1143
+
1144
+ **Content:**
1145
+
1146
+ ```txt
1147
+ # Core dependencies
1148
+ speechbrain
1149
+ torch>=1.13.0
1150
+ torchaudio>=0.13.0
1151
+ gradio>=4.0.0
1152
+
1153
+ # Audio processing
1154
+ librosa
1155
+ soundfile
1156
+
1157
+ # Visualization
1158
+ matplotlib
1159
+ numpy
1160
+ scipy
1161
+
1162
+ # HuggingFace integration
1163
+ transformers
1164
+ huggingface_hub
1165
+
1166
+ # Additional utilities
1167
+ hyperpyyaml
1168
+ sentencepiece
1169
+
1170
+ # F0 extraction with TorchYIN
1171
+ torchyin
1172
+ ```
1173
+
1174
+ **Note:** `torchyin` is required for F0 (pitch) extraction using the YIN algorithm.
1175
+
1176
+ ---
1177
+
1178
+ ### 6. `README.md` (for Hugging Face Space)
1179
+
1180
+ **Purpose:** Documentation displayed on your Space page.
1181
+
1182
+ **Content:**
1183
+
1184
+ ```markdown
1185
+ ---
1186
+ title: ProTeVa Yoruba Tone Recognition
1187
+ emoji: 🎵
1188
+ colorFrom: blue
1189
+ colorTo: green
1190
+ sdk: gradio
1191
+ sdk_version: 4.44.0
1192
+ app_file: app.py
1193
+ pinned: false
1194
+ license: apache-2.0
1195
+ ---
1196
+
1197
+ # ProTeVa: Yoruba Tone Recognition
1198
+
1199
+ This Space demonstrates **ProTeVa** (Prototype-based Tone Variant Autoencoder), a neural model for recognizing tone patterns in Yoruba language.
1200
+
1201
+ ## Features
1202
+
1203
+ - 🎤 **Record or Upload**: Use your microphone or upload audio files
1204
+ - 🎯 **Tone Detection**: Automatically detects 3 Yoruba tones (Low, Mid, High)
1205
+ - 📊 **F0 Visualization**: Shows fundamental frequency contours
1206
+ - 🎨 **Interactive UI**: Real-time predictions with visual feedback
1207
+
1208
+ ## Yoruba Tones
1209
+
1210
+ Yoruba is a tonal language with three contrastive tones:
1211
+
1212
+ 1. **High Tone (H)** (◌́) - Example: ágbó (elder)
1213
+ 2. **Low Tone (B)** (◌̀) - Example: àgbò (ram)
1214
+ 3. **Mid Tone (M)** (◌) - Example: agbo (medicine)
1215
+
1216
+ ## Model Architecture
1217
+
1218
+ - **Feature Extractor**: HuBERT (Orange/SSA-HuBERT-base-60k)
1219
+ - **Encoder**: 2-layer Bidirectional GRU (512 hidden units)
1220
+ - **Decoder**: VanillaNN (2 blocks, 512 neurons)
1221
+ - **Prototype Layer**: 10 learnable tone prototypes
1222
+ - **Output**: CTC-based sequence prediction
1223
+
1224
+ ## Training Details
1225
+
1226
+ - **Dataset**: Yoruba speech corpus
1227
+ - **Sample Rate**: 16kHz
1228
+ - **Loss Functions**:
1229
+ - CTC loss for tone sequence
1230
+ - MSE loss for F0 reconstruction
1231
+ - Prototype regularization (R₁ + R₂)
1232
+ - **Training Duration**: 65 epochs
1233
+
1234
+ ## Usage
1235
+
1236
+ 1. Click on the microphone icon to record or upload an audio file
1237
+ 2. Click "🔍 Predict Tones"
1238
+ 3. View predicted tone sequence and F0 contour
1239
+
1240
+ ## Citation
1241
+
1242
+ If you use this model in your research, please cite:
1243
+
1244
+ ```bibtex
1245
+ @article{proteva2025,
1246
+ title={ProTeVa: Prototype-based Tone Variant Autoencoder for Yoruba Tone Recognition},
1247
+ author={Your Name},
1248
+ year={2025}
1249
+ }
1250
+ ```
1251
+
1252
+ ## Acknowledgments
1253
+
1254
+ Built with [SpeechBrain](https://speechbrain.github.io/) and [Gradio](https://gradio.app/).
1255
+
1256
+ ## License
1257
+
1258
+ Apache 2.0
1259
+ ```
1260
+
1261
+ ---
1262
+
1263
+ ## Space Detection Implementation
1264
+
1265
+ ProTeVa implements intelligent word boundary detection using acoustic features. Since the base model only predicts 3 tones (H, B, M), space tokens (label 4) are inserted via post-processing.
1266
+
1267
+ ### Detection Methods
1268
+
1269
+ #### 1. **Silence Detection** (`'silence'`)
1270
+ - Analyzes F0 contours for gaps with low/zero pitch
1271
+ - Gaps longer than `SILENCE_THRESHOLD` (default: 0.15s) indicate word boundaries
1272
+ - Effective for clear pauses between words
1273
+
1274
+ #### 2. **F0 Drop Detection** (`'f0_drop'`)
1275
+ - Detects significant pitch drops between consecutive tones
1276
+ - Drops greater than `F0_DROP_THRESHOLD` (default: 20%) suggest boundaries
1277
+ - Mimics natural prosody where pitch resets at word beginnings
1278
+
1279
+ #### 3. **Duration-Based** (`'duration'`)
1280
+ - Simple heuristic based on regular intervals
1281
+ - Inserts spaces every N tones (configurable)
1282
+ - Less accurate but works without acoustic features
1283
+
1284
+ #### 4. **Combined Method** (`'combined'`) - **RECOMMENDED**
1285
+ - Combines silence and F0 drop detection
1286
+ - Higher confidence when both methods agree
1287
+ - Balances precision and recall
1288
+
1289
+ ### Configuration
1290
+
1291
+ Edit `config.py` to customize:
1292
+
1293
+ ```python
1294
+ # Enable/disable space detection
1295
+ ENABLE_SPACE_DETECTION = True
1296
+
1297
+ # Choose detection method
1298
+ SPACE_DETECTION_METHOD = "combined" # Best results
1299
+
1300
+ # Fine-tune thresholds
1301
+ SILENCE_THRESHOLD = 0.15 # Adjust for speaker style
1302
+ F0_DROP_THRESHOLD = 0.20 # 20% F0 drop
1303
+ SPACE_CONFIDENCE_THRESHOLD = 0.6
1304
+ ```
1305
+
1306
+ ### Implementation Details
1307
+
1308
+ 1. **Model predicts base tones** (1, 2, 3) using CTC
1309
+ 2. **Post-processing analyzes** F0 contours and silence patterns
1310
+ 3. **Space tokens (4) inserted** at detected word boundaries
1311
+ 4. **Visualization** shows spaces as vertical separators
1312
+
1313
+ ### Tuning Tips
1314
+
1315
+ - **Too many spaces?** Increase thresholds or use `'f0_drop'` only
1316
+ - **Too few spaces?** Decrease thresholds or use `'combined'`
1317
+ - **Disable completely:** Set `ENABLE_SPACE_DETECTION = False`
1318
+
1319
+ ---
1320
+
1321
+ ## Testing & Troubleshooting
1322
+
1323
+ ### Local Testing Checklist
1324
+
1325
+ ```bash
1326
+ # 1. Install dependencies
1327
+ pip install -r requirements.txt
1328
+
1329
+ # 2. Verify file structure
1330
+ ls -la
1331
+ # Should see: app.py, custom_interface.py, inference.yaml, modules.py, requirements.txt
1332
+ # Should see: CKPT+2025-10-20+04-14-23+00/ folder
1333
+
1334
+ # 3. Check checkpoint folder
1335
+ ls CKPT+2025-10-20+04-14-23+00/
1336
+ # Should see: model.ckpt, wav2vec2.ckpt, tokenizer.ckpt
1337
+
1338
+ # 4. Run the app
1339
+ python app.py
1340
+
1341
+ # 5. Open browser
1342
+ # http://localhost:7860
1343
+
1344
+ # 6. Test functionality
1345
+ # - Record audio
1346
+ # - Upload file
1347
+ # - Check predictions
1348
+ # - Verify plots display
1349
+ ```
1350
+
1351
+ ### Common Issues
1352
+
1353
+ #### Issue 1: "Module not found: modules"
1354
+ **Solution:** Ensure `modules.py` is in the same directory as `app.py`
1355
+
1356
+ #### Issue 2: "Cannot find checkpoint"
1357
+ **Solution:** Update `save_folder` in `inference.yaml` to match your checkpoint folder name exactly
1358
+
1359
+ #### Issue 3: "F0Extractor not implemented"
1360
+ **Solution:** Implement the `forward()` method in `F0Extractor` class in `modules.py`
1361
+
1362
+ #### Issue 4: "CUDA out of memory"
1363
+ **Solution:** Set `device: cpu` in `inference.yaml` or upgrade to GPU hardware
1364
+
1365
+ #### Issue 5: "File too large for upload"
1366
+ **Solution:** Use Git LFS for checkpoint files:
1367
+ ```bash
1368
+ git lfs install
1369
+ git lfs track "*.ckpt"
1370
+ git add .gitattributes
1371
+ ```
1372
+
1373
+ #### Issue 6: "Model loading timeout"
1374
+ **Solution:** Large models may take 2-5 minutes to load on first run. Check Space logs.
1375
+
1376
+ ### Verification Steps on Hugging Face Spaces
1377
+
1378
+ 1. ✅ Space builds without errors (check "Logs" tab)
1379
+ 2. ✅ Model loads successfully (check startup logs)
1380
+ 3. ✅ UI displays correctly
1381
+ 4. ✅ Can record audio from microphone
1382
+ 5. ✅ Can upload audio files
1383
+ 6. ✅ Predictions are generated
1384
+ 7. ✅ F0 plot appears
1385
+ 8. ✅ Tone visualization shows
1386
+ 9. ✅ Statistics display correctly
1387
+ 10. ✅ No errors in browser console
1388
+
1389
+ ---
1390
+
1391
+ ## Quick Reference
1392
+
1393
+ ### File Checklist
1394
+ - [ ] `config.py` (central configuration - **UPDATE THIS FIRST**)
1395
+ - [ ] `app.py` (main application)
1396
+ - [ ] `custom_interface.py` (inference interface with space detection)
1397
+ - [ ] `inference.yaml` (model configuration)
1398
+ - [ ] `modules.py` (custom modules - F0Extractor, PrototypeLayer, PitchDecoder)
1399
+ - [ ] `requirements.txt` (dependencies)
1400
+ - [ ] `README.md` (Space documentation)
1401
+ - [ ] `CKPT+2025-10-20+08-19-07+00/` (checkpoint folder)
1402
+ - [ ] `model.ckpt`
1403
+ - [ ] `wav2vec2.ckpt`
1404
+ - [ ] `tokenizer.ckpt`
1405
+
1406
+ ### Configuration Updates
1407
+ - [ ] Update `CHECKPOINT_FOLDER` in `config.py` to match your checkpoint folder
1408
+ - [ ] Configure space detection settings in `config.py`:
1409
+ - `ENABLE_SPACE_DETECTION`: True/False
1410
+ - `SPACE_DETECTION_METHOD`: 'combined', 'silence', 'f0_drop', or 'duration'
1411
+ - [ ] Ensure `save_folder` in `inference.yaml` matches `config.py`
1412
+ - [ ] Add your name/info to `README.md`
1413
+
1414
+ ### Deployment Commands
1415
+ ```bash
1416
+ # Local test
1417
+ python app.py
1418
+
1419
+ # Deploy to Hugging Face
1420
+ git clone https://huggingface.co/spaces/USERNAME/SPACE_NAME
1421
+ cd SPACE_NAME
1422
+ cp -r /path/to/files/* ./
1423
+ git lfs track "*.ckpt"
1424
+ git add .
1425
+ git commit -m "Deploy ProTeVa"
1426
+ git push
1427
+ ```
1428
+
1429
+ ---
1430
+
1431
+ ## Support & Resources
1432
+
1433
+ - **SpeechBrain Docs**: https://speechbrain.readthedocs.io/
1434
+ - **Gradio Docs**: https://gradio.app/docs/
1435
+ - **Hugging Face Spaces**: https://huggingface.co/docs/hub/spaces
1436
+
1437
+ ---
1438
+
1439
+ **You're ready to deploy! 🚀**
1440
+
1441
+ Follow the steps, test locally, then push to Hugging Face Spaces.
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio App for ProTeVa Yoruba Tone Recognition
3
+ Hugging Face Spaces deployment
4
+ """
5
+
6
+ import gradio as gr
7
+ from speechbrain.inference.interfaces import foreign_class
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import torch
11
+ import config
12
+
13
+ # ============ CONFIGURATION ============
14
+
15
+ # Import tone info from config
16
+ TONE_INFO = config.TONE_INFO
17
+
18
+ # ============ MODEL LOADING ============
19
+
20
+ print("Loading ProTeVa tone recognition model...")
21
+ print(f"Checkpoint folder: {config.CHECKPOINT_FOLDER}")
22
+
23
+ try:
24
+ tone_recognizer = foreign_class(
25
+ source="./",
26
+ pymodule_file="custom_interface.py",
27
+ classname="ProTeVaToneRecognizer",
28
+ hparams_file="inference.yaml",
29
+ savedir=config.PRETRAINED_MODEL_DIR
30
+ )
31
+ print("✓ Model loaded successfully!")
32
+
33
+ # Validate configuration
34
+ if config.validate_config():
35
+ print(f"✓ Space detection: {'ENABLED' if config.ENABLE_SPACE_DETECTION else 'DISABLED'}")
36
+ if config.ENABLE_SPACE_DETECTION:
37
+ print(f" Method: {config.SPACE_DETECTION_METHOD}")
38
+ except Exception as e:
39
+ print(f"✗ Error loading model: {e}")
40
+ tone_recognizer = None
41
+
42
+ # ============ HELPER FUNCTIONS ============
43
+
44
+ def format_tone_sequence(tone_indices, tone_names):
45
+ """Format tone sequence with colors and symbols"""
46
+ if not tone_indices:
47
+ return "No tones detected"
48
+
49
+ formatted = []
50
+ for idx, name in zip(tone_indices, tone_names):
51
+ info = config.get_tone_info(idx)
52
+ formatted.append(f"{info['name']} ({info['symbol']})")
53
+
54
+ return " → ".join(formatted)
55
+
56
+ def create_f0_plot(f0_contour):
57
+ """Create F0 contour plot"""
58
+ if f0_contour is None or len(f0_contour) == 0:
59
+ return None
60
+
61
+ # Convert to numpy
62
+ if isinstance(f0_contour, torch.Tensor):
63
+ f0_numpy = f0_contour.cpu().numpy().flatten()
64
+ else:
65
+ f0_numpy = np.array(f0_contour).flatten()
66
+
67
+ # Create plot
68
+ fig, ax = plt.subplots(figsize=(10, 4))
69
+ time = np.arange(len(f0_numpy)) / len(f0_numpy)
70
+ ax.plot(time, f0_numpy, linewidth=2, color='#3498db')
71
+ ax.set_xlabel('Normalized Time', fontsize=12)
72
+ ax.set_ylabel('F0 (Hz)', fontsize=12)
73
+ ax.set_title('Fundamental Frequency Contour', fontsize=14, fontweight='bold')
74
+ ax.grid(True, alpha=0.3)
75
+ plt.tight_layout()
76
+
77
+ return fig
78
+
79
+ def create_tone_visualization(tone_indices):
80
+ """Create visual representation of tone sequence"""
81
+ if not tone_indices:
82
+ return None
83
+
84
+ fig, ax = plt.subplots(figsize=(max(12, len(tone_indices) * 0.8), 3))
85
+
86
+ # Prepare data
87
+ x_positions = []
88
+ colors = []
89
+ labels = []
90
+
91
+ position = 0
92
+ for idx in tone_indices:
93
+ info = config.get_tone_info(idx)
94
+
95
+ # Space tokens get different visual treatment
96
+ if idx == 4:
97
+ # Draw vertical line for space
98
+ ax.axvline(x=position - 0.25, color=info['color'],
99
+ linewidth=3, linestyle='--', alpha=0.7)
100
+ else:
101
+ x_positions.append(position)
102
+ colors.append(info['color'])
103
+ labels.append(info['symbol'])
104
+ position += 1
105
+
106
+ # Draw tone bars
107
+ if x_positions:
108
+ ax.bar(x_positions, [1] * len(x_positions), color=colors, alpha=0.7,
109
+ edgecolor='black', linewidth=2, width=0.8)
110
+
111
+ # Add tone symbols
112
+ for i, (pos, label) in enumerate(zip(x_positions, labels)):
113
+ ax.text(pos, 0.5, label, ha='center', va='center',
114
+ fontsize=20, fontweight='bold')
115
+
116
+ # Configure plot
117
+ if x_positions:
118
+ ax.set_xlim(-0.5, max(x_positions) + 0.5)
119
+ ax.set_ylim(0, 1.2)
120
+ if x_positions:
121
+ ax.set_xticks(x_positions)
122
+ ax.set_xticklabels([f"T{i+1}" for i in range(len(x_positions))])
123
+ ax.set_ylabel('Tone', fontsize=12)
124
+ ax.set_title('Tone Sequence Visualization (| = word boundary)',
125
+ fontsize=14, fontweight='bold')
126
+ ax.set_yticks([])
127
+ plt.tight_layout()
128
+
129
+ return fig
130
+
131
+ # ============ PREDICTION FUNCTION ============
132
+
133
+ def predict_tone(audio_file):
134
+ """Main prediction function for Gradio interface"""
135
+ if tone_recognizer is None:
136
+ return "❌ Model not loaded. Please check configuration.", None, None, ""
137
+
138
+ if audio_file is None:
139
+ return "⚠️ Please provide an audio file", None, None, ""
140
+
141
+ try:
142
+ # Get predictions
143
+ tone_indices, tone_names, f0_contour = tone_recognizer.classify_file(audio_file)
144
+
145
+ # Format output
146
+ tone_text = format_tone_sequence(tone_indices, tone_names)
147
+
148
+ # Create visualizations
149
+ f0_plot = create_f0_plot(f0_contour)
150
+ tone_viz = create_tone_visualization(tone_indices)
151
+
152
+ # Create statistics
153
+ num_tones = len([t for t in tone_indices if t != 4])
154
+ num_spaces = len([t for t in tone_indices if t == 4])
155
+
156
+ stats = f"""
157
+ 📊 **Prediction Statistics:**
158
+ - Total tones detected: {num_tones}
159
+ - Word boundaries detected: {num_spaces}
160
+ - Sequence length: {len(tone_indices)}
161
+
162
+ 🎵 **Tone Distribution:**
163
+ - High tones (H): {tone_indices.count(1)}
164
+ - Low tones (B): {tone_indices.count(2)}
165
+ - Mid tones (M): {tone_indices.count(3)}
166
+
167
+ ⚙️ **Detection Settings:**
168
+ - Space detection: {'ENABLED' if config.ENABLE_SPACE_DETECTION else 'DISABLED'}
169
+ {f"- Method: {config.SPACE_DETECTION_METHOD}" if config.ENABLE_SPACE_DETECTION else ""}
170
+ """
171
+
172
+ return tone_text, f0_plot, tone_viz, stats
173
+
174
+ except Exception as e:
175
+ import traceback
176
+ error_details = traceback.format_exc()
177
+ return f"❌ Error during prediction: {str(e)}\n\n{error_details}", None, None, ""
178
+
179
+ # ============ GRADIO INTERFACE ============
180
+
181
+ custom_css = """
182
+ .gradio-container {
183
+ font-family: 'Arial', sans-serif;
184
+ }
185
+ .output-text {
186
+ font-size: 18px;
187
+ font-weight: bold;
188
+ }
189
+ """
190
+
191
+ with gr.Blocks(css=custom_css, title="ProTeVa Tone Recognition") as demo:
192
+
193
+ gr.Markdown(
194
+ f"""
195
+ # 🎵 ProTeVa: Yoruba Tone Recognition
196
+
197
+ Upload an audio file or record your voice to detect Yoruba tone patterns.
198
+
199
+ **Yoruba Tones:**
200
+ - **High Tone (H)** (◌́): Syllable with high pitch
201
+ - **Low Tone (B)** (◌̀): Syllable with low pitch
202
+ - **Mid Tone (M)** (◌): Syllable with neutral/middle pitch
203
+ - **Space ( | )**: Word boundary (detected automatically)
204
+
205
+ **Space Detection:** {config.SPACE_DETECTION_METHOD if config.ENABLE_SPACE_DETECTION else 'OFF'}
206
+ """
207
+ )
208
+
209
+ with gr.Row():
210
+ with gr.Column(scale=1):
211
+ gr.Markdown("### 🎤 Input Audio")
212
+
213
+ audio_input = gr.Audio(
214
+ sources=["microphone", "upload"],
215
+ type="filepath",
216
+ label="Record or Upload Audio",
217
+ waveform_options={"show_controls": True}
218
+ )
219
+
220
+ predict_btn = gr.Button("🔍 Predict Tones", variant="primary", size="lg")
221
+
222
+ gr.Markdown(
223
+ """
224
+ ### 📝 Tips:
225
+ - Speak clearly in Yoruba
226
+ - Keep recordings under 10 seconds
227
+ - Avoid background noise
228
+ - Pause slightly between words for better boundary detection
229
+ """
230
+ )
231
+
232
+ with gr.Column(scale=2):
233
+ gr.Markdown("### 🎯 Results")
234
+
235
+ tone_output = gr.Textbox(
236
+ label="Predicted Tone Sequence",
237
+ lines=3,
238
+ elem_classes="output-text"
239
+ )
240
+
241
+ stats_output = gr.Markdown(label="Statistics")
242
+
243
+ with gr.Tabs():
244
+ with gr.Tab("F0 Contour"):
245
+ f0_plot = gr.Plot(label="Fundamental Frequency")
246
+
247
+ with gr.Tab("Tone Visualization"):
248
+ tone_viz = gr.Plot(label="Tone Sequence")
249
+
250
+ predict_btn.click(
251
+ fn=predict_tone,
252
+ inputs=audio_input,
253
+ outputs=[tone_output, f0_plot, tone_viz, stats_output]
254
+ )
255
+
256
+ gr.Markdown("### 📚 Example Audios")
257
+ gr.Markdown("*Upload Yoruba speech samples to test the model*")
258
+
259
+ gr.Markdown(
260
+ f"""
261
+ ---
262
+
263
+ **About ProTeVa:**
264
+
265
+ ProTeVa (Prototype-based Tone Variant Autoencoder) is a neural model for tone recognition.
266
+
267
+ **Model Architecture:**
268
+ - Feature Extractor: HuBERT (Orange/SSA-HuBERT-base-60k)
269
+ - Encoder: {config.RNN_LAYERS}-layer Bidirectional GRU ({config.RNN_NEURONS} neurons)
270
+ - Prototype Layer: {config.N_PROTOTYPES} learnable tone prototypes
271
+ - Decoder: F0 reconstruction
272
+ - Output: CTC-based tone sequence prediction + acoustic space detection
273
+
274
+ **Space Detection:**
275
+ - Method: {config.SPACE_DETECTION_METHOD if config.ENABLE_SPACE_DETECTION else 'Disabled'}
276
+ - Uses F0 contours, silence patterns, and tone duration
277
+ - Automatically detects word boundaries in continuous speech
278
+
279
+ Built with ❤️ using SpeechBrain and Gradio
280
+
281
+ **Model Checkpoint:** {config.CHECKPOINT_FOLDER}
282
+ """
283
+ )
284
+
285
+ if __name__ == "__main__":
286
+ demo.launch(
287
+ share=config.GRADIO_SHARE,
288
+ server_name=config.GRADIO_SERVER_NAME,
289
+ server_port=config.GRADIO_SERVER_PORT
290
+ )
config.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ProTeVa Configuration File
3
+ Central configuration for model paths and tone settings
4
+ """
5
+
6
+ import os
7
+
8
+ # ============ PATH CONFIGURATION ============
9
+
10
+ # Checkpoint folder name - UPDATE THIS when using a different checkpoint
11
+ CHECKPOINT_FOLDER = "CKPT+2025-10-20+08-19-07+00"
12
+
13
+ # Get the absolute path to the checkpoint folder
14
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
15
+ CHECKPOINT_PATH = os.path.join(BASE_DIR, CHECKPOINT_FOLDER)
16
+
17
+ # Model files
18
+ MODEL_CKPT = os.path.join(CHECKPOINT_PATH, "model.ckpt")
19
+ WAV2VEC2_CKPT = os.path.join(CHECKPOINT_PATH, "wav2vec2.ckpt")
20
+ TOKENIZER_CKPT = os.path.join(CHECKPOINT_PATH, "tokenizer.ckpt")
21
+
22
+ # ============ MODEL CONFIGURATION ============
23
+
24
+ # Audio settings
25
+ SAMPLE_RATE = 16000
26
+
27
+ # Model architecture
28
+ RNN_LAYERS = 2
29
+ RNN_NEURONS = 512
30
+ DNN_BLOCKS = 2
31
+ DNN_NEURONS = 512
32
+ N_PROTOTYPES = 10
33
+ EMB_DIM = 768
34
+
35
+ # ============ TONE CONFIGURATION ============
36
+
37
+ # Tone label mapping (from labelencoder.txt)
38
+ # These are the indices used by the trained model
39
+ TONE_LABELS = {
40
+ 0: "BLANK", # CTC blank token
41
+ 1: "H", # High tone
42
+ 2: "B", # Low tone (Bas)
43
+ 3: "M" # Mid tone
44
+ }
45
+
46
+ # Output neurons (number of classes)
47
+ OUTPUT_NEURONS = 4 # blank, H, B, M
48
+
49
+ # CTC blank index
50
+ BLANK_INDEX = 0
51
+
52
+ # ============ SPACE/WORD BOUNDARY DETECTION ============
53
+
54
+ # Enable space detection between tones
55
+ ENABLE_SPACE_DETECTION = True
56
+
57
+ # Space detection method: 'silence', 'f0_drop', 'duration', or 'combined'
58
+ SPACE_DETECTION_METHOD = "combined"
59
+
60
+ # Silence threshold (in seconds) - gaps longer than this are word boundaries
61
+ SILENCE_THRESHOLD = 0.15
62
+
63
+ # F0 drop threshold (percentage) - F0 drops greater than this indicate boundaries
64
+ F0_DROP_THRESHOLD = 0.20 # 20% drop
65
+
66
+ # Duration threshold (in seconds) - long tones might indicate word endings
67
+ DURATION_THRESHOLD = 0.25
68
+
69
+ # Minimum confidence for space insertion
70
+ SPACE_CONFIDENCE_THRESHOLD = 0.6
71
+
72
+ # ============ VISUALIZATION CONFIGURATION ============
73
+
74
+ # Tone display information for UI
75
+ TONE_INFO = {
76
+ 1: {
77
+ "name": "High Tone",
78
+ "symbol": "◌́",
79
+ "color": "#e74c3c",
80
+ "label": "H"
81
+ },
82
+ 2: {
83
+ "name": "Low Tone",
84
+ "symbol": "◌̀",
85
+ "color": "#3498db",
86
+ "label": "B"
87
+ },
88
+ 3: {
89
+ "name": "Mid Tone",
90
+ "symbol": "◌",
91
+ "color": "#2ecc71",
92
+ "label": "M"
93
+ },
94
+ 4: {
95
+ "name": "Space",
96
+ "symbol": " | ",
97
+ "color": "#95a5a6",
98
+ "label": "SPACE"
99
+ }
100
+ }
101
+
102
+ # ============ DEPLOYMENT CONFIGURATION ============
103
+
104
+ # Device (cpu or cuda)
105
+ DEVICE = "cpu"
106
+
107
+ # Gradio server settings
108
+ GRADIO_SERVER_NAME = "0.0.0.0"
109
+ GRADIO_SERVER_PORT = 7860
110
+ GRADIO_SHARE = False
111
+
112
+ # Model save directory for SpeechBrain
113
+ PRETRAINED_MODEL_DIR = "./pretrained_model"
114
+
115
+ # ============ HELPER FUNCTIONS ============
116
+
117
+ def get_checkpoint_path():
118
+ """Get the checkpoint folder path"""
119
+ return CHECKPOINT_PATH
120
+
121
+ def get_tone_name(idx):
122
+ """Get the tone name from index"""
123
+ return TONE_LABELS.get(idx, f"Unknown({idx})")
124
+
125
+ def get_tone_info(idx):
126
+ """Get the tone display information"""
127
+ return TONE_INFO.get(idx, {
128
+ "name": f"Unknown({idx})",
129
+ "symbol": "?",
130
+ "color": "#95a5a6",
131
+ "label": f"UNK{idx}"
132
+ })
133
+
134
+ def validate_config():
135
+ """Validate that the configuration is correct"""
136
+ errors = []
137
+
138
+ # Check if checkpoint folder exists
139
+ if not os.path.exists(CHECKPOINT_PATH):
140
+ errors.append(f"Checkpoint folder not found: {CHECKPOINT_PATH}")
141
+
142
+ # Check if required model files exist
143
+ if not os.path.exists(MODEL_CKPT):
144
+ errors.append(f"Model checkpoint not found: {MODEL_CKPT}")
145
+ if not os.path.exists(WAV2VEC2_CKPT):
146
+ errors.append(f"Wav2Vec2 checkpoint not found: {WAV2VEC2_CKPT}")
147
+ if not os.path.exists(TOKENIZER_CKPT):
148
+ errors.append(f"Tokenizer checkpoint not found: {TOKENIZER_CKPT}")
149
+
150
+ # Check tone labels match output neurons
151
+ non_blank_labels = [k for k in TONE_LABELS.keys() if k != BLANK_INDEX]
152
+ if len(non_blank_labels) != OUTPUT_NEURONS - 1:
153
+ errors.append(f"Mismatch: {len(non_blank_labels)} tone labels but {OUTPUT_NEURONS-1} expected")
154
+
155
+ if errors:
156
+ print("⚠️ Configuration Errors:")
157
+ for error in errors:
158
+ print(f" - {error}")
159
+ return False
160
+
161
+ print("✅ Configuration validated successfully!")
162
+ return True
163
+
164
+ # Run validation when module is imported
165
+ if __name__ != "__main__":
166
+ # Only show validation messages in development
167
+ pass
custom_interface.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom SpeechBrain inference interface for ProTeVa tone recognition model
3
+ Includes intelligent space/word boundary detection
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from speechbrain.inference.interfaces import Pretrained
9
+ import config
10
+
11
+
12
+ class ProTeVaToneRecognizer(Pretrained):
13
+ """
14
+ Custom interface for ProTeVa tone recognition model
15
+ Predicts tone sequences for Yoruba language (3 tones)
16
+ Includes post-processing for space detection
17
+ """
18
+
19
+ HPARAMS_NEEDED = ["wav2vec2", "enc", "dec", "pitch_dec",
20
+ "proto", "output_lin", "log_softmax",
21
+ "label_encoder", "f0Compute", "sample_rate"]
22
+
23
+ MODULES_NEEDED = ["wav2vec2", "enc", "dec", "pitch_dec",
24
+ "proto", "output_lin"]
25
+
26
+ def __init__(self, *args, **kwargs):
27
+ super().__init__(*args, **kwargs)
28
+ self.sample_rate = self.hparams.sample_rate
29
+
30
+ def classify_file(self, path):
31
+ """
32
+ Classify tone sequence from audio file
33
+
34
+ Arguments
35
+ ---------
36
+ path : str
37
+ Path to audio file
38
+
39
+ Returns
40
+ -------
41
+ tone_sequence : list
42
+ Predicted tone labels (integers)
43
+ tone_names : list
44
+ Predicted tone names (strings)
45
+ f0_contour : torch.Tensor
46
+ Reconstructed F0 contour
47
+ """
48
+ waveform = self.load_audio(path)
49
+ wavs = waveform.unsqueeze(0)
50
+ wav_lens = torch.tensor([1.0])
51
+
52
+ tone_sequences, tone_names, f0_contours = self.classify_batch(wavs, wav_lens)
53
+
54
+ return tone_sequences[0], tone_names[0], f0_contours[0]
55
+
56
+ def classify_batch(self, wavs, wav_lens):
57
+ """
58
+ Classify tones from a batch of waveforms
59
+
60
+ Arguments
61
+ ---------
62
+ wavs : torch.Tensor
63
+ Batch of waveforms [batch, time]
64
+ wav_lens : torch.Tensor
65
+ Relative lengths of waveforms
66
+
67
+ Returns
68
+ -------
69
+ tone_sequences : list of lists
70
+ Predicted tone label indices (with spaces if enabled)
71
+ tone_names : list of lists
72
+ Predicted tone names
73
+ f0_contours : torch.Tensor
74
+ Reconstructed F0 contours
75
+ """
76
+ self.eval()
77
+
78
+ with torch.no_grad():
79
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
80
+
81
+ # Extract features from HuBERT
82
+ feats = self.modules.wav2vec2(wavs)
83
+
84
+ # Extract F0
85
+ f0 = self.hparams.f0Compute(wavs, target_size=feats.shape[1])
86
+
87
+ # Encode with BiGRU
88
+ x, hidden = self.modules.enc(feats)
89
+
90
+ # Decode with VanillaNN
91
+ x = self.modules.dec(x)
92
+
93
+ # Pitch decoder - reconstruct F0
94
+ dec_out = self.modules.pitch_dec(x)
95
+
96
+ # Prototype layer - similarity to learned tone prototypes
97
+ proto_out = self.modules.proto(x)
98
+
99
+ # Classification layer
100
+ logits = self.modules.output_lin(proto_out)
101
+ log_probs = self.hparams.log_softmax(logits)
102
+
103
+ # CTC greedy decoding
104
+ tone_sequences = self._ctc_decode(log_probs, wav_lens)
105
+
106
+ # Apply space detection if enabled
107
+ if config.ENABLE_SPACE_DETECTION:
108
+ tone_sequences = self._insert_spaces(
109
+ tone_sequences,
110
+ f0.cpu().numpy(),
111
+ log_probs.cpu().numpy(),
112
+ feats.shape[1]
113
+ )
114
+
115
+ # Convert indices to tone names
116
+ tone_names = []
117
+ for seq in tone_sequences:
118
+ names = [self._get_tone_name(idx) for idx in seq if idx != 0]
119
+ tone_names.append(names)
120
+
121
+ return tone_sequences, tone_names, dec_out
122
+
123
+ def _ctc_decode(self, log_probs, wav_lens):
124
+ """CTC greedy decoding"""
125
+ from speechbrain.decoders import ctc_greedy_decode
126
+
127
+ sequences = ctc_greedy_decode(
128
+ log_probs,
129
+ wav_lens,
130
+ blank_index=config.BLANK_INDEX
131
+ )
132
+
133
+ return sequences
134
+
135
+ def _insert_spaces(self, sequences, f0_contours, log_probs, feat_len):
136
+ """
137
+ Insert space tokens (label 4) between tones based on acoustic features
138
+
139
+ Arguments
140
+ ---------
141
+ sequences : list of lists
142
+ Tone sequences without spaces
143
+ f0_contours : numpy.ndarray
144
+ F0 contours [batch, time, 1]
145
+ log_probs : numpy.ndarray
146
+ Log probabilities from model [batch, time, classes]
147
+ feat_len : int
148
+ Length of feature sequence
149
+
150
+ Returns
151
+ -------
152
+ sequences_with_spaces : list of lists
153
+ Tone sequences with space tokens (4) inserted
154
+ """
155
+ sequences_with_spaces = []
156
+
157
+ for seq_idx, sequence in enumerate(sequences):
158
+ if len(sequence) == 0:
159
+ sequences_with_spaces.append(sequence)
160
+ continue
161
+
162
+ # Get F0 for this sequence
163
+ f0 = f0_contours[seq_idx].flatten()
164
+
165
+ # Detect word boundaries
166
+ new_sequence = []
167
+
168
+ for i, tone in enumerate(sequence):
169
+ new_sequence.append(tone)
170
+
171
+ # Don't add space after last tone
172
+ if i == len(sequence) - 1:
173
+ continue
174
+
175
+ # Calculate space likelihood based on method
176
+ should_insert_space = False
177
+
178
+ if config.SPACE_DETECTION_METHOD == "combined":
179
+ should_insert_space = self._detect_space_combined(
180
+ f0, i, len(sequence), feat_len
181
+ )
182
+ elif config.SPACE_DETECTION_METHOD == "silence":
183
+ should_insert_space = self._detect_space_silence(
184
+ f0, i, len(sequence), feat_len
185
+ )
186
+ elif config.SPACE_DETECTION_METHOD == "f0_drop":
187
+ should_insert_space = self._detect_space_f0_drop(
188
+ f0, i, len(sequence)
189
+ )
190
+ elif config.SPACE_DETECTION_METHOD == "duration":
191
+ should_insert_space = self._detect_space_duration(
192
+ i, len(sequence), feat_len
193
+ )
194
+
195
+ if should_insert_space:
196
+ new_sequence.append(4) # Space token
197
+
198
+ sequences_with_spaces.append(new_sequence)
199
+
200
+ return sequences_with_spaces
201
+
202
+ def _detect_space_silence(self, f0, tone_idx, total_tones, feat_len):
203
+ """Detect space based on silence (low F0) between tones"""
204
+ # Estimate frame positions for current and next tone
205
+ frames_per_tone = feat_len // max(total_tones, 1)
206
+ current_end = min((tone_idx + 1) * frames_per_tone, len(f0) - 1)
207
+ next_start = min((tone_idx + 2) * frames_per_tone, len(f0))
208
+
209
+ if current_end >= next_start or next_start >= len(f0):
210
+ return False
211
+
212
+ # Check gap between tones for silence
213
+ gap_f0 = f0[current_end:next_start]
214
+ silence_ratio = np.sum(gap_f0 < 50) / max(len(gap_f0), 1) # Pitch < 50 Hz is silence
215
+
216
+ return silence_ratio > 0.5
217
+
218
+ def _detect_space_f0_drop(self, f0, tone_idx, total_tones):
219
+ """Detect space based on F0 drop between tones"""
220
+ if tone_idx >= len(f0) - 1:
221
+ return False
222
+
223
+ # Calculate average F0 for current and next tone regions
224
+ window_size = max(len(f0) // (total_tones * 2), 5)
225
+
226
+ current_start = max(0, tone_idx * window_size)
227
+ current_end = min((tone_idx + 1) * window_size, len(f0))
228
+ next_start = current_end
229
+ next_end = min(next_start + window_size, len(f0))
230
+
231
+ if current_start >= current_end or next_start >= next_end:
232
+ return False
233
+
234
+ current_f0 = f0[current_start:current_end]
235
+ next_f0 = f0[next_start:next_end]
236
+
237
+ # Filter out silence
238
+ current_f0 = current_f0[current_f0 > 50]
239
+ next_f0 = next_f0[next_f0 > 50]
240
+
241
+ if len(current_f0) == 0 or len(next_f0) == 0:
242
+ return True # Silence indicates word boundary
243
+
244
+ # Calculate F0 drop
245
+ avg_current = np.mean(current_f0)
246
+ avg_next = np.mean(next_f0)
247
+ f0_drop = (avg_current - avg_next) / avg_current if avg_current > 0 else 0
248
+
249
+ return f0_drop > config.F0_DROP_THRESHOLD
250
+
251
+ def _detect_space_duration(self, tone_idx, total_tones, feat_len):
252
+ """Detect space based on regular intervals (simple heuristic)"""
253
+ # Every 3-5 tones, insert a space (simple word-length heuristic)
254
+ return (tone_idx + 1) % 4 == 0
255
+
256
+ def _detect_space_combined(self, f0, tone_idx, total_tones, feat_len):
257
+ """Combine multiple space detection methods"""
258
+ silence_vote = self._detect_space_silence(f0, tone_idx, total_tones, feat_len)
259
+ f0_drop_vote = self._detect_space_f0_drop(f0, tone_idx, total_tones)
260
+
261
+ # If both methods agree, high confidence
262
+ if silence_vote and f0_drop_vote:
263
+ return True
264
+
265
+ # If at least one method detects space and we're at a reasonable position
266
+ if (silence_vote or f0_drop_vote) and (tone_idx + 1) % 2 == 0:
267
+ return True
268
+
269
+ return False
270
+
271
+ def _get_tone_name(self, idx):
272
+ """
273
+ Convert tone index to name
274
+
275
+ Based on labelencoder.txt + space detection:
276
+ - 0: Blank (CTC)
277
+ - 1: High tone (H)
278
+ - 2: Low tone (B - Bas)
279
+ - 3: Mid tone (M)
280
+ - 4: Space (detected post-processing)
281
+ """
282
+ tone_map = {
283
+ 0: "BLANK",
284
+ 1: "High",
285
+ 2: "Low",
286
+ 3: "Mid",
287
+ 4: "Space"
288
+ }
289
+ return tone_map.get(idx, f"Unknown({idx})")
290
+
291
+ def forward(self, wavs, wav_lens):
292
+ """Forward pass for the model"""
293
+ return self.classify_batch(wavs, wav_lens)
examples/yof_00295_00024634140.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76dd4293dd93bdffbd4065bbab97a5949033947129e03b2b80977daac51ee6c1
3
+ size 92888
examples/yof_00295_00151151204.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69065a4d0403b0725912e45e46cb8296e40a3adbf4d9916752579f75de518a8c
3
+ size 158424
examples/yof_00295_00427144639.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5c7c23374afadad02651c48526b0b517798a9a8274d9d33d0fad223a939a472
3
+ size 155692
examples/yof_00295_00564596981.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b858793b98658c76b77e2dc5f2a4414cf8847d88528f54b7b24bfd05f9e4ab94
3
+ size 112002
examples/yof_00295_00654803226.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3b9acb687260e35431f214c52483341a06c8c2ddb8e0cb22ece5f5e36d58292
3
+ size 117464
examples/yof_00295_01329504028.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24c8b5283ebf63c7a1ca20ac9df81bee683d0662defc736faf165765c556f640
3
+ size 106540
examples/yof_00295_01428115987.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e8d05f45d4532e4212e1b916ab72ed45b8e97ea614ca77f4f655dfefd6f7840
3
+ size 139308
examples/yom_08784_01544027142.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:423c8d6d006a7383715da9e85877db63ccf9ec799b8caacd780c9354018ef710
3
+ size 166616
examples/yom_08784_01571599993.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:354dd23cd2765334e271583931da6bc4de2196eebc06fb7ea9e4100c70b8a5d2
3
+ size 120194
examples/yom_08784_01716814128.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddf9ec030fe7469d6345f4a2dbd5e12fad902b1d200b3ec6b3c8bf9eb8f83e4a
3
+ size 109272
examples/yom_08784_01792196659.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6ff02e184058c36eefae6f43e33e86c86597ed3c86ad7d8bf651ec639d8014e
3
+ size 90156
examples/yom_08784_01855888561.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5a19cab7263e2b07c8c6223e37c96bf7dd2c34916f70edbf3d73255ad3b9a7d
3
+ size 150232
examples/yom_09334_00045442417.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ac0572cec403c1da263e567621932a9e1d51a019749a3c10a63de36798bf0c6
3
+ size 139308
examples/yom_09334_00091591408.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20291615ab693caacdf8ffef05951904b366bacf482302d32b6d7a55d46453ae
3
+ size 98348
examples/yom_09334_00167629780.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42c068263939cb7c08021553fd10f479120599a3f511f9c06273323b75a517de
3
+ size 128386
inference.yaml ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ################################
2
+ # ProTeVa Inference Configuration
3
+ # Simplified YAML for deployment
4
+ # ################################
5
+
6
+ # Basic settings
7
+ seed: 200
8
+ device: cpu # Change to cuda if GPU available
9
+ sample_rate: 16000
10
+
11
+ # Output neurons (4 classes: blank, high, low, mid)
12
+ # Based on labelencoder.txt: 0=blank, 1=H, 2=B, 3=M
13
+ # Space (4) is added via post-processing
14
+ output_neurons: 4
15
+ blank_index: 0
16
+
17
+ # Number of prototypes
18
+ n_prototypes: 10
19
+
20
+ # Feature dimension from HuBERT
21
+ emb_dim: 768
22
+
23
+ # Encoder settings
24
+ rnn_layers: 2
25
+ rnn_neurons: 512
26
+
27
+ # Decoder settings
28
+ dnn_blocks: 2
29
+ dnn_neurons: 512
30
+
31
+ # Pitch decoder settings
32
+ dec_dnn_blocks: [1]
33
+ dec_dnn_neurons: [128]
34
+
35
+ # Activation function
36
+ activation: !name:torch.nn.LeakyReLU
37
+
38
+ # ============ MODULES ============
39
+
40
+ # HuBERT feature extractor
41
+ wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
42
+ source: "Orange/SSA-HuBERT-base-60k"
43
+ output_norm: True
44
+ freeze: False
45
+ save_path: whubert_checkpoint
46
+
47
+ # F0 extractor (requires custom module)
48
+ f0Compute: !new:modules.F0Extractor
49
+ device: !ref <device>
50
+ sample_rate: !ref <sample_rate>
51
+
52
+ # BiGRU Encoder
53
+ enc: !new:speechbrain.nnet.RNN.GRU
54
+ input_shape: [null, null, !ref <emb_dim>]
55
+ hidden_size: !ref <rnn_neurons>
56
+ num_layers: !ref <rnn_layers>
57
+ bidirectional: True
58
+ dropout: 0.15
59
+
60
+ # VanillaNN Decoder
61
+ dec: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
62
+ input_shape: [null, null, 1024] # 512 * 2 (bidirectional)
63
+ activation: !ref <activation>
64
+ dnn_blocks: !ref <dnn_blocks>
65
+ dnn_neurons: !ref <dnn_neurons>
66
+
67
+ # Pitch Decoder (requires custom module)
68
+ pitch_dec: !new:modules.PitchDecoderLayer
69
+ input_shape: [null, null, !ref <dnn_neurons>]
70
+ dnn_blocks: !ref <dec_dnn_blocks>
71
+ dnn_neurons: !ref <dec_dnn_neurons>
72
+
73
+ # Prototype Layer (requires custom module)
74
+ proto: !new:modules.PrototypeLayer
75
+ n_prototypes: !ref <n_prototypes>
76
+ latent_dims: !ref <dnn_neurons>
77
+
78
+ # Output linear layer
79
+ output_lin: !new:speechbrain.nnet.linear.Linear
80
+ input_size: !ref <n_prototypes>
81
+ n_neurons: !ref <output_neurons>
82
+ bias: True
83
+
84
+ # Log softmax
85
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
86
+ apply_log: True
87
+
88
+ # Label encoder
89
+ label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder
90
+
91
+ # ============ MODULES DICT ============
92
+
93
+ modules:
94
+ wav2vec2: !ref <wav2vec2>
95
+ enc: !ref <enc>
96
+ dec: !ref <dec>
97
+ pitch_dec: !ref <pitch_dec>
98
+ proto: !ref <proto>
99
+ output_lin: !ref <output_lin>
100
+
101
+ # Model container for all modules
102
+ model: !new:torch.nn.ModuleList
103
+ - [!ref <enc>, !ref <dec>, !ref <proto>, !ref <output_lin>, !ref <pitch_dec>]
104
+
105
+ # ============ PRETRAINER ============
106
+ # This loads the trained checkpoints
107
+
108
+ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
109
+ loadables:
110
+ model: !ref <model>
111
+ wav2vec2: !ref <wav2vec2>
112
+ tokenizer: !ref <label_encoder>
113
+ paths:
114
+ model: !ref <save_folder>/model.ckpt
115
+ wav2vec2: !ref <save_folder>/wav2vec2.ckpt
116
+ tokenizer: !ref <save_folder>/tokenizer.ckpt
117
+
118
+ # Save folder - Path is loaded from config.py
119
+ # To change checkpoint folder, update CHECKPOINT_FOLDER in config.py
120
+ save_folder: ./CKPT+2025-10-20+08-19-07+00
labelencoder.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ 'M' => 3
2
+ 'H' => 1
3
+ 'B' => 2
4
+ '<blank>' => 0
5
+ ================
6
+ 'starting_index' => 0
7
+ 'blank_label' => '<blank>'
modules.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom modules for ProTeVa tone recognition model
3
+
4
+ Authors
5
+ * St Germes BENGONO OBIANG 2024
6
+ """
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torchyin
13
+ from scipy.interpolate import interp1d
14
+ from speechbrain.lobes.models.VanillaNN import VanillaNN
15
+ from torch.nn import LeakyReLU, ReLU
16
+ from speechbrain.nnet.containers import ModuleList
17
+
18
+
19
+ class F0Extractor(torch.nn.Module):
20
+ """This module extracts F0 of sound and returns it as embedding vector
21
+
22
+ Arguments
23
+ ---------
24
+ device : str
25
+ Device to run computations on ('cpu' or 'cuda')
26
+ sample_rate : int
27
+ The signal sample rate (default: 16000)
28
+ frame_stride : float
29
+ Length of the sliding window used for F0 extraction (default: 0.018)
30
+ pitch_min : float
31
+ The minimum value of pitch (default: 50)
32
+ pitch_max : float
33
+ The maximum value of pitch (default: 500)
34
+
35
+ Example
36
+ -------
37
+ >>> compute_f0 = F0Extractor(sample_rate=16000)
38
+ >>> input_feats = torch.rand([1, 23000])
39
+ >>> outputs = compute_f0(input_feats, target_size=220)
40
+ >>> outputs.shape
41
+ torch.Size([1, 220, 1])
42
+
43
+ Authors
44
+ -------
45
+ * St Germes BENGONO OBIANG 2024
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ device="cpu",
51
+ sample_rate=16000,
52
+ frame_stride=0.018,
53
+ pitch_min=50,
54
+ pitch_max=500,
55
+ ):
56
+ super().__init__()
57
+ self.device = device
58
+ self.sample_rate = sample_rate
59
+ self.pitch_min = pitch_min
60
+ self.pitch_max = pitch_max
61
+ self.frame_stride = frame_stride
62
+
63
+ def interpolate_spline(self, H, N):
64
+ """Interpolate pitch values to target size using cubic spline interpolation"""
65
+ # Generate indices for the original and new tensors
66
+ idx_original = np.arange(len(H))
67
+ idx_new = np.linspace(0, len(H) - 1, N)
68
+
69
+ # Create the interpolation function
70
+ interpolator = interp1d(idx_original, H, kind='cubic')
71
+
72
+ # Perform interpolation
73
+ H_interpolated = interpolator(idx_new)
74
+
75
+ # Create a mask for values below minimum pitch
76
+ mask = H_interpolated < self.pitch_min
77
+ H_interpolated[mask] = 0
78
+
79
+ return torch.as_tensor(H_interpolated.tolist())
80
+
81
+ def forward(self, wavs, target_size):
82
+ """Extract F0 from waveforms and interpolate to target size"""
83
+ results = []
84
+ for wav in wavs:
85
+ pitch = torchyin.estimate(
86
+ wav,
87
+ self.sample_rate,
88
+ pitch_min=self.pitch_min,
89
+ pitch_max=self.pitch_max,
90
+ frame_stride=self.frame_stride
91
+ )
92
+
93
+ # Interpolate the pitch
94
+ pitch = self.interpolate_spline(pitch.cpu().numpy(), target_size)
95
+
96
+ # Reshape the pitch output
97
+ pitch = pitch.view(pitch.shape[0], 1)
98
+ results.append(pitch.tolist())
99
+
100
+ return torch.as_tensor(results).to(self.device)
101
+
102
+
103
+ class PitchDecoderLayer(torch.nn.Module):
104
+ """Layer for decoding latent vector to pitch
105
+
106
+ This decoder reconstructs F0 contours from encoded representations
107
+ using stacked VanillaNN layers.
108
+
109
+ Arguments
110
+ ---------
111
+ input_shape : list
112
+ Shape of input tensor [None, None, feature_dim]
113
+ dnn_blocks : list
114
+ Number of blocks for each DNN layer
115
+ dnn_neurons : list
116
+ Number of neurons for each DNN layer
117
+
118
+ Authors
119
+ -------
120
+ * St Germes BENGONO OBIANG 2024
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ input_shape=[None, None, 256],
126
+ dnn_blocks=[2, 2],
127
+ dnn_neurons=[256, 128],
128
+ ):
129
+ super().__init__()
130
+ if len(dnn_blocks) != len(dnn_neurons):
131
+ raise ValueError(
132
+ f"dnn_blocks and dnn_neurons should have the same size but we received {len(dnn_blocks)} and {len(dnn_neurons)}"
133
+ )
134
+
135
+ layers = []
136
+ for index in range(len(dnn_neurons)):
137
+ if index == 0:
138
+ layers.append(
139
+ VanillaNN(
140
+ activation=LeakyReLU,
141
+ dnn_blocks=dnn_blocks[index],
142
+ dnn_neurons=dnn_neurons[index],
143
+ input_shape=input_shape
144
+ )
145
+ )
146
+ else:
147
+ # The input shape is equal to the output of the previous layer
148
+ layers.append(
149
+ VanillaNN(
150
+ activation=LeakyReLU,
151
+ dnn_blocks=dnn_blocks[index],
152
+ dnn_neurons=dnn_neurons[index],
153
+ input_shape=[None, None, dnn_neurons[index - 1]]
154
+ )
155
+ )
156
+
157
+ # Add the last required layer. The input shape is equal to the last DNN block output
158
+ layers.append(
159
+ VanillaNN(
160
+ activation=ReLU,
161
+ dnn_blocks=1,
162
+ dnn_neurons=1,
163
+ input_shape=[None, None, dnn_neurons[len(dnn_neurons) - 1]]
164
+ )
165
+ )
166
+
167
+ self.decoder = ModuleList(*layers)
168
+
169
+ def forward(self, latent_vector):
170
+ """Decode latent vector to F0 prediction"""
171
+ return self.decoder(latent_vector)
172
+
173
+
174
+ # ============ HELPER FUNCTIONS FOR PROTOTYPE LAYER ============
175
+
176
+ def distance_to_prototype(latent_vector, prototypes):
177
+ """
178
+ Compute the L2 squared distance between each timestamp in the latent_vector and each prototype.
179
+
180
+ Args:
181
+ latent_vector (torch.Tensor): Tensor of shape [batch, timesteps, features].
182
+ prototypes (torch.Tensor): Tensor of shape [n_prototypes, features].
183
+
184
+ Returns:
185
+ torch.Tensor: Tensor of shape [batch, timesteps, n_prototypes] with L2 squared distances.
186
+ """
187
+ # Expand the dimensions of prototypes to match the shape for broadcasting
188
+ prototypes = prototypes.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, n_prototypes, features]
189
+
190
+ # Expand latent_vector to match the shape for broadcasting
191
+ latent_vector = latent_vector.unsqueeze(2) # Shape: [batch, timesteps, 1, features]
192
+
193
+ # Compute the L2 squared distance
194
+ distance = torch.sum((latent_vector - prototypes) ** 2, dim=-1) # Shape: [batch, timesteps, n_prototypes]
195
+
196
+ return distance
197
+
198
+
199
+ def cosine_similarity_to_prototype(latent_vector, prototypes):
200
+ """
201
+ Compute the cosine similarity between each timestamp in the latent_vector and each prototype.
202
+
203
+ Args:
204
+ latent_vector (torch.Tensor): Tensor of shape [batch, timesteps, features].
205
+ prototypes (torch.Tensor): Tensor of shape [n_prototypes, features].
206
+
207
+ Returns:
208
+ torch.Tensor: Tensor of shape [batch, timesteps, n_prototypes] with cosine similarities.
209
+ """
210
+ # Normalize the latent vector and prototypes
211
+ latent_vector_norm = F.normalize(latent_vector, p=2, dim=-1) # Shape: [batch, timesteps, features]
212
+ prototypes_norm = F.normalize(prototypes, p=2, dim=-1) # Shape: [n_prototypes, features]
213
+
214
+ # Expand dimensions to match for broadcasting
215
+ prototypes_norm = prototypes_norm.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, n_prototypes, features]
216
+ latent_vector_norm = latent_vector_norm.unsqueeze(2) # Shape: [batch, timesteps, 1, features]
217
+
218
+ # Compute the cosine similarity
219
+ similarity = torch.sum(latent_vector_norm * prototypes_norm, dim=-1) # Shape: [batch, timesteps, n_prototypes]
220
+
221
+ return similarity
222
+
223
+
224
+ def distances_to_feature(input_tensor, prototypes):
225
+ """
226
+ Compute the L2 squared distance between each prototype and each timestamp in the input_tensor.
227
+
228
+ Args:
229
+ input_tensor (torch.Tensor): Tensor of shape [batch_size, num_timestep, feature_dim].
230
+ prototypes (torch.Tensor): Tensor of shape [num_prototypes, feature_dim].
231
+
232
+ Returns:
233
+ torch.Tensor: Tensor of shape [num_prototypes, batch_size, num_timestep] with L2 squared distances.
234
+ """
235
+ # Expand the dimensions of prototypes to match the shape for broadcasting
236
+ prototypes = prototypes.unsqueeze(1).unsqueeze(2) # Shape: [num_prototypes, 1, 1, feature_dim]
237
+
238
+ # Expand input_tensor to match the shape for broadcasting
239
+ input_tensor = input_tensor.unsqueeze(0) # Shape: [1, batch_size, num_timestep, feature_dim]
240
+
241
+ # Compute the L2 squared distance
242
+ distance = torch.sum((input_tensor - prototypes) ** 2, dim=-1) # Shape: [num_prototypes, batch_size, num_timestep]
243
+
244
+ return distance
245
+
246
+
247
+ def compute_prototype_distances(prototypes):
248
+ """
249
+ Compute the L2 squared distance between each pair of prototypes.
250
+
251
+ Args:
252
+ prototypes (torch.Tensor): Tensor of shape [n_prototypes, features].
253
+
254
+ Returns:
255
+ torch.Tensor: Tensor of shape [n_prototypes, n_prototypes] with L2 squared distances between prototypes.
256
+ """
257
+ # Calculate the squared norms of the prototypes
258
+ squared_norms = torch.sum(prototypes ** 2, dim=1, keepdim=True) # Shape: [n_prototypes, 1]
259
+
260
+ # Calculate the pairwise distance using the formula: (a-b)^2 = a^2 + b^2 - 2ab
261
+ distances = squared_norms + squared_norms.T - 2 * torch.mm(prototypes, prototypes.T) # Shape: [n_prototypes, n_prototypes]
262
+ distances = distances.fill_diagonal_(1e+6)
263
+
264
+ return distances
265
+
266
+
267
+ class PrototypeLayer(torch.nn.Module):
268
+ """
269
+ Prototype Layer for tone representation learning
270
+
271
+ Learns M prototypes that represent canonical tone patterns.
272
+ Computes similarity between input features and prototypes.
273
+ Includes regularization losses R_1, R_2, and R_3.
274
+
275
+ Arguments
276
+ ---------
277
+ n_prototypes : int
278
+ Number of learnable prototypes (default: 9)
279
+ latent_dims : int
280
+ Dimension of latent space (default: 256)
281
+
282
+ Authors
283
+ -------
284
+ * St Germes BENGONO OBIANG 2024
285
+ """
286
+
287
+ def __init__(
288
+ self,
289
+ n_prototypes=9,
290
+ latent_dims=256,
291
+ ):
292
+ super().__init__()
293
+ self.n_prototypes = n_prototypes
294
+ self.latent_dims = latent_dims
295
+
296
+ # Initialize prototypes with Kaiming uniform initialization
297
+ self.prototypes = torch.nn.Parameter(
298
+ torch.nn.init.kaiming_uniform_(
299
+ torch.empty([n_prototypes, latent_dims]),
300
+ nonlinearity='relu'
301
+ ),
302
+ requires_grad=True
303
+ )
304
+
305
+ # Regularization losses
306
+ self.R_1 = 0 # Feature distances regulation
307
+ self.R_2 = 0 # Prototypes distances regulation
308
+ self.R_3 = 0 # Prototypes to prototypes distances
309
+
310
+ def setProto(self, proto):
311
+ """Set prototype values (for initialization or transfer learning)"""
312
+ self.prototypes = torch.nn.Parameter(proto, requires_grad=True)
313
+
314
+ def forward(self, latent_vector):
315
+ """
316
+ Compute similarity between input and prototypes
317
+
318
+ Args:
319
+ latent_vector (torch.Tensor): Input features [batch, time, latent_dims]
320
+
321
+ Returns:
322
+ torch.Tensor: Prototype similarities [batch, time, n_prototypes]
323
+ """
324
+ # Compute distances and similarities
325
+ dist2proto = distance_to_prototype(latent_vector, self.prototypes)
326
+ similarity2Proto = cosine_similarity_to_prototype(latent_vector, self.prototypes)
327
+ dist2Feature = distances_to_feature(latent_vector, self.prototypes)
328
+ protoDistance = compute_prototype_distances(self.prototypes)
329
+
330
+ if self.training:
331
+ # R_1: Each prototype is near to at least one data in latent space
332
+ self.R_1 = torch.mean(torch.min(dist2Feature, dim=2).values)
333
+
334
+ # R_2: Each data in latent space is near to at least one prototype
335
+ self.R_2 = torch.mean(torch.min(dist2proto, dim=2).values)
336
+
337
+ # R_3: Prototype is as far as possible to other prototypes
338
+ self.R_3 = 1 / (torch.mean(torch.min(protoDistance, dim=1).values) + 1e-8)
339
+
340
+ return similarity2Proto
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ # Install torch and torchaudio first to match training environment versions
3
+ torch==2.8.0
4
+ torchaudio==2.8.0
5
+
6
+ # SpeechBrain includes: numpy, scipy, sentencepiece, hyperpyyaml, transformers, huggingface_hub
7
+ speechbrain==1.0.0
8
+
9
+ # F0 extraction with TorchYIN (note: package name is torch-yin, not torchyin)
10
+ torch-yin==0.1.3
11
+
12
+ # Gradio for UI
13
+ gradio>=4.0.0
14
+
15
+ # Audio processing (not included in speechbrain)
16
+ librosa
17
+ soundfile
18
+
19
+ # Visualization (not included in speechbrain)
20
+ matplotlib