Spaces:
Running on Zero
Running on Zero
File size: 7,183 Bytes
5b8133e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | """
Arithmetic coder for neural text compression.
Uses high-precision integer arithmetic (32-bit range) with proper
renormalization and underflow handling. The encoder and decoder are
perfectly symmetric — given the same sequence of CDFs, the decoder
recovers the exact symbol sequence the encoder consumed.
"""
class ArithmeticEncoder:
"""Encodes symbols into a compressed bitstream using arithmetic coding.
Bits are packed into a bytearray on the fly instead of stored as
individual Python ints, cutting memory from O(n_bits * 28 bytes) to
O(n_bits / 8).
"""
PRECISION = 32
FULL = 1 << PRECISION # 2^32
HALF = 1 << (PRECISION - 1) # 2^31
QUARTER = 1 << (PRECISION - 2) # 2^30
MAX_RANGE = FULL - 1 # 0xFFFFFFFF
def __init__(self):
self.low = 0
self.high = self.MAX_RANGE
self.pending_bits = 0
self._buf = bytearray()
self._cur_byte = 0
self._bits_in_cur = 0
self._total_bits = 0
def _write_bit(self, bit: int):
"""Pack a single bit into the output bytearray."""
self._cur_byte = (self._cur_byte << 1) | bit
self._bits_in_cur += 1
self._total_bits += 1
if self._bits_in_cur == 8:
self._buf.append(self._cur_byte)
self._cur_byte = 0
self._bits_in_cur = 0
def _output_bit(self, bit: int):
self._write_bit(bit)
# Flush pending bits (opposite of the bit just emitted)
for _ in range(self.pending_bits):
self._write_bit(1 - bit)
self.pending_bits = 0
def encode_symbol(self, cdf, symbol_index: int):
"""Encode a single symbol given its CDF.
Args:
cdf: Cumulative distribution function. Supports both list[int]
and torch.Tensor (indexed with []). Length = num_symbols + 1.
cdf[0] = 0, cdf[-1] = total.
symbol_index: Index of the symbol to encode (0-based).
"""
total = int(cdf[-1])
rng = self.high - self.low + 1
sym_lo = int(cdf[symbol_index])
sym_hi = int(cdf[symbol_index + 1])
# Narrow the interval
self.high = self.low + (rng * sym_hi) // total - 1
self.low = self.low + (rng * sym_lo) // total
# Renormalize
while True:
if self.high < self.HALF:
# Both in lower half — output 0
self._output_bit(0)
self.low = self.low << 1
self.high = (self.high << 1) | 1
elif self.low >= self.HALF:
# Both in upper half — output 1
self._output_bit(1)
self.low = (self.low - self.HALF) << 1
self.high = ((self.high - self.HALF) << 1) | 1
elif self.low >= self.QUARTER and self.high < 3 * self.QUARTER:
# Underflow / near-convergence
self.pending_bits += 1
self.low = (self.low - self.QUARTER) << 1
self.high = ((self.high - self.QUARTER) << 1) | 1
else:
break
# Keep values in range
self.low &= self.MAX_RANGE
self.high &= self.MAX_RANGE
def finish(self) -> bytes:
"""Finalize encoding and return compressed data as bytes."""
# Flush remaining state
self.pending_bits += 1
if self.low < self.QUARTER:
self._output_bit(0)
else:
self._output_bit(1)
# Pad to byte boundary
while self._bits_in_cur != 0:
self._write_bit(0)
return bytes(self._buf)
def get_bit_count(self) -> int:
"""Return number of bits written so far (approximate)."""
return self._total_bits + self.pending_bits
class ArithmeticDecoder:
"""Decodes symbols from a compressed bitstream using arithmetic coding.
Reads bits lazily from the compressed bytes instead of expanding
every byte into 8 Python ints upfront.
"""
PRECISION = 32
FULL = 1 << PRECISION
HALF = 1 << (PRECISION - 1)
QUARTER = 1 << (PRECISION - 2)
MAX_RANGE = FULL - 1
def __init__(self, data: bytes):
self._data = data
self._byte_pos = 0
self._bit_buf = 0
self._bits_left = 0
self.low = 0
self.high = self.MAX_RANGE
# Read initial value
self.value = 0
for _ in range(self.PRECISION):
self.value = (self.value << 1) | self._read_bit()
def _read_bit(self) -> int:
if self._bits_left == 0:
if self._byte_pos < len(self._data):
self._bit_buf = self._data[self._byte_pos]
self._byte_pos += 1
self._bits_left = 8
else:
return 0 # Implicit trailing zeros
self._bits_left -= 1
return (self._bit_buf >> self._bits_left) & 1
def decode_symbol(self, cdf) -> int:
"""Decode a single symbol given its CDF.
Args:
cdf: Same CDF format as encoder. Supports both list[int] and
torch.Tensor. Length = num_symbols + 1, cdf[0] = 0,
cdf[-1] = total.
Returns:
The symbol index (0-based).
"""
total = int(cdf[-1])
rng = self.high - self.low + 1
# Find which symbol the current value falls into
scaled_value = ((self.value - self.low + 1) * total - 1) // rng
# Binary search for the symbol
num_symbols = len(cdf) - 1
lo, hi = 0, num_symbols - 1
while lo <= hi:
mid = (lo + hi) // 2
if int(cdf[mid + 1]) <= scaled_value:
lo = mid + 1
else:
hi = mid - 1
symbol = lo
sym_lo = int(cdf[symbol])
sym_hi = int(cdf[symbol + 1])
# Update range (must match encoder exactly)
self.high = self.low + (rng * sym_hi) // total - 1
self.low = self.low + (rng * sym_lo) // total
# Renormalize (must match encoder exactly)
while True:
if self.high < self.HALF:
self.low = self.low << 1
self.high = (self.high << 1) | 1
self.value = (self.value << 1) | self._read_bit()
elif self.low >= self.HALF:
self.low = (self.low - self.HALF) << 1
self.high = ((self.high - self.HALF) << 1) | 1
self.value = ((self.value - self.HALF) << 1) | self._read_bit()
elif self.low >= self.QUARTER and self.high < 3 * self.QUARTER:
self.low = (self.low - self.QUARTER) << 1
self.high = ((self.high - self.QUARTER) << 1) | 1
self.value = ((self.value - self.QUARTER) << 1) | self._read_bit()
else:
break
self.low &= self.MAX_RANGE
self.high &= self.MAX_RANGE
self.value &= self.MAX_RANGE
return symbol
|