ASR_AGENT_ / analysis /events.py
unknown
Update wer and cer
d7df0a5
from __future__ import annotations
import re
from typing import Dict, List
from core.schemas import AlignResult, ErrorEvent
RE_HAS_LATIN = re.compile(r"[A-Za-z]")
RE_HAS_DIGIT = re.compile(r"\d")
RE_NUM_ZH = re.compile(r"[零一二三四五六七八九十百千万亿点两]")
RE_TIME_UNIT = re.compile(r"(点|分|秒|小时|天|月|年|号|日|周|星期)")
RE_CJK = re.compile(r"[\u4e00-\u9fff]")
def classify_error(ref: str, hyp: str, lang_type: str, level: str) -> str:
joined = f"{ref}{hyp}"
if ref and hyp:
if (
RE_HAS_DIGIT.search(joined)
or RE_NUM_ZH.search(joined)
or RE_TIME_UNIT.search(joined)
):
return "number_or_time"
if RE_HAS_LATIN.search(joined) and RE_CJK.search(joined):
return "code_switch"
if lang_type == "en" and level == "word":
return "word_substitution"
if lang_type in {"zh", "mixed"} and level == "char":
return "char_substitution"
return "substitution"
if ref and not hyp:
if lang_type == "en" and level == "word":
return "word_deletion"
return "deletion"
if hyp and not ref:
if lang_type == "en" and level == "word":
return "word_insertion"
return "insertion"
return "other"
def extract_events(run_id: str, align: AlignResult, meta: Dict) -> List[ErrorEvent]:
events: List[ErrorEvent] = []
def from_ops(level: str, ops):
pos = 0
for op in ops:
if op.op == "OK":
pos += 1
continue
if op.op in ("S", "I", "D"):
err_class = classify_error(op.ref, op.hyp, align.lang_type, level)
events.append(
ErrorEvent(
run_id=run_id,
utt_id=align.utt_id,
op_type=op.op,
ref=op.ref,
hyp=op.hyp,
position=pos,
level=level,
lang_type=align.lang_type,
is_primary_level=(level == align.primary_level),
error_class=err_class,
meta=meta,
)
)
pos += 1
if align.ref_text is None:
return events
from_ops("word", align.ops_word)
from_ops("char", align.ops_char)
return events