| | |
| | |
| | |
| | |
| | |
| |
|
| | """Arithmetic coder.""" |
| |
|
| | import io |
| | import math |
| | import random |
| | import typing as tp |
| | import torch |
| |
|
| | from ..binary import BitPacker, BitUnpacker |
| |
|
| |
|
| | def build_stable_quantized_cdf( |
| | pdf: torch.Tensor, total_range_bits: int, roundoff: float = 1e-8, min_range: int = 2, check: bool = True |
| | ) -> torch.Tensor: |
| | """Turn the given PDF into a quantized CDF that splits |
| | [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional |
| | to the PDF. |
| | |
| | Args: |
| | pdf (torch.Tensor): probability distribution, shape should be `[N]`. |
| | total_range_bits (int): see `ArithmeticCoder`, the typical range we expect |
| | during the coding process is `[0, 2 ** total_range_bits - 1]`. |
| | roundoff (float): will round the pdf up to that level to remove difference coming |
| | from e.g. evaluating the Language Model on different architectures. |
| | min_range (int): minimum range width. Should always be at least 2 for numerical |
| | stability. Use this to avoid pathological behavior is a value |
| | that is expected to be rare actually happens in real life. |
| | check (bool): if True, checks that nothing bad happened, can be deactivated for speed. |
| | """ |
| | pdf = pdf.detach() |
| | if roundoff: |
| | pdf = (pdf / roundoff).floor() * roundoff |
| | |
| | total_range = 2**total_range_bits |
| | cardinality = len(pdf) |
| | alpha = min_range * cardinality / total_range |
| | assert alpha <= 1, "you must reduce min_range" |
| | ranges = (((1 - alpha) * total_range) * pdf).floor().long() |
| | ranges += min_range |
| | quantized_cdf = torch.cumsum(ranges, dim=-1) |
| | if min_range < 2: |
| | raise ValueError("min_range must be at least 2.") |
| | if check: |
| | assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] |
| | if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: |
| | raise ValueError("You must increase your total_range_bits.") |
| | return quantized_cdf |
| |
|
| |
|
| | class ArithmeticCoder: |
| | """ArithmeticCoder, |
| | Let us take a distribution `p` over `N` symbols, and assume we have a stream |
| | of random variables `s_t` sampled from `p`. Let us assume that we have a budget |
| | of `B` bits that we can afford to write on device. There are `2**B` possible numbers, |
| | corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single |
| | sequence `(s_t)` by doing the following: |
| | |
| | 1) Initialize the current range to` [0 ** 2 B - 1]`. |
| | 2) For each time step t, split the current range into contiguous chunks, |
| | one for each possible outcome, with size roughly proportional to `p`. |
| | For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks |
| | would be `{[0, 2], [3, 3]}`. |
| | 3) Select the chunk corresponding to `s_t`, and replace the current range with this. |
| | 4) When done encoding all the values, just select any value remaining in the range. |
| | |
| | You will notice that this procedure can fail: for instance if at any point in time |
| | the range is smaller than `N`, then we can no longer assign a non-empty chunk to each |
| | possible outcome. Intuitively, the more likely a value is, the less the range width |
| | will reduce, and the longer we can go on encoding values. This makes sense: for any efficient |
| | coding scheme, likely outcomes would take less bits, and more of them can be coded |
| | with a fixed budget. |
| | |
| | In practice, we do not know `B` ahead of time, but we have a way to inject new bits |
| | when the current range decreases below a given limit (given by `total_range_bits`), without |
| | having to redo all the computations. If we encode mostly likely values, we will seldom |
| | need to inject new bits, but a single rare value can deplete our stock of entropy! |
| | |
| | In this explanation, we assumed that the distribution `p` was constant. In fact, the present |
| | code works for any sequence `(p_t)` possibly different for each timestep. |
| | We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller |
| | the KL between the true distribution and `p_t`, the most efficient the coding will be. |
| | |
| | Args: |
| | fo (IO[bytes]): file-like object to which the bytes will be written to. |
| | total_range_bits (int): the range `M` described above is `2 ** total_range_bits. |
| | Any time the current range width fall under this limit, new bits will |
| | be injected to rescale the initial range. |
| | """ |
| |
|
| | def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): |
| | assert total_range_bits <= 30 |
| | self.total_range_bits = total_range_bits |
| | self.packer = BitPacker(bits=1, fo=fo) |
| | self.low: int = 0 |
| | self.high: int = 0 |
| | self.max_bit: int = -1 |
| | self._dbg: tp.List[tp.Any] = [] |
| | self._dbg2: tp.List[tp.Any] = [] |
| |
|
| | @property |
| | def delta(self) -> int: |
| | """Return the current range width.""" |
| | return self.high - self.low + 1 |
| |
|
| | def _flush_common_prefix(self): |
| | |
| | |
| | |
| | assert self.high >= self.low, (self.low, self.high) |
| | assert self.high < 2 ** (self.max_bit + 1) |
| | while self.max_bit >= 0: |
| | b1 = self.low >> self.max_bit |
| | b2 = self.high >> self.max_bit |
| | if b1 == b2: |
| | self.low -= b1 << self.max_bit |
| | self.high -= b1 << self.max_bit |
| | assert self.high >= self.low, (self.high, self.low, self.max_bit) |
| | assert self.low >= 0 |
| | self.max_bit -= 1 |
| | self.packer.push(b1) |
| | else: |
| | break |
| |
|
| | def push(self, symbol: int, quantized_cdf: torch.Tensor): |
| | """Push the given symbol on the stream, flushing out bits |
| | if possible. |
| | |
| | Args: |
| | symbol (int): symbol to encode with the AC. |
| | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` |
| | to build this from your pdf estimate. |
| | """ |
| | while self.delta < 2**self.total_range_bits: |
| | self.low *= 2 |
| | self.high = self.high * 2 + 1 |
| | self.max_bit += 1 |
| |
|
| | range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() |
| | range_high = quantized_cdf[symbol].item() - 1 |
| | effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits)))) |
| | effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits)))) |
| | assert self.low <= self.high |
| | self.high = self.low + effective_high |
| | self.low = self.low + effective_low |
| | assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) |
| | self._dbg.append((self.low, self.high)) |
| | self._dbg2.append((self.low, self.high)) |
| | outs = self._flush_common_prefix() |
| | assert self.low <= self.high |
| | assert self.max_bit >= -1 |
| | assert self.max_bit <= 61, self.max_bit |
| | return outs |
| |
|
| | def flush(self): |
| | """Flush the remaining information to the stream.""" |
| | while self.max_bit >= 0: |
| | b1 = (self.low >> self.max_bit) & 1 |
| | self.packer.push(b1) |
| | self.max_bit -= 1 |
| | self.packer.flush() |
| |
|
| |
|
| | class ArithmeticDecoder: |
| | """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. |
| | |
| | Note that this must be called with **exactly** the same parameters and sequence |
| | of quantized cdf as the arithmetic encoder or the wrong values will be decoded. |
| | |
| | If the AC encoder current range is [L, H], with `L` and `H` having the some common |
| | prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. |
| | For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside |
| | `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained |
| | for a specific sequence of symbols and a binary-search allows us to decode those symbols. |
| | At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, |
| | and we will need to read new bits from the stream and repeat the process. |
| | |
| | """ |
| |
|
| | def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): |
| | self.total_range_bits = total_range_bits |
| | self.low: int = 0 |
| | self.high: int = 0 |
| | self.current: int = 0 |
| | self.max_bit: int = -1 |
| | self.unpacker = BitUnpacker(bits=1, fo=fo) |
| | |
| | self._dbg: tp.List[tp.Any] = [] |
| | self._dbg2: tp.List[tp.Any] = [] |
| | self._last: tp.Any = None |
| |
|
| | @property |
| | def delta(self) -> int: |
| | return self.high - self.low + 1 |
| |
|
| | def _flush_common_prefix(self): |
| | |
| | |
| | while self.max_bit >= 0: |
| | b1 = self.low >> self.max_bit |
| | b2 = self.high >> self.max_bit |
| | if b1 == b2: |
| | self.low -= b1 << self.max_bit |
| | self.high -= b1 << self.max_bit |
| | self.current -= b1 << self.max_bit |
| | assert self.high >= self.low |
| | assert self.low >= 0 |
| | self.max_bit -= 1 |
| | else: |
| | break |
| |
|
| | def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: |
| | """Pull a symbol, reading as many bits from the stream as required. |
| | This returns `None` when the stream has been exhausted. |
| | |
| | Args: |
| | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` |
| | to build this from your pdf estimate. This must be **exatly** |
| | the same cdf as the one used at encoding time. |
| | """ |
| | while self.delta < 2**self.total_range_bits: |
| | bit = self.unpacker.pull() |
| | if bit is None: |
| | return None |
| | self.low *= 2 |
| | self.high = self.high * 2 + 1 |
| | self.current = self.current * 2 + bit |
| | self.max_bit += 1 |
| |
|
| | def bin_search(low_idx: int, high_idx: int): |
| | |
| | if high_idx < low_idx: |
| | raise RuntimeError("Binary search failed") |
| | mid = (low_idx + high_idx) // 2 |
| | range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 |
| | range_high = quantized_cdf[mid].item() - 1 |
| | effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits)))) |
| | effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits)))) |
| | low = effective_low + self.low |
| | high = effective_high + self.low |
| | if self.current >= low: |
| | if self.current <= high: |
| | return (mid, low, high, self.current) |
| | else: |
| | return bin_search(mid + 1, high_idx) |
| | else: |
| | return bin_search(low_idx, mid - 1) |
| |
|
| | self._last = (self.low, self.high, self.current, self.max_bit) |
| | sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) |
| | self._dbg.append((self.low, self.high, self.current)) |
| | self._flush_common_prefix() |
| | self._dbg2.append((self.low, self.high, self.current)) |
| |
|
| | return sym |
| |
|
| |
|
| | def test(): |
| | torch.manual_seed(1234) |
| | random.seed(1234) |
| | for _ in range(4): |
| | pdfs = [] |
| | cardinality = random.randrange(4000) |
| | steps = random.randrange(100, 500) |
| | fo = io.BytesIO() |
| | encoder = ArithmeticCoder(fo) |
| | symbols = [] |
| | for step in range(steps): |
| | pdf = torch.softmax(torch.randn(cardinality), dim=0) |
| | pdfs.append(pdf) |
| | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) |
| | symbol = torch.multinomial(pdf, 1).item() |
| | symbols.append(symbol) |
| | encoder.push(symbol, q_cdf) |
| | encoder.flush() |
| |
|
| | fo.seek(0) |
| | decoder = ArithmeticDecoder(fo) |
| | for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): |
| | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) |
| | decoded_symbol = decoder.pull(q_cdf) |
| | assert decoded_symbol == symbol, idx |
| | assert decoder.pull(torch.zeros(1)) is None |
| |
|
| |
|
| | if __name__ == "__main__": |
| | test() |
| |
|