|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Useful utilities for working with bitarrays. |
|
|
""" |
|
|
from __future__ import absolute_import |
|
|
|
|
|
import os |
|
|
import sys |
|
|
|
|
|
from bitarray import bitarray, bits2bytes, get_default_endian |
|
|
|
|
|
from bitarray._util import ( |
|
|
count_n, rindex, parity, count_and, count_or, count_xor, subset, |
|
|
serialize, ba2hex, _hex2ba, ba2base, _base2ba, vl_encode, _vl_decode, |
|
|
canonical_decode, _set_bato, |
|
|
) |
|
|
|
|
|
__all__ = [ |
|
|
'zeros', 'urandom', 'pprint', 'make_endian', 'rindex', 'strip', 'count_n', |
|
|
'parity', 'count_and', 'count_or', 'count_xor', 'subset', |
|
|
'ba2hex', 'hex2ba', 'ba2base', 'base2ba', 'ba2int', 'int2ba', |
|
|
'serialize', 'deserialize', 'vl_encode', 'vl_decode', |
|
|
'huffman_code', 'canonical_huffman', 'canonical_decode', |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_set_bato(bitarray) |
|
|
|
|
|
_is_py2 = bool(sys.version_info[0] == 2) |
|
|
|
|
|
|
|
|
def zeros(__length, endian=None): |
|
|
"""zeros(length, /, endian=None) -> bitarray |
|
|
|
|
|
Create a bitarray of length, with all values 0, and optional |
|
|
endianness, which may be 'big', 'little'. |
|
|
""" |
|
|
if not isinstance(__length, (int, long) if _is_py2 else int): |
|
|
raise TypeError("int expected, got '%s'" % type(__length).__name__) |
|
|
|
|
|
a = bitarray(__length, get_default_endian() if endian is None else endian) |
|
|
a.setall(0) |
|
|
return a |
|
|
|
|
|
|
|
|
def urandom(__length, endian=None): |
|
|
"""urandom(length, /, endian=None) -> bitarray |
|
|
|
|
|
Return a bitarray of `length` random bits (uses `os.urandom`). |
|
|
""" |
|
|
a = bitarray(0, get_default_endian() if endian is None else endian) |
|
|
a.frombytes(os.urandom(bits2bytes(__length))) |
|
|
del a[__length:] |
|
|
return a |
|
|
|
|
|
|
|
|
def pprint(__a, stream=None, group=8, indent=4, width=80): |
|
|
"""pprint(bitarray, /, stream=None, group=8, indent=4, width=80) |
|
|
|
|
|
Prints the formatted representation of object on `stream` (which defaults |
|
|
to `sys.stdout`). By default, elements are grouped in bytes (8 elements), |
|
|
and 8 bytes (64 elements) per line. |
|
|
Non-bitarray objects are printed by the standard library |
|
|
function `pprint.pprint()`. |
|
|
""" |
|
|
if stream is None: |
|
|
stream = sys.stdout |
|
|
|
|
|
if not isinstance(__a, bitarray): |
|
|
import pprint as _pprint |
|
|
_pprint.pprint(__a, stream=stream, indent=indent, width=width) |
|
|
return |
|
|
|
|
|
group = int(group) |
|
|
if group < 1: |
|
|
raise ValueError('group must be >= 1') |
|
|
indent = int(indent) |
|
|
if indent < 0: |
|
|
raise ValueError('indent must be >= 0') |
|
|
width = int(width) |
|
|
if width <= indent: |
|
|
raise ValueError('width must be > %d (indent)' % indent) |
|
|
|
|
|
gpl = (width - indent) // (group + 1) |
|
|
epl = group * gpl |
|
|
if epl == 0: |
|
|
epl = width - indent - 2 |
|
|
type_name = type(__a).__name__ |
|
|
|
|
|
multiline = len(type_name) + 4 + len(__a) + len(__a) // group >= width |
|
|
if multiline: |
|
|
quotes = "'''" |
|
|
elif __a: |
|
|
quotes = "'" |
|
|
else: |
|
|
quotes = "" |
|
|
|
|
|
stream.write("%s(%s" % (type_name, quotes)) |
|
|
for i, b in enumerate(__a): |
|
|
if multiline and i % epl == 0: |
|
|
stream.write('\n%s' % (indent * ' ')) |
|
|
if i % group == 0 and i % epl != 0: |
|
|
stream.write(' ') |
|
|
stream.write(str(b)) |
|
|
|
|
|
if multiline: |
|
|
stream.write('\n') |
|
|
|
|
|
stream.write("%s)\n" % quotes) |
|
|
stream.flush() |
|
|
|
|
|
|
|
|
def make_endian(__a, endian): |
|
|
"""make_endian(bitarray, /, endian) -> bitarray |
|
|
|
|
|
When the endianness of the given bitarray is different from `endian`, |
|
|
return a new bitarray, with endianness `endian` and the same elements |
|
|
as the original bitarray. |
|
|
Otherwise (endianness is already `endian`) the original bitarray is returned |
|
|
unchanged. |
|
|
""" |
|
|
if not isinstance(__a, bitarray): |
|
|
raise TypeError("bitarray expected, got '%s'" % type(__a).__name__) |
|
|
|
|
|
if __a.endian() == endian: |
|
|
return __a |
|
|
|
|
|
return bitarray(__a, endian) |
|
|
|
|
|
|
|
|
def strip(__a, mode='right'): |
|
|
"""strip(bitarray, /, mode='right') -> bitarray |
|
|
|
|
|
Return a new bitarray with zeros stripped from left, right or both ends. |
|
|
Allowed values for mode are the strings: `left`, `right`, `both` |
|
|
""" |
|
|
if not isinstance(__a, bitarray): |
|
|
raise TypeError("bitarray expected, got '%s'" % type(__a).__name__) |
|
|
if not isinstance(mode, str): |
|
|
raise TypeError("str expected for mode, got '%s'" % type(__a).__name__) |
|
|
if mode not in ('left', 'right', 'both'): |
|
|
raise ValueError("mode must be 'left', 'right' or 'both', got %r" % |
|
|
mode) |
|
|
first = 0 |
|
|
if mode in ('left', 'both'): |
|
|
try: |
|
|
first = __a.index(1) |
|
|
except ValueError: |
|
|
return __a[:0] |
|
|
|
|
|
last = len(__a) - 1 |
|
|
if mode in ('right', 'both'): |
|
|
try: |
|
|
last = rindex(__a) |
|
|
except ValueError: |
|
|
return __a[:0] |
|
|
|
|
|
return __a[first:last + 1] |
|
|
|
|
|
|
|
|
def hex2ba(__s, endian=None): |
|
|
"""hex2ba(hexstr, /, endian=None) -> bitarray |
|
|
|
|
|
Bitarray of hexadecimal representation. hexstr may contain any number |
|
|
(including odd numbers) of hex digits (upper or lower case). |
|
|
""" |
|
|
if isinstance(__s, unicode if _is_py2 else str): |
|
|
__s = __s.encode('ascii') |
|
|
if not isinstance(__s, bytes): |
|
|
raise TypeError("str expected, got '%s'" % type(__s).__name__) |
|
|
|
|
|
a = bitarray(4 * len(__s), |
|
|
get_default_endian() if endian is None else endian) |
|
|
_hex2ba(a, __s) |
|
|
return a |
|
|
|
|
|
|
|
|
def base2ba(__n, __s, endian=None): |
|
|
"""base2ba(n, asciistr, /, endian=None) -> bitarray |
|
|
|
|
|
Bitarray of the base `n` ASCII representation. |
|
|
Allowed values for `n` are 2, 4, 8, 16, 32 and 64. |
|
|
For `n=16` (hexadecimal), `hex2ba()` will be much faster, as `base2ba()` |
|
|
does not take advantage of byte level operations. |
|
|
For `n=32` the RFC 4648 Base32 alphabet is used, and for `n=64` the |
|
|
standard base 64 alphabet is used. |
|
|
""" |
|
|
if isinstance(__s, unicode if _is_py2 else str): |
|
|
__s = __s.encode('ascii') |
|
|
if not isinstance(__s, bytes): |
|
|
raise TypeError("str expected, got '%s'" % type(__s).__name__) |
|
|
|
|
|
a = bitarray(_base2ba(__n) * len(__s), |
|
|
get_default_endian() if endian is None else endian) |
|
|
_base2ba(__n, a, __s) |
|
|
return a |
|
|
|
|
|
|
|
|
def ba2int(__a, signed=False): |
|
|
"""ba2int(bitarray, /, signed=False) -> int |
|
|
|
|
|
Convert the given bitarray to an integer. |
|
|
The bit-endianness of the bitarray is respected. |
|
|
`signed` indicates whether two's complement is used to represent the integer. |
|
|
""" |
|
|
if not isinstance(__a, bitarray): |
|
|
raise TypeError("bitarray expected, got '%s'" % type(__a).__name__) |
|
|
length = len(__a) |
|
|
if length == 0: |
|
|
raise ValueError("non-empty bitarray expected") |
|
|
|
|
|
le = bool(__a.endian() == 'little') |
|
|
if length % 8: |
|
|
pad = zeros(8 - length % 8, __a.endian()) |
|
|
__a = __a + pad if le else pad + __a |
|
|
|
|
|
if _is_py2: |
|
|
a = bitarray(__a, 'big') |
|
|
if le: |
|
|
a.reverse() |
|
|
res = int(ba2hex(a), 16) |
|
|
else: |
|
|
res = int.from_bytes(__a.tobytes(), byteorder=__a.endian()) |
|
|
|
|
|
if signed and res >= 1 << (length - 1): |
|
|
res -= 1 << length |
|
|
return res |
|
|
|
|
|
|
|
|
def int2ba(__i, length=None, endian=None, signed=False): |
|
|
"""int2ba(int, /, length=None, endian=None, signed=False) -> bitarray |
|
|
|
|
|
Convert the given integer to a bitarray (with given endianness, |
|
|
and no leading (big-endian) / trailing (little-endian) zeros), unless |
|
|
the `length` of the bitarray is provided. An `OverflowError` is raised |
|
|
if the integer is not representable with the given number of bits. |
|
|
`signed` determines whether two's complement is used to represent the integer, |
|
|
and requires `length` to be provided. |
|
|
""" |
|
|
if not isinstance(__i, (int, long) if _is_py2 else int): |
|
|
raise TypeError("int expected, got '%s'" % type(__i).__name__) |
|
|
if length is not None: |
|
|
if not isinstance(length, int): |
|
|
raise TypeError("int expected for length") |
|
|
if length <= 0: |
|
|
raise ValueError("length must be > 0") |
|
|
if signed and length is None: |
|
|
raise TypeError("signed requires length") |
|
|
|
|
|
if __i == 0: |
|
|
|
|
|
return zeros(length or 1, endian) |
|
|
|
|
|
if signed: |
|
|
m = 1 << (length - 1) |
|
|
if not (-m <= __i < m): |
|
|
raise OverflowError("signed integer not in range(%d, %d), " |
|
|
"got %d" % (-m, m, __i)) |
|
|
if __i < 0: |
|
|
__i += 1 << length |
|
|
else: |
|
|
if __i < 0: |
|
|
raise OverflowError("unsigned integer not positive, got %d" % __i) |
|
|
if length and __i >= (1 << length): |
|
|
raise OverflowError("unsigned integer not in range(0, %d), " |
|
|
"got %d" % (1 << length, __i)) |
|
|
|
|
|
a = bitarray(0, get_default_endian() if endian is None else endian) |
|
|
le = bool(a.endian() == 'little') |
|
|
if _is_py2: |
|
|
s = hex(__i)[2:].rstrip('L') |
|
|
a.extend(hex2ba(s, 'big')) |
|
|
if le: |
|
|
a.reverse() |
|
|
else: |
|
|
b = __i.to_bytes(bits2bytes(__i.bit_length()), byteorder=a.endian()) |
|
|
a.frombytes(b) |
|
|
|
|
|
if length is None: |
|
|
return strip(a, 'right' if le else 'left') |
|
|
|
|
|
la = len(a) |
|
|
if la > length: |
|
|
a = a[:length] if le else a[-length:] |
|
|
if la < length: |
|
|
pad = zeros(length - la, endian) |
|
|
a = a + pad if le else pad + a |
|
|
assert len(a) == length |
|
|
return a |
|
|
|
|
|
|
|
|
def deserialize(__b): |
|
|
"""deserialize(bytes, /) -> bitarray |
|
|
|
|
|
Return a bitarray given a bytes-like representation such as returned |
|
|
by `serialize()`. |
|
|
""" |
|
|
if isinstance(__b, int): |
|
|
raise TypeError("cannot convert 'int' object to bytes") |
|
|
if not isinstance(__b, bytes): |
|
|
__b = bytes(__b) |
|
|
if len(__b) == 0: |
|
|
raise ValueError("non-empty bytes expected") |
|
|
|
|
|
if _is_py2: |
|
|
head = ord(__b[0]) |
|
|
if head >= 32 or head % 16 >= 8: |
|
|
raise ValueError('invalid header byte: 0x%02x' % head) |
|
|
try: |
|
|
return bitarray(__b) |
|
|
except TypeError: |
|
|
raise ValueError('invalid header byte: 0x%02x' % __b[0]) |
|
|
|
|
|
|
|
|
def vl_decode(__stream, endian=None): |
|
|
"""vl_decode(stream, /, endian=None) -> bitarray |
|
|
|
|
|
Decode binary stream (an integer iterator, or bytes-like object), and return |
|
|
the decoded bitarray. This function consumes only one bitarray and leaves |
|
|
the remaining stream untouched. `StopIteration` is raised when no |
|
|
terminating byte is found. |
|
|
Use `vl_encode()` for encoding. |
|
|
""" |
|
|
a = bitarray(32, get_default_endian() if endian is None else endian) |
|
|
_vl_decode(iter(__stream), a) |
|
|
return a |
|
|
|
|
|
|
|
|
|
|
|
def _huffman_tree(__freq_map): |
|
|
"""_huffman_tree(dict, /) -> Node |
|
|
|
|
|
Given a dict mapping symbols to their frequency, construct a Huffman tree |
|
|
and return its root node. |
|
|
""" |
|
|
from heapq import heappush, heappop |
|
|
|
|
|
class Node(object): |
|
|
""" |
|
|
A Node object will either have .symbol (leaf node) or |
|
|
both .child_0 and .child_1 (internal node) attributes. |
|
|
The .freq attributes will always be present. |
|
|
""" |
|
|
def __lt__(self, other): |
|
|
|
|
|
return self.freq < other.freq |
|
|
|
|
|
minheap = [] |
|
|
|
|
|
for sym, f in __freq_map.items(): |
|
|
nd = Node() |
|
|
nd.symbol = sym |
|
|
nd.freq = f |
|
|
heappush(minheap, nd) |
|
|
|
|
|
|
|
|
while len(minheap) > 1: |
|
|
|
|
|
child_0 = heappop(minheap) |
|
|
child_1 = heappop(minheap) |
|
|
|
|
|
parent = Node() |
|
|
parent.child_0 = child_0 |
|
|
parent.child_1 = child_1 |
|
|
parent.freq = child_0.freq + child_1.freq |
|
|
heappush(minheap, parent) |
|
|
|
|
|
|
|
|
return minheap[0] |
|
|
|
|
|
|
|
|
def huffman_code(__freq_map, endian=None): |
|
|
"""huffman_code(dict, /, endian=None) -> dict |
|
|
|
|
|
Given a frequency map, a dictionary mapping symbols to their frequency, |
|
|
calculate the Huffman code, i.e. a dict mapping those symbols to |
|
|
bitarrays (with given endianness). Note that the symbols are not limited |
|
|
to being strings. Symbols may may be any hashable object (such as `None`). |
|
|
""" |
|
|
if not isinstance(__freq_map, dict): |
|
|
raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__) |
|
|
if endian is None: |
|
|
endian = get_default_endian() |
|
|
|
|
|
b0 = bitarray('0', endian) |
|
|
b1 = bitarray('1', endian) |
|
|
|
|
|
if len(__freq_map) < 2: |
|
|
if len(__freq_map) == 0: |
|
|
raise ValueError("cannot create Huffman code with no symbols") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {list(__freq_map)[0]: b0} |
|
|
|
|
|
result = {} |
|
|
|
|
|
def traverse(nd, prefix=bitarray(0, endian)): |
|
|
try: |
|
|
result[nd.symbol] = prefix |
|
|
except AttributeError: |
|
|
traverse(nd.child_0, prefix + b0) |
|
|
traverse(nd.child_1, prefix + b1) |
|
|
|
|
|
traverse(_huffman_tree(__freq_map)) |
|
|
return result |
|
|
|
|
|
|
|
|
def canonical_huffman(__freq_map): |
|
|
"""canonical_huffman(dict, /) -> tuple |
|
|
|
|
|
Given a frequency map, a dictionary mapping symbols to their frequency, |
|
|
calculate the canonical Huffman code. Returns a tuple containing: |
|
|
|
|
|
0. the canonical Huffman code as a dict mapping symbols to bitarrays |
|
|
1. a list containing the number of symbols of each code length |
|
|
2. a list of symbols in canonical order |
|
|
|
|
|
Note: the two lists may be used as input for `canonical_decode()`. |
|
|
""" |
|
|
if not isinstance(__freq_map, dict): |
|
|
raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__) |
|
|
|
|
|
if len(__freq_map) < 2: |
|
|
if len(__freq_map) == 0: |
|
|
raise ValueError("cannot create Huffman code with no symbols") |
|
|
|
|
|
sym = list(__freq_map)[0] |
|
|
return {sym: bitarray('0', 'big')}, [0, 1], [sym] |
|
|
|
|
|
code_length = {} |
|
|
|
|
|
def traverse(nd, length=0): |
|
|
|
|
|
|
|
|
try: |
|
|
code_length[nd.symbol] = length |
|
|
except AttributeError: |
|
|
traverse(nd.child_0, length + 1) |
|
|
traverse(nd.child_1, length + 1) |
|
|
|
|
|
traverse(_huffman_tree(__freq_map)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
table = sorted(code_length.items(), key=lambda item: (item[1], item[0])) |
|
|
|
|
|
maxbits = max(item[1] for item in table) |
|
|
codedict = {} |
|
|
count = (maxbits + 1) * [0] |
|
|
|
|
|
code = 0 |
|
|
for i, (sym, length) in enumerate(table): |
|
|
codedict[sym] = int2ba(code, length, 'big') |
|
|
count[length] += 1 |
|
|
if i + 1 < len(table): |
|
|
code = (code + 1) << (table[i + 1][1] - length) |
|
|
|
|
|
return codedict, count, [item[0] for item in table] |
|
|
|