odg123's picture
Upload icefall experiment results and logs
d596074 verified
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
from dataclasses import dataclass
from typing import List
import torch
@dataclass
class TokenSpan:
# ID of the token
token: int
# Start frame of this token in the output log_prob
start: int
# End frame of this token in the output log_prob
end: int
# See also
# https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/_alignment.py#L96
# We use torchaudio as a reference while implementing this function
def merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]:
"""Compute start and end frames of each token from the given alignment.
Args:
alignment:
A list of token IDs.
blank_id:
ID of the blank.
Returns:
Return a list of TokenSpan.
"""
alignment_tensor = torch.tensor(alignment, dtype=torch.int32)
diff = torch.diff(
alignment_tensor,
prepend=torch.tensor([-1]),
append=torch.tensor([-1]),
)
non_zero_indexes = torch.nonzero(diff != 0).squeeze().tolist()
ans = []
for start, end in zip(non_zero_indexes[:-1], non_zero_indexes[1:]):
token = alignment[start]
if token == blank:
continue
span = TokenSpan(token=token, start=start, end=end)
ans.append(span)
return ans