from typing import List, Tuple import itertools import math def _pack_rank_ids(buf: List[int], rank_bitlength: int) -> List[int]: per_b = 8 // rank_bitlength mask = (1 << rank_bitlength) - 1 out_b = [] it = iter(buf) while True: chunk = list(itertools.islice(it, per_b)) if not chunk: break byte_val = 0 for p, idx in enumerate(chunk): byte_val |= (idx & mask) << (p * rank_bitlength) out_b.append(byte_val) return out_b def _unpack_rank_ids(payload: List[int], run_len: int, rank_bitlength: int): mask = (1 << rank_bitlength) - 1 byte_iter = iter(payload) cur_byte = next(byte_iter) filled = 8 for _ in range(run_len): if filled == 0: cur_byte = next(byte_iter) filled = 8 rank_id = cur_byte & mask cur_byte >>= rank_bitlength filled -= rank_bitlength yield rank_id class SimpleAdaptiveRankCodec: def __init__( self, top_k: int = 4, tau: float = 0.5, min_run: int = 3, max_run: int = 255, sentinel_rle: int = 256, sentinel_rank_run: int = 257, ): self.top_k = top_k self.tau = tau self.min_run = min_run self.max_run = max_run self.raw_byte_offset = 256 self.rank_bitlength = max(1, (top_k - 1).bit_length()) assert self.rank_bitlength <= 8 and 8 % self.rank_bitlength == 0, ( f"rank_bitlength must be between 1 and 8 and must divide 8, got {self.rank_bitlength}" f"top_k: {top_k}" ) self.ranks_per_byte = 8 // self.rank_bitlength self.sentinel_rle = sentinel_rle self.sentinel_rank_run = sentinel_rank_run def encode_window( self, tokens: List[int], repeat_probs: List[float], ranks: List[int], ) -> List[int]: """Return a list of ints: raw bytes 0-255 and sentinel events ≥256.""" assert len(tokens) == len(repeat_probs) == len(ranks) rank_buf: List[int] = [] out: List[int] = [tokens[0]] i, n = 1, len(tokens) def flush_rank_buf(): if not rank_buf: return out.append(self.sentinel_rank_run) out.append(len(rank_buf)) out.extend(_pack_rank_ids(rank_buf, self.rank_bitlength)) rank_buf.clear() while i < n: tok = tokens[i] # --- RLE probe (uses *current* token prob) -------------------- run = 1 while (i + run < n and tokens[i + run] == tok and repeat_probs[i + run] >= self.tau): run += 1 if run >= self.min_run: flush_rank_buf() out.extend([self.sentinel_rle, run, tok]) i += run continue if ranks[i] < self.top_k: rank_buf.append(ranks[i]) else: # the current token is not in top-K, # so we escape to a raw byte flush_rank_buf() out.append(tok) i += 1 flush_rank_buf() return out def encoding_to_pseudo_bytes(self, enc: list[int]) -> list[int]: # NOTE: this function is not expected to be lossless, that is, # we cannot reconstruct the original encoding from the pseudo-bytes out: list[int] = [] i = 0 while i < len(enc): tok = enc[i] i += 1 if tok < self.raw_byte_offset: out.append(tok) elif tok == self.sentinel_rle: run = enc[i] raw = enc[i+1] i += 2 run = min(run, self.max_run) # we mark the run length from 512 to 256 out.extend([self.raw_byte_offset + self.raw_byte_offset - run, raw]) elif tok == self.sentinel_rank_run: length = enc[i] i += 1 n_bytes = math.ceil(length / self.ranks_per_byte) for _ in range(n_bytes): pb = enc[i] + self.raw_byte_offset out.append(pb) i += 1 else: raise ValueError(f"unknown token {tok}") return out def pseudo_bytes_to_encoding(self, pb: list[int], original_encoding: list[int]) -> list[int]: # NOTE: we do not expect the encoding-to-pseudo-bytes conversion to be lossless, # so we need to pass the original encoding to reconstruct the original encoding # this function is just for sanity check raise NotImplementedError("Not implemented") def decode_window( self, stream: List[int], original_len: int, topk_symbols: List[List[int]], ) -> List[int]: """ `topk_symbols[pos][idx]` must give the byte value (0-255) that corresponds to rank `idx` at position `pos`, e.g. recomputed from the helper LM during decoding. """ out: List[int] = [] # position in input stream i = 0 # position in output tokens pos = 0 while pos < original_len: tok = stream[i] i += 1 if tok < 256: out.append(tok) pos += 1 elif tok == self.sentinel_rle: run_len = stream[i] raw = stream[i+1] i += 2 out.extend([raw] * run_len) pos += run_len elif tok == self.sentinel_rank_run: run_len = stream[i] i += 1 bytes_needed = math.ceil(run_len / self.ranks_per_byte) payload = stream[i: i + bytes_needed] i += bytes_needed for rank_id in _unpack_rank_ids(payload, run_len, self.rank_bitlength): sym = topk_symbols[pos][rank_id] out.append(sym) pos += 1 else: raise ValueError(f"Unknown sentinel {tok}") return out[:original_len] if __name__ == "__main__": import torch, random random.seed(0) T, K = 384, 13 # demonstrate non-power-of-two K tokens = torch.randint(0, 32, (T,)).tolist() repeat_probs = torch.rand(T).tolist() ranks = torch.randint(0, K + 5, (T,)).tolist() # some ranks ≥K → raw ranks = [r if r < K else K for r in ranks] # fake LM top-K table for demo: identity mapping topk = [[tokens[t]] * K for t in range(T)] codec = SimpleAdaptiveRankCodec(top_k=K, tau=0.00) enc = codec.encode_window(tokens, repeat_probs, ranks) dec = codec.decode_window(enc, T, topk) print(f"raw={T} encoded={len(enc)} ratio={len(enc)/T:.2f}") assert dec == tokens print("✓ window-enc-dec round-trip passes")