Spaces:
Running
on
Zero
Running
on
Zero
| # tools/wer.py | |
| from __future__ import annotations | |
| from typing import List, Tuple | |
| import string | |
| from jiwer import process_words | |
| from zhon.hanzi import punctuation as zh_punctuation | |
| # 中文标点 + 英文标点 + '-' | |
| _PUNCTUATION_ALL = zh_punctuation + string.punctuation + "-" | |
| def _normalize_pair(gt: str, gen: str, lang: str) -> Tuple[str, str]: | |
| gt = "" if gt is None else str(gt) | |
| gen = "" if gen is None else str(gen) | |
| # 去标点(保留 "'") | |
| for x in _PUNCTUATION_ALL: | |
| if x == "'": | |
| continue | |
| gt = gt.replace(x, "") | |
| gen = gen.replace(x, "") | |
| # 统一空格与连字符 | |
| gt = gt.replace(" ", " ").replace("-", " ") | |
| gen = gen.replace(" ", " ").replace("-", " ") | |
| if lang == "zh": | |
| # 把“字”当作 token | |
| gt = " ".join([ch for ch in gt]) | |
| gen = " ".join([ch for ch in gen]) | |
| elif lang == "en": | |
| gt = gt.lower() | |
| gen = gen.lower() | |
| else: | |
| raise NotImplementedError("lang must be 'zh' or 'en'") | |
| return gt, gen | |
| def compute_wers(gt_texts: List[str], gen_texts: List[str], lang: str = "zh") -> List[float]: | |
| if len(gt_texts) != len(gen_texts): | |
| raise ValueError(f"Length mismatch: {len(gt_texts)} != {len(gen_texts)}") | |
| wers: List[float] = [] | |
| for gt_raw, gen_raw in zip(gt_texts, gen_texts): | |
| gt_norm, gen_norm = _normalize_pair(gt_raw, gen_raw, lang=lang) | |
| measures = process_words(reference=gt_norm, hypothesis=gen_norm) | |
| wers.append(float(measures.wer)) | |
| return wers | |
| if __name__ == "__main__": | |
| gt = ["你好世界啊", "今天天气不对", "abc-def"] | |
| gen = ["你好,世界!", "今天 天气 不错", "abc def"] | |
| print(compute_wers(gt, gen, lang="zh")) | |
| print(compute_wers(["Hello World"], ["hello, world!"], lang="en")) | |