File size: 3,911 Bytes
b9196ed
d7df0a5
b9196ed
 
 
 
d7df0a5
0c22680
b9196ed
 
 
 
 
d7df0a5
b9196ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7df0a5
 
 
 
 
 
 
b9196ed
0c22680
 
d7df0a5
 
0c22680
 
 
b9196ed
d7df0a5
b9196ed
d7df0a5
 
b9196ed
d7df0a5
 
b9196ed
d7df0a5
 
 
 
 
 
 
b9196ed
d7df0a5
 
b9196ed
d7df0a5
 
 
 
b9196ed
 
 
 
 
 
 
d7df0a5
 
 
 
b9196ed
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import annotations
from typing import List, Optional

from jiwer import wer as jiwer_wer

from core.schemas import AlignOp, AlignResult
from .language_utils import choose_primary_level, detect_lang_type, split_chars_no_space, split_word_like
from .normalize import normalize_text


def _levenshtein_ops(ref: List[str], hyp: List[str]) -> List[AlignOp]:
    n, m = len(ref), len(hyp)
    dp = [[0] * (m + 1) for _ in range(n + 1)]
    bt = [[None] * (m + 1) for _ in range(n + 1)]

    for i in range(n + 1):
        dp[i][0] = i
        bt[i][0] = "D" if i > 0 else None
    for j in range(m + 1):
        dp[0][j] = j
        bt[0][j] = "I" if j > 0 else None

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            if ref[i - 1] == hyp[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
                bt[i][j] = "OK"
            else:
                sub = dp[i - 1][j - 1] + 1
                dele = dp[i - 1][j] + 1
                ins = dp[i][j - 1] + 1
                best = min(sub, dele, ins)
                dp[i][j] = best
                bt[i][j] = "S" if best == sub else ("D" if best == dele else "I")

    ops: List[AlignOp] = []
    i, j = n, m
    while i > 0 or j > 0:
        action = bt[i][j]
        if action == "OK":
            ops.append(AlignOp(op="OK", ref=ref[i - 1], hyp=hyp[j - 1], i_ref=i - 1, i_hyp=j - 1))
            i -= 1
            j -= 1
        elif action == "S":
            ops.append(AlignOp(op="S", ref=ref[i - 1], hyp=hyp[j - 1], i_ref=i - 1, i_hyp=j - 1))
            i -= 1
            j -= 1
        elif action == "D":
            ops.append(AlignOp(op="D", ref=ref[i - 1], hyp="", i_ref=i - 1, i_hyp=j))
            i -= 1
        elif action == "I":
            ops.append(AlignOp(op="I", ref="", hyp=hyp[j - 1], i_ref=i, i_hyp=j - 1))
            j -= 1
        else:
            break

    ops.reverse()
    return ops


def _rate_from_ops(ops: List[AlignOp], ref_len: int) -> Optional[float]:
    if ref_len == 0:
        return 0.0
    err = sum(1 for op in ops if op.op in ("S", "I", "D"))
    return float(err / ref_len)


def align_one(utt_id: str, ref_text: Optional[str], hyp_text: str) -> AlignResult:
    raw_for_lang = " ".join([x for x in [ref_text, hyp_text] if x])
    lang_type = detect_lang_type(raw_for_lang)
    primary_level = choose_primary_level(lang_type)

    norm_ref = normalize_text(ref_text, lang_hint=lang_type) if ref_text is not None else None
    norm_hyp = normalize_text(hyp_text, lang_hint=lang_type)

    ops_word: List[AlignOp] = []
    ops_char: List[AlignOp] = []
    wer_value: Optional[float] = None
    cer_value: Optional[float] = None

    if norm_ref is not None:
        ref_w = split_word_like(norm_ref)
        hyp_w = split_word_like(norm_hyp)
        ops_word = _levenshtein_ops(ref_w, hyp_w)
        if lang_type == "en":
            try:
                wer_value = float(jiwer_wer(" ".join(ref_w), " ".join(hyp_w)))
            except Exception:
                wer_value = _rate_from_ops(ops_word, len(ref_w))
        else:
            wer_value = _rate_from_ops(ops_word, len(ref_w))

        ref_c = split_chars_no_space(norm_ref)
        hyp_c = split_chars_no_space(norm_hyp)
        ops_char = _levenshtein_ops(ref_c, hyp_c)
        cer_value = _rate_from_ops(ops_char, len(ref_c))

    primary_metric_name = "wer" if primary_level == "word" else "cer"
    primary_metric_value = wer_value if primary_level == "word" else cer_value

    return AlignResult(
        utt_id=utt_id,
        ref_text=ref_text,
        hyp_text=hyp_text,
        norm_ref=norm_ref,
        norm_hyp=norm_hyp,
        lang_type=lang_type,
        primary_level=primary_level,
        primary_metric_name=primary_metric_name,
        primary_metric_value=primary_metric_value,
        wer=wer_value,
        cer=cer_value,
        ops_word=ops_word,
        ops_char=ops_char,
    )