Spaces:
Running
Running
| from __future__ import annotations | |
| from typing import List, Optional, Tuple | |
| from jiwer import wer as jiwer_wer | |
| from core.schemas import AlignOp, AlignResult | |
| from .normalize import normalize_text_zh | |
| def _split_words(text: str) -> List[str]: | |
| # 简单:按空格切。中文如果没有空格,会作为一个整体。 | |
| # 你可以后续接入 jieba 或自定义 tokenizer,把中文按字/词切开。 | |
| return text.split() if text else [] | |
| def _split_chars(text: str) -> List[str]: | |
| return list(text) if text else [] | |
| def _levenshtein_ops(ref: List[str], hyp: List[str]) -> List[AlignOp]: | |
| # 经典DP求最短编辑路径,并回溯得到 S/I/D/OK | |
| n, m = len(ref), len(hyp) | |
| dp = [[0] * (m + 1) for _ in range(n + 1)] | |
| bt = [[None] * (m + 1) for _ in range(n + 1)] # backtrace | |
| for i in range(n + 1): | |
| dp[i][0] = i | |
| bt[i][0] = "D" if i > 0 else None | |
| for j in range(m + 1): | |
| dp[0][j] = j | |
| bt[0][j] = "I" if j > 0 else None | |
| for i in range(1, n + 1): | |
| for j in range(1, m + 1): | |
| if ref[i - 1] == hyp[j - 1]: | |
| dp[i][j] = dp[i - 1][j - 1] | |
| bt[i][j] = "OK" | |
| else: | |
| # substitute, delete, insert | |
| sub = dp[i - 1][j - 1] + 1 | |
| dele = dp[i - 1][j] + 1 | |
| ins = dp[i][j - 1] + 1 | |
| best = min(sub, dele, ins) | |
| dp[i][j] = best | |
| bt[i][j] = "S" if best == sub else ("D" if best == dele else "I") | |
| # backtrace | |
| ops: List[AlignOp] = [] | |
| i, j = n, m | |
| while i > 0 or j > 0: | |
| action = bt[i][j] | |
| if action == "OK": | |
| ops.append(AlignOp(op="OK", ref=ref[i - 1], hyp=hyp[j - 1], i_ref=i - 1, i_hyp=j - 1)) | |
| i -= 1 | |
| j -= 1 | |
| elif action == "S": | |
| ops.append(AlignOp(op="S", ref=ref[i - 1], hyp=hyp[j - 1], i_ref=i - 1, i_hyp=j - 1)) | |
| i -= 1 | |
| j -= 1 | |
| elif action == "D": | |
| ops.append(AlignOp(op="D", ref=ref[i - 1], hyp="", i_ref=i - 1, i_hyp=j)) | |
| i -= 1 | |
| elif action == "I": | |
| ops.append(AlignOp(op="I", ref="", hyp=hyp[j - 1], i_ref=i, i_hyp=j - 1)) | |
| j -= 1 | |
| else: | |
| # should not happen | |
| break | |
| ops.reverse() | |
| return ops | |
| def align_one(utt_id: str, ref_text: Optional[str], hyp_text: str) -> AlignResult: | |
| norm_ref = normalize_text_zh(ref_text) if ref_text is not None else None | |
| norm_hyp = normalize_text_zh(hyp_text) | |
| # word-level | |
| ops_word: List[AlignOp] = [] | |
| wer_value: Optional[float] = None | |
| if norm_ref is not None: | |
| ref_w = _split_words(norm_ref) | |
| hyp_w = _split_words(norm_hyp) | |
| ops_word = _levenshtein_ops(ref_w, hyp_w) | |
| wer_value = float(jiwer_wer(norm_ref, norm_hyp)) | |
| # char-level CER (更适合中文) | |
| ops_char: List[AlignOp] = [] | |
| cer_value: Optional[float] = None | |
| if norm_ref is not None: | |
| ref_c = _split_chars(norm_ref.replace(" ", "")) | |
| hyp_c = _split_chars(norm_hyp.replace(" ", "")) | |
| ops_char = _levenshtein_ops(ref_c, hyp_c) | |
| # CER = (S + I + D) / len(ref) | |
| if len(ref_c) > 0: | |
| err = sum(1 for op in ops_char if op.op in ("S", "I", "D")) | |
| cer_value = float(err / len(ref_c)) | |
| else: | |
| cer_value = 0.0 | |
| return AlignResult( | |
| utt_id=utt_id, | |
| ref_text=ref_text, | |
| hyp_text=hyp_text, | |
| norm_ref=norm_ref, | |
| norm_hyp=norm_hyp, | |
| wer=wer_value, | |
| cer=cer_value, | |
| ops_word=ops_word, | |
| ops_char=ops_char, | |
| ) | |