| | |
| | |
| | |
| | |
| |
|
| | import io |
| | import os |
| | import string |
| | import tempfile |
| | import unittest |
| |
|
| | import torch |
| | from fairseq import tokenizer |
| | from fairseq.data import Dictionary |
| |
|
| |
|
| | class TestDictionary(unittest.TestCase): |
| | def test_finalize(self): |
| | txt = [ |
| | "A B C D", |
| | "B C D", |
| | "C D", |
| | "D", |
| | ] |
| | ref_ids1 = list( |
| | map( |
| | torch.IntTensor, |
| | [ |
| | [4, 5, 6, 7, 2], |
| | [5, 6, 7, 2], |
| | [6, 7, 2], |
| | [7, 2], |
| | ], |
| | ) |
| | ) |
| | ref_ids2 = list( |
| | map( |
| | torch.IntTensor, |
| | [ |
| | [7, 6, 5, 4, 2], |
| | [6, 5, 4, 2], |
| | [5, 4, 2], |
| | [4, 2], |
| | ], |
| | ) |
| | ) |
| |
|
| | |
| | d = Dictionary() |
| | for line in txt: |
| | d.encode_line(line, add_if_not_exist=True) |
| |
|
| | def get_ids(dictionary): |
| | ids = [] |
| | for line in txt: |
| | ids.append(dictionary.encode_line(line, add_if_not_exist=False)) |
| | return ids |
| |
|
| | def assertMatch(ids, ref_ids): |
| | for toks, ref_toks in zip(ids, ref_ids): |
| | self.assertEqual(toks.size(), ref_toks.size()) |
| | self.assertEqual(0, (toks != ref_toks).sum().item()) |
| |
|
| | ids = get_ids(d) |
| | assertMatch(ids, ref_ids1) |
| |
|
| | |
| | d.finalize() |
| | finalized_ids = get_ids(d) |
| | assertMatch(finalized_ids, ref_ids2) |
| |
|
| | |
| | with tempfile.NamedTemporaryFile(mode="w") as tmp_dict: |
| | d.save(tmp_dict.name) |
| | d = Dictionary.load(tmp_dict.name) |
| | reload_ids = get_ids(d) |
| | assertMatch(reload_ids, ref_ids2) |
| | assertMatch(finalized_ids, reload_ids) |
| |
|
| | def test_overwrite(self): |
| | |
| | dict_file = io.StringIO( |
| | "<unk> 999 #fairseq:overwrite\n" |
| | "<s> 999 #fairseq:overwrite\n" |
| | "</s> 999 #fairseq:overwrite\n" |
| | ", 999\n" |
| | "▁de 999\n" |
| | ) |
| | d = Dictionary() |
| | d.add_from_file(dict_file) |
| | self.assertEqual(d.index("<pad>"), 1) |
| | self.assertEqual(d.index("foo"), 3) |
| | self.assertEqual(d.index("<unk>"), 4) |
| | self.assertEqual(d.index("<s>"), 5) |
| | self.assertEqual(d.index("</s>"), 6) |
| | self.assertEqual(d.index(","), 7) |
| | self.assertEqual(d.index("▁de"), 8) |
| |
|
| | def test_no_overwrite(self): |
| | |
| | dict_file = io.StringIO( |
| | "<unk> 999\n" "<s> 999\n" "</s> 999\n" ", 999\n" "▁de 999\n" |
| | ) |
| | d = Dictionary() |
| | with self.assertRaisesRegex(RuntimeError, "Duplicate"): |
| | d.add_from_file(dict_file) |
| |
|
| | def test_space(self): |
| | |
| | dict_file = io.StringIO(" 999\n" "a 999\n" "b 999\n") |
| | d = Dictionary() |
| | d.add_from_file(dict_file) |
| | self.assertEqual(d.index(" "), 4) |
| | self.assertEqual(d.index("a"), 5) |
| | self.assertEqual(d.index("b"), 6) |
| |
|
| | def test_add_file_to_dict(self): |
| | counts = {} |
| | num_lines = 100 |
| | per_line = 10 |
| | with tempfile.TemporaryDirectory("test_sampling") as data_dir: |
| | filename = os.path.join(data_dir, "dummy.txt") |
| | with open(filename, "w", encoding="utf-8") as data: |
| | for c in string.ascii_letters: |
| | line = f"{c} " * per_line |
| | for _ in range(num_lines): |
| | data.write(f"{line}\n") |
| | counts[c] = per_line * num_lines |
| | per_line += 5 |
| |
|
| | dict = Dictionary() |
| | Dictionary.add_file_to_dictionary( |
| | filename, dict, tokenizer.tokenize_line, 10 |
| | ) |
| | dict.finalize(threshold=0, nwords=-1, padding_factor=8) |
| |
|
| | for c in string.ascii_letters: |
| | count = dict.get_count(dict.index(c)) |
| | self.assertEqual( |
| | counts[c], count, f"{c} count is {count} but should be {counts[c]}" |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | unittest.main() |
| |
|