|
|
import re |
|
|
from jiwer import wer, cer |
|
|
import cn2an |
|
|
from whisper_normalizer.english import EnglishTextNormalizer |
|
|
from whisper_normalizer.basic import BasicTextNormalizer |
|
|
|
|
|
|
|
|
class ASREvaluator: |
|
|
def __init__(self): |
|
|
|
|
|
self.en_normalizer = EnglishTextNormalizer() |
|
|
|
|
|
self.zh_normalizer = BasicTextNormalizer() |
|
|
|
|
|
def clean_zh_text(self, text): |
|
|
"""针对中文的特殊处理""" |
|
|
|
|
|
if re.search(r"\d", text): |
|
|
text = cn2an.transform(text, "an2cn") |
|
|
|
|
|
text = self.zh_normalizer(text) |
|
|
|
|
|
text = re.sub(r'\s+', '', text) |
|
|
|
|
|
return text |
|
|
|
|
|
def clean_en_text(self, text): |
|
|
"""针对英文的标准化""" |
|
|
|
|
|
text = self.en_normalizer(text) |
|
|
|
|
|
return text |
|
|
|
|
|
def compute_en_wer(self, data): |
|
|
"""计算英文词错误率 (WER) |
|
|
data = ["{ |
|
|
"index": 1, |
|
|
"audio_path": "xxx.wav", |
|
|
"reference": "text", |
|
|
"inference_time": 0.123, |
|
|
"predicts": "test" |
|
|
}] |
|
|
""" |
|
|
refs = [] |
|
|
preds = [] |
|
|
for item in data: |
|
|
ref_clean = self.clean_en_text(item["reference"]) |
|
|
pred_clean = self.clean_en_text(item["predicts"]) |
|
|
if ref_clean.strip(): |
|
|
refs.append(ref_clean) |
|
|
preds.append(pred_clean) |
|
|
score = wer(refs, preds) |
|
|
return score |
|
|
|
|
|
def compute_zh_cer(self, data): |
|
|
"""计算中文字错误率 (CER) |
|
|
data = ["{ |
|
|
"index": 1, |
|
|
"audio_path": "xxx.wav", |
|
|
"reference": "text", |
|
|
"inference_time": 0.123, |
|
|
"predicts": "test" |
|
|
}] |
|
|
""" |
|
|
|
|
|
refs = [] |
|
|
preds = [] |
|
|
for item in data: |
|
|
ref_clean = self.clean_zh_text(item["reference"]) |
|
|
pred_clean = self.clean_zh_text(item["predicts"]) |
|
|
if ref_clean.strip(): |
|
|
refs.append(ref_clean) |
|
|
preds.append(pred_clean) |
|
|
score = cer(refs, preds) |
|
|
return score |
|
|
|
|
|
|
|
|
def compute(model_name, data_name, results, language): |
|
|
evaluator = ASREvaluator() |
|
|
if language == "zh": |
|
|
cer = evaluator.compute_zh_cer(results) |
|
|
res = f"Model: {model_name}, Dataset: {data_name}, data {len(results)}, CER: {cer:.2%}" |
|
|
print(res) |
|
|
return res |
|
|
else: |
|
|
wer = evaluator.compute_en_wer(results) |
|
|
res = f"Model: {model_name}, Dataset: {data_name}, data {len(results)}, WER: {wer:.2%}" |
|
|
print(res) |
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
evaluator = ASREvaluator() |
|
|
print(evaluator.clean_zh_text("相比sota模型,我们的方法在小样本场景下召回率高出十二个百分点,且参数量仅为其三分之一。")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reports = Path("/Users/jeqin/work/code/TestTranslator/reports") |
|
|
for file in reports.glob("*wenet_net.json"): |
|
|
with open(file) as f: |
|
|
data = json.load(f) |
|
|
compute(model_name=file.name, data_name="wenet_net",results=data, language="zh") |