File size: 2,285 Bytes
d596074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python3
# Copyright      2024  Xiaomi Corp.        (authors: Fangjun Kuang)

from typing import List

from utils import TokenSpan, merge_tokens


def inefficient_merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]:
    """Compute start and end frames of each token from the given alignment.

    Args:
      alignment:
        A list of token IDs.
      blank_id:
        ID of the blank.
    Returns:
      Return a list of TokenSpan.
    """
    ans = []
    last_token = None
    last_i = None

    #  import pdb

    #  pdb.set_trace()
    for i, token in enumerate(alignment):
        if token == blank:
            if last_token is None or last_token == token:
                continue

            # end of the last token
            span = TokenSpan(token=last_token, start=last_i, end=i)
            ans.append(span)
            last_token = None
            last_i = None
            continue

        # The current token is not a blank
        if last_token is None or last_token == blank:
            last_token = token
            last_i = i
            continue

        if last_token == token:
            continue

        # end of the last token and start of the current token
        span = TokenSpan(token=last_token, start=last_i, end=i)
        last_token = token
        last_i = i
        ans.append(span)

    if last_token is not None:
        assert last_i is not None, (last_i, last_token)
        span = TokenSpan(token=last_token, start=last_i, end=len(alignment))
        # Note for the last token, its end is larger than len(alignment)-1
        ans.append(span)

    return ans


def test_merge_tokens():
    data_list = [
        # 0  1  2  3  4  5  6  7  8  9  10 11 12 13 14
        [0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0],
        [0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3],
        [1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0],
        [1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3],
        [0, 1, 2, 3, 0],
        [1, 2, 3, 0],
        [0, 1, 2, 3],
        [1, 2, 3],
    ]

    for data in data_list:
        span1 = merge_tokens(data)
        span2 = inefficient_merge_tokens(data)
        assert span1 == span2, (data, span1, span2)


def main():
    test_merge_tokens()


if __name__ == "__main__":
    main()