ASR_AGENT_ / analysis /align.py
unknown
Add my local files
b9196ed
from __future__ import annotations
from typing import List, Optional, Tuple
from jiwer import wer as jiwer_wer
from core.schemas import AlignOp, AlignResult
from .normalize import normalize_text_zh
def _split_words(text: str) -> List[str]:
# 简单:按空格切。中文如果没有空格,会作为一个整体。
# 你可以后续接入 jieba 或自定义 tokenizer,把中文按字/词切开。
return text.split() if text else []
def _split_chars(text: str) -> List[str]:
return list(text) if text else []
def _levenshtein_ops(ref: List[str], hyp: List[str]) -> List[AlignOp]:
# 经典DP求最短编辑路径,并回溯得到 S/I/D/OK
n, m = len(ref), len(hyp)
dp = [[0] * (m + 1) for _ in range(n + 1)]
bt = [[None] * (m + 1) for _ in range(n + 1)] # backtrace
for i in range(n + 1):
dp[i][0] = i
bt[i][0] = "D" if i > 0 else None
for j in range(m + 1):
dp[0][j] = j
bt[0][j] = "I" if j > 0 else None
for i in range(1, n + 1):
for j in range(1, m + 1):
if ref[i - 1] == hyp[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
bt[i][j] = "OK"
else:
# substitute, delete, insert
sub = dp[i - 1][j - 1] + 1
dele = dp[i - 1][j] + 1
ins = dp[i][j - 1] + 1
best = min(sub, dele, ins)
dp[i][j] = best
bt[i][j] = "S" if best == sub else ("D" if best == dele else "I")
# backtrace
ops: List[AlignOp] = []
i, j = n, m
while i > 0 or j > 0:
action = bt[i][j]
if action == "OK":
ops.append(AlignOp(op="OK", ref=ref[i - 1], hyp=hyp[j - 1], i_ref=i - 1, i_hyp=j - 1))
i -= 1
j -= 1
elif action == "S":
ops.append(AlignOp(op="S", ref=ref[i - 1], hyp=hyp[j - 1], i_ref=i - 1, i_hyp=j - 1))
i -= 1
j -= 1
elif action == "D":
ops.append(AlignOp(op="D", ref=ref[i - 1], hyp="", i_ref=i - 1, i_hyp=j))
i -= 1
elif action == "I":
ops.append(AlignOp(op="I", ref="", hyp=hyp[j - 1], i_ref=i, i_hyp=j - 1))
j -= 1
else:
# should not happen
break
ops.reverse()
return ops
def align_one(utt_id: str, ref_text: Optional[str], hyp_text: str) -> AlignResult:
norm_ref = normalize_text_zh(ref_text) if ref_text is not None else None
norm_hyp = normalize_text_zh(hyp_text)
# word-level
ops_word: List[AlignOp] = []
wer_value: Optional[float] = None
if norm_ref is not None:
ref_w = _split_words(norm_ref)
hyp_w = _split_words(norm_hyp)
ops_word = _levenshtein_ops(ref_w, hyp_w)
wer_value = float(jiwer_wer(norm_ref, norm_hyp))
# char-level CER (更适合中文)
ops_char: List[AlignOp] = []
cer_value: Optional[float] = None
if norm_ref is not None:
ref_c = _split_chars(norm_ref.replace(" ", ""))
hyp_c = _split_chars(norm_hyp.replace(" ", ""))
ops_char = _levenshtein_ops(ref_c, hyp_c)
# CER = (S + I + D) / len(ref)
if len(ref_c) > 0:
err = sum(1 for op in ops_char if op.op in ("S", "I", "D"))
cer_value = float(err / len(ref_c))
else:
cer_value = 0.0
return AlignResult(
utt_id=utt_id,
ref_text=ref_text,
hyp_text=hyp_text,
norm_ref=norm_ref,
norm_hyp=norm_hyp,
wer=wer_value,
cer=cer_value,
ops_word=ops_word,
ops_char=ops_char,
)