| | |
| | |
| | |
| | |
| |
|
| |
|
| | import os |
| | import typing as tp |
| | import unittest |
| | from tempfile import TemporaryDirectory |
| |
|
| | from fairseq.binarizer import BinarizeSummary, FileBinarizer, VocabularyDatasetBinarizer |
| | from fairseq.data import Dictionary, indexed_dataset |
| | from tests.utils import make_data, sizes |
| |
|
| |
|
| | def build_vocab(data: tp.List[tp.List[str]]) -> Dictionary: |
| | d = Dictionary() |
| | for s in data: |
| | for token in s: |
| | d.add_symbol(token) |
| | d.finalize() |
| | return d |
| |
|
| |
|
| | class TestBinarizer(unittest.TestCase): |
| | def compare_ds_data(self, summary, data, prefix, impl, vocab): |
| | self.assertEqual(summary.num_seq, len(data)) |
| | self.assertEqual(summary.num_tok, sum([len(s) for s in data])) |
| |
|
| | dataset = indexed_dataset.make_dataset(prefix, impl) |
| |
|
| | self.assertEqual(len(dataset), len(data)) |
| | decoded = [vocab.string(dataset[i]).split() for i in range(0, len(dataset))] |
| |
|
| | self.assertEqual(decoded, data) |
| | data_sizes = [i.item() for i in dataset.sizes] |
| | self.assertEqual(data_sizes, sizes(data)) |
| |
|
| | def test_can_binarize_line(self): |
| | data = make_data(length=1) |
| | vocab = build_vocab(data) |
| |
|
| | binarizer = VocabularyDatasetBinarizer( |
| | vocab, |
| | ) |
| |
|
| | sentence = data[0] |
| | summary = BinarizeSummary() |
| |
|
| | tensor = binarizer.binarize_line( |
| | " ".join(sentence), |
| | summary, |
| | ) |
| |
|
| | self.assertEqual(len(tensor), len(sentence) + 1) |
| |
|
| | self.assertEqual(summary.num_tok, len(sentence) + 1) |
| | self.assertEqual(summary.num_seq, 1) |
| |
|
| | def test_can_binarize_file_chunk(self): |
| | |
| | with TemporaryDirectory() as dirname: |
| | raw_file = os.path.join(dirname, "raw1") |
| | prefix = os.path.join(dirname, "test1") |
| | impl = "mmap" |
| |
|
| | data = make_data(out_file=raw_file) |
| | vocab = build_vocab(data) |
| |
|
| | binarizer = VocabularyDatasetBinarizer( |
| | vocab, |
| | append_eos=False, |
| | ) |
| |
|
| | summary = FileBinarizer._binarize_chunk_and_finalize( |
| | binarizer, |
| | raw_file, |
| | offset_start=0, |
| | offset_end=-1, |
| | output_prefix=prefix, |
| | dataset_impl=impl, |
| | vocab_size=len(vocab), |
| | ) |
| |
|
| | self.compare_ds_data(summary, data, prefix, impl, vocab) |
| |
|
| | def test_can_multiprocess(self): |
| | with TemporaryDirectory() as dirname: |
| | raw_file = os.path.join(dirname, "raw1") |
| | prefix = os.path.join(dirname, "test1") |
| | impl = "mmap" |
| | data = make_data(out_file=raw_file) |
| | vocab = build_vocab(data) |
| | binarizer = VocabularyDatasetBinarizer( |
| | vocab, |
| | append_eos=False, |
| | ) |
| | |
| | summary = FileBinarizer.multiprocess_dataset( |
| | raw_file, |
| | impl, |
| | binarizer, |
| | output_prefix=prefix, |
| | vocab_size=len(vocab), |
| | num_workers=1, |
| | ) |
| |
|
| | self.compare_ds_data(summary, data, prefix, impl, vocab) |
| |
|
| | |
| | prefix_multi = os.path.join(dirname, "test2") |
| | summary = FileBinarizer.multiprocess_dataset( |
| | raw_file, |
| | impl, |
| | binarizer, |
| | output_prefix=prefix_multi, |
| | vocab_size=len(vocab), |
| | num_workers=3, |
| | ) |
| |
|
| | self.compare_ds_data(summary, data, prefix_multi, impl, vocab) |
| |
|