Spaces:
Sleeping
Sleeping
| """ | |
| Evaluation script for Whisper German ASR model | |
| Computes WER, CER, and other metrics on test data | |
| """ | |
| import torch | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| from datasets import load_from_disk | |
| import jiwer | |
| import librosa | |
| import numpy as np | |
| from pathlib import Path | |
| import json | |
| from tqdm import tqdm | |
| import argparse | |
| def normalize_text(text): | |
| """Normalize text for consistent evaluation""" | |
| import re | |
| text = text.lower() | |
| text = re.sub(r'[^\w\s]', '', text) # Remove punctuation | |
| text = ' '.join(text.split()) # Normalize whitespace | |
| return text | |
| def load_model(model_path): | |
| """Load fine-tuned Whisper model""" | |
| print(f"\n๐ฆ Loading model from: {model_path}") | |
| model_path = Path(model_path) | |
| # Check for checkpoint directories | |
| if model_path.is_dir(): | |
| checkpoints = list(model_path.glob('checkpoint-*')) | |
| if checkpoints: | |
| # Use the latest checkpoint | |
| latest = max(checkpoints, key=lambda p: int(p.name.split('-')[1])) | |
| model_path = latest | |
| print(f" Using checkpoint: {latest.name}") | |
| model = WhisperForConditionalGeneration.from_pretrained(model_path) | |
| processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
| # Set language conditioning | |
| model.config.forced_decoder_ids = processor.get_decoder_prompt_ids( | |
| language="german", | |
| task="transcribe" | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| model.eval() | |
| print(f"โ Model loaded on {device}") | |
| print(f"โ Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M") | |
| return model, processor, device | |
| def transcribe_audio(audio_array, sample_rate, model, processor, device): | |
| """Transcribe a single audio sample""" | |
| # Resample if needed | |
| if sample_rate != 16000: | |
| audio_array = librosa.resample( | |
| audio_array, | |
| orig_sr=sample_rate, | |
| target_sr=16000 | |
| ) | |
| # Process audio | |
| input_features = processor( | |
| audio_array, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).input_features.to(device) | |
| # Generate transcription | |
| with torch.no_grad(): | |
| predicted_ids = model.generate( | |
| input_features, | |
| max_length=448, | |
| num_beams=5, | |
| early_stopping=True | |
| ) | |
| transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| return transcription | |
| def evaluate_dataset(model, processor, device, dataset_path, split='test', max_samples=None): | |
| """Evaluate model on dataset""" | |
| print(f"\n๐ Evaluating on dataset: {dataset_path}") | |
| # Load dataset | |
| dataset = load_from_disk(dataset_path) | |
| # Handle different dataset formats | |
| if isinstance(dataset, dict): | |
| if split in dataset: | |
| dataset = dataset[split] | |
| elif 'test' in dataset: | |
| dataset = dataset['test'] | |
| elif 'validation' in dataset: | |
| dataset = dataset['validation'] | |
| else: | |
| # Use a portion of train as test | |
| dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)['test'] | |
| if max_samples: | |
| dataset = dataset.select(range(min(max_samples, len(dataset)))) | |
| print(f" Evaluating on {len(dataset)} samples...") | |
| predictions = [] | |
| references = [] | |
| for sample in tqdm(dataset, desc="Transcribing"): | |
| # Get audio | |
| audio = sample['audio']['array'] | |
| sr = sample['audio']['sampling_rate'] | |
| # Transcribe | |
| pred = transcribe_audio(audio, sr, model, processor, device) | |
| ref = sample['transcription'] | |
| predictions.append(normalize_text(pred)) | |
| references.append(normalize_text(ref)) | |
| # Compute metrics | |
| wer = jiwer.wer(references, predictions) | |
| cer = jiwer.cer(references, predictions) | |
| # Word-level metrics | |
| wer_transform = jiwer.Compose([ | |
| jiwer.ToLowerCase(), | |
| jiwer.RemovePunctuation(), | |
| jiwer.RemoveMultipleSpaces(), | |
| jiwer.Strip(), | |
| ]) | |
| measures = jiwer.compute_measures( | |
| references, | |
| predictions, | |
| truth_transform=wer_transform, | |
| hypothesis_transform=wer_transform | |
| ) | |
| results = { | |
| 'wer': wer, | |
| 'cer': cer, | |
| 'num_samples': len(dataset), | |
| 'substitutions': measures['substitutions'], | |
| 'deletions': measures['deletions'], | |
| 'insertions': measures['insertions'], | |
| 'hits': measures['hits'], | |
| } | |
| return results, predictions, references | |
| def print_results(results): | |
| """Print evaluation results""" | |
| print("\n" + "=" * 60) | |
| print("EVALUATION RESULTS") | |
| print("=" * 60) | |
| print(f"\n๐ Metrics:") | |
| print(f" Word Error Rate (WER): {results['wer']:.4f} ({results['wer']*100:.2f}%)") | |
| print(f" Character Error Rate (CER): {results['cer']:.4f} ({results['cer']*100:.2f}%)") | |
| print(f"\n๐ Word-level Statistics:") | |
| print(f" Correct (Hits): {results['hits']}") | |
| print(f" Substitutions: {results['substitutions']}") | |
| print(f" Deletions: {results['deletions']}") | |
| print(f" Insertions: {results['insertions']}") | |
| print(f" Total samples: {results['num_samples']}") | |
| print("=" * 60) | |
| def save_results(results, predictions, references, output_file): | |
| """Save evaluation results to file""" | |
| output = { | |
| 'metrics': results, | |
| 'samples': [ | |
| {'prediction': p, 'reference': r} | |
| for p, r in zip(predictions, references) | |
| ] | |
| } | |
| with open(output_file, 'w', encoding='utf-8') as f: | |
| json.dump(output, f, indent=2, ensure_ascii=False) | |
| print(f"\n๐พ Results saved to: {output_file}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Evaluate Whisper German ASR model") | |
| parser.add_argument('--model', type=str, default='./whisper_test_tuned', | |
| help='Path to fine-tuned model') | |
| parser.add_argument('--dataset', type=str, default='./data/minds14_medium', | |
| help='Path to dataset') | |
| parser.add_argument('--split', type=str, default='test', | |
| help='Dataset split to evaluate (test/validation)') | |
| parser.add_argument('--max-samples', type=int, default=None, | |
| help='Maximum number of samples to evaluate') | |
| parser.add_argument('--output', type=str, default='./evaluation_results.json', | |
| help='Output file for results') | |
| args = parser.parse_args() | |
| # Load model | |
| model, processor, device = load_model(args.model) | |
| # Evaluate | |
| results, predictions, references = evaluate_dataset( | |
| model, processor, device, | |
| args.dataset, | |
| split=args.split, | |
| max_samples=args.max_samples | |
| ) | |
| # Print results | |
| print_results(results) | |
| # Save results | |
| save_results(results, predictions, references, args.output) | |
| print("\nโ Evaluation complete!\n") | |
| if __name__ == "__main__": | |
| main() | |