TestTranslator / scripts /caculate_cer.py
yujuanqin's picture
update scripts
42742c6
import json
from lib.utils import run_textdistance, clean_text_for_comparison_zh, highlight_diff
import re
# import Levenshtein
def calculate_distance(reference: str, hypothesis: str):
"""
使用 python-Levenshtein 库计算字符错误率 (CER)。
CER = (Substitutions + Deletions + Insertions) / Total Characters in Reference
= Levenshtein Distance / Total Characters in Reference
Args:
reference: 真实的文本转录 (Ground Truth)。
hypothesis: ASR 模型的预测结果。
Returns:
字符错误率 (CER)。
"""
text1_clean = clean_text_for_comparison_zh(reference)
text2_clean = clean_text_for_comparison_zh(hypothesis)
d, nd = run_textdistance(text1_clean, text2_clean)
diff = ""
if d > 0:
diff = highlight_diff(text1_clean, text2_clean, spliter="")
return d, diff
if __name__ == '__main__':
import cn2an
results_list = json.load(open("csv/funasr_wenet_results.json", encoding="utf-8"))
count = 0
distance_sum = 0
reference_sum = 0
for item in results_list:
count += 1
reference = item["reference"]
hypothesis = item["inference_result"]
# # 如果是 whisper,使用 cn2an替换数字为中文
# if re.search(r"\d", hypothesis):
# hypothesis = cn2an.transform(hypothesis, "an2cn")
distance, diff = calculate_distance(reference, hypothesis)
print(f"{count}. distance: {distance}")
if distance > 0:
print(f"Audio Path: {item['audio_path']}")
print(f"Reference: {reference}")
print(f"Hypothesis: {hypothesis}")
print(f"Diff: {diff}")
distance_sum += distance
reference_sum += len(reference)
cer = distance_sum / reference_sum if reference_sum > 0 else 0
print(f"Total Distance: {distance_sum}, Total Reference Length: {reference_sum}, CER: {cer:.4f}")