""" Evaluation script for Whisper German ASR model Computes WER, CER, and other metrics on test data """ import torch from transformers import WhisperForConditionalGeneration, WhisperProcessor from datasets import load_from_disk import jiwer import librosa import numpy as np from pathlib import Path import json from tqdm import tqdm import argparse def normalize_text(text): """Normalize text for consistent evaluation""" import re text = text.lower() text = re.sub(r'[^\w\s]', '', text) # Remove punctuation text = ' '.join(text.split()) # Normalize whitespace return text def load_model(model_path): """Load fine-tuned Whisper model""" print(f"\nšŸ“¦ Loading model from: {model_path}") model_path = Path(model_path) # Check for checkpoint directories if model_path.is_dir(): checkpoints = list(model_path.glob('checkpoint-*')) if checkpoints: # Use the latest checkpoint latest = max(checkpoints, key=lambda p: int(p.name.split('-')[1])) model_path = latest print(f" Using checkpoint: {latest.name}") model = WhisperForConditionalGeneration.from_pretrained(model_path) processor = WhisperProcessor.from_pretrained("openai/whisper-small") # Set language conditioning model.config.forced_decoder_ids = processor.get_decoder_prompt_ids( language="german", task="transcribe" ) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() print(f"āœ“ Model loaded on {device}") print(f"āœ“ Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M") return model, processor, device def transcribe_audio(audio_array, sample_rate, model, processor, device): """Transcribe a single audio sample""" # Resample if needed if sample_rate != 16000: audio_array = librosa.resample( audio_array, orig_sr=sample_rate, target_sr=16000 ) # Process audio input_features = processor( audio_array, sampling_rate=16000, return_tensors="pt" ).input_features.to(device) # Generate transcription with torch.no_grad(): predicted_ids = model.generate( input_features, max_length=448, num_beams=5, early_stopping=True ) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return transcription def evaluate_dataset(model, processor, device, dataset_path, split='test', max_samples=None): """Evaluate model on dataset""" print(f"\nšŸ“Š Evaluating on dataset: {dataset_path}") # Load dataset dataset = load_from_disk(dataset_path) # Handle different dataset formats if isinstance(dataset, dict): if split in dataset: dataset = dataset[split] elif 'test' in dataset: dataset = dataset['test'] elif 'validation' in dataset: dataset = dataset['validation'] else: # Use a portion of train as test dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)['test'] if max_samples: dataset = dataset.select(range(min(max_samples, len(dataset)))) print(f" Evaluating on {len(dataset)} samples...") predictions = [] references = [] for sample in tqdm(dataset, desc="Transcribing"): # Get audio audio = sample['audio']['array'] sr = sample['audio']['sampling_rate'] # Transcribe pred = transcribe_audio(audio, sr, model, processor, device) ref = sample['transcription'] predictions.append(normalize_text(pred)) references.append(normalize_text(ref)) # Compute metrics wer = jiwer.wer(references, predictions) cer = jiwer.cer(references, predictions) # Word-level metrics wer_transform = jiwer.Compose([ jiwer.ToLowerCase(), jiwer.RemovePunctuation(), jiwer.RemoveMultipleSpaces(), jiwer.Strip(), ]) measures = jiwer.compute_measures( references, predictions, truth_transform=wer_transform, hypothesis_transform=wer_transform ) results = { 'wer': wer, 'cer': cer, 'num_samples': len(dataset), 'substitutions': measures['substitutions'], 'deletions': measures['deletions'], 'insertions': measures['insertions'], 'hits': measures['hits'], } return results, predictions, references def print_results(results): """Print evaluation results""" print("\n" + "=" * 60) print("EVALUATION RESULTS") print("=" * 60) print(f"\nšŸ“Š Metrics:") print(f" Word Error Rate (WER): {results['wer']:.4f} ({results['wer']*100:.2f}%)") print(f" Character Error Rate (CER): {results['cer']:.4f} ({results['cer']*100:.2f}%)") print(f"\nšŸ“ˆ Word-level Statistics:") print(f" Correct (Hits): {results['hits']}") print(f" Substitutions: {results['substitutions']}") print(f" Deletions: {results['deletions']}") print(f" Insertions: {results['insertions']}") print(f" Total samples: {results['num_samples']}") print("=" * 60) def save_results(results, predictions, references, output_file): """Save evaluation results to file""" output = { 'metrics': results, 'samples': [ {'prediction': p, 'reference': r} for p, r in zip(predictions, references) ] } with open(output_file, 'w', encoding='utf-8') as f: json.dump(output, f, indent=2, ensure_ascii=False) print(f"\nšŸ’¾ Results saved to: {output_file}") def main(): parser = argparse.ArgumentParser(description="Evaluate Whisper German ASR model") parser.add_argument('--model', type=str, default='./whisper_test_tuned', help='Path to fine-tuned model') parser.add_argument('--dataset', type=str, default='./data/minds14_medium', help='Path to dataset') parser.add_argument('--split', type=str, default='test', help='Dataset split to evaluate (test/validation)') parser.add_argument('--max-samples', type=int, default=None, help='Maximum number of samples to evaluate') parser.add_argument('--output', type=str, default='./evaluation_results.json', help='Output file for results') args = parser.parse_args() # Load model model, processor, device = load_model(args.model) # Evaluate results, predictions, references = evaluate_dataset( model, processor, device, args.dataset, split=args.split, max_samples=args.max_samples ) # Print results print_results(results) # Save results save_results(results, predictions, references, args.output) print("\nāœ… Evaluation complete!\n") if __name__ == "__main__": main()