sinhala-tts / scripts /test_asr_models.py
outlawmold's picture
Merge origin/macos-apple-silicon into main
7a31795
#!/usr/bin/env python3
"""
=============================================================
Sinhala ASR Model Comparison Test
=============================================================
Tests multiple ASR models on existing audio segments from
your pipeline output. Helps you pick the best ASR model
before re-running the full pipeline.
Tests:
1. Facebook MMS (facebook/mms-1b-all) — CTC, 1B params
2. Lingalingeswaran/whisper-small-sinhala_v3 — Whisper fine-tune, 242M params
3. Lingalingeswaran/whisper-small-sinhala — Whisper fine-tune, 242M params
4. openai/whisper-large-v3 (forced Sinhala) — baseline, 1.5B params
Runs ONE model at a time to fit in 6GB VRAM.
Usage:
# Test on existing pipeline segments (default: first 5)
python scripts/test_asr_models.py --segments-dir pipeline_output/segments
# Test on specific WAV files
python scripts/test_asr_models.py --wav-files audio1.wav audio2.wav
# Test only specific models
python scripts/test_asr_models.py --models mms whisper-si-v3
# Test more segments
python scripts/test_asr_models.py --max-segments 20
# CPU only (slower but works without GPU)
python scripts/test_asr_models.py --cpu
Requirements:
pip install transformers torch torchaudio soundfile numpy
=============================================================
"""
import os
import sys
import gc
import json
import argparse
import time
import warnings
from pathlib import Path
from typing import List, Dict, Optional, Tuple
warnings.filterwarnings("ignore")
import torch
import torchaudio
import numpy as np
# ============================================================
# GPU MEMORY MANAGEMENT
# ============================================================
def free_gpu():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
elif torch.backends.mps.is_available():
pass
def gpu_mem():
if torch.cuda.is_available():
used = torch.cuda.memory_allocated() / 1e9
total = torch.cuda.get_device_properties(0).total_memory / 1e9
return f"{used:.1f}/{total:.1f}GB"
if torch.backends.mps.is_available():
return "MPS"
return "CPU"
# ============================================================
# AUDIO LOADING
# ============================================================
def load_audio_16k(wav_path: str) -> np.ndarray:
"""Load WAV and resample to 16kHz mono float32 numpy array."""
import soundfile as sf
data, sr = sf.read(wav_path, dtype='float32')
waveform = torch.from_numpy(data)
if waveform.ndim == 2:
waveform = waveform.mean(dim=-1) # stereo to mono
if waveform.ndim == 0:
return np.zeros(16000, dtype=np.float32)
if sr != 16000:
waveform = torchaudio.transforms.Resample(sr, 16000)(waveform.unsqueeze(0)).squeeze(0)
return waveform.numpy()
# ============================================================
# MODEL 1: Facebook MMS (CTC)
# ============================================================
def test_mms(wav_paths: List[str], device: str) -> List[Dict]:
"""Test facebook/mms-1b-all with Sinhala adapter."""
print("\n" + "=" * 60)
print("MODEL 1: facebook/mms-1b-all (MMS CTC)")
print("=" * 60)
results = []
try:
from transformers import Wav2Vec2ForCTC, AutoProcessor
model_id = "facebook/mms-1b-all"
target_lang = "sin"
# First check if Sinhala is supported
print(f" Loading processor for {model_id}...")
try:
processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang)
except Exception as e:
if "sin" in str(e).lower() or "not found" in str(e).lower() or "does not exist" in str(e).lower():
print(f" WARNING: Sinhala ('sin') failed: {e}")
# Try alternative codes
for alt_code in ["snh", "sinh", "si"]:
try:
print(f" Trying alternate code '{alt_code}'...")
processor = AutoProcessor.from_pretrained(model_id, target_lang=alt_code)
target_lang = alt_code
print(f" Found working code: '{alt_code}'")
break
except Exception:
continue
else:
print(f" No Sinhala adapter found in MMS. Skipping.")
return [{"model": "mms-1b-all", "file": os.path.basename(p), "text": "[UNSUPPORTED]", "time": 0} for p in wav_paths]
else:
raise
print(f" Loading model (fp16)...")
use_dtype = torch.float16 if device == "cuda" else torch.float32
model = Wav2Vec2ForCTC.from_pretrained(
model_id,
target_lang=target_lang,
ignore_mismatched_sizes=True,
torch_dtype=use_dtype,
).to(device).eval()
print(f" Model loaded. VRAM: {gpu_mem()}")
for wav_path in wav_paths:
start = time.time()
try:
audio = load_audio_16k(wav_path)
inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
if use_dtype == torch.float16:
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)[0]
text = processor.decode(predicted_ids)
elapsed = time.time() - start
results.append({
"model": "mms-1b-all",
"file": os.path.basename(wav_path),
"text": text.strip(),
"time": round(elapsed, 2),
})
print(f" [{os.path.basename(wav_path)}] ({elapsed:.1f}s) -> {text.strip()[:100]}")
except Exception as e:
results.append({
"model": "mms-1b-all",
"file": os.path.basename(wav_path),
"text": f"[ERROR: {e}]",
"time": 0,
})
print(f" [{os.path.basename(wav_path)}] ERROR: {e}")
del model, processor
free_gpu()
print(f" Model unloaded. VRAM: {gpu_mem()}")
except ImportError:
print(" transformers not installed!")
return [{"model": "mms-1b-all", "file": os.path.basename(p), "text": "[IMPORT_ERROR]", "time": 0} for p in wav_paths]
return results
# ============================================================
# MODEL 2/3: Whisper Sinhala Fine-tunes (HuggingFace transformers)
# ============================================================
def test_whisper_finetune(wav_paths: List[str], model_id: str, model_label: str, device: str) -> List[Dict]:
"""Test a HuggingFace Whisper fine-tune for Sinhala."""
print(f"\n{'=' * 60}")
print(f"MODEL: {model_label} ({model_id})")
print(f"{'=' * 60}")
results = []
try:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
print(f" Loading {model_id}...")
use_dtype = torch.float16 if device == "cuda" else torch.float32
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=use_dtype,
).to(device).eval()
print(f" Model loaded. VRAM: {gpu_mem()}")
# Force Sinhala language and transcribe task
gen_kwargs = {
"max_new_tokens": 225,
"no_repeat_ngram_size": 3,
"language": "si",
"task": "transcribe",
}
for wav_path in wav_paths:
start = time.time()
try:
audio = load_audio_16k(wav_path)
input_features = processor(
audio, sampling_rate=16000, return_tensors="pt"
).input_features.to(device)
if use_dtype == torch.float16:
input_features = input_features.half()
with torch.no_grad():
predicted_ids = model.generate(
input_features,
**gen_kwargs
)
text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
elapsed = time.time() - start
results.append({
"model": model_label,
"file": os.path.basename(wav_path),
"text": text.strip(),
"time": round(elapsed, 2),
})
print(f" [{os.path.basename(wav_path)}] ({elapsed:.1f}s) -> {text.strip()[:100]}")
except Exception as e:
results.append({
"model": model_label,
"file": os.path.basename(wav_path),
"text": f"[ERROR: {e}]",
"time": 0,
})
print(f" [{os.path.basename(wav_path)}] ERROR: {e}")
del model, processor
free_gpu()
print(f" Model unloaded. VRAM: {gpu_mem()}")
except ImportError:
print(" transformers not installed!")
return [{"model": model_label, "file": os.path.basename(p), "text": "[IMPORT_ERROR]", "time": 0} for p in wav_paths]
return results
# ============================================================
# MODEL 4: Whisper large-v3 (forced Sinhala)
# ============================================================
def test_whisper_large(wav_paths: List[str], device: str) -> List[Dict]:
"""Test openai/whisper-large-v3 with forced Sinhala language."""
print(f"\n{'=' * 60}")
print(f"MODEL: openai/whisper-large-v3 (forced Sinhala)")
print(f"{'=' * 60}")
results = []
try:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
model_id = "openai/whisper-large-v3"
print(f" Loading {model_id} (fp16)...")
use_dtype = torch.float16 if device == "cuda" else torch.float32
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=use_dtype,
).to(device).eval()
print(f" Model loaded. VRAM: {gpu_mem()}")
forced_decoder_ids = processor.get_decoder_prompt_ids(language="si", task="transcribe")
for wav_path in wav_paths:
start = time.time()
try:
audio = load_audio_16k(wav_path)
input_features = processor(
audio, sampling_rate=16000, return_tensors="pt"
).input_features.to(device)
if use_dtype == torch.float16:
input_features = input_features.half()
with torch.no_grad():
predicted_ids = model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids,
max_new_tokens=225,
no_repeat_ngram_size=3,
temperature=0.0,
do_sample=False,
)
text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
elapsed = time.time() - start
results.append({
"model": "whisper-large-v3",
"file": os.path.basename(wav_path),
"text": text.strip(),
"time": round(elapsed, 2),
})
print(f" [{os.path.basename(wav_path)}] ({elapsed:.1f}s) -> {text.strip()[:100]}")
except Exception as e:
results.append({
"model": "whisper-large-v3",
"file": os.path.basename(wav_path),
"text": f"[ERROR: {e}]",
"time": 0,
})
print(f" [{os.path.basename(wav_path)}] ERROR: {e}")
del model, processor
free_gpu()
print(f" Model unloaded. VRAM: {gpu_mem()}")
except ImportError:
print(" transformers not installed!")
return [{"model": "whisper-large-v3", "file": os.path.basename(p), "text": "[IMPORT_ERROR]", "time": 0} for p in wav_paths]
return results
# ============================================================
# FIND AUDIO SEGMENTS
# ============================================================
def find_segments(segments_dir: str, max_segments: int) -> List[str]:
"""Find WAV segments from pipeline output."""
seg_path = Path(segments_dir)
if not seg_path.exists():
print(f"Segments directory not found: {segments_dir}")
print(f" Run the pipeline first, or provide --wav-files")
sys.exit(1)
wavs = sorted(seg_path.rglob("*.wav"))
if not wavs:
print(f"No WAV files found in {segments_dir}")
sys.exit(1)
# Pick a diverse sample: first, middle, last
if len(wavs) > max_segments:
indices = np.linspace(0, len(wavs) - 1, max_segments, dtype=int)
wavs = [wavs[i] for i in indices]
print(f"Found {len(wavs)} segments to test")
for w in wavs:
import soundfile as sf
info = sf.info(str(w))
print(f" {w.name}: {info.duration:.1f}s, {info.samplerate}Hz")
return [str(w) for w in wavs]
# ============================================================
# MAIN
# ============================================================
AVAILABLE_MODELS = {
"mms": ("facebook/mms-1b-all", "MMS CTC"),
"whisper-si-v3": ("Lingalingeswaran/whisper-small-sinhala_v3", "Whisper-Small-Sinhala-v3"),
"whisper-si": ("Lingalingeswaran/whisper-small-sinhala", "Whisper-Small-Sinhala"),
"whisper-hackvermin": ("hackvermin/whisper-small-sinhala", "Whisper-Small-Hackvermin"),
"whisper-seniruk": ("seniruk/whisper-small-si", "Whisper-Small-Seniruk"),
"whisper-large": ("openai/whisper-large-v3", "Whisper-Large-v3-ForcedSi"),
"whisper-turbo": ("openai/whisper-large-v3-turbo", "Whisper-Turbo-ForcedSi"),
}
def main():
parser = argparse.ArgumentParser(description="Compare ASR models for Sinhala")
parser.add_argument("--segments-dir", type=str, default="pipeline_output/segments",
help="Directory with WAV segments from pipeline")
parser.add_argument("--wav-files", nargs="+", type=str, default=None,
help="Specific WAV files to test (overrides --segments-dir)")
parser.add_argument("--max-segments", type=int, default=5,
help="Max segments to test (default: 5)")
parser.add_argument("--models", nargs="+", type=str, default=None,
choices=list(AVAILABLE_MODELS.keys()),
help=f"Models to test (default: all). Options: {list(AVAILABLE_MODELS.keys())}")
parser.add_argument("--cpu", action="store_true", help="Force CPU")
parser.add_argument("--output", type=str, default="asr_comparison.json",
help="Output JSON file for results")
args = parser.parse_args()
if args.cpu:
device = "cpu"
elif torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"\n{'#' * 60}")
print(f"# Sinhala ASR Model Comparison")
print(f"# Device: {device}")
if device == "cuda":
print(f"# GPU: {torch.cuda.get_device_name(0)}")
props = torch.cuda.get_device_properties(0)
print(f"# VRAM: {props.total_memory / 1e9:.1f}GB")
elif device == "mps":
print(f"# GPU: Apple Silicon (MPS)")
print(f"{'#' * 60}")
# Find audio segments
if args.wav_files:
wav_paths = args.wav_files
else:
wav_paths = find_segments(args.segments_dir, args.max_segments)
# Select models to test
models_to_test = args.models or list(AVAILABLE_MODELS.keys())
# whisper-large needs ~3.1GB fp16
if device == "cuda":
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if vram_gb < 5.0 and "whisper-large" in models_to_test and len(models_to_test) > 1:
print(f"\n VRAM ({vram_gb:.1f}GB) may be tight for whisper-large-v3. Testing anyway...")
all_results = []
for model_key in models_to_test:
free_gpu()
if model_key == "mms":
results = test_mms(wav_paths, device)
elif model_key == "whisper-large":
results = test_whisper_large(wav_paths, device)
else:
model_id, label = AVAILABLE_MODELS[model_key]
results = test_whisper_finetune(wav_paths, model_id, label, device)
all_results.extend(results)
# ---- Summary ----
print(f"\n\n{'=' * 80}")
print(f"{'COMPARISON RESULTS':^80}")
print(f"{'=' * 80}")
# Group by file
files = sorted(set(r["file"] for r in all_results))
models = sorted(set(r["model"] for r in all_results))
for f in files:
print(f"\n {f}")
print(f" {'Model':<35} {'Time':>6} Transcription")
print(f" {'-'*35} {'-'*6} {'-'*50}")
for m in models:
match = [r for r in all_results if r["file"] == f and r["model"] == m]
if match:
r = match[0]
text_preview = r["text"][:80] + ("..." if len(r["text"]) > 80 else "")
print(f" {m:<35} {r['time']:>5.1f}s {text_preview}")
# Save results
output_path = Path(args.output)
output_path.write_text(json.dumps(all_results, indent=2, ensure_ascii=False), encoding="utf-8")
print(f"\nResults saved to {output_path}")
# ---- Verdict ----
print(f"\n{'=' * 60}")
print("WHAT TO LOOK FOR:")
print("=" * 60)
print("""
1. Does the text contain REAL Sinhala words?
(Not repetitive garbage like "..." patterns)
2. Is the text coherent? Does it make sense as spoken Sinhala?
(Historical content about Sri Lanka)
3. Speed: faster is better for processing 293 videos.
4. Hallucination: any model producing repetitive/nonsensical output
should be rejected immediately.
Pick the model that produces the most coherent Sinhala text, then
update local_pipeline.py to use it.
""")
if __name__ == "__main__":
main()