Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import pandas as pd | |
| from pathlib import Path | |
| from typing import Tuple, Dict, List, Optional, Union | |
| from dataclasses import dataclass | |
| from collections import Counter | |
| import zipfile | |
| import shutil | |
| from speechbrain.pretrained import EncoderClassifier | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| from imblearn.over_sampling import RandomOverSampler | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| import torch | |
| import torchaudio | |
| import soundfile as sf | |
| class Config: | |
| """Configuration for the language identification pipeline""" | |
| target_sample_rate: int = 16000 | |
| embedding_dim: int = 256 | |
| test_size: float = 0.2 | |
| random_state: int = 42 | |
| max_iter: int = 1000 | |
| # Language mappings for custom classifier | |
| label_map: Dict[str, int] = None | |
| canonical_languages: List[str] = None | |
| def __post_init__(self): | |
| # Now includes malay in the custom classifier | |
| self.label_map = {"iban": 0, "bukar_sadong": 1, "malay": 2} | |
| self.canonical_languages = ["malay", "english", "mandarin", "tamil"] | |
| class AudioProcessor: | |
| """Handles audio loading and preprocessing""" | |
| def __init__(self, target_sr: int = 16000): | |
| self.target_sr = target_sr | |
| def load_audio(self, path: str) -> torch.Tensor: | |
| try: | |
| signal, sr = torchaudio.load(path) | |
| except RuntimeError as e: | |
| print(f"[WARN] torchaudio failed: {e}. Falling back to soundfile.") | |
| signal, sr = sf.read(path, dtype="float32") | |
| signal = torch.tensor(signal).T # (channels, time) | |
| # Convert to mono | |
| if signal.shape[0] > 1: | |
| signal = signal.mean(dim=0, keepdim=True) | |
| # Resample if needed | |
| if sr != self.target_sr: | |
| resampler = torchaudio.transforms.Resample(sr, self.target_sr) | |
| signal = resampler(signal) | |
| return signal.to(torch.float32) | |
| class LanguageIdentifier: | |
| """Main language identification system""" | |
| def __init__(self, config: Config = None): | |
| self.config = config or Config() | |
| self.audio_processor = AudioProcessor(self.config.target_sample_rate) | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Initialize models | |
| self.vox_model = None | |
| self.custom_classifier = None | |
| self.label_encoder = None | |
| def load_vox_model(self, model_path: str = None): | |
| """Load SpeechBrain VoxLingua107 model""" | |
| source = "speechbrain/lang-id-voxlingua107-ecapa" | |
| savedir = model_path or "pretrained_models/lang-id-voxlingua107-ecapa" | |
| self.vox_model = EncoderClassifier.from_hparams( | |
| source=source, | |
| savedir=savedir, | |
| run_opts={"device": self.device} | |
| ) | |
| self.label_encoder = self.vox_model.hparams.label_encoder | |
| print(f"VoxLingua107 model loaded on {self.device}") | |
| def extract_embedding(self, audio: Union[str, torch.Tensor]) -> np.ndarray: | |
| """Extract embedding from audio using VoxLingua107""" | |
| if isinstance(audio, str): | |
| wav = self.audio_processor.load_audio(audio) | |
| else: | |
| wav = audio | |
| # Ensure batch dimension | |
| if wav.dim() == 1: | |
| wav = wav.unsqueeze(0) | |
| wav = wav.to(self.device, dtype=torch.float32) | |
| with torch.no_grad(): | |
| embedding = self.vox_model.encode_batch(wav) | |
| if isinstance(embedding, tuple): | |
| embedding = embedding[0] | |
| # Flatten to 1D array | |
| embedding = embedding.view(embedding.size(0), -1).squeeze(0) | |
| return embedding.cpu().numpy() | |
| def normalize_language_label(self, raw_label: str) -> Optional[str]: | |
| """Map VoxLingua107 short codes to canonical language names""" | |
| label_code = raw_label.strip().lower() | |
| # Direct mapping from VoxLingua codes to canonical names | |
| vox_to_canonical = { | |
| "ms": "malay", | |
| "en": "english", | |
| "zh": "mandarin", | |
| "ta": "tamil" | |
| } | |
| return vox_to_canonical.get(label_code) | |
| def extract_audio_files_from_zip(self, zip_path: str, extract_dir: str) -> List[Path]: | |
| """Extract and return list of audio files from a zip archive""" | |
| temp_extract = Path(extract_dir) / Path(zip_path).stem | |
| if temp_extract.exists(): | |
| shutil.rmtree(temp_extract) | |
| temp_extract.mkdir(parents=True) | |
| with zipfile.ZipFile(zip_path, 'r') as z: | |
| z.extractall(temp_extract) | |
| # Find all audio files | |
| audio_files = [] | |
| for ext in ['*.wav', '*.mp3']: | |
| audio_files.extend(list(temp_extract.rglob(ext))) | |
| return sorted(audio_files) | |
| def train_custom_classifier(self, drive_base: str = "/content/drive"): | |
| """Train custom classifier for Iban/Bukar Sadong/Malay""" | |
| print("Training custom 3-language classifier...") | |
| # Temporary extraction directory | |
| temp_dir = Path("/tmp/training_data") | |
| if temp_dir.exists(): | |
| shutil.rmtree(temp_dir) | |
| temp_dir.mkdir(parents=True) | |
| all_embeddings = [] | |
| all_labels = [] | |
| language_files = {"iban": [], "bukar_sadong": [], "malay": []} | |
| # Process Iban data (from two sources) | |
| print("\nProcessing Iban data...") | |
| iban_zips = [ | |
| f"{drive_base}/MyDrive/language_identification/training_data/github_iban_filter_train.zip", | |
| f"{drive_base}/MyDrive/language_identification/training_data/gkalaka_iban_filter_train.zip" | |
| ] | |
| for zip_path in iban_zips: | |
| if os.path.exists(zip_path): | |
| print(f"Extracting {Path(zip_path).name}...") | |
| audio_files = self.extract_audio_files_from_zip(zip_path, str(temp_dir)) | |
| language_files["iban"].extend(audio_files) | |
| print(f"Found {len(audio_files)} files") | |
| # Process Malay data | |
| print("\nProcessing Malay data...") | |
| malay_zip = f"{drive_base}/MyDrive/language_identification/training_data/malay_train.zip" | |
| if os.path.exists(malay_zip): | |
| audio_files = self.extract_audio_files_from_zip(malay_zip, str(temp_dir)) | |
| language_files["malay"].extend(audio_files) | |
| print(f"Found {len(audio_files)} Malay files") | |
| # Process Bukar Sadong data | |
| print("\nProcessing Bukar Sadong data...") | |
| bukar_zip = f"{drive_base}/MyDrive/language_identification/training_data/bukar_sadong_train.zip" | |
| if os.path.exists(bukar_zip): | |
| audio_files = self.extract_audio_files_from_zip(bukar_zip, str(temp_dir)) | |
| language_files["bukar_sadong"].extend(audio_files) | |
| print(f"Found {len(audio_files)} Bukar Sadong files") | |
| # Extract embeddings for each language | |
| for lang, files in language_files.items(): | |
| print(f"\nExtracting embeddings for {lang}: {len(files)} files") | |
| for i, audio_file in enumerate(files): | |
| if i % 100 == 0: | |
| print(f"Processing {lang}: {i}/{len(files)}") | |
| try: | |
| emb = self.extract_embedding(str(audio_file)) | |
| all_embeddings.append(emb) | |
| all_labels.append(self.config.label_map[lang]) | |
| except Exception as e: | |
| print(f"Error processing {audio_file}: {e}") | |
| if not all_embeddings: | |
| raise ValueError("No training data collected") | |
| X = np.array(all_embeddings) | |
| y = np.array(all_labels) | |
| print(f"\nTotal samples collected:") | |
| print(f"Iban: {np.sum(y == 0)}") | |
| print(f"Bukar Sadong: {np.sum(y == 1)}") | |
| print(f"Malay: {np.sum(y == 2)}") | |
| # Stratified split ensuring 20% from each language | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=self.config.test_size, | |
| stratify=y, random_state=self.config.random_state | |
| ) | |
| print(f"\nTraining set distribution:") | |
| for i, lang in enumerate(["iban", "bukar_sadong", "malay"]): | |
| print(f"{lang}: {np.sum(y_train == i)}") | |
| # Apply oversampling to balance the training set | |
| # Given the huge imbalance (48 vs 2895), we'll use a moderate sampling strategy | |
| ros = RandomOverSampler( | |
| sampling_strategy='not majority', # Oversample minority classes | |
| random_state=self.config.random_state | |
| ) | |
| X_train_balanced, y_train_balanced = ros.fit_resample(X_train, y_train) | |
| print(f"\nAfter oversampling:") | |
| for i, lang in enumerate(["iban", "bukar_sadong", "malay"]): | |
| print(f"{lang}: {np.sum(y_train_balanced == i)}") | |
| # Train classifier | |
| self.custom_classifier = LogisticRegression( | |
| max_iter=self.config.max_iter, | |
| random_state=self.config.random_state, | |
| class_weight='balanced' # Additional balancing | |
| ) | |
| self.custom_classifier.fit(X_train_balanced, y_train_balanced) | |
| # Evaluate | |
| y_pred = self.custom_classifier.predict(X_test) | |
| print("\n" + "="*60) | |
| print("Custom Classifier Performance:") | |
| print("="*60) | |
| print(classification_report(y_test, y_pred, | |
| target_names=["iban", "bukar_sadong", "malay"])) | |
| print("\nConfusion Matrix:") | |
| cm = confusion_matrix(y_test, y_pred) | |
| print(" Iban Bukar Malay") | |
| for i, row in enumerate(cm): | |
| label = ["Iban ", "Bukar ", "Malay "][i] | |
| print(f"{label} {row}") | |
| # Cleanup | |
| shutil.rmtree(temp_dir) | |
| return self.custom_classifier | |
| def predict_vox(self, audio: Union[str, torch.Tensor]) -> Tuple[str, float, List]: | |
| """Predict using VoxLingua107 for major languages""" | |
| if isinstance(audio, str): | |
| wav = self.audio_processor.load_audio(audio) | |
| else: | |
| wav = audio | |
| if wav.dim() == 1: | |
| wav = wav.unsqueeze(0) | |
| wav = wav.to(self.device, dtype=torch.float32) | |
| # Get predictions | |
| output = self.vox_model.classify_batch(wav) | |
| logits = output[0] if isinstance(output, tuple) else output | |
| logits = logits.squeeze(0).detach().cpu() | |
| # Convert to probabilities | |
| if logits.max().item() <= 1.0: | |
| probs = logits.exp() | |
| probs = probs / probs.sum() | |
| else: | |
| probs = logits | |
| # Get top prediction | |
| top_prob, top_idx = torch.max(probs, dim=0) | |
| top_prob = float(top_prob.item()) | |
| # Decode label | |
| try: | |
| raw_label = self.label_encoder.ind2lab[int(top_idx)] | |
| except: | |
| raw_label = self.label_encoder.decode_ndim(int(top_idx)) | |
| raw_label = raw_label.split(":")[0].strip().lower() | |
| # Get canonical name | |
| canonical = self.normalize_language_label(raw_label) | |
| # Get top-5 for debugging | |
| topk = torch.topk(probs, k=min(5, probs.shape[0])) | |
| top_results = [] | |
| for prob, idx in zip(topk.values.tolist(), topk.indices.tolist()): | |
| try: | |
| label = self.label_encoder.ind2lab[int(idx)] | |
| except: | |
| label = self.label_encoder.decode_ndim(int(idx)) | |
| top_results.append((label, float(prob))) | |
| return canonical if canonical else raw_label, top_prob, top_results | |
| def predict_custom(self, audio: Union[str, torch.Tensor]) -> Tuple[str, float]: | |
| """Predict using custom Iban/Bukar Sadong/Malay classifier""" | |
| emb = self.extract_embedding(audio) | |
| proba = self.custom_classifier.predict_proba([emb])[0] | |
| pred_idx = np.argmax(proba) | |
| inv_label_map = {v: k for k, v in self.config.label_map.items()} | |
| return inv_label_map[pred_idx], float(proba[pred_idx]) | |
| def predict(self, audio: Union[str, torch.Tensor]) -> Dict: | |
| """Main prediction method combining both classifiers""" | |
| # First, get VoxLingua107 prediction | |
| vox_lang, vox_score, top_results = self.predict_vox(audio) | |
| # Check if VoxLingua predicted one of the 4 major languages | |
| major_languages = ["english", "mandarin", "tamil", "malay"] | |
| # Condition 1: If not a major language, pass to custom classifier | |
| if vox_lang not in major_languages: | |
| custom_lang, custom_score = self.predict_custom(audio) | |
| return { | |
| 'language': custom_lang, | |
| 'confidence': custom_score, | |
| 'source': 'custom_classifier', | |
| 'reason': 'non_major_language', | |
| 'vox_initial': {'language': vox_lang, 'confidence': vox_score}, | |
| 'debug': {'vox_top_5': top_results} | |
| } | |
| # Condition 2: If VoxLingua predicts Malay, compare with custom classifier | |
| if vox_lang == "malay": | |
| custom_lang, custom_score = self.predict_custom(audio) | |
| # Compare scores and take the higher confidence prediction | |
| if custom_score > vox_score: | |
| # Custom classifier has higher confidence | |
| return { | |
| 'language': custom_lang, | |
| 'confidence': custom_score, | |
| 'source': 'custom_classifier', | |
| 'reason': 'higher_confidence', | |
| 'vox_initial': {'language': vox_lang, 'confidence': vox_score}, | |
| 'custom_scores': { | |
| 'iban': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][0]), | |
| 'bukar_sadong': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][1]), | |
| 'malay': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][2]) | |
| }, | |
| 'debug': {'vox_top_5': top_results} | |
| } | |
| else: | |
| # VoxLingua has higher confidence, keep Malay | |
| return { | |
| 'language': 'malay', | |
| 'confidence': vox_score, | |
| 'source': 'voxlingua107', | |
| 'reason': 'higher_confidence', | |
| 'custom_comparison': {'language': custom_lang, 'confidence': custom_score}, | |
| 'custom_scores': { | |
| 'iban': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][0]), | |
| 'bukar_sadong': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][1]), | |
| 'malay': float(self.custom_classifier.predict_proba([self.extract_embedding(audio)])[0][2]) | |
| }, | |
| 'debug': {'top_5': top_results} | |
| } | |
| # For English, Mandarin, Tamil - use VoxLingua result directly | |
| return { | |
| 'language': vox_lang, | |
| 'confidence': vox_score, | |
| 'source': 'voxlingua107', | |
| 'debug': {'top_5': top_results} | |
| } | |
| class Evaluator: | |
| """Evaluate performance on test datasets""" | |
| def __init__(self, identifier: LanguageIdentifier): | |
| self.identifier = identifier | |
| def test_zip_file(self, zip_path: str, true_label: Optional[str] = None, | |
| verbose: bool = True) -> Dict: | |
| """Test on a zip file containing audio files""" | |
| # Extract files | |
| extract_dir = Path(f"/tmp/test_{Path(zip_path).stem}") | |
| if extract_dir.exists(): | |
| shutil.rmtree(extract_dir) | |
| extract_dir.mkdir(parents=True) | |
| with zipfile.ZipFile(zip_path, 'r') as z: | |
| z.extractall(extract_dir) | |
| # Find all audio files | |
| audio_files = list(extract_dir.rglob("*.wav")) | |
| audio_files.extend(list(extract_dir.rglob("*.mp3"))) | |
| audio_files.sort() | |
| if not audio_files: | |
| print(f"No audio files found in {zip_path}") | |
| return {} | |
| results = [] | |
| source_counts = Counter() | |
| language_counts = Counter() | |
| reason_counts = Counter() | |
| for audio_file in audio_files: | |
| try: | |
| pred = self.identifier.predict(str(audio_file)) | |
| results.append(pred) | |
| source_counts[pred['source']] += 1 | |
| language_counts[pred['language']] += 1 | |
| if 'reason' in pred: | |
| reason_counts[pred['reason']] += 1 | |
| if verbose: | |
| status = "" | |
| if true_label: | |
| status = "✓" if pred['language'] == true_label else "✗" | |
| # Build detailed output string | |
| output_str = f"{audio_file.name:<30} → {pred['language']:<12} [{pred['confidence']:.3f}]" | |
| # Add source and reason if available | |
| if 'reason' in pred: | |
| output_str += f" via {pred['source']:<20} (reason: {pred['reason']})" | |
| else: | |
| output_str += f" via {pred['source']:<20}" | |
| # Add comparison info if available | |
| if 'custom_comparison' in pred: | |
| comp = pred['custom_comparison'] | |
| output_str += f" [vs {comp['language']}:{comp['confidence']:.3f}]" | |
| elif 'vox_initial' in pred: | |
| vox = pred['vox_initial'] | |
| output_str += f" [vox:{vox['language']}:{vox['confidence']:.3f}]" | |
| print(f"{output_str} {status}") | |
| except Exception as e: | |
| print(f"Error processing {audio_file.name}: {e}") | |
| # Calculate accuracy if true label provided | |
| accuracy = None | |
| if true_label: | |
| correct = sum(1 for r in results if r['language'] == true_label) | |
| accuracy = correct / len(results) if results else 0 | |
| print(f"\nAccuracy for '{true_label}': {accuracy:.1%} ({correct}/{len(results)})") | |
| print(f"\nSource usage: {dict(source_counts)}") | |
| print(f"Language predictions: {dict(language_counts)}") | |
| if reason_counts: | |
| print(f"Decision reasons: {dict(reason_counts)}") | |
| # Cleanup | |
| shutil.rmtree(extract_dir) | |
| return { | |
| 'total': len(results), | |
| 'results': results, | |
| 'source_counts': dict(source_counts), | |
| 'language_counts': dict(language_counts), | |
| 'reason_counts': dict(reason_counts), | |
| 'accuracy': accuracy | |
| } |