File size: 13,137 Bytes
72c0672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 | # Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements an Arithmetic Encoder and Decoder."""
from typing import Any, Callable
import numpy as np
InputFn = Callable[[], int]
OutputFn = Callable[[int], None]
IOFn = InputFn | OutputFn
def _log_power_of_b(n: int, base: int) -> int:
"""Returns k assuming n = base ** k.
We manually implement this function to be faster than a np.log or math.log,
which doesn't assume n is an integer.
Args:
n: The integer of which we want the log.
base: The base of the log.
"""
log_n = 0
while n > 1:
n //= base
log_n += 1
return log_n
def _raise_post_terminate_exception(*args: Any, **kwargs: Any) -> None:
"""Dummy function that raises an error to ensure AC termination."""
del args, kwargs
raise ValueError(
"Arithmetic encoder was terminated. "
"Create a new instance for encoding more data. "
"Do NOT use an output function that writes to the same data sink "
"used by the output function of this instance. "
"This will corrupt the arithmetic code as decoding relies on detecting "
"when the compressed data stream is exhausted."
)
class _CoderBase:
"""Arithmetic coder (AC) base class."""
def __init__(self, base: int, precision: int, io_fn: IOFn):
"""Does initialization shared by AC encoder and decoder.
Args:
base: The arithmetic coder will output digits in {0, 1, ..., base - 1}.
precision: Precision for internal state; on the average this will waste
code space worth at most 1/log(base) * base ** -(precision - 2) digits
of output per coding step.
io_fn: Function to write digits to compressed stream/read digits from
compressed stream.
"""
assert base >= 2
assert precision >= 2
self._base: int = base
self._base_to_pm1: int = int(base ** (precision - 1))
self._base_to_pm2: int = int(base ** (precision - 2))
self._io_fn = io_fn
# NOTE: We represent the AC interval [0, 1) as rational numbers:
# [0, 1)
# ~ [self._low / base ** precision, (self._high + 1) / base ** precision)
# = [self._low / base ** precision, self._high / base ** precision],
# where the we represent the upper bound *INCLUSIVE*. This is a subtle
# detail required to make the integer arithmetic work correctly given that
# all involved integers have `precision` digits in base `base`.
self._low: int = 0
self._high: int = int(base**precision) - 1
self._num_carry_digits: int = 0
self._code: int = 0
def __str__(self) -> str:
"""Returns string describing internal state."""
if self._base > 16:
raise ValueError("`__str__` with `base` exceeding 16 not implmeneted.")
p = 1 + _log_power_of_b(self._base_to_pm1, base=self._base)
def _to_str(x: int) -> str:
"""Returns representation of `n` in base `self._base`."""
digits = [(x // self._base**i) % self._base for i in range(p)]
return f"{digits[-1]:x}<C:{self._num_carry_digits:d}>" + "".join(
f"{d:x}" for d in digits[-2::-1]
)
return (
f"[{_to_str(self._low)}, {_to_str(self._high)}) {_to_str(self._code)}"
)
def _get_intervals(self, pdf: np.ndarray) -> np.ndarray:
"""Partition the current AC interval according to the distribution `pdf`."""
if (pdf < 0).any():
raise ValueError(
"Some probabilities are negative. Please make sure that pdf[x] > 0."
)
# Compute CPDF s.t. cpdf[x] = sum_y<x Pr[y] for 0 <= x < alphabet_size and
# add a sentinel s.t. cpdf[alphabet_size] = 1.0. This partitions [0, 1) into
# non-overlapping intervals, the interval [cpdf[x], cpdf[x + 1]) represents
# symbol x. Since AC relies on an integer representation we rescale this
# into the current AC range `high - low + 1` and quantise, this yields the
# quantised CPDF `qcpdf`.
width = self._high - self._low + 1
qcpdf = (np.insert(pdf, 0, 0).cumsum() * width).astype(int)
if (qcpdf[1:] == qcpdf[:-1]).any():
raise ValueError(
"Some probabilities are 0 after quantisation. Please make sure that:"
" pdf[x] >= max(base ** -(precision - 2), np.dtype(x).eps) for any"
" symbol by either preprocessing `pdf` or by increasing `precision`."
)
if qcpdf[-1] > width:
raise ValueError(
"Cumulative sum of probabilities exceeds 1 after quantisation. "
"Please make sure that sum(pdf) <= 1.0 - eps, for a small eps > 0."
)
return self._low + qcpdf
def _remove_matching_digits(self, low_pre_split: int, encoding: bool) -> None:
"""Remove matching most significant digits from AC state [low, high).
This is the *FIRST* normalization step after encoding a symbol into the AC
state.
When encoding we write the most significant matching digits of the
integer representation of [low, high) to the output, widen the integer
representation of [low, high) including a (potential) queue of carry digits;
when decoding we drop the matching most significant digits of the integer
representation of [low, high), widen this interval and keep the current
slice of the arithmetic code word `self._code` in sync.
Args:
low_pre_split: Value of `self._low` before encoding a new symbol into the
AC state when `encoding` is True; abitrary, otherwise.
encoding: Are we encoding (i.e. normalise by writing data) or decoding
(i.e. normalise by reading data)?
"""
def _shift_left(x: int) -> int:
"""Shift `x` one digit left."""
return (x % self._base_to_pm1) * self._base
while self._low // self._base_to_pm1 == self._high // self._base_to_pm1:
if encoding:
low_msd = self._low // self._base_to_pm1
self._io_fn(low_msd)
# Note that carry digits will only be written in the first round of this
# loop.
carry_digit = (
self._base - 1 + low_msd - low_pre_split // self._base_to_pm1
) % self._base
assert carry_digit in {0, self._base - 1} or self._num_carry_digits == 0
while self._num_carry_digits > 0:
self._io_fn(carry_digit)
self._num_carry_digits -= 1
else:
self._code = _shift_left(self._code) + self._io_fn()
self._low = _shift_left(self._low)
self._high = _shift_left(self._high) + self._base - 1
def _remove_carry_digits(self, encoding: bool) -> None:
"""Remove and record 2nd most significant digits from AC state [low, high).
This is the *SECOND* normalization step after encoding a symbol into the AC
state [low, high).
If the AC state takes the form
low = x B-1 B-1 ... B-1 u ...
high = x+1 0 0 0 v ...
^__ prefix __^
where x, u and v are base-B digits then low and high can get arbitrarily (
well, by means of infinite precision arithmetics) without matching. Since we
work with finite precision arithmetics, we must make sure that this doesn't
occour and we guarantee sufficient of coding range (`high - low`). To end
this we detect the above situation and cut off the highlighted prefix above
to widen the integer representation of [low, high) and record the number of
prefix digits removed. When decoding we must similarly process the current
slice of the arithmetic code word `self._code` to keep it in sync.
Args:
encoding: Are we encoding (i.e. normalise by writing data) or decoding
(i.e. normalise by reading data)?
"""
def _shift_left_keeping_msd(x: int) -> int:
"""Shift `x` except MSD, which remains in place, one digit left."""
return x - (x % self._base_to_pm1) + (x % self._base_to_pm2) * self._base
while self._low // self._base_to_pm2 + 1 == self._high // self._base_to_pm2:
if encoding:
self._num_carry_digits += 1
else:
self._code = _shift_left_keeping_msd(self._code) + self._io_fn()
self._low = _shift_left_keeping_msd(self._low)
self._high = _shift_left_keeping_msd(self._high) + self._base - 1
def _process(self, pdf: np.ndarray, symbol: int | None) -> int:
"""Perform an AC encoding or decoding step and modify AC state in-place.
Args:
pdf: Probability distribution over input alphabet.
symbol: Letter to encode from {0, 1, ..., pdf.size - 1} when encoding or
`None` when decoding.
Returns:
y: `symbol` from above when encoding or decoded letter from {0, 1, ...,
pdf.size - 1}.
"""
encoding = symbol is not None
intervals = self._get_intervals(pdf)
if not encoding:
symbol = np.searchsorted(intervals, self._code, side="right") - 1
assert 0 <= symbol < pdf.size
low_pre_split = self._low
self._low, self._high = intervals[[symbol, symbol + 1]]
# Due to integer arithmetics the integer representation of [low, high) has
# an inclusive upper bound, so decrease high.
self._high -= 1
assert 0 <= self._low <= self._high < self._base_to_pm1 * self._base
# Normalize the AC state.
self._remove_matching_digits(low_pre_split=low_pre_split, encoding=encoding)
assert 0 <= self._low <= self._high < self._base_to_pm1 * self._base
assert encoding or self._low <= self._code <= self._high
assert self._low // self._base_to_pm1 != self._high // self._base_to_pm1
self._remove_carry_digits(encoding=encoding)
assert 0 <= self._low <= self._high < self._base_to_pm1 * self._base
assert encoding or self._low <= self._code <= self._high
assert self._high - self._low > self._base_to_pm2
return symbol
@classmethod
def p_min(cls, base: int, precision: int) -> float:
"""Get minimum probability supported by AC config."""
# The leading factor 2 is supposed to account for rounding errors and
# wouldn't be necessary given infinite float precision.
return 2.0 * base ** -(precision - 2)
class Encoder(_CoderBase):
"""Arithmetic encoder."""
def __init__(self, base: int, precision: int, output_fn: OutputFn):
"""Constructs arithmetic encoder.
Args:
base: The arithmetic coder will output digits in {0, 1, ..., base - 1}.
precision: Precision for internal state; on the average this will waste
code space worth at most 1/log(base) * base ** -(precision - 2) digits
of output per coding step.
output_fn: Function that writes a digit from {0, 1, ..., base - 1} to the
compressed output.
"""
super().__init__(base, precision, output_fn)
def encode(self, pdf: np.ndarray, symbol: int) -> None:
"""Encodes symbol `symbol` assuming coding distribution `pdf`."""
self._process(pdf, symbol)
def terminate(self) -> None:
"""Finalizes arithmetic code."""
# Write outstanding part of the arithmetic code plus one digit to uniquely
# determine a code within the interval of the final symbol coded.
self._io_fn(self._low // self._base_to_pm1)
for _ in range(self._num_carry_digits):
self._io_fn(self._base - 1)
self.encode = _raise_post_terminate_exception
self.terminate = _raise_post_terminate_exception
class Decoder(_CoderBase):
"""Arithmetic decoder."""
def __init__(self, base: int, precision: int, input_fn: InputFn):
"""Constructs arithmetic decoder.
Args:
base: The arithmetic coder will output digits in {0, 1, ..., base - 1}.
precision: Precision for internal state; on the average this will waste
code space worth at most 1/log(base) * base ** -(precision - 2) digits
of output per coding step.
input_fn: Function that reads a digit from {0, 1, ..., base - 1} from the
compressed input or returns `None` when the input is exhausted.
"""
# Add padding to ensure the AC state is well-defined when decoding the last
# symbol. Note that what exactly we do here depends on how encoder
# termination is implemented (see `Encoder.terminate`).
trailing_digits = (base - 1 for _ in range(precision - 1))
def _padded_input_fn() -> int:
"""Reads digit from input padding the arithmetic code."""
digit = input_fn()
if digit is None:
digit = next(trailing_digits)
assert 0 <= int(digit) < base
return digit
super().__init__(base, precision, _padded_input_fn)
for _ in range(precision):
self._code = self._code * base + _padded_input_fn()
def decode(self, pdf: np.ndarray) -> int:
return self._process(pdf, None)
|