| import unittest |
| from onmt.translate import GeneratorLM |
| import torch |
|
|
|
|
| class TestGeneratorLM(unittest.TestCase): |
| def test_split_src_to_prevent_padding_target_prefix_is_none_when_equal_size( |
| self, |
| ): |
| src = torch.randint(0, 10, (6, 5, 1)) |
| src_len = 5 * torch.ones(5, dtype=torch.int) |
| ( |
| src, |
| src_len, |
| target_prefix, |
| ) = GeneratorLM.split_src_to_prevent_padding(src, src_len) |
| self.assertIsNone(target_prefix) |
|
|
| def test_split_src_to_prevent_padding_target_prefix_is_ok_when_different_size( |
| self, |
| ): |
| default_length = 5 |
| src = torch.randint(0, 10, (6, default_length, 1)) |
| src_len = default_length * torch.ones(6, dtype=torch.int) |
| new_length = 4 |
| src_len[1] = new_length |
| ( |
| src, |
| src_len, |
| target_prefix, |
| ) = GeneratorLM.split_src_to_prevent_padding(src, src_len) |
| self.assertTupleEqual(src.shape, (6, new_length, 1)) |
| self.assertTupleEqual(target_prefix.shape, (6, 1, 1)) |
| self.assertTrue(src_len.equal(new_length * torch.ones(6, dtype=torch.int))) |
|
|