| |
| """ |
| ============================================================= |
| Sinhala ASR Model Comparison Test |
| ============================================================= |
| Tests multiple ASR models on existing audio segments from |
| your pipeline output. Helps you pick the best ASR model |
| before re-running the full pipeline. |
| |
| Tests: |
| 1. Facebook MMS (facebook/mms-1b-all) — CTC, 1B params |
| 2. Lingalingeswaran/whisper-small-sinhala_v3 — Whisper fine-tune, 242M params |
| 3. Lingalingeswaran/whisper-small-sinhala — Whisper fine-tune, 242M params |
| 4. openai/whisper-large-v3 (forced Sinhala) — baseline, 1.5B params |
| |
| Runs ONE model at a time to fit in 6GB VRAM. |
| |
| Usage: |
| # Test on existing pipeline segments (default: first 5) |
| python scripts/test_asr_models.py --segments-dir pipeline_output/segments |
| |
| # Test on specific WAV files |
| python scripts/test_asr_models.py --wav-files audio1.wav audio2.wav |
| |
| # Test only specific models |
| python scripts/test_asr_models.py --models mms whisper-si-v3 |
| |
| # Test more segments |
| python scripts/test_asr_models.py --max-segments 20 |
| |
| # CPU only (slower but works without GPU) |
| python scripts/test_asr_models.py --cpu |
| |
| Requirements: |
| pip install transformers torch torchaudio soundfile numpy |
| ============================================================= |
| """ |
|
|
| import os |
| import sys |
| import gc |
| import json |
| import argparse |
| import time |
| import warnings |
| from pathlib import Path |
| from typing import List, Dict, Optional, Tuple |
|
|
| warnings.filterwarnings("ignore") |
|
|
| import torch |
| import torchaudio |
| import numpy as np |
|
|
|
|
| |
| |
| |
| def free_gpu(): |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| elif torch.backends.mps.is_available(): |
| pass |
|
|
|
|
| def gpu_mem(): |
| if torch.cuda.is_available(): |
| used = torch.cuda.memory_allocated() / 1e9 |
| total = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| return f"{used:.1f}/{total:.1f}GB" |
| if torch.backends.mps.is_available(): |
| return "MPS" |
| return "CPU" |
|
|
|
|
| |
| |
| |
| def load_audio_16k(wav_path: str) -> np.ndarray: |
| """Load WAV and resample to 16kHz mono float32 numpy array.""" |
| import soundfile as sf |
| data, sr = sf.read(wav_path, dtype='float32') |
| waveform = torch.from_numpy(data) |
| if waveform.ndim == 2: |
| waveform = waveform.mean(dim=-1) |
| if waveform.ndim == 0: |
| return np.zeros(16000, dtype=np.float32) |
| if sr != 16000: |
| waveform = torchaudio.transforms.Resample(sr, 16000)(waveform.unsqueeze(0)).squeeze(0) |
| return waveform.numpy() |
|
|
|
|
| |
| |
| |
| def test_mms(wav_paths: List[str], device: str) -> List[Dict]: |
| """Test facebook/mms-1b-all with Sinhala adapter.""" |
| print("\n" + "=" * 60) |
| print("MODEL 1: facebook/mms-1b-all (MMS CTC)") |
| print("=" * 60) |
|
|
| results = [] |
|
|
| try: |
| from transformers import Wav2Vec2ForCTC, AutoProcessor |
|
|
| model_id = "facebook/mms-1b-all" |
| target_lang = "sin" |
|
|
| |
| print(f" Loading processor for {model_id}...") |
| try: |
| processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang) |
| except Exception as e: |
| if "sin" in str(e).lower() or "not found" in str(e).lower() or "does not exist" in str(e).lower(): |
| print(f" WARNING: Sinhala ('sin') failed: {e}") |
| |
| |
| for alt_code in ["snh", "sinh", "si"]: |
| try: |
| print(f" Trying alternate code '{alt_code}'...") |
| processor = AutoProcessor.from_pretrained(model_id, target_lang=alt_code) |
| target_lang = alt_code |
| print(f" Found working code: '{alt_code}'") |
| break |
| except Exception: |
| continue |
| else: |
| print(f" No Sinhala adapter found in MMS. Skipping.") |
| return [{"model": "mms-1b-all", "file": os.path.basename(p), "text": "[UNSUPPORTED]", "time": 0} for p in wav_paths] |
| else: |
| raise |
|
|
| print(f" Loading model (fp16)...") |
| use_dtype = torch.float16 if device == "cuda" else torch.float32 |
| model = Wav2Vec2ForCTC.from_pretrained( |
| model_id, |
| target_lang=target_lang, |
| ignore_mismatched_sizes=True, |
| torch_dtype=use_dtype, |
| ).to(device).eval() |
| print(f" Model loaded. VRAM: {gpu_mem()}") |
|
|
| for wav_path in wav_paths: |
| start = time.time() |
| try: |
| audio = load_audio_16k(wav_path) |
| inputs = processor(audio, sampling_rate=16000, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| if use_dtype == torch.float16: |
| inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| |
| predicted_ids = torch.argmax(logits, dim=-1)[0] |
| text = processor.decode(predicted_ids) |
| elapsed = time.time() - start |
|
|
| results.append({ |
| "model": "mms-1b-all", |
| "file": os.path.basename(wav_path), |
| "text": text.strip(), |
| "time": round(elapsed, 2), |
| }) |
| print(f" [{os.path.basename(wav_path)}] ({elapsed:.1f}s) -> {text.strip()[:100]}") |
| except Exception as e: |
| results.append({ |
| "model": "mms-1b-all", |
| "file": os.path.basename(wav_path), |
| "text": f"[ERROR: {e}]", |
| "time": 0, |
| }) |
| print(f" [{os.path.basename(wav_path)}] ERROR: {e}") |
|
|
| del model, processor |
| free_gpu() |
| print(f" Model unloaded. VRAM: {gpu_mem()}") |
|
|
| except ImportError: |
| print(" transformers not installed!") |
| return [{"model": "mms-1b-all", "file": os.path.basename(p), "text": "[IMPORT_ERROR]", "time": 0} for p in wav_paths] |
|
|
| return results |
|
|
|
|
| |
| |
| |
| def test_whisper_finetune(wav_paths: List[str], model_id: str, model_label: str, device: str) -> List[Dict]: |
| """Test a HuggingFace Whisper fine-tune for Sinhala.""" |
| print(f"\n{'=' * 60}") |
| print(f"MODEL: {model_label} ({model_id})") |
| print(f"{'=' * 60}") |
|
|
| results = [] |
|
|
| try: |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor |
|
|
| print(f" Loading {model_id}...") |
| use_dtype = torch.float16 if device == "cuda" else torch.float32 |
| processor = WhisperProcessor.from_pretrained(model_id) |
| model = WhisperForConditionalGeneration.from_pretrained( |
| model_id, |
| torch_dtype=use_dtype, |
| ).to(device).eval() |
| print(f" Model loaded. VRAM: {gpu_mem()}") |
|
|
| |
| gen_kwargs = { |
| "max_new_tokens": 225, |
| "no_repeat_ngram_size": 3, |
| "language": "si", |
| "task": "transcribe", |
| } |
|
|
| for wav_path in wav_paths: |
| start = time.time() |
| try: |
| audio = load_audio_16k(wav_path) |
| input_features = processor( |
| audio, sampling_rate=16000, return_tensors="pt" |
| ).input_features.to(device) |
| if use_dtype == torch.float16: |
| input_features = input_features.half() |
|
|
| with torch.no_grad(): |
| predicted_ids = model.generate( |
| input_features, |
| **gen_kwargs |
| ) |
|
|
| text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
| elapsed = time.time() - start |
|
|
| results.append({ |
| "model": model_label, |
| "file": os.path.basename(wav_path), |
| "text": text.strip(), |
| "time": round(elapsed, 2), |
| }) |
| print(f" [{os.path.basename(wav_path)}] ({elapsed:.1f}s) -> {text.strip()[:100]}") |
| except Exception as e: |
| results.append({ |
| "model": model_label, |
| "file": os.path.basename(wav_path), |
| "text": f"[ERROR: {e}]", |
| "time": 0, |
| }) |
| print(f" [{os.path.basename(wav_path)}] ERROR: {e}") |
|
|
| del model, processor |
| free_gpu() |
| print(f" Model unloaded. VRAM: {gpu_mem()}") |
|
|
| except ImportError: |
| print(" transformers not installed!") |
| return [{"model": model_label, "file": os.path.basename(p), "text": "[IMPORT_ERROR]", "time": 0} for p in wav_paths] |
|
|
| return results |
|
|
|
|
| |
| |
| |
| def test_whisper_large(wav_paths: List[str], device: str) -> List[Dict]: |
| """Test openai/whisper-large-v3 with forced Sinhala language.""" |
| print(f"\n{'=' * 60}") |
| print(f"MODEL: openai/whisper-large-v3 (forced Sinhala)") |
| print(f"{'=' * 60}") |
|
|
| results = [] |
|
|
| try: |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor |
|
|
| model_id = "openai/whisper-large-v3" |
| print(f" Loading {model_id} (fp16)...") |
| use_dtype = torch.float16 if device == "cuda" else torch.float32 |
| processor = WhisperProcessor.from_pretrained(model_id) |
| model = WhisperForConditionalGeneration.from_pretrained( |
| model_id, |
| torch_dtype=use_dtype, |
| ).to(device).eval() |
| print(f" Model loaded. VRAM: {gpu_mem()}") |
|
|
| forced_decoder_ids = processor.get_decoder_prompt_ids(language="si", task="transcribe") |
|
|
| for wav_path in wav_paths: |
| start = time.time() |
| try: |
| audio = load_audio_16k(wav_path) |
| input_features = processor( |
| audio, sampling_rate=16000, return_tensors="pt" |
| ).input_features.to(device) |
| if use_dtype == torch.float16: |
| input_features = input_features.half() |
|
|
| with torch.no_grad(): |
| predicted_ids = model.generate( |
| input_features, |
| forced_decoder_ids=forced_decoder_ids, |
| max_new_tokens=225, |
| no_repeat_ngram_size=3, |
| temperature=0.0, |
| do_sample=False, |
| ) |
|
|
| text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
| elapsed = time.time() - start |
|
|
| results.append({ |
| "model": "whisper-large-v3", |
| "file": os.path.basename(wav_path), |
| "text": text.strip(), |
| "time": round(elapsed, 2), |
| }) |
| print(f" [{os.path.basename(wav_path)}] ({elapsed:.1f}s) -> {text.strip()[:100]}") |
| except Exception as e: |
| results.append({ |
| "model": "whisper-large-v3", |
| "file": os.path.basename(wav_path), |
| "text": f"[ERROR: {e}]", |
| "time": 0, |
| }) |
| print(f" [{os.path.basename(wav_path)}] ERROR: {e}") |
|
|
| del model, processor |
| free_gpu() |
| print(f" Model unloaded. VRAM: {gpu_mem()}") |
|
|
| except ImportError: |
| print(" transformers not installed!") |
| return [{"model": "whisper-large-v3", "file": os.path.basename(p), "text": "[IMPORT_ERROR]", "time": 0} for p in wav_paths] |
|
|
| return results |
|
|
|
|
| |
| |
| |
| def find_segments(segments_dir: str, max_segments: int) -> List[str]: |
| """Find WAV segments from pipeline output.""" |
| seg_path = Path(segments_dir) |
| if not seg_path.exists(): |
| print(f"Segments directory not found: {segments_dir}") |
| print(f" Run the pipeline first, or provide --wav-files") |
| sys.exit(1) |
|
|
| wavs = sorted(seg_path.rglob("*.wav")) |
| if not wavs: |
| print(f"No WAV files found in {segments_dir}") |
| sys.exit(1) |
|
|
| |
| if len(wavs) > max_segments: |
| indices = np.linspace(0, len(wavs) - 1, max_segments, dtype=int) |
| wavs = [wavs[i] for i in indices] |
|
|
| print(f"Found {len(wavs)} segments to test") |
| for w in wavs: |
| import soundfile as sf |
| info = sf.info(str(w)) |
| print(f" {w.name}: {info.duration:.1f}s, {info.samplerate}Hz") |
|
|
| return [str(w) for w in wavs] |
|
|
|
|
| |
| |
| |
| AVAILABLE_MODELS = { |
| "mms": ("facebook/mms-1b-all", "MMS CTC"), |
| "whisper-si-v3": ("Lingalingeswaran/whisper-small-sinhala_v3", "Whisper-Small-Sinhala-v3"), |
| "whisper-si": ("Lingalingeswaran/whisper-small-sinhala", "Whisper-Small-Sinhala"), |
| "whisper-hackvermin": ("hackvermin/whisper-small-sinhala", "Whisper-Small-Hackvermin"), |
| "whisper-seniruk": ("seniruk/whisper-small-si", "Whisper-Small-Seniruk"), |
| "whisper-large": ("openai/whisper-large-v3", "Whisper-Large-v3-ForcedSi"), |
| "whisper-turbo": ("openai/whisper-large-v3-turbo", "Whisper-Turbo-ForcedSi"), |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Compare ASR models for Sinhala") |
| parser.add_argument("--segments-dir", type=str, default="pipeline_output/segments", |
| help="Directory with WAV segments from pipeline") |
| parser.add_argument("--wav-files", nargs="+", type=str, default=None, |
| help="Specific WAV files to test (overrides --segments-dir)") |
| parser.add_argument("--max-segments", type=int, default=5, |
| help="Max segments to test (default: 5)") |
| parser.add_argument("--models", nargs="+", type=str, default=None, |
| choices=list(AVAILABLE_MODELS.keys()), |
| help=f"Models to test (default: all). Options: {list(AVAILABLE_MODELS.keys())}") |
| parser.add_argument("--cpu", action="store_true", help="Force CPU") |
| parser.add_argument("--output", type=str, default="asr_comparison.json", |
| help="Output JSON file for results") |
| args = parser.parse_args() |
|
|
| if args.cpu: |
| device = "cpu" |
| elif torch.cuda.is_available(): |
| device = "cuda" |
| elif torch.backends.mps.is_available(): |
| device = "mps" |
| else: |
| device = "cpu" |
|
|
| print(f"\n{'#' * 60}") |
| print(f"# Sinhala ASR Model Comparison") |
| print(f"# Device: {device}") |
| if device == "cuda": |
| print(f"# GPU: {torch.cuda.get_device_name(0)}") |
| props = torch.cuda.get_device_properties(0) |
| print(f"# VRAM: {props.total_memory / 1e9:.1f}GB") |
| elif device == "mps": |
| print(f"# GPU: Apple Silicon (MPS)") |
| print(f"{'#' * 60}") |
|
|
| |
| if args.wav_files: |
| wav_paths = args.wav_files |
| else: |
| wav_paths = find_segments(args.segments_dir, args.max_segments) |
|
|
| |
| models_to_test = args.models or list(AVAILABLE_MODELS.keys()) |
|
|
| |
| if device == "cuda": |
| vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| if vram_gb < 5.0 and "whisper-large" in models_to_test and len(models_to_test) > 1: |
| print(f"\n VRAM ({vram_gb:.1f}GB) may be tight for whisper-large-v3. Testing anyway...") |
|
|
| all_results = [] |
|
|
| for model_key in models_to_test: |
| free_gpu() |
|
|
| if model_key == "mms": |
| results = test_mms(wav_paths, device) |
| elif model_key == "whisper-large": |
| results = test_whisper_large(wav_paths, device) |
| else: |
| model_id, label = AVAILABLE_MODELS[model_key] |
| results = test_whisper_finetune(wav_paths, model_id, label, device) |
|
|
| all_results.extend(results) |
|
|
| |
| print(f"\n\n{'=' * 80}") |
| print(f"{'COMPARISON RESULTS':^80}") |
| print(f"{'=' * 80}") |
|
|
| |
| files = sorted(set(r["file"] for r in all_results)) |
| models = sorted(set(r["model"] for r in all_results)) |
|
|
| for f in files: |
| print(f"\n {f}") |
| print(f" {'Model':<35} {'Time':>6} Transcription") |
| print(f" {'-'*35} {'-'*6} {'-'*50}") |
| for m in models: |
| match = [r for r in all_results if r["file"] == f and r["model"] == m] |
| if match: |
| r = match[0] |
| text_preview = r["text"][:80] + ("..." if len(r["text"]) > 80 else "") |
| print(f" {m:<35} {r['time']:>5.1f}s {text_preview}") |
|
|
| |
| output_path = Path(args.output) |
| output_path.write_text(json.dumps(all_results, indent=2, ensure_ascii=False), encoding="utf-8") |
| print(f"\nResults saved to {output_path}") |
|
|
| |
| print(f"\n{'=' * 60}") |
| print("WHAT TO LOOK FOR:") |
| print("=" * 60) |
| print(""" |
| 1. Does the text contain REAL Sinhala words? |
| (Not repetitive garbage like "..." patterns) |
| |
| 2. Is the text coherent? Does it make sense as spoken Sinhala? |
| (Historical content about Sri Lanka) |
| |
| 3. Speed: faster is better for processing 293 videos. |
| |
| 4. Hallucination: any model producing repetitive/nonsensical output |
| should be rejected immediately. |
| |
| Pick the model that produces the most coherent Sinhala text, then |
| update local_pipeline.py to use it. |
| """) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|