#!/usr/bin/env python3 """ ============================================================= Sinhala ASR Model Comparison Test ============================================================= Tests multiple ASR models on existing audio segments from your pipeline output. Helps you pick the best ASR model before re-running the full pipeline. Tests: 1. Facebook MMS (facebook/mms-1b-all) — CTC, 1B params 2. Lingalingeswaran/whisper-small-sinhala_v3 — Whisper fine-tune, 242M params 3. Lingalingeswaran/whisper-small-sinhala — Whisper fine-tune, 242M params 4. openai/whisper-large-v3 (forced Sinhala) — baseline, 1.5B params Runs ONE model at a time to fit in 6GB VRAM. Usage: # Test on existing pipeline segments (default: first 5) python scripts/test_asr_models.py --segments-dir pipeline_output/segments # Test on specific WAV files python scripts/test_asr_models.py --wav-files audio1.wav audio2.wav # Test only specific models python scripts/test_asr_models.py --models mms whisper-si-v3 # Test more segments python scripts/test_asr_models.py --max-segments 20 # CPU only (slower but works without GPU) python scripts/test_asr_models.py --cpu Requirements: pip install transformers torch torchaudio soundfile numpy ============================================================= """ import os import sys import gc import json import argparse import time import warnings from pathlib import Path from typing import List, Dict, Optional, Tuple warnings.filterwarnings("ignore") import torch import torchaudio import numpy as np # ============================================================ # GPU MEMORY MANAGEMENT # ============================================================ def free_gpu(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() elif torch.backends.mps.is_available(): pass def gpu_mem(): if torch.cuda.is_available(): used = torch.cuda.memory_allocated() / 1e9 total = torch.cuda.get_device_properties(0).total_memory / 1e9 return f"{used:.1f}/{total:.1f}GB" if torch.backends.mps.is_available(): return "MPS" return "CPU" # ============================================================ # AUDIO LOADING # ============================================================ def load_audio_16k(wav_path: str) -> np.ndarray: """Load WAV and resample to 16kHz mono float32 numpy array.""" import soundfile as sf data, sr = sf.read(wav_path, dtype='float32') waveform = torch.from_numpy(data) if waveform.ndim == 2: waveform = waveform.mean(dim=-1) # stereo to mono if waveform.ndim == 0: return np.zeros(16000, dtype=np.float32) if sr != 16000: waveform = torchaudio.transforms.Resample(sr, 16000)(waveform.unsqueeze(0)).squeeze(0) return waveform.numpy() # ============================================================ # MODEL 1: Facebook MMS (CTC) # ============================================================ def test_mms(wav_paths: List[str], device: str) -> List[Dict]: """Test facebook/mms-1b-all with Sinhala adapter.""" print("\n" + "=" * 60) print("MODEL 1: facebook/mms-1b-all (MMS CTC)") print("=" * 60) results = [] try: from transformers import Wav2Vec2ForCTC, AutoProcessor model_id = "facebook/mms-1b-all" target_lang = "sin" # First check if Sinhala is supported print(f" Loading processor for {model_id}...") try: processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang) except Exception as e: if "sin" in str(e).lower() or "not found" in str(e).lower() or "does not exist" in str(e).lower(): print(f" WARNING: Sinhala ('sin') failed: {e}") # Try alternative codes for alt_code in ["snh", "sinh", "si"]: try: print(f" Trying alternate code '{alt_code}'...") processor = AutoProcessor.from_pretrained(model_id, target_lang=alt_code) target_lang = alt_code print(f" Found working code: '{alt_code}'") break except Exception: continue else: print(f" No Sinhala adapter found in MMS. Skipping.") return [{"model": "mms-1b-all", "file": os.path.basename(p), "text": "[UNSUPPORTED]", "time": 0} for p in wav_paths] else: raise print(f" Loading model (fp16)...") use_dtype = torch.float16 if device == "cuda" else torch.float32 model = Wav2Vec2ForCTC.from_pretrained( model_id, target_lang=target_lang, ignore_mismatched_sizes=True, torch_dtype=use_dtype, ).to(device).eval() print(f" Model loaded. VRAM: {gpu_mem()}") for wav_path in wav_paths: start = time.time() try: audio = load_audio_16k(wav_path) inputs = processor(audio, sampling_rate=16000, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} if use_dtype == torch.float16: inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1)[0] text = processor.decode(predicted_ids) elapsed = time.time() - start results.append({ "model": "mms-1b-all", "file": os.path.basename(wav_path), "text": text.strip(), "time": round(elapsed, 2), }) print(f" [{os.path.basename(wav_path)}] ({elapsed:.1f}s) -> {text.strip()[:100]}") except Exception as e: results.append({ "model": "mms-1b-all", "file": os.path.basename(wav_path), "text": f"[ERROR: {e}]", "time": 0, }) print(f" [{os.path.basename(wav_path)}] ERROR: {e}") del model, processor free_gpu() print(f" Model unloaded. VRAM: {gpu_mem()}") except ImportError: print(" transformers not installed!") return [{"model": "mms-1b-all", "file": os.path.basename(p), "text": "[IMPORT_ERROR]", "time": 0} for p in wav_paths] return results # ============================================================ # MODEL 2/3: Whisper Sinhala Fine-tunes (HuggingFace transformers) # ============================================================ def test_whisper_finetune(wav_paths: List[str], model_id: str, model_label: str, device: str) -> List[Dict]: """Test a HuggingFace Whisper fine-tune for Sinhala.""" print(f"\n{'=' * 60}") print(f"MODEL: {model_label} ({model_id})") print(f"{'=' * 60}") results = [] try: from transformers import WhisperForConditionalGeneration, WhisperProcessor print(f" Loading {model_id}...") use_dtype = torch.float16 if device == "cuda" else torch.float32 processor = WhisperProcessor.from_pretrained(model_id) model = WhisperForConditionalGeneration.from_pretrained( model_id, torch_dtype=use_dtype, ).to(device).eval() print(f" Model loaded. VRAM: {gpu_mem()}") # Force Sinhala language and transcribe task gen_kwargs = { "max_new_tokens": 225, "no_repeat_ngram_size": 3, "language": "si", "task": "transcribe", } for wav_path in wav_paths: start = time.time() try: audio = load_audio_16k(wav_path) input_features = processor( audio, sampling_rate=16000, return_tensors="pt" ).input_features.to(device) if use_dtype == torch.float16: input_features = input_features.half() with torch.no_grad(): predicted_ids = model.generate( input_features, **gen_kwargs ) text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] elapsed = time.time() - start results.append({ "model": model_label, "file": os.path.basename(wav_path), "text": text.strip(), "time": round(elapsed, 2), }) print(f" [{os.path.basename(wav_path)}] ({elapsed:.1f}s) -> {text.strip()[:100]}") except Exception as e: results.append({ "model": model_label, "file": os.path.basename(wav_path), "text": f"[ERROR: {e}]", "time": 0, }) print(f" [{os.path.basename(wav_path)}] ERROR: {e}") del model, processor free_gpu() print(f" Model unloaded. VRAM: {gpu_mem()}") except ImportError: print(" transformers not installed!") return [{"model": model_label, "file": os.path.basename(p), "text": "[IMPORT_ERROR]", "time": 0} for p in wav_paths] return results # ============================================================ # MODEL 4: Whisper large-v3 (forced Sinhala) # ============================================================ def test_whisper_large(wav_paths: List[str], device: str) -> List[Dict]: """Test openai/whisper-large-v3 with forced Sinhala language.""" print(f"\n{'=' * 60}") print(f"MODEL: openai/whisper-large-v3 (forced Sinhala)") print(f"{'=' * 60}") results = [] try: from transformers import WhisperForConditionalGeneration, WhisperProcessor model_id = "openai/whisper-large-v3" print(f" Loading {model_id} (fp16)...") use_dtype = torch.float16 if device == "cuda" else torch.float32 processor = WhisperProcessor.from_pretrained(model_id) model = WhisperForConditionalGeneration.from_pretrained( model_id, torch_dtype=use_dtype, ).to(device).eval() print(f" Model loaded. VRAM: {gpu_mem()}") forced_decoder_ids = processor.get_decoder_prompt_ids(language="si", task="transcribe") for wav_path in wav_paths: start = time.time() try: audio = load_audio_16k(wav_path) input_features = processor( audio, sampling_rate=16000, return_tensors="pt" ).input_features.to(device) if use_dtype == torch.float16: input_features = input_features.half() with torch.no_grad(): predicted_ids = model.generate( input_features, forced_decoder_ids=forced_decoder_ids, max_new_tokens=225, no_repeat_ngram_size=3, temperature=0.0, do_sample=False, ) text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] elapsed = time.time() - start results.append({ "model": "whisper-large-v3", "file": os.path.basename(wav_path), "text": text.strip(), "time": round(elapsed, 2), }) print(f" [{os.path.basename(wav_path)}] ({elapsed:.1f}s) -> {text.strip()[:100]}") except Exception as e: results.append({ "model": "whisper-large-v3", "file": os.path.basename(wav_path), "text": f"[ERROR: {e}]", "time": 0, }) print(f" [{os.path.basename(wav_path)}] ERROR: {e}") del model, processor free_gpu() print(f" Model unloaded. VRAM: {gpu_mem()}") except ImportError: print(" transformers not installed!") return [{"model": "whisper-large-v3", "file": os.path.basename(p), "text": "[IMPORT_ERROR]", "time": 0} for p in wav_paths] return results # ============================================================ # FIND AUDIO SEGMENTS # ============================================================ def find_segments(segments_dir: str, max_segments: int) -> List[str]: """Find WAV segments from pipeline output.""" seg_path = Path(segments_dir) if not seg_path.exists(): print(f"Segments directory not found: {segments_dir}") print(f" Run the pipeline first, or provide --wav-files") sys.exit(1) wavs = sorted(seg_path.rglob("*.wav")) if not wavs: print(f"No WAV files found in {segments_dir}") sys.exit(1) # Pick a diverse sample: first, middle, last if len(wavs) > max_segments: indices = np.linspace(0, len(wavs) - 1, max_segments, dtype=int) wavs = [wavs[i] for i in indices] print(f"Found {len(wavs)} segments to test") for w in wavs: import soundfile as sf info = sf.info(str(w)) print(f" {w.name}: {info.duration:.1f}s, {info.samplerate}Hz") return [str(w) for w in wavs] # ============================================================ # MAIN # ============================================================ AVAILABLE_MODELS = { "mms": ("facebook/mms-1b-all", "MMS CTC"), "whisper-si-v3": ("Lingalingeswaran/whisper-small-sinhala_v3", "Whisper-Small-Sinhala-v3"), "whisper-si": ("Lingalingeswaran/whisper-small-sinhala", "Whisper-Small-Sinhala"), "whisper-hackvermin": ("hackvermin/whisper-small-sinhala", "Whisper-Small-Hackvermin"), "whisper-seniruk": ("seniruk/whisper-small-si", "Whisper-Small-Seniruk"), "whisper-large": ("openai/whisper-large-v3", "Whisper-Large-v3-ForcedSi"), "whisper-turbo": ("openai/whisper-large-v3-turbo", "Whisper-Turbo-ForcedSi"), } def main(): parser = argparse.ArgumentParser(description="Compare ASR models for Sinhala") parser.add_argument("--segments-dir", type=str, default="pipeline_output/segments", help="Directory with WAV segments from pipeline") parser.add_argument("--wav-files", nargs="+", type=str, default=None, help="Specific WAV files to test (overrides --segments-dir)") parser.add_argument("--max-segments", type=int, default=5, help="Max segments to test (default: 5)") parser.add_argument("--models", nargs="+", type=str, default=None, choices=list(AVAILABLE_MODELS.keys()), help=f"Models to test (default: all). Options: {list(AVAILABLE_MODELS.keys())}") parser.add_argument("--cpu", action="store_true", help="Force CPU") parser.add_argument("--output", type=str, default="asr_comparison.json", help="Output JSON file for results") args = parser.parse_args() if args.cpu: device = "cpu" elif torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" print(f"\n{'#' * 60}") print(f"# Sinhala ASR Model Comparison") print(f"# Device: {device}") if device == "cuda": print(f"# GPU: {torch.cuda.get_device_name(0)}") props = torch.cuda.get_device_properties(0) print(f"# VRAM: {props.total_memory / 1e9:.1f}GB") elif device == "mps": print(f"# GPU: Apple Silicon (MPS)") print(f"{'#' * 60}") # Find audio segments if args.wav_files: wav_paths = args.wav_files else: wav_paths = find_segments(args.segments_dir, args.max_segments) # Select models to test models_to_test = args.models or list(AVAILABLE_MODELS.keys()) # whisper-large needs ~3.1GB fp16 if device == "cuda": vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if vram_gb < 5.0 and "whisper-large" in models_to_test and len(models_to_test) > 1: print(f"\n VRAM ({vram_gb:.1f}GB) may be tight for whisper-large-v3. Testing anyway...") all_results = [] for model_key in models_to_test: free_gpu() if model_key == "mms": results = test_mms(wav_paths, device) elif model_key == "whisper-large": results = test_whisper_large(wav_paths, device) else: model_id, label = AVAILABLE_MODELS[model_key] results = test_whisper_finetune(wav_paths, model_id, label, device) all_results.extend(results) # ---- Summary ---- print(f"\n\n{'=' * 80}") print(f"{'COMPARISON RESULTS':^80}") print(f"{'=' * 80}") # Group by file files = sorted(set(r["file"] for r in all_results)) models = sorted(set(r["model"] for r in all_results)) for f in files: print(f"\n {f}") print(f" {'Model':<35} {'Time':>6} Transcription") print(f" {'-'*35} {'-'*6} {'-'*50}") for m in models: match = [r for r in all_results if r["file"] == f and r["model"] == m] if match: r = match[0] text_preview = r["text"][:80] + ("..." if len(r["text"]) > 80 else "") print(f" {m:<35} {r['time']:>5.1f}s {text_preview}") # Save results output_path = Path(args.output) output_path.write_text(json.dumps(all_results, indent=2, ensure_ascii=False), encoding="utf-8") print(f"\nResults saved to {output_path}") # ---- Verdict ---- print(f"\n{'=' * 60}") print("WHAT TO LOOK FOR:") print("=" * 60) print(""" 1. Does the text contain REAL Sinhala words? (Not repetitive garbage like "..." patterns) 2. Is the text coherent? Does it make sense as spoken Sinhala? (Historical content about Sri Lanka) 3. Speed: faster is better for processing 293 videos. 4. Hallucination: any model producing repetitive/nonsensical output should be rejected immediately. Pick the model that produces the most coherent Sinhala text, then update local_pipeline.py to use it. """) if __name__ == "__main__": main()