File size: 4,038 Bytes
db0d138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import re
from jiwer import wer, cer
import cn2an
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer


class ASREvaluator:
    def __init__(self):
        # 官方英文标准化:处理拼写、缩写、标点
        self.en_normalizer = EnglishTextNormalizer()
        # 官方基础标准化:主要用于非英文,处理标点、大小写、多余空格
        self.zh_normalizer = BasicTextNormalizer()

    def clean_zh_text(self, text):
        """针对中文的特殊处理"""
        # print("text before clean:", text)
        if re.search(r"\d", text):
            text = cn2an.transform(text, "an2cn")
        # 1. 基础标准化 (去标点、繁转简等)
        text = self.zh_normalizer(text)
        # 2. 去除所有空格(防止原文本中自带的空格干扰)
        text = re.sub(r'\s+', '', text)
        # print("text after clean:", text)
        return text

    def clean_en_text(self, text):
        """针对英文的标准化"""
        # print("text before clean:", text)
        text = self.en_normalizer(text)
        # print("text after clean:", text)
        return text

    def compute_en_wer(self, data):
        """计算英文词错误率 (WER)
        data = ["{
                "index": 1,
                "audio_path": "xxx.wav",
                "reference": "text",
                "inference_time": 0.123,
                "predicts": "test"
            }]
        """
        refs = []
        preds = []
        for item in data:
            ref_clean = self.clean_en_text(item["reference"])
            pred_clean = self.clean_en_text(item["predicts"])
            if ref_clean.strip(): # 过滤掉空的参考文本
                refs.append(ref_clean)
                preds.append(pred_clean)
        score = wer(refs, preds)
        return score

    def compute_zh_cer(self, data):
        """计算中文字错误率 (CER)
        data = ["{
                "index": 1,
                "audio_path": "xxx.wav",
                "reference": "text",
                "inference_time": 0.123,
                "predicts": "test"
            }]
        """
        # 注意:在中文评估中,将句子拆解为“字”后计算 WER,结果等同于 CER
        refs = []
        preds = []
        for item in data:
            ref_clean = self.clean_zh_text(item["reference"])
            pred_clean = self.clean_zh_text(item["predicts"])
            if ref_clean.strip(): # 过滤掉空的参考文本
                refs.append(ref_clean)
                preds.append(pred_clean)
        score = cer(refs, preds)
        return score


def compute(model_name, data_name, results, language):
    evaluator = ASREvaluator()
    if language == "zh":
        cer = evaluator.compute_zh_cer(results)
        res = f"Model: {model_name}, Dataset: {data_name}, data {len(results)}, CER: {cer:.2%}"
        print(res)
        return res
    else:
        wer = evaluator.compute_en_wer(results)
        res = f"Model: {model_name}, Dataset: {data_name}, data {len(results)}, WER: {wer:.2%}"
        print(res)
        return res


# ================= 使用示例 =================
if __name__ == "__main__":
    import json
    from pathlib import Path

    evaluator = ASREvaluator()
    print(evaluator.clean_zh_text("相比sota模型,我们的方法在小样本场景下召回率高出十二个百分点,且参数量仅为其三分之一。"))
    # result_file = Path("/Users/jeqin/work/code/TestTranslator/reports/whisper_libri_en.json")
    # with open(result_file, "r", encoding="utf-8") as f:
    #     data = json.load(f)
    # en_wer = evaluator.compute_en_wer(data)
    # print(f"{result_file.name} WER: {en_wer:.2%}")

    reports = Path("/Users/jeqin/work/code/TestTranslator/reports")
    for file in reports.glob("*wenet_net.json"):
        with open(file) as f:
            data = json.load(f)
            compute(model_name=file.name, data_name="wenet_net",results=data, language="zh")