# tools/wer.py from __future__ import annotations from typing import List, Tuple import string from jiwer import process_words from zhon.hanzi import punctuation as zh_punctuation # 中文标点 + 英文标点 + '-' _PUNCTUATION_ALL = zh_punctuation + string.punctuation + "-" def _normalize_pair(gt: str, gen: str, lang: str) -> Tuple[str, str]: gt = "" if gt is None else str(gt) gen = "" if gen is None else str(gen) # 去标点(保留 "'") for x in _PUNCTUATION_ALL: if x == "'": continue gt = gt.replace(x, "") gen = gen.replace(x, "") # 统一空格与连字符 gt = gt.replace(" ", " ").replace("-", " ") gen = gen.replace(" ", " ").replace("-", " ") if lang == "zh": # 把“字”当作 token gt = " ".join([ch for ch in gt]) gen = " ".join([ch for ch in gen]) elif lang == "en": gt = gt.lower() gen = gen.lower() else: raise NotImplementedError("lang must be 'zh' or 'en'") return gt, gen def compute_wers(gt_texts: List[str], gen_texts: List[str], lang: str = "zh") -> List[float]: if len(gt_texts) != len(gen_texts): raise ValueError(f"Length mismatch: {len(gt_texts)} != {len(gen_texts)}") wers: List[float] = [] for gt_raw, gen_raw in zip(gt_texts, gen_texts): gt_norm, gen_norm = _normalize_pair(gt_raw, gen_raw, lang=lang) measures = process_words(reference=gt_norm, hypothesis=gen_norm) wers.append(float(measures.wer)) return wers if __name__ == "__main__": gt = ["你好世界啊", "今天天气不对", "abc-def"] gen = ["你好,世界!", "今天 天气 不错", "abc def"] print(compute_wers(gt, gen, lang="zh")) print(compute_wers(["Hello World"], ["hello, world!"], lang="en"))