Thanh-Lam commited on
Commit
c3418e9
·
0 Parent(s):

Vietnamese Speaker Profiling with wav2vec2-base-vi-vlsp2020

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Vietnamese Speaker Profiling
3
+ emoji: 📈
4
+ colorFrom: indigo
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 6.0.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio Web Interface for Speaker Profiling
3
+
4
+ Usage:
5
+ python app.py
6
+ python app.py --config configs/infer.yaml --share
7
+ """
8
+
9
+ import os
10
+ import argparse
11
+ import tempfile
12
+ import time
13
+ import numpy as np
14
+ import torch
15
+ import librosa
16
+ import gradio as gr
17
+ from pathlib import Path
18
+
19
+ from src.models import MultiTaskSpeakerModel
20
+ from src.utils import (
21
+ setup_logging,
22
+ get_logger,
23
+ load_config,
24
+ get_device,
25
+ load_model_checkpoint,
26
+ preprocess_audio
27
+ )
28
+
29
+
30
+ class SpeakerProfilerApp:
31
+ """Gradio application for speaker profiling"""
32
+
33
+ def __init__(self, config_path: str):
34
+ self.logger = setup_logging(name="gradio_app")
35
+ self.config = load_config(config_path)
36
+ self.device = get_device(self.config['inference']['device'])
37
+
38
+ self.sampling_rate = self.config['audio']['sampling_rate']
39
+ self.max_duration = self.config['audio']['max_duration']
40
+
41
+ self.gender_labels = self.config['labels']['gender']
42
+ self.dialect_labels = self.config['labels']['dialect']
43
+
44
+ self._load_model()
45
+
46
+ def _load_model(self):
47
+ """Load model and feature extractor"""
48
+ from transformers import Wav2Vec2FeatureExtractor, WhisperFeatureExtractor
49
+
50
+ self.logger.info("Loading model...")
51
+
52
+ model_name = self.config['model']['name']
53
+ is_ecapa = 'ecapa' in model_name.lower() or 'speechbrain' in model_name.lower()
54
+
55
+ # Check if this is a Whisper/PhoWhisper model
56
+ self.is_whisper = 'whisper' in model_name.lower() or 'phowhisper' in model_name.lower()
57
+
58
+ if is_ecapa:
59
+ # ECAPA-TDNN: use Wav2Vec2 feature extractor for audio normalization
60
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
61
+ "facebook/wav2vec2-base"
62
+ )
63
+ elif self.is_whisper:
64
+ # Whisper/PhoWhisper: use WhisperFeatureExtractor
65
+ self.feature_extractor = WhisperFeatureExtractor.from_pretrained(
66
+ model_name
67
+ )
68
+ else:
69
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
70
+ self.config['model']['checkpoint']
71
+ )
72
+
73
+ self.model = MultiTaskSpeakerModel(model_name)
74
+ self.model = load_model_checkpoint(
75
+ self.model,
76
+ self.config['model']['checkpoint'],
77
+ str(self.device)
78
+ )
79
+
80
+ self.model.to(self.device)
81
+ self.model.eval()
82
+
83
+ self.logger.info(f"Model loaded on {self.device}")
84
+
85
+ def predict(self, audio_input):
86
+ """
87
+ Predict gender and dialect from audio
88
+
89
+ Args:
90
+ audio_input: Tuple of (sample_rate, audio_array) from Gradio
91
+
92
+ Returns:
93
+ Tuple of (gender_result, dialect_result, details)
94
+ """
95
+ if audio_input is None:
96
+ return "No audio", "No audio", "Please upload or record audio"
97
+
98
+ try:
99
+ sr, audio = audio_input
100
+
101
+ if len(audio.shape) > 1:
102
+ audio = audio.mean(axis=1)
103
+
104
+ audio = audio.astype(np.float32)
105
+ if audio.max() > 1.0:
106
+ audio = audio / 32768.0
107
+
108
+ if sr != self.sampling_rate:
109
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sampling_rate)
110
+
111
+ # Calculate original audio duration BEFORE preprocessing
112
+ audio_duration = len(audio) / self.sampling_rate
113
+
114
+ # Whisper requires 30 seconds of audio
115
+ if self.is_whisper:
116
+ max_duration = 30
117
+ else:
118
+ max_duration = self.max_duration
119
+
120
+ audio = preprocess_audio(
121
+ audio,
122
+ sampling_rate=self.sampling_rate,
123
+ max_duration=max_duration
124
+ )
125
+
126
+ # Whisper needs exactly 30 seconds - pad if necessary
127
+ if self.is_whisper:
128
+ target_len = self.sampling_rate * 30
129
+ if len(audio) < target_len:
130
+ audio = np.pad(audio, (0, target_len - len(audio)))
131
+
132
+ inputs = self.feature_extractor(
133
+ audio,
134
+ sampling_rate=self.sampling_rate,
135
+ return_tensors="pt",
136
+ padding=True
137
+ )
138
+
139
+ # Whisper uses 'input_features', WavLM/HuBERT/Wav2Vec2 use 'input_values'
140
+ if self.is_whisper:
141
+ input_values = inputs.input_features.to(self.device)
142
+ else:
143
+ input_values = inputs.input_values.to(self.device)
144
+
145
+ # Measure inference time
146
+ start_time = time.perf_counter()
147
+
148
+ with torch.no_grad():
149
+ outputs = self.model(input_values)
150
+ gender_logits = outputs['gender_logits']
151
+ dialect_logits = outputs['dialect_logits']
152
+
153
+ # Calculate inference time
154
+ infer_time = (time.perf_counter() - start_time) * 1000 # Convert to ms
155
+
156
+ gender_probs = torch.softmax(gender_logits, dim=-1).cpu().numpy()[0]
157
+ dialect_probs = torch.softmax(dialect_logits, dim=-1).cpu().numpy()[0]
158
+
159
+ gender_pred = int(np.argmax(gender_probs))
160
+ dialect_pred = int(np.argmax(dialect_probs))
161
+
162
+ gender_name = self.gender_labels[gender_pred]
163
+ dialect_name = self.dialect_labels[dialect_pred]
164
+
165
+ gender_conf = gender_probs[gender_pred] * 100
166
+ dialect_conf = dialect_probs[dialect_pred] * 100
167
+
168
+ gender_result = f"{gender_name} ({gender_conf:.1f}%)"
169
+ dialect_result = f"{dialect_name} ({dialect_conf:.1f}%)"
170
+
171
+ details = self._format_details(gender_probs, dialect_probs, infer_time, audio_duration)
172
+
173
+ self.logger.info(f"Prediction: Gender={gender_name}, Dialect={dialect_name} | Inference time: {infer_time:.2f}ms | Audio: {audio_duration:.2f}s")
174
+
175
+ return gender_result, dialect_result, details
176
+
177
+ except Exception as e:
178
+ self.logger.error(f"Prediction error: {e}")
179
+ return "Error", "Error", f"Error: {str(e)}"
180
+
181
+ def _format_details(self, gender_probs: np.ndarray, dialect_probs: np.ndarray, infer_time: float = None, audio_duration: float = None) -> str:
182
+ """Format detailed prediction results"""
183
+ # Gender label names
184
+ gender_names = ['Female', 'Male']
185
+ # Dialect label names
186
+ dialect_names = ['North', 'Central', 'South']
187
+
188
+ lines = []
189
+ lines.append("Gender Probabilities:")
190
+ for i, name in enumerate(gender_names):
191
+ lines.append(f" {name}: {gender_probs[i]*100:.2f}%")
192
+
193
+ lines.append("")
194
+ lines.append("Dialect Probabilities:")
195
+ for i, name in enumerate(dialect_names):
196
+ lines.append(f" {name}: {dialect_probs[i]*100:.2f}%")
197
+
198
+ lines.append("")
199
+ lines.append("─" * 30)
200
+
201
+ if audio_duration is not None:
202
+ lines.append(f"Audio Duration: {audio_duration:.2f} s")
203
+
204
+ if infer_time is not None:
205
+ lines.append(f"Inference Time: {infer_time:.2f} ms")
206
+
207
+ return "\n".join(lines)
208
+
209
+ def create_interface(self) -> gr.Blocks:
210
+ """Create Gradio interface"""
211
+
212
+ # Gradio < 4.0 doesn't support theme in Blocks
213
+ with gr.Blocks(title="Vietnamese Speaker Profiling") as demo:
214
+
215
+ gr.Markdown(
216
+ """
217
+ # Vietnamese Speaker Profiling
218
+
219
+ Identify gender and dialect from Vietnamese speech audio.
220
+
221
+ **Model:** Encoder + Attentive Pooling + LayerNorm + MultiHead Classifier
222
+
223
+ **Supported dialects:** North, Central, South
224
+ """
225
+ )
226
+
227
+ with gr.Row():
228
+ with gr.Column(scale=1):
229
+ audio_input = gr.Audio(
230
+ label="Input Audio",
231
+ type="numpy",
232
+ sources=["upload", "microphone"]
233
+ )
234
+
235
+ submit_btn = gr.Button("Analyze", variant="primary")
236
+ clear_btn = gr.Button("Clear")
237
+
238
+ with gr.Column(scale=1):
239
+ gender_output = gr.Textbox(
240
+ label="Gender",
241
+ interactive=False
242
+ )
243
+ dialect_output = gr.Textbox(
244
+ label="Dialect",
245
+ interactive=False
246
+ )
247
+ details_output = gr.Textbox(
248
+ label="Details",
249
+ lines=8,
250
+ interactive=False
251
+ )
252
+
253
+ gr.Markdown(
254
+ """
255
+ ---
256
+ **Notes:**
257
+ - Supported formats: WAV, MP3
258
+ - Recommended duration: 3-10 seconds
259
+ """
260
+ )
261
+
262
+ submit_btn.click(
263
+ fn=self.predict,
264
+ inputs=[audio_input],
265
+ outputs=[gender_output, dialect_output, details_output]
266
+ )
267
+
268
+ clear_btn.click(
269
+ fn=lambda: (None, "", "", ""),
270
+ inputs=[],
271
+ outputs=[audio_input, gender_output, dialect_output, details_output]
272
+ )
273
+
274
+ return demo
275
+
276
+
277
+ def main():
278
+ """Main function"""
279
+ parser = argparse.ArgumentParser(description="Speaker Profiling Web Interface")
280
+ parser.add_argument(
281
+ "--config",
282
+ type=str,
283
+ default="configs/infer.yaml",
284
+ help="Path to config file"
285
+ )
286
+ parser.add_argument(
287
+ "--share",
288
+ action="store_true",
289
+ help="Create public link"
290
+ )
291
+ parser.add_argument(
292
+ "--port",
293
+ type=int,
294
+ default=7860,
295
+ help="Port number (default: 7860)"
296
+ )
297
+ parser.add_argument(
298
+ "--server_name",
299
+ type=str,
300
+ default="0.0.0.0",
301
+ help="Server name (default: 0.0.0.0)"
302
+ )
303
+ args = parser.parse_args()
304
+
305
+ app = SpeakerProfilerApp(args.config)
306
+ demo = app.create_interface()
307
+
308
+ demo.launch(
309
+ server_name=args.server_name,
310
+ server_port=args.port,
311
+ share=args.share
312
+ )
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main()
configs/eval.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation Configuration
2
+ # Architecture: Encoder + Attentive Pooling + LayerNorm
3
+
4
+ # Model
5
+ model:
6
+ checkpoint: "output/speaker-profiling/best_model"
7
+ name: "microsoft/wavlm-base-plus"
8
+ head_hidden_dim: 256
9
+
10
+ # Audio Processing
11
+ audio:
12
+ sampling_rate: 16000
13
+ max_duration: 5
14
+
15
+ # Evaluation
16
+ evaluation:
17
+ batch_size: 32
18
+ dataloader_num_workers: 2
19
+
20
+ # Data Paths (relative to repo root)
21
+ data:
22
+ # === ViSpeech (CSV format) ===
23
+ clean_test_meta: "/home/ubuntu/DataScience/Voice_Pro_filling/vispeech_data/ViSpeech/metadata/clean_testset.csv"
24
+ clean_test_audio: "/home/ubuntu/DataScience/Voice_Pro_filling/vispeech_data/ViSpeech/clean_testset"
25
+ noisy_test_meta: "/home/ubuntu/DataScience/Voice_Pro_filling/vispeech_data/ViSpeech/metadata/noisy_testset.csv"
26
+ noisy_test_audio: "/home/ubuntu/DataScience/Voice_Pro_filling/vispeech_data/ViSpeech/noisy_testset"
27
+
28
+ # === ViMD (HuggingFace format) ===
29
+ vimd_path: "/kaggle/input/vimd-dataset"
30
+
31
+ # Output
32
+ output:
33
+ dir: "output/evaluation"
34
+ save_predictions: true
35
+ save_confusion_matrix: true
36
+
37
+ # Label Mappings
38
+ labels:
39
+ gender:
40
+ Male: 0
41
+ Female: 1
42
+ 0: 0
43
+ 1: 1
44
+ dialect:
45
+ North: 0
46
+ Central: 1
47
+ South: 2
48
+ region_to_dialect:
49
+ North: 0
50
+ Central: 1
51
+ South: 2
52
+
53
+ # Baseline Comparison (PACLIC 2024 - ResNet34)
54
+ baseline:
55
+ gender:
56
+ clean: 98.73
57
+ noisy: 98.14
58
+ dialect:
59
+ clean: 81.47
60
+ noisy: 74.80
configs/eval.yaml.example ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation Configuration# Evaluation Configuration# Evaluation Configuration
2
+
3
+ # Evaluate model on test sets from raw audio
4
+
5
+ # Copy this file to eval.yaml and update paths# Evaluate model on test sets from raw audio# Architecture: WavLM + Attentive Pooling + LayerNorm + Deeper Heads
6
+
7
+
8
+
9
+ # Model# Copy this file to eval.yaml and update paths# Copy this file to eval.yaml and update paths
10
+
11
+ model:
12
+
13
+ checkpoint: "path/to/best_model"
14
+
15
+ name: "microsoft/wavlm-base-plus"
16
+
17
+ head_hidden_dim: 256# Model# Model
18
+
19
+
20
+
21
+ # Audio Processingmodel:model:
22
+
23
+ audio:
24
+
25
+ sampling_rate: 16000 checkpoint: "path/to/best_model" checkpoint: "path/to/best_model"
26
+
27
+ max_duration: 5
28
+
29
+ name: "microsoft/wavlm-base-plus" name: "microsoft/wavlm-base-plus"
30
+
31
+ # Evaluation
32
+
33
+ evaluation: head_hidden_dim: 256 head_hidden_dim: 256
34
+
35
+ batch_size: 32
36
+
37
+ dataloader_num_workers: 2
38
+
39
+
40
+
41
+ # Data Paths# Audio Processing# Audio Processing
42
+
43
+ data:
44
+
45
+ # === ViSpeech (CSV format) ===audio:audio:
46
+
47
+ clean_test_meta: "path/to/metadata/clean_testset.csv"
48
+
49
+ clean_test_audio: "path/to/clean_testset" sampling_rate: 16000 sampling_rate: 16000
50
+
51
+ noisy_test_meta: "path/to/metadata/noisy_testset.csv"
52
+
53
+ noisy_test_audio: "path/to/noisy_testset" max_duration: 5 max_duration: 5
54
+
55
+
56
+
57
+ # === ViMD (HuggingFace format) ===
58
+
59
+ vimd_path: "/path/to/vimd-dataset"
60
+
61
+ # Evaluation# Evaluation
62
+
63
+ # Output
64
+
65
+ output:evaluation:evaluation:
66
+
67
+ dir: "output/evaluation"
68
+
69
+ save_predictions: true batch_size: 32 batch_size: 32
70
+
71
+ save_confusion_matrix: true
72
+
73
+ dataloader_num_workers: 2 dataloader_num_workers: 2
74
+
75
+ # Label Mappings
76
+
77
+ labels:
78
+
79
+ gender:
80
+
81
+ Male: 0# Data Paths# Data Paths (UPDATE THESE PATHS)
82
+
83
+ Female: 1
84
+
85
+ 0: 0data:data:
86
+
87
+ 1: 1
88
+
89
+ dialect: clean_test_meta: "path/to/metadata/clean_testset.csv" clean_test_meta: "path/to/metadata/clean_testset.csv"
90
+
91
+ North: 0
92
+
93
+ Central: 1 clean_test_audio: "path/to/clean_testset" clean_test_audio: "path/to/clean_testset"
94
+
95
+ South: 2
96
+
97
+ region_to_dialect: noisy_test_meta: "path/to/metadata/noisy_testset.csv" noisy_test_meta: "path/to/metadata/noisy_testset.csv"
98
+
99
+ North: 0
100
+
101
+ Central: 1 noisy_test_audio: "path/to/noisy_testset" noisy_test_audio: "path/to/noisy_testset"
102
+
103
+ South: 2
104
+
105
+
106
+
107
+ # Baseline Comparison (PACLIC 2024 - ResNet34)
108
+
109
+ baseline:# Output# Output
110
+
111
+ gender:
112
+
113
+ clean: 98.73output:output:
114
+
115
+ noisy: 98.14
116
+
117
+ dialect: dir: "output/evaluation" dir: "output/evaluation"
118
+
119
+ clean: 81.47
120
+
121
+ noisy: 74.80 save_predictions: true save_predictions: true
122
+
123
+
124
+ save_confusion_matrix: true save_confusion_matrix: true
125
+
126
+
127
+
128
+ # Label Mappings# Label Mappings
129
+
130
+ labels:labels:
131
+
132
+ gender: gender:
133
+
134
+ Male: 0 Male: 0
135
+
136
+ Female: 1 Female: 1
137
+
138
+ 0: 0 dialect:
139
+
140
+ 1: 1 North: 0
141
+
142
+ dialect: Central: 1
143
+
144
+ North: 0 South: 2
145
+
146
+ Central: 1
147
+
148
+ South: 2# Baseline Comparison (PACLIC 2024 - ResNet34)
149
+
150
+ baseline:
151
+
152
+ # Baseline Comparison (PACLIC 2024 - ResNet34) gender:
153
+
154
+ baseline: clean: 98.73
155
+
156
+ gender: noisy: 98.14
157
+
158
+ clean: 98.73 dialect:
159
+
160
+ noisy: 98.14 clean: 81.47
161
+
162
+ dialect: noisy: 74.80
163
+
164
+ clean: 81.47
165
+ noisy: 74.80
configs/finetune.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model (for classification heads only - features are pre-extracted)
2
+ model:
3
+ name: "microsoft/wavlm-base-plus" # Used for hidden_size reference
4
+ num_genders: 2
5
+ num_dialects: 3
6
+ dropout: 0.1
7
+ head_hidden_dim: 256
8
+
9
+ # Audio processing
10
+ audio:
11
+ sampling_rate: 16000
12
+ max_duration: 5 # seconds
13
+
14
+ # Training
15
+ training:
16
+ batch_size: 32
17
+ learning_rate: 5e-5
18
+ num_epochs: 15
19
+ warmup_ratio: 0.125
20
+ weight_decay: 0.0125
21
+ gradient_clip: 0.5
22
+ lr_scheduler: "linear"
23
+ fp16: true
24
+ dataloader_num_workers: 4
25
+
26
+ # Data Augmentation
27
+ augmentation:
28
+ enabled: true
29
+ prob: 0.8
30
+
31
+ # Loss
32
+ loss:
33
+ dialect_weight: 3.0
34
+
35
+ # WandB Configuration
36
+ wandb:
37
+ enabled: true
38
+ api_key: "f05e29c3466ec288e97041e0e3d541c4087096a6"
39
+ project: "speaker-profiling"
40
+ run_name: null
41
+
42
+ # Dataset paths
43
+ # source: "vispeech" (CSV format) or "vimd" (HuggingFace format)
44
+ data:
45
+ source: "vispeech" # Options: vispeech, vimd
46
+
47
+ # === ViSpeech (CSV format) ===
48
+ vispeech_root: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech"
49
+ train_meta: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/metadata/trainset.csv"
50
+ train_audio: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/trainset"
51
+ clean_test_meta: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/metadata/clean_testset.csv"
52
+ clean_test_audio: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/clean_testset"
53
+ noisy_test_meta: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/metadata/noisy_testset.csv"
54
+ noisy_test_audio: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/noisy_testset"
55
+ val_split: 0.15
56
+
57
+ # === ViMD (HuggingFace format) ===
58
+ vimd_path: "/kaggle/input/vimd-dataset"
59
+
60
+ # Output
61
+ output:
62
+ dir: "output/speaker-profiling"
63
+ save_total_limit: 3
64
+ metric_for_best_model: "dialect_acc"
65
+
66
+ # Early Stopping
67
+ early_stopping:
68
+ patience: 3
69
+ threshold: 0.0025
70
+
71
+ # Label Mappings
72
+ labels:
73
+ gender:
74
+ Male: 0
75
+ Female: 1
76
+ 0: 0 # Support int labels (ViMD)
77
+ 1: 1
78
+ dialect:
79
+ North: 0
80
+ Central: 1
81
+ South: 2
82
+ # ViMD uses 'region' column
83
+ region_to_dialect:
84
+ North: 0
85
+ Central: 1
86
+ South: 2
87
+
88
+ # Reproducibility
89
+ seed: 42
configs/finetune.yaml.example ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetune Configuration# Finetune Configuration
2
+
3
+ # Full model finetuning from raw audio# Architecture: WavLM + Attentive Pooling + LayerNorm + Deeper Heads
4
+
5
+ # Supports: ViSpeech (CSV) and ViMD (HuggingFace)# Uses pre-extracted features from prepare_data.py
6
+
7
+ # Copy this file to finetune.yaml and update paths# Copy this file to finetune.yaml and update paths
8
+
9
+
10
+
11
+ # Model# Model (for classification heads only - features are pre-extracted)
12
+
13
+ model:model:
14
+
15
+ name: "microsoft/wavlm-base-plus" name: "microsoft/wavlm-base-plus" # Used for hidden_size reference
16
+
17
+ num_genders: 2 hidden_size: 768 # WavLM base hidden dimension
18
+
19
+ num_dialects: 3 num_genders: 2
20
+
21
+ dropout: 0.1 num_dialects: 3
22
+
23
+ head_hidden_dim: 256 dropout: 0.1
24
+
25
+ head_hidden_dim: 256
26
+
27
+ # Audio processing
28
+
29
+ audio:# Training
30
+
31
+ sampling_rate: 16000training:
32
+
33
+ max_duration: 5 # seconds batch_size: 32
34
+
35
+ learning_rate: 5e-5
36
+
37
+ # Training num_epochs: 15
38
+
39
+ training: warmup_ratio: 0.125
40
+
41
+ batch_size: 32 weight_decay: 0.0125
42
+
43
+ learning_rate: 5e-5 gradient_clip: 1.0
44
+
45
+ num_epochs: 15 lr_scheduler: "linear"
46
+
47
+ warmup_ratio: 0.125 fp16: true
48
+
49
+ weight_decay: 0.0125 dataloader_num_workers: 4
50
+
51
+ gradient_clip: 1.0
52
+
53
+ lr_scheduler: "linear"# Loss
54
+
55
+ fp16: trueloss:
56
+
57
+ dataloader_num_workers: 4 dialect_weight: 3.0
58
+
59
+
60
+
61
+ # Data Augmentation# MLflow Configuration
62
+
63
+ augmentation:mlflow:
64
+
65
+ enabled: true enabled: true
66
+
67
+ prob: 0.8 tracking_uri: "mlruns"
68
+
69
+ experiment_name: "speaker-profiling"
70
+
71
+ # Loss run_name: null
72
+
73
+ loss: registered_model_name: null
74
+
75
+ dialect_weight: 3.0
76
+
77
+ # Dataset paths
78
+
79
+ # MLflow Configuration# ============================================================
80
+
81
+ mlflow:# STEP 1: Update RAW DATASET PATHS to your local ViSpeech location
82
+
83
+ enabled: true# STEP 2: Run prepare_data.py to extract features
84
+
85
+ tracking_uri: "mlruns"# STEP 3: Features will be saved to train_dir/val_dir folders
86
+
87
+ experiment_name: "speaker-profiling"# ============================================================
88
+
89
+ run_name: nulldata:
90
+
91
+ registered_model_name: null # === RAW DATASET PATHS (for prepare_data.py) ===
92
+
93
+ # Download ViSpeech: https://drive.google.com/file/d/1-BbOHf42o6eBje2WqQiiRKMtNxmZiRf9
94
+
95
+ # Dataset # Update these paths to match your local dataset location
96
+
97
+ # source: "vispeech" (CSV format) or "vimd" (HuggingFace format) vispeech_root: "/path/to/ViSpeech" # <-- UPDATE THIS
98
+
99
+ data:
100
+
101
+ source: "vispeech" # Options: vispeech, vimd # Training data
102
+
103
+ train_meta: "/path/to/ViSpeech/metadata/trainset.csv" # <-- UPDATE
104
+
105
+ # === ViSpeech (CSV format) === train_audio: "/path/to/ViSpeech/trainset" # <-- UPDATE
106
+
107
+ vispeech_root: "/path/to/ViSpeech"
108
+
109
+ train_meta: "/path/to/ViSpeech/metadata/trainset.csv" # Test data
110
+
111
+ train_audio: "/path/to/ViSpeech/trainset" clean_test_meta: "/path/to/ViSpeech/metadata/clean_testset.csv"
112
+
113
+ clean_test_meta: "/path/to/ViSpeech/metadata/clean_testset.csv" clean_test_audio: "/path/to/ViSpeech/clean_testset"
114
+
115
+ clean_test_audio: "/path/to/ViSpeech/clean_testset" noisy_test_meta: "/path/to/ViSpeech/metadata/noisy_testset.csv"
116
+
117
+ noisy_test_meta: "/path/to/ViSpeech/metadata/noisy_testset.csv" noisy_test_audio: "/path/to/ViSpeech/noisy_testset"
118
+
119
+ noisy_test_audio: "/path/to/ViSpeech/noisy_testset"
120
+
121
+ val_split: 0.15 # Validation split ratio (extracted from trainset)
122
+
123
+ val_split: 0.15
124
+
125
+ # === ViMD (HuggingFace format) ===
126
+
127
+ vimd_path: "/path/to/vimd-dataset" # === EXTRACTED FEATURES PATHS (for finetune.py) ===
128
+
129
+ # After running prepare_data.py, features will be saved here
130
+
131
+ # Output # These paths are relative to project root
132
+
133
+ output: train_dir: "datasets/ViSpeech/train"
134
+
135
+ dir: "output/speaker-profiling" val_dir: "datasets/ViSpeech/val"
136
+
137
+ save_total_limit: 3
138
+
139
+ metric_for_best_model: "dialect_acc"# Output
140
+
141
+ output:
142
+
143
+ # Early Stopping dir: "output/speaker-profiling"
144
+
145
+ early_stopping: save_total_limit: 3
146
+
147
+ patience: 3 metric_for_best_model: "dialect_acc"
148
+
149
+ threshold: 0.0025
150
+
151
+ # Early Stopping
152
+
153
+ # Label Mappingsearly_stopping:
154
+
155
+ labels: patience: 3
156
+
157
+ gender: threshold: 0.0025
158
+
159
+ Male: 0
160
+
161
+ Female: 1# Label Mappings (must match prepare_data.py)
162
+
163
+ 0: 0labels:
164
+
165
+ 1: 1 gender:
166
+
167
+ dialect: Male: 0
168
+
169
+ North: 0 Female: 1
170
+
171
+ Central: 1 dialect:
172
+
173
+ South: 2 North: 0
174
+
175
+ region_to_dialect: Central: 1
176
+
177
+ North: 0 South: 2
178
+
179
+ Central: 1
180
+
181
+ South: 2# Reproducibility
182
+
183
+ seed: 42
184
+
185
+ # Reproducibility
186
+ seed: 42
configs/infer.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference Configuration
2
+
3
+ # Model
4
+ model:
5
+ checkpoint: "model/vulehuubinh"
6
+ name: "nguyenvulebinh/wav2vec2-base-vi-vlsp2020"
7
+ head_hidden_dim: 256
8
+
9
+ # Audio Processing
10
+ audio:
11
+ sampling_rate: 16000
12
+ max_duration: 5
13
+
14
+ # Inference
15
+ inference:
16
+ batch_size: 1
17
+ device: "cuda"
18
+
19
+ # Input
20
+ input:
21
+ audio_path: null
22
+ audio_dir: null
23
+
24
+ # Output
25
+ output:
26
+ dir: "output/predictions"
27
+ save_results: true
28
+ format: "json"
29
+
30
+ # Label Mappings
31
+ # NOTE: Model was trained with Female=0, Male=1 (opposite of finetune.yaml order)
32
+ # This is because pandas .map() may have processed labels in different order
33
+ labels:
34
+ gender:
35
+ 0: "Female"
36
+ 1: "Male"
37
+ dialect:
38
+ 0: "North"
39
+ 1: "Central"
40
+ 2: "South"
configs/infer.yaml.example ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference Configuration# Inference Configuration
2
+
3
+ # Predict gender and dialect from audio# Architecture: WavLM + Attentive Pooling + LayerNorm + Deeper Heads
4
+
5
+ # Copy this file to infer.yaml and update paths# Copy this file to infer.yaml and update paths
6
+
7
+
8
+
9
+ # Model# Model
10
+
11
+ model:model:
12
+
13
+ checkpoint: "path/to/best_model" checkpoint: "path/to/best_model"
14
+
15
+ name: "microsoft/wavlm-base-plus" name: "microsoft/wavlm-base-plus"
16
+
17
+ head_hidden_dim: 256 head_hidden_dim: 256
18
+
19
+
20
+
21
+ # Audio Processing# Audio Processing
22
+
23
+ audio:audio:
24
+
25
+ sampling_rate: 16000 sampling_rate: 16000
26
+
27
+ max_duration: 5 max_duration: 5
28
+
29
+
30
+
31
+ # Inference# Inference
32
+
33
+ inference:inference:
34
+
35
+ batch_size: 1 batch_size: 1
36
+
37
+ device: "cuda" device: "cuda"
38
+
39
+
40
+
41
+ # Input# Input
42
+
43
+ input:input:
44
+
45
+ audio_path: null audio_path: null
46
+
47
+ audio_dir: null audio_dir: null
48
+
49
+
50
+
51
+ # Output# Output
52
+
53
+ output:output:
54
+
55
+ dir: "output/predictions" dir: "output/predictions"
56
+
57
+ save_results: true save_results: true
58
+
59
+ format: "json" format: "json"
60
+
61
+
62
+
63
+ # Label Mappings# Label Mappings
64
+
65
+ labels:labels:
66
+
67
+ gender: gender:
68
+
69
+ 0: "Male" 0: "Male"
70
+
71
+ 1: "Female" 1: "Female"
72
+
73
+ dialect: dialect:
74
+
75
+ 0: "North" 0: "North"
76
+
77
+ 1: "Central" 1: "Central"
78
+
79
+ 2: "South" 2: "South"
80
+
configs/train_ecapa.yaml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Config for ECAPA-TDNN (SpeechBrain)
2
+ # Model: speechbrain/spkrec-ecapa-voxceleb
3
+
4
+ # Model
5
+ model:
6
+ name: "speechbrain/spkrec-ecapa-voxceleb"
7
+ num_genders: 2
8
+ num_dialects: 3
9
+ dropout: 0.1
10
+ head_hidden_dim: 128 # Smaller head for 192-dim embeddings
11
+
12
+ # Audio processing
13
+ audio:
14
+ sampling_rate: 16000
15
+ max_duration: 5 # seconds
16
+
17
+ # Training
18
+ training:
19
+ batch_size: 32
20
+ learning_rate: 1e-4 # Higher LR since only training heads
21
+ num_epochs: 15
22
+ warmup_ratio: 0.1
23
+ weight_decay: 0.01
24
+ gradient_clip: 1.0
25
+ lr_scheduler: "linear"
26
+ fp16: false # ECAPA-TDNN does not support fp16
27
+ dataloader_num_workers: 4
28
+
29
+ # Data Augmentation
30
+ augmentation:
31
+ enabled: true
32
+ prob: 0.8
33
+
34
+ # Loss
35
+ loss:
36
+ dialect_weight: 3.0
37
+
38
+ # WandB Configuration
39
+ wandb:
40
+ enabled: true
41
+ api_key: "f05e29c3466ec288e97041e0e3d541c4087096a6"
42
+ project: "vispeech-speaker-profiling"
43
+ run_name: "ecapa-tdnn"
44
+
45
+ # Dataset paths
46
+ data:
47
+ source: "vispeech" # Options: vispeech, vimd
48
+
49
+ # === ViSpeech (CSV format) ===
50
+ vispeech_root: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech"
51
+ train_meta: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/metadata/trainset.csv"
52
+ train_audio: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/trainset"
53
+ clean_test_meta: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/metadata/clean_testset.csv"
54
+ clean_test_audio: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/clean_testset"
55
+ noisy_test_meta: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/metadata/noisy_testset.csv"
56
+ noisy_test_audio: "/home/ubuntu/DataScience/Voice_Pro_filing/vispeech_data/ViSpeech/noisy_testset"
57
+ val_split: 0.15
58
+
59
+ # === ViMD (HuggingFace format) ===
60
+ vimd_path: "/kaggle/input/vimd-dataset"
61
+
62
+ # Output
63
+ output:
64
+ dir: "output/ecapa-tdnn"
65
+ save_total_limit: 3
66
+ metric_for_best_model: "dialect_acc"
67
+
68
+ # Early Stopping
69
+ early_stopping:
70
+ patience: 3
71
+ threshold: 0.0025
72
+
73
+ # Label Mappings
74
+ labels:
75
+ gender:
76
+ Male: 0
77
+ Female: 1
78
+ 0: 0
79
+ 1: 1
80
+ dialect:
81
+ North: 0
82
+ Central: 1
83
+ South: 2
84
+ region_to_dialect:
85
+ North: 0
86
+ Central: 1
87
+ South: 2
88
+
89
+ # Reproducibility
90
+ seed: 42
model/vulehuubinh/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a5b4a3417c2d783e44b7cfd701b083b979c076fde257fc0ea80c12fab5705ad
3
+ size 381595388
model/vulehuubinh/preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "processor_class": "Wav2Vec2ProcessorWithLM",
8
+ "return_attention_mask": false,
9
+ "sampling_rate": 16000
10
+ }
model/vulehuubinh/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a059e8720c9e406f538f14e191d903e9efad04de1f27661fc918fefecbd6bea1
3
+ size 5176
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace Spaces requirements
2
+ torch>=2.0.0
3
+ torchaudio>=2.0.0
4
+ transformers==4.44.0
5
+ librosa>=0.10.0
6
+ soundfile>=0.12.0
7
+ numpy<2.0
8
+ safetensors>=0.4.0
9
+ gradio>=4.0.0
10
+ pyyaml>=6.0
11
+ omegaconf
src/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speaker Profiling Source Package
3
+ """
4
+
5
+ from .models import (
6
+ AttentivePooling,
7
+ MultiTaskSpeakerModel,
8
+ MultiTaskSpeakerModelFromConfig
9
+ )
10
+
11
+ from .utils import (
12
+ setup_logging,
13
+ get_logger,
14
+ load_config,
15
+ set_seed,
16
+ load_audio,
17
+ preprocess_audio,
18
+ load_and_preprocess_audio,
19
+ load_model_checkpoint,
20
+ get_device,
21
+ count_parameters,
22
+ format_number
23
+ )
24
+
25
+ __all__ = [
26
+ # Models
27
+ 'AttentivePooling',
28
+ 'MultiTaskSpeakerModel',
29
+ 'MultiTaskSpeakerModelFromConfig',
30
+ # Utils
31
+ 'setup_logging',
32
+ 'get_logger',
33
+ 'load_config',
34
+ 'set_seed',
35
+ 'load_audio',
36
+ 'preprocess_audio',
37
+ 'load_and_preprocess_audio',
38
+ 'load_model_checkpoint',
39
+ 'get_device',
40
+ 'count_parameters',
41
+ 'format_number'
42
+ ]
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (925 Bytes). View file
 
src/__pycache__/models.cpython-311.pyc ADDED
Binary file (28.2 kB). View file
 
src/__pycache__/utils.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
src/models.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Architecture for Speaker Profiling
3
+ Supports multiple encoders: WavLM, HuBERT, Wav2Vec2, Whisper, ECAPA-TDNN
4
+ Architecture: Encoder + Attentive Pooling + LayerNorm + Classification Heads
5
+ """
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from transformers import (
12
+ WavLMModel,
13
+ HubertModel,
14
+ Wav2Vec2Model,
15
+ WhisperModel,
16
+ AutoConfig
17
+ )
18
+
19
+ # SpeechBrain ECAPA-TDNN support - lazy import to avoid torchaudio issues
20
+ SPEECHBRAIN_AVAILABLE = None # Will be set on first use
21
+ EncoderClassifier = None # Will be imported lazily
22
+
23
+ def _check_speechbrain():
24
+ """Lazily check and import SpeechBrain"""
25
+ global SPEECHBRAIN_AVAILABLE, EncoderClassifier
26
+ if SPEECHBRAIN_AVAILABLE is None:
27
+ try:
28
+ from speechbrain.inference.speaker import EncoderClassifier as _EncoderClassifier
29
+ EncoderClassifier = _EncoderClassifier
30
+ SPEECHBRAIN_AVAILABLE = True
31
+ except (ImportError, AttributeError) as e:
32
+ SPEECHBRAIN_AVAILABLE = False
33
+ logger.warning(f"SpeechBrain not available: {e}")
34
+ return SPEECHBRAIN_AVAILABLE
35
+
36
+ logger = logging.getLogger("speaker_profiling")
37
+
38
+
39
+ # ECAPA-TDNN wrapper class for consistent interface
40
+ class ECAPATDNNEncoder(nn.Module):
41
+ """
42
+ Wrapper for SpeechBrain ECAPA-TDNN encoder.
43
+
44
+ ECAPA-TDNN outputs fixed-size embeddings (192 or 512 dim) instead of
45
+ frame-level features like WavLM/HuBERT. This wrapper handles the difference.
46
+
47
+ Supported models:
48
+ - speechbrain/spkrec-ecapa-voxceleb: 192-dim embeddings
49
+ - speechbrain/spkrec-xvect-voxceleb: 512-dim embeddings (x-vector)
50
+ """
51
+
52
+ def __init__(self, model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
53
+ super().__init__()
54
+
55
+ # Lazy import SpeechBrain
56
+ if not _check_speechbrain():
57
+ raise ImportError(
58
+ "SpeechBrain is required for ECAPA-TDNN. "
59
+ "Install with: pip install speechbrain"
60
+ )
61
+
62
+ self.model_name = model_name
63
+
64
+ # Detect if CUDA is available
65
+ device = "cuda" if torch.cuda.is_available() else "cpu"
66
+
67
+ self.encoder = EncoderClassifier.from_hparams(
68
+ source=model_name,
69
+ savedir=f"pretrained_models/{model_name.split('/')[-1]}",
70
+ run_opts={"device": device}
71
+ )
72
+
73
+ # Force float32 for all encoder parameters
74
+ self.encoder.mods.float()
75
+
76
+ # Determine embedding size
77
+ if "ecapa" in model_name.lower():
78
+ self.embedding_size = 192
79
+ elif "xvect" in model_name.lower():
80
+ self.embedding_size = 512
81
+ else:
82
+ self.embedding_size = 192 # default
83
+
84
+ # Config-like object for compatibility
85
+ class Config:
86
+ def __init__(self, hidden_size):
87
+ self.hidden_size = hidden_size
88
+
89
+ self.config = Config(self.embedding_size)
90
+
91
+ # Track current device
92
+ self._current_device = device
93
+
94
+ def forward(self, input_values: torch.Tensor, attention_mask: torch.Tensor = None):
95
+ """
96
+ Extract embeddings from audio.
97
+
98
+ Args:
99
+ input_values: Audio waveform [B, T]
100
+ attention_mask: Not used for ECAPA-TDNN
101
+
102
+ Returns:
103
+ Object with last_hidden_state attribute [B, 1, H]
104
+ """
105
+ # Get device from input
106
+ device = input_values.device
107
+
108
+ # Move encoder to same device as input if needed
109
+ if str(device) != str(self._current_device):
110
+ self.encoder.to(device)
111
+ self.encoder.mods.float() # Ensure float32 after move
112
+ self._current_device = device
113
+
114
+ # Ensure input is float32 and on correct device
115
+ input_values = input_values.float().to(device)
116
+
117
+ # SpeechBrain expects [B, T] audio at 16kHz
118
+ # encode_batch handles feature extraction internally
119
+ with torch.no_grad():
120
+ # Set encoder to eval mode to handle BatchNorm properly
121
+ self.encoder.eval()
122
+ embeddings = self.encoder.encode_batch(input_values) # [B, 1, H]
123
+
124
+ # Ensure output is float32
125
+ embeddings = embeddings.float()
126
+
127
+ # Return object compatible with HuggingFace models
128
+ class Output:
129
+ def __init__(self, hidden_state):
130
+ self.last_hidden_state = hidden_state
131
+
132
+ return Output(embeddings)
133
+
134
+
135
+ # Encoder registry - maps model type to class and hidden size
136
+ ENCODER_REGISTRY = {
137
+ # WavLM variants
138
+ "microsoft/wavlm-base": {"class": WavLMModel, "hidden_size": 768},
139
+ "microsoft/wavlm-base-plus": {"class": WavLMModel, "hidden_size": 768},
140
+ "microsoft/wavlm-large": {"class": WavLMModel, "hidden_size": 1024},
141
+
142
+ # HuBERT variants
143
+ "facebook/hubert-base-ls960": {"class": HubertModel, "hidden_size": 768},
144
+ "facebook/hubert-large-ls960-ft": {"class": HubertModel, "hidden_size": 1024},
145
+ "facebook/hubert-xlarge-ls960-ft": {"class": HubertModel, "hidden_size": 1280},
146
+
147
+ # Wav2Vec2 variants
148
+ "facebook/wav2vec2-base": {"class": Wav2Vec2Model, "hidden_size": 768},
149
+ "facebook/wav2vec2-base-960h": {"class": Wav2Vec2Model, "hidden_size": 768},
150
+ "facebook/wav2vec2-large": {"class": Wav2Vec2Model, "hidden_size": 1024},
151
+ "facebook/wav2vec2-large-960h": {"class": Wav2Vec2Model, "hidden_size": 1024},
152
+ "facebook/wav2vec2-xls-r-300m": {"class": Wav2Vec2Model, "hidden_size": 1024},
153
+
154
+ # Vietnamese Wav2Vec2 (VLSP2020)
155
+ "nguyenvulebinh/wav2vec2-base-vi-vlsp2020": {"class": Wav2Vec2Model, "hidden_size": 768},
156
+
157
+ # Whisper variants (encoder only)
158
+ "openai/whisper-tiny": {"class": WhisperModel, "hidden_size": 384, "is_whisper": True},
159
+ "openai/whisper-base": {"class": WhisperModel, "hidden_size": 512, "is_whisper": True},
160
+ "openai/whisper-small": {"class": WhisperModel, "hidden_size": 768, "is_whisper": True},
161
+ "openai/whisper-medium": {"class": WhisperModel, "hidden_size": 1024, "is_whisper": True},
162
+ "openai/whisper-large": {"class": WhisperModel, "hidden_size": 1280, "is_whisper": True},
163
+ "openai/whisper-large-v2": {"class": WhisperModel, "hidden_size": 1280, "is_whisper": True},
164
+ "openai/whisper-large-v3": {"class": WhisperModel, "hidden_size": 1280, "is_whisper": True},
165
+
166
+ # PhoWhisper - Vietnamese fine-tuned Whisper (VinAI)
167
+ "vinai/PhoWhisper-tiny": {"class": WhisperModel, "hidden_size": 384, "is_whisper": True},
168
+ "vinai/PhoWhisper-base": {"class": WhisperModel, "hidden_size": 512, "is_whisper": True},
169
+ "vinai/PhoWhisper-small": {"class": WhisperModel, "hidden_size": 768, "is_whisper": True},
170
+ "vinai/PhoWhisper-medium": {"class": WhisperModel, "hidden_size": 1024, "is_whisper": True},
171
+ "vinai/PhoWhisper-large": {"class": WhisperModel, "hidden_size": 1280, "is_whisper": True},
172
+
173
+ # ECAPA-TDNN (SpeechBrain)
174
+ "speechbrain/spkrec-ecapa-voxceleb": {
175
+ "class": ECAPATDNNEncoder,
176
+ "hidden_size": 192,
177
+ "is_ecapa": True
178
+ },
179
+ "speechbrain/spkrec-xvect-voxceleb": {
180
+ "class": ECAPATDNNEncoder,
181
+ "hidden_size": 512,
182
+ "is_ecapa": True
183
+ },
184
+ }
185
+
186
+
187
+ def get_encoder_info(model_name: str) -> dict:
188
+ """Get encoder class and hidden size for a model name"""
189
+ if model_name in ENCODER_REGISTRY:
190
+ return ENCODER_REGISTRY[model_name]
191
+
192
+ # Check for ECAPA-TDNN / SpeechBrain models
193
+ # Note: We don't check SPEECHBRAIN_AVAILABLE here - the actual import
194
+ # will happen lazily in ECAPATDNNEncoder.__init__() when the model is used
195
+ if 'ecapa' in model_name.lower() or 'speechbrain' in model_name.lower():
196
+ hidden_size = 512 if 'xvect' in model_name.lower() else 192
197
+ return {"class": ECAPATDNNEncoder, "hidden_size": hidden_size, "is_ecapa": True}
198
+
199
+ # Try to auto-detect from config
200
+ try:
201
+ config = AutoConfig.from_pretrained(model_name)
202
+ hidden_size = getattr(config, 'hidden_size', 768)
203
+
204
+ if 'wavlm' in model_name.lower():
205
+ return {"class": WavLMModel, "hidden_size": hidden_size}
206
+ elif 'hubert' in model_name.lower():
207
+ return {"class": HubertModel, "hidden_size": hidden_size}
208
+ elif 'wav2vec2' in model_name.lower():
209
+ return {"class": Wav2Vec2Model, "hidden_size": hidden_size}
210
+ elif 'whisper' in model_name.lower() or 'phowhisper' in model_name.lower():
211
+ return {"class": WhisperModel, "hidden_size": hidden_size, "is_whisper": True}
212
+ else:
213
+ # Default to Wav2Vec2 architecture
214
+ return {"class": Wav2Vec2Model, "hidden_size": hidden_size}
215
+ except Exception as e:
216
+ logger.warning(f"Could not auto-detect encoder for {model_name}: {e}")
217
+ return {"class": WavLMModel, "hidden_size": 768}
218
+
219
+
220
+ class AttentivePooling(nn.Module):
221
+ """
222
+ Attention-based pooling for temporal aggregation
223
+
224
+ Takes sequence of hidden states and produces a single vector
225
+ by computing attention weights and performing weighted sum.
226
+ """
227
+
228
+ def __init__(self, hidden_size: int):
229
+ super().__init__()
230
+ self.attention = nn.Sequential(
231
+ nn.Linear(hidden_size, hidden_size),
232
+ nn.Tanh(),
233
+ nn.Linear(hidden_size, 1, bias=False)
234
+ )
235
+
236
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
237
+ """
238
+ Args:
239
+ x: Hidden states [B, T, H]
240
+ mask: Attention mask [B, T]
241
+
242
+ Returns:
243
+ pooled: Pooled representation [B, H]
244
+ attn_weights: Attention weights [B, T]
245
+ """
246
+ attn_weights = self.attention(x) # [B, T, 1]
247
+
248
+ if mask is not None:
249
+ mask = mask.unsqueeze(-1)
250
+ attn_weights = attn_weights.masked_fill(mask == 0, -1e9)
251
+
252
+ attn_weights = F.softmax(attn_weights, dim=1)
253
+ pooled = torch.sum(x * attn_weights, dim=1)
254
+
255
+ return pooled, attn_weights.squeeze(-1)
256
+
257
+
258
+ class MultiTaskSpeakerModel(nn.Module):
259
+ """
260
+ Multi-task model for gender and dialect classification
261
+
262
+ Architecture:
263
+ Audio -> Encoder (WavLM/HuBERT/Wav2Vec2/Whisper/ECAPA-TDNN) -> Last Hidden [B,T,H]
264
+ |
265
+ Attentive Pooling [B,H] (skipped for ECAPA-TDNN)
266
+ |
267
+ Layer Normalization
268
+ |
269
+ Dropout(0.1)
270
+ |
271
+ +---------------+---------------+
272
+ | |
273
+ Gender Head (2 layers) Dialect Head (3 layers)
274
+ | |
275
+ [B,2] [B,3]
276
+
277
+ Supported encoders:
278
+ - WavLM: microsoft/wavlm-base-plus, microsoft/wavlm-large
279
+ - HuBERT: facebook/hubert-base-ls960, facebook/hubert-large-ls960-ft
280
+ - Wav2Vec2: facebook/wav2vec2-base, facebook/wav2vec2-large-960h
281
+ - Whisper: openai/whisper-base, openai/whisper-small, openai/whisper-medium
282
+ - ECAPA-TDNN: speechbrain/spkrec-ecapa-voxceleb (192-dim embeddings)
283
+
284
+ Args:
285
+ model_name: Pretrained encoder model name or path
286
+ num_genders: Number of gender classes (default: 2)
287
+ num_dialects: Number of dialect classes (default: 3)
288
+ dropout: Dropout probability (default: 0.1)
289
+ head_hidden_dim: Hidden dimension for classification heads (default: 256)
290
+ freeze_encoder: Whether to freeze encoder (default: False)
291
+ dialect_loss_weight: Weight for dialect loss in multi-task learning (default: 3.0)
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ model_name: str,
297
+ num_genders: int = 2,
298
+ num_dialects: int = 3,
299
+ dropout: float = 0.1,
300
+ head_hidden_dim: int = 256,
301
+ freeze_encoder: bool = False,
302
+ dialect_loss_weight: float = 3.0
303
+ ):
304
+ super().__init__()
305
+
306
+ self.model_name = model_name
307
+ self.dialect_loss_weight = dialect_loss_weight
308
+
309
+ # Get encoder info and load model
310
+ encoder_info = get_encoder_info(model_name)
311
+ encoder_class = encoder_info["class"]
312
+ self.is_whisper = encoder_info.get("is_whisper", False)
313
+ self.is_ecapa = encoder_info.get("is_ecapa", False)
314
+
315
+ logger.info(f"Loading encoder: {model_name}")
316
+ logger.info(f"Encoder class: {encoder_class.__name__}")
317
+
318
+ # Load pretrained encoder
319
+ if self.is_ecapa:
320
+ # ECAPA-TDNN uses different loading mechanism
321
+ self.encoder = encoder_class(model_name)
322
+ else:
323
+ self.encoder = encoder_class.from_pretrained(model_name)
324
+
325
+ hidden_size = self.encoder.config.hidden_size
326
+ self.hidden_size = hidden_size
327
+
328
+ logger.info(f"Hidden size: {hidden_size}")
329
+
330
+ # Optionally freeze encoder
331
+ if freeze_encoder:
332
+ for param in self.encoder.parameters():
333
+ param.requires_grad = False
334
+ logger.info("Encoder weights frozen")
335
+
336
+ # Pooling and normalization (ECAPA-TDNN already outputs pooled embeddings)
337
+ self.attentive_pooling = AttentivePooling(hidden_size)
338
+ self.layer_norm = nn.LayerNorm(hidden_size)
339
+ self.dropout = nn.Dropout(dropout)
340
+
341
+ # Gender classification head (2 layers)
342
+ self.gender_head = nn.Sequential(
343
+ nn.Linear(hidden_size, head_hidden_dim),
344
+ nn.ReLU(),
345
+ nn.Dropout(dropout),
346
+ nn.Linear(head_hidden_dim, num_genders)
347
+ )
348
+
349
+ # Dialect classification head (3 layers - deeper for harder task)
350
+ self.dialect_head = nn.Sequential(
351
+ nn.Linear(hidden_size, head_hidden_dim),
352
+ nn.ReLU(),
353
+ nn.Dropout(dropout),
354
+ nn.Linear(head_hidden_dim, head_hidden_dim // 2),
355
+ nn.ReLU(),
356
+ nn.Dropout(dropout),
357
+ nn.Linear(head_hidden_dim // 2, num_dialects)
358
+ )
359
+
360
+ def forward(
361
+ self,
362
+ input_values: torch.Tensor = None,
363
+ input_features: torch.Tensor = None,
364
+ attention_mask: torch.Tensor = None,
365
+ gender_labels: torch.Tensor = None,
366
+ dialect_labels: torch.Tensor = None
367
+ ):
368
+ """
369
+ Forward pass - supports both raw audio and pre-extracted features
370
+
371
+ Args:
372
+ input_values: Audio waveform [B, T] (for raw audio mode)
373
+ input_features: Pre-extracted features [B, T, H] or [B, 1, H] for ECAPA
374
+ attention_mask: Attention mask [B, T]
375
+ gender_labels: Gender labels [B] (optional, for training)
376
+ dialect_labels: Dialect labels [B] (optional, for training)
377
+
378
+ Returns:
379
+ dict with keys:
380
+ - loss: Combined loss (if labels provided)
381
+ - gender_logits: Gender predictions [B, num_genders]
382
+ - dialect_logits: Dialect predictions [B, num_dialects]
383
+ - attention_weights: Attention weights from pooling [B, T] (None for ECAPA)
384
+ """
385
+ # Get hidden states from either raw audio or pre-extracted features
386
+ if input_features is not None:
387
+ # Use pre-extracted features directly
388
+ hidden_states = input_features
389
+ elif input_values is not None:
390
+ # Extract features from encoder
391
+ hidden_states = self._encode(input_values, attention_mask)
392
+ else:
393
+ raise ValueError("Either input_values or input_features must be provided")
394
+
395
+ # Handle ECAPA-TDNN (outputs [B, 1, H] - already pooled embeddings)
396
+ if self.is_ecapa or hidden_states.shape[1] == 1:
397
+ # ECAPA-TDNN outputs already pooled embeddings
398
+ pooled = hidden_states.squeeze(1) # [B, H]
399
+ attn_weights = None
400
+ else:
401
+ # Create proper attention mask for hidden states (encoder downsamples audio)
402
+ # Hidden states have different sequence length than input audio
403
+ if attention_mask is not None and hidden_states.shape[1] != attention_mask.shape[1]:
404
+ # Create new mask based on hidden states length
405
+ batch_size, seq_len, _ = hidden_states.shape
406
+ pooled_mask = torch.ones(batch_size, seq_len, device=hidden_states.device)
407
+ else:
408
+ pooled_mask = attention_mask
409
+
410
+ # Attentive pooling
411
+ pooled, attn_weights = self.attentive_pooling(hidden_states, pooled_mask)
412
+
413
+ # Normalization and dropout
414
+ pooled = self.layer_norm(pooled)
415
+ pooled = self.dropout(pooled)
416
+
417
+ # Classification heads
418
+ gender_logits = self.gender_head(pooled)
419
+ dialect_logits = self.dialect_head(pooled)
420
+
421
+ # Compute loss if labels provided
422
+ loss = None
423
+ if gender_labels is not None and dialect_labels is not None:
424
+ loss_fct = nn.CrossEntropyLoss()
425
+ gender_loss = loss_fct(gender_logits, gender_labels)
426
+ dialect_loss = loss_fct(dialect_logits, dialect_labels)
427
+ loss = gender_loss + self.dialect_loss_weight * dialect_loss
428
+
429
+ return {
430
+ 'loss': loss,
431
+ 'gender_logits': gender_logits,
432
+ 'dialect_logits': dialect_logits,
433
+ 'attention_weights': attn_weights
434
+ }
435
+
436
+ def _encode(
437
+ self,
438
+ input_values: torch.Tensor,
439
+ attention_mask: torch.Tensor = None
440
+ ) -> torch.Tensor:
441
+ """
442
+ Extract hidden states from encoder
443
+
444
+ Args:
445
+ input_values: Audio waveform [B, T]
446
+ attention_mask: Attention mask [B, T]
447
+
448
+ Returns:
449
+ hidden_states: Hidden states [B, T, H] or [B, 1, H] for ECAPA-TDNN
450
+ """
451
+ if self.is_ecapa:
452
+ # ECAPA-TDNN outputs fixed-size embeddings [B, 1, H]
453
+ outputs = self.encoder(input_values, attention_mask)
454
+ hidden_states = outputs.last_hidden_state
455
+ elif self.is_whisper:
456
+ # Whisper uses encoder-decoder, we only use encoder
457
+ outputs = self.encoder.encoder(input_values)
458
+ hidden_states = outputs.last_hidden_state
459
+ else:
460
+ # WavLM, HuBERT, Wav2Vec2
461
+ outputs = self.encoder(input_values, attention_mask=attention_mask)
462
+ hidden_states = outputs.last_hidden_state
463
+
464
+ return hidden_states
465
+
466
+ def get_embeddings(
467
+ self,
468
+ input_values: torch.Tensor,
469
+ attention_mask: torch.Tensor = None
470
+ ) -> torch.Tensor:
471
+ """
472
+ Extract speaker embeddings (pooled representations)
473
+
474
+ Args:
475
+ input_values: Audio waveform [B, T]
476
+ attention_mask: Attention mask [B, T]
477
+
478
+ Returns:
479
+ embeddings: Speaker embeddings [B, H]
480
+ """
481
+ hidden_states = self._encode(input_values, attention_mask)
482
+
483
+ if self.is_ecapa or hidden_states.shape[1] == 1:
484
+ # ECAPA-TDNN already outputs pooled embeddings
485
+ pooled = hidden_states.squeeze(1)
486
+ else:
487
+ pooled, _ = self.attentive_pooling(hidden_states, attention_mask)
488
+
489
+ pooled = self.layer_norm(pooled)
490
+ return pooled
491
+
492
+
493
+ class MultiTaskSpeakerModelFromConfig(MultiTaskSpeakerModel):
494
+ """
495
+ Multi-task model initialized from OmegaConf config
496
+
497
+ Supports multiple encoders: WavLM, HuBERT, Wav2Vec2, Whisper
498
+ Use this for inference with raw audio input.
499
+
500
+ Usage:
501
+ config = OmegaConf.load('configs/finetune.yaml')
502
+ model = MultiTaskSpeakerModelFromConfig(config)
503
+ """
504
+
505
+ def __init__(self, config):
506
+ model_config = config['model']
507
+
508
+ super().__init__(
509
+ model_name=model_config['name'],
510
+ num_genders=model_config.get('num_genders', 2),
511
+ num_dialects=model_config.get('num_dialects', 3),
512
+ dropout=model_config.get('dropout', 0.1),
513
+ head_hidden_dim=model_config.get('head_hidden_dim', 256),
514
+ freeze_encoder=model_config.get('freeze_encoder', False),
515
+ dialect_loss_weight=config.get('loss', {}).get('dialect_weight', 3.0)
516
+ )
517
+
518
+ logger.info(f"Architecture: {model_config['name']} + Attentive Pooling + LayerNorm")
519
+ logger.info(f"Hidden size: {self.hidden_size}")
520
+ logger.info(f"Head hidden dim: {model_config.get('head_hidden_dim', 256)}")
521
+ logger.info(f"Dropout: {model_config.get('dropout', 0.1)}")
522
+
523
+
524
+ class ClassificationHeadModel(nn.Module):
525
+ """
526
+ Lightweight model with only classification heads (no encoder).
527
+
528
+ Use this for training with pre-extracted features to save memory.
529
+ Hidden_size depends on encoder: WavLM-base=768, WavLM-large=1024, etc.
530
+
531
+ Usage:
532
+ model = ClassificationHeadModel(config)
533
+ output = model(input_features=features, gender_labels=y_gender, dialect_labels=y_dialect)
534
+ """
535
+
536
+ def __init__(
537
+ self,
538
+ hidden_size: int = 768,
539
+ num_genders: int = 2,
540
+ num_dialects: int = 3,
541
+ dropout: float = 0.1,
542
+ head_hidden_dim: int = 256,
543
+ dialect_loss_weight: float = 3.0
544
+ ):
545
+ super().__init__()
546
+
547
+ self.hidden_size = hidden_size
548
+ self.dialect_loss_weight = dialect_loss_weight
549
+
550
+ # Pooling and normalization
551
+ self.attentive_pooling = AttentivePooling(hidden_size)
552
+ self.layer_norm = nn.LayerNorm(hidden_size)
553
+ self.dropout = nn.Dropout(dropout)
554
+
555
+ # Gender classification head (2 layers)
556
+ self.gender_head = nn.Sequential(
557
+ nn.Linear(hidden_size, head_hidden_dim),
558
+ nn.ReLU(),
559
+ nn.Dropout(dropout),
560
+ nn.Linear(head_hidden_dim, num_genders)
561
+ )
562
+
563
+ # Dialect classification head (3 layers - deeper for harder task)
564
+ self.dialect_head = nn.Sequential(
565
+ nn.Linear(hidden_size, head_hidden_dim),
566
+ nn.ReLU(),
567
+ nn.Dropout(dropout),
568
+ nn.Linear(head_hidden_dim, head_hidden_dim // 2),
569
+ nn.ReLU(),
570
+ nn.Dropout(dropout),
571
+ nn.Linear(head_hidden_dim // 2, num_dialects)
572
+ )
573
+
574
+ logger.info(f"ClassificationHeadModel initialized (hidden_size={hidden_size})")
575
+
576
+ def forward(
577
+ self,
578
+ input_features: torch.Tensor,
579
+ attention_mask: torch.Tensor = None,
580
+ gender_labels: torch.Tensor = None,
581
+ dialect_labels: torch.Tensor = None
582
+ ):
583
+ """
584
+ Forward pass for pre-extracted features
585
+
586
+ Args:
587
+ input_features: Pre-extracted WavLM features [B, T, H]
588
+ attention_mask: Attention mask [B, T]
589
+ gender_labels: Gender labels [B] (optional, for training)
590
+ dialect_labels: Dialect labels [B] (optional, for training)
591
+
592
+ Returns:
593
+ dict with keys:
594
+ - loss: Combined loss (if labels provided)
595
+ - gender_logits: Gender predictions [B, num_genders]
596
+ - dialect_logits: Dialect predictions [B, num_dialects]
597
+ - attention_weights: Attention weights from pooling [B, T]
598
+ """
599
+ # Attentive pooling
600
+ pooled, attn_weights = self.attentive_pooling(input_features, attention_mask)
601
+
602
+ # Normalization and dropout
603
+ pooled = self.layer_norm(pooled)
604
+ pooled = self.dropout(pooled)
605
+
606
+ # Classification heads
607
+ gender_logits = self.gender_head(pooled)
608
+ dialect_logits = self.dialect_head(pooled)
609
+
610
+ # Compute loss if labels provided
611
+ loss = None
612
+ if gender_labels is not None and dialect_labels is not None:
613
+ loss_fct = nn.CrossEntropyLoss()
614
+ gender_loss = loss_fct(gender_logits, gender_labels)
615
+ dialect_loss = loss_fct(dialect_logits, dialect_labels)
616
+ loss = gender_loss + self.dialect_loss_weight * dialect_loss
617
+
618
+ return {
619
+ 'loss': loss,
620
+ 'gender_logits': gender_logits,
621
+ 'dialect_logits': dialect_logits,
622
+ 'attention_weights': attn_weights
623
+ }
624
+
625
+
626
+ class ClassificationHeadModelFromConfig(ClassificationHeadModel):
627
+ """
628
+ Lightweight classification model initialized from OmegaConf config.
629
+
630
+ Use this for training with pre-extracted features.
631
+ """
632
+
633
+ def __init__(self, config):
634
+ model_config = config['model']
635
+
636
+ super().__init__(
637
+ hidden_size=model_config.get('hidden_size', 768), # WavLM base hidden size
638
+ num_genders=model_config.get('num_genders', 2),
639
+ num_dialects=model_config.get('num_dialects', 3),
640
+ dropout=model_config.get('dropout', 0.1),
641
+ head_hidden_dim=model_config.get('head_hidden_dim', 256),
642
+ dialect_loss_weight=config.get('loss', {}).get('dialect_weight', 3.0)
643
+ )
644
+
645
+ logger.info("Architecture: Attentive Pooling + LayerNorm + Classification Heads")
646
+ logger.info(f"Hidden size: {self.hidden_size}")
647
+ logger.info(f"Head hidden dim: {model_config.get('head_hidden_dim', 256)}")
648
+ logger.info(f"Dropout: {model_config.get('dropout', 0.1)}")
src/utils.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for Speaker Profiling
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ import random
8
+ import numpy as np
9
+ import torch
10
+ import librosa
11
+ from pathlib import Path
12
+ from omegaconf import OmegaConf
13
+ from typing import Union, Optional, Tuple
14
+
15
+
16
+ def setup_logging(
17
+ name: str = "speaker_profiling",
18
+ level: int = logging.INFO,
19
+ log_file: Optional[str] = None
20
+ ) -> logging.Logger:
21
+ """
22
+ Setup logging configuration
23
+
24
+ Args:
25
+ name: Logger name
26
+ level: Logging level
27
+ log_file: Optional path to log file
28
+
29
+ Returns:
30
+ Configured logger instance
31
+ """
32
+ logger = logging.getLogger(name)
33
+ logger.setLevel(level)
34
+
35
+ if logger.handlers:
36
+ logger.handlers.clear()
37
+
38
+ formatter = logging.Formatter(
39
+ fmt="%(asctime)s | %(levelname)s | %(message)s",
40
+ datefmt="%Y-%m-%d %H:%M:%S"
41
+ )
42
+
43
+ console_handler = logging.StreamHandler()
44
+ console_handler.setLevel(level)
45
+ console_handler.setFormatter(formatter)
46
+ logger.addHandler(console_handler)
47
+
48
+ if log_file:
49
+ os.makedirs(os.path.dirname(log_file), exist_ok=True)
50
+ file_handler = logging.FileHandler(log_file, encoding='utf-8')
51
+ file_handler.setLevel(level)
52
+ file_handler.setFormatter(formatter)
53
+ logger.addHandler(file_handler)
54
+
55
+ return logger
56
+
57
+
58
+ def get_logger(name: str = "speaker_profiling") -> logging.Logger:
59
+ """Get existing logger or create new one"""
60
+ logger = logging.getLogger(name)
61
+ if not logger.handlers:
62
+ return setup_logging(name)
63
+ return logger
64
+
65
+
66
+ def load_config(config_path: str) -> OmegaConf:
67
+ """
68
+ Load configuration from yaml file
69
+
70
+ Args:
71
+ config_path: Path to yaml config file
72
+
73
+ Returns:
74
+ OmegaConf configuration object
75
+ """
76
+ if not os.path.exists(config_path):
77
+ raise FileNotFoundError(f"Config file not found: {config_path}")
78
+ return OmegaConf.load(config_path)
79
+
80
+
81
+ def set_seed(seed: int) -> None:
82
+ """
83
+ Set random seed for reproducibility
84
+
85
+ Args:
86
+ seed: Random seed value
87
+ """
88
+ random.seed(seed)
89
+ np.random.seed(seed)
90
+ torch.manual_seed(seed)
91
+ if torch.cuda.is_available():
92
+ torch.cuda.manual_seed_all(seed)
93
+ torch.backends.cudnn.deterministic = True
94
+ torch.backends.cudnn.benchmark = False
95
+
96
+
97
+ def load_audio(
98
+ audio_path: Union[str, Path],
99
+ sampling_rate: int = 16000,
100
+ mono: bool = True
101
+ ) -> Tuple[np.ndarray, int]:
102
+ """
103
+ Load audio file
104
+
105
+ Args:
106
+ audio_path: Path to audio file
107
+ sampling_rate: Target sampling rate
108
+ mono: Whether to convert to mono
109
+
110
+ Returns:
111
+ Tuple of (audio array, sampling rate)
112
+ """
113
+ audio, sr = librosa.load(audio_path, sr=sampling_rate, mono=mono)
114
+ return audio, sr
115
+
116
+
117
+ def preprocess_audio(
118
+ audio: np.ndarray,
119
+ sampling_rate: int = 16000,
120
+ max_duration: float = 10.0,
121
+ trim_db: int = 20,
122
+ normalize: bool = True,
123
+ center_crop: bool = True
124
+ ) -> np.ndarray:
125
+ """
126
+ Preprocess audio for model input
127
+
128
+ Args:
129
+ audio: Raw audio array
130
+ sampling_rate: Audio sampling rate
131
+ max_duration: Maximum duration in seconds
132
+ trim_db: Threshold for silence trimming
133
+ normalize: Whether to normalize audio
134
+ center_crop: If True, center crop; else random crop (for training)
135
+
136
+ Returns:
137
+ Preprocessed audio array
138
+ """
139
+ max_length = int(sampling_rate * max_duration)
140
+
141
+ audio, _ = librosa.effects.trim(audio, top_db=trim_db)
142
+
143
+ if normalize:
144
+ audio = audio / (np.max(np.abs(audio)) + 1e-8)
145
+
146
+ if len(audio) < max_length:
147
+ audio = np.pad(audio, (0, max_length - len(audio)))
148
+ elif len(audio) > max_length:
149
+ if center_crop:
150
+ start = (len(audio) - max_length) // 2
151
+ else:
152
+ start = np.random.randint(0, len(audio) - max_length + 1)
153
+ audio = audio[start:start + max_length]
154
+
155
+ return audio
156
+
157
+
158
+ def load_and_preprocess_audio(
159
+ audio_path: Union[str, Path],
160
+ sampling_rate: int = 16000,
161
+ max_duration: float = 10.0,
162
+ trim_db: int = 20,
163
+ normalize: bool = True,
164
+ center_crop: bool = True
165
+ ) -> np.ndarray:
166
+ """
167
+ Load and preprocess audio file in one step
168
+
169
+ Args:
170
+ audio_path: Path to audio file
171
+ sampling_rate: Target sampling rate
172
+ max_duration: Maximum duration in seconds
173
+ trim_db: Threshold for silence trimming
174
+ normalize: Whether to normalize audio
175
+ center_crop: If True, center crop; else random crop
176
+
177
+ Returns:
178
+ Preprocessed audio array
179
+ """
180
+ audio, _ = load_audio(audio_path, sampling_rate)
181
+ return preprocess_audio(
182
+ audio,
183
+ sampling_rate,
184
+ max_duration,
185
+ trim_db,
186
+ normalize,
187
+ center_crop
188
+ )
189
+
190
+
191
+ def load_model_checkpoint(
192
+ model: torch.nn.Module,
193
+ checkpoint_path: str,
194
+ device: str = 'cpu'
195
+ ) -> torch.nn.Module:
196
+ """
197
+ Load model from checkpoint
198
+
199
+ Args:
200
+ model: PyTorch model instance
201
+ checkpoint_path: Path to checkpoint directory
202
+ device: Device to load model on
203
+
204
+ Returns:
205
+ Model with loaded weights
206
+ """
207
+ logger = get_logger()
208
+
209
+ safetensors_path = os.path.join(checkpoint_path, 'model.safetensors')
210
+ pytorch_path = os.path.join(checkpoint_path, 'pytorch_model.bin')
211
+
212
+ if os.path.exists(safetensors_path):
213
+ from safetensors.torch import load_file
214
+ state_dict = load_file(safetensors_path)
215
+ logger.info(f"Loading checkpoint from {safetensors_path}")
216
+ elif os.path.exists(pytorch_path):
217
+ state_dict = torch.load(pytorch_path, map_location=device)
218
+ logger.info(f"Loading checkpoint from {pytorch_path}")
219
+ else:
220
+ raise FileNotFoundError(
221
+ f"No checkpoint found in {checkpoint_path}. "
222
+ f"Expected 'model.safetensors' or 'pytorch_model.bin'"
223
+ )
224
+
225
+ model.load_state_dict(state_dict)
226
+ return model
227
+
228
+
229
+ def get_device(device_str: str = 'cuda') -> torch.device:
230
+ """
231
+ Get torch device, fallback to CPU if CUDA not available
232
+
233
+ Args:
234
+ device_str: Desired device string ('cuda' or 'cpu')
235
+
236
+ Returns:
237
+ torch.device instance
238
+ """
239
+ if device_str == 'cuda' and torch.cuda.is_available():
240
+ return torch.device('cuda')
241
+ return torch.device('cpu')
242
+
243
+
244
+ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
245
+ """
246
+ Count model parameters
247
+
248
+ Args:
249
+ model: PyTorch model
250
+
251
+ Returns:
252
+ Tuple of (total_params, trainable_params)
253
+ """
254
+ total = sum(p.numel() for p in model.parameters())
255
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
256
+ return total, trainable
257
+
258
+
259
+ def format_number(num: int) -> str:
260
+ """Format large numbers with commas"""
261
+ return f"{num:,}"