|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import k2 |
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from icefall.env import get_env_info |
|
|
from icefall.utils import ( |
|
|
AttributeDict, |
|
|
add_eos, |
|
|
add_sos, |
|
|
encode_supervisions, |
|
|
get_texts, |
|
|
make_pad_mask, |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def sup(): |
|
|
sequence_idx = torch.tensor([0, 1, 2]) |
|
|
start_frame = torch.tensor([1, 3, 9]) |
|
|
num_frames = torch.tensor([20, 30, 10]) |
|
|
text = ["one", "two", "three"] |
|
|
return { |
|
|
"sequence_idx": sequence_idx, |
|
|
"start_frame": start_frame, |
|
|
"num_frames": num_frames, |
|
|
"text": text, |
|
|
} |
|
|
|
|
|
|
|
|
def test_encode_supervisions(sup): |
|
|
supervision_segments, texts = encode_supervisions(sup, subsampling_factor=4) |
|
|
assert torch.all( |
|
|
torch.eq( |
|
|
supervision_segments, |
|
|
torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]), |
|
|
) |
|
|
) |
|
|
assert texts == ["two", "one", "three"] |
|
|
|
|
|
|
|
|
def test_get_texts_ragged(): |
|
|
fsa1 = k2.Fsa.from_str( |
|
|
""" |
|
|
0 1 1 10 |
|
|
1 2 2 20 |
|
|
2 3 3 30 |
|
|
3 4 -1 0 |
|
|
4 |
|
|
""" |
|
|
) |
|
|
fsa1.aux_labels = k2.RaggedTensor("[ [1 3 0 2] [] [4 0 1] [-1]]") |
|
|
|
|
|
fsa2 = k2.Fsa.from_str( |
|
|
""" |
|
|
0 1 1 1 |
|
|
1 2 2 2 |
|
|
2 3 -1 0 |
|
|
3 |
|
|
""" |
|
|
) |
|
|
fsa2.aux_labels = k2.RaggedTensor("[[3 0 5 0 8] [0 9 7 0] [-1]]") |
|
|
fsas = k2.Fsa.from_fsas([fsa1, fsa2]) |
|
|
texts = get_texts(fsas) |
|
|
assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]] |
|
|
|
|
|
|
|
|
def test_get_texts_regular(): |
|
|
fsa1 = k2.Fsa.from_str( |
|
|
""" |
|
|
0 1 1 3 10 |
|
|
1 2 2 0 20 |
|
|
2 3 3 2 30 |
|
|
3 4 -1 -1 0 |
|
|
4 |
|
|
""", |
|
|
num_aux_labels=1, |
|
|
) |
|
|
|
|
|
fsa2 = k2.Fsa.from_str( |
|
|
""" |
|
|
0 1 1 10 1 |
|
|
1 2 2 5 2 |
|
|
2 3 -1 -1 0 |
|
|
3 |
|
|
""", |
|
|
num_aux_labels=1, |
|
|
) |
|
|
fsas = k2.Fsa.from_fsas([fsa1, fsa2]) |
|
|
texts = get_texts(fsas) |
|
|
assert texts == [[3, 2], [10, 5]] |
|
|
|
|
|
|
|
|
def test_attribute_dict(): |
|
|
s = AttributeDict({"a": 10, "b": 20}) |
|
|
assert s.a == 10 |
|
|
assert s["b"] == 20 |
|
|
s.c = 100 |
|
|
assert s["c"] == 100 |
|
|
|
|
|
assert hasattr(s, "a") |
|
|
assert hasattr(s, "b") |
|
|
assert getattr(s, "a") == 10 |
|
|
del s.a |
|
|
assert hasattr(s, "a") is False |
|
|
setattr(s, "c", 100) |
|
|
s.c = 100 |
|
|
try: |
|
|
del s.a |
|
|
except AttributeError as ex: |
|
|
print(f"Caught exception: {ex}") |
|
|
|
|
|
|
|
|
def test_get_env_info(): |
|
|
s = get_env_info() |
|
|
print(s) |
|
|
|
|
|
|
|
|
def test_makd_pad_mask(): |
|
|
lengths = torch.tensor([1, 3, 2]) |
|
|
mask = make_pad_mask(lengths) |
|
|
expected = torch.tensor( |
|
|
[ |
|
|
[False, True, True], |
|
|
[False, False, False], |
|
|
[False, False, True], |
|
|
] |
|
|
) |
|
|
assert torch.all(torch.eq(mask, expected)) |
|
|
assert (~expected).sum() == lengths.sum() |
|
|
|
|
|
|
|
|
def test_add_sos(): |
|
|
sos_id = 100 |
|
|
ragged = k2.RaggedTensor([[1, 2], [3], [0]]) |
|
|
sos_ragged = add_sos(ragged, sos_id) |
|
|
expected = k2.RaggedTensor([[sos_id, 1, 2], [sos_id, 3], [sos_id, 0]]) |
|
|
assert str(sos_ragged) == str(expected) |
|
|
|
|
|
|
|
|
def test_add_eos(): |
|
|
eos_id = 30 |
|
|
ragged = k2.RaggedTensor([[1, 2], [3], [], [5, 8, 9]]) |
|
|
ragged_eos = add_eos(ragged, eos_id) |
|
|
expected = k2.RaggedTensor( |
|
|
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]] |
|
|
) |
|
|
assert str(ragged_eos) == str(expected) |
|
|
|