File size: 7,016 Bytes
72c0672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | from typing import List, Tuple
import itertools
import math
def _pack_rank_ids(buf: List[int], rank_bitlength: int) -> List[int]:
per_b = 8 // rank_bitlength
mask = (1 << rank_bitlength) - 1
out_b = []
it = iter(buf)
while True:
chunk = list(itertools.islice(it, per_b))
if not chunk:
break
byte_val = 0
for p, idx in enumerate(chunk):
byte_val |= (idx & mask) << (p * rank_bitlength)
out_b.append(byte_val)
return out_b
def _unpack_rank_ids(payload: List[int], run_len: int, rank_bitlength: int):
mask = (1 << rank_bitlength) - 1
byte_iter = iter(payload)
cur_byte = next(byte_iter)
filled = 8
for _ in range(run_len):
if filled == 0:
cur_byte = next(byte_iter)
filled = 8
rank_id = cur_byte & mask
cur_byte >>= rank_bitlength
filled -= rank_bitlength
yield rank_id
class SimpleAdaptiveRankCodec:
def __init__(
self,
top_k: int = 4,
tau: float = 0.5,
min_run: int = 3,
max_run: int = 255,
sentinel_rle: int = 256,
sentinel_rank_run: int = 257,
):
self.top_k = top_k
self.tau = tau
self.min_run = min_run
self.max_run = max_run
self.raw_byte_offset = 256
self.rank_bitlength = max(1, (top_k - 1).bit_length())
assert self.rank_bitlength <= 8 and 8 % self.rank_bitlength == 0, (
f"rank_bitlength must be between 1 and 8 and must divide 8, got {self.rank_bitlength}"
f"top_k: {top_k}"
)
self.ranks_per_byte = 8 // self.rank_bitlength
self.sentinel_rle = sentinel_rle
self.sentinel_rank_run = sentinel_rank_run
def encode_window(
self,
tokens: List[int],
repeat_probs: List[float],
ranks: List[int],
) -> List[int]:
"""Return a list of ints: raw bytes 0-255 and sentinel events ≥256."""
assert len(tokens) == len(repeat_probs) == len(ranks)
rank_buf: List[int] = []
out: List[int] = [tokens[0]]
i, n = 1, len(tokens)
def flush_rank_buf():
if not rank_buf:
return
out.append(self.sentinel_rank_run)
out.append(len(rank_buf))
out.extend(_pack_rank_ids(rank_buf, self.rank_bitlength))
rank_buf.clear()
while i < n:
tok = tokens[i]
# --- RLE probe (uses *current* token prob) --------------------
run = 1
while (i + run < n and
tokens[i + run] == tok and
repeat_probs[i + run] >= self.tau):
run += 1
if run >= self.min_run:
flush_rank_buf()
out.extend([self.sentinel_rle, run, tok])
i += run
continue
if ranks[i] < self.top_k:
rank_buf.append(ranks[i])
else:
# the current token is not in top-K,
# so we escape to a raw byte
flush_rank_buf()
out.append(tok)
i += 1
flush_rank_buf()
return out
def encoding_to_pseudo_bytes(self, enc: list[int]) -> list[int]:
# NOTE: this function is not expected to be lossless, that is,
# we cannot reconstruct the original encoding from the pseudo-bytes
out: list[int] = []
i = 0
while i < len(enc):
tok = enc[i]
i += 1
if tok < self.raw_byte_offset:
out.append(tok)
elif tok == self.sentinel_rle:
run = enc[i]
raw = enc[i+1]
i += 2
run = min(run, self.max_run)
# we mark the run length from 512 to 256
out.extend([self.raw_byte_offset + self.raw_byte_offset - run, raw])
elif tok == self.sentinel_rank_run:
length = enc[i]
i += 1
n_bytes = math.ceil(length / self.ranks_per_byte)
for _ in range(n_bytes):
pb = enc[i] + self.raw_byte_offset
out.append(pb)
i += 1
else:
raise ValueError(f"unknown token {tok}")
return out
def pseudo_bytes_to_encoding(self, pb: list[int], original_encoding: list[int]) -> list[int]:
# NOTE: we do not expect the encoding-to-pseudo-bytes conversion to be lossless,
# so we need to pass the original encoding to reconstruct the original encoding
# this function is just for sanity check
raise NotImplementedError("Not implemented")
def decode_window(
self,
stream: List[int],
original_len: int,
topk_symbols: List[List[int]],
) -> List[int]:
"""
`topk_symbols[pos][idx]` must give the byte value (0-255) that
corresponds to rank `idx` at position `pos`, e.g. recomputed from
the helper LM during decoding.
"""
out: List[int] = []
# position in input stream
i = 0
# position in output tokens
pos = 0
while pos < original_len:
tok = stream[i]
i += 1
if tok < 256:
out.append(tok)
pos += 1
elif tok == self.sentinel_rle:
run_len = stream[i]
raw = stream[i+1]
i += 2
out.extend([raw] * run_len)
pos += run_len
elif tok == self.sentinel_rank_run:
run_len = stream[i]
i += 1
bytes_needed = math.ceil(run_len / self.ranks_per_byte)
payload = stream[i: i + bytes_needed]
i += bytes_needed
for rank_id in _unpack_rank_ids(payload, run_len, self.rank_bitlength):
sym = topk_symbols[pos][rank_id]
out.append(sym)
pos += 1
else:
raise ValueError(f"Unknown sentinel {tok}")
return out[:original_len]
if __name__ == "__main__":
import torch, random
random.seed(0)
T, K = 384, 13 # demonstrate non-power-of-two K
tokens = torch.randint(0, 32, (T,)).tolist()
repeat_probs = torch.rand(T).tolist()
ranks = torch.randint(0, K + 5, (T,)).tolist() # some ranks ≥K → raw
ranks = [r if r < K else K for r in ranks]
# fake LM top-K table for demo: identity mapping
topk = [[tokens[t]] * K for t in range(T)]
codec = SimpleAdaptiveRankCodec(top_k=K, tau=0.00)
enc = codec.encode_window(tokens, repeat_probs, ranks)
dec = codec.decode_window(enc, T, topk)
print(f"raw={T} encoded={len(enc)} ratio={len(enc)/T:.2f}")
assert dec == tokens
print("✓ window-enc-dec round-trip passes")
|