mds04 commited on
Commit
a31da5a
·
verified ·
1 Parent(s): efe17d8

Create helper_classes.py

Browse files
Files changed (1) hide show
  1. helper_classes.py +451 -0
helper_classes.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @dataclass
2
+ class Config:
3
+ """Configuration for the language identification pipeline"""
4
+ target_sample_rate: int = 16000
5
+ embedding_dim: int = 256
6
+ test_size: float = 0.2 # Changed to 0.2 as requested
7
+ random_state: int = 42
8
+ max_iter: int = 1000
9
+
10
+ # Language mappings for custom classifier
11
+ label_map: Dict[str, int] = None
12
+ canonical_languages: List[str] = None
13
+
14
+ def __post_init__(self):
15
+ # Now includes malay in the custom classifier
16
+ self.label_map = {"iban": 0, "bukar_sadong": 1, "malay": 2}
17
+ self.canonical_languages = ["malay", "english", "mandarin", "tamil"]
18
+
19
+ class AudioProcessor:
20
+ """Handles audio loading and preprocessing"""
21
+
22
+ def __init__(self, target_sr: int = 16000):
23
+ self.target_sr = target_sr
24
+
25
+ def load_audio(self, path: str) -> torch.Tensor:
26
+ """Load and preprocess audio file to mono 16kHz"""
27
+ signal, sr = torchaudio.load(path)
28
+
29
+ # Convert to mono if stereo
30
+ if signal.shape[0] > 1:
31
+ signal = signal.mean(dim=0, keepdim=True)
32
+
33
+ # Resample if needed
34
+ if sr != self.target_sr:
35
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sr)
36
+ signal = resampler(signal)
37
+
38
+ return signal.to(torch.float32)
39
+
40
+ class LanguageIdentifier:
41
+ """Main language identification system"""
42
+
43
+ def __init__(self, config: Config = None):
44
+ self.config = config or Config()
45
+ self.audio_processor = AudioProcessor(self.config.target_sample_rate)
46
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47
+
48
+ # Initialize models
49
+ self.vox_model = None
50
+ self.custom_classifier = None
51
+ self.label_encoder = None
52
+
53
+ def load_vox_model(self, model_path: str = None):
54
+ """Load SpeechBrain VoxLingua107 model"""
55
+ source = "speechbrain/lang-id-voxlingua107-ecapa"
56
+ savedir = model_path or "pretrained_models/lang-id-voxlingua107-ecapa"
57
+
58
+ self.vox_model = EncoderClassifier.from_hparams(
59
+ source=source,
60
+ savedir=savedir,
61
+ run_opts={"device": self.device}
62
+ )
63
+ self.label_encoder = self.vox_model.hparams.label_encoder
64
+ print(f"VoxLingua107 model loaded on {self.device}")
65
+
66
+ def extract_embedding(self, audio: Union[str, torch.Tensor]) -> np.ndarray:
67
+ """Extract embedding from audio using VoxLingua107"""
68
+ if isinstance(audio, str):
69
+ wav = self.audio_processor.load_audio(audio)
70
+ else:
71
+ wav = audio
72
+
73
+ # Ensure batch dimension
74
+ if wav.dim() == 1:
75
+ wav = wav.unsqueeze(0)
76
+
77
+ wav = wav.to(self.device, dtype=torch.float32)
78
+
79
+ with torch.no_grad():
80
+ embedding = self.vox_model.encode_batch(wav)
81
+ if isinstance(embedding, tuple):
82
+ embedding = embedding[0]
83
+ # Flatten to 1D array
84
+ embedding = embedding.view(embedding.size(0), -1).squeeze(0)
85
+ return embedding.cpu().numpy()
86
+
87
+ def normalize_language_label(self, raw_label: str) -> Optional[str]:
88
+ """Map VoxLingua107 short codes to canonical language names"""
89
+ label_code = raw_label.strip().lower()
90
+
91
+ # Direct mapping from VoxLingua codes to canonical names
92
+ vox_to_canonical = {
93
+ "ms": "malay",
94
+ "en": "english",
95
+ "zh": "mandarin",
96
+ "ta": "tamil"
97
+ }
98
+
99
+ return vox_to_canonical.get(label_code)
100
+
101
+ def extract_audio_files_from_zip(self, zip_path: str, extract_dir: str) -> List[Path]:
102
+ """Extract and return list of audio files from a zip archive"""
103
+ temp_extract = Path(extract_dir) / Path(zip_path).stem
104
+ if temp_extract.exists():
105
+ shutil.rmtree(temp_extract)
106
+ temp_extract.mkdir(parents=True)
107
+
108
+ with zipfile.ZipFile(zip_path, 'r') as z:
109
+ z.extractall(temp_extract)
110
+
111
+ # Find all audio files
112
+ audio_files = []
113
+ for ext in ['*.wav', '*.mp3']:
114
+ audio_files.extend(list(temp_extract.rglob(ext)))
115
+
116
+ return sorted(audio_files)
117
+
118
+ def train_custom_classifier(self, drive_base: str = "/content/drive"):
119
+ """Train custom classifier for Iban/Bukar Sadong/Malay"""
120
+ print("Training custom 3-language classifier...")
121
+
122
+ # Temporary extraction directory
123
+ temp_dir = Path("/tmp/training_data")
124
+ if temp_dir.exists():
125
+ shutil.rmtree(temp_dir)
126
+ temp_dir.mkdir(parents=True)
127
+
128
+ all_embeddings = []
129
+ all_labels = []
130
+ language_files = {"iban": [], "bukar_sadong": [], "malay": []}
131
+
132
+ # Process Iban data (from two sources)
133
+ print("\nProcessing Iban data...")
134
+ iban_zips = [
135
+ f"{drive_base}/MyDrive/language_identification/training_data/github_iban_filter_train.zip",
136
+ f"{drive_base}/MyDrive/language_identification/training_data/gkalaka_iban_filter_train.zip"
137
+ ]
138
+
139
+ for zip_path in iban_zips:
140
+ if os.path.exists(zip_path):
141
+ print(f"Extracting {Path(zip_path).name}...")
142
+ audio_files = self.extract_audio_files_from_zip(zip_path, str(temp_dir))
143
+ language_files["iban"].extend(audio_files)
144
+ print(f"Found {len(audio_files)} files")
145
+
146
+ # Process Malay data
147
+ print("\nProcessing Malay data...")
148
+ malay_zip = f"{drive_base}/MyDrive/language_identification/training_data/malay_train.zip"
149
+ if os.path.exists(malay_zip):
150
+ audio_files = self.extract_audio_files_from_zip(malay_zip, str(temp_dir))
151
+ language_files["malay"].extend(audio_files)
152
+ print(f"Found {len(audio_files)} Malay files")
153
+
154
+ # Process Bukar Sadong data
155
+ print("\nProcessing Bukar Sadong data...")
156
+ bukar_zip = f"{drive_base}/MyDrive/language_identification/training_data/bukar_sadong_train.zip"
157
+ if os.path.exists(bukar_zip):
158
+ audio_files = self.extract_audio_files_from_zip(bukar_zip, str(temp_dir))
159
+ language_files["bukar_sadong"].extend(audio_files)
160
+ print(f"Found {len(audio_files)} Bukar Sadong files")
161
+
162
+ # Extract embeddings for each language
163
+ for lang, files in language_files.items():
164
+ print(f"\nExtracting embeddings for {lang}: {len(files)} files")
165
+ for i, audio_file in enumerate(files):
166
+ if i % 100 == 0:
167
+ print(f"Processing {lang}: {i}/{len(files)}")
168
+ try:
169
+ emb = self.extract_embedding(str(audio_file))
170
+ all_embeddings.append(emb)
171
+ all_labels.append(self.config.label_map[lang])
172
+ except Exception as e:
173
+ print(f"Error processing {audio_file}: {e}")
174
+
175
+ if not all_embeddings:
176
+ raise ValueError("No training data collected")
177
+
178
+ X = np.array(all_embeddings)
179
+ y = np.array(all_labels)
180
+
181
+ print(f"\nTotal samples collected:")
182
+ print(f"Iban: {np.sum(y == 0)}")
183
+ print(f"Bukar Sadong: {np.sum(y == 1)}")
184
+ print(f"Malay: {np.sum(y == 2)}")
185
+
186
+ # Stratified split ensuring 20% from each language
187
+ X_train, X_test, y_train, y_test = train_test_split(
188
+ X, y, test_size=self.config.test_size,
189
+ stratify=y, random_state=self.config.random_state
190
+ )
191
+
192
+ print(f"\nTraining set distribution:")
193
+ for i, lang in enumerate(["iban", "bukar_sadong", "malay"]):
194
+ print(f"{lang}: {np.sum(y_train == i)}")
195
+
196
+ # Apply oversampling to balance the training set
197
+ # Given the huge imbalance (48 vs 2895), we'll use a moderate sampling strategy
198
+ ros = RandomOverSampler(
199
+ sampling_strategy='not majority', # Oversample minority classes
200
+ random_state=self.config.random_state
201
+ )
202
+ X_train_balanced, y_train_balanced = ros.fit_resample(X_train, y_train)
203
+
204
+ print(f"\nAfter oversampling:")
205
+ for i, lang in enumerate(["iban", "bukar_sadong", "malay"]):
206
+ print(f"{lang}: {np.sum(y_train_balanced == i)}")
207
+
208
+ # Train classifier
209
+ self.custom_classifier = LogisticRegression(
210
+ max_iter=self.config.max_iter,
211
+ random_state=self.config.random_state,
212
+ class_weight='balanced' # Additional balancing
213
+ )
214
+ self.custom_classifier.fit(X_train_balanced, y_train_balanced)
215
+
216
+ # Evaluate
217
+ y_pred = self.custom_classifier.predict(X_test)
218
+ print("\n" + "="*60)
219
+ print("Custom Classifier Performance:")
220
+ print("="*60)
221
+ print(classification_report(y_test, y_pred,
222
+ target_names=["iban", "bukar_sadong", "malay"]))
223
+
224
+ print("\nConfusion Matrix:")
225
+ cm = confusion_matrix(y_test, y_pred)
226
+ print(" Iban Bukar Malay")
227
+ for i, row in enumerate(cm):
228
+ label = ["Iban ", "Bukar ", "Malay "][i]
229
+ print(f"{label} {row}")
230
+
231
+ # Cleanup
232
+ shutil.rmtree(temp_dir)
233
+
234
+ return self.custom_classifier
235
+
236
+ @torch.no_grad()
237
+ def predict_vox(self, audio: Union[str, torch.Tensor]) -> Tuple[str, float, List]:
238
+ """Predict using VoxLingua107 for major languages"""
239
+ if isinstance(audio, str):
240
+ wav = self.audio_processor.load_audio(audio)
241
+ else:
242
+ wav = audio
243
+
244
+ if wav.dim() == 1:
245
+ wav = wav.unsqueeze(0)
246
+
247
+ wav = wav.to(self.device, dtype=torch.float32)
248
+
249
+ # Get predictions
250
+ output = self.vox_model.classify_batch(wav)
251
+ logits = output[0] if isinstance(output, tuple) else output
252
+ logits = logits.squeeze(0).detach().cpu()
253
+
254
+ # Convert to probabilities
255
+ if logits.max().item() <= 1.0: # Log probabilities
256
+ probs = logits.exp()
257
+ probs = probs / probs.sum()
258
+ else:
259
+ probs = logits
260
+
261
+ # Get top prediction
262
+ top_prob, top_idx = torch.max(probs, dim=0)
263
+ top_prob = float(top_prob.item())
264
+
265
+ # Decode label
266
+ try:
267
+ raw_label = self.label_encoder.ind2lab[int(top_idx)]
268
+ except:
269
+ raw_label = self.label_encoder.decode_ndim(int(top_idx))
270
+ raw_label = raw_label.split(":")[0].strip().lower()
271
+
272
+ # Get canonical name
273
+ canonical = self.normalize_language_label(raw_label)
274
+
275
+ # Get top-5 for debugging
276
+ topk = torch.topk(probs, k=min(5, probs.shape[0]))
277
+ top_results = []
278
+ for prob, idx in zip(topk.values.tolist(), topk.indices.tolist()):
279
+ try:
280
+ label = self.label_encoder.ind2lab[int(idx)]
281
+ except:
282
+ label = self.label_encoder.decode_ndim(int(idx))
283
+ top_results.append((label, float(prob)))
284
+
285
+ return canonical if canonical else raw_label, top_prob, top_results
286
+
287
+ def predict_custom(self, audio: Union[str, torch.Tensor]) -> Tuple[str, float]:
288
+ """Predict using custom Iban/Bukar Sadong/Malay classifier"""
289
+ emb = self.extract_embedding(audio)
290
+ proba = self.custom_classifier.predict_proba([emb])[0]
291
+ pred_idx = np.argmax(proba)
292
+
293
+ inv_label_map = {v: k for k, v in self.config.label_map.items()}
294
+ return inv_label_map[pred_idx], float(proba[pred_idx])
295
+
296
+ def predict(self, audio: Union[str, torch.Tensor]) -> Dict:
297
+ """Main prediction method combining both classifiers"""
298
+ # First, get VoxLingua107 prediction
299
+ vox_lang, vox_score, top_results = self.predict_vox(audio)
300
+
301
+ # Check if VoxLingua predicted one of the 4 major languages
302
+ major_languages = ["english", "mandarin", "tamil", "malay"]
303
+
304
+ # Condition 1: If not a major language, pass to custom classifier
305
+ if vox_lang not in major_languages:
306
+ custom_lang, custom_score = self.predict_custom(audio)
307
+ return {
308
+ 'language': custom_lang,
309
+ 'confidence': custom_score,
310
+ 'source': 'custom_classifier',
311
+ 'reason': 'non_major_language',
312
+ 'vox_initial': {'language': vox_lang, 'confidence': vox_score},
313
+ 'debug': {'vox_top_5': top_results}
314
+ }
315
+
316
+ # Condition 2: If VoxLingua predicts Malay, compare with custom classifier
317
+ if vox_lang == "malay":
318
+ custom_lang, custom_score = self.predict_custom(audio)
319
+
320
+ # Compare scores and take the higher confidence prediction
321
+ if custom_score > vox_score:
322
+ # Custom classifier has higher confidence
323
+ return {
324
+ 'language': custom_lang,
325
+ 'confidence': custom_score,
326
+ 'source': 'custom_classifier',
327
+ 'reason': 'higher_confidence',
328
+ 'vox_initial': {'language': vox_lang, 'confidence': vox_score},
329
+ 'custom_scores': {
330
+ 'iban': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][0]),
331
+ 'bukar_sadong': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][1]),
332
+ 'malay': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][2])
333
+ },
334
+ 'debug': {'vox_top_5': top_results}
335
+ }
336
+ else:
337
+ # VoxLingua has higher confidence, keep Malay
338
+ return {
339
+ 'language': 'malay',
340
+ 'confidence': vox_score,
341
+ 'source': 'voxlingua107',
342
+ 'reason': 'higher_confidence',
343
+ 'custom_comparison': {'language': custom_lang, 'confidence': custom_score},
344
+ 'custom_scores': {
345
+ 'iban': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][0]),
346
+ 'bukar_sadong': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][1]),
347
+ 'malay': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][2])
348
+ },
349
+ 'debug': {'top_5': top_results}
350
+ }
351
+
352
+ # For English, Mandarin, Tamil - use VoxLingua result directly
353
+ return {
354
+ 'language': vox_lang,
355
+ 'confidence': vox_score,
356
+ 'source': 'voxlingua107',
357
+ 'debug': {'top_5': top_results}
358
+ }
359
+
360
+
361
+ class Evaluator:
362
+ """Evaluate performance on test datasets"""
363
+
364
+ def __init__(self, identifier: LanguageIdentifier):
365
+ self.identifier = identifier
366
+
367
+ def test_zip_file(self, zip_path: str, true_label: Optional[str] = None,
368
+ verbose: bool = True) -> Dict:
369
+ """Test on a zip file containing audio files"""
370
+ # Extract files
371
+ extract_dir = Path(f"/tmp/test_{Path(zip_path).stem}")
372
+ if extract_dir.exists():
373
+ shutil.rmtree(extract_dir)
374
+ extract_dir.mkdir(parents=True)
375
+
376
+ with zipfile.ZipFile(zip_path, 'r') as z:
377
+ z.extractall(extract_dir)
378
+
379
+ # Find all audio files
380
+ audio_files = list(extract_dir.rglob("*.wav"))
381
+ audio_files.extend(list(extract_dir.rglob("*.mp3")))
382
+ audio_files.sort()
383
+
384
+ if not audio_files:
385
+ print(f"No audio files found in {zip_path}")
386
+ return {}
387
+
388
+ results = []
389
+ source_counts = Counter()
390
+ language_counts = Counter()
391
+ reason_counts = Counter()
392
+
393
+ for audio_file in audio_files:
394
+ try:
395
+ pred = self.identifier.predict(str(audio_file))
396
+ results.append(pred)
397
+ source_counts[pred['source']] += 1
398
+ language_counts[pred['language']] += 1
399
+ if 'reason' in pred:
400
+ reason_counts[pred['reason']] += 1
401
+
402
+ if verbose:
403
+ status = ""
404
+ if true_label:
405
+ status = "✓" if pred['language'] == true_label else "✗"
406
+
407
+ # Build detailed output string
408
+ output_str = f"{audio_file.name:<30} → {pred['language']:<12} [{pred['confidence']:.3f}]"
409
+
410
+ # Add source and reason if available
411
+ if 'reason' in pred:
412
+ output_str += f" via {pred['source']:<20} (reason: {pred['reason']})"
413
+ else:
414
+ output_str += f" via {pred['source']:<20}"
415
+
416
+ # Add comparison info if available
417
+ if 'custom_comparison' in pred:
418
+ comp = pred['custom_comparison']
419
+ output_str += f" [vs {comp['language']}:{comp['confidence']:.3f}]"
420
+ elif 'vox_initial' in pred:
421
+ vox = pred['vox_initial']
422
+ output_str += f" [vox:{vox['language']}:{vox['confidence']:.3f}]"
423
+
424
+ print(f"{output_str} {status}")
425
+
426
+ except Exception as e:
427
+ print(f"Error processing {audio_file.name}: {e}")
428
+
429
+ # Calculate accuracy if true label provided
430
+ accuracy = None
431
+ if true_label:
432
+ correct = sum(1 for r in results if r['language'] == true_label)
433
+ accuracy = correct / len(results) if results else 0
434
+ print(f"\nAccuracy for '{true_label}': {accuracy:.1%} ({correct}/{len(results)})")
435
+
436
+ print(f"\nSource usage: {dict(source_counts)}")
437
+ print(f"Language predictions: {dict(language_counts)}")
438
+ if reason_counts:
439
+ print(f"Decision reasons: {dict(reason_counts)}")
440
+
441
+ # Cleanup
442
+ shutil.rmtree(extract_dir)
443
+
444
+ return {
445
+ 'total': len(results),
446
+ 'results': results,
447
+ 'source_counts': dict(source_counts),
448
+ 'language_counts': dict(language_counts),
449
+ 'reason_counts': dict(reason_counts),
450
+ 'accuracy': accuracy
451
+ }