Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import re | |
| import typing as tp | |
| from collections import Counter, deque | |
| from dataclasses import dataclass | |
| from bitarray import bitarray, util | |
| from fairseq.data import Dictionary | |
| # basically we have to write to addressable bytes for the memory mapped | |
| # dataset loader. Sentences that get encoded to a length that is not a | |
| # multiple of BLOCKSIZE (a byte) will be padded to fit. (see _pad in the coder) | |
| BLOCKSIZE = 8 | |
| class HuffmanCoder: | |
| def __init__( | |
| self, root: "HuffmanNode", bos="<s>", pad="<pad>", eos="</s>", unk="<unk>" | |
| ): | |
| self.root = root | |
| self.table = root.code_table() | |
| self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos | |
| def _pad(self, a: bitarray) -> bitarray: | |
| """ | |
| bitpadding, 1 then 0. | |
| If the array is already a multiple of blocksize, we add a full block. | |
| """ | |
| pad_len = BLOCKSIZE - (len(a) % BLOCKSIZE) - 1 | |
| padding = bitarray("1" + "0" * pad_len) | |
| return a + padding | |
| def _unpad(self, a: bitarray) -> bitarray: | |
| """ | |
| remove the bitpadding. | |
| There will be a set of 0s preceded by a 1 at the end of the bitarray, we remove that | |
| """ | |
| # count the 0 padding at the end until we find the first 1 | |
| # we want to remove the one too | |
| remove_cnt = util.rindex(a, 1) | |
| return a[:remove_cnt] | |
| def encode(self, iter: tp.List[str]) -> bytes: | |
| """ | |
| encode a list of tokens a return bytes. We use bitpadding to make sure the encoded bits fit in bytes. | |
| """ | |
| a = bitarray() | |
| for token in iter: | |
| code = self.get_code(token) | |
| if code is None: | |
| if self.unk_word is None: | |
| raise Exception(f"unknown token {token} cannot be encoded.") | |
| else: | |
| token = self.unk_word | |
| a = a + self.get_code(token) | |
| return self._pad(a).tobytes() | |
| def decode(self, bits: bytes) -> tp.Iterator["HuffmanNode"]: | |
| """ | |
| take bitpadded bytes and decode it to a set of leaves. You can then use each node to find the symbol/id | |
| """ | |
| a = bitarray() | |
| a.frombytes(bits) | |
| return self.root.decode(self._unpad(a)) | |
| def get_code(self, symbol: str) -> tp.Optional[bitarray]: | |
| node = self.get_node(symbol) | |
| return None if node is None else node.code | |
| def get_node(self, symbol: str) -> "HuffmanNode": | |
| return self.table.get(symbol) | |
| def from_file( | |
| cls, | |
| filename: str, | |
| bos="<s>", | |
| pad="<pad>", | |
| eos="</s>", | |
| unk="<unk>", | |
| ) -> "HuffmanCoder": | |
| builder = HuffmanCodeBuilder.from_file(filename) | |
| return builder.build_code(bos=bos, pad=pad, eos=eos, unk=unk) | |
| def to_file(self, filename, sep="\t"): | |
| nodes = list(self.table.values()) | |
| nodes.sort(key=lambda n: n.id) | |
| with open(filename, "w", encoding="utf-8") as output: | |
| for n in nodes: | |
| output.write(f"{n.symbol}{sep}{n.count}\n") | |
| def __iter__(self): | |
| for n in self.table.values(): | |
| yield n | |
| def merge(self, other_coder: "HuffmanCoder") -> "HuffmanCoder": | |
| builder = HuffmanCodeBuilder() | |
| for n in self: | |
| builder.increment(n.symbol, n.count) | |
| for n in other_coder: | |
| builder.increment(n.symbol, n.count) | |
| return builder.build_code() | |
| def __eq__(self, other: "HuffmanCoder") -> bool: | |
| return self.table == other.table | |
| def __len__(self) -> int: | |
| return len(self.table) | |
| def __contains__(self, sym: str) -> bool: | |
| return sym in self.table | |
| def to_dictionary(self) -> Dictionary: | |
| dictionary = Dictionary(bos=self.bos, unk=self.unk, pad=self.pad, eos=self.eos) | |
| for n in self: | |
| dictionary.add_symbol(n.symbol, n=n.count) | |
| dictionary.finalize() | |
| return dictionary | |
| class HuffmanNode: | |
| """ | |
| a node in a Huffman tree | |
| """ | |
| id: int | |
| count: int | |
| symbol: tp.Optional[str] = None | |
| left: tp.Optional["HuffmanNode"] = None | |
| right: tp.Optional["HuffmanNode"] = None | |
| code: tp.Optional[bitarray] = None | |
| def is_leaf(self) -> bool: | |
| return self.left is None and self.right is None | |
| def code_table( | |
| self, prefix: tp.Optional[bitarray] = None | |
| ) -> tp.Dict[str, "HuffmanNode"]: | |
| defaulted_prefix = prefix if prefix is not None else bitarray() | |
| if self.is_leaf(): | |
| self.code = ( | |
| defaulted_prefix if len(defaulted_prefix) > 0 else bitarray("0") | |
| ) # leaf could be the root if there is only one symbol | |
| return {self.symbol: self} | |
| codes_right = self.right.code_table(defaulted_prefix + bitarray([0])) | |
| codes_left = self.left.code_table(defaulted_prefix + bitarray([1])) | |
| return {**codes_left, **codes_right} | |
| def decode(self, bits: bitarray) -> tp.Iterator["HuffmanNode"]: | |
| current_node = self | |
| for bit in bits: | |
| if bit == 0: # go right | |
| current_node = current_node.right | |
| else: # go left | |
| current_node = current_node.left | |
| if current_node is None: | |
| # we shouldn't be on a leaf here | |
| raise Exception("fell off a leaf") | |
| if current_node.is_leaf(): | |
| yield current_node | |
| current_node = self | |
| if current_node != self: | |
| raise Exception("couldn't decode all the bits") | |
| class HuffmanCodeBuilder: | |
| """ | |
| build a dictionary with occurence count and then build the Huffman code for it. | |
| """ | |
| def __init__(self): | |
| self.symbols = Counter() | |
| def add_symbols(self, *syms) -> None: | |
| self.symbols.update(syms) | |
| def increment(self, symbol: str, cnt: int) -> None: | |
| self.symbols[symbol] += cnt | |
| def from_file(cls, filename): | |
| c = cls() | |
| with open(filename, "r", encoding="utf-8") as input: | |
| for line in input: | |
| split = re.split(r"[\s]+", line) | |
| c.increment(split[0], int(split[1])) | |
| return c | |
| def to_file(self, filename, sep="\t"): | |
| with open(filename, "w", encoding="utf-8") as output: | |
| for (tok, cnt) in self.symbols.most_common(): | |
| output.write(f"{tok}{sep}{cnt}\n") | |
| def _smallest(self, q1: deque, q2: deque) -> HuffmanNode: | |
| if len(q1) == 0: | |
| return q2.pop() | |
| if len(q2) == 0: | |
| return q1.pop() | |
| if q1[-1].count < q2[-1].count: | |
| return q1.pop() | |
| return q2.pop() | |
| def __add__(self, c: "HuffmanCodeBuilder") -> "HuffmanCodeBuilder": | |
| new_c = self.symbols + c.symbols | |
| new_b = HuffmanCodeBuilder() | |
| new_b.symbols = new_c | |
| return new_b | |
| def build_code( | |
| self, | |
| bos="<s>", | |
| pad="<pad>", | |
| eos="</s>", | |
| unk="<unk>", | |
| ) -> HuffmanCoder: | |
| assert len(self.symbols) > 0, "cannot build code from empty list of symbols" | |
| if self.symbols[bos] == 0: | |
| self.add_symbols(bos) | |
| if self.symbols[pad] == 0: | |
| self.add_symbols(pad) | |
| if self.symbols[eos] == 0: | |
| self.add_symbols(eos) | |
| if self.symbols[unk] == 0: | |
| self.add_symbols(unk) | |
| node_id = 0 | |
| leaves_queue = deque( | |
| [ | |
| HuffmanNode(symbol=symbol, count=count, id=idx) | |
| for idx, (symbol, count) in enumerate(self.symbols.most_common()) | |
| ] | |
| ) # left are the most common, right are the least common | |
| if len(leaves_queue) == 1: | |
| root = leaves_queue.pop() | |
| root.id = 0 | |
| return HuffmanCoder(root) | |
| nodes_queue = deque() | |
| while len(leaves_queue) > 0 or len(nodes_queue) != 1: | |
| # get the lowest two nodes at the head of each queue | |
| node1 = self._smallest(leaves_queue, nodes_queue) | |
| node2 = self._smallest(leaves_queue, nodes_queue) | |
| # add new node | |
| nodes_queue.appendleft( | |
| HuffmanNode( | |
| count=node1.count + node2.count, left=node1, right=node2, id=node_id | |
| ) | |
| ) | |
| node_id += 1 | |
| # we are left with the root | |
| return HuffmanCoder(nodes_queue.pop(), bos=bos, pad=pad, eos=eos, unk=unk) | |