yujuanqin's picture
add asr test
db0d138
import re
from jiwer import wer, cer
import cn2an
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer
class ASREvaluator:
def __init__(self):
# 官方英文标准化:处理拼写、缩写、标点
self.en_normalizer = EnglishTextNormalizer()
# 官方基础标准化:主要用于非英文,处理标点、大小写、多余空格
self.zh_normalizer = BasicTextNormalizer()
def clean_zh_text(self, text):
"""针对中文的特殊处理"""
# print("text before clean:", text)
if re.search(r"\d", text):
text = cn2an.transform(text, "an2cn")
# 1. 基础标准化 (去标点、繁转简等)
text = self.zh_normalizer(text)
# 2. 去除所有空格(防止原文本中自带的空格干扰)
text = re.sub(r'\s+', '', text)
# print("text after clean:", text)
return text
def clean_en_text(self, text):
"""针对英文的标准化"""
# print("text before clean:", text)
text = self.en_normalizer(text)
# print("text after clean:", text)
return text
def compute_en_wer(self, data):
"""计算英文词错误率 (WER)
data = ["{
"index": 1,
"audio_path": "xxx.wav",
"reference": "text",
"inference_time": 0.123,
"predicts": "test"
}]
"""
refs = []
preds = []
for item in data:
ref_clean = self.clean_en_text(item["reference"])
pred_clean = self.clean_en_text(item["predicts"])
if ref_clean.strip(): # 过滤掉空的参考文本
refs.append(ref_clean)
preds.append(pred_clean)
score = wer(refs, preds)
return score
def compute_zh_cer(self, data):
"""计算中文字错误率 (CER)
data = ["{
"index": 1,
"audio_path": "xxx.wav",
"reference": "text",
"inference_time": 0.123,
"predicts": "test"
}]
"""
# 注意:在中文评估中,将句子拆解为“字”后计算 WER,结果等同于 CER
refs = []
preds = []
for item in data:
ref_clean = self.clean_zh_text(item["reference"])
pred_clean = self.clean_zh_text(item["predicts"])
if ref_clean.strip(): # 过滤掉空的参考文本
refs.append(ref_clean)
preds.append(pred_clean)
score = cer(refs, preds)
return score
def compute(model_name, data_name, results, language):
evaluator = ASREvaluator()
if language == "zh":
cer = evaluator.compute_zh_cer(results)
res = f"Model: {model_name}, Dataset: {data_name}, data {len(results)}, CER: {cer:.2%}"
print(res)
return res
else:
wer = evaluator.compute_en_wer(results)
res = f"Model: {model_name}, Dataset: {data_name}, data {len(results)}, WER: {wer:.2%}"
print(res)
return res
# ================= 使用示例 =================
if __name__ == "__main__":
import json
from pathlib import Path
evaluator = ASREvaluator()
print(evaluator.clean_zh_text("相比sota模型,我们的方法在小样本场景下召回率高出十二个百分点,且参数量仅为其三分之一。"))
# result_file = Path("/Users/jeqin/work/code/TestTranslator/reports/whisper_libri_en.json")
# with open(result_file, "r", encoding="utf-8") as f:
# data = json.load(f)
# en_wer = evaluator.compute_en_wer(data)
# print(f"{result_file.name} WER: {en_wer:.2%}")
reports = Path("/Users/jeqin/work/code/TestTranslator/reports")
for file in reports.glob("*wenet_net.json"):
with open(file) as f:
data = json.load(f)
compute(model_name=file.name, data_name="wenet_net",results=data, language="zh")