Automatic Speech Recognition
Transformers
Vietnamese
vietnamese
whisper
speech-to-text
asr-1 / src /evaluate.py
rain1024's picture
Initial commit: ASR-1 Vietnamese speech recognition model
5763d9e
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch>=2.0.0",
# "torchaudio>=2.0.0",
# "transformers>=4.36.0",
# "datasets>=2.14.0",
# "click>=8.0.0",
# "tqdm>=4.60.0",
# "jiwer>=3.0.0",
# ]
# ///
"""
Evaluation script for ASR-1 Vietnamese Speech Recognition.
Computes WER and CER on test datasets.
Usage:
uv run src/evaluate.py --model models/asr-1
uv run src/evaluate.py --model models/asr-1 --dataset vivos
uv run src/evaluate.py --model openai/whisper-large-v3 # baseline
"""
import sys
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
import torch
import click
from tqdm import tqdm
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from jiwer import wer, cer
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.data import load_common_voice, load_vivos
@click.command()
@click.option('--model', '-m', required=True, help='Model path or HuggingFace model ID')
@click.option('--dataset', type=click.Choice(['common_voice', 'vivos', 'both']), default='common_voice',
help='Evaluation dataset')
@click.option('--batch-size', default=8, type=int, help='Batch size')
@click.option('--cache-dir', default=None, help='Dataset cache directory')
@click.option('--num-samples', default=0, type=int, help='Number of samples to evaluate (0=all)')
def evaluate_model(model, dataset, batch_size, cache_dir, num_samples):
"""Evaluate ASR-1 model on Vietnamese speech benchmarks."""
device = "cuda" if torch.cuda.is_available() else "cpu"
click.echo(f"Using device: {device}")
click.echo("=" * 60)
click.echo("ASR-1: Evaluation")
click.echo("=" * 60)
# Load model
click.echo(f"\nLoading model: {model}")
processor = WhisperProcessor.from_pretrained(model)
asr_model = WhisperForConditionalGeneration.from_pretrained(model).to(device)
asr_model.eval()
# Determine datasets to evaluate
datasets_to_eval = []
if dataset in ("common_voice", "both"):
datasets_to_eval.append(("Common Voice (test)", load_common_voice("test", cache_dir=cache_dir)))
if dataset in ("vivos", "both"):
datasets_to_eval.append(("VIVOS (test)", load_vivos("test", cache_dir=cache_dir)))
for ds_name, ds in datasets_to_eval:
if num_samples > 0:
ds = ds.select(range(min(num_samples, len(ds))))
click.echo(f"\nEvaluating on {ds_name} ({len(ds)} samples)...")
all_predictions = []
all_references = []
for i in tqdm(range(0, len(ds), batch_size)):
batch = ds[i:i + batch_size]
audios = batch["audio"]
# Extract features
input_features = processor.feature_extractor(
[a["array"] for a in audios],
sampling_rate=16000,
return_tensors="pt",
).input_features.to(device)
# Generate
with torch.no_grad():
predicted_ids = asr_model.generate(
input_features,
language="vi",
task="transcribe",
)
# Decode
predictions = processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)
references = batch["sentence"]
all_predictions.extend(predictions)
all_references.extend(references)
# Compute metrics
wer_score = 100 * wer(all_references, all_predictions)
cer_score = 100 * cer(all_references, all_predictions)
click.echo(f"\n{ds_name} Results:")
click.echo(f" WER: {wer_score:.2f}%")
click.echo(f" CER: {cer_score:.2f}%")
click.echo(f" Samples: {len(all_references)}")
# Show some examples
click.echo(f"\nExamples:")
for j in range(min(5, len(all_predictions))):
click.echo(f" REF: {all_references[j]}")
click.echo(f" HYP: {all_predictions[j]}")
click.echo()
if __name__ == '__main__':
evaluate_model()