Huakang Chen
Add application file
1ec923d
# tools/wer.py
from __future__ import annotations
from typing import List, Tuple
import string
from jiwer import process_words
from zhon.hanzi import punctuation as zh_punctuation
# 中文标点 + 英文标点 + '-'
_PUNCTUATION_ALL = zh_punctuation + string.punctuation + "-"
def _normalize_pair(gt: str, gen: str, lang: str) -> Tuple[str, str]:
gt = "" if gt is None else str(gt)
gen = "" if gen is None else str(gen)
# 去标点(保留 "'")
for x in _PUNCTUATION_ALL:
if x == "'":
continue
gt = gt.replace(x, "")
gen = gen.replace(x, "")
# 统一空格与连字符
gt = gt.replace(" ", " ").replace("-", " ")
gen = gen.replace(" ", " ").replace("-", " ")
if lang == "zh":
# 把“字”当作 token
gt = " ".join([ch for ch in gt])
gen = " ".join([ch for ch in gen])
elif lang == "en":
gt = gt.lower()
gen = gen.lower()
else:
raise NotImplementedError("lang must be 'zh' or 'en'")
return gt, gen
def compute_wers(gt_texts: List[str], gen_texts: List[str], lang: str = "zh") -> List[float]:
if len(gt_texts) != len(gen_texts):
raise ValueError(f"Length mismatch: {len(gt_texts)} != {len(gen_texts)}")
wers: List[float] = []
for gt_raw, gen_raw in zip(gt_texts, gen_texts):
gt_norm, gen_norm = _normalize_pair(gt_raw, gen_raw, lang=lang)
measures = process_words(reference=gt_norm, hypothesis=gen_norm)
wers.append(float(measures.wer))
return wers
if __name__ == "__main__":
gt = ["你好世界啊", "今天天气不对", "abc-def"]
gen = ["你好,世界!", "今天 天气 不错", "abc def"]
print(compute_wers(gt, gen, lang="zh"))
print(compute_wers(["Hello World"], ["hello, world!"], lang="en"))