| # Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | |
| from dataclasses import dataclass | |
| from typing import List | |
| import torch | |
| class TokenSpan: | |
| # ID of the token | |
| token: int | |
| # Start frame of this token in the output log_prob | |
| start: int | |
| # End frame of this token in the output log_prob | |
| end: int | |
| # See also | |
| # https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/_alignment.py#L96 | |
| # We use torchaudio as a reference while implementing this function | |
| def merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]: | |
| """Compute start and end frames of each token from the given alignment. | |
| Args: | |
| alignment: | |
| A list of token IDs. | |
| blank_id: | |
| ID of the blank. | |
| Returns: | |
| Return a list of TokenSpan. | |
| """ | |
| alignment_tensor = torch.tensor(alignment, dtype=torch.int32) | |
| diff = torch.diff( | |
| alignment_tensor, | |
| prepend=torch.tensor([-1]), | |
| append=torch.tensor([-1]), | |
| ) | |
| non_zero_indexes = torch.nonzero(diff != 0).squeeze().tolist() | |
| ans = [] | |
| for start, end in zip(non_zero_indexes[:-1], non_zero_indexes[1:]): | |
| token = alignment[start] | |
| if token == blank: | |
| continue | |
| span = TokenSpan(token=token, start=start, end=end) | |
| ans.append(span) | |
| return ans | |