Spaces:
Sleeping
Sleeping
| """ | |
| WER and CER computation for Seq2SeqTrainer eval loop. | |
| """ | |
| from __future__ import annotations | |
| from typing import TYPE_CHECKING, Callable | |
| import numpy as np | |
| if TYPE_CHECKING: | |
| from transformers import EvalPrediction, WhisperProcessor | |
| def make_compute_metrics(processor: "WhisperProcessor") -> Callable[["EvalPrediction"], dict[str, float]]: | |
| """ | |
| Returns a compute_metrics function compatible with HuggingFace Seq2SeqTrainer. | |
| Computes Word Error Rate (WER) and Character Error Rate (CER). | |
| """ | |
| import jiwer | |
| def compute_metrics(pred: "EvalPrediction") -> dict[str, float]: | |
| pred_ids = pred.predictions | |
| label_ids = pred.label_ids | |
| # Replace -100 (loss mask) with pad token id so decoding doesn't fail | |
| label_ids = np.where(label_ids != -100, label_ids, processor.tokenizer.pad_token_id) | |
| pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) | |
| label_str = processor.batch_decode(label_ids, skip_special_tokens=True) | |
| # Normalize whitespace | |
| pred_str = [" ".join(s.split()) for s in pred_str] | |
| label_str = [" ".join(s.split()) for s in label_str] | |
| wer = jiwer.wer(label_str, pred_str) | |
| cer = jiwer.cer(label_str, pred_str) | |
| return {"wer": round(wer, 4), "cer": round(cer, 4)} | |
| return compute_metrics | |