milo-asr / evaluate_common_voice.py
pluttodk's picture
merge2
d38720b
#!/usr/bin/env python
"""
Benchmark ASR models on Common Voice Danish dataset.
This script evaluates hvisketiske-v2 (Qwen3-ASR) and hviske-v3 (Whisper)
on the Mozilla Common Voice Danish test set for comparison.
IMPORTANT: Common Voice requires authentication and agreement to terms of use.
Before running this script:
1. Create a HuggingFace account at https://huggingface.co
2. Visit https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0
3. Agree to the dataset terms of use
4. Create an access token at https://huggingface.co/settings/tokens
5. Login via CLI: `huggingface-cli login`
Usage:
# After logging in:
python huggingface/evaluate_common_voice.py \
--hvisketiske-path ./outputs/hvisketiske-v2/checkpoint-23448 \
--max-samples 1000 \
--output-file ./results/common_voice_comparison.json
# Quick test with fewer samples:
python huggingface/evaluate_common_voice.py --max-samples 100
# Use specific token:
python huggingface/evaluate_common_voice.py --hf-token YOUR_TOKEN
"""
import argparse
import json
import sys
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
import soundfile as sf
from datasets import load_dataset
from jiwer import cer, wer
from tqdm import tqdm
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from hvisketiske.evaluation.model_adapters import (
ASRModelAdapter,
HviskeV3Adapter,
Qwen3ASRAdapter,
TranscriptionResult,
)
from hvisketiske.evaluation.timing import AggregatedTimingStats
@dataclass
class CommonVoiceSample:
"""A single Common Voice sample."""
audio_path: str
reference: str
audio_duration: float
def load_common_voice_danish(
split: str = "test",
max_samples: Optional[int] = None,
cache_dir: Optional[str] = None,
hf_token: Optional[str] = None,
) -> List[CommonVoiceSample]:
"""
Load Common Voice Danish dataset and prepare samples.
Args:
split: Dataset split to load (test, validation, train).
max_samples: Maximum number of samples to load.
cache_dir: Directory to cache audio files.
hf_token: HuggingFace API token for authentication.
Returns:
List of CommonVoiceSample objects.
"""
print(f"Loading Common Voice Danish ({split} split)...")
print("Note: This requires HuggingFace authentication and agreement to dataset terms.")
print("Visit: https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0")
print()
try:
ds = load_dataset(
"mozilla-foundation/common_voice_17_0",
"da",
split=split,
trust_remote_code=True,
token=hf_token,
)
except Exception as e:
error_msg = str(e)
if "EmptyDatasetError" in error_msg or "doesn't contain any data" in error_msg:
print("\n" + "=" * 70)
print("ERROR: Cannot access Common Voice dataset.")
print("=" * 70)
print("\nThis dataset requires authentication. Please:")
print("1. Visit https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0")
print("2. Log in and agree to the terms of use")
print("3. Run: huggingface-cli login")
print("4. Or pass --hf-token YOUR_TOKEN to this script")
print("=" * 70 + "\n")
raise
if max_samples:
ds = ds.select(range(min(max_samples, len(ds))))
print(f"Loaded {len(ds)} samples")
# Create temp directory for audio files if not provided
if cache_dir is None:
cache_dir = tempfile.mkdtemp(prefix="cv_danish_")
cache_path = Path(cache_dir)
cache_path.mkdir(parents=True, exist_ok=True)
samples = []
print("Preparing audio files...")
for i, item in enumerate(tqdm(ds, desc="Preparing samples")):
# Extract audio array and sample rate
audio_array = item["audio"]["array"]
sample_rate = item["audio"]["sampling_rate"]
# Save to temp file
audio_path = cache_path / f"sample_{i:06d}.wav"
sf.write(str(audio_path), audio_array, sample_rate)
# Calculate duration
duration = len(audio_array) / sample_rate
samples.append(
CommonVoiceSample(
audio_path=str(audio_path),
reference=item["sentence"],
audio_duration=duration,
)
)
return samples
def normalize_text(text: str) -> str:
"""Normalize text for fair comparison."""
text = text.lower()
text = " ".join(text.split())
return text
def evaluate_model(
model: ASRModelAdapter,
samples: List[CommonVoiceSample],
warmup_samples: int = 3,
) -> dict:
"""
Evaluate a model on the Common Voice samples.
Args:
model: Model adapter to evaluate.
samples: List of samples to evaluate.
warmup_samples: Number of warmup iterations.
Returns:
Dictionary with evaluation results.
"""
print(f"\nEvaluating: {model.model_name}")
print("Loading model...")
model.load()
# Warmup
if warmup_samples > 0 and samples:
print(f"Running {warmup_samples} warmup iterations...")
model.warmup(samples[0].audio_path, num_runs=warmup_samples)
# Transcribe all samples
predictions = []
individual_times = []
total_audio_duration = 0.0
total_inference_time = 0.0
print(f"Transcribing {len(samples)} samples...")
for sample in tqdm(samples, desc=f"Evaluating {model.model_name[:30]}"):
result = model.transcribe(sample.audio_path)
predictions.append(result.text)
individual_times.append(result.inference_time_seconds)
total_audio_duration += sample.audio_duration
total_inference_time += result.inference_time_seconds
# Normalize text
predictions_norm = [normalize_text(p) for p in predictions]
references_norm = [normalize_text(s.reference) for s in samples]
# Calculate metrics
word_error_rate = wer(references_norm, predictions_norm)
char_error_rate = cer(references_norm, predictions_norm)
timing_stats = AggregatedTimingStats(
total_inference_time_seconds=total_inference_time,
total_audio_duration_seconds=total_audio_duration,
num_samples=len(samples),
individual_times=individual_times,
)
return {
"model_name": model.model_name,
"model_size": model.model_size_params,
"accuracy": {
"wer": word_error_rate,
"cer": char_error_rate,
},
"performance": {
"total_inference_time_seconds": timing_stats.total_inference_time_seconds,
"total_audio_duration_seconds": timing_stats.total_audio_duration_seconds,
"real_time_factor": timing_stats.real_time_factor,
"throughput_samples_per_second": timing_stats.throughput_samples_per_second,
"mean_time_per_sample_seconds": timing_stats.mean_time_per_sample,
"std_time_per_sample_seconds": timing_stats.std_time_per_sample,
},
"num_samples": len(samples),
}
def print_summary(results: dict) -> None:
"""Print formatted comparison summary."""
print("\n" + "=" * 80)
print("COMMON VOICE DANISH - ASR MODEL COMPARISON")
print("=" * 80)
print(f"Dataset: mozilla-foundation/common_voice_17_0 (Danish)")
print(f"Number of models: {len(results['models'])}")
sample_count = next(iter(results["models"].values()))["num_samples"]
print(f"Samples evaluated: {sample_count}")
# Accuracy comparison table
print("\n" + "-" * 80)
print("ACCURACY METRICS (lower is better)")
print("-" * 80)
print(f"{'Model':<45} {'WER':>12} {'CER':>12}")
print("-" * 80)
for name, result in sorted(
results["models"].items(), key=lambda x: x[1]["accuracy"]["wer"]
):
print(
f"{result['model_name'][:45]:<45} "
f"{result['accuracy']['wer']:>11.2%} "
f"{result['accuracy']['cer']:>11.2%}"
)
# Performance comparison table
print("\n" + "-" * 80)
print("PERFORMANCE METRICS (RTF < 1.0 = faster than real-time)")
print("-" * 80)
print(f"{'Model':<35} {'RTF':>8} {'Throughput':>12} {'Mean Time':>12}")
print(f"{'':35} {'':>8} {'(samples/s)':>12} {'(s/sample)':>12}")
print("-" * 80)
for name, result in sorted(
results["models"].items(), key=lambda x: x[1]["performance"]["real_time_factor"]
):
perf = result["performance"]
print(
f"{result['model_name'][:35]:<35} "
f"{perf['real_time_factor']:>8.3f} "
f"{perf['throughput_samples_per_second']:>12.2f} "
f"{perf['mean_time_per_sample_seconds']:>12.3f}"
)
print("=" * 80)
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Benchmark ASR models on Common Voice Danish"
)
parser.add_argument(
"--output-file",
type=Path,
default=Path("results/common_voice_comparison.json"),
help="Path to save comparison report (JSON)",
)
parser.add_argument(
"--max-samples",
type=int,
default=None,
help="Maximum samples to evaluate (for quick testing)",
)
parser.add_argument(
"--warmup",
type=int,
default=3,
help="Number of warmup iterations per model (default: 3)",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0",
help="Device for inference (default: cuda:0)",
)
parser.add_argument(
"--cache-dir",
type=str,
default=None,
help="Directory to cache audio files",
)
parser.add_argument(
"--hf-token",
type=str,
default=None,
help="HuggingFace API token for authentication (or use huggingface-cli login)",
)
# Model selection
parser.add_argument(
"--skip-hviske-v3",
action="store_true",
help="Skip hviske-v3-conversation model",
)
parser.add_argument(
"--skip-hvisketiske",
action="store_true",
help="Skip hvisketiske-v2 model",
)
parser.add_argument(
"--hvisketiske-path",
type=str,
default="./outputs/hvisketiske-v2/checkpoint-23448",
help="Path to local hvisketiske checkpoint",
)
return parser.parse_args()
def main() -> None:
"""Main entry point for Common Voice evaluation."""
args = parse_args()
# Load dataset
samples = load_common_voice_danish(
split="test",
max_samples=args.max_samples,
cache_dir=args.cache_dir,
hf_token=args.hf_token,
)
# Configure models to evaluate
models = []
if not args.skip_hviske_v3:
models.append(
HviskeV3Adapter(
model_id="syvai/hviske-v3-conversation",
device=args.device,
)
)
if not args.skip_hvisketiske:
models.append(
Qwen3ASRAdapter(
model_path=args.hvisketiske_path,
device=args.device,
)
)
if not models:
print("Error: No models selected for evaluation")
sys.exit(1)
print("=" * 60)
print("Common Voice Danish ASR Evaluation")
print("=" * 60)
print(f"Dataset: mozilla-foundation/common_voice_17_0")
print(f"Samples: {len(samples)}")
print(f"Device: {args.device}")
print(f"Warmup iterations: {args.warmup}")
print(f"Models to evaluate: {len(models)}")
for m in models:
print(f" - {m.model_name} ({m.model_size_params})")
print("=" * 60)
# Evaluate all models
results = {"dataset": "mozilla-foundation/common_voice_17_0", "models": {}}
for model in models:
model_results = evaluate_model(model, samples, warmup_samples=args.warmup)
results["models"][model.model_name] = model_results
# Print summary
print_summary(results)
# Save results
args.output_file.parent.mkdir(parents=True, exist_ok=True)
with open(args.output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\nResults saved to: {args.output_file}")
if __name__ == "__main__":
main()