""" WER and CER computation for Seq2SeqTrainer eval loop. """ from __future__ import annotations from typing import TYPE_CHECKING, Callable import numpy as np if TYPE_CHECKING: from transformers import EvalPrediction, WhisperProcessor def make_compute_metrics(processor: "WhisperProcessor") -> Callable[["EvalPrediction"], dict[str, float]]: """ Returns a compute_metrics function compatible with HuggingFace Seq2SeqTrainer. Computes Word Error Rate (WER) and Character Error Rate (CER). """ import jiwer def compute_metrics(pred: "EvalPrediction") -> dict[str, float]: pred_ids = pred.predictions label_ids = pred.label_ids # Replace -100 (loss mask) with pad token id so decoding doesn't fail label_ids = np.where(label_ids != -100, label_ids, processor.tokenizer.pad_token_id) pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) label_str = processor.batch_decode(label_ids, skip_special_tokens=True) # Normalize whitespace pred_str = [" ".join(s.split()) for s in pred_str] label_str = [" ".join(s.split()) for s in label_str] wer = jiwer.wer(label_str, pred_str) cer = jiwer.cer(label_str, pred_str) return {"wer": round(wer, 4), "cer": round(cer, 4)} return compute_metrics