ground-zero / src /training /metrics.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
Raw
History Blame Contribute Delete
1.34 kB
"""
WER and CER computation for Seq2SeqTrainer eval loop.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Callable
import numpy as np
if TYPE_CHECKING:
from transformers import EvalPrediction, WhisperProcessor
def make_compute_metrics(processor: "WhisperProcessor") -> Callable[["EvalPrediction"], dict[str, float]]:
"""
Returns a compute_metrics function compatible with HuggingFace Seq2SeqTrainer.
Computes Word Error Rate (WER) and Character Error Rate (CER).
"""
import jiwer
def compute_metrics(pred: "EvalPrediction") -> dict[str, float]:
pred_ids = pred.predictions
label_ids = pred.label_ids
# Replace -100 (loss mask) with pad token id so decoding doesn't fail
label_ids = np.where(label_ids != -100, label_ids, processor.tokenizer.pad_token_id)
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
# Normalize whitespace
pred_str = [" ".join(s.split()) for s in pred_str]
label_str = [" ".join(s.split()) for s in label_str]
wer = jiwer.wer(label_str, pred_str)
cer = jiwer.cer(label_str, pred_str)
return {"wer": round(wer, 4), "cer": round(cer, 4)}
return compute_metrics