|
|
|
|
|
|
|
|
|
|
|
from typing import List |
|
|
|
|
|
from utils import TokenSpan, merge_tokens |
|
|
|
|
|
|
|
|
def inefficient_merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]: |
|
|
"""Compute start and end frames of each token from the given alignment. |
|
|
|
|
|
Args: |
|
|
alignment: |
|
|
A list of token IDs. |
|
|
blank_id: |
|
|
ID of the blank. |
|
|
Returns: |
|
|
Return a list of TokenSpan. |
|
|
""" |
|
|
ans = [] |
|
|
last_token = None |
|
|
last_i = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, token in enumerate(alignment): |
|
|
if token == blank: |
|
|
if last_token is None or last_token == token: |
|
|
continue |
|
|
|
|
|
|
|
|
span = TokenSpan(token=last_token, start=last_i, end=i) |
|
|
ans.append(span) |
|
|
last_token = None |
|
|
last_i = None |
|
|
continue |
|
|
|
|
|
|
|
|
if last_token is None or last_token == blank: |
|
|
last_token = token |
|
|
last_i = i |
|
|
continue |
|
|
|
|
|
if last_token == token: |
|
|
continue |
|
|
|
|
|
|
|
|
span = TokenSpan(token=last_token, start=last_i, end=i) |
|
|
last_token = token |
|
|
last_i = i |
|
|
ans.append(span) |
|
|
|
|
|
if last_token is not None: |
|
|
assert last_i is not None, (last_i, last_token) |
|
|
span = TokenSpan(token=last_token, start=last_i, end=len(alignment)) |
|
|
|
|
|
ans.append(span) |
|
|
|
|
|
return ans |
|
|
|
|
|
|
|
|
def test_merge_tokens(): |
|
|
data_list = [ |
|
|
|
|
|
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0], |
|
|
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3], |
|
|
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0], |
|
|
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3], |
|
|
[0, 1, 2, 3, 0], |
|
|
[1, 2, 3, 0], |
|
|
[0, 1, 2, 3], |
|
|
[1, 2, 3], |
|
|
] |
|
|
|
|
|
for data in data_list: |
|
|
span1 = merge_tokens(data) |
|
|
span2 = inefficient_merge_tokens(data) |
|
|
assert span1 == span2, (data, span1, span2) |
|
|
|
|
|
|
|
|
def main(): |
|
|
test_merge_tokens() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|