import re from jiwer import wer, cer import cn2an from whisper_normalizer.english import EnglishTextNormalizer from whisper_normalizer.basic import BasicTextNormalizer class ASREvaluator: def __init__(self): # 官方英文标准化:处理拼写、缩写、标点 self.en_normalizer = EnglishTextNormalizer() # 官方基础标准化:主要用于非英文,处理标点、大小写、多余空格 self.zh_normalizer = BasicTextNormalizer() def clean_zh_text(self, text): """针对中文的特殊处理""" # print("text before clean:", text) if re.search(r"\d", text): text = cn2an.transform(text, "an2cn") # 1. 基础标准化 (去标点、繁转简等) text = self.zh_normalizer(text) # 2. 去除所有空格(防止原文本中自带的空格干扰) text = re.sub(r'\s+', '', text) # print("text after clean:", text) return text def clean_en_text(self, text): """针对英文的标准化""" # print("text before clean:", text) text = self.en_normalizer(text) # print("text after clean:", text) return text def compute_en_wer(self, data): """计算英文词错误率 (WER) data = ["{ "index": 1, "audio_path": "xxx.wav", "reference": "text", "inference_time": 0.123, "predicts": "test" }] """ refs = [] preds = [] for item in data: ref_clean = self.clean_en_text(item["reference"]) pred_clean = self.clean_en_text(item["predicts"]) if ref_clean.strip(): # 过滤掉空的参考文本 refs.append(ref_clean) preds.append(pred_clean) score = wer(refs, preds) return score def compute_zh_cer(self, data): """计算中文字错误率 (CER) data = ["{ "index": 1, "audio_path": "xxx.wav", "reference": "text", "inference_time": 0.123, "predicts": "test" }] """ # 注意:在中文评估中,将句子拆解为“字”后计算 WER,结果等同于 CER refs = [] preds = [] for item in data: ref_clean = self.clean_zh_text(item["reference"]) pred_clean = self.clean_zh_text(item["predicts"]) if ref_clean.strip(): # 过滤掉空的参考文本 refs.append(ref_clean) preds.append(pred_clean) score = cer(refs, preds) return score def compute(model_name, data_name, results, language): evaluator = ASREvaluator() if language == "zh": cer = evaluator.compute_zh_cer(results) res = f"Model: {model_name}, Dataset: {data_name}, data {len(results)}, CER: {cer:.2%}" print(res) return res else: wer = evaluator.compute_en_wer(results) res = f"Model: {model_name}, Dataset: {data_name}, data {len(results)}, WER: {wer:.2%}" print(res) return res # ================= 使用示例 ================= if __name__ == "__main__": import json from pathlib import Path evaluator = ASREvaluator() print(evaluator.clean_zh_text("相比sota模型,我们的方法在小样本场景下召回率高出十二个百分点,且参数量仅为其三分之一。")) # result_file = Path("/Users/jeqin/work/code/TestTranslator/reports/whisper_libri_en.json") # with open(result_file, "r", encoding="utf-8") as f: # data = json.load(f) # en_wer = evaluator.compute_en_wer(data) # print(f"{result_file.name} WER: {en_wer:.2%}") reports = Path("/Users/jeqin/work/code/TestTranslator/reports") for file in reports.glob("*wenet_net.json"): with open(file) as f: data = json.load(f) compute(model_name=file.name, data_name="wenet_net",results=data, language="zh")