File size: 1,348 Bytes
d596074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright      2024  Xiaomi Corp.        (authors: Fangjun Kuang)

from dataclasses import dataclass
from typing import List

import torch


@dataclass
class TokenSpan:
    # ID of the token
    token: int

    # Start frame of this token in the output log_prob
    start: int

    # End frame of this token in the output log_prob
    end: int


# See also
# https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/_alignment.py#L96
# We use torchaudio as a reference while implementing this function
def merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]:
    """Compute start and end frames of each token from the given alignment.

    Args:
      alignment:
        A list of token IDs.
      blank_id:
        ID of the blank.
    Returns:
      Return a list of TokenSpan.
    """
    alignment_tensor = torch.tensor(alignment, dtype=torch.int32)

    diff = torch.diff(
        alignment_tensor,
        prepend=torch.tensor([-1]),
        append=torch.tensor([-1]),
    )

    non_zero_indexes = torch.nonzero(diff != 0).squeeze().tolist()

    ans = []
    for start, end in zip(non_zero_indexes[:-1], non_zero_indexes[1:]):
        token = alignment[start]
        if token == blank:
            continue
        span = TokenSpan(token=token, start=start, end=end)
        ans.append(span)
    return ans