Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import random | |
| import string | |
| import typing as tp | |
| import unittest | |
| from collections import Counter | |
| from tempfile import NamedTemporaryFile, TemporaryDirectory | |
| from fairseq.data import Dictionary, indexed_dataset | |
| from fairseq.data.huffman import ( | |
| HuffmanCodeBuilder, | |
| HuffmanCoder, | |
| HuffmanMMapIndexedDataset, | |
| HuffmanMMapIndexedDatasetBuilder, | |
| ) | |
| POPULATION = string.ascii_letters + string.digits | |
| def make_sentence() -> tp.List[str]: | |
| length = random.randint(10, 50) | |
| return random.choices( | |
| population=POPULATION, k=length, weights=range(1, len(POPULATION) + 1) | |
| ) | |
| def make_data(length=1000) -> tp.List[tp.List[str]]: | |
| return ( | |
| [make_sentence() for _ in range(0, length)] | |
| # add all the symbols at least once | |
| + [list(string.ascii_letters), list(string.digits)] | |
| ) | |
| def make_counts(data: tp.List[tp.List[str]]) -> Counter: | |
| return Counter([symbol for sentence in data for symbol in sentence]) | |
| def make_code_builder(data: tp.List[tp.List[str]]) -> HuffmanCodeBuilder: | |
| builder = HuffmanCodeBuilder() | |
| for sentence in data: | |
| builder.add_symbols(*sentence) | |
| return builder | |
| class TestCodeBuilder(unittest.TestCase): | |
| def test_code_builder_can_count(self): | |
| data = make_data() | |
| counts = make_counts(data) | |
| builder = make_code_builder(data) | |
| self.assertEqual(builder.symbols, counts) | |
| def test_code_builder_can_add(self): | |
| data = make_data() | |
| counts = make_counts(data) | |
| builder = make_code_builder(data) | |
| new_builder = builder + builder | |
| self.assertEqual(new_builder.symbols, counts + counts) | |
| def test_code_builder_can_io(self): | |
| data = make_data() | |
| builder = make_code_builder(data) | |
| with NamedTemporaryFile() as tmp_fp: | |
| builder.to_file(tmp_fp.name) | |
| other_builder = HuffmanCodeBuilder.from_file(tmp_fp.name) | |
| self.assertEqual(builder.symbols, other_builder.symbols) | |
| class TestCoder(unittest.TestCase): | |
| def test_coder_can_io(self): | |
| data = make_data() | |
| builder = make_code_builder(data) | |
| coder = builder.build_code() | |
| with NamedTemporaryFile() as tmp_fp: | |
| coder.to_file(tmp_fp.name) | |
| other_coder = HuffmanCoder.from_file(tmp_fp.name) | |
| self.assertEqual(coder, other_coder) | |
| def test_coder_can_encode_decode(self): | |
| data = make_data() | |
| builder = make_code_builder(data) | |
| coder = builder.build_code() | |
| encoded = [coder.encode(sentence) for sentence in data] | |
| decoded = [[n.symbol for n in coder.decode(enc)] for enc in encoded] | |
| self.assertEqual(decoded, data) | |
| unseen_data = make_data() | |
| unseen_encoded = [coder.encode(sentence) for sentence in unseen_data] | |
| unseen_decoded = [ | |
| [n.symbol for n in coder.decode(enc)] for enc in unseen_encoded | |
| ] | |
| self.assertEqual(unseen_decoded, unseen_data) | |
| def build_dataset(prefix, data, coder): | |
| with HuffmanMMapIndexedDatasetBuilder(prefix, coder) as builder: | |
| for sentence in data: | |
| builder.add_item(sentence) | |
| def sizes(data): | |
| return [len(sentence) for sentence in data] | |
| class TestHuffmanDataset(unittest.TestCase): | |
| def test_huffman_can_encode_decode(self): | |
| data = make_data() | |
| builder = make_code_builder(data) | |
| coder = builder.build_code() | |
| with TemporaryDirectory() as dirname: | |
| prefix = os.path.join(dirname, "test1") | |
| build_dataset(prefix, data, coder) | |
| dataset = HuffmanMMapIndexedDataset(prefix) | |
| self.assertEqual(len(dataset), len(data)) | |
| decoded = [list(dataset.get_symbols(i)) for i in range(0, len(dataset))] | |
| self.assertEqual(decoded, data) | |
| data_sizes = [i.item() for i in dataset.sizes] | |
| self.assertEqual(data_sizes, sizes(data)) | |
| def test_huffman_compresses(self): | |
| data = make_data() | |
| builder = make_code_builder(data) | |
| coder = builder.build_code() | |
| with TemporaryDirectory() as dirname: | |
| prefix = os.path.join(dirname, "huffman") | |
| build_dataset(prefix, data, coder) | |
| prefix_mmap = os.path.join(dirname, "mmap") | |
| mmap_builder = indexed_dataset.make_builder( | |
| indexed_dataset.data_file_path(prefix_mmap), | |
| "mmap", | |
| vocab_size=len(POPULATION), | |
| ) | |
| dictionary = Dictionary() | |
| for c in POPULATION: | |
| dictionary.add_symbol(c) | |
| dictionary.finalize() | |
| for sentence in data: | |
| mmap_builder.add_item(dictionary.encode_line(" ".join(sentence))) | |
| mmap_builder.finalize(indexed_dataset.index_file_path(prefix_mmap)) | |
| huff_size = os.stat(indexed_dataset.data_file_path(prefix)).st_size | |
| mmap_size = os.stat(indexed_dataset.data_file_path(prefix_mmap)).st_size | |
| self.assertLess(huff_size, mmap_size) | |
| def test_huffman_can_append(self): | |
| data1 = make_data() | |
| builder = make_code_builder(data1) | |
| coder = builder.build_code() | |
| with TemporaryDirectory() as dirname: | |
| prefix1 = os.path.join(dirname, "test1") | |
| build_dataset(prefix1, data1, coder) | |
| data2 = make_data() | |
| prefix2 = os.path.join(dirname, "test2") | |
| build_dataset(prefix2, data2, coder) | |
| prefix3 = os.path.join(dirname, "test3") | |
| with HuffmanMMapIndexedDatasetBuilder(prefix3, coder) as builder: | |
| builder.append(prefix1) | |
| builder.append(prefix2) | |
| dataset = HuffmanMMapIndexedDataset(prefix3) | |
| self.assertEqual(len(dataset), len(data1) + len(data2)) | |
| decoded1 = [list(dataset.get_symbols(i)) for i in range(0, len(data1))] | |
| self.assertEqual(decoded1, data1) | |
| decoded2 = [ | |
| list(dataset.get_symbols(i)) for i in range(len(data1), len(dataset)) | |
| ] | |
| self.assertEqual(decoded2, data2) | |
| data_sizes = [i.item() for i in dataset.sizes] | |
| self.assertEqual(data_sizes[: len(data1)], sizes(data1)) | |
| self.assertEqual(data_sizes[len(data1) : len(dataset)], sizes(data2)) | |
| if __name__ == "__main__": | |
| unittest.main() | |