from __future__ import annotations from typing import List, Optional from jiwer import wer as jiwer_wer from core.schemas import AlignOp, AlignResult from .language_utils import choose_primary_level, detect_lang_type, split_chars_no_space, split_word_like from .normalize import normalize_text_zh def _levenshtein_ops(ref: List[str], hyp: List[str]) -> List[AlignOp]: n, m = len(ref), len(hyp) dp = [[0] * (m + 1) for _ in range(n + 1)] bt = [[None] * (m + 1) for _ in range(n + 1)] for i in range(n + 1): dp[i][0] = i bt[i][0] = "D" if i > 0 else None for j in range(m + 1): dp[0][j] = j bt[0][j] = "I" if j > 0 else None for i in range(1, n + 1): for j in range(1, m + 1): if ref[i - 1] == hyp[j - 1]: dp[i][j] = dp[i - 1][j - 1] bt[i][j] = "OK" else: sub = dp[i - 1][j - 1] + 1 dele = dp[i - 1][j] + 1 ins = dp[i][j - 1] + 1 best = min(sub, dele, ins) dp[i][j] = best bt[i][j] = "S" if best == sub else ("D" if best == dele else "I") ops: List[AlignOp] = [] i, j = n, m while i > 0 or j > 0: action = bt[i][j] if action == "OK": ops.append(AlignOp(op="OK", ref=ref[i - 1], hyp=hyp[j - 1], i_ref=i - 1, i_hyp=j - 1)) i -= 1 j -= 1 elif action == "S": ops.append(AlignOp(op="S", ref=ref[i - 1], hyp=hyp[j - 1], i_ref=i - 1, i_hyp=j - 1)) i -= 1 j -= 1 elif action == "D": ops.append(AlignOp(op="D", ref=ref[i - 1], hyp="", i_ref=i - 1, i_hyp=j)) i -= 1 elif action == "I": ops.append(AlignOp(op="I", ref="", hyp=hyp[j - 1], i_ref=i, i_hyp=j - 1)) j -= 1 else: break ops.reverse() return ops def _rate_from_ops(ops: List[AlignOp], ref_len: int) -> Optional[float]: if ref_len == 0: return 0.0 err = sum(1 for op in ops if op.op in ("S", "I", "D")) return float(err / ref_len) def align_one(utt_id: str, ref_text: Optional[str], hyp_text: str) -> AlignResult: norm_ref = normalize_text_zh(ref_text) if ref_text is not None else None norm_hyp = normalize_text_zh(hyp_text) lang_type = detect_lang_type(norm_ref or norm_hyp) primary_level = choose_primary_level(lang_type) ops_word: List[AlignOp] = [] ops_char: List[AlignOp] = [] wer_value: Optional[float] = None cer_value: Optional[float] = None if norm_ref is not None: ref_w = split_word_like(norm_ref) hyp_w = split_word_like(norm_hyp) ops_word = _levenshtein_ops(ref_w, hyp_w) if lang_type == "en": try: wer_value = float(jiwer_wer(" ".join(ref_w), " ".join(hyp_w))) except Exception: wer_value = _rate_from_ops(ops_word, len(ref_w)) else: wer_value = _rate_from_ops(ops_word, len(ref_w)) ref_c = split_chars_no_space(norm_ref) hyp_c = split_chars_no_space(norm_hyp) ops_char = _levenshtein_ops(ref_c, hyp_c) cer_value = _rate_from_ops(ops_char, len(ref_c)) primary_metric_name = "wer" if primary_level == "word" else "cer" primary_metric_value = wer_value if primary_level == "word" else cer_value return AlignResult( utt_id=utt_id, ref_text=ref_text, hyp_text=hyp_text, norm_ref=norm_ref, norm_hyp=norm_hyp, lang_type=lang_type, primary_level=primary_level, primary_metric_name=primary_metric_name, primary_metric_value=primary_metric_value, wer=wer_value, cer=cer_value, ops_word=ops_word, ops_char=ops_char, )