enumerative-entropy-coding / enumerative_coding.py
osteele's picture
Initial commit
24c19d8 unverified
#!/usr/bin/env python3
"""
Correct implementation of enumerative entropy coding as described in Han et al. (2008).
This version is fully self-contained, embedding all necessary data into the stream.
"""
import numpy as np
from typing import List, Dict, Tuple, Optional
from collections import Counter
import math
class ExpGolombCoder:
"""Exponential-Golomb coding for non-negative integers."""
@staticmethod
def encode(n: int) -> str:
"""Encodes a non-negative integer n >= 0."""
if n < 0:
raise ValueError("Exp-Golomb is for non-negative integers.")
n_plus_1 = n + 1
binary = bin(n_plus_1)[2:]
leading_zeros = '0' * (len(binary) - 1)
return leading_zeros + binary
@staticmethod
def decode(bits: str, start_pos: int = 0) -> Tuple[int, int]:
"""Decodes an exp-Golomb integer from a bit string."""
pos = start_pos
leading_zeros = 0
while pos < len(bits) and bits[pos] == '0':
leading_zeros += 1
pos += 1
if pos >= len(bits):
raise ValueError("Incomplete exp-Golomb code: no '1' bit found.")
num_bits_to_read = leading_zeros + 1
if pos + num_bits_to_read > len(bits):
raise ValueError("Incomplete exp-Golomb code: not enough bits for value.")
code_bits = bits[pos:pos + num_bits_to_read]
value = int(code_bits, 2) - 1
return value, pos + num_bits_to_read
class OptimizedBinomialTable:
"""
Computes and caches binomial coefficients C(n, k) using Python's arbitrary
precision integers to prevent overflow.
"""
def __init__(self):
self._cache = {}
def get(self, n: int, k: int) -> int:
if k < 0 or k > n:
return 0
if k == 0 or k == n:
return 1
if k > n // 2:
k = n - k
key = (n, k)
if key in self._cache:
return self._cache[key]
result = math.comb(n, k)
self._cache[key] = result
return result
def __getitem__(self, n: int):
return BinomialRow(self, n)
class BinomialRow:
"""Helper class to support table[n][k] syntax."""
def __init__(self, table: OptimizedBinomialTable, n: int):
self.table = table
self.n = n
def __getitem__(self, k: int) -> int:
return self.table.get(self.n, k)
class EnumerativeEncoder:
"""
An enumerative entropy coder aligned with the algorithm described in
"Entropy Coding Using Equiprobable Partitioning" by Han et al. (2008).
This implementation is self-contained, writing all necessary information
(length, alphabet, counts, and positions) into the output stream.
"""
def __init__(self):
self.binom_table = OptimizedBinomialTable()
def _rank(self, n: int, k: int, positions: List[int]) -> int:
"""Calculates the standard lexicographical rank of a combination."""
index = 0
for i, pos in enumerate(positions):
index += self.binom_table.get(pos, i + 1)
return index
def _unrank(self, n: int, k: int, index: int) -> List[int]:
"""Converts a standard lexicographical rank back to a combination."""
positions = []
v_high = n - 1
for i in range(k - 1, -1, -1):
v_low = i
# Binary search for the largest position p_i
while v_low < v_high:
mid = (v_low + v_high + 1) // 2
if self.binom_table.get(mid, i + 1) <= index:
v_low = mid
else:
v_high = mid - 1
p_i = v_low
positions.append(p_i)
index -= self.binom_table.get(p_i, i + 1)
v_high = p_i - 1
positions.reverse() # Stored descending, so reverse to ascending
return positions
def encode(self, data: List[int]) -> bytes:
if not data:
return bytes()
n = len(data)
symbol_counts = Counter(data)
# Optimization: encode symbols from least frequent to most frequent
sorted_symbols = sorted(symbol_counts.keys(), key=lambda s: symbol_counts[s])
K = len(sorted_symbols)
bits = ""
# Step 1: Encode sequence length n
bits += ExpGolombCoder.encode(n)
# Step 2: Encode header - alphabet size (K) and the alphabet itself
bits += ExpGolombCoder.encode(K)
for symbol in sorted_symbols:
bits += ExpGolombCoder.encode(symbol)
# Step 3: Encode K-1 symbol frequencies
for i in range(K - 1):
bits += ExpGolombCoder.encode(symbol_counts[sorted_symbols[i]])
# Step 4: Encode symbol locations sequentially
available_indices = list(range(n))
for i in range(K - 1):
symbol = sorted_symbols[i]
k = symbol_counts[symbol]
if k == 0:
continue
current_n = len(available_indices)
# Find the positions of the current symbol within the available slots
symbol_positions_in_available = [
j for j, original_idx in enumerate(available_indices) if data[original_idx] == symbol
]
# Optimization: Use complement method for frequent symbols
use_complement = k > current_n / 2
bits += '1' if use_complement else '0'
if use_complement:
complement_k = current_n - k
complement_positions = [j for j in range(current_n) if j not in symbol_positions_in_available]
index = self._rank(current_n, complement_k, complement_positions)
else:
index = self._rank(current_n, k, symbol_positions_in_available)
bits += ExpGolombCoder.encode(index)
# Update available indices for the next symbol
used_indices = {available_indices[j] for j in symbol_positions_in_available}
available_indices = [idx for idx in available_indices if idx not in used_indices]
# Convert bit string to bytes with padding
padding = (8 - len(bits) % 8) % 8
bits += '0' * padding
encoded_bytes = bytes(int(bits[i:i+8], 2) for i in range(0, len(bits), 8))
return encoded_bytes
def decode(self, encoded_bytes: bytes) -> List[int]:
if not encoded_bytes:
return []
# Convert bytes to bit string
bits = ''.join(format(byte, '08b') for byte in encoded_bytes)
pos = 0
# Step 1: Decode sequence length n
n, pos = ExpGolombCoder.decode(bits, pos)
# Step 2: Decode header - alphabet size (K) and the alphabet itself
K, pos = ExpGolombCoder.decode(bits, pos)
sorted_symbols = []
for _ in range(K):
symbol, pos = ExpGolombCoder.decode(bits, pos)
sorted_symbols.append(symbol)
# Step 3: Decode K-1 symbol frequencies
counts = {}
decoded_count_sum = 0
for i in range(K - 1):
symbol = sorted_symbols[i]
count, pos = ExpGolombCoder.decode(bits, pos)
counts[symbol] = count
decoded_count_sum += count
# The last symbol's count is implied
last_symbol = sorted_symbols[-1]
counts[last_symbol] = n - decoded_count_sum
# Step 4: Decode symbol locations sequentially
result = [None] * n
available_indices = list(range(n))
for i in range(K - 1):
symbol = sorted_symbols[i]
k = counts[symbol]
if k == 0:
continue
current_n = len(available_indices)
# Read complement flag
use_complement = (bits[pos] == '1')
pos += 1
index, pos = ExpGolombCoder.decode(bits, pos)
if use_complement:
complement_k = current_n - k
complement_positions = self._unrank(current_n, complement_k, index)
positions_in_available = [j for j in range(current_n) if j not in complement_positions]
else:
positions_in_available = self._unrank(current_n, k, index)
# Map positions from available list back to original sequence
used_indices = set()
for rel_pos in positions_in_available:
abs_pos = available_indices[rel_pos]
result[abs_pos] = symbol
used_indices.add(abs_pos)
# Update available indices
available_indices = [idx for idx in available_indices if idx not in used_indices]
# Last symbol fills all remaining positions
for i in range(n):
if result[i] is None:
result[i] = last_symbol
return result